mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
feat: Add bulk rethink memory for letta agent batches (#2002)
Co-authored-by: Sarah Wooders <sarahwooders@gmail.com> Co-authored-by: cthomas <caren@letta.com>
This commit is contained in:
parent
172d0bcef1
commit
812f664202
@ -66,7 +66,7 @@ class _ResumeContext:
|
||||
request_status_updates: List[RequestStatusUpdateInfo]
|
||||
|
||||
|
||||
async def execute_tool_wrapper(params: ToolExecutionParams):
|
||||
async def execute_tool_wrapper(params: ToolExecutionParams) -> Tuple[str, Tuple[str, bool]]:
|
||||
"""
|
||||
Executes the tool in an out‑of‑process worker and returns:
|
||||
(agent_id, (tool_result:str, success_flag:bool))
|
||||
@ -324,8 +324,13 @@ class LettaAgentBatch:
|
||||
@trace_method
|
||||
async def _execute_tools(self, ctx: _ResumeContext) -> Sequence[Tuple[str, Tuple[str, bool]]]:
|
||||
sbx_cfg, sbx_env = self._build_sandbox()
|
||||
params = [
|
||||
ToolExecutionParams(
|
||||
rethink_memory_tool_name = "rethink_memory"
|
||||
tool_params = []
|
||||
# TODO: This is a special case - we need to think about how to generalize this
|
||||
# TODO: Rethink memory is a common op that is easily batchable, so we pull this logic out
|
||||
rethink_memory_params = []
|
||||
for aid in ctx.agent_ids:
|
||||
param = ToolExecutionParams(
|
||||
agent_id=aid,
|
||||
tool_call_name=ctx.tool_call_name_map[aid],
|
||||
tool_args=ctx.tool_call_args_map[aid],
|
||||
@ -334,10 +339,44 @@ class LettaAgentBatch:
|
||||
sbx_config=sbx_cfg,
|
||||
sbx_env_vars=sbx_env,
|
||||
)
|
||||
for aid in ctx.agent_ids
|
||||
]
|
||||
async with Pool() as pool:
|
||||
return await pool.map(execute_tool_wrapper, params)
|
||||
|
||||
if ctx.tool_call_name_map[aid] == rethink_memory_tool_name:
|
||||
rethink_memory_params.append(param)
|
||||
else:
|
||||
tool_params.append(param)
|
||||
|
||||
if rethink_memory_params:
|
||||
return self._bulk_rethink_memory(rethink_memory_params)
|
||||
|
||||
if tool_params:
|
||||
async with Pool() as pool:
|
||||
return await pool.map(execute_tool_wrapper, tool_params)
|
||||
|
||||
@trace_method
|
||||
def _bulk_rethink_memory(self, params: List[ToolExecutionParams]) -> Sequence[Tuple[str, Tuple[str, bool]]]:
|
||||
updates = {}
|
||||
result = []
|
||||
for param in params:
|
||||
# Sanity check
|
||||
# TODO: This is very brittle and done quickly for performance
|
||||
# TODO: If the end tool is changed, this will break
|
||||
# TODO: Move 'rethink_memory' to a native Letta tool that we control
|
||||
if "new_memory" not in param.tool_args or "target_block_label" not in param.tool_args:
|
||||
raise ValueError(f"Missing either `new_memory` or `target_block_label` in the tool args: {param.tool_args}")
|
||||
|
||||
# Find the block id/update
|
||||
block_id = param.agent_state.memory.get_block(label=param.tool_args.get("target_block_label")).id
|
||||
new_value = param.tool_args.get("new_memory")
|
||||
|
||||
# This is sensitive to multiple agents overwriting the same memory block
|
||||
updates[block_id] = new_value
|
||||
|
||||
# TODO: This is quite ugly and confusing - this is mostly to align with the returns of other tools
|
||||
result.append((param.agent_id, ("", True)))
|
||||
|
||||
self.block_manager.bulk_update_block_values(updates=updates, actor=self.actor)
|
||||
|
||||
return result
|
||||
|
||||
def _persist_tool_messages(
|
||||
self,
|
||||
|
@ -1,8 +1,9 @@
|
||||
import os
|
||||
from typing import List, Optional
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from letta.log import get_logger
|
||||
from letta.orm.block import Block as BlockModel
|
||||
from letta.orm.block_history import BlockHistory
|
||||
from letta.orm.enums import ActorType
|
||||
@ -13,6 +14,8 @@ from letta.schemas.block import BlockUpdate, Human, Persona
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.utils import enforce_types, list_human_files, list_persona_files
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class BlockManager:
|
||||
"""Manager class to handle business logic related to Blocks."""
|
||||
@ -349,3 +352,44 @@ class BlockManager:
|
||||
|
||||
session.commit()
|
||||
return block.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def bulk_update_block_values(
|
||||
self, updates: Dict[str, str], actor: PydanticUser, return_hydrated: bool = False
|
||||
) -> Optional[List[PydanticBlock]]:
|
||||
"""
|
||||
Bulk-update the `value` field for multiple blocks in one transaction.
|
||||
|
||||
Args:
|
||||
updates: mapping of block_id -> new value
|
||||
actor: the user performing the update (for org scoping, permissions, audit)
|
||||
return_hydrated: whether to return the pydantic Block objects that were updated
|
||||
|
||||
Returns:
|
||||
the updated Block objects as Pydantic schemas
|
||||
|
||||
Raises:
|
||||
NoResultFound if any block_id doesn’t exist or isn’t visible to this actor
|
||||
ValueError if any new value exceeds its block’s limit
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
q = session.query(BlockModel).filter(BlockModel.id.in_(updates.keys()), BlockModel.organization_id == actor.organization_id)
|
||||
blocks = q.all()
|
||||
|
||||
found_ids = {b.id for b in blocks}
|
||||
missing = set(updates.keys()) - found_ids
|
||||
if missing:
|
||||
logger.warning(f"Block IDs not found or inaccessible, skipping during bulk update: {missing!r}")
|
||||
|
||||
for block in blocks:
|
||||
new_val = updates[block.id]
|
||||
if len(new_val) > block.limit:
|
||||
logger.warning(f"Value length ({len(new_val)}) exceeds limit " f"({block.limit}) for block {block.id!r}, truncating...")
|
||||
new_val = new_val[: block.limit]
|
||||
block.value = new_val
|
||||
|
||||
session.commit()
|
||||
|
||||
if return_hydrated:
|
||||
return [b.to_pydantic() for b in blocks]
|
||||
return None
|
||||
|
@ -1,3 +1,4 @@
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
@ -73,6 +74,7 @@ from letta.services.block_manager import BlockManager
|
||||
from letta.services.organization_manager import OrganizationManager
|
||||
from letta.settings import tool_settings
|
||||
from tests.helpers.utils import comprehensive_agent_checks
|
||||
from tests.utils import random_string
|
||||
|
||||
DEFAULT_EMBEDDING_CONFIG = EmbeddingConfig(
|
||||
embedding_endpoint_type="hugging-face",
|
||||
@ -2759,6 +2761,88 @@ def test_batch_create_multiple_blocks(server: SyncServer, default_user):
|
||||
assert expected_labels.issubset(all_labels)
|
||||
|
||||
|
||||
def test_bulk_update_skips_missing_and_truncates_then_returns_none(server: SyncServer, default_user: PydanticUser, caplog):
|
||||
mgr = BlockManager()
|
||||
|
||||
# create one block with a small limit
|
||||
b = mgr.create_or_update_block(
|
||||
PydanticBlock(label="human", value="orig", limit=5),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# prepare updates: one real id with an over‐limit value, plus one missing id
|
||||
long_val = random_string(10) # length > limit==5
|
||||
updates = {
|
||||
b.id: long_val,
|
||||
"nonexistent-id": "whatever",
|
||||
}
|
||||
|
||||
caplog.set_level(logging.WARNING)
|
||||
result = mgr.bulk_update_block_values(updates, actor=default_user)
|
||||
# default return_hydrated=False → should be None
|
||||
assert result is None
|
||||
|
||||
# warnings should mention skipping the missing ID and truncation
|
||||
assert "skipping during bulk update" in caplog.text
|
||||
assert "truncating" in caplog.text
|
||||
|
||||
# confirm the value was truncated to `limit` characters
|
||||
reloaded = mgr.get_blocks(actor=default_user, id=b.id)[0]
|
||||
assert len(reloaded.value) == 5
|
||||
assert reloaded.value == long_val[:5]
|
||||
|
||||
|
||||
def test_bulk_update_return_hydrated_true(server: SyncServer, default_user: PydanticUser):
|
||||
mgr = BlockManager()
|
||||
|
||||
# create a block
|
||||
b = mgr.create_or_update_block(
|
||||
PydanticBlock(label="persona", value="foo", limit=20),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
updates = {b.id: "new-val"}
|
||||
updated = mgr.bulk_update_block_values(updates, actor=default_user, return_hydrated=True)
|
||||
|
||||
# with return_hydrated=True, we get back a list of schemas
|
||||
assert isinstance(updated, list) and len(updated) == 1
|
||||
assert updated[0].id == b.id
|
||||
assert updated[0].value == "new-val"
|
||||
|
||||
|
||||
def test_bulk_update_respects_org_scoping(server: SyncServer, default_user: PydanticUser, other_user_different_org: PydanticUser, caplog):
|
||||
mgr = BlockManager()
|
||||
|
||||
# one block in each org
|
||||
mine = mgr.create_or_update_block(
|
||||
PydanticBlock(label="human", value="mine", limit=100),
|
||||
actor=default_user,
|
||||
)
|
||||
theirs = mgr.create_or_update_block(
|
||||
PydanticBlock(label="human", value="theirs", limit=100),
|
||||
actor=other_user_different_org,
|
||||
)
|
||||
|
||||
updates = {
|
||||
mine.id: "updated-mine",
|
||||
theirs.id: "updated-theirs",
|
||||
}
|
||||
|
||||
caplog.set_level(logging.WARNING)
|
||||
mgr.bulk_update_block_values(updates, actor=default_user)
|
||||
|
||||
# mine should be updated...
|
||||
reloaded_mine = mgr.get_blocks(actor=default_user, id=mine.id)[0]
|
||||
assert reloaded_mine.value == "updated-mine"
|
||||
|
||||
# ...theirs should remain untouched
|
||||
reloaded_theirs = mgr.get_blocks(actor=other_user_different_org, id=theirs.id)[0]
|
||||
assert reloaded_theirs.value == "theirs"
|
||||
|
||||
# warning should mention skipping the other-org ID
|
||||
assert "skipping during bulk update" in caplog.text
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# Block Manager Tests - Checkpointing
|
||||
# ======================================================================================================================
|
||||
|
@ -1,4 +1,6 @@
|
||||
import os
|
||||
import random
|
||||
import string
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from importlib import util
|
||||
@ -193,3 +195,7 @@ def wait_for_server(url, timeout=30, interval=0.5):
|
||||
time.sleep(interval)
|
||||
|
||||
raise TimeoutError(f"Server at {url} did not start within {timeout} seconds")
|
||||
|
||||
|
||||
def random_string(length: int) -> str:
|
||||
return "".join(random.choices(string.ascii_letters + string.digits, k=length))
|
||||
|
Loading…
Reference in New Issue
Block a user