fix: patch embedding_model null issue in tests (#1305)

Co-authored-by: Sarah Wooders <sarahwooders@gmail.com>
This commit is contained in:
Charles Packer 2024-04-27 21:33:00 -07:00 committed by GitHub
parent 61a8f9d229
commit ddc3dabf4f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 37 additions and 104 deletions

View File

@ -37,4 +37,4 @@ services:
- OPENAI_API_KEY=${OPENAI_API_KEY}
volumes:
- ./configs/server_config.yaml:/root/.memgpt/config # config file
- ~/.memgpt/credentials:/root/.memgpt/credentials # credentials file
# ~/.memgpt/credentials:/root/.memgpt/credentials # credentials file

View File

@ -162,7 +162,7 @@ def quickstart(
# Load the file from the relative path
script_dir = os.path.dirname(__file__) # Get the directory where the script is located
backup_config_path = os.path.join(script_dir, "..", "configs", "memgpt_hosted.json")
backup_config_path = os.path.join(script_dir, "..", "..", "configs", "memgpt_hosted.json")
try:
with open(backup_config_path, "r", encoding="utf-8") as file:
backup_config = json.load(file)
@ -175,7 +175,7 @@ def quickstart(
# Load the file from the relative path
script_dir = os.path.dirname(__file__) # Get the directory where the script is located
print("SCRIPT", script_dir)
backup_config_path = os.path.join(script_dir, "..", "configs", "memgpt_hosted.json")
backup_config_path = os.path.join(script_dir, "..", "..", "configs", "memgpt_hosted.json")
print("FILE PATH", backup_config_path)
try:
with open(backup_config_path, "r", encoding="utf-8") as file:
@ -214,7 +214,7 @@ def quickstart(
# Load the file from the relative path
script_dir = os.path.dirname(__file__) # Get the directory where the script is located
backup_config_path = os.path.join(script_dir, "..", "configs", "openai.json")
backup_config_path = os.path.join(script_dir, "..", "..", "configs", "openai.json")
try:
with open(backup_config_path, "r", encoding="utf-8") as file:
backup_config = json.load(file)
@ -226,7 +226,7 @@ def quickstart(
else:
# Load the file from the relative path
script_dir = os.path.dirname(__file__) # Get the directory where the script is located
backup_config_path = os.path.join(script_dir, "..", "configs", "openai.json")
backup_config_path = os.path.join(script_dir, "..", "..", "configs", "openai.json")
try:
with open(backup_config_path, "r", encoding="utf-8") as file:
backup_config = json.load(file)

View File

@ -1,12 +0,0 @@
{
"context_window": 16384,
"model": "memgpt",
"model_endpoint_type": "openai",
"model_endpoint": "https://inference.memgpt.ai",
"model_wrapper": "chatml",
"embedding_endpoint_type": "hugging-face",
"embedding_endpoint": "https://embeddings.memgpt.ai",
"embedding_model": "BAAI/bge-large-en-v1.5",
"embedding_dim": 1024,
"embedding_chunk_size": 300
}

View File

@ -1,12 +0,0 @@
{
"context_window": 8192,
"model": "gpt-4",
"model_endpoint_type": "openai",
"model_endpoint": "https://api.openai.com/v1",
"model_wrapper": null,
"embedding_endpoint_type": "openai",
"embedding_endpoint": "https://api.openai.com/v1",
"embedding_model": "text-embedding-ada-002",
"embedding_dim": 1536,
"embedding_chunk_size": 300
}

View File

@ -209,6 +209,7 @@ class SyncServer(LockingServer):
# Initialize the connection to the DB
self.config = MemGPTConfig.load()
print(f"server :: loading configuration from '{self.config.config_path}'")
assert self.config.persona is not None, "Persona must be set in the config"
assert self.config.human is not None, "Human must be set in the config"
@ -260,6 +261,7 @@ class SyncServer(LockingServer):
embedding_model=self.config.default_embedding_config.embedding_model,
embedding_chunk_size=self.config.default_embedding_config.embedding_chunk_size,
)
assert self.server_embedding_config.embedding_model is not None, vars(self.server_embedding_config)
# Initialize the metadata store
self.ms = MetadataStore(self.config)

View File

@ -1,16 +1,11 @@
import os
import threading
import time
import uuid
import pytest
from dotenv import load_dotenv
from memgpt import Admin
from memgpt.credentials import MemGPTCredentials
from memgpt.data_types import EmbeddingConfig, LLMConfig
from memgpt.settings import settings
from tests.config import TestMGPTConfig
from tests.test_client import _reset_config, run_server
test_base_url = "http://localhost:8283"
@ -18,69 +13,6 @@ test_base_url = "http://localhost:8283"
test_server_token = "test_server_token"
def run_server():
pass
load_dotenv()
db_url = settings.pg_db
if os.getenv("OPENAI_API_KEY"):
config = TestMGPTConfig(
archival_storage_uri=db_url,
recall_storage_uri=db_url,
metadata_storage_uri=db_url,
archival_storage_type="postgres",
recall_storage_type="postgres",
metadata_storage_type="postgres",
# embeddings
default_embedding_config=EmbeddingConfig(
embedding_endpoint_type="openai",
embedding_endpoint="https://api.openai.com/v1",
embedding_dim=1536,
),
# llms
default_llm_config=LLMConfig(
model_endpoint_type="openai",
model_endpoint="https://api.openai.com/v1",
model="gpt-4",
),
)
credentials = MemGPTCredentials(
openai_key=os.getenv("OPENAI_API_KEY"),
)
else: # hosted
config = TestMGPTConfig(
archival_storage_uri=db_url,
recall_storage_uri=db_url,
metadata_storage_uri=db_url,
archival_storage_type="postgres",
recall_storage_type="postgres",
metadata_storage_type="postgres",
# embeddings
default_embedding_config=EmbeddingConfig(
embedding_endpoint_type="hugging-face",
embedding_endpoint="https://embeddings.memgpt.ai",
embedding_model="BAAI/bge-large-en-v1.5",
embedding_dim=1024,
),
# llms
default_llm_config=LLMConfig(
model_endpoint_type="vllm",
model_endpoint="https://api.memgpt.ai",
model="ehartford/dolphin-2.5-mixtral-8x7b",
),
)
credentials = MemGPTCredentials()
config.save()
credentials.save()
# start server
from memgpt.server.rest_api.server import start_server
start_server(debug=True)
@pytest.fixture(scope="session", autouse=True)
def start_uvicorn_server():
"""Starts Uvicorn server in a background thread."""
@ -102,6 +34,8 @@ def admin_client():
def test_admin_client(admin_client):
_reset_config()
# create a user
user_id = uuid.uuid4()
create_user1_response = admin_client.create_user(user_id)

View File

@ -28,10 +28,7 @@ test_user_id = uuid.uuid4()
test_server_token = "test_server_token"
def run_server():
pass
load_dotenv()
def _reset_config():
# Use os.getenv with a fallback to os.environ.get
db_url = settings.pg_uri
@ -48,8 +45,8 @@ def run_server():
default_embedding_config=EmbeddingConfig(
embedding_endpoint_type="openai",
embedding_endpoint="https://api.openai.com/v1",
embedding_dim=1536,
embedding_model="text-embedding-ada-002",
embedding_dim=1536,
),
# llms
default_llm_config=LLMConfig(
@ -87,10 +84,18 @@ def run_server():
config.save()
credentials.save()
print("_reset_config :: ", config.config_path)
def run_server():
load_dotenv()
_reset_config()
from memgpt.server.rest_api.server import start_server
print("Starting server...", config.config_path)
print("Starting server...")
start_server(debug=True)
@ -146,6 +151,8 @@ def agent(client):
def test_agent(client, agent):
_reset_config()
# test client.rename_agent
new_name = "RenamedTestAgent"
client.rename_agent(agent_id=agent.id, new_name=new_name)
@ -160,6 +167,8 @@ def test_agent(client, agent):
def test_memory(client, agent):
_reset_config()
memory_response = client.get_agent_memory(agent_id=agent.id)
print("MEMORY", memory_response)
@ -173,6 +182,8 @@ def test_memory(client, agent):
def test_agent_interactions(client, agent):
_reset_config()
message = "Hello, agent!"
message_response = client.user_message(agent_id=str(agent.id), message=message)
@ -182,6 +193,8 @@ def test_agent_interactions(client, agent):
def test_archival_memory(client, agent):
_reset_config()
memory_content = "Archival memory content"
insert_response = client.insert_archival_memory(agent_id=agent.id, memory=memory_content)
assert insert_response, "Inserting archival memory failed"
@ -197,6 +210,8 @@ def test_archival_memory(client, agent):
def test_messages(client, agent):
_reset_config()
send_message_response = client.send_message(agent_id=agent.id, message="Test message", role="user")
assert send_message_response, "Sending message failed"
@ -205,6 +220,8 @@ def test_messages(client, agent):
def test_humans_personas(client, agent):
_reset_config()
humans_response = client.list_humans()
print("HUMANS", humans_response)
@ -232,6 +249,8 @@ def test_humans_personas(client, agent):
def test_config(client, agent):
_reset_config()
models_response = client.list_models()
print("MODELS", models_response)
@ -242,6 +261,7 @@ def test_config(client, agent):
def test_sources(client, agent):
_reset_config()
if not hasattr(client, "base_url"):
pytest.skip("Skipping test_sources because base_url is None")
@ -298,6 +318,7 @@ def test_sources(client, agent):
def test_presets(client, agent):
_reset_config()
new_preset = Preset(
# user_id=client.user_id,