mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
143 lines
4.0 KiB
Python
143 lines
4.0 KiB
Python
import os
|
|
import threading
|
|
import time
|
|
import uuid
|
|
|
|
import pytest
|
|
from dotenv import load_dotenv
|
|
|
|
from memgpt import Admin, create_client
|
|
from memgpt.config import MemGPTConfig
|
|
from memgpt.constants import DEFAULT_PRESET
|
|
from memgpt.credentials import MemGPTCredentials
|
|
from memgpt.data_types import Preset # TODO move to PresetModel
|
|
from memgpt.settings import settings
|
|
from tests.utils import create_config
|
|
|
|
test_agent_name = f"test_client_{str(uuid.uuid4())}"
|
|
# test_preset_name = "test_preset"
|
|
test_preset_name = DEFAULT_PRESET
|
|
test_agent_state = None
|
|
client = None
|
|
|
|
test_agent_state_post_message = None
|
|
test_user_id = uuid.uuid4()
|
|
|
|
|
|
# admin credentials
|
|
test_server_token = "test_server_token"
|
|
|
|
|
|
def _reset_config():
|
|
|
|
# Use os.getenv with a fallback to os.environ.get
|
|
db_url = settings.memgpt_pg_uri
|
|
|
|
if os.getenv("OPENAI_API_KEY"):
|
|
create_config("openai")
|
|
credentials = MemGPTCredentials(
|
|
openai_key=os.getenv("OPENAI_API_KEY"),
|
|
)
|
|
else: # hosted
|
|
create_config("memgpt_hosted")
|
|
credentials = MemGPTCredentials()
|
|
|
|
config = MemGPTConfig.load()
|
|
|
|
# set to use postgres
|
|
config.archival_storage_uri = db_url
|
|
config.recall_storage_uri = db_url
|
|
config.metadata_storage_uri = db_url
|
|
config.archival_storage_type = "postgres"
|
|
config.recall_storage_type = "postgres"
|
|
config.metadata_storage_type = "postgres"
|
|
|
|
config.save()
|
|
credentials.save()
|
|
print("_reset_config :: ", config.config_path)
|
|
|
|
|
|
def run_server():
|
|
|
|
load_dotenv()
|
|
|
|
_reset_config()
|
|
|
|
from memgpt.server.rest_api.server import start_server
|
|
|
|
print("Starting server...")
|
|
start_server(debug=True)
|
|
|
|
|
|
# Fixture to create clients with different configurations
|
|
@pytest.fixture(
|
|
params=[ # whether to use REST API server
|
|
{"server": True},
|
|
# {"server": False} # TODO: add when implemented
|
|
],
|
|
scope="module",
|
|
)
|
|
def admin_client(request):
|
|
if request.param["server"]:
|
|
# get URL from enviornment
|
|
server_url = os.getenv("MEMGPT_SERVER_URL")
|
|
if server_url is None:
|
|
# run server in thread
|
|
# NOTE: must set MEMGPT_SERVER_PASS enviornment variable
|
|
server_url = "http://localhost:8283"
|
|
print("Starting server thread")
|
|
thread = threading.Thread(target=run_server, daemon=True)
|
|
thread.start()
|
|
time.sleep(5)
|
|
print("Running client tests with server:", server_url)
|
|
# create user via admin client
|
|
admin = Admin(server_url, test_server_token)
|
|
response = admin.create_user(test_user_id) # Adjust as per your client's method
|
|
|
|
yield admin
|
|
|
|
|
|
def test_concurrent_messages(admin_client):
|
|
# test concurrent messages
|
|
|
|
# create three
|
|
|
|
results = []
|
|
|
|
def _send_message():
|
|
try:
|
|
print("START SEND MESSAGE")
|
|
response = admin_client.create_user()
|
|
token = response.api_key
|
|
client = create_client(base_url=admin_client.base_url, token=token)
|
|
agent = client.create_agent()
|
|
|
|
print("Agent created", agent.id)
|
|
|
|
st = time.time()
|
|
message = "Hello, how are you?"
|
|
response = client.send_message(agent_id=agent.id, message=message, role="user")
|
|
et = time.time()
|
|
print(f"Message sent from {st} to {et}")
|
|
print(response.messages)
|
|
results.append((st, et))
|
|
except Exception as e:
|
|
print("ERROR", e)
|
|
|
|
threads = []
|
|
print("Starting threads...")
|
|
for i in range(5):
|
|
thread = threading.Thread(target=_send_message)
|
|
threads.append(thread)
|
|
thread.start()
|
|
print("CREATED THREAD")
|
|
|
|
print("waiting for threads to finish...")
|
|
for thread in threads:
|
|
print(thread.join())
|
|
|
|
# make sure runtime are overlapping
|
|
assert (results[0][0] < results[1][0] and results[0][1] > results[1][0]) or (
|
|
results[1][0] < results[0][0] and results[1][1] > results[0][0]
|
|
), f"Threads should have overlapping runtimes {results}"
|