chore: release 0.7.1 (#2583)

This commit is contained in:
Sarah Wooders 2025-04-22 20:16:32 -07:00 committed by GitHub
commit 435b754286
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
47 changed files with 2223 additions and 724 deletions

View File

@ -0,0 +1,31 @@
"""add support for structured_outputs in agents
Revision ID: 28b8765bdd0a
Revises: a3c7d62e08ca
Create Date: 2025-04-18 11:43:47.701786
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "28b8765bdd0a"
down_revision: Union[str, None] = "a3c7d62e08ca"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("agents", sa.Column("response_format", sa.JSON(), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("agents", "response_format")
# ### end Alembic commands ###

View File

@ -1,7 +1,6 @@
import time
from letta_client import AgentType, GroupUpdateManagerConfig, Letta, ManagerType, SleeptimeManagerUpdate
from letta_client.types import agent_type
from letta_client import Letta
client = Letta(base_url="http://localhost:8283")
@ -70,7 +69,3 @@ while job.status != "completed":
print(passage.text)
time.sleep(2)

View File

@ -1,4 +1,4 @@
__version__ = "0.7.0"
__version__ = "0.7.1"
# import clients
from letta.client.client import LocalClient, RESTClient, create_client

View File

@ -3,7 +3,7 @@ import time
import traceback
import warnings
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union
from openai.types.beta.function_tool import FunctionTool as OpenAITool
@ -17,6 +17,7 @@ from letta.constants import (
LETTA_MULTI_AGENT_TOOL_MODULE_NAME,
LLM_MAX_TOKENS,
REQ_HEARTBEAT_MESSAGE,
SEND_MESSAGE_TOOL_NAME,
)
from letta.errors import ContextWindowExceededError
from letta.functions.ast_parsers import coerce_dict_args_by_annotations, get_function_annotations_from_source
@ -27,6 +28,7 @@ from letta.helpers import ToolRulesSolver
from letta.helpers.composio_helpers import get_composio_api_key
from letta.helpers.datetime_helpers import get_utc_time
from letta.helpers.json_helpers import json_dumps, json_loads
from letta.helpers.message_helper import prepare_input_message_create
from letta.interface import AgentInterface
from letta.llm_api.helpers import calculate_summarizer_cutoff, get_token_counts_for_messages, is_context_overflow_error
from letta.llm_api.llm_api_tools import create
@ -42,12 +44,13 @@ from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import MessageRole
from letta.schemas.letta_message_content import TextContent
from letta.schemas.memory import ContextWindowOverview, Memory
from letta.schemas.message import Message, ToolReturn
from letta.schemas.message import Message, MessageCreate, ToolReturn
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
from letta.schemas.openai.chat_completion_response import Message as ChatCompletionMessage
from letta.schemas.openai.chat_completion_response import UsageStatistics
from letta.schemas.sandbox_config import SandboxRunResult
from letta.schemas.response_format import ResponseFormatType
from letta.schemas.tool import Tool
from letta.schemas.tool_execution_result import ToolExecutionResult
from letta.schemas.tool_rule import TerminalToolRule
from letta.schemas.usage import LettaUsageStatistics
from letta.services.agent_manager import AgentManager
@ -78,7 +81,7 @@ class BaseAgent(ABC):
@abstractmethod
def step(
self,
messages: Union[Message, List[Message]],
input_messages: List[MessageCreate],
) -> LettaUsageStatistics:
"""
Top-level event message handler for the agent.
@ -255,6 +258,28 @@ class Agent(BaseAgent):
# Return updated messages
return messages
def _runtime_override_tool_json_schema(
self,
functions_list: List[Dict | None],
) -> List[Dict | None]:
"""Override the tool JSON schema at runtime for a particular tool if conditions are met."""
# Currently just injects `send_message` with a `response_format` if provided to the agent.
if self.agent_state.response_format and self.agent_state.response_format.type != ResponseFormatType.text:
for func in functions_list:
if func["name"] == SEND_MESSAGE_TOOL_NAME:
if self.agent_state.response_format.type == ResponseFormatType.json_schema:
func["parameters"]["properties"]["message"] = self.agent_state.response_format.json_schema["schema"]
if self.agent_state.response_format.type == ResponseFormatType.json_object:
func["parameters"]["properties"]["message"] = {
"type": "object",
"description": "Message contents. All unicode (including emojis) are supported.",
"additionalProperties": True,
"properties": {},
}
break
return functions_list
@trace_method
def _get_ai_reply(
self,
@ -268,27 +293,26 @@ class Agent(BaseAgent):
step_count: Optional[int] = None,
last_function_failed: bool = False,
put_inner_thoughts_first: bool = True,
) -> ChatCompletionResponse:
) -> ChatCompletionResponse | None:
"""Get response from LLM API with robust retry mechanism."""
log_telemetry(self.logger, "_get_ai_reply start")
available_tools = set([t.name for t in self.agent_state.tools])
allowed_tool_names = self.tool_rules_solver.get_allowed_tool_names(
available_tools=available_tools, last_function_response=self.last_function_response
)
agent_state_tool_jsons = [t.json_schema for t in self.agent_state.tools]
allowed_functions = (
agent_state_tool_jsons
if not allowed_tool_names
else [func for func in agent_state_tool_jsons if func["name"] in allowed_tool_names]
)
# Get allowed tools or allow all if none are allowed
allowed_tool_names = self.tool_rules_solver.get_allowed_tool_names(
available_tools=available_tools, last_function_response=self.last_function_response
) or list(available_tools)
# Don't allow a tool to be called if it failed last time
if last_function_failed and self.tool_rules_solver.tool_call_history:
allowed_functions = [f for f in allowed_functions if f["name"] != self.tool_rules_solver.tool_call_history[-1]]
if not allowed_functions:
allowed_tool_names = [f for f in allowed_tool_names if f != self.tool_rules_solver.tool_call_history[-1]]
if not allowed_tool_names:
return None
allowed_functions = [func for func in agent_state_tool_jsons if func["name"] in allowed_tool_names]
allowed_functions = self._runtime_override_tool_json_schema(allowed_functions)
# For the first message, force the initial tool if one is specified
force_tool_call = None
if (
@ -418,7 +442,7 @@ class Agent(BaseAgent):
tool_call_id = response_message.tool_calls[0].id
assert tool_call_id is not None # should be defined
# only necessary to add the tool_cal_id to a function call (antipattern)
# only necessary to add the tool_call_id to a function call (antipattern)
# response_message_dict = response_message.model_dump()
# response_message_dict["tool_call_id"] = tool_call_id
@ -513,6 +537,10 @@ class Agent(BaseAgent):
# Failure case 3: function failed during execution
# NOTE: the msg_obj associated with the "Running " message is the prior assistant message, not the function/tool role message
# this is because the function/tool role message is only created once the function/tool has executed/returned
# handle cases where we return a json message
if "message" in function_args:
function_args["message"] = str(function_args.get("message", ""))
self.interface.function_message(f"Running {function_name}({function_args})", msg_obj=messages[-1], chunk_index=self.chunk_index)
self.chunk_index += 1
try:
@ -529,22 +557,23 @@ class Agent(BaseAgent):
},
)
function_response, sandbox_run_result = self.execute_tool_and_persist_state(function_name, function_args, target_letta_tool)
tool_execution_result = self.execute_tool_and_persist_state(function_name, function_args, target_letta_tool)
function_response = tool_execution_result.func_return
log_event(
"tool_call_ended",
attributes={
"function_response": function_response,
"sandbox_run_result": sandbox_run_result.model_dump() if sandbox_run_result else None,
"tool_execution_result": tool_execution_result.model_dump(),
},
)
log_telemetry(
self.logger, "_handle_ai_response execute tool finish", function_name=function_name, function_args=function_args
)
if sandbox_run_result and sandbox_run_result.status == "error":
if tool_execution_result and tool_execution_result.status == "error":
tool_return = ToolReturn(
status=sandbox_run_result.status, stdout=sandbox_run_result.stdout, stderr=sandbox_run_result.stderr
status=tool_execution_result.status, stdout=tool_execution_result.stdout, stderr=tool_execution_result.stderr
)
messages = self._handle_function_error_response(
function_response,
@ -598,14 +627,10 @@ class Agent(BaseAgent):
# Step 4: check if function response is an error
if function_response_string.startswith(ERROR_MESSAGE_PREFIX):
error_msg = function_response_string
tool_return = (
ToolReturn(
status=sandbox_run_result.status,
stdout=sandbox_run_result.stdout,
stderr=sandbox_run_result.stderr,
)
if sandbox_run_result
else None
tool_return = ToolReturn(
status=tool_execution_result.status,
stdout=tool_execution_result.stdout,
stderr=tool_execution_result.stderr,
)
messages = self._handle_function_error_response(
error_msg,
@ -622,14 +647,10 @@ class Agent(BaseAgent):
# If no failures happened along the way: ...
# Step 5: send the info on the function call and function response to GPT
tool_return = (
ToolReturn(
status=sandbox_run_result.status,
stdout=sandbox_run_result.stdout,
stderr=sandbox_run_result.stderr,
)
if sandbox_run_result
else None
tool_return = ToolReturn(
status=tool_execution_result.status,
stdout=tool_execution_result.stdout,
stderr=tool_execution_result.stderr,
)
messages.append(
Message(
@ -641,7 +662,7 @@ class Agent(BaseAgent):
content=[TextContent(text=function_response)],
tool_call_id=tool_call_id,
# Letta extras
tool_returns=[tool_return] if sandbox_run_result else None,
tool_returns=[tool_return],
group_id=group_id,
)
) # extend conversation with function response
@ -691,7 +712,7 @@ class Agent(BaseAgent):
@trace_method
def step(
self,
messages: Union[Message, List[Message]],
input_messages: List[MessageCreate],
# additional args
chaining: bool = True,
max_chaining_steps: Optional[int] = None,
@ -704,7 +725,9 @@ class Agent(BaseAgent):
# But just to be safe
self.tool_rules_solver.clear_tool_history()
next_input_message = messages if isinstance(messages, list) else [messages]
# Convert MessageCreate objects to Message objects
message_objects = [prepare_input_message_create(m, self.agent_state.id, True, True) for m in input_messages]
next_input_messages = message_objects
counter = 0
total_usage = UsageStatistics()
step_count = 0
@ -715,7 +738,7 @@ class Agent(BaseAgent):
kwargs["step_count"] = step_count
kwargs["last_function_failed"] = function_failed
step_response = self.inner_step(
messages=next_input_message,
messages=next_input_messages,
put_inner_thoughts_first=put_inner_thoughts_first,
**kwargs,
)
@ -745,36 +768,42 @@ class Agent(BaseAgent):
# Chain handlers
elif token_warning and summarizer_settings.send_memory_warning_message:
assert self.agent_state.created_by_id is not None
next_input_message = Message.dict_to_message(
agent_id=self.agent_state.id,
model=self.model,
openai_message_dict={
"role": "user", # TODO: change to system?
"content": get_token_limit_warning(),
},
)
next_input_messages = [
Message.dict_to_message(
agent_id=self.agent_state.id,
model=self.model,
openai_message_dict={
"role": "user", # TODO: change to system?
"content": get_token_limit_warning(),
},
),
]
continue # always chain
elif function_failed:
assert self.agent_state.created_by_id is not None
next_input_message = Message.dict_to_message(
agent_id=self.agent_state.id,
model=self.model,
openai_message_dict={
"role": "user", # TODO: change to system?
"content": get_heartbeat(FUNC_FAILED_HEARTBEAT_MESSAGE),
},
)
next_input_messages = [
Message.dict_to_message(
agent_id=self.agent_state.id,
model=self.model,
openai_message_dict={
"role": "user", # TODO: change to system?
"content": get_heartbeat(FUNC_FAILED_HEARTBEAT_MESSAGE),
},
)
]
continue # always chain
elif heartbeat_request:
assert self.agent_state.created_by_id is not None
next_input_message = Message.dict_to_message(
agent_id=self.agent_state.id,
model=self.model,
openai_message_dict={
"role": "user", # TODO: change to system?
"content": get_heartbeat(REQ_HEARTBEAT_MESSAGE),
},
)
next_input_messages = [
Message.dict_to_message(
agent_id=self.agent_state.id,
model=self.model,
openai_message_dict={
"role": "user", # TODO: change to system?
"content": get_heartbeat(REQ_HEARTBEAT_MESSAGE),
},
)
]
continue # always chain
# Letta no-op / yield
else:
@ -788,7 +817,7 @@ class Agent(BaseAgent):
def inner_step(
self,
messages: Union[Message, List[Message]],
messages: List[Message],
first_message: bool = False,
first_message_retry_limit: int = FIRST_MESSAGE_ATTEMPTS,
skip_verify: bool = False,
@ -814,11 +843,8 @@ class Agent(BaseAgent):
self.update_memory_if_changed(current_persisted_memory)
# Step 1: add user message
if isinstance(messages, Message):
messages = [messages]
if not all(isinstance(m, Message) for m in messages):
raise ValueError(f"messages should be a Message or a list of Message, got {type(messages)}")
raise ValueError(f"messages should be a list of Message, got {[type(m) for m in messages]}")
in_context_messages = self.agent_manager.get_in_context_messages(agent_id=self.agent_state.id, actor=self.user)
input_message_sequence = in_context_messages + messages
@ -1229,9 +1255,7 @@ class Agent(BaseAgent):
return context_window_breakdown.context_window_size_current
# TODO: Refactor into separate class v.s. large if/elses here
def execute_tool_and_persist_state(
self, function_name: str, function_args: dict, target_letta_tool: Tool
) -> tuple[Any, Optional[SandboxRunResult]]:
def execute_tool_and_persist_state(self, function_name: str, function_args: dict, target_letta_tool: Tool) -> ToolExecutionResult:
"""
Execute tool modifications and persist the state of the agent.
Note: only some agent state modifications will be persisted, such as data in the AgentState ORM and block data
@ -1293,8 +1317,10 @@ class Agent(BaseAgent):
)
function_response, is_error = mcp_client.execute_tool(tool_name=function_name, tool_args=function_args)
sandbox_run_result = SandboxRunResult(status="error" if is_error else "success")
return function_response, sandbox_run_result
return ToolExecutionResult(
status="error" if is_error else "success",
func_return=function_response,
)
else:
try:
# Parse the source code to extract function annotations
@ -1311,23 +1337,29 @@ class Agent(BaseAgent):
agent_state_copy.tools = []
agent_state_copy.tool_rules = []
sandbox_run_result = ToolExecutionSandbox(function_name, function_args, self.user, tool_object=target_letta_tool).run(
tool_execution_result = ToolExecutionSandbox(function_name, function_args, self.user, tool_object=target_letta_tool).run(
agent_state=agent_state_copy
)
function_response, updated_agent_state = sandbox_run_result.func_return, sandbox_run_result.agent_state
assert orig_memory_str == self.agent_state.memory.compile(), "Memory should not be modified in a sandbox tool"
if updated_agent_state is not None:
self.update_memory_if_changed(updated_agent_state.memory)
return function_response, sandbox_run_result
if tool_execution_result.agent_state is not None:
self.update_memory_if_changed(tool_execution_result.agent_state.memory)
return tool_execution_result
except Exception as e:
# Need to catch error here, or else trunction wont happen
# TODO: modify to function execution error
function_response = get_friendly_error_msg(
function_name=function_name, exception_name=type(e).__name__, exception_message=str(e)
)
return function_response, SandboxRunResult(status="error")
return ToolExecutionResult(
status="error",
func_return=function_response,
stderr=[traceback.format_exc()],
)
return function_response, None
return ToolExecutionResult(
status="success",
func_return=function_response,
)
def save_agent(agent: Agent):

View File

@ -324,11 +324,11 @@ class LettaAgent(BaseAgent):
tool_execution_manager = ToolExecutionManager(agent_state=agent_state, actor=self.actor)
# TODO: Integrate sandbox result
log_event(name=f"start_{tool_name}_execution", attributes=tool_args)
function_response, _ = await tool_execution_manager.execute_tool_async(
tool_execution_result = await tool_execution_manager.execute_tool_async(
function_name=tool_name, function_args=tool_args, tool=target_tool
)
log_event(name=f"finish_{tool_name}_execution", attributes=tool_args)
return function_response, True
return tool_execution_result.func_return, True
except Exception as e:
return f"Failed to call tool. Error: {e}", False

View File

@ -37,6 +37,7 @@ from letta.services.passage_manager import PassageManager
from letta.services.sandbox_config_manager import SandboxConfigManager
from letta.services.tool_executor.tool_execution_manager import ToolExecutionManager
from letta.settings import tool_settings
from letta.tracing import log_event, trace_method
from letta.utils import united_diff
logger = get_logger(__name__)
@ -82,12 +83,12 @@ async def execute_tool_wrapper(params: ToolExecutionParams):
sandbox_config=params.sbx_config,
sandbox_env_vars=params.sbx_env_vars,
)
result, _ = await mgr.execute_tool_async(
tool_execution_result = await mgr.execute_tool_async(
function_name=params.tool_call_name,
function_args=params.tool_args,
tool=target_tool,
)
return params.agent_id, (result, True)
return params.agent_id, (tool_execution_result.func_return, True)
except Exception as e:
return params.agent_id, (f"Failed to call tool. Error: {e}", False)
@ -120,55 +121,54 @@ class LettaAgentBatch:
self.actor = actor
self.max_steps = max_steps
@trace_method
async def step_until_request(
self,
batch_requests: List[LettaBatchRequest],
letta_batch_job_id: str,
agent_step_state_mapping: Optional[Dict[str, AgentStepState]] = None,
) -> LettaBatchResponse:
# Basic checks
log_event(name="validate_inputs")
if not batch_requests:
raise ValueError("Empty list of batch_requests passed in!")
if agent_step_state_mapping is None:
agent_step_state_mapping = {}
log_event(name="load_and_prepare_agents")
agent_messages_mapping: Dict[str, List[Message]] = {}
agent_tools_mapping: Dict[str, List[dict]] = {}
agent_states = []
for batch_request in batch_requests:
agent_id = batch_request.agent_id
agent_state = self.agent_manager.get_agent_by_id(agent_id, actor=self.actor)
agent_states.append(agent_state)
agent_messages_mapping[agent_id] = self._get_in_context_messages_per_agent(
agent_state=agent_state, input_messages=batch_request.messages
)
# TODO: Think about a cleaner way to do this?
if agent_id not in agent_step_state_mapping:
agent_step_state_mapping[agent_id] = AgentStepState(
step_number=0, tool_rules_solver=ToolRulesSolver(tool_rules=agent_state.tool_rules)
)
agent_tools_mapping[agent_id] = self._prepare_tools_per_agent(
agent_state, agent_step_state_mapping.get(agent_id).tool_rules_solver
)
agent_tools_mapping[agent_id] = self._prepare_tools_per_agent(agent_state, agent_step_state_mapping[agent_id].tool_rules_solver)
# TODO: This is a hack, this is because LLM client expects a LLM config
# TODO: But that doesn't really work in batch land
# TODO: @caren will factor this out
log_event(name="init_llm_client")
llm_client = LLMClient.create(
llm_config=agent_states[0].llm_config,
put_inner_thoughts_first=True,
)
agent_llm_config_mapping = {agent_state.id: agent_state.llm_config for agent_state in agent_states}
agent_llm_config_mapping = {s.id: s.llm_config for s in agent_states}
log_event(name="send_llm_batch_request")
batch_response = await llm_client.send_llm_batch_request_async(
agent_messages_mapping=agent_messages_mapping,
agent_tools_mapping=agent_tools_mapping,
agent_llm_config_mapping=agent_llm_config_mapping,
)
# Write the response into the jobs table, where it will get picked up by the next cron run
log_event(name="persist_llm_batch_job")
llm_batch_job = self.batch_manager.create_llm_batch_job(
llm_provider=ProviderType.anthropic, # TODO: Expand to more providers
create_batch_response=batch_response,
@ -177,24 +177,26 @@ class LettaAgentBatch:
letta_batch_job_id=letta_batch_job_id,
)
# Create batch items in bulk for all agents
log_event(name="prepare_batch_items")
batch_items = []
for agent_state in agent_states:
agent_step_state = agent_step_state_mapping.get(agent_state.id)
batch_item = LLMBatchItem(
llm_batch_id=llm_batch_job.id,
agent_id=agent_state.id,
llm_config=agent_state.llm_config,
request_status=JobStatus.created,
step_status=AgentStepStatus.paused,
step_state=agent_step_state,
for state in agent_states:
step_state = agent_step_state_mapping[state.id]
batch_items.append(
LLMBatchItem(
llm_batch_id=llm_batch_job.id,
agent_id=state.id,
llm_config=state.llm_config,
request_status=JobStatus.created,
step_status=AgentStepStatus.paused,
step_state=step_state,
)
)
batch_items.append(batch_item)
# Create all batch items at once using the bulk operation
if batch_items:
log_event(name="bulk_create_batch_items")
self.batch_manager.create_llm_batch_items_bulk(batch_items, actor=self.actor)
log_event(name="return_batch_response")
return LettaBatchResponse(
letta_batch_id=llm_batch_job.letta_batch_job_id,
last_llm_batch_id=llm_batch_job.id,
@ -204,27 +206,27 @@ class LettaAgentBatch:
created_at=llm_batch_job.created_at,
)
@trace_method
async def resume_step_after_request(self, letta_batch_id: str, llm_batch_id: str) -> LettaBatchResponse:
# 1. gather everything we need
log_event(name="load_context")
llm_batch_job = self.batch_manager.get_llm_batch_job_by_id(llm_batch_id=llm_batch_id, actor=self.actor)
ctx = await self._collect_resume_context(llm_batch_id)
# 2. persist requestlevel status updates
log_event(name="update_statuses")
self._update_request_statuses(ctx.request_status_updates)
# 3. run the tools in parallel
log_event(name="exec_tools")
exec_results = await self._execute_tools(ctx)
# 4. create + save assistant/tool messages
log_event(name="persist_messages")
msg_map = self._persist_tool_messages(exec_results, ctx)
# 5. mark steps complete
log_event(name="mark_steps_done")
self._mark_steps_complete(llm_batch_id, ctx.agent_ids)
# 6. build nextround requests / stepstate map
log_event(name="prepare_next")
next_reqs, next_step_state = self._prepare_next_iteration(exec_results, ctx, msg_map)
if len(next_reqs) == 0:
# mark batch job as completed
self.job_manager.update_job_by_id(job_id=letta_batch_id, job_update=JobUpdate(status=JobStatus.completed), actor=self.actor)
return LettaBatchResponse(
letta_batch_id=llm_batch_job.letta_batch_job_id,
@ -235,15 +237,16 @@ class LettaAgentBatch:
created_at=llm_batch_job.created_at,
)
# 7. recurse into the normal stepping pipeline
return await self.step_until_request(
batch_requests=next_reqs,
letta_batch_job_id=letta_batch_id,
agent_step_state_mapping=next_step_state,
)
@trace_method
async def _collect_resume_context(self, llm_batch_id: str) -> _ResumeContext:
batch_items = self.batch_manager.list_llm_batch_items(llm_batch_id=llm_batch_id)
# NOTE: We only continue for items with successful results
batch_items = self.batch_manager.list_llm_batch_items(llm_batch_id=llm_batch_id, request_status=JobStatus.completed)
agent_ids, agent_state_map = [], {}
provider_results, name_map, args_map, cont_map = {}, {}, {}, {}
@ -300,6 +303,7 @@ class LettaAgentBatch:
env = self.sandbox_config_manager.get_sandbox_env_vars_as_dict(cfg.id, actor=self.actor, limit=100)
return cfg, env
@trace_method
async def _execute_tools(self, ctx: _ResumeContext) -> Sequence[Tuple[str, Tuple[str, bool]]]:
sbx_cfg, sbx_env = self._build_sandbox()
params = [

View File

@ -32,6 +32,7 @@ from letta.schemas.message import Message, MessageCreate
from letta.schemas.openai.chat_completion_response import UsageStatistics
from letta.schemas.organization import Organization
from letta.schemas.passage import Passage
from letta.schemas.response_format import ResponseFormatUnion
from letta.schemas.run import Run
from letta.schemas.sandbox_config import E2BSandboxConfig, LocalSandboxConfig, SandboxConfig, SandboxConfigCreate, SandboxConfigUpdate
from letta.schemas.source import Source, SourceCreate, SourceUpdate
@ -100,6 +101,7 @@ class AbstractClient(object):
message_ids: Optional[List[str]] = None,
memory: Optional[Memory] = None,
tags: Optional[List[str]] = None,
response_format: Optional[ResponseFormatUnion] = None,
):
raise NotImplementedError
@ -553,6 +555,7 @@ class RESTClient(AbstractClient):
initial_message_sequence: Optional[List[Message]] = None,
tags: Optional[List[str]] = None,
message_buffer_autoclear: bool = False,
response_format: Optional[ResponseFormatUnion] = None,
) -> AgentState:
"""Create an agent
@ -615,6 +618,7 @@ class RESTClient(AbstractClient):
"include_base_tools": include_base_tools,
"message_buffer_autoclear": message_buffer_autoclear,
"include_multi_agent_tools": include_multi_agent_tools,
"response_format": response_format,
}
# Only add name if it's not None
@ -653,6 +657,7 @@ class RESTClient(AbstractClient):
embedding_config: Optional[EmbeddingConfig] = None,
message_ids: Optional[List[str]] = None,
tags: Optional[List[str]] = None,
response_format: Optional[ResponseFormatUnion] = None,
) -> AgentState:
"""
Update an existing agent
@ -682,6 +687,7 @@ class RESTClient(AbstractClient):
llm_config=llm_config,
embedding_config=embedding_config,
message_ids=message_ids,
response_format=response_format,
)
response = requests.patch(f"{self.base_url}/{self.api_prefix}/agents/{agent_id}", json=request.model_dump(), headers=self.headers)
if response.status_code != 200:
@ -2425,6 +2431,7 @@ class LocalClient(AbstractClient):
llm_config: Optional[LLMConfig] = None,
embedding_config: Optional[EmbeddingConfig] = None,
message_ids: Optional[List[str]] = None,
response_format: Optional[ResponseFormatUnion] = None,
):
"""
Update an existing agent
@ -2458,6 +2465,7 @@ class LocalClient(AbstractClient):
llm_config=llm_config,
embedding_config=embedding_config,
message_ids=message_ids,
response_format=response_format,
),
actor=self.user,
)
@ -2661,7 +2669,7 @@ class LocalClient(AbstractClient):
response (LettaResponse): Response from the agent
"""
self.interface.clear()
usage = self.server.send_messages(actor=self.user, agent_id=agent_id, messages=messages)
usage = self.server.send_messages(actor=self.user, agent_id=agent_id, input_messages=messages)
# format messages
return LettaResponse(messages=messages, usage=usage)
@ -2703,7 +2711,7 @@ class LocalClient(AbstractClient):
usage = self.server.send_messages(
actor=self.user,
agent_id=agent_id,
messages=[MessageCreate(role=MessageRole(role), content=message, name=name)],
input_messages=[MessageCreate(role=MessageRole(role), content=message, name=name)],
)
## TODO: need to make sure date/timestamp is propely passed

View File

@ -47,13 +47,14 @@ DEFAULT_PERSONA = "sam_pov"
DEFAULT_HUMAN = "basic"
DEFAULT_PRESET = "memgpt_chat"
SEND_MESSAGE_TOOL_NAME = "send_message"
# Base tools that cannot be edited, as they access agent state directly
# Note that we don't include "conversation_search_date" for now
BASE_TOOLS = ["send_message", "conversation_search", "archival_memory_insert", "archival_memory_search"]
BASE_TOOLS = [SEND_MESSAGE_TOOL_NAME, "conversation_search", "archival_memory_insert", "archival_memory_search"]
# Base memory tools CAN be edited, and are added by default by the server
BASE_MEMORY_TOOLS = ["core_memory_append", "core_memory_replace"]
# Base tools if the memgpt agent has enable_sleeptime on
BASE_SLEEPTIME_CHAT_TOOLS = ["send_message", "conversation_search", "archival_memory_search"]
BASE_SLEEPTIME_CHAT_TOOLS = [SEND_MESSAGE_TOOL_NAME, "conversation_search", "archival_memory_search"]
# Base memory tools for sleeptime agent
BASE_SLEEPTIME_TOOLS = [
"memory_replace",
@ -72,7 +73,7 @@ LETTA_TOOL_SET = set(BASE_TOOLS + BASE_MEMORY_TOOLS + MULTI_AGENT_TOOLS + BASE_S
# The name of the tool used to send message to the user
# May not be relevant in cases where the agent has multiple ways to message to user (send_imessage, send_discord_mesasge, ...)
# or in cases where the agent has no concept of messaging a user (e.g. a workflow agent)
DEFAULT_MESSAGE_TOOL = "send_message"
DEFAULT_MESSAGE_TOOL = SEND_MESSAGE_TOOL_NAME
DEFAULT_MESSAGE_TOOL_KWARG = "message"
PRE_EXECUTION_MESSAGE_ARG = "pre_exec_msg"

View File

@ -9,7 +9,6 @@ from letta.functions.helpers import (
extract_send_message_from_steps_messages,
fire_and_forget_send_to_agent,
)
from letta.helpers.message_helper import prepare_input_message_create
from letta.schemas.enums import MessageRole
from letta.schemas.message import MessageCreate
from letta.server.rest_api.utils import get_letta_server
@ -109,11 +108,10 @@ def send_message_to_agents_matching_tags(self: "Agent", message: str, match_all:
# Prepare the message
messages = [MessageCreate(role=MessageRole.system, content=augmented_message, name=self.agent_state.name)]
input_messages = [prepare_input_message_create(m, agent_id) for m in messages]
# Run .step() and return the response
usage_stats = agent.step(
messages=input_messages,
input_messages=messages,
chaining=True,
max_chaining_steps=None,
stream=False,

View File

@ -352,7 +352,7 @@ async def send_message_to_agent_no_stream(
server: "SyncServer",
agent_id: str,
actor: User,
messages: Union[List[Message], List[MessageCreate]],
messages: List[MessageCreate],
metadata: Optional[dict] = None,
) -> LettaResponse:
"""
@ -368,7 +368,7 @@ async def send_message_to_agent_no_stream(
server.send_messages,
actor=actor,
agent_id=agent_id,
messages=messages,
input_messages=messages,
interface=interface,
metadata=metadata,
)
@ -478,7 +478,7 @@ def fire_and_forget_send_to_agent(
await server.send_message_to_agent(
agent_id=other_agent_id,
actor=sender_agent.user,
messages=messages,
input_messages=messages,
stream_steps=False,
stream_tokens=False,
use_assistant_message=True,

View File

@ -35,7 +35,7 @@ class DynamicMultiAgent(Agent):
def step(
self,
messages: List[MessageCreate],
input_messages: List[MessageCreate],
chaining: bool = True,
max_chaining_steps: Optional[int] = None,
put_inner_thoughts_first: bool = True,
@ -43,27 +43,43 @@ class DynamicMultiAgent(Agent):
) -> LettaUsageStatistics:
total_usage = UsageStatistics()
step_count = 0
speaker_id = None
# Load settings
token_streaming = self.interface.streaming_mode if hasattr(self.interface, "streaming_mode") else False
metadata = self.interface.metadata if hasattr(self.interface, "metadata") else None
agents = {}
# Load agents and initialize chat history with indexing
agents = {self.agent_state.id: self.load_manager_agent()}
message_index = {self.agent_state.id: 0}
agents[self.agent_state.id] = self.load_manager_agent()
chat_history: List[MessageCreate] = []
for agent_id in self.agent_ids:
agents[agent_id] = self.load_participant_agent(agent_id=agent_id)
message_index[agent_id] = 0
chat_history: List[Message] = []
new_messages = messages
speaker_id = None
# Prepare new messages
new_messages = []
for message in input_messages:
if isinstance(message.content, str):
message.content = [TextContent(text=message.content)]
message.group_id = self.group_id
new_messages.append(message)
try:
for _ in range(self.max_turns):
# Prepare manager message
agent_id_options = [agent_id for agent_id in self.agent_ids if agent_id != speaker_id]
manager_message = self.ask_manager_to_choose_participant_message(new_messages, chat_history, agent_id_options)
manager_message = self.ask_manager_to_choose_participant_message(
manager_agent_id=self.agent_state.id,
new_messages=new_messages,
chat_history=chat_history,
agent_id_options=agent_id_options,
)
# Perform manager step
manager_agent = agents[self.agent_state.id]
usage_stats = manager_agent.step(
messages=[manager_message],
input_messages=[manager_message],
chaining=chaining,
max_chaining_steps=max_chaining_steps,
stream=token_streaming,
@ -71,42 +87,27 @@ class DynamicMultiAgent(Agent):
metadata=metadata,
put_inner_thoughts_first=put_inner_thoughts_first,
)
# Parse manager response
responses = Message.to_letta_messages_from_list(manager_agent.last_response_messages)
assistant_message = [response for response in responses if response.message_type == "assistant_message"][0]
for name, agent_id in [(agents[agent_id].agent_state.name, agent_id) for agent_id in agent_id_options]:
if name.lower() in assistant_message.content.lower():
speaker_id = agent_id
# sum usage
# Sum usage
total_usage.prompt_tokens += usage_stats.prompt_tokens
total_usage.completion_tokens += usage_stats.completion_tokens
total_usage.total_tokens += usage_stats.total_tokens
step_count += 1
# initialize input messages
for message in chat_history[message_index[speaker_id] :]:
message.id = Message.generate_id()
message.agent_id = speaker_id
# Update chat history
chat_history.extend(new_messages)
for message in new_messages:
chat_history.append(
Message(
agent_id=speaker_id,
role=message.role,
content=[TextContent(text=message.content)],
name=message.name,
model=None,
tool_calls=None,
tool_call_id=None,
group_id=self.group_id,
otid=message.otid,
)
)
# load agent and perform step
# Perform participant step
participant_agent = agents[speaker_id]
usage_stats = participant_agent.step(
messages=chat_history[message_index[speaker_id] :],
input_messages=chat_history[message_index[speaker_id] :],
chaining=chaining,
max_chaining_steps=max_chaining_steps,
stream=token_streaming,
@ -115,54 +116,54 @@ class DynamicMultiAgent(Agent):
put_inner_thoughts_first=put_inner_thoughts_first,
)
# parse new messages for next step
# Parse participant response
responses = Message.to_letta_messages_from_list(
participant_agent.last_response_messages,
)
assistant_messages = [response for response in responses if response.message_type == "assistant_message"]
new_messages = [
MessageCreate(
role="system",
content=message.content,
content=[TextContent(text=message.content)] if isinstance(message.content, str) else message.content,
name=participant_agent.agent_state.name,
otid=message.otid,
sender_id=participant_agent.agent_state.id,
group_id=self.group_id,
)
for message in assistant_messages
]
# Update message index
message_index[speaker_id] = len(chat_history) + len(new_messages)
# sum usage
# Sum usage
total_usage.prompt_tokens += usage_stats.prompt_tokens
total_usage.completion_tokens += usage_stats.completion_tokens
total_usage.total_tokens += usage_stats.total_tokens
step_count += 1
# check for termination token
# Check for termination token
if any(self.termination_token in message.content for message in new_messages):
break
# persist remaining chat history
for message in new_messages:
chat_history.append(
Message(
agent_id=agent_id,
role=message.role,
content=[TextContent(text=message.content)],
name=message.name,
model=None,
tool_calls=None,
tool_call_id=None,
group_id=self.group_id,
)
)
# Persist remaining chat history
chat_history.extend(new_messages)
for agent_id, index in message_index.items():
if agent_id == speaker_id:
continue
messages_to_persist = []
for message in chat_history[index:]:
message.id = Message.generate_id()
message.agent_id = agent_id
self.message_manager.create_many_messages(chat_history[index:], actor=self.user)
message_to_persist = Message(
role=message.role,
content=message.content,
name=message.name,
otid=message.otid,
sender_id=message.sender_id,
group_id=message.group_id,
agent_id=agent_id,
)
messages_to_persist.append(message_to_persist)
self.message_manager.create_many_messages(messages_to_persist, actor=self.user)
except Exception as e:
raise e
@ -249,10 +250,11 @@ class DynamicMultiAgent(Agent):
def ask_manager_to_choose_participant_message(
self,
manager_agent_id: str,
new_messages: List[MessageCreate],
chat_history: List[Message],
agent_id_options: List[str],
) -> Message:
) -> MessageCreate:
text_chat_history = [f"{message.name or 'user'}: {message.content[0].text}" for message in chat_history]
for message in new_messages:
text_chat_history.append(f"{message.name or 'user'}: {message.content}")
@ -264,14 +266,11 @@ class DynamicMultiAgent(Agent):
"respond to the messages yourself, your task is only to decide the "
f"next speaker, not to participate. \nChat history:\n{context_messages}"
)
return Message(
agent_id=self.agent_state.id,
return MessageCreate(
role="user",
content=[TextContent(text=message_text)],
name=None,
model=None,
tool_calls=None,
tool_call_id=None,
group_id=self.group_id,
otid=Message.generate_otid(),
sender_id=manager_agent_id,
group_id=self.group_id,
)

View File

@ -29,7 +29,7 @@ class RoundRobinMultiAgent(Agent):
def step(
self,
messages: List[MessageCreate],
input_messages: List[MessageCreate],
chaining: bool = True,
max_chaining_steps: Optional[int] = None,
put_inner_thoughts_first: bool = True,
@ -37,46 +37,39 @@ class RoundRobinMultiAgent(Agent):
) -> LettaUsageStatistics:
total_usage = UsageStatistics()
step_count = 0
speaker_id = None
# Load settings
token_streaming = self.interface.streaming_mode if hasattr(self.interface, "streaming_mode") else False
metadata = self.interface.metadata if hasattr(self.interface, "metadata") else None
agents = {}
# Load agents and initialize chat history with indexing
agents, message_index = {}, {}
chat_history: List[MessageCreate] = []
for agent_id in self.agent_ids:
agents[agent_id] = self.load_participant_agent(agent_id=agent_id)
message_index[agent_id] = 0
# Prepare new messages
new_messages = []
for message in input_messages:
if isinstance(message.content, str):
message.content = [TextContent(text=message.content)]
message.group_id = self.group_id
new_messages.append(message)
message_index = {agent_id: 0 for agent_id in self.agent_ids}
chat_history: List[Message] = []
new_messages = messages
speaker_id = None
try:
for i in range(self.max_turns):
# Select speaker
speaker_id = self.agent_ids[i % len(self.agent_ids)]
# initialize input messages
start_index = message_index[speaker_id] if speaker_id in message_index else 0
for message in chat_history[start_index:]:
message.id = Message.generate_id()
message.agent_id = speaker_id
for message in new_messages:
chat_history.append(
Message(
agent_id=speaker_id,
role=message.role,
content=[TextContent(text=message.content)],
name=message.name,
model=None,
tool_calls=None,
tool_call_id=None,
group_id=self.group_id,
otid=message.otid,
)
)
# Update chat history
chat_history.extend(new_messages)
# load agent and perform step
# Perform participant step
participant_agent = agents[speaker_id]
usage_stats = participant_agent.step(
messages=chat_history[start_index:],
input_messages=chat_history[message_index[speaker_id] :],
chaining=chaining,
max_chaining_steps=max_chaining_steps,
stream=token_streaming,
@ -85,47 +78,48 @@ class RoundRobinMultiAgent(Agent):
put_inner_thoughts_first=put_inner_thoughts_first,
)
# parse new messages for next step
# Parse participant response
responses = Message.to_letta_messages_from_list(participant_agent.last_response_messages)
assistant_messages = [response for response in responses if response.message_type == "assistant_message"]
new_messages = [
MessageCreate(
role="system",
content=message.content,
name=message.name,
content=[TextContent(text=message.content)] if isinstance(message.content, str) else message.content,
name=participant_agent.agent_state.name,
otid=message.otid,
sender_id=participant_agent.agent_state.id,
group_id=self.group_id,
)
for message in assistant_messages
]
# Update message index
message_index[speaker_id] = len(chat_history) + len(new_messages)
# sum usage
# Sum usage
total_usage.prompt_tokens += usage_stats.prompt_tokens
total_usage.completion_tokens += usage_stats.completion_tokens
total_usage.total_tokens += usage_stats.total_tokens
step_count += 1
# persist remaining chat history
for message in new_messages:
chat_history.append(
Message(
agent_id=agent_id,
role=message.role,
content=[TextContent(text=message.content)],
name=message.name,
model=None,
tool_calls=None,
tool_call_id=None,
group_id=self.group_id,
)
)
# Persist remaining chat history
chat_history.extend(new_messages)
for agent_id, index in message_index.items():
if agent_id == speaker_id:
continue
messages_to_persist = []
for message in chat_history[index:]:
message.id = Message.generate_id()
message.agent_id = agent_id
self.message_manager.create_many_messages(chat_history[index:], actor=self.user)
message_to_persist = Message(
role=message.role,
content=message.content,
name=message.name,
otid=message.otid,
sender_id=message.sender_id,
group_id=self.group_id,
agent_id=agent_id,
)
messages_to_persist.append(message_to_persist)
self.message_manager.create_many_messages(messages_to_persist, actor=self.user)
except Exception as e:
raise e

View File

@ -143,8 +143,21 @@ class SleeptimeMultiAgent(Agent):
group_id=self.group_id,
)
]
# Convert Message objects to MessageCreate objects
message_creates = [
MessageCreate(
role=m.role,
content=m.content[0].text if m.content and len(m.content) == 1 else m.content,
name=m.name,
otid=m.otid,
sender_id=m.sender_id,
)
for m in participant_agent_messages
]
result = participant_agent.step(
messages=participant_agent_messages,
input_messages=message_creates,
chaining=chaining,
max_chaining_steps=max_chaining_steps,
stream=token_streaming,
@ -173,7 +186,7 @@ class SleeptimeMultiAgent(Agent):
def step(
self,
messages: List[MessageCreate],
input_messages: List[MessageCreate],
chaining: bool = True,
max_chaining_steps: Optional[int] = None,
put_inner_thoughts_first: bool = True,
@ -181,33 +194,28 @@ class SleeptimeMultiAgent(Agent):
) -> LettaUsageStatistics:
run_ids = []
# Load settings
token_streaming = self.interface.streaming_mode if hasattr(self.interface, "streaming_mode") else False
metadata = self.interface.metadata if hasattr(self.interface, "metadata") else None
messages = [
Message(
id=Message.generate_id(),
agent_id=self.agent_state.id,
role=message.role,
content=[TextContent(text=message.content)] if isinstance(message.content, str) else message.content,
name=message.name,
model=None,
tool_calls=None,
tool_call_id=None,
group_id=self.group_id,
otid=message.otid,
)
for message in messages
]
# Prepare new messages
new_messages = []
for message in input_messages:
if isinstance(message.content, str):
message.content = [TextContent(text=message.content)]
message.group_id = self.group_id
new_messages.append(message)
try:
# Load main agent
main_agent = Agent(
agent_state=self.agent_state,
interface=self.interface,
user=self.user,
)
# Perform main agent step
usage_stats = main_agent.step(
messages=messages,
input_messages=new_messages,
chaining=chaining,
max_chaining_steps=max_chaining_steps,
stream=token_streaming,
@ -216,10 +224,12 @@ class SleeptimeMultiAgent(Agent):
put_inner_thoughts_first=put_inner_thoughts_first,
)
# Update turns counter
turns_counter = None
if self.sleeptime_agent_frequency is not None and self.sleeptime_agent_frequency > 0:
turns_counter = self.group_manager.bump_turns_counter(group_id=self.group_id, actor=self.user)
# Perform participant steps
if self.sleeptime_agent_frequency is None or (
turns_counter is not None and turns_counter % self.sleeptime_agent_frequency == 0
):

View File

@ -9,7 +9,7 @@ from letta.interface import AgentInterface
from letta.orm import User
from letta.orm.enums import ToolType
from letta.schemas.letta_message_content import TextContent
from letta.schemas.message import Message, MessageCreate
from letta.schemas.message import MessageCreate
from letta.schemas.tool import Tool
from letta.schemas.tool_rule import ChildToolRule, InitToolRule, TerminalToolRule
from letta.schemas.usage import LettaUsageStatistics
@ -37,17 +37,18 @@ class SupervisorMultiAgent(Agent):
def step(
self,
messages: List[MessageCreate],
input_messages: List[MessageCreate],
chaining: bool = True,
max_chaining_steps: Optional[int] = None,
put_inner_thoughts_first: bool = True,
assistant_message_tool_name: str = DEFAULT_MESSAGE_TOOL,
**kwargs,
) -> LettaUsageStatistics:
# Load settings
token_streaming = self.interface.streaming_mode if hasattr(self.interface, "streaming_mode") else False
metadata = self.interface.metadata if hasattr(self.interface, "metadata") else None
# add multi agent tool
# Prepare supervisor agent
if self.tool_manager.get_tool_by_name(tool_name="send_message_to_all_agents_in_group", actor=self.user) is None:
multi_agent_tool = Tool(
name=send_message_to_all_agents_in_group.__name__,
@ -64,7 +65,6 @@ class SupervisorMultiAgent(Agent):
)
self.agent_state = self.agent_manager.attach_tool(agent_id=self.agent_state.id, tool_id=multi_agent_tool.id, actor=self.user)
# override tool rules
old_tool_rules = self.agent_state.tool_rules
self.agent_state.tool_rules = [
InitToolRule(
@ -79,24 +79,25 @@ class SupervisorMultiAgent(Agent):
),
]
supervisor_messages = [
Message(
agent_id=self.agent_state.id,
role="user",
content=[TextContent(text=message.content)],
name=None,
model=None,
tool_calls=None,
tool_call_id=None,
group_id=self.group_id,
otid=message.otid,
)
for message in messages
]
# Prepare new messages
new_messages = []
for message in input_messages:
if isinstance(message.content, str):
message.content = [TextContent(text=message.content)]
message.group_id = self.group_id
new_messages.append(message)
try:
supervisor_agent = Agent(agent_state=self.agent_state, interface=self.interface, user=self.user)
# Load supervisor agent
supervisor_agent = Agent(
agent_state=self.agent_state,
interface=self.interface,
user=self.user,
)
# Perform supervisor step
usage_stats = supervisor_agent.step(
messages=supervisor_messages,
input_messages=new_messages,
chaining=chaining,
max_chaining_steps=max_chaining_steps,
stream=token_streaming,

View File

@ -22,6 +22,13 @@ from letta.schemas.letta_message_content import (
)
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import ToolReturn
from letta.schemas.response_format import (
JsonObjectResponseFormat,
JsonSchemaResponseFormat,
ResponseFormatType,
ResponseFormatUnion,
TextResponseFormat,
)
from letta.schemas.tool_rule import (
ChildToolRule,
ConditionalToolRule,
@ -371,3 +378,25 @@ def deserialize_agent_step_state(data: Optional[Dict]) -> Optional[AgentStepStat
return None
return AgentStepState(**data)
# --------------------------
# Response Format Serialization
# --------------------------
def serialize_response_format(response_format: Optional[ResponseFormatUnion]) -> Optional[Dict[str, Any]]:
if not response_format:
return None
return response_format.model_dump(mode="json")
def deserialize_response_format(data: Optional[Dict]) -> Optional[ResponseFormatUnion]:
if not data:
return None
if data["type"] == ResponseFormatType.text:
return TextResponseFormat(**data)
if data["type"] == ResponseFormatType.json_schema:
return JsonSchemaResponseFormat(**data)
if data["type"] == ResponseFormatType.json_object:
return JsonObjectResponseFormat(**data)

View File

@ -40,4 +40,5 @@ def prepare_input_message_create(
tool_call_id=None,
otid=message.otid,
sender_id=message.sender_id,
group_id=message.group_id,
)

View File

@ -160,12 +160,12 @@ def execute_external_tool(
else:
agent_state_copy = None
sandbox_run_result = ToolExecutionSandbox(function_name, function_args, actor).run(agent_state=agent_state_copy)
function_response, updated_agent_state = sandbox_run_result.func_return, sandbox_run_result.agent_state
tool_execution_result = ToolExecutionSandbox(function_name, function_args, actor).run(agent_state=agent_state_copy)
function_response, updated_agent_state = tool_execution_result.func_return, tool_execution_result.agent_state
# TODO: Bring this back
# if allow_agent_state_modifications and updated_agent_state is not None:
# self.update_memory_if_changed(updated_agent_state.memory)
return function_response, sandbox_run_result
return function_response, tool_execution_result
except Exception as e:
# Need to catch error here, or else trunction wont happen
# TODO: modify to function execution error

View File

@ -5,7 +5,7 @@ from sqlalchemy import JSON, Boolean, Index, String
from sqlalchemy.orm import Mapped, mapped_column, relationship
from letta.orm.block import Block
from letta.orm.custom_columns import EmbeddingConfigColumn, LLMConfigColumn, ToolRulesColumn
from letta.orm.custom_columns import EmbeddingConfigColumn, LLMConfigColumn, ResponseFormatColumn, ToolRulesColumn
from letta.orm.identity import Identity
from letta.orm.mixins import OrganizationMixin
from letta.orm.organization import Organization
@ -15,6 +15,7 @@ from letta.schemas.agent import AgentType, get_prompt_template_for_agent_type
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import Memory
from letta.schemas.response_format import ResponseFormatUnion
from letta.schemas.tool_rule import ToolRule
if TYPE_CHECKING:
@ -48,6 +49,11 @@ class Agent(SqlalchemyBase, OrganizationMixin):
# This is dangerously flexible with the JSON type
message_ids: Mapped[Optional[List[str]]] = mapped_column(JSON, nullable=True, doc="List of message IDs in in-context memory.")
# Response Format
response_format: Mapped[Optional[ResponseFormatUnion]] = mapped_column(
ResponseFormatColumn, nullable=True, doc="The response format for the agent."
)
# Metadata and configs
metadata_: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True, doc="metadata for the agent.")
llm_config: Mapped[Optional[LLMConfig]] = mapped_column(
@ -168,6 +174,7 @@ class Agent(SqlalchemyBase, OrganizationMixin):
"multi_agent_group": None,
"tool_exec_environment_variables": [],
"enable_sleeptime": None,
"response_format": self.response_format,
}
# Optional fields: only included if requested

View File

@ -9,6 +9,7 @@ from letta.helpers.converters import (
deserialize_llm_config,
deserialize_message_content,
deserialize_poll_batch_response,
deserialize_response_format,
deserialize_tool_calls,
deserialize_tool_returns,
deserialize_tool_rules,
@ -20,6 +21,7 @@ from letta.helpers.converters import (
serialize_llm_config,
serialize_message_content,
serialize_poll_batch_response,
serialize_response_format,
serialize_tool_calls,
serialize_tool_returns,
serialize_tool_rules,
@ -168,3 +170,16 @@ class AgentStepStateColumn(TypeDecorator):
def process_result_value(self, value, dialect):
return deserialize_agent_step_state(value)
class ResponseFormatColumn(TypeDecorator):
"""Custom SQLAlchemy column type for storing a list of ToolRules as JSON."""
impl = JSON
cache_ok = True
def process_bind_param(self, value, dialect):
return serialize_response_format(value)
def process_result_value(self, value, dialect):
return deserialize_response_format(value)

View File

@ -14,6 +14,7 @@ from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import Memory
from letta.schemas.message import Message, MessageCreate
from letta.schemas.openai.chat_completion_response import UsageStatistics
from letta.schemas.response_format import ResponseFormatUnion
from letta.schemas.source import Source
from letta.schemas.tool import Tool
from letta.schemas.tool_rule import ToolRule
@ -66,6 +67,9 @@ class AgentState(OrmMetadataBase, validate_assignment=True):
# llm information
llm_config: LLMConfig = Field(..., description="The LLM configuration used by the agent.")
embedding_config: EmbeddingConfig = Field(..., description="The embedding configuration used by the agent.")
response_format: Optional[ResponseFormatUnion] = Field(
None, description="The response format used by the agent when returning from `send_message`."
)
# This is an object representing the in-process state of a running `Agent`
# Field in this object can be theoretically edited by tools, and will be persisted by the ORM
@ -180,6 +184,7 @@ class CreateAgent(BaseModel, validate_assignment=True): #
description="If set to True, the agent will not remember previous messages (though the agent will still retain state via core memory blocks and archival/recall memory). Not recommended unless you have an advanced use case.",
)
enable_sleeptime: Optional[bool] = Field(None, description="If set to True, memory management will move to a background agent thread.")
response_format: Optional[ResponseFormatUnion] = Field(None, description="The response format for the agent.")
@field_validator("name")
@classmethod
@ -259,6 +264,7 @@ class UpdateAgent(BaseModel):
None, description="The embedding configuration handle used by the agent, specified in the format provider/model-name."
)
enable_sleeptime: Optional[bool] = Field(None, description="If set to True, memory management will move to a background agent thread.")
response_format: Optional[ResponseFormatUnion] = Field(None, description="The response format for the agent.")
class Config:
extra = "ignore" # Ignores extra fields

View File

@ -82,6 +82,7 @@ class MessageCreate(BaseModel):
name: Optional[str] = Field(None, description="The name of the participant.")
otid: Optional[str] = Field(None, description="The offline threading id associated with this message")
sender_id: Optional[str] = Field(None, description="The id of the sender of the message, can be an identity id or agent id")
group_id: Optional[str] = Field(None, description="The multi-agent group that the message was sent in")
def model_dump(self, to_orm: bool = False, **kwargs) -> Dict[str, Any]:
data = super().model_dump(**kwargs)

View File

@ -0,0 +1,78 @@
from enum import Enum
from typing import Annotated, Any, Dict, Literal, Union
from pydantic import BaseModel, Field, validator
class ResponseFormatType(str, Enum):
"""Enum defining the possible response format types."""
text = "text"
json_schema = "json_schema"
json_object = "json_object"
class ResponseFormat(BaseModel):
"""Base class for all response formats."""
type: ResponseFormatType = Field(
...,
description="The type of the response format.",
# why use this?
example=ResponseFormatType.text,
)
# ---------------------
# Response Format Types
# ---------------------
# SQLAlchemy type for database mapping
ResponseFormatDict = Dict[str, Any]
class TextResponseFormat(ResponseFormat):
"""Response format for plain text responses."""
type: Literal[ResponseFormatType.text] = Field(
ResponseFormatType.text,
description="The type of the response format.",
)
class JsonSchemaResponseFormat(ResponseFormat):
"""Response format for JSON schema-based responses."""
type: Literal[ResponseFormatType.json_schema] = Field(
ResponseFormatType.json_schema,
description="The type of the response format.",
)
json_schema: Dict[str, Any] = Field(
...,
description="The JSON schema of the response.",
)
@validator("json_schema")
def validate_json_schema(cls, v: Dict[str, Any]) -> Dict[str, Any]:
"""Validate that the provided schema is a valid JSON schema."""
if not isinstance(v, dict):
raise ValueError("JSON schema must be a dictionary")
if "schema" not in v:
raise ValueError("JSON schema should include a $schema property")
return v
class JsonObjectResponseFormat(ResponseFormat):
"""Response format for JSON object responses."""
type: Literal[ResponseFormatType.json_object] = Field(
ResponseFormatType.json_object,
description="The type of the response format.",
)
# Pydantic type for validation
ResponseFormatUnion = Annotated[
Union[TextResponseFormat | JsonSchemaResponseFormat | JsonObjectResponseFormat],
Field(discriminator="type"),
]

View File

@ -0,0 +1,14 @@
from typing import Any, List, Literal, Optional
from pydantic import BaseModel, Field
from letta.schemas.agent import AgentState
class ToolExecutionResult(BaseModel):
status: Literal["success", "error"] = Field(..., description="The status of the tool execution and return object")
func_return: Optional[Any] = Field(None, description="The function return object")
agent_state: Optional[AgentState] = Field(None, description="The agent state")
stdout: Optional[List[str]] = Field(None, description="Captured stdout (prints, logs) from function invocation")
stderr: Optional[List[str]] = Field(None, description="Captured stderr from the function invocation")
sandbox_config_fingerprint: Optional[str] = Field(None, description="The fingerprint of the config for the sandbox")

View File

@ -1240,10 +1240,11 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
and function_call.function.name == self.assistant_message_tool_name
and self.assistant_message_tool_kwarg in func_args
):
# Coerce content to `str` in cases where it's a JSON due to `response_format` being a JSON
processed_chunk = AssistantMessage(
id=msg_obj.id,
date=msg_obj.created_at,
content=func_args[self.assistant_message_tool_kwarg],
content=str(func_args[self.assistant_message_tool_kwarg]),
name=msg_obj.name,
otid=Message.generate_otid_from_id(msg_obj.id, chunk_index) if chunk_index is not None else None,
)

View File

@ -111,7 +111,7 @@ async def send_message_to_agent_chat_completions(
server.send_messages,
actor=actor,
agent_id=letta_agent.agent_state.id,
messages=messages,
input_messages=messages,
interface=streaming_interface,
put_inner_thoughts_first=False,
)

View File

@ -412,7 +412,7 @@ def list_blocks(
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
try:
agent = server.agent_manager.get_agent_by_id(agent_id, actor=actor)
agent = server.agent_manager.get_agent_by_id(agent_id, actor)
return agent.memory.blocks
except NoResultFound as e:
raise HTTPException(status_code=404, detail=str(e))
@ -640,7 +640,7 @@ async def send_message(
result = await server.send_message_to_agent(
agent_id=agent_id,
actor=actor,
messages=request.messages,
input_messages=request.messages,
stream_steps=False,
stream_tokens=False,
# Support for AssistantMessage
@ -703,7 +703,7 @@ async def send_message_streaming(
result = await server.send_message_to_agent(
agent_id=agent_id,
actor=actor,
messages=request.messages,
input_messages=request.messages,
stream_steps=True,
stream_tokens=request.stream_tokens,
# Support for AssistantMessage
@ -730,7 +730,7 @@ async def process_message_background(
result = await server.send_message_to_agent(
agent_id=agent_id,
actor=actor,
messages=messages,
input_messages=messages,
stream_steps=False, # NOTE(matt)
stream_tokens=False,
use_assistant_message=use_assistant_message,

View File

@ -128,7 +128,7 @@ async def send_group_message(
result = await server.send_group_message_to_agent(
group_id=group_id,
actor=actor,
messages=request.messages,
input_messages=request.messages,
stream_steps=False,
stream_tokens=False,
# Support for AssistantMessage
@ -167,7 +167,7 @@ async def send_group_message_streaming(
result = await server.send_group_message_to_agent(
group_id=group_id,
actor=actor,
messages=request.messages,
input_messages=request.messages,
stream_steps=True,
stream_tokens=request.stream_tokens,
# Support for AssistantMessage

View File

@ -7,7 +7,7 @@ from starlette.requests import Request
from letta.agents.letta_agent_batch import LettaAgentBatch
from letta.log import get_logger
from letta.orm.errors import NoResultFound
from letta.schemas.job import BatchJob, JobStatus, JobType
from letta.schemas.job import BatchJob, JobStatus, JobType, JobUpdate
from letta.schemas.letta_request import CreateBatch
from letta.server.rest_api.utils import get_letta_server
from letta.server.server import SyncServer
@ -43,18 +43,18 @@ async def create_messages_batch(
if length > max_bytes:
raise HTTPException(status_code=413, detail=f"Request too large ({length} bytes). Max is {max_bytes} bytes.")
try:
actor = server.user_manager.get_user_or_default(user_id=actor_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
batch_job = BatchJob(
user_id=actor.id,
status=JobStatus.running,
metadata={
"job_type": "batch_messages",
},
callback_url=str(payload.callback_url),
)
# Create a new job
batch_job = BatchJob(
user_id=actor.id,
status=JobStatus.created,
metadata={
"job_type": "batch_messages",
},
callback_url=str(payload.callback_url),
)
try:
batch_job = server.job_manager.create_job(pydantic_job=batch_job, actor=actor)
# create the batch runner
batch_runner = LettaAgentBatch(
@ -67,14 +67,17 @@ async def create_messages_batch(
job_manager=server.job_manager,
actor=actor,
)
llm_batch_job = await batch_runner.step_until_request(batch_requests=payload.requests, letta_batch_job_id=batch_job.id)
await batch_runner.step_until_request(batch_requests=payload.requests, letta_batch_job_id=batch_job.id)
# TODO: update run metadata
batch_job = server.job_manager.create_job(pydantic_job=batch_job, actor=actor)
except Exception:
except Exception as e:
import traceback
print("Error creating batch job", e)
traceback.print_exc()
# mark job as failed
server.job_manager.update_job_by_id(job_id=batch_job.id, job=BatchJob(status=JobStatus.failed), actor=actor)
raise
return batch_job
@ -125,8 +128,19 @@ async def cancel_batch_run(
try:
job = server.job_manager.get_job_by_id(job_id=batch_id, actor=actor)
job.status = JobStatus.cancelled
server.job_manager.update_job_by_id(job_id=job, job=job)
# TODO: actually cancel it
job = server.job_manager.update_job_by_id(job_id=job.id, job_update=JobUpdate(status=JobStatus.cancelled), actor=actor)
# Get related llm batch jobs
llm_batch_jobs = server.batch_manager.list_llm_batch_jobs(letta_batch_id=job.id, actor=actor)
for llm_batch_job in llm_batch_jobs:
if llm_batch_job.status in {JobStatus.running, JobStatus.created}:
# TODO: Extend to providers beyond anthropic
# TODO: For now, we only support anthropic
# Cancel the job
anthropic_batch_id = llm_batch_job.create_batch_response.id
await server.anthropic_async_client.messages.batches.cancel(anthropic_batch_id)
# Update all the batch_job statuses
server.batch_manager.update_llm_batch_status(llm_batch_id=llm_batch_job.id, status=JobStatus.cancelled, actor=actor)
except NoResultFound:
raise HTTPException(status_code=404, detail="Run not found")

View File

@ -28,7 +28,6 @@ from letta.functions.mcp_client.types import MCPServerType, MCPTool, SSEServerCo
from letta.groups.helpers import load_multi_agent
from letta.helpers.datetime_helpers import get_utc_time
from letta.helpers.json_helpers import json_dumps, json_loads
from letta.helpers.message_helper import prepare_input_message_create
# TODO use custom interface
from letta.interface import AgentInterface # abstract
@ -148,7 +147,7 @@ class Server(object):
raise NotImplementedError
@abstractmethod
def send_messages(self, user_id: str, agent_id: str, messages: Union[MessageCreate, List[Message]]) -> None:
def send_messages(self, user_id: str, agent_id: str, input_messages: List[MessageCreate]) -> None:
"""Send a list of messages to the agent"""
raise NotImplementedError
@ -372,19 +371,13 @@ class SyncServer(Server):
self,
actor: User,
agent_id: str,
input_messages: Union[Message, List[Message]],
input_messages: List[MessageCreate],
interface: Union[AgentInterface, None] = None, # needed to getting responses
put_inner_thoughts_first: bool = True,
# timestamp: Optional[datetime],
) -> LettaUsageStatistics:
"""Send the input message through the agent"""
# TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user
# Input validation
if isinstance(input_messages, Message):
input_messages = [input_messages]
if not all(isinstance(m, Message) for m in input_messages):
raise ValueError(f"messages should be a Message or a list of Message, got {type(input_messages)}")
logger.debug(f"Got input messages: {input_messages}")
letta_agent = None
try:
@ -400,8 +393,9 @@ class SyncServer(Server):
metadata = interface.metadata if hasattr(interface, "metadata") else None
else:
metadata = None
usage_stats = letta_agent.step(
messages=input_messages,
input_messages=input_messages,
chaining=self.chaining,
max_chaining_steps=self.max_chaining_steps,
stream=token_streaming,
@ -572,23 +566,14 @@ class SyncServer(Server):
)
# NOTE: eventually deprecate and only allow passing Message types
# Convert to a Message object
if timestamp:
message = Message(
agent_id=agent_id,
role="user",
content=[TextContent(text=packaged_user_message)],
created_at=timestamp,
)
else:
message = Message(
agent_id=agent_id,
role="user",
content=[TextContent(text=packaged_user_message)],
)
message = MessageCreate(
agent_id=agent_id,
role="user",
content=[TextContent(text=packaged_user_message)],
)
# Run the agent state forward
usage = self._step(actor=actor, agent_id=agent_id, input_messages=message)
usage = self._step(actor=actor, agent_id=agent_id, input_messages=[message])
return usage
def system_message(
@ -660,23 +645,14 @@ class SyncServer(Server):
self,
actor: User,
agent_id: str,
messages: Union[List[MessageCreate], List[Message]],
input_messages: List[MessageCreate],
wrap_user_message: bool = True,
wrap_system_message: bool = True,
interface: Union[AgentInterface, ChatCompletionsStreamingInterface, None] = None, # needed for responses
metadata: Optional[dict] = None, # Pass through metadata to interface
put_inner_thoughts_first: bool = True,
) -> LettaUsageStatistics:
"""Send a list of messages to the agent.
If messages are of type MessageCreate, convert them to Message objects before sending.
"""
if all(isinstance(m, MessageCreate) for m in messages):
message_objects = [prepare_input_message_create(m, agent_id, wrap_user_message, wrap_system_message) for m in messages]
elif all(isinstance(m, Message) for m in messages):
message_objects = messages
else:
raise ValueError(f"All messages must be of type Message or MessageCreate, got {[type(m) for m in messages]}")
"""Send a list of messages to the agent."""
# Store metadata in interface if provided
if metadata and hasattr(interface, "metadata"):
@ -686,7 +662,7 @@ class SyncServer(Server):
return self._step(
actor=actor,
agent_id=agent_id,
input_messages=message_objects,
input_messages=input_messages,
interface=interface,
put_inner_thoughts_first=put_inner_thoughts_first,
)
@ -703,8 +679,6 @@ class SyncServer(Server):
@trace_method
def get_cached_llm_config(self, **kwargs):
key = make_key(**kwargs)
print(self._llm_config_cache)
print("KEY", key)
if key not in self._llm_config_cache:
self._llm_config_cache[key] = self.get_llm_config_from_handle(**kwargs)
return self._llm_config_cache[key]
@ -1019,12 +993,8 @@ class SyncServer(Server):
agent = self.load_agent(agent_id=sleeptime_agent.id, actor=actor)
for passage in self.list_data_source_passages(source_id=source.id, user_id=actor.id):
agent.step(
messages=[
Message(
role="user",
content=[TextContent(text=passage.text)],
agent_id=sleeptime_agent.id,
),
input_messages=[
MessageCreate(role="user", content=passage.text),
]
)
self.agent_manager.delete_agent(agent_id=sleeptime_agent.id, actor=actor)
@ -1182,7 +1152,6 @@ class SyncServer(Server):
provider = self.get_provider_from_name(provider_name)
llm_configs = [config for config in provider.list_llm_models() if config.handle == handle]
print("LLM CONFIGS", llm_configs)
if not llm_configs:
llm_configs = [config for config in provider.list_llm_models() if config.model == model_name]
if not llm_configs:
@ -1195,8 +1164,6 @@ class SyncServer(Server):
if not llm_configs:
raise e
print("CONFIGS", llm_configs)
if len(llm_configs) == 1:
llm_config = llm_configs[0]
elif len(llm_configs) > 1:
@ -1343,17 +1310,17 @@ class SyncServer(Server):
# Next, attempt to run the tool with the sandbox
try:
sandbox_run_result = ToolExecutionSandbox(tool.name, tool_args, actor, tool_object=tool).run(
tool_execution_result = ToolExecutionSandbox(tool.name, tool_args, actor, tool_object=tool).run(
agent_state=agent_state, additional_env_vars=tool_env_vars
)
return ToolReturnMessage(
id="null",
tool_call_id="null",
date=get_utc_time(),
status=sandbox_run_result.status,
tool_return=str(sandbox_run_result.func_return),
stdout=sandbox_run_result.stdout,
stderr=sandbox_run_result.stderr,
status=tool_execution_result.status,
tool_return=str(tool_execution_result.func_return),
stdout=tool_execution_result.stdout,
stderr=tool_execution_result.stderr,
)
except Exception as e:
@ -1567,7 +1534,7 @@ class SyncServer(Server):
agent_id: str,
actor: User,
# role: MessageRole,
messages: Union[List[Message], List[MessageCreate]],
input_messages: List[MessageCreate],
stream_steps: bool,
stream_tokens: bool,
# related to whether or not we return `LettaMessage`s or `Message`s
@ -1647,7 +1614,7 @@ class SyncServer(Server):
self.send_messages,
actor=actor,
agent_id=agent_id,
messages=messages,
input_messages=input_messages,
interface=streaming_interface,
metadata=metadata,
)
@ -1701,7 +1668,7 @@ class SyncServer(Server):
self,
group_id: str,
actor: User,
messages: Union[List[Message], List[MessageCreate]],
input_messages: Union[List[Message], List[MessageCreate]],
stream_steps: bool,
stream_tokens: bool,
chat_completion_mode: bool = False,
@ -1751,7 +1718,7 @@ class SyncServer(Server):
task = asyncio.create_task(
asyncio.to_thread(
letta_multi_agent.step,
messages=messages,
input_messages=input_messages,
chaining=self.chaining,
max_chaining_steps=self.max_chaining_steps,
)

View File

@ -364,6 +364,7 @@ class AgentManager:
"base_template_id": agent_update.base_template_id,
"message_buffer_autoclear": agent_update.message_buffer_autoclear,
"enable_sleeptime": agent_update.enable_sleeptime,
"response_format": agent_update.response_format,
}
for col, val in scalar_updates.items():
if val is not None:

View File

@ -291,9 +291,7 @@ class LLMBatchManager:
return [item.to_pydantic() for item in results]
def bulk_update_llm_batch_items(
self,
llm_batch_id_agent_id_pairs: List[Tuple[str, str]],
field_updates: List[Dict[str, Any]],
self, llm_batch_id_agent_id_pairs: List[Tuple[str, str]], field_updates: List[Dict[str, Any]], strict: bool = True
) -> None:
"""
Efficiently update multiple LLMBatchItem rows by (llm_batch_id, agent_id) pairs.
@ -301,30 +299,43 @@ class LLMBatchManager:
Args:
llm_batch_id_agent_id_pairs: List of (llm_batch_id, agent_id) tuples identifying items to update
field_updates: List of dictionaries containing the fields to update for each item
strict: Whether to error if any of the requested keys don't exist (default True).
If False, missing pairs are skipped.
"""
if not llm_batch_id_agent_id_pairs or not field_updates:
return
if len(llm_batch_id_agent_id_pairs) != len(field_updates):
raise ValueError("batch_id_agent_id_pairs and field_updates must have the same length")
raise ValueError("llm_batch_id_agent_id_pairs and field_updates must have the same length")
with self.session_maker() as session:
# Lookup primary keys
# Lookup primary keys for all requested (batch_id, agent_id) pairs
items = (
session.query(LLMBatchItem.id, LLMBatchItem.llm_batch_id, LLMBatchItem.agent_id)
.filter(tuple_(LLMBatchItem.llm_batch_id, LLMBatchItem.agent_id).in_(llm_batch_id_agent_id_pairs))
.all()
)
pair_to_pk = {(b, a): id for id, b, a in items}
pair_to_pk = {(batch_id, agent_id): pk for pk, batch_id, agent_id in items}
if strict:
requested = set(llm_batch_id_agent_id_pairs)
found = set(pair_to_pk.keys())
missing = requested - found
if missing:
raise ValueError(
f"Cannot bulk-update batch items: no records for the following " f"(llm_batch_id, agent_id) pairs: {missing}"
)
# Build mappings, skipping any missing when strict=False
mappings = []
for (llm_batch_id, agent_id), fields in zip(llm_batch_id_agent_id_pairs, field_updates):
pk_id = pair_to_pk.get((llm_batch_id, agent_id))
if not pk_id:
for (batch_id, agent_id), fields in zip(llm_batch_id_agent_id_pairs, field_updates):
pk = pair_to_pk.get((batch_id, agent_id))
if pk is None:
# skip missing in non-strict mode
continue
update_fields = fields.copy()
update_fields["id"] = pk_id
update_fields["id"] = pk
mappings.append(update_fields)
if mappings:
@ -332,10 +343,7 @@ class LLMBatchManager:
session.commit()
@enforce_types
def bulk_update_batch_llm_items_results_by_agent(
self,
updates: List[ItemUpdateInfo],
) -> None:
def bulk_update_batch_llm_items_results_by_agent(self, updates: List[ItemUpdateInfo], strict: bool = True) -> None:
"""Update request status and batch results for multiple batch items."""
batch_id_agent_id_pairs = [(update.llm_batch_id, update.agent_id) for update in updates]
field_updates = [
@ -346,29 +354,23 @@ class LLMBatchManager:
for update in updates
]
self.bulk_update_llm_batch_items(batch_id_agent_id_pairs, field_updates)
self.bulk_update_llm_batch_items(batch_id_agent_id_pairs, field_updates, strict=strict)
@enforce_types
def bulk_update_llm_batch_items_step_status_by_agent(
self,
updates: List[StepStatusUpdateInfo],
) -> None:
def bulk_update_llm_batch_items_step_status_by_agent(self, updates: List[StepStatusUpdateInfo], strict: bool = True) -> None:
"""Update step status for multiple batch items."""
batch_id_agent_id_pairs = [(update.llm_batch_id, update.agent_id) for update in updates]
field_updates = [{"step_status": update.step_status} for update in updates]
self.bulk_update_llm_batch_items(batch_id_agent_id_pairs, field_updates)
self.bulk_update_llm_batch_items(batch_id_agent_id_pairs, field_updates, strict=strict)
@enforce_types
def bulk_update_llm_batch_items_request_status_by_agent(
self,
updates: List[RequestStatusUpdateInfo],
) -> None:
def bulk_update_llm_batch_items_request_status_by_agent(self, updates: List[RequestStatusUpdateInfo], strict: bool = True) -> None:
"""Update request status for multiple batch items."""
batch_id_agent_id_pairs = [(update.llm_batch_id, update.agent_id) for update in updates]
field_updates = [{"request_status": update.request_status} for update in updates]
self.bulk_update_llm_batch_items(batch_id_agent_id_pairs, field_updates)
self.bulk_update_llm_batch_items(batch_id_agent_id_pairs, field_updates, strict=strict)
@enforce_types
def delete_llm_batch_item(self, item_id: str, actor: PydanticUser) -> None:

View File

@ -1,16 +1,17 @@
from typing import Any, Dict, Optional, Tuple, Type
import traceback
from typing import Any, Dict, Optional, Type
from letta.log import get_logger
from letta.orm.enums import ToolType
from letta.schemas.agent import AgentState
from letta.schemas.sandbox_config import SandboxConfig, SandboxRunResult
from letta.schemas.sandbox_config import SandboxConfig
from letta.schemas.tool import Tool
from letta.schemas.tool_execution_result import ToolExecutionResult
from letta.schemas.user import User
from letta.services.tool_executor.tool_executor import (
ExternalComposioToolExecutor,
ExternalMCPToolExecutor,
LettaCoreToolExecutor,
LettaMemoryToolExecutor,
LettaMultiAgentToolExecutor,
SandboxToolExecutor,
ToolExecutor,
@ -24,8 +25,9 @@ class ToolExecutorFactory:
_executor_map: Dict[ToolType, Type[ToolExecutor]] = {
ToolType.LETTA_CORE: LettaCoreToolExecutor,
ToolType.LETTA_MEMORY_CORE: LettaCoreToolExecutor,
ToolType.LETTA_SLEEPTIME_CORE: LettaCoreToolExecutor,
ToolType.LETTA_MULTI_AGENT_CORE: LettaMultiAgentToolExecutor,
ToolType.LETTA_MEMORY_CORE: LettaMemoryToolExecutor,
ToolType.EXTERNAL_COMPOSIO: ExternalComposioToolExecutor,
ToolType.EXTERNAL_MCP: ExternalMCPToolExecutor,
}
@ -33,13 +35,8 @@ class ToolExecutorFactory:
@classmethod
def get_executor(cls, tool_type: ToolType) -> ToolExecutor:
"""Get the appropriate executor for the given tool type."""
executor_class = cls._executor_map.get(tool_type)
if executor_class:
return executor_class()
# Default to sandbox executor for unknown types
return SandboxToolExecutor()
executor_class = cls._executor_map.get(tool_type, SandboxToolExecutor)
return executor_class()
class ToolExecutionManager:
@ -58,7 +55,7 @@ class ToolExecutionManager:
self.sandbox_config = sandbox_config
self.sandbox_env_vars = sandbox_env_vars
def execute_tool(self, function_name: str, function_args: dict, tool: Tool) -> Tuple[Any, Optional[SandboxRunResult]]:
def execute_tool(self, function_name: str, function_args: dict, tool: Tool) -> ToolExecutionResult:
"""
Execute a tool and persist any state changes.
@ -71,36 +68,17 @@ class ToolExecutionManager:
Tuple containing the function response and sandbox run result (if applicable)
"""
try:
# Get the appropriate executor for this tool type
executor = ToolExecutorFactory.get_executor(tool.tool_type)
# Execute the tool
return executor.execute(
function_name, function_args, self.agent_state, tool, self.actor, self.sandbox_config, self.sandbox_env_vars
function_name,
function_args,
self.agent_state,
tool,
self.actor,
self.sandbox_config,
self.sandbox_env_vars,
)
except Exception as e:
self.logger.error(f"Error executing tool {function_name}: {str(e)}")
error_message = get_friendly_error_msg(function_name=function_name, exception_name=type(e).__name__, exception_message=str(e))
return error_message, SandboxRunResult(status="error")
@trace_method
async def execute_tool_async(self, function_name: str, function_args: dict, tool: Tool) -> Tuple[Any, Optional[SandboxRunResult]]:
"""
Execute a tool asynchronously and persist any state changes.
"""
try:
# Get the appropriate executor for this tool type
# TODO: Extend this async model to composio
if tool.tool_type == ToolType.CUSTOM:
executor = SandboxToolExecutor()
result_tuple = await executor.execute(function_name, function_args, self.agent_state, tool, self.actor)
else:
executor = ToolExecutorFactory.get_executor(tool.tool_type)
result_tuple = executor.execute(function_name, function_args, self.agent_state, tool, self.actor)
return result_tuple
except Exception as e:
self.logger.error(f"Error executing tool {function_name}: {str(e)}")
error_message = get_friendly_error_msg(
@ -108,4 +86,35 @@ class ToolExecutionManager:
exception_name=type(e).__name__,
exception_message=str(e),
)
return error_message, SandboxRunResult(status="error")
return ToolExecutionResult(
status="error",
func_return=error_message,
stderr=[traceback.format_exc()],
)
@trace_method
async def execute_tool_async(self, function_name: str, function_args: dict, tool: Tool) -> ToolExecutionResult:
"""
Execute a tool asynchronously and persist any state changes.
"""
try:
executor = ToolExecutorFactory.get_executor(tool.tool_type)
# TODO: Extend this async model to composio
if isinstance(executor, SandboxToolExecutor):
result = await executor.execute(function_name, function_args, self.agent_state, tool, self.actor)
else:
result = executor.execute(function_name, function_args, self.agent_state, tool, self.actor)
return result
except Exception as e:
self.logger.error(f"Error executing tool {function_name}: {str(e)}")
error_message = get_friendly_error_msg(
function_name=function_name,
exception_name=type(e).__name__,
exception_message=str(e),
)
return ToolExecutionResult(
status="error",
func_return=error_message,
stderr=[traceback.format_exc()],
)

View File

@ -13,8 +13,9 @@ from typing import Any, Dict, Optional
from letta.functions.helpers import generate_model_from_args_json_schema
from letta.log import get_logger
from letta.schemas.agent import AgentState
from letta.schemas.sandbox_config import SandboxConfig, SandboxRunResult, SandboxType
from letta.schemas.sandbox_config import SandboxConfig, SandboxType
from letta.schemas.tool import Tool
from letta.schemas.tool_execution_result import ToolExecutionResult
from letta.schemas.user import User
from letta.services.helpers.tool_execution_helper import (
add_imports_and_pydantic_schemas_for_args,
@ -72,7 +73,11 @@ class ToolExecutionSandbox:
self.force_recreate = force_recreate
self.force_recreate_venv = force_recreate_venv
def run(self, agent_state: Optional[AgentState] = None, additional_env_vars: Optional[Dict] = None) -> SandboxRunResult:
def run(
self,
agent_state: Optional[AgentState] = None,
additional_env_vars: Optional[Dict] = None,
) -> ToolExecutionResult:
"""
Run the tool in a sandbox environment.
@ -81,7 +86,7 @@ class ToolExecutionSandbox:
additional_env_vars (Optional[Dict]): Environment variables to inject into the sandbox
Returns:
Tuple[Any, Optional[AgentState]]: Tuple containing (tool_result, agent_state)
ToolExecutionResult: Object containing tool execution outcome (e.g. status, response)
"""
if tool_settings.e2b_api_key and not self.privileged_tools:
logger.debug(f"Using e2b sandbox to execute {self.tool_name}")
@ -115,7 +120,7 @@ class ToolExecutionSandbox:
@trace_method
def run_local_dir_sandbox(
self, agent_state: Optional[AgentState] = None, additional_env_vars: Optional[Dict] = None
) -> SandboxRunResult:
) -> ToolExecutionResult:
sbx_config = self.sandbox_config_manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.LOCAL, actor=self.user)
local_configs = sbx_config.get_local_config()
@ -162,7 +167,12 @@ class ToolExecutionSandbox:
os.remove(temp_file_path)
@trace_method
def run_local_dir_sandbox_venv(self, sbx_config: SandboxConfig, env: Dict[str, str], temp_file_path: str) -> SandboxRunResult:
def run_local_dir_sandbox_venv(
self,
sbx_config: SandboxConfig,
env: Dict[str, str],
temp_file_path: str,
) -> ToolExecutionResult:
local_configs = sbx_config.get_local_config()
sandbox_dir = os.path.expanduser(local_configs.sandbox_dir) # Expand tilde
venv_path = os.path.join(sandbox_dir, local_configs.venv_name)
@ -205,12 +215,12 @@ class ToolExecutionSandbox:
func_result, stdout = self.parse_out_function_results_markers(result.stdout)
func_return, agent_state = self.parse_best_effort(func_result)
return SandboxRunResult(
return ToolExecutionResult(
status="success",
func_return=func_return,
agent_state=agent_state,
stdout=[stdout] if stdout else [],
stderr=[result.stderr] if result.stderr else [],
status="success",
sandbox_config_fingerprint=sbx_config.fingerprint(),
)
@ -221,12 +231,12 @@ class ToolExecutionSandbox:
exception_name=type(e).__name__,
exception_message=str(e),
)
return SandboxRunResult(
return ToolExecutionResult(
status="error",
func_return=func_return,
agent_state=None,
stdout=[e.stdout] if e.stdout else [],
stderr=[e.stderr] if e.stderr else [],
status="error",
sandbox_config_fingerprint=sbx_config.fingerprint(),
)
@ -238,7 +248,12 @@ class ToolExecutionSandbox:
raise e
@trace_method
def run_local_dir_sandbox_directly(self, sbx_config: SandboxConfig, env: Dict[str, str], temp_file_path: str) -> SandboxRunResult:
def run_local_dir_sandbox_directly(
self,
sbx_config: SandboxConfig,
env: Dict[str, str],
temp_file_path: str,
) -> ToolExecutionResult:
status = "success"
func_return, agent_state, stderr = None, None, None
@ -288,12 +303,12 @@ class ToolExecutionSandbox:
stdout_output = [captured_stdout.getvalue()] if captured_stdout.getvalue() else []
stderr_output = [captured_stderr.getvalue()] if captured_stderr.getvalue() else []
return SandboxRunResult(
return ToolExecutionResult(
status=status,
func_return=func_return,
agent_state=agent_state,
stdout=stdout_output,
stderr=stderr_output,
status=status,
sandbox_config_fingerprint=sbx_config.fingerprint(),
)
@ -307,7 +322,11 @@ class ToolExecutionSandbox:
# e2b sandbox specific functions
def run_e2b_sandbox(self, agent_state: Optional[AgentState] = None, additional_env_vars: Optional[Dict] = None) -> SandboxRunResult:
def run_e2b_sandbox(
self,
agent_state: Optional[AgentState] = None,
additional_env_vars: Optional[Dict] = None,
) -> ToolExecutionResult:
sbx_config = self.sandbox_config_manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.E2B, actor=self.user)
sbx = self.get_running_e2b_sandbox_with_same_state(sbx_config)
if not sbx or self.force_recreate:
@ -348,12 +367,12 @@ class ToolExecutionSandbox:
else:
raise ValueError(f"Tool {self.tool_name} returned execution with None")
return SandboxRunResult(
return ToolExecutionResult(
status="error" if execution.error else "success",
func_return=func_return,
agent_state=agent_state,
stdout=execution.logs.stdout,
stderr=execution.logs.stderr,
status="error" if execution.error else "success",
sandbox_config_fingerprint=sbx_config.fingerprint(),
)
@ -535,7 +554,7 @@ class ToolExecutionSandbox:
Generate the code string to call the function.
Args:
inject_agent_state (bool): Whether to inject the axgent's state as an input into the tool
inject_agent_state (bool): Whether to inject the agent's state as an input into the tool
Returns:
str: Generated code string for calling the tool

View File

@ -1,15 +1,17 @@
import math
import traceback
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, Optional
from letta.constants import COMPOSIO_ENTITY_ENV_VAR_KEY, RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
from letta.constants import COMPOSIO_ENTITY_ENV_VAR_KEY, CORE_MEMORY_LINE_NUMBER_WARNING, RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
from letta.functions.ast_parsers import coerce_dict_args_by_annotations, get_function_annotations_from_source
from letta.functions.helpers import execute_composio_action, generate_composio_action_from_func_name
from letta.helpers.composio_helpers import get_composio_api_key
from letta.helpers.json_helpers import json_dumps
from letta.schemas.agent import AgentState
from letta.schemas.sandbox_config import SandboxConfig, SandboxRunResult
from letta.schemas.sandbox_config import SandboxConfig
from letta.schemas.tool import Tool
from letta.schemas.tool_execution_result import ToolExecutionResult
from letta.schemas.user import User
from letta.services.agent_manager import AgentManager
from letta.services.message_manager import MessageManager
@ -33,7 +35,7 @@ class ToolExecutor(ABC):
actor: User,
sandbox_config: Optional[SandboxConfig] = None,
sandbox_env_vars: Optional[Dict[str, Any]] = None,
) -> Tuple[Any, Optional[SandboxRunResult]]:
) -> ToolExecutionResult:
"""Execute the tool and return the result."""
@ -49,13 +51,19 @@ class LettaCoreToolExecutor(ToolExecutor):
actor: User,
sandbox_config: Optional[SandboxConfig] = None,
sandbox_env_vars: Optional[Dict[str, Any]] = None,
) -> Tuple[Any, Optional[SandboxRunResult]]:
) -> ToolExecutionResult:
# Map function names to method calls
function_map = {
"send_message": self.send_message,
"conversation_search": self.conversation_search,
"archival_memory_search": self.archival_memory_search,
"archival_memory_insert": self.archival_memory_insert,
"core_memory_append": self.core_memory_append,
"core_memory_replace": self.core_memory_replace,
"memory_replace": self.memory_replace,
"memory_insert": self.memory_insert,
"memory_rethink": self.memory_rethink,
"memory_finish_edits": self.memory_finish_edits,
}
if function_name not in function_map:
@ -64,7 +72,10 @@ class LettaCoreToolExecutor(ToolExecutor):
# Execute the appropriate function
function_args_copy = function_args.copy() # Make a copy to avoid modifying the original
function_response = function_map[function_name](agent_state, actor, **function_args_copy)
return function_response, None
return ToolExecutionResult(
status="success",
func_return=function_response,
)
def send_message(self, agent_state: AgentState, actor: User, message: str) -> Optional[str]:
"""
@ -181,51 +192,7 @@ class LettaCoreToolExecutor(ToolExecutor):
AgentManager().rebuild_system_prompt(agent_id=agent_state.id, actor=actor, force=True)
return None
class LettaMultiAgentToolExecutor(ToolExecutor):
"""Executor for LETTA multi-agent core tools."""
# TODO: Implement
# def execute(self, function_name: str, function_args: dict, agent: "Agent", tool: Tool) -> Tuple[
# Any, Optional[SandboxRunResult]]:
# callable_func = get_function_from_module(LETTA_MULTI_AGENT_TOOL_MODULE_NAME, function_name)
# function_args["self"] = agent # need to attach self to arg since it's dynamically linked
# function_response = callable_func(**function_args)
# return function_response, None
class LettaMemoryToolExecutor(ToolExecutor):
"""Executor for LETTA memory core tools with direct implementation."""
def execute(
self,
function_name: str,
function_args: dict,
agent_state: AgentState,
tool: Tool,
actor: User,
sandbox_config: Optional[SandboxConfig] = None,
sandbox_env_vars: Optional[Dict[str, Any]] = None,
) -> Tuple[Any, Optional[SandboxRunResult]]:
# Map function names to method calls
function_map = {
"core_memory_append": self.core_memory_append,
"core_memory_replace": self.core_memory_replace,
}
if function_name not in function_map:
raise ValueError(f"Unknown function: {function_name}")
# Execute the appropriate function with the copied state
function_args_copy = function_args.copy() # Make a copy to avoid modifying the original
function_response = function_map[function_name](agent_state, **function_args_copy)
# Update memory if changed
AgentManager().update_memory_if_changed(agent_id=agent_state.id, new_memory=agent_state.memory, actor=actor)
return function_response, None
def core_memory_append(self, agent_state: "AgentState", label: str, content: str) -> Optional[str]:
def core_memory_append(self, agent_state: "AgentState", actor: User, label: str, content: str) -> Optional[str]:
"""
Append to the contents of core memory.
@ -239,9 +206,17 @@ class LettaMemoryToolExecutor(ToolExecutor):
current_value = str(agent_state.memory.get_block(label).value)
new_value = current_value + "\n" + str(content)
agent_state.memory.update_block_value(label=label, value=new_value)
AgentManager().update_memory_if_changed(agent_id=agent_state.id, new_memory=agent_state.memory, actor=actor)
return None
def core_memory_replace(self, agent_state: "AgentState", label: str, old_content: str, new_content: str) -> Optional[str]:
def core_memory_replace(
self,
agent_state: "AgentState",
actor: User,
label: str,
old_content: str,
new_content: str,
) -> Optional[str]:
"""
Replace the contents of core memory. To delete memories, use an empty string for new_content.
@ -258,8 +233,253 @@ class LettaMemoryToolExecutor(ToolExecutor):
raise ValueError(f"Old content '{old_content}' not found in memory block '{label}'")
new_value = current_value.replace(str(old_content), str(new_content))
agent_state.memory.update_block_value(label=label, value=new_value)
AgentManager().update_memory_if_changed(agent_id=agent_state.id, new_memory=agent_state.memory, actor=actor)
return None
def memory_replace(
agent_state: "AgentState",
actor: User,
label: str,
old_str: str,
new_str: Optional[str] = None,
) -> str:
"""
The memory_replace command allows you to replace a specific string in a memory
block with a new string. This is used for making precise edits.
Args:
label (str): Section of the memory to be edited, identified by its label.
old_str (str): The text to replace (must match exactly, including whitespace
and indentation).
new_str (Optional[str]): The new text to insert in place of the old text.
Omit this argument to delete the old_str.
Returns:
str: The success message
"""
import re
if bool(re.search(r"\nLine \d+: ", old_str)):
raise ValueError(
"old_str contains a line number prefix, which is not allowed. "
"Do not include line numbers when calling memory tools (line "
"numbers are for display purposes only)."
)
if CORE_MEMORY_LINE_NUMBER_WARNING in old_str:
raise ValueError(
"old_str contains a line number warning, which is not allowed. "
"Do not include line number information when calling memory tools "
"(line numbers are for display purposes only)."
)
if bool(re.search(r"\nLine \d+: ", new_str)):
raise ValueError(
"new_str contains a line number prefix, which is not allowed. "
"Do not include line numbers when calling memory tools (line "
"numbers are for display purposes only)."
)
old_str = str(old_str).expandtabs()
new_str = str(new_str).expandtabs()
current_value = str(agent_state.memory.get_block(label).value).expandtabs()
# Check if old_str is unique in the block
occurences = current_value.count(old_str)
if occurences == 0:
raise ValueError(
f"No replacement was performed, old_str `{old_str}` did not appear " f"verbatim in memory block with label `{label}`."
)
elif occurences > 1:
content_value_lines = current_value.split("\n")
lines = [idx + 1 for idx, line in enumerate(content_value_lines) if old_str in line]
raise ValueError(
f"No replacement was performed. Multiple occurrences of "
f"old_str `{old_str}` in lines {lines}. Please ensure it is unique."
)
# Replace old_str with new_str
new_value = current_value.replace(str(old_str), str(new_str))
# Write the new content to the block
agent_state.memory.update_block_value(label=label, value=new_value)
AgentManager().update_memory_if_changed(agent_id=agent_state.id, new_memory=agent_state.memory, actor=actor)
# Create a snippet of the edited section
SNIPPET_LINES = 3
replacement_line = current_value.split(old_str)[0].count("\n")
start_line = max(0, replacement_line - SNIPPET_LINES)
end_line = replacement_line + SNIPPET_LINES + new_str.count("\n")
snippet = "\n".join(new_value.split("\n")[start_line : end_line + 1])
# Prepare the success message
success_msg = f"The core memory block with label `{label}` has been edited. "
# success_msg += self._make_output(
# snippet, f"a snippet of {path}", start_line + 1
# )
# success_msg += f"A snippet of core memory block `{label}`:\n{snippet}\n"
success_msg += (
"Review the changes and make sure they are as expected (correct indentation, "
"no duplicate lines, etc). Edit the memory block again if necessary."
)
# return None
return success_msg
def memory_insert(
agent_state: "AgentState",
actor: User,
label: str,
new_str: str,
insert_line: int = -1,
) -> str:
"""
The memory_insert command allows you to insert text at a specific location
in a memory block.
Args:
label (str): Section of the memory to be edited, identified by its label.
new_str (str): The text to insert.
insert_line (int): The line number after which to insert the text (0 for
beginning of file). Defaults to -1 (end of the file).
Returns:
str: The success message
"""
import re
if bool(re.search(r"\nLine \d+: ", new_str)):
raise ValueError(
"new_str contains a line number prefix, which is not allowed. Do not "
"include line numbers when calling memory tools (line numbers are for "
"display purposes only)."
)
if CORE_MEMORY_LINE_NUMBER_WARNING in new_str:
raise ValueError(
"new_str contains a line number warning, which is not allowed. Do not "
"include line number information when calling memory tools (line numbers "
"are for display purposes only)."
)
current_value = str(agent_state.memory.get_block(label).value).expandtabs()
new_str = str(new_str).expandtabs()
current_value_lines = current_value.split("\n")
n_lines = len(current_value_lines)
# Check if we're in range, from 0 (pre-line), to 1 (first line), to n_lines (last line)
if insert_line < 0 or insert_line > n_lines:
raise ValueError(
f"Invalid `insert_line` parameter: {insert_line}. It should be within "
f"the range of lines of the memory block: {[0, n_lines]}, or -1 to "
f"append to the end of the memory block."
)
# Insert the new string as a line
SNIPPET_LINES = 3
new_str_lines = new_str.split("\n")
new_value_lines = current_value_lines[:insert_line] + new_str_lines + current_value_lines[insert_line:]
snippet_lines = (
current_value_lines[max(0, insert_line - SNIPPET_LINES) : insert_line]
+ new_str_lines
+ current_value_lines[insert_line : insert_line + SNIPPET_LINES]
)
# Collate into the new value to update
new_value = "\n".join(new_value_lines)
snippet = "\n".join(snippet_lines)
# Write into the block
agent_state.memory.update_block_value(label=label, value=new_value)
AgentManager().update_memory_if_changed(agent_id=agent_state.id, new_memory=agent_state.memory, actor=actor)
# Prepare the success message
success_msg = f"The core memory block with label `{label}` has been edited. "
# success_msg += self._make_output(
# snippet,
# "a snippet of the edited file",
# max(1, insert_line - SNIPPET_LINES + 1),
# )
# success_msg += f"A snippet of core memory block `{label}`:\n{snippet}\n"
success_msg += (
"Review the changes and make sure they are as expected (correct indentation, "
"no duplicate lines, etc). Edit the memory block again if necessary."
)
return success_msg
def memory_rethink(agent_state: "AgentState", actor: User, label: str, new_memory: str) -> str:
"""
The memory_rethink command allows you to completely rewrite the contents of a
memory block. Use this tool to make large sweeping changes (e.g. when you want
to condense or reorganize the memory blocks), do NOT use this tool to make small
precise edits (e.g. add or remove a line, replace a specific string, etc).
Args:
label (str): The memory block to be rewritten, identified by its label.
new_memory (str): The new memory contents with information integrated from
existing memory blocks and the conversation context.
Returns:
str: The success message
"""
import re
if bool(re.search(r"\nLine \d+: ", new_memory)):
raise ValueError(
"new_memory contains a line number prefix, which is not allowed. Do not "
"include line numbers when calling memory tools (line numbers are for "
"display purposes only)."
)
if CORE_MEMORY_LINE_NUMBER_WARNING in new_memory:
raise ValueError(
"new_memory contains a line number warning, which is not allowed. Do not "
"include line number information when calling memory tools (line numbers "
"are for display purposes only)."
)
if agent_state.memory.get_block(label) is None:
agent_state.memory.create_block(label=label, value=new_memory)
agent_state.memory.update_block_value(label=label, value=new_memory)
AgentManager().update_memory_if_changed(agent_id=agent_state.id, new_memory=agent_state.memory, actor=actor)
# Prepare the success message
success_msg = f"The core memory block with label `{label}` has been edited. "
# success_msg += self._make_output(
# snippet, f"a snippet of {path}", start_line + 1
# )
# success_msg += f"A snippet of core memory block `{label}`:\n{snippet}\n"
success_msg += (
"Review the changes and make sure they are as expected (correct indentation, "
"no duplicate lines, etc). Edit the memory block again if necessary."
)
# return None
return success_msg
def memory_finish_edits(agent_state: "AgentState") -> None:
"""
Call the memory_finish_edits command when you are finished making edits
(integrating all new information) into the memory blocks. This function
is called when the agent is done rethinking the memory.
Returns:
Optional[str]: None is always returned as this function does not produce a response.
"""
return None
class LettaMultiAgentToolExecutor(ToolExecutor):
"""Executor for LETTA multi-agent core tools."""
# TODO: Implement
# def execute(self, function_name: str, function_args: dict, agent: "Agent", tool: Tool) -> ToolExecutionResult:
# callable_func = get_function_from_module(LETTA_MULTI_AGENT_TOOL_MODULE_NAME, function_name)
# function_args["self"] = agent # need to attach self to arg since it's dynamically linked
# function_response = callable_func(**function_args)
# return ToolExecutionResult(func_return=function_response)
class ExternalComposioToolExecutor(ToolExecutor):
"""Executor for external Composio tools."""
@ -273,7 +493,7 @@ class ExternalComposioToolExecutor(ToolExecutor):
actor: User,
sandbox_config: Optional[SandboxConfig] = None,
sandbox_env_vars: Optional[Dict[str, Any]] = None,
) -> Tuple[Any, Optional[SandboxRunResult]]:
) -> ToolExecutionResult:
action_name = generate_composio_action_from_func_name(tool.name)
# Get entity ID from the agent_state
@ -287,7 +507,10 @@ class ExternalComposioToolExecutor(ToolExecutor):
action_name=action_name, args=function_args, api_key=composio_api_key, entity_id=entity_id
)
return function_response, None
return ToolExecutionResult(
status="success",
func_return=function_response,
)
def _get_entity_id(self, agent_state: AgentState) -> Optional[str]:
"""Extract the entity ID from environment variables."""
@ -302,8 +525,7 @@ class ExternalMCPToolExecutor(ToolExecutor):
# TODO: Implement
#
# def execute(self, function_name: str, function_args: dict, agent_state: AgentState, tool: Tool, actor: User) -> Tuple[
# Any, Optional[SandboxRunResult]]:
# def execute(self, function_name: str, function_args: dict, agent_state: AgentState, tool: Tool, actor: User) -> ToolExecutionResult:
# # Get the server name from the tool tag
# server_name = self._extract_server_name(tool)
#
@ -316,8 +538,10 @@ class ExternalMCPToolExecutor(ToolExecutor):
# # Execute the tool
# function_response, is_error = mcp_client.execute_tool(tool_name=function_name, tool_args=function_args)
#
# sandbox_run_result = SandboxRunResult(status="error" if is_error else "success")
# return function_response, sandbox_run_result
# return ToolExecutionResult(
# status="error" if is_error else "success",
# func_return=function_response,
# )
#
# def _extract_server_name(self, tool: Tool) -> str:
# """Extract server name from tool tags."""
@ -360,7 +584,7 @@ class SandboxToolExecutor(ToolExecutor):
actor: User,
sandbox_config: Optional[SandboxConfig] = None,
sandbox_env_vars: Optional[Dict[str, Any]] = None,
) -> Tuple[Any, Optional[SandboxRunResult]]:
) -> ToolExecutionResult:
# Store original memory state
orig_memory_str = agent_state.memory.compile()
@ -381,21 +605,19 @@ class SandboxToolExecutor(ToolExecutor):
function_name, function_args, actor, tool_object=tool, sandbox_config=sandbox_config, sandbox_env_vars=sandbox_env_vars
)
sandbox_run_result = await sandbox.run(agent_state=agent_state_copy)
function_response, updated_agent_state = sandbox_run_result.func_return, sandbox_run_result.agent_state
tool_execution_result = await sandbox.run(agent_state=agent_state_copy)
# Verify memory integrity
assert orig_memory_str == agent_state.memory.compile(), "Memory should not be modified in a sandbox tool"
# Update agent memory if needed
if updated_agent_state is not None:
AgentManager().update_memory_if_changed(agent_state.id, updated_agent_state.memory, actor)
if tool_execution_result.agent_state is not None:
AgentManager().update_memory_if_changed(agent_state.id, tool_execution_result.agent_state.memory, actor)
return function_response, sandbox_run_result
return tool_execution_result
except Exception as e:
return self._handle_execution_error(e, function_name)
return self._handle_execution_error(e, function_name, traceback.format_exc())
def _prepare_function_args(self, function_args: dict, tool: Tool, function_name: str) -> dict:
"""Prepare function arguments with proper type coercion."""
@ -417,9 +639,18 @@ class SandboxToolExecutor(ToolExecutor):
agent_state_copy.tool_rules = []
return agent_state_copy
def _handle_execution_error(self, exception: Exception, function_name: str) -> Tuple[str, SandboxRunResult]:
def _handle_execution_error(
self,
exception: Exception,
function_name: str,
stderr: str,
) -> ToolExecutionResult:
"""Handle tool execution errors."""
error_message = get_friendly_error_msg(
function_name=function_name, exception_name=type(exception).__name__, exception_message=str(exception)
)
return error_message, SandboxRunResult(status="error")
return ToolExecutionResult(
status="error",
func_return=error_message,
stderr=[stderr],
)

View File

@ -7,8 +7,9 @@ from typing import Any, Dict, Optional, Tuple
from letta.functions.helpers import generate_model_from_args_json_schema
from letta.schemas.agent import AgentState
from letta.schemas.sandbox_config import SandboxConfig, SandboxRunResult
from letta.schemas.sandbox_config import SandboxConfig
from letta.schemas.tool import Tool
from letta.schemas.tool_execution_result import ToolExecutionResult
from letta.services.helpers.tool_execution_helper import add_imports_and_pydantic_schemas_for_args
from letta.services.sandbox_config_manager import SandboxConfigManager
from letta.services.tool_manager import ToolManager
@ -64,7 +65,7 @@ class AsyncToolSandboxBase(ABC):
self,
agent_state: Optional[AgentState] = None,
additional_env_vars: Optional[Dict] = None,
) -> SandboxRunResult:
) -> ToolExecutionResult:
"""
Run the tool in a sandbox environment asynchronously.
Must be implemented by subclasses.

View File

@ -2,8 +2,9 @@ from typing import Any, Dict, Optional
from letta.log import get_logger
from letta.schemas.agent import AgentState
from letta.schemas.sandbox_config import SandboxConfig, SandboxRunResult, SandboxType
from letta.schemas.sandbox_config import SandboxConfig, SandboxType
from letta.schemas.tool import Tool
from letta.schemas.tool_execution_result import ToolExecutionResult
from letta.services.tool_sandbox.base import AsyncToolSandboxBase
from letta.utils import get_friendly_error_msg
@ -30,7 +31,7 @@ class AsyncToolSandboxE2B(AsyncToolSandboxBase):
self,
agent_state: Optional[AgentState] = None,
additional_env_vars: Optional[Dict] = None,
) -> SandboxRunResult:
) -> ToolExecutionResult:
"""
Run the tool in a sandbox environment asynchronously,
*always* using a subprocess for execution.
@ -45,7 +46,7 @@ class AsyncToolSandboxE2B(AsyncToolSandboxBase):
async def run_e2b_sandbox(
self, agent_state: Optional[AgentState] = None, additional_env_vars: Optional[Dict] = None
) -> SandboxRunResult:
) -> ToolExecutionResult:
if self.provided_sandbox_config:
sbx_config = self.provided_sandbox_config
else:
@ -94,7 +95,7 @@ class AsyncToolSandboxE2B(AsyncToolSandboxBase):
else:
raise ValueError(f"Tool {self.tool_name} returned execution with None")
return SandboxRunResult(
return ToolExecutionResult(
func_return=func_return,
agent_state=agent_state,
stdout=execution.logs.stdout,

View File

@ -5,8 +5,9 @@ import tempfile
from typing import Any, Dict, Optional, Tuple
from letta.schemas.agent import AgentState
from letta.schemas.sandbox_config import SandboxConfig, SandboxRunResult, SandboxType
from letta.schemas.sandbox_config import SandboxConfig, SandboxType
from letta.schemas.tool import Tool
from letta.schemas.tool_execution_result import ToolExecutionResult
from letta.services.helpers.tool_execution_helper import (
create_venv_for_local_sandbox,
find_python_executable,
@ -39,7 +40,7 @@ class AsyncToolSandboxLocal(AsyncToolSandboxBase):
self,
agent_state: Optional[AgentState] = None,
additional_env_vars: Optional[Dict] = None,
) -> SandboxRunResult:
) -> ToolExecutionResult:
"""
Run the tool in a sandbox environment asynchronously,
*always* using a subprocess for execution.
@ -53,7 +54,11 @@ class AsyncToolSandboxLocal(AsyncToolSandboxBase):
return result
@trace_method
async def run_local_dir_sandbox(self, agent_state: Optional[AgentState], additional_env_vars: Optional[Dict]) -> SandboxRunResult:
async def run_local_dir_sandbox(
self,
agent_state: Optional[AgentState],
additional_env_vars: Optional[Dict],
) -> ToolExecutionResult:
"""
Unified asynchronougit pus method to run the tool in a local sandbox environment,
always via subprocess for multi-core parallelism.
@ -156,7 +161,7 @@ class AsyncToolSandboxLocal(AsyncToolSandboxBase):
@trace_method
async def _execute_tool_subprocess(
self, sbx_config, python_executable: str, temp_file_path: str, env: Dict[str, str], cwd: str
) -> SandboxRunResult:
) -> ToolExecutionResult:
"""
Execute user code in a subprocess, always capturing stdout and stderr.
We parse special markers to extract the pickled result string.
@ -189,7 +194,7 @@ class AsyncToolSandboxLocal(AsyncToolSandboxBase):
func_result, stdout_text = self.parse_out_function_results_markers(stdout)
func_return, agent_state = self.parse_best_effort(func_result)
return SandboxRunResult(
return ToolExecutionResult(
func_return=func_return,
agent_state=agent_state,
stdout=[stdout_text] if stdout_text else [],
@ -209,7 +214,7 @@ class AsyncToolSandboxLocal(AsyncToolSandboxBase):
exception_name=type(e).__name__,
exception_message=str(e),
)
return SandboxRunResult(
return ToolExecutionResult(
func_return=func_return,
agent_state=None,
stdout=[],

453
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,6 @@
[tool.poetry]
name = "letta"
version = "0.7.0"
version = "0.7.1"
packages = [
{include = "letta"},
]

View File

@ -1,80 +1,140 @@
# import time
#
# import pytest
# from letta_client import Letta, LettaBatchRequest, MessageCreate, TextContent
#
#
# @pytest.fixture(scope="module")
# def client():
# return Letta(base_url="http://localhost:8283")
#
#
# def test_create_batch(client: Letta):
#
# # create agents
# agent1 = client.agents.create(
# name="agent1",
# memory_blocks=[{"label": "persona", "value": "you are agent 1"}],
# model="anthropic/claude-3-7-sonnet-20250219",
# embedding="letta/letta-free",
# )
# agent2 = client.agents.create(
# name="agent2",
# memory_blocks=[{"label": "persona", "value": "you are agent 2"}],
# model="anthropic/claude-3-7-sonnet-20250219",
# embedding="letta/letta-free",
# )
#
# # create a run
# run = client.messages.batches.create(
# requests=[
# LettaBatchRequest(
# messages=[
# MessageCreate(
# role="user",
# content=[
# TextContent(
# text="text",
# )
# ],
# )
# ],
# agent_id=agent1.id,
# ),
# LettaBatchRequest(
# messages=[
# MessageCreate(
# role="user",
# content=[
# TextContent(
# text="text",
# )
# ],
# )
# ],
# agent_id=agent2.id,
# ),
# ]
# )
# assert run is not None
#
# # list batches
# batches = client.messages.batches.list()
# assert len(batches) > 0, f"Expected 1 batch, got {len(batches)}"
#
# # check run status
# while True:
# run = client.messages.batches.retrieve(batch_id=run.id)
# if run.status == "completed":
# break
# print("Waiting for run to complete...", run.status)
# time.sleep(1)
#
# # get the batch results
# results = client.messages.batches.retrieve(
# run_id=run.id,
# )
# assert results is not None
# print(results)
#
# # cancel a run
import os
import threading
import time
import pytest
from dotenv import load_dotenv
from letta_client import Letta, LettaBatchRequest, MessageCreate, TextContent
from letta.config import LettaConfig
from letta.jobs.llm_batch_job_polling import poll_running_llm_batches
from letta.orm import Base
from letta.schemas.enums import JobStatus
from letta.server.db import db_context
from letta.server.server import SyncServer
@pytest.fixture(autouse=True)
def clear_batch_tables():
"""Clear batch-related tables before each test."""
with db_context() as session:
for table in reversed(Base.metadata.sorted_tables):
if table.name in {"jobs", "llm_batch_job", "llm_batch_items"}:
session.execute(table.delete()) # Truncate table
session.commit()
def run_server():
"""Starts the Letta server in a background thread."""
load_dotenv()
from letta.server.rest_api.app import start_server
start_server(debug=True)
@pytest.fixture(scope="session")
def server_url():
"""
Ensures a server is running and returns its base URL.
Uses environment variable if available, otherwise starts a server
in a background thread.
"""
url = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
if not os.getenv("LETTA_SERVER_URL"):
thread = threading.Thread(target=run_server, daemon=True)
thread.start()
time.sleep(5) # Give server time to start
return url
@pytest.fixture(scope="module")
def server():
"""
Creates a SyncServer instance for testing.
Loads and saves config to ensure proper initialization.
"""
config = LettaConfig.load()
config.save()
return SyncServer()
@pytest.fixture(scope="session")
def client(server_url):
"""Creates a REST client connected to the test server."""
return Letta(base_url=server_url)
@pytest.mark.asyncio
async def test_create_batch(client: Letta, server: SyncServer):
# create agents
agent1 = client.agents.create(
name="agent1_batch",
memory_blocks=[{"label": "persona", "value": "you are agent 1"}],
model="anthropic/claude-3-7-sonnet-20250219",
embedding="letta/letta-free",
)
agent2 = client.agents.create(
name="agent2_batch",
memory_blocks=[{"label": "persona", "value": "you are agent 2"}],
model="anthropic/claude-3-7-sonnet-20250219",
embedding="letta/letta-free",
)
# create a run
run = client.batches.create(
requests=[
LettaBatchRequest(
messages=[
MessageCreate(
role="user",
content=[
TextContent(
text="hi",
)
],
)
],
agent_id=agent1.id,
),
LettaBatchRequest(
messages=[
MessageCreate(
role="user",
content=[
TextContent(
text="hi",
)
],
)
],
agent_id=agent2.id,
),
]
)
assert run is not None
# list batches
batches = client.batches.list()
assert len(batches) == 1, f"Expected 1 batch, got {len(batches)}"
assert batches[0].status == JobStatus.running
# Poll it once
await poll_running_llm_batches(server)
# get the batch results
results = client.batches.retrieve(
batch_id=run.id,
)
assert results is not None
# cancel
client.batches.cancel(batch_id=run.id)
batch_job = client.batches.retrieve(
batch_id=run.id,
)
assert batch_job.status == JobStatus.cancelled

View File

@ -67,9 +67,9 @@ def test_composio_tool_execution_e2e(check_composio_key_set, composio_get_emojis
actor=default_user,
)
function_response, sandbox_run_result = ToolExecutionManager(agent_state, actor=default_user).execute_tool(
tool_execution_result = ToolExecutionManager(agent_state, actor=default_user).execute_tool(
function_name=composio_get_emojis.name, function_args={}, tool=composio_get_emojis
)
# Small check, it should return something at least
assert len(function_response.keys()) > 10
assert len(tool_execution_result.func_return.keys()) > 10

View File

@ -0,0 +1,192 @@
# TODO (cliandy): Tested in SDK
# TODO (cliandy): Comment out after merge
# import os
# import threading
# import time
# import pytest
# from dotenv import load_dotenv
# from letta_client import AssistantMessage, AsyncLetta, Letta, Tool
# from letta.schemas.agent import AgentState
# from typing import List, Any, Dict
# # ------------------------------
# # Fixtures
# # ------------------------------
# @pytest.fixture(scope="module")
# def server_url() -> str:
# """
# Provides the URL for the Letta server.
# If the environment variable 'LETTA_SERVER_URL' is not set, this fixture
# will start the Letta server in a background thread and return the default URL.
# """
# def _run_server() -> None:
# """Starts the Letta server in a background thread."""
# load_dotenv() # Load environment variables from .env file
# from letta.server.rest_api.app import start_server
# start_server(debug=True)
# # Retrieve server URL from environment, or default to localhost
# url: str = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
# # If no environment variable is set, start the server in a background thread
# if not os.getenv("LETTA_SERVER_URL"):
# thread = threading.Thread(target=_run_server, daemon=True)
# thread.start()
# time.sleep(5) # Allow time for the server to start
# return url
# @pytest.fixture
# def client(server_url: str) -> Letta:
# """
# Creates and returns a synchronous Letta REST client for testing.
# """
# client_instance = Letta(base_url=server_url)
# yield client_instance
# @pytest.fixture
# def async_client(server_url: str) -> AsyncLetta:
# """
# Creates and returns an asynchronous Letta REST client for testing.
# """
# async_client_instance = AsyncLetta(base_url=server_url)
# yield async_client_instance
# @pytest.fixture
# def roll_dice_tool(client: Letta) -> Tool:
# """
# Registers a simple roll dice tool with the provided client.
# The tool simulates rolling a six-sided die but returns a fixed result.
# """
# def roll_dice() -> str:
# """
# Simulates rolling a die.
# Returns:
# str: The roll result.
# """
# # Note: The result here is intentionally incorrect for demonstration purposes.
# return "Rolled a 10!"
# tool = client.tools.upsert_from_function(func=roll_dice)
# yield tool
# @pytest.fixture
# def agent_state(client: Letta, roll_dice_tool: Tool) -> AgentState:
# """
# Creates and returns an agent state for testing with a pre-configured agent.
# The agent is named 'supervisor' and is configured with base tools and the roll_dice tool.
# """
# agent_state_instance = client.agents.create(
# name="supervisor",
# include_base_tools=True,
# tool_ids=[roll_dice_tool.id],
# model="openai/gpt-4o",
# embedding="letta/letta-free",
# tags=["supervisor"],
# include_base_tool_rules=True,
# )
# yield agent_state_instance
# # Goal is to test that when an Agent is created with a `response_format`, that the response
# # of `send_message` is in the correct format. This will be done by modifying the agent's
# # `send_message` tool so that it returns a format based on what is passed in.
# #
# # `response_format` is an optional field
# # if `response_format.type` is `text`, then the schema does not change
# # if `response_format.type` is `json_object`, then the schema is a dict
# # if `response_format.type` is `json_schema`, then the schema is a dict matching that json schema
# USER_MESSAGE: List[Dict[str, str]] = [{"role": "user", "content": "Send me a message."}]
# # ------------------------------
# # Test Cases
# # ------------------------------
# def test_client_send_message_text_response_format(client: "Letta", agent: "AgentState") -> None:
# """Test client send_message with response_format='json_object'."""
# client.agents.modify(agent.id, response_format={"type": "text"})
# response = client.agents.messages.create_stream(
# agent_id=agent.id,
# messages=USER_MESSAGE,
# )
# messages = list(response)
# assert isinstance(messages[-1], AssistantMessage)
# assert isinstance(messages[-1].content, str)
# def test_client_send_message_json_object_response_format(client: "Letta", agent: "AgentState") -> None:
# """Test client send_message with response_format='json_object'."""
# client.agents.modify(agent.id, response_format={"type": "json_object"})
# response = client.agents.messages.create_stream(
# agent_id=agent.id,
# messages=USER_MESSAGE,
# )
# messages = list(response)
# assert isinstance(messages[-1], AssistantMessage)
# assert isinstance(messages[-1].content, dict)
# def test_client_send_message_json_schema_response_format(client: "Letta", agent: "AgentState") -> None:
# """Test client send_message with response_format='json_schema' and a valid schema."""
# client.agents.modify(agent.id, response_format={
# "type": "json_schema",
# "json_schema": {
# "name": "reasoning_schema",
# "schema": {
# "type": "object",
# "properties": {
# "steps": {
# "type": "array",
# "items": {
# "type": "object",
# "properties": {
# "explanation": { "type": "string" },
# "output": { "type": "string" }
# },
# "required": ["explanation", "output"],
# "additionalProperties": False
# }
# },
# "final_answer": { "type": "string" }
# },
# "required": ["steps", "final_answer"],
# "additionalProperties": True
# },
# "strict": True
# }
# })
# response = client.agents.messages.create_stream(
# agent_id=agent.id,
# messages=USER_MESSAGE,
# )
# messages = list(response)
# assert isinstance(messages[-1], AssistantMessage)
# assert isinstance(messages[-1].content, dict)
# # def test_client_send_message_invalid_json_schema(client: "Letta", agent: "AgentState") -> None:
# # """Test client send_message with an invalid json_schema (should error or fallback)."""
# # invalid_schema: Dict[str, Any] = {"type": "object", "properties": {"foo": {"type": "unknown"}}}
# # client.agents.modify(agent.id, response_format="json_schema")
# # result: Any = client.agents.send_message(agent.id, "Test invalid schema")
# # assert result is None or "error" in str(result).lower()

View File

@ -132,7 +132,7 @@ async def test_sleeptime_group_chat(server, actor):
response = await server.send_message_to_agent(
agent_id=main_agent.id,
actor=actor,
messages=[
input_messages=[
MessageCreate(
role="user",
content=text,
@ -206,7 +206,7 @@ async def test_sleeptime_removes_redundant_information(server, actor):
_ = await server.send_message_to_agent(
agent_id=main_agent.id,
actor=actor,
messages=[
input_messages=[
MessageCreate(
role="user",
content=test_message,
@ -270,7 +270,7 @@ async def test_sleeptime_edit(server, actor):
_ = await server.send_message_to_agent(
agent_id=sleeptime_agent.id,
actor=actor,
messages=[
input_messages=[
MessageCreate(
role="user",
content="Messi has now moved to playing for Inter Miami",

View File

@ -454,7 +454,7 @@ def test_agent_serialize_with_user_messages(local_client, server, serialize_test
"""Test deserializing JSON into an Agent instance."""
append_copy_suffix = False
server.send_messages(
actor=default_user, agent_id=serialize_test_agent.id, messages=[MessageCreate(role=MessageRole.user, content="hello")]
actor=default_user, agent_id=serialize_test_agent.id, input_messages=[MessageCreate(role=MessageRole.user, content="hello")]
)
result = server.agent_manager.serialize(agent_id=serialize_test_agent.id, actor=default_user)
@ -470,10 +470,12 @@ def test_agent_serialize_with_user_messages(local_client, server, serialize_test
# Make sure both agents can receive messages after
server.send_messages(
actor=default_user, agent_id=serialize_test_agent.id, messages=[MessageCreate(role=MessageRole.user, content="and hello again")]
actor=default_user,
agent_id=serialize_test_agent.id,
input_messages=[MessageCreate(role=MessageRole.user, content="and hello again")],
)
server.send_messages(
actor=other_user, agent_id=agent_copy.id, messages=[MessageCreate(role=MessageRole.user, content="and hello again")]
actor=other_user, agent_id=agent_copy.id, input_messages=[MessageCreate(role=MessageRole.user, content="and hello again")]
)
@ -483,7 +485,7 @@ def test_agent_serialize_tool_calls(disable_e2b_api_key, local_client, server, s
server.send_messages(
actor=default_user,
agent_id=serialize_test_agent.id,
messages=[MessageCreate(role=MessageRole.user, content="What's the weather like in San Francisco?")],
input_messages=[MessageCreate(role=MessageRole.user, content="What's the weather like in San Francisco?")],
)
result = server.agent_manager.serialize(agent_id=serialize_test_agent.id, actor=default_user)
@ -501,12 +503,12 @@ def test_agent_serialize_tool_calls(disable_e2b_api_key, local_client, server, s
original_agent_response = server.send_messages(
actor=default_user,
agent_id=serialize_test_agent.id,
messages=[MessageCreate(role=MessageRole.user, content="What's the weather like in Seattle?")],
input_messages=[MessageCreate(role=MessageRole.user, content="What's the weather like in Seattle?")],
)
copy_agent_response = server.send_messages(
actor=other_user,
agent_id=agent_copy.id,
messages=[MessageCreate(role=MessageRole.user, content="What's the weather like in Seattle?")],
input_messages=[MessageCreate(role=MessageRole.user, content="What's the weather like in Seattle?")],
)
assert original_agent_response.completion_tokens > 0 and original_agent_response.step_count > 0
@ -519,12 +521,12 @@ def test_agent_serialize_update_blocks(disable_e2b_api_key, local_client, server
server.send_messages(
actor=default_user,
agent_id=serialize_test_agent.id,
messages=[MessageCreate(role=MessageRole.user, content="Append 'banana' to core_memory.")],
input_messages=[MessageCreate(role=MessageRole.user, content="Append 'banana' to core_memory.")],
)
server.send_messages(
actor=default_user,
agent_id=serialize_test_agent.id,
messages=[MessageCreate(role=MessageRole.user, content="What do you think about that?")],
input_messages=[MessageCreate(role=MessageRole.user, content="What do you think about that?")],
)
result = server.agent_manager.serialize(agent_id=serialize_test_agent.id, actor=default_user)
@ -543,12 +545,12 @@ def test_agent_serialize_update_blocks(disable_e2b_api_key, local_client, server
original_agent_response = server.send_messages(
actor=default_user,
agent_id=serialize_test_agent.id,
messages=[MessageCreate(role=MessageRole.user, content="Hi")],
input_messages=[MessageCreate(role=MessageRole.user, content="Hi")],
)
copy_agent_response = server.send_messages(
actor=other_user,
agent_id=agent_copy.id,
messages=[MessageCreate(role=MessageRole.user, content="Hi")],
input_messages=[MessageCreate(role=MessageRole.user, content="Hi")],
)
assert original_agent_response.completion_tokens > 0 and original_agent_response.step_count > 0
@ -635,5 +637,5 @@ def test_upload_agentfile_from_disk(server, disable_e2b_api_key, fastapi_client,
server.send_messages(
actor=other_user,
agent_id=copied_agent_id,
messages=[MessageCreate(role=MessageRole.user, content="Hello there!")],
input_messages=[MessageCreate(role=MessageRole.user, content="Hello there!")],
)

View File

@ -1,11 +1,3 @@
"""
Tests for LettaAgentBatch.step_until_request functionality.
This module tests the batch processing capabilities of LettaAgentBatch,
specifically the step_until_request method which prepares agent requests
for batch processing.
"""
import os
import threading
import time
@ -92,6 +84,28 @@ def weather_tool(client):
yield tool
@pytest.fixture(scope="function")
def rethink_tool(client):
def rethink_memory(agent_state: "AgentState", new_memory: str, target_block_label: str) -> str: # type: ignore
"""
Re-evaluate the memory in block_name, integrating new and updated facts.
Replace outdated information with the most likely truths, avoiding redundancy with original memories.
Ensure consistency with other memory blocks.
Args:
new_memory (str): The new memory with information integrated from the memory block. If there is no new information, then this should be the same as the content in the source block.
target_block_label (str): The name of the block to write to.
Returns:
str: None is always returned as this function does not produce a response.
"""
agent_state.memory.update_block_value(label=target_block_label, value=new_memory)
return None
tool = client.tools.upsert_from_function(func=rethink_memory)
# Yield the created tool
yield tool
@pytest.fixture
def agents(client, weather_tool):
"""
@ -146,7 +160,7 @@ def step_state_map(agents):
Returns:
Dict[str, AgentStepState]: Mapping of agent IDs to step states
"""
solver = ToolRulesSolver(tool_rules=[InitToolRule(tool_name="send_message")])
solver = ToolRulesSolver(tool_rules=[InitToolRule(tool_name="get_weather")])
return {agent.id: AgentStepState(step_number=0, tool_rules_solver=solver) for agent in agents}
@ -173,26 +187,7 @@ def create_batch_response(batch_id: str, processing_status: str = "in_progress")
)
def create_successful_response(custom_id: str) -> BetaMessageBatchIndividualResponse:
"""Create a dummy successful batch response."""
return BetaMessageBatchIndividualResponse(
custom_id=custom_id,
result=BetaMessageBatchSucceededResult(
type="succeeded",
message=BetaMessage(
id="msg_abc123",
role="assistant",
type="message",
model="claude-3-5-sonnet-20240620",
content=[{"type": "text", "text": "hi!"}],
usage={"input_tokens": 5, "output_tokens": 7},
stop_reason="end_turn",
),
),
)
def create_complete_tool_response(custom_id: str, model: str, request_heartbeat: bool) -> BetaMessageBatchIndividualResponse:
def create_get_weather_tool_response(custom_id: str, model: str, request_heartbeat: bool) -> BetaMessageBatchIndividualResponse:
"""Create a dummy successful batch response with a tool call after user asks about weather."""
return BetaMessageBatchIndividualResponse(
custom_id=custom_id,
@ -223,6 +218,39 @@ def create_complete_tool_response(custom_id: str, model: str, request_heartbeat:
)
def create_rethink_tool_response(
custom_id: str, model: str, request_heartbeat: bool, new_memory: str, target_block_label: str
) -> BetaMessageBatchIndividualResponse:
"""Create a dummy successful batch response with a tool call after user asks about weather."""
return BetaMessageBatchIndividualResponse(
custom_id=custom_id,
result=BetaMessageBatchSucceededResult(
type="succeeded",
message=BetaMessage(
id="msg_abc123",
role="assistant",
type="message",
model=model,
content=[
{"type": "text", "text": "Let me rethink my memory."},
{
"type": "tool_use",
"id": "tu_01234567890123456789012345",
"name": "rethink_memory",
"input": {
"new_memory": new_memory,
"target_block_label": target_block_label,
"request_heartbeat": request_heartbeat,
},
},
],
usage={"input_tokens": 7, "output_tokens": 17},
stop_reason="end_turn",
),
),
)
def create_failed_response(custom_id: str) -> BetaMessageBatchIndividualResponse:
"""Create a dummy failed batch response with a rate limit error."""
return BetaMessageBatchIndividualResponse(
@ -340,6 +368,357 @@ class MockAsyncIterable:
# --------------------------------------------------------------------------- #
@pytest.mark.asyncio
async def test_rethink_tool_modify_agent_state(client, disable_e2b_api_key, server, default_user, batch_job, rethink_tool):
target_block_label = "human"
new_memory = "banana"
agent = client.agents.create(
name=f"test_agent_rethink",
include_base_tools=True,
model=MODELS["sonnet"],
tags=["test_agents"],
embedding="letta/letta-free",
tool_ids=[rethink_tool.id],
memory_blocks=[
{
"label": target_block_label,
"value": "Name: Matt",
},
],
)
agents = [agent]
batch_requests = [
LettaBatchRequest(agent_id=agent.id, messages=[MessageCreate(role="user", content=[TextContent(text=f"Rethink memory.")])])
for agent in agents
]
anthropic_batch_id = "msgbatch_test_12345"
dummy_batch_response = create_batch_response(
batch_id=anthropic_batch_id,
)
# 1. Invoke `step_until_request`
with patch("letta.llm_api.anthropic_client.AnthropicClient.send_llm_batch_request_async", return_value=dummy_batch_response):
# Create batch runner
batch_runner = LettaAgentBatch(
message_manager=server.message_manager,
agent_manager=server.agent_manager,
block_manager=server.block_manager,
passage_manager=server.passage_manager,
batch_manager=server.batch_manager,
sandbox_config_manager=server.sandbox_config_manager,
job_manager=server.job_manager,
actor=default_user,
)
# Run the method under test
solver = ToolRulesSolver(tool_rules=[InitToolRule(tool_name="rethink_memory")])
step_state_map = {agent.id: AgentStepState(step_number=0, tool_rules_solver=solver) for agent in agents}
pre_resume_response = await batch_runner.step_until_request(
batch_requests=batch_requests,
agent_step_state_mapping=step_state_map,
letta_batch_job_id=batch_job.id,
)
# 2. Invoke the polling job and mock responses from Anthropic
mock_retrieve = AsyncMock(return_value=create_batch_response(batch_id=pre_resume_response.letta_batch_id, processing_status="ended"))
with patch.object(server.anthropic_async_client.beta.messages.batches, "retrieve", mock_retrieve):
mock_items = [
create_rethink_tool_response(
custom_id=agent.id,
model=agent.llm_config.model,
request_heartbeat=False,
new_memory=new_memory,
target_block_label=target_block_label,
)
for agent in agents
]
# Create the mock for results
mock_results = Mock()
mock_results.return_value = MockAsyncIterable(mock_items.copy())
with patch.object(server.anthropic_async_client.beta.messages.batches, "results", mock_results):
with patch("letta.llm_api.anthropic_client.AnthropicClient.send_llm_batch_request_async", return_value=dummy_batch_response):
await poll_running_llm_batches(server)
# Check that the tool has been executed correctly
agent = client.agents.retrieve(agent_id=agent.id)
for block in agent.memory.blocks:
if block.label == target_block_label:
assert block.value == new_memory
@pytest.mark.asyncio
async def test_partial_error_from_anthropic_batch(
disable_e2b_api_key, server, default_user, agents: Tuple[AgentState], batch_requests, step_state_map, batch_job
):
anthropic_batch_id = "msgbatch_test_12345"
dummy_batch_response = create_batch_response(
batch_id=anthropic_batch_id,
)
# 1. Invoke `step_until_request`
with patch("letta.llm_api.anthropic_client.AnthropicClient.send_llm_batch_request_async", return_value=dummy_batch_response):
# Create batch runner
batch_runner = LettaAgentBatch(
message_manager=server.message_manager,
agent_manager=server.agent_manager,
block_manager=server.block_manager,
passage_manager=server.passage_manager,
batch_manager=server.batch_manager,
sandbox_config_manager=server.sandbox_config_manager,
job_manager=server.job_manager,
actor=default_user,
)
# Run the method under test
pre_resume_response = await batch_runner.step_until_request(
batch_requests=batch_requests,
agent_step_state_mapping=step_state_map,
letta_batch_job_id=batch_job.id,
)
llm_batch_jobs = server.batch_manager.list_llm_batch_jobs(letta_batch_id=pre_resume_response.letta_batch_id, actor=default_user)
llm_batch_job = llm_batch_jobs[0]
# 2. Invoke the polling job and mock responses from Anthropic
mock_retrieve = AsyncMock(return_value=create_batch_response(batch_id=pre_resume_response.letta_batch_id, processing_status="ended"))
with patch.object(server.anthropic_async_client.beta.messages.batches, "retrieve", mock_retrieve):
agents_failed = agents[:1]
agents_continue = agents[1:]
# Create failed response for one agent
mock_items = [create_failed_response(custom_id=agent.id) for agent in agents_failed]
mock_items.extend(
[
create_get_weather_tool_response(custom_id=agent.id, model=agent.llm_config.model, request_heartbeat=True)
for agent in agents_continue
]
)
# Create the mock for results
mock_results = Mock()
mock_results.return_value = MockAsyncIterable(mock_items.copy()) # Using copy to preserve the original list
with patch.object(server.anthropic_async_client.beta.messages.batches, "results", mock_results):
with patch("letta.llm_api.anthropic_client.AnthropicClient.send_llm_batch_request_async", return_value=dummy_batch_response):
msg_counts_before = {agent.id: server.message_manager.size(actor=default_user, agent_id=agent.id) for agent in agents}
new_batch_responses = await poll_running_llm_batches(server)
# Verify database records were updated correctly
llm_batch_job = server.batch_manager.get_llm_batch_job_by_id(llm_batch_job.id, actor=default_user)
# Verify job properties
assert llm_batch_job.status == JobStatus.completed, "Job status should be 'completed'"
# Verify batch items
items = server.batch_manager.list_llm_batch_items(llm_batch_id=llm_batch_job.id, actor=default_user)
assert len(items) == 3, f"Expected 3 batch items, got {len(items)}"
# Verify only one new batch response
assert len(new_batch_responses) == 1
post_resume_response = new_batch_responses[0]
assert (
post_resume_response.letta_batch_id == pre_resume_response.letta_batch_id
), "resume_step_after_request is expected to have the same letta_batch_id"
assert (
post_resume_response.last_llm_batch_id != pre_resume_response.last_llm_batch_id
), "resume_step_after_request is expected to have different llm_batch_id."
assert post_resume_response.status == JobStatus.running
# NOTE: We only expect 2 agents to continue (succeeded ones)
assert post_resume_response.agent_count == 2
# New batchitems should exist, initialised in (created, paused) state
new_items = server.batch_manager.list_llm_batch_items(
llm_batch_id=post_resume_response.last_llm_batch_id, actor=default_user
)
assert len(new_items) == 2, f"Expected 2 new batch item, got {len(new_items)}"
# Assert that the continuing agent is in the only item
assert {i.agent_id for i in new_items} == {a.id for a in agents_continue}
assert {i.request_status for i in new_items} == {JobStatus.created}
assert {i.step_status for i in new_items} == {AgentStepStatus.paused}
# Confirm that tool_rules_solver state was preserved correctly
# Assert every new item's step_state's tool_rules_solver has "get_weather" in the tool_call_history
assert all(
"get_weather" in item.step_state.tool_rules_solver.tool_call_history for item in new_items
), "Expected 'get_weather' in tool_call_history for all new_items"
# Assert that each new item's step_number was incremented to 1
assert all(
item.step_state.step_number == 1 for item in new_items
), "Expected step_number to be incremented to 1 for all new_items"
# Old items must have been flipped to completed / finished earlier
# (sanity we already asserted this above, but we keep it close for clarity)
old_items = server.batch_manager.list_llm_batch_items(
llm_batch_id=pre_resume_response.last_llm_batch_id, actor=default_user
)
for item in old_items:
if item.agent_id == agents_failed[0].id:
assert item.request_status == JobStatus.failed
assert item.step_status == AgentStepStatus.paused
else:
assert item.request_status == JobStatus.completed
assert item.step_status == AgentStepStatus.completed
# Toolcall sideeffects each agent gets at least 2 extra messages
for agent in agents:
before = msg_counts_before[agent.id] # captured just before resume
after = server.message_manager.size(actor=default_user, agent_id=agent.id)
if agent.id == agents_failed[0].id:
assert after == before, f"Agent {agent.id} should not have extra messages persisted due to Anthropic failure"
else:
assert after - before >= 2, (
f"Agent {agent.id} should have an assistant toolcall " f"and toolresponse message persisted."
)
# Check that agent states have been properly modified to have extended in-context messages
for agent in agents:
refreshed_agent = server.agent_manager.get_agent_by_id(agent_id=agent.id, actor=default_user)
if refreshed_agent.id == agents_failed[0].id:
assert (
len(refreshed_agent.message_ids) == 4
), f"Agent's in-context messages have not been extended, are length: {len(refreshed_agent.message_ids)}"
else:
assert (
len(refreshed_agent.message_ids) == 6
), f"Agent's in-context messages have been extended, are length: {len(refreshed_agent.message_ids)}"
@pytest.mark.asyncio
async def test_resume_step_some_stop(
disable_e2b_api_key, server, default_user, agents: Tuple[AgentState], batch_requests, step_state_map, batch_job
):
anthropic_batch_id = "msgbatch_test_12345"
dummy_batch_response = create_batch_response(
batch_id=anthropic_batch_id,
)
# 1. Invoke `step_until_request`
with patch("letta.llm_api.anthropic_client.AnthropicClient.send_llm_batch_request_async", return_value=dummy_batch_response):
# Create batch runner
batch_runner = LettaAgentBatch(
message_manager=server.message_manager,
agent_manager=server.agent_manager,
block_manager=server.block_manager,
passage_manager=server.passage_manager,
batch_manager=server.batch_manager,
sandbox_config_manager=server.sandbox_config_manager,
job_manager=server.job_manager,
actor=default_user,
)
# Run the method under test
pre_resume_response = await batch_runner.step_until_request(
batch_requests=batch_requests,
agent_step_state_mapping=step_state_map,
letta_batch_job_id=batch_job.id,
)
llm_batch_jobs = server.batch_manager.list_llm_batch_jobs(letta_batch_id=pre_resume_response.letta_batch_id, actor=default_user)
llm_batch_job = llm_batch_jobs[0]
# 2. Invoke the polling job and mock responses from Anthropic
mock_retrieve = AsyncMock(return_value=create_batch_response(batch_id=pre_resume_response.letta_batch_id, processing_status="ended"))
with patch.object(server.anthropic_async_client.beta.messages.batches, "retrieve", mock_retrieve):
agents_continue = agents[:1]
agents_finish = agents[1:]
mock_items = [
create_get_weather_tool_response(custom_id=agent.id, model=agent.llm_config.model, request_heartbeat=True)
for agent in agents_continue
]
mock_items.extend(
[
create_get_weather_tool_response(custom_id=agent.id, model=agent.llm_config.model, request_heartbeat=False)
for agent in agents_finish
]
)
# Create the mock for results
mock_results = Mock()
mock_results.return_value = MockAsyncIterable(mock_items.copy()) # Using copy to preserve the original list
with patch.object(server.anthropic_async_client.beta.messages.batches, "results", mock_results):
with patch("letta.llm_api.anthropic_client.AnthropicClient.send_llm_batch_request_async", return_value=dummy_batch_response):
msg_counts_before = {agent.id: server.message_manager.size(actor=default_user, agent_id=agent.id) for agent in agents}
new_batch_responses = await poll_running_llm_batches(server)
# Verify database records were updated correctly
llm_batch_job = server.batch_manager.get_llm_batch_job_by_id(llm_batch_job.id, actor=default_user)
# Verify job properties
assert llm_batch_job.status == JobStatus.completed, "Job status should be 'completed'"
# Verify batch items
items = server.batch_manager.list_llm_batch_items(llm_batch_id=llm_batch_job.id, actor=default_user)
assert len(items) == 3, f"Expected 3 batch items, got {len(items)}"
assert all([item.request_status == JobStatus.completed for item in items])
# Verify only one new batch response
assert len(new_batch_responses) == 1
post_resume_response = new_batch_responses[0]
assert (
post_resume_response.letta_batch_id == pre_resume_response.letta_batch_id
), "resume_step_after_request is expected to have the same letta_batch_id"
assert (
post_resume_response.last_llm_batch_id != pre_resume_response.last_llm_batch_id
), "resume_step_after_request is expected to have different llm_batch_id."
assert post_resume_response.status == JobStatus.running
# NOTE: We only expect 1 agent to continue
assert post_resume_response.agent_count == 1
# New batchitems should exist, initialised in (created, paused) state
new_items = server.batch_manager.list_llm_batch_items(
llm_batch_id=post_resume_response.last_llm_batch_id, actor=default_user
)
assert len(new_items) == 1, f"Expected 1 new batch item, got {len(new_items)}"
# Assert that the continuing agent is in the only item
assert new_items[0].agent_id == agents_continue[0].id
assert {i.request_status for i in new_items} == {JobStatus.created}
assert {i.step_status for i in new_items} == {AgentStepStatus.paused}
# Confirm that tool_rules_solver state was preserved correctly
# Assert every new item's step_state's tool_rules_solver has "get_weather" in the tool_call_history
assert all(
"get_weather" in item.step_state.tool_rules_solver.tool_call_history for item in new_items
), "Expected 'get_weather' in tool_call_history for all new_items"
# Assert that each new item's step_number was incremented to 1
assert all(
item.step_state.step_number == 1 for item in new_items
), "Expected step_number to be incremented to 1 for all new_items"
# Old items must have been flipped to completed / finished earlier
# (sanity we already asserted this above, but we keep it close for clarity)
old_items = server.batch_manager.list_llm_batch_items(
llm_batch_id=pre_resume_response.last_llm_batch_id, actor=default_user
)
assert {i.request_status for i in old_items} == {JobStatus.completed}
assert {i.step_status for i in old_items} == {AgentStepStatus.completed}
# Toolcall sideeffects each agent gets at least 2 extra messages
for agent in agents:
before = msg_counts_before[agent.id] # captured just before resume
after = server.message_manager.size(actor=default_user, agent_id=agent.id)
assert after - before >= 2, (
f"Agent {agent.id} should have an assistant toolcall " f"and toolresponse message persisted."
)
# Check that agent states have been properly modified to have extended in-context messages
for agent in agents:
refreshed_agent = server.agent_manager.get_agent_by_id(agent_id=agent.id, actor=default_user)
assert (
len(refreshed_agent.message_ids) == 6
), f"Agent's in-context messages have been extended, are length: {len(refreshed_agent.message_ids)}"
@pytest.mark.asyncio
async def test_resume_step_after_request_all_continue(
disable_e2b_api_key, server, default_user, agents: Tuple[AgentState], batch_requests, step_state_map, batch_job
@ -384,7 +763,7 @@ async def test_resume_step_after_request_all_continue(
with patch.object(server.anthropic_async_client.beta.messages.batches, "retrieve", mock_retrieve):
mock_items = [
create_complete_tool_response(custom_id=agent.id, model=agent.llm_config.model, request_heartbeat=True) for agent in agents
create_get_weather_tool_response(custom_id=agent.id, model=agent.llm_config.model, request_heartbeat=True) for agent in agents
]
# Create the mock for results
@ -460,7 +839,7 @@ async def test_resume_step_after_request_all_continue(
refreshed_agent = server.agent_manager.get_agent_by_id(agent_id=agent.id, actor=default_user)
assert (
len(refreshed_agent.message_ids) == 6
), f"Agent's in-context messages have not been extended, are length: {len(refreshed_agent.message_ids)}"
), f"Agent's in-context messages have been extended, are length: {len(refreshed_agent.message_ids)}"
@pytest.mark.asyncio
@ -518,7 +897,7 @@ async def test_step_until_request_prepares_and_submits_batch_correctly(
# Verify tool configuration
for agent_id, tools in agent_tools_mapping.items():
available_tools = {tool["name"] for tool in tools}
assert available_tools == {"send_message"}, f"Expected only send_message tool, got {available_tools}"
assert available_tools == {"get_weather"}, f"Expected only send_message tool, got {available_tools}"
# Verify model assignments
for agent_id, expected_model in expected_models.items():

View File

@ -1,5 +1,6 @@
import os
import random
import re
import string
import time
from datetime import datetime, timedelta, timezone
@ -5211,6 +5212,47 @@ def test_bulk_update_batch_items_request_status_by_agent(
assert updated.request_status == JobStatus.expired
def test_bulk_update_nonexistent_items_should_error(
server,
default_user,
dummy_beta_message_batch,
dummy_successful_response,
letta_batch_job,
):
# Create a batch job
batch = server.batch_manager.create_llm_batch_job(
llm_provider=ProviderType.anthropic,
create_batch_response=dummy_beta_message_batch,
actor=default_user,
letta_batch_job_id=letta_batch_job.id,
)
nonexistent_pairs = [(batch.id, "nonexistent-agent-id")]
nonexistent_updates = [{"request_status": JobStatus.expired}]
expected_err_msg = (
f"Cannot bulk-update batch items: no records for the following "
f"(llm_batch_id, agent_id) pairs: {{('{batch.id}', 'nonexistent-agent-id')}}"
)
with pytest.raises(ValueError, match=re.escape(expected_err_msg)):
server.batch_manager.bulk_update_llm_batch_items(nonexistent_pairs, nonexistent_updates)
with pytest.raises(ValueError, match=re.escape(expected_err_msg)):
server.batch_manager.bulk_update_batch_llm_items_results_by_agent(
[ItemUpdateInfo(batch.id, "nonexistent-agent-id", JobStatus.expired, dummy_successful_response)]
)
with pytest.raises(ValueError, match=re.escape(expected_err_msg)):
server.batch_manager.bulk_update_llm_batch_items_step_status_by_agent(
[StepStatusUpdateInfo(batch.id, "nonexistent-agent-id", AgentStepStatus.resumed)]
)
with pytest.raises(ValueError, match=re.escape(expected_err_msg)):
server.batch_manager.bulk_update_llm_batch_items_request_status_by_agent(
[RequestStatusUpdateInfo(batch.id, "nonexistent-agent-id", JobStatus.expired)]
)
def test_bulk_update_nonexistent_items(server, default_user, dummy_beta_message_batch, dummy_successful_response, letta_batch_job):
# Create a batch job
batch = server.batch_manager.create_llm_batch_job(
@ -5227,22 +5269,22 @@ def test_bulk_update_nonexistent_items(server, default_user, dummy_beta_message_
nonexistent_updates = [{"request_status": JobStatus.expired}]
# This should not raise an error, just silently skip non-existent items
server.batch_manager.bulk_update_llm_batch_items(nonexistent_pairs, nonexistent_updates)
server.batch_manager.bulk_update_llm_batch_items(nonexistent_pairs, nonexistent_updates, strict=False)
# Test with higher-level methods
# Results by agent
server.batch_manager.bulk_update_batch_llm_items_results_by_agent(
[ItemUpdateInfo(batch.id, "nonexistent-agent-id", JobStatus.expired, dummy_successful_response)]
[ItemUpdateInfo(batch.id, "nonexistent-agent-id", JobStatus.expired, dummy_successful_response)], strict=False
)
# Step status by agent
server.batch_manager.bulk_update_llm_batch_items_step_status_by_agent(
[StepStatusUpdateInfo(batch.id, "nonexistent-agent-id", AgentStepStatus.resumed)]
[StepStatusUpdateInfo(batch.id, "nonexistent-agent-id", AgentStepStatus.resumed)], strict=False
)
# Request status by agent
server.batch_manager.bulk_update_llm_batch_items_request_status_by_agent(
[RequestStatusUpdateInfo(batch.id, "nonexistent-agent-id", JobStatus.expired)]
[RequestStatusUpdateInfo(batch.id, "nonexistent-agent-id", JobStatus.expired)], strict=False
)

View File

@ -158,7 +158,7 @@ async def test_empty_group(server, actor):
await server.send_group_message_to_agent(
group_id=group.id,
actor=actor,
messages=[
input_messages=[
MessageCreate(
role="user",
content="what is everyone up to for the holidays?",
@ -246,7 +246,7 @@ async def test_round_robin(server, actor, participant_agents):
response = await server.send_group_message_to_agent(
group_id=group.id,
actor=actor,
messages=[
input_messages=[
MessageCreate(
role="user",
content="what is everyone up to for the holidays?",
@ -301,7 +301,7 @@ async def test_round_robin(server, actor, participant_agents):
response = await server.send_group_message_to_agent(
group_id=group.id,
actor=actor,
messages=[
input_messages=[
MessageCreate(
role="user",
content="what is everyone up to for the holidays?",
@ -367,7 +367,7 @@ async def test_supervisor(server, actor, participant_agents):
response = await server.send_group_message_to_agent(
group_id=group.id,
actor=actor,
messages=[
input_messages=[
MessageCreate(
role="user",
content="ask everyone what they like to do for fun and then come up with an activity for everyone to do together.",
@ -449,7 +449,7 @@ async def test_dynamic_group_chat(server, actor, manager_agent, participant_agen
response = await server.send_group_message_to_agent(
group_id=group.id,
actor=actor,
messages=[
input_messages=[
MessageCreate(role="user", content="what is everyone up to for the holidays?"),
],
stream_steps=False,