mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
chore: release 0.7.1 (#2583)
This commit is contained in:
commit
435b754286
@ -0,0 +1,31 @@
|
||||
"""add support for structured_outputs in agents
|
||||
|
||||
Revision ID: 28b8765bdd0a
|
||||
Revises: a3c7d62e08ca
|
||||
Create Date: 2025-04-18 11:43:47.701786
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "28b8765bdd0a"
|
||||
down_revision: Union[str, None] = "a3c7d62e08ca"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("agents", sa.Column("response_format", sa.JSON(), nullable=True))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("agents", "response_format")
|
||||
# ### end Alembic commands ###
|
@ -1,11 +1,10 @@
|
||||
import time
|
||||
|
||||
from letta_client import AgentType, GroupUpdateManagerConfig, Letta, ManagerType, SleeptimeManagerUpdate
|
||||
from letta_client.types import agent_type
|
||||
from letta_client import Letta
|
||||
|
||||
client = Letta(base_url="http://localhost:8283")
|
||||
|
||||
# delete all sources
|
||||
# delete all sources
|
||||
for source in client.sources.list():
|
||||
print(f"Deleting source {source.name}")
|
||||
client.sources.delete(source.id)
|
||||
@ -21,19 +20,19 @@ agent = client.agents.create(
|
||||
)
|
||||
print(f"Created agent id {agent.id}")
|
||||
|
||||
# get the group
|
||||
# get the group
|
||||
group_id = agent.multi_agent_group.id
|
||||
current_frequence = agent.multi_agent_group.sleeptime_agent_frequency
|
||||
print(f"Group id: {group_id}, frequency: {current_frequence}")
|
||||
|
||||
# create a source
|
||||
# create a source
|
||||
source_name = "employee_handbook"
|
||||
source = client.sources.create(
|
||||
name=source_name,
|
||||
description="Provides reference information for the employee handbook",
|
||||
description="Provides reference information for the employee handbook",
|
||||
embedding="openai/text-embedding-ada-002" # must match agent
|
||||
)
|
||||
# attach the source to the agent
|
||||
# attach the source to the agent
|
||||
client.agents.sources.attach(
|
||||
source_id=source.id,
|
||||
agent_id=agent.id
|
||||
@ -52,7 +51,7 @@ print("Agent blocks", [b.label for b in client.agents.blocks.list(agent_id=agent
|
||||
block = client.agents.blocks.retrieve(agent_id=agent.id, block_label="employee_handbook")
|
||||
|
||||
|
||||
# get attached agents
|
||||
# get attached agents
|
||||
agents = client.blocks.agents.list(block_id=block.id)
|
||||
for agent in agents:
|
||||
print(f"Agent id {agent.id}", agent.agent_type)
|
||||
@ -63,14 +62,10 @@ for agent in agents:
|
||||
while job.status != "completed":
|
||||
job = client.jobs.retrieve(job.id)
|
||||
|
||||
# count passages
|
||||
# count passages
|
||||
passages = client.agents.passages.list(agent_id=agent.id)
|
||||
print(f"Passages {len(passages)}")
|
||||
for passage in passages:
|
||||
print(passage.text)
|
||||
|
||||
time.sleep(2)
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
__version__ = "0.7.0"
|
||||
__version__ = "0.7.1"
|
||||
|
||||
# import clients
|
||||
from letta.client.client import LocalClient, RESTClient, create_client
|
||||
|
194
letta/agent.py
194
letta/agent.py
@ -3,7 +3,7 @@ import time
|
||||
import traceback
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
from openai.types.beta.function_tool import FunctionTool as OpenAITool
|
||||
|
||||
@ -17,6 +17,7 @@ from letta.constants import (
|
||||
LETTA_MULTI_AGENT_TOOL_MODULE_NAME,
|
||||
LLM_MAX_TOKENS,
|
||||
REQ_HEARTBEAT_MESSAGE,
|
||||
SEND_MESSAGE_TOOL_NAME,
|
||||
)
|
||||
from letta.errors import ContextWindowExceededError
|
||||
from letta.functions.ast_parsers import coerce_dict_args_by_annotations, get_function_annotations_from_source
|
||||
@ -27,6 +28,7 @@ from letta.helpers import ToolRulesSolver
|
||||
from letta.helpers.composio_helpers import get_composio_api_key
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
from letta.helpers.json_helpers import json_dumps, json_loads
|
||||
from letta.helpers.message_helper import prepare_input_message_create
|
||||
from letta.interface import AgentInterface
|
||||
from letta.llm_api.helpers import calculate_summarizer_cutoff, get_token_counts_for_messages, is_context_overflow_error
|
||||
from letta.llm_api.llm_api_tools import create
|
||||
@ -42,12 +44,13 @@ from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.memory import ContextWindowOverview, Memory
|
||||
from letta.schemas.message import Message, ToolReturn
|
||||
from letta.schemas.message import Message, MessageCreate, ToolReturn
|
||||
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
|
||||
from letta.schemas.openai.chat_completion_response import Message as ChatCompletionMessage
|
||||
from letta.schemas.openai.chat_completion_response import UsageStatistics
|
||||
from letta.schemas.sandbox_config import SandboxRunResult
|
||||
from letta.schemas.response_format import ResponseFormatType
|
||||
from letta.schemas.tool import Tool
|
||||
from letta.schemas.tool_execution_result import ToolExecutionResult
|
||||
from letta.schemas.tool_rule import TerminalToolRule
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.services.agent_manager import AgentManager
|
||||
@ -78,7 +81,7 @@ class BaseAgent(ABC):
|
||||
@abstractmethod
|
||||
def step(
|
||||
self,
|
||||
messages: Union[Message, List[Message]],
|
||||
input_messages: List[MessageCreate],
|
||||
) -> LettaUsageStatistics:
|
||||
"""
|
||||
Top-level event message handler for the agent.
|
||||
@ -255,6 +258,28 @@ class Agent(BaseAgent):
|
||||
# Return updated messages
|
||||
return messages
|
||||
|
||||
def _runtime_override_tool_json_schema(
|
||||
self,
|
||||
functions_list: List[Dict | None],
|
||||
) -> List[Dict | None]:
|
||||
"""Override the tool JSON schema at runtime for a particular tool if conditions are met."""
|
||||
|
||||
# Currently just injects `send_message` with a `response_format` if provided to the agent.
|
||||
if self.agent_state.response_format and self.agent_state.response_format.type != ResponseFormatType.text:
|
||||
for func in functions_list:
|
||||
if func["name"] == SEND_MESSAGE_TOOL_NAME:
|
||||
if self.agent_state.response_format.type == ResponseFormatType.json_schema:
|
||||
func["parameters"]["properties"]["message"] = self.agent_state.response_format.json_schema["schema"]
|
||||
if self.agent_state.response_format.type == ResponseFormatType.json_object:
|
||||
func["parameters"]["properties"]["message"] = {
|
||||
"type": "object",
|
||||
"description": "Message contents. All unicode (including emojis) are supported.",
|
||||
"additionalProperties": True,
|
||||
"properties": {},
|
||||
}
|
||||
break
|
||||
return functions_list
|
||||
|
||||
@trace_method
|
||||
def _get_ai_reply(
|
||||
self,
|
||||
@ -268,27 +293,26 @@ class Agent(BaseAgent):
|
||||
step_count: Optional[int] = None,
|
||||
last_function_failed: bool = False,
|
||||
put_inner_thoughts_first: bool = True,
|
||||
) -> ChatCompletionResponse:
|
||||
) -> ChatCompletionResponse | None:
|
||||
"""Get response from LLM API with robust retry mechanism."""
|
||||
log_telemetry(self.logger, "_get_ai_reply start")
|
||||
available_tools = set([t.name for t in self.agent_state.tools])
|
||||
allowed_tool_names = self.tool_rules_solver.get_allowed_tool_names(
|
||||
available_tools=available_tools, last_function_response=self.last_function_response
|
||||
)
|
||||
agent_state_tool_jsons = [t.json_schema for t in self.agent_state.tools]
|
||||
|
||||
allowed_functions = (
|
||||
agent_state_tool_jsons
|
||||
if not allowed_tool_names
|
||||
else [func for func in agent_state_tool_jsons if func["name"] in allowed_tool_names]
|
||||
)
|
||||
# Get allowed tools or allow all if none are allowed
|
||||
allowed_tool_names = self.tool_rules_solver.get_allowed_tool_names(
|
||||
available_tools=available_tools, last_function_response=self.last_function_response
|
||||
) or list(available_tools)
|
||||
|
||||
# Don't allow a tool to be called if it failed last time
|
||||
if last_function_failed and self.tool_rules_solver.tool_call_history:
|
||||
allowed_functions = [f for f in allowed_functions if f["name"] != self.tool_rules_solver.tool_call_history[-1]]
|
||||
if not allowed_functions:
|
||||
allowed_tool_names = [f for f in allowed_tool_names if f != self.tool_rules_solver.tool_call_history[-1]]
|
||||
if not allowed_tool_names:
|
||||
return None
|
||||
|
||||
allowed_functions = [func for func in agent_state_tool_jsons if func["name"] in allowed_tool_names]
|
||||
allowed_functions = self._runtime_override_tool_json_schema(allowed_functions)
|
||||
|
||||
# For the first message, force the initial tool if one is specified
|
||||
force_tool_call = None
|
||||
if (
|
||||
@ -418,7 +442,7 @@ class Agent(BaseAgent):
|
||||
tool_call_id = response_message.tool_calls[0].id
|
||||
assert tool_call_id is not None # should be defined
|
||||
|
||||
# only necessary to add the tool_cal_id to a function call (antipattern)
|
||||
# only necessary to add the tool_call_id to a function call (antipattern)
|
||||
# response_message_dict = response_message.model_dump()
|
||||
# response_message_dict["tool_call_id"] = tool_call_id
|
||||
|
||||
@ -513,6 +537,10 @@ class Agent(BaseAgent):
|
||||
# Failure case 3: function failed during execution
|
||||
# NOTE: the msg_obj associated with the "Running " message is the prior assistant message, not the function/tool role message
|
||||
# this is because the function/tool role message is only created once the function/tool has executed/returned
|
||||
|
||||
# handle cases where we return a json message
|
||||
if "message" in function_args:
|
||||
function_args["message"] = str(function_args.get("message", ""))
|
||||
self.interface.function_message(f"Running {function_name}({function_args})", msg_obj=messages[-1], chunk_index=self.chunk_index)
|
||||
self.chunk_index += 1
|
||||
try:
|
||||
@ -529,22 +557,23 @@ class Agent(BaseAgent):
|
||||
},
|
||||
)
|
||||
|
||||
function_response, sandbox_run_result = self.execute_tool_and_persist_state(function_name, function_args, target_letta_tool)
|
||||
tool_execution_result = self.execute_tool_and_persist_state(function_name, function_args, target_letta_tool)
|
||||
function_response = tool_execution_result.func_return
|
||||
|
||||
log_event(
|
||||
"tool_call_ended",
|
||||
attributes={
|
||||
"function_response": function_response,
|
||||
"sandbox_run_result": sandbox_run_result.model_dump() if sandbox_run_result else None,
|
||||
"tool_execution_result": tool_execution_result.model_dump(),
|
||||
},
|
||||
)
|
||||
log_telemetry(
|
||||
self.logger, "_handle_ai_response execute tool finish", function_name=function_name, function_args=function_args
|
||||
)
|
||||
|
||||
if sandbox_run_result and sandbox_run_result.status == "error":
|
||||
if tool_execution_result and tool_execution_result.status == "error":
|
||||
tool_return = ToolReturn(
|
||||
status=sandbox_run_result.status, stdout=sandbox_run_result.stdout, stderr=sandbox_run_result.stderr
|
||||
status=tool_execution_result.status, stdout=tool_execution_result.stdout, stderr=tool_execution_result.stderr
|
||||
)
|
||||
messages = self._handle_function_error_response(
|
||||
function_response,
|
||||
@ -598,14 +627,10 @@ class Agent(BaseAgent):
|
||||
# Step 4: check if function response is an error
|
||||
if function_response_string.startswith(ERROR_MESSAGE_PREFIX):
|
||||
error_msg = function_response_string
|
||||
tool_return = (
|
||||
ToolReturn(
|
||||
status=sandbox_run_result.status,
|
||||
stdout=sandbox_run_result.stdout,
|
||||
stderr=sandbox_run_result.stderr,
|
||||
)
|
||||
if sandbox_run_result
|
||||
else None
|
||||
tool_return = ToolReturn(
|
||||
status=tool_execution_result.status,
|
||||
stdout=tool_execution_result.stdout,
|
||||
stderr=tool_execution_result.stderr,
|
||||
)
|
||||
messages = self._handle_function_error_response(
|
||||
error_msg,
|
||||
@ -622,14 +647,10 @@ class Agent(BaseAgent):
|
||||
|
||||
# If no failures happened along the way: ...
|
||||
# Step 5: send the info on the function call and function response to GPT
|
||||
tool_return = (
|
||||
ToolReturn(
|
||||
status=sandbox_run_result.status,
|
||||
stdout=sandbox_run_result.stdout,
|
||||
stderr=sandbox_run_result.stderr,
|
||||
)
|
||||
if sandbox_run_result
|
||||
else None
|
||||
tool_return = ToolReturn(
|
||||
status=tool_execution_result.status,
|
||||
stdout=tool_execution_result.stdout,
|
||||
stderr=tool_execution_result.stderr,
|
||||
)
|
||||
messages.append(
|
||||
Message(
|
||||
@ -641,7 +662,7 @@ class Agent(BaseAgent):
|
||||
content=[TextContent(text=function_response)],
|
||||
tool_call_id=tool_call_id,
|
||||
# Letta extras
|
||||
tool_returns=[tool_return] if sandbox_run_result else None,
|
||||
tool_returns=[tool_return],
|
||||
group_id=group_id,
|
||||
)
|
||||
) # extend conversation with function response
|
||||
@ -691,7 +712,7 @@ class Agent(BaseAgent):
|
||||
@trace_method
|
||||
def step(
|
||||
self,
|
||||
messages: Union[Message, List[Message]],
|
||||
input_messages: List[MessageCreate],
|
||||
# additional args
|
||||
chaining: bool = True,
|
||||
max_chaining_steps: Optional[int] = None,
|
||||
@ -704,7 +725,9 @@ class Agent(BaseAgent):
|
||||
# But just to be safe
|
||||
self.tool_rules_solver.clear_tool_history()
|
||||
|
||||
next_input_message = messages if isinstance(messages, list) else [messages]
|
||||
# Convert MessageCreate objects to Message objects
|
||||
message_objects = [prepare_input_message_create(m, self.agent_state.id, True, True) for m in input_messages]
|
||||
next_input_messages = message_objects
|
||||
counter = 0
|
||||
total_usage = UsageStatistics()
|
||||
step_count = 0
|
||||
@ -715,7 +738,7 @@ class Agent(BaseAgent):
|
||||
kwargs["step_count"] = step_count
|
||||
kwargs["last_function_failed"] = function_failed
|
||||
step_response = self.inner_step(
|
||||
messages=next_input_message,
|
||||
messages=next_input_messages,
|
||||
put_inner_thoughts_first=put_inner_thoughts_first,
|
||||
**kwargs,
|
||||
)
|
||||
@ -745,36 +768,42 @@ class Agent(BaseAgent):
|
||||
# Chain handlers
|
||||
elif token_warning and summarizer_settings.send_memory_warning_message:
|
||||
assert self.agent_state.created_by_id is not None
|
||||
next_input_message = Message.dict_to_message(
|
||||
agent_id=self.agent_state.id,
|
||||
model=self.model,
|
||||
openai_message_dict={
|
||||
"role": "user", # TODO: change to system?
|
||||
"content": get_token_limit_warning(),
|
||||
},
|
||||
)
|
||||
next_input_messages = [
|
||||
Message.dict_to_message(
|
||||
agent_id=self.agent_state.id,
|
||||
model=self.model,
|
||||
openai_message_dict={
|
||||
"role": "user", # TODO: change to system?
|
||||
"content": get_token_limit_warning(),
|
||||
},
|
||||
),
|
||||
]
|
||||
continue # always chain
|
||||
elif function_failed:
|
||||
assert self.agent_state.created_by_id is not None
|
||||
next_input_message = Message.dict_to_message(
|
||||
agent_id=self.agent_state.id,
|
||||
model=self.model,
|
||||
openai_message_dict={
|
||||
"role": "user", # TODO: change to system?
|
||||
"content": get_heartbeat(FUNC_FAILED_HEARTBEAT_MESSAGE),
|
||||
},
|
||||
)
|
||||
next_input_messages = [
|
||||
Message.dict_to_message(
|
||||
agent_id=self.agent_state.id,
|
||||
model=self.model,
|
||||
openai_message_dict={
|
||||
"role": "user", # TODO: change to system?
|
||||
"content": get_heartbeat(FUNC_FAILED_HEARTBEAT_MESSAGE),
|
||||
},
|
||||
)
|
||||
]
|
||||
continue # always chain
|
||||
elif heartbeat_request:
|
||||
assert self.agent_state.created_by_id is not None
|
||||
next_input_message = Message.dict_to_message(
|
||||
agent_id=self.agent_state.id,
|
||||
model=self.model,
|
||||
openai_message_dict={
|
||||
"role": "user", # TODO: change to system?
|
||||
"content": get_heartbeat(REQ_HEARTBEAT_MESSAGE),
|
||||
},
|
||||
)
|
||||
next_input_messages = [
|
||||
Message.dict_to_message(
|
||||
agent_id=self.agent_state.id,
|
||||
model=self.model,
|
||||
openai_message_dict={
|
||||
"role": "user", # TODO: change to system?
|
||||
"content": get_heartbeat(REQ_HEARTBEAT_MESSAGE),
|
||||
},
|
||||
)
|
||||
]
|
||||
continue # always chain
|
||||
# Letta no-op / yield
|
||||
else:
|
||||
@ -788,7 +817,7 @@ class Agent(BaseAgent):
|
||||
|
||||
def inner_step(
|
||||
self,
|
||||
messages: Union[Message, List[Message]],
|
||||
messages: List[Message],
|
||||
first_message: bool = False,
|
||||
first_message_retry_limit: int = FIRST_MESSAGE_ATTEMPTS,
|
||||
skip_verify: bool = False,
|
||||
@ -814,11 +843,8 @@ class Agent(BaseAgent):
|
||||
self.update_memory_if_changed(current_persisted_memory)
|
||||
|
||||
# Step 1: add user message
|
||||
if isinstance(messages, Message):
|
||||
messages = [messages]
|
||||
|
||||
if not all(isinstance(m, Message) for m in messages):
|
||||
raise ValueError(f"messages should be a Message or a list of Message, got {type(messages)}")
|
||||
raise ValueError(f"messages should be a list of Message, got {[type(m) for m in messages]}")
|
||||
|
||||
in_context_messages = self.agent_manager.get_in_context_messages(agent_id=self.agent_state.id, actor=self.user)
|
||||
input_message_sequence = in_context_messages + messages
|
||||
@ -1229,9 +1255,7 @@ class Agent(BaseAgent):
|
||||
return context_window_breakdown.context_window_size_current
|
||||
|
||||
# TODO: Refactor into separate class v.s. large if/elses here
|
||||
def execute_tool_and_persist_state(
|
||||
self, function_name: str, function_args: dict, target_letta_tool: Tool
|
||||
) -> tuple[Any, Optional[SandboxRunResult]]:
|
||||
def execute_tool_and_persist_state(self, function_name: str, function_args: dict, target_letta_tool: Tool) -> ToolExecutionResult:
|
||||
"""
|
||||
Execute tool modifications and persist the state of the agent.
|
||||
Note: only some agent state modifications will be persisted, such as data in the AgentState ORM and block data
|
||||
@ -1293,8 +1317,10 @@ class Agent(BaseAgent):
|
||||
)
|
||||
|
||||
function_response, is_error = mcp_client.execute_tool(tool_name=function_name, tool_args=function_args)
|
||||
sandbox_run_result = SandboxRunResult(status="error" if is_error else "success")
|
||||
return function_response, sandbox_run_result
|
||||
return ToolExecutionResult(
|
||||
status="error" if is_error else "success",
|
||||
func_return=function_response,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
# Parse the source code to extract function annotations
|
||||
@ -1311,23 +1337,29 @@ class Agent(BaseAgent):
|
||||
agent_state_copy.tools = []
|
||||
agent_state_copy.tool_rules = []
|
||||
|
||||
sandbox_run_result = ToolExecutionSandbox(function_name, function_args, self.user, tool_object=target_letta_tool).run(
|
||||
tool_execution_result = ToolExecutionSandbox(function_name, function_args, self.user, tool_object=target_letta_tool).run(
|
||||
agent_state=agent_state_copy
|
||||
)
|
||||
function_response, updated_agent_state = sandbox_run_result.func_return, sandbox_run_result.agent_state
|
||||
assert orig_memory_str == self.agent_state.memory.compile(), "Memory should not be modified in a sandbox tool"
|
||||
if updated_agent_state is not None:
|
||||
self.update_memory_if_changed(updated_agent_state.memory)
|
||||
return function_response, sandbox_run_result
|
||||
if tool_execution_result.agent_state is not None:
|
||||
self.update_memory_if_changed(tool_execution_result.agent_state.memory)
|
||||
return tool_execution_result
|
||||
except Exception as e:
|
||||
# Need to catch error here, or else trunction wont happen
|
||||
# TODO: modify to function execution error
|
||||
function_response = get_friendly_error_msg(
|
||||
function_name=function_name, exception_name=type(e).__name__, exception_message=str(e)
|
||||
)
|
||||
return function_response, SandboxRunResult(status="error")
|
||||
return ToolExecutionResult(
|
||||
status="error",
|
||||
func_return=function_response,
|
||||
stderr=[traceback.format_exc()],
|
||||
)
|
||||
|
||||
return function_response, None
|
||||
return ToolExecutionResult(
|
||||
status="success",
|
||||
func_return=function_response,
|
||||
)
|
||||
|
||||
|
||||
def save_agent(agent: Agent):
|
||||
|
@ -324,11 +324,11 @@ class LettaAgent(BaseAgent):
|
||||
tool_execution_manager = ToolExecutionManager(agent_state=agent_state, actor=self.actor)
|
||||
# TODO: Integrate sandbox result
|
||||
log_event(name=f"start_{tool_name}_execution", attributes=tool_args)
|
||||
function_response, _ = await tool_execution_manager.execute_tool_async(
|
||||
tool_execution_result = await tool_execution_manager.execute_tool_async(
|
||||
function_name=tool_name, function_args=tool_args, tool=target_tool
|
||||
)
|
||||
log_event(name=f"finish_{tool_name}_execution", attributes=tool_args)
|
||||
return function_response, True
|
||||
return tool_execution_result.func_return, True
|
||||
except Exception as e:
|
||||
return f"Failed to call tool. Error: {e}", False
|
||||
|
||||
|
@ -37,6 +37,7 @@ from letta.services.passage_manager import PassageManager
|
||||
from letta.services.sandbox_config_manager import SandboxConfigManager
|
||||
from letta.services.tool_executor.tool_execution_manager import ToolExecutionManager
|
||||
from letta.settings import tool_settings
|
||||
from letta.tracing import log_event, trace_method
|
||||
from letta.utils import united_diff
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@ -82,12 +83,12 @@ async def execute_tool_wrapper(params: ToolExecutionParams):
|
||||
sandbox_config=params.sbx_config,
|
||||
sandbox_env_vars=params.sbx_env_vars,
|
||||
)
|
||||
result, _ = await mgr.execute_tool_async(
|
||||
tool_execution_result = await mgr.execute_tool_async(
|
||||
function_name=params.tool_call_name,
|
||||
function_args=params.tool_args,
|
||||
tool=target_tool,
|
||||
)
|
||||
return params.agent_id, (result, True)
|
||||
return params.agent_id, (tool_execution_result.func_return, True)
|
||||
except Exception as e:
|
||||
return params.agent_id, (f"Failed to call tool. Error: {e}", False)
|
||||
|
||||
@ -120,55 +121,54 @@ class LettaAgentBatch:
|
||||
self.actor = actor
|
||||
self.max_steps = max_steps
|
||||
|
||||
@trace_method
|
||||
async def step_until_request(
|
||||
self,
|
||||
batch_requests: List[LettaBatchRequest],
|
||||
letta_batch_job_id: str,
|
||||
agent_step_state_mapping: Optional[Dict[str, AgentStepState]] = None,
|
||||
) -> LettaBatchResponse:
|
||||
# Basic checks
|
||||
log_event(name="validate_inputs")
|
||||
if not batch_requests:
|
||||
raise ValueError("Empty list of batch_requests passed in!")
|
||||
if agent_step_state_mapping is None:
|
||||
agent_step_state_mapping = {}
|
||||
|
||||
log_event(name="load_and_prepare_agents")
|
||||
agent_messages_mapping: Dict[str, List[Message]] = {}
|
||||
agent_tools_mapping: Dict[str, List[dict]] = {}
|
||||
agent_states = []
|
||||
|
||||
for batch_request in batch_requests:
|
||||
agent_id = batch_request.agent_id
|
||||
agent_state = self.agent_manager.get_agent_by_id(agent_id, actor=self.actor)
|
||||
agent_states.append(agent_state)
|
||||
|
||||
agent_messages_mapping[agent_id] = self._get_in_context_messages_per_agent(
|
||||
agent_state=agent_state, input_messages=batch_request.messages
|
||||
)
|
||||
|
||||
# TODO: Think about a cleaner way to do this?
|
||||
if agent_id not in agent_step_state_mapping:
|
||||
agent_step_state_mapping[agent_id] = AgentStepState(
|
||||
step_number=0, tool_rules_solver=ToolRulesSolver(tool_rules=agent_state.tool_rules)
|
||||
)
|
||||
|
||||
agent_tools_mapping[agent_id] = self._prepare_tools_per_agent(
|
||||
agent_state, agent_step_state_mapping.get(agent_id).tool_rules_solver
|
||||
)
|
||||
agent_tools_mapping[agent_id] = self._prepare_tools_per_agent(agent_state, agent_step_state_mapping[agent_id].tool_rules_solver)
|
||||
|
||||
# TODO: This is a hack, this is because LLM client expects a LLM config
|
||||
# TODO: But that doesn't really work in batch land
|
||||
# TODO: @caren will factor this out
|
||||
log_event(name="init_llm_client")
|
||||
llm_client = LLMClient.create(
|
||||
llm_config=agent_states[0].llm_config,
|
||||
put_inner_thoughts_first=True,
|
||||
)
|
||||
agent_llm_config_mapping = {agent_state.id: agent_state.llm_config for agent_state in agent_states}
|
||||
agent_llm_config_mapping = {s.id: s.llm_config for s in agent_states}
|
||||
|
||||
log_event(name="send_llm_batch_request")
|
||||
batch_response = await llm_client.send_llm_batch_request_async(
|
||||
agent_messages_mapping=agent_messages_mapping,
|
||||
agent_tools_mapping=agent_tools_mapping,
|
||||
agent_llm_config_mapping=agent_llm_config_mapping,
|
||||
)
|
||||
|
||||
# Write the response into the jobs table, where it will get picked up by the next cron run
|
||||
log_event(name="persist_llm_batch_job")
|
||||
llm_batch_job = self.batch_manager.create_llm_batch_job(
|
||||
llm_provider=ProviderType.anthropic, # TODO: Expand to more providers
|
||||
create_batch_response=batch_response,
|
||||
@ -177,24 +177,26 @@ class LettaAgentBatch:
|
||||
letta_batch_job_id=letta_batch_job_id,
|
||||
)
|
||||
|
||||
# Create batch items in bulk for all agents
|
||||
log_event(name="prepare_batch_items")
|
||||
batch_items = []
|
||||
for agent_state in agent_states:
|
||||
agent_step_state = agent_step_state_mapping.get(agent_state.id)
|
||||
batch_item = LLMBatchItem(
|
||||
llm_batch_id=llm_batch_job.id,
|
||||
agent_id=agent_state.id,
|
||||
llm_config=agent_state.llm_config,
|
||||
request_status=JobStatus.created,
|
||||
step_status=AgentStepStatus.paused,
|
||||
step_state=agent_step_state,
|
||||
for state in agent_states:
|
||||
step_state = agent_step_state_mapping[state.id]
|
||||
batch_items.append(
|
||||
LLMBatchItem(
|
||||
llm_batch_id=llm_batch_job.id,
|
||||
agent_id=state.id,
|
||||
llm_config=state.llm_config,
|
||||
request_status=JobStatus.created,
|
||||
step_status=AgentStepStatus.paused,
|
||||
step_state=step_state,
|
||||
)
|
||||
)
|
||||
batch_items.append(batch_item)
|
||||
|
||||
# Create all batch items at once using the bulk operation
|
||||
if batch_items:
|
||||
log_event(name="bulk_create_batch_items")
|
||||
self.batch_manager.create_llm_batch_items_bulk(batch_items, actor=self.actor)
|
||||
|
||||
log_event(name="return_batch_response")
|
||||
return LettaBatchResponse(
|
||||
letta_batch_id=llm_batch_job.letta_batch_job_id,
|
||||
last_llm_batch_id=llm_batch_job.id,
|
||||
@ -204,27 +206,27 @@ class LettaAgentBatch:
|
||||
created_at=llm_batch_job.created_at,
|
||||
)
|
||||
|
||||
@trace_method
|
||||
async def resume_step_after_request(self, letta_batch_id: str, llm_batch_id: str) -> LettaBatchResponse:
|
||||
# 1. gather everything we need
|
||||
log_event(name="load_context")
|
||||
llm_batch_job = self.batch_manager.get_llm_batch_job_by_id(llm_batch_id=llm_batch_id, actor=self.actor)
|
||||
ctx = await self._collect_resume_context(llm_batch_id)
|
||||
|
||||
# 2. persist request‑level status updates
|
||||
log_event(name="update_statuses")
|
||||
self._update_request_statuses(ctx.request_status_updates)
|
||||
|
||||
# 3. run the tools in parallel
|
||||
log_event(name="exec_tools")
|
||||
exec_results = await self._execute_tools(ctx)
|
||||
|
||||
# 4. create + save assistant/tool messages
|
||||
log_event(name="persist_messages")
|
||||
msg_map = self._persist_tool_messages(exec_results, ctx)
|
||||
|
||||
# 5. mark steps complete
|
||||
log_event(name="mark_steps_done")
|
||||
self._mark_steps_complete(llm_batch_id, ctx.agent_ids)
|
||||
|
||||
# 6. build next‑round requests / step‑state map
|
||||
log_event(name="prepare_next")
|
||||
next_reqs, next_step_state = self._prepare_next_iteration(exec_results, ctx, msg_map)
|
||||
if len(next_reqs) == 0:
|
||||
# mark batch job as completed
|
||||
self.job_manager.update_job_by_id(job_id=letta_batch_id, job_update=JobUpdate(status=JobStatus.completed), actor=self.actor)
|
||||
return LettaBatchResponse(
|
||||
letta_batch_id=llm_batch_job.letta_batch_job_id,
|
||||
@ -235,15 +237,16 @@ class LettaAgentBatch:
|
||||
created_at=llm_batch_job.created_at,
|
||||
)
|
||||
|
||||
# 7. recurse into the normal stepping pipeline
|
||||
return await self.step_until_request(
|
||||
batch_requests=next_reqs,
|
||||
letta_batch_job_id=letta_batch_id,
|
||||
agent_step_state_mapping=next_step_state,
|
||||
)
|
||||
|
||||
@trace_method
|
||||
async def _collect_resume_context(self, llm_batch_id: str) -> _ResumeContext:
|
||||
batch_items = self.batch_manager.list_llm_batch_items(llm_batch_id=llm_batch_id)
|
||||
# NOTE: We only continue for items with successful results
|
||||
batch_items = self.batch_manager.list_llm_batch_items(llm_batch_id=llm_batch_id, request_status=JobStatus.completed)
|
||||
|
||||
agent_ids, agent_state_map = [], {}
|
||||
provider_results, name_map, args_map, cont_map = {}, {}, {}, {}
|
||||
@ -300,6 +303,7 @@ class LettaAgentBatch:
|
||||
env = self.sandbox_config_manager.get_sandbox_env_vars_as_dict(cfg.id, actor=self.actor, limit=100)
|
||||
return cfg, env
|
||||
|
||||
@trace_method
|
||||
async def _execute_tools(self, ctx: _ResumeContext) -> Sequence[Tuple[str, Tuple[str, bool]]]:
|
||||
sbx_cfg, sbx_env = self._build_sandbox()
|
||||
params = [
|
||||
|
@ -32,6 +32,7 @@ from letta.schemas.message import Message, MessageCreate
|
||||
from letta.schemas.openai.chat_completion_response import UsageStatistics
|
||||
from letta.schemas.organization import Organization
|
||||
from letta.schemas.passage import Passage
|
||||
from letta.schemas.response_format import ResponseFormatUnion
|
||||
from letta.schemas.run import Run
|
||||
from letta.schemas.sandbox_config import E2BSandboxConfig, LocalSandboxConfig, SandboxConfig, SandboxConfigCreate, SandboxConfigUpdate
|
||||
from letta.schemas.source import Source, SourceCreate, SourceUpdate
|
||||
@ -100,6 +101,7 @@ class AbstractClient(object):
|
||||
message_ids: Optional[List[str]] = None,
|
||||
memory: Optional[Memory] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
response_format: Optional[ResponseFormatUnion] = None,
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
@ -553,6 +555,7 @@ class RESTClient(AbstractClient):
|
||||
initial_message_sequence: Optional[List[Message]] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
message_buffer_autoclear: bool = False,
|
||||
response_format: Optional[ResponseFormatUnion] = None,
|
||||
) -> AgentState:
|
||||
"""Create an agent
|
||||
|
||||
@ -615,6 +618,7 @@ class RESTClient(AbstractClient):
|
||||
"include_base_tools": include_base_tools,
|
||||
"message_buffer_autoclear": message_buffer_autoclear,
|
||||
"include_multi_agent_tools": include_multi_agent_tools,
|
||||
"response_format": response_format,
|
||||
}
|
||||
|
||||
# Only add name if it's not None
|
||||
@ -653,6 +657,7 @@ class RESTClient(AbstractClient):
|
||||
embedding_config: Optional[EmbeddingConfig] = None,
|
||||
message_ids: Optional[List[str]] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
response_format: Optional[ResponseFormatUnion] = None,
|
||||
) -> AgentState:
|
||||
"""
|
||||
Update an existing agent
|
||||
@ -682,6 +687,7 @@ class RESTClient(AbstractClient):
|
||||
llm_config=llm_config,
|
||||
embedding_config=embedding_config,
|
||||
message_ids=message_ids,
|
||||
response_format=response_format,
|
||||
)
|
||||
response = requests.patch(f"{self.base_url}/{self.api_prefix}/agents/{agent_id}", json=request.model_dump(), headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
@ -2425,6 +2431,7 @@ class LocalClient(AbstractClient):
|
||||
llm_config: Optional[LLMConfig] = None,
|
||||
embedding_config: Optional[EmbeddingConfig] = None,
|
||||
message_ids: Optional[List[str]] = None,
|
||||
response_format: Optional[ResponseFormatUnion] = None,
|
||||
):
|
||||
"""
|
||||
Update an existing agent
|
||||
@ -2458,6 +2465,7 @@ class LocalClient(AbstractClient):
|
||||
llm_config=llm_config,
|
||||
embedding_config=embedding_config,
|
||||
message_ids=message_ids,
|
||||
response_format=response_format,
|
||||
),
|
||||
actor=self.user,
|
||||
)
|
||||
@ -2661,7 +2669,7 @@ class LocalClient(AbstractClient):
|
||||
response (LettaResponse): Response from the agent
|
||||
"""
|
||||
self.interface.clear()
|
||||
usage = self.server.send_messages(actor=self.user, agent_id=agent_id, messages=messages)
|
||||
usage = self.server.send_messages(actor=self.user, agent_id=agent_id, input_messages=messages)
|
||||
|
||||
# format messages
|
||||
return LettaResponse(messages=messages, usage=usage)
|
||||
@ -2703,7 +2711,7 @@ class LocalClient(AbstractClient):
|
||||
usage = self.server.send_messages(
|
||||
actor=self.user,
|
||||
agent_id=agent_id,
|
||||
messages=[MessageCreate(role=MessageRole(role), content=message, name=name)],
|
||||
input_messages=[MessageCreate(role=MessageRole(role), content=message, name=name)],
|
||||
)
|
||||
|
||||
## TODO: need to make sure date/timestamp is propely passed
|
||||
|
@ -47,13 +47,14 @@ DEFAULT_PERSONA = "sam_pov"
|
||||
DEFAULT_HUMAN = "basic"
|
||||
DEFAULT_PRESET = "memgpt_chat"
|
||||
|
||||
SEND_MESSAGE_TOOL_NAME = "send_message"
|
||||
# Base tools that cannot be edited, as they access agent state directly
|
||||
# Note that we don't include "conversation_search_date" for now
|
||||
BASE_TOOLS = ["send_message", "conversation_search", "archival_memory_insert", "archival_memory_search"]
|
||||
BASE_TOOLS = [SEND_MESSAGE_TOOL_NAME, "conversation_search", "archival_memory_insert", "archival_memory_search"]
|
||||
# Base memory tools CAN be edited, and are added by default by the server
|
||||
BASE_MEMORY_TOOLS = ["core_memory_append", "core_memory_replace"]
|
||||
# Base tools if the memgpt agent has enable_sleeptime on
|
||||
BASE_SLEEPTIME_CHAT_TOOLS = ["send_message", "conversation_search", "archival_memory_search"]
|
||||
BASE_SLEEPTIME_CHAT_TOOLS = [SEND_MESSAGE_TOOL_NAME, "conversation_search", "archival_memory_search"]
|
||||
# Base memory tools for sleeptime agent
|
||||
BASE_SLEEPTIME_TOOLS = [
|
||||
"memory_replace",
|
||||
@ -72,7 +73,7 @@ LETTA_TOOL_SET = set(BASE_TOOLS + BASE_MEMORY_TOOLS + MULTI_AGENT_TOOLS + BASE_S
|
||||
# The name of the tool used to send message to the user
|
||||
# May not be relevant in cases where the agent has multiple ways to message to user (send_imessage, send_discord_mesasge, ...)
|
||||
# or in cases where the agent has no concept of messaging a user (e.g. a workflow agent)
|
||||
DEFAULT_MESSAGE_TOOL = "send_message"
|
||||
DEFAULT_MESSAGE_TOOL = SEND_MESSAGE_TOOL_NAME
|
||||
DEFAULT_MESSAGE_TOOL_KWARG = "message"
|
||||
|
||||
PRE_EXECUTION_MESSAGE_ARG = "pre_exec_msg"
|
||||
|
@ -9,7 +9,6 @@ from letta.functions.helpers import (
|
||||
extract_send_message_from_steps_messages,
|
||||
fire_and_forget_send_to_agent,
|
||||
)
|
||||
from letta.helpers.message_helper import prepare_input_message_create
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.message import MessageCreate
|
||||
from letta.server.rest_api.utils import get_letta_server
|
||||
@ -109,11 +108,10 @@ def send_message_to_agents_matching_tags(self: "Agent", message: str, match_all:
|
||||
|
||||
# Prepare the message
|
||||
messages = [MessageCreate(role=MessageRole.system, content=augmented_message, name=self.agent_state.name)]
|
||||
input_messages = [prepare_input_message_create(m, agent_id) for m in messages]
|
||||
|
||||
# Run .step() and return the response
|
||||
usage_stats = agent.step(
|
||||
messages=input_messages,
|
||||
input_messages=messages,
|
||||
chaining=True,
|
||||
max_chaining_steps=None,
|
||||
stream=False,
|
||||
|
@ -352,7 +352,7 @@ async def send_message_to_agent_no_stream(
|
||||
server: "SyncServer",
|
||||
agent_id: str,
|
||||
actor: User,
|
||||
messages: Union[List[Message], List[MessageCreate]],
|
||||
messages: List[MessageCreate],
|
||||
metadata: Optional[dict] = None,
|
||||
) -> LettaResponse:
|
||||
"""
|
||||
@ -368,7 +368,7 @@ async def send_message_to_agent_no_stream(
|
||||
server.send_messages,
|
||||
actor=actor,
|
||||
agent_id=agent_id,
|
||||
messages=messages,
|
||||
input_messages=messages,
|
||||
interface=interface,
|
||||
metadata=metadata,
|
||||
)
|
||||
@ -478,7 +478,7 @@ def fire_and_forget_send_to_agent(
|
||||
await server.send_message_to_agent(
|
||||
agent_id=other_agent_id,
|
||||
actor=sender_agent.user,
|
||||
messages=messages,
|
||||
input_messages=messages,
|
||||
stream_steps=False,
|
||||
stream_tokens=False,
|
||||
use_assistant_message=True,
|
||||
|
@ -35,7 +35,7 @@ class DynamicMultiAgent(Agent):
|
||||
|
||||
def step(
|
||||
self,
|
||||
messages: List[MessageCreate],
|
||||
input_messages: List[MessageCreate],
|
||||
chaining: bool = True,
|
||||
max_chaining_steps: Optional[int] = None,
|
||||
put_inner_thoughts_first: bool = True,
|
||||
@ -43,27 +43,43 @@ class DynamicMultiAgent(Agent):
|
||||
) -> LettaUsageStatistics:
|
||||
total_usage = UsageStatistics()
|
||||
step_count = 0
|
||||
speaker_id = None
|
||||
|
||||
# Load settings
|
||||
token_streaming = self.interface.streaming_mode if hasattr(self.interface, "streaming_mode") else False
|
||||
metadata = self.interface.metadata if hasattr(self.interface, "metadata") else None
|
||||
|
||||
agents = {}
|
||||
# Load agents and initialize chat history with indexing
|
||||
agents = {self.agent_state.id: self.load_manager_agent()}
|
||||
message_index = {self.agent_state.id: 0}
|
||||
agents[self.agent_state.id] = self.load_manager_agent()
|
||||
chat_history: List[MessageCreate] = []
|
||||
for agent_id in self.agent_ids:
|
||||
agents[agent_id] = self.load_participant_agent(agent_id=agent_id)
|
||||
message_index[agent_id] = 0
|
||||
|
||||
chat_history: List[Message] = []
|
||||
new_messages = messages
|
||||
speaker_id = None
|
||||
# Prepare new messages
|
||||
new_messages = []
|
||||
for message in input_messages:
|
||||
if isinstance(message.content, str):
|
||||
message.content = [TextContent(text=message.content)]
|
||||
message.group_id = self.group_id
|
||||
new_messages.append(message)
|
||||
|
||||
try:
|
||||
for _ in range(self.max_turns):
|
||||
# Prepare manager message
|
||||
agent_id_options = [agent_id for agent_id in self.agent_ids if agent_id != speaker_id]
|
||||
manager_message = self.ask_manager_to_choose_participant_message(new_messages, chat_history, agent_id_options)
|
||||
manager_message = self.ask_manager_to_choose_participant_message(
|
||||
manager_agent_id=self.agent_state.id,
|
||||
new_messages=new_messages,
|
||||
chat_history=chat_history,
|
||||
agent_id_options=agent_id_options,
|
||||
)
|
||||
|
||||
# Perform manager step
|
||||
manager_agent = agents[self.agent_state.id]
|
||||
usage_stats = manager_agent.step(
|
||||
messages=[manager_message],
|
||||
input_messages=[manager_message],
|
||||
chaining=chaining,
|
||||
max_chaining_steps=max_chaining_steps,
|
||||
stream=token_streaming,
|
||||
@ -71,42 +87,27 @@ class DynamicMultiAgent(Agent):
|
||||
metadata=metadata,
|
||||
put_inner_thoughts_first=put_inner_thoughts_first,
|
||||
)
|
||||
|
||||
# Parse manager response
|
||||
responses = Message.to_letta_messages_from_list(manager_agent.last_response_messages)
|
||||
assistant_message = [response for response in responses if response.message_type == "assistant_message"][0]
|
||||
for name, agent_id in [(agents[agent_id].agent_state.name, agent_id) for agent_id in agent_id_options]:
|
||||
if name.lower() in assistant_message.content.lower():
|
||||
speaker_id = agent_id
|
||||
|
||||
# sum usage
|
||||
# Sum usage
|
||||
total_usage.prompt_tokens += usage_stats.prompt_tokens
|
||||
total_usage.completion_tokens += usage_stats.completion_tokens
|
||||
total_usage.total_tokens += usage_stats.total_tokens
|
||||
step_count += 1
|
||||
|
||||
# initialize input messages
|
||||
for message in chat_history[message_index[speaker_id] :]:
|
||||
message.id = Message.generate_id()
|
||||
message.agent_id = speaker_id
|
||||
# Update chat history
|
||||
chat_history.extend(new_messages)
|
||||
|
||||
for message in new_messages:
|
||||
chat_history.append(
|
||||
Message(
|
||||
agent_id=speaker_id,
|
||||
role=message.role,
|
||||
content=[TextContent(text=message.content)],
|
||||
name=message.name,
|
||||
model=None,
|
||||
tool_calls=None,
|
||||
tool_call_id=None,
|
||||
group_id=self.group_id,
|
||||
otid=message.otid,
|
||||
)
|
||||
)
|
||||
|
||||
# load agent and perform step
|
||||
# Perform participant step
|
||||
participant_agent = agents[speaker_id]
|
||||
usage_stats = participant_agent.step(
|
||||
messages=chat_history[message_index[speaker_id] :],
|
||||
input_messages=chat_history[message_index[speaker_id] :],
|
||||
chaining=chaining,
|
||||
max_chaining_steps=max_chaining_steps,
|
||||
stream=token_streaming,
|
||||
@ -115,54 +116,54 @@ class DynamicMultiAgent(Agent):
|
||||
put_inner_thoughts_first=put_inner_thoughts_first,
|
||||
)
|
||||
|
||||
# parse new messages for next step
|
||||
# Parse participant response
|
||||
responses = Message.to_letta_messages_from_list(
|
||||
participant_agent.last_response_messages,
|
||||
)
|
||||
|
||||
assistant_messages = [response for response in responses if response.message_type == "assistant_message"]
|
||||
new_messages = [
|
||||
MessageCreate(
|
||||
role="system",
|
||||
content=message.content,
|
||||
content=[TextContent(text=message.content)] if isinstance(message.content, str) else message.content,
|
||||
name=participant_agent.agent_state.name,
|
||||
otid=message.otid,
|
||||
sender_id=participant_agent.agent_state.id,
|
||||
group_id=self.group_id,
|
||||
)
|
||||
for message in assistant_messages
|
||||
]
|
||||
|
||||
# Update message index
|
||||
message_index[speaker_id] = len(chat_history) + len(new_messages)
|
||||
|
||||
# sum usage
|
||||
# Sum usage
|
||||
total_usage.prompt_tokens += usage_stats.prompt_tokens
|
||||
total_usage.completion_tokens += usage_stats.completion_tokens
|
||||
total_usage.total_tokens += usage_stats.total_tokens
|
||||
step_count += 1
|
||||
|
||||
# check for termination token
|
||||
# Check for termination token
|
||||
if any(self.termination_token in message.content for message in new_messages):
|
||||
break
|
||||
|
||||
# persist remaining chat history
|
||||
for message in new_messages:
|
||||
chat_history.append(
|
||||
Message(
|
||||
agent_id=agent_id,
|
||||
role=message.role,
|
||||
content=[TextContent(text=message.content)],
|
||||
name=message.name,
|
||||
model=None,
|
||||
tool_calls=None,
|
||||
tool_call_id=None,
|
||||
group_id=self.group_id,
|
||||
)
|
||||
)
|
||||
# Persist remaining chat history
|
||||
chat_history.extend(new_messages)
|
||||
for agent_id, index in message_index.items():
|
||||
if agent_id == speaker_id:
|
||||
continue
|
||||
messages_to_persist = []
|
||||
for message in chat_history[index:]:
|
||||
message.id = Message.generate_id()
|
||||
message.agent_id = agent_id
|
||||
self.message_manager.create_many_messages(chat_history[index:], actor=self.user)
|
||||
message_to_persist = Message(
|
||||
role=message.role,
|
||||
content=message.content,
|
||||
name=message.name,
|
||||
otid=message.otid,
|
||||
sender_id=message.sender_id,
|
||||
group_id=message.group_id,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
messages_to_persist.append(message_to_persist)
|
||||
self.message_manager.create_many_messages(messages_to_persist, actor=self.user)
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
@ -249,10 +250,11 @@ class DynamicMultiAgent(Agent):
|
||||
|
||||
def ask_manager_to_choose_participant_message(
|
||||
self,
|
||||
manager_agent_id: str,
|
||||
new_messages: List[MessageCreate],
|
||||
chat_history: List[Message],
|
||||
agent_id_options: List[str],
|
||||
) -> Message:
|
||||
) -> MessageCreate:
|
||||
text_chat_history = [f"{message.name or 'user'}: {message.content[0].text}" for message in chat_history]
|
||||
for message in new_messages:
|
||||
text_chat_history.append(f"{message.name or 'user'}: {message.content}")
|
||||
@ -264,14 +266,11 @@ class DynamicMultiAgent(Agent):
|
||||
"respond to the messages yourself, your task is only to decide the "
|
||||
f"next speaker, not to participate. \nChat history:\n{context_messages}"
|
||||
)
|
||||
return Message(
|
||||
agent_id=self.agent_state.id,
|
||||
return MessageCreate(
|
||||
role="user",
|
||||
content=[TextContent(text=message_text)],
|
||||
name=None,
|
||||
model=None,
|
||||
tool_calls=None,
|
||||
tool_call_id=None,
|
||||
group_id=self.group_id,
|
||||
otid=Message.generate_otid(),
|
||||
sender_id=manager_agent_id,
|
||||
group_id=self.group_id,
|
||||
)
|
||||
|
@ -29,7 +29,7 @@ class RoundRobinMultiAgent(Agent):
|
||||
|
||||
def step(
|
||||
self,
|
||||
messages: List[MessageCreate],
|
||||
input_messages: List[MessageCreate],
|
||||
chaining: bool = True,
|
||||
max_chaining_steps: Optional[int] = None,
|
||||
put_inner_thoughts_first: bool = True,
|
||||
@ -37,46 +37,39 @@ class RoundRobinMultiAgent(Agent):
|
||||
) -> LettaUsageStatistics:
|
||||
total_usage = UsageStatistics()
|
||||
step_count = 0
|
||||
speaker_id = None
|
||||
|
||||
# Load settings
|
||||
token_streaming = self.interface.streaming_mode if hasattr(self.interface, "streaming_mode") else False
|
||||
metadata = self.interface.metadata if hasattr(self.interface, "metadata") else None
|
||||
|
||||
agents = {}
|
||||
# Load agents and initialize chat history with indexing
|
||||
agents, message_index = {}, {}
|
||||
chat_history: List[MessageCreate] = []
|
||||
for agent_id in self.agent_ids:
|
||||
agents[agent_id] = self.load_participant_agent(agent_id=agent_id)
|
||||
message_index[agent_id] = 0
|
||||
|
||||
# Prepare new messages
|
||||
new_messages = []
|
||||
for message in input_messages:
|
||||
if isinstance(message.content, str):
|
||||
message.content = [TextContent(text=message.content)]
|
||||
message.group_id = self.group_id
|
||||
new_messages.append(message)
|
||||
|
||||
message_index = {agent_id: 0 for agent_id in self.agent_ids}
|
||||
chat_history: List[Message] = []
|
||||
new_messages = messages
|
||||
speaker_id = None
|
||||
try:
|
||||
for i in range(self.max_turns):
|
||||
# Select speaker
|
||||
speaker_id = self.agent_ids[i % len(self.agent_ids)]
|
||||
# initialize input messages
|
||||
start_index = message_index[speaker_id] if speaker_id in message_index else 0
|
||||
for message in chat_history[start_index:]:
|
||||
message.id = Message.generate_id()
|
||||
message.agent_id = speaker_id
|
||||
|
||||
for message in new_messages:
|
||||
chat_history.append(
|
||||
Message(
|
||||
agent_id=speaker_id,
|
||||
role=message.role,
|
||||
content=[TextContent(text=message.content)],
|
||||
name=message.name,
|
||||
model=None,
|
||||
tool_calls=None,
|
||||
tool_call_id=None,
|
||||
group_id=self.group_id,
|
||||
otid=message.otid,
|
||||
)
|
||||
)
|
||||
# Update chat history
|
||||
chat_history.extend(new_messages)
|
||||
|
||||
# load agent and perform step
|
||||
# Perform participant step
|
||||
participant_agent = agents[speaker_id]
|
||||
usage_stats = participant_agent.step(
|
||||
messages=chat_history[start_index:],
|
||||
input_messages=chat_history[message_index[speaker_id] :],
|
||||
chaining=chaining,
|
||||
max_chaining_steps=max_chaining_steps,
|
||||
stream=token_streaming,
|
||||
@ -85,47 +78,48 @@ class RoundRobinMultiAgent(Agent):
|
||||
put_inner_thoughts_first=put_inner_thoughts_first,
|
||||
)
|
||||
|
||||
# parse new messages for next step
|
||||
# Parse participant response
|
||||
responses = Message.to_letta_messages_from_list(participant_agent.last_response_messages)
|
||||
assistant_messages = [response for response in responses if response.message_type == "assistant_message"]
|
||||
new_messages = [
|
||||
MessageCreate(
|
||||
role="system",
|
||||
content=message.content,
|
||||
name=message.name,
|
||||
content=[TextContent(text=message.content)] if isinstance(message.content, str) else message.content,
|
||||
name=participant_agent.agent_state.name,
|
||||
otid=message.otid,
|
||||
sender_id=participant_agent.agent_state.id,
|
||||
group_id=self.group_id,
|
||||
)
|
||||
for message in assistant_messages
|
||||
]
|
||||
|
||||
# Update message index
|
||||
message_index[speaker_id] = len(chat_history) + len(new_messages)
|
||||
|
||||
# sum usage
|
||||
# Sum usage
|
||||
total_usage.prompt_tokens += usage_stats.prompt_tokens
|
||||
total_usage.completion_tokens += usage_stats.completion_tokens
|
||||
total_usage.total_tokens += usage_stats.total_tokens
|
||||
step_count += 1
|
||||
|
||||
# persist remaining chat history
|
||||
for message in new_messages:
|
||||
chat_history.append(
|
||||
Message(
|
||||
agent_id=agent_id,
|
||||
role=message.role,
|
||||
content=[TextContent(text=message.content)],
|
||||
name=message.name,
|
||||
model=None,
|
||||
tool_calls=None,
|
||||
tool_call_id=None,
|
||||
group_id=self.group_id,
|
||||
)
|
||||
)
|
||||
# Persist remaining chat history
|
||||
chat_history.extend(new_messages)
|
||||
for agent_id, index in message_index.items():
|
||||
if agent_id == speaker_id:
|
||||
continue
|
||||
messages_to_persist = []
|
||||
for message in chat_history[index:]:
|
||||
message.id = Message.generate_id()
|
||||
message.agent_id = agent_id
|
||||
self.message_manager.create_many_messages(chat_history[index:], actor=self.user)
|
||||
message_to_persist = Message(
|
||||
role=message.role,
|
||||
content=message.content,
|
||||
name=message.name,
|
||||
otid=message.otid,
|
||||
sender_id=message.sender_id,
|
||||
group_id=self.group_id,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
messages_to_persist.append(message_to_persist)
|
||||
self.message_manager.create_many_messages(messages_to_persist, actor=self.user)
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
@ -143,8 +143,21 @@ class SleeptimeMultiAgent(Agent):
|
||||
group_id=self.group_id,
|
||||
)
|
||||
]
|
||||
|
||||
# Convert Message objects to MessageCreate objects
|
||||
message_creates = [
|
||||
MessageCreate(
|
||||
role=m.role,
|
||||
content=m.content[0].text if m.content and len(m.content) == 1 else m.content,
|
||||
name=m.name,
|
||||
otid=m.otid,
|
||||
sender_id=m.sender_id,
|
||||
)
|
||||
for m in participant_agent_messages
|
||||
]
|
||||
|
||||
result = participant_agent.step(
|
||||
messages=participant_agent_messages,
|
||||
input_messages=message_creates,
|
||||
chaining=chaining,
|
||||
max_chaining_steps=max_chaining_steps,
|
||||
stream=token_streaming,
|
||||
@ -173,7 +186,7 @@ class SleeptimeMultiAgent(Agent):
|
||||
|
||||
def step(
|
||||
self,
|
||||
messages: List[MessageCreate],
|
||||
input_messages: List[MessageCreate],
|
||||
chaining: bool = True,
|
||||
max_chaining_steps: Optional[int] = None,
|
||||
put_inner_thoughts_first: bool = True,
|
||||
@ -181,33 +194,28 @@ class SleeptimeMultiAgent(Agent):
|
||||
) -> LettaUsageStatistics:
|
||||
run_ids = []
|
||||
|
||||
# Load settings
|
||||
token_streaming = self.interface.streaming_mode if hasattr(self.interface, "streaming_mode") else False
|
||||
metadata = self.interface.metadata if hasattr(self.interface, "metadata") else None
|
||||
|
||||
messages = [
|
||||
Message(
|
||||
id=Message.generate_id(),
|
||||
agent_id=self.agent_state.id,
|
||||
role=message.role,
|
||||
content=[TextContent(text=message.content)] if isinstance(message.content, str) else message.content,
|
||||
name=message.name,
|
||||
model=None,
|
||||
tool_calls=None,
|
||||
tool_call_id=None,
|
||||
group_id=self.group_id,
|
||||
otid=message.otid,
|
||||
)
|
||||
for message in messages
|
||||
]
|
||||
# Prepare new messages
|
||||
new_messages = []
|
||||
for message in input_messages:
|
||||
if isinstance(message.content, str):
|
||||
message.content = [TextContent(text=message.content)]
|
||||
message.group_id = self.group_id
|
||||
new_messages.append(message)
|
||||
|
||||
try:
|
||||
# Load main agent
|
||||
main_agent = Agent(
|
||||
agent_state=self.agent_state,
|
||||
interface=self.interface,
|
||||
user=self.user,
|
||||
)
|
||||
# Perform main agent step
|
||||
usage_stats = main_agent.step(
|
||||
messages=messages,
|
||||
input_messages=new_messages,
|
||||
chaining=chaining,
|
||||
max_chaining_steps=max_chaining_steps,
|
||||
stream=token_streaming,
|
||||
@ -216,10 +224,12 @@ class SleeptimeMultiAgent(Agent):
|
||||
put_inner_thoughts_first=put_inner_thoughts_first,
|
||||
)
|
||||
|
||||
# Update turns counter
|
||||
turns_counter = None
|
||||
if self.sleeptime_agent_frequency is not None and self.sleeptime_agent_frequency > 0:
|
||||
turns_counter = self.group_manager.bump_turns_counter(group_id=self.group_id, actor=self.user)
|
||||
|
||||
# Perform participant steps
|
||||
if self.sleeptime_agent_frequency is None or (
|
||||
turns_counter is not None and turns_counter % self.sleeptime_agent_frequency == 0
|
||||
):
|
||||
|
@ -9,7 +9,7 @@ from letta.interface import AgentInterface
|
||||
from letta.orm import User
|
||||
from letta.orm.enums import ToolType
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.message import Message, MessageCreate
|
||||
from letta.schemas.message import MessageCreate
|
||||
from letta.schemas.tool import Tool
|
||||
from letta.schemas.tool_rule import ChildToolRule, InitToolRule, TerminalToolRule
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
@ -37,17 +37,18 @@ class SupervisorMultiAgent(Agent):
|
||||
|
||||
def step(
|
||||
self,
|
||||
messages: List[MessageCreate],
|
||||
input_messages: List[MessageCreate],
|
||||
chaining: bool = True,
|
||||
max_chaining_steps: Optional[int] = None,
|
||||
put_inner_thoughts_first: bool = True,
|
||||
assistant_message_tool_name: str = DEFAULT_MESSAGE_TOOL,
|
||||
**kwargs,
|
||||
) -> LettaUsageStatistics:
|
||||
# Load settings
|
||||
token_streaming = self.interface.streaming_mode if hasattr(self.interface, "streaming_mode") else False
|
||||
metadata = self.interface.metadata if hasattr(self.interface, "metadata") else None
|
||||
|
||||
# add multi agent tool
|
||||
# Prepare supervisor agent
|
||||
if self.tool_manager.get_tool_by_name(tool_name="send_message_to_all_agents_in_group", actor=self.user) is None:
|
||||
multi_agent_tool = Tool(
|
||||
name=send_message_to_all_agents_in_group.__name__,
|
||||
@ -64,7 +65,6 @@ class SupervisorMultiAgent(Agent):
|
||||
)
|
||||
self.agent_state = self.agent_manager.attach_tool(agent_id=self.agent_state.id, tool_id=multi_agent_tool.id, actor=self.user)
|
||||
|
||||
# override tool rules
|
||||
old_tool_rules = self.agent_state.tool_rules
|
||||
self.agent_state.tool_rules = [
|
||||
InitToolRule(
|
||||
@ -79,24 +79,25 @@ class SupervisorMultiAgent(Agent):
|
||||
),
|
||||
]
|
||||
|
||||
supervisor_messages = [
|
||||
Message(
|
||||
agent_id=self.agent_state.id,
|
||||
role="user",
|
||||
content=[TextContent(text=message.content)],
|
||||
name=None,
|
||||
model=None,
|
||||
tool_calls=None,
|
||||
tool_call_id=None,
|
||||
group_id=self.group_id,
|
||||
otid=message.otid,
|
||||
)
|
||||
for message in messages
|
||||
]
|
||||
# Prepare new messages
|
||||
new_messages = []
|
||||
for message in input_messages:
|
||||
if isinstance(message.content, str):
|
||||
message.content = [TextContent(text=message.content)]
|
||||
message.group_id = self.group_id
|
||||
new_messages.append(message)
|
||||
|
||||
try:
|
||||
supervisor_agent = Agent(agent_state=self.agent_state, interface=self.interface, user=self.user)
|
||||
# Load supervisor agent
|
||||
supervisor_agent = Agent(
|
||||
agent_state=self.agent_state,
|
||||
interface=self.interface,
|
||||
user=self.user,
|
||||
)
|
||||
|
||||
# Perform supervisor step
|
||||
usage_stats = supervisor_agent.step(
|
||||
messages=supervisor_messages,
|
||||
input_messages=new_messages,
|
||||
chaining=chaining,
|
||||
max_chaining_steps=max_chaining_steps,
|
||||
stream=token_streaming,
|
||||
|
@ -22,6 +22,13 @@ from letta.schemas.letta_message_content import (
|
||||
)
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import ToolReturn
|
||||
from letta.schemas.response_format import (
|
||||
JsonObjectResponseFormat,
|
||||
JsonSchemaResponseFormat,
|
||||
ResponseFormatType,
|
||||
ResponseFormatUnion,
|
||||
TextResponseFormat,
|
||||
)
|
||||
from letta.schemas.tool_rule import (
|
||||
ChildToolRule,
|
||||
ConditionalToolRule,
|
||||
@ -371,3 +378,25 @@ def deserialize_agent_step_state(data: Optional[Dict]) -> Optional[AgentStepStat
|
||||
return None
|
||||
|
||||
return AgentStepState(**data)
|
||||
|
||||
|
||||
# --------------------------
|
||||
# Response Format Serialization
|
||||
# --------------------------
|
||||
|
||||
|
||||
def serialize_response_format(response_format: Optional[ResponseFormatUnion]) -> Optional[Dict[str, Any]]:
|
||||
if not response_format:
|
||||
return None
|
||||
return response_format.model_dump(mode="json")
|
||||
|
||||
|
||||
def deserialize_response_format(data: Optional[Dict]) -> Optional[ResponseFormatUnion]:
|
||||
if not data:
|
||||
return None
|
||||
if data["type"] == ResponseFormatType.text:
|
||||
return TextResponseFormat(**data)
|
||||
if data["type"] == ResponseFormatType.json_schema:
|
||||
return JsonSchemaResponseFormat(**data)
|
||||
if data["type"] == ResponseFormatType.json_object:
|
||||
return JsonObjectResponseFormat(**data)
|
||||
|
@ -40,4 +40,5 @@ def prepare_input_message_create(
|
||||
tool_call_id=None,
|
||||
otid=message.otid,
|
||||
sender_id=message.sender_id,
|
||||
group_id=message.group_id,
|
||||
)
|
||||
|
@ -160,12 +160,12 @@ def execute_external_tool(
|
||||
else:
|
||||
agent_state_copy = None
|
||||
|
||||
sandbox_run_result = ToolExecutionSandbox(function_name, function_args, actor).run(agent_state=agent_state_copy)
|
||||
function_response, updated_agent_state = sandbox_run_result.func_return, sandbox_run_result.agent_state
|
||||
tool_execution_result = ToolExecutionSandbox(function_name, function_args, actor).run(agent_state=agent_state_copy)
|
||||
function_response, updated_agent_state = tool_execution_result.func_return, tool_execution_result.agent_state
|
||||
# TODO: Bring this back
|
||||
# if allow_agent_state_modifications and updated_agent_state is not None:
|
||||
# self.update_memory_if_changed(updated_agent_state.memory)
|
||||
return function_response, sandbox_run_result
|
||||
return function_response, tool_execution_result
|
||||
except Exception as e:
|
||||
# Need to catch error here, or else trunction wont happen
|
||||
# TODO: modify to function execution error
|
||||
|
@ -5,7 +5,7 @@ from sqlalchemy import JSON, Boolean, Index, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from letta.orm.block import Block
|
||||
from letta.orm.custom_columns import EmbeddingConfigColumn, LLMConfigColumn, ToolRulesColumn
|
||||
from letta.orm.custom_columns import EmbeddingConfigColumn, LLMConfigColumn, ResponseFormatColumn, ToolRulesColumn
|
||||
from letta.orm.identity import Identity
|
||||
from letta.orm.mixins import OrganizationMixin
|
||||
from letta.orm.organization import Organization
|
||||
@ -15,6 +15,7 @@ from letta.schemas.agent import AgentType, get_prompt_template_for_agent_type
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.memory import Memory
|
||||
from letta.schemas.response_format import ResponseFormatUnion
|
||||
from letta.schemas.tool_rule import ToolRule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -48,6 +49,11 @@ class Agent(SqlalchemyBase, OrganizationMixin):
|
||||
# This is dangerously flexible with the JSON type
|
||||
message_ids: Mapped[Optional[List[str]]] = mapped_column(JSON, nullable=True, doc="List of message IDs in in-context memory.")
|
||||
|
||||
# Response Format
|
||||
response_format: Mapped[Optional[ResponseFormatUnion]] = mapped_column(
|
||||
ResponseFormatColumn, nullable=True, doc="The response format for the agent."
|
||||
)
|
||||
|
||||
# Metadata and configs
|
||||
metadata_: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True, doc="metadata for the agent.")
|
||||
llm_config: Mapped[Optional[LLMConfig]] = mapped_column(
|
||||
@ -168,6 +174,7 @@ class Agent(SqlalchemyBase, OrganizationMixin):
|
||||
"multi_agent_group": None,
|
||||
"tool_exec_environment_variables": [],
|
||||
"enable_sleeptime": None,
|
||||
"response_format": self.response_format,
|
||||
}
|
||||
|
||||
# Optional fields: only included if requested
|
||||
|
@ -9,6 +9,7 @@ from letta.helpers.converters import (
|
||||
deserialize_llm_config,
|
||||
deserialize_message_content,
|
||||
deserialize_poll_batch_response,
|
||||
deserialize_response_format,
|
||||
deserialize_tool_calls,
|
||||
deserialize_tool_returns,
|
||||
deserialize_tool_rules,
|
||||
@ -20,6 +21,7 @@ from letta.helpers.converters import (
|
||||
serialize_llm_config,
|
||||
serialize_message_content,
|
||||
serialize_poll_batch_response,
|
||||
serialize_response_format,
|
||||
serialize_tool_calls,
|
||||
serialize_tool_returns,
|
||||
serialize_tool_rules,
|
||||
@ -168,3 +170,16 @@ class AgentStepStateColumn(TypeDecorator):
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
return deserialize_agent_step_state(value)
|
||||
|
||||
|
||||
class ResponseFormatColumn(TypeDecorator):
|
||||
"""Custom SQLAlchemy column type for storing a list of ToolRules as JSON."""
|
||||
|
||||
impl = JSON
|
||||
cache_ok = True
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
return serialize_response_format(value)
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
return deserialize_response_format(value)
|
||||
|
@ -14,6 +14,7 @@ from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.memory import Memory
|
||||
from letta.schemas.message import Message, MessageCreate
|
||||
from letta.schemas.openai.chat_completion_response import UsageStatistics
|
||||
from letta.schemas.response_format import ResponseFormatUnion
|
||||
from letta.schemas.source import Source
|
||||
from letta.schemas.tool import Tool
|
||||
from letta.schemas.tool_rule import ToolRule
|
||||
@ -66,6 +67,9 @@ class AgentState(OrmMetadataBase, validate_assignment=True):
|
||||
# llm information
|
||||
llm_config: LLMConfig = Field(..., description="The LLM configuration used by the agent.")
|
||||
embedding_config: EmbeddingConfig = Field(..., description="The embedding configuration used by the agent.")
|
||||
response_format: Optional[ResponseFormatUnion] = Field(
|
||||
None, description="The response format used by the agent when returning from `send_message`."
|
||||
)
|
||||
|
||||
# This is an object representing the in-process state of a running `Agent`
|
||||
# Field in this object can be theoretically edited by tools, and will be persisted by the ORM
|
||||
@ -180,6 +184,7 @@ class CreateAgent(BaseModel, validate_assignment=True): #
|
||||
description="If set to True, the agent will not remember previous messages (though the agent will still retain state via core memory blocks and archival/recall memory). Not recommended unless you have an advanced use case.",
|
||||
)
|
||||
enable_sleeptime: Optional[bool] = Field(None, description="If set to True, memory management will move to a background agent thread.")
|
||||
response_format: Optional[ResponseFormatUnion] = Field(None, description="The response format for the agent.")
|
||||
|
||||
@field_validator("name")
|
||||
@classmethod
|
||||
@ -259,6 +264,7 @@ class UpdateAgent(BaseModel):
|
||||
None, description="The embedding configuration handle used by the agent, specified in the format provider/model-name."
|
||||
)
|
||||
enable_sleeptime: Optional[bool] = Field(None, description="If set to True, memory management will move to a background agent thread.")
|
||||
response_format: Optional[ResponseFormatUnion] = Field(None, description="The response format for the agent.")
|
||||
|
||||
class Config:
|
||||
extra = "ignore" # Ignores extra fields
|
||||
|
@ -82,6 +82,7 @@ class MessageCreate(BaseModel):
|
||||
name: Optional[str] = Field(None, description="The name of the participant.")
|
||||
otid: Optional[str] = Field(None, description="The offline threading id associated with this message")
|
||||
sender_id: Optional[str] = Field(None, description="The id of the sender of the message, can be an identity id or agent id")
|
||||
group_id: Optional[str] = Field(None, description="The multi-agent group that the message was sent in")
|
||||
|
||||
def model_dump(self, to_orm: bool = False, **kwargs) -> Dict[str, Any]:
|
||||
data = super().model_dump(**kwargs)
|
||||
|
78
letta/schemas/response_format.py
Normal file
78
letta/schemas/response_format.py
Normal file
@ -0,0 +1,78 @@
|
||||
from enum import Enum
|
||||
from typing import Annotated, Any, Dict, Literal, Union
|
||||
|
||||
from pydantic import BaseModel, Field, validator
|
||||
|
||||
|
||||
class ResponseFormatType(str, Enum):
|
||||
"""Enum defining the possible response format types."""
|
||||
|
||||
text = "text"
|
||||
json_schema = "json_schema"
|
||||
json_object = "json_object"
|
||||
|
||||
|
||||
class ResponseFormat(BaseModel):
|
||||
"""Base class for all response formats."""
|
||||
|
||||
type: ResponseFormatType = Field(
|
||||
...,
|
||||
description="The type of the response format.",
|
||||
# why use this?
|
||||
example=ResponseFormatType.text,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------
|
||||
# Response Format Types
|
||||
# ---------------------
|
||||
|
||||
# SQLAlchemy type for database mapping
|
||||
ResponseFormatDict = Dict[str, Any]
|
||||
|
||||
|
||||
class TextResponseFormat(ResponseFormat):
|
||||
"""Response format for plain text responses."""
|
||||
|
||||
type: Literal[ResponseFormatType.text] = Field(
|
||||
ResponseFormatType.text,
|
||||
description="The type of the response format.",
|
||||
)
|
||||
|
||||
|
||||
class JsonSchemaResponseFormat(ResponseFormat):
|
||||
"""Response format for JSON schema-based responses."""
|
||||
|
||||
type: Literal[ResponseFormatType.json_schema] = Field(
|
||||
ResponseFormatType.json_schema,
|
||||
description="The type of the response format.",
|
||||
)
|
||||
json_schema: Dict[str, Any] = Field(
|
||||
...,
|
||||
description="The JSON schema of the response.",
|
||||
)
|
||||
|
||||
@validator("json_schema")
|
||||
def validate_json_schema(cls, v: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Validate that the provided schema is a valid JSON schema."""
|
||||
if not isinstance(v, dict):
|
||||
raise ValueError("JSON schema must be a dictionary")
|
||||
if "schema" not in v:
|
||||
raise ValueError("JSON schema should include a $schema property")
|
||||
return v
|
||||
|
||||
|
||||
class JsonObjectResponseFormat(ResponseFormat):
|
||||
"""Response format for JSON object responses."""
|
||||
|
||||
type: Literal[ResponseFormatType.json_object] = Field(
|
||||
ResponseFormatType.json_object,
|
||||
description="The type of the response format.",
|
||||
)
|
||||
|
||||
|
||||
# Pydantic type for validation
|
||||
ResponseFormatUnion = Annotated[
|
||||
Union[TextResponseFormat | JsonSchemaResponseFormat | JsonObjectResponseFormat],
|
||||
Field(discriminator="type"),
|
||||
]
|
14
letta/schemas/tool_execution_result.py
Normal file
14
letta/schemas/tool_execution_result.py
Normal file
@ -0,0 +1,14 @@
|
||||
from typing import Any, List, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from letta.schemas.agent import AgentState
|
||||
|
||||
|
||||
class ToolExecutionResult(BaseModel):
|
||||
status: Literal["success", "error"] = Field(..., description="The status of the tool execution and return object")
|
||||
func_return: Optional[Any] = Field(None, description="The function return object")
|
||||
agent_state: Optional[AgentState] = Field(None, description="The agent state")
|
||||
stdout: Optional[List[str]] = Field(None, description="Captured stdout (prints, logs) from function invocation")
|
||||
stderr: Optional[List[str]] = Field(None, description="Captured stderr from the function invocation")
|
||||
sandbox_config_fingerprint: Optional[str] = Field(None, description="The fingerprint of the config for the sandbox")
|
@ -1240,10 +1240,11 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
and function_call.function.name == self.assistant_message_tool_name
|
||||
and self.assistant_message_tool_kwarg in func_args
|
||||
):
|
||||
# Coerce content to `str` in cases where it's a JSON due to `response_format` being a JSON
|
||||
processed_chunk = AssistantMessage(
|
||||
id=msg_obj.id,
|
||||
date=msg_obj.created_at,
|
||||
content=func_args[self.assistant_message_tool_kwarg],
|
||||
content=str(func_args[self.assistant_message_tool_kwarg]),
|
||||
name=msg_obj.name,
|
||||
otid=Message.generate_otid_from_id(msg_obj.id, chunk_index) if chunk_index is not None else None,
|
||||
)
|
||||
|
@ -111,7 +111,7 @@ async def send_message_to_agent_chat_completions(
|
||||
server.send_messages,
|
||||
actor=actor,
|
||||
agent_id=letta_agent.agent_state.id,
|
||||
messages=messages,
|
||||
input_messages=messages,
|
||||
interface=streaming_interface,
|
||||
put_inner_thoughts_first=False,
|
||||
)
|
||||
|
@ -412,7 +412,7 @@ def list_blocks(
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
try:
|
||||
agent = server.agent_manager.get_agent_by_id(agent_id, actor=actor)
|
||||
agent = server.agent_manager.get_agent_by_id(agent_id, actor)
|
||||
return agent.memory.blocks
|
||||
except NoResultFound as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
@ -640,7 +640,7 @@ async def send_message(
|
||||
result = await server.send_message_to_agent(
|
||||
agent_id=agent_id,
|
||||
actor=actor,
|
||||
messages=request.messages,
|
||||
input_messages=request.messages,
|
||||
stream_steps=False,
|
||||
stream_tokens=False,
|
||||
# Support for AssistantMessage
|
||||
@ -703,7 +703,7 @@ async def send_message_streaming(
|
||||
result = await server.send_message_to_agent(
|
||||
agent_id=agent_id,
|
||||
actor=actor,
|
||||
messages=request.messages,
|
||||
input_messages=request.messages,
|
||||
stream_steps=True,
|
||||
stream_tokens=request.stream_tokens,
|
||||
# Support for AssistantMessage
|
||||
@ -730,7 +730,7 @@ async def process_message_background(
|
||||
result = await server.send_message_to_agent(
|
||||
agent_id=agent_id,
|
||||
actor=actor,
|
||||
messages=messages,
|
||||
input_messages=messages,
|
||||
stream_steps=False, # NOTE(matt)
|
||||
stream_tokens=False,
|
||||
use_assistant_message=use_assistant_message,
|
||||
|
@ -128,7 +128,7 @@ async def send_group_message(
|
||||
result = await server.send_group_message_to_agent(
|
||||
group_id=group_id,
|
||||
actor=actor,
|
||||
messages=request.messages,
|
||||
input_messages=request.messages,
|
||||
stream_steps=False,
|
||||
stream_tokens=False,
|
||||
# Support for AssistantMessage
|
||||
@ -167,7 +167,7 @@ async def send_group_message_streaming(
|
||||
result = await server.send_group_message_to_agent(
|
||||
group_id=group_id,
|
||||
actor=actor,
|
||||
messages=request.messages,
|
||||
input_messages=request.messages,
|
||||
stream_steps=True,
|
||||
stream_tokens=request.stream_tokens,
|
||||
# Support for AssistantMessage
|
||||
|
@ -7,7 +7,7 @@ from starlette.requests import Request
|
||||
from letta.agents.letta_agent_batch import LettaAgentBatch
|
||||
from letta.log import get_logger
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.schemas.job import BatchJob, JobStatus, JobType
|
||||
from letta.schemas.job import BatchJob, JobStatus, JobType, JobUpdate
|
||||
from letta.schemas.letta_request import CreateBatch
|
||||
from letta.server.rest_api.utils import get_letta_server
|
||||
from letta.server.server import SyncServer
|
||||
@ -43,18 +43,18 @@ async def create_messages_batch(
|
||||
if length > max_bytes:
|
||||
raise HTTPException(status_code=413, detail=f"Request too large ({length} bytes). Max is {max_bytes} bytes.")
|
||||
|
||||
try:
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
batch_job = BatchJob(
|
||||
user_id=actor.id,
|
||||
status=JobStatus.running,
|
||||
metadata={
|
||||
"job_type": "batch_messages",
|
||||
},
|
||||
callback_url=str(payload.callback_url),
|
||||
)
|
||||
|
||||
# Create a new job
|
||||
batch_job = BatchJob(
|
||||
user_id=actor.id,
|
||||
status=JobStatus.created,
|
||||
metadata={
|
||||
"job_type": "batch_messages",
|
||||
},
|
||||
callback_url=str(payload.callback_url),
|
||||
)
|
||||
try:
|
||||
batch_job = server.job_manager.create_job(pydantic_job=batch_job, actor=actor)
|
||||
|
||||
# create the batch runner
|
||||
batch_runner = LettaAgentBatch(
|
||||
@ -67,14 +67,17 @@ async def create_messages_batch(
|
||||
job_manager=server.job_manager,
|
||||
actor=actor,
|
||||
)
|
||||
llm_batch_job = await batch_runner.step_until_request(batch_requests=payload.requests, letta_batch_job_id=batch_job.id)
|
||||
await batch_runner.step_until_request(batch_requests=payload.requests, letta_batch_job_id=batch_job.id)
|
||||
|
||||
# TODO: update run metadata
|
||||
batch_job = server.job_manager.create_job(pydantic_job=batch_job, actor=actor)
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
print("Error creating batch job", e)
|
||||
traceback.print_exc()
|
||||
|
||||
# mark job as failed
|
||||
server.job_manager.update_job_by_id(job_id=batch_job.id, job=BatchJob(status=JobStatus.failed), actor=actor)
|
||||
raise
|
||||
return batch_job
|
||||
|
||||
@ -125,8 +128,19 @@ async def cancel_batch_run(
|
||||
|
||||
try:
|
||||
job = server.job_manager.get_job_by_id(job_id=batch_id, actor=actor)
|
||||
job.status = JobStatus.cancelled
|
||||
server.job_manager.update_job_by_id(job_id=job, job=job)
|
||||
# TODO: actually cancel it
|
||||
job = server.job_manager.update_job_by_id(job_id=job.id, job_update=JobUpdate(status=JobStatus.cancelled), actor=actor)
|
||||
|
||||
# Get related llm batch jobs
|
||||
llm_batch_jobs = server.batch_manager.list_llm_batch_jobs(letta_batch_id=job.id, actor=actor)
|
||||
for llm_batch_job in llm_batch_jobs:
|
||||
if llm_batch_job.status in {JobStatus.running, JobStatus.created}:
|
||||
# TODO: Extend to providers beyond anthropic
|
||||
# TODO: For now, we only support anthropic
|
||||
# Cancel the job
|
||||
anthropic_batch_id = llm_batch_job.create_batch_response.id
|
||||
await server.anthropic_async_client.messages.batches.cancel(anthropic_batch_id)
|
||||
|
||||
# Update all the batch_job statuses
|
||||
server.batch_manager.update_llm_batch_status(llm_batch_id=llm_batch_job.id, status=JobStatus.cancelled, actor=actor)
|
||||
except NoResultFound:
|
||||
raise HTTPException(status_code=404, detail="Run not found")
|
||||
|
@ -28,7 +28,6 @@ from letta.functions.mcp_client.types import MCPServerType, MCPTool, SSEServerCo
|
||||
from letta.groups.helpers import load_multi_agent
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
from letta.helpers.json_helpers import json_dumps, json_loads
|
||||
from letta.helpers.message_helper import prepare_input_message_create
|
||||
|
||||
# TODO use custom interface
|
||||
from letta.interface import AgentInterface # abstract
|
||||
@ -148,7 +147,7 @@ class Server(object):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def send_messages(self, user_id: str, agent_id: str, messages: Union[MessageCreate, List[Message]]) -> None:
|
||||
def send_messages(self, user_id: str, agent_id: str, input_messages: List[MessageCreate]) -> None:
|
||||
"""Send a list of messages to the agent"""
|
||||
raise NotImplementedError
|
||||
|
||||
@ -372,19 +371,13 @@ class SyncServer(Server):
|
||||
self,
|
||||
actor: User,
|
||||
agent_id: str,
|
||||
input_messages: Union[Message, List[Message]],
|
||||
input_messages: List[MessageCreate],
|
||||
interface: Union[AgentInterface, None] = None, # needed to getting responses
|
||||
put_inner_thoughts_first: bool = True,
|
||||
# timestamp: Optional[datetime],
|
||||
) -> LettaUsageStatistics:
|
||||
"""Send the input message through the agent"""
|
||||
# TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user
|
||||
# Input validation
|
||||
if isinstance(input_messages, Message):
|
||||
input_messages = [input_messages]
|
||||
if not all(isinstance(m, Message) for m in input_messages):
|
||||
raise ValueError(f"messages should be a Message or a list of Message, got {type(input_messages)}")
|
||||
|
||||
logger.debug(f"Got input messages: {input_messages}")
|
||||
letta_agent = None
|
||||
try:
|
||||
@ -400,8 +393,9 @@ class SyncServer(Server):
|
||||
metadata = interface.metadata if hasattr(interface, "metadata") else None
|
||||
else:
|
||||
metadata = None
|
||||
|
||||
usage_stats = letta_agent.step(
|
||||
messages=input_messages,
|
||||
input_messages=input_messages,
|
||||
chaining=self.chaining,
|
||||
max_chaining_steps=self.max_chaining_steps,
|
||||
stream=token_streaming,
|
||||
@ -572,23 +566,14 @@ class SyncServer(Server):
|
||||
)
|
||||
|
||||
# NOTE: eventually deprecate and only allow passing Message types
|
||||
# Convert to a Message object
|
||||
if timestamp:
|
||||
message = Message(
|
||||
agent_id=agent_id,
|
||||
role="user",
|
||||
content=[TextContent(text=packaged_user_message)],
|
||||
created_at=timestamp,
|
||||
)
|
||||
else:
|
||||
message = Message(
|
||||
agent_id=agent_id,
|
||||
role="user",
|
||||
content=[TextContent(text=packaged_user_message)],
|
||||
)
|
||||
message = MessageCreate(
|
||||
agent_id=agent_id,
|
||||
role="user",
|
||||
content=[TextContent(text=packaged_user_message)],
|
||||
)
|
||||
|
||||
# Run the agent state forward
|
||||
usage = self._step(actor=actor, agent_id=agent_id, input_messages=message)
|
||||
usage = self._step(actor=actor, agent_id=agent_id, input_messages=[message])
|
||||
return usage
|
||||
|
||||
def system_message(
|
||||
@ -660,23 +645,14 @@ class SyncServer(Server):
|
||||
self,
|
||||
actor: User,
|
||||
agent_id: str,
|
||||
messages: Union[List[MessageCreate], List[Message]],
|
||||
input_messages: List[MessageCreate],
|
||||
wrap_user_message: bool = True,
|
||||
wrap_system_message: bool = True,
|
||||
interface: Union[AgentInterface, ChatCompletionsStreamingInterface, None] = None, # needed for responses
|
||||
metadata: Optional[dict] = None, # Pass through metadata to interface
|
||||
put_inner_thoughts_first: bool = True,
|
||||
) -> LettaUsageStatistics:
|
||||
"""Send a list of messages to the agent.
|
||||
|
||||
If messages are of type MessageCreate, convert them to Message objects before sending.
|
||||
"""
|
||||
if all(isinstance(m, MessageCreate) for m in messages):
|
||||
message_objects = [prepare_input_message_create(m, agent_id, wrap_user_message, wrap_system_message) for m in messages]
|
||||
elif all(isinstance(m, Message) for m in messages):
|
||||
message_objects = messages
|
||||
else:
|
||||
raise ValueError(f"All messages must be of type Message or MessageCreate, got {[type(m) for m in messages]}")
|
||||
"""Send a list of messages to the agent."""
|
||||
|
||||
# Store metadata in interface if provided
|
||||
if metadata and hasattr(interface, "metadata"):
|
||||
@ -686,7 +662,7 @@ class SyncServer(Server):
|
||||
return self._step(
|
||||
actor=actor,
|
||||
agent_id=agent_id,
|
||||
input_messages=message_objects,
|
||||
input_messages=input_messages,
|
||||
interface=interface,
|
||||
put_inner_thoughts_first=put_inner_thoughts_first,
|
||||
)
|
||||
@ -703,8 +679,6 @@ class SyncServer(Server):
|
||||
@trace_method
|
||||
def get_cached_llm_config(self, **kwargs):
|
||||
key = make_key(**kwargs)
|
||||
print(self._llm_config_cache)
|
||||
print("KEY", key)
|
||||
if key not in self._llm_config_cache:
|
||||
self._llm_config_cache[key] = self.get_llm_config_from_handle(**kwargs)
|
||||
return self._llm_config_cache[key]
|
||||
@ -1019,12 +993,8 @@ class SyncServer(Server):
|
||||
agent = self.load_agent(agent_id=sleeptime_agent.id, actor=actor)
|
||||
for passage in self.list_data_source_passages(source_id=source.id, user_id=actor.id):
|
||||
agent.step(
|
||||
messages=[
|
||||
Message(
|
||||
role="user",
|
||||
content=[TextContent(text=passage.text)],
|
||||
agent_id=sleeptime_agent.id,
|
||||
),
|
||||
input_messages=[
|
||||
MessageCreate(role="user", content=passage.text),
|
||||
]
|
||||
)
|
||||
self.agent_manager.delete_agent(agent_id=sleeptime_agent.id, actor=actor)
|
||||
@ -1182,7 +1152,6 @@ class SyncServer(Server):
|
||||
provider = self.get_provider_from_name(provider_name)
|
||||
|
||||
llm_configs = [config for config in provider.list_llm_models() if config.handle == handle]
|
||||
print("LLM CONFIGS", llm_configs)
|
||||
if not llm_configs:
|
||||
llm_configs = [config for config in provider.list_llm_models() if config.model == model_name]
|
||||
if not llm_configs:
|
||||
@ -1195,8 +1164,6 @@ class SyncServer(Server):
|
||||
if not llm_configs:
|
||||
raise e
|
||||
|
||||
print("CONFIGS", llm_configs)
|
||||
|
||||
if len(llm_configs) == 1:
|
||||
llm_config = llm_configs[0]
|
||||
elif len(llm_configs) > 1:
|
||||
@ -1343,17 +1310,17 @@ class SyncServer(Server):
|
||||
|
||||
# Next, attempt to run the tool with the sandbox
|
||||
try:
|
||||
sandbox_run_result = ToolExecutionSandbox(tool.name, tool_args, actor, tool_object=tool).run(
|
||||
tool_execution_result = ToolExecutionSandbox(tool.name, tool_args, actor, tool_object=tool).run(
|
||||
agent_state=agent_state, additional_env_vars=tool_env_vars
|
||||
)
|
||||
return ToolReturnMessage(
|
||||
id="null",
|
||||
tool_call_id="null",
|
||||
date=get_utc_time(),
|
||||
status=sandbox_run_result.status,
|
||||
tool_return=str(sandbox_run_result.func_return),
|
||||
stdout=sandbox_run_result.stdout,
|
||||
stderr=sandbox_run_result.stderr,
|
||||
status=tool_execution_result.status,
|
||||
tool_return=str(tool_execution_result.func_return),
|
||||
stdout=tool_execution_result.stdout,
|
||||
stderr=tool_execution_result.stderr,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@ -1567,7 +1534,7 @@ class SyncServer(Server):
|
||||
agent_id: str,
|
||||
actor: User,
|
||||
# role: MessageRole,
|
||||
messages: Union[List[Message], List[MessageCreate]],
|
||||
input_messages: List[MessageCreate],
|
||||
stream_steps: bool,
|
||||
stream_tokens: bool,
|
||||
# related to whether or not we return `LettaMessage`s or `Message`s
|
||||
@ -1647,7 +1614,7 @@ class SyncServer(Server):
|
||||
self.send_messages,
|
||||
actor=actor,
|
||||
agent_id=agent_id,
|
||||
messages=messages,
|
||||
input_messages=input_messages,
|
||||
interface=streaming_interface,
|
||||
metadata=metadata,
|
||||
)
|
||||
@ -1701,7 +1668,7 @@ class SyncServer(Server):
|
||||
self,
|
||||
group_id: str,
|
||||
actor: User,
|
||||
messages: Union[List[Message], List[MessageCreate]],
|
||||
input_messages: Union[List[Message], List[MessageCreate]],
|
||||
stream_steps: bool,
|
||||
stream_tokens: bool,
|
||||
chat_completion_mode: bool = False,
|
||||
@ -1751,7 +1718,7 @@ class SyncServer(Server):
|
||||
task = asyncio.create_task(
|
||||
asyncio.to_thread(
|
||||
letta_multi_agent.step,
|
||||
messages=messages,
|
||||
input_messages=input_messages,
|
||||
chaining=self.chaining,
|
||||
max_chaining_steps=self.max_chaining_steps,
|
||||
)
|
||||
|
@ -364,6 +364,7 @@ class AgentManager:
|
||||
"base_template_id": agent_update.base_template_id,
|
||||
"message_buffer_autoclear": agent_update.message_buffer_autoclear,
|
||||
"enable_sleeptime": agent_update.enable_sleeptime,
|
||||
"response_format": agent_update.response_format,
|
||||
}
|
||||
for col, val in scalar_updates.items():
|
||||
if val is not None:
|
||||
|
@ -291,9 +291,7 @@ class LLMBatchManager:
|
||||
return [item.to_pydantic() for item in results]
|
||||
|
||||
def bulk_update_llm_batch_items(
|
||||
self,
|
||||
llm_batch_id_agent_id_pairs: List[Tuple[str, str]],
|
||||
field_updates: List[Dict[str, Any]],
|
||||
self, llm_batch_id_agent_id_pairs: List[Tuple[str, str]], field_updates: List[Dict[str, Any]], strict: bool = True
|
||||
) -> None:
|
||||
"""
|
||||
Efficiently update multiple LLMBatchItem rows by (llm_batch_id, agent_id) pairs.
|
||||
@ -301,30 +299,43 @@ class LLMBatchManager:
|
||||
Args:
|
||||
llm_batch_id_agent_id_pairs: List of (llm_batch_id, agent_id) tuples identifying items to update
|
||||
field_updates: List of dictionaries containing the fields to update for each item
|
||||
strict: Whether to error if any of the requested keys don't exist (default True).
|
||||
If False, missing pairs are skipped.
|
||||
"""
|
||||
if not llm_batch_id_agent_id_pairs or not field_updates:
|
||||
return
|
||||
|
||||
if len(llm_batch_id_agent_id_pairs) != len(field_updates):
|
||||
raise ValueError("batch_id_agent_id_pairs and field_updates must have the same length")
|
||||
raise ValueError("llm_batch_id_agent_id_pairs and field_updates must have the same length")
|
||||
|
||||
with self.session_maker() as session:
|
||||
# Lookup primary keys
|
||||
# Lookup primary keys for all requested (batch_id, agent_id) pairs
|
||||
items = (
|
||||
session.query(LLMBatchItem.id, LLMBatchItem.llm_batch_id, LLMBatchItem.agent_id)
|
||||
.filter(tuple_(LLMBatchItem.llm_batch_id, LLMBatchItem.agent_id).in_(llm_batch_id_agent_id_pairs))
|
||||
.all()
|
||||
)
|
||||
pair_to_pk = {(b, a): id for id, b, a in items}
|
||||
pair_to_pk = {(batch_id, agent_id): pk for pk, batch_id, agent_id in items}
|
||||
|
||||
if strict:
|
||||
requested = set(llm_batch_id_agent_id_pairs)
|
||||
found = set(pair_to_pk.keys())
|
||||
missing = requested - found
|
||||
if missing:
|
||||
raise ValueError(
|
||||
f"Cannot bulk-update batch items: no records for the following " f"(llm_batch_id, agent_id) pairs: {missing}"
|
||||
)
|
||||
|
||||
# Build mappings, skipping any missing when strict=False
|
||||
mappings = []
|
||||
for (llm_batch_id, agent_id), fields in zip(llm_batch_id_agent_id_pairs, field_updates):
|
||||
pk_id = pair_to_pk.get((llm_batch_id, agent_id))
|
||||
if not pk_id:
|
||||
for (batch_id, agent_id), fields in zip(llm_batch_id_agent_id_pairs, field_updates):
|
||||
pk = pair_to_pk.get((batch_id, agent_id))
|
||||
if pk is None:
|
||||
# skip missing in non-strict mode
|
||||
continue
|
||||
|
||||
update_fields = fields.copy()
|
||||
update_fields["id"] = pk_id
|
||||
update_fields["id"] = pk
|
||||
mappings.append(update_fields)
|
||||
|
||||
if mappings:
|
||||
@ -332,10 +343,7 @@ class LLMBatchManager:
|
||||
session.commit()
|
||||
|
||||
@enforce_types
|
||||
def bulk_update_batch_llm_items_results_by_agent(
|
||||
self,
|
||||
updates: List[ItemUpdateInfo],
|
||||
) -> None:
|
||||
def bulk_update_batch_llm_items_results_by_agent(self, updates: List[ItemUpdateInfo], strict: bool = True) -> None:
|
||||
"""Update request status and batch results for multiple batch items."""
|
||||
batch_id_agent_id_pairs = [(update.llm_batch_id, update.agent_id) for update in updates]
|
||||
field_updates = [
|
||||
@ -346,29 +354,23 @@ class LLMBatchManager:
|
||||
for update in updates
|
||||
]
|
||||
|
||||
self.bulk_update_llm_batch_items(batch_id_agent_id_pairs, field_updates)
|
||||
self.bulk_update_llm_batch_items(batch_id_agent_id_pairs, field_updates, strict=strict)
|
||||
|
||||
@enforce_types
|
||||
def bulk_update_llm_batch_items_step_status_by_agent(
|
||||
self,
|
||||
updates: List[StepStatusUpdateInfo],
|
||||
) -> None:
|
||||
def bulk_update_llm_batch_items_step_status_by_agent(self, updates: List[StepStatusUpdateInfo], strict: bool = True) -> None:
|
||||
"""Update step status for multiple batch items."""
|
||||
batch_id_agent_id_pairs = [(update.llm_batch_id, update.agent_id) for update in updates]
|
||||
field_updates = [{"step_status": update.step_status} for update in updates]
|
||||
|
||||
self.bulk_update_llm_batch_items(batch_id_agent_id_pairs, field_updates)
|
||||
self.bulk_update_llm_batch_items(batch_id_agent_id_pairs, field_updates, strict=strict)
|
||||
|
||||
@enforce_types
|
||||
def bulk_update_llm_batch_items_request_status_by_agent(
|
||||
self,
|
||||
updates: List[RequestStatusUpdateInfo],
|
||||
) -> None:
|
||||
def bulk_update_llm_batch_items_request_status_by_agent(self, updates: List[RequestStatusUpdateInfo], strict: bool = True) -> None:
|
||||
"""Update request status for multiple batch items."""
|
||||
batch_id_agent_id_pairs = [(update.llm_batch_id, update.agent_id) for update in updates]
|
||||
field_updates = [{"request_status": update.request_status} for update in updates]
|
||||
|
||||
self.bulk_update_llm_batch_items(batch_id_agent_id_pairs, field_updates)
|
||||
self.bulk_update_llm_batch_items(batch_id_agent_id_pairs, field_updates, strict=strict)
|
||||
|
||||
@enforce_types
|
||||
def delete_llm_batch_item(self, item_id: str, actor: PydanticUser) -> None:
|
||||
|
@ -1,16 +1,17 @@
|
||||
from typing import Any, Dict, Optional, Tuple, Type
|
||||
import traceback
|
||||
from typing import Any, Dict, Optional, Type
|
||||
|
||||
from letta.log import get_logger
|
||||
from letta.orm.enums import ToolType
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.sandbox_config import SandboxConfig, SandboxRunResult
|
||||
from letta.schemas.sandbox_config import SandboxConfig
|
||||
from letta.schemas.tool import Tool
|
||||
from letta.schemas.tool_execution_result import ToolExecutionResult
|
||||
from letta.schemas.user import User
|
||||
from letta.services.tool_executor.tool_executor import (
|
||||
ExternalComposioToolExecutor,
|
||||
ExternalMCPToolExecutor,
|
||||
LettaCoreToolExecutor,
|
||||
LettaMemoryToolExecutor,
|
||||
LettaMultiAgentToolExecutor,
|
||||
SandboxToolExecutor,
|
||||
ToolExecutor,
|
||||
@ -24,8 +25,9 @@ class ToolExecutorFactory:
|
||||
|
||||
_executor_map: Dict[ToolType, Type[ToolExecutor]] = {
|
||||
ToolType.LETTA_CORE: LettaCoreToolExecutor,
|
||||
ToolType.LETTA_MEMORY_CORE: LettaCoreToolExecutor,
|
||||
ToolType.LETTA_SLEEPTIME_CORE: LettaCoreToolExecutor,
|
||||
ToolType.LETTA_MULTI_AGENT_CORE: LettaMultiAgentToolExecutor,
|
||||
ToolType.LETTA_MEMORY_CORE: LettaMemoryToolExecutor,
|
||||
ToolType.EXTERNAL_COMPOSIO: ExternalComposioToolExecutor,
|
||||
ToolType.EXTERNAL_MCP: ExternalMCPToolExecutor,
|
||||
}
|
||||
@ -33,13 +35,8 @@ class ToolExecutorFactory:
|
||||
@classmethod
|
||||
def get_executor(cls, tool_type: ToolType) -> ToolExecutor:
|
||||
"""Get the appropriate executor for the given tool type."""
|
||||
executor_class = cls._executor_map.get(tool_type)
|
||||
|
||||
if executor_class:
|
||||
return executor_class()
|
||||
|
||||
# Default to sandbox executor for unknown types
|
||||
return SandboxToolExecutor()
|
||||
executor_class = cls._executor_map.get(tool_type, SandboxToolExecutor)
|
||||
return executor_class()
|
||||
|
||||
|
||||
class ToolExecutionManager:
|
||||
@ -58,7 +55,7 @@ class ToolExecutionManager:
|
||||
self.sandbox_config = sandbox_config
|
||||
self.sandbox_env_vars = sandbox_env_vars
|
||||
|
||||
def execute_tool(self, function_name: str, function_args: dict, tool: Tool) -> Tuple[Any, Optional[SandboxRunResult]]:
|
||||
def execute_tool(self, function_name: str, function_args: dict, tool: Tool) -> ToolExecutionResult:
|
||||
"""
|
||||
Execute a tool and persist any state changes.
|
||||
|
||||
@ -71,36 +68,17 @@ class ToolExecutionManager:
|
||||
Tuple containing the function response and sandbox run result (if applicable)
|
||||
"""
|
||||
try:
|
||||
# Get the appropriate executor for this tool type
|
||||
executor = ToolExecutorFactory.get_executor(tool.tool_type)
|
||||
|
||||
# Execute the tool
|
||||
return executor.execute(
|
||||
function_name, function_args, self.agent_state, tool, self.actor, self.sandbox_config, self.sandbox_env_vars
|
||||
function_name,
|
||||
function_args,
|
||||
self.agent_state,
|
||||
tool,
|
||||
self.actor,
|
||||
self.sandbox_config,
|
||||
self.sandbox_env_vars,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error executing tool {function_name}: {str(e)}")
|
||||
error_message = get_friendly_error_msg(function_name=function_name, exception_name=type(e).__name__, exception_message=str(e))
|
||||
return error_message, SandboxRunResult(status="error")
|
||||
|
||||
@trace_method
|
||||
async def execute_tool_async(self, function_name: str, function_args: dict, tool: Tool) -> Tuple[Any, Optional[SandboxRunResult]]:
|
||||
"""
|
||||
Execute a tool asynchronously and persist any state changes.
|
||||
"""
|
||||
try:
|
||||
# Get the appropriate executor for this tool type
|
||||
# TODO: Extend this async model to composio
|
||||
|
||||
if tool.tool_type == ToolType.CUSTOM:
|
||||
executor = SandboxToolExecutor()
|
||||
result_tuple = await executor.execute(function_name, function_args, self.agent_state, tool, self.actor)
|
||||
else:
|
||||
executor = ToolExecutorFactory.get_executor(tool.tool_type)
|
||||
result_tuple = executor.execute(function_name, function_args, self.agent_state, tool, self.actor)
|
||||
return result_tuple
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error executing tool {function_name}: {str(e)}")
|
||||
error_message = get_friendly_error_msg(
|
||||
@ -108,4 +86,35 @@ class ToolExecutionManager:
|
||||
exception_name=type(e).__name__,
|
||||
exception_message=str(e),
|
||||
)
|
||||
return error_message, SandboxRunResult(status="error")
|
||||
return ToolExecutionResult(
|
||||
status="error",
|
||||
func_return=error_message,
|
||||
stderr=[traceback.format_exc()],
|
||||
)
|
||||
|
||||
@trace_method
|
||||
async def execute_tool_async(self, function_name: str, function_args: dict, tool: Tool) -> ToolExecutionResult:
|
||||
"""
|
||||
Execute a tool asynchronously and persist any state changes.
|
||||
"""
|
||||
try:
|
||||
executor = ToolExecutorFactory.get_executor(tool.tool_type)
|
||||
# TODO: Extend this async model to composio
|
||||
if isinstance(executor, SandboxToolExecutor):
|
||||
result = await executor.execute(function_name, function_args, self.agent_state, tool, self.actor)
|
||||
else:
|
||||
result = executor.execute(function_name, function_args, self.agent_state, tool, self.actor)
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error executing tool {function_name}: {str(e)}")
|
||||
error_message = get_friendly_error_msg(
|
||||
function_name=function_name,
|
||||
exception_name=type(e).__name__,
|
||||
exception_message=str(e),
|
||||
)
|
||||
return ToolExecutionResult(
|
||||
status="error",
|
||||
func_return=error_message,
|
||||
stderr=[traceback.format_exc()],
|
||||
)
|
||||
|
@ -13,8 +13,9 @@ from typing import Any, Dict, Optional
|
||||
from letta.functions.helpers import generate_model_from_args_json_schema
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.sandbox_config import SandboxConfig, SandboxRunResult, SandboxType
|
||||
from letta.schemas.sandbox_config import SandboxConfig, SandboxType
|
||||
from letta.schemas.tool import Tool
|
||||
from letta.schemas.tool_execution_result import ToolExecutionResult
|
||||
from letta.schemas.user import User
|
||||
from letta.services.helpers.tool_execution_helper import (
|
||||
add_imports_and_pydantic_schemas_for_args,
|
||||
@ -72,7 +73,11 @@ class ToolExecutionSandbox:
|
||||
self.force_recreate = force_recreate
|
||||
self.force_recreate_venv = force_recreate_venv
|
||||
|
||||
def run(self, agent_state: Optional[AgentState] = None, additional_env_vars: Optional[Dict] = None) -> SandboxRunResult:
|
||||
def run(
|
||||
self,
|
||||
agent_state: Optional[AgentState] = None,
|
||||
additional_env_vars: Optional[Dict] = None,
|
||||
) -> ToolExecutionResult:
|
||||
"""
|
||||
Run the tool in a sandbox environment.
|
||||
|
||||
@ -81,7 +86,7 @@ class ToolExecutionSandbox:
|
||||
additional_env_vars (Optional[Dict]): Environment variables to inject into the sandbox
|
||||
|
||||
Returns:
|
||||
Tuple[Any, Optional[AgentState]]: Tuple containing (tool_result, agent_state)
|
||||
ToolExecutionResult: Object containing tool execution outcome (e.g. status, response)
|
||||
"""
|
||||
if tool_settings.e2b_api_key and not self.privileged_tools:
|
||||
logger.debug(f"Using e2b sandbox to execute {self.tool_name}")
|
||||
@ -115,7 +120,7 @@ class ToolExecutionSandbox:
|
||||
@trace_method
|
||||
def run_local_dir_sandbox(
|
||||
self, agent_state: Optional[AgentState] = None, additional_env_vars: Optional[Dict] = None
|
||||
) -> SandboxRunResult:
|
||||
) -> ToolExecutionResult:
|
||||
sbx_config = self.sandbox_config_manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.LOCAL, actor=self.user)
|
||||
local_configs = sbx_config.get_local_config()
|
||||
|
||||
@ -162,7 +167,12 @@ class ToolExecutionSandbox:
|
||||
os.remove(temp_file_path)
|
||||
|
||||
@trace_method
|
||||
def run_local_dir_sandbox_venv(self, sbx_config: SandboxConfig, env: Dict[str, str], temp_file_path: str) -> SandboxRunResult:
|
||||
def run_local_dir_sandbox_venv(
|
||||
self,
|
||||
sbx_config: SandboxConfig,
|
||||
env: Dict[str, str],
|
||||
temp_file_path: str,
|
||||
) -> ToolExecutionResult:
|
||||
local_configs = sbx_config.get_local_config()
|
||||
sandbox_dir = os.path.expanduser(local_configs.sandbox_dir) # Expand tilde
|
||||
venv_path = os.path.join(sandbox_dir, local_configs.venv_name)
|
||||
@ -205,12 +215,12 @@ class ToolExecutionSandbox:
|
||||
func_result, stdout = self.parse_out_function_results_markers(result.stdout)
|
||||
func_return, agent_state = self.parse_best_effort(func_result)
|
||||
|
||||
return SandboxRunResult(
|
||||
return ToolExecutionResult(
|
||||
status="success",
|
||||
func_return=func_return,
|
||||
agent_state=agent_state,
|
||||
stdout=[stdout] if stdout else [],
|
||||
stderr=[result.stderr] if result.stderr else [],
|
||||
status="success",
|
||||
sandbox_config_fingerprint=sbx_config.fingerprint(),
|
||||
)
|
||||
|
||||
@ -221,12 +231,12 @@ class ToolExecutionSandbox:
|
||||
exception_name=type(e).__name__,
|
||||
exception_message=str(e),
|
||||
)
|
||||
return SandboxRunResult(
|
||||
return ToolExecutionResult(
|
||||
status="error",
|
||||
func_return=func_return,
|
||||
agent_state=None,
|
||||
stdout=[e.stdout] if e.stdout else [],
|
||||
stderr=[e.stderr] if e.stderr else [],
|
||||
status="error",
|
||||
sandbox_config_fingerprint=sbx_config.fingerprint(),
|
||||
)
|
||||
|
||||
@ -238,7 +248,12 @@ class ToolExecutionSandbox:
|
||||
raise e
|
||||
|
||||
@trace_method
|
||||
def run_local_dir_sandbox_directly(self, sbx_config: SandboxConfig, env: Dict[str, str], temp_file_path: str) -> SandboxRunResult:
|
||||
def run_local_dir_sandbox_directly(
|
||||
self,
|
||||
sbx_config: SandboxConfig,
|
||||
env: Dict[str, str],
|
||||
temp_file_path: str,
|
||||
) -> ToolExecutionResult:
|
||||
status = "success"
|
||||
func_return, agent_state, stderr = None, None, None
|
||||
|
||||
@ -288,12 +303,12 @@ class ToolExecutionSandbox:
|
||||
stdout_output = [captured_stdout.getvalue()] if captured_stdout.getvalue() else []
|
||||
stderr_output = [captured_stderr.getvalue()] if captured_stderr.getvalue() else []
|
||||
|
||||
return SandboxRunResult(
|
||||
return ToolExecutionResult(
|
||||
status=status,
|
||||
func_return=func_return,
|
||||
agent_state=agent_state,
|
||||
stdout=stdout_output,
|
||||
stderr=stderr_output,
|
||||
status=status,
|
||||
sandbox_config_fingerprint=sbx_config.fingerprint(),
|
||||
)
|
||||
|
||||
@ -307,7 +322,11 @@ class ToolExecutionSandbox:
|
||||
|
||||
# e2b sandbox specific functions
|
||||
|
||||
def run_e2b_sandbox(self, agent_state: Optional[AgentState] = None, additional_env_vars: Optional[Dict] = None) -> SandboxRunResult:
|
||||
def run_e2b_sandbox(
|
||||
self,
|
||||
agent_state: Optional[AgentState] = None,
|
||||
additional_env_vars: Optional[Dict] = None,
|
||||
) -> ToolExecutionResult:
|
||||
sbx_config = self.sandbox_config_manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.E2B, actor=self.user)
|
||||
sbx = self.get_running_e2b_sandbox_with_same_state(sbx_config)
|
||||
if not sbx or self.force_recreate:
|
||||
@ -348,12 +367,12 @@ class ToolExecutionSandbox:
|
||||
else:
|
||||
raise ValueError(f"Tool {self.tool_name} returned execution with None")
|
||||
|
||||
return SandboxRunResult(
|
||||
return ToolExecutionResult(
|
||||
status="error" if execution.error else "success",
|
||||
func_return=func_return,
|
||||
agent_state=agent_state,
|
||||
stdout=execution.logs.stdout,
|
||||
stderr=execution.logs.stderr,
|
||||
status="error" if execution.error else "success",
|
||||
sandbox_config_fingerprint=sbx_config.fingerprint(),
|
||||
)
|
||||
|
||||
@ -535,7 +554,7 @@ class ToolExecutionSandbox:
|
||||
Generate the code string to call the function.
|
||||
|
||||
Args:
|
||||
inject_agent_state (bool): Whether to inject the axgent's state as an input into the tool
|
||||
inject_agent_state (bool): Whether to inject the agent's state as an input into the tool
|
||||
|
||||
Returns:
|
||||
str: Generated code string for calling the tool
|
||||
|
@ -1,15 +1,17 @@
|
||||
import math
|
||||
import traceback
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from letta.constants import COMPOSIO_ENTITY_ENV_VAR_KEY, RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
|
||||
from letta.constants import COMPOSIO_ENTITY_ENV_VAR_KEY, CORE_MEMORY_LINE_NUMBER_WARNING, RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
|
||||
from letta.functions.ast_parsers import coerce_dict_args_by_annotations, get_function_annotations_from_source
|
||||
from letta.functions.helpers import execute_composio_action, generate_composio_action_from_func_name
|
||||
from letta.helpers.composio_helpers import get_composio_api_key
|
||||
from letta.helpers.json_helpers import json_dumps
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.sandbox_config import SandboxConfig, SandboxRunResult
|
||||
from letta.schemas.sandbox_config import SandboxConfig
|
||||
from letta.schemas.tool import Tool
|
||||
from letta.schemas.tool_execution_result import ToolExecutionResult
|
||||
from letta.schemas.user import User
|
||||
from letta.services.agent_manager import AgentManager
|
||||
from letta.services.message_manager import MessageManager
|
||||
@ -33,7 +35,7 @@ class ToolExecutor(ABC):
|
||||
actor: User,
|
||||
sandbox_config: Optional[SandboxConfig] = None,
|
||||
sandbox_env_vars: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[Any, Optional[SandboxRunResult]]:
|
||||
) -> ToolExecutionResult:
|
||||
"""Execute the tool and return the result."""
|
||||
|
||||
|
||||
@ -49,13 +51,19 @@ class LettaCoreToolExecutor(ToolExecutor):
|
||||
actor: User,
|
||||
sandbox_config: Optional[SandboxConfig] = None,
|
||||
sandbox_env_vars: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[Any, Optional[SandboxRunResult]]:
|
||||
) -> ToolExecutionResult:
|
||||
# Map function names to method calls
|
||||
function_map = {
|
||||
"send_message": self.send_message,
|
||||
"conversation_search": self.conversation_search,
|
||||
"archival_memory_search": self.archival_memory_search,
|
||||
"archival_memory_insert": self.archival_memory_insert,
|
||||
"core_memory_append": self.core_memory_append,
|
||||
"core_memory_replace": self.core_memory_replace,
|
||||
"memory_replace": self.memory_replace,
|
||||
"memory_insert": self.memory_insert,
|
||||
"memory_rethink": self.memory_rethink,
|
||||
"memory_finish_edits": self.memory_finish_edits,
|
||||
}
|
||||
|
||||
if function_name not in function_map:
|
||||
@ -64,7 +72,10 @@ class LettaCoreToolExecutor(ToolExecutor):
|
||||
# Execute the appropriate function
|
||||
function_args_copy = function_args.copy() # Make a copy to avoid modifying the original
|
||||
function_response = function_map[function_name](agent_state, actor, **function_args_copy)
|
||||
return function_response, None
|
||||
return ToolExecutionResult(
|
||||
status="success",
|
||||
func_return=function_response,
|
||||
)
|
||||
|
||||
def send_message(self, agent_state: AgentState, actor: User, message: str) -> Optional[str]:
|
||||
"""
|
||||
@ -181,51 +192,7 @@ class LettaCoreToolExecutor(ToolExecutor):
|
||||
AgentManager().rebuild_system_prompt(agent_id=agent_state.id, actor=actor, force=True)
|
||||
return None
|
||||
|
||||
|
||||
class LettaMultiAgentToolExecutor(ToolExecutor):
|
||||
"""Executor for LETTA multi-agent core tools."""
|
||||
|
||||
# TODO: Implement
|
||||
# def execute(self, function_name: str, function_args: dict, agent: "Agent", tool: Tool) -> Tuple[
|
||||
# Any, Optional[SandboxRunResult]]:
|
||||
# callable_func = get_function_from_module(LETTA_MULTI_AGENT_TOOL_MODULE_NAME, function_name)
|
||||
# function_args["self"] = agent # need to attach self to arg since it's dynamically linked
|
||||
# function_response = callable_func(**function_args)
|
||||
# return function_response, None
|
||||
|
||||
|
||||
class LettaMemoryToolExecutor(ToolExecutor):
|
||||
"""Executor for LETTA memory core tools with direct implementation."""
|
||||
|
||||
def execute(
|
||||
self,
|
||||
function_name: str,
|
||||
function_args: dict,
|
||||
agent_state: AgentState,
|
||||
tool: Tool,
|
||||
actor: User,
|
||||
sandbox_config: Optional[SandboxConfig] = None,
|
||||
sandbox_env_vars: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[Any, Optional[SandboxRunResult]]:
|
||||
# Map function names to method calls
|
||||
function_map = {
|
||||
"core_memory_append": self.core_memory_append,
|
||||
"core_memory_replace": self.core_memory_replace,
|
||||
}
|
||||
|
||||
if function_name not in function_map:
|
||||
raise ValueError(f"Unknown function: {function_name}")
|
||||
|
||||
# Execute the appropriate function with the copied state
|
||||
function_args_copy = function_args.copy() # Make a copy to avoid modifying the original
|
||||
function_response = function_map[function_name](agent_state, **function_args_copy)
|
||||
|
||||
# Update memory if changed
|
||||
AgentManager().update_memory_if_changed(agent_id=agent_state.id, new_memory=agent_state.memory, actor=actor)
|
||||
|
||||
return function_response, None
|
||||
|
||||
def core_memory_append(self, agent_state: "AgentState", label: str, content: str) -> Optional[str]:
|
||||
def core_memory_append(self, agent_state: "AgentState", actor: User, label: str, content: str) -> Optional[str]:
|
||||
"""
|
||||
Append to the contents of core memory.
|
||||
|
||||
@ -239,9 +206,17 @@ class LettaMemoryToolExecutor(ToolExecutor):
|
||||
current_value = str(agent_state.memory.get_block(label).value)
|
||||
new_value = current_value + "\n" + str(content)
|
||||
agent_state.memory.update_block_value(label=label, value=new_value)
|
||||
AgentManager().update_memory_if_changed(agent_id=agent_state.id, new_memory=agent_state.memory, actor=actor)
|
||||
return None
|
||||
|
||||
def core_memory_replace(self, agent_state: "AgentState", label: str, old_content: str, new_content: str) -> Optional[str]:
|
||||
def core_memory_replace(
|
||||
self,
|
||||
agent_state: "AgentState",
|
||||
actor: User,
|
||||
label: str,
|
||||
old_content: str,
|
||||
new_content: str,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Replace the contents of core memory. To delete memories, use an empty string for new_content.
|
||||
|
||||
@ -258,8 +233,253 @@ class LettaMemoryToolExecutor(ToolExecutor):
|
||||
raise ValueError(f"Old content '{old_content}' not found in memory block '{label}'")
|
||||
new_value = current_value.replace(str(old_content), str(new_content))
|
||||
agent_state.memory.update_block_value(label=label, value=new_value)
|
||||
AgentManager().update_memory_if_changed(agent_id=agent_state.id, new_memory=agent_state.memory, actor=actor)
|
||||
return None
|
||||
|
||||
def memory_replace(
|
||||
agent_state: "AgentState",
|
||||
actor: User,
|
||||
label: str,
|
||||
old_str: str,
|
||||
new_str: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
The memory_replace command allows you to replace a specific string in a memory
|
||||
block with a new string. This is used for making precise edits.
|
||||
|
||||
Args:
|
||||
label (str): Section of the memory to be edited, identified by its label.
|
||||
old_str (str): The text to replace (must match exactly, including whitespace
|
||||
and indentation).
|
||||
new_str (Optional[str]): The new text to insert in place of the old text.
|
||||
Omit this argument to delete the old_str.
|
||||
|
||||
Returns:
|
||||
str: The success message
|
||||
"""
|
||||
import re
|
||||
|
||||
if bool(re.search(r"\nLine \d+: ", old_str)):
|
||||
raise ValueError(
|
||||
"old_str contains a line number prefix, which is not allowed. "
|
||||
"Do not include line numbers when calling memory tools (line "
|
||||
"numbers are for display purposes only)."
|
||||
)
|
||||
if CORE_MEMORY_LINE_NUMBER_WARNING in old_str:
|
||||
raise ValueError(
|
||||
"old_str contains a line number warning, which is not allowed. "
|
||||
"Do not include line number information when calling memory tools "
|
||||
"(line numbers are for display purposes only)."
|
||||
)
|
||||
if bool(re.search(r"\nLine \d+: ", new_str)):
|
||||
raise ValueError(
|
||||
"new_str contains a line number prefix, which is not allowed. "
|
||||
"Do not include line numbers when calling memory tools (line "
|
||||
"numbers are for display purposes only)."
|
||||
)
|
||||
|
||||
old_str = str(old_str).expandtabs()
|
||||
new_str = str(new_str).expandtabs()
|
||||
current_value = str(agent_state.memory.get_block(label).value).expandtabs()
|
||||
|
||||
# Check if old_str is unique in the block
|
||||
occurences = current_value.count(old_str)
|
||||
if occurences == 0:
|
||||
raise ValueError(
|
||||
f"No replacement was performed, old_str `{old_str}` did not appear " f"verbatim in memory block with label `{label}`."
|
||||
)
|
||||
elif occurences > 1:
|
||||
content_value_lines = current_value.split("\n")
|
||||
lines = [idx + 1 for idx, line in enumerate(content_value_lines) if old_str in line]
|
||||
raise ValueError(
|
||||
f"No replacement was performed. Multiple occurrences of "
|
||||
f"old_str `{old_str}` in lines {lines}. Please ensure it is unique."
|
||||
)
|
||||
|
||||
# Replace old_str with new_str
|
||||
new_value = current_value.replace(str(old_str), str(new_str))
|
||||
|
||||
# Write the new content to the block
|
||||
agent_state.memory.update_block_value(label=label, value=new_value)
|
||||
|
||||
AgentManager().update_memory_if_changed(agent_id=agent_state.id, new_memory=agent_state.memory, actor=actor)
|
||||
|
||||
# Create a snippet of the edited section
|
||||
SNIPPET_LINES = 3
|
||||
replacement_line = current_value.split(old_str)[0].count("\n")
|
||||
start_line = max(0, replacement_line - SNIPPET_LINES)
|
||||
end_line = replacement_line + SNIPPET_LINES + new_str.count("\n")
|
||||
snippet = "\n".join(new_value.split("\n")[start_line : end_line + 1])
|
||||
|
||||
# Prepare the success message
|
||||
success_msg = f"The core memory block with label `{label}` has been edited. "
|
||||
# success_msg += self._make_output(
|
||||
# snippet, f"a snippet of {path}", start_line + 1
|
||||
# )
|
||||
# success_msg += f"A snippet of core memory block `{label}`:\n{snippet}\n"
|
||||
success_msg += (
|
||||
"Review the changes and make sure they are as expected (correct indentation, "
|
||||
"no duplicate lines, etc). Edit the memory block again if necessary."
|
||||
)
|
||||
|
||||
# return None
|
||||
return success_msg
|
||||
|
||||
def memory_insert(
|
||||
agent_state: "AgentState",
|
||||
actor: User,
|
||||
label: str,
|
||||
new_str: str,
|
||||
insert_line: int = -1,
|
||||
) -> str:
|
||||
"""
|
||||
The memory_insert command allows you to insert text at a specific location
|
||||
in a memory block.
|
||||
|
||||
Args:
|
||||
label (str): Section of the memory to be edited, identified by its label.
|
||||
new_str (str): The text to insert.
|
||||
insert_line (int): The line number after which to insert the text (0 for
|
||||
beginning of file). Defaults to -1 (end of the file).
|
||||
|
||||
Returns:
|
||||
str: The success message
|
||||
"""
|
||||
import re
|
||||
|
||||
if bool(re.search(r"\nLine \d+: ", new_str)):
|
||||
raise ValueError(
|
||||
"new_str contains a line number prefix, which is not allowed. Do not "
|
||||
"include line numbers when calling memory tools (line numbers are for "
|
||||
"display purposes only)."
|
||||
)
|
||||
if CORE_MEMORY_LINE_NUMBER_WARNING in new_str:
|
||||
raise ValueError(
|
||||
"new_str contains a line number warning, which is not allowed. Do not "
|
||||
"include line number information when calling memory tools (line numbers "
|
||||
"are for display purposes only)."
|
||||
)
|
||||
|
||||
current_value = str(agent_state.memory.get_block(label).value).expandtabs()
|
||||
new_str = str(new_str).expandtabs()
|
||||
current_value_lines = current_value.split("\n")
|
||||
n_lines = len(current_value_lines)
|
||||
|
||||
# Check if we're in range, from 0 (pre-line), to 1 (first line), to n_lines (last line)
|
||||
if insert_line < 0 or insert_line > n_lines:
|
||||
raise ValueError(
|
||||
f"Invalid `insert_line` parameter: {insert_line}. It should be within "
|
||||
f"the range of lines of the memory block: {[0, n_lines]}, or -1 to "
|
||||
f"append to the end of the memory block."
|
||||
)
|
||||
|
||||
# Insert the new string as a line
|
||||
SNIPPET_LINES = 3
|
||||
new_str_lines = new_str.split("\n")
|
||||
new_value_lines = current_value_lines[:insert_line] + new_str_lines + current_value_lines[insert_line:]
|
||||
snippet_lines = (
|
||||
current_value_lines[max(0, insert_line - SNIPPET_LINES) : insert_line]
|
||||
+ new_str_lines
|
||||
+ current_value_lines[insert_line : insert_line + SNIPPET_LINES]
|
||||
)
|
||||
|
||||
# Collate into the new value to update
|
||||
new_value = "\n".join(new_value_lines)
|
||||
snippet = "\n".join(snippet_lines)
|
||||
|
||||
# Write into the block
|
||||
agent_state.memory.update_block_value(label=label, value=new_value)
|
||||
|
||||
AgentManager().update_memory_if_changed(agent_id=agent_state.id, new_memory=agent_state.memory, actor=actor)
|
||||
|
||||
# Prepare the success message
|
||||
success_msg = f"The core memory block with label `{label}` has been edited. "
|
||||
# success_msg += self._make_output(
|
||||
# snippet,
|
||||
# "a snippet of the edited file",
|
||||
# max(1, insert_line - SNIPPET_LINES + 1),
|
||||
# )
|
||||
# success_msg += f"A snippet of core memory block `{label}`:\n{snippet}\n"
|
||||
success_msg += (
|
||||
"Review the changes and make sure they are as expected (correct indentation, "
|
||||
"no duplicate lines, etc). Edit the memory block again if necessary."
|
||||
)
|
||||
|
||||
return success_msg
|
||||
|
||||
def memory_rethink(agent_state: "AgentState", actor: User, label: str, new_memory: str) -> str:
|
||||
"""
|
||||
The memory_rethink command allows you to completely rewrite the contents of a
|
||||
memory block. Use this tool to make large sweeping changes (e.g. when you want
|
||||
to condense or reorganize the memory blocks), do NOT use this tool to make small
|
||||
precise edits (e.g. add or remove a line, replace a specific string, etc).
|
||||
|
||||
Args:
|
||||
label (str): The memory block to be rewritten, identified by its label.
|
||||
new_memory (str): The new memory contents with information integrated from
|
||||
existing memory blocks and the conversation context.
|
||||
|
||||
Returns:
|
||||
str: The success message
|
||||
"""
|
||||
import re
|
||||
|
||||
if bool(re.search(r"\nLine \d+: ", new_memory)):
|
||||
raise ValueError(
|
||||
"new_memory contains a line number prefix, which is not allowed. Do not "
|
||||
"include line numbers when calling memory tools (line numbers are for "
|
||||
"display purposes only)."
|
||||
)
|
||||
if CORE_MEMORY_LINE_NUMBER_WARNING in new_memory:
|
||||
raise ValueError(
|
||||
"new_memory contains a line number warning, which is not allowed. Do not "
|
||||
"include line number information when calling memory tools (line numbers "
|
||||
"are for display purposes only)."
|
||||
)
|
||||
|
||||
if agent_state.memory.get_block(label) is None:
|
||||
agent_state.memory.create_block(label=label, value=new_memory)
|
||||
|
||||
agent_state.memory.update_block_value(label=label, value=new_memory)
|
||||
|
||||
AgentManager().update_memory_if_changed(agent_id=agent_state.id, new_memory=agent_state.memory, actor=actor)
|
||||
|
||||
# Prepare the success message
|
||||
success_msg = f"The core memory block with label `{label}` has been edited. "
|
||||
# success_msg += self._make_output(
|
||||
# snippet, f"a snippet of {path}", start_line + 1
|
||||
# )
|
||||
# success_msg += f"A snippet of core memory block `{label}`:\n{snippet}\n"
|
||||
success_msg += (
|
||||
"Review the changes and make sure they are as expected (correct indentation, "
|
||||
"no duplicate lines, etc). Edit the memory block again if necessary."
|
||||
)
|
||||
|
||||
# return None
|
||||
return success_msg
|
||||
|
||||
def memory_finish_edits(agent_state: "AgentState") -> None:
|
||||
"""
|
||||
Call the memory_finish_edits command when you are finished making edits
|
||||
(integrating all new information) into the memory blocks. This function
|
||||
is called when the agent is done rethinking the memory.
|
||||
|
||||
Returns:
|
||||
Optional[str]: None is always returned as this function does not produce a response.
|
||||
"""
|
||||
return None
|
||||
|
||||
|
||||
class LettaMultiAgentToolExecutor(ToolExecutor):
|
||||
"""Executor for LETTA multi-agent core tools."""
|
||||
|
||||
# TODO: Implement
|
||||
# def execute(self, function_name: str, function_args: dict, agent: "Agent", tool: Tool) -> ToolExecutionResult:
|
||||
# callable_func = get_function_from_module(LETTA_MULTI_AGENT_TOOL_MODULE_NAME, function_name)
|
||||
# function_args["self"] = agent # need to attach self to arg since it's dynamically linked
|
||||
# function_response = callable_func(**function_args)
|
||||
# return ToolExecutionResult(func_return=function_response)
|
||||
|
||||
|
||||
class ExternalComposioToolExecutor(ToolExecutor):
|
||||
"""Executor for external Composio tools."""
|
||||
@ -273,7 +493,7 @@ class ExternalComposioToolExecutor(ToolExecutor):
|
||||
actor: User,
|
||||
sandbox_config: Optional[SandboxConfig] = None,
|
||||
sandbox_env_vars: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[Any, Optional[SandboxRunResult]]:
|
||||
) -> ToolExecutionResult:
|
||||
action_name = generate_composio_action_from_func_name(tool.name)
|
||||
|
||||
# Get entity ID from the agent_state
|
||||
@ -287,7 +507,10 @@ class ExternalComposioToolExecutor(ToolExecutor):
|
||||
action_name=action_name, args=function_args, api_key=composio_api_key, entity_id=entity_id
|
||||
)
|
||||
|
||||
return function_response, None
|
||||
return ToolExecutionResult(
|
||||
status="success",
|
||||
func_return=function_response,
|
||||
)
|
||||
|
||||
def _get_entity_id(self, agent_state: AgentState) -> Optional[str]:
|
||||
"""Extract the entity ID from environment variables."""
|
||||
@ -302,8 +525,7 @@ class ExternalMCPToolExecutor(ToolExecutor):
|
||||
|
||||
# TODO: Implement
|
||||
#
|
||||
# def execute(self, function_name: str, function_args: dict, agent_state: AgentState, tool: Tool, actor: User) -> Tuple[
|
||||
# Any, Optional[SandboxRunResult]]:
|
||||
# def execute(self, function_name: str, function_args: dict, agent_state: AgentState, tool: Tool, actor: User) -> ToolExecutionResult:
|
||||
# # Get the server name from the tool tag
|
||||
# server_name = self._extract_server_name(tool)
|
||||
#
|
||||
@ -316,8 +538,10 @@ class ExternalMCPToolExecutor(ToolExecutor):
|
||||
# # Execute the tool
|
||||
# function_response, is_error = mcp_client.execute_tool(tool_name=function_name, tool_args=function_args)
|
||||
#
|
||||
# sandbox_run_result = SandboxRunResult(status="error" if is_error else "success")
|
||||
# return function_response, sandbox_run_result
|
||||
# return ToolExecutionResult(
|
||||
# status="error" if is_error else "success",
|
||||
# func_return=function_response,
|
||||
# )
|
||||
#
|
||||
# def _extract_server_name(self, tool: Tool) -> str:
|
||||
# """Extract server name from tool tags."""
|
||||
@ -360,7 +584,7 @@ class SandboxToolExecutor(ToolExecutor):
|
||||
actor: User,
|
||||
sandbox_config: Optional[SandboxConfig] = None,
|
||||
sandbox_env_vars: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[Any, Optional[SandboxRunResult]]:
|
||||
) -> ToolExecutionResult:
|
||||
|
||||
# Store original memory state
|
||||
orig_memory_str = agent_state.memory.compile()
|
||||
@ -381,21 +605,19 @@ class SandboxToolExecutor(ToolExecutor):
|
||||
function_name, function_args, actor, tool_object=tool, sandbox_config=sandbox_config, sandbox_env_vars=sandbox_env_vars
|
||||
)
|
||||
|
||||
sandbox_run_result = await sandbox.run(agent_state=agent_state_copy)
|
||||
|
||||
function_response, updated_agent_state = sandbox_run_result.func_return, sandbox_run_result.agent_state
|
||||
tool_execution_result = await sandbox.run(agent_state=agent_state_copy)
|
||||
|
||||
# Verify memory integrity
|
||||
assert orig_memory_str == agent_state.memory.compile(), "Memory should not be modified in a sandbox tool"
|
||||
|
||||
# Update agent memory if needed
|
||||
if updated_agent_state is not None:
|
||||
AgentManager().update_memory_if_changed(agent_state.id, updated_agent_state.memory, actor)
|
||||
if tool_execution_result.agent_state is not None:
|
||||
AgentManager().update_memory_if_changed(agent_state.id, tool_execution_result.agent_state.memory, actor)
|
||||
|
||||
return function_response, sandbox_run_result
|
||||
return tool_execution_result
|
||||
|
||||
except Exception as e:
|
||||
return self._handle_execution_error(e, function_name)
|
||||
return self._handle_execution_error(e, function_name, traceback.format_exc())
|
||||
|
||||
def _prepare_function_args(self, function_args: dict, tool: Tool, function_name: str) -> dict:
|
||||
"""Prepare function arguments with proper type coercion."""
|
||||
@ -417,9 +639,18 @@ class SandboxToolExecutor(ToolExecutor):
|
||||
agent_state_copy.tool_rules = []
|
||||
return agent_state_copy
|
||||
|
||||
def _handle_execution_error(self, exception: Exception, function_name: str) -> Tuple[str, SandboxRunResult]:
|
||||
def _handle_execution_error(
|
||||
self,
|
||||
exception: Exception,
|
||||
function_name: str,
|
||||
stderr: str,
|
||||
) -> ToolExecutionResult:
|
||||
"""Handle tool execution errors."""
|
||||
error_message = get_friendly_error_msg(
|
||||
function_name=function_name, exception_name=type(exception).__name__, exception_message=str(exception)
|
||||
)
|
||||
return error_message, SandboxRunResult(status="error")
|
||||
return ToolExecutionResult(
|
||||
status="error",
|
||||
func_return=error_message,
|
||||
stderr=[stderr],
|
||||
)
|
||||
|
@ -7,8 +7,9 @@ from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
from letta.functions.helpers import generate_model_from_args_json_schema
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.sandbox_config import SandboxConfig, SandboxRunResult
|
||||
from letta.schemas.sandbox_config import SandboxConfig
|
||||
from letta.schemas.tool import Tool
|
||||
from letta.schemas.tool_execution_result import ToolExecutionResult
|
||||
from letta.services.helpers.tool_execution_helper import add_imports_and_pydantic_schemas_for_args
|
||||
from letta.services.sandbox_config_manager import SandboxConfigManager
|
||||
from letta.services.tool_manager import ToolManager
|
||||
@ -64,7 +65,7 @@ class AsyncToolSandboxBase(ABC):
|
||||
self,
|
||||
agent_state: Optional[AgentState] = None,
|
||||
additional_env_vars: Optional[Dict] = None,
|
||||
) -> SandboxRunResult:
|
||||
) -> ToolExecutionResult:
|
||||
"""
|
||||
Run the tool in a sandbox environment asynchronously.
|
||||
Must be implemented by subclasses.
|
||||
|
@ -2,8 +2,9 @@ from typing import Any, Dict, Optional
|
||||
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.sandbox_config import SandboxConfig, SandboxRunResult, SandboxType
|
||||
from letta.schemas.sandbox_config import SandboxConfig, SandboxType
|
||||
from letta.schemas.tool import Tool
|
||||
from letta.schemas.tool_execution_result import ToolExecutionResult
|
||||
from letta.services.tool_sandbox.base import AsyncToolSandboxBase
|
||||
from letta.utils import get_friendly_error_msg
|
||||
|
||||
@ -30,7 +31,7 @@ class AsyncToolSandboxE2B(AsyncToolSandboxBase):
|
||||
self,
|
||||
agent_state: Optional[AgentState] = None,
|
||||
additional_env_vars: Optional[Dict] = None,
|
||||
) -> SandboxRunResult:
|
||||
) -> ToolExecutionResult:
|
||||
"""
|
||||
Run the tool in a sandbox environment asynchronously,
|
||||
*always* using a subprocess for execution.
|
||||
@ -45,7 +46,7 @@ class AsyncToolSandboxE2B(AsyncToolSandboxBase):
|
||||
|
||||
async def run_e2b_sandbox(
|
||||
self, agent_state: Optional[AgentState] = None, additional_env_vars: Optional[Dict] = None
|
||||
) -> SandboxRunResult:
|
||||
) -> ToolExecutionResult:
|
||||
if self.provided_sandbox_config:
|
||||
sbx_config = self.provided_sandbox_config
|
||||
else:
|
||||
@ -94,7 +95,7 @@ class AsyncToolSandboxE2B(AsyncToolSandboxBase):
|
||||
else:
|
||||
raise ValueError(f"Tool {self.tool_name} returned execution with None")
|
||||
|
||||
return SandboxRunResult(
|
||||
return ToolExecutionResult(
|
||||
func_return=func_return,
|
||||
agent_state=agent_state,
|
||||
stdout=execution.logs.stdout,
|
||||
|
@ -5,8 +5,9 @@ import tempfile
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.sandbox_config import SandboxConfig, SandboxRunResult, SandboxType
|
||||
from letta.schemas.sandbox_config import SandboxConfig, SandboxType
|
||||
from letta.schemas.tool import Tool
|
||||
from letta.schemas.tool_execution_result import ToolExecutionResult
|
||||
from letta.services.helpers.tool_execution_helper import (
|
||||
create_venv_for_local_sandbox,
|
||||
find_python_executable,
|
||||
@ -39,7 +40,7 @@ class AsyncToolSandboxLocal(AsyncToolSandboxBase):
|
||||
self,
|
||||
agent_state: Optional[AgentState] = None,
|
||||
additional_env_vars: Optional[Dict] = None,
|
||||
) -> SandboxRunResult:
|
||||
) -> ToolExecutionResult:
|
||||
"""
|
||||
Run the tool in a sandbox environment asynchronously,
|
||||
*always* using a subprocess for execution.
|
||||
@ -53,7 +54,11 @@ class AsyncToolSandboxLocal(AsyncToolSandboxBase):
|
||||
return result
|
||||
|
||||
@trace_method
|
||||
async def run_local_dir_sandbox(self, agent_state: Optional[AgentState], additional_env_vars: Optional[Dict]) -> SandboxRunResult:
|
||||
async def run_local_dir_sandbox(
|
||||
self,
|
||||
agent_state: Optional[AgentState],
|
||||
additional_env_vars: Optional[Dict],
|
||||
) -> ToolExecutionResult:
|
||||
"""
|
||||
Unified asynchronougit pus method to run the tool in a local sandbox environment,
|
||||
always via subprocess for multi-core parallelism.
|
||||
@ -156,7 +161,7 @@ class AsyncToolSandboxLocal(AsyncToolSandboxBase):
|
||||
@trace_method
|
||||
async def _execute_tool_subprocess(
|
||||
self, sbx_config, python_executable: str, temp_file_path: str, env: Dict[str, str], cwd: str
|
||||
) -> SandboxRunResult:
|
||||
) -> ToolExecutionResult:
|
||||
"""
|
||||
Execute user code in a subprocess, always capturing stdout and stderr.
|
||||
We parse special markers to extract the pickled result string.
|
||||
@ -189,7 +194,7 @@ class AsyncToolSandboxLocal(AsyncToolSandboxBase):
|
||||
func_result, stdout_text = self.parse_out_function_results_markers(stdout)
|
||||
func_return, agent_state = self.parse_best_effort(func_result)
|
||||
|
||||
return SandboxRunResult(
|
||||
return ToolExecutionResult(
|
||||
func_return=func_return,
|
||||
agent_state=agent_state,
|
||||
stdout=[stdout_text] if stdout_text else [],
|
||||
@ -209,7 +214,7 @@ class AsyncToolSandboxLocal(AsyncToolSandboxBase):
|
||||
exception_name=type(e).__name__,
|
||||
exception_message=str(e),
|
||||
)
|
||||
return SandboxRunResult(
|
||||
return ToolExecutionResult(
|
||||
func_return=func_return,
|
||||
agent_state=None,
|
||||
stdout=[],
|
||||
|
453
poetry.lock
generated
453
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "letta"
|
||||
version = "0.7.0"
|
||||
version = "0.7.1"
|
||||
packages = [
|
||||
{include = "letta"},
|
||||
]
|
||||
|
@ -1,80 +1,140 @@
|
||||
# import time
|
||||
#
|
||||
# import pytest
|
||||
# from letta_client import Letta, LettaBatchRequest, MessageCreate, TextContent
|
||||
#
|
||||
#
|
||||
# @pytest.fixture(scope="module")
|
||||
# def client():
|
||||
# return Letta(base_url="http://localhost:8283")
|
||||
#
|
||||
#
|
||||
# def test_create_batch(client: Letta):
|
||||
#
|
||||
# # create agents
|
||||
# agent1 = client.agents.create(
|
||||
# name="agent1",
|
||||
# memory_blocks=[{"label": "persona", "value": "you are agent 1"}],
|
||||
# model="anthropic/claude-3-7-sonnet-20250219",
|
||||
# embedding="letta/letta-free",
|
||||
# )
|
||||
# agent2 = client.agents.create(
|
||||
# name="agent2",
|
||||
# memory_blocks=[{"label": "persona", "value": "you are agent 2"}],
|
||||
# model="anthropic/claude-3-7-sonnet-20250219",
|
||||
# embedding="letta/letta-free",
|
||||
# )
|
||||
#
|
||||
# # create a run
|
||||
# run = client.messages.batches.create(
|
||||
# requests=[
|
||||
# LettaBatchRequest(
|
||||
# messages=[
|
||||
# MessageCreate(
|
||||
# role="user",
|
||||
# content=[
|
||||
# TextContent(
|
||||
# text="text",
|
||||
# )
|
||||
# ],
|
||||
# )
|
||||
# ],
|
||||
# agent_id=agent1.id,
|
||||
# ),
|
||||
# LettaBatchRequest(
|
||||
# messages=[
|
||||
# MessageCreate(
|
||||
# role="user",
|
||||
# content=[
|
||||
# TextContent(
|
||||
# text="text",
|
||||
# )
|
||||
# ],
|
||||
# )
|
||||
# ],
|
||||
# agent_id=agent2.id,
|
||||
# ),
|
||||
# ]
|
||||
# )
|
||||
# assert run is not None
|
||||
#
|
||||
# # list batches
|
||||
# batches = client.messages.batches.list()
|
||||
# assert len(batches) > 0, f"Expected 1 batch, got {len(batches)}"
|
||||
#
|
||||
# # check run status
|
||||
# while True:
|
||||
# run = client.messages.batches.retrieve(batch_id=run.id)
|
||||
# if run.status == "completed":
|
||||
# break
|
||||
# print("Waiting for run to complete...", run.status)
|
||||
# time.sleep(1)
|
||||
#
|
||||
# # get the batch results
|
||||
# results = client.messages.batches.retrieve(
|
||||
# run_id=run.id,
|
||||
# )
|
||||
# assert results is not None
|
||||
# print(results)
|
||||
#
|
||||
# # cancel a run
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
|
||||
import pytest
|
||||
from dotenv import load_dotenv
|
||||
from letta_client import Letta, LettaBatchRequest, MessageCreate, TextContent
|
||||
|
||||
from letta.config import LettaConfig
|
||||
from letta.jobs.llm_batch_job_polling import poll_running_llm_batches
|
||||
from letta.orm import Base
|
||||
from letta.schemas.enums import JobStatus
|
||||
from letta.server.db import db_context
|
||||
from letta.server.server import SyncServer
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_batch_tables():
|
||||
"""Clear batch-related tables before each test."""
|
||||
with db_context() as session:
|
||||
for table in reversed(Base.metadata.sorted_tables):
|
||||
if table.name in {"jobs", "llm_batch_job", "llm_batch_items"}:
|
||||
session.execute(table.delete()) # Truncate table
|
||||
session.commit()
|
||||
|
||||
|
||||
def run_server():
|
||||
"""Starts the Letta server in a background thread."""
|
||||
load_dotenv()
|
||||
from letta.server.rest_api.app import start_server
|
||||
|
||||
start_server(debug=True)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def server_url():
|
||||
"""
|
||||
Ensures a server is running and returns its base URL.
|
||||
|
||||
Uses environment variable if available, otherwise starts a server
|
||||
in a background thread.
|
||||
"""
|
||||
url = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
|
||||
|
||||
if not os.getenv("LETTA_SERVER_URL"):
|
||||
thread = threading.Thread(target=run_server, daemon=True)
|
||||
thread.start()
|
||||
time.sleep(5) # Give server time to start
|
||||
|
||||
return url
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
"""
|
||||
Creates a SyncServer instance for testing.
|
||||
|
||||
Loads and saves config to ensure proper initialization.
|
||||
"""
|
||||
config = LettaConfig.load()
|
||||
config.save()
|
||||
return SyncServer()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def client(server_url):
|
||||
"""Creates a REST client connected to the test server."""
|
||||
return Letta(base_url=server_url)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_batch(client: Letta, server: SyncServer):
|
||||
|
||||
# create agents
|
||||
agent1 = client.agents.create(
|
||||
name="agent1_batch",
|
||||
memory_blocks=[{"label": "persona", "value": "you are agent 1"}],
|
||||
model="anthropic/claude-3-7-sonnet-20250219",
|
||||
embedding="letta/letta-free",
|
||||
)
|
||||
agent2 = client.agents.create(
|
||||
name="agent2_batch",
|
||||
memory_blocks=[{"label": "persona", "value": "you are agent 2"}],
|
||||
model="anthropic/claude-3-7-sonnet-20250219",
|
||||
embedding="letta/letta-free",
|
||||
)
|
||||
|
||||
# create a run
|
||||
run = client.batches.create(
|
||||
requests=[
|
||||
LettaBatchRequest(
|
||||
messages=[
|
||||
MessageCreate(
|
||||
role="user",
|
||||
content=[
|
||||
TextContent(
|
||||
text="hi",
|
||||
)
|
||||
],
|
||||
)
|
||||
],
|
||||
agent_id=agent1.id,
|
||||
),
|
||||
LettaBatchRequest(
|
||||
messages=[
|
||||
MessageCreate(
|
||||
role="user",
|
||||
content=[
|
||||
TextContent(
|
||||
text="hi",
|
||||
)
|
||||
],
|
||||
)
|
||||
],
|
||||
agent_id=agent2.id,
|
||||
),
|
||||
]
|
||||
)
|
||||
assert run is not None
|
||||
|
||||
# list batches
|
||||
batches = client.batches.list()
|
||||
assert len(batches) == 1, f"Expected 1 batch, got {len(batches)}"
|
||||
assert batches[0].status == JobStatus.running
|
||||
|
||||
# Poll it once
|
||||
await poll_running_llm_batches(server)
|
||||
|
||||
# get the batch results
|
||||
results = client.batches.retrieve(
|
||||
batch_id=run.id,
|
||||
)
|
||||
assert results is not None
|
||||
|
||||
# cancel
|
||||
client.batches.cancel(batch_id=run.id)
|
||||
batch_job = client.batches.retrieve(
|
||||
batch_id=run.id,
|
||||
)
|
||||
assert batch_job.status == JobStatus.cancelled
|
||||
|
@ -67,9 +67,9 @@ def test_composio_tool_execution_e2e(check_composio_key_set, composio_get_emojis
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
function_response, sandbox_run_result = ToolExecutionManager(agent_state, actor=default_user).execute_tool(
|
||||
tool_execution_result = ToolExecutionManager(agent_state, actor=default_user).execute_tool(
|
||||
function_name=composio_get_emojis.name, function_args={}, tool=composio_get_emojis
|
||||
)
|
||||
|
||||
# Small check, it should return something at least
|
||||
assert len(function_response.keys()) > 10
|
||||
assert len(tool_execution_result.func_return.keys()) > 10
|
||||
|
192
tests/integration_test_send_message_schema.py
Normal file
192
tests/integration_test_send_message_schema.py
Normal file
@ -0,0 +1,192 @@
|
||||
# TODO (cliandy): Tested in SDK
|
||||
# TODO (cliandy): Comment out after merge
|
||||
|
||||
# import os
|
||||
# import threading
|
||||
# import time
|
||||
|
||||
# import pytest
|
||||
# from dotenv import load_dotenv
|
||||
# from letta_client import AssistantMessage, AsyncLetta, Letta, Tool
|
||||
|
||||
# from letta.schemas.agent import AgentState
|
||||
# from typing import List, Any, Dict
|
||||
|
||||
# # ------------------------------
|
||||
# # Fixtures
|
||||
# # ------------------------------
|
||||
|
||||
|
||||
# @pytest.fixture(scope="module")
|
||||
# def server_url() -> str:
|
||||
# """
|
||||
# Provides the URL for the Letta server.
|
||||
# If the environment variable 'LETTA_SERVER_URL' is not set, this fixture
|
||||
# will start the Letta server in a background thread and return the default URL.
|
||||
# """
|
||||
|
||||
# def _run_server() -> None:
|
||||
# """Starts the Letta server in a background thread."""
|
||||
# load_dotenv() # Load environment variables from .env file
|
||||
# from letta.server.rest_api.app import start_server
|
||||
|
||||
# start_server(debug=True)
|
||||
|
||||
# # Retrieve server URL from environment, or default to localhost
|
||||
# url: str = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
|
||||
|
||||
# # If no environment variable is set, start the server in a background thread
|
||||
# if not os.getenv("LETTA_SERVER_URL"):
|
||||
# thread = threading.Thread(target=_run_server, daemon=True)
|
||||
# thread.start()
|
||||
# time.sleep(5) # Allow time for the server to start
|
||||
|
||||
# return url
|
||||
|
||||
|
||||
# @pytest.fixture
|
||||
# def client(server_url: str) -> Letta:
|
||||
# """
|
||||
# Creates and returns a synchronous Letta REST client for testing.
|
||||
# """
|
||||
# client_instance = Letta(base_url=server_url)
|
||||
# yield client_instance
|
||||
|
||||
|
||||
# @pytest.fixture
|
||||
# def async_client(server_url: str) -> AsyncLetta:
|
||||
# """
|
||||
# Creates and returns an asynchronous Letta REST client for testing.
|
||||
# """
|
||||
# async_client_instance = AsyncLetta(base_url=server_url)
|
||||
# yield async_client_instance
|
||||
|
||||
|
||||
# @pytest.fixture
|
||||
# def roll_dice_tool(client: Letta) -> Tool:
|
||||
# """
|
||||
# Registers a simple roll dice tool with the provided client.
|
||||
|
||||
# The tool simulates rolling a six-sided die but returns a fixed result.
|
||||
# """
|
||||
|
||||
# def roll_dice() -> str:
|
||||
# """
|
||||
# Simulates rolling a die.
|
||||
|
||||
# Returns:
|
||||
# str: The roll result.
|
||||
# """
|
||||
# # Note: The result here is intentionally incorrect for demonstration purposes.
|
||||
# return "Rolled a 10!"
|
||||
|
||||
# tool = client.tools.upsert_from_function(func=roll_dice)
|
||||
# yield tool
|
||||
|
||||
|
||||
# @pytest.fixture
|
||||
# def agent_state(client: Letta, roll_dice_tool: Tool) -> AgentState:
|
||||
# """
|
||||
# Creates and returns an agent state for testing with a pre-configured agent.
|
||||
# The agent is named 'supervisor' and is configured with base tools and the roll_dice tool.
|
||||
# """
|
||||
# agent_state_instance = client.agents.create(
|
||||
# name="supervisor",
|
||||
# include_base_tools=True,
|
||||
# tool_ids=[roll_dice_tool.id],
|
||||
# model="openai/gpt-4o",
|
||||
# embedding="letta/letta-free",
|
||||
# tags=["supervisor"],
|
||||
# include_base_tool_rules=True,
|
||||
|
||||
# )
|
||||
# yield agent_state_instance
|
||||
|
||||
|
||||
# # Goal is to test that when an Agent is created with a `response_format`, that the response
|
||||
# # of `send_message` is in the correct format. This will be done by modifying the agent's
|
||||
# # `send_message` tool so that it returns a format based on what is passed in.
|
||||
# #
|
||||
# # `response_format` is an optional field
|
||||
# # if `response_format.type` is `text`, then the schema does not change
|
||||
# # if `response_format.type` is `json_object`, then the schema is a dict
|
||||
# # if `response_format.type` is `json_schema`, then the schema is a dict matching that json schema
|
||||
|
||||
|
||||
# USER_MESSAGE: List[Dict[str, str]] = [{"role": "user", "content": "Send me a message."}]
|
||||
|
||||
# # ------------------------------
|
||||
# # Test Cases
|
||||
# # ------------------------------
|
||||
|
||||
# def test_client_send_message_text_response_format(client: "Letta", agent: "AgentState") -> None:
|
||||
# """Test client send_message with response_format='json_object'."""
|
||||
# client.agents.modify(agent.id, response_format={"type": "text"})
|
||||
|
||||
# response = client.agents.messages.create_stream(
|
||||
# agent_id=agent.id,
|
||||
# messages=USER_MESSAGE,
|
||||
# )
|
||||
# messages = list(response)
|
||||
# assert isinstance(messages[-1], AssistantMessage)
|
||||
# assert isinstance(messages[-1].content, str)
|
||||
|
||||
|
||||
# def test_client_send_message_json_object_response_format(client: "Letta", agent: "AgentState") -> None:
|
||||
# """Test client send_message with response_format='json_object'."""
|
||||
# client.agents.modify(agent.id, response_format={"type": "json_object"})
|
||||
|
||||
# response = client.agents.messages.create_stream(
|
||||
# agent_id=agent.id,
|
||||
# messages=USER_MESSAGE,
|
||||
# )
|
||||
# messages = list(response)
|
||||
# assert isinstance(messages[-1], AssistantMessage)
|
||||
# assert isinstance(messages[-1].content, dict)
|
||||
|
||||
|
||||
# def test_client_send_message_json_schema_response_format(client: "Letta", agent: "AgentState") -> None:
|
||||
# """Test client send_message with response_format='json_schema' and a valid schema."""
|
||||
# client.agents.modify(agent.id, response_format={
|
||||
# "type": "json_schema",
|
||||
# "json_schema": {
|
||||
# "name": "reasoning_schema",
|
||||
# "schema": {
|
||||
# "type": "object",
|
||||
# "properties": {
|
||||
# "steps": {
|
||||
# "type": "array",
|
||||
# "items": {
|
||||
# "type": "object",
|
||||
# "properties": {
|
||||
# "explanation": { "type": "string" },
|
||||
# "output": { "type": "string" }
|
||||
# },
|
||||
# "required": ["explanation", "output"],
|
||||
# "additionalProperties": False
|
||||
# }
|
||||
# },
|
||||
# "final_answer": { "type": "string" }
|
||||
# },
|
||||
# "required": ["steps", "final_answer"],
|
||||
# "additionalProperties": True
|
||||
# },
|
||||
# "strict": True
|
||||
# }
|
||||
# })
|
||||
# response = client.agents.messages.create_stream(
|
||||
# agent_id=agent.id,
|
||||
# messages=USER_MESSAGE,
|
||||
# )
|
||||
# messages = list(response)
|
||||
|
||||
# assert isinstance(messages[-1], AssistantMessage)
|
||||
# assert isinstance(messages[-1].content, dict)
|
||||
|
||||
|
||||
# # def test_client_send_message_invalid_json_schema(client: "Letta", agent: "AgentState") -> None:
|
||||
# # """Test client send_message with an invalid json_schema (should error or fallback)."""
|
||||
# # invalid_schema: Dict[str, Any] = {"type": "object", "properties": {"foo": {"type": "unknown"}}}
|
||||
# # client.agents.modify(agent.id, response_format="json_schema")
|
||||
# # result: Any = client.agents.send_message(agent.id, "Test invalid schema")
|
||||
# # assert result is None or "error" in str(result).lower()
|
@ -132,7 +132,7 @@ async def test_sleeptime_group_chat(server, actor):
|
||||
response = await server.send_message_to_agent(
|
||||
agent_id=main_agent.id,
|
||||
actor=actor,
|
||||
messages=[
|
||||
input_messages=[
|
||||
MessageCreate(
|
||||
role="user",
|
||||
content=text,
|
||||
@ -206,7 +206,7 @@ async def test_sleeptime_removes_redundant_information(server, actor):
|
||||
_ = await server.send_message_to_agent(
|
||||
agent_id=main_agent.id,
|
||||
actor=actor,
|
||||
messages=[
|
||||
input_messages=[
|
||||
MessageCreate(
|
||||
role="user",
|
||||
content=test_message,
|
||||
@ -270,7 +270,7 @@ async def test_sleeptime_edit(server, actor):
|
||||
_ = await server.send_message_to_agent(
|
||||
agent_id=sleeptime_agent.id,
|
||||
actor=actor,
|
||||
messages=[
|
||||
input_messages=[
|
||||
MessageCreate(
|
||||
role="user",
|
||||
content="Messi has now moved to playing for Inter Miami",
|
||||
|
@ -454,7 +454,7 @@ def test_agent_serialize_with_user_messages(local_client, server, serialize_test
|
||||
"""Test deserializing JSON into an Agent instance."""
|
||||
append_copy_suffix = False
|
||||
server.send_messages(
|
||||
actor=default_user, agent_id=serialize_test_agent.id, messages=[MessageCreate(role=MessageRole.user, content="hello")]
|
||||
actor=default_user, agent_id=serialize_test_agent.id, input_messages=[MessageCreate(role=MessageRole.user, content="hello")]
|
||||
)
|
||||
result = server.agent_manager.serialize(agent_id=serialize_test_agent.id, actor=default_user)
|
||||
|
||||
@ -470,10 +470,12 @@ def test_agent_serialize_with_user_messages(local_client, server, serialize_test
|
||||
|
||||
# Make sure both agents can receive messages after
|
||||
server.send_messages(
|
||||
actor=default_user, agent_id=serialize_test_agent.id, messages=[MessageCreate(role=MessageRole.user, content="and hello again")]
|
||||
actor=default_user,
|
||||
agent_id=serialize_test_agent.id,
|
||||
input_messages=[MessageCreate(role=MessageRole.user, content="and hello again")],
|
||||
)
|
||||
server.send_messages(
|
||||
actor=other_user, agent_id=agent_copy.id, messages=[MessageCreate(role=MessageRole.user, content="and hello again")]
|
||||
actor=other_user, agent_id=agent_copy.id, input_messages=[MessageCreate(role=MessageRole.user, content="and hello again")]
|
||||
)
|
||||
|
||||
|
||||
@ -483,7 +485,7 @@ def test_agent_serialize_tool_calls(disable_e2b_api_key, local_client, server, s
|
||||
server.send_messages(
|
||||
actor=default_user,
|
||||
agent_id=serialize_test_agent.id,
|
||||
messages=[MessageCreate(role=MessageRole.user, content="What's the weather like in San Francisco?")],
|
||||
input_messages=[MessageCreate(role=MessageRole.user, content="What's the weather like in San Francisco?")],
|
||||
)
|
||||
result = server.agent_manager.serialize(agent_id=serialize_test_agent.id, actor=default_user)
|
||||
|
||||
@ -501,12 +503,12 @@ def test_agent_serialize_tool_calls(disable_e2b_api_key, local_client, server, s
|
||||
original_agent_response = server.send_messages(
|
||||
actor=default_user,
|
||||
agent_id=serialize_test_agent.id,
|
||||
messages=[MessageCreate(role=MessageRole.user, content="What's the weather like in Seattle?")],
|
||||
input_messages=[MessageCreate(role=MessageRole.user, content="What's the weather like in Seattle?")],
|
||||
)
|
||||
copy_agent_response = server.send_messages(
|
||||
actor=other_user,
|
||||
agent_id=agent_copy.id,
|
||||
messages=[MessageCreate(role=MessageRole.user, content="What's the weather like in Seattle?")],
|
||||
input_messages=[MessageCreate(role=MessageRole.user, content="What's the weather like in Seattle?")],
|
||||
)
|
||||
|
||||
assert original_agent_response.completion_tokens > 0 and original_agent_response.step_count > 0
|
||||
@ -519,12 +521,12 @@ def test_agent_serialize_update_blocks(disable_e2b_api_key, local_client, server
|
||||
server.send_messages(
|
||||
actor=default_user,
|
||||
agent_id=serialize_test_agent.id,
|
||||
messages=[MessageCreate(role=MessageRole.user, content="Append 'banana' to core_memory.")],
|
||||
input_messages=[MessageCreate(role=MessageRole.user, content="Append 'banana' to core_memory.")],
|
||||
)
|
||||
server.send_messages(
|
||||
actor=default_user,
|
||||
agent_id=serialize_test_agent.id,
|
||||
messages=[MessageCreate(role=MessageRole.user, content="What do you think about that?")],
|
||||
input_messages=[MessageCreate(role=MessageRole.user, content="What do you think about that?")],
|
||||
)
|
||||
|
||||
result = server.agent_manager.serialize(agent_id=serialize_test_agent.id, actor=default_user)
|
||||
@ -543,12 +545,12 @@ def test_agent_serialize_update_blocks(disable_e2b_api_key, local_client, server
|
||||
original_agent_response = server.send_messages(
|
||||
actor=default_user,
|
||||
agent_id=serialize_test_agent.id,
|
||||
messages=[MessageCreate(role=MessageRole.user, content="Hi")],
|
||||
input_messages=[MessageCreate(role=MessageRole.user, content="Hi")],
|
||||
)
|
||||
copy_agent_response = server.send_messages(
|
||||
actor=other_user,
|
||||
agent_id=agent_copy.id,
|
||||
messages=[MessageCreate(role=MessageRole.user, content="Hi")],
|
||||
input_messages=[MessageCreate(role=MessageRole.user, content="Hi")],
|
||||
)
|
||||
|
||||
assert original_agent_response.completion_tokens > 0 and original_agent_response.step_count > 0
|
||||
@ -635,5 +637,5 @@ def test_upload_agentfile_from_disk(server, disable_e2b_api_key, fastapi_client,
|
||||
server.send_messages(
|
||||
actor=other_user,
|
||||
agent_id=copied_agent_id,
|
||||
messages=[MessageCreate(role=MessageRole.user, content="Hello there!")],
|
||||
input_messages=[MessageCreate(role=MessageRole.user, content="Hello there!")],
|
||||
)
|
||||
|
@ -1,11 +1,3 @@
|
||||
"""
|
||||
Tests for LettaAgentBatch.step_until_request functionality.
|
||||
|
||||
This module tests the batch processing capabilities of LettaAgentBatch,
|
||||
specifically the step_until_request method which prepares agent requests
|
||||
for batch processing.
|
||||
"""
|
||||
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
@ -92,6 +84,28 @@ def weather_tool(client):
|
||||
yield tool
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def rethink_tool(client):
|
||||
def rethink_memory(agent_state: "AgentState", new_memory: str, target_block_label: str) -> str: # type: ignore
|
||||
"""
|
||||
Re-evaluate the memory in block_name, integrating new and updated facts.
|
||||
Replace outdated information with the most likely truths, avoiding redundancy with original memories.
|
||||
Ensure consistency with other memory blocks.
|
||||
|
||||
Args:
|
||||
new_memory (str): The new memory with information integrated from the memory block. If there is no new information, then this should be the same as the content in the source block.
|
||||
target_block_label (str): The name of the block to write to.
|
||||
Returns:
|
||||
str: None is always returned as this function does not produce a response.
|
||||
"""
|
||||
agent_state.memory.update_block_value(label=target_block_label, value=new_memory)
|
||||
return None
|
||||
|
||||
tool = client.tools.upsert_from_function(func=rethink_memory)
|
||||
# Yield the created tool
|
||||
yield tool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def agents(client, weather_tool):
|
||||
"""
|
||||
@ -146,7 +160,7 @@ def step_state_map(agents):
|
||||
Returns:
|
||||
Dict[str, AgentStepState]: Mapping of agent IDs to step states
|
||||
"""
|
||||
solver = ToolRulesSolver(tool_rules=[InitToolRule(tool_name="send_message")])
|
||||
solver = ToolRulesSolver(tool_rules=[InitToolRule(tool_name="get_weather")])
|
||||
return {agent.id: AgentStepState(step_number=0, tool_rules_solver=solver) for agent in agents}
|
||||
|
||||
|
||||
@ -173,26 +187,7 @@ def create_batch_response(batch_id: str, processing_status: str = "in_progress")
|
||||
)
|
||||
|
||||
|
||||
def create_successful_response(custom_id: str) -> BetaMessageBatchIndividualResponse:
|
||||
"""Create a dummy successful batch response."""
|
||||
return BetaMessageBatchIndividualResponse(
|
||||
custom_id=custom_id,
|
||||
result=BetaMessageBatchSucceededResult(
|
||||
type="succeeded",
|
||||
message=BetaMessage(
|
||||
id="msg_abc123",
|
||||
role="assistant",
|
||||
type="message",
|
||||
model="claude-3-5-sonnet-20240620",
|
||||
content=[{"type": "text", "text": "hi!"}],
|
||||
usage={"input_tokens": 5, "output_tokens": 7},
|
||||
stop_reason="end_turn",
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def create_complete_tool_response(custom_id: str, model: str, request_heartbeat: bool) -> BetaMessageBatchIndividualResponse:
|
||||
def create_get_weather_tool_response(custom_id: str, model: str, request_heartbeat: bool) -> BetaMessageBatchIndividualResponse:
|
||||
"""Create a dummy successful batch response with a tool call after user asks about weather."""
|
||||
return BetaMessageBatchIndividualResponse(
|
||||
custom_id=custom_id,
|
||||
@ -223,6 +218,39 @@ def create_complete_tool_response(custom_id: str, model: str, request_heartbeat:
|
||||
)
|
||||
|
||||
|
||||
def create_rethink_tool_response(
|
||||
custom_id: str, model: str, request_heartbeat: bool, new_memory: str, target_block_label: str
|
||||
) -> BetaMessageBatchIndividualResponse:
|
||||
"""Create a dummy successful batch response with a tool call after user asks about weather."""
|
||||
return BetaMessageBatchIndividualResponse(
|
||||
custom_id=custom_id,
|
||||
result=BetaMessageBatchSucceededResult(
|
||||
type="succeeded",
|
||||
message=BetaMessage(
|
||||
id="msg_abc123",
|
||||
role="assistant",
|
||||
type="message",
|
||||
model=model,
|
||||
content=[
|
||||
{"type": "text", "text": "Let me rethink my memory."},
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "tu_01234567890123456789012345",
|
||||
"name": "rethink_memory",
|
||||
"input": {
|
||||
"new_memory": new_memory,
|
||||
"target_block_label": target_block_label,
|
||||
"request_heartbeat": request_heartbeat,
|
||||
},
|
||||
},
|
||||
],
|
||||
usage={"input_tokens": 7, "output_tokens": 17},
|
||||
stop_reason="end_turn",
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def create_failed_response(custom_id: str) -> BetaMessageBatchIndividualResponse:
|
||||
"""Create a dummy failed batch response with a rate limit error."""
|
||||
return BetaMessageBatchIndividualResponse(
|
||||
@ -340,6 +368,357 @@ class MockAsyncIterable:
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rethink_tool_modify_agent_state(client, disable_e2b_api_key, server, default_user, batch_job, rethink_tool):
|
||||
target_block_label = "human"
|
||||
new_memory = "banana"
|
||||
agent = client.agents.create(
|
||||
name=f"test_agent_rethink",
|
||||
include_base_tools=True,
|
||||
model=MODELS["sonnet"],
|
||||
tags=["test_agents"],
|
||||
embedding="letta/letta-free",
|
||||
tool_ids=[rethink_tool.id],
|
||||
memory_blocks=[
|
||||
{
|
||||
"label": target_block_label,
|
||||
"value": "Name: Matt",
|
||||
},
|
||||
],
|
||||
)
|
||||
agents = [agent]
|
||||
batch_requests = [
|
||||
LettaBatchRequest(agent_id=agent.id, messages=[MessageCreate(role="user", content=[TextContent(text=f"Rethink memory.")])])
|
||||
for agent in agents
|
||||
]
|
||||
|
||||
anthropic_batch_id = "msgbatch_test_12345"
|
||||
dummy_batch_response = create_batch_response(
|
||||
batch_id=anthropic_batch_id,
|
||||
)
|
||||
|
||||
# 1. Invoke `step_until_request`
|
||||
with patch("letta.llm_api.anthropic_client.AnthropicClient.send_llm_batch_request_async", return_value=dummy_batch_response):
|
||||
# Create batch runner
|
||||
batch_runner = LettaAgentBatch(
|
||||
message_manager=server.message_manager,
|
||||
agent_manager=server.agent_manager,
|
||||
block_manager=server.block_manager,
|
||||
passage_manager=server.passage_manager,
|
||||
batch_manager=server.batch_manager,
|
||||
sandbox_config_manager=server.sandbox_config_manager,
|
||||
job_manager=server.job_manager,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Run the method under test
|
||||
solver = ToolRulesSolver(tool_rules=[InitToolRule(tool_name="rethink_memory")])
|
||||
step_state_map = {agent.id: AgentStepState(step_number=0, tool_rules_solver=solver) for agent in agents}
|
||||
pre_resume_response = await batch_runner.step_until_request(
|
||||
batch_requests=batch_requests,
|
||||
agent_step_state_mapping=step_state_map,
|
||||
letta_batch_job_id=batch_job.id,
|
||||
)
|
||||
|
||||
# 2. Invoke the polling job and mock responses from Anthropic
|
||||
mock_retrieve = AsyncMock(return_value=create_batch_response(batch_id=pre_resume_response.letta_batch_id, processing_status="ended"))
|
||||
|
||||
with patch.object(server.anthropic_async_client.beta.messages.batches, "retrieve", mock_retrieve):
|
||||
mock_items = [
|
||||
create_rethink_tool_response(
|
||||
custom_id=agent.id,
|
||||
model=agent.llm_config.model,
|
||||
request_heartbeat=False,
|
||||
new_memory=new_memory,
|
||||
target_block_label=target_block_label,
|
||||
)
|
||||
for agent in agents
|
||||
]
|
||||
|
||||
# Create the mock for results
|
||||
mock_results = Mock()
|
||||
mock_results.return_value = MockAsyncIterable(mock_items.copy())
|
||||
|
||||
with patch.object(server.anthropic_async_client.beta.messages.batches, "results", mock_results):
|
||||
with patch("letta.llm_api.anthropic_client.AnthropicClient.send_llm_batch_request_async", return_value=dummy_batch_response):
|
||||
await poll_running_llm_batches(server)
|
||||
|
||||
# Check that the tool has been executed correctly
|
||||
agent = client.agents.retrieve(agent_id=agent.id)
|
||||
for block in agent.memory.blocks:
|
||||
if block.label == target_block_label:
|
||||
assert block.value == new_memory
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_partial_error_from_anthropic_batch(
|
||||
disable_e2b_api_key, server, default_user, agents: Tuple[AgentState], batch_requests, step_state_map, batch_job
|
||||
):
|
||||
anthropic_batch_id = "msgbatch_test_12345"
|
||||
dummy_batch_response = create_batch_response(
|
||||
batch_id=anthropic_batch_id,
|
||||
)
|
||||
|
||||
# 1. Invoke `step_until_request`
|
||||
with patch("letta.llm_api.anthropic_client.AnthropicClient.send_llm_batch_request_async", return_value=dummy_batch_response):
|
||||
# Create batch runner
|
||||
batch_runner = LettaAgentBatch(
|
||||
message_manager=server.message_manager,
|
||||
agent_manager=server.agent_manager,
|
||||
block_manager=server.block_manager,
|
||||
passage_manager=server.passage_manager,
|
||||
batch_manager=server.batch_manager,
|
||||
sandbox_config_manager=server.sandbox_config_manager,
|
||||
job_manager=server.job_manager,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Run the method under test
|
||||
pre_resume_response = await batch_runner.step_until_request(
|
||||
batch_requests=batch_requests,
|
||||
agent_step_state_mapping=step_state_map,
|
||||
letta_batch_job_id=batch_job.id,
|
||||
)
|
||||
|
||||
llm_batch_jobs = server.batch_manager.list_llm_batch_jobs(letta_batch_id=pre_resume_response.letta_batch_id, actor=default_user)
|
||||
llm_batch_job = llm_batch_jobs[0]
|
||||
|
||||
# 2. Invoke the polling job and mock responses from Anthropic
|
||||
mock_retrieve = AsyncMock(return_value=create_batch_response(batch_id=pre_resume_response.letta_batch_id, processing_status="ended"))
|
||||
|
||||
with patch.object(server.anthropic_async_client.beta.messages.batches, "retrieve", mock_retrieve):
|
||||
agents_failed = agents[:1]
|
||||
agents_continue = agents[1:]
|
||||
# Create failed response for one agent
|
||||
mock_items = [create_failed_response(custom_id=agent.id) for agent in agents_failed]
|
||||
mock_items.extend(
|
||||
[
|
||||
create_get_weather_tool_response(custom_id=agent.id, model=agent.llm_config.model, request_heartbeat=True)
|
||||
for agent in agents_continue
|
||||
]
|
||||
)
|
||||
|
||||
# Create the mock for results
|
||||
mock_results = Mock()
|
||||
mock_results.return_value = MockAsyncIterable(mock_items.copy()) # Using copy to preserve the original list
|
||||
|
||||
with patch.object(server.anthropic_async_client.beta.messages.batches, "results", mock_results):
|
||||
with patch("letta.llm_api.anthropic_client.AnthropicClient.send_llm_batch_request_async", return_value=dummy_batch_response):
|
||||
msg_counts_before = {agent.id: server.message_manager.size(actor=default_user, agent_id=agent.id) for agent in agents}
|
||||
|
||||
new_batch_responses = await poll_running_llm_batches(server)
|
||||
|
||||
# Verify database records were updated correctly
|
||||
llm_batch_job = server.batch_manager.get_llm_batch_job_by_id(llm_batch_job.id, actor=default_user)
|
||||
|
||||
# Verify job properties
|
||||
assert llm_batch_job.status == JobStatus.completed, "Job status should be 'completed'"
|
||||
|
||||
# Verify batch items
|
||||
items = server.batch_manager.list_llm_batch_items(llm_batch_id=llm_batch_job.id, actor=default_user)
|
||||
assert len(items) == 3, f"Expected 3 batch items, got {len(items)}"
|
||||
|
||||
# Verify only one new batch response
|
||||
assert len(new_batch_responses) == 1
|
||||
post_resume_response = new_batch_responses[0]
|
||||
|
||||
assert (
|
||||
post_resume_response.letta_batch_id == pre_resume_response.letta_batch_id
|
||||
), "resume_step_after_request is expected to have the same letta_batch_id"
|
||||
assert (
|
||||
post_resume_response.last_llm_batch_id != pre_resume_response.last_llm_batch_id
|
||||
), "resume_step_after_request is expected to have different llm_batch_id."
|
||||
assert post_resume_response.status == JobStatus.running
|
||||
# NOTE: We only expect 2 agents to continue (succeeded ones)
|
||||
assert post_resume_response.agent_count == 2
|
||||
|
||||
# New batch‑items should exist, initialised in (created, paused) state
|
||||
new_items = server.batch_manager.list_llm_batch_items(
|
||||
llm_batch_id=post_resume_response.last_llm_batch_id, actor=default_user
|
||||
)
|
||||
assert len(new_items) == 2, f"Expected 2 new batch item, got {len(new_items)}"
|
||||
# Assert that the continuing agent is in the only item
|
||||
assert {i.agent_id for i in new_items} == {a.id for a in agents_continue}
|
||||
assert {i.request_status for i in new_items} == {JobStatus.created}
|
||||
assert {i.step_status for i in new_items} == {AgentStepStatus.paused}
|
||||
|
||||
# Confirm that tool_rules_solver state was preserved correctly
|
||||
# Assert every new item's step_state's tool_rules_solver has "get_weather" in the tool_call_history
|
||||
assert all(
|
||||
"get_weather" in item.step_state.tool_rules_solver.tool_call_history for item in new_items
|
||||
), "Expected 'get_weather' in tool_call_history for all new_items"
|
||||
# Assert that each new item's step_number was incremented to 1
|
||||
assert all(
|
||||
item.step_state.step_number == 1 for item in new_items
|
||||
), "Expected step_number to be incremented to 1 for all new_items"
|
||||
|
||||
# Old items must have been flipped to completed / finished earlier
|
||||
# (sanity – we already asserted this above, but we keep it close for clarity)
|
||||
old_items = server.batch_manager.list_llm_batch_items(
|
||||
llm_batch_id=pre_resume_response.last_llm_batch_id, actor=default_user
|
||||
)
|
||||
for item in old_items:
|
||||
if item.agent_id == agents_failed[0].id:
|
||||
assert item.request_status == JobStatus.failed
|
||||
assert item.step_status == AgentStepStatus.paused
|
||||
else:
|
||||
assert item.request_status == JobStatus.completed
|
||||
assert item.step_status == AgentStepStatus.completed
|
||||
|
||||
# Tool‑call side‑effects – each agent gets at least 2 extra messages
|
||||
for agent in agents:
|
||||
before = msg_counts_before[agent.id] # captured just before resume
|
||||
after = server.message_manager.size(actor=default_user, agent_id=agent.id)
|
||||
|
||||
if agent.id == agents_failed[0].id:
|
||||
assert after == before, f"Agent {agent.id} should not have extra messages persisted due to Anthropic failure"
|
||||
else:
|
||||
assert after - before >= 2, (
|
||||
f"Agent {agent.id} should have an assistant tool‑call " f"and tool‑response message persisted."
|
||||
)
|
||||
|
||||
# Check that agent states have been properly modified to have extended in-context messages
|
||||
for agent in agents:
|
||||
refreshed_agent = server.agent_manager.get_agent_by_id(agent_id=agent.id, actor=default_user)
|
||||
if refreshed_agent.id == agents_failed[0].id:
|
||||
assert (
|
||||
len(refreshed_agent.message_ids) == 4
|
||||
), f"Agent's in-context messages have not been extended, are length: {len(refreshed_agent.message_ids)}"
|
||||
else:
|
||||
assert (
|
||||
len(refreshed_agent.message_ids) == 6
|
||||
), f"Agent's in-context messages have been extended, are length: {len(refreshed_agent.message_ids)}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_step_some_stop(
|
||||
disable_e2b_api_key, server, default_user, agents: Tuple[AgentState], batch_requests, step_state_map, batch_job
|
||||
):
|
||||
anthropic_batch_id = "msgbatch_test_12345"
|
||||
dummy_batch_response = create_batch_response(
|
||||
batch_id=anthropic_batch_id,
|
||||
)
|
||||
|
||||
# 1. Invoke `step_until_request`
|
||||
with patch("letta.llm_api.anthropic_client.AnthropicClient.send_llm_batch_request_async", return_value=dummy_batch_response):
|
||||
# Create batch runner
|
||||
batch_runner = LettaAgentBatch(
|
||||
message_manager=server.message_manager,
|
||||
agent_manager=server.agent_manager,
|
||||
block_manager=server.block_manager,
|
||||
passage_manager=server.passage_manager,
|
||||
batch_manager=server.batch_manager,
|
||||
sandbox_config_manager=server.sandbox_config_manager,
|
||||
job_manager=server.job_manager,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Run the method under test
|
||||
pre_resume_response = await batch_runner.step_until_request(
|
||||
batch_requests=batch_requests,
|
||||
agent_step_state_mapping=step_state_map,
|
||||
letta_batch_job_id=batch_job.id,
|
||||
)
|
||||
|
||||
llm_batch_jobs = server.batch_manager.list_llm_batch_jobs(letta_batch_id=pre_resume_response.letta_batch_id, actor=default_user)
|
||||
llm_batch_job = llm_batch_jobs[0]
|
||||
|
||||
# 2. Invoke the polling job and mock responses from Anthropic
|
||||
mock_retrieve = AsyncMock(return_value=create_batch_response(batch_id=pre_resume_response.letta_batch_id, processing_status="ended"))
|
||||
|
||||
with patch.object(server.anthropic_async_client.beta.messages.batches, "retrieve", mock_retrieve):
|
||||
agents_continue = agents[:1]
|
||||
agents_finish = agents[1:]
|
||||
mock_items = [
|
||||
create_get_weather_tool_response(custom_id=agent.id, model=agent.llm_config.model, request_heartbeat=True)
|
||||
for agent in agents_continue
|
||||
]
|
||||
mock_items.extend(
|
||||
[
|
||||
create_get_weather_tool_response(custom_id=agent.id, model=agent.llm_config.model, request_heartbeat=False)
|
||||
for agent in agents_finish
|
||||
]
|
||||
)
|
||||
|
||||
# Create the mock for results
|
||||
mock_results = Mock()
|
||||
mock_results.return_value = MockAsyncIterable(mock_items.copy()) # Using copy to preserve the original list
|
||||
|
||||
with patch.object(server.anthropic_async_client.beta.messages.batches, "results", mock_results):
|
||||
with patch("letta.llm_api.anthropic_client.AnthropicClient.send_llm_batch_request_async", return_value=dummy_batch_response):
|
||||
msg_counts_before = {agent.id: server.message_manager.size(actor=default_user, agent_id=agent.id) for agent in agents}
|
||||
|
||||
new_batch_responses = await poll_running_llm_batches(server)
|
||||
|
||||
# Verify database records were updated correctly
|
||||
llm_batch_job = server.batch_manager.get_llm_batch_job_by_id(llm_batch_job.id, actor=default_user)
|
||||
|
||||
# Verify job properties
|
||||
assert llm_batch_job.status == JobStatus.completed, "Job status should be 'completed'"
|
||||
|
||||
# Verify batch items
|
||||
items = server.batch_manager.list_llm_batch_items(llm_batch_id=llm_batch_job.id, actor=default_user)
|
||||
assert len(items) == 3, f"Expected 3 batch items, got {len(items)}"
|
||||
assert all([item.request_status == JobStatus.completed for item in items])
|
||||
|
||||
# Verify only one new batch response
|
||||
assert len(new_batch_responses) == 1
|
||||
post_resume_response = new_batch_responses[0]
|
||||
|
||||
assert (
|
||||
post_resume_response.letta_batch_id == pre_resume_response.letta_batch_id
|
||||
), "resume_step_after_request is expected to have the same letta_batch_id"
|
||||
assert (
|
||||
post_resume_response.last_llm_batch_id != pre_resume_response.last_llm_batch_id
|
||||
), "resume_step_after_request is expected to have different llm_batch_id."
|
||||
assert post_resume_response.status == JobStatus.running
|
||||
# NOTE: We only expect 1 agent to continue
|
||||
assert post_resume_response.agent_count == 1
|
||||
|
||||
# New batch‑items should exist, initialised in (created, paused) state
|
||||
new_items = server.batch_manager.list_llm_batch_items(
|
||||
llm_batch_id=post_resume_response.last_llm_batch_id, actor=default_user
|
||||
)
|
||||
assert len(new_items) == 1, f"Expected 1 new batch item, got {len(new_items)}"
|
||||
# Assert that the continuing agent is in the only item
|
||||
assert new_items[0].agent_id == agents_continue[0].id
|
||||
assert {i.request_status for i in new_items} == {JobStatus.created}
|
||||
assert {i.step_status for i in new_items} == {AgentStepStatus.paused}
|
||||
|
||||
# Confirm that tool_rules_solver state was preserved correctly
|
||||
# Assert every new item's step_state's tool_rules_solver has "get_weather" in the tool_call_history
|
||||
assert all(
|
||||
"get_weather" in item.step_state.tool_rules_solver.tool_call_history for item in new_items
|
||||
), "Expected 'get_weather' in tool_call_history for all new_items"
|
||||
# Assert that each new item's step_number was incremented to 1
|
||||
assert all(
|
||||
item.step_state.step_number == 1 for item in new_items
|
||||
), "Expected step_number to be incremented to 1 for all new_items"
|
||||
|
||||
# Old items must have been flipped to completed / finished earlier
|
||||
# (sanity – we already asserted this above, but we keep it close for clarity)
|
||||
old_items = server.batch_manager.list_llm_batch_items(
|
||||
llm_batch_id=pre_resume_response.last_llm_batch_id, actor=default_user
|
||||
)
|
||||
assert {i.request_status for i in old_items} == {JobStatus.completed}
|
||||
assert {i.step_status for i in old_items} == {AgentStepStatus.completed}
|
||||
|
||||
# Tool‑call side‑effects – each agent gets at least 2 extra messages
|
||||
for agent in agents:
|
||||
before = msg_counts_before[agent.id] # captured just before resume
|
||||
after = server.message_manager.size(actor=default_user, agent_id=agent.id)
|
||||
assert after - before >= 2, (
|
||||
f"Agent {agent.id} should have an assistant tool‑call " f"and tool‑response message persisted."
|
||||
)
|
||||
|
||||
# Check that agent states have been properly modified to have extended in-context messages
|
||||
for agent in agents:
|
||||
refreshed_agent = server.agent_manager.get_agent_by_id(agent_id=agent.id, actor=default_user)
|
||||
assert (
|
||||
len(refreshed_agent.message_ids) == 6
|
||||
), f"Agent's in-context messages have been extended, are length: {len(refreshed_agent.message_ids)}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_step_after_request_all_continue(
|
||||
disable_e2b_api_key, server, default_user, agents: Tuple[AgentState], batch_requests, step_state_map, batch_job
|
||||
@ -384,7 +763,7 @@ async def test_resume_step_after_request_all_continue(
|
||||
|
||||
with patch.object(server.anthropic_async_client.beta.messages.batches, "retrieve", mock_retrieve):
|
||||
mock_items = [
|
||||
create_complete_tool_response(custom_id=agent.id, model=agent.llm_config.model, request_heartbeat=True) for agent in agents
|
||||
create_get_weather_tool_response(custom_id=agent.id, model=agent.llm_config.model, request_heartbeat=True) for agent in agents
|
||||
]
|
||||
|
||||
# Create the mock for results
|
||||
@ -460,7 +839,7 @@ async def test_resume_step_after_request_all_continue(
|
||||
refreshed_agent = server.agent_manager.get_agent_by_id(agent_id=agent.id, actor=default_user)
|
||||
assert (
|
||||
len(refreshed_agent.message_ids) == 6
|
||||
), f"Agent's in-context messages have not been extended, are length: {len(refreshed_agent.message_ids)}"
|
||||
), f"Agent's in-context messages have been extended, are length: {len(refreshed_agent.message_ids)}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -518,7 +897,7 @@ async def test_step_until_request_prepares_and_submits_batch_correctly(
|
||||
# Verify tool configuration
|
||||
for agent_id, tools in agent_tools_mapping.items():
|
||||
available_tools = {tool["name"] for tool in tools}
|
||||
assert available_tools == {"send_message"}, f"Expected only send_message tool, got {available_tools}"
|
||||
assert available_tools == {"get_weather"}, f"Expected only send_message tool, got {available_tools}"
|
||||
|
||||
# Verify model assignments
|
||||
for agent_id, expected_model in expected_models.items():
|
||||
|
@ -1,5 +1,6 @@
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import string
|
||||
import time
|
||||
from datetime import datetime, timedelta, timezone
|
||||
@ -5211,6 +5212,47 @@ def test_bulk_update_batch_items_request_status_by_agent(
|
||||
assert updated.request_status == JobStatus.expired
|
||||
|
||||
|
||||
def test_bulk_update_nonexistent_items_should_error(
|
||||
server,
|
||||
default_user,
|
||||
dummy_beta_message_batch,
|
||||
dummy_successful_response,
|
||||
letta_batch_job,
|
||||
):
|
||||
# Create a batch job
|
||||
batch = server.batch_manager.create_llm_batch_job(
|
||||
llm_provider=ProviderType.anthropic,
|
||||
create_batch_response=dummy_beta_message_batch,
|
||||
actor=default_user,
|
||||
letta_batch_job_id=letta_batch_job.id,
|
||||
)
|
||||
|
||||
nonexistent_pairs = [(batch.id, "nonexistent-agent-id")]
|
||||
nonexistent_updates = [{"request_status": JobStatus.expired}]
|
||||
expected_err_msg = (
|
||||
f"Cannot bulk-update batch items: no records for the following "
|
||||
f"(llm_batch_id, agent_id) pairs: {{('{batch.id}', 'nonexistent-agent-id')}}"
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match=re.escape(expected_err_msg)):
|
||||
server.batch_manager.bulk_update_llm_batch_items(nonexistent_pairs, nonexistent_updates)
|
||||
|
||||
with pytest.raises(ValueError, match=re.escape(expected_err_msg)):
|
||||
server.batch_manager.bulk_update_batch_llm_items_results_by_agent(
|
||||
[ItemUpdateInfo(batch.id, "nonexistent-agent-id", JobStatus.expired, dummy_successful_response)]
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match=re.escape(expected_err_msg)):
|
||||
server.batch_manager.bulk_update_llm_batch_items_step_status_by_agent(
|
||||
[StepStatusUpdateInfo(batch.id, "nonexistent-agent-id", AgentStepStatus.resumed)]
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match=re.escape(expected_err_msg)):
|
||||
server.batch_manager.bulk_update_llm_batch_items_request_status_by_agent(
|
||||
[RequestStatusUpdateInfo(batch.id, "nonexistent-agent-id", JobStatus.expired)]
|
||||
)
|
||||
|
||||
|
||||
def test_bulk_update_nonexistent_items(server, default_user, dummy_beta_message_batch, dummy_successful_response, letta_batch_job):
|
||||
# Create a batch job
|
||||
batch = server.batch_manager.create_llm_batch_job(
|
||||
@ -5227,22 +5269,22 @@ def test_bulk_update_nonexistent_items(server, default_user, dummy_beta_message_
|
||||
nonexistent_updates = [{"request_status": JobStatus.expired}]
|
||||
|
||||
# This should not raise an error, just silently skip non-existent items
|
||||
server.batch_manager.bulk_update_llm_batch_items(nonexistent_pairs, nonexistent_updates)
|
||||
server.batch_manager.bulk_update_llm_batch_items(nonexistent_pairs, nonexistent_updates, strict=False)
|
||||
|
||||
# Test with higher-level methods
|
||||
# Results by agent
|
||||
server.batch_manager.bulk_update_batch_llm_items_results_by_agent(
|
||||
[ItemUpdateInfo(batch.id, "nonexistent-agent-id", JobStatus.expired, dummy_successful_response)]
|
||||
[ItemUpdateInfo(batch.id, "nonexistent-agent-id", JobStatus.expired, dummy_successful_response)], strict=False
|
||||
)
|
||||
|
||||
# Step status by agent
|
||||
server.batch_manager.bulk_update_llm_batch_items_step_status_by_agent(
|
||||
[StepStatusUpdateInfo(batch.id, "nonexistent-agent-id", AgentStepStatus.resumed)]
|
||||
[StepStatusUpdateInfo(batch.id, "nonexistent-agent-id", AgentStepStatus.resumed)], strict=False
|
||||
)
|
||||
|
||||
# Request status by agent
|
||||
server.batch_manager.bulk_update_llm_batch_items_request_status_by_agent(
|
||||
[RequestStatusUpdateInfo(batch.id, "nonexistent-agent-id", JobStatus.expired)]
|
||||
[RequestStatusUpdateInfo(batch.id, "nonexistent-agent-id", JobStatus.expired)], strict=False
|
||||
)
|
||||
|
||||
|
||||
|
@ -158,7 +158,7 @@ async def test_empty_group(server, actor):
|
||||
await server.send_group_message_to_agent(
|
||||
group_id=group.id,
|
||||
actor=actor,
|
||||
messages=[
|
||||
input_messages=[
|
||||
MessageCreate(
|
||||
role="user",
|
||||
content="what is everyone up to for the holidays?",
|
||||
@ -246,7 +246,7 @@ async def test_round_robin(server, actor, participant_agents):
|
||||
response = await server.send_group_message_to_agent(
|
||||
group_id=group.id,
|
||||
actor=actor,
|
||||
messages=[
|
||||
input_messages=[
|
||||
MessageCreate(
|
||||
role="user",
|
||||
content="what is everyone up to for the holidays?",
|
||||
@ -301,7 +301,7 @@ async def test_round_robin(server, actor, participant_agents):
|
||||
response = await server.send_group_message_to_agent(
|
||||
group_id=group.id,
|
||||
actor=actor,
|
||||
messages=[
|
||||
input_messages=[
|
||||
MessageCreate(
|
||||
role="user",
|
||||
content="what is everyone up to for the holidays?",
|
||||
@ -367,7 +367,7 @@ async def test_supervisor(server, actor, participant_agents):
|
||||
response = await server.send_group_message_to_agent(
|
||||
group_id=group.id,
|
||||
actor=actor,
|
||||
messages=[
|
||||
input_messages=[
|
||||
MessageCreate(
|
||||
role="user",
|
||||
content="ask everyone what they like to do for fun and then come up with an activity for everyone to do together.",
|
||||
@ -449,7 +449,7 @@ async def test_dynamic_group_chat(server, actor, manager_agent, participant_agen
|
||||
response = await server.send_group_message_to_agent(
|
||||
group_id=group.id,
|
||||
actor=actor,
|
||||
messages=[
|
||||
input_messages=[
|
||||
MessageCreate(role="user", content="what is everyone up to for the holidays?"),
|
||||
],
|
||||
stream_steps=False,
|
||||
|
Loading…
Reference in New Issue
Block a user