MemGPT/performance_tests/test_agent_mass_update.py
2025-04-19 22:08:59 -07:00

221 lines
6.7 KiB
Python

import logging
import os
import random
import threading
import time
import uuid
from concurrent.futures import ThreadPoolExecutor, as_completed
import matplotlib.pyplot as plt
import pandas as pd
import pytest
from dotenv import load_dotenv
from letta_client import Letta
from tqdm import tqdm
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.llm_config import LLMConfig
logging.getLogger("httpx").setLevel(logging.WARNING)
logging.getLogger("httpcore").setLevel(logging.WARNING)
# --- Server Management --- #
def _run_server():
"""Starts the Letta server in a background thread."""
load_dotenv()
from letta.server.rest_api.app import start_server
start_server(debug=True)
@pytest.fixture(scope="session")
def server_url():
"""Ensures a server is running and returns its base URL."""
url = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
if not os.getenv("LETTA_SERVER_URL"):
thread = threading.Thread(target=_run_server, daemon=True)
thread.start()
time.sleep(2) # Allow server startup time
return url
# --- Client Setup --- #
@pytest.fixture(scope="session")
def client(server_url):
"""Creates a REST client for testing."""
client = Letta(base_url=server_url)
yield client
@pytest.fixture()
def roll_dice_tool(client):
def roll_dice():
"""
Rolls a 6 sided die.
Returns:
str: The roll result.
"""
return "Rolled a 10!"
tool = client.tools.upsert_from_function(func=roll_dice)
# Yield the created tool
yield tool
@pytest.fixture()
def rethink_tool(client):
def rethink_memory(agent_state: "AgentState", new_memory: str, target_block_label: str) -> str: # type: ignore
"""
Re-evaluate the memory in block_name, integrating new and updated facts.
Replace outdated information with the most likely truths, avoiding redundancy with original memories.
Ensure consistency with other memory blocks.
Args:
new_memory (str): The new memory with information integrated from the memory block. If there is no new information, then this should be the same as the content in the source block.
target_block_label (str): The name of the block to write to.
Returns:
str: None is always returned as this function does not produce a response.
"""
agent_state.memory.update_block_value(label=target_block_label, value=new_memory)
return None
tool = client.tools.upsert_from_function(func=rethink_memory)
yield tool
@pytest.fixture(scope="function")
def weather_tool(client):
def get_weather(location: str) -> str:
"""
Fetches the current weather for a given location.
Parameters:
location (str): The location to get the weather for.
Returns:
str: A formatted string describing the weather in the given location.
Raises:
RuntimeError: If the request to fetch weather data fails.
"""
import requests
url = f"https://wttr.in/{location}?format=%C+%t"
response = requests.get(url)
if response.status_code == 200:
weather_data = response.text
return f"The weather in {location} is {weather_data}."
else:
raise RuntimeError(f"Failed to get weather data, status code: {response.status_code}")
tool = client.tools.upsert_from_function(func=get_weather)
# Yield the created tool
yield tool
# --- Load Test --- #
@pytest.mark.slow
def test_parallel_mass_update_agents_complex(client, roll_dice_tool, weather_tool, rethink_tool):
# 1) Create 30 agents WITHOUT the rethink_tool initially
agent_ids = []
for i in range(5):
agent = client.agents.create(
name=f"complex_agent_{i}_{uuid.uuid4().hex[:6]}",
tool_ids=[roll_dice_tool.id, weather_tool.id],
include_base_tools=False,
memory_blocks=[
{"label": "human", "value": "Name: Matt"},
{"label": "persona", "value": "Friendly agent"},
],
llm_config=LLMConfig.default_config("gpt-4o-mini"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
)
agent_ids.append(agent.id)
# 2) Pre-create 10 new blocks *per* agent
per_agent_blocks = {}
for aid in agent_ids:
block_ids = []
for j in range(10):
blk = client.blocks.create(
label=f"{aid[:6]}_blk{j}",
value="Precreated block content",
description="Load-test block",
limit=500,
metadata={"idx": str(j)},
)
block_ids.append(blk.id)
per_agent_blocks[aid] = block_ids
# 3) Dispatch 100 updates per agent in parallel
total_updates = len(agent_ids) * 100
latencies = []
def do_update(agent_id: str):
start = time.time()
if random.random() < 0.5:
client.agents.modify(agent_id=agent_id, tool_ids=[rethink_tool.id])
else:
bid = random.choice(per_agent_blocks[agent_id])
client.agents.modify(agent_id=agent_id, block_ids=[bid])
return time.time() - start
with ThreadPoolExecutor(max_workers=50) as executor:
futures = [executor.submit(do_update, aid) for aid in agent_ids for _ in range(10)]
for future in tqdm(as_completed(futures), total=total_updates, desc="Complex updates"):
latencies.append(future.result())
# 4) Cleanup
for aid in agent_ids:
client.agents.delete(aid)
# 5) Plot latency distribution
df = pd.DataFrame({"latency": latencies})
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.hist(df["latency"], bins=30, edgecolor="black")
plt.title("Update Latency Distribution")
plt.xlabel("Latency (seconds)")
plt.ylabel("Frequency")
plt.subplot(1, 2, 2)
plt.boxplot(df["latency"], vert=False)
plt.title("Update Latency Boxplot")
plt.xlabel("Latency (seconds)")
plt.tight_layout()
plot_file = f"complex_update_latency_{int(time.time())}.png"
plt.savefig(plot_file)
plt.close()
# 6) Report summary
mean = df["latency"].mean()
median = df["latency"].median()
minimum = df["latency"].min()
maximum = df["latency"].max()
stdev = df["latency"].std()
print("\n===== Complex Update Latency Statistics =====")
print(f"Total updates: {len(latencies)}")
print(f"Mean: {mean:.3f}s")
print(f"Median: {median:.3f}s")
print(f"Min: {minimum:.3f}s")
print(f"Max: {maximum:.3f}s")
print(f"Std: {stdev:.3f}s")
print(f"Plot saved to: {plot_file}")
# Sanity assertion
assert median < 2.0, f"Median update latency too high: {median:.3f}s"