mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
197 lines
6.7 KiB
Python
197 lines
6.7 KiB
Python
import datetime
|
|
import json
|
|
import math
|
|
import os
|
|
import random
|
|
import uuid
|
|
|
|
import pytest
|
|
from faker import Faker
|
|
from tqdm import tqdm
|
|
|
|
from letta import create_client
|
|
from letta.orm import Base
|
|
from letta.schemas.embedding_config import EmbeddingConfig
|
|
from letta.schemas.llm_config import LLMConfig
|
|
from letta.schemas.message import Message
|
|
from letta.services.agent_manager import AgentManager
|
|
from letta.services.message_manager import MessageManager
|
|
from tests.integration_test_summarizer import LLM_CONFIG_DIR
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def truncate_database():
|
|
from letta.server.db import db_context
|
|
|
|
with db_context() as session:
|
|
for table in reversed(Base.metadata.sorted_tables): # Reverse to avoid FK issues
|
|
session.execute(table.delete()) # Truncate table
|
|
session.commit()
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
def client():
|
|
filename = os.path.join(LLM_CONFIG_DIR, "claude-3-5-sonnet.json")
|
|
config_data = json.load(open(filename, "r"))
|
|
llm_config = LLMConfig(**config_data)
|
|
client = create_client()
|
|
client.set_default_llm_config(llm_config)
|
|
client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai"))
|
|
|
|
yield client
|
|
|
|
|
|
def generate_tool_call_id():
|
|
"""Generates a unique tool call ID."""
|
|
return "toolu_" + uuid.uuid4().hex[:24]
|
|
|
|
|
|
def generate_timestamps(base_time):
|
|
"""Creates a sequence of timestamps for user, assistant, and tool messages."""
|
|
user_time = base_time
|
|
send_time = user_time + datetime.timedelta(seconds=random.randint(2, 5))
|
|
tool_time = send_time + datetime.timedelta(seconds=random.randint(1, 3))
|
|
next_group_time = tool_time + datetime.timedelta(seconds=random.randint(5, 10))
|
|
|
|
return user_time, send_time, tool_time, next_group_time
|
|
|
|
|
|
def get_conversation_pair():
|
|
fake = Faker()
|
|
return f"Where does {fake.name()} live?", f"{fake.address()}"
|
|
|
|
|
|
def create_user_message(agent_id, organization_id, message_text, timestamp):
|
|
"""Creates a user message dictionary."""
|
|
return {
|
|
"role": "user",
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": json.dumps(
|
|
{"type": "user_message", "message": message_text, "time": timestamp.strftime("%Y-%m-%d %I:%M:%S %p PST-0800")}, indent=2
|
|
),
|
|
}
|
|
],
|
|
"organization_id": organization_id,
|
|
"agent_id": agent_id,
|
|
"model": None,
|
|
"name": None,
|
|
"tool_calls": None,
|
|
"tool_call_id": None,
|
|
}
|
|
|
|
|
|
def create_send_message(agent_id, organization_id, assistant_text, tool_call_id, timestamp):
|
|
"""Creates an assistant message dictionary."""
|
|
return {
|
|
"role": "assistant",
|
|
"content": [{"type": "text", "text": f"Assistant reply generated at {timestamp.strftime('%Y-%m-%d %I:%M:%S %p PST-0800')}."}],
|
|
"organization_id": organization_id,
|
|
"agent_id": agent_id,
|
|
"model": "claude-3-5-haiku-20241022",
|
|
"name": None,
|
|
"tool_calls": [
|
|
{
|
|
"id": tool_call_id,
|
|
"function": {
|
|
"name": "send_message",
|
|
"arguments": json.dumps(
|
|
{"message": assistant_text, "time": timestamp.strftime("%Y-%m-%d %I:%M:%S %p PST-0800")}, indent=2
|
|
),
|
|
},
|
|
"type": "function",
|
|
}
|
|
],
|
|
"tool_call_id": None,
|
|
}
|
|
|
|
|
|
def create_tool_message(agent_id, organization_id, tool_call_id, timestamp):
|
|
"""Creates a tool response message dictionary."""
|
|
return {
|
|
"role": "tool",
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": json.dumps(
|
|
{"status": "OK", "message": "None", "time": timestamp.strftime("%Y-%m-%d %I:%M:%S %p PST-0800")}, indent=2
|
|
),
|
|
}
|
|
],
|
|
"organization_id": organization_id,
|
|
"agent_id": agent_id,
|
|
"model": "claude-3-5-haiku-20241022",
|
|
"name": "send_message",
|
|
"tool_calls": None,
|
|
"tool_call_id": tool_call_id,
|
|
}
|
|
|
|
|
|
@pytest.mark.parametrize("num_messages", [1000])
|
|
def test_many_messages_performance(client, num_messages):
|
|
"""Main test function to generate messages and insert them into the database."""
|
|
message_manager = MessageManager()
|
|
agent_manager = AgentManager()
|
|
actor = client.user
|
|
|
|
start_time = datetime.datetime.now()
|
|
last_event_time = start_time # Track last event time
|
|
|
|
def log_event(event):
|
|
nonlocal last_event_time
|
|
now = datetime.datetime.now()
|
|
total_elapsed = (now - start_time).total_seconds()
|
|
step_elapsed = (now - last_event_time).total_seconds()
|
|
print(f"[+{total_elapsed:.3f}s | Δ{step_elapsed:.3f}s] {event}")
|
|
last_event_time = now # Update last event time
|
|
|
|
log_event(f"Starting test with {num_messages} messages")
|
|
|
|
agent_state = client.create_agent(name="manager")
|
|
log_event(f"Created agent with ID {agent_state.id}")
|
|
|
|
message_group_size = 3
|
|
num_groups = math.ceil((num_messages - 4) / message_group_size)
|
|
base_time = datetime.datetime(2025, 2, 10, 16, 3, 22)
|
|
current_time = base_time
|
|
organization_id = "org-00000000-0000-4000-8000-000000000000"
|
|
|
|
all_messages = []
|
|
|
|
for _ in tqdm(range(num_groups)):
|
|
user_text, assistant_text = get_conversation_pair()
|
|
tool_call_id = generate_tool_call_id()
|
|
user_time, send_time, tool_time, current_time = generate_timestamps(current_time)
|
|
new_messages = [
|
|
Message(**create_user_message(agent_state.id, organization_id, user_text, user_time)),
|
|
Message(**create_send_message(agent_state.id, organization_id, assistant_text, tool_call_id, send_time)),
|
|
Message(**create_tool_message(agent_state.id, organization_id, tool_call_id, tool_time)),
|
|
]
|
|
all_messages.extend(new_messages)
|
|
|
|
log_event(f"Finished generating {len(all_messages)} messages")
|
|
|
|
message_manager.create_many_messages(all_messages, actor=actor)
|
|
log_event("Inserted messages into the database")
|
|
|
|
agent_manager.set_in_context_messages(
|
|
agent_id=agent_state.id, message_ids=agent_state.message_ids + [m.id for m in all_messages], actor=client.user
|
|
)
|
|
log_event("Updated agent context with messages")
|
|
|
|
messages = message_manager.list_messages_for_agent(agent_id=agent_state.id, actor=client.user, limit=1000000000)
|
|
log_event(f"Retrieved {len(messages)} messages from the database")
|
|
|
|
assert len(messages) >= num_groups * message_group_size
|
|
|
|
response = client.send_message(
|
|
agent_id=agent_state.id,
|
|
role="user",
|
|
message="What have we been talking about?",
|
|
)
|
|
log_event("Sent message to agent and received response")
|
|
|
|
assert response
|
|
log_event("Test completed successfully")
|