mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
feat: Improve agent update performance (#1799)
This commit is contained in:
parent
a5e0698110
commit
cdab671428
@ -1,9 +1,9 @@
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
|
||||
import numpy as np
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import Select, and_, func, insert, literal, or_, select, union_all
|
||||
from sqlalchemy import Select, and_, delete, func, insert, literal, or_, select, union_all
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
|
||||
from letta.constants import (
|
||||
@ -24,7 +24,6 @@ from letta.orm import Block as BlockModel
|
||||
from letta.orm import BlocksAgents
|
||||
from letta.orm import Group as GroupModel
|
||||
from letta.orm import IdentitiesAgents
|
||||
from letta.orm import Identity as IdentityModel
|
||||
from letta.orm import Source as SourceModel
|
||||
from letta.orm import SourcePassage, SourcesAgents
|
||||
from letta.orm import Tool as ToolModel
|
||||
@ -62,7 +61,6 @@ from letta.services.helpers.agent_manager_helper import (
|
||||
_apply_pagination,
|
||||
_apply_tag_filter,
|
||||
_process_relationship,
|
||||
_process_tags,
|
||||
check_supports_structured_output,
|
||||
compile_system_message,
|
||||
derive_system_message,
|
||||
@ -147,6 +145,18 @@ class AgentManager:
|
||||
|
||||
session.execute(stmt)
|
||||
|
||||
@staticmethod
|
||||
@trace_method
|
||||
def _replace_pivot_rows(session, table, agent_id: str, rows: list[dict]):
|
||||
"""
|
||||
Replace all pivot rows for an agent with *exactly* the provided list.
|
||||
Uses two bulk statements (DELETE + INSERT ... ON CONFLICT DO NOTHING).
|
||||
"""
|
||||
# delete all existing rows for this agent
|
||||
session.execute(delete(table).where(table.c.agent_id == agent_id))
|
||||
if rows:
|
||||
AgentManager._bulk_insert_pivot(session, table, rows)
|
||||
|
||||
# ======================================================================================================================
|
||||
# Basic CRUD operations
|
||||
# ======================================================================================================================
|
||||
@ -322,86 +332,121 @@ class AgentManager:
|
||||
return self.append_to_in_context_messages(init_messages, agent_id=agent_state.id, actor=actor)
|
||||
|
||||
@enforce_types
|
||||
def update_agent(self, agent_id: str, agent_update: UpdateAgent, actor: PydanticUser) -> PydanticAgentState:
|
||||
agent_state = self._update_agent(agent_id=agent_id, agent_update=agent_update, actor=actor)
|
||||
def update_agent(
|
||||
self,
|
||||
agent_id: str,
|
||||
agent_update: UpdateAgent,
|
||||
actor: PydanticUser,
|
||||
) -> PydanticAgentState:
|
||||
|
||||
# If there are provided environment variables, add them in
|
||||
if agent_update.tool_exec_environment_variables:
|
||||
agent_state = self._set_environment_variables(
|
||||
agent_id=agent_state.id,
|
||||
env_vars=agent_update.tool_exec_environment_variables,
|
||||
actor=actor,
|
||||
)
|
||||
new_tools = set(agent_update.tool_ids or [])
|
||||
new_sources = set(agent_update.source_ids or [])
|
||||
new_blocks = set(agent_update.block_ids or [])
|
||||
new_idents = set(agent_update.identity_ids or [])
|
||||
new_tags = set(agent_update.tags or [])
|
||||
|
||||
# Rebuild the system prompt if it's different
|
||||
if agent_update.enable_sleeptime and agent_update.system is None:
|
||||
agent_update.system = derive_system_message(
|
||||
agent_type=agent_state.agent_type,
|
||||
enable_sleeptime=agent_update.enable_sleeptime,
|
||||
system=agent_update.system or agent_state.system,
|
||||
)
|
||||
if agent_update.system and agent_update.system != agent_state.system:
|
||||
agent_state = self.rebuild_system_prompt(agent_id=agent_state.id, actor=actor, force=True, update_timestamp=False)
|
||||
with self.session_maker() as session, session.begin():
|
||||
|
||||
return agent_state
|
||||
agent: AgentModel = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||||
agent.updated_at = datetime.now(timezone.utc)
|
||||
agent.last_updated_by_id = actor.id
|
||||
|
||||
@enforce_types
|
||||
def _update_agent(self, agent_id: str, agent_update: UpdateAgent, actor: PydanticUser) -> PydanticAgentState:
|
||||
"""
|
||||
Update an existing agent.
|
||||
|
||||
Args:
|
||||
agent_id: The ID of the agent to update.
|
||||
agent_update: UpdateAgent object containing the updated fields.
|
||||
actor: User performing the action.
|
||||
|
||||
Returns:
|
||||
PydanticAgentState: The updated agent as a Pydantic model.
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
# Retrieve the existing agent
|
||||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||||
|
||||
# Update scalar fields directly
|
||||
scalar_fields = {
|
||||
"name",
|
||||
"system",
|
||||
"llm_config",
|
||||
"embedding_config",
|
||||
"message_ids",
|
||||
"tool_rules",
|
||||
"description",
|
||||
"metadata",
|
||||
"project_id",
|
||||
"template_id",
|
||||
"base_template_id",
|
||||
"message_buffer_autoclear",
|
||||
"enable_sleeptime",
|
||||
scalar_updates = {
|
||||
"name": agent_update.name,
|
||||
"system": agent_update.system,
|
||||
"llm_config": agent_update.llm_config,
|
||||
"embedding_config": agent_update.embedding_config,
|
||||
"message_ids": agent_update.message_ids,
|
||||
"tool_rules": agent_update.tool_rules,
|
||||
"description": agent_update.description,
|
||||
"project_id": agent_update.project_id,
|
||||
"template_id": agent_update.template_id,
|
||||
"base_template_id": agent_update.base_template_id,
|
||||
"message_buffer_autoclear": agent_update.message_buffer_autoclear,
|
||||
"enable_sleeptime": agent_update.enable_sleeptime,
|
||||
}
|
||||
for field in scalar_fields:
|
||||
value = getattr(agent_update, field, None)
|
||||
if value is not None:
|
||||
if field == "metadata":
|
||||
setattr(agent, "metadata_", value)
|
||||
else:
|
||||
setattr(agent, field, value)
|
||||
for col, val in scalar_updates.items():
|
||||
if val is not None:
|
||||
setattr(agent, col, val)
|
||||
|
||||
if agent_update.metadata is not None:
|
||||
agent.metadata_ = agent_update.metadata
|
||||
|
||||
aid = agent.id
|
||||
|
||||
# Update relationships using _process_relationship and _process_tags
|
||||
if agent_update.tool_ids is not None:
|
||||
_process_relationship(session, agent, "tools", ToolModel, agent_update.tool_ids, replace=True)
|
||||
self._replace_pivot_rows(
|
||||
session,
|
||||
ToolsAgents.__table__,
|
||||
aid,
|
||||
[{"agent_id": aid, "tool_id": tid} for tid in new_tools],
|
||||
)
|
||||
session.expire(agent, ["tools"])
|
||||
|
||||
if agent_update.source_ids is not None:
|
||||
_process_relationship(session, agent, "sources", SourceModel, agent_update.source_ids, replace=True)
|
||||
self._replace_pivot_rows(
|
||||
session,
|
||||
SourcesAgents.__table__,
|
||||
aid,
|
||||
[{"agent_id": aid, "source_id": sid} for sid in new_sources],
|
||||
)
|
||||
session.expire(agent, ["sources"])
|
||||
|
||||
if agent_update.block_ids is not None:
|
||||
_process_relationship(session, agent, "core_memory", BlockModel, agent_update.block_ids, replace=True)
|
||||
if agent_update.tags is not None:
|
||||
_process_tags(agent, agent_update.tags, replace=True)
|
||||
rows = []
|
||||
if new_blocks:
|
||||
label_map = {
|
||||
bid: lbl
|
||||
for bid, lbl in session.execute(select(BlockModel.id, BlockModel.label).where(BlockModel.id.in_(new_blocks)))
|
||||
}
|
||||
rows = [{"agent_id": aid, "block_id": bid, "block_label": label_map[bid]} for bid in new_blocks]
|
||||
|
||||
self._replace_pivot_rows(session, BlocksAgents.__table__, aid, rows)
|
||||
session.expire(agent, ["core_memory"])
|
||||
|
||||
if agent_update.identity_ids is not None:
|
||||
_process_relationship(session, agent, "identities", IdentityModel, agent_update.identity_ids, replace=True)
|
||||
self._replace_pivot_rows(
|
||||
session,
|
||||
IdentitiesAgents.__table__,
|
||||
aid,
|
||||
[{"agent_id": aid, "identity_id": iid} for iid in new_idents],
|
||||
)
|
||||
session.expire(agent, ["identities"])
|
||||
|
||||
# Commit and refresh the agent
|
||||
agent.update(session, actor=actor)
|
||||
if agent_update.tags is not None:
|
||||
self._replace_pivot_rows(
|
||||
session,
|
||||
AgentsTags.__table__,
|
||||
aid,
|
||||
[{"agent_id": aid, "tag": tag} for tag in new_tags],
|
||||
)
|
||||
session.expire(agent, ["tags"])
|
||||
|
||||
if agent_update.tool_exec_environment_variables is not None:
|
||||
session.execute(delete(AgentEnvironmentVariable).where(AgentEnvironmentVariable.agent_id == aid))
|
||||
env_rows = [
|
||||
{
|
||||
"agent_id": aid,
|
||||
"key": k,
|
||||
"value": v,
|
||||
"organization_id": agent.organization_id,
|
||||
}
|
||||
for k, v in agent_update.tool_exec_environment_variables.items()
|
||||
]
|
||||
if env_rows:
|
||||
self._bulk_insert_pivot(session, AgentEnvironmentVariable.__table__, env_rows)
|
||||
session.expire(agent, ["tool_exec_environment_variables"])
|
||||
|
||||
if agent_update.enable_sleeptime and agent_update.system is None:
|
||||
agent.system = derive_system_message(
|
||||
agent_type=agent.agent_type,
|
||||
enable_sleeptime=agent_update.enable_sleeptime,
|
||||
system=agent.system,
|
||||
)
|
||||
|
||||
session.flush()
|
||||
session.refresh(agent)
|
||||
|
||||
# Convert to PydanticAgentState and return
|
||||
return agent.to_pydantic()
|
||||
|
||||
# TODO: Make this general and think about how to roll this into sqlalchemybase
|
||||
|
@ -4,7 +4,6 @@ import random
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
@ -126,12 +125,15 @@ def weather_tool(client):
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_parallel_mass_update_agents_complex(client, roll_dice_tool, weather_tool, rethink_tool):
|
||||
def test_sequential_mass_update_agents_complex(client, roll_dice_tool, weather_tool, rethink_tool):
|
||||
# 1) Create 30 agents WITHOUT the rethink_tool initially
|
||||
AGENT_COUNT = 30
|
||||
UPDATES_PER_AGENT = 50
|
||||
|
||||
agent_ids = []
|
||||
for i in range(5):
|
||||
for i in range(AGENT_COUNT):
|
||||
agent = client.agents.create(
|
||||
name=f"complex_agent_{i}_{uuid.uuid4().hex[:6]}",
|
||||
name=f"seq_agent_{i}_{uuid.uuid4().hex[:6]}",
|
||||
tool_ids=[roll_dice_tool.id, weather_tool.id],
|
||||
include_base_tools=False,
|
||||
memory_blocks=[
|
||||
@ -158,63 +160,56 @@ def test_parallel_mass_update_agents_complex(client, roll_dice_tool, weather_too
|
||||
block_ids.append(blk.id)
|
||||
per_agent_blocks[aid] = block_ids
|
||||
|
||||
# 3) Dispatch 100 updates per agent in parallel
|
||||
total_updates = len(agent_ids) * 100
|
||||
# 3) Sequential updates: measure latency for each (agent, iteration)
|
||||
latencies = []
|
||||
total_ops = AGENT_COUNT * UPDATES_PER_AGENT
|
||||
|
||||
def do_update(agent_id: str):
|
||||
start = time.time()
|
||||
if random.random() < 0.5:
|
||||
client.agents.modify(agent_id=agent_id, tool_ids=[rethink_tool.id])
|
||||
else:
|
||||
bid = random.choice(per_agent_blocks[agent_id])
|
||||
client.agents.modify(agent_id=agent_id, block_ids=[bid])
|
||||
return time.time() - start
|
||||
idx = 0
|
||||
with tqdm(total=total_ops, desc="Sequential updates") as pbar:
|
||||
for aid in agent_ids:
|
||||
for _ in range(UPDATES_PER_AGENT):
|
||||
start = time.time()
|
||||
if random.random() < 0.5:
|
||||
client.agents.modify(agent_id=aid, tool_ids=[rethink_tool.id])
|
||||
else:
|
||||
bid = random.choice(per_agent_blocks[aid])
|
||||
client.agents.modify(agent_id=aid, block_ids=[bid])
|
||||
elapsed = time.time() - start
|
||||
|
||||
with ThreadPoolExecutor(max_workers=50) as executor:
|
||||
futures = [executor.submit(do_update, aid) for aid in agent_ids for _ in range(10)]
|
||||
for future in tqdm(as_completed(futures), total=total_updates, desc="Complex updates"):
|
||||
latencies.append(future.result())
|
||||
latencies.append(elapsed)
|
||||
idx += 1
|
||||
pbar.update(1)
|
||||
|
||||
# 4) Cleanup
|
||||
for aid in agent_ids:
|
||||
client.agents.delete(aid)
|
||||
|
||||
# 5) Plot latency distribution
|
||||
# 5) Line‐plot every single latency
|
||||
df = pd.DataFrame({"latency": latencies})
|
||||
plt.figure(figsize=(12, 6))
|
||||
|
||||
plt.subplot(1, 2, 1)
|
||||
plt.hist(df["latency"], bins=30, edgecolor="black")
|
||||
plt.title("Update Latency Distribution")
|
||||
plt.xlabel("Latency (seconds)")
|
||||
plt.ylabel("Frequency")
|
||||
|
||||
plt.subplot(1, 2, 2)
|
||||
plt.boxplot(df["latency"], vert=False)
|
||||
plt.title("Update Latency Boxplot")
|
||||
plt.xlabel("Latency (seconds)")
|
||||
plt.figure(figsize=(10, 5))
|
||||
plt.plot(df["latency"].values, marker=".", linestyle="-", alpha=0.7)
|
||||
plt.title("Sequential Update Latencies Over Time")
|
||||
plt.xlabel("Operation Index")
|
||||
plt.ylabel("Latency (s)")
|
||||
plt.grid(True, alpha=0.3)
|
||||
|
||||
plot_file = f"seq_update_latency_{int(time.time())}.png"
|
||||
plt.tight_layout()
|
||||
plot_file = f"complex_update_latency_{int(time.time())}.png"
|
||||
plt.savefig(plot_file)
|
||||
plt.close()
|
||||
|
||||
# 6) Report summary
|
||||
# 6) Summary
|
||||
mean = df["latency"].mean()
|
||||
median = df["latency"].median()
|
||||
minimum = df["latency"].min()
|
||||
maximum = df["latency"].max()
|
||||
stdev = df["latency"].std()
|
||||
|
||||
print("\n===== Complex Update Latency Statistics =====")
|
||||
print(f"Total updates: {len(latencies)}")
|
||||
print(f"Mean: {mean:.3f}s")
|
||||
print(f"Median: {median:.3f}s")
|
||||
print(f"Min: {minimum:.3f}s")
|
||||
print(f"Max: {maximum:.3f}s")
|
||||
print(f"Std: {stdev:.3f}s")
|
||||
print(f"Plot saved to: {plot_file}")
|
||||
|
||||
# Sanity assertion
|
||||
assert median < 2.0, f"Median update latency too high: {median:.3f}s"
|
||||
print("\n===== Sequential Complex Update Latencies =====")
|
||||
print(f"Total ops : {len(latencies)}")
|
||||
print(f"Mean : {mean:.3f}s")
|
||||
print(f"Median : {median:.3f}s")
|
||||
print(f"Min : {minimum:.3f}s")
|
||||
print(f"Max : {maximum:.3f}s")
|
||||
print(f"Std dev : {stdev:.3f}s")
|
||||
print(f"Plot saved: {plot_file}")
|
||||
|
Loading…
Reference in New Issue
Block a user