MemGPT/letta/agents/letta_agent.py

387 lines
17 KiB
Python

import asyncio
import json
import uuid
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
from openai import AsyncStream
from openai.types.chat import ChatCompletion, ChatCompletionChunk
from letta.agents.base_agent import BaseAgent
from letta.agents.helpers import _create_letta_response, _prepare_in_context_messages
from letta.helpers import ToolRulesSolver
from letta.helpers.datetime_helpers import get_utc_time
from letta.helpers.tool_execution_helper import enable_strict_mode
from letta.interfaces.anthropic_streaming_interface import AnthropicStreamingInterface
from letta.llm_api.llm_client import LLMClient
from letta.llm_api.llm_client_base import LLMClientBase
from letta.local_llm.constants import INNER_THOUGHTS_KWARG
from letta.log import get_logger
from letta.orm.enums import ToolType
from letta.schemas.agent import AgentState
from letta.schemas.enums import MessageRole, MessageStreamStatus
from letta.schemas.letta_message import AssistantMessage
from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, RedactedReasoningContent, TextContent
from letta.schemas.letta_response import LettaResponse
from letta.schemas.message import Message, MessageCreate, MessageUpdate
from letta.schemas.openai.chat_completion_response import ToolCall
from letta.schemas.user import User
from letta.server.rest_api.utils import create_letta_messages_from_llm_response
from letta.services.agent_manager import AgentManager
from letta.services.block_manager import BlockManager
from letta.services.helpers.agent_manager_helper import compile_system_message
from letta.services.message_manager import MessageManager
from letta.services.passage_manager import PassageManager
from letta.services.tool_executor.tool_execution_manager import ToolExecutionManager
from letta.tracing import log_event, trace_method
from letta.utils import united_diff
logger = get_logger(__name__)
class LettaAgent(BaseAgent):
def __init__(
self,
agent_id: str,
message_manager: MessageManager,
agent_manager: AgentManager,
block_manager: BlockManager,
passage_manager: PassageManager,
actor: User,
use_assistant_message: bool = True,
):
super().__init__(agent_id=agent_id, openai_client=None, message_manager=message_manager, agent_manager=agent_manager, actor=actor)
# TODO: Make this more general, factorable
# Summarizer settings
self.block_manager = block_manager
self.passage_manager = passage_manager
self.use_assistant_message = use_assistant_message
@trace_method
async def step(self, input_messages: List[MessageCreate], max_steps: int = 10) -> LettaResponse:
agent_state = self.agent_manager.get_agent_by_id(self.agent_id, actor=self.actor)
current_in_context_messages, new_in_context_messages = _prepare_in_context_messages(
input_messages, agent_state, self.message_manager, self.actor
)
tool_rules_solver = ToolRulesSolver(agent_state.tool_rules)
llm_client = LLMClient.create(
llm_config=agent_state.llm_config,
put_inner_thoughts_first=True,
)
for step in range(max_steps):
response = await self._get_ai_reply(
llm_client=llm_client,
in_context_messages=current_in_context_messages + new_in_context_messages,
agent_state=agent_state,
tool_rules_solver=tool_rules_solver,
stream=False,
# TODO: also pass in reasoning content
)
tool_call = response.choices[0].message.tool_calls[0]
persisted_messages, should_continue = await self._handle_ai_response(tool_call, agent_state, tool_rules_solver)
new_in_context_messages.extend(persisted_messages)
if not should_continue:
break
# Extend the in context message ids
if not agent_state.message_buffer_autoclear:
message_ids = [m.id for m in (current_in_context_messages + new_in_context_messages)]
self.agent_manager.set_in_context_messages(agent_id=self.agent_id, message_ids=message_ids, actor=self.actor)
return _create_letta_response(new_in_context_messages=new_in_context_messages, use_assistant_message=self.use_assistant_message)
@trace_method
async def step_stream(
self, input_messages: List[MessageCreate], max_steps: int = 10, use_assistant_message: bool = False
) -> AsyncGenerator[str, None]:
"""
Main streaming loop that yields partial tokens.
Whenever we detect a tool call, we yield from _handle_ai_response as well.
"""
agent_state = self.agent_manager.get_agent_by_id(self.agent_id, actor=self.actor)
current_in_context_messages, new_in_context_messages = _prepare_in_context_messages(
input_messages, agent_state, self.message_manager, self.actor
)
tool_rules_solver = ToolRulesSolver(agent_state.tool_rules)
llm_client = LLMClient.create(
llm_config=agent_state.llm_config,
put_inner_thoughts_first=True,
)
for step in range(max_steps):
stream = await self._get_ai_reply(
llm_client=llm_client,
in_context_messages=current_in_context_messages + new_in_context_messages,
agent_state=agent_state,
tool_rules_solver=tool_rules_solver,
stream=True,
)
# TODO: THIS IS INCREDIBLY UGLY
# TODO: THERE ARE MULTIPLE COPIES OF THE LLM_CONFIG EVERYWHERE THAT ARE GETTING MANIPULATED
interface = AnthropicStreamingInterface(
use_assistant_message=use_assistant_message, put_inner_thoughts_in_kwarg=llm_client.llm_config.put_inner_thoughts_in_kwargs
)
async for chunk in interface.process(stream):
yield f"data: {chunk.model_dump_json()}\n\n"
# Process resulting stream content
tool_call = interface.get_tool_call_object()
reasoning_content = interface.get_reasoning_content()
persisted_messages, should_continue = await self._handle_ai_response(
tool_call,
agent_state,
tool_rules_solver,
reasoning_content=reasoning_content,
pre_computed_assistant_message_id=interface.letta_assistant_message_id,
pre_computed_tool_message_id=interface.letta_tool_message_id,
)
new_in_context_messages.extend(persisted_messages)
if not should_continue:
break
# Extend the in context message ids
if not agent_state.message_buffer_autoclear:
message_ids = [m.id for m in (current_in_context_messages + new_in_context_messages)]
self.agent_manager.set_in_context_messages(agent_id=self.agent_id, message_ids=message_ids, actor=self.actor)
# TODO: Also yield out a letta usage stats SSE
yield f"data: {MessageStreamStatus.done.model_dump_json()}\n\n"
@trace_method
async def _get_ai_reply(
self,
llm_client: LLMClientBase,
in_context_messages: List[Message],
agent_state: AgentState,
tool_rules_solver: ToolRulesSolver,
stream: bool,
) -> ChatCompletion | AsyncStream[ChatCompletionChunk]:
in_context_messages = self._rebuild_memory(in_context_messages, agent_state)
tools = [
t
for t in agent_state.tools
if t.tool_type in {ToolType.CUSTOM, ToolType.LETTA_CORE, ToolType.LETTA_MEMORY_CORE}
or (t.tool_type == ToolType.LETTA_MULTI_AGENT_CORE and t.name == "send_message_to_agents_matching_tags")
]
valid_tool_names = tool_rules_solver.get_allowed_tool_names(available_tools=set([t.name for t in tools]))
# TODO: Copied from legacy agent loop, so please be cautious
# Set force tool
force_tool_call = None
if len(valid_tool_names) == 1:
force_tool_call = valid_tool_names[0]
allowed_tools = [enable_strict_mode(t.json_schema) for t in tools if t.name in set(valid_tool_names)]
response = await llm_client.send_llm_request_async(
messages=in_context_messages,
tools=allowed_tools,
force_tool_call=force_tool_call,
stream=stream,
)
return response
@trace_method
async def _handle_ai_response(
self,
tool_call: ToolCall,
agent_state: AgentState,
tool_rules_solver: ToolRulesSolver,
reasoning_content: Optional[List[Union[TextContent, ReasoningContent, RedactedReasoningContent, OmittedReasoningContent]]] = None,
pre_computed_assistant_message_id: Optional[str] = None,
pre_computed_tool_message_id: Optional[str] = None,
) -> Tuple[List[Message], bool]:
"""
Now that streaming is done, handle the final AI response.
This might yield additional SSE tokens if we do stalling.
At the end, set self._continue_execution accordingly.
"""
tool_call_name = tool_call.function.name
tool_call_args_str = tool_call.function.arguments
try:
tool_args = json.loads(tool_call_args_str)
except json.JSONDecodeError:
tool_args = {}
# Get request heartbeats and coerce to bool
request_heartbeat = tool_args.pop("request_heartbeat", False)
# Pre-emptively pop out inner_thoughts
tool_args.pop(INNER_THOUGHTS_KWARG, "")
# So this is necessary, because sometimes non-structured outputs makes mistakes
if not isinstance(request_heartbeat, bool):
if isinstance(request_heartbeat, str):
request_heartbeat = request_heartbeat.lower() == "true"
else:
request_heartbeat = bool(request_heartbeat)
tool_call_id = tool_call.id or f"call_{uuid.uuid4().hex[:8]}"
tool_result, success_flag = await self._execute_tool(
tool_name=tool_call_name,
tool_args=tool_args,
agent_state=agent_state,
)
# 4. Register tool call with tool rule solver
# Resolve whether or not to continue stepping
continue_stepping = request_heartbeat
tool_rules_solver.register_tool_call(tool_name=tool_call_name)
if tool_rules_solver.is_terminal_tool(tool_name=tool_call_name):
continue_stepping = False
elif tool_rules_solver.has_children_tools(tool_name=tool_call_name):
continue_stepping = True
elif tool_rules_solver.is_continue_tool(tool_name=tool_call_name):
continue_stepping = True
# 5. Persist to DB
tool_call_messages = create_letta_messages_from_llm_response(
agent_id=agent_state.id,
model=agent_state.llm_config.model,
function_name=tool_call_name,
function_arguments=tool_args,
tool_call_id=tool_call_id,
function_call_success=success_flag,
function_response=tool_result,
actor=self.actor,
add_heartbeat_request_system_message=continue_stepping,
reasoning_content=reasoning_content,
pre_computed_assistant_message_id=pre_computed_assistant_message_id,
pre_computed_tool_message_id=pre_computed_tool_message_id,
)
persisted_messages = self.message_manager.create_many_messages(tool_call_messages, actor=self.actor)
return persisted_messages, continue_stepping
def _rebuild_memory(self, in_context_messages: List[Message], agent_state: AgentState) -> List[Message]:
self.agent_manager.refresh_memory(agent_state=agent_state, actor=self.actor)
# TODO: This is a pretty brittle pattern established all over our code, need to get rid of this
curr_system_message = in_context_messages[0]
curr_memory_str = agent_state.memory.compile()
curr_system_message_text = curr_system_message.content[0].text
if curr_memory_str in curr_system_message_text:
# NOTE: could this cause issues if a block is removed? (substring match would still work)
logger.debug(
f"Memory hasn't changed for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name}), skipping system prompt rebuild"
)
return in_context_messages
memory_edit_timestamp = get_utc_time()
num_messages = self.message_manager.size(actor=self.actor, agent_id=agent_state.id)
num_archival_memories = self.passage_manager.size(actor=self.actor, agent_id=agent_state.id)
new_system_message_str = compile_system_message(
system_prompt=agent_state.system,
in_context_memory=agent_state.memory,
in_context_memory_last_edit=memory_edit_timestamp,
previous_message_count=num_messages,
archival_memory_size=num_archival_memories,
)
diff = united_diff(curr_system_message_text, new_system_message_str)
if len(diff) > 0:
logger.debug(f"Rebuilding system with new memory...\nDiff:\n{diff}")
new_system_message = self.message_manager.update_message_by_id(
curr_system_message.id, message_update=MessageUpdate(content=new_system_message_str), actor=self.actor
)
# Skip pulling down the agent's memory again to save on a db call
return [new_system_message] + in_context_messages[1:]
else:
return in_context_messages
@trace_method
async def _execute_tool(self, tool_name: str, tool_args: dict, agent_state: AgentState) -> Tuple[str, bool]:
"""
Executes a tool and returns (result, success_flag).
"""
# Special memory case
target_tool = next((x for x in agent_state.tools if x.name == tool_name), None)
if not target_tool:
return f"Tool not found: {tool_name}", False
# TODO: This temp. Move this logic and code to executors
try:
if target_tool.name == "send_message_to_agents_matching_tags" and target_tool.tool_type == ToolType.LETTA_MULTI_AGENT_CORE:
log_event(name="start_send_message_to_agents_matching_tags", attributes=tool_args)
results = await self._send_message_to_agents_matching_tags(**tool_args)
log_event(name="finish_send_message_to_agents_matching_tags", attributes=tool_args)
return json.dumps(results), True
else:
tool_execution_manager = ToolExecutionManager(agent_state=agent_state, actor=self.actor)
# TODO: Integrate sandbox result
log_event(name=f"start_{tool_name}_execution", attributes=tool_args)
tool_execution_result = await tool_execution_manager.execute_tool_async(
function_name=tool_name, function_args=tool_args, tool=target_tool
)
log_event(name=f"finish_{tool_name}_execution", attributes=tool_args)
return tool_execution_result.func_return, True
except Exception as e:
return f"Failed to call tool. Error: {e}", False
@trace_method
async def _send_message_to_agents_matching_tags(
self, message: str, match_all: List[str], match_some: List[str]
) -> List[Dict[str, Any]]:
# Find matching agents
matching_agents = self.agent_manager.list_agents_matching_tags(actor=self.actor, match_all=match_all, match_some=match_some)
if not matching_agents:
return []
async def process_agent(agent_state: AgentState, message: str) -> Dict[str, Any]:
try:
letta_agent = LettaAgent(
agent_id=agent_state.id,
message_manager=self.message_manager,
agent_manager=self.agent_manager,
block_manager=self.block_manager,
passage_manager=self.passage_manager,
actor=self.actor,
use_assistant_message=True,
)
augmented_message = (
"[Incoming message from external Letta agent - to reply to this message, "
"make sure to use the 'send_message' at the end, and the system will notify "
"the sender of your response] "
f"{message}"
)
letta_response = await letta_agent.step(
[MessageCreate(role=MessageRole.system, content=[TextContent(text=augmented_message)])]
)
messages = letta_response.messages
send_message_content = [message.content for message in messages if isinstance(message, AssistantMessage)]
return {
"agent_id": agent_state.id,
"agent_name": agent_state.name,
"response": send_message_content if send_message_content else ["<no response>"],
}
except Exception as e:
return {
"agent_id": agent_state.id,
"agent_name": agent_state.name,
"error": str(e),
"type": type(e).__name__,
}
tasks = [asyncio.create_task(process_agent(agent_state=agent_state, message=message)) for agent_state in matching_agents]
results = await asyncio.gather(*tasks)
return results