diff --git a/letta/agents/letta_agent_batch.py b/letta/agents/letta_agent_batch.py index b9e30ac07..1c63c079d 100644 --- a/letta/agents/letta_agent_batch.py +++ b/letta/agents/letta_agent_batch.py @@ -66,7 +66,7 @@ class _ResumeContext: request_status_updates: List[RequestStatusUpdateInfo] -async def execute_tool_wrapper(params: ToolExecutionParams): +async def execute_tool_wrapper(params: ToolExecutionParams) -> Tuple[str, Tuple[str, bool]]: """ Executes the tool in an out‑of‑process worker and returns: (agent_id, (tool_result:str, success_flag:bool)) @@ -324,8 +324,13 @@ class LettaAgentBatch: @trace_method async def _execute_tools(self, ctx: _ResumeContext) -> Sequence[Tuple[str, Tuple[str, bool]]]: sbx_cfg, sbx_env = self._build_sandbox() - params = [ - ToolExecutionParams( + rethink_memory_tool_name = "rethink_memory" + tool_params = [] + # TODO: This is a special case - we need to think about how to generalize this + # TODO: Rethink memory is a common op that is easily batchable, so we pull this logic out + rethink_memory_params = [] + for aid in ctx.agent_ids: + param = ToolExecutionParams( agent_id=aid, tool_call_name=ctx.tool_call_name_map[aid], tool_args=ctx.tool_call_args_map[aid], @@ -334,10 +339,44 @@ class LettaAgentBatch: sbx_config=sbx_cfg, sbx_env_vars=sbx_env, ) - for aid in ctx.agent_ids - ] - async with Pool() as pool: - return await pool.map(execute_tool_wrapper, params) + + if ctx.tool_call_name_map[aid] == rethink_memory_tool_name: + rethink_memory_params.append(param) + else: + tool_params.append(param) + + if rethink_memory_params: + return self._bulk_rethink_memory(rethink_memory_params) + + if tool_params: + async with Pool() as pool: + return await pool.map(execute_tool_wrapper, tool_params) + + @trace_method + def _bulk_rethink_memory(self, params: List[ToolExecutionParams]) -> Sequence[Tuple[str, Tuple[str, bool]]]: + updates = {} + result = [] + for param in params: + # Sanity check + # TODO: This is very brittle and done quickly for performance + # TODO: If the end tool is changed, this will break + # TODO: Move 'rethink_memory' to a native Letta tool that we control + if "new_memory" not in param.tool_args or "target_block_label" not in param.tool_args: + raise ValueError(f"Missing either `new_memory` or `target_block_label` in the tool args: {param.tool_args}") + + # Find the block id/update + block_id = param.agent_state.memory.get_block(label=param.tool_args.get("target_block_label")).id + new_value = param.tool_args.get("new_memory") + + # This is sensitive to multiple agents overwriting the same memory block + updates[block_id] = new_value + + # TODO: This is quite ugly and confusing - this is mostly to align with the returns of other tools + result.append((param.agent_id, ("", True))) + + self.block_manager.bulk_update_block_values(updates=updates, actor=self.actor) + + return result def _persist_tool_messages( self, diff --git a/letta/services/block_manager.py b/letta/services/block_manager.py index 40e025196..94b3e743c 100644 --- a/letta/services/block_manager.py +++ b/letta/services/block_manager.py @@ -1,8 +1,9 @@ import os -from typing import List, Optional +from typing import Dict, List, Optional from sqlalchemy.orm import Session +from letta.log import get_logger from letta.orm.block import Block as BlockModel from letta.orm.block_history import BlockHistory from letta.orm.enums import ActorType @@ -13,6 +14,8 @@ from letta.schemas.block import BlockUpdate, Human, Persona from letta.schemas.user import User as PydanticUser from letta.utils import enforce_types, list_human_files, list_persona_files +logger = get_logger(__name__) + class BlockManager: """Manager class to handle business logic related to Blocks.""" @@ -349,3 +352,44 @@ class BlockManager: session.commit() return block.to_pydantic() + + @enforce_types + def bulk_update_block_values( + self, updates: Dict[str, str], actor: PydanticUser, return_hydrated: bool = False + ) -> Optional[List[PydanticBlock]]: + """ + Bulk-update the `value` field for multiple blocks in one transaction. + + Args: + updates: mapping of block_id -> new value + actor: the user performing the update (for org scoping, permissions, audit) + return_hydrated: whether to return the pydantic Block objects that were updated + + Returns: + the updated Block objects as Pydantic schemas + + Raises: + NoResultFound if any block_id doesn’t exist or isn’t visible to this actor + ValueError if any new value exceeds its block’s limit + """ + with self.session_maker() as session: + q = session.query(BlockModel).filter(BlockModel.id.in_(updates.keys()), BlockModel.organization_id == actor.organization_id) + blocks = q.all() + + found_ids = {b.id for b in blocks} + missing = set(updates.keys()) - found_ids + if missing: + logger.warning(f"Block IDs not found or inaccessible, skipping during bulk update: {missing!r}") + + for block in blocks: + new_val = updates[block.id] + if len(new_val) > block.limit: + logger.warning(f"Value length ({len(new_val)}) exceeds limit " f"({block.limit}) for block {block.id!r}, truncating...") + new_val = new_val[: block.limit] + block.value = new_val + + session.commit() + + if return_hydrated: + return [b.to_pydantic() for b in blocks] + return None diff --git a/tests/test_managers.py b/tests/test_managers.py index fa69b8025..cf719f9b3 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -1,3 +1,4 @@ +import logging import os import random import re @@ -73,6 +74,7 @@ from letta.services.block_manager import BlockManager from letta.services.organization_manager import OrganizationManager from letta.settings import tool_settings from tests.helpers.utils import comprehensive_agent_checks +from tests.utils import random_string DEFAULT_EMBEDDING_CONFIG = EmbeddingConfig( embedding_endpoint_type="hugging-face", @@ -2759,6 +2761,88 @@ def test_batch_create_multiple_blocks(server: SyncServer, default_user): assert expected_labels.issubset(all_labels) +def test_bulk_update_skips_missing_and_truncates_then_returns_none(server: SyncServer, default_user: PydanticUser, caplog): + mgr = BlockManager() + + # create one block with a small limit + b = mgr.create_or_update_block( + PydanticBlock(label="human", value="orig", limit=5), + actor=default_user, + ) + + # prepare updates: one real id with an over‐limit value, plus one missing id + long_val = random_string(10) # length > limit==5 + updates = { + b.id: long_val, + "nonexistent-id": "whatever", + } + + caplog.set_level(logging.WARNING) + result = mgr.bulk_update_block_values(updates, actor=default_user) + # default return_hydrated=False → should be None + assert result is None + + # warnings should mention skipping the missing ID and truncation + assert "skipping during bulk update" in caplog.text + assert "truncating" in caplog.text + + # confirm the value was truncated to `limit` characters + reloaded = mgr.get_blocks(actor=default_user, id=b.id)[0] + assert len(reloaded.value) == 5 + assert reloaded.value == long_val[:5] + + +def test_bulk_update_return_hydrated_true(server: SyncServer, default_user: PydanticUser): + mgr = BlockManager() + + # create a block + b = mgr.create_or_update_block( + PydanticBlock(label="persona", value="foo", limit=20), + actor=default_user, + ) + + updates = {b.id: "new-val"} + updated = mgr.bulk_update_block_values(updates, actor=default_user, return_hydrated=True) + + # with return_hydrated=True, we get back a list of schemas + assert isinstance(updated, list) and len(updated) == 1 + assert updated[0].id == b.id + assert updated[0].value == "new-val" + + +def test_bulk_update_respects_org_scoping(server: SyncServer, default_user: PydanticUser, other_user_different_org: PydanticUser, caplog): + mgr = BlockManager() + + # one block in each org + mine = mgr.create_or_update_block( + PydanticBlock(label="human", value="mine", limit=100), + actor=default_user, + ) + theirs = mgr.create_or_update_block( + PydanticBlock(label="human", value="theirs", limit=100), + actor=other_user_different_org, + ) + + updates = { + mine.id: "updated-mine", + theirs.id: "updated-theirs", + } + + caplog.set_level(logging.WARNING) + mgr.bulk_update_block_values(updates, actor=default_user) + + # mine should be updated... + reloaded_mine = mgr.get_blocks(actor=default_user, id=mine.id)[0] + assert reloaded_mine.value == "updated-mine" + + # ...theirs should remain untouched + reloaded_theirs = mgr.get_blocks(actor=other_user_different_org, id=theirs.id)[0] + assert reloaded_theirs.value == "theirs" + + # warning should mention skipping the other-org ID + assert "skipping during bulk update" in caplog.text + + # ====================================================================================================================== # Block Manager Tests - Checkpointing # ====================================================================================================================== diff --git a/tests/utils.py b/tests/utils.py index 12f7c40a2..65c3ee2ff 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,4 +1,6 @@ import os +import random +import string import time from datetime import datetime, timezone from importlib import util @@ -193,3 +195,7 @@ def wait_for_server(url, timeout=30, interval=0.5): time.sleep(interval) raise TimeoutError(f"Server at {url} did not start within {timeout} seconds") + + +def random_string(length: int) -> str: + return "".join(random.choices(string.ascii_letters + string.digits, k=length))