feat: Improve agent update performance (#1799)

This commit is contained in:
Matthew Zhou 2025-04-18 17:46:56 -07:00 committed by GitHub
parent a5e0698110
commit cdab671428
2 changed files with 156 additions and 116 deletions

View File

@ -1,9 +1,9 @@
from datetime import datetime
from datetime import datetime, timezone
from typing import Dict, List, Optional, Set, Tuple
import numpy as np
import sqlalchemy as sa
from sqlalchemy import Select, and_, func, insert, literal, or_, select, union_all
from sqlalchemy import Select, and_, delete, func, insert, literal, or_, select, union_all
from sqlalchemy.dialects.postgresql import insert as pg_insert
from letta.constants import (
@ -24,7 +24,6 @@ from letta.orm import Block as BlockModel
from letta.orm import BlocksAgents
from letta.orm import Group as GroupModel
from letta.orm import IdentitiesAgents
from letta.orm import Identity as IdentityModel
from letta.orm import Source as SourceModel
from letta.orm import SourcePassage, SourcesAgents
from letta.orm import Tool as ToolModel
@ -62,7 +61,6 @@ from letta.services.helpers.agent_manager_helper import (
_apply_pagination,
_apply_tag_filter,
_process_relationship,
_process_tags,
check_supports_structured_output,
compile_system_message,
derive_system_message,
@ -147,6 +145,18 @@ class AgentManager:
session.execute(stmt)
@staticmethod
@trace_method
def _replace_pivot_rows(session, table, agent_id: str, rows: list[dict]):
"""
Replace all pivot rows for an agent with *exactly* the provided list.
Uses two bulk statements (DELETE + INSERT ... ON CONFLICT DO NOTHING).
"""
# delete all existing rows for this agent
session.execute(delete(table).where(table.c.agent_id == agent_id))
if rows:
AgentManager._bulk_insert_pivot(session, table, rows)
# ======================================================================================================================
# Basic CRUD operations
# ======================================================================================================================
@ -322,86 +332,121 @@ class AgentManager:
return self.append_to_in_context_messages(init_messages, agent_id=agent_state.id, actor=actor)
@enforce_types
def update_agent(self, agent_id: str, agent_update: UpdateAgent, actor: PydanticUser) -> PydanticAgentState:
agent_state = self._update_agent(agent_id=agent_id, agent_update=agent_update, actor=actor)
def update_agent(
self,
agent_id: str,
agent_update: UpdateAgent,
actor: PydanticUser,
) -> PydanticAgentState:
# If there are provided environment variables, add them in
if agent_update.tool_exec_environment_variables:
agent_state = self._set_environment_variables(
agent_id=agent_state.id,
env_vars=agent_update.tool_exec_environment_variables,
actor=actor,
)
new_tools = set(agent_update.tool_ids or [])
new_sources = set(agent_update.source_ids or [])
new_blocks = set(agent_update.block_ids or [])
new_idents = set(agent_update.identity_ids or [])
new_tags = set(agent_update.tags or [])
# Rebuild the system prompt if it's different
if agent_update.enable_sleeptime and agent_update.system is None:
agent_update.system = derive_system_message(
agent_type=agent_state.agent_type,
enable_sleeptime=agent_update.enable_sleeptime,
system=agent_update.system or agent_state.system,
)
if agent_update.system and agent_update.system != agent_state.system:
agent_state = self.rebuild_system_prompt(agent_id=agent_state.id, actor=actor, force=True, update_timestamp=False)
with self.session_maker() as session, session.begin():
return agent_state
agent: AgentModel = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
agent.updated_at = datetime.now(timezone.utc)
agent.last_updated_by_id = actor.id
@enforce_types
def _update_agent(self, agent_id: str, agent_update: UpdateAgent, actor: PydanticUser) -> PydanticAgentState:
"""
Update an existing agent.
Args:
agent_id: The ID of the agent to update.
agent_update: UpdateAgent object containing the updated fields.
actor: User performing the action.
Returns:
PydanticAgentState: The updated agent as a Pydantic model.
"""
with self.session_maker() as session:
# Retrieve the existing agent
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
# Update scalar fields directly
scalar_fields = {
"name",
"system",
"llm_config",
"embedding_config",
"message_ids",
"tool_rules",
"description",
"metadata",
"project_id",
"template_id",
"base_template_id",
"message_buffer_autoclear",
"enable_sleeptime",
scalar_updates = {
"name": agent_update.name,
"system": agent_update.system,
"llm_config": agent_update.llm_config,
"embedding_config": agent_update.embedding_config,
"message_ids": agent_update.message_ids,
"tool_rules": agent_update.tool_rules,
"description": agent_update.description,
"project_id": agent_update.project_id,
"template_id": agent_update.template_id,
"base_template_id": agent_update.base_template_id,
"message_buffer_autoclear": agent_update.message_buffer_autoclear,
"enable_sleeptime": agent_update.enable_sleeptime,
}
for field in scalar_fields:
value = getattr(agent_update, field, None)
if value is not None:
if field == "metadata":
setattr(agent, "metadata_", value)
else:
setattr(agent, field, value)
for col, val in scalar_updates.items():
if val is not None:
setattr(agent, col, val)
if agent_update.metadata is not None:
agent.metadata_ = agent_update.metadata
aid = agent.id
# Update relationships using _process_relationship and _process_tags
if agent_update.tool_ids is not None:
_process_relationship(session, agent, "tools", ToolModel, agent_update.tool_ids, replace=True)
self._replace_pivot_rows(
session,
ToolsAgents.__table__,
aid,
[{"agent_id": aid, "tool_id": tid} for tid in new_tools],
)
session.expire(agent, ["tools"])
if agent_update.source_ids is not None:
_process_relationship(session, agent, "sources", SourceModel, agent_update.source_ids, replace=True)
self._replace_pivot_rows(
session,
SourcesAgents.__table__,
aid,
[{"agent_id": aid, "source_id": sid} for sid in new_sources],
)
session.expire(agent, ["sources"])
if agent_update.block_ids is not None:
_process_relationship(session, agent, "core_memory", BlockModel, agent_update.block_ids, replace=True)
if agent_update.tags is not None:
_process_tags(agent, agent_update.tags, replace=True)
rows = []
if new_blocks:
label_map = {
bid: lbl
for bid, lbl in session.execute(select(BlockModel.id, BlockModel.label).where(BlockModel.id.in_(new_blocks)))
}
rows = [{"agent_id": aid, "block_id": bid, "block_label": label_map[bid]} for bid in new_blocks]
self._replace_pivot_rows(session, BlocksAgents.__table__, aid, rows)
session.expire(agent, ["core_memory"])
if agent_update.identity_ids is not None:
_process_relationship(session, agent, "identities", IdentityModel, agent_update.identity_ids, replace=True)
self._replace_pivot_rows(
session,
IdentitiesAgents.__table__,
aid,
[{"agent_id": aid, "identity_id": iid} for iid in new_idents],
)
session.expire(agent, ["identities"])
# Commit and refresh the agent
agent.update(session, actor=actor)
if agent_update.tags is not None:
self._replace_pivot_rows(
session,
AgentsTags.__table__,
aid,
[{"agent_id": aid, "tag": tag} for tag in new_tags],
)
session.expire(agent, ["tags"])
if agent_update.tool_exec_environment_variables is not None:
session.execute(delete(AgentEnvironmentVariable).where(AgentEnvironmentVariable.agent_id == aid))
env_rows = [
{
"agent_id": aid,
"key": k,
"value": v,
"organization_id": agent.organization_id,
}
for k, v in agent_update.tool_exec_environment_variables.items()
]
if env_rows:
self._bulk_insert_pivot(session, AgentEnvironmentVariable.__table__, env_rows)
session.expire(agent, ["tool_exec_environment_variables"])
if agent_update.enable_sleeptime and agent_update.system is None:
agent.system = derive_system_message(
agent_type=agent.agent_type,
enable_sleeptime=agent_update.enable_sleeptime,
system=agent.system,
)
session.flush()
session.refresh(agent)
# Convert to PydanticAgentState and return
return agent.to_pydantic()
# TODO: Make this general and think about how to roll this into sqlalchemybase

View File

@ -4,7 +4,6 @@ import random
import threading
import time
import uuid
from concurrent.futures import ThreadPoolExecutor, as_completed
import matplotlib.pyplot as plt
import pandas as pd
@ -126,12 +125,15 @@ def weather_tool(client):
@pytest.mark.slow
def test_parallel_mass_update_agents_complex(client, roll_dice_tool, weather_tool, rethink_tool):
def test_sequential_mass_update_agents_complex(client, roll_dice_tool, weather_tool, rethink_tool):
# 1) Create 30 agents WITHOUT the rethink_tool initially
AGENT_COUNT = 30
UPDATES_PER_AGENT = 50
agent_ids = []
for i in range(5):
for i in range(AGENT_COUNT):
agent = client.agents.create(
name=f"complex_agent_{i}_{uuid.uuid4().hex[:6]}",
name=f"seq_agent_{i}_{uuid.uuid4().hex[:6]}",
tool_ids=[roll_dice_tool.id, weather_tool.id],
include_base_tools=False,
memory_blocks=[
@ -158,63 +160,56 @@ def test_parallel_mass_update_agents_complex(client, roll_dice_tool, weather_too
block_ids.append(blk.id)
per_agent_blocks[aid] = block_ids
# 3) Dispatch 100 updates per agent in parallel
total_updates = len(agent_ids) * 100
# 3) Sequential updates: measure latency for each (agent, iteration)
latencies = []
total_ops = AGENT_COUNT * UPDATES_PER_AGENT
def do_update(agent_id: str):
start = time.time()
if random.random() < 0.5:
client.agents.modify(agent_id=agent_id, tool_ids=[rethink_tool.id])
else:
bid = random.choice(per_agent_blocks[agent_id])
client.agents.modify(agent_id=agent_id, block_ids=[bid])
return time.time() - start
idx = 0
with tqdm(total=total_ops, desc="Sequential updates") as pbar:
for aid in agent_ids:
for _ in range(UPDATES_PER_AGENT):
start = time.time()
if random.random() < 0.5:
client.agents.modify(agent_id=aid, tool_ids=[rethink_tool.id])
else:
bid = random.choice(per_agent_blocks[aid])
client.agents.modify(agent_id=aid, block_ids=[bid])
elapsed = time.time() - start
with ThreadPoolExecutor(max_workers=50) as executor:
futures = [executor.submit(do_update, aid) for aid in agent_ids for _ in range(10)]
for future in tqdm(as_completed(futures), total=total_updates, desc="Complex updates"):
latencies.append(future.result())
latencies.append(elapsed)
idx += 1
pbar.update(1)
# 4) Cleanup
for aid in agent_ids:
client.agents.delete(aid)
# 5) Plot latency distribution
# 5) Lineplot every single latency
df = pd.DataFrame({"latency": latencies})
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.hist(df["latency"], bins=30, edgecolor="black")
plt.title("Update Latency Distribution")
plt.xlabel("Latency (seconds)")
plt.ylabel("Frequency")
plt.subplot(1, 2, 2)
plt.boxplot(df["latency"], vert=False)
plt.title("Update Latency Boxplot")
plt.xlabel("Latency (seconds)")
plt.figure(figsize=(10, 5))
plt.plot(df["latency"].values, marker=".", linestyle="-", alpha=0.7)
plt.title("Sequential Update Latencies Over Time")
plt.xlabel("Operation Index")
plt.ylabel("Latency (s)")
plt.grid(True, alpha=0.3)
plot_file = f"seq_update_latency_{int(time.time())}.png"
plt.tight_layout()
plot_file = f"complex_update_latency_{int(time.time())}.png"
plt.savefig(plot_file)
plt.close()
# 6) Report summary
# 6) Summary
mean = df["latency"].mean()
median = df["latency"].median()
minimum = df["latency"].min()
maximum = df["latency"].max()
stdev = df["latency"].std()
print("\n===== Complex Update Latency Statistics =====")
print(f"Total updates: {len(latencies)}")
print(f"Mean: {mean:.3f}s")
print(f"Median: {median:.3f}s")
print(f"Min: {minimum:.3f}s")
print(f"Max: {maximum:.3f}s")
print(f"Std: {stdev:.3f}s")
print(f"Plot saved to: {plot_file}")
# Sanity assertion
assert median < 2.0, f"Median update latency too high: {median:.3f}s"
print("\n===== Sequential Complex Update Latencies =====")
print(f"Total ops : {len(latencies)}")
print(f"Mean : {mean:.3f}s")
print(f"Median : {median:.3f}s")
print(f"Min : {minimum:.3f}s")
print(f"Max : {maximum:.3f}s")
print(f"Std dev : {stdev:.3f}s")
print(f"Plot saved: {plot_file}")