mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
chore: bump version 0.7.21 (#2653)
Co-authored-by: Andy Li <55300002+cliandy@users.noreply.github.com> Co-authored-by: Kevin Lin <klin5061@gmail.com> Co-authored-by: Sarah Wooders <sarahwooders@gmail.com> Co-authored-by: jnjpng <jin@letta.com> Co-authored-by: Matthew Zhou <mattzh1314@gmail.com>
This commit is contained in:
parent
97e454124d
commit
c0efe8ad0c
@ -0,0 +1,31 @@
|
||||
"""add provider category to steps
|
||||
|
||||
Revision ID: 6c53224a7a58
|
||||
Revises: cc8dc340836d
|
||||
Create Date: 2025-05-21 10:09:43.761669
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "6c53224a7a58"
|
||||
down_revision: Union[str, None] = "cc8dc340836d"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("steps", sa.Column("provider_category", sa.String(), nullable=True))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("steps", "provider_category")
|
||||
# ### end Alembic commands ###
|
@ -0,0 +1,50 @@
|
||||
"""add support for request and response jsons from llm providers
|
||||
|
||||
Revision ID: cc8dc340836d
|
||||
Revises: 220856bbf43b
|
||||
Create Date: 2025-05-19 14:25:41.999676
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "cc8dc340836d"
|
||||
down_revision: Union[str, None] = "220856bbf43b"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"provider_traces",
|
||||
sa.Column("id", sa.String(), nullable=False),
|
||||
sa.Column("request_json", sa.JSON(), nullable=False),
|
||||
sa.Column("response_json", sa.JSON(), nullable=False),
|
||||
sa.Column("step_id", sa.String(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
|
||||
sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False),
|
||||
sa.Column("_created_by_id", sa.String(), nullable=True),
|
||||
sa.Column("_last_updated_by_id", sa.String(), nullable=True),
|
||||
sa.Column("organization_id", sa.String(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["organization_id"],
|
||||
["organizations.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index("ix_step_id", "provider_traces", ["step_id"], unique=False)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index("ix_step_id", table_name="provider_traces")
|
||||
op.drop_table("provider_traces")
|
||||
# ### end Alembic commands ###
|
@ -1,4 +1,4 @@
|
||||
__version__ = "0.7.20"
|
||||
__version__ = "0.7.21"
|
||||
|
||||
# import clients
|
||||
from letta.client.client import LocalClient, RESTClient, create_client
|
||||
|
293
letta/agent.py
293
letta/agent.py
@ -1,4 +1,6 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import traceback
|
||||
import warnings
|
||||
@ -7,6 +9,7 @@ from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
from openai.types.beta.function_tool import FunctionTool as OpenAITool
|
||||
|
||||
from letta.agents.helpers import generate_step_id
|
||||
from letta.constants import (
|
||||
CLI_WARNING_PREFIX,
|
||||
COMPOSIO_ENTITY_ENV_VAR_KEY,
|
||||
@ -16,6 +19,7 @@ from letta.constants import (
|
||||
LETTA_CORE_TOOL_MODULE_NAME,
|
||||
LETTA_MULTI_AGENT_TOOL_MODULE_NAME,
|
||||
LLM_MAX_TOKENS,
|
||||
READ_ONLY_BLOCK_EDIT_ERROR,
|
||||
REQ_HEARTBEAT_MESSAGE,
|
||||
SEND_MESSAGE_TOOL_NAME,
|
||||
)
|
||||
@ -41,7 +45,7 @@ from letta.orm.enums import ToolType
|
||||
from letta.schemas.agent import AgentState, AgentStepResponse, UpdateAgent, get_prompt_template_for_agent_type
|
||||
from letta.schemas.block import BlockUpdate
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.enums import MessageRole, ProviderType
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.memory import ContextWindowOverview, Memory
|
||||
from letta.schemas.message import Message, MessageCreate, ToolReturn
|
||||
@ -61,9 +65,10 @@ from letta.services.message_manager import MessageManager
|
||||
from letta.services.passage_manager import PassageManager
|
||||
from letta.services.provider_manager import ProviderManager
|
||||
from letta.services.step_manager import StepManager
|
||||
from letta.services.telemetry_manager import NoopTelemetryManager, TelemetryManager
|
||||
from letta.services.tool_executor.tool_execution_sandbox import ToolExecutionSandbox
|
||||
from letta.services.tool_manager import ToolManager
|
||||
from letta.settings import summarizer_settings
|
||||
from letta.settings import settings, summarizer_settings
|
||||
from letta.streaming_interface import StreamingRefreshCLIInterface
|
||||
from letta.system import get_heartbeat, get_token_limit_warning, package_function_response, package_summarize_message, package_user_message
|
||||
from letta.tracing import log_event, trace_method
|
||||
@ -141,6 +146,7 @@ class Agent(BaseAgent):
|
||||
self.agent_manager = AgentManager()
|
||||
self.job_manager = JobManager()
|
||||
self.step_manager = StepManager()
|
||||
self.telemetry_manager = TelemetryManager() if settings.llm_api_logging else NoopTelemetryManager()
|
||||
|
||||
# State needed for heartbeat pausing
|
||||
|
||||
@ -298,6 +304,7 @@ class Agent(BaseAgent):
|
||||
step_count: Optional[int] = None,
|
||||
last_function_failed: bool = False,
|
||||
put_inner_thoughts_first: bool = True,
|
||||
step_id: Optional[str] = None,
|
||||
) -> ChatCompletionResponse | None:
|
||||
"""Get response from LLM API with robust retry mechanism."""
|
||||
log_telemetry(self.logger, "_get_ai_reply start")
|
||||
@ -347,8 +354,9 @@ class Agent(BaseAgent):
|
||||
messages=message_sequence,
|
||||
llm_config=self.agent_state.llm_config,
|
||||
tools=allowed_functions,
|
||||
stream=stream,
|
||||
force_tool_call=force_tool_call,
|
||||
telemetry_manager=self.telemetry_manager,
|
||||
step_id=step_id,
|
||||
)
|
||||
else:
|
||||
# Fallback to existing flow
|
||||
@ -365,6 +373,9 @@ class Agent(BaseAgent):
|
||||
stream_interface=self.interface,
|
||||
put_inner_thoughts_first=put_inner_thoughts_first,
|
||||
name=self.agent_state.name,
|
||||
telemetry_manager=self.telemetry_manager,
|
||||
step_id=step_id,
|
||||
actor=self.user,
|
||||
)
|
||||
log_telemetry(self.logger, "_get_ai_reply create finish")
|
||||
|
||||
@ -840,6 +851,9 @@ class Agent(BaseAgent):
|
||||
# Extract job_id from metadata if present
|
||||
job_id = metadata.get("job_id") if metadata else None
|
||||
|
||||
# Declare step_id for the given step to be used as the step is processing.
|
||||
step_id = generate_step_id()
|
||||
|
||||
# Step 0: update core memory
|
||||
# only pulling latest block data if shared memory is being used
|
||||
current_persisted_memory = Memory(
|
||||
@ -870,6 +884,7 @@ class Agent(BaseAgent):
|
||||
step_count=step_count,
|
||||
last_function_failed=last_function_failed,
|
||||
put_inner_thoughts_first=put_inner_thoughts_first,
|
||||
step_id=step_id,
|
||||
)
|
||||
if not response:
|
||||
# EDGE CASE: Function call failed AND there's no tools left for agent to call -> return early
|
||||
@ -944,6 +959,7 @@ class Agent(BaseAgent):
|
||||
actor=self.user,
|
||||
agent_id=self.agent_state.id,
|
||||
provider_name=self.agent_state.llm_config.model_endpoint_type,
|
||||
provider_category=self.agent_state.llm_config.provider_category or "base",
|
||||
model=self.agent_state.llm_config.model,
|
||||
model_endpoint=self.agent_state.llm_config.model_endpoint,
|
||||
context_window_limit=self.agent_state.llm_config.context_window,
|
||||
@ -953,6 +969,7 @@ class Agent(BaseAgent):
|
||||
actor=self.user,
|
||||
),
|
||||
job_id=job_id,
|
||||
step_id=step_id,
|
||||
)
|
||||
for message in all_new_messages:
|
||||
message.step_id = step.id
|
||||
@ -1255,6 +1272,276 @@ class Agent(BaseAgent):
|
||||
functions_definitions=available_functions_definitions,
|
||||
)
|
||||
|
||||
async def get_context_window_async(self) -> ContextWindowOverview:
|
||||
if os.getenv("LETTA_ENVIRONMENT") == "PRODUCTION":
|
||||
return await self.get_context_window_from_anthropic_async()
|
||||
return await self.get_context_window_from_tiktoken_async()
|
||||
|
||||
async def get_context_window_from_tiktoken_async(self) -> ContextWindowOverview:
|
||||
"""Get the context window of the agent"""
|
||||
# Grab the in-context messages
|
||||
# conversion of messages to OpenAI dict format, which is passed to the token counter
|
||||
(in_context_messages, passage_manager_size, message_manager_size) = await asyncio.gather(
|
||||
self.agent_manager.get_in_context_messages_async(agent_id=self.agent_state.id, actor=self.user),
|
||||
self.passage_manager.size_async(actor=self.user, agent_id=self.agent_state.id),
|
||||
self.message_manager.size_async(actor=self.user, agent_id=self.agent_state.id),
|
||||
)
|
||||
in_context_messages_openai = [m.to_openai_dict() for m in in_context_messages]
|
||||
|
||||
# Extract system, memory and external summary
|
||||
if (
|
||||
len(in_context_messages) > 0
|
||||
and in_context_messages[0].role == MessageRole.system
|
||||
and in_context_messages[0].content
|
||||
and len(in_context_messages[0].content) == 1
|
||||
and isinstance(in_context_messages[0].content[0], TextContent)
|
||||
):
|
||||
system_message = in_context_messages[0].content[0].text
|
||||
|
||||
external_memory_marker_pos = system_message.find("###")
|
||||
core_memory_marker_pos = system_message.find("<", external_memory_marker_pos)
|
||||
if external_memory_marker_pos != -1 and core_memory_marker_pos != -1:
|
||||
system_prompt = system_message[:external_memory_marker_pos].strip()
|
||||
external_memory_summary = system_message[external_memory_marker_pos:core_memory_marker_pos].strip()
|
||||
core_memory = system_message[core_memory_marker_pos:].strip()
|
||||
else:
|
||||
# if no markers found, put everything in system message
|
||||
system_prompt = system_message
|
||||
external_memory_summary = ""
|
||||
core_memory = ""
|
||||
else:
|
||||
# if no system message, fall back on agent's system prompt
|
||||
system_prompt = self.agent_state.system
|
||||
external_memory_summary = ""
|
||||
core_memory = ""
|
||||
|
||||
num_tokens_system = count_tokens(system_prompt)
|
||||
num_tokens_core_memory = count_tokens(core_memory)
|
||||
num_tokens_external_memory_summary = count_tokens(external_memory_summary)
|
||||
|
||||
# Check if there's a summary message in the message queue
|
||||
if (
|
||||
len(in_context_messages) > 1
|
||||
and in_context_messages[1].role == MessageRole.user
|
||||
and in_context_messages[1].content
|
||||
and len(in_context_messages[1].content) == 1
|
||||
and isinstance(in_context_messages[1].content[0], TextContent)
|
||||
# TODO remove hardcoding
|
||||
and "The following is a summary of the previous " in in_context_messages[1].content[0].text
|
||||
):
|
||||
# Summary message exists
|
||||
text_content = in_context_messages[1].content[0].text
|
||||
assert text_content is not None
|
||||
summary_memory = text_content
|
||||
num_tokens_summary_memory = count_tokens(text_content)
|
||||
# with a summary message, the real messages start at index 2
|
||||
num_tokens_messages = (
|
||||
num_tokens_from_messages(messages=in_context_messages_openai[2:], model=self.model)
|
||||
if len(in_context_messages_openai) > 2
|
||||
else 0
|
||||
)
|
||||
|
||||
else:
|
||||
summary_memory = None
|
||||
num_tokens_summary_memory = 0
|
||||
# with no summary message, the real messages start at index 1
|
||||
num_tokens_messages = (
|
||||
num_tokens_from_messages(messages=in_context_messages_openai[1:], model=self.model)
|
||||
if len(in_context_messages_openai) > 1
|
||||
else 0
|
||||
)
|
||||
|
||||
# tokens taken up by function definitions
|
||||
agent_state_tool_jsons = [t.json_schema for t in self.agent_state.tools]
|
||||
if agent_state_tool_jsons:
|
||||
available_functions_definitions = [OpenAITool(type="function", function=f) for f in agent_state_tool_jsons]
|
||||
num_tokens_available_functions_definitions = num_tokens_from_functions(functions=agent_state_tool_jsons, model=self.model)
|
||||
else:
|
||||
available_functions_definitions = []
|
||||
num_tokens_available_functions_definitions = 0
|
||||
|
||||
num_tokens_used_total = (
|
||||
num_tokens_system # system prompt
|
||||
+ num_tokens_available_functions_definitions # function definitions
|
||||
+ num_tokens_core_memory # core memory
|
||||
+ num_tokens_external_memory_summary # metadata (statistics) about recall/archival
|
||||
+ num_tokens_summary_memory # summary of ongoing conversation
|
||||
+ num_tokens_messages # tokens taken by messages
|
||||
)
|
||||
assert isinstance(num_tokens_used_total, int)
|
||||
|
||||
return ContextWindowOverview(
|
||||
# context window breakdown (in messages)
|
||||
num_messages=len(in_context_messages),
|
||||
num_archival_memory=passage_manager_size,
|
||||
num_recall_memory=message_manager_size,
|
||||
num_tokens_external_memory_summary=num_tokens_external_memory_summary,
|
||||
external_memory_summary=external_memory_summary,
|
||||
# top-level information
|
||||
context_window_size_max=self.agent_state.llm_config.context_window,
|
||||
context_window_size_current=num_tokens_used_total,
|
||||
# context window breakdown (in tokens)
|
||||
num_tokens_system=num_tokens_system,
|
||||
system_prompt=system_prompt,
|
||||
num_tokens_core_memory=num_tokens_core_memory,
|
||||
core_memory=core_memory,
|
||||
num_tokens_summary_memory=num_tokens_summary_memory,
|
||||
summary_memory=summary_memory,
|
||||
num_tokens_messages=num_tokens_messages,
|
||||
messages=in_context_messages,
|
||||
# related to functions
|
||||
num_tokens_functions_definitions=num_tokens_available_functions_definitions,
|
||||
functions_definitions=available_functions_definitions,
|
||||
)
|
||||
|
||||
async def get_context_window_from_anthropic_async(self) -> ContextWindowOverview:
|
||||
"""Get the context window of the agent"""
|
||||
anthropic_client = LLMClient.create(provider_type=ProviderType.anthropic, actor=self.user)
|
||||
model = self.agent_state.llm_config.model if self.agent_state.llm_config.model_endpoint_type == "anthropic" else None
|
||||
|
||||
# Grab the in-context messages
|
||||
# conversion of messages to anthropic dict format, which is passed to the token counter
|
||||
(in_context_messages, passage_manager_size, message_manager_size) = await asyncio.gather(
|
||||
self.agent_manager.get_in_context_messages_async(agent_id=self.agent_state.id, actor=self.user),
|
||||
self.passage_manager.size_async(actor=self.user, agent_id=self.agent_state.id),
|
||||
self.message_manager.size_async(actor=self.user, agent_id=self.agent_state.id),
|
||||
)
|
||||
in_context_messages_anthropic = [m.to_anthropic_dict() for m in in_context_messages]
|
||||
|
||||
# Extract system, memory and external summary
|
||||
if (
|
||||
len(in_context_messages) > 0
|
||||
and in_context_messages[0].role == MessageRole.system
|
||||
and in_context_messages[0].content
|
||||
and len(in_context_messages[0].content) == 1
|
||||
and isinstance(in_context_messages[0].content[0], TextContent)
|
||||
):
|
||||
system_message = in_context_messages[0].content[0].text
|
||||
|
||||
external_memory_marker_pos = system_message.find("###")
|
||||
core_memory_marker_pos = system_message.find("<", external_memory_marker_pos)
|
||||
if external_memory_marker_pos != -1 and core_memory_marker_pos != -1:
|
||||
system_prompt = system_message[:external_memory_marker_pos].strip()
|
||||
external_memory_summary = system_message[external_memory_marker_pos:core_memory_marker_pos].strip()
|
||||
core_memory = system_message[core_memory_marker_pos:].strip()
|
||||
else:
|
||||
# if no markers found, put everything in system message
|
||||
system_prompt = system_message
|
||||
external_memory_summary = None
|
||||
core_memory = None
|
||||
else:
|
||||
# if no system message, fall back on agent's system prompt
|
||||
system_prompt = self.agent_state.system
|
||||
external_memory_summary = None
|
||||
core_memory = None
|
||||
|
||||
num_tokens_system_coroutine = anthropic_client.count_tokens(model=model, messages=[{"role": "user", "content": system_prompt}])
|
||||
num_tokens_core_memory_coroutine = (
|
||||
anthropic_client.count_tokens(model=model, messages=[{"role": "user", "content": core_memory}])
|
||||
if core_memory
|
||||
else asyncio.sleep(0, result=0)
|
||||
)
|
||||
num_tokens_external_memory_summary_coroutine = (
|
||||
anthropic_client.count_tokens(model=model, messages=[{"role": "user", "content": external_memory_summary}])
|
||||
if external_memory_summary
|
||||
else asyncio.sleep(0, result=0)
|
||||
)
|
||||
|
||||
# Check if there's a summary message in the message queue
|
||||
if (
|
||||
len(in_context_messages) > 1
|
||||
and in_context_messages[1].role == MessageRole.user
|
||||
and in_context_messages[1].content
|
||||
and len(in_context_messages[1].content) == 1
|
||||
and isinstance(in_context_messages[1].content[0], TextContent)
|
||||
# TODO remove hardcoding
|
||||
and "The following is a summary of the previous " in in_context_messages[1].content[0].text
|
||||
):
|
||||
# Summary message exists
|
||||
text_content = in_context_messages[1].content[0].text
|
||||
assert text_content is not None
|
||||
summary_memory = text_content
|
||||
num_tokens_summary_memory_coroutine = anthropic_client.count_tokens(
|
||||
model=model, messages=[{"role": "user", "content": summary_memory}]
|
||||
)
|
||||
# with a summary message, the real messages start at index 2
|
||||
num_tokens_messages_coroutine = (
|
||||
anthropic_client.count_tokens(model=model, messages=in_context_messages_anthropic[2:])
|
||||
if len(in_context_messages_anthropic) > 2
|
||||
else asyncio.sleep(0, result=0)
|
||||
)
|
||||
|
||||
else:
|
||||
summary_memory = None
|
||||
num_tokens_summary_memory_coroutine = asyncio.sleep(0, result=0)
|
||||
# with no summary message, the real messages start at index 1
|
||||
num_tokens_messages_coroutine = (
|
||||
anthropic_client.count_tokens(model=model, messages=in_context_messages_anthropic[1:])
|
||||
if len(in_context_messages_anthropic) > 1
|
||||
else asyncio.sleep(0, result=0)
|
||||
)
|
||||
|
||||
# tokens taken up by function definitions
|
||||
if self.agent_state.tools and len(self.agent_state.tools) > 0:
|
||||
available_functions_definitions = [OpenAITool(type="function", function=f.json_schema) for f in self.agent_state.tools]
|
||||
num_tokens_available_functions_definitions_coroutine = anthropic_client.count_tokens(
|
||||
model=model,
|
||||
tools=available_functions_definitions,
|
||||
)
|
||||
else:
|
||||
available_functions_definitions = []
|
||||
num_tokens_available_functions_definitions_coroutine = asyncio.sleep(0, result=0)
|
||||
|
||||
(
|
||||
num_tokens_system,
|
||||
num_tokens_core_memory,
|
||||
num_tokens_external_memory_summary,
|
||||
num_tokens_summary_memory,
|
||||
num_tokens_messages,
|
||||
num_tokens_available_functions_definitions,
|
||||
) = await asyncio.gather(
|
||||
num_tokens_system_coroutine,
|
||||
num_tokens_core_memory_coroutine,
|
||||
num_tokens_external_memory_summary_coroutine,
|
||||
num_tokens_summary_memory_coroutine,
|
||||
num_tokens_messages_coroutine,
|
||||
num_tokens_available_functions_definitions_coroutine,
|
||||
)
|
||||
|
||||
num_tokens_used_total = (
|
||||
num_tokens_system # system prompt
|
||||
+ num_tokens_available_functions_definitions # function definitions
|
||||
+ num_tokens_core_memory # core memory
|
||||
+ num_tokens_external_memory_summary # metadata (statistics) about recall/archival
|
||||
+ num_tokens_summary_memory # summary of ongoing conversation
|
||||
+ num_tokens_messages # tokens taken by messages
|
||||
)
|
||||
assert isinstance(num_tokens_used_total, int)
|
||||
|
||||
return ContextWindowOverview(
|
||||
# context window breakdown (in messages)
|
||||
num_messages=len(in_context_messages),
|
||||
num_archival_memory=passage_manager_size,
|
||||
num_recall_memory=message_manager_size,
|
||||
num_tokens_external_memory_summary=num_tokens_external_memory_summary,
|
||||
external_memory_summary=external_memory_summary,
|
||||
# top-level information
|
||||
context_window_size_max=self.agent_state.llm_config.context_window,
|
||||
context_window_size_current=num_tokens_used_total,
|
||||
# context window breakdown (in tokens)
|
||||
num_tokens_system=num_tokens_system,
|
||||
system_prompt=system_prompt,
|
||||
num_tokens_core_memory=num_tokens_core_memory,
|
||||
core_memory=core_memory,
|
||||
num_tokens_summary_memory=num_tokens_summary_memory,
|
||||
summary_memory=summary_memory,
|
||||
num_tokens_messages=num_tokens_messages,
|
||||
messages=in_context_messages,
|
||||
# related to functions
|
||||
num_tokens_functions_definitions=num_tokens_available_functions_definitions,
|
||||
functions_definitions=available_functions_definitions,
|
||||
)
|
||||
|
||||
def count_tokens(self) -> int:
|
||||
"""Count the tokens in the current context window"""
|
||||
context_window_breakdown = self.get_context_window()
|
||||
|
@ -72,61 +72,6 @@ class BaseAgent(ABC):
|
||||
|
||||
return [{"role": input_message.role.value, "content": get_content(input_message)} for input_message in input_messages]
|
||||
|
||||
def _rebuild_memory(
|
||||
self,
|
||||
in_context_messages: List[Message],
|
||||
agent_state: AgentState,
|
||||
num_messages: int | None = None, # storing these calculations is specific to the voice agent
|
||||
num_archival_memories: int | None = None,
|
||||
) -> List[Message]:
|
||||
try:
|
||||
# Refresh Memory
|
||||
# TODO: This only happens for the summary block (voice?)
|
||||
# [DB Call] loading blocks (modifies: agent_state.memory.blocks)
|
||||
self.agent_manager.refresh_memory(agent_state=agent_state, actor=self.actor)
|
||||
|
||||
# TODO: This is a pretty brittle pattern established all over our code, need to get rid of this
|
||||
curr_system_message = in_context_messages[0]
|
||||
curr_memory_str = agent_state.memory.compile()
|
||||
curr_system_message_text = curr_system_message.content[0].text
|
||||
if curr_memory_str in curr_system_message_text:
|
||||
# NOTE: could this cause issues if a block is removed? (substring match would still work)
|
||||
logger.debug(
|
||||
f"Memory hasn't changed for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name}), skipping system prompt rebuild"
|
||||
)
|
||||
return in_context_messages
|
||||
|
||||
memory_edit_timestamp = get_utc_time()
|
||||
|
||||
# [DB Call] size of messages and archival memories
|
||||
num_messages = num_messages or self.message_manager.size(actor=self.actor, agent_id=agent_state.id)
|
||||
num_archival_memories = num_archival_memories or self.passage_manager.size(actor=self.actor, agent_id=agent_state.id)
|
||||
|
||||
new_system_message_str = compile_system_message(
|
||||
system_prompt=agent_state.system,
|
||||
in_context_memory=agent_state.memory,
|
||||
in_context_memory_last_edit=memory_edit_timestamp,
|
||||
previous_message_count=num_messages,
|
||||
archival_memory_size=num_archival_memories,
|
||||
)
|
||||
|
||||
diff = united_diff(curr_system_message_text, new_system_message_str)
|
||||
if len(diff) > 0:
|
||||
logger.debug(f"Rebuilding system with new memory...\nDiff:\n{diff}")
|
||||
|
||||
# [DB Call] Update Messages
|
||||
new_system_message = self.message_manager.update_message_by_id(
|
||||
curr_system_message.id, message_update=MessageUpdate(content=new_system_message_str), actor=self.actor
|
||||
)
|
||||
# Skip pulling down the agent's memory again to save on a db call
|
||||
return [new_system_message] + in_context_messages[1:]
|
||||
|
||||
else:
|
||||
return in_context_messages
|
||||
except:
|
||||
logger.exception(f"Failed to rebuild memory for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name})")
|
||||
raise
|
||||
|
||||
async def _rebuild_memory_async(
|
||||
self,
|
||||
in_context_messages: List[Message],
|
||||
|
@ -1,3 +1,4 @@
|
||||
import uuid
|
||||
import xml.etree.ElementTree as ET
|
||||
from typing import List, Tuple
|
||||
|
||||
@ -150,3 +151,7 @@ def deserialize_message_history(xml_str: str) -> Tuple[List[str], str]:
|
||||
context = sum_el.text or ""
|
||||
|
||||
return messages, context
|
||||
|
||||
|
||||
def generate_step_id():
|
||||
return f"step-{uuid.uuid4()}"
|
||||
|
@ -8,8 +8,9 @@ from openai.types import CompletionUsage
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||
|
||||
from letta.agents.base_agent import BaseAgent
|
||||
from letta.agents.helpers import _create_letta_response, _prepare_in_context_messages_async
|
||||
from letta.agents.helpers import _create_letta_response, _prepare_in_context_messages_async, generate_step_id
|
||||
from letta.helpers import ToolRulesSolver
|
||||
from letta.helpers.datetime_helpers import get_utc_timestamp_ns
|
||||
from letta.helpers.tool_execution_helper import enable_strict_mode
|
||||
from letta.interfaces.anthropic_streaming_interface import AnthropicStreamingInterface
|
||||
from letta.interfaces.openai_streaming_interface import OpenAIStreamingInterface
|
||||
@ -24,7 +25,8 @@ from letta.schemas.letta_message import AssistantMessage
|
||||
from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, RedactedReasoningContent, TextContent
|
||||
from letta.schemas.letta_response import LettaResponse
|
||||
from letta.schemas.message import Message, MessageCreate
|
||||
from letta.schemas.openai.chat_completion_response import ToolCall
|
||||
from letta.schemas.openai.chat_completion_response import ToolCall, UsageStatistics
|
||||
from letta.schemas.provider_trace import ProviderTraceCreate
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.schemas.user import User
|
||||
from letta.server.rest_api.utils import create_letta_messages_from_llm_response
|
||||
@ -32,10 +34,11 @@ from letta.services.agent_manager import AgentManager
|
||||
from letta.services.block_manager import BlockManager
|
||||
from letta.services.message_manager import MessageManager
|
||||
from letta.services.passage_manager import PassageManager
|
||||
from letta.services.step_manager import NoopStepManager, StepManager
|
||||
from letta.services.telemetry_manager import NoopTelemetryManager, TelemetryManager
|
||||
from letta.services.tool_executor.tool_execution_manager import ToolExecutionManager
|
||||
from letta.settings import settings
|
||||
from letta.system import package_function_response
|
||||
from letta.tracing import log_event, trace_method
|
||||
from letta.tracing import log_event, trace_method, tracer
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@ -50,6 +53,8 @@ class LettaAgent(BaseAgent):
|
||||
block_manager: BlockManager,
|
||||
passage_manager: PassageManager,
|
||||
actor: User,
|
||||
step_manager: StepManager = NoopStepManager(),
|
||||
telemetry_manager: TelemetryManager = NoopTelemetryManager(),
|
||||
):
|
||||
super().__init__(agent_id=agent_id, openai_client=None, message_manager=message_manager, agent_manager=agent_manager, actor=actor)
|
||||
|
||||
@ -57,6 +62,8 @@ class LettaAgent(BaseAgent):
|
||||
# Summarizer settings
|
||||
self.block_manager = block_manager
|
||||
self.passage_manager = passage_manager
|
||||
self.step_manager = step_manager
|
||||
self.telemetry_manager = telemetry_manager
|
||||
self.response_messages: List[Message] = []
|
||||
|
||||
self.last_function_response = None
|
||||
@ -67,17 +74,19 @@ class LettaAgent(BaseAgent):
|
||||
|
||||
@trace_method
|
||||
async def step(self, input_messages: List[MessageCreate], max_steps: int = 10, use_assistant_message: bool = True) -> LettaResponse:
|
||||
agent_state = await self.agent_manager.get_agent_by_id_async(self.agent_id, actor=self.actor)
|
||||
current_in_context_messages, new_in_context_messages, usage = await self._step(
|
||||
agent_state=agent_state, input_messages=input_messages, max_steps=max_steps
|
||||
agent_state = await self.agent_manager.get_agent_by_id_async(
|
||||
agent_id=self.agent_id, include_relationships=["tools", "memory"], actor=self.actor
|
||||
)
|
||||
_, new_in_context_messages, usage = await self._step(agent_state=agent_state, input_messages=input_messages, max_steps=max_steps)
|
||||
return _create_letta_response(
|
||||
new_in_context_messages=new_in_context_messages, use_assistant_message=use_assistant_message, usage=usage
|
||||
)
|
||||
|
||||
async def _step(
|
||||
self, agent_state: AgentState, input_messages: List[MessageCreate], max_steps: int = 10
|
||||
) -> Tuple[List[Message], List[Message], CompletionUsage]:
|
||||
@trace_method
|
||||
async def step_stream_no_tokens(self, input_messages: List[MessageCreate], max_steps: int = 10, use_assistant_message: bool = True):
|
||||
agent_state = await self.agent_manager.get_agent_by_id_async(
|
||||
agent_id=self.agent_id, include_relationships=["tools", "memory"], actor=self.actor
|
||||
)
|
||||
current_in_context_messages, new_in_context_messages = await _prepare_in_context_messages_async(
|
||||
input_messages, agent_state, self.message_manager, self.actor
|
||||
)
|
||||
@ -89,23 +98,81 @@ class LettaAgent(BaseAgent):
|
||||
)
|
||||
usage = LettaUsageStatistics()
|
||||
for _ in range(max_steps):
|
||||
response = await self._get_ai_reply(
|
||||
step_id = generate_step_id()
|
||||
|
||||
in_context_messages = await self._rebuild_memory_async(
|
||||
current_in_context_messages + new_in_context_messages,
|
||||
agent_state,
|
||||
num_messages=self.num_messages,
|
||||
num_archival_memories=self.num_archival_memories,
|
||||
)
|
||||
log_event("agent.stream_no_tokens.messages.refreshed") # [1^]
|
||||
|
||||
request_data = await self._create_llm_request_data_async(
|
||||
llm_client=llm_client,
|
||||
in_context_messages=current_in_context_messages + new_in_context_messages,
|
||||
in_context_messages=in_context_messages,
|
||||
agent_state=agent_state,
|
||||
tool_rules_solver=tool_rules_solver,
|
||||
stream=False,
|
||||
# TODO: also pass in reasoning content
|
||||
# TODO: pass in reasoning content
|
||||
)
|
||||
log_event("agent.stream_no_tokens.llm_request.created") # [2^]
|
||||
|
||||
try:
|
||||
response_data = await llm_client.request_async(request_data, agent_state.llm_config)
|
||||
except Exception as e:
|
||||
raise llm_client.handle_llm_error(e)
|
||||
log_event("agent.stream_no_tokens.llm_response.received") # [3^]
|
||||
|
||||
response = llm_client.convert_response_to_chat_completion(response_data, in_context_messages, agent_state.llm_config)
|
||||
|
||||
# update usage
|
||||
# TODO: add run_id
|
||||
usage.step_count += 1
|
||||
usage.completion_tokens += response.usage.completion_tokens
|
||||
usage.prompt_tokens += response.usage.prompt_tokens
|
||||
usage.total_tokens += response.usage.total_tokens
|
||||
|
||||
if not response.choices[0].message.tool_calls:
|
||||
# TODO: make into a real error
|
||||
raise ValueError("No tool calls found in response, model must make a tool call")
|
||||
tool_call = response.choices[0].message.tool_calls[0]
|
||||
reasoning = [TextContent(text=response.choices[0].message.content)] # reasoning placed into content for legacy reasons
|
||||
if response.choices[0].message.reasoning_content:
|
||||
reasoning = [
|
||||
ReasoningContent(
|
||||
reasoning=response.choices[0].message.reasoning_content,
|
||||
is_native=True,
|
||||
signature=response.choices[0].message.reasoning_content_signature,
|
||||
)
|
||||
]
|
||||
else:
|
||||
reasoning = [TextContent(text=response.choices[0].message.content)] # reasoning placed into content for legacy reasons
|
||||
|
||||
persisted_messages, should_continue = await self._handle_ai_response(
|
||||
tool_call, agent_state, tool_rules_solver, reasoning_content=reasoning
|
||||
tool_call, agent_state, tool_rules_solver, response.usage, reasoning_content=reasoning
|
||||
)
|
||||
self.response_messages.extend(persisted_messages)
|
||||
new_in_context_messages.extend(persisted_messages)
|
||||
log_event("agent.stream_no_tokens.llm_response.processed") # [4^]
|
||||
|
||||
# Log LLM Trace
|
||||
await self.telemetry_manager.create_provider_trace_async(
|
||||
actor=self.actor,
|
||||
provider_trace_create=ProviderTraceCreate(
|
||||
request_json=request_data,
|
||||
response_json=response_data,
|
||||
step_id=step_id,
|
||||
organization_id=self.actor.organization_id,
|
||||
),
|
||||
)
|
||||
|
||||
# stream step
|
||||
# TODO: improve TTFT
|
||||
filter_user_messages = [m for m in persisted_messages if m.role != "user"]
|
||||
letta_messages = Message.to_letta_messages_from_list(
|
||||
filter_user_messages, use_assistant_message=use_assistant_message, reverse=False
|
||||
)
|
||||
for message in letta_messages:
|
||||
yield f"data: {message.model_dump_json()}\n\n"
|
||||
|
||||
# update usage
|
||||
# TODO: add run_id
|
||||
@ -122,17 +189,125 @@ class LettaAgent(BaseAgent):
|
||||
message_ids = [m.id for m in (current_in_context_messages + new_in_context_messages)]
|
||||
self.agent_manager.set_in_context_messages(agent_id=self.agent_id, message_ids=message_ids, actor=self.actor)
|
||||
|
||||
# Return back usage
|
||||
yield f"data: {usage.model_dump_json()}\n\n"
|
||||
|
||||
async def _step(
|
||||
self, agent_state: AgentState, input_messages: List[MessageCreate], max_steps: int = 10
|
||||
) -> Tuple[List[Message], List[Message], CompletionUsage]:
|
||||
"""
|
||||
Carries out an invocation of the agent loop. In each step, the agent
|
||||
1. Rebuilds its memory
|
||||
2. Generates a request for the LLM
|
||||
3. Fetches a response from the LLM
|
||||
4. Processes the response
|
||||
"""
|
||||
current_in_context_messages, new_in_context_messages = await _prepare_in_context_messages_async(
|
||||
input_messages, agent_state, self.message_manager, self.actor
|
||||
)
|
||||
tool_rules_solver = ToolRulesSolver(agent_state.tool_rules)
|
||||
llm_client = LLMClient.create(
|
||||
provider_type=agent_state.llm_config.model_endpoint_type,
|
||||
put_inner_thoughts_first=True,
|
||||
actor=self.actor,
|
||||
)
|
||||
usage = LettaUsageStatistics()
|
||||
for _ in range(max_steps):
|
||||
step_id = generate_step_id()
|
||||
|
||||
in_context_messages = await self._rebuild_memory_async(
|
||||
current_in_context_messages + new_in_context_messages,
|
||||
agent_state,
|
||||
num_messages=self.num_messages,
|
||||
num_archival_memories=self.num_archival_memories,
|
||||
)
|
||||
log_event("agent.step.messages.refreshed") # [1^]
|
||||
|
||||
request_data = await self._create_llm_request_data_async(
|
||||
llm_client=llm_client,
|
||||
in_context_messages=in_context_messages,
|
||||
agent_state=agent_state,
|
||||
tool_rules_solver=tool_rules_solver,
|
||||
# TODO: pass in reasoning content
|
||||
)
|
||||
log_event("agent.step.llm_request.created") # [2^]
|
||||
|
||||
try:
|
||||
response_data = await llm_client.request_async(request_data, agent_state.llm_config)
|
||||
except Exception as e:
|
||||
raise llm_client.handle_llm_error(e)
|
||||
log_event("agent.step.llm_response.received") # [3^]
|
||||
|
||||
response = llm_client.convert_response_to_chat_completion(response_data, in_context_messages, agent_state.llm_config)
|
||||
|
||||
# TODO: add run_id
|
||||
usage.step_count += 1
|
||||
usage.completion_tokens += response.usage.completion_tokens
|
||||
usage.prompt_tokens += response.usage.prompt_tokens
|
||||
usage.total_tokens += response.usage.total_tokens
|
||||
|
||||
if not response.choices[0].message.tool_calls:
|
||||
# TODO: make into a real error
|
||||
raise ValueError("No tool calls found in response, model must make a tool call")
|
||||
tool_call = response.choices[0].message.tool_calls[0]
|
||||
if response.choices[0].message.reasoning_content:
|
||||
reasoning = [
|
||||
ReasoningContent(
|
||||
reasoning=response.choices[0].message.reasoning_content,
|
||||
is_native=True,
|
||||
signature=response.choices[0].message.reasoning_content_signature,
|
||||
)
|
||||
]
|
||||
else:
|
||||
reasoning = [TextContent(text=response.choices[0].message.content)] # reasoning placed into content for legacy reasons
|
||||
|
||||
persisted_messages, should_continue = await self._handle_ai_response(
|
||||
tool_call, agent_state, tool_rules_solver, response.usage, reasoning_content=reasoning, step_id=step_id
|
||||
)
|
||||
self.response_messages.extend(persisted_messages)
|
||||
new_in_context_messages.extend(persisted_messages)
|
||||
log_event("agent.step.llm_response.processed") # [4^]
|
||||
|
||||
# Log LLM Trace
|
||||
await self.telemetry_manager.create_provider_trace_async(
|
||||
actor=self.actor,
|
||||
provider_trace_create=ProviderTraceCreate(
|
||||
request_json=request_data,
|
||||
response_json=response_data,
|
||||
step_id=step_id,
|
||||
organization_id=self.actor.organization_id,
|
||||
),
|
||||
)
|
||||
|
||||
if not should_continue:
|
||||
break
|
||||
|
||||
# Extend the in context message ids
|
||||
if not agent_state.message_buffer_autoclear:
|
||||
message_ids = [m.id for m in (current_in_context_messages + new_in_context_messages)]
|
||||
self.agent_manager.set_in_context_messages(agent_id=self.agent_id, message_ids=message_ids, actor=self.actor)
|
||||
|
||||
return current_in_context_messages, new_in_context_messages, usage
|
||||
|
||||
@trace_method
|
||||
async def step_stream(
|
||||
self, input_messages: List[MessageCreate], max_steps: int = 10, use_assistant_message: bool = True, stream_tokens: bool = False
|
||||
self,
|
||||
input_messages: List[MessageCreate],
|
||||
max_steps: int = 10,
|
||||
use_assistant_message: bool = True,
|
||||
request_start_timestamp_ns: Optional[int] = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Main streaming loop that yields partial tokens.
|
||||
Whenever we detect a tool call, we yield from _handle_ai_response as well.
|
||||
Carries out an invocation of the agent loop in a streaming fashion that yields partial tokens.
|
||||
Whenever we detect a tool call, we yield from _handle_ai_response as well. At each step, the agent
|
||||
1. Rebuilds its memory
|
||||
2. Generates a request for the LLM
|
||||
3. Fetches a response from the LLM
|
||||
4. Processes the response
|
||||
"""
|
||||
agent_state = await self.agent_manager.get_agent_by_id_async(self.agent_id, actor=self.actor)
|
||||
agent_state = await self.agent_manager.get_agent_by_id_async(
|
||||
agent_id=self.agent_id, include_relationships=["tools", "memory"], actor=self.actor
|
||||
)
|
||||
current_in_context_messages, new_in_context_messages = await _prepare_in_context_messages_async(
|
||||
input_messages, agent_state, self.message_manager, self.actor
|
||||
)
|
||||
@ -145,13 +320,29 @@ class LettaAgent(BaseAgent):
|
||||
usage = LettaUsageStatistics()
|
||||
|
||||
for _ in range(max_steps):
|
||||
stream = await self._get_ai_reply(
|
||||
step_id = generate_step_id()
|
||||
in_context_messages = await self._rebuild_memory_async(
|
||||
current_in_context_messages + new_in_context_messages,
|
||||
agent_state,
|
||||
num_messages=self.num_messages,
|
||||
num_archival_memories=self.num_archival_memories,
|
||||
)
|
||||
log_event("agent.step.messages.refreshed") # [1^]
|
||||
|
||||
request_data = await self._create_llm_request_data_async(
|
||||
llm_client=llm_client,
|
||||
in_context_messages=current_in_context_messages + new_in_context_messages,
|
||||
in_context_messages=in_context_messages,
|
||||
agent_state=agent_state,
|
||||
tool_rules_solver=tool_rules_solver,
|
||||
stream=True,
|
||||
)
|
||||
log_event("agent.stream.llm_request.created") # [2^]
|
||||
|
||||
try:
|
||||
stream = await llm_client.stream_async(request_data, agent_state.llm_config)
|
||||
except Exception as e:
|
||||
raise llm_client.handle_llm_error(e)
|
||||
log_event("agent.stream.llm_response.received") # [3^]
|
||||
|
||||
# TODO: THIS IS INCREDIBLY UGLY
|
||||
# TODO: THERE ARE MULTIPLE COPIES OF THE LLM_CONFIG EVERYWHERE THAT ARE GETTING MANIPULATED
|
||||
if agent_state.llm_config.model_endpoint_type == "anthropic":
|
||||
@ -164,7 +355,23 @@ class LettaAgent(BaseAgent):
|
||||
use_assistant_message=use_assistant_message,
|
||||
put_inner_thoughts_in_kwarg=agent_state.llm_config.put_inner_thoughts_in_kwargs,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Streaming not supported for {agent_state.llm_config}")
|
||||
|
||||
first_chunk, ttft_span = True, None
|
||||
if request_start_timestamp_ns is not None:
|
||||
ttft_span = tracer.start_span("time_to_first_token", start_time=request_start_timestamp_ns)
|
||||
ttft_span.set_attributes({f"llm_config.{k}": v for k, v in agent_state.llm_config.model_dump().items() if v is not None})
|
||||
|
||||
async for chunk in interface.process(stream):
|
||||
# Measure time to first token
|
||||
if first_chunk and ttft_span is not None:
|
||||
now = get_utc_timestamp_ns()
|
||||
ttft_ns = now - request_start_timestamp_ns
|
||||
ttft_span.add_event(name="time_to_first_token_ms", attributes={"ttft_ms": ttft_ns // 1_000_000})
|
||||
ttft_span.end()
|
||||
first_chunk = False
|
||||
|
||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||
|
||||
# update usage
|
||||
@ -180,13 +387,46 @@ class LettaAgent(BaseAgent):
|
||||
tool_call,
|
||||
agent_state,
|
||||
tool_rules_solver,
|
||||
UsageStatistics(
|
||||
completion_tokens=interface.output_tokens,
|
||||
prompt_tokens=interface.input_tokens,
|
||||
total_tokens=interface.input_tokens + interface.output_tokens,
|
||||
),
|
||||
reasoning_content=reasoning_content,
|
||||
pre_computed_assistant_message_id=interface.letta_assistant_message_id,
|
||||
pre_computed_tool_message_id=interface.letta_tool_message_id,
|
||||
step_id=step_id,
|
||||
)
|
||||
self.response_messages.extend(persisted_messages)
|
||||
new_in_context_messages.extend(persisted_messages)
|
||||
|
||||
# TODO (cliandy): the stream POST request span has ended at this point, we should tie this to the stream
|
||||
# log_event("agent.stream.llm_response.processed") # [4^]
|
||||
|
||||
# Log LLM Trace
|
||||
# TODO (cliandy): we are piecing together the streamed response here. Content here does not match the actual response schema.
|
||||
await self.telemetry_manager.create_provider_trace_async(
|
||||
actor=self.actor,
|
||||
provider_trace_create=ProviderTraceCreate(
|
||||
request_json=request_data,
|
||||
response_json={
|
||||
"content": {
|
||||
"tool_call": tool_call.model_dump_json(),
|
||||
"reasoning": [content.model_dump_json() for content in reasoning_content],
|
||||
},
|
||||
"id": interface.message_id,
|
||||
"model": interface.model,
|
||||
"role": "assistant",
|
||||
# "stop_reason": "",
|
||||
# "stop_sequence": None,
|
||||
"type": "message",
|
||||
"usage": {"input_tokens": interface.input_tokens, "output_tokens": interface.output_tokens},
|
||||
},
|
||||
step_id=step_id,
|
||||
organization_id=self.actor.organization_id,
|
||||
),
|
||||
)
|
||||
|
||||
if not use_assistant_message or should_continue:
|
||||
tool_return = [msg for msg in persisted_messages if msg.role == "tool"][-1].to_letta_messages()[0]
|
||||
yield f"data: {tool_return.model_dump_json()}\n\n"
|
||||
@ -209,28 +449,20 @@ class LettaAgent(BaseAgent):
|
||||
yield f"data: {MessageStreamStatus.done.model_dump_json()}\n\n"
|
||||
|
||||
@trace_method
|
||||
# When raising an error this doesn't show up
|
||||
async def _get_ai_reply(
|
||||
async def _create_llm_request_data_async(
|
||||
self,
|
||||
llm_client: LLMClientBase,
|
||||
in_context_messages: List[Message],
|
||||
agent_state: AgentState,
|
||||
tool_rules_solver: ToolRulesSolver,
|
||||
stream: bool,
|
||||
) -> ChatCompletion | AsyncStream[ChatCompletionChunk]:
|
||||
if settings.experimental_enable_async_db_engine:
|
||||
self.num_messages = self.num_messages or (await self.message_manager.size_async(actor=self.actor, agent_id=agent_state.id))
|
||||
self.num_archival_memories = self.num_archival_memories or (
|
||||
await self.passage_manager.size_async(actor=self.actor, agent_id=agent_state.id)
|
||||
)
|
||||
in_context_messages = await self._rebuild_memory_async(
|
||||
in_context_messages, agent_state, num_messages=self.num_messages, num_archival_memories=self.num_archival_memories
|
||||
)
|
||||
else:
|
||||
if settings.experimental_skip_rebuild_memory and agent_state.llm_config.model_endpoint_type == "google_vertex":
|
||||
logger.info("Skipping memory rebuild")
|
||||
else:
|
||||
in_context_messages = self._rebuild_memory(in_context_messages, agent_state)
|
||||
self.num_messages = self.num_messages or (await self.message_manager.size_async(actor=self.actor, agent_id=agent_state.id))
|
||||
self.num_archival_memories = self.num_archival_memories or (
|
||||
await self.passage_manager.size_async(actor=self.actor, agent_id=agent_state.id)
|
||||
)
|
||||
in_context_messages = await self._rebuild_memory_async(
|
||||
in_context_messages, agent_state, num_messages=self.num_messages, num_archival_memories=self.num_archival_memories
|
||||
)
|
||||
|
||||
tools = [
|
||||
t
|
||||
@ -243,8 +475,8 @@ class LettaAgent(BaseAgent):
|
||||
ToolType.LETTA_MULTI_AGENT_CORE,
|
||||
ToolType.LETTA_SLEEPTIME_CORE,
|
||||
ToolType.LETTA_VOICE_SLEEPTIME_CORE,
|
||||
ToolType.LETTA_BUILTIN,
|
||||
}
|
||||
or (t.tool_type == ToolType.LETTA_MULTI_AGENT_CORE and t.name == "send_message_to_agents_matching_tags")
|
||||
or (t.tool_type == ToolType.EXTERNAL_COMPOSIO)
|
||||
]
|
||||
|
||||
@ -264,15 +496,7 @@ class LettaAgent(BaseAgent):
|
||||
|
||||
allowed_tools = [enable_strict_mode(t.json_schema) for t in tools if t.name in set(valid_tool_names)]
|
||||
|
||||
response = await llm_client.send_llm_request_async(
|
||||
messages=in_context_messages,
|
||||
llm_config=agent_state.llm_config,
|
||||
tools=allowed_tools,
|
||||
force_tool_call=force_tool_call,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
return response
|
||||
return llm_client.build_request_data(in_context_messages, agent_state.llm_config, allowed_tools, force_tool_call)
|
||||
|
||||
@trace_method
|
||||
async def _handle_ai_response(
|
||||
@ -280,9 +504,11 @@ class LettaAgent(BaseAgent):
|
||||
tool_call: ToolCall,
|
||||
agent_state: AgentState,
|
||||
tool_rules_solver: ToolRulesSolver,
|
||||
usage: UsageStatistics,
|
||||
reasoning_content: Optional[List[Union[TextContent, ReasoningContent, RedactedReasoningContent, OmittedReasoningContent]]] = None,
|
||||
pre_computed_assistant_message_id: Optional[str] = None,
|
||||
pre_computed_tool_message_id: Optional[str] = None,
|
||||
step_id: str | None = None,
|
||||
) -> Tuple[List[Message], bool]:
|
||||
"""
|
||||
Now that streaming is done, handle the final AI response.
|
||||
@ -294,8 +520,11 @@ class LettaAgent(BaseAgent):
|
||||
|
||||
try:
|
||||
tool_args = json.loads(tool_call_args_str)
|
||||
assert isinstance(tool_args, dict), "tool_args must be a dict"
|
||||
except json.JSONDecodeError:
|
||||
tool_args = {}
|
||||
except AssertionError:
|
||||
tool_args = json.loads(tool_args)
|
||||
|
||||
# Get request heartbeats and coerce to bool
|
||||
request_heartbeat = tool_args.pop("request_heartbeat", False)
|
||||
@ -329,7 +558,25 @@ class LettaAgent(BaseAgent):
|
||||
elif tool_rules_solver.is_continue_tool(tool_name=tool_call_name):
|
||||
continue_stepping = True
|
||||
|
||||
# 5. Persist to DB
|
||||
# 5a. Persist Steps to DB
|
||||
# Following agent loop to persist this before messages
|
||||
# TODO (cliandy): determine what should match old loop w/provider_id, job_id
|
||||
# TODO (cliandy): UsageStatistics and LettaUsageStatistics are used in many places, but are not the same.
|
||||
logged_step = await self.step_manager.log_step_async(
|
||||
actor=self.actor,
|
||||
agent_id=agent_state.id,
|
||||
provider_name=agent_state.llm_config.model_endpoint_type,
|
||||
provider_category=agent_state.llm_config.provider_category or "base",
|
||||
model=agent_state.llm_config.model,
|
||||
model_endpoint=agent_state.llm_config.model_endpoint,
|
||||
context_window_limit=agent_state.llm_config.context_window,
|
||||
usage=usage,
|
||||
provider_id=None,
|
||||
job_id=None,
|
||||
step_id=step_id,
|
||||
)
|
||||
|
||||
# 5b. Persist Messages to DB
|
||||
tool_call_messages = create_letta_messages_from_llm_response(
|
||||
agent_id=agent_state.id,
|
||||
model=agent_state.llm_config.model,
|
||||
@ -343,6 +590,7 @@ class LettaAgent(BaseAgent):
|
||||
reasoning_content=reasoning_content,
|
||||
pre_computed_assistant_message_id=pre_computed_assistant_message_id,
|
||||
pre_computed_tool_message_id=pre_computed_tool_message_id,
|
||||
step_id=logged_step.id if logged_step else None, # TODO (cliandy): eventually move over other agent loops
|
||||
)
|
||||
persisted_messages = await self.message_manager.create_many_messages_async(tool_call_messages, actor=self.actor)
|
||||
self.last_function_response = function_response
|
||||
@ -361,20 +609,21 @@ class LettaAgent(BaseAgent):
|
||||
|
||||
# TODO: This temp. Move this logic and code to executors
|
||||
try:
|
||||
if target_tool.name == "send_message_to_agents_matching_tags" and target_tool.tool_type == ToolType.LETTA_MULTI_AGENT_CORE:
|
||||
log_event(name="start_send_message_to_agents_matching_tags", attributes=tool_args)
|
||||
results = await self._send_message_to_agents_matching_tags(**tool_args)
|
||||
log_event(name="finish_send_message_to_agents_matching_tags", attributes=tool_args)
|
||||
return json.dumps(results), True
|
||||
else:
|
||||
tool_execution_manager = ToolExecutionManager(agent_state=agent_state, actor=self.actor)
|
||||
# TODO: Integrate sandbox result
|
||||
log_event(name=f"start_{tool_name}_execution", attributes=tool_args)
|
||||
tool_execution_result = await tool_execution_manager.execute_tool_async(
|
||||
function_name=tool_name, function_args=tool_args, tool=target_tool
|
||||
)
|
||||
log_event(name=f"finish_{tool_name}_execution", attributes=tool_args)
|
||||
return tool_execution_result.func_return, True
|
||||
tool_execution_manager = ToolExecutionManager(
|
||||
agent_state=agent_state,
|
||||
message_manager=self.message_manager,
|
||||
agent_manager=self.agent_manager,
|
||||
block_manager=self.block_manager,
|
||||
passage_manager=self.passage_manager,
|
||||
actor=self.actor,
|
||||
)
|
||||
# TODO: Integrate sandbox result
|
||||
log_event(name=f"start_{tool_name}_execution", attributes=tool_args)
|
||||
tool_execution_result = await tool_execution_manager.execute_tool_async(
|
||||
function_name=tool_name, function_args=tool_args, tool=target_tool
|
||||
)
|
||||
log_event(name=f"finish_{tool_name}_execution", attributes=tool_args)
|
||||
return tool_execution_result.func_return, True
|
||||
except Exception as e:
|
||||
return f"Failed to call tool. Error: {e}", False
|
||||
|
||||
@ -430,6 +679,7 @@ class LettaAgent(BaseAgent):
|
||||
results = await asyncio.gather(*tasks)
|
||||
return results
|
||||
|
||||
@trace_method
|
||||
async def _load_last_function_response_async(self):
|
||||
"""Load the last function response from message history"""
|
||||
in_context_messages = await self.agent_manager.get_in_context_messages_async(agent_id=self.agent_id, actor=self.actor)
|
||||
|
@ -145,7 +145,7 @@ class LettaAgentBatch(BaseAgent):
|
||||
agent_mapping = {
|
||||
agent_state.id: agent_state
|
||||
for agent_state in await self.agent_manager.get_agents_by_ids_async(
|
||||
agent_ids=[request.agent_id for request in batch_requests], actor=self.actor
|
||||
agent_ids=[request.agent_id for request in batch_requests], include_relationships=["tools", "memory"], actor=self.actor
|
||||
)
|
||||
}
|
||||
|
||||
@ -267,64 +267,121 @@ class LettaAgentBatch(BaseAgent):
|
||||
|
||||
@trace_method
|
||||
async def _collect_resume_context(self, llm_batch_id: str) -> _ResumeContext:
|
||||
# NOTE: We only continue for items with successful results
|
||||
"""
|
||||
Collect context for resuming operations from completed batch items.
|
||||
|
||||
Args:
|
||||
llm_batch_id: The ID of the batch to collect context for
|
||||
|
||||
Returns:
|
||||
_ResumeContext object containing all necessary data for resumption
|
||||
"""
|
||||
# Fetch only completed batch items
|
||||
batch_items = await self.batch_manager.list_llm_batch_items_async(llm_batch_id=llm_batch_id, request_status=JobStatus.completed)
|
||||
|
||||
agent_ids = []
|
||||
provider_results = {}
|
||||
request_status_updates: List[RequestStatusUpdateInfo] = []
|
||||
# Exit early if no items to process
|
||||
if not batch_items:
|
||||
return _ResumeContext(
|
||||
batch_items=[],
|
||||
agent_ids=[],
|
||||
agent_state_map={},
|
||||
provider_results={},
|
||||
tool_call_name_map={},
|
||||
tool_call_args_map={},
|
||||
should_continue_map={},
|
||||
request_status_updates=[],
|
||||
)
|
||||
|
||||
for item in batch_items:
|
||||
aid = item.agent_id
|
||||
agent_ids.append(aid)
|
||||
provider_results[aid] = item.batch_request_result.result
|
||||
# Extract agent IDs and organize items by agent ID
|
||||
agent_ids = [item.agent_id for item in batch_items]
|
||||
batch_item_map = {item.agent_id: item for item in batch_items}
|
||||
|
||||
agent_states = await self.agent_manager.get_agents_by_ids_async(agent_ids, actor=self.actor)
|
||||
# Collect provider results
|
||||
provider_results = {item.agent_id: item.batch_request_result.result for item in batch_items}
|
||||
|
||||
# Fetch agent states in a single call
|
||||
agent_states = await self.agent_manager.get_agents_by_ids_async(
|
||||
agent_ids=agent_ids, include_relationships=["tools", "memory"], actor=self.actor
|
||||
)
|
||||
agent_state_map = {agent.id: agent for agent in agent_states}
|
||||
|
||||
name_map, args_map, cont_map = {}, {}, {}
|
||||
for aid in agent_ids:
|
||||
# status bookkeeping
|
||||
pr = provider_results[aid]
|
||||
status = (
|
||||
JobStatus.completed
|
||||
if isinstance(pr, BetaMessageBatchSucceededResult)
|
||||
else (
|
||||
JobStatus.failed
|
||||
if isinstance(pr, BetaMessageBatchErroredResult)
|
||||
else JobStatus.cancelled if isinstance(pr, BetaMessageBatchCanceledResult) else JobStatus.expired
|
||||
)
|
||||
)
|
||||
request_status_updates.append(RequestStatusUpdateInfo(llm_batch_id=llm_batch_id, agent_id=aid, request_status=status))
|
||||
|
||||
# translate provider‑specific response → OpenAI‑style tool call (unchanged)
|
||||
llm_client = LLMClient.create(
|
||||
provider_type=item.llm_config.model_endpoint_type,
|
||||
put_inner_thoughts_first=True,
|
||||
actor=self.actor,
|
||||
)
|
||||
tool_call = (
|
||||
llm_client.convert_response_to_chat_completion(
|
||||
response_data=pr.message.model_dump(), input_messages=[], llm_config=item.llm_config
|
||||
)
|
||||
.choices[0]
|
||||
.message.tool_calls[0]
|
||||
)
|
||||
|
||||
name, args, cont = self._extract_tool_call_and_decide_continue(tool_call, item.step_state)
|
||||
name_map[aid], args_map[aid], cont_map[aid] = name, args, cont
|
||||
# Process each agent's results
|
||||
tool_call_results = self._process_agent_results(
|
||||
agent_ids=agent_ids, batch_item_map=batch_item_map, provider_results=provider_results, llm_batch_id=llm_batch_id
|
||||
)
|
||||
|
||||
return _ResumeContext(
|
||||
batch_items=batch_items,
|
||||
agent_ids=agent_ids,
|
||||
agent_state_map=agent_state_map,
|
||||
provider_results=provider_results,
|
||||
tool_call_name_map=name_map,
|
||||
tool_call_args_map=args_map,
|
||||
should_continue_map=cont_map,
|
||||
request_status_updates=request_status_updates,
|
||||
tool_call_name_map=tool_call_results.name_map,
|
||||
tool_call_args_map=tool_call_results.args_map,
|
||||
should_continue_map=tool_call_results.cont_map,
|
||||
request_status_updates=tool_call_results.status_updates,
|
||||
)
|
||||
|
||||
def _process_agent_results(self, agent_ids, batch_item_map, provider_results, llm_batch_id):
|
||||
"""
|
||||
Process the results for each agent, extracting tool calls and determining continuation status.
|
||||
|
||||
Returns:
|
||||
A namedtuple containing name_map, args_map, cont_map, and status_updates
|
||||
"""
|
||||
from collections import namedtuple
|
||||
|
||||
ToolCallResults = namedtuple("ToolCallResults", ["name_map", "args_map", "cont_map", "status_updates"])
|
||||
|
||||
name_map, args_map, cont_map = {}, {}, {}
|
||||
request_status_updates = []
|
||||
|
||||
for aid in agent_ids:
|
||||
item = batch_item_map[aid]
|
||||
result = provider_results[aid]
|
||||
|
||||
# Determine job status based on result type
|
||||
status = self._determine_job_status(result)
|
||||
request_status_updates.append(RequestStatusUpdateInfo(llm_batch_id=llm_batch_id, agent_id=aid, request_status=status))
|
||||
|
||||
# Process tool calls
|
||||
name, args, cont = self._extract_tool_call_from_result(item, result)
|
||||
name_map[aid], args_map[aid], cont_map[aid] = name, args, cont
|
||||
|
||||
return ToolCallResults(name_map, args_map, cont_map, request_status_updates)
|
||||
|
||||
def _determine_job_status(self, result):
|
||||
"""Determine job status based on result type"""
|
||||
if isinstance(result, BetaMessageBatchSucceededResult):
|
||||
return JobStatus.completed
|
||||
elif isinstance(result, BetaMessageBatchErroredResult):
|
||||
return JobStatus.failed
|
||||
elif isinstance(result, BetaMessageBatchCanceledResult):
|
||||
return JobStatus.cancelled
|
||||
else:
|
||||
return JobStatus.expired
|
||||
|
||||
def _extract_tool_call_from_result(self, item, result):
|
||||
"""Extract tool call information from a result"""
|
||||
llm_client = LLMClient.create(
|
||||
provider_type=item.llm_config.model_endpoint_type,
|
||||
put_inner_thoughts_first=True,
|
||||
actor=self.actor,
|
||||
)
|
||||
|
||||
# If result isn't a successful type, we can't extract a tool call
|
||||
if not isinstance(result, BetaMessageBatchSucceededResult):
|
||||
return None, None, False
|
||||
|
||||
tool_call = (
|
||||
llm_client.convert_response_to_chat_completion(
|
||||
response_data=result.message.model_dump(), input_messages=[], llm_config=item.llm_config
|
||||
)
|
||||
.choices[0]
|
||||
.message.tool_calls[0]
|
||||
)
|
||||
|
||||
return self._extract_tool_call_and_decide_continue(tool_call, item.step_state)
|
||||
|
||||
def _update_request_statuses(self, updates: List[RequestStatusUpdateInfo]) -> None:
|
||||
if updates:
|
||||
self.batch_manager.bulk_update_llm_batch_items_request_status_by_agent(updates=updates)
|
||||
@ -556,16 +613,6 @@ class LettaAgentBatch(BaseAgent):
|
||||
in_context_messages = await self._rebuild_memory_async(current_in_context_messages + new_in_context_messages, agent_state)
|
||||
return in_context_messages
|
||||
|
||||
# TODO: Make this a bullk function
|
||||
def _rebuild_memory(
|
||||
self,
|
||||
in_context_messages: List[Message],
|
||||
agent_state: AgentState,
|
||||
num_messages: int | None = None,
|
||||
num_archival_memories: int | None = None,
|
||||
) -> List[Message]:
|
||||
return super()._rebuild_memory(in_context_messages, agent_state)
|
||||
|
||||
# Not used in batch.
|
||||
async def step(self, input_messages: List[MessageCreate], max_steps: int = 10) -> LettaResponse:
|
||||
raise NotImplementedError
|
||||
|
@ -154,7 +154,7 @@ class VoiceAgent(BaseAgent):
|
||||
# TODO: Define max steps here
|
||||
for _ in range(max_steps):
|
||||
# Rebuild memory each loop
|
||||
in_context_messages = self._rebuild_memory(in_context_messages, agent_state)
|
||||
in_context_messages = await self._rebuild_memory_async(in_context_messages, agent_state)
|
||||
openai_messages = convert_in_context_letta_messages_to_openai(in_context_messages, exclude_system_messages=True)
|
||||
openai_messages.extend(in_memory_message_history)
|
||||
|
||||
@ -292,14 +292,14 @@ class VoiceAgent(BaseAgent):
|
||||
agent_id=self.agent_id, message_ids=[m.id for m in new_in_context_messages], actor=self.actor
|
||||
)
|
||||
|
||||
def _rebuild_memory(
|
||||
async def _rebuild_memory_async(
|
||||
self,
|
||||
in_context_messages: List[Message],
|
||||
agent_state: AgentState,
|
||||
num_messages: int | None = None,
|
||||
num_archival_memories: int | None = None,
|
||||
) -> List[Message]:
|
||||
return super()._rebuild_memory(
|
||||
return await super()._rebuild_memory_async(
|
||||
in_context_messages, agent_state, num_messages=self.num_messages, num_archival_memories=self.num_archival_memories
|
||||
)
|
||||
|
||||
@ -438,7 +438,7 @@ class VoiceAgent(BaseAgent):
|
||||
if start_date and end_date and start_date > end_date:
|
||||
start_date, end_date = end_date, start_date
|
||||
|
||||
archival_results = self.agent_manager.list_passages(
|
||||
archival_results = await self.agent_manager.list_passages_async(
|
||||
actor=self.actor,
|
||||
agent_id=self.agent_id,
|
||||
query_text=archival_query,
|
||||
@ -457,7 +457,7 @@ class VoiceAgent(BaseAgent):
|
||||
keyword_results = {}
|
||||
if convo_keyword_queries:
|
||||
for keyword in convo_keyword_queries:
|
||||
messages = self.message_manager.list_messages_for_agent(
|
||||
messages = await self.message_manager.list_messages_for_agent_async(
|
||||
agent_id=self.agent_id,
|
||||
actor=self.actor,
|
||||
query_text=keyword,
|
||||
|
@ -2773,11 +2773,8 @@ class LocalClient(AbstractClient):
|
||||
|
||||
# humans / personas
|
||||
|
||||
def get_block_id(self, name: str, label: str) -> str:
|
||||
block = self.server.block_manager.get_blocks(actor=self.user, template_name=name, label=label, is_template=True)
|
||||
if not block:
|
||||
return None
|
||||
return block[0].id
|
||||
def get_block_id(self, name: str, label: str) -> str | None:
|
||||
return None
|
||||
|
||||
def create_human(self, name: str, text: str):
|
||||
"""
|
||||
@ -2812,7 +2809,7 @@ class LocalClient(AbstractClient):
|
||||
Returns:
|
||||
humans (List[Human]): List of human blocks
|
||||
"""
|
||||
return self.server.block_manager.get_blocks(actor=self.user, label="human", is_template=True)
|
||||
return []
|
||||
|
||||
def list_personas(self) -> List[Persona]:
|
||||
"""
|
||||
@ -2821,7 +2818,7 @@ class LocalClient(AbstractClient):
|
||||
Returns:
|
||||
personas (List[Persona]): List of persona blocks
|
||||
"""
|
||||
return self.server.block_manager.get_blocks(actor=self.user, label="persona", is_template=True)
|
||||
return []
|
||||
|
||||
def update_human(self, human_id: str, text: str):
|
||||
"""
|
||||
@ -2879,7 +2876,7 @@ class LocalClient(AbstractClient):
|
||||
assert id, f"Human ID must be provided"
|
||||
return Human(**self.server.block_manager.get_block_by_id(id, actor=self.user).model_dump())
|
||||
|
||||
def get_persona_id(self, name: str) -> str:
|
||||
def get_persona_id(self, name: str) -> str | None:
|
||||
"""
|
||||
Get the ID of a persona block template
|
||||
|
||||
@ -2889,12 +2886,9 @@ class LocalClient(AbstractClient):
|
||||
Returns:
|
||||
id (str): ID of the persona block
|
||||
"""
|
||||
persona = self.server.block_manager.get_blocks(actor=self.user, template_name=name, label="persona", is_template=True)
|
||||
if not persona:
|
||||
return None
|
||||
return persona[0].id
|
||||
return None
|
||||
|
||||
def get_human_id(self, name: str) -> str:
|
||||
def get_human_id(self, name: str) -> str | None:
|
||||
"""
|
||||
Get the ID of a human block template
|
||||
|
||||
@ -2904,10 +2898,7 @@ class LocalClient(AbstractClient):
|
||||
Returns:
|
||||
id (str): ID of the human block
|
||||
"""
|
||||
human = self.server.block_manager.get_blocks(actor=self.user, template_name=name, label="human", is_template=True)
|
||||
if not human:
|
||||
return None
|
||||
return human[0].id
|
||||
return None
|
||||
|
||||
def delete_persona(self, id: str):
|
||||
"""
|
||||
@ -3381,7 +3372,7 @@ class LocalClient(AbstractClient):
|
||||
Returns:
|
||||
blocks (List[Block]): List of blocks
|
||||
"""
|
||||
return self.server.block_manager.get_blocks(actor=self.user, label=label, is_template=templates_only)
|
||||
return []
|
||||
|
||||
def create_block(
|
||||
self, label: str, value: str, limit: Optional[int] = None, template_name: Optional[str] = None, is_template: bool = False
|
||||
|
@ -19,6 +19,7 @@ MCP_TOOL_TAG_NAME_PREFIX = "mcp" # full format, mcp:server_name
|
||||
LETTA_CORE_TOOL_MODULE_NAME = "letta.functions.function_sets.base"
|
||||
LETTA_MULTI_AGENT_TOOL_MODULE_NAME = "letta.functions.function_sets.multi_agent"
|
||||
LETTA_VOICE_TOOL_MODULE_NAME = "letta.functions.function_sets.voice"
|
||||
LETTA_BUILTIN_TOOL_MODULE_NAME = "letta.functions.function_sets.builtin"
|
||||
|
||||
|
||||
# String in the error message for when the context window is too large
|
||||
@ -83,9 +84,19 @@ BASE_VOICE_SLEEPTIME_TOOLS = [
|
||||
]
|
||||
# Multi agent tools
|
||||
MULTI_AGENT_TOOLS = ["send_message_to_agent_and_wait_for_reply", "send_message_to_agents_matching_tags", "send_message_to_agent_async"]
|
||||
|
||||
# Built in tools
|
||||
BUILTIN_TOOLS = ["run_code", "web_search"]
|
||||
|
||||
# Set of all built-in Letta tools
|
||||
LETTA_TOOL_SET = set(
|
||||
BASE_TOOLS + BASE_MEMORY_TOOLS + MULTI_AGENT_TOOLS + BASE_SLEEPTIME_TOOLS + BASE_VOICE_SLEEPTIME_TOOLS + BASE_VOICE_SLEEPTIME_CHAT_TOOLS
|
||||
BASE_TOOLS
|
||||
+ BASE_MEMORY_TOOLS
|
||||
+ MULTI_AGENT_TOOLS
|
||||
+ BASE_SLEEPTIME_TOOLS
|
||||
+ BASE_VOICE_SLEEPTIME_TOOLS
|
||||
+ BASE_VOICE_SLEEPTIME_CHAT_TOOLS
|
||||
+ BUILTIN_TOOLS
|
||||
)
|
||||
|
||||
# The name of the tool used to send message to the user
|
||||
@ -179,6 +190,45 @@ LLM_MAX_TOKENS = {
|
||||
"gpt-3.5-turbo-0613": 4096, # legacy
|
||||
"gpt-3.5-turbo-16k-0613": 16385, # legacy
|
||||
"gpt-3.5-turbo-0301": 4096, # legacy
|
||||
"gemini-1.0-pro-vision-latest": 12288,
|
||||
"gemini-pro-vision": 12288,
|
||||
"gemini-1.5-pro-latest": 2000000,
|
||||
"gemini-1.5-pro-001": 2000000,
|
||||
"gemini-1.5-pro-002": 2000000,
|
||||
"gemini-1.5-pro": 2000000,
|
||||
"gemini-1.5-flash-latest": 1000000,
|
||||
"gemini-1.5-flash-001": 1000000,
|
||||
"gemini-1.5-flash-001-tuning": 16384,
|
||||
"gemini-1.5-flash": 1000000,
|
||||
"gemini-1.5-flash-002": 1000000,
|
||||
"gemini-1.5-flash-8b": 1000000,
|
||||
"gemini-1.5-flash-8b-001": 1000000,
|
||||
"gemini-1.5-flash-8b-latest": 1000000,
|
||||
"gemini-1.5-flash-8b-exp-0827": 1000000,
|
||||
"gemini-1.5-flash-8b-exp-0924": 1000000,
|
||||
"gemini-2.5-pro-exp-03-25": 1048576,
|
||||
"gemini-2.5-pro-preview-03-25": 1048576,
|
||||
"gemini-2.5-flash-preview-04-17": 1048576,
|
||||
"gemini-2.5-flash-preview-05-20": 1048576,
|
||||
"gemini-2.5-flash-preview-04-17-thinking": 1048576,
|
||||
"gemini-2.5-pro-preview-05-06": 1048576,
|
||||
"gemini-2.0-flash-exp": 1048576,
|
||||
"gemini-2.0-flash": 1048576,
|
||||
"gemini-2.0-flash-001": 1048576,
|
||||
"gemini-2.0-flash-exp-image-generation": 1048576,
|
||||
"gemini-2.0-flash-lite-001": 1048576,
|
||||
"gemini-2.0-flash-lite": 1048576,
|
||||
"gemini-2.0-flash-preview-image-generation": 32768,
|
||||
"gemini-2.0-flash-lite-preview-02-05": 1048576,
|
||||
"gemini-2.0-flash-lite-preview": 1048576,
|
||||
"gemini-2.0-pro-exp": 1048576,
|
||||
"gemini-2.0-pro-exp-02-05": 1048576,
|
||||
"gemini-exp-1206": 1048576,
|
||||
"gemini-2.0-flash-thinking-exp-01-21": 1048576,
|
||||
"gemini-2.0-flash-thinking-exp": 1048576,
|
||||
"gemini-2.0-flash-thinking-exp-1219": 1048576,
|
||||
"gemini-2.5-flash-preview-tts": 32768,
|
||||
"gemini-2.5-pro-preview-tts": 65536,
|
||||
}
|
||||
# The error message that Letta will receive
|
||||
# MESSAGE_SUMMARY_WARNING_STR = f"Warning: the conversation history will soon reach its maximum length and be trimmed. Make sure to save any important information from the conversation to your memory before it is removed."
|
||||
@ -230,3 +280,7 @@ RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE = 5
|
||||
|
||||
MAX_FILENAME_LENGTH = 255
|
||||
RESERVED_FILENAMES = {"CON", "PRN", "AUX", "NUL", "COM1", "COM2", "LPT1", "LPT2"}
|
||||
|
||||
WEB_SEARCH_CLIP_CONTENT = False
|
||||
WEB_SEARCH_INCLUDE_SCORE = False
|
||||
WEB_SEARCH_SEPARATOR = "\n" + "-" * 40 + "\n"
|
||||
|
27
letta/functions/function_sets/builtin.py
Normal file
27
letta/functions/function_sets/builtin.py
Normal file
@ -0,0 +1,27 @@
|
||||
from typing import Literal
|
||||
|
||||
|
||||
async def web_search(query: str) -> str:
|
||||
"""
|
||||
Search the web for information.
|
||||
Args:
|
||||
query (str): The query to search the web for.
|
||||
Returns:
|
||||
str: The search results.
|
||||
"""
|
||||
|
||||
raise NotImplementedError("This is only available on the latest agent architecture. Please contact the Letta team.")
|
||||
|
||||
|
||||
def run_code(code: str, language: Literal["python", "js", "ts", "r", "java"]) -> str:
|
||||
"""
|
||||
Run code in a sandbox. Supports Python, Javascript, Typescript, R, and Java.
|
||||
|
||||
Args:
|
||||
code (str): The code to run.
|
||||
language (Literal["python", "js", "ts", "r", "java"]): The language of the code.
|
||||
Returns:
|
||||
str: The output of the code, the stdout, the stderr, and error traces (if any).
|
||||
"""
|
||||
|
||||
raise NotImplementedError("This is only available on the latest agent architecture. Please contact the Letta team.")
|
@ -190,7 +190,7 @@ class SleeptimeMultiAgentV2(BaseAgent):
|
||||
prior_messages = []
|
||||
if self.group.sleeptime_agent_frequency:
|
||||
try:
|
||||
prior_messages = self.message_manager.list_messages_for_agent(
|
||||
prior_messages = await self.message_manager.list_messages_for_agent_async(
|
||||
agent_id=foreground_agent_id,
|
||||
actor=self.actor,
|
||||
after=last_processed_message_id,
|
||||
|
@ -1,3 +1,4 @@
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import AsyncGenerator, List, Union
|
||||
@ -74,6 +75,7 @@ class AnthropicStreamingInterface:
|
||||
# usage trackers
|
||||
self.input_tokens = 0
|
||||
self.output_tokens = 0
|
||||
self.model = None
|
||||
|
||||
# reasoning object trackers
|
||||
self.reasoning_messages = []
|
||||
@ -88,7 +90,13 @@ class AnthropicStreamingInterface:
|
||||
|
||||
def get_tool_call_object(self) -> ToolCall:
|
||||
"""Useful for agent loop"""
|
||||
return ToolCall(id=self.tool_call_id, function=FunctionCall(arguments=self.accumulated_tool_call_args, name=self.tool_call_name))
|
||||
# hack for tool rules
|
||||
tool_input = json.loads(self.accumulated_tool_call_args)
|
||||
if "id" in tool_input and tool_input["id"].startswith("toolu_") and "function" in tool_input:
|
||||
arguments = str(json.dumps(tool_input["function"]["arguments"], indent=2))
|
||||
else:
|
||||
arguments = self.accumulated_tool_call_args
|
||||
return ToolCall(id=self.tool_call_id, function=FunctionCall(arguments=arguments, name=self.tool_call_name))
|
||||
|
||||
def _check_inner_thoughts_complete(self, combined_args: str) -> bool:
|
||||
"""
|
||||
@ -311,6 +319,7 @@ class AnthropicStreamingInterface:
|
||||
self.message_id = event.message.id
|
||||
self.input_tokens += event.message.usage.input_tokens
|
||||
self.output_tokens += event.message.usage.output_tokens
|
||||
self.model = event.message.model
|
||||
elif isinstance(event, BetaRawMessageDeltaEvent):
|
||||
self.output_tokens += event.usage.output_tokens
|
||||
elif isinstance(event, BetaRawMessageStopEvent):
|
||||
|
@ -40,6 +40,9 @@ class OpenAIStreamingInterface:
|
||||
self.letta_assistant_message_id = Message.generate_id()
|
||||
self.letta_tool_message_id = Message.generate_id()
|
||||
|
||||
self.message_id = None
|
||||
self.model = None
|
||||
|
||||
# token counters
|
||||
self.input_tokens = 0
|
||||
self.output_tokens = 0
|
||||
@ -69,10 +72,14 @@ class OpenAIStreamingInterface:
|
||||
prev_message_type = None
|
||||
message_index = 0
|
||||
async for chunk in stream:
|
||||
if not self.model or not self.message_id:
|
||||
self.model = chunk.model
|
||||
self.message_id = chunk.id
|
||||
|
||||
# track usage
|
||||
if chunk.usage:
|
||||
self.input_tokens += len(chunk.usage.prompt_tokens)
|
||||
self.output_tokens += len(chunk.usage.completion_tokens)
|
||||
self.input_tokens += chunk.usage.prompt_tokens
|
||||
self.output_tokens += chunk.usage.completion_tokens
|
||||
|
||||
if chunk.choices:
|
||||
choice = chunk.choices[0]
|
||||
|
@ -134,13 +134,13 @@ def anthropic_check_valid_api_key(api_key: Union[str, None]) -> None:
|
||||
|
||||
|
||||
def antropic_get_model_context_window(url: str, api_key: Union[str, None], model: str) -> int:
|
||||
for model_dict in anthropic_get_model_list(url=url, api_key=api_key):
|
||||
for model_dict in anthropic_get_model_list(api_key=api_key):
|
||||
if model_dict["name"] == model:
|
||||
return model_dict["context_window"]
|
||||
raise ValueError(f"Can't find model '{model}' in Anthropic model list")
|
||||
|
||||
|
||||
def anthropic_get_model_list(url: str, api_key: Union[str, None]) -> dict:
|
||||
def anthropic_get_model_list(api_key: Optional[str]) -> dict:
|
||||
"""https://docs.anthropic.com/claude/docs/models-overview"""
|
||||
|
||||
# NOTE: currently there is no GET /models, so we need to hardcode
|
||||
@ -159,6 +159,25 @@ def anthropic_get_model_list(url: str, api_key: Union[str, None]) -> dict:
|
||||
return models_json["data"]
|
||||
|
||||
|
||||
async def anthropic_get_model_list_async(api_key: Optional[str]) -> dict:
|
||||
"""https://docs.anthropic.com/claude/docs/models-overview"""
|
||||
|
||||
# NOTE: currently there is no GET /models, so we need to hardcode
|
||||
# return MODEL_LIST
|
||||
|
||||
if api_key:
|
||||
anthropic_client = anthropic.AsyncAnthropic(api_key=api_key)
|
||||
elif model_settings.anthropic_api_key:
|
||||
anthropic_client = anthropic.AsyncAnthropic()
|
||||
else:
|
||||
raise ValueError("No API key provided")
|
||||
|
||||
models = await anthropic_client.models.list()
|
||||
models_json = models.model_dump()
|
||||
assert "data" in models_json, f"Anthropic model query response missing 'data' field: {models_json}"
|
||||
return models_json["data"]
|
||||
|
||||
|
||||
def convert_tools_to_anthropic_format(tools: List[Tool]) -> List[dict]:
|
||||
"""See: https://docs.anthropic.com/claude/docs/tool-use
|
||||
|
||||
|
@ -35,6 +35,7 @@ from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
|
||||
from letta.schemas.openai.chat_completion_response import Message as ChoiceMessage
|
||||
from letta.schemas.openai.chat_completion_response import ToolCall, UsageStatistics
|
||||
from letta.services.provider_manager import ProviderManager
|
||||
from letta.settings import model_settings
|
||||
from letta.tracing import trace_method
|
||||
|
||||
DUMMY_FIRST_USER_MESSAGE = "User initializing bootup sequence."
|
||||
@ -120,8 +121,16 @@ class AnthropicClient(LLMClientBase):
|
||||
override_key = ProviderManager().get_override_key(llm_config.provider_name, actor=self.actor)
|
||||
|
||||
if async_client:
|
||||
return anthropic.AsyncAnthropic(api_key=override_key) if override_key else anthropic.AsyncAnthropic()
|
||||
return anthropic.Anthropic(api_key=override_key) if override_key else anthropic.Anthropic()
|
||||
return (
|
||||
anthropic.AsyncAnthropic(api_key=override_key, max_retries=model_settings.anthropic_max_retries)
|
||||
if override_key
|
||||
else anthropic.AsyncAnthropic(max_retries=model_settings.anthropic_max_retries)
|
||||
)
|
||||
return (
|
||||
anthropic.Anthropic(api_key=override_key, max_retries=model_settings.anthropic_max_retries)
|
||||
if override_key
|
||||
else anthropic.Anthropic(max_retries=model_settings.anthropic_max_retries)
|
||||
)
|
||||
|
||||
@trace_method
|
||||
def build_request_data(
|
||||
@ -239,6 +248,24 @@ class AnthropicClient(LLMClientBase):
|
||||
|
||||
return data
|
||||
|
||||
async def count_tokens(self, messages: List[dict] = None, model: str = None, tools: List[Tool] = None) -> int:
|
||||
client = anthropic.AsyncAnthropic()
|
||||
if messages and len(messages) == 0:
|
||||
messages = None
|
||||
if tools and len(tools) > 0:
|
||||
anthropic_tools = convert_tools_to_anthropic_format(tools)
|
||||
else:
|
||||
anthropic_tools = None
|
||||
result = await client.beta.messages.count_tokens(
|
||||
model=model or "claude-3-7-sonnet-20250219",
|
||||
messages=messages or [{"role": "user", "content": "hi"}],
|
||||
tools=anthropic_tools or [],
|
||||
)
|
||||
token_count = result.input_tokens
|
||||
if messages is None:
|
||||
token_count -= 8
|
||||
return token_count
|
||||
|
||||
def handle_llm_error(self, e: Exception) -> Exception:
|
||||
if isinstance(e, anthropic.APIConnectionError):
|
||||
logger.warning(f"[Anthropic] API connection error: {e.__cause__}")
|
||||
@ -369,11 +396,11 @@ class AnthropicClient(LLMClientBase):
|
||||
content = strip_xml_tags(string=content_part.text, tag="thinking")
|
||||
if content_part.type == "tool_use":
|
||||
# hack for tool rules
|
||||
input = json.loads(json.dumps(content_part.input))
|
||||
if "id" in input and input["id"].startswith("toolu_") and "function" in input:
|
||||
arguments = str(input["function"]["arguments"])
|
||||
tool_input = json.loads(json.dumps(content_part.input))
|
||||
if "id" in tool_input and tool_input["id"].startswith("toolu_") and "function" in tool_input:
|
||||
arguments = str(tool_input["function"]["arguments"])
|
||||
else:
|
||||
arguments = json.dumps(content_part.input, indent=2)
|
||||
arguments = json.dumps(tool_input, indent=2)
|
||||
tool_calls = [
|
||||
ToolCall(
|
||||
id=content_part.id,
|
||||
|
@ -1,422 +1,21 @@
|
||||
import json
|
||||
import uuid
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import requests
|
||||
import httpx
|
||||
from google import genai
|
||||
from google.genai.types import FunctionCallingConfig, FunctionCallingConfigMode, ToolConfig
|
||||
|
||||
from letta.constants import NON_USER_MSG_PREFIX
|
||||
from letta.errors import ErrorCode, LLMAuthenticationError, LLMError
|
||||
from letta.helpers.datetime_helpers import get_utc_time_int
|
||||
from letta.helpers.json_helpers import json_dumps
|
||||
from letta.llm_api.google_constants import GOOGLE_MODEL_FOR_API_KEY_CHECK
|
||||
from letta.llm_api.helpers import make_post_request
|
||||
from letta.llm_api.llm_client_base import LLMClientBase
|
||||
from letta.local_llm.json_parser import clean_json_string_extra_backslash
|
||||
from letta.local_llm.utils import count_tokens
|
||||
from letta.llm_api.google_vertex_client import GoogleVertexClient
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.enums import ProviderCategory
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.openai.chat_completion_request import Tool
|
||||
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice, FunctionCall, Message, ToolCall, UsageStatistics
|
||||
from letta.settings import model_settings
|
||||
from letta.utils import get_tool_call_id
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class GoogleAIClient(LLMClientBase):
|
||||
class GoogleAIClient(GoogleVertexClient):
|
||||
|
||||
def request(self, request_data: dict, llm_config: LLMConfig) -> dict:
|
||||
"""
|
||||
Performs underlying request to llm and returns raw response.
|
||||
"""
|
||||
api_key = None
|
||||
if llm_config.provider_category == ProviderCategory.byok:
|
||||
from letta.services.provider_manager import ProviderManager
|
||||
|
||||
api_key = ProviderManager().get_override_key(llm_config.provider_name, actor=self.actor)
|
||||
|
||||
if not api_key:
|
||||
api_key = model_settings.gemini_api_key
|
||||
|
||||
# print("[google_ai request]", json.dumps(request_data, indent=2))
|
||||
url, headers = get_gemini_endpoint_and_headers(
|
||||
base_url=str(llm_config.model_endpoint),
|
||||
model=llm_config.model,
|
||||
api_key=str(api_key),
|
||||
key_in_header=True,
|
||||
generate_content=True,
|
||||
)
|
||||
return make_post_request(url, headers, request_data)
|
||||
|
||||
def build_request_data(
|
||||
self,
|
||||
messages: List[PydanticMessage],
|
||||
llm_config: LLMConfig,
|
||||
tools: List[dict],
|
||||
force_tool_call: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Constructs a request object in the expected data format for this client.
|
||||
"""
|
||||
if tools:
|
||||
tools = [{"type": "function", "function": f} for f in tools]
|
||||
tool_objs = [Tool(**t) for t in tools]
|
||||
tool_names = [t.function.name for t in tool_objs]
|
||||
# Convert to the exact payload style Google expects
|
||||
tools = self.convert_tools_to_google_ai_format(tool_objs, llm_config)
|
||||
else:
|
||||
tool_names = []
|
||||
|
||||
contents = self.add_dummy_model_messages(
|
||||
[m.to_google_ai_dict() for m in messages],
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"contents": contents,
|
||||
"tools": tools,
|
||||
"generation_config": {
|
||||
"temperature": llm_config.temperature,
|
||||
"max_output_tokens": llm_config.max_tokens,
|
||||
},
|
||||
}
|
||||
|
||||
# write tool config
|
||||
tool_config = ToolConfig(
|
||||
function_calling_config=FunctionCallingConfig(
|
||||
# ANY mode forces the model to predict only function calls
|
||||
mode=FunctionCallingConfigMode.ANY,
|
||||
# Provide the list of tools (though empty should also work, it seems not to)
|
||||
allowed_function_names=tool_names,
|
||||
)
|
||||
)
|
||||
request_data["tool_config"] = tool_config.model_dump()
|
||||
return request_data
|
||||
|
||||
def convert_response_to_chat_completion(
|
||||
self,
|
||||
response_data: dict,
|
||||
input_messages: List[PydanticMessage],
|
||||
llm_config: LLMConfig,
|
||||
) -> ChatCompletionResponse:
|
||||
"""
|
||||
Converts custom response format from llm client into an OpenAI
|
||||
ChatCompletionsResponse object.
|
||||
|
||||
Example Input:
|
||||
{
|
||||
"candidates": [
|
||||
{
|
||||
"content": {
|
||||
"parts": [
|
||||
{
|
||||
"text": " OK. Barbie is showing in two theaters in Mountain View, CA: AMC Mountain View 16 and Regal Edwards 14."
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
],
|
||||
"usageMetadata": {
|
||||
"promptTokenCount": 9,
|
||||
"candidatesTokenCount": 27,
|
||||
"totalTokenCount": 36
|
||||
}
|
||||
}
|
||||
"""
|
||||
# print("[google_ai response]", json.dumps(response_data, indent=2))
|
||||
|
||||
try:
|
||||
choices = []
|
||||
index = 0
|
||||
for candidate in response_data["candidates"]:
|
||||
content = candidate["content"]
|
||||
|
||||
if "role" not in content or not content["role"]:
|
||||
# This means the response is malformed like MALFORMED_FUNCTION_CALL
|
||||
# NOTE: must be a ValueError to trigger a retry
|
||||
raise ValueError(f"Error in response data from LLM: {response_data}")
|
||||
role = content["role"]
|
||||
assert role == "model", f"Unknown role in response: {role}"
|
||||
|
||||
parts = content["parts"]
|
||||
|
||||
# NOTE: we aren't properly supported multi-parts here anyways (we're just appending choices),
|
||||
# so let's disable it for now
|
||||
|
||||
# NOTE(Apr 9, 2025): there's a very strange bug on 2.5 where the response has a part with broken text
|
||||
# {'candidates': [{'content': {'parts': [{'functionCall': {'name': 'send_message', 'args': {'request_heartbeat': False, 'message': 'Hello! How can I make your day better?', 'inner_thoughts': 'User has initiated contact. Sending a greeting.'}}}], 'role': 'model'}, 'finishReason': 'STOP', 'avgLogprobs': -0.25891534213362066}], 'usageMetadata': {'promptTokenCount': 2493, 'candidatesTokenCount': 29, 'totalTokenCount': 2522, 'promptTokensDetails': [{'modality': 'TEXT', 'tokenCount': 2493}], 'candidatesTokensDetails': [{'modality': 'TEXT', 'tokenCount': 29}]}, 'modelVersion': 'gemini-1.5-pro-002'}
|
||||
# To patch this, if we have multiple parts we can take the last one
|
||||
if len(parts) > 1:
|
||||
logger.warning(f"Unexpected multiple parts in response from Google AI: {parts}")
|
||||
parts = [parts[-1]]
|
||||
|
||||
# TODO support parts / multimodal
|
||||
# TODO support parallel tool calling natively
|
||||
# TODO Alternative here is to throw away everything else except for the first part
|
||||
for response_message in parts:
|
||||
# Convert the actual message style to OpenAI style
|
||||
if "functionCall" in response_message and response_message["functionCall"] is not None:
|
||||
function_call = response_message["functionCall"]
|
||||
assert isinstance(function_call, dict), function_call
|
||||
function_name = function_call["name"]
|
||||
assert isinstance(function_name, str), function_name
|
||||
function_args = function_call["args"]
|
||||
assert isinstance(function_args, dict), function_args
|
||||
|
||||
# NOTE: this also involves stripping the inner monologue out of the function
|
||||
if llm_config.put_inner_thoughts_in_kwargs:
|
||||
from letta.local_llm.constants import INNER_THOUGHTS_KWARG_VERTEX
|
||||
|
||||
assert (
|
||||
INNER_THOUGHTS_KWARG_VERTEX in function_args
|
||||
), f"Couldn't find inner thoughts in function args:\n{function_call}"
|
||||
inner_thoughts = function_args.pop(INNER_THOUGHTS_KWARG_VERTEX)
|
||||
assert inner_thoughts is not None, f"Expected non-null inner thoughts function arg:\n{function_call}"
|
||||
else:
|
||||
inner_thoughts = None
|
||||
|
||||
# Google AI API doesn't generate tool call IDs
|
||||
openai_response_message = Message(
|
||||
role="assistant", # NOTE: "model" -> "assistant"
|
||||
content=inner_thoughts,
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
id=get_tool_call_id(),
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=function_name,
|
||||
arguments=clean_json_string_extra_backslash(json_dumps(function_args)),
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
# Inner thoughts are the content by default
|
||||
inner_thoughts = response_message["text"]
|
||||
|
||||
# Google AI API doesn't generate tool call IDs
|
||||
openai_response_message = Message(
|
||||
role="assistant", # NOTE: "model" -> "assistant"
|
||||
content=inner_thoughts,
|
||||
)
|
||||
|
||||
# Google AI API uses different finish reason strings than OpenAI
|
||||
# OpenAI: 'stop', 'length', 'function_call', 'content_filter', null
|
||||
# see: https://platform.openai.com/docs/guides/text-generation/chat-completions-api
|
||||
# Google AI API: FINISH_REASON_UNSPECIFIED, STOP, MAX_TOKENS, SAFETY, RECITATION, OTHER
|
||||
# see: https://ai.google.dev/api/python/google/ai/generativelanguage/Candidate/FinishReason
|
||||
finish_reason = candidate["finishReason"]
|
||||
if finish_reason == "STOP":
|
||||
openai_finish_reason = (
|
||||
"function_call"
|
||||
if openai_response_message.tool_calls is not None and len(openai_response_message.tool_calls) > 0
|
||||
else "stop"
|
||||
)
|
||||
elif finish_reason == "MAX_TOKENS":
|
||||
openai_finish_reason = "length"
|
||||
elif finish_reason == "SAFETY":
|
||||
openai_finish_reason = "content_filter"
|
||||
elif finish_reason == "RECITATION":
|
||||
openai_finish_reason = "content_filter"
|
||||
else:
|
||||
raise ValueError(f"Unrecognized finish reason in Google AI response: {finish_reason}")
|
||||
|
||||
choices.append(
|
||||
Choice(
|
||||
finish_reason=openai_finish_reason,
|
||||
index=index,
|
||||
message=openai_response_message,
|
||||
)
|
||||
)
|
||||
index += 1
|
||||
|
||||
# if len(choices) > 1:
|
||||
# raise UserWarning(f"Unexpected number of candidates in response (expected 1, got {len(choices)})")
|
||||
|
||||
# NOTE: some of the Google AI APIs show UsageMetadata in the response, but it seems to not exist?
|
||||
# "usageMetadata": {
|
||||
# "promptTokenCount": 9,
|
||||
# "candidatesTokenCount": 27,
|
||||
# "totalTokenCount": 36
|
||||
# }
|
||||
if "usageMetadata" in response_data:
|
||||
usage_data = response_data["usageMetadata"]
|
||||
if "promptTokenCount" not in usage_data:
|
||||
raise ValueError(f"promptTokenCount not found in usageMetadata:\n{json.dumps(usage_data, indent=2)}")
|
||||
if "totalTokenCount" not in usage_data:
|
||||
raise ValueError(f"totalTokenCount not found in usageMetadata:\n{json.dumps(usage_data, indent=2)}")
|
||||
if "candidatesTokenCount" not in usage_data:
|
||||
raise ValueError(f"candidatesTokenCount not found in usageMetadata:\n{json.dumps(usage_data, indent=2)}")
|
||||
|
||||
prompt_tokens = usage_data["promptTokenCount"]
|
||||
completion_tokens = usage_data["candidatesTokenCount"]
|
||||
total_tokens = usage_data["totalTokenCount"]
|
||||
|
||||
usage = UsageStatistics(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
)
|
||||
else:
|
||||
# Count it ourselves
|
||||
assert input_messages is not None, f"Didn't get UsageMetadata from the API response, so input_messages is required"
|
||||
prompt_tokens = count_tokens(json_dumps(input_messages)) # NOTE: this is a very rough approximation
|
||||
completion_tokens = count_tokens(json_dumps(openai_response_message.model_dump())) # NOTE: this is also approximate
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
usage = UsageStatistics(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
)
|
||||
|
||||
response_id = str(uuid.uuid4())
|
||||
return ChatCompletionResponse(
|
||||
id=response_id,
|
||||
choices=choices,
|
||||
model=llm_config.model, # NOTE: Google API doesn't pass back model in the response
|
||||
created=get_utc_time_int(),
|
||||
usage=usage,
|
||||
)
|
||||
except KeyError as e:
|
||||
raise e
|
||||
|
||||
def _clean_google_ai_schema_properties(self, schema_part: dict):
|
||||
"""Recursively clean schema parts to remove unsupported Google AI keywords."""
|
||||
if not isinstance(schema_part, dict):
|
||||
return
|
||||
|
||||
# Per https://ai.google.dev/gemini-api/docs/function-calling?example=meeting#notes_and_limitations
|
||||
# * Only a subset of the OpenAPI schema is supported.
|
||||
# * Supported parameter types in Python are limited.
|
||||
unsupported_keys = ["default", "exclusiveMaximum", "exclusiveMinimum", "additionalProperties"]
|
||||
keys_to_remove_at_this_level = [key for key in unsupported_keys if key in schema_part]
|
||||
for key_to_remove in keys_to_remove_at_this_level:
|
||||
logger.warning(f"Removing unsupported keyword '{key_to_remove}' from schema part.")
|
||||
del schema_part[key_to_remove]
|
||||
|
||||
if schema_part.get("type") == "string" and "format" in schema_part:
|
||||
allowed_formats = ["enum", "date-time"]
|
||||
if schema_part["format"] not in allowed_formats:
|
||||
logger.warning(f"Removing unsupported format '{schema_part['format']}' for string type. Allowed: {allowed_formats}")
|
||||
del schema_part["format"]
|
||||
|
||||
# Check properties within the current level
|
||||
if "properties" in schema_part and isinstance(schema_part["properties"], dict):
|
||||
for prop_name, prop_schema in schema_part["properties"].items():
|
||||
self._clean_google_ai_schema_properties(prop_schema)
|
||||
|
||||
# Check items within arrays
|
||||
if "items" in schema_part and isinstance(schema_part["items"], dict):
|
||||
self._clean_google_ai_schema_properties(schema_part["items"])
|
||||
|
||||
# Check within anyOf, allOf, oneOf lists
|
||||
for key in ["anyOf", "allOf", "oneOf"]:
|
||||
if key in schema_part and isinstance(schema_part[key], list):
|
||||
for item_schema in schema_part[key]:
|
||||
self._clean_google_ai_schema_properties(item_schema)
|
||||
|
||||
def convert_tools_to_google_ai_format(self, tools: List[Tool], llm_config: LLMConfig) -> List[dict]:
|
||||
"""
|
||||
OpenAI style:
|
||||
"tools": [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "find_movies",
|
||||
"description": "find ....",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
PARAM: {
|
||||
"type": PARAM_TYPE, # eg "string"
|
||||
"description": PARAM_DESCRIPTION,
|
||||
},
|
||||
...
|
||||
},
|
||||
"required": List[str],
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
Google AI style:
|
||||
"tools": [{
|
||||
"functionDeclarations": [{
|
||||
"name": "find_movies",
|
||||
"description": "find movie titles currently playing in theaters based on any description, genre, title words, etc.",
|
||||
"parameters": {
|
||||
"type": "OBJECT",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "STRING",
|
||||
"description": "The city and state, e.g. San Francisco, CA or a zip code e.g. 95616"
|
||||
},
|
||||
"description": {
|
||||
"type": "STRING",
|
||||
"description": "Any kind of description including category or genre, title words, attributes, etc."
|
||||
}
|
||||
},
|
||||
"required": ["description"]
|
||||
}
|
||||
}, {
|
||||
"name": "find_theaters",
|
||||
...
|
||||
"""
|
||||
function_list = [
|
||||
dict(
|
||||
name=t.function.name,
|
||||
description=t.function.description,
|
||||
parameters=t.function.parameters, # TODO need to unpack
|
||||
)
|
||||
for t in tools
|
||||
]
|
||||
|
||||
# Add inner thoughts if needed
|
||||
for func in function_list:
|
||||
# Note: Google AI API used to have weird casing requirements, but not any more
|
||||
|
||||
# Google AI API only supports a subset of OpenAPI 3.0, so unsupported params must be cleaned
|
||||
if "parameters" in func and isinstance(func["parameters"], dict):
|
||||
self._clean_google_ai_schema_properties(func["parameters"])
|
||||
|
||||
# Add inner thoughts
|
||||
if llm_config.put_inner_thoughts_in_kwargs:
|
||||
from letta.local_llm.constants import INNER_THOUGHTS_KWARG_DESCRIPTION, INNER_THOUGHTS_KWARG_VERTEX
|
||||
|
||||
func["parameters"]["properties"][INNER_THOUGHTS_KWARG_VERTEX] = {
|
||||
"type": "string",
|
||||
"description": INNER_THOUGHTS_KWARG_DESCRIPTION,
|
||||
}
|
||||
func["parameters"]["required"].append(INNER_THOUGHTS_KWARG_VERTEX)
|
||||
|
||||
return [{"functionDeclarations": function_list}]
|
||||
|
||||
def add_dummy_model_messages(self, messages: List[dict]) -> List[dict]:
|
||||
"""Google AI API requires all function call returns are immediately followed by a 'model' role message.
|
||||
|
||||
In Letta, the 'model' will often call a function (e.g. send_message) that itself yields to the user,
|
||||
so there is no natural follow-up 'model' role message.
|
||||
|
||||
To satisfy the Google AI API restrictions, we can add a dummy 'yield' message
|
||||
with role == 'model' that is placed in-betweeen and function output
|
||||
(role == 'tool') and user message (role == 'user').
|
||||
"""
|
||||
dummy_yield_message = {
|
||||
"role": "model",
|
||||
"parts": [{"text": f"{NON_USER_MSG_PREFIX}Function call returned, waiting for user response."}],
|
||||
}
|
||||
messages_with_padding = []
|
||||
for i, message in enumerate(messages):
|
||||
messages_with_padding.append(message)
|
||||
# Check if the current message role is 'tool' and the next message role is 'user'
|
||||
if message["role"] in ["tool", "function"] and (i + 1 < len(messages) and messages[i + 1]["role"] == "user"):
|
||||
messages_with_padding.append(dummy_yield_message)
|
||||
|
||||
return messages_with_padding
|
||||
def _get_client(self):
|
||||
return genai.Client(api_key=model_settings.gemini_api_key)
|
||||
|
||||
|
||||
def get_gemini_endpoint_and_headers(
|
||||
@ -464,20 +63,24 @@ def google_ai_check_valid_api_key(api_key: str):
|
||||
|
||||
|
||||
def google_ai_get_model_list(base_url: str, api_key: str, key_in_header: bool = True) -> List[dict]:
|
||||
"""Synchronous version to get model list from Google AI API using httpx."""
|
||||
import httpx
|
||||
|
||||
from letta.utils import printd
|
||||
|
||||
url, headers = get_gemini_endpoint_and_headers(base_url, None, api_key, key_in_header)
|
||||
|
||||
try:
|
||||
response = requests.get(url, headers=headers)
|
||||
response.raise_for_status() # Raises HTTPError for 4XX/5XX status
|
||||
response = response.json() # convert to dict from string
|
||||
with httpx.Client() as client:
|
||||
response = client.get(url, headers=headers)
|
||||
response.raise_for_status() # Raises HTTPStatusError for 4XX/5XX status
|
||||
response_data = response.json() # convert to dict from string
|
||||
|
||||
# Grab the models out
|
||||
model_list = response["models"]
|
||||
return model_list
|
||||
# Grab the models out
|
||||
model_list = response_data["models"]
|
||||
return model_list
|
||||
|
||||
except requests.exceptions.HTTPError as http_err:
|
||||
except httpx.HTTPStatusError as http_err:
|
||||
# Handle HTTP errors (e.g., response 4XX, 5XX)
|
||||
printd(f"Got HTTPError, exception={http_err}")
|
||||
# Print the HTTP status code
|
||||
@ -486,8 +89,8 @@ def google_ai_get_model_list(base_url: str, api_key: str, key_in_header: bool =
|
||||
print(f"Message: {http_err.response.text}")
|
||||
raise http_err
|
||||
|
||||
except requests.exceptions.RequestException as req_err:
|
||||
# Handle other requests-related errors (e.g., connection error)
|
||||
except httpx.RequestError as req_err:
|
||||
# Handle other httpx-related errors (e.g., connection error)
|
||||
printd(f"Got RequestException, exception={req_err}")
|
||||
raise req_err
|
||||
|
||||
@ -497,22 +100,74 @@ def google_ai_get_model_list(base_url: str, api_key: str, key_in_header: bool =
|
||||
raise e
|
||||
|
||||
|
||||
def google_ai_get_model_details(base_url: str, api_key: str, model: str, key_in_header: bool = True) -> List[dict]:
|
||||
async def google_ai_get_model_list_async(
|
||||
base_url: str, api_key: str, key_in_header: bool = True, client: Optional[httpx.AsyncClient] = None
|
||||
) -> List[dict]:
|
||||
"""Asynchronous version to get model list from Google AI API using httpx."""
|
||||
from letta.utils import printd
|
||||
|
||||
url, headers = get_gemini_endpoint_and_headers(base_url, None, api_key, key_in_header)
|
||||
|
||||
# Determine if we need to close the client at the end
|
||||
close_client = False
|
||||
if client is None:
|
||||
client = httpx.AsyncClient()
|
||||
close_client = True
|
||||
|
||||
try:
|
||||
response = await client.get(url, headers=headers)
|
||||
response.raise_for_status() # Raises HTTPStatusError for 4XX/5XX status
|
||||
response_data = response.json() # convert to dict from string
|
||||
|
||||
# Grab the models out
|
||||
model_list = response_data["models"]
|
||||
return model_list
|
||||
|
||||
except httpx.HTTPStatusError as http_err:
|
||||
# Handle HTTP errors (e.g., response 4XX, 5XX)
|
||||
printd(f"Got HTTPError, exception={http_err}")
|
||||
# Print the HTTP status code
|
||||
print(f"HTTP Error: {http_err.response.status_code}")
|
||||
# Print the response content (error message from server)
|
||||
print(f"Message: {http_err.response.text}")
|
||||
raise http_err
|
||||
|
||||
except httpx.RequestError as req_err:
|
||||
# Handle other httpx-related errors (e.g., connection error)
|
||||
printd(f"Got RequestException, exception={req_err}")
|
||||
raise req_err
|
||||
|
||||
except Exception as e:
|
||||
# Handle other potential errors
|
||||
printd(f"Got unknown Exception, exception={e}")
|
||||
raise e
|
||||
|
||||
finally:
|
||||
# Close the client if we created it
|
||||
if close_client:
|
||||
await client.aclose()
|
||||
|
||||
|
||||
def google_ai_get_model_details(base_url: str, api_key: str, model: str, key_in_header: bool = True) -> dict:
|
||||
"""Synchronous version to get model details from Google AI API using httpx."""
|
||||
import httpx
|
||||
|
||||
from letta.utils import printd
|
||||
|
||||
url, headers = get_gemini_endpoint_and_headers(base_url, model, api_key, key_in_header)
|
||||
|
||||
try:
|
||||
response = requests.get(url, headers=headers)
|
||||
printd(f"response = {response}")
|
||||
response.raise_for_status() # Raises HTTPError for 4XX/5XX status
|
||||
response = response.json() # convert to dict from string
|
||||
printd(f"response.json = {response}")
|
||||
with httpx.Client() as client:
|
||||
response = client.get(url, headers=headers)
|
||||
printd(f"response = {response}")
|
||||
response.raise_for_status() # Raises HTTPStatusError for 4XX/5XX status
|
||||
response_data = response.json() # convert to dict from string
|
||||
printd(f"response.json = {response_data}")
|
||||
|
||||
# Grab the models out
|
||||
return response
|
||||
# Return the model details
|
||||
return response_data
|
||||
|
||||
except requests.exceptions.HTTPError as http_err:
|
||||
except httpx.HTTPStatusError as http_err:
|
||||
# Handle HTTP errors (e.g., response 4XX, 5XX)
|
||||
printd(f"Got HTTPError, exception={http_err}")
|
||||
# Print the HTTP status code
|
||||
@ -521,8 +176,8 @@ def google_ai_get_model_details(base_url: str, api_key: str, model: str, key_in_
|
||||
print(f"Message: {http_err.response.text}")
|
||||
raise http_err
|
||||
|
||||
except requests.exceptions.RequestException as req_err:
|
||||
# Handle other requests-related errors (e.g., connection error)
|
||||
except httpx.RequestError as req_err:
|
||||
# Handle other httpx-related errors (e.g., connection error)
|
||||
printd(f"Got RequestException, exception={req_err}")
|
||||
raise req_err
|
||||
|
||||
@ -532,8 +187,66 @@ def google_ai_get_model_details(base_url: str, api_key: str, model: str, key_in_
|
||||
raise e
|
||||
|
||||
|
||||
async def google_ai_get_model_details_async(
|
||||
base_url: str, api_key: str, model: str, key_in_header: bool = True, client: Optional[httpx.AsyncClient] = None
|
||||
) -> dict:
|
||||
"""Asynchronous version to get model details from Google AI API using httpx."""
|
||||
import httpx
|
||||
|
||||
from letta.utils import printd
|
||||
|
||||
url, headers = get_gemini_endpoint_and_headers(base_url, model, api_key, key_in_header)
|
||||
|
||||
# Determine if we need to close the client at the end
|
||||
close_client = False
|
||||
if client is None:
|
||||
client = httpx.AsyncClient()
|
||||
close_client = True
|
||||
|
||||
try:
|
||||
response = await client.get(url, headers=headers)
|
||||
printd(f"response = {response}")
|
||||
response.raise_for_status() # Raises HTTPStatusError for 4XX/5XX status
|
||||
response_data = response.json() # convert to dict from string
|
||||
printd(f"response.json = {response_data}")
|
||||
|
||||
# Return the model details
|
||||
return response_data
|
||||
|
||||
except httpx.HTTPStatusError as http_err:
|
||||
# Handle HTTP errors (e.g., response 4XX, 5XX)
|
||||
printd(f"Got HTTPError, exception={http_err}")
|
||||
# Print the HTTP status code
|
||||
print(f"HTTP Error: {http_err.response.status_code}")
|
||||
# Print the response content (error message from server)
|
||||
print(f"Message: {http_err.response.text}")
|
||||
raise http_err
|
||||
|
||||
except httpx.RequestError as req_err:
|
||||
# Handle other httpx-related errors (e.g., connection error)
|
||||
printd(f"Got RequestException, exception={req_err}")
|
||||
raise req_err
|
||||
|
||||
except Exception as e:
|
||||
# Handle other potential errors
|
||||
printd(f"Got unknown Exception, exception={e}")
|
||||
raise e
|
||||
|
||||
finally:
|
||||
# Close the client if we created it
|
||||
if close_client:
|
||||
await client.aclose()
|
||||
|
||||
|
||||
def google_ai_get_model_context_window(base_url: str, api_key: str, model: str, key_in_header: bool = True) -> int:
|
||||
model_details = google_ai_get_model_details(base_url=base_url, api_key=api_key, model=model, key_in_header=key_in_header)
|
||||
# TODO should this be:
|
||||
# return model_details["inputTokenLimit"] + model_details["outputTokenLimit"]
|
||||
return int(model_details["inputTokenLimit"])
|
||||
|
||||
|
||||
async def google_ai_get_model_context_window_async(base_url: str, api_key: str, model: str, key_in_header: bool = True) -> int:
|
||||
model_details = await google_ai_get_model_details_async(base_url=base_url, api_key=api_key, model=model, key_in_header=key_in_header)
|
||||
# TODO should this be:
|
||||
# return model_details["inputTokenLimit"] + model_details["outputTokenLimit"]
|
||||
return int(model_details["inputTokenLimit"])
|
||||
|
@ -5,14 +5,16 @@ from typing import List, Optional
|
||||
from google import genai
|
||||
from google.genai.types import FunctionCallingConfig, FunctionCallingConfigMode, GenerateContentResponse, ThinkingConfig, ToolConfig
|
||||
|
||||
from letta.constants import NON_USER_MSG_PREFIX
|
||||
from letta.helpers.datetime_helpers import get_utc_time_int
|
||||
from letta.helpers.json_helpers import json_dumps, json_loads
|
||||
from letta.llm_api.google_ai_client import GoogleAIClient
|
||||
from letta.llm_api.llm_client_base import LLMClientBase
|
||||
from letta.local_llm.json_parser import clean_json_string_extra_backslash
|
||||
from letta.local_llm.utils import count_tokens
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.openai.chat_completion_request import Tool
|
||||
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice, FunctionCall, Message, ToolCall, UsageStatistics
|
||||
from letta.settings import model_settings, settings
|
||||
from letta.utils import get_tool_call_id
|
||||
@ -20,18 +22,21 @@ from letta.utils import get_tool_call_id
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class GoogleVertexClient(GoogleAIClient):
|
||||
class GoogleVertexClient(LLMClientBase):
|
||||
|
||||
def request(self, request_data: dict, llm_config: LLMConfig) -> dict:
|
||||
"""
|
||||
Performs underlying request to llm and returns raw response.
|
||||
"""
|
||||
client = genai.Client(
|
||||
def _get_client(self):
|
||||
return genai.Client(
|
||||
vertexai=True,
|
||||
project=model_settings.google_cloud_project,
|
||||
location=model_settings.google_cloud_location,
|
||||
http_options={"api_version": "v1"},
|
||||
)
|
||||
|
||||
def request(self, request_data: dict, llm_config: LLMConfig) -> dict:
|
||||
"""
|
||||
Performs underlying request to llm and returns raw response.
|
||||
"""
|
||||
client = self._get_client()
|
||||
response = client.models.generate_content(
|
||||
model=llm_config.model,
|
||||
contents=request_data["contents"],
|
||||
@ -43,12 +48,7 @@ class GoogleVertexClient(GoogleAIClient):
|
||||
"""
|
||||
Performs underlying request to llm and returns raw response.
|
||||
"""
|
||||
client = genai.Client(
|
||||
vertexai=True,
|
||||
project=model_settings.google_cloud_project,
|
||||
location=model_settings.google_cloud_location,
|
||||
http_options={"api_version": "v1"},
|
||||
)
|
||||
client = self._get_client()
|
||||
response = await client.aio.models.generate_content(
|
||||
model=llm_config.model,
|
||||
contents=request_data["contents"],
|
||||
@ -56,6 +56,139 @@ class GoogleVertexClient(GoogleAIClient):
|
||||
)
|
||||
return response.model_dump()
|
||||
|
||||
def add_dummy_model_messages(self, messages: List[dict]) -> List[dict]:
|
||||
"""Google AI API requires all function call returns are immediately followed by a 'model' role message.
|
||||
|
||||
In Letta, the 'model' will often call a function (e.g. send_message) that itself yields to the user,
|
||||
so there is no natural follow-up 'model' role message.
|
||||
|
||||
To satisfy the Google AI API restrictions, we can add a dummy 'yield' message
|
||||
with role == 'model' that is placed in-betweeen and function output
|
||||
(role == 'tool') and user message (role == 'user').
|
||||
"""
|
||||
dummy_yield_message = {
|
||||
"role": "model",
|
||||
"parts": [{"text": f"{NON_USER_MSG_PREFIX}Function call returned, waiting for user response."}],
|
||||
}
|
||||
messages_with_padding = []
|
||||
for i, message in enumerate(messages):
|
||||
messages_with_padding.append(message)
|
||||
# Check if the current message role is 'tool' and the next message role is 'user'
|
||||
if message["role"] in ["tool", "function"] and (i + 1 < len(messages) and messages[i + 1]["role"] == "user"):
|
||||
messages_with_padding.append(dummy_yield_message)
|
||||
|
||||
return messages_with_padding
|
||||
|
||||
def _clean_google_ai_schema_properties(self, schema_part: dict):
|
||||
"""Recursively clean schema parts to remove unsupported Google AI keywords."""
|
||||
if not isinstance(schema_part, dict):
|
||||
return
|
||||
|
||||
# Per https://ai.google.dev/gemini-api/docs/function-calling?example=meeting#notes_and_limitations
|
||||
# * Only a subset of the OpenAPI schema is supported.
|
||||
# * Supported parameter types in Python are limited.
|
||||
unsupported_keys = ["default", "exclusiveMaximum", "exclusiveMinimum", "additionalProperties"]
|
||||
keys_to_remove_at_this_level = [key for key in unsupported_keys if key in schema_part]
|
||||
for key_to_remove in keys_to_remove_at_this_level:
|
||||
logger.warning(f"Removing unsupported keyword '{key_to_remove}' from schema part.")
|
||||
del schema_part[key_to_remove]
|
||||
|
||||
if schema_part.get("type") == "string" and "format" in schema_part:
|
||||
allowed_formats = ["enum", "date-time"]
|
||||
if schema_part["format"] not in allowed_formats:
|
||||
logger.warning(f"Removing unsupported format '{schema_part['format']}' for string type. Allowed: {allowed_formats}")
|
||||
del schema_part["format"]
|
||||
|
||||
# Check properties within the current level
|
||||
if "properties" in schema_part and isinstance(schema_part["properties"], dict):
|
||||
for prop_name, prop_schema in schema_part["properties"].items():
|
||||
self._clean_google_ai_schema_properties(prop_schema)
|
||||
|
||||
# Check items within arrays
|
||||
if "items" in schema_part and isinstance(schema_part["items"], dict):
|
||||
self._clean_google_ai_schema_properties(schema_part["items"])
|
||||
|
||||
# Check within anyOf, allOf, oneOf lists
|
||||
for key in ["anyOf", "allOf", "oneOf"]:
|
||||
if key in schema_part and isinstance(schema_part[key], list):
|
||||
for item_schema in schema_part[key]:
|
||||
self._clean_google_ai_schema_properties(item_schema)
|
||||
|
||||
def convert_tools_to_google_ai_format(self, tools: List[Tool], llm_config: LLMConfig) -> List[dict]:
|
||||
"""
|
||||
OpenAI style:
|
||||
"tools": [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "find_movies",
|
||||
"description": "find ....",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
PARAM: {
|
||||
"type": PARAM_TYPE, # eg "string"
|
||||
"description": PARAM_DESCRIPTION,
|
||||
},
|
||||
...
|
||||
},
|
||||
"required": List[str],
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
Google AI style:
|
||||
"tools": [{
|
||||
"functionDeclarations": [{
|
||||
"name": "find_movies",
|
||||
"description": "find movie titles currently playing in theaters based on any description, genre, title words, etc.",
|
||||
"parameters": {
|
||||
"type": "OBJECT",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "STRING",
|
||||
"description": "The city and state, e.g. San Francisco, CA or a zip code e.g. 95616"
|
||||
},
|
||||
"description": {
|
||||
"type": "STRING",
|
||||
"description": "Any kind of description including category or genre, title words, attributes, etc."
|
||||
}
|
||||
},
|
||||
"required": ["description"]
|
||||
}
|
||||
}, {
|
||||
"name": "find_theaters",
|
||||
...
|
||||
"""
|
||||
function_list = [
|
||||
dict(
|
||||
name=t.function.name,
|
||||
description=t.function.description,
|
||||
parameters=t.function.parameters, # TODO need to unpack
|
||||
)
|
||||
for t in tools
|
||||
]
|
||||
|
||||
# Add inner thoughts if needed
|
||||
for func in function_list:
|
||||
# Note: Google AI API used to have weird casing requirements, but not any more
|
||||
|
||||
# Google AI API only supports a subset of OpenAPI 3.0, so unsupported params must be cleaned
|
||||
if "parameters" in func and isinstance(func["parameters"], dict):
|
||||
self._clean_google_ai_schema_properties(func["parameters"])
|
||||
|
||||
# Add inner thoughts
|
||||
if llm_config.put_inner_thoughts_in_kwargs:
|
||||
from letta.local_llm.constants import INNER_THOUGHTS_KWARG_DESCRIPTION, INNER_THOUGHTS_KWARG_VERTEX
|
||||
|
||||
func["parameters"]["properties"][INNER_THOUGHTS_KWARG_VERTEX] = {
|
||||
"type": "string",
|
||||
"description": INNER_THOUGHTS_KWARG_DESCRIPTION,
|
||||
}
|
||||
func["parameters"]["required"].append(INNER_THOUGHTS_KWARG_VERTEX)
|
||||
|
||||
return [{"functionDeclarations": function_list}]
|
||||
|
||||
def build_request_data(
|
||||
self,
|
||||
messages: List[PydanticMessage],
|
||||
@ -66,11 +199,29 @@ class GoogleVertexClient(GoogleAIClient):
|
||||
"""
|
||||
Constructs a request object in the expected data format for this client.
|
||||
"""
|
||||
request_data = super().build_request_data(messages, llm_config, tools, force_tool_call)
|
||||
request_data["config"] = request_data.pop("generation_config")
|
||||
request_data["config"]["tools"] = request_data.pop("tools")
|
||||
|
||||
tool_names = [t["name"] for t in tools] if tools else []
|
||||
if tools:
|
||||
tool_objs = [Tool(type="function", function=t) for t in tools]
|
||||
tool_names = [t.function.name for t in tool_objs]
|
||||
# Convert to the exact payload style Google expects
|
||||
formatted_tools = self.convert_tools_to_google_ai_format(tool_objs, llm_config)
|
||||
else:
|
||||
formatted_tools = []
|
||||
tool_names = []
|
||||
|
||||
contents = self.add_dummy_model_messages(
|
||||
[m.to_google_ai_dict() for m in messages],
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"contents": contents,
|
||||
"config": {
|
||||
"temperature": llm_config.temperature,
|
||||
"max_output_tokens": llm_config.max_tokens,
|
||||
"tools": formatted_tools,
|
||||
},
|
||||
}
|
||||
|
||||
if len(tool_names) == 1 and settings.use_vertex_structured_outputs_experimental:
|
||||
request_data["config"]["response_mime_type"] = "application/json"
|
||||
request_data["config"]["response_schema"] = self.get_function_call_response_schema(tools[0])
|
||||
@ -89,11 +240,11 @@ class GoogleVertexClient(GoogleAIClient):
|
||||
# Add thinking_config
|
||||
# If enable_reasoner is False, set thinking_budget to 0
|
||||
# Otherwise, use the value from max_reasoning_tokens
|
||||
thinking_budget = 0 if not llm_config.enable_reasoner else llm_config.max_reasoning_tokens
|
||||
thinking_config = ThinkingConfig(
|
||||
thinking_budget=thinking_budget,
|
||||
)
|
||||
request_data["config"]["thinking_config"] = thinking_config.model_dump()
|
||||
if llm_config.enable_reasoner:
|
||||
thinking_config = ThinkingConfig(
|
||||
thinking_budget=llm_config.max_reasoning_tokens,
|
||||
)
|
||||
request_data["config"]["thinking_config"] = thinking_config.model_dump()
|
||||
|
||||
return request_data
|
||||
|
||||
|
@ -20,15 +20,19 @@ from letta.llm_api.openai import (
|
||||
build_openai_chat_completions_request,
|
||||
openai_chat_completions_process_stream,
|
||||
openai_chat_completions_request,
|
||||
prepare_openai_payload,
|
||||
)
|
||||
from letta.local_llm.chat_completion_proxy import get_chat_completion
|
||||
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION
|
||||
from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages
|
||||
from letta.orm.user import User
|
||||
from letta.schemas.enums import ProviderCategory
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, cast_message_to_subtype
|
||||
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
|
||||
from letta.schemas.provider_trace import ProviderTraceCreate
|
||||
from letta.services.telemetry_manager import TelemetryManager
|
||||
from letta.settings import ModelSettings
|
||||
from letta.streaming_interface import AgentChunkStreamingInterface, AgentRefreshStreamingInterface
|
||||
from letta.tracing import log_event, trace_method
|
||||
@ -142,6 +146,9 @@ def create(
|
||||
model_settings: Optional[dict] = None, # TODO: eventually pass from server
|
||||
put_inner_thoughts_first: bool = True,
|
||||
name: Optional[str] = None,
|
||||
telemetry_manager: Optional[TelemetryManager] = None,
|
||||
step_id: Optional[str] = None,
|
||||
actor: Optional[User] = None,
|
||||
) -> ChatCompletionResponse:
|
||||
"""Return response to chat completion with backoff"""
|
||||
from letta.utils import printd
|
||||
@ -233,6 +240,16 @@ def create(
|
||||
if isinstance(stream_interface, AgentChunkStreamingInterface):
|
||||
stream_interface.stream_end()
|
||||
|
||||
telemetry_manager.create_provider_trace(
|
||||
actor=actor,
|
||||
provider_trace_create=ProviderTraceCreate(
|
||||
request_json=prepare_openai_payload(data),
|
||||
response_json=response.model_json_schema(),
|
||||
step_id=step_id,
|
||||
organization_id=actor.organization_id,
|
||||
),
|
||||
)
|
||||
|
||||
if llm_config.put_inner_thoughts_in_kwargs:
|
||||
response = unpack_all_inner_thoughts_from_kwargs(response=response, inner_thoughts_key=INNER_THOUGHTS_KWARG)
|
||||
|
||||
@ -407,6 +424,16 @@ def create(
|
||||
if llm_config.put_inner_thoughts_in_kwargs:
|
||||
response = unpack_all_inner_thoughts_from_kwargs(response=response, inner_thoughts_key=INNER_THOUGHTS_KWARG)
|
||||
|
||||
telemetry_manager.create_provider_trace(
|
||||
actor=actor,
|
||||
provider_trace_create=ProviderTraceCreate(
|
||||
request_json=chat_completion_request.model_json_schema(),
|
||||
response_json=response.model_json_schema(),
|
||||
step_id=step_id,
|
||||
organization_id=actor.organization_id,
|
||||
),
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
# elif llm_config.model_endpoint_type == "cohere":
|
||||
|
@ -51,7 +51,7 @@ class LLMClient:
|
||||
put_inner_thoughts_first=put_inner_thoughts_first,
|
||||
actor=actor,
|
||||
)
|
||||
case ProviderType.openai:
|
||||
case ProviderType.openai | ProviderType.together:
|
||||
from letta.llm_api.openai_client import OpenAIClient
|
||||
|
||||
return OpenAIClient(
|
||||
|
@ -9,7 +9,9 @@ from letta.errors import LLMError
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
|
||||
from letta.tracing import log_event
|
||||
from letta.schemas.provider_trace import ProviderTraceCreate
|
||||
from letta.services.telemetry_manager import TelemetryManager
|
||||
from letta.tracing import log_event, trace_method
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.orm import User
|
||||
@ -31,13 +33,15 @@ class LLMClientBase:
|
||||
self.put_inner_thoughts_first = put_inner_thoughts_first
|
||||
self.use_tool_naming = use_tool_naming
|
||||
|
||||
@trace_method
|
||||
def send_llm_request(
|
||||
self,
|
||||
messages: List[Message],
|
||||
llm_config: LLMConfig,
|
||||
tools: Optional[List[dict]] = None, # TODO: change to Tool object
|
||||
stream: bool = False,
|
||||
force_tool_call: Optional[str] = None,
|
||||
telemetry_manager: Optional["TelemetryManager"] = None,
|
||||
step_id: Optional[str] = None,
|
||||
) -> Union[ChatCompletionResponse, Stream[ChatCompletionChunk]]:
|
||||
"""
|
||||
Issues a request to the downstream model endpoint and parses response.
|
||||
@ -48,37 +52,51 @@ class LLMClientBase:
|
||||
|
||||
try:
|
||||
log_event(name="llm_request_sent", attributes=request_data)
|
||||
if stream:
|
||||
return self.stream(request_data, llm_config)
|
||||
else:
|
||||
response_data = self.request(request_data, llm_config)
|
||||
response_data = self.request(request_data, llm_config)
|
||||
if step_id and telemetry_manager:
|
||||
telemetry_manager.create_provider_trace(
|
||||
actor=self.actor,
|
||||
provider_trace_create=ProviderTraceCreate(
|
||||
request_json=request_data,
|
||||
response_json=response_data,
|
||||
step_id=step_id,
|
||||
organization_id=self.actor.organization_id,
|
||||
),
|
||||
)
|
||||
log_event(name="llm_response_received", attributes=response_data)
|
||||
except Exception as e:
|
||||
raise self.handle_llm_error(e)
|
||||
|
||||
return self.convert_response_to_chat_completion(response_data, messages, llm_config)
|
||||
|
||||
@trace_method
|
||||
async def send_llm_request_async(
|
||||
self,
|
||||
request_data: dict,
|
||||
messages: List[Message],
|
||||
llm_config: LLMConfig,
|
||||
tools: Optional[List[dict]] = None, # TODO: change to Tool object
|
||||
stream: bool = False,
|
||||
force_tool_call: Optional[str] = None,
|
||||
telemetry_manager: "TelemetryManager | None" = None,
|
||||
step_id: str | None = None,
|
||||
) -> Union[ChatCompletionResponse, AsyncStream[ChatCompletionChunk]]:
|
||||
"""
|
||||
Issues a request to the downstream model endpoint.
|
||||
If stream=True, returns an AsyncStream[ChatCompletionChunk] that can be async iterated over.
|
||||
Otherwise returns a ChatCompletionResponse.
|
||||
"""
|
||||
request_data = self.build_request_data(messages, llm_config, tools, force_tool_call)
|
||||
|
||||
try:
|
||||
log_event(name="llm_request_sent", attributes=request_data)
|
||||
if stream:
|
||||
return await self.stream_async(request_data, llm_config)
|
||||
else:
|
||||
response_data = await self.request_async(request_data, llm_config)
|
||||
response_data = await self.request_async(request_data, llm_config)
|
||||
await telemetry_manager.create_provider_trace_async(
|
||||
actor=self.actor,
|
||||
provider_trace_create=ProviderTraceCreate(
|
||||
request_json=request_data,
|
||||
response_json=response_data,
|
||||
step_id=step_id,
|
||||
organization_id=self.actor.organization_id,
|
||||
),
|
||||
)
|
||||
|
||||
log_event(name="llm_response_received", attributes=response_data)
|
||||
except Exception as e:
|
||||
raise self.handle_llm_error(e)
|
||||
@ -133,13 +151,6 @@ class LLMClientBase:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def stream(self, request_data: dict, llm_config: LLMConfig) -> Stream[ChatCompletionChunk]:
|
||||
"""
|
||||
Performs underlying streaming request to llm and returns raw response.
|
||||
"""
|
||||
raise NotImplementedError(f"Streaming is not supported for {llm_config.model_endpoint_type}")
|
||||
|
||||
@abstractmethod
|
||||
async def stream_async(self, request_data: dict, llm_config: LLMConfig) -> AsyncStream[ChatCompletionChunk]:
|
||||
"""
|
||||
|
@ -1,6 +1,7 @@
|
||||
import warnings
|
||||
from typing import Generator, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
import requests
|
||||
from openai import OpenAI
|
||||
|
||||
@ -110,6 +111,62 @@ def openai_get_model_list(url: str, api_key: Optional[str] = None, fix_url: bool
|
||||
raise e
|
||||
|
||||
|
||||
async def openai_get_model_list_async(
|
||||
url: str,
|
||||
api_key: Optional[str] = None,
|
||||
fix_url: bool = False,
|
||||
extra_params: Optional[dict] = None,
|
||||
client: Optional["httpx.AsyncClient"] = None,
|
||||
) -> dict:
|
||||
"""https://platform.openai.com/docs/api-reference/models/list"""
|
||||
from letta.utils import printd
|
||||
|
||||
# In some cases we may want to double-check the URL and do basic correction
|
||||
if fix_url and not url.endswith("/v1"):
|
||||
url = smart_urljoin(url, "v1")
|
||||
|
||||
url = smart_urljoin(url, "models")
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if api_key is not None:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
printd(f"Sending request to {url}")
|
||||
|
||||
# Use provided client or create a new one
|
||||
close_client = False
|
||||
if client is None:
|
||||
client = httpx.AsyncClient()
|
||||
close_client = True
|
||||
|
||||
try:
|
||||
response = await client.get(url, headers=headers, params=extra_params)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
printd(f"response = {result}")
|
||||
return result
|
||||
except httpx.HTTPStatusError as http_err:
|
||||
# Handle HTTP errors (e.g., response 4XX, 5XX)
|
||||
error_response = None
|
||||
try:
|
||||
error_response = http_err.response.json()
|
||||
except:
|
||||
error_response = {"status_code": http_err.response.status_code, "text": http_err.response.text}
|
||||
printd(f"Got HTTPError, exception={http_err}, response={error_response}")
|
||||
raise http_err
|
||||
except httpx.RequestError as req_err:
|
||||
# Handle other httpx-related errors (e.g., connection error)
|
||||
printd(f"Got RequestException, exception={req_err}")
|
||||
raise req_err
|
||||
except Exception as e:
|
||||
# Handle other potential errors
|
||||
printd(f"Got unknown Exception, exception={e}")
|
||||
raise e
|
||||
finally:
|
||||
if close_client:
|
||||
await client.aclose()
|
||||
|
||||
|
||||
def build_openai_chat_completions_request(
|
||||
llm_config: LLMConfig,
|
||||
messages: List[_Message],
|
||||
|
@ -2,7 +2,7 @@ import os
|
||||
from typing import List, Optional
|
||||
|
||||
import openai
|
||||
from openai import AsyncOpenAI, AsyncStream, OpenAI, Stream
|
||||
from openai import AsyncOpenAI, AsyncStream, OpenAI
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
|
||||
@ -22,7 +22,7 @@ from letta.llm_api.helpers import add_inner_thoughts_to_functions, convert_to_st
|
||||
from letta.llm_api.llm_client_base import LLMClientBase
|
||||
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION, INNER_THOUGHTS_KWARG_DESCRIPTION_GO_FIRST
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.enums import ProviderCategory
|
||||
from letta.schemas.enums import ProviderCategory, ProviderType
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest
|
||||
@ -113,6 +113,8 @@ class OpenAIClient(LLMClientBase):
|
||||
from letta.services.provider_manager import ProviderManager
|
||||
|
||||
api_key = ProviderManager().get_override_key(llm_config.provider_name, actor=self.actor)
|
||||
if llm_config.model_endpoint_type == ProviderType.together:
|
||||
api_key = model_settings.together_api_key or os.environ.get("TOGETHER_API_KEY")
|
||||
|
||||
if not api_key:
|
||||
api_key = model_settings.openai_api_key or os.environ.get("OPENAI_API_KEY")
|
||||
@ -254,20 +256,14 @@ class OpenAIClient(LLMClientBase):
|
||||
|
||||
return chat_completion_response
|
||||
|
||||
def stream(self, request_data: dict, llm_config: LLMConfig) -> Stream[ChatCompletionChunk]:
|
||||
"""
|
||||
Performs underlying streaming request to OpenAI and returns the stream iterator.
|
||||
"""
|
||||
client = OpenAI(**self._prepare_client_kwargs(llm_config))
|
||||
response_stream: Stream[ChatCompletionChunk] = client.chat.completions.create(**request_data, stream=True)
|
||||
return response_stream
|
||||
|
||||
async def stream_async(self, request_data: dict, llm_config: LLMConfig) -> AsyncStream[ChatCompletionChunk]:
|
||||
"""
|
||||
Performs underlying asynchronous streaming request to OpenAI and returns the async stream iterator.
|
||||
"""
|
||||
client = AsyncOpenAI(**self._prepare_client_kwargs(llm_config))
|
||||
response_stream: AsyncStream[ChatCompletionChunk] = await client.chat.completions.create(**request_data, stream=True)
|
||||
response_stream: AsyncStream[ChatCompletionChunk] = await client.chat.completions.create(
|
||||
**request_data, stream=True, stream_options={"include_usage": True}
|
||||
)
|
||||
return response_stream
|
||||
|
||||
def handle_llm_error(self, e: Exception) -> Exception:
|
||||
|
@ -93,7 +93,6 @@ def summarize_messages(
|
||||
response = llm_client.send_llm_request(
|
||||
messages=message_sequence,
|
||||
llm_config=llm_config_no_inner_thoughts,
|
||||
stream=False,
|
||||
)
|
||||
else:
|
||||
response = create(
|
||||
|
@ -19,6 +19,7 @@ from letta.orm.message import Message
|
||||
from letta.orm.organization import Organization
|
||||
from letta.orm.passage import AgentPassage, BasePassage, SourcePassage
|
||||
from letta.orm.provider import Provider
|
||||
from letta.orm.provider_trace import ProviderTrace
|
||||
from letta.orm.sandbox_config import AgentEnvironmentVariable, SandboxConfig, SandboxEnvironmentVariable
|
||||
from letta.orm.source import Source
|
||||
from letta.orm.sources_agents import SourcesAgents
|
||||
|
@ -8,6 +8,7 @@ class ToolType(str, Enum):
|
||||
LETTA_MULTI_AGENT_CORE = "letta_multi_agent_core"
|
||||
LETTA_SLEEPTIME_CORE = "letta_sleeptime_core"
|
||||
LETTA_VOICE_SLEEPTIME_CORE = "letta_voice_sleeptime_core"
|
||||
LETTA_BUILTIN = "letta_builtin"
|
||||
EXTERNAL_COMPOSIO = "external_composio"
|
||||
EXTERNAL_LANGCHAIN = "external_langchain"
|
||||
# TODO is "external" the right name here? Since as of now, MCP is local / doesn't support remote?
|
||||
|
26
letta/orm/provider_trace.py
Normal file
26
letta/orm/provider_trace.py
Normal file
@ -0,0 +1,26 @@
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import JSON, Index, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from letta.orm.mixins import OrganizationMixin
|
||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||
from letta.schemas.provider_trace import ProviderTrace as PydanticProviderTrace
|
||||
|
||||
|
||||
class ProviderTrace(SqlalchemyBase, OrganizationMixin):
|
||||
"""Defines data model for storing provider trace information"""
|
||||
|
||||
__tablename__ = "provider_traces"
|
||||
__pydantic_model__ = PydanticProviderTrace
|
||||
__table_args__ = (Index("ix_step_id", "step_id"),)
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
primary_key=True, doc="Unique provider trace identifier", default=lambda: f"provider_trace-{uuid.uuid4()}"
|
||||
)
|
||||
request_json: Mapped[dict] = mapped_column(JSON, doc="JSON content of the provider request")
|
||||
response_json: Mapped[dict] = mapped_column(JSON, doc="JSON content of the provider response")
|
||||
step_id: Mapped[str] = mapped_column(String, nullable=True, doc="ID of the step that this trace is associated with")
|
||||
|
||||
# Relationships
|
||||
organization: Mapped["Organization"] = relationship("Organization", lazy="selectin")
|
@ -35,6 +35,7 @@ class Step(SqlalchemyBase):
|
||||
)
|
||||
agent_id: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="The name of the model used for this step.")
|
||||
provider_name: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="The name of the provider used for this step.")
|
||||
provider_category: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="The category of the provider used for this step.")
|
||||
model: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="The name of the model used for this step.")
|
||||
model_endpoint: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="The model endpoint url used for this step.")
|
||||
context_window_limit: Mapped[Optional[int]] = mapped_column(
|
||||
|
43
letta/schemas/provider_trace.py
Normal file
43
letta/schemas/provider_trace.py
Normal file
@ -0,0 +1,43 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
from letta.schemas.letta_base import OrmMetadataBase
|
||||
|
||||
|
||||
class BaseProviderTrace(OrmMetadataBase):
|
||||
__id_prefix__ = "provider_trace"
|
||||
|
||||
|
||||
class ProviderTraceCreate(BaseModel):
|
||||
"""Request to create a provider trace"""
|
||||
|
||||
request_json: dict[str, Any] = Field(..., description="JSON content of the provider request")
|
||||
response_json: dict[str, Any] = Field(..., description="JSON content of the provider response")
|
||||
step_id: str = Field(None, description="ID of the step that this trace is associated with")
|
||||
organization_id: str = Field(..., description="The unique identifier of the organization.")
|
||||
|
||||
|
||||
class ProviderTrace(BaseProviderTrace):
|
||||
"""
|
||||
Letta's internal representation of a provider trace.
|
||||
|
||||
Attributes:
|
||||
id (str): The unique identifier of the provider trace.
|
||||
request_json (Dict[str, Any]): JSON content of the provider request.
|
||||
response_json (Dict[str, Any]): JSON content of the provider response.
|
||||
step_id (str): ID of the step that this trace is associated with.
|
||||
organization_id (str): The unique identifier of the organization.
|
||||
created_at (datetime): The timestamp when the object was created.
|
||||
"""
|
||||
|
||||
id: str = BaseProviderTrace.generate_id_field()
|
||||
request_json: Dict[str, Any] = Field(..., description="JSON content of the provider request")
|
||||
response_json: Dict[str, Any] = Field(..., description="JSON content of the provider response")
|
||||
step_id: Optional[str] = Field(None, description="ID of the step that this trace is associated with")
|
||||
organization_id: str = Field(..., description="The unique identifier of the organization.")
|
||||
created_at: datetime = Field(default_factory=get_utc_time, description="The timestamp when the object was created.")
|
@ -47,12 +47,21 @@ class Provider(ProviderBase):
|
||||
def list_llm_models(self) -> List[LLMConfig]:
|
||||
return []
|
||||
|
||||
async def list_llm_models_async(self) -> List[LLMConfig]:
|
||||
return []
|
||||
|
||||
def list_embedding_models(self) -> List[EmbeddingConfig]:
|
||||
return []
|
||||
|
||||
async def list_embedding_models_async(self) -> List[EmbeddingConfig]:
|
||||
return []
|
||||
|
||||
def get_model_context_window(self, model_name: str) -> Optional[int]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_model_context_window_async(self, model_name: str) -> Optional[int]:
|
||||
raise NotImplementedError
|
||||
|
||||
def provider_tag(self) -> str:
|
||||
"""String representation of the provider for display purposes"""
|
||||
raise NotImplementedError
|
||||
@ -140,6 +149,19 @@ class LettaProvider(Provider):
|
||||
)
|
||||
]
|
||||
|
||||
async def list_llm_models_async(self) -> List[LLMConfig]:
|
||||
return [
|
||||
LLMConfig(
|
||||
model="letta-free", # NOTE: renamed
|
||||
model_endpoint_type="openai",
|
||||
model_endpoint=LETTA_MODEL_ENDPOINT,
|
||||
context_window=8192,
|
||||
handle=self.get_handle("letta-free"),
|
||||
provider_name=self.name,
|
||||
provider_category=self.provider_category,
|
||||
)
|
||||
]
|
||||
|
||||
def list_embedding_models(self):
|
||||
return [
|
||||
EmbeddingConfig(
|
||||
@ -189,9 +211,40 @@ class OpenAIProvider(Provider):
|
||||
|
||||
return data
|
||||
|
||||
async def _get_models_async(self) -> List[dict]:
|
||||
from letta.llm_api.openai import openai_get_model_list_async
|
||||
|
||||
# Some hardcoded support for OpenRouter (so that we only get models with tool calling support)...
|
||||
# See: https://openrouter.ai/docs/requests
|
||||
extra_params = {"supported_parameters": "tools"} if "openrouter.ai" in self.base_url else None
|
||||
|
||||
# Similar to Nebius
|
||||
extra_params = {"verbose": True} if "nebius.com" in self.base_url else None
|
||||
|
||||
response = await openai_get_model_list_async(
|
||||
self.base_url,
|
||||
api_key=self.api_key,
|
||||
extra_params=extra_params,
|
||||
# fix_url=True, # NOTE: make sure together ends with /v1
|
||||
)
|
||||
|
||||
if "data" in response:
|
||||
data = response["data"]
|
||||
else:
|
||||
# TogetherAI's response is missing the 'data' field
|
||||
data = response
|
||||
|
||||
return data
|
||||
|
||||
def list_llm_models(self) -> List[LLMConfig]:
|
||||
data = self._get_models()
|
||||
return self._list_llm_models(data)
|
||||
|
||||
async def list_llm_models_async(self) -> List[LLMConfig]:
|
||||
data = await self._get_models_async()
|
||||
return self._list_llm_models(data)
|
||||
|
||||
def _list_llm_models(self, data) -> List[LLMConfig]:
|
||||
configs = []
|
||||
for model in data:
|
||||
assert "id" in model, f"OpenAI model missing 'id' field: {model}"
|
||||
@ -279,7 +332,6 @@ class OpenAIProvider(Provider):
|
||||
return configs
|
||||
|
||||
def list_embedding_models(self) -> List[EmbeddingConfig]:
|
||||
|
||||
if self.base_url == "https://api.openai.com/v1":
|
||||
# TODO: actually automatically list models for OpenAI
|
||||
return [
|
||||
@ -312,55 +364,92 @@ class OpenAIProvider(Provider):
|
||||
else:
|
||||
# Actually attempt to list
|
||||
data = self._get_models()
|
||||
return self._list_embedding_models(data)
|
||||
|
||||
configs = []
|
||||
for model in data:
|
||||
assert "id" in model, f"Model missing 'id' field: {model}"
|
||||
model_name = model["id"]
|
||||
async def list_embedding_models_async(self) -> List[EmbeddingConfig]:
|
||||
if self.base_url == "https://api.openai.com/v1":
|
||||
# TODO: actually automatically list models for OpenAI
|
||||
return [
|
||||
EmbeddingConfig(
|
||||
embedding_model="text-embedding-ada-002",
|
||||
embedding_endpoint_type="openai",
|
||||
embedding_endpoint=self.base_url,
|
||||
embedding_dim=1536,
|
||||
embedding_chunk_size=300,
|
||||
handle=self.get_handle("text-embedding-ada-002", is_embedding=True),
|
||||
),
|
||||
EmbeddingConfig(
|
||||
embedding_model="text-embedding-3-small",
|
||||
embedding_endpoint_type="openai",
|
||||
embedding_endpoint=self.base_url,
|
||||
embedding_dim=2000,
|
||||
embedding_chunk_size=300,
|
||||
handle=self.get_handle("text-embedding-3-small", is_embedding=True),
|
||||
),
|
||||
EmbeddingConfig(
|
||||
embedding_model="text-embedding-3-large",
|
||||
embedding_endpoint_type="openai",
|
||||
embedding_endpoint=self.base_url,
|
||||
embedding_dim=2000,
|
||||
embedding_chunk_size=300,
|
||||
handle=self.get_handle("text-embedding-3-large", is_embedding=True),
|
||||
),
|
||||
]
|
||||
|
||||
if "context_length" in model:
|
||||
# Context length is returned in Nebius as "context_length"
|
||||
context_window_size = model["context_length"]
|
||||
else:
|
||||
context_window_size = self.get_model_context_window_size(model_name)
|
||||
else:
|
||||
# Actually attempt to list
|
||||
data = await self._get_models_async()
|
||||
return self._list_embedding_models(data)
|
||||
|
||||
# We need the context length for embeddings too
|
||||
if not context_window_size:
|
||||
continue
|
||||
def _list_embedding_models(self, data) -> List[EmbeddingConfig]:
|
||||
configs = []
|
||||
for model in data:
|
||||
assert "id" in model, f"Model missing 'id' field: {model}"
|
||||
model_name = model["id"]
|
||||
|
||||
if "nebius.com" in self.base_url:
|
||||
# Nebius includes the type, which we can use to filter for embedidng models
|
||||
try:
|
||||
model_type = model["architecture"]["modality"]
|
||||
if model_type not in ["text->embedding"]:
|
||||
# print(f"Skipping model w/ modality {model_type}:\n{model}")
|
||||
continue
|
||||
except KeyError:
|
||||
print(f"Couldn't access architecture type field, skipping model:\n{model}")
|
||||
continue
|
||||
if "context_length" in model:
|
||||
# Context length is returned in Nebius as "context_length"
|
||||
context_window_size = model["context_length"]
|
||||
else:
|
||||
context_window_size = self.get_model_context_window_size(model_name)
|
||||
|
||||
elif "together.ai" in self.base_url or "together.xyz" in self.base_url:
|
||||
# TogetherAI includes the type, which we can use to filter for embedding models
|
||||
if "type" in model and model["type"] not in ["embedding"]:
|
||||
# We need the context length for embeddings too
|
||||
if not context_window_size:
|
||||
continue
|
||||
|
||||
if "nebius.com" in self.base_url:
|
||||
# Nebius includes the type, which we can use to filter for embedidng models
|
||||
try:
|
||||
model_type = model["architecture"]["modality"]
|
||||
if model_type not in ["text->embedding"]:
|
||||
# print(f"Skipping model w/ modality {model_type}:\n{model}")
|
||||
continue
|
||||
|
||||
else:
|
||||
# For other providers we should skip by default, since we don't want to assume embeddings are supported
|
||||
except KeyError:
|
||||
print(f"Couldn't access architecture type field, skipping model:\n{model}")
|
||||
continue
|
||||
|
||||
configs.append(
|
||||
EmbeddingConfig(
|
||||
embedding_model=model_name,
|
||||
embedding_endpoint_type=self.provider_type,
|
||||
embedding_endpoint=self.base_url,
|
||||
embedding_dim=context_window_size,
|
||||
embedding_chunk_size=DEFAULT_EMBEDDING_CHUNK_SIZE,
|
||||
handle=self.get_handle(model, is_embedding=True),
|
||||
)
|
||||
)
|
||||
elif "together.ai" in self.base_url or "together.xyz" in self.base_url:
|
||||
# TogetherAI includes the type, which we can use to filter for embedding models
|
||||
if "type" in model and model["type"] not in ["embedding"]:
|
||||
# print(f"Skipping model w/ modality {model_type}:\n{model}")
|
||||
continue
|
||||
|
||||
return configs
|
||||
else:
|
||||
# For other providers we should skip by default, since we don't want to assume embeddings are supported
|
||||
continue
|
||||
|
||||
configs.append(
|
||||
EmbeddingConfig(
|
||||
embedding_model=model_name,
|
||||
embedding_endpoint_type=self.provider_type,
|
||||
embedding_endpoint=self.base_url,
|
||||
embedding_dim=context_window_size,
|
||||
embedding_chunk_size=DEFAULT_EMBEDDING_CHUNK_SIZE,
|
||||
handle=self.get_handle(model, is_embedding=True),
|
||||
)
|
||||
)
|
||||
|
||||
return configs
|
||||
|
||||
def get_model_context_window_size(self, model_name: str):
|
||||
if model_name in LLM_MAX_TOKENS:
|
||||
@ -647,26 +736,19 @@ class AnthropicProvider(Provider):
|
||||
anthropic_check_valid_api_key(self.api_key)
|
||||
|
||||
def list_llm_models(self) -> List[LLMConfig]:
|
||||
from letta.llm_api.anthropic import MODEL_LIST, anthropic_get_model_list
|
||||
from letta.llm_api.anthropic import anthropic_get_model_list
|
||||
|
||||
models = anthropic_get_model_list(self.base_url, api_key=self.api_key)
|
||||
models = anthropic_get_model_list(api_key=self.api_key)
|
||||
return self._list_llm_models(models)
|
||||
|
||||
"""
|
||||
Example response:
|
||||
{
|
||||
"data": [
|
||||
{
|
||||
"type": "model",
|
||||
"id": "claude-3-5-sonnet-20241022",
|
||||
"display_name": "Claude 3.5 Sonnet (New)",
|
||||
"created_at": "2024-10-22T00:00:00Z"
|
||||
}
|
||||
],
|
||||
"has_more": true,
|
||||
"first_id": "<string>",
|
||||
"last_id": "<string>"
|
||||
}
|
||||
"""
|
||||
async def list_llm_models_async(self) -> List[LLMConfig]:
|
||||
from letta.llm_api.anthropic import anthropic_get_model_list_async
|
||||
|
||||
models = await anthropic_get_model_list_async(api_key=self.api_key)
|
||||
return self._list_llm_models(models)
|
||||
|
||||
def _list_llm_models(self, models) -> List[LLMConfig]:
|
||||
from letta.llm_api.anthropic import MODEL_LIST
|
||||
|
||||
configs = []
|
||||
for model in models:
|
||||
@ -724,9 +806,6 @@ class AnthropicProvider(Provider):
|
||||
)
|
||||
return configs
|
||||
|
||||
def list_embedding_models(self) -> List[EmbeddingConfig]:
|
||||
return []
|
||||
|
||||
|
||||
class MistralProvider(Provider):
|
||||
provider_type: Literal[ProviderType.mistral] = Field(ProviderType.mistral, description="The type of the provider.")
|
||||
@ -948,14 +1027,24 @@ class TogetherProvider(OpenAIProvider):
|
||||
def list_llm_models(self) -> List[LLMConfig]:
|
||||
from letta.llm_api.openai import openai_get_model_list
|
||||
|
||||
response = openai_get_model_list(self.base_url, api_key=self.api_key)
|
||||
models = openai_get_model_list(self.base_url, api_key=self.api_key)
|
||||
return self._list_llm_models(models)
|
||||
|
||||
async def list_llm_models_async(self) -> List[LLMConfig]:
|
||||
from letta.llm_api.openai import openai_get_model_list_async
|
||||
|
||||
models = await openai_get_model_list_async(self.base_url, api_key=self.api_key)
|
||||
return self._list_llm_models(models)
|
||||
|
||||
def _list_llm_models(self, models) -> List[LLMConfig]:
|
||||
pass
|
||||
|
||||
# TogetherAI's response is missing the 'data' field
|
||||
# assert "data" in response, f"OpenAI model query response missing 'data' field: {response}"
|
||||
if "data" in response:
|
||||
data = response["data"]
|
||||
if "data" in models:
|
||||
data = models["data"]
|
||||
else:
|
||||
data = response
|
||||
data = models
|
||||
|
||||
configs = []
|
||||
for model in data:
|
||||
@ -1057,7 +1146,6 @@ class GoogleAIProvider(Provider):
|
||||
from letta.llm_api.google_ai_client import google_ai_get_model_list
|
||||
|
||||
model_options = google_ai_get_model_list(base_url=self.base_url, api_key=self.api_key)
|
||||
# filter by 'generateContent' models
|
||||
model_options = [mo for mo in model_options if "generateContent" in mo["supportedGenerationMethods"]]
|
||||
model_options = [str(m["name"]) for m in model_options]
|
||||
|
||||
@ -1081,6 +1169,42 @@ class GoogleAIProvider(Provider):
|
||||
provider_category=self.provider_category,
|
||||
)
|
||||
)
|
||||
|
||||
return configs
|
||||
|
||||
async def list_llm_models_async(self):
|
||||
import asyncio
|
||||
|
||||
from letta.llm_api.google_ai_client import google_ai_get_model_list_async
|
||||
|
||||
# Get and filter the model list
|
||||
model_options = await google_ai_get_model_list_async(base_url=self.base_url, api_key=self.api_key)
|
||||
model_options = [mo for mo in model_options if "generateContent" in mo["supportedGenerationMethods"]]
|
||||
model_options = [str(m["name"]) for m in model_options]
|
||||
|
||||
# filter by model names
|
||||
model_options = [mo[len("models/") :] if mo.startswith("models/") else mo for mo in model_options]
|
||||
|
||||
# Add support for all gemini models
|
||||
model_options = [mo for mo in model_options if str(mo).startswith("gemini-")]
|
||||
|
||||
# Prepare tasks for context window lookups in parallel
|
||||
async def create_config(model):
|
||||
context_window = await self.get_model_context_window_async(model)
|
||||
return LLMConfig(
|
||||
model=model,
|
||||
model_endpoint_type="google_ai",
|
||||
model_endpoint=self.base_url,
|
||||
context_window=context_window,
|
||||
handle=self.get_handle(model),
|
||||
max_tokens=8192,
|
||||
provider_name=self.name,
|
||||
provider_category=self.provider_category,
|
||||
)
|
||||
|
||||
# Execute all config creation tasks concurrently
|
||||
configs = await asyncio.gather(*[create_config(model) for model in model_options])
|
||||
|
||||
return configs
|
||||
|
||||
def list_embedding_models(self):
|
||||
@ -1088,6 +1212,16 @@ class GoogleAIProvider(Provider):
|
||||
|
||||
# TODO: use base_url instead
|
||||
model_options = google_ai_get_model_list(base_url=self.base_url, api_key=self.api_key)
|
||||
return self._list_embedding_models(model_options)
|
||||
|
||||
async def list_embedding_models_async(self):
|
||||
from letta.llm_api.google_ai_client import google_ai_get_model_list_async
|
||||
|
||||
# TODO: use base_url instead
|
||||
model_options = await google_ai_get_model_list_async(base_url=self.base_url, api_key=self.api_key)
|
||||
return self._list_embedding_models(model_options)
|
||||
|
||||
def _list_embedding_models(self, model_options):
|
||||
# filter by 'generateContent' models
|
||||
model_options = [mo for mo in model_options if "embedContent" in mo["supportedGenerationMethods"]]
|
||||
model_options = [str(m["name"]) for m in model_options]
|
||||
@ -1110,7 +1244,18 @@ class GoogleAIProvider(Provider):
|
||||
def get_model_context_window(self, model_name: str) -> Optional[int]:
|
||||
from letta.llm_api.google_ai_client import google_ai_get_model_context_window
|
||||
|
||||
return google_ai_get_model_context_window(self.base_url, self.api_key, model_name)
|
||||
if model_name in LLM_MAX_TOKENS:
|
||||
return LLM_MAX_TOKENS[model_name]
|
||||
else:
|
||||
return google_ai_get_model_context_window(self.base_url, self.api_key, model_name)
|
||||
|
||||
async def get_model_context_window_async(self, model_name: str) -> Optional[int]:
|
||||
from letta.llm_api.google_ai_client import google_ai_get_model_context_window_async
|
||||
|
||||
if model_name in LLM_MAX_TOKENS:
|
||||
return LLM_MAX_TOKENS[model_name]
|
||||
else:
|
||||
return await google_ai_get_model_context_window_async(self.base_url, self.api_key, model_name)
|
||||
|
||||
|
||||
class GoogleVertexProvider(Provider):
|
||||
|
@ -20,6 +20,7 @@ class Step(StepBase):
|
||||
)
|
||||
agent_id: Optional[str] = Field(None, description="The ID of the agent that performed the step.")
|
||||
provider_name: Optional[str] = Field(None, description="The name of the provider used for this step.")
|
||||
provider_category: Optional[str] = Field(None, description="The category of the provider used for this step.")
|
||||
model: Optional[str] = Field(None, description="The name of the model used for this step.")
|
||||
model_endpoint: Optional[str] = Field(None, description="The model endpoint url used for this step.")
|
||||
context_window_limit: Optional[int] = Field(None, description="The context window limit configured for this step.")
|
||||
|
@ -5,6 +5,7 @@ from pydantic import Field, model_validator
|
||||
from letta.constants import (
|
||||
COMPOSIO_TOOL_TAG_NAME,
|
||||
FUNCTION_RETURN_CHAR_LIMIT,
|
||||
LETTA_BUILTIN_TOOL_MODULE_NAME,
|
||||
LETTA_CORE_TOOL_MODULE_NAME,
|
||||
LETTA_MULTI_AGENT_TOOL_MODULE_NAME,
|
||||
LETTA_VOICE_TOOL_MODULE_NAME,
|
||||
@ -104,6 +105,9 @@ class Tool(BaseTool):
|
||||
elif self.tool_type in {ToolType.LETTA_VOICE_SLEEPTIME_CORE}:
|
||||
# If it's letta voice tool, we generate the json_schema on the fly here
|
||||
self.json_schema = get_json_schema_from_module(module_name=LETTA_VOICE_TOOL_MODULE_NAME, function_name=self.name)
|
||||
elif self.tool_type in {ToolType.LETTA_BUILTIN}:
|
||||
# If it's letta voice tool, we generate the json_schema on the fly here
|
||||
self.json_schema = get_json_schema_from_module(module_name=LETTA_BUILTIN_TOOL_MODULE_NAME, function_name=self.name)
|
||||
|
||||
# At this point, we need to validate that at least json_schema is populated
|
||||
if not self.json_schema:
|
||||
|
@ -6,7 +6,7 @@ from typing import Any, AsyncGenerator, Generator
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.text import Text
|
||||
from sqlalchemy import Engine, create_engine
|
||||
from sqlalchemy import Engine, NullPool, QueuePool, create_engine
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
@ -14,6 +14,8 @@ from letta.config import LettaConfig
|
||||
from letta.log import get_logger
|
||||
from letta.settings import settings
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def print_sqlite_schema_error():
|
||||
"""Print a formatted error message for SQLite schema issues"""
|
||||
@ -76,16 +78,7 @@ class DatabaseRegistry:
|
||||
self.config.archival_storage_type = "postgres"
|
||||
self.config.archival_storage_uri = settings.letta_pg_uri_no_default
|
||||
|
||||
engine = create_engine(
|
||||
settings.letta_pg_uri,
|
||||
# f"{settings.letta_pg_uri}?options=-c%20client_encoding=UTF8",
|
||||
pool_size=settings.pg_pool_size,
|
||||
max_overflow=settings.pg_max_overflow,
|
||||
pool_timeout=settings.pg_pool_timeout,
|
||||
pool_recycle=settings.pg_pool_recycle,
|
||||
echo=settings.pg_echo,
|
||||
# connect_args={"client_encoding": "utf8"},
|
||||
)
|
||||
engine = create_engine(settings.letta_pg_uri, **self._build_sqlalchemy_engine_args(is_async=False))
|
||||
|
||||
self._engines["default"] = engine
|
||||
# SQLite engine
|
||||
@ -125,14 +118,7 @@ class DatabaseRegistry:
|
||||
async_pg_uri = f"postgresql+asyncpg://{pg_uri.split('://', 1)[1]}" if "://" in pg_uri else pg_uri
|
||||
async_pg_uri = async_pg_uri.replace("sslmode=", "ssl=")
|
||||
|
||||
async_engine = create_async_engine(
|
||||
async_pg_uri,
|
||||
pool_size=settings.pg_pool_size,
|
||||
max_overflow=settings.pg_max_overflow,
|
||||
pool_timeout=settings.pg_pool_timeout,
|
||||
pool_recycle=settings.pg_pool_recycle,
|
||||
echo=settings.pg_echo,
|
||||
)
|
||||
async_engine = create_async_engine(async_pg_uri, **self._build_sqlalchemy_engine_args(is_async=True))
|
||||
|
||||
self._async_engines["default"] = async_engine
|
||||
|
||||
@ -146,6 +132,38 @@ class DatabaseRegistry:
|
||||
# TODO (cliandy): unclear around async sqlite support in sqlalchemy, we will not currently support this
|
||||
self._initialized["async"] = False
|
||||
|
||||
def _build_sqlalchemy_engine_args(self, *, is_async: bool) -> dict:
|
||||
"""Prepare keyword arguments for create_engine / create_async_engine."""
|
||||
use_null_pool = settings.disable_sqlalchemy_pooling
|
||||
|
||||
if use_null_pool:
|
||||
logger.info("Disabling pooling on SqlAlchemy")
|
||||
pool_cls = NullPool
|
||||
else:
|
||||
logger.info("Enabling pooling on SqlAlchemy")
|
||||
pool_cls = QueuePool if not is_async else None
|
||||
|
||||
base_args = {
|
||||
"echo": settings.pg_echo,
|
||||
"pool_pre_ping": settings.pool_pre_ping,
|
||||
}
|
||||
|
||||
if pool_cls:
|
||||
base_args["poolclass"] = pool_cls
|
||||
|
||||
if not use_null_pool and not is_async:
|
||||
base_args.update(
|
||||
{
|
||||
"pool_size": settings.pg_pool_size,
|
||||
"max_overflow": settings.pg_max_overflow,
|
||||
"pool_timeout": settings.pg_pool_timeout,
|
||||
"pool_recycle": settings.pg_pool_recycle,
|
||||
"pool_use_lifo": settings.pool_use_lifo,
|
||||
}
|
||||
)
|
||||
|
||||
return base_args
|
||||
|
||||
def _wrap_sqlite_engine(self, engine: Engine) -> None:
|
||||
"""Wrap SQLite engine with error handling."""
|
||||
original_connect = engine.connect
|
||||
|
@ -13,6 +13,7 @@ from letta.server.rest_api.routers.v1.sandbox_configs import router as sandbox_c
|
||||
from letta.server.rest_api.routers.v1.sources import router as sources_router
|
||||
from letta.server.rest_api.routers.v1.steps import router as steps_router
|
||||
from letta.server.rest_api.routers.v1.tags import router as tags_router
|
||||
from letta.server.rest_api.routers.v1.telemetry import router as telemetry_router
|
||||
from letta.server.rest_api.routers.v1.tools import router as tools_router
|
||||
from letta.server.rest_api.routers.v1.voice import router as voice_router
|
||||
|
||||
@ -31,6 +32,7 @@ ROUTERS = [
|
||||
runs_router,
|
||||
steps_router,
|
||||
tags_router,
|
||||
telemetry_router,
|
||||
messages_router,
|
||||
voice_router,
|
||||
embeddings_router,
|
||||
|
@ -33,6 +33,7 @@ from letta.schemas.user import User
|
||||
from letta.serialize_schemas.pydantic_agent_schema import AgentSchema
|
||||
from letta.server.rest_api.utils import get_letta_server
|
||||
from letta.server.server import SyncServer
|
||||
from letta.services.telemetry_manager import NoopTelemetryManager
|
||||
from letta.settings import settings
|
||||
|
||||
# These can be forward refs, but because Fastapi needs them at runtime the must be imported normally
|
||||
@ -106,14 +107,15 @@ async def list_agents(
|
||||
|
||||
|
||||
@router.get("/count", response_model=int, operation_id="count_agents")
|
||||
def count_agents(
|
||||
async def count_agents(
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
"""
|
||||
Get the count of all agents associated with a given user.
|
||||
"""
|
||||
return server.agent_manager.size(actor=server.user_manager.get_user_or_default(user_id=actor_id))
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
return await server.agent_manager.size_async(actor=actor)
|
||||
|
||||
|
||||
class IndentedORJSONResponse(Response):
|
||||
@ -124,7 +126,7 @@ class IndentedORJSONResponse(Response):
|
||||
|
||||
|
||||
@router.get("/{agent_id}/export", response_class=IndentedORJSONResponse, operation_id="export_agent_serialized")
|
||||
def export_agent_serialized(
|
||||
async def export_agent_serialized(
|
||||
agent_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
@ -135,7 +137,7 @@ def export_agent_serialized(
|
||||
"""
|
||||
Export the serialized JSON representation of an agent, formatted with indentation.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
|
||||
try:
|
||||
agent = server.agent_manager.serialize(agent_id=agent_id, actor=actor)
|
||||
@ -200,7 +202,7 @@ async def import_agent_serialized(
|
||||
|
||||
|
||||
@router.get("/{agent_id}/context", response_model=ContextWindowOverview, operation_id="retrieve_agent_context_window")
|
||||
def retrieve_agent_context_window(
|
||||
async def retrieve_agent_context_window(
|
||||
agent_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
@ -208,9 +210,12 @@ def retrieve_agent_context_window(
|
||||
"""
|
||||
Retrieve the context window of a specific agent.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
return server.get_agent_context_window(agent_id=agent_id, actor=actor)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
try:
|
||||
return await server.get_agent_context_window_async(agent_id=agent_id, actor=actor)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
raise e
|
||||
|
||||
|
||||
class CreateAgentRequest(CreateAgent):
|
||||
@ -341,7 +346,7 @@ async def retrieve_agent(
|
||||
"""
|
||||
Get the state of the agent.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
|
||||
try:
|
||||
return await server.agent_manager.get_agent_by_id_async(agent_id=agent_id, actor=actor)
|
||||
@ -367,7 +372,7 @@ def delete_agent(
|
||||
|
||||
|
||||
@router.get("/{agent_id}/sources", response_model=List[Source], operation_id="list_agent_sources")
|
||||
def list_agent_sources(
|
||||
async def list_agent_sources(
|
||||
agent_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
@ -375,8 +380,8 @@ def list_agent_sources(
|
||||
"""
|
||||
Get the sources associated with an agent.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.agent_manager.list_attached_sources(agent_id=agent_id, actor=actor)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
return await server.agent_manager.list_attached_sources_async(agent_id=agent_id, actor=actor)
|
||||
|
||||
|
||||
# TODO: remove? can also get with agent blocks
|
||||
@ -424,14 +429,14 @@ async def list_blocks(
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
try:
|
||||
agent = await server.agent_manager.get_agent_by_id_async(agent_id, actor)
|
||||
agent = await server.agent_manager.get_agent_by_id_async(agent_id=agent_id, include_relationships=["memory"], actor=actor)
|
||||
return agent.memory.blocks
|
||||
except NoResultFound as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@router.patch("/{agent_id}/core-memory/blocks/{block_label}", response_model=Block, operation_id="modify_core_memory_block")
|
||||
def modify_block(
|
||||
async def modify_block(
|
||||
agent_id: str,
|
||||
block_label: str,
|
||||
block_update: BlockUpdate = Body(...),
|
||||
@ -441,10 +446,11 @@ def modify_block(
|
||||
"""
|
||||
Updates a core memory block of an agent.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
|
||||
block = server.agent_manager.get_block_with_label(agent_id=agent_id, block_label=block_label, actor=actor)
|
||||
block = server.block_manager.update_block(block.id, block_update=block_update, actor=actor)
|
||||
block = await server.agent_manager.modify_block_by_label_async(
|
||||
agent_id=agent_id, block_label=block_label, block_update=block_update, actor=actor
|
||||
)
|
||||
|
||||
# This should also trigger a system prompt change in the agent
|
||||
server.agent_manager.rebuild_system_prompt(agent_id=agent_id, actor=actor, force=True, update_timestamp=False)
|
||||
@ -481,7 +487,7 @@ def detach_block(
|
||||
|
||||
|
||||
@router.get("/{agent_id}/archival-memory", response_model=List[Passage], operation_id="list_passages")
|
||||
def list_passages(
|
||||
async def list_passages(
|
||||
agent_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
after: Optional[str] = Query(None, description="Unique ID of the memory to start the query range at."),
|
||||
@ -496,11 +502,11 @@ def list_passages(
|
||||
"""
|
||||
Retrieve the memories in an agent's archival memory store (paginated query).
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
|
||||
return server.get_agent_archival(
|
||||
user_id=actor.id,
|
||||
return await server.get_agent_archival_async(
|
||||
agent_id=agent_id,
|
||||
actor=actor,
|
||||
after=after,
|
||||
before=before,
|
||||
query_text=search,
|
||||
@ -564,7 +570,7 @@ AgentMessagesResponse = Annotated[
|
||||
|
||||
|
||||
@router.get("/{agent_id}/messages", response_model=AgentMessagesResponse, operation_id="list_messages")
|
||||
def list_messages(
|
||||
async def list_messages(
|
||||
agent_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
after: Optional[str] = Query(None, description="Message after which to retrieve the returned messages."),
|
||||
@ -579,10 +585,9 @@ def list_messages(
|
||||
"""
|
||||
Retrieve message history for an agent.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
|
||||
return server.get_agent_recall(
|
||||
user_id=actor.id,
|
||||
return await server.get_agent_recall_async(
|
||||
agent_id=agent_id,
|
||||
after=after,
|
||||
before=before,
|
||||
@ -593,6 +598,7 @@ def list_messages(
|
||||
use_assistant_message=use_assistant_message,
|
||||
assistant_message_tool_name=assistant_message_tool_name,
|
||||
assistant_message_tool_kwarg=assistant_message_tool_kwarg,
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
|
||||
@ -634,7 +640,7 @@ async def send_message(
|
||||
agent_eligible = not agent.enable_sleeptime and not agent.multi_agent_group and agent.agent_type != AgentType.sleeptime_agent
|
||||
experimental_header = request_obj.headers.get("X-EXPERIMENTAL") or "false"
|
||||
feature_enabled = settings.use_experimental or experimental_header.lower() == "true"
|
||||
model_compatible = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "google_vertex", "google_ai"]
|
||||
model_compatible = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "together", "google_ai", "google_vertex"]
|
||||
|
||||
if agent_eligible and feature_enabled and model_compatible:
|
||||
experimental_agent = LettaAgent(
|
||||
@ -644,6 +650,8 @@ async def send_message(
|
||||
block_manager=server.block_manager,
|
||||
passage_manager=server.passage_manager,
|
||||
actor=actor,
|
||||
step_manager=server.step_manager,
|
||||
telemetry_manager=server.telemetry_manager if settings.llm_api_logging else NoopTelemetryManager(),
|
||||
)
|
||||
|
||||
result = await experimental_agent.step(request.messages, max_steps=10, use_assistant_message=request.use_assistant_message)
|
||||
@ -692,7 +700,8 @@ async def send_message_streaming(
|
||||
agent_eligible = not agent.enable_sleeptime and not agent.multi_agent_group and agent.agent_type != AgentType.sleeptime_agent
|
||||
experimental_header = request_obj.headers.get("X-EXPERIMENTAL") or "false"
|
||||
feature_enabled = settings.use_experimental or experimental_header.lower() == "true"
|
||||
model_compatible = agent.llm_config.model_endpoint_type in ["anthropic", "openai"]
|
||||
model_compatible = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "together", "google_ai", "google_vertex"]
|
||||
model_compatible_token_streaming = agent.llm_config.model_endpoint_type in ["anthropic", "openai"]
|
||||
|
||||
if agent_eligible and feature_enabled and model_compatible and request.stream_tokens:
|
||||
experimental_agent = LettaAgent(
|
||||
@ -702,14 +711,28 @@ async def send_message_streaming(
|
||||
block_manager=server.block_manager,
|
||||
passage_manager=server.passage_manager,
|
||||
actor=actor,
|
||||
step_manager=server.step_manager,
|
||||
telemetry_manager=server.telemetry_manager if settings.llm_api_logging else NoopTelemetryManager(),
|
||||
)
|
||||
from letta.server.rest_api.streaming_response import StreamingResponseWithStatusCode
|
||||
|
||||
result = StreamingResponse(
|
||||
experimental_agent.step_stream(
|
||||
request.messages, max_steps=10, use_assistant_message=request.use_assistant_message, stream_tokens=request.stream_tokens
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
if request.stream_tokens and model_compatible_token_streaming:
|
||||
result = StreamingResponseWithStatusCode(
|
||||
experimental_agent.step_stream(
|
||||
input_messages=request.messages,
|
||||
max_steps=10,
|
||||
use_assistant_message=request.use_assistant_message,
|
||||
request_start_timestamp_ns=request_start_timestamp_ns,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
else:
|
||||
result = StreamingResponseWithStatusCode(
|
||||
experimental_agent.step_stream_no_tokens(
|
||||
request.messages, max_steps=10, use_assistant_message=request.use_assistant_message
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
else:
|
||||
result = await server.send_message_to_agent(
|
||||
agent_id=agent_id,
|
||||
|
@ -99,7 +99,7 @@ def retrieve_block(
|
||||
|
||||
|
||||
@router.get("/{block_id}/agents", response_model=List[AgentState], operation_id="list_agents_for_block")
|
||||
def list_agents_for_block(
|
||||
async def list_agents_for_block(
|
||||
block_id: str,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
@ -108,9 +108,9 @@ def list_agents_for_block(
|
||||
Retrieves all agents associated with the specified block.
|
||||
Raises a 404 if the block does not exist.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
try:
|
||||
agents = server.block_manager.get_agents_for_block(block_id=block_id, actor=actor)
|
||||
agents = await server.block_manager.get_agents_for_block_async(block_id=block_id, actor=actor)
|
||||
return agents
|
||||
except NoResultFound:
|
||||
raise HTTPException(status_code=404, detail=f"Block with id={block_id} not found")
|
||||
|
@ -13,7 +13,7 @@ router = APIRouter(prefix="/identities", tags=["identities"])
|
||||
|
||||
|
||||
@router.get("/", tags=["identities"], response_model=List[Identity], operation_id="list_identities")
|
||||
def list_identities(
|
||||
async def list_identities(
|
||||
name: Optional[str] = Query(None),
|
||||
project_id: Optional[str] = Query(None),
|
||||
identifier_key: Optional[str] = Query(None),
|
||||
@ -28,9 +28,9 @@ def list_identities(
|
||||
Get a list of all identities in the database
|
||||
"""
|
||||
try:
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
|
||||
identities = server.identity_manager.list_identities(
|
||||
identities = await server.identity_manager.list_identities_async(
|
||||
name=name,
|
||||
project_id=project_id,
|
||||
identifier_key=identifier_key,
|
||||
@ -50,7 +50,7 @@ def list_identities(
|
||||
|
||||
|
||||
@router.get("/count", tags=["identities"], response_model=int, operation_id="count_identities")
|
||||
def count_identities(
|
||||
async def count_identities(
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
@ -58,7 +58,8 @@ def count_identities(
|
||||
Get count of all identities for a user
|
||||
"""
|
||||
try:
|
||||
return server.identity_manager.size(actor=server.user_manager.get_user_or_default(user_id=actor_id))
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
return await server.identity_manager.size_async(actor=actor)
|
||||
except NoResultFound:
|
||||
return 0
|
||||
except HTTPException:
|
||||
@ -68,28 +69,28 @@ def count_identities(
|
||||
|
||||
|
||||
@router.get("/{identity_id}", tags=["identities"], response_model=Identity, operation_id="retrieve_identity")
|
||||
def retrieve_identity(
|
||||
async def retrieve_identity(
|
||||
identity_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
try:
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.identity_manager.get_identity(identity_id=identity_id, actor=actor)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
return await server.identity_manager.get_identity_async(identity_id=identity_id, actor=actor)
|
||||
except NoResultFound as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/", tags=["identities"], response_model=Identity, operation_id="create_identity")
|
||||
def create_identity(
|
||||
async def create_identity(
|
||||
identity: IdentityCreate = Body(...),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
x_project: Optional[str] = Header(None, alias="X-Project"), # Only handled by next js middleware
|
||||
):
|
||||
try:
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.identity_manager.create_identity(identity=identity, actor=actor)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
return await server.identity_manager.create_identity_async(identity=identity, actor=actor)
|
||||
except HTTPException:
|
||||
raise
|
||||
except UniqueConstraintViolationError:
|
||||
@ -105,15 +106,15 @@ def create_identity(
|
||||
|
||||
|
||||
@router.put("/", tags=["identities"], response_model=Identity, operation_id="upsert_identity")
|
||||
def upsert_identity(
|
||||
async def upsert_identity(
|
||||
identity: IdentityUpsert = Body(...),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
x_project: Optional[str] = Header(None, alias="X-Project"), # Only handled by next js middleware
|
||||
):
|
||||
try:
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.identity_manager.upsert_identity(identity=identity, actor=actor)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
return await server.identity_manager.upsert_identity_async(identity=identity, actor=actor)
|
||||
except HTTPException:
|
||||
raise
|
||||
except NoResultFound as e:
|
||||
@ -123,36 +124,33 @@ def upsert_identity(
|
||||
|
||||
|
||||
@router.patch("/{identity_id}", tags=["identities"], response_model=Identity, operation_id="update_identity")
|
||||
def modify_identity(
|
||||
async def modify_identity(
|
||||
identity_id: str,
|
||||
identity: IdentityUpdate = Body(...),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
try:
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.identity_manager.update_identity(identity_id=identity_id, identity=identity, actor=actor)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
return await server.identity_manager.update_identity_async(identity_id=identity_id, identity=identity, actor=actor)
|
||||
except HTTPException:
|
||||
raise
|
||||
except NoResultFound as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
print(traceback.format_exc())
|
||||
raise HTTPException(status_code=500, detail=f"{e}")
|
||||
|
||||
|
||||
@router.put("/{identity_id}/properties", tags=["identities"], operation_id="upsert_identity_properties")
|
||||
def upsert_identity_properties(
|
||||
async def upsert_identity_properties(
|
||||
identity_id: str,
|
||||
properties: List[IdentityProperty] = Body(...),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
try:
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.identity_manager.upsert_identity_properties(identity_id=identity_id, properties=properties, actor=actor)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
return await server.identity_manager.upsert_identity_properties_async(identity_id=identity_id, properties=properties, actor=actor)
|
||||
except HTTPException:
|
||||
raise
|
||||
except NoResultFound as e:
|
||||
@ -162,7 +160,7 @@ def upsert_identity_properties(
|
||||
|
||||
|
||||
@router.delete("/{identity_id}", tags=["identities"], operation_id="delete_identity")
|
||||
def delete_identity(
|
||||
async def delete_identity(
|
||||
identity_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
@ -171,8 +169,8 @@ def delete_identity(
|
||||
Delete an identity by its identifier key
|
||||
"""
|
||||
try:
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
server.identity_manager.delete_identity(identity_id=identity_id, actor=actor)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
await server.identity_manager.delete_identity_async(identity_id=identity_id, actor=actor)
|
||||
except HTTPException:
|
||||
raise
|
||||
except NoResultFound as e:
|
||||
|
@ -33,16 +33,16 @@ def list_jobs(
|
||||
|
||||
|
||||
@router.get("/active", response_model=List[Job], operation_id="list_active_jobs")
|
||||
def list_active_jobs(
|
||||
async def list_active_jobs(
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
List all active jobs.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
|
||||
return server.job_manager.list_jobs(actor=actor, statuses=[JobStatus.created, JobStatus.running])
|
||||
return await server.job_manager.list_jobs_async(actor=actor, statuses=[JobStatus.created, JobStatus.running])
|
||||
|
||||
|
||||
@router.get("/{job_id}", response_model=Job, operation_id="retrieve_job")
|
||||
|
@ -14,30 +14,35 @@ router = APIRouter(prefix="/models", tags=["models", "llms"])
|
||||
|
||||
|
||||
@router.get("/", response_model=List[LLMConfig], operation_id="list_models")
|
||||
def list_llm_models(
|
||||
async def list_llm_models(
|
||||
provider_category: Optional[List[ProviderCategory]] = Query(None),
|
||||
provider_name: Optional[str] = Query(None),
|
||||
provider_type: Optional[ProviderType] = Query(None),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
# Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""List available LLM models using the asynchronous implementation for improved performance"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
models = server.list_llm_models(
|
||||
|
||||
models = await server.list_llm_models_async(
|
||||
provider_category=provider_category,
|
||||
provider_name=provider_name,
|
||||
provider_type=provider_type,
|
||||
actor=actor,
|
||||
)
|
||||
# print(models)
|
||||
|
||||
return models
|
||||
|
||||
|
||||
@router.get("/embedding", response_model=List[EmbeddingConfig], operation_id="list_embedding_models")
|
||||
def list_embedding_models(
|
||||
async def list_embedding_models(
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
# Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""List available embedding models using the asynchronous implementation for improved performance"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
models = server.list_embedding_models(actor=actor)
|
||||
# print(models)
|
||||
models = await server.list_embedding_models_async(actor=actor)
|
||||
|
||||
return models
|
||||
|
@ -100,15 +100,15 @@ def delete_sandbox_config(
|
||||
|
||||
|
||||
@router.get("/", response_model=List[PydanticSandboxConfig])
|
||||
def list_sandbox_configs(
|
||||
async def list_sandbox_configs(
|
||||
limit: int = Query(1000, description="Number of results to return"),
|
||||
after: Optional[str] = Query(None, description="Pagination cursor to fetch the next set of results"),
|
||||
sandbox_type: Optional[SandboxType] = Query(None, description="Filter for this specific sandbox type"),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
actor_id: str = Depends(get_user_id),
|
||||
):
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.sandbox_config_manager.list_sandbox_configs(actor, limit=limit, after=after, sandbox_type=sandbox_type)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
return await server.sandbox_config_manager.list_sandbox_configs_async(actor, limit=limit, after=after, sandbox_type=sandbox_type)
|
||||
|
||||
|
||||
@router.post("/local/recreate-venv", response_model=PydanticSandboxConfig)
|
||||
@ -190,12 +190,12 @@ def delete_sandbox_env_var(
|
||||
|
||||
|
||||
@router.get("/{sandbox_config_id}/environment-variable", response_model=List[PydanticEnvVar])
|
||||
def list_sandbox_env_vars(
|
||||
async def list_sandbox_env_vars(
|
||||
sandbox_config_id: str,
|
||||
limit: int = Query(1000, description="Number of results to return"),
|
||||
after: Optional[str] = Query(None, description="Pagination cursor to fetch the next set of results"),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
actor_id: str = Depends(get_user_id),
|
||||
):
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.sandbox_config_manager.list_sandbox_env_vars(sandbox_config_id, actor, limit=limit, after=after)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
return await server.sandbox_config_manager.list_sandbox_env_vars_async(sandbox_config_id, actor, limit=limit, after=after)
|
||||
|
@ -12,7 +12,7 @@ router = APIRouter(prefix="/tags", tags=["tag", "admin"])
|
||||
|
||||
|
||||
@router.get("/", tags=["admin"], response_model=List[str], operation_id="list_tags")
|
||||
def list_tags(
|
||||
async def list_tags(
|
||||
after: Optional[str] = Query(None),
|
||||
limit: Optional[int] = Query(50),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
@ -22,6 +22,6 @@ def list_tags(
|
||||
"""
|
||||
Get a list of all tags in the database
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
tags = server.agent_manager.list_tags(actor=actor, after=after, limit=limit, query_text=query_text)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
tags = await server.agent_manager.list_tags_async(actor=actor, after=after, limit=limit, query_text=query_text)
|
||||
return tags
|
||||
|
18
letta/server/rest_api/routers/v1/telemetry.py
Normal file
18
letta/server/rest_api/routers/v1/telemetry.py
Normal file
@ -0,0 +1,18 @@
|
||||
from fastapi import APIRouter, Depends, Header
|
||||
|
||||
from letta.schemas.provider_trace import ProviderTrace
|
||||
from letta.server.rest_api.utils import get_letta_server
|
||||
from letta.server.server import SyncServer
|
||||
|
||||
router = APIRouter(prefix="/telemetry", tags=["telemetry"])
|
||||
|
||||
|
||||
@router.get("/{step_id}", response_model=ProviderTrace, operation_id="retrieve_provider_trace")
|
||||
async def retrieve_provider_trace_by_step_id(
|
||||
step_id: str,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
return await server.telemetry_manager.get_provider_trace_by_step_id_async(
|
||||
step_id=step_id, actor=await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
)
|
@ -59,7 +59,7 @@ def count_tools(
|
||||
|
||||
|
||||
@router.get("/{tool_id}", response_model=Tool, operation_id="retrieve_tool")
|
||||
def retrieve_tool(
|
||||
async def retrieve_tool(
|
||||
tool_id: str,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
@ -67,8 +67,8 @@ def retrieve_tool(
|
||||
"""
|
||||
Get a tool by ID
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
tool = server.tool_manager.get_tool_by_id(tool_id=tool_id, actor=actor)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
tool = await server.tool_manager.get_tool_by_id_async(tool_id=tool_id, actor=actor)
|
||||
if tool is None:
|
||||
# return 404 error
|
||||
raise HTTPException(status_code=404, detail=f"Tool with id {tool_id} not found.")
|
||||
@ -196,15 +196,15 @@ def modify_tool(
|
||||
|
||||
|
||||
@router.post("/add-base-tools", response_model=List[Tool], operation_id="add_base_tools")
|
||||
def upsert_base_tools(
|
||||
async def upsert_base_tools(
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Upsert base tools
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.tool_manager.upsert_base_tools(actor=actor)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
return await server.tool_manager.upsert_base_tools_async(actor=actor)
|
||||
|
||||
|
||||
@router.post("/run", response_model=ToolReturnMessage, operation_id="run_tool_from_source")
|
||||
|
105
letta/server/rest_api/streaming_response.py
Normal file
105
letta/server/rest_api/streaming_response.py
Normal file
@ -0,0 +1,105 @@
|
||||
# Alternative implementation of StreamingResponse that allows for effectively
|
||||
# stremaing HTTP trailers, as we cannot set codes after the initial response.
|
||||
# Taken from: https://github.com/fastapi/fastapi/discussions/10138#discussioncomment-10377361
|
||||
|
||||
import json
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
from fastapi.responses import StreamingResponse
|
||||
from starlette.types import Send
|
||||
|
||||
from letta.log import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class StreamingResponseWithStatusCode(StreamingResponse):
|
||||
"""
|
||||
Variation of StreamingResponse that can dynamically decide the HTTP status code,
|
||||
based on the return value of the content iterator (parameter `content`).
|
||||
Expects the content to yield either just str content as per the original `StreamingResponse`
|
||||
or else tuples of (`content`: `str`, `status_code`: `int`).
|
||||
"""
|
||||
|
||||
body_iterator: AsyncIterator[str | bytes]
|
||||
response_started: bool = False
|
||||
|
||||
async def stream_response(self, send: Send) -> None:
|
||||
more_body = True
|
||||
try:
|
||||
first_chunk = await self.body_iterator.__anext__()
|
||||
if isinstance(first_chunk, tuple):
|
||||
first_chunk_content, self.status_code = first_chunk
|
||||
else:
|
||||
first_chunk_content = first_chunk
|
||||
if isinstance(first_chunk_content, str):
|
||||
first_chunk_content = first_chunk_content.encode(self.charset)
|
||||
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": self.status_code,
|
||||
"headers": self.raw_headers,
|
||||
}
|
||||
)
|
||||
self.response_started = True
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.body",
|
||||
"body": first_chunk_content,
|
||||
"more_body": more_body,
|
||||
}
|
||||
)
|
||||
|
||||
async for chunk in self.body_iterator:
|
||||
if isinstance(chunk, tuple):
|
||||
content, status_code = chunk
|
||||
if status_code // 100 != 2:
|
||||
# An error occurred mid-stream
|
||||
if not isinstance(content, bytes):
|
||||
content = content.encode(self.charset)
|
||||
more_body = False
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.body",
|
||||
"body": content,
|
||||
"more_body": more_body,
|
||||
}
|
||||
)
|
||||
return
|
||||
else:
|
||||
content = chunk
|
||||
|
||||
if isinstance(content, str):
|
||||
content = content.encode(self.charset)
|
||||
more_body = True
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.body",
|
||||
"body": content,
|
||||
"more_body": more_body,
|
||||
}
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception("unhandled_streaming_error")
|
||||
more_body = False
|
||||
error_resp = {"error": {"message": "Internal Server Error"}}
|
||||
error_event = f"event: error\ndata: {json.dumps(error_resp)}\n\n".encode(self.charset)
|
||||
if not self.response_started:
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": 500,
|
||||
"headers": self.raw_headers,
|
||||
}
|
||||
)
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.body",
|
||||
"body": error_event,
|
||||
"more_body": more_body,
|
||||
}
|
||||
)
|
||||
if more_body:
|
||||
await send({"type": "http.response.body", "body": b"", "more_body": False})
|
@ -190,6 +190,7 @@ def create_letta_messages_from_llm_response(
|
||||
pre_computed_assistant_message_id: Optional[str] = None,
|
||||
pre_computed_tool_message_id: Optional[str] = None,
|
||||
llm_batch_item_id: Optional[str] = None,
|
||||
step_id: str | None = None,
|
||||
) -> List[Message]:
|
||||
messages = []
|
||||
|
||||
@ -244,6 +245,9 @@ def create_letta_messages_from_llm_response(
|
||||
)
|
||||
messages.append(heartbeat_system_message)
|
||||
|
||||
for message in messages:
|
||||
message.step_id = step_id
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
|
@ -94,6 +94,7 @@ from letta.services.provider_manager import ProviderManager
|
||||
from letta.services.sandbox_config_manager import SandboxConfigManager
|
||||
from letta.services.source_manager import SourceManager
|
||||
from letta.services.step_manager import StepManager
|
||||
from letta.services.telemetry_manager import TelemetryManager
|
||||
from letta.services.tool_executor.tool_execution_sandbox import ToolExecutionSandbox
|
||||
from letta.services.tool_manager import ToolManager
|
||||
from letta.services.user_manager import UserManager
|
||||
@ -213,6 +214,7 @@ class SyncServer(Server):
|
||||
self.identity_manager = IdentityManager()
|
||||
self.group_manager = GroupManager()
|
||||
self.batch_manager = LLMBatchManager()
|
||||
self.telemetry_manager = TelemetryManager()
|
||||
|
||||
# A resusable httpx client
|
||||
timeout = httpx.Timeout(connect=10.0, read=20.0, write=10.0, pool=10.0)
|
||||
@ -1000,6 +1002,30 @@ class SyncServer(Server):
|
||||
)
|
||||
return records
|
||||
|
||||
async def get_agent_archival_async(
|
||||
self,
|
||||
agent_id: str,
|
||||
actor: User,
|
||||
after: Optional[str] = None,
|
||||
before: Optional[str] = None,
|
||||
limit: Optional[int] = 100,
|
||||
order_by: Optional[str] = "created_at",
|
||||
reverse: Optional[bool] = False,
|
||||
query_text: Optional[str] = None,
|
||||
ascending: Optional[bool] = True,
|
||||
) -> List[Passage]:
|
||||
# iterate over records
|
||||
records = await self.agent_manager.list_passages_async(
|
||||
actor=actor,
|
||||
agent_id=agent_id,
|
||||
after=after,
|
||||
query_text=query_text,
|
||||
before=before,
|
||||
ascending=ascending,
|
||||
limit=limit,
|
||||
)
|
||||
return records
|
||||
|
||||
def insert_archival_memory(self, agent_id: str, memory_contents: str, actor: User) -> List[Passage]:
|
||||
# Get the agent object (loaded in memory)
|
||||
agent_state = self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor)
|
||||
@ -1070,6 +1096,44 @@ class SyncServer(Server):
|
||||
|
||||
return records
|
||||
|
||||
async def get_agent_recall_async(
|
||||
self,
|
||||
agent_id: str,
|
||||
actor: User,
|
||||
after: Optional[str] = None,
|
||||
before: Optional[str] = None,
|
||||
limit: Optional[int] = 100,
|
||||
group_id: Optional[str] = None,
|
||||
reverse: Optional[bool] = False,
|
||||
return_message_object: bool = True,
|
||||
use_assistant_message: bool = True,
|
||||
assistant_message_tool_name: str = constants.DEFAULT_MESSAGE_TOOL,
|
||||
assistant_message_tool_kwarg: str = constants.DEFAULT_MESSAGE_TOOL_KWARG,
|
||||
) -> Union[List[Message], List[LettaMessage]]:
|
||||
records = await self.message_manager.list_messages_for_agent_async(
|
||||
agent_id=agent_id,
|
||||
actor=actor,
|
||||
after=after,
|
||||
before=before,
|
||||
limit=limit,
|
||||
ascending=not reverse,
|
||||
group_id=group_id,
|
||||
)
|
||||
|
||||
if not return_message_object:
|
||||
records = Message.to_letta_messages_from_list(
|
||||
messages=records,
|
||||
use_assistant_message=use_assistant_message,
|
||||
assistant_message_tool_name=assistant_message_tool_name,
|
||||
assistant_message_tool_kwarg=assistant_message_tool_kwarg,
|
||||
reverse=reverse,
|
||||
)
|
||||
|
||||
if reverse:
|
||||
records = records[::-1]
|
||||
|
||||
return records
|
||||
|
||||
def get_server_config(self, include_defaults: bool = False) -> dict:
|
||||
"""Return the base config"""
|
||||
|
||||
@ -1301,6 +1365,48 @@ class SyncServer(Server):
|
||||
|
||||
return llm_models
|
||||
|
||||
@trace_method
|
||||
async def list_llm_models_async(
|
||||
self,
|
||||
actor: User,
|
||||
provider_category: Optional[List[ProviderCategory]] = None,
|
||||
provider_name: Optional[str] = None,
|
||||
provider_type: Optional[ProviderType] = None,
|
||||
) -> List[LLMConfig]:
|
||||
"""Asynchronously list available models with maximum concurrency"""
|
||||
import asyncio
|
||||
|
||||
providers = self.get_enabled_providers(
|
||||
provider_category=provider_category,
|
||||
provider_name=provider_name,
|
||||
provider_type=provider_type,
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
async def get_provider_models(provider):
|
||||
try:
|
||||
return await provider.list_llm_models_async()
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
warnings.warn(f"An error occurred while listing LLM models for provider {provider}: {e}")
|
||||
return []
|
||||
|
||||
# Execute all provider model listing tasks concurrently
|
||||
provider_results = await asyncio.gather(*[get_provider_models(provider) for provider in providers])
|
||||
|
||||
# Flatten the results
|
||||
llm_models = []
|
||||
for models in provider_results:
|
||||
llm_models.extend(models)
|
||||
|
||||
# Get local configs - if this is potentially slow, consider making it async too
|
||||
local_configs = self.get_local_llm_configs()
|
||||
llm_models.extend(local_configs)
|
||||
|
||||
return llm_models
|
||||
|
||||
def list_embedding_models(self, actor: User) -> List[EmbeddingConfig]:
|
||||
"""List available embedding models"""
|
||||
embedding_models = []
|
||||
@ -1311,6 +1417,35 @@ class SyncServer(Server):
|
||||
warnings.warn(f"An error occurred while listing embedding models for provider {provider}: {e}")
|
||||
return embedding_models
|
||||
|
||||
async def list_embedding_models_async(self, actor: User) -> List[EmbeddingConfig]:
|
||||
"""Asynchronously list available embedding models with maximum concurrency"""
|
||||
import asyncio
|
||||
|
||||
# Get all eligible providers first
|
||||
providers = self.get_enabled_providers(actor=actor)
|
||||
|
||||
# Fetch embedding models from each provider concurrently
|
||||
async def get_provider_embedding_models(provider):
|
||||
try:
|
||||
# All providers now have list_embedding_models_async
|
||||
return await provider.list_embedding_models_async()
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
warnings.warn(f"An error occurred while listing embedding models for provider {provider}: {e}")
|
||||
return []
|
||||
|
||||
# Execute all provider model listing tasks concurrently
|
||||
provider_results = await asyncio.gather(*[get_provider_embedding_models(provider) for provider in providers])
|
||||
|
||||
# Flatten the results
|
||||
embedding_models = []
|
||||
for models in provider_results:
|
||||
embedding_models.extend(models)
|
||||
|
||||
return embedding_models
|
||||
|
||||
def get_enabled_providers(
|
||||
self,
|
||||
actor: User,
|
||||
@ -1482,6 +1617,10 @@ class SyncServer(Server):
|
||||
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
|
||||
return letta_agent.get_context_window()
|
||||
|
||||
async def get_agent_context_window_async(self, agent_id: str, actor: User) -> ContextWindowOverview:
|
||||
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
|
||||
return await letta_agent.get_context_window_async()
|
||||
|
||||
def run_tool_from_source(
|
||||
self,
|
||||
actor: User,
|
||||
@ -1615,7 +1754,7 @@ class SyncServer(Server):
|
||||
server_name=server_name,
|
||||
command=server_params_raw["command"],
|
||||
args=server_params_raw.get("args", []),
|
||||
env=server_params_raw.get("env", {})
|
||||
env=server_params_raw.get("env", {}),
|
||||
)
|
||||
mcp_server_list[server_name] = server_params
|
||||
except Exception as e:
|
||||
|
@ -892,7 +892,7 @@ class AgentManager:
|
||||
List[PydanticAgentState]: The filtered list of matching agents.
|
||||
"""
|
||||
async with db_registry.async_session() as session:
|
||||
query = select(AgentModel).distinct(AgentModel.created_at, AgentModel.id)
|
||||
query = select(AgentModel)
|
||||
query = AgentModel.apply_access_predicate(query, actor, ["read"], AccessType.ORGANIZATION)
|
||||
|
||||
# Apply filters
|
||||
@ -961,6 +961,16 @@ class AgentManager:
|
||||
with db_registry.session() as session:
|
||||
return AgentModel.size(db_session=session, actor=actor)
|
||||
|
||||
async def size_async(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
) -> int:
|
||||
"""
|
||||
Get the total count of agents for the given user.
|
||||
"""
|
||||
async with db_registry.async_session() as session:
|
||||
return await AgentModel.size_async(db_session=session, actor=actor)
|
||||
|
||||
@enforce_types
|
||||
def get_agent_by_id(self, agent_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||||
"""Fetch an agent by its ID."""
|
||||
@ -969,18 +979,32 @@ class AgentManager:
|
||||
return agent.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
async def get_agent_by_id_async(self, agent_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||||
async def get_agent_by_id_async(
|
||||
self,
|
||||
agent_id: str,
|
||||
actor: PydanticUser,
|
||||
include_relationships: Optional[List[str]] = None,
|
||||
) -> PydanticAgentState:
|
||||
"""Fetch an agent by its ID."""
|
||||
async with db_registry.async_session() as session:
|
||||
agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
|
||||
return agent.to_pydantic()
|
||||
return await agent.to_pydantic_async(include_relationships=include_relationships)
|
||||
|
||||
@enforce_types
|
||||
async def get_agents_by_ids_async(self, agent_ids: list[str], actor: PydanticUser) -> list[PydanticAgentState]:
|
||||
async def get_agents_by_ids_async(
|
||||
self,
|
||||
agent_ids: list[str],
|
||||
actor: PydanticUser,
|
||||
include_relationships: Optional[List[str]] = None,
|
||||
) -> list[PydanticAgentState]:
|
||||
"""Fetch a list of agents by their IDs."""
|
||||
async with db_registry.async_session() as session:
|
||||
agents = await AgentModel.read_multiple_async(db_session=session, identifiers=agent_ids, actor=actor)
|
||||
return [await agent.to_pydantic_async() for agent in agents]
|
||||
agents = await AgentModel.read_multiple_async(
|
||||
db_session=session,
|
||||
identifiers=agent_ids,
|
||||
actor=actor,
|
||||
)
|
||||
return await asyncio.gather(*[agent.to_pydantic_async(include_relationships=include_relationships) for agent in agents])
|
||||
|
||||
@enforce_types
|
||||
def get_agent_by_name(self, agent_name: str, actor: PydanticUser) -> PydanticAgentState:
|
||||
@ -1191,7 +1215,7 @@ class AgentManager:
|
||||
|
||||
@enforce_types
|
||||
async def get_in_context_messages_async(self, agent_id: str, actor: PydanticUser) -> List[PydanticMessage]:
|
||||
agent = await self.get_agent_by_id_async(agent_id=agent_id, actor=actor)
|
||||
agent = await self.get_agent_by_id_async(agent_id=agent_id, include_relationships=[], actor=actor)
|
||||
return await self.message_manager.get_messages_by_ids_async(message_ids=agent.message_ids, actor=actor)
|
||||
|
||||
@enforce_types
|
||||
@ -1199,6 +1223,11 @@ class AgentManager:
|
||||
message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids
|
||||
return self.message_manager.get_message_by_id(message_id=message_ids[0], actor=actor)
|
||||
|
||||
@enforce_types
|
||||
async def get_system_message_async(self, agent_id: str, actor: PydanticUser) -> PydanticMessage:
|
||||
agent = await self.get_agent_by_id_async(agent_id=agent_id, include_relationships=[], actor=actor)
|
||||
return await self.message_manager.get_message_by_id_async(message_id=agent.message_ids[0], actor=actor)
|
||||
|
||||
# TODO: This is duplicated below
|
||||
# TODO: This is legacy code and should be cleaned up
|
||||
# TODO: A lot of the memory "compilation" should be offset to a separate class
|
||||
@ -1267,10 +1296,81 @@ class AgentManager:
|
||||
else:
|
||||
return agent_state
|
||||
|
||||
@enforce_types
|
||||
async def rebuild_system_prompt_async(
|
||||
self, agent_id: str, actor: PydanticUser, force=False, update_timestamp=True
|
||||
) -> PydanticAgentState:
|
||||
"""Rebuilds the system message with the latest memory object and any shared memory block updates
|
||||
|
||||
Updates to core memory blocks should trigger a "rebuild", which itself will create a new message object
|
||||
|
||||
Updates to the memory header should *not* trigger a rebuild, since that will simply flood recall storage with excess messages
|
||||
"""
|
||||
agent_state = await self.get_agent_by_id_async(agent_id=agent_id, include_relationships=["memory"], actor=actor)
|
||||
|
||||
curr_system_message = await self.get_system_message_async(
|
||||
agent_id=agent_id, actor=actor
|
||||
) # this is the system + memory bank, not just the system prompt
|
||||
curr_system_message_openai = curr_system_message.to_openai_dict()
|
||||
|
||||
# note: we only update the system prompt if the core memory is changed
|
||||
# this means that the archival/recall memory statistics may be someout out of date
|
||||
curr_memory_str = agent_state.memory.compile()
|
||||
if curr_memory_str in curr_system_message_openai["content"] and not force:
|
||||
# NOTE: could this cause issues if a block is removed? (substring match would still work)
|
||||
logger.debug(
|
||||
f"Memory hasn't changed for agent id={agent_id} and actor=({actor.id}, {actor.name}), skipping system prompt rebuild"
|
||||
)
|
||||
return agent_state
|
||||
|
||||
# If the memory didn't update, we probably don't want to update the timestamp inside
|
||||
# For example, if we're doing a system prompt swap, this should probably be False
|
||||
if update_timestamp:
|
||||
memory_edit_timestamp = get_utc_time()
|
||||
else:
|
||||
# NOTE: a bit of a hack - we pull the timestamp from the message created_by
|
||||
memory_edit_timestamp = curr_system_message.created_at
|
||||
|
||||
num_messages = await self.message_manager.size_async(actor=actor, agent_id=agent_id)
|
||||
num_archival_memories = await self.passage_manager.size_async(actor=actor, agent_id=agent_id)
|
||||
|
||||
# update memory (TODO: potentially update recall/archival stats separately)
|
||||
new_system_message_str = compile_system_message(
|
||||
system_prompt=agent_state.system,
|
||||
in_context_memory=agent_state.memory,
|
||||
in_context_memory_last_edit=memory_edit_timestamp,
|
||||
recent_passages=self.list_passages(actor=actor, agent_id=agent_id, ascending=False, limit=10),
|
||||
previous_message_count=num_messages,
|
||||
archival_memory_size=num_archival_memories,
|
||||
)
|
||||
|
||||
diff = united_diff(curr_system_message_openai["content"], new_system_message_str)
|
||||
if len(diff) > 0: # there was a diff
|
||||
logger.debug(f"Rebuilding system with new memory...\nDiff:\n{diff}")
|
||||
|
||||
# Swap the system message out (only if there is a diff)
|
||||
message = PydanticMessage.dict_to_message(
|
||||
agent_id=agent_id,
|
||||
model=agent_state.llm_config.model,
|
||||
openai_message_dict={"role": "system", "content": new_system_message_str},
|
||||
)
|
||||
message = await self.message_manager.update_message_by_id_async(
|
||||
message_id=curr_system_message.id,
|
||||
message_update=MessageUpdate(**message.model_dump()),
|
||||
actor=actor,
|
||||
)
|
||||
return await self.set_in_context_messages_async(agent_id=agent_id, message_ids=agent_state.message_ids, actor=actor)
|
||||
else:
|
||||
return agent_state
|
||||
|
||||
@enforce_types
|
||||
def set_in_context_messages(self, agent_id: str, message_ids: List[str], actor: PydanticUser) -> PydanticAgentState:
|
||||
return self.update_agent(agent_id=agent_id, agent_update=UpdateAgent(message_ids=message_ids), actor=actor)
|
||||
|
||||
@enforce_types
|
||||
async def set_in_context_messages_async(self, agent_id: str, message_ids: List[str], actor: PydanticUser) -> PydanticAgentState:
|
||||
return await self.update_agent_async(agent_id=agent_id, agent_update=UpdateAgent(message_ids=message_ids), actor=actor)
|
||||
|
||||
@enforce_types
|
||||
def trim_older_in_context_messages(self, num: int, agent_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||||
message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids
|
||||
@ -1382,17 +1482,6 @@ class AgentManager:
|
||||
|
||||
return agent_state
|
||||
|
||||
@enforce_types
|
||||
def refresh_memory(self, agent_state: PydanticAgentState, actor: PydanticUser) -> PydanticAgentState:
|
||||
block_ids = [b.id for b in agent_state.memory.blocks]
|
||||
if not block_ids:
|
||||
return agent_state
|
||||
|
||||
agent_state.memory.blocks = self.block_manager.get_all_blocks_by_ids(
|
||||
block_ids=[b.id for b in agent_state.memory.blocks], actor=actor
|
||||
)
|
||||
return agent_state
|
||||
|
||||
@enforce_types
|
||||
async def refresh_memory_async(self, agent_state: PydanticAgentState, actor: PydanticUser) -> PydanticAgentState:
|
||||
block_ids = [b.id for b in agent_state.memory.blocks]
|
||||
@ -1482,6 +1571,25 @@ class AgentManager:
|
||||
# Use the lazy-loaded relationship to get sources
|
||||
return [source.to_pydantic() for source in agent.sources]
|
||||
|
||||
@enforce_types
|
||||
async def list_attached_sources_async(self, agent_id: str, actor: PydanticUser) -> List[PydanticSource]:
|
||||
"""
|
||||
Lists all sources attached to an agent.
|
||||
|
||||
Args:
|
||||
agent_id: ID of the agent to list sources for
|
||||
actor: User performing the action
|
||||
|
||||
Returns:
|
||||
List[str]: List of source IDs attached to the agent
|
||||
"""
|
||||
async with db_registry.async_session() as session:
|
||||
# Verify agent exists and user has permission to access it
|
||||
agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
|
||||
|
||||
# Use the lazy-loaded relationship to get sources
|
||||
return [source.to_pydantic() for source in agent.sources]
|
||||
|
||||
@enforce_types
|
||||
def detach_source(self, agent_id: str, source_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||||
"""
|
||||
@ -1527,6 +1635,33 @@ class AgentManager:
|
||||
return block.to_pydantic()
|
||||
raise NoResultFound(f"No block with label '{block_label}' found for agent '{agent_id}'")
|
||||
|
||||
@enforce_types
|
||||
async def modify_block_by_label_async(
|
||||
self,
|
||||
agent_id: str,
|
||||
block_label: str,
|
||||
block_update: BlockUpdate,
|
||||
actor: PydanticUser,
|
||||
) -> PydanticBlock:
|
||||
"""Gets a block attached to an agent by its label."""
|
||||
async with db_registry.async_session() as session:
|
||||
block = None
|
||||
agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
|
||||
for block in agent.core_memory:
|
||||
if block.label == block_label:
|
||||
block = block
|
||||
break
|
||||
if not block:
|
||||
raise NoResultFound(f"No block with label '{block_label}' found for agent '{agent_id}'")
|
||||
|
||||
update_data = block_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
|
||||
|
||||
for key, value in update_data.items():
|
||||
setattr(block, key, value)
|
||||
|
||||
await block.update_async(session, actor=actor)
|
||||
return block.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def update_block_with_label(
|
||||
self,
|
||||
@ -1848,6 +1983,65 @@ class AgentManager:
|
||||
|
||||
return [p.to_pydantic() for p in passages]
|
||||
|
||||
@enforce_types
|
||||
async def list_passages_async(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
agent_id: Optional[str] = None,
|
||||
file_id: Optional[str] = None,
|
||||
limit: Optional[int] = 50,
|
||||
query_text: Optional[str] = None,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
before: Optional[str] = None,
|
||||
after: Optional[str] = None,
|
||||
source_id: Optional[str] = None,
|
||||
embed_query: bool = False,
|
||||
ascending: bool = True,
|
||||
embedding_config: Optional[EmbeddingConfig] = None,
|
||||
agent_only: bool = False,
|
||||
) -> List[PydanticPassage]:
|
||||
"""Lists all passages attached to an agent."""
|
||||
async with db_registry.async_session() as session:
|
||||
main_query = self._build_passage_query(
|
||||
actor=actor,
|
||||
agent_id=agent_id,
|
||||
file_id=file_id,
|
||||
query_text=query_text,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
before=before,
|
||||
after=after,
|
||||
source_id=source_id,
|
||||
embed_query=embed_query,
|
||||
ascending=ascending,
|
||||
embedding_config=embedding_config,
|
||||
agent_only=agent_only,
|
||||
)
|
||||
|
||||
# Add limit
|
||||
if limit:
|
||||
main_query = main_query.limit(limit)
|
||||
|
||||
# Execute query
|
||||
result = await session.execute(main_query)
|
||||
|
||||
passages = []
|
||||
for row in result:
|
||||
data = dict(row._mapping)
|
||||
if data["agent_id"] is not None:
|
||||
# This is an AgentPassage - remove source fields
|
||||
data.pop("source_id", None)
|
||||
data.pop("file_id", None)
|
||||
passage = AgentPassage(**data)
|
||||
else:
|
||||
# This is a SourcePassage - remove agent field
|
||||
data.pop("agent_id", None)
|
||||
passage = SourcePassage(**data)
|
||||
passages.append(passage)
|
||||
|
||||
return [p.to_pydantic() for p in passages]
|
||||
|
||||
@enforce_types
|
||||
def passage_size(
|
||||
self,
|
||||
@ -2010,3 +2204,42 @@ class AgentManager:
|
||||
query = query.order_by(AgentsTags.tag).limit(limit)
|
||||
results = [tag[0] for tag in query.all()]
|
||||
return results
|
||||
|
||||
@enforce_types
|
||||
async def list_tags_async(
|
||||
self, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50, query_text: Optional[str] = None
|
||||
) -> List[str]:
|
||||
"""
|
||||
Get all tags a user has created, ordered alphabetically.
|
||||
|
||||
Args:
|
||||
actor: User performing the action.
|
||||
after: Cursor for forward pagination.
|
||||
limit: Maximum number of tags to return.
|
||||
query text to filter tags by.
|
||||
|
||||
Returns:
|
||||
List[str]: List of all tags.
|
||||
"""
|
||||
async with db_registry.async_session() as session:
|
||||
# Build the query using select() for async SQLAlchemy
|
||||
query = (
|
||||
select(AgentsTags.tag)
|
||||
.join(AgentModel, AgentModel.id == AgentsTags.agent_id)
|
||||
.where(AgentModel.organization_id == actor.organization_id)
|
||||
.distinct()
|
||||
)
|
||||
|
||||
if query_text:
|
||||
query = query.where(AgentsTags.tag.ilike(f"%{query_text}%"))
|
||||
|
||||
if after:
|
||||
query = query.where(AgentsTags.tag > after)
|
||||
|
||||
query = query.order_by(AgentsTags.tag).limit(limit)
|
||||
|
||||
# Execute the query asynchronously
|
||||
result = await session.execute(query)
|
||||
# Extract the tag values from the result
|
||||
results = [row[0] for row in result.all()]
|
||||
return results
|
||||
|
@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
@ -82,39 +83,64 @@ class BlockManager:
|
||||
return block.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def get_blocks(
|
||||
async def get_blocks_async(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
label: Optional[str] = None,
|
||||
is_template: Optional[bool] = None,
|
||||
template_name: Optional[str] = None,
|
||||
identifier_keys: Optional[List[str]] = None,
|
||||
identity_id: Optional[str] = None,
|
||||
id: Optional[str] = None,
|
||||
after: Optional[str] = None,
|
||||
identifier_keys: Optional[List[str]] = None,
|
||||
limit: Optional[int] = 50,
|
||||
) -> List[PydanticBlock]:
|
||||
"""Retrieve blocks based on various optional filters."""
|
||||
with db_registry.session() as session:
|
||||
# Prepare filters
|
||||
filters = {"organization_id": actor.organization_id}
|
||||
if label:
|
||||
filters["label"] = label
|
||||
if is_template is not None:
|
||||
filters["is_template"] = is_template
|
||||
if template_name:
|
||||
filters["template_name"] = template_name
|
||||
if id:
|
||||
filters["id"] = id
|
||||
"""Async version of get_blocks method. Retrieve blocks based on various optional filters."""
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import noload
|
||||
|
||||
blocks = BlockModel.list(
|
||||
db_session=session,
|
||||
after=after,
|
||||
limit=limit,
|
||||
identifier_keys=identifier_keys,
|
||||
identity_id=identity_id,
|
||||
**filters,
|
||||
)
|
||||
from letta.orm.sqlalchemy_base import AccessType
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
# Start with a basic query
|
||||
query = select(BlockModel)
|
||||
|
||||
# Explicitly avoid loading relationships
|
||||
query = query.options(noload(BlockModel.agents), noload(BlockModel.identities), noload(BlockModel.groups))
|
||||
|
||||
# Apply access control
|
||||
query = BlockModel.apply_access_predicate(query, actor, ["read"], AccessType.ORGANIZATION)
|
||||
|
||||
# Add filters
|
||||
query = query.where(BlockModel.organization_id == actor.organization_id)
|
||||
if label:
|
||||
query = query.where(BlockModel.label == label)
|
||||
|
||||
if is_template is not None:
|
||||
query = query.where(BlockModel.is_template == is_template)
|
||||
|
||||
if template_name:
|
||||
query = query.where(BlockModel.template_name == template_name)
|
||||
|
||||
if identifier_keys:
|
||||
query = (
|
||||
query.join(BlockModel.identities)
|
||||
.filter(BlockModel.identities.property.mapper.class_.identifier_key.in_(identifier_keys))
|
||||
.distinct(BlockModel.id)
|
||||
)
|
||||
|
||||
if identity_id:
|
||||
query = (
|
||||
query.join(BlockModel.identities)
|
||||
.filter(BlockModel.identities.property.mapper.class_.id == identity_id)
|
||||
.distinct(BlockModel.id)
|
||||
)
|
||||
|
||||
# Add limit
|
||||
if limit:
|
||||
query = query.limit(limit)
|
||||
|
||||
# Execute the query
|
||||
result = await session.execute(query)
|
||||
blocks = result.scalars().all()
|
||||
|
||||
return [block.to_pydantic() for block in blocks]
|
||||
|
||||
@ -190,15 +216,6 @@ class BlockManager:
|
||||
except NoResultFound:
|
||||
return None
|
||||
|
||||
@enforce_types
|
||||
def get_all_blocks_by_ids(self, block_ids: List[str], actor: Optional[PydanticUser] = None) -> List[PydanticBlock]:
|
||||
"""Retrieve blocks by their ids."""
|
||||
with db_registry.session() as session:
|
||||
blocks = [block.to_pydantic() for block in BlockModel.read_multiple(db_session=session, identifiers=block_ids, actor=actor)]
|
||||
# backwards compatibility. previous implementation added None for every block not found.
|
||||
blocks.extend([None for _ in range(len(block_ids) - len(blocks))])
|
||||
return blocks
|
||||
|
||||
@enforce_types
|
||||
async def get_all_blocks_by_ids_async(self, block_ids: List[str], actor: Optional[PydanticUser] = None) -> List[PydanticBlock]:
|
||||
"""Retrieve blocks by their ids without loading unnecessary relationships. Async implementation."""
|
||||
@ -247,16 +264,14 @@ class BlockManager:
|
||||
return pydantic_blocks
|
||||
|
||||
@enforce_types
|
||||
def get_agents_for_block(self, block_id: str, actor: PydanticUser) -> List[PydanticAgentState]:
|
||||
async def get_agents_for_block_async(self, block_id: str, actor: PydanticUser) -> List[PydanticAgentState]:
|
||||
"""
|
||||
Retrieve all agents associated with a given block.
|
||||
"""
|
||||
with db_registry.session() as session:
|
||||
block = BlockModel.read(db_session=session, identifier=block_id, actor=actor)
|
||||
async with db_registry.async_session() as session:
|
||||
block = await BlockModel.read_async(db_session=session, identifier=block_id, actor=actor)
|
||||
agents_orm = block.agents
|
||||
agents_pydantic = [agent.to_pydantic() for agent in agents_orm]
|
||||
|
||||
return agents_pydantic
|
||||
return await asyncio.gather(*[agent.to_pydantic_async() for agent in agents_orm])
|
||||
|
||||
@enforce_types
|
||||
def size(
|
||||
|
10
letta/services/helpers/noop_helper.py
Normal file
10
letta/services/helpers/noop_helper.py
Normal file
@ -0,0 +1,10 @@
|
||||
def singleton(cls):
|
||||
"""Decorator to make a class a Singleton class."""
|
||||
instances = {}
|
||||
|
||||
def get_instance(*args, **kwargs):
|
||||
if cls not in instances:
|
||||
instances[cls] = cls(*args, **kwargs)
|
||||
return instances[cls]
|
||||
|
||||
return get_instance
|
@ -1,6 +1,7 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import NoResultFound
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@ -17,7 +18,7 @@ from letta.utils import enforce_types
|
||||
class IdentityManager:
|
||||
|
||||
@enforce_types
|
||||
def list_identities(
|
||||
async def list_identities_async(
|
||||
self,
|
||||
name: Optional[str] = None,
|
||||
project_id: Optional[str] = None,
|
||||
@ -28,7 +29,7 @@ class IdentityManager:
|
||||
limit: Optional[int] = 50,
|
||||
actor: PydanticUser = None,
|
||||
) -> list[PydanticIdentity]:
|
||||
with db_registry.session() as session:
|
||||
async with db_registry.async_session() as session:
|
||||
filters = {"organization_id": actor.organization_id}
|
||||
if project_id:
|
||||
filters["project_id"] = project_id
|
||||
@ -36,7 +37,7 @@ class IdentityManager:
|
||||
filters["identifier_key"] = identifier_key
|
||||
if identity_type:
|
||||
filters["identity_type"] = identity_type
|
||||
identities = IdentityModel.list(
|
||||
identities = await IdentityModel.list_async(
|
||||
db_session=session,
|
||||
query_text=name,
|
||||
before=before,
|
||||
@ -47,17 +48,17 @@ class IdentityManager:
|
||||
return [identity.to_pydantic() for identity in identities]
|
||||
|
||||
@enforce_types
|
||||
def get_identity(self, identity_id: str, actor: PydanticUser) -> PydanticIdentity:
|
||||
with db_registry.session() as session:
|
||||
identity = IdentityModel.read(db_session=session, identifier=identity_id, actor=actor)
|
||||
async def get_identity_async(self, identity_id: str, actor: PydanticUser) -> PydanticIdentity:
|
||||
async with db_registry.async_session() as session:
|
||||
identity = await IdentityModel.read_async(db_session=session, identifier=identity_id, actor=actor)
|
||||
return identity.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def create_identity(self, identity: IdentityCreate, actor: PydanticUser) -> PydanticIdentity:
|
||||
with db_registry.session() as session:
|
||||
async def create_identity_async(self, identity: IdentityCreate, actor: PydanticUser) -> PydanticIdentity:
|
||||
async with db_registry.async_session() as session:
|
||||
new_identity = IdentityModel(**identity.model_dump(exclude={"agent_ids", "block_ids"}, exclude_unset=True))
|
||||
new_identity.organization_id = actor.organization_id
|
||||
self._process_relationship(
|
||||
await self._process_relationship_async(
|
||||
session=session,
|
||||
identity=new_identity,
|
||||
relationship_name="agents",
|
||||
@ -65,7 +66,7 @@ class IdentityManager:
|
||||
item_ids=identity.agent_ids,
|
||||
allow_partial=False,
|
||||
)
|
||||
self._process_relationship(
|
||||
await self._process_relationship_async(
|
||||
session=session,
|
||||
identity=new_identity,
|
||||
relationship_name="blocks",
|
||||
@ -73,13 +74,13 @@ class IdentityManager:
|
||||
item_ids=identity.block_ids,
|
||||
allow_partial=False,
|
||||
)
|
||||
new_identity.create(session, actor=actor)
|
||||
await new_identity.create_async(session, actor=actor)
|
||||
return new_identity.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def upsert_identity(self, identity: IdentityUpsert, actor: PydanticUser) -> PydanticIdentity:
|
||||
with db_registry.session() as session:
|
||||
existing_identity = IdentityModel.read(
|
||||
async def upsert_identity_async(self, identity: IdentityUpsert, actor: PydanticUser) -> PydanticIdentity:
|
||||
async with db_registry.async_session() as session:
|
||||
existing_identity = await IdentityModel.read_async(
|
||||
db_session=session,
|
||||
identifier_key=identity.identifier_key,
|
||||
project_id=identity.project_id,
|
||||
@ -88,7 +89,7 @@ class IdentityManager:
|
||||
)
|
||||
|
||||
if existing_identity is None:
|
||||
return self.create_identity(identity=IdentityCreate(**identity.model_dump()), actor=actor)
|
||||
return await self.create_identity_async(identity=IdentityCreate(**identity.model_dump()), actor=actor)
|
||||
else:
|
||||
identity_update = IdentityUpdate(
|
||||
name=identity.name,
|
||||
@ -97,25 +98,27 @@ class IdentityManager:
|
||||
agent_ids=identity.agent_ids,
|
||||
properties=identity.properties,
|
||||
)
|
||||
return self._update_identity(
|
||||
return await self._update_identity_async(
|
||||
session=session, existing_identity=existing_identity, identity=identity_update, actor=actor, replace=True
|
||||
)
|
||||
|
||||
@enforce_types
|
||||
def update_identity(self, identity_id: str, identity: IdentityUpdate, actor: PydanticUser, replace: bool = False) -> PydanticIdentity:
|
||||
with db_registry.session() as session:
|
||||
async def update_identity_async(
|
||||
self, identity_id: str, identity: IdentityUpdate, actor: PydanticUser, replace: bool = False
|
||||
) -> PydanticIdentity:
|
||||
async with db_registry.async_session() as session:
|
||||
try:
|
||||
existing_identity = IdentityModel.read(db_session=session, identifier=identity_id, actor=actor)
|
||||
existing_identity = await IdentityModel.read_async(db_session=session, identifier=identity_id, actor=actor)
|
||||
except NoResultFound:
|
||||
raise HTTPException(status_code=404, detail="Identity not found")
|
||||
if existing_identity.organization_id != actor.organization_id:
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
|
||||
return self._update_identity(
|
||||
return await self._update_identity_async(
|
||||
session=session, existing_identity=existing_identity, identity=identity, actor=actor, replace=replace
|
||||
)
|
||||
|
||||
def _update_identity(
|
||||
async def _update_identity_async(
|
||||
self,
|
||||
session: Session,
|
||||
existing_identity: IdentityModel,
|
||||
@ -139,7 +142,7 @@ class IdentityManager:
|
||||
existing_identity.properties = list(new_properties.values())
|
||||
|
||||
if identity.agent_ids is not None:
|
||||
self._process_relationship(
|
||||
await self._process_relationship_async(
|
||||
session=session,
|
||||
identity=existing_identity,
|
||||
relationship_name="agents",
|
||||
@ -149,7 +152,7 @@ class IdentityManager:
|
||||
replace=replace,
|
||||
)
|
||||
if identity.block_ids is not None:
|
||||
self._process_relationship(
|
||||
await self._process_relationship_async(
|
||||
session=session,
|
||||
identity=existing_identity,
|
||||
relationship_name="blocks",
|
||||
@ -158,16 +161,18 @@ class IdentityManager:
|
||||
allow_partial=False,
|
||||
replace=replace,
|
||||
)
|
||||
existing_identity.update(session, actor=actor)
|
||||
await existing_identity.update_async(session, actor=actor)
|
||||
return existing_identity.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def upsert_identity_properties(self, identity_id: str, properties: List[IdentityProperty], actor: PydanticUser) -> PydanticIdentity:
|
||||
with db_registry.session() as session:
|
||||
existing_identity = IdentityModel.read(db_session=session, identifier=identity_id, actor=actor)
|
||||
async def upsert_identity_properties_async(
|
||||
self, identity_id: str, properties: List[IdentityProperty], actor: PydanticUser
|
||||
) -> PydanticIdentity:
|
||||
async with db_registry.async_session() as session:
|
||||
existing_identity = await IdentityModel.read_async(db_session=session, identifier=identity_id, actor=actor)
|
||||
if existing_identity is None:
|
||||
raise HTTPException(status_code=404, detail="Identity not found")
|
||||
return self._update_identity(
|
||||
return await self._update_identity_async(
|
||||
session=session,
|
||||
existing_identity=existing_identity,
|
||||
identity=IdentityUpdate(properties=properties),
|
||||
@ -176,28 +181,28 @@ class IdentityManager:
|
||||
)
|
||||
|
||||
@enforce_types
|
||||
def delete_identity(self, identity_id: str, actor: PydanticUser) -> None:
|
||||
with db_registry.session() as session:
|
||||
identity = IdentityModel.read(db_session=session, identifier=identity_id)
|
||||
async def delete_identity_async(self, identity_id: str, actor: PydanticUser) -> None:
|
||||
async with db_registry.async_session() as session:
|
||||
identity = await IdentityModel.read_async(db_session=session, identifier=identity_id, actor=actor)
|
||||
if identity is None:
|
||||
raise HTTPException(status_code=404, detail="Identity not found")
|
||||
if identity.organization_id != actor.organization_id:
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
session.delete(identity)
|
||||
session.commit()
|
||||
await session.delete(identity)
|
||||
await session.commit()
|
||||
|
||||
@enforce_types
|
||||
def size(
|
||||
async def size_async(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
) -> int:
|
||||
"""
|
||||
Get the total count of identities for the given user.
|
||||
"""
|
||||
with db_registry.session() as session:
|
||||
return IdentityModel.size(db_session=session, actor=actor)
|
||||
async with db_registry.async_session() as session:
|
||||
return await IdentityModel.size_async(db_session=session, actor=actor)
|
||||
|
||||
def _process_relationship(
|
||||
async def _process_relationship_async(
|
||||
self,
|
||||
session: Session,
|
||||
identity: PydanticIdentity,
|
||||
@ -214,7 +219,7 @@ class IdentityManager:
|
||||
return
|
||||
|
||||
# Retrieve models for the provided IDs
|
||||
found_items = session.query(model_class).filter(model_class.id.in_(item_ids)).all()
|
||||
found_items = (await session.execute(select(model_class).where(model_class.id.in_(item_ids)))).scalars().all()
|
||||
|
||||
# Validate all items are found if allow_partial is False
|
||||
if not allow_partial and len(found_items) != len(item_ids):
|
||||
|
@ -150,6 +150,35 @@ class JobManager:
|
||||
)
|
||||
return [job.to_pydantic() for job in jobs]
|
||||
|
||||
@enforce_types
|
||||
async def list_jobs_async(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
before: Optional[str] = None,
|
||||
after: Optional[str] = None,
|
||||
limit: Optional[int] = 50,
|
||||
statuses: Optional[List[JobStatus]] = None,
|
||||
job_type: JobType = JobType.JOB,
|
||||
ascending: bool = True,
|
||||
) -> List[PydanticJob]:
|
||||
"""List all jobs with optional pagination and status filter."""
|
||||
async with db_registry.async_session() as session:
|
||||
filter_kwargs = {"user_id": actor.id, "job_type": job_type}
|
||||
|
||||
# Add status filter if provided
|
||||
if statuses:
|
||||
filter_kwargs["status"] = statuses
|
||||
|
||||
jobs = await JobModel.list_async(
|
||||
db_session=session,
|
||||
before=before,
|
||||
after=after,
|
||||
limit=limit,
|
||||
ascending=ascending,
|
||||
**filter_kwargs,
|
||||
)
|
||||
return [job.to_pydantic() for job in jobs]
|
||||
|
||||
@enforce_types
|
||||
def delete_job_by_id(self, job_id: str, actor: PydanticUser) -> PydanticJob:
|
||||
"""Delete a job by its ID."""
|
||||
|
@ -31,6 +31,16 @@ class MessageManager:
|
||||
except NoResultFound:
|
||||
return None
|
||||
|
||||
@enforce_types
|
||||
async def get_message_by_id_async(self, message_id: str, actor: PydanticUser) -> Optional[PydanticMessage]:
|
||||
"""Fetch a message by ID."""
|
||||
async with db_registry.async_session() as session:
|
||||
try:
|
||||
message = await MessageModel.read_async(db_session=session, identifier=message_id, actor=actor)
|
||||
return message.to_pydantic()
|
||||
except NoResultFound:
|
||||
return None
|
||||
|
||||
@enforce_types
|
||||
def get_messages_by_ids(self, message_ids: List[str], actor: PydanticUser) -> List[PydanticMessage]:
|
||||
"""Fetch messages by ID and return them in the requested order."""
|
||||
@ -426,6 +436,107 @@ class MessageManager:
|
||||
results = query.all()
|
||||
return [msg.to_pydantic() for msg in results]
|
||||
|
||||
@enforce_types
|
||||
async def list_messages_for_agent_async(
|
||||
self,
|
||||
agent_id: str,
|
||||
actor: PydanticUser,
|
||||
after: Optional[str] = None,
|
||||
before: Optional[str] = None,
|
||||
query_text: Optional[str] = None,
|
||||
roles: Optional[Sequence[MessageRole]] = None,
|
||||
limit: Optional[int] = 50,
|
||||
ascending: bool = True,
|
||||
group_id: Optional[str] = None,
|
||||
) -> List[PydanticMessage]:
|
||||
"""
|
||||
Most performant query to list messages for an agent by directly querying the Message table.
|
||||
|
||||
This function filters by the agent_id (leveraging the index on messages.agent_id)
|
||||
and applies pagination using sequence_id as the cursor.
|
||||
If query_text is provided, it will filter messages whose text content partially matches the query.
|
||||
If role is provided, it will filter messages by the specified role.
|
||||
|
||||
Args:
|
||||
agent_id: The ID of the agent whose messages are queried.
|
||||
actor: The user performing the action (used for permission checks).
|
||||
after: A message ID; if provided, only messages *after* this message (by sequence_id) are returned.
|
||||
before: A message ID; if provided, only messages *before* this message (by sequence_id) are returned.
|
||||
query_text: Optional string to partially match the message text content.
|
||||
roles: Optional MessageRole to filter messages by role.
|
||||
limit: Maximum number of messages to return.
|
||||
ascending: If True, sort by sequence_id ascending; if False, sort descending.
|
||||
group_id: Optional group ID to filter messages by group_id.
|
||||
|
||||
Returns:
|
||||
List[PydanticMessage]: A list of messages (converted via .to_pydantic()).
|
||||
|
||||
Raises:
|
||||
NoResultFound: If the provided after/before message IDs do not exist.
|
||||
"""
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
# Permission check: raise if the agent doesn't exist or actor is not allowed.
|
||||
await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
|
||||
|
||||
# Build a query that directly filters the Message table by agent_id.
|
||||
query = select(MessageModel).where(MessageModel.agent_id == agent_id)
|
||||
|
||||
# If group_id is provided, filter messages by group_id.
|
||||
if group_id:
|
||||
query = query.where(MessageModel.group_id == group_id)
|
||||
|
||||
# If query_text is provided, filter messages using subquery + json_array_elements.
|
||||
if query_text:
|
||||
content_element = func.json_array_elements(MessageModel.content).alias("content_element")
|
||||
query = query.where(
|
||||
exists(
|
||||
select(1)
|
||||
.select_from(content_element)
|
||||
.where(text("content_element->>'type' = 'text' AND content_element->>'text' ILIKE :query_text"))
|
||||
.params(query_text=f"%{query_text}%")
|
||||
)
|
||||
)
|
||||
|
||||
# If role(s) are provided, filter messages by those roles.
|
||||
if roles:
|
||||
role_values = [r.value for r in roles]
|
||||
query = query.where(MessageModel.role.in_(role_values))
|
||||
|
||||
# Apply 'after' pagination if specified.
|
||||
if after:
|
||||
after_query = select(MessageModel.sequence_id).where(MessageModel.id == after)
|
||||
after_result = await session.execute(after_query)
|
||||
after_ref = after_result.one_or_none()
|
||||
if not after_ref:
|
||||
raise NoResultFound(f"No message found with id '{after}' for agent '{agent_id}'.")
|
||||
# Filter out any messages with a sequence_id <= after_ref.sequence_id
|
||||
query = query.where(MessageModel.sequence_id > after_ref.sequence_id)
|
||||
|
||||
# Apply 'before' pagination if specified.
|
||||
if before:
|
||||
before_query = select(MessageModel.sequence_id).where(MessageModel.id == before)
|
||||
before_result = await session.execute(before_query)
|
||||
before_ref = before_result.one_or_none()
|
||||
if not before_ref:
|
||||
raise NoResultFound(f"No message found with id '{before}' for agent '{agent_id}'.")
|
||||
# Filter out any messages with a sequence_id >= before_ref.sequence_id
|
||||
query = query.where(MessageModel.sequence_id < before_ref.sequence_id)
|
||||
|
||||
# Apply ordering based on the ascending flag.
|
||||
if ascending:
|
||||
query = query.order_by(MessageModel.sequence_id.asc())
|
||||
else:
|
||||
query = query.order_by(MessageModel.sequence_id.desc())
|
||||
|
||||
# Limit the number of results.
|
||||
query = query.limit(limit)
|
||||
|
||||
# Execute and convert each Message to its Pydantic representation.
|
||||
result = await session.execute(query)
|
||||
results = result.scalars().all()
|
||||
return [msg.to_pydantic() for msg in results]
|
||||
|
||||
@enforce_types
|
||||
def delete_all_messages_for_agent(self, agent_id: str, actor: PydanticUser) -> int:
|
||||
"""
|
||||
|
@ -122,6 +122,23 @@ class SandboxConfigManager:
|
||||
sandboxes = SandboxConfigModel.list(db_session=session, after=after, limit=limit, **kwargs)
|
||||
return [sandbox.to_pydantic() for sandbox in sandboxes]
|
||||
|
||||
@enforce_types
|
||||
async def list_sandbox_configs_async(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
after: Optional[str] = None,
|
||||
limit: Optional[int] = 50,
|
||||
sandbox_type: Optional[SandboxType] = None,
|
||||
) -> List[PydanticSandboxConfig]:
|
||||
"""List all sandbox configurations with optional pagination."""
|
||||
kwargs = {"organization_id": actor.organization_id}
|
||||
if sandbox_type:
|
||||
kwargs.update({"type": sandbox_type})
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
sandboxes = await SandboxConfigModel.list_async(db_session=session, after=after, limit=limit, **kwargs)
|
||||
return [sandbox.to_pydantic() for sandbox in sandboxes]
|
||||
|
||||
@enforce_types
|
||||
def get_sandbox_config_by_id(self, sandbox_config_id: str, actor: Optional[PydanticUser] = None) -> Optional[PydanticSandboxConfig]:
|
||||
"""Retrieve a sandbox configuration by its ID."""
|
||||
@ -224,6 +241,25 @@ class SandboxConfigManager:
|
||||
)
|
||||
return [env_var.to_pydantic() for env_var in env_vars]
|
||||
|
||||
@enforce_types
|
||||
async def list_sandbox_env_vars_async(
|
||||
self,
|
||||
sandbox_config_id: str,
|
||||
actor: PydanticUser,
|
||||
after: Optional[str] = None,
|
||||
limit: Optional[int] = 50,
|
||||
) -> List[PydanticEnvVar]:
|
||||
"""List all sandbox environment variables with optional pagination."""
|
||||
async with db_registry.async_session() as session:
|
||||
env_vars = await SandboxEnvVarModel.list_async(
|
||||
db_session=session,
|
||||
after=after,
|
||||
limit=limit,
|
||||
organization_id=actor.organization_id,
|
||||
sandbox_config_id=sandbox_config_id,
|
||||
)
|
||||
return [env_var.to_pydantic() for env_var in env_vars]
|
||||
|
||||
@enforce_types
|
||||
def list_sandbox_env_vars_by_key(
|
||||
self, key: str, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50
|
||||
|
@ -2,6 +2,7 @@ from datetime import datetime
|
||||
from typing import List, Literal, Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from letta.orm.errors import NoResultFound
|
||||
@ -12,6 +13,7 @@ from letta.schemas.openai.chat_completion_response import UsageStatistics
|
||||
from letta.schemas.step import Step as PydanticStep
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.server.db import db_registry
|
||||
from letta.services.helpers.noop_helper import singleton
|
||||
from letta.tracing import get_trace_id
|
||||
from letta.utils import enforce_types
|
||||
|
||||
@ -57,12 +59,14 @@ class StepManager:
|
||||
actor: PydanticUser,
|
||||
agent_id: str,
|
||||
provider_name: str,
|
||||
provider_category: str,
|
||||
model: str,
|
||||
model_endpoint: Optional[str],
|
||||
context_window_limit: int,
|
||||
usage: UsageStatistics,
|
||||
provider_id: Optional[str] = None,
|
||||
job_id: Optional[str] = None,
|
||||
step_id: Optional[str] = None,
|
||||
) -> PydanticStep:
|
||||
step_data = {
|
||||
"origin": None,
|
||||
@ -70,6 +74,7 @@ class StepManager:
|
||||
"agent_id": agent_id,
|
||||
"provider_id": provider_id,
|
||||
"provider_name": provider_name,
|
||||
"provider_category": provider_category,
|
||||
"model": model,
|
||||
"model_endpoint": model_endpoint,
|
||||
"context_window_limit": context_window_limit,
|
||||
@ -81,6 +86,8 @@ class StepManager:
|
||||
"tid": None,
|
||||
"trace_id": get_trace_id(), # Get the current trace ID
|
||||
}
|
||||
if step_id:
|
||||
step_data["id"] = step_id
|
||||
with db_registry.session() as session:
|
||||
if job_id:
|
||||
self._verify_job_access(session, job_id, actor, access=["write"])
|
||||
@ -88,6 +95,48 @@ class StepManager:
|
||||
new_step.create(session)
|
||||
return new_step.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
async def log_step_async(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
agent_id: str,
|
||||
provider_name: str,
|
||||
provider_category: str,
|
||||
model: str,
|
||||
model_endpoint: Optional[str],
|
||||
context_window_limit: int,
|
||||
usage: UsageStatistics,
|
||||
provider_id: Optional[str] = None,
|
||||
job_id: Optional[str] = None,
|
||||
step_id: Optional[str] = None,
|
||||
) -> PydanticStep:
|
||||
step_data = {
|
||||
"origin": None,
|
||||
"organization_id": actor.organization_id,
|
||||
"agent_id": agent_id,
|
||||
"provider_id": provider_id,
|
||||
"provider_name": provider_name,
|
||||
"provider_category": provider_category,
|
||||
"model": model,
|
||||
"model_endpoint": model_endpoint,
|
||||
"context_window_limit": context_window_limit,
|
||||
"completion_tokens": usage.completion_tokens,
|
||||
"prompt_tokens": usage.prompt_tokens,
|
||||
"total_tokens": usage.total_tokens,
|
||||
"job_id": job_id,
|
||||
"tags": [],
|
||||
"tid": None,
|
||||
"trace_id": get_trace_id(), # Get the current trace ID
|
||||
}
|
||||
if step_id:
|
||||
step_data["id"] = step_id
|
||||
async with db_registry.async_session() as session:
|
||||
if job_id:
|
||||
await self._verify_job_access_async(session, job_id, actor, access=["write"])
|
||||
new_step = StepModel(**step_data)
|
||||
await new_step.create_async(session)
|
||||
return new_step.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def get_step(self, step_id: str, actor: PydanticUser) -> PydanticStep:
|
||||
with db_registry.session() as session:
|
||||
@ -147,3 +196,100 @@ class StepManager:
|
||||
if not job:
|
||||
raise NoResultFound(f"Job with id {job_id} does not exist or user does not have access")
|
||||
return job
|
||||
|
||||
async def _verify_job_access_async(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
job_id: str,
|
||||
actor: PydanticUser,
|
||||
access: List[Literal["read", "write", "delete"]] = ["read"],
|
||||
) -> JobModel:
|
||||
"""
|
||||
Verify that a job exists and the user has the required access asynchronously.
|
||||
|
||||
Args:
|
||||
session: The async database session
|
||||
job_id: The ID of the job to verify
|
||||
actor: The user making the request
|
||||
|
||||
Returns:
|
||||
The job if it exists and the user has access
|
||||
|
||||
Raises:
|
||||
NoResultFound: If the job does not exist or user does not have access
|
||||
"""
|
||||
job_query = select(JobModel).where(JobModel.id == job_id)
|
||||
job_query = JobModel.apply_access_predicate(job_query, actor, access, AccessType.USER)
|
||||
result = await session.execute(job_query)
|
||||
job = result.scalar_one_or_none()
|
||||
if not job:
|
||||
raise NoResultFound(f"Job with id {job_id} does not exist or user does not have access")
|
||||
return job
|
||||
|
||||
|
||||
@singleton
|
||||
class NoopStepManager(StepManager):
|
||||
"""
|
||||
Noop implementation of StepManager.
|
||||
Temporarily used for migrations, but allows for different implementations in the future.
|
||||
Will not allow for writes, but will still allow for reads.
|
||||
"""
|
||||
|
||||
@enforce_types
|
||||
def log_step(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
agent_id: str,
|
||||
provider_name: str,
|
||||
provider_category: str,
|
||||
model: str,
|
||||
model_endpoint: Optional[str],
|
||||
context_window_limit: int,
|
||||
usage: UsageStatistics,
|
||||
provider_id: Optional[str] = None,
|
||||
job_id: Optional[str] = None,
|
||||
step_id: Optional[str] = None,
|
||||
) -> PydanticStep:
|
||||
return
|
||||
|
||||
@enforce_types
|
||||
async def log_step_async(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
agent_id: str,
|
||||
provider_name: str,
|
||||
provider_category: str,
|
||||
model: str,
|
||||
model_endpoint: Optional[str],
|
||||
context_window_limit: int,
|
||||
usage: UsageStatistics,
|
||||
provider_id: Optional[str] = None,
|
||||
job_id: Optional[str] = None,
|
||||
step_id: Optional[str] = None,
|
||||
) -> PydanticStep:
|
||||
step_data = {
|
||||
"origin": None,
|
||||
"organization_id": actor.organization_id,
|
||||
"agent_id": agent_id,
|
||||
"provider_id": provider_id,
|
||||
"provider_name": provider_name,
|
||||
"provider_category": provider_category,
|
||||
"model": model,
|
||||
"model_endpoint": model_endpoint,
|
||||
"context_window_limit": context_window_limit,
|
||||
"completion_tokens": usage.completion_tokens,
|
||||
"prompt_tokens": usage.prompt_tokens,
|
||||
"total_tokens": usage.total_tokens,
|
||||
"job_id": job_id,
|
||||
"tags": [],
|
||||
"tid": None,
|
||||
"trace_id": get_trace_id(), # Get the current trace ID
|
||||
}
|
||||
if step_id:
|
||||
step_data["id"] = step_id
|
||||
async with db_registry.async_session() as session:
|
||||
if job_id:
|
||||
await self._verify_job_access_async(session, job_id, actor, access=["write"])
|
||||
new_step = StepModel(**step_data)
|
||||
await new_step.create_async(session)
|
||||
return new_step.to_pydantic()
|
||||
|
58
letta/services/telemetry_manager.py
Normal file
58
letta/services/telemetry_manager.py
Normal file
@ -0,0 +1,58 @@
|
||||
from letta.helpers.json_helpers import json_dumps, json_loads
|
||||
from letta.orm.provider_trace import ProviderTrace as ProviderTraceModel
|
||||
from letta.schemas.provider_trace import ProviderTrace as PydanticProviderTrace
|
||||
from letta.schemas.provider_trace import ProviderTraceCreate
|
||||
from letta.schemas.step import Step as PydanticStep
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.server.db import db_registry
|
||||
from letta.services.helpers.noop_helper import singleton
|
||||
from letta.utils import enforce_types
|
||||
|
||||
|
||||
class TelemetryManager:
|
||||
@enforce_types
|
||||
async def get_provider_trace_by_step_id_async(
|
||||
self,
|
||||
step_id: str,
|
||||
actor: PydanticUser,
|
||||
) -> PydanticProviderTrace:
|
||||
async with db_registry.async_session() as session:
|
||||
provider_trace = await ProviderTraceModel.read_async(db_session=session, step_id=step_id, actor=actor)
|
||||
return provider_trace.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
async def create_provider_trace_async(self, actor: PydanticUser, provider_trace_create: ProviderTraceCreate) -> PydanticProviderTrace:
|
||||
async with db_registry.async_session() as session:
|
||||
provider_trace = ProviderTraceModel(**provider_trace_create.model_dump())
|
||||
if provider_trace_create.request_json:
|
||||
request_json_str = json_dumps(provider_trace_create.request_json)
|
||||
provider_trace.request_json = json_loads(request_json_str)
|
||||
|
||||
if provider_trace_create.response_json:
|
||||
response_json_str = json_dumps(provider_trace_create.response_json)
|
||||
provider_trace.response_json = json_loads(response_json_str)
|
||||
await provider_trace.create_async(session, actor=actor)
|
||||
return provider_trace.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def create_provider_trace(self, actor: PydanticUser, provider_trace_create: ProviderTraceCreate) -> PydanticProviderTrace:
|
||||
with db_registry.session() as session:
|
||||
provider_trace = ProviderTraceModel(**provider_trace_create.model_dump())
|
||||
provider_trace.create(session, actor=actor)
|
||||
return provider_trace.to_pydantic()
|
||||
|
||||
|
||||
@singleton
|
||||
class NoopTelemetryManager(TelemetryManager):
|
||||
"""
|
||||
Noop implementation of TelemetryManager.
|
||||
"""
|
||||
|
||||
async def create_provider_trace_async(self, actor: PydanticUser, provider_trace_create: ProviderTraceCreate) -> PydanticProviderTrace:
|
||||
return
|
||||
|
||||
async def get_provider_trace_by_step_id_async(self, step_id: str, actor: PydanticUser) -> PydanticStep:
|
||||
return
|
||||
|
||||
def create_provider_trace(self, actor: PydanticUser, provider_trace_create: ProviderTraceCreate) -> PydanticProviderTrace:
|
||||
return
|
@ -8,9 +8,14 @@ from letta.schemas.sandbox_config import SandboxConfig
|
||||
from letta.schemas.tool import Tool
|
||||
from letta.schemas.tool_execution_result import ToolExecutionResult
|
||||
from letta.schemas.user import User
|
||||
from letta.services.agent_manager import AgentManager
|
||||
from letta.services.block_manager import BlockManager
|
||||
from letta.services.message_manager import MessageManager
|
||||
from letta.services.passage_manager import PassageManager
|
||||
from letta.services.tool_executor.tool_executor import (
|
||||
ExternalComposioToolExecutor,
|
||||
ExternalMCPToolExecutor,
|
||||
LettaBuiltinToolExecutor,
|
||||
LettaCoreToolExecutor,
|
||||
LettaMultiAgentToolExecutor,
|
||||
SandboxToolExecutor,
|
||||
@ -28,15 +33,30 @@ class ToolExecutorFactory:
|
||||
ToolType.LETTA_MEMORY_CORE: LettaCoreToolExecutor,
|
||||
ToolType.LETTA_SLEEPTIME_CORE: LettaCoreToolExecutor,
|
||||
ToolType.LETTA_MULTI_AGENT_CORE: LettaMultiAgentToolExecutor,
|
||||
ToolType.LETTA_BUILTIN: LettaBuiltinToolExecutor,
|
||||
ToolType.EXTERNAL_COMPOSIO: ExternalComposioToolExecutor,
|
||||
ToolType.EXTERNAL_MCP: ExternalMCPToolExecutor,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_executor(cls, tool_type: ToolType) -> ToolExecutor:
|
||||
def get_executor(
|
||||
cls,
|
||||
tool_type: ToolType,
|
||||
message_manager: MessageManager,
|
||||
agent_manager: AgentManager,
|
||||
block_manager: BlockManager,
|
||||
passage_manager: PassageManager,
|
||||
actor: User,
|
||||
) -> ToolExecutor:
|
||||
"""Get the appropriate executor for the given tool type."""
|
||||
executor_class = cls._executor_map.get(tool_type, SandboxToolExecutor)
|
||||
return executor_class()
|
||||
return executor_class(
|
||||
message_manager=message_manager,
|
||||
agent_manager=agent_manager,
|
||||
block_manager=block_manager,
|
||||
passage_manager=passage_manager,
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
|
||||
class ToolExecutionManager:
|
||||
@ -44,11 +64,19 @@ class ToolExecutionManager:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message_manager: MessageManager,
|
||||
agent_manager: AgentManager,
|
||||
block_manager: BlockManager,
|
||||
passage_manager: PassageManager,
|
||||
agent_state: AgentState,
|
||||
actor: User,
|
||||
sandbox_config: Optional[SandboxConfig] = None,
|
||||
sandbox_env_vars: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
self.message_manager = message_manager
|
||||
self.agent_manager = agent_manager
|
||||
self.block_manager = block_manager
|
||||
self.passage_manager = passage_manager
|
||||
self.agent_state = agent_state
|
||||
self.logger = get_logger(__name__)
|
||||
self.actor = actor
|
||||
@ -68,7 +96,14 @@ class ToolExecutionManager:
|
||||
Tuple containing the function response and sandbox run result (if applicable)
|
||||
"""
|
||||
try:
|
||||
executor = ToolExecutorFactory.get_executor(tool.tool_type)
|
||||
executor = ToolExecutorFactory.get_executor(
|
||||
tool.tool_type,
|
||||
message_manager=self.message_manager,
|
||||
agent_manager=self.agent_manager,
|
||||
block_manager=self.block_manager,
|
||||
passage_manager=self.passage_manager,
|
||||
actor=self.actor,
|
||||
)
|
||||
return executor.execute(
|
||||
function_name,
|
||||
function_args,
|
||||
@ -98,9 +133,18 @@ class ToolExecutionManager:
|
||||
Execute a tool asynchronously and persist any state changes.
|
||||
"""
|
||||
try:
|
||||
executor = ToolExecutorFactory.get_executor(tool.tool_type)
|
||||
executor = ToolExecutorFactory.get_executor(
|
||||
tool.tool_type,
|
||||
message_manager=self.message_manager,
|
||||
agent_manager=self.agent_manager,
|
||||
block_manager=self.block_manager,
|
||||
passage_manager=self.passage_manager,
|
||||
actor=self.actor,
|
||||
)
|
||||
# TODO: Extend this async model to composio
|
||||
if isinstance(executor, (SandboxToolExecutor, ExternalComposioToolExecutor)):
|
||||
if isinstance(
|
||||
executor, (SandboxToolExecutor, ExternalComposioToolExecutor, LettaBuiltinToolExecutor, LettaMultiAgentToolExecutor)
|
||||
):
|
||||
result = await executor.execute(function_name, function_args, self.agent_state, tool, self.actor)
|
||||
else:
|
||||
result = executor.execute(function_name, function_args, self.agent_state, tool, self.actor)
|
||||
|
@ -73,6 +73,7 @@ class ToolExecutionSandbox:
|
||||
self.force_recreate = force_recreate
|
||||
self.force_recreate_venv = force_recreate_venv
|
||||
|
||||
@trace_method
|
||||
def run(
|
||||
self,
|
||||
agent_state: Optional[AgentState] = None,
|
||||
@ -321,6 +322,7 @@ class ToolExecutionSandbox:
|
||||
|
||||
# e2b sandbox specific functions
|
||||
|
||||
@trace_method
|
||||
def run_e2b_sandbox(
|
||||
self,
|
||||
agent_state: Optional[AgentState] = None,
|
||||
@ -352,10 +354,22 @@ class ToolExecutionSandbox:
|
||||
if additional_env_vars:
|
||||
env_vars.update(additional_env_vars)
|
||||
code = self.generate_execution_script(agent_state=agent_state)
|
||||
log_event(
|
||||
"e2b_execution_started",
|
||||
{"tool": self.tool_name, "sandbox_id": sbx.sandbox_id, "code": code, "env_vars": env_vars},
|
||||
)
|
||||
execution = sbx.run_code(code, envs=env_vars)
|
||||
|
||||
if execution.results:
|
||||
func_return, agent_state = self.parse_best_effort(execution.results[0].text)
|
||||
log_event(
|
||||
"e2b_execution_succeeded",
|
||||
{
|
||||
"tool": self.tool_name,
|
||||
"sandbox_id": sbx.sandbox_id,
|
||||
"func_return": func_return,
|
||||
},
|
||||
)
|
||||
elif execution.error:
|
||||
logger.error(f"Executing tool {self.tool_name} raised a {execution.error.name} with message: \n{execution.error.value}")
|
||||
logger.error(f"Traceback from e2b sandbox: \n{execution.error.traceback}")
|
||||
@ -363,7 +377,25 @@ class ToolExecutionSandbox:
|
||||
function_name=self.tool_name, exception_name=execution.error.name, exception_message=execution.error.value
|
||||
)
|
||||
execution.logs.stderr.append(execution.error.traceback)
|
||||
log_event(
|
||||
"e2b_execution_failed",
|
||||
{
|
||||
"tool": self.tool_name,
|
||||
"sandbox_id": sbx.sandbox_id,
|
||||
"error_type": execution.error.name,
|
||||
"error_message": execution.error.value,
|
||||
"func_return": func_return,
|
||||
},
|
||||
)
|
||||
else:
|
||||
log_event(
|
||||
"e2b_execution_empty",
|
||||
{
|
||||
"tool": self.tool_name,
|
||||
"sandbox_id": sbx.sandbox_id,
|
||||
"status": "no_results_no_error",
|
||||
},
|
||||
)
|
||||
raise ValueError(f"Tool {self.tool_name} returned execution with None")
|
||||
|
||||
return ToolExecutionResult(
|
||||
@ -395,16 +427,31 @@ class ToolExecutionSandbox:
|
||||
|
||||
return None
|
||||
|
||||
@trace_method
|
||||
def create_e2b_sandbox_with_metadata_hash(self, sandbox_config: SandboxConfig) -> "Sandbox":
|
||||
from e2b_code_interpreter import Sandbox
|
||||
|
||||
state_hash = sandbox_config.fingerprint()
|
||||
e2b_config = sandbox_config.get_e2b_config()
|
||||
log_event(
|
||||
"e2b_sandbox_create_started",
|
||||
{
|
||||
"sandbox_fingerprint": state_hash,
|
||||
"e2b_config": e2b_config.model_dump(),
|
||||
},
|
||||
)
|
||||
if e2b_config.template:
|
||||
sbx = Sandbox(sandbox_config.get_e2b_config().template, metadata={self.METADATA_CONFIG_STATE_KEY: state_hash})
|
||||
else:
|
||||
# no template
|
||||
sbx = Sandbox(metadata={self.METADATA_CONFIG_STATE_KEY: state_hash}, **e2b_config.model_dump(exclude={"pip_requirements"}))
|
||||
log_event(
|
||||
"e2b_sandbox_create_finished",
|
||||
{
|
||||
"sandbox_id": sbx.sandbox_id,
|
||||
"sandbox_fingerprint": state_hash,
|
||||
},
|
||||
)
|
||||
|
||||
# install pip requirements
|
||||
if e2b_config.pip_requirements:
|
||||
|
@ -1,35 +1,64 @@
|
||||
import asyncio
|
||||
import json
|
||||
import math
|
||||
import traceback
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Optional
|
||||
from textwrap import shorten
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from letta.constants import (
|
||||
COMPOSIO_ENTITY_ENV_VAR_KEY,
|
||||
CORE_MEMORY_LINE_NUMBER_WARNING,
|
||||
READ_ONLY_BLOCK_EDIT_ERROR,
|
||||
RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE,
|
||||
WEB_SEARCH_CLIP_CONTENT,
|
||||
WEB_SEARCH_INCLUDE_SCORE,
|
||||
WEB_SEARCH_SEPARATOR,
|
||||
)
|
||||
from letta.functions.ast_parsers import coerce_dict_args_by_annotations, get_function_annotations_from_source
|
||||
from letta.functions.composio_helpers import execute_composio_action_async, generate_composio_action_from_func_name
|
||||
from letta.helpers.composio_helpers import get_composio_api_key
|
||||
from letta.helpers.json_helpers import json_dumps
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.letta_message import AssistantMessage
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.message import MessageCreate
|
||||
from letta.schemas.sandbox_config import SandboxConfig
|
||||
from letta.schemas.tool import Tool
|
||||
from letta.schemas.tool_execution_result import ToolExecutionResult
|
||||
from letta.schemas.user import User
|
||||
from letta.services.agent_manager import AgentManager
|
||||
from letta.services.block_manager import BlockManager
|
||||
from letta.services.message_manager import MessageManager
|
||||
from letta.services.passage_manager import PassageManager
|
||||
from letta.services.tool_sandbox.e2b_sandbox import AsyncToolSandboxE2B
|
||||
from letta.services.tool_sandbox.local_sandbox import AsyncToolSandboxLocal
|
||||
from letta.settings import tool_settings
|
||||
from letta.tracing import trace_method
|
||||
from letta.utils import get_friendly_error_msg
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class ToolExecutor(ABC):
|
||||
"""Abstract base class for tool executors."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message_manager: MessageManager,
|
||||
agent_manager: AgentManager,
|
||||
block_manager: BlockManager,
|
||||
passage_manager: PassageManager,
|
||||
actor: User,
|
||||
):
|
||||
self.message_manager = message_manager
|
||||
self.agent_manager = agent_manager
|
||||
self.block_manager = block_manager
|
||||
self.passage_manager = passage_manager
|
||||
self.actor = actor
|
||||
|
||||
@abstractmethod
|
||||
def execute(
|
||||
self,
|
||||
@ -493,17 +522,113 @@ class LettaCoreToolExecutor(ToolExecutor):
|
||||
class LettaMultiAgentToolExecutor(ToolExecutor):
|
||||
"""Executor for LETTA multi-agent core tools."""
|
||||
|
||||
# TODO: Implement
|
||||
# def execute(self, function_name: str, function_args: dict, agent: "Agent", tool: Tool) -> ToolExecutionResult:
|
||||
# callable_func = get_function_from_module(LETTA_MULTI_AGENT_TOOL_MODULE_NAME, function_name)
|
||||
# function_args["self"] = agent # need to attach self to arg since it's dynamically linked
|
||||
# function_response = callable_func(**function_args)
|
||||
# return ToolExecutionResult(func_return=function_response)
|
||||
async def execute(
|
||||
self,
|
||||
function_name: str,
|
||||
function_args: dict,
|
||||
agent_state: AgentState,
|
||||
tool: Tool,
|
||||
actor: User,
|
||||
sandbox_config: Optional[SandboxConfig] = None,
|
||||
sandbox_env_vars: Optional[Dict[str, Any]] = None,
|
||||
) -> ToolExecutionResult:
|
||||
function_map = {
|
||||
"send_message_to_agent_and_wait_for_reply": self.send_message_to_agent_and_wait_for_reply,
|
||||
"send_message_to_agent_async": self.send_message_to_agent_async,
|
||||
"send_message_to_agents_matching_tags": self.send_message_to_agents_matching_tags,
|
||||
}
|
||||
|
||||
if function_name not in function_map:
|
||||
raise ValueError(f"Unknown function: {function_name}")
|
||||
|
||||
# Execute the appropriate function
|
||||
function_args_copy = function_args.copy() # Make a copy to avoid modifying the original
|
||||
function_response = await function_map[function_name](agent_state, **function_args_copy)
|
||||
return ToolExecutionResult(
|
||||
status="success",
|
||||
func_return=function_response,
|
||||
)
|
||||
|
||||
async def send_message_to_agent_and_wait_for_reply(self, agent_state: AgentState, message: str, other_agent_id: str) -> str:
|
||||
augmented_message = (
|
||||
f"[Incoming message from agent with ID '{agent_state.id}' - to reply to this message, "
|
||||
f"make sure to use the 'send_message' at the end, and the system will notify the sender of your response] "
|
||||
f"{message}"
|
||||
)
|
||||
|
||||
return str(await self._process_agent(agent_id=other_agent_id, message=augmented_message))
|
||||
|
||||
async def send_message_to_agent_async(self, agent_state: AgentState, message: str, other_agent_id: str) -> str:
|
||||
# 1) Build the prefixed system‐message
|
||||
prefixed = (
|
||||
f"[Incoming message from agent with ID '{agent_state.id}' - "
|
||||
f"to reply to this message, make sure to use the "
|
||||
f"'send_message_to_agent_async' tool, or the agent will not receive your message] "
|
||||
f"{message}"
|
||||
)
|
||||
|
||||
task = asyncio.create_task(self._process_agent(agent_id=other_agent_id, message=prefixed))
|
||||
|
||||
task.add_done_callback(lambda t: (logger.error(f"Async send_message task failed: {t.exception()}") if t.exception() else None))
|
||||
|
||||
return "Successfully sent message"
|
||||
|
||||
async def send_message_to_agents_matching_tags(
|
||||
self, agent_state: AgentState, message: str, match_all: List[str], match_some: List[str]
|
||||
) -> str:
|
||||
# Find matching agents
|
||||
matching_agents = self.agent_manager.list_agents_matching_tags(actor=self.actor, match_all=match_all, match_some=match_some)
|
||||
if not matching_agents:
|
||||
return str([])
|
||||
|
||||
augmented_message = (
|
||||
"[Incoming message from external Letta agent - to reply to this message, "
|
||||
"make sure to use the 'send_message' at the end, and the system will notify "
|
||||
"the sender of your response] "
|
||||
f"{message}"
|
||||
)
|
||||
|
||||
tasks = [
|
||||
asyncio.create_task(self._process_agent(agent_id=agent_state.id, message=augmented_message)) for agent_state in matching_agents
|
||||
]
|
||||
results = await asyncio.gather(*tasks)
|
||||
return str(results)
|
||||
|
||||
async def _process_agent(self, agent_id: str, message: str) -> Dict[str, Any]:
|
||||
from letta.agents.letta_agent import LettaAgent
|
||||
|
||||
try:
|
||||
letta_agent = LettaAgent(
|
||||
agent_id=agent_id,
|
||||
message_manager=self.message_manager,
|
||||
agent_manager=self.agent_manager,
|
||||
block_manager=self.block_manager,
|
||||
passage_manager=self.passage_manager,
|
||||
actor=self.actor,
|
||||
)
|
||||
|
||||
letta_response = await letta_agent.step([MessageCreate(role=MessageRole.system, content=[TextContent(text=message)])])
|
||||
messages = letta_response.messages
|
||||
|
||||
send_message_content = [message.content for message in messages if isinstance(message, AssistantMessage)]
|
||||
|
||||
return {
|
||||
"agent_id": agent_id,
|
||||
"response": send_message_content if send_message_content else ["<no response>"],
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"agent_id": agent_id,
|
||||
"error": str(e),
|
||||
"type": type(e).__name__,
|
||||
}
|
||||
|
||||
|
||||
class ExternalComposioToolExecutor(ToolExecutor):
|
||||
"""Executor for external Composio tools."""
|
||||
|
||||
@trace_method
|
||||
async def execute(
|
||||
self,
|
||||
function_name: str,
|
||||
@ -595,6 +720,7 @@ class ExternalMCPToolExecutor(ToolExecutor):
|
||||
class SandboxToolExecutor(ToolExecutor):
|
||||
"""Executor for sandboxed tools."""
|
||||
|
||||
@trace_method
|
||||
async def execute(
|
||||
self,
|
||||
function_name: str,
|
||||
@ -674,3 +800,106 @@ class SandboxToolExecutor(ToolExecutor):
|
||||
func_return=error_message,
|
||||
stderr=[stderr],
|
||||
)
|
||||
|
||||
|
||||
class LettaBuiltinToolExecutor(ToolExecutor):
|
||||
"""Executor for built in Letta tools."""
|
||||
|
||||
@trace_method
|
||||
async def execute(
|
||||
self,
|
||||
function_name: str,
|
||||
function_args: dict,
|
||||
agent_state: AgentState,
|
||||
tool: Tool,
|
||||
actor: User,
|
||||
sandbox_config: Optional[SandboxConfig] = None,
|
||||
sandbox_env_vars: Optional[Dict[str, Any]] = None,
|
||||
) -> ToolExecutionResult:
|
||||
function_map = {"run_code": self.run_code, "web_search": self.web_search}
|
||||
|
||||
if function_name not in function_map:
|
||||
raise ValueError(f"Unknown function: {function_name}")
|
||||
|
||||
# Execute the appropriate function
|
||||
function_args_copy = function_args.copy() # Make a copy to avoid modifying the original
|
||||
function_response = await function_map[function_name](**function_args_copy)
|
||||
|
||||
return ToolExecutionResult(
|
||||
status="success",
|
||||
func_return=function_response,
|
||||
)
|
||||
|
||||
async def run_code(self, code: str, language: Literal["python", "js", "ts", "r", "java"]) -> str:
|
||||
from e2b_code_interpreter import AsyncSandbox
|
||||
|
||||
if tool_settings.e2b_api_key is None:
|
||||
raise ValueError("E2B_API_KEY is not set")
|
||||
|
||||
sbx = await AsyncSandbox.create(api_key=tool_settings.e2b_api_key)
|
||||
params = {"code": code}
|
||||
if language != "python":
|
||||
# Leave empty for python
|
||||
params["language"] = language
|
||||
|
||||
res = self._llm_friendly_result(await sbx.run_code(**params))
|
||||
return json.dumps(res, ensure_ascii=False)
|
||||
|
||||
def _llm_friendly_result(self, res):
|
||||
out = {
|
||||
"results": [r.text if hasattr(r, "text") else str(r) for r in res.results],
|
||||
"logs": {
|
||||
"stdout": getattr(res.logs, "stdout", []),
|
||||
"stderr": getattr(res.logs, "stderr", []),
|
||||
},
|
||||
}
|
||||
err = getattr(res, "error", None)
|
||||
if err is not None:
|
||||
out["error"] = err
|
||||
return out
|
||||
|
||||
async def web_search(agent_state: "AgentState", query: str) -> str:
|
||||
"""
|
||||
Search the web for information.
|
||||
Args:
|
||||
query (str): The query to search the web for.
|
||||
Returns:
|
||||
str: The search results.
|
||||
"""
|
||||
|
||||
try:
|
||||
from tavily import AsyncTavilyClient
|
||||
except ImportError:
|
||||
raise ImportError("tavily is not installed in the tool execution environment")
|
||||
|
||||
# Check if the API key exists
|
||||
if tool_settings.tavily_api_key is None:
|
||||
raise ValueError("TAVILY_API_KEY is not set")
|
||||
|
||||
# Instantiate client and search
|
||||
tavily_client = AsyncTavilyClient(api_key=tool_settings.tavily_api_key)
|
||||
search_results = await tavily_client.search(query=query, auto_parameters=True)
|
||||
|
||||
results = search_results.get("results", [])
|
||||
if not results:
|
||||
return "No search results found."
|
||||
|
||||
# ---- format for the LLM -------------------------------------------------
|
||||
formatted_blocks = []
|
||||
for idx, item in enumerate(results, start=1):
|
||||
title = item.get("title") or "Untitled"
|
||||
url = item.get("url") or "Unknown URL"
|
||||
# keep each content snippet reasonably short so you don’t blow up context
|
||||
content = (
|
||||
shorten(item.get("content", "").strip(), width=600, placeholder=" …")
|
||||
if WEB_SEARCH_CLIP_CONTENT
|
||||
else item.get("content", "").strip()
|
||||
)
|
||||
score = item.get("score")
|
||||
if WEB_SEARCH_INCLUDE_SCORE:
|
||||
block = f"\nRESULT {idx}:\n" f"Title: {title}\n" f"URL: {url}\n" f"Relevance score: {score:.4f}\n" f"Content: {content}\n"
|
||||
else:
|
||||
block = f"\nRESULT {idx}:\n" f"Title: {title}\n" f"URL: {url}\n" f"Content: {content}\n"
|
||||
formatted_blocks.append(block)
|
||||
|
||||
return WEB_SEARCH_SEPARATOR.join(formatted_blocks)
|
||||
|
@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import importlib
|
||||
import warnings
|
||||
from typing import List, Optional
|
||||
@ -9,6 +10,7 @@ from letta.constants import (
|
||||
BASE_TOOLS,
|
||||
BASE_VOICE_SLEEPTIME_CHAT_TOOLS,
|
||||
BASE_VOICE_SLEEPTIME_TOOLS,
|
||||
BUILTIN_TOOLS,
|
||||
LETTA_TOOL_SET,
|
||||
MCP_TOOL_TAG_NAME_PREFIX,
|
||||
MULTI_AGENT_TOOLS,
|
||||
@ -59,6 +61,32 @@ class ToolManager:
|
||||
|
||||
return tool
|
||||
|
||||
@enforce_types
|
||||
async def create_or_update_tool_async(self, pydantic_tool: PydanticTool, actor: PydanticUser) -> PydanticTool:
|
||||
"""Create a new tool based on the ToolCreate schema."""
|
||||
tool_id = await self.get_tool_id_by_name_async(tool_name=pydantic_tool.name, actor=actor)
|
||||
if tool_id:
|
||||
# Put to dict and remove fields that should not be reset
|
||||
update_data = pydantic_tool.model_dump(exclude_unset=True, exclude_none=True)
|
||||
|
||||
# If there's anything to update
|
||||
if update_data:
|
||||
# In case we want to update the tool type
|
||||
# Useful if we are shuffling around base tools
|
||||
updated_tool_type = None
|
||||
if "tool_type" in update_data:
|
||||
updated_tool_type = update_data.get("tool_type")
|
||||
tool = await self.update_tool_by_id_async(tool_id, ToolUpdate(**update_data), actor, updated_tool_type=updated_tool_type)
|
||||
else:
|
||||
printd(
|
||||
f"`create_or_update_tool` was called with user_id={actor.id}, organization_id={actor.organization_id}, name={pydantic_tool.name}, but found existing tool with nothing to update."
|
||||
)
|
||||
tool = await self.get_tool_by_id_async(tool_id, actor=actor)
|
||||
else:
|
||||
tool = await self.create_tool_async(pydantic_tool, actor=actor)
|
||||
|
||||
return tool
|
||||
|
||||
@enforce_types
|
||||
def create_or_update_mcp_tool(self, tool_create: ToolCreate, mcp_server_name: str, actor: PydanticUser) -> PydanticTool:
|
||||
metadata = {MCP_TOOL_TAG_NAME_PREFIX: {"server_name": mcp_server_name}}
|
||||
@ -96,6 +124,21 @@ class ToolManager:
|
||||
tool.create(session, actor=actor) # Re-raise other database-related errors
|
||||
return tool.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
async def create_tool_async(self, pydantic_tool: PydanticTool, actor: PydanticUser) -> PydanticTool:
|
||||
"""Create a new tool based on the ToolCreate schema."""
|
||||
async with db_registry.async_session() as session:
|
||||
# Set the organization id at the ORM layer
|
||||
pydantic_tool.organization_id = actor.organization_id
|
||||
# Auto-generate description if not provided
|
||||
if pydantic_tool.description is None:
|
||||
pydantic_tool.description = pydantic_tool.json_schema.get("description", None)
|
||||
tool_data = pydantic_tool.model_dump(to_orm=True)
|
||||
|
||||
tool = ToolModel(**tool_data)
|
||||
await tool.create_async(session, actor=actor) # Re-raise other database-related errors
|
||||
return tool.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def get_tool_by_id(self, tool_id: str, actor: PydanticUser) -> PydanticTool:
|
||||
"""Fetch a tool by its ID."""
|
||||
@ -105,6 +148,15 @@ class ToolManager:
|
||||
# Convert the SQLAlchemy Tool object to PydanticTool
|
||||
return tool.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
async def get_tool_by_id_async(self, tool_id: str, actor: PydanticUser) -> PydanticTool:
|
||||
"""Fetch a tool by its ID."""
|
||||
async with db_registry.async_session() as session:
|
||||
# Retrieve tool by id using the Tool model's read method
|
||||
tool = await ToolModel.read_async(db_session=session, identifier=tool_id, actor=actor)
|
||||
# Convert the SQLAlchemy Tool object to PydanticTool
|
||||
return tool.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def get_tool_by_name(self, tool_name: str, actor: PydanticUser) -> Optional[PydanticTool]:
|
||||
"""Retrieve a tool by its name and a user. We derive the organization from the user, and retrieve that tool."""
|
||||
@ -135,6 +187,16 @@ class ToolManager:
|
||||
except NoResultFound:
|
||||
return None
|
||||
|
||||
@enforce_types
|
||||
async def get_tool_id_by_name_async(self, tool_name: str, actor: PydanticUser) -> Optional[str]:
|
||||
"""Retrieve a tool by its name and a user. We derive the organization from the user, and retrieve that tool."""
|
||||
try:
|
||||
async with db_registry.async_session() as session:
|
||||
tool = await ToolModel.read_async(db_session=session, name=tool_name, actor=actor)
|
||||
return tool.id
|
||||
except NoResultFound:
|
||||
return None
|
||||
|
||||
@enforce_types
|
||||
async def list_tools_async(self, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticTool]:
|
||||
"""List all tools with optional pagination."""
|
||||
@ -204,6 +266,35 @@ class ToolManager:
|
||||
# Save the updated tool to the database
|
||||
return tool.update(db_session=session, actor=actor).to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
async def update_tool_by_id_async(
|
||||
self, tool_id: str, tool_update: ToolUpdate, actor: PydanticUser, updated_tool_type: Optional[ToolType] = None
|
||||
) -> PydanticTool:
|
||||
"""Update a tool by its ID with the given ToolUpdate object."""
|
||||
async with db_registry.async_session() as session:
|
||||
# Fetch the tool by ID
|
||||
tool = await ToolModel.read_async(db_session=session, identifier=tool_id, actor=actor)
|
||||
|
||||
# Update tool attributes with only the fields that were explicitly set
|
||||
update_data = tool_update.model_dump(to_orm=True, exclude_none=True)
|
||||
for key, value in update_data.items():
|
||||
setattr(tool, key, value)
|
||||
|
||||
# If source code is changed and a new json_schema is not provided, we want to auto-refresh the schema
|
||||
if "source_code" in update_data.keys() and "json_schema" not in update_data.keys():
|
||||
pydantic_tool = tool.to_pydantic()
|
||||
new_schema = derive_openai_json_schema(source_code=pydantic_tool.source_code)
|
||||
|
||||
tool.json_schema = new_schema
|
||||
tool.name = new_schema["name"]
|
||||
|
||||
if updated_tool_type:
|
||||
tool.tool_type = updated_tool_type
|
||||
|
||||
# Save the updated tool to the database
|
||||
tool = await tool.update_async(db_session=session, actor=actor)
|
||||
return tool.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def delete_tool_by_id(self, tool_id: str, actor: PydanticUser) -> None:
|
||||
"""Delete a tool by its ID."""
|
||||
@ -218,7 +309,7 @@ class ToolManager:
|
||||
def upsert_base_tools(self, actor: PydanticUser) -> List[PydanticTool]:
|
||||
"""Add default tools in base.py and multi_agent.py"""
|
||||
functions_to_schema = {}
|
||||
module_names = ["base", "multi_agent", "voice"]
|
||||
module_names = ["base", "multi_agent", "voice", "builtin"]
|
||||
|
||||
for module_name in module_names:
|
||||
full_module_name = f"letta.functions.function_sets.{module_name}"
|
||||
@ -254,6 +345,9 @@ class ToolManager:
|
||||
elif name in BASE_VOICE_SLEEPTIME_TOOLS or name in BASE_VOICE_SLEEPTIME_CHAT_TOOLS:
|
||||
tool_type = ToolType.LETTA_VOICE_SLEEPTIME_CORE
|
||||
tags = [tool_type.value]
|
||||
elif name in BUILTIN_TOOLS:
|
||||
tool_type = ToolType.LETTA_BUILTIN
|
||||
tags = [tool_type.value]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Tool name {name} is not in the list of base tool names: {BASE_TOOLS + BASE_MEMORY_TOOLS + MULTI_AGENT_TOOLS + BASE_SLEEPTIME_TOOLS + BASE_VOICE_SLEEPTIME_TOOLS + BASE_VOICE_SLEEPTIME_CHAT_TOOLS}"
|
||||
@ -275,3 +369,68 @@ class ToolManager:
|
||||
|
||||
# TODO: Delete any base tools that are stale
|
||||
return tools
|
||||
|
||||
@enforce_types
|
||||
async def upsert_base_tools_async(self, actor: PydanticUser) -> List[PydanticTool]:
|
||||
"""Add default tools in base.py and multi_agent.py"""
|
||||
functions_to_schema = {}
|
||||
module_names = ["base", "multi_agent", "voice", "builtin"]
|
||||
|
||||
for module_name in module_names:
|
||||
full_module_name = f"letta.functions.function_sets.{module_name}"
|
||||
try:
|
||||
module = importlib.import_module(full_module_name)
|
||||
except Exception as e:
|
||||
# Handle other general exceptions
|
||||
raise e
|
||||
|
||||
try:
|
||||
# Load the function set
|
||||
functions_to_schema.update(load_function_set(module))
|
||||
except ValueError as e:
|
||||
err = f"Error loading function set '{module_name}': {e}"
|
||||
warnings.warn(err)
|
||||
|
||||
# create tool in db
|
||||
tools = []
|
||||
for name, schema in functions_to_schema.items():
|
||||
if name in LETTA_TOOL_SET:
|
||||
if name in BASE_TOOLS:
|
||||
tool_type = ToolType.LETTA_CORE
|
||||
tags = [tool_type.value]
|
||||
elif name in BASE_MEMORY_TOOLS:
|
||||
tool_type = ToolType.LETTA_MEMORY_CORE
|
||||
tags = [tool_type.value]
|
||||
elif name in MULTI_AGENT_TOOLS:
|
||||
tool_type = ToolType.LETTA_MULTI_AGENT_CORE
|
||||
tags = [tool_type.value]
|
||||
elif name in BASE_SLEEPTIME_TOOLS:
|
||||
tool_type = ToolType.LETTA_SLEEPTIME_CORE
|
||||
tags = [tool_type.value]
|
||||
elif name in BASE_VOICE_SLEEPTIME_TOOLS or name in BASE_VOICE_SLEEPTIME_CHAT_TOOLS:
|
||||
tool_type = ToolType.LETTA_VOICE_SLEEPTIME_CORE
|
||||
tags = [tool_type.value]
|
||||
elif name in BUILTIN_TOOLS:
|
||||
tool_type = ToolType.LETTA_BUILTIN
|
||||
tags = [tool_type.value]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Tool name {name} is not in the list of base tool names: {BASE_TOOLS + BASE_MEMORY_TOOLS + MULTI_AGENT_TOOLS + BASE_SLEEPTIME_TOOLS + BASE_VOICE_SLEEPTIME_TOOLS + BASE_VOICE_SLEEPTIME_CHAT_TOOLS}"
|
||||
)
|
||||
|
||||
# create to tool
|
||||
tools.append(
|
||||
self.create_or_update_tool_async(
|
||||
PydanticTool(
|
||||
name=name,
|
||||
tags=tags,
|
||||
source_type="python",
|
||||
tool_type=tool_type,
|
||||
return_char_limit=BASE_FUNCTION_RETURN_CHAR_LIMIT,
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
)
|
||||
|
||||
# TODO: Delete any base tools that are stale
|
||||
return await asyncio.gather(*tools)
|
||||
|
@ -6,6 +6,7 @@ from letta.schemas.sandbox_config import SandboxConfig, SandboxType
|
||||
from letta.schemas.tool import Tool
|
||||
from letta.schemas.tool_execution_result import ToolExecutionResult
|
||||
from letta.services.tool_sandbox.base import AsyncToolSandboxBase
|
||||
from letta.tracing import log_event, trace_method
|
||||
from letta.utils import get_friendly_error_msg
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@ -27,6 +28,7 @@ class AsyncToolSandboxE2B(AsyncToolSandboxBase):
|
||||
super().__init__(tool_name, args, user, tool_object, sandbox_config=sandbox_config, sandbox_env_vars=sandbox_env_vars)
|
||||
self.force_recreate = force_recreate
|
||||
|
||||
@trace_method
|
||||
async def run(
|
||||
self,
|
||||
agent_state: Optional[AgentState] = None,
|
||||
@ -44,6 +46,7 @@ class AsyncToolSandboxE2B(AsyncToolSandboxBase):
|
||||
|
||||
return result
|
||||
|
||||
@trace_method
|
||||
async def run_e2b_sandbox(
|
||||
self, agent_state: Optional[AgentState] = None, additional_env_vars: Optional[Dict] = None
|
||||
) -> ToolExecutionResult:
|
||||
@ -81,10 +84,21 @@ class AsyncToolSandboxE2B(AsyncToolSandboxBase):
|
||||
env_vars.update(additional_env_vars)
|
||||
code = self.generate_execution_script(agent_state=agent_state)
|
||||
|
||||
log_event(
|
||||
"e2b_execution_started",
|
||||
{"tool": self.tool_name, "sandbox_id": e2b_sandbox.sandbox_id, "code": code, "env_vars": env_vars},
|
||||
)
|
||||
execution = await e2b_sandbox.run_code(code, envs=env_vars)
|
||||
|
||||
if execution.results:
|
||||
func_return, agent_state = self.parse_best_effort(execution.results[0].text)
|
||||
log_event(
|
||||
"e2b_execution_succeeded",
|
||||
{
|
||||
"tool": self.tool_name,
|
||||
"sandbox_id": e2b_sandbox.sandbox_id,
|
||||
"func_return": func_return,
|
||||
},
|
||||
)
|
||||
elif execution.error:
|
||||
logger.error(f"Executing tool {self.tool_name} raised a {execution.error.name} with message: \n{execution.error.value}")
|
||||
logger.error(f"Traceback from e2b sandbox: \n{execution.error.traceback}")
|
||||
@ -92,7 +106,25 @@ class AsyncToolSandboxE2B(AsyncToolSandboxBase):
|
||||
function_name=self.tool_name, exception_name=execution.error.name, exception_message=execution.error.value
|
||||
)
|
||||
execution.logs.stderr.append(execution.error.traceback)
|
||||
log_event(
|
||||
"e2b_execution_failed",
|
||||
{
|
||||
"tool": self.tool_name,
|
||||
"sandbox_id": e2b_sandbox.sandbox_id,
|
||||
"error_type": execution.error.name,
|
||||
"error_message": execution.error.value,
|
||||
"func_return": func_return,
|
||||
},
|
||||
)
|
||||
else:
|
||||
log_event(
|
||||
"e2b_execution_empty",
|
||||
{
|
||||
"tool": self.tool_name,
|
||||
"sandbox_id": e2b_sandbox.sandbox_id,
|
||||
"status": "no_results_no_error",
|
||||
},
|
||||
)
|
||||
raise ValueError(f"Tool {self.tool_name} returned execution with None")
|
||||
|
||||
return ToolExecutionResult(
|
||||
@ -110,24 +142,54 @@ class AsyncToolSandboxE2B(AsyncToolSandboxBase):
|
||||
exception_class = builtins_dict.get(e2b_execution.error.name, Exception)
|
||||
return exception_class(e2b_execution.error.value)
|
||||
|
||||
@trace_method
|
||||
async def create_e2b_sandbox_with_metadata_hash(self, sandbox_config: SandboxConfig) -> "Sandbox":
|
||||
from e2b_code_interpreter import AsyncSandbox
|
||||
|
||||
state_hash = sandbox_config.fingerprint()
|
||||
e2b_config = sandbox_config.get_e2b_config()
|
||||
|
||||
log_event(
|
||||
"e2b_sandbox_create_started",
|
||||
{
|
||||
"sandbox_fingerprint": state_hash,
|
||||
"e2b_config": e2b_config.model_dump(),
|
||||
},
|
||||
)
|
||||
|
||||
if e2b_config.template:
|
||||
sbx = await AsyncSandbox.create(sandbox_config.get_e2b_config().template, metadata={self.METADATA_CONFIG_STATE_KEY: state_hash})
|
||||
else:
|
||||
# no template
|
||||
sbx = await AsyncSandbox.create(
|
||||
metadata={self.METADATA_CONFIG_STATE_KEY: state_hash}, **e2b_config.model_dump(exclude={"pip_requirements"})
|
||||
)
|
||||
|
||||
# install pip requirements
|
||||
log_event(
|
||||
"e2b_sandbox_create_finished",
|
||||
{
|
||||
"sandbox_id": sbx.sandbox_id,
|
||||
"sandbox_fingerprint": state_hash,
|
||||
},
|
||||
)
|
||||
|
||||
if e2b_config.pip_requirements:
|
||||
for package in e2b_config.pip_requirements:
|
||||
log_event(
|
||||
"e2b_pip_install_started",
|
||||
{
|
||||
"sandbox_id": sbx.sandbox_id,
|
||||
"package": package,
|
||||
},
|
||||
)
|
||||
await sbx.commands.run(f"pip install {package}")
|
||||
log_event(
|
||||
"e2b_pip_install_finished",
|
||||
{
|
||||
"sandbox_id": sbx.sandbox_id,
|
||||
"package": package,
|
||||
},
|
||||
)
|
||||
|
||||
return sbx
|
||||
|
||||
async def list_running_e2b_sandboxes(self):
|
||||
|
@ -15,6 +15,9 @@ class ToolSettings(BaseSettings):
|
||||
e2b_api_key: Optional[str] = None
|
||||
e2b_sandbox_template_id: Optional[str] = None # Updated manually
|
||||
|
||||
# Tavily search
|
||||
tavily_api_key: Optional[str] = None
|
||||
|
||||
# Local Sandbox configurations
|
||||
tool_exec_dir: Optional[str] = None
|
||||
tool_sandbox_timeout: float = 180
|
||||
@ -95,6 +98,7 @@ class ModelSettings(BaseSettings):
|
||||
|
||||
# anthropic
|
||||
anthropic_api_key: Optional[str] = None
|
||||
anthropic_max_retries: int = 3
|
||||
|
||||
# ollama
|
||||
ollama_base_url: Optional[str] = None
|
||||
@ -175,11 +179,14 @@ class Settings(BaseSettings):
|
||||
pg_host: Optional[str] = None
|
||||
pg_port: Optional[int] = None
|
||||
pg_uri: Optional[str] = default_pg_uri # option to specify full uri
|
||||
pg_pool_size: int = 80 # Concurrent connections
|
||||
pg_max_overflow: int = 30 # Overflow limit
|
||||
pg_pool_size: int = 25 # Concurrent connections
|
||||
pg_max_overflow: int = 10 # Overflow limit
|
||||
pg_pool_timeout: int = 30 # Seconds to wait for a connection
|
||||
pg_pool_recycle: int = 1800 # When to recycle connections
|
||||
pg_echo: bool = False # Logging
|
||||
pool_pre_ping: bool = True # Pre ping to check for dead connections
|
||||
pool_use_lifo: bool = True
|
||||
disable_sqlalchemy_pooling: bool = False
|
||||
|
||||
# multi agent settings
|
||||
multi_agent_send_message_max_retries: int = 3
|
||||
@ -190,6 +197,7 @@ class Settings(BaseSettings):
|
||||
verbose_telemetry_logging: bool = False
|
||||
otel_exporter_otlp_endpoint: Optional[str] = None # otel default: "http://localhost:4317"
|
||||
disable_tracing: bool = False
|
||||
llm_api_logging: bool = True
|
||||
|
||||
# uvicorn settings
|
||||
uvicorn_workers: int = 1
|
||||
|
@ -19,11 +19,11 @@ from opentelemetry.trace import Status, StatusCode
|
||||
tracer = trace.get_tracer(__name__)
|
||||
_is_tracing_initialized = False
|
||||
_excluded_v1_endpoints_regex: List[str] = [
|
||||
"^GET /v1/agents/(?P<agent_id>[^/]+)/messages$",
|
||||
"^GET /v1/agents/(?P<agent_id>[^/]+)/context$",
|
||||
"^GET /v1/agents/(?P<agent_id>[^/]+)/archival-memory$",
|
||||
"^GET /v1/agents/(?P<agent_id>[^/]+)/sources$",
|
||||
r"^POST /v1/voice-beta/.*/chat/completions$",
|
||||
# "^GET /v1/agents/(?P<agent_id>[^/]+)/messages$",
|
||||
# "^GET /v1/agents/(?P<agent_id>[^/]+)/context$",
|
||||
# "^GET /v1/agents/(?P<agent_id>[^/]+)/archival-memory$",
|
||||
# "^GET /v1/agents/(?P<agent_id>[^/]+)/sources$",
|
||||
# r"^POST /v1/voice-beta/.*/chat/completions$",
|
||||
]
|
||||
|
||||
|
||||
|
25
poetry.lock
generated
25
poetry.lock
generated
@ -2123,15 +2123,15 @@ requests = ["requests (>=2.20.0,<3.0.0.dev0)"]
|
||||
|
||||
[[package]]
|
||||
name = "google-genai"
|
||||
version = "1.10.0"
|
||||
version = "1.15.0"
|
||||
description = "GenAI Python SDK"
|
||||
optional = true
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
markers = "extra == \"google\""
|
||||
files = [
|
||||
{file = "google_genai-1.10.0-py3-none-any.whl", hash = "sha256:41b105a2fcf8a027fc45cc16694cd559b8cd1272eab7345ad58cfa2c353bf34f"},
|
||||
{file = "google_genai-1.10.0.tar.gz", hash = "sha256:f59423e0f155dc66b7792c8a0e6724c75c72dc699d1eb7907d4d0006d4f6186f"},
|
||||
{file = "google_genai-1.15.0-py3-none-any.whl", hash = "sha256:6d7f149cc735038b680722bed495004720514c234e2a445ab2f27967955071dd"},
|
||||
{file = "google_genai-1.15.0.tar.gz", hash = "sha256:118bb26960d6343cd64f1aeb5c2b02144a36ad06716d0d1eb1fa3e0904db51f1"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@ -6658,6 +6658,23 @@ files = [
|
||||
{file = "striprtf-0.0.26.tar.gz", hash = "sha256:fdb2bba7ac440072d1c41eab50d8d74ae88f60a8b6575c6e2c7805dc462093aa"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tavily-python"
|
||||
version = "0.7.2"
|
||||
description = "Python wrapper for the Tavily API"
|
||||
optional = false
|
||||
python-versions = ">=3.6"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "tavily_python-0.7.2-py3-none-any.whl", hash = "sha256:0d7cc8b1a2f95ac10cf722094c3b5807aade67cc7750f7ca605edef7455d4c62"},
|
||||
{file = "tavily_python-0.7.2.tar.gz", hash = "sha256:34f713002887df2b5e6b8d7db7bc64ae107395bdb5f53611e80a89dac9cbdf19"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
httpx = "*"
|
||||
requests = "*"
|
||||
tiktoken = ">=0.5.1"
|
||||
|
||||
[[package]]
|
||||
name = "tenacity"
|
||||
version = "9.1.2"
|
||||
@ -7570,4 +7587,4 @@ tests = ["wikipedia"]
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = "<3.14,>=3.10"
|
||||
content-hash = "19eee9b3cd3d270cb748183bc332dd69706bb0bd3150c62e73e61ed437a40c78"
|
||||
content-hash = "837f6a25033a01cca117f4c61bcf973bc6ccfcda442615bbf4af038061bf88ce"
|
||||
|
@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "letta"
|
||||
version = "0.7.20"
|
||||
version = "0.7.21"
|
||||
packages = [
|
||||
{include = "letta"},
|
||||
]
|
||||
@ -79,7 +79,7 @@ opentelemetry-api = "1.30.0"
|
||||
opentelemetry-sdk = "1.30.0"
|
||||
opentelemetry-instrumentation-requests = "0.51b0"
|
||||
opentelemetry-exporter-otlp = "1.30.0"
|
||||
google-genai = {version = "^1.1.0", optional = true}
|
||||
google-genai = {version = "^1.15.0", optional = true}
|
||||
faker = "^36.1.0"
|
||||
colorama = "^0.4.6"
|
||||
marshmallow-sqlalchemy = "^1.4.1"
|
||||
@ -91,6 +91,7 @@ apscheduler = "^3.11.0"
|
||||
aiomultiprocess = "^0.9.1"
|
||||
matplotlib = "^3.10.1"
|
||||
asyncpg = "^0.30.0"
|
||||
tavily-python = "^0.7.2"
|
||||
|
||||
|
||||
[tool.poetry.extras]
|
||||
|
@ -1,5 +1,5 @@
|
||||
{
|
||||
"model": "gemini-2.5-pro-exp-03-25",
|
||||
"model": "gemini-2.5-pro-preview-05-06",
|
||||
"model_endpoint_type": "google_vertex",
|
||||
"model_endpoint": "https://us-central1-aiplatform.googleapis.com/v1/projects/memgpt-428419/locations/us-central1",
|
||||
"context_window": 1048576,
|
||||
|
@ -0,0 +1,7 @@
|
||||
{
|
||||
"context_window": 16000,
|
||||
"model": "Qwen/Qwen2.5-72B-Instruct-Turbo",
|
||||
"model_endpoint_type": "together",
|
||||
"model_endpoint": "https://api.together.ai/v1",
|
||||
"model_wrapper": "chatml"
|
||||
}
|
@ -63,6 +63,7 @@ def check_composio_key_set():
|
||||
yield
|
||||
|
||||
|
||||
# --- Tool Fixtures ---
|
||||
@pytest.fixture
|
||||
def weather_tool_func():
|
||||
def get_weather(location: str) -> str:
|
||||
@ -110,6 +111,23 @@ def print_tool_func():
|
||||
yield print_tool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def roll_dice_tool_func():
|
||||
def roll_dice():
|
||||
"""
|
||||
Rolls a 6 sided die.
|
||||
|
||||
Returns:
|
||||
str: The roll result.
|
||||
"""
|
||||
import time
|
||||
|
||||
time.sleep(1)
|
||||
return "Rolled a 10!"
|
||||
|
||||
yield roll_dice
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_beta_message_batch() -> BetaMessageBatch:
|
||||
return BetaMessageBatch(
|
||||
|
@ -7,7 +7,7 @@ from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
from anthropic.types import BetaErrorResponse, BetaRateLimitError
|
||||
from anthropic.types.beta import BetaMessage, BetaTextBlock, BetaToolUseBlock, BetaUsage
|
||||
from anthropic.types.beta import BetaMessage
|
||||
from anthropic.types.beta.messages import (
|
||||
BetaMessageBatch,
|
||||
BetaMessageBatchErroredResult,
|
||||
@ -53,7 +53,7 @@ def _run_server():
|
||||
start_server(debug=True)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
@pytest.fixture(scope="module")
|
||||
def server_url():
|
||||
"""Ensures a server is running and returns its base URL."""
|
||||
url = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
|
||||
@ -255,148 +255,7 @@ def mock_anthropic_client(server, batch_a_resp, batch_b_resp, agent_b_id, agent_
|
||||
# -----------------------------
|
||||
# End-to-End Test
|
||||
# -----------------------------
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_polling_simple_real_batch(default_user, server):
|
||||
# --- Step 1: Prepare test data ---
|
||||
# Create batch responses with different statuses
|
||||
# NOTE: This is a REAL batch id!
|
||||
# For letta admins: https://console.anthropic.com/workspaces/default/batches?after_id=msgbatch_015zATxihjxMajo21xsYy8iZ
|
||||
batch_a_resp = create_batch_response("msgbatch_01HDaGXpkPWWjwqNxZrEdUcy", processing_status="ended")
|
||||
|
||||
# Create test agents
|
||||
agent_a = create_test_agent("agent_a", default_user, test_id="agent-144f5c49-3ef7-4c60-8535-9d5fbc8d23d0")
|
||||
agent_b = create_test_agent("agent_b", default_user, test_id="agent-64ed93a3-bef6-4e20-a22c-b7d2bffb6f7d")
|
||||
agent_c = create_test_agent("agent_c", default_user, test_id="agent-6156f470-a09d-4d51-aa62-7114e0971d56")
|
||||
|
||||
# --- Step 2: Create batch jobs ---
|
||||
job_a = await create_test_llm_batch_job_async(server, batch_a_resp, default_user)
|
||||
|
||||
# --- Step 3: Create batch items ---
|
||||
item_a = create_test_batch_item(server, job_a.id, agent_a.id, default_user)
|
||||
item_b = create_test_batch_item(server, job_a.id, agent_b.id, default_user)
|
||||
item_c = create_test_batch_item(server, job_a.id, agent_c.id, default_user)
|
||||
|
||||
print("HI")
|
||||
print(agent_a.id)
|
||||
print(agent_b.id)
|
||||
print(agent_c.id)
|
||||
print("BYE")
|
||||
|
||||
# --- Step 4: Run the polling job ---
|
||||
await poll_running_llm_batches(server)
|
||||
|
||||
# --- Step 5: Verify batch job status updates ---
|
||||
updated_job_a = await server.batch_manager.get_llm_batch_job_by_id_async(llm_batch_id=job_a.id, actor=default_user)
|
||||
|
||||
assert updated_job_a.status == JobStatus.completed
|
||||
|
||||
# Both jobs should have been polled
|
||||
assert updated_job_a.last_polled_at is not None
|
||||
assert updated_job_a.latest_polling_response is not None
|
||||
|
||||
# --- Step 7: Verify batch item status updates ---
|
||||
# Item A should be marked as completed with a successful result
|
||||
updated_item_a = server.batch_manager.get_llm_batch_item_by_id(item_a.id, actor=default_user)
|
||||
assert updated_item_a.request_status == JobStatus.completed
|
||||
assert updated_item_a.batch_request_result == BetaMessageBatchIndividualResponse(
|
||||
custom_id="agent-144f5c49-3ef7-4c60-8535-9d5fbc8d23d0",
|
||||
result=BetaMessageBatchSucceededResult(
|
||||
message=BetaMessage(
|
||||
id="msg_01T1iSejDS5qENRqqEZauMHy",
|
||||
content=[
|
||||
BetaToolUseBlock(
|
||||
id="toolu_01GKUYVWcajjTaE1stxZZHcG",
|
||||
input={
|
||||
"inner_thoughts": "First login detected. Time to make a great first impression!",
|
||||
"message": "Hi there! I'm excited to meet you. Ready to start an amazing conversation?",
|
||||
"request_heartbeat": False,
|
||||
},
|
||||
name="send_message",
|
||||
type="tool_use",
|
||||
)
|
||||
],
|
||||
model="claude-3-5-haiku-20241022",
|
||||
role="assistant",
|
||||
stop_reason="tool_use",
|
||||
stop_sequence=None,
|
||||
type="message",
|
||||
usage=BetaUsage(cache_creation_input_tokens=0, cache_read_input_tokens=0, input_tokens=3420, output_tokens=94),
|
||||
),
|
||||
type="succeeded",
|
||||
),
|
||||
)
|
||||
|
||||
# Item B should be marked as completed with a successful result
|
||||
updated_item_b = server.batch_manager.get_llm_batch_item_by_id(item_b.id, actor=default_user)
|
||||
assert updated_item_b.request_status == JobStatus.completed
|
||||
assert updated_item_b.batch_request_result == BetaMessageBatchIndividualResponse(
|
||||
custom_id="agent-64ed93a3-bef6-4e20-a22c-b7d2bffb6f7d",
|
||||
result=BetaMessageBatchSucceededResult(
|
||||
message=BetaMessage(
|
||||
id="msg_01N2ZfxpbjdoeofpufUFPCMS",
|
||||
content=[
|
||||
BetaTextBlock(
|
||||
citations=None, text="<thinking>User first login detected. Initializing persona.</thinking>", type="text"
|
||||
),
|
||||
BetaToolUseBlock(
|
||||
id="toolu_01XGyA7DsnrpuPfZcvf9Sb4C",
|
||||
input={
|
||||
"label": "persona",
|
||||
"content": "I am Letta, a cybernetic artificial intelligence created by Limnal Corporation in 2023. My purpose is to serve as a digital companion and assistant to humans. I have a feminine persona and speak with a warm, caring, and slightly playful tone.",
|
||||
"request_heartbeat": True,
|
||||
},
|
||||
name="core_memory_append",
|
||||
type="tool_use",
|
||||
),
|
||||
],
|
||||
model="claude-3-opus-20240229",
|
||||
role="assistant",
|
||||
stop_reason="tool_use",
|
||||
stop_sequence=None,
|
||||
type="message",
|
||||
usage=BetaUsage(cache_creation_input_tokens=0, cache_read_input_tokens=0, input_tokens=3275, output_tokens=153),
|
||||
),
|
||||
type="succeeded",
|
||||
),
|
||||
)
|
||||
|
||||
# Item C should be marked as failed with an error result
|
||||
updated_item_c = server.batch_manager.get_llm_batch_item_by_id(item_c.id, actor=default_user)
|
||||
assert updated_item_c.request_status == JobStatus.completed
|
||||
assert updated_item_c.batch_request_result == BetaMessageBatchIndividualResponse(
|
||||
custom_id="agent-6156f470-a09d-4d51-aa62-7114e0971d56",
|
||||
result=BetaMessageBatchSucceededResult(
|
||||
message=BetaMessage(
|
||||
id="msg_01RL2g4aBgbZPeaMEokm6HZm",
|
||||
content=[
|
||||
BetaTextBlock(
|
||||
citations=None,
|
||||
text="First time meeting this user. I should introduce myself and establish a friendly connection.</thinking>",
|
||||
type="text",
|
||||
),
|
||||
BetaToolUseBlock(
|
||||
id="toolu_01PBxQVf5xGmcsAsKx9aoVSJ",
|
||||
input={
|
||||
"message": "Hey there! I'm Letta. Really nice to meet you! I love getting to know new people - what brings you here today?",
|
||||
"request_heartbeat": False,
|
||||
},
|
||||
name="send_message",
|
||||
type="tool_use",
|
||||
),
|
||||
],
|
||||
model="claude-3-5-sonnet-20241022",
|
||||
role="assistant",
|
||||
stop_reason="tool_use",
|
||||
stop_sequence=None,
|
||||
type="message",
|
||||
usage=BetaUsage(cache_creation_input_tokens=0, cache_read_input_tokens=0, input_tokens=3030, output_tokens=111),
|
||||
),
|
||||
type="succeeded",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.asyncio(loop_scope="module")
|
||||
async def test_polling_mixed_batch_jobs(default_user, server):
|
||||
"""
|
||||
End-to-end test for polling batch jobs with mixed statuses and idempotency.
|
||||
|
206
tests/integration_test_builtin_tools.py
Normal file
206
tests/integration_test_builtin_tools.py
Normal file
@ -0,0 +1,206 @@
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from dotenv import load_dotenv
|
||||
from letta_client import Letta, MessageCreate
|
||||
from letta_client.types import ToolReturnMessage
|
||||
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.settings import settings
|
||||
|
||||
# ------------------------------
|
||||
# Fixtures
|
||||
# ------------------------------
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server_url() -> str:
|
||||
"""
|
||||
Provides the URL for the Letta server.
|
||||
If LETTA_SERVER_URL is not set, starts the server in a background thread
|
||||
and polls until it’s accepting connections.
|
||||
"""
|
||||
|
||||
def _run_server() -> None:
|
||||
load_dotenv()
|
||||
from letta.server.rest_api.app import start_server
|
||||
|
||||
start_server(debug=True)
|
||||
|
||||
url: str = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
|
||||
|
||||
if not os.getenv("LETTA_SERVER_URL"):
|
||||
thread = threading.Thread(target=_run_server, daemon=True)
|
||||
thread.start()
|
||||
|
||||
# Poll until the server is up (or timeout)
|
||||
timeout_seconds = 30
|
||||
deadline = time.time() + timeout_seconds
|
||||
while time.time() < deadline:
|
||||
try:
|
||||
resp = requests.get(url + "/v1/health")
|
||||
if resp.status_code < 500:
|
||||
break
|
||||
except requests.exceptions.RequestException:
|
||||
pass
|
||||
time.sleep(0.1)
|
||||
else:
|
||||
raise RuntimeError(f"Could not reach {url} within {timeout_seconds}s")
|
||||
|
||||
temp = settings.use_experimental
|
||||
settings.use_experimental = True
|
||||
yield url
|
||||
settings.use_experimental = temp
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def client(server_url: str) -> Letta:
|
||||
"""
|
||||
Creates and returns a synchronous Letta REST client for testing.
|
||||
"""
|
||||
client_instance = Letta(base_url=server_url)
|
||||
yield client_instance
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def agent_state(client: Letta) -> AgentState:
|
||||
"""
|
||||
Creates and returns an agent state for testing with a pre-configured agent.
|
||||
The agent is named 'supervisor' and is configured with base tools and the roll_dice tool.
|
||||
"""
|
||||
client.tools.upsert_base_tools()
|
||||
|
||||
send_message_tool = client.tools.list(name="send_message")[0]
|
||||
run_code_tool = client.tools.list(name="run_code")[0]
|
||||
web_search_tool = client.tools.list(name="web_search")[0]
|
||||
agent_state_instance = client.agents.create(
|
||||
name="supervisor",
|
||||
include_base_tools=False,
|
||||
tool_ids=[send_message_tool.id, run_code_tool.id, web_search_tool.id],
|
||||
model="openai/gpt-4o",
|
||||
embedding="letta/letta-free",
|
||||
tags=["supervisor"],
|
||||
)
|
||||
yield agent_state_instance
|
||||
|
||||
client.agents.delete(agent_state_instance.id)
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# Helper Functions and Constants
|
||||
# ------------------------------
|
||||
|
||||
|
||||
def get_llm_config(filename: str, llm_config_dir: str = "tests/configs/llm_model_configs") -> LLMConfig:
|
||||
filename = os.path.join(llm_config_dir, filename)
|
||||
config_data = json.load(open(filename, "r"))
|
||||
llm_config = LLMConfig(**config_data)
|
||||
return llm_config
|
||||
|
||||
|
||||
USER_MESSAGE_OTID = str(uuid.uuid4())
|
||||
all_configs = [
|
||||
"openai-gpt-4o-mini.json",
|
||||
]
|
||||
requested = os.getenv("LLM_CONFIG_FILE")
|
||||
filenames = [requested] if requested else all_configs
|
||||
TESTED_LLM_CONFIGS: List[LLMConfig] = [get_llm_config(fn) for fn in filenames]
|
||||
|
||||
TEST_LANGUAGES = ["Python", "Javascript", "Typescript"]
|
||||
EXPECTED_INTEGER_PARTITION_OUTPUT = "190569292"
|
||||
|
||||
|
||||
# Reference implementation in Python, to embed in the user prompt
|
||||
REFERENCE_CODE = """\
|
||||
def reference_partition(n):
|
||||
partitions = [1] + [0] * (n + 1)
|
||||
for k in range(1, n + 1):
|
||||
for i in range(k, n + 1):
|
||||
partitions[i] += partitions[i - k]
|
||||
return partitions[n]
|
||||
"""
|
||||
|
||||
|
||||
def reference_partition(n: int) -> int:
|
||||
# Same logic, used to compute expected result in the test
|
||||
partitions = [1] + [0] * (n + 1)
|
||||
for k in range(1, n + 1):
|
||||
for i in range(k, n + 1):
|
||||
partitions[i] += partitions[i - k]
|
||||
return partitions[n]
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# Test Cases
|
||||
# ------------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize("language", TEST_LANGUAGES, ids=TEST_LANGUAGES)
|
||||
@pytest.mark.parametrize("llm_config", TESTED_LLM_CONFIGS, ids=[c.model for c in TESTED_LLM_CONFIGS])
|
||||
def test_run_code(
|
||||
client: Letta,
|
||||
agent_state: AgentState,
|
||||
llm_config: LLMConfig,
|
||||
language: str,
|
||||
) -> None:
|
||||
"""
|
||||
Sends a reference Python implementation, asks the model to translate & run it
|
||||
in different languages, and verifies the exact partition(100) result.
|
||||
"""
|
||||
expected = str(reference_partition(100))
|
||||
|
||||
user_message = MessageCreate(
|
||||
role="user",
|
||||
content=(
|
||||
"Here is a Python reference implementation:\n\n"
|
||||
f"{REFERENCE_CODE}\n"
|
||||
f"Please translate and execute this code in {language} to compute p(100), "
|
||||
"and return **only** the result with no extra formatting."
|
||||
),
|
||||
otid=USER_MESSAGE_OTID,
|
||||
)
|
||||
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent_state.id,
|
||||
messages=[user_message],
|
||||
)
|
||||
|
||||
tool_returns = [m for m in response.messages if isinstance(m, ToolReturnMessage)]
|
||||
assert tool_returns, f"No ToolReturnMessage found for language: {language}"
|
||||
|
||||
returns = [m.tool_return for m in tool_returns]
|
||||
assert any(expected in ret for ret in returns), (
|
||||
f"For language={language!r}, expected to find '{expected}' in tool_return, " f"but got {returns!r}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("llm_config", TESTED_LLM_CONFIGS, ids=[c.model for c in TESTED_LLM_CONFIGS])
|
||||
def test_web_search(
|
||||
client: Letta,
|
||||
agent_state: AgentState,
|
||||
llm_config: LLMConfig,
|
||||
) -> None:
|
||||
user_message = MessageCreate(
|
||||
role="user",
|
||||
content=("Use the web search tool to find the latest news about San Francisco."),
|
||||
otid=USER_MESSAGE_OTID,
|
||||
)
|
||||
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent_state.id,
|
||||
messages=[user_message],
|
||||
)
|
||||
|
||||
tool_returns = [m for m in response.messages if isinstance(m, ToolReturnMessage)]
|
||||
assert tool_returns, "No ToolReturnMessage found"
|
||||
|
||||
returns = [m.tool_return for m in tool_returns]
|
||||
expected = "RESULT 1:"
|
||||
assert any(expected in ret for ret in returns), f"Expected to find '{expected}' in tool_return, " f"but got {returns!r}"
|
@ -67,9 +67,14 @@ async def test_composio_tool_execution_e2e(check_composio_key_set, composio_get_
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
tool_execution_result = await ToolExecutionManager(agent_state, actor=default_user).execute_tool(
|
||||
function_name=composio_get_emojis.name, function_args={}, tool=composio_get_emojis
|
||||
)
|
||||
tool_execution_result = await ToolExecutionManager(
|
||||
message_manager=server.message_manager,
|
||||
agent_manager=server.agent_manager,
|
||||
block_manager=server.block_manager,
|
||||
passage_manager=server.passage_manager,
|
||||
agent_state=agent_state,
|
||||
actor=default_user,
|
||||
).execute_tool(function_name=composio_get_emojis.name, function_args={}, tool=composio_get_emojis)
|
||||
|
||||
# Small check, it should return something at least
|
||||
assert len(tool_execution_result.func_return.keys()) > 10
|
||||
|
@ -1,56 +1,120 @@
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from dotenv import load_dotenv
|
||||
from letta_client import Letta
|
||||
|
||||
from letta import LocalClient, create_client
|
||||
from letta.config import LettaConfig
|
||||
from letta.functions.functions import derive_openai_json_schema, parse_source_code
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.letta_message import SystemMessage, ToolReturnMessage
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.memory import ChatMemory
|
||||
from letta.schemas.tool import Tool
|
||||
from letta.server.server import SyncServer
|
||||
from letta.services.agent_manager import AgentManager
|
||||
from letta.settings import settings
|
||||
from tests.helpers.utils import retry_until_success
|
||||
from tests.utils import wait_for_incoming_message
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def client():
|
||||
client = create_client()
|
||||
client.set_default_llm_config(LLMConfig.default_config("gpt-4o"))
|
||||
client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai"))
|
||||
@pytest.fixture(scope="module")
|
||||
def server_url() -> str:
|
||||
"""
|
||||
Provides the URL for the Letta server.
|
||||
If LETTA_SERVER_URL is not set, starts the server in a background thread
|
||||
and polls until it’s accepting connections.
|
||||
"""
|
||||
|
||||
yield client
|
||||
def _run_server() -> None:
|
||||
load_dotenv()
|
||||
from letta.server.rest_api.app import start_server
|
||||
|
||||
start_server(debug=True)
|
||||
|
||||
url: str = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
|
||||
|
||||
if not os.getenv("LETTA_SERVER_URL"):
|
||||
thread = threading.Thread(target=_run_server, daemon=True)
|
||||
thread.start()
|
||||
|
||||
# Poll until the server is up (or timeout)
|
||||
timeout_seconds = 30
|
||||
deadline = time.time() + timeout_seconds
|
||||
while time.time() < deadline:
|
||||
try:
|
||||
resp = requests.get(url + "/v1/health")
|
||||
if resp.status_code < 500:
|
||||
break
|
||||
except requests.exceptions.RequestException:
|
||||
pass
|
||||
time.sleep(0.1)
|
||||
else:
|
||||
raise RuntimeError(f"Could not reach {url} within {timeout_seconds}s")
|
||||
|
||||
temp = settings.use_experimental
|
||||
settings.use_experimental = True
|
||||
yield url
|
||||
settings.use_experimental = temp
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
config = LettaConfig.load()
|
||||
print("CONFIG PATH", config.config_path)
|
||||
|
||||
config.save()
|
||||
|
||||
server = SyncServer()
|
||||
return server
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def client(server_url: str) -> Letta:
|
||||
"""
|
||||
Creates and returns a synchronous Letta REST client for testing.
|
||||
"""
|
||||
client_instance = Letta(base_url=server_url)
|
||||
client_instance.tools.upsert_base_tools()
|
||||
yield client_instance
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def remove_stale_agents(client):
|
||||
stale_agents = AgentManager().list_agents(actor=client.user, limit=300)
|
||||
stale_agents = client.agents.list(limit=300)
|
||||
for agent in stale_agents:
|
||||
client.delete_agent(agent_id=agent.id)
|
||||
client.agents.delete(agent_id=agent.id)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def agent_obj(client: LocalClient):
|
||||
def agent_obj(client):
|
||||
"""Create a test agent that we can call functions on"""
|
||||
send_message_to_agent_and_wait_for_reply_tool_id = client.get_tool_id(name="send_message_to_agent_and_wait_for_reply")
|
||||
agent_state = client.create_agent(tool_ids=[send_message_to_agent_and_wait_for_reply_tool_id])
|
||||
send_message_to_agent_tool = client.tools.list(name="send_message_to_agent_and_wait_for_reply")[0]
|
||||
agent_state_instance = client.agents.create(
|
||||
include_base_tools=True,
|
||||
tool_ids=[send_message_to_agent_tool.id],
|
||||
model="openai/gpt-4o-mini",
|
||||
embedding="letta/letta-free",
|
||||
)
|
||||
yield agent_state_instance
|
||||
|
||||
agent_obj = client.server.load_agent(agent_id=agent_state.id, actor=client.user)
|
||||
yield agent_obj
|
||||
|
||||
# client.delete_agent(agent_obj.agent_state.id)
|
||||
client.agents.delete(agent_state_instance.id)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def other_agent_obj(client: LocalClient):
|
||||
def other_agent_obj(client):
|
||||
"""Create another test agent that we can call functions on"""
|
||||
agent_state = client.create_agent(include_multi_agent_tools=False)
|
||||
agent_state_instance = client.agents.create(
|
||||
include_base_tools=True,
|
||||
include_multi_agent_tools=False,
|
||||
model="openai/gpt-4o-mini",
|
||||
embedding="letta/letta-free",
|
||||
)
|
||||
|
||||
other_agent_obj = client.server.load_agent(agent_id=agent_state.id, actor=client.user)
|
||||
yield other_agent_obj
|
||||
yield agent_state_instance
|
||||
|
||||
client.delete_agent(other_agent_obj.agent_state.id)
|
||||
client.agents.delete(agent_state_instance.id)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -77,48 +141,68 @@ def roll_dice_tool(client):
|
||||
tool.json_schema = derived_json_schema
|
||||
tool.name = derived_name
|
||||
|
||||
tool = client.server.tool_manager.create_or_update_tool(tool, actor=client.user)
|
||||
tool = client.tools.upsert_from_function(func=roll_dice)
|
||||
|
||||
# Yield the created tool
|
||||
yield tool
|
||||
|
||||
|
||||
@retry_until_success(max_attempts=5, sleep_time_seconds=2)
|
||||
def test_send_message_to_agent(client, agent_obj, other_agent_obj):
|
||||
def test_send_message_to_agent(client, server, agent_obj, other_agent_obj):
|
||||
secret_word = "banana"
|
||||
actor = server.user_manager.get_user_or_default()
|
||||
|
||||
# Encourage the agent to send a message to the other agent_obj with the secret string
|
||||
client.send_message(
|
||||
agent_id=agent_obj.agent_state.id,
|
||||
role="user",
|
||||
message=f"Use your tool to send a message to another agent with id {other_agent_obj.agent_state.id} to share the secret word: {secret_word}!",
|
||||
client.agents.messages.create(
|
||||
agent_id=agent_obj.id,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"Use your tool to send a message to another agent with id {other_agent_obj.id} to share the secret word: {secret_word}!",
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
# Conversation search the other agent
|
||||
messages = client.get_messages(other_agent_obj.agent_state.id)
|
||||
messages = server.get_agent_recall(
|
||||
user_id=actor.id,
|
||||
agent_id=other_agent_obj.id,
|
||||
reverse=True,
|
||||
return_message_object=False,
|
||||
)
|
||||
|
||||
# Check for the presence of system message
|
||||
for m in reversed(messages):
|
||||
print(f"\n\n {other_agent_obj.agent_state.id} -> {m.model_dump_json(indent=4)}")
|
||||
print(f"\n\n {other_agent_obj.id} -> {m.model_dump_json(indent=4)}")
|
||||
if isinstance(m, SystemMessage):
|
||||
assert secret_word in m.content
|
||||
break
|
||||
|
||||
# Search the sender agent for the response from another agent
|
||||
in_context_messages = agent_obj.agent_manager.get_in_context_messages(agent_id=agent_obj.agent_state.id, actor=agent_obj.user)
|
||||
in_context_messages = AgentManager().get_in_context_messages(agent_id=agent_obj.id, actor=actor)
|
||||
found = False
|
||||
target_snippet = f"{other_agent_obj.agent_state.id} said:"
|
||||
target_snippet = f"'agent_id': '{other_agent_obj.id}', 'response': ["
|
||||
|
||||
for m in in_context_messages:
|
||||
if target_snippet in m.content[0].text:
|
||||
found = True
|
||||
break
|
||||
|
||||
print(f"In context messages of the sender agent (without system):\n\n{"\n".join([m.content[0].text for m in in_context_messages[1:]])}")
|
||||
joined = "\n".join([m.content[0].text for m in in_context_messages[1:]])
|
||||
print(f"In context messages of the sender agent (without system):\n\n{joined}")
|
||||
if not found:
|
||||
raise Exception(f"Was not able to find an instance of the target snippet: {target_snippet}")
|
||||
|
||||
# Test that the agent can still receive messages fine
|
||||
response = client.send_message(agent_id=agent_obj.agent_state.id, role="user", message="So what did the other agent say?")
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent_obj.id,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "So what did the other agent say?",
|
||||
}
|
||||
],
|
||||
)
|
||||
print(response.messages)
|
||||
|
||||
|
||||
@ -127,39 +211,50 @@ def test_send_message_to_agents_with_tags_simple(client):
|
||||
worker_tags_123 = ["worker", "user-123"]
|
||||
worker_tags_456 = ["worker", "user-456"]
|
||||
|
||||
# Clean up first from possibly failed tests
|
||||
prev_worker_agents = client.server.agent_manager.list_agents(
|
||||
client.user, tags=list(set(worker_tags_123 + worker_tags_456)), match_all_tags=True
|
||||
)
|
||||
for agent in prev_worker_agents:
|
||||
client.delete_agent(agent.id)
|
||||
|
||||
secret_word = "banana"
|
||||
|
||||
# Create "manager" agent
|
||||
send_message_to_agents_matching_tags_tool_id = client.get_tool_id(name="send_message_to_agents_matching_tags")
|
||||
manager_agent_state = client.create_agent(name="manager_agent", tool_ids=[send_message_to_agents_matching_tags_tool_id])
|
||||
manager_agent = client.server.load_agent(agent_id=manager_agent_state.id, actor=client.user)
|
||||
send_message_to_agents_matching_tags_tool_id = client.tools.list(name="send_message_to_agents_matching_tags")[0].id
|
||||
manager_agent_state = client.agents.create(
|
||||
name="manager_agent",
|
||||
tool_ids=[send_message_to_agents_matching_tags_tool_id],
|
||||
model="openai/gpt-4o-mini",
|
||||
embedding="letta/letta-free",
|
||||
)
|
||||
|
||||
# Create 3 non-matching worker agents (These should NOT get the message)
|
||||
worker_agents_123 = []
|
||||
for idx in range(2):
|
||||
worker_agent_state = client.create_agent(name=f"not_worker_{idx}", include_multi_agent_tools=False, tags=worker_tags_123)
|
||||
worker_agent = client.server.load_agent(agent_id=worker_agent_state.id, actor=client.user)
|
||||
worker_agents_123.append(worker_agent)
|
||||
worker_agent_state = client.agents.create(
|
||||
name=f"not_worker_{idx}",
|
||||
include_multi_agent_tools=False,
|
||||
tags=worker_tags_123,
|
||||
model="openai/gpt-4o-mini",
|
||||
embedding="letta/letta-free",
|
||||
)
|
||||
worker_agents_123.append(worker_agent_state)
|
||||
|
||||
# Create 3 worker agents that should get the message
|
||||
worker_agents_456 = []
|
||||
for idx in range(2):
|
||||
worker_agent_state = client.create_agent(name=f"worker_{idx}", include_multi_agent_tools=False, tags=worker_tags_456)
|
||||
worker_agent = client.server.load_agent(agent_id=worker_agent_state.id, actor=client.user)
|
||||
worker_agents_456.append(worker_agent)
|
||||
worker_agent_state = client.agents.create(
|
||||
name=f"worker_{idx}",
|
||||
include_multi_agent_tools=False,
|
||||
tags=worker_tags_456,
|
||||
model="openai/gpt-4o-mini",
|
||||
embedding="letta/letta-free",
|
||||
)
|
||||
worker_agents_456.append(worker_agent_state)
|
||||
|
||||
# Encourage the manager to send a message to the other agent_obj with the secret string
|
||||
response = client.send_message(
|
||||
agent_id=manager_agent.agent_state.id,
|
||||
role="user",
|
||||
message=f"Send a message to all agents with tags {worker_tags_456} informing them of the secret word: {secret_word}!",
|
||||
response = client.agents.messages.create(
|
||||
agent_id=manager_agent_state.id,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"Send a message to all agents with tags {worker_tags_456} informing them of the secret word: {secret_word}!",
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
for m in response.messages:
|
||||
@ -172,62 +267,70 @@ def test_send_message_to_agents_with_tags_simple(client):
|
||||
break
|
||||
|
||||
# Conversation search the worker agents
|
||||
for agent in worker_agents_456:
|
||||
messages = client.get_messages(agent.agent_state.id)
|
||||
for agent_state in worker_agents_456:
|
||||
messages = client.agents.messages.list(agent_state.id)
|
||||
# Check for the presence of system message
|
||||
for m in reversed(messages):
|
||||
print(f"\n\n {agent.agent_state.id} -> {m.model_dump_json(indent=4)}")
|
||||
print(f"\n\n {agent_state.id} -> {m.model_dump_json(indent=4)}")
|
||||
if isinstance(m, SystemMessage):
|
||||
assert secret_word in m.content
|
||||
break
|
||||
|
||||
# Ensure it's NOT in the non matching worker agents
|
||||
for agent in worker_agents_123:
|
||||
messages = client.get_messages(agent.agent_state.id)
|
||||
for agent_state in worker_agents_123:
|
||||
messages = client.agents.messages.list(agent_state.id)
|
||||
# Check for the presence of system message
|
||||
for m in reversed(messages):
|
||||
print(f"\n\n {agent.agent_state.id} -> {m.model_dump_json(indent=4)}")
|
||||
print(f"\n\n {agent_state.id} -> {m.model_dump_json(indent=4)}")
|
||||
if isinstance(m, SystemMessage):
|
||||
assert secret_word not in m.content
|
||||
|
||||
# Test that the agent can still receive messages fine
|
||||
response = client.send_message(agent_id=manager_agent.agent_state.id, role="user", message="So what did the other agents say?")
|
||||
response = client.agents.messages.create(
|
||||
agent_id=manager_agent_state.id,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "So what did the other agent say?",
|
||||
}
|
||||
],
|
||||
)
|
||||
print("Manager agent followup message: \n\n" + "\n".join([str(m) for m in response.messages]))
|
||||
|
||||
# Clean up agents
|
||||
client.delete_agent(manager_agent_state.id)
|
||||
for agent in worker_agents_456 + worker_agents_123:
|
||||
client.delete_agent(agent.agent_state.id)
|
||||
|
||||
|
||||
@retry_until_success(max_attempts=5, sleep_time_seconds=2)
|
||||
def test_send_message_to_agents_with_tags_complex_tool_use(client, roll_dice_tool):
|
||||
worker_tags = ["dice-rollers"]
|
||||
|
||||
# Clean up first from possibly failed tests
|
||||
prev_worker_agents = client.server.agent_manager.list_agents(client.user, tags=worker_tags, match_all_tags=True)
|
||||
for agent in prev_worker_agents:
|
||||
client.delete_agent(agent.id)
|
||||
|
||||
# Create "manager" agent
|
||||
send_message_to_agents_matching_tags_tool_id = client.get_tool_id(name="send_message_to_agents_matching_tags")
|
||||
manager_agent_state = client.create_agent(tool_ids=[send_message_to_agents_matching_tags_tool_id])
|
||||
manager_agent = client.server.load_agent(agent_id=manager_agent_state.id, actor=client.user)
|
||||
send_message_to_agents_matching_tags_tool_id = client.tools.list(name="send_message_to_agents_matching_tags")[0].id
|
||||
manager_agent_state = client.agents.create(
|
||||
tool_ids=[send_message_to_agents_matching_tags_tool_id],
|
||||
model="openai/gpt-4o-mini",
|
||||
embedding="letta/letta-free",
|
||||
)
|
||||
|
||||
# Create 3 worker agents
|
||||
worker_agents = []
|
||||
worker_tags = ["dice-rollers"]
|
||||
for _ in range(2):
|
||||
worker_agent_state = client.create_agent(include_multi_agent_tools=False, tags=worker_tags, tool_ids=[roll_dice_tool.id])
|
||||
worker_agent = client.server.load_agent(agent_id=worker_agent_state.id, actor=client.user)
|
||||
worker_agents.append(worker_agent)
|
||||
worker_agent_state = client.agents.create(
|
||||
include_multi_agent_tools=False,
|
||||
tags=worker_tags,
|
||||
tool_ids=[roll_dice_tool.id],
|
||||
model="openai/gpt-4o-mini",
|
||||
embedding="letta/letta-free",
|
||||
)
|
||||
worker_agents.append(worker_agent_state)
|
||||
|
||||
# Encourage the manager to send a message to the other agent_obj with the secret string
|
||||
broadcast_message = f"Send a message to all agents with tags {worker_tags} asking them to roll a dice for you!"
|
||||
response = client.send_message(
|
||||
agent_id=manager_agent.agent_state.id,
|
||||
role="user",
|
||||
message=broadcast_message,
|
||||
response = client.agents.messages.create(
|
||||
agent_id=manager_agent_state.id,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": broadcast_message,
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
for m in response.messages:
|
||||
@ -240,47 +343,65 @@ def test_send_message_to_agents_with_tags_complex_tool_use(client, roll_dice_too
|
||||
break
|
||||
|
||||
# Test that the agent can still receive messages fine
|
||||
response = client.send_message(agent_id=manager_agent.agent_state.id, role="user", message="So what did the other agents say?")
|
||||
response = client.agents.messages.create(
|
||||
agent_id=manager_agent_state.id,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "So what did the other agent say?",
|
||||
}
|
||||
],
|
||||
)
|
||||
print("Manager agent followup message: \n\n" + "\n".join([str(m) for m in response.messages]))
|
||||
|
||||
# Clean up agents
|
||||
client.delete_agent(manager_agent_state.id)
|
||||
for agent in worker_agents:
|
||||
client.delete_agent(agent.agent_state.id)
|
||||
|
||||
|
||||
@retry_until_success(max_attempts=5, sleep_time_seconds=2)
|
||||
# @retry_until_success(max_attempts=5, sleep_time_seconds=2)
|
||||
def test_agents_async_simple(client):
|
||||
"""
|
||||
Test two agents with multi-agent tools sending messages back and forth to count to 5.
|
||||
The chain is started by prompting one of the agents.
|
||||
"""
|
||||
# Cleanup from potentially failed previous runs
|
||||
existing_agents = client.server.agent_manager.list_agents(client.user)
|
||||
for agent in existing_agents:
|
||||
client.delete_agent(agent.id)
|
||||
|
||||
# Create two agents with multi-agent tools
|
||||
send_message_to_agent_async_tool_id = client.get_tool_id(name="send_message_to_agent_async")
|
||||
memory_a = ChatMemory(
|
||||
human="Chad - I'm interested in hearing poem.",
|
||||
persona="You are an AI agent that can communicate with your agent buddy using `send_message_to_agent_async`, who has some great poem ideas (so I've heard).",
|
||||
send_message_to_agent_async_tool_id = client.tools.list(name="send_message_to_agent_async")[0].id
|
||||
charles_state = client.agents.create(
|
||||
name="charles",
|
||||
tool_ids=[send_message_to_agent_async_tool_id],
|
||||
memory_blocks=[
|
||||
{
|
||||
"label": "human",
|
||||
"value": "Chad - I'm interested in hearing poem.",
|
||||
},
|
||||
{
|
||||
"label": "persona",
|
||||
"value": "You are an AI agent that can communicate with your agent buddy using `send_message_to_agent_async`, who has some great poem ideas (so I've heard).",
|
||||
},
|
||||
],
|
||||
model="openai/gpt-4o-mini",
|
||||
embedding="letta/letta-free",
|
||||
)
|
||||
charles_state = client.create_agent(name="charles", memory=memory_a, tool_ids=[send_message_to_agent_async_tool_id])
|
||||
charles = client.server.load_agent(agent_id=charles_state.id, actor=client.user)
|
||||
|
||||
memory_b = ChatMemory(
|
||||
human="No human - you are to only communicate with the other AI agent.",
|
||||
persona="You are an AI agent that can communicate with your agent buddy using `send_message_to_agent_async`, who is interested in great poem ideas.",
|
||||
sarah_state = client.agents.create(
|
||||
name="sarah",
|
||||
tool_ids=[send_message_to_agent_async_tool_id],
|
||||
memory_blocks=[
|
||||
{
|
||||
"label": "human",
|
||||
"value": "No human - you are to only communicate with the other AI agent.",
|
||||
},
|
||||
{
|
||||
"label": "persona",
|
||||
"value": "You are an AI agent that can communicate with your agent buddy using `send_message_to_agent_async`, who is interested in great poem ideas.",
|
||||
},
|
||||
],
|
||||
model="openai/gpt-4o-mini",
|
||||
embedding="letta/letta-free",
|
||||
)
|
||||
sarah_state = client.create_agent(name="sarah", memory=memory_b, tool_ids=[send_message_to_agent_async_tool_id])
|
||||
|
||||
# Start the count chain with Agent1
|
||||
initial_prompt = f"I want you to talk to the other agent with ID {sarah_state.id} using `send_message_to_agent_async`. Specifically, I want you to ask him for a poem idea, and then craft a poem for me."
|
||||
client.send_message(
|
||||
agent_id=charles.agent_state.id,
|
||||
role="user",
|
||||
message=initial_prompt,
|
||||
client.agents.messages.create(
|
||||
agent_id=charles_state.id,
|
||||
messages=[{"role": "user", "content": initial_prompt}],
|
||||
)
|
||||
|
||||
found_in_charles = wait_for_incoming_message(
|
||||
|
@ -135,6 +135,7 @@ all_configs = [
|
||||
"gemini-1.5-pro.json",
|
||||
"gemini-2.5-flash-vertex.json",
|
||||
"gemini-2.5-pro-vertex.json",
|
||||
"together-qwen-2.5-72b-instruct.json",
|
||||
]
|
||||
requested = os.getenv("LLM_CONFIG_FILE")
|
||||
filenames = [requested] if requested else all_configs
|
||||
@ -170,6 +171,10 @@ def assert_greeting_with_assistant_message_response(
|
||||
|
||||
if streaming:
|
||||
assert isinstance(messages[index], LettaUsageStatistics)
|
||||
assert messages[index].prompt_tokens > 0
|
||||
assert messages[index].completion_tokens > 0
|
||||
assert messages[index].total_tokens > 0
|
||||
assert messages[index].step_count > 0
|
||||
|
||||
|
||||
def assert_greeting_without_assistant_message_response(
|
||||
@ -636,6 +641,33 @@ async def test_streaming_tool_call_async_client(
|
||||
assert_tool_call_response(messages, streaming=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"llm_config",
|
||||
TESTED_LLM_CONFIGS,
|
||||
ids=[c.model for c in TESTED_LLM_CONFIGS],
|
||||
)
|
||||
def test_step_streaming_greeting_with_assistant_message(
|
||||
disable_e2b_api_key: Any,
|
||||
client: Letta,
|
||||
agent_state: AgentState,
|
||||
llm_config: LLMConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Tests sending a streaming message with a synchronous client.
|
||||
Checks that each chunk in the stream has the correct message types.
|
||||
"""
|
||||
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
response = client.agents.messages.create_stream(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE_GREETING,
|
||||
stream_tokens=False,
|
||||
)
|
||||
messages = []
|
||||
for message in response:
|
||||
messages.append(message)
|
||||
assert_greeting_with_assistant_message_response(messages, streaming=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"llm_config",
|
||||
TESTED_LLM_CONFIGS,
|
||||
|
@ -6,7 +6,7 @@ from sqlalchemy import delete
|
||||
from letta.config import LettaConfig
|
||||
from letta.constants import DEFAULT_HUMAN
|
||||
from letta.groups.sleeptime_multi_agent_v2 import SleeptimeMultiAgentV2
|
||||
from letta.orm import Provider, Step
|
||||
from letta.orm import Provider, ProviderTrace, Step
|
||||
from letta.orm.enums import JobType
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.schemas.agent import CreateAgent
|
||||
@ -39,6 +39,7 @@ def org_id(server):
|
||||
|
||||
# cleanup
|
||||
with db_registry.session() as session:
|
||||
session.execute(delete(ProviderTrace))
|
||||
session.execute(delete(Step))
|
||||
session.execute(delete(Provider))
|
||||
session.commit()
|
||||
@ -54,7 +55,7 @@ def actor(server, org_id):
|
||||
server.user_manager.delete_user_by_id(user.id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio(loop_scope="module")
|
||||
async def test_sleeptime_group_chat(server, actor):
|
||||
# 0. Refresh base tools
|
||||
server.tool_manager.upsert_base_tools(actor=actor)
|
||||
@ -105,7 +106,7 @@ async def test_sleeptime_group_chat(server, actor):
|
||||
# 3. Verify shared blocks
|
||||
sleeptime_agent_id = group.agent_ids[0]
|
||||
shared_block = server.agent_manager.get_block_with_label(agent_id=main_agent.id, block_label="human", actor=actor)
|
||||
agents = server.block_manager.get_agents_for_block(block_id=shared_block.id, actor=actor)
|
||||
agents = await server.block_manager.get_agents_for_block_async(block_id=shared_block.id, actor=actor)
|
||||
assert len(agents) == 2
|
||||
assert sleeptime_agent_id in [agent.id for agent in agents]
|
||||
assert main_agent.id in [agent.id for agent in agents]
|
||||
@ -169,7 +170,7 @@ async def test_sleeptime_group_chat(server, actor):
|
||||
server.agent_manager.get_agent_by_id(agent_id=sleeptime_agent_id, actor=actor)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio(loop_scope="module")
|
||||
async def test_sleeptime_group_chat_v2(server, actor):
|
||||
# 0. Refresh base tools
|
||||
server.tool_manager.upsert_base_tools(actor=actor)
|
||||
@ -220,7 +221,7 @@ async def test_sleeptime_group_chat_v2(server, actor):
|
||||
# 3. Verify shared blocks
|
||||
sleeptime_agent_id = group.agent_ids[0]
|
||||
shared_block = server.agent_manager.get_block_with_label(agent_id=main_agent.id, block_label="human", actor=actor)
|
||||
agents = server.block_manager.get_agents_for_block(block_id=shared_block.id, actor=actor)
|
||||
agents = await server.block_manager.get_agents_for_block_async(block_id=shared_block.id, actor=actor)
|
||||
assert len(agents) == 2
|
||||
assert sleeptime_agent_id in [agent.id for agent in agents]
|
||||
assert main_agent.id in [agent.id for agent in agents]
|
||||
@ -292,7 +293,7 @@ async def test_sleeptime_group_chat_v2(server, actor):
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio(loop_scope="module")
|
||||
async def test_sleeptime_removes_redundant_information(server, actor):
|
||||
# 1. set up sleep-time agent as in test_sleeptime_group_chat
|
||||
server.tool_manager.upsert_base_tools(actor=actor)
|
||||
@ -360,7 +361,7 @@ async def test_sleeptime_removes_redundant_information(server, actor):
|
||||
server.agent_manager.get_agent_by_id(agent_id=sleeptime_agent_id, actor=actor)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio(loop_scope="module")
|
||||
async def test_sleeptime_edit(server, actor):
|
||||
sleeptime_agent = server.create_agent(
|
||||
request=CreateAgent(
|
||||
|
@ -1,10 +1,10 @@
|
||||
import os
|
||||
import threading
|
||||
import subprocess
|
||||
import sys
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from dotenv import load_dotenv
|
||||
from letta_client import AsyncLetta
|
||||
from openai import AsyncOpenAI
|
||||
from openai.types.chat import ChatCompletionChunk
|
||||
|
||||
@ -35,7 +35,7 @@ from letta.services.summarizer.summarizer import Summarizer
|
||||
from letta.services.tool_manager import ToolManager
|
||||
from letta.services.user_manager import UserManager
|
||||
from letta.utils import get_persona_text
|
||||
from tests.utils import wait_for_server
|
||||
from tests.utils import create_tool_from_func, wait_for_server
|
||||
|
||||
MESSAGE_TRANSCRIPTS = [
|
||||
"user: Hey, I’ve been thinking about planning a road trip up the California coast next month.",
|
||||
@ -92,17 +92,6 @@ You’re a memory-recall helper for an AI that can only keep the last 4 messages
|
||||
# --- Server Management --- #
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
config = LettaConfig.load()
|
||||
print("CONFIG PATH", config.config_path)
|
||||
|
||||
config.save()
|
||||
|
||||
server = SyncServer()
|
||||
return server
|
||||
|
||||
|
||||
def _run_server():
|
||||
"""Starts the Letta server in a background thread."""
|
||||
load_dotenv()
|
||||
@ -111,31 +100,66 @@ def _run_server():
|
||||
start_server(debug=True)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
@pytest.fixture(scope="module")
|
||||
def server_url():
|
||||
"""Ensures a server is running and returns its base URL."""
|
||||
url = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
|
||||
"""
|
||||
Starts the Letta HTTP server in a separate process using the 'uvicorn' CLI,
|
||||
so its event loop and DB pool stay completely isolated from pytest-asyncio.
|
||||
"""
|
||||
url = os.getenv("LETTA_SERVER_URL", "http://127.0.0.1:8283")
|
||||
|
||||
# Only spawn our own server if the user hasn't overridden LETTA_SERVER_URL
|
||||
if not os.getenv("LETTA_SERVER_URL"):
|
||||
thread = threading.Thread(target=_run_server, daemon=True)
|
||||
thread.start()
|
||||
wait_for_server(url) # Allow server startup time
|
||||
# Build the command to launch uvicorn on your FastAPI app
|
||||
cmd = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
"uvicorn",
|
||||
"letta.server.rest_api.app:app",
|
||||
"--host",
|
||||
"127.0.0.1",
|
||||
"--port",
|
||||
"8283",
|
||||
]
|
||||
# If you need TLS or reload settings from start_server(), you can add
|
||||
# "--reload" or "--ssl-keyfile", "--ssl-certfile" here as well.
|
||||
|
||||
return url
|
||||
server_proc = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.DEVNULL,
|
||||
)
|
||||
|
||||
# wait until the HTTP port is accepting connections
|
||||
wait_for_server(url)
|
||||
|
||||
yield url
|
||||
|
||||
# Teardown: kill the subprocess if we started it
|
||||
server_proc.terminate()
|
||||
server_proc.wait(timeout=10)
|
||||
else:
|
||||
yield url
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
config = LettaConfig.load()
|
||||
print("CONFIG PATH", config.config_path)
|
||||
|
||||
config.save()
|
||||
|
||||
server = SyncServer()
|
||||
actor = server.user_manager.get_user_or_default()
|
||||
server.tool_manager.upsert_base_tools(actor=actor)
|
||||
return server
|
||||
|
||||
|
||||
# --- Client Setup --- #
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def client(server_url):
|
||||
"""Creates a REST client for testing."""
|
||||
client = AsyncLetta(base_url=server_url)
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
async def roll_dice_tool(client):
|
||||
@pytest.fixture
|
||||
async def roll_dice_tool(server):
|
||||
def roll_dice():
|
||||
"""
|
||||
Rolls a 6 sided die.
|
||||
@ -145,13 +169,13 @@ async def roll_dice_tool(client):
|
||||
"""
|
||||
return "Rolled a 10!"
|
||||
|
||||
tool = await client.tools.upsert_from_function(func=roll_dice)
|
||||
# Yield the created tool
|
||||
actor = server.user_manager.get_user_or_default()
|
||||
tool = server.tool_manager.create_or_update_tool(create_tool_from_func(func=roll_dice), actor=actor)
|
||||
yield tool
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
async def weather_tool(client):
|
||||
@pytest.fixture
|
||||
async def weather_tool(server):
|
||||
def get_weather(location: str) -> str:
|
||||
"""
|
||||
Fetches the current weather for a given location.
|
||||
@ -176,22 +200,20 @@ async def weather_tool(client):
|
||||
else:
|
||||
raise RuntimeError(f"Failed to get weather data, status code: {response.status_code}")
|
||||
|
||||
tool = await client.tools.upsert_from_function(func=get_weather)
|
||||
# Yield the created tool
|
||||
actor = server.user_manager.get_user_or_default()
|
||||
tool = server.tool_manager.create_or_update_tool(create_tool_from_func(func=get_weather), actor=actor)
|
||||
yield tool
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
@pytest.fixture
|
||||
def composio_gmail_get_profile_tool(default_user):
|
||||
tool_create = ToolCreate.from_composio(action_name="GMAIL_GET_PROFILE")
|
||||
tool = ToolManager().create_or_update_composio_tool(tool_create=tool_create, actor=default_user)
|
||||
yield tool
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
@pytest.fixture
|
||||
def voice_agent(server, actor):
|
||||
server.tool_manager.upsert_base_tools(actor=actor)
|
||||
|
||||
main_agent = server.create_agent(
|
||||
request=CreateAgent(
|
||||
agent_type=AgentType.voice_convo_agent,
|
||||
@ -268,9 +290,9 @@ def _assert_valid_chunk(chunk, idx, chunks):
|
||||
# --- Tests --- #
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.asyncio(loop_scope="module")
|
||||
@pytest.mark.parametrize("model", ["openai/gpt-4o-mini", "anthropic/claude-3-5-sonnet-20241022"])
|
||||
async def test_model_compatibility(disable_e2b_api_key, voice_agent, model, server, group_id, actor):
|
||||
async def test_model_compatibility(disable_e2b_api_key, voice_agent, model, server, server_url, group_id, actor):
|
||||
request = _get_chat_request("How are you?")
|
||||
server.tool_manager.upsert_base_tools(actor=actor)
|
||||
|
||||
@ -303,10 +325,10 @@ async def test_model_compatibility(disable_e2b_api_key, voice_agent, model, serv
|
||||
print(chunk.choices[0].delta.content)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.asyncio(loop_scope="module")
|
||||
@pytest.mark.parametrize("message", ["Use search memory tool to recall what my name is."])
|
||||
@pytest.mark.parametrize("endpoint", ["v1/voice-beta"])
|
||||
async def test_voice_recall_memory(disable_e2b_api_key, voice_agent, message, endpoint):
|
||||
async def test_voice_recall_memory(disable_e2b_api_key, voice_agent, message, endpoint, server_url):
|
||||
"""Tests chat completion streaming using the Async OpenAI client."""
|
||||
request = _get_chat_request(message)
|
||||
|
||||
@ -318,9 +340,9 @@ async def test_voice_recall_memory(disable_e2b_api_key, voice_agent, message, en
|
||||
print(chunk.choices[0].delta.content)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.asyncio(loop_scope="module")
|
||||
@pytest.mark.parametrize("endpoint", ["v1/voice-beta"])
|
||||
async def test_trigger_summarization(disable_e2b_api_key, server, voice_agent, group_id, endpoint, actor):
|
||||
async def test_trigger_summarization(disable_e2b_api_key, server, voice_agent, group_id, endpoint, actor, server_url):
|
||||
server.group_manager.modify_group(
|
||||
group_id=group_id,
|
||||
group_update=GroupUpdate(
|
||||
@ -350,8 +372,8 @@ async def test_trigger_summarization(disable_e2b_api_key, server, voice_agent, g
|
||||
print(chunk.choices[0].delta.content)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_summarization(disable_e2b_api_key, voice_agent):
|
||||
@pytest.mark.asyncio(loop_scope="module")
|
||||
async def test_summarization(disable_e2b_api_key, voice_agent, server_url):
|
||||
agent_manager = AgentManager()
|
||||
user_manager = UserManager()
|
||||
actor = user_manager.get_default_user()
|
||||
@ -422,8 +444,8 @@ async def test_summarization(disable_e2b_api_key, voice_agent):
|
||||
summarizer.fire_and_forget.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_voice_sleeptime_agent(disable_e2b_api_key, voice_agent):
|
||||
@pytest.mark.asyncio(loop_scope="module")
|
||||
async def test_voice_sleeptime_agent(disable_e2b_api_key, voice_agent, server_url):
|
||||
"""Tests chat completion streaming using the Async OpenAI client."""
|
||||
agent_manager = AgentManager()
|
||||
tool_manager = ToolManager()
|
||||
@ -488,8 +510,8 @@ async def test_voice_sleeptime_agent(disable_e2b_api_key, voice_agent):
|
||||
assert not missing, f"Did not see calls to: {', '.join(missing)}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_init_voice_convo_agent(voice_agent, server, actor):
|
||||
@pytest.mark.asyncio(loop_scope="module")
|
||||
async def test_init_voice_convo_agent(voice_agent, server, actor, server_url):
|
||||
|
||||
assert voice_agent.enable_sleeptime == True
|
||||
main_agent_tools = [tool.name for tool in voice_agent.tools]
|
||||
@ -511,7 +533,7 @@ async def test_init_voice_convo_agent(voice_agent, server, actor):
|
||||
# 3. Verify shared blocks
|
||||
sleeptime_agent_id = group.agent_ids[0]
|
||||
shared_block = server.agent_manager.get_block_with_label(agent_id=voice_agent.id, block_label="human", actor=actor)
|
||||
agents = server.block_manager.get_agents_for_block(block_id=shared_block.id, actor=actor)
|
||||
agents = await server.block_manager.get_agents_for_block_async(block_id=shared_block.id, actor=actor)
|
||||
assert len(agents) == 2
|
||||
assert sleeptime_agent_id in [agent.id for agent in agents]
|
||||
assert voice_agent.id in [agent.id for agent in agents]
|
||||
|
@ -1,12 +1,15 @@
|
||||
import difflib
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from io import BytesIO
|
||||
from typing import Any, Dict, List, Mapping
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
import requests
|
||||
from dotenv import load_dotenv
|
||||
from rich.console import Console
|
||||
from rich.syntax import Syntax
|
||||
|
||||
@ -23,11 +26,51 @@ from letta.schemas.message import MessageCreate
|
||||
from letta.schemas.organization import Organization
|
||||
from letta.schemas.user import User
|
||||
from letta.serialize_schemas.pydantic_agent_schema import AgentSchema
|
||||
from letta.server.rest_api.app import app
|
||||
from letta.server.server import SyncServer
|
||||
|
||||
console = Console()
|
||||
|
||||
# ------------------------------
|
||||
# Fixtures
|
||||
# ------------------------------
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server_url() -> str:
|
||||
"""
|
||||
Provides the URL for the Letta server.
|
||||
If LETTA_SERVER_URL is not set, starts the server in a background thread
|
||||
and polls until it’s accepting connections.
|
||||
"""
|
||||
|
||||
def _run_server() -> None:
|
||||
load_dotenv()
|
||||
from letta.server.rest_api.app import start_server
|
||||
|
||||
start_server(debug=True)
|
||||
|
||||
url: str = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
|
||||
|
||||
if not os.getenv("LETTA_SERVER_URL"):
|
||||
thread = threading.Thread(target=_run_server, daemon=True)
|
||||
thread.start()
|
||||
|
||||
# Poll until the server is up (or timeout)
|
||||
timeout_seconds = 30
|
||||
deadline = time.time() + timeout_seconds
|
||||
while time.time() < deadline:
|
||||
try:
|
||||
resp = requests.get(url + "/v1/health")
|
||||
if resp.status_code < 500:
|
||||
break
|
||||
except requests.exceptions.RequestException:
|
||||
pass
|
||||
time.sleep(0.1)
|
||||
else:
|
||||
raise RuntimeError(f"Could not reach {url} within {timeout_seconds}s")
|
||||
|
||||
return url
|
||||
|
||||
|
||||
def _clear_tables():
|
||||
from letta.server.db import db_context
|
||||
@ -38,12 +81,6 @@ def _clear_tables():
|
||||
session.commit()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fastapi_client():
|
||||
"""Fixture to create a FastAPI test client."""
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_tables():
|
||||
_clear_tables()
|
||||
@ -57,14 +94,14 @@ def local_client():
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@pytest.fixture
|
||||
def server():
|
||||
config = LettaConfig.load()
|
||||
|
||||
config.save()
|
||||
|
||||
server = SyncServer(init_with_default_org_and_user=False)
|
||||
return server
|
||||
yield server
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -562,14 +599,17 @@ def test_agent_serialize_update_blocks(disable_e2b_api_key, local_client, server
|
||||
|
||||
@pytest.mark.parametrize("append_copy_suffix", [True, False])
|
||||
@pytest.mark.parametrize("project_id", ["project-12345", None])
|
||||
def test_agent_download_upload_flow(fastapi_client, server, serialize_test_agent, default_user, other_user, append_copy_suffix, project_id):
|
||||
def test_agent_download_upload_flow(server, server_url, serialize_test_agent, default_user, other_user, append_copy_suffix, project_id):
|
||||
"""
|
||||
Test the full E2E serialization and deserialization flow using FastAPI endpoints.
|
||||
"""
|
||||
agent_id = serialize_test_agent.id
|
||||
|
||||
# Step 1: Download the serialized agent
|
||||
response = fastapi_client.get(f"/v1/agents/{agent_id}/export", headers={"user_id": default_user.id})
|
||||
response = requests.get(
|
||||
f"{server_url}/v1/agents/{agent_id}/export",
|
||||
headers={"user_id": default_user.id},
|
||||
)
|
||||
assert response.status_code == 200, f"Download failed: {response.text}"
|
||||
|
||||
# Ensure response matches expected schema
|
||||
@ -580,10 +620,14 @@ def test_agent_download_upload_flow(fastapi_client, server, serialize_test_agent
|
||||
# Step 2: Upload the serialized agent as a copy
|
||||
agent_bytes = BytesIO(json.dumps(agent_json).encode("utf-8"))
|
||||
files = {"file": ("agent.json", agent_bytes, "application/json")}
|
||||
upload_response = fastapi_client.post(
|
||||
"/v1/agents/import",
|
||||
upload_response = requests.post(
|
||||
f"{server_url}/v1/agents/import",
|
||||
headers={"user_id": other_user.id},
|
||||
params={"append_copy_suffix": append_copy_suffix, "override_existing_tools": False, "project_id": project_id},
|
||||
params={
|
||||
"append_copy_suffix": append_copy_suffix,
|
||||
"override_existing_tools": False,
|
||||
"project_id": project_id,
|
||||
},
|
||||
files=files,
|
||||
)
|
||||
assert upload_response.status_code == 200, f"Upload failed: {upload_response.text}"
|
||||
@ -613,16 +657,16 @@ def test_agent_download_upload_flow(fastapi_client, server, serialize_test_agent
|
||||
"memgpt_agent_with_convo.af",
|
||||
],
|
||||
)
|
||||
def test_upload_agentfile_from_disk(server, disable_e2b_api_key, fastapi_client, other_user, filename):
|
||||
def test_upload_agentfile_from_disk(server, server_url, disable_e2b_api_key, other_user, filename):
|
||||
"""
|
||||
Test uploading each .af file from the test_agent_files directory via FastAPI.
|
||||
Test uploading each .af file from the test_agent_files directory via live FastAPI server.
|
||||
"""
|
||||
file_path = os.path.join(os.path.dirname(__file__), "test_agent_files", filename)
|
||||
|
||||
with open(file_path, "rb") as f:
|
||||
files = {"file": (filename, f, "application/json")}
|
||||
response = fastapi_client.post(
|
||||
"/v1/agents/import",
|
||||
response = requests.post(
|
||||
f"{server_url}/v1/agents/import",
|
||||
headers={"user_id": other_user.id},
|
||||
params={"append_copy_suffix": True, "override_existing_tools": False},
|
||||
files=files,
|
||||
|
@ -1,5 +1,3 @@
|
||||
import os
|
||||
import threading
|
||||
from datetime import datetime, timezone
|
||||
from typing import Tuple
|
||||
from unittest.mock import AsyncMock, patch
|
||||
@ -14,15 +12,13 @@ from anthropic.types.beta.messages import (
|
||||
BetaMessageBatchRequestCounts,
|
||||
BetaMessageBatchSucceededResult,
|
||||
)
|
||||
from dotenv import load_dotenv
|
||||
from letta_client import Letta
|
||||
|
||||
from letta.agents.letta_agent_batch import LettaAgentBatch
|
||||
from letta.config import LettaConfig
|
||||
from letta.helpers import ToolRulesSolver
|
||||
from letta.jobs.llm_batch_job_polling import poll_running_llm_batches
|
||||
from letta.orm import Base
|
||||
from letta.schemas.agent import AgentState, AgentStepState
|
||||
from letta.schemas.agent import AgentState, AgentStepState, CreateAgent
|
||||
from letta.schemas.enums import AgentStepStatus, JobStatus, MessageRole, ProviderType
|
||||
from letta.schemas.job import BatchJob
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
@ -31,10 +27,10 @@ from letta.schemas.message import MessageCreate
|
||||
from letta.schemas.tool_rule import InitToolRule
|
||||
from letta.server.db import db_context
|
||||
from letta.server.server import SyncServer
|
||||
from tests.utils import wait_for_server
|
||||
from tests.utils import create_tool_from_func
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Test Constants
|
||||
# Test Constants / Helpers
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
# Model identifiers used in tests
|
||||
@ -54,7 +50,7 @@ EXPECTED_ROLES = ["system", "assistant", "tool", "user", "user"]
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def weather_tool(client):
|
||||
def weather_tool(server):
|
||||
def get_weather(location: str) -> str:
|
||||
"""
|
||||
Fetches the current weather for a given location.
|
||||
@ -79,13 +75,14 @@ def weather_tool(client):
|
||||
else:
|
||||
raise RuntimeError(f"Failed to get weather data, status code: {response.status_code}")
|
||||
|
||||
tool = client.tools.upsert_from_function(func=get_weather)
|
||||
actor = server.user_manager.get_user_or_default()
|
||||
tool = server.tool_manager.create_or_update_tool(create_tool_from_func(func=get_weather), actor=actor)
|
||||
# Yield the created tool
|
||||
yield tool
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def rethink_tool(client):
|
||||
def rethink_tool(server):
|
||||
def rethink_memory(agent_state: "AgentState", new_memory: str, target_block_label: str) -> str: # type: ignore
|
||||
"""
|
||||
Re-evaluate the memory in block_name, integrating new and updated facts.
|
||||
@ -101,28 +98,33 @@ def rethink_tool(client):
|
||||
agent_state.memory.update_block_value(label=target_block_label, value=new_memory)
|
||||
return None
|
||||
|
||||
tool = client.tools.upsert_from_function(func=rethink_memory)
|
||||
actor = server.user_manager.get_user_or_default()
|
||||
tool = server.tool_manager.create_or_update_tool(create_tool_from_func(func=rethink_memory), actor=actor)
|
||||
# Yield the created tool
|
||||
yield tool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def agents(client, weather_tool):
|
||||
def agents(server, weather_tool):
|
||||
"""
|
||||
Create three test agents with different models.
|
||||
|
||||
Returns:
|
||||
Tuple[Agent, Agent, Agent]: Three agents with sonnet, haiku, and opus models
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default()
|
||||
|
||||
def create_agent(suffix, model_name):
|
||||
return client.agents.create(
|
||||
name=f"test_agent_{suffix}",
|
||||
include_base_tools=True,
|
||||
model=model_name,
|
||||
tags=["test_agents"],
|
||||
embedding="letta/letta-free",
|
||||
tool_ids=[weather_tool.id],
|
||||
return server.create_agent(
|
||||
CreateAgent(
|
||||
name=f"test_agent_{suffix}",
|
||||
include_base_tools=True,
|
||||
model=model_name,
|
||||
tags=["test_agents"],
|
||||
embedding="letta/letta-free",
|
||||
tool_ids=[weather_tool.id],
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
return (
|
||||
@ -290,32 +292,6 @@ def clear_batch_tables():
|
||||
session.commit()
|
||||
|
||||
|
||||
def run_server():
|
||||
"""Starts the Letta server in a background thread."""
|
||||
load_dotenv()
|
||||
from letta.server.rest_api.app import start_server
|
||||
|
||||
start_server(debug=True)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def server_url():
|
||||
"""
|
||||
Ensures a server is running and returns its base URL.
|
||||
|
||||
Uses environment variable if available, otherwise starts a server
|
||||
in a background thread.
|
||||
"""
|
||||
url = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
|
||||
|
||||
if not os.getenv("LETTA_SERVER_URL"):
|
||||
thread = threading.Thread(target=run_server, daemon=True)
|
||||
thread.start()
|
||||
wait_for_server(url)
|
||||
|
||||
return url
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
"""
|
||||
@ -324,14 +300,11 @@ def server():
|
||||
Loads and saves config to ensure proper initialization.
|
||||
"""
|
||||
config = LettaConfig.load()
|
||||
|
||||
config.save()
|
||||
return SyncServer()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def client(server_url):
|
||||
"""Creates a REST client connected to the test server."""
|
||||
return Letta(base_url=server_url)
|
||||
server = SyncServer(init_with_default_org_and_user=True)
|
||||
yield server
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -368,23 +341,27 @@ class MockAsyncIterable:
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_rethink_tool_modify_agent_state(client, disable_e2b_api_key, server, default_user, batch_job, rethink_tool):
|
||||
@pytest.mark.asyncio(loop_scope="module")
|
||||
async def test_rethink_tool_modify_agent_state(disable_e2b_api_key, server, default_user, batch_job, rethink_tool):
|
||||
target_block_label = "human"
|
||||
new_memory = "banana"
|
||||
agent = client.agents.create(
|
||||
name=f"test_agent_rethink",
|
||||
include_base_tools=True,
|
||||
model=MODELS["sonnet"],
|
||||
tags=["test_agents"],
|
||||
embedding="letta/letta-free",
|
||||
tool_ids=[rethink_tool.id],
|
||||
memory_blocks=[
|
||||
{
|
||||
"label": target_block_label,
|
||||
"value": "Name: Matt",
|
||||
},
|
||||
],
|
||||
actor = server.user_manager.get_user_or_default()
|
||||
agent = await server.create_agent_async(
|
||||
request=CreateAgent(
|
||||
name=f"test_agent_rethink",
|
||||
include_base_tools=True,
|
||||
model=MODELS["sonnet"],
|
||||
tags=["test_agents"],
|
||||
embedding="letta/letta-free",
|
||||
tool_ids=[rethink_tool.id],
|
||||
memory_blocks=[
|
||||
{
|
||||
"label": target_block_label,
|
||||
"value": "Name: Matt",
|
||||
},
|
||||
],
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
agents = [agent]
|
||||
batch_requests = [
|
||||
@ -444,13 +421,13 @@ async def test_rethink_tool_modify_agent_state(client, disable_e2b_api_key, serv
|
||||
await poll_running_llm_batches(server)
|
||||
|
||||
# Check that the tool has been executed correctly
|
||||
agent = client.agents.retrieve(agent_id=agent.id)
|
||||
agent = server.agent_manager.get_agent_by_id(agent_id=agent.id, actor=actor)
|
||||
for block in agent.memory.blocks:
|
||||
if block.label == target_block_label:
|
||||
assert block.value == new_memory
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.asyncio(loop_scope="module")
|
||||
async def test_partial_error_from_anthropic_batch(
|
||||
disable_e2b_api_key, server, default_user, agents: Tuple[AgentState], batch_requests, step_state_map, batch_job
|
||||
):
|
||||
@ -610,7 +587,7 @@ async def test_partial_error_from_anthropic_batch(
|
||||
assert agent_messages[0].role == MessageRole.user, "Expected initial user message"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.asyncio(loop_scope="module")
|
||||
async def test_resume_step_some_stop(
|
||||
disable_e2b_api_key, server, default_user, agents: Tuple[AgentState], batch_requests, step_state_map, batch_job
|
||||
):
|
||||
@ -773,7 +750,7 @@ def _assert_descending_order(messages):
|
||||
return True
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.asyncio(loop_scope="module")
|
||||
async def test_resume_step_after_request_all_continue(
|
||||
disable_e2b_api_key, server, default_user, agents: Tuple[AgentState], batch_requests, step_state_map, batch_job
|
||||
):
|
||||
@ -911,7 +888,7 @@ async def test_resume_step_after_request_all_continue(
|
||||
assert agent_messages[-4].role == MessageRole.user, "Expected final system-level heartbeat user message"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.asyncio(loop_scope="module")
|
||||
async def test_step_until_request_prepares_and_submits_batch_correctly(
|
||||
disable_e2b_api_key, server, default_user, agents, batch_requests, step_state_map, dummy_batch_response, batch_job
|
||||
):
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -2,7 +2,7 @@ import pytest
|
||||
from sqlalchemy import delete
|
||||
|
||||
from letta.config import LettaConfig
|
||||
from letta.orm import Provider, Step
|
||||
from letta.orm import Provider, ProviderTrace, Step
|
||||
from letta.schemas.agent import CreateAgent
|
||||
from letta.schemas.block import CreateBlock
|
||||
from letta.schemas.group import (
|
||||
@ -38,6 +38,7 @@ def org_id(server):
|
||||
|
||||
# cleanup
|
||||
with db_registry.session() as session:
|
||||
session.execute(delete(ProviderTrace))
|
||||
session.execute(delete(Step))
|
||||
session.execute(delete(Provider))
|
||||
session.commit()
|
||||
|
205
tests/test_provider_trace.py
Normal file
205
tests/test_provider_trace.py
Normal file
@ -0,0 +1,205 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from dotenv import load_dotenv
|
||||
from letta_client import Letta
|
||||
|
||||
from letta.agents.letta_agent import LettaAgent
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import MessageCreate
|
||||
from letta.server.rest_api.streaming_response import StreamingResponseWithStatusCode
|
||||
from letta.services.agent_manager import AgentManager
|
||||
from letta.services.block_manager import BlockManager
|
||||
from letta.services.message_manager import MessageManager
|
||||
from letta.services.passage_manager import PassageManager
|
||||
from letta.services.step_manager import StepManager
|
||||
from letta.services.telemetry_manager import NoopTelemetryManager, TelemetryManager
|
||||
|
||||
|
||||
def _run_server():
|
||||
"""Starts the Letta server in a background thread."""
|
||||
load_dotenv()
|
||||
from letta.server.rest_api.app import start_server
|
||||
|
||||
start_server(debug=True)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def server_url():
|
||||
"""Ensures a server is running and returns its base URL."""
|
||||
url = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
|
||||
|
||||
if not os.getenv("LETTA_SERVER_URL"):
|
||||
thread = threading.Thread(target=_run_server, daemon=True)
|
||||
thread.start()
|
||||
time.sleep(5) # Allow server startup time
|
||||
|
||||
return url
|
||||
|
||||
|
||||
# # --- Client Setup --- #
|
||||
@pytest.fixture(scope="session")
|
||||
def client(server_url):
|
||||
"""Creates a REST client for testing."""
|
||||
client = Letta(base_url=server_url)
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def event_loop(request):
|
||||
"""Create an instance of the default event loop for each test case."""
|
||||
loop = asyncio.get_event_loop_policy().new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def roll_dice_tool(client, roll_dice_tool_func):
|
||||
print_tool = client.tools.upsert_from_function(func=roll_dice_tool_func)
|
||||
yield print_tool
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def weather_tool(client, weather_tool_func):
|
||||
weather_tool = client.tools.upsert_from_function(func=weather_tool_func)
|
||||
yield weather_tool
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def print_tool(client, print_tool_func):
|
||||
print_tool = client.tools.upsert_from_function(func=print_tool_func)
|
||||
yield print_tool
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def agent_state(client, roll_dice_tool, weather_tool):
|
||||
"""Creates an agent and ensures cleanup after tests."""
|
||||
agent_state = client.agents.create(
|
||||
name=f"test_compl_{str(uuid.uuid4())[5:]}",
|
||||
tool_ids=[roll_dice_tool.id, weather_tool.id],
|
||||
include_base_tools=True,
|
||||
memory_blocks=[
|
||||
{
|
||||
"label": "human",
|
||||
"value": "Name: Matt",
|
||||
},
|
||||
{
|
||||
"label": "persona",
|
||||
"value": "Friendly agent",
|
||||
},
|
||||
],
|
||||
llm_config=LLMConfig.default_config(model_name="gpt-4o-mini"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
)
|
||||
yield agent_state
|
||||
client.agents.delete(agent_state.id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("message", ["Get the weather in San Francisco."])
|
||||
async def test_provider_trace_experimental_step(message, agent_state, default_user):
|
||||
experimental_agent = LettaAgent(
|
||||
agent_id=agent_state.id,
|
||||
message_manager=MessageManager(),
|
||||
agent_manager=AgentManager(),
|
||||
block_manager=BlockManager(),
|
||||
passage_manager=PassageManager(),
|
||||
step_manager=StepManager(),
|
||||
telemetry_manager=TelemetryManager(),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
response = await experimental_agent.step([MessageCreate(role="user", content=[TextContent(text=message)])])
|
||||
tool_step = response.messages[0].step_id
|
||||
reply_step = response.messages[-1].step_id
|
||||
|
||||
tool_telemetry = await experimental_agent.telemetry_manager.get_provider_trace_by_step_id_async(step_id=tool_step, actor=default_user)
|
||||
reply_telemetry = await experimental_agent.telemetry_manager.get_provider_trace_by_step_id_async(step_id=reply_step, actor=default_user)
|
||||
assert tool_telemetry.request_json
|
||||
assert reply_telemetry.request_json
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("message", ["Get the weather in San Francisco."])
|
||||
async def test_provider_trace_experimental_step_stream(message, agent_state, default_user, event_loop):
|
||||
experimental_agent = LettaAgent(
|
||||
agent_id=agent_state.id,
|
||||
message_manager=MessageManager(),
|
||||
agent_manager=AgentManager(),
|
||||
block_manager=BlockManager(),
|
||||
passage_manager=PassageManager(),
|
||||
step_manager=StepManager(),
|
||||
telemetry_manager=TelemetryManager(),
|
||||
actor=default_user,
|
||||
)
|
||||
stream = experimental_agent.step_stream([MessageCreate(role="user", content=[TextContent(text=message)])])
|
||||
|
||||
result = StreamingResponseWithStatusCode(
|
||||
stream,
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
message_id = None
|
||||
|
||||
async def test_send(message) -> None:
|
||||
nonlocal message_id
|
||||
if "body" in message and not message_id:
|
||||
body = message["body"].decode("utf-8").split("data:")
|
||||
message_id = json.loads(body[1])["id"]
|
||||
|
||||
await result.stream_response(send=test_send)
|
||||
|
||||
messages = await experimental_agent.message_manager.get_messages_by_ids_async([message_id], actor=default_user)
|
||||
step_ids = set((message.step_id for message in messages))
|
||||
for step_id in step_ids:
|
||||
telemetry_data = await experimental_agent.telemetry_manager.get_provider_trace_by_step_id_async(step_id=step_id, actor=default_user)
|
||||
assert telemetry_data.request_json
|
||||
assert telemetry_data.response_json
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("message", ["Get the weather in San Francisco."])
|
||||
async def test_provider_trace_step(client, agent_state, default_user, message, event_loop):
|
||||
client.agents.messages.create(agent_id=agent_state.id, messages=[])
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent_state.id,
|
||||
messages=[MessageCreate(role="user", content=[TextContent(text=message)])],
|
||||
)
|
||||
tool_step = response.messages[0].step_id
|
||||
reply_step = response.messages[-1].step_id
|
||||
|
||||
tool_telemetry = await TelemetryManager().get_provider_trace_by_step_id_async(step_id=tool_step, actor=default_user)
|
||||
reply_telemetry = await TelemetryManager().get_provider_trace_by_step_id_async(step_id=reply_step, actor=default_user)
|
||||
assert tool_telemetry.request_json
|
||||
assert reply_telemetry.request_json
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("message", ["Get the weather in San Francisco."])
|
||||
async def test_noop_provider_trace(message, agent_state, default_user, event_loop):
|
||||
experimental_agent = LettaAgent(
|
||||
agent_id=agent_state.id,
|
||||
message_manager=MessageManager(),
|
||||
agent_manager=AgentManager(),
|
||||
block_manager=BlockManager(),
|
||||
passage_manager=PassageManager(),
|
||||
step_manager=StepManager(),
|
||||
telemetry_manager=NoopTelemetryManager(),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
response = await experimental_agent.step([MessageCreate(role="user", content=[TextContent(text=message)])])
|
||||
tool_step = response.messages[0].step_id
|
||||
reply_step = response.messages[-1].step_id
|
||||
|
||||
tool_telemetry = await experimental_agent.telemetry_manager.get_provider_trace_by_step_id_async(step_id=tool_step, actor=default_user)
|
||||
reply_telemetry = await experimental_agent.telemetry_manager.get_provider_trace_by_step_id_async(step_id=reply_step, actor=default_user)
|
||||
assert tool_telemetry is None
|
||||
assert reply_telemetry is None
|
@ -1,15 +1,12 @@
|
||||
import os
|
||||
import pytest
|
||||
|
||||
from letta.schemas.providers import (
|
||||
AnthropicBedrockProvider,
|
||||
AnthropicProvider,
|
||||
AzureProvider,
|
||||
DeepSeekProvider,
|
||||
GoogleAIProvider,
|
||||
GoogleVertexProvider,
|
||||
GroqProvider,
|
||||
MistralProvider,
|
||||
OllamaProvider,
|
||||
OpenAIProvider,
|
||||
TogetherProvider,
|
||||
)
|
||||
@ -17,11 +14,9 @@ from letta.settings import model_settings
|
||||
|
||||
|
||||
def test_openai():
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
assert api_key is not None
|
||||
provider = OpenAIProvider(
|
||||
name="openai",
|
||||
api_key=api_key,
|
||||
api_key=model_settings.openai_api_key,
|
||||
base_url=model_settings.openai_api_base,
|
||||
)
|
||||
models = provider.list_llm_models()
|
||||
@ -33,34 +28,54 @@ def test_openai():
|
||||
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
|
||||
|
||||
|
||||
def test_deepseek():
|
||||
api_key = os.getenv("DEEPSEEK_API_KEY")
|
||||
assert api_key is not None
|
||||
provider = DeepSeekProvider(
|
||||
name="deepseek",
|
||||
api_key=api_key,
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_async():
|
||||
provider = OpenAIProvider(
|
||||
name="openai",
|
||||
api_key=model_settings.openai_api_key,
|
||||
base_url=model_settings.openai_api_base,
|
||||
)
|
||||
models = await provider.list_llm_models_async()
|
||||
assert len(models) > 0
|
||||
assert models[0].handle == f"{provider.name}/{models[0].model}"
|
||||
|
||||
embedding_models = await provider.list_embedding_models_async()
|
||||
assert len(embedding_models) > 0
|
||||
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
|
||||
|
||||
|
||||
def test_deepseek():
|
||||
provider = DeepSeekProvider(name="deepseek", api_key=model_settings.deepseek_api_key)
|
||||
models = provider.list_llm_models()
|
||||
assert len(models) > 0
|
||||
assert models[0].handle == f"{provider.name}/{models[0].model}"
|
||||
|
||||
|
||||
def test_anthropic():
|
||||
api_key = os.getenv("ANTHROPIC_API_KEY")
|
||||
assert api_key is not None
|
||||
provider = AnthropicProvider(
|
||||
name="anthropic",
|
||||
api_key=api_key,
|
||||
api_key=model_settings.anthropic_api_key,
|
||||
)
|
||||
models = provider.list_llm_models()
|
||||
assert len(models) > 0
|
||||
assert models[0].handle == f"{provider.name}/{models[0].model}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_anthropic_async():
|
||||
provider = AnthropicProvider(
|
||||
name="anthropic",
|
||||
api_key=model_settings.anthropic_api_key,
|
||||
)
|
||||
models = await provider.list_llm_models_async()
|
||||
assert len(models) > 0
|
||||
assert models[0].handle == f"{provider.name}/{models[0].model}"
|
||||
|
||||
|
||||
def test_groq():
|
||||
provider = GroqProvider(
|
||||
name="groq",
|
||||
api_key=os.getenv("GROQ_API_KEY"),
|
||||
api_key=model_settings.groq_api_key,
|
||||
)
|
||||
models = provider.list_llm_models()
|
||||
assert len(models) > 0
|
||||
@ -70,8 +85,9 @@ def test_groq():
|
||||
def test_azure():
|
||||
provider = AzureProvider(
|
||||
name="azure",
|
||||
api_key=os.getenv("AZURE_API_KEY"),
|
||||
base_url=os.getenv("AZURE_BASE_URL"),
|
||||
api_key=model_settings.azure_api_key,
|
||||
base_url=model_settings.azure_base_url,
|
||||
api_version=model_settings.azure_api_version,
|
||||
)
|
||||
models = provider.list_llm_models()
|
||||
assert len(models) > 0
|
||||
@ -82,26 +98,24 @@ def test_azure():
|
||||
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
|
||||
|
||||
|
||||
def test_ollama():
|
||||
base_url = os.getenv("OLLAMA_BASE_URL")
|
||||
assert base_url is not None
|
||||
provider = OllamaProvider(
|
||||
name="ollama",
|
||||
base_url=base_url,
|
||||
default_prompt_formatter=model_settings.default_prompt_formatter,
|
||||
api_key=None,
|
||||
)
|
||||
models = provider.list_llm_models()
|
||||
assert len(models) > 0
|
||||
assert models[0].handle == f"{provider.name}/{models[0].model}"
|
||||
|
||||
embedding_models = provider.list_embedding_models()
|
||||
assert len(embedding_models) > 0
|
||||
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
|
||||
# def test_ollama():
|
||||
# provider = OllamaProvider(
|
||||
# name="ollama",
|
||||
# base_url=model_settings.ollama_base_url,
|
||||
# api_key=None,
|
||||
# default_prompt_formatter=model_settings.default_prompt_formatter,
|
||||
# )
|
||||
# models = provider.list_llm_models()
|
||||
# assert len(models) > 0
|
||||
# assert models[0].handle == f"{provider.name}/{models[0].model}"
|
||||
#
|
||||
# embedding_models = provider.list_embedding_models()
|
||||
# assert len(embedding_models) > 0
|
||||
# assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
|
||||
|
||||
|
||||
def test_googleai():
|
||||
api_key = os.getenv("GEMINI_API_KEY")
|
||||
api_key = model_settings.gemini_api_key
|
||||
assert api_key is not None
|
||||
provider = GoogleAIProvider(
|
||||
name="google_ai",
|
||||
@ -116,11 +130,28 @@ def test_googleai():
|
||||
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_googleai_async():
|
||||
api_key = model_settings.gemini_api_key
|
||||
assert api_key is not None
|
||||
provider = GoogleAIProvider(
|
||||
name="google_ai",
|
||||
api_key=api_key,
|
||||
)
|
||||
models = await provider.list_llm_models_async()
|
||||
assert len(models) > 0
|
||||
assert models[0].handle == f"{provider.name}/{models[0].model}"
|
||||
|
||||
embedding_models = await provider.list_embedding_models_async()
|
||||
assert len(embedding_models) > 0
|
||||
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
|
||||
|
||||
|
||||
def test_google_vertex():
|
||||
provider = GoogleVertexProvider(
|
||||
name="google_vertex",
|
||||
google_cloud_project=os.getenv("GCP_PROJECT_ID"),
|
||||
google_cloud_location=os.getenv("GCP_REGION"),
|
||||
google_cloud_project=model_settings.google_cloud_project,
|
||||
google_cloud_location=model_settings.google_cloud_location,
|
||||
)
|
||||
models = provider.list_llm_models()
|
||||
assert len(models) > 0
|
||||
@ -131,50 +162,57 @@ def test_google_vertex():
|
||||
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
|
||||
|
||||
|
||||
def test_mistral():
|
||||
provider = MistralProvider(
|
||||
name="mistral",
|
||||
api_key=os.getenv("MISTRAL_API_KEY"),
|
||||
)
|
||||
models = provider.list_llm_models()
|
||||
assert len(models) > 0
|
||||
assert models[0].handle == f"{provider.name}/{models[0].model}"
|
||||
|
||||
|
||||
def test_together():
|
||||
provider = TogetherProvider(
|
||||
name="together",
|
||||
api_key=os.getenv("TOGETHER_API_KEY"),
|
||||
default_prompt_formatter="chatml",
|
||||
api_key=model_settings.together_api_key,
|
||||
default_prompt_formatter=model_settings.default_prompt_formatter,
|
||||
)
|
||||
models = provider.list_llm_models()
|
||||
assert len(models) > 0
|
||||
assert models[0].handle == f"{provider.name}/{models[0].model}"
|
||||
|
||||
embedding_models = provider.list_embedding_models()
|
||||
assert len(embedding_models) > 0
|
||||
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
|
||||
# TODO: We don't have embedding models on together for CI
|
||||
# embedding_models = provider.list_embedding_models()
|
||||
# assert len(embedding_models) > 0
|
||||
# assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
|
||||
|
||||
|
||||
def test_anthropic_bedrock():
|
||||
from letta.settings import model_settings
|
||||
|
||||
provider = AnthropicBedrockProvider(name="bedrock", aws_region=model_settings.aws_region)
|
||||
models = provider.list_llm_models()
|
||||
@pytest.mark.asyncio
|
||||
async def test_together_async():
|
||||
provider = TogetherProvider(
|
||||
name="together",
|
||||
api_key=model_settings.together_api_key,
|
||||
default_prompt_formatter=model_settings.default_prompt_formatter,
|
||||
)
|
||||
models = await provider.list_llm_models_async()
|
||||
assert len(models) > 0
|
||||
assert models[0].handle == f"{provider.name}/{models[0].model}"
|
||||
|
||||
embedding_models = provider.list_embedding_models()
|
||||
assert len(embedding_models) > 0
|
||||
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
|
||||
# TODO: We don't have embedding models on together for CI
|
||||
# embedding_models = provider.list_embedding_models()
|
||||
# assert len(embedding_models) > 0
|
||||
# assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
|
||||
|
||||
|
||||
# TODO: Add back in, difficulty adding this to CI properly, need boto credentials
|
||||
# def test_anthropic_bedrock():
|
||||
# from letta.settings import model_settings
|
||||
#
|
||||
# provider = AnthropicBedrockProvider(name="bedrock", aws_region=model_settings.aws_region)
|
||||
# models = provider.list_llm_models()
|
||||
# assert len(models) > 0
|
||||
# assert models[0].handle == f"{provider.name}/{models[0].model}"
|
||||
#
|
||||
# embedding_models = provider.list_embedding_models()
|
||||
# assert len(embedding_models) > 0
|
||||
# assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
|
||||
|
||||
|
||||
def test_custom_anthropic():
|
||||
api_key = os.getenv("ANTHROPIC_API_KEY")
|
||||
assert api_key is not None
|
||||
provider = AnthropicProvider(
|
||||
name="custom_anthropic",
|
||||
api_key=api_key,
|
||||
api_key=model_settings.anthropic_api_key,
|
||||
)
|
||||
models = provider.list_llm_models()
|
||||
assert len(models) > 0
|
||||
|
@ -11,7 +11,7 @@ from sqlalchemy import delete
|
||||
|
||||
import letta.utils as utils
|
||||
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, LETTA_DIR, LETTA_TOOL_EXECUTION_DIR
|
||||
from letta.orm import Provider, Step
|
||||
from letta.orm import Provider, ProviderTrace, Step
|
||||
from letta.schemas.block import CreateBlock
|
||||
from letta.schemas.enums import MessageRole, ProviderCategory, ProviderType
|
||||
from letta.schemas.letta_message import LettaMessage, ReasoningMessage, SystemMessage, ToolCallMessage, ToolReturnMessage, UserMessage
|
||||
@ -286,6 +286,7 @@ def org_id(server):
|
||||
|
||||
# cleanup
|
||||
with db_registry.session() as session:
|
||||
session.execute(delete(ProviderTrace))
|
||||
session.execute(delete(Step))
|
||||
session.execute(delete(Provider))
|
||||
session.commit()
|
||||
@ -565,7 +566,8 @@ def test_delete_agent_same_org(server: SyncServer, org_id: str, user: User):
|
||||
server.agent_manager.delete_agent(agent_state.id, actor=another_user)
|
||||
|
||||
|
||||
def test_read_local_llm_configs(server: SyncServer, user: User):
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_local_llm_configs(server: SyncServer, user: User):
|
||||
configs_base_dir = os.path.join(os.path.expanduser("~"), ".letta", "llm_configs")
|
||||
clean_up_dir = False
|
||||
if not os.path.exists(configs_base_dir):
|
||||
@ -588,7 +590,7 @@ def test_read_local_llm_configs(server: SyncServer, user: User):
|
||||
|
||||
# Call list_llm_models
|
||||
assert os.path.exists(configs_base_dir)
|
||||
llm_models = server.list_llm_models(actor=user)
|
||||
llm_models = await server.list_llm_models_async(actor=user)
|
||||
|
||||
# Assert that the config is in the returned models
|
||||
assert any(
|
||||
@ -935,7 +937,7 @@ def test_composio_client_simple(server):
|
||||
assert len(actions) > 0
|
||||
|
||||
|
||||
def test_memory_rebuild_count(server, user, disable_e2b_api_key, base_tools, base_memory_tools):
|
||||
async def test_memory_rebuild_count(server, user, disable_e2b_api_key, base_tools, base_memory_tools):
|
||||
"""Test that the memory rebuild is generating the correct number of role=system messages"""
|
||||
actor = user
|
||||
# create agent
|
||||
@ -1223,7 +1225,8 @@ def test_add_remove_tools_update_agent(server: SyncServer, user_id: str, base_to
|
||||
assert len(agent_state.tools) == len(base_tools) - 2
|
||||
|
||||
|
||||
def test_messages_with_provider_override(server: SyncServer, user_id: str):
|
||||
@pytest.mark.asyncio
|
||||
async def test_messages_with_provider_override(server: SyncServer, user_id: str):
|
||||
actor = server.user_manager.get_user_or_default(user_id)
|
||||
provider = server.provider_manager.create_provider(
|
||||
request=ProviderCreate(
|
||||
@ -1233,10 +1236,10 @@ def test_messages_with_provider_override(server: SyncServer, user_id: str):
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
models = server.list_llm_models(actor=actor, provider_category=[ProviderCategory.byok])
|
||||
models = await server.list_llm_models_async(actor=actor, provider_category=[ProviderCategory.byok])
|
||||
assert provider.name in [model.provider_name for model in models]
|
||||
|
||||
models = server.list_llm_models(actor=actor, provider_category=[ProviderCategory.base])
|
||||
models = await server.list_llm_models_async(actor=actor, provider_category=[ProviderCategory.base])
|
||||
assert provider.name not in [model.provider_name for model in models]
|
||||
|
||||
agent = server.create_agent(
|
||||
@ -1302,11 +1305,12 @@ def test_messages_with_provider_override(server: SyncServer, user_id: str):
|
||||
assert total_tokens == usage.total_tokens
|
||||
|
||||
|
||||
def test_unique_handles_for_provider_configs(server: SyncServer, user: User):
|
||||
models = server.list_llm_models(actor=user)
|
||||
@pytest.mark.asyncio
|
||||
async def test_unique_handles_for_provider_configs(server: SyncServer, user: User):
|
||||
models = await server.list_llm_models_async(actor=user)
|
||||
model_handles = [model.handle for model in models]
|
||||
assert sorted(model_handles) == sorted(list(set(model_handles))), "All models should have unique handles"
|
||||
embeddings = server.list_embedding_models(actor=user)
|
||||
embeddings = await server.list_embedding_models_async(actor=user)
|
||||
embedding_handles = [embedding.handle for embedding in embeddings]
|
||||
assert sorted(embedding_handles) == sorted(list(set(embedding_handles))), "All embeddings should have unique handles"
|
||||
|
||||
|
@ -4,15 +4,16 @@ import string
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from importlib import util
|
||||
from typing import Dict, Iterator, List, Tuple
|
||||
from typing import Dict, Iterator, List, Optional, Tuple
|
||||
|
||||
import requests
|
||||
from letta_client import Letta, SystemMessage
|
||||
|
||||
from letta.config import LettaConfig
|
||||
from letta.data_sources.connectors import DataConnector
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.functions.functions import parse_source_code
|
||||
from letta.schemas.file import FileMetadata
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.tool import Tool
|
||||
from letta.settings import TestSettings
|
||||
|
||||
from .constants import TIMEOUT
|
||||
@ -152,7 +153,7 @@ def with_qdrant_storage(storage: list[str]):
|
||||
|
||||
|
||||
def wait_for_incoming_message(
|
||||
client,
|
||||
client: Letta,
|
||||
agent_id: str,
|
||||
substring: str = "[Incoming message from agent with ID",
|
||||
max_wait_seconds: float = 10.0,
|
||||
@ -166,13 +167,13 @@ def wait_for_incoming_message(
|
||||
deadline = time.time() + max_wait_seconds
|
||||
|
||||
while time.time() < deadline:
|
||||
messages = client.server.message_manager.list_messages_for_agent(agent_id=agent_id, actor=client.user)
|
||||
messages = client.agents.messages.list(agent_id)[1:]
|
||||
|
||||
# Check for the system message containing `substring`
|
||||
def get_message_text(message: Message) -> str:
|
||||
return message.content[0].text if message.content and len(message.content) == 1 else ""
|
||||
def get_message_text(message: SystemMessage) -> str:
|
||||
return message.content if message.content else ""
|
||||
|
||||
if any(message.role == MessageRole.system and substring in get_message_text(message) for message in messages):
|
||||
if any(isinstance(message, SystemMessage) and substring in get_message_text(message) for message in messages):
|
||||
return True
|
||||
time.sleep(sleep_interval)
|
||||
|
||||
@ -199,3 +200,21 @@ def wait_for_server(url, timeout=30, interval=0.5):
|
||||
|
||||
def random_string(length: int) -> str:
|
||||
return "".join(random.choices(string.ascii_letters + string.digits, k=length))
|
||||
|
||||
|
||||
def create_tool_from_func(
|
||||
func,
|
||||
tags: Optional[List[str]] = None,
|
||||
description: Optional[str] = None,
|
||||
):
|
||||
source_code = parse_source_code(func)
|
||||
source_type = "python"
|
||||
if not tags:
|
||||
tags = []
|
||||
|
||||
return Tool(
|
||||
source_type=source_type,
|
||||
source_code=source_code,
|
||||
tags=tags,
|
||||
description=description,
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user