mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
221 lines
6.7 KiB
Python
221 lines
6.7 KiB
Python
import logging
|
|
import os
|
|
import random
|
|
import threading
|
|
import time
|
|
import uuid
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
|
|
import matplotlib.pyplot as plt
|
|
import pandas as pd
|
|
import pytest
|
|
from dotenv import load_dotenv
|
|
from letta_client import Letta
|
|
from tqdm import tqdm
|
|
|
|
from letta.schemas.embedding_config import EmbeddingConfig
|
|
from letta.schemas.llm_config import LLMConfig
|
|
|
|
logging.getLogger("httpx").setLevel(logging.WARNING)
|
|
logging.getLogger("httpcore").setLevel(logging.WARNING)
|
|
|
|
|
|
# --- Server Management --- #
|
|
|
|
|
|
def _run_server():
|
|
"""Starts the Letta server in a background thread."""
|
|
load_dotenv()
|
|
from letta.server.rest_api.app import start_server
|
|
|
|
start_server(debug=True)
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def server_url():
|
|
"""Ensures a server is running and returns its base URL."""
|
|
url = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
|
|
|
|
if not os.getenv("LETTA_SERVER_URL"):
|
|
thread = threading.Thread(target=_run_server, daemon=True)
|
|
thread.start()
|
|
time.sleep(2) # Allow server startup time
|
|
|
|
return url
|
|
|
|
|
|
# --- Client Setup --- #
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def client(server_url):
|
|
"""Creates a REST client for testing."""
|
|
client = Letta(base_url=server_url)
|
|
yield client
|
|
|
|
|
|
@pytest.fixture()
|
|
def roll_dice_tool(client):
|
|
def roll_dice():
|
|
"""
|
|
Rolls a 6 sided die.
|
|
|
|
Returns:
|
|
str: The roll result.
|
|
"""
|
|
return "Rolled a 10!"
|
|
|
|
tool = client.tools.upsert_from_function(func=roll_dice)
|
|
# Yield the created tool
|
|
yield tool
|
|
|
|
|
|
@pytest.fixture()
|
|
def rethink_tool(client):
|
|
def rethink_memory(agent_state: "AgentState", new_memory: str, target_block_label: str) -> str: # type: ignore
|
|
"""
|
|
Re-evaluate the memory in block_name, integrating new and updated facts.
|
|
Replace outdated information with the most likely truths, avoiding redundancy with original memories.
|
|
Ensure consistency with other memory blocks.
|
|
|
|
Args:
|
|
new_memory (str): The new memory with information integrated from the memory block. If there is no new information, then this should be the same as the content in the source block.
|
|
target_block_label (str): The name of the block to write to.
|
|
Returns:
|
|
str: None is always returned as this function does not produce a response.
|
|
"""
|
|
agent_state.memory.update_block_value(label=target_block_label, value=new_memory)
|
|
return None
|
|
|
|
tool = client.tools.upsert_from_function(func=rethink_memory)
|
|
yield tool
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
def weather_tool(client):
|
|
def get_weather(location: str) -> str:
|
|
"""
|
|
Fetches the current weather for a given location.
|
|
|
|
Parameters:
|
|
location (str): The location to get the weather for.
|
|
|
|
Returns:
|
|
str: A formatted string describing the weather in the given location.
|
|
|
|
Raises:
|
|
RuntimeError: If the request to fetch weather data fails.
|
|
"""
|
|
import requests
|
|
|
|
url = f"https://wttr.in/{location}?format=%C+%t"
|
|
|
|
response = requests.get(url)
|
|
if response.status_code == 200:
|
|
weather_data = response.text
|
|
return f"The weather in {location} is {weather_data}."
|
|
else:
|
|
raise RuntimeError(f"Failed to get weather data, status code: {response.status_code}")
|
|
|
|
tool = client.tools.upsert_from_function(func=get_weather)
|
|
# Yield the created tool
|
|
yield tool
|
|
|
|
|
|
# --- Load Test --- #
|
|
|
|
|
|
@pytest.mark.slow
|
|
def test_parallel_mass_update_agents_complex(client, roll_dice_tool, weather_tool, rethink_tool):
|
|
# 1) Create 30 agents WITHOUT the rethink_tool initially
|
|
agent_ids = []
|
|
for i in range(5):
|
|
agent = client.agents.create(
|
|
name=f"complex_agent_{i}_{uuid.uuid4().hex[:6]}",
|
|
tool_ids=[roll_dice_tool.id, weather_tool.id],
|
|
include_base_tools=False,
|
|
memory_blocks=[
|
|
{"label": "human", "value": "Name: Matt"},
|
|
{"label": "persona", "value": "Friendly agent"},
|
|
],
|
|
llm_config=LLMConfig.default_config("gpt-4o-mini"),
|
|
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
|
)
|
|
agent_ids.append(agent.id)
|
|
|
|
# 2) Pre-create 10 new blocks *per* agent
|
|
per_agent_blocks = {}
|
|
for aid in agent_ids:
|
|
block_ids = []
|
|
for j in range(10):
|
|
blk = client.blocks.create(
|
|
label=f"{aid[:6]}_blk{j}",
|
|
value="Precreated block content",
|
|
description="Load-test block",
|
|
limit=500,
|
|
metadata={"idx": str(j)},
|
|
)
|
|
block_ids.append(blk.id)
|
|
per_agent_blocks[aid] = block_ids
|
|
|
|
# 3) Dispatch 100 updates per agent in parallel
|
|
total_updates = len(agent_ids) * 100
|
|
latencies = []
|
|
|
|
def do_update(agent_id: str):
|
|
start = time.time()
|
|
if random.random() < 0.5:
|
|
client.agents.modify(agent_id=agent_id, tool_ids=[rethink_tool.id])
|
|
else:
|
|
bid = random.choice(per_agent_blocks[agent_id])
|
|
client.agents.modify(agent_id=agent_id, block_ids=[bid])
|
|
return time.time() - start
|
|
|
|
with ThreadPoolExecutor(max_workers=50) as executor:
|
|
futures = [executor.submit(do_update, aid) for aid in agent_ids for _ in range(10)]
|
|
for future in tqdm(as_completed(futures), total=total_updates, desc="Complex updates"):
|
|
latencies.append(future.result())
|
|
|
|
# 4) Cleanup
|
|
for aid in agent_ids:
|
|
client.agents.delete(aid)
|
|
|
|
# 5) Plot latency distribution
|
|
df = pd.DataFrame({"latency": latencies})
|
|
plt.figure(figsize=(12, 6))
|
|
|
|
plt.subplot(1, 2, 1)
|
|
plt.hist(df["latency"], bins=30, edgecolor="black")
|
|
plt.title("Update Latency Distribution")
|
|
plt.xlabel("Latency (seconds)")
|
|
plt.ylabel("Frequency")
|
|
|
|
plt.subplot(1, 2, 2)
|
|
plt.boxplot(df["latency"], vert=False)
|
|
plt.title("Update Latency Boxplot")
|
|
plt.xlabel("Latency (seconds)")
|
|
|
|
plt.tight_layout()
|
|
plot_file = f"complex_update_latency_{int(time.time())}.png"
|
|
plt.savefig(plot_file)
|
|
plt.close()
|
|
|
|
# 6) Report summary
|
|
mean = df["latency"].mean()
|
|
median = df["latency"].median()
|
|
minimum = df["latency"].min()
|
|
maximum = df["latency"].max()
|
|
stdev = df["latency"].std()
|
|
|
|
print("\n===== Complex Update Latency Statistics =====")
|
|
print(f"Total updates: {len(latencies)}")
|
|
print(f"Mean: {mean:.3f}s")
|
|
print(f"Median: {median:.3f}s")
|
|
print(f"Min: {minimum:.3f}s")
|
|
print(f"Max: {maximum:.3f}s")
|
|
print(f"Std: {stdev:.3f}s")
|
|
print(f"Plot saved to: {plot_file}")
|
|
|
|
# Sanity assertion
|
|
assert median < 2.0, f"Median update latency too high: {median:.3f}s"
|