mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
fix: Drop sqlalchemy sequence on sequence (#1808)
This commit is contained in:
parent
df2ae9dd1c
commit
5cc618331f
@ -1,7 +1,7 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall
|
||||
from sqlalchemy import BigInteger, FetchedValue, ForeignKey, Index, Sequence, event, text
|
||||
from sqlalchemy import BigInteger, FetchedValue, ForeignKey, Index, event, text
|
||||
from sqlalchemy.orm import Mapped, Session, mapped_column, relationship
|
||||
|
||||
from letta.orm.custom_columns import MessageContentColumn, ToolCallColumn, ToolReturnColumn
|
||||
@ -48,7 +48,6 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):
|
||||
# Monotonically increasing sequence for efficient/correct listing
|
||||
sequence_id: Mapped[int] = mapped_column(
|
||||
BigInteger,
|
||||
Sequence("message_seq_id"),
|
||||
server_default=FetchedValue(),
|
||||
unique=True,
|
||||
nullable=False,
|
||||
|
@ -4,6 +4,7 @@ import random
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
@ -125,15 +126,12 @@ def weather_tool(client):
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_sequential_mass_update_agents_complex(client, roll_dice_tool, weather_tool, rethink_tool):
|
||||
def test_parallel_mass_update_agents_complex(client, roll_dice_tool, weather_tool, rethink_tool):
|
||||
# 1) Create 30 agents WITHOUT the rethink_tool initially
|
||||
AGENT_COUNT = 30
|
||||
UPDATES_PER_AGENT = 50
|
||||
|
||||
agent_ids = []
|
||||
for i in range(AGENT_COUNT):
|
||||
for i in range(5):
|
||||
agent = client.agents.create(
|
||||
name=f"seq_agent_{i}_{uuid.uuid4().hex[:6]}",
|
||||
name=f"complex_agent_{i}_{uuid.uuid4().hex[:6]}",
|
||||
tool_ids=[roll_dice_tool.id, weather_tool.id],
|
||||
include_base_tools=False,
|
||||
memory_blocks=[
|
||||
@ -160,56 +158,63 @@ def test_sequential_mass_update_agents_complex(client, roll_dice_tool, weather_t
|
||||
block_ids.append(blk.id)
|
||||
per_agent_blocks[aid] = block_ids
|
||||
|
||||
# 3) Sequential updates: measure latency for each (agent, iteration)
|
||||
# 3) Dispatch 100 updates per agent in parallel
|
||||
total_updates = len(agent_ids) * 100
|
||||
latencies = []
|
||||
total_ops = AGENT_COUNT * UPDATES_PER_AGENT
|
||||
|
||||
idx = 0
|
||||
with tqdm(total=total_ops, desc="Sequential updates") as pbar:
|
||||
for aid in agent_ids:
|
||||
for _ in range(UPDATES_PER_AGENT):
|
||||
start = time.time()
|
||||
if random.random() < 0.5:
|
||||
client.agents.modify(agent_id=aid, tool_ids=[rethink_tool.id])
|
||||
else:
|
||||
bid = random.choice(per_agent_blocks[aid])
|
||||
client.agents.modify(agent_id=aid, block_ids=[bid])
|
||||
elapsed = time.time() - start
|
||||
def do_update(agent_id: str):
|
||||
start = time.time()
|
||||
if random.random() < 0.5:
|
||||
client.agents.modify(agent_id=agent_id, tool_ids=[rethink_tool.id])
|
||||
else:
|
||||
bid = random.choice(per_agent_blocks[agent_id])
|
||||
client.agents.modify(agent_id=agent_id, block_ids=[bid])
|
||||
return time.time() - start
|
||||
|
||||
latencies.append(elapsed)
|
||||
idx += 1
|
||||
pbar.update(1)
|
||||
with ThreadPoolExecutor(max_workers=50) as executor:
|
||||
futures = [executor.submit(do_update, aid) for aid in agent_ids for _ in range(10)]
|
||||
for future in tqdm(as_completed(futures), total=total_updates, desc="Complex updates"):
|
||||
latencies.append(future.result())
|
||||
|
||||
# 4) Cleanup
|
||||
for aid in agent_ids:
|
||||
client.agents.delete(aid)
|
||||
|
||||
# 5) Line‐plot every single latency
|
||||
# 5) Plot latency distribution
|
||||
df = pd.DataFrame({"latency": latencies})
|
||||
plt.figure(figsize=(10, 5))
|
||||
plt.plot(df["latency"].values, marker=".", linestyle="-", alpha=0.7)
|
||||
plt.title("Sequential Update Latencies Over Time")
|
||||
plt.xlabel("Operation Index")
|
||||
plt.ylabel("Latency (s)")
|
||||
plt.grid(True, alpha=0.3)
|
||||
plt.figure(figsize=(12, 6))
|
||||
|
||||
plt.subplot(1, 2, 1)
|
||||
plt.hist(df["latency"], bins=30, edgecolor="black")
|
||||
plt.title("Update Latency Distribution")
|
||||
plt.xlabel("Latency (seconds)")
|
||||
plt.ylabel("Frequency")
|
||||
|
||||
plt.subplot(1, 2, 2)
|
||||
plt.boxplot(df["latency"], vert=False)
|
||||
plt.title("Update Latency Boxplot")
|
||||
plt.xlabel("Latency (seconds)")
|
||||
|
||||
plot_file = f"seq_update_latency_{int(time.time())}.png"
|
||||
plt.tight_layout()
|
||||
plot_file = f"complex_update_latency_{int(time.time())}.png"
|
||||
plt.savefig(plot_file)
|
||||
plt.close()
|
||||
|
||||
# 6) Summary
|
||||
# 6) Report summary
|
||||
mean = df["latency"].mean()
|
||||
median = df["latency"].median()
|
||||
minimum = df["latency"].min()
|
||||
maximum = df["latency"].max()
|
||||
stdev = df["latency"].std()
|
||||
|
||||
print("\n===== Sequential Complex Update Latencies =====")
|
||||
print(f"Total ops : {len(latencies)}")
|
||||
print(f"Mean : {mean:.3f}s")
|
||||
print(f"Median : {median:.3f}s")
|
||||
print(f"Min : {minimum:.3f}s")
|
||||
print(f"Max : {maximum:.3f}s")
|
||||
print(f"Std dev : {stdev:.3f}s")
|
||||
print(f"Plot saved: {plot_file}")
|
||||
print("\n===== Complex Update Latency Statistics =====")
|
||||
print(f"Total updates: {len(latencies)}")
|
||||
print(f"Mean: {mean:.3f}s")
|
||||
print(f"Median: {median:.3f}s")
|
||||
print(f"Min: {minimum:.3f}s")
|
||||
print(f"Max: {maximum:.3f}s")
|
||||
print(f"Std: {stdev:.3f}s")
|
||||
print(f"Plot saved to: {plot_file}")
|
||||
|
||||
# Sanity assertion
|
||||
assert median < 2.0, f"Median update latency too high: {median:.3f}s"
|
||||
|
Loading…
Reference in New Issue
Block a user