mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
1714 lines
86 KiB
Python
1714 lines
86 KiB
Python
import asyncio
|
|
import json
|
|
import os
|
|
import time
|
|
import traceback
|
|
import warnings
|
|
from abc import ABC, abstractmethod
|
|
from typing import Dict, List, Optional, Tuple, Union
|
|
|
|
from openai.types.beta.function_tool import FunctionTool as OpenAITool
|
|
|
|
from letta.agents.helpers import generate_step_id
|
|
from letta.constants import (
|
|
CLI_WARNING_PREFIX,
|
|
COMPOSIO_ENTITY_ENV_VAR_KEY,
|
|
ERROR_MESSAGE_PREFIX,
|
|
FIRST_MESSAGE_ATTEMPTS,
|
|
FUNC_FAILED_HEARTBEAT_MESSAGE,
|
|
LETTA_CORE_TOOL_MODULE_NAME,
|
|
LETTA_MULTI_AGENT_TOOL_MODULE_NAME,
|
|
LLM_MAX_TOKENS,
|
|
READ_ONLY_BLOCK_EDIT_ERROR,
|
|
REQ_HEARTBEAT_MESSAGE,
|
|
SEND_MESSAGE_TOOL_NAME,
|
|
)
|
|
from letta.errors import ContextWindowExceededError
|
|
from letta.functions.ast_parsers import coerce_dict_args_by_annotations, get_function_annotations_from_source
|
|
from letta.functions.composio_helpers import execute_composio_action, generate_composio_action_from_func_name
|
|
from letta.functions.functions import get_function_from_module
|
|
from letta.functions.mcp_client.base_client import BaseMCPClient
|
|
from letta.helpers import ToolRulesSolver
|
|
from letta.helpers.composio_helpers import get_composio_api_key
|
|
from letta.helpers.datetime_helpers import get_utc_time
|
|
from letta.helpers.json_helpers import json_dumps, json_loads
|
|
from letta.helpers.message_helper import convert_message_creates_to_messages
|
|
from letta.interface import AgentInterface
|
|
from letta.llm_api.helpers import calculate_summarizer_cutoff, get_token_counts_for_messages, is_context_overflow_error
|
|
from letta.llm_api.llm_api_tools import create
|
|
from letta.llm_api.llm_client import LLMClient
|
|
from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages
|
|
from letta.log import get_logger
|
|
from letta.memory import summarize_messages
|
|
from letta.orm import User
|
|
from letta.orm.enums import ToolType
|
|
from letta.schemas.agent import AgentState, AgentStepResponse, UpdateAgent, get_prompt_template_for_agent_type
|
|
from letta.schemas.block import BlockUpdate
|
|
from letta.schemas.embedding_config import EmbeddingConfig
|
|
from letta.schemas.enums import MessageRole, ProviderType
|
|
from letta.schemas.letta_message_content import TextContent
|
|
from letta.schemas.memory import ContextWindowOverview, Memory
|
|
from letta.schemas.message import Message, MessageCreate, ToolReturn
|
|
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
|
|
from letta.schemas.openai.chat_completion_response import Message as ChatCompletionMessage
|
|
from letta.schemas.openai.chat_completion_response import UsageStatistics
|
|
from letta.schemas.response_format import ResponseFormatType
|
|
from letta.schemas.tool import Tool
|
|
from letta.schemas.tool_execution_result import ToolExecutionResult
|
|
from letta.schemas.tool_rule import TerminalToolRule
|
|
from letta.schemas.usage import LettaUsageStatistics
|
|
from letta.services.agent_manager import AgentManager
|
|
from letta.services.block_manager import BlockManager
|
|
from letta.services.helpers.agent_manager_helper import check_supports_structured_output, compile_memory_metadata_block
|
|
from letta.services.job_manager import JobManager
|
|
from letta.services.message_manager import MessageManager
|
|
from letta.services.passage_manager import PassageManager
|
|
from letta.services.provider_manager import ProviderManager
|
|
from letta.services.step_manager import StepManager
|
|
from letta.services.telemetry_manager import NoopTelemetryManager, TelemetryManager
|
|
from letta.services.tool_executor.tool_execution_sandbox import ToolExecutionSandbox
|
|
from letta.services.tool_manager import ToolManager
|
|
from letta.settings import settings, summarizer_settings, model_settings
|
|
from letta.streaming_interface import StreamingRefreshCLIInterface
|
|
from letta.system import get_heartbeat, get_token_limit_warning, package_function_response, package_summarize_message, package_user_message
|
|
from letta.tracing import log_event, trace_method
|
|
from letta.utils import count_tokens, get_friendly_error_msg, get_tool_call_id, log_telemetry, parse_json, validate_function_response
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class BaseAgent(ABC):
|
|
"""
|
|
Abstract class for all agents.
|
|
Only one interface is required: step.
|
|
"""
|
|
|
|
@abstractmethod
|
|
def step(
|
|
self,
|
|
input_messages: List[MessageCreate],
|
|
) -> LettaUsageStatistics:
|
|
"""
|
|
Top-level event message handler for the agent.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
|
|
class Agent(BaseAgent):
|
|
def __init__(
|
|
self,
|
|
interface: Optional[Union[AgentInterface, StreamingRefreshCLIInterface]],
|
|
agent_state: AgentState, # in-memory representation of the agent state (read from multiple tables)
|
|
user: User,
|
|
# extras
|
|
first_message_verify_mono: bool = True, # TODO move to config?
|
|
# MCP sessions, state held in-memory in the server
|
|
mcp_clients: Optional[Dict[str, BaseMCPClient]] = None,
|
|
save_last_response: bool = False,
|
|
):
|
|
assert isinstance(agent_state.memory, Memory), f"Memory object is not of type Memory: {type(agent_state.memory)}"
|
|
# Hold a copy of the state that was used to init the agent
|
|
self.agent_state = agent_state
|
|
assert isinstance(self.agent_state.memory, Memory), f"Memory object is not of type Memory: {type(self.agent_state.memory)}"
|
|
|
|
self.user = user
|
|
|
|
# initialize a tool rules solver
|
|
self.tool_rules_solver = ToolRulesSolver(tool_rules=agent_state.tool_rules)
|
|
|
|
# gpt-4, gpt-3.5-turbo, ...
|
|
self.model = self.agent_state.llm_config.model
|
|
self.supports_structured_output = check_supports_structured_output(model=self.model, tool_rules=agent_state.tool_rules)
|
|
|
|
# if there are tool rules, print out a warning
|
|
if not self.supports_structured_output and agent_state.tool_rules:
|
|
for rule in agent_state.tool_rules:
|
|
if not isinstance(rule, TerminalToolRule):
|
|
warnings.warn("Tool rules only work reliably for model backends that support structured outputs (e.g. OpenAI gpt-4o).")
|
|
break
|
|
|
|
# state managers
|
|
self.block_manager = BlockManager()
|
|
|
|
# Interface must implement:
|
|
# - internal_monologue
|
|
# - assistant_message
|
|
# - function_message
|
|
# ...
|
|
# Different interfaces can handle events differently
|
|
# e.g., print in CLI vs send a discord message with a discord bot
|
|
self.interface = interface
|
|
|
|
# Create the persistence manager object based on the AgentState info
|
|
self.message_manager = MessageManager()
|
|
self.passage_manager = PassageManager()
|
|
self.provider_manager = ProviderManager()
|
|
self.agent_manager = AgentManager()
|
|
self.job_manager = JobManager()
|
|
self.step_manager = StepManager()
|
|
self.telemetry_manager = TelemetryManager() if settings.llm_api_logging else NoopTelemetryManager()
|
|
|
|
# State needed for heartbeat pausing
|
|
|
|
self.first_message_verify_mono = first_message_verify_mono
|
|
|
|
# Controls if the convo memory pressure warning is triggered
|
|
# When an alert is sent in the message queue, set this to True (to avoid repeat alerts)
|
|
# When the summarizer is run, set this back to False (to reset)
|
|
self.agent_alerted_about_memory_pressure = False
|
|
|
|
# Load last function response from message history
|
|
self.last_function_response = self.load_last_function_response()
|
|
|
|
# Save last responses in memory
|
|
self.save_last_response = save_last_response
|
|
self.last_response_messages = []
|
|
|
|
# Logger that the Agent specifically can use, will also report the agent_state ID with the logs
|
|
self.logger = get_logger(agent_state.id)
|
|
|
|
# MCPClient, state/sessions managed by the server
|
|
self.mcp_clients = mcp_clients
|
|
|
|
def load_last_function_response(self):
|
|
"""Load the last function response from message history"""
|
|
in_context_messages = self.agent_manager.get_in_context_messages(agent_id=self.agent_state.id, actor=self.user)
|
|
for i in range(len(in_context_messages) - 1, -1, -1):
|
|
msg = in_context_messages[i]
|
|
if msg.role == MessageRole.tool and msg.content and len(msg.content) == 1 and isinstance(msg.content[0], TextContent):
|
|
text_content = msg.content[0].text
|
|
try:
|
|
response_json = json.loads(text_content)
|
|
if response_json.get("message"):
|
|
return response_json["message"]
|
|
except (json.JSONDecodeError, KeyError):
|
|
raise ValueError(f"Invalid JSON format in message: {text_content}")
|
|
return None
|
|
|
|
def ensure_read_only_block_not_modified(self, new_memory: Memory) -> None:
|
|
"""
|
|
Throw an error if a read-only block has been modified
|
|
"""
|
|
for label in self.agent_state.memory.list_block_labels():
|
|
if self.agent_state.memory.get_block(label).read_only:
|
|
if new_memory.get_block(label).value != self.agent_state.memory.get_block(label).value:
|
|
raise ValueError(READ_ONLY_BLOCK_EDIT_ERROR)
|
|
|
|
def update_memory_if_changed(self, new_memory: Memory) -> bool:
|
|
"""
|
|
Update internal memory object and system prompt if there have been modifications.
|
|
|
|
Args:
|
|
new_memory (Memory): the new memory object to compare to the current memory object
|
|
|
|
Returns:
|
|
modified (bool): whether the memory was updated
|
|
"""
|
|
system_message = self.message_manager.get_message_by_id(message_id=self.agent_state.message_ids[0], actor=self.user)
|
|
if new_memory.compile() not in system_message.content[0].text:
|
|
# update the blocks (LRW) in the DB
|
|
for label in self.agent_state.memory.list_block_labels():
|
|
updated_value = new_memory.get_block(label).value
|
|
if updated_value != self.agent_state.memory.get_block(label).value:
|
|
# update the block if it's changed
|
|
block_id = self.agent_state.memory.get_block(label).id
|
|
self.block_manager.update_block(block_id=block_id, block_update=BlockUpdate(value=updated_value), actor=self.user)
|
|
|
|
# refresh memory from DB (using block ids)
|
|
self.agent_state.memory = Memory(
|
|
blocks=[self.block_manager.get_block_by_id(block.id, actor=self.user) for block in self.agent_state.memory.get_blocks()],
|
|
prompt_template=get_prompt_template_for_agent_type(self.agent_state.agent_type),
|
|
)
|
|
|
|
# NOTE: don't do this since re-buildin the memory is handled at the start of the step
|
|
# rebuild memory - this records the last edited timestamp of the memory
|
|
# TODO: pass in update timestamp from block edit time
|
|
self.agent_state = self.agent_manager.rebuild_system_prompt(agent_id=self.agent_state.id, actor=self.user)
|
|
|
|
return True
|
|
return False
|
|
|
|
def _handle_function_error_response(
|
|
self,
|
|
error_msg: str,
|
|
tool_call_id: str,
|
|
function_name: str,
|
|
function_args: dict,
|
|
function_response: str,
|
|
messages: List[Message],
|
|
tool_returns: Optional[List[ToolReturn]] = None,
|
|
include_function_failed_message: bool = False,
|
|
group_id: Optional[str] = None,
|
|
) -> List[Message]:
|
|
"""
|
|
Handle error from function call response
|
|
"""
|
|
# Update tool rules
|
|
self.last_function_response = function_response
|
|
self.tool_rules_solver.register_tool_call(function_name)
|
|
|
|
# Extend conversation with function response
|
|
function_response = package_function_response(False, error_msg)
|
|
new_message = Message(
|
|
agent_id=self.agent_state.id,
|
|
# Base info OpenAI-style
|
|
model=self.model,
|
|
role="tool",
|
|
name=function_name, # NOTE: when role is 'tool', the 'name' is the function name, not agent name
|
|
content=[TextContent(text=function_response)],
|
|
tool_call_id=tool_call_id,
|
|
# Letta extras
|
|
tool_returns=tool_returns,
|
|
group_id=group_id,
|
|
)
|
|
messages.append(new_message)
|
|
self.interface.function_message(f"Error: {error_msg}", msg_obj=new_message, chunk_index=0)
|
|
if include_function_failed_message:
|
|
self.interface.function_message(f"Ran {function_name}({function_args})", msg_obj=new_message)
|
|
|
|
# Return updated messages
|
|
return messages
|
|
|
|
def _runtime_override_tool_json_schema(
|
|
self,
|
|
functions_list: List[Dict | None],
|
|
) -> List[Dict | None]:
|
|
"""Override the tool JSON schema at runtime for a particular tool if conditions are met."""
|
|
|
|
# Currently just injects `send_message` with a `response_format` if provided to the agent.
|
|
if self.agent_state.response_format and self.agent_state.response_format.type != ResponseFormatType.text:
|
|
for func in functions_list:
|
|
if func["name"] == SEND_MESSAGE_TOOL_NAME:
|
|
if self.agent_state.response_format.type == ResponseFormatType.json_schema:
|
|
func["parameters"]["properties"]["message"] = self.agent_state.response_format.json_schema["schema"]
|
|
if self.agent_state.response_format.type == ResponseFormatType.json_object:
|
|
func["parameters"]["properties"]["message"] = {
|
|
"type": "object",
|
|
"description": "Message contents. All unicode (including emojis) are supported.",
|
|
"additionalProperties": True,
|
|
"properties": {},
|
|
}
|
|
break
|
|
return functions_list
|
|
|
|
@trace_method
|
|
def _get_ai_reply(
|
|
self,
|
|
message_sequence: List[Message],
|
|
function_call: Optional[str] = None,
|
|
first_message: bool = False,
|
|
stream: bool = False, # TODO move to config?
|
|
empty_response_retry_limit: int = 3,
|
|
backoff_factor: float = 0.5, # delay multiplier for exponential backoff
|
|
max_delay: float = 10.0, # max delay between retries
|
|
step_count: Optional[int] = None,
|
|
last_function_failed: bool = False,
|
|
put_inner_thoughts_first: bool = True,
|
|
step_id: Optional[str] = None,
|
|
) -> ChatCompletionResponse | None:
|
|
"""Get response from LLM API with robust retry mechanism."""
|
|
log_telemetry(self.logger, "_get_ai_reply start")
|
|
available_tools = set([t.name for t in self.agent_state.tools])
|
|
agent_state_tool_jsons = [t.json_schema for t in self.agent_state.tools]
|
|
|
|
# Get allowed tools or allow all if none are allowed
|
|
allowed_tool_names = self.tool_rules_solver.get_allowed_tool_names(
|
|
available_tools=available_tools, last_function_response=self.last_function_response
|
|
) or list(available_tools)
|
|
|
|
# Don't allow a tool to be called if it failed last time
|
|
if last_function_failed and self.tool_rules_solver.tool_call_history:
|
|
allowed_tool_names = [f for f in allowed_tool_names if f != self.tool_rules_solver.tool_call_history[-1]]
|
|
if not allowed_tool_names:
|
|
return None
|
|
|
|
allowed_functions = [func for func in agent_state_tool_jsons if func["name"] in allowed_tool_names]
|
|
allowed_functions = self._runtime_override_tool_json_schema(allowed_functions)
|
|
|
|
# For the first message, force the initial tool if one is specified
|
|
force_tool_call = None
|
|
if (
|
|
step_count is not None
|
|
and step_count == 0
|
|
and not self.supports_structured_output
|
|
and len(self.tool_rules_solver.init_tool_rules) > 0
|
|
):
|
|
# TODO: This just seems wrong? What if there are more than 1 init tool rules?
|
|
force_tool_call = self.tool_rules_solver.init_tool_rules[0].tool_name
|
|
# Force a tool call if exactly one tool is specified
|
|
elif step_count is not None and step_count > 0 and len(allowed_tool_names) == 1:
|
|
force_tool_call = allowed_tool_names[0]
|
|
|
|
for attempt in range(1, empty_response_retry_limit + 1):
|
|
try:
|
|
log_telemetry(self.logger, "_get_ai_reply create start")
|
|
# New LLM client flow
|
|
llm_client = LLMClient.create(
|
|
provider_type=self.agent_state.llm_config.model_endpoint_type,
|
|
put_inner_thoughts_first=put_inner_thoughts_first,
|
|
actor=self.user,
|
|
)
|
|
|
|
if llm_client and not stream:
|
|
response = llm_client.send_llm_request(
|
|
messages=message_sequence,
|
|
llm_config=self.agent_state.llm_config,
|
|
tools=allowed_functions,
|
|
force_tool_call=force_tool_call,
|
|
telemetry_manager=self.telemetry_manager,
|
|
step_id=step_id,
|
|
)
|
|
else:
|
|
# Fallback to existing flow
|
|
response = create(
|
|
llm_config=self.agent_state.llm_config,
|
|
messages=message_sequence,
|
|
user_id=self.agent_state.created_by_id,
|
|
functions=allowed_functions,
|
|
# functions_python=self.functions_python, do we need this?
|
|
function_call=function_call,
|
|
first_message=first_message,
|
|
force_tool_call=force_tool_call,
|
|
stream=stream,
|
|
stream_interface=self.interface,
|
|
put_inner_thoughts_first=put_inner_thoughts_first,
|
|
name=self.agent_state.name,
|
|
telemetry_manager=self.telemetry_manager,
|
|
step_id=step_id,
|
|
actor=self.user,
|
|
)
|
|
log_telemetry(self.logger, "_get_ai_reply create finish")
|
|
|
|
# These bottom two are retryable
|
|
if len(response.choices) == 0 or response.choices[0] is None:
|
|
raise ValueError(f"API call returned an empty message: {response}")
|
|
|
|
if response.choices[0].finish_reason not in ["stop", "function_call", "tool_calls"]:
|
|
if response.choices[0].finish_reason == "length":
|
|
# This is not retryable, hence RuntimeError v.s. ValueError
|
|
raise RuntimeError("Finish reason was length (maximum context length)")
|
|
else:
|
|
raise ValueError(f"Bad finish reason from API: {response.choices[0].finish_reason}")
|
|
log_telemetry(self.logger, "_handle_ai_response finish")
|
|
|
|
except ValueError as ve:
|
|
if attempt >= empty_response_retry_limit:
|
|
warnings.warn(f"Retry limit reached. Final error: {ve}")
|
|
log_telemetry(self.logger, "_handle_ai_response finish ValueError")
|
|
raise Exception(f"Retries exhausted and no valid response received. Final error: {ve}")
|
|
else:
|
|
delay = min(backoff_factor * (2 ** (attempt - 1)), max_delay)
|
|
warnings.warn(f"Attempt {attempt} failed: {ve}. Retrying in {delay} seconds...")
|
|
time.sleep(delay)
|
|
continue
|
|
|
|
except Exception as e:
|
|
# For non-retryable errors, exit immediately
|
|
log_telemetry(self.logger, "_handle_ai_response finish generic Exception")
|
|
raise e
|
|
|
|
# check if we are going over the context window: this allows for articifial constraints
|
|
if response.usage.total_tokens > self.agent_state.llm_config.context_window:
|
|
# trigger summarization
|
|
log_telemetry(self.logger, "_get_ai_reply summarize_messages_inplace")
|
|
self.summarize_messages_inplace()
|
|
|
|
# return the response
|
|
return response
|
|
|
|
log_telemetry(self.logger, "_handle_ai_response finish catch-all exception")
|
|
raise Exception("Retries exhausted and no valid response received.")
|
|
|
|
@trace_method
|
|
def _handle_ai_response(
|
|
self,
|
|
response_message: ChatCompletionMessage, # TODO should we eventually move the Message creation outside of this function?
|
|
override_tool_call_id: bool = False,
|
|
# If we are streaming, we needed to create a Message ID ahead of time,
|
|
# and now we want to use it in the creation of the Message object
|
|
# TODO figure out a cleaner way to do this
|
|
response_message_id: Optional[str] = None,
|
|
group_id: Optional[str] = None,
|
|
) -> Tuple[List[Message], bool, bool]:
|
|
"""Handles parsing and function execution"""
|
|
log_telemetry(self.logger, "_handle_ai_response start")
|
|
# Hacky failsafe for now to make sure we didn't implement the streaming Message ID creation incorrectly
|
|
if response_message_id is not None:
|
|
assert response_message_id.startswith("message-"), response_message_id
|
|
|
|
messages = [] # append these to the history when done
|
|
function_name = None
|
|
function_args = {}
|
|
chunk_index = 0
|
|
|
|
# Step 2: check if LLM wanted to call a function
|
|
if response_message.function_call or (response_message.tool_calls is not None and len(response_message.tool_calls) > 0):
|
|
if response_message.function_call:
|
|
raise DeprecationWarning(response_message)
|
|
if response_message.tool_calls is not None and len(response_message.tool_calls) > 1:
|
|
# raise NotImplementedError(f">1 tool call not supported")
|
|
# TODO eventually support sequential tool calling
|
|
self.logger.warning(f">1 tool call not supported, using index=0 only\n{response_message.tool_calls}")
|
|
response_message.tool_calls = [response_message.tool_calls[0]]
|
|
assert response_message.tool_calls is not None and len(response_message.tool_calls) > 0
|
|
|
|
# generate UUID for tool call
|
|
if override_tool_call_id or response_message.function_call:
|
|
warnings.warn("Overriding the tool call can result in inconsistent tool call IDs during streaming")
|
|
tool_call_id = get_tool_call_id() # needs to be a string for JSON
|
|
response_message.tool_calls[0].id = tool_call_id
|
|
else:
|
|
tool_call_id = response_message.tool_calls[0].id
|
|
assert tool_call_id is not None # should be defined
|
|
|
|
# only necessary to add the tool_call_id to a function call (antipattern)
|
|
# response_message_dict = response_message.model_dump()
|
|
# response_message_dict["tool_call_id"] = tool_call_id
|
|
|
|
# role: assistant (requesting tool call, set tool call ID)
|
|
messages.append(
|
|
# NOTE: we're recreating the message here
|
|
# TODO should probably just overwrite the fields?
|
|
Message.dict_to_message(
|
|
id=response_message_id,
|
|
agent_id=self.agent_state.id,
|
|
model=self.model,
|
|
openai_message_dict=response_message.model_dump(),
|
|
name=self.agent_state.name,
|
|
group_id=group_id,
|
|
)
|
|
) # extend conversation with assistant's reply
|
|
self.logger.debug(f"Function call message: {messages[-1]}")
|
|
|
|
nonnull_content = False
|
|
if response_message.content or response_message.reasoning_content or response_message.redacted_reasoning_content:
|
|
# The content if then internal monologue, not chat
|
|
self.interface.internal_monologue(response_message.content, msg_obj=messages[-1], chunk_index=chunk_index)
|
|
chunk_index += 1
|
|
# Flag to avoid printing a duplicate if inner thoughts get popped from the function call
|
|
nonnull_content = True
|
|
|
|
# Step 3: call the function
|
|
# Note: the JSON response may not always be valid; be sure to handle errors
|
|
function_call = (
|
|
response_message.function_call if response_message.function_call is not None else response_message.tool_calls[0].function
|
|
)
|
|
function_name = function_call.name
|
|
self.logger.debug(f"Request to call function {function_name} with tool_call_id: {tool_call_id}")
|
|
|
|
# Failure case 1: function name is wrong (not in agent_state.tools)
|
|
target_letta_tool = None
|
|
for t in self.agent_state.tools:
|
|
if t.name == function_name:
|
|
# This force refreshes the target_letta_tool from the database
|
|
# We only do this on name match to confirm that the agent state contains a specific tool with the right name
|
|
target_letta_tool = ToolManager().get_tool_by_name(tool_name=function_name, actor=self.user)
|
|
break
|
|
|
|
if not target_letta_tool:
|
|
error_msg = f"No function named {function_name}"
|
|
function_response = "None" # more like "never ran?"
|
|
messages = self._handle_function_error_response(
|
|
error_msg, tool_call_id, function_name, function_args, function_response, messages, group_id=group_id
|
|
)
|
|
return messages, False, True # force a heartbeat to allow agent to handle error
|
|
|
|
# Failure case 2: function name is OK, but function args are bad JSON
|
|
try:
|
|
raw_function_args = function_call.arguments
|
|
function_args = parse_json(raw_function_args)
|
|
if not isinstance(function_args, dict):
|
|
raise ValueError(f"Function arguments are not a dictionary: {function_args} (raw={raw_function_args})")
|
|
except Exception as e:
|
|
print(e)
|
|
error_msg = f"Error parsing JSON for function '{function_name}' arguments: {function_call.arguments}"
|
|
function_response = "None" # more like "never ran?"
|
|
messages = self._handle_function_error_response(
|
|
error_msg, tool_call_id, function_name, function_args, function_response, messages, group_id=group_id
|
|
)
|
|
return messages, False, True # force a heartbeat to allow agent to handle error
|
|
|
|
# Check if inner thoughts is in the function call arguments (possible apparently if you are using Azure)
|
|
if "inner_thoughts" in function_args:
|
|
response_message.content = function_args.pop("inner_thoughts")
|
|
# The content if then internal monologue, not chat
|
|
if response_message.content and not nonnull_content:
|
|
self.interface.internal_monologue(response_message.content, msg_obj=messages[-1], chunk_index=chunk_index)
|
|
chunk_index += 1
|
|
|
|
# (Still parsing function args)
|
|
# Handle requests for immediate heartbeat
|
|
heartbeat_request = function_args.pop("request_heartbeat", None)
|
|
|
|
# Edge case: heartbeat_request is returned as a stringified boolean, we will attempt to parse:
|
|
if isinstance(heartbeat_request, str) and heartbeat_request.lower().strip() == "true":
|
|
heartbeat_request = True
|
|
|
|
if heartbeat_request is None:
|
|
heartbeat_request = False
|
|
|
|
if not isinstance(heartbeat_request, bool):
|
|
self.logger.warning(
|
|
f"{CLI_WARNING_PREFIX}'request_heartbeat' arg parsed was not a bool or None, type={type(heartbeat_request)}, value={heartbeat_request}"
|
|
)
|
|
heartbeat_request = False
|
|
|
|
# Failure case 3: function failed during execution
|
|
# NOTE: the msg_obj associated with the "Running " message is the prior assistant message, not the function/tool role message
|
|
# this is because the function/tool role message is only created once the function/tool has executed/returned
|
|
|
|
# handle cases where we return a json message
|
|
if "message" in function_args:
|
|
function_args["message"] = str(function_args.get("message", ""))
|
|
self.interface.function_message(f"Running {function_name}({function_args})", msg_obj=messages[-1], chunk_index=chunk_index)
|
|
chunk_index = 0 # reset chunk index after assistant message
|
|
try:
|
|
# handle tool execution (sandbox) and state updates
|
|
log_telemetry(
|
|
self.logger, "_handle_ai_response execute tool start", function_name=function_name, function_args=function_args
|
|
)
|
|
log_event(
|
|
"tool_call_initiated",
|
|
attributes={
|
|
"function_name": function_name,
|
|
"target_letta_tool": target_letta_tool.model_dump(),
|
|
**{f"function_args.{k}": v for k, v in function_args.items()},
|
|
},
|
|
)
|
|
|
|
tool_execution_result = self.execute_tool_and_persist_state(function_name, function_args, target_letta_tool)
|
|
function_response = tool_execution_result.func_return
|
|
|
|
log_event(
|
|
"tool_call_ended",
|
|
attributes={
|
|
"function_response": function_response,
|
|
"tool_execution_result": tool_execution_result.model_dump(),
|
|
},
|
|
)
|
|
log_telemetry(
|
|
self.logger, "_handle_ai_response execute tool finish", function_name=function_name, function_args=function_args
|
|
)
|
|
|
|
if tool_execution_result and tool_execution_result.status == "error":
|
|
tool_return = ToolReturn(
|
|
status=tool_execution_result.status, stdout=tool_execution_result.stdout, stderr=tool_execution_result.stderr
|
|
)
|
|
messages = self._handle_function_error_response(
|
|
function_response,
|
|
tool_call_id,
|
|
function_name,
|
|
function_args,
|
|
function_response,
|
|
messages,
|
|
[tool_return],
|
|
group_id=group_id,
|
|
)
|
|
return messages, False, True # force a heartbeat to allow agent to handle error
|
|
|
|
# handle trunction
|
|
if function_name in ["conversation_search", "conversation_search_date", "archival_memory_search"]:
|
|
# with certain functions we rely on the paging mechanism to handle overflow
|
|
truncate = False
|
|
else:
|
|
# but by default, we add a truncation safeguard to prevent bad functions from
|
|
# overflow the agent context window
|
|
truncate = True
|
|
|
|
# get the function response limit
|
|
return_char_limit = target_letta_tool.return_char_limit
|
|
function_response_string = validate_function_response(
|
|
function_response, return_char_limit=return_char_limit, truncate=truncate
|
|
)
|
|
function_args.pop("self", None)
|
|
function_response = package_function_response(True, function_response_string)
|
|
function_failed = False
|
|
except Exception as e:
|
|
function_args.pop("self", None)
|
|
# error_msg = f"Error calling function {function_name} with args {function_args}: {str(e)}"
|
|
# Less detailed - don't provide full args, idea is that it should be in recent context so no need (just adds noise)
|
|
error_msg = get_friendly_error_msg(function_name=function_name, exception_name=type(e).__name__, exception_message=str(e))
|
|
error_msg_user = f"{error_msg}\n{traceback.format_exc()}"
|
|
self.logger.error(error_msg_user)
|
|
messages = self._handle_function_error_response(
|
|
error_msg,
|
|
tool_call_id,
|
|
function_name,
|
|
function_args,
|
|
function_response,
|
|
messages,
|
|
[ToolReturn(status="error", stderr=[error_msg_user])],
|
|
include_function_failed_message=True,
|
|
group_id=group_id,
|
|
)
|
|
return messages, False, True # force a heartbeat to allow agent to handle error
|
|
|
|
# Step 4: check if function response is an error
|
|
if function_response_string.startswith(ERROR_MESSAGE_PREFIX):
|
|
error_msg = function_response_string
|
|
tool_return = ToolReturn(
|
|
status=tool_execution_result.status,
|
|
stdout=tool_execution_result.stdout,
|
|
stderr=tool_execution_result.stderr,
|
|
)
|
|
messages = self._handle_function_error_response(
|
|
error_msg,
|
|
tool_call_id,
|
|
function_name,
|
|
function_args,
|
|
function_response,
|
|
messages,
|
|
[tool_return],
|
|
include_function_failed_message=True,
|
|
group_id=group_id,
|
|
)
|
|
return messages, False, True # force a heartbeat to allow agent to handle error
|
|
|
|
# If no failures happened along the way: ...
|
|
# Step 5: send the info on the function call and function response to GPT
|
|
tool_return = ToolReturn(
|
|
status=tool_execution_result.status,
|
|
stdout=tool_execution_result.stdout,
|
|
stderr=tool_execution_result.stderr,
|
|
)
|
|
messages.append(
|
|
Message(
|
|
agent_id=self.agent_state.id,
|
|
# Base info OpenAI-style
|
|
model=self.model,
|
|
role="tool",
|
|
name=function_name, # NOTE: when role is 'tool', the 'name' is the function name, not agent name
|
|
content=[TextContent(text=function_response)],
|
|
tool_call_id=tool_call_id,
|
|
# Letta extras
|
|
tool_returns=[tool_return],
|
|
group_id=group_id,
|
|
)
|
|
) # extend conversation with function response
|
|
self.interface.function_message(f"Ran {function_name}({function_args})", msg_obj=messages[-1], chunk_index=chunk_index)
|
|
self.interface.function_message(f"Success: {function_response_string}", msg_obj=messages[-1], chunk_index=chunk_index)
|
|
chunk_index += 1
|
|
self.last_function_response = function_response
|
|
|
|
else:
|
|
# Standard non-function reply
|
|
messages.append(
|
|
Message.dict_to_message(
|
|
id=response_message_id,
|
|
agent_id=self.agent_state.id,
|
|
model=self.model,
|
|
openai_message_dict=response_message.model_dump(),
|
|
name=self.agent_state.name,
|
|
group_id=group_id,
|
|
)
|
|
) # extend conversation with assistant's reply
|
|
self.interface.internal_monologue(response_message.content, msg_obj=messages[-1], chunk_index=chunk_index)
|
|
chunk_index += 1
|
|
heartbeat_request = False
|
|
function_failed = False
|
|
|
|
# rebuild memory
|
|
# TODO: @charles please check this
|
|
self.agent_state = self.agent_manager.rebuild_system_prompt(agent_id=self.agent_state.id, actor=self.user)
|
|
|
|
# Update ToolRulesSolver state with last called function
|
|
self.tool_rules_solver.register_tool_call(function_name)
|
|
# Update heartbeat request according to provided tool rules
|
|
if self.tool_rules_solver.has_children_tools(function_name):
|
|
heartbeat_request = True
|
|
elif self.tool_rules_solver.is_terminal_tool(function_name):
|
|
heartbeat_request = False
|
|
|
|
# if continue tool rule, then must request a heartbeat
|
|
# TODO: dont even include heartbeats in the args
|
|
if self.tool_rules_solver.is_continue_tool(function_name):
|
|
heartbeat_request = True
|
|
|
|
log_telemetry(self.logger, "_handle_ai_response finish")
|
|
return messages, heartbeat_request, function_failed
|
|
|
|
@trace_method
|
|
def step(
|
|
self,
|
|
input_messages: List[MessageCreate],
|
|
# additional args
|
|
chaining: bool = True,
|
|
max_chaining_steps: Optional[int] = None,
|
|
put_inner_thoughts_first: bool = True,
|
|
**kwargs,
|
|
) -> LettaUsageStatistics:
|
|
"""Run Agent.step in a loop, handling chaining via heartbeat requests and function failures"""
|
|
# Defensively clear the tool rules solver history
|
|
# Usually this would be extraneous as Agent loop is re-loaded on every message send
|
|
# But just to be safe
|
|
self.tool_rules_solver.clear_tool_history()
|
|
|
|
# Convert MessageCreate objects to Message objects
|
|
next_input_messages = convert_message_creates_to_messages(input_messages, self.agent_state.id)
|
|
counter = 0
|
|
total_usage = UsageStatistics()
|
|
step_count = 0
|
|
function_failed = False
|
|
steps_messages = []
|
|
while True:
|
|
kwargs["first_message"] = False
|
|
kwargs["step_count"] = step_count
|
|
kwargs["last_function_failed"] = function_failed
|
|
step_response = self.inner_step(
|
|
messages=next_input_messages,
|
|
put_inner_thoughts_first=put_inner_thoughts_first,
|
|
**kwargs,
|
|
)
|
|
|
|
heartbeat_request = step_response.heartbeat_request
|
|
function_failed = step_response.function_failed
|
|
token_warning = step_response.in_context_memory_warning
|
|
usage = step_response.usage
|
|
steps_messages.append(step_response.messages)
|
|
|
|
step_count += 1
|
|
total_usage += usage
|
|
counter += 1
|
|
self.interface.step_complete()
|
|
|
|
# logger.debug("Saving agent state")
|
|
# save updated state
|
|
save_agent(self)
|
|
|
|
# Chain stops
|
|
if not chaining:
|
|
self.logger.info("No chaining, stopping after one step")
|
|
break
|
|
elif max_chaining_steps is not None and counter > max_chaining_steps:
|
|
self.logger.info(f"Hit max chaining steps, stopping after {counter} steps")
|
|
break
|
|
# Chain handlers
|
|
elif token_warning and summarizer_settings.send_memory_warning_message:
|
|
assert self.agent_state.created_by_id is not None
|
|
next_input_messages = [
|
|
Message.dict_to_message(
|
|
agent_id=self.agent_state.id,
|
|
model=self.model,
|
|
openai_message_dict={
|
|
"role": "user", # TODO: change to system?
|
|
"content": get_token_limit_warning(),
|
|
},
|
|
),
|
|
]
|
|
continue # always chain
|
|
elif function_failed:
|
|
assert self.agent_state.created_by_id is not None
|
|
next_input_messages = [
|
|
Message.dict_to_message(
|
|
agent_id=self.agent_state.id,
|
|
model=self.model,
|
|
openai_message_dict={
|
|
"role": "user", # TODO: change to system?
|
|
"content": get_heartbeat(FUNC_FAILED_HEARTBEAT_MESSAGE),
|
|
},
|
|
)
|
|
]
|
|
continue # always chain
|
|
elif heartbeat_request:
|
|
assert self.agent_state.created_by_id is not None
|
|
next_input_messages = [
|
|
Message.dict_to_message(
|
|
agent_id=self.agent_state.id,
|
|
model=self.model,
|
|
openai_message_dict={
|
|
"role": "user", # TODO: change to system?
|
|
"content": get_heartbeat(REQ_HEARTBEAT_MESSAGE),
|
|
},
|
|
)
|
|
]
|
|
continue # always chain
|
|
# Letta no-op / yield
|
|
else:
|
|
break
|
|
|
|
if self.agent_state.message_buffer_autoclear:
|
|
self.logger.info("Autoclearing message buffer")
|
|
self.agent_state = self.agent_manager.trim_all_in_context_messages_except_system(self.agent_state.id, actor=self.user)
|
|
|
|
return LettaUsageStatistics(**total_usage.model_dump(), step_count=step_count, steps_messages=steps_messages)
|
|
|
|
def inner_step(
|
|
self,
|
|
messages: List[Message],
|
|
first_message: bool = False,
|
|
first_message_retry_limit: int = FIRST_MESSAGE_ATTEMPTS,
|
|
skip_verify: bool = False,
|
|
stream: bool = False, # TODO move to config?
|
|
step_count: Optional[int] = None,
|
|
metadata: Optional[dict] = None,
|
|
summarize_attempt_count: int = 0,
|
|
last_function_failed: bool = False,
|
|
put_inner_thoughts_first: bool = True,
|
|
) -> AgentStepResponse:
|
|
"""Runs a single step in the agent loop (generates at most one LLM call)"""
|
|
try:
|
|
|
|
# Extract job_id from metadata if present
|
|
job_id = metadata.get("job_id") if metadata else None
|
|
|
|
# Declare step_id for the given step to be used as the step is processing.
|
|
step_id = generate_step_id()
|
|
|
|
# Step 0: update core memory
|
|
# only pulling latest block data if shared memory is being used
|
|
current_persisted_memory = Memory(
|
|
blocks=[self.block_manager.get_block_by_id(block.id, actor=self.user) for block in self.agent_state.memory.get_blocks()],
|
|
prompt_template=get_prompt_template_for_agent_type(self.agent_state.agent_type),
|
|
) # read blocks from DB
|
|
self.update_memory_if_changed(current_persisted_memory)
|
|
|
|
# Step 1: add user message
|
|
if not all(isinstance(m, Message) for m in messages):
|
|
raise ValueError(f"messages should be a list of Message, got {[type(m) for m in messages]}")
|
|
|
|
in_context_messages = self.agent_manager.get_in_context_messages(agent_id=self.agent_state.id, actor=self.user)
|
|
input_message_sequence = in_context_messages + messages
|
|
|
|
if (
|
|
len(input_message_sequence) > 1
|
|
and input_message_sequence[-1].role != "user"
|
|
and input_message_sequence[-1].group_id is None
|
|
):
|
|
self.logger.warning(f"{CLI_WARNING_PREFIX}Attempting to run ChatCompletion without user as the last message in the queue")
|
|
|
|
# Step 2: send the conversation and available functions to the LLM
|
|
response = self._get_ai_reply(
|
|
message_sequence=input_message_sequence,
|
|
first_message=first_message,
|
|
stream=stream,
|
|
step_count=step_count,
|
|
last_function_failed=last_function_failed,
|
|
put_inner_thoughts_first=put_inner_thoughts_first,
|
|
step_id=step_id,
|
|
)
|
|
if not response:
|
|
# EDGE CASE: Function call failed AND there's no tools left for agent to call -> return early
|
|
return AgentStepResponse(
|
|
messages=input_message_sequence,
|
|
heartbeat_request=False,
|
|
function_failed=False, # NOTE: this is different from other function fails. We force to return early
|
|
in_context_memory_warning=False,
|
|
usage=UsageStatistics(),
|
|
)
|
|
|
|
# Step 3: check if LLM wanted to call a function
|
|
# (if yes) Step 4: call the function
|
|
# (if yes) Step 5: send the info on the function call and function response to LLM
|
|
response_message = response.choices[0].message
|
|
|
|
response_message.model_copy() # TODO why are we copying here?
|
|
all_response_messages, heartbeat_request, function_failed = self._handle_ai_response(
|
|
response_message,
|
|
# TODO this is kind of hacky, find a better way to handle this
|
|
# the only time we set up message creation ahead of time is when streaming is on
|
|
response_message_id=response.id if stream else None,
|
|
group_id=input_message_sequence[-1].group_id,
|
|
)
|
|
|
|
# Step 6: extend the message history
|
|
if len(messages) > 0:
|
|
all_new_messages = messages + all_response_messages
|
|
else:
|
|
all_new_messages = all_response_messages
|
|
|
|
if self.save_last_response:
|
|
self.last_response_messages = all_response_messages
|
|
|
|
# Check the memory pressure and potentially issue a memory pressure warning
|
|
current_total_tokens = response.usage.total_tokens
|
|
active_memory_warning = False
|
|
|
|
# We can't do summarize logic properly if context_window is undefined
|
|
if self.agent_state.llm_config.context_window is None:
|
|
# Fallback if for some reason context_window is missing, just set to the default
|
|
print(f"{CLI_WARNING_PREFIX}could not find context_window in config, setting to default {LLM_MAX_TOKENS['DEFAULT']}")
|
|
print(f"{self.agent_state}")
|
|
self.agent_state.llm_config.context_window = (
|
|
LLM_MAX_TOKENS[self.model] if (self.model is not None and self.model in LLM_MAX_TOKENS) else LLM_MAX_TOKENS["DEFAULT"]
|
|
)
|
|
|
|
if current_total_tokens > summarizer_settings.memory_warning_threshold * int(self.agent_state.llm_config.context_window):
|
|
logger.warning(
|
|
f"{CLI_WARNING_PREFIX}last response total_tokens ({current_total_tokens}) > {summarizer_settings.memory_warning_threshold * int(self.agent_state.llm_config.context_window)}"
|
|
)
|
|
|
|
log_event(
|
|
name="memory_pressure_warning",
|
|
attributes={
|
|
"current_total_tokens": current_total_tokens,
|
|
"context_window_limit": self.agent_state.llm_config.context_window,
|
|
},
|
|
)
|
|
# Only deliver the alert if we haven't already (this period)
|
|
if not self.agent_alerted_about_memory_pressure:
|
|
active_memory_warning = True
|
|
self.agent_alerted_about_memory_pressure = True # it's up to the outer loop to handle this
|
|
|
|
else:
|
|
logger.info(
|
|
f"last response total_tokens ({current_total_tokens}) < {summarizer_settings.memory_warning_threshold * int(self.agent_state.llm_config.context_window)}"
|
|
)
|
|
|
|
# Log step - this must happen before messages are persisted
|
|
step = self.step_manager.log_step(
|
|
actor=self.user,
|
|
agent_id=self.agent_state.id,
|
|
provider_name=self.agent_state.llm_config.model_endpoint_type,
|
|
provider_category=self.agent_state.llm_config.provider_category or "base",
|
|
model=self.agent_state.llm_config.model,
|
|
model_endpoint=self.agent_state.llm_config.model_endpoint,
|
|
context_window_limit=self.agent_state.llm_config.context_window,
|
|
usage=response.usage,
|
|
provider_id=self.provider_manager.get_provider_id_from_name(
|
|
self.agent_state.llm_config.provider_name,
|
|
actor=self.user,
|
|
),
|
|
job_id=job_id,
|
|
step_id=step_id,
|
|
)
|
|
for message in all_new_messages:
|
|
message.step_id = step.id
|
|
|
|
# Persisting into Messages
|
|
self.agent_state = self.agent_manager.append_to_in_context_messages(
|
|
all_new_messages, agent_id=self.agent_state.id, actor=self.user
|
|
)
|
|
if job_id:
|
|
for message in all_new_messages:
|
|
self.job_manager.add_message_to_job(
|
|
job_id=job_id,
|
|
message_id=message.id,
|
|
actor=self.user,
|
|
)
|
|
|
|
return AgentStepResponse(
|
|
messages=all_new_messages,
|
|
heartbeat_request=heartbeat_request,
|
|
function_failed=function_failed,
|
|
in_context_memory_warning=active_memory_warning,
|
|
usage=response.usage,
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"step() failed\nmessages = {messages}\nerror = {e}")
|
|
|
|
# If we got a context alert, try trimming the messages length, then try again
|
|
if is_context_overflow_error(e):
|
|
in_context_messages = self.agent_manager.get_in_context_messages(agent_id=self.agent_state.id, actor=self.user)
|
|
|
|
# TODO: this is a patch to resolve immediate issues, should be removed once the summarizer is fixes
|
|
if self.agent_state.message_buffer_autoclear:
|
|
# no calling the summarizer in this case
|
|
logger.error(
|
|
f"step() failed with an exception that looks like a context window overflow, but message buffer is set to autoclear, so skipping: '{str(e)}'"
|
|
)
|
|
raise e
|
|
|
|
if summarize_attempt_count <= summarizer_settings.max_summarizer_retries:
|
|
logger.warning(
|
|
f"context window exceeded with limit {self.agent_state.llm_config.context_window}, attempting to summarize ({summarize_attempt_count}/{summarizer_settings.max_summarizer_retries}"
|
|
)
|
|
# A separate API call to run a summarizer
|
|
self.summarize_messages_inplace()
|
|
|
|
# Try step again
|
|
return self.inner_step(
|
|
messages=messages,
|
|
first_message=first_message,
|
|
first_message_retry_limit=first_message_retry_limit,
|
|
skip_verify=skip_verify,
|
|
stream=stream,
|
|
metadata=metadata,
|
|
summarize_attempt_count=summarize_attempt_count + 1,
|
|
)
|
|
else:
|
|
err_msg = f"Ran summarizer {summarize_attempt_count - 1} times for agent id={self.agent_state.id}, but messages are still overflowing the context window."
|
|
token_counts = (get_token_counts_for_messages(in_context_messages),)
|
|
logger.error(err_msg)
|
|
logger.error(f"num_in_context_messages: {len(self.agent_state.message_ids)}")
|
|
logger.error(f"token_counts: {token_counts}")
|
|
raise ContextWindowExceededError(
|
|
err_msg,
|
|
details={
|
|
"num_in_context_messages": len(self.agent_state.message_ids),
|
|
"in_context_messages_text": [m.content for m in in_context_messages],
|
|
"token_counts": token_counts,
|
|
},
|
|
)
|
|
|
|
else:
|
|
logger.error(f"step() failed with an unrecognized exception: '{str(e)}'")
|
|
traceback.print_exc()
|
|
raise e
|
|
|
|
def step_user_message(self, user_message_str: str, **kwargs) -> AgentStepResponse:
|
|
"""Takes a basic user message string, turns it into a stringified JSON with extra metadata, then sends it to the agent
|
|
|
|
Example:
|
|
-> user_message_str = 'hi'
|
|
-> {'message': 'hi', 'type': 'user_message', ...}
|
|
-> json.dumps(...)
|
|
-> agent.step(messages=[Message(role='user', text=...)])
|
|
"""
|
|
# Wrap with metadata, dumps to JSON
|
|
assert user_message_str and isinstance(
|
|
user_message_str, str
|
|
), f"user_message_str should be a non-empty string, got {type(user_message_str)}"
|
|
user_message_json_str = package_user_message(user_message_str)
|
|
|
|
# Validate JSON via save/load
|
|
user_message = validate_json(user_message_json_str)
|
|
cleaned_user_message_text, name = strip_name_field_from_user_message(user_message)
|
|
|
|
# Turn into a dict
|
|
openai_message_dict = {"role": "user", "content": cleaned_user_message_text, "name": name}
|
|
|
|
# Create the associated Message object (in the database)
|
|
assert self.agent_state.created_by_id is not None, "User ID is not set"
|
|
user_message = Message.dict_to_message(
|
|
agent_id=self.agent_state.id,
|
|
model=self.model,
|
|
openai_message_dict=openai_message_dict,
|
|
# created_at=timestamp,
|
|
)
|
|
|
|
return self.inner_step(messages=[user_message], **kwargs)
|
|
|
|
def summarize_messages_inplace(self):
|
|
in_context_messages = self.agent_manager.get_in_context_messages(agent_id=self.agent_state.id, actor=self.user)
|
|
in_context_messages_openai = [m.to_openai_dict() for m in in_context_messages]
|
|
in_context_messages_openai_no_system = in_context_messages_openai[1:]
|
|
token_counts = get_token_counts_for_messages(in_context_messages)
|
|
logger.info(f"System message token count={token_counts[0]}")
|
|
logger.info(f"token_counts_no_system={token_counts[1:]}")
|
|
|
|
if in_context_messages_openai[0]["role"] != "system":
|
|
raise RuntimeError(f"in_context_messages_openai[0] should be system (instead got {in_context_messages_openai[0]})")
|
|
|
|
# If at this point there's nothing to summarize, throw an error
|
|
if len(in_context_messages_openai_no_system) == 0:
|
|
raise ContextWindowExceededError(
|
|
"Not enough messages to compress for summarization",
|
|
details={
|
|
"num_candidate_messages": len(in_context_messages_openai_no_system),
|
|
"num_total_messages": len(in_context_messages_openai),
|
|
},
|
|
)
|
|
|
|
cutoff = calculate_summarizer_cutoff(in_context_messages=in_context_messages, token_counts=token_counts, logger=logger)
|
|
message_sequence_to_summarize = in_context_messages[1:cutoff] # do NOT get rid of the system message
|
|
logger.info(f"Attempting to summarize {len(message_sequence_to_summarize)} messages of {len(in_context_messages)}")
|
|
|
|
# We can't do summarize logic properly if context_window is undefined
|
|
if self.agent_state.llm_config.context_window is None:
|
|
# Fallback if for some reason context_window is missing, just set to the default
|
|
logger.warning(f"{CLI_WARNING_PREFIX}could not find context_window in config, setting to default {LLM_MAX_TOKENS['DEFAULT']}")
|
|
self.agent_state.llm_config.context_window = (
|
|
LLM_MAX_TOKENS[self.model] if (self.model is not None and self.model in LLM_MAX_TOKENS) else LLM_MAX_TOKENS["DEFAULT"]
|
|
)
|
|
|
|
summary = summarize_messages(
|
|
agent_state=self.agent_state, message_sequence_to_summarize=message_sequence_to_summarize, actor=self.user
|
|
)
|
|
logger.info(f"Got summary: {summary}")
|
|
|
|
# Metadata that's useful for the agent to see
|
|
all_time_message_count = self.message_manager.size(agent_id=self.agent_state.id, actor=self.user)
|
|
remaining_message_count = 1 + len(in_context_messages) - cutoff # System + remaining
|
|
hidden_message_count = all_time_message_count - remaining_message_count
|
|
summary_message_count = len(message_sequence_to_summarize)
|
|
summary_message = package_summarize_message(summary, summary_message_count, hidden_message_count, all_time_message_count)
|
|
logger.info(f"Packaged into message: {summary_message}")
|
|
|
|
prior_len = len(in_context_messages_openai)
|
|
self.agent_state = self.agent_manager.trim_older_in_context_messages(num=cutoff, agent_id=self.agent_state.id, actor=self.user)
|
|
packed_summary_message = {"role": "user", "content": summary_message}
|
|
# Prepend the summary
|
|
self.agent_state = self.agent_manager.prepend_to_in_context_messages(
|
|
messages=[
|
|
Message.dict_to_message(
|
|
agent_id=self.agent_state.id,
|
|
model=self.model,
|
|
openai_message_dict=packed_summary_message,
|
|
)
|
|
],
|
|
agent_id=self.agent_state.id,
|
|
actor=self.user,
|
|
)
|
|
|
|
# reset alert
|
|
self.agent_alerted_about_memory_pressure = False
|
|
curr_in_context_messages = self.agent_manager.get_in_context_messages(agent_id=self.agent_state.id, actor=self.user)
|
|
|
|
current_token_count = sum(get_token_counts_for_messages(curr_in_context_messages))
|
|
logger.info(f"Ran summarizer, messages length {prior_len} -> {len(curr_in_context_messages)}")
|
|
logger.info(f"Summarizer brought down total token count from {sum(token_counts)} -> {current_token_count}")
|
|
log_event(
|
|
name="summarization",
|
|
attributes={
|
|
"prior_length": prior_len,
|
|
"current_length": len(curr_in_context_messages),
|
|
"prior_token_count": sum(token_counts),
|
|
"current_token_count": current_token_count,
|
|
"context_window_limit": self.agent_state.llm_config.context_window,
|
|
},
|
|
)
|
|
|
|
def add_function(self, function_name: str) -> str:
|
|
# TODO: refactor
|
|
raise NotImplementedError
|
|
|
|
def remove_function(self, function_name: str) -> str:
|
|
# TODO: refactor
|
|
raise NotImplementedError
|
|
|
|
def migrate_embedding(self, embedding_config: EmbeddingConfig):
|
|
"""Migrate the agent to a new embedding"""
|
|
# TODO: archival memory
|
|
|
|
# TODO: recall memory
|
|
raise NotImplementedError()
|
|
|
|
def get_context_window(self) -> ContextWindowOverview:
|
|
"""Get the context window of the agent"""
|
|
|
|
system_prompt = self.agent_state.system # TODO is this the current system or the initial system?
|
|
num_tokens_system = count_tokens(system_prompt)
|
|
core_memory = self.agent_state.memory.compile()
|
|
num_tokens_core_memory = count_tokens(core_memory)
|
|
|
|
# Grab the in-context messages
|
|
# conversion of messages to OpenAI dict format, which is passed to the token counter
|
|
in_context_messages = self.agent_manager.get_in_context_messages(agent_id=self.agent_state.id, actor=self.user)
|
|
in_context_messages_openai = [m.to_openai_dict() for m in in_context_messages]
|
|
|
|
# Check if there's a summary message in the message queue
|
|
if (
|
|
len(in_context_messages) > 1
|
|
and in_context_messages[1].role == MessageRole.user
|
|
and in_context_messages[1].content
|
|
and len(in_context_messages[1].content) == 1
|
|
and isinstance(in_context_messages[1].content[0], TextContent)
|
|
# TODO remove hardcoding
|
|
and "The following is a summary of the previous " in in_context_messages[1].content[0].text
|
|
):
|
|
# Summary message exists
|
|
text_content = in_context_messages[1].content[0].text
|
|
assert text_content is not None
|
|
summary_memory = text_content
|
|
num_tokens_summary_memory = count_tokens(text_content)
|
|
# with a summary message, the real messages start at index 2
|
|
num_tokens_messages = (
|
|
num_tokens_from_messages(messages=in_context_messages_openai[2:], model=self.model)
|
|
if len(in_context_messages_openai) > 2
|
|
else 0
|
|
)
|
|
|
|
else:
|
|
summary_memory = None
|
|
num_tokens_summary_memory = 0
|
|
# with no summary message, the real messages start at index 1
|
|
num_tokens_messages = (
|
|
num_tokens_from_messages(messages=in_context_messages_openai[1:], model=self.model)
|
|
if len(in_context_messages_openai) > 1
|
|
else 0
|
|
)
|
|
|
|
agent_manager_passage_size = self.agent_manager.passage_size(actor=self.user, agent_id=self.agent_state.id)
|
|
message_manager_size = self.message_manager.size(actor=self.user, agent_id=self.agent_state.id)
|
|
external_memory_summary = compile_memory_metadata_block(
|
|
memory_edit_timestamp=get_utc_time(),
|
|
previous_message_count=self.message_manager.size(actor=self.user, agent_id=self.agent_state.id),
|
|
archival_memory_size=self.agent_manager.passage_size(actor=self.user, agent_id=self.agent_state.id),
|
|
recent_passages=self.agent_manager.list_passages(actor=self.user, agent_id=self.agent_state.id, ascending=False, limit=10),
|
|
)
|
|
num_tokens_external_memory_summary = count_tokens(external_memory_summary)
|
|
|
|
# tokens taken up by function definitions
|
|
agent_state_tool_jsons = [t.json_schema for t in self.agent_state.tools]
|
|
if agent_state_tool_jsons:
|
|
available_functions_definitions = [OpenAITool(type="function", function=f) for f in agent_state_tool_jsons]
|
|
num_tokens_available_functions_definitions = num_tokens_from_functions(functions=agent_state_tool_jsons, model=self.model)
|
|
else:
|
|
available_functions_definitions = []
|
|
num_tokens_available_functions_definitions = 0
|
|
|
|
num_tokens_used_total = (
|
|
num_tokens_system # system prompt
|
|
+ num_tokens_available_functions_definitions # function definitions
|
|
+ num_tokens_core_memory # core memory
|
|
+ num_tokens_external_memory_summary # metadata (statistics) about recall/archival
|
|
+ num_tokens_summary_memory # summary of ongoing conversation
|
|
+ num_tokens_messages # tokens taken by messages
|
|
)
|
|
assert isinstance(num_tokens_used_total, int)
|
|
|
|
return ContextWindowOverview(
|
|
# context window breakdown (in messages)
|
|
num_messages=len(in_context_messages),
|
|
num_archival_memory=agent_manager_passage_size,
|
|
num_recall_memory=message_manager_size,
|
|
num_tokens_external_memory_summary=num_tokens_external_memory_summary,
|
|
external_memory_summary=external_memory_summary,
|
|
# top-level information
|
|
context_window_size_max=self.agent_state.llm_config.context_window,
|
|
context_window_size_current=num_tokens_used_total,
|
|
# context window breakdown (in tokens)
|
|
num_tokens_system=num_tokens_system,
|
|
system_prompt=system_prompt,
|
|
num_tokens_core_memory=num_tokens_core_memory,
|
|
core_memory=core_memory,
|
|
num_tokens_summary_memory=num_tokens_summary_memory,
|
|
summary_memory=summary_memory,
|
|
num_tokens_messages=num_tokens_messages,
|
|
messages=in_context_messages,
|
|
# related to functions
|
|
num_tokens_functions_definitions=num_tokens_available_functions_definitions,
|
|
functions_definitions=available_functions_definitions,
|
|
)
|
|
|
|
async def get_context_window_async(self) -> ContextWindowOverview:
|
|
if os.getenv("LETTA_ENVIRONMENT") == "PRODUCTION" and model_settings.anthropic_api_key is not None:
|
|
return await self.get_context_window_from_anthropic_async()
|
|
return await self.get_context_window_from_tiktoken_async()
|
|
|
|
async def get_context_window_from_tiktoken_async(self) -> ContextWindowOverview:
|
|
"""Get the context window of the agent"""
|
|
# Grab the in-context messages
|
|
# conversion of messages to OpenAI dict format, which is passed to the token counter
|
|
(in_context_messages, passage_manager_size, message_manager_size) = await asyncio.gather(
|
|
self.agent_manager.get_in_context_messages_async(agent_id=self.agent_state.id, actor=self.user),
|
|
self.passage_manager.size_async(actor=self.user, agent_id=self.agent_state.id),
|
|
self.message_manager.size_async(actor=self.user, agent_id=self.agent_state.id),
|
|
)
|
|
in_context_messages_openai = [m.to_openai_dict() for m in in_context_messages]
|
|
|
|
# Extract system, memory and external summary
|
|
if (
|
|
len(in_context_messages) > 0
|
|
and in_context_messages[0].role == MessageRole.system
|
|
and in_context_messages[0].content
|
|
and len(in_context_messages[0].content) == 1
|
|
and isinstance(in_context_messages[0].content[0], TextContent)
|
|
):
|
|
system_message = in_context_messages[0].content[0].text
|
|
|
|
external_memory_marker_pos = system_message.find("###")
|
|
core_memory_marker_pos = system_message.find("<", external_memory_marker_pos)
|
|
if external_memory_marker_pos != -1 and core_memory_marker_pos != -1:
|
|
system_prompt = system_message[:external_memory_marker_pos].strip()
|
|
external_memory_summary = system_message[external_memory_marker_pos:core_memory_marker_pos].strip()
|
|
core_memory = system_message[core_memory_marker_pos:].strip()
|
|
else:
|
|
# if no markers found, put everything in system message
|
|
system_prompt = system_message
|
|
external_memory_summary = ""
|
|
core_memory = ""
|
|
else:
|
|
# if no system message, fall back on agent's system prompt
|
|
system_prompt = self.agent_state.system
|
|
external_memory_summary = ""
|
|
core_memory = ""
|
|
|
|
num_tokens_system = count_tokens(system_prompt)
|
|
num_tokens_core_memory = count_tokens(core_memory)
|
|
num_tokens_external_memory_summary = count_tokens(external_memory_summary)
|
|
|
|
# Check if there's a summary message in the message queue
|
|
if (
|
|
len(in_context_messages) > 1
|
|
and in_context_messages[1].role == MessageRole.user
|
|
and in_context_messages[1].content
|
|
and len(in_context_messages[1].content) == 1
|
|
and isinstance(in_context_messages[1].content[0], TextContent)
|
|
# TODO remove hardcoding
|
|
and "The following is a summary of the previous " in in_context_messages[1].content[0].text
|
|
):
|
|
# Summary message exists
|
|
text_content = in_context_messages[1].content[0].text
|
|
assert text_content is not None
|
|
summary_memory = text_content
|
|
num_tokens_summary_memory = count_tokens(text_content)
|
|
# with a summary message, the real messages start at index 2
|
|
num_tokens_messages = (
|
|
num_tokens_from_messages(messages=in_context_messages_openai[2:], model=self.model)
|
|
if len(in_context_messages_openai) > 2
|
|
else 0
|
|
)
|
|
|
|
else:
|
|
summary_memory = None
|
|
num_tokens_summary_memory = 0
|
|
# with no summary message, the real messages start at index 1
|
|
num_tokens_messages = (
|
|
num_tokens_from_messages(messages=in_context_messages_openai[1:], model=self.model)
|
|
if len(in_context_messages_openai) > 1
|
|
else 0
|
|
)
|
|
|
|
# tokens taken up by function definitions
|
|
agent_state_tool_jsons = [t.json_schema for t in self.agent_state.tools]
|
|
if agent_state_tool_jsons:
|
|
available_functions_definitions = [OpenAITool(type="function", function=f) for f in agent_state_tool_jsons]
|
|
num_tokens_available_functions_definitions = num_tokens_from_functions(functions=agent_state_tool_jsons, model=self.model)
|
|
else:
|
|
available_functions_definitions = []
|
|
num_tokens_available_functions_definitions = 0
|
|
|
|
num_tokens_used_total = (
|
|
num_tokens_system # system prompt
|
|
+ num_tokens_available_functions_definitions # function definitions
|
|
+ num_tokens_core_memory # core memory
|
|
+ num_tokens_external_memory_summary # metadata (statistics) about recall/archival
|
|
+ num_tokens_summary_memory # summary of ongoing conversation
|
|
+ num_tokens_messages # tokens taken by messages
|
|
)
|
|
assert isinstance(num_tokens_used_total, int)
|
|
|
|
return ContextWindowOverview(
|
|
# context window breakdown (in messages)
|
|
num_messages=len(in_context_messages),
|
|
num_archival_memory=passage_manager_size,
|
|
num_recall_memory=message_manager_size,
|
|
num_tokens_external_memory_summary=num_tokens_external_memory_summary,
|
|
external_memory_summary=external_memory_summary,
|
|
# top-level information
|
|
context_window_size_max=self.agent_state.llm_config.context_window,
|
|
context_window_size_current=num_tokens_used_total,
|
|
# context window breakdown (in tokens)
|
|
num_tokens_system=num_tokens_system,
|
|
system_prompt=system_prompt,
|
|
num_tokens_core_memory=num_tokens_core_memory,
|
|
core_memory=core_memory,
|
|
num_tokens_summary_memory=num_tokens_summary_memory,
|
|
summary_memory=summary_memory,
|
|
num_tokens_messages=num_tokens_messages,
|
|
messages=in_context_messages,
|
|
# related to functions
|
|
num_tokens_functions_definitions=num_tokens_available_functions_definitions,
|
|
functions_definitions=available_functions_definitions,
|
|
)
|
|
|
|
async def get_context_window_from_anthropic_async(self) -> ContextWindowOverview:
|
|
"""Get the context window of the agent"""
|
|
anthropic_client = LLMClient.create(provider_type=ProviderType.anthropic, actor=self.user)
|
|
model = self.agent_state.llm_config.model if self.agent_state.llm_config.model_endpoint_type == "anthropic" else None
|
|
|
|
# Grab the in-context messages
|
|
# conversion of messages to anthropic dict format, which is passed to the token counter
|
|
(in_context_messages, passage_manager_size, message_manager_size) = await asyncio.gather(
|
|
self.agent_manager.get_in_context_messages_async(agent_id=self.agent_state.id, actor=self.user),
|
|
self.passage_manager.size_async(actor=self.user, agent_id=self.agent_state.id),
|
|
self.message_manager.size_async(actor=self.user, agent_id=self.agent_state.id),
|
|
)
|
|
in_context_messages_anthropic = [m.to_anthropic_dict() for m in in_context_messages]
|
|
|
|
# Extract system, memory and external summary
|
|
if (
|
|
len(in_context_messages) > 0
|
|
and in_context_messages[0].role == MessageRole.system
|
|
and in_context_messages[0].content
|
|
and len(in_context_messages[0].content) == 1
|
|
and isinstance(in_context_messages[0].content[0], TextContent)
|
|
):
|
|
system_message = in_context_messages[0].content[0].text
|
|
|
|
external_memory_marker_pos = system_message.find("###")
|
|
core_memory_marker_pos = system_message.find("<", external_memory_marker_pos)
|
|
if external_memory_marker_pos != -1 and core_memory_marker_pos != -1:
|
|
system_prompt = system_message[:external_memory_marker_pos].strip()
|
|
external_memory_summary = system_message[external_memory_marker_pos:core_memory_marker_pos].strip()
|
|
core_memory = system_message[core_memory_marker_pos:].strip()
|
|
else:
|
|
# if no markers found, put everything in system message
|
|
system_prompt = system_message
|
|
external_memory_summary = None
|
|
core_memory = None
|
|
else:
|
|
# if no system message, fall back on agent's system prompt
|
|
system_prompt = self.agent_state.system
|
|
external_memory_summary = None
|
|
core_memory = None
|
|
|
|
num_tokens_system_coroutine = anthropic_client.count_tokens(model=model, messages=[{"role": "user", "content": system_prompt}])
|
|
num_tokens_core_memory_coroutine = (
|
|
anthropic_client.count_tokens(model=model, messages=[{"role": "user", "content": core_memory}])
|
|
if core_memory
|
|
else asyncio.sleep(0, result=0)
|
|
)
|
|
num_tokens_external_memory_summary_coroutine = (
|
|
anthropic_client.count_tokens(model=model, messages=[{"role": "user", "content": external_memory_summary}])
|
|
if external_memory_summary
|
|
else asyncio.sleep(0, result=0)
|
|
)
|
|
|
|
# Check if there's a summary message in the message queue
|
|
if (
|
|
len(in_context_messages) > 1
|
|
and in_context_messages[1].role == MessageRole.user
|
|
and in_context_messages[1].content
|
|
and len(in_context_messages[1].content) == 1
|
|
and isinstance(in_context_messages[1].content[0], TextContent)
|
|
# TODO remove hardcoding
|
|
and "The following is a summary of the previous " in in_context_messages[1].content[0].text
|
|
):
|
|
# Summary message exists
|
|
text_content = in_context_messages[1].content[0].text
|
|
assert text_content is not None
|
|
summary_memory = text_content
|
|
num_tokens_summary_memory_coroutine = anthropic_client.count_tokens(
|
|
model=model, messages=[{"role": "user", "content": summary_memory}]
|
|
)
|
|
# with a summary message, the real messages start at index 2
|
|
num_tokens_messages_coroutine = (
|
|
anthropic_client.count_tokens(model=model, messages=in_context_messages_anthropic[2:])
|
|
if len(in_context_messages_anthropic) > 2
|
|
else asyncio.sleep(0, result=0)
|
|
)
|
|
|
|
else:
|
|
summary_memory = None
|
|
num_tokens_summary_memory_coroutine = asyncio.sleep(0, result=0)
|
|
# with no summary message, the real messages start at index 1
|
|
num_tokens_messages_coroutine = (
|
|
anthropic_client.count_tokens(model=model, messages=in_context_messages_anthropic[1:])
|
|
if len(in_context_messages_anthropic) > 1
|
|
else asyncio.sleep(0, result=0)
|
|
)
|
|
|
|
# tokens taken up by function definitions
|
|
if self.agent_state.tools and len(self.agent_state.tools) > 0:
|
|
available_functions_definitions = [OpenAITool(type="function", function=f.json_schema) for f in self.agent_state.tools]
|
|
num_tokens_available_functions_definitions_coroutine = anthropic_client.count_tokens(
|
|
model=model,
|
|
tools=available_functions_definitions,
|
|
)
|
|
else:
|
|
available_functions_definitions = []
|
|
num_tokens_available_functions_definitions_coroutine = asyncio.sleep(0, result=0)
|
|
|
|
(
|
|
num_tokens_system,
|
|
num_tokens_core_memory,
|
|
num_tokens_external_memory_summary,
|
|
num_tokens_summary_memory,
|
|
num_tokens_messages,
|
|
num_tokens_available_functions_definitions,
|
|
) = await asyncio.gather(
|
|
num_tokens_system_coroutine,
|
|
num_tokens_core_memory_coroutine,
|
|
num_tokens_external_memory_summary_coroutine,
|
|
num_tokens_summary_memory_coroutine,
|
|
num_tokens_messages_coroutine,
|
|
num_tokens_available_functions_definitions_coroutine,
|
|
)
|
|
|
|
num_tokens_used_total = (
|
|
num_tokens_system # system prompt
|
|
+ num_tokens_available_functions_definitions # function definitions
|
|
+ num_tokens_core_memory # core memory
|
|
+ num_tokens_external_memory_summary # metadata (statistics) about recall/archival
|
|
+ num_tokens_summary_memory # summary of ongoing conversation
|
|
+ num_tokens_messages # tokens taken by messages
|
|
)
|
|
assert isinstance(num_tokens_used_total, int)
|
|
|
|
return ContextWindowOverview(
|
|
# context window breakdown (in messages)
|
|
num_messages=len(in_context_messages),
|
|
num_archival_memory=passage_manager_size,
|
|
num_recall_memory=message_manager_size,
|
|
num_tokens_external_memory_summary=num_tokens_external_memory_summary,
|
|
external_memory_summary=external_memory_summary,
|
|
# top-level information
|
|
context_window_size_max=self.agent_state.llm_config.context_window,
|
|
context_window_size_current=num_tokens_used_total,
|
|
# context window breakdown (in tokens)
|
|
num_tokens_system=num_tokens_system,
|
|
system_prompt=system_prompt,
|
|
num_tokens_core_memory=num_tokens_core_memory,
|
|
core_memory=core_memory,
|
|
num_tokens_summary_memory=num_tokens_summary_memory,
|
|
summary_memory=summary_memory,
|
|
num_tokens_messages=num_tokens_messages,
|
|
messages=in_context_messages,
|
|
# related to functions
|
|
num_tokens_functions_definitions=num_tokens_available_functions_definitions,
|
|
functions_definitions=available_functions_definitions,
|
|
)
|
|
|
|
def count_tokens(self) -> int:
|
|
"""Count the tokens in the current context window"""
|
|
context_window_breakdown = self.get_context_window()
|
|
return context_window_breakdown.context_window_size_current
|
|
|
|
# TODO: Refactor into separate class v.s. large if/elses here
|
|
def execute_tool_and_persist_state(self, function_name: str, function_args: dict, target_letta_tool: Tool) -> ToolExecutionResult:
|
|
"""
|
|
Execute tool modifications and persist the state of the agent.
|
|
Note: only some agent state modifications will be persisted, such as data in the AgentState ORM and block data
|
|
"""
|
|
# TODO: add agent manager here
|
|
orig_memory_str = self.agent_state.memory.compile()
|
|
|
|
# TODO: need to have an AgentState object that actually has full access to the block data
|
|
# this is because the sandbox tools need to be able to access block.value to edit this data
|
|
try:
|
|
if target_letta_tool.tool_type == ToolType.LETTA_CORE:
|
|
# base tools are allowed to access the `Agent` object and run on the database
|
|
callable_func = get_function_from_module(LETTA_CORE_TOOL_MODULE_NAME, function_name)
|
|
function_args["self"] = self # need to attach self to arg since it's dynamically linked
|
|
function_response = callable_func(**function_args)
|
|
elif target_letta_tool.tool_type == ToolType.LETTA_MULTI_AGENT_CORE:
|
|
callable_func = get_function_from_module(LETTA_MULTI_AGENT_TOOL_MODULE_NAME, function_name)
|
|
function_args["self"] = self # need to attach self to arg since it's dynamically linked
|
|
function_response = callable_func(**function_args)
|
|
elif target_letta_tool.tool_type == ToolType.LETTA_MEMORY_CORE or target_letta_tool.tool_type == ToolType.LETTA_SLEEPTIME_CORE:
|
|
callable_func = get_function_from_module(LETTA_CORE_TOOL_MODULE_NAME, function_name)
|
|
agent_state_copy = self.agent_state.__deepcopy__()
|
|
function_args["agent_state"] = agent_state_copy # need to attach self to arg since it's dynamically linked
|
|
function_response = callable_func(**function_args)
|
|
self.ensure_read_only_block_not_modified(
|
|
new_memory=agent_state_copy.memory
|
|
) # memory editing tools cannot edit read-only blocks
|
|
self.update_memory_if_changed(agent_state_copy.memory)
|
|
elif target_letta_tool.tool_type == ToolType.EXTERNAL_COMPOSIO:
|
|
action_name = generate_composio_action_from_func_name(target_letta_tool.name)
|
|
# Get entity ID from the agent_state
|
|
entity_id = None
|
|
for env_var in self.agent_state.tool_exec_environment_variables:
|
|
if env_var.key == COMPOSIO_ENTITY_ENV_VAR_KEY:
|
|
entity_id = env_var.value
|
|
# Get composio_api_key
|
|
composio_api_key = get_composio_api_key(actor=self.user, logger=self.logger)
|
|
function_response = execute_composio_action(
|
|
action_name=action_name, args=function_args, api_key=composio_api_key, entity_id=entity_id
|
|
)
|
|
elif target_letta_tool.tool_type == ToolType.EXTERNAL_MCP:
|
|
# Get the server name from the tool tag
|
|
# TODO make a property instead?
|
|
server_name = target_letta_tool.tags[0].split(":")[1]
|
|
|
|
# Get the MCPClient from the server's handle
|
|
# TODO these don't get raised properly
|
|
if not self.mcp_clients:
|
|
raise ValueError(f"No MCP client available to use")
|
|
if server_name not in self.mcp_clients:
|
|
raise ValueError(f"Unknown MCP server name: {server_name}")
|
|
mcp_client = self.mcp_clients[server_name]
|
|
if not isinstance(mcp_client, BaseMCPClient):
|
|
raise RuntimeError(f"Expected an MCPClient, but got: {type(mcp_client)}")
|
|
|
|
# Check that tool exists
|
|
available_tools = mcp_client.list_tools()
|
|
available_tool_names = [t.name for t in available_tools]
|
|
if function_name not in available_tool_names:
|
|
raise ValueError(
|
|
f"{function_name} is not available in MCP server {server_name}. Please check your `~/.letta/mcp_config.json` file."
|
|
)
|
|
|
|
function_response, is_error = mcp_client.execute_tool(tool_name=function_name, tool_args=function_args)
|
|
return ToolExecutionResult(
|
|
status="error" if is_error else "success",
|
|
func_return=function_response,
|
|
)
|
|
else:
|
|
try:
|
|
# Parse the source code to extract function annotations
|
|
annotations = get_function_annotations_from_source(target_letta_tool.source_code, function_name)
|
|
# Coerce the function arguments to the correct types based on the annotations
|
|
function_args = coerce_dict_args_by_annotations(function_args, annotations)
|
|
except ValueError as e:
|
|
self.logger.debug(f"Error coercing function arguments: {e}")
|
|
|
|
# execute tool in a sandbox
|
|
# TODO: allow agent_state to specify which sandbox to execute tools in
|
|
# TODO: This is only temporary, can remove after we publish a pip package with this object
|
|
agent_state_copy = self.agent_state.__deepcopy__()
|
|
agent_state_copy.tools = []
|
|
agent_state_copy.tool_rules = []
|
|
|
|
tool_execution_result = ToolExecutionSandbox(function_name, function_args, self.user, tool_object=target_letta_tool).run(
|
|
agent_state=agent_state_copy
|
|
)
|
|
assert orig_memory_str == self.agent_state.memory.compile(), "Memory should not be modified in a sandbox tool"
|
|
if tool_execution_result.agent_state is not None:
|
|
self.update_memory_if_changed(tool_execution_result.agent_state.memory)
|
|
return tool_execution_result
|
|
except Exception as e:
|
|
# Need to catch error here, or else trunction wont happen
|
|
# TODO: modify to function execution error
|
|
function_response = get_friendly_error_msg(
|
|
function_name=function_name, exception_name=type(e).__name__, exception_message=str(e)
|
|
)
|
|
return ToolExecutionResult(
|
|
status="error",
|
|
func_return=function_response,
|
|
stderr=[traceback.format_exc()],
|
|
)
|
|
|
|
return ToolExecutionResult(
|
|
status="success",
|
|
func_return=function_response,
|
|
)
|
|
|
|
|
|
def save_agent(agent: Agent):
|
|
"""Save agent to metadata store"""
|
|
agent_state = agent.agent_state
|
|
assert isinstance(agent_state.memory, Memory), f"Memory is not a Memory object: {type(agent_state.memory)}"
|
|
|
|
# TODO: move this to agent manager
|
|
# TODO: Completely strip out metadata
|
|
# convert to persisted model
|
|
agent_manager = AgentManager()
|
|
update_agent = UpdateAgent(
|
|
name=agent_state.name,
|
|
tool_ids=[t.id for t in agent_state.tools],
|
|
source_ids=[s.id for s in agent_state.sources],
|
|
block_ids=[b.id for b in agent_state.memory.blocks],
|
|
tags=agent_state.tags,
|
|
system=agent_state.system,
|
|
tool_rules=agent_state.tool_rules,
|
|
llm_config=agent_state.llm_config,
|
|
embedding_config=agent_state.embedding_config,
|
|
message_ids=agent_state.message_ids,
|
|
description=agent_state.description,
|
|
metadata=agent_state.metadata,
|
|
# TODO: Add this back in later
|
|
# tool_exec_environment_variables=agent_state.get_agent_env_vars_as_dict(),
|
|
)
|
|
agent_manager.update_agent(agent_id=agent_state.id, agent_update=update_agent, actor=agent.user)
|
|
|
|
|
|
def strip_name_field_from_user_message(user_message_text: str) -> Tuple[str, Optional[str]]:
|
|
"""If 'name' exists in the JSON string, remove it and return the cleaned text + name value"""
|
|
try:
|
|
user_message_json = dict(json_loads(user_message_text))
|
|
# Special handling for AutoGen messages with 'name' field
|
|
# Treat 'name' as a special field
|
|
# If it exists in the input message, elevate it to the 'message' level
|
|
name = user_message_json.pop("name", None)
|
|
clean_message = json_dumps(user_message_json)
|
|
return clean_message, name
|
|
|
|
except Exception as e:
|
|
print(f"{CLI_WARNING_PREFIX}handling of 'name' field failed with: {e}")
|
|
raise e
|
|
|
|
|
|
def validate_json(user_message_text: str) -> str:
|
|
"""Make sure that the user input message is valid JSON"""
|
|
try:
|
|
user_message_json = dict(json_loads(user_message_text))
|
|
user_message_json_val = json_dumps(user_message_json)
|
|
return user_message_json_val
|
|
except Exception as e:
|
|
print(f"{CLI_WARNING_PREFIX}couldn't parse user input message as JSON: {e}")
|
|
raise e
|