From 5cc618331f8464223ef970e084c7f13f1caaac95 Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Sat, 19 Apr 2025 22:08:59 -0700 Subject: [PATCH] fix: Drop sqlalchemy sequence on sequence (#1808) --- letta/orm/message.py | 3 +- performance_tests/test_agent_mass_update.py | 83 +++++++++++---------- 2 files changed, 45 insertions(+), 41 deletions(-) diff --git a/letta/orm/message.py b/letta/orm/message.py index a9febdb8e..b5c65ec3b 100644 --- a/letta/orm/message.py +++ b/letta/orm/message.py @@ -1,7 +1,7 @@ from typing import List, Optional from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall -from sqlalchemy import BigInteger, FetchedValue, ForeignKey, Index, Sequence, event, text +from sqlalchemy import BigInteger, FetchedValue, ForeignKey, Index, event, text from sqlalchemy.orm import Mapped, Session, mapped_column, relationship from letta.orm.custom_columns import MessageContentColumn, ToolCallColumn, ToolReturnColumn @@ -48,7 +48,6 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin): # Monotonically increasing sequence for efficient/correct listing sequence_id: Mapped[int] = mapped_column( BigInteger, - Sequence("message_seq_id"), server_default=FetchedValue(), unique=True, nullable=False, diff --git a/performance_tests/test_agent_mass_update.py b/performance_tests/test_agent_mass_update.py index 076a816e7..841462ef2 100644 --- a/performance_tests/test_agent_mass_update.py +++ b/performance_tests/test_agent_mass_update.py @@ -4,6 +4,7 @@ import random import threading import time import uuid +from concurrent.futures import ThreadPoolExecutor, as_completed import matplotlib.pyplot as plt import pandas as pd @@ -125,15 +126,12 @@ def weather_tool(client): @pytest.mark.slow -def test_sequential_mass_update_agents_complex(client, roll_dice_tool, weather_tool, rethink_tool): +def test_parallel_mass_update_agents_complex(client, roll_dice_tool, weather_tool, rethink_tool): # 1) Create 30 agents WITHOUT the rethink_tool initially - AGENT_COUNT = 30 - UPDATES_PER_AGENT = 50 - agent_ids = [] - for i in range(AGENT_COUNT): + for i in range(5): agent = client.agents.create( - name=f"seq_agent_{i}_{uuid.uuid4().hex[:6]}", + name=f"complex_agent_{i}_{uuid.uuid4().hex[:6]}", tool_ids=[roll_dice_tool.id, weather_tool.id], include_base_tools=False, memory_blocks=[ @@ -160,56 +158,63 @@ def test_sequential_mass_update_agents_complex(client, roll_dice_tool, weather_t block_ids.append(blk.id) per_agent_blocks[aid] = block_ids - # 3) Sequential updates: measure latency for each (agent, iteration) + # 3) Dispatch 100 updates per agent in parallel + total_updates = len(agent_ids) * 100 latencies = [] - total_ops = AGENT_COUNT * UPDATES_PER_AGENT - idx = 0 - with tqdm(total=total_ops, desc="Sequential updates") as pbar: - for aid in agent_ids: - for _ in range(UPDATES_PER_AGENT): - start = time.time() - if random.random() < 0.5: - client.agents.modify(agent_id=aid, tool_ids=[rethink_tool.id]) - else: - bid = random.choice(per_agent_blocks[aid]) - client.agents.modify(agent_id=aid, block_ids=[bid]) - elapsed = time.time() - start + def do_update(agent_id: str): + start = time.time() + if random.random() < 0.5: + client.agents.modify(agent_id=agent_id, tool_ids=[rethink_tool.id]) + else: + bid = random.choice(per_agent_blocks[agent_id]) + client.agents.modify(agent_id=agent_id, block_ids=[bid]) + return time.time() - start - latencies.append(elapsed) - idx += 1 - pbar.update(1) + with ThreadPoolExecutor(max_workers=50) as executor: + futures = [executor.submit(do_update, aid) for aid in agent_ids for _ in range(10)] + for future in tqdm(as_completed(futures), total=total_updates, desc="Complex updates"): + latencies.append(future.result()) # 4) Cleanup for aid in agent_ids: client.agents.delete(aid) - # 5) Line‐plot every single latency + # 5) Plot latency distribution df = pd.DataFrame({"latency": latencies}) - plt.figure(figsize=(10, 5)) - plt.plot(df["latency"].values, marker=".", linestyle="-", alpha=0.7) - plt.title("Sequential Update Latencies Over Time") - plt.xlabel("Operation Index") - plt.ylabel("Latency (s)") - plt.grid(True, alpha=0.3) + plt.figure(figsize=(12, 6)) + + plt.subplot(1, 2, 1) + plt.hist(df["latency"], bins=30, edgecolor="black") + plt.title("Update Latency Distribution") + plt.xlabel("Latency (seconds)") + plt.ylabel("Frequency") + + plt.subplot(1, 2, 2) + plt.boxplot(df["latency"], vert=False) + plt.title("Update Latency Boxplot") + plt.xlabel("Latency (seconds)") - plot_file = f"seq_update_latency_{int(time.time())}.png" plt.tight_layout() + plot_file = f"complex_update_latency_{int(time.time())}.png" plt.savefig(plot_file) plt.close() - # 6) Summary + # 6) Report summary mean = df["latency"].mean() median = df["latency"].median() minimum = df["latency"].min() maximum = df["latency"].max() stdev = df["latency"].std() - print("\n===== Sequential Complex Update Latencies =====") - print(f"Total ops : {len(latencies)}") - print(f"Mean : {mean:.3f}s") - print(f"Median : {median:.3f}s") - print(f"Min : {minimum:.3f}s") - print(f"Max : {maximum:.3f}s") - print(f"Std dev : {stdev:.3f}s") - print(f"Plot saved: {plot_file}") + print("\n===== Complex Update Latency Statistics =====") + print(f"Total updates: {len(latencies)}") + print(f"Mean: {mean:.3f}s") + print(f"Median: {median:.3f}s") + print(f"Min: {minimum:.3f}s") + print(f"Max: {maximum:.3f}s") + print(f"Std: {stdev:.3f}s") + print(f"Plot saved to: {plot_file}") + + # Sanity assertion + assert median < 2.0, f"Median update latency too high: {median:.3f}s"