chore: bump version 0.7.8 (#2604)

Co-authored-by: Kian Jones <11655409+kianjones9@users.noreply.github.com>
Co-authored-by: Andy Li <55300002+cliandy@users.noreply.github.com>
Co-authored-by: Matthew Zhou <mattzh1314@gmail.com>
This commit is contained in:
cthomas 2025-04-30 23:39:58 -07:00 committed by GitHub
parent e07edd840b
commit 20ecab29a1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
65 changed files with 1248 additions and 649 deletions

View File

@ -167,7 +167,7 @@ docker exec -it $(docker ps -q -f ancestor=letta/letta) letta run
In the CLI tool, you'll be able to create new agents, or load existing agents:
```
🧬 Creating new agent...
? Select LLM model: letta-free [type=openai] [ip=https://inference.memgpt.ai]
? Select LLM model: letta-free [type=openai] [ip=https://inference.letta.com]
? Select embedding model: letta-free [type=hugging-face] [ip=https://embeddings.memgpt.ai]
-> 🤖 Using persona profile: 'sam_pov'
-> 🧑 Using human profile: 'basic'
@ -233,7 +233,7 @@ letta run
```
```
🧬 Creating new agent...
? Select LLM model: letta-free [type=openai] [ip=https://inference.memgpt.ai]
? Select LLM model: letta-free [type=openai] [ip=https://inference.letta.com]
? Select embedding model: letta-free [type=hugging-face] [ip=https://embeddings.memgpt.ai]
-> 🤖 Using persona profile: 'sam_pov'
-> 🧑 Using human profile: 'basic'

View File

@ -0,0 +1,35 @@
"""add byok fields and unique constraint
Revision ID: 373dabcba6cf
Revises: c56081a05371
Create Date: 2025-04-30 19:38:25.010856
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "373dabcba6cf"
down_revision: Union[str, None] = "c56081a05371"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("providers", sa.Column("provider_type", sa.String(), nullable=True))
op.add_column("providers", sa.Column("base_url", sa.String(), nullable=True))
op.create_unique_constraint("unique_name_organization_id", "providers", ["name", "organization_id"])
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint("unique_name_organization_id", "providers", type_="unique")
op.drop_column("providers", "base_url")
op.drop_column("providers", "provider_type")
# ### end Alembic commands ###

View File

@ -0,0 +1,33 @@
"""Add buffer length min max for voice sleeptime
Revision ID: c56081a05371
Revises: 28b8765bdd0a
Create Date: 2025-04-30 16:03:41.213750
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "c56081a05371"
down_revision: Union[str, None] = "28b8765bdd0a"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("groups", sa.Column("max_message_buffer_length", sa.Integer(), nullable=True))
op.add_column("groups", sa.Column("min_message_buffer_length", sa.Integer(), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("groups", "min_message_buffer_length")
op.drop_column("groups", "max_message_buffer_length")
# ### end Alembic commands ###

View File

@ -60,7 +60,7 @@ Last updated Oct 2, 2024. Please check `composio` documentation for any composio
def main():
from composio_langchain import Action
from composio import Action
# Add the composio tool
tool = client.load_composio_tool(action=Action.GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER)

View File

@ -0,0 +1,32 @@
from letta_client import Letta, VoiceSleeptimeManagerUpdate
client = Letta(base_url="http://localhost:8283")
agent = client.agents.create(
name="low_latency_voice_agent_demo",
agent_type="voice_convo_agent",
memory_blocks=[
{"value": "Name: ?", "label": "human"},
{"value": "You are a helpful assistant.", "label": "persona"},
],
model="openai/gpt-4o-mini", # Use 4o-mini for speed
embedding="openai/text-embedding-3-small",
enable_sleeptime=True,
initial_message_sequence = [],
)
print(f"Created agent id {agent.id}")
# get the group
group_id = agent.multi_agent_group.id
max_message_buffer_length = agent.multi_agent_group.max_message_buffer_length
min_message_buffer_length = agent.multi_agent_group.min_message_buffer_length
print(f"Group id: {group_id}, max_message_buffer_length: {max_message_buffer_length}, min_message_buffer_length: {min_message_buffer_length}")
# change it to be more frequent
group = client.groups.modify(
group_id=group_id,
manager_config=VoiceSleeptimeManagerUpdate(
max_message_buffer_length=10,
min_message_buffer_length=6,
)
)

View File

@ -1,4 +1,4 @@
__version__ = "0.7.7"
__version__ = "0.7.8"
# import clients
from letta.client.client import LocalClient, RESTClient, create_client

View File

@ -21,14 +21,14 @@ from letta.constants import (
)
from letta.errors import ContextWindowExceededError
from letta.functions.ast_parsers import coerce_dict_args_by_annotations, get_function_annotations_from_source
from letta.functions.composio_helpers import execute_composio_action, generate_composio_action_from_func_name
from letta.functions.functions import get_function_from_module
from letta.functions.helpers import execute_composio_action, generate_composio_action_from_func_name
from letta.functions.mcp_client.base_client import BaseMCPClient
from letta.helpers import ToolRulesSolver
from letta.helpers.composio_helpers import get_composio_api_key
from letta.helpers.datetime_helpers import get_utc_time
from letta.helpers.json_helpers import json_dumps, json_loads
from letta.helpers.message_helper import prepare_input_message_create
from letta.helpers.message_helper import convert_message_creates_to_messages
from letta.interface import AgentInterface
from letta.llm_api.helpers import calculate_summarizer_cutoff, get_token_counts_for_messages, is_context_overflow_error
from letta.llm_api.llm_api_tools import create
@ -331,8 +331,10 @@ class Agent(BaseAgent):
log_telemetry(self.logger, "_get_ai_reply create start")
# New LLM client flow
llm_client = LLMClient.create(
provider=self.agent_state.llm_config.model_endpoint_type,
provider_name=self.agent_state.llm_config.provider_name,
provider_type=self.agent_state.llm_config.model_endpoint_type,
put_inner_thoughts_first=put_inner_thoughts_first,
actor_id=self.user.id,
)
if llm_client and not stream:
@ -726,8 +728,7 @@ class Agent(BaseAgent):
self.tool_rules_solver.clear_tool_history()
# Convert MessageCreate objects to Message objects
message_objects = [prepare_input_message_create(m, self.agent_state.id, True, True) for m in input_messages]
next_input_messages = message_objects
next_input_messages = convert_message_creates_to_messages(input_messages, self.agent_state.id)
counter = 0
total_usage = UsageStatistics()
step_count = 0
@ -942,12 +943,7 @@ class Agent(BaseAgent):
model_endpoint=self.agent_state.llm_config.model_endpoint,
context_window_limit=self.agent_state.llm_config.context_window,
usage=response.usage,
# TODO(@caren): Add full provider support - this line is a workaround for v0 BYOK feature
provider_id=(
self.provider_manager.get_anthropic_override_provider_id()
if self.agent_state.llm_config.model_endpoint_type == "anthropic"
else None
),
provider_id=self.provider_manager.get_provider_id_from_name(self.agent_state.llm_config.provider_name),
job_id=job_id,
)
for message in all_new_messages:

View File

@ -0,0 +1,6 @@
class IncompatibleAgentType(ValueError):
def __init__(self, expected_type: str, actual_type: str):
message = f"Incompatible agent type: expected '{expected_type}', but got '{actual_type}'."
super().__init__(message)
self.expected_type = expected_type
self.actual_type = actual_type

View File

@ -67,8 +67,10 @@ class LettaAgent(BaseAgent):
)
tool_rules_solver = ToolRulesSolver(agent_state.tool_rules)
llm_client = LLMClient.create(
provider=agent_state.llm_config.model_endpoint_type,
provider_name=agent_state.llm_config.provider_name,
provider_type=agent_state.llm_config.model_endpoint_type,
put_inner_thoughts_first=True,
actor_id=self.actor.id,
)
for step in range(max_steps):
response = await self._get_ai_reply(
@ -109,8 +111,10 @@ class LettaAgent(BaseAgent):
)
tool_rules_solver = ToolRulesSolver(agent_state.tool_rules)
llm_client = LLMClient.create(
llm_config=agent_state.llm_config,
provider_name=agent_state.llm_config.provider_name,
provider_type=agent_state.llm_config.model_endpoint_type,
put_inner_thoughts_first=True,
actor_id=self.actor.id,
)
for step in range(max_steps):
@ -125,7 +129,7 @@ class LettaAgent(BaseAgent):
# TODO: THIS IS INCREDIBLY UGLY
# TODO: THERE ARE MULTIPLE COPIES OF THE LLM_CONFIG EVERYWHERE THAT ARE GETTING MANIPULATED
interface = AnthropicStreamingInterface(
use_assistant_message=use_assistant_message, put_inner_thoughts_in_kwarg=llm_client.llm_config.put_inner_thoughts_in_kwargs
use_assistant_message=use_assistant_message, put_inner_thoughts_in_kwarg=agent_state.llm_config.put_inner_thoughts_in_kwargs
)
async for chunk in interface.process(stream):
yield f"data: {chunk.model_dump_json()}\n\n"
@ -179,6 +183,7 @@ class LettaAgent(BaseAgent):
ToolType.LETTA_SLEEPTIME_CORE,
}
or (t.tool_type == ToolType.LETTA_MULTI_AGENT_CORE and t.name == "send_message_to_agents_matching_tags")
or (t.tool_type == ToolType.EXTERNAL_COMPOSIO)
]
valid_tool_names = tool_rules_solver.get_allowed_tool_names(available_tools=set([t.name for t in tools]))
@ -274,6 +279,7 @@ class LettaAgent(BaseAgent):
return persisted_messages, continue_stepping
def _rebuild_memory(self, in_context_messages: List[Message], agent_state: AgentState) -> List[Message]:
try:
self.agent_manager.refresh_memory(agent_state=agent_state, actor=self.actor)
# TODO: This is a pretty brittle pattern established all over our code, need to get rid of this
@ -313,6 +319,9 @@ class LettaAgent(BaseAgent):
else:
return in_context_messages
except:
logger.exception(f"Failed to rebuild memory for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name})")
raise
@trace_method
async def _execute_tool(self, tool_name: str, tool_args: dict, agent_state: AgentState) -> Tuple[str, bool]:
@ -331,6 +340,10 @@ class LettaAgent(BaseAgent):
results = await self._send_message_to_agents_matching_tags(**tool_args)
log_event(name="finish_send_message_to_agents_matching_tags", attributes=tool_args)
return json.dumps(results), True
elif target_tool.type == ToolType.EXTERNAL_COMPOSIO:
log_event(name=f"start_composio_{tool_name}_execution", attributes=tool_args)
log_event(name=f"finish_compsio_{tool_name}_execution", attributes=tool_args)
return tool_execution_result.func_return, True
else:
tool_execution_manager = ToolExecutionManager(agent_state=agent_state, actor=self.actor)
# TODO: Integrate sandbox result

View File

@ -156,8 +156,10 @@ class LettaAgentBatch:
log_event(name="init_llm_client")
llm_client = LLMClient.create(
provider=agent_states[0].llm_config.model_endpoint_type,
provider_name=agent_states[0].llm_config.provider_name,
provider_type=agent_states[0].llm_config.model_endpoint_type,
put_inner_thoughts_first=True,
actor_id=self.actor.id,
)
agent_llm_config_mapping = {s.id: s.llm_config for s in agent_states}
@ -273,8 +275,10 @@ class LettaAgentBatch:
# translate providerspecific response → OpenAIstyle tool call (unchanged)
llm_client = LLMClient.create(
provider=item.llm_config.model_endpoint_type,
provider_name=item.llm_config.provider_name,
provider_type=item.llm_config.model_endpoint_type,
put_inner_thoughts_first=True,
actor_id=self.actor.id,
)
tool_call = (
llm_client.convert_response_to_chat_completion(

View File

@ -6,6 +6,7 @@ from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
import openai
from letta.agents.base_agent import BaseAgent
from letta.agents.exceptions import IncompatibleAgentType
from letta.agents.voice_sleeptime_agent import VoiceSleeptimeAgent
from letta.constants import NON_USER_MSG_PREFIX
from letta.helpers.datetime_helpers import get_utc_time
@ -18,7 +19,7 @@ from letta.helpers.tool_execution_helper import (
from letta.interfaces.openai_chat_completions_streaming_interface import OpenAIChatCompletionsStreamingInterface
from letta.log import get_logger
from letta.orm.enums import ToolType
from letta.schemas.agent import AgentState
from letta.schemas.agent import AgentState, AgentType
from letta.schemas.enums import MessageRole
from letta.schemas.letta_response import LettaResponse
from letta.schemas.message import Message, MessageCreate, MessageUpdate
@ -68,8 +69,6 @@ class VoiceAgent(BaseAgent):
block_manager: BlockManager,
passage_manager: PassageManager,
actor: User,
message_buffer_limit: int,
message_buffer_min: int,
):
super().__init__(
agent_id=agent_id, openai_client=openai_client, message_manager=message_manager, agent_manager=agent_manager, actor=actor
@ -80,8 +79,6 @@ class VoiceAgent(BaseAgent):
self.passage_manager = passage_manager
# TODO: This is not guaranteed to exist!
self.summary_block_label = "human"
self.message_buffer_limit = message_buffer_limit
self.message_buffer_min = message_buffer_min
# Cached archival memory/message size
self.num_messages = self.message_manager.size(actor=self.actor, agent_id=agent_id)
@ -108,8 +105,8 @@ class VoiceAgent(BaseAgent):
target_block_label=self.summary_block_label,
message_transcripts=[],
),
message_buffer_limit=self.message_buffer_limit,
message_buffer_min=self.message_buffer_min,
message_buffer_limit=agent_state.multi_agent_group.max_message_buffer_length,
message_buffer_min=agent_state.multi_agent_group.min_message_buffer_length,
)
return summarizer
@ -124,9 +121,15 @@ class VoiceAgent(BaseAgent):
"""
if len(input_messages) != 1 or input_messages[0].role != MessageRole.user:
raise ValueError(f"Voice Agent was invoked with multiple input messages or message did not have role `user`: {input_messages}")
user_query = input_messages[0].content[0].text
agent_state = self.agent_manager.get_agent_by_id(self.agent_id, actor=self.actor)
# Safety check
if agent_state.agent_type != AgentType.voice_convo_agent:
raise IncompatibleAgentType(expected_type=AgentType.voice_convo_agent, actual_type=agent_state.agent_type)
summarizer = self.init_summarizer(agent_state=agent_state)
in_context_messages = self.message_manager.get_messages_by_ids(message_ids=agent_state.message_ids, actor=self.actor)

View File

@ -4,7 +4,7 @@ from logging import CRITICAL, DEBUG, ERROR, INFO, NOTSET, WARN, WARNING
LETTA_DIR = os.path.join(os.path.expanduser("~"), ".letta")
LETTA_TOOL_EXECUTION_DIR = os.path.join(LETTA_DIR, "tool_execution_dir")
LETTA_MODEL_ENDPOINT = "https://inference.memgpt.ai"
LETTA_MODEL_ENDPOINT = "https://inference.letta.com"
ADMIN_PREFIX = "/v1/admin"
API_PREFIX = "/v1"
@ -35,6 +35,10 @@ TOOL_CALL_ID_MAX_LEN = 29
# minimum context window size
MIN_CONTEXT_WINDOW = 4096
# Voice Sleeptime message buffer lengths
DEFAULT_MAX_MESSAGE_BUFFER_LENGTH = 30
DEFAULT_MIN_MESSAGE_BUFFER_LENGTH = 15
# embeddings
MAX_EMBEDDING_DIM = 4096 # maximum supported embeding size - do NOT change or else DBs will need to be reset
DEFAULT_EMBEDDING_CHUNK_SIZE = 300

View File

@ -0,0 +1,100 @@
import asyncio
import os
from typing import Any, Optional
from composio import ComposioToolSet
from composio.constants import DEFAULT_ENTITY_ID
from composio.exceptions import (
ApiKeyNotProvidedError,
ComposioSDKError,
ConnectedAccountNotFoundError,
EnumMetadataNotFound,
EnumStringNotFound,
)
from letta.constants import COMPOSIO_ENTITY_ENV_VAR_KEY
# TODO: This is kind of hacky, as this is used to search up the action later on composio's side
# TODO: So be very careful changing/removing these pair of functions
def _generate_func_name_from_composio_action(action_name: str) -> str:
"""
Generates the composio function name from the composio action.
Args:
action_name: The composio action name
Returns:
function name
"""
return action_name.lower()
def generate_composio_action_from_func_name(func_name: str) -> str:
"""
Generates the composio action from the composio function name.
Args:
func_name: The composio function name
Returns:
composio action name
"""
return func_name.upper()
def generate_composio_tool_wrapper(action_name: str) -> tuple[str, str]:
# Generate func name
func_name = _generate_func_name_from_composio_action(action_name)
wrapper_function_str = f"""\
def {func_name}(**kwargs):
raise RuntimeError("Something went wrong - we should never be using the persisted source code for Composio. Please reach out to Letta team")
"""
# Compile safety check
_assert_code_gen_compilable(wrapper_function_str.strip())
return func_name, wrapper_function_str.strip()
async def execute_composio_action_async(
action_name: str, args: dict, api_key: Optional[str] = None, entity_id: Optional[str] = None
) -> tuple[str, str]:
try:
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, execute_composio_action, action_name, args, api_key, entity_id)
except Exception as e:
raise RuntimeError(f"Error in execute_composio_action_async: {e}") from e
def execute_composio_action(action_name: str, args: dict, api_key: Optional[str] = None, entity_id: Optional[str] = None) -> Any:
entity_id = entity_id or os.getenv(COMPOSIO_ENTITY_ENV_VAR_KEY, DEFAULT_ENTITY_ID)
try:
composio_toolset = ComposioToolSet(api_key=api_key, entity_id=entity_id, lock=False)
response = composio_toolset.execute_action(action=action_name, params=args)
except ApiKeyNotProvidedError:
raise RuntimeError(
f"Composio API key is missing for action '{action_name}'. "
"Please set the sandbox environment variables either through the ADE or the API."
)
except ConnectedAccountNotFoundError:
raise RuntimeError(f"No connected account was found for action '{action_name}'. " "Please link an account and try again.")
except EnumStringNotFound as e:
raise RuntimeError(f"Invalid value provided for action '{action_name}': " + str(e) + ". Please check the action parameters.")
except EnumMetadataNotFound as e:
raise RuntimeError(f"Invalid value provided for action '{action_name}': " + str(e) + ". Please check the action parameters.")
except ComposioSDKError as e:
raise RuntimeError(f"An unexpected error occurred in Composio SDK while executing action '{action_name}': " + str(e))
if "error" in response and response["error"]:
raise RuntimeError(f"Error while executing action '{action_name}': " + str(response["error"]))
return response.get("data")
def _assert_code_gen_compilable(code_str):
try:
compile(code_str, "<string>", "exec")
except SyntaxError as e:
print(f"Syntax error in code: {e}")

View File

@ -1,8 +1,9 @@
import importlib
import inspect
from collections.abc import Callable
from textwrap import dedent # remove indentation
from types import ModuleType
from typing import Dict, List, Literal, Optional
from typing import Any, Dict, List, Literal, Optional
from letta.errors import LettaToolCreateError
from letta.functions.schema_generator import generate_schema
@ -66,7 +67,8 @@ def parse_source_code(func) -> str:
return source_code
def get_function_from_module(module_name: str, function_name: str):
# TODO (cliandy) refactor below two funcs
def get_function_from_module(module_name: str, function_name: str) -> Callable[..., Any]:
"""
Dynamically imports a function from a specified module.

View File

@ -6,10 +6,9 @@ from random import uniform
from typing import Any, Dict, List, Optional, Type, Union
import humps
from composio.constants import DEFAULT_ENTITY_ID
from pydantic import BaseModel, Field, create_model
from letta.constants import COMPOSIO_ENTITY_ENV_VAR_KEY, DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
from letta.functions.interface import MultiAgentMessagingInterface
from letta.orm.errors import NoResultFound
from letta.schemas.enums import MessageRole
@ -21,34 +20,6 @@ from letta.server.rest_api.utils import get_letta_server
from letta.settings import settings
# TODO: This is kind of hacky, as this is used to search up the action later on composio's side
# TODO: So be very careful changing/removing these pair of functions
def generate_func_name_from_composio_action(action_name: str) -> str:
"""
Generates the composio function name from the composio action.
Args:
action_name: The composio action name
Returns:
function name
"""
return action_name.lower()
def generate_composio_action_from_func_name(func_name: str) -> str:
"""
Generates the composio action from the composio function name.
Args:
func_name: The composio function name
Returns:
composio action name
"""
return func_name.upper()
# TODO needed?
def generate_mcp_tool_wrapper(mcp_tool_name: str) -> tuple[str, str]:
@ -58,71 +29,20 @@ def {mcp_tool_name}(**kwargs):
"""
# Compile safety check
assert_code_gen_compilable(wrapper_function_str.strip())
_assert_code_gen_compilable(wrapper_function_str.strip())
return mcp_tool_name, wrapper_function_str.strip()
def generate_composio_tool_wrapper(action_name: str) -> tuple[str, str]:
# Generate func name
func_name = generate_func_name_from_composio_action(action_name)
wrapper_function_str = f"""\
def {func_name}(**kwargs):
raise RuntimeError("Something went wrong - we should never be using the persisted source code for Composio. Please reach out to Letta team")
"""
# Compile safety check
assert_code_gen_compilable(wrapper_function_str.strip())
return func_name, wrapper_function_str.strip()
def execute_composio_action(action_name: str, args: dict, api_key: Optional[str] = None, entity_id: Optional[str] = None) -> Any:
import os
from composio.exceptions import (
ApiKeyNotProvidedError,
ComposioSDKError,
ConnectedAccountNotFoundError,
EnumMetadataNotFound,
EnumStringNotFound,
)
from composio_langchain import ComposioToolSet
entity_id = entity_id or os.getenv(COMPOSIO_ENTITY_ENV_VAR_KEY, DEFAULT_ENTITY_ID)
try:
composio_toolset = ComposioToolSet(api_key=api_key, entity_id=entity_id, lock=False)
response = composio_toolset.execute_action(action=action_name, params=args)
except ApiKeyNotProvidedError:
raise RuntimeError(
f"Composio API key is missing for action '{action_name}'. "
"Please set the sandbox environment variables either through the ADE or the API."
)
except ConnectedAccountNotFoundError:
raise RuntimeError(f"No connected account was found for action '{action_name}'. " "Please link an account and try again.")
except EnumStringNotFound as e:
raise RuntimeError(f"Invalid value provided for action '{action_name}': " + str(e) + ". Please check the action parameters.")
except EnumMetadataNotFound as e:
raise RuntimeError(f"Invalid value provided for action '{action_name}': " + str(e) + ". Please check the action parameters.")
except ComposioSDKError as e:
raise RuntimeError(f"An unexpected error occurred in Composio SDK while executing action '{action_name}': " + str(e))
if "error" in response:
raise RuntimeError(f"Error while executing action '{action_name}': " + str(response["error"]))
return response.get("data")
def generate_langchain_tool_wrapper(
tool: "LangChainBaseTool", additional_imports_module_attr_map: dict[str, str] = None
) -> tuple[str, str]:
tool_name = tool.__class__.__name__
import_statement = f"from langchain_community.tools import {tool_name}"
extra_module_imports = generate_import_code(additional_imports_module_attr_map)
extra_module_imports = _generate_import_code(additional_imports_module_attr_map)
# Safety check that user has passed in all required imports:
assert_all_classes_are_imported(tool, additional_imports_module_attr_map)
_assert_all_classes_are_imported(tool, additional_imports_module_attr_map)
tool_instantiation = f"tool = {generate_imported_tool_instantiation_call_str(tool)}"
run_call = f"return tool._run(**kwargs)"
@ -139,25 +59,25 @@ def {func_name}(**kwargs):
"""
# Compile safety check
assert_code_gen_compilable(wrapper_function_str)
_assert_code_gen_compilable(wrapper_function_str)
return func_name, wrapper_function_str
def assert_code_gen_compilable(code_str):
def _assert_code_gen_compilable(code_str):
try:
compile(code_str, "<string>", "exec")
except SyntaxError as e:
print(f"Syntax error in code: {e}")
def assert_all_classes_are_imported(tool: Union["LangChainBaseTool"], additional_imports_module_attr_map: dict[str, str]) -> None:
def _assert_all_classes_are_imported(tool: Union["LangChainBaseTool"], additional_imports_module_attr_map: dict[str, str]) -> None:
# Safety check that user has passed in all required imports:
tool_name = tool.__class__.__name__
current_class_imports = {tool_name}
if additional_imports_module_attr_map:
current_class_imports.update(set(additional_imports_module_attr_map.values()))
required_class_imports = set(find_required_class_names_for_import(tool))
required_class_imports = set(_find_required_class_names_for_import(tool))
if not current_class_imports.issuperset(required_class_imports):
err_msg = f"[ERROR] You are missing module_attr pairs in `additional_imports_module_attr_map`. Currently, you have imports for {current_class_imports}, but the required classes for import are {required_class_imports}"
@ -165,7 +85,7 @@ def assert_all_classes_are_imported(tool: Union["LangChainBaseTool"], additional
raise RuntimeError(err_msg)
def find_required_class_names_for_import(obj: Union["LangChainBaseTool", BaseModel]) -> list[str]:
def _find_required_class_names_for_import(obj: Union["LangChainBaseTool", BaseModel]) -> list[str]:
"""
Finds all the class names for required imports when instantiating the `obj`.
NOTE: This does not return the full import path, only the class name.
@ -181,7 +101,7 @@ def find_required_class_names_for_import(obj: Union["LangChainBaseTool", BaseMod
# Collect all possible candidates for BaseModel objects
candidates = []
if is_base_model(curr_obj):
if _is_base_model(curr_obj):
# If it is a base model, we get all the values of the object parameters
# i.e., if obj('b' = <class A>), we would want to inspect <class A>
fields = dict(curr_obj)
@ -198,7 +118,7 @@ def find_required_class_names_for_import(obj: Union["LangChainBaseTool", BaseMod
# Filter out all candidates that are not BaseModels
# In the list example above, ['a', 3, None, <class A>], we want to filter out 'a', 3, and None
candidates = filter(lambda x: is_base_model(x), candidates)
candidates = filter(lambda x: _is_base_model(x), candidates)
# Classic BFS here
for c in candidates:
@ -216,7 +136,7 @@ def generate_imported_tool_instantiation_call_str(obj: Any) -> Optional[str]:
# If it is a basic Python type, we trivially return the string version of that value
# Handle basic types
return repr(obj)
elif is_base_model(obj):
elif _is_base_model(obj):
# Otherwise, if it is a BaseModel
# We want to pull out all the parameters, and reformat them into strings
# e.g. {arg}={value}
@ -269,11 +189,11 @@ def generate_imported_tool_instantiation_call_str(obj: Any) -> Optional[str]:
return None
def is_base_model(obj: Any):
def _is_base_model(obj: Any):
return isinstance(obj, BaseModel)
def generate_import_code(module_attr_map: Optional[dict]):
def _generate_import_code(module_attr_map: Optional[dict]):
if not module_attr_map:
return ""
@ -286,7 +206,7 @@ def generate_import_code(module_attr_map: Optional[dict]):
return "\n".join(code_lines)
def parse_letta_response_for_assistant_message(
def _parse_letta_response_for_assistant_message(
target_agent_id: str,
letta_response: LettaResponse,
) -> Optional[str]:
@ -346,7 +266,7 @@ def execute_send_message_to_agent(
return asyncio.run(async_execute_send_message_to_agent(sender_agent, messages, other_agent_id, log_prefix))
async def send_message_to_agent_no_stream(
async def _send_message_to_agent_no_stream(
server: "SyncServer",
agent_id: str,
actor: User,
@ -375,7 +295,7 @@ async def send_message_to_agent_no_stream(
return LettaResponse(messages=final_messages, usage=usage_stats)
async def async_send_message_with_retries(
async def _async_send_message_with_retries(
server: "SyncServer",
sender_agent: "Agent",
target_agent_id: str,
@ -389,7 +309,7 @@ async def async_send_message_with_retries(
for attempt in range(1, max_retries + 1):
try:
response = await asyncio.wait_for(
send_message_to_agent_no_stream(
_send_message_to_agent_no_stream(
server=server,
agent_id=target_agent_id,
actor=sender_agent.user,
@ -399,7 +319,7 @@ async def async_send_message_with_retries(
)
# Then parse out the assistant message
assistant_message = parse_letta_response_for_assistant_message(target_agent_id, response)
assistant_message = _parse_letta_response_for_assistant_message(target_agent_id, response)
if assistant_message:
sender_agent.logger.info(f"{logging_prefix} - {assistant_message}")
return assistant_message

View File

@ -76,6 +76,7 @@ def load_multi_agent(
agent_state=agent_state,
interface=interface,
user=actor,
mcp_clients=mcp_clients,
group_id=group.id,
agent_ids=group.agent_ids,
description=group.description,

View File

@ -1,9 +1,10 @@
import asyncio
import threading
from datetime import datetime, timezone
from typing import List, Optional
from typing import Dict, List, Optional
from letta.agent import Agent, AgentState
from letta.functions.mcp_client.base_client import BaseMCPClient
from letta.groups.helpers import stringify_message
from letta.interface import AgentInterface
from letta.orm import User
@ -26,6 +27,7 @@ class SleeptimeMultiAgent(Agent):
interface: AgentInterface,
agent_state: AgentState,
user: User,
mcp_clients: Optional[Dict[str, BaseMCPClient]] = None,
# custom
group_id: str = "",
agent_ids: List[str] = [],
@ -115,6 +117,7 @@ class SleeptimeMultiAgent(Agent):
agent_state=participant_agent_state,
interface=StreamingServerInterface(),
user=self.user,
mcp_clients=self.mcp_clients,
)
prior_messages = []
@ -212,6 +215,7 @@ class SleeptimeMultiAgent(Agent):
agent_state=self.agent_state,
interface=self.interface,
user=self.user,
mcp_clients=self.mcp_clients,
)
# Perform main agent step
usage_stats = main_agent.step(

View File

@ -4,7 +4,24 @@ from letta.schemas.letta_message_content import TextContent
from letta.schemas.message import Message, MessageCreate
def prepare_input_message_create(
def convert_message_creates_to_messages(
messages: list[MessageCreate],
agent_id: str,
wrap_user_message: bool = True,
wrap_system_message: bool = True,
) -> list[Message]:
return [
_convert_message_create_to_message(
message=message,
agent_id=agent_id,
wrap_user_message=wrap_user_message,
wrap_system_message=wrap_system_message,
)
for message in messages
]
def _convert_message_create_to_message(
message: MessageCreate,
agent_id: str,
wrap_user_message: bool = True,
@ -23,12 +40,12 @@ def prepare_input_message_create(
raise ValueError("Message content is empty or invalid")
# Apply wrapping if needed
if message.role == MessageRole.user and wrap_user_message:
if message.role not in {MessageRole.user, MessageRole.system}:
raise ValueError(f"Invalid message role: {message.role}")
elif message.role == MessageRole.user and wrap_user_message:
message_content = system.package_user_message(user_message=message_content)
elif message.role == MessageRole.system and wrap_system_message:
message_content = system.package_system_message(system_message=message_content)
elif message.role not in {MessageRole.user, MessageRole.system}:
raise ValueError(f"Invalid message role: {message.role}")
return Message(
agent_id=agent_id,

View File

@ -3,7 +3,7 @@ from typing import Any, Dict, Optional
from letta.constants import COMPOSIO_ENTITY_ENV_VAR_KEY, PRE_EXECUTION_MESSAGE_ARG
from letta.functions.ast_parsers import coerce_dict_args_by_annotations, get_function_annotations_from_source
from letta.functions.helpers import execute_composio_action, generate_composio_action_from_func_name
from letta.functions.composio_helpers import execute_composio_action, generate_composio_action_from_func_name
from letta.helpers.composio_helpers import get_composio_api_key
from letta.orm.enums import ToolType
from letta.schemas.agent import AgentState

View File

@ -35,7 +35,7 @@ from letta.schemas.letta_message import (
from letta.schemas.letta_message_content import ReasoningContent, RedactedReasoningContent, TextContent
from letta.schemas.message import Message
from letta.schemas.openai.chat_completion_response import FunctionCall, ToolCall
from letta.server.rest_api.optimistic_json_parser import OptimisticJSONParser
from letta.server.rest_api.json_parser import JSONParser, PydanticJSONParser
logger = get_logger(__name__)
@ -56,7 +56,7 @@ class AnthropicStreamingInterface:
"""
def __init__(self, use_assistant_message: bool = False, put_inner_thoughts_in_kwarg: bool = False):
self.optimistic_json_parser: OptimisticJSONParser = OptimisticJSONParser()
self.json_parser: JSONParser = PydanticJSONParser()
self.use_assistant_message = use_assistant_message
# Premake IDs for database writes
@ -68,7 +68,7 @@ class AnthropicStreamingInterface:
self.accumulated_inner_thoughts = []
self.tool_call_id = None
self.tool_call_name = None
self.accumulated_tool_call_args = []
self.accumulated_tool_call_args = ""
self.previous_parse = {}
# usage trackers
@ -85,24 +85,27 @@ class AnthropicStreamingInterface:
def get_tool_call_object(self) -> ToolCall:
"""Useful for agent loop"""
return ToolCall(
id=self.tool_call_id, function=FunctionCall(arguments="".join(self.accumulated_tool_call_args), name=self.tool_call_name)
)
return ToolCall(id=self.tool_call_id, function=FunctionCall(arguments=self.accumulated_tool_call_args, name=self.tool_call_name))
def _check_inner_thoughts_complete(self, combined_args: str) -> bool:
"""
Check if inner thoughts are complete in the current tool call arguments
by looking for a closing quote after the inner_thoughts field
"""
try:
if not self.put_inner_thoughts_in_kwarg:
# None of the things should have inner thoughts in kwargs
return True
else:
parsed = self.optimistic_json_parser.parse(combined_args)
parsed = self.json_parser.parse(combined_args)
# TODO: This will break on tools with 0 input
return len(parsed.keys()) > 1 and INNER_THOUGHTS_KWARG in parsed.keys()
except Exception as e:
logger.error("Error checking inner thoughts: %s", e)
raise
async def process(self, stream: AsyncStream[BetaRawMessageStreamEvent]) -> AsyncGenerator[LettaMessage, None]:
try:
async with stream:
async for event in stream:
# TODO: Support BetaThinkingBlock, BetaRedactedThinkingBlock
@ -169,9 +172,8 @@ class AnthropicStreamingInterface:
f"Streaming integrity failed - received BetaInputJSONDelta object while not in TOOL_USE EventMode: {delta}"
)
self.accumulated_tool_call_args.append(delta.partial_json)
combined_args = "".join(self.accumulated_tool_call_args)
current_parsed = self.optimistic_json_parser.parse(combined_args)
self.accumulated_tool_call_args += delta.partial_json
current_parsed = self.json_parser.parse(self.accumulated_tool_call_args)
# Start detecting a difference in inner thoughts
previous_inner_thoughts = self.previous_parse.get(INNER_THOUGHTS_KWARG, "")
@ -188,7 +190,7 @@ class AnthropicStreamingInterface:
yield reasoning_message
# Check if inner thoughts are complete - if so, flush the buffer
if not self.inner_thoughts_complete and self._check_inner_thoughts_complete(combined_args):
if not self.inner_thoughts_complete and self._check_inner_thoughts_complete(self.accumulated_tool_call_args):
self.inner_thoughts_complete = True
# Flush all buffered tool call messages
for buffered_msg in self.tool_call_buffer:
@ -272,6 +274,11 @@ class AnthropicStreamingInterface:
self.tool_call_buffer = []
self.anthropic_mode = None
except Exception as e:
logger.error("Error processing stream: %s", e)
raise
finally:
logger.info("AnthropicStreamingInterface: Stream processing complete.")
def get_reasoning_content(self) -> List[Union[TextContent, ReasoningContent, RedactedReasoningContent]]:
def _process_group(

View File

@ -5,7 +5,7 @@ from openai.types.chat.chat_completion_chunk import ChatCompletionChunk, Choice,
from letta.constants import PRE_EXECUTION_MESSAGE_ARG
from letta.interfaces.utils import _format_sse_chunk
from letta.server.rest_api.optimistic_json_parser import OptimisticJSONParser
from letta.server.rest_api.json_parser import OptimisticJSONParser
class OpenAIChatCompletionsStreamingInterface:

View File

@ -26,6 +26,7 @@ from letta.llm_api.helpers import add_inner_thoughts_to_functions
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION
from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages
from letta.log import get_logger
from letta.schemas.enums import ProviderType
from letta.schemas.message import Message as _Message
from letta.schemas.message import MessageRole as _MessageRole
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, Tool
@ -128,11 +129,12 @@ def anthropic_get_model_list(url: str, api_key: Union[str, None]) -> dict:
# NOTE: currently there is no GET /models, so we need to hardcode
# return MODEL_LIST
anthropic_override_key = ProviderManager().get_anthropic_override_key()
if anthropic_override_key:
anthropic_client = anthropic.Anthropic(api_key=anthropic_override_key)
if api_key:
anthropic_client = anthropic.Anthropic(api_key=api_key)
elif model_settings.anthropic_api_key:
anthropic_client = anthropic.Anthropic()
else:
raise ValueError("No API key provided")
models = anthropic_client.models.list()
models_json = models.model_dump()
@ -738,13 +740,14 @@ def anthropic_chat_completions_request(
put_inner_thoughts_in_kwargs: bool = False,
extended_thinking: bool = False,
max_reasoning_tokens: Optional[int] = None,
provider_name: Optional[str] = None,
betas: List[str] = ["tools-2024-04-04"],
) -> ChatCompletionResponse:
"""https://docs.anthropic.com/claude/docs/tool-use"""
anthropic_client = None
anthropic_override_key = ProviderManager().get_anthropic_override_key()
if anthropic_override_key:
anthropic_client = anthropic.Anthropic(api_key=anthropic_override_key)
if provider_name and provider_name != ProviderType.anthropic.value:
api_key = ProviderManager().get_override_key(provider_name)
anthropic_client = anthropic.Anthropic(api_key=api_key)
elif model_settings.anthropic_api_key:
anthropic_client = anthropic.Anthropic()
else:
@ -796,6 +799,7 @@ def anthropic_chat_completions_request_stream(
put_inner_thoughts_in_kwargs: bool = False,
extended_thinking: bool = False,
max_reasoning_tokens: Optional[int] = None,
provider_name: Optional[str] = None,
betas: List[str] = ["tools-2024-04-04"],
) -> Generator[ChatCompletionChunkResponse, None, None]:
"""Stream chat completions from Anthropic API.
@ -810,10 +814,9 @@ def anthropic_chat_completions_request_stream(
extended_thinking=extended_thinking,
max_reasoning_tokens=max_reasoning_tokens,
)
anthropic_override_key = ProviderManager().get_anthropic_override_key()
if anthropic_override_key:
anthropic_client = anthropic.Anthropic(api_key=anthropic_override_key)
if provider_name and provider_name != ProviderType.anthropic.value:
api_key = ProviderManager().get_override_key(provider_name)
anthropic_client = anthropic.Anthropic(api_key=api_key)
elif model_settings.anthropic_api_key:
anthropic_client = anthropic.Anthropic()
@ -860,6 +863,7 @@ def anthropic_chat_completions_process_stream(
put_inner_thoughts_in_kwargs: bool = False,
extended_thinking: bool = False,
max_reasoning_tokens: Optional[int] = None,
provider_name: Optional[str] = None,
create_message_id: bool = True,
create_message_datetime: bool = True,
betas: List[str] = ["tools-2024-04-04"],
@ -944,6 +948,7 @@ def anthropic_chat_completions_process_stream(
put_inner_thoughts_in_kwargs=put_inner_thoughts_in_kwargs,
extended_thinking=extended_thinking,
max_reasoning_tokens=max_reasoning_tokens,
provider_name=provider_name,
betas=betas,
)
):

View File

@ -27,6 +27,7 @@ from letta.llm_api.helpers import add_inner_thoughts_to_functions, unpack_all_in
from letta.llm_api.llm_client_base import LLMClientBase
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION
from letta.log import get_logger
from letta.schemas.enums import ProviderType
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message as PydanticMessage
from letta.schemas.openai.chat_completion_request import Tool
@ -112,7 +113,10 @@ class AnthropicClient(LLMClientBase):
@trace_method
def _get_anthropic_client(self, async_client: bool = False) -> Union[anthropic.AsyncAnthropic, anthropic.Anthropic]:
override_key = ProviderManager().get_anthropic_override_key()
override_key = None
if self.provider_name and self.provider_name != ProviderType.anthropic.value:
override_key = ProviderManager().get_override_key(self.provider_name)
if async_client:
return anthropic.AsyncAnthropic(api_key=override_key) if override_key else anthropic.AsyncAnthropic()
return anthropic.Anthropic(api_key=override_key) if override_key else anthropic.Anthropic()

View File

@ -63,7 +63,7 @@ class GoogleVertexClient(GoogleAIClient):
# Add thinking_config
# If enable_reasoner is False, set thinking_budget to 0
# Otherwise, use the value from max_reasoning_tokens
thinking_budget = 0 if not self.llm_config.enable_reasoner else self.llm_config.max_reasoning_tokens
thinking_budget = 0 if not llm_config.enable_reasoner else llm_config.max_reasoning_tokens
thinking_config = ThinkingConfig(
thinking_budget=thinking_budget,
)

View File

@ -24,6 +24,7 @@ from letta.llm_api.openai import (
from letta.local_llm.chat_completion_proxy import get_chat_completion
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION
from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages
from letta.schemas.enums import ProviderType
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, cast_message_to_subtype
@ -171,6 +172,10 @@ def create(
if model_settings.openai_api_key is None and llm_config.model_endpoint == "https://api.openai.com/v1":
# only is a problem if we are *not* using an openai proxy
raise LettaConfigurationError(message="OpenAI key is missing from letta config file", missing_fields=["openai_api_key"])
elif llm_config.provider_name and llm_config.provider_name != ProviderType.openai.value:
from letta.services.provider_manager import ProviderManager
api_key = ProviderManager().get_override_key(llm_config.provider_name)
elif model_settings.openai_api_key is None:
# the openai python client requires a dummy API key
api_key = "DUMMY_API_KEY"
@ -373,6 +378,7 @@ def create(
stream_interface=stream_interface,
extended_thinking=llm_config.enable_reasoner,
max_reasoning_tokens=llm_config.max_reasoning_tokens,
provider_name=llm_config.provider_name,
name=name,
)
@ -383,6 +389,7 @@ def create(
put_inner_thoughts_in_kwargs=llm_config.put_inner_thoughts_in_kwargs,
extended_thinking=llm_config.enable_reasoner,
max_reasoning_tokens=llm_config.max_reasoning_tokens,
provider_name=llm_config.provider_name,
)
if llm_config.put_inner_thoughts_in_kwargs:

View File

@ -9,8 +9,10 @@ class LLMClient:
@staticmethod
def create(
provider: ProviderType,
provider_type: ProviderType,
provider_name: Optional[str] = None,
put_inner_thoughts_first: bool = True,
actor_id: Optional[str] = None,
) -> Optional[LLMClientBase]:
"""
Create an LLM client based on the model endpoint type.
@ -25,30 +27,38 @@ class LLMClient:
Raises:
ValueError: If the model endpoint type is not supported
"""
match provider:
match provider_type:
case ProviderType.google_ai:
from letta.llm_api.google_ai_client import GoogleAIClient
return GoogleAIClient(
provider_name=provider_name,
put_inner_thoughts_first=put_inner_thoughts_first,
actor_id=actor_id,
)
case ProviderType.google_vertex:
from letta.llm_api.google_vertex_client import GoogleVertexClient
return GoogleVertexClient(
provider_name=provider_name,
put_inner_thoughts_first=put_inner_thoughts_first,
actor_id=actor_id,
)
case ProviderType.anthropic:
from letta.llm_api.anthropic_client import AnthropicClient
return AnthropicClient(
provider_name=provider_name,
put_inner_thoughts_first=put_inner_thoughts_first,
actor_id=actor_id,
)
case ProviderType.openai:
from letta.llm_api.openai_client import OpenAIClient
return OpenAIClient(
provider_name=provider_name,
put_inner_thoughts_first=put_inner_thoughts_first,
actor_id=actor_id,
)
case _:
return None

View File

@ -20,9 +20,13 @@ class LLMClientBase:
def __init__(
self,
provider_name: Optional[str] = None,
put_inner_thoughts_first: Optional[bool] = True,
use_tool_naming: bool = True,
actor_id: Optional[str] = None,
):
self.actor_id = actor_id
self.provider_name = provider_name
self.put_inner_thoughts_first = put_inner_thoughts_first
self.use_tool_naming = use_tool_naming

View File

@ -157,11 +157,17 @@ def build_openai_chat_completions_request(
# if "gpt-4o" in llm_config.model or "gpt-4-turbo" in llm_config.model or "gpt-3.5-turbo" in llm_config.model:
# data.response_format = {"type": "json_object"}
# always set user id for openai requests
if user_id:
data.user = str(user_id)
if llm_config.model_endpoint == LETTA_MODEL_ENDPOINT:
# override user id for inference.memgpt.ai
if not user_id:
# override user id for inference.letta.com
import uuid
data.user = str(uuid.UUID(int=0))
data.model = "memgpt-openai"
if use_structured_output and data.tools is not None and len(data.tools) > 0:

View File

@ -22,6 +22,7 @@ from letta.llm_api.helpers import add_inner_thoughts_to_functions, convert_to_st
from letta.llm_api.llm_client_base import LLMClientBase
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION, INNER_THOUGHTS_KWARG_DESCRIPTION_GO_FIRST
from letta.log import get_logger
from letta.schemas.enums import ProviderType
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message as PydanticMessage
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest
@ -64,6 +65,13 @@ def supports_parallel_tool_calling(model: str) -> bool:
class OpenAIClient(LLMClientBase):
def _prepare_client_kwargs(self, llm_config: LLMConfig) -> dict:
api_key = None
if llm_config.provider_name and llm_config.provider_name != ProviderType.openai.value:
from letta.services.provider_manager import ProviderManager
api_key = ProviderManager().get_override_key(llm_config.provider_name)
if not api_key:
api_key = model_settings.openai_api_key or os.environ.get("OPENAI_API_KEY")
# supposedly the openai python client requires a dummy API key
api_key = api_key or "DUMMY_API_KEY"
@ -135,11 +143,17 @@ class OpenAIClient(LLMClientBase):
temperature=llm_config.temperature if supports_temperature_param(model) else None,
)
# always set user id for openai requests
if self.actor_id:
data.user = self.actor_id
if llm_config.model_endpoint == LETTA_MODEL_ENDPOINT:
# override user id for inference.memgpt.ai
if not self.actor_id:
# override user id for inference.letta.com
import uuid
data.user = str(uuid.UUID(int=0))
data.model = "memgpt-openai"
if data.tools is not None and len(data.tools) > 0:

View File

@ -79,8 +79,10 @@ def summarize_messages(
llm_config_no_inner_thoughts.put_inner_thoughts_in_kwargs = False
llm_client = LLMClient.create(
provider=llm_config_no_inner_thoughts.model_endpoint_type,
provider_name=llm_config_no_inner_thoughts.provider_name,
provider_type=llm_config_no_inner_thoughts.model_endpoint_type,
put_inner_thoughts_first=False,
actor_id=agent_state.created_by_id,
)
# try to use new client, otherwise fallback to old flow
# TODO: we can just directly call the LLM here?

View File

@ -21,6 +21,8 @@ class Group(SqlalchemyBase, OrganizationMixin):
termination_token: Mapped[Optional[str]] = mapped_column(nullable=True, doc="")
max_turns: Mapped[Optional[int]] = mapped_column(nullable=True, doc="")
sleeptime_agent_frequency: Mapped[Optional[int]] = mapped_column(nullable=True, doc="")
max_message_buffer_length: Mapped[Optional[int]] = mapped_column(nullable=True, doc="")
min_message_buffer_length: Mapped[Optional[int]] = mapped_column(nullable=True, doc="")
turns_counter: Mapped[Optional[int]] = mapped_column(nullable=True, doc="")
last_processed_message_id: Mapped[Optional[str]] = mapped_column(nullable=True, doc="")

View File

@ -1,5 +1,6 @@
from typing import TYPE_CHECKING
from sqlalchemy import UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column, relationship
from letta.orm.mixins import OrganizationMixin
@ -15,9 +16,18 @@ class Provider(SqlalchemyBase, OrganizationMixin):
__tablename__ = "providers"
__pydantic_model__ = PydanticProvider
__table_args__ = (
UniqueConstraint(
"name",
"organization_id",
name="unique_name_organization_id",
),
)
name: Mapped[str] = mapped_column(nullable=False, doc="The name of the provider")
provider_type: Mapped[str] = mapped_column(nullable=True, doc="The type of the provider")
api_key: Mapped[str] = mapped_column(nullable=True, doc="API key used for requests to the provider.")
base_url: Mapped[str] = mapped_column(nullable=True, doc="Base URL for the provider.")
# relationships
organization: Mapped["Organization"] = relationship("Organization", back_populates="providers")

View File

@ -56,7 +56,6 @@ class AgentState(OrmMetadataBase, validate_assignment=True):
name: str = Field(..., description="The name of the agent.")
# tool rules
tool_rules: Optional[List[ToolRule]] = Field(default=None, description="The list of tool rules.")
# in-context memory
message_ids: Optional[List[str]] = Field(default=None, description="The ids of the messages in the agent's in-context memory.")

View File

@ -6,6 +6,17 @@ class ProviderType(str, Enum):
google_ai = "google_ai"
google_vertex = "google_vertex"
openai = "openai"
letta = "letta"
deepseek = "deepseek"
lmstudio_openai = "lmstudio_openai"
xai = "xai"
mistral = "mistral"
ollama = "ollama"
groq = "groq"
together = "together"
azure = "azure"
vllm = "vllm"
bedrock = "bedrock"
class MessageRole(str, Enum):

View File

@ -32,6 +32,14 @@ class Group(GroupBase):
sleeptime_agent_frequency: Optional[int] = Field(None, description="")
turns_counter: Optional[int] = Field(None, description="")
last_processed_message_id: Optional[str] = Field(None, description="")
max_message_buffer_length: Optional[int] = Field(
None,
description="The desired maximum length of messages in the context window of the convo agent. This is a best effort, and may be off slightly due to user/assistant interleaving.",
)
min_message_buffer_length: Optional[int] = Field(
None,
description="The desired minimum length of messages in the context window of the convo agent. This is a best effort, and may be off-by-one due to user/assistant interleaving.",
)
class ManagerConfig(BaseModel):
@ -87,11 +95,27 @@ class SleeptimeManagerUpdate(ManagerConfig):
class VoiceSleeptimeManager(ManagerConfig):
manager_type: Literal[ManagerType.voice_sleeptime] = Field(ManagerType.voice_sleeptime, description="")
manager_agent_id: str = Field(..., description="")
max_message_buffer_length: Optional[int] = Field(
None,
description="The desired maximum length of messages in the context window of the convo agent. This is a best effort, and may be off slightly due to user/assistant interleaving.",
)
min_message_buffer_length: Optional[int] = Field(
None,
description="The desired minimum length of messages in the context window of the convo agent. This is a best effort, and may be off-by-one due to user/assistant interleaving.",
)
class VoiceSleeptimeManagerUpdate(ManagerConfig):
manager_type: Literal[ManagerType.voice_sleeptime] = Field(ManagerType.voice_sleeptime, description="")
manager_agent_id: Optional[str] = Field(None, description="")
max_message_buffer_length: Optional[int] = Field(
None,
description="The desired maximum length of messages in the context window of the convo agent. This is a best effort, and may be off slightly due to user/assistant interleaving.",
)
min_message_buffer_length: Optional[int] = Field(
None,
description="The desired minimum length of messages in the context window of the convo agent. This is a best effort, and may be off-by-one due to user/assistant interleaving.",
)
# class SwarmGroup(ManagerConfig):

View File

@ -50,6 +50,7 @@ class LLMConfig(BaseModel):
"xai",
] = Field(..., description="The endpoint type for the model.")
model_endpoint: Optional[str] = Field(None, description="The endpoint for the model.")
provider_name: Optional[str] = Field(None, description="The provider name for the model.")
model_wrapper: Optional[str] = Field(None, description="The wrapper for the model.")
context_window: int = Field(..., description="The context window size for the model.")
put_inner_thoughts_in_kwargs: Optional[bool] = Field(

View File

@ -2,8 +2,8 @@ from typing import Dict
LLM_HANDLE_OVERRIDES: Dict[str, Dict[str, str]] = {
"anthropic": {
"claude-3-5-haiku-20241022": "claude-3.5-haiku",
"claude-3-5-sonnet-20241022": "claude-3.5-sonnet",
"claude-3-5-haiku-20241022": "claude-3-5-haiku",
"claude-3-5-sonnet-20241022": "claude-3-5-sonnet",
"claude-3-opus-20240229": "claude-3-opus",
},
"openai": {

View File

@ -1,6 +1,6 @@
import warnings
from datetime import datetime
from typing import List, Optional
from typing import List, Literal, Optional
from pydantic import Field, model_validator
@ -9,9 +9,11 @@ from letta.llm_api.azure_openai import get_azure_chat_completions_endpoint, get_
from letta.llm_api.azure_openai_constants import AZURE_MODEL_TO_CONTEXT_LENGTH
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.embedding_config_overrides import EMBEDDING_HANDLE_OVERRIDES
from letta.schemas.enums import ProviderType
from letta.schemas.letta_base import LettaBase
from letta.schemas.llm_config import LLMConfig
from letta.schemas.llm_config_overrides import LLM_HANDLE_OVERRIDES
from letta.settings import model_settings
class ProviderBase(LettaBase):
@ -21,10 +23,18 @@ class ProviderBase(LettaBase):
class Provider(ProviderBase):
id: Optional[str] = Field(None, description="The id of the provider, lazily created by the database manager.")
name: str = Field(..., description="The name of the provider")
provider_type: ProviderType = Field(..., description="The type of the provider")
api_key: Optional[str] = Field(None, description="API key used for requests to the provider.")
base_url: Optional[str] = Field(None, description="Base URL for the provider.")
organization_id: Optional[str] = Field(None, description="The organization id of the user")
updated_at: Optional[datetime] = Field(None, description="The last update timestamp of the provider.")
@model_validator(mode="after")
def default_base_url(self):
if self.provider_type == ProviderType.openai and self.base_url is None:
self.base_url = model_settings.openai_api_base
return self
def resolve_identifier(self):
if not self.id:
self.id = ProviderBase.generate_id(prefix=ProviderBase.__id_prefix__)
@ -59,9 +69,41 @@ class Provider(ProviderBase):
return f"{self.name}/{model_name}"
def cast_to_subtype(self):
match (self.provider_type):
case ProviderType.letta:
return LettaProvider(**self.model_dump(exclude_none=True))
case ProviderType.openai:
return OpenAIProvider(**self.model_dump(exclude_none=True))
case ProviderType.anthropic:
return AnthropicProvider(**self.model_dump(exclude_none=True))
case ProviderType.anthropic_bedrock:
return AnthropicBedrockProvider(**self.model_dump(exclude_none=True))
case ProviderType.ollama:
return OllamaProvider(**self.model_dump(exclude_none=True))
case ProviderType.google_ai:
return GoogleAIProvider(**self.model_dump(exclude_none=True))
case ProviderType.google_vertex:
return GoogleVertexProvider(**self.model_dump(exclude_none=True))
case ProviderType.azure:
return AzureProvider(**self.model_dump(exclude_none=True))
case ProviderType.groq:
return GroqProvider(**self.model_dump(exclude_none=True))
case ProviderType.together:
return TogetherProvider(**self.model_dump(exclude_none=True))
case ProviderType.vllm_chat_completions:
return VLLMChatCompletionsProvider(**self.model_dump(exclude_none=True))
case ProviderType.vllm_completions:
return VLLMCompletionsProvider(**self.model_dump(exclude_none=True))
case ProviderType.xai:
return XAIProvider(**self.model_dump(exclude_none=True))
case _:
raise ValueError(f"Unknown provider type: {self.provider_type}")
class ProviderCreate(ProviderBase):
name: str = Field(..., description="The name of the provider.")
provider_type: ProviderType = Field(..., description="The type of the provider.")
api_key: str = Field(..., description="API key used for requests to the provider.")
@ -70,8 +112,7 @@ class ProviderUpdate(ProviderBase):
class LettaProvider(Provider):
name: str = "letta"
provider_type: Literal[ProviderType.letta] = Field(ProviderType.letta, description="The type of the provider.")
def list_llm_models(self) -> List[LLMConfig]:
return [
@ -81,6 +122,7 @@ class LettaProvider(Provider):
model_endpoint=LETTA_MODEL_ENDPOINT,
context_window=8192,
handle=self.get_handle("letta-free"),
provider_name=self.name,
)
]
@ -98,7 +140,7 @@ class LettaProvider(Provider):
class OpenAIProvider(Provider):
name: str = "openai"
provider_type: Literal[ProviderType.openai] = Field(ProviderType.openai, description="The type of the provider.")
api_key: str = Field(..., description="API key for the OpenAI API.")
base_url: str = Field(..., description="Base URL for the OpenAI API.")
@ -180,6 +222,7 @@ class OpenAIProvider(Provider):
model_endpoint=self.base_url,
context_window=context_window_size,
handle=self.get_handle(model_name),
provider_name=self.name,
)
)
@ -235,7 +278,7 @@ class DeepSeekProvider(OpenAIProvider):
* It also does not support native function calling
"""
name: str = "deepseek"
provider_type: Literal[ProviderType.deepseek] = Field(ProviderType.deepseek, description="The type of the provider.")
base_url: str = Field("https://api.deepseek.com/v1", description="Base URL for the DeepSeek API.")
api_key: str = Field(..., description="API key for the DeepSeek API.")
@ -286,6 +329,7 @@ class DeepSeekProvider(OpenAIProvider):
context_window=context_window_size,
handle=self.get_handle(model_name),
put_inner_thoughts_in_kwargs=put_inner_thoughts_in_kwargs,
provider_name=self.name,
)
)
@ -297,7 +341,7 @@ class DeepSeekProvider(OpenAIProvider):
class LMStudioOpenAIProvider(OpenAIProvider):
name: str = "lmstudio-openai"
provider_type: Literal[ProviderType.lmstudio_openai] = Field(ProviderType.lmstudio_openai, description="The type of the provider.")
base_url: str = Field(..., description="Base URL for the LMStudio OpenAI API.")
api_key: Optional[str] = Field(None, description="API key for the LMStudio API.")
@ -423,7 +467,7 @@ class LMStudioOpenAIProvider(OpenAIProvider):
class XAIProvider(OpenAIProvider):
"""https://docs.x.ai/docs/api-reference"""
name: str = "xai"
provider_type: Literal[ProviderType.xai] = Field(ProviderType.xai, description="The type of the provider.")
api_key: str = Field(..., description="API key for the xAI/Grok API.")
base_url: str = Field("https://api.x.ai/v1", description="Base URL for the xAI/Grok API.")
@ -476,6 +520,7 @@ class XAIProvider(OpenAIProvider):
model_endpoint=self.base_url,
context_window=context_window_size,
handle=self.get_handle(model_name),
provider_name=self.name,
)
)
@ -487,7 +532,7 @@ class XAIProvider(OpenAIProvider):
class AnthropicProvider(Provider):
name: str = "anthropic"
provider_type: Literal[ProviderType.anthropic] = Field(ProviderType.anthropic, description="The type of the provider.")
api_key: str = Field(..., description="API key for the Anthropic API.")
base_url: str = "https://api.anthropic.com/v1"
@ -563,6 +608,7 @@ class AnthropicProvider(Provider):
handle=self.get_handle(model["id"]),
put_inner_thoughts_in_kwargs=inner_thoughts_in_kwargs,
max_tokens=max_tokens,
provider_name=self.name,
)
)
return configs
@ -572,7 +618,7 @@ class AnthropicProvider(Provider):
class MistralProvider(Provider):
name: str = "mistral"
provider_type: Literal[ProviderType.mistral] = Field(ProviderType.mistral, description="The type of the provider.")
api_key: str = Field(..., description="API key for the Mistral API.")
base_url: str = "https://api.mistral.ai/v1"
@ -596,6 +642,7 @@ class MistralProvider(Provider):
model_endpoint=self.base_url,
context_window=model["max_context_length"],
handle=self.get_handle(model["id"]),
provider_name=self.name,
)
)
@ -622,7 +669,7 @@ class OllamaProvider(OpenAIProvider):
See: https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion
"""
name: str = "ollama"
provider_type: Literal[ProviderType.ollama] = Field(ProviderType.ollama, description="The type of the provider.")
base_url: str = Field(..., description="Base URL for the Ollama API.")
api_key: Optional[str] = Field(None, description="API key for the Ollama API (default: `None`).")
default_prompt_formatter: str = Field(
@ -652,6 +699,7 @@ class OllamaProvider(OpenAIProvider):
model_wrapper=self.default_prompt_formatter,
context_window=context_window,
handle=self.get_handle(model["name"]),
provider_name=self.name,
)
)
return configs
@ -734,7 +782,7 @@ class OllamaProvider(OpenAIProvider):
class GroqProvider(OpenAIProvider):
name: str = "groq"
provider_type: Literal[ProviderType.groq] = Field(ProviderType.groq, description="The type of the provider.")
base_url: str = "https://api.groq.com/openai/v1"
api_key: str = Field(..., description="API key for the Groq API.")
@ -753,6 +801,7 @@ class GroqProvider(OpenAIProvider):
model_endpoint=self.base_url,
context_window=model["context_window"],
handle=self.get_handle(model["id"]),
provider_name=self.name,
)
)
return configs
@ -773,7 +822,7 @@ class TogetherProvider(OpenAIProvider):
function calling support is limited.
"""
name: str = "together"
provider_type: Literal[ProviderType.together] = Field(ProviderType.together, description="The type of the provider.")
base_url: str = "https://api.together.ai/v1"
api_key: str = Field(..., description="API key for the TogetherAI API.")
default_prompt_formatter: str = Field(..., description="Default prompt formatter (aka model wrapper) to use on vLLM /completions API.")
@ -821,6 +870,7 @@ class TogetherProvider(OpenAIProvider):
model_wrapper=self.default_prompt_formatter,
context_window=context_window_size,
handle=self.get_handle(model_name),
provider_name=self.name,
)
)
@ -874,7 +924,7 @@ class TogetherProvider(OpenAIProvider):
class GoogleAIProvider(Provider):
# gemini
name: str = "google_ai"
provider_type: Literal[ProviderType.google_ai] = Field(ProviderType.google_ai, description="The type of the provider.")
api_key: str = Field(..., description="API key for the Google AI API.")
base_url: str = "https://generativelanguage.googleapis.com"
@ -889,7 +939,6 @@ class GoogleAIProvider(Provider):
# filter by model names
model_options = [mo[len("models/") :] if mo.startswith("models/") else mo for mo in model_options]
# TODO remove manual filtering for gemini-pro
# Add support for all gemini models
model_options = [mo for mo in model_options if str(mo).startswith("gemini-")]
@ -903,6 +952,7 @@ class GoogleAIProvider(Provider):
context_window=self.get_model_context_window(model),
handle=self.get_handle(model),
max_tokens=8192,
provider_name=self.name,
)
)
return configs
@ -938,7 +988,7 @@ class GoogleAIProvider(Provider):
class GoogleVertexProvider(Provider):
name: str = "google_vertex"
provider_type: Literal[ProviderType.google_vertex] = Field(ProviderType.google_vertex, description="The type of the provider.")
google_cloud_project: str = Field(..., description="GCP project ID for the Google Vertex API.")
google_cloud_location: str = Field(..., description="GCP region for the Google Vertex API.")
@ -955,6 +1005,7 @@ class GoogleVertexProvider(Provider):
context_window=context_length,
handle=self.get_handle(model),
max_tokens=8192,
provider_name=self.name,
)
)
return configs
@ -978,7 +1029,7 @@ class GoogleVertexProvider(Provider):
class AzureProvider(Provider):
name: str = "azure"
provider_type: Literal[ProviderType.azure] = Field(ProviderType.azure, description="The type of the provider.")
latest_api_version: str = "2024-09-01-preview" # https://learn.microsoft.com/en-us/azure/ai-services/openai/api-version-deprecation
base_url: str = Field(
..., description="Base URL for the Azure API endpoint. This should be specific to your org, e.g. `https://letta.openai.azure.com`."
@ -1011,6 +1062,7 @@ class AzureProvider(Provider):
model_endpoint=model_endpoint,
context_window=context_window_size,
handle=self.get_handle(model_name),
provider_name=self.name,
),
)
return configs
@ -1051,7 +1103,7 @@ class VLLMChatCompletionsProvider(Provider):
"""vLLM provider that treats vLLM as an OpenAI /chat/completions proxy"""
# NOTE: vLLM only serves one model at a time (so could configure that through env variables)
name: str = "vllm"
provider_type: Literal[ProviderType.vllm] = Field(ProviderType.vllm, description="The type of the provider.")
base_url: str = Field(..., description="Base URL for the vLLM API.")
def list_llm_models(self) -> List[LLMConfig]:
@ -1070,6 +1122,7 @@ class VLLMChatCompletionsProvider(Provider):
model_endpoint=self.base_url,
context_window=model["max_model_len"],
handle=self.get_handle(model["id"]),
provider_name=self.name,
)
)
return configs
@ -1083,7 +1136,7 @@ class VLLMCompletionsProvider(Provider):
"""This uses /completions API as the backend, not /chat/completions, so we need to specify a model wrapper"""
# NOTE: vLLM only serves one model at a time (so could configure that through env variables)
name: str = "vllm"
provider_type: Literal[ProviderType.vllm] = Field(ProviderType.vllm, description="The type of the provider.")
base_url: str = Field(..., description="Base URL for the vLLM API.")
default_prompt_formatter: str = Field(..., description="Default prompt formatter (aka model wrapper) to use on vLLM /completions API.")
@ -1103,6 +1156,7 @@ class VLLMCompletionsProvider(Provider):
model_wrapper=self.default_prompt_formatter,
context_window=model["max_model_len"],
handle=self.get_handle(model["id"]),
provider_name=self.name,
)
)
return configs
@ -1117,7 +1171,7 @@ class CohereProvider(OpenAIProvider):
class AnthropicBedrockProvider(Provider):
name: str = "bedrock"
provider_type: Literal[ProviderType.bedrock] = Field(ProviderType.bedrock, description="The type of the provider.")
aws_region: str = Field(..., description="AWS region for Bedrock")
def list_llm_models(self):
@ -1131,10 +1185,11 @@ class AnthropicBedrockProvider(Provider):
configs.append(
LLMConfig(
model=model_arn,
model_endpoint_type=self.name,
model_endpoint_type=self.provider_type.value,
model_endpoint=None,
context_window=self.get_model_context_window(model_arn),
handle=self.get_handle(model_arn),
provider_name=self.name,
)
)
return configs

View File

@ -11,13 +11,9 @@ from letta.constants import (
MCP_TOOL_TAG_NAME_PREFIX,
)
from letta.functions.ast_parsers import get_function_name_and_description
from letta.functions.composio_helpers import generate_composio_tool_wrapper
from letta.functions.functions import derive_openai_json_schema, get_json_schema_from_module
from letta.functions.helpers import (
generate_composio_tool_wrapper,
generate_langchain_tool_wrapper,
generate_mcp_tool_wrapper,
generate_model_from_args_json_schema,
)
from letta.functions.helpers import generate_langchain_tool_wrapper, generate_mcp_tool_wrapper, generate_model_from_args_json_schema
from letta.functions.mcp_client.types import MCPTool
from letta.functions.schema_generator import (
generate_schema_from_args_schema_v2,
@ -176,8 +172,7 @@ class ToolCreate(LettaBase):
Returns:
Tool: A Letta Tool initialized with attributes derived from the Composio tool.
"""
from composio import LogLevel
from composio_langchain import ComposioToolSet
from composio import ComposioToolSet, LogLevel
composio_toolset = ComposioToolSet(logging_level=LogLevel.ERROR, lock=False)
composio_action_schemas = composio_toolset.get_action_schemas(actions=[action_name], check_connected_accounts=False)

View File

@ -14,6 +14,7 @@ from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.cors import CORSMiddleware
from letta.__init__ import __version__
from letta.agents.exceptions import IncompatibleAgentType
from letta.constants import ADMIN_PREFIX, API_PREFIX, OPENAI_API_PREFIX
from letta.errors import BedrockPermissionError, LettaAgentNotFoundError, LettaUserNotFoundError
from letta.jobs.scheduler import shutdown_cron_scheduler, start_cron_jobs
@ -173,6 +174,17 @@ def create_application() -> "FastAPI":
def shutdown_scheduler():
shutdown_cron_scheduler()
@app.exception_handler(IncompatibleAgentType)
async def handle_incompatible_agent_type(request: Request, exc: IncompatibleAgentType):
return JSONResponse(
status_code=400,
content={
"detail": str(exc),
"expected_type": exc.expected_type,
"actual_type": exc.actual_type,
},
)
@app.exception_handler(Exception)
async def generic_error_handler(request: Request, exc: Exception):
# Log the actual error for debugging

View File

@ -12,7 +12,7 @@ from letta.schemas.enums import MessageStreamStatus
from letta.schemas.letta_message import LettaMessage
from letta.schemas.message import Message
from letta.schemas.openai.chat_completion_response import ChatCompletionChunkResponse
from letta.server.rest_api.optimistic_json_parser import OptimisticJSONParser
from letta.server.rest_api.json_parser import OptimisticJSONParser
from letta.streaming_interface import AgentChunkStreamingInterface
logger = get_logger(__name__)

View File

@ -28,7 +28,7 @@ from letta.schemas.letta_message import (
from letta.schemas.letta_message_content import ReasoningContent, RedactedReasoningContent, TextContent
from letta.schemas.message import Message
from letta.schemas.openai.chat_completion_response import ChatCompletionChunkResponse
from letta.server.rest_api.optimistic_json_parser import OptimisticJSONParser
from letta.server.rest_api.json_parser import OptimisticJSONParser
from letta.streaming_interface import AgentChunkStreamingInterface
from letta.streaming_utils import FunctionArgumentsStreamHandler, JSONInnerThoughtsExtractor
from letta.utils import parse_json
@ -291,7 +291,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
self.streaming_chat_completion_json_reader = FunctionArgumentsStreamHandler(json_key=assistant_message_tool_kwarg)
# @matt's changes here, adopting new optimistic json parser
self.current_function_arguments = []
self.current_function_arguments = ""
self.optimistic_json_parser = OptimisticJSONParser()
self.current_json_parse_result = {}
@ -387,7 +387,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
def stream_start(self):
"""Initialize streaming by activating the generator and clearing any old chunks."""
self.streaming_chat_completion_mode_function_name = None
self.current_function_arguments = []
self.current_function_arguments = ""
self.current_json_parse_result = {}
if not self._active:
@ -398,7 +398,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
def stream_end(self):
"""Clean up the stream by deactivating and clearing chunks."""
self.streaming_chat_completion_mode_function_name = None
self.current_function_arguments = []
self.current_function_arguments = ""
self.current_json_parse_result = {}
# if not self.streaming_chat_completion_mode and not self.nonstreaming_legacy_mode:
@ -609,14 +609,13 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
# early exit to turn into content mode
return None
if tool_call.function.arguments:
self.current_function_arguments.append(tool_call.function.arguments)
self.current_function_arguments += tool_call.function.arguments
# if we're in the middle of parsing a send_message, we'll keep processing the JSON chunks
if tool_call.function.arguments and self.streaming_chat_completion_mode_function_name == self.assistant_message_tool_name:
# Strip out any extras tokens
# In the case that we just have the prefix of something, no message yet, then we should early exit to move to the next chunk
combined_args = "".join(self.current_function_arguments)
parsed_args = self.optimistic_json_parser.parse(combined_args)
parsed_args = self.optimistic_json_parser.parse(self.current_function_arguments)
if parsed_args.get(self.assistant_message_tool_kwarg) and parsed_args.get(
self.assistant_message_tool_kwarg
@ -686,7 +685,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
# updates_inner_thoughts = ""
# else: # OpenAI
# updates_main_json, updates_inner_thoughts = self.function_args_reader.process_fragment(tool_call.function.arguments)
self.current_function_arguments.append(tool_call.function.arguments)
self.current_function_arguments += tool_call.function.arguments
updates_main_json, updates_inner_thoughts = self.function_args_reader.process_fragment(tool_call.function.arguments)
# If we have inner thoughts, we should output them as a chunk
@ -805,8 +804,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
# TODO: THIS IS HORRIBLE
# TODO: WE USE THE OLD JSON PARSER EARLIER (WHICH DOES NOTHING) AND NOW THE NEW JSON PARSER
# TODO: THIS IS TOTALLY WRONG AND BAD, BUT SAVING FOR A LARGER REWRITE IN THE NEAR FUTURE
combined_args = "".join(self.current_function_arguments)
parsed_args = self.optimistic_json_parser.parse(combined_args)
parsed_args = self.optimistic_json_parser.parse(self.current_function_arguments)
if parsed_args.get(self.assistant_message_tool_kwarg) and parsed_args.get(
self.assistant_message_tool_kwarg

View File

@ -1,7 +1,43 @@
import json
from abc import ABC, abstractmethod
from typing import Any
from pydantic_core import from_json
from letta.log import get_logger
logger = get_logger(__name__)
class OptimisticJSONParser:
class JSONParser(ABC):
@abstractmethod
def parse(self, input_str: str) -> Any:
raise NotImplementedError()
class PydanticJSONParser(JSONParser):
"""
https://docs.pydantic.dev/latest/concepts/json/#json-parsing
If `strict` is True, we will not allow for partial parsing of JSON.
Compared with `OptimisticJSONParser`, this parser is more strict.
Note: This will not partially parse strings which may be decrease parsing speed for message strings
"""
def __init__(self, strict=False):
self.strict = strict
def parse(self, input_str: str) -> Any:
if not input_str:
return {}
try:
return from_json(input_str, allow_partial="trailing-strings" if not self.strict else False)
except ValueError as e:
logger.error(f"Failed to parse JSON: {e}")
raise
class OptimisticJSONParser(JSONParser):
"""
A JSON parser that attempts to parse a given string using `json.loads`,
and if that fails, it parses as much valid JSON as possible while
@ -13,25 +49,25 @@ class OptimisticJSONParser:
def __init__(self, strict=False):
self.strict = strict
self.parsers = {
" ": self.parse_space,
"\r": self.parse_space,
"\n": self.parse_space,
"\t": self.parse_space,
"[": self.parse_array,
"{": self.parse_object,
'"': self.parse_string,
"t": self.parse_true,
"f": self.parse_false,
"n": self.parse_null,
" ": self._parse_space,
"\r": self._parse_space,
"\n": self._parse_space,
"\t": self._parse_space,
"[": self._parse_array,
"{": self._parse_object,
'"': self._parse_string,
"t": self._parse_true,
"f": self._parse_false,
"n": self._parse_null,
}
# Register number parser for digits and signs
for char in "0123456789.-":
self.parsers[char] = self.parse_number
self.last_parse_reminding = None
self.on_extra_token = self.default_on_extra_token
self.on_extra_token = self._default_on_extra_token
def default_on_extra_token(self, text, data, reminding):
def _default_on_extra_token(self, text, data, reminding):
print(f"Parsed JSON with extra tokens: {data}, remaining: {reminding}")
def parse(self, input_str):
@ -45,7 +81,7 @@ class OptimisticJSONParser:
try:
return json.loads(input_str)
except json.JSONDecodeError as decode_error:
data, reminding = self.parse_any(input_str, decode_error)
data, reminding = self._parse_any(input_str, decode_error)
self.last_parse_reminding = reminding
if self.on_extra_token and reminding:
self.on_extra_token(input_str, data, reminding)
@ -53,7 +89,7 @@ class OptimisticJSONParser:
else:
return json.loads("{}")
def parse_any(self, input_str, decode_error):
def _parse_any(self, input_str, decode_error):
"""Determine which parser to use based on the first character."""
if not input_str:
raise decode_error
@ -62,11 +98,11 @@ class OptimisticJSONParser:
raise decode_error
return parser(input_str, decode_error)
def parse_space(self, input_str, decode_error):
def _parse_space(self, input_str, decode_error):
"""Strip leading whitespace and parse again."""
return self.parse_any(input_str.strip(), decode_error)
return self._parse_any(input_str.strip(), decode_error)
def parse_array(self, input_str, decode_error):
def _parse_array(self, input_str, decode_error):
"""Parse a JSON array, returning the list and remaining string."""
# Skip the '['
input_str = input_str[1:]
@ -77,7 +113,7 @@ class OptimisticJSONParser:
# Skip the ']'
input_str = input_str[1:]
break
value, input_str = self.parse_any(input_str, decode_error)
value, input_str = self._parse_any(input_str, decode_error)
array_values.append(value)
input_str = input_str.strip()
if input_str.startswith(","):
@ -85,7 +121,7 @@ class OptimisticJSONParser:
input_str = input_str[1:].strip()
return array_values, input_str
def parse_object(self, input_str, decode_error):
def _parse_object(self, input_str, decode_error):
"""Parse a JSON object, returning the dict and remaining string."""
# Skip the '{'
input_str = input_str[1:]
@ -96,7 +132,7 @@ class OptimisticJSONParser:
# Skip the '}'
input_str = input_str[1:]
break
key, input_str = self.parse_any(input_str, decode_error)
key, input_str = self._parse_any(input_str, decode_error)
input_str = input_str.strip()
if not input_str or input_str[0] == "}":
@ -113,7 +149,7 @@ class OptimisticJSONParser:
input_str = input_str[1:]
break
value, input_str = self.parse_any(input_str, decode_error)
value, input_str = self._parse_any(input_str, decode_error)
obj[key] = value
input_str = input_str.strip()
if input_str.startswith(","):
@ -121,7 +157,7 @@ class OptimisticJSONParser:
input_str = input_str[1:].strip()
return obj, input_str
def parse_string(self, input_str, decode_error):
def _parse_string(self, input_str, decode_error):
"""Parse a JSON string, respecting escaped quotes if present."""
end = input_str.find('"', 1)
while end != -1 and input_str[end - 1] == "\\":
@ -166,19 +202,19 @@ class OptimisticJSONParser:
return num, remainder
def parse_true(self, input_str, decode_error):
def _parse_true(self, input_str, decode_error):
"""Parse a 'true' value."""
if input_str.startswith(("t", "T")):
return True, input_str[4:]
raise decode_error
def parse_false(self, input_str, decode_error):
def _parse_false(self, input_str, decode_error):
"""Parse a 'false' value."""
if input_str.startswith(("f", "F")):
return False, input_str[5:]
raise decode_error
def parse_null(self, input_str, decode_error):
def _parse_null(self, input_str, decode_error):
"""Parse a 'null' value."""
if input_str.startswith("n"):
return None, input_str[4:]

View File

@ -678,7 +678,7 @@ async def send_message_streaming(
server: SyncServer = Depends(get_letta_server),
request: LettaStreamingRequest = Body(...),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
) -> StreamingResponse | LettaResponse:
"""
Process a user message and return the agent's response.
This endpoint accepts a message from a user and processes it through the agent.

View File

@ -1,6 +1,6 @@
from typing import TYPE_CHECKING, List
from typing import TYPE_CHECKING, List, Optional
from fastapi import APIRouter, Depends
from fastapi import APIRouter, Depends, Query
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.llm_config import LLMConfig
@ -14,10 +14,11 @@ router = APIRouter(prefix="/models", tags=["models", "llms"])
@router.get("/", response_model=List[LLMConfig], operation_id="list_models")
def list_llm_models(
byok_only: Optional[bool] = Query(None),
server: "SyncServer" = Depends(get_letta_server),
):
models = server.list_llm_models()
models = server.list_llm_models(byok_only=byok_only)
# print(models)
return models

View File

@ -2,6 +2,7 @@ from typing import TYPE_CHECKING, List, Optional
from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query
from letta.schemas.enums import ProviderType
from letta.schemas.providers import Provider, ProviderCreate, ProviderUpdate
from letta.server.rest_api.utils import get_letta_server
@ -13,6 +14,8 @@ router = APIRouter(prefix="/providers", tags=["providers"])
@router.get("/", response_model=List[Provider], operation_id="list_providers")
def list_providers(
name: Optional[str] = Query(None),
provider_type: Optional[ProviderType] = Query(None),
after: Optional[str] = Query(None),
limit: Optional[int] = Query(50),
actor_id: Optional[str] = Header(None, alias="user_id"),
@ -23,7 +26,7 @@ def list_providers(
"""
try:
actor = server.user_manager.get_user_or_default(user_id=actor_id)
providers = server.provider_manager.list_providers(after=after, limit=limit, actor=actor)
providers = server.provider_manager.list_providers(after=after, limit=limit, actor=actor, name=name, provider_type=provider_type)
except HTTPException:
raise
except Exception as e:

View File

@ -54,8 +54,6 @@ async def create_voice_chat_completions(
block_manager=server.block_manager,
passage_manager=server.passage_manager,
actor=actor,
message_buffer_limit=8,
message_buffer_min=4,
)
# Return the streaming generator

View File

@ -16,6 +16,7 @@ from pydantic import BaseModel
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, FUNC_FAILED_HEARTBEAT_MESSAGE, REQ_HEARTBEAT_MESSAGE
from letta.errors import ContextWindowExceededError, RateLimitExceededError
from letta.helpers.datetime_helpers import get_utc_time
from letta.helpers.message_helper import convert_message_creates_to_messages
from letta.log import get_logger
from letta.schemas.enums import MessageRole
from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, RedactedReasoningContent, TextContent
@ -143,27 +144,15 @@ def log_error_to_sentry(e):
def create_input_messages(input_messages: List[MessageCreate], agent_id: str, actor: User) -> List[Message]:
"""
Converts a user input message into the internal structured format.
"""
new_messages = []
for input_message in input_messages:
# Construct the Message object
new_message = Message(
id=f"message-{uuid.uuid4()}",
role=input_message.role,
content=input_message.content,
name=input_message.name,
otid=input_message.otid,
sender_id=input_message.sender_id,
organization_id=actor.organization_id,
agent_id=agent_id,
model=None,
tool_calls=None,
tool_call_id=None,
created_at=get_utc_time(),
)
new_messages.append(new_message)
return new_messages
TODO (cliandy): this effectively duplicates the functionality of `convert_message_creates_to_messages`,
we should unify this when it's clear what message attributes we need.
"""
messages = convert_message_creates_to_messages(input_messages, agent_id, wrap_user_message=False, wrap_system_message=False)
for message in messages:
message.organization_id = actor.organization_id
return messages
def create_letta_messages_from_llm_response(

View File

@ -268,10 +268,11 @@ class SyncServer(Server):
)
# collect providers (always has Letta as a default)
self._enabled_providers: List[Provider] = [LettaProvider()]
self._enabled_providers: List[Provider] = [LettaProvider(name="letta")]
if model_settings.openai_api_key:
self._enabled_providers.append(
OpenAIProvider(
name="openai",
api_key=model_settings.openai_api_key,
base_url=model_settings.openai_api_base,
)
@ -279,12 +280,14 @@ class SyncServer(Server):
if model_settings.anthropic_api_key:
self._enabled_providers.append(
AnthropicProvider(
name="anthropic",
api_key=model_settings.anthropic_api_key,
)
)
if model_settings.ollama_base_url:
self._enabled_providers.append(
OllamaProvider(
name="ollama",
base_url=model_settings.ollama_base_url,
api_key=None,
default_prompt_formatter=model_settings.default_prompt_formatter,
@ -293,12 +296,14 @@ class SyncServer(Server):
if model_settings.gemini_api_key:
self._enabled_providers.append(
GoogleAIProvider(
name="google_ai",
api_key=model_settings.gemini_api_key,
)
)
if model_settings.google_cloud_location and model_settings.google_cloud_project:
self._enabled_providers.append(
GoogleVertexProvider(
name="google_vertex",
google_cloud_project=model_settings.google_cloud_project,
google_cloud_location=model_settings.google_cloud_location,
)
@ -307,6 +312,7 @@ class SyncServer(Server):
assert model_settings.azure_api_version, "AZURE_API_VERSION is required"
self._enabled_providers.append(
AzureProvider(
name="azure",
api_key=model_settings.azure_api_key,
base_url=model_settings.azure_base_url,
api_version=model_settings.azure_api_version,
@ -315,12 +321,14 @@ class SyncServer(Server):
if model_settings.groq_api_key:
self._enabled_providers.append(
GroqProvider(
name="groq",
api_key=model_settings.groq_api_key,
)
)
if model_settings.together_api_key:
self._enabled_providers.append(
TogetherProvider(
name="together",
api_key=model_settings.together_api_key,
default_prompt_formatter=model_settings.default_prompt_formatter,
)
@ -329,6 +337,7 @@ class SyncServer(Server):
# vLLM exposes both a /chat/completions and a /completions endpoint
self._enabled_providers.append(
VLLMCompletionsProvider(
name="vllm",
base_url=model_settings.vllm_api_base,
default_prompt_formatter=model_settings.default_prompt_formatter,
)
@ -338,12 +347,14 @@ class SyncServer(Server):
# e.g. "... --enable-auto-tool-choice --tool-call-parser hermes"
self._enabled_providers.append(
VLLMChatCompletionsProvider(
name="vllm",
base_url=model_settings.vllm_api_base,
)
)
if model_settings.aws_access_key and model_settings.aws_secret_access_key and model_settings.aws_region:
self._enabled_providers.append(
AnthropicBedrockProvider(
name="bedrock",
aws_region=model_settings.aws_region,
)
)
@ -355,11 +366,11 @@ class SyncServer(Server):
if model_settings.lmstudio_base_url.endswith("/v1")
else model_settings.lmstudio_base_url + "/v1"
)
self._enabled_providers.append(LMStudioOpenAIProvider(base_url=lmstudio_url))
self._enabled_providers.append(LMStudioOpenAIProvider(name="lmstudio_openai", base_url=lmstudio_url))
if model_settings.deepseek_api_key:
self._enabled_providers.append(DeepSeekProvider(api_key=model_settings.deepseek_api_key))
self._enabled_providers.append(DeepSeekProvider(name="deepseek", api_key=model_settings.deepseek_api_key))
if model_settings.xai_api_key:
self._enabled_providers.append(XAIProvider(api_key=model_settings.xai_api_key))
self._enabled_providers.append(XAIProvider(name="xai", api_key=model_settings.xai_api_key))
# For MCP
"""Initialize the MCP clients (there may be multiple)"""
@ -862,6 +873,8 @@ class SyncServer(Server):
agent_ids=[voice_sleeptime_agent.id],
manager_config=VoiceSleeptimeManager(
manager_agent_id=main_agent.id,
max_message_buffer_length=constants.DEFAULT_MAX_MESSAGE_BUFFER_LENGTH,
min_message_buffer_length=constants.DEFAULT_MIN_MESSAGE_BUFFER_LENGTH,
),
),
actor=actor,
@ -1182,10 +1195,10 @@ class SyncServer(Server):
except NoResultFound:
raise HTTPException(status_code=404, detail=f"Organization with id {org_id} not found")
def list_llm_models(self) -> List[LLMConfig]:
def list_llm_models(self, byok_only: bool = False) -> List[LLMConfig]:
"""List available models"""
llm_models = []
for provider in self.get_enabled_providers():
for provider in self.get_enabled_providers(byok_only=byok_only):
try:
llm_models.extend(provider.list_llm_models())
except Exception as e:
@ -1205,11 +1218,12 @@ class SyncServer(Server):
warnings.warn(f"An error occurred while listing embedding models for provider {provider}: {e}")
return embedding_models
def get_enabled_providers(self):
def get_enabled_providers(self, byok_only: bool = False):
providers_from_db = {p.name: p.cast_to_subtype() for p in self.provider_manager.list_providers()}
if byok_only:
return list(providers_from_db.values())
providers_from_env = {p.name: p for p in self._enabled_providers}
providers_from_db = {p.name: p for p in self.provider_manager.list_providers()}
# Merge the two dictionaries, keeping the values from providers_from_db where conflicts occur
return {**providers_from_env, **providers_from_db}.values()
return list(providers_from_env.values()) + list(providers_from_db.values())
@trace_method
def get_llm_config_from_handle(
@ -1294,7 +1308,7 @@ class SyncServer(Server):
return embedding_config
def get_provider_from_name(self, provider_name: str) -> Provider:
providers = [provider for provider in self._enabled_providers if provider.name == provider_name]
providers = [provider for provider in self.get_enabled_providers() if provider.name == provider_name]
if not providers:
raise ValueError(f"Provider {provider_name} is not supported")
elif len(providers) > 1:

View File

@ -80,6 +80,12 @@ class GroupManager:
case ManagerType.voice_sleeptime:
new_group.manager_type = ManagerType.voice_sleeptime
new_group.manager_agent_id = group.manager_config.manager_agent_id
max_message_buffer_length = group.manager_config.max_message_buffer_length
min_message_buffer_length = group.manager_config.min_message_buffer_length
# Safety check for buffer length range
self.ensure_buffer_length_range_valid(max_value=max_message_buffer_length, min_value=min_message_buffer_length)
new_group.max_message_buffer_length = max_message_buffer_length
new_group.min_message_buffer_length = min_message_buffer_length
case _:
raise ValueError(f"Unsupported manager type: {group.manager_config.manager_type}")
@ -97,6 +103,8 @@ class GroupManager:
group = GroupModel.read(db_session=session, identifier=group_id, actor=actor)
sleeptime_agent_frequency = None
max_message_buffer_length = None
min_message_buffer_length = None
max_turns = None
termination_token = None
manager_agent_id = None
@ -117,11 +125,24 @@ class GroupManager:
sleeptime_agent_frequency = group_update.manager_config.sleeptime_agent_frequency
if sleeptime_agent_frequency and group.turns_counter is None:
group.turns_counter = -1
case ManagerType.voice_sleeptime:
manager_agent_id = group_update.manager_config.manager_agent_id
max_message_buffer_length = group_update.manager_config.max_message_buffer_length or group.max_message_buffer_length
min_message_buffer_length = group_update.manager_config.min_message_buffer_length or group.min_message_buffer_length
if sleeptime_agent_frequency and group.turns_counter is None:
group.turns_counter = -1
case _:
raise ValueError(f"Unsupported manager type: {group_update.manager_config.manager_type}")
# Safety check for buffer length range
self.ensure_buffer_length_range_valid(max_value=max_message_buffer_length, min_value=min_message_buffer_length)
if sleeptime_agent_frequency:
group.sleeptime_agent_frequency = sleeptime_agent_frequency
if max_message_buffer_length:
group.max_message_buffer_length = max_message_buffer_length
if min_message_buffer_length:
group.min_message_buffer_length = min_message_buffer_length
if max_turns:
group.max_turns = max_turns
if termination_token:
@ -274,3 +295,40 @@ class GroupManager:
if manager_agent:
for block in blocks:
session.add(BlocksAgents(agent_id=manager_agent.id, block_id=block.id, block_label=block.label))
@staticmethod
def ensure_buffer_length_range_valid(
max_value: Optional[int],
min_value: Optional[int],
max_name: str = "max_message_buffer_length",
min_name: str = "min_message_buffer_length",
) -> None:
"""
1) Both-or-none: if one is set, the other must be set.
2) Both must be ints > 4.
3) max_value must be strictly greater than min_value.
"""
# 1) require both-or-none
if (max_value is None) != (min_value is None):
raise ValueError(
f"Both '{max_name}' and '{min_name}' must be provided together " f"(got {max_name}={max_value}, {min_name}={min_value})"
)
# no further checks if neither is provided
if max_value is None:
return
# 2) type & lowerbound checks
if not isinstance(max_value, int) or not isinstance(min_value, int):
raise ValueError(
f"Both '{max_name}' and '{min_name}' must be integers "
f"(got {max_name}={type(max_value).__name__}, {min_name}={type(min_value).__name__})"
)
if max_value <= 4 or min_value <= 4:
raise ValueError(
f"Both '{max_name}' and '{min_name}' must be greater than 4 " f"(got {max_name}={max_value}, {min_name}={min_value})"
)
# 3) ordering
if max_value <= min_value:
raise ValueError(f"'{max_name}' must be greater than '{min_name}' " f"(got {max_name}={max_value} <= {min_name}={min_value})")

View File

@ -1,6 +1,7 @@
from typing import List, Optional
from typing import List, Optional, Union
from letta.orm.provider import Provider as ProviderModel
from letta.schemas.enums import ProviderType
from letta.schemas.providers import Provider as PydanticProvider
from letta.schemas.providers import ProviderUpdate
from letta.schemas.user import User as PydanticUser
@ -18,6 +19,9 @@ class ProviderManager:
def create_provider(self, provider: PydanticProvider, actor: PydanticUser) -> PydanticProvider:
"""Create a new provider if it doesn't already exist."""
with self.session_maker() as session:
if provider.name == provider.provider_type.value:
raise ValueError("Provider name must be unique and different from provider type")
# Assign the organization id based on the actor
provider.organization_id = actor.organization_id
@ -59,29 +63,36 @@ class ProviderManager:
session.commit()
@enforce_types
def list_providers(self, after: Optional[str] = None, limit: Optional[int] = 50, actor: PydanticUser = None) -> List[PydanticProvider]:
def list_providers(
self,
name: Optional[str] = None,
provider_type: Optional[ProviderType] = None,
after: Optional[str] = None,
limit: Optional[int] = 50,
actor: PydanticUser = None,
) -> List[PydanticProvider]:
"""List all providers with optional pagination."""
filter_kwargs = {}
if name:
filter_kwargs["name"] = name
if provider_type:
filter_kwargs["provider_type"] = provider_type
with self.session_maker() as session:
providers = ProviderModel.list(
db_session=session,
after=after,
limit=limit,
actor=actor,
**filter_kwargs,
)
return [provider.to_pydantic() for provider in providers]
@enforce_types
def get_anthropic_override_provider_id(self) -> Optional[str]:
"""Helper function to fetch custom anthropic provider id for v0 BYOK feature"""
anthropic_provider = [provider for provider in self.list_providers() if provider.name == "anthropic"]
if len(anthropic_provider) != 0:
return anthropic_provider[0].id
return None
def get_provider_id_from_name(self, provider_name: Union[str, None]) -> Optional[str]:
providers = self.list_providers(name=provider_name)
return providers[0].id if providers else None
@enforce_types
def get_anthropic_override_key(self) -> Optional[str]:
"""Helper function to fetch custom anthropic key for v0 BYOK feature"""
anthropic_provider = [provider for provider in self.list_providers() if provider.name == "anthropic"]
if len(anthropic_provider) != 0:
return anthropic_provider[0].api_key
return None
def get_override_key(self, provider_name: Union[str, None]) -> Optional[str]:
providers = self.list_providers(name=provider_name)
return providers[0].api_key if providers else None

View File

@ -4,6 +4,7 @@ import traceback
from typing import List, Tuple
from letta.agents.voice_sleeptime_agent import VoiceSleeptimeAgent
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
from letta.log import get_logger
from letta.schemas.enums import MessageRole
from letta.schemas.letta_message_content import TextContent
@ -77,7 +78,7 @@ class Summarizer:
logger.info("Buffer length hit, evicting messages.")
target_trim_index = len(all_in_context_messages) - self.message_buffer_min + 1
target_trim_index = len(all_in_context_messages) - self.message_buffer_min
while target_trim_index < len(all_in_context_messages) and all_in_context_messages[target_trim_index].role != MessageRole.user:
target_trim_index += 1
@ -112,11 +113,12 @@ class Summarizer:
summary_request_text = f"""Youre a memory-recall helper for an AI that can only keep the last {self.message_buffer_min} messages. Scan the conversation history, focusing on messages about to drop out of that window, and write crisp notes that capture any important facts or insights about the human so they arent lost.
(Older) Evicted Messages:\n
{evicted_messages_str}
{evicted_messages_str}\n
(Newer) In-Context Messages:\n
{in_context_messages_str}
"""
print(summary_request_text)
# Fire-and-forget the summarization task
self.fire_and_forget(
self.summarizer_agent.step([MessageCreate(role=MessageRole.user, content=[TextContent(text=summary_request_text)])])
@ -149,6 +151,9 @@ def format_transcript(messages: List[Message], include_system: bool = False) ->
# 1) Try plain content
if msg.content:
# Skip tool messages where the name is "send_message"
if msg.role == MessageRole.tool and msg.name == DEFAULT_MESSAGE_TOOL:
continue
text = "".join(c.text for c in msg.content).strip()
# 2) Otherwise, try extracting from function calls
@ -156,12 +161,15 @@ def format_transcript(messages: List[Message], include_system: bool = False) ->
parts = []
for call in msg.tool_calls:
args_str = call.function.arguments
if call.function.name == DEFAULT_MESSAGE_TOOL:
try:
args = json.loads(args_str)
# pull out a "message" field if present
parts.append(args.get("message", args_str))
parts.append(args.get(DEFAULT_MESSAGE_TOOL_KWARG, args_str))
except json.JSONDecodeError:
parts.append(args_str)
else:
parts.append(args_str)
text = " ".join(parts).strip()
else:

View File

@ -100,7 +100,7 @@ class ToolExecutionManager:
try:
executor = ToolExecutorFactory.get_executor(tool.tool_type)
# TODO: Extend this async model to composio
if isinstance(executor, SandboxToolExecutor):
if isinstance(executor, (SandboxToolExecutor, ExternalComposioToolExecutor)):
result = await executor.execute(function_name, function_args, self.agent_state, tool, self.actor)
else:
result = executor.execute(function_name, function_args, self.agent_state, tool, self.actor)

View File

@ -5,7 +5,7 @@ from typing import Any, Dict, Optional
from letta.constants import COMPOSIO_ENTITY_ENV_VAR_KEY, CORE_MEMORY_LINE_NUMBER_WARNING, RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
from letta.functions.ast_parsers import coerce_dict_args_by_annotations, get_function_annotations_from_source
from letta.functions.helpers import execute_composio_action, generate_composio_action_from_func_name
from letta.functions.composio_helpers import execute_composio_action_async, generate_composio_action_from_func_name
from letta.helpers.composio_helpers import get_composio_api_key
from letta.helpers.json_helpers import json_dumps
from letta.schemas.agent import AgentState
@ -486,7 +486,7 @@ class LettaMultiAgentToolExecutor(ToolExecutor):
class ExternalComposioToolExecutor(ToolExecutor):
"""Executor for external Composio tools."""
def execute(
async def execute(
self,
function_name: str,
function_args: dict,
@ -505,7 +505,7 @@ class ExternalComposioToolExecutor(ToolExecutor):
composio_api_key = get_composio_api_key(actor=actor)
# TODO (matt): Roll in execute_composio_action into this class
function_response = execute_composio_action(
function_response = await execute_composio_action_async(
action_name=action_name, args=function_args, api_key=composio_api_key, entity_id=entity_id
)

104
poetry.lock generated
View File

@ -1016,25 +1016,6 @@ e2b = ["e2b (>=0.17.2a37,<1.1.0)", "e2b-code-interpreter"]
flyio = ["gql", "requests_toolbelt"]
tools = ["diskcache", "flake8", "networkx", "pathspec", "pygments", "ruff", "transformers"]
[[package]]
name = "composio-langchain"
version = "0.7.15"
description = "Use Composio to get an array of tools with your LangChain agent."
optional = false
python-versions = "<4,>=3.9"
groups = ["main"]
files = [
{file = "composio_langchain-0.7.15-py3-none-any.whl", hash = "sha256:a71b5371ad6c3ee4d4289c7a994fad1424e24c29a38e820b6b2ed259056abb65"},
{file = "composio_langchain-0.7.15.tar.gz", hash = "sha256:cb75c460289ecdf9590caf7ddc0d7888b0a6622ca4f800c9358abe90c25d055e"},
]
[package.dependencies]
composio_core = ">=0.7.0,<0.8.0"
langchain = ">=0.1.0"
langchain-openai = ">=0.0.2.post1"
langchainhub = ">=0.1.15"
pydantic = ">=2.6.4"
[[package]]
name = "configargparse"
version = "1.7"
@ -2842,9 +2823,10 @@ files = [
name = "jsonpatch"
version = "1.33"
description = "Apply JSON-Patches (RFC 6902)"
optional = false
optional = true
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, !=3.6.*"
groups = ["main"]
markers = "extra == \"external-tools\" or extra == \"desktop\" or extra == \"all\""
files = [
{file = "jsonpatch-1.33-py2.py3-none-any.whl", hash = "sha256:0ae28c0cd062bbd8b8ecc26d7d164fbbea9652a1a3693f3b956c1eae5145dade"},
{file = "jsonpatch-1.33.tar.gz", hash = "sha256:9fcd4009c41e6d12348b4a0ff2563ba56a2923a7dfee731d004e212e1ee5030c"},
@ -2857,9 +2839,10 @@ jsonpointer = ">=1.9"
name = "jsonpointer"
version = "3.0.0"
description = "Identify specific nodes in a JSON document (RFC 6901)"
optional = false
optional = true
python-versions = ">=3.7"
groups = ["main"]
markers = "extra == \"external-tools\" or extra == \"desktop\" or extra == \"all\""
files = [
{file = "jsonpointer-3.0.0-py2.py3-none-any.whl", hash = "sha256:13e088adc14fca8b6aa8177c044e12701e6ad4b28ff10e65f2267a90109c9942"},
{file = "jsonpointer-3.0.0.tar.gz", hash = "sha256:2b2d729f2091522d61c3b31f82e11870f60b68f43fbc705cb76bf4b832af59ef"},
@ -3052,9 +3035,10 @@ files = [
name = "langchain"
version = "0.3.23"
description = "Building applications with LLMs through composability"
optional = false
optional = true
python-versions = "<4.0,>=3.9"
groups = ["main"]
markers = "extra == \"external-tools\" or extra == \"desktop\" or extra == \"all\""
files = [
{file = "langchain-0.3.23-py3-none-any.whl", hash = "sha256:084f05ee7e80b7c3f378ebadd7309f2a37868ce2906fa0ae64365a67843ade3d"},
{file = "langchain-0.3.23.tar.gz", hash = "sha256:d95004afe8abebb52d51d6026270248da3f4b53d93e9bf699f76005e0c83ad34"},
@ -3120,9 +3104,10 @@ tenacity = ">=8.1.0,<8.4.0 || >8.4.0,<10"
name = "langchain-core"
version = "0.3.51"
description = "Building applications with LLMs through composability"
optional = false
optional = true
python-versions = "<4.0,>=3.9"
groups = ["main"]
markers = "extra == \"external-tools\" or extra == \"desktop\" or extra == \"all\""
files = [
{file = "langchain_core-0.3.51-py3-none-any.whl", hash = "sha256:4bd71e8acd45362aa428953f2a91d8162318014544a2216e4b769463caf68e13"},
{file = "langchain_core-0.3.51.tar.gz", hash = "sha256:db76b9cc331411602cb40ba0469a161febe7a0663fbcaddbc9056046ac2d22f4"},
@ -3140,30 +3125,14 @@ PyYAML = ">=5.3"
tenacity = ">=8.1.0,<8.4.0 || >8.4.0,<10.0.0"
typing-extensions = ">=4.7"
[[package]]
name = "langchain-openai"
version = "0.3.12"
description = "An integration package connecting OpenAI and LangChain"
optional = false
python-versions = "<4.0,>=3.9"
groups = ["main"]
files = [
{file = "langchain_openai-0.3.12-py3-none-any.whl", hash = "sha256:0fab64d58ec95e65ffbaf659470cd362e815685e15edbcb171641e90eca4eb86"},
{file = "langchain_openai-0.3.12.tar.gz", hash = "sha256:c9dbff63551f6bd91913bca9f99a2d057fd95dc58d4778657d67e5baa1737f61"},
]
[package.dependencies]
langchain-core = ">=0.3.49,<1.0.0"
openai = ">=1.68.2,<2.0.0"
tiktoken = ">=0.7,<1"
[[package]]
name = "langchain-text-splitters"
version = "0.3.8"
description = "LangChain text splitting utilities"
optional = false
optional = true
python-versions = "<4.0,>=3.9"
groups = ["main"]
markers = "extra == \"external-tools\" or extra == \"desktop\" or extra == \"all\""
files = [
{file = "langchain_text_splitters-0.3.8-py3-none-any.whl", hash = "sha256:e75cc0f4ae58dcf07d9f18776400cf8ade27fadd4ff6d264df6278bb302f6f02"},
{file = "langchain_text_splitters-0.3.8.tar.gz", hash = "sha256:116d4b9f2a22dda357d0b79e30acf005c5518177971c66a9f1ab0edfdb0f912e"},
@ -3172,30 +3141,14 @@ files = [
[package.dependencies]
langchain-core = ">=0.3.51,<1.0.0"
[[package]]
name = "langchainhub"
version = "0.1.21"
description = "The LangChain Hub API client"
optional = false
python-versions = "<4.0,>=3.8.1"
groups = ["main"]
files = [
{file = "langchainhub-0.1.21-py3-none-any.whl", hash = "sha256:1cc002dc31e0d132a776afd044361e2b698743df5202618cf2bad399246b895f"},
{file = "langchainhub-0.1.21.tar.gz", hash = "sha256:723383b3964a47dbaea6ad5d0ef728accefbc9d2c07480e800bdec43510a8c10"},
]
[package.dependencies]
packaging = ">=23.2,<25"
requests = ">=2,<3"
types-requests = ">=2.31.0.2,<3.0.0.0"
[[package]]
name = "langsmith"
version = "0.3.28"
description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform."
optional = false
optional = true
python-versions = "<4.0,>=3.9"
groups = ["main"]
markers = "extra == \"external-tools\" or extra == \"desktop\" or extra == \"all\""
files = [
{file = "langsmith-0.3.28-py3-none-any.whl", hash = "sha256:54ac8815514af52d9c801ad7970086693667e266bf1db90fc453c1759e8407cd"},
{file = "langsmith-0.3.28.tar.gz", hash = "sha256:4666595207131d7f8d83418e54dc86c05e28562e5c997633e7c33fc18f9aeb89"},
@ -3221,14 +3174,14 @@ pytest = ["pytest (>=7.0.0)", "rich (>=13.9.4,<14.0.0)"]
[[package]]
name = "letta-client"
version = "0.1.124"
version = "0.1.129"
description = ""
optional = false
python-versions = "<4.0,>=3.8"
groups = ["main"]
files = [
{file = "letta_client-0.1.124-py3-none-any.whl", hash = "sha256:a7901437ef91f395cd85d24c0312046b7c82e5a4dd8e04de0d39b5ca085c65d3"},
{file = "letta_client-0.1.124.tar.gz", hash = "sha256:e8b5716930824cc98c62ee01343e358f88619d346578d48a466277bc8282036d"},
{file = "letta_client-0.1.129-py3-none-any.whl", hash = "sha256:87a5fc32471e5b9fefbfc1e1337fd667d5e2e340ece5d2a6c782afbceab4bf36"},
{file = "letta_client-0.1.129.tar.gz", hash = "sha256:b00f611c18a2ad802ec9265f384e1666938c5fc5c86364b2c410d72f0331d597"},
]
[package.dependencies]
@ -4366,10 +4319,10 @@ files = [
name = "orjson"
version = "3.10.16"
description = "Fast, correct Python JSON library supporting dataclasses, datetimes, and numpy"
optional = false
optional = true
python-versions = ">=3.9"
groups = ["main"]
markers = "platform_python_implementation != \"PyPy\""
markers = "platform_python_implementation != \"PyPy\" and (extra == \"external-tools\" or extra == \"desktop\" or extra == \"all\")"
files = [
{file = "orjson-3.10.16-cp310-cp310-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:4cb473b8e79154fa778fb56d2d73763d977be3dcc140587e07dbc545bbfc38f8"},
{file = "orjson-3.10.16-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:622a8e85eeec1948690409a19ca1c7d9fd8ff116f4861d261e6ae2094fe59a00"},
@ -6069,9 +6022,10 @@ use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
name = "requests-toolbelt"
version = "1.0.0"
description = "A utility belt for advanced users of python-requests"
optional = false
optional = true
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
groups = ["main"]
markers = "extra == \"external-tools\" or extra == \"desktop\" or extra == \"all\""
files = [
{file = "requests-toolbelt-1.0.0.tar.gz", hash = "sha256:7681a0a3d047012b5bdc0ee37d7f8f07ebe76ab08caeccfc3921ce23c88d5bc6"},
{file = "requests_toolbelt-1.0.0-py2.py3-none-any.whl", hash = "sha256:cccfdd665f0a24fcf4726e690f65639d272bb0637b9b92dfd91a5568ccf6bd06"},
@ -6855,21 +6809,6 @@ dev = ["autoflake (>=1.3.1,<2.0.0)", "flake8 (>=3.8.3,<4.0.0)", "pre-commit (>=2
doc = ["cairosvg (>=2.5.2,<3.0.0)", "mdx-include (>=1.4.1,<2.0.0)", "mkdocs (>=1.1.2,<2.0.0)", "mkdocs-material (>=8.1.4,<9.0.0)", "pillow (>=9.3.0,<10.0.0)"]
test = ["black (>=22.3.0,<23.0.0)", "coverage (>=6.2,<7.0)", "isort (>=5.0.6,<6.0.0)", "mypy (==0.971)", "pytest (>=4.4.0,<8.0.0)", "pytest-cov (>=2.10.0,<5.0.0)", "pytest-sugar (>=0.9.4,<0.10.0)", "pytest-xdist (>=1.32.0,<4.0.0)", "rich (>=10.11.0,<14.0.0)", "shellingham (>=1.3.0,<2.0.0)"]
[[package]]
name = "types-requests"
version = "2.32.0.20250328"
description = "Typing stubs for requests"
optional = false
python-versions = ">=3.9"
groups = ["main"]
files = [
{file = "types_requests-2.32.0.20250328-py3-none-any.whl", hash = "sha256:72ff80f84b15eb3aa7a8e2625fffb6a93f2ad5a0c20215fc1dcfa61117bcb2a2"},
{file = "types_requests-2.32.0.20250328.tar.gz", hash = "sha256:c9e67228ea103bd811c96984fac36ed2ae8da87a36a633964a21f199d60baf32"},
]
[package.dependencies]
urllib3 = ">=2"
[[package]]
name = "typing-extensions"
version = "4.13.2"
@ -7438,9 +7377,10 @@ testing = ["coverage[toml]", "zope.event", "zope.testing"]
name = "zstandard"
version = "0.23.0"
description = "Zstandard bindings for Python"
optional = false
optional = true
python-versions = ">=3.8"
groups = ["main"]
markers = "extra == \"external-tools\" or extra == \"desktop\" or extra == \"all\""
files = [
{file = "zstandard-0.23.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:bf0a05b6059c0528477fba9054d09179beb63744355cab9f38059548fedd46a9"},
{file = "zstandard-0.23.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:fc9ca1c9718cb3b06634c7c8dec57d24e9438b2aa9a0f02b8bb36bf478538880"},
@ -7563,4 +7503,4 @@ tests = ["wikipedia"]
[metadata]
lock-version = "2.1"
python-versions = "<3.14,>=3.10"
content-hash = "75c1c949aa6c0ef8d681bddd91999f97ed4991451be93ca45bf9c01dd19d8a8a"
content-hash = "ba9cf0e00af2d5542aa4beecbd727af92b77ba584033f05c222b00ae47f96585"

View File

@ -1,6 +1,6 @@
[tool.poetry]
name = "letta"
version = "0.7.7"
version = "0.7.8"
packages = [
{include = "letta"},
]
@ -56,7 +56,6 @@ nltk = "^3.8.1"
jinja2 = "^3.1.5"
locust = {version = "^2.31.5", optional = true}
wikipedia = {version = "^1.4.0", optional = true}
composio-langchain = "^0.7.7"
composio-core = "^0.7.7"
alembic = "^1.13.3"
pyhumps = "^3.8.0"
@ -74,7 +73,7 @@ llama-index = "^0.12.2"
llama-index-embeddings-openai = "^0.3.1"
e2b-code-interpreter = {version = "^1.0.3", optional = true}
anthropic = "^0.49.0"
letta_client = "^0.1.124"
letta_client = "^0.1.127"
openai = "^1.60.0"
opentelemetry-api = "1.30.0"
opentelemetry-sdk = "1.30.0"

View File

@ -1,7 +1,7 @@
{
"context_window": 8192,
"model_endpoint_type": "openai",
"model_endpoint": "https://inference.memgpt.ai",
"model_endpoint": "https://inference.letta.com",
"model": "memgpt-openai",
"embedding_endpoint_type": "hugging-face",
"embedding_endpoint": "https://embeddings.memgpt.ai",

View File

@ -1,7 +1,7 @@
{
"context_window": 8192,
"model_endpoint_type": "openai",
"model_endpoint": "https://inference.memgpt.ai",
"model_endpoint": "https://inference.letta.com",
"model": "memgpt-openai",
"put_inner_thoughts_in_kwargs": true
}

View File

@ -105,7 +105,9 @@ def check_first_response_is_valid_for_llm_endpoint(filename: str, validate_inner
agent = Agent(agent_state=full_agent_state, interface=None, user=client.user)
llm_client = LLMClient.create(
provider=agent_state.llm_config.model_endpoint_type,
provider_name=agent_state.llm_config.provider_name,
provider_type=agent_state.llm_config.model_endpoint_type,
actor_id=client.user.id,
)
if llm_client:
response = llm_client.send_llm_request(
@ -179,7 +181,7 @@ def check_agent_uses_external_tool(filename: str) -> LettaResponse:
Note: This is acting on the Letta response, note the usage of `user_message`
"""
from composio_langchain import Action
from composio import Action
# Set up client
client = create_client()

View File

@ -56,7 +56,7 @@ def test_add_composio_tool(fastapi_client):
assert "name" in response.json()
def test_composio_tool_execution_e2e(check_composio_key_set, composio_get_emojis, server: SyncServer, default_user):
async def test_composio_tool_execution_e2e(check_composio_key_set, composio_get_emojis, server: SyncServer, default_user):
agent_state = server.agent_manager.create_agent(
agent_create=CreateAgent(
name="sarah_agent",
@ -67,7 +67,7 @@ def test_composio_tool_execution_e2e(check_composio_key_set, composio_get_emojis
actor=default_user,
)
tool_execution_result = ToolExecutionManager(agent_state, actor=default_user).execute_tool(
tool_execution_result = await ToolExecutionManager(agent_state, actor=default_user).execute_tool(
function_name=composio_get_emojis.name, function_args={}, tool=composio_get_emojis
)

View File

@ -1,26 +1,26 @@
import os
import threading
from unittest.mock import MagicMock
import pytest
from dotenv import load_dotenv
from letta_client import Letta
from openai import AsyncOpenAI
from openai.types.chat import ChatCompletionChunk
from sqlalchemy import delete
from letta.agents.voice_sleeptime_agent import VoiceSleeptimeAgent
from letta.config import LettaConfig
from letta.orm import Provider, Step
from letta.constants import DEFAULT_MAX_MESSAGE_BUFFER_LENGTH, DEFAULT_MIN_MESSAGE_BUFFER_LENGTH
from letta.orm.errors import NoResultFound
from letta.schemas.agent import AgentType, CreateAgent
from letta.schemas.block import CreateBlock
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import MessageRole, MessageStreamStatus
from letta.schemas.group import ManagerType
from letta.schemas.group import GroupUpdate, ManagerType, VoiceSleeptimeManagerUpdate
from letta.schemas.letta_message import AssistantMessage, ReasoningMessage, ToolCallMessage, ToolReturnMessage, UserMessage
from letta.schemas.letta_message_content import TextContent
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import MessageCreate
from letta.schemas.message import Message, MessageCreate
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest
from letta.schemas.openai.chat_completion_request import UserMessage as OpenAIUserMessage
from letta.schemas.tool import ToolCreate
@ -29,6 +29,8 @@ from letta.server.server import SyncServer
from letta.services.agent_manager import AgentManager
from letta.services.block_manager import BlockManager
from letta.services.message_manager import MessageManager
from letta.services.summarizer.enums import SummarizationMode
from letta.services.summarizer.summarizer import Summarizer
from letta.services.tool_manager import ToolManager
from letta.services.user_manager import UserManager
from letta.utils import get_persona_text
@ -48,16 +50,24 @@ MESSAGE_TRANSCRIPTS = [
"user: Maybe just a recommendation for a nice vegan bakery to grab a birthday treat.",
"assistant: How about Vegan Treats in Santa Barbara? Theyre highly rated.",
"user: Sounds good. Also, I work remotely as a UX designer, usually on a MacBook Pro.",
"user: I want to make sure my itinerary isnt too tight—aiming for 34 days total.",
"assistant: Understood. I can draft a relaxed 4-day schedule with driving and stops.",
"user: Yes, lets do that.",
"assistant: Ill put together a day-by-day plan now.",
]
SUMMARY_REQ_TEXT = """
Here is the conversation history. Lines marked (Older) are about to be evicted; lines marked (Newer) are still in context for clarity:
SYSTEM_MESSAGE = Message(role=MessageRole.system, content=[TextContent(text="System message")])
MESSAGE_OBJECTS = [SYSTEM_MESSAGE]
for entry in MESSAGE_TRANSCRIPTS:
role_str, text = entry.split(":", 1)
role = MessageRole.user if role_str.strip() == "user" else MessageRole.assistant
MESSAGE_OBJECTS.append(Message(role=role, content=[TextContent(text=text.strip())]))
MESSAGE_EVICT_BREAKPOINT = 14
SUMMARY_REQ_TEXT = """
Youre a memory-recall helper for an AI that can only keep the last 4 messages. Scan the conversation history, focusing on messages about to drop out of that window, and write crisp notes that capture any important facts or insights about the human so they arent lost.
(Older) Evicted Messages:
(Older)
0. user: Hey, Ive been thinking about planning a road trip up the California coast next month.
1. assistant: That sounds amazing! Do you have any particular cities or sights in mind?
2. user: I definitely want to stop in Big Sur and maybe Santa Barbara. Also, I love craft coffee shops.
@ -70,16 +80,13 @@ Here is the conversation history. Lines marked (Older) are about to be evicted;
9. assistant: Happy early birthday! Would you like gift ideas or celebration tips?
10. user: Maybe just a recommendation for a nice vegan bakery to grab a birthday treat.
11. assistant: How about Vegan Treats in Santa Barbara? Theyre highly rated.
(Newer) In-Context Messages:
12. user: Sounds good. Also, I work remotely as a UX designer, usually on a MacBook Pro.
(Newer)
13. user: I want to make sure my itinerary isnt too tightaiming for 34 days total.
14. assistant: Understood. I can draft a relaxed 4-day schedule with driving and stops.
15. user: Yes, lets do that.
16. assistant: Ill put together a day-by-day plan now.
Please segment the (Older) portion into coherent chunks andusing **only** the `store_memory` tooloutput a JSON call that lists each chunks `start_index`, `end_index`, and a one-sentence `contextual_description`.
"""
13. assistant: Understood. I can draft a relaxed 4-day schedule with driving and stops.
14. user: Yes, lets do that.
15. assistant: Ill put together a day-by-day plan now."""
# --- Server Management --- #
@ -214,22 +221,12 @@ def org_id(server):
yield org.id
# cleanup
with server.organization_manager.session_maker() as session:
session.execute(delete(Step))
session.execute(delete(Provider))
session.commit()
server.organization_manager.delete_organization_by_id(org.id)
@pytest.fixture(scope="module")
def actor(server, org_id):
user = server.user_manager.create_default_user()
yield user
# cleanup
server.user_manager.delete_user_by_id(user.id)
# --- Helper Functions --- #
@ -301,6 +298,80 @@ async def test_multiple_messages(disable_e2b_api_key, client, voice_agent, endpo
print(chunk.choices[0].delta.content)
@pytest.mark.asyncio
async def test_summarization(disable_e2b_api_key, voice_agent):
agent_manager = AgentManager()
user_manager = UserManager()
actor = user_manager.get_default_user()
request = CreateAgent(
name=voice_agent.name + "-sleeptime",
agent_type=AgentType.voice_sleeptime_agent,
block_ids=[block.id for block in voice_agent.memory.blocks],
memory_blocks=[
CreateBlock(
label="memory_persona",
value=get_persona_text("voice_memory_persona"),
),
],
llm_config=LLMConfig.default_config(model_name="gpt-4o-mini"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
project_id=voice_agent.project_id,
)
sleeptime_agent = agent_manager.create_agent(request, actor=actor)
async_client = AsyncOpenAI()
memory_agent = VoiceSleeptimeAgent(
agent_id=sleeptime_agent.id,
convo_agent_state=sleeptime_agent, # In reality, this will be the main convo agent
openai_client=async_client,
message_manager=MessageManager(),
agent_manager=agent_manager,
actor=actor,
block_manager=BlockManager(),
target_block_label="human",
message_transcripts=MESSAGE_TRANSCRIPTS,
)
summarizer = Summarizer(
mode=SummarizationMode.STATIC_MESSAGE_BUFFER,
summarizer_agent=memory_agent,
message_buffer_limit=8,
message_buffer_min=4,
)
# stub out the agent.step so it returns a known sentinel
memory_agent.step = MagicMock(return_value="STEP_RESULT")
# patch fire_and_forget on *this* summarizer instance to a MagicMock
summarizer.fire_and_forget = MagicMock()
# now call the method under test
in_ctx = MESSAGE_OBJECTS[:MESSAGE_EVICT_BREAKPOINT]
new_msgs = MESSAGE_OBJECTS[MESSAGE_EVICT_BREAKPOINT:]
# call under test (this is sync)
updated, did_summarize = summarizer._static_buffer_summarization(
in_context_messages=in_ctx,
new_letta_messages=new_msgs,
)
assert did_summarize is True
assert len(updated) == summarizer.message_buffer_min + 1 # One extra for system message
assert updated[0].role == MessageRole.system # Preserved system message
# 2) the summarizer_agent.step() should have been *called* exactly once
memory_agent.step.assert_called_once()
call_args = memory_agent.step.call_args.args[0] # the single positional argument: a list of MessageCreate
assert isinstance(call_args, list)
assert isinstance(call_args[0], MessageCreate)
assert call_args[0].role == MessageRole.user
assert "15. assistant: Ill put together a day-by-day plan now." in call_args[0].content[0].text
# 3) fire_and_forget should have been called once, and its argument must be the coroutine returned by step()
summarizer.fire_and_forget.assert_called_once()
@pytest.mark.asyncio
async def test_voice_sleeptime_agent(disable_e2b_api_key, voice_agent):
"""Tests chat completion streaming using the Async OpenAI client."""
@ -427,3 +498,66 @@ async def test_init_voice_convo_agent(voice_agent, server, actor):
server.group_manager.retrieve_group(group_id=group.id, actor=actor)
with pytest.raises(NoResultFound):
server.agent_manager.get_agent_by_id(agent_id=sleeptime_agent_id, actor=actor)
def _modify(group_id, server, actor, max_val, min_val):
"""Helper to invoke modify_group with voice_sleeptime config."""
return server.group_manager.modify_group(
group_id=group_id,
group_update=GroupUpdate(
manager_config=VoiceSleeptimeManagerUpdate(
manager_type=ManagerType.voice_sleeptime,
max_message_buffer_length=max_val,
min_message_buffer_length=min_val,
)
),
actor=actor,
)
@pytest.fixture
def group_id(voice_agent):
return voice_agent.multi_agent_group.id
def test_valid_buffer_lengths_above_four(group_id, server, actor):
# both > 4 and max > min
updated = _modify(group_id, server, actor, max_val=10, min_val=5)
assert updated.max_message_buffer_length == 10
assert updated.min_message_buffer_length == 5
def test_valid_buffer_lengths_only_max(group_id, server, actor):
# both > 4 and max > min
updated = _modify(group_id, server, actor, max_val=DEFAULT_MAX_MESSAGE_BUFFER_LENGTH + 1, min_val=None)
assert updated.max_message_buffer_length == DEFAULT_MAX_MESSAGE_BUFFER_LENGTH + 1
assert updated.min_message_buffer_length == DEFAULT_MIN_MESSAGE_BUFFER_LENGTH
def test_valid_buffer_lengths_only_min(group_id, server, actor):
# both > 4 and max > min
updated = _modify(group_id, server, actor, max_val=None, min_val=DEFAULT_MIN_MESSAGE_BUFFER_LENGTH + 1)
assert updated.max_message_buffer_length == DEFAULT_MAX_MESSAGE_BUFFER_LENGTH
assert updated.min_message_buffer_length == DEFAULT_MIN_MESSAGE_BUFFER_LENGTH + 1
@pytest.mark.parametrize(
"max_val,min_val,err_part",
[
# only one set → both-or-none
(None, DEFAULT_MAX_MESSAGE_BUFFER_LENGTH, "must be greater than"),
(DEFAULT_MIN_MESSAGE_BUFFER_LENGTH, None, "must be greater than"),
# ordering violations
(5, 5, "must be greater than"),
(6, 7, "must be greater than"),
# lower-bound (must both be > 4)
(4, 5, "greater than 4"),
(5, 4, "greater than 4"),
(1, 10, "greater than 4"),
(10, 1, "greater than 4"),
],
)
def test_invalid_buffer_lengths(group_id, server, actor, max_val, min_val, err_part):
with pytest.raises(ValueError) as exc:
_modify(group_id, server, actor, max_val, min_val)
assert err_part in str(exc.value)

View File

@ -124,7 +124,7 @@ def test_agent(client: LocalClient):
def test_agent_add_remove_tools(client: LocalClient, agent):
# Create and add two tools to the client
# tool 1
from composio_langchain import Action
from composio import Action
github_tool = client.load_composio_tool(action=Action.GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER)
@ -316,7 +316,7 @@ def test_tools(client: LocalClient):
def test_tools_from_composio_basic(client: LocalClient):
from composio_langchain import Action
from composio import Action
# Create a `LocalClient` (you can also use a `RESTClient`, see the letta_rest_client.py example)
client = create_client()

View File

@ -3,7 +3,7 @@ from unittest.mock import patch
import pytest
from letta.server.rest_api.optimistic_json_parser import OptimisticJSONParser
from letta.server.rest_api.json_parser import OptimisticJSONParser
@pytest.fixture

View File

@ -19,97 +19,166 @@ from letta.settings import model_settings
def test_openai():
api_key = os.getenv("OPENAI_API_KEY")
assert api_key is not None
provider = OpenAIProvider(api_key=api_key, base_url=model_settings.openai_api_base)
provider = OpenAIProvider(
name="openai",
api_key=api_key,
base_url=model_settings.openai_api_base,
)
models = provider.list_llm_models()
print(models)
assert len(models) > 0
assert models[0].handle == f"{provider.name}/{models[0].model}"
embedding_models = provider.list_embedding_models()
assert len(embedding_models) > 0
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
def test_deepseek():
api_key = os.getenv("DEEPSEEK_API_KEY")
assert api_key is not None
provider = DeepSeekProvider(api_key=api_key)
provider = DeepSeekProvider(
name="deepseek",
api_key=api_key,
)
models = provider.list_llm_models()
print(models)
assert len(models) > 0
assert models[0].handle == f"{provider.name}/{models[0].model}"
def test_anthropic():
api_key = os.getenv("ANTHROPIC_API_KEY")
assert api_key is not None
provider = AnthropicProvider(api_key=api_key)
provider = AnthropicProvider(
name="anthropic",
api_key=api_key,
)
models = provider.list_llm_models()
print(models)
assert len(models) > 0
assert models[0].handle == f"{provider.name}/{models[0].model}"
def test_groq():
provider = GroqProvider(api_key=os.getenv("GROQ_API_KEY"))
provider = GroqProvider(
name="groq",
api_key=os.getenv("GROQ_API_KEY"),
)
models = provider.list_llm_models()
print(models)
assert len(models) > 0
assert models[0].handle == f"{provider.name}/{models[0].model}"
def test_azure():
provider = AzureProvider(api_key=os.getenv("AZURE_API_KEY"), base_url=os.getenv("AZURE_BASE_URL"))
provider = AzureProvider(
name="azure",
api_key=os.getenv("AZURE_API_KEY"),
base_url=os.getenv("AZURE_BASE_URL"),
)
models = provider.list_llm_models()
print([m.model for m in models])
assert len(models) > 0
assert models[0].handle == f"{provider.name}/{models[0].model}"
embed_models = provider.list_embedding_models()
print([m.embedding_model for m in embed_models])
embedding_models = provider.list_embedding_models()
assert len(embedding_models) > 0
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
def test_ollama():
base_url = os.getenv("OLLAMA_BASE_URL")
assert base_url is not None
provider = OllamaProvider(base_url=base_url, default_prompt_formatter=model_settings.default_prompt_formatter, api_key=None)
provider = OllamaProvider(
name="ollama",
base_url=base_url,
default_prompt_formatter=model_settings.default_prompt_formatter,
api_key=None,
)
models = provider.list_llm_models()
print(models)
assert len(models) > 0
assert models[0].handle == f"{provider.name}/{models[0].model}"
embedding_models = provider.list_embedding_models()
print(embedding_models)
assert len(embedding_models) > 0
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
def test_googleai():
api_key = os.getenv("GEMINI_API_KEY")
assert api_key is not None
provider = GoogleAIProvider(api_key=api_key)
provider = GoogleAIProvider(
name="google_ai",
api_key=api_key,
)
models = provider.list_llm_models()
print(models)
assert len(models) > 0
assert models[0].handle == f"{provider.name}/{models[0].model}"
provider.list_embedding_models()
embedding_models = provider.list_embedding_models()
assert len(embedding_models) > 0
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
def test_google_vertex():
provider = GoogleVertexProvider(google_cloud_project=os.getenv("GCP_PROJECT_ID"), google_cloud_location=os.getenv("GCP_REGION"))
provider = GoogleVertexProvider(
name="google_vertex",
google_cloud_project=os.getenv("GCP_PROJECT_ID"),
google_cloud_location=os.getenv("GCP_REGION"),
)
models = provider.list_llm_models()
print(models)
print([m.model for m in models])
assert len(models) > 0
assert models[0].handle == f"{provider.name}/{models[0].model}"
embedding_models = provider.list_embedding_models()
print([m.embedding_model for m in embedding_models])
assert len(embedding_models) > 0
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
def test_mistral():
provider = MistralProvider(api_key=os.getenv("MISTRAL_API_KEY"))
provider = MistralProvider(
name="mistral",
api_key=os.getenv("MISTRAL_API_KEY"),
)
models = provider.list_llm_models()
print([m.model for m in models])
assert len(models) > 0
assert models[0].handle == f"{provider.name}/{models[0].model}"
def test_together():
provider = TogetherProvider(api_key=os.getenv("TOGETHER_API_KEY"), default_prompt_formatter="chatml")
provider = TogetherProvider(
name="together",
api_key=os.getenv("TOGETHER_API_KEY"),
default_prompt_formatter="chatml",
)
models = provider.list_llm_models()
print([m.model for m in models])
assert len(models) > 0
assert models[0].handle == f"{provider.name}/{models[0].model}"
embedding_models = provider.list_embedding_models()
print([m.embedding_model for m in embedding_models])
assert len(embedding_models) > 0
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
def test_anthropic_bedrock():
from letta.settings import model_settings
provider = AnthropicBedrockProvider(aws_region=model_settings.aws_region)
provider = AnthropicBedrockProvider(name="bedrock", aws_region=model_settings.aws_region)
models = provider.list_llm_models()
print([m.model for m in models])
assert len(models) > 0
assert models[0].handle == f"{provider.name}/{models[0].model}"
embedding_models = provider.list_embedding_models()
print([m.embedding_model for m in embedding_models])
assert len(embedding_models) > 0
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
def test_custom_anthropic():
api_key = os.getenv("ANTHROPIC_API_KEY")
assert api_key is not None
provider = AnthropicProvider(
name="custom_anthropic",
api_key=api_key,
)
models = provider.list_llm_models()
assert len(models) > 0
assert models[0].handle == f"{provider.name}/{models[0].model}"
# def test_vllm():

View File

@ -13,7 +13,7 @@ import letta.utils as utils
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, LETTA_DIR, LETTA_TOOL_EXECUTION_DIR
from letta.orm import Provider, Step
from letta.schemas.block import CreateBlock
from letta.schemas.enums import MessageRole
from letta.schemas.enums import MessageRole, ProviderType
from letta.schemas.letta_message import LettaMessage, ReasoningMessage, SystemMessage, ToolCallMessage, ToolReturnMessage, UserMessage
from letta.schemas.llm_config import LLMConfig
from letta.schemas.providers import Provider as PydanticProvider
@ -1226,7 +1226,8 @@ def test_messages_with_provider_override(server: SyncServer, user_id: str):
actor = server.user_manager.get_user_or_default(user_id)
provider = server.provider_manager.create_provider(
provider=PydanticProvider(
name="anthropic",
name="caren-anthropic",
provider_type=ProviderType.anthropic,
api_key=os.getenv("ANTHROPIC_API_KEY"),
),
actor=actor,
@ -1234,8 +1235,8 @@ def test_messages_with_provider_override(server: SyncServer, user_id: str):
agent = server.create_agent(
request=CreateAgent(
memory_blocks=[],
model="anthropic/claude-3-opus-20240229",
context_window_limit=200000,
model="caren-anthropic/claude-3-opus-20240229",
context_window_limit=100000,
embedding="openai/text-embedding-ada-002",
),
actor=actor,