feat: Add bulk rethink memory for letta agent batches (#2002)

Co-authored-by: Sarah Wooders <sarahwooders@gmail.com>
Co-authored-by: cthomas <caren@letta.com>
This commit is contained in:
Matthew Zhou 2025-05-06 06:24:59 +08:00 committed by GitHub
parent 172d0bcef1
commit 812f664202
4 changed files with 181 additions and 8 deletions

View File

@ -66,7 +66,7 @@ class _ResumeContext:
request_status_updates: List[RequestStatusUpdateInfo]
async def execute_tool_wrapper(params: ToolExecutionParams):
async def execute_tool_wrapper(params: ToolExecutionParams) -> Tuple[str, Tuple[str, bool]]:
"""
Executes the tool in an outofprocess worker and returns:
(agent_id, (tool_result:str, success_flag:bool))
@ -324,8 +324,13 @@ class LettaAgentBatch:
@trace_method
async def _execute_tools(self, ctx: _ResumeContext) -> Sequence[Tuple[str, Tuple[str, bool]]]:
sbx_cfg, sbx_env = self._build_sandbox()
params = [
ToolExecutionParams(
rethink_memory_tool_name = "rethink_memory"
tool_params = []
# TODO: This is a special case - we need to think about how to generalize this
# TODO: Rethink memory is a common op that is easily batchable, so we pull this logic out
rethink_memory_params = []
for aid in ctx.agent_ids:
param = ToolExecutionParams(
agent_id=aid,
tool_call_name=ctx.tool_call_name_map[aid],
tool_args=ctx.tool_call_args_map[aid],
@ -334,10 +339,44 @@ class LettaAgentBatch:
sbx_config=sbx_cfg,
sbx_env_vars=sbx_env,
)
for aid in ctx.agent_ids
]
async with Pool() as pool:
return await pool.map(execute_tool_wrapper, params)
if ctx.tool_call_name_map[aid] == rethink_memory_tool_name:
rethink_memory_params.append(param)
else:
tool_params.append(param)
if rethink_memory_params:
return self._bulk_rethink_memory(rethink_memory_params)
if tool_params:
async with Pool() as pool:
return await pool.map(execute_tool_wrapper, tool_params)
@trace_method
def _bulk_rethink_memory(self, params: List[ToolExecutionParams]) -> Sequence[Tuple[str, Tuple[str, bool]]]:
updates = {}
result = []
for param in params:
# Sanity check
# TODO: This is very brittle and done quickly for performance
# TODO: If the end tool is changed, this will break
# TODO: Move 'rethink_memory' to a native Letta tool that we control
if "new_memory" not in param.tool_args or "target_block_label" not in param.tool_args:
raise ValueError(f"Missing either `new_memory` or `target_block_label` in the tool args: {param.tool_args}")
# Find the block id/update
block_id = param.agent_state.memory.get_block(label=param.tool_args.get("target_block_label")).id
new_value = param.tool_args.get("new_memory")
# This is sensitive to multiple agents overwriting the same memory block
updates[block_id] = new_value
# TODO: This is quite ugly and confusing - this is mostly to align with the returns of other tools
result.append((param.agent_id, ("", True)))
self.block_manager.bulk_update_block_values(updates=updates, actor=self.actor)
return result
def _persist_tool_messages(
self,

View File

@ -1,8 +1,9 @@
import os
from typing import List, Optional
from typing import Dict, List, Optional
from sqlalchemy.orm import Session
from letta.log import get_logger
from letta.orm.block import Block as BlockModel
from letta.orm.block_history import BlockHistory
from letta.orm.enums import ActorType
@ -13,6 +14,8 @@ from letta.schemas.block import BlockUpdate, Human, Persona
from letta.schemas.user import User as PydanticUser
from letta.utils import enforce_types, list_human_files, list_persona_files
logger = get_logger(__name__)
class BlockManager:
"""Manager class to handle business logic related to Blocks."""
@ -349,3 +352,44 @@ class BlockManager:
session.commit()
return block.to_pydantic()
@enforce_types
def bulk_update_block_values(
self, updates: Dict[str, str], actor: PydanticUser, return_hydrated: bool = False
) -> Optional[List[PydanticBlock]]:
"""
Bulk-update the `value` field for multiple blocks in one transaction.
Args:
updates: mapping of block_id -> new value
actor: the user performing the update (for org scoping, permissions, audit)
return_hydrated: whether to return the pydantic Block objects that were updated
Returns:
the updated Block objects as Pydantic schemas
Raises:
NoResultFound if any block_id doesnt exist or isnt visible to this actor
ValueError if any new value exceeds its blocks limit
"""
with self.session_maker() as session:
q = session.query(BlockModel).filter(BlockModel.id.in_(updates.keys()), BlockModel.organization_id == actor.organization_id)
blocks = q.all()
found_ids = {b.id for b in blocks}
missing = set(updates.keys()) - found_ids
if missing:
logger.warning(f"Block IDs not found or inaccessible, skipping during bulk update: {missing!r}")
for block in blocks:
new_val = updates[block.id]
if len(new_val) > block.limit:
logger.warning(f"Value length ({len(new_val)}) exceeds limit " f"({block.limit}) for block {block.id!r}, truncating...")
new_val = new_val[: block.limit]
block.value = new_val
session.commit()
if return_hydrated:
return [b.to_pydantic() for b in blocks]
return None

View File

@ -1,3 +1,4 @@
import logging
import os
import random
import re
@ -73,6 +74,7 @@ from letta.services.block_manager import BlockManager
from letta.services.organization_manager import OrganizationManager
from letta.settings import tool_settings
from tests.helpers.utils import comprehensive_agent_checks
from tests.utils import random_string
DEFAULT_EMBEDDING_CONFIG = EmbeddingConfig(
embedding_endpoint_type="hugging-face",
@ -2759,6 +2761,88 @@ def test_batch_create_multiple_blocks(server: SyncServer, default_user):
assert expected_labels.issubset(all_labels)
def test_bulk_update_skips_missing_and_truncates_then_returns_none(server: SyncServer, default_user: PydanticUser, caplog):
mgr = BlockManager()
# create one block with a small limit
b = mgr.create_or_update_block(
PydanticBlock(label="human", value="orig", limit=5),
actor=default_user,
)
# prepare updates: one real id with an overlimit value, plus one missing id
long_val = random_string(10) # length > limit==5
updates = {
b.id: long_val,
"nonexistent-id": "whatever",
}
caplog.set_level(logging.WARNING)
result = mgr.bulk_update_block_values(updates, actor=default_user)
# default return_hydrated=False → should be None
assert result is None
# warnings should mention skipping the missing ID and truncation
assert "skipping during bulk update" in caplog.text
assert "truncating" in caplog.text
# confirm the value was truncated to `limit` characters
reloaded = mgr.get_blocks(actor=default_user, id=b.id)[0]
assert len(reloaded.value) == 5
assert reloaded.value == long_val[:5]
def test_bulk_update_return_hydrated_true(server: SyncServer, default_user: PydanticUser):
mgr = BlockManager()
# create a block
b = mgr.create_or_update_block(
PydanticBlock(label="persona", value="foo", limit=20),
actor=default_user,
)
updates = {b.id: "new-val"}
updated = mgr.bulk_update_block_values(updates, actor=default_user, return_hydrated=True)
# with return_hydrated=True, we get back a list of schemas
assert isinstance(updated, list) and len(updated) == 1
assert updated[0].id == b.id
assert updated[0].value == "new-val"
def test_bulk_update_respects_org_scoping(server: SyncServer, default_user: PydanticUser, other_user_different_org: PydanticUser, caplog):
mgr = BlockManager()
# one block in each org
mine = mgr.create_or_update_block(
PydanticBlock(label="human", value="mine", limit=100),
actor=default_user,
)
theirs = mgr.create_or_update_block(
PydanticBlock(label="human", value="theirs", limit=100),
actor=other_user_different_org,
)
updates = {
mine.id: "updated-mine",
theirs.id: "updated-theirs",
}
caplog.set_level(logging.WARNING)
mgr.bulk_update_block_values(updates, actor=default_user)
# mine should be updated...
reloaded_mine = mgr.get_blocks(actor=default_user, id=mine.id)[0]
assert reloaded_mine.value == "updated-mine"
# ...theirs should remain untouched
reloaded_theirs = mgr.get_blocks(actor=other_user_different_org, id=theirs.id)[0]
assert reloaded_theirs.value == "theirs"
# warning should mention skipping the other-org ID
assert "skipping during bulk update" in caplog.text
# ======================================================================================================================
# Block Manager Tests - Checkpointing
# ======================================================================================================================

View File

@ -1,4 +1,6 @@
import os
import random
import string
import time
from datetime import datetime, timezone
from importlib import util
@ -193,3 +195,7 @@ def wait_for_server(url, timeout=30, interval=0.5):
time.sleep(interval)
raise TimeoutError(f"Server at {url} did not start within {timeout} seconds")
def random_string(length: int) -> str:
return "".join(random.choices(string.ascii_letters + string.digits, k=length))