MemGPT/letta/agents/letta_agent_batch.py
cthomas c0efe8ad0c
chore: bump version 0.7.21 (#2653)
Co-authored-by: Andy Li <55300002+cliandy@users.noreply.github.com>
Co-authored-by: Kevin Lin <klin5061@gmail.com>
Co-authored-by: Sarah Wooders <sarahwooders@gmail.com>
Co-authored-by: jnjpng <jin@letta.com>
Co-authored-by: Matthew Zhou <mattzh1314@gmail.com>
2025-05-21 16:33:29 -07:00

624 lines
26 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import uuid
from dataclasses import dataclass
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence, Tuple, Union
from aiomultiprocess import Pool
from anthropic.types.beta.messages import BetaMessageBatchCanceledResult, BetaMessageBatchErroredResult, BetaMessageBatchSucceededResult
from letta.agents.base_agent import BaseAgent
from letta.agents.helpers import _prepare_in_context_messages_async
from letta.helpers import ToolRulesSolver
from letta.helpers.datetime_helpers import get_utc_time
from letta.helpers.tool_execution_helper import enable_strict_mode
from letta.jobs.types import RequestStatusUpdateInfo, StepStatusUpdateInfo
from letta.llm_api.llm_client import LLMClient
from letta.local_llm.constants import INNER_THOUGHTS_KWARG
from letta.log import get_logger
from letta.orm.enums import ToolType
from letta.schemas.agent import AgentState, AgentStepState
from letta.schemas.enums import AgentStepStatus, JobStatus, MessageStreamStatus, ProviderType
from letta.schemas.job import JobUpdate
from letta.schemas.letta_message import LegacyLettaMessage, LettaMessage
from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, RedactedReasoningContent, TextContent
from letta.schemas.letta_request import LettaBatchRequest
from letta.schemas.letta_response import LettaBatchResponse, LettaResponse
from letta.schemas.llm_batch_job import LLMBatchItem
from letta.schemas.message import Message, MessageCreate
from letta.schemas.openai.chat_completion_response import ToolCall as OpenAIToolCall
from letta.schemas.sandbox_config import SandboxConfig, SandboxType
from letta.schemas.user import User
from letta.server.rest_api.utils import create_heartbeat_system_message, create_letta_messages_from_llm_response
from letta.services.agent_manager import AgentManager
from letta.services.block_manager import BlockManager
from letta.services.job_manager import JobManager
from letta.services.llm_batch_manager import LLMBatchManager
from letta.services.message_manager import MessageManager
from letta.services.passage_manager import PassageManager
from letta.services.sandbox_config_manager import SandboxConfigManager
from letta.services.tool_executor.tool_execution_manager import ToolExecutionManager
from letta.settings import tool_settings
from letta.tracing import log_event, trace_method
logger = get_logger(__name__)
@dataclass
class ToolExecutionParams:
agent_id: str
tool_call_name: str
tool_args: Dict[str, Any]
agent_state: AgentState
actor: User
sbx_config: SandboxConfig
sbx_env_vars: Dict[str, Any]
@dataclass
class _ResumeContext:
batch_items: List[LLMBatchItem]
agent_ids: List[str]
agent_state_map: Dict[str, AgentState]
provider_results: Dict[str, Any]
tool_call_name_map: Dict[str, str]
tool_call_args_map: Dict[str, Dict[str, Any]]
should_continue_map: Dict[str, bool]
request_status_updates: List[RequestStatusUpdateInfo]
async def execute_tool_wrapper(params: ToolExecutionParams) -> Tuple[str, Tuple[str, bool]]:
"""
Executes the tool in an outofprocess worker and returns:
(agent_id, (tool_result:str, success_flag:bool))
"""
# locate the tool on the agent
target_tool = next((t for t in params.agent_state.tools if t.name == params.tool_call_name), None)
if not target_tool:
return params.agent_id, (f"Tool not found: {params.tool_call_name}", False)
try:
mgr = ToolExecutionManager(
agent_state=params.agent_state,
actor=params.actor,
sandbox_config=params.sbx_config,
sandbox_env_vars=params.sbx_env_vars,
)
tool_execution_result = await mgr.execute_tool_async(
function_name=params.tool_call_name,
function_args=params.tool_args,
tool=target_tool,
)
return params.agent_id, (tool_execution_result.func_return, True)
except Exception as e:
return params.agent_id, (f"Failed to call tool. Error: {e}", False)
# TODO: Limitations ->
# TODO: Only works with anthropic for now
class LettaAgentBatch(BaseAgent):
def __init__(
self,
message_manager: MessageManager,
agent_manager: AgentManager,
block_manager: BlockManager,
passage_manager: PassageManager,
batch_manager: LLMBatchManager,
sandbox_config_manager: SandboxConfigManager,
job_manager: JobManager,
actor: User,
max_steps: int = 10,
):
self.message_manager = message_manager
self.agent_manager = agent_manager
self.block_manager = block_manager
self.passage_manager = passage_manager
self.batch_manager = batch_manager
self.sandbox_config_manager = sandbox_config_manager
self.job_manager = job_manager
self.actor = actor
self.max_steps = max_steps
@trace_method
async def step_until_request(
self,
batch_requests: List[LettaBatchRequest],
letta_batch_job_id: str,
agent_step_state_mapping: Optional[Dict[str, AgentStepState]] = None,
) -> LettaBatchResponse:
"""Carry out agent steps until the LLM request is sent."""
log_event(name="validate_inputs")
if not batch_requests:
raise ValueError("Empty list of batch_requests passed in!")
if agent_step_state_mapping is None:
agent_step_state_mapping = {}
log_event(name="load_and_prepare_agents")
# prepares (1) agent states, (2) step states, (3) LLMBatchItems (4) message batch_item_ids (5) messages per agent (6) tools per agent
agent_messages_mapping: dict[str, list[Message]] = {}
agent_tools_mapping: dict[str, list[dict]] = {}
# TODO: This isn't optimal, moving fast - prone to bugs because we pass around this half formed pydantic object
agent_batch_item_mapping: dict[str, LLMBatchItem] = {}
# fetch agent states in batch
agent_mapping = {
agent_state.id: agent_state
for agent_state in await self.agent_manager.get_agents_by_ids_async(
agent_ids=[request.agent_id for request in batch_requests], include_relationships=["tools", "memory"], actor=self.actor
)
}
agent_states = []
for batch_request in batch_requests:
agent_id = batch_request.agent_id
agent_state = agent_mapping[agent_id]
agent_states.append(agent_state) # keeping this to maintain ordering, but may not be necessary
if agent_id not in agent_step_state_mapping:
agent_step_state_mapping[agent_id] = AgentStepState(
step_number=0, tool_rules_solver=ToolRulesSolver(tool_rules=agent_state.tool_rules)
)
llm_batch_item = LLMBatchItem(
llm_batch_id="", # TODO: This is hacky, it gets filled in later
agent_id=agent_state.id,
llm_config=agent_state.llm_config,
request_status=JobStatus.created,
step_status=AgentStepStatus.paused,
step_state=agent_step_state_mapping[agent_id],
)
agent_batch_item_mapping[agent_id] = llm_batch_item
# Fill in the batch_item_id for the message
for msg in batch_request.messages:
msg.batch_item_id = llm_batch_item.id
agent_messages_mapping[agent_id] = await self._prepare_in_context_messages_per_agent_async(
agent_state=agent_state, input_messages=batch_request.messages
)
agent_tools_mapping[agent_id] = self._prepare_tools_per_agent(agent_state, agent_step_state_mapping[agent_id].tool_rules_solver)
log_event(name="init_llm_client")
llm_client = LLMClient.create(
provider_type=agent_states[0].llm_config.model_endpoint_type,
put_inner_thoughts_first=True,
actor=self.actor,
)
agent_llm_config_mapping = {s.id: s.llm_config for s in agent_states}
log_event(name="send_llm_batch_request")
batch_response = await llm_client.send_llm_batch_request_async(
agent_messages_mapping=agent_messages_mapping,
agent_tools_mapping=agent_tools_mapping,
agent_llm_config_mapping=agent_llm_config_mapping,
)
log_event(name="persist_llm_batch_job")
llm_batch_job = await self.batch_manager.create_llm_batch_job_async(
llm_provider=ProviderType.anthropic, # TODO: Expand to more providers
create_batch_response=batch_response,
actor=self.actor,
status=JobStatus.running,
letta_batch_job_id=letta_batch_job_id,
)
log_event(name="prepare_batch_items")
batch_items = []
for state in agent_states:
llm_batch_item = agent_batch_item_mapping[state.id]
# TODO This is hacky
llm_batch_item.llm_batch_id = llm_batch_job.id
batch_items.append(llm_batch_item)
if batch_items:
log_event(name="bulk_create_batch_items")
batch_items_persisted = await self.batch_manager.create_llm_batch_items_bulk_async(batch_items, actor=self.actor)
log_event(name="return_batch_response")
return LettaBatchResponse(
letta_batch_id=llm_batch_job.letta_batch_job_id,
last_llm_batch_id=llm_batch_job.id,
status=llm_batch_job.status,
agent_count=len(agent_states),
last_polled_at=get_utc_time(),
created_at=llm_batch_job.created_at,
)
@trace_method
async def resume_step_after_request(self, letta_batch_id: str, llm_batch_id: str) -> LettaBatchResponse:
log_event(name="load_context")
llm_batch_job = await self.batch_manager.get_llm_batch_job_by_id_async(llm_batch_id=llm_batch_id, actor=self.actor)
ctx = await self._collect_resume_context(llm_batch_id)
log_event(name="update_statuses")
self._update_request_statuses(ctx.request_status_updates)
log_event(name="exec_tools")
exec_results = await self._execute_tools(ctx)
log_event(name="persist_messages")
msg_map = await self._persist_tool_messages(exec_results, ctx)
log_event(name="mark_steps_done")
self._mark_steps_complete(llm_batch_id, ctx.agent_ids)
log_event(name="prepare_next")
next_reqs, next_step_state = self._prepare_next_iteration(exec_results, ctx, msg_map)
if len(next_reqs) == 0:
await self.job_manager.update_job_by_id_async(
job_id=letta_batch_id, job_update=JobUpdate(status=JobStatus.completed), actor=self.actor
)
return LettaBatchResponse(
letta_batch_id=llm_batch_job.letta_batch_job_id,
last_llm_batch_id=llm_batch_job.id,
status=JobStatus.completed,
agent_count=len(ctx.agent_ids),
last_polled_at=get_utc_time(),
created_at=llm_batch_job.created_at,
)
return await self.step_until_request(
batch_requests=next_reqs,
letta_batch_job_id=letta_batch_id,
agent_step_state_mapping=next_step_state,
)
@trace_method
async def _collect_resume_context(self, llm_batch_id: str) -> _ResumeContext:
"""
Collect context for resuming operations from completed batch items.
Args:
llm_batch_id: The ID of the batch to collect context for
Returns:
_ResumeContext object containing all necessary data for resumption
"""
# Fetch only completed batch items
batch_items = await self.batch_manager.list_llm_batch_items_async(llm_batch_id=llm_batch_id, request_status=JobStatus.completed)
# Exit early if no items to process
if not batch_items:
return _ResumeContext(
batch_items=[],
agent_ids=[],
agent_state_map={},
provider_results={},
tool_call_name_map={},
tool_call_args_map={},
should_continue_map={},
request_status_updates=[],
)
# Extract agent IDs and organize items by agent ID
agent_ids = [item.agent_id for item in batch_items]
batch_item_map = {item.agent_id: item for item in batch_items}
# Collect provider results
provider_results = {item.agent_id: item.batch_request_result.result for item in batch_items}
# Fetch agent states in a single call
agent_states = await self.agent_manager.get_agents_by_ids_async(
agent_ids=agent_ids, include_relationships=["tools", "memory"], actor=self.actor
)
agent_state_map = {agent.id: agent for agent in agent_states}
# Process each agent's results
tool_call_results = self._process_agent_results(
agent_ids=agent_ids, batch_item_map=batch_item_map, provider_results=provider_results, llm_batch_id=llm_batch_id
)
return _ResumeContext(
batch_items=batch_items,
agent_ids=agent_ids,
agent_state_map=agent_state_map,
provider_results=provider_results,
tool_call_name_map=tool_call_results.name_map,
tool_call_args_map=tool_call_results.args_map,
should_continue_map=tool_call_results.cont_map,
request_status_updates=tool_call_results.status_updates,
)
def _process_agent_results(self, agent_ids, batch_item_map, provider_results, llm_batch_id):
"""
Process the results for each agent, extracting tool calls and determining continuation status.
Returns:
A namedtuple containing name_map, args_map, cont_map, and status_updates
"""
from collections import namedtuple
ToolCallResults = namedtuple("ToolCallResults", ["name_map", "args_map", "cont_map", "status_updates"])
name_map, args_map, cont_map = {}, {}, {}
request_status_updates = []
for aid in agent_ids:
item = batch_item_map[aid]
result = provider_results[aid]
# Determine job status based on result type
status = self._determine_job_status(result)
request_status_updates.append(RequestStatusUpdateInfo(llm_batch_id=llm_batch_id, agent_id=aid, request_status=status))
# Process tool calls
name, args, cont = self._extract_tool_call_from_result(item, result)
name_map[aid], args_map[aid], cont_map[aid] = name, args, cont
return ToolCallResults(name_map, args_map, cont_map, request_status_updates)
def _determine_job_status(self, result):
"""Determine job status based on result type"""
if isinstance(result, BetaMessageBatchSucceededResult):
return JobStatus.completed
elif isinstance(result, BetaMessageBatchErroredResult):
return JobStatus.failed
elif isinstance(result, BetaMessageBatchCanceledResult):
return JobStatus.cancelled
else:
return JobStatus.expired
def _extract_tool_call_from_result(self, item, result):
"""Extract tool call information from a result"""
llm_client = LLMClient.create(
provider_type=item.llm_config.model_endpoint_type,
put_inner_thoughts_first=True,
actor=self.actor,
)
# If result isn't a successful type, we can't extract a tool call
if not isinstance(result, BetaMessageBatchSucceededResult):
return None, None, False
tool_call = (
llm_client.convert_response_to_chat_completion(
response_data=result.message.model_dump(), input_messages=[], llm_config=item.llm_config
)
.choices[0]
.message.tool_calls[0]
)
return self._extract_tool_call_and_decide_continue(tool_call, item.step_state)
def _update_request_statuses(self, updates: List[RequestStatusUpdateInfo]) -> None:
if updates:
self.batch_manager.bulk_update_llm_batch_items_request_status_by_agent(updates=updates)
def _build_sandbox(self) -> Tuple[SandboxConfig, Dict[str, Any]]:
sbx_type = SandboxType.E2B if tool_settings.e2b_api_key else SandboxType.LOCAL
cfg = self.sandbox_config_manager.get_or_create_default_sandbox_config(sandbox_type=sbx_type, actor=self.actor)
env = self.sandbox_config_manager.get_sandbox_env_vars_as_dict(cfg.id, actor=self.actor, limit=100)
return cfg, env
@trace_method
async def _execute_tools(self, ctx: _ResumeContext) -> Sequence[Tuple[str, Tuple[str, bool]]]:
sbx_cfg, sbx_env = self._build_sandbox()
rethink_memory_tool_name = "rethink_memory"
tool_params = []
# TODO: This is a special case - we need to think about how to generalize this
# TODO: Rethink memory is a common op that is easily batchable, so we pull this logic out
rethink_memory_params = []
for aid in ctx.agent_ids:
param = ToolExecutionParams(
agent_id=aid,
tool_call_name=ctx.tool_call_name_map[aid],
tool_args=ctx.tool_call_args_map[aid],
agent_state=ctx.agent_state_map[aid],
actor=self.actor,
sbx_config=sbx_cfg,
sbx_env_vars=sbx_env,
)
if ctx.tool_call_name_map[aid] == rethink_memory_tool_name:
rethink_memory_params.append(param)
else:
tool_params.append(param)
if rethink_memory_params:
return await self._bulk_rethink_memory_async(rethink_memory_params)
if tool_params:
async with Pool() as pool:
return await pool.map(execute_tool_wrapper, tool_params)
@trace_method
async def _bulk_rethink_memory_async(self, params: List[ToolExecutionParams]) -> Sequence[Tuple[str, Tuple[str, bool]]]:
updates = {}
result = []
for param in params:
# Sanity check
# TODO: This is very brittle and done quickly for performance
# TODO: If the end tool is changed, this will break
# TODO: Move 'rethink_memory' to a native Letta tool that we control
if "new_memory" not in param.tool_args or "target_block_label" not in param.tool_args:
raise ValueError(f"Missing either `new_memory` or `target_block_label` in the tool args: {param.tool_args}")
# Find the block id/update
block_id = param.agent_state.memory.get_block(label=param.tool_args.get("target_block_label")).id
new_value = param.tool_args.get("new_memory")
# This is sensitive to multiple agents overwriting the same memory block
updates[block_id] = new_value
# TODO: This is quite ugly and confusing - this is mostly to align with the returns of other tools
result.append((param.agent_id, ("", True)))
await self.block_manager.bulk_update_block_values_async(updates=updates, actor=self.actor)
return result
async def _persist_tool_messages(
self,
exec_results: Sequence[Tuple[str, Tuple[str, bool]]],
ctx: _ResumeContext,
) -> Dict[str, List[Message]]:
# TODO: This is redundant, we should have this ready on the ctx
# TODO: I am doing it quick and dirty for now
agent_item_map: Dict[str, LLMBatchItem] = {item.agent_id: item for item in ctx.batch_items}
msg_map: Dict[str, List[Message]] = {}
for aid, (tool_res, success) in exec_results:
msgs = self._create_tool_call_messages(
llm_batch_item_id=agent_item_map[aid].id,
agent_state=ctx.agent_state_map[aid],
tool_call_name=ctx.tool_call_name_map[aid],
tool_call_args=ctx.tool_call_args_map[aid],
tool_exec_result=tool_res,
success_flag=success,
reasoning_content=None,
)
msg_map[aid] = msgs
# flatten & persist
await self.message_manager.create_many_messages_async([m for msgs in msg_map.values() for m in msgs], actor=self.actor)
return msg_map
def _mark_steps_complete(self, llm_batch_id: str, agent_ids: List[str]) -> None:
updates = [
StepStatusUpdateInfo(llm_batch_id=llm_batch_id, agent_id=aid, step_status=AgentStepStatus.completed) for aid in agent_ids
]
self.batch_manager.bulk_update_llm_batch_items_step_status_by_agent(updates)
def _prepare_next_iteration(
self,
exec_results: Sequence[Tuple[str, Tuple[str, bool]]],
ctx: _ResumeContext,
msg_map: Dict[str, List[Message]],
) -> Tuple[List[LettaBatchRequest], Dict[str, AgentStepState]]:
# who continues?
continues = [aid for aid, cont in ctx.should_continue_map.items() if cont]
success_flag_map = {aid: flag for aid, (_res, flag) in exec_results}
batch_reqs: List[LettaBatchRequest] = []
for aid in continues:
heartbeat = create_heartbeat_system_message(
agent_id=aid,
model=ctx.agent_state_map[aid].llm_config.model,
function_call_success=success_flag_map[aid],
actor=self.actor,
)
batch_reqs.append(
LettaBatchRequest(
agent_id=aid, messages=[MessageCreate.model_validate(heartbeat.model_dump(include={"role", "content", "name", "otid"}))]
)
)
# extend incontext ids when necessary
for aid, new_msgs in msg_map.items():
ast = ctx.agent_state_map[aid]
if not ast.message_buffer_autoclear:
self.agent_manager.set_in_context_messages(
agent_id=aid,
message_ids=ast.message_ids + [m.id for m in new_msgs],
actor=self.actor,
)
# bump step number
step_map = {
item.agent_id: item.step_state.model_copy(update={"step_number": item.step_state.step_number + 1}) for item in ctx.batch_items
}
return batch_reqs, step_map
def _create_tool_call_messages(
self,
llm_batch_item_id: str,
agent_state: AgentState,
tool_call_name: str,
tool_call_args: Dict[str, Any],
tool_exec_result: str,
success_flag: bool,
reasoning_content: Optional[List[Union[TextContent, ReasoningContent, RedactedReasoningContent, OmittedReasoningContent]]] = None,
) -> List[Message]:
tool_call_id = f"call_{uuid.uuid4().hex[:8]}"
tool_call_messages = create_letta_messages_from_llm_response(
agent_id=agent_state.id,
model=agent_state.llm_config.model,
function_name=tool_call_name,
function_arguments=tool_call_args,
tool_call_id=tool_call_id,
function_call_success=success_flag,
function_response=tool_exec_result,
actor=self.actor,
add_heartbeat_request_system_message=False,
reasoning_content=reasoning_content,
pre_computed_assistant_message_id=None,
pre_computed_tool_message_id=None,
llm_batch_item_id=llm_batch_item_id,
)
return tool_call_messages
# TODO: This is doing a lot of dict passing
# TODO: Make the passing here typed
def _extract_tool_call_and_decide_continue(
self, tool_call: OpenAIToolCall, agent_step_state: AgentStepState
) -> Tuple[str, Dict[str, Any], bool]:
"""
Now that streaming is done, handle the final AI response.
This might yield additional SSE tokens if we do stalling.
At the end, set self._continue_execution accordingly.
"""
tool_call_name = tool_call.function.name
tool_call_args_str = tool_call.function.arguments
try:
tool_args = json.loads(tool_call_args_str)
except json.JSONDecodeError:
logger.warning(f"Failed to JSON decode tool call argument string: {tool_call_args_str}")
tool_args = {}
# Get request heartbeats and coerce to bool
request_heartbeat = tool_args.pop("request_heartbeat", False)
# Pre-emptively pop out inner_thoughts
tool_args.pop(INNER_THOUGHTS_KWARG, "")
# So this is necessary, because sometimes non-structured outputs makes mistakes
if isinstance(request_heartbeat, str):
request_heartbeat = request_heartbeat.lower() == "true"
else:
request_heartbeat = bool(request_heartbeat)
continue_stepping = request_heartbeat
tool_rules_solver = agent_step_state.tool_rules_solver
tool_rules_solver.register_tool_call(tool_name=tool_call_name)
if tool_rules_solver.is_terminal_tool(tool_name=tool_call_name):
continue_stepping = False
elif tool_rules_solver.has_children_tools(tool_name=tool_call_name):
continue_stepping = True
elif tool_rules_solver.is_continue_tool(tool_name=tool_call_name):
continue_stepping = True
step_count = agent_step_state.step_number
if step_count >= self.max_steps:
logger.warning("Hit max steps, stopping agent loop prematurely.")
continue_stepping = False
return tool_call_name, tool_args, continue_stepping
def _prepare_tools_per_agent(self, agent_state: AgentState, tool_rules_solver: ToolRulesSolver) -> List[dict]:
tools = [t for t in agent_state.tools if t.tool_type in {ToolType.CUSTOM, ToolType.LETTA_CORE, ToolType.LETTA_MEMORY_CORE}]
valid_tool_names = tool_rules_solver.get_allowed_tool_names(available_tools=set([t.name for t in tools]))
return [enable_strict_mode(t.json_schema) for t in tools if t.name in set(valid_tool_names)]
async def _prepare_in_context_messages_per_agent_async(
self, agent_state: AgentState, input_messages: List[MessageCreate]
) -> List[Message]:
current_in_context_messages, new_in_context_messages = await _prepare_in_context_messages_async(
input_messages, agent_state, self.message_manager, self.actor
)
in_context_messages = await self._rebuild_memory_async(current_in_context_messages + new_in_context_messages, agent_state)
return in_context_messages
# Not used in batch.
async def step(self, input_messages: List[MessageCreate], max_steps: int = 10) -> LettaResponse:
raise NotImplementedError
async def step_stream(
self, input_messages: List[MessageCreate], max_steps: int = 10
) -> AsyncGenerator[Union[LettaMessage, LegacyLettaMessage, MessageStreamStatus], None]:
raise NotImplementedError