fix: Drop sqlalchemy sequence on sequence (#1808)

This commit is contained in:
Matthew Zhou 2025-04-19 22:08:59 -07:00 committed by GitHub
parent df2ae9dd1c
commit 5cc618331f
2 changed files with 45 additions and 41 deletions

View File

@ -1,7 +1,7 @@
from typing import List, Optional
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall
from sqlalchemy import BigInteger, FetchedValue, ForeignKey, Index, Sequence, event, text
from sqlalchemy import BigInteger, FetchedValue, ForeignKey, Index, event, text
from sqlalchemy.orm import Mapped, Session, mapped_column, relationship
from letta.orm.custom_columns import MessageContentColumn, ToolCallColumn, ToolReturnColumn
@ -48,7 +48,6 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):
# Monotonically increasing sequence for efficient/correct listing
sequence_id: Mapped[int] = mapped_column(
BigInteger,
Sequence("message_seq_id"),
server_default=FetchedValue(),
unique=True,
nullable=False,

View File

@ -4,6 +4,7 @@ import random
import threading
import time
import uuid
from concurrent.futures import ThreadPoolExecutor, as_completed
import matplotlib.pyplot as plt
import pandas as pd
@ -125,15 +126,12 @@ def weather_tool(client):
@pytest.mark.slow
def test_sequential_mass_update_agents_complex(client, roll_dice_tool, weather_tool, rethink_tool):
def test_parallel_mass_update_agents_complex(client, roll_dice_tool, weather_tool, rethink_tool):
# 1) Create 30 agents WITHOUT the rethink_tool initially
AGENT_COUNT = 30
UPDATES_PER_AGENT = 50
agent_ids = []
for i in range(AGENT_COUNT):
for i in range(5):
agent = client.agents.create(
name=f"seq_agent_{i}_{uuid.uuid4().hex[:6]}",
name=f"complex_agent_{i}_{uuid.uuid4().hex[:6]}",
tool_ids=[roll_dice_tool.id, weather_tool.id],
include_base_tools=False,
memory_blocks=[
@ -160,56 +158,63 @@ def test_sequential_mass_update_agents_complex(client, roll_dice_tool, weather_t
block_ids.append(blk.id)
per_agent_blocks[aid] = block_ids
# 3) Sequential updates: measure latency for each (agent, iteration)
# 3) Dispatch 100 updates per agent in parallel
total_updates = len(agent_ids) * 100
latencies = []
total_ops = AGENT_COUNT * UPDATES_PER_AGENT
idx = 0
with tqdm(total=total_ops, desc="Sequential updates") as pbar:
for aid in agent_ids:
for _ in range(UPDATES_PER_AGENT):
start = time.time()
if random.random() < 0.5:
client.agents.modify(agent_id=aid, tool_ids=[rethink_tool.id])
else:
bid = random.choice(per_agent_blocks[aid])
client.agents.modify(agent_id=aid, block_ids=[bid])
elapsed = time.time() - start
def do_update(agent_id: str):
start = time.time()
if random.random() < 0.5:
client.agents.modify(agent_id=agent_id, tool_ids=[rethink_tool.id])
else:
bid = random.choice(per_agent_blocks[agent_id])
client.agents.modify(agent_id=agent_id, block_ids=[bid])
return time.time() - start
latencies.append(elapsed)
idx += 1
pbar.update(1)
with ThreadPoolExecutor(max_workers=50) as executor:
futures = [executor.submit(do_update, aid) for aid in agent_ids for _ in range(10)]
for future in tqdm(as_completed(futures), total=total_updates, desc="Complex updates"):
latencies.append(future.result())
# 4) Cleanup
for aid in agent_ids:
client.agents.delete(aid)
# 5) Lineplot every single latency
# 5) Plot latency distribution
df = pd.DataFrame({"latency": latencies})
plt.figure(figsize=(10, 5))
plt.plot(df["latency"].values, marker=".", linestyle="-", alpha=0.7)
plt.title("Sequential Update Latencies Over Time")
plt.xlabel("Operation Index")
plt.ylabel("Latency (s)")
plt.grid(True, alpha=0.3)
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.hist(df["latency"], bins=30, edgecolor="black")
plt.title("Update Latency Distribution")
plt.xlabel("Latency (seconds)")
plt.ylabel("Frequency")
plt.subplot(1, 2, 2)
plt.boxplot(df["latency"], vert=False)
plt.title("Update Latency Boxplot")
plt.xlabel("Latency (seconds)")
plot_file = f"seq_update_latency_{int(time.time())}.png"
plt.tight_layout()
plot_file = f"complex_update_latency_{int(time.time())}.png"
plt.savefig(plot_file)
plt.close()
# 6) Summary
# 6) Report summary
mean = df["latency"].mean()
median = df["latency"].median()
minimum = df["latency"].min()
maximum = df["latency"].max()
stdev = df["latency"].std()
print("\n===== Sequential Complex Update Latencies =====")
print(f"Total ops : {len(latencies)}")
print(f"Mean : {mean:.3f}s")
print(f"Median : {median:.3f}s")
print(f"Min : {minimum:.3f}s")
print(f"Max : {maximum:.3f}s")
print(f"Std dev : {stdev:.3f}s")
print(f"Plot saved: {plot_file}")
print("\n===== Complex Update Latency Statistics =====")
print(f"Total updates: {len(latencies)}")
print(f"Mean: {mean:.3f}s")
print(f"Median: {median:.3f}s")
print(f"Min: {minimum:.3f}s")
print(f"Max: {maximum:.3f}s")
print(f"Std: {stdev:.3f}s")
print(f"Plot saved to: {plot_file}")
# Sanity assertion
assert median < 2.0, f"Median update latency too high: {median:.3f}s"