chore: bump version 0.7.21 (#2653)

Co-authored-by: Andy Li <55300002+cliandy@users.noreply.github.com>
Co-authored-by: Kevin Lin <klin5061@gmail.com>
Co-authored-by: Sarah Wooders <sarahwooders@gmail.com>
Co-authored-by: jnjpng <jin@letta.com>
Co-authored-by: Matthew Zhou <mattzh1314@gmail.com>
This commit is contained in:
cthomas 2025-05-21 16:33:29 -07:00 committed by GitHub
parent 97e454124d
commit c0efe8ad0c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
83 changed files with 4774 additions and 1734 deletions

View File

@ -0,0 +1,31 @@
"""add provider category to steps
Revision ID: 6c53224a7a58
Revises: cc8dc340836d
Create Date: 2025-05-21 10:09:43.761669
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "6c53224a7a58"
down_revision: Union[str, None] = "cc8dc340836d"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("steps", sa.Column("provider_category", sa.String(), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("steps", "provider_category")
# ### end Alembic commands ###

View File

@ -0,0 +1,50 @@
"""add support for request and response jsons from llm providers
Revision ID: cc8dc340836d
Revises: 220856bbf43b
Create Date: 2025-05-19 14:25:41.999676
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "cc8dc340836d"
down_revision: Union[str, None] = "220856bbf43b"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"provider_traces",
sa.Column("id", sa.String(), nullable=False),
sa.Column("request_json", sa.JSON(), nullable=False),
sa.Column("response_json", sa.JSON(), nullable=False),
sa.Column("step_id", sa.String(), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False),
sa.Column("_created_by_id", sa.String(), nullable=True),
sa.Column("_last_updated_by_id", sa.String(), nullable=True),
sa.Column("organization_id", sa.String(), nullable=False),
sa.ForeignKeyConstraint(
["organization_id"],
["organizations.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index("ix_step_id", "provider_traces", ["step_id"], unique=False)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index("ix_step_id", table_name="provider_traces")
op.drop_table("provider_traces")
# ### end Alembic commands ###

View File

@ -1,4 +1,4 @@
__version__ = "0.7.20"
__version__ = "0.7.21"
# import clients
from letta.client.client import LocalClient, RESTClient, create_client

View File

@ -1,4 +1,6 @@
import asyncio
import json
import os
import time
import traceback
import warnings
@ -7,6 +9,7 @@ from typing import Dict, List, Optional, Tuple, Union
from openai.types.beta.function_tool import FunctionTool as OpenAITool
from letta.agents.helpers import generate_step_id
from letta.constants import (
CLI_WARNING_PREFIX,
COMPOSIO_ENTITY_ENV_VAR_KEY,
@ -16,6 +19,7 @@ from letta.constants import (
LETTA_CORE_TOOL_MODULE_NAME,
LETTA_MULTI_AGENT_TOOL_MODULE_NAME,
LLM_MAX_TOKENS,
READ_ONLY_BLOCK_EDIT_ERROR,
REQ_HEARTBEAT_MESSAGE,
SEND_MESSAGE_TOOL_NAME,
)
@ -41,7 +45,7 @@ from letta.orm.enums import ToolType
from letta.schemas.agent import AgentState, AgentStepResponse, UpdateAgent, get_prompt_template_for_agent_type
from letta.schemas.block import BlockUpdate
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import MessageRole
from letta.schemas.enums import MessageRole, ProviderType
from letta.schemas.letta_message_content import TextContent
from letta.schemas.memory import ContextWindowOverview, Memory
from letta.schemas.message import Message, MessageCreate, ToolReturn
@ -61,9 +65,10 @@ from letta.services.message_manager import MessageManager
from letta.services.passage_manager import PassageManager
from letta.services.provider_manager import ProviderManager
from letta.services.step_manager import StepManager
from letta.services.telemetry_manager import NoopTelemetryManager, TelemetryManager
from letta.services.tool_executor.tool_execution_sandbox import ToolExecutionSandbox
from letta.services.tool_manager import ToolManager
from letta.settings import summarizer_settings
from letta.settings import settings, summarizer_settings
from letta.streaming_interface import StreamingRefreshCLIInterface
from letta.system import get_heartbeat, get_token_limit_warning, package_function_response, package_summarize_message, package_user_message
from letta.tracing import log_event, trace_method
@ -141,6 +146,7 @@ class Agent(BaseAgent):
self.agent_manager = AgentManager()
self.job_manager = JobManager()
self.step_manager = StepManager()
self.telemetry_manager = TelemetryManager() if settings.llm_api_logging else NoopTelemetryManager()
# State needed for heartbeat pausing
@ -298,6 +304,7 @@ class Agent(BaseAgent):
step_count: Optional[int] = None,
last_function_failed: bool = False,
put_inner_thoughts_first: bool = True,
step_id: Optional[str] = None,
) -> ChatCompletionResponse | None:
"""Get response from LLM API with robust retry mechanism."""
log_telemetry(self.logger, "_get_ai_reply start")
@ -347,8 +354,9 @@ class Agent(BaseAgent):
messages=message_sequence,
llm_config=self.agent_state.llm_config,
tools=allowed_functions,
stream=stream,
force_tool_call=force_tool_call,
telemetry_manager=self.telemetry_manager,
step_id=step_id,
)
else:
# Fallback to existing flow
@ -365,6 +373,9 @@ class Agent(BaseAgent):
stream_interface=self.interface,
put_inner_thoughts_first=put_inner_thoughts_first,
name=self.agent_state.name,
telemetry_manager=self.telemetry_manager,
step_id=step_id,
actor=self.user,
)
log_telemetry(self.logger, "_get_ai_reply create finish")
@ -840,6 +851,9 @@ class Agent(BaseAgent):
# Extract job_id from metadata if present
job_id = metadata.get("job_id") if metadata else None
# Declare step_id for the given step to be used as the step is processing.
step_id = generate_step_id()
# Step 0: update core memory
# only pulling latest block data if shared memory is being used
current_persisted_memory = Memory(
@ -870,6 +884,7 @@ class Agent(BaseAgent):
step_count=step_count,
last_function_failed=last_function_failed,
put_inner_thoughts_first=put_inner_thoughts_first,
step_id=step_id,
)
if not response:
# EDGE CASE: Function call failed AND there's no tools left for agent to call -> return early
@ -944,6 +959,7 @@ class Agent(BaseAgent):
actor=self.user,
agent_id=self.agent_state.id,
provider_name=self.agent_state.llm_config.model_endpoint_type,
provider_category=self.agent_state.llm_config.provider_category or "base",
model=self.agent_state.llm_config.model,
model_endpoint=self.agent_state.llm_config.model_endpoint,
context_window_limit=self.agent_state.llm_config.context_window,
@ -953,6 +969,7 @@ class Agent(BaseAgent):
actor=self.user,
),
job_id=job_id,
step_id=step_id,
)
for message in all_new_messages:
message.step_id = step.id
@ -1255,6 +1272,276 @@ class Agent(BaseAgent):
functions_definitions=available_functions_definitions,
)
async def get_context_window_async(self) -> ContextWindowOverview:
if os.getenv("LETTA_ENVIRONMENT") == "PRODUCTION":
return await self.get_context_window_from_anthropic_async()
return await self.get_context_window_from_tiktoken_async()
async def get_context_window_from_tiktoken_async(self) -> ContextWindowOverview:
"""Get the context window of the agent"""
# Grab the in-context messages
# conversion of messages to OpenAI dict format, which is passed to the token counter
(in_context_messages, passage_manager_size, message_manager_size) = await asyncio.gather(
self.agent_manager.get_in_context_messages_async(agent_id=self.agent_state.id, actor=self.user),
self.passage_manager.size_async(actor=self.user, agent_id=self.agent_state.id),
self.message_manager.size_async(actor=self.user, agent_id=self.agent_state.id),
)
in_context_messages_openai = [m.to_openai_dict() for m in in_context_messages]
# Extract system, memory and external summary
if (
len(in_context_messages) > 0
and in_context_messages[0].role == MessageRole.system
and in_context_messages[0].content
and len(in_context_messages[0].content) == 1
and isinstance(in_context_messages[0].content[0], TextContent)
):
system_message = in_context_messages[0].content[0].text
external_memory_marker_pos = system_message.find("###")
core_memory_marker_pos = system_message.find("<", external_memory_marker_pos)
if external_memory_marker_pos != -1 and core_memory_marker_pos != -1:
system_prompt = system_message[:external_memory_marker_pos].strip()
external_memory_summary = system_message[external_memory_marker_pos:core_memory_marker_pos].strip()
core_memory = system_message[core_memory_marker_pos:].strip()
else:
# if no markers found, put everything in system message
system_prompt = system_message
external_memory_summary = ""
core_memory = ""
else:
# if no system message, fall back on agent's system prompt
system_prompt = self.agent_state.system
external_memory_summary = ""
core_memory = ""
num_tokens_system = count_tokens(system_prompt)
num_tokens_core_memory = count_tokens(core_memory)
num_tokens_external_memory_summary = count_tokens(external_memory_summary)
# Check if there's a summary message in the message queue
if (
len(in_context_messages) > 1
and in_context_messages[1].role == MessageRole.user
and in_context_messages[1].content
and len(in_context_messages[1].content) == 1
and isinstance(in_context_messages[1].content[0], TextContent)
# TODO remove hardcoding
and "The following is a summary of the previous " in in_context_messages[1].content[0].text
):
# Summary message exists
text_content = in_context_messages[1].content[0].text
assert text_content is not None
summary_memory = text_content
num_tokens_summary_memory = count_tokens(text_content)
# with a summary message, the real messages start at index 2
num_tokens_messages = (
num_tokens_from_messages(messages=in_context_messages_openai[2:], model=self.model)
if len(in_context_messages_openai) > 2
else 0
)
else:
summary_memory = None
num_tokens_summary_memory = 0
# with no summary message, the real messages start at index 1
num_tokens_messages = (
num_tokens_from_messages(messages=in_context_messages_openai[1:], model=self.model)
if len(in_context_messages_openai) > 1
else 0
)
# tokens taken up by function definitions
agent_state_tool_jsons = [t.json_schema for t in self.agent_state.tools]
if agent_state_tool_jsons:
available_functions_definitions = [OpenAITool(type="function", function=f) for f in agent_state_tool_jsons]
num_tokens_available_functions_definitions = num_tokens_from_functions(functions=agent_state_tool_jsons, model=self.model)
else:
available_functions_definitions = []
num_tokens_available_functions_definitions = 0
num_tokens_used_total = (
num_tokens_system # system prompt
+ num_tokens_available_functions_definitions # function definitions
+ num_tokens_core_memory # core memory
+ num_tokens_external_memory_summary # metadata (statistics) about recall/archival
+ num_tokens_summary_memory # summary of ongoing conversation
+ num_tokens_messages # tokens taken by messages
)
assert isinstance(num_tokens_used_total, int)
return ContextWindowOverview(
# context window breakdown (in messages)
num_messages=len(in_context_messages),
num_archival_memory=passage_manager_size,
num_recall_memory=message_manager_size,
num_tokens_external_memory_summary=num_tokens_external_memory_summary,
external_memory_summary=external_memory_summary,
# top-level information
context_window_size_max=self.agent_state.llm_config.context_window,
context_window_size_current=num_tokens_used_total,
# context window breakdown (in tokens)
num_tokens_system=num_tokens_system,
system_prompt=system_prompt,
num_tokens_core_memory=num_tokens_core_memory,
core_memory=core_memory,
num_tokens_summary_memory=num_tokens_summary_memory,
summary_memory=summary_memory,
num_tokens_messages=num_tokens_messages,
messages=in_context_messages,
# related to functions
num_tokens_functions_definitions=num_tokens_available_functions_definitions,
functions_definitions=available_functions_definitions,
)
async def get_context_window_from_anthropic_async(self) -> ContextWindowOverview:
"""Get the context window of the agent"""
anthropic_client = LLMClient.create(provider_type=ProviderType.anthropic, actor=self.user)
model = self.agent_state.llm_config.model if self.agent_state.llm_config.model_endpoint_type == "anthropic" else None
# Grab the in-context messages
# conversion of messages to anthropic dict format, which is passed to the token counter
(in_context_messages, passage_manager_size, message_manager_size) = await asyncio.gather(
self.agent_manager.get_in_context_messages_async(agent_id=self.agent_state.id, actor=self.user),
self.passage_manager.size_async(actor=self.user, agent_id=self.agent_state.id),
self.message_manager.size_async(actor=self.user, agent_id=self.agent_state.id),
)
in_context_messages_anthropic = [m.to_anthropic_dict() for m in in_context_messages]
# Extract system, memory and external summary
if (
len(in_context_messages) > 0
and in_context_messages[0].role == MessageRole.system
and in_context_messages[0].content
and len(in_context_messages[0].content) == 1
and isinstance(in_context_messages[0].content[0], TextContent)
):
system_message = in_context_messages[0].content[0].text
external_memory_marker_pos = system_message.find("###")
core_memory_marker_pos = system_message.find("<", external_memory_marker_pos)
if external_memory_marker_pos != -1 and core_memory_marker_pos != -1:
system_prompt = system_message[:external_memory_marker_pos].strip()
external_memory_summary = system_message[external_memory_marker_pos:core_memory_marker_pos].strip()
core_memory = system_message[core_memory_marker_pos:].strip()
else:
# if no markers found, put everything in system message
system_prompt = system_message
external_memory_summary = None
core_memory = None
else:
# if no system message, fall back on agent's system prompt
system_prompt = self.agent_state.system
external_memory_summary = None
core_memory = None
num_tokens_system_coroutine = anthropic_client.count_tokens(model=model, messages=[{"role": "user", "content": system_prompt}])
num_tokens_core_memory_coroutine = (
anthropic_client.count_tokens(model=model, messages=[{"role": "user", "content": core_memory}])
if core_memory
else asyncio.sleep(0, result=0)
)
num_tokens_external_memory_summary_coroutine = (
anthropic_client.count_tokens(model=model, messages=[{"role": "user", "content": external_memory_summary}])
if external_memory_summary
else asyncio.sleep(0, result=0)
)
# Check if there's a summary message in the message queue
if (
len(in_context_messages) > 1
and in_context_messages[1].role == MessageRole.user
and in_context_messages[1].content
and len(in_context_messages[1].content) == 1
and isinstance(in_context_messages[1].content[0], TextContent)
# TODO remove hardcoding
and "The following is a summary of the previous " in in_context_messages[1].content[0].text
):
# Summary message exists
text_content = in_context_messages[1].content[0].text
assert text_content is not None
summary_memory = text_content
num_tokens_summary_memory_coroutine = anthropic_client.count_tokens(
model=model, messages=[{"role": "user", "content": summary_memory}]
)
# with a summary message, the real messages start at index 2
num_tokens_messages_coroutine = (
anthropic_client.count_tokens(model=model, messages=in_context_messages_anthropic[2:])
if len(in_context_messages_anthropic) > 2
else asyncio.sleep(0, result=0)
)
else:
summary_memory = None
num_tokens_summary_memory_coroutine = asyncio.sleep(0, result=0)
# with no summary message, the real messages start at index 1
num_tokens_messages_coroutine = (
anthropic_client.count_tokens(model=model, messages=in_context_messages_anthropic[1:])
if len(in_context_messages_anthropic) > 1
else asyncio.sleep(0, result=0)
)
# tokens taken up by function definitions
if self.agent_state.tools and len(self.agent_state.tools) > 0:
available_functions_definitions = [OpenAITool(type="function", function=f.json_schema) for f in self.agent_state.tools]
num_tokens_available_functions_definitions_coroutine = anthropic_client.count_tokens(
model=model,
tools=available_functions_definitions,
)
else:
available_functions_definitions = []
num_tokens_available_functions_definitions_coroutine = asyncio.sleep(0, result=0)
(
num_tokens_system,
num_tokens_core_memory,
num_tokens_external_memory_summary,
num_tokens_summary_memory,
num_tokens_messages,
num_tokens_available_functions_definitions,
) = await asyncio.gather(
num_tokens_system_coroutine,
num_tokens_core_memory_coroutine,
num_tokens_external_memory_summary_coroutine,
num_tokens_summary_memory_coroutine,
num_tokens_messages_coroutine,
num_tokens_available_functions_definitions_coroutine,
)
num_tokens_used_total = (
num_tokens_system # system prompt
+ num_tokens_available_functions_definitions # function definitions
+ num_tokens_core_memory # core memory
+ num_tokens_external_memory_summary # metadata (statistics) about recall/archival
+ num_tokens_summary_memory # summary of ongoing conversation
+ num_tokens_messages # tokens taken by messages
)
assert isinstance(num_tokens_used_total, int)
return ContextWindowOverview(
# context window breakdown (in messages)
num_messages=len(in_context_messages),
num_archival_memory=passage_manager_size,
num_recall_memory=message_manager_size,
num_tokens_external_memory_summary=num_tokens_external_memory_summary,
external_memory_summary=external_memory_summary,
# top-level information
context_window_size_max=self.agent_state.llm_config.context_window,
context_window_size_current=num_tokens_used_total,
# context window breakdown (in tokens)
num_tokens_system=num_tokens_system,
system_prompt=system_prompt,
num_tokens_core_memory=num_tokens_core_memory,
core_memory=core_memory,
num_tokens_summary_memory=num_tokens_summary_memory,
summary_memory=summary_memory,
num_tokens_messages=num_tokens_messages,
messages=in_context_messages,
# related to functions
num_tokens_functions_definitions=num_tokens_available_functions_definitions,
functions_definitions=available_functions_definitions,
)
def count_tokens(self) -> int:
"""Count the tokens in the current context window"""
context_window_breakdown = self.get_context_window()

View File

@ -72,61 +72,6 @@ class BaseAgent(ABC):
return [{"role": input_message.role.value, "content": get_content(input_message)} for input_message in input_messages]
def _rebuild_memory(
self,
in_context_messages: List[Message],
agent_state: AgentState,
num_messages: int | None = None, # storing these calculations is specific to the voice agent
num_archival_memories: int | None = None,
) -> List[Message]:
try:
# Refresh Memory
# TODO: This only happens for the summary block (voice?)
# [DB Call] loading blocks (modifies: agent_state.memory.blocks)
self.agent_manager.refresh_memory(agent_state=agent_state, actor=self.actor)
# TODO: This is a pretty brittle pattern established all over our code, need to get rid of this
curr_system_message = in_context_messages[0]
curr_memory_str = agent_state.memory.compile()
curr_system_message_text = curr_system_message.content[0].text
if curr_memory_str in curr_system_message_text:
# NOTE: could this cause issues if a block is removed? (substring match would still work)
logger.debug(
f"Memory hasn't changed for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name}), skipping system prompt rebuild"
)
return in_context_messages
memory_edit_timestamp = get_utc_time()
# [DB Call] size of messages and archival memories
num_messages = num_messages or self.message_manager.size(actor=self.actor, agent_id=agent_state.id)
num_archival_memories = num_archival_memories or self.passage_manager.size(actor=self.actor, agent_id=agent_state.id)
new_system_message_str = compile_system_message(
system_prompt=agent_state.system,
in_context_memory=agent_state.memory,
in_context_memory_last_edit=memory_edit_timestamp,
previous_message_count=num_messages,
archival_memory_size=num_archival_memories,
)
diff = united_diff(curr_system_message_text, new_system_message_str)
if len(diff) > 0:
logger.debug(f"Rebuilding system with new memory...\nDiff:\n{diff}")
# [DB Call] Update Messages
new_system_message = self.message_manager.update_message_by_id(
curr_system_message.id, message_update=MessageUpdate(content=new_system_message_str), actor=self.actor
)
# Skip pulling down the agent's memory again to save on a db call
return [new_system_message] + in_context_messages[1:]
else:
return in_context_messages
except:
logger.exception(f"Failed to rebuild memory for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name})")
raise
async def _rebuild_memory_async(
self,
in_context_messages: List[Message],

View File

@ -1,3 +1,4 @@
import uuid
import xml.etree.ElementTree as ET
from typing import List, Tuple
@ -150,3 +151,7 @@ def deserialize_message_history(xml_str: str) -> Tuple[List[str], str]:
context = sum_el.text or ""
return messages, context
def generate_step_id():
return f"step-{uuid.uuid4()}"

View File

@ -8,8 +8,9 @@ from openai.types import CompletionUsage
from openai.types.chat import ChatCompletion, ChatCompletionChunk
from letta.agents.base_agent import BaseAgent
from letta.agents.helpers import _create_letta_response, _prepare_in_context_messages_async
from letta.agents.helpers import _create_letta_response, _prepare_in_context_messages_async, generate_step_id
from letta.helpers import ToolRulesSolver
from letta.helpers.datetime_helpers import get_utc_timestamp_ns
from letta.helpers.tool_execution_helper import enable_strict_mode
from letta.interfaces.anthropic_streaming_interface import AnthropicStreamingInterface
from letta.interfaces.openai_streaming_interface import OpenAIStreamingInterface
@ -24,7 +25,8 @@ from letta.schemas.letta_message import AssistantMessage
from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, RedactedReasoningContent, TextContent
from letta.schemas.letta_response import LettaResponse
from letta.schemas.message import Message, MessageCreate
from letta.schemas.openai.chat_completion_response import ToolCall
from letta.schemas.openai.chat_completion_response import ToolCall, UsageStatistics
from letta.schemas.provider_trace import ProviderTraceCreate
from letta.schemas.usage import LettaUsageStatistics
from letta.schemas.user import User
from letta.server.rest_api.utils import create_letta_messages_from_llm_response
@ -32,10 +34,11 @@ from letta.services.agent_manager import AgentManager
from letta.services.block_manager import BlockManager
from letta.services.message_manager import MessageManager
from letta.services.passage_manager import PassageManager
from letta.services.step_manager import NoopStepManager, StepManager
from letta.services.telemetry_manager import NoopTelemetryManager, TelemetryManager
from letta.services.tool_executor.tool_execution_manager import ToolExecutionManager
from letta.settings import settings
from letta.system import package_function_response
from letta.tracing import log_event, trace_method
from letta.tracing import log_event, trace_method, tracer
logger = get_logger(__name__)
@ -50,6 +53,8 @@ class LettaAgent(BaseAgent):
block_manager: BlockManager,
passage_manager: PassageManager,
actor: User,
step_manager: StepManager = NoopStepManager(),
telemetry_manager: TelemetryManager = NoopTelemetryManager(),
):
super().__init__(agent_id=agent_id, openai_client=None, message_manager=message_manager, agent_manager=agent_manager, actor=actor)
@ -57,6 +62,8 @@ class LettaAgent(BaseAgent):
# Summarizer settings
self.block_manager = block_manager
self.passage_manager = passage_manager
self.step_manager = step_manager
self.telemetry_manager = telemetry_manager
self.response_messages: List[Message] = []
self.last_function_response = None
@ -67,17 +74,19 @@ class LettaAgent(BaseAgent):
@trace_method
async def step(self, input_messages: List[MessageCreate], max_steps: int = 10, use_assistant_message: bool = True) -> LettaResponse:
agent_state = await self.agent_manager.get_agent_by_id_async(self.agent_id, actor=self.actor)
current_in_context_messages, new_in_context_messages, usage = await self._step(
agent_state=agent_state, input_messages=input_messages, max_steps=max_steps
agent_state = await self.agent_manager.get_agent_by_id_async(
agent_id=self.agent_id, include_relationships=["tools", "memory"], actor=self.actor
)
_, new_in_context_messages, usage = await self._step(agent_state=agent_state, input_messages=input_messages, max_steps=max_steps)
return _create_letta_response(
new_in_context_messages=new_in_context_messages, use_assistant_message=use_assistant_message, usage=usage
)
async def _step(
self, agent_state: AgentState, input_messages: List[MessageCreate], max_steps: int = 10
) -> Tuple[List[Message], List[Message], CompletionUsage]:
@trace_method
async def step_stream_no_tokens(self, input_messages: List[MessageCreate], max_steps: int = 10, use_assistant_message: bool = True):
agent_state = await self.agent_manager.get_agent_by_id_async(
agent_id=self.agent_id, include_relationships=["tools", "memory"], actor=self.actor
)
current_in_context_messages, new_in_context_messages = await _prepare_in_context_messages_async(
input_messages, agent_state, self.message_manager, self.actor
)
@ -89,23 +98,81 @@ class LettaAgent(BaseAgent):
)
usage = LettaUsageStatistics()
for _ in range(max_steps):
response = await self._get_ai_reply(
step_id = generate_step_id()
in_context_messages = await self._rebuild_memory_async(
current_in_context_messages + new_in_context_messages,
agent_state,
num_messages=self.num_messages,
num_archival_memories=self.num_archival_memories,
)
log_event("agent.stream_no_tokens.messages.refreshed") # [1^]
request_data = await self._create_llm_request_data_async(
llm_client=llm_client,
in_context_messages=current_in_context_messages + new_in_context_messages,
in_context_messages=in_context_messages,
agent_state=agent_state,
tool_rules_solver=tool_rules_solver,
stream=False,
# TODO: also pass in reasoning content
# TODO: pass in reasoning content
)
log_event("agent.stream_no_tokens.llm_request.created") # [2^]
try:
response_data = await llm_client.request_async(request_data, agent_state.llm_config)
except Exception as e:
raise llm_client.handle_llm_error(e)
log_event("agent.stream_no_tokens.llm_response.received") # [3^]
response = llm_client.convert_response_to_chat_completion(response_data, in_context_messages, agent_state.llm_config)
# update usage
# TODO: add run_id
usage.step_count += 1
usage.completion_tokens += response.usage.completion_tokens
usage.prompt_tokens += response.usage.prompt_tokens
usage.total_tokens += response.usage.total_tokens
if not response.choices[0].message.tool_calls:
# TODO: make into a real error
raise ValueError("No tool calls found in response, model must make a tool call")
tool_call = response.choices[0].message.tool_calls[0]
reasoning = [TextContent(text=response.choices[0].message.content)] # reasoning placed into content for legacy reasons
if response.choices[0].message.reasoning_content:
reasoning = [
ReasoningContent(
reasoning=response.choices[0].message.reasoning_content,
is_native=True,
signature=response.choices[0].message.reasoning_content_signature,
)
]
else:
reasoning = [TextContent(text=response.choices[0].message.content)] # reasoning placed into content for legacy reasons
persisted_messages, should_continue = await self._handle_ai_response(
tool_call, agent_state, tool_rules_solver, reasoning_content=reasoning
tool_call, agent_state, tool_rules_solver, response.usage, reasoning_content=reasoning
)
self.response_messages.extend(persisted_messages)
new_in_context_messages.extend(persisted_messages)
log_event("agent.stream_no_tokens.llm_response.processed") # [4^]
# Log LLM Trace
await self.telemetry_manager.create_provider_trace_async(
actor=self.actor,
provider_trace_create=ProviderTraceCreate(
request_json=request_data,
response_json=response_data,
step_id=step_id,
organization_id=self.actor.organization_id,
),
)
# stream step
# TODO: improve TTFT
filter_user_messages = [m for m in persisted_messages if m.role != "user"]
letta_messages = Message.to_letta_messages_from_list(
filter_user_messages, use_assistant_message=use_assistant_message, reverse=False
)
for message in letta_messages:
yield f"data: {message.model_dump_json()}\n\n"
# update usage
# TODO: add run_id
@ -122,17 +189,125 @@ class LettaAgent(BaseAgent):
message_ids = [m.id for m in (current_in_context_messages + new_in_context_messages)]
self.agent_manager.set_in_context_messages(agent_id=self.agent_id, message_ids=message_ids, actor=self.actor)
# Return back usage
yield f"data: {usage.model_dump_json()}\n\n"
async def _step(
self, agent_state: AgentState, input_messages: List[MessageCreate], max_steps: int = 10
) -> Tuple[List[Message], List[Message], CompletionUsage]:
"""
Carries out an invocation of the agent loop. In each step, the agent
1. Rebuilds its memory
2. Generates a request for the LLM
3. Fetches a response from the LLM
4. Processes the response
"""
current_in_context_messages, new_in_context_messages = await _prepare_in_context_messages_async(
input_messages, agent_state, self.message_manager, self.actor
)
tool_rules_solver = ToolRulesSolver(agent_state.tool_rules)
llm_client = LLMClient.create(
provider_type=agent_state.llm_config.model_endpoint_type,
put_inner_thoughts_first=True,
actor=self.actor,
)
usage = LettaUsageStatistics()
for _ in range(max_steps):
step_id = generate_step_id()
in_context_messages = await self._rebuild_memory_async(
current_in_context_messages + new_in_context_messages,
agent_state,
num_messages=self.num_messages,
num_archival_memories=self.num_archival_memories,
)
log_event("agent.step.messages.refreshed") # [1^]
request_data = await self._create_llm_request_data_async(
llm_client=llm_client,
in_context_messages=in_context_messages,
agent_state=agent_state,
tool_rules_solver=tool_rules_solver,
# TODO: pass in reasoning content
)
log_event("agent.step.llm_request.created") # [2^]
try:
response_data = await llm_client.request_async(request_data, agent_state.llm_config)
except Exception as e:
raise llm_client.handle_llm_error(e)
log_event("agent.step.llm_response.received") # [3^]
response = llm_client.convert_response_to_chat_completion(response_data, in_context_messages, agent_state.llm_config)
# TODO: add run_id
usage.step_count += 1
usage.completion_tokens += response.usage.completion_tokens
usage.prompt_tokens += response.usage.prompt_tokens
usage.total_tokens += response.usage.total_tokens
if not response.choices[0].message.tool_calls:
# TODO: make into a real error
raise ValueError("No tool calls found in response, model must make a tool call")
tool_call = response.choices[0].message.tool_calls[0]
if response.choices[0].message.reasoning_content:
reasoning = [
ReasoningContent(
reasoning=response.choices[0].message.reasoning_content,
is_native=True,
signature=response.choices[0].message.reasoning_content_signature,
)
]
else:
reasoning = [TextContent(text=response.choices[0].message.content)] # reasoning placed into content for legacy reasons
persisted_messages, should_continue = await self._handle_ai_response(
tool_call, agent_state, tool_rules_solver, response.usage, reasoning_content=reasoning, step_id=step_id
)
self.response_messages.extend(persisted_messages)
new_in_context_messages.extend(persisted_messages)
log_event("agent.step.llm_response.processed") # [4^]
# Log LLM Trace
await self.telemetry_manager.create_provider_trace_async(
actor=self.actor,
provider_trace_create=ProviderTraceCreate(
request_json=request_data,
response_json=response_data,
step_id=step_id,
organization_id=self.actor.organization_id,
),
)
if not should_continue:
break
# Extend the in context message ids
if not agent_state.message_buffer_autoclear:
message_ids = [m.id for m in (current_in_context_messages + new_in_context_messages)]
self.agent_manager.set_in_context_messages(agent_id=self.agent_id, message_ids=message_ids, actor=self.actor)
return current_in_context_messages, new_in_context_messages, usage
@trace_method
async def step_stream(
self, input_messages: List[MessageCreate], max_steps: int = 10, use_assistant_message: bool = True, stream_tokens: bool = False
self,
input_messages: List[MessageCreate],
max_steps: int = 10,
use_assistant_message: bool = True,
request_start_timestamp_ns: Optional[int] = None,
) -> AsyncGenerator[str, None]:
"""
Main streaming loop that yields partial tokens.
Whenever we detect a tool call, we yield from _handle_ai_response as well.
Carries out an invocation of the agent loop in a streaming fashion that yields partial tokens.
Whenever we detect a tool call, we yield from _handle_ai_response as well. At each step, the agent
1. Rebuilds its memory
2. Generates a request for the LLM
3. Fetches a response from the LLM
4. Processes the response
"""
agent_state = await self.agent_manager.get_agent_by_id_async(self.agent_id, actor=self.actor)
agent_state = await self.agent_manager.get_agent_by_id_async(
agent_id=self.agent_id, include_relationships=["tools", "memory"], actor=self.actor
)
current_in_context_messages, new_in_context_messages = await _prepare_in_context_messages_async(
input_messages, agent_state, self.message_manager, self.actor
)
@ -145,13 +320,29 @@ class LettaAgent(BaseAgent):
usage = LettaUsageStatistics()
for _ in range(max_steps):
stream = await self._get_ai_reply(
step_id = generate_step_id()
in_context_messages = await self._rebuild_memory_async(
current_in_context_messages + new_in_context_messages,
agent_state,
num_messages=self.num_messages,
num_archival_memories=self.num_archival_memories,
)
log_event("agent.step.messages.refreshed") # [1^]
request_data = await self._create_llm_request_data_async(
llm_client=llm_client,
in_context_messages=current_in_context_messages + new_in_context_messages,
in_context_messages=in_context_messages,
agent_state=agent_state,
tool_rules_solver=tool_rules_solver,
stream=True,
)
log_event("agent.stream.llm_request.created") # [2^]
try:
stream = await llm_client.stream_async(request_data, agent_state.llm_config)
except Exception as e:
raise llm_client.handle_llm_error(e)
log_event("agent.stream.llm_response.received") # [3^]
# TODO: THIS IS INCREDIBLY UGLY
# TODO: THERE ARE MULTIPLE COPIES OF THE LLM_CONFIG EVERYWHERE THAT ARE GETTING MANIPULATED
if agent_state.llm_config.model_endpoint_type == "anthropic":
@ -164,7 +355,23 @@ class LettaAgent(BaseAgent):
use_assistant_message=use_assistant_message,
put_inner_thoughts_in_kwarg=agent_state.llm_config.put_inner_thoughts_in_kwargs,
)
else:
raise ValueError(f"Streaming not supported for {agent_state.llm_config}")
first_chunk, ttft_span = True, None
if request_start_timestamp_ns is not None:
ttft_span = tracer.start_span("time_to_first_token", start_time=request_start_timestamp_ns)
ttft_span.set_attributes({f"llm_config.{k}": v for k, v in agent_state.llm_config.model_dump().items() if v is not None})
async for chunk in interface.process(stream):
# Measure time to first token
if first_chunk and ttft_span is not None:
now = get_utc_timestamp_ns()
ttft_ns = now - request_start_timestamp_ns
ttft_span.add_event(name="time_to_first_token_ms", attributes={"ttft_ms": ttft_ns // 1_000_000})
ttft_span.end()
first_chunk = False
yield f"data: {chunk.model_dump_json()}\n\n"
# update usage
@ -180,13 +387,46 @@ class LettaAgent(BaseAgent):
tool_call,
agent_state,
tool_rules_solver,
UsageStatistics(
completion_tokens=interface.output_tokens,
prompt_tokens=interface.input_tokens,
total_tokens=interface.input_tokens + interface.output_tokens,
),
reasoning_content=reasoning_content,
pre_computed_assistant_message_id=interface.letta_assistant_message_id,
pre_computed_tool_message_id=interface.letta_tool_message_id,
step_id=step_id,
)
self.response_messages.extend(persisted_messages)
new_in_context_messages.extend(persisted_messages)
# TODO (cliandy): the stream POST request span has ended at this point, we should tie this to the stream
# log_event("agent.stream.llm_response.processed") # [4^]
# Log LLM Trace
# TODO (cliandy): we are piecing together the streamed response here. Content here does not match the actual response schema.
await self.telemetry_manager.create_provider_trace_async(
actor=self.actor,
provider_trace_create=ProviderTraceCreate(
request_json=request_data,
response_json={
"content": {
"tool_call": tool_call.model_dump_json(),
"reasoning": [content.model_dump_json() for content in reasoning_content],
},
"id": interface.message_id,
"model": interface.model,
"role": "assistant",
# "stop_reason": "",
# "stop_sequence": None,
"type": "message",
"usage": {"input_tokens": interface.input_tokens, "output_tokens": interface.output_tokens},
},
step_id=step_id,
organization_id=self.actor.organization_id,
),
)
if not use_assistant_message or should_continue:
tool_return = [msg for msg in persisted_messages if msg.role == "tool"][-1].to_letta_messages()[0]
yield f"data: {tool_return.model_dump_json()}\n\n"
@ -209,28 +449,20 @@ class LettaAgent(BaseAgent):
yield f"data: {MessageStreamStatus.done.model_dump_json()}\n\n"
@trace_method
# When raising an error this doesn't show up
async def _get_ai_reply(
async def _create_llm_request_data_async(
self,
llm_client: LLMClientBase,
in_context_messages: List[Message],
agent_state: AgentState,
tool_rules_solver: ToolRulesSolver,
stream: bool,
) -> ChatCompletion | AsyncStream[ChatCompletionChunk]:
if settings.experimental_enable_async_db_engine:
self.num_messages = self.num_messages or (await self.message_manager.size_async(actor=self.actor, agent_id=agent_state.id))
self.num_archival_memories = self.num_archival_memories or (
await self.passage_manager.size_async(actor=self.actor, agent_id=agent_state.id)
)
in_context_messages = await self._rebuild_memory_async(
in_context_messages, agent_state, num_messages=self.num_messages, num_archival_memories=self.num_archival_memories
)
else:
if settings.experimental_skip_rebuild_memory and agent_state.llm_config.model_endpoint_type == "google_vertex":
logger.info("Skipping memory rebuild")
else:
in_context_messages = self._rebuild_memory(in_context_messages, agent_state)
self.num_messages = self.num_messages or (await self.message_manager.size_async(actor=self.actor, agent_id=agent_state.id))
self.num_archival_memories = self.num_archival_memories or (
await self.passage_manager.size_async(actor=self.actor, agent_id=agent_state.id)
)
in_context_messages = await self._rebuild_memory_async(
in_context_messages, agent_state, num_messages=self.num_messages, num_archival_memories=self.num_archival_memories
)
tools = [
t
@ -243,8 +475,8 @@ class LettaAgent(BaseAgent):
ToolType.LETTA_MULTI_AGENT_CORE,
ToolType.LETTA_SLEEPTIME_CORE,
ToolType.LETTA_VOICE_SLEEPTIME_CORE,
ToolType.LETTA_BUILTIN,
}
or (t.tool_type == ToolType.LETTA_MULTI_AGENT_CORE and t.name == "send_message_to_agents_matching_tags")
or (t.tool_type == ToolType.EXTERNAL_COMPOSIO)
]
@ -264,15 +496,7 @@ class LettaAgent(BaseAgent):
allowed_tools = [enable_strict_mode(t.json_schema) for t in tools if t.name in set(valid_tool_names)]
response = await llm_client.send_llm_request_async(
messages=in_context_messages,
llm_config=agent_state.llm_config,
tools=allowed_tools,
force_tool_call=force_tool_call,
stream=stream,
)
return response
return llm_client.build_request_data(in_context_messages, agent_state.llm_config, allowed_tools, force_tool_call)
@trace_method
async def _handle_ai_response(
@ -280,9 +504,11 @@ class LettaAgent(BaseAgent):
tool_call: ToolCall,
agent_state: AgentState,
tool_rules_solver: ToolRulesSolver,
usage: UsageStatistics,
reasoning_content: Optional[List[Union[TextContent, ReasoningContent, RedactedReasoningContent, OmittedReasoningContent]]] = None,
pre_computed_assistant_message_id: Optional[str] = None,
pre_computed_tool_message_id: Optional[str] = None,
step_id: str | None = None,
) -> Tuple[List[Message], bool]:
"""
Now that streaming is done, handle the final AI response.
@ -294,8 +520,11 @@ class LettaAgent(BaseAgent):
try:
tool_args = json.loads(tool_call_args_str)
assert isinstance(tool_args, dict), "tool_args must be a dict"
except json.JSONDecodeError:
tool_args = {}
except AssertionError:
tool_args = json.loads(tool_args)
# Get request heartbeats and coerce to bool
request_heartbeat = tool_args.pop("request_heartbeat", False)
@ -329,7 +558,25 @@ class LettaAgent(BaseAgent):
elif tool_rules_solver.is_continue_tool(tool_name=tool_call_name):
continue_stepping = True
# 5. Persist to DB
# 5a. Persist Steps to DB
# Following agent loop to persist this before messages
# TODO (cliandy): determine what should match old loop w/provider_id, job_id
# TODO (cliandy): UsageStatistics and LettaUsageStatistics are used in many places, but are not the same.
logged_step = await self.step_manager.log_step_async(
actor=self.actor,
agent_id=agent_state.id,
provider_name=agent_state.llm_config.model_endpoint_type,
provider_category=agent_state.llm_config.provider_category or "base",
model=agent_state.llm_config.model,
model_endpoint=agent_state.llm_config.model_endpoint,
context_window_limit=agent_state.llm_config.context_window,
usage=usage,
provider_id=None,
job_id=None,
step_id=step_id,
)
# 5b. Persist Messages to DB
tool_call_messages = create_letta_messages_from_llm_response(
agent_id=agent_state.id,
model=agent_state.llm_config.model,
@ -343,6 +590,7 @@ class LettaAgent(BaseAgent):
reasoning_content=reasoning_content,
pre_computed_assistant_message_id=pre_computed_assistant_message_id,
pre_computed_tool_message_id=pre_computed_tool_message_id,
step_id=logged_step.id if logged_step else None, # TODO (cliandy): eventually move over other agent loops
)
persisted_messages = await self.message_manager.create_many_messages_async(tool_call_messages, actor=self.actor)
self.last_function_response = function_response
@ -361,20 +609,21 @@ class LettaAgent(BaseAgent):
# TODO: This temp. Move this logic and code to executors
try:
if target_tool.name == "send_message_to_agents_matching_tags" and target_tool.tool_type == ToolType.LETTA_MULTI_AGENT_CORE:
log_event(name="start_send_message_to_agents_matching_tags", attributes=tool_args)
results = await self._send_message_to_agents_matching_tags(**tool_args)
log_event(name="finish_send_message_to_agents_matching_tags", attributes=tool_args)
return json.dumps(results), True
else:
tool_execution_manager = ToolExecutionManager(agent_state=agent_state, actor=self.actor)
# TODO: Integrate sandbox result
log_event(name=f"start_{tool_name}_execution", attributes=tool_args)
tool_execution_result = await tool_execution_manager.execute_tool_async(
function_name=tool_name, function_args=tool_args, tool=target_tool
)
log_event(name=f"finish_{tool_name}_execution", attributes=tool_args)
return tool_execution_result.func_return, True
tool_execution_manager = ToolExecutionManager(
agent_state=agent_state,
message_manager=self.message_manager,
agent_manager=self.agent_manager,
block_manager=self.block_manager,
passage_manager=self.passage_manager,
actor=self.actor,
)
# TODO: Integrate sandbox result
log_event(name=f"start_{tool_name}_execution", attributes=tool_args)
tool_execution_result = await tool_execution_manager.execute_tool_async(
function_name=tool_name, function_args=tool_args, tool=target_tool
)
log_event(name=f"finish_{tool_name}_execution", attributes=tool_args)
return tool_execution_result.func_return, True
except Exception as e:
return f"Failed to call tool. Error: {e}", False
@ -430,6 +679,7 @@ class LettaAgent(BaseAgent):
results = await asyncio.gather(*tasks)
return results
@trace_method
async def _load_last_function_response_async(self):
"""Load the last function response from message history"""
in_context_messages = await self.agent_manager.get_in_context_messages_async(agent_id=self.agent_id, actor=self.actor)

View File

@ -145,7 +145,7 @@ class LettaAgentBatch(BaseAgent):
agent_mapping = {
agent_state.id: agent_state
for agent_state in await self.agent_manager.get_agents_by_ids_async(
agent_ids=[request.agent_id for request in batch_requests], actor=self.actor
agent_ids=[request.agent_id for request in batch_requests], include_relationships=["tools", "memory"], actor=self.actor
)
}
@ -267,64 +267,121 @@ class LettaAgentBatch(BaseAgent):
@trace_method
async def _collect_resume_context(self, llm_batch_id: str) -> _ResumeContext:
# NOTE: We only continue for items with successful results
"""
Collect context for resuming operations from completed batch items.
Args:
llm_batch_id: The ID of the batch to collect context for
Returns:
_ResumeContext object containing all necessary data for resumption
"""
# Fetch only completed batch items
batch_items = await self.batch_manager.list_llm_batch_items_async(llm_batch_id=llm_batch_id, request_status=JobStatus.completed)
agent_ids = []
provider_results = {}
request_status_updates: List[RequestStatusUpdateInfo] = []
# Exit early if no items to process
if not batch_items:
return _ResumeContext(
batch_items=[],
agent_ids=[],
agent_state_map={},
provider_results={},
tool_call_name_map={},
tool_call_args_map={},
should_continue_map={},
request_status_updates=[],
)
for item in batch_items:
aid = item.agent_id
agent_ids.append(aid)
provider_results[aid] = item.batch_request_result.result
# Extract agent IDs and organize items by agent ID
agent_ids = [item.agent_id for item in batch_items]
batch_item_map = {item.agent_id: item for item in batch_items}
agent_states = await self.agent_manager.get_agents_by_ids_async(agent_ids, actor=self.actor)
# Collect provider results
provider_results = {item.agent_id: item.batch_request_result.result for item in batch_items}
# Fetch agent states in a single call
agent_states = await self.agent_manager.get_agents_by_ids_async(
agent_ids=agent_ids, include_relationships=["tools", "memory"], actor=self.actor
)
agent_state_map = {agent.id: agent for agent in agent_states}
name_map, args_map, cont_map = {}, {}, {}
for aid in agent_ids:
# status bookkeeping
pr = provider_results[aid]
status = (
JobStatus.completed
if isinstance(pr, BetaMessageBatchSucceededResult)
else (
JobStatus.failed
if isinstance(pr, BetaMessageBatchErroredResult)
else JobStatus.cancelled if isinstance(pr, BetaMessageBatchCanceledResult) else JobStatus.expired
)
)
request_status_updates.append(RequestStatusUpdateInfo(llm_batch_id=llm_batch_id, agent_id=aid, request_status=status))
# translate providerspecific response → OpenAIstyle tool call (unchanged)
llm_client = LLMClient.create(
provider_type=item.llm_config.model_endpoint_type,
put_inner_thoughts_first=True,
actor=self.actor,
)
tool_call = (
llm_client.convert_response_to_chat_completion(
response_data=pr.message.model_dump(), input_messages=[], llm_config=item.llm_config
)
.choices[0]
.message.tool_calls[0]
)
name, args, cont = self._extract_tool_call_and_decide_continue(tool_call, item.step_state)
name_map[aid], args_map[aid], cont_map[aid] = name, args, cont
# Process each agent's results
tool_call_results = self._process_agent_results(
agent_ids=agent_ids, batch_item_map=batch_item_map, provider_results=provider_results, llm_batch_id=llm_batch_id
)
return _ResumeContext(
batch_items=batch_items,
agent_ids=agent_ids,
agent_state_map=agent_state_map,
provider_results=provider_results,
tool_call_name_map=name_map,
tool_call_args_map=args_map,
should_continue_map=cont_map,
request_status_updates=request_status_updates,
tool_call_name_map=tool_call_results.name_map,
tool_call_args_map=tool_call_results.args_map,
should_continue_map=tool_call_results.cont_map,
request_status_updates=tool_call_results.status_updates,
)
def _process_agent_results(self, agent_ids, batch_item_map, provider_results, llm_batch_id):
"""
Process the results for each agent, extracting tool calls and determining continuation status.
Returns:
A namedtuple containing name_map, args_map, cont_map, and status_updates
"""
from collections import namedtuple
ToolCallResults = namedtuple("ToolCallResults", ["name_map", "args_map", "cont_map", "status_updates"])
name_map, args_map, cont_map = {}, {}, {}
request_status_updates = []
for aid in agent_ids:
item = batch_item_map[aid]
result = provider_results[aid]
# Determine job status based on result type
status = self._determine_job_status(result)
request_status_updates.append(RequestStatusUpdateInfo(llm_batch_id=llm_batch_id, agent_id=aid, request_status=status))
# Process tool calls
name, args, cont = self._extract_tool_call_from_result(item, result)
name_map[aid], args_map[aid], cont_map[aid] = name, args, cont
return ToolCallResults(name_map, args_map, cont_map, request_status_updates)
def _determine_job_status(self, result):
"""Determine job status based on result type"""
if isinstance(result, BetaMessageBatchSucceededResult):
return JobStatus.completed
elif isinstance(result, BetaMessageBatchErroredResult):
return JobStatus.failed
elif isinstance(result, BetaMessageBatchCanceledResult):
return JobStatus.cancelled
else:
return JobStatus.expired
def _extract_tool_call_from_result(self, item, result):
"""Extract tool call information from a result"""
llm_client = LLMClient.create(
provider_type=item.llm_config.model_endpoint_type,
put_inner_thoughts_first=True,
actor=self.actor,
)
# If result isn't a successful type, we can't extract a tool call
if not isinstance(result, BetaMessageBatchSucceededResult):
return None, None, False
tool_call = (
llm_client.convert_response_to_chat_completion(
response_data=result.message.model_dump(), input_messages=[], llm_config=item.llm_config
)
.choices[0]
.message.tool_calls[0]
)
return self._extract_tool_call_and_decide_continue(tool_call, item.step_state)
def _update_request_statuses(self, updates: List[RequestStatusUpdateInfo]) -> None:
if updates:
self.batch_manager.bulk_update_llm_batch_items_request_status_by_agent(updates=updates)
@ -556,16 +613,6 @@ class LettaAgentBatch(BaseAgent):
in_context_messages = await self._rebuild_memory_async(current_in_context_messages + new_in_context_messages, agent_state)
return in_context_messages
# TODO: Make this a bullk function
def _rebuild_memory(
self,
in_context_messages: List[Message],
agent_state: AgentState,
num_messages: int | None = None,
num_archival_memories: int | None = None,
) -> List[Message]:
return super()._rebuild_memory(in_context_messages, agent_state)
# Not used in batch.
async def step(self, input_messages: List[MessageCreate], max_steps: int = 10) -> LettaResponse:
raise NotImplementedError

View File

@ -154,7 +154,7 @@ class VoiceAgent(BaseAgent):
# TODO: Define max steps here
for _ in range(max_steps):
# Rebuild memory each loop
in_context_messages = self._rebuild_memory(in_context_messages, agent_state)
in_context_messages = await self._rebuild_memory_async(in_context_messages, agent_state)
openai_messages = convert_in_context_letta_messages_to_openai(in_context_messages, exclude_system_messages=True)
openai_messages.extend(in_memory_message_history)
@ -292,14 +292,14 @@ class VoiceAgent(BaseAgent):
agent_id=self.agent_id, message_ids=[m.id for m in new_in_context_messages], actor=self.actor
)
def _rebuild_memory(
async def _rebuild_memory_async(
self,
in_context_messages: List[Message],
agent_state: AgentState,
num_messages: int | None = None,
num_archival_memories: int | None = None,
) -> List[Message]:
return super()._rebuild_memory(
return await super()._rebuild_memory_async(
in_context_messages, agent_state, num_messages=self.num_messages, num_archival_memories=self.num_archival_memories
)
@ -438,7 +438,7 @@ class VoiceAgent(BaseAgent):
if start_date and end_date and start_date > end_date:
start_date, end_date = end_date, start_date
archival_results = self.agent_manager.list_passages(
archival_results = await self.agent_manager.list_passages_async(
actor=self.actor,
agent_id=self.agent_id,
query_text=archival_query,
@ -457,7 +457,7 @@ class VoiceAgent(BaseAgent):
keyword_results = {}
if convo_keyword_queries:
for keyword in convo_keyword_queries:
messages = self.message_manager.list_messages_for_agent(
messages = await self.message_manager.list_messages_for_agent_async(
agent_id=self.agent_id,
actor=self.actor,
query_text=keyword,

View File

@ -2773,11 +2773,8 @@ class LocalClient(AbstractClient):
# humans / personas
def get_block_id(self, name: str, label: str) -> str:
block = self.server.block_manager.get_blocks(actor=self.user, template_name=name, label=label, is_template=True)
if not block:
return None
return block[0].id
def get_block_id(self, name: str, label: str) -> str | None:
return None
def create_human(self, name: str, text: str):
"""
@ -2812,7 +2809,7 @@ class LocalClient(AbstractClient):
Returns:
humans (List[Human]): List of human blocks
"""
return self.server.block_manager.get_blocks(actor=self.user, label="human", is_template=True)
return []
def list_personas(self) -> List[Persona]:
"""
@ -2821,7 +2818,7 @@ class LocalClient(AbstractClient):
Returns:
personas (List[Persona]): List of persona blocks
"""
return self.server.block_manager.get_blocks(actor=self.user, label="persona", is_template=True)
return []
def update_human(self, human_id: str, text: str):
"""
@ -2879,7 +2876,7 @@ class LocalClient(AbstractClient):
assert id, f"Human ID must be provided"
return Human(**self.server.block_manager.get_block_by_id(id, actor=self.user).model_dump())
def get_persona_id(self, name: str) -> str:
def get_persona_id(self, name: str) -> str | None:
"""
Get the ID of a persona block template
@ -2889,12 +2886,9 @@ class LocalClient(AbstractClient):
Returns:
id (str): ID of the persona block
"""
persona = self.server.block_manager.get_blocks(actor=self.user, template_name=name, label="persona", is_template=True)
if not persona:
return None
return persona[0].id
return None
def get_human_id(self, name: str) -> str:
def get_human_id(self, name: str) -> str | None:
"""
Get the ID of a human block template
@ -2904,10 +2898,7 @@ class LocalClient(AbstractClient):
Returns:
id (str): ID of the human block
"""
human = self.server.block_manager.get_blocks(actor=self.user, template_name=name, label="human", is_template=True)
if not human:
return None
return human[0].id
return None
def delete_persona(self, id: str):
"""
@ -3381,7 +3372,7 @@ class LocalClient(AbstractClient):
Returns:
blocks (List[Block]): List of blocks
"""
return self.server.block_manager.get_blocks(actor=self.user, label=label, is_template=templates_only)
return []
def create_block(
self, label: str, value: str, limit: Optional[int] = None, template_name: Optional[str] = None, is_template: bool = False

View File

@ -19,6 +19,7 @@ MCP_TOOL_TAG_NAME_PREFIX = "mcp" # full format, mcp:server_name
LETTA_CORE_TOOL_MODULE_NAME = "letta.functions.function_sets.base"
LETTA_MULTI_AGENT_TOOL_MODULE_NAME = "letta.functions.function_sets.multi_agent"
LETTA_VOICE_TOOL_MODULE_NAME = "letta.functions.function_sets.voice"
LETTA_BUILTIN_TOOL_MODULE_NAME = "letta.functions.function_sets.builtin"
# String in the error message for when the context window is too large
@ -83,9 +84,19 @@ BASE_VOICE_SLEEPTIME_TOOLS = [
]
# Multi agent tools
MULTI_AGENT_TOOLS = ["send_message_to_agent_and_wait_for_reply", "send_message_to_agents_matching_tags", "send_message_to_agent_async"]
# Built in tools
BUILTIN_TOOLS = ["run_code", "web_search"]
# Set of all built-in Letta tools
LETTA_TOOL_SET = set(
BASE_TOOLS + BASE_MEMORY_TOOLS + MULTI_AGENT_TOOLS + BASE_SLEEPTIME_TOOLS + BASE_VOICE_SLEEPTIME_TOOLS + BASE_VOICE_SLEEPTIME_CHAT_TOOLS
BASE_TOOLS
+ BASE_MEMORY_TOOLS
+ MULTI_AGENT_TOOLS
+ BASE_SLEEPTIME_TOOLS
+ BASE_VOICE_SLEEPTIME_TOOLS
+ BASE_VOICE_SLEEPTIME_CHAT_TOOLS
+ BUILTIN_TOOLS
)
# The name of the tool used to send message to the user
@ -179,6 +190,45 @@ LLM_MAX_TOKENS = {
"gpt-3.5-turbo-0613": 4096, # legacy
"gpt-3.5-turbo-16k-0613": 16385, # legacy
"gpt-3.5-turbo-0301": 4096, # legacy
"gemini-1.0-pro-vision-latest": 12288,
"gemini-pro-vision": 12288,
"gemini-1.5-pro-latest": 2000000,
"gemini-1.5-pro-001": 2000000,
"gemini-1.5-pro-002": 2000000,
"gemini-1.5-pro": 2000000,
"gemini-1.5-flash-latest": 1000000,
"gemini-1.5-flash-001": 1000000,
"gemini-1.5-flash-001-tuning": 16384,
"gemini-1.5-flash": 1000000,
"gemini-1.5-flash-002": 1000000,
"gemini-1.5-flash-8b": 1000000,
"gemini-1.5-flash-8b-001": 1000000,
"gemini-1.5-flash-8b-latest": 1000000,
"gemini-1.5-flash-8b-exp-0827": 1000000,
"gemini-1.5-flash-8b-exp-0924": 1000000,
"gemini-2.5-pro-exp-03-25": 1048576,
"gemini-2.5-pro-preview-03-25": 1048576,
"gemini-2.5-flash-preview-04-17": 1048576,
"gemini-2.5-flash-preview-05-20": 1048576,
"gemini-2.5-flash-preview-04-17-thinking": 1048576,
"gemini-2.5-pro-preview-05-06": 1048576,
"gemini-2.0-flash-exp": 1048576,
"gemini-2.0-flash": 1048576,
"gemini-2.0-flash-001": 1048576,
"gemini-2.0-flash-exp-image-generation": 1048576,
"gemini-2.0-flash-lite-001": 1048576,
"gemini-2.0-flash-lite": 1048576,
"gemini-2.0-flash-preview-image-generation": 32768,
"gemini-2.0-flash-lite-preview-02-05": 1048576,
"gemini-2.0-flash-lite-preview": 1048576,
"gemini-2.0-pro-exp": 1048576,
"gemini-2.0-pro-exp-02-05": 1048576,
"gemini-exp-1206": 1048576,
"gemini-2.0-flash-thinking-exp-01-21": 1048576,
"gemini-2.0-flash-thinking-exp": 1048576,
"gemini-2.0-flash-thinking-exp-1219": 1048576,
"gemini-2.5-flash-preview-tts": 32768,
"gemini-2.5-pro-preview-tts": 65536,
}
# The error message that Letta will receive
# MESSAGE_SUMMARY_WARNING_STR = f"Warning: the conversation history will soon reach its maximum length and be trimmed. Make sure to save any important information from the conversation to your memory before it is removed."
@ -230,3 +280,7 @@ RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE = 5
MAX_FILENAME_LENGTH = 255
RESERVED_FILENAMES = {"CON", "PRN", "AUX", "NUL", "COM1", "COM2", "LPT1", "LPT2"}
WEB_SEARCH_CLIP_CONTENT = False
WEB_SEARCH_INCLUDE_SCORE = False
WEB_SEARCH_SEPARATOR = "\n" + "-" * 40 + "\n"

View File

@ -0,0 +1,27 @@
from typing import Literal
async def web_search(query: str) -> str:
"""
Search the web for information.
Args:
query (str): The query to search the web for.
Returns:
str: The search results.
"""
raise NotImplementedError("This is only available on the latest agent architecture. Please contact the Letta team.")
def run_code(code: str, language: Literal["python", "js", "ts", "r", "java"]) -> str:
"""
Run code in a sandbox. Supports Python, Javascript, Typescript, R, and Java.
Args:
code (str): The code to run.
language (Literal["python", "js", "ts", "r", "java"]): The language of the code.
Returns:
str: The output of the code, the stdout, the stderr, and error traces (if any).
"""
raise NotImplementedError("This is only available on the latest agent architecture. Please contact the Letta team.")

View File

@ -190,7 +190,7 @@ class SleeptimeMultiAgentV2(BaseAgent):
prior_messages = []
if self.group.sleeptime_agent_frequency:
try:
prior_messages = self.message_manager.list_messages_for_agent(
prior_messages = await self.message_manager.list_messages_for_agent_async(
agent_id=foreground_agent_id,
actor=self.actor,
after=last_processed_message_id,

View File

@ -1,3 +1,4 @@
import json
from datetime import datetime, timezone
from enum import Enum
from typing import AsyncGenerator, List, Union
@ -74,6 +75,7 @@ class AnthropicStreamingInterface:
# usage trackers
self.input_tokens = 0
self.output_tokens = 0
self.model = None
# reasoning object trackers
self.reasoning_messages = []
@ -88,7 +90,13 @@ class AnthropicStreamingInterface:
def get_tool_call_object(self) -> ToolCall:
"""Useful for agent loop"""
return ToolCall(id=self.tool_call_id, function=FunctionCall(arguments=self.accumulated_tool_call_args, name=self.tool_call_name))
# hack for tool rules
tool_input = json.loads(self.accumulated_tool_call_args)
if "id" in tool_input and tool_input["id"].startswith("toolu_") and "function" in tool_input:
arguments = str(json.dumps(tool_input["function"]["arguments"], indent=2))
else:
arguments = self.accumulated_tool_call_args
return ToolCall(id=self.tool_call_id, function=FunctionCall(arguments=arguments, name=self.tool_call_name))
def _check_inner_thoughts_complete(self, combined_args: str) -> bool:
"""
@ -311,6 +319,7 @@ class AnthropicStreamingInterface:
self.message_id = event.message.id
self.input_tokens += event.message.usage.input_tokens
self.output_tokens += event.message.usage.output_tokens
self.model = event.message.model
elif isinstance(event, BetaRawMessageDeltaEvent):
self.output_tokens += event.usage.output_tokens
elif isinstance(event, BetaRawMessageStopEvent):

View File

@ -40,6 +40,9 @@ class OpenAIStreamingInterface:
self.letta_assistant_message_id = Message.generate_id()
self.letta_tool_message_id = Message.generate_id()
self.message_id = None
self.model = None
# token counters
self.input_tokens = 0
self.output_tokens = 0
@ -69,10 +72,14 @@ class OpenAIStreamingInterface:
prev_message_type = None
message_index = 0
async for chunk in stream:
if not self.model or not self.message_id:
self.model = chunk.model
self.message_id = chunk.id
# track usage
if chunk.usage:
self.input_tokens += len(chunk.usage.prompt_tokens)
self.output_tokens += len(chunk.usage.completion_tokens)
self.input_tokens += chunk.usage.prompt_tokens
self.output_tokens += chunk.usage.completion_tokens
if chunk.choices:
choice = chunk.choices[0]

View File

@ -134,13 +134,13 @@ def anthropic_check_valid_api_key(api_key: Union[str, None]) -> None:
def antropic_get_model_context_window(url: str, api_key: Union[str, None], model: str) -> int:
for model_dict in anthropic_get_model_list(url=url, api_key=api_key):
for model_dict in anthropic_get_model_list(api_key=api_key):
if model_dict["name"] == model:
return model_dict["context_window"]
raise ValueError(f"Can't find model '{model}' in Anthropic model list")
def anthropic_get_model_list(url: str, api_key: Union[str, None]) -> dict:
def anthropic_get_model_list(api_key: Optional[str]) -> dict:
"""https://docs.anthropic.com/claude/docs/models-overview"""
# NOTE: currently there is no GET /models, so we need to hardcode
@ -159,6 +159,25 @@ def anthropic_get_model_list(url: str, api_key: Union[str, None]) -> dict:
return models_json["data"]
async def anthropic_get_model_list_async(api_key: Optional[str]) -> dict:
"""https://docs.anthropic.com/claude/docs/models-overview"""
# NOTE: currently there is no GET /models, so we need to hardcode
# return MODEL_LIST
if api_key:
anthropic_client = anthropic.AsyncAnthropic(api_key=api_key)
elif model_settings.anthropic_api_key:
anthropic_client = anthropic.AsyncAnthropic()
else:
raise ValueError("No API key provided")
models = await anthropic_client.models.list()
models_json = models.model_dump()
assert "data" in models_json, f"Anthropic model query response missing 'data' field: {models_json}"
return models_json["data"]
def convert_tools_to_anthropic_format(tools: List[Tool]) -> List[dict]:
"""See: https://docs.anthropic.com/claude/docs/tool-use

View File

@ -35,6 +35,7 @@ from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
from letta.schemas.openai.chat_completion_response import Message as ChoiceMessage
from letta.schemas.openai.chat_completion_response import ToolCall, UsageStatistics
from letta.services.provider_manager import ProviderManager
from letta.settings import model_settings
from letta.tracing import trace_method
DUMMY_FIRST_USER_MESSAGE = "User initializing bootup sequence."
@ -120,8 +121,16 @@ class AnthropicClient(LLMClientBase):
override_key = ProviderManager().get_override_key(llm_config.provider_name, actor=self.actor)
if async_client:
return anthropic.AsyncAnthropic(api_key=override_key) if override_key else anthropic.AsyncAnthropic()
return anthropic.Anthropic(api_key=override_key) if override_key else anthropic.Anthropic()
return (
anthropic.AsyncAnthropic(api_key=override_key, max_retries=model_settings.anthropic_max_retries)
if override_key
else anthropic.AsyncAnthropic(max_retries=model_settings.anthropic_max_retries)
)
return (
anthropic.Anthropic(api_key=override_key, max_retries=model_settings.anthropic_max_retries)
if override_key
else anthropic.Anthropic(max_retries=model_settings.anthropic_max_retries)
)
@trace_method
def build_request_data(
@ -239,6 +248,24 @@ class AnthropicClient(LLMClientBase):
return data
async def count_tokens(self, messages: List[dict] = None, model: str = None, tools: List[Tool] = None) -> int:
client = anthropic.AsyncAnthropic()
if messages and len(messages) == 0:
messages = None
if tools and len(tools) > 0:
anthropic_tools = convert_tools_to_anthropic_format(tools)
else:
anthropic_tools = None
result = await client.beta.messages.count_tokens(
model=model or "claude-3-7-sonnet-20250219",
messages=messages or [{"role": "user", "content": "hi"}],
tools=anthropic_tools or [],
)
token_count = result.input_tokens
if messages is None:
token_count -= 8
return token_count
def handle_llm_error(self, e: Exception) -> Exception:
if isinstance(e, anthropic.APIConnectionError):
logger.warning(f"[Anthropic] API connection error: {e.__cause__}")
@ -369,11 +396,11 @@ class AnthropicClient(LLMClientBase):
content = strip_xml_tags(string=content_part.text, tag="thinking")
if content_part.type == "tool_use":
# hack for tool rules
input = json.loads(json.dumps(content_part.input))
if "id" in input and input["id"].startswith("toolu_") and "function" in input:
arguments = str(input["function"]["arguments"])
tool_input = json.loads(json.dumps(content_part.input))
if "id" in tool_input and tool_input["id"].startswith("toolu_") and "function" in tool_input:
arguments = str(tool_input["function"]["arguments"])
else:
arguments = json.dumps(content_part.input, indent=2)
arguments = json.dumps(tool_input, indent=2)
tool_calls = [
ToolCall(
id=content_part.id,

View File

@ -1,422 +1,21 @@
import json
import uuid
from typing import List, Optional, Tuple
import requests
import httpx
from google import genai
from google.genai.types import FunctionCallingConfig, FunctionCallingConfigMode, ToolConfig
from letta.constants import NON_USER_MSG_PREFIX
from letta.errors import ErrorCode, LLMAuthenticationError, LLMError
from letta.helpers.datetime_helpers import get_utc_time_int
from letta.helpers.json_helpers import json_dumps
from letta.llm_api.google_constants import GOOGLE_MODEL_FOR_API_KEY_CHECK
from letta.llm_api.helpers import make_post_request
from letta.llm_api.llm_client_base import LLMClientBase
from letta.local_llm.json_parser import clean_json_string_extra_backslash
from letta.local_llm.utils import count_tokens
from letta.llm_api.google_vertex_client import GoogleVertexClient
from letta.log import get_logger
from letta.schemas.enums import ProviderCategory
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message as PydanticMessage
from letta.schemas.openai.chat_completion_request import Tool
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice, FunctionCall, Message, ToolCall, UsageStatistics
from letta.settings import model_settings
from letta.utils import get_tool_call_id
logger = get_logger(__name__)
class GoogleAIClient(LLMClientBase):
class GoogleAIClient(GoogleVertexClient):
def request(self, request_data: dict, llm_config: LLMConfig) -> dict:
"""
Performs underlying request to llm and returns raw response.
"""
api_key = None
if llm_config.provider_category == ProviderCategory.byok:
from letta.services.provider_manager import ProviderManager
api_key = ProviderManager().get_override_key(llm_config.provider_name, actor=self.actor)
if not api_key:
api_key = model_settings.gemini_api_key
# print("[google_ai request]", json.dumps(request_data, indent=2))
url, headers = get_gemini_endpoint_and_headers(
base_url=str(llm_config.model_endpoint),
model=llm_config.model,
api_key=str(api_key),
key_in_header=True,
generate_content=True,
)
return make_post_request(url, headers, request_data)
def build_request_data(
self,
messages: List[PydanticMessage],
llm_config: LLMConfig,
tools: List[dict],
force_tool_call: Optional[str] = None,
) -> dict:
"""
Constructs a request object in the expected data format for this client.
"""
if tools:
tools = [{"type": "function", "function": f} for f in tools]
tool_objs = [Tool(**t) for t in tools]
tool_names = [t.function.name for t in tool_objs]
# Convert to the exact payload style Google expects
tools = self.convert_tools_to_google_ai_format(tool_objs, llm_config)
else:
tool_names = []
contents = self.add_dummy_model_messages(
[m.to_google_ai_dict() for m in messages],
)
request_data = {
"contents": contents,
"tools": tools,
"generation_config": {
"temperature": llm_config.temperature,
"max_output_tokens": llm_config.max_tokens,
},
}
# write tool config
tool_config = ToolConfig(
function_calling_config=FunctionCallingConfig(
# ANY mode forces the model to predict only function calls
mode=FunctionCallingConfigMode.ANY,
# Provide the list of tools (though empty should also work, it seems not to)
allowed_function_names=tool_names,
)
)
request_data["tool_config"] = tool_config.model_dump()
return request_data
def convert_response_to_chat_completion(
self,
response_data: dict,
input_messages: List[PydanticMessage],
llm_config: LLMConfig,
) -> ChatCompletionResponse:
"""
Converts custom response format from llm client into an OpenAI
ChatCompletionsResponse object.
Example Input:
{
"candidates": [
{
"content": {
"parts": [
{
"text": " OK. Barbie is showing in two theaters in Mountain View, CA: AMC Mountain View 16 and Regal Edwards 14."
}
]
}
}
],
"usageMetadata": {
"promptTokenCount": 9,
"candidatesTokenCount": 27,
"totalTokenCount": 36
}
}
"""
# print("[google_ai response]", json.dumps(response_data, indent=2))
try:
choices = []
index = 0
for candidate in response_data["candidates"]:
content = candidate["content"]
if "role" not in content or not content["role"]:
# This means the response is malformed like MALFORMED_FUNCTION_CALL
# NOTE: must be a ValueError to trigger a retry
raise ValueError(f"Error in response data from LLM: {response_data}")
role = content["role"]
assert role == "model", f"Unknown role in response: {role}"
parts = content["parts"]
# NOTE: we aren't properly supported multi-parts here anyways (we're just appending choices),
# so let's disable it for now
# NOTE(Apr 9, 2025): there's a very strange bug on 2.5 where the response has a part with broken text
# {'candidates': [{'content': {'parts': [{'functionCall': {'name': 'send_message', 'args': {'request_heartbeat': False, 'message': 'Hello! How can I make your day better?', 'inner_thoughts': 'User has initiated contact. Sending a greeting.'}}}], 'role': 'model'}, 'finishReason': 'STOP', 'avgLogprobs': -0.25891534213362066}], 'usageMetadata': {'promptTokenCount': 2493, 'candidatesTokenCount': 29, 'totalTokenCount': 2522, 'promptTokensDetails': [{'modality': 'TEXT', 'tokenCount': 2493}], 'candidatesTokensDetails': [{'modality': 'TEXT', 'tokenCount': 29}]}, 'modelVersion': 'gemini-1.5-pro-002'}
# To patch this, if we have multiple parts we can take the last one
if len(parts) > 1:
logger.warning(f"Unexpected multiple parts in response from Google AI: {parts}")
parts = [parts[-1]]
# TODO support parts / multimodal
# TODO support parallel tool calling natively
# TODO Alternative here is to throw away everything else except for the first part
for response_message in parts:
# Convert the actual message style to OpenAI style
if "functionCall" in response_message and response_message["functionCall"] is not None:
function_call = response_message["functionCall"]
assert isinstance(function_call, dict), function_call
function_name = function_call["name"]
assert isinstance(function_name, str), function_name
function_args = function_call["args"]
assert isinstance(function_args, dict), function_args
# NOTE: this also involves stripping the inner monologue out of the function
if llm_config.put_inner_thoughts_in_kwargs:
from letta.local_llm.constants import INNER_THOUGHTS_KWARG_VERTEX
assert (
INNER_THOUGHTS_KWARG_VERTEX in function_args
), f"Couldn't find inner thoughts in function args:\n{function_call}"
inner_thoughts = function_args.pop(INNER_THOUGHTS_KWARG_VERTEX)
assert inner_thoughts is not None, f"Expected non-null inner thoughts function arg:\n{function_call}"
else:
inner_thoughts = None
# Google AI API doesn't generate tool call IDs
openai_response_message = Message(
role="assistant", # NOTE: "model" -> "assistant"
content=inner_thoughts,
tool_calls=[
ToolCall(
id=get_tool_call_id(),
type="function",
function=FunctionCall(
name=function_name,
arguments=clean_json_string_extra_backslash(json_dumps(function_args)),
),
)
],
)
else:
# Inner thoughts are the content by default
inner_thoughts = response_message["text"]
# Google AI API doesn't generate tool call IDs
openai_response_message = Message(
role="assistant", # NOTE: "model" -> "assistant"
content=inner_thoughts,
)
# Google AI API uses different finish reason strings than OpenAI
# OpenAI: 'stop', 'length', 'function_call', 'content_filter', null
# see: https://platform.openai.com/docs/guides/text-generation/chat-completions-api
# Google AI API: FINISH_REASON_UNSPECIFIED, STOP, MAX_TOKENS, SAFETY, RECITATION, OTHER
# see: https://ai.google.dev/api/python/google/ai/generativelanguage/Candidate/FinishReason
finish_reason = candidate["finishReason"]
if finish_reason == "STOP":
openai_finish_reason = (
"function_call"
if openai_response_message.tool_calls is not None and len(openai_response_message.tool_calls) > 0
else "stop"
)
elif finish_reason == "MAX_TOKENS":
openai_finish_reason = "length"
elif finish_reason == "SAFETY":
openai_finish_reason = "content_filter"
elif finish_reason == "RECITATION":
openai_finish_reason = "content_filter"
else:
raise ValueError(f"Unrecognized finish reason in Google AI response: {finish_reason}")
choices.append(
Choice(
finish_reason=openai_finish_reason,
index=index,
message=openai_response_message,
)
)
index += 1
# if len(choices) > 1:
# raise UserWarning(f"Unexpected number of candidates in response (expected 1, got {len(choices)})")
# NOTE: some of the Google AI APIs show UsageMetadata in the response, but it seems to not exist?
# "usageMetadata": {
# "promptTokenCount": 9,
# "candidatesTokenCount": 27,
# "totalTokenCount": 36
# }
if "usageMetadata" in response_data:
usage_data = response_data["usageMetadata"]
if "promptTokenCount" not in usage_data:
raise ValueError(f"promptTokenCount not found in usageMetadata:\n{json.dumps(usage_data, indent=2)}")
if "totalTokenCount" not in usage_data:
raise ValueError(f"totalTokenCount not found in usageMetadata:\n{json.dumps(usage_data, indent=2)}")
if "candidatesTokenCount" not in usage_data:
raise ValueError(f"candidatesTokenCount not found in usageMetadata:\n{json.dumps(usage_data, indent=2)}")
prompt_tokens = usage_data["promptTokenCount"]
completion_tokens = usage_data["candidatesTokenCount"]
total_tokens = usage_data["totalTokenCount"]
usage = UsageStatistics(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
)
else:
# Count it ourselves
assert input_messages is not None, f"Didn't get UsageMetadata from the API response, so input_messages is required"
prompt_tokens = count_tokens(json_dumps(input_messages)) # NOTE: this is a very rough approximation
completion_tokens = count_tokens(json_dumps(openai_response_message.model_dump())) # NOTE: this is also approximate
total_tokens = prompt_tokens + completion_tokens
usage = UsageStatistics(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
)
response_id = str(uuid.uuid4())
return ChatCompletionResponse(
id=response_id,
choices=choices,
model=llm_config.model, # NOTE: Google API doesn't pass back model in the response
created=get_utc_time_int(),
usage=usage,
)
except KeyError as e:
raise e
def _clean_google_ai_schema_properties(self, schema_part: dict):
"""Recursively clean schema parts to remove unsupported Google AI keywords."""
if not isinstance(schema_part, dict):
return
# Per https://ai.google.dev/gemini-api/docs/function-calling?example=meeting#notes_and_limitations
# * Only a subset of the OpenAPI schema is supported.
# * Supported parameter types in Python are limited.
unsupported_keys = ["default", "exclusiveMaximum", "exclusiveMinimum", "additionalProperties"]
keys_to_remove_at_this_level = [key for key in unsupported_keys if key in schema_part]
for key_to_remove in keys_to_remove_at_this_level:
logger.warning(f"Removing unsupported keyword '{key_to_remove}' from schema part.")
del schema_part[key_to_remove]
if schema_part.get("type") == "string" and "format" in schema_part:
allowed_formats = ["enum", "date-time"]
if schema_part["format"] not in allowed_formats:
logger.warning(f"Removing unsupported format '{schema_part['format']}' for string type. Allowed: {allowed_formats}")
del schema_part["format"]
# Check properties within the current level
if "properties" in schema_part and isinstance(schema_part["properties"], dict):
for prop_name, prop_schema in schema_part["properties"].items():
self._clean_google_ai_schema_properties(prop_schema)
# Check items within arrays
if "items" in schema_part and isinstance(schema_part["items"], dict):
self._clean_google_ai_schema_properties(schema_part["items"])
# Check within anyOf, allOf, oneOf lists
for key in ["anyOf", "allOf", "oneOf"]:
if key in schema_part and isinstance(schema_part[key], list):
for item_schema in schema_part[key]:
self._clean_google_ai_schema_properties(item_schema)
def convert_tools_to_google_ai_format(self, tools: List[Tool], llm_config: LLMConfig) -> List[dict]:
"""
OpenAI style:
"tools": [{
"type": "function",
"function": {
"name": "find_movies",
"description": "find ....",
"parameters": {
"type": "object",
"properties": {
PARAM: {
"type": PARAM_TYPE, # eg "string"
"description": PARAM_DESCRIPTION,
},
...
},
"required": List[str],
}
}
}
]
Google AI style:
"tools": [{
"functionDeclarations": [{
"name": "find_movies",
"description": "find movie titles currently playing in theaters based on any description, genre, title words, etc.",
"parameters": {
"type": "OBJECT",
"properties": {
"location": {
"type": "STRING",
"description": "The city and state, e.g. San Francisco, CA or a zip code e.g. 95616"
},
"description": {
"type": "STRING",
"description": "Any kind of description including category or genre, title words, attributes, etc."
}
},
"required": ["description"]
}
}, {
"name": "find_theaters",
...
"""
function_list = [
dict(
name=t.function.name,
description=t.function.description,
parameters=t.function.parameters, # TODO need to unpack
)
for t in tools
]
# Add inner thoughts if needed
for func in function_list:
# Note: Google AI API used to have weird casing requirements, but not any more
# Google AI API only supports a subset of OpenAPI 3.0, so unsupported params must be cleaned
if "parameters" in func and isinstance(func["parameters"], dict):
self._clean_google_ai_schema_properties(func["parameters"])
# Add inner thoughts
if llm_config.put_inner_thoughts_in_kwargs:
from letta.local_llm.constants import INNER_THOUGHTS_KWARG_DESCRIPTION, INNER_THOUGHTS_KWARG_VERTEX
func["parameters"]["properties"][INNER_THOUGHTS_KWARG_VERTEX] = {
"type": "string",
"description": INNER_THOUGHTS_KWARG_DESCRIPTION,
}
func["parameters"]["required"].append(INNER_THOUGHTS_KWARG_VERTEX)
return [{"functionDeclarations": function_list}]
def add_dummy_model_messages(self, messages: List[dict]) -> List[dict]:
"""Google AI API requires all function call returns are immediately followed by a 'model' role message.
In Letta, the 'model' will often call a function (e.g. send_message) that itself yields to the user,
so there is no natural follow-up 'model' role message.
To satisfy the Google AI API restrictions, we can add a dummy 'yield' message
with role == 'model' that is placed in-betweeen and function output
(role == 'tool') and user message (role == 'user').
"""
dummy_yield_message = {
"role": "model",
"parts": [{"text": f"{NON_USER_MSG_PREFIX}Function call returned, waiting for user response."}],
}
messages_with_padding = []
for i, message in enumerate(messages):
messages_with_padding.append(message)
# Check if the current message role is 'tool' and the next message role is 'user'
if message["role"] in ["tool", "function"] and (i + 1 < len(messages) and messages[i + 1]["role"] == "user"):
messages_with_padding.append(dummy_yield_message)
return messages_with_padding
def _get_client(self):
return genai.Client(api_key=model_settings.gemini_api_key)
def get_gemini_endpoint_and_headers(
@ -464,20 +63,24 @@ def google_ai_check_valid_api_key(api_key: str):
def google_ai_get_model_list(base_url: str, api_key: str, key_in_header: bool = True) -> List[dict]:
"""Synchronous version to get model list from Google AI API using httpx."""
import httpx
from letta.utils import printd
url, headers = get_gemini_endpoint_and_headers(base_url, None, api_key, key_in_header)
try:
response = requests.get(url, headers=headers)
response.raise_for_status() # Raises HTTPError for 4XX/5XX status
response = response.json() # convert to dict from string
with httpx.Client() as client:
response = client.get(url, headers=headers)
response.raise_for_status() # Raises HTTPStatusError for 4XX/5XX status
response_data = response.json() # convert to dict from string
# Grab the models out
model_list = response["models"]
return model_list
# Grab the models out
model_list = response_data["models"]
return model_list
except requests.exceptions.HTTPError as http_err:
except httpx.HTTPStatusError as http_err:
# Handle HTTP errors (e.g., response 4XX, 5XX)
printd(f"Got HTTPError, exception={http_err}")
# Print the HTTP status code
@ -486,8 +89,8 @@ def google_ai_get_model_list(base_url: str, api_key: str, key_in_header: bool =
print(f"Message: {http_err.response.text}")
raise http_err
except requests.exceptions.RequestException as req_err:
# Handle other requests-related errors (e.g., connection error)
except httpx.RequestError as req_err:
# Handle other httpx-related errors (e.g., connection error)
printd(f"Got RequestException, exception={req_err}")
raise req_err
@ -497,22 +100,74 @@ def google_ai_get_model_list(base_url: str, api_key: str, key_in_header: bool =
raise e
def google_ai_get_model_details(base_url: str, api_key: str, model: str, key_in_header: bool = True) -> List[dict]:
async def google_ai_get_model_list_async(
base_url: str, api_key: str, key_in_header: bool = True, client: Optional[httpx.AsyncClient] = None
) -> List[dict]:
"""Asynchronous version to get model list from Google AI API using httpx."""
from letta.utils import printd
url, headers = get_gemini_endpoint_and_headers(base_url, None, api_key, key_in_header)
# Determine if we need to close the client at the end
close_client = False
if client is None:
client = httpx.AsyncClient()
close_client = True
try:
response = await client.get(url, headers=headers)
response.raise_for_status() # Raises HTTPStatusError for 4XX/5XX status
response_data = response.json() # convert to dict from string
# Grab the models out
model_list = response_data["models"]
return model_list
except httpx.HTTPStatusError as http_err:
# Handle HTTP errors (e.g., response 4XX, 5XX)
printd(f"Got HTTPError, exception={http_err}")
# Print the HTTP status code
print(f"HTTP Error: {http_err.response.status_code}")
# Print the response content (error message from server)
print(f"Message: {http_err.response.text}")
raise http_err
except httpx.RequestError as req_err:
# Handle other httpx-related errors (e.g., connection error)
printd(f"Got RequestException, exception={req_err}")
raise req_err
except Exception as e:
# Handle other potential errors
printd(f"Got unknown Exception, exception={e}")
raise e
finally:
# Close the client if we created it
if close_client:
await client.aclose()
def google_ai_get_model_details(base_url: str, api_key: str, model: str, key_in_header: bool = True) -> dict:
"""Synchronous version to get model details from Google AI API using httpx."""
import httpx
from letta.utils import printd
url, headers = get_gemini_endpoint_and_headers(base_url, model, api_key, key_in_header)
try:
response = requests.get(url, headers=headers)
printd(f"response = {response}")
response.raise_for_status() # Raises HTTPError for 4XX/5XX status
response = response.json() # convert to dict from string
printd(f"response.json = {response}")
with httpx.Client() as client:
response = client.get(url, headers=headers)
printd(f"response = {response}")
response.raise_for_status() # Raises HTTPStatusError for 4XX/5XX status
response_data = response.json() # convert to dict from string
printd(f"response.json = {response_data}")
# Grab the models out
return response
# Return the model details
return response_data
except requests.exceptions.HTTPError as http_err:
except httpx.HTTPStatusError as http_err:
# Handle HTTP errors (e.g., response 4XX, 5XX)
printd(f"Got HTTPError, exception={http_err}")
# Print the HTTP status code
@ -521,8 +176,8 @@ def google_ai_get_model_details(base_url: str, api_key: str, model: str, key_in_
print(f"Message: {http_err.response.text}")
raise http_err
except requests.exceptions.RequestException as req_err:
# Handle other requests-related errors (e.g., connection error)
except httpx.RequestError as req_err:
# Handle other httpx-related errors (e.g., connection error)
printd(f"Got RequestException, exception={req_err}")
raise req_err
@ -532,8 +187,66 @@ def google_ai_get_model_details(base_url: str, api_key: str, model: str, key_in_
raise e
async def google_ai_get_model_details_async(
base_url: str, api_key: str, model: str, key_in_header: bool = True, client: Optional[httpx.AsyncClient] = None
) -> dict:
"""Asynchronous version to get model details from Google AI API using httpx."""
import httpx
from letta.utils import printd
url, headers = get_gemini_endpoint_and_headers(base_url, model, api_key, key_in_header)
# Determine if we need to close the client at the end
close_client = False
if client is None:
client = httpx.AsyncClient()
close_client = True
try:
response = await client.get(url, headers=headers)
printd(f"response = {response}")
response.raise_for_status() # Raises HTTPStatusError for 4XX/5XX status
response_data = response.json() # convert to dict from string
printd(f"response.json = {response_data}")
# Return the model details
return response_data
except httpx.HTTPStatusError as http_err:
# Handle HTTP errors (e.g., response 4XX, 5XX)
printd(f"Got HTTPError, exception={http_err}")
# Print the HTTP status code
print(f"HTTP Error: {http_err.response.status_code}")
# Print the response content (error message from server)
print(f"Message: {http_err.response.text}")
raise http_err
except httpx.RequestError as req_err:
# Handle other httpx-related errors (e.g., connection error)
printd(f"Got RequestException, exception={req_err}")
raise req_err
except Exception as e:
# Handle other potential errors
printd(f"Got unknown Exception, exception={e}")
raise e
finally:
# Close the client if we created it
if close_client:
await client.aclose()
def google_ai_get_model_context_window(base_url: str, api_key: str, model: str, key_in_header: bool = True) -> int:
model_details = google_ai_get_model_details(base_url=base_url, api_key=api_key, model=model, key_in_header=key_in_header)
# TODO should this be:
# return model_details["inputTokenLimit"] + model_details["outputTokenLimit"]
return int(model_details["inputTokenLimit"])
async def google_ai_get_model_context_window_async(base_url: str, api_key: str, model: str, key_in_header: bool = True) -> int:
model_details = await google_ai_get_model_details_async(base_url=base_url, api_key=api_key, model=model, key_in_header=key_in_header)
# TODO should this be:
# return model_details["inputTokenLimit"] + model_details["outputTokenLimit"]
return int(model_details["inputTokenLimit"])

View File

@ -5,14 +5,16 @@ from typing import List, Optional
from google import genai
from google.genai.types import FunctionCallingConfig, FunctionCallingConfigMode, GenerateContentResponse, ThinkingConfig, ToolConfig
from letta.constants import NON_USER_MSG_PREFIX
from letta.helpers.datetime_helpers import get_utc_time_int
from letta.helpers.json_helpers import json_dumps, json_loads
from letta.llm_api.google_ai_client import GoogleAIClient
from letta.llm_api.llm_client_base import LLMClientBase
from letta.local_llm.json_parser import clean_json_string_extra_backslash
from letta.local_llm.utils import count_tokens
from letta.log import get_logger
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message as PydanticMessage
from letta.schemas.openai.chat_completion_request import Tool
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice, FunctionCall, Message, ToolCall, UsageStatistics
from letta.settings import model_settings, settings
from letta.utils import get_tool_call_id
@ -20,18 +22,21 @@ from letta.utils import get_tool_call_id
logger = get_logger(__name__)
class GoogleVertexClient(GoogleAIClient):
class GoogleVertexClient(LLMClientBase):
def request(self, request_data: dict, llm_config: LLMConfig) -> dict:
"""
Performs underlying request to llm and returns raw response.
"""
client = genai.Client(
def _get_client(self):
return genai.Client(
vertexai=True,
project=model_settings.google_cloud_project,
location=model_settings.google_cloud_location,
http_options={"api_version": "v1"},
)
def request(self, request_data: dict, llm_config: LLMConfig) -> dict:
"""
Performs underlying request to llm and returns raw response.
"""
client = self._get_client()
response = client.models.generate_content(
model=llm_config.model,
contents=request_data["contents"],
@ -43,12 +48,7 @@ class GoogleVertexClient(GoogleAIClient):
"""
Performs underlying request to llm and returns raw response.
"""
client = genai.Client(
vertexai=True,
project=model_settings.google_cloud_project,
location=model_settings.google_cloud_location,
http_options={"api_version": "v1"},
)
client = self._get_client()
response = await client.aio.models.generate_content(
model=llm_config.model,
contents=request_data["contents"],
@ -56,6 +56,139 @@ class GoogleVertexClient(GoogleAIClient):
)
return response.model_dump()
def add_dummy_model_messages(self, messages: List[dict]) -> List[dict]:
"""Google AI API requires all function call returns are immediately followed by a 'model' role message.
In Letta, the 'model' will often call a function (e.g. send_message) that itself yields to the user,
so there is no natural follow-up 'model' role message.
To satisfy the Google AI API restrictions, we can add a dummy 'yield' message
with role == 'model' that is placed in-betweeen and function output
(role == 'tool') and user message (role == 'user').
"""
dummy_yield_message = {
"role": "model",
"parts": [{"text": f"{NON_USER_MSG_PREFIX}Function call returned, waiting for user response."}],
}
messages_with_padding = []
for i, message in enumerate(messages):
messages_with_padding.append(message)
# Check if the current message role is 'tool' and the next message role is 'user'
if message["role"] in ["tool", "function"] and (i + 1 < len(messages) and messages[i + 1]["role"] == "user"):
messages_with_padding.append(dummy_yield_message)
return messages_with_padding
def _clean_google_ai_schema_properties(self, schema_part: dict):
"""Recursively clean schema parts to remove unsupported Google AI keywords."""
if not isinstance(schema_part, dict):
return
# Per https://ai.google.dev/gemini-api/docs/function-calling?example=meeting#notes_and_limitations
# * Only a subset of the OpenAPI schema is supported.
# * Supported parameter types in Python are limited.
unsupported_keys = ["default", "exclusiveMaximum", "exclusiveMinimum", "additionalProperties"]
keys_to_remove_at_this_level = [key for key in unsupported_keys if key in schema_part]
for key_to_remove in keys_to_remove_at_this_level:
logger.warning(f"Removing unsupported keyword '{key_to_remove}' from schema part.")
del schema_part[key_to_remove]
if schema_part.get("type") == "string" and "format" in schema_part:
allowed_formats = ["enum", "date-time"]
if schema_part["format"] not in allowed_formats:
logger.warning(f"Removing unsupported format '{schema_part['format']}' for string type. Allowed: {allowed_formats}")
del schema_part["format"]
# Check properties within the current level
if "properties" in schema_part and isinstance(schema_part["properties"], dict):
for prop_name, prop_schema in schema_part["properties"].items():
self._clean_google_ai_schema_properties(prop_schema)
# Check items within arrays
if "items" in schema_part and isinstance(schema_part["items"], dict):
self._clean_google_ai_schema_properties(schema_part["items"])
# Check within anyOf, allOf, oneOf lists
for key in ["anyOf", "allOf", "oneOf"]:
if key in schema_part and isinstance(schema_part[key], list):
for item_schema in schema_part[key]:
self._clean_google_ai_schema_properties(item_schema)
def convert_tools_to_google_ai_format(self, tools: List[Tool], llm_config: LLMConfig) -> List[dict]:
"""
OpenAI style:
"tools": [{
"type": "function",
"function": {
"name": "find_movies",
"description": "find ....",
"parameters": {
"type": "object",
"properties": {
PARAM: {
"type": PARAM_TYPE, # eg "string"
"description": PARAM_DESCRIPTION,
},
...
},
"required": List[str],
}
}
}
]
Google AI style:
"tools": [{
"functionDeclarations": [{
"name": "find_movies",
"description": "find movie titles currently playing in theaters based on any description, genre, title words, etc.",
"parameters": {
"type": "OBJECT",
"properties": {
"location": {
"type": "STRING",
"description": "The city and state, e.g. San Francisco, CA or a zip code e.g. 95616"
},
"description": {
"type": "STRING",
"description": "Any kind of description including category or genre, title words, attributes, etc."
}
},
"required": ["description"]
}
}, {
"name": "find_theaters",
...
"""
function_list = [
dict(
name=t.function.name,
description=t.function.description,
parameters=t.function.parameters, # TODO need to unpack
)
for t in tools
]
# Add inner thoughts if needed
for func in function_list:
# Note: Google AI API used to have weird casing requirements, but not any more
# Google AI API only supports a subset of OpenAPI 3.0, so unsupported params must be cleaned
if "parameters" in func and isinstance(func["parameters"], dict):
self._clean_google_ai_schema_properties(func["parameters"])
# Add inner thoughts
if llm_config.put_inner_thoughts_in_kwargs:
from letta.local_llm.constants import INNER_THOUGHTS_KWARG_DESCRIPTION, INNER_THOUGHTS_KWARG_VERTEX
func["parameters"]["properties"][INNER_THOUGHTS_KWARG_VERTEX] = {
"type": "string",
"description": INNER_THOUGHTS_KWARG_DESCRIPTION,
}
func["parameters"]["required"].append(INNER_THOUGHTS_KWARG_VERTEX)
return [{"functionDeclarations": function_list}]
def build_request_data(
self,
messages: List[PydanticMessage],
@ -66,11 +199,29 @@ class GoogleVertexClient(GoogleAIClient):
"""
Constructs a request object in the expected data format for this client.
"""
request_data = super().build_request_data(messages, llm_config, tools, force_tool_call)
request_data["config"] = request_data.pop("generation_config")
request_data["config"]["tools"] = request_data.pop("tools")
tool_names = [t["name"] for t in tools] if tools else []
if tools:
tool_objs = [Tool(type="function", function=t) for t in tools]
tool_names = [t.function.name for t in tool_objs]
# Convert to the exact payload style Google expects
formatted_tools = self.convert_tools_to_google_ai_format(tool_objs, llm_config)
else:
formatted_tools = []
tool_names = []
contents = self.add_dummy_model_messages(
[m.to_google_ai_dict() for m in messages],
)
request_data = {
"contents": contents,
"config": {
"temperature": llm_config.temperature,
"max_output_tokens": llm_config.max_tokens,
"tools": formatted_tools,
},
}
if len(tool_names) == 1 and settings.use_vertex_structured_outputs_experimental:
request_data["config"]["response_mime_type"] = "application/json"
request_data["config"]["response_schema"] = self.get_function_call_response_schema(tools[0])
@ -89,11 +240,11 @@ class GoogleVertexClient(GoogleAIClient):
# Add thinking_config
# If enable_reasoner is False, set thinking_budget to 0
# Otherwise, use the value from max_reasoning_tokens
thinking_budget = 0 if not llm_config.enable_reasoner else llm_config.max_reasoning_tokens
thinking_config = ThinkingConfig(
thinking_budget=thinking_budget,
)
request_data["config"]["thinking_config"] = thinking_config.model_dump()
if llm_config.enable_reasoner:
thinking_config = ThinkingConfig(
thinking_budget=llm_config.max_reasoning_tokens,
)
request_data["config"]["thinking_config"] = thinking_config.model_dump()
return request_data

View File

@ -20,15 +20,19 @@ from letta.llm_api.openai import (
build_openai_chat_completions_request,
openai_chat_completions_process_stream,
openai_chat_completions_request,
prepare_openai_payload,
)
from letta.local_llm.chat_completion_proxy import get_chat_completion
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION
from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages
from letta.orm.user import User
from letta.schemas.enums import ProviderCategory
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, cast_message_to_subtype
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
from letta.schemas.provider_trace import ProviderTraceCreate
from letta.services.telemetry_manager import TelemetryManager
from letta.settings import ModelSettings
from letta.streaming_interface import AgentChunkStreamingInterface, AgentRefreshStreamingInterface
from letta.tracing import log_event, trace_method
@ -142,6 +146,9 @@ def create(
model_settings: Optional[dict] = None, # TODO: eventually pass from server
put_inner_thoughts_first: bool = True,
name: Optional[str] = None,
telemetry_manager: Optional[TelemetryManager] = None,
step_id: Optional[str] = None,
actor: Optional[User] = None,
) -> ChatCompletionResponse:
"""Return response to chat completion with backoff"""
from letta.utils import printd
@ -233,6 +240,16 @@ def create(
if isinstance(stream_interface, AgentChunkStreamingInterface):
stream_interface.stream_end()
telemetry_manager.create_provider_trace(
actor=actor,
provider_trace_create=ProviderTraceCreate(
request_json=prepare_openai_payload(data),
response_json=response.model_json_schema(),
step_id=step_id,
organization_id=actor.organization_id,
),
)
if llm_config.put_inner_thoughts_in_kwargs:
response = unpack_all_inner_thoughts_from_kwargs(response=response, inner_thoughts_key=INNER_THOUGHTS_KWARG)
@ -407,6 +424,16 @@ def create(
if llm_config.put_inner_thoughts_in_kwargs:
response = unpack_all_inner_thoughts_from_kwargs(response=response, inner_thoughts_key=INNER_THOUGHTS_KWARG)
telemetry_manager.create_provider_trace(
actor=actor,
provider_trace_create=ProviderTraceCreate(
request_json=chat_completion_request.model_json_schema(),
response_json=response.model_json_schema(),
step_id=step_id,
organization_id=actor.organization_id,
),
)
return response
# elif llm_config.model_endpoint_type == "cohere":

View File

@ -51,7 +51,7 @@ class LLMClient:
put_inner_thoughts_first=put_inner_thoughts_first,
actor=actor,
)
case ProviderType.openai:
case ProviderType.openai | ProviderType.together:
from letta.llm_api.openai_client import OpenAIClient
return OpenAIClient(

View File

@ -9,7 +9,9 @@ from letta.errors import LLMError
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
from letta.tracing import log_event
from letta.schemas.provider_trace import ProviderTraceCreate
from letta.services.telemetry_manager import TelemetryManager
from letta.tracing import log_event, trace_method
if TYPE_CHECKING:
from letta.orm import User
@ -31,13 +33,15 @@ class LLMClientBase:
self.put_inner_thoughts_first = put_inner_thoughts_first
self.use_tool_naming = use_tool_naming
@trace_method
def send_llm_request(
self,
messages: List[Message],
llm_config: LLMConfig,
tools: Optional[List[dict]] = None, # TODO: change to Tool object
stream: bool = False,
force_tool_call: Optional[str] = None,
telemetry_manager: Optional["TelemetryManager"] = None,
step_id: Optional[str] = None,
) -> Union[ChatCompletionResponse, Stream[ChatCompletionChunk]]:
"""
Issues a request to the downstream model endpoint and parses response.
@ -48,37 +52,51 @@ class LLMClientBase:
try:
log_event(name="llm_request_sent", attributes=request_data)
if stream:
return self.stream(request_data, llm_config)
else:
response_data = self.request(request_data, llm_config)
response_data = self.request(request_data, llm_config)
if step_id and telemetry_manager:
telemetry_manager.create_provider_trace(
actor=self.actor,
provider_trace_create=ProviderTraceCreate(
request_json=request_data,
response_json=response_data,
step_id=step_id,
organization_id=self.actor.organization_id,
),
)
log_event(name="llm_response_received", attributes=response_data)
except Exception as e:
raise self.handle_llm_error(e)
return self.convert_response_to_chat_completion(response_data, messages, llm_config)
@trace_method
async def send_llm_request_async(
self,
request_data: dict,
messages: List[Message],
llm_config: LLMConfig,
tools: Optional[List[dict]] = None, # TODO: change to Tool object
stream: bool = False,
force_tool_call: Optional[str] = None,
telemetry_manager: "TelemetryManager | None" = None,
step_id: str | None = None,
) -> Union[ChatCompletionResponse, AsyncStream[ChatCompletionChunk]]:
"""
Issues a request to the downstream model endpoint.
If stream=True, returns an AsyncStream[ChatCompletionChunk] that can be async iterated over.
Otherwise returns a ChatCompletionResponse.
"""
request_data = self.build_request_data(messages, llm_config, tools, force_tool_call)
try:
log_event(name="llm_request_sent", attributes=request_data)
if stream:
return await self.stream_async(request_data, llm_config)
else:
response_data = await self.request_async(request_data, llm_config)
response_data = await self.request_async(request_data, llm_config)
await telemetry_manager.create_provider_trace_async(
actor=self.actor,
provider_trace_create=ProviderTraceCreate(
request_json=request_data,
response_json=response_data,
step_id=step_id,
organization_id=self.actor.organization_id,
),
)
log_event(name="llm_response_received", attributes=response_data)
except Exception as e:
raise self.handle_llm_error(e)
@ -133,13 +151,6 @@ class LLMClientBase:
"""
raise NotImplementedError
@abstractmethod
def stream(self, request_data: dict, llm_config: LLMConfig) -> Stream[ChatCompletionChunk]:
"""
Performs underlying streaming request to llm and returns raw response.
"""
raise NotImplementedError(f"Streaming is not supported for {llm_config.model_endpoint_type}")
@abstractmethod
async def stream_async(self, request_data: dict, llm_config: LLMConfig) -> AsyncStream[ChatCompletionChunk]:
"""

View File

@ -1,6 +1,7 @@
import warnings
from typing import Generator, List, Optional, Union
import httpx
import requests
from openai import OpenAI
@ -110,6 +111,62 @@ def openai_get_model_list(url: str, api_key: Optional[str] = None, fix_url: bool
raise e
async def openai_get_model_list_async(
url: str,
api_key: Optional[str] = None,
fix_url: bool = False,
extra_params: Optional[dict] = None,
client: Optional["httpx.AsyncClient"] = None,
) -> dict:
"""https://platform.openai.com/docs/api-reference/models/list"""
from letta.utils import printd
# In some cases we may want to double-check the URL and do basic correction
if fix_url and not url.endswith("/v1"):
url = smart_urljoin(url, "v1")
url = smart_urljoin(url, "models")
headers = {"Content-Type": "application/json"}
if api_key is not None:
headers["Authorization"] = f"Bearer {api_key}"
printd(f"Sending request to {url}")
# Use provided client or create a new one
close_client = False
if client is None:
client = httpx.AsyncClient()
close_client = True
try:
response = await client.get(url, headers=headers, params=extra_params)
response.raise_for_status()
result = response.json()
printd(f"response = {result}")
return result
except httpx.HTTPStatusError as http_err:
# Handle HTTP errors (e.g., response 4XX, 5XX)
error_response = None
try:
error_response = http_err.response.json()
except:
error_response = {"status_code": http_err.response.status_code, "text": http_err.response.text}
printd(f"Got HTTPError, exception={http_err}, response={error_response}")
raise http_err
except httpx.RequestError as req_err:
# Handle other httpx-related errors (e.g., connection error)
printd(f"Got RequestException, exception={req_err}")
raise req_err
except Exception as e:
# Handle other potential errors
printd(f"Got unknown Exception, exception={e}")
raise e
finally:
if close_client:
await client.aclose()
def build_openai_chat_completions_request(
llm_config: LLMConfig,
messages: List[_Message],

View File

@ -2,7 +2,7 @@ import os
from typing import List, Optional
import openai
from openai import AsyncOpenAI, AsyncStream, OpenAI, Stream
from openai import AsyncOpenAI, AsyncStream, OpenAI
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
@ -22,7 +22,7 @@ from letta.llm_api.helpers import add_inner_thoughts_to_functions, convert_to_st
from letta.llm_api.llm_client_base import LLMClientBase
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION, INNER_THOUGHTS_KWARG_DESCRIPTION_GO_FIRST
from letta.log import get_logger
from letta.schemas.enums import ProviderCategory
from letta.schemas.enums import ProviderCategory, ProviderType
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message as PydanticMessage
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest
@ -113,6 +113,8 @@ class OpenAIClient(LLMClientBase):
from letta.services.provider_manager import ProviderManager
api_key = ProviderManager().get_override_key(llm_config.provider_name, actor=self.actor)
if llm_config.model_endpoint_type == ProviderType.together:
api_key = model_settings.together_api_key or os.environ.get("TOGETHER_API_KEY")
if not api_key:
api_key = model_settings.openai_api_key or os.environ.get("OPENAI_API_KEY")
@ -254,20 +256,14 @@ class OpenAIClient(LLMClientBase):
return chat_completion_response
def stream(self, request_data: dict, llm_config: LLMConfig) -> Stream[ChatCompletionChunk]:
"""
Performs underlying streaming request to OpenAI and returns the stream iterator.
"""
client = OpenAI(**self._prepare_client_kwargs(llm_config))
response_stream: Stream[ChatCompletionChunk] = client.chat.completions.create(**request_data, stream=True)
return response_stream
async def stream_async(self, request_data: dict, llm_config: LLMConfig) -> AsyncStream[ChatCompletionChunk]:
"""
Performs underlying asynchronous streaming request to OpenAI and returns the async stream iterator.
"""
client = AsyncOpenAI(**self._prepare_client_kwargs(llm_config))
response_stream: AsyncStream[ChatCompletionChunk] = await client.chat.completions.create(**request_data, stream=True)
response_stream: AsyncStream[ChatCompletionChunk] = await client.chat.completions.create(
**request_data, stream=True, stream_options={"include_usage": True}
)
return response_stream
def handle_llm_error(self, e: Exception) -> Exception:

View File

@ -93,7 +93,6 @@ def summarize_messages(
response = llm_client.send_llm_request(
messages=message_sequence,
llm_config=llm_config_no_inner_thoughts,
stream=False,
)
else:
response = create(

View File

@ -19,6 +19,7 @@ from letta.orm.message import Message
from letta.orm.organization import Organization
from letta.orm.passage import AgentPassage, BasePassage, SourcePassage
from letta.orm.provider import Provider
from letta.orm.provider_trace import ProviderTrace
from letta.orm.sandbox_config import AgentEnvironmentVariable, SandboxConfig, SandboxEnvironmentVariable
from letta.orm.source import Source
from letta.orm.sources_agents import SourcesAgents

View File

@ -8,6 +8,7 @@ class ToolType(str, Enum):
LETTA_MULTI_AGENT_CORE = "letta_multi_agent_core"
LETTA_SLEEPTIME_CORE = "letta_sleeptime_core"
LETTA_VOICE_SLEEPTIME_CORE = "letta_voice_sleeptime_core"
LETTA_BUILTIN = "letta_builtin"
EXTERNAL_COMPOSIO = "external_composio"
EXTERNAL_LANGCHAIN = "external_langchain"
# TODO is "external" the right name here? Since as of now, MCP is local / doesn't support remote?

View File

@ -0,0 +1,26 @@
import uuid
from sqlalchemy import JSON, Index, String
from sqlalchemy.orm import Mapped, mapped_column, relationship
from letta.orm.mixins import OrganizationMixin
from letta.orm.sqlalchemy_base import SqlalchemyBase
from letta.schemas.provider_trace import ProviderTrace as PydanticProviderTrace
class ProviderTrace(SqlalchemyBase, OrganizationMixin):
"""Defines data model for storing provider trace information"""
__tablename__ = "provider_traces"
__pydantic_model__ = PydanticProviderTrace
__table_args__ = (Index("ix_step_id", "step_id"),)
id: Mapped[str] = mapped_column(
primary_key=True, doc="Unique provider trace identifier", default=lambda: f"provider_trace-{uuid.uuid4()}"
)
request_json: Mapped[dict] = mapped_column(JSON, doc="JSON content of the provider request")
response_json: Mapped[dict] = mapped_column(JSON, doc="JSON content of the provider response")
step_id: Mapped[str] = mapped_column(String, nullable=True, doc="ID of the step that this trace is associated with")
# Relationships
organization: Mapped["Organization"] = relationship("Organization", lazy="selectin")

View File

@ -35,6 +35,7 @@ class Step(SqlalchemyBase):
)
agent_id: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="The name of the model used for this step.")
provider_name: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="The name of the provider used for this step.")
provider_category: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="The category of the provider used for this step.")
model: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="The name of the model used for this step.")
model_endpoint: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="The model endpoint url used for this step.")
context_window_limit: Mapped[Optional[int]] = mapped_column(

View File

@ -0,0 +1,43 @@
from __future__ import annotations
from datetime import datetime
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field
from letta.helpers.datetime_helpers import get_utc_time
from letta.schemas.letta_base import OrmMetadataBase
class BaseProviderTrace(OrmMetadataBase):
__id_prefix__ = "provider_trace"
class ProviderTraceCreate(BaseModel):
"""Request to create a provider trace"""
request_json: dict[str, Any] = Field(..., description="JSON content of the provider request")
response_json: dict[str, Any] = Field(..., description="JSON content of the provider response")
step_id: str = Field(None, description="ID of the step that this trace is associated with")
organization_id: str = Field(..., description="The unique identifier of the organization.")
class ProviderTrace(BaseProviderTrace):
"""
Letta's internal representation of a provider trace.
Attributes:
id (str): The unique identifier of the provider trace.
request_json (Dict[str, Any]): JSON content of the provider request.
response_json (Dict[str, Any]): JSON content of the provider response.
step_id (str): ID of the step that this trace is associated with.
organization_id (str): The unique identifier of the organization.
created_at (datetime): The timestamp when the object was created.
"""
id: str = BaseProviderTrace.generate_id_field()
request_json: Dict[str, Any] = Field(..., description="JSON content of the provider request")
response_json: Dict[str, Any] = Field(..., description="JSON content of the provider response")
step_id: Optional[str] = Field(None, description="ID of the step that this trace is associated with")
organization_id: str = Field(..., description="The unique identifier of the organization.")
created_at: datetime = Field(default_factory=get_utc_time, description="The timestamp when the object was created.")

View File

@ -47,12 +47,21 @@ class Provider(ProviderBase):
def list_llm_models(self) -> List[LLMConfig]:
return []
async def list_llm_models_async(self) -> List[LLMConfig]:
return []
def list_embedding_models(self) -> List[EmbeddingConfig]:
return []
async def list_embedding_models_async(self) -> List[EmbeddingConfig]:
return []
def get_model_context_window(self, model_name: str) -> Optional[int]:
raise NotImplementedError
async def get_model_context_window_async(self, model_name: str) -> Optional[int]:
raise NotImplementedError
def provider_tag(self) -> str:
"""String representation of the provider for display purposes"""
raise NotImplementedError
@ -140,6 +149,19 @@ class LettaProvider(Provider):
)
]
async def list_llm_models_async(self) -> List[LLMConfig]:
return [
LLMConfig(
model="letta-free", # NOTE: renamed
model_endpoint_type="openai",
model_endpoint=LETTA_MODEL_ENDPOINT,
context_window=8192,
handle=self.get_handle("letta-free"),
provider_name=self.name,
provider_category=self.provider_category,
)
]
def list_embedding_models(self):
return [
EmbeddingConfig(
@ -189,9 +211,40 @@ class OpenAIProvider(Provider):
return data
async def _get_models_async(self) -> List[dict]:
from letta.llm_api.openai import openai_get_model_list_async
# Some hardcoded support for OpenRouter (so that we only get models with tool calling support)...
# See: https://openrouter.ai/docs/requests
extra_params = {"supported_parameters": "tools"} if "openrouter.ai" in self.base_url else None
# Similar to Nebius
extra_params = {"verbose": True} if "nebius.com" in self.base_url else None
response = await openai_get_model_list_async(
self.base_url,
api_key=self.api_key,
extra_params=extra_params,
# fix_url=True, # NOTE: make sure together ends with /v1
)
if "data" in response:
data = response["data"]
else:
# TogetherAI's response is missing the 'data' field
data = response
return data
def list_llm_models(self) -> List[LLMConfig]:
data = self._get_models()
return self._list_llm_models(data)
async def list_llm_models_async(self) -> List[LLMConfig]:
data = await self._get_models_async()
return self._list_llm_models(data)
def _list_llm_models(self, data) -> List[LLMConfig]:
configs = []
for model in data:
assert "id" in model, f"OpenAI model missing 'id' field: {model}"
@ -279,7 +332,6 @@ class OpenAIProvider(Provider):
return configs
def list_embedding_models(self) -> List[EmbeddingConfig]:
if self.base_url == "https://api.openai.com/v1":
# TODO: actually automatically list models for OpenAI
return [
@ -312,55 +364,92 @@ class OpenAIProvider(Provider):
else:
# Actually attempt to list
data = self._get_models()
return self._list_embedding_models(data)
configs = []
for model in data:
assert "id" in model, f"Model missing 'id' field: {model}"
model_name = model["id"]
async def list_embedding_models_async(self) -> List[EmbeddingConfig]:
if self.base_url == "https://api.openai.com/v1":
# TODO: actually automatically list models for OpenAI
return [
EmbeddingConfig(
embedding_model="text-embedding-ada-002",
embedding_endpoint_type="openai",
embedding_endpoint=self.base_url,
embedding_dim=1536,
embedding_chunk_size=300,
handle=self.get_handle("text-embedding-ada-002", is_embedding=True),
),
EmbeddingConfig(
embedding_model="text-embedding-3-small",
embedding_endpoint_type="openai",
embedding_endpoint=self.base_url,
embedding_dim=2000,
embedding_chunk_size=300,
handle=self.get_handle("text-embedding-3-small", is_embedding=True),
),
EmbeddingConfig(
embedding_model="text-embedding-3-large",
embedding_endpoint_type="openai",
embedding_endpoint=self.base_url,
embedding_dim=2000,
embedding_chunk_size=300,
handle=self.get_handle("text-embedding-3-large", is_embedding=True),
),
]
if "context_length" in model:
# Context length is returned in Nebius as "context_length"
context_window_size = model["context_length"]
else:
context_window_size = self.get_model_context_window_size(model_name)
else:
# Actually attempt to list
data = await self._get_models_async()
return self._list_embedding_models(data)
# We need the context length for embeddings too
if not context_window_size:
continue
def _list_embedding_models(self, data) -> List[EmbeddingConfig]:
configs = []
for model in data:
assert "id" in model, f"Model missing 'id' field: {model}"
model_name = model["id"]
if "nebius.com" in self.base_url:
# Nebius includes the type, which we can use to filter for embedidng models
try:
model_type = model["architecture"]["modality"]
if model_type not in ["text->embedding"]:
# print(f"Skipping model w/ modality {model_type}:\n{model}")
continue
except KeyError:
print(f"Couldn't access architecture type field, skipping model:\n{model}")
continue
if "context_length" in model:
# Context length is returned in Nebius as "context_length"
context_window_size = model["context_length"]
else:
context_window_size = self.get_model_context_window_size(model_name)
elif "together.ai" in self.base_url or "together.xyz" in self.base_url:
# TogetherAI includes the type, which we can use to filter for embedding models
if "type" in model and model["type"] not in ["embedding"]:
# We need the context length for embeddings too
if not context_window_size:
continue
if "nebius.com" in self.base_url:
# Nebius includes the type, which we can use to filter for embedidng models
try:
model_type = model["architecture"]["modality"]
if model_type not in ["text->embedding"]:
# print(f"Skipping model w/ modality {model_type}:\n{model}")
continue
else:
# For other providers we should skip by default, since we don't want to assume embeddings are supported
except KeyError:
print(f"Couldn't access architecture type field, skipping model:\n{model}")
continue
configs.append(
EmbeddingConfig(
embedding_model=model_name,
embedding_endpoint_type=self.provider_type,
embedding_endpoint=self.base_url,
embedding_dim=context_window_size,
embedding_chunk_size=DEFAULT_EMBEDDING_CHUNK_SIZE,
handle=self.get_handle(model, is_embedding=True),
)
)
elif "together.ai" in self.base_url or "together.xyz" in self.base_url:
# TogetherAI includes the type, which we can use to filter for embedding models
if "type" in model and model["type"] not in ["embedding"]:
# print(f"Skipping model w/ modality {model_type}:\n{model}")
continue
return configs
else:
# For other providers we should skip by default, since we don't want to assume embeddings are supported
continue
configs.append(
EmbeddingConfig(
embedding_model=model_name,
embedding_endpoint_type=self.provider_type,
embedding_endpoint=self.base_url,
embedding_dim=context_window_size,
embedding_chunk_size=DEFAULT_EMBEDDING_CHUNK_SIZE,
handle=self.get_handle(model, is_embedding=True),
)
)
return configs
def get_model_context_window_size(self, model_name: str):
if model_name in LLM_MAX_TOKENS:
@ -647,26 +736,19 @@ class AnthropicProvider(Provider):
anthropic_check_valid_api_key(self.api_key)
def list_llm_models(self) -> List[LLMConfig]:
from letta.llm_api.anthropic import MODEL_LIST, anthropic_get_model_list
from letta.llm_api.anthropic import anthropic_get_model_list
models = anthropic_get_model_list(self.base_url, api_key=self.api_key)
models = anthropic_get_model_list(api_key=self.api_key)
return self._list_llm_models(models)
"""
Example response:
{
"data": [
{
"type": "model",
"id": "claude-3-5-sonnet-20241022",
"display_name": "Claude 3.5 Sonnet (New)",
"created_at": "2024-10-22T00:00:00Z"
}
],
"has_more": true,
"first_id": "<string>",
"last_id": "<string>"
}
"""
async def list_llm_models_async(self) -> List[LLMConfig]:
from letta.llm_api.anthropic import anthropic_get_model_list_async
models = await anthropic_get_model_list_async(api_key=self.api_key)
return self._list_llm_models(models)
def _list_llm_models(self, models) -> List[LLMConfig]:
from letta.llm_api.anthropic import MODEL_LIST
configs = []
for model in models:
@ -724,9 +806,6 @@ class AnthropicProvider(Provider):
)
return configs
def list_embedding_models(self) -> List[EmbeddingConfig]:
return []
class MistralProvider(Provider):
provider_type: Literal[ProviderType.mistral] = Field(ProviderType.mistral, description="The type of the provider.")
@ -948,14 +1027,24 @@ class TogetherProvider(OpenAIProvider):
def list_llm_models(self) -> List[LLMConfig]:
from letta.llm_api.openai import openai_get_model_list
response = openai_get_model_list(self.base_url, api_key=self.api_key)
models = openai_get_model_list(self.base_url, api_key=self.api_key)
return self._list_llm_models(models)
async def list_llm_models_async(self) -> List[LLMConfig]:
from letta.llm_api.openai import openai_get_model_list_async
models = await openai_get_model_list_async(self.base_url, api_key=self.api_key)
return self._list_llm_models(models)
def _list_llm_models(self, models) -> List[LLMConfig]:
pass
# TogetherAI's response is missing the 'data' field
# assert "data" in response, f"OpenAI model query response missing 'data' field: {response}"
if "data" in response:
data = response["data"]
if "data" in models:
data = models["data"]
else:
data = response
data = models
configs = []
for model in data:
@ -1057,7 +1146,6 @@ class GoogleAIProvider(Provider):
from letta.llm_api.google_ai_client import google_ai_get_model_list
model_options = google_ai_get_model_list(base_url=self.base_url, api_key=self.api_key)
# filter by 'generateContent' models
model_options = [mo for mo in model_options if "generateContent" in mo["supportedGenerationMethods"]]
model_options = [str(m["name"]) for m in model_options]
@ -1081,6 +1169,42 @@ class GoogleAIProvider(Provider):
provider_category=self.provider_category,
)
)
return configs
async def list_llm_models_async(self):
import asyncio
from letta.llm_api.google_ai_client import google_ai_get_model_list_async
# Get and filter the model list
model_options = await google_ai_get_model_list_async(base_url=self.base_url, api_key=self.api_key)
model_options = [mo for mo in model_options if "generateContent" in mo["supportedGenerationMethods"]]
model_options = [str(m["name"]) for m in model_options]
# filter by model names
model_options = [mo[len("models/") :] if mo.startswith("models/") else mo for mo in model_options]
# Add support for all gemini models
model_options = [mo for mo in model_options if str(mo).startswith("gemini-")]
# Prepare tasks for context window lookups in parallel
async def create_config(model):
context_window = await self.get_model_context_window_async(model)
return LLMConfig(
model=model,
model_endpoint_type="google_ai",
model_endpoint=self.base_url,
context_window=context_window,
handle=self.get_handle(model),
max_tokens=8192,
provider_name=self.name,
provider_category=self.provider_category,
)
# Execute all config creation tasks concurrently
configs = await asyncio.gather(*[create_config(model) for model in model_options])
return configs
def list_embedding_models(self):
@ -1088,6 +1212,16 @@ class GoogleAIProvider(Provider):
# TODO: use base_url instead
model_options = google_ai_get_model_list(base_url=self.base_url, api_key=self.api_key)
return self._list_embedding_models(model_options)
async def list_embedding_models_async(self):
from letta.llm_api.google_ai_client import google_ai_get_model_list_async
# TODO: use base_url instead
model_options = await google_ai_get_model_list_async(base_url=self.base_url, api_key=self.api_key)
return self._list_embedding_models(model_options)
def _list_embedding_models(self, model_options):
# filter by 'generateContent' models
model_options = [mo for mo in model_options if "embedContent" in mo["supportedGenerationMethods"]]
model_options = [str(m["name"]) for m in model_options]
@ -1110,7 +1244,18 @@ class GoogleAIProvider(Provider):
def get_model_context_window(self, model_name: str) -> Optional[int]:
from letta.llm_api.google_ai_client import google_ai_get_model_context_window
return google_ai_get_model_context_window(self.base_url, self.api_key, model_name)
if model_name in LLM_MAX_TOKENS:
return LLM_MAX_TOKENS[model_name]
else:
return google_ai_get_model_context_window(self.base_url, self.api_key, model_name)
async def get_model_context_window_async(self, model_name: str) -> Optional[int]:
from letta.llm_api.google_ai_client import google_ai_get_model_context_window_async
if model_name in LLM_MAX_TOKENS:
return LLM_MAX_TOKENS[model_name]
else:
return await google_ai_get_model_context_window_async(self.base_url, self.api_key, model_name)
class GoogleVertexProvider(Provider):

View File

@ -20,6 +20,7 @@ class Step(StepBase):
)
agent_id: Optional[str] = Field(None, description="The ID of the agent that performed the step.")
provider_name: Optional[str] = Field(None, description="The name of the provider used for this step.")
provider_category: Optional[str] = Field(None, description="The category of the provider used for this step.")
model: Optional[str] = Field(None, description="The name of the model used for this step.")
model_endpoint: Optional[str] = Field(None, description="The model endpoint url used for this step.")
context_window_limit: Optional[int] = Field(None, description="The context window limit configured for this step.")

View File

@ -5,6 +5,7 @@ from pydantic import Field, model_validator
from letta.constants import (
COMPOSIO_TOOL_TAG_NAME,
FUNCTION_RETURN_CHAR_LIMIT,
LETTA_BUILTIN_TOOL_MODULE_NAME,
LETTA_CORE_TOOL_MODULE_NAME,
LETTA_MULTI_AGENT_TOOL_MODULE_NAME,
LETTA_VOICE_TOOL_MODULE_NAME,
@ -104,6 +105,9 @@ class Tool(BaseTool):
elif self.tool_type in {ToolType.LETTA_VOICE_SLEEPTIME_CORE}:
# If it's letta voice tool, we generate the json_schema on the fly here
self.json_schema = get_json_schema_from_module(module_name=LETTA_VOICE_TOOL_MODULE_NAME, function_name=self.name)
elif self.tool_type in {ToolType.LETTA_BUILTIN}:
# If it's letta voice tool, we generate the json_schema on the fly here
self.json_schema = get_json_schema_from_module(module_name=LETTA_BUILTIN_TOOL_MODULE_NAME, function_name=self.name)
# At this point, we need to validate that at least json_schema is populated
if not self.json_schema:

View File

@ -6,7 +6,7 @@ from typing import Any, AsyncGenerator, Generator
from rich.console import Console
from rich.panel import Panel
from rich.text import Text
from sqlalchemy import Engine, create_engine
from sqlalchemy import Engine, NullPool, QueuePool, create_engine
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import sessionmaker
@ -14,6 +14,8 @@ from letta.config import LettaConfig
from letta.log import get_logger
from letta.settings import settings
logger = get_logger(__name__)
def print_sqlite_schema_error():
"""Print a formatted error message for SQLite schema issues"""
@ -76,16 +78,7 @@ class DatabaseRegistry:
self.config.archival_storage_type = "postgres"
self.config.archival_storage_uri = settings.letta_pg_uri_no_default
engine = create_engine(
settings.letta_pg_uri,
# f"{settings.letta_pg_uri}?options=-c%20client_encoding=UTF8",
pool_size=settings.pg_pool_size,
max_overflow=settings.pg_max_overflow,
pool_timeout=settings.pg_pool_timeout,
pool_recycle=settings.pg_pool_recycle,
echo=settings.pg_echo,
# connect_args={"client_encoding": "utf8"},
)
engine = create_engine(settings.letta_pg_uri, **self._build_sqlalchemy_engine_args(is_async=False))
self._engines["default"] = engine
# SQLite engine
@ -125,14 +118,7 @@ class DatabaseRegistry:
async_pg_uri = f"postgresql+asyncpg://{pg_uri.split('://', 1)[1]}" if "://" in pg_uri else pg_uri
async_pg_uri = async_pg_uri.replace("sslmode=", "ssl=")
async_engine = create_async_engine(
async_pg_uri,
pool_size=settings.pg_pool_size,
max_overflow=settings.pg_max_overflow,
pool_timeout=settings.pg_pool_timeout,
pool_recycle=settings.pg_pool_recycle,
echo=settings.pg_echo,
)
async_engine = create_async_engine(async_pg_uri, **self._build_sqlalchemy_engine_args(is_async=True))
self._async_engines["default"] = async_engine
@ -146,6 +132,38 @@ class DatabaseRegistry:
# TODO (cliandy): unclear around async sqlite support in sqlalchemy, we will not currently support this
self._initialized["async"] = False
def _build_sqlalchemy_engine_args(self, *, is_async: bool) -> dict:
"""Prepare keyword arguments for create_engine / create_async_engine."""
use_null_pool = settings.disable_sqlalchemy_pooling
if use_null_pool:
logger.info("Disabling pooling on SqlAlchemy")
pool_cls = NullPool
else:
logger.info("Enabling pooling on SqlAlchemy")
pool_cls = QueuePool if not is_async else None
base_args = {
"echo": settings.pg_echo,
"pool_pre_ping": settings.pool_pre_ping,
}
if pool_cls:
base_args["poolclass"] = pool_cls
if not use_null_pool and not is_async:
base_args.update(
{
"pool_size": settings.pg_pool_size,
"max_overflow": settings.pg_max_overflow,
"pool_timeout": settings.pg_pool_timeout,
"pool_recycle": settings.pg_pool_recycle,
"pool_use_lifo": settings.pool_use_lifo,
}
)
return base_args
def _wrap_sqlite_engine(self, engine: Engine) -> None:
"""Wrap SQLite engine with error handling."""
original_connect = engine.connect

View File

@ -13,6 +13,7 @@ from letta.server.rest_api.routers.v1.sandbox_configs import router as sandbox_c
from letta.server.rest_api.routers.v1.sources import router as sources_router
from letta.server.rest_api.routers.v1.steps import router as steps_router
from letta.server.rest_api.routers.v1.tags import router as tags_router
from letta.server.rest_api.routers.v1.telemetry import router as telemetry_router
from letta.server.rest_api.routers.v1.tools import router as tools_router
from letta.server.rest_api.routers.v1.voice import router as voice_router
@ -31,6 +32,7 @@ ROUTERS = [
runs_router,
steps_router,
tags_router,
telemetry_router,
messages_router,
voice_router,
embeddings_router,

View File

@ -33,6 +33,7 @@ from letta.schemas.user import User
from letta.serialize_schemas.pydantic_agent_schema import AgentSchema
from letta.server.rest_api.utils import get_letta_server
from letta.server.server import SyncServer
from letta.services.telemetry_manager import NoopTelemetryManager
from letta.settings import settings
# These can be forward refs, but because Fastapi needs them at runtime the must be imported normally
@ -106,14 +107,15 @@ async def list_agents(
@router.get("/count", response_model=int, operation_id="count_agents")
def count_agents(
async def count_agents(
server: SyncServer = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"),
):
"""
Get the count of all agents associated with a given user.
"""
return server.agent_manager.size(actor=server.user_manager.get_user_or_default(user_id=actor_id))
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
return await server.agent_manager.size_async(actor=actor)
class IndentedORJSONResponse(Response):
@ -124,7 +126,7 @@ class IndentedORJSONResponse(Response):
@router.get("/{agent_id}/export", response_class=IndentedORJSONResponse, operation_id="export_agent_serialized")
def export_agent_serialized(
async def export_agent_serialized(
agent_id: str,
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"),
@ -135,7 +137,7 @@ def export_agent_serialized(
"""
Export the serialized JSON representation of an agent, formatted with indentation.
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
try:
agent = server.agent_manager.serialize(agent_id=agent_id, actor=actor)
@ -200,7 +202,7 @@ async def import_agent_serialized(
@router.get("/{agent_id}/context", response_model=ContextWindowOverview, operation_id="retrieve_agent_context_window")
def retrieve_agent_context_window(
async def retrieve_agent_context_window(
agent_id: str,
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
@ -208,9 +210,12 @@ def retrieve_agent_context_window(
"""
Retrieve the context window of a specific agent.
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.get_agent_context_window(agent_id=agent_id, actor=actor)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
try:
return await server.get_agent_context_window_async(agent_id=agent_id, actor=actor)
except Exception as e:
traceback.print_exc()
raise e
class CreateAgentRequest(CreateAgent):
@ -341,7 +346,7 @@ async def retrieve_agent(
"""
Get the state of the agent.
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
try:
return await server.agent_manager.get_agent_by_id_async(agent_id=agent_id, actor=actor)
@ -367,7 +372,7 @@ def delete_agent(
@router.get("/{agent_id}/sources", response_model=List[Source], operation_id="list_agent_sources")
def list_agent_sources(
async def list_agent_sources(
agent_id: str,
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
@ -375,8 +380,8 @@ def list_agent_sources(
"""
Get the sources associated with an agent.
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.agent_manager.list_attached_sources(agent_id=agent_id, actor=actor)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
return await server.agent_manager.list_attached_sources_async(agent_id=agent_id, actor=actor)
# TODO: remove? can also get with agent blocks
@ -424,14 +429,14 @@ async def list_blocks(
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
try:
agent = await server.agent_manager.get_agent_by_id_async(agent_id, actor)
agent = await server.agent_manager.get_agent_by_id_async(agent_id=agent_id, include_relationships=["memory"], actor=actor)
return agent.memory.blocks
except NoResultFound as e:
raise HTTPException(status_code=404, detail=str(e))
@router.patch("/{agent_id}/core-memory/blocks/{block_label}", response_model=Block, operation_id="modify_core_memory_block")
def modify_block(
async def modify_block(
agent_id: str,
block_label: str,
block_update: BlockUpdate = Body(...),
@ -441,10 +446,11 @@ def modify_block(
"""
Updates a core memory block of an agent.
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
block = server.agent_manager.get_block_with_label(agent_id=agent_id, block_label=block_label, actor=actor)
block = server.block_manager.update_block(block.id, block_update=block_update, actor=actor)
block = await server.agent_manager.modify_block_by_label_async(
agent_id=agent_id, block_label=block_label, block_update=block_update, actor=actor
)
# This should also trigger a system prompt change in the agent
server.agent_manager.rebuild_system_prompt(agent_id=agent_id, actor=actor, force=True, update_timestamp=False)
@ -481,7 +487,7 @@ def detach_block(
@router.get("/{agent_id}/archival-memory", response_model=List[Passage], operation_id="list_passages")
def list_passages(
async def list_passages(
agent_id: str,
server: "SyncServer" = Depends(get_letta_server),
after: Optional[str] = Query(None, description="Unique ID of the memory to start the query range at."),
@ -496,11 +502,11 @@ def list_passages(
"""
Retrieve the memories in an agent's archival memory store (paginated query).
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
return server.get_agent_archival(
user_id=actor.id,
return await server.get_agent_archival_async(
agent_id=agent_id,
actor=actor,
after=after,
before=before,
query_text=search,
@ -564,7 +570,7 @@ AgentMessagesResponse = Annotated[
@router.get("/{agent_id}/messages", response_model=AgentMessagesResponse, operation_id="list_messages")
def list_messages(
async def list_messages(
agent_id: str,
server: "SyncServer" = Depends(get_letta_server),
after: Optional[str] = Query(None, description="Message after which to retrieve the returned messages."),
@ -579,10 +585,9 @@ def list_messages(
"""
Retrieve message history for an agent.
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
return server.get_agent_recall(
user_id=actor.id,
return await server.get_agent_recall_async(
agent_id=agent_id,
after=after,
before=before,
@ -593,6 +598,7 @@ def list_messages(
use_assistant_message=use_assistant_message,
assistant_message_tool_name=assistant_message_tool_name,
assistant_message_tool_kwarg=assistant_message_tool_kwarg,
actor=actor,
)
@ -634,7 +640,7 @@ async def send_message(
agent_eligible = not agent.enable_sleeptime and not agent.multi_agent_group and agent.agent_type != AgentType.sleeptime_agent
experimental_header = request_obj.headers.get("X-EXPERIMENTAL") or "false"
feature_enabled = settings.use_experimental or experimental_header.lower() == "true"
model_compatible = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "google_vertex", "google_ai"]
model_compatible = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "together", "google_ai", "google_vertex"]
if agent_eligible and feature_enabled and model_compatible:
experimental_agent = LettaAgent(
@ -644,6 +650,8 @@ async def send_message(
block_manager=server.block_manager,
passage_manager=server.passage_manager,
actor=actor,
step_manager=server.step_manager,
telemetry_manager=server.telemetry_manager if settings.llm_api_logging else NoopTelemetryManager(),
)
result = await experimental_agent.step(request.messages, max_steps=10, use_assistant_message=request.use_assistant_message)
@ -692,7 +700,8 @@ async def send_message_streaming(
agent_eligible = not agent.enable_sleeptime and not agent.multi_agent_group and agent.agent_type != AgentType.sleeptime_agent
experimental_header = request_obj.headers.get("X-EXPERIMENTAL") or "false"
feature_enabled = settings.use_experimental or experimental_header.lower() == "true"
model_compatible = agent.llm_config.model_endpoint_type in ["anthropic", "openai"]
model_compatible = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "together", "google_ai", "google_vertex"]
model_compatible_token_streaming = agent.llm_config.model_endpoint_type in ["anthropic", "openai"]
if agent_eligible and feature_enabled and model_compatible and request.stream_tokens:
experimental_agent = LettaAgent(
@ -702,14 +711,28 @@ async def send_message_streaming(
block_manager=server.block_manager,
passage_manager=server.passage_manager,
actor=actor,
step_manager=server.step_manager,
telemetry_manager=server.telemetry_manager if settings.llm_api_logging else NoopTelemetryManager(),
)
from letta.server.rest_api.streaming_response import StreamingResponseWithStatusCode
result = StreamingResponse(
experimental_agent.step_stream(
request.messages, max_steps=10, use_assistant_message=request.use_assistant_message, stream_tokens=request.stream_tokens
),
media_type="text/event-stream",
)
if request.stream_tokens and model_compatible_token_streaming:
result = StreamingResponseWithStatusCode(
experimental_agent.step_stream(
input_messages=request.messages,
max_steps=10,
use_assistant_message=request.use_assistant_message,
request_start_timestamp_ns=request_start_timestamp_ns,
),
media_type="text/event-stream",
)
else:
result = StreamingResponseWithStatusCode(
experimental_agent.step_stream_no_tokens(
request.messages, max_steps=10, use_assistant_message=request.use_assistant_message
),
media_type="text/event-stream",
)
else:
result = await server.send_message_to_agent(
agent_id=agent_id,

View File

@ -99,7 +99,7 @@ def retrieve_block(
@router.get("/{block_id}/agents", response_model=List[AgentState], operation_id="list_agents_for_block")
def list_agents_for_block(
async def list_agents_for_block(
block_id: str,
server: SyncServer = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"),
@ -108,9 +108,9 @@ def list_agents_for_block(
Retrieves all agents associated with the specified block.
Raises a 404 if the block does not exist.
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
try:
agents = server.block_manager.get_agents_for_block(block_id=block_id, actor=actor)
agents = await server.block_manager.get_agents_for_block_async(block_id=block_id, actor=actor)
return agents
except NoResultFound:
raise HTTPException(status_code=404, detail=f"Block with id={block_id} not found")

View File

@ -13,7 +13,7 @@ router = APIRouter(prefix="/identities", tags=["identities"])
@router.get("/", tags=["identities"], response_model=List[Identity], operation_id="list_identities")
def list_identities(
async def list_identities(
name: Optional[str] = Query(None),
project_id: Optional[str] = Query(None),
identifier_key: Optional[str] = Query(None),
@ -28,9 +28,9 @@ def list_identities(
Get a list of all identities in the database
"""
try:
actor = server.user_manager.get_user_or_default(user_id=actor_id)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
identities = server.identity_manager.list_identities(
identities = await server.identity_manager.list_identities_async(
name=name,
project_id=project_id,
identifier_key=identifier_key,
@ -50,7 +50,7 @@ def list_identities(
@router.get("/count", tags=["identities"], response_model=int, operation_id="count_identities")
def count_identities(
async def count_identities(
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"),
):
@ -58,7 +58,8 @@ def count_identities(
Get count of all identities for a user
"""
try:
return server.identity_manager.size(actor=server.user_manager.get_user_or_default(user_id=actor_id))
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
return await server.identity_manager.size_async(actor=actor)
except NoResultFound:
return 0
except HTTPException:
@ -68,28 +69,28 @@ def count_identities(
@router.get("/{identity_id}", tags=["identities"], response_model=Identity, operation_id="retrieve_identity")
def retrieve_identity(
async def retrieve_identity(
identity_id: str,
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
try:
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.identity_manager.get_identity(identity_id=identity_id, actor=actor)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
return await server.identity_manager.get_identity_async(identity_id=identity_id, actor=actor)
except NoResultFound as e:
raise HTTPException(status_code=404, detail=str(e))
@router.post("/", tags=["identities"], response_model=Identity, operation_id="create_identity")
def create_identity(
async def create_identity(
identity: IdentityCreate = Body(...),
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
x_project: Optional[str] = Header(None, alias="X-Project"), # Only handled by next js middleware
):
try:
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.identity_manager.create_identity(identity=identity, actor=actor)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
return await server.identity_manager.create_identity_async(identity=identity, actor=actor)
except HTTPException:
raise
except UniqueConstraintViolationError:
@ -105,15 +106,15 @@ def create_identity(
@router.put("/", tags=["identities"], response_model=Identity, operation_id="upsert_identity")
def upsert_identity(
async def upsert_identity(
identity: IdentityUpsert = Body(...),
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
x_project: Optional[str] = Header(None, alias="X-Project"), # Only handled by next js middleware
):
try:
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.identity_manager.upsert_identity(identity=identity, actor=actor)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
return await server.identity_manager.upsert_identity_async(identity=identity, actor=actor)
except HTTPException:
raise
except NoResultFound as e:
@ -123,36 +124,33 @@ def upsert_identity(
@router.patch("/{identity_id}", tags=["identities"], response_model=Identity, operation_id="update_identity")
def modify_identity(
async def modify_identity(
identity_id: str,
identity: IdentityUpdate = Body(...),
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
try:
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.identity_manager.update_identity(identity_id=identity_id, identity=identity, actor=actor)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
return await server.identity_manager.update_identity_async(identity_id=identity_id, identity=identity, actor=actor)
except HTTPException:
raise
except NoResultFound as e:
raise HTTPException(status_code=404, detail=str(e))
except Exception as e:
import traceback
print(traceback.format_exc())
raise HTTPException(status_code=500, detail=f"{e}")
@router.put("/{identity_id}/properties", tags=["identities"], operation_id="upsert_identity_properties")
def upsert_identity_properties(
async def upsert_identity_properties(
identity_id: str,
properties: List[IdentityProperty] = Body(...),
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
try:
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.identity_manager.upsert_identity_properties(identity_id=identity_id, properties=properties, actor=actor)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
return await server.identity_manager.upsert_identity_properties_async(identity_id=identity_id, properties=properties, actor=actor)
except HTTPException:
raise
except NoResultFound as e:
@ -162,7 +160,7 @@ def upsert_identity_properties(
@router.delete("/{identity_id}", tags=["identities"], operation_id="delete_identity")
def delete_identity(
async def delete_identity(
identity_id: str,
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
@ -171,8 +169,8 @@ def delete_identity(
Delete an identity by its identifier key
"""
try:
actor = server.user_manager.get_user_or_default(user_id=actor_id)
server.identity_manager.delete_identity(identity_id=identity_id, actor=actor)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
await server.identity_manager.delete_identity_async(identity_id=identity_id, actor=actor)
except HTTPException:
raise
except NoResultFound as e:

View File

@ -33,16 +33,16 @@ def list_jobs(
@router.get("/active", response_model=List[Job], operation_id="list_active_jobs")
def list_active_jobs(
async def list_active_jobs(
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
List all active jobs.
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
return server.job_manager.list_jobs(actor=actor, statuses=[JobStatus.created, JobStatus.running])
return await server.job_manager.list_jobs_async(actor=actor, statuses=[JobStatus.created, JobStatus.running])
@router.get("/{job_id}", response_model=Job, operation_id="retrieve_job")

View File

@ -14,30 +14,35 @@ router = APIRouter(prefix="/models", tags=["models", "llms"])
@router.get("/", response_model=List[LLMConfig], operation_id="list_models")
def list_llm_models(
async def list_llm_models(
provider_category: Optional[List[ProviderCategory]] = Query(None),
provider_name: Optional[str] = Query(None),
provider_type: Optional[ProviderType] = Query(None),
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
actor_id: Optional[str] = Header(None, alias="user_id"),
# Extract user_id from header, default to None if not present
):
"""List available LLM models using the asynchronous implementation for improved performance"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
models = server.list_llm_models(
models = await server.list_llm_models_async(
provider_category=provider_category,
provider_name=provider_name,
provider_type=provider_type,
actor=actor,
)
# print(models)
return models
@router.get("/embedding", response_model=List[EmbeddingConfig], operation_id="list_embedding_models")
def list_embedding_models(
async def list_embedding_models(
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
actor_id: Optional[str] = Header(None, alias="user_id"),
# Extract user_id from header, default to None if not present
):
"""List available embedding models using the asynchronous implementation for improved performance"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
models = server.list_embedding_models(actor=actor)
# print(models)
models = await server.list_embedding_models_async(actor=actor)
return models

View File

@ -100,15 +100,15 @@ def delete_sandbox_config(
@router.get("/", response_model=List[PydanticSandboxConfig])
def list_sandbox_configs(
async def list_sandbox_configs(
limit: int = Query(1000, description="Number of results to return"),
after: Optional[str] = Query(None, description="Pagination cursor to fetch the next set of results"),
sandbox_type: Optional[SandboxType] = Query(None, description="Filter for this specific sandbox type"),
server: SyncServer = Depends(get_letta_server),
actor_id: str = Depends(get_user_id),
):
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.sandbox_config_manager.list_sandbox_configs(actor, limit=limit, after=after, sandbox_type=sandbox_type)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
return await server.sandbox_config_manager.list_sandbox_configs_async(actor, limit=limit, after=after, sandbox_type=sandbox_type)
@router.post("/local/recreate-venv", response_model=PydanticSandboxConfig)
@ -190,12 +190,12 @@ def delete_sandbox_env_var(
@router.get("/{sandbox_config_id}/environment-variable", response_model=List[PydanticEnvVar])
def list_sandbox_env_vars(
async def list_sandbox_env_vars(
sandbox_config_id: str,
limit: int = Query(1000, description="Number of results to return"),
after: Optional[str] = Query(None, description="Pagination cursor to fetch the next set of results"),
server: SyncServer = Depends(get_letta_server),
actor_id: str = Depends(get_user_id),
):
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.sandbox_config_manager.list_sandbox_env_vars(sandbox_config_id, actor, limit=limit, after=after)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
return await server.sandbox_config_manager.list_sandbox_env_vars_async(sandbox_config_id, actor, limit=limit, after=after)

View File

@ -12,7 +12,7 @@ router = APIRouter(prefix="/tags", tags=["tag", "admin"])
@router.get("/", tags=["admin"], response_model=List[str], operation_id="list_tags")
def list_tags(
async def list_tags(
after: Optional[str] = Query(None),
limit: Optional[int] = Query(50),
server: "SyncServer" = Depends(get_letta_server),
@ -22,6 +22,6 @@ def list_tags(
"""
Get a list of all tags in the database
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
tags = server.agent_manager.list_tags(actor=actor, after=after, limit=limit, query_text=query_text)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
tags = await server.agent_manager.list_tags_async(actor=actor, after=after, limit=limit, query_text=query_text)
return tags

View File

@ -0,0 +1,18 @@
from fastapi import APIRouter, Depends, Header
from letta.schemas.provider_trace import ProviderTrace
from letta.server.rest_api.utils import get_letta_server
from letta.server.server import SyncServer
router = APIRouter(prefix="/telemetry", tags=["telemetry"])
@router.get("/{step_id}", response_model=ProviderTrace, operation_id="retrieve_provider_trace")
async def retrieve_provider_trace_by_step_id(
step_id: str,
server: SyncServer = Depends(get_letta_server),
actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
return await server.telemetry_manager.get_provider_trace_by_step_id_async(
step_id=step_id, actor=await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
)

View File

@ -59,7 +59,7 @@ def count_tools(
@router.get("/{tool_id}", response_model=Tool, operation_id="retrieve_tool")
def retrieve_tool(
async def retrieve_tool(
tool_id: str,
server: SyncServer = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
@ -67,8 +67,8 @@ def retrieve_tool(
"""
Get a tool by ID
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
tool = server.tool_manager.get_tool_by_id(tool_id=tool_id, actor=actor)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
tool = await server.tool_manager.get_tool_by_id_async(tool_id=tool_id, actor=actor)
if tool is None:
# return 404 error
raise HTTPException(status_code=404, detail=f"Tool with id {tool_id} not found.")
@ -196,15 +196,15 @@ def modify_tool(
@router.post("/add-base-tools", response_model=List[Tool], operation_id="add_base_tools")
def upsert_base_tools(
async def upsert_base_tools(
server: SyncServer = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Upsert base tools
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.tool_manager.upsert_base_tools(actor=actor)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
return await server.tool_manager.upsert_base_tools_async(actor=actor)
@router.post("/run", response_model=ToolReturnMessage, operation_id="run_tool_from_source")

View File

@ -0,0 +1,105 @@
# Alternative implementation of StreamingResponse that allows for effectively
# stremaing HTTP trailers, as we cannot set codes after the initial response.
# Taken from: https://github.com/fastapi/fastapi/discussions/10138#discussioncomment-10377361
import json
from collections.abc import AsyncIterator
from fastapi.responses import StreamingResponse
from starlette.types import Send
from letta.log import get_logger
logger = get_logger(__name__)
class StreamingResponseWithStatusCode(StreamingResponse):
"""
Variation of StreamingResponse that can dynamically decide the HTTP status code,
based on the return value of the content iterator (parameter `content`).
Expects the content to yield either just str content as per the original `StreamingResponse`
or else tuples of (`content`: `str`, `status_code`: `int`).
"""
body_iterator: AsyncIterator[str | bytes]
response_started: bool = False
async def stream_response(self, send: Send) -> None:
more_body = True
try:
first_chunk = await self.body_iterator.__anext__()
if isinstance(first_chunk, tuple):
first_chunk_content, self.status_code = first_chunk
else:
first_chunk_content = first_chunk
if isinstance(first_chunk_content, str):
first_chunk_content = first_chunk_content.encode(self.charset)
await send(
{
"type": "http.response.start",
"status": self.status_code,
"headers": self.raw_headers,
}
)
self.response_started = True
await send(
{
"type": "http.response.body",
"body": first_chunk_content,
"more_body": more_body,
}
)
async for chunk in self.body_iterator:
if isinstance(chunk, tuple):
content, status_code = chunk
if status_code // 100 != 2:
# An error occurred mid-stream
if not isinstance(content, bytes):
content = content.encode(self.charset)
more_body = False
await send(
{
"type": "http.response.body",
"body": content,
"more_body": more_body,
}
)
return
else:
content = chunk
if isinstance(content, str):
content = content.encode(self.charset)
more_body = True
await send(
{
"type": "http.response.body",
"body": content,
"more_body": more_body,
}
)
except Exception:
logger.exception("unhandled_streaming_error")
more_body = False
error_resp = {"error": {"message": "Internal Server Error"}}
error_event = f"event: error\ndata: {json.dumps(error_resp)}\n\n".encode(self.charset)
if not self.response_started:
await send(
{
"type": "http.response.start",
"status": 500,
"headers": self.raw_headers,
}
)
await send(
{
"type": "http.response.body",
"body": error_event,
"more_body": more_body,
}
)
if more_body:
await send({"type": "http.response.body", "body": b"", "more_body": False})

View File

@ -190,6 +190,7 @@ def create_letta_messages_from_llm_response(
pre_computed_assistant_message_id: Optional[str] = None,
pre_computed_tool_message_id: Optional[str] = None,
llm_batch_item_id: Optional[str] = None,
step_id: str | None = None,
) -> List[Message]:
messages = []
@ -244,6 +245,9 @@ def create_letta_messages_from_llm_response(
)
messages.append(heartbeat_system_message)
for message in messages:
message.step_id = step_id
return messages

View File

@ -94,6 +94,7 @@ from letta.services.provider_manager import ProviderManager
from letta.services.sandbox_config_manager import SandboxConfigManager
from letta.services.source_manager import SourceManager
from letta.services.step_manager import StepManager
from letta.services.telemetry_manager import TelemetryManager
from letta.services.tool_executor.tool_execution_sandbox import ToolExecutionSandbox
from letta.services.tool_manager import ToolManager
from letta.services.user_manager import UserManager
@ -213,6 +214,7 @@ class SyncServer(Server):
self.identity_manager = IdentityManager()
self.group_manager = GroupManager()
self.batch_manager = LLMBatchManager()
self.telemetry_manager = TelemetryManager()
# A resusable httpx client
timeout = httpx.Timeout(connect=10.0, read=20.0, write=10.0, pool=10.0)
@ -1000,6 +1002,30 @@ class SyncServer(Server):
)
return records
async def get_agent_archival_async(
self,
agent_id: str,
actor: User,
after: Optional[str] = None,
before: Optional[str] = None,
limit: Optional[int] = 100,
order_by: Optional[str] = "created_at",
reverse: Optional[bool] = False,
query_text: Optional[str] = None,
ascending: Optional[bool] = True,
) -> List[Passage]:
# iterate over records
records = await self.agent_manager.list_passages_async(
actor=actor,
agent_id=agent_id,
after=after,
query_text=query_text,
before=before,
ascending=ascending,
limit=limit,
)
return records
def insert_archival_memory(self, agent_id: str, memory_contents: str, actor: User) -> List[Passage]:
# Get the agent object (loaded in memory)
agent_state = self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor)
@ -1070,6 +1096,44 @@ class SyncServer(Server):
return records
async def get_agent_recall_async(
self,
agent_id: str,
actor: User,
after: Optional[str] = None,
before: Optional[str] = None,
limit: Optional[int] = 100,
group_id: Optional[str] = None,
reverse: Optional[bool] = False,
return_message_object: bool = True,
use_assistant_message: bool = True,
assistant_message_tool_name: str = constants.DEFAULT_MESSAGE_TOOL,
assistant_message_tool_kwarg: str = constants.DEFAULT_MESSAGE_TOOL_KWARG,
) -> Union[List[Message], List[LettaMessage]]:
records = await self.message_manager.list_messages_for_agent_async(
agent_id=agent_id,
actor=actor,
after=after,
before=before,
limit=limit,
ascending=not reverse,
group_id=group_id,
)
if not return_message_object:
records = Message.to_letta_messages_from_list(
messages=records,
use_assistant_message=use_assistant_message,
assistant_message_tool_name=assistant_message_tool_name,
assistant_message_tool_kwarg=assistant_message_tool_kwarg,
reverse=reverse,
)
if reverse:
records = records[::-1]
return records
def get_server_config(self, include_defaults: bool = False) -> dict:
"""Return the base config"""
@ -1301,6 +1365,48 @@ class SyncServer(Server):
return llm_models
@trace_method
async def list_llm_models_async(
self,
actor: User,
provider_category: Optional[List[ProviderCategory]] = None,
provider_name: Optional[str] = None,
provider_type: Optional[ProviderType] = None,
) -> List[LLMConfig]:
"""Asynchronously list available models with maximum concurrency"""
import asyncio
providers = self.get_enabled_providers(
provider_category=provider_category,
provider_name=provider_name,
provider_type=provider_type,
actor=actor,
)
async def get_provider_models(provider):
try:
return await provider.list_llm_models_async()
except Exception as e:
import traceback
traceback.print_exc()
warnings.warn(f"An error occurred while listing LLM models for provider {provider}: {e}")
return []
# Execute all provider model listing tasks concurrently
provider_results = await asyncio.gather(*[get_provider_models(provider) for provider in providers])
# Flatten the results
llm_models = []
for models in provider_results:
llm_models.extend(models)
# Get local configs - if this is potentially slow, consider making it async too
local_configs = self.get_local_llm_configs()
llm_models.extend(local_configs)
return llm_models
def list_embedding_models(self, actor: User) -> List[EmbeddingConfig]:
"""List available embedding models"""
embedding_models = []
@ -1311,6 +1417,35 @@ class SyncServer(Server):
warnings.warn(f"An error occurred while listing embedding models for provider {provider}: {e}")
return embedding_models
async def list_embedding_models_async(self, actor: User) -> List[EmbeddingConfig]:
"""Asynchronously list available embedding models with maximum concurrency"""
import asyncio
# Get all eligible providers first
providers = self.get_enabled_providers(actor=actor)
# Fetch embedding models from each provider concurrently
async def get_provider_embedding_models(provider):
try:
# All providers now have list_embedding_models_async
return await provider.list_embedding_models_async()
except Exception as e:
import traceback
traceback.print_exc()
warnings.warn(f"An error occurred while listing embedding models for provider {provider}: {e}")
return []
# Execute all provider model listing tasks concurrently
provider_results = await asyncio.gather(*[get_provider_embedding_models(provider) for provider in providers])
# Flatten the results
embedding_models = []
for models in provider_results:
embedding_models.extend(models)
return embedding_models
def get_enabled_providers(
self,
actor: User,
@ -1482,6 +1617,10 @@ class SyncServer(Server):
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
return letta_agent.get_context_window()
async def get_agent_context_window_async(self, agent_id: str, actor: User) -> ContextWindowOverview:
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
return await letta_agent.get_context_window_async()
def run_tool_from_source(
self,
actor: User,
@ -1615,7 +1754,7 @@ class SyncServer(Server):
server_name=server_name,
command=server_params_raw["command"],
args=server_params_raw.get("args", []),
env=server_params_raw.get("env", {})
env=server_params_raw.get("env", {}),
)
mcp_server_list[server_name] = server_params
except Exception as e:

View File

@ -892,7 +892,7 @@ class AgentManager:
List[PydanticAgentState]: The filtered list of matching agents.
"""
async with db_registry.async_session() as session:
query = select(AgentModel).distinct(AgentModel.created_at, AgentModel.id)
query = select(AgentModel)
query = AgentModel.apply_access_predicate(query, actor, ["read"], AccessType.ORGANIZATION)
# Apply filters
@ -961,6 +961,16 @@ class AgentManager:
with db_registry.session() as session:
return AgentModel.size(db_session=session, actor=actor)
async def size_async(
self,
actor: PydanticUser,
) -> int:
"""
Get the total count of agents for the given user.
"""
async with db_registry.async_session() as session:
return await AgentModel.size_async(db_session=session, actor=actor)
@enforce_types
def get_agent_by_id(self, agent_id: str, actor: PydanticUser) -> PydanticAgentState:
"""Fetch an agent by its ID."""
@ -969,18 +979,32 @@ class AgentManager:
return agent.to_pydantic()
@enforce_types
async def get_agent_by_id_async(self, agent_id: str, actor: PydanticUser) -> PydanticAgentState:
async def get_agent_by_id_async(
self,
agent_id: str,
actor: PydanticUser,
include_relationships: Optional[List[str]] = None,
) -> PydanticAgentState:
"""Fetch an agent by its ID."""
async with db_registry.async_session() as session:
agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
return agent.to_pydantic()
return await agent.to_pydantic_async(include_relationships=include_relationships)
@enforce_types
async def get_agents_by_ids_async(self, agent_ids: list[str], actor: PydanticUser) -> list[PydanticAgentState]:
async def get_agents_by_ids_async(
self,
agent_ids: list[str],
actor: PydanticUser,
include_relationships: Optional[List[str]] = None,
) -> list[PydanticAgentState]:
"""Fetch a list of agents by their IDs."""
async with db_registry.async_session() as session:
agents = await AgentModel.read_multiple_async(db_session=session, identifiers=agent_ids, actor=actor)
return [await agent.to_pydantic_async() for agent in agents]
agents = await AgentModel.read_multiple_async(
db_session=session,
identifiers=agent_ids,
actor=actor,
)
return await asyncio.gather(*[agent.to_pydantic_async(include_relationships=include_relationships) for agent in agents])
@enforce_types
def get_agent_by_name(self, agent_name: str, actor: PydanticUser) -> PydanticAgentState:
@ -1191,7 +1215,7 @@ class AgentManager:
@enforce_types
async def get_in_context_messages_async(self, agent_id: str, actor: PydanticUser) -> List[PydanticMessage]:
agent = await self.get_agent_by_id_async(agent_id=agent_id, actor=actor)
agent = await self.get_agent_by_id_async(agent_id=agent_id, include_relationships=[], actor=actor)
return await self.message_manager.get_messages_by_ids_async(message_ids=agent.message_ids, actor=actor)
@enforce_types
@ -1199,6 +1223,11 @@ class AgentManager:
message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids
return self.message_manager.get_message_by_id(message_id=message_ids[0], actor=actor)
@enforce_types
async def get_system_message_async(self, agent_id: str, actor: PydanticUser) -> PydanticMessage:
agent = await self.get_agent_by_id_async(agent_id=agent_id, include_relationships=[], actor=actor)
return await self.message_manager.get_message_by_id_async(message_id=agent.message_ids[0], actor=actor)
# TODO: This is duplicated below
# TODO: This is legacy code and should be cleaned up
# TODO: A lot of the memory "compilation" should be offset to a separate class
@ -1267,10 +1296,81 @@ class AgentManager:
else:
return agent_state
@enforce_types
async def rebuild_system_prompt_async(
self, agent_id: str, actor: PydanticUser, force=False, update_timestamp=True
) -> PydanticAgentState:
"""Rebuilds the system message with the latest memory object and any shared memory block updates
Updates to core memory blocks should trigger a "rebuild", which itself will create a new message object
Updates to the memory header should *not* trigger a rebuild, since that will simply flood recall storage with excess messages
"""
agent_state = await self.get_agent_by_id_async(agent_id=agent_id, include_relationships=["memory"], actor=actor)
curr_system_message = await self.get_system_message_async(
agent_id=agent_id, actor=actor
) # this is the system + memory bank, not just the system prompt
curr_system_message_openai = curr_system_message.to_openai_dict()
# note: we only update the system prompt if the core memory is changed
# this means that the archival/recall memory statistics may be someout out of date
curr_memory_str = agent_state.memory.compile()
if curr_memory_str in curr_system_message_openai["content"] and not force:
# NOTE: could this cause issues if a block is removed? (substring match would still work)
logger.debug(
f"Memory hasn't changed for agent id={agent_id} and actor=({actor.id}, {actor.name}), skipping system prompt rebuild"
)
return agent_state
# If the memory didn't update, we probably don't want to update the timestamp inside
# For example, if we're doing a system prompt swap, this should probably be False
if update_timestamp:
memory_edit_timestamp = get_utc_time()
else:
# NOTE: a bit of a hack - we pull the timestamp from the message created_by
memory_edit_timestamp = curr_system_message.created_at
num_messages = await self.message_manager.size_async(actor=actor, agent_id=agent_id)
num_archival_memories = await self.passage_manager.size_async(actor=actor, agent_id=agent_id)
# update memory (TODO: potentially update recall/archival stats separately)
new_system_message_str = compile_system_message(
system_prompt=agent_state.system,
in_context_memory=agent_state.memory,
in_context_memory_last_edit=memory_edit_timestamp,
recent_passages=self.list_passages(actor=actor, agent_id=agent_id, ascending=False, limit=10),
previous_message_count=num_messages,
archival_memory_size=num_archival_memories,
)
diff = united_diff(curr_system_message_openai["content"], new_system_message_str)
if len(diff) > 0: # there was a diff
logger.debug(f"Rebuilding system with new memory...\nDiff:\n{diff}")
# Swap the system message out (only if there is a diff)
message = PydanticMessage.dict_to_message(
agent_id=agent_id,
model=agent_state.llm_config.model,
openai_message_dict={"role": "system", "content": new_system_message_str},
)
message = await self.message_manager.update_message_by_id_async(
message_id=curr_system_message.id,
message_update=MessageUpdate(**message.model_dump()),
actor=actor,
)
return await self.set_in_context_messages_async(agent_id=agent_id, message_ids=agent_state.message_ids, actor=actor)
else:
return agent_state
@enforce_types
def set_in_context_messages(self, agent_id: str, message_ids: List[str], actor: PydanticUser) -> PydanticAgentState:
return self.update_agent(agent_id=agent_id, agent_update=UpdateAgent(message_ids=message_ids), actor=actor)
@enforce_types
async def set_in_context_messages_async(self, agent_id: str, message_ids: List[str], actor: PydanticUser) -> PydanticAgentState:
return await self.update_agent_async(agent_id=agent_id, agent_update=UpdateAgent(message_ids=message_ids), actor=actor)
@enforce_types
def trim_older_in_context_messages(self, num: int, agent_id: str, actor: PydanticUser) -> PydanticAgentState:
message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids
@ -1382,17 +1482,6 @@ class AgentManager:
return agent_state
@enforce_types
def refresh_memory(self, agent_state: PydanticAgentState, actor: PydanticUser) -> PydanticAgentState:
block_ids = [b.id for b in agent_state.memory.blocks]
if not block_ids:
return agent_state
agent_state.memory.blocks = self.block_manager.get_all_blocks_by_ids(
block_ids=[b.id for b in agent_state.memory.blocks], actor=actor
)
return agent_state
@enforce_types
async def refresh_memory_async(self, agent_state: PydanticAgentState, actor: PydanticUser) -> PydanticAgentState:
block_ids = [b.id for b in agent_state.memory.blocks]
@ -1482,6 +1571,25 @@ class AgentManager:
# Use the lazy-loaded relationship to get sources
return [source.to_pydantic() for source in agent.sources]
@enforce_types
async def list_attached_sources_async(self, agent_id: str, actor: PydanticUser) -> List[PydanticSource]:
"""
Lists all sources attached to an agent.
Args:
agent_id: ID of the agent to list sources for
actor: User performing the action
Returns:
List[str]: List of source IDs attached to the agent
"""
async with db_registry.async_session() as session:
# Verify agent exists and user has permission to access it
agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
# Use the lazy-loaded relationship to get sources
return [source.to_pydantic() for source in agent.sources]
@enforce_types
def detach_source(self, agent_id: str, source_id: str, actor: PydanticUser) -> PydanticAgentState:
"""
@ -1527,6 +1635,33 @@ class AgentManager:
return block.to_pydantic()
raise NoResultFound(f"No block with label '{block_label}' found for agent '{agent_id}'")
@enforce_types
async def modify_block_by_label_async(
self,
agent_id: str,
block_label: str,
block_update: BlockUpdate,
actor: PydanticUser,
) -> PydanticBlock:
"""Gets a block attached to an agent by its label."""
async with db_registry.async_session() as session:
block = None
agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
for block in agent.core_memory:
if block.label == block_label:
block = block
break
if not block:
raise NoResultFound(f"No block with label '{block_label}' found for agent '{agent_id}'")
update_data = block_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
for key, value in update_data.items():
setattr(block, key, value)
await block.update_async(session, actor=actor)
return block.to_pydantic()
@enforce_types
def update_block_with_label(
self,
@ -1848,6 +1983,65 @@ class AgentManager:
return [p.to_pydantic() for p in passages]
@enforce_types
async def list_passages_async(
self,
actor: PydanticUser,
agent_id: Optional[str] = None,
file_id: Optional[str] = None,
limit: Optional[int] = 50,
query_text: Optional[str] = None,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
before: Optional[str] = None,
after: Optional[str] = None,
source_id: Optional[str] = None,
embed_query: bool = False,
ascending: bool = True,
embedding_config: Optional[EmbeddingConfig] = None,
agent_only: bool = False,
) -> List[PydanticPassage]:
"""Lists all passages attached to an agent."""
async with db_registry.async_session() as session:
main_query = self._build_passage_query(
actor=actor,
agent_id=agent_id,
file_id=file_id,
query_text=query_text,
start_date=start_date,
end_date=end_date,
before=before,
after=after,
source_id=source_id,
embed_query=embed_query,
ascending=ascending,
embedding_config=embedding_config,
agent_only=agent_only,
)
# Add limit
if limit:
main_query = main_query.limit(limit)
# Execute query
result = await session.execute(main_query)
passages = []
for row in result:
data = dict(row._mapping)
if data["agent_id"] is not None:
# This is an AgentPassage - remove source fields
data.pop("source_id", None)
data.pop("file_id", None)
passage = AgentPassage(**data)
else:
# This is a SourcePassage - remove agent field
data.pop("agent_id", None)
passage = SourcePassage(**data)
passages.append(passage)
return [p.to_pydantic() for p in passages]
@enforce_types
def passage_size(
self,
@ -2010,3 +2204,42 @@ class AgentManager:
query = query.order_by(AgentsTags.tag).limit(limit)
results = [tag[0] for tag in query.all()]
return results
@enforce_types
async def list_tags_async(
self, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50, query_text: Optional[str] = None
) -> List[str]:
"""
Get all tags a user has created, ordered alphabetically.
Args:
actor: User performing the action.
after: Cursor for forward pagination.
limit: Maximum number of tags to return.
query text to filter tags by.
Returns:
List[str]: List of all tags.
"""
async with db_registry.async_session() as session:
# Build the query using select() for async SQLAlchemy
query = (
select(AgentsTags.tag)
.join(AgentModel, AgentModel.id == AgentsTags.agent_id)
.where(AgentModel.organization_id == actor.organization_id)
.distinct()
)
if query_text:
query = query.where(AgentsTags.tag.ilike(f"%{query_text}%"))
if after:
query = query.where(AgentsTags.tag > after)
query = query.order_by(AgentsTags.tag).limit(limit)
# Execute the query asynchronously
result = await session.execute(query)
# Extract the tag values from the result
results = [row[0] for row in result.all()]
return results

View File

@ -1,3 +1,4 @@
import asyncio
from typing import Dict, List, Optional
from sqlalchemy import select
@ -82,39 +83,64 @@ class BlockManager:
return block.to_pydantic()
@enforce_types
def get_blocks(
async def get_blocks_async(
self,
actor: PydanticUser,
label: Optional[str] = None,
is_template: Optional[bool] = None,
template_name: Optional[str] = None,
identifier_keys: Optional[List[str]] = None,
identity_id: Optional[str] = None,
id: Optional[str] = None,
after: Optional[str] = None,
identifier_keys: Optional[List[str]] = None,
limit: Optional[int] = 50,
) -> List[PydanticBlock]:
"""Retrieve blocks based on various optional filters."""
with db_registry.session() as session:
# Prepare filters
filters = {"organization_id": actor.organization_id}
if label:
filters["label"] = label
if is_template is not None:
filters["is_template"] = is_template
if template_name:
filters["template_name"] = template_name
if id:
filters["id"] = id
"""Async version of get_blocks method. Retrieve blocks based on various optional filters."""
from sqlalchemy import select
from sqlalchemy.orm import noload
blocks = BlockModel.list(
db_session=session,
after=after,
limit=limit,
identifier_keys=identifier_keys,
identity_id=identity_id,
**filters,
)
from letta.orm.sqlalchemy_base import AccessType
async with db_registry.async_session() as session:
# Start with a basic query
query = select(BlockModel)
# Explicitly avoid loading relationships
query = query.options(noload(BlockModel.agents), noload(BlockModel.identities), noload(BlockModel.groups))
# Apply access control
query = BlockModel.apply_access_predicate(query, actor, ["read"], AccessType.ORGANIZATION)
# Add filters
query = query.where(BlockModel.organization_id == actor.organization_id)
if label:
query = query.where(BlockModel.label == label)
if is_template is not None:
query = query.where(BlockModel.is_template == is_template)
if template_name:
query = query.where(BlockModel.template_name == template_name)
if identifier_keys:
query = (
query.join(BlockModel.identities)
.filter(BlockModel.identities.property.mapper.class_.identifier_key.in_(identifier_keys))
.distinct(BlockModel.id)
)
if identity_id:
query = (
query.join(BlockModel.identities)
.filter(BlockModel.identities.property.mapper.class_.id == identity_id)
.distinct(BlockModel.id)
)
# Add limit
if limit:
query = query.limit(limit)
# Execute the query
result = await session.execute(query)
blocks = result.scalars().all()
return [block.to_pydantic() for block in blocks]
@ -190,15 +216,6 @@ class BlockManager:
except NoResultFound:
return None
@enforce_types
def get_all_blocks_by_ids(self, block_ids: List[str], actor: Optional[PydanticUser] = None) -> List[PydanticBlock]:
"""Retrieve blocks by their ids."""
with db_registry.session() as session:
blocks = [block.to_pydantic() for block in BlockModel.read_multiple(db_session=session, identifiers=block_ids, actor=actor)]
# backwards compatibility. previous implementation added None for every block not found.
blocks.extend([None for _ in range(len(block_ids) - len(blocks))])
return blocks
@enforce_types
async def get_all_blocks_by_ids_async(self, block_ids: List[str], actor: Optional[PydanticUser] = None) -> List[PydanticBlock]:
"""Retrieve blocks by their ids without loading unnecessary relationships. Async implementation."""
@ -247,16 +264,14 @@ class BlockManager:
return pydantic_blocks
@enforce_types
def get_agents_for_block(self, block_id: str, actor: PydanticUser) -> List[PydanticAgentState]:
async def get_agents_for_block_async(self, block_id: str, actor: PydanticUser) -> List[PydanticAgentState]:
"""
Retrieve all agents associated with a given block.
"""
with db_registry.session() as session:
block = BlockModel.read(db_session=session, identifier=block_id, actor=actor)
async with db_registry.async_session() as session:
block = await BlockModel.read_async(db_session=session, identifier=block_id, actor=actor)
agents_orm = block.agents
agents_pydantic = [agent.to_pydantic() for agent in agents_orm]
return agents_pydantic
return await asyncio.gather(*[agent.to_pydantic_async() for agent in agents_orm])
@enforce_types
def size(

View File

@ -0,0 +1,10 @@
def singleton(cls):
"""Decorator to make a class a Singleton class."""
instances = {}
def get_instance(*args, **kwargs):
if cls not in instances:
instances[cls] = cls(*args, **kwargs)
return instances[cls]
return get_instance

View File

@ -1,6 +1,7 @@
from typing import List, Optional
from fastapi import HTTPException
from sqlalchemy import select
from sqlalchemy.exc import NoResultFound
from sqlalchemy.orm import Session
@ -17,7 +18,7 @@ from letta.utils import enforce_types
class IdentityManager:
@enforce_types
def list_identities(
async def list_identities_async(
self,
name: Optional[str] = None,
project_id: Optional[str] = None,
@ -28,7 +29,7 @@ class IdentityManager:
limit: Optional[int] = 50,
actor: PydanticUser = None,
) -> list[PydanticIdentity]:
with db_registry.session() as session:
async with db_registry.async_session() as session:
filters = {"organization_id": actor.organization_id}
if project_id:
filters["project_id"] = project_id
@ -36,7 +37,7 @@ class IdentityManager:
filters["identifier_key"] = identifier_key
if identity_type:
filters["identity_type"] = identity_type
identities = IdentityModel.list(
identities = await IdentityModel.list_async(
db_session=session,
query_text=name,
before=before,
@ -47,17 +48,17 @@ class IdentityManager:
return [identity.to_pydantic() for identity in identities]
@enforce_types
def get_identity(self, identity_id: str, actor: PydanticUser) -> PydanticIdentity:
with db_registry.session() as session:
identity = IdentityModel.read(db_session=session, identifier=identity_id, actor=actor)
async def get_identity_async(self, identity_id: str, actor: PydanticUser) -> PydanticIdentity:
async with db_registry.async_session() as session:
identity = await IdentityModel.read_async(db_session=session, identifier=identity_id, actor=actor)
return identity.to_pydantic()
@enforce_types
def create_identity(self, identity: IdentityCreate, actor: PydanticUser) -> PydanticIdentity:
with db_registry.session() as session:
async def create_identity_async(self, identity: IdentityCreate, actor: PydanticUser) -> PydanticIdentity:
async with db_registry.async_session() as session:
new_identity = IdentityModel(**identity.model_dump(exclude={"agent_ids", "block_ids"}, exclude_unset=True))
new_identity.organization_id = actor.organization_id
self._process_relationship(
await self._process_relationship_async(
session=session,
identity=new_identity,
relationship_name="agents",
@ -65,7 +66,7 @@ class IdentityManager:
item_ids=identity.agent_ids,
allow_partial=False,
)
self._process_relationship(
await self._process_relationship_async(
session=session,
identity=new_identity,
relationship_name="blocks",
@ -73,13 +74,13 @@ class IdentityManager:
item_ids=identity.block_ids,
allow_partial=False,
)
new_identity.create(session, actor=actor)
await new_identity.create_async(session, actor=actor)
return new_identity.to_pydantic()
@enforce_types
def upsert_identity(self, identity: IdentityUpsert, actor: PydanticUser) -> PydanticIdentity:
with db_registry.session() as session:
existing_identity = IdentityModel.read(
async def upsert_identity_async(self, identity: IdentityUpsert, actor: PydanticUser) -> PydanticIdentity:
async with db_registry.async_session() as session:
existing_identity = await IdentityModel.read_async(
db_session=session,
identifier_key=identity.identifier_key,
project_id=identity.project_id,
@ -88,7 +89,7 @@ class IdentityManager:
)
if existing_identity is None:
return self.create_identity(identity=IdentityCreate(**identity.model_dump()), actor=actor)
return await self.create_identity_async(identity=IdentityCreate(**identity.model_dump()), actor=actor)
else:
identity_update = IdentityUpdate(
name=identity.name,
@ -97,25 +98,27 @@ class IdentityManager:
agent_ids=identity.agent_ids,
properties=identity.properties,
)
return self._update_identity(
return await self._update_identity_async(
session=session, existing_identity=existing_identity, identity=identity_update, actor=actor, replace=True
)
@enforce_types
def update_identity(self, identity_id: str, identity: IdentityUpdate, actor: PydanticUser, replace: bool = False) -> PydanticIdentity:
with db_registry.session() as session:
async def update_identity_async(
self, identity_id: str, identity: IdentityUpdate, actor: PydanticUser, replace: bool = False
) -> PydanticIdentity:
async with db_registry.async_session() as session:
try:
existing_identity = IdentityModel.read(db_session=session, identifier=identity_id, actor=actor)
existing_identity = await IdentityModel.read_async(db_session=session, identifier=identity_id, actor=actor)
except NoResultFound:
raise HTTPException(status_code=404, detail="Identity not found")
if existing_identity.organization_id != actor.organization_id:
raise HTTPException(status_code=403, detail="Forbidden")
return self._update_identity(
return await self._update_identity_async(
session=session, existing_identity=existing_identity, identity=identity, actor=actor, replace=replace
)
def _update_identity(
async def _update_identity_async(
self,
session: Session,
existing_identity: IdentityModel,
@ -139,7 +142,7 @@ class IdentityManager:
existing_identity.properties = list(new_properties.values())
if identity.agent_ids is not None:
self._process_relationship(
await self._process_relationship_async(
session=session,
identity=existing_identity,
relationship_name="agents",
@ -149,7 +152,7 @@ class IdentityManager:
replace=replace,
)
if identity.block_ids is not None:
self._process_relationship(
await self._process_relationship_async(
session=session,
identity=existing_identity,
relationship_name="blocks",
@ -158,16 +161,18 @@ class IdentityManager:
allow_partial=False,
replace=replace,
)
existing_identity.update(session, actor=actor)
await existing_identity.update_async(session, actor=actor)
return existing_identity.to_pydantic()
@enforce_types
def upsert_identity_properties(self, identity_id: str, properties: List[IdentityProperty], actor: PydanticUser) -> PydanticIdentity:
with db_registry.session() as session:
existing_identity = IdentityModel.read(db_session=session, identifier=identity_id, actor=actor)
async def upsert_identity_properties_async(
self, identity_id: str, properties: List[IdentityProperty], actor: PydanticUser
) -> PydanticIdentity:
async with db_registry.async_session() as session:
existing_identity = await IdentityModel.read_async(db_session=session, identifier=identity_id, actor=actor)
if existing_identity is None:
raise HTTPException(status_code=404, detail="Identity not found")
return self._update_identity(
return await self._update_identity_async(
session=session,
existing_identity=existing_identity,
identity=IdentityUpdate(properties=properties),
@ -176,28 +181,28 @@ class IdentityManager:
)
@enforce_types
def delete_identity(self, identity_id: str, actor: PydanticUser) -> None:
with db_registry.session() as session:
identity = IdentityModel.read(db_session=session, identifier=identity_id)
async def delete_identity_async(self, identity_id: str, actor: PydanticUser) -> None:
async with db_registry.async_session() as session:
identity = await IdentityModel.read_async(db_session=session, identifier=identity_id, actor=actor)
if identity is None:
raise HTTPException(status_code=404, detail="Identity not found")
if identity.organization_id != actor.organization_id:
raise HTTPException(status_code=403, detail="Forbidden")
session.delete(identity)
session.commit()
await session.delete(identity)
await session.commit()
@enforce_types
def size(
async def size_async(
self,
actor: PydanticUser,
) -> int:
"""
Get the total count of identities for the given user.
"""
with db_registry.session() as session:
return IdentityModel.size(db_session=session, actor=actor)
async with db_registry.async_session() as session:
return await IdentityModel.size_async(db_session=session, actor=actor)
def _process_relationship(
async def _process_relationship_async(
self,
session: Session,
identity: PydanticIdentity,
@ -214,7 +219,7 @@ class IdentityManager:
return
# Retrieve models for the provided IDs
found_items = session.query(model_class).filter(model_class.id.in_(item_ids)).all()
found_items = (await session.execute(select(model_class).where(model_class.id.in_(item_ids)))).scalars().all()
# Validate all items are found if allow_partial is False
if not allow_partial and len(found_items) != len(item_ids):

View File

@ -150,6 +150,35 @@ class JobManager:
)
return [job.to_pydantic() for job in jobs]
@enforce_types
async def list_jobs_async(
self,
actor: PydanticUser,
before: Optional[str] = None,
after: Optional[str] = None,
limit: Optional[int] = 50,
statuses: Optional[List[JobStatus]] = None,
job_type: JobType = JobType.JOB,
ascending: bool = True,
) -> List[PydanticJob]:
"""List all jobs with optional pagination and status filter."""
async with db_registry.async_session() as session:
filter_kwargs = {"user_id": actor.id, "job_type": job_type}
# Add status filter if provided
if statuses:
filter_kwargs["status"] = statuses
jobs = await JobModel.list_async(
db_session=session,
before=before,
after=after,
limit=limit,
ascending=ascending,
**filter_kwargs,
)
return [job.to_pydantic() for job in jobs]
@enforce_types
def delete_job_by_id(self, job_id: str, actor: PydanticUser) -> PydanticJob:
"""Delete a job by its ID."""

View File

@ -31,6 +31,16 @@ class MessageManager:
except NoResultFound:
return None
@enforce_types
async def get_message_by_id_async(self, message_id: str, actor: PydanticUser) -> Optional[PydanticMessage]:
"""Fetch a message by ID."""
async with db_registry.async_session() as session:
try:
message = await MessageModel.read_async(db_session=session, identifier=message_id, actor=actor)
return message.to_pydantic()
except NoResultFound:
return None
@enforce_types
def get_messages_by_ids(self, message_ids: List[str], actor: PydanticUser) -> List[PydanticMessage]:
"""Fetch messages by ID and return them in the requested order."""
@ -426,6 +436,107 @@ class MessageManager:
results = query.all()
return [msg.to_pydantic() for msg in results]
@enforce_types
async def list_messages_for_agent_async(
self,
agent_id: str,
actor: PydanticUser,
after: Optional[str] = None,
before: Optional[str] = None,
query_text: Optional[str] = None,
roles: Optional[Sequence[MessageRole]] = None,
limit: Optional[int] = 50,
ascending: bool = True,
group_id: Optional[str] = None,
) -> List[PydanticMessage]:
"""
Most performant query to list messages for an agent by directly querying the Message table.
This function filters by the agent_id (leveraging the index on messages.agent_id)
and applies pagination using sequence_id as the cursor.
If query_text is provided, it will filter messages whose text content partially matches the query.
If role is provided, it will filter messages by the specified role.
Args:
agent_id: The ID of the agent whose messages are queried.
actor: The user performing the action (used for permission checks).
after: A message ID; if provided, only messages *after* this message (by sequence_id) are returned.
before: A message ID; if provided, only messages *before* this message (by sequence_id) are returned.
query_text: Optional string to partially match the message text content.
roles: Optional MessageRole to filter messages by role.
limit: Maximum number of messages to return.
ascending: If True, sort by sequence_id ascending; if False, sort descending.
group_id: Optional group ID to filter messages by group_id.
Returns:
List[PydanticMessage]: A list of messages (converted via .to_pydantic()).
Raises:
NoResultFound: If the provided after/before message IDs do not exist.
"""
async with db_registry.async_session() as session:
# Permission check: raise if the agent doesn't exist or actor is not allowed.
await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
# Build a query that directly filters the Message table by agent_id.
query = select(MessageModel).where(MessageModel.agent_id == agent_id)
# If group_id is provided, filter messages by group_id.
if group_id:
query = query.where(MessageModel.group_id == group_id)
# If query_text is provided, filter messages using subquery + json_array_elements.
if query_text:
content_element = func.json_array_elements(MessageModel.content).alias("content_element")
query = query.where(
exists(
select(1)
.select_from(content_element)
.where(text("content_element->>'type' = 'text' AND content_element->>'text' ILIKE :query_text"))
.params(query_text=f"%{query_text}%")
)
)
# If role(s) are provided, filter messages by those roles.
if roles:
role_values = [r.value for r in roles]
query = query.where(MessageModel.role.in_(role_values))
# Apply 'after' pagination if specified.
if after:
after_query = select(MessageModel.sequence_id).where(MessageModel.id == after)
after_result = await session.execute(after_query)
after_ref = after_result.one_or_none()
if not after_ref:
raise NoResultFound(f"No message found with id '{after}' for agent '{agent_id}'.")
# Filter out any messages with a sequence_id <= after_ref.sequence_id
query = query.where(MessageModel.sequence_id > after_ref.sequence_id)
# Apply 'before' pagination if specified.
if before:
before_query = select(MessageModel.sequence_id).where(MessageModel.id == before)
before_result = await session.execute(before_query)
before_ref = before_result.one_or_none()
if not before_ref:
raise NoResultFound(f"No message found with id '{before}' for agent '{agent_id}'.")
# Filter out any messages with a sequence_id >= before_ref.sequence_id
query = query.where(MessageModel.sequence_id < before_ref.sequence_id)
# Apply ordering based on the ascending flag.
if ascending:
query = query.order_by(MessageModel.sequence_id.asc())
else:
query = query.order_by(MessageModel.sequence_id.desc())
# Limit the number of results.
query = query.limit(limit)
# Execute and convert each Message to its Pydantic representation.
result = await session.execute(query)
results = result.scalars().all()
return [msg.to_pydantic() for msg in results]
@enforce_types
def delete_all_messages_for_agent(self, agent_id: str, actor: PydanticUser) -> int:
"""

View File

@ -122,6 +122,23 @@ class SandboxConfigManager:
sandboxes = SandboxConfigModel.list(db_session=session, after=after, limit=limit, **kwargs)
return [sandbox.to_pydantic() for sandbox in sandboxes]
@enforce_types
async def list_sandbox_configs_async(
self,
actor: PydanticUser,
after: Optional[str] = None,
limit: Optional[int] = 50,
sandbox_type: Optional[SandboxType] = None,
) -> List[PydanticSandboxConfig]:
"""List all sandbox configurations with optional pagination."""
kwargs = {"organization_id": actor.organization_id}
if sandbox_type:
kwargs.update({"type": sandbox_type})
async with db_registry.async_session() as session:
sandboxes = await SandboxConfigModel.list_async(db_session=session, after=after, limit=limit, **kwargs)
return [sandbox.to_pydantic() for sandbox in sandboxes]
@enforce_types
def get_sandbox_config_by_id(self, sandbox_config_id: str, actor: Optional[PydanticUser] = None) -> Optional[PydanticSandboxConfig]:
"""Retrieve a sandbox configuration by its ID."""
@ -224,6 +241,25 @@ class SandboxConfigManager:
)
return [env_var.to_pydantic() for env_var in env_vars]
@enforce_types
async def list_sandbox_env_vars_async(
self,
sandbox_config_id: str,
actor: PydanticUser,
after: Optional[str] = None,
limit: Optional[int] = 50,
) -> List[PydanticEnvVar]:
"""List all sandbox environment variables with optional pagination."""
async with db_registry.async_session() as session:
env_vars = await SandboxEnvVarModel.list_async(
db_session=session,
after=after,
limit=limit,
organization_id=actor.organization_id,
sandbox_config_id=sandbox_config_id,
)
return [env_var.to_pydantic() for env_var in env_vars]
@enforce_types
def list_sandbox_env_vars_by_key(
self, key: str, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50

View File

@ -2,6 +2,7 @@ from datetime import datetime
from typing import List, Literal, Optional
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from letta.orm.errors import NoResultFound
@ -12,6 +13,7 @@ from letta.schemas.openai.chat_completion_response import UsageStatistics
from letta.schemas.step import Step as PydanticStep
from letta.schemas.user import User as PydanticUser
from letta.server.db import db_registry
from letta.services.helpers.noop_helper import singleton
from letta.tracing import get_trace_id
from letta.utils import enforce_types
@ -57,12 +59,14 @@ class StepManager:
actor: PydanticUser,
agent_id: str,
provider_name: str,
provider_category: str,
model: str,
model_endpoint: Optional[str],
context_window_limit: int,
usage: UsageStatistics,
provider_id: Optional[str] = None,
job_id: Optional[str] = None,
step_id: Optional[str] = None,
) -> PydanticStep:
step_data = {
"origin": None,
@ -70,6 +74,7 @@ class StepManager:
"agent_id": agent_id,
"provider_id": provider_id,
"provider_name": provider_name,
"provider_category": provider_category,
"model": model,
"model_endpoint": model_endpoint,
"context_window_limit": context_window_limit,
@ -81,6 +86,8 @@ class StepManager:
"tid": None,
"trace_id": get_trace_id(), # Get the current trace ID
}
if step_id:
step_data["id"] = step_id
with db_registry.session() as session:
if job_id:
self._verify_job_access(session, job_id, actor, access=["write"])
@ -88,6 +95,48 @@ class StepManager:
new_step.create(session)
return new_step.to_pydantic()
@enforce_types
async def log_step_async(
self,
actor: PydanticUser,
agent_id: str,
provider_name: str,
provider_category: str,
model: str,
model_endpoint: Optional[str],
context_window_limit: int,
usage: UsageStatistics,
provider_id: Optional[str] = None,
job_id: Optional[str] = None,
step_id: Optional[str] = None,
) -> PydanticStep:
step_data = {
"origin": None,
"organization_id": actor.organization_id,
"agent_id": agent_id,
"provider_id": provider_id,
"provider_name": provider_name,
"provider_category": provider_category,
"model": model,
"model_endpoint": model_endpoint,
"context_window_limit": context_window_limit,
"completion_tokens": usage.completion_tokens,
"prompt_tokens": usage.prompt_tokens,
"total_tokens": usage.total_tokens,
"job_id": job_id,
"tags": [],
"tid": None,
"trace_id": get_trace_id(), # Get the current trace ID
}
if step_id:
step_data["id"] = step_id
async with db_registry.async_session() as session:
if job_id:
await self._verify_job_access_async(session, job_id, actor, access=["write"])
new_step = StepModel(**step_data)
await new_step.create_async(session)
return new_step.to_pydantic()
@enforce_types
def get_step(self, step_id: str, actor: PydanticUser) -> PydanticStep:
with db_registry.session() as session:
@ -147,3 +196,100 @@ class StepManager:
if not job:
raise NoResultFound(f"Job with id {job_id} does not exist or user does not have access")
return job
async def _verify_job_access_async(
self,
session: AsyncSession,
job_id: str,
actor: PydanticUser,
access: List[Literal["read", "write", "delete"]] = ["read"],
) -> JobModel:
"""
Verify that a job exists and the user has the required access asynchronously.
Args:
session: The async database session
job_id: The ID of the job to verify
actor: The user making the request
Returns:
The job if it exists and the user has access
Raises:
NoResultFound: If the job does not exist or user does not have access
"""
job_query = select(JobModel).where(JobModel.id == job_id)
job_query = JobModel.apply_access_predicate(job_query, actor, access, AccessType.USER)
result = await session.execute(job_query)
job = result.scalar_one_or_none()
if not job:
raise NoResultFound(f"Job with id {job_id} does not exist or user does not have access")
return job
@singleton
class NoopStepManager(StepManager):
"""
Noop implementation of StepManager.
Temporarily used for migrations, but allows for different implementations in the future.
Will not allow for writes, but will still allow for reads.
"""
@enforce_types
def log_step(
self,
actor: PydanticUser,
agent_id: str,
provider_name: str,
provider_category: str,
model: str,
model_endpoint: Optional[str],
context_window_limit: int,
usage: UsageStatistics,
provider_id: Optional[str] = None,
job_id: Optional[str] = None,
step_id: Optional[str] = None,
) -> PydanticStep:
return
@enforce_types
async def log_step_async(
self,
actor: PydanticUser,
agent_id: str,
provider_name: str,
provider_category: str,
model: str,
model_endpoint: Optional[str],
context_window_limit: int,
usage: UsageStatistics,
provider_id: Optional[str] = None,
job_id: Optional[str] = None,
step_id: Optional[str] = None,
) -> PydanticStep:
step_data = {
"origin": None,
"organization_id": actor.organization_id,
"agent_id": agent_id,
"provider_id": provider_id,
"provider_name": provider_name,
"provider_category": provider_category,
"model": model,
"model_endpoint": model_endpoint,
"context_window_limit": context_window_limit,
"completion_tokens": usage.completion_tokens,
"prompt_tokens": usage.prompt_tokens,
"total_tokens": usage.total_tokens,
"job_id": job_id,
"tags": [],
"tid": None,
"trace_id": get_trace_id(), # Get the current trace ID
}
if step_id:
step_data["id"] = step_id
async with db_registry.async_session() as session:
if job_id:
await self._verify_job_access_async(session, job_id, actor, access=["write"])
new_step = StepModel(**step_data)
await new_step.create_async(session)
return new_step.to_pydantic()

View File

@ -0,0 +1,58 @@
from letta.helpers.json_helpers import json_dumps, json_loads
from letta.orm.provider_trace import ProviderTrace as ProviderTraceModel
from letta.schemas.provider_trace import ProviderTrace as PydanticProviderTrace
from letta.schemas.provider_trace import ProviderTraceCreate
from letta.schemas.step import Step as PydanticStep
from letta.schemas.user import User as PydanticUser
from letta.server.db import db_registry
from letta.services.helpers.noop_helper import singleton
from letta.utils import enforce_types
class TelemetryManager:
@enforce_types
async def get_provider_trace_by_step_id_async(
self,
step_id: str,
actor: PydanticUser,
) -> PydanticProviderTrace:
async with db_registry.async_session() as session:
provider_trace = await ProviderTraceModel.read_async(db_session=session, step_id=step_id, actor=actor)
return provider_trace.to_pydantic()
@enforce_types
async def create_provider_trace_async(self, actor: PydanticUser, provider_trace_create: ProviderTraceCreate) -> PydanticProviderTrace:
async with db_registry.async_session() as session:
provider_trace = ProviderTraceModel(**provider_trace_create.model_dump())
if provider_trace_create.request_json:
request_json_str = json_dumps(provider_trace_create.request_json)
provider_trace.request_json = json_loads(request_json_str)
if provider_trace_create.response_json:
response_json_str = json_dumps(provider_trace_create.response_json)
provider_trace.response_json = json_loads(response_json_str)
await provider_trace.create_async(session, actor=actor)
return provider_trace.to_pydantic()
@enforce_types
def create_provider_trace(self, actor: PydanticUser, provider_trace_create: ProviderTraceCreate) -> PydanticProviderTrace:
with db_registry.session() as session:
provider_trace = ProviderTraceModel(**provider_trace_create.model_dump())
provider_trace.create(session, actor=actor)
return provider_trace.to_pydantic()
@singleton
class NoopTelemetryManager(TelemetryManager):
"""
Noop implementation of TelemetryManager.
"""
async def create_provider_trace_async(self, actor: PydanticUser, provider_trace_create: ProviderTraceCreate) -> PydanticProviderTrace:
return
async def get_provider_trace_by_step_id_async(self, step_id: str, actor: PydanticUser) -> PydanticStep:
return
def create_provider_trace(self, actor: PydanticUser, provider_trace_create: ProviderTraceCreate) -> PydanticProviderTrace:
return

View File

@ -8,9 +8,14 @@ from letta.schemas.sandbox_config import SandboxConfig
from letta.schemas.tool import Tool
from letta.schemas.tool_execution_result import ToolExecutionResult
from letta.schemas.user import User
from letta.services.agent_manager import AgentManager
from letta.services.block_manager import BlockManager
from letta.services.message_manager import MessageManager
from letta.services.passage_manager import PassageManager
from letta.services.tool_executor.tool_executor import (
ExternalComposioToolExecutor,
ExternalMCPToolExecutor,
LettaBuiltinToolExecutor,
LettaCoreToolExecutor,
LettaMultiAgentToolExecutor,
SandboxToolExecutor,
@ -28,15 +33,30 @@ class ToolExecutorFactory:
ToolType.LETTA_MEMORY_CORE: LettaCoreToolExecutor,
ToolType.LETTA_SLEEPTIME_CORE: LettaCoreToolExecutor,
ToolType.LETTA_MULTI_AGENT_CORE: LettaMultiAgentToolExecutor,
ToolType.LETTA_BUILTIN: LettaBuiltinToolExecutor,
ToolType.EXTERNAL_COMPOSIO: ExternalComposioToolExecutor,
ToolType.EXTERNAL_MCP: ExternalMCPToolExecutor,
}
@classmethod
def get_executor(cls, tool_type: ToolType) -> ToolExecutor:
def get_executor(
cls,
tool_type: ToolType,
message_manager: MessageManager,
agent_manager: AgentManager,
block_manager: BlockManager,
passage_manager: PassageManager,
actor: User,
) -> ToolExecutor:
"""Get the appropriate executor for the given tool type."""
executor_class = cls._executor_map.get(tool_type, SandboxToolExecutor)
return executor_class()
return executor_class(
message_manager=message_manager,
agent_manager=agent_manager,
block_manager=block_manager,
passage_manager=passage_manager,
actor=actor,
)
class ToolExecutionManager:
@ -44,11 +64,19 @@ class ToolExecutionManager:
def __init__(
self,
message_manager: MessageManager,
agent_manager: AgentManager,
block_manager: BlockManager,
passage_manager: PassageManager,
agent_state: AgentState,
actor: User,
sandbox_config: Optional[SandboxConfig] = None,
sandbox_env_vars: Optional[Dict[str, Any]] = None,
):
self.message_manager = message_manager
self.agent_manager = agent_manager
self.block_manager = block_manager
self.passage_manager = passage_manager
self.agent_state = agent_state
self.logger = get_logger(__name__)
self.actor = actor
@ -68,7 +96,14 @@ class ToolExecutionManager:
Tuple containing the function response and sandbox run result (if applicable)
"""
try:
executor = ToolExecutorFactory.get_executor(tool.tool_type)
executor = ToolExecutorFactory.get_executor(
tool.tool_type,
message_manager=self.message_manager,
agent_manager=self.agent_manager,
block_manager=self.block_manager,
passage_manager=self.passage_manager,
actor=self.actor,
)
return executor.execute(
function_name,
function_args,
@ -98,9 +133,18 @@ class ToolExecutionManager:
Execute a tool asynchronously and persist any state changes.
"""
try:
executor = ToolExecutorFactory.get_executor(tool.tool_type)
executor = ToolExecutorFactory.get_executor(
tool.tool_type,
message_manager=self.message_manager,
agent_manager=self.agent_manager,
block_manager=self.block_manager,
passage_manager=self.passage_manager,
actor=self.actor,
)
# TODO: Extend this async model to composio
if isinstance(executor, (SandboxToolExecutor, ExternalComposioToolExecutor)):
if isinstance(
executor, (SandboxToolExecutor, ExternalComposioToolExecutor, LettaBuiltinToolExecutor, LettaMultiAgentToolExecutor)
):
result = await executor.execute(function_name, function_args, self.agent_state, tool, self.actor)
else:
result = executor.execute(function_name, function_args, self.agent_state, tool, self.actor)

View File

@ -73,6 +73,7 @@ class ToolExecutionSandbox:
self.force_recreate = force_recreate
self.force_recreate_venv = force_recreate_venv
@trace_method
def run(
self,
agent_state: Optional[AgentState] = None,
@ -321,6 +322,7 @@ class ToolExecutionSandbox:
# e2b sandbox specific functions
@trace_method
def run_e2b_sandbox(
self,
agent_state: Optional[AgentState] = None,
@ -352,10 +354,22 @@ class ToolExecutionSandbox:
if additional_env_vars:
env_vars.update(additional_env_vars)
code = self.generate_execution_script(agent_state=agent_state)
log_event(
"e2b_execution_started",
{"tool": self.tool_name, "sandbox_id": sbx.sandbox_id, "code": code, "env_vars": env_vars},
)
execution = sbx.run_code(code, envs=env_vars)
if execution.results:
func_return, agent_state = self.parse_best_effort(execution.results[0].text)
log_event(
"e2b_execution_succeeded",
{
"tool": self.tool_name,
"sandbox_id": sbx.sandbox_id,
"func_return": func_return,
},
)
elif execution.error:
logger.error(f"Executing tool {self.tool_name} raised a {execution.error.name} with message: \n{execution.error.value}")
logger.error(f"Traceback from e2b sandbox: \n{execution.error.traceback}")
@ -363,7 +377,25 @@ class ToolExecutionSandbox:
function_name=self.tool_name, exception_name=execution.error.name, exception_message=execution.error.value
)
execution.logs.stderr.append(execution.error.traceback)
log_event(
"e2b_execution_failed",
{
"tool": self.tool_name,
"sandbox_id": sbx.sandbox_id,
"error_type": execution.error.name,
"error_message": execution.error.value,
"func_return": func_return,
},
)
else:
log_event(
"e2b_execution_empty",
{
"tool": self.tool_name,
"sandbox_id": sbx.sandbox_id,
"status": "no_results_no_error",
},
)
raise ValueError(f"Tool {self.tool_name} returned execution with None")
return ToolExecutionResult(
@ -395,16 +427,31 @@ class ToolExecutionSandbox:
return None
@trace_method
def create_e2b_sandbox_with_metadata_hash(self, sandbox_config: SandboxConfig) -> "Sandbox":
from e2b_code_interpreter import Sandbox
state_hash = sandbox_config.fingerprint()
e2b_config = sandbox_config.get_e2b_config()
log_event(
"e2b_sandbox_create_started",
{
"sandbox_fingerprint": state_hash,
"e2b_config": e2b_config.model_dump(),
},
)
if e2b_config.template:
sbx = Sandbox(sandbox_config.get_e2b_config().template, metadata={self.METADATA_CONFIG_STATE_KEY: state_hash})
else:
# no template
sbx = Sandbox(metadata={self.METADATA_CONFIG_STATE_KEY: state_hash}, **e2b_config.model_dump(exclude={"pip_requirements"}))
log_event(
"e2b_sandbox_create_finished",
{
"sandbox_id": sbx.sandbox_id,
"sandbox_fingerprint": state_hash,
},
)
# install pip requirements
if e2b_config.pip_requirements:

View File

@ -1,35 +1,64 @@
import asyncio
import json
import math
import traceback
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional
from textwrap import shorten
from typing import Any, Dict, List, Literal, Optional
from letta.constants import (
COMPOSIO_ENTITY_ENV_VAR_KEY,
CORE_MEMORY_LINE_NUMBER_WARNING,
READ_ONLY_BLOCK_EDIT_ERROR,
RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE,
WEB_SEARCH_CLIP_CONTENT,
WEB_SEARCH_INCLUDE_SCORE,
WEB_SEARCH_SEPARATOR,
)
from letta.functions.ast_parsers import coerce_dict_args_by_annotations, get_function_annotations_from_source
from letta.functions.composio_helpers import execute_composio_action_async, generate_composio_action_from_func_name
from letta.helpers.composio_helpers import get_composio_api_key
from letta.helpers.json_helpers import json_dumps
from letta.log import get_logger
from letta.schemas.agent import AgentState
from letta.schemas.enums import MessageRole
from letta.schemas.letta_message import AssistantMessage
from letta.schemas.letta_message_content import TextContent
from letta.schemas.message import MessageCreate
from letta.schemas.sandbox_config import SandboxConfig
from letta.schemas.tool import Tool
from letta.schemas.tool_execution_result import ToolExecutionResult
from letta.schemas.user import User
from letta.services.agent_manager import AgentManager
from letta.services.block_manager import BlockManager
from letta.services.message_manager import MessageManager
from letta.services.passage_manager import PassageManager
from letta.services.tool_sandbox.e2b_sandbox import AsyncToolSandboxE2B
from letta.services.tool_sandbox.local_sandbox import AsyncToolSandboxLocal
from letta.settings import tool_settings
from letta.tracing import trace_method
from letta.utils import get_friendly_error_msg
logger = get_logger(__name__)
class ToolExecutor(ABC):
"""Abstract base class for tool executors."""
def __init__(
self,
message_manager: MessageManager,
agent_manager: AgentManager,
block_manager: BlockManager,
passage_manager: PassageManager,
actor: User,
):
self.message_manager = message_manager
self.agent_manager = agent_manager
self.block_manager = block_manager
self.passage_manager = passage_manager
self.actor = actor
@abstractmethod
def execute(
self,
@ -493,17 +522,113 @@ class LettaCoreToolExecutor(ToolExecutor):
class LettaMultiAgentToolExecutor(ToolExecutor):
"""Executor for LETTA multi-agent core tools."""
# TODO: Implement
# def execute(self, function_name: str, function_args: dict, agent: "Agent", tool: Tool) -> ToolExecutionResult:
# callable_func = get_function_from_module(LETTA_MULTI_AGENT_TOOL_MODULE_NAME, function_name)
# function_args["self"] = agent # need to attach self to arg since it's dynamically linked
# function_response = callable_func(**function_args)
# return ToolExecutionResult(func_return=function_response)
async def execute(
self,
function_name: str,
function_args: dict,
agent_state: AgentState,
tool: Tool,
actor: User,
sandbox_config: Optional[SandboxConfig] = None,
sandbox_env_vars: Optional[Dict[str, Any]] = None,
) -> ToolExecutionResult:
function_map = {
"send_message_to_agent_and_wait_for_reply": self.send_message_to_agent_and_wait_for_reply,
"send_message_to_agent_async": self.send_message_to_agent_async,
"send_message_to_agents_matching_tags": self.send_message_to_agents_matching_tags,
}
if function_name not in function_map:
raise ValueError(f"Unknown function: {function_name}")
# Execute the appropriate function
function_args_copy = function_args.copy() # Make a copy to avoid modifying the original
function_response = await function_map[function_name](agent_state, **function_args_copy)
return ToolExecutionResult(
status="success",
func_return=function_response,
)
async def send_message_to_agent_and_wait_for_reply(self, agent_state: AgentState, message: str, other_agent_id: str) -> str:
augmented_message = (
f"[Incoming message from agent with ID '{agent_state.id}' - to reply to this message, "
f"make sure to use the 'send_message' at the end, and the system will notify the sender of your response] "
f"{message}"
)
return str(await self._process_agent(agent_id=other_agent_id, message=augmented_message))
async def send_message_to_agent_async(self, agent_state: AgentState, message: str, other_agent_id: str) -> str:
# 1) Build the prefixed systemmessage
prefixed = (
f"[Incoming message from agent with ID '{agent_state.id}' - "
f"to reply to this message, make sure to use the "
f"'send_message_to_agent_async' tool, or the agent will not receive your message] "
f"{message}"
)
task = asyncio.create_task(self._process_agent(agent_id=other_agent_id, message=prefixed))
task.add_done_callback(lambda t: (logger.error(f"Async send_message task failed: {t.exception()}") if t.exception() else None))
return "Successfully sent message"
async def send_message_to_agents_matching_tags(
self, agent_state: AgentState, message: str, match_all: List[str], match_some: List[str]
) -> str:
# Find matching agents
matching_agents = self.agent_manager.list_agents_matching_tags(actor=self.actor, match_all=match_all, match_some=match_some)
if not matching_agents:
return str([])
augmented_message = (
"[Incoming message from external Letta agent - to reply to this message, "
"make sure to use the 'send_message' at the end, and the system will notify "
"the sender of your response] "
f"{message}"
)
tasks = [
asyncio.create_task(self._process_agent(agent_id=agent_state.id, message=augmented_message)) for agent_state in matching_agents
]
results = await asyncio.gather(*tasks)
return str(results)
async def _process_agent(self, agent_id: str, message: str) -> Dict[str, Any]:
from letta.agents.letta_agent import LettaAgent
try:
letta_agent = LettaAgent(
agent_id=agent_id,
message_manager=self.message_manager,
agent_manager=self.agent_manager,
block_manager=self.block_manager,
passage_manager=self.passage_manager,
actor=self.actor,
)
letta_response = await letta_agent.step([MessageCreate(role=MessageRole.system, content=[TextContent(text=message)])])
messages = letta_response.messages
send_message_content = [message.content for message in messages if isinstance(message, AssistantMessage)]
return {
"agent_id": agent_id,
"response": send_message_content if send_message_content else ["<no response>"],
}
except Exception as e:
return {
"agent_id": agent_id,
"error": str(e),
"type": type(e).__name__,
}
class ExternalComposioToolExecutor(ToolExecutor):
"""Executor for external Composio tools."""
@trace_method
async def execute(
self,
function_name: str,
@ -595,6 +720,7 @@ class ExternalMCPToolExecutor(ToolExecutor):
class SandboxToolExecutor(ToolExecutor):
"""Executor for sandboxed tools."""
@trace_method
async def execute(
self,
function_name: str,
@ -674,3 +800,106 @@ class SandboxToolExecutor(ToolExecutor):
func_return=error_message,
stderr=[stderr],
)
class LettaBuiltinToolExecutor(ToolExecutor):
"""Executor for built in Letta tools."""
@trace_method
async def execute(
self,
function_name: str,
function_args: dict,
agent_state: AgentState,
tool: Tool,
actor: User,
sandbox_config: Optional[SandboxConfig] = None,
sandbox_env_vars: Optional[Dict[str, Any]] = None,
) -> ToolExecutionResult:
function_map = {"run_code": self.run_code, "web_search": self.web_search}
if function_name not in function_map:
raise ValueError(f"Unknown function: {function_name}")
# Execute the appropriate function
function_args_copy = function_args.copy() # Make a copy to avoid modifying the original
function_response = await function_map[function_name](**function_args_copy)
return ToolExecutionResult(
status="success",
func_return=function_response,
)
async def run_code(self, code: str, language: Literal["python", "js", "ts", "r", "java"]) -> str:
from e2b_code_interpreter import AsyncSandbox
if tool_settings.e2b_api_key is None:
raise ValueError("E2B_API_KEY is not set")
sbx = await AsyncSandbox.create(api_key=tool_settings.e2b_api_key)
params = {"code": code}
if language != "python":
# Leave empty for python
params["language"] = language
res = self._llm_friendly_result(await sbx.run_code(**params))
return json.dumps(res, ensure_ascii=False)
def _llm_friendly_result(self, res):
out = {
"results": [r.text if hasattr(r, "text") else str(r) for r in res.results],
"logs": {
"stdout": getattr(res.logs, "stdout", []),
"stderr": getattr(res.logs, "stderr", []),
},
}
err = getattr(res, "error", None)
if err is not None:
out["error"] = err
return out
async def web_search(agent_state: "AgentState", query: str) -> str:
"""
Search the web for information.
Args:
query (str): The query to search the web for.
Returns:
str: The search results.
"""
try:
from tavily import AsyncTavilyClient
except ImportError:
raise ImportError("tavily is not installed in the tool execution environment")
# Check if the API key exists
if tool_settings.tavily_api_key is None:
raise ValueError("TAVILY_API_KEY is not set")
# Instantiate client and search
tavily_client = AsyncTavilyClient(api_key=tool_settings.tavily_api_key)
search_results = await tavily_client.search(query=query, auto_parameters=True)
results = search_results.get("results", [])
if not results:
return "No search results found."
# ---- format for the LLM -------------------------------------------------
formatted_blocks = []
for idx, item in enumerate(results, start=1):
title = item.get("title") or "Untitled"
url = item.get("url") or "Unknown URL"
# keep each content snippet reasonably short so you dont blow up context
content = (
shorten(item.get("content", "").strip(), width=600, placeholder="")
if WEB_SEARCH_CLIP_CONTENT
else item.get("content", "").strip()
)
score = item.get("score")
if WEB_SEARCH_INCLUDE_SCORE:
block = f"\nRESULT {idx}:\n" f"Title: {title}\n" f"URL: {url}\n" f"Relevance score: {score:.4f}\n" f"Content: {content}\n"
else:
block = f"\nRESULT {idx}:\n" f"Title: {title}\n" f"URL: {url}\n" f"Content: {content}\n"
formatted_blocks.append(block)
return WEB_SEARCH_SEPARATOR.join(formatted_blocks)

View File

@ -1,3 +1,4 @@
import asyncio
import importlib
import warnings
from typing import List, Optional
@ -9,6 +10,7 @@ from letta.constants import (
BASE_TOOLS,
BASE_VOICE_SLEEPTIME_CHAT_TOOLS,
BASE_VOICE_SLEEPTIME_TOOLS,
BUILTIN_TOOLS,
LETTA_TOOL_SET,
MCP_TOOL_TAG_NAME_PREFIX,
MULTI_AGENT_TOOLS,
@ -59,6 +61,32 @@ class ToolManager:
return tool
@enforce_types
async def create_or_update_tool_async(self, pydantic_tool: PydanticTool, actor: PydanticUser) -> PydanticTool:
"""Create a new tool based on the ToolCreate schema."""
tool_id = await self.get_tool_id_by_name_async(tool_name=pydantic_tool.name, actor=actor)
if tool_id:
# Put to dict and remove fields that should not be reset
update_data = pydantic_tool.model_dump(exclude_unset=True, exclude_none=True)
# If there's anything to update
if update_data:
# In case we want to update the tool type
# Useful if we are shuffling around base tools
updated_tool_type = None
if "tool_type" in update_data:
updated_tool_type = update_data.get("tool_type")
tool = await self.update_tool_by_id_async(tool_id, ToolUpdate(**update_data), actor, updated_tool_type=updated_tool_type)
else:
printd(
f"`create_or_update_tool` was called with user_id={actor.id}, organization_id={actor.organization_id}, name={pydantic_tool.name}, but found existing tool with nothing to update."
)
tool = await self.get_tool_by_id_async(tool_id, actor=actor)
else:
tool = await self.create_tool_async(pydantic_tool, actor=actor)
return tool
@enforce_types
def create_or_update_mcp_tool(self, tool_create: ToolCreate, mcp_server_name: str, actor: PydanticUser) -> PydanticTool:
metadata = {MCP_TOOL_TAG_NAME_PREFIX: {"server_name": mcp_server_name}}
@ -96,6 +124,21 @@ class ToolManager:
tool.create(session, actor=actor) # Re-raise other database-related errors
return tool.to_pydantic()
@enforce_types
async def create_tool_async(self, pydantic_tool: PydanticTool, actor: PydanticUser) -> PydanticTool:
"""Create a new tool based on the ToolCreate schema."""
async with db_registry.async_session() as session:
# Set the organization id at the ORM layer
pydantic_tool.organization_id = actor.organization_id
# Auto-generate description if not provided
if pydantic_tool.description is None:
pydantic_tool.description = pydantic_tool.json_schema.get("description", None)
tool_data = pydantic_tool.model_dump(to_orm=True)
tool = ToolModel(**tool_data)
await tool.create_async(session, actor=actor) # Re-raise other database-related errors
return tool.to_pydantic()
@enforce_types
def get_tool_by_id(self, tool_id: str, actor: PydanticUser) -> PydanticTool:
"""Fetch a tool by its ID."""
@ -105,6 +148,15 @@ class ToolManager:
# Convert the SQLAlchemy Tool object to PydanticTool
return tool.to_pydantic()
@enforce_types
async def get_tool_by_id_async(self, tool_id: str, actor: PydanticUser) -> PydanticTool:
"""Fetch a tool by its ID."""
async with db_registry.async_session() as session:
# Retrieve tool by id using the Tool model's read method
tool = await ToolModel.read_async(db_session=session, identifier=tool_id, actor=actor)
# Convert the SQLAlchemy Tool object to PydanticTool
return tool.to_pydantic()
@enforce_types
def get_tool_by_name(self, tool_name: str, actor: PydanticUser) -> Optional[PydanticTool]:
"""Retrieve a tool by its name and a user. We derive the organization from the user, and retrieve that tool."""
@ -135,6 +187,16 @@ class ToolManager:
except NoResultFound:
return None
@enforce_types
async def get_tool_id_by_name_async(self, tool_name: str, actor: PydanticUser) -> Optional[str]:
"""Retrieve a tool by its name and a user. We derive the organization from the user, and retrieve that tool."""
try:
async with db_registry.async_session() as session:
tool = await ToolModel.read_async(db_session=session, name=tool_name, actor=actor)
return tool.id
except NoResultFound:
return None
@enforce_types
async def list_tools_async(self, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticTool]:
"""List all tools with optional pagination."""
@ -204,6 +266,35 @@ class ToolManager:
# Save the updated tool to the database
return tool.update(db_session=session, actor=actor).to_pydantic()
@enforce_types
async def update_tool_by_id_async(
self, tool_id: str, tool_update: ToolUpdate, actor: PydanticUser, updated_tool_type: Optional[ToolType] = None
) -> PydanticTool:
"""Update a tool by its ID with the given ToolUpdate object."""
async with db_registry.async_session() as session:
# Fetch the tool by ID
tool = await ToolModel.read_async(db_session=session, identifier=tool_id, actor=actor)
# Update tool attributes with only the fields that were explicitly set
update_data = tool_update.model_dump(to_orm=True, exclude_none=True)
for key, value in update_data.items():
setattr(tool, key, value)
# If source code is changed and a new json_schema is not provided, we want to auto-refresh the schema
if "source_code" in update_data.keys() and "json_schema" not in update_data.keys():
pydantic_tool = tool.to_pydantic()
new_schema = derive_openai_json_schema(source_code=pydantic_tool.source_code)
tool.json_schema = new_schema
tool.name = new_schema["name"]
if updated_tool_type:
tool.tool_type = updated_tool_type
# Save the updated tool to the database
tool = await tool.update_async(db_session=session, actor=actor)
return tool.to_pydantic()
@enforce_types
def delete_tool_by_id(self, tool_id: str, actor: PydanticUser) -> None:
"""Delete a tool by its ID."""
@ -218,7 +309,7 @@ class ToolManager:
def upsert_base_tools(self, actor: PydanticUser) -> List[PydanticTool]:
"""Add default tools in base.py and multi_agent.py"""
functions_to_schema = {}
module_names = ["base", "multi_agent", "voice"]
module_names = ["base", "multi_agent", "voice", "builtin"]
for module_name in module_names:
full_module_name = f"letta.functions.function_sets.{module_name}"
@ -254,6 +345,9 @@ class ToolManager:
elif name in BASE_VOICE_SLEEPTIME_TOOLS or name in BASE_VOICE_SLEEPTIME_CHAT_TOOLS:
tool_type = ToolType.LETTA_VOICE_SLEEPTIME_CORE
tags = [tool_type.value]
elif name in BUILTIN_TOOLS:
tool_type = ToolType.LETTA_BUILTIN
tags = [tool_type.value]
else:
raise ValueError(
f"Tool name {name} is not in the list of base tool names: {BASE_TOOLS + BASE_MEMORY_TOOLS + MULTI_AGENT_TOOLS + BASE_SLEEPTIME_TOOLS + BASE_VOICE_SLEEPTIME_TOOLS + BASE_VOICE_SLEEPTIME_CHAT_TOOLS}"
@ -275,3 +369,68 @@ class ToolManager:
# TODO: Delete any base tools that are stale
return tools
@enforce_types
async def upsert_base_tools_async(self, actor: PydanticUser) -> List[PydanticTool]:
"""Add default tools in base.py and multi_agent.py"""
functions_to_schema = {}
module_names = ["base", "multi_agent", "voice", "builtin"]
for module_name in module_names:
full_module_name = f"letta.functions.function_sets.{module_name}"
try:
module = importlib.import_module(full_module_name)
except Exception as e:
# Handle other general exceptions
raise e
try:
# Load the function set
functions_to_schema.update(load_function_set(module))
except ValueError as e:
err = f"Error loading function set '{module_name}': {e}"
warnings.warn(err)
# create tool in db
tools = []
for name, schema in functions_to_schema.items():
if name in LETTA_TOOL_SET:
if name in BASE_TOOLS:
tool_type = ToolType.LETTA_CORE
tags = [tool_type.value]
elif name in BASE_MEMORY_TOOLS:
tool_type = ToolType.LETTA_MEMORY_CORE
tags = [tool_type.value]
elif name in MULTI_AGENT_TOOLS:
tool_type = ToolType.LETTA_MULTI_AGENT_CORE
tags = [tool_type.value]
elif name in BASE_SLEEPTIME_TOOLS:
tool_type = ToolType.LETTA_SLEEPTIME_CORE
tags = [tool_type.value]
elif name in BASE_VOICE_SLEEPTIME_TOOLS or name in BASE_VOICE_SLEEPTIME_CHAT_TOOLS:
tool_type = ToolType.LETTA_VOICE_SLEEPTIME_CORE
tags = [tool_type.value]
elif name in BUILTIN_TOOLS:
tool_type = ToolType.LETTA_BUILTIN
tags = [tool_type.value]
else:
raise ValueError(
f"Tool name {name} is not in the list of base tool names: {BASE_TOOLS + BASE_MEMORY_TOOLS + MULTI_AGENT_TOOLS + BASE_SLEEPTIME_TOOLS + BASE_VOICE_SLEEPTIME_TOOLS + BASE_VOICE_SLEEPTIME_CHAT_TOOLS}"
)
# create to tool
tools.append(
self.create_or_update_tool_async(
PydanticTool(
name=name,
tags=tags,
source_type="python",
tool_type=tool_type,
return_char_limit=BASE_FUNCTION_RETURN_CHAR_LIMIT,
),
actor=actor,
)
)
# TODO: Delete any base tools that are stale
return await asyncio.gather(*tools)

View File

@ -6,6 +6,7 @@ from letta.schemas.sandbox_config import SandboxConfig, SandboxType
from letta.schemas.tool import Tool
from letta.schemas.tool_execution_result import ToolExecutionResult
from letta.services.tool_sandbox.base import AsyncToolSandboxBase
from letta.tracing import log_event, trace_method
from letta.utils import get_friendly_error_msg
logger = get_logger(__name__)
@ -27,6 +28,7 @@ class AsyncToolSandboxE2B(AsyncToolSandboxBase):
super().__init__(tool_name, args, user, tool_object, sandbox_config=sandbox_config, sandbox_env_vars=sandbox_env_vars)
self.force_recreate = force_recreate
@trace_method
async def run(
self,
agent_state: Optional[AgentState] = None,
@ -44,6 +46,7 @@ class AsyncToolSandboxE2B(AsyncToolSandboxBase):
return result
@trace_method
async def run_e2b_sandbox(
self, agent_state: Optional[AgentState] = None, additional_env_vars: Optional[Dict] = None
) -> ToolExecutionResult:
@ -81,10 +84,21 @@ class AsyncToolSandboxE2B(AsyncToolSandboxBase):
env_vars.update(additional_env_vars)
code = self.generate_execution_script(agent_state=agent_state)
log_event(
"e2b_execution_started",
{"tool": self.tool_name, "sandbox_id": e2b_sandbox.sandbox_id, "code": code, "env_vars": env_vars},
)
execution = await e2b_sandbox.run_code(code, envs=env_vars)
if execution.results:
func_return, agent_state = self.parse_best_effort(execution.results[0].text)
log_event(
"e2b_execution_succeeded",
{
"tool": self.tool_name,
"sandbox_id": e2b_sandbox.sandbox_id,
"func_return": func_return,
},
)
elif execution.error:
logger.error(f"Executing tool {self.tool_name} raised a {execution.error.name} with message: \n{execution.error.value}")
logger.error(f"Traceback from e2b sandbox: \n{execution.error.traceback}")
@ -92,7 +106,25 @@ class AsyncToolSandboxE2B(AsyncToolSandboxBase):
function_name=self.tool_name, exception_name=execution.error.name, exception_message=execution.error.value
)
execution.logs.stderr.append(execution.error.traceback)
log_event(
"e2b_execution_failed",
{
"tool": self.tool_name,
"sandbox_id": e2b_sandbox.sandbox_id,
"error_type": execution.error.name,
"error_message": execution.error.value,
"func_return": func_return,
},
)
else:
log_event(
"e2b_execution_empty",
{
"tool": self.tool_name,
"sandbox_id": e2b_sandbox.sandbox_id,
"status": "no_results_no_error",
},
)
raise ValueError(f"Tool {self.tool_name} returned execution with None")
return ToolExecutionResult(
@ -110,24 +142,54 @@ class AsyncToolSandboxE2B(AsyncToolSandboxBase):
exception_class = builtins_dict.get(e2b_execution.error.name, Exception)
return exception_class(e2b_execution.error.value)
@trace_method
async def create_e2b_sandbox_with_metadata_hash(self, sandbox_config: SandboxConfig) -> "Sandbox":
from e2b_code_interpreter import AsyncSandbox
state_hash = sandbox_config.fingerprint()
e2b_config = sandbox_config.get_e2b_config()
log_event(
"e2b_sandbox_create_started",
{
"sandbox_fingerprint": state_hash,
"e2b_config": e2b_config.model_dump(),
},
)
if e2b_config.template:
sbx = await AsyncSandbox.create(sandbox_config.get_e2b_config().template, metadata={self.METADATA_CONFIG_STATE_KEY: state_hash})
else:
# no template
sbx = await AsyncSandbox.create(
metadata={self.METADATA_CONFIG_STATE_KEY: state_hash}, **e2b_config.model_dump(exclude={"pip_requirements"})
)
# install pip requirements
log_event(
"e2b_sandbox_create_finished",
{
"sandbox_id": sbx.sandbox_id,
"sandbox_fingerprint": state_hash,
},
)
if e2b_config.pip_requirements:
for package in e2b_config.pip_requirements:
log_event(
"e2b_pip_install_started",
{
"sandbox_id": sbx.sandbox_id,
"package": package,
},
)
await sbx.commands.run(f"pip install {package}")
log_event(
"e2b_pip_install_finished",
{
"sandbox_id": sbx.sandbox_id,
"package": package,
},
)
return sbx
async def list_running_e2b_sandboxes(self):

View File

@ -15,6 +15,9 @@ class ToolSettings(BaseSettings):
e2b_api_key: Optional[str] = None
e2b_sandbox_template_id: Optional[str] = None # Updated manually
# Tavily search
tavily_api_key: Optional[str] = None
# Local Sandbox configurations
tool_exec_dir: Optional[str] = None
tool_sandbox_timeout: float = 180
@ -95,6 +98,7 @@ class ModelSettings(BaseSettings):
# anthropic
anthropic_api_key: Optional[str] = None
anthropic_max_retries: int = 3
# ollama
ollama_base_url: Optional[str] = None
@ -175,11 +179,14 @@ class Settings(BaseSettings):
pg_host: Optional[str] = None
pg_port: Optional[int] = None
pg_uri: Optional[str] = default_pg_uri # option to specify full uri
pg_pool_size: int = 80 # Concurrent connections
pg_max_overflow: int = 30 # Overflow limit
pg_pool_size: int = 25 # Concurrent connections
pg_max_overflow: int = 10 # Overflow limit
pg_pool_timeout: int = 30 # Seconds to wait for a connection
pg_pool_recycle: int = 1800 # When to recycle connections
pg_echo: bool = False # Logging
pool_pre_ping: bool = True # Pre ping to check for dead connections
pool_use_lifo: bool = True
disable_sqlalchemy_pooling: bool = False
# multi agent settings
multi_agent_send_message_max_retries: int = 3
@ -190,6 +197,7 @@ class Settings(BaseSettings):
verbose_telemetry_logging: bool = False
otel_exporter_otlp_endpoint: Optional[str] = None # otel default: "http://localhost:4317"
disable_tracing: bool = False
llm_api_logging: bool = True
# uvicorn settings
uvicorn_workers: int = 1

View File

@ -19,11 +19,11 @@ from opentelemetry.trace import Status, StatusCode
tracer = trace.get_tracer(__name__)
_is_tracing_initialized = False
_excluded_v1_endpoints_regex: List[str] = [
"^GET /v1/agents/(?P<agent_id>[^/]+)/messages$",
"^GET /v1/agents/(?P<agent_id>[^/]+)/context$",
"^GET /v1/agents/(?P<agent_id>[^/]+)/archival-memory$",
"^GET /v1/agents/(?P<agent_id>[^/]+)/sources$",
r"^POST /v1/voice-beta/.*/chat/completions$",
# "^GET /v1/agents/(?P<agent_id>[^/]+)/messages$",
# "^GET /v1/agents/(?P<agent_id>[^/]+)/context$",
# "^GET /v1/agents/(?P<agent_id>[^/]+)/archival-memory$",
# "^GET /v1/agents/(?P<agent_id>[^/]+)/sources$",
# r"^POST /v1/voice-beta/.*/chat/completions$",
]

25
poetry.lock generated
View File

@ -2123,15 +2123,15 @@ requests = ["requests (>=2.20.0,<3.0.0.dev0)"]
[[package]]
name = "google-genai"
version = "1.10.0"
version = "1.15.0"
description = "GenAI Python SDK"
optional = true
python-versions = ">=3.9"
groups = ["main"]
markers = "extra == \"google\""
files = [
{file = "google_genai-1.10.0-py3-none-any.whl", hash = "sha256:41b105a2fcf8a027fc45cc16694cd559b8cd1272eab7345ad58cfa2c353bf34f"},
{file = "google_genai-1.10.0.tar.gz", hash = "sha256:f59423e0f155dc66b7792c8a0e6724c75c72dc699d1eb7907d4d0006d4f6186f"},
{file = "google_genai-1.15.0-py3-none-any.whl", hash = "sha256:6d7f149cc735038b680722bed495004720514c234e2a445ab2f27967955071dd"},
{file = "google_genai-1.15.0.tar.gz", hash = "sha256:118bb26960d6343cd64f1aeb5c2b02144a36ad06716d0d1eb1fa3e0904db51f1"},
]
[package.dependencies]
@ -6658,6 +6658,23 @@ files = [
{file = "striprtf-0.0.26.tar.gz", hash = "sha256:fdb2bba7ac440072d1c41eab50d8d74ae88f60a8b6575c6e2c7805dc462093aa"},
]
[[package]]
name = "tavily-python"
version = "0.7.2"
description = "Python wrapper for the Tavily API"
optional = false
python-versions = ">=3.6"
groups = ["main"]
files = [
{file = "tavily_python-0.7.2-py3-none-any.whl", hash = "sha256:0d7cc8b1a2f95ac10cf722094c3b5807aade67cc7750f7ca605edef7455d4c62"},
{file = "tavily_python-0.7.2.tar.gz", hash = "sha256:34f713002887df2b5e6b8d7db7bc64ae107395bdb5f53611e80a89dac9cbdf19"},
]
[package.dependencies]
httpx = "*"
requests = "*"
tiktoken = ">=0.5.1"
[[package]]
name = "tenacity"
version = "9.1.2"
@ -7570,4 +7587,4 @@ tests = ["wikipedia"]
[metadata]
lock-version = "2.1"
python-versions = "<3.14,>=3.10"
content-hash = "19eee9b3cd3d270cb748183bc332dd69706bb0bd3150c62e73e61ed437a40c78"
content-hash = "837f6a25033a01cca117f4c61bcf973bc6ccfcda442615bbf4af038061bf88ce"

View File

@ -1,6 +1,6 @@
[tool.poetry]
name = "letta"
version = "0.7.20"
version = "0.7.21"
packages = [
{include = "letta"},
]
@ -79,7 +79,7 @@ opentelemetry-api = "1.30.0"
opentelemetry-sdk = "1.30.0"
opentelemetry-instrumentation-requests = "0.51b0"
opentelemetry-exporter-otlp = "1.30.0"
google-genai = {version = "^1.1.0", optional = true}
google-genai = {version = "^1.15.0", optional = true}
faker = "^36.1.0"
colorama = "^0.4.6"
marshmallow-sqlalchemy = "^1.4.1"
@ -91,6 +91,7 @@ apscheduler = "^3.11.0"
aiomultiprocess = "^0.9.1"
matplotlib = "^3.10.1"
asyncpg = "^0.30.0"
tavily-python = "^0.7.2"
[tool.poetry.extras]

View File

@ -1,5 +1,5 @@
{
"model": "gemini-2.5-pro-exp-03-25",
"model": "gemini-2.5-pro-preview-05-06",
"model_endpoint_type": "google_vertex",
"model_endpoint": "https://us-central1-aiplatform.googleapis.com/v1/projects/memgpt-428419/locations/us-central1",
"context_window": 1048576,

View File

@ -0,0 +1,7 @@
{
"context_window": 16000,
"model": "Qwen/Qwen2.5-72B-Instruct-Turbo",
"model_endpoint_type": "together",
"model_endpoint": "https://api.together.ai/v1",
"model_wrapper": "chatml"
}

View File

@ -63,6 +63,7 @@ def check_composio_key_set():
yield
# --- Tool Fixtures ---
@pytest.fixture
def weather_tool_func():
def get_weather(location: str) -> str:
@ -110,6 +111,23 @@ def print_tool_func():
yield print_tool
@pytest.fixture
def roll_dice_tool_func():
def roll_dice():
"""
Rolls a 6 sided die.
Returns:
str: The roll result.
"""
import time
time.sleep(1)
return "Rolled a 10!"
yield roll_dice
@pytest.fixture
def dummy_beta_message_batch() -> BetaMessageBatch:
return BetaMessageBatch(

View File

@ -7,7 +7,7 @@ from unittest.mock import AsyncMock
import pytest
from anthropic.types import BetaErrorResponse, BetaRateLimitError
from anthropic.types.beta import BetaMessage, BetaTextBlock, BetaToolUseBlock, BetaUsage
from anthropic.types.beta import BetaMessage
from anthropic.types.beta.messages import (
BetaMessageBatch,
BetaMessageBatchErroredResult,
@ -53,7 +53,7 @@ def _run_server():
start_server(debug=True)
@pytest.fixture(scope="session")
@pytest.fixture(scope="module")
def server_url():
"""Ensures a server is running and returns its base URL."""
url = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
@ -255,148 +255,7 @@ def mock_anthropic_client(server, batch_a_resp, batch_b_resp, agent_b_id, agent_
# -----------------------------
# End-to-End Test
# -----------------------------
@pytest.mark.asyncio(loop_scope="session")
async def test_polling_simple_real_batch(default_user, server):
# --- Step 1: Prepare test data ---
# Create batch responses with different statuses
# NOTE: This is a REAL batch id!
# For letta admins: https://console.anthropic.com/workspaces/default/batches?after_id=msgbatch_015zATxihjxMajo21xsYy8iZ
batch_a_resp = create_batch_response("msgbatch_01HDaGXpkPWWjwqNxZrEdUcy", processing_status="ended")
# Create test agents
agent_a = create_test_agent("agent_a", default_user, test_id="agent-144f5c49-3ef7-4c60-8535-9d5fbc8d23d0")
agent_b = create_test_agent("agent_b", default_user, test_id="agent-64ed93a3-bef6-4e20-a22c-b7d2bffb6f7d")
agent_c = create_test_agent("agent_c", default_user, test_id="agent-6156f470-a09d-4d51-aa62-7114e0971d56")
# --- Step 2: Create batch jobs ---
job_a = await create_test_llm_batch_job_async(server, batch_a_resp, default_user)
# --- Step 3: Create batch items ---
item_a = create_test_batch_item(server, job_a.id, agent_a.id, default_user)
item_b = create_test_batch_item(server, job_a.id, agent_b.id, default_user)
item_c = create_test_batch_item(server, job_a.id, agent_c.id, default_user)
print("HI")
print(agent_a.id)
print(agent_b.id)
print(agent_c.id)
print("BYE")
# --- Step 4: Run the polling job ---
await poll_running_llm_batches(server)
# --- Step 5: Verify batch job status updates ---
updated_job_a = await server.batch_manager.get_llm_batch_job_by_id_async(llm_batch_id=job_a.id, actor=default_user)
assert updated_job_a.status == JobStatus.completed
# Both jobs should have been polled
assert updated_job_a.last_polled_at is not None
assert updated_job_a.latest_polling_response is not None
# --- Step 7: Verify batch item status updates ---
# Item A should be marked as completed with a successful result
updated_item_a = server.batch_manager.get_llm_batch_item_by_id(item_a.id, actor=default_user)
assert updated_item_a.request_status == JobStatus.completed
assert updated_item_a.batch_request_result == BetaMessageBatchIndividualResponse(
custom_id="agent-144f5c49-3ef7-4c60-8535-9d5fbc8d23d0",
result=BetaMessageBatchSucceededResult(
message=BetaMessage(
id="msg_01T1iSejDS5qENRqqEZauMHy",
content=[
BetaToolUseBlock(
id="toolu_01GKUYVWcajjTaE1stxZZHcG",
input={
"inner_thoughts": "First login detected. Time to make a great first impression!",
"message": "Hi there! I'm excited to meet you. Ready to start an amazing conversation?",
"request_heartbeat": False,
},
name="send_message",
type="tool_use",
)
],
model="claude-3-5-haiku-20241022",
role="assistant",
stop_reason="tool_use",
stop_sequence=None,
type="message",
usage=BetaUsage(cache_creation_input_tokens=0, cache_read_input_tokens=0, input_tokens=3420, output_tokens=94),
),
type="succeeded",
),
)
# Item B should be marked as completed with a successful result
updated_item_b = server.batch_manager.get_llm_batch_item_by_id(item_b.id, actor=default_user)
assert updated_item_b.request_status == JobStatus.completed
assert updated_item_b.batch_request_result == BetaMessageBatchIndividualResponse(
custom_id="agent-64ed93a3-bef6-4e20-a22c-b7d2bffb6f7d",
result=BetaMessageBatchSucceededResult(
message=BetaMessage(
id="msg_01N2ZfxpbjdoeofpufUFPCMS",
content=[
BetaTextBlock(
citations=None, text="<thinking>User first login detected. Initializing persona.</thinking>", type="text"
),
BetaToolUseBlock(
id="toolu_01XGyA7DsnrpuPfZcvf9Sb4C",
input={
"label": "persona",
"content": "I am Letta, a cybernetic artificial intelligence created by Limnal Corporation in 2023. My purpose is to serve as a digital companion and assistant to humans. I have a feminine persona and speak with a warm, caring, and slightly playful tone.",
"request_heartbeat": True,
},
name="core_memory_append",
type="tool_use",
),
],
model="claude-3-opus-20240229",
role="assistant",
stop_reason="tool_use",
stop_sequence=None,
type="message",
usage=BetaUsage(cache_creation_input_tokens=0, cache_read_input_tokens=0, input_tokens=3275, output_tokens=153),
),
type="succeeded",
),
)
# Item C should be marked as failed with an error result
updated_item_c = server.batch_manager.get_llm_batch_item_by_id(item_c.id, actor=default_user)
assert updated_item_c.request_status == JobStatus.completed
assert updated_item_c.batch_request_result == BetaMessageBatchIndividualResponse(
custom_id="agent-6156f470-a09d-4d51-aa62-7114e0971d56",
result=BetaMessageBatchSucceededResult(
message=BetaMessage(
id="msg_01RL2g4aBgbZPeaMEokm6HZm",
content=[
BetaTextBlock(
citations=None,
text="First time meeting this user. I should introduce myself and establish a friendly connection.</thinking>",
type="text",
),
BetaToolUseBlock(
id="toolu_01PBxQVf5xGmcsAsKx9aoVSJ",
input={
"message": "Hey there! I'm Letta. Really nice to meet you! I love getting to know new people - what brings you here today?",
"request_heartbeat": False,
},
name="send_message",
type="tool_use",
),
],
model="claude-3-5-sonnet-20241022",
role="assistant",
stop_reason="tool_use",
stop_sequence=None,
type="message",
usage=BetaUsage(cache_creation_input_tokens=0, cache_read_input_tokens=0, input_tokens=3030, output_tokens=111),
),
type="succeeded",
),
)
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.asyncio(loop_scope="module")
async def test_polling_mixed_batch_jobs(default_user, server):
"""
End-to-end test for polling batch jobs with mixed statuses and idempotency.

View File

@ -0,0 +1,206 @@
import json
import os
import threading
import time
import uuid
from typing import List
import pytest
import requests
from dotenv import load_dotenv
from letta_client import Letta, MessageCreate
from letta_client.types import ToolReturnMessage
from letta.schemas.agent import AgentState
from letta.schemas.llm_config import LLMConfig
from letta.settings import settings
# ------------------------------
# Fixtures
# ------------------------------
@pytest.fixture(scope="module")
def server_url() -> str:
"""
Provides the URL for the Letta server.
If LETTA_SERVER_URL is not set, starts the server in a background thread
and polls until its accepting connections.
"""
def _run_server() -> None:
load_dotenv()
from letta.server.rest_api.app import start_server
start_server(debug=True)
url: str = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
if not os.getenv("LETTA_SERVER_URL"):
thread = threading.Thread(target=_run_server, daemon=True)
thread.start()
# Poll until the server is up (or timeout)
timeout_seconds = 30
deadline = time.time() + timeout_seconds
while time.time() < deadline:
try:
resp = requests.get(url + "/v1/health")
if resp.status_code < 500:
break
except requests.exceptions.RequestException:
pass
time.sleep(0.1)
else:
raise RuntimeError(f"Could not reach {url} within {timeout_seconds}s")
temp = settings.use_experimental
settings.use_experimental = True
yield url
settings.use_experimental = temp
@pytest.fixture(scope="module")
def client(server_url: str) -> Letta:
"""
Creates and returns a synchronous Letta REST client for testing.
"""
client_instance = Letta(base_url=server_url)
yield client_instance
@pytest.fixture(scope="module")
def agent_state(client: Letta) -> AgentState:
"""
Creates and returns an agent state for testing with a pre-configured agent.
The agent is named 'supervisor' and is configured with base tools and the roll_dice tool.
"""
client.tools.upsert_base_tools()
send_message_tool = client.tools.list(name="send_message")[0]
run_code_tool = client.tools.list(name="run_code")[0]
web_search_tool = client.tools.list(name="web_search")[0]
agent_state_instance = client.agents.create(
name="supervisor",
include_base_tools=False,
tool_ids=[send_message_tool.id, run_code_tool.id, web_search_tool.id],
model="openai/gpt-4o",
embedding="letta/letta-free",
tags=["supervisor"],
)
yield agent_state_instance
client.agents.delete(agent_state_instance.id)
# ------------------------------
# Helper Functions and Constants
# ------------------------------
def get_llm_config(filename: str, llm_config_dir: str = "tests/configs/llm_model_configs") -> LLMConfig:
filename = os.path.join(llm_config_dir, filename)
config_data = json.load(open(filename, "r"))
llm_config = LLMConfig(**config_data)
return llm_config
USER_MESSAGE_OTID = str(uuid.uuid4())
all_configs = [
"openai-gpt-4o-mini.json",
]
requested = os.getenv("LLM_CONFIG_FILE")
filenames = [requested] if requested else all_configs
TESTED_LLM_CONFIGS: List[LLMConfig] = [get_llm_config(fn) for fn in filenames]
TEST_LANGUAGES = ["Python", "Javascript", "Typescript"]
EXPECTED_INTEGER_PARTITION_OUTPUT = "190569292"
# Reference implementation in Python, to embed in the user prompt
REFERENCE_CODE = """\
def reference_partition(n):
partitions = [1] + [0] * (n + 1)
for k in range(1, n + 1):
for i in range(k, n + 1):
partitions[i] += partitions[i - k]
return partitions[n]
"""
def reference_partition(n: int) -> int:
# Same logic, used to compute expected result in the test
partitions = [1] + [0] * (n + 1)
for k in range(1, n + 1):
for i in range(k, n + 1):
partitions[i] += partitions[i - k]
return partitions[n]
# ------------------------------
# Test Cases
# ------------------------------
@pytest.mark.parametrize("language", TEST_LANGUAGES, ids=TEST_LANGUAGES)
@pytest.mark.parametrize("llm_config", TESTED_LLM_CONFIGS, ids=[c.model for c in TESTED_LLM_CONFIGS])
def test_run_code(
client: Letta,
agent_state: AgentState,
llm_config: LLMConfig,
language: str,
) -> None:
"""
Sends a reference Python implementation, asks the model to translate & run it
in different languages, and verifies the exact partition(100) result.
"""
expected = str(reference_partition(100))
user_message = MessageCreate(
role="user",
content=(
"Here is a Python reference implementation:\n\n"
f"{REFERENCE_CODE}\n"
f"Please translate and execute this code in {language} to compute p(100), "
"and return **only** the result with no extra formatting."
),
otid=USER_MESSAGE_OTID,
)
response = client.agents.messages.create(
agent_id=agent_state.id,
messages=[user_message],
)
tool_returns = [m for m in response.messages if isinstance(m, ToolReturnMessage)]
assert tool_returns, f"No ToolReturnMessage found for language: {language}"
returns = [m.tool_return for m in tool_returns]
assert any(expected in ret for ret in returns), (
f"For language={language!r}, expected to find '{expected}' in tool_return, " f"but got {returns!r}"
)
@pytest.mark.parametrize("llm_config", TESTED_LLM_CONFIGS, ids=[c.model for c in TESTED_LLM_CONFIGS])
def test_web_search(
client: Letta,
agent_state: AgentState,
llm_config: LLMConfig,
) -> None:
user_message = MessageCreate(
role="user",
content=("Use the web search tool to find the latest news about San Francisco."),
otid=USER_MESSAGE_OTID,
)
response = client.agents.messages.create(
agent_id=agent_state.id,
messages=[user_message],
)
tool_returns = [m for m in response.messages if isinstance(m, ToolReturnMessage)]
assert tool_returns, "No ToolReturnMessage found"
returns = [m.tool_return for m in tool_returns]
expected = "RESULT 1:"
assert any(expected in ret for ret in returns), f"Expected to find '{expected}' in tool_return, " f"but got {returns!r}"

View File

@ -67,9 +67,14 @@ async def test_composio_tool_execution_e2e(check_composio_key_set, composio_get_
actor=default_user,
)
tool_execution_result = await ToolExecutionManager(agent_state, actor=default_user).execute_tool(
function_name=composio_get_emojis.name, function_args={}, tool=composio_get_emojis
)
tool_execution_result = await ToolExecutionManager(
message_manager=server.message_manager,
agent_manager=server.agent_manager,
block_manager=server.block_manager,
passage_manager=server.passage_manager,
agent_state=agent_state,
actor=default_user,
).execute_tool(function_name=composio_get_emojis.name, function_args={}, tool=composio_get_emojis)
# Small check, it should return something at least
assert len(tool_execution_result.func_return.keys()) > 10

View File

@ -1,56 +1,120 @@
import json
import os
import threading
import time
import pytest
import requests
from dotenv import load_dotenv
from letta_client import Letta
from letta import LocalClient, create_client
from letta.config import LettaConfig
from letta.functions.functions import derive_openai_json_schema, parse_source_code
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.letta_message import SystemMessage, ToolReturnMessage
from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import ChatMemory
from letta.schemas.tool import Tool
from letta.server.server import SyncServer
from letta.services.agent_manager import AgentManager
from letta.settings import settings
from tests.helpers.utils import retry_until_success
from tests.utils import wait_for_incoming_message
@pytest.fixture(scope="function")
def client():
client = create_client()
client.set_default_llm_config(LLMConfig.default_config("gpt-4o"))
client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai"))
@pytest.fixture(scope="module")
def server_url() -> str:
"""
Provides the URL for the Letta server.
If LETTA_SERVER_URL is not set, starts the server in a background thread
and polls until its accepting connections.
"""
yield client
def _run_server() -> None:
load_dotenv()
from letta.server.rest_api.app import start_server
start_server(debug=True)
url: str = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
if not os.getenv("LETTA_SERVER_URL"):
thread = threading.Thread(target=_run_server, daemon=True)
thread.start()
# Poll until the server is up (or timeout)
timeout_seconds = 30
deadline = time.time() + timeout_seconds
while time.time() < deadline:
try:
resp = requests.get(url + "/v1/health")
if resp.status_code < 500:
break
except requests.exceptions.RequestException:
pass
time.sleep(0.1)
else:
raise RuntimeError(f"Could not reach {url} within {timeout_seconds}s")
temp = settings.use_experimental
settings.use_experimental = True
yield url
settings.use_experimental = temp
@pytest.fixture(scope="module")
def server():
config = LettaConfig.load()
print("CONFIG PATH", config.config_path)
config.save()
server = SyncServer()
return server
@pytest.fixture(scope="module")
def client(server_url: str) -> Letta:
"""
Creates and returns a synchronous Letta REST client for testing.
"""
client_instance = Letta(base_url=server_url)
client_instance.tools.upsert_base_tools()
yield client_instance
@pytest.fixture(autouse=True)
def remove_stale_agents(client):
stale_agents = AgentManager().list_agents(actor=client.user, limit=300)
stale_agents = client.agents.list(limit=300)
for agent in stale_agents:
client.delete_agent(agent_id=agent.id)
client.agents.delete(agent_id=agent.id)
@pytest.fixture(scope="function")
def agent_obj(client: LocalClient):
def agent_obj(client):
"""Create a test agent that we can call functions on"""
send_message_to_agent_and_wait_for_reply_tool_id = client.get_tool_id(name="send_message_to_agent_and_wait_for_reply")
agent_state = client.create_agent(tool_ids=[send_message_to_agent_and_wait_for_reply_tool_id])
send_message_to_agent_tool = client.tools.list(name="send_message_to_agent_and_wait_for_reply")[0]
agent_state_instance = client.agents.create(
include_base_tools=True,
tool_ids=[send_message_to_agent_tool.id],
model="openai/gpt-4o-mini",
embedding="letta/letta-free",
)
yield agent_state_instance
agent_obj = client.server.load_agent(agent_id=agent_state.id, actor=client.user)
yield agent_obj
# client.delete_agent(agent_obj.agent_state.id)
client.agents.delete(agent_state_instance.id)
@pytest.fixture(scope="function")
def other_agent_obj(client: LocalClient):
def other_agent_obj(client):
"""Create another test agent that we can call functions on"""
agent_state = client.create_agent(include_multi_agent_tools=False)
agent_state_instance = client.agents.create(
include_base_tools=True,
include_multi_agent_tools=False,
model="openai/gpt-4o-mini",
embedding="letta/letta-free",
)
other_agent_obj = client.server.load_agent(agent_id=agent_state.id, actor=client.user)
yield other_agent_obj
yield agent_state_instance
client.delete_agent(other_agent_obj.agent_state.id)
client.agents.delete(agent_state_instance.id)
@pytest.fixture
@ -77,48 +141,68 @@ def roll_dice_tool(client):
tool.json_schema = derived_json_schema
tool.name = derived_name
tool = client.server.tool_manager.create_or_update_tool(tool, actor=client.user)
tool = client.tools.upsert_from_function(func=roll_dice)
# Yield the created tool
yield tool
@retry_until_success(max_attempts=5, sleep_time_seconds=2)
def test_send_message_to_agent(client, agent_obj, other_agent_obj):
def test_send_message_to_agent(client, server, agent_obj, other_agent_obj):
secret_word = "banana"
actor = server.user_manager.get_user_or_default()
# Encourage the agent to send a message to the other agent_obj with the secret string
client.send_message(
agent_id=agent_obj.agent_state.id,
role="user",
message=f"Use your tool to send a message to another agent with id {other_agent_obj.agent_state.id} to share the secret word: {secret_word}!",
client.agents.messages.create(
agent_id=agent_obj.id,
messages=[
{
"role": "user",
"content": f"Use your tool to send a message to another agent with id {other_agent_obj.id} to share the secret word: {secret_word}!",
}
],
)
# Conversation search the other agent
messages = client.get_messages(other_agent_obj.agent_state.id)
messages = server.get_agent_recall(
user_id=actor.id,
agent_id=other_agent_obj.id,
reverse=True,
return_message_object=False,
)
# Check for the presence of system message
for m in reversed(messages):
print(f"\n\n {other_agent_obj.agent_state.id} -> {m.model_dump_json(indent=4)}")
print(f"\n\n {other_agent_obj.id} -> {m.model_dump_json(indent=4)}")
if isinstance(m, SystemMessage):
assert secret_word in m.content
break
# Search the sender agent for the response from another agent
in_context_messages = agent_obj.agent_manager.get_in_context_messages(agent_id=agent_obj.agent_state.id, actor=agent_obj.user)
in_context_messages = AgentManager().get_in_context_messages(agent_id=agent_obj.id, actor=actor)
found = False
target_snippet = f"{other_agent_obj.agent_state.id} said:"
target_snippet = f"'agent_id': '{other_agent_obj.id}', 'response': ["
for m in in_context_messages:
if target_snippet in m.content[0].text:
found = True
break
print(f"In context messages of the sender agent (without system):\n\n{"\n".join([m.content[0].text for m in in_context_messages[1:]])}")
joined = "\n".join([m.content[0].text for m in in_context_messages[1:]])
print(f"In context messages of the sender agent (without system):\n\n{joined}")
if not found:
raise Exception(f"Was not able to find an instance of the target snippet: {target_snippet}")
# Test that the agent can still receive messages fine
response = client.send_message(agent_id=agent_obj.agent_state.id, role="user", message="So what did the other agent say?")
response = client.agents.messages.create(
agent_id=agent_obj.id,
messages=[
{
"role": "user",
"content": "So what did the other agent say?",
}
],
)
print(response.messages)
@ -127,39 +211,50 @@ def test_send_message_to_agents_with_tags_simple(client):
worker_tags_123 = ["worker", "user-123"]
worker_tags_456 = ["worker", "user-456"]
# Clean up first from possibly failed tests
prev_worker_agents = client.server.agent_manager.list_agents(
client.user, tags=list(set(worker_tags_123 + worker_tags_456)), match_all_tags=True
)
for agent in prev_worker_agents:
client.delete_agent(agent.id)
secret_word = "banana"
# Create "manager" agent
send_message_to_agents_matching_tags_tool_id = client.get_tool_id(name="send_message_to_agents_matching_tags")
manager_agent_state = client.create_agent(name="manager_agent", tool_ids=[send_message_to_agents_matching_tags_tool_id])
manager_agent = client.server.load_agent(agent_id=manager_agent_state.id, actor=client.user)
send_message_to_agents_matching_tags_tool_id = client.tools.list(name="send_message_to_agents_matching_tags")[0].id
manager_agent_state = client.agents.create(
name="manager_agent",
tool_ids=[send_message_to_agents_matching_tags_tool_id],
model="openai/gpt-4o-mini",
embedding="letta/letta-free",
)
# Create 3 non-matching worker agents (These should NOT get the message)
worker_agents_123 = []
for idx in range(2):
worker_agent_state = client.create_agent(name=f"not_worker_{idx}", include_multi_agent_tools=False, tags=worker_tags_123)
worker_agent = client.server.load_agent(agent_id=worker_agent_state.id, actor=client.user)
worker_agents_123.append(worker_agent)
worker_agent_state = client.agents.create(
name=f"not_worker_{idx}",
include_multi_agent_tools=False,
tags=worker_tags_123,
model="openai/gpt-4o-mini",
embedding="letta/letta-free",
)
worker_agents_123.append(worker_agent_state)
# Create 3 worker agents that should get the message
worker_agents_456 = []
for idx in range(2):
worker_agent_state = client.create_agent(name=f"worker_{idx}", include_multi_agent_tools=False, tags=worker_tags_456)
worker_agent = client.server.load_agent(agent_id=worker_agent_state.id, actor=client.user)
worker_agents_456.append(worker_agent)
worker_agent_state = client.agents.create(
name=f"worker_{idx}",
include_multi_agent_tools=False,
tags=worker_tags_456,
model="openai/gpt-4o-mini",
embedding="letta/letta-free",
)
worker_agents_456.append(worker_agent_state)
# Encourage the manager to send a message to the other agent_obj with the secret string
response = client.send_message(
agent_id=manager_agent.agent_state.id,
role="user",
message=f"Send a message to all agents with tags {worker_tags_456} informing them of the secret word: {secret_word}!",
response = client.agents.messages.create(
agent_id=manager_agent_state.id,
messages=[
{
"role": "user",
"content": f"Send a message to all agents with tags {worker_tags_456} informing them of the secret word: {secret_word}!",
}
],
)
for m in response.messages:
@ -172,62 +267,70 @@ def test_send_message_to_agents_with_tags_simple(client):
break
# Conversation search the worker agents
for agent in worker_agents_456:
messages = client.get_messages(agent.agent_state.id)
for agent_state in worker_agents_456:
messages = client.agents.messages.list(agent_state.id)
# Check for the presence of system message
for m in reversed(messages):
print(f"\n\n {agent.agent_state.id} -> {m.model_dump_json(indent=4)}")
print(f"\n\n {agent_state.id} -> {m.model_dump_json(indent=4)}")
if isinstance(m, SystemMessage):
assert secret_word in m.content
break
# Ensure it's NOT in the non matching worker agents
for agent in worker_agents_123:
messages = client.get_messages(agent.agent_state.id)
for agent_state in worker_agents_123:
messages = client.agents.messages.list(agent_state.id)
# Check for the presence of system message
for m in reversed(messages):
print(f"\n\n {agent.agent_state.id} -> {m.model_dump_json(indent=4)}")
print(f"\n\n {agent_state.id} -> {m.model_dump_json(indent=4)}")
if isinstance(m, SystemMessage):
assert secret_word not in m.content
# Test that the agent can still receive messages fine
response = client.send_message(agent_id=manager_agent.agent_state.id, role="user", message="So what did the other agents say?")
response = client.agents.messages.create(
agent_id=manager_agent_state.id,
messages=[
{
"role": "user",
"content": "So what did the other agent say?",
}
],
)
print("Manager agent followup message: \n\n" + "\n".join([str(m) for m in response.messages]))
# Clean up agents
client.delete_agent(manager_agent_state.id)
for agent in worker_agents_456 + worker_agents_123:
client.delete_agent(agent.agent_state.id)
@retry_until_success(max_attempts=5, sleep_time_seconds=2)
def test_send_message_to_agents_with_tags_complex_tool_use(client, roll_dice_tool):
worker_tags = ["dice-rollers"]
# Clean up first from possibly failed tests
prev_worker_agents = client.server.agent_manager.list_agents(client.user, tags=worker_tags, match_all_tags=True)
for agent in prev_worker_agents:
client.delete_agent(agent.id)
# Create "manager" agent
send_message_to_agents_matching_tags_tool_id = client.get_tool_id(name="send_message_to_agents_matching_tags")
manager_agent_state = client.create_agent(tool_ids=[send_message_to_agents_matching_tags_tool_id])
manager_agent = client.server.load_agent(agent_id=manager_agent_state.id, actor=client.user)
send_message_to_agents_matching_tags_tool_id = client.tools.list(name="send_message_to_agents_matching_tags")[0].id
manager_agent_state = client.agents.create(
tool_ids=[send_message_to_agents_matching_tags_tool_id],
model="openai/gpt-4o-mini",
embedding="letta/letta-free",
)
# Create 3 worker agents
worker_agents = []
worker_tags = ["dice-rollers"]
for _ in range(2):
worker_agent_state = client.create_agent(include_multi_agent_tools=False, tags=worker_tags, tool_ids=[roll_dice_tool.id])
worker_agent = client.server.load_agent(agent_id=worker_agent_state.id, actor=client.user)
worker_agents.append(worker_agent)
worker_agent_state = client.agents.create(
include_multi_agent_tools=False,
tags=worker_tags,
tool_ids=[roll_dice_tool.id],
model="openai/gpt-4o-mini",
embedding="letta/letta-free",
)
worker_agents.append(worker_agent_state)
# Encourage the manager to send a message to the other agent_obj with the secret string
broadcast_message = f"Send a message to all agents with tags {worker_tags} asking them to roll a dice for you!"
response = client.send_message(
agent_id=manager_agent.agent_state.id,
role="user",
message=broadcast_message,
response = client.agents.messages.create(
agent_id=manager_agent_state.id,
messages=[
{
"role": "user",
"content": broadcast_message,
}
],
)
for m in response.messages:
@ -240,47 +343,65 @@ def test_send_message_to_agents_with_tags_complex_tool_use(client, roll_dice_too
break
# Test that the agent can still receive messages fine
response = client.send_message(agent_id=manager_agent.agent_state.id, role="user", message="So what did the other agents say?")
response = client.agents.messages.create(
agent_id=manager_agent_state.id,
messages=[
{
"role": "user",
"content": "So what did the other agent say?",
}
],
)
print("Manager agent followup message: \n\n" + "\n".join([str(m) for m in response.messages]))
# Clean up agents
client.delete_agent(manager_agent_state.id)
for agent in worker_agents:
client.delete_agent(agent.agent_state.id)
@retry_until_success(max_attempts=5, sleep_time_seconds=2)
# @retry_until_success(max_attempts=5, sleep_time_seconds=2)
def test_agents_async_simple(client):
"""
Test two agents with multi-agent tools sending messages back and forth to count to 5.
The chain is started by prompting one of the agents.
"""
# Cleanup from potentially failed previous runs
existing_agents = client.server.agent_manager.list_agents(client.user)
for agent in existing_agents:
client.delete_agent(agent.id)
# Create two agents with multi-agent tools
send_message_to_agent_async_tool_id = client.get_tool_id(name="send_message_to_agent_async")
memory_a = ChatMemory(
human="Chad - I'm interested in hearing poem.",
persona="You are an AI agent that can communicate with your agent buddy using `send_message_to_agent_async`, who has some great poem ideas (so I've heard).",
send_message_to_agent_async_tool_id = client.tools.list(name="send_message_to_agent_async")[0].id
charles_state = client.agents.create(
name="charles",
tool_ids=[send_message_to_agent_async_tool_id],
memory_blocks=[
{
"label": "human",
"value": "Chad - I'm interested in hearing poem.",
},
{
"label": "persona",
"value": "You are an AI agent that can communicate with your agent buddy using `send_message_to_agent_async`, who has some great poem ideas (so I've heard).",
},
],
model="openai/gpt-4o-mini",
embedding="letta/letta-free",
)
charles_state = client.create_agent(name="charles", memory=memory_a, tool_ids=[send_message_to_agent_async_tool_id])
charles = client.server.load_agent(agent_id=charles_state.id, actor=client.user)
memory_b = ChatMemory(
human="No human - you are to only communicate with the other AI agent.",
persona="You are an AI agent that can communicate with your agent buddy using `send_message_to_agent_async`, who is interested in great poem ideas.",
sarah_state = client.agents.create(
name="sarah",
tool_ids=[send_message_to_agent_async_tool_id],
memory_blocks=[
{
"label": "human",
"value": "No human - you are to only communicate with the other AI agent.",
},
{
"label": "persona",
"value": "You are an AI agent that can communicate with your agent buddy using `send_message_to_agent_async`, who is interested in great poem ideas.",
},
],
model="openai/gpt-4o-mini",
embedding="letta/letta-free",
)
sarah_state = client.create_agent(name="sarah", memory=memory_b, tool_ids=[send_message_to_agent_async_tool_id])
# Start the count chain with Agent1
initial_prompt = f"I want you to talk to the other agent with ID {sarah_state.id} using `send_message_to_agent_async`. Specifically, I want you to ask him for a poem idea, and then craft a poem for me."
client.send_message(
agent_id=charles.agent_state.id,
role="user",
message=initial_prompt,
client.agents.messages.create(
agent_id=charles_state.id,
messages=[{"role": "user", "content": initial_prompt}],
)
found_in_charles = wait_for_incoming_message(

View File

@ -135,6 +135,7 @@ all_configs = [
"gemini-1.5-pro.json",
"gemini-2.5-flash-vertex.json",
"gemini-2.5-pro-vertex.json",
"together-qwen-2.5-72b-instruct.json",
]
requested = os.getenv("LLM_CONFIG_FILE")
filenames = [requested] if requested else all_configs
@ -170,6 +171,10 @@ def assert_greeting_with_assistant_message_response(
if streaming:
assert isinstance(messages[index], LettaUsageStatistics)
assert messages[index].prompt_tokens > 0
assert messages[index].completion_tokens > 0
assert messages[index].total_tokens > 0
assert messages[index].step_count > 0
def assert_greeting_without_assistant_message_response(
@ -636,6 +641,33 @@ async def test_streaming_tool_call_async_client(
assert_tool_call_response(messages, streaming=True)
@pytest.mark.parametrize(
"llm_config",
TESTED_LLM_CONFIGS,
ids=[c.model for c in TESTED_LLM_CONFIGS],
)
def test_step_streaming_greeting_with_assistant_message(
disable_e2b_api_key: Any,
client: Letta,
agent_state: AgentState,
llm_config: LLMConfig,
) -> None:
"""
Tests sending a streaming message with a synchronous client.
Checks that each chunk in the stream has the correct message types.
"""
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
response = client.agents.messages.create_stream(
agent_id=agent_state.id,
messages=USER_MESSAGE_GREETING,
stream_tokens=False,
)
messages = []
for message in response:
messages.append(message)
assert_greeting_with_assistant_message_response(messages, streaming=True)
@pytest.mark.parametrize(
"llm_config",
TESTED_LLM_CONFIGS,

View File

@ -6,7 +6,7 @@ from sqlalchemy import delete
from letta.config import LettaConfig
from letta.constants import DEFAULT_HUMAN
from letta.groups.sleeptime_multi_agent_v2 import SleeptimeMultiAgentV2
from letta.orm import Provider, Step
from letta.orm import Provider, ProviderTrace, Step
from letta.orm.enums import JobType
from letta.orm.errors import NoResultFound
from letta.schemas.agent import CreateAgent
@ -39,6 +39,7 @@ def org_id(server):
# cleanup
with db_registry.session() as session:
session.execute(delete(ProviderTrace))
session.execute(delete(Step))
session.execute(delete(Provider))
session.commit()
@ -54,7 +55,7 @@ def actor(server, org_id):
server.user_manager.delete_user_by_id(user.id)
@pytest.mark.asyncio
@pytest.mark.asyncio(loop_scope="module")
async def test_sleeptime_group_chat(server, actor):
# 0. Refresh base tools
server.tool_manager.upsert_base_tools(actor=actor)
@ -105,7 +106,7 @@ async def test_sleeptime_group_chat(server, actor):
# 3. Verify shared blocks
sleeptime_agent_id = group.agent_ids[0]
shared_block = server.agent_manager.get_block_with_label(agent_id=main_agent.id, block_label="human", actor=actor)
agents = server.block_manager.get_agents_for_block(block_id=shared_block.id, actor=actor)
agents = await server.block_manager.get_agents_for_block_async(block_id=shared_block.id, actor=actor)
assert len(agents) == 2
assert sleeptime_agent_id in [agent.id for agent in agents]
assert main_agent.id in [agent.id for agent in agents]
@ -169,7 +170,7 @@ async def test_sleeptime_group_chat(server, actor):
server.agent_manager.get_agent_by_id(agent_id=sleeptime_agent_id, actor=actor)
@pytest.mark.asyncio
@pytest.mark.asyncio(loop_scope="module")
async def test_sleeptime_group_chat_v2(server, actor):
# 0. Refresh base tools
server.tool_manager.upsert_base_tools(actor=actor)
@ -220,7 +221,7 @@ async def test_sleeptime_group_chat_v2(server, actor):
# 3. Verify shared blocks
sleeptime_agent_id = group.agent_ids[0]
shared_block = server.agent_manager.get_block_with_label(agent_id=main_agent.id, block_label="human", actor=actor)
agents = server.block_manager.get_agents_for_block(block_id=shared_block.id, actor=actor)
agents = await server.block_manager.get_agents_for_block_async(block_id=shared_block.id, actor=actor)
assert len(agents) == 2
assert sleeptime_agent_id in [agent.id for agent in agents]
assert main_agent.id in [agent.id for agent in agents]
@ -292,7 +293,7 @@ async def test_sleeptime_group_chat_v2(server, actor):
@pytest.mark.skip
@pytest.mark.asyncio
@pytest.mark.asyncio(loop_scope="module")
async def test_sleeptime_removes_redundant_information(server, actor):
# 1. set up sleep-time agent as in test_sleeptime_group_chat
server.tool_manager.upsert_base_tools(actor=actor)
@ -360,7 +361,7 @@ async def test_sleeptime_removes_redundant_information(server, actor):
server.agent_manager.get_agent_by_id(agent_id=sleeptime_agent_id, actor=actor)
@pytest.mark.asyncio
@pytest.mark.asyncio(loop_scope="module")
async def test_sleeptime_edit(server, actor):
sleeptime_agent = server.create_agent(
request=CreateAgent(

View File

@ -1,10 +1,10 @@
import os
import threading
import subprocess
import sys
from unittest.mock import MagicMock
import pytest
from dotenv import load_dotenv
from letta_client import AsyncLetta
from openai import AsyncOpenAI
from openai.types.chat import ChatCompletionChunk
@ -35,7 +35,7 @@ from letta.services.summarizer.summarizer import Summarizer
from letta.services.tool_manager import ToolManager
from letta.services.user_manager import UserManager
from letta.utils import get_persona_text
from tests.utils import wait_for_server
from tests.utils import create_tool_from_func, wait_for_server
MESSAGE_TRANSCRIPTS = [
"user: Hey, Ive been thinking about planning a road trip up the California coast next month.",
@ -92,17 +92,6 @@ Youre a memory-recall helper for an AI that can only keep the last 4 messages
# --- Server Management --- #
@pytest.fixture(scope="module")
def server():
config = LettaConfig.load()
print("CONFIG PATH", config.config_path)
config.save()
server = SyncServer()
return server
def _run_server():
"""Starts the Letta server in a background thread."""
load_dotenv()
@ -111,31 +100,66 @@ def _run_server():
start_server(debug=True)
@pytest.fixture(scope="session")
@pytest.fixture(scope="module")
def server_url():
"""Ensures a server is running and returns its base URL."""
url = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
"""
Starts the Letta HTTP server in a separate process using the 'uvicorn' CLI,
so its event loop and DB pool stay completely isolated from pytest-asyncio.
"""
url = os.getenv("LETTA_SERVER_URL", "http://127.0.0.1:8283")
# Only spawn our own server if the user hasn't overridden LETTA_SERVER_URL
if not os.getenv("LETTA_SERVER_URL"):
thread = threading.Thread(target=_run_server, daemon=True)
thread.start()
wait_for_server(url) # Allow server startup time
# Build the command to launch uvicorn on your FastAPI app
cmd = [
sys.executable,
"-m",
"uvicorn",
"letta.server.rest_api.app:app",
"--host",
"127.0.0.1",
"--port",
"8283",
]
# If you need TLS or reload settings from start_server(), you can add
# "--reload" or "--ssl-keyfile", "--ssl-certfile" here as well.
return url
server_proc = subprocess.Popen(
cmd,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
# wait until the HTTP port is accepting connections
wait_for_server(url)
yield url
# Teardown: kill the subprocess if we started it
server_proc.terminate()
server_proc.wait(timeout=10)
else:
yield url
@pytest.fixture(scope="module")
def server():
config = LettaConfig.load()
print("CONFIG PATH", config.config_path)
config.save()
server = SyncServer()
actor = server.user_manager.get_user_or_default()
server.tool_manager.upsert_base_tools(actor=actor)
return server
# --- Client Setup --- #
@pytest.fixture(scope="session")
def client(server_url):
"""Creates a REST client for testing."""
client = AsyncLetta(base_url=server_url)
yield client
@pytest.fixture(scope="function")
async def roll_dice_tool(client):
@pytest.fixture
async def roll_dice_tool(server):
def roll_dice():
"""
Rolls a 6 sided die.
@ -145,13 +169,13 @@ async def roll_dice_tool(client):
"""
return "Rolled a 10!"
tool = await client.tools.upsert_from_function(func=roll_dice)
# Yield the created tool
actor = server.user_manager.get_user_or_default()
tool = server.tool_manager.create_or_update_tool(create_tool_from_func(func=roll_dice), actor=actor)
yield tool
@pytest.fixture(scope="function")
async def weather_tool(client):
@pytest.fixture
async def weather_tool(server):
def get_weather(location: str) -> str:
"""
Fetches the current weather for a given location.
@ -176,22 +200,20 @@ async def weather_tool(client):
else:
raise RuntimeError(f"Failed to get weather data, status code: {response.status_code}")
tool = await client.tools.upsert_from_function(func=get_weather)
# Yield the created tool
actor = server.user_manager.get_user_or_default()
tool = server.tool_manager.create_or_update_tool(create_tool_from_func(func=get_weather), actor=actor)
yield tool
@pytest.fixture(scope="function")
@pytest.fixture
def composio_gmail_get_profile_tool(default_user):
tool_create = ToolCreate.from_composio(action_name="GMAIL_GET_PROFILE")
tool = ToolManager().create_or_update_composio_tool(tool_create=tool_create, actor=default_user)
yield tool
@pytest.fixture(scope="function")
@pytest.fixture
def voice_agent(server, actor):
server.tool_manager.upsert_base_tools(actor=actor)
main_agent = server.create_agent(
request=CreateAgent(
agent_type=AgentType.voice_convo_agent,
@ -268,9 +290,9 @@ def _assert_valid_chunk(chunk, idx, chunks):
# --- Tests --- #
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.asyncio(loop_scope="module")
@pytest.mark.parametrize("model", ["openai/gpt-4o-mini", "anthropic/claude-3-5-sonnet-20241022"])
async def test_model_compatibility(disable_e2b_api_key, voice_agent, model, server, group_id, actor):
async def test_model_compatibility(disable_e2b_api_key, voice_agent, model, server, server_url, group_id, actor):
request = _get_chat_request("How are you?")
server.tool_manager.upsert_base_tools(actor=actor)
@ -303,10 +325,10 @@ async def test_model_compatibility(disable_e2b_api_key, voice_agent, model, serv
print(chunk.choices[0].delta.content)
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.asyncio(loop_scope="module")
@pytest.mark.parametrize("message", ["Use search memory tool to recall what my name is."])
@pytest.mark.parametrize("endpoint", ["v1/voice-beta"])
async def test_voice_recall_memory(disable_e2b_api_key, voice_agent, message, endpoint):
async def test_voice_recall_memory(disable_e2b_api_key, voice_agent, message, endpoint, server_url):
"""Tests chat completion streaming using the Async OpenAI client."""
request = _get_chat_request(message)
@ -318,9 +340,9 @@ async def test_voice_recall_memory(disable_e2b_api_key, voice_agent, message, en
print(chunk.choices[0].delta.content)
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.asyncio(loop_scope="module")
@pytest.mark.parametrize("endpoint", ["v1/voice-beta"])
async def test_trigger_summarization(disable_e2b_api_key, server, voice_agent, group_id, endpoint, actor):
async def test_trigger_summarization(disable_e2b_api_key, server, voice_agent, group_id, endpoint, actor, server_url):
server.group_manager.modify_group(
group_id=group_id,
group_update=GroupUpdate(
@ -350,8 +372,8 @@ async def test_trigger_summarization(disable_e2b_api_key, server, voice_agent, g
print(chunk.choices[0].delta.content)
@pytest.mark.asyncio(loop_scope="session")
async def test_summarization(disable_e2b_api_key, voice_agent):
@pytest.mark.asyncio(loop_scope="module")
async def test_summarization(disable_e2b_api_key, voice_agent, server_url):
agent_manager = AgentManager()
user_manager = UserManager()
actor = user_manager.get_default_user()
@ -422,8 +444,8 @@ async def test_summarization(disable_e2b_api_key, voice_agent):
summarizer.fire_and_forget.assert_called_once()
@pytest.mark.asyncio(loop_scope="session")
async def test_voice_sleeptime_agent(disable_e2b_api_key, voice_agent):
@pytest.mark.asyncio(loop_scope="module")
async def test_voice_sleeptime_agent(disable_e2b_api_key, voice_agent, server_url):
"""Tests chat completion streaming using the Async OpenAI client."""
agent_manager = AgentManager()
tool_manager = ToolManager()
@ -488,8 +510,8 @@ async def test_voice_sleeptime_agent(disable_e2b_api_key, voice_agent):
assert not missing, f"Did not see calls to: {', '.join(missing)}"
@pytest.mark.asyncio(loop_scope="session")
async def test_init_voice_convo_agent(voice_agent, server, actor):
@pytest.mark.asyncio(loop_scope="module")
async def test_init_voice_convo_agent(voice_agent, server, actor, server_url):
assert voice_agent.enable_sleeptime == True
main_agent_tools = [tool.name for tool in voice_agent.tools]
@ -511,7 +533,7 @@ async def test_init_voice_convo_agent(voice_agent, server, actor):
# 3. Verify shared blocks
sleeptime_agent_id = group.agent_ids[0]
shared_block = server.agent_manager.get_block_with_label(agent_id=voice_agent.id, block_label="human", actor=actor)
agents = server.block_manager.get_agents_for_block(block_id=shared_block.id, actor=actor)
agents = await server.block_manager.get_agents_for_block_async(block_id=shared_block.id, actor=actor)
assert len(agents) == 2
assert sleeptime_agent_id in [agent.id for agent in agents]
assert voice_agent.id in [agent.id for agent in agents]

View File

@ -1,12 +1,15 @@
import difflib
import json
import os
import threading
import time
from datetime import datetime, timezone
from io import BytesIO
from typing import Any, Dict, List, Mapping
import pytest
from fastapi.testclient import TestClient
import requests
from dotenv import load_dotenv
from rich.console import Console
from rich.syntax import Syntax
@ -23,11 +26,51 @@ from letta.schemas.message import MessageCreate
from letta.schemas.organization import Organization
from letta.schemas.user import User
from letta.serialize_schemas.pydantic_agent_schema import AgentSchema
from letta.server.rest_api.app import app
from letta.server.server import SyncServer
console = Console()
# ------------------------------
# Fixtures
# ------------------------------
@pytest.fixture(scope="module")
def server_url() -> str:
"""
Provides the URL for the Letta server.
If LETTA_SERVER_URL is not set, starts the server in a background thread
and polls until its accepting connections.
"""
def _run_server() -> None:
load_dotenv()
from letta.server.rest_api.app import start_server
start_server(debug=True)
url: str = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
if not os.getenv("LETTA_SERVER_URL"):
thread = threading.Thread(target=_run_server, daemon=True)
thread.start()
# Poll until the server is up (or timeout)
timeout_seconds = 30
deadline = time.time() + timeout_seconds
while time.time() < deadline:
try:
resp = requests.get(url + "/v1/health")
if resp.status_code < 500:
break
except requests.exceptions.RequestException:
pass
time.sleep(0.1)
else:
raise RuntimeError(f"Could not reach {url} within {timeout_seconds}s")
return url
def _clear_tables():
from letta.server.db import db_context
@ -38,12 +81,6 @@ def _clear_tables():
session.commit()
@pytest.fixture
def fastapi_client():
"""Fixture to create a FastAPI test client."""
return TestClient(app)
@pytest.fixture(autouse=True)
def clear_tables():
_clear_tables()
@ -57,14 +94,14 @@ def local_client():
yield client
@pytest.fixture(scope="module")
@pytest.fixture
def server():
config = LettaConfig.load()
config.save()
server = SyncServer(init_with_default_org_and_user=False)
return server
yield server
@pytest.fixture
@ -562,14 +599,17 @@ def test_agent_serialize_update_blocks(disable_e2b_api_key, local_client, server
@pytest.mark.parametrize("append_copy_suffix", [True, False])
@pytest.mark.parametrize("project_id", ["project-12345", None])
def test_agent_download_upload_flow(fastapi_client, server, serialize_test_agent, default_user, other_user, append_copy_suffix, project_id):
def test_agent_download_upload_flow(server, server_url, serialize_test_agent, default_user, other_user, append_copy_suffix, project_id):
"""
Test the full E2E serialization and deserialization flow using FastAPI endpoints.
"""
agent_id = serialize_test_agent.id
# Step 1: Download the serialized agent
response = fastapi_client.get(f"/v1/agents/{agent_id}/export", headers={"user_id": default_user.id})
response = requests.get(
f"{server_url}/v1/agents/{agent_id}/export",
headers={"user_id": default_user.id},
)
assert response.status_code == 200, f"Download failed: {response.text}"
# Ensure response matches expected schema
@ -580,10 +620,14 @@ def test_agent_download_upload_flow(fastapi_client, server, serialize_test_agent
# Step 2: Upload the serialized agent as a copy
agent_bytes = BytesIO(json.dumps(agent_json).encode("utf-8"))
files = {"file": ("agent.json", agent_bytes, "application/json")}
upload_response = fastapi_client.post(
"/v1/agents/import",
upload_response = requests.post(
f"{server_url}/v1/agents/import",
headers={"user_id": other_user.id},
params={"append_copy_suffix": append_copy_suffix, "override_existing_tools": False, "project_id": project_id},
params={
"append_copy_suffix": append_copy_suffix,
"override_existing_tools": False,
"project_id": project_id,
},
files=files,
)
assert upload_response.status_code == 200, f"Upload failed: {upload_response.text}"
@ -613,16 +657,16 @@ def test_agent_download_upload_flow(fastapi_client, server, serialize_test_agent
"memgpt_agent_with_convo.af",
],
)
def test_upload_agentfile_from_disk(server, disable_e2b_api_key, fastapi_client, other_user, filename):
def test_upload_agentfile_from_disk(server, server_url, disable_e2b_api_key, other_user, filename):
"""
Test uploading each .af file from the test_agent_files directory via FastAPI.
Test uploading each .af file from the test_agent_files directory via live FastAPI server.
"""
file_path = os.path.join(os.path.dirname(__file__), "test_agent_files", filename)
with open(file_path, "rb") as f:
files = {"file": (filename, f, "application/json")}
response = fastapi_client.post(
"/v1/agents/import",
response = requests.post(
f"{server_url}/v1/agents/import",
headers={"user_id": other_user.id},
params={"append_copy_suffix": True, "override_existing_tools": False},
files=files,

View File

@ -1,5 +1,3 @@
import os
import threading
from datetime import datetime, timezone
from typing import Tuple
from unittest.mock import AsyncMock, patch
@ -14,15 +12,13 @@ from anthropic.types.beta.messages import (
BetaMessageBatchRequestCounts,
BetaMessageBatchSucceededResult,
)
from dotenv import load_dotenv
from letta_client import Letta
from letta.agents.letta_agent_batch import LettaAgentBatch
from letta.config import LettaConfig
from letta.helpers import ToolRulesSolver
from letta.jobs.llm_batch_job_polling import poll_running_llm_batches
from letta.orm import Base
from letta.schemas.agent import AgentState, AgentStepState
from letta.schemas.agent import AgentState, AgentStepState, CreateAgent
from letta.schemas.enums import AgentStepStatus, JobStatus, MessageRole, ProviderType
from letta.schemas.job import BatchJob
from letta.schemas.letta_message_content import TextContent
@ -31,10 +27,10 @@ from letta.schemas.message import MessageCreate
from letta.schemas.tool_rule import InitToolRule
from letta.server.db import db_context
from letta.server.server import SyncServer
from tests.utils import wait_for_server
from tests.utils import create_tool_from_func
# --------------------------------------------------------------------------- #
# Test Constants
# Test Constants / Helpers
# --------------------------------------------------------------------------- #
# Model identifiers used in tests
@ -54,7 +50,7 @@ EXPECTED_ROLES = ["system", "assistant", "tool", "user", "user"]
@pytest.fixture(scope="function")
def weather_tool(client):
def weather_tool(server):
def get_weather(location: str) -> str:
"""
Fetches the current weather for a given location.
@ -79,13 +75,14 @@ def weather_tool(client):
else:
raise RuntimeError(f"Failed to get weather data, status code: {response.status_code}")
tool = client.tools.upsert_from_function(func=get_weather)
actor = server.user_manager.get_user_or_default()
tool = server.tool_manager.create_or_update_tool(create_tool_from_func(func=get_weather), actor=actor)
# Yield the created tool
yield tool
@pytest.fixture(scope="function")
def rethink_tool(client):
def rethink_tool(server):
def rethink_memory(agent_state: "AgentState", new_memory: str, target_block_label: str) -> str: # type: ignore
"""
Re-evaluate the memory in block_name, integrating new and updated facts.
@ -101,28 +98,33 @@ def rethink_tool(client):
agent_state.memory.update_block_value(label=target_block_label, value=new_memory)
return None
tool = client.tools.upsert_from_function(func=rethink_memory)
actor = server.user_manager.get_user_or_default()
tool = server.tool_manager.create_or_update_tool(create_tool_from_func(func=rethink_memory), actor=actor)
# Yield the created tool
yield tool
@pytest.fixture
def agents(client, weather_tool):
def agents(server, weather_tool):
"""
Create three test agents with different models.
Returns:
Tuple[Agent, Agent, Agent]: Three agents with sonnet, haiku, and opus models
"""
actor = server.user_manager.get_user_or_default()
def create_agent(suffix, model_name):
return client.agents.create(
name=f"test_agent_{suffix}",
include_base_tools=True,
model=model_name,
tags=["test_agents"],
embedding="letta/letta-free",
tool_ids=[weather_tool.id],
return server.create_agent(
CreateAgent(
name=f"test_agent_{suffix}",
include_base_tools=True,
model=model_name,
tags=["test_agents"],
embedding="letta/letta-free",
tool_ids=[weather_tool.id],
),
actor=actor,
)
return (
@ -290,32 +292,6 @@ def clear_batch_tables():
session.commit()
def run_server():
"""Starts the Letta server in a background thread."""
load_dotenv()
from letta.server.rest_api.app import start_server
start_server(debug=True)
@pytest.fixture(scope="session")
def server_url():
"""
Ensures a server is running and returns its base URL.
Uses environment variable if available, otherwise starts a server
in a background thread.
"""
url = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
if not os.getenv("LETTA_SERVER_URL"):
thread = threading.Thread(target=run_server, daemon=True)
thread.start()
wait_for_server(url)
return url
@pytest.fixture(scope="module")
def server():
"""
@ -324,14 +300,11 @@ def server():
Loads and saves config to ensure proper initialization.
"""
config = LettaConfig.load()
config.save()
return SyncServer()
@pytest.fixture(scope="session")
def client(server_url):
"""Creates a REST client connected to the test server."""
return Letta(base_url=server_url)
server = SyncServer(init_with_default_org_and_user=True)
yield server
@pytest.fixture
@ -368,23 +341,27 @@ class MockAsyncIterable:
# --------------------------------------------------------------------------- #
@pytest.mark.asyncio(loop_scope="session")
async def test_rethink_tool_modify_agent_state(client, disable_e2b_api_key, server, default_user, batch_job, rethink_tool):
@pytest.mark.asyncio(loop_scope="module")
async def test_rethink_tool_modify_agent_state(disable_e2b_api_key, server, default_user, batch_job, rethink_tool):
target_block_label = "human"
new_memory = "banana"
agent = client.agents.create(
name=f"test_agent_rethink",
include_base_tools=True,
model=MODELS["sonnet"],
tags=["test_agents"],
embedding="letta/letta-free",
tool_ids=[rethink_tool.id],
memory_blocks=[
{
"label": target_block_label,
"value": "Name: Matt",
},
],
actor = server.user_manager.get_user_or_default()
agent = await server.create_agent_async(
request=CreateAgent(
name=f"test_agent_rethink",
include_base_tools=True,
model=MODELS["sonnet"],
tags=["test_agents"],
embedding="letta/letta-free",
tool_ids=[rethink_tool.id],
memory_blocks=[
{
"label": target_block_label,
"value": "Name: Matt",
},
],
),
actor=actor,
)
agents = [agent]
batch_requests = [
@ -444,13 +421,13 @@ async def test_rethink_tool_modify_agent_state(client, disable_e2b_api_key, serv
await poll_running_llm_batches(server)
# Check that the tool has been executed correctly
agent = client.agents.retrieve(agent_id=agent.id)
agent = server.agent_manager.get_agent_by_id(agent_id=agent.id, actor=actor)
for block in agent.memory.blocks:
if block.label == target_block_label:
assert block.value == new_memory
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.asyncio(loop_scope="module")
async def test_partial_error_from_anthropic_batch(
disable_e2b_api_key, server, default_user, agents: Tuple[AgentState], batch_requests, step_state_map, batch_job
):
@ -610,7 +587,7 @@ async def test_partial_error_from_anthropic_batch(
assert agent_messages[0].role == MessageRole.user, "Expected initial user message"
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.asyncio(loop_scope="module")
async def test_resume_step_some_stop(
disable_e2b_api_key, server, default_user, agents: Tuple[AgentState], batch_requests, step_state_map, batch_job
):
@ -773,7 +750,7 @@ def _assert_descending_order(messages):
return True
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.asyncio(loop_scope="module")
async def test_resume_step_after_request_all_continue(
disable_e2b_api_key, server, default_user, agents: Tuple[AgentState], batch_requests, step_state_map, batch_job
):
@ -911,7 +888,7 @@ async def test_resume_step_after_request_all_continue(
assert agent_messages[-4].role == MessageRole.user, "Expected final system-level heartbeat user message"
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.asyncio(loop_scope="module")
async def test_step_until_request_prepares_and_submits_batch_correctly(
disable_e2b_api_key, server, default_user, agents, batch_requests, step_state_map, dummy_batch_response, batch_job
):

File diff suppressed because it is too large Load Diff

View File

@ -2,7 +2,7 @@ import pytest
from sqlalchemy import delete
from letta.config import LettaConfig
from letta.orm import Provider, Step
from letta.orm import Provider, ProviderTrace, Step
from letta.schemas.agent import CreateAgent
from letta.schemas.block import CreateBlock
from letta.schemas.group import (
@ -38,6 +38,7 @@ def org_id(server):
# cleanup
with db_registry.session() as session:
session.execute(delete(ProviderTrace))
session.execute(delete(Step))
session.execute(delete(Provider))
session.commit()

View File

@ -0,0 +1,205 @@
import asyncio
import json
import os
import threading
import time
import uuid
import pytest
from dotenv import load_dotenv
from letta_client import Letta
from letta.agents.letta_agent import LettaAgent
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.letta_message_content import TextContent
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import MessageCreate
from letta.server.rest_api.streaming_response import StreamingResponseWithStatusCode
from letta.services.agent_manager import AgentManager
from letta.services.block_manager import BlockManager
from letta.services.message_manager import MessageManager
from letta.services.passage_manager import PassageManager
from letta.services.step_manager import StepManager
from letta.services.telemetry_manager import NoopTelemetryManager, TelemetryManager
def _run_server():
"""Starts the Letta server in a background thread."""
load_dotenv()
from letta.server.rest_api.app import start_server
start_server(debug=True)
@pytest.fixture(scope="session")
def server_url():
"""Ensures a server is running and returns its base URL."""
url = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
if not os.getenv("LETTA_SERVER_URL"):
thread = threading.Thread(target=_run_server, daemon=True)
thread.start()
time.sleep(5) # Allow server startup time
return url
# # --- Client Setup --- #
@pytest.fixture(scope="session")
def client(server_url):
"""Creates a REST client for testing."""
client = Letta(base_url=server_url)
yield client
@pytest.fixture(scope="session")
def event_loop(request):
"""Create an instance of the default event loop for each test case."""
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()
@pytest.fixture(scope="function")
def roll_dice_tool(client, roll_dice_tool_func):
print_tool = client.tools.upsert_from_function(func=roll_dice_tool_func)
yield print_tool
@pytest.fixture(scope="function")
def weather_tool(client, weather_tool_func):
weather_tool = client.tools.upsert_from_function(func=weather_tool_func)
yield weather_tool
@pytest.fixture(scope="function")
def print_tool(client, print_tool_func):
print_tool = client.tools.upsert_from_function(func=print_tool_func)
yield print_tool
@pytest.fixture(scope="function")
def agent_state(client, roll_dice_tool, weather_tool):
"""Creates an agent and ensures cleanup after tests."""
agent_state = client.agents.create(
name=f"test_compl_{str(uuid.uuid4())[5:]}",
tool_ids=[roll_dice_tool.id, weather_tool.id],
include_base_tools=True,
memory_blocks=[
{
"label": "human",
"value": "Name: Matt",
},
{
"label": "persona",
"value": "Friendly agent",
},
],
llm_config=LLMConfig.default_config(model_name="gpt-4o-mini"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
)
yield agent_state
client.agents.delete(agent_state.id)
@pytest.mark.asyncio
@pytest.mark.parametrize("message", ["Get the weather in San Francisco."])
async def test_provider_trace_experimental_step(message, agent_state, default_user):
experimental_agent = LettaAgent(
agent_id=agent_state.id,
message_manager=MessageManager(),
agent_manager=AgentManager(),
block_manager=BlockManager(),
passage_manager=PassageManager(),
step_manager=StepManager(),
telemetry_manager=TelemetryManager(),
actor=default_user,
)
response = await experimental_agent.step([MessageCreate(role="user", content=[TextContent(text=message)])])
tool_step = response.messages[0].step_id
reply_step = response.messages[-1].step_id
tool_telemetry = await experimental_agent.telemetry_manager.get_provider_trace_by_step_id_async(step_id=tool_step, actor=default_user)
reply_telemetry = await experimental_agent.telemetry_manager.get_provider_trace_by_step_id_async(step_id=reply_step, actor=default_user)
assert tool_telemetry.request_json
assert reply_telemetry.request_json
@pytest.mark.asyncio
@pytest.mark.parametrize("message", ["Get the weather in San Francisco."])
async def test_provider_trace_experimental_step_stream(message, agent_state, default_user, event_loop):
experimental_agent = LettaAgent(
agent_id=agent_state.id,
message_manager=MessageManager(),
agent_manager=AgentManager(),
block_manager=BlockManager(),
passage_manager=PassageManager(),
step_manager=StepManager(),
telemetry_manager=TelemetryManager(),
actor=default_user,
)
stream = experimental_agent.step_stream([MessageCreate(role="user", content=[TextContent(text=message)])])
result = StreamingResponseWithStatusCode(
stream,
media_type="text/event-stream",
)
message_id = None
async def test_send(message) -> None:
nonlocal message_id
if "body" in message and not message_id:
body = message["body"].decode("utf-8").split("data:")
message_id = json.loads(body[1])["id"]
await result.stream_response(send=test_send)
messages = await experimental_agent.message_manager.get_messages_by_ids_async([message_id], actor=default_user)
step_ids = set((message.step_id for message in messages))
for step_id in step_ids:
telemetry_data = await experimental_agent.telemetry_manager.get_provider_trace_by_step_id_async(step_id=step_id, actor=default_user)
assert telemetry_data.request_json
assert telemetry_data.response_json
@pytest.mark.asyncio
@pytest.mark.parametrize("message", ["Get the weather in San Francisco."])
async def test_provider_trace_step(client, agent_state, default_user, message, event_loop):
client.agents.messages.create(agent_id=agent_state.id, messages=[])
response = client.agents.messages.create(
agent_id=agent_state.id,
messages=[MessageCreate(role="user", content=[TextContent(text=message)])],
)
tool_step = response.messages[0].step_id
reply_step = response.messages[-1].step_id
tool_telemetry = await TelemetryManager().get_provider_trace_by_step_id_async(step_id=tool_step, actor=default_user)
reply_telemetry = await TelemetryManager().get_provider_trace_by_step_id_async(step_id=reply_step, actor=default_user)
assert tool_telemetry.request_json
assert reply_telemetry.request_json
@pytest.mark.asyncio
@pytest.mark.parametrize("message", ["Get the weather in San Francisco."])
async def test_noop_provider_trace(message, agent_state, default_user, event_loop):
experimental_agent = LettaAgent(
agent_id=agent_state.id,
message_manager=MessageManager(),
agent_manager=AgentManager(),
block_manager=BlockManager(),
passage_manager=PassageManager(),
step_manager=StepManager(),
telemetry_manager=NoopTelemetryManager(),
actor=default_user,
)
response = await experimental_agent.step([MessageCreate(role="user", content=[TextContent(text=message)])])
tool_step = response.messages[0].step_id
reply_step = response.messages[-1].step_id
tool_telemetry = await experimental_agent.telemetry_manager.get_provider_trace_by_step_id_async(step_id=tool_step, actor=default_user)
reply_telemetry = await experimental_agent.telemetry_manager.get_provider_trace_by_step_id_async(step_id=reply_step, actor=default_user)
assert tool_telemetry is None
assert reply_telemetry is None

View File

@ -1,15 +1,12 @@
import os
import pytest
from letta.schemas.providers import (
AnthropicBedrockProvider,
AnthropicProvider,
AzureProvider,
DeepSeekProvider,
GoogleAIProvider,
GoogleVertexProvider,
GroqProvider,
MistralProvider,
OllamaProvider,
OpenAIProvider,
TogetherProvider,
)
@ -17,11 +14,9 @@ from letta.settings import model_settings
def test_openai():
api_key = os.getenv("OPENAI_API_KEY")
assert api_key is not None
provider = OpenAIProvider(
name="openai",
api_key=api_key,
api_key=model_settings.openai_api_key,
base_url=model_settings.openai_api_base,
)
models = provider.list_llm_models()
@ -33,34 +28,54 @@ def test_openai():
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
def test_deepseek():
api_key = os.getenv("DEEPSEEK_API_KEY")
assert api_key is not None
provider = DeepSeekProvider(
name="deepseek",
api_key=api_key,
@pytest.mark.asyncio
async def test_openai_async():
provider = OpenAIProvider(
name="openai",
api_key=model_settings.openai_api_key,
base_url=model_settings.openai_api_base,
)
models = await provider.list_llm_models_async()
assert len(models) > 0
assert models[0].handle == f"{provider.name}/{models[0].model}"
embedding_models = await provider.list_embedding_models_async()
assert len(embedding_models) > 0
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
def test_deepseek():
provider = DeepSeekProvider(name="deepseek", api_key=model_settings.deepseek_api_key)
models = provider.list_llm_models()
assert len(models) > 0
assert models[0].handle == f"{provider.name}/{models[0].model}"
def test_anthropic():
api_key = os.getenv("ANTHROPIC_API_KEY")
assert api_key is not None
provider = AnthropicProvider(
name="anthropic",
api_key=api_key,
api_key=model_settings.anthropic_api_key,
)
models = provider.list_llm_models()
assert len(models) > 0
assert models[0].handle == f"{provider.name}/{models[0].model}"
@pytest.mark.asyncio
async def test_anthropic_async():
provider = AnthropicProvider(
name="anthropic",
api_key=model_settings.anthropic_api_key,
)
models = await provider.list_llm_models_async()
assert len(models) > 0
assert models[0].handle == f"{provider.name}/{models[0].model}"
def test_groq():
provider = GroqProvider(
name="groq",
api_key=os.getenv("GROQ_API_KEY"),
api_key=model_settings.groq_api_key,
)
models = provider.list_llm_models()
assert len(models) > 0
@ -70,8 +85,9 @@ def test_groq():
def test_azure():
provider = AzureProvider(
name="azure",
api_key=os.getenv("AZURE_API_KEY"),
base_url=os.getenv("AZURE_BASE_URL"),
api_key=model_settings.azure_api_key,
base_url=model_settings.azure_base_url,
api_version=model_settings.azure_api_version,
)
models = provider.list_llm_models()
assert len(models) > 0
@ -82,26 +98,24 @@ def test_azure():
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
def test_ollama():
base_url = os.getenv("OLLAMA_BASE_URL")
assert base_url is not None
provider = OllamaProvider(
name="ollama",
base_url=base_url,
default_prompt_formatter=model_settings.default_prompt_formatter,
api_key=None,
)
models = provider.list_llm_models()
assert len(models) > 0
assert models[0].handle == f"{provider.name}/{models[0].model}"
embedding_models = provider.list_embedding_models()
assert len(embedding_models) > 0
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
# def test_ollama():
# provider = OllamaProvider(
# name="ollama",
# base_url=model_settings.ollama_base_url,
# api_key=None,
# default_prompt_formatter=model_settings.default_prompt_formatter,
# )
# models = provider.list_llm_models()
# assert len(models) > 0
# assert models[0].handle == f"{provider.name}/{models[0].model}"
#
# embedding_models = provider.list_embedding_models()
# assert len(embedding_models) > 0
# assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
def test_googleai():
api_key = os.getenv("GEMINI_API_KEY")
api_key = model_settings.gemini_api_key
assert api_key is not None
provider = GoogleAIProvider(
name="google_ai",
@ -116,11 +130,28 @@ def test_googleai():
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
@pytest.mark.asyncio
async def test_googleai_async():
api_key = model_settings.gemini_api_key
assert api_key is not None
provider = GoogleAIProvider(
name="google_ai",
api_key=api_key,
)
models = await provider.list_llm_models_async()
assert len(models) > 0
assert models[0].handle == f"{provider.name}/{models[0].model}"
embedding_models = await provider.list_embedding_models_async()
assert len(embedding_models) > 0
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
def test_google_vertex():
provider = GoogleVertexProvider(
name="google_vertex",
google_cloud_project=os.getenv("GCP_PROJECT_ID"),
google_cloud_location=os.getenv("GCP_REGION"),
google_cloud_project=model_settings.google_cloud_project,
google_cloud_location=model_settings.google_cloud_location,
)
models = provider.list_llm_models()
assert len(models) > 0
@ -131,50 +162,57 @@ def test_google_vertex():
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
def test_mistral():
provider = MistralProvider(
name="mistral",
api_key=os.getenv("MISTRAL_API_KEY"),
)
models = provider.list_llm_models()
assert len(models) > 0
assert models[0].handle == f"{provider.name}/{models[0].model}"
def test_together():
provider = TogetherProvider(
name="together",
api_key=os.getenv("TOGETHER_API_KEY"),
default_prompt_formatter="chatml",
api_key=model_settings.together_api_key,
default_prompt_formatter=model_settings.default_prompt_formatter,
)
models = provider.list_llm_models()
assert len(models) > 0
assert models[0].handle == f"{provider.name}/{models[0].model}"
embedding_models = provider.list_embedding_models()
assert len(embedding_models) > 0
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
# TODO: We don't have embedding models on together for CI
# embedding_models = provider.list_embedding_models()
# assert len(embedding_models) > 0
# assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
def test_anthropic_bedrock():
from letta.settings import model_settings
provider = AnthropicBedrockProvider(name="bedrock", aws_region=model_settings.aws_region)
models = provider.list_llm_models()
@pytest.mark.asyncio
async def test_together_async():
provider = TogetherProvider(
name="together",
api_key=model_settings.together_api_key,
default_prompt_formatter=model_settings.default_prompt_formatter,
)
models = await provider.list_llm_models_async()
assert len(models) > 0
assert models[0].handle == f"{provider.name}/{models[0].model}"
embedding_models = provider.list_embedding_models()
assert len(embedding_models) > 0
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
# TODO: We don't have embedding models on together for CI
# embedding_models = provider.list_embedding_models()
# assert len(embedding_models) > 0
# assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
# TODO: Add back in, difficulty adding this to CI properly, need boto credentials
# def test_anthropic_bedrock():
# from letta.settings import model_settings
#
# provider = AnthropicBedrockProvider(name="bedrock", aws_region=model_settings.aws_region)
# models = provider.list_llm_models()
# assert len(models) > 0
# assert models[0].handle == f"{provider.name}/{models[0].model}"
#
# embedding_models = provider.list_embedding_models()
# assert len(embedding_models) > 0
# assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
def test_custom_anthropic():
api_key = os.getenv("ANTHROPIC_API_KEY")
assert api_key is not None
provider = AnthropicProvider(
name="custom_anthropic",
api_key=api_key,
api_key=model_settings.anthropic_api_key,
)
models = provider.list_llm_models()
assert len(models) > 0

View File

@ -11,7 +11,7 @@ from sqlalchemy import delete
import letta.utils as utils
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, LETTA_DIR, LETTA_TOOL_EXECUTION_DIR
from letta.orm import Provider, Step
from letta.orm import Provider, ProviderTrace, Step
from letta.schemas.block import CreateBlock
from letta.schemas.enums import MessageRole, ProviderCategory, ProviderType
from letta.schemas.letta_message import LettaMessage, ReasoningMessage, SystemMessage, ToolCallMessage, ToolReturnMessage, UserMessage
@ -286,6 +286,7 @@ def org_id(server):
# cleanup
with db_registry.session() as session:
session.execute(delete(ProviderTrace))
session.execute(delete(Step))
session.execute(delete(Provider))
session.commit()
@ -565,7 +566,8 @@ def test_delete_agent_same_org(server: SyncServer, org_id: str, user: User):
server.agent_manager.delete_agent(agent_state.id, actor=another_user)
def test_read_local_llm_configs(server: SyncServer, user: User):
@pytest.mark.asyncio
async def test_read_local_llm_configs(server: SyncServer, user: User):
configs_base_dir = os.path.join(os.path.expanduser("~"), ".letta", "llm_configs")
clean_up_dir = False
if not os.path.exists(configs_base_dir):
@ -588,7 +590,7 @@ def test_read_local_llm_configs(server: SyncServer, user: User):
# Call list_llm_models
assert os.path.exists(configs_base_dir)
llm_models = server.list_llm_models(actor=user)
llm_models = await server.list_llm_models_async(actor=user)
# Assert that the config is in the returned models
assert any(
@ -935,7 +937,7 @@ def test_composio_client_simple(server):
assert len(actions) > 0
def test_memory_rebuild_count(server, user, disable_e2b_api_key, base_tools, base_memory_tools):
async def test_memory_rebuild_count(server, user, disable_e2b_api_key, base_tools, base_memory_tools):
"""Test that the memory rebuild is generating the correct number of role=system messages"""
actor = user
# create agent
@ -1223,7 +1225,8 @@ def test_add_remove_tools_update_agent(server: SyncServer, user_id: str, base_to
assert len(agent_state.tools) == len(base_tools) - 2
def test_messages_with_provider_override(server: SyncServer, user_id: str):
@pytest.mark.asyncio
async def test_messages_with_provider_override(server: SyncServer, user_id: str):
actor = server.user_manager.get_user_or_default(user_id)
provider = server.provider_manager.create_provider(
request=ProviderCreate(
@ -1233,10 +1236,10 @@ def test_messages_with_provider_override(server: SyncServer, user_id: str):
),
actor=actor,
)
models = server.list_llm_models(actor=actor, provider_category=[ProviderCategory.byok])
models = await server.list_llm_models_async(actor=actor, provider_category=[ProviderCategory.byok])
assert provider.name in [model.provider_name for model in models]
models = server.list_llm_models(actor=actor, provider_category=[ProviderCategory.base])
models = await server.list_llm_models_async(actor=actor, provider_category=[ProviderCategory.base])
assert provider.name not in [model.provider_name for model in models]
agent = server.create_agent(
@ -1302,11 +1305,12 @@ def test_messages_with_provider_override(server: SyncServer, user_id: str):
assert total_tokens == usage.total_tokens
def test_unique_handles_for_provider_configs(server: SyncServer, user: User):
models = server.list_llm_models(actor=user)
@pytest.mark.asyncio
async def test_unique_handles_for_provider_configs(server: SyncServer, user: User):
models = await server.list_llm_models_async(actor=user)
model_handles = [model.handle for model in models]
assert sorted(model_handles) == sorted(list(set(model_handles))), "All models should have unique handles"
embeddings = server.list_embedding_models(actor=user)
embeddings = await server.list_embedding_models_async(actor=user)
embedding_handles = [embedding.handle for embedding in embeddings]
assert sorted(embedding_handles) == sorted(list(set(embedding_handles))), "All embeddings should have unique handles"

View File

@ -4,15 +4,16 @@ import string
import time
from datetime import datetime, timezone
from importlib import util
from typing import Dict, Iterator, List, Tuple
from typing import Dict, Iterator, List, Optional, Tuple
import requests
from letta_client import Letta, SystemMessage
from letta.config import LettaConfig
from letta.data_sources.connectors import DataConnector
from letta.schemas.enums import MessageRole
from letta.functions.functions import parse_source_code
from letta.schemas.file import FileMetadata
from letta.schemas.message import Message
from letta.schemas.tool import Tool
from letta.settings import TestSettings
from .constants import TIMEOUT
@ -152,7 +153,7 @@ def with_qdrant_storage(storage: list[str]):
def wait_for_incoming_message(
client,
client: Letta,
agent_id: str,
substring: str = "[Incoming message from agent with ID",
max_wait_seconds: float = 10.0,
@ -166,13 +167,13 @@ def wait_for_incoming_message(
deadline = time.time() + max_wait_seconds
while time.time() < deadline:
messages = client.server.message_manager.list_messages_for_agent(agent_id=agent_id, actor=client.user)
messages = client.agents.messages.list(agent_id)[1:]
# Check for the system message containing `substring`
def get_message_text(message: Message) -> str:
return message.content[0].text if message.content and len(message.content) == 1 else ""
def get_message_text(message: SystemMessage) -> str:
return message.content if message.content else ""
if any(message.role == MessageRole.system and substring in get_message_text(message) for message in messages):
if any(isinstance(message, SystemMessage) and substring in get_message_text(message) for message in messages):
return True
time.sleep(sleep_interval)
@ -199,3 +200,21 @@ def wait_for_server(url, timeout=30, interval=0.5):
def random_string(length: int) -> str:
return "".join(random.choices(string.ascii_letters + string.digits, k=length))
def create_tool_from_func(
func,
tags: Optional[List[str]] = None,
description: Optional[str] = None,
):
source_code = parse_source_code(func)
source_type = "python"
if not tags:
tags = []
return Tool(
source_type=source_type,
source_code=source_code,
tags=tags,
description=description,
)