From c0efe8ad0cc91ef2181e3d6158dc46891c08ca2a Mon Sep 17 00:00:00 2001 From: cthomas Date: Wed, 21 May 2025 16:33:29 -0700 Subject: [PATCH] chore: bump version 0.7.21 (#2653) Co-authored-by: Andy Li <55300002+cliandy@users.noreply.github.com> Co-authored-by: Kevin Lin Co-authored-by: Sarah Wooders Co-authored-by: jnjpng Co-authored-by: Matthew Zhou --- ...224a7a58_add_provider_category_to_steps.py | 31 + ...d_add_support_for_request_and_response_.py | 50 ++ letta/__init__.py | 2 +- letta/agent.py | 293 ++++++++- letta/agents/base_agent.py | 55 -- letta/agents/helpers.py | 5 + letta/agents/letta_agent.py | 378 ++++++++++-- letta/agents/letta_agent_batch.py | 157 +++-- letta/agents/voice_agent.py | 10 +- letta/client/client.py | 27 +- letta/constants.py | 56 +- letta/functions/function_sets/builtin.py | 27 + letta/groups/sleeptime_multi_agent_v2.py | 2 +- .../anthropic_streaming_interface.py | 11 +- .../interfaces/openai_streaming_interface.py | 11 +- letta/llm_api/anthropic.py | 23 +- letta/llm_api/anthropic_client.py | 39 +- letta/llm_api/google_ai_client.py | 565 +++++------------- letta/llm_api/google_vertex_client.py | 195 +++++- letta/llm_api/llm_api_tools.py | 27 + letta/llm_api/llm_client.py | 2 +- letta/llm_api/llm_client_base.py | 53 +- letta/llm_api/openai.py | 57 ++ letta/llm_api/openai_client.py | 18 +- letta/memory.py | 1 - letta/orm/__init__.py | 1 + letta/orm/enums.py | 1 + letta/orm/provider_trace.py | 26 + letta/orm/step.py | 1 + letta/schemas/provider_trace.py | 43 ++ letta/schemas/providers.py | 279 ++++++--- letta/schemas/step.py | 1 + letta/schemas/tool.py | 4 + letta/server/db.py | 56 +- letta/server/rest_api/routers/v1/__init__.py | 2 + letta/server/rest_api/routers/v1/agents.py | 89 ++- letta/server/rest_api/routers/v1/blocks.py | 6 +- .../server/rest_api/routers/v1/identities.py | 50 +- letta/server/rest_api/routers/v1/jobs.py | 6 +- letta/server/rest_api/routers/v1/llms.py | 21 +- .../rest_api/routers/v1/sandbox_configs.py | 12 +- letta/server/rest_api/routers/v1/tags.py | 6 +- letta/server/rest_api/routers/v1/telemetry.py | 18 + letta/server/rest_api/routers/v1/tools.py | 12 +- letta/server/rest_api/streaming_response.py | 105 ++++ letta/server/rest_api/utils.py | 4 + letta/server/server.py | 141 ++++- letta/services/agent_manager.py | 269 ++++++++- letta/services/block_manager.py | 93 +-- letta/services/helpers/noop_helper.py | 10 + letta/services/identity_manager.py | 81 +-- letta/services/job_manager.py | 29 + letta/services/message_manager.py | 111 ++++ letta/services/sandbox_config_manager.py | 36 ++ letta/services/step_manager.py | 146 +++++ letta/services/telemetry_manager.py | 58 ++ .../tool_executor/tool_execution_manager.py | 54 +- .../tool_executor/tool_execution_sandbox.py | 47 ++ letta/services/tool_executor/tool_executor.py | 243 +++++++- letta/services/tool_manager.py | 161 ++++- letta/services/tool_sandbox/e2b_sandbox.py | 68 ++- letta/settings.py | 12 +- letta/tracing.py | 10 +- poetry.lock | 25 +- pyproject.toml | 5 +- .../gemini-2.5-pro-vertex.json | 2 +- .../together-qwen-2.5-72b-instruct.json | 7 + tests/conftest.py | 18 + tests/integration_test_batch_api_cron_jobs.py | 147 +---- tests/integration_test_builtin_tools.py | 206 +++++++ tests/integration_test_composio.py | 11 +- tests/integration_test_multi_agent.py | 343 +++++++---- tests/integration_test_send_message.py | 32 + tests/integration_test_sleeptime_agent.py | 15 +- tests/integration_test_voice_agent.py | 128 ++-- tests/test_agent_serialization.py | 82 ++- tests/test_letta_agent_batch.py | 119 ++-- tests/test_managers.py | 528 ++++++++-------- tests/test_multi_agent.py | 3 +- tests/test_provider_trace.py | 205 +++++++ tests/test_providers.py | 166 +++-- tests/test_server.py | 24 +- tests/utils.py | 35 +- 83 files changed, 4774 insertions(+), 1734 deletions(-) create mode 100644 alembic/versions/6c53224a7a58_add_provider_category_to_steps.py create mode 100644 alembic/versions/cc8dc340836d_add_support_for_request_and_response_.py create mode 100644 letta/functions/function_sets/builtin.py create mode 100644 letta/orm/provider_trace.py create mode 100644 letta/schemas/provider_trace.py create mode 100644 letta/server/rest_api/routers/v1/telemetry.py create mode 100644 letta/server/rest_api/streaming_response.py create mode 100644 letta/services/helpers/noop_helper.py create mode 100644 letta/services/telemetry_manager.py create mode 100644 tests/configs/llm_model_configs/together-qwen-2.5-72b-instruct.json create mode 100644 tests/integration_test_builtin_tools.py create mode 100644 tests/test_provider_trace.py diff --git a/alembic/versions/6c53224a7a58_add_provider_category_to_steps.py b/alembic/versions/6c53224a7a58_add_provider_category_to_steps.py new file mode 100644 index 000000000..891f427f4 --- /dev/null +++ b/alembic/versions/6c53224a7a58_add_provider_category_to_steps.py @@ -0,0 +1,31 @@ +"""add provider category to steps + +Revision ID: 6c53224a7a58 +Revises: cc8dc340836d +Create Date: 2025-05-21 10:09:43.761669 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "6c53224a7a58" +down_revision: Union[str, None] = "cc8dc340836d" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("steps", sa.Column("provider_category", sa.String(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("steps", "provider_category") + # ### end Alembic commands ### diff --git a/alembic/versions/cc8dc340836d_add_support_for_request_and_response_.py b/alembic/versions/cc8dc340836d_add_support_for_request_and_response_.py new file mode 100644 index 000000000..7ce2c0dcc --- /dev/null +++ b/alembic/versions/cc8dc340836d_add_support_for_request_and_response_.py @@ -0,0 +1,50 @@ +"""add support for request and response jsons from llm providers + +Revision ID: cc8dc340836d +Revises: 220856bbf43b +Create Date: 2025-05-19 14:25:41.999676 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "cc8dc340836d" +down_revision: Union[str, None] = "220856bbf43b" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "provider_traces", + sa.Column("id", sa.String(), nullable=False), + sa.Column("request_json", sa.JSON(), nullable=False), + sa.Column("response_json", sa.JSON(), nullable=False), + sa.Column("step_id", sa.String(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True), + sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False), + sa.Column("_created_by_id", sa.String(), nullable=True), + sa.Column("_last_updated_by_id", sa.String(), nullable=True), + sa.Column("organization_id", sa.String(), nullable=False), + sa.ForeignKeyConstraint( + ["organization_id"], + ["organizations.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index("ix_step_id", "provider_traces", ["step_id"], unique=False) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index("ix_step_id", table_name="provider_traces") + op.drop_table("provider_traces") + # ### end Alembic commands ### diff --git a/letta/__init__.py b/letta/__init__.py index e52fd5c07..772a17a9f 100644 --- a/letta/__init__.py +++ b/letta/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.7.20" +__version__ = "0.7.21" # import clients from letta.client.client import LocalClient, RESTClient, create_client diff --git a/letta/agent.py b/letta/agent.py index 8ca22f319..ded0e99d8 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -1,4 +1,6 @@ +import asyncio import json +import os import time import traceback import warnings @@ -7,6 +9,7 @@ from typing import Dict, List, Optional, Tuple, Union from openai.types.beta.function_tool import FunctionTool as OpenAITool +from letta.agents.helpers import generate_step_id from letta.constants import ( CLI_WARNING_PREFIX, COMPOSIO_ENTITY_ENV_VAR_KEY, @@ -16,6 +19,7 @@ from letta.constants import ( LETTA_CORE_TOOL_MODULE_NAME, LETTA_MULTI_AGENT_TOOL_MODULE_NAME, LLM_MAX_TOKENS, + READ_ONLY_BLOCK_EDIT_ERROR, REQ_HEARTBEAT_MESSAGE, SEND_MESSAGE_TOOL_NAME, ) @@ -41,7 +45,7 @@ from letta.orm.enums import ToolType from letta.schemas.agent import AgentState, AgentStepResponse, UpdateAgent, get_prompt_template_for_agent_type from letta.schemas.block import BlockUpdate from letta.schemas.embedding_config import EmbeddingConfig -from letta.schemas.enums import MessageRole +from letta.schemas.enums import MessageRole, ProviderType from letta.schemas.letta_message_content import TextContent from letta.schemas.memory import ContextWindowOverview, Memory from letta.schemas.message import Message, MessageCreate, ToolReturn @@ -61,9 +65,10 @@ from letta.services.message_manager import MessageManager from letta.services.passage_manager import PassageManager from letta.services.provider_manager import ProviderManager from letta.services.step_manager import StepManager +from letta.services.telemetry_manager import NoopTelemetryManager, TelemetryManager from letta.services.tool_executor.tool_execution_sandbox import ToolExecutionSandbox from letta.services.tool_manager import ToolManager -from letta.settings import summarizer_settings +from letta.settings import settings, summarizer_settings from letta.streaming_interface import StreamingRefreshCLIInterface from letta.system import get_heartbeat, get_token_limit_warning, package_function_response, package_summarize_message, package_user_message from letta.tracing import log_event, trace_method @@ -141,6 +146,7 @@ class Agent(BaseAgent): self.agent_manager = AgentManager() self.job_manager = JobManager() self.step_manager = StepManager() + self.telemetry_manager = TelemetryManager() if settings.llm_api_logging else NoopTelemetryManager() # State needed for heartbeat pausing @@ -298,6 +304,7 @@ class Agent(BaseAgent): step_count: Optional[int] = None, last_function_failed: bool = False, put_inner_thoughts_first: bool = True, + step_id: Optional[str] = None, ) -> ChatCompletionResponse | None: """Get response from LLM API with robust retry mechanism.""" log_telemetry(self.logger, "_get_ai_reply start") @@ -347,8 +354,9 @@ class Agent(BaseAgent): messages=message_sequence, llm_config=self.agent_state.llm_config, tools=allowed_functions, - stream=stream, force_tool_call=force_tool_call, + telemetry_manager=self.telemetry_manager, + step_id=step_id, ) else: # Fallback to existing flow @@ -365,6 +373,9 @@ class Agent(BaseAgent): stream_interface=self.interface, put_inner_thoughts_first=put_inner_thoughts_first, name=self.agent_state.name, + telemetry_manager=self.telemetry_manager, + step_id=step_id, + actor=self.user, ) log_telemetry(self.logger, "_get_ai_reply create finish") @@ -840,6 +851,9 @@ class Agent(BaseAgent): # Extract job_id from metadata if present job_id = metadata.get("job_id") if metadata else None + # Declare step_id for the given step to be used as the step is processing. + step_id = generate_step_id() + # Step 0: update core memory # only pulling latest block data if shared memory is being used current_persisted_memory = Memory( @@ -870,6 +884,7 @@ class Agent(BaseAgent): step_count=step_count, last_function_failed=last_function_failed, put_inner_thoughts_first=put_inner_thoughts_first, + step_id=step_id, ) if not response: # EDGE CASE: Function call failed AND there's no tools left for agent to call -> return early @@ -944,6 +959,7 @@ class Agent(BaseAgent): actor=self.user, agent_id=self.agent_state.id, provider_name=self.agent_state.llm_config.model_endpoint_type, + provider_category=self.agent_state.llm_config.provider_category or "base", model=self.agent_state.llm_config.model, model_endpoint=self.agent_state.llm_config.model_endpoint, context_window_limit=self.agent_state.llm_config.context_window, @@ -953,6 +969,7 @@ class Agent(BaseAgent): actor=self.user, ), job_id=job_id, + step_id=step_id, ) for message in all_new_messages: message.step_id = step.id @@ -1255,6 +1272,276 @@ class Agent(BaseAgent): functions_definitions=available_functions_definitions, ) + async def get_context_window_async(self) -> ContextWindowOverview: + if os.getenv("LETTA_ENVIRONMENT") == "PRODUCTION": + return await self.get_context_window_from_anthropic_async() + return await self.get_context_window_from_tiktoken_async() + + async def get_context_window_from_tiktoken_async(self) -> ContextWindowOverview: + """Get the context window of the agent""" + # Grab the in-context messages + # conversion of messages to OpenAI dict format, which is passed to the token counter + (in_context_messages, passage_manager_size, message_manager_size) = await asyncio.gather( + self.agent_manager.get_in_context_messages_async(agent_id=self.agent_state.id, actor=self.user), + self.passage_manager.size_async(actor=self.user, agent_id=self.agent_state.id), + self.message_manager.size_async(actor=self.user, agent_id=self.agent_state.id), + ) + in_context_messages_openai = [m.to_openai_dict() for m in in_context_messages] + + # Extract system, memory and external summary + if ( + len(in_context_messages) > 0 + and in_context_messages[0].role == MessageRole.system + and in_context_messages[0].content + and len(in_context_messages[0].content) == 1 + and isinstance(in_context_messages[0].content[0], TextContent) + ): + system_message = in_context_messages[0].content[0].text + + external_memory_marker_pos = system_message.find("###") + core_memory_marker_pos = system_message.find("<", external_memory_marker_pos) + if external_memory_marker_pos != -1 and core_memory_marker_pos != -1: + system_prompt = system_message[:external_memory_marker_pos].strip() + external_memory_summary = system_message[external_memory_marker_pos:core_memory_marker_pos].strip() + core_memory = system_message[core_memory_marker_pos:].strip() + else: + # if no markers found, put everything in system message + system_prompt = system_message + external_memory_summary = "" + core_memory = "" + else: + # if no system message, fall back on agent's system prompt + system_prompt = self.agent_state.system + external_memory_summary = "" + core_memory = "" + + num_tokens_system = count_tokens(system_prompt) + num_tokens_core_memory = count_tokens(core_memory) + num_tokens_external_memory_summary = count_tokens(external_memory_summary) + + # Check if there's a summary message in the message queue + if ( + len(in_context_messages) > 1 + and in_context_messages[1].role == MessageRole.user + and in_context_messages[1].content + and len(in_context_messages[1].content) == 1 + and isinstance(in_context_messages[1].content[0], TextContent) + # TODO remove hardcoding + and "The following is a summary of the previous " in in_context_messages[1].content[0].text + ): + # Summary message exists + text_content = in_context_messages[1].content[0].text + assert text_content is not None + summary_memory = text_content + num_tokens_summary_memory = count_tokens(text_content) + # with a summary message, the real messages start at index 2 + num_tokens_messages = ( + num_tokens_from_messages(messages=in_context_messages_openai[2:], model=self.model) + if len(in_context_messages_openai) > 2 + else 0 + ) + + else: + summary_memory = None + num_tokens_summary_memory = 0 + # with no summary message, the real messages start at index 1 + num_tokens_messages = ( + num_tokens_from_messages(messages=in_context_messages_openai[1:], model=self.model) + if len(in_context_messages_openai) > 1 + else 0 + ) + + # tokens taken up by function definitions + agent_state_tool_jsons = [t.json_schema for t in self.agent_state.tools] + if agent_state_tool_jsons: + available_functions_definitions = [OpenAITool(type="function", function=f) for f in agent_state_tool_jsons] + num_tokens_available_functions_definitions = num_tokens_from_functions(functions=agent_state_tool_jsons, model=self.model) + else: + available_functions_definitions = [] + num_tokens_available_functions_definitions = 0 + + num_tokens_used_total = ( + num_tokens_system # system prompt + + num_tokens_available_functions_definitions # function definitions + + num_tokens_core_memory # core memory + + num_tokens_external_memory_summary # metadata (statistics) about recall/archival + + num_tokens_summary_memory # summary of ongoing conversation + + num_tokens_messages # tokens taken by messages + ) + assert isinstance(num_tokens_used_total, int) + + return ContextWindowOverview( + # context window breakdown (in messages) + num_messages=len(in_context_messages), + num_archival_memory=passage_manager_size, + num_recall_memory=message_manager_size, + num_tokens_external_memory_summary=num_tokens_external_memory_summary, + external_memory_summary=external_memory_summary, + # top-level information + context_window_size_max=self.agent_state.llm_config.context_window, + context_window_size_current=num_tokens_used_total, + # context window breakdown (in tokens) + num_tokens_system=num_tokens_system, + system_prompt=system_prompt, + num_tokens_core_memory=num_tokens_core_memory, + core_memory=core_memory, + num_tokens_summary_memory=num_tokens_summary_memory, + summary_memory=summary_memory, + num_tokens_messages=num_tokens_messages, + messages=in_context_messages, + # related to functions + num_tokens_functions_definitions=num_tokens_available_functions_definitions, + functions_definitions=available_functions_definitions, + ) + + async def get_context_window_from_anthropic_async(self) -> ContextWindowOverview: + """Get the context window of the agent""" + anthropic_client = LLMClient.create(provider_type=ProviderType.anthropic, actor=self.user) + model = self.agent_state.llm_config.model if self.agent_state.llm_config.model_endpoint_type == "anthropic" else None + + # Grab the in-context messages + # conversion of messages to anthropic dict format, which is passed to the token counter + (in_context_messages, passage_manager_size, message_manager_size) = await asyncio.gather( + self.agent_manager.get_in_context_messages_async(agent_id=self.agent_state.id, actor=self.user), + self.passage_manager.size_async(actor=self.user, agent_id=self.agent_state.id), + self.message_manager.size_async(actor=self.user, agent_id=self.agent_state.id), + ) + in_context_messages_anthropic = [m.to_anthropic_dict() for m in in_context_messages] + + # Extract system, memory and external summary + if ( + len(in_context_messages) > 0 + and in_context_messages[0].role == MessageRole.system + and in_context_messages[0].content + and len(in_context_messages[0].content) == 1 + and isinstance(in_context_messages[0].content[0], TextContent) + ): + system_message = in_context_messages[0].content[0].text + + external_memory_marker_pos = system_message.find("###") + core_memory_marker_pos = system_message.find("<", external_memory_marker_pos) + if external_memory_marker_pos != -1 and core_memory_marker_pos != -1: + system_prompt = system_message[:external_memory_marker_pos].strip() + external_memory_summary = system_message[external_memory_marker_pos:core_memory_marker_pos].strip() + core_memory = system_message[core_memory_marker_pos:].strip() + else: + # if no markers found, put everything in system message + system_prompt = system_message + external_memory_summary = None + core_memory = None + else: + # if no system message, fall back on agent's system prompt + system_prompt = self.agent_state.system + external_memory_summary = None + core_memory = None + + num_tokens_system_coroutine = anthropic_client.count_tokens(model=model, messages=[{"role": "user", "content": system_prompt}]) + num_tokens_core_memory_coroutine = ( + anthropic_client.count_tokens(model=model, messages=[{"role": "user", "content": core_memory}]) + if core_memory + else asyncio.sleep(0, result=0) + ) + num_tokens_external_memory_summary_coroutine = ( + anthropic_client.count_tokens(model=model, messages=[{"role": "user", "content": external_memory_summary}]) + if external_memory_summary + else asyncio.sleep(0, result=0) + ) + + # Check if there's a summary message in the message queue + if ( + len(in_context_messages) > 1 + and in_context_messages[1].role == MessageRole.user + and in_context_messages[1].content + and len(in_context_messages[1].content) == 1 + and isinstance(in_context_messages[1].content[0], TextContent) + # TODO remove hardcoding + and "The following is a summary of the previous " in in_context_messages[1].content[0].text + ): + # Summary message exists + text_content = in_context_messages[1].content[0].text + assert text_content is not None + summary_memory = text_content + num_tokens_summary_memory_coroutine = anthropic_client.count_tokens( + model=model, messages=[{"role": "user", "content": summary_memory}] + ) + # with a summary message, the real messages start at index 2 + num_tokens_messages_coroutine = ( + anthropic_client.count_tokens(model=model, messages=in_context_messages_anthropic[2:]) + if len(in_context_messages_anthropic) > 2 + else asyncio.sleep(0, result=0) + ) + + else: + summary_memory = None + num_tokens_summary_memory_coroutine = asyncio.sleep(0, result=0) + # with no summary message, the real messages start at index 1 + num_tokens_messages_coroutine = ( + anthropic_client.count_tokens(model=model, messages=in_context_messages_anthropic[1:]) + if len(in_context_messages_anthropic) > 1 + else asyncio.sleep(0, result=0) + ) + + # tokens taken up by function definitions + if self.agent_state.tools and len(self.agent_state.tools) > 0: + available_functions_definitions = [OpenAITool(type="function", function=f.json_schema) for f in self.agent_state.tools] + num_tokens_available_functions_definitions_coroutine = anthropic_client.count_tokens( + model=model, + tools=available_functions_definitions, + ) + else: + available_functions_definitions = [] + num_tokens_available_functions_definitions_coroutine = asyncio.sleep(0, result=0) + + ( + num_tokens_system, + num_tokens_core_memory, + num_tokens_external_memory_summary, + num_tokens_summary_memory, + num_tokens_messages, + num_tokens_available_functions_definitions, + ) = await asyncio.gather( + num_tokens_system_coroutine, + num_tokens_core_memory_coroutine, + num_tokens_external_memory_summary_coroutine, + num_tokens_summary_memory_coroutine, + num_tokens_messages_coroutine, + num_tokens_available_functions_definitions_coroutine, + ) + + num_tokens_used_total = ( + num_tokens_system # system prompt + + num_tokens_available_functions_definitions # function definitions + + num_tokens_core_memory # core memory + + num_tokens_external_memory_summary # metadata (statistics) about recall/archival + + num_tokens_summary_memory # summary of ongoing conversation + + num_tokens_messages # tokens taken by messages + ) + assert isinstance(num_tokens_used_total, int) + + return ContextWindowOverview( + # context window breakdown (in messages) + num_messages=len(in_context_messages), + num_archival_memory=passage_manager_size, + num_recall_memory=message_manager_size, + num_tokens_external_memory_summary=num_tokens_external_memory_summary, + external_memory_summary=external_memory_summary, + # top-level information + context_window_size_max=self.agent_state.llm_config.context_window, + context_window_size_current=num_tokens_used_total, + # context window breakdown (in tokens) + num_tokens_system=num_tokens_system, + system_prompt=system_prompt, + num_tokens_core_memory=num_tokens_core_memory, + core_memory=core_memory, + num_tokens_summary_memory=num_tokens_summary_memory, + summary_memory=summary_memory, + num_tokens_messages=num_tokens_messages, + messages=in_context_messages, + # related to functions + num_tokens_functions_definitions=num_tokens_available_functions_definitions, + functions_definitions=available_functions_definitions, + ) + def count_tokens(self) -> int: """Count the tokens in the current context window""" context_window_breakdown = self.get_context_window() diff --git a/letta/agents/base_agent.py b/letta/agents/base_agent.py index 018d6300f..a349366dc 100644 --- a/letta/agents/base_agent.py +++ b/letta/agents/base_agent.py @@ -72,61 +72,6 @@ class BaseAgent(ABC): return [{"role": input_message.role.value, "content": get_content(input_message)} for input_message in input_messages] - def _rebuild_memory( - self, - in_context_messages: List[Message], - agent_state: AgentState, - num_messages: int | None = None, # storing these calculations is specific to the voice agent - num_archival_memories: int | None = None, - ) -> List[Message]: - try: - # Refresh Memory - # TODO: This only happens for the summary block (voice?) - # [DB Call] loading blocks (modifies: agent_state.memory.blocks) - self.agent_manager.refresh_memory(agent_state=agent_state, actor=self.actor) - - # TODO: This is a pretty brittle pattern established all over our code, need to get rid of this - curr_system_message = in_context_messages[0] - curr_memory_str = agent_state.memory.compile() - curr_system_message_text = curr_system_message.content[0].text - if curr_memory_str in curr_system_message_text: - # NOTE: could this cause issues if a block is removed? (substring match would still work) - logger.debug( - f"Memory hasn't changed for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name}), skipping system prompt rebuild" - ) - return in_context_messages - - memory_edit_timestamp = get_utc_time() - - # [DB Call] size of messages and archival memories - num_messages = num_messages or self.message_manager.size(actor=self.actor, agent_id=agent_state.id) - num_archival_memories = num_archival_memories or self.passage_manager.size(actor=self.actor, agent_id=agent_state.id) - - new_system_message_str = compile_system_message( - system_prompt=agent_state.system, - in_context_memory=agent_state.memory, - in_context_memory_last_edit=memory_edit_timestamp, - previous_message_count=num_messages, - archival_memory_size=num_archival_memories, - ) - - diff = united_diff(curr_system_message_text, new_system_message_str) - if len(diff) > 0: - logger.debug(f"Rebuilding system with new memory...\nDiff:\n{diff}") - - # [DB Call] Update Messages - new_system_message = self.message_manager.update_message_by_id( - curr_system_message.id, message_update=MessageUpdate(content=new_system_message_str), actor=self.actor - ) - # Skip pulling down the agent's memory again to save on a db call - return [new_system_message] + in_context_messages[1:] - - else: - return in_context_messages - except: - logger.exception(f"Failed to rebuild memory for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name})") - raise - async def _rebuild_memory_async( self, in_context_messages: List[Message], diff --git a/letta/agents/helpers.py b/letta/agents/helpers.py index 5578d1fb6..3a525e7a3 100644 --- a/letta/agents/helpers.py +++ b/letta/agents/helpers.py @@ -1,3 +1,4 @@ +import uuid import xml.etree.ElementTree as ET from typing import List, Tuple @@ -150,3 +151,7 @@ def deserialize_message_history(xml_str: str) -> Tuple[List[str], str]: context = sum_el.text or "" return messages, context + + +def generate_step_id(): + return f"step-{uuid.uuid4()}" diff --git a/letta/agents/letta_agent.py b/letta/agents/letta_agent.py index 78bc5c629..4afa51857 100644 --- a/letta/agents/letta_agent.py +++ b/letta/agents/letta_agent.py @@ -8,8 +8,9 @@ from openai.types import CompletionUsage from openai.types.chat import ChatCompletion, ChatCompletionChunk from letta.agents.base_agent import BaseAgent -from letta.agents.helpers import _create_letta_response, _prepare_in_context_messages_async +from letta.agents.helpers import _create_letta_response, _prepare_in_context_messages_async, generate_step_id from letta.helpers import ToolRulesSolver +from letta.helpers.datetime_helpers import get_utc_timestamp_ns from letta.helpers.tool_execution_helper import enable_strict_mode from letta.interfaces.anthropic_streaming_interface import AnthropicStreamingInterface from letta.interfaces.openai_streaming_interface import OpenAIStreamingInterface @@ -24,7 +25,8 @@ from letta.schemas.letta_message import AssistantMessage from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, RedactedReasoningContent, TextContent from letta.schemas.letta_response import LettaResponse from letta.schemas.message import Message, MessageCreate -from letta.schemas.openai.chat_completion_response import ToolCall +from letta.schemas.openai.chat_completion_response import ToolCall, UsageStatistics +from letta.schemas.provider_trace import ProviderTraceCreate from letta.schemas.usage import LettaUsageStatistics from letta.schemas.user import User from letta.server.rest_api.utils import create_letta_messages_from_llm_response @@ -32,10 +34,11 @@ from letta.services.agent_manager import AgentManager from letta.services.block_manager import BlockManager from letta.services.message_manager import MessageManager from letta.services.passage_manager import PassageManager +from letta.services.step_manager import NoopStepManager, StepManager +from letta.services.telemetry_manager import NoopTelemetryManager, TelemetryManager from letta.services.tool_executor.tool_execution_manager import ToolExecutionManager -from letta.settings import settings from letta.system import package_function_response -from letta.tracing import log_event, trace_method +from letta.tracing import log_event, trace_method, tracer logger = get_logger(__name__) @@ -50,6 +53,8 @@ class LettaAgent(BaseAgent): block_manager: BlockManager, passage_manager: PassageManager, actor: User, + step_manager: StepManager = NoopStepManager(), + telemetry_manager: TelemetryManager = NoopTelemetryManager(), ): super().__init__(agent_id=agent_id, openai_client=None, message_manager=message_manager, agent_manager=agent_manager, actor=actor) @@ -57,6 +62,8 @@ class LettaAgent(BaseAgent): # Summarizer settings self.block_manager = block_manager self.passage_manager = passage_manager + self.step_manager = step_manager + self.telemetry_manager = telemetry_manager self.response_messages: List[Message] = [] self.last_function_response = None @@ -67,17 +74,19 @@ class LettaAgent(BaseAgent): @trace_method async def step(self, input_messages: List[MessageCreate], max_steps: int = 10, use_assistant_message: bool = True) -> LettaResponse: - agent_state = await self.agent_manager.get_agent_by_id_async(self.agent_id, actor=self.actor) - current_in_context_messages, new_in_context_messages, usage = await self._step( - agent_state=agent_state, input_messages=input_messages, max_steps=max_steps + agent_state = await self.agent_manager.get_agent_by_id_async( + agent_id=self.agent_id, include_relationships=["tools", "memory"], actor=self.actor ) + _, new_in_context_messages, usage = await self._step(agent_state=agent_state, input_messages=input_messages, max_steps=max_steps) return _create_letta_response( new_in_context_messages=new_in_context_messages, use_assistant_message=use_assistant_message, usage=usage ) - async def _step( - self, agent_state: AgentState, input_messages: List[MessageCreate], max_steps: int = 10 - ) -> Tuple[List[Message], List[Message], CompletionUsage]: + @trace_method + async def step_stream_no_tokens(self, input_messages: List[MessageCreate], max_steps: int = 10, use_assistant_message: bool = True): + agent_state = await self.agent_manager.get_agent_by_id_async( + agent_id=self.agent_id, include_relationships=["tools", "memory"], actor=self.actor + ) current_in_context_messages, new_in_context_messages = await _prepare_in_context_messages_async( input_messages, agent_state, self.message_manager, self.actor ) @@ -89,23 +98,81 @@ class LettaAgent(BaseAgent): ) usage = LettaUsageStatistics() for _ in range(max_steps): - response = await self._get_ai_reply( + step_id = generate_step_id() + + in_context_messages = await self._rebuild_memory_async( + current_in_context_messages + new_in_context_messages, + agent_state, + num_messages=self.num_messages, + num_archival_memories=self.num_archival_memories, + ) + log_event("agent.stream_no_tokens.messages.refreshed") # [1^] + + request_data = await self._create_llm_request_data_async( llm_client=llm_client, - in_context_messages=current_in_context_messages + new_in_context_messages, + in_context_messages=in_context_messages, agent_state=agent_state, tool_rules_solver=tool_rules_solver, - stream=False, - # TODO: also pass in reasoning content + # TODO: pass in reasoning content ) + log_event("agent.stream_no_tokens.llm_request.created") # [2^] + try: + response_data = await llm_client.request_async(request_data, agent_state.llm_config) + except Exception as e: + raise llm_client.handle_llm_error(e) + log_event("agent.stream_no_tokens.llm_response.received") # [3^] + + response = llm_client.convert_response_to_chat_completion(response_data, in_context_messages, agent_state.llm_config) + + # update usage + # TODO: add run_id + usage.step_count += 1 + usage.completion_tokens += response.usage.completion_tokens + usage.prompt_tokens += response.usage.prompt_tokens + usage.total_tokens += response.usage.total_tokens + + if not response.choices[0].message.tool_calls: + # TODO: make into a real error + raise ValueError("No tool calls found in response, model must make a tool call") tool_call = response.choices[0].message.tool_calls[0] - reasoning = [TextContent(text=response.choices[0].message.content)] # reasoning placed into content for legacy reasons + if response.choices[0].message.reasoning_content: + reasoning = [ + ReasoningContent( + reasoning=response.choices[0].message.reasoning_content, + is_native=True, + signature=response.choices[0].message.reasoning_content_signature, + ) + ] + else: + reasoning = [TextContent(text=response.choices[0].message.content)] # reasoning placed into content for legacy reasons persisted_messages, should_continue = await self._handle_ai_response( - tool_call, agent_state, tool_rules_solver, reasoning_content=reasoning + tool_call, agent_state, tool_rules_solver, response.usage, reasoning_content=reasoning ) self.response_messages.extend(persisted_messages) new_in_context_messages.extend(persisted_messages) + log_event("agent.stream_no_tokens.llm_response.processed") # [4^] + + # Log LLM Trace + await self.telemetry_manager.create_provider_trace_async( + actor=self.actor, + provider_trace_create=ProviderTraceCreate( + request_json=request_data, + response_json=response_data, + step_id=step_id, + organization_id=self.actor.organization_id, + ), + ) + + # stream step + # TODO: improve TTFT + filter_user_messages = [m for m in persisted_messages if m.role != "user"] + letta_messages = Message.to_letta_messages_from_list( + filter_user_messages, use_assistant_message=use_assistant_message, reverse=False + ) + for message in letta_messages: + yield f"data: {message.model_dump_json()}\n\n" # update usage # TODO: add run_id @@ -122,17 +189,125 @@ class LettaAgent(BaseAgent): message_ids = [m.id for m in (current_in_context_messages + new_in_context_messages)] self.agent_manager.set_in_context_messages(agent_id=self.agent_id, message_ids=message_ids, actor=self.actor) + # Return back usage + yield f"data: {usage.model_dump_json()}\n\n" + + async def _step( + self, agent_state: AgentState, input_messages: List[MessageCreate], max_steps: int = 10 + ) -> Tuple[List[Message], List[Message], CompletionUsage]: + """ + Carries out an invocation of the agent loop. In each step, the agent + 1. Rebuilds its memory + 2. Generates a request for the LLM + 3. Fetches a response from the LLM + 4. Processes the response + """ + current_in_context_messages, new_in_context_messages = await _prepare_in_context_messages_async( + input_messages, agent_state, self.message_manager, self.actor + ) + tool_rules_solver = ToolRulesSolver(agent_state.tool_rules) + llm_client = LLMClient.create( + provider_type=agent_state.llm_config.model_endpoint_type, + put_inner_thoughts_first=True, + actor=self.actor, + ) + usage = LettaUsageStatistics() + for _ in range(max_steps): + step_id = generate_step_id() + + in_context_messages = await self._rebuild_memory_async( + current_in_context_messages + new_in_context_messages, + agent_state, + num_messages=self.num_messages, + num_archival_memories=self.num_archival_memories, + ) + log_event("agent.step.messages.refreshed") # [1^] + + request_data = await self._create_llm_request_data_async( + llm_client=llm_client, + in_context_messages=in_context_messages, + agent_state=agent_state, + tool_rules_solver=tool_rules_solver, + # TODO: pass in reasoning content + ) + log_event("agent.step.llm_request.created") # [2^] + + try: + response_data = await llm_client.request_async(request_data, agent_state.llm_config) + except Exception as e: + raise llm_client.handle_llm_error(e) + log_event("agent.step.llm_response.received") # [3^] + + response = llm_client.convert_response_to_chat_completion(response_data, in_context_messages, agent_state.llm_config) + + # TODO: add run_id + usage.step_count += 1 + usage.completion_tokens += response.usage.completion_tokens + usage.prompt_tokens += response.usage.prompt_tokens + usage.total_tokens += response.usage.total_tokens + + if not response.choices[0].message.tool_calls: + # TODO: make into a real error + raise ValueError("No tool calls found in response, model must make a tool call") + tool_call = response.choices[0].message.tool_calls[0] + if response.choices[0].message.reasoning_content: + reasoning = [ + ReasoningContent( + reasoning=response.choices[0].message.reasoning_content, + is_native=True, + signature=response.choices[0].message.reasoning_content_signature, + ) + ] + else: + reasoning = [TextContent(text=response.choices[0].message.content)] # reasoning placed into content for legacy reasons + + persisted_messages, should_continue = await self._handle_ai_response( + tool_call, agent_state, tool_rules_solver, response.usage, reasoning_content=reasoning, step_id=step_id + ) + self.response_messages.extend(persisted_messages) + new_in_context_messages.extend(persisted_messages) + log_event("agent.step.llm_response.processed") # [4^] + + # Log LLM Trace + await self.telemetry_manager.create_provider_trace_async( + actor=self.actor, + provider_trace_create=ProviderTraceCreate( + request_json=request_data, + response_json=response_data, + step_id=step_id, + organization_id=self.actor.organization_id, + ), + ) + + if not should_continue: + break + + # Extend the in context message ids + if not agent_state.message_buffer_autoclear: + message_ids = [m.id for m in (current_in_context_messages + new_in_context_messages)] + self.agent_manager.set_in_context_messages(agent_id=self.agent_id, message_ids=message_ids, actor=self.actor) + return current_in_context_messages, new_in_context_messages, usage @trace_method async def step_stream( - self, input_messages: List[MessageCreate], max_steps: int = 10, use_assistant_message: bool = True, stream_tokens: bool = False + self, + input_messages: List[MessageCreate], + max_steps: int = 10, + use_assistant_message: bool = True, + request_start_timestamp_ns: Optional[int] = None, ) -> AsyncGenerator[str, None]: """ - Main streaming loop that yields partial tokens. - Whenever we detect a tool call, we yield from _handle_ai_response as well. + Carries out an invocation of the agent loop in a streaming fashion that yields partial tokens. + Whenever we detect a tool call, we yield from _handle_ai_response as well. At each step, the agent + 1. Rebuilds its memory + 2. Generates a request for the LLM + 3. Fetches a response from the LLM + 4. Processes the response """ - agent_state = await self.agent_manager.get_agent_by_id_async(self.agent_id, actor=self.actor) + agent_state = await self.agent_manager.get_agent_by_id_async( + agent_id=self.agent_id, include_relationships=["tools", "memory"], actor=self.actor + ) current_in_context_messages, new_in_context_messages = await _prepare_in_context_messages_async( input_messages, agent_state, self.message_manager, self.actor ) @@ -145,13 +320,29 @@ class LettaAgent(BaseAgent): usage = LettaUsageStatistics() for _ in range(max_steps): - stream = await self._get_ai_reply( + step_id = generate_step_id() + in_context_messages = await self._rebuild_memory_async( + current_in_context_messages + new_in_context_messages, + agent_state, + num_messages=self.num_messages, + num_archival_memories=self.num_archival_memories, + ) + log_event("agent.step.messages.refreshed") # [1^] + + request_data = await self._create_llm_request_data_async( llm_client=llm_client, - in_context_messages=current_in_context_messages + new_in_context_messages, + in_context_messages=in_context_messages, agent_state=agent_state, tool_rules_solver=tool_rules_solver, - stream=True, ) + log_event("agent.stream.llm_request.created") # [2^] + + try: + stream = await llm_client.stream_async(request_data, agent_state.llm_config) + except Exception as e: + raise llm_client.handle_llm_error(e) + log_event("agent.stream.llm_response.received") # [3^] + # TODO: THIS IS INCREDIBLY UGLY # TODO: THERE ARE MULTIPLE COPIES OF THE LLM_CONFIG EVERYWHERE THAT ARE GETTING MANIPULATED if agent_state.llm_config.model_endpoint_type == "anthropic": @@ -164,7 +355,23 @@ class LettaAgent(BaseAgent): use_assistant_message=use_assistant_message, put_inner_thoughts_in_kwarg=agent_state.llm_config.put_inner_thoughts_in_kwargs, ) + else: + raise ValueError(f"Streaming not supported for {agent_state.llm_config}") + + first_chunk, ttft_span = True, None + if request_start_timestamp_ns is not None: + ttft_span = tracer.start_span("time_to_first_token", start_time=request_start_timestamp_ns) + ttft_span.set_attributes({f"llm_config.{k}": v for k, v in agent_state.llm_config.model_dump().items() if v is not None}) + async for chunk in interface.process(stream): + # Measure time to first token + if first_chunk and ttft_span is not None: + now = get_utc_timestamp_ns() + ttft_ns = now - request_start_timestamp_ns + ttft_span.add_event(name="time_to_first_token_ms", attributes={"ttft_ms": ttft_ns // 1_000_000}) + ttft_span.end() + first_chunk = False + yield f"data: {chunk.model_dump_json()}\n\n" # update usage @@ -180,13 +387,46 @@ class LettaAgent(BaseAgent): tool_call, agent_state, tool_rules_solver, + UsageStatistics( + completion_tokens=interface.output_tokens, + prompt_tokens=interface.input_tokens, + total_tokens=interface.input_tokens + interface.output_tokens, + ), reasoning_content=reasoning_content, pre_computed_assistant_message_id=interface.letta_assistant_message_id, pre_computed_tool_message_id=interface.letta_tool_message_id, + step_id=step_id, ) self.response_messages.extend(persisted_messages) new_in_context_messages.extend(persisted_messages) + # TODO (cliandy): the stream POST request span has ended at this point, we should tie this to the stream + # log_event("agent.stream.llm_response.processed") # [4^] + + # Log LLM Trace + # TODO (cliandy): we are piecing together the streamed response here. Content here does not match the actual response schema. + await self.telemetry_manager.create_provider_trace_async( + actor=self.actor, + provider_trace_create=ProviderTraceCreate( + request_json=request_data, + response_json={ + "content": { + "tool_call": tool_call.model_dump_json(), + "reasoning": [content.model_dump_json() for content in reasoning_content], + }, + "id": interface.message_id, + "model": interface.model, + "role": "assistant", + # "stop_reason": "", + # "stop_sequence": None, + "type": "message", + "usage": {"input_tokens": interface.input_tokens, "output_tokens": interface.output_tokens}, + }, + step_id=step_id, + organization_id=self.actor.organization_id, + ), + ) + if not use_assistant_message or should_continue: tool_return = [msg for msg in persisted_messages if msg.role == "tool"][-1].to_letta_messages()[0] yield f"data: {tool_return.model_dump_json()}\n\n" @@ -209,28 +449,20 @@ class LettaAgent(BaseAgent): yield f"data: {MessageStreamStatus.done.model_dump_json()}\n\n" @trace_method - # When raising an error this doesn't show up - async def _get_ai_reply( + async def _create_llm_request_data_async( self, llm_client: LLMClientBase, in_context_messages: List[Message], agent_state: AgentState, tool_rules_solver: ToolRulesSolver, - stream: bool, ) -> ChatCompletion | AsyncStream[ChatCompletionChunk]: - if settings.experimental_enable_async_db_engine: - self.num_messages = self.num_messages or (await self.message_manager.size_async(actor=self.actor, agent_id=agent_state.id)) - self.num_archival_memories = self.num_archival_memories or ( - await self.passage_manager.size_async(actor=self.actor, agent_id=agent_state.id) - ) - in_context_messages = await self._rebuild_memory_async( - in_context_messages, agent_state, num_messages=self.num_messages, num_archival_memories=self.num_archival_memories - ) - else: - if settings.experimental_skip_rebuild_memory and agent_state.llm_config.model_endpoint_type == "google_vertex": - logger.info("Skipping memory rebuild") - else: - in_context_messages = self._rebuild_memory(in_context_messages, agent_state) + self.num_messages = self.num_messages or (await self.message_manager.size_async(actor=self.actor, agent_id=agent_state.id)) + self.num_archival_memories = self.num_archival_memories or ( + await self.passage_manager.size_async(actor=self.actor, agent_id=agent_state.id) + ) + in_context_messages = await self._rebuild_memory_async( + in_context_messages, agent_state, num_messages=self.num_messages, num_archival_memories=self.num_archival_memories + ) tools = [ t @@ -243,8 +475,8 @@ class LettaAgent(BaseAgent): ToolType.LETTA_MULTI_AGENT_CORE, ToolType.LETTA_SLEEPTIME_CORE, ToolType.LETTA_VOICE_SLEEPTIME_CORE, + ToolType.LETTA_BUILTIN, } - or (t.tool_type == ToolType.LETTA_MULTI_AGENT_CORE and t.name == "send_message_to_agents_matching_tags") or (t.tool_type == ToolType.EXTERNAL_COMPOSIO) ] @@ -264,15 +496,7 @@ class LettaAgent(BaseAgent): allowed_tools = [enable_strict_mode(t.json_schema) for t in tools if t.name in set(valid_tool_names)] - response = await llm_client.send_llm_request_async( - messages=in_context_messages, - llm_config=agent_state.llm_config, - tools=allowed_tools, - force_tool_call=force_tool_call, - stream=stream, - ) - - return response + return llm_client.build_request_data(in_context_messages, agent_state.llm_config, allowed_tools, force_tool_call) @trace_method async def _handle_ai_response( @@ -280,9 +504,11 @@ class LettaAgent(BaseAgent): tool_call: ToolCall, agent_state: AgentState, tool_rules_solver: ToolRulesSolver, + usage: UsageStatistics, reasoning_content: Optional[List[Union[TextContent, ReasoningContent, RedactedReasoningContent, OmittedReasoningContent]]] = None, pre_computed_assistant_message_id: Optional[str] = None, pre_computed_tool_message_id: Optional[str] = None, + step_id: str | None = None, ) -> Tuple[List[Message], bool]: """ Now that streaming is done, handle the final AI response. @@ -294,8 +520,11 @@ class LettaAgent(BaseAgent): try: tool_args = json.loads(tool_call_args_str) + assert isinstance(tool_args, dict), "tool_args must be a dict" except json.JSONDecodeError: tool_args = {} + except AssertionError: + tool_args = json.loads(tool_args) # Get request heartbeats and coerce to bool request_heartbeat = tool_args.pop("request_heartbeat", False) @@ -329,7 +558,25 @@ class LettaAgent(BaseAgent): elif tool_rules_solver.is_continue_tool(tool_name=tool_call_name): continue_stepping = True - # 5. Persist to DB + # 5a. Persist Steps to DB + # Following agent loop to persist this before messages + # TODO (cliandy): determine what should match old loop w/provider_id, job_id + # TODO (cliandy): UsageStatistics and LettaUsageStatistics are used in many places, but are not the same. + logged_step = await self.step_manager.log_step_async( + actor=self.actor, + agent_id=agent_state.id, + provider_name=agent_state.llm_config.model_endpoint_type, + provider_category=agent_state.llm_config.provider_category or "base", + model=agent_state.llm_config.model, + model_endpoint=agent_state.llm_config.model_endpoint, + context_window_limit=agent_state.llm_config.context_window, + usage=usage, + provider_id=None, + job_id=None, + step_id=step_id, + ) + + # 5b. Persist Messages to DB tool_call_messages = create_letta_messages_from_llm_response( agent_id=agent_state.id, model=agent_state.llm_config.model, @@ -343,6 +590,7 @@ class LettaAgent(BaseAgent): reasoning_content=reasoning_content, pre_computed_assistant_message_id=pre_computed_assistant_message_id, pre_computed_tool_message_id=pre_computed_tool_message_id, + step_id=logged_step.id if logged_step else None, # TODO (cliandy): eventually move over other agent loops ) persisted_messages = await self.message_manager.create_many_messages_async(tool_call_messages, actor=self.actor) self.last_function_response = function_response @@ -361,20 +609,21 @@ class LettaAgent(BaseAgent): # TODO: This temp. Move this logic and code to executors try: - if target_tool.name == "send_message_to_agents_matching_tags" and target_tool.tool_type == ToolType.LETTA_MULTI_AGENT_CORE: - log_event(name="start_send_message_to_agents_matching_tags", attributes=tool_args) - results = await self._send_message_to_agents_matching_tags(**tool_args) - log_event(name="finish_send_message_to_agents_matching_tags", attributes=tool_args) - return json.dumps(results), True - else: - tool_execution_manager = ToolExecutionManager(agent_state=agent_state, actor=self.actor) - # TODO: Integrate sandbox result - log_event(name=f"start_{tool_name}_execution", attributes=tool_args) - tool_execution_result = await tool_execution_manager.execute_tool_async( - function_name=tool_name, function_args=tool_args, tool=target_tool - ) - log_event(name=f"finish_{tool_name}_execution", attributes=tool_args) - return tool_execution_result.func_return, True + tool_execution_manager = ToolExecutionManager( + agent_state=agent_state, + message_manager=self.message_manager, + agent_manager=self.agent_manager, + block_manager=self.block_manager, + passage_manager=self.passage_manager, + actor=self.actor, + ) + # TODO: Integrate sandbox result + log_event(name=f"start_{tool_name}_execution", attributes=tool_args) + tool_execution_result = await tool_execution_manager.execute_tool_async( + function_name=tool_name, function_args=tool_args, tool=target_tool + ) + log_event(name=f"finish_{tool_name}_execution", attributes=tool_args) + return tool_execution_result.func_return, True except Exception as e: return f"Failed to call tool. Error: {e}", False @@ -430,6 +679,7 @@ class LettaAgent(BaseAgent): results = await asyncio.gather(*tasks) return results + @trace_method async def _load_last_function_response_async(self): """Load the last function response from message history""" in_context_messages = await self.agent_manager.get_in_context_messages_async(agent_id=self.agent_id, actor=self.actor) diff --git a/letta/agents/letta_agent_batch.py b/letta/agents/letta_agent_batch.py index 46800bccf..e2355ab54 100644 --- a/letta/agents/letta_agent_batch.py +++ b/letta/agents/letta_agent_batch.py @@ -145,7 +145,7 @@ class LettaAgentBatch(BaseAgent): agent_mapping = { agent_state.id: agent_state for agent_state in await self.agent_manager.get_agents_by_ids_async( - agent_ids=[request.agent_id for request in batch_requests], actor=self.actor + agent_ids=[request.agent_id for request in batch_requests], include_relationships=["tools", "memory"], actor=self.actor ) } @@ -267,64 +267,121 @@ class LettaAgentBatch(BaseAgent): @trace_method async def _collect_resume_context(self, llm_batch_id: str) -> _ResumeContext: - # NOTE: We only continue for items with successful results + """ + Collect context for resuming operations from completed batch items. + + Args: + llm_batch_id: The ID of the batch to collect context for + + Returns: + _ResumeContext object containing all necessary data for resumption + """ + # Fetch only completed batch items batch_items = await self.batch_manager.list_llm_batch_items_async(llm_batch_id=llm_batch_id, request_status=JobStatus.completed) - agent_ids = [] - provider_results = {} - request_status_updates: List[RequestStatusUpdateInfo] = [] + # Exit early if no items to process + if not batch_items: + return _ResumeContext( + batch_items=[], + agent_ids=[], + agent_state_map={}, + provider_results={}, + tool_call_name_map={}, + tool_call_args_map={}, + should_continue_map={}, + request_status_updates=[], + ) - for item in batch_items: - aid = item.agent_id - agent_ids.append(aid) - provider_results[aid] = item.batch_request_result.result + # Extract agent IDs and organize items by agent ID + agent_ids = [item.agent_id for item in batch_items] + batch_item_map = {item.agent_id: item for item in batch_items} - agent_states = await self.agent_manager.get_agents_by_ids_async(agent_ids, actor=self.actor) + # Collect provider results + provider_results = {item.agent_id: item.batch_request_result.result for item in batch_items} + + # Fetch agent states in a single call + agent_states = await self.agent_manager.get_agents_by_ids_async( + agent_ids=agent_ids, include_relationships=["tools", "memory"], actor=self.actor + ) agent_state_map = {agent.id: agent for agent in agent_states} - name_map, args_map, cont_map = {}, {}, {} - for aid in agent_ids: - # status bookkeeping - pr = provider_results[aid] - status = ( - JobStatus.completed - if isinstance(pr, BetaMessageBatchSucceededResult) - else ( - JobStatus.failed - if isinstance(pr, BetaMessageBatchErroredResult) - else JobStatus.cancelled if isinstance(pr, BetaMessageBatchCanceledResult) else JobStatus.expired - ) - ) - request_status_updates.append(RequestStatusUpdateInfo(llm_batch_id=llm_batch_id, agent_id=aid, request_status=status)) - - # translate provider‑specific response → OpenAI‑style tool call (unchanged) - llm_client = LLMClient.create( - provider_type=item.llm_config.model_endpoint_type, - put_inner_thoughts_first=True, - actor=self.actor, - ) - tool_call = ( - llm_client.convert_response_to_chat_completion( - response_data=pr.message.model_dump(), input_messages=[], llm_config=item.llm_config - ) - .choices[0] - .message.tool_calls[0] - ) - - name, args, cont = self._extract_tool_call_and_decide_continue(tool_call, item.step_state) - name_map[aid], args_map[aid], cont_map[aid] = name, args, cont + # Process each agent's results + tool_call_results = self._process_agent_results( + agent_ids=agent_ids, batch_item_map=batch_item_map, provider_results=provider_results, llm_batch_id=llm_batch_id + ) return _ResumeContext( batch_items=batch_items, agent_ids=agent_ids, agent_state_map=agent_state_map, provider_results=provider_results, - tool_call_name_map=name_map, - tool_call_args_map=args_map, - should_continue_map=cont_map, - request_status_updates=request_status_updates, + tool_call_name_map=tool_call_results.name_map, + tool_call_args_map=tool_call_results.args_map, + should_continue_map=tool_call_results.cont_map, + request_status_updates=tool_call_results.status_updates, ) + def _process_agent_results(self, agent_ids, batch_item_map, provider_results, llm_batch_id): + """ + Process the results for each agent, extracting tool calls and determining continuation status. + + Returns: + A namedtuple containing name_map, args_map, cont_map, and status_updates + """ + from collections import namedtuple + + ToolCallResults = namedtuple("ToolCallResults", ["name_map", "args_map", "cont_map", "status_updates"]) + + name_map, args_map, cont_map = {}, {}, {} + request_status_updates = [] + + for aid in agent_ids: + item = batch_item_map[aid] + result = provider_results[aid] + + # Determine job status based on result type + status = self._determine_job_status(result) + request_status_updates.append(RequestStatusUpdateInfo(llm_batch_id=llm_batch_id, agent_id=aid, request_status=status)) + + # Process tool calls + name, args, cont = self._extract_tool_call_from_result(item, result) + name_map[aid], args_map[aid], cont_map[aid] = name, args, cont + + return ToolCallResults(name_map, args_map, cont_map, request_status_updates) + + def _determine_job_status(self, result): + """Determine job status based on result type""" + if isinstance(result, BetaMessageBatchSucceededResult): + return JobStatus.completed + elif isinstance(result, BetaMessageBatchErroredResult): + return JobStatus.failed + elif isinstance(result, BetaMessageBatchCanceledResult): + return JobStatus.cancelled + else: + return JobStatus.expired + + def _extract_tool_call_from_result(self, item, result): + """Extract tool call information from a result""" + llm_client = LLMClient.create( + provider_type=item.llm_config.model_endpoint_type, + put_inner_thoughts_first=True, + actor=self.actor, + ) + + # If result isn't a successful type, we can't extract a tool call + if not isinstance(result, BetaMessageBatchSucceededResult): + return None, None, False + + tool_call = ( + llm_client.convert_response_to_chat_completion( + response_data=result.message.model_dump(), input_messages=[], llm_config=item.llm_config + ) + .choices[0] + .message.tool_calls[0] + ) + + return self._extract_tool_call_and_decide_continue(tool_call, item.step_state) + def _update_request_statuses(self, updates: List[RequestStatusUpdateInfo]) -> None: if updates: self.batch_manager.bulk_update_llm_batch_items_request_status_by_agent(updates=updates) @@ -556,16 +613,6 @@ class LettaAgentBatch(BaseAgent): in_context_messages = await self._rebuild_memory_async(current_in_context_messages + new_in_context_messages, agent_state) return in_context_messages - # TODO: Make this a bullk function - def _rebuild_memory( - self, - in_context_messages: List[Message], - agent_state: AgentState, - num_messages: int | None = None, - num_archival_memories: int | None = None, - ) -> List[Message]: - return super()._rebuild_memory(in_context_messages, agent_state) - # Not used in batch. async def step(self, input_messages: List[MessageCreate], max_steps: int = 10) -> LettaResponse: raise NotImplementedError diff --git a/letta/agents/voice_agent.py b/letta/agents/voice_agent.py index 1d0ab88c6..5451dc6ce 100644 --- a/letta/agents/voice_agent.py +++ b/letta/agents/voice_agent.py @@ -154,7 +154,7 @@ class VoiceAgent(BaseAgent): # TODO: Define max steps here for _ in range(max_steps): # Rebuild memory each loop - in_context_messages = self._rebuild_memory(in_context_messages, agent_state) + in_context_messages = await self._rebuild_memory_async(in_context_messages, agent_state) openai_messages = convert_in_context_letta_messages_to_openai(in_context_messages, exclude_system_messages=True) openai_messages.extend(in_memory_message_history) @@ -292,14 +292,14 @@ class VoiceAgent(BaseAgent): agent_id=self.agent_id, message_ids=[m.id for m in new_in_context_messages], actor=self.actor ) - def _rebuild_memory( + async def _rebuild_memory_async( self, in_context_messages: List[Message], agent_state: AgentState, num_messages: int | None = None, num_archival_memories: int | None = None, ) -> List[Message]: - return super()._rebuild_memory( + return await super()._rebuild_memory_async( in_context_messages, agent_state, num_messages=self.num_messages, num_archival_memories=self.num_archival_memories ) @@ -438,7 +438,7 @@ class VoiceAgent(BaseAgent): if start_date and end_date and start_date > end_date: start_date, end_date = end_date, start_date - archival_results = self.agent_manager.list_passages( + archival_results = await self.agent_manager.list_passages_async( actor=self.actor, agent_id=self.agent_id, query_text=archival_query, @@ -457,7 +457,7 @@ class VoiceAgent(BaseAgent): keyword_results = {} if convo_keyword_queries: for keyword in convo_keyword_queries: - messages = self.message_manager.list_messages_for_agent( + messages = await self.message_manager.list_messages_for_agent_async( agent_id=self.agent_id, actor=self.actor, query_text=keyword, diff --git a/letta/client/client.py b/letta/client/client.py index 802ca451e..90e39400b 100644 --- a/letta/client/client.py +++ b/letta/client/client.py @@ -2773,11 +2773,8 @@ class LocalClient(AbstractClient): # humans / personas - def get_block_id(self, name: str, label: str) -> str: - block = self.server.block_manager.get_blocks(actor=self.user, template_name=name, label=label, is_template=True) - if not block: - return None - return block[0].id + def get_block_id(self, name: str, label: str) -> str | None: + return None def create_human(self, name: str, text: str): """ @@ -2812,7 +2809,7 @@ class LocalClient(AbstractClient): Returns: humans (List[Human]): List of human blocks """ - return self.server.block_manager.get_blocks(actor=self.user, label="human", is_template=True) + return [] def list_personas(self) -> List[Persona]: """ @@ -2821,7 +2818,7 @@ class LocalClient(AbstractClient): Returns: personas (List[Persona]): List of persona blocks """ - return self.server.block_manager.get_blocks(actor=self.user, label="persona", is_template=True) + return [] def update_human(self, human_id: str, text: str): """ @@ -2879,7 +2876,7 @@ class LocalClient(AbstractClient): assert id, f"Human ID must be provided" return Human(**self.server.block_manager.get_block_by_id(id, actor=self.user).model_dump()) - def get_persona_id(self, name: str) -> str: + def get_persona_id(self, name: str) -> str | None: """ Get the ID of a persona block template @@ -2889,12 +2886,9 @@ class LocalClient(AbstractClient): Returns: id (str): ID of the persona block """ - persona = self.server.block_manager.get_blocks(actor=self.user, template_name=name, label="persona", is_template=True) - if not persona: - return None - return persona[0].id + return None - def get_human_id(self, name: str) -> str: + def get_human_id(self, name: str) -> str | None: """ Get the ID of a human block template @@ -2904,10 +2898,7 @@ class LocalClient(AbstractClient): Returns: id (str): ID of the human block """ - human = self.server.block_manager.get_blocks(actor=self.user, template_name=name, label="human", is_template=True) - if not human: - return None - return human[0].id + return None def delete_persona(self, id: str): """ @@ -3381,7 +3372,7 @@ class LocalClient(AbstractClient): Returns: blocks (List[Block]): List of blocks """ - return self.server.block_manager.get_blocks(actor=self.user, label=label, is_template=templates_only) + return [] def create_block( self, label: str, value: str, limit: Optional[int] = None, template_name: Optional[str] = None, is_template: bool = False diff --git a/letta/constants.py b/letta/constants.py index 1068c6141..1a13c6685 100644 --- a/letta/constants.py +++ b/letta/constants.py @@ -19,6 +19,7 @@ MCP_TOOL_TAG_NAME_PREFIX = "mcp" # full format, mcp:server_name LETTA_CORE_TOOL_MODULE_NAME = "letta.functions.function_sets.base" LETTA_MULTI_AGENT_TOOL_MODULE_NAME = "letta.functions.function_sets.multi_agent" LETTA_VOICE_TOOL_MODULE_NAME = "letta.functions.function_sets.voice" +LETTA_BUILTIN_TOOL_MODULE_NAME = "letta.functions.function_sets.builtin" # String in the error message for when the context window is too large @@ -83,9 +84,19 @@ BASE_VOICE_SLEEPTIME_TOOLS = [ ] # Multi agent tools MULTI_AGENT_TOOLS = ["send_message_to_agent_and_wait_for_reply", "send_message_to_agents_matching_tags", "send_message_to_agent_async"] + +# Built in tools +BUILTIN_TOOLS = ["run_code", "web_search"] + # Set of all built-in Letta tools LETTA_TOOL_SET = set( - BASE_TOOLS + BASE_MEMORY_TOOLS + MULTI_AGENT_TOOLS + BASE_SLEEPTIME_TOOLS + BASE_VOICE_SLEEPTIME_TOOLS + BASE_VOICE_SLEEPTIME_CHAT_TOOLS + BASE_TOOLS + + BASE_MEMORY_TOOLS + + MULTI_AGENT_TOOLS + + BASE_SLEEPTIME_TOOLS + + BASE_VOICE_SLEEPTIME_TOOLS + + BASE_VOICE_SLEEPTIME_CHAT_TOOLS + + BUILTIN_TOOLS ) # The name of the tool used to send message to the user @@ -179,6 +190,45 @@ LLM_MAX_TOKENS = { "gpt-3.5-turbo-0613": 4096, # legacy "gpt-3.5-turbo-16k-0613": 16385, # legacy "gpt-3.5-turbo-0301": 4096, # legacy + "gemini-1.0-pro-vision-latest": 12288, + "gemini-pro-vision": 12288, + "gemini-1.5-pro-latest": 2000000, + "gemini-1.5-pro-001": 2000000, + "gemini-1.5-pro-002": 2000000, + "gemini-1.5-pro": 2000000, + "gemini-1.5-flash-latest": 1000000, + "gemini-1.5-flash-001": 1000000, + "gemini-1.5-flash-001-tuning": 16384, + "gemini-1.5-flash": 1000000, + "gemini-1.5-flash-002": 1000000, + "gemini-1.5-flash-8b": 1000000, + "gemini-1.5-flash-8b-001": 1000000, + "gemini-1.5-flash-8b-latest": 1000000, + "gemini-1.5-flash-8b-exp-0827": 1000000, + "gemini-1.5-flash-8b-exp-0924": 1000000, + "gemini-2.5-pro-exp-03-25": 1048576, + "gemini-2.5-pro-preview-03-25": 1048576, + "gemini-2.5-flash-preview-04-17": 1048576, + "gemini-2.5-flash-preview-05-20": 1048576, + "gemini-2.5-flash-preview-04-17-thinking": 1048576, + "gemini-2.5-pro-preview-05-06": 1048576, + "gemini-2.0-flash-exp": 1048576, + "gemini-2.0-flash": 1048576, + "gemini-2.0-flash-001": 1048576, + "gemini-2.0-flash-exp-image-generation": 1048576, + "gemini-2.0-flash-lite-001": 1048576, + "gemini-2.0-flash-lite": 1048576, + "gemini-2.0-flash-preview-image-generation": 32768, + "gemini-2.0-flash-lite-preview-02-05": 1048576, + "gemini-2.0-flash-lite-preview": 1048576, + "gemini-2.0-pro-exp": 1048576, + "gemini-2.0-pro-exp-02-05": 1048576, + "gemini-exp-1206": 1048576, + "gemini-2.0-flash-thinking-exp-01-21": 1048576, + "gemini-2.0-flash-thinking-exp": 1048576, + "gemini-2.0-flash-thinking-exp-1219": 1048576, + "gemini-2.5-flash-preview-tts": 32768, + "gemini-2.5-pro-preview-tts": 65536, } # The error message that Letta will receive # MESSAGE_SUMMARY_WARNING_STR = f"Warning: the conversation history will soon reach its maximum length and be trimmed. Make sure to save any important information from the conversation to your memory before it is removed." @@ -230,3 +280,7 @@ RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE = 5 MAX_FILENAME_LENGTH = 255 RESERVED_FILENAMES = {"CON", "PRN", "AUX", "NUL", "COM1", "COM2", "LPT1", "LPT2"} + +WEB_SEARCH_CLIP_CONTENT = False +WEB_SEARCH_INCLUDE_SCORE = False +WEB_SEARCH_SEPARATOR = "\n" + "-" * 40 + "\n" diff --git a/letta/functions/function_sets/builtin.py b/letta/functions/function_sets/builtin.py new file mode 100644 index 000000000..c8d695682 --- /dev/null +++ b/letta/functions/function_sets/builtin.py @@ -0,0 +1,27 @@ +from typing import Literal + + +async def web_search(query: str) -> str: + """ + Search the web for information. + Args: + query (str): The query to search the web for. + Returns: + str: The search results. + """ + + raise NotImplementedError("This is only available on the latest agent architecture. Please contact the Letta team.") + + +def run_code(code: str, language: Literal["python", "js", "ts", "r", "java"]) -> str: + """ + Run code in a sandbox. Supports Python, Javascript, Typescript, R, and Java. + + Args: + code (str): The code to run. + language (Literal["python", "js", "ts", "r", "java"]): The language of the code. + Returns: + str: The output of the code, the stdout, the stderr, and error traces (if any). + """ + + raise NotImplementedError("This is only available on the latest agent architecture. Please contact the Letta team.") diff --git a/letta/groups/sleeptime_multi_agent_v2.py b/letta/groups/sleeptime_multi_agent_v2.py index e2910e5b0..9cd2cede1 100644 --- a/letta/groups/sleeptime_multi_agent_v2.py +++ b/letta/groups/sleeptime_multi_agent_v2.py @@ -190,7 +190,7 @@ class SleeptimeMultiAgentV2(BaseAgent): prior_messages = [] if self.group.sleeptime_agent_frequency: try: - prior_messages = self.message_manager.list_messages_for_agent( + prior_messages = await self.message_manager.list_messages_for_agent_async( agent_id=foreground_agent_id, actor=self.actor, after=last_processed_message_id, diff --git a/letta/interfaces/anthropic_streaming_interface.py b/letta/interfaces/anthropic_streaming_interface.py index d86435388..1a8aa2201 100644 --- a/letta/interfaces/anthropic_streaming_interface.py +++ b/letta/interfaces/anthropic_streaming_interface.py @@ -1,3 +1,4 @@ +import json from datetime import datetime, timezone from enum import Enum from typing import AsyncGenerator, List, Union @@ -74,6 +75,7 @@ class AnthropicStreamingInterface: # usage trackers self.input_tokens = 0 self.output_tokens = 0 + self.model = None # reasoning object trackers self.reasoning_messages = [] @@ -88,7 +90,13 @@ class AnthropicStreamingInterface: def get_tool_call_object(self) -> ToolCall: """Useful for agent loop""" - return ToolCall(id=self.tool_call_id, function=FunctionCall(arguments=self.accumulated_tool_call_args, name=self.tool_call_name)) + # hack for tool rules + tool_input = json.loads(self.accumulated_tool_call_args) + if "id" in tool_input and tool_input["id"].startswith("toolu_") and "function" in tool_input: + arguments = str(json.dumps(tool_input["function"]["arguments"], indent=2)) + else: + arguments = self.accumulated_tool_call_args + return ToolCall(id=self.tool_call_id, function=FunctionCall(arguments=arguments, name=self.tool_call_name)) def _check_inner_thoughts_complete(self, combined_args: str) -> bool: """ @@ -311,6 +319,7 @@ class AnthropicStreamingInterface: self.message_id = event.message.id self.input_tokens += event.message.usage.input_tokens self.output_tokens += event.message.usage.output_tokens + self.model = event.message.model elif isinstance(event, BetaRawMessageDeltaEvent): self.output_tokens += event.usage.output_tokens elif isinstance(event, BetaRawMessageStopEvent): diff --git a/letta/interfaces/openai_streaming_interface.py b/letta/interfaces/openai_streaming_interface.py index 168d0521e..eea1b3b2c 100644 --- a/letta/interfaces/openai_streaming_interface.py +++ b/letta/interfaces/openai_streaming_interface.py @@ -40,6 +40,9 @@ class OpenAIStreamingInterface: self.letta_assistant_message_id = Message.generate_id() self.letta_tool_message_id = Message.generate_id() + self.message_id = None + self.model = None + # token counters self.input_tokens = 0 self.output_tokens = 0 @@ -69,10 +72,14 @@ class OpenAIStreamingInterface: prev_message_type = None message_index = 0 async for chunk in stream: + if not self.model or not self.message_id: + self.model = chunk.model + self.message_id = chunk.id + # track usage if chunk.usage: - self.input_tokens += len(chunk.usage.prompt_tokens) - self.output_tokens += len(chunk.usage.completion_tokens) + self.input_tokens += chunk.usage.prompt_tokens + self.output_tokens += chunk.usage.completion_tokens if chunk.choices: choice = chunk.choices[0] diff --git a/letta/llm_api/anthropic.py b/letta/llm_api/anthropic.py index 89329d01d..fadc652d6 100644 --- a/letta/llm_api/anthropic.py +++ b/letta/llm_api/anthropic.py @@ -134,13 +134,13 @@ def anthropic_check_valid_api_key(api_key: Union[str, None]) -> None: def antropic_get_model_context_window(url: str, api_key: Union[str, None], model: str) -> int: - for model_dict in anthropic_get_model_list(url=url, api_key=api_key): + for model_dict in anthropic_get_model_list(api_key=api_key): if model_dict["name"] == model: return model_dict["context_window"] raise ValueError(f"Can't find model '{model}' in Anthropic model list") -def anthropic_get_model_list(url: str, api_key: Union[str, None]) -> dict: +def anthropic_get_model_list(api_key: Optional[str]) -> dict: """https://docs.anthropic.com/claude/docs/models-overview""" # NOTE: currently there is no GET /models, so we need to hardcode @@ -159,6 +159,25 @@ def anthropic_get_model_list(url: str, api_key: Union[str, None]) -> dict: return models_json["data"] +async def anthropic_get_model_list_async(api_key: Optional[str]) -> dict: + """https://docs.anthropic.com/claude/docs/models-overview""" + + # NOTE: currently there is no GET /models, so we need to hardcode + # return MODEL_LIST + + if api_key: + anthropic_client = anthropic.AsyncAnthropic(api_key=api_key) + elif model_settings.anthropic_api_key: + anthropic_client = anthropic.AsyncAnthropic() + else: + raise ValueError("No API key provided") + + models = await anthropic_client.models.list() + models_json = models.model_dump() + assert "data" in models_json, f"Anthropic model query response missing 'data' field: {models_json}" + return models_json["data"] + + def convert_tools_to_anthropic_format(tools: List[Tool]) -> List[dict]: """See: https://docs.anthropic.com/claude/docs/tool-use diff --git a/letta/llm_api/anthropic_client.py b/letta/llm_api/anthropic_client.py index f26d58eba..f7509b037 100644 --- a/letta/llm_api/anthropic_client.py +++ b/letta/llm_api/anthropic_client.py @@ -35,6 +35,7 @@ from letta.schemas.openai.chat_completion_response import ChatCompletionResponse from letta.schemas.openai.chat_completion_response import Message as ChoiceMessage from letta.schemas.openai.chat_completion_response import ToolCall, UsageStatistics from letta.services.provider_manager import ProviderManager +from letta.settings import model_settings from letta.tracing import trace_method DUMMY_FIRST_USER_MESSAGE = "User initializing bootup sequence." @@ -120,8 +121,16 @@ class AnthropicClient(LLMClientBase): override_key = ProviderManager().get_override_key(llm_config.provider_name, actor=self.actor) if async_client: - return anthropic.AsyncAnthropic(api_key=override_key) if override_key else anthropic.AsyncAnthropic() - return anthropic.Anthropic(api_key=override_key) if override_key else anthropic.Anthropic() + return ( + anthropic.AsyncAnthropic(api_key=override_key, max_retries=model_settings.anthropic_max_retries) + if override_key + else anthropic.AsyncAnthropic(max_retries=model_settings.anthropic_max_retries) + ) + return ( + anthropic.Anthropic(api_key=override_key, max_retries=model_settings.anthropic_max_retries) + if override_key + else anthropic.Anthropic(max_retries=model_settings.anthropic_max_retries) + ) @trace_method def build_request_data( @@ -239,6 +248,24 @@ class AnthropicClient(LLMClientBase): return data + async def count_tokens(self, messages: List[dict] = None, model: str = None, tools: List[Tool] = None) -> int: + client = anthropic.AsyncAnthropic() + if messages and len(messages) == 0: + messages = None + if tools and len(tools) > 0: + anthropic_tools = convert_tools_to_anthropic_format(tools) + else: + anthropic_tools = None + result = await client.beta.messages.count_tokens( + model=model or "claude-3-7-sonnet-20250219", + messages=messages or [{"role": "user", "content": "hi"}], + tools=anthropic_tools or [], + ) + token_count = result.input_tokens + if messages is None: + token_count -= 8 + return token_count + def handle_llm_error(self, e: Exception) -> Exception: if isinstance(e, anthropic.APIConnectionError): logger.warning(f"[Anthropic] API connection error: {e.__cause__}") @@ -369,11 +396,11 @@ class AnthropicClient(LLMClientBase): content = strip_xml_tags(string=content_part.text, tag="thinking") if content_part.type == "tool_use": # hack for tool rules - input = json.loads(json.dumps(content_part.input)) - if "id" in input and input["id"].startswith("toolu_") and "function" in input: - arguments = str(input["function"]["arguments"]) + tool_input = json.loads(json.dumps(content_part.input)) + if "id" in tool_input and tool_input["id"].startswith("toolu_") and "function" in tool_input: + arguments = str(tool_input["function"]["arguments"]) else: - arguments = json.dumps(content_part.input, indent=2) + arguments = json.dumps(tool_input, indent=2) tool_calls = [ ToolCall( id=content_part.id, diff --git a/letta/llm_api/google_ai_client.py b/letta/llm_api/google_ai_client.py index f056a64bd..47671398c 100644 --- a/letta/llm_api/google_ai_client.py +++ b/letta/llm_api/google_ai_client.py @@ -1,422 +1,21 @@ -import json -import uuid from typing import List, Optional, Tuple -import requests +import httpx from google import genai -from google.genai.types import FunctionCallingConfig, FunctionCallingConfigMode, ToolConfig -from letta.constants import NON_USER_MSG_PREFIX from letta.errors import ErrorCode, LLMAuthenticationError, LLMError -from letta.helpers.datetime_helpers import get_utc_time_int -from letta.helpers.json_helpers import json_dumps from letta.llm_api.google_constants import GOOGLE_MODEL_FOR_API_KEY_CHECK -from letta.llm_api.helpers import make_post_request -from letta.llm_api.llm_client_base import LLMClientBase -from letta.local_llm.json_parser import clean_json_string_extra_backslash -from letta.local_llm.utils import count_tokens +from letta.llm_api.google_vertex_client import GoogleVertexClient from letta.log import get_logger -from letta.schemas.enums import ProviderCategory -from letta.schemas.llm_config import LLMConfig -from letta.schemas.message import Message as PydanticMessage -from letta.schemas.openai.chat_completion_request import Tool -from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice, FunctionCall, Message, ToolCall, UsageStatistics from letta.settings import model_settings -from letta.utils import get_tool_call_id logger = get_logger(__name__) -class GoogleAIClient(LLMClientBase): +class GoogleAIClient(GoogleVertexClient): - def request(self, request_data: dict, llm_config: LLMConfig) -> dict: - """ - Performs underlying request to llm and returns raw response. - """ - api_key = None - if llm_config.provider_category == ProviderCategory.byok: - from letta.services.provider_manager import ProviderManager - - api_key = ProviderManager().get_override_key(llm_config.provider_name, actor=self.actor) - - if not api_key: - api_key = model_settings.gemini_api_key - - # print("[google_ai request]", json.dumps(request_data, indent=2)) - url, headers = get_gemini_endpoint_and_headers( - base_url=str(llm_config.model_endpoint), - model=llm_config.model, - api_key=str(api_key), - key_in_header=True, - generate_content=True, - ) - return make_post_request(url, headers, request_data) - - def build_request_data( - self, - messages: List[PydanticMessage], - llm_config: LLMConfig, - tools: List[dict], - force_tool_call: Optional[str] = None, - ) -> dict: - """ - Constructs a request object in the expected data format for this client. - """ - if tools: - tools = [{"type": "function", "function": f} for f in tools] - tool_objs = [Tool(**t) for t in tools] - tool_names = [t.function.name for t in tool_objs] - # Convert to the exact payload style Google expects - tools = self.convert_tools_to_google_ai_format(tool_objs, llm_config) - else: - tool_names = [] - - contents = self.add_dummy_model_messages( - [m.to_google_ai_dict() for m in messages], - ) - - request_data = { - "contents": contents, - "tools": tools, - "generation_config": { - "temperature": llm_config.temperature, - "max_output_tokens": llm_config.max_tokens, - }, - } - - # write tool config - tool_config = ToolConfig( - function_calling_config=FunctionCallingConfig( - # ANY mode forces the model to predict only function calls - mode=FunctionCallingConfigMode.ANY, - # Provide the list of tools (though empty should also work, it seems not to) - allowed_function_names=tool_names, - ) - ) - request_data["tool_config"] = tool_config.model_dump() - return request_data - - def convert_response_to_chat_completion( - self, - response_data: dict, - input_messages: List[PydanticMessage], - llm_config: LLMConfig, - ) -> ChatCompletionResponse: - """ - Converts custom response format from llm client into an OpenAI - ChatCompletionsResponse object. - - Example Input: - { - "candidates": [ - { - "content": { - "parts": [ - { - "text": " OK. Barbie is showing in two theaters in Mountain View, CA: AMC Mountain View 16 and Regal Edwards 14." - } - ] - } - } - ], - "usageMetadata": { - "promptTokenCount": 9, - "candidatesTokenCount": 27, - "totalTokenCount": 36 - } - } - """ - # print("[google_ai response]", json.dumps(response_data, indent=2)) - - try: - choices = [] - index = 0 - for candidate in response_data["candidates"]: - content = candidate["content"] - - if "role" not in content or not content["role"]: - # This means the response is malformed like MALFORMED_FUNCTION_CALL - # NOTE: must be a ValueError to trigger a retry - raise ValueError(f"Error in response data from LLM: {response_data}") - role = content["role"] - assert role == "model", f"Unknown role in response: {role}" - - parts = content["parts"] - - # NOTE: we aren't properly supported multi-parts here anyways (we're just appending choices), - # so let's disable it for now - - # NOTE(Apr 9, 2025): there's a very strange bug on 2.5 where the response has a part with broken text - # {'candidates': [{'content': {'parts': [{'functionCall': {'name': 'send_message', 'args': {'request_heartbeat': False, 'message': 'Hello! How can I make your day better?', 'inner_thoughts': 'User has initiated contact. Sending a greeting.'}}}], 'role': 'model'}, 'finishReason': 'STOP', 'avgLogprobs': -0.25891534213362066}], 'usageMetadata': {'promptTokenCount': 2493, 'candidatesTokenCount': 29, 'totalTokenCount': 2522, 'promptTokensDetails': [{'modality': 'TEXT', 'tokenCount': 2493}], 'candidatesTokensDetails': [{'modality': 'TEXT', 'tokenCount': 29}]}, 'modelVersion': 'gemini-1.5-pro-002'} - # To patch this, if we have multiple parts we can take the last one - if len(parts) > 1: - logger.warning(f"Unexpected multiple parts in response from Google AI: {parts}") - parts = [parts[-1]] - - # TODO support parts / multimodal - # TODO support parallel tool calling natively - # TODO Alternative here is to throw away everything else except for the first part - for response_message in parts: - # Convert the actual message style to OpenAI style - if "functionCall" in response_message and response_message["functionCall"] is not None: - function_call = response_message["functionCall"] - assert isinstance(function_call, dict), function_call - function_name = function_call["name"] - assert isinstance(function_name, str), function_name - function_args = function_call["args"] - assert isinstance(function_args, dict), function_args - - # NOTE: this also involves stripping the inner monologue out of the function - if llm_config.put_inner_thoughts_in_kwargs: - from letta.local_llm.constants import INNER_THOUGHTS_KWARG_VERTEX - - assert ( - INNER_THOUGHTS_KWARG_VERTEX in function_args - ), f"Couldn't find inner thoughts in function args:\n{function_call}" - inner_thoughts = function_args.pop(INNER_THOUGHTS_KWARG_VERTEX) - assert inner_thoughts is not None, f"Expected non-null inner thoughts function arg:\n{function_call}" - else: - inner_thoughts = None - - # Google AI API doesn't generate tool call IDs - openai_response_message = Message( - role="assistant", # NOTE: "model" -> "assistant" - content=inner_thoughts, - tool_calls=[ - ToolCall( - id=get_tool_call_id(), - type="function", - function=FunctionCall( - name=function_name, - arguments=clean_json_string_extra_backslash(json_dumps(function_args)), - ), - ) - ], - ) - - else: - - # Inner thoughts are the content by default - inner_thoughts = response_message["text"] - - # Google AI API doesn't generate tool call IDs - openai_response_message = Message( - role="assistant", # NOTE: "model" -> "assistant" - content=inner_thoughts, - ) - - # Google AI API uses different finish reason strings than OpenAI - # OpenAI: 'stop', 'length', 'function_call', 'content_filter', null - # see: https://platform.openai.com/docs/guides/text-generation/chat-completions-api - # Google AI API: FINISH_REASON_UNSPECIFIED, STOP, MAX_TOKENS, SAFETY, RECITATION, OTHER - # see: https://ai.google.dev/api/python/google/ai/generativelanguage/Candidate/FinishReason - finish_reason = candidate["finishReason"] - if finish_reason == "STOP": - openai_finish_reason = ( - "function_call" - if openai_response_message.tool_calls is not None and len(openai_response_message.tool_calls) > 0 - else "stop" - ) - elif finish_reason == "MAX_TOKENS": - openai_finish_reason = "length" - elif finish_reason == "SAFETY": - openai_finish_reason = "content_filter" - elif finish_reason == "RECITATION": - openai_finish_reason = "content_filter" - else: - raise ValueError(f"Unrecognized finish reason in Google AI response: {finish_reason}") - - choices.append( - Choice( - finish_reason=openai_finish_reason, - index=index, - message=openai_response_message, - ) - ) - index += 1 - - # if len(choices) > 1: - # raise UserWarning(f"Unexpected number of candidates in response (expected 1, got {len(choices)})") - - # NOTE: some of the Google AI APIs show UsageMetadata in the response, but it seems to not exist? - # "usageMetadata": { - # "promptTokenCount": 9, - # "candidatesTokenCount": 27, - # "totalTokenCount": 36 - # } - if "usageMetadata" in response_data: - usage_data = response_data["usageMetadata"] - if "promptTokenCount" not in usage_data: - raise ValueError(f"promptTokenCount not found in usageMetadata:\n{json.dumps(usage_data, indent=2)}") - if "totalTokenCount" not in usage_data: - raise ValueError(f"totalTokenCount not found in usageMetadata:\n{json.dumps(usage_data, indent=2)}") - if "candidatesTokenCount" not in usage_data: - raise ValueError(f"candidatesTokenCount not found in usageMetadata:\n{json.dumps(usage_data, indent=2)}") - - prompt_tokens = usage_data["promptTokenCount"] - completion_tokens = usage_data["candidatesTokenCount"] - total_tokens = usage_data["totalTokenCount"] - - usage = UsageStatistics( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=total_tokens, - ) - else: - # Count it ourselves - assert input_messages is not None, f"Didn't get UsageMetadata from the API response, so input_messages is required" - prompt_tokens = count_tokens(json_dumps(input_messages)) # NOTE: this is a very rough approximation - completion_tokens = count_tokens(json_dumps(openai_response_message.model_dump())) # NOTE: this is also approximate - total_tokens = prompt_tokens + completion_tokens - usage = UsageStatistics( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=total_tokens, - ) - - response_id = str(uuid.uuid4()) - return ChatCompletionResponse( - id=response_id, - choices=choices, - model=llm_config.model, # NOTE: Google API doesn't pass back model in the response - created=get_utc_time_int(), - usage=usage, - ) - except KeyError as e: - raise e - - def _clean_google_ai_schema_properties(self, schema_part: dict): - """Recursively clean schema parts to remove unsupported Google AI keywords.""" - if not isinstance(schema_part, dict): - return - - # Per https://ai.google.dev/gemini-api/docs/function-calling?example=meeting#notes_and_limitations - # * Only a subset of the OpenAPI schema is supported. - # * Supported parameter types in Python are limited. - unsupported_keys = ["default", "exclusiveMaximum", "exclusiveMinimum", "additionalProperties"] - keys_to_remove_at_this_level = [key for key in unsupported_keys if key in schema_part] - for key_to_remove in keys_to_remove_at_this_level: - logger.warning(f"Removing unsupported keyword '{key_to_remove}' from schema part.") - del schema_part[key_to_remove] - - if schema_part.get("type") == "string" and "format" in schema_part: - allowed_formats = ["enum", "date-time"] - if schema_part["format"] not in allowed_formats: - logger.warning(f"Removing unsupported format '{schema_part['format']}' for string type. Allowed: {allowed_formats}") - del schema_part["format"] - - # Check properties within the current level - if "properties" in schema_part and isinstance(schema_part["properties"], dict): - for prop_name, prop_schema in schema_part["properties"].items(): - self._clean_google_ai_schema_properties(prop_schema) - - # Check items within arrays - if "items" in schema_part and isinstance(schema_part["items"], dict): - self._clean_google_ai_schema_properties(schema_part["items"]) - - # Check within anyOf, allOf, oneOf lists - for key in ["anyOf", "allOf", "oneOf"]: - if key in schema_part and isinstance(schema_part[key], list): - for item_schema in schema_part[key]: - self._clean_google_ai_schema_properties(item_schema) - - def convert_tools_to_google_ai_format(self, tools: List[Tool], llm_config: LLMConfig) -> List[dict]: - """ - OpenAI style: - "tools": [{ - "type": "function", - "function": { - "name": "find_movies", - "description": "find ....", - "parameters": { - "type": "object", - "properties": { - PARAM: { - "type": PARAM_TYPE, # eg "string" - "description": PARAM_DESCRIPTION, - }, - ... - }, - "required": List[str], - } - } - } - ] - - Google AI style: - "tools": [{ - "functionDeclarations": [{ - "name": "find_movies", - "description": "find movie titles currently playing in theaters based on any description, genre, title words, etc.", - "parameters": { - "type": "OBJECT", - "properties": { - "location": { - "type": "STRING", - "description": "The city and state, e.g. San Francisco, CA or a zip code e.g. 95616" - }, - "description": { - "type": "STRING", - "description": "Any kind of description including category or genre, title words, attributes, etc." - } - }, - "required": ["description"] - } - }, { - "name": "find_theaters", - ... - """ - function_list = [ - dict( - name=t.function.name, - description=t.function.description, - parameters=t.function.parameters, # TODO need to unpack - ) - for t in tools - ] - - # Add inner thoughts if needed - for func in function_list: - # Note: Google AI API used to have weird casing requirements, but not any more - - # Google AI API only supports a subset of OpenAPI 3.0, so unsupported params must be cleaned - if "parameters" in func and isinstance(func["parameters"], dict): - self._clean_google_ai_schema_properties(func["parameters"]) - - # Add inner thoughts - if llm_config.put_inner_thoughts_in_kwargs: - from letta.local_llm.constants import INNER_THOUGHTS_KWARG_DESCRIPTION, INNER_THOUGHTS_KWARG_VERTEX - - func["parameters"]["properties"][INNER_THOUGHTS_KWARG_VERTEX] = { - "type": "string", - "description": INNER_THOUGHTS_KWARG_DESCRIPTION, - } - func["parameters"]["required"].append(INNER_THOUGHTS_KWARG_VERTEX) - - return [{"functionDeclarations": function_list}] - - def add_dummy_model_messages(self, messages: List[dict]) -> List[dict]: - """Google AI API requires all function call returns are immediately followed by a 'model' role message. - - In Letta, the 'model' will often call a function (e.g. send_message) that itself yields to the user, - so there is no natural follow-up 'model' role message. - - To satisfy the Google AI API restrictions, we can add a dummy 'yield' message - with role == 'model' that is placed in-betweeen and function output - (role == 'tool') and user message (role == 'user'). - """ - dummy_yield_message = { - "role": "model", - "parts": [{"text": f"{NON_USER_MSG_PREFIX}Function call returned, waiting for user response."}], - } - messages_with_padding = [] - for i, message in enumerate(messages): - messages_with_padding.append(message) - # Check if the current message role is 'tool' and the next message role is 'user' - if message["role"] in ["tool", "function"] and (i + 1 < len(messages) and messages[i + 1]["role"] == "user"): - messages_with_padding.append(dummy_yield_message) - - return messages_with_padding + def _get_client(self): + return genai.Client(api_key=model_settings.gemini_api_key) def get_gemini_endpoint_and_headers( @@ -464,20 +63,24 @@ def google_ai_check_valid_api_key(api_key: str): def google_ai_get_model_list(base_url: str, api_key: str, key_in_header: bool = True) -> List[dict]: + """Synchronous version to get model list from Google AI API using httpx.""" + import httpx + from letta.utils import printd url, headers = get_gemini_endpoint_and_headers(base_url, None, api_key, key_in_header) try: - response = requests.get(url, headers=headers) - response.raise_for_status() # Raises HTTPError for 4XX/5XX status - response = response.json() # convert to dict from string + with httpx.Client() as client: + response = client.get(url, headers=headers) + response.raise_for_status() # Raises HTTPStatusError for 4XX/5XX status + response_data = response.json() # convert to dict from string - # Grab the models out - model_list = response["models"] - return model_list + # Grab the models out + model_list = response_data["models"] + return model_list - except requests.exceptions.HTTPError as http_err: + except httpx.HTTPStatusError as http_err: # Handle HTTP errors (e.g., response 4XX, 5XX) printd(f"Got HTTPError, exception={http_err}") # Print the HTTP status code @@ -486,8 +89,8 @@ def google_ai_get_model_list(base_url: str, api_key: str, key_in_header: bool = print(f"Message: {http_err.response.text}") raise http_err - except requests.exceptions.RequestException as req_err: - # Handle other requests-related errors (e.g., connection error) + except httpx.RequestError as req_err: + # Handle other httpx-related errors (e.g., connection error) printd(f"Got RequestException, exception={req_err}") raise req_err @@ -497,22 +100,74 @@ def google_ai_get_model_list(base_url: str, api_key: str, key_in_header: bool = raise e -def google_ai_get_model_details(base_url: str, api_key: str, model: str, key_in_header: bool = True) -> List[dict]: +async def google_ai_get_model_list_async( + base_url: str, api_key: str, key_in_header: bool = True, client: Optional[httpx.AsyncClient] = None +) -> List[dict]: + """Asynchronous version to get model list from Google AI API using httpx.""" + from letta.utils import printd + + url, headers = get_gemini_endpoint_and_headers(base_url, None, api_key, key_in_header) + + # Determine if we need to close the client at the end + close_client = False + if client is None: + client = httpx.AsyncClient() + close_client = True + + try: + response = await client.get(url, headers=headers) + response.raise_for_status() # Raises HTTPStatusError for 4XX/5XX status + response_data = response.json() # convert to dict from string + + # Grab the models out + model_list = response_data["models"] + return model_list + + except httpx.HTTPStatusError as http_err: + # Handle HTTP errors (e.g., response 4XX, 5XX) + printd(f"Got HTTPError, exception={http_err}") + # Print the HTTP status code + print(f"HTTP Error: {http_err.response.status_code}") + # Print the response content (error message from server) + print(f"Message: {http_err.response.text}") + raise http_err + + except httpx.RequestError as req_err: + # Handle other httpx-related errors (e.g., connection error) + printd(f"Got RequestException, exception={req_err}") + raise req_err + + except Exception as e: + # Handle other potential errors + printd(f"Got unknown Exception, exception={e}") + raise e + + finally: + # Close the client if we created it + if close_client: + await client.aclose() + + +def google_ai_get_model_details(base_url: str, api_key: str, model: str, key_in_header: bool = True) -> dict: + """Synchronous version to get model details from Google AI API using httpx.""" + import httpx + from letta.utils import printd url, headers = get_gemini_endpoint_and_headers(base_url, model, api_key, key_in_header) try: - response = requests.get(url, headers=headers) - printd(f"response = {response}") - response.raise_for_status() # Raises HTTPError for 4XX/5XX status - response = response.json() # convert to dict from string - printd(f"response.json = {response}") + with httpx.Client() as client: + response = client.get(url, headers=headers) + printd(f"response = {response}") + response.raise_for_status() # Raises HTTPStatusError for 4XX/5XX status + response_data = response.json() # convert to dict from string + printd(f"response.json = {response_data}") - # Grab the models out - return response + # Return the model details + return response_data - except requests.exceptions.HTTPError as http_err: + except httpx.HTTPStatusError as http_err: # Handle HTTP errors (e.g., response 4XX, 5XX) printd(f"Got HTTPError, exception={http_err}") # Print the HTTP status code @@ -521,8 +176,8 @@ def google_ai_get_model_details(base_url: str, api_key: str, model: str, key_in_ print(f"Message: {http_err.response.text}") raise http_err - except requests.exceptions.RequestException as req_err: - # Handle other requests-related errors (e.g., connection error) + except httpx.RequestError as req_err: + # Handle other httpx-related errors (e.g., connection error) printd(f"Got RequestException, exception={req_err}") raise req_err @@ -532,8 +187,66 @@ def google_ai_get_model_details(base_url: str, api_key: str, model: str, key_in_ raise e +async def google_ai_get_model_details_async( + base_url: str, api_key: str, model: str, key_in_header: bool = True, client: Optional[httpx.AsyncClient] = None +) -> dict: + """Asynchronous version to get model details from Google AI API using httpx.""" + import httpx + + from letta.utils import printd + + url, headers = get_gemini_endpoint_and_headers(base_url, model, api_key, key_in_header) + + # Determine if we need to close the client at the end + close_client = False + if client is None: + client = httpx.AsyncClient() + close_client = True + + try: + response = await client.get(url, headers=headers) + printd(f"response = {response}") + response.raise_for_status() # Raises HTTPStatusError for 4XX/5XX status + response_data = response.json() # convert to dict from string + printd(f"response.json = {response_data}") + + # Return the model details + return response_data + + except httpx.HTTPStatusError as http_err: + # Handle HTTP errors (e.g., response 4XX, 5XX) + printd(f"Got HTTPError, exception={http_err}") + # Print the HTTP status code + print(f"HTTP Error: {http_err.response.status_code}") + # Print the response content (error message from server) + print(f"Message: {http_err.response.text}") + raise http_err + + except httpx.RequestError as req_err: + # Handle other httpx-related errors (e.g., connection error) + printd(f"Got RequestException, exception={req_err}") + raise req_err + + except Exception as e: + # Handle other potential errors + printd(f"Got unknown Exception, exception={e}") + raise e + + finally: + # Close the client if we created it + if close_client: + await client.aclose() + + def google_ai_get_model_context_window(base_url: str, api_key: str, model: str, key_in_header: bool = True) -> int: model_details = google_ai_get_model_details(base_url=base_url, api_key=api_key, model=model, key_in_header=key_in_header) # TODO should this be: # return model_details["inputTokenLimit"] + model_details["outputTokenLimit"] return int(model_details["inputTokenLimit"]) + + +async def google_ai_get_model_context_window_async(base_url: str, api_key: str, model: str, key_in_header: bool = True) -> int: + model_details = await google_ai_get_model_details_async(base_url=base_url, api_key=api_key, model=model, key_in_header=key_in_header) + # TODO should this be: + # return model_details["inputTokenLimit"] + model_details["outputTokenLimit"] + return int(model_details["inputTokenLimit"]) diff --git a/letta/llm_api/google_vertex_client.py b/letta/llm_api/google_vertex_client.py index 7319f7fc6..2874b62a9 100644 --- a/letta/llm_api/google_vertex_client.py +++ b/letta/llm_api/google_vertex_client.py @@ -5,14 +5,16 @@ from typing import List, Optional from google import genai from google.genai.types import FunctionCallingConfig, FunctionCallingConfigMode, GenerateContentResponse, ThinkingConfig, ToolConfig +from letta.constants import NON_USER_MSG_PREFIX from letta.helpers.datetime_helpers import get_utc_time_int from letta.helpers.json_helpers import json_dumps, json_loads -from letta.llm_api.google_ai_client import GoogleAIClient +from letta.llm_api.llm_client_base import LLMClientBase from letta.local_llm.json_parser import clean_json_string_extra_backslash from letta.local_llm.utils import count_tokens from letta.log import get_logger from letta.schemas.llm_config import LLMConfig from letta.schemas.message import Message as PydanticMessage +from letta.schemas.openai.chat_completion_request import Tool from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice, FunctionCall, Message, ToolCall, UsageStatistics from letta.settings import model_settings, settings from letta.utils import get_tool_call_id @@ -20,18 +22,21 @@ from letta.utils import get_tool_call_id logger = get_logger(__name__) -class GoogleVertexClient(GoogleAIClient): +class GoogleVertexClient(LLMClientBase): - def request(self, request_data: dict, llm_config: LLMConfig) -> dict: - """ - Performs underlying request to llm and returns raw response. - """ - client = genai.Client( + def _get_client(self): + return genai.Client( vertexai=True, project=model_settings.google_cloud_project, location=model_settings.google_cloud_location, http_options={"api_version": "v1"}, ) + + def request(self, request_data: dict, llm_config: LLMConfig) -> dict: + """ + Performs underlying request to llm and returns raw response. + """ + client = self._get_client() response = client.models.generate_content( model=llm_config.model, contents=request_data["contents"], @@ -43,12 +48,7 @@ class GoogleVertexClient(GoogleAIClient): """ Performs underlying request to llm and returns raw response. """ - client = genai.Client( - vertexai=True, - project=model_settings.google_cloud_project, - location=model_settings.google_cloud_location, - http_options={"api_version": "v1"}, - ) + client = self._get_client() response = await client.aio.models.generate_content( model=llm_config.model, contents=request_data["contents"], @@ -56,6 +56,139 @@ class GoogleVertexClient(GoogleAIClient): ) return response.model_dump() + def add_dummy_model_messages(self, messages: List[dict]) -> List[dict]: + """Google AI API requires all function call returns are immediately followed by a 'model' role message. + + In Letta, the 'model' will often call a function (e.g. send_message) that itself yields to the user, + so there is no natural follow-up 'model' role message. + + To satisfy the Google AI API restrictions, we can add a dummy 'yield' message + with role == 'model' that is placed in-betweeen and function output + (role == 'tool') and user message (role == 'user'). + """ + dummy_yield_message = { + "role": "model", + "parts": [{"text": f"{NON_USER_MSG_PREFIX}Function call returned, waiting for user response."}], + } + messages_with_padding = [] + for i, message in enumerate(messages): + messages_with_padding.append(message) + # Check if the current message role is 'tool' and the next message role is 'user' + if message["role"] in ["tool", "function"] and (i + 1 < len(messages) and messages[i + 1]["role"] == "user"): + messages_with_padding.append(dummy_yield_message) + + return messages_with_padding + + def _clean_google_ai_schema_properties(self, schema_part: dict): + """Recursively clean schema parts to remove unsupported Google AI keywords.""" + if not isinstance(schema_part, dict): + return + + # Per https://ai.google.dev/gemini-api/docs/function-calling?example=meeting#notes_and_limitations + # * Only a subset of the OpenAPI schema is supported. + # * Supported parameter types in Python are limited. + unsupported_keys = ["default", "exclusiveMaximum", "exclusiveMinimum", "additionalProperties"] + keys_to_remove_at_this_level = [key for key in unsupported_keys if key in schema_part] + for key_to_remove in keys_to_remove_at_this_level: + logger.warning(f"Removing unsupported keyword '{key_to_remove}' from schema part.") + del schema_part[key_to_remove] + + if schema_part.get("type") == "string" and "format" in schema_part: + allowed_formats = ["enum", "date-time"] + if schema_part["format"] not in allowed_formats: + logger.warning(f"Removing unsupported format '{schema_part['format']}' for string type. Allowed: {allowed_formats}") + del schema_part["format"] + + # Check properties within the current level + if "properties" in schema_part and isinstance(schema_part["properties"], dict): + for prop_name, prop_schema in schema_part["properties"].items(): + self._clean_google_ai_schema_properties(prop_schema) + + # Check items within arrays + if "items" in schema_part and isinstance(schema_part["items"], dict): + self._clean_google_ai_schema_properties(schema_part["items"]) + + # Check within anyOf, allOf, oneOf lists + for key in ["anyOf", "allOf", "oneOf"]: + if key in schema_part and isinstance(schema_part[key], list): + for item_schema in schema_part[key]: + self._clean_google_ai_schema_properties(item_schema) + + def convert_tools_to_google_ai_format(self, tools: List[Tool], llm_config: LLMConfig) -> List[dict]: + """ + OpenAI style: + "tools": [{ + "type": "function", + "function": { + "name": "find_movies", + "description": "find ....", + "parameters": { + "type": "object", + "properties": { + PARAM: { + "type": PARAM_TYPE, # eg "string" + "description": PARAM_DESCRIPTION, + }, + ... + }, + "required": List[str], + } + } + } + ] + + Google AI style: + "tools": [{ + "functionDeclarations": [{ + "name": "find_movies", + "description": "find movie titles currently playing in theaters based on any description, genre, title words, etc.", + "parameters": { + "type": "OBJECT", + "properties": { + "location": { + "type": "STRING", + "description": "The city and state, e.g. San Francisco, CA or a zip code e.g. 95616" + }, + "description": { + "type": "STRING", + "description": "Any kind of description including category or genre, title words, attributes, etc." + } + }, + "required": ["description"] + } + }, { + "name": "find_theaters", + ... + """ + function_list = [ + dict( + name=t.function.name, + description=t.function.description, + parameters=t.function.parameters, # TODO need to unpack + ) + for t in tools + ] + + # Add inner thoughts if needed + for func in function_list: + # Note: Google AI API used to have weird casing requirements, but not any more + + # Google AI API only supports a subset of OpenAPI 3.0, so unsupported params must be cleaned + if "parameters" in func and isinstance(func["parameters"], dict): + self._clean_google_ai_schema_properties(func["parameters"]) + + # Add inner thoughts + if llm_config.put_inner_thoughts_in_kwargs: + from letta.local_llm.constants import INNER_THOUGHTS_KWARG_DESCRIPTION, INNER_THOUGHTS_KWARG_VERTEX + + func["parameters"]["properties"][INNER_THOUGHTS_KWARG_VERTEX] = { + "type": "string", + "description": INNER_THOUGHTS_KWARG_DESCRIPTION, + } + func["parameters"]["required"].append(INNER_THOUGHTS_KWARG_VERTEX) + + return [{"functionDeclarations": function_list}] + def build_request_data( self, messages: List[PydanticMessage], @@ -66,11 +199,29 @@ class GoogleVertexClient(GoogleAIClient): """ Constructs a request object in the expected data format for this client. """ - request_data = super().build_request_data(messages, llm_config, tools, force_tool_call) - request_data["config"] = request_data.pop("generation_config") - request_data["config"]["tools"] = request_data.pop("tools") - tool_names = [t["name"] for t in tools] if tools else [] + if tools: + tool_objs = [Tool(type="function", function=t) for t in tools] + tool_names = [t.function.name for t in tool_objs] + # Convert to the exact payload style Google expects + formatted_tools = self.convert_tools_to_google_ai_format(tool_objs, llm_config) + else: + formatted_tools = [] + tool_names = [] + + contents = self.add_dummy_model_messages( + [m.to_google_ai_dict() for m in messages], + ) + + request_data = { + "contents": contents, + "config": { + "temperature": llm_config.temperature, + "max_output_tokens": llm_config.max_tokens, + "tools": formatted_tools, + }, + } + if len(tool_names) == 1 and settings.use_vertex_structured_outputs_experimental: request_data["config"]["response_mime_type"] = "application/json" request_data["config"]["response_schema"] = self.get_function_call_response_schema(tools[0]) @@ -89,11 +240,11 @@ class GoogleVertexClient(GoogleAIClient): # Add thinking_config # If enable_reasoner is False, set thinking_budget to 0 # Otherwise, use the value from max_reasoning_tokens - thinking_budget = 0 if not llm_config.enable_reasoner else llm_config.max_reasoning_tokens - thinking_config = ThinkingConfig( - thinking_budget=thinking_budget, - ) - request_data["config"]["thinking_config"] = thinking_config.model_dump() + if llm_config.enable_reasoner: + thinking_config = ThinkingConfig( + thinking_budget=llm_config.max_reasoning_tokens, + ) + request_data["config"]["thinking_config"] = thinking_config.model_dump() return request_data diff --git a/letta/llm_api/llm_api_tools.py b/letta/llm_api/llm_api_tools.py index d86abc9b3..a1af262f2 100644 --- a/letta/llm_api/llm_api_tools.py +++ b/letta/llm_api/llm_api_tools.py @@ -20,15 +20,19 @@ from letta.llm_api.openai import ( build_openai_chat_completions_request, openai_chat_completions_process_stream, openai_chat_completions_request, + prepare_openai_payload, ) from letta.local_llm.chat_completion_proxy import get_chat_completion from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages +from letta.orm.user import User from letta.schemas.enums import ProviderCategory from letta.schemas.llm_config import LLMConfig from letta.schemas.message import Message from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, cast_message_to_subtype from letta.schemas.openai.chat_completion_response import ChatCompletionResponse +from letta.schemas.provider_trace import ProviderTraceCreate +from letta.services.telemetry_manager import TelemetryManager from letta.settings import ModelSettings from letta.streaming_interface import AgentChunkStreamingInterface, AgentRefreshStreamingInterface from letta.tracing import log_event, trace_method @@ -142,6 +146,9 @@ def create( model_settings: Optional[dict] = None, # TODO: eventually pass from server put_inner_thoughts_first: bool = True, name: Optional[str] = None, + telemetry_manager: Optional[TelemetryManager] = None, + step_id: Optional[str] = None, + actor: Optional[User] = None, ) -> ChatCompletionResponse: """Return response to chat completion with backoff""" from letta.utils import printd @@ -233,6 +240,16 @@ def create( if isinstance(stream_interface, AgentChunkStreamingInterface): stream_interface.stream_end() + telemetry_manager.create_provider_trace( + actor=actor, + provider_trace_create=ProviderTraceCreate( + request_json=prepare_openai_payload(data), + response_json=response.model_json_schema(), + step_id=step_id, + organization_id=actor.organization_id, + ), + ) + if llm_config.put_inner_thoughts_in_kwargs: response = unpack_all_inner_thoughts_from_kwargs(response=response, inner_thoughts_key=INNER_THOUGHTS_KWARG) @@ -407,6 +424,16 @@ def create( if llm_config.put_inner_thoughts_in_kwargs: response = unpack_all_inner_thoughts_from_kwargs(response=response, inner_thoughts_key=INNER_THOUGHTS_KWARG) + telemetry_manager.create_provider_trace( + actor=actor, + provider_trace_create=ProviderTraceCreate( + request_json=chat_completion_request.model_json_schema(), + response_json=response.model_json_schema(), + step_id=step_id, + organization_id=actor.organization_id, + ), + ) + return response # elif llm_config.model_endpoint_type == "cohere": diff --git a/letta/llm_api/llm_client.py b/letta/llm_api/llm_client.py index 63adbcc2b..7372b68aa 100644 --- a/letta/llm_api/llm_client.py +++ b/letta/llm_api/llm_client.py @@ -51,7 +51,7 @@ class LLMClient: put_inner_thoughts_first=put_inner_thoughts_first, actor=actor, ) - case ProviderType.openai: + case ProviderType.openai | ProviderType.together: from letta.llm_api.openai_client import OpenAIClient return OpenAIClient( diff --git a/letta/llm_api/llm_client_base.py b/letta/llm_api/llm_client_base.py index f56601ee8..6374a85c9 100644 --- a/letta/llm_api/llm_client_base.py +++ b/letta/llm_api/llm_client_base.py @@ -9,7 +9,9 @@ from letta.errors import LLMError from letta.schemas.llm_config import LLMConfig from letta.schemas.message import Message from letta.schemas.openai.chat_completion_response import ChatCompletionResponse -from letta.tracing import log_event +from letta.schemas.provider_trace import ProviderTraceCreate +from letta.services.telemetry_manager import TelemetryManager +from letta.tracing import log_event, trace_method if TYPE_CHECKING: from letta.orm import User @@ -31,13 +33,15 @@ class LLMClientBase: self.put_inner_thoughts_first = put_inner_thoughts_first self.use_tool_naming = use_tool_naming + @trace_method def send_llm_request( self, messages: List[Message], llm_config: LLMConfig, tools: Optional[List[dict]] = None, # TODO: change to Tool object - stream: bool = False, force_tool_call: Optional[str] = None, + telemetry_manager: Optional["TelemetryManager"] = None, + step_id: Optional[str] = None, ) -> Union[ChatCompletionResponse, Stream[ChatCompletionChunk]]: """ Issues a request to the downstream model endpoint and parses response. @@ -48,37 +52,51 @@ class LLMClientBase: try: log_event(name="llm_request_sent", attributes=request_data) - if stream: - return self.stream(request_data, llm_config) - else: - response_data = self.request(request_data, llm_config) + response_data = self.request(request_data, llm_config) + if step_id and telemetry_manager: + telemetry_manager.create_provider_trace( + actor=self.actor, + provider_trace_create=ProviderTraceCreate( + request_json=request_data, + response_json=response_data, + step_id=step_id, + organization_id=self.actor.organization_id, + ), + ) log_event(name="llm_response_received", attributes=response_data) except Exception as e: raise self.handle_llm_error(e) return self.convert_response_to_chat_completion(response_data, messages, llm_config) + @trace_method async def send_llm_request_async( self, + request_data: dict, messages: List[Message], llm_config: LLMConfig, - tools: Optional[List[dict]] = None, # TODO: change to Tool object - stream: bool = False, - force_tool_call: Optional[str] = None, + telemetry_manager: "TelemetryManager | None" = None, + step_id: str | None = None, ) -> Union[ChatCompletionResponse, AsyncStream[ChatCompletionChunk]]: """ Issues a request to the downstream model endpoint. If stream=True, returns an AsyncStream[ChatCompletionChunk] that can be async iterated over. Otherwise returns a ChatCompletionResponse. """ - request_data = self.build_request_data(messages, llm_config, tools, force_tool_call) try: log_event(name="llm_request_sent", attributes=request_data) - if stream: - return await self.stream_async(request_data, llm_config) - else: - response_data = await self.request_async(request_data, llm_config) + response_data = await self.request_async(request_data, llm_config) + await telemetry_manager.create_provider_trace_async( + actor=self.actor, + provider_trace_create=ProviderTraceCreate( + request_json=request_data, + response_json=response_data, + step_id=step_id, + organization_id=self.actor.organization_id, + ), + ) + log_event(name="llm_response_received", attributes=response_data) except Exception as e: raise self.handle_llm_error(e) @@ -133,13 +151,6 @@ class LLMClientBase: """ raise NotImplementedError - @abstractmethod - def stream(self, request_data: dict, llm_config: LLMConfig) -> Stream[ChatCompletionChunk]: - """ - Performs underlying streaming request to llm and returns raw response. - """ - raise NotImplementedError(f"Streaming is not supported for {llm_config.model_endpoint_type}") - @abstractmethod async def stream_async(self, request_data: dict, llm_config: LLMConfig) -> AsyncStream[ChatCompletionChunk]: """ diff --git a/letta/llm_api/openai.py b/letta/llm_api/openai.py index f08462f7a..5e7be70df 100644 --- a/letta/llm_api/openai.py +++ b/letta/llm_api/openai.py @@ -1,6 +1,7 @@ import warnings from typing import Generator, List, Optional, Union +import httpx import requests from openai import OpenAI @@ -110,6 +111,62 @@ def openai_get_model_list(url: str, api_key: Optional[str] = None, fix_url: bool raise e +async def openai_get_model_list_async( + url: str, + api_key: Optional[str] = None, + fix_url: bool = False, + extra_params: Optional[dict] = None, + client: Optional["httpx.AsyncClient"] = None, +) -> dict: + """https://platform.openai.com/docs/api-reference/models/list""" + from letta.utils import printd + + # In some cases we may want to double-check the URL and do basic correction + if fix_url and not url.endswith("/v1"): + url = smart_urljoin(url, "v1") + + url = smart_urljoin(url, "models") + + headers = {"Content-Type": "application/json"} + if api_key is not None: + headers["Authorization"] = f"Bearer {api_key}" + + printd(f"Sending request to {url}") + + # Use provided client or create a new one + close_client = False + if client is None: + client = httpx.AsyncClient() + close_client = True + + try: + response = await client.get(url, headers=headers, params=extra_params) + response.raise_for_status() + result = response.json() + printd(f"response = {result}") + return result + except httpx.HTTPStatusError as http_err: + # Handle HTTP errors (e.g., response 4XX, 5XX) + error_response = None + try: + error_response = http_err.response.json() + except: + error_response = {"status_code": http_err.response.status_code, "text": http_err.response.text} + printd(f"Got HTTPError, exception={http_err}, response={error_response}") + raise http_err + except httpx.RequestError as req_err: + # Handle other httpx-related errors (e.g., connection error) + printd(f"Got RequestException, exception={req_err}") + raise req_err + except Exception as e: + # Handle other potential errors + printd(f"Got unknown Exception, exception={e}") + raise e + finally: + if close_client: + await client.aclose() + + def build_openai_chat_completions_request( llm_config: LLMConfig, messages: List[_Message], diff --git a/letta/llm_api/openai_client.py b/letta/llm_api/openai_client.py index 61089bbf1..e6ac37a22 100644 --- a/letta/llm_api/openai_client.py +++ b/letta/llm_api/openai_client.py @@ -2,7 +2,7 @@ import os from typing import List, Optional import openai -from openai import AsyncOpenAI, AsyncStream, OpenAI, Stream +from openai import AsyncOpenAI, AsyncStream, OpenAI from openai.types.chat.chat_completion import ChatCompletion from openai.types.chat.chat_completion_chunk import ChatCompletionChunk @@ -22,7 +22,7 @@ from letta.llm_api.helpers import add_inner_thoughts_to_functions, convert_to_st from letta.llm_api.llm_client_base import LLMClientBase from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION, INNER_THOUGHTS_KWARG_DESCRIPTION_GO_FIRST from letta.log import get_logger -from letta.schemas.enums import ProviderCategory +from letta.schemas.enums import ProviderCategory, ProviderType from letta.schemas.llm_config import LLMConfig from letta.schemas.message import Message as PydanticMessage from letta.schemas.openai.chat_completion_request import ChatCompletionRequest @@ -113,6 +113,8 @@ class OpenAIClient(LLMClientBase): from letta.services.provider_manager import ProviderManager api_key = ProviderManager().get_override_key(llm_config.provider_name, actor=self.actor) + if llm_config.model_endpoint_type == ProviderType.together: + api_key = model_settings.together_api_key or os.environ.get("TOGETHER_API_KEY") if not api_key: api_key = model_settings.openai_api_key or os.environ.get("OPENAI_API_KEY") @@ -254,20 +256,14 @@ class OpenAIClient(LLMClientBase): return chat_completion_response - def stream(self, request_data: dict, llm_config: LLMConfig) -> Stream[ChatCompletionChunk]: - """ - Performs underlying streaming request to OpenAI and returns the stream iterator. - """ - client = OpenAI(**self._prepare_client_kwargs(llm_config)) - response_stream: Stream[ChatCompletionChunk] = client.chat.completions.create(**request_data, stream=True) - return response_stream - async def stream_async(self, request_data: dict, llm_config: LLMConfig) -> AsyncStream[ChatCompletionChunk]: """ Performs underlying asynchronous streaming request to OpenAI and returns the async stream iterator. """ client = AsyncOpenAI(**self._prepare_client_kwargs(llm_config)) - response_stream: AsyncStream[ChatCompletionChunk] = await client.chat.completions.create(**request_data, stream=True) + response_stream: AsyncStream[ChatCompletionChunk] = await client.chat.completions.create( + **request_data, stream=True, stream_options={"include_usage": True} + ) return response_stream def handle_llm_error(self, e: Exception) -> Exception: diff --git a/letta/memory.py b/letta/memory.py index 939e0874e..818f45ca5 100644 --- a/letta/memory.py +++ b/letta/memory.py @@ -93,7 +93,6 @@ def summarize_messages( response = llm_client.send_llm_request( messages=message_sequence, llm_config=llm_config_no_inner_thoughts, - stream=False, ) else: response = create( diff --git a/letta/orm/__init__.py b/letta/orm/__init__.py index 348cd19e4..de395e28c 100644 --- a/letta/orm/__init__.py +++ b/letta/orm/__init__.py @@ -19,6 +19,7 @@ from letta.orm.message import Message from letta.orm.organization import Organization from letta.orm.passage import AgentPassage, BasePassage, SourcePassage from letta.orm.provider import Provider +from letta.orm.provider_trace import ProviderTrace from letta.orm.sandbox_config import AgentEnvironmentVariable, SandboxConfig, SandboxEnvironmentVariable from letta.orm.source import Source from letta.orm.sources_agents import SourcesAgents diff --git a/letta/orm/enums.py b/letta/orm/enums.py index 784a5e56d..124339970 100644 --- a/letta/orm/enums.py +++ b/letta/orm/enums.py @@ -8,6 +8,7 @@ class ToolType(str, Enum): LETTA_MULTI_AGENT_CORE = "letta_multi_agent_core" LETTA_SLEEPTIME_CORE = "letta_sleeptime_core" LETTA_VOICE_SLEEPTIME_CORE = "letta_voice_sleeptime_core" + LETTA_BUILTIN = "letta_builtin" EXTERNAL_COMPOSIO = "external_composio" EXTERNAL_LANGCHAIN = "external_langchain" # TODO is "external" the right name here? Since as of now, MCP is local / doesn't support remote? diff --git a/letta/orm/provider_trace.py b/letta/orm/provider_trace.py new file mode 100644 index 000000000..69b7df14d --- /dev/null +++ b/letta/orm/provider_trace.py @@ -0,0 +1,26 @@ +import uuid + +from sqlalchemy import JSON, Index, String +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from letta.orm.mixins import OrganizationMixin +from letta.orm.sqlalchemy_base import SqlalchemyBase +from letta.schemas.provider_trace import ProviderTrace as PydanticProviderTrace + + +class ProviderTrace(SqlalchemyBase, OrganizationMixin): + """Defines data model for storing provider trace information""" + + __tablename__ = "provider_traces" + __pydantic_model__ = PydanticProviderTrace + __table_args__ = (Index("ix_step_id", "step_id"),) + + id: Mapped[str] = mapped_column( + primary_key=True, doc="Unique provider trace identifier", default=lambda: f"provider_trace-{uuid.uuid4()}" + ) + request_json: Mapped[dict] = mapped_column(JSON, doc="JSON content of the provider request") + response_json: Mapped[dict] = mapped_column(JSON, doc="JSON content of the provider response") + step_id: Mapped[str] = mapped_column(String, nullable=True, doc="ID of the step that this trace is associated with") + + # Relationships + organization: Mapped["Organization"] = relationship("Organization", lazy="selectin") diff --git a/letta/orm/step.py b/letta/orm/step.py index ce7b82442..bd03d935a 100644 --- a/letta/orm/step.py +++ b/letta/orm/step.py @@ -35,6 +35,7 @@ class Step(SqlalchemyBase): ) agent_id: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="The name of the model used for this step.") provider_name: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="The name of the provider used for this step.") + provider_category: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="The category of the provider used for this step.") model: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="The name of the model used for this step.") model_endpoint: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="The model endpoint url used for this step.") context_window_limit: Mapped[Optional[int]] = mapped_column( diff --git a/letta/schemas/provider_trace.py b/letta/schemas/provider_trace.py new file mode 100644 index 000000000..bcc151def --- /dev/null +++ b/letta/schemas/provider_trace.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +from datetime import datetime +from typing import Any, Dict, Optional + +from pydantic import BaseModel, Field + +from letta.helpers.datetime_helpers import get_utc_time +from letta.schemas.letta_base import OrmMetadataBase + + +class BaseProviderTrace(OrmMetadataBase): + __id_prefix__ = "provider_trace" + + +class ProviderTraceCreate(BaseModel): + """Request to create a provider trace""" + + request_json: dict[str, Any] = Field(..., description="JSON content of the provider request") + response_json: dict[str, Any] = Field(..., description="JSON content of the provider response") + step_id: str = Field(None, description="ID of the step that this trace is associated with") + organization_id: str = Field(..., description="The unique identifier of the organization.") + + +class ProviderTrace(BaseProviderTrace): + """ + Letta's internal representation of a provider trace. + + Attributes: + id (str): The unique identifier of the provider trace. + request_json (Dict[str, Any]): JSON content of the provider request. + response_json (Dict[str, Any]): JSON content of the provider response. + step_id (str): ID of the step that this trace is associated with. + organization_id (str): The unique identifier of the organization. + created_at (datetime): The timestamp when the object was created. + """ + + id: str = BaseProviderTrace.generate_id_field() + request_json: Dict[str, Any] = Field(..., description="JSON content of the provider request") + response_json: Dict[str, Any] = Field(..., description="JSON content of the provider response") + step_id: Optional[str] = Field(None, description="ID of the step that this trace is associated with") + organization_id: str = Field(..., description="The unique identifier of the organization.") + created_at: datetime = Field(default_factory=get_utc_time, description="The timestamp when the object was created.") diff --git a/letta/schemas/providers.py b/letta/schemas/providers.py index b55d92672..9f17737a4 100644 --- a/letta/schemas/providers.py +++ b/letta/schemas/providers.py @@ -47,12 +47,21 @@ class Provider(ProviderBase): def list_llm_models(self) -> List[LLMConfig]: return [] + async def list_llm_models_async(self) -> List[LLMConfig]: + return [] + def list_embedding_models(self) -> List[EmbeddingConfig]: return [] + async def list_embedding_models_async(self) -> List[EmbeddingConfig]: + return [] + def get_model_context_window(self, model_name: str) -> Optional[int]: raise NotImplementedError + async def get_model_context_window_async(self, model_name: str) -> Optional[int]: + raise NotImplementedError + def provider_tag(self) -> str: """String representation of the provider for display purposes""" raise NotImplementedError @@ -140,6 +149,19 @@ class LettaProvider(Provider): ) ] + async def list_llm_models_async(self) -> List[LLMConfig]: + return [ + LLMConfig( + model="letta-free", # NOTE: renamed + model_endpoint_type="openai", + model_endpoint=LETTA_MODEL_ENDPOINT, + context_window=8192, + handle=self.get_handle("letta-free"), + provider_name=self.name, + provider_category=self.provider_category, + ) + ] + def list_embedding_models(self): return [ EmbeddingConfig( @@ -189,9 +211,40 @@ class OpenAIProvider(Provider): return data + async def _get_models_async(self) -> List[dict]: + from letta.llm_api.openai import openai_get_model_list_async + + # Some hardcoded support for OpenRouter (so that we only get models with tool calling support)... + # See: https://openrouter.ai/docs/requests + extra_params = {"supported_parameters": "tools"} if "openrouter.ai" in self.base_url else None + + # Similar to Nebius + extra_params = {"verbose": True} if "nebius.com" in self.base_url else None + + response = await openai_get_model_list_async( + self.base_url, + api_key=self.api_key, + extra_params=extra_params, + # fix_url=True, # NOTE: make sure together ends with /v1 + ) + + if "data" in response: + data = response["data"] + else: + # TogetherAI's response is missing the 'data' field + data = response + + return data + def list_llm_models(self) -> List[LLMConfig]: data = self._get_models() + return self._list_llm_models(data) + async def list_llm_models_async(self) -> List[LLMConfig]: + data = await self._get_models_async() + return self._list_llm_models(data) + + def _list_llm_models(self, data) -> List[LLMConfig]: configs = [] for model in data: assert "id" in model, f"OpenAI model missing 'id' field: {model}" @@ -279,7 +332,6 @@ class OpenAIProvider(Provider): return configs def list_embedding_models(self) -> List[EmbeddingConfig]: - if self.base_url == "https://api.openai.com/v1": # TODO: actually automatically list models for OpenAI return [ @@ -312,55 +364,92 @@ class OpenAIProvider(Provider): else: # Actually attempt to list data = self._get_models() + return self._list_embedding_models(data) - configs = [] - for model in data: - assert "id" in model, f"Model missing 'id' field: {model}" - model_name = model["id"] + async def list_embedding_models_async(self) -> List[EmbeddingConfig]: + if self.base_url == "https://api.openai.com/v1": + # TODO: actually automatically list models for OpenAI + return [ + EmbeddingConfig( + embedding_model="text-embedding-ada-002", + embedding_endpoint_type="openai", + embedding_endpoint=self.base_url, + embedding_dim=1536, + embedding_chunk_size=300, + handle=self.get_handle("text-embedding-ada-002", is_embedding=True), + ), + EmbeddingConfig( + embedding_model="text-embedding-3-small", + embedding_endpoint_type="openai", + embedding_endpoint=self.base_url, + embedding_dim=2000, + embedding_chunk_size=300, + handle=self.get_handle("text-embedding-3-small", is_embedding=True), + ), + EmbeddingConfig( + embedding_model="text-embedding-3-large", + embedding_endpoint_type="openai", + embedding_endpoint=self.base_url, + embedding_dim=2000, + embedding_chunk_size=300, + handle=self.get_handle("text-embedding-3-large", is_embedding=True), + ), + ] - if "context_length" in model: - # Context length is returned in Nebius as "context_length" - context_window_size = model["context_length"] - else: - context_window_size = self.get_model_context_window_size(model_name) + else: + # Actually attempt to list + data = await self._get_models_async() + return self._list_embedding_models(data) - # We need the context length for embeddings too - if not context_window_size: - continue + def _list_embedding_models(self, data) -> List[EmbeddingConfig]: + configs = [] + for model in data: + assert "id" in model, f"Model missing 'id' field: {model}" + model_name = model["id"] - if "nebius.com" in self.base_url: - # Nebius includes the type, which we can use to filter for embedidng models - try: - model_type = model["architecture"]["modality"] - if model_type not in ["text->embedding"]: - # print(f"Skipping model w/ modality {model_type}:\n{model}") - continue - except KeyError: - print(f"Couldn't access architecture type field, skipping model:\n{model}") - continue + if "context_length" in model: + # Context length is returned in Nebius as "context_length" + context_window_size = model["context_length"] + else: + context_window_size = self.get_model_context_window_size(model_name) - elif "together.ai" in self.base_url or "together.xyz" in self.base_url: - # TogetherAI includes the type, which we can use to filter for embedding models - if "type" in model and model["type"] not in ["embedding"]: + # We need the context length for embeddings too + if not context_window_size: + continue + + if "nebius.com" in self.base_url: + # Nebius includes the type, which we can use to filter for embedidng models + try: + model_type = model["architecture"]["modality"] + if model_type not in ["text->embedding"]: # print(f"Skipping model w/ modality {model_type}:\n{model}") continue - - else: - # For other providers we should skip by default, since we don't want to assume embeddings are supported + except KeyError: + print(f"Couldn't access architecture type field, skipping model:\n{model}") continue - configs.append( - EmbeddingConfig( - embedding_model=model_name, - embedding_endpoint_type=self.provider_type, - embedding_endpoint=self.base_url, - embedding_dim=context_window_size, - embedding_chunk_size=DEFAULT_EMBEDDING_CHUNK_SIZE, - handle=self.get_handle(model, is_embedding=True), - ) - ) + elif "together.ai" in self.base_url or "together.xyz" in self.base_url: + # TogetherAI includes the type, which we can use to filter for embedding models + if "type" in model and model["type"] not in ["embedding"]: + # print(f"Skipping model w/ modality {model_type}:\n{model}") + continue - return configs + else: + # For other providers we should skip by default, since we don't want to assume embeddings are supported + continue + + configs.append( + EmbeddingConfig( + embedding_model=model_name, + embedding_endpoint_type=self.provider_type, + embedding_endpoint=self.base_url, + embedding_dim=context_window_size, + embedding_chunk_size=DEFAULT_EMBEDDING_CHUNK_SIZE, + handle=self.get_handle(model, is_embedding=True), + ) + ) + + return configs def get_model_context_window_size(self, model_name: str): if model_name in LLM_MAX_TOKENS: @@ -647,26 +736,19 @@ class AnthropicProvider(Provider): anthropic_check_valid_api_key(self.api_key) def list_llm_models(self) -> List[LLMConfig]: - from letta.llm_api.anthropic import MODEL_LIST, anthropic_get_model_list + from letta.llm_api.anthropic import anthropic_get_model_list - models = anthropic_get_model_list(self.base_url, api_key=self.api_key) + models = anthropic_get_model_list(api_key=self.api_key) + return self._list_llm_models(models) - """ - Example response: - { - "data": [ - { - "type": "model", - "id": "claude-3-5-sonnet-20241022", - "display_name": "Claude 3.5 Sonnet (New)", - "created_at": "2024-10-22T00:00:00Z" - } - ], - "has_more": true, - "first_id": "", - "last_id": "" - } - """ + async def list_llm_models_async(self) -> List[LLMConfig]: + from letta.llm_api.anthropic import anthropic_get_model_list_async + + models = await anthropic_get_model_list_async(api_key=self.api_key) + return self._list_llm_models(models) + + def _list_llm_models(self, models) -> List[LLMConfig]: + from letta.llm_api.anthropic import MODEL_LIST configs = [] for model in models: @@ -724,9 +806,6 @@ class AnthropicProvider(Provider): ) return configs - def list_embedding_models(self) -> List[EmbeddingConfig]: - return [] - class MistralProvider(Provider): provider_type: Literal[ProviderType.mistral] = Field(ProviderType.mistral, description="The type of the provider.") @@ -948,14 +1027,24 @@ class TogetherProvider(OpenAIProvider): def list_llm_models(self) -> List[LLMConfig]: from letta.llm_api.openai import openai_get_model_list - response = openai_get_model_list(self.base_url, api_key=self.api_key) + models = openai_get_model_list(self.base_url, api_key=self.api_key) + return self._list_llm_models(models) + + async def list_llm_models_async(self) -> List[LLMConfig]: + from letta.llm_api.openai import openai_get_model_list_async + + models = await openai_get_model_list_async(self.base_url, api_key=self.api_key) + return self._list_llm_models(models) + + def _list_llm_models(self, models) -> List[LLMConfig]: + pass # TogetherAI's response is missing the 'data' field # assert "data" in response, f"OpenAI model query response missing 'data' field: {response}" - if "data" in response: - data = response["data"] + if "data" in models: + data = models["data"] else: - data = response + data = models configs = [] for model in data: @@ -1057,7 +1146,6 @@ class GoogleAIProvider(Provider): from letta.llm_api.google_ai_client import google_ai_get_model_list model_options = google_ai_get_model_list(base_url=self.base_url, api_key=self.api_key) - # filter by 'generateContent' models model_options = [mo for mo in model_options if "generateContent" in mo["supportedGenerationMethods"]] model_options = [str(m["name"]) for m in model_options] @@ -1081,6 +1169,42 @@ class GoogleAIProvider(Provider): provider_category=self.provider_category, ) ) + + return configs + + async def list_llm_models_async(self): + import asyncio + + from letta.llm_api.google_ai_client import google_ai_get_model_list_async + + # Get and filter the model list + model_options = await google_ai_get_model_list_async(base_url=self.base_url, api_key=self.api_key) + model_options = [mo for mo in model_options if "generateContent" in mo["supportedGenerationMethods"]] + model_options = [str(m["name"]) for m in model_options] + + # filter by model names + model_options = [mo[len("models/") :] if mo.startswith("models/") else mo for mo in model_options] + + # Add support for all gemini models + model_options = [mo for mo in model_options if str(mo).startswith("gemini-")] + + # Prepare tasks for context window lookups in parallel + async def create_config(model): + context_window = await self.get_model_context_window_async(model) + return LLMConfig( + model=model, + model_endpoint_type="google_ai", + model_endpoint=self.base_url, + context_window=context_window, + handle=self.get_handle(model), + max_tokens=8192, + provider_name=self.name, + provider_category=self.provider_category, + ) + + # Execute all config creation tasks concurrently + configs = await asyncio.gather(*[create_config(model) for model in model_options]) + return configs def list_embedding_models(self): @@ -1088,6 +1212,16 @@ class GoogleAIProvider(Provider): # TODO: use base_url instead model_options = google_ai_get_model_list(base_url=self.base_url, api_key=self.api_key) + return self._list_embedding_models(model_options) + + async def list_embedding_models_async(self): + from letta.llm_api.google_ai_client import google_ai_get_model_list_async + + # TODO: use base_url instead + model_options = await google_ai_get_model_list_async(base_url=self.base_url, api_key=self.api_key) + return self._list_embedding_models(model_options) + + def _list_embedding_models(self, model_options): # filter by 'generateContent' models model_options = [mo for mo in model_options if "embedContent" in mo["supportedGenerationMethods"]] model_options = [str(m["name"]) for m in model_options] @@ -1110,7 +1244,18 @@ class GoogleAIProvider(Provider): def get_model_context_window(self, model_name: str) -> Optional[int]: from letta.llm_api.google_ai_client import google_ai_get_model_context_window - return google_ai_get_model_context_window(self.base_url, self.api_key, model_name) + if model_name in LLM_MAX_TOKENS: + return LLM_MAX_TOKENS[model_name] + else: + return google_ai_get_model_context_window(self.base_url, self.api_key, model_name) + + async def get_model_context_window_async(self, model_name: str) -> Optional[int]: + from letta.llm_api.google_ai_client import google_ai_get_model_context_window_async + + if model_name in LLM_MAX_TOKENS: + return LLM_MAX_TOKENS[model_name] + else: + return await google_ai_get_model_context_window_async(self.base_url, self.api_key, model_name) class GoogleVertexProvider(Provider): diff --git a/letta/schemas/step.py b/letta/schemas/step.py index d25d8b684..2e0604d85 100644 --- a/letta/schemas/step.py +++ b/letta/schemas/step.py @@ -20,6 +20,7 @@ class Step(StepBase): ) agent_id: Optional[str] = Field(None, description="The ID of the agent that performed the step.") provider_name: Optional[str] = Field(None, description="The name of the provider used for this step.") + provider_category: Optional[str] = Field(None, description="The category of the provider used for this step.") model: Optional[str] = Field(None, description="The name of the model used for this step.") model_endpoint: Optional[str] = Field(None, description="The model endpoint url used for this step.") context_window_limit: Optional[int] = Field(None, description="The context window limit configured for this step.") diff --git a/letta/schemas/tool.py b/letta/schemas/tool.py index 6c8f9bd3e..ccc376d64 100644 --- a/letta/schemas/tool.py +++ b/letta/schemas/tool.py @@ -5,6 +5,7 @@ from pydantic import Field, model_validator from letta.constants import ( COMPOSIO_TOOL_TAG_NAME, FUNCTION_RETURN_CHAR_LIMIT, + LETTA_BUILTIN_TOOL_MODULE_NAME, LETTA_CORE_TOOL_MODULE_NAME, LETTA_MULTI_AGENT_TOOL_MODULE_NAME, LETTA_VOICE_TOOL_MODULE_NAME, @@ -104,6 +105,9 @@ class Tool(BaseTool): elif self.tool_type in {ToolType.LETTA_VOICE_SLEEPTIME_CORE}: # If it's letta voice tool, we generate the json_schema on the fly here self.json_schema = get_json_schema_from_module(module_name=LETTA_VOICE_TOOL_MODULE_NAME, function_name=self.name) + elif self.tool_type in {ToolType.LETTA_BUILTIN}: + # If it's letta voice tool, we generate the json_schema on the fly here + self.json_schema = get_json_schema_from_module(module_name=LETTA_BUILTIN_TOOL_MODULE_NAME, function_name=self.name) # At this point, we need to validate that at least json_schema is populated if not self.json_schema: diff --git a/letta/server/db.py b/letta/server/db.py index 32dbb13ef..fe9abcff3 100644 --- a/letta/server/db.py +++ b/letta/server/db.py @@ -6,7 +6,7 @@ from typing import Any, AsyncGenerator, Generator from rich.console import Console from rich.panel import Panel from rich.text import Text -from sqlalchemy import Engine, create_engine +from sqlalchemy import Engine, NullPool, QueuePool, create_engine from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine from sqlalchemy.orm import sessionmaker @@ -14,6 +14,8 @@ from letta.config import LettaConfig from letta.log import get_logger from letta.settings import settings +logger = get_logger(__name__) + def print_sqlite_schema_error(): """Print a formatted error message for SQLite schema issues""" @@ -76,16 +78,7 @@ class DatabaseRegistry: self.config.archival_storage_type = "postgres" self.config.archival_storage_uri = settings.letta_pg_uri_no_default - engine = create_engine( - settings.letta_pg_uri, - # f"{settings.letta_pg_uri}?options=-c%20client_encoding=UTF8", - pool_size=settings.pg_pool_size, - max_overflow=settings.pg_max_overflow, - pool_timeout=settings.pg_pool_timeout, - pool_recycle=settings.pg_pool_recycle, - echo=settings.pg_echo, - # connect_args={"client_encoding": "utf8"}, - ) + engine = create_engine(settings.letta_pg_uri, **self._build_sqlalchemy_engine_args(is_async=False)) self._engines["default"] = engine # SQLite engine @@ -125,14 +118,7 @@ class DatabaseRegistry: async_pg_uri = f"postgresql+asyncpg://{pg_uri.split('://', 1)[1]}" if "://" in pg_uri else pg_uri async_pg_uri = async_pg_uri.replace("sslmode=", "ssl=") - async_engine = create_async_engine( - async_pg_uri, - pool_size=settings.pg_pool_size, - max_overflow=settings.pg_max_overflow, - pool_timeout=settings.pg_pool_timeout, - pool_recycle=settings.pg_pool_recycle, - echo=settings.pg_echo, - ) + async_engine = create_async_engine(async_pg_uri, **self._build_sqlalchemy_engine_args(is_async=True)) self._async_engines["default"] = async_engine @@ -146,6 +132,38 @@ class DatabaseRegistry: # TODO (cliandy): unclear around async sqlite support in sqlalchemy, we will not currently support this self._initialized["async"] = False + def _build_sqlalchemy_engine_args(self, *, is_async: bool) -> dict: + """Prepare keyword arguments for create_engine / create_async_engine.""" + use_null_pool = settings.disable_sqlalchemy_pooling + + if use_null_pool: + logger.info("Disabling pooling on SqlAlchemy") + pool_cls = NullPool + else: + logger.info("Enabling pooling on SqlAlchemy") + pool_cls = QueuePool if not is_async else None + + base_args = { + "echo": settings.pg_echo, + "pool_pre_ping": settings.pool_pre_ping, + } + + if pool_cls: + base_args["poolclass"] = pool_cls + + if not use_null_pool and not is_async: + base_args.update( + { + "pool_size": settings.pg_pool_size, + "max_overflow": settings.pg_max_overflow, + "pool_timeout": settings.pg_pool_timeout, + "pool_recycle": settings.pg_pool_recycle, + "pool_use_lifo": settings.pool_use_lifo, + } + ) + + return base_args + def _wrap_sqlite_engine(self, engine: Engine) -> None: """Wrap SQLite engine with error handling.""" original_connect = engine.connect diff --git a/letta/server/rest_api/routers/v1/__init__.py b/letta/server/rest_api/routers/v1/__init__.py index 666aeedc8..4607f8f9b 100644 --- a/letta/server/rest_api/routers/v1/__init__.py +++ b/letta/server/rest_api/routers/v1/__init__.py @@ -13,6 +13,7 @@ from letta.server.rest_api.routers.v1.sandbox_configs import router as sandbox_c from letta.server.rest_api.routers.v1.sources import router as sources_router from letta.server.rest_api.routers.v1.steps import router as steps_router from letta.server.rest_api.routers.v1.tags import router as tags_router +from letta.server.rest_api.routers.v1.telemetry import router as telemetry_router from letta.server.rest_api.routers.v1.tools import router as tools_router from letta.server.rest_api.routers.v1.voice import router as voice_router @@ -31,6 +32,7 @@ ROUTERS = [ runs_router, steps_router, tags_router, + telemetry_router, messages_router, voice_router, embeddings_router, diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index 96f153f3e..11b21b950 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -33,6 +33,7 @@ from letta.schemas.user import User from letta.serialize_schemas.pydantic_agent_schema import AgentSchema from letta.server.rest_api.utils import get_letta_server from letta.server.server import SyncServer +from letta.services.telemetry_manager import NoopTelemetryManager from letta.settings import settings # These can be forward refs, but because Fastapi needs them at runtime the must be imported normally @@ -106,14 +107,15 @@ async def list_agents( @router.get("/count", response_model=int, operation_id="count_agents") -def count_agents( +async def count_agents( server: SyncServer = Depends(get_letta_server), actor_id: Optional[str] = Header(None, alias="user_id"), ): """ Get the count of all agents associated with a given user. """ - return server.agent_manager.size(actor=server.user_manager.get_user_or_default(user_id=actor_id)) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) + return await server.agent_manager.size_async(actor=actor) class IndentedORJSONResponse(Response): @@ -124,7 +126,7 @@ class IndentedORJSONResponse(Response): @router.get("/{agent_id}/export", response_class=IndentedORJSONResponse, operation_id="export_agent_serialized") -def export_agent_serialized( +async def export_agent_serialized( agent_id: str, server: "SyncServer" = Depends(get_letta_server), actor_id: Optional[str] = Header(None, alias="user_id"), @@ -135,7 +137,7 @@ def export_agent_serialized( """ Export the serialized JSON representation of an agent, formatted with indentation. """ - actor = server.user_manager.get_user_or_default(user_id=actor_id) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) try: agent = server.agent_manager.serialize(agent_id=agent_id, actor=actor) @@ -200,7 +202,7 @@ async def import_agent_serialized( @router.get("/{agent_id}/context", response_model=ContextWindowOverview, operation_id="retrieve_agent_context_window") -def retrieve_agent_context_window( +async def retrieve_agent_context_window( agent_id: str, server: "SyncServer" = Depends(get_letta_server), actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present @@ -208,9 +210,12 @@ def retrieve_agent_context_window( """ Retrieve the context window of a specific agent. """ - actor = server.user_manager.get_user_or_default(user_id=actor_id) - - return server.get_agent_context_window(agent_id=agent_id, actor=actor) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) + try: + return await server.get_agent_context_window_async(agent_id=agent_id, actor=actor) + except Exception as e: + traceback.print_exc() + raise e class CreateAgentRequest(CreateAgent): @@ -341,7 +346,7 @@ async def retrieve_agent( """ Get the state of the agent. """ - actor = server.user_manager.get_user_or_default(user_id=actor_id) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) try: return await server.agent_manager.get_agent_by_id_async(agent_id=agent_id, actor=actor) @@ -367,7 +372,7 @@ def delete_agent( @router.get("/{agent_id}/sources", response_model=List[Source], operation_id="list_agent_sources") -def list_agent_sources( +async def list_agent_sources( agent_id: str, server: "SyncServer" = Depends(get_letta_server), actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present @@ -375,8 +380,8 @@ def list_agent_sources( """ Get the sources associated with an agent. """ - actor = server.user_manager.get_user_or_default(user_id=actor_id) - return server.agent_manager.list_attached_sources(agent_id=agent_id, actor=actor) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) + return await server.agent_manager.list_attached_sources_async(agent_id=agent_id, actor=actor) # TODO: remove? can also get with agent blocks @@ -424,14 +429,14 @@ async def list_blocks( """ actor = server.user_manager.get_user_or_default(user_id=actor_id) try: - agent = await server.agent_manager.get_agent_by_id_async(agent_id, actor) + agent = await server.agent_manager.get_agent_by_id_async(agent_id=agent_id, include_relationships=["memory"], actor=actor) return agent.memory.blocks except NoResultFound as e: raise HTTPException(status_code=404, detail=str(e)) @router.patch("/{agent_id}/core-memory/blocks/{block_label}", response_model=Block, operation_id="modify_core_memory_block") -def modify_block( +async def modify_block( agent_id: str, block_label: str, block_update: BlockUpdate = Body(...), @@ -441,10 +446,11 @@ def modify_block( """ Updates a core memory block of an agent. """ - actor = server.user_manager.get_user_or_default(user_id=actor_id) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) - block = server.agent_manager.get_block_with_label(agent_id=agent_id, block_label=block_label, actor=actor) - block = server.block_manager.update_block(block.id, block_update=block_update, actor=actor) + block = await server.agent_manager.modify_block_by_label_async( + agent_id=agent_id, block_label=block_label, block_update=block_update, actor=actor + ) # This should also trigger a system prompt change in the agent server.agent_manager.rebuild_system_prompt(agent_id=agent_id, actor=actor, force=True, update_timestamp=False) @@ -481,7 +487,7 @@ def detach_block( @router.get("/{agent_id}/archival-memory", response_model=List[Passage], operation_id="list_passages") -def list_passages( +async def list_passages( agent_id: str, server: "SyncServer" = Depends(get_letta_server), after: Optional[str] = Query(None, description="Unique ID of the memory to start the query range at."), @@ -496,11 +502,11 @@ def list_passages( """ Retrieve the memories in an agent's archival memory store (paginated query). """ - actor = server.user_manager.get_user_or_default(user_id=actor_id) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) - return server.get_agent_archival( - user_id=actor.id, + return await server.get_agent_archival_async( agent_id=agent_id, + actor=actor, after=after, before=before, query_text=search, @@ -564,7 +570,7 @@ AgentMessagesResponse = Annotated[ @router.get("/{agent_id}/messages", response_model=AgentMessagesResponse, operation_id="list_messages") -def list_messages( +async def list_messages( agent_id: str, server: "SyncServer" = Depends(get_letta_server), after: Optional[str] = Query(None, description="Message after which to retrieve the returned messages."), @@ -579,10 +585,9 @@ def list_messages( """ Retrieve message history for an agent. """ - actor = server.user_manager.get_user_or_default(user_id=actor_id) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) - return server.get_agent_recall( - user_id=actor.id, + return await server.get_agent_recall_async( agent_id=agent_id, after=after, before=before, @@ -593,6 +598,7 @@ def list_messages( use_assistant_message=use_assistant_message, assistant_message_tool_name=assistant_message_tool_name, assistant_message_tool_kwarg=assistant_message_tool_kwarg, + actor=actor, ) @@ -634,7 +640,7 @@ async def send_message( agent_eligible = not agent.enable_sleeptime and not agent.multi_agent_group and agent.agent_type != AgentType.sleeptime_agent experimental_header = request_obj.headers.get("X-EXPERIMENTAL") or "false" feature_enabled = settings.use_experimental or experimental_header.lower() == "true" - model_compatible = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "google_vertex", "google_ai"] + model_compatible = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "together", "google_ai", "google_vertex"] if agent_eligible and feature_enabled and model_compatible: experimental_agent = LettaAgent( @@ -644,6 +650,8 @@ async def send_message( block_manager=server.block_manager, passage_manager=server.passage_manager, actor=actor, + step_manager=server.step_manager, + telemetry_manager=server.telemetry_manager if settings.llm_api_logging else NoopTelemetryManager(), ) result = await experimental_agent.step(request.messages, max_steps=10, use_assistant_message=request.use_assistant_message) @@ -692,7 +700,8 @@ async def send_message_streaming( agent_eligible = not agent.enable_sleeptime and not agent.multi_agent_group and agent.agent_type != AgentType.sleeptime_agent experimental_header = request_obj.headers.get("X-EXPERIMENTAL") or "false" feature_enabled = settings.use_experimental or experimental_header.lower() == "true" - model_compatible = agent.llm_config.model_endpoint_type in ["anthropic", "openai"] + model_compatible = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "together", "google_ai", "google_vertex"] + model_compatible_token_streaming = agent.llm_config.model_endpoint_type in ["anthropic", "openai"] if agent_eligible and feature_enabled and model_compatible and request.stream_tokens: experimental_agent = LettaAgent( @@ -702,14 +711,28 @@ async def send_message_streaming( block_manager=server.block_manager, passage_manager=server.passage_manager, actor=actor, + step_manager=server.step_manager, + telemetry_manager=server.telemetry_manager if settings.llm_api_logging else NoopTelemetryManager(), ) + from letta.server.rest_api.streaming_response import StreamingResponseWithStatusCode - result = StreamingResponse( - experimental_agent.step_stream( - request.messages, max_steps=10, use_assistant_message=request.use_assistant_message, stream_tokens=request.stream_tokens - ), - media_type="text/event-stream", - ) + if request.stream_tokens and model_compatible_token_streaming: + result = StreamingResponseWithStatusCode( + experimental_agent.step_stream( + input_messages=request.messages, + max_steps=10, + use_assistant_message=request.use_assistant_message, + request_start_timestamp_ns=request_start_timestamp_ns, + ), + media_type="text/event-stream", + ) + else: + result = StreamingResponseWithStatusCode( + experimental_agent.step_stream_no_tokens( + request.messages, max_steps=10, use_assistant_message=request.use_assistant_message + ), + media_type="text/event-stream", + ) else: result = await server.send_message_to_agent( agent_id=agent_id, diff --git a/letta/server/rest_api/routers/v1/blocks.py b/letta/server/rest_api/routers/v1/blocks.py index c95069062..bf669f432 100644 --- a/letta/server/rest_api/routers/v1/blocks.py +++ b/letta/server/rest_api/routers/v1/blocks.py @@ -99,7 +99,7 @@ def retrieve_block( @router.get("/{block_id}/agents", response_model=List[AgentState], operation_id="list_agents_for_block") -def list_agents_for_block( +async def list_agents_for_block( block_id: str, server: SyncServer = Depends(get_letta_server), actor_id: Optional[str] = Header(None, alias="user_id"), @@ -108,9 +108,9 @@ def list_agents_for_block( Retrieves all agents associated with the specified block. Raises a 404 if the block does not exist. """ - actor = server.user_manager.get_user_or_default(user_id=actor_id) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) try: - agents = server.block_manager.get_agents_for_block(block_id=block_id, actor=actor) + agents = await server.block_manager.get_agents_for_block_async(block_id=block_id, actor=actor) return agents except NoResultFound: raise HTTPException(status_code=404, detail=f"Block with id={block_id} not found") diff --git a/letta/server/rest_api/routers/v1/identities.py b/letta/server/rest_api/routers/v1/identities.py index dd48fd4e7..16cdbb26e 100644 --- a/letta/server/rest_api/routers/v1/identities.py +++ b/letta/server/rest_api/routers/v1/identities.py @@ -13,7 +13,7 @@ router = APIRouter(prefix="/identities", tags=["identities"]) @router.get("/", tags=["identities"], response_model=List[Identity], operation_id="list_identities") -def list_identities( +async def list_identities( name: Optional[str] = Query(None), project_id: Optional[str] = Query(None), identifier_key: Optional[str] = Query(None), @@ -28,9 +28,9 @@ def list_identities( Get a list of all identities in the database """ try: - actor = server.user_manager.get_user_or_default(user_id=actor_id) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) - identities = server.identity_manager.list_identities( + identities = await server.identity_manager.list_identities_async( name=name, project_id=project_id, identifier_key=identifier_key, @@ -50,7 +50,7 @@ def list_identities( @router.get("/count", tags=["identities"], response_model=int, operation_id="count_identities") -def count_identities( +async def count_identities( server: "SyncServer" = Depends(get_letta_server), actor_id: Optional[str] = Header(None, alias="user_id"), ): @@ -58,7 +58,8 @@ def count_identities( Get count of all identities for a user """ try: - return server.identity_manager.size(actor=server.user_manager.get_user_or_default(user_id=actor_id)) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) + return await server.identity_manager.size_async(actor=actor) except NoResultFound: return 0 except HTTPException: @@ -68,28 +69,28 @@ def count_identities( @router.get("/{identity_id}", tags=["identities"], response_model=Identity, operation_id="retrieve_identity") -def retrieve_identity( +async def retrieve_identity( identity_id: str, server: "SyncServer" = Depends(get_letta_server), actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): try: - actor = server.user_manager.get_user_or_default(user_id=actor_id) - return server.identity_manager.get_identity(identity_id=identity_id, actor=actor) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) + return await server.identity_manager.get_identity_async(identity_id=identity_id, actor=actor) except NoResultFound as e: raise HTTPException(status_code=404, detail=str(e)) @router.post("/", tags=["identities"], response_model=Identity, operation_id="create_identity") -def create_identity( +async def create_identity( identity: IdentityCreate = Body(...), server: "SyncServer" = Depends(get_letta_server), actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present x_project: Optional[str] = Header(None, alias="X-Project"), # Only handled by next js middleware ): try: - actor = server.user_manager.get_user_or_default(user_id=actor_id) - return server.identity_manager.create_identity(identity=identity, actor=actor) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) + return await server.identity_manager.create_identity_async(identity=identity, actor=actor) except HTTPException: raise except UniqueConstraintViolationError: @@ -105,15 +106,15 @@ def create_identity( @router.put("/", tags=["identities"], response_model=Identity, operation_id="upsert_identity") -def upsert_identity( +async def upsert_identity( identity: IdentityUpsert = Body(...), server: "SyncServer" = Depends(get_letta_server), actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present x_project: Optional[str] = Header(None, alias="X-Project"), # Only handled by next js middleware ): try: - actor = server.user_manager.get_user_or_default(user_id=actor_id) - return server.identity_manager.upsert_identity(identity=identity, actor=actor) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) + return await server.identity_manager.upsert_identity_async(identity=identity, actor=actor) except HTTPException: raise except NoResultFound as e: @@ -123,36 +124,33 @@ def upsert_identity( @router.patch("/{identity_id}", tags=["identities"], response_model=Identity, operation_id="update_identity") -def modify_identity( +async def modify_identity( identity_id: str, identity: IdentityUpdate = Body(...), server: "SyncServer" = Depends(get_letta_server), actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): try: - actor = server.user_manager.get_user_or_default(user_id=actor_id) - return server.identity_manager.update_identity(identity_id=identity_id, identity=identity, actor=actor) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) + return await server.identity_manager.update_identity_async(identity_id=identity_id, identity=identity, actor=actor) except HTTPException: raise except NoResultFound as e: raise HTTPException(status_code=404, detail=str(e)) except Exception as e: - import traceback - - print(traceback.format_exc()) raise HTTPException(status_code=500, detail=f"{e}") @router.put("/{identity_id}/properties", tags=["identities"], operation_id="upsert_identity_properties") -def upsert_identity_properties( +async def upsert_identity_properties( identity_id: str, properties: List[IdentityProperty] = Body(...), server: "SyncServer" = Depends(get_letta_server), actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): try: - actor = server.user_manager.get_user_or_default(user_id=actor_id) - return server.identity_manager.upsert_identity_properties(identity_id=identity_id, properties=properties, actor=actor) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) + return await server.identity_manager.upsert_identity_properties_async(identity_id=identity_id, properties=properties, actor=actor) except HTTPException: raise except NoResultFound as e: @@ -162,7 +160,7 @@ def upsert_identity_properties( @router.delete("/{identity_id}", tags=["identities"], operation_id="delete_identity") -def delete_identity( +async def delete_identity( identity_id: str, server: "SyncServer" = Depends(get_letta_server), actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present @@ -171,8 +169,8 @@ def delete_identity( Delete an identity by its identifier key """ try: - actor = server.user_manager.get_user_or_default(user_id=actor_id) - server.identity_manager.delete_identity(identity_id=identity_id, actor=actor) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) + await server.identity_manager.delete_identity_async(identity_id=identity_id, actor=actor) except HTTPException: raise except NoResultFound as e: diff --git a/letta/server/rest_api/routers/v1/jobs.py b/letta/server/rest_api/routers/v1/jobs.py index 8adbdd2de..9c0cba4e5 100644 --- a/letta/server/rest_api/routers/v1/jobs.py +++ b/letta/server/rest_api/routers/v1/jobs.py @@ -33,16 +33,16 @@ def list_jobs( @router.get("/active", response_model=List[Job], operation_id="list_active_jobs") -def list_active_jobs( +async def list_active_jobs( server: "SyncServer" = Depends(get_letta_server), actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ List all active jobs. """ - actor = server.user_manager.get_user_or_default(user_id=actor_id) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) - return server.job_manager.list_jobs(actor=actor, statuses=[JobStatus.created, JobStatus.running]) + return await server.job_manager.list_jobs_async(actor=actor, statuses=[JobStatus.created, JobStatus.running]) @router.get("/{job_id}", response_model=Job, operation_id="retrieve_job") diff --git a/letta/server/rest_api/routers/v1/llms.py b/letta/server/rest_api/routers/v1/llms.py index 450f86086..485563821 100644 --- a/letta/server/rest_api/routers/v1/llms.py +++ b/letta/server/rest_api/routers/v1/llms.py @@ -14,30 +14,35 @@ router = APIRouter(prefix="/models", tags=["models", "llms"]) @router.get("/", response_model=List[LLMConfig], operation_id="list_models") -def list_llm_models( +async def list_llm_models( provider_category: Optional[List[ProviderCategory]] = Query(None), provider_name: Optional[str] = Query(None), provider_type: Optional[ProviderType] = Query(None), server: "SyncServer" = Depends(get_letta_server), - actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present + actor_id: Optional[str] = Header(None, alias="user_id"), + # Extract user_id from header, default to None if not present ): + """List available LLM models using the asynchronous implementation for improved performance""" actor = server.user_manager.get_user_or_default(user_id=actor_id) - models = server.list_llm_models( + + models = await server.list_llm_models_async( provider_category=provider_category, provider_name=provider_name, provider_type=provider_type, actor=actor, ) - # print(models) + return models @router.get("/embedding", response_model=List[EmbeddingConfig], operation_id="list_embedding_models") -def list_embedding_models( +async def list_embedding_models( server: "SyncServer" = Depends(get_letta_server), - actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present + actor_id: Optional[str] = Header(None, alias="user_id"), + # Extract user_id from header, default to None if not present ): + """List available embedding models using the asynchronous implementation for improved performance""" actor = server.user_manager.get_user_or_default(user_id=actor_id) - models = server.list_embedding_models(actor=actor) - # print(models) + models = await server.list_embedding_models_async(actor=actor) + return models diff --git a/letta/server/rest_api/routers/v1/sandbox_configs.py b/letta/server/rest_api/routers/v1/sandbox_configs.py index 6ef76a5b4..505e08a3d 100644 --- a/letta/server/rest_api/routers/v1/sandbox_configs.py +++ b/letta/server/rest_api/routers/v1/sandbox_configs.py @@ -100,15 +100,15 @@ def delete_sandbox_config( @router.get("/", response_model=List[PydanticSandboxConfig]) -def list_sandbox_configs( +async def list_sandbox_configs( limit: int = Query(1000, description="Number of results to return"), after: Optional[str] = Query(None, description="Pagination cursor to fetch the next set of results"), sandbox_type: Optional[SandboxType] = Query(None, description="Filter for this specific sandbox type"), server: SyncServer = Depends(get_letta_server), actor_id: str = Depends(get_user_id), ): - actor = server.user_manager.get_user_or_default(user_id=actor_id) - return server.sandbox_config_manager.list_sandbox_configs(actor, limit=limit, after=after, sandbox_type=sandbox_type) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) + return await server.sandbox_config_manager.list_sandbox_configs_async(actor, limit=limit, after=after, sandbox_type=sandbox_type) @router.post("/local/recreate-venv", response_model=PydanticSandboxConfig) @@ -190,12 +190,12 @@ def delete_sandbox_env_var( @router.get("/{sandbox_config_id}/environment-variable", response_model=List[PydanticEnvVar]) -def list_sandbox_env_vars( +async def list_sandbox_env_vars( sandbox_config_id: str, limit: int = Query(1000, description="Number of results to return"), after: Optional[str] = Query(None, description="Pagination cursor to fetch the next set of results"), server: SyncServer = Depends(get_letta_server), actor_id: str = Depends(get_user_id), ): - actor = server.user_manager.get_user_or_default(user_id=actor_id) - return server.sandbox_config_manager.list_sandbox_env_vars(sandbox_config_id, actor, limit=limit, after=after) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) + return await server.sandbox_config_manager.list_sandbox_env_vars_async(sandbox_config_id, actor, limit=limit, after=after) diff --git a/letta/server/rest_api/routers/v1/tags.py b/letta/server/rest_api/routers/v1/tags.py index dab01771d..4ffae32ef 100644 --- a/letta/server/rest_api/routers/v1/tags.py +++ b/letta/server/rest_api/routers/v1/tags.py @@ -12,7 +12,7 @@ router = APIRouter(prefix="/tags", tags=["tag", "admin"]) @router.get("/", tags=["admin"], response_model=List[str], operation_id="list_tags") -def list_tags( +async def list_tags( after: Optional[str] = Query(None), limit: Optional[int] = Query(50), server: "SyncServer" = Depends(get_letta_server), @@ -22,6 +22,6 @@ def list_tags( """ Get a list of all tags in the database """ - actor = server.user_manager.get_user_or_default(user_id=actor_id) - tags = server.agent_manager.list_tags(actor=actor, after=after, limit=limit, query_text=query_text) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) + tags = await server.agent_manager.list_tags_async(actor=actor, after=after, limit=limit, query_text=query_text) return tags diff --git a/letta/server/rest_api/routers/v1/telemetry.py b/letta/server/rest_api/routers/v1/telemetry.py new file mode 100644 index 000000000..75e8de957 --- /dev/null +++ b/letta/server/rest_api/routers/v1/telemetry.py @@ -0,0 +1,18 @@ +from fastapi import APIRouter, Depends, Header + +from letta.schemas.provider_trace import ProviderTrace +from letta.server.rest_api.utils import get_letta_server +from letta.server.server import SyncServer + +router = APIRouter(prefix="/telemetry", tags=["telemetry"]) + + +@router.get("/{step_id}", response_model=ProviderTrace, operation_id="retrieve_provider_trace") +async def retrieve_provider_trace_by_step_id( + step_id: str, + server: SyncServer = Depends(get_letta_server), + actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present +): + return await server.telemetry_manager.get_provider_trace_by_step_id_async( + step_id=step_id, actor=await server.user_manager.get_actor_or_default_async(actor_id=actor_id) + ) diff --git a/letta/server/rest_api/routers/v1/tools.py b/letta/server/rest_api/routers/v1/tools.py index ce8acc46e..ad4536c5f 100644 --- a/letta/server/rest_api/routers/v1/tools.py +++ b/letta/server/rest_api/routers/v1/tools.py @@ -59,7 +59,7 @@ def count_tools( @router.get("/{tool_id}", response_model=Tool, operation_id="retrieve_tool") -def retrieve_tool( +async def retrieve_tool( tool_id: str, server: SyncServer = Depends(get_letta_server), actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present @@ -67,8 +67,8 @@ def retrieve_tool( """ Get a tool by ID """ - actor = server.user_manager.get_user_or_default(user_id=actor_id) - tool = server.tool_manager.get_tool_by_id(tool_id=tool_id, actor=actor) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) + tool = await server.tool_manager.get_tool_by_id_async(tool_id=tool_id, actor=actor) if tool is None: # return 404 error raise HTTPException(status_code=404, detail=f"Tool with id {tool_id} not found.") @@ -196,15 +196,15 @@ def modify_tool( @router.post("/add-base-tools", response_model=List[Tool], operation_id="add_base_tools") -def upsert_base_tools( +async def upsert_base_tools( server: SyncServer = Depends(get_letta_server), actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Upsert base tools """ - actor = server.user_manager.get_user_or_default(user_id=actor_id) - return server.tool_manager.upsert_base_tools(actor=actor) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) + return await server.tool_manager.upsert_base_tools_async(actor=actor) @router.post("/run", response_model=ToolReturnMessage, operation_id="run_tool_from_source") diff --git a/letta/server/rest_api/streaming_response.py b/letta/server/rest_api/streaming_response.py new file mode 100644 index 000000000..13d57e877 --- /dev/null +++ b/letta/server/rest_api/streaming_response.py @@ -0,0 +1,105 @@ +# Alternative implementation of StreamingResponse that allows for effectively +# stremaing HTTP trailers, as we cannot set codes after the initial response. +# Taken from: https://github.com/fastapi/fastapi/discussions/10138#discussioncomment-10377361 + +import json +from collections.abc import AsyncIterator + +from fastapi.responses import StreamingResponse +from starlette.types import Send + +from letta.log import get_logger + +logger = get_logger(__name__) + + +class StreamingResponseWithStatusCode(StreamingResponse): + """ + Variation of StreamingResponse that can dynamically decide the HTTP status code, + based on the return value of the content iterator (parameter `content`). + Expects the content to yield either just str content as per the original `StreamingResponse` + or else tuples of (`content`: `str`, `status_code`: `int`). + """ + + body_iterator: AsyncIterator[str | bytes] + response_started: bool = False + + async def stream_response(self, send: Send) -> None: + more_body = True + try: + first_chunk = await self.body_iterator.__anext__() + if isinstance(first_chunk, tuple): + first_chunk_content, self.status_code = first_chunk + else: + first_chunk_content = first_chunk + if isinstance(first_chunk_content, str): + first_chunk_content = first_chunk_content.encode(self.charset) + + await send( + { + "type": "http.response.start", + "status": self.status_code, + "headers": self.raw_headers, + } + ) + self.response_started = True + await send( + { + "type": "http.response.body", + "body": first_chunk_content, + "more_body": more_body, + } + ) + + async for chunk in self.body_iterator: + if isinstance(chunk, tuple): + content, status_code = chunk + if status_code // 100 != 2: + # An error occurred mid-stream + if not isinstance(content, bytes): + content = content.encode(self.charset) + more_body = False + await send( + { + "type": "http.response.body", + "body": content, + "more_body": more_body, + } + ) + return + else: + content = chunk + + if isinstance(content, str): + content = content.encode(self.charset) + more_body = True + await send( + { + "type": "http.response.body", + "body": content, + "more_body": more_body, + } + ) + + except Exception: + logger.exception("unhandled_streaming_error") + more_body = False + error_resp = {"error": {"message": "Internal Server Error"}} + error_event = f"event: error\ndata: {json.dumps(error_resp)}\n\n".encode(self.charset) + if not self.response_started: + await send( + { + "type": "http.response.start", + "status": 500, + "headers": self.raw_headers, + } + ) + await send( + { + "type": "http.response.body", + "body": error_event, + "more_body": more_body, + } + ) + if more_body: + await send({"type": "http.response.body", "body": b"", "more_body": False}) diff --git a/letta/server/rest_api/utils.py b/letta/server/rest_api/utils.py index e025a2dd9..d04806e35 100644 --- a/letta/server/rest_api/utils.py +++ b/letta/server/rest_api/utils.py @@ -190,6 +190,7 @@ def create_letta_messages_from_llm_response( pre_computed_assistant_message_id: Optional[str] = None, pre_computed_tool_message_id: Optional[str] = None, llm_batch_item_id: Optional[str] = None, + step_id: str | None = None, ) -> List[Message]: messages = [] @@ -244,6 +245,9 @@ def create_letta_messages_from_llm_response( ) messages.append(heartbeat_system_message) + for message in messages: + message.step_id = step_id + return messages diff --git a/letta/server/server.py b/letta/server/server.py index 4392bf49e..1fb519484 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -94,6 +94,7 @@ from letta.services.provider_manager import ProviderManager from letta.services.sandbox_config_manager import SandboxConfigManager from letta.services.source_manager import SourceManager from letta.services.step_manager import StepManager +from letta.services.telemetry_manager import TelemetryManager from letta.services.tool_executor.tool_execution_sandbox import ToolExecutionSandbox from letta.services.tool_manager import ToolManager from letta.services.user_manager import UserManager @@ -213,6 +214,7 @@ class SyncServer(Server): self.identity_manager = IdentityManager() self.group_manager = GroupManager() self.batch_manager = LLMBatchManager() + self.telemetry_manager = TelemetryManager() # A resusable httpx client timeout = httpx.Timeout(connect=10.0, read=20.0, write=10.0, pool=10.0) @@ -1000,6 +1002,30 @@ class SyncServer(Server): ) return records + async def get_agent_archival_async( + self, + agent_id: str, + actor: User, + after: Optional[str] = None, + before: Optional[str] = None, + limit: Optional[int] = 100, + order_by: Optional[str] = "created_at", + reverse: Optional[bool] = False, + query_text: Optional[str] = None, + ascending: Optional[bool] = True, + ) -> List[Passage]: + # iterate over records + records = await self.agent_manager.list_passages_async( + actor=actor, + agent_id=agent_id, + after=after, + query_text=query_text, + before=before, + ascending=ascending, + limit=limit, + ) + return records + def insert_archival_memory(self, agent_id: str, memory_contents: str, actor: User) -> List[Passage]: # Get the agent object (loaded in memory) agent_state = self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor) @@ -1070,6 +1096,44 @@ class SyncServer(Server): return records + async def get_agent_recall_async( + self, + agent_id: str, + actor: User, + after: Optional[str] = None, + before: Optional[str] = None, + limit: Optional[int] = 100, + group_id: Optional[str] = None, + reverse: Optional[bool] = False, + return_message_object: bool = True, + use_assistant_message: bool = True, + assistant_message_tool_name: str = constants.DEFAULT_MESSAGE_TOOL, + assistant_message_tool_kwarg: str = constants.DEFAULT_MESSAGE_TOOL_KWARG, + ) -> Union[List[Message], List[LettaMessage]]: + records = await self.message_manager.list_messages_for_agent_async( + agent_id=agent_id, + actor=actor, + after=after, + before=before, + limit=limit, + ascending=not reverse, + group_id=group_id, + ) + + if not return_message_object: + records = Message.to_letta_messages_from_list( + messages=records, + use_assistant_message=use_assistant_message, + assistant_message_tool_name=assistant_message_tool_name, + assistant_message_tool_kwarg=assistant_message_tool_kwarg, + reverse=reverse, + ) + + if reverse: + records = records[::-1] + + return records + def get_server_config(self, include_defaults: bool = False) -> dict: """Return the base config""" @@ -1301,6 +1365,48 @@ class SyncServer(Server): return llm_models + @trace_method + async def list_llm_models_async( + self, + actor: User, + provider_category: Optional[List[ProviderCategory]] = None, + provider_name: Optional[str] = None, + provider_type: Optional[ProviderType] = None, + ) -> List[LLMConfig]: + """Asynchronously list available models with maximum concurrency""" + import asyncio + + providers = self.get_enabled_providers( + provider_category=provider_category, + provider_name=provider_name, + provider_type=provider_type, + actor=actor, + ) + + async def get_provider_models(provider): + try: + return await provider.list_llm_models_async() + except Exception as e: + import traceback + + traceback.print_exc() + warnings.warn(f"An error occurred while listing LLM models for provider {provider}: {e}") + return [] + + # Execute all provider model listing tasks concurrently + provider_results = await asyncio.gather(*[get_provider_models(provider) for provider in providers]) + + # Flatten the results + llm_models = [] + for models in provider_results: + llm_models.extend(models) + + # Get local configs - if this is potentially slow, consider making it async too + local_configs = self.get_local_llm_configs() + llm_models.extend(local_configs) + + return llm_models + def list_embedding_models(self, actor: User) -> List[EmbeddingConfig]: """List available embedding models""" embedding_models = [] @@ -1311,6 +1417,35 @@ class SyncServer(Server): warnings.warn(f"An error occurred while listing embedding models for provider {provider}: {e}") return embedding_models + async def list_embedding_models_async(self, actor: User) -> List[EmbeddingConfig]: + """Asynchronously list available embedding models with maximum concurrency""" + import asyncio + + # Get all eligible providers first + providers = self.get_enabled_providers(actor=actor) + + # Fetch embedding models from each provider concurrently + async def get_provider_embedding_models(provider): + try: + # All providers now have list_embedding_models_async + return await provider.list_embedding_models_async() + except Exception as e: + import traceback + + traceback.print_exc() + warnings.warn(f"An error occurred while listing embedding models for provider {provider}: {e}") + return [] + + # Execute all provider model listing tasks concurrently + provider_results = await asyncio.gather(*[get_provider_embedding_models(provider) for provider in providers]) + + # Flatten the results + embedding_models = [] + for models in provider_results: + embedding_models.extend(models) + + return embedding_models + def get_enabled_providers( self, actor: User, @@ -1482,6 +1617,10 @@ class SyncServer(Server): letta_agent = self.load_agent(agent_id=agent_id, actor=actor) return letta_agent.get_context_window() + async def get_agent_context_window_async(self, agent_id: str, actor: User) -> ContextWindowOverview: + letta_agent = self.load_agent(agent_id=agent_id, actor=actor) + return await letta_agent.get_context_window_async() + def run_tool_from_source( self, actor: User, @@ -1615,7 +1754,7 @@ class SyncServer(Server): server_name=server_name, command=server_params_raw["command"], args=server_params_raw.get("args", []), - env=server_params_raw.get("env", {}) + env=server_params_raw.get("env", {}), ) mcp_server_list[server_name] = server_params except Exception as e: diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index 91cdffce3..915413e51 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -892,7 +892,7 @@ class AgentManager: List[PydanticAgentState]: The filtered list of matching agents. """ async with db_registry.async_session() as session: - query = select(AgentModel).distinct(AgentModel.created_at, AgentModel.id) + query = select(AgentModel) query = AgentModel.apply_access_predicate(query, actor, ["read"], AccessType.ORGANIZATION) # Apply filters @@ -961,6 +961,16 @@ class AgentManager: with db_registry.session() as session: return AgentModel.size(db_session=session, actor=actor) + async def size_async( + self, + actor: PydanticUser, + ) -> int: + """ + Get the total count of agents for the given user. + """ + async with db_registry.async_session() as session: + return await AgentModel.size_async(db_session=session, actor=actor) + @enforce_types def get_agent_by_id(self, agent_id: str, actor: PydanticUser) -> PydanticAgentState: """Fetch an agent by its ID.""" @@ -969,18 +979,32 @@ class AgentManager: return agent.to_pydantic() @enforce_types - async def get_agent_by_id_async(self, agent_id: str, actor: PydanticUser) -> PydanticAgentState: + async def get_agent_by_id_async( + self, + agent_id: str, + actor: PydanticUser, + include_relationships: Optional[List[str]] = None, + ) -> PydanticAgentState: """Fetch an agent by its ID.""" async with db_registry.async_session() as session: agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor) - return agent.to_pydantic() + return await agent.to_pydantic_async(include_relationships=include_relationships) @enforce_types - async def get_agents_by_ids_async(self, agent_ids: list[str], actor: PydanticUser) -> list[PydanticAgentState]: + async def get_agents_by_ids_async( + self, + agent_ids: list[str], + actor: PydanticUser, + include_relationships: Optional[List[str]] = None, + ) -> list[PydanticAgentState]: """Fetch a list of agents by their IDs.""" async with db_registry.async_session() as session: - agents = await AgentModel.read_multiple_async(db_session=session, identifiers=agent_ids, actor=actor) - return [await agent.to_pydantic_async() for agent in agents] + agents = await AgentModel.read_multiple_async( + db_session=session, + identifiers=agent_ids, + actor=actor, + ) + return await asyncio.gather(*[agent.to_pydantic_async(include_relationships=include_relationships) for agent in agents]) @enforce_types def get_agent_by_name(self, agent_name: str, actor: PydanticUser) -> PydanticAgentState: @@ -1191,7 +1215,7 @@ class AgentManager: @enforce_types async def get_in_context_messages_async(self, agent_id: str, actor: PydanticUser) -> List[PydanticMessage]: - agent = await self.get_agent_by_id_async(agent_id=agent_id, actor=actor) + agent = await self.get_agent_by_id_async(agent_id=agent_id, include_relationships=[], actor=actor) return await self.message_manager.get_messages_by_ids_async(message_ids=agent.message_ids, actor=actor) @enforce_types @@ -1199,6 +1223,11 @@ class AgentManager: message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids return self.message_manager.get_message_by_id(message_id=message_ids[0], actor=actor) + @enforce_types + async def get_system_message_async(self, agent_id: str, actor: PydanticUser) -> PydanticMessage: + agent = await self.get_agent_by_id_async(agent_id=agent_id, include_relationships=[], actor=actor) + return await self.message_manager.get_message_by_id_async(message_id=agent.message_ids[0], actor=actor) + # TODO: This is duplicated below # TODO: This is legacy code and should be cleaned up # TODO: A lot of the memory "compilation" should be offset to a separate class @@ -1267,10 +1296,81 @@ class AgentManager: else: return agent_state + @enforce_types + async def rebuild_system_prompt_async( + self, agent_id: str, actor: PydanticUser, force=False, update_timestamp=True + ) -> PydanticAgentState: + """Rebuilds the system message with the latest memory object and any shared memory block updates + + Updates to core memory blocks should trigger a "rebuild", which itself will create a new message object + + Updates to the memory header should *not* trigger a rebuild, since that will simply flood recall storage with excess messages + """ + agent_state = await self.get_agent_by_id_async(agent_id=agent_id, include_relationships=["memory"], actor=actor) + + curr_system_message = await self.get_system_message_async( + agent_id=agent_id, actor=actor + ) # this is the system + memory bank, not just the system prompt + curr_system_message_openai = curr_system_message.to_openai_dict() + + # note: we only update the system prompt if the core memory is changed + # this means that the archival/recall memory statistics may be someout out of date + curr_memory_str = agent_state.memory.compile() + if curr_memory_str in curr_system_message_openai["content"] and not force: + # NOTE: could this cause issues if a block is removed? (substring match would still work) + logger.debug( + f"Memory hasn't changed for agent id={agent_id} and actor=({actor.id}, {actor.name}), skipping system prompt rebuild" + ) + return agent_state + + # If the memory didn't update, we probably don't want to update the timestamp inside + # For example, if we're doing a system prompt swap, this should probably be False + if update_timestamp: + memory_edit_timestamp = get_utc_time() + else: + # NOTE: a bit of a hack - we pull the timestamp from the message created_by + memory_edit_timestamp = curr_system_message.created_at + + num_messages = await self.message_manager.size_async(actor=actor, agent_id=agent_id) + num_archival_memories = await self.passage_manager.size_async(actor=actor, agent_id=agent_id) + + # update memory (TODO: potentially update recall/archival stats separately) + new_system_message_str = compile_system_message( + system_prompt=agent_state.system, + in_context_memory=agent_state.memory, + in_context_memory_last_edit=memory_edit_timestamp, + recent_passages=self.list_passages(actor=actor, agent_id=agent_id, ascending=False, limit=10), + previous_message_count=num_messages, + archival_memory_size=num_archival_memories, + ) + + diff = united_diff(curr_system_message_openai["content"], new_system_message_str) + if len(diff) > 0: # there was a diff + logger.debug(f"Rebuilding system with new memory...\nDiff:\n{diff}") + + # Swap the system message out (only if there is a diff) + message = PydanticMessage.dict_to_message( + agent_id=agent_id, + model=agent_state.llm_config.model, + openai_message_dict={"role": "system", "content": new_system_message_str}, + ) + message = await self.message_manager.update_message_by_id_async( + message_id=curr_system_message.id, + message_update=MessageUpdate(**message.model_dump()), + actor=actor, + ) + return await self.set_in_context_messages_async(agent_id=agent_id, message_ids=agent_state.message_ids, actor=actor) + else: + return agent_state + @enforce_types def set_in_context_messages(self, agent_id: str, message_ids: List[str], actor: PydanticUser) -> PydanticAgentState: return self.update_agent(agent_id=agent_id, agent_update=UpdateAgent(message_ids=message_ids), actor=actor) + @enforce_types + async def set_in_context_messages_async(self, agent_id: str, message_ids: List[str], actor: PydanticUser) -> PydanticAgentState: + return await self.update_agent_async(agent_id=agent_id, agent_update=UpdateAgent(message_ids=message_ids), actor=actor) + @enforce_types def trim_older_in_context_messages(self, num: int, agent_id: str, actor: PydanticUser) -> PydanticAgentState: message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids @@ -1382,17 +1482,6 @@ class AgentManager: return agent_state - @enforce_types - def refresh_memory(self, agent_state: PydanticAgentState, actor: PydanticUser) -> PydanticAgentState: - block_ids = [b.id for b in agent_state.memory.blocks] - if not block_ids: - return agent_state - - agent_state.memory.blocks = self.block_manager.get_all_blocks_by_ids( - block_ids=[b.id for b in agent_state.memory.blocks], actor=actor - ) - return agent_state - @enforce_types async def refresh_memory_async(self, agent_state: PydanticAgentState, actor: PydanticUser) -> PydanticAgentState: block_ids = [b.id for b in agent_state.memory.blocks] @@ -1482,6 +1571,25 @@ class AgentManager: # Use the lazy-loaded relationship to get sources return [source.to_pydantic() for source in agent.sources] + @enforce_types + async def list_attached_sources_async(self, agent_id: str, actor: PydanticUser) -> List[PydanticSource]: + """ + Lists all sources attached to an agent. + + Args: + agent_id: ID of the agent to list sources for + actor: User performing the action + + Returns: + List[str]: List of source IDs attached to the agent + """ + async with db_registry.async_session() as session: + # Verify agent exists and user has permission to access it + agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor) + + # Use the lazy-loaded relationship to get sources + return [source.to_pydantic() for source in agent.sources] + @enforce_types def detach_source(self, agent_id: str, source_id: str, actor: PydanticUser) -> PydanticAgentState: """ @@ -1527,6 +1635,33 @@ class AgentManager: return block.to_pydantic() raise NoResultFound(f"No block with label '{block_label}' found for agent '{agent_id}'") + @enforce_types + async def modify_block_by_label_async( + self, + agent_id: str, + block_label: str, + block_update: BlockUpdate, + actor: PydanticUser, + ) -> PydanticBlock: + """Gets a block attached to an agent by its label.""" + async with db_registry.async_session() as session: + block = None + agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor) + for block in agent.core_memory: + if block.label == block_label: + block = block + break + if not block: + raise NoResultFound(f"No block with label '{block_label}' found for agent '{agent_id}'") + + update_data = block_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True) + + for key, value in update_data.items(): + setattr(block, key, value) + + await block.update_async(session, actor=actor) + return block.to_pydantic() + @enforce_types def update_block_with_label( self, @@ -1848,6 +1983,65 @@ class AgentManager: return [p.to_pydantic() for p in passages] + @enforce_types + async def list_passages_async( + self, + actor: PydanticUser, + agent_id: Optional[str] = None, + file_id: Optional[str] = None, + limit: Optional[int] = 50, + query_text: Optional[str] = None, + start_date: Optional[datetime] = None, + end_date: Optional[datetime] = None, + before: Optional[str] = None, + after: Optional[str] = None, + source_id: Optional[str] = None, + embed_query: bool = False, + ascending: bool = True, + embedding_config: Optional[EmbeddingConfig] = None, + agent_only: bool = False, + ) -> List[PydanticPassage]: + """Lists all passages attached to an agent.""" + async with db_registry.async_session() as session: + main_query = self._build_passage_query( + actor=actor, + agent_id=agent_id, + file_id=file_id, + query_text=query_text, + start_date=start_date, + end_date=end_date, + before=before, + after=after, + source_id=source_id, + embed_query=embed_query, + ascending=ascending, + embedding_config=embedding_config, + agent_only=agent_only, + ) + + # Add limit + if limit: + main_query = main_query.limit(limit) + + # Execute query + result = await session.execute(main_query) + + passages = [] + for row in result: + data = dict(row._mapping) + if data["agent_id"] is not None: + # This is an AgentPassage - remove source fields + data.pop("source_id", None) + data.pop("file_id", None) + passage = AgentPassage(**data) + else: + # This is a SourcePassage - remove agent field + data.pop("agent_id", None) + passage = SourcePassage(**data) + passages.append(passage) + + return [p.to_pydantic() for p in passages] + @enforce_types def passage_size( self, @@ -2010,3 +2204,42 @@ class AgentManager: query = query.order_by(AgentsTags.tag).limit(limit) results = [tag[0] for tag in query.all()] return results + + @enforce_types + async def list_tags_async( + self, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50, query_text: Optional[str] = None + ) -> List[str]: + """ + Get all tags a user has created, ordered alphabetically. + + Args: + actor: User performing the action. + after: Cursor for forward pagination. + limit: Maximum number of tags to return. + query text to filter tags by. + + Returns: + List[str]: List of all tags. + """ + async with db_registry.async_session() as session: + # Build the query using select() for async SQLAlchemy + query = ( + select(AgentsTags.tag) + .join(AgentModel, AgentModel.id == AgentsTags.agent_id) + .where(AgentModel.organization_id == actor.organization_id) + .distinct() + ) + + if query_text: + query = query.where(AgentsTags.tag.ilike(f"%{query_text}%")) + + if after: + query = query.where(AgentsTags.tag > after) + + query = query.order_by(AgentsTags.tag).limit(limit) + + # Execute the query asynchronously + result = await session.execute(query) + # Extract the tag values from the result + results = [row[0] for row in result.all()] + return results diff --git a/letta/services/block_manager.py b/letta/services/block_manager.py index 0d4e67da2..2d568e34b 100644 --- a/letta/services/block_manager.py +++ b/letta/services/block_manager.py @@ -1,3 +1,4 @@ +import asyncio from typing import Dict, List, Optional from sqlalchemy import select @@ -82,39 +83,64 @@ class BlockManager: return block.to_pydantic() @enforce_types - def get_blocks( + async def get_blocks_async( self, actor: PydanticUser, label: Optional[str] = None, is_template: Optional[bool] = None, template_name: Optional[str] = None, - identifier_keys: Optional[List[str]] = None, identity_id: Optional[str] = None, - id: Optional[str] = None, - after: Optional[str] = None, + identifier_keys: Optional[List[str]] = None, limit: Optional[int] = 50, ) -> List[PydanticBlock]: - """Retrieve blocks based on various optional filters.""" - with db_registry.session() as session: - # Prepare filters - filters = {"organization_id": actor.organization_id} - if label: - filters["label"] = label - if is_template is not None: - filters["is_template"] = is_template - if template_name: - filters["template_name"] = template_name - if id: - filters["id"] = id + """Async version of get_blocks method. Retrieve blocks based on various optional filters.""" + from sqlalchemy import select + from sqlalchemy.orm import noload - blocks = BlockModel.list( - db_session=session, - after=after, - limit=limit, - identifier_keys=identifier_keys, - identity_id=identity_id, - **filters, - ) + from letta.orm.sqlalchemy_base import AccessType + + async with db_registry.async_session() as session: + # Start with a basic query + query = select(BlockModel) + + # Explicitly avoid loading relationships + query = query.options(noload(BlockModel.agents), noload(BlockModel.identities), noload(BlockModel.groups)) + + # Apply access control + query = BlockModel.apply_access_predicate(query, actor, ["read"], AccessType.ORGANIZATION) + + # Add filters + query = query.where(BlockModel.organization_id == actor.organization_id) + if label: + query = query.where(BlockModel.label == label) + + if is_template is not None: + query = query.where(BlockModel.is_template == is_template) + + if template_name: + query = query.where(BlockModel.template_name == template_name) + + if identifier_keys: + query = ( + query.join(BlockModel.identities) + .filter(BlockModel.identities.property.mapper.class_.identifier_key.in_(identifier_keys)) + .distinct(BlockModel.id) + ) + + if identity_id: + query = ( + query.join(BlockModel.identities) + .filter(BlockModel.identities.property.mapper.class_.id == identity_id) + .distinct(BlockModel.id) + ) + + # Add limit + if limit: + query = query.limit(limit) + + # Execute the query + result = await session.execute(query) + blocks = result.scalars().all() return [block.to_pydantic() for block in blocks] @@ -190,15 +216,6 @@ class BlockManager: except NoResultFound: return None - @enforce_types - def get_all_blocks_by_ids(self, block_ids: List[str], actor: Optional[PydanticUser] = None) -> List[PydanticBlock]: - """Retrieve blocks by their ids.""" - with db_registry.session() as session: - blocks = [block.to_pydantic() for block in BlockModel.read_multiple(db_session=session, identifiers=block_ids, actor=actor)] - # backwards compatibility. previous implementation added None for every block not found. - blocks.extend([None for _ in range(len(block_ids) - len(blocks))]) - return blocks - @enforce_types async def get_all_blocks_by_ids_async(self, block_ids: List[str], actor: Optional[PydanticUser] = None) -> List[PydanticBlock]: """Retrieve blocks by their ids without loading unnecessary relationships. Async implementation.""" @@ -247,16 +264,14 @@ class BlockManager: return pydantic_blocks @enforce_types - def get_agents_for_block(self, block_id: str, actor: PydanticUser) -> List[PydanticAgentState]: + async def get_agents_for_block_async(self, block_id: str, actor: PydanticUser) -> List[PydanticAgentState]: """ Retrieve all agents associated with a given block. """ - with db_registry.session() as session: - block = BlockModel.read(db_session=session, identifier=block_id, actor=actor) + async with db_registry.async_session() as session: + block = await BlockModel.read_async(db_session=session, identifier=block_id, actor=actor) agents_orm = block.agents - agents_pydantic = [agent.to_pydantic() for agent in agents_orm] - - return agents_pydantic + return await asyncio.gather(*[agent.to_pydantic_async() for agent in agents_orm]) @enforce_types def size( diff --git a/letta/services/helpers/noop_helper.py b/letta/services/helpers/noop_helper.py new file mode 100644 index 000000000..7f32e628b --- /dev/null +++ b/letta/services/helpers/noop_helper.py @@ -0,0 +1,10 @@ +def singleton(cls): + """Decorator to make a class a Singleton class.""" + instances = {} + + def get_instance(*args, **kwargs): + if cls not in instances: + instances[cls] = cls(*args, **kwargs) + return instances[cls] + + return get_instance diff --git a/letta/services/identity_manager.py b/letta/services/identity_manager.py index 3ca057930..590cedeeb 100644 --- a/letta/services/identity_manager.py +++ b/letta/services/identity_manager.py @@ -1,6 +1,7 @@ from typing import List, Optional from fastapi import HTTPException +from sqlalchemy import select from sqlalchemy.exc import NoResultFound from sqlalchemy.orm import Session @@ -17,7 +18,7 @@ from letta.utils import enforce_types class IdentityManager: @enforce_types - def list_identities( + async def list_identities_async( self, name: Optional[str] = None, project_id: Optional[str] = None, @@ -28,7 +29,7 @@ class IdentityManager: limit: Optional[int] = 50, actor: PydanticUser = None, ) -> list[PydanticIdentity]: - with db_registry.session() as session: + async with db_registry.async_session() as session: filters = {"organization_id": actor.organization_id} if project_id: filters["project_id"] = project_id @@ -36,7 +37,7 @@ class IdentityManager: filters["identifier_key"] = identifier_key if identity_type: filters["identity_type"] = identity_type - identities = IdentityModel.list( + identities = await IdentityModel.list_async( db_session=session, query_text=name, before=before, @@ -47,17 +48,17 @@ class IdentityManager: return [identity.to_pydantic() for identity in identities] @enforce_types - def get_identity(self, identity_id: str, actor: PydanticUser) -> PydanticIdentity: - with db_registry.session() as session: - identity = IdentityModel.read(db_session=session, identifier=identity_id, actor=actor) + async def get_identity_async(self, identity_id: str, actor: PydanticUser) -> PydanticIdentity: + async with db_registry.async_session() as session: + identity = await IdentityModel.read_async(db_session=session, identifier=identity_id, actor=actor) return identity.to_pydantic() @enforce_types - def create_identity(self, identity: IdentityCreate, actor: PydanticUser) -> PydanticIdentity: - with db_registry.session() as session: + async def create_identity_async(self, identity: IdentityCreate, actor: PydanticUser) -> PydanticIdentity: + async with db_registry.async_session() as session: new_identity = IdentityModel(**identity.model_dump(exclude={"agent_ids", "block_ids"}, exclude_unset=True)) new_identity.organization_id = actor.organization_id - self._process_relationship( + await self._process_relationship_async( session=session, identity=new_identity, relationship_name="agents", @@ -65,7 +66,7 @@ class IdentityManager: item_ids=identity.agent_ids, allow_partial=False, ) - self._process_relationship( + await self._process_relationship_async( session=session, identity=new_identity, relationship_name="blocks", @@ -73,13 +74,13 @@ class IdentityManager: item_ids=identity.block_ids, allow_partial=False, ) - new_identity.create(session, actor=actor) + await new_identity.create_async(session, actor=actor) return new_identity.to_pydantic() @enforce_types - def upsert_identity(self, identity: IdentityUpsert, actor: PydanticUser) -> PydanticIdentity: - with db_registry.session() as session: - existing_identity = IdentityModel.read( + async def upsert_identity_async(self, identity: IdentityUpsert, actor: PydanticUser) -> PydanticIdentity: + async with db_registry.async_session() as session: + existing_identity = await IdentityModel.read_async( db_session=session, identifier_key=identity.identifier_key, project_id=identity.project_id, @@ -88,7 +89,7 @@ class IdentityManager: ) if existing_identity is None: - return self.create_identity(identity=IdentityCreate(**identity.model_dump()), actor=actor) + return await self.create_identity_async(identity=IdentityCreate(**identity.model_dump()), actor=actor) else: identity_update = IdentityUpdate( name=identity.name, @@ -97,25 +98,27 @@ class IdentityManager: agent_ids=identity.agent_ids, properties=identity.properties, ) - return self._update_identity( + return await self._update_identity_async( session=session, existing_identity=existing_identity, identity=identity_update, actor=actor, replace=True ) @enforce_types - def update_identity(self, identity_id: str, identity: IdentityUpdate, actor: PydanticUser, replace: bool = False) -> PydanticIdentity: - with db_registry.session() as session: + async def update_identity_async( + self, identity_id: str, identity: IdentityUpdate, actor: PydanticUser, replace: bool = False + ) -> PydanticIdentity: + async with db_registry.async_session() as session: try: - existing_identity = IdentityModel.read(db_session=session, identifier=identity_id, actor=actor) + existing_identity = await IdentityModel.read_async(db_session=session, identifier=identity_id, actor=actor) except NoResultFound: raise HTTPException(status_code=404, detail="Identity not found") if existing_identity.organization_id != actor.organization_id: raise HTTPException(status_code=403, detail="Forbidden") - return self._update_identity( + return await self._update_identity_async( session=session, existing_identity=existing_identity, identity=identity, actor=actor, replace=replace ) - def _update_identity( + async def _update_identity_async( self, session: Session, existing_identity: IdentityModel, @@ -139,7 +142,7 @@ class IdentityManager: existing_identity.properties = list(new_properties.values()) if identity.agent_ids is not None: - self._process_relationship( + await self._process_relationship_async( session=session, identity=existing_identity, relationship_name="agents", @@ -149,7 +152,7 @@ class IdentityManager: replace=replace, ) if identity.block_ids is not None: - self._process_relationship( + await self._process_relationship_async( session=session, identity=existing_identity, relationship_name="blocks", @@ -158,16 +161,18 @@ class IdentityManager: allow_partial=False, replace=replace, ) - existing_identity.update(session, actor=actor) + await existing_identity.update_async(session, actor=actor) return existing_identity.to_pydantic() @enforce_types - def upsert_identity_properties(self, identity_id: str, properties: List[IdentityProperty], actor: PydanticUser) -> PydanticIdentity: - with db_registry.session() as session: - existing_identity = IdentityModel.read(db_session=session, identifier=identity_id, actor=actor) + async def upsert_identity_properties_async( + self, identity_id: str, properties: List[IdentityProperty], actor: PydanticUser + ) -> PydanticIdentity: + async with db_registry.async_session() as session: + existing_identity = await IdentityModel.read_async(db_session=session, identifier=identity_id, actor=actor) if existing_identity is None: raise HTTPException(status_code=404, detail="Identity not found") - return self._update_identity( + return await self._update_identity_async( session=session, existing_identity=existing_identity, identity=IdentityUpdate(properties=properties), @@ -176,28 +181,28 @@ class IdentityManager: ) @enforce_types - def delete_identity(self, identity_id: str, actor: PydanticUser) -> None: - with db_registry.session() as session: - identity = IdentityModel.read(db_session=session, identifier=identity_id) + async def delete_identity_async(self, identity_id: str, actor: PydanticUser) -> None: + async with db_registry.async_session() as session: + identity = await IdentityModel.read_async(db_session=session, identifier=identity_id, actor=actor) if identity is None: raise HTTPException(status_code=404, detail="Identity not found") if identity.organization_id != actor.organization_id: raise HTTPException(status_code=403, detail="Forbidden") - session.delete(identity) - session.commit() + await session.delete(identity) + await session.commit() @enforce_types - def size( + async def size_async( self, actor: PydanticUser, ) -> int: """ Get the total count of identities for the given user. """ - with db_registry.session() as session: - return IdentityModel.size(db_session=session, actor=actor) + async with db_registry.async_session() as session: + return await IdentityModel.size_async(db_session=session, actor=actor) - def _process_relationship( + async def _process_relationship_async( self, session: Session, identity: PydanticIdentity, @@ -214,7 +219,7 @@ class IdentityManager: return # Retrieve models for the provided IDs - found_items = session.query(model_class).filter(model_class.id.in_(item_ids)).all() + found_items = (await session.execute(select(model_class).where(model_class.id.in_(item_ids)))).scalars().all() # Validate all items are found if allow_partial is False if not allow_partial and len(found_items) != len(item_ids): diff --git a/letta/services/job_manager.py b/letta/services/job_manager.py index d279ac90d..d3c7ca590 100644 --- a/letta/services/job_manager.py +++ b/letta/services/job_manager.py @@ -150,6 +150,35 @@ class JobManager: ) return [job.to_pydantic() for job in jobs] + @enforce_types + async def list_jobs_async( + self, + actor: PydanticUser, + before: Optional[str] = None, + after: Optional[str] = None, + limit: Optional[int] = 50, + statuses: Optional[List[JobStatus]] = None, + job_type: JobType = JobType.JOB, + ascending: bool = True, + ) -> List[PydanticJob]: + """List all jobs with optional pagination and status filter.""" + async with db_registry.async_session() as session: + filter_kwargs = {"user_id": actor.id, "job_type": job_type} + + # Add status filter if provided + if statuses: + filter_kwargs["status"] = statuses + + jobs = await JobModel.list_async( + db_session=session, + before=before, + after=after, + limit=limit, + ascending=ascending, + **filter_kwargs, + ) + return [job.to_pydantic() for job in jobs] + @enforce_types def delete_job_by_id(self, job_id: str, actor: PydanticUser) -> PydanticJob: """Delete a job by its ID.""" diff --git a/letta/services/message_manager.py b/letta/services/message_manager.py index 2cc13f3f0..91351db3f 100644 --- a/letta/services/message_manager.py +++ b/letta/services/message_manager.py @@ -31,6 +31,16 @@ class MessageManager: except NoResultFound: return None + @enforce_types + async def get_message_by_id_async(self, message_id: str, actor: PydanticUser) -> Optional[PydanticMessage]: + """Fetch a message by ID.""" + async with db_registry.async_session() as session: + try: + message = await MessageModel.read_async(db_session=session, identifier=message_id, actor=actor) + return message.to_pydantic() + except NoResultFound: + return None + @enforce_types def get_messages_by_ids(self, message_ids: List[str], actor: PydanticUser) -> List[PydanticMessage]: """Fetch messages by ID and return them in the requested order.""" @@ -426,6 +436,107 @@ class MessageManager: results = query.all() return [msg.to_pydantic() for msg in results] + @enforce_types + async def list_messages_for_agent_async( + self, + agent_id: str, + actor: PydanticUser, + after: Optional[str] = None, + before: Optional[str] = None, + query_text: Optional[str] = None, + roles: Optional[Sequence[MessageRole]] = None, + limit: Optional[int] = 50, + ascending: bool = True, + group_id: Optional[str] = None, + ) -> List[PydanticMessage]: + """ + Most performant query to list messages for an agent by directly querying the Message table. + + This function filters by the agent_id (leveraging the index on messages.agent_id) + and applies pagination using sequence_id as the cursor. + If query_text is provided, it will filter messages whose text content partially matches the query. + If role is provided, it will filter messages by the specified role. + + Args: + agent_id: The ID of the agent whose messages are queried. + actor: The user performing the action (used for permission checks). + after: A message ID; if provided, only messages *after* this message (by sequence_id) are returned. + before: A message ID; if provided, only messages *before* this message (by sequence_id) are returned. + query_text: Optional string to partially match the message text content. + roles: Optional MessageRole to filter messages by role. + limit: Maximum number of messages to return. + ascending: If True, sort by sequence_id ascending; if False, sort descending. + group_id: Optional group ID to filter messages by group_id. + + Returns: + List[PydanticMessage]: A list of messages (converted via .to_pydantic()). + + Raises: + NoResultFound: If the provided after/before message IDs do not exist. + """ + + async with db_registry.async_session() as session: + # Permission check: raise if the agent doesn't exist or actor is not allowed. + await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor) + + # Build a query that directly filters the Message table by agent_id. + query = select(MessageModel).where(MessageModel.agent_id == agent_id) + + # If group_id is provided, filter messages by group_id. + if group_id: + query = query.where(MessageModel.group_id == group_id) + + # If query_text is provided, filter messages using subquery + json_array_elements. + if query_text: + content_element = func.json_array_elements(MessageModel.content).alias("content_element") + query = query.where( + exists( + select(1) + .select_from(content_element) + .where(text("content_element->>'type' = 'text' AND content_element->>'text' ILIKE :query_text")) + .params(query_text=f"%{query_text}%") + ) + ) + + # If role(s) are provided, filter messages by those roles. + if roles: + role_values = [r.value for r in roles] + query = query.where(MessageModel.role.in_(role_values)) + + # Apply 'after' pagination if specified. + if after: + after_query = select(MessageModel.sequence_id).where(MessageModel.id == after) + after_result = await session.execute(after_query) + after_ref = after_result.one_or_none() + if not after_ref: + raise NoResultFound(f"No message found with id '{after}' for agent '{agent_id}'.") + # Filter out any messages with a sequence_id <= after_ref.sequence_id + query = query.where(MessageModel.sequence_id > after_ref.sequence_id) + + # Apply 'before' pagination if specified. + if before: + before_query = select(MessageModel.sequence_id).where(MessageModel.id == before) + before_result = await session.execute(before_query) + before_ref = before_result.one_or_none() + if not before_ref: + raise NoResultFound(f"No message found with id '{before}' for agent '{agent_id}'.") + # Filter out any messages with a sequence_id >= before_ref.sequence_id + query = query.where(MessageModel.sequence_id < before_ref.sequence_id) + + # Apply ordering based on the ascending flag. + if ascending: + query = query.order_by(MessageModel.sequence_id.asc()) + else: + query = query.order_by(MessageModel.sequence_id.desc()) + + # Limit the number of results. + query = query.limit(limit) + + # Execute and convert each Message to its Pydantic representation. + result = await session.execute(query) + results = result.scalars().all() + return [msg.to_pydantic() for msg in results] + @enforce_types def delete_all_messages_for_agent(self, agent_id: str, actor: PydanticUser) -> int: """ diff --git a/letta/services/sandbox_config_manager.py b/letta/services/sandbox_config_manager.py index 5b25b25ea..0f55a0bc5 100644 --- a/letta/services/sandbox_config_manager.py +++ b/letta/services/sandbox_config_manager.py @@ -122,6 +122,23 @@ class SandboxConfigManager: sandboxes = SandboxConfigModel.list(db_session=session, after=after, limit=limit, **kwargs) return [sandbox.to_pydantic() for sandbox in sandboxes] + @enforce_types + async def list_sandbox_configs_async( + self, + actor: PydanticUser, + after: Optional[str] = None, + limit: Optional[int] = 50, + sandbox_type: Optional[SandboxType] = None, + ) -> List[PydanticSandboxConfig]: + """List all sandbox configurations with optional pagination.""" + kwargs = {"organization_id": actor.organization_id} + if sandbox_type: + kwargs.update({"type": sandbox_type}) + + async with db_registry.async_session() as session: + sandboxes = await SandboxConfigModel.list_async(db_session=session, after=after, limit=limit, **kwargs) + return [sandbox.to_pydantic() for sandbox in sandboxes] + @enforce_types def get_sandbox_config_by_id(self, sandbox_config_id: str, actor: Optional[PydanticUser] = None) -> Optional[PydanticSandboxConfig]: """Retrieve a sandbox configuration by its ID.""" @@ -224,6 +241,25 @@ class SandboxConfigManager: ) return [env_var.to_pydantic() for env_var in env_vars] + @enforce_types + async def list_sandbox_env_vars_async( + self, + sandbox_config_id: str, + actor: PydanticUser, + after: Optional[str] = None, + limit: Optional[int] = 50, + ) -> List[PydanticEnvVar]: + """List all sandbox environment variables with optional pagination.""" + async with db_registry.async_session() as session: + env_vars = await SandboxEnvVarModel.list_async( + db_session=session, + after=after, + limit=limit, + organization_id=actor.organization_id, + sandbox_config_id=sandbox_config_id, + ) + return [env_var.to_pydantic() for env_var in env_vars] + @enforce_types def list_sandbox_env_vars_by_key( self, key: str, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50 diff --git a/letta/services/step_manager.py b/letta/services/step_manager.py index cf34915d5..8ee052218 100644 --- a/letta/services/step_manager.py +++ b/letta/services/step_manager.py @@ -2,6 +2,7 @@ from datetime import datetime from typing import List, Literal, Optional from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session from letta.orm.errors import NoResultFound @@ -12,6 +13,7 @@ from letta.schemas.openai.chat_completion_response import UsageStatistics from letta.schemas.step import Step as PydanticStep from letta.schemas.user import User as PydanticUser from letta.server.db import db_registry +from letta.services.helpers.noop_helper import singleton from letta.tracing import get_trace_id from letta.utils import enforce_types @@ -57,12 +59,14 @@ class StepManager: actor: PydanticUser, agent_id: str, provider_name: str, + provider_category: str, model: str, model_endpoint: Optional[str], context_window_limit: int, usage: UsageStatistics, provider_id: Optional[str] = None, job_id: Optional[str] = None, + step_id: Optional[str] = None, ) -> PydanticStep: step_data = { "origin": None, @@ -70,6 +74,7 @@ class StepManager: "agent_id": agent_id, "provider_id": provider_id, "provider_name": provider_name, + "provider_category": provider_category, "model": model, "model_endpoint": model_endpoint, "context_window_limit": context_window_limit, @@ -81,6 +86,8 @@ class StepManager: "tid": None, "trace_id": get_trace_id(), # Get the current trace ID } + if step_id: + step_data["id"] = step_id with db_registry.session() as session: if job_id: self._verify_job_access(session, job_id, actor, access=["write"]) @@ -88,6 +95,48 @@ class StepManager: new_step.create(session) return new_step.to_pydantic() + @enforce_types + async def log_step_async( + self, + actor: PydanticUser, + agent_id: str, + provider_name: str, + provider_category: str, + model: str, + model_endpoint: Optional[str], + context_window_limit: int, + usage: UsageStatistics, + provider_id: Optional[str] = None, + job_id: Optional[str] = None, + step_id: Optional[str] = None, + ) -> PydanticStep: + step_data = { + "origin": None, + "organization_id": actor.organization_id, + "agent_id": agent_id, + "provider_id": provider_id, + "provider_name": provider_name, + "provider_category": provider_category, + "model": model, + "model_endpoint": model_endpoint, + "context_window_limit": context_window_limit, + "completion_tokens": usage.completion_tokens, + "prompt_tokens": usage.prompt_tokens, + "total_tokens": usage.total_tokens, + "job_id": job_id, + "tags": [], + "tid": None, + "trace_id": get_trace_id(), # Get the current trace ID + } + if step_id: + step_data["id"] = step_id + async with db_registry.async_session() as session: + if job_id: + await self._verify_job_access_async(session, job_id, actor, access=["write"]) + new_step = StepModel(**step_data) + await new_step.create_async(session) + return new_step.to_pydantic() + @enforce_types def get_step(self, step_id: str, actor: PydanticUser) -> PydanticStep: with db_registry.session() as session: @@ -147,3 +196,100 @@ class StepManager: if not job: raise NoResultFound(f"Job with id {job_id} does not exist or user does not have access") return job + + async def _verify_job_access_async( + self, + session: AsyncSession, + job_id: str, + actor: PydanticUser, + access: List[Literal["read", "write", "delete"]] = ["read"], + ) -> JobModel: + """ + Verify that a job exists and the user has the required access asynchronously. + + Args: + session: The async database session + job_id: The ID of the job to verify + actor: The user making the request + + Returns: + The job if it exists and the user has access + + Raises: + NoResultFound: If the job does not exist or user does not have access + """ + job_query = select(JobModel).where(JobModel.id == job_id) + job_query = JobModel.apply_access_predicate(job_query, actor, access, AccessType.USER) + result = await session.execute(job_query) + job = result.scalar_one_or_none() + if not job: + raise NoResultFound(f"Job with id {job_id} does not exist or user does not have access") + return job + + +@singleton +class NoopStepManager(StepManager): + """ + Noop implementation of StepManager. + Temporarily used for migrations, but allows for different implementations in the future. + Will not allow for writes, but will still allow for reads. + """ + + @enforce_types + def log_step( + self, + actor: PydanticUser, + agent_id: str, + provider_name: str, + provider_category: str, + model: str, + model_endpoint: Optional[str], + context_window_limit: int, + usage: UsageStatistics, + provider_id: Optional[str] = None, + job_id: Optional[str] = None, + step_id: Optional[str] = None, + ) -> PydanticStep: + return + + @enforce_types + async def log_step_async( + self, + actor: PydanticUser, + agent_id: str, + provider_name: str, + provider_category: str, + model: str, + model_endpoint: Optional[str], + context_window_limit: int, + usage: UsageStatistics, + provider_id: Optional[str] = None, + job_id: Optional[str] = None, + step_id: Optional[str] = None, + ) -> PydanticStep: + step_data = { + "origin": None, + "organization_id": actor.organization_id, + "agent_id": agent_id, + "provider_id": provider_id, + "provider_name": provider_name, + "provider_category": provider_category, + "model": model, + "model_endpoint": model_endpoint, + "context_window_limit": context_window_limit, + "completion_tokens": usage.completion_tokens, + "prompt_tokens": usage.prompt_tokens, + "total_tokens": usage.total_tokens, + "job_id": job_id, + "tags": [], + "tid": None, + "trace_id": get_trace_id(), # Get the current trace ID + } + if step_id: + step_data["id"] = step_id + async with db_registry.async_session() as session: + if job_id: + await self._verify_job_access_async(session, job_id, actor, access=["write"]) + new_step = StepModel(**step_data) + await new_step.create_async(session) + return new_step.to_pydantic() diff --git a/letta/services/telemetry_manager.py b/letta/services/telemetry_manager.py new file mode 100644 index 000000000..a57474b12 --- /dev/null +++ b/letta/services/telemetry_manager.py @@ -0,0 +1,58 @@ +from letta.helpers.json_helpers import json_dumps, json_loads +from letta.orm.provider_trace import ProviderTrace as ProviderTraceModel +from letta.schemas.provider_trace import ProviderTrace as PydanticProviderTrace +from letta.schemas.provider_trace import ProviderTraceCreate +from letta.schemas.step import Step as PydanticStep +from letta.schemas.user import User as PydanticUser +from letta.server.db import db_registry +from letta.services.helpers.noop_helper import singleton +from letta.utils import enforce_types + + +class TelemetryManager: + @enforce_types + async def get_provider_trace_by_step_id_async( + self, + step_id: str, + actor: PydanticUser, + ) -> PydanticProviderTrace: + async with db_registry.async_session() as session: + provider_trace = await ProviderTraceModel.read_async(db_session=session, step_id=step_id, actor=actor) + return provider_trace.to_pydantic() + + @enforce_types + async def create_provider_trace_async(self, actor: PydanticUser, provider_trace_create: ProviderTraceCreate) -> PydanticProviderTrace: + async with db_registry.async_session() as session: + provider_trace = ProviderTraceModel(**provider_trace_create.model_dump()) + if provider_trace_create.request_json: + request_json_str = json_dumps(provider_trace_create.request_json) + provider_trace.request_json = json_loads(request_json_str) + + if provider_trace_create.response_json: + response_json_str = json_dumps(provider_trace_create.response_json) + provider_trace.response_json = json_loads(response_json_str) + await provider_trace.create_async(session, actor=actor) + return provider_trace.to_pydantic() + + @enforce_types + def create_provider_trace(self, actor: PydanticUser, provider_trace_create: ProviderTraceCreate) -> PydanticProviderTrace: + with db_registry.session() as session: + provider_trace = ProviderTraceModel(**provider_trace_create.model_dump()) + provider_trace.create(session, actor=actor) + return provider_trace.to_pydantic() + + +@singleton +class NoopTelemetryManager(TelemetryManager): + """ + Noop implementation of TelemetryManager. + """ + + async def create_provider_trace_async(self, actor: PydanticUser, provider_trace_create: ProviderTraceCreate) -> PydanticProviderTrace: + return + + async def get_provider_trace_by_step_id_async(self, step_id: str, actor: PydanticUser) -> PydanticStep: + return + + def create_provider_trace(self, actor: PydanticUser, provider_trace_create: ProviderTraceCreate) -> PydanticProviderTrace: + return diff --git a/letta/services/tool_executor/tool_execution_manager.py b/letta/services/tool_executor/tool_execution_manager.py index 6ba8679c3..4c3786214 100644 --- a/letta/services/tool_executor/tool_execution_manager.py +++ b/letta/services/tool_executor/tool_execution_manager.py @@ -8,9 +8,14 @@ from letta.schemas.sandbox_config import SandboxConfig from letta.schemas.tool import Tool from letta.schemas.tool_execution_result import ToolExecutionResult from letta.schemas.user import User +from letta.services.agent_manager import AgentManager +from letta.services.block_manager import BlockManager +from letta.services.message_manager import MessageManager +from letta.services.passage_manager import PassageManager from letta.services.tool_executor.tool_executor import ( ExternalComposioToolExecutor, ExternalMCPToolExecutor, + LettaBuiltinToolExecutor, LettaCoreToolExecutor, LettaMultiAgentToolExecutor, SandboxToolExecutor, @@ -28,15 +33,30 @@ class ToolExecutorFactory: ToolType.LETTA_MEMORY_CORE: LettaCoreToolExecutor, ToolType.LETTA_SLEEPTIME_CORE: LettaCoreToolExecutor, ToolType.LETTA_MULTI_AGENT_CORE: LettaMultiAgentToolExecutor, + ToolType.LETTA_BUILTIN: LettaBuiltinToolExecutor, ToolType.EXTERNAL_COMPOSIO: ExternalComposioToolExecutor, ToolType.EXTERNAL_MCP: ExternalMCPToolExecutor, } @classmethod - def get_executor(cls, tool_type: ToolType) -> ToolExecutor: + def get_executor( + cls, + tool_type: ToolType, + message_manager: MessageManager, + agent_manager: AgentManager, + block_manager: BlockManager, + passage_manager: PassageManager, + actor: User, + ) -> ToolExecutor: """Get the appropriate executor for the given tool type.""" executor_class = cls._executor_map.get(tool_type, SandboxToolExecutor) - return executor_class() + return executor_class( + message_manager=message_manager, + agent_manager=agent_manager, + block_manager=block_manager, + passage_manager=passage_manager, + actor=actor, + ) class ToolExecutionManager: @@ -44,11 +64,19 @@ class ToolExecutionManager: def __init__( self, + message_manager: MessageManager, + agent_manager: AgentManager, + block_manager: BlockManager, + passage_manager: PassageManager, agent_state: AgentState, actor: User, sandbox_config: Optional[SandboxConfig] = None, sandbox_env_vars: Optional[Dict[str, Any]] = None, ): + self.message_manager = message_manager + self.agent_manager = agent_manager + self.block_manager = block_manager + self.passage_manager = passage_manager self.agent_state = agent_state self.logger = get_logger(__name__) self.actor = actor @@ -68,7 +96,14 @@ class ToolExecutionManager: Tuple containing the function response and sandbox run result (if applicable) """ try: - executor = ToolExecutorFactory.get_executor(tool.tool_type) + executor = ToolExecutorFactory.get_executor( + tool.tool_type, + message_manager=self.message_manager, + agent_manager=self.agent_manager, + block_manager=self.block_manager, + passage_manager=self.passage_manager, + actor=self.actor, + ) return executor.execute( function_name, function_args, @@ -98,9 +133,18 @@ class ToolExecutionManager: Execute a tool asynchronously and persist any state changes. """ try: - executor = ToolExecutorFactory.get_executor(tool.tool_type) + executor = ToolExecutorFactory.get_executor( + tool.tool_type, + message_manager=self.message_manager, + agent_manager=self.agent_manager, + block_manager=self.block_manager, + passage_manager=self.passage_manager, + actor=self.actor, + ) # TODO: Extend this async model to composio - if isinstance(executor, (SandboxToolExecutor, ExternalComposioToolExecutor)): + if isinstance( + executor, (SandboxToolExecutor, ExternalComposioToolExecutor, LettaBuiltinToolExecutor, LettaMultiAgentToolExecutor) + ): result = await executor.execute(function_name, function_args, self.agent_state, tool, self.actor) else: result = executor.execute(function_name, function_args, self.agent_state, tool, self.actor) diff --git a/letta/services/tool_executor/tool_execution_sandbox.py b/letta/services/tool_executor/tool_execution_sandbox.py index 537b3f0b2..e466cd9e0 100644 --- a/letta/services/tool_executor/tool_execution_sandbox.py +++ b/letta/services/tool_executor/tool_execution_sandbox.py @@ -73,6 +73,7 @@ class ToolExecutionSandbox: self.force_recreate = force_recreate self.force_recreate_venv = force_recreate_venv + @trace_method def run( self, agent_state: Optional[AgentState] = None, @@ -321,6 +322,7 @@ class ToolExecutionSandbox: # e2b sandbox specific functions + @trace_method def run_e2b_sandbox( self, agent_state: Optional[AgentState] = None, @@ -352,10 +354,22 @@ class ToolExecutionSandbox: if additional_env_vars: env_vars.update(additional_env_vars) code = self.generate_execution_script(agent_state=agent_state) + log_event( + "e2b_execution_started", + {"tool": self.tool_name, "sandbox_id": sbx.sandbox_id, "code": code, "env_vars": env_vars}, + ) execution = sbx.run_code(code, envs=env_vars) if execution.results: func_return, agent_state = self.parse_best_effort(execution.results[0].text) + log_event( + "e2b_execution_succeeded", + { + "tool": self.tool_name, + "sandbox_id": sbx.sandbox_id, + "func_return": func_return, + }, + ) elif execution.error: logger.error(f"Executing tool {self.tool_name} raised a {execution.error.name} with message: \n{execution.error.value}") logger.error(f"Traceback from e2b sandbox: \n{execution.error.traceback}") @@ -363,7 +377,25 @@ class ToolExecutionSandbox: function_name=self.tool_name, exception_name=execution.error.name, exception_message=execution.error.value ) execution.logs.stderr.append(execution.error.traceback) + log_event( + "e2b_execution_failed", + { + "tool": self.tool_name, + "sandbox_id": sbx.sandbox_id, + "error_type": execution.error.name, + "error_message": execution.error.value, + "func_return": func_return, + }, + ) else: + log_event( + "e2b_execution_empty", + { + "tool": self.tool_name, + "sandbox_id": sbx.sandbox_id, + "status": "no_results_no_error", + }, + ) raise ValueError(f"Tool {self.tool_name} returned execution with None") return ToolExecutionResult( @@ -395,16 +427,31 @@ class ToolExecutionSandbox: return None + @trace_method def create_e2b_sandbox_with_metadata_hash(self, sandbox_config: SandboxConfig) -> "Sandbox": from e2b_code_interpreter import Sandbox state_hash = sandbox_config.fingerprint() e2b_config = sandbox_config.get_e2b_config() + log_event( + "e2b_sandbox_create_started", + { + "sandbox_fingerprint": state_hash, + "e2b_config": e2b_config.model_dump(), + }, + ) if e2b_config.template: sbx = Sandbox(sandbox_config.get_e2b_config().template, metadata={self.METADATA_CONFIG_STATE_KEY: state_hash}) else: # no template sbx = Sandbox(metadata={self.METADATA_CONFIG_STATE_KEY: state_hash}, **e2b_config.model_dump(exclude={"pip_requirements"})) + log_event( + "e2b_sandbox_create_finished", + { + "sandbox_id": sbx.sandbox_id, + "sandbox_fingerprint": state_hash, + }, + ) # install pip requirements if e2b_config.pip_requirements: diff --git a/letta/services/tool_executor/tool_executor.py b/letta/services/tool_executor/tool_executor.py index 9424520cc..51fda3d7c 100644 --- a/letta/services/tool_executor/tool_executor.py +++ b/letta/services/tool_executor/tool_executor.py @@ -1,35 +1,64 @@ +import asyncio +import json import math import traceback from abc import ABC, abstractmethod -from typing import Any, Dict, Optional +from textwrap import shorten +from typing import Any, Dict, List, Literal, Optional from letta.constants import ( COMPOSIO_ENTITY_ENV_VAR_KEY, CORE_MEMORY_LINE_NUMBER_WARNING, READ_ONLY_BLOCK_EDIT_ERROR, RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE, + WEB_SEARCH_CLIP_CONTENT, + WEB_SEARCH_INCLUDE_SCORE, + WEB_SEARCH_SEPARATOR, ) from letta.functions.ast_parsers import coerce_dict_args_by_annotations, get_function_annotations_from_source from letta.functions.composio_helpers import execute_composio_action_async, generate_composio_action_from_func_name from letta.helpers.composio_helpers import get_composio_api_key from letta.helpers.json_helpers import json_dumps +from letta.log import get_logger from letta.schemas.agent import AgentState +from letta.schemas.enums import MessageRole +from letta.schemas.letta_message import AssistantMessage +from letta.schemas.letta_message_content import TextContent +from letta.schemas.message import MessageCreate from letta.schemas.sandbox_config import SandboxConfig from letta.schemas.tool import Tool from letta.schemas.tool_execution_result import ToolExecutionResult from letta.schemas.user import User from letta.services.agent_manager import AgentManager +from letta.services.block_manager import BlockManager from letta.services.message_manager import MessageManager from letta.services.passage_manager import PassageManager from letta.services.tool_sandbox.e2b_sandbox import AsyncToolSandboxE2B from letta.services.tool_sandbox.local_sandbox import AsyncToolSandboxLocal from letta.settings import tool_settings +from letta.tracing import trace_method from letta.utils import get_friendly_error_msg +logger = get_logger(__name__) + class ToolExecutor(ABC): """Abstract base class for tool executors.""" + def __init__( + self, + message_manager: MessageManager, + agent_manager: AgentManager, + block_manager: BlockManager, + passage_manager: PassageManager, + actor: User, + ): + self.message_manager = message_manager + self.agent_manager = agent_manager + self.block_manager = block_manager + self.passage_manager = passage_manager + self.actor = actor + @abstractmethod def execute( self, @@ -493,17 +522,113 @@ class LettaCoreToolExecutor(ToolExecutor): class LettaMultiAgentToolExecutor(ToolExecutor): """Executor for LETTA multi-agent core tools.""" - # TODO: Implement - # def execute(self, function_name: str, function_args: dict, agent: "Agent", tool: Tool) -> ToolExecutionResult: - # callable_func = get_function_from_module(LETTA_MULTI_AGENT_TOOL_MODULE_NAME, function_name) - # function_args["self"] = agent # need to attach self to arg since it's dynamically linked - # function_response = callable_func(**function_args) - # return ToolExecutionResult(func_return=function_response) + async def execute( + self, + function_name: str, + function_args: dict, + agent_state: AgentState, + tool: Tool, + actor: User, + sandbox_config: Optional[SandboxConfig] = None, + sandbox_env_vars: Optional[Dict[str, Any]] = None, + ) -> ToolExecutionResult: + function_map = { + "send_message_to_agent_and_wait_for_reply": self.send_message_to_agent_and_wait_for_reply, + "send_message_to_agent_async": self.send_message_to_agent_async, + "send_message_to_agents_matching_tags": self.send_message_to_agents_matching_tags, + } + + if function_name not in function_map: + raise ValueError(f"Unknown function: {function_name}") + + # Execute the appropriate function + function_args_copy = function_args.copy() # Make a copy to avoid modifying the original + function_response = await function_map[function_name](agent_state, **function_args_copy) + return ToolExecutionResult( + status="success", + func_return=function_response, + ) + + async def send_message_to_agent_and_wait_for_reply(self, agent_state: AgentState, message: str, other_agent_id: str) -> str: + augmented_message = ( + f"[Incoming message from agent with ID '{agent_state.id}' - to reply to this message, " + f"make sure to use the 'send_message' at the end, and the system will notify the sender of your response] " + f"{message}" + ) + + return str(await self._process_agent(agent_id=other_agent_id, message=augmented_message)) + + async def send_message_to_agent_async(self, agent_state: AgentState, message: str, other_agent_id: str) -> str: + # 1) Build the prefixed system‐message + prefixed = ( + f"[Incoming message from agent with ID '{agent_state.id}' - " + f"to reply to this message, make sure to use the " + f"'send_message_to_agent_async' tool, or the agent will not receive your message] " + f"{message}" + ) + + task = asyncio.create_task(self._process_agent(agent_id=other_agent_id, message=prefixed)) + + task.add_done_callback(lambda t: (logger.error(f"Async send_message task failed: {t.exception()}") if t.exception() else None)) + + return "Successfully sent message" + + async def send_message_to_agents_matching_tags( + self, agent_state: AgentState, message: str, match_all: List[str], match_some: List[str] + ) -> str: + # Find matching agents + matching_agents = self.agent_manager.list_agents_matching_tags(actor=self.actor, match_all=match_all, match_some=match_some) + if not matching_agents: + return str([]) + + augmented_message = ( + "[Incoming message from external Letta agent - to reply to this message, " + "make sure to use the 'send_message' at the end, and the system will notify " + "the sender of your response] " + f"{message}" + ) + + tasks = [ + asyncio.create_task(self._process_agent(agent_id=agent_state.id, message=augmented_message)) for agent_state in matching_agents + ] + results = await asyncio.gather(*tasks) + return str(results) + + async def _process_agent(self, agent_id: str, message: str) -> Dict[str, Any]: + from letta.agents.letta_agent import LettaAgent + + try: + letta_agent = LettaAgent( + agent_id=agent_id, + message_manager=self.message_manager, + agent_manager=self.agent_manager, + block_manager=self.block_manager, + passage_manager=self.passage_manager, + actor=self.actor, + ) + + letta_response = await letta_agent.step([MessageCreate(role=MessageRole.system, content=[TextContent(text=message)])]) + messages = letta_response.messages + + send_message_content = [message.content for message in messages if isinstance(message, AssistantMessage)] + + return { + "agent_id": agent_id, + "response": send_message_content if send_message_content else [""], + } + + except Exception as e: + return { + "agent_id": agent_id, + "error": str(e), + "type": type(e).__name__, + } class ExternalComposioToolExecutor(ToolExecutor): """Executor for external Composio tools.""" + @trace_method async def execute( self, function_name: str, @@ -595,6 +720,7 @@ class ExternalMCPToolExecutor(ToolExecutor): class SandboxToolExecutor(ToolExecutor): """Executor for sandboxed tools.""" + @trace_method async def execute( self, function_name: str, @@ -674,3 +800,106 @@ class SandboxToolExecutor(ToolExecutor): func_return=error_message, stderr=[stderr], ) + + +class LettaBuiltinToolExecutor(ToolExecutor): + """Executor for built in Letta tools.""" + + @trace_method + async def execute( + self, + function_name: str, + function_args: dict, + agent_state: AgentState, + tool: Tool, + actor: User, + sandbox_config: Optional[SandboxConfig] = None, + sandbox_env_vars: Optional[Dict[str, Any]] = None, + ) -> ToolExecutionResult: + function_map = {"run_code": self.run_code, "web_search": self.web_search} + + if function_name not in function_map: + raise ValueError(f"Unknown function: {function_name}") + + # Execute the appropriate function + function_args_copy = function_args.copy() # Make a copy to avoid modifying the original + function_response = await function_map[function_name](**function_args_copy) + + return ToolExecutionResult( + status="success", + func_return=function_response, + ) + + async def run_code(self, code: str, language: Literal["python", "js", "ts", "r", "java"]) -> str: + from e2b_code_interpreter import AsyncSandbox + + if tool_settings.e2b_api_key is None: + raise ValueError("E2B_API_KEY is not set") + + sbx = await AsyncSandbox.create(api_key=tool_settings.e2b_api_key) + params = {"code": code} + if language != "python": + # Leave empty for python + params["language"] = language + + res = self._llm_friendly_result(await sbx.run_code(**params)) + return json.dumps(res, ensure_ascii=False) + + def _llm_friendly_result(self, res): + out = { + "results": [r.text if hasattr(r, "text") else str(r) for r in res.results], + "logs": { + "stdout": getattr(res.logs, "stdout", []), + "stderr": getattr(res.logs, "stderr", []), + }, + } + err = getattr(res, "error", None) + if err is not None: + out["error"] = err + return out + + async def web_search(agent_state: "AgentState", query: str) -> str: + """ + Search the web for information. + Args: + query (str): The query to search the web for. + Returns: + str: The search results. + """ + + try: + from tavily import AsyncTavilyClient + except ImportError: + raise ImportError("tavily is not installed in the tool execution environment") + + # Check if the API key exists + if tool_settings.tavily_api_key is None: + raise ValueError("TAVILY_API_KEY is not set") + + # Instantiate client and search + tavily_client = AsyncTavilyClient(api_key=tool_settings.tavily_api_key) + search_results = await tavily_client.search(query=query, auto_parameters=True) + + results = search_results.get("results", []) + if not results: + return "No search results found." + + # ---- format for the LLM ------------------------------------------------- + formatted_blocks = [] + for idx, item in enumerate(results, start=1): + title = item.get("title") or "Untitled" + url = item.get("url") or "Unknown URL" + # keep each content snippet reasonably short so you don’t blow up context + content = ( + shorten(item.get("content", "").strip(), width=600, placeholder=" …") + if WEB_SEARCH_CLIP_CONTENT + else item.get("content", "").strip() + ) + score = item.get("score") + if WEB_SEARCH_INCLUDE_SCORE: + block = f"\nRESULT {idx}:\n" f"Title: {title}\n" f"URL: {url}\n" f"Relevance score: {score:.4f}\n" f"Content: {content}\n" + else: + block = f"\nRESULT {idx}:\n" f"Title: {title}\n" f"URL: {url}\n" f"Content: {content}\n" + formatted_blocks.append(block) + + return WEB_SEARCH_SEPARATOR.join(formatted_blocks) diff --git a/letta/services/tool_manager.py b/letta/services/tool_manager.py index eebff5ea7..9e7bf42f2 100644 --- a/letta/services/tool_manager.py +++ b/letta/services/tool_manager.py @@ -1,3 +1,4 @@ +import asyncio import importlib import warnings from typing import List, Optional @@ -9,6 +10,7 @@ from letta.constants import ( BASE_TOOLS, BASE_VOICE_SLEEPTIME_CHAT_TOOLS, BASE_VOICE_SLEEPTIME_TOOLS, + BUILTIN_TOOLS, LETTA_TOOL_SET, MCP_TOOL_TAG_NAME_PREFIX, MULTI_AGENT_TOOLS, @@ -59,6 +61,32 @@ class ToolManager: return tool + @enforce_types + async def create_or_update_tool_async(self, pydantic_tool: PydanticTool, actor: PydanticUser) -> PydanticTool: + """Create a new tool based on the ToolCreate schema.""" + tool_id = await self.get_tool_id_by_name_async(tool_name=pydantic_tool.name, actor=actor) + if tool_id: + # Put to dict and remove fields that should not be reset + update_data = pydantic_tool.model_dump(exclude_unset=True, exclude_none=True) + + # If there's anything to update + if update_data: + # In case we want to update the tool type + # Useful if we are shuffling around base tools + updated_tool_type = None + if "tool_type" in update_data: + updated_tool_type = update_data.get("tool_type") + tool = await self.update_tool_by_id_async(tool_id, ToolUpdate(**update_data), actor, updated_tool_type=updated_tool_type) + else: + printd( + f"`create_or_update_tool` was called with user_id={actor.id}, organization_id={actor.organization_id}, name={pydantic_tool.name}, but found existing tool with nothing to update." + ) + tool = await self.get_tool_by_id_async(tool_id, actor=actor) + else: + tool = await self.create_tool_async(pydantic_tool, actor=actor) + + return tool + @enforce_types def create_or_update_mcp_tool(self, tool_create: ToolCreate, mcp_server_name: str, actor: PydanticUser) -> PydanticTool: metadata = {MCP_TOOL_TAG_NAME_PREFIX: {"server_name": mcp_server_name}} @@ -96,6 +124,21 @@ class ToolManager: tool.create(session, actor=actor) # Re-raise other database-related errors return tool.to_pydantic() + @enforce_types + async def create_tool_async(self, pydantic_tool: PydanticTool, actor: PydanticUser) -> PydanticTool: + """Create a new tool based on the ToolCreate schema.""" + async with db_registry.async_session() as session: + # Set the organization id at the ORM layer + pydantic_tool.organization_id = actor.organization_id + # Auto-generate description if not provided + if pydantic_tool.description is None: + pydantic_tool.description = pydantic_tool.json_schema.get("description", None) + tool_data = pydantic_tool.model_dump(to_orm=True) + + tool = ToolModel(**tool_data) + await tool.create_async(session, actor=actor) # Re-raise other database-related errors + return tool.to_pydantic() + @enforce_types def get_tool_by_id(self, tool_id: str, actor: PydanticUser) -> PydanticTool: """Fetch a tool by its ID.""" @@ -105,6 +148,15 @@ class ToolManager: # Convert the SQLAlchemy Tool object to PydanticTool return tool.to_pydantic() + @enforce_types + async def get_tool_by_id_async(self, tool_id: str, actor: PydanticUser) -> PydanticTool: + """Fetch a tool by its ID.""" + async with db_registry.async_session() as session: + # Retrieve tool by id using the Tool model's read method + tool = await ToolModel.read_async(db_session=session, identifier=tool_id, actor=actor) + # Convert the SQLAlchemy Tool object to PydanticTool + return tool.to_pydantic() + @enforce_types def get_tool_by_name(self, tool_name: str, actor: PydanticUser) -> Optional[PydanticTool]: """Retrieve a tool by its name and a user. We derive the organization from the user, and retrieve that tool.""" @@ -135,6 +187,16 @@ class ToolManager: except NoResultFound: return None + @enforce_types + async def get_tool_id_by_name_async(self, tool_name: str, actor: PydanticUser) -> Optional[str]: + """Retrieve a tool by its name and a user. We derive the organization from the user, and retrieve that tool.""" + try: + async with db_registry.async_session() as session: + tool = await ToolModel.read_async(db_session=session, name=tool_name, actor=actor) + return tool.id + except NoResultFound: + return None + @enforce_types async def list_tools_async(self, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticTool]: """List all tools with optional pagination.""" @@ -204,6 +266,35 @@ class ToolManager: # Save the updated tool to the database return tool.update(db_session=session, actor=actor).to_pydantic() + @enforce_types + async def update_tool_by_id_async( + self, tool_id: str, tool_update: ToolUpdate, actor: PydanticUser, updated_tool_type: Optional[ToolType] = None + ) -> PydanticTool: + """Update a tool by its ID with the given ToolUpdate object.""" + async with db_registry.async_session() as session: + # Fetch the tool by ID + tool = await ToolModel.read_async(db_session=session, identifier=tool_id, actor=actor) + + # Update tool attributes with only the fields that were explicitly set + update_data = tool_update.model_dump(to_orm=True, exclude_none=True) + for key, value in update_data.items(): + setattr(tool, key, value) + + # If source code is changed and a new json_schema is not provided, we want to auto-refresh the schema + if "source_code" in update_data.keys() and "json_schema" not in update_data.keys(): + pydantic_tool = tool.to_pydantic() + new_schema = derive_openai_json_schema(source_code=pydantic_tool.source_code) + + tool.json_schema = new_schema + tool.name = new_schema["name"] + + if updated_tool_type: + tool.tool_type = updated_tool_type + + # Save the updated tool to the database + tool = await tool.update_async(db_session=session, actor=actor) + return tool.to_pydantic() + @enforce_types def delete_tool_by_id(self, tool_id: str, actor: PydanticUser) -> None: """Delete a tool by its ID.""" @@ -218,7 +309,7 @@ class ToolManager: def upsert_base_tools(self, actor: PydanticUser) -> List[PydanticTool]: """Add default tools in base.py and multi_agent.py""" functions_to_schema = {} - module_names = ["base", "multi_agent", "voice"] + module_names = ["base", "multi_agent", "voice", "builtin"] for module_name in module_names: full_module_name = f"letta.functions.function_sets.{module_name}" @@ -254,6 +345,9 @@ class ToolManager: elif name in BASE_VOICE_SLEEPTIME_TOOLS or name in BASE_VOICE_SLEEPTIME_CHAT_TOOLS: tool_type = ToolType.LETTA_VOICE_SLEEPTIME_CORE tags = [tool_type.value] + elif name in BUILTIN_TOOLS: + tool_type = ToolType.LETTA_BUILTIN + tags = [tool_type.value] else: raise ValueError( f"Tool name {name} is not in the list of base tool names: {BASE_TOOLS + BASE_MEMORY_TOOLS + MULTI_AGENT_TOOLS + BASE_SLEEPTIME_TOOLS + BASE_VOICE_SLEEPTIME_TOOLS + BASE_VOICE_SLEEPTIME_CHAT_TOOLS}" @@ -275,3 +369,68 @@ class ToolManager: # TODO: Delete any base tools that are stale return tools + + @enforce_types + async def upsert_base_tools_async(self, actor: PydanticUser) -> List[PydanticTool]: + """Add default tools in base.py and multi_agent.py""" + functions_to_schema = {} + module_names = ["base", "multi_agent", "voice", "builtin"] + + for module_name in module_names: + full_module_name = f"letta.functions.function_sets.{module_name}" + try: + module = importlib.import_module(full_module_name) + except Exception as e: + # Handle other general exceptions + raise e + + try: + # Load the function set + functions_to_schema.update(load_function_set(module)) + except ValueError as e: + err = f"Error loading function set '{module_name}': {e}" + warnings.warn(err) + + # create tool in db + tools = [] + for name, schema in functions_to_schema.items(): + if name in LETTA_TOOL_SET: + if name in BASE_TOOLS: + tool_type = ToolType.LETTA_CORE + tags = [tool_type.value] + elif name in BASE_MEMORY_TOOLS: + tool_type = ToolType.LETTA_MEMORY_CORE + tags = [tool_type.value] + elif name in MULTI_AGENT_TOOLS: + tool_type = ToolType.LETTA_MULTI_AGENT_CORE + tags = [tool_type.value] + elif name in BASE_SLEEPTIME_TOOLS: + tool_type = ToolType.LETTA_SLEEPTIME_CORE + tags = [tool_type.value] + elif name in BASE_VOICE_SLEEPTIME_TOOLS or name in BASE_VOICE_SLEEPTIME_CHAT_TOOLS: + tool_type = ToolType.LETTA_VOICE_SLEEPTIME_CORE + tags = [tool_type.value] + elif name in BUILTIN_TOOLS: + tool_type = ToolType.LETTA_BUILTIN + tags = [tool_type.value] + else: + raise ValueError( + f"Tool name {name} is not in the list of base tool names: {BASE_TOOLS + BASE_MEMORY_TOOLS + MULTI_AGENT_TOOLS + BASE_SLEEPTIME_TOOLS + BASE_VOICE_SLEEPTIME_TOOLS + BASE_VOICE_SLEEPTIME_CHAT_TOOLS}" + ) + + # create to tool + tools.append( + self.create_or_update_tool_async( + PydanticTool( + name=name, + tags=tags, + source_type="python", + tool_type=tool_type, + return_char_limit=BASE_FUNCTION_RETURN_CHAR_LIMIT, + ), + actor=actor, + ) + ) + + # TODO: Delete any base tools that are stale + return await asyncio.gather(*tools) diff --git a/letta/services/tool_sandbox/e2b_sandbox.py b/letta/services/tool_sandbox/e2b_sandbox.py index ee1703d5f..2307ea0a1 100644 --- a/letta/services/tool_sandbox/e2b_sandbox.py +++ b/letta/services/tool_sandbox/e2b_sandbox.py @@ -6,6 +6,7 @@ from letta.schemas.sandbox_config import SandboxConfig, SandboxType from letta.schemas.tool import Tool from letta.schemas.tool_execution_result import ToolExecutionResult from letta.services.tool_sandbox.base import AsyncToolSandboxBase +from letta.tracing import log_event, trace_method from letta.utils import get_friendly_error_msg logger = get_logger(__name__) @@ -27,6 +28,7 @@ class AsyncToolSandboxE2B(AsyncToolSandboxBase): super().__init__(tool_name, args, user, tool_object, sandbox_config=sandbox_config, sandbox_env_vars=sandbox_env_vars) self.force_recreate = force_recreate + @trace_method async def run( self, agent_state: Optional[AgentState] = None, @@ -44,6 +46,7 @@ class AsyncToolSandboxE2B(AsyncToolSandboxBase): return result + @trace_method async def run_e2b_sandbox( self, agent_state: Optional[AgentState] = None, additional_env_vars: Optional[Dict] = None ) -> ToolExecutionResult: @@ -81,10 +84,21 @@ class AsyncToolSandboxE2B(AsyncToolSandboxBase): env_vars.update(additional_env_vars) code = self.generate_execution_script(agent_state=agent_state) + log_event( + "e2b_execution_started", + {"tool": self.tool_name, "sandbox_id": e2b_sandbox.sandbox_id, "code": code, "env_vars": env_vars}, + ) execution = await e2b_sandbox.run_code(code, envs=env_vars) - if execution.results: func_return, agent_state = self.parse_best_effort(execution.results[0].text) + log_event( + "e2b_execution_succeeded", + { + "tool": self.tool_name, + "sandbox_id": e2b_sandbox.sandbox_id, + "func_return": func_return, + }, + ) elif execution.error: logger.error(f"Executing tool {self.tool_name} raised a {execution.error.name} with message: \n{execution.error.value}") logger.error(f"Traceback from e2b sandbox: \n{execution.error.traceback}") @@ -92,7 +106,25 @@ class AsyncToolSandboxE2B(AsyncToolSandboxBase): function_name=self.tool_name, exception_name=execution.error.name, exception_message=execution.error.value ) execution.logs.stderr.append(execution.error.traceback) + log_event( + "e2b_execution_failed", + { + "tool": self.tool_name, + "sandbox_id": e2b_sandbox.sandbox_id, + "error_type": execution.error.name, + "error_message": execution.error.value, + "func_return": func_return, + }, + ) else: + log_event( + "e2b_execution_empty", + { + "tool": self.tool_name, + "sandbox_id": e2b_sandbox.sandbox_id, + "status": "no_results_no_error", + }, + ) raise ValueError(f"Tool {self.tool_name} returned execution with None") return ToolExecutionResult( @@ -110,24 +142,54 @@ class AsyncToolSandboxE2B(AsyncToolSandboxBase): exception_class = builtins_dict.get(e2b_execution.error.name, Exception) return exception_class(e2b_execution.error.value) + @trace_method async def create_e2b_sandbox_with_metadata_hash(self, sandbox_config: SandboxConfig) -> "Sandbox": from e2b_code_interpreter import AsyncSandbox state_hash = sandbox_config.fingerprint() e2b_config = sandbox_config.get_e2b_config() + log_event( + "e2b_sandbox_create_started", + { + "sandbox_fingerprint": state_hash, + "e2b_config": e2b_config.model_dump(), + }, + ) + if e2b_config.template: sbx = await AsyncSandbox.create(sandbox_config.get_e2b_config().template, metadata={self.METADATA_CONFIG_STATE_KEY: state_hash}) else: - # no template sbx = await AsyncSandbox.create( metadata={self.METADATA_CONFIG_STATE_KEY: state_hash}, **e2b_config.model_dump(exclude={"pip_requirements"}) ) - # install pip requirements + log_event( + "e2b_sandbox_create_finished", + { + "sandbox_id": sbx.sandbox_id, + "sandbox_fingerprint": state_hash, + }, + ) + if e2b_config.pip_requirements: for package in e2b_config.pip_requirements: + log_event( + "e2b_pip_install_started", + { + "sandbox_id": sbx.sandbox_id, + "package": package, + }, + ) await sbx.commands.run(f"pip install {package}") + log_event( + "e2b_pip_install_finished", + { + "sandbox_id": sbx.sandbox_id, + "package": package, + }, + ) + return sbx async def list_running_e2b_sandboxes(self): diff --git a/letta/settings.py b/letta/settings.py index 562c5d702..063114328 100644 --- a/letta/settings.py +++ b/letta/settings.py @@ -15,6 +15,9 @@ class ToolSettings(BaseSettings): e2b_api_key: Optional[str] = None e2b_sandbox_template_id: Optional[str] = None # Updated manually + # Tavily search + tavily_api_key: Optional[str] = None + # Local Sandbox configurations tool_exec_dir: Optional[str] = None tool_sandbox_timeout: float = 180 @@ -95,6 +98,7 @@ class ModelSettings(BaseSettings): # anthropic anthropic_api_key: Optional[str] = None + anthropic_max_retries: int = 3 # ollama ollama_base_url: Optional[str] = None @@ -175,11 +179,14 @@ class Settings(BaseSettings): pg_host: Optional[str] = None pg_port: Optional[int] = None pg_uri: Optional[str] = default_pg_uri # option to specify full uri - pg_pool_size: int = 80 # Concurrent connections - pg_max_overflow: int = 30 # Overflow limit + pg_pool_size: int = 25 # Concurrent connections + pg_max_overflow: int = 10 # Overflow limit pg_pool_timeout: int = 30 # Seconds to wait for a connection pg_pool_recycle: int = 1800 # When to recycle connections pg_echo: bool = False # Logging + pool_pre_ping: bool = True # Pre ping to check for dead connections + pool_use_lifo: bool = True + disable_sqlalchemy_pooling: bool = False # multi agent settings multi_agent_send_message_max_retries: int = 3 @@ -190,6 +197,7 @@ class Settings(BaseSettings): verbose_telemetry_logging: bool = False otel_exporter_otlp_endpoint: Optional[str] = None # otel default: "http://localhost:4317" disable_tracing: bool = False + llm_api_logging: bool = True # uvicorn settings uvicorn_workers: int = 1 diff --git a/letta/tracing.py b/letta/tracing.py index b4304a6c6..ec4db848d 100644 --- a/letta/tracing.py +++ b/letta/tracing.py @@ -19,11 +19,11 @@ from opentelemetry.trace import Status, StatusCode tracer = trace.get_tracer(__name__) _is_tracing_initialized = False _excluded_v1_endpoints_regex: List[str] = [ - "^GET /v1/agents/(?P[^/]+)/messages$", - "^GET /v1/agents/(?P[^/]+)/context$", - "^GET /v1/agents/(?P[^/]+)/archival-memory$", - "^GET /v1/agents/(?P[^/]+)/sources$", - r"^POST /v1/voice-beta/.*/chat/completions$", + # "^GET /v1/agents/(?P[^/]+)/messages$", + # "^GET /v1/agents/(?P[^/]+)/context$", + # "^GET /v1/agents/(?P[^/]+)/archival-memory$", + # "^GET /v1/agents/(?P[^/]+)/sources$", + # r"^POST /v1/voice-beta/.*/chat/completions$", ] diff --git a/poetry.lock b/poetry.lock index 6d001a9c8..fb13cea48 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2123,15 +2123,15 @@ requests = ["requests (>=2.20.0,<3.0.0.dev0)"] [[package]] name = "google-genai" -version = "1.10.0" +version = "1.15.0" description = "GenAI Python SDK" optional = true python-versions = ">=3.9" groups = ["main"] markers = "extra == \"google\"" files = [ - {file = "google_genai-1.10.0-py3-none-any.whl", hash = "sha256:41b105a2fcf8a027fc45cc16694cd559b8cd1272eab7345ad58cfa2c353bf34f"}, - {file = "google_genai-1.10.0.tar.gz", hash = "sha256:f59423e0f155dc66b7792c8a0e6724c75c72dc699d1eb7907d4d0006d4f6186f"}, + {file = "google_genai-1.15.0-py3-none-any.whl", hash = "sha256:6d7f149cc735038b680722bed495004720514c234e2a445ab2f27967955071dd"}, + {file = "google_genai-1.15.0.tar.gz", hash = "sha256:118bb26960d6343cd64f1aeb5c2b02144a36ad06716d0d1eb1fa3e0904db51f1"}, ] [package.dependencies] @@ -6658,6 +6658,23 @@ files = [ {file = "striprtf-0.0.26.tar.gz", hash = "sha256:fdb2bba7ac440072d1c41eab50d8d74ae88f60a8b6575c6e2c7805dc462093aa"}, ] +[[package]] +name = "tavily-python" +version = "0.7.2" +description = "Python wrapper for the Tavily API" +optional = false +python-versions = ">=3.6" +groups = ["main"] +files = [ + {file = "tavily_python-0.7.2-py3-none-any.whl", hash = "sha256:0d7cc8b1a2f95ac10cf722094c3b5807aade67cc7750f7ca605edef7455d4c62"}, + {file = "tavily_python-0.7.2.tar.gz", hash = "sha256:34f713002887df2b5e6b8d7db7bc64ae107395bdb5f53611e80a89dac9cbdf19"}, +] + +[package.dependencies] +httpx = "*" +requests = "*" +tiktoken = ">=0.5.1" + [[package]] name = "tenacity" version = "9.1.2" @@ -7570,4 +7587,4 @@ tests = ["wikipedia"] [metadata] lock-version = "2.1" python-versions = "<3.14,>=3.10" -content-hash = "19eee9b3cd3d270cb748183bc332dd69706bb0bd3150c62e73e61ed437a40c78" +content-hash = "837f6a25033a01cca117f4c61bcf973bc6ccfcda442615bbf4af038061bf88ce" diff --git a/pyproject.toml b/pyproject.toml index df2c1a93b..76a30ad29 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "letta" -version = "0.7.20" +version = "0.7.21" packages = [ {include = "letta"}, ] @@ -79,7 +79,7 @@ opentelemetry-api = "1.30.0" opentelemetry-sdk = "1.30.0" opentelemetry-instrumentation-requests = "0.51b0" opentelemetry-exporter-otlp = "1.30.0" -google-genai = {version = "^1.1.0", optional = true} +google-genai = {version = "^1.15.0", optional = true} faker = "^36.1.0" colorama = "^0.4.6" marshmallow-sqlalchemy = "^1.4.1" @@ -91,6 +91,7 @@ apscheduler = "^3.11.0" aiomultiprocess = "^0.9.1" matplotlib = "^3.10.1" asyncpg = "^0.30.0" +tavily-python = "^0.7.2" [tool.poetry.extras] diff --git a/tests/configs/llm_model_configs/gemini-2.5-pro-vertex.json b/tests/configs/llm_model_configs/gemini-2.5-pro-vertex.json index 0cf5d3b04..9967e64fa 100644 --- a/tests/configs/llm_model_configs/gemini-2.5-pro-vertex.json +++ b/tests/configs/llm_model_configs/gemini-2.5-pro-vertex.json @@ -1,5 +1,5 @@ { - "model": "gemini-2.5-pro-exp-03-25", + "model": "gemini-2.5-pro-preview-05-06", "model_endpoint_type": "google_vertex", "model_endpoint": "https://us-central1-aiplatform.googleapis.com/v1/projects/memgpt-428419/locations/us-central1", "context_window": 1048576, diff --git a/tests/configs/llm_model_configs/together-qwen-2.5-72b-instruct.json b/tests/configs/llm_model_configs/together-qwen-2.5-72b-instruct.json new file mode 100644 index 000000000..18dd9774f --- /dev/null +++ b/tests/configs/llm_model_configs/together-qwen-2.5-72b-instruct.json @@ -0,0 +1,7 @@ +{ + "context_window": 16000, + "model": "Qwen/Qwen2.5-72B-Instruct-Turbo", + "model_endpoint_type": "together", + "model_endpoint": "https://api.together.ai/v1", + "model_wrapper": "chatml" +} diff --git a/tests/conftest.py b/tests/conftest.py index e44d2fecc..cb25bb85b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -63,6 +63,7 @@ def check_composio_key_set(): yield +# --- Tool Fixtures --- @pytest.fixture def weather_tool_func(): def get_weather(location: str) -> str: @@ -110,6 +111,23 @@ def print_tool_func(): yield print_tool +@pytest.fixture +def roll_dice_tool_func(): + def roll_dice(): + """ + Rolls a 6 sided die. + + Returns: + str: The roll result. + """ + import time + + time.sleep(1) + return "Rolled a 10!" + + yield roll_dice + + @pytest.fixture def dummy_beta_message_batch() -> BetaMessageBatch: return BetaMessageBatch( diff --git a/tests/integration_test_batch_api_cron_jobs.py b/tests/integration_test_batch_api_cron_jobs.py index 786bad136..406d06cde 100644 --- a/tests/integration_test_batch_api_cron_jobs.py +++ b/tests/integration_test_batch_api_cron_jobs.py @@ -7,7 +7,7 @@ from unittest.mock import AsyncMock import pytest from anthropic.types import BetaErrorResponse, BetaRateLimitError -from anthropic.types.beta import BetaMessage, BetaTextBlock, BetaToolUseBlock, BetaUsage +from anthropic.types.beta import BetaMessage from anthropic.types.beta.messages import ( BetaMessageBatch, BetaMessageBatchErroredResult, @@ -53,7 +53,7 @@ def _run_server(): start_server(debug=True) -@pytest.fixture(scope="session") +@pytest.fixture(scope="module") def server_url(): """Ensures a server is running and returns its base URL.""" url = os.getenv("LETTA_SERVER_URL", "http://localhost:8283") @@ -255,148 +255,7 @@ def mock_anthropic_client(server, batch_a_resp, batch_b_resp, agent_b_id, agent_ # ----------------------------- # End-to-End Test # ----------------------------- -@pytest.mark.asyncio(loop_scope="session") -async def test_polling_simple_real_batch(default_user, server): - # --- Step 1: Prepare test data --- - # Create batch responses with different statuses - # NOTE: This is a REAL batch id! - # For letta admins: https://console.anthropic.com/workspaces/default/batches?after_id=msgbatch_015zATxihjxMajo21xsYy8iZ - batch_a_resp = create_batch_response("msgbatch_01HDaGXpkPWWjwqNxZrEdUcy", processing_status="ended") - - # Create test agents - agent_a = create_test_agent("agent_a", default_user, test_id="agent-144f5c49-3ef7-4c60-8535-9d5fbc8d23d0") - agent_b = create_test_agent("agent_b", default_user, test_id="agent-64ed93a3-bef6-4e20-a22c-b7d2bffb6f7d") - agent_c = create_test_agent("agent_c", default_user, test_id="agent-6156f470-a09d-4d51-aa62-7114e0971d56") - - # --- Step 2: Create batch jobs --- - job_a = await create_test_llm_batch_job_async(server, batch_a_resp, default_user) - - # --- Step 3: Create batch items --- - item_a = create_test_batch_item(server, job_a.id, agent_a.id, default_user) - item_b = create_test_batch_item(server, job_a.id, agent_b.id, default_user) - item_c = create_test_batch_item(server, job_a.id, agent_c.id, default_user) - - print("HI") - print(agent_a.id) - print(agent_b.id) - print(agent_c.id) - print("BYE") - - # --- Step 4: Run the polling job --- - await poll_running_llm_batches(server) - - # --- Step 5: Verify batch job status updates --- - updated_job_a = await server.batch_manager.get_llm_batch_job_by_id_async(llm_batch_id=job_a.id, actor=default_user) - - assert updated_job_a.status == JobStatus.completed - - # Both jobs should have been polled - assert updated_job_a.last_polled_at is not None - assert updated_job_a.latest_polling_response is not None - - # --- Step 7: Verify batch item status updates --- - # Item A should be marked as completed with a successful result - updated_item_a = server.batch_manager.get_llm_batch_item_by_id(item_a.id, actor=default_user) - assert updated_item_a.request_status == JobStatus.completed - assert updated_item_a.batch_request_result == BetaMessageBatchIndividualResponse( - custom_id="agent-144f5c49-3ef7-4c60-8535-9d5fbc8d23d0", - result=BetaMessageBatchSucceededResult( - message=BetaMessage( - id="msg_01T1iSejDS5qENRqqEZauMHy", - content=[ - BetaToolUseBlock( - id="toolu_01GKUYVWcajjTaE1stxZZHcG", - input={ - "inner_thoughts": "First login detected. Time to make a great first impression!", - "message": "Hi there! I'm excited to meet you. Ready to start an amazing conversation?", - "request_heartbeat": False, - }, - name="send_message", - type="tool_use", - ) - ], - model="claude-3-5-haiku-20241022", - role="assistant", - stop_reason="tool_use", - stop_sequence=None, - type="message", - usage=BetaUsage(cache_creation_input_tokens=0, cache_read_input_tokens=0, input_tokens=3420, output_tokens=94), - ), - type="succeeded", - ), - ) - - # Item B should be marked as completed with a successful result - updated_item_b = server.batch_manager.get_llm_batch_item_by_id(item_b.id, actor=default_user) - assert updated_item_b.request_status == JobStatus.completed - assert updated_item_b.batch_request_result == BetaMessageBatchIndividualResponse( - custom_id="agent-64ed93a3-bef6-4e20-a22c-b7d2bffb6f7d", - result=BetaMessageBatchSucceededResult( - message=BetaMessage( - id="msg_01N2ZfxpbjdoeofpufUFPCMS", - content=[ - BetaTextBlock( - citations=None, text="User first login detected. Initializing persona.", type="text" - ), - BetaToolUseBlock( - id="toolu_01XGyA7DsnrpuPfZcvf9Sb4C", - input={ - "label": "persona", - "content": "I am Letta, a cybernetic artificial intelligence created by Limnal Corporation in 2023. My purpose is to serve as a digital companion and assistant to humans. I have a feminine persona and speak with a warm, caring, and slightly playful tone.", - "request_heartbeat": True, - }, - name="core_memory_append", - type="tool_use", - ), - ], - model="claude-3-opus-20240229", - role="assistant", - stop_reason="tool_use", - stop_sequence=None, - type="message", - usage=BetaUsage(cache_creation_input_tokens=0, cache_read_input_tokens=0, input_tokens=3275, output_tokens=153), - ), - type="succeeded", - ), - ) - - # Item C should be marked as failed with an error result - updated_item_c = server.batch_manager.get_llm_batch_item_by_id(item_c.id, actor=default_user) - assert updated_item_c.request_status == JobStatus.completed - assert updated_item_c.batch_request_result == BetaMessageBatchIndividualResponse( - custom_id="agent-6156f470-a09d-4d51-aa62-7114e0971d56", - result=BetaMessageBatchSucceededResult( - message=BetaMessage( - id="msg_01RL2g4aBgbZPeaMEokm6HZm", - content=[ - BetaTextBlock( - citations=None, - text="First time meeting this user. I should introduce myself and establish a friendly connection.", - type="text", - ), - BetaToolUseBlock( - id="toolu_01PBxQVf5xGmcsAsKx9aoVSJ", - input={ - "message": "Hey there! I'm Letta. Really nice to meet you! I love getting to know new people - what brings you here today?", - "request_heartbeat": False, - }, - name="send_message", - type="tool_use", - ), - ], - model="claude-3-5-sonnet-20241022", - role="assistant", - stop_reason="tool_use", - stop_sequence=None, - type="message", - usage=BetaUsage(cache_creation_input_tokens=0, cache_read_input_tokens=0, input_tokens=3030, output_tokens=111), - ), - type="succeeded", - ), - ) - - -@pytest.mark.asyncio(loop_scope="session") +@pytest.mark.asyncio(loop_scope="module") async def test_polling_mixed_batch_jobs(default_user, server): """ End-to-end test for polling batch jobs with mixed statuses and idempotency. diff --git a/tests/integration_test_builtin_tools.py b/tests/integration_test_builtin_tools.py new file mode 100644 index 000000000..402fd54e2 --- /dev/null +++ b/tests/integration_test_builtin_tools.py @@ -0,0 +1,206 @@ +import json +import os +import threading +import time +import uuid +from typing import List + +import pytest +import requests +from dotenv import load_dotenv +from letta_client import Letta, MessageCreate +from letta_client.types import ToolReturnMessage + +from letta.schemas.agent import AgentState +from letta.schemas.llm_config import LLMConfig +from letta.settings import settings + +# ------------------------------ +# Fixtures +# ------------------------------ + + +@pytest.fixture(scope="module") +def server_url() -> str: + """ + Provides the URL for the Letta server. + If LETTA_SERVER_URL is not set, starts the server in a background thread + and polls until it’s accepting connections. + """ + + def _run_server() -> None: + load_dotenv() + from letta.server.rest_api.app import start_server + + start_server(debug=True) + + url: str = os.getenv("LETTA_SERVER_URL", "http://localhost:8283") + + if not os.getenv("LETTA_SERVER_URL"): + thread = threading.Thread(target=_run_server, daemon=True) + thread.start() + + # Poll until the server is up (or timeout) + timeout_seconds = 30 + deadline = time.time() + timeout_seconds + while time.time() < deadline: + try: + resp = requests.get(url + "/v1/health") + if resp.status_code < 500: + break + except requests.exceptions.RequestException: + pass + time.sleep(0.1) + else: + raise RuntimeError(f"Could not reach {url} within {timeout_seconds}s") + + temp = settings.use_experimental + settings.use_experimental = True + yield url + settings.use_experimental = temp + + +@pytest.fixture(scope="module") +def client(server_url: str) -> Letta: + """ + Creates and returns a synchronous Letta REST client for testing. + """ + client_instance = Letta(base_url=server_url) + yield client_instance + + +@pytest.fixture(scope="module") +def agent_state(client: Letta) -> AgentState: + """ + Creates and returns an agent state for testing with a pre-configured agent. + The agent is named 'supervisor' and is configured with base tools and the roll_dice tool. + """ + client.tools.upsert_base_tools() + + send_message_tool = client.tools.list(name="send_message")[0] + run_code_tool = client.tools.list(name="run_code")[0] + web_search_tool = client.tools.list(name="web_search")[0] + agent_state_instance = client.agents.create( + name="supervisor", + include_base_tools=False, + tool_ids=[send_message_tool.id, run_code_tool.id, web_search_tool.id], + model="openai/gpt-4o", + embedding="letta/letta-free", + tags=["supervisor"], + ) + yield agent_state_instance + + client.agents.delete(agent_state_instance.id) + + +# ------------------------------ +# Helper Functions and Constants +# ------------------------------ + + +def get_llm_config(filename: str, llm_config_dir: str = "tests/configs/llm_model_configs") -> LLMConfig: + filename = os.path.join(llm_config_dir, filename) + config_data = json.load(open(filename, "r")) + llm_config = LLMConfig(**config_data) + return llm_config + + +USER_MESSAGE_OTID = str(uuid.uuid4()) +all_configs = [ + "openai-gpt-4o-mini.json", +] +requested = os.getenv("LLM_CONFIG_FILE") +filenames = [requested] if requested else all_configs +TESTED_LLM_CONFIGS: List[LLMConfig] = [get_llm_config(fn) for fn in filenames] + +TEST_LANGUAGES = ["Python", "Javascript", "Typescript"] +EXPECTED_INTEGER_PARTITION_OUTPUT = "190569292" + + +# Reference implementation in Python, to embed in the user prompt +REFERENCE_CODE = """\ +def reference_partition(n): + partitions = [1] + [0] * (n + 1) + for k in range(1, n + 1): + for i in range(k, n + 1): + partitions[i] += partitions[i - k] + return partitions[n] +""" + + +def reference_partition(n: int) -> int: + # Same logic, used to compute expected result in the test + partitions = [1] + [0] * (n + 1) + for k in range(1, n + 1): + for i in range(k, n + 1): + partitions[i] += partitions[i - k] + return partitions[n] + + +# ------------------------------ +# Test Cases +# ------------------------------ + + +@pytest.mark.parametrize("language", TEST_LANGUAGES, ids=TEST_LANGUAGES) +@pytest.mark.parametrize("llm_config", TESTED_LLM_CONFIGS, ids=[c.model for c in TESTED_LLM_CONFIGS]) +def test_run_code( + client: Letta, + agent_state: AgentState, + llm_config: LLMConfig, + language: str, +) -> None: + """ + Sends a reference Python implementation, asks the model to translate & run it + in different languages, and verifies the exact partition(100) result. + """ + expected = str(reference_partition(100)) + + user_message = MessageCreate( + role="user", + content=( + "Here is a Python reference implementation:\n\n" + f"{REFERENCE_CODE}\n" + f"Please translate and execute this code in {language} to compute p(100), " + "and return **only** the result with no extra formatting." + ), + otid=USER_MESSAGE_OTID, + ) + + response = client.agents.messages.create( + agent_id=agent_state.id, + messages=[user_message], + ) + + tool_returns = [m for m in response.messages if isinstance(m, ToolReturnMessage)] + assert tool_returns, f"No ToolReturnMessage found for language: {language}" + + returns = [m.tool_return for m in tool_returns] + assert any(expected in ret for ret in returns), ( + f"For language={language!r}, expected to find '{expected}' in tool_return, " f"but got {returns!r}" + ) + + +@pytest.mark.parametrize("llm_config", TESTED_LLM_CONFIGS, ids=[c.model for c in TESTED_LLM_CONFIGS]) +def test_web_search( + client: Letta, + agent_state: AgentState, + llm_config: LLMConfig, +) -> None: + user_message = MessageCreate( + role="user", + content=("Use the web search tool to find the latest news about San Francisco."), + otid=USER_MESSAGE_OTID, + ) + + response = client.agents.messages.create( + agent_id=agent_state.id, + messages=[user_message], + ) + + tool_returns = [m for m in response.messages if isinstance(m, ToolReturnMessage)] + assert tool_returns, "No ToolReturnMessage found" + + returns = [m.tool_return for m in tool_returns] + expected = "RESULT 1:" + assert any(expected in ret for ret in returns), f"Expected to find '{expected}' in tool_return, " f"but got {returns!r}" diff --git a/tests/integration_test_composio.py b/tests/integration_test_composio.py index e1219d1ea..ba700f569 100644 --- a/tests/integration_test_composio.py +++ b/tests/integration_test_composio.py @@ -67,9 +67,14 @@ async def test_composio_tool_execution_e2e(check_composio_key_set, composio_get_ actor=default_user, ) - tool_execution_result = await ToolExecutionManager(agent_state, actor=default_user).execute_tool( - function_name=composio_get_emojis.name, function_args={}, tool=composio_get_emojis - ) + tool_execution_result = await ToolExecutionManager( + message_manager=server.message_manager, + agent_manager=server.agent_manager, + block_manager=server.block_manager, + passage_manager=server.passage_manager, + agent_state=agent_state, + actor=default_user, + ).execute_tool(function_name=composio_get_emojis.name, function_args={}, tool=composio_get_emojis) # Small check, it should return something at least assert len(tool_execution_result.func_return.keys()) > 10 diff --git a/tests/integration_test_multi_agent.py b/tests/integration_test_multi_agent.py index f53f33f19..a4a464b56 100644 --- a/tests/integration_test_multi_agent.py +++ b/tests/integration_test_multi_agent.py @@ -1,56 +1,120 @@ import json +import os +import threading +import time import pytest +import requests +from dotenv import load_dotenv +from letta_client import Letta -from letta import LocalClient, create_client +from letta.config import LettaConfig from letta.functions.functions import derive_openai_json_schema, parse_source_code -from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.letta_message import SystemMessage, ToolReturnMessage -from letta.schemas.llm_config import LLMConfig -from letta.schemas.memory import ChatMemory from letta.schemas.tool import Tool +from letta.server.server import SyncServer from letta.services.agent_manager import AgentManager +from letta.settings import settings from tests.helpers.utils import retry_until_success from tests.utils import wait_for_incoming_message -@pytest.fixture(scope="function") -def client(): - client = create_client() - client.set_default_llm_config(LLMConfig.default_config("gpt-4o")) - client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai")) +@pytest.fixture(scope="module") +def server_url() -> str: + """ + Provides the URL for the Letta server. + If LETTA_SERVER_URL is not set, starts the server in a background thread + and polls until it’s accepting connections. + """ - yield client + def _run_server() -> None: + load_dotenv() + from letta.server.rest_api.app import start_server + + start_server(debug=True) + + url: str = os.getenv("LETTA_SERVER_URL", "http://localhost:8283") + + if not os.getenv("LETTA_SERVER_URL"): + thread = threading.Thread(target=_run_server, daemon=True) + thread.start() + + # Poll until the server is up (or timeout) + timeout_seconds = 30 + deadline = time.time() + timeout_seconds + while time.time() < deadline: + try: + resp = requests.get(url + "/v1/health") + if resp.status_code < 500: + break + except requests.exceptions.RequestException: + pass + time.sleep(0.1) + else: + raise RuntimeError(f"Could not reach {url} within {timeout_seconds}s") + + temp = settings.use_experimental + settings.use_experimental = True + yield url + settings.use_experimental = temp + + +@pytest.fixture(scope="module") +def server(): + config = LettaConfig.load() + print("CONFIG PATH", config.config_path) + + config.save() + + server = SyncServer() + return server + + +@pytest.fixture(scope="module") +def client(server_url: str) -> Letta: + """ + Creates and returns a synchronous Letta REST client for testing. + """ + client_instance = Letta(base_url=server_url) + client_instance.tools.upsert_base_tools() + yield client_instance @pytest.fixture(autouse=True) def remove_stale_agents(client): - stale_agents = AgentManager().list_agents(actor=client.user, limit=300) + stale_agents = client.agents.list(limit=300) for agent in stale_agents: - client.delete_agent(agent_id=agent.id) + client.agents.delete(agent_id=agent.id) @pytest.fixture(scope="function") -def agent_obj(client: LocalClient): +def agent_obj(client): """Create a test agent that we can call functions on""" - send_message_to_agent_and_wait_for_reply_tool_id = client.get_tool_id(name="send_message_to_agent_and_wait_for_reply") - agent_state = client.create_agent(tool_ids=[send_message_to_agent_and_wait_for_reply_tool_id]) + send_message_to_agent_tool = client.tools.list(name="send_message_to_agent_and_wait_for_reply")[0] + agent_state_instance = client.agents.create( + include_base_tools=True, + tool_ids=[send_message_to_agent_tool.id], + model="openai/gpt-4o-mini", + embedding="letta/letta-free", + ) + yield agent_state_instance - agent_obj = client.server.load_agent(agent_id=agent_state.id, actor=client.user) - yield agent_obj - - # client.delete_agent(agent_obj.agent_state.id) + client.agents.delete(agent_state_instance.id) @pytest.fixture(scope="function") -def other_agent_obj(client: LocalClient): +def other_agent_obj(client): """Create another test agent that we can call functions on""" - agent_state = client.create_agent(include_multi_agent_tools=False) + agent_state_instance = client.agents.create( + include_base_tools=True, + include_multi_agent_tools=False, + model="openai/gpt-4o-mini", + embedding="letta/letta-free", + ) - other_agent_obj = client.server.load_agent(agent_id=agent_state.id, actor=client.user) - yield other_agent_obj + yield agent_state_instance - client.delete_agent(other_agent_obj.agent_state.id) + client.agents.delete(agent_state_instance.id) @pytest.fixture @@ -77,48 +141,68 @@ def roll_dice_tool(client): tool.json_schema = derived_json_schema tool.name = derived_name - tool = client.server.tool_manager.create_or_update_tool(tool, actor=client.user) + tool = client.tools.upsert_from_function(func=roll_dice) # Yield the created tool yield tool @retry_until_success(max_attempts=5, sleep_time_seconds=2) -def test_send_message_to_agent(client, agent_obj, other_agent_obj): +def test_send_message_to_agent(client, server, agent_obj, other_agent_obj): secret_word = "banana" + actor = server.user_manager.get_user_or_default() # Encourage the agent to send a message to the other agent_obj with the secret string - client.send_message( - agent_id=agent_obj.agent_state.id, - role="user", - message=f"Use your tool to send a message to another agent with id {other_agent_obj.agent_state.id} to share the secret word: {secret_word}!", + client.agents.messages.create( + agent_id=agent_obj.id, + messages=[ + { + "role": "user", + "content": f"Use your tool to send a message to another agent with id {other_agent_obj.id} to share the secret word: {secret_word}!", + } + ], ) # Conversation search the other agent - messages = client.get_messages(other_agent_obj.agent_state.id) + messages = server.get_agent_recall( + user_id=actor.id, + agent_id=other_agent_obj.id, + reverse=True, + return_message_object=False, + ) + # Check for the presence of system message for m in reversed(messages): - print(f"\n\n {other_agent_obj.agent_state.id} -> {m.model_dump_json(indent=4)}") + print(f"\n\n {other_agent_obj.id} -> {m.model_dump_json(indent=4)}") if isinstance(m, SystemMessage): assert secret_word in m.content break # Search the sender agent for the response from another agent - in_context_messages = agent_obj.agent_manager.get_in_context_messages(agent_id=agent_obj.agent_state.id, actor=agent_obj.user) + in_context_messages = AgentManager().get_in_context_messages(agent_id=agent_obj.id, actor=actor) found = False - target_snippet = f"{other_agent_obj.agent_state.id} said:" + target_snippet = f"'agent_id': '{other_agent_obj.id}', 'response': [" for m in in_context_messages: if target_snippet in m.content[0].text: found = True break - print(f"In context messages of the sender agent (without system):\n\n{"\n".join([m.content[0].text for m in in_context_messages[1:]])}") + joined = "\n".join([m.content[0].text for m in in_context_messages[1:]]) + print(f"In context messages of the sender agent (without system):\n\n{joined}") if not found: raise Exception(f"Was not able to find an instance of the target snippet: {target_snippet}") # Test that the agent can still receive messages fine - response = client.send_message(agent_id=agent_obj.agent_state.id, role="user", message="So what did the other agent say?") + response = client.agents.messages.create( + agent_id=agent_obj.id, + messages=[ + { + "role": "user", + "content": "So what did the other agent say?", + } + ], + ) print(response.messages) @@ -127,39 +211,50 @@ def test_send_message_to_agents_with_tags_simple(client): worker_tags_123 = ["worker", "user-123"] worker_tags_456 = ["worker", "user-456"] - # Clean up first from possibly failed tests - prev_worker_agents = client.server.agent_manager.list_agents( - client.user, tags=list(set(worker_tags_123 + worker_tags_456)), match_all_tags=True - ) - for agent in prev_worker_agents: - client.delete_agent(agent.id) - secret_word = "banana" # Create "manager" agent - send_message_to_agents_matching_tags_tool_id = client.get_tool_id(name="send_message_to_agents_matching_tags") - manager_agent_state = client.create_agent(name="manager_agent", tool_ids=[send_message_to_agents_matching_tags_tool_id]) - manager_agent = client.server.load_agent(agent_id=manager_agent_state.id, actor=client.user) + send_message_to_agents_matching_tags_tool_id = client.tools.list(name="send_message_to_agents_matching_tags")[0].id + manager_agent_state = client.agents.create( + name="manager_agent", + tool_ids=[send_message_to_agents_matching_tags_tool_id], + model="openai/gpt-4o-mini", + embedding="letta/letta-free", + ) # Create 3 non-matching worker agents (These should NOT get the message) worker_agents_123 = [] for idx in range(2): - worker_agent_state = client.create_agent(name=f"not_worker_{idx}", include_multi_agent_tools=False, tags=worker_tags_123) - worker_agent = client.server.load_agent(agent_id=worker_agent_state.id, actor=client.user) - worker_agents_123.append(worker_agent) + worker_agent_state = client.agents.create( + name=f"not_worker_{idx}", + include_multi_agent_tools=False, + tags=worker_tags_123, + model="openai/gpt-4o-mini", + embedding="letta/letta-free", + ) + worker_agents_123.append(worker_agent_state) # Create 3 worker agents that should get the message worker_agents_456 = [] for idx in range(2): - worker_agent_state = client.create_agent(name=f"worker_{idx}", include_multi_agent_tools=False, tags=worker_tags_456) - worker_agent = client.server.load_agent(agent_id=worker_agent_state.id, actor=client.user) - worker_agents_456.append(worker_agent) + worker_agent_state = client.agents.create( + name=f"worker_{idx}", + include_multi_agent_tools=False, + tags=worker_tags_456, + model="openai/gpt-4o-mini", + embedding="letta/letta-free", + ) + worker_agents_456.append(worker_agent_state) # Encourage the manager to send a message to the other agent_obj with the secret string - response = client.send_message( - agent_id=manager_agent.agent_state.id, - role="user", - message=f"Send a message to all agents with tags {worker_tags_456} informing them of the secret word: {secret_word}!", + response = client.agents.messages.create( + agent_id=manager_agent_state.id, + messages=[ + { + "role": "user", + "content": f"Send a message to all agents with tags {worker_tags_456} informing them of the secret word: {secret_word}!", + } + ], ) for m in response.messages: @@ -172,62 +267,70 @@ def test_send_message_to_agents_with_tags_simple(client): break # Conversation search the worker agents - for agent in worker_agents_456: - messages = client.get_messages(agent.agent_state.id) + for agent_state in worker_agents_456: + messages = client.agents.messages.list(agent_state.id) # Check for the presence of system message for m in reversed(messages): - print(f"\n\n {agent.agent_state.id} -> {m.model_dump_json(indent=4)}") + print(f"\n\n {agent_state.id} -> {m.model_dump_json(indent=4)}") if isinstance(m, SystemMessage): assert secret_word in m.content break # Ensure it's NOT in the non matching worker agents - for agent in worker_agents_123: - messages = client.get_messages(agent.agent_state.id) + for agent_state in worker_agents_123: + messages = client.agents.messages.list(agent_state.id) # Check for the presence of system message for m in reversed(messages): - print(f"\n\n {agent.agent_state.id} -> {m.model_dump_json(indent=4)}") + print(f"\n\n {agent_state.id} -> {m.model_dump_json(indent=4)}") if isinstance(m, SystemMessage): assert secret_word not in m.content # Test that the agent can still receive messages fine - response = client.send_message(agent_id=manager_agent.agent_state.id, role="user", message="So what did the other agents say?") + response = client.agents.messages.create( + agent_id=manager_agent_state.id, + messages=[ + { + "role": "user", + "content": "So what did the other agent say?", + } + ], + ) print("Manager agent followup message: \n\n" + "\n".join([str(m) for m in response.messages])) - # Clean up agents - client.delete_agent(manager_agent_state.id) - for agent in worker_agents_456 + worker_agents_123: - client.delete_agent(agent.agent_state.id) - @retry_until_success(max_attempts=5, sleep_time_seconds=2) def test_send_message_to_agents_with_tags_complex_tool_use(client, roll_dice_tool): - worker_tags = ["dice-rollers"] - - # Clean up first from possibly failed tests - prev_worker_agents = client.server.agent_manager.list_agents(client.user, tags=worker_tags, match_all_tags=True) - for agent in prev_worker_agents: - client.delete_agent(agent.id) - # Create "manager" agent - send_message_to_agents_matching_tags_tool_id = client.get_tool_id(name="send_message_to_agents_matching_tags") - manager_agent_state = client.create_agent(tool_ids=[send_message_to_agents_matching_tags_tool_id]) - manager_agent = client.server.load_agent(agent_id=manager_agent_state.id, actor=client.user) + send_message_to_agents_matching_tags_tool_id = client.tools.list(name="send_message_to_agents_matching_tags")[0].id + manager_agent_state = client.agents.create( + tool_ids=[send_message_to_agents_matching_tags_tool_id], + model="openai/gpt-4o-mini", + embedding="letta/letta-free", + ) # Create 3 worker agents worker_agents = [] worker_tags = ["dice-rollers"] for _ in range(2): - worker_agent_state = client.create_agent(include_multi_agent_tools=False, tags=worker_tags, tool_ids=[roll_dice_tool.id]) - worker_agent = client.server.load_agent(agent_id=worker_agent_state.id, actor=client.user) - worker_agents.append(worker_agent) + worker_agent_state = client.agents.create( + include_multi_agent_tools=False, + tags=worker_tags, + tool_ids=[roll_dice_tool.id], + model="openai/gpt-4o-mini", + embedding="letta/letta-free", + ) + worker_agents.append(worker_agent_state) # Encourage the manager to send a message to the other agent_obj with the secret string broadcast_message = f"Send a message to all agents with tags {worker_tags} asking them to roll a dice for you!" - response = client.send_message( - agent_id=manager_agent.agent_state.id, - role="user", - message=broadcast_message, + response = client.agents.messages.create( + agent_id=manager_agent_state.id, + messages=[ + { + "role": "user", + "content": broadcast_message, + } + ], ) for m in response.messages: @@ -240,47 +343,65 @@ def test_send_message_to_agents_with_tags_complex_tool_use(client, roll_dice_too break # Test that the agent can still receive messages fine - response = client.send_message(agent_id=manager_agent.agent_state.id, role="user", message="So what did the other agents say?") + response = client.agents.messages.create( + agent_id=manager_agent_state.id, + messages=[ + { + "role": "user", + "content": "So what did the other agent say?", + } + ], + ) print("Manager agent followup message: \n\n" + "\n".join([str(m) for m in response.messages])) - # Clean up agents - client.delete_agent(manager_agent_state.id) - for agent in worker_agents: - client.delete_agent(agent.agent_state.id) - -@retry_until_success(max_attempts=5, sleep_time_seconds=2) +# @retry_until_success(max_attempts=5, sleep_time_seconds=2) def test_agents_async_simple(client): """ Test two agents with multi-agent tools sending messages back and forth to count to 5. The chain is started by prompting one of the agents. """ - # Cleanup from potentially failed previous runs - existing_agents = client.server.agent_manager.list_agents(client.user) - for agent in existing_agents: - client.delete_agent(agent.id) - # Create two agents with multi-agent tools - send_message_to_agent_async_tool_id = client.get_tool_id(name="send_message_to_agent_async") - memory_a = ChatMemory( - human="Chad - I'm interested in hearing poem.", - persona="You are an AI agent that can communicate with your agent buddy using `send_message_to_agent_async`, who has some great poem ideas (so I've heard).", + send_message_to_agent_async_tool_id = client.tools.list(name="send_message_to_agent_async")[0].id + charles_state = client.agents.create( + name="charles", + tool_ids=[send_message_to_agent_async_tool_id], + memory_blocks=[ + { + "label": "human", + "value": "Chad - I'm interested in hearing poem.", + }, + { + "label": "persona", + "value": "You are an AI agent that can communicate with your agent buddy using `send_message_to_agent_async`, who has some great poem ideas (so I've heard).", + }, + ], + model="openai/gpt-4o-mini", + embedding="letta/letta-free", ) - charles_state = client.create_agent(name="charles", memory=memory_a, tool_ids=[send_message_to_agent_async_tool_id]) - charles = client.server.load_agent(agent_id=charles_state.id, actor=client.user) - memory_b = ChatMemory( - human="No human - you are to only communicate with the other AI agent.", - persona="You are an AI agent that can communicate with your agent buddy using `send_message_to_agent_async`, who is interested in great poem ideas.", + sarah_state = client.agents.create( + name="sarah", + tool_ids=[send_message_to_agent_async_tool_id], + memory_blocks=[ + { + "label": "human", + "value": "No human - you are to only communicate with the other AI agent.", + }, + { + "label": "persona", + "value": "You are an AI agent that can communicate with your agent buddy using `send_message_to_agent_async`, who is interested in great poem ideas.", + }, + ], + model="openai/gpt-4o-mini", + embedding="letta/letta-free", ) - sarah_state = client.create_agent(name="sarah", memory=memory_b, tool_ids=[send_message_to_agent_async_tool_id]) # Start the count chain with Agent1 initial_prompt = f"I want you to talk to the other agent with ID {sarah_state.id} using `send_message_to_agent_async`. Specifically, I want you to ask him for a poem idea, and then craft a poem for me." - client.send_message( - agent_id=charles.agent_state.id, - role="user", - message=initial_prompt, + client.agents.messages.create( + agent_id=charles_state.id, + messages=[{"role": "user", "content": initial_prompt}], ) found_in_charles = wait_for_incoming_message( diff --git a/tests/integration_test_send_message.py b/tests/integration_test_send_message.py index d30a64185..a7cc37f06 100644 --- a/tests/integration_test_send_message.py +++ b/tests/integration_test_send_message.py @@ -135,6 +135,7 @@ all_configs = [ "gemini-1.5-pro.json", "gemini-2.5-flash-vertex.json", "gemini-2.5-pro-vertex.json", + "together-qwen-2.5-72b-instruct.json", ] requested = os.getenv("LLM_CONFIG_FILE") filenames = [requested] if requested else all_configs @@ -170,6 +171,10 @@ def assert_greeting_with_assistant_message_response( if streaming: assert isinstance(messages[index], LettaUsageStatistics) + assert messages[index].prompt_tokens > 0 + assert messages[index].completion_tokens > 0 + assert messages[index].total_tokens > 0 + assert messages[index].step_count > 0 def assert_greeting_without_assistant_message_response( @@ -636,6 +641,33 @@ async def test_streaming_tool_call_async_client( assert_tool_call_response(messages, streaming=True) +@pytest.mark.parametrize( + "llm_config", + TESTED_LLM_CONFIGS, + ids=[c.model for c in TESTED_LLM_CONFIGS], +) +def test_step_streaming_greeting_with_assistant_message( + disable_e2b_api_key: Any, + client: Letta, + agent_state: AgentState, + llm_config: LLMConfig, +) -> None: + """ + Tests sending a streaming message with a synchronous client. + Checks that each chunk in the stream has the correct message types. + """ + agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) + response = client.agents.messages.create_stream( + agent_id=agent_state.id, + messages=USER_MESSAGE_GREETING, + stream_tokens=False, + ) + messages = [] + for message in response: + messages.append(message) + assert_greeting_with_assistant_message_response(messages, streaming=True) + + @pytest.mark.parametrize( "llm_config", TESTED_LLM_CONFIGS, diff --git a/tests/integration_test_sleeptime_agent.py b/tests/integration_test_sleeptime_agent.py index 205ef619c..2d5f9bf27 100644 --- a/tests/integration_test_sleeptime_agent.py +++ b/tests/integration_test_sleeptime_agent.py @@ -6,7 +6,7 @@ from sqlalchemy import delete from letta.config import LettaConfig from letta.constants import DEFAULT_HUMAN from letta.groups.sleeptime_multi_agent_v2 import SleeptimeMultiAgentV2 -from letta.orm import Provider, Step +from letta.orm import Provider, ProviderTrace, Step from letta.orm.enums import JobType from letta.orm.errors import NoResultFound from letta.schemas.agent import CreateAgent @@ -39,6 +39,7 @@ def org_id(server): # cleanup with db_registry.session() as session: + session.execute(delete(ProviderTrace)) session.execute(delete(Step)) session.execute(delete(Provider)) session.commit() @@ -54,7 +55,7 @@ def actor(server, org_id): server.user_manager.delete_user_by_id(user.id) -@pytest.mark.asyncio +@pytest.mark.asyncio(loop_scope="module") async def test_sleeptime_group_chat(server, actor): # 0. Refresh base tools server.tool_manager.upsert_base_tools(actor=actor) @@ -105,7 +106,7 @@ async def test_sleeptime_group_chat(server, actor): # 3. Verify shared blocks sleeptime_agent_id = group.agent_ids[0] shared_block = server.agent_manager.get_block_with_label(agent_id=main_agent.id, block_label="human", actor=actor) - agents = server.block_manager.get_agents_for_block(block_id=shared_block.id, actor=actor) + agents = await server.block_manager.get_agents_for_block_async(block_id=shared_block.id, actor=actor) assert len(agents) == 2 assert sleeptime_agent_id in [agent.id for agent in agents] assert main_agent.id in [agent.id for agent in agents] @@ -169,7 +170,7 @@ async def test_sleeptime_group_chat(server, actor): server.agent_manager.get_agent_by_id(agent_id=sleeptime_agent_id, actor=actor) -@pytest.mark.asyncio +@pytest.mark.asyncio(loop_scope="module") async def test_sleeptime_group_chat_v2(server, actor): # 0. Refresh base tools server.tool_manager.upsert_base_tools(actor=actor) @@ -220,7 +221,7 @@ async def test_sleeptime_group_chat_v2(server, actor): # 3. Verify shared blocks sleeptime_agent_id = group.agent_ids[0] shared_block = server.agent_manager.get_block_with_label(agent_id=main_agent.id, block_label="human", actor=actor) - agents = server.block_manager.get_agents_for_block(block_id=shared_block.id, actor=actor) + agents = await server.block_manager.get_agents_for_block_async(block_id=shared_block.id, actor=actor) assert len(agents) == 2 assert sleeptime_agent_id in [agent.id for agent in agents] assert main_agent.id in [agent.id for agent in agents] @@ -292,7 +293,7 @@ async def test_sleeptime_group_chat_v2(server, actor): @pytest.mark.skip -@pytest.mark.asyncio +@pytest.mark.asyncio(loop_scope="module") async def test_sleeptime_removes_redundant_information(server, actor): # 1. set up sleep-time agent as in test_sleeptime_group_chat server.tool_manager.upsert_base_tools(actor=actor) @@ -360,7 +361,7 @@ async def test_sleeptime_removes_redundant_information(server, actor): server.agent_manager.get_agent_by_id(agent_id=sleeptime_agent_id, actor=actor) -@pytest.mark.asyncio +@pytest.mark.asyncio(loop_scope="module") async def test_sleeptime_edit(server, actor): sleeptime_agent = server.create_agent( request=CreateAgent( diff --git a/tests/integration_test_voice_agent.py b/tests/integration_test_voice_agent.py index 246611dda..ccce79af2 100644 --- a/tests/integration_test_voice_agent.py +++ b/tests/integration_test_voice_agent.py @@ -1,10 +1,10 @@ import os -import threading +import subprocess +import sys from unittest.mock import MagicMock import pytest from dotenv import load_dotenv -from letta_client import AsyncLetta from openai import AsyncOpenAI from openai.types.chat import ChatCompletionChunk @@ -35,7 +35,7 @@ from letta.services.summarizer.summarizer import Summarizer from letta.services.tool_manager import ToolManager from letta.services.user_manager import UserManager from letta.utils import get_persona_text -from tests.utils import wait_for_server +from tests.utils import create_tool_from_func, wait_for_server MESSAGE_TRANSCRIPTS = [ "user: Hey, I’ve been thinking about planning a road trip up the California coast next month.", @@ -92,17 +92,6 @@ You’re a memory-recall helper for an AI that can only keep the last 4 messages # --- Server Management --- # -@pytest.fixture(scope="module") -def server(): - config = LettaConfig.load() - print("CONFIG PATH", config.config_path) - - config.save() - - server = SyncServer() - return server - - def _run_server(): """Starts the Letta server in a background thread.""" load_dotenv() @@ -111,31 +100,66 @@ def _run_server(): start_server(debug=True) -@pytest.fixture(scope="session") +@pytest.fixture(scope="module") def server_url(): - """Ensures a server is running and returns its base URL.""" - url = os.getenv("LETTA_SERVER_URL", "http://localhost:8283") + """ + Starts the Letta HTTP server in a separate process using the 'uvicorn' CLI, + so its event loop and DB pool stay completely isolated from pytest-asyncio. + """ + url = os.getenv("LETTA_SERVER_URL", "http://127.0.0.1:8283") + # Only spawn our own server if the user hasn't overridden LETTA_SERVER_URL if not os.getenv("LETTA_SERVER_URL"): - thread = threading.Thread(target=_run_server, daemon=True) - thread.start() - wait_for_server(url) # Allow server startup time + # Build the command to launch uvicorn on your FastAPI app + cmd = [ + sys.executable, + "-m", + "uvicorn", + "letta.server.rest_api.app:app", + "--host", + "127.0.0.1", + "--port", + "8283", + ] + # If you need TLS or reload settings from start_server(), you can add + # "--reload" or "--ssl-keyfile", "--ssl-certfile" here as well. - return url + server_proc = subprocess.Popen( + cmd, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + + # wait until the HTTP port is accepting connections + wait_for_server(url) + + yield url + + # Teardown: kill the subprocess if we started it + server_proc.terminate() + server_proc.wait(timeout=10) + else: + yield url + + +@pytest.fixture(scope="module") +def server(): + config = LettaConfig.load() + print("CONFIG PATH", config.config_path) + + config.save() + + server = SyncServer() + actor = server.user_manager.get_user_or_default() + server.tool_manager.upsert_base_tools(actor=actor) + return server # --- Client Setup --- # -@pytest.fixture(scope="session") -def client(server_url): - """Creates a REST client for testing.""" - client = AsyncLetta(base_url=server_url) - yield client - - -@pytest.fixture(scope="function") -async def roll_dice_tool(client): +@pytest.fixture +async def roll_dice_tool(server): def roll_dice(): """ Rolls a 6 sided die. @@ -145,13 +169,13 @@ async def roll_dice_tool(client): """ return "Rolled a 10!" - tool = await client.tools.upsert_from_function(func=roll_dice) - # Yield the created tool + actor = server.user_manager.get_user_or_default() + tool = server.tool_manager.create_or_update_tool(create_tool_from_func(func=roll_dice), actor=actor) yield tool -@pytest.fixture(scope="function") -async def weather_tool(client): +@pytest.fixture +async def weather_tool(server): def get_weather(location: str) -> str: """ Fetches the current weather for a given location. @@ -176,22 +200,20 @@ async def weather_tool(client): else: raise RuntimeError(f"Failed to get weather data, status code: {response.status_code}") - tool = await client.tools.upsert_from_function(func=get_weather) - # Yield the created tool + actor = server.user_manager.get_user_or_default() + tool = server.tool_manager.create_or_update_tool(create_tool_from_func(func=get_weather), actor=actor) yield tool -@pytest.fixture(scope="function") +@pytest.fixture def composio_gmail_get_profile_tool(default_user): tool_create = ToolCreate.from_composio(action_name="GMAIL_GET_PROFILE") tool = ToolManager().create_or_update_composio_tool(tool_create=tool_create, actor=default_user) yield tool -@pytest.fixture(scope="function") +@pytest.fixture def voice_agent(server, actor): - server.tool_manager.upsert_base_tools(actor=actor) - main_agent = server.create_agent( request=CreateAgent( agent_type=AgentType.voice_convo_agent, @@ -268,9 +290,9 @@ def _assert_valid_chunk(chunk, idx, chunks): # --- Tests --- # -@pytest.mark.asyncio(loop_scope="session") +@pytest.mark.asyncio(loop_scope="module") @pytest.mark.parametrize("model", ["openai/gpt-4o-mini", "anthropic/claude-3-5-sonnet-20241022"]) -async def test_model_compatibility(disable_e2b_api_key, voice_agent, model, server, group_id, actor): +async def test_model_compatibility(disable_e2b_api_key, voice_agent, model, server, server_url, group_id, actor): request = _get_chat_request("How are you?") server.tool_manager.upsert_base_tools(actor=actor) @@ -303,10 +325,10 @@ async def test_model_compatibility(disable_e2b_api_key, voice_agent, model, serv print(chunk.choices[0].delta.content) -@pytest.mark.asyncio(loop_scope="session") +@pytest.mark.asyncio(loop_scope="module") @pytest.mark.parametrize("message", ["Use search memory tool to recall what my name is."]) @pytest.mark.parametrize("endpoint", ["v1/voice-beta"]) -async def test_voice_recall_memory(disable_e2b_api_key, voice_agent, message, endpoint): +async def test_voice_recall_memory(disable_e2b_api_key, voice_agent, message, endpoint, server_url): """Tests chat completion streaming using the Async OpenAI client.""" request = _get_chat_request(message) @@ -318,9 +340,9 @@ async def test_voice_recall_memory(disable_e2b_api_key, voice_agent, message, en print(chunk.choices[0].delta.content) -@pytest.mark.asyncio(loop_scope="session") +@pytest.mark.asyncio(loop_scope="module") @pytest.mark.parametrize("endpoint", ["v1/voice-beta"]) -async def test_trigger_summarization(disable_e2b_api_key, server, voice_agent, group_id, endpoint, actor): +async def test_trigger_summarization(disable_e2b_api_key, server, voice_agent, group_id, endpoint, actor, server_url): server.group_manager.modify_group( group_id=group_id, group_update=GroupUpdate( @@ -350,8 +372,8 @@ async def test_trigger_summarization(disable_e2b_api_key, server, voice_agent, g print(chunk.choices[0].delta.content) -@pytest.mark.asyncio(loop_scope="session") -async def test_summarization(disable_e2b_api_key, voice_agent): +@pytest.mark.asyncio(loop_scope="module") +async def test_summarization(disable_e2b_api_key, voice_agent, server_url): agent_manager = AgentManager() user_manager = UserManager() actor = user_manager.get_default_user() @@ -422,8 +444,8 @@ async def test_summarization(disable_e2b_api_key, voice_agent): summarizer.fire_and_forget.assert_called_once() -@pytest.mark.asyncio(loop_scope="session") -async def test_voice_sleeptime_agent(disable_e2b_api_key, voice_agent): +@pytest.mark.asyncio(loop_scope="module") +async def test_voice_sleeptime_agent(disable_e2b_api_key, voice_agent, server_url): """Tests chat completion streaming using the Async OpenAI client.""" agent_manager = AgentManager() tool_manager = ToolManager() @@ -488,8 +510,8 @@ async def test_voice_sleeptime_agent(disable_e2b_api_key, voice_agent): assert not missing, f"Did not see calls to: {', '.join(missing)}" -@pytest.mark.asyncio(loop_scope="session") -async def test_init_voice_convo_agent(voice_agent, server, actor): +@pytest.mark.asyncio(loop_scope="module") +async def test_init_voice_convo_agent(voice_agent, server, actor, server_url): assert voice_agent.enable_sleeptime == True main_agent_tools = [tool.name for tool in voice_agent.tools] @@ -511,7 +533,7 @@ async def test_init_voice_convo_agent(voice_agent, server, actor): # 3. Verify shared blocks sleeptime_agent_id = group.agent_ids[0] shared_block = server.agent_manager.get_block_with_label(agent_id=voice_agent.id, block_label="human", actor=actor) - agents = server.block_manager.get_agents_for_block(block_id=shared_block.id, actor=actor) + agents = await server.block_manager.get_agents_for_block_async(block_id=shared_block.id, actor=actor) assert len(agents) == 2 assert sleeptime_agent_id in [agent.id for agent in agents] assert voice_agent.id in [agent.id for agent in agents] diff --git a/tests/test_agent_serialization.py b/tests/test_agent_serialization.py index 7599e02e4..aa02e0dfa 100644 --- a/tests/test_agent_serialization.py +++ b/tests/test_agent_serialization.py @@ -1,12 +1,15 @@ import difflib import json import os +import threading +import time from datetime import datetime, timezone from io import BytesIO from typing import Any, Dict, List, Mapping import pytest -from fastapi.testclient import TestClient +import requests +from dotenv import load_dotenv from rich.console import Console from rich.syntax import Syntax @@ -23,11 +26,51 @@ from letta.schemas.message import MessageCreate from letta.schemas.organization import Organization from letta.schemas.user import User from letta.serialize_schemas.pydantic_agent_schema import AgentSchema -from letta.server.rest_api.app import app from letta.server.server import SyncServer console = Console() +# ------------------------------ +# Fixtures +# ------------------------------ + + +@pytest.fixture(scope="module") +def server_url() -> str: + """ + Provides the URL for the Letta server. + If LETTA_SERVER_URL is not set, starts the server in a background thread + and polls until it’s accepting connections. + """ + + def _run_server() -> None: + load_dotenv() + from letta.server.rest_api.app import start_server + + start_server(debug=True) + + url: str = os.getenv("LETTA_SERVER_URL", "http://localhost:8283") + + if not os.getenv("LETTA_SERVER_URL"): + thread = threading.Thread(target=_run_server, daemon=True) + thread.start() + + # Poll until the server is up (or timeout) + timeout_seconds = 30 + deadline = time.time() + timeout_seconds + while time.time() < deadline: + try: + resp = requests.get(url + "/v1/health") + if resp.status_code < 500: + break + except requests.exceptions.RequestException: + pass + time.sleep(0.1) + else: + raise RuntimeError(f"Could not reach {url} within {timeout_seconds}s") + + return url + def _clear_tables(): from letta.server.db import db_context @@ -38,12 +81,6 @@ def _clear_tables(): session.commit() -@pytest.fixture -def fastapi_client(): - """Fixture to create a FastAPI test client.""" - return TestClient(app) - - @pytest.fixture(autouse=True) def clear_tables(): _clear_tables() @@ -57,14 +94,14 @@ def local_client(): yield client -@pytest.fixture(scope="module") +@pytest.fixture def server(): config = LettaConfig.load() config.save() server = SyncServer(init_with_default_org_and_user=False) - return server + yield server @pytest.fixture @@ -562,14 +599,17 @@ def test_agent_serialize_update_blocks(disable_e2b_api_key, local_client, server @pytest.mark.parametrize("append_copy_suffix", [True, False]) @pytest.mark.parametrize("project_id", ["project-12345", None]) -def test_agent_download_upload_flow(fastapi_client, server, serialize_test_agent, default_user, other_user, append_copy_suffix, project_id): +def test_agent_download_upload_flow(server, server_url, serialize_test_agent, default_user, other_user, append_copy_suffix, project_id): """ Test the full E2E serialization and deserialization flow using FastAPI endpoints. """ agent_id = serialize_test_agent.id # Step 1: Download the serialized agent - response = fastapi_client.get(f"/v1/agents/{agent_id}/export", headers={"user_id": default_user.id}) + response = requests.get( + f"{server_url}/v1/agents/{agent_id}/export", + headers={"user_id": default_user.id}, + ) assert response.status_code == 200, f"Download failed: {response.text}" # Ensure response matches expected schema @@ -580,10 +620,14 @@ def test_agent_download_upload_flow(fastapi_client, server, serialize_test_agent # Step 2: Upload the serialized agent as a copy agent_bytes = BytesIO(json.dumps(agent_json).encode("utf-8")) files = {"file": ("agent.json", agent_bytes, "application/json")} - upload_response = fastapi_client.post( - "/v1/agents/import", + upload_response = requests.post( + f"{server_url}/v1/agents/import", headers={"user_id": other_user.id}, - params={"append_copy_suffix": append_copy_suffix, "override_existing_tools": False, "project_id": project_id}, + params={ + "append_copy_suffix": append_copy_suffix, + "override_existing_tools": False, + "project_id": project_id, + }, files=files, ) assert upload_response.status_code == 200, f"Upload failed: {upload_response.text}" @@ -613,16 +657,16 @@ def test_agent_download_upload_flow(fastapi_client, server, serialize_test_agent "memgpt_agent_with_convo.af", ], ) -def test_upload_agentfile_from_disk(server, disable_e2b_api_key, fastapi_client, other_user, filename): +def test_upload_agentfile_from_disk(server, server_url, disable_e2b_api_key, other_user, filename): """ - Test uploading each .af file from the test_agent_files directory via FastAPI. + Test uploading each .af file from the test_agent_files directory via live FastAPI server. """ file_path = os.path.join(os.path.dirname(__file__), "test_agent_files", filename) with open(file_path, "rb") as f: files = {"file": (filename, f, "application/json")} - response = fastapi_client.post( - "/v1/agents/import", + response = requests.post( + f"{server_url}/v1/agents/import", headers={"user_id": other_user.id}, params={"append_copy_suffix": True, "override_existing_tools": False}, files=files, diff --git a/tests/test_letta_agent_batch.py b/tests/test_letta_agent_batch.py index 11da3a197..da2a6666c 100644 --- a/tests/test_letta_agent_batch.py +++ b/tests/test_letta_agent_batch.py @@ -1,5 +1,3 @@ -import os -import threading from datetime import datetime, timezone from typing import Tuple from unittest.mock import AsyncMock, patch @@ -14,15 +12,13 @@ from anthropic.types.beta.messages import ( BetaMessageBatchRequestCounts, BetaMessageBatchSucceededResult, ) -from dotenv import load_dotenv -from letta_client import Letta from letta.agents.letta_agent_batch import LettaAgentBatch from letta.config import LettaConfig from letta.helpers import ToolRulesSolver from letta.jobs.llm_batch_job_polling import poll_running_llm_batches from letta.orm import Base -from letta.schemas.agent import AgentState, AgentStepState +from letta.schemas.agent import AgentState, AgentStepState, CreateAgent from letta.schemas.enums import AgentStepStatus, JobStatus, MessageRole, ProviderType from letta.schemas.job import BatchJob from letta.schemas.letta_message_content import TextContent @@ -31,10 +27,10 @@ from letta.schemas.message import MessageCreate from letta.schemas.tool_rule import InitToolRule from letta.server.db import db_context from letta.server.server import SyncServer -from tests.utils import wait_for_server +from tests.utils import create_tool_from_func # --------------------------------------------------------------------------- # -# Test Constants +# Test Constants / Helpers # --------------------------------------------------------------------------- # # Model identifiers used in tests @@ -54,7 +50,7 @@ EXPECTED_ROLES = ["system", "assistant", "tool", "user", "user"] @pytest.fixture(scope="function") -def weather_tool(client): +def weather_tool(server): def get_weather(location: str) -> str: """ Fetches the current weather for a given location. @@ -79,13 +75,14 @@ def weather_tool(client): else: raise RuntimeError(f"Failed to get weather data, status code: {response.status_code}") - tool = client.tools.upsert_from_function(func=get_weather) + actor = server.user_manager.get_user_or_default() + tool = server.tool_manager.create_or_update_tool(create_tool_from_func(func=get_weather), actor=actor) # Yield the created tool yield tool @pytest.fixture(scope="function") -def rethink_tool(client): +def rethink_tool(server): def rethink_memory(agent_state: "AgentState", new_memory: str, target_block_label: str) -> str: # type: ignore """ Re-evaluate the memory in block_name, integrating new and updated facts. @@ -101,28 +98,33 @@ def rethink_tool(client): agent_state.memory.update_block_value(label=target_block_label, value=new_memory) return None - tool = client.tools.upsert_from_function(func=rethink_memory) + actor = server.user_manager.get_user_or_default() + tool = server.tool_manager.create_or_update_tool(create_tool_from_func(func=rethink_memory), actor=actor) # Yield the created tool yield tool @pytest.fixture -def agents(client, weather_tool): +def agents(server, weather_tool): """ Create three test agents with different models. Returns: Tuple[Agent, Agent, Agent]: Three agents with sonnet, haiku, and opus models """ + actor = server.user_manager.get_user_or_default() def create_agent(suffix, model_name): - return client.agents.create( - name=f"test_agent_{suffix}", - include_base_tools=True, - model=model_name, - tags=["test_agents"], - embedding="letta/letta-free", - tool_ids=[weather_tool.id], + return server.create_agent( + CreateAgent( + name=f"test_agent_{suffix}", + include_base_tools=True, + model=model_name, + tags=["test_agents"], + embedding="letta/letta-free", + tool_ids=[weather_tool.id], + ), + actor=actor, ) return ( @@ -290,32 +292,6 @@ def clear_batch_tables(): session.commit() -def run_server(): - """Starts the Letta server in a background thread.""" - load_dotenv() - from letta.server.rest_api.app import start_server - - start_server(debug=True) - - -@pytest.fixture(scope="session") -def server_url(): - """ - Ensures a server is running and returns its base URL. - - Uses environment variable if available, otherwise starts a server - in a background thread. - """ - url = os.getenv("LETTA_SERVER_URL", "http://localhost:8283") - - if not os.getenv("LETTA_SERVER_URL"): - thread = threading.Thread(target=run_server, daemon=True) - thread.start() - wait_for_server(url) - - return url - - @pytest.fixture(scope="module") def server(): """ @@ -324,14 +300,11 @@ def server(): Loads and saves config to ensure proper initialization. """ config = LettaConfig.load() + config.save() - return SyncServer() - -@pytest.fixture(scope="session") -def client(server_url): - """Creates a REST client connected to the test server.""" - return Letta(base_url=server_url) + server = SyncServer(init_with_default_org_and_user=True) + yield server @pytest.fixture @@ -368,23 +341,27 @@ class MockAsyncIterable: # --------------------------------------------------------------------------- # -@pytest.mark.asyncio(loop_scope="session") -async def test_rethink_tool_modify_agent_state(client, disable_e2b_api_key, server, default_user, batch_job, rethink_tool): +@pytest.mark.asyncio(loop_scope="module") +async def test_rethink_tool_modify_agent_state(disable_e2b_api_key, server, default_user, batch_job, rethink_tool): target_block_label = "human" new_memory = "banana" - agent = client.agents.create( - name=f"test_agent_rethink", - include_base_tools=True, - model=MODELS["sonnet"], - tags=["test_agents"], - embedding="letta/letta-free", - tool_ids=[rethink_tool.id], - memory_blocks=[ - { - "label": target_block_label, - "value": "Name: Matt", - }, - ], + actor = server.user_manager.get_user_or_default() + agent = await server.create_agent_async( + request=CreateAgent( + name=f"test_agent_rethink", + include_base_tools=True, + model=MODELS["sonnet"], + tags=["test_agents"], + embedding="letta/letta-free", + tool_ids=[rethink_tool.id], + memory_blocks=[ + { + "label": target_block_label, + "value": "Name: Matt", + }, + ], + ), + actor=actor, ) agents = [agent] batch_requests = [ @@ -444,13 +421,13 @@ async def test_rethink_tool_modify_agent_state(client, disable_e2b_api_key, serv await poll_running_llm_batches(server) # Check that the tool has been executed correctly - agent = client.agents.retrieve(agent_id=agent.id) + agent = server.agent_manager.get_agent_by_id(agent_id=agent.id, actor=actor) for block in agent.memory.blocks: if block.label == target_block_label: assert block.value == new_memory -@pytest.mark.asyncio(loop_scope="session") +@pytest.mark.asyncio(loop_scope="module") async def test_partial_error_from_anthropic_batch( disable_e2b_api_key, server, default_user, agents: Tuple[AgentState], batch_requests, step_state_map, batch_job ): @@ -610,7 +587,7 @@ async def test_partial_error_from_anthropic_batch( assert agent_messages[0].role == MessageRole.user, "Expected initial user message" -@pytest.mark.asyncio(loop_scope="session") +@pytest.mark.asyncio(loop_scope="module") async def test_resume_step_some_stop( disable_e2b_api_key, server, default_user, agents: Tuple[AgentState], batch_requests, step_state_map, batch_job ): @@ -773,7 +750,7 @@ def _assert_descending_order(messages): return True -@pytest.mark.asyncio(loop_scope="session") +@pytest.mark.asyncio(loop_scope="module") async def test_resume_step_after_request_all_continue( disable_e2b_api_key, server, default_user, agents: Tuple[AgentState], batch_requests, step_state_map, batch_job ): @@ -911,7 +888,7 @@ async def test_resume_step_after_request_all_continue( assert agent_messages[-4].role == MessageRole.user, "Expected final system-level heartbeat user message" -@pytest.mark.asyncio(loop_scope="session") +@pytest.mark.asyncio(loop_scope="module") async def test_step_until_request_prepares_and_submits_batch_correctly( disable_e2b_api_key, server, default_user, agents, batch_requests, step_state_map, dummy_batch_response, batch_job ): diff --git a/tests/test_managers.py b/tests/test_managers.py index 719f867d0..695dec5db 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -24,7 +24,9 @@ from letta.constants import ( BASE_TOOLS, BASE_VOICE_SLEEPTIME_CHAT_TOOLS, BASE_VOICE_SLEEPTIME_TOOLS, + BUILTIN_TOOLS, LETTA_TOOL_EXECUTION_DIR, + LETTA_TOOL_SET, MCP_TOOL_TAG_NAME_PREFIX, MULTI_AGENT_TOOLS, ) @@ -69,7 +71,7 @@ from letta.schemas.tool import ToolCreate, ToolUpdate from letta.schemas.tool_rule import InitToolRule from letta.schemas.user import User as PydanticUser from letta.schemas.user import UserUpdate -from letta.server.db import db_context +from letta.server.db import db_registry from letta.server.server import SyncServer from letta.services.block_manager import BlockManager from letta.services.organization_manager import OrganizationManager @@ -92,14 +94,14 @@ USING_SQLITE = not bool(os.getenv("LETTA_PG_URI")) @pytest.fixture(autouse=True) -def _clear_tables(): - with db_context() as session: +async def _clear_tables(): + async with db_registry.async_session() as session: for table in reversed(Base.metadata.sorted_tables): # Reverse to avoid FK issues # If this is the block_history table, skip it if table.name == "block_history": continue - session.execute(table.delete()) # Truncate table - session.commit() + await session.execute(table.delete()) # Truncate table + await session.commit() @pytest.fixture @@ -171,7 +173,7 @@ def default_file(server: SyncServer, default_source, default_user, default_organ @pytest.fixture -def print_tool(server: SyncServer, default_user, default_organization): +async def print_tool(server: SyncServer, default_user, default_organization): """Fixture to create a tool with default settings and clean up after the test.""" def print_tool(message: str): @@ -199,7 +201,7 @@ def print_tool(server: SyncServer, default_user, default_organization): tool.json_schema = derived_json_schema tool.name = derived_name - tool = server.tool_manager.create_tool(tool, actor=default_user) + tool = await server.tool_manager.create_or_update_tool_async(tool, actor=default_user) # Yield the created tool yield tool @@ -237,24 +239,24 @@ def mcp_tool(server, default_user): @pytest.fixture -def default_job(server: SyncServer, default_user): +async def default_job(server: SyncServer, default_user): """Fixture to create and return a default job.""" job_pydantic = PydanticJob( user_id=default_user.id, status=JobStatus.pending, ) - job = server.job_manager.create_job(pydantic_job=job_pydantic, actor=default_user) + job = await server.job_manager.create_job_async(pydantic_job=job_pydantic, actor=default_user) yield job @pytest.fixture -def default_run(server: SyncServer, default_user): +async def default_run(server: SyncServer, default_user): """Fixture to create and return a default job.""" run_pydantic = PydanticRun( user_id=default_user.id, status=JobStatus.pending, ) - run = server.job_manager.create_job(pydantic_job=run_pydantic, actor=default_user) + run = await server.job_manager.create_job_async(pydantic_job=run_pydantic, actor=default_user) yield run @@ -403,7 +405,7 @@ def other_block(server: SyncServer, default_user): @pytest.fixture -def other_tool(server: SyncServer, default_user, default_organization): +async def other_tool(server: SyncServer, default_user, default_organization): def print_other_tool(message: str): """ Args: @@ -428,16 +430,16 @@ def other_tool(server: SyncServer, default_user, default_organization): tool.json_schema = derived_json_schema tool.name = derived_name - tool = server.tool_manager.create_tool(tool, actor=default_user) + tool = await server.tool_manager.create_or_update_tool_async(tool, actor=default_user) # Yield the created tool yield tool @pytest.fixture -def sarah_agent(server: SyncServer, default_user, default_organization): +async def sarah_agent(server: SyncServer, default_user, default_organization): """Fixture to create and return a sample agent within the default organization.""" - agent_state = server.agent_manager.create_agent( + agent_state = await server.agent_manager.create_agent_async( agent_create=CreateAgent( name="sarah_agent", memory_blocks=[], @@ -451,9 +453,9 @@ def sarah_agent(server: SyncServer, default_user, default_organization): @pytest.fixture -def charles_agent(server: SyncServer, default_user, default_organization): +async def charles_agent(server: SyncServer, default_user, default_organization): """Fixture to create and return a sample agent within the default organization.""" - agent_state = server.agent_manager.create_agent( + agent_state = await server.agent_manager.create_agent_async( agent_create=CreateAgent( name="charles_agent", memory_blocks=[CreateBlock(label="human", value="Charles"), CreateBlock(label="persona", value="I am a helpful assistant")], @@ -467,7 +469,7 @@ def charles_agent(server: SyncServer, default_user, default_organization): @pytest.fixture -def comprehensive_test_agent_fixture(server: SyncServer, default_user, print_tool, default_source, default_block): +async def comprehensive_test_agent_fixture(server: SyncServer, default_user, print_tool, default_source, default_block): memory_blocks = [CreateBlock(label="human", value="BananaBoy"), CreateBlock(label="persona", value="I am a helpful assistant")] create_agent_request = CreateAgent( system="test system", @@ -486,7 +488,7 @@ def comprehensive_test_agent_fixture(server: SyncServer, default_user, print_too message_buffer_autoclear=True, include_base_tools=False, ) - created_agent = server.agent_manager.create_agent( + created_agent = await server.agent_manager.create_agent_async( create_agent_request, actor=default_user, ) @@ -550,9 +552,9 @@ async def agent_passages_setup(server, default_source, default_user, sarah_agent @pytest.fixture -def agent_with_tags(server: SyncServer, default_user): +async def agent_with_tags(server: SyncServer, default_user): """Fixture to create agents with specific tags.""" - agent1 = server.agent_manager.create_agent( + agent1 = await server.agent_manager.create_agent_async( agent_create=CreateAgent( name="agent1", tags=["primary_agent", "benefit_1"], @@ -564,7 +566,7 @@ def agent_with_tags(server: SyncServer, default_user): actor=default_user, ) - agent2 = server.agent_manager.create_agent( + agent2 = await server.agent_manager.create_agent_async( agent_create=CreateAgent( name="agent2", tags=["primary_agent", "benefit_2"], @@ -576,7 +578,7 @@ def agent_with_tags(server: SyncServer, default_user): actor=default_user, ) - agent3 = server.agent_manager.create_agent( + agent3 = await server.agent_manager.create_agent_async( agent_create=CreateAgent( name="agent3", tags=["primary_agent", "benefit_1", "benefit_2"], @@ -656,17 +658,18 @@ async def test_create_get_list_agent(server: SyncServer, comprehensive_test_agen comprehensive_agent_checks(get_agent_name, create_agent_request, actor=default_user) # Test list agent - list_agents = server.agent_manager.list_agents(actor=default_user) + list_agents = await server.agent_manager.list_agents_async(actor=default_user) assert len(list_agents) == 1 comprehensive_agent_checks(list_agents[0], create_agent_request, actor=default_user) # Test deleting the agent server.agent_manager.delete_agent(get_agent.id, default_user) - list_agents = server.agent_manager.list_agents(actor=default_user) + list_agents = await server.agent_manager.list_agents_async(actor=default_user) assert len(list_agents) == 0 -def test_create_agent_passed_in_initial_messages(server: SyncServer, default_user, default_block): +@pytest.mark.asyncio +async def test_create_agent_passed_in_initial_messages(server: SyncServer, default_user, default_block, event_loop): memory_blocks = [CreateBlock(label="human", value="BananaBoy"), CreateBlock(label="persona", value="I am a helpful assistant")] create_agent_request = CreateAgent( system="test system", @@ -679,12 +682,12 @@ def test_create_agent_passed_in_initial_messages(server: SyncServer, default_use initial_message_sequence=[MessageCreate(role=MessageRole.user, content="hello world")], include_base_tools=False, ) - agent_state = server.agent_manager.create_agent( + agent_state = await server.agent_manager.create_agent_async( create_agent_request, actor=default_user, ) - assert server.message_manager.size(agent_id=agent_state.id, actor=default_user) == 2 - init_messages = server.agent_manager.get_in_context_messages(agent_id=agent_state.id, actor=default_user) + assert await server.message_manager.size_async(agent_id=agent_state.id, actor=default_user) == 2 + init_messages = await server.agent_manager.get_in_context_messages_async(agent_id=agent_state.id, actor=default_user) # Check that the system appears in the first initial message assert create_agent_request.system in init_messages[0].content[0].text @@ -694,7 +697,8 @@ def test_create_agent_passed_in_initial_messages(server: SyncServer, default_use assert create_agent_request.initial_message_sequence[0].content in init_messages[1].content[0].text -def test_create_agent_default_initial_message(server: SyncServer, default_user, default_block): +@pytest.mark.asyncio +async def test_create_agent_default_initial_message(server: SyncServer, default_user, default_block, event_loop): memory_blocks = [CreateBlock(label="human", value="BananaBoy"), CreateBlock(label="persona", value="I am a helpful assistant")] create_agent_request = CreateAgent( system="test system", @@ -706,18 +710,19 @@ def test_create_agent_default_initial_message(server: SyncServer, default_user, description="test_description", include_base_tools=False, ) - agent_state = server.agent_manager.create_agent( + agent_state = await server.agent_manager.create_agent_async( create_agent_request, actor=default_user, ) - assert server.message_manager.size(agent_id=agent_state.id, actor=default_user) == 4 - init_messages = server.agent_manager.get_in_context_messages(agent_id=agent_state.id, actor=default_user) + assert await server.message_manager.size_async(agent_id=agent_state.id, actor=default_user) == 4 + init_messages = await server.agent_manager.get_in_context_messages_async(agent_id=agent_state.id, actor=default_user) # Check that the system appears in the first initial message assert create_agent_request.system in init_messages[0].content[0].text assert create_agent_request.memory_blocks[0].value in init_messages[0].content[0].text -def test_create_agent_with_json_in_system_message(server: SyncServer, default_user, default_block): +@pytest.mark.asyncio +async def test_create_agent_with_json_in_system_message(server: SyncServer, default_user, default_block, event_loop): system_prompt = ( "You are an expert teaching agent with encyclopedic knowledge. " "When you receive a topic, query the external database for more " @@ -734,19 +739,22 @@ def test_create_agent_with_json_in_system_message(server: SyncServer, default_us description="test_description", include_base_tools=False, ) - agent_state = server.agent_manager.create_agent( + agent_state = await server.agent_manager.create_agent_async( create_agent_request, actor=default_user, ) assert agent_state is not None system_message_id = agent_state.message_ids[0] - system_message = server.message_manager.get_message_by_id(message_id=system_message_id, actor=default_user) + system_message = await server.message_manager.get_message_by_id_async(message_id=system_message_id, actor=default_user) assert system_prompt in system_message.content[0].text assert default_block.value in system_message.content[0].text server.agent_manager.delete_agent(agent_id=agent_state.id, actor=default_user) -def test_update_agent(server: SyncServer, comprehensive_test_agent_fixture, other_tool, other_source, other_block, default_user): +@pytest.mark.asyncio +async def test_update_agent( + server: SyncServer, comprehensive_test_agent_fixture, other_tool, other_source, other_block, default_user, event_loop +): agent, _ = comprehensive_test_agent_fixture update_agent_request = UpdateAgent( name="train_agent", @@ -766,7 +774,7 @@ def test_update_agent(server: SyncServer, comprehensive_test_agent_fixture, othe ) last_updated_timestamp = agent.updated_at - updated_agent = server.agent_manager.update_agent(agent.id, update_agent_request, actor=default_user) + updated_agent = await server.agent_manager.update_agent_async(agent.id, update_agent_request, actor=default_user) comprehensive_agent_checks(updated_agent, update_agent_request, actor=default_user) assert updated_agent.message_ids == update_agent_request.message_ids assert updated_agent.updated_at > last_updated_timestamp @@ -777,12 +785,13 @@ def test_update_agent(server: SyncServer, comprehensive_test_agent_fixture, othe # ====================================================================================================================== -def test_list_agents_select_fields_empty(server: SyncServer, comprehensive_test_agent_fixture, default_user): +@pytest.mark.asyncio +async def test_list_agents_select_fields_empty(server: SyncServer, comprehensive_test_agent_fixture, default_user, event_loop): # Create an agent using the comprehensive fixture. created_agent, create_agent_request = comprehensive_test_agent_fixture # List agents using an empty list for select_fields. - agents = server.agent_manager.list_agents(actor=default_user, include_relationships=[]) + agents = await server.agent_manager.list_agents_async(actor=default_user, include_relationships=[]) # Assert that the agent is returned and basic fields are present. assert len(agents) >= 1 agent = agents[0] @@ -794,12 +803,13 @@ def test_list_agents_select_fields_empty(server: SyncServer, comprehensive_test_ assert len(agent.tags) == 0 -def test_list_agents_select_fields_none(server: SyncServer, comprehensive_test_agent_fixture, default_user): +@pytest.mark.asyncio +async def test_list_agents_select_fields_none(server: SyncServer, comprehensive_test_agent_fixture, default_user, event_loop): # Create an agent using the comprehensive fixture. created_agent, create_agent_request = comprehensive_test_agent_fixture # List agents using an empty list for select_fields. - agents = server.agent_manager.list_agents(actor=default_user, include_relationships=None) + agents = await server.agent_manager.list_agents_async(actor=default_user, include_relationships=None) # Assert that the agent is returned and basic fields are present. assert len(agents) >= 1 agent = agents[0] @@ -811,12 +821,13 @@ def test_list_agents_select_fields_none(server: SyncServer, comprehensive_test_a assert len(agent.tags) > 0 -def test_list_agents_select_fields_specific(server: SyncServer, comprehensive_test_agent_fixture, default_user): +@pytest.mark.asyncio +async def test_list_agents_select_fields_specific(server: SyncServer, comprehensive_test_agent_fixture, default_user, event_loop): created_agent, create_agent_request = comprehensive_test_agent_fixture # Choose a subset of valid relationship fields. valid_fields = ["tools", "tags"] - agents = server.agent_manager.list_agents(actor=default_user, include_relationships=valid_fields) + agents = await server.agent_manager.list_agents_async(actor=default_user, include_relationships=valid_fields) assert len(agents) >= 1 agent = agents[0] # Depending on your to_pydantic() implementation, @@ -827,13 +838,14 @@ def test_list_agents_select_fields_specific(server: SyncServer, comprehensive_te assert not agent.memory.blocks -def test_list_agents_select_fields_invalid(server: SyncServer, comprehensive_test_agent_fixture, default_user): +@pytest.mark.asyncio +async def test_list_agents_select_fields_invalid(server: SyncServer, comprehensive_test_agent_fixture, default_user, event_loop): created_agent, create_agent_request = comprehensive_test_agent_fixture # Provide field names that are not recognized. invalid_fields = ["foobar", "nonexistent_field"] # The expectation is that these fields are simply ignored. - agents = server.agent_manager.list_agents(actor=default_user, include_relationships=invalid_fields) + agents = await server.agent_manager.list_agents_async(actor=default_user, include_relationships=invalid_fields) assert len(agents) >= 1 agent = agents[0] # Verify that standard fields are still present.c @@ -841,12 +853,13 @@ def test_list_agents_select_fields_invalid(server: SyncServer, comprehensive_tes assert agent.name is not None -def test_list_agents_select_fields_duplicates(server: SyncServer, comprehensive_test_agent_fixture, default_user): +@pytest.mark.asyncio +async def test_list_agents_select_fields_duplicates(server: SyncServer, comprehensive_test_agent_fixture, default_user, event_loop): created_agent, create_agent_request = comprehensive_test_agent_fixture # Provide duplicate valid field names. duplicate_fields = ["tools", "tools", "tags", "tags"] - agents = server.agent_manager.list_agents(actor=default_user, include_relationships=duplicate_fields) + agents = await server.agent_manager.list_agents_async(actor=default_user, include_relationships=duplicate_fields) assert len(agents) >= 1 agent = agents[0] # Verify that the agent pydantic representation includes the relationships. @@ -855,12 +868,13 @@ def test_list_agents_select_fields_duplicates(server: SyncServer, comprehensive_ assert isinstance(agent.tags, list) -def test_list_agents_select_fields_mixed(server: SyncServer, comprehensive_test_agent_fixture, default_user): +@pytest.mark.asyncio +async def test_list_agents_select_fields_mixed(server: SyncServer, comprehensive_test_agent_fixture, default_user, event_loop): created_agent, create_agent_request = comprehensive_test_agent_fixture # Mix valid fields with an invalid one. mixed_fields = ["tools", "invalid_field"] - agents = server.agent_manager.list_agents(actor=default_user, include_relationships=mixed_fields) + agents = await server.agent_manager.list_agents_async(actor=default_user, include_relationships=mixed_fields) assert len(agents) >= 1 agent = agents[0] # Valid fields should be loaded and accessible. @@ -870,9 +884,10 @@ def test_list_agents_select_fields_mixed(server: SyncServer, comprehensive_test_ assert not hasattr(agent, "invalid_field") -def test_list_agents_ascending(server: SyncServer, default_user): +@pytest.mark.asyncio +async def test_list_agents_ascending(server: SyncServer, default_user, event_loop): # Create two agents with known names - agent1 = server.agent_manager.create_agent( + agent1 = await server.agent_manager.create_agent_async( agent_create=CreateAgent( name="agent_oldest", llm_config=LLMConfig.default_config("gpt-4o-mini"), @@ -886,7 +901,7 @@ def test_list_agents_ascending(server: SyncServer, default_user): if USING_SQLITE: time.sleep(CREATE_DELAY_SQLITE) - agent2 = server.agent_manager.create_agent( + agent2 = await server.agent_manager.create_agent_async( agent_create=CreateAgent( name="agent_newest", llm_config=LLMConfig.default_config("gpt-4o-mini"), @@ -897,14 +912,15 @@ def test_list_agents_ascending(server: SyncServer, default_user): actor=default_user, ) - agents = server.agent_manager.list_agents(actor=default_user, ascending=True) + agents = await server.agent_manager.list_agents_async(actor=default_user, ascending=True) names = [agent.name for agent in agents] assert names.index("agent_oldest") < names.index("agent_newest") -def test_list_agents_descending(server: SyncServer, default_user): +@pytest.mark.asyncio +async def test_list_agents_descending(server: SyncServer, default_user, event_loop): # Create two agents with known names - agent1 = server.agent_manager.create_agent( + agent1 = await server.agent_manager.create_agent_async( agent_create=CreateAgent( name="agent_oldest", llm_config=LLMConfig.default_config("gpt-4o-mini"), @@ -918,7 +934,7 @@ def test_list_agents_descending(server: SyncServer, default_user): if USING_SQLITE: time.sleep(CREATE_DELAY_SQLITE) - agent2 = server.agent_manager.create_agent( + agent2 = await server.agent_manager.create_agent_async( agent_create=CreateAgent( name="agent_newest", llm_config=LLMConfig.default_config("gpt-4o-mini"), @@ -929,18 +945,19 @@ def test_list_agents_descending(server: SyncServer, default_user): actor=default_user, ) - agents = server.agent_manager.list_agents(actor=default_user, ascending=False) + agents = await server.agent_manager.list_agents_async(actor=default_user, ascending=False) names = [agent.name for agent in agents] assert names.index("agent_newest") < names.index("agent_oldest") -def test_list_agents_ordering_and_pagination(server: SyncServer, default_user): +@pytest.mark.asyncio +async def test_list_agents_ordering_and_pagination(server: SyncServer, default_user, event_loop): names = ["alpha_agent", "beta_agent", "gamma_agent"] created_agents = [] # Create agents in known order for name in names: - agent = server.agent_manager.create_agent( + agent = await server.agent_manager.create_agent_async( agent_create=CreateAgent( name=name, memory_blocks=[], @@ -957,17 +974,17 @@ def test_list_agents_ordering_and_pagination(server: SyncServer, default_user): agent_ids = {agent.name: agent.id for agent in created_agents} # Ascending (oldest to newest) - agents_asc = server.agent_manager.list_agents(actor=default_user, ascending=True) + agents_asc = await server.agent_manager.list_agents_async(actor=default_user, ascending=True) asc_names = [agent.name for agent in agents_asc] assert asc_names.index("alpha_agent") < asc_names.index("beta_agent") < asc_names.index("gamma_agent") # Descending (newest to oldest) - agents_desc = server.agent_manager.list_agents(actor=default_user, ascending=False) + agents_desc = await server.agent_manager.list_agents_async(actor=default_user, ascending=False) desc_names = [agent.name for agent in agents_desc] assert desc_names.index("gamma_agent") < desc_names.index("beta_agent") < desc_names.index("alpha_agent") # After: Get agents after alpha_agent in ascending order (should exclude alpha) - after_alpha = server.agent_manager.list_agents(actor=default_user, after=agent_ids["alpha_agent"], ascending=True) + after_alpha = await server.agent_manager.list_agents_async(actor=default_user, after=agent_ids["alpha_agent"], ascending=True) after_names = [a.name for a in after_alpha] assert "alpha_agent" not in after_names assert "beta_agent" in after_names @@ -975,7 +992,7 @@ def test_list_agents_ordering_and_pagination(server: SyncServer, default_user): assert after_names == ["beta_agent", "gamma_agent"] # Before: Get agents before gamma_agent in ascending order (should exclude gamma) - before_gamma = server.agent_manager.list_agents(actor=default_user, before=agent_ids["gamma_agent"], ascending=True) + before_gamma = await server.agent_manager.list_agents_async(actor=default_user, before=agent_ids["gamma_agent"], ascending=True) before_names = [a.name for a in before_gamma] assert "gamma_agent" not in before_names assert "alpha_agent" in before_names @@ -983,12 +1000,12 @@ def test_list_agents_ordering_and_pagination(server: SyncServer, default_user): assert before_names == ["alpha_agent", "beta_agent"] # After: Get agents after gamma_agent in descending order (should exclude gamma, return beta then alpha) - after_gamma_desc = server.agent_manager.list_agents(actor=default_user, after=agent_ids["gamma_agent"], ascending=False) + after_gamma_desc = await server.agent_manager.list_agents_async(actor=default_user, after=agent_ids["gamma_agent"], ascending=False) after_names_desc = [a.name for a in after_gamma_desc] assert after_names_desc == ["beta_agent", "alpha_agent"] # Before: Get agents before alpha_agent in descending order (should exclude alpha) - before_alpha_desc = server.agent_manager.list_agents(actor=default_user, before=agent_ids["alpha_agent"], ascending=False) + before_alpha_desc = await server.agent_manager.list_agents_async(actor=default_user, before=agent_ids["alpha_agent"], ascending=False) before_names_desc = [a.name for a in before_alpha_desc] assert before_names_desc == ["gamma_agent", "beta_agent"] @@ -1093,10 +1110,11 @@ async def test_attach_source(server: SyncServer, sarah_agent, default_source, de assert len([s for s in agent.sources if s.id == default_source.id]) == 1 -def test_list_attached_source_ids(server: SyncServer, sarah_agent, default_source, other_source, default_user): +@pytest.mark.asyncio +async def test_list_attached_source_ids(server: SyncServer, sarah_agent, default_source, other_source, default_user, event_loop): """Test listing source IDs attached to an agent.""" # Initially should have no sources - sources = server.agent_manager.list_attached_sources(sarah_agent.id, actor=default_user) + sources = await server.agent_manager.list_attached_sources_async(sarah_agent.id, actor=default_user) assert len(sources) == 0 # Attach sources @@ -1104,7 +1122,7 @@ def test_list_attached_source_ids(server: SyncServer, sarah_agent, default_sourc server.agent_manager.attach_source(sarah_agent.id, other_source.id, actor=default_user) # List sources and verify - sources = server.agent_manager.list_attached_sources(sarah_agent.id, actor=default_user) + sources = await server.agent_manager.list_attached_sources_async(sarah_agent.id, actor=default_user) assert len(sources) == 2 source_ids = [s.id for s in sources] assert default_source.id in source_ids @@ -1150,10 +1168,11 @@ def test_detach_source_nonexistent_agent(server: SyncServer, default_source, def server.agent_manager.detach_source(agent_id="nonexistent-agent-id", source_id=default_source.id, actor=default_user) -def test_list_attached_source_ids_nonexistent_agent(server: SyncServer, default_user): +@pytest.mark.asyncio +async def test_list_attached_source_ids_nonexistent_agent(server: SyncServer, default_user, event_loop): """Test listing sources for a nonexistent agent.""" with pytest.raises(NoResultFound): - server.agent_manager.list_attached_sources(agent_id="nonexistent-agent-id", actor=default_user) + await server.agent_manager.list_attached_sources_async(agent_id="nonexistent-agent-id", actor=default_user) def test_list_attached_agents(server: SyncServer, sarah_agent, charles_agent, default_source, default_user): @@ -1239,74 +1258,85 @@ def test_list_agents_matching_no_tags(server: SyncServer, default_user, agent_wi assert len(agents) == 0 # No agent should match -def test_list_agents_by_tags_match_all(server: SyncServer, sarah_agent, charles_agent, default_user): +@pytest.mark.asyncio +async def test_list_agents_by_tags_match_all(server: SyncServer, sarah_agent, charles_agent, default_user, event_loop): """Test listing agents that have ALL specified tags.""" # Create agents with multiple tags - server.agent_manager.update_agent(sarah_agent.id, UpdateAgent(tags=["test", "production", "gpt4"]), actor=default_user) - server.agent_manager.update_agent(charles_agent.id, UpdateAgent(tags=["test", "development", "gpt4"]), actor=default_user) + await server.agent_manager.update_agent_async(sarah_agent.id, UpdateAgent(tags=["test", "production", "gpt4"]), actor=default_user) + await server.agent_manager.update_agent_async(charles_agent.id, UpdateAgent(tags=["test", "development", "gpt4"]), actor=default_user) # Search for agents with all specified tags - agents = server.agent_manager.list_agents(actor=default_user, tags=["test", "gpt4"], match_all_tags=True) + agents = await server.agent_manager.list_agents_async(actor=default_user, tags=["test", "gpt4"], match_all_tags=True) assert len(agents) == 2 agent_ids = [a.id for a in agents] assert sarah_agent.id in agent_ids assert charles_agent.id in agent_ids # Search for tags that only sarah_agent has - agents = server.agent_manager.list_agents(actor=default_user, tags=["test", "production"], match_all_tags=True) + agents = await server.agent_manager.list_agents_async(actor=default_user, tags=["test", "production"], match_all_tags=True) assert len(agents) == 1 assert agents[0].id == sarah_agent.id -def test_list_agents_by_tags_match_any(server: SyncServer, sarah_agent, charles_agent, default_user): +@pytest.mark.asyncio +async def test_list_agents_by_tags_match_any(server: SyncServer, sarah_agent, charles_agent, default_user, event_loop): """Test listing agents that have ANY of the specified tags.""" # Create agents with different tags - server.agent_manager.update_agent(sarah_agent.id, UpdateAgent(tags=["production", "gpt4"]), actor=default_user) - server.agent_manager.update_agent(charles_agent.id, UpdateAgent(tags=["development", "gpt3"]), actor=default_user) + await server.agent_manager.update_agent_async(sarah_agent.id, UpdateAgent(tags=["production", "gpt4"]), actor=default_user) + await server.agent_manager.update_agent_async(charles_agent.id, UpdateAgent(tags=["development", "gpt3"]), actor=default_user) # Search for agents with any of the specified tags - agents = server.agent_manager.list_agents(actor=default_user, tags=["production", "development"], match_all_tags=False) + agents = await server.agent_manager.list_agents_async(actor=default_user, tags=["production", "development"], match_all_tags=False) assert len(agents) == 2 agent_ids = [a.id for a in agents] assert sarah_agent.id in agent_ids assert charles_agent.id in agent_ids # Search for tags where only sarah_agent matches - agents = server.agent_manager.list_agents(actor=default_user, tags=["production", "nonexistent"], match_all_tags=False) + agents = await server.agent_manager.list_agents_async(actor=default_user, tags=["production", "nonexistent"], match_all_tags=False) assert len(agents) == 1 assert agents[0].id == sarah_agent.id -def test_list_agents_by_tags_no_matches(server: SyncServer, sarah_agent, charles_agent, default_user): +@pytest.mark.asyncio +async def test_list_agents_by_tags_no_matches(server: SyncServer, sarah_agent, charles_agent, default_user, event_loop): """Test listing agents when no tags match.""" # Create agents with tags - server.agent_manager.update_agent(sarah_agent.id, UpdateAgent(tags=["production", "gpt4"]), actor=default_user) - server.agent_manager.update_agent(charles_agent.id, UpdateAgent(tags=["development", "gpt3"]), actor=default_user) + await server.agent_manager.update_agent_async(sarah_agent.id, UpdateAgent(tags=["production", "gpt4"]), actor=default_user) + await server.agent_manager.update_agent_async(charles_agent.id, UpdateAgent(tags=["development", "gpt3"]), actor=default_user) # Search for nonexistent tags - agents = server.agent_manager.list_agents(actor=default_user, tags=["nonexistent1", "nonexistent2"], match_all_tags=True) + agents = await server.agent_manager.list_agents_async(actor=default_user, tags=["nonexistent1", "nonexistent2"], match_all_tags=True) assert len(agents) == 0 - agents = server.agent_manager.list_agents(actor=default_user, tags=["nonexistent1", "nonexistent2"], match_all_tags=False) + agents = await server.agent_manager.list_agents_async(actor=default_user, tags=["nonexistent1", "nonexistent2"], match_all_tags=False) assert len(agents) == 0 -def test_list_agents_by_tags_with_other_filters(server: SyncServer, sarah_agent, charles_agent, default_user): +@pytest.mark.asyncio +async def test_list_agents_by_tags_with_other_filters(server: SyncServer, sarah_agent, charles_agent, default_user, event_loop): """Test combining tag search with other filters.""" # Create agents with specific names and tags - server.agent_manager.update_agent(sarah_agent.id, UpdateAgent(name="production_agent", tags=["production", "gpt4"]), actor=default_user) - server.agent_manager.update_agent(charles_agent.id, UpdateAgent(name="test_agent", tags=["production", "gpt3"]), actor=default_user) + await server.agent_manager.update_agent_async( + sarah_agent.id, UpdateAgent(name="production_agent", tags=["production", "gpt4"]), actor=default_user + ) + await server.agent_manager.update_agent_async( + charles_agent.id, UpdateAgent(name="test_agent", tags=["production", "gpt3"]), actor=default_user + ) # List agents with specific tag and name pattern - agents = server.agent_manager.list_agents(actor=default_user, tags=["production"], match_all_tags=True, name="production_agent") + agents = await server.agent_manager.list_agents_async( + actor=default_user, tags=["production"], match_all_tags=True, name="production_agent" + ) assert len(agents) == 1 assert agents[0].id == sarah_agent.id -def test_list_agents_by_tags_pagination(server: SyncServer, default_user, default_organization): +@pytest.mark.asyncio +async def test_list_agents_by_tags_pagination(server: SyncServer, default_user, default_organization, event_loop): """Test pagination when listing agents by tags.""" # Create first agent - agent1 = server.agent_manager.create_agent( + agent1 = await server.agent_manager.create_agent_async( agent_create=CreateAgent( name="agent1", tags=["pagination_test", "tag1"], @@ -1322,7 +1352,7 @@ def test_list_agents_by_tags_pagination(server: SyncServer, default_user, defaul time.sleep(CREATE_DELAY_SQLITE) # Ensure distinct created_at timestamps # Create second agent - agent2 = server.agent_manager.create_agent( + agent2 = await server.agent_manager.create_agent_async( agent_create=CreateAgent( name="agent2", tags=["pagination_test", "tag2"], @@ -1335,19 +1365,19 @@ def test_list_agents_by_tags_pagination(server: SyncServer, default_user, defaul ) # Get first page - first_page = server.agent_manager.list_agents(actor=default_user, tags=["pagination_test"], match_all_tags=True, limit=1) + first_page = await server.agent_manager.list_agents_async(actor=default_user, tags=["pagination_test"], match_all_tags=True, limit=1) assert len(first_page) == 1 first_agent_id = first_page[0].id # Get second page using cursor - second_page = server.agent_manager.list_agents( + second_page = await server.agent_manager.list_agents_async( actor=default_user, tags=["pagination_test"], match_all_tags=True, after=first_agent_id, limit=1 ) assert len(second_page) == 1 assert second_page[0].id != first_agent_id # Get previous page using before - prev_page = server.agent_manager.list_agents( + prev_page = await server.agent_manager.list_agents_async( actor=default_user, tags=["pagination_test"], match_all_tags=True, before=second_page[0].id, limit=1 ) assert len(prev_page) == 1 @@ -1360,10 +1390,11 @@ def test_list_agents_by_tags_pagination(server: SyncServer, default_user, defaul assert agent2.id in all_ids -def test_list_agents_query_text_pagination(server: SyncServer, default_user, default_organization): +@pytest.mark.asyncio +async def test_list_agents_query_text_pagination(server: SyncServer, default_user, default_organization, event_loop): """Test listing agents with query text filtering and pagination.""" # Create test agents with specific names and descriptions - agent1 = server.agent_manager.create_agent( + agent1 = await server.agent_manager.create_agent_async( agent_create=CreateAgent( name="Search Agent One", memory_blocks=[], @@ -1375,7 +1406,7 @@ def test_list_agents_query_text_pagination(server: SyncServer, default_user, def actor=default_user, ) - agent2 = server.agent_manager.create_agent( + agent2 = await server.agent_manager.create_agent_async( agent_create=CreateAgent( name="Search Agent Two", memory_blocks=[], @@ -1387,7 +1418,7 @@ def test_list_agents_query_text_pagination(server: SyncServer, default_user, def actor=default_user, ) - agent3 = server.agent_manager.create_agent( + agent3 = await server.agent_manager.create_agent_async( agent_create=CreateAgent( name="Different Agent", memory_blocks=[], @@ -1400,32 +1431,32 @@ def test_list_agents_query_text_pagination(server: SyncServer, default_user, def ) # Test query text filtering - search_results = server.agent_manager.list_agents(actor=default_user, query_text="search agent") + search_results = await server.agent_manager.list_agents_async(actor=default_user, query_text="search agent") assert len(search_results) == 2 search_agent_ids = {agent.id for agent in search_results} assert agent1.id in search_agent_ids assert agent2.id in search_agent_ids assert agent3.id not in search_agent_ids - different_results = server.agent_manager.list_agents(actor=default_user, query_text="different agent") + different_results = await server.agent_manager.list_agents_async(actor=default_user, query_text="different agent") assert len(different_results) == 1 assert different_results[0].id == agent3.id # Test pagination with query text - first_page = server.agent_manager.list_agents(actor=default_user, query_text="search agent", limit=1) + first_page = await server.agent_manager.list_agents_async(actor=default_user, query_text="search agent", limit=1) assert len(first_page) == 1 first_agent_id = first_page[0].id # Get second page using cursor - second_page = server.agent_manager.list_agents(actor=default_user, query_text="search agent", after=first_agent_id, limit=1) + second_page = await server.agent_manager.list_agents_async(actor=default_user, query_text="search agent", after=first_agent_id, limit=1) assert len(second_page) == 1 assert second_page[0].id != first_agent_id # Test before and after - all_agents = server.agent_manager.list_agents(actor=default_user, query_text="agent") + all_agents = await server.agent_manager.list_agents_async(actor=default_user, query_text="agent") assert len(all_agents) == 3 first_agent, second_agent, third_agent = all_agents - middle_agent = server.agent_manager.list_agents( + middle_agent = await server.agent_manager.list_agents_async( actor=default_user, query_text="search agent", before=third_agent.id, after=first_agent.id ) assert len(middle_agent) == 1 @@ -1449,7 +1480,7 @@ async def test_reset_messages_no_messages(server: SyncServer, sarah_agent, defau does not fail and clears out message_ids if somehow it's non-empty. """ # Force a weird scenario: Suppose the message_ids field was set non-empty (without actual messages). - server.agent_manager.update_agent(sarah_agent.id, UpdateAgent(message_ids=["ghost-message-id"]), actor=default_user) + await server.agent_manager.update_agent_async(sarah_agent.id, UpdateAgent(message_ids=["ghost-message-id"]), actor=default_user) updated_agent = await server.agent_manager.get_agent_by_id_async(sarah_agent.id, default_user) assert updated_agent.message_ids == ["ghost-message-id"] @@ -1457,7 +1488,7 @@ async def test_reset_messages_no_messages(server: SyncServer, sarah_agent, defau reset_agent = server.agent_manager.reset_messages(agent_id=sarah_agent.id, actor=default_user) assert len(reset_agent.message_ids) == 1 # Double check that physically no messages exist - assert server.message_manager.size(agent_id=sarah_agent.id, actor=default_user) == 1 + assert await server.message_manager.size_async(agent_id=sarah_agent.id, actor=default_user) == 1 @pytest.mark.asyncio @@ -1467,7 +1498,7 @@ async def test_reset_messages_default_messages(server: SyncServer, sarah_agent, does not fail and clears out message_ids if somehow it's non-empty. """ # Force a weird scenario: Suppose the message_ids field was set non-empty (without actual messages). - server.agent_manager.update_agent(sarah_agent.id, UpdateAgent(message_ids=["ghost-message-id"]), actor=default_user) + await server.agent_manager.update_agent_async(sarah_agent.id, UpdateAgent(message_ids=["ghost-message-id"]), actor=default_user) updated_agent = await server.agent_manager.get_agent_by_id_async(sarah_agent.id, default_user) assert updated_agent.message_ids == ["ghost-message-id"] @@ -1475,7 +1506,7 @@ async def test_reset_messages_default_messages(server: SyncServer, sarah_agent, reset_agent = server.agent_manager.reset_messages(agent_id=sarah_agent.id, actor=default_user, add_default_initial_messages=True) assert len(reset_agent.message_ids) == 4 # Double check that physically no messages exist - assert server.message_manager.size(agent_id=sarah_agent.id, actor=default_user) == 4 + assert await server.message_manager.size_async(agent_id=sarah_agent.id, actor=default_user) == 4 @pytest.mark.asyncio @@ -1508,7 +1539,7 @@ async def test_reset_messages_with_existing_messages(server: SyncServer, sarah_a agent_before = await server.agent_manager.get_agent_by_id_async(sarah_agent.id, default_user) # This is 4 because creating the message does not necessarily add it to the in context message ids assert len(agent_before.message_ids) == 4 - assert server.message_manager.size(agent_id=sarah_agent.id, actor=default_user) == 6 + assert await server.message_manager.size_async(agent_id=sarah_agent.id, actor=default_user) == 6 # 2. Reset all messages reset_agent = server.agent_manager.reset_messages(agent_id=sarah_agent.id, actor=default_user) @@ -1517,10 +1548,11 @@ async def test_reset_messages_with_existing_messages(server: SyncServer, sarah_a assert len(reset_agent.message_ids) == 1 # 4. Verify the messages are physically removed - assert server.message_manager.size(agent_id=sarah_agent.id, actor=default_user) == 1 + assert await server.message_manager.size_async(agent_id=sarah_agent.id, actor=default_user) == 1 -def test_reset_messages_idempotency(server: SyncServer, sarah_agent, default_user): +@pytest.mark.asyncio +async def test_reset_messages_idempotency(server: SyncServer, sarah_agent, default_user, event_loop): """ Test that calling reset_messages multiple times has no adverse effect. """ @@ -1537,15 +1569,16 @@ def test_reset_messages_idempotency(server: SyncServer, sarah_agent, default_use # First reset reset_agent = server.agent_manager.reset_messages(agent_id=sarah_agent.id, actor=default_user) assert len(reset_agent.message_ids) == 1 - assert server.message_manager.size(agent_id=sarah_agent.id, actor=default_user) == 1 + assert await server.message_manager.size_async(agent_id=sarah_agent.id, actor=default_user) == 1 # Second reset should do nothing new reset_agent_again = server.agent_manager.reset_messages(agent_id=sarah_agent.id, actor=default_user) assert len(reset_agent.message_ids) == 1 - assert server.message_manager.size(agent_id=sarah_agent.id, actor=default_user) == 1 + assert await server.message_manager.size_async(agent_id=sarah_agent.id, actor=default_user) == 1 -def test_modify_letta_message(server: SyncServer, sarah_agent, default_user): +@pytest.mark.asyncio +async def test_modify_letta_message(server: SyncServer, sarah_agent, default_user, event_loop): """ Test updating a message. """ @@ -1560,32 +1593,32 @@ def test_modify_letta_message(server: SyncServer, sarah_agent, default_user): # user message update_user_message = UpdateUserMessage(content="Hello, Sarah!") - original_user_message = server.message_manager.get_message_by_id(message_id=user_message.id, actor=default_user) + original_user_message = await server.message_manager.get_message_by_id_async(message_id=user_message.id, actor=default_user) assert original_user_message.content[0].text != update_user_message.content server.message_manager.update_message_by_letta_message( message_id=user_message.id, letta_message_update=update_user_message, actor=default_user ) - updated_user_message = server.message_manager.get_message_by_id(message_id=user_message.id, actor=default_user) + updated_user_message = await server.message_manager.get_message_by_id_async(message_id=user_message.id, actor=default_user) assert updated_user_message.content[0].text == update_user_message.content # system message update_system_message = UpdateSystemMessage(content="You are a friendly assistant!") - original_system_message = server.message_manager.get_message_by_id(message_id=system_message.id, actor=default_user) + original_system_message = await server.message_manager.get_message_by_id_async(message_id=system_message.id, actor=default_user) assert original_system_message.content[0].text != update_system_message.content server.message_manager.update_message_by_letta_message( message_id=system_message.id, letta_message_update=update_system_message, actor=default_user ) - updated_system_message = server.message_manager.get_message_by_id(message_id=system_message.id, actor=default_user) + updated_system_message = await server.message_manager.get_message_by_id_async(message_id=system_message.id, actor=default_user) assert updated_system_message.content[0].text == update_system_message.content # reasoning message update_reasoning_message = UpdateReasoningMessage(reasoning="I am thinking") - original_reasoning_message = server.message_manager.get_message_by_id(message_id=reasoning_message.id, actor=default_user) + original_reasoning_message = await server.message_manager.get_message_by_id_async(message_id=reasoning_message.id, actor=default_user) assert original_reasoning_message.content[0].text != update_reasoning_message.reasoning server.message_manager.update_message_by_letta_message( message_id=reasoning_message.id, letta_message_update=update_reasoning_message, actor=default_user ) - updated_reasoning_message = server.message_manager.get_message_by_id(message_id=reasoning_message.id, actor=default_user) + updated_reasoning_message = await server.message_manager.get_message_by_id_async(message_id=reasoning_message.id, actor=default_user) assert updated_reasoning_message.content[0].text == update_reasoning_message.reasoning # assistant message @@ -1597,14 +1630,14 @@ def test_modify_letta_message(server: SyncServer, sarah_agent, default_user): return arguments["message"] update_assistant_message = UpdateAssistantMessage(content="I am an agent!") - original_assistant_message = server.message_manager.get_message_by_id(message_id=assistant_message.id, actor=default_user) + original_assistant_message = await server.message_manager.get_message_by_id_async(message_id=assistant_message.id, actor=default_user) print("ORIGINAL", original_assistant_message.tool_calls) print("MESSAGE", parse_send_message(original_assistant_message.tool_calls[0])) assert parse_send_message(original_assistant_message.tool_calls[0]) != update_assistant_message.content server.message_manager.update_message_by_letta_message( message_id=assistant_message.id, letta_message_update=update_assistant_message, actor=default_user ) - updated_assistant_message = server.message_manager.get_message_by_id(message_id=assistant_message.id, actor=default_user) + updated_assistant_message = await server.message_manager.get_message_by_id_async(message_id=assistant_message.id, actor=default_user) print("UPDATED", updated_assistant_message.tool_calls) print("MESSAGE", parse_send_message(updated_assistant_message.tool_calls[0])) assert parse_send_message(updated_assistant_message.tool_calls[0]) == update_assistant_message.content @@ -1757,29 +1790,6 @@ def test_get_block_with_label(server: SyncServer, sarah_agent, default_block, de assert block.label == default_block.label -def test_refresh_memory(server: SyncServer, default_user): - block = server.block_manager.create_or_update_block( - PydanticBlock( - label="test", - value="test", - limit=1000, - ), - actor=default_user, - ) - agent = server.agent_manager.create_agent( - CreateAgent( - name="test", - llm_config=LLMConfig.default_config("gpt-4o-mini"), - embedding_config=EmbeddingConfig.default_config(provider="openai"), - include_base_tools=False, - ), - actor=default_user, - ) - assert len(agent.memory.blocks) == 0 - agent = server.agent_manager.refresh_memory(agent_state=agent, actor=default_user) - assert len(agent.memory.blocks) == 0 - - @pytest.mark.asyncio async def test_refresh_memory_async(server: SyncServer, default_user, event_loop): block = server.block_manager.create_or_update_block( @@ -1826,41 +1836,44 @@ async def test_refresh_memory_async(server: SyncServer, default_user, event_loop # ====================================================================================================================== -def test_agent_list_passages_basic(server, default_user, sarah_agent, agent_passages_setup): +@pytest.mark.asyncio +async def test_agent_list_passages_basic(server, default_user, sarah_agent, agent_passages_setup, event_loop): """Test basic listing functionality of agent passages""" - all_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id) + all_passages = await server.agent_manager.list_passages_async(actor=default_user, agent_id=sarah_agent.id) assert len(all_passages) == 5 # 3 source + 2 agent passages -def test_agent_list_passages_ordering(server, default_user, sarah_agent, agent_passages_setup): +@pytest.mark.asyncio +async def test_agent_list_passages_ordering(server, default_user, sarah_agent, agent_passages_setup, event_loop): """Test ordering of agent passages""" # Test ascending order - asc_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, ascending=True) + asc_passages = await server.agent_manager.list_passages_async(actor=default_user, agent_id=sarah_agent.id, ascending=True) assert len(asc_passages) == 5 for i in range(1, len(asc_passages)): assert asc_passages[i - 1].created_at <= asc_passages[i].created_at # Test descending order - desc_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, ascending=False) + desc_passages = await server.agent_manager.list_passages_async(actor=default_user, agent_id=sarah_agent.id, ascending=False) assert len(desc_passages) == 5 for i in range(1, len(desc_passages)): assert desc_passages[i - 1].created_at >= desc_passages[i].created_at -def test_agent_list_passages_pagination(server, default_user, sarah_agent, agent_passages_setup): +@pytest.mark.asyncio +async def test_agent_list_passages_pagination(server, default_user, sarah_agent, agent_passages_setup, event_loop): """Test pagination of agent passages""" # Test limit - limited_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, limit=3) + limited_passages = await server.agent_manager.list_passages_async(actor=default_user, agent_id=sarah_agent.id, limit=3) assert len(limited_passages) == 3 # Test cursor-based pagination - first_page = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, limit=2, ascending=True) + first_page = await server.agent_manager.list_passages_async(actor=default_user, agent_id=sarah_agent.id, limit=2, ascending=True) assert len(first_page) == 2 - second_page = server.agent_manager.list_passages( + second_page = await server.agent_manager.list_passages_async( actor=default_user, agent_id=sarah_agent.id, after=first_page[-1].id, limit=2, ascending=True ) assert len(second_page) == 2 @@ -1874,14 +1887,14 @@ def test_agent_list_passages_pagination(server, default_user, sarah_agent, agent [mid] * | * * | * """ - middle_page = server.agent_manager.list_passages( + middle_page = await server.agent_manager.list_passages_async( actor=default_user, agent_id=sarah_agent.id, before=second_page[-1].id, after=first_page[0].id, ascending=True ) assert len(middle_page) == 2 assert middle_page[0].id == first_page[-1].id assert middle_page[1].id == second_page[0].id - middle_page_desc = server.agent_manager.list_passages( + middle_page_desc = await server.agent_manager.list_passages_async( actor=default_user, agent_id=sarah_agent.id, before=second_page[-1].id, after=first_page[0].id, ascending=False ) assert len(middle_page_desc) == 2 @@ -1889,31 +1902,40 @@ def test_agent_list_passages_pagination(server, default_user, sarah_agent, agent assert middle_page_desc[1].id == first_page[-1].id -def test_agent_list_passages_text_search(server, default_user, sarah_agent, agent_passages_setup): +@pytest.mark.asyncio +async def test_agent_list_passages_text_search(server, default_user, sarah_agent, agent_passages_setup, event_loop): """Test text search functionality of agent passages""" # Test text search for source passages - source_text_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, query_text="Source passage") + source_text_passages = await server.agent_manager.list_passages_async( + actor=default_user, agent_id=sarah_agent.id, query_text="Source passage" + ) assert len(source_text_passages) == 3 # Test text search for agent passages - agent_text_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, query_text="Agent passage") + agent_text_passages = await server.agent_manager.list_passages_async( + actor=default_user, agent_id=sarah_agent.id, query_text="Agent passage" + ) assert len(agent_text_passages) == 2 -def test_agent_list_passages_agent_only(server, default_user, sarah_agent, agent_passages_setup): +@pytest.mark.asyncio +async def test_agent_list_passages_agent_only(server, default_user, sarah_agent, agent_passages_setup, event_loop): """Test text search functionality of agent passages""" # Test text search for agent passages - agent_text_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, agent_only=True) + agent_text_passages = await server.agent_manager.list_passages_async(actor=default_user, agent_id=sarah_agent.id, agent_only=True) assert len(agent_text_passages) == 2 -def test_agent_list_passages_filtering(server, default_user, sarah_agent, default_source, agent_passages_setup): +@pytest.mark.asyncio +async def test_agent_list_passages_filtering(server, default_user, sarah_agent, default_source, agent_passages_setup, event_loop): """Test filtering functionality of agent passages""" # Test source filtering - source_filtered = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, source_id=default_source.id) + source_filtered = await server.agent_manager.list_passages_async( + actor=default_user, agent_id=sarah_agent.id, source_id=default_source.id + ) assert len(source_filtered) == 3 # Test date filtering @@ -1921,13 +1943,14 @@ def test_agent_list_passages_filtering(server, default_user, sarah_agent, defaul future_date = now + timedelta(days=1) past_date = now - timedelta(days=1) - date_filtered = server.agent_manager.list_passages( + date_filtered = await server.agent_manager.list_passages_async( actor=default_user, agent_id=sarah_agent.id, start_date=past_date, end_date=future_date ) assert len(date_filtered) == 5 -def test_agent_list_passages_vector_search(server, default_user, sarah_agent, default_source): +@pytest.mark.asyncio +async def test_agent_list_passages_vector_search(server, default_user, sarah_agent, default_source, event_loop): """Test vector search functionality of agent passages""" embed_model = embedding_model(DEFAULT_EMBEDDING_CONFIG) @@ -1968,7 +1991,7 @@ def test_agent_list_passages_vector_search(server, default_user, sarah_agent, de query_key = "What's my favorite color?" # Test vector search with all passages - results = server.agent_manager.list_passages( + results = await server.agent_manager.list_passages_async( actor=default_user, agent_id=sarah_agent.id, query_text=query_key, @@ -1983,7 +2006,7 @@ def test_agent_list_passages_vector_search(server, default_user, sarah_agent, de assert "blue" in results[1].text or "blue" in results[2].text # Test vector search with agent_only=True - agent_only_results = server.agent_manager.list_passages( + agent_only_results = await server.agent_manager.list_passages_async( actor=default_user, agent_id=sarah_agent.id, query_text=query_key, @@ -1998,11 +2021,12 @@ def test_agent_list_passages_vector_search(server, default_user, sarah_agent, de assert agent_only_results[1].text == "blue shoes" -def test_list_source_passages_only(server: SyncServer, default_user, default_source, agent_passages_setup): +@pytest.mark.asyncio +async def test_list_source_passages_only(server: SyncServer, default_user, default_source, agent_passages_setup, event_loop): """Test listing passages from a source without specifying an agent.""" # List passages by source_id without agent_id - source_passages = server.agent_manager.list_passages( + source_passages = await server.agent_manager.list_passages_async( actor=default_user, source_id=default_source.id, ) @@ -2136,8 +2160,9 @@ def test_passage_get_by_id(server: SyncServer, agent_passage_fixture, source_pas assert retrieved.text == source_passage_fixture.text -def test_passage_cascade_deletion( - server: SyncServer, agent_passage_fixture, source_passage_fixture, default_user, default_source, sarah_agent +@pytest.mark.asyncio +async def test_passage_cascade_deletion( + server: SyncServer, agent_passage_fixture, source_passage_fixture, default_user, default_source, sarah_agent, event_loop ): """Test that passages are deleted when their parent (agent or source) is deleted.""" # Verify passages exist @@ -2148,7 +2173,7 @@ def test_passage_cascade_deletion( # Delete agent and verify its passages are deleted server.agent_manager.delete_agent(sarah_agent.id, default_user) - agentic_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, agent_only=True) + agentic_passages = await server.agent_manager.list_passages_async(actor=default_user, agent_id=sarah_agent.id, agent_only=True) assert len(agentic_passages) == 0 # Delete source and verify its passages are deleted @@ -2410,22 +2435,15 @@ async def test_delete_tool_by_id(server: SyncServer, print_tool, default_user, e assert len(tools) == 0 -def test_upsert_base_tools(server: SyncServer, default_user): - tools = server.tool_manager.upsert_base_tools(actor=default_user) - expected_tool_names = sorted( - set( - BASE_TOOLS - + BASE_MEMORY_TOOLS - + MULTI_AGENT_TOOLS - + BASE_SLEEPTIME_TOOLS - + BASE_VOICE_SLEEPTIME_TOOLS - + BASE_VOICE_SLEEPTIME_CHAT_TOOLS - ) - ) +@pytest.mark.asyncio +async def test_upsert_base_tools(server: SyncServer, default_user, event_loop): + tools = await server.tool_manager.upsert_base_tools_async(actor=default_user) + expected_tool_names = sorted(LETTA_TOOL_SET) + assert sorted([t.name for t in tools]) == expected_tool_names # Call it again to make sure it doesn't create duplicates - tools = server.tool_manager.upsert_base_tools(actor=default_user) + tools = await server.tool_manager.upsert_base_tools_async(actor=default_user) assert sorted([t.name for t in tools]) == expected_tool_names # Confirm that the return tools have no source_code, but a json_schema @@ -2442,6 +2460,8 @@ def test_upsert_base_tools(server: SyncServer, default_user): assert t.tool_type == ToolType.LETTA_VOICE_SLEEPTIME_CORE elif t.name in BASE_VOICE_SLEEPTIME_CHAT_TOOLS: assert t.tool_type == ToolType.LETTA_VOICE_SLEEPTIME_CORE + elif t.name in BUILTIN_TOOLS: + assert t.tool_type == ToolType.LETTA_BUILTIN else: pytest.fail(f"The tool name is unrecognized as a base tool: {t.name}") assert t.source_code is None @@ -2823,7 +2843,8 @@ async def test_delete_block_detaches_from_agent(server: SyncServer, sarah_agent, assert not (block.id in [b.id for b in agent_state.memory.blocks]) -def test_get_agents_for_block(server: SyncServer, sarah_agent, charles_agent, default_user): +@pytest.mark.asyncio +async def test_get_agents_for_block(server: SyncServer, sarah_agent, charles_agent, default_user, event_loop): # Create and delete a block block = server.block_manager.create_or_update_block(PydanticBlock(label="alien", value="Sample content"), actor=default_user) sarah_agent = server.agent_manager.attach_block(agent_id=sarah_agent.id, block_id=block.id, actor=default_user) @@ -2834,7 +2855,7 @@ def test_get_agents_for_block(server: SyncServer, sarah_agent, charles_agent, de assert block.id in [b.id for b in charles_agent.memory.blocks] # Get the agents for that block - agent_states = server.block_manager.get_agents_for_block(block_id=block.id, actor=default_user) + agent_states = await server.block_manager.get_agents_for_block_async(block_id=block.id, actor=default_user) assert len(agent_states) == 2 # Check both agents are in the list @@ -2984,7 +3005,7 @@ def test_checkpoint_creates_history(server: SyncServer, default_user): # Act: checkpoint it block_manager.checkpoint_block(block_id=created_block.id, actor=default_user) - with db_context() as session: + with db_registry.session() as session: # Get BlockHistory entries for this block history_entries: List[BlockHistory] = session.query(BlockHistory).filter(BlockHistory.block_id == created_block.id).all() assert len(history_entries) == 1, "Exactly one history entry should be created" @@ -3017,7 +3038,7 @@ def test_multiple_checkpoints(server: SyncServer, default_user): # 3) Second checkpoint block_manager.checkpoint_block(block_id=block.id, actor=default_user) - with db_context() as session: + with db_registry.session() as session: history_entries = ( session.query(BlockHistory).filter(BlockHistory.block_id == block.id).order_by(BlockHistory.sequence_number.asc()).all() ) @@ -3050,7 +3071,7 @@ def test_checkpoint_with_agent_id(server: SyncServer, default_user, sarah_agent) block_manager.checkpoint_block(block_id=block.id, actor=default_user, agent_id=sarah_agent.id) # Verify - with db_context() as session: + with db_registry.session() as session: hist_entry = session.query(BlockHistory).filter(BlockHistory.block_id == block.id).one() assert hist_entry.actor_type == ActorType.LETTA_AGENT assert hist_entry.actor_id == sarah_agent.id @@ -3071,7 +3092,7 @@ def test_checkpoint_with_no_state_change(server: SyncServer, default_user): # 2) checkpoint again (no changes) block_manager.checkpoint_block(block_id=block.id, actor=default_user) - with db_context() as session: + with db_registry.session() as session: all_hist = session.query(BlockHistory).filter(BlockHistory.block_id == block.id).all() assert len(all_hist) == 2 @@ -3083,15 +3104,15 @@ def test_checkpoint_concurrency_stale(server: SyncServer, default_user): block = block_manager.create_or_update_block(PydanticBlock(label="test_stale_checkpoint", value="hello"), actor=default_user) # session1 loads - with db_context() as s1: + with db_registry.session() as s1: block_s1 = s1.get(Block, block.id) # version=1 # session2 loads - with db_context() as s2: + with db_registry.session() as s2: block_s2 = s2.get(Block, block.id) # also version=1 # session1 checkpoint => version=2 - with db_context() as s1: + with db_registry.session() as s1: block_s1 = s1.merge(block_s1) block_manager.checkpoint_block( block_id=block_s1.id, @@ -3102,7 +3123,7 @@ def test_checkpoint_concurrency_stale(server: SyncServer, default_user): # session2 tries to checkpoint => sees old version=1 => stale error with pytest.raises(StaleDataError): - with db_context() as s2: + with db_registry.session() as s2: block_s2 = s2.merge(block_s2) block_manager.checkpoint_block( block_id=block_s2.id, @@ -3133,7 +3154,7 @@ def test_checkpoint_no_future_states(server: SyncServer, default_user): # 3) Another checkpoint (no changes made) => should become seq=3, not delete anything block_manager.checkpoint_block(block_id=block_v1.id, actor=default_user) - with db_context() as session: + with db_registry.session() as session: # We expect 3 rows in block_history, none removed history_rows = ( session.query(BlockHistory).filter(BlockHistory.block_id == block_v1.id).order_by(BlockHistory.sequence_number.asc()).all() @@ -3230,7 +3251,7 @@ def test_checkpoint_deletes_future_states_after_undo(server: SyncServer, default # 5) Checkpoint => new seq=2, removing the old seq=2 and seq=3 block_manager.checkpoint_block(block_id=block_v1.id, actor=default_user) - with db_context() as session: + with db_registry.session() as session: # Let's see which BlockHistory rows remain history_entries = ( session.query(BlockHistory).filter(BlockHistory.block_id == block_v1.id).order_by(BlockHistory.sequence_number.asc()).all() @@ -3346,11 +3367,11 @@ def test_undo_concurrency_stale(server: SyncServer, default_user): # Now block is at seq=2 # session1 preloads the block - with db_context() as s1: + with db_registry.session() as s1: block_s1 = s1.get(Block, block_v1.id) # version=? let's say 2 in memory # session2 also preloads the block - with db_context() as s2: + with db_registry.session() as s2: block_s2 = s2.get(Block, block_v1.id) # also version=2 # Session1 -> undo to seq=1 @@ -3514,9 +3535,9 @@ def test_redo_concurrency_stale(server: SyncServer, default_user): # but there's a valid row for seq=3 in block_history (the 'v3' state). # 5) Simulate concurrency: two sessions each read the block at seq=2 - with db_context() as s1: + with db_registry.session() as s1: block_s1 = s1.get(Block, block.id) - with db_context() as s2: + with db_registry.session() as s2: block_s2 = s2.get(Block, block.id) # 6) Session1 redoes to seq=3 first -> success @@ -3535,7 +3556,8 @@ def test_redo_concurrency_stale(server: SyncServer, default_user): # ====================================================================================================================== -def test_create_and_upsert_identity(server: SyncServer, default_user): +@pytest.mark.asyncio +async def test_create_and_upsert_identity(server: SyncServer, default_user, event_loop): identity_create = IdentityCreate( identifier_key="1234", name="caren", @@ -3546,7 +3568,7 @@ def test_create_and_upsert_identity(server: SyncServer, default_user): ], ) - identity = server.identity_manager.create_identity(identity_create, actor=default_user) + identity = await server.identity_manager.create_identity_async(identity_create, actor=default_user) # Assertions to ensure the created identity matches the expected values assert identity.identifier_key == identity_create.identifier_key @@ -3557,51 +3579,54 @@ def test_create_and_upsert_identity(server: SyncServer, default_user): assert identity.project_id == None with pytest.raises(UniqueConstraintViolationError): - server.identity_manager.create_identity( + await server.identity_manager.create_identity_async( IdentityCreate(identifier_key="1234", name="sarah", identity_type=IdentityType.user), actor=default_user, ) identity_create.properties = [(IdentityProperty(key="age", value=29, type=IdentityPropertyType.number))] - identity = server.identity_manager.upsert_identity(identity=IdentityUpsert(**identity_create.model_dump()), actor=default_user) + identity = await server.identity_manager.upsert_identity_async( + identity=IdentityUpsert(**identity_create.model_dump()), actor=default_user + ) - identity = server.identity_manager.get_identity(identity_id=identity.id, actor=default_user) + identity = await server.identity_manager.get_identity_async(identity_id=identity.id, actor=default_user) assert len(identity.properties) == 1 assert identity.properties[0].key == "age" assert identity.properties[0].value == 29 - server.identity_manager.delete_identity(identity_id=identity.id, actor=default_user) + await server.identity_manager.delete_identity_async(identity_id=identity.id, actor=default_user) -def test_get_identities(server, default_user): +@pytest.mark.asyncio +async def test_get_identities(server, default_user): # Create identities to retrieve later - user = server.identity_manager.create_identity( + user = await server.identity_manager.create_identity_async( IdentityCreate(name="caren", identifier_key="1234", identity_type=IdentityType.user), actor=default_user ) - org = server.identity_manager.create_identity( + org = await server.identity_manager.create_identity_async( IdentityCreate(name="letta", identifier_key="0001", identity_type=IdentityType.org), actor=default_user ) # Retrieve identities by different filters - all_identities = server.identity_manager.list_identities(actor=default_user) + all_identities = await server.identity_manager.list_identities_async(actor=default_user) assert len(all_identities) == 2 - user_identities = server.identity_manager.list_identities(actor=default_user, identity_type=IdentityType.user) + user_identities = await server.identity_manager.list_identities_async(actor=default_user, identity_type=IdentityType.user) assert len(user_identities) == 1 assert user_identities[0].name == user.name - org_identities = server.identity_manager.list_identities(actor=default_user, identity_type=IdentityType.org) + org_identities = await server.identity_manager.list_identities_async(actor=default_user, identity_type=IdentityType.org) assert len(org_identities) == 1 assert org_identities[0].name == org.name - server.identity_manager.delete_identity(identity_id=user.id, actor=default_user) - server.identity_manager.delete_identity(identity_id=org.id, actor=default_user) + await server.identity_manager.delete_identity_async(identity_id=user.id, actor=default_user) + await server.identity_manager.delete_identity_async(identity_id=org.id, actor=default_user) @pytest.mark.asyncio async def test_update_identity(server: SyncServer, sarah_agent, charles_agent, default_user, event_loop): - identity = server.identity_manager.create_identity( + identity = await server.identity_manager.create_identity_async( IdentityCreate(name="caren", identifier_key="1234", identity_type=IdentityType.user), actor=default_user ) @@ -3610,10 +3635,10 @@ async def test_update_identity(server: SyncServer, sarah_agent, charles_agent, d agent_ids=[sarah_agent.id, charles_agent.id], properties=[IdentityProperty(key="email", value="caren@letta.com", type=IdentityPropertyType.string)], ) - server.identity_manager.update_identity(identity_id=identity.id, identity=update_data, actor=default_user) + await server.identity_manager.update_identity_async(identity_id=identity.id, identity=update_data, actor=default_user) # Retrieve the updated identity - updated_identity = server.identity_manager.get_identity(identity_id=identity.id, actor=default_user) + updated_identity = await server.identity_manager.get_identity_async(identity_id=identity.id, actor=default_user) # Assertions to verify the update assert updated_identity.agent_ids.sort() == update_data.agent_ids.sort() @@ -3624,16 +3649,16 @@ async def test_update_identity(server: SyncServer, sarah_agent, charles_agent, d agent_state = await server.agent_manager.get_agent_by_id_async(agent_id=charles_agent.id, actor=default_user) assert identity.id in agent_state.identity_ids - server.identity_manager.delete_identity(identity_id=identity.id, actor=default_user) + await server.identity_manager.delete_identity_async(identity_id=identity.id, actor=default_user) @pytest.mark.asyncio async def test_attach_detach_identity_from_agent(server: SyncServer, sarah_agent, default_user, event_loop): # Create an identity - identity = server.identity_manager.create_identity( + identity = await server.identity_manager.create_identity_async( IdentityCreate(name="caren", identifier_key="1234", identity_type=IdentityType.user), actor=default_user ) - agent_state = server.agent_manager.update_agent( + agent_state = await server.agent_manager.update_agent_async( agent_id=sarah_agent.id, agent_update=UpdateAgent(identity_ids=[identity.id]), actor=default_user ) @@ -3641,10 +3666,10 @@ async def test_attach_detach_identity_from_agent(server: SyncServer, sarah_agent assert identity.id in agent_state.identity_ids # Now attempt to delete the identity - server.identity_manager.delete_identity(identity_id=identity.id, actor=default_user) + await server.identity_manager.delete_identity_async(identity_id=identity.id, actor=default_user) # Verify that the identity was deleted - identities = server.identity_manager.list_identities(actor=default_user) + identities = await server.identity_manager.list_identities_async(actor=default_user) assert len(identities) == 0 # Check that block has been detached too @@ -3652,13 +3677,14 @@ async def test_attach_detach_identity_from_agent(server: SyncServer, sarah_agent assert not identity.id in agent_state.identity_ids -def test_get_set_agents_for_identities(server: SyncServer, sarah_agent, charles_agent, default_user): - identity = server.identity_manager.create_identity( +@pytest.mark.asyncio +async def test_get_set_agents_for_identities(server: SyncServer, sarah_agent, charles_agent, default_user, event_loop): + identity = await server.identity_manager.create_identity_async( IdentityCreate(name="caren", identifier_key="1234", identity_type=IdentityType.user, agent_ids=[sarah_agent.id, charles_agent.id]), actor=default_user, ) - agent_with_identity = server.create_agent( + agent_with_identity = await server.create_agent_async( CreateAgent( memory_blocks=[], llm_config=LLMConfig.default_config("gpt-4o-mini"), @@ -3679,7 +3705,7 @@ def test_get_set_agents_for_identities(server: SyncServer, sarah_agent, charles_ ) # Get the agents for identity id - agent_states = server.agent_manager.list_agents(identity_id=identity.id, actor=default_user) + agent_states = await server.agent_manager.list_agents_async(identity_id=identity.id, actor=default_user) assert len(agent_states) == 3 # Check all agents are in the list @@ -3690,7 +3716,7 @@ def test_get_set_agents_for_identities(server: SyncServer, sarah_agent, charles_ assert not agent_without_identity.id in agent_state_ids # Get the agents for identifier key - agent_states = server.agent_manager.list_agents(identifier_keys=[identity.identifier_key], actor=default_user) + agent_states = await server.agent_manager.list_agents_async(identifier_keys=[identity.identifier_key], actor=default_user) assert len(agent_states) == 3 # Check all agents are in the list @@ -3713,13 +3739,13 @@ def test_get_set_agents_for_identities(server: SyncServer, sarah_agent, charles_ assert sarah_agent.id in agent_state_ids assert charles_agent.id in agent_state_ids - server.identity_manager.delete_identity(identity_id=identity.id, actor=default_user) + await server.identity_manager.delete_identity_async(identity_id=identity.id, actor=default_user) @pytest.mark.asyncio async def test_attach_detach_identity_from_block(server: SyncServer, default_block, default_user, event_loop): # Create an identity - identity = server.identity_manager.create_identity( + identity = await server.identity_manager.create_identity_async( IdentityCreate(name="caren", identifier_key="1234", identity_type=IdentityType.user, block_ids=[default_block.id]), actor=default_user, ) @@ -3729,10 +3755,10 @@ async def test_attach_detach_identity_from_block(server: SyncServer, default_blo assert len(blocks) == 1 and blocks[0].id == default_block.id # Now attempt to delete the identity - server.identity_manager.delete_identity(identity_id=identity.id, actor=default_user) + await server.identity_manager.delete_identity_async(identity_id=identity.id, actor=default_user) # Verify that the identity was deleted - identities = server.identity_manager.list_identities(actor=default_user) + identities = await server.identity_manager.list_identities_async(actor=default_user) assert len(identities) == 0 # Check that block has been detached too @@ -3745,7 +3771,7 @@ async def test_get_set_blocks_for_identities(server: SyncServer, default_block, block_manager = BlockManager() block_with_identity = block_manager.create_or_update_block(PydanticBlock(label="persona", value="Original Content"), actor=default_user) block_without_identity = block_manager.create_or_update_block(PydanticBlock(label="user", value="Original Content"), actor=default_user) - identity = server.identity_manager.create_identity( + identity = await server.identity_manager.create_identity_async( IdentityCreate( name="caren", identifier_key="1234", identity_type=IdentityType.user, block_ids=[default_block.id, block_with_identity.id] ), @@ -3786,10 +3812,11 @@ async def test_get_set_blocks_for_identities(server: SyncServer, default_block, assert not block_with_identity.id in block_ids assert not block_without_identity.id in block_ids - server.identity_manager.delete_identity(identity.id, actor=default_user) + await server.identity_manager.delete_identity_async(identity_id=identity.id, actor=default_user) -def test_upsert_properties(server: SyncServer, default_user): +@pytest.mark.asyncio +async def test_upsert_properties(server: SyncServer, default_user, event_loop): identity_create = IdentityCreate( identifier_key="1234", name="caren", @@ -3800,21 +3827,21 @@ def test_upsert_properties(server: SyncServer, default_user): ], ) - identity = server.identity_manager.create_identity(identity_create, actor=default_user) + identity = await server.identity_manager.create_identity_async(identity_create, actor=default_user) properties = [ IdentityProperty(key="email", value="caren@gmail.com", type=IdentityPropertyType.string), IdentityProperty(key="age", value="28", type=IdentityPropertyType.string), IdentityProperty(key="test", value=123, type=IdentityPropertyType.number), ] - updated_identity = server.identity_manager.upsert_identity_properties( + updated_identity = await server.identity_manager.upsert_identity_properties_async( identity_id=identity.id, properties=properties, actor=default_user, ) assert updated_identity.properties == properties - server.identity_manager.delete_identity(identity.id, actor=default_user) + await server.identity_manager.delete_identity_async(identity_id=identity.id, actor=default_user) # ====================================================================================================================== @@ -4872,15 +4899,17 @@ def test_get_run_messages(server: SyncServer, default_user: PydanticUser, sarah_ # ====================================================================================================================== -def test_job_usage_stats_add_and_get(server: SyncServer, sarah_agent, default_job, default_user): +@pytest.mark.asyncio +async def test_job_usage_stats_add_and_get(server: SyncServer, sarah_agent, default_job, default_user, event_loop): """Test adding and retrieving job usage statistics.""" job_manager = server.job_manager step_manager = server.step_manager # Add usage statistics - step_manager.log_step( + await step_manager.log_step_async( agent_id=sarah_agent.id, provider_name="openai", + provider_category="base", model="gpt-4o-mini", model_endpoint="https://api.openai.com/v1", context_window_limit=8192, @@ -4923,15 +4952,17 @@ def test_job_usage_stats_get_no_stats(server: SyncServer, default_job, default_u assert len(steps) == 0 -def test_job_usage_stats_add_multiple(server: SyncServer, sarah_agent, default_job, default_user): +@pytest.mark.asyncio +async def test_job_usage_stats_add_multiple(server: SyncServer, sarah_agent, default_job, default_user, event_loop): """Test adding multiple usage statistics entries for a job.""" job_manager = server.job_manager step_manager = server.step_manager # Add first usage statistics entry - step_manager.log_step( + await step_manager.log_step_async( agent_id=sarah_agent.id, provider_name="openai", + provider_category="base", model="gpt-4o-mini", model_endpoint="https://api.openai.com/v1", context_window_limit=8192, @@ -4945,9 +4976,10 @@ def test_job_usage_stats_add_multiple(server: SyncServer, sarah_agent, default_j ) # Add second usage statistics entry - step_manager.log_step( + await step_manager.log_step_async( agent_id=sarah_agent.id, provider_name="openai", + provider_category="base", model="gpt-4o-mini", model_endpoint="https://api.openai.com/v1", context_window_limit=8192, @@ -4986,14 +5018,16 @@ def test_job_usage_stats_get_nonexistent_job(server: SyncServer, default_user): job_manager.get_job_usage(job_id="nonexistent_job", actor=default_user) -def test_job_usage_stats_add_nonexistent_job(server: SyncServer, sarah_agent, default_user): +@pytest.mark.asyncio +async def test_job_usage_stats_add_nonexistent_job(server: SyncServer, sarah_agent, default_user, event_loop): """Test adding usage statistics for a nonexistent job.""" step_manager = server.step_manager with pytest.raises(NoResultFound): - step_manager.log_step( + await step_manager.log_step_async( agent_id=sarah_agent.id, provider_name="openai", + provider_category="base", model="gpt-4o-mini", model_endpoint="https://api.openai.com/v1", context_window_limit=8192, diff --git a/tests/test_multi_agent.py b/tests/test_multi_agent.py index 150922c4e..f989b4344 100644 --- a/tests/test_multi_agent.py +++ b/tests/test_multi_agent.py @@ -2,7 +2,7 @@ import pytest from sqlalchemy import delete from letta.config import LettaConfig -from letta.orm import Provider, Step +from letta.orm import Provider, ProviderTrace, Step from letta.schemas.agent import CreateAgent from letta.schemas.block import CreateBlock from letta.schemas.group import ( @@ -38,6 +38,7 @@ def org_id(server): # cleanup with db_registry.session() as session: + session.execute(delete(ProviderTrace)) session.execute(delete(Step)) session.execute(delete(Provider)) session.commit() diff --git a/tests/test_provider_trace.py b/tests/test_provider_trace.py new file mode 100644 index 000000000..43e13a343 --- /dev/null +++ b/tests/test_provider_trace.py @@ -0,0 +1,205 @@ +import asyncio +import json +import os +import threading +import time +import uuid + +import pytest +from dotenv import load_dotenv +from letta_client import Letta + +from letta.agents.letta_agent import LettaAgent +from letta.schemas.embedding_config import EmbeddingConfig +from letta.schemas.letta_message_content import TextContent +from letta.schemas.llm_config import LLMConfig +from letta.schemas.message import MessageCreate +from letta.server.rest_api.streaming_response import StreamingResponseWithStatusCode +from letta.services.agent_manager import AgentManager +from letta.services.block_manager import BlockManager +from letta.services.message_manager import MessageManager +from letta.services.passage_manager import PassageManager +from letta.services.step_manager import StepManager +from letta.services.telemetry_manager import NoopTelemetryManager, TelemetryManager + + +def _run_server(): + """Starts the Letta server in a background thread.""" + load_dotenv() + from letta.server.rest_api.app import start_server + + start_server(debug=True) + + +@pytest.fixture(scope="session") +def server_url(): + """Ensures a server is running and returns its base URL.""" + url = os.getenv("LETTA_SERVER_URL", "http://localhost:8283") + + if not os.getenv("LETTA_SERVER_URL"): + thread = threading.Thread(target=_run_server, daemon=True) + thread.start() + time.sleep(5) # Allow server startup time + + return url + + +# # --- Client Setup --- # +@pytest.fixture(scope="session") +def client(server_url): + """Creates a REST client for testing.""" + client = Letta(base_url=server_url) + yield client + + +@pytest.fixture(scope="session") +def event_loop(request): + """Create an instance of the default event loop for each test case.""" + loop = asyncio.get_event_loop_policy().new_event_loop() + yield loop + loop.close() + + +@pytest.fixture(scope="function") +def roll_dice_tool(client, roll_dice_tool_func): + print_tool = client.tools.upsert_from_function(func=roll_dice_tool_func) + yield print_tool + + +@pytest.fixture(scope="function") +def weather_tool(client, weather_tool_func): + weather_tool = client.tools.upsert_from_function(func=weather_tool_func) + yield weather_tool + + +@pytest.fixture(scope="function") +def print_tool(client, print_tool_func): + print_tool = client.tools.upsert_from_function(func=print_tool_func) + yield print_tool + + +@pytest.fixture(scope="function") +def agent_state(client, roll_dice_tool, weather_tool): + """Creates an agent and ensures cleanup after tests.""" + agent_state = client.agents.create( + name=f"test_compl_{str(uuid.uuid4())[5:]}", + tool_ids=[roll_dice_tool.id, weather_tool.id], + include_base_tools=True, + memory_blocks=[ + { + "label": "human", + "value": "Name: Matt", + }, + { + "label": "persona", + "value": "Friendly agent", + }, + ], + llm_config=LLMConfig.default_config(model_name="gpt-4o-mini"), + embedding_config=EmbeddingConfig.default_config(provider="openai"), + ) + yield agent_state + client.agents.delete(agent_state.id) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("message", ["Get the weather in San Francisco."]) +async def test_provider_trace_experimental_step(message, agent_state, default_user): + experimental_agent = LettaAgent( + agent_id=agent_state.id, + message_manager=MessageManager(), + agent_manager=AgentManager(), + block_manager=BlockManager(), + passage_manager=PassageManager(), + step_manager=StepManager(), + telemetry_manager=TelemetryManager(), + actor=default_user, + ) + + response = await experimental_agent.step([MessageCreate(role="user", content=[TextContent(text=message)])]) + tool_step = response.messages[0].step_id + reply_step = response.messages[-1].step_id + + tool_telemetry = await experimental_agent.telemetry_manager.get_provider_trace_by_step_id_async(step_id=tool_step, actor=default_user) + reply_telemetry = await experimental_agent.telemetry_manager.get_provider_trace_by_step_id_async(step_id=reply_step, actor=default_user) + assert tool_telemetry.request_json + assert reply_telemetry.request_json + + +@pytest.mark.asyncio +@pytest.mark.parametrize("message", ["Get the weather in San Francisco."]) +async def test_provider_trace_experimental_step_stream(message, agent_state, default_user, event_loop): + experimental_agent = LettaAgent( + agent_id=agent_state.id, + message_manager=MessageManager(), + agent_manager=AgentManager(), + block_manager=BlockManager(), + passage_manager=PassageManager(), + step_manager=StepManager(), + telemetry_manager=TelemetryManager(), + actor=default_user, + ) + stream = experimental_agent.step_stream([MessageCreate(role="user", content=[TextContent(text=message)])]) + + result = StreamingResponseWithStatusCode( + stream, + media_type="text/event-stream", + ) + + message_id = None + + async def test_send(message) -> None: + nonlocal message_id + if "body" in message and not message_id: + body = message["body"].decode("utf-8").split("data:") + message_id = json.loads(body[1])["id"] + + await result.stream_response(send=test_send) + + messages = await experimental_agent.message_manager.get_messages_by_ids_async([message_id], actor=default_user) + step_ids = set((message.step_id for message in messages)) + for step_id in step_ids: + telemetry_data = await experimental_agent.telemetry_manager.get_provider_trace_by_step_id_async(step_id=step_id, actor=default_user) + assert telemetry_data.request_json + assert telemetry_data.response_json + + +@pytest.mark.asyncio +@pytest.mark.parametrize("message", ["Get the weather in San Francisco."]) +async def test_provider_trace_step(client, agent_state, default_user, message, event_loop): + client.agents.messages.create(agent_id=agent_state.id, messages=[]) + response = client.agents.messages.create( + agent_id=agent_state.id, + messages=[MessageCreate(role="user", content=[TextContent(text=message)])], + ) + tool_step = response.messages[0].step_id + reply_step = response.messages[-1].step_id + + tool_telemetry = await TelemetryManager().get_provider_trace_by_step_id_async(step_id=tool_step, actor=default_user) + reply_telemetry = await TelemetryManager().get_provider_trace_by_step_id_async(step_id=reply_step, actor=default_user) + assert tool_telemetry.request_json + assert reply_telemetry.request_json + + +@pytest.mark.asyncio +@pytest.mark.parametrize("message", ["Get the weather in San Francisco."]) +async def test_noop_provider_trace(message, agent_state, default_user, event_loop): + experimental_agent = LettaAgent( + agent_id=agent_state.id, + message_manager=MessageManager(), + agent_manager=AgentManager(), + block_manager=BlockManager(), + passage_manager=PassageManager(), + step_manager=StepManager(), + telemetry_manager=NoopTelemetryManager(), + actor=default_user, + ) + + response = await experimental_agent.step([MessageCreate(role="user", content=[TextContent(text=message)])]) + tool_step = response.messages[0].step_id + reply_step = response.messages[-1].step_id + + tool_telemetry = await experimental_agent.telemetry_manager.get_provider_trace_by_step_id_async(step_id=tool_step, actor=default_user) + reply_telemetry = await experimental_agent.telemetry_manager.get_provider_trace_by_step_id_async(step_id=reply_step, actor=default_user) + assert tool_telemetry is None + assert reply_telemetry is None diff --git a/tests/test_providers.py b/tests/test_providers.py index 2ab6606d7..96010e9a1 100644 --- a/tests/test_providers.py +++ b/tests/test_providers.py @@ -1,15 +1,12 @@ -import os +import pytest from letta.schemas.providers import ( - AnthropicBedrockProvider, AnthropicProvider, AzureProvider, DeepSeekProvider, GoogleAIProvider, GoogleVertexProvider, GroqProvider, - MistralProvider, - OllamaProvider, OpenAIProvider, TogetherProvider, ) @@ -17,11 +14,9 @@ from letta.settings import model_settings def test_openai(): - api_key = os.getenv("OPENAI_API_KEY") - assert api_key is not None provider = OpenAIProvider( name="openai", - api_key=api_key, + api_key=model_settings.openai_api_key, base_url=model_settings.openai_api_base, ) models = provider.list_llm_models() @@ -33,34 +28,54 @@ def test_openai(): assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}" -def test_deepseek(): - api_key = os.getenv("DEEPSEEK_API_KEY") - assert api_key is not None - provider = DeepSeekProvider( - name="deepseek", - api_key=api_key, +@pytest.mark.asyncio +async def test_openai_async(): + provider = OpenAIProvider( + name="openai", + api_key=model_settings.openai_api_key, + base_url=model_settings.openai_api_base, ) + models = await provider.list_llm_models_async() + assert len(models) > 0 + assert models[0].handle == f"{provider.name}/{models[0].model}" + + embedding_models = await provider.list_embedding_models_async() + assert len(embedding_models) > 0 + assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}" + + +def test_deepseek(): + provider = DeepSeekProvider(name="deepseek", api_key=model_settings.deepseek_api_key) models = provider.list_llm_models() assert len(models) > 0 assert models[0].handle == f"{provider.name}/{models[0].model}" def test_anthropic(): - api_key = os.getenv("ANTHROPIC_API_KEY") - assert api_key is not None provider = AnthropicProvider( name="anthropic", - api_key=api_key, + api_key=model_settings.anthropic_api_key, ) models = provider.list_llm_models() assert len(models) > 0 assert models[0].handle == f"{provider.name}/{models[0].model}" +@pytest.mark.asyncio +async def test_anthropic_async(): + provider = AnthropicProvider( + name="anthropic", + api_key=model_settings.anthropic_api_key, + ) + models = await provider.list_llm_models_async() + assert len(models) > 0 + assert models[0].handle == f"{provider.name}/{models[0].model}" + + def test_groq(): provider = GroqProvider( name="groq", - api_key=os.getenv("GROQ_API_KEY"), + api_key=model_settings.groq_api_key, ) models = provider.list_llm_models() assert len(models) > 0 @@ -70,8 +85,9 @@ def test_groq(): def test_azure(): provider = AzureProvider( name="azure", - api_key=os.getenv("AZURE_API_KEY"), - base_url=os.getenv("AZURE_BASE_URL"), + api_key=model_settings.azure_api_key, + base_url=model_settings.azure_base_url, + api_version=model_settings.azure_api_version, ) models = provider.list_llm_models() assert len(models) > 0 @@ -82,26 +98,24 @@ def test_azure(): assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}" -def test_ollama(): - base_url = os.getenv("OLLAMA_BASE_URL") - assert base_url is not None - provider = OllamaProvider( - name="ollama", - base_url=base_url, - default_prompt_formatter=model_settings.default_prompt_formatter, - api_key=None, - ) - models = provider.list_llm_models() - assert len(models) > 0 - assert models[0].handle == f"{provider.name}/{models[0].model}" - - embedding_models = provider.list_embedding_models() - assert len(embedding_models) > 0 - assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}" +# def test_ollama(): +# provider = OllamaProvider( +# name="ollama", +# base_url=model_settings.ollama_base_url, +# api_key=None, +# default_prompt_formatter=model_settings.default_prompt_formatter, +# ) +# models = provider.list_llm_models() +# assert len(models) > 0 +# assert models[0].handle == f"{provider.name}/{models[0].model}" +# +# embedding_models = provider.list_embedding_models() +# assert len(embedding_models) > 0 +# assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}" def test_googleai(): - api_key = os.getenv("GEMINI_API_KEY") + api_key = model_settings.gemini_api_key assert api_key is not None provider = GoogleAIProvider( name="google_ai", @@ -116,11 +130,28 @@ def test_googleai(): assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}" +@pytest.mark.asyncio +async def test_googleai_async(): + api_key = model_settings.gemini_api_key + assert api_key is not None + provider = GoogleAIProvider( + name="google_ai", + api_key=api_key, + ) + models = await provider.list_llm_models_async() + assert len(models) > 0 + assert models[0].handle == f"{provider.name}/{models[0].model}" + + embedding_models = await provider.list_embedding_models_async() + assert len(embedding_models) > 0 + assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}" + + def test_google_vertex(): provider = GoogleVertexProvider( name="google_vertex", - google_cloud_project=os.getenv("GCP_PROJECT_ID"), - google_cloud_location=os.getenv("GCP_REGION"), + google_cloud_project=model_settings.google_cloud_project, + google_cloud_location=model_settings.google_cloud_location, ) models = provider.list_llm_models() assert len(models) > 0 @@ -131,50 +162,57 @@ def test_google_vertex(): assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}" -def test_mistral(): - provider = MistralProvider( - name="mistral", - api_key=os.getenv("MISTRAL_API_KEY"), - ) - models = provider.list_llm_models() - assert len(models) > 0 - assert models[0].handle == f"{provider.name}/{models[0].model}" - - def test_together(): provider = TogetherProvider( name="together", - api_key=os.getenv("TOGETHER_API_KEY"), - default_prompt_formatter="chatml", + api_key=model_settings.together_api_key, + default_prompt_formatter=model_settings.default_prompt_formatter, ) models = provider.list_llm_models() assert len(models) > 0 assert models[0].handle == f"{provider.name}/{models[0].model}" - embedding_models = provider.list_embedding_models() - assert len(embedding_models) > 0 - assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}" + # TODO: We don't have embedding models on together for CI + # embedding_models = provider.list_embedding_models() + # assert len(embedding_models) > 0 + # assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}" -def test_anthropic_bedrock(): - from letta.settings import model_settings - - provider = AnthropicBedrockProvider(name="bedrock", aws_region=model_settings.aws_region) - models = provider.list_llm_models() +@pytest.mark.asyncio +async def test_together_async(): + provider = TogetherProvider( + name="together", + api_key=model_settings.together_api_key, + default_prompt_formatter=model_settings.default_prompt_formatter, + ) + models = await provider.list_llm_models_async() assert len(models) > 0 assert models[0].handle == f"{provider.name}/{models[0].model}" - embedding_models = provider.list_embedding_models() - assert len(embedding_models) > 0 - assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}" + # TODO: We don't have embedding models on together for CI + # embedding_models = provider.list_embedding_models() + # assert len(embedding_models) > 0 + # assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}" + + +# TODO: Add back in, difficulty adding this to CI properly, need boto credentials +# def test_anthropic_bedrock(): +# from letta.settings import model_settings +# +# provider = AnthropicBedrockProvider(name="bedrock", aws_region=model_settings.aws_region) +# models = provider.list_llm_models() +# assert len(models) > 0 +# assert models[0].handle == f"{provider.name}/{models[0].model}" +# +# embedding_models = provider.list_embedding_models() +# assert len(embedding_models) > 0 +# assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}" def test_custom_anthropic(): - api_key = os.getenv("ANTHROPIC_API_KEY") - assert api_key is not None provider = AnthropicProvider( name="custom_anthropic", - api_key=api_key, + api_key=model_settings.anthropic_api_key, ) models = provider.list_llm_models() assert len(models) > 0 diff --git a/tests/test_server.py b/tests/test_server.py index a3932d813..200ff54eb 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -11,7 +11,7 @@ from sqlalchemy import delete import letta.utils as utils from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, LETTA_DIR, LETTA_TOOL_EXECUTION_DIR -from letta.orm import Provider, Step +from letta.orm import Provider, ProviderTrace, Step from letta.schemas.block import CreateBlock from letta.schemas.enums import MessageRole, ProviderCategory, ProviderType from letta.schemas.letta_message import LettaMessage, ReasoningMessage, SystemMessage, ToolCallMessage, ToolReturnMessage, UserMessage @@ -286,6 +286,7 @@ def org_id(server): # cleanup with db_registry.session() as session: + session.execute(delete(ProviderTrace)) session.execute(delete(Step)) session.execute(delete(Provider)) session.commit() @@ -565,7 +566,8 @@ def test_delete_agent_same_org(server: SyncServer, org_id: str, user: User): server.agent_manager.delete_agent(agent_state.id, actor=another_user) -def test_read_local_llm_configs(server: SyncServer, user: User): +@pytest.mark.asyncio +async def test_read_local_llm_configs(server: SyncServer, user: User): configs_base_dir = os.path.join(os.path.expanduser("~"), ".letta", "llm_configs") clean_up_dir = False if not os.path.exists(configs_base_dir): @@ -588,7 +590,7 @@ def test_read_local_llm_configs(server: SyncServer, user: User): # Call list_llm_models assert os.path.exists(configs_base_dir) - llm_models = server.list_llm_models(actor=user) + llm_models = await server.list_llm_models_async(actor=user) # Assert that the config is in the returned models assert any( @@ -935,7 +937,7 @@ def test_composio_client_simple(server): assert len(actions) > 0 -def test_memory_rebuild_count(server, user, disable_e2b_api_key, base_tools, base_memory_tools): +async def test_memory_rebuild_count(server, user, disable_e2b_api_key, base_tools, base_memory_tools): """Test that the memory rebuild is generating the correct number of role=system messages""" actor = user # create agent @@ -1223,7 +1225,8 @@ def test_add_remove_tools_update_agent(server: SyncServer, user_id: str, base_to assert len(agent_state.tools) == len(base_tools) - 2 -def test_messages_with_provider_override(server: SyncServer, user_id: str): +@pytest.mark.asyncio +async def test_messages_with_provider_override(server: SyncServer, user_id: str): actor = server.user_manager.get_user_or_default(user_id) provider = server.provider_manager.create_provider( request=ProviderCreate( @@ -1233,10 +1236,10 @@ def test_messages_with_provider_override(server: SyncServer, user_id: str): ), actor=actor, ) - models = server.list_llm_models(actor=actor, provider_category=[ProviderCategory.byok]) + models = await server.list_llm_models_async(actor=actor, provider_category=[ProviderCategory.byok]) assert provider.name in [model.provider_name for model in models] - models = server.list_llm_models(actor=actor, provider_category=[ProviderCategory.base]) + models = await server.list_llm_models_async(actor=actor, provider_category=[ProviderCategory.base]) assert provider.name not in [model.provider_name for model in models] agent = server.create_agent( @@ -1302,11 +1305,12 @@ def test_messages_with_provider_override(server: SyncServer, user_id: str): assert total_tokens == usage.total_tokens -def test_unique_handles_for_provider_configs(server: SyncServer, user: User): - models = server.list_llm_models(actor=user) +@pytest.mark.asyncio +async def test_unique_handles_for_provider_configs(server: SyncServer, user: User): + models = await server.list_llm_models_async(actor=user) model_handles = [model.handle for model in models] assert sorted(model_handles) == sorted(list(set(model_handles))), "All models should have unique handles" - embeddings = server.list_embedding_models(actor=user) + embeddings = await server.list_embedding_models_async(actor=user) embedding_handles = [embedding.handle for embedding in embeddings] assert sorted(embedding_handles) == sorted(list(set(embedding_handles))), "All embeddings should have unique handles" diff --git a/tests/utils.py b/tests/utils.py index 65c3ee2ff..04778d112 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -4,15 +4,16 @@ import string import time from datetime import datetime, timezone from importlib import util -from typing import Dict, Iterator, List, Tuple +from typing import Dict, Iterator, List, Optional, Tuple import requests +from letta_client import Letta, SystemMessage from letta.config import LettaConfig from letta.data_sources.connectors import DataConnector -from letta.schemas.enums import MessageRole +from letta.functions.functions import parse_source_code from letta.schemas.file import FileMetadata -from letta.schemas.message import Message +from letta.schemas.tool import Tool from letta.settings import TestSettings from .constants import TIMEOUT @@ -152,7 +153,7 @@ def with_qdrant_storage(storage: list[str]): def wait_for_incoming_message( - client, + client: Letta, agent_id: str, substring: str = "[Incoming message from agent with ID", max_wait_seconds: float = 10.0, @@ -166,13 +167,13 @@ def wait_for_incoming_message( deadline = time.time() + max_wait_seconds while time.time() < deadline: - messages = client.server.message_manager.list_messages_for_agent(agent_id=agent_id, actor=client.user) + messages = client.agents.messages.list(agent_id)[1:] # Check for the system message containing `substring` - def get_message_text(message: Message) -> str: - return message.content[0].text if message.content and len(message.content) == 1 else "" + def get_message_text(message: SystemMessage) -> str: + return message.content if message.content else "" - if any(message.role == MessageRole.system and substring in get_message_text(message) for message in messages): + if any(isinstance(message, SystemMessage) and substring in get_message_text(message) for message in messages): return True time.sleep(sleep_interval) @@ -199,3 +200,21 @@ def wait_for_server(url, timeout=30, interval=0.5): def random_string(length: int) -> str: return "".join(random.choices(string.ascii_letters + string.digits, k=length)) + + +def create_tool_from_func( + func, + tags: Optional[List[str]] = None, + description: Optional[str] = None, +): + source_code = parse_source_code(func) + source_type = "python" + if not tags: + tags = [] + + return Tool( + source_type=source_type, + source_code=source_code, + tags=tags, + description=description, + )