mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
fix: patch embedding_model
null issue in tests (#1305)
Co-authored-by: Sarah Wooders <sarahwooders@gmail.com>
This commit is contained in:
parent
61a8f9d229
commit
ddc3dabf4f
@ -37,4 +37,4 @@ services:
|
||||
- OPENAI_API_KEY=${OPENAI_API_KEY}
|
||||
volumes:
|
||||
- ./configs/server_config.yaml:/root/.memgpt/config # config file
|
||||
- ~/.memgpt/credentials:/root/.memgpt/credentials # credentials file
|
||||
# ~/.memgpt/credentials:/root/.memgpt/credentials # credentials file
|
||||
|
@ -162,7 +162,7 @@ def quickstart(
|
||||
|
||||
# Load the file from the relative path
|
||||
script_dir = os.path.dirname(__file__) # Get the directory where the script is located
|
||||
backup_config_path = os.path.join(script_dir, "..", "configs", "memgpt_hosted.json")
|
||||
backup_config_path = os.path.join(script_dir, "..", "..", "configs", "memgpt_hosted.json")
|
||||
try:
|
||||
with open(backup_config_path, "r", encoding="utf-8") as file:
|
||||
backup_config = json.load(file)
|
||||
@ -175,7 +175,7 @@ def quickstart(
|
||||
# Load the file from the relative path
|
||||
script_dir = os.path.dirname(__file__) # Get the directory where the script is located
|
||||
print("SCRIPT", script_dir)
|
||||
backup_config_path = os.path.join(script_dir, "..", "configs", "memgpt_hosted.json")
|
||||
backup_config_path = os.path.join(script_dir, "..", "..", "configs", "memgpt_hosted.json")
|
||||
print("FILE PATH", backup_config_path)
|
||||
try:
|
||||
with open(backup_config_path, "r", encoding="utf-8") as file:
|
||||
@ -214,7 +214,7 @@ def quickstart(
|
||||
|
||||
# Load the file from the relative path
|
||||
script_dir = os.path.dirname(__file__) # Get the directory where the script is located
|
||||
backup_config_path = os.path.join(script_dir, "..", "configs", "openai.json")
|
||||
backup_config_path = os.path.join(script_dir, "..", "..", "configs", "openai.json")
|
||||
try:
|
||||
with open(backup_config_path, "r", encoding="utf-8") as file:
|
||||
backup_config = json.load(file)
|
||||
@ -226,7 +226,7 @@ def quickstart(
|
||||
else:
|
||||
# Load the file from the relative path
|
||||
script_dir = os.path.dirname(__file__) # Get the directory where the script is located
|
||||
backup_config_path = os.path.join(script_dir, "..", "configs", "openai.json")
|
||||
backup_config_path = os.path.join(script_dir, "..", "..", "configs", "openai.json")
|
||||
try:
|
||||
with open(backup_config_path, "r", encoding="utf-8") as file:
|
||||
backup_config = json.load(file)
|
||||
|
@ -1,12 +0,0 @@
|
||||
{
|
||||
"context_window": 16384,
|
||||
"model": "memgpt",
|
||||
"model_endpoint_type": "openai",
|
||||
"model_endpoint": "https://inference.memgpt.ai",
|
||||
"model_wrapper": "chatml",
|
||||
"embedding_endpoint_type": "hugging-face",
|
||||
"embedding_endpoint": "https://embeddings.memgpt.ai",
|
||||
"embedding_model": "BAAI/bge-large-en-v1.5",
|
||||
"embedding_dim": 1024,
|
||||
"embedding_chunk_size": 300
|
||||
}
|
@ -1,12 +0,0 @@
|
||||
{
|
||||
"context_window": 8192,
|
||||
"model": "gpt-4",
|
||||
"model_endpoint_type": "openai",
|
||||
"model_endpoint": "https://api.openai.com/v1",
|
||||
"model_wrapper": null,
|
||||
"embedding_endpoint_type": "openai",
|
||||
"embedding_endpoint": "https://api.openai.com/v1",
|
||||
"embedding_model": "text-embedding-ada-002",
|
||||
"embedding_dim": 1536,
|
||||
"embedding_chunk_size": 300
|
||||
}
|
@ -209,6 +209,7 @@ class SyncServer(LockingServer):
|
||||
|
||||
# Initialize the connection to the DB
|
||||
self.config = MemGPTConfig.load()
|
||||
print(f"server :: loading configuration from '{self.config.config_path}'")
|
||||
assert self.config.persona is not None, "Persona must be set in the config"
|
||||
assert self.config.human is not None, "Human must be set in the config"
|
||||
|
||||
@ -260,6 +261,7 @@ class SyncServer(LockingServer):
|
||||
embedding_model=self.config.default_embedding_config.embedding_model,
|
||||
embedding_chunk_size=self.config.default_embedding_config.embedding_chunk_size,
|
||||
)
|
||||
assert self.server_embedding_config.embedding_model is not None, vars(self.server_embedding_config)
|
||||
|
||||
# Initialize the metadata store
|
||||
self.ms = MetadataStore(self.config)
|
||||
|
@ -1,16 +1,11 @@
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from memgpt import Admin
|
||||
from memgpt.credentials import MemGPTCredentials
|
||||
from memgpt.data_types import EmbeddingConfig, LLMConfig
|
||||
from memgpt.settings import settings
|
||||
from tests.config import TestMGPTConfig
|
||||
from tests.test_client import _reset_config, run_server
|
||||
|
||||
test_base_url = "http://localhost:8283"
|
||||
|
||||
@ -18,69 +13,6 @@ test_base_url = "http://localhost:8283"
|
||||
test_server_token = "test_server_token"
|
||||
|
||||
|
||||
def run_server():
|
||||
pass
|
||||
|
||||
load_dotenv()
|
||||
|
||||
db_url = settings.pg_db
|
||||
if os.getenv("OPENAI_API_KEY"):
|
||||
config = TestMGPTConfig(
|
||||
archival_storage_uri=db_url,
|
||||
recall_storage_uri=db_url,
|
||||
metadata_storage_uri=db_url,
|
||||
archival_storage_type="postgres",
|
||||
recall_storage_type="postgres",
|
||||
metadata_storage_type="postgres",
|
||||
# embeddings
|
||||
default_embedding_config=EmbeddingConfig(
|
||||
embedding_endpoint_type="openai",
|
||||
embedding_endpoint="https://api.openai.com/v1",
|
||||
embedding_dim=1536,
|
||||
),
|
||||
# llms
|
||||
default_llm_config=LLMConfig(
|
||||
model_endpoint_type="openai",
|
||||
model_endpoint="https://api.openai.com/v1",
|
||||
model="gpt-4",
|
||||
),
|
||||
)
|
||||
credentials = MemGPTCredentials(
|
||||
openai_key=os.getenv("OPENAI_API_KEY"),
|
||||
)
|
||||
else: # hosted
|
||||
config = TestMGPTConfig(
|
||||
archival_storage_uri=db_url,
|
||||
recall_storage_uri=db_url,
|
||||
metadata_storage_uri=db_url,
|
||||
archival_storage_type="postgres",
|
||||
recall_storage_type="postgres",
|
||||
metadata_storage_type="postgres",
|
||||
# embeddings
|
||||
default_embedding_config=EmbeddingConfig(
|
||||
embedding_endpoint_type="hugging-face",
|
||||
embedding_endpoint="https://embeddings.memgpt.ai",
|
||||
embedding_model="BAAI/bge-large-en-v1.5",
|
||||
embedding_dim=1024,
|
||||
),
|
||||
# llms
|
||||
default_llm_config=LLMConfig(
|
||||
model_endpoint_type="vllm",
|
||||
model_endpoint="https://api.memgpt.ai",
|
||||
model="ehartford/dolphin-2.5-mixtral-8x7b",
|
||||
),
|
||||
)
|
||||
credentials = MemGPTCredentials()
|
||||
|
||||
config.save()
|
||||
credentials.save()
|
||||
|
||||
# start server
|
||||
from memgpt.server.rest_api.server import start_server
|
||||
|
||||
start_server(debug=True)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def start_uvicorn_server():
|
||||
"""Starts Uvicorn server in a background thread."""
|
||||
@ -102,6 +34,8 @@ def admin_client():
|
||||
|
||||
|
||||
def test_admin_client(admin_client):
|
||||
_reset_config()
|
||||
|
||||
# create a user
|
||||
user_id = uuid.uuid4()
|
||||
create_user1_response = admin_client.create_user(user_id)
|
||||
|
@ -28,10 +28,7 @@ test_user_id = uuid.uuid4()
|
||||
test_server_token = "test_server_token"
|
||||
|
||||
|
||||
def run_server():
|
||||
pass
|
||||
|
||||
load_dotenv()
|
||||
def _reset_config():
|
||||
|
||||
# Use os.getenv with a fallback to os.environ.get
|
||||
db_url = settings.pg_uri
|
||||
@ -48,8 +45,8 @@ def run_server():
|
||||
default_embedding_config=EmbeddingConfig(
|
||||
embedding_endpoint_type="openai",
|
||||
embedding_endpoint="https://api.openai.com/v1",
|
||||
embedding_dim=1536,
|
||||
embedding_model="text-embedding-ada-002",
|
||||
embedding_dim=1536,
|
||||
),
|
||||
# llms
|
||||
default_llm_config=LLMConfig(
|
||||
@ -87,10 +84,18 @@ def run_server():
|
||||
|
||||
config.save()
|
||||
credentials.save()
|
||||
print("_reset_config :: ", config.config_path)
|
||||
|
||||
|
||||
def run_server():
|
||||
|
||||
load_dotenv()
|
||||
|
||||
_reset_config()
|
||||
|
||||
from memgpt.server.rest_api.server import start_server
|
||||
|
||||
print("Starting server...", config.config_path)
|
||||
print("Starting server...")
|
||||
start_server(debug=True)
|
||||
|
||||
|
||||
@ -146,6 +151,8 @@ def agent(client):
|
||||
|
||||
|
||||
def test_agent(client, agent):
|
||||
_reset_config()
|
||||
|
||||
# test client.rename_agent
|
||||
new_name = "RenamedTestAgent"
|
||||
client.rename_agent(agent_id=agent.id, new_name=new_name)
|
||||
@ -160,6 +167,8 @@ def test_agent(client, agent):
|
||||
|
||||
|
||||
def test_memory(client, agent):
|
||||
_reset_config()
|
||||
|
||||
memory_response = client.get_agent_memory(agent_id=agent.id)
|
||||
print("MEMORY", memory_response)
|
||||
|
||||
@ -173,6 +182,8 @@ def test_memory(client, agent):
|
||||
|
||||
|
||||
def test_agent_interactions(client, agent):
|
||||
_reset_config()
|
||||
|
||||
message = "Hello, agent!"
|
||||
message_response = client.user_message(agent_id=str(agent.id), message=message)
|
||||
|
||||
@ -182,6 +193,8 @@ def test_agent_interactions(client, agent):
|
||||
|
||||
|
||||
def test_archival_memory(client, agent):
|
||||
_reset_config()
|
||||
|
||||
memory_content = "Archival memory content"
|
||||
insert_response = client.insert_archival_memory(agent_id=agent.id, memory=memory_content)
|
||||
assert insert_response, "Inserting archival memory failed"
|
||||
@ -197,6 +210,8 @@ def test_archival_memory(client, agent):
|
||||
|
||||
|
||||
def test_messages(client, agent):
|
||||
_reset_config()
|
||||
|
||||
send_message_response = client.send_message(agent_id=agent.id, message="Test message", role="user")
|
||||
assert send_message_response, "Sending message failed"
|
||||
|
||||
@ -205,6 +220,8 @@ def test_messages(client, agent):
|
||||
|
||||
|
||||
def test_humans_personas(client, agent):
|
||||
_reset_config()
|
||||
|
||||
humans_response = client.list_humans()
|
||||
print("HUMANS", humans_response)
|
||||
|
||||
@ -232,6 +249,8 @@ def test_humans_personas(client, agent):
|
||||
|
||||
|
||||
def test_config(client, agent):
|
||||
_reset_config()
|
||||
|
||||
models_response = client.list_models()
|
||||
print("MODELS", models_response)
|
||||
|
||||
@ -242,6 +261,7 @@ def test_config(client, agent):
|
||||
|
||||
|
||||
def test_sources(client, agent):
|
||||
_reset_config()
|
||||
|
||||
if not hasattr(client, "base_url"):
|
||||
pytest.skip("Skipping test_sources because base_url is None")
|
||||
@ -298,6 +318,7 @@ def test_sources(client, agent):
|
||||
|
||||
|
||||
def test_presets(client, agent):
|
||||
_reset_config()
|
||||
|
||||
new_preset = Preset(
|
||||
# user_id=client.user_id,
|
||||
|
Loading…
Reference in New Issue
Block a user