chore: bump version 0.7.8 (#2604)

Co-authored-by: Kian Jones <11655409+kianjones9@users.noreply.github.com>
Co-authored-by: Andy Li <55300002+cliandy@users.noreply.github.com>
Co-authored-by: Matthew Zhou <mattzh1314@gmail.com>
This commit is contained in:
cthomas 2025-04-30 23:39:58 -07:00 committed by GitHub
parent e07edd840b
commit 20ecab29a1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
65 changed files with 1248 additions and 649 deletions

View File

@ -167,7 +167,7 @@ docker exec -it $(docker ps -q -f ancestor=letta/letta) letta run
In the CLI tool, you'll be able to create new agents, or load existing agents: In the CLI tool, you'll be able to create new agents, or load existing agents:
``` ```
🧬 Creating new agent... 🧬 Creating new agent...
? Select LLM model: letta-free [type=openai] [ip=https://inference.memgpt.ai] ? Select LLM model: letta-free [type=openai] [ip=https://inference.letta.com]
? Select embedding model: letta-free [type=hugging-face] [ip=https://embeddings.memgpt.ai] ? Select embedding model: letta-free [type=hugging-face] [ip=https://embeddings.memgpt.ai]
-> 🤖 Using persona profile: 'sam_pov' -> 🤖 Using persona profile: 'sam_pov'
-> 🧑 Using human profile: 'basic' -> 🧑 Using human profile: 'basic'
@ -233,7 +233,7 @@ letta run
``` ```
``` ```
🧬 Creating new agent... 🧬 Creating new agent...
? Select LLM model: letta-free [type=openai] [ip=https://inference.memgpt.ai] ? Select LLM model: letta-free [type=openai] [ip=https://inference.letta.com]
? Select embedding model: letta-free [type=hugging-face] [ip=https://embeddings.memgpt.ai] ? Select embedding model: letta-free [type=hugging-face] [ip=https://embeddings.memgpt.ai]
-> 🤖 Using persona profile: 'sam_pov' -> 🤖 Using persona profile: 'sam_pov'
-> 🧑 Using human profile: 'basic' -> 🧑 Using human profile: 'basic'

View File

@ -0,0 +1,35 @@
"""add byok fields and unique constraint
Revision ID: 373dabcba6cf
Revises: c56081a05371
Create Date: 2025-04-30 19:38:25.010856
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "373dabcba6cf"
down_revision: Union[str, None] = "c56081a05371"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("providers", sa.Column("provider_type", sa.String(), nullable=True))
op.add_column("providers", sa.Column("base_url", sa.String(), nullable=True))
op.create_unique_constraint("unique_name_organization_id", "providers", ["name", "organization_id"])
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint("unique_name_organization_id", "providers", type_="unique")
op.drop_column("providers", "base_url")
op.drop_column("providers", "provider_type")
# ### end Alembic commands ###

View File

@ -0,0 +1,33 @@
"""Add buffer length min max for voice sleeptime
Revision ID: c56081a05371
Revises: 28b8765bdd0a
Create Date: 2025-04-30 16:03:41.213750
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "c56081a05371"
down_revision: Union[str, None] = "28b8765bdd0a"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("groups", sa.Column("max_message_buffer_length", sa.Integer(), nullable=True))
op.add_column("groups", sa.Column("min_message_buffer_length", sa.Integer(), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("groups", "min_message_buffer_length")
op.drop_column("groups", "max_message_buffer_length")
# ### end Alembic commands ###

View File

@ -60,7 +60,7 @@ Last updated Oct 2, 2024. Please check `composio` documentation for any composio
def main(): def main():
from composio_langchain import Action from composio import Action
# Add the composio tool # Add the composio tool
tool = client.load_composio_tool(action=Action.GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER) tool = client.load_composio_tool(action=Action.GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER)

View File

@ -0,0 +1,32 @@
from letta_client import Letta, VoiceSleeptimeManagerUpdate
client = Letta(base_url="http://localhost:8283")
agent = client.agents.create(
name="low_latency_voice_agent_demo",
agent_type="voice_convo_agent",
memory_blocks=[
{"value": "Name: ?", "label": "human"},
{"value": "You are a helpful assistant.", "label": "persona"},
],
model="openai/gpt-4o-mini", # Use 4o-mini for speed
embedding="openai/text-embedding-3-small",
enable_sleeptime=True,
initial_message_sequence = [],
)
print(f"Created agent id {agent.id}")
# get the group
group_id = agent.multi_agent_group.id
max_message_buffer_length = agent.multi_agent_group.max_message_buffer_length
min_message_buffer_length = agent.multi_agent_group.min_message_buffer_length
print(f"Group id: {group_id}, max_message_buffer_length: {max_message_buffer_length}, min_message_buffer_length: {min_message_buffer_length}")
# change it to be more frequent
group = client.groups.modify(
group_id=group_id,
manager_config=VoiceSleeptimeManagerUpdate(
max_message_buffer_length=10,
min_message_buffer_length=6,
)
)

View File

@ -1,4 +1,4 @@
__version__ = "0.7.7" __version__ = "0.7.8"
# import clients # import clients
from letta.client.client import LocalClient, RESTClient, create_client from letta.client.client import LocalClient, RESTClient, create_client

View File

@ -21,14 +21,14 @@ from letta.constants import (
) )
from letta.errors import ContextWindowExceededError from letta.errors import ContextWindowExceededError
from letta.functions.ast_parsers import coerce_dict_args_by_annotations, get_function_annotations_from_source from letta.functions.ast_parsers import coerce_dict_args_by_annotations, get_function_annotations_from_source
from letta.functions.composio_helpers import execute_composio_action, generate_composio_action_from_func_name
from letta.functions.functions import get_function_from_module from letta.functions.functions import get_function_from_module
from letta.functions.helpers import execute_composio_action, generate_composio_action_from_func_name
from letta.functions.mcp_client.base_client import BaseMCPClient from letta.functions.mcp_client.base_client import BaseMCPClient
from letta.helpers import ToolRulesSolver from letta.helpers import ToolRulesSolver
from letta.helpers.composio_helpers import get_composio_api_key from letta.helpers.composio_helpers import get_composio_api_key
from letta.helpers.datetime_helpers import get_utc_time from letta.helpers.datetime_helpers import get_utc_time
from letta.helpers.json_helpers import json_dumps, json_loads from letta.helpers.json_helpers import json_dumps, json_loads
from letta.helpers.message_helper import prepare_input_message_create from letta.helpers.message_helper import convert_message_creates_to_messages
from letta.interface import AgentInterface from letta.interface import AgentInterface
from letta.llm_api.helpers import calculate_summarizer_cutoff, get_token_counts_for_messages, is_context_overflow_error from letta.llm_api.helpers import calculate_summarizer_cutoff, get_token_counts_for_messages, is_context_overflow_error
from letta.llm_api.llm_api_tools import create from letta.llm_api.llm_api_tools import create
@ -331,8 +331,10 @@ class Agent(BaseAgent):
log_telemetry(self.logger, "_get_ai_reply create start") log_telemetry(self.logger, "_get_ai_reply create start")
# New LLM client flow # New LLM client flow
llm_client = LLMClient.create( llm_client = LLMClient.create(
provider=self.agent_state.llm_config.model_endpoint_type, provider_name=self.agent_state.llm_config.provider_name,
provider_type=self.agent_state.llm_config.model_endpoint_type,
put_inner_thoughts_first=put_inner_thoughts_first, put_inner_thoughts_first=put_inner_thoughts_first,
actor_id=self.user.id,
) )
if llm_client and not stream: if llm_client and not stream:
@ -726,8 +728,7 @@ class Agent(BaseAgent):
self.tool_rules_solver.clear_tool_history() self.tool_rules_solver.clear_tool_history()
# Convert MessageCreate objects to Message objects # Convert MessageCreate objects to Message objects
message_objects = [prepare_input_message_create(m, self.agent_state.id, True, True) for m in input_messages] next_input_messages = convert_message_creates_to_messages(input_messages, self.agent_state.id)
next_input_messages = message_objects
counter = 0 counter = 0
total_usage = UsageStatistics() total_usage = UsageStatistics()
step_count = 0 step_count = 0
@ -942,12 +943,7 @@ class Agent(BaseAgent):
model_endpoint=self.agent_state.llm_config.model_endpoint, model_endpoint=self.agent_state.llm_config.model_endpoint,
context_window_limit=self.agent_state.llm_config.context_window, context_window_limit=self.agent_state.llm_config.context_window,
usage=response.usage, usage=response.usage,
# TODO(@caren): Add full provider support - this line is a workaround for v0 BYOK feature provider_id=self.provider_manager.get_provider_id_from_name(self.agent_state.llm_config.provider_name),
provider_id=(
self.provider_manager.get_anthropic_override_provider_id()
if self.agent_state.llm_config.model_endpoint_type == "anthropic"
else None
),
job_id=job_id, job_id=job_id,
) )
for message in all_new_messages: for message in all_new_messages:

View File

@ -0,0 +1,6 @@
class IncompatibleAgentType(ValueError):
def __init__(self, expected_type: str, actual_type: str):
message = f"Incompatible agent type: expected '{expected_type}', but got '{actual_type}'."
super().__init__(message)
self.expected_type = expected_type
self.actual_type = actual_type

View File

@ -67,8 +67,10 @@ class LettaAgent(BaseAgent):
) )
tool_rules_solver = ToolRulesSolver(agent_state.tool_rules) tool_rules_solver = ToolRulesSolver(agent_state.tool_rules)
llm_client = LLMClient.create( llm_client = LLMClient.create(
provider=agent_state.llm_config.model_endpoint_type, provider_name=agent_state.llm_config.provider_name,
provider_type=agent_state.llm_config.model_endpoint_type,
put_inner_thoughts_first=True, put_inner_thoughts_first=True,
actor_id=self.actor.id,
) )
for step in range(max_steps): for step in range(max_steps):
response = await self._get_ai_reply( response = await self._get_ai_reply(
@ -109,8 +111,10 @@ class LettaAgent(BaseAgent):
) )
tool_rules_solver = ToolRulesSolver(agent_state.tool_rules) tool_rules_solver = ToolRulesSolver(agent_state.tool_rules)
llm_client = LLMClient.create( llm_client = LLMClient.create(
llm_config=agent_state.llm_config, provider_name=agent_state.llm_config.provider_name,
provider_type=agent_state.llm_config.model_endpoint_type,
put_inner_thoughts_first=True, put_inner_thoughts_first=True,
actor_id=self.actor.id,
) )
for step in range(max_steps): for step in range(max_steps):
@ -125,7 +129,7 @@ class LettaAgent(BaseAgent):
# TODO: THIS IS INCREDIBLY UGLY # TODO: THIS IS INCREDIBLY UGLY
# TODO: THERE ARE MULTIPLE COPIES OF THE LLM_CONFIG EVERYWHERE THAT ARE GETTING MANIPULATED # TODO: THERE ARE MULTIPLE COPIES OF THE LLM_CONFIG EVERYWHERE THAT ARE GETTING MANIPULATED
interface = AnthropicStreamingInterface( interface = AnthropicStreamingInterface(
use_assistant_message=use_assistant_message, put_inner_thoughts_in_kwarg=llm_client.llm_config.put_inner_thoughts_in_kwargs use_assistant_message=use_assistant_message, put_inner_thoughts_in_kwarg=agent_state.llm_config.put_inner_thoughts_in_kwargs
) )
async for chunk in interface.process(stream): async for chunk in interface.process(stream):
yield f"data: {chunk.model_dump_json()}\n\n" yield f"data: {chunk.model_dump_json()}\n\n"
@ -179,6 +183,7 @@ class LettaAgent(BaseAgent):
ToolType.LETTA_SLEEPTIME_CORE, ToolType.LETTA_SLEEPTIME_CORE,
} }
or (t.tool_type == ToolType.LETTA_MULTI_AGENT_CORE and t.name == "send_message_to_agents_matching_tags") or (t.tool_type == ToolType.LETTA_MULTI_AGENT_CORE and t.name == "send_message_to_agents_matching_tags")
or (t.tool_type == ToolType.EXTERNAL_COMPOSIO)
] ]
valid_tool_names = tool_rules_solver.get_allowed_tool_names(available_tools=set([t.name for t in tools])) valid_tool_names = tool_rules_solver.get_allowed_tool_names(available_tools=set([t.name for t in tools]))
@ -274,45 +279,49 @@ class LettaAgent(BaseAgent):
return persisted_messages, continue_stepping return persisted_messages, continue_stepping
def _rebuild_memory(self, in_context_messages: List[Message], agent_state: AgentState) -> List[Message]: def _rebuild_memory(self, in_context_messages: List[Message], agent_state: AgentState) -> List[Message]:
self.agent_manager.refresh_memory(agent_state=agent_state, actor=self.actor) try:
self.agent_manager.refresh_memory(agent_state=agent_state, actor=self.actor)
# TODO: This is a pretty brittle pattern established all over our code, need to get rid of this # TODO: This is a pretty brittle pattern established all over our code, need to get rid of this
curr_system_message = in_context_messages[0] curr_system_message = in_context_messages[0]
curr_memory_str = agent_state.memory.compile() curr_memory_str = agent_state.memory.compile()
curr_system_message_text = curr_system_message.content[0].text curr_system_message_text = curr_system_message.content[0].text
if curr_memory_str in curr_system_message_text: if curr_memory_str in curr_system_message_text:
# NOTE: could this cause issues if a block is removed? (substring match would still work) # NOTE: could this cause issues if a block is removed? (substring match would still work)
logger.debug( logger.debug(
f"Memory hasn't changed for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name}), skipping system prompt rebuild" f"Memory hasn't changed for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name}), skipping system prompt rebuild"
) )
return in_context_messages return in_context_messages
memory_edit_timestamp = get_utc_time() memory_edit_timestamp = get_utc_time()
num_messages = self.message_manager.size(actor=self.actor, agent_id=agent_state.id) num_messages = self.message_manager.size(actor=self.actor, agent_id=agent_state.id)
num_archival_memories = self.passage_manager.size(actor=self.actor, agent_id=agent_state.id) num_archival_memories = self.passage_manager.size(actor=self.actor, agent_id=agent_state.id)
new_system_message_str = compile_system_message( new_system_message_str = compile_system_message(
system_prompt=agent_state.system, system_prompt=agent_state.system,
in_context_memory=agent_state.memory, in_context_memory=agent_state.memory,
in_context_memory_last_edit=memory_edit_timestamp, in_context_memory_last_edit=memory_edit_timestamp,
previous_message_count=num_messages, previous_message_count=num_messages,
archival_memory_size=num_archival_memories, archival_memory_size=num_archival_memories,
)
diff = united_diff(curr_system_message_text, new_system_message_str)
if len(diff) > 0:
logger.debug(f"Rebuilding system with new memory...\nDiff:\n{diff}")
new_system_message = self.message_manager.update_message_by_id(
curr_system_message.id, message_update=MessageUpdate(content=new_system_message_str), actor=self.actor
) )
# Skip pulling down the agent's memory again to save on a db call diff = united_diff(curr_system_message_text, new_system_message_str)
return [new_system_message] + in_context_messages[1:] if len(diff) > 0:
logger.debug(f"Rebuilding system with new memory...\nDiff:\n{diff}")
else: new_system_message = self.message_manager.update_message_by_id(
return in_context_messages curr_system_message.id, message_update=MessageUpdate(content=new_system_message_str), actor=self.actor
)
# Skip pulling down the agent's memory again to save on a db call
return [new_system_message] + in_context_messages[1:]
else:
return in_context_messages
except:
logger.exception(f"Failed to rebuild memory for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name})")
raise
@trace_method @trace_method
async def _execute_tool(self, tool_name: str, tool_args: dict, agent_state: AgentState) -> Tuple[str, bool]: async def _execute_tool(self, tool_name: str, tool_args: dict, agent_state: AgentState) -> Tuple[str, bool]:
@ -331,6 +340,10 @@ class LettaAgent(BaseAgent):
results = await self._send_message_to_agents_matching_tags(**tool_args) results = await self._send_message_to_agents_matching_tags(**tool_args)
log_event(name="finish_send_message_to_agents_matching_tags", attributes=tool_args) log_event(name="finish_send_message_to_agents_matching_tags", attributes=tool_args)
return json.dumps(results), True return json.dumps(results), True
elif target_tool.type == ToolType.EXTERNAL_COMPOSIO:
log_event(name=f"start_composio_{tool_name}_execution", attributes=tool_args)
log_event(name=f"finish_compsio_{tool_name}_execution", attributes=tool_args)
return tool_execution_result.func_return, True
else: else:
tool_execution_manager = ToolExecutionManager(agent_state=agent_state, actor=self.actor) tool_execution_manager = ToolExecutionManager(agent_state=agent_state, actor=self.actor)
# TODO: Integrate sandbox result # TODO: Integrate sandbox result

View File

@ -156,8 +156,10 @@ class LettaAgentBatch:
log_event(name="init_llm_client") log_event(name="init_llm_client")
llm_client = LLMClient.create( llm_client = LLMClient.create(
provider=agent_states[0].llm_config.model_endpoint_type, provider_name=agent_states[0].llm_config.provider_name,
provider_type=agent_states[0].llm_config.model_endpoint_type,
put_inner_thoughts_first=True, put_inner_thoughts_first=True,
actor_id=self.actor.id,
) )
agent_llm_config_mapping = {s.id: s.llm_config for s in agent_states} agent_llm_config_mapping = {s.id: s.llm_config for s in agent_states}
@ -273,8 +275,10 @@ class LettaAgentBatch:
# translate providerspecific response → OpenAIstyle tool call (unchanged) # translate providerspecific response → OpenAIstyle tool call (unchanged)
llm_client = LLMClient.create( llm_client = LLMClient.create(
provider=item.llm_config.model_endpoint_type, provider_name=item.llm_config.provider_name,
provider_type=item.llm_config.model_endpoint_type,
put_inner_thoughts_first=True, put_inner_thoughts_first=True,
actor_id=self.actor.id,
) )
tool_call = ( tool_call = (
llm_client.convert_response_to_chat_completion( llm_client.convert_response_to_chat_completion(

View File

@ -6,6 +6,7 @@ from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
import openai import openai
from letta.agents.base_agent import BaseAgent from letta.agents.base_agent import BaseAgent
from letta.agents.exceptions import IncompatibleAgentType
from letta.agents.voice_sleeptime_agent import VoiceSleeptimeAgent from letta.agents.voice_sleeptime_agent import VoiceSleeptimeAgent
from letta.constants import NON_USER_MSG_PREFIX from letta.constants import NON_USER_MSG_PREFIX
from letta.helpers.datetime_helpers import get_utc_time from letta.helpers.datetime_helpers import get_utc_time
@ -18,7 +19,7 @@ from letta.helpers.tool_execution_helper import (
from letta.interfaces.openai_chat_completions_streaming_interface import OpenAIChatCompletionsStreamingInterface from letta.interfaces.openai_chat_completions_streaming_interface import OpenAIChatCompletionsStreamingInterface
from letta.log import get_logger from letta.log import get_logger
from letta.orm.enums import ToolType from letta.orm.enums import ToolType
from letta.schemas.agent import AgentState from letta.schemas.agent import AgentState, AgentType
from letta.schemas.enums import MessageRole from letta.schemas.enums import MessageRole
from letta.schemas.letta_response import LettaResponse from letta.schemas.letta_response import LettaResponse
from letta.schemas.message import Message, MessageCreate, MessageUpdate from letta.schemas.message import Message, MessageCreate, MessageUpdate
@ -68,8 +69,6 @@ class VoiceAgent(BaseAgent):
block_manager: BlockManager, block_manager: BlockManager,
passage_manager: PassageManager, passage_manager: PassageManager,
actor: User, actor: User,
message_buffer_limit: int,
message_buffer_min: int,
): ):
super().__init__( super().__init__(
agent_id=agent_id, openai_client=openai_client, message_manager=message_manager, agent_manager=agent_manager, actor=actor agent_id=agent_id, openai_client=openai_client, message_manager=message_manager, agent_manager=agent_manager, actor=actor
@ -80,8 +79,6 @@ class VoiceAgent(BaseAgent):
self.passage_manager = passage_manager self.passage_manager = passage_manager
# TODO: This is not guaranteed to exist! # TODO: This is not guaranteed to exist!
self.summary_block_label = "human" self.summary_block_label = "human"
self.message_buffer_limit = message_buffer_limit
self.message_buffer_min = message_buffer_min
# Cached archival memory/message size # Cached archival memory/message size
self.num_messages = self.message_manager.size(actor=self.actor, agent_id=agent_id) self.num_messages = self.message_manager.size(actor=self.actor, agent_id=agent_id)
@ -108,8 +105,8 @@ class VoiceAgent(BaseAgent):
target_block_label=self.summary_block_label, target_block_label=self.summary_block_label,
message_transcripts=[], message_transcripts=[],
), ),
message_buffer_limit=self.message_buffer_limit, message_buffer_limit=agent_state.multi_agent_group.max_message_buffer_length,
message_buffer_min=self.message_buffer_min, message_buffer_min=agent_state.multi_agent_group.min_message_buffer_length,
) )
return summarizer return summarizer
@ -124,9 +121,15 @@ class VoiceAgent(BaseAgent):
""" """
if len(input_messages) != 1 or input_messages[0].role != MessageRole.user: if len(input_messages) != 1 or input_messages[0].role != MessageRole.user:
raise ValueError(f"Voice Agent was invoked with multiple input messages or message did not have role `user`: {input_messages}") raise ValueError(f"Voice Agent was invoked with multiple input messages or message did not have role `user`: {input_messages}")
user_query = input_messages[0].content[0].text user_query = input_messages[0].content[0].text
agent_state = self.agent_manager.get_agent_by_id(self.agent_id, actor=self.actor) agent_state = self.agent_manager.get_agent_by_id(self.agent_id, actor=self.actor)
# Safety check
if agent_state.agent_type != AgentType.voice_convo_agent:
raise IncompatibleAgentType(expected_type=AgentType.voice_convo_agent, actual_type=agent_state.agent_type)
summarizer = self.init_summarizer(agent_state=agent_state) summarizer = self.init_summarizer(agent_state=agent_state)
in_context_messages = self.message_manager.get_messages_by_ids(message_ids=agent_state.message_ids, actor=self.actor) in_context_messages = self.message_manager.get_messages_by_ids(message_ids=agent_state.message_ids, actor=self.actor)

View File

@ -4,7 +4,7 @@ from logging import CRITICAL, DEBUG, ERROR, INFO, NOTSET, WARN, WARNING
LETTA_DIR = os.path.join(os.path.expanduser("~"), ".letta") LETTA_DIR = os.path.join(os.path.expanduser("~"), ".letta")
LETTA_TOOL_EXECUTION_DIR = os.path.join(LETTA_DIR, "tool_execution_dir") LETTA_TOOL_EXECUTION_DIR = os.path.join(LETTA_DIR, "tool_execution_dir")
LETTA_MODEL_ENDPOINT = "https://inference.memgpt.ai" LETTA_MODEL_ENDPOINT = "https://inference.letta.com"
ADMIN_PREFIX = "/v1/admin" ADMIN_PREFIX = "/v1/admin"
API_PREFIX = "/v1" API_PREFIX = "/v1"
@ -35,6 +35,10 @@ TOOL_CALL_ID_MAX_LEN = 29
# minimum context window size # minimum context window size
MIN_CONTEXT_WINDOW = 4096 MIN_CONTEXT_WINDOW = 4096
# Voice Sleeptime message buffer lengths
DEFAULT_MAX_MESSAGE_BUFFER_LENGTH = 30
DEFAULT_MIN_MESSAGE_BUFFER_LENGTH = 15
# embeddings # embeddings
MAX_EMBEDDING_DIM = 4096 # maximum supported embeding size - do NOT change or else DBs will need to be reset MAX_EMBEDDING_DIM = 4096 # maximum supported embeding size - do NOT change or else DBs will need to be reset
DEFAULT_EMBEDDING_CHUNK_SIZE = 300 DEFAULT_EMBEDDING_CHUNK_SIZE = 300

View File

@ -0,0 +1,100 @@
import asyncio
import os
from typing import Any, Optional
from composio import ComposioToolSet
from composio.constants import DEFAULT_ENTITY_ID
from composio.exceptions import (
ApiKeyNotProvidedError,
ComposioSDKError,
ConnectedAccountNotFoundError,
EnumMetadataNotFound,
EnumStringNotFound,
)
from letta.constants import COMPOSIO_ENTITY_ENV_VAR_KEY
# TODO: This is kind of hacky, as this is used to search up the action later on composio's side
# TODO: So be very careful changing/removing these pair of functions
def _generate_func_name_from_composio_action(action_name: str) -> str:
"""
Generates the composio function name from the composio action.
Args:
action_name: The composio action name
Returns:
function name
"""
return action_name.lower()
def generate_composio_action_from_func_name(func_name: str) -> str:
"""
Generates the composio action from the composio function name.
Args:
func_name: The composio function name
Returns:
composio action name
"""
return func_name.upper()
def generate_composio_tool_wrapper(action_name: str) -> tuple[str, str]:
# Generate func name
func_name = _generate_func_name_from_composio_action(action_name)
wrapper_function_str = f"""\
def {func_name}(**kwargs):
raise RuntimeError("Something went wrong - we should never be using the persisted source code for Composio. Please reach out to Letta team")
"""
# Compile safety check
_assert_code_gen_compilable(wrapper_function_str.strip())
return func_name, wrapper_function_str.strip()
async def execute_composio_action_async(
action_name: str, args: dict, api_key: Optional[str] = None, entity_id: Optional[str] = None
) -> tuple[str, str]:
try:
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, execute_composio_action, action_name, args, api_key, entity_id)
except Exception as e:
raise RuntimeError(f"Error in execute_composio_action_async: {e}") from e
def execute_composio_action(action_name: str, args: dict, api_key: Optional[str] = None, entity_id: Optional[str] = None) -> Any:
entity_id = entity_id or os.getenv(COMPOSIO_ENTITY_ENV_VAR_KEY, DEFAULT_ENTITY_ID)
try:
composio_toolset = ComposioToolSet(api_key=api_key, entity_id=entity_id, lock=False)
response = composio_toolset.execute_action(action=action_name, params=args)
except ApiKeyNotProvidedError:
raise RuntimeError(
f"Composio API key is missing for action '{action_name}'. "
"Please set the sandbox environment variables either through the ADE or the API."
)
except ConnectedAccountNotFoundError:
raise RuntimeError(f"No connected account was found for action '{action_name}'. " "Please link an account and try again.")
except EnumStringNotFound as e:
raise RuntimeError(f"Invalid value provided for action '{action_name}': " + str(e) + ". Please check the action parameters.")
except EnumMetadataNotFound as e:
raise RuntimeError(f"Invalid value provided for action '{action_name}': " + str(e) + ". Please check the action parameters.")
except ComposioSDKError as e:
raise RuntimeError(f"An unexpected error occurred in Composio SDK while executing action '{action_name}': " + str(e))
if "error" in response and response["error"]:
raise RuntimeError(f"Error while executing action '{action_name}': " + str(response["error"]))
return response.get("data")
def _assert_code_gen_compilable(code_str):
try:
compile(code_str, "<string>", "exec")
except SyntaxError as e:
print(f"Syntax error in code: {e}")

View File

@ -1,8 +1,9 @@
import importlib import importlib
import inspect import inspect
from collections.abc import Callable
from textwrap import dedent # remove indentation from textwrap import dedent # remove indentation
from types import ModuleType from types import ModuleType
from typing import Dict, List, Literal, Optional from typing import Any, Dict, List, Literal, Optional
from letta.errors import LettaToolCreateError from letta.errors import LettaToolCreateError
from letta.functions.schema_generator import generate_schema from letta.functions.schema_generator import generate_schema
@ -66,7 +67,8 @@ def parse_source_code(func) -> str:
return source_code return source_code
def get_function_from_module(module_name: str, function_name: str): # TODO (cliandy) refactor below two funcs
def get_function_from_module(module_name: str, function_name: str) -> Callable[..., Any]:
""" """
Dynamically imports a function from a specified module. Dynamically imports a function from a specified module.

View File

@ -6,10 +6,9 @@ from random import uniform
from typing import Any, Dict, List, Optional, Type, Union from typing import Any, Dict, List, Optional, Type, Union
import humps import humps
from composio.constants import DEFAULT_ENTITY_ID
from pydantic import BaseModel, Field, create_model from pydantic import BaseModel, Field, create_model
from letta.constants import COMPOSIO_ENTITY_ENV_VAR_KEY, DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
from letta.functions.interface import MultiAgentMessagingInterface from letta.functions.interface import MultiAgentMessagingInterface
from letta.orm.errors import NoResultFound from letta.orm.errors import NoResultFound
from letta.schemas.enums import MessageRole from letta.schemas.enums import MessageRole
@ -21,34 +20,6 @@ from letta.server.rest_api.utils import get_letta_server
from letta.settings import settings from letta.settings import settings
# TODO: This is kind of hacky, as this is used to search up the action later on composio's side
# TODO: So be very careful changing/removing these pair of functions
def generate_func_name_from_composio_action(action_name: str) -> str:
"""
Generates the composio function name from the composio action.
Args:
action_name: The composio action name
Returns:
function name
"""
return action_name.lower()
def generate_composio_action_from_func_name(func_name: str) -> str:
"""
Generates the composio action from the composio function name.
Args:
func_name: The composio function name
Returns:
composio action name
"""
return func_name.upper()
# TODO needed? # TODO needed?
def generate_mcp_tool_wrapper(mcp_tool_name: str) -> tuple[str, str]: def generate_mcp_tool_wrapper(mcp_tool_name: str) -> tuple[str, str]:
@ -58,71 +29,20 @@ def {mcp_tool_name}(**kwargs):
""" """
# Compile safety check # Compile safety check
assert_code_gen_compilable(wrapper_function_str.strip()) _assert_code_gen_compilable(wrapper_function_str.strip())
return mcp_tool_name, wrapper_function_str.strip() return mcp_tool_name, wrapper_function_str.strip()
def generate_composio_tool_wrapper(action_name: str) -> tuple[str, str]:
# Generate func name
func_name = generate_func_name_from_composio_action(action_name)
wrapper_function_str = f"""\
def {func_name}(**kwargs):
raise RuntimeError("Something went wrong - we should never be using the persisted source code for Composio. Please reach out to Letta team")
"""
# Compile safety check
assert_code_gen_compilable(wrapper_function_str.strip())
return func_name, wrapper_function_str.strip()
def execute_composio_action(action_name: str, args: dict, api_key: Optional[str] = None, entity_id: Optional[str] = None) -> Any:
import os
from composio.exceptions import (
ApiKeyNotProvidedError,
ComposioSDKError,
ConnectedAccountNotFoundError,
EnumMetadataNotFound,
EnumStringNotFound,
)
from composio_langchain import ComposioToolSet
entity_id = entity_id or os.getenv(COMPOSIO_ENTITY_ENV_VAR_KEY, DEFAULT_ENTITY_ID)
try:
composio_toolset = ComposioToolSet(api_key=api_key, entity_id=entity_id, lock=False)
response = composio_toolset.execute_action(action=action_name, params=args)
except ApiKeyNotProvidedError:
raise RuntimeError(
f"Composio API key is missing for action '{action_name}'. "
"Please set the sandbox environment variables either through the ADE or the API."
)
except ConnectedAccountNotFoundError:
raise RuntimeError(f"No connected account was found for action '{action_name}'. " "Please link an account and try again.")
except EnumStringNotFound as e:
raise RuntimeError(f"Invalid value provided for action '{action_name}': " + str(e) + ". Please check the action parameters.")
except EnumMetadataNotFound as e:
raise RuntimeError(f"Invalid value provided for action '{action_name}': " + str(e) + ". Please check the action parameters.")
except ComposioSDKError as e:
raise RuntimeError(f"An unexpected error occurred in Composio SDK while executing action '{action_name}': " + str(e))
if "error" in response:
raise RuntimeError(f"Error while executing action '{action_name}': " + str(response["error"]))
return response.get("data")
def generate_langchain_tool_wrapper( def generate_langchain_tool_wrapper(
tool: "LangChainBaseTool", additional_imports_module_attr_map: dict[str, str] = None tool: "LangChainBaseTool", additional_imports_module_attr_map: dict[str, str] = None
) -> tuple[str, str]: ) -> tuple[str, str]:
tool_name = tool.__class__.__name__ tool_name = tool.__class__.__name__
import_statement = f"from langchain_community.tools import {tool_name}" import_statement = f"from langchain_community.tools import {tool_name}"
extra_module_imports = generate_import_code(additional_imports_module_attr_map) extra_module_imports = _generate_import_code(additional_imports_module_attr_map)
# Safety check that user has passed in all required imports: # Safety check that user has passed in all required imports:
assert_all_classes_are_imported(tool, additional_imports_module_attr_map) _assert_all_classes_are_imported(tool, additional_imports_module_attr_map)
tool_instantiation = f"tool = {generate_imported_tool_instantiation_call_str(tool)}" tool_instantiation = f"tool = {generate_imported_tool_instantiation_call_str(tool)}"
run_call = f"return tool._run(**kwargs)" run_call = f"return tool._run(**kwargs)"
@ -139,25 +59,25 @@ def {func_name}(**kwargs):
""" """
# Compile safety check # Compile safety check
assert_code_gen_compilable(wrapper_function_str) _assert_code_gen_compilable(wrapper_function_str)
return func_name, wrapper_function_str return func_name, wrapper_function_str
def assert_code_gen_compilable(code_str): def _assert_code_gen_compilable(code_str):
try: try:
compile(code_str, "<string>", "exec") compile(code_str, "<string>", "exec")
except SyntaxError as e: except SyntaxError as e:
print(f"Syntax error in code: {e}") print(f"Syntax error in code: {e}")
def assert_all_classes_are_imported(tool: Union["LangChainBaseTool"], additional_imports_module_attr_map: dict[str, str]) -> None: def _assert_all_classes_are_imported(tool: Union["LangChainBaseTool"], additional_imports_module_attr_map: dict[str, str]) -> None:
# Safety check that user has passed in all required imports: # Safety check that user has passed in all required imports:
tool_name = tool.__class__.__name__ tool_name = tool.__class__.__name__
current_class_imports = {tool_name} current_class_imports = {tool_name}
if additional_imports_module_attr_map: if additional_imports_module_attr_map:
current_class_imports.update(set(additional_imports_module_attr_map.values())) current_class_imports.update(set(additional_imports_module_attr_map.values()))
required_class_imports = set(find_required_class_names_for_import(tool)) required_class_imports = set(_find_required_class_names_for_import(tool))
if not current_class_imports.issuperset(required_class_imports): if not current_class_imports.issuperset(required_class_imports):
err_msg = f"[ERROR] You are missing module_attr pairs in `additional_imports_module_attr_map`. Currently, you have imports for {current_class_imports}, but the required classes for import are {required_class_imports}" err_msg = f"[ERROR] You are missing module_attr pairs in `additional_imports_module_attr_map`. Currently, you have imports for {current_class_imports}, but the required classes for import are {required_class_imports}"
@ -165,7 +85,7 @@ def assert_all_classes_are_imported(tool: Union["LangChainBaseTool"], additional
raise RuntimeError(err_msg) raise RuntimeError(err_msg)
def find_required_class_names_for_import(obj: Union["LangChainBaseTool", BaseModel]) -> list[str]: def _find_required_class_names_for_import(obj: Union["LangChainBaseTool", BaseModel]) -> list[str]:
""" """
Finds all the class names for required imports when instantiating the `obj`. Finds all the class names for required imports when instantiating the `obj`.
NOTE: This does not return the full import path, only the class name. NOTE: This does not return the full import path, only the class name.
@ -181,7 +101,7 @@ def find_required_class_names_for_import(obj: Union["LangChainBaseTool", BaseMod
# Collect all possible candidates for BaseModel objects # Collect all possible candidates for BaseModel objects
candidates = [] candidates = []
if is_base_model(curr_obj): if _is_base_model(curr_obj):
# If it is a base model, we get all the values of the object parameters # If it is a base model, we get all the values of the object parameters
# i.e., if obj('b' = <class A>), we would want to inspect <class A> # i.e., if obj('b' = <class A>), we would want to inspect <class A>
fields = dict(curr_obj) fields = dict(curr_obj)
@ -198,7 +118,7 @@ def find_required_class_names_for_import(obj: Union["LangChainBaseTool", BaseMod
# Filter out all candidates that are not BaseModels # Filter out all candidates that are not BaseModels
# In the list example above, ['a', 3, None, <class A>], we want to filter out 'a', 3, and None # In the list example above, ['a', 3, None, <class A>], we want to filter out 'a', 3, and None
candidates = filter(lambda x: is_base_model(x), candidates) candidates = filter(lambda x: _is_base_model(x), candidates)
# Classic BFS here # Classic BFS here
for c in candidates: for c in candidates:
@ -216,7 +136,7 @@ def generate_imported_tool_instantiation_call_str(obj: Any) -> Optional[str]:
# If it is a basic Python type, we trivially return the string version of that value # If it is a basic Python type, we trivially return the string version of that value
# Handle basic types # Handle basic types
return repr(obj) return repr(obj)
elif is_base_model(obj): elif _is_base_model(obj):
# Otherwise, if it is a BaseModel # Otherwise, if it is a BaseModel
# We want to pull out all the parameters, and reformat them into strings # We want to pull out all the parameters, and reformat them into strings
# e.g. {arg}={value} # e.g. {arg}={value}
@ -269,11 +189,11 @@ def generate_imported_tool_instantiation_call_str(obj: Any) -> Optional[str]:
return None return None
def is_base_model(obj: Any): def _is_base_model(obj: Any):
return isinstance(obj, BaseModel) return isinstance(obj, BaseModel)
def generate_import_code(module_attr_map: Optional[dict]): def _generate_import_code(module_attr_map: Optional[dict]):
if not module_attr_map: if not module_attr_map:
return "" return ""
@ -286,7 +206,7 @@ def generate_import_code(module_attr_map: Optional[dict]):
return "\n".join(code_lines) return "\n".join(code_lines)
def parse_letta_response_for_assistant_message( def _parse_letta_response_for_assistant_message(
target_agent_id: str, target_agent_id: str,
letta_response: LettaResponse, letta_response: LettaResponse,
) -> Optional[str]: ) -> Optional[str]:
@ -346,7 +266,7 @@ def execute_send_message_to_agent(
return asyncio.run(async_execute_send_message_to_agent(sender_agent, messages, other_agent_id, log_prefix)) return asyncio.run(async_execute_send_message_to_agent(sender_agent, messages, other_agent_id, log_prefix))
async def send_message_to_agent_no_stream( async def _send_message_to_agent_no_stream(
server: "SyncServer", server: "SyncServer",
agent_id: str, agent_id: str,
actor: User, actor: User,
@ -375,7 +295,7 @@ async def send_message_to_agent_no_stream(
return LettaResponse(messages=final_messages, usage=usage_stats) return LettaResponse(messages=final_messages, usage=usage_stats)
async def async_send_message_with_retries( async def _async_send_message_with_retries(
server: "SyncServer", server: "SyncServer",
sender_agent: "Agent", sender_agent: "Agent",
target_agent_id: str, target_agent_id: str,
@ -389,7 +309,7 @@ async def async_send_message_with_retries(
for attempt in range(1, max_retries + 1): for attempt in range(1, max_retries + 1):
try: try:
response = await asyncio.wait_for( response = await asyncio.wait_for(
send_message_to_agent_no_stream( _send_message_to_agent_no_stream(
server=server, server=server,
agent_id=target_agent_id, agent_id=target_agent_id,
actor=sender_agent.user, actor=sender_agent.user,
@ -399,7 +319,7 @@ async def async_send_message_with_retries(
) )
# Then parse out the assistant message # Then parse out the assistant message
assistant_message = parse_letta_response_for_assistant_message(target_agent_id, response) assistant_message = _parse_letta_response_for_assistant_message(target_agent_id, response)
if assistant_message: if assistant_message:
sender_agent.logger.info(f"{logging_prefix} - {assistant_message}") sender_agent.logger.info(f"{logging_prefix} - {assistant_message}")
return assistant_message return assistant_message

View File

@ -76,6 +76,7 @@ def load_multi_agent(
agent_state=agent_state, agent_state=agent_state,
interface=interface, interface=interface,
user=actor, user=actor,
mcp_clients=mcp_clients,
group_id=group.id, group_id=group.id,
agent_ids=group.agent_ids, agent_ids=group.agent_ids,
description=group.description, description=group.description,

View File

@ -1,9 +1,10 @@
import asyncio import asyncio
import threading import threading
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import List, Optional from typing import Dict, List, Optional
from letta.agent import Agent, AgentState from letta.agent import Agent, AgentState
from letta.functions.mcp_client.base_client import BaseMCPClient
from letta.groups.helpers import stringify_message from letta.groups.helpers import stringify_message
from letta.interface import AgentInterface from letta.interface import AgentInterface
from letta.orm import User from letta.orm import User
@ -26,6 +27,7 @@ class SleeptimeMultiAgent(Agent):
interface: AgentInterface, interface: AgentInterface,
agent_state: AgentState, agent_state: AgentState,
user: User, user: User,
mcp_clients: Optional[Dict[str, BaseMCPClient]] = None,
# custom # custom
group_id: str = "", group_id: str = "",
agent_ids: List[str] = [], agent_ids: List[str] = [],
@ -115,6 +117,7 @@ class SleeptimeMultiAgent(Agent):
agent_state=participant_agent_state, agent_state=participant_agent_state,
interface=StreamingServerInterface(), interface=StreamingServerInterface(),
user=self.user, user=self.user,
mcp_clients=self.mcp_clients,
) )
prior_messages = [] prior_messages = []
@ -212,6 +215,7 @@ class SleeptimeMultiAgent(Agent):
agent_state=self.agent_state, agent_state=self.agent_state,
interface=self.interface, interface=self.interface,
user=self.user, user=self.user,
mcp_clients=self.mcp_clients,
) )
# Perform main agent step # Perform main agent step
usage_stats = main_agent.step( usage_stats = main_agent.step(

View File

@ -4,7 +4,24 @@ from letta.schemas.letta_message_content import TextContent
from letta.schemas.message import Message, MessageCreate from letta.schemas.message import Message, MessageCreate
def prepare_input_message_create( def convert_message_creates_to_messages(
messages: list[MessageCreate],
agent_id: str,
wrap_user_message: bool = True,
wrap_system_message: bool = True,
) -> list[Message]:
return [
_convert_message_create_to_message(
message=message,
agent_id=agent_id,
wrap_user_message=wrap_user_message,
wrap_system_message=wrap_system_message,
)
for message in messages
]
def _convert_message_create_to_message(
message: MessageCreate, message: MessageCreate,
agent_id: str, agent_id: str,
wrap_user_message: bool = True, wrap_user_message: bool = True,
@ -23,12 +40,12 @@ def prepare_input_message_create(
raise ValueError("Message content is empty or invalid") raise ValueError("Message content is empty or invalid")
# Apply wrapping if needed # Apply wrapping if needed
if message.role == MessageRole.user and wrap_user_message: if message.role not in {MessageRole.user, MessageRole.system}:
raise ValueError(f"Invalid message role: {message.role}")
elif message.role == MessageRole.user and wrap_user_message:
message_content = system.package_user_message(user_message=message_content) message_content = system.package_user_message(user_message=message_content)
elif message.role == MessageRole.system and wrap_system_message: elif message.role == MessageRole.system and wrap_system_message:
message_content = system.package_system_message(system_message=message_content) message_content = system.package_system_message(system_message=message_content)
elif message.role not in {MessageRole.user, MessageRole.system}:
raise ValueError(f"Invalid message role: {message.role}")
return Message( return Message(
agent_id=agent_id, agent_id=agent_id,

View File

@ -3,7 +3,7 @@ from typing import Any, Dict, Optional
from letta.constants import COMPOSIO_ENTITY_ENV_VAR_KEY, PRE_EXECUTION_MESSAGE_ARG from letta.constants import COMPOSIO_ENTITY_ENV_VAR_KEY, PRE_EXECUTION_MESSAGE_ARG
from letta.functions.ast_parsers import coerce_dict_args_by_annotations, get_function_annotations_from_source from letta.functions.ast_parsers import coerce_dict_args_by_annotations, get_function_annotations_from_source
from letta.functions.helpers import execute_composio_action, generate_composio_action_from_func_name from letta.functions.composio_helpers import execute_composio_action, generate_composio_action_from_func_name
from letta.helpers.composio_helpers import get_composio_api_key from letta.helpers.composio_helpers import get_composio_api_key
from letta.orm.enums import ToolType from letta.orm.enums import ToolType
from letta.schemas.agent import AgentState from letta.schemas.agent import AgentState

View File

@ -35,7 +35,7 @@ from letta.schemas.letta_message import (
from letta.schemas.letta_message_content import ReasoningContent, RedactedReasoningContent, TextContent from letta.schemas.letta_message_content import ReasoningContent, RedactedReasoningContent, TextContent
from letta.schemas.message import Message from letta.schemas.message import Message
from letta.schemas.openai.chat_completion_response import FunctionCall, ToolCall from letta.schemas.openai.chat_completion_response import FunctionCall, ToolCall
from letta.server.rest_api.optimistic_json_parser import OptimisticJSONParser from letta.server.rest_api.json_parser import JSONParser, PydanticJSONParser
logger = get_logger(__name__) logger = get_logger(__name__)
@ -56,7 +56,7 @@ class AnthropicStreamingInterface:
""" """
def __init__(self, use_assistant_message: bool = False, put_inner_thoughts_in_kwarg: bool = False): def __init__(self, use_assistant_message: bool = False, put_inner_thoughts_in_kwarg: bool = False):
self.optimistic_json_parser: OptimisticJSONParser = OptimisticJSONParser() self.json_parser: JSONParser = PydanticJSONParser()
self.use_assistant_message = use_assistant_message self.use_assistant_message = use_assistant_message
# Premake IDs for database writes # Premake IDs for database writes
@ -68,7 +68,7 @@ class AnthropicStreamingInterface:
self.accumulated_inner_thoughts = [] self.accumulated_inner_thoughts = []
self.tool_call_id = None self.tool_call_id = None
self.tool_call_name = None self.tool_call_name = None
self.accumulated_tool_call_args = [] self.accumulated_tool_call_args = ""
self.previous_parse = {} self.previous_parse = {}
# usage trackers # usage trackers
@ -85,193 +85,200 @@ class AnthropicStreamingInterface:
def get_tool_call_object(self) -> ToolCall: def get_tool_call_object(self) -> ToolCall:
"""Useful for agent loop""" """Useful for agent loop"""
return ToolCall( return ToolCall(id=self.tool_call_id, function=FunctionCall(arguments=self.accumulated_tool_call_args, name=self.tool_call_name))
id=self.tool_call_id, function=FunctionCall(arguments="".join(self.accumulated_tool_call_args), name=self.tool_call_name)
)
def _check_inner_thoughts_complete(self, combined_args: str) -> bool: def _check_inner_thoughts_complete(self, combined_args: str) -> bool:
""" """
Check if inner thoughts are complete in the current tool call arguments Check if inner thoughts are complete in the current tool call arguments
by looking for a closing quote after the inner_thoughts field by looking for a closing quote after the inner_thoughts field
""" """
if not self.put_inner_thoughts_in_kwarg: try:
# None of the things should have inner thoughts in kwargs if not self.put_inner_thoughts_in_kwarg:
return True # None of the things should have inner thoughts in kwargs
else: return True
parsed = self.optimistic_json_parser.parse(combined_args) else:
# TODO: This will break on tools with 0 input parsed = self.json_parser.parse(combined_args)
return len(parsed.keys()) > 1 and INNER_THOUGHTS_KWARG in parsed.keys() # TODO: This will break on tools with 0 input
return len(parsed.keys()) > 1 and INNER_THOUGHTS_KWARG in parsed.keys()
except Exception as e:
logger.error("Error checking inner thoughts: %s", e)
raise
async def process(self, stream: AsyncStream[BetaRawMessageStreamEvent]) -> AsyncGenerator[LettaMessage, None]: async def process(self, stream: AsyncStream[BetaRawMessageStreamEvent]) -> AsyncGenerator[LettaMessage, None]:
async with stream: try:
async for event in stream: async with stream:
# TODO: Support BetaThinkingBlock, BetaRedactedThinkingBlock async for event in stream:
if isinstance(event, BetaRawContentBlockStartEvent): # TODO: Support BetaThinkingBlock, BetaRedactedThinkingBlock
content = event.content_block if isinstance(event, BetaRawContentBlockStartEvent):
content = event.content_block
if isinstance(content, BetaTextBlock): if isinstance(content, BetaTextBlock):
self.anthropic_mode = EventMode.TEXT self.anthropic_mode = EventMode.TEXT
# TODO: Can capture citations, etc. # TODO: Can capture citations, etc.
elif isinstance(content, BetaToolUseBlock): elif isinstance(content, BetaToolUseBlock):
self.anthropic_mode = EventMode.TOOL_USE self.anthropic_mode = EventMode.TOOL_USE
self.tool_call_id = content.id self.tool_call_id = content.id
self.tool_call_name = content.name self.tool_call_name = content.name
self.inner_thoughts_complete = False self.inner_thoughts_complete = False
if not self.use_assistant_message: if not self.use_assistant_message:
# Buffer the initial tool call message instead of yielding immediately # Buffer the initial tool call message instead of yielding immediately
tool_call_msg = ToolCallMessage( tool_call_msg = ToolCallMessage(
id=self.letta_tool_message_id, id=self.letta_tool_message_id,
tool_call=ToolCallDelta(name=self.tool_call_name, tool_call_id=self.tool_call_id), tool_call=ToolCallDelta(name=self.tool_call_name, tool_call_id=self.tool_call_id),
date=datetime.now(timezone.utc).isoformat(),
)
self.tool_call_buffer.append(tool_call_msg)
elif isinstance(content, BetaThinkingBlock):
self.anthropic_mode = EventMode.THINKING
# TODO: Can capture signature, etc.
elif isinstance(content, BetaRedactedThinkingBlock):
self.anthropic_mode = EventMode.REDACTED_THINKING
hidden_reasoning_message = HiddenReasoningMessage(
id=self.letta_assistant_message_id,
state="redacted",
hidden_reasoning=content.data,
date=datetime.now(timezone.utc).isoformat(), date=datetime.now(timezone.utc).isoformat(),
) )
self.tool_call_buffer.append(tool_call_msg) self.reasoning_messages.append(hidden_reasoning_message)
elif isinstance(content, BetaThinkingBlock): yield hidden_reasoning_message
self.anthropic_mode = EventMode.THINKING
# TODO: Can capture signature, etc.
elif isinstance(content, BetaRedactedThinkingBlock):
self.anthropic_mode = EventMode.REDACTED_THINKING
hidden_reasoning_message = HiddenReasoningMessage( elif isinstance(event, BetaRawContentBlockDeltaEvent):
id=self.letta_assistant_message_id, delta = event.delta
state="redacted",
hidden_reasoning=content.data,
date=datetime.now(timezone.utc).isoformat(),
)
self.reasoning_messages.append(hidden_reasoning_message)
yield hidden_reasoning_message
elif isinstance(event, BetaRawContentBlockDeltaEvent): if isinstance(delta, BetaTextDelta):
delta = event.delta # Safety check
if not self.anthropic_mode == EventMode.TEXT:
raise RuntimeError(
f"Streaming integrity failed - received BetaTextDelta object while not in TEXT EventMode: {delta}"
)
if isinstance(delta, BetaTextDelta): # TODO: Strip out </thinking> more robustly, this is pretty hacky lol
# Safety check delta.text = delta.text.replace("</thinking>", "")
if not self.anthropic_mode == EventMode.TEXT: self.accumulated_inner_thoughts.append(delta.text)
raise RuntimeError(
f"Streaming integrity failed - received BetaTextDelta object while not in TEXT EventMode: {delta}"
)
# TODO: Strip out </thinking> more robustly, this is pretty hacky lol
delta.text = delta.text.replace("</thinking>", "")
self.accumulated_inner_thoughts.append(delta.text)
reasoning_message = ReasoningMessage(
id=self.letta_assistant_message_id,
reasoning=self.accumulated_inner_thoughts[-1],
date=datetime.now(timezone.utc).isoformat(),
)
self.reasoning_messages.append(reasoning_message)
yield reasoning_message
elif isinstance(delta, BetaInputJSONDelta):
if not self.anthropic_mode == EventMode.TOOL_USE:
raise RuntimeError(
f"Streaming integrity failed - received BetaInputJSONDelta object while not in TOOL_USE EventMode: {delta}"
)
self.accumulated_tool_call_args.append(delta.partial_json)
combined_args = "".join(self.accumulated_tool_call_args)
current_parsed = self.optimistic_json_parser.parse(combined_args)
# Start detecting a difference in inner thoughts
previous_inner_thoughts = self.previous_parse.get(INNER_THOUGHTS_KWARG, "")
current_inner_thoughts = current_parsed.get(INNER_THOUGHTS_KWARG, "")
inner_thoughts_diff = current_inner_thoughts[len(previous_inner_thoughts) :]
if inner_thoughts_diff:
reasoning_message = ReasoningMessage( reasoning_message = ReasoningMessage(
id=self.letta_assistant_message_id, id=self.letta_assistant_message_id,
reasoning=inner_thoughts_diff, reasoning=self.accumulated_inner_thoughts[-1],
date=datetime.now(timezone.utc).isoformat(), date=datetime.now(timezone.utc).isoformat(),
) )
self.reasoning_messages.append(reasoning_message) self.reasoning_messages.append(reasoning_message)
yield reasoning_message yield reasoning_message
# Check if inner thoughts are complete - if so, flush the buffer elif isinstance(delta, BetaInputJSONDelta):
if not self.inner_thoughts_complete and self._check_inner_thoughts_complete(combined_args): if not self.anthropic_mode == EventMode.TOOL_USE:
self.inner_thoughts_complete = True raise RuntimeError(
# Flush all buffered tool call messages f"Streaming integrity failed - received BetaInputJSONDelta object while not in TOOL_USE EventMode: {delta}"
)
self.accumulated_tool_call_args += delta.partial_json
current_parsed = self.json_parser.parse(self.accumulated_tool_call_args)
# Start detecting a difference in inner thoughts
previous_inner_thoughts = self.previous_parse.get(INNER_THOUGHTS_KWARG, "")
current_inner_thoughts = current_parsed.get(INNER_THOUGHTS_KWARG, "")
inner_thoughts_diff = current_inner_thoughts[len(previous_inner_thoughts) :]
if inner_thoughts_diff:
reasoning_message = ReasoningMessage(
id=self.letta_assistant_message_id,
reasoning=inner_thoughts_diff,
date=datetime.now(timezone.utc).isoformat(),
)
self.reasoning_messages.append(reasoning_message)
yield reasoning_message
# Check if inner thoughts are complete - if so, flush the buffer
if not self.inner_thoughts_complete and self._check_inner_thoughts_complete(self.accumulated_tool_call_args):
self.inner_thoughts_complete = True
# Flush all buffered tool call messages
for buffered_msg in self.tool_call_buffer:
yield buffered_msg
self.tool_call_buffer = []
# Start detecting special case of "send_message"
if self.tool_call_name == DEFAULT_MESSAGE_TOOL and self.use_assistant_message:
previous_send_message = self.previous_parse.get(DEFAULT_MESSAGE_TOOL_KWARG, "")
current_send_message = current_parsed.get(DEFAULT_MESSAGE_TOOL_KWARG, "")
send_message_diff = current_send_message[len(previous_send_message) :]
# Only stream out if it's not an empty string
if send_message_diff:
yield AssistantMessage(
id=self.letta_assistant_message_id,
content=[TextContent(text=send_message_diff)],
date=datetime.now(timezone.utc).isoformat(),
)
else:
# Otherwise, it is a normal tool call - buffer or yield based on inner thoughts status
tool_call_msg = ToolCallMessage(
id=self.letta_tool_message_id,
tool_call=ToolCallDelta(arguments=delta.partial_json),
date=datetime.now(timezone.utc).isoformat(),
)
if self.inner_thoughts_complete:
yield tool_call_msg
else:
self.tool_call_buffer.append(tool_call_msg)
# Set previous parse
self.previous_parse = current_parsed
elif isinstance(delta, BetaThinkingDelta):
# Safety check
if not self.anthropic_mode == EventMode.THINKING:
raise RuntimeError(
f"Streaming integrity failed - received BetaThinkingBlock object while not in THINKING EventMode: {delta}"
)
reasoning_message = ReasoningMessage(
id=self.letta_assistant_message_id,
source="reasoner_model",
reasoning=delta.thinking,
date=datetime.now(timezone.utc).isoformat(),
)
self.reasoning_messages.append(reasoning_message)
yield reasoning_message
elif isinstance(delta, BetaSignatureDelta):
# Safety check
if not self.anthropic_mode == EventMode.THINKING:
raise RuntimeError(
f"Streaming integrity failed - received BetaSignatureDelta object while not in THINKING EventMode: {delta}"
)
reasoning_message = ReasoningMessage(
id=self.letta_assistant_message_id,
source="reasoner_model",
reasoning="",
date=datetime.now(timezone.utc).isoformat(),
signature=delta.signature,
)
self.reasoning_messages.append(reasoning_message)
yield reasoning_message
elif isinstance(event, BetaRawMessageStartEvent):
self.message_id = event.message.id
self.input_tokens += event.message.usage.input_tokens
self.output_tokens += event.message.usage.output_tokens
elif isinstance(event, BetaRawMessageDeltaEvent):
self.output_tokens += event.usage.output_tokens
elif isinstance(event, BetaRawMessageStopEvent):
# Don't do anything here! We don't want to stop the stream.
pass
elif isinstance(event, BetaRawContentBlockStopEvent):
# If we're exiting a tool use block and there are still buffered messages,
# we should flush them now
if self.anthropic_mode == EventMode.TOOL_USE and self.tool_call_buffer:
for buffered_msg in self.tool_call_buffer: for buffered_msg in self.tool_call_buffer:
yield buffered_msg yield buffered_msg
self.tool_call_buffer = [] self.tool_call_buffer = []
# Start detecting special case of "send_message" self.anthropic_mode = None
if self.tool_call_name == DEFAULT_MESSAGE_TOOL and self.use_assistant_message: except Exception as e:
previous_send_message = self.previous_parse.get(DEFAULT_MESSAGE_TOOL_KWARG, "") logger.error("Error processing stream: %s", e)
current_send_message = current_parsed.get(DEFAULT_MESSAGE_TOOL_KWARG, "") raise
send_message_diff = current_send_message[len(previous_send_message) :] finally:
logger.info("AnthropicStreamingInterface: Stream processing complete.")
# Only stream out if it's not an empty string
if send_message_diff:
yield AssistantMessage(
id=self.letta_assistant_message_id,
content=[TextContent(text=send_message_diff)],
date=datetime.now(timezone.utc).isoformat(),
)
else:
# Otherwise, it is a normal tool call - buffer or yield based on inner thoughts status
tool_call_msg = ToolCallMessage(
id=self.letta_tool_message_id,
tool_call=ToolCallDelta(arguments=delta.partial_json),
date=datetime.now(timezone.utc).isoformat(),
)
if self.inner_thoughts_complete:
yield tool_call_msg
else:
self.tool_call_buffer.append(tool_call_msg)
# Set previous parse
self.previous_parse = current_parsed
elif isinstance(delta, BetaThinkingDelta):
# Safety check
if not self.anthropic_mode == EventMode.THINKING:
raise RuntimeError(
f"Streaming integrity failed - received BetaThinkingBlock object while not in THINKING EventMode: {delta}"
)
reasoning_message = ReasoningMessage(
id=self.letta_assistant_message_id,
source="reasoner_model",
reasoning=delta.thinking,
date=datetime.now(timezone.utc).isoformat(),
)
self.reasoning_messages.append(reasoning_message)
yield reasoning_message
elif isinstance(delta, BetaSignatureDelta):
# Safety check
if not self.anthropic_mode == EventMode.THINKING:
raise RuntimeError(
f"Streaming integrity failed - received BetaSignatureDelta object while not in THINKING EventMode: {delta}"
)
reasoning_message = ReasoningMessage(
id=self.letta_assistant_message_id,
source="reasoner_model",
reasoning="",
date=datetime.now(timezone.utc).isoformat(),
signature=delta.signature,
)
self.reasoning_messages.append(reasoning_message)
yield reasoning_message
elif isinstance(event, BetaRawMessageStartEvent):
self.message_id = event.message.id
self.input_tokens += event.message.usage.input_tokens
self.output_tokens += event.message.usage.output_tokens
elif isinstance(event, BetaRawMessageDeltaEvent):
self.output_tokens += event.usage.output_tokens
elif isinstance(event, BetaRawMessageStopEvent):
# Don't do anything here! We don't want to stop the stream.
pass
elif isinstance(event, BetaRawContentBlockStopEvent):
# If we're exiting a tool use block and there are still buffered messages,
# we should flush them now
if self.anthropic_mode == EventMode.TOOL_USE and self.tool_call_buffer:
for buffered_msg in self.tool_call_buffer:
yield buffered_msg
self.tool_call_buffer = []
self.anthropic_mode = None
def get_reasoning_content(self) -> List[Union[TextContent, ReasoningContent, RedactedReasoningContent]]: def get_reasoning_content(self) -> List[Union[TextContent, ReasoningContent, RedactedReasoningContent]]:
def _process_group( def _process_group(

View File

@ -5,7 +5,7 @@ from openai.types.chat.chat_completion_chunk import ChatCompletionChunk, Choice,
from letta.constants import PRE_EXECUTION_MESSAGE_ARG from letta.constants import PRE_EXECUTION_MESSAGE_ARG
from letta.interfaces.utils import _format_sse_chunk from letta.interfaces.utils import _format_sse_chunk
from letta.server.rest_api.optimistic_json_parser import OptimisticJSONParser from letta.server.rest_api.json_parser import OptimisticJSONParser
class OpenAIChatCompletionsStreamingInterface: class OpenAIChatCompletionsStreamingInterface:

View File

@ -26,6 +26,7 @@ from letta.llm_api.helpers import add_inner_thoughts_to_functions
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION
from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages
from letta.log import get_logger from letta.log import get_logger
from letta.schemas.enums import ProviderType
from letta.schemas.message import Message as _Message from letta.schemas.message import Message as _Message
from letta.schemas.message import MessageRole as _MessageRole from letta.schemas.message import MessageRole as _MessageRole
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, Tool from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, Tool
@ -128,11 +129,12 @@ def anthropic_get_model_list(url: str, api_key: Union[str, None]) -> dict:
# NOTE: currently there is no GET /models, so we need to hardcode # NOTE: currently there is no GET /models, so we need to hardcode
# return MODEL_LIST # return MODEL_LIST
anthropic_override_key = ProviderManager().get_anthropic_override_key() if api_key:
if anthropic_override_key: anthropic_client = anthropic.Anthropic(api_key=api_key)
anthropic_client = anthropic.Anthropic(api_key=anthropic_override_key)
elif model_settings.anthropic_api_key: elif model_settings.anthropic_api_key:
anthropic_client = anthropic.Anthropic() anthropic_client = anthropic.Anthropic()
else:
raise ValueError("No API key provided")
models = anthropic_client.models.list() models = anthropic_client.models.list()
models_json = models.model_dump() models_json = models.model_dump()
@ -738,13 +740,14 @@ def anthropic_chat_completions_request(
put_inner_thoughts_in_kwargs: bool = False, put_inner_thoughts_in_kwargs: bool = False,
extended_thinking: bool = False, extended_thinking: bool = False,
max_reasoning_tokens: Optional[int] = None, max_reasoning_tokens: Optional[int] = None,
provider_name: Optional[str] = None,
betas: List[str] = ["tools-2024-04-04"], betas: List[str] = ["tools-2024-04-04"],
) -> ChatCompletionResponse: ) -> ChatCompletionResponse:
"""https://docs.anthropic.com/claude/docs/tool-use""" """https://docs.anthropic.com/claude/docs/tool-use"""
anthropic_client = None anthropic_client = None
anthropic_override_key = ProviderManager().get_anthropic_override_key() if provider_name and provider_name != ProviderType.anthropic.value:
if anthropic_override_key: api_key = ProviderManager().get_override_key(provider_name)
anthropic_client = anthropic.Anthropic(api_key=anthropic_override_key) anthropic_client = anthropic.Anthropic(api_key=api_key)
elif model_settings.anthropic_api_key: elif model_settings.anthropic_api_key:
anthropic_client = anthropic.Anthropic() anthropic_client = anthropic.Anthropic()
else: else:
@ -796,6 +799,7 @@ def anthropic_chat_completions_request_stream(
put_inner_thoughts_in_kwargs: bool = False, put_inner_thoughts_in_kwargs: bool = False,
extended_thinking: bool = False, extended_thinking: bool = False,
max_reasoning_tokens: Optional[int] = None, max_reasoning_tokens: Optional[int] = None,
provider_name: Optional[str] = None,
betas: List[str] = ["tools-2024-04-04"], betas: List[str] = ["tools-2024-04-04"],
) -> Generator[ChatCompletionChunkResponse, None, None]: ) -> Generator[ChatCompletionChunkResponse, None, None]:
"""Stream chat completions from Anthropic API. """Stream chat completions from Anthropic API.
@ -810,10 +814,9 @@ def anthropic_chat_completions_request_stream(
extended_thinking=extended_thinking, extended_thinking=extended_thinking,
max_reasoning_tokens=max_reasoning_tokens, max_reasoning_tokens=max_reasoning_tokens,
) )
if provider_name and provider_name != ProviderType.anthropic.value:
anthropic_override_key = ProviderManager().get_anthropic_override_key() api_key = ProviderManager().get_override_key(provider_name)
if anthropic_override_key: anthropic_client = anthropic.Anthropic(api_key=api_key)
anthropic_client = anthropic.Anthropic(api_key=anthropic_override_key)
elif model_settings.anthropic_api_key: elif model_settings.anthropic_api_key:
anthropic_client = anthropic.Anthropic() anthropic_client = anthropic.Anthropic()
@ -860,6 +863,7 @@ def anthropic_chat_completions_process_stream(
put_inner_thoughts_in_kwargs: bool = False, put_inner_thoughts_in_kwargs: bool = False,
extended_thinking: bool = False, extended_thinking: bool = False,
max_reasoning_tokens: Optional[int] = None, max_reasoning_tokens: Optional[int] = None,
provider_name: Optional[str] = None,
create_message_id: bool = True, create_message_id: bool = True,
create_message_datetime: bool = True, create_message_datetime: bool = True,
betas: List[str] = ["tools-2024-04-04"], betas: List[str] = ["tools-2024-04-04"],
@ -944,6 +948,7 @@ def anthropic_chat_completions_process_stream(
put_inner_thoughts_in_kwargs=put_inner_thoughts_in_kwargs, put_inner_thoughts_in_kwargs=put_inner_thoughts_in_kwargs,
extended_thinking=extended_thinking, extended_thinking=extended_thinking,
max_reasoning_tokens=max_reasoning_tokens, max_reasoning_tokens=max_reasoning_tokens,
provider_name=provider_name,
betas=betas, betas=betas,
) )
): ):

View File

@ -27,6 +27,7 @@ from letta.llm_api.helpers import add_inner_thoughts_to_functions, unpack_all_in
from letta.llm_api.llm_client_base import LLMClientBase from letta.llm_api.llm_client_base import LLMClientBase
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION
from letta.log import get_logger from letta.log import get_logger
from letta.schemas.enums import ProviderType
from letta.schemas.llm_config import LLMConfig from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message as PydanticMessage from letta.schemas.message import Message as PydanticMessage
from letta.schemas.openai.chat_completion_request import Tool from letta.schemas.openai.chat_completion_request import Tool
@ -112,7 +113,10 @@ class AnthropicClient(LLMClientBase):
@trace_method @trace_method
def _get_anthropic_client(self, async_client: bool = False) -> Union[anthropic.AsyncAnthropic, anthropic.Anthropic]: def _get_anthropic_client(self, async_client: bool = False) -> Union[anthropic.AsyncAnthropic, anthropic.Anthropic]:
override_key = ProviderManager().get_anthropic_override_key() override_key = None
if self.provider_name and self.provider_name != ProviderType.anthropic.value:
override_key = ProviderManager().get_override_key(self.provider_name)
if async_client: if async_client:
return anthropic.AsyncAnthropic(api_key=override_key) if override_key else anthropic.AsyncAnthropic() return anthropic.AsyncAnthropic(api_key=override_key) if override_key else anthropic.AsyncAnthropic()
return anthropic.Anthropic(api_key=override_key) if override_key else anthropic.Anthropic() return anthropic.Anthropic(api_key=override_key) if override_key else anthropic.Anthropic()

View File

@ -63,7 +63,7 @@ class GoogleVertexClient(GoogleAIClient):
# Add thinking_config # Add thinking_config
# If enable_reasoner is False, set thinking_budget to 0 # If enable_reasoner is False, set thinking_budget to 0
# Otherwise, use the value from max_reasoning_tokens # Otherwise, use the value from max_reasoning_tokens
thinking_budget = 0 if not self.llm_config.enable_reasoner else self.llm_config.max_reasoning_tokens thinking_budget = 0 if not llm_config.enable_reasoner else llm_config.max_reasoning_tokens
thinking_config = ThinkingConfig( thinking_config = ThinkingConfig(
thinking_budget=thinking_budget, thinking_budget=thinking_budget,
) )

View File

@ -24,6 +24,7 @@ from letta.llm_api.openai import (
from letta.local_llm.chat_completion_proxy import get_chat_completion from letta.local_llm.chat_completion_proxy import get_chat_completion
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION
from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages
from letta.schemas.enums import ProviderType
from letta.schemas.llm_config import LLMConfig from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message from letta.schemas.message import Message
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, cast_message_to_subtype from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, cast_message_to_subtype
@ -171,6 +172,10 @@ def create(
if model_settings.openai_api_key is None and llm_config.model_endpoint == "https://api.openai.com/v1": if model_settings.openai_api_key is None and llm_config.model_endpoint == "https://api.openai.com/v1":
# only is a problem if we are *not* using an openai proxy # only is a problem if we are *not* using an openai proxy
raise LettaConfigurationError(message="OpenAI key is missing from letta config file", missing_fields=["openai_api_key"]) raise LettaConfigurationError(message="OpenAI key is missing from letta config file", missing_fields=["openai_api_key"])
elif llm_config.provider_name and llm_config.provider_name != ProviderType.openai.value:
from letta.services.provider_manager import ProviderManager
api_key = ProviderManager().get_override_key(llm_config.provider_name)
elif model_settings.openai_api_key is None: elif model_settings.openai_api_key is None:
# the openai python client requires a dummy API key # the openai python client requires a dummy API key
api_key = "DUMMY_API_KEY" api_key = "DUMMY_API_KEY"
@ -373,6 +378,7 @@ def create(
stream_interface=stream_interface, stream_interface=stream_interface,
extended_thinking=llm_config.enable_reasoner, extended_thinking=llm_config.enable_reasoner,
max_reasoning_tokens=llm_config.max_reasoning_tokens, max_reasoning_tokens=llm_config.max_reasoning_tokens,
provider_name=llm_config.provider_name,
name=name, name=name,
) )
@ -383,6 +389,7 @@ def create(
put_inner_thoughts_in_kwargs=llm_config.put_inner_thoughts_in_kwargs, put_inner_thoughts_in_kwargs=llm_config.put_inner_thoughts_in_kwargs,
extended_thinking=llm_config.enable_reasoner, extended_thinking=llm_config.enable_reasoner,
max_reasoning_tokens=llm_config.max_reasoning_tokens, max_reasoning_tokens=llm_config.max_reasoning_tokens,
provider_name=llm_config.provider_name,
) )
if llm_config.put_inner_thoughts_in_kwargs: if llm_config.put_inner_thoughts_in_kwargs:

View File

@ -9,8 +9,10 @@ class LLMClient:
@staticmethod @staticmethod
def create( def create(
provider: ProviderType, provider_type: ProviderType,
provider_name: Optional[str] = None,
put_inner_thoughts_first: bool = True, put_inner_thoughts_first: bool = True,
actor_id: Optional[str] = None,
) -> Optional[LLMClientBase]: ) -> Optional[LLMClientBase]:
""" """
Create an LLM client based on the model endpoint type. Create an LLM client based on the model endpoint type.
@ -25,30 +27,38 @@ class LLMClient:
Raises: Raises:
ValueError: If the model endpoint type is not supported ValueError: If the model endpoint type is not supported
""" """
match provider: match provider_type:
case ProviderType.google_ai: case ProviderType.google_ai:
from letta.llm_api.google_ai_client import GoogleAIClient from letta.llm_api.google_ai_client import GoogleAIClient
return GoogleAIClient( return GoogleAIClient(
provider_name=provider_name,
put_inner_thoughts_first=put_inner_thoughts_first, put_inner_thoughts_first=put_inner_thoughts_first,
actor_id=actor_id,
) )
case ProviderType.google_vertex: case ProviderType.google_vertex:
from letta.llm_api.google_vertex_client import GoogleVertexClient from letta.llm_api.google_vertex_client import GoogleVertexClient
return GoogleVertexClient( return GoogleVertexClient(
provider_name=provider_name,
put_inner_thoughts_first=put_inner_thoughts_first, put_inner_thoughts_first=put_inner_thoughts_first,
actor_id=actor_id,
) )
case ProviderType.anthropic: case ProviderType.anthropic:
from letta.llm_api.anthropic_client import AnthropicClient from letta.llm_api.anthropic_client import AnthropicClient
return AnthropicClient( return AnthropicClient(
provider_name=provider_name,
put_inner_thoughts_first=put_inner_thoughts_first, put_inner_thoughts_first=put_inner_thoughts_first,
actor_id=actor_id,
) )
case ProviderType.openai: case ProviderType.openai:
from letta.llm_api.openai_client import OpenAIClient from letta.llm_api.openai_client import OpenAIClient
return OpenAIClient( return OpenAIClient(
provider_name=provider_name,
put_inner_thoughts_first=put_inner_thoughts_first, put_inner_thoughts_first=put_inner_thoughts_first,
actor_id=actor_id,
) )
case _: case _:
return None return None

View File

@ -20,9 +20,13 @@ class LLMClientBase:
def __init__( def __init__(
self, self,
provider_name: Optional[str] = None,
put_inner_thoughts_first: Optional[bool] = True, put_inner_thoughts_first: Optional[bool] = True,
use_tool_naming: bool = True, use_tool_naming: bool = True,
actor_id: Optional[str] = None,
): ):
self.actor_id = actor_id
self.provider_name = provider_name
self.put_inner_thoughts_first = put_inner_thoughts_first self.put_inner_thoughts_first = put_inner_thoughts_first
self.use_tool_naming = use_tool_naming self.use_tool_naming = use_tool_naming

View File

@ -157,11 +157,17 @@ def build_openai_chat_completions_request(
# if "gpt-4o" in llm_config.model or "gpt-4-turbo" in llm_config.model or "gpt-3.5-turbo" in llm_config.model: # if "gpt-4o" in llm_config.model or "gpt-4-turbo" in llm_config.model or "gpt-3.5-turbo" in llm_config.model:
# data.response_format = {"type": "json_object"} # data.response_format = {"type": "json_object"}
if llm_config.model_endpoint == LETTA_MODEL_ENDPOINT: # always set user id for openai requests
# override user id for inference.memgpt.ai if user_id:
import uuid data.user = str(user_id)
if llm_config.model_endpoint == LETTA_MODEL_ENDPOINT:
if not user_id:
# override user id for inference.letta.com
import uuid
data.user = str(uuid.UUID(int=0))
data.user = str(uuid.UUID(int=0))
data.model = "memgpt-openai" data.model = "memgpt-openai"
if use_structured_output and data.tools is not None and len(data.tools) > 0: if use_structured_output and data.tools is not None and len(data.tools) > 0:

View File

@ -22,6 +22,7 @@ from letta.llm_api.helpers import add_inner_thoughts_to_functions, convert_to_st
from letta.llm_api.llm_client_base import LLMClientBase from letta.llm_api.llm_client_base import LLMClientBase
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION, INNER_THOUGHTS_KWARG_DESCRIPTION_GO_FIRST from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION, INNER_THOUGHTS_KWARG_DESCRIPTION_GO_FIRST
from letta.log import get_logger from letta.log import get_logger
from letta.schemas.enums import ProviderType
from letta.schemas.llm_config import LLMConfig from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message as PydanticMessage from letta.schemas.message import Message as PydanticMessage
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest from letta.schemas.openai.chat_completion_request import ChatCompletionRequest
@ -64,7 +65,14 @@ def supports_parallel_tool_calling(model: str) -> bool:
class OpenAIClient(LLMClientBase): class OpenAIClient(LLMClientBase):
def _prepare_client_kwargs(self, llm_config: LLMConfig) -> dict: def _prepare_client_kwargs(self, llm_config: LLMConfig) -> dict:
api_key = model_settings.openai_api_key or os.environ.get("OPENAI_API_KEY") api_key = None
if llm_config.provider_name and llm_config.provider_name != ProviderType.openai.value:
from letta.services.provider_manager import ProviderManager
api_key = ProviderManager().get_override_key(llm_config.provider_name)
if not api_key:
api_key = model_settings.openai_api_key or os.environ.get("OPENAI_API_KEY")
# supposedly the openai python client requires a dummy API key # supposedly the openai python client requires a dummy API key
api_key = api_key or "DUMMY_API_KEY" api_key = api_key or "DUMMY_API_KEY"
kwargs = {"api_key": api_key, "base_url": llm_config.model_endpoint} kwargs = {"api_key": api_key, "base_url": llm_config.model_endpoint}
@ -135,11 +143,17 @@ class OpenAIClient(LLMClientBase):
temperature=llm_config.temperature if supports_temperature_param(model) else None, temperature=llm_config.temperature if supports_temperature_param(model) else None,
) )
if llm_config.model_endpoint == LETTA_MODEL_ENDPOINT: # always set user id for openai requests
# override user id for inference.memgpt.ai if self.actor_id:
import uuid data.user = self.actor_id
if llm_config.model_endpoint == LETTA_MODEL_ENDPOINT:
if not self.actor_id:
# override user id for inference.letta.com
import uuid
data.user = str(uuid.UUID(int=0))
data.user = str(uuid.UUID(int=0))
data.model = "memgpt-openai" data.model = "memgpt-openai"
if data.tools is not None and len(data.tools) > 0: if data.tools is not None and len(data.tools) > 0:

View File

@ -79,8 +79,10 @@ def summarize_messages(
llm_config_no_inner_thoughts.put_inner_thoughts_in_kwargs = False llm_config_no_inner_thoughts.put_inner_thoughts_in_kwargs = False
llm_client = LLMClient.create( llm_client = LLMClient.create(
provider=llm_config_no_inner_thoughts.model_endpoint_type, provider_name=llm_config_no_inner_thoughts.provider_name,
provider_type=llm_config_no_inner_thoughts.model_endpoint_type,
put_inner_thoughts_first=False, put_inner_thoughts_first=False,
actor_id=agent_state.created_by_id,
) )
# try to use new client, otherwise fallback to old flow # try to use new client, otherwise fallback to old flow
# TODO: we can just directly call the LLM here? # TODO: we can just directly call the LLM here?

View File

@ -21,6 +21,8 @@ class Group(SqlalchemyBase, OrganizationMixin):
termination_token: Mapped[Optional[str]] = mapped_column(nullable=True, doc="") termination_token: Mapped[Optional[str]] = mapped_column(nullable=True, doc="")
max_turns: Mapped[Optional[int]] = mapped_column(nullable=True, doc="") max_turns: Mapped[Optional[int]] = mapped_column(nullable=True, doc="")
sleeptime_agent_frequency: Mapped[Optional[int]] = mapped_column(nullable=True, doc="") sleeptime_agent_frequency: Mapped[Optional[int]] = mapped_column(nullable=True, doc="")
max_message_buffer_length: Mapped[Optional[int]] = mapped_column(nullable=True, doc="")
min_message_buffer_length: Mapped[Optional[int]] = mapped_column(nullable=True, doc="")
turns_counter: Mapped[Optional[int]] = mapped_column(nullable=True, doc="") turns_counter: Mapped[Optional[int]] = mapped_column(nullable=True, doc="")
last_processed_message_id: Mapped[Optional[str]] = mapped_column(nullable=True, doc="") last_processed_message_id: Mapped[Optional[str]] = mapped_column(nullable=True, doc="")

View File

@ -1,5 +1,6 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from sqlalchemy import UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.orm import Mapped, mapped_column, relationship
from letta.orm.mixins import OrganizationMixin from letta.orm.mixins import OrganizationMixin
@ -15,9 +16,18 @@ class Provider(SqlalchemyBase, OrganizationMixin):
__tablename__ = "providers" __tablename__ = "providers"
__pydantic_model__ = PydanticProvider __pydantic_model__ = PydanticProvider
__table_args__ = (
UniqueConstraint(
"name",
"organization_id",
name="unique_name_organization_id",
),
)
name: Mapped[str] = mapped_column(nullable=False, doc="The name of the provider") name: Mapped[str] = mapped_column(nullable=False, doc="The name of the provider")
provider_type: Mapped[str] = mapped_column(nullable=True, doc="The type of the provider")
api_key: Mapped[str] = mapped_column(nullable=True, doc="API key used for requests to the provider.") api_key: Mapped[str] = mapped_column(nullable=True, doc="API key used for requests to the provider.")
base_url: Mapped[str] = mapped_column(nullable=True, doc="Base URL for the provider.")
# relationships # relationships
organization: Mapped["Organization"] = relationship("Organization", back_populates="providers") organization: Mapped["Organization"] = relationship("Organization", back_populates="providers")

View File

@ -56,7 +56,6 @@ class AgentState(OrmMetadataBase, validate_assignment=True):
name: str = Field(..., description="The name of the agent.") name: str = Field(..., description="The name of the agent.")
# tool rules # tool rules
tool_rules: Optional[List[ToolRule]] = Field(default=None, description="The list of tool rules.") tool_rules: Optional[List[ToolRule]] = Field(default=None, description="The list of tool rules.")
# in-context memory # in-context memory
message_ids: Optional[List[str]] = Field(default=None, description="The ids of the messages in the agent's in-context memory.") message_ids: Optional[List[str]] = Field(default=None, description="The ids of the messages in the agent's in-context memory.")

View File

@ -6,6 +6,17 @@ class ProviderType(str, Enum):
google_ai = "google_ai" google_ai = "google_ai"
google_vertex = "google_vertex" google_vertex = "google_vertex"
openai = "openai" openai = "openai"
letta = "letta"
deepseek = "deepseek"
lmstudio_openai = "lmstudio_openai"
xai = "xai"
mistral = "mistral"
ollama = "ollama"
groq = "groq"
together = "together"
azure = "azure"
vllm = "vllm"
bedrock = "bedrock"
class MessageRole(str, Enum): class MessageRole(str, Enum):

View File

@ -32,6 +32,14 @@ class Group(GroupBase):
sleeptime_agent_frequency: Optional[int] = Field(None, description="") sleeptime_agent_frequency: Optional[int] = Field(None, description="")
turns_counter: Optional[int] = Field(None, description="") turns_counter: Optional[int] = Field(None, description="")
last_processed_message_id: Optional[str] = Field(None, description="") last_processed_message_id: Optional[str] = Field(None, description="")
max_message_buffer_length: Optional[int] = Field(
None,
description="The desired maximum length of messages in the context window of the convo agent. This is a best effort, and may be off slightly due to user/assistant interleaving.",
)
min_message_buffer_length: Optional[int] = Field(
None,
description="The desired minimum length of messages in the context window of the convo agent. This is a best effort, and may be off-by-one due to user/assistant interleaving.",
)
class ManagerConfig(BaseModel): class ManagerConfig(BaseModel):
@ -87,11 +95,27 @@ class SleeptimeManagerUpdate(ManagerConfig):
class VoiceSleeptimeManager(ManagerConfig): class VoiceSleeptimeManager(ManagerConfig):
manager_type: Literal[ManagerType.voice_sleeptime] = Field(ManagerType.voice_sleeptime, description="") manager_type: Literal[ManagerType.voice_sleeptime] = Field(ManagerType.voice_sleeptime, description="")
manager_agent_id: str = Field(..., description="") manager_agent_id: str = Field(..., description="")
max_message_buffer_length: Optional[int] = Field(
None,
description="The desired maximum length of messages in the context window of the convo agent. This is a best effort, and may be off slightly due to user/assistant interleaving.",
)
min_message_buffer_length: Optional[int] = Field(
None,
description="The desired minimum length of messages in the context window of the convo agent. This is a best effort, and may be off-by-one due to user/assistant interleaving.",
)
class VoiceSleeptimeManagerUpdate(ManagerConfig): class VoiceSleeptimeManagerUpdate(ManagerConfig):
manager_type: Literal[ManagerType.voice_sleeptime] = Field(ManagerType.voice_sleeptime, description="") manager_type: Literal[ManagerType.voice_sleeptime] = Field(ManagerType.voice_sleeptime, description="")
manager_agent_id: Optional[str] = Field(None, description="") manager_agent_id: Optional[str] = Field(None, description="")
max_message_buffer_length: Optional[int] = Field(
None,
description="The desired maximum length of messages in the context window of the convo agent. This is a best effort, and may be off slightly due to user/assistant interleaving.",
)
min_message_buffer_length: Optional[int] = Field(
None,
description="The desired minimum length of messages in the context window of the convo agent. This is a best effort, and may be off-by-one due to user/assistant interleaving.",
)
# class SwarmGroup(ManagerConfig): # class SwarmGroup(ManagerConfig):

View File

@ -50,6 +50,7 @@ class LLMConfig(BaseModel):
"xai", "xai",
] = Field(..., description="The endpoint type for the model.") ] = Field(..., description="The endpoint type for the model.")
model_endpoint: Optional[str] = Field(None, description="The endpoint for the model.") model_endpoint: Optional[str] = Field(None, description="The endpoint for the model.")
provider_name: Optional[str] = Field(None, description="The provider name for the model.")
model_wrapper: Optional[str] = Field(None, description="The wrapper for the model.") model_wrapper: Optional[str] = Field(None, description="The wrapper for the model.")
context_window: int = Field(..., description="The context window size for the model.") context_window: int = Field(..., description="The context window size for the model.")
put_inner_thoughts_in_kwargs: Optional[bool] = Field( put_inner_thoughts_in_kwargs: Optional[bool] = Field(

View File

@ -2,8 +2,8 @@ from typing import Dict
LLM_HANDLE_OVERRIDES: Dict[str, Dict[str, str]] = { LLM_HANDLE_OVERRIDES: Dict[str, Dict[str, str]] = {
"anthropic": { "anthropic": {
"claude-3-5-haiku-20241022": "claude-3.5-haiku", "claude-3-5-haiku-20241022": "claude-3-5-haiku",
"claude-3-5-sonnet-20241022": "claude-3.5-sonnet", "claude-3-5-sonnet-20241022": "claude-3-5-sonnet",
"claude-3-opus-20240229": "claude-3-opus", "claude-3-opus-20240229": "claude-3-opus",
}, },
"openai": { "openai": {

View File

@ -1,6 +1,6 @@
import warnings import warnings
from datetime import datetime from datetime import datetime
from typing import List, Optional from typing import List, Literal, Optional
from pydantic import Field, model_validator from pydantic import Field, model_validator
@ -9,9 +9,11 @@ from letta.llm_api.azure_openai import get_azure_chat_completions_endpoint, get_
from letta.llm_api.azure_openai_constants import AZURE_MODEL_TO_CONTEXT_LENGTH from letta.llm_api.azure_openai_constants import AZURE_MODEL_TO_CONTEXT_LENGTH
from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.embedding_config_overrides import EMBEDDING_HANDLE_OVERRIDES from letta.schemas.embedding_config_overrides import EMBEDDING_HANDLE_OVERRIDES
from letta.schemas.enums import ProviderType
from letta.schemas.letta_base import LettaBase from letta.schemas.letta_base import LettaBase
from letta.schemas.llm_config import LLMConfig from letta.schemas.llm_config import LLMConfig
from letta.schemas.llm_config_overrides import LLM_HANDLE_OVERRIDES from letta.schemas.llm_config_overrides import LLM_HANDLE_OVERRIDES
from letta.settings import model_settings
class ProviderBase(LettaBase): class ProviderBase(LettaBase):
@ -21,10 +23,18 @@ class ProviderBase(LettaBase):
class Provider(ProviderBase): class Provider(ProviderBase):
id: Optional[str] = Field(None, description="The id of the provider, lazily created by the database manager.") id: Optional[str] = Field(None, description="The id of the provider, lazily created by the database manager.")
name: str = Field(..., description="The name of the provider") name: str = Field(..., description="The name of the provider")
provider_type: ProviderType = Field(..., description="The type of the provider")
api_key: Optional[str] = Field(None, description="API key used for requests to the provider.") api_key: Optional[str] = Field(None, description="API key used for requests to the provider.")
base_url: Optional[str] = Field(None, description="Base URL for the provider.")
organization_id: Optional[str] = Field(None, description="The organization id of the user") organization_id: Optional[str] = Field(None, description="The organization id of the user")
updated_at: Optional[datetime] = Field(None, description="The last update timestamp of the provider.") updated_at: Optional[datetime] = Field(None, description="The last update timestamp of the provider.")
@model_validator(mode="after")
def default_base_url(self):
if self.provider_type == ProviderType.openai and self.base_url is None:
self.base_url = model_settings.openai_api_base
return self
def resolve_identifier(self): def resolve_identifier(self):
if not self.id: if not self.id:
self.id = ProviderBase.generate_id(prefix=ProviderBase.__id_prefix__) self.id = ProviderBase.generate_id(prefix=ProviderBase.__id_prefix__)
@ -59,9 +69,41 @@ class Provider(ProviderBase):
return f"{self.name}/{model_name}" return f"{self.name}/{model_name}"
def cast_to_subtype(self):
match (self.provider_type):
case ProviderType.letta:
return LettaProvider(**self.model_dump(exclude_none=True))
case ProviderType.openai:
return OpenAIProvider(**self.model_dump(exclude_none=True))
case ProviderType.anthropic:
return AnthropicProvider(**self.model_dump(exclude_none=True))
case ProviderType.anthropic_bedrock:
return AnthropicBedrockProvider(**self.model_dump(exclude_none=True))
case ProviderType.ollama:
return OllamaProvider(**self.model_dump(exclude_none=True))
case ProviderType.google_ai:
return GoogleAIProvider(**self.model_dump(exclude_none=True))
case ProviderType.google_vertex:
return GoogleVertexProvider(**self.model_dump(exclude_none=True))
case ProviderType.azure:
return AzureProvider(**self.model_dump(exclude_none=True))
case ProviderType.groq:
return GroqProvider(**self.model_dump(exclude_none=True))
case ProviderType.together:
return TogetherProvider(**self.model_dump(exclude_none=True))
case ProviderType.vllm_chat_completions:
return VLLMChatCompletionsProvider(**self.model_dump(exclude_none=True))
case ProviderType.vllm_completions:
return VLLMCompletionsProvider(**self.model_dump(exclude_none=True))
case ProviderType.xai:
return XAIProvider(**self.model_dump(exclude_none=True))
case _:
raise ValueError(f"Unknown provider type: {self.provider_type}")
class ProviderCreate(ProviderBase): class ProviderCreate(ProviderBase):
name: str = Field(..., description="The name of the provider.") name: str = Field(..., description="The name of the provider.")
provider_type: ProviderType = Field(..., description="The type of the provider.")
api_key: str = Field(..., description="API key used for requests to the provider.") api_key: str = Field(..., description="API key used for requests to the provider.")
@ -70,8 +112,7 @@ class ProviderUpdate(ProviderBase):
class LettaProvider(Provider): class LettaProvider(Provider):
provider_type: Literal[ProviderType.letta] = Field(ProviderType.letta, description="The type of the provider.")
name: str = "letta"
def list_llm_models(self) -> List[LLMConfig]: def list_llm_models(self) -> List[LLMConfig]:
return [ return [
@ -81,6 +122,7 @@ class LettaProvider(Provider):
model_endpoint=LETTA_MODEL_ENDPOINT, model_endpoint=LETTA_MODEL_ENDPOINT,
context_window=8192, context_window=8192,
handle=self.get_handle("letta-free"), handle=self.get_handle("letta-free"),
provider_name=self.name,
) )
] ]
@ -98,7 +140,7 @@ class LettaProvider(Provider):
class OpenAIProvider(Provider): class OpenAIProvider(Provider):
name: str = "openai" provider_type: Literal[ProviderType.openai] = Field(ProviderType.openai, description="The type of the provider.")
api_key: str = Field(..., description="API key for the OpenAI API.") api_key: str = Field(..., description="API key for the OpenAI API.")
base_url: str = Field(..., description="Base URL for the OpenAI API.") base_url: str = Field(..., description="Base URL for the OpenAI API.")
@ -180,6 +222,7 @@ class OpenAIProvider(Provider):
model_endpoint=self.base_url, model_endpoint=self.base_url,
context_window=context_window_size, context_window=context_window_size,
handle=self.get_handle(model_name), handle=self.get_handle(model_name),
provider_name=self.name,
) )
) )
@ -235,7 +278,7 @@ class DeepSeekProvider(OpenAIProvider):
* It also does not support native function calling * It also does not support native function calling
""" """
name: str = "deepseek" provider_type: Literal[ProviderType.deepseek] = Field(ProviderType.deepseek, description="The type of the provider.")
base_url: str = Field("https://api.deepseek.com/v1", description="Base URL for the DeepSeek API.") base_url: str = Field("https://api.deepseek.com/v1", description="Base URL for the DeepSeek API.")
api_key: str = Field(..., description="API key for the DeepSeek API.") api_key: str = Field(..., description="API key for the DeepSeek API.")
@ -286,6 +329,7 @@ class DeepSeekProvider(OpenAIProvider):
context_window=context_window_size, context_window=context_window_size,
handle=self.get_handle(model_name), handle=self.get_handle(model_name),
put_inner_thoughts_in_kwargs=put_inner_thoughts_in_kwargs, put_inner_thoughts_in_kwargs=put_inner_thoughts_in_kwargs,
provider_name=self.name,
) )
) )
@ -297,7 +341,7 @@ class DeepSeekProvider(OpenAIProvider):
class LMStudioOpenAIProvider(OpenAIProvider): class LMStudioOpenAIProvider(OpenAIProvider):
name: str = "lmstudio-openai" provider_type: Literal[ProviderType.lmstudio_openai] = Field(ProviderType.lmstudio_openai, description="The type of the provider.")
base_url: str = Field(..., description="Base URL for the LMStudio OpenAI API.") base_url: str = Field(..., description="Base URL for the LMStudio OpenAI API.")
api_key: Optional[str] = Field(None, description="API key for the LMStudio API.") api_key: Optional[str] = Field(None, description="API key for the LMStudio API.")
@ -423,7 +467,7 @@ class LMStudioOpenAIProvider(OpenAIProvider):
class XAIProvider(OpenAIProvider): class XAIProvider(OpenAIProvider):
"""https://docs.x.ai/docs/api-reference""" """https://docs.x.ai/docs/api-reference"""
name: str = "xai" provider_type: Literal[ProviderType.xai] = Field(ProviderType.xai, description="The type of the provider.")
api_key: str = Field(..., description="API key for the xAI/Grok API.") api_key: str = Field(..., description="API key for the xAI/Grok API.")
base_url: str = Field("https://api.x.ai/v1", description="Base URL for the xAI/Grok API.") base_url: str = Field("https://api.x.ai/v1", description="Base URL for the xAI/Grok API.")
@ -476,6 +520,7 @@ class XAIProvider(OpenAIProvider):
model_endpoint=self.base_url, model_endpoint=self.base_url,
context_window=context_window_size, context_window=context_window_size,
handle=self.get_handle(model_name), handle=self.get_handle(model_name),
provider_name=self.name,
) )
) )
@ -487,7 +532,7 @@ class XAIProvider(OpenAIProvider):
class AnthropicProvider(Provider): class AnthropicProvider(Provider):
name: str = "anthropic" provider_type: Literal[ProviderType.anthropic] = Field(ProviderType.anthropic, description="The type of the provider.")
api_key: str = Field(..., description="API key for the Anthropic API.") api_key: str = Field(..., description="API key for the Anthropic API.")
base_url: str = "https://api.anthropic.com/v1" base_url: str = "https://api.anthropic.com/v1"
@ -563,6 +608,7 @@ class AnthropicProvider(Provider):
handle=self.get_handle(model["id"]), handle=self.get_handle(model["id"]),
put_inner_thoughts_in_kwargs=inner_thoughts_in_kwargs, put_inner_thoughts_in_kwargs=inner_thoughts_in_kwargs,
max_tokens=max_tokens, max_tokens=max_tokens,
provider_name=self.name,
) )
) )
return configs return configs
@ -572,7 +618,7 @@ class AnthropicProvider(Provider):
class MistralProvider(Provider): class MistralProvider(Provider):
name: str = "mistral" provider_type: Literal[ProviderType.mistral] = Field(ProviderType.mistral, description="The type of the provider.")
api_key: str = Field(..., description="API key for the Mistral API.") api_key: str = Field(..., description="API key for the Mistral API.")
base_url: str = "https://api.mistral.ai/v1" base_url: str = "https://api.mistral.ai/v1"
@ -596,6 +642,7 @@ class MistralProvider(Provider):
model_endpoint=self.base_url, model_endpoint=self.base_url,
context_window=model["max_context_length"], context_window=model["max_context_length"],
handle=self.get_handle(model["id"]), handle=self.get_handle(model["id"]),
provider_name=self.name,
) )
) )
@ -622,7 +669,7 @@ class OllamaProvider(OpenAIProvider):
See: https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion See: https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion
""" """
name: str = "ollama" provider_type: Literal[ProviderType.ollama] = Field(ProviderType.ollama, description="The type of the provider.")
base_url: str = Field(..., description="Base URL for the Ollama API.") base_url: str = Field(..., description="Base URL for the Ollama API.")
api_key: Optional[str] = Field(None, description="API key for the Ollama API (default: `None`).") api_key: Optional[str] = Field(None, description="API key for the Ollama API (default: `None`).")
default_prompt_formatter: str = Field( default_prompt_formatter: str = Field(
@ -652,6 +699,7 @@ class OllamaProvider(OpenAIProvider):
model_wrapper=self.default_prompt_formatter, model_wrapper=self.default_prompt_formatter,
context_window=context_window, context_window=context_window,
handle=self.get_handle(model["name"]), handle=self.get_handle(model["name"]),
provider_name=self.name,
) )
) )
return configs return configs
@ -734,7 +782,7 @@ class OllamaProvider(OpenAIProvider):
class GroqProvider(OpenAIProvider): class GroqProvider(OpenAIProvider):
name: str = "groq" provider_type: Literal[ProviderType.groq] = Field(ProviderType.groq, description="The type of the provider.")
base_url: str = "https://api.groq.com/openai/v1" base_url: str = "https://api.groq.com/openai/v1"
api_key: str = Field(..., description="API key for the Groq API.") api_key: str = Field(..., description="API key for the Groq API.")
@ -753,6 +801,7 @@ class GroqProvider(OpenAIProvider):
model_endpoint=self.base_url, model_endpoint=self.base_url,
context_window=model["context_window"], context_window=model["context_window"],
handle=self.get_handle(model["id"]), handle=self.get_handle(model["id"]),
provider_name=self.name,
) )
) )
return configs return configs
@ -773,7 +822,7 @@ class TogetherProvider(OpenAIProvider):
function calling support is limited. function calling support is limited.
""" """
name: str = "together" provider_type: Literal[ProviderType.together] = Field(ProviderType.together, description="The type of the provider.")
base_url: str = "https://api.together.ai/v1" base_url: str = "https://api.together.ai/v1"
api_key: str = Field(..., description="API key for the TogetherAI API.") api_key: str = Field(..., description="API key for the TogetherAI API.")
default_prompt_formatter: str = Field(..., description="Default prompt formatter (aka model wrapper) to use on vLLM /completions API.") default_prompt_formatter: str = Field(..., description="Default prompt formatter (aka model wrapper) to use on vLLM /completions API.")
@ -821,6 +870,7 @@ class TogetherProvider(OpenAIProvider):
model_wrapper=self.default_prompt_formatter, model_wrapper=self.default_prompt_formatter,
context_window=context_window_size, context_window=context_window_size,
handle=self.get_handle(model_name), handle=self.get_handle(model_name),
provider_name=self.name,
) )
) )
@ -874,7 +924,7 @@ class TogetherProvider(OpenAIProvider):
class GoogleAIProvider(Provider): class GoogleAIProvider(Provider):
# gemini # gemini
name: str = "google_ai" provider_type: Literal[ProviderType.google_ai] = Field(ProviderType.google_ai, description="The type of the provider.")
api_key: str = Field(..., description="API key for the Google AI API.") api_key: str = Field(..., description="API key for the Google AI API.")
base_url: str = "https://generativelanguage.googleapis.com" base_url: str = "https://generativelanguage.googleapis.com"
@ -889,7 +939,6 @@ class GoogleAIProvider(Provider):
# filter by model names # filter by model names
model_options = [mo[len("models/") :] if mo.startswith("models/") else mo for mo in model_options] model_options = [mo[len("models/") :] if mo.startswith("models/") else mo for mo in model_options]
# TODO remove manual filtering for gemini-pro
# Add support for all gemini models # Add support for all gemini models
model_options = [mo for mo in model_options if str(mo).startswith("gemini-")] model_options = [mo for mo in model_options if str(mo).startswith("gemini-")]
@ -903,6 +952,7 @@ class GoogleAIProvider(Provider):
context_window=self.get_model_context_window(model), context_window=self.get_model_context_window(model),
handle=self.get_handle(model), handle=self.get_handle(model),
max_tokens=8192, max_tokens=8192,
provider_name=self.name,
) )
) )
return configs return configs
@ -938,7 +988,7 @@ class GoogleAIProvider(Provider):
class GoogleVertexProvider(Provider): class GoogleVertexProvider(Provider):
name: str = "google_vertex" provider_type: Literal[ProviderType.google_vertex] = Field(ProviderType.google_vertex, description="The type of the provider.")
google_cloud_project: str = Field(..., description="GCP project ID for the Google Vertex API.") google_cloud_project: str = Field(..., description="GCP project ID for the Google Vertex API.")
google_cloud_location: str = Field(..., description="GCP region for the Google Vertex API.") google_cloud_location: str = Field(..., description="GCP region for the Google Vertex API.")
@ -955,6 +1005,7 @@ class GoogleVertexProvider(Provider):
context_window=context_length, context_window=context_length,
handle=self.get_handle(model), handle=self.get_handle(model),
max_tokens=8192, max_tokens=8192,
provider_name=self.name,
) )
) )
return configs return configs
@ -978,7 +1029,7 @@ class GoogleVertexProvider(Provider):
class AzureProvider(Provider): class AzureProvider(Provider):
name: str = "azure" provider_type: Literal[ProviderType.azure] = Field(ProviderType.azure, description="The type of the provider.")
latest_api_version: str = "2024-09-01-preview" # https://learn.microsoft.com/en-us/azure/ai-services/openai/api-version-deprecation latest_api_version: str = "2024-09-01-preview" # https://learn.microsoft.com/en-us/azure/ai-services/openai/api-version-deprecation
base_url: str = Field( base_url: str = Field(
..., description="Base URL for the Azure API endpoint. This should be specific to your org, e.g. `https://letta.openai.azure.com`." ..., description="Base URL for the Azure API endpoint. This should be specific to your org, e.g. `https://letta.openai.azure.com`."
@ -1011,6 +1062,7 @@ class AzureProvider(Provider):
model_endpoint=model_endpoint, model_endpoint=model_endpoint,
context_window=context_window_size, context_window=context_window_size,
handle=self.get_handle(model_name), handle=self.get_handle(model_name),
provider_name=self.name,
), ),
) )
return configs return configs
@ -1051,7 +1103,7 @@ class VLLMChatCompletionsProvider(Provider):
"""vLLM provider that treats vLLM as an OpenAI /chat/completions proxy""" """vLLM provider that treats vLLM as an OpenAI /chat/completions proxy"""
# NOTE: vLLM only serves one model at a time (so could configure that through env variables) # NOTE: vLLM only serves one model at a time (so could configure that through env variables)
name: str = "vllm" provider_type: Literal[ProviderType.vllm] = Field(ProviderType.vllm, description="The type of the provider.")
base_url: str = Field(..., description="Base URL for the vLLM API.") base_url: str = Field(..., description="Base URL for the vLLM API.")
def list_llm_models(self) -> List[LLMConfig]: def list_llm_models(self) -> List[LLMConfig]:
@ -1070,6 +1122,7 @@ class VLLMChatCompletionsProvider(Provider):
model_endpoint=self.base_url, model_endpoint=self.base_url,
context_window=model["max_model_len"], context_window=model["max_model_len"],
handle=self.get_handle(model["id"]), handle=self.get_handle(model["id"]),
provider_name=self.name,
) )
) )
return configs return configs
@ -1083,7 +1136,7 @@ class VLLMCompletionsProvider(Provider):
"""This uses /completions API as the backend, not /chat/completions, so we need to specify a model wrapper""" """This uses /completions API as the backend, not /chat/completions, so we need to specify a model wrapper"""
# NOTE: vLLM only serves one model at a time (so could configure that through env variables) # NOTE: vLLM only serves one model at a time (so could configure that through env variables)
name: str = "vllm" provider_type: Literal[ProviderType.vllm] = Field(ProviderType.vllm, description="The type of the provider.")
base_url: str = Field(..., description="Base URL for the vLLM API.") base_url: str = Field(..., description="Base URL for the vLLM API.")
default_prompt_formatter: str = Field(..., description="Default prompt formatter (aka model wrapper) to use on vLLM /completions API.") default_prompt_formatter: str = Field(..., description="Default prompt formatter (aka model wrapper) to use on vLLM /completions API.")
@ -1103,6 +1156,7 @@ class VLLMCompletionsProvider(Provider):
model_wrapper=self.default_prompt_formatter, model_wrapper=self.default_prompt_formatter,
context_window=model["max_model_len"], context_window=model["max_model_len"],
handle=self.get_handle(model["id"]), handle=self.get_handle(model["id"]),
provider_name=self.name,
) )
) )
return configs return configs
@ -1117,7 +1171,7 @@ class CohereProvider(OpenAIProvider):
class AnthropicBedrockProvider(Provider): class AnthropicBedrockProvider(Provider):
name: str = "bedrock" provider_type: Literal[ProviderType.bedrock] = Field(ProviderType.bedrock, description="The type of the provider.")
aws_region: str = Field(..., description="AWS region for Bedrock") aws_region: str = Field(..., description="AWS region for Bedrock")
def list_llm_models(self): def list_llm_models(self):
@ -1131,10 +1185,11 @@ class AnthropicBedrockProvider(Provider):
configs.append( configs.append(
LLMConfig( LLMConfig(
model=model_arn, model=model_arn,
model_endpoint_type=self.name, model_endpoint_type=self.provider_type.value,
model_endpoint=None, model_endpoint=None,
context_window=self.get_model_context_window(model_arn), context_window=self.get_model_context_window(model_arn),
handle=self.get_handle(model_arn), handle=self.get_handle(model_arn),
provider_name=self.name,
) )
) )
return configs return configs

View File

@ -11,13 +11,9 @@ from letta.constants import (
MCP_TOOL_TAG_NAME_PREFIX, MCP_TOOL_TAG_NAME_PREFIX,
) )
from letta.functions.ast_parsers import get_function_name_and_description from letta.functions.ast_parsers import get_function_name_and_description
from letta.functions.composio_helpers import generate_composio_tool_wrapper
from letta.functions.functions import derive_openai_json_schema, get_json_schema_from_module from letta.functions.functions import derive_openai_json_schema, get_json_schema_from_module
from letta.functions.helpers import ( from letta.functions.helpers import generate_langchain_tool_wrapper, generate_mcp_tool_wrapper, generate_model_from_args_json_schema
generate_composio_tool_wrapper,
generate_langchain_tool_wrapper,
generate_mcp_tool_wrapper,
generate_model_from_args_json_schema,
)
from letta.functions.mcp_client.types import MCPTool from letta.functions.mcp_client.types import MCPTool
from letta.functions.schema_generator import ( from letta.functions.schema_generator import (
generate_schema_from_args_schema_v2, generate_schema_from_args_schema_v2,
@ -176,8 +172,7 @@ class ToolCreate(LettaBase):
Returns: Returns:
Tool: A Letta Tool initialized with attributes derived from the Composio tool. Tool: A Letta Tool initialized with attributes derived from the Composio tool.
""" """
from composio import LogLevel from composio import ComposioToolSet, LogLevel
from composio_langchain import ComposioToolSet
composio_toolset = ComposioToolSet(logging_level=LogLevel.ERROR, lock=False) composio_toolset = ComposioToolSet(logging_level=LogLevel.ERROR, lock=False)
composio_action_schemas = composio_toolset.get_action_schemas(actions=[action_name], check_connected_accounts=False) composio_action_schemas = composio_toolset.get_action_schemas(actions=[action_name], check_connected_accounts=False)

View File

@ -14,6 +14,7 @@ from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.cors import CORSMiddleware from starlette.middleware.cors import CORSMiddleware
from letta.__init__ import __version__ from letta.__init__ import __version__
from letta.agents.exceptions import IncompatibleAgentType
from letta.constants import ADMIN_PREFIX, API_PREFIX, OPENAI_API_PREFIX from letta.constants import ADMIN_PREFIX, API_PREFIX, OPENAI_API_PREFIX
from letta.errors import BedrockPermissionError, LettaAgentNotFoundError, LettaUserNotFoundError from letta.errors import BedrockPermissionError, LettaAgentNotFoundError, LettaUserNotFoundError
from letta.jobs.scheduler import shutdown_cron_scheduler, start_cron_jobs from letta.jobs.scheduler import shutdown_cron_scheduler, start_cron_jobs
@ -173,6 +174,17 @@ def create_application() -> "FastAPI":
def shutdown_scheduler(): def shutdown_scheduler():
shutdown_cron_scheduler() shutdown_cron_scheduler()
@app.exception_handler(IncompatibleAgentType)
async def handle_incompatible_agent_type(request: Request, exc: IncompatibleAgentType):
return JSONResponse(
status_code=400,
content={
"detail": str(exc),
"expected_type": exc.expected_type,
"actual_type": exc.actual_type,
},
)
@app.exception_handler(Exception) @app.exception_handler(Exception)
async def generic_error_handler(request: Request, exc: Exception): async def generic_error_handler(request: Request, exc: Exception):
# Log the actual error for debugging # Log the actual error for debugging

View File

@ -12,7 +12,7 @@ from letta.schemas.enums import MessageStreamStatus
from letta.schemas.letta_message import LettaMessage from letta.schemas.letta_message import LettaMessage
from letta.schemas.message import Message from letta.schemas.message import Message
from letta.schemas.openai.chat_completion_response import ChatCompletionChunkResponse from letta.schemas.openai.chat_completion_response import ChatCompletionChunkResponse
from letta.server.rest_api.optimistic_json_parser import OptimisticJSONParser from letta.server.rest_api.json_parser import OptimisticJSONParser
from letta.streaming_interface import AgentChunkStreamingInterface from letta.streaming_interface import AgentChunkStreamingInterface
logger = get_logger(__name__) logger = get_logger(__name__)

View File

@ -28,7 +28,7 @@ from letta.schemas.letta_message import (
from letta.schemas.letta_message_content import ReasoningContent, RedactedReasoningContent, TextContent from letta.schemas.letta_message_content import ReasoningContent, RedactedReasoningContent, TextContent
from letta.schemas.message import Message from letta.schemas.message import Message
from letta.schemas.openai.chat_completion_response import ChatCompletionChunkResponse from letta.schemas.openai.chat_completion_response import ChatCompletionChunkResponse
from letta.server.rest_api.optimistic_json_parser import OptimisticJSONParser from letta.server.rest_api.json_parser import OptimisticJSONParser
from letta.streaming_interface import AgentChunkStreamingInterface from letta.streaming_interface import AgentChunkStreamingInterface
from letta.streaming_utils import FunctionArgumentsStreamHandler, JSONInnerThoughtsExtractor from letta.streaming_utils import FunctionArgumentsStreamHandler, JSONInnerThoughtsExtractor
from letta.utils import parse_json from letta.utils import parse_json
@ -291,7 +291,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
self.streaming_chat_completion_json_reader = FunctionArgumentsStreamHandler(json_key=assistant_message_tool_kwarg) self.streaming_chat_completion_json_reader = FunctionArgumentsStreamHandler(json_key=assistant_message_tool_kwarg)
# @matt's changes here, adopting new optimistic json parser # @matt's changes here, adopting new optimistic json parser
self.current_function_arguments = [] self.current_function_arguments = ""
self.optimistic_json_parser = OptimisticJSONParser() self.optimistic_json_parser = OptimisticJSONParser()
self.current_json_parse_result = {} self.current_json_parse_result = {}
@ -387,7 +387,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
def stream_start(self): def stream_start(self):
"""Initialize streaming by activating the generator and clearing any old chunks.""" """Initialize streaming by activating the generator and clearing any old chunks."""
self.streaming_chat_completion_mode_function_name = None self.streaming_chat_completion_mode_function_name = None
self.current_function_arguments = [] self.current_function_arguments = ""
self.current_json_parse_result = {} self.current_json_parse_result = {}
if not self._active: if not self._active:
@ -398,7 +398,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
def stream_end(self): def stream_end(self):
"""Clean up the stream by deactivating and clearing chunks.""" """Clean up the stream by deactivating and clearing chunks."""
self.streaming_chat_completion_mode_function_name = None self.streaming_chat_completion_mode_function_name = None
self.current_function_arguments = [] self.current_function_arguments = ""
self.current_json_parse_result = {} self.current_json_parse_result = {}
# if not self.streaming_chat_completion_mode and not self.nonstreaming_legacy_mode: # if not self.streaming_chat_completion_mode and not self.nonstreaming_legacy_mode:
@ -609,14 +609,13 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
# early exit to turn into content mode # early exit to turn into content mode
return None return None
if tool_call.function.arguments: if tool_call.function.arguments:
self.current_function_arguments.append(tool_call.function.arguments) self.current_function_arguments += tool_call.function.arguments
# if we're in the middle of parsing a send_message, we'll keep processing the JSON chunks # if we're in the middle of parsing a send_message, we'll keep processing the JSON chunks
if tool_call.function.arguments and self.streaming_chat_completion_mode_function_name == self.assistant_message_tool_name: if tool_call.function.arguments and self.streaming_chat_completion_mode_function_name == self.assistant_message_tool_name:
# Strip out any extras tokens # Strip out any extras tokens
# In the case that we just have the prefix of something, no message yet, then we should early exit to move to the next chunk # In the case that we just have the prefix of something, no message yet, then we should early exit to move to the next chunk
combined_args = "".join(self.current_function_arguments) parsed_args = self.optimistic_json_parser.parse(self.current_function_arguments)
parsed_args = self.optimistic_json_parser.parse(combined_args)
if parsed_args.get(self.assistant_message_tool_kwarg) and parsed_args.get( if parsed_args.get(self.assistant_message_tool_kwarg) and parsed_args.get(
self.assistant_message_tool_kwarg self.assistant_message_tool_kwarg
@ -686,7 +685,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
# updates_inner_thoughts = "" # updates_inner_thoughts = ""
# else: # OpenAI # else: # OpenAI
# updates_main_json, updates_inner_thoughts = self.function_args_reader.process_fragment(tool_call.function.arguments) # updates_main_json, updates_inner_thoughts = self.function_args_reader.process_fragment(tool_call.function.arguments)
self.current_function_arguments.append(tool_call.function.arguments) self.current_function_arguments += tool_call.function.arguments
updates_main_json, updates_inner_thoughts = self.function_args_reader.process_fragment(tool_call.function.arguments) updates_main_json, updates_inner_thoughts = self.function_args_reader.process_fragment(tool_call.function.arguments)
# If we have inner thoughts, we should output them as a chunk # If we have inner thoughts, we should output them as a chunk
@ -805,8 +804,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
# TODO: THIS IS HORRIBLE # TODO: THIS IS HORRIBLE
# TODO: WE USE THE OLD JSON PARSER EARLIER (WHICH DOES NOTHING) AND NOW THE NEW JSON PARSER # TODO: WE USE THE OLD JSON PARSER EARLIER (WHICH DOES NOTHING) AND NOW THE NEW JSON PARSER
# TODO: THIS IS TOTALLY WRONG AND BAD, BUT SAVING FOR A LARGER REWRITE IN THE NEAR FUTURE # TODO: THIS IS TOTALLY WRONG AND BAD, BUT SAVING FOR A LARGER REWRITE IN THE NEAR FUTURE
combined_args = "".join(self.current_function_arguments) parsed_args = self.optimistic_json_parser.parse(self.current_function_arguments)
parsed_args = self.optimistic_json_parser.parse(combined_args)
if parsed_args.get(self.assistant_message_tool_kwarg) and parsed_args.get( if parsed_args.get(self.assistant_message_tool_kwarg) and parsed_args.get(
self.assistant_message_tool_kwarg self.assistant_message_tool_kwarg

View File

@ -1,7 +1,43 @@
import json import json
from abc import ABC, abstractmethod
from typing import Any
from pydantic_core import from_json
from letta.log import get_logger
logger = get_logger(__name__)
class OptimisticJSONParser: class JSONParser(ABC):
@abstractmethod
def parse(self, input_str: str) -> Any:
raise NotImplementedError()
class PydanticJSONParser(JSONParser):
"""
https://docs.pydantic.dev/latest/concepts/json/#json-parsing
If `strict` is True, we will not allow for partial parsing of JSON.
Compared with `OptimisticJSONParser`, this parser is more strict.
Note: This will not partially parse strings which may be decrease parsing speed for message strings
"""
def __init__(self, strict=False):
self.strict = strict
def parse(self, input_str: str) -> Any:
if not input_str:
return {}
try:
return from_json(input_str, allow_partial="trailing-strings" if not self.strict else False)
except ValueError as e:
logger.error(f"Failed to parse JSON: {e}")
raise
class OptimisticJSONParser(JSONParser):
""" """
A JSON parser that attempts to parse a given string using `json.loads`, A JSON parser that attempts to parse a given string using `json.loads`,
and if that fails, it parses as much valid JSON as possible while and if that fails, it parses as much valid JSON as possible while
@ -13,25 +49,25 @@ class OptimisticJSONParser:
def __init__(self, strict=False): def __init__(self, strict=False):
self.strict = strict self.strict = strict
self.parsers = { self.parsers = {
" ": self.parse_space, " ": self._parse_space,
"\r": self.parse_space, "\r": self._parse_space,
"\n": self.parse_space, "\n": self._parse_space,
"\t": self.parse_space, "\t": self._parse_space,
"[": self.parse_array, "[": self._parse_array,
"{": self.parse_object, "{": self._parse_object,
'"': self.parse_string, '"': self._parse_string,
"t": self.parse_true, "t": self._parse_true,
"f": self.parse_false, "f": self._parse_false,
"n": self.parse_null, "n": self._parse_null,
} }
# Register number parser for digits and signs # Register number parser for digits and signs
for char in "0123456789.-": for char in "0123456789.-":
self.parsers[char] = self.parse_number self.parsers[char] = self.parse_number
self.last_parse_reminding = None self.last_parse_reminding = None
self.on_extra_token = self.default_on_extra_token self.on_extra_token = self._default_on_extra_token
def default_on_extra_token(self, text, data, reminding): def _default_on_extra_token(self, text, data, reminding):
print(f"Parsed JSON with extra tokens: {data}, remaining: {reminding}") print(f"Parsed JSON with extra tokens: {data}, remaining: {reminding}")
def parse(self, input_str): def parse(self, input_str):
@ -45,7 +81,7 @@ class OptimisticJSONParser:
try: try:
return json.loads(input_str) return json.loads(input_str)
except json.JSONDecodeError as decode_error: except json.JSONDecodeError as decode_error:
data, reminding = self.parse_any(input_str, decode_error) data, reminding = self._parse_any(input_str, decode_error)
self.last_parse_reminding = reminding self.last_parse_reminding = reminding
if self.on_extra_token and reminding: if self.on_extra_token and reminding:
self.on_extra_token(input_str, data, reminding) self.on_extra_token(input_str, data, reminding)
@ -53,7 +89,7 @@ class OptimisticJSONParser:
else: else:
return json.loads("{}") return json.loads("{}")
def parse_any(self, input_str, decode_error): def _parse_any(self, input_str, decode_error):
"""Determine which parser to use based on the first character.""" """Determine which parser to use based on the first character."""
if not input_str: if not input_str:
raise decode_error raise decode_error
@ -62,11 +98,11 @@ class OptimisticJSONParser:
raise decode_error raise decode_error
return parser(input_str, decode_error) return parser(input_str, decode_error)
def parse_space(self, input_str, decode_error): def _parse_space(self, input_str, decode_error):
"""Strip leading whitespace and parse again.""" """Strip leading whitespace and parse again."""
return self.parse_any(input_str.strip(), decode_error) return self._parse_any(input_str.strip(), decode_error)
def parse_array(self, input_str, decode_error): def _parse_array(self, input_str, decode_error):
"""Parse a JSON array, returning the list and remaining string.""" """Parse a JSON array, returning the list and remaining string."""
# Skip the '[' # Skip the '['
input_str = input_str[1:] input_str = input_str[1:]
@ -77,7 +113,7 @@ class OptimisticJSONParser:
# Skip the ']' # Skip the ']'
input_str = input_str[1:] input_str = input_str[1:]
break break
value, input_str = self.parse_any(input_str, decode_error) value, input_str = self._parse_any(input_str, decode_error)
array_values.append(value) array_values.append(value)
input_str = input_str.strip() input_str = input_str.strip()
if input_str.startswith(","): if input_str.startswith(","):
@ -85,7 +121,7 @@ class OptimisticJSONParser:
input_str = input_str[1:].strip() input_str = input_str[1:].strip()
return array_values, input_str return array_values, input_str
def parse_object(self, input_str, decode_error): def _parse_object(self, input_str, decode_error):
"""Parse a JSON object, returning the dict and remaining string.""" """Parse a JSON object, returning the dict and remaining string."""
# Skip the '{' # Skip the '{'
input_str = input_str[1:] input_str = input_str[1:]
@ -96,7 +132,7 @@ class OptimisticJSONParser:
# Skip the '}' # Skip the '}'
input_str = input_str[1:] input_str = input_str[1:]
break break
key, input_str = self.parse_any(input_str, decode_error) key, input_str = self._parse_any(input_str, decode_error)
input_str = input_str.strip() input_str = input_str.strip()
if not input_str or input_str[0] == "}": if not input_str or input_str[0] == "}":
@ -113,7 +149,7 @@ class OptimisticJSONParser:
input_str = input_str[1:] input_str = input_str[1:]
break break
value, input_str = self.parse_any(input_str, decode_error) value, input_str = self._parse_any(input_str, decode_error)
obj[key] = value obj[key] = value
input_str = input_str.strip() input_str = input_str.strip()
if input_str.startswith(","): if input_str.startswith(","):
@ -121,7 +157,7 @@ class OptimisticJSONParser:
input_str = input_str[1:].strip() input_str = input_str[1:].strip()
return obj, input_str return obj, input_str
def parse_string(self, input_str, decode_error): def _parse_string(self, input_str, decode_error):
"""Parse a JSON string, respecting escaped quotes if present.""" """Parse a JSON string, respecting escaped quotes if present."""
end = input_str.find('"', 1) end = input_str.find('"', 1)
while end != -1 and input_str[end - 1] == "\\": while end != -1 and input_str[end - 1] == "\\":
@ -166,19 +202,19 @@ class OptimisticJSONParser:
return num, remainder return num, remainder
def parse_true(self, input_str, decode_error): def _parse_true(self, input_str, decode_error):
"""Parse a 'true' value.""" """Parse a 'true' value."""
if input_str.startswith(("t", "T")): if input_str.startswith(("t", "T")):
return True, input_str[4:] return True, input_str[4:]
raise decode_error raise decode_error
def parse_false(self, input_str, decode_error): def _parse_false(self, input_str, decode_error):
"""Parse a 'false' value.""" """Parse a 'false' value."""
if input_str.startswith(("f", "F")): if input_str.startswith(("f", "F")):
return False, input_str[5:] return False, input_str[5:]
raise decode_error raise decode_error
def parse_null(self, input_str, decode_error): def _parse_null(self, input_str, decode_error):
"""Parse a 'null' value.""" """Parse a 'null' value."""
if input_str.startswith("n"): if input_str.startswith("n"):
return None, input_str[4:] return None, input_str[4:]

View File

@ -678,7 +678,7 @@ async def send_message_streaming(
server: SyncServer = Depends(get_letta_server), server: SyncServer = Depends(get_letta_server),
request: LettaStreamingRequest = Body(...), request: LettaStreamingRequest = Body(...),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
): ) -> StreamingResponse | LettaResponse:
""" """
Process a user message and return the agent's response. Process a user message and return the agent's response.
This endpoint accepts a message from a user and processes it through the agent. This endpoint accepts a message from a user and processes it through the agent.

View File

@ -1,6 +1,6 @@
from typing import TYPE_CHECKING, List from typing import TYPE_CHECKING, List, Optional
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends, Query
from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.llm_config import LLMConfig from letta.schemas.llm_config import LLMConfig
@ -14,10 +14,11 @@ router = APIRouter(prefix="/models", tags=["models", "llms"])
@router.get("/", response_model=List[LLMConfig], operation_id="list_models") @router.get("/", response_model=List[LLMConfig], operation_id="list_models")
def list_llm_models( def list_llm_models(
byok_only: Optional[bool] = Query(None),
server: "SyncServer" = Depends(get_letta_server), server: "SyncServer" = Depends(get_letta_server),
): ):
models = server.list_llm_models() models = server.list_llm_models(byok_only=byok_only)
# print(models) # print(models)
return models return models

View File

@ -2,6 +2,7 @@ from typing import TYPE_CHECKING, List, Optional
from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query
from letta.schemas.enums import ProviderType
from letta.schemas.providers import Provider, ProviderCreate, ProviderUpdate from letta.schemas.providers import Provider, ProviderCreate, ProviderUpdate
from letta.server.rest_api.utils import get_letta_server from letta.server.rest_api.utils import get_letta_server
@ -13,6 +14,8 @@ router = APIRouter(prefix="/providers", tags=["providers"])
@router.get("/", response_model=List[Provider], operation_id="list_providers") @router.get("/", response_model=List[Provider], operation_id="list_providers")
def list_providers( def list_providers(
name: Optional[str] = Query(None),
provider_type: Optional[ProviderType] = Query(None),
after: Optional[str] = Query(None), after: Optional[str] = Query(None),
limit: Optional[int] = Query(50), limit: Optional[int] = Query(50),
actor_id: Optional[str] = Header(None, alias="user_id"), actor_id: Optional[str] = Header(None, alias="user_id"),
@ -23,7 +26,7 @@ def list_providers(
""" """
try: try:
actor = server.user_manager.get_user_or_default(user_id=actor_id) actor = server.user_manager.get_user_or_default(user_id=actor_id)
providers = server.provider_manager.list_providers(after=after, limit=limit, actor=actor) providers = server.provider_manager.list_providers(after=after, limit=limit, actor=actor, name=name, provider_type=provider_type)
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:

View File

@ -54,8 +54,6 @@ async def create_voice_chat_completions(
block_manager=server.block_manager, block_manager=server.block_manager,
passage_manager=server.passage_manager, passage_manager=server.passage_manager,
actor=actor, actor=actor,
message_buffer_limit=8,
message_buffer_min=4,
) )
# Return the streaming generator # Return the streaming generator

View File

@ -16,6 +16,7 @@ from pydantic import BaseModel
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, FUNC_FAILED_HEARTBEAT_MESSAGE, REQ_HEARTBEAT_MESSAGE from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, FUNC_FAILED_HEARTBEAT_MESSAGE, REQ_HEARTBEAT_MESSAGE
from letta.errors import ContextWindowExceededError, RateLimitExceededError from letta.errors import ContextWindowExceededError, RateLimitExceededError
from letta.helpers.datetime_helpers import get_utc_time from letta.helpers.datetime_helpers import get_utc_time
from letta.helpers.message_helper import convert_message_creates_to_messages
from letta.log import get_logger from letta.log import get_logger
from letta.schemas.enums import MessageRole from letta.schemas.enums import MessageRole
from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, RedactedReasoningContent, TextContent from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, RedactedReasoningContent, TextContent
@ -143,27 +144,15 @@ def log_error_to_sentry(e):
def create_input_messages(input_messages: List[MessageCreate], agent_id: str, actor: User) -> List[Message]: def create_input_messages(input_messages: List[MessageCreate], agent_id: str, actor: User) -> List[Message]:
""" """
Converts a user input message into the internal structured format. Converts a user input message into the internal structured format.
"""
new_messages = []
for input_message in input_messages:
# Construct the Message object
new_message = Message(
id=f"message-{uuid.uuid4()}",
role=input_message.role,
content=input_message.content,
name=input_message.name,
otid=input_message.otid,
sender_id=input_message.sender_id,
organization_id=actor.organization_id,
agent_id=agent_id,
model=None,
tool_calls=None,
tool_call_id=None,
created_at=get_utc_time(),
)
new_messages.append(new_message)
return new_messages TODO (cliandy): this effectively duplicates the functionality of `convert_message_creates_to_messages`,
we should unify this when it's clear what message attributes we need.
"""
messages = convert_message_creates_to_messages(input_messages, agent_id, wrap_user_message=False, wrap_system_message=False)
for message in messages:
message.organization_id = actor.organization_id
return messages
def create_letta_messages_from_llm_response( def create_letta_messages_from_llm_response(

View File

@ -268,10 +268,11 @@ class SyncServer(Server):
) )
# collect providers (always has Letta as a default) # collect providers (always has Letta as a default)
self._enabled_providers: List[Provider] = [LettaProvider()] self._enabled_providers: List[Provider] = [LettaProvider(name="letta")]
if model_settings.openai_api_key: if model_settings.openai_api_key:
self._enabled_providers.append( self._enabled_providers.append(
OpenAIProvider( OpenAIProvider(
name="openai",
api_key=model_settings.openai_api_key, api_key=model_settings.openai_api_key,
base_url=model_settings.openai_api_base, base_url=model_settings.openai_api_base,
) )
@ -279,12 +280,14 @@ class SyncServer(Server):
if model_settings.anthropic_api_key: if model_settings.anthropic_api_key:
self._enabled_providers.append( self._enabled_providers.append(
AnthropicProvider( AnthropicProvider(
name="anthropic",
api_key=model_settings.anthropic_api_key, api_key=model_settings.anthropic_api_key,
) )
) )
if model_settings.ollama_base_url: if model_settings.ollama_base_url:
self._enabled_providers.append( self._enabled_providers.append(
OllamaProvider( OllamaProvider(
name="ollama",
base_url=model_settings.ollama_base_url, base_url=model_settings.ollama_base_url,
api_key=None, api_key=None,
default_prompt_formatter=model_settings.default_prompt_formatter, default_prompt_formatter=model_settings.default_prompt_formatter,
@ -293,12 +296,14 @@ class SyncServer(Server):
if model_settings.gemini_api_key: if model_settings.gemini_api_key:
self._enabled_providers.append( self._enabled_providers.append(
GoogleAIProvider( GoogleAIProvider(
name="google_ai",
api_key=model_settings.gemini_api_key, api_key=model_settings.gemini_api_key,
) )
) )
if model_settings.google_cloud_location and model_settings.google_cloud_project: if model_settings.google_cloud_location and model_settings.google_cloud_project:
self._enabled_providers.append( self._enabled_providers.append(
GoogleVertexProvider( GoogleVertexProvider(
name="google_vertex",
google_cloud_project=model_settings.google_cloud_project, google_cloud_project=model_settings.google_cloud_project,
google_cloud_location=model_settings.google_cloud_location, google_cloud_location=model_settings.google_cloud_location,
) )
@ -307,6 +312,7 @@ class SyncServer(Server):
assert model_settings.azure_api_version, "AZURE_API_VERSION is required" assert model_settings.azure_api_version, "AZURE_API_VERSION is required"
self._enabled_providers.append( self._enabled_providers.append(
AzureProvider( AzureProvider(
name="azure",
api_key=model_settings.azure_api_key, api_key=model_settings.azure_api_key,
base_url=model_settings.azure_base_url, base_url=model_settings.azure_base_url,
api_version=model_settings.azure_api_version, api_version=model_settings.azure_api_version,
@ -315,12 +321,14 @@ class SyncServer(Server):
if model_settings.groq_api_key: if model_settings.groq_api_key:
self._enabled_providers.append( self._enabled_providers.append(
GroqProvider( GroqProvider(
name="groq",
api_key=model_settings.groq_api_key, api_key=model_settings.groq_api_key,
) )
) )
if model_settings.together_api_key: if model_settings.together_api_key:
self._enabled_providers.append( self._enabled_providers.append(
TogetherProvider( TogetherProvider(
name="together",
api_key=model_settings.together_api_key, api_key=model_settings.together_api_key,
default_prompt_formatter=model_settings.default_prompt_formatter, default_prompt_formatter=model_settings.default_prompt_formatter,
) )
@ -329,6 +337,7 @@ class SyncServer(Server):
# vLLM exposes both a /chat/completions and a /completions endpoint # vLLM exposes both a /chat/completions and a /completions endpoint
self._enabled_providers.append( self._enabled_providers.append(
VLLMCompletionsProvider( VLLMCompletionsProvider(
name="vllm",
base_url=model_settings.vllm_api_base, base_url=model_settings.vllm_api_base,
default_prompt_formatter=model_settings.default_prompt_formatter, default_prompt_formatter=model_settings.default_prompt_formatter,
) )
@ -338,12 +347,14 @@ class SyncServer(Server):
# e.g. "... --enable-auto-tool-choice --tool-call-parser hermes" # e.g. "... --enable-auto-tool-choice --tool-call-parser hermes"
self._enabled_providers.append( self._enabled_providers.append(
VLLMChatCompletionsProvider( VLLMChatCompletionsProvider(
name="vllm",
base_url=model_settings.vllm_api_base, base_url=model_settings.vllm_api_base,
) )
) )
if model_settings.aws_access_key and model_settings.aws_secret_access_key and model_settings.aws_region: if model_settings.aws_access_key and model_settings.aws_secret_access_key and model_settings.aws_region:
self._enabled_providers.append( self._enabled_providers.append(
AnthropicBedrockProvider( AnthropicBedrockProvider(
name="bedrock",
aws_region=model_settings.aws_region, aws_region=model_settings.aws_region,
) )
) )
@ -355,11 +366,11 @@ class SyncServer(Server):
if model_settings.lmstudio_base_url.endswith("/v1") if model_settings.lmstudio_base_url.endswith("/v1")
else model_settings.lmstudio_base_url + "/v1" else model_settings.lmstudio_base_url + "/v1"
) )
self._enabled_providers.append(LMStudioOpenAIProvider(base_url=lmstudio_url)) self._enabled_providers.append(LMStudioOpenAIProvider(name="lmstudio_openai", base_url=lmstudio_url))
if model_settings.deepseek_api_key: if model_settings.deepseek_api_key:
self._enabled_providers.append(DeepSeekProvider(api_key=model_settings.deepseek_api_key)) self._enabled_providers.append(DeepSeekProvider(name="deepseek", api_key=model_settings.deepseek_api_key))
if model_settings.xai_api_key: if model_settings.xai_api_key:
self._enabled_providers.append(XAIProvider(api_key=model_settings.xai_api_key)) self._enabled_providers.append(XAIProvider(name="xai", api_key=model_settings.xai_api_key))
# For MCP # For MCP
"""Initialize the MCP clients (there may be multiple)""" """Initialize the MCP clients (there may be multiple)"""
@ -862,6 +873,8 @@ class SyncServer(Server):
agent_ids=[voice_sleeptime_agent.id], agent_ids=[voice_sleeptime_agent.id],
manager_config=VoiceSleeptimeManager( manager_config=VoiceSleeptimeManager(
manager_agent_id=main_agent.id, manager_agent_id=main_agent.id,
max_message_buffer_length=constants.DEFAULT_MAX_MESSAGE_BUFFER_LENGTH,
min_message_buffer_length=constants.DEFAULT_MIN_MESSAGE_BUFFER_LENGTH,
), ),
), ),
actor=actor, actor=actor,
@ -1182,10 +1195,10 @@ class SyncServer(Server):
except NoResultFound: except NoResultFound:
raise HTTPException(status_code=404, detail=f"Organization with id {org_id} not found") raise HTTPException(status_code=404, detail=f"Organization with id {org_id} not found")
def list_llm_models(self) -> List[LLMConfig]: def list_llm_models(self, byok_only: bool = False) -> List[LLMConfig]:
"""List available models""" """List available models"""
llm_models = [] llm_models = []
for provider in self.get_enabled_providers(): for provider in self.get_enabled_providers(byok_only=byok_only):
try: try:
llm_models.extend(provider.list_llm_models()) llm_models.extend(provider.list_llm_models())
except Exception as e: except Exception as e:
@ -1205,11 +1218,12 @@ class SyncServer(Server):
warnings.warn(f"An error occurred while listing embedding models for provider {provider}: {e}") warnings.warn(f"An error occurred while listing embedding models for provider {provider}: {e}")
return embedding_models return embedding_models
def get_enabled_providers(self): def get_enabled_providers(self, byok_only: bool = False):
providers_from_db = {p.name: p.cast_to_subtype() for p in self.provider_manager.list_providers()}
if byok_only:
return list(providers_from_db.values())
providers_from_env = {p.name: p for p in self._enabled_providers} providers_from_env = {p.name: p for p in self._enabled_providers}
providers_from_db = {p.name: p for p in self.provider_manager.list_providers()} return list(providers_from_env.values()) + list(providers_from_db.values())
# Merge the two dictionaries, keeping the values from providers_from_db where conflicts occur
return {**providers_from_env, **providers_from_db}.values()
@trace_method @trace_method
def get_llm_config_from_handle( def get_llm_config_from_handle(
@ -1294,7 +1308,7 @@ class SyncServer(Server):
return embedding_config return embedding_config
def get_provider_from_name(self, provider_name: str) -> Provider: def get_provider_from_name(self, provider_name: str) -> Provider:
providers = [provider for provider in self._enabled_providers if provider.name == provider_name] providers = [provider for provider in self.get_enabled_providers() if provider.name == provider_name]
if not providers: if not providers:
raise ValueError(f"Provider {provider_name} is not supported") raise ValueError(f"Provider {provider_name} is not supported")
elif len(providers) > 1: elif len(providers) > 1:

View File

@ -80,6 +80,12 @@ class GroupManager:
case ManagerType.voice_sleeptime: case ManagerType.voice_sleeptime:
new_group.manager_type = ManagerType.voice_sleeptime new_group.manager_type = ManagerType.voice_sleeptime
new_group.manager_agent_id = group.manager_config.manager_agent_id new_group.manager_agent_id = group.manager_config.manager_agent_id
max_message_buffer_length = group.manager_config.max_message_buffer_length
min_message_buffer_length = group.manager_config.min_message_buffer_length
# Safety check for buffer length range
self.ensure_buffer_length_range_valid(max_value=max_message_buffer_length, min_value=min_message_buffer_length)
new_group.max_message_buffer_length = max_message_buffer_length
new_group.min_message_buffer_length = min_message_buffer_length
case _: case _:
raise ValueError(f"Unsupported manager type: {group.manager_config.manager_type}") raise ValueError(f"Unsupported manager type: {group.manager_config.manager_type}")
@ -97,6 +103,8 @@ class GroupManager:
group = GroupModel.read(db_session=session, identifier=group_id, actor=actor) group = GroupModel.read(db_session=session, identifier=group_id, actor=actor)
sleeptime_agent_frequency = None sleeptime_agent_frequency = None
max_message_buffer_length = None
min_message_buffer_length = None
max_turns = None max_turns = None
termination_token = None termination_token = None
manager_agent_id = None manager_agent_id = None
@ -117,11 +125,24 @@ class GroupManager:
sleeptime_agent_frequency = group_update.manager_config.sleeptime_agent_frequency sleeptime_agent_frequency = group_update.manager_config.sleeptime_agent_frequency
if sleeptime_agent_frequency and group.turns_counter is None: if sleeptime_agent_frequency and group.turns_counter is None:
group.turns_counter = -1 group.turns_counter = -1
case ManagerType.voice_sleeptime:
manager_agent_id = group_update.manager_config.manager_agent_id
max_message_buffer_length = group_update.manager_config.max_message_buffer_length or group.max_message_buffer_length
min_message_buffer_length = group_update.manager_config.min_message_buffer_length or group.min_message_buffer_length
if sleeptime_agent_frequency and group.turns_counter is None:
group.turns_counter = -1
case _: case _:
raise ValueError(f"Unsupported manager type: {group_update.manager_config.manager_type}") raise ValueError(f"Unsupported manager type: {group_update.manager_config.manager_type}")
# Safety check for buffer length range
self.ensure_buffer_length_range_valid(max_value=max_message_buffer_length, min_value=min_message_buffer_length)
if sleeptime_agent_frequency: if sleeptime_agent_frequency:
group.sleeptime_agent_frequency = sleeptime_agent_frequency group.sleeptime_agent_frequency = sleeptime_agent_frequency
if max_message_buffer_length:
group.max_message_buffer_length = max_message_buffer_length
if min_message_buffer_length:
group.min_message_buffer_length = min_message_buffer_length
if max_turns: if max_turns:
group.max_turns = max_turns group.max_turns = max_turns
if termination_token: if termination_token:
@ -274,3 +295,40 @@ class GroupManager:
if manager_agent: if manager_agent:
for block in blocks: for block in blocks:
session.add(BlocksAgents(agent_id=manager_agent.id, block_id=block.id, block_label=block.label)) session.add(BlocksAgents(agent_id=manager_agent.id, block_id=block.id, block_label=block.label))
@staticmethod
def ensure_buffer_length_range_valid(
max_value: Optional[int],
min_value: Optional[int],
max_name: str = "max_message_buffer_length",
min_name: str = "min_message_buffer_length",
) -> None:
"""
1) Both-or-none: if one is set, the other must be set.
2) Both must be ints > 4.
3) max_value must be strictly greater than min_value.
"""
# 1) require both-or-none
if (max_value is None) != (min_value is None):
raise ValueError(
f"Both '{max_name}' and '{min_name}' must be provided together " f"(got {max_name}={max_value}, {min_name}={min_value})"
)
# no further checks if neither is provided
if max_value is None:
return
# 2) type & lowerbound checks
if not isinstance(max_value, int) or not isinstance(min_value, int):
raise ValueError(
f"Both '{max_name}' and '{min_name}' must be integers "
f"(got {max_name}={type(max_value).__name__}, {min_name}={type(min_value).__name__})"
)
if max_value <= 4 or min_value <= 4:
raise ValueError(
f"Both '{max_name}' and '{min_name}' must be greater than 4 " f"(got {max_name}={max_value}, {min_name}={min_value})"
)
# 3) ordering
if max_value <= min_value:
raise ValueError(f"'{max_name}' must be greater than '{min_name}' " f"(got {max_name}={max_value} <= {min_name}={min_value})")

View File

@ -1,6 +1,7 @@
from typing import List, Optional from typing import List, Optional, Union
from letta.orm.provider import Provider as ProviderModel from letta.orm.provider import Provider as ProviderModel
from letta.schemas.enums import ProviderType
from letta.schemas.providers import Provider as PydanticProvider from letta.schemas.providers import Provider as PydanticProvider
from letta.schemas.providers import ProviderUpdate from letta.schemas.providers import ProviderUpdate
from letta.schemas.user import User as PydanticUser from letta.schemas.user import User as PydanticUser
@ -18,6 +19,9 @@ class ProviderManager:
def create_provider(self, provider: PydanticProvider, actor: PydanticUser) -> PydanticProvider: def create_provider(self, provider: PydanticProvider, actor: PydanticUser) -> PydanticProvider:
"""Create a new provider if it doesn't already exist.""" """Create a new provider if it doesn't already exist."""
with self.session_maker() as session: with self.session_maker() as session:
if provider.name == provider.provider_type.value:
raise ValueError("Provider name must be unique and different from provider type")
# Assign the organization id based on the actor # Assign the organization id based on the actor
provider.organization_id = actor.organization_id provider.organization_id = actor.organization_id
@ -59,29 +63,36 @@ class ProviderManager:
session.commit() session.commit()
@enforce_types @enforce_types
def list_providers(self, after: Optional[str] = None, limit: Optional[int] = 50, actor: PydanticUser = None) -> List[PydanticProvider]: def list_providers(
self,
name: Optional[str] = None,
provider_type: Optional[ProviderType] = None,
after: Optional[str] = None,
limit: Optional[int] = 50,
actor: PydanticUser = None,
) -> List[PydanticProvider]:
"""List all providers with optional pagination.""" """List all providers with optional pagination."""
filter_kwargs = {}
if name:
filter_kwargs["name"] = name
if provider_type:
filter_kwargs["provider_type"] = provider_type
with self.session_maker() as session: with self.session_maker() as session:
providers = ProviderModel.list( providers = ProviderModel.list(
db_session=session, db_session=session,
after=after, after=after,
limit=limit, limit=limit,
actor=actor, actor=actor,
**filter_kwargs,
) )
return [provider.to_pydantic() for provider in providers] return [provider.to_pydantic() for provider in providers]
@enforce_types @enforce_types
def get_anthropic_override_provider_id(self) -> Optional[str]: def get_provider_id_from_name(self, provider_name: Union[str, None]) -> Optional[str]:
"""Helper function to fetch custom anthropic provider id for v0 BYOK feature""" providers = self.list_providers(name=provider_name)
anthropic_provider = [provider for provider in self.list_providers() if provider.name == "anthropic"] return providers[0].id if providers else None
if len(anthropic_provider) != 0:
return anthropic_provider[0].id
return None
@enforce_types @enforce_types
def get_anthropic_override_key(self) -> Optional[str]: def get_override_key(self, provider_name: Union[str, None]) -> Optional[str]:
"""Helper function to fetch custom anthropic key for v0 BYOK feature""" providers = self.list_providers(name=provider_name)
anthropic_provider = [provider for provider in self.list_providers() if provider.name == "anthropic"] return providers[0].api_key if providers else None
if len(anthropic_provider) != 0:
return anthropic_provider[0].api_key
return None

View File

@ -4,6 +4,7 @@ import traceback
from typing import List, Tuple from typing import List, Tuple
from letta.agents.voice_sleeptime_agent import VoiceSleeptimeAgent from letta.agents.voice_sleeptime_agent import VoiceSleeptimeAgent
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
from letta.log import get_logger from letta.log import get_logger
from letta.schemas.enums import MessageRole from letta.schemas.enums import MessageRole
from letta.schemas.letta_message_content import TextContent from letta.schemas.letta_message_content import TextContent
@ -77,7 +78,7 @@ class Summarizer:
logger.info("Buffer length hit, evicting messages.") logger.info("Buffer length hit, evicting messages.")
target_trim_index = len(all_in_context_messages) - self.message_buffer_min + 1 target_trim_index = len(all_in_context_messages) - self.message_buffer_min
while target_trim_index < len(all_in_context_messages) and all_in_context_messages[target_trim_index].role != MessageRole.user: while target_trim_index < len(all_in_context_messages) and all_in_context_messages[target_trim_index].role != MessageRole.user:
target_trim_index += 1 target_trim_index += 1
@ -112,11 +113,12 @@ class Summarizer:
summary_request_text = f"""Youre a memory-recall helper for an AI that can only keep the last {self.message_buffer_min} messages. Scan the conversation history, focusing on messages about to drop out of that window, and write crisp notes that capture any important facts or insights about the human so they arent lost. summary_request_text = f"""Youre a memory-recall helper for an AI that can only keep the last {self.message_buffer_min} messages. Scan the conversation history, focusing on messages about to drop out of that window, and write crisp notes that capture any important facts or insights about the human so they arent lost.
(Older) Evicted Messages:\n (Older) Evicted Messages:\n
{evicted_messages_str} {evicted_messages_str}\n
(Newer) In-Context Messages:\n (Newer) In-Context Messages:\n
{in_context_messages_str} {in_context_messages_str}
""" """
print(summary_request_text)
# Fire-and-forget the summarization task # Fire-and-forget the summarization task
self.fire_and_forget( self.fire_and_forget(
self.summarizer_agent.step([MessageCreate(role=MessageRole.user, content=[TextContent(text=summary_request_text)])]) self.summarizer_agent.step([MessageCreate(role=MessageRole.user, content=[TextContent(text=summary_request_text)])])
@ -149,6 +151,9 @@ def format_transcript(messages: List[Message], include_system: bool = False) ->
# 1) Try plain content # 1) Try plain content
if msg.content: if msg.content:
# Skip tool messages where the name is "send_message"
if msg.role == MessageRole.tool and msg.name == DEFAULT_MESSAGE_TOOL:
continue
text = "".join(c.text for c in msg.content).strip() text = "".join(c.text for c in msg.content).strip()
# 2) Otherwise, try extracting from function calls # 2) Otherwise, try extracting from function calls
@ -156,11 +161,14 @@ def format_transcript(messages: List[Message], include_system: bool = False) ->
parts = [] parts = []
for call in msg.tool_calls: for call in msg.tool_calls:
args_str = call.function.arguments args_str = call.function.arguments
try: if call.function.name == DEFAULT_MESSAGE_TOOL:
args = json.loads(args_str) try:
# pull out a "message" field if present args = json.loads(args_str)
parts.append(args.get("message", args_str)) # pull out a "message" field if present
except json.JSONDecodeError: parts.append(args.get(DEFAULT_MESSAGE_TOOL_KWARG, args_str))
except json.JSONDecodeError:
parts.append(args_str)
else:
parts.append(args_str) parts.append(args_str)
text = " ".join(parts).strip() text = " ".join(parts).strip()

View File

@ -100,7 +100,7 @@ class ToolExecutionManager:
try: try:
executor = ToolExecutorFactory.get_executor(tool.tool_type) executor = ToolExecutorFactory.get_executor(tool.tool_type)
# TODO: Extend this async model to composio # TODO: Extend this async model to composio
if isinstance(executor, SandboxToolExecutor): if isinstance(executor, (SandboxToolExecutor, ExternalComposioToolExecutor)):
result = await executor.execute(function_name, function_args, self.agent_state, tool, self.actor) result = await executor.execute(function_name, function_args, self.agent_state, tool, self.actor)
else: else:
result = executor.execute(function_name, function_args, self.agent_state, tool, self.actor) result = executor.execute(function_name, function_args, self.agent_state, tool, self.actor)

View File

@ -5,7 +5,7 @@ from typing import Any, Dict, Optional
from letta.constants import COMPOSIO_ENTITY_ENV_VAR_KEY, CORE_MEMORY_LINE_NUMBER_WARNING, RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE from letta.constants import COMPOSIO_ENTITY_ENV_VAR_KEY, CORE_MEMORY_LINE_NUMBER_WARNING, RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
from letta.functions.ast_parsers import coerce_dict_args_by_annotations, get_function_annotations_from_source from letta.functions.ast_parsers import coerce_dict_args_by_annotations, get_function_annotations_from_source
from letta.functions.helpers import execute_composio_action, generate_composio_action_from_func_name from letta.functions.composio_helpers import execute_composio_action_async, generate_composio_action_from_func_name
from letta.helpers.composio_helpers import get_composio_api_key from letta.helpers.composio_helpers import get_composio_api_key
from letta.helpers.json_helpers import json_dumps from letta.helpers.json_helpers import json_dumps
from letta.schemas.agent import AgentState from letta.schemas.agent import AgentState
@ -486,7 +486,7 @@ class LettaMultiAgentToolExecutor(ToolExecutor):
class ExternalComposioToolExecutor(ToolExecutor): class ExternalComposioToolExecutor(ToolExecutor):
"""Executor for external Composio tools.""" """Executor for external Composio tools."""
def execute( async def execute(
self, self,
function_name: str, function_name: str,
function_args: dict, function_args: dict,
@ -505,7 +505,7 @@ class ExternalComposioToolExecutor(ToolExecutor):
composio_api_key = get_composio_api_key(actor=actor) composio_api_key = get_composio_api_key(actor=actor)
# TODO (matt): Roll in execute_composio_action into this class # TODO (matt): Roll in execute_composio_action into this class
function_response = execute_composio_action( function_response = await execute_composio_action_async(
action_name=action_name, args=function_args, api_key=composio_api_key, entity_id=entity_id action_name=action_name, args=function_args, api_key=composio_api_key, entity_id=entity_id
) )

104
poetry.lock generated
View File

@ -1016,25 +1016,6 @@ e2b = ["e2b (>=0.17.2a37,<1.1.0)", "e2b-code-interpreter"]
flyio = ["gql", "requests_toolbelt"] flyio = ["gql", "requests_toolbelt"]
tools = ["diskcache", "flake8", "networkx", "pathspec", "pygments", "ruff", "transformers"] tools = ["diskcache", "flake8", "networkx", "pathspec", "pygments", "ruff", "transformers"]
[[package]]
name = "composio-langchain"
version = "0.7.15"
description = "Use Composio to get an array of tools with your LangChain agent."
optional = false
python-versions = "<4,>=3.9"
groups = ["main"]
files = [
{file = "composio_langchain-0.7.15-py3-none-any.whl", hash = "sha256:a71b5371ad6c3ee4d4289c7a994fad1424e24c29a38e820b6b2ed259056abb65"},
{file = "composio_langchain-0.7.15.tar.gz", hash = "sha256:cb75c460289ecdf9590caf7ddc0d7888b0a6622ca4f800c9358abe90c25d055e"},
]
[package.dependencies]
composio_core = ">=0.7.0,<0.8.0"
langchain = ">=0.1.0"
langchain-openai = ">=0.0.2.post1"
langchainhub = ">=0.1.15"
pydantic = ">=2.6.4"
[[package]] [[package]]
name = "configargparse" name = "configargparse"
version = "1.7" version = "1.7"
@ -2842,9 +2823,10 @@ files = [
name = "jsonpatch" name = "jsonpatch"
version = "1.33" version = "1.33"
description = "Apply JSON-Patches (RFC 6902)" description = "Apply JSON-Patches (RFC 6902)"
optional = false optional = true
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, !=3.6.*" python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, !=3.6.*"
groups = ["main"] groups = ["main"]
markers = "extra == \"external-tools\" or extra == \"desktop\" or extra == \"all\""
files = [ files = [
{file = "jsonpatch-1.33-py2.py3-none-any.whl", hash = "sha256:0ae28c0cd062bbd8b8ecc26d7d164fbbea9652a1a3693f3b956c1eae5145dade"}, {file = "jsonpatch-1.33-py2.py3-none-any.whl", hash = "sha256:0ae28c0cd062bbd8b8ecc26d7d164fbbea9652a1a3693f3b956c1eae5145dade"},
{file = "jsonpatch-1.33.tar.gz", hash = "sha256:9fcd4009c41e6d12348b4a0ff2563ba56a2923a7dfee731d004e212e1ee5030c"}, {file = "jsonpatch-1.33.tar.gz", hash = "sha256:9fcd4009c41e6d12348b4a0ff2563ba56a2923a7dfee731d004e212e1ee5030c"},
@ -2857,9 +2839,10 @@ jsonpointer = ">=1.9"
name = "jsonpointer" name = "jsonpointer"
version = "3.0.0" version = "3.0.0"
description = "Identify specific nodes in a JSON document (RFC 6901)" description = "Identify specific nodes in a JSON document (RFC 6901)"
optional = false optional = true
python-versions = ">=3.7" python-versions = ">=3.7"
groups = ["main"] groups = ["main"]
markers = "extra == \"external-tools\" or extra == \"desktop\" or extra == \"all\""
files = [ files = [
{file = "jsonpointer-3.0.0-py2.py3-none-any.whl", hash = "sha256:13e088adc14fca8b6aa8177c044e12701e6ad4b28ff10e65f2267a90109c9942"}, {file = "jsonpointer-3.0.0-py2.py3-none-any.whl", hash = "sha256:13e088adc14fca8b6aa8177c044e12701e6ad4b28ff10e65f2267a90109c9942"},
{file = "jsonpointer-3.0.0.tar.gz", hash = "sha256:2b2d729f2091522d61c3b31f82e11870f60b68f43fbc705cb76bf4b832af59ef"}, {file = "jsonpointer-3.0.0.tar.gz", hash = "sha256:2b2d729f2091522d61c3b31f82e11870f60b68f43fbc705cb76bf4b832af59ef"},
@ -3052,9 +3035,10 @@ files = [
name = "langchain" name = "langchain"
version = "0.3.23" version = "0.3.23"
description = "Building applications with LLMs through composability" description = "Building applications with LLMs through composability"
optional = false optional = true
python-versions = "<4.0,>=3.9" python-versions = "<4.0,>=3.9"
groups = ["main"] groups = ["main"]
markers = "extra == \"external-tools\" or extra == \"desktop\" or extra == \"all\""
files = [ files = [
{file = "langchain-0.3.23-py3-none-any.whl", hash = "sha256:084f05ee7e80b7c3f378ebadd7309f2a37868ce2906fa0ae64365a67843ade3d"}, {file = "langchain-0.3.23-py3-none-any.whl", hash = "sha256:084f05ee7e80b7c3f378ebadd7309f2a37868ce2906fa0ae64365a67843ade3d"},
{file = "langchain-0.3.23.tar.gz", hash = "sha256:d95004afe8abebb52d51d6026270248da3f4b53d93e9bf699f76005e0c83ad34"}, {file = "langchain-0.3.23.tar.gz", hash = "sha256:d95004afe8abebb52d51d6026270248da3f4b53d93e9bf699f76005e0c83ad34"},
@ -3120,9 +3104,10 @@ tenacity = ">=8.1.0,<8.4.0 || >8.4.0,<10"
name = "langchain-core" name = "langchain-core"
version = "0.3.51" version = "0.3.51"
description = "Building applications with LLMs through composability" description = "Building applications with LLMs through composability"
optional = false optional = true
python-versions = "<4.0,>=3.9" python-versions = "<4.0,>=3.9"
groups = ["main"] groups = ["main"]
markers = "extra == \"external-tools\" or extra == \"desktop\" or extra == \"all\""
files = [ files = [
{file = "langchain_core-0.3.51-py3-none-any.whl", hash = "sha256:4bd71e8acd45362aa428953f2a91d8162318014544a2216e4b769463caf68e13"}, {file = "langchain_core-0.3.51-py3-none-any.whl", hash = "sha256:4bd71e8acd45362aa428953f2a91d8162318014544a2216e4b769463caf68e13"},
{file = "langchain_core-0.3.51.tar.gz", hash = "sha256:db76b9cc331411602cb40ba0469a161febe7a0663fbcaddbc9056046ac2d22f4"}, {file = "langchain_core-0.3.51.tar.gz", hash = "sha256:db76b9cc331411602cb40ba0469a161febe7a0663fbcaddbc9056046ac2d22f4"},
@ -3140,30 +3125,14 @@ PyYAML = ">=5.3"
tenacity = ">=8.1.0,<8.4.0 || >8.4.0,<10.0.0" tenacity = ">=8.1.0,<8.4.0 || >8.4.0,<10.0.0"
typing-extensions = ">=4.7" typing-extensions = ">=4.7"
[[package]]
name = "langchain-openai"
version = "0.3.12"
description = "An integration package connecting OpenAI and LangChain"
optional = false
python-versions = "<4.0,>=3.9"
groups = ["main"]
files = [
{file = "langchain_openai-0.3.12-py3-none-any.whl", hash = "sha256:0fab64d58ec95e65ffbaf659470cd362e815685e15edbcb171641e90eca4eb86"},
{file = "langchain_openai-0.3.12.tar.gz", hash = "sha256:c9dbff63551f6bd91913bca9f99a2d057fd95dc58d4778657d67e5baa1737f61"},
]
[package.dependencies]
langchain-core = ">=0.3.49,<1.0.0"
openai = ">=1.68.2,<2.0.0"
tiktoken = ">=0.7,<1"
[[package]] [[package]]
name = "langchain-text-splitters" name = "langchain-text-splitters"
version = "0.3.8" version = "0.3.8"
description = "LangChain text splitting utilities" description = "LangChain text splitting utilities"
optional = false optional = true
python-versions = "<4.0,>=3.9" python-versions = "<4.0,>=3.9"
groups = ["main"] groups = ["main"]
markers = "extra == \"external-tools\" or extra == \"desktop\" or extra == \"all\""
files = [ files = [
{file = "langchain_text_splitters-0.3.8-py3-none-any.whl", hash = "sha256:e75cc0f4ae58dcf07d9f18776400cf8ade27fadd4ff6d264df6278bb302f6f02"}, {file = "langchain_text_splitters-0.3.8-py3-none-any.whl", hash = "sha256:e75cc0f4ae58dcf07d9f18776400cf8ade27fadd4ff6d264df6278bb302f6f02"},
{file = "langchain_text_splitters-0.3.8.tar.gz", hash = "sha256:116d4b9f2a22dda357d0b79e30acf005c5518177971c66a9f1ab0edfdb0f912e"}, {file = "langchain_text_splitters-0.3.8.tar.gz", hash = "sha256:116d4b9f2a22dda357d0b79e30acf005c5518177971c66a9f1ab0edfdb0f912e"},
@ -3172,30 +3141,14 @@ files = [
[package.dependencies] [package.dependencies]
langchain-core = ">=0.3.51,<1.0.0" langchain-core = ">=0.3.51,<1.0.0"
[[package]]
name = "langchainhub"
version = "0.1.21"
description = "The LangChain Hub API client"
optional = false
python-versions = "<4.0,>=3.8.1"
groups = ["main"]
files = [
{file = "langchainhub-0.1.21-py3-none-any.whl", hash = "sha256:1cc002dc31e0d132a776afd044361e2b698743df5202618cf2bad399246b895f"},
{file = "langchainhub-0.1.21.tar.gz", hash = "sha256:723383b3964a47dbaea6ad5d0ef728accefbc9d2c07480e800bdec43510a8c10"},
]
[package.dependencies]
packaging = ">=23.2,<25"
requests = ">=2,<3"
types-requests = ">=2.31.0.2,<3.0.0.0"
[[package]] [[package]]
name = "langsmith" name = "langsmith"
version = "0.3.28" version = "0.3.28"
description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform." description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform."
optional = false optional = true
python-versions = "<4.0,>=3.9" python-versions = "<4.0,>=3.9"
groups = ["main"] groups = ["main"]
markers = "extra == \"external-tools\" or extra == \"desktop\" or extra == \"all\""
files = [ files = [
{file = "langsmith-0.3.28-py3-none-any.whl", hash = "sha256:54ac8815514af52d9c801ad7970086693667e266bf1db90fc453c1759e8407cd"}, {file = "langsmith-0.3.28-py3-none-any.whl", hash = "sha256:54ac8815514af52d9c801ad7970086693667e266bf1db90fc453c1759e8407cd"},
{file = "langsmith-0.3.28.tar.gz", hash = "sha256:4666595207131d7f8d83418e54dc86c05e28562e5c997633e7c33fc18f9aeb89"}, {file = "langsmith-0.3.28.tar.gz", hash = "sha256:4666595207131d7f8d83418e54dc86c05e28562e5c997633e7c33fc18f9aeb89"},
@ -3221,14 +3174,14 @@ pytest = ["pytest (>=7.0.0)", "rich (>=13.9.4,<14.0.0)"]
[[package]] [[package]]
name = "letta-client" name = "letta-client"
version = "0.1.124" version = "0.1.129"
description = "" description = ""
optional = false optional = false
python-versions = "<4.0,>=3.8" python-versions = "<4.0,>=3.8"
groups = ["main"] groups = ["main"]
files = [ files = [
{file = "letta_client-0.1.124-py3-none-any.whl", hash = "sha256:a7901437ef91f395cd85d24c0312046b7c82e5a4dd8e04de0d39b5ca085c65d3"}, {file = "letta_client-0.1.129-py3-none-any.whl", hash = "sha256:87a5fc32471e5b9fefbfc1e1337fd667d5e2e340ece5d2a6c782afbceab4bf36"},
{file = "letta_client-0.1.124.tar.gz", hash = "sha256:e8b5716930824cc98c62ee01343e358f88619d346578d48a466277bc8282036d"}, {file = "letta_client-0.1.129.tar.gz", hash = "sha256:b00f611c18a2ad802ec9265f384e1666938c5fc5c86364b2c410d72f0331d597"},
] ]
[package.dependencies] [package.dependencies]
@ -4366,10 +4319,10 @@ files = [
name = "orjson" name = "orjson"
version = "3.10.16" version = "3.10.16"
description = "Fast, correct Python JSON library supporting dataclasses, datetimes, and numpy" description = "Fast, correct Python JSON library supporting dataclasses, datetimes, and numpy"
optional = false optional = true
python-versions = ">=3.9" python-versions = ">=3.9"
groups = ["main"] groups = ["main"]
markers = "platform_python_implementation != \"PyPy\"" markers = "platform_python_implementation != \"PyPy\" and (extra == \"external-tools\" or extra == \"desktop\" or extra == \"all\")"
files = [ files = [
{file = "orjson-3.10.16-cp310-cp310-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:4cb473b8e79154fa778fb56d2d73763d977be3dcc140587e07dbc545bbfc38f8"}, {file = "orjson-3.10.16-cp310-cp310-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:4cb473b8e79154fa778fb56d2d73763d977be3dcc140587e07dbc545bbfc38f8"},
{file = "orjson-3.10.16-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:622a8e85eeec1948690409a19ca1c7d9fd8ff116f4861d261e6ae2094fe59a00"}, {file = "orjson-3.10.16-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:622a8e85eeec1948690409a19ca1c7d9fd8ff116f4861d261e6ae2094fe59a00"},
@ -6069,9 +6022,10 @@ use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
name = "requests-toolbelt" name = "requests-toolbelt"
version = "1.0.0" version = "1.0.0"
description = "A utility belt for advanced users of python-requests" description = "A utility belt for advanced users of python-requests"
optional = false optional = true
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
groups = ["main"] groups = ["main"]
markers = "extra == \"external-tools\" or extra == \"desktop\" or extra == \"all\""
files = [ files = [
{file = "requests-toolbelt-1.0.0.tar.gz", hash = "sha256:7681a0a3d047012b5bdc0ee37d7f8f07ebe76ab08caeccfc3921ce23c88d5bc6"}, {file = "requests-toolbelt-1.0.0.tar.gz", hash = "sha256:7681a0a3d047012b5bdc0ee37d7f8f07ebe76ab08caeccfc3921ce23c88d5bc6"},
{file = "requests_toolbelt-1.0.0-py2.py3-none-any.whl", hash = "sha256:cccfdd665f0a24fcf4726e690f65639d272bb0637b9b92dfd91a5568ccf6bd06"}, {file = "requests_toolbelt-1.0.0-py2.py3-none-any.whl", hash = "sha256:cccfdd665f0a24fcf4726e690f65639d272bb0637b9b92dfd91a5568ccf6bd06"},
@ -6855,21 +6809,6 @@ dev = ["autoflake (>=1.3.1,<2.0.0)", "flake8 (>=3.8.3,<4.0.0)", "pre-commit (>=2
doc = ["cairosvg (>=2.5.2,<3.0.0)", "mdx-include (>=1.4.1,<2.0.0)", "mkdocs (>=1.1.2,<2.0.0)", "mkdocs-material (>=8.1.4,<9.0.0)", "pillow (>=9.3.0,<10.0.0)"] doc = ["cairosvg (>=2.5.2,<3.0.0)", "mdx-include (>=1.4.1,<2.0.0)", "mkdocs (>=1.1.2,<2.0.0)", "mkdocs-material (>=8.1.4,<9.0.0)", "pillow (>=9.3.0,<10.0.0)"]
test = ["black (>=22.3.0,<23.0.0)", "coverage (>=6.2,<7.0)", "isort (>=5.0.6,<6.0.0)", "mypy (==0.971)", "pytest (>=4.4.0,<8.0.0)", "pytest-cov (>=2.10.0,<5.0.0)", "pytest-sugar (>=0.9.4,<0.10.0)", "pytest-xdist (>=1.32.0,<4.0.0)", "rich (>=10.11.0,<14.0.0)", "shellingham (>=1.3.0,<2.0.0)"] test = ["black (>=22.3.0,<23.0.0)", "coverage (>=6.2,<7.0)", "isort (>=5.0.6,<6.0.0)", "mypy (==0.971)", "pytest (>=4.4.0,<8.0.0)", "pytest-cov (>=2.10.0,<5.0.0)", "pytest-sugar (>=0.9.4,<0.10.0)", "pytest-xdist (>=1.32.0,<4.0.0)", "rich (>=10.11.0,<14.0.0)", "shellingham (>=1.3.0,<2.0.0)"]
[[package]]
name = "types-requests"
version = "2.32.0.20250328"
description = "Typing stubs for requests"
optional = false
python-versions = ">=3.9"
groups = ["main"]
files = [
{file = "types_requests-2.32.0.20250328-py3-none-any.whl", hash = "sha256:72ff80f84b15eb3aa7a8e2625fffb6a93f2ad5a0c20215fc1dcfa61117bcb2a2"},
{file = "types_requests-2.32.0.20250328.tar.gz", hash = "sha256:c9e67228ea103bd811c96984fac36ed2ae8da87a36a633964a21f199d60baf32"},
]
[package.dependencies]
urllib3 = ">=2"
[[package]] [[package]]
name = "typing-extensions" name = "typing-extensions"
version = "4.13.2" version = "4.13.2"
@ -7438,9 +7377,10 @@ testing = ["coverage[toml]", "zope.event", "zope.testing"]
name = "zstandard" name = "zstandard"
version = "0.23.0" version = "0.23.0"
description = "Zstandard bindings for Python" description = "Zstandard bindings for Python"
optional = false optional = true
python-versions = ">=3.8" python-versions = ">=3.8"
groups = ["main"] groups = ["main"]
markers = "extra == \"external-tools\" or extra == \"desktop\" or extra == \"all\""
files = [ files = [
{file = "zstandard-0.23.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:bf0a05b6059c0528477fba9054d09179beb63744355cab9f38059548fedd46a9"}, {file = "zstandard-0.23.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:bf0a05b6059c0528477fba9054d09179beb63744355cab9f38059548fedd46a9"},
{file = "zstandard-0.23.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:fc9ca1c9718cb3b06634c7c8dec57d24e9438b2aa9a0f02b8bb36bf478538880"}, {file = "zstandard-0.23.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:fc9ca1c9718cb3b06634c7c8dec57d24e9438b2aa9a0f02b8bb36bf478538880"},
@ -7563,4 +7503,4 @@ tests = ["wikipedia"]
[metadata] [metadata]
lock-version = "2.1" lock-version = "2.1"
python-versions = "<3.14,>=3.10" python-versions = "<3.14,>=3.10"
content-hash = "75c1c949aa6c0ef8d681bddd91999f97ed4991451be93ca45bf9c01dd19d8a8a" content-hash = "ba9cf0e00af2d5542aa4beecbd727af92b77ba584033f05c222b00ae47f96585"

View File

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "letta" name = "letta"
version = "0.7.7" version = "0.7.8"
packages = [ packages = [
{include = "letta"}, {include = "letta"},
] ]
@ -56,7 +56,6 @@ nltk = "^3.8.1"
jinja2 = "^3.1.5" jinja2 = "^3.1.5"
locust = {version = "^2.31.5", optional = true} locust = {version = "^2.31.5", optional = true}
wikipedia = {version = "^1.4.0", optional = true} wikipedia = {version = "^1.4.0", optional = true}
composio-langchain = "^0.7.7"
composio-core = "^0.7.7" composio-core = "^0.7.7"
alembic = "^1.13.3" alembic = "^1.13.3"
pyhumps = "^3.8.0" pyhumps = "^3.8.0"
@ -74,7 +73,7 @@ llama-index = "^0.12.2"
llama-index-embeddings-openai = "^0.3.1" llama-index-embeddings-openai = "^0.3.1"
e2b-code-interpreter = {version = "^1.0.3", optional = true} e2b-code-interpreter = {version = "^1.0.3", optional = true}
anthropic = "^0.49.0" anthropic = "^0.49.0"
letta_client = "^0.1.124" letta_client = "^0.1.127"
openai = "^1.60.0" openai = "^1.60.0"
opentelemetry-api = "1.30.0" opentelemetry-api = "1.30.0"
opentelemetry-sdk = "1.30.0" opentelemetry-sdk = "1.30.0"

View File

@ -1,11 +1,11 @@
{ {
"context_window": 8192, "context_window": 8192,
"model_endpoint_type": "openai", "model_endpoint_type": "openai",
"model_endpoint": "https://inference.memgpt.ai", "model_endpoint": "https://inference.letta.com",
"model": "memgpt-openai", "model": "memgpt-openai",
"embedding_endpoint_type": "hugging-face", "embedding_endpoint_type": "hugging-face",
"embedding_endpoint": "https://embeddings.memgpt.ai", "embedding_endpoint": "https://embeddings.memgpt.ai",
"embedding_model": "BAAI/bge-large-en-v1.5", "embedding_model": "BAAI/bge-large-en-v1.5",
"embedding_dim": 1024, "embedding_dim": 1024,
"embedding_chunk_size": 300 "embedding_chunk_size": 300
} }

View File

@ -1,7 +1,7 @@
{ {
"context_window": 8192, "context_window": 8192,
"model_endpoint_type": "openai", "model_endpoint_type": "openai",
"model_endpoint": "https://inference.memgpt.ai", "model_endpoint": "https://inference.letta.com",
"model": "memgpt-openai", "model": "memgpt-openai",
"put_inner_thoughts_in_kwargs": true "put_inner_thoughts_in_kwargs": true
} }

View File

@ -105,7 +105,9 @@ def check_first_response_is_valid_for_llm_endpoint(filename: str, validate_inner
agent = Agent(agent_state=full_agent_state, interface=None, user=client.user) agent = Agent(agent_state=full_agent_state, interface=None, user=client.user)
llm_client = LLMClient.create( llm_client = LLMClient.create(
provider=agent_state.llm_config.model_endpoint_type, provider_name=agent_state.llm_config.provider_name,
provider_type=agent_state.llm_config.model_endpoint_type,
actor_id=client.user.id,
) )
if llm_client: if llm_client:
response = llm_client.send_llm_request( response = llm_client.send_llm_request(
@ -179,7 +181,7 @@ def check_agent_uses_external_tool(filename: str) -> LettaResponse:
Note: This is acting on the Letta response, note the usage of `user_message` Note: This is acting on the Letta response, note the usage of `user_message`
""" """
from composio_langchain import Action from composio import Action
# Set up client # Set up client
client = create_client() client = create_client()

View File

@ -56,7 +56,7 @@ def test_add_composio_tool(fastapi_client):
assert "name" in response.json() assert "name" in response.json()
def test_composio_tool_execution_e2e(check_composio_key_set, composio_get_emojis, server: SyncServer, default_user): async def test_composio_tool_execution_e2e(check_composio_key_set, composio_get_emojis, server: SyncServer, default_user):
agent_state = server.agent_manager.create_agent( agent_state = server.agent_manager.create_agent(
agent_create=CreateAgent( agent_create=CreateAgent(
name="sarah_agent", name="sarah_agent",
@ -67,7 +67,7 @@ def test_composio_tool_execution_e2e(check_composio_key_set, composio_get_emojis
actor=default_user, actor=default_user,
) )
tool_execution_result = ToolExecutionManager(agent_state, actor=default_user).execute_tool( tool_execution_result = await ToolExecutionManager(agent_state, actor=default_user).execute_tool(
function_name=composio_get_emojis.name, function_args={}, tool=composio_get_emojis function_name=composio_get_emojis.name, function_args={}, tool=composio_get_emojis
) )

View File

@ -1,26 +1,26 @@
import os import os
import threading import threading
from unittest.mock import MagicMock
import pytest import pytest
from dotenv import load_dotenv from dotenv import load_dotenv
from letta_client import Letta from letta_client import Letta
from openai import AsyncOpenAI from openai import AsyncOpenAI
from openai.types.chat import ChatCompletionChunk from openai.types.chat import ChatCompletionChunk
from sqlalchemy import delete
from letta.agents.voice_sleeptime_agent import VoiceSleeptimeAgent from letta.agents.voice_sleeptime_agent import VoiceSleeptimeAgent
from letta.config import LettaConfig from letta.config import LettaConfig
from letta.orm import Provider, Step from letta.constants import DEFAULT_MAX_MESSAGE_BUFFER_LENGTH, DEFAULT_MIN_MESSAGE_BUFFER_LENGTH
from letta.orm.errors import NoResultFound from letta.orm.errors import NoResultFound
from letta.schemas.agent import AgentType, CreateAgent from letta.schemas.agent import AgentType, CreateAgent
from letta.schemas.block import CreateBlock from letta.schemas.block import CreateBlock
from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import MessageRole, MessageStreamStatus from letta.schemas.enums import MessageRole, MessageStreamStatus
from letta.schemas.group import ManagerType from letta.schemas.group import GroupUpdate, ManagerType, VoiceSleeptimeManagerUpdate
from letta.schemas.letta_message import AssistantMessage, ReasoningMessage, ToolCallMessage, ToolReturnMessage, UserMessage from letta.schemas.letta_message import AssistantMessage, ReasoningMessage, ToolCallMessage, ToolReturnMessage, UserMessage
from letta.schemas.letta_message_content import TextContent from letta.schemas.letta_message_content import TextContent
from letta.schemas.llm_config import LLMConfig from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import MessageCreate from letta.schemas.message import Message, MessageCreate
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest from letta.schemas.openai.chat_completion_request import ChatCompletionRequest
from letta.schemas.openai.chat_completion_request import UserMessage as OpenAIUserMessage from letta.schemas.openai.chat_completion_request import UserMessage as OpenAIUserMessage
from letta.schemas.tool import ToolCreate from letta.schemas.tool import ToolCreate
@ -29,6 +29,8 @@ from letta.server.server import SyncServer
from letta.services.agent_manager import AgentManager from letta.services.agent_manager import AgentManager
from letta.services.block_manager import BlockManager from letta.services.block_manager import BlockManager
from letta.services.message_manager import MessageManager from letta.services.message_manager import MessageManager
from letta.services.summarizer.enums import SummarizationMode
from letta.services.summarizer.summarizer import Summarizer
from letta.services.tool_manager import ToolManager from letta.services.tool_manager import ToolManager
from letta.services.user_manager import UserManager from letta.services.user_manager import UserManager
from letta.utils import get_persona_text from letta.utils import get_persona_text
@ -48,16 +50,24 @@ MESSAGE_TRANSCRIPTS = [
"user: Maybe just a recommendation for a nice vegan bakery to grab a birthday treat.", "user: Maybe just a recommendation for a nice vegan bakery to grab a birthday treat.",
"assistant: How about Vegan Treats in Santa Barbara? Theyre highly rated.", "assistant: How about Vegan Treats in Santa Barbara? Theyre highly rated.",
"user: Sounds good. Also, I work remotely as a UX designer, usually on a MacBook Pro.", "user: Sounds good. Also, I work remotely as a UX designer, usually on a MacBook Pro.",
"user: I want to make sure my itinerary isnt too tight—aiming for 34 days total.",
"assistant: Understood. I can draft a relaxed 4-day schedule with driving and stops.", "assistant: Understood. I can draft a relaxed 4-day schedule with driving and stops.",
"user: Yes, lets do that.", "user: Yes, lets do that.",
"assistant: Ill put together a day-by-day plan now.", "assistant: Ill put together a day-by-day plan now.",
] ]
SUMMARY_REQ_TEXT = """ SYSTEM_MESSAGE = Message(role=MessageRole.system, content=[TextContent(text="System message")])
Here is the conversation history. Lines marked (Older) are about to be evicted; lines marked (Newer) are still in context for clarity: MESSAGE_OBJECTS = [SYSTEM_MESSAGE]
for entry in MESSAGE_TRANSCRIPTS:
role_str, text = entry.split(":", 1)
role = MessageRole.user if role_str.strip() == "user" else MessageRole.assistant
MESSAGE_OBJECTS.append(Message(role=role, content=[TextContent(text=text.strip())]))
MESSAGE_EVICT_BREAKPOINT = 14
SUMMARY_REQ_TEXT = """
Youre a memory-recall helper for an AI that can only keep the last 4 messages. Scan the conversation history, focusing on messages about to drop out of that window, and write crisp notes that capture any important facts or insights about the human so they arent lost.
(Older) Evicted Messages:
(Older)
0. user: Hey, Ive been thinking about planning a road trip up the California coast next month. 0. user: Hey, Ive been thinking about planning a road trip up the California coast next month.
1. assistant: That sounds amazing! Do you have any particular cities or sights in mind? 1. assistant: That sounds amazing! Do you have any particular cities or sights in mind?
2. user: I definitely want to stop in Big Sur and maybe Santa Barbara. Also, I love craft coffee shops. 2. user: I definitely want to stop in Big Sur and maybe Santa Barbara. Also, I love craft coffee shops.
@ -70,16 +80,13 @@ Here is the conversation history. Lines marked (Older) are about to be evicted;
9. assistant: Happy early birthday! Would you like gift ideas or celebration tips? 9. assistant: Happy early birthday! Would you like gift ideas or celebration tips?
10. user: Maybe just a recommendation for a nice vegan bakery to grab a birthday treat. 10. user: Maybe just a recommendation for a nice vegan bakery to grab a birthday treat.
11. assistant: How about Vegan Treats in Santa Barbara? Theyre highly rated. 11. assistant: How about Vegan Treats in Santa Barbara? Theyre highly rated.
(Newer) In-Context Messages:
12. user: Sounds good. Also, I work remotely as a UX designer, usually on a MacBook Pro. 12. user: Sounds good. Also, I work remotely as a UX designer, usually on a MacBook Pro.
13. assistant: Understood. I can draft a relaxed 4-day schedule with driving and stops.
(Newer) 14. user: Yes, lets do that.
13. user: I want to make sure my itinerary isnt too tightaiming for 34 days total. 15. assistant: Ill put together a day-by-day plan now."""
14. assistant: Understood. I can draft a relaxed 4-day schedule with driving and stops.
15. user: Yes, lets do that.
16. assistant: Ill put together a day-by-day plan now.
Please segment the (Older) portion into coherent chunks andusing **only** the `store_memory` tooloutput a JSON call that lists each chunks `start_index`, `end_index`, and a one-sentence `contextual_description`.
"""
# --- Server Management --- # # --- Server Management --- #
@ -214,22 +221,12 @@ def org_id(server):
yield org.id yield org.id
# cleanup
with server.organization_manager.session_maker() as session:
session.execute(delete(Step))
session.execute(delete(Provider))
session.commit()
server.organization_manager.delete_organization_by_id(org.id)
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def actor(server, org_id): def actor(server, org_id):
user = server.user_manager.create_default_user() user = server.user_manager.create_default_user()
yield user yield user
# cleanup
server.user_manager.delete_user_by_id(user.id)
# --- Helper Functions --- # # --- Helper Functions --- #
@ -301,6 +298,80 @@ async def test_multiple_messages(disable_e2b_api_key, client, voice_agent, endpo
print(chunk.choices[0].delta.content) print(chunk.choices[0].delta.content)
@pytest.mark.asyncio
async def test_summarization(disable_e2b_api_key, voice_agent):
agent_manager = AgentManager()
user_manager = UserManager()
actor = user_manager.get_default_user()
request = CreateAgent(
name=voice_agent.name + "-sleeptime",
agent_type=AgentType.voice_sleeptime_agent,
block_ids=[block.id for block in voice_agent.memory.blocks],
memory_blocks=[
CreateBlock(
label="memory_persona",
value=get_persona_text("voice_memory_persona"),
),
],
llm_config=LLMConfig.default_config(model_name="gpt-4o-mini"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
project_id=voice_agent.project_id,
)
sleeptime_agent = agent_manager.create_agent(request, actor=actor)
async_client = AsyncOpenAI()
memory_agent = VoiceSleeptimeAgent(
agent_id=sleeptime_agent.id,
convo_agent_state=sleeptime_agent, # In reality, this will be the main convo agent
openai_client=async_client,
message_manager=MessageManager(),
agent_manager=agent_manager,
actor=actor,
block_manager=BlockManager(),
target_block_label="human",
message_transcripts=MESSAGE_TRANSCRIPTS,
)
summarizer = Summarizer(
mode=SummarizationMode.STATIC_MESSAGE_BUFFER,
summarizer_agent=memory_agent,
message_buffer_limit=8,
message_buffer_min=4,
)
# stub out the agent.step so it returns a known sentinel
memory_agent.step = MagicMock(return_value="STEP_RESULT")
# patch fire_and_forget on *this* summarizer instance to a MagicMock
summarizer.fire_and_forget = MagicMock()
# now call the method under test
in_ctx = MESSAGE_OBJECTS[:MESSAGE_EVICT_BREAKPOINT]
new_msgs = MESSAGE_OBJECTS[MESSAGE_EVICT_BREAKPOINT:]
# call under test (this is sync)
updated, did_summarize = summarizer._static_buffer_summarization(
in_context_messages=in_ctx,
new_letta_messages=new_msgs,
)
assert did_summarize is True
assert len(updated) == summarizer.message_buffer_min + 1 # One extra for system message
assert updated[0].role == MessageRole.system # Preserved system message
# 2) the summarizer_agent.step() should have been *called* exactly once
memory_agent.step.assert_called_once()
call_args = memory_agent.step.call_args.args[0] # the single positional argument: a list of MessageCreate
assert isinstance(call_args, list)
assert isinstance(call_args[0], MessageCreate)
assert call_args[0].role == MessageRole.user
assert "15. assistant: Ill put together a day-by-day plan now." in call_args[0].content[0].text
# 3) fire_and_forget should have been called once, and its argument must be the coroutine returned by step()
summarizer.fire_and_forget.assert_called_once()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_voice_sleeptime_agent(disable_e2b_api_key, voice_agent): async def test_voice_sleeptime_agent(disable_e2b_api_key, voice_agent):
"""Tests chat completion streaming using the Async OpenAI client.""" """Tests chat completion streaming using the Async OpenAI client."""
@ -427,3 +498,66 @@ async def test_init_voice_convo_agent(voice_agent, server, actor):
server.group_manager.retrieve_group(group_id=group.id, actor=actor) server.group_manager.retrieve_group(group_id=group.id, actor=actor)
with pytest.raises(NoResultFound): with pytest.raises(NoResultFound):
server.agent_manager.get_agent_by_id(agent_id=sleeptime_agent_id, actor=actor) server.agent_manager.get_agent_by_id(agent_id=sleeptime_agent_id, actor=actor)
def _modify(group_id, server, actor, max_val, min_val):
"""Helper to invoke modify_group with voice_sleeptime config."""
return server.group_manager.modify_group(
group_id=group_id,
group_update=GroupUpdate(
manager_config=VoiceSleeptimeManagerUpdate(
manager_type=ManagerType.voice_sleeptime,
max_message_buffer_length=max_val,
min_message_buffer_length=min_val,
)
),
actor=actor,
)
@pytest.fixture
def group_id(voice_agent):
return voice_agent.multi_agent_group.id
def test_valid_buffer_lengths_above_four(group_id, server, actor):
# both > 4 and max > min
updated = _modify(group_id, server, actor, max_val=10, min_val=5)
assert updated.max_message_buffer_length == 10
assert updated.min_message_buffer_length == 5
def test_valid_buffer_lengths_only_max(group_id, server, actor):
# both > 4 and max > min
updated = _modify(group_id, server, actor, max_val=DEFAULT_MAX_MESSAGE_BUFFER_LENGTH + 1, min_val=None)
assert updated.max_message_buffer_length == DEFAULT_MAX_MESSAGE_BUFFER_LENGTH + 1
assert updated.min_message_buffer_length == DEFAULT_MIN_MESSAGE_BUFFER_LENGTH
def test_valid_buffer_lengths_only_min(group_id, server, actor):
# both > 4 and max > min
updated = _modify(group_id, server, actor, max_val=None, min_val=DEFAULT_MIN_MESSAGE_BUFFER_LENGTH + 1)
assert updated.max_message_buffer_length == DEFAULT_MAX_MESSAGE_BUFFER_LENGTH
assert updated.min_message_buffer_length == DEFAULT_MIN_MESSAGE_BUFFER_LENGTH + 1
@pytest.mark.parametrize(
"max_val,min_val,err_part",
[
# only one set → both-or-none
(None, DEFAULT_MAX_MESSAGE_BUFFER_LENGTH, "must be greater than"),
(DEFAULT_MIN_MESSAGE_BUFFER_LENGTH, None, "must be greater than"),
# ordering violations
(5, 5, "must be greater than"),
(6, 7, "must be greater than"),
# lower-bound (must both be > 4)
(4, 5, "greater than 4"),
(5, 4, "greater than 4"),
(1, 10, "greater than 4"),
(10, 1, "greater than 4"),
],
)
def test_invalid_buffer_lengths(group_id, server, actor, max_val, min_val, err_part):
with pytest.raises(ValueError) as exc:
_modify(group_id, server, actor, max_val, min_val)
assert err_part in str(exc.value)

View File

@ -124,7 +124,7 @@ def test_agent(client: LocalClient):
def test_agent_add_remove_tools(client: LocalClient, agent): def test_agent_add_remove_tools(client: LocalClient, agent):
# Create and add two tools to the client # Create and add two tools to the client
# tool 1 # tool 1
from composio_langchain import Action from composio import Action
github_tool = client.load_composio_tool(action=Action.GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER) github_tool = client.load_composio_tool(action=Action.GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER)
@ -316,7 +316,7 @@ def test_tools(client: LocalClient):
def test_tools_from_composio_basic(client: LocalClient): def test_tools_from_composio_basic(client: LocalClient):
from composio_langchain import Action from composio import Action
# Create a `LocalClient` (you can also use a `RESTClient`, see the letta_rest_client.py example) # Create a `LocalClient` (you can also use a `RESTClient`, see the letta_rest_client.py example)
client = create_client() client = create_client()

View File

@ -3,7 +3,7 @@ from unittest.mock import patch
import pytest import pytest
from letta.server.rest_api.optimistic_json_parser import OptimisticJSONParser from letta.server.rest_api.json_parser import OptimisticJSONParser
@pytest.fixture @pytest.fixture

View File

@ -19,97 +19,166 @@ from letta.settings import model_settings
def test_openai(): def test_openai():
api_key = os.getenv("OPENAI_API_KEY") api_key = os.getenv("OPENAI_API_KEY")
assert api_key is not None assert api_key is not None
provider = OpenAIProvider(api_key=api_key, base_url=model_settings.openai_api_base) provider = OpenAIProvider(
name="openai",
api_key=api_key,
base_url=model_settings.openai_api_base,
)
models = provider.list_llm_models() models = provider.list_llm_models()
print(models) assert len(models) > 0
assert models[0].handle == f"{provider.name}/{models[0].model}"
embedding_models = provider.list_embedding_models()
assert len(embedding_models) > 0
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
def test_deepseek(): def test_deepseek():
api_key = os.getenv("DEEPSEEK_API_KEY") api_key = os.getenv("DEEPSEEK_API_KEY")
assert api_key is not None assert api_key is not None
provider = DeepSeekProvider(api_key=api_key) provider = DeepSeekProvider(
name="deepseek",
api_key=api_key,
)
models = provider.list_llm_models() models = provider.list_llm_models()
print(models) assert len(models) > 0
assert models[0].handle == f"{provider.name}/{models[0].model}"
def test_anthropic(): def test_anthropic():
api_key = os.getenv("ANTHROPIC_API_KEY") api_key = os.getenv("ANTHROPIC_API_KEY")
assert api_key is not None assert api_key is not None
provider = AnthropicProvider(api_key=api_key) provider = AnthropicProvider(
name="anthropic",
api_key=api_key,
)
models = provider.list_llm_models() models = provider.list_llm_models()
print(models) assert len(models) > 0
assert models[0].handle == f"{provider.name}/{models[0].model}"
def test_groq(): def test_groq():
provider = GroqProvider(api_key=os.getenv("GROQ_API_KEY")) provider = GroqProvider(
name="groq",
api_key=os.getenv("GROQ_API_KEY"),
)
models = provider.list_llm_models() models = provider.list_llm_models()
print(models) assert len(models) > 0
assert models[0].handle == f"{provider.name}/{models[0].model}"
def test_azure(): def test_azure():
provider = AzureProvider(api_key=os.getenv("AZURE_API_KEY"), base_url=os.getenv("AZURE_BASE_URL")) provider = AzureProvider(
name="azure",
api_key=os.getenv("AZURE_API_KEY"),
base_url=os.getenv("AZURE_BASE_URL"),
)
models = provider.list_llm_models() models = provider.list_llm_models()
print([m.model for m in models]) assert len(models) > 0
assert models[0].handle == f"{provider.name}/{models[0].model}"
embed_models = provider.list_embedding_models() embedding_models = provider.list_embedding_models()
print([m.embedding_model for m in embed_models]) assert len(embedding_models) > 0
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
def test_ollama(): def test_ollama():
base_url = os.getenv("OLLAMA_BASE_URL") base_url = os.getenv("OLLAMA_BASE_URL")
assert base_url is not None assert base_url is not None
provider = OllamaProvider(base_url=base_url, default_prompt_formatter=model_settings.default_prompt_formatter, api_key=None) provider = OllamaProvider(
name="ollama",
base_url=base_url,
default_prompt_formatter=model_settings.default_prompt_formatter,
api_key=None,
)
models = provider.list_llm_models() models = provider.list_llm_models()
print(models) assert len(models) > 0
assert models[0].handle == f"{provider.name}/{models[0].model}"
embedding_models = provider.list_embedding_models() embedding_models = provider.list_embedding_models()
print(embedding_models) assert len(embedding_models) > 0
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
def test_googleai(): def test_googleai():
api_key = os.getenv("GEMINI_API_KEY") api_key = os.getenv("GEMINI_API_KEY")
assert api_key is not None assert api_key is not None
provider = GoogleAIProvider(api_key=api_key) provider = GoogleAIProvider(
name="google_ai",
api_key=api_key,
)
models = provider.list_llm_models() models = provider.list_llm_models()
print(models) assert len(models) > 0
assert models[0].handle == f"{provider.name}/{models[0].model}"
provider.list_embedding_models() embedding_models = provider.list_embedding_models()
assert len(embedding_models) > 0
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
def test_google_vertex(): def test_google_vertex():
provider = GoogleVertexProvider(google_cloud_project=os.getenv("GCP_PROJECT_ID"), google_cloud_location=os.getenv("GCP_REGION")) provider = GoogleVertexProvider(
name="google_vertex",
google_cloud_project=os.getenv("GCP_PROJECT_ID"),
google_cloud_location=os.getenv("GCP_REGION"),
)
models = provider.list_llm_models() models = provider.list_llm_models()
print(models) assert len(models) > 0
print([m.model for m in models]) assert models[0].handle == f"{provider.name}/{models[0].model}"
embedding_models = provider.list_embedding_models() embedding_models = provider.list_embedding_models()
print([m.embedding_model for m in embedding_models]) assert len(embedding_models) > 0
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
def test_mistral(): def test_mistral():
provider = MistralProvider(api_key=os.getenv("MISTRAL_API_KEY")) provider = MistralProvider(
name="mistral",
api_key=os.getenv("MISTRAL_API_KEY"),
)
models = provider.list_llm_models() models = provider.list_llm_models()
print([m.model for m in models]) assert len(models) > 0
assert models[0].handle == f"{provider.name}/{models[0].model}"
def test_together(): def test_together():
provider = TogetherProvider(api_key=os.getenv("TOGETHER_API_KEY"), default_prompt_formatter="chatml") provider = TogetherProvider(
name="together",
api_key=os.getenv("TOGETHER_API_KEY"),
default_prompt_formatter="chatml",
)
models = provider.list_llm_models() models = provider.list_llm_models()
print([m.model for m in models]) assert len(models) > 0
assert models[0].handle == f"{provider.name}/{models[0].model}"
embedding_models = provider.list_embedding_models() embedding_models = provider.list_embedding_models()
print([m.embedding_model for m in embedding_models]) assert len(embedding_models) > 0
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
def test_anthropic_bedrock(): def test_anthropic_bedrock():
from letta.settings import model_settings from letta.settings import model_settings
provider = AnthropicBedrockProvider(aws_region=model_settings.aws_region) provider = AnthropicBedrockProvider(name="bedrock", aws_region=model_settings.aws_region)
models = provider.list_llm_models() models = provider.list_llm_models()
print([m.model for m in models]) assert len(models) > 0
assert models[0].handle == f"{provider.name}/{models[0].model}"
embedding_models = provider.list_embedding_models() embedding_models = provider.list_embedding_models()
print([m.embedding_model for m in embedding_models]) assert len(embedding_models) > 0
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
def test_custom_anthropic():
api_key = os.getenv("ANTHROPIC_API_KEY")
assert api_key is not None
provider = AnthropicProvider(
name="custom_anthropic",
api_key=api_key,
)
models = provider.list_llm_models()
assert len(models) > 0
assert models[0].handle == f"{provider.name}/{models[0].model}"
# def test_vllm(): # def test_vllm():

View File

@ -13,7 +13,7 @@ import letta.utils as utils
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, LETTA_DIR, LETTA_TOOL_EXECUTION_DIR from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, LETTA_DIR, LETTA_TOOL_EXECUTION_DIR
from letta.orm import Provider, Step from letta.orm import Provider, Step
from letta.schemas.block import CreateBlock from letta.schemas.block import CreateBlock
from letta.schemas.enums import MessageRole from letta.schemas.enums import MessageRole, ProviderType
from letta.schemas.letta_message import LettaMessage, ReasoningMessage, SystemMessage, ToolCallMessage, ToolReturnMessage, UserMessage from letta.schemas.letta_message import LettaMessage, ReasoningMessage, SystemMessage, ToolCallMessage, ToolReturnMessage, UserMessage
from letta.schemas.llm_config import LLMConfig from letta.schemas.llm_config import LLMConfig
from letta.schemas.providers import Provider as PydanticProvider from letta.schemas.providers import Provider as PydanticProvider
@ -1226,7 +1226,8 @@ def test_messages_with_provider_override(server: SyncServer, user_id: str):
actor = server.user_manager.get_user_or_default(user_id) actor = server.user_manager.get_user_or_default(user_id)
provider = server.provider_manager.create_provider( provider = server.provider_manager.create_provider(
provider=PydanticProvider( provider=PydanticProvider(
name="anthropic", name="caren-anthropic",
provider_type=ProviderType.anthropic,
api_key=os.getenv("ANTHROPIC_API_KEY"), api_key=os.getenv("ANTHROPIC_API_KEY"),
), ),
actor=actor, actor=actor,
@ -1234,8 +1235,8 @@ def test_messages_with_provider_override(server: SyncServer, user_id: str):
agent = server.create_agent( agent = server.create_agent(
request=CreateAgent( request=CreateAgent(
memory_blocks=[], memory_blocks=[],
model="anthropic/claude-3-opus-20240229", model="caren-anthropic/claude-3-opus-20240229",
context_window_limit=200000, context_window_limit=100000,
embedding="openai/text-embedding-ada-002", embedding="openai/text-embedding-ada-002",
), ),
actor=actor, actor=actor,