From 20ecab29a1458772d9591e4629f4e820fa926ecf Mon Sep 17 00:00:00 2001 From: cthomas Date: Wed, 30 Apr 2025 23:39:58 -0700 Subject: [PATCH] chore: bump version 0.7.8 (#2604) Co-authored-by: Kian Jones <11655409+kianjones9@users.noreply.github.com> Co-authored-by: Andy Li <55300002+cliandy@users.noreply.github.com> Co-authored-by: Matthew Zhou --- README.md | 4 +- ...f_add_byok_fields_and_unique_constraint.py | 35 ++ ...71_add_buffer_length_min_max_for_voice_.py | 33 ++ examples/composio_tool_usage.py | 2 +- examples/sleeptime/voice_sleeptime_example.py | 32 ++ letta/__init__.py | 2 +- letta/agent.py | 18 +- letta/agents/exceptions.py | 6 + letta/agents/letta_agent.py | 83 +++-- letta/agents/letta_agent_batch.py | 8 +- letta/agents/voice_agent.py | 17 +- letta/constants.py | 6 +- letta/functions/composio_helpers.py | 100 ++++++ letta/functions/functions.py | 6 +- letta/functions/helpers.py | 118 +------ letta/groups/helpers.py | 1 + letta/groups/sleeptime_multi_agent.py | 6 +- letta/helpers/message_helper.py | 25 +- letta/helpers/tool_execution_helper.py | 2 +- .../anthropic_streaming_interface.py | 333 +++++++++--------- ...ai_chat_completions_streaming_interface.py | 2 +- letta/llm_api/anthropic.py | 25 +- letta/llm_api/anthropic_client.py | 6 +- letta/llm_api/google_vertex_client.py | 2 +- letta/llm_api/llm_api_tools.py | 7 + letta/llm_api/llm_client.py | 14 +- letta/llm_api/llm_client_base.py | 4 + letta/llm_api/openai.py | 14 +- letta/llm_api/openai_client.py | 24 +- letta/memory.py | 4 +- letta/orm/group.py | 2 + letta/orm/provider.py | 10 + letta/schemas/agent.py | 1 - letta/schemas/enums.py | 11 + letta/schemas/group.py | 24 ++ letta/schemas/llm_config.py | 1 + letta/schemas/llm_config_overrides.py | 4 +- letta/schemas/providers.py | 95 +++-- letta/schemas/tool.py | 11 +- letta/server/rest_api/app.py | 12 + .../rest_api/chat_completions_interface.py | 2 +- letta/server/rest_api/interface.py | 18 +- ...timistic_json_parser.py => json_parser.py} | 88 +++-- letta/server/rest_api/routers/v1/agents.py | 2 +- letta/server/rest_api/routers/v1/llms.py | 7 +- letta/server/rest_api/routers/v1/providers.py | 5 +- letta/server/rest_api/routers/v1/voice.py | 2 - letta/server/rest_api/utils.py | 29 +- letta/server/server.py | 36 +- letta/services/group_manager.py | 58 +++ letta/services/provider_manager.py | 39 +- letta/services/summarizer/summarizer.py | 22 +- .../tool_executor/tool_execution_manager.py | 2 +- letta/services/tool_executor/tool_executor.py | 6 +- poetry.lock | 104 ++---- pyproject.toml | 5 +- tests/configs/letta_hosted.json | 18 +- .../llm_model_configs/letta-hosted.json | 2 +- tests/helpers/endpoints_helper.py | 6 +- tests/integration_test_composio.py | 4 +- tests/integration_test_voice_agent.py | 188 ++++++++-- tests/test_local_client.py | 4 +- tests/test_optimistic_json_parser.py | 2 +- tests/test_providers.py | 129 +++++-- tests/test_server.py | 9 +- 65 files changed, 1248 insertions(+), 649 deletions(-) create mode 100644 alembic/versions/373dabcba6cf_add_byok_fields_and_unique_constraint.py create mode 100644 alembic/versions/c56081a05371_add_buffer_length_min_max_for_voice_.py create mode 100644 examples/sleeptime/voice_sleeptime_example.py create mode 100644 letta/agents/exceptions.py create mode 100644 letta/functions/composio_helpers.py rename letta/server/rest_api/{optimistic_json_parser.py => json_parser.py} (70%) diff --git a/README.md b/README.md index b8ccade94..aa102ba37 100644 --- a/README.md +++ b/README.md @@ -167,7 +167,7 @@ docker exec -it $(docker ps -q -f ancestor=letta/letta) letta run In the CLI tool, you'll be able to create new agents, or load existing agents: ``` 🧬 Creating new agent... -? Select LLM model: letta-free [type=openai] [ip=https://inference.memgpt.ai] +? Select LLM model: letta-free [type=openai] [ip=https://inference.letta.com] ? Select embedding model: letta-free [type=hugging-face] [ip=https://embeddings.memgpt.ai] -> 🤖 Using persona profile: 'sam_pov' -> 🧑 Using human profile: 'basic' @@ -233,7 +233,7 @@ letta run ``` ``` 🧬 Creating new agent... -? Select LLM model: letta-free [type=openai] [ip=https://inference.memgpt.ai] +? Select LLM model: letta-free [type=openai] [ip=https://inference.letta.com] ? Select embedding model: letta-free [type=hugging-face] [ip=https://embeddings.memgpt.ai] -> 🤖 Using persona profile: 'sam_pov' -> 🧑 Using human profile: 'basic' diff --git a/alembic/versions/373dabcba6cf_add_byok_fields_and_unique_constraint.py b/alembic/versions/373dabcba6cf_add_byok_fields_and_unique_constraint.py new file mode 100644 index 000000000..3b94ceddb --- /dev/null +++ b/alembic/versions/373dabcba6cf_add_byok_fields_and_unique_constraint.py @@ -0,0 +1,35 @@ +"""add byok fields and unique constraint + +Revision ID: 373dabcba6cf +Revises: c56081a05371 +Create Date: 2025-04-30 19:38:25.010856 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "373dabcba6cf" +down_revision: Union[str, None] = "c56081a05371" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("providers", sa.Column("provider_type", sa.String(), nullable=True)) + op.add_column("providers", sa.Column("base_url", sa.String(), nullable=True)) + op.create_unique_constraint("unique_name_organization_id", "providers", ["name", "organization_id"]) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint("unique_name_organization_id", "providers", type_="unique") + op.drop_column("providers", "base_url") + op.drop_column("providers", "provider_type") + # ### end Alembic commands ### diff --git a/alembic/versions/c56081a05371_add_buffer_length_min_max_for_voice_.py b/alembic/versions/c56081a05371_add_buffer_length_min_max_for_voice_.py new file mode 100644 index 000000000..44f9a87f6 --- /dev/null +++ b/alembic/versions/c56081a05371_add_buffer_length_min_max_for_voice_.py @@ -0,0 +1,33 @@ +"""Add buffer length min max for voice sleeptime + +Revision ID: c56081a05371 +Revises: 28b8765bdd0a +Create Date: 2025-04-30 16:03:41.213750 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "c56081a05371" +down_revision: Union[str, None] = "28b8765bdd0a" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("groups", sa.Column("max_message_buffer_length", sa.Integer(), nullable=True)) + op.add_column("groups", sa.Column("min_message_buffer_length", sa.Integer(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("groups", "min_message_buffer_length") + op.drop_column("groups", "max_message_buffer_length") + # ### end Alembic commands ### diff --git a/examples/composio_tool_usage.py b/examples/composio_tool_usage.py index d32546d1f..89c662b00 100644 --- a/examples/composio_tool_usage.py +++ b/examples/composio_tool_usage.py @@ -60,7 +60,7 @@ Last updated Oct 2, 2024. Please check `composio` documentation for any composio def main(): - from composio_langchain import Action + from composio import Action # Add the composio tool tool = client.load_composio_tool(action=Action.GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER) diff --git a/examples/sleeptime/voice_sleeptime_example.py b/examples/sleeptime/voice_sleeptime_example.py new file mode 100644 index 000000000..66c0be7d3 --- /dev/null +++ b/examples/sleeptime/voice_sleeptime_example.py @@ -0,0 +1,32 @@ +from letta_client import Letta, VoiceSleeptimeManagerUpdate + +client = Letta(base_url="http://localhost:8283") + +agent = client.agents.create( + name="low_latency_voice_agent_demo", + agent_type="voice_convo_agent", + memory_blocks=[ + {"value": "Name: ?", "label": "human"}, + {"value": "You are a helpful assistant.", "label": "persona"}, + ], + model="openai/gpt-4o-mini", # Use 4o-mini for speed + embedding="openai/text-embedding-3-small", + enable_sleeptime=True, + initial_message_sequence = [], +) +print(f"Created agent id {agent.id}") + +# get the group +group_id = agent.multi_agent_group.id +max_message_buffer_length = agent.multi_agent_group.max_message_buffer_length +min_message_buffer_length = agent.multi_agent_group.min_message_buffer_length +print(f"Group id: {group_id}, max_message_buffer_length: {max_message_buffer_length}, min_message_buffer_length: {min_message_buffer_length}") + +# change it to be more frequent +group = client.groups.modify( + group_id=group_id, + manager_config=VoiceSleeptimeManagerUpdate( + max_message_buffer_length=10, + min_message_buffer_length=6, + ) +) diff --git a/letta/__init__.py b/letta/__init__.py index d240209bf..1b3c8af68 100644 --- a/letta/__init__.py +++ b/letta/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.7.7" +__version__ = "0.7.8" # import clients from letta.client.client import LocalClient, RESTClient, create_client diff --git a/letta/agent.py b/letta/agent.py index 38062c557..40019673f 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -21,14 +21,14 @@ from letta.constants import ( ) from letta.errors import ContextWindowExceededError from letta.functions.ast_parsers import coerce_dict_args_by_annotations, get_function_annotations_from_source +from letta.functions.composio_helpers import execute_composio_action, generate_composio_action_from_func_name from letta.functions.functions import get_function_from_module -from letta.functions.helpers import execute_composio_action, generate_composio_action_from_func_name from letta.functions.mcp_client.base_client import BaseMCPClient from letta.helpers import ToolRulesSolver from letta.helpers.composio_helpers import get_composio_api_key from letta.helpers.datetime_helpers import get_utc_time from letta.helpers.json_helpers import json_dumps, json_loads -from letta.helpers.message_helper import prepare_input_message_create +from letta.helpers.message_helper import convert_message_creates_to_messages from letta.interface import AgentInterface from letta.llm_api.helpers import calculate_summarizer_cutoff, get_token_counts_for_messages, is_context_overflow_error from letta.llm_api.llm_api_tools import create @@ -331,8 +331,10 @@ class Agent(BaseAgent): log_telemetry(self.logger, "_get_ai_reply create start") # New LLM client flow llm_client = LLMClient.create( - provider=self.agent_state.llm_config.model_endpoint_type, + provider_name=self.agent_state.llm_config.provider_name, + provider_type=self.agent_state.llm_config.model_endpoint_type, put_inner_thoughts_first=put_inner_thoughts_first, + actor_id=self.user.id, ) if llm_client and not stream: @@ -726,8 +728,7 @@ class Agent(BaseAgent): self.tool_rules_solver.clear_tool_history() # Convert MessageCreate objects to Message objects - message_objects = [prepare_input_message_create(m, self.agent_state.id, True, True) for m in input_messages] - next_input_messages = message_objects + next_input_messages = convert_message_creates_to_messages(input_messages, self.agent_state.id) counter = 0 total_usage = UsageStatistics() step_count = 0 @@ -942,12 +943,7 @@ class Agent(BaseAgent): model_endpoint=self.agent_state.llm_config.model_endpoint, context_window_limit=self.agent_state.llm_config.context_window, usage=response.usage, - # TODO(@caren): Add full provider support - this line is a workaround for v0 BYOK feature - provider_id=( - self.provider_manager.get_anthropic_override_provider_id() - if self.agent_state.llm_config.model_endpoint_type == "anthropic" - else None - ), + provider_id=self.provider_manager.get_provider_id_from_name(self.agent_state.llm_config.provider_name), job_id=job_id, ) for message in all_new_messages: diff --git a/letta/agents/exceptions.py b/letta/agents/exceptions.py new file mode 100644 index 000000000..270cfc356 --- /dev/null +++ b/letta/agents/exceptions.py @@ -0,0 +1,6 @@ +class IncompatibleAgentType(ValueError): + def __init__(self, expected_type: str, actual_type: str): + message = f"Incompatible agent type: expected '{expected_type}', but got '{actual_type}'." + super().__init__(message) + self.expected_type = expected_type + self.actual_type = actual_type diff --git a/letta/agents/letta_agent.py b/letta/agents/letta_agent.py index 834098f97..5d859b342 100644 --- a/letta/agents/letta_agent.py +++ b/letta/agents/letta_agent.py @@ -67,8 +67,10 @@ class LettaAgent(BaseAgent): ) tool_rules_solver = ToolRulesSolver(agent_state.tool_rules) llm_client = LLMClient.create( - provider=agent_state.llm_config.model_endpoint_type, + provider_name=agent_state.llm_config.provider_name, + provider_type=agent_state.llm_config.model_endpoint_type, put_inner_thoughts_first=True, + actor_id=self.actor.id, ) for step in range(max_steps): response = await self._get_ai_reply( @@ -109,8 +111,10 @@ class LettaAgent(BaseAgent): ) tool_rules_solver = ToolRulesSolver(agent_state.tool_rules) llm_client = LLMClient.create( - llm_config=agent_state.llm_config, + provider_name=agent_state.llm_config.provider_name, + provider_type=agent_state.llm_config.model_endpoint_type, put_inner_thoughts_first=True, + actor_id=self.actor.id, ) for step in range(max_steps): @@ -125,7 +129,7 @@ class LettaAgent(BaseAgent): # TODO: THIS IS INCREDIBLY UGLY # TODO: THERE ARE MULTIPLE COPIES OF THE LLM_CONFIG EVERYWHERE THAT ARE GETTING MANIPULATED interface = AnthropicStreamingInterface( - use_assistant_message=use_assistant_message, put_inner_thoughts_in_kwarg=llm_client.llm_config.put_inner_thoughts_in_kwargs + use_assistant_message=use_assistant_message, put_inner_thoughts_in_kwarg=agent_state.llm_config.put_inner_thoughts_in_kwargs ) async for chunk in interface.process(stream): yield f"data: {chunk.model_dump_json()}\n\n" @@ -179,6 +183,7 @@ class LettaAgent(BaseAgent): ToolType.LETTA_SLEEPTIME_CORE, } or (t.tool_type == ToolType.LETTA_MULTI_AGENT_CORE and t.name == "send_message_to_agents_matching_tags") + or (t.tool_type == ToolType.EXTERNAL_COMPOSIO) ] valid_tool_names = tool_rules_solver.get_allowed_tool_names(available_tools=set([t.name for t in tools])) @@ -274,45 +279,49 @@ class LettaAgent(BaseAgent): return persisted_messages, continue_stepping def _rebuild_memory(self, in_context_messages: List[Message], agent_state: AgentState) -> List[Message]: - self.agent_manager.refresh_memory(agent_state=agent_state, actor=self.actor) + try: + self.agent_manager.refresh_memory(agent_state=agent_state, actor=self.actor) - # TODO: This is a pretty brittle pattern established all over our code, need to get rid of this - curr_system_message = in_context_messages[0] - curr_memory_str = agent_state.memory.compile() - curr_system_message_text = curr_system_message.content[0].text - if curr_memory_str in curr_system_message_text: - # NOTE: could this cause issues if a block is removed? (substring match would still work) - logger.debug( - f"Memory hasn't changed for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name}), skipping system prompt rebuild" - ) - return in_context_messages + # TODO: This is a pretty brittle pattern established all over our code, need to get rid of this + curr_system_message = in_context_messages[0] + curr_memory_str = agent_state.memory.compile() + curr_system_message_text = curr_system_message.content[0].text + if curr_memory_str in curr_system_message_text: + # NOTE: could this cause issues if a block is removed? (substring match would still work) + logger.debug( + f"Memory hasn't changed for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name}), skipping system prompt rebuild" + ) + return in_context_messages - memory_edit_timestamp = get_utc_time() + memory_edit_timestamp = get_utc_time() - num_messages = self.message_manager.size(actor=self.actor, agent_id=agent_state.id) - num_archival_memories = self.passage_manager.size(actor=self.actor, agent_id=agent_state.id) + num_messages = self.message_manager.size(actor=self.actor, agent_id=agent_state.id) + num_archival_memories = self.passage_manager.size(actor=self.actor, agent_id=agent_state.id) - new_system_message_str = compile_system_message( - system_prompt=agent_state.system, - in_context_memory=agent_state.memory, - in_context_memory_last_edit=memory_edit_timestamp, - previous_message_count=num_messages, - archival_memory_size=num_archival_memories, - ) - - diff = united_diff(curr_system_message_text, new_system_message_str) - if len(diff) > 0: - logger.debug(f"Rebuilding system with new memory...\nDiff:\n{diff}") - - new_system_message = self.message_manager.update_message_by_id( - curr_system_message.id, message_update=MessageUpdate(content=new_system_message_str), actor=self.actor + new_system_message_str = compile_system_message( + system_prompt=agent_state.system, + in_context_memory=agent_state.memory, + in_context_memory_last_edit=memory_edit_timestamp, + previous_message_count=num_messages, + archival_memory_size=num_archival_memories, ) - # Skip pulling down the agent's memory again to save on a db call - return [new_system_message] + in_context_messages[1:] + diff = united_diff(curr_system_message_text, new_system_message_str) + if len(diff) > 0: + logger.debug(f"Rebuilding system with new memory...\nDiff:\n{diff}") - else: - return in_context_messages + new_system_message = self.message_manager.update_message_by_id( + curr_system_message.id, message_update=MessageUpdate(content=new_system_message_str), actor=self.actor + ) + + # Skip pulling down the agent's memory again to save on a db call + return [new_system_message] + in_context_messages[1:] + + else: + return in_context_messages + except: + logger.exception(f"Failed to rebuild memory for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name})") + raise @trace_method async def _execute_tool(self, tool_name: str, tool_args: dict, agent_state: AgentState) -> Tuple[str, bool]: @@ -331,6 +340,10 @@ class LettaAgent(BaseAgent): results = await self._send_message_to_agents_matching_tags(**tool_args) log_event(name="finish_send_message_to_agents_matching_tags", attributes=tool_args) return json.dumps(results), True + elif target_tool.type == ToolType.EXTERNAL_COMPOSIO: + log_event(name=f"start_composio_{tool_name}_execution", attributes=tool_args) + log_event(name=f"finish_compsio_{tool_name}_execution", attributes=tool_args) + return tool_execution_result.func_return, True else: tool_execution_manager = ToolExecutionManager(agent_state=agent_state, actor=self.actor) # TODO: Integrate sandbox result diff --git a/letta/agents/letta_agent_batch.py b/letta/agents/letta_agent_batch.py index a6d31a09c..3610bf2ec 100644 --- a/letta/agents/letta_agent_batch.py +++ b/letta/agents/letta_agent_batch.py @@ -156,8 +156,10 @@ class LettaAgentBatch: log_event(name="init_llm_client") llm_client = LLMClient.create( - provider=agent_states[0].llm_config.model_endpoint_type, + provider_name=agent_states[0].llm_config.provider_name, + provider_type=agent_states[0].llm_config.model_endpoint_type, put_inner_thoughts_first=True, + actor_id=self.actor.id, ) agent_llm_config_mapping = {s.id: s.llm_config for s in agent_states} @@ -273,8 +275,10 @@ class LettaAgentBatch: # translate provider‑specific response → OpenAI‑style tool call (unchanged) llm_client = LLMClient.create( - provider=item.llm_config.model_endpoint_type, + provider_name=item.llm_config.provider_name, + provider_type=item.llm_config.model_endpoint_type, put_inner_thoughts_first=True, + actor_id=self.actor.id, ) tool_call = ( llm_client.convert_response_to_chat_completion( diff --git a/letta/agents/voice_agent.py b/letta/agents/voice_agent.py index 8d3077fc0..390964608 100644 --- a/letta/agents/voice_agent.py +++ b/letta/agents/voice_agent.py @@ -6,6 +6,7 @@ from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple import openai from letta.agents.base_agent import BaseAgent +from letta.agents.exceptions import IncompatibleAgentType from letta.agents.voice_sleeptime_agent import VoiceSleeptimeAgent from letta.constants import NON_USER_MSG_PREFIX from letta.helpers.datetime_helpers import get_utc_time @@ -18,7 +19,7 @@ from letta.helpers.tool_execution_helper import ( from letta.interfaces.openai_chat_completions_streaming_interface import OpenAIChatCompletionsStreamingInterface from letta.log import get_logger from letta.orm.enums import ToolType -from letta.schemas.agent import AgentState +from letta.schemas.agent import AgentState, AgentType from letta.schemas.enums import MessageRole from letta.schemas.letta_response import LettaResponse from letta.schemas.message import Message, MessageCreate, MessageUpdate @@ -68,8 +69,6 @@ class VoiceAgent(BaseAgent): block_manager: BlockManager, passage_manager: PassageManager, actor: User, - message_buffer_limit: int, - message_buffer_min: int, ): super().__init__( agent_id=agent_id, openai_client=openai_client, message_manager=message_manager, agent_manager=agent_manager, actor=actor @@ -80,8 +79,6 @@ class VoiceAgent(BaseAgent): self.passage_manager = passage_manager # TODO: This is not guaranteed to exist! self.summary_block_label = "human" - self.message_buffer_limit = message_buffer_limit - self.message_buffer_min = message_buffer_min # Cached archival memory/message size self.num_messages = self.message_manager.size(actor=self.actor, agent_id=agent_id) @@ -108,8 +105,8 @@ class VoiceAgent(BaseAgent): target_block_label=self.summary_block_label, message_transcripts=[], ), - message_buffer_limit=self.message_buffer_limit, - message_buffer_min=self.message_buffer_min, + message_buffer_limit=agent_state.multi_agent_group.max_message_buffer_length, + message_buffer_min=agent_state.multi_agent_group.min_message_buffer_length, ) return summarizer @@ -124,9 +121,15 @@ class VoiceAgent(BaseAgent): """ if len(input_messages) != 1 or input_messages[0].role != MessageRole.user: raise ValueError(f"Voice Agent was invoked with multiple input messages or message did not have role `user`: {input_messages}") + user_query = input_messages[0].content[0].text agent_state = self.agent_manager.get_agent_by_id(self.agent_id, actor=self.actor) + + # Safety check + if agent_state.agent_type != AgentType.voice_convo_agent: + raise IncompatibleAgentType(expected_type=AgentType.voice_convo_agent, actual_type=agent_state.agent_type) + summarizer = self.init_summarizer(agent_state=agent_state) in_context_messages = self.message_manager.get_messages_by_ids(message_ids=agent_state.message_ids, actor=self.actor) diff --git a/letta/constants.py b/letta/constants.py index 6466798e8..448277f84 100644 --- a/letta/constants.py +++ b/letta/constants.py @@ -4,7 +4,7 @@ from logging import CRITICAL, DEBUG, ERROR, INFO, NOTSET, WARN, WARNING LETTA_DIR = os.path.join(os.path.expanduser("~"), ".letta") LETTA_TOOL_EXECUTION_DIR = os.path.join(LETTA_DIR, "tool_execution_dir") -LETTA_MODEL_ENDPOINT = "https://inference.memgpt.ai" +LETTA_MODEL_ENDPOINT = "https://inference.letta.com" ADMIN_PREFIX = "/v1/admin" API_PREFIX = "/v1" @@ -35,6 +35,10 @@ TOOL_CALL_ID_MAX_LEN = 29 # minimum context window size MIN_CONTEXT_WINDOW = 4096 +# Voice Sleeptime message buffer lengths +DEFAULT_MAX_MESSAGE_BUFFER_LENGTH = 30 +DEFAULT_MIN_MESSAGE_BUFFER_LENGTH = 15 + # embeddings MAX_EMBEDDING_DIM = 4096 # maximum supported embeding size - do NOT change or else DBs will need to be reset DEFAULT_EMBEDDING_CHUNK_SIZE = 300 diff --git a/letta/functions/composio_helpers.py b/letta/functions/composio_helpers.py new file mode 100644 index 000000000..ae5cbb35a --- /dev/null +++ b/letta/functions/composio_helpers.py @@ -0,0 +1,100 @@ +import asyncio +import os +from typing import Any, Optional + +from composio import ComposioToolSet +from composio.constants import DEFAULT_ENTITY_ID +from composio.exceptions import ( + ApiKeyNotProvidedError, + ComposioSDKError, + ConnectedAccountNotFoundError, + EnumMetadataNotFound, + EnumStringNotFound, +) + +from letta.constants import COMPOSIO_ENTITY_ENV_VAR_KEY + + +# TODO: This is kind of hacky, as this is used to search up the action later on composio's side +# TODO: So be very careful changing/removing these pair of functions +def _generate_func_name_from_composio_action(action_name: str) -> str: + """ + Generates the composio function name from the composio action. + + Args: + action_name: The composio action name + + Returns: + function name + """ + return action_name.lower() + + +def generate_composio_action_from_func_name(func_name: str) -> str: + """ + Generates the composio action from the composio function name. + + Args: + func_name: The composio function name + + Returns: + composio action name + """ + return func_name.upper() + + +def generate_composio_tool_wrapper(action_name: str) -> tuple[str, str]: + # Generate func name + func_name = _generate_func_name_from_composio_action(action_name) + + wrapper_function_str = f"""\ +def {func_name}(**kwargs): + raise RuntimeError("Something went wrong - we should never be using the persisted source code for Composio. Please reach out to Letta team") +""" + + # Compile safety check + _assert_code_gen_compilable(wrapper_function_str.strip()) + + return func_name, wrapper_function_str.strip() + + +async def execute_composio_action_async( + action_name: str, args: dict, api_key: Optional[str] = None, entity_id: Optional[str] = None +) -> tuple[str, str]: + try: + loop = asyncio.get_running_loop() + return await loop.run_in_executor(None, execute_composio_action, action_name, args, api_key, entity_id) + except Exception as e: + raise RuntimeError(f"Error in execute_composio_action_async: {e}") from e + + +def execute_composio_action(action_name: str, args: dict, api_key: Optional[str] = None, entity_id: Optional[str] = None) -> Any: + entity_id = entity_id or os.getenv(COMPOSIO_ENTITY_ENV_VAR_KEY, DEFAULT_ENTITY_ID) + try: + composio_toolset = ComposioToolSet(api_key=api_key, entity_id=entity_id, lock=False) + response = composio_toolset.execute_action(action=action_name, params=args) + except ApiKeyNotProvidedError: + raise RuntimeError( + f"Composio API key is missing for action '{action_name}'. " + "Please set the sandbox environment variables either through the ADE or the API." + ) + except ConnectedAccountNotFoundError: + raise RuntimeError(f"No connected account was found for action '{action_name}'. " "Please link an account and try again.") + except EnumStringNotFound as e: + raise RuntimeError(f"Invalid value provided for action '{action_name}': " + str(e) + ". Please check the action parameters.") + except EnumMetadataNotFound as e: + raise RuntimeError(f"Invalid value provided for action '{action_name}': " + str(e) + ". Please check the action parameters.") + except ComposioSDKError as e: + raise RuntimeError(f"An unexpected error occurred in Composio SDK while executing action '{action_name}': " + str(e)) + + if "error" in response and response["error"]: + raise RuntimeError(f"Error while executing action '{action_name}': " + str(response["error"])) + + return response.get("data") + + +def _assert_code_gen_compilable(code_str): + try: + compile(code_str, "", "exec") + except SyntaxError as e: + print(f"Syntax error in code: {e}") diff --git a/letta/functions/functions.py b/letta/functions/functions.py index 007d587d1..b0c41a86b 100644 --- a/letta/functions/functions.py +++ b/letta/functions/functions.py @@ -1,8 +1,9 @@ import importlib import inspect +from collections.abc import Callable from textwrap import dedent # remove indentation from types import ModuleType -from typing import Dict, List, Literal, Optional +from typing import Any, Dict, List, Literal, Optional from letta.errors import LettaToolCreateError from letta.functions.schema_generator import generate_schema @@ -66,7 +67,8 @@ def parse_source_code(func) -> str: return source_code -def get_function_from_module(module_name: str, function_name: str): +# TODO (cliandy) refactor below two funcs +def get_function_from_module(module_name: str, function_name: str) -> Callable[..., Any]: """ Dynamically imports a function from a specified module. diff --git a/letta/functions/helpers.py b/letta/functions/helpers.py index 54ca2740b..9797796dc 100644 --- a/letta/functions/helpers.py +++ b/letta/functions/helpers.py @@ -6,10 +6,9 @@ from random import uniform from typing import Any, Dict, List, Optional, Type, Union import humps -from composio.constants import DEFAULT_ENTITY_ID from pydantic import BaseModel, Field, create_model -from letta.constants import COMPOSIO_ENTITY_ENV_VAR_KEY, DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG +from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG from letta.functions.interface import MultiAgentMessagingInterface from letta.orm.errors import NoResultFound from letta.schemas.enums import MessageRole @@ -21,34 +20,6 @@ from letta.server.rest_api.utils import get_letta_server from letta.settings import settings -# TODO: This is kind of hacky, as this is used to search up the action later on composio's side -# TODO: So be very careful changing/removing these pair of functions -def generate_func_name_from_composio_action(action_name: str) -> str: - """ - Generates the composio function name from the composio action. - - Args: - action_name: The composio action name - - Returns: - function name - """ - return action_name.lower() - - -def generate_composio_action_from_func_name(func_name: str) -> str: - """ - Generates the composio action from the composio function name. - - Args: - func_name: The composio function name - - Returns: - composio action name - """ - return func_name.upper() - - # TODO needed? def generate_mcp_tool_wrapper(mcp_tool_name: str) -> tuple[str, str]: @@ -58,71 +29,20 @@ def {mcp_tool_name}(**kwargs): """ # Compile safety check - assert_code_gen_compilable(wrapper_function_str.strip()) + _assert_code_gen_compilable(wrapper_function_str.strip()) return mcp_tool_name, wrapper_function_str.strip() -def generate_composio_tool_wrapper(action_name: str) -> tuple[str, str]: - # Generate func name - func_name = generate_func_name_from_composio_action(action_name) - - wrapper_function_str = f"""\ -def {func_name}(**kwargs): - raise RuntimeError("Something went wrong - we should never be using the persisted source code for Composio. Please reach out to Letta team") -""" - - # Compile safety check - assert_code_gen_compilable(wrapper_function_str.strip()) - - return func_name, wrapper_function_str.strip() - - -def execute_composio_action(action_name: str, args: dict, api_key: Optional[str] = None, entity_id: Optional[str] = None) -> Any: - import os - - from composio.exceptions import ( - ApiKeyNotProvidedError, - ComposioSDKError, - ConnectedAccountNotFoundError, - EnumMetadataNotFound, - EnumStringNotFound, - ) - from composio_langchain import ComposioToolSet - - entity_id = entity_id or os.getenv(COMPOSIO_ENTITY_ENV_VAR_KEY, DEFAULT_ENTITY_ID) - try: - composio_toolset = ComposioToolSet(api_key=api_key, entity_id=entity_id, lock=False) - response = composio_toolset.execute_action(action=action_name, params=args) - except ApiKeyNotProvidedError: - raise RuntimeError( - f"Composio API key is missing for action '{action_name}'. " - "Please set the sandbox environment variables either through the ADE or the API." - ) - except ConnectedAccountNotFoundError: - raise RuntimeError(f"No connected account was found for action '{action_name}'. " "Please link an account and try again.") - except EnumStringNotFound as e: - raise RuntimeError(f"Invalid value provided for action '{action_name}': " + str(e) + ". Please check the action parameters.") - except EnumMetadataNotFound as e: - raise RuntimeError(f"Invalid value provided for action '{action_name}': " + str(e) + ". Please check the action parameters.") - except ComposioSDKError as e: - raise RuntimeError(f"An unexpected error occurred in Composio SDK while executing action '{action_name}': " + str(e)) - - if "error" in response: - raise RuntimeError(f"Error while executing action '{action_name}': " + str(response["error"])) - - return response.get("data") - - def generate_langchain_tool_wrapper( tool: "LangChainBaseTool", additional_imports_module_attr_map: dict[str, str] = None ) -> tuple[str, str]: tool_name = tool.__class__.__name__ import_statement = f"from langchain_community.tools import {tool_name}" - extra_module_imports = generate_import_code(additional_imports_module_attr_map) + extra_module_imports = _generate_import_code(additional_imports_module_attr_map) # Safety check that user has passed in all required imports: - assert_all_classes_are_imported(tool, additional_imports_module_attr_map) + _assert_all_classes_are_imported(tool, additional_imports_module_attr_map) tool_instantiation = f"tool = {generate_imported_tool_instantiation_call_str(tool)}" run_call = f"return tool._run(**kwargs)" @@ -139,25 +59,25 @@ def {func_name}(**kwargs): """ # Compile safety check - assert_code_gen_compilable(wrapper_function_str) + _assert_code_gen_compilable(wrapper_function_str) return func_name, wrapper_function_str -def assert_code_gen_compilable(code_str): +def _assert_code_gen_compilable(code_str): try: compile(code_str, "", "exec") except SyntaxError as e: print(f"Syntax error in code: {e}") -def assert_all_classes_are_imported(tool: Union["LangChainBaseTool"], additional_imports_module_attr_map: dict[str, str]) -> None: +def _assert_all_classes_are_imported(tool: Union["LangChainBaseTool"], additional_imports_module_attr_map: dict[str, str]) -> None: # Safety check that user has passed in all required imports: tool_name = tool.__class__.__name__ current_class_imports = {tool_name} if additional_imports_module_attr_map: current_class_imports.update(set(additional_imports_module_attr_map.values())) - required_class_imports = set(find_required_class_names_for_import(tool)) + required_class_imports = set(_find_required_class_names_for_import(tool)) if not current_class_imports.issuperset(required_class_imports): err_msg = f"[ERROR] You are missing module_attr pairs in `additional_imports_module_attr_map`. Currently, you have imports for {current_class_imports}, but the required classes for import are {required_class_imports}" @@ -165,7 +85,7 @@ def assert_all_classes_are_imported(tool: Union["LangChainBaseTool"], additional raise RuntimeError(err_msg) -def find_required_class_names_for_import(obj: Union["LangChainBaseTool", BaseModel]) -> list[str]: +def _find_required_class_names_for_import(obj: Union["LangChainBaseTool", BaseModel]) -> list[str]: """ Finds all the class names for required imports when instantiating the `obj`. NOTE: This does not return the full import path, only the class name. @@ -181,7 +101,7 @@ def find_required_class_names_for_import(obj: Union["LangChainBaseTool", BaseMod # Collect all possible candidates for BaseModel objects candidates = [] - if is_base_model(curr_obj): + if _is_base_model(curr_obj): # If it is a base model, we get all the values of the object parameters # i.e., if obj('b' = ), we would want to inspect fields = dict(curr_obj) @@ -198,7 +118,7 @@ def find_required_class_names_for_import(obj: Union["LangChainBaseTool", BaseMod # Filter out all candidates that are not BaseModels # In the list example above, ['a', 3, None, ], we want to filter out 'a', 3, and None - candidates = filter(lambda x: is_base_model(x), candidates) + candidates = filter(lambda x: _is_base_model(x), candidates) # Classic BFS here for c in candidates: @@ -216,7 +136,7 @@ def generate_imported_tool_instantiation_call_str(obj: Any) -> Optional[str]: # If it is a basic Python type, we trivially return the string version of that value # Handle basic types return repr(obj) - elif is_base_model(obj): + elif _is_base_model(obj): # Otherwise, if it is a BaseModel # We want to pull out all the parameters, and reformat them into strings # e.g. {arg}={value} @@ -269,11 +189,11 @@ def generate_imported_tool_instantiation_call_str(obj: Any) -> Optional[str]: return None -def is_base_model(obj: Any): +def _is_base_model(obj: Any): return isinstance(obj, BaseModel) -def generate_import_code(module_attr_map: Optional[dict]): +def _generate_import_code(module_attr_map: Optional[dict]): if not module_attr_map: return "" @@ -286,7 +206,7 @@ def generate_import_code(module_attr_map: Optional[dict]): return "\n".join(code_lines) -def parse_letta_response_for_assistant_message( +def _parse_letta_response_for_assistant_message( target_agent_id: str, letta_response: LettaResponse, ) -> Optional[str]: @@ -346,7 +266,7 @@ def execute_send_message_to_agent( return asyncio.run(async_execute_send_message_to_agent(sender_agent, messages, other_agent_id, log_prefix)) -async def send_message_to_agent_no_stream( +async def _send_message_to_agent_no_stream( server: "SyncServer", agent_id: str, actor: User, @@ -375,7 +295,7 @@ async def send_message_to_agent_no_stream( return LettaResponse(messages=final_messages, usage=usage_stats) -async def async_send_message_with_retries( +async def _async_send_message_with_retries( server: "SyncServer", sender_agent: "Agent", target_agent_id: str, @@ -389,7 +309,7 @@ async def async_send_message_with_retries( for attempt in range(1, max_retries + 1): try: response = await asyncio.wait_for( - send_message_to_agent_no_stream( + _send_message_to_agent_no_stream( server=server, agent_id=target_agent_id, actor=sender_agent.user, @@ -399,7 +319,7 @@ async def async_send_message_with_retries( ) # Then parse out the assistant message - assistant_message = parse_letta_response_for_assistant_message(target_agent_id, response) + assistant_message = _parse_letta_response_for_assistant_message(target_agent_id, response) if assistant_message: sender_agent.logger.info(f"{logging_prefix} - {assistant_message}") return assistant_message diff --git a/letta/groups/helpers.py b/letta/groups/helpers.py index 039230df4..f66269c71 100644 --- a/letta/groups/helpers.py +++ b/letta/groups/helpers.py @@ -76,6 +76,7 @@ def load_multi_agent( agent_state=agent_state, interface=interface, user=actor, + mcp_clients=mcp_clients, group_id=group.id, agent_ids=group.agent_ids, description=group.description, diff --git a/letta/groups/sleeptime_multi_agent.py b/letta/groups/sleeptime_multi_agent.py index 6349b57b3..87f49b105 100644 --- a/letta/groups/sleeptime_multi_agent.py +++ b/letta/groups/sleeptime_multi_agent.py @@ -1,9 +1,10 @@ import asyncio import threading from datetime import datetime, timezone -from typing import List, Optional +from typing import Dict, List, Optional from letta.agent import Agent, AgentState +from letta.functions.mcp_client.base_client import BaseMCPClient from letta.groups.helpers import stringify_message from letta.interface import AgentInterface from letta.orm import User @@ -26,6 +27,7 @@ class SleeptimeMultiAgent(Agent): interface: AgentInterface, agent_state: AgentState, user: User, + mcp_clients: Optional[Dict[str, BaseMCPClient]] = None, # custom group_id: str = "", agent_ids: List[str] = [], @@ -115,6 +117,7 @@ class SleeptimeMultiAgent(Agent): agent_state=participant_agent_state, interface=StreamingServerInterface(), user=self.user, + mcp_clients=self.mcp_clients, ) prior_messages = [] @@ -212,6 +215,7 @@ class SleeptimeMultiAgent(Agent): agent_state=self.agent_state, interface=self.interface, user=self.user, + mcp_clients=self.mcp_clients, ) # Perform main agent step usage_stats = main_agent.step( diff --git a/letta/helpers/message_helper.py b/letta/helpers/message_helper.py index 41d2b8f69..be05b85a2 100644 --- a/letta/helpers/message_helper.py +++ b/letta/helpers/message_helper.py @@ -4,7 +4,24 @@ from letta.schemas.letta_message_content import TextContent from letta.schemas.message import Message, MessageCreate -def prepare_input_message_create( +def convert_message_creates_to_messages( + messages: list[MessageCreate], + agent_id: str, + wrap_user_message: bool = True, + wrap_system_message: bool = True, +) -> list[Message]: + return [ + _convert_message_create_to_message( + message=message, + agent_id=agent_id, + wrap_user_message=wrap_user_message, + wrap_system_message=wrap_system_message, + ) + for message in messages + ] + + +def _convert_message_create_to_message( message: MessageCreate, agent_id: str, wrap_user_message: bool = True, @@ -23,12 +40,12 @@ def prepare_input_message_create( raise ValueError("Message content is empty or invalid") # Apply wrapping if needed - if message.role == MessageRole.user and wrap_user_message: + if message.role not in {MessageRole.user, MessageRole.system}: + raise ValueError(f"Invalid message role: {message.role}") + elif message.role == MessageRole.user and wrap_user_message: message_content = system.package_user_message(user_message=message_content) elif message.role == MessageRole.system and wrap_system_message: message_content = system.package_system_message(system_message=message_content) - elif message.role not in {MessageRole.user, MessageRole.system}: - raise ValueError(f"Invalid message role: {message.role}") return Message( agent_id=agent_id, diff --git a/letta/helpers/tool_execution_helper.py b/letta/helpers/tool_execution_helper.py index 2ea281578..1ec3f6462 100644 --- a/letta/helpers/tool_execution_helper.py +++ b/letta/helpers/tool_execution_helper.py @@ -3,7 +3,7 @@ from typing import Any, Dict, Optional from letta.constants import COMPOSIO_ENTITY_ENV_VAR_KEY, PRE_EXECUTION_MESSAGE_ARG from letta.functions.ast_parsers import coerce_dict_args_by_annotations, get_function_annotations_from_source -from letta.functions.helpers import execute_composio_action, generate_composio_action_from_func_name +from letta.functions.composio_helpers import execute_composio_action, generate_composio_action_from_func_name from letta.helpers.composio_helpers import get_composio_api_key from letta.orm.enums import ToolType from letta.schemas.agent import AgentState diff --git a/letta/interfaces/anthropic_streaming_interface.py b/letta/interfaces/anthropic_streaming_interface.py index 84178932a..974673f84 100644 --- a/letta/interfaces/anthropic_streaming_interface.py +++ b/letta/interfaces/anthropic_streaming_interface.py @@ -35,7 +35,7 @@ from letta.schemas.letta_message import ( from letta.schemas.letta_message_content import ReasoningContent, RedactedReasoningContent, TextContent from letta.schemas.message import Message from letta.schemas.openai.chat_completion_response import FunctionCall, ToolCall -from letta.server.rest_api.optimistic_json_parser import OptimisticJSONParser +from letta.server.rest_api.json_parser import JSONParser, PydanticJSONParser logger = get_logger(__name__) @@ -56,7 +56,7 @@ class AnthropicStreamingInterface: """ def __init__(self, use_assistant_message: bool = False, put_inner_thoughts_in_kwarg: bool = False): - self.optimistic_json_parser: OptimisticJSONParser = OptimisticJSONParser() + self.json_parser: JSONParser = PydanticJSONParser() self.use_assistant_message = use_assistant_message # Premake IDs for database writes @@ -68,7 +68,7 @@ class AnthropicStreamingInterface: self.accumulated_inner_thoughts = [] self.tool_call_id = None self.tool_call_name = None - self.accumulated_tool_call_args = [] + self.accumulated_tool_call_args = "" self.previous_parse = {} # usage trackers @@ -85,193 +85,200 @@ class AnthropicStreamingInterface: def get_tool_call_object(self) -> ToolCall: """Useful for agent loop""" - return ToolCall( - id=self.tool_call_id, function=FunctionCall(arguments="".join(self.accumulated_tool_call_args), name=self.tool_call_name) - ) + return ToolCall(id=self.tool_call_id, function=FunctionCall(arguments=self.accumulated_tool_call_args, name=self.tool_call_name)) def _check_inner_thoughts_complete(self, combined_args: str) -> bool: """ Check if inner thoughts are complete in the current tool call arguments by looking for a closing quote after the inner_thoughts field """ - if not self.put_inner_thoughts_in_kwarg: - # None of the things should have inner thoughts in kwargs - return True - else: - parsed = self.optimistic_json_parser.parse(combined_args) - # TODO: This will break on tools with 0 input - return len(parsed.keys()) > 1 and INNER_THOUGHTS_KWARG in parsed.keys() + try: + if not self.put_inner_thoughts_in_kwarg: + # None of the things should have inner thoughts in kwargs + return True + else: + parsed = self.json_parser.parse(combined_args) + # TODO: This will break on tools with 0 input + return len(parsed.keys()) > 1 and INNER_THOUGHTS_KWARG in parsed.keys() + except Exception as e: + logger.error("Error checking inner thoughts: %s", e) + raise async def process(self, stream: AsyncStream[BetaRawMessageStreamEvent]) -> AsyncGenerator[LettaMessage, None]: - async with stream: - async for event in stream: - # TODO: Support BetaThinkingBlock, BetaRedactedThinkingBlock - if isinstance(event, BetaRawContentBlockStartEvent): - content = event.content_block + try: + async with stream: + async for event in stream: + # TODO: Support BetaThinkingBlock, BetaRedactedThinkingBlock + if isinstance(event, BetaRawContentBlockStartEvent): + content = event.content_block - if isinstance(content, BetaTextBlock): - self.anthropic_mode = EventMode.TEXT - # TODO: Can capture citations, etc. - elif isinstance(content, BetaToolUseBlock): - self.anthropic_mode = EventMode.TOOL_USE - self.tool_call_id = content.id - self.tool_call_name = content.name - self.inner_thoughts_complete = False + if isinstance(content, BetaTextBlock): + self.anthropic_mode = EventMode.TEXT + # TODO: Can capture citations, etc. + elif isinstance(content, BetaToolUseBlock): + self.anthropic_mode = EventMode.TOOL_USE + self.tool_call_id = content.id + self.tool_call_name = content.name + self.inner_thoughts_complete = False - if not self.use_assistant_message: - # Buffer the initial tool call message instead of yielding immediately - tool_call_msg = ToolCallMessage( - id=self.letta_tool_message_id, - tool_call=ToolCallDelta(name=self.tool_call_name, tool_call_id=self.tool_call_id), + if not self.use_assistant_message: + # Buffer the initial tool call message instead of yielding immediately + tool_call_msg = ToolCallMessage( + id=self.letta_tool_message_id, + tool_call=ToolCallDelta(name=self.tool_call_name, tool_call_id=self.tool_call_id), + date=datetime.now(timezone.utc).isoformat(), + ) + self.tool_call_buffer.append(tool_call_msg) + elif isinstance(content, BetaThinkingBlock): + self.anthropic_mode = EventMode.THINKING + # TODO: Can capture signature, etc. + elif isinstance(content, BetaRedactedThinkingBlock): + self.anthropic_mode = EventMode.REDACTED_THINKING + + hidden_reasoning_message = HiddenReasoningMessage( + id=self.letta_assistant_message_id, + state="redacted", + hidden_reasoning=content.data, date=datetime.now(timezone.utc).isoformat(), ) - self.tool_call_buffer.append(tool_call_msg) - elif isinstance(content, BetaThinkingBlock): - self.anthropic_mode = EventMode.THINKING - # TODO: Can capture signature, etc. - elif isinstance(content, BetaRedactedThinkingBlock): - self.anthropic_mode = EventMode.REDACTED_THINKING + self.reasoning_messages.append(hidden_reasoning_message) + yield hidden_reasoning_message - hidden_reasoning_message = HiddenReasoningMessage( - id=self.letta_assistant_message_id, - state="redacted", - hidden_reasoning=content.data, - date=datetime.now(timezone.utc).isoformat(), - ) - self.reasoning_messages.append(hidden_reasoning_message) - yield hidden_reasoning_message + elif isinstance(event, BetaRawContentBlockDeltaEvent): + delta = event.delta - elif isinstance(event, BetaRawContentBlockDeltaEvent): - delta = event.delta + if isinstance(delta, BetaTextDelta): + # Safety check + if not self.anthropic_mode == EventMode.TEXT: + raise RuntimeError( + f"Streaming integrity failed - received BetaTextDelta object while not in TEXT EventMode: {delta}" + ) - if isinstance(delta, BetaTextDelta): - # Safety check - if not self.anthropic_mode == EventMode.TEXT: - raise RuntimeError( - f"Streaming integrity failed - received BetaTextDelta object while not in TEXT EventMode: {delta}" - ) + # TODO: Strip out more robustly, this is pretty hacky lol + delta.text = delta.text.replace("", "") + self.accumulated_inner_thoughts.append(delta.text) - # TODO: Strip out more robustly, this is pretty hacky lol - delta.text = delta.text.replace("", "") - self.accumulated_inner_thoughts.append(delta.text) - - reasoning_message = ReasoningMessage( - id=self.letta_assistant_message_id, - reasoning=self.accumulated_inner_thoughts[-1], - date=datetime.now(timezone.utc).isoformat(), - ) - self.reasoning_messages.append(reasoning_message) - yield reasoning_message - - elif isinstance(delta, BetaInputJSONDelta): - if not self.anthropic_mode == EventMode.TOOL_USE: - raise RuntimeError( - f"Streaming integrity failed - received BetaInputJSONDelta object while not in TOOL_USE EventMode: {delta}" - ) - - self.accumulated_tool_call_args.append(delta.partial_json) - combined_args = "".join(self.accumulated_tool_call_args) - current_parsed = self.optimistic_json_parser.parse(combined_args) - - # Start detecting a difference in inner thoughts - previous_inner_thoughts = self.previous_parse.get(INNER_THOUGHTS_KWARG, "") - current_inner_thoughts = current_parsed.get(INNER_THOUGHTS_KWARG, "") - inner_thoughts_diff = current_inner_thoughts[len(previous_inner_thoughts) :] - - if inner_thoughts_diff: reasoning_message = ReasoningMessage( id=self.letta_assistant_message_id, - reasoning=inner_thoughts_diff, + reasoning=self.accumulated_inner_thoughts[-1], date=datetime.now(timezone.utc).isoformat(), ) self.reasoning_messages.append(reasoning_message) yield reasoning_message - # Check if inner thoughts are complete - if so, flush the buffer - if not self.inner_thoughts_complete and self._check_inner_thoughts_complete(combined_args): - self.inner_thoughts_complete = True - # Flush all buffered tool call messages + elif isinstance(delta, BetaInputJSONDelta): + if not self.anthropic_mode == EventMode.TOOL_USE: + raise RuntimeError( + f"Streaming integrity failed - received BetaInputJSONDelta object while not in TOOL_USE EventMode: {delta}" + ) + + self.accumulated_tool_call_args += delta.partial_json + current_parsed = self.json_parser.parse(self.accumulated_tool_call_args) + + # Start detecting a difference in inner thoughts + previous_inner_thoughts = self.previous_parse.get(INNER_THOUGHTS_KWARG, "") + current_inner_thoughts = current_parsed.get(INNER_THOUGHTS_KWARG, "") + inner_thoughts_diff = current_inner_thoughts[len(previous_inner_thoughts) :] + + if inner_thoughts_diff: + reasoning_message = ReasoningMessage( + id=self.letta_assistant_message_id, + reasoning=inner_thoughts_diff, + date=datetime.now(timezone.utc).isoformat(), + ) + self.reasoning_messages.append(reasoning_message) + yield reasoning_message + + # Check if inner thoughts are complete - if so, flush the buffer + if not self.inner_thoughts_complete and self._check_inner_thoughts_complete(self.accumulated_tool_call_args): + self.inner_thoughts_complete = True + # Flush all buffered tool call messages + for buffered_msg in self.tool_call_buffer: + yield buffered_msg + self.tool_call_buffer = [] + + # Start detecting special case of "send_message" + if self.tool_call_name == DEFAULT_MESSAGE_TOOL and self.use_assistant_message: + previous_send_message = self.previous_parse.get(DEFAULT_MESSAGE_TOOL_KWARG, "") + current_send_message = current_parsed.get(DEFAULT_MESSAGE_TOOL_KWARG, "") + send_message_diff = current_send_message[len(previous_send_message) :] + + # Only stream out if it's not an empty string + if send_message_diff: + yield AssistantMessage( + id=self.letta_assistant_message_id, + content=[TextContent(text=send_message_diff)], + date=datetime.now(timezone.utc).isoformat(), + ) + else: + # Otherwise, it is a normal tool call - buffer or yield based on inner thoughts status + tool_call_msg = ToolCallMessage( + id=self.letta_tool_message_id, + tool_call=ToolCallDelta(arguments=delta.partial_json), + date=datetime.now(timezone.utc).isoformat(), + ) + + if self.inner_thoughts_complete: + yield tool_call_msg + else: + self.tool_call_buffer.append(tool_call_msg) + + # Set previous parse + self.previous_parse = current_parsed + elif isinstance(delta, BetaThinkingDelta): + # Safety check + if not self.anthropic_mode == EventMode.THINKING: + raise RuntimeError( + f"Streaming integrity failed - received BetaThinkingBlock object while not in THINKING EventMode: {delta}" + ) + + reasoning_message = ReasoningMessage( + id=self.letta_assistant_message_id, + source="reasoner_model", + reasoning=delta.thinking, + date=datetime.now(timezone.utc).isoformat(), + ) + self.reasoning_messages.append(reasoning_message) + yield reasoning_message + elif isinstance(delta, BetaSignatureDelta): + # Safety check + if not self.anthropic_mode == EventMode.THINKING: + raise RuntimeError( + f"Streaming integrity failed - received BetaSignatureDelta object while not in THINKING EventMode: {delta}" + ) + + reasoning_message = ReasoningMessage( + id=self.letta_assistant_message_id, + source="reasoner_model", + reasoning="", + date=datetime.now(timezone.utc).isoformat(), + signature=delta.signature, + ) + self.reasoning_messages.append(reasoning_message) + yield reasoning_message + elif isinstance(event, BetaRawMessageStartEvent): + self.message_id = event.message.id + self.input_tokens += event.message.usage.input_tokens + self.output_tokens += event.message.usage.output_tokens + elif isinstance(event, BetaRawMessageDeltaEvent): + self.output_tokens += event.usage.output_tokens + elif isinstance(event, BetaRawMessageStopEvent): + # Don't do anything here! We don't want to stop the stream. + pass + elif isinstance(event, BetaRawContentBlockStopEvent): + # If we're exiting a tool use block and there are still buffered messages, + # we should flush them now + if self.anthropic_mode == EventMode.TOOL_USE and self.tool_call_buffer: for buffered_msg in self.tool_call_buffer: yield buffered_msg self.tool_call_buffer = [] - # Start detecting special case of "send_message" - if self.tool_call_name == DEFAULT_MESSAGE_TOOL and self.use_assistant_message: - previous_send_message = self.previous_parse.get(DEFAULT_MESSAGE_TOOL_KWARG, "") - current_send_message = current_parsed.get(DEFAULT_MESSAGE_TOOL_KWARG, "") - send_message_diff = current_send_message[len(previous_send_message) :] - - # Only stream out if it's not an empty string - if send_message_diff: - yield AssistantMessage( - id=self.letta_assistant_message_id, - content=[TextContent(text=send_message_diff)], - date=datetime.now(timezone.utc).isoformat(), - ) - else: - # Otherwise, it is a normal tool call - buffer or yield based on inner thoughts status - tool_call_msg = ToolCallMessage( - id=self.letta_tool_message_id, - tool_call=ToolCallDelta(arguments=delta.partial_json), - date=datetime.now(timezone.utc).isoformat(), - ) - - if self.inner_thoughts_complete: - yield tool_call_msg - else: - self.tool_call_buffer.append(tool_call_msg) - - # Set previous parse - self.previous_parse = current_parsed - elif isinstance(delta, BetaThinkingDelta): - # Safety check - if not self.anthropic_mode == EventMode.THINKING: - raise RuntimeError( - f"Streaming integrity failed - received BetaThinkingBlock object while not in THINKING EventMode: {delta}" - ) - - reasoning_message = ReasoningMessage( - id=self.letta_assistant_message_id, - source="reasoner_model", - reasoning=delta.thinking, - date=datetime.now(timezone.utc).isoformat(), - ) - self.reasoning_messages.append(reasoning_message) - yield reasoning_message - elif isinstance(delta, BetaSignatureDelta): - # Safety check - if not self.anthropic_mode == EventMode.THINKING: - raise RuntimeError( - f"Streaming integrity failed - received BetaSignatureDelta object while not in THINKING EventMode: {delta}" - ) - - reasoning_message = ReasoningMessage( - id=self.letta_assistant_message_id, - source="reasoner_model", - reasoning="", - date=datetime.now(timezone.utc).isoformat(), - signature=delta.signature, - ) - self.reasoning_messages.append(reasoning_message) - yield reasoning_message - elif isinstance(event, BetaRawMessageStartEvent): - self.message_id = event.message.id - self.input_tokens += event.message.usage.input_tokens - self.output_tokens += event.message.usage.output_tokens - elif isinstance(event, BetaRawMessageDeltaEvent): - self.output_tokens += event.usage.output_tokens - elif isinstance(event, BetaRawMessageStopEvent): - # Don't do anything here! We don't want to stop the stream. - pass - elif isinstance(event, BetaRawContentBlockStopEvent): - # If we're exiting a tool use block and there are still buffered messages, - # we should flush them now - if self.anthropic_mode == EventMode.TOOL_USE and self.tool_call_buffer: - for buffered_msg in self.tool_call_buffer: - yield buffered_msg - self.tool_call_buffer = [] - - self.anthropic_mode = None + self.anthropic_mode = None + except Exception as e: + logger.error("Error processing stream: %s", e) + raise + finally: + logger.info("AnthropicStreamingInterface: Stream processing complete.") def get_reasoning_content(self) -> List[Union[TextContent, ReasoningContent, RedactedReasoningContent]]: def _process_group( diff --git a/letta/interfaces/openai_chat_completions_streaming_interface.py b/letta/interfaces/openai_chat_completions_streaming_interface.py index 0f3bd8419..6ff38cab1 100644 --- a/letta/interfaces/openai_chat_completions_streaming_interface.py +++ b/letta/interfaces/openai_chat_completions_streaming_interface.py @@ -5,7 +5,7 @@ from openai.types.chat.chat_completion_chunk import ChatCompletionChunk, Choice, from letta.constants import PRE_EXECUTION_MESSAGE_ARG from letta.interfaces.utils import _format_sse_chunk -from letta.server.rest_api.optimistic_json_parser import OptimisticJSONParser +from letta.server.rest_api.json_parser import OptimisticJSONParser class OpenAIChatCompletionsStreamingInterface: diff --git a/letta/llm_api/anthropic.py b/letta/llm_api/anthropic.py index 59939e4d6..08e70d069 100644 --- a/letta/llm_api/anthropic.py +++ b/letta/llm_api/anthropic.py @@ -26,6 +26,7 @@ from letta.llm_api.helpers import add_inner_thoughts_to_functions from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages from letta.log import get_logger +from letta.schemas.enums import ProviderType from letta.schemas.message import Message as _Message from letta.schemas.message import MessageRole as _MessageRole from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, Tool @@ -128,11 +129,12 @@ def anthropic_get_model_list(url: str, api_key: Union[str, None]) -> dict: # NOTE: currently there is no GET /models, so we need to hardcode # return MODEL_LIST - anthropic_override_key = ProviderManager().get_anthropic_override_key() - if anthropic_override_key: - anthropic_client = anthropic.Anthropic(api_key=anthropic_override_key) + if api_key: + anthropic_client = anthropic.Anthropic(api_key=api_key) elif model_settings.anthropic_api_key: anthropic_client = anthropic.Anthropic() + else: + raise ValueError("No API key provided") models = anthropic_client.models.list() models_json = models.model_dump() @@ -738,13 +740,14 @@ def anthropic_chat_completions_request( put_inner_thoughts_in_kwargs: bool = False, extended_thinking: bool = False, max_reasoning_tokens: Optional[int] = None, + provider_name: Optional[str] = None, betas: List[str] = ["tools-2024-04-04"], ) -> ChatCompletionResponse: """https://docs.anthropic.com/claude/docs/tool-use""" anthropic_client = None - anthropic_override_key = ProviderManager().get_anthropic_override_key() - if anthropic_override_key: - anthropic_client = anthropic.Anthropic(api_key=anthropic_override_key) + if provider_name and provider_name != ProviderType.anthropic.value: + api_key = ProviderManager().get_override_key(provider_name) + anthropic_client = anthropic.Anthropic(api_key=api_key) elif model_settings.anthropic_api_key: anthropic_client = anthropic.Anthropic() else: @@ -796,6 +799,7 @@ def anthropic_chat_completions_request_stream( put_inner_thoughts_in_kwargs: bool = False, extended_thinking: bool = False, max_reasoning_tokens: Optional[int] = None, + provider_name: Optional[str] = None, betas: List[str] = ["tools-2024-04-04"], ) -> Generator[ChatCompletionChunkResponse, None, None]: """Stream chat completions from Anthropic API. @@ -810,10 +814,9 @@ def anthropic_chat_completions_request_stream( extended_thinking=extended_thinking, max_reasoning_tokens=max_reasoning_tokens, ) - - anthropic_override_key = ProviderManager().get_anthropic_override_key() - if anthropic_override_key: - anthropic_client = anthropic.Anthropic(api_key=anthropic_override_key) + if provider_name and provider_name != ProviderType.anthropic.value: + api_key = ProviderManager().get_override_key(provider_name) + anthropic_client = anthropic.Anthropic(api_key=api_key) elif model_settings.anthropic_api_key: anthropic_client = anthropic.Anthropic() @@ -860,6 +863,7 @@ def anthropic_chat_completions_process_stream( put_inner_thoughts_in_kwargs: bool = False, extended_thinking: bool = False, max_reasoning_tokens: Optional[int] = None, + provider_name: Optional[str] = None, create_message_id: bool = True, create_message_datetime: bool = True, betas: List[str] = ["tools-2024-04-04"], @@ -944,6 +948,7 @@ def anthropic_chat_completions_process_stream( put_inner_thoughts_in_kwargs=put_inner_thoughts_in_kwargs, extended_thinking=extended_thinking, max_reasoning_tokens=max_reasoning_tokens, + provider_name=provider_name, betas=betas, ) ): diff --git a/letta/llm_api/anthropic_client.py b/letta/llm_api/anthropic_client.py index 863fcef0d..35317dd82 100644 --- a/letta/llm_api/anthropic_client.py +++ b/letta/llm_api/anthropic_client.py @@ -27,6 +27,7 @@ from letta.llm_api.helpers import add_inner_thoughts_to_functions, unpack_all_in from letta.llm_api.llm_client_base import LLMClientBase from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION from letta.log import get_logger +from letta.schemas.enums import ProviderType from letta.schemas.llm_config import LLMConfig from letta.schemas.message import Message as PydanticMessage from letta.schemas.openai.chat_completion_request import Tool @@ -112,7 +113,10 @@ class AnthropicClient(LLMClientBase): @trace_method def _get_anthropic_client(self, async_client: bool = False) -> Union[anthropic.AsyncAnthropic, anthropic.Anthropic]: - override_key = ProviderManager().get_anthropic_override_key() + override_key = None + if self.provider_name and self.provider_name != ProviderType.anthropic.value: + override_key = ProviderManager().get_override_key(self.provider_name) + if async_client: return anthropic.AsyncAnthropic(api_key=override_key) if override_key else anthropic.AsyncAnthropic() return anthropic.Anthropic(api_key=override_key) if override_key else anthropic.Anthropic() diff --git a/letta/llm_api/google_vertex_client.py b/letta/llm_api/google_vertex_client.py index a987d8a94..177eac8d9 100644 --- a/letta/llm_api/google_vertex_client.py +++ b/letta/llm_api/google_vertex_client.py @@ -63,7 +63,7 @@ class GoogleVertexClient(GoogleAIClient): # Add thinking_config # If enable_reasoner is False, set thinking_budget to 0 # Otherwise, use the value from max_reasoning_tokens - thinking_budget = 0 if not self.llm_config.enable_reasoner else self.llm_config.max_reasoning_tokens + thinking_budget = 0 if not llm_config.enable_reasoner else llm_config.max_reasoning_tokens thinking_config = ThinkingConfig( thinking_budget=thinking_budget, ) diff --git a/letta/llm_api/llm_api_tools.py b/letta/llm_api/llm_api_tools.py index be1b9d82a..b1112290c 100644 --- a/letta/llm_api/llm_api_tools.py +++ b/letta/llm_api/llm_api_tools.py @@ -24,6 +24,7 @@ from letta.llm_api.openai import ( from letta.local_llm.chat_completion_proxy import get_chat_completion from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages +from letta.schemas.enums import ProviderType from letta.schemas.llm_config import LLMConfig from letta.schemas.message import Message from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, cast_message_to_subtype @@ -171,6 +172,10 @@ def create( if model_settings.openai_api_key is None and llm_config.model_endpoint == "https://api.openai.com/v1": # only is a problem if we are *not* using an openai proxy raise LettaConfigurationError(message="OpenAI key is missing from letta config file", missing_fields=["openai_api_key"]) + elif llm_config.provider_name and llm_config.provider_name != ProviderType.openai.value: + from letta.services.provider_manager import ProviderManager + + api_key = ProviderManager().get_override_key(llm_config.provider_name) elif model_settings.openai_api_key is None: # the openai python client requires a dummy API key api_key = "DUMMY_API_KEY" @@ -373,6 +378,7 @@ def create( stream_interface=stream_interface, extended_thinking=llm_config.enable_reasoner, max_reasoning_tokens=llm_config.max_reasoning_tokens, + provider_name=llm_config.provider_name, name=name, ) @@ -383,6 +389,7 @@ def create( put_inner_thoughts_in_kwargs=llm_config.put_inner_thoughts_in_kwargs, extended_thinking=llm_config.enable_reasoner, max_reasoning_tokens=llm_config.max_reasoning_tokens, + provider_name=llm_config.provider_name, ) if llm_config.put_inner_thoughts_in_kwargs: diff --git a/letta/llm_api/llm_client.py b/letta/llm_api/llm_client.py index 674f94974..a63913a4d 100644 --- a/letta/llm_api/llm_client.py +++ b/letta/llm_api/llm_client.py @@ -9,8 +9,10 @@ class LLMClient: @staticmethod def create( - provider: ProviderType, + provider_type: ProviderType, + provider_name: Optional[str] = None, put_inner_thoughts_first: bool = True, + actor_id: Optional[str] = None, ) -> Optional[LLMClientBase]: """ Create an LLM client based on the model endpoint type. @@ -25,30 +27,38 @@ class LLMClient: Raises: ValueError: If the model endpoint type is not supported """ - match provider: + match provider_type: case ProviderType.google_ai: from letta.llm_api.google_ai_client import GoogleAIClient return GoogleAIClient( + provider_name=provider_name, put_inner_thoughts_first=put_inner_thoughts_first, + actor_id=actor_id, ) case ProviderType.google_vertex: from letta.llm_api.google_vertex_client import GoogleVertexClient return GoogleVertexClient( + provider_name=provider_name, put_inner_thoughts_first=put_inner_thoughts_first, + actor_id=actor_id, ) case ProviderType.anthropic: from letta.llm_api.anthropic_client import AnthropicClient return AnthropicClient( + provider_name=provider_name, put_inner_thoughts_first=put_inner_thoughts_first, + actor_id=actor_id, ) case ProviderType.openai: from letta.llm_api.openai_client import OpenAIClient return OpenAIClient( + provider_name=provider_name, put_inner_thoughts_first=put_inner_thoughts_first, + actor_id=actor_id, ) case _: return None diff --git a/letta/llm_api/llm_client_base.py b/letta/llm_api/llm_client_base.py index 5c7dcab9e..223921f9e 100644 --- a/letta/llm_api/llm_client_base.py +++ b/letta/llm_api/llm_client_base.py @@ -20,9 +20,13 @@ class LLMClientBase: def __init__( self, + provider_name: Optional[str] = None, put_inner_thoughts_first: Optional[bool] = True, use_tool_naming: bool = True, + actor_id: Optional[str] = None, ): + self.actor_id = actor_id + self.provider_name = provider_name self.put_inner_thoughts_first = put_inner_thoughts_first self.use_tool_naming = use_tool_naming diff --git a/letta/llm_api/openai.py b/letta/llm_api/openai.py index 578f2d020..d72fb2597 100644 --- a/letta/llm_api/openai.py +++ b/letta/llm_api/openai.py @@ -157,11 +157,17 @@ def build_openai_chat_completions_request( # if "gpt-4o" in llm_config.model or "gpt-4-turbo" in llm_config.model or "gpt-3.5-turbo" in llm_config.model: # data.response_format = {"type": "json_object"} - if llm_config.model_endpoint == LETTA_MODEL_ENDPOINT: - # override user id for inference.memgpt.ai - import uuid + # always set user id for openai requests + if user_id: + data.user = str(user_id) + + if llm_config.model_endpoint == LETTA_MODEL_ENDPOINT: + if not user_id: + # override user id for inference.letta.com + import uuid + + data.user = str(uuid.UUID(int=0)) - data.user = str(uuid.UUID(int=0)) data.model = "memgpt-openai" if use_structured_output and data.tools is not None and len(data.tools) > 0: diff --git a/letta/llm_api/openai_client.py b/letta/llm_api/openai_client.py index 96e473c72..c5e512d09 100644 --- a/letta/llm_api/openai_client.py +++ b/letta/llm_api/openai_client.py @@ -22,6 +22,7 @@ from letta.llm_api.helpers import add_inner_thoughts_to_functions, convert_to_st from letta.llm_api.llm_client_base import LLMClientBase from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION, INNER_THOUGHTS_KWARG_DESCRIPTION_GO_FIRST from letta.log import get_logger +from letta.schemas.enums import ProviderType from letta.schemas.llm_config import LLMConfig from letta.schemas.message import Message as PydanticMessage from letta.schemas.openai.chat_completion_request import ChatCompletionRequest @@ -64,7 +65,14 @@ def supports_parallel_tool_calling(model: str) -> bool: class OpenAIClient(LLMClientBase): def _prepare_client_kwargs(self, llm_config: LLMConfig) -> dict: - api_key = model_settings.openai_api_key or os.environ.get("OPENAI_API_KEY") + api_key = None + if llm_config.provider_name and llm_config.provider_name != ProviderType.openai.value: + from letta.services.provider_manager import ProviderManager + + api_key = ProviderManager().get_override_key(llm_config.provider_name) + + if not api_key: + api_key = model_settings.openai_api_key or os.environ.get("OPENAI_API_KEY") # supposedly the openai python client requires a dummy API key api_key = api_key or "DUMMY_API_KEY" kwargs = {"api_key": api_key, "base_url": llm_config.model_endpoint} @@ -135,11 +143,17 @@ class OpenAIClient(LLMClientBase): temperature=llm_config.temperature if supports_temperature_param(model) else None, ) - if llm_config.model_endpoint == LETTA_MODEL_ENDPOINT: - # override user id for inference.memgpt.ai - import uuid + # always set user id for openai requests + if self.actor_id: + data.user = self.actor_id + + if llm_config.model_endpoint == LETTA_MODEL_ENDPOINT: + if not self.actor_id: + # override user id for inference.letta.com + import uuid + + data.user = str(uuid.UUID(int=0)) - data.user = str(uuid.UUID(int=0)) data.model = "memgpt-openai" if data.tools is not None and len(data.tools) > 0: diff --git a/letta/memory.py b/letta/memory.py index 6d29963f0..100d3966d 100644 --- a/letta/memory.py +++ b/letta/memory.py @@ -79,8 +79,10 @@ def summarize_messages( llm_config_no_inner_thoughts.put_inner_thoughts_in_kwargs = False llm_client = LLMClient.create( - provider=llm_config_no_inner_thoughts.model_endpoint_type, + provider_name=llm_config_no_inner_thoughts.provider_name, + provider_type=llm_config_no_inner_thoughts.model_endpoint_type, put_inner_thoughts_first=False, + actor_id=agent_state.created_by_id, ) # try to use new client, otherwise fallback to old flow # TODO: we can just directly call the LLM here? diff --git a/letta/orm/group.py b/letta/orm/group.py index 48c3b65be..489e563f5 100644 --- a/letta/orm/group.py +++ b/letta/orm/group.py @@ -21,6 +21,8 @@ class Group(SqlalchemyBase, OrganizationMixin): termination_token: Mapped[Optional[str]] = mapped_column(nullable=True, doc="") max_turns: Mapped[Optional[int]] = mapped_column(nullable=True, doc="") sleeptime_agent_frequency: Mapped[Optional[int]] = mapped_column(nullable=True, doc="") + max_message_buffer_length: Mapped[Optional[int]] = mapped_column(nullable=True, doc="") + min_message_buffer_length: Mapped[Optional[int]] = mapped_column(nullable=True, doc="") turns_counter: Mapped[Optional[int]] = mapped_column(nullable=True, doc="") last_processed_message_id: Mapped[Optional[str]] = mapped_column(nullable=True, doc="") diff --git a/letta/orm/provider.py b/letta/orm/provider.py index 2ae524b56..d85e5ef2b 100644 --- a/letta/orm/provider.py +++ b/letta/orm/provider.py @@ -1,5 +1,6 @@ from typing import TYPE_CHECKING +from sqlalchemy import UniqueConstraint from sqlalchemy.orm import Mapped, mapped_column, relationship from letta.orm.mixins import OrganizationMixin @@ -15,9 +16,18 @@ class Provider(SqlalchemyBase, OrganizationMixin): __tablename__ = "providers" __pydantic_model__ = PydanticProvider + __table_args__ = ( + UniqueConstraint( + "name", + "organization_id", + name="unique_name_organization_id", + ), + ) name: Mapped[str] = mapped_column(nullable=False, doc="The name of the provider") + provider_type: Mapped[str] = mapped_column(nullable=True, doc="The type of the provider") api_key: Mapped[str] = mapped_column(nullable=True, doc="API key used for requests to the provider.") + base_url: Mapped[str] = mapped_column(nullable=True, doc="Base URL for the provider.") # relationships organization: Mapped["Organization"] = relationship("Organization", back_populates="providers") diff --git a/letta/schemas/agent.py b/letta/schemas/agent.py index caf7b3cde..13f74d826 100644 --- a/letta/schemas/agent.py +++ b/letta/schemas/agent.py @@ -56,7 +56,6 @@ class AgentState(OrmMetadataBase, validate_assignment=True): name: str = Field(..., description="The name of the agent.") # tool rules tool_rules: Optional[List[ToolRule]] = Field(default=None, description="The list of tool rules.") - # in-context memory message_ids: Optional[List[str]] = Field(default=None, description="The ids of the messages in the agent's in-context memory.") diff --git a/letta/schemas/enums.py b/letta/schemas/enums.py index c1d54d776..6258e1e51 100644 --- a/letta/schemas/enums.py +++ b/letta/schemas/enums.py @@ -6,6 +6,17 @@ class ProviderType(str, Enum): google_ai = "google_ai" google_vertex = "google_vertex" openai = "openai" + letta = "letta" + deepseek = "deepseek" + lmstudio_openai = "lmstudio_openai" + xai = "xai" + mistral = "mistral" + ollama = "ollama" + groq = "groq" + together = "together" + azure = "azure" + vllm = "vllm" + bedrock = "bedrock" class MessageRole(str, Enum): diff --git a/letta/schemas/group.py b/letta/schemas/group.py index dce4a9e5c..de40ba5d0 100644 --- a/letta/schemas/group.py +++ b/letta/schemas/group.py @@ -32,6 +32,14 @@ class Group(GroupBase): sleeptime_agent_frequency: Optional[int] = Field(None, description="") turns_counter: Optional[int] = Field(None, description="") last_processed_message_id: Optional[str] = Field(None, description="") + max_message_buffer_length: Optional[int] = Field( + None, + description="The desired maximum length of messages in the context window of the convo agent. This is a best effort, and may be off slightly due to user/assistant interleaving.", + ) + min_message_buffer_length: Optional[int] = Field( + None, + description="The desired minimum length of messages in the context window of the convo agent. This is a best effort, and may be off-by-one due to user/assistant interleaving.", + ) class ManagerConfig(BaseModel): @@ -87,11 +95,27 @@ class SleeptimeManagerUpdate(ManagerConfig): class VoiceSleeptimeManager(ManagerConfig): manager_type: Literal[ManagerType.voice_sleeptime] = Field(ManagerType.voice_sleeptime, description="") manager_agent_id: str = Field(..., description="") + max_message_buffer_length: Optional[int] = Field( + None, + description="The desired maximum length of messages in the context window of the convo agent. This is a best effort, and may be off slightly due to user/assistant interleaving.", + ) + min_message_buffer_length: Optional[int] = Field( + None, + description="The desired minimum length of messages in the context window of the convo agent. This is a best effort, and may be off-by-one due to user/assistant interleaving.", + ) class VoiceSleeptimeManagerUpdate(ManagerConfig): manager_type: Literal[ManagerType.voice_sleeptime] = Field(ManagerType.voice_sleeptime, description="") manager_agent_id: Optional[str] = Field(None, description="") + max_message_buffer_length: Optional[int] = Field( + None, + description="The desired maximum length of messages in the context window of the convo agent. This is a best effort, and may be off slightly due to user/assistant interleaving.", + ) + min_message_buffer_length: Optional[int] = Field( + None, + description="The desired minimum length of messages in the context window of the convo agent. This is a best effort, and may be off-by-one due to user/assistant interleaving.", + ) # class SwarmGroup(ManagerConfig): diff --git a/letta/schemas/llm_config.py b/letta/schemas/llm_config.py index 9c7f467c6..7b6b99978 100644 --- a/letta/schemas/llm_config.py +++ b/letta/schemas/llm_config.py @@ -50,6 +50,7 @@ class LLMConfig(BaseModel): "xai", ] = Field(..., description="The endpoint type for the model.") model_endpoint: Optional[str] = Field(None, description="The endpoint for the model.") + provider_name: Optional[str] = Field(None, description="The provider name for the model.") model_wrapper: Optional[str] = Field(None, description="The wrapper for the model.") context_window: int = Field(..., description="The context window size for the model.") put_inner_thoughts_in_kwargs: Optional[bool] = Field( diff --git a/letta/schemas/llm_config_overrides.py b/letta/schemas/llm_config_overrides.py index f8f286ae2..407c73a29 100644 --- a/letta/schemas/llm_config_overrides.py +++ b/letta/schemas/llm_config_overrides.py @@ -2,8 +2,8 @@ from typing import Dict LLM_HANDLE_OVERRIDES: Dict[str, Dict[str, str]] = { "anthropic": { - "claude-3-5-haiku-20241022": "claude-3.5-haiku", - "claude-3-5-sonnet-20241022": "claude-3.5-sonnet", + "claude-3-5-haiku-20241022": "claude-3-5-haiku", + "claude-3-5-sonnet-20241022": "claude-3-5-sonnet", "claude-3-opus-20240229": "claude-3-opus", }, "openai": { diff --git a/letta/schemas/providers.py b/letta/schemas/providers.py index a985a412a..f067007a3 100644 --- a/letta/schemas/providers.py +++ b/letta/schemas/providers.py @@ -1,6 +1,6 @@ import warnings from datetime import datetime -from typing import List, Optional +from typing import List, Literal, Optional from pydantic import Field, model_validator @@ -9,9 +9,11 @@ from letta.llm_api.azure_openai import get_azure_chat_completions_endpoint, get_ from letta.llm_api.azure_openai_constants import AZURE_MODEL_TO_CONTEXT_LENGTH from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.embedding_config_overrides import EMBEDDING_HANDLE_OVERRIDES +from letta.schemas.enums import ProviderType from letta.schemas.letta_base import LettaBase from letta.schemas.llm_config import LLMConfig from letta.schemas.llm_config_overrides import LLM_HANDLE_OVERRIDES +from letta.settings import model_settings class ProviderBase(LettaBase): @@ -21,10 +23,18 @@ class ProviderBase(LettaBase): class Provider(ProviderBase): id: Optional[str] = Field(None, description="The id of the provider, lazily created by the database manager.") name: str = Field(..., description="The name of the provider") + provider_type: ProviderType = Field(..., description="The type of the provider") api_key: Optional[str] = Field(None, description="API key used for requests to the provider.") + base_url: Optional[str] = Field(None, description="Base URL for the provider.") organization_id: Optional[str] = Field(None, description="The organization id of the user") updated_at: Optional[datetime] = Field(None, description="The last update timestamp of the provider.") + @model_validator(mode="after") + def default_base_url(self): + if self.provider_type == ProviderType.openai and self.base_url is None: + self.base_url = model_settings.openai_api_base + return self + def resolve_identifier(self): if not self.id: self.id = ProviderBase.generate_id(prefix=ProviderBase.__id_prefix__) @@ -59,9 +69,41 @@ class Provider(ProviderBase): return f"{self.name}/{model_name}" + def cast_to_subtype(self): + match (self.provider_type): + case ProviderType.letta: + return LettaProvider(**self.model_dump(exclude_none=True)) + case ProviderType.openai: + return OpenAIProvider(**self.model_dump(exclude_none=True)) + case ProviderType.anthropic: + return AnthropicProvider(**self.model_dump(exclude_none=True)) + case ProviderType.anthropic_bedrock: + return AnthropicBedrockProvider(**self.model_dump(exclude_none=True)) + case ProviderType.ollama: + return OllamaProvider(**self.model_dump(exclude_none=True)) + case ProviderType.google_ai: + return GoogleAIProvider(**self.model_dump(exclude_none=True)) + case ProviderType.google_vertex: + return GoogleVertexProvider(**self.model_dump(exclude_none=True)) + case ProviderType.azure: + return AzureProvider(**self.model_dump(exclude_none=True)) + case ProviderType.groq: + return GroqProvider(**self.model_dump(exclude_none=True)) + case ProviderType.together: + return TogetherProvider(**self.model_dump(exclude_none=True)) + case ProviderType.vllm_chat_completions: + return VLLMChatCompletionsProvider(**self.model_dump(exclude_none=True)) + case ProviderType.vllm_completions: + return VLLMCompletionsProvider(**self.model_dump(exclude_none=True)) + case ProviderType.xai: + return XAIProvider(**self.model_dump(exclude_none=True)) + case _: + raise ValueError(f"Unknown provider type: {self.provider_type}") + class ProviderCreate(ProviderBase): name: str = Field(..., description="The name of the provider.") + provider_type: ProviderType = Field(..., description="The type of the provider.") api_key: str = Field(..., description="API key used for requests to the provider.") @@ -70,8 +112,7 @@ class ProviderUpdate(ProviderBase): class LettaProvider(Provider): - - name: str = "letta" + provider_type: Literal[ProviderType.letta] = Field(ProviderType.letta, description="The type of the provider.") def list_llm_models(self) -> List[LLMConfig]: return [ @@ -81,6 +122,7 @@ class LettaProvider(Provider): model_endpoint=LETTA_MODEL_ENDPOINT, context_window=8192, handle=self.get_handle("letta-free"), + provider_name=self.name, ) ] @@ -98,7 +140,7 @@ class LettaProvider(Provider): class OpenAIProvider(Provider): - name: str = "openai" + provider_type: Literal[ProviderType.openai] = Field(ProviderType.openai, description="The type of the provider.") api_key: str = Field(..., description="API key for the OpenAI API.") base_url: str = Field(..., description="Base URL for the OpenAI API.") @@ -180,6 +222,7 @@ class OpenAIProvider(Provider): model_endpoint=self.base_url, context_window=context_window_size, handle=self.get_handle(model_name), + provider_name=self.name, ) ) @@ -235,7 +278,7 @@ class DeepSeekProvider(OpenAIProvider): * It also does not support native function calling """ - name: str = "deepseek" + provider_type: Literal[ProviderType.deepseek] = Field(ProviderType.deepseek, description="The type of the provider.") base_url: str = Field("https://api.deepseek.com/v1", description="Base URL for the DeepSeek API.") api_key: str = Field(..., description="API key for the DeepSeek API.") @@ -286,6 +329,7 @@ class DeepSeekProvider(OpenAIProvider): context_window=context_window_size, handle=self.get_handle(model_name), put_inner_thoughts_in_kwargs=put_inner_thoughts_in_kwargs, + provider_name=self.name, ) ) @@ -297,7 +341,7 @@ class DeepSeekProvider(OpenAIProvider): class LMStudioOpenAIProvider(OpenAIProvider): - name: str = "lmstudio-openai" + provider_type: Literal[ProviderType.lmstudio_openai] = Field(ProviderType.lmstudio_openai, description="The type of the provider.") base_url: str = Field(..., description="Base URL for the LMStudio OpenAI API.") api_key: Optional[str] = Field(None, description="API key for the LMStudio API.") @@ -423,7 +467,7 @@ class LMStudioOpenAIProvider(OpenAIProvider): class XAIProvider(OpenAIProvider): """https://docs.x.ai/docs/api-reference""" - name: str = "xai" + provider_type: Literal[ProviderType.xai] = Field(ProviderType.xai, description="The type of the provider.") api_key: str = Field(..., description="API key for the xAI/Grok API.") base_url: str = Field("https://api.x.ai/v1", description="Base URL for the xAI/Grok API.") @@ -476,6 +520,7 @@ class XAIProvider(OpenAIProvider): model_endpoint=self.base_url, context_window=context_window_size, handle=self.get_handle(model_name), + provider_name=self.name, ) ) @@ -487,7 +532,7 @@ class XAIProvider(OpenAIProvider): class AnthropicProvider(Provider): - name: str = "anthropic" + provider_type: Literal[ProviderType.anthropic] = Field(ProviderType.anthropic, description="The type of the provider.") api_key: str = Field(..., description="API key for the Anthropic API.") base_url: str = "https://api.anthropic.com/v1" @@ -563,6 +608,7 @@ class AnthropicProvider(Provider): handle=self.get_handle(model["id"]), put_inner_thoughts_in_kwargs=inner_thoughts_in_kwargs, max_tokens=max_tokens, + provider_name=self.name, ) ) return configs @@ -572,7 +618,7 @@ class AnthropicProvider(Provider): class MistralProvider(Provider): - name: str = "mistral" + provider_type: Literal[ProviderType.mistral] = Field(ProviderType.mistral, description="The type of the provider.") api_key: str = Field(..., description="API key for the Mistral API.") base_url: str = "https://api.mistral.ai/v1" @@ -596,6 +642,7 @@ class MistralProvider(Provider): model_endpoint=self.base_url, context_window=model["max_context_length"], handle=self.get_handle(model["id"]), + provider_name=self.name, ) ) @@ -622,7 +669,7 @@ class OllamaProvider(OpenAIProvider): See: https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion """ - name: str = "ollama" + provider_type: Literal[ProviderType.ollama] = Field(ProviderType.ollama, description="The type of the provider.") base_url: str = Field(..., description="Base URL for the Ollama API.") api_key: Optional[str] = Field(None, description="API key for the Ollama API (default: `None`).") default_prompt_formatter: str = Field( @@ -652,6 +699,7 @@ class OllamaProvider(OpenAIProvider): model_wrapper=self.default_prompt_formatter, context_window=context_window, handle=self.get_handle(model["name"]), + provider_name=self.name, ) ) return configs @@ -734,7 +782,7 @@ class OllamaProvider(OpenAIProvider): class GroqProvider(OpenAIProvider): - name: str = "groq" + provider_type: Literal[ProviderType.groq] = Field(ProviderType.groq, description="The type of the provider.") base_url: str = "https://api.groq.com/openai/v1" api_key: str = Field(..., description="API key for the Groq API.") @@ -753,6 +801,7 @@ class GroqProvider(OpenAIProvider): model_endpoint=self.base_url, context_window=model["context_window"], handle=self.get_handle(model["id"]), + provider_name=self.name, ) ) return configs @@ -773,7 +822,7 @@ class TogetherProvider(OpenAIProvider): function calling support is limited. """ - name: str = "together" + provider_type: Literal[ProviderType.together] = Field(ProviderType.together, description="The type of the provider.") base_url: str = "https://api.together.ai/v1" api_key: str = Field(..., description="API key for the TogetherAI API.") default_prompt_formatter: str = Field(..., description="Default prompt formatter (aka model wrapper) to use on vLLM /completions API.") @@ -821,6 +870,7 @@ class TogetherProvider(OpenAIProvider): model_wrapper=self.default_prompt_formatter, context_window=context_window_size, handle=self.get_handle(model_name), + provider_name=self.name, ) ) @@ -874,7 +924,7 @@ class TogetherProvider(OpenAIProvider): class GoogleAIProvider(Provider): # gemini - name: str = "google_ai" + provider_type: Literal[ProviderType.google_ai] = Field(ProviderType.google_ai, description="The type of the provider.") api_key: str = Field(..., description="API key for the Google AI API.") base_url: str = "https://generativelanguage.googleapis.com" @@ -889,7 +939,6 @@ class GoogleAIProvider(Provider): # filter by model names model_options = [mo[len("models/") :] if mo.startswith("models/") else mo for mo in model_options] - # TODO remove manual filtering for gemini-pro # Add support for all gemini models model_options = [mo for mo in model_options if str(mo).startswith("gemini-")] @@ -903,6 +952,7 @@ class GoogleAIProvider(Provider): context_window=self.get_model_context_window(model), handle=self.get_handle(model), max_tokens=8192, + provider_name=self.name, ) ) return configs @@ -938,7 +988,7 @@ class GoogleAIProvider(Provider): class GoogleVertexProvider(Provider): - name: str = "google_vertex" + provider_type: Literal[ProviderType.google_vertex] = Field(ProviderType.google_vertex, description="The type of the provider.") google_cloud_project: str = Field(..., description="GCP project ID for the Google Vertex API.") google_cloud_location: str = Field(..., description="GCP region for the Google Vertex API.") @@ -955,6 +1005,7 @@ class GoogleVertexProvider(Provider): context_window=context_length, handle=self.get_handle(model), max_tokens=8192, + provider_name=self.name, ) ) return configs @@ -978,7 +1029,7 @@ class GoogleVertexProvider(Provider): class AzureProvider(Provider): - name: str = "azure" + provider_type: Literal[ProviderType.azure] = Field(ProviderType.azure, description="The type of the provider.") latest_api_version: str = "2024-09-01-preview" # https://learn.microsoft.com/en-us/azure/ai-services/openai/api-version-deprecation base_url: str = Field( ..., description="Base URL for the Azure API endpoint. This should be specific to your org, e.g. `https://letta.openai.azure.com`." @@ -1011,6 +1062,7 @@ class AzureProvider(Provider): model_endpoint=model_endpoint, context_window=context_window_size, handle=self.get_handle(model_name), + provider_name=self.name, ), ) return configs @@ -1051,7 +1103,7 @@ class VLLMChatCompletionsProvider(Provider): """vLLM provider that treats vLLM as an OpenAI /chat/completions proxy""" # NOTE: vLLM only serves one model at a time (so could configure that through env variables) - name: str = "vllm" + provider_type: Literal[ProviderType.vllm] = Field(ProviderType.vllm, description="The type of the provider.") base_url: str = Field(..., description="Base URL for the vLLM API.") def list_llm_models(self) -> List[LLMConfig]: @@ -1070,6 +1122,7 @@ class VLLMChatCompletionsProvider(Provider): model_endpoint=self.base_url, context_window=model["max_model_len"], handle=self.get_handle(model["id"]), + provider_name=self.name, ) ) return configs @@ -1083,7 +1136,7 @@ class VLLMCompletionsProvider(Provider): """This uses /completions API as the backend, not /chat/completions, so we need to specify a model wrapper""" # NOTE: vLLM only serves one model at a time (so could configure that through env variables) - name: str = "vllm" + provider_type: Literal[ProviderType.vllm] = Field(ProviderType.vllm, description="The type of the provider.") base_url: str = Field(..., description="Base URL for the vLLM API.") default_prompt_formatter: str = Field(..., description="Default prompt formatter (aka model wrapper) to use on vLLM /completions API.") @@ -1103,6 +1156,7 @@ class VLLMCompletionsProvider(Provider): model_wrapper=self.default_prompt_formatter, context_window=model["max_model_len"], handle=self.get_handle(model["id"]), + provider_name=self.name, ) ) return configs @@ -1117,7 +1171,7 @@ class CohereProvider(OpenAIProvider): class AnthropicBedrockProvider(Provider): - name: str = "bedrock" + provider_type: Literal[ProviderType.bedrock] = Field(ProviderType.bedrock, description="The type of the provider.") aws_region: str = Field(..., description="AWS region for Bedrock") def list_llm_models(self): @@ -1131,10 +1185,11 @@ class AnthropicBedrockProvider(Provider): configs.append( LLMConfig( model=model_arn, - model_endpoint_type=self.name, + model_endpoint_type=self.provider_type.value, model_endpoint=None, context_window=self.get_model_context_window(model_arn), handle=self.get_handle(model_arn), + provider_name=self.name, ) ) return configs diff --git a/letta/schemas/tool.py b/letta/schemas/tool.py index bc2de0e4a..6c8f9bd3e 100644 --- a/letta/schemas/tool.py +++ b/letta/schemas/tool.py @@ -11,13 +11,9 @@ from letta.constants import ( MCP_TOOL_TAG_NAME_PREFIX, ) from letta.functions.ast_parsers import get_function_name_and_description +from letta.functions.composio_helpers import generate_composio_tool_wrapper from letta.functions.functions import derive_openai_json_schema, get_json_schema_from_module -from letta.functions.helpers import ( - generate_composio_tool_wrapper, - generate_langchain_tool_wrapper, - generate_mcp_tool_wrapper, - generate_model_from_args_json_schema, -) +from letta.functions.helpers import generate_langchain_tool_wrapper, generate_mcp_tool_wrapper, generate_model_from_args_json_schema from letta.functions.mcp_client.types import MCPTool from letta.functions.schema_generator import ( generate_schema_from_args_schema_v2, @@ -176,8 +172,7 @@ class ToolCreate(LettaBase): Returns: Tool: A Letta Tool initialized with attributes derived from the Composio tool. """ - from composio import LogLevel - from composio_langchain import ComposioToolSet + from composio import ComposioToolSet, LogLevel composio_toolset = ComposioToolSet(logging_level=LogLevel.ERROR, lock=False) composio_action_schemas = composio_toolset.get_action_schemas(actions=[action_name], check_connected_accounts=False) diff --git a/letta/server/rest_api/app.py b/letta/server/rest_api/app.py index 5aeb206c3..476b818f3 100644 --- a/letta/server/rest_api/app.py +++ b/letta/server/rest_api/app.py @@ -14,6 +14,7 @@ from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.cors import CORSMiddleware from letta.__init__ import __version__ +from letta.agents.exceptions import IncompatibleAgentType from letta.constants import ADMIN_PREFIX, API_PREFIX, OPENAI_API_PREFIX from letta.errors import BedrockPermissionError, LettaAgentNotFoundError, LettaUserNotFoundError from letta.jobs.scheduler import shutdown_cron_scheduler, start_cron_jobs @@ -173,6 +174,17 @@ def create_application() -> "FastAPI": def shutdown_scheduler(): shutdown_cron_scheduler() + @app.exception_handler(IncompatibleAgentType) + async def handle_incompatible_agent_type(request: Request, exc: IncompatibleAgentType): + return JSONResponse( + status_code=400, + content={ + "detail": str(exc), + "expected_type": exc.expected_type, + "actual_type": exc.actual_type, + }, + ) + @app.exception_handler(Exception) async def generic_error_handler(request: Request, exc: Exception): # Log the actual error for debugging diff --git a/letta/server/rest_api/chat_completions_interface.py b/letta/server/rest_api/chat_completions_interface.py index 0f684ed7c..9b05ca845 100644 --- a/letta/server/rest_api/chat_completions_interface.py +++ b/letta/server/rest_api/chat_completions_interface.py @@ -12,7 +12,7 @@ from letta.schemas.enums import MessageStreamStatus from letta.schemas.letta_message import LettaMessage from letta.schemas.message import Message from letta.schemas.openai.chat_completion_response import ChatCompletionChunkResponse -from letta.server.rest_api.optimistic_json_parser import OptimisticJSONParser +from letta.server.rest_api.json_parser import OptimisticJSONParser from letta.streaming_interface import AgentChunkStreamingInterface logger = get_logger(__name__) diff --git a/letta/server/rest_api/interface.py b/letta/server/rest_api/interface.py index edf8a2330..9a89f9074 100644 --- a/letta/server/rest_api/interface.py +++ b/letta/server/rest_api/interface.py @@ -28,7 +28,7 @@ from letta.schemas.letta_message import ( from letta.schemas.letta_message_content import ReasoningContent, RedactedReasoningContent, TextContent from letta.schemas.message import Message from letta.schemas.openai.chat_completion_response import ChatCompletionChunkResponse -from letta.server.rest_api.optimistic_json_parser import OptimisticJSONParser +from letta.server.rest_api.json_parser import OptimisticJSONParser from letta.streaming_interface import AgentChunkStreamingInterface from letta.streaming_utils import FunctionArgumentsStreamHandler, JSONInnerThoughtsExtractor from letta.utils import parse_json @@ -291,7 +291,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): self.streaming_chat_completion_json_reader = FunctionArgumentsStreamHandler(json_key=assistant_message_tool_kwarg) # @matt's changes here, adopting new optimistic json parser - self.current_function_arguments = [] + self.current_function_arguments = "" self.optimistic_json_parser = OptimisticJSONParser() self.current_json_parse_result = {} @@ -387,7 +387,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): def stream_start(self): """Initialize streaming by activating the generator and clearing any old chunks.""" self.streaming_chat_completion_mode_function_name = None - self.current_function_arguments = [] + self.current_function_arguments = "" self.current_json_parse_result = {} if not self._active: @@ -398,7 +398,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): def stream_end(self): """Clean up the stream by deactivating and clearing chunks.""" self.streaming_chat_completion_mode_function_name = None - self.current_function_arguments = [] + self.current_function_arguments = "" self.current_json_parse_result = {} # if not self.streaming_chat_completion_mode and not self.nonstreaming_legacy_mode: @@ -609,14 +609,13 @@ class StreamingServerInterface(AgentChunkStreamingInterface): # early exit to turn into content mode return None if tool_call.function.arguments: - self.current_function_arguments.append(tool_call.function.arguments) + self.current_function_arguments += tool_call.function.arguments # if we're in the middle of parsing a send_message, we'll keep processing the JSON chunks if tool_call.function.arguments and self.streaming_chat_completion_mode_function_name == self.assistant_message_tool_name: # Strip out any extras tokens # In the case that we just have the prefix of something, no message yet, then we should early exit to move to the next chunk - combined_args = "".join(self.current_function_arguments) - parsed_args = self.optimistic_json_parser.parse(combined_args) + parsed_args = self.optimistic_json_parser.parse(self.current_function_arguments) if parsed_args.get(self.assistant_message_tool_kwarg) and parsed_args.get( self.assistant_message_tool_kwarg @@ -686,7 +685,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): # updates_inner_thoughts = "" # else: # OpenAI # updates_main_json, updates_inner_thoughts = self.function_args_reader.process_fragment(tool_call.function.arguments) - self.current_function_arguments.append(tool_call.function.arguments) + self.current_function_arguments += tool_call.function.arguments updates_main_json, updates_inner_thoughts = self.function_args_reader.process_fragment(tool_call.function.arguments) # If we have inner thoughts, we should output them as a chunk @@ -805,8 +804,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): # TODO: THIS IS HORRIBLE # TODO: WE USE THE OLD JSON PARSER EARLIER (WHICH DOES NOTHING) AND NOW THE NEW JSON PARSER # TODO: THIS IS TOTALLY WRONG AND BAD, BUT SAVING FOR A LARGER REWRITE IN THE NEAR FUTURE - combined_args = "".join(self.current_function_arguments) - parsed_args = self.optimistic_json_parser.parse(combined_args) + parsed_args = self.optimistic_json_parser.parse(self.current_function_arguments) if parsed_args.get(self.assistant_message_tool_kwarg) and parsed_args.get( self.assistant_message_tool_kwarg diff --git a/letta/server/rest_api/optimistic_json_parser.py b/letta/server/rest_api/json_parser.py similarity index 70% rename from letta/server/rest_api/optimistic_json_parser.py rename to letta/server/rest_api/json_parser.py index c3a2f069d..27b4f0cfb 100644 --- a/letta/server/rest_api/optimistic_json_parser.py +++ b/letta/server/rest_api/json_parser.py @@ -1,7 +1,43 @@ import json +from abc import ABC, abstractmethod +from typing import Any + +from pydantic_core import from_json + +from letta.log import get_logger + +logger = get_logger(__name__) -class OptimisticJSONParser: +class JSONParser(ABC): + @abstractmethod + def parse(self, input_str: str) -> Any: + raise NotImplementedError() + + +class PydanticJSONParser(JSONParser): + """ + https://docs.pydantic.dev/latest/concepts/json/#json-parsing + If `strict` is True, we will not allow for partial parsing of JSON. + + Compared with `OptimisticJSONParser`, this parser is more strict. + Note: This will not partially parse strings which may be decrease parsing speed for message strings + """ + + def __init__(self, strict=False): + self.strict = strict + + def parse(self, input_str: str) -> Any: + if not input_str: + return {} + try: + return from_json(input_str, allow_partial="trailing-strings" if not self.strict else False) + except ValueError as e: + logger.error(f"Failed to parse JSON: {e}") + raise + + +class OptimisticJSONParser(JSONParser): """ A JSON parser that attempts to parse a given string using `json.loads`, and if that fails, it parses as much valid JSON as possible while @@ -13,25 +49,25 @@ class OptimisticJSONParser: def __init__(self, strict=False): self.strict = strict self.parsers = { - " ": self.parse_space, - "\r": self.parse_space, - "\n": self.parse_space, - "\t": self.parse_space, - "[": self.parse_array, - "{": self.parse_object, - '"': self.parse_string, - "t": self.parse_true, - "f": self.parse_false, - "n": self.parse_null, + " ": self._parse_space, + "\r": self._parse_space, + "\n": self._parse_space, + "\t": self._parse_space, + "[": self._parse_array, + "{": self._parse_object, + '"': self._parse_string, + "t": self._parse_true, + "f": self._parse_false, + "n": self._parse_null, } # Register number parser for digits and signs for char in "0123456789.-": self.parsers[char] = self.parse_number self.last_parse_reminding = None - self.on_extra_token = self.default_on_extra_token + self.on_extra_token = self._default_on_extra_token - def default_on_extra_token(self, text, data, reminding): + def _default_on_extra_token(self, text, data, reminding): print(f"Parsed JSON with extra tokens: {data}, remaining: {reminding}") def parse(self, input_str): @@ -45,7 +81,7 @@ class OptimisticJSONParser: try: return json.loads(input_str) except json.JSONDecodeError as decode_error: - data, reminding = self.parse_any(input_str, decode_error) + data, reminding = self._parse_any(input_str, decode_error) self.last_parse_reminding = reminding if self.on_extra_token and reminding: self.on_extra_token(input_str, data, reminding) @@ -53,7 +89,7 @@ class OptimisticJSONParser: else: return json.loads("{}") - def parse_any(self, input_str, decode_error): + def _parse_any(self, input_str, decode_error): """Determine which parser to use based on the first character.""" if not input_str: raise decode_error @@ -62,11 +98,11 @@ class OptimisticJSONParser: raise decode_error return parser(input_str, decode_error) - def parse_space(self, input_str, decode_error): + def _parse_space(self, input_str, decode_error): """Strip leading whitespace and parse again.""" - return self.parse_any(input_str.strip(), decode_error) + return self._parse_any(input_str.strip(), decode_error) - def parse_array(self, input_str, decode_error): + def _parse_array(self, input_str, decode_error): """Parse a JSON array, returning the list and remaining string.""" # Skip the '[' input_str = input_str[1:] @@ -77,7 +113,7 @@ class OptimisticJSONParser: # Skip the ']' input_str = input_str[1:] break - value, input_str = self.parse_any(input_str, decode_error) + value, input_str = self._parse_any(input_str, decode_error) array_values.append(value) input_str = input_str.strip() if input_str.startswith(","): @@ -85,7 +121,7 @@ class OptimisticJSONParser: input_str = input_str[1:].strip() return array_values, input_str - def parse_object(self, input_str, decode_error): + def _parse_object(self, input_str, decode_error): """Parse a JSON object, returning the dict and remaining string.""" # Skip the '{' input_str = input_str[1:] @@ -96,7 +132,7 @@ class OptimisticJSONParser: # Skip the '}' input_str = input_str[1:] break - key, input_str = self.parse_any(input_str, decode_error) + key, input_str = self._parse_any(input_str, decode_error) input_str = input_str.strip() if not input_str or input_str[0] == "}": @@ -113,7 +149,7 @@ class OptimisticJSONParser: input_str = input_str[1:] break - value, input_str = self.parse_any(input_str, decode_error) + value, input_str = self._parse_any(input_str, decode_error) obj[key] = value input_str = input_str.strip() if input_str.startswith(","): @@ -121,7 +157,7 @@ class OptimisticJSONParser: input_str = input_str[1:].strip() return obj, input_str - def parse_string(self, input_str, decode_error): + def _parse_string(self, input_str, decode_error): """Parse a JSON string, respecting escaped quotes if present.""" end = input_str.find('"', 1) while end != -1 and input_str[end - 1] == "\\": @@ -166,19 +202,19 @@ class OptimisticJSONParser: return num, remainder - def parse_true(self, input_str, decode_error): + def _parse_true(self, input_str, decode_error): """Parse a 'true' value.""" if input_str.startswith(("t", "T")): return True, input_str[4:] raise decode_error - def parse_false(self, input_str, decode_error): + def _parse_false(self, input_str, decode_error): """Parse a 'false' value.""" if input_str.startswith(("f", "F")): return False, input_str[5:] raise decode_error - def parse_null(self, input_str, decode_error): + def _parse_null(self, input_str, decode_error): """Parse a 'null' value.""" if input_str.startswith("n"): return None, input_str[4:] diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index 971805c26..698f5d4a9 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -678,7 +678,7 @@ async def send_message_streaming( server: SyncServer = Depends(get_letta_server), request: LettaStreamingRequest = Body(...), actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present -): +) -> StreamingResponse | LettaResponse: """ Process a user message and return the agent's response. This endpoint accepts a message from a user and processes it through the agent. diff --git a/letta/server/rest_api/routers/v1/llms.py b/letta/server/rest_api/routers/v1/llms.py index 173b1a578..02c369f66 100644 --- a/letta/server/rest_api/routers/v1/llms.py +++ b/letta/server/rest_api/routers/v1/llms.py @@ -1,6 +1,6 @@ -from typing import TYPE_CHECKING, List +from typing import TYPE_CHECKING, List, Optional -from fastapi import APIRouter, Depends +from fastapi import APIRouter, Depends, Query from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.llm_config import LLMConfig @@ -14,10 +14,11 @@ router = APIRouter(prefix="/models", tags=["models", "llms"]) @router.get("/", response_model=List[LLMConfig], operation_id="list_models") def list_llm_models( + byok_only: Optional[bool] = Query(None), server: "SyncServer" = Depends(get_letta_server), ): - models = server.list_llm_models() + models = server.list_llm_models(byok_only=byok_only) # print(models) return models diff --git a/letta/server/rest_api/routers/v1/providers.py b/letta/server/rest_api/routers/v1/providers.py index 1de78ba57..02615f633 100644 --- a/letta/server/rest_api/routers/v1/providers.py +++ b/letta/server/rest_api/routers/v1/providers.py @@ -2,6 +2,7 @@ from typing import TYPE_CHECKING, List, Optional from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query +from letta.schemas.enums import ProviderType from letta.schemas.providers import Provider, ProviderCreate, ProviderUpdate from letta.server.rest_api.utils import get_letta_server @@ -13,6 +14,8 @@ router = APIRouter(prefix="/providers", tags=["providers"]) @router.get("/", response_model=List[Provider], operation_id="list_providers") def list_providers( + name: Optional[str] = Query(None), + provider_type: Optional[ProviderType] = Query(None), after: Optional[str] = Query(None), limit: Optional[int] = Query(50), actor_id: Optional[str] = Header(None, alias="user_id"), @@ -23,7 +26,7 @@ def list_providers( """ try: actor = server.user_manager.get_user_or_default(user_id=actor_id) - providers = server.provider_manager.list_providers(after=after, limit=limit, actor=actor) + providers = server.provider_manager.list_providers(after=after, limit=limit, actor=actor, name=name, provider_type=provider_type) except HTTPException: raise except Exception as e: diff --git a/letta/server/rest_api/routers/v1/voice.py b/letta/server/rest_api/routers/v1/voice.py index 47c989e74..4517a1a05 100644 --- a/letta/server/rest_api/routers/v1/voice.py +++ b/letta/server/rest_api/routers/v1/voice.py @@ -54,8 +54,6 @@ async def create_voice_chat_completions( block_manager=server.block_manager, passage_manager=server.passage_manager, actor=actor, - message_buffer_limit=8, - message_buffer_min=4, ) # Return the streaming generator diff --git a/letta/server/rest_api/utils.py b/letta/server/rest_api/utils.py index 40471eab5..2e9b3e9a5 100644 --- a/letta/server/rest_api/utils.py +++ b/letta/server/rest_api/utils.py @@ -16,6 +16,7 @@ from pydantic import BaseModel from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, FUNC_FAILED_HEARTBEAT_MESSAGE, REQ_HEARTBEAT_MESSAGE from letta.errors import ContextWindowExceededError, RateLimitExceededError from letta.helpers.datetime_helpers import get_utc_time +from letta.helpers.message_helper import convert_message_creates_to_messages from letta.log import get_logger from letta.schemas.enums import MessageRole from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, RedactedReasoningContent, TextContent @@ -143,27 +144,15 @@ def log_error_to_sentry(e): def create_input_messages(input_messages: List[MessageCreate], agent_id: str, actor: User) -> List[Message]: """ Converts a user input message into the internal structured format. - """ - new_messages = [] - for input_message in input_messages: - # Construct the Message object - new_message = Message( - id=f"message-{uuid.uuid4()}", - role=input_message.role, - content=input_message.content, - name=input_message.name, - otid=input_message.otid, - sender_id=input_message.sender_id, - organization_id=actor.organization_id, - agent_id=agent_id, - model=None, - tool_calls=None, - tool_call_id=None, - created_at=get_utc_time(), - ) - new_messages.append(new_message) - return new_messages + TODO (cliandy): this effectively duplicates the functionality of `convert_message_creates_to_messages`, + we should unify this when it's clear what message attributes we need. + """ + + messages = convert_message_creates_to_messages(input_messages, agent_id, wrap_user_message=False, wrap_system_message=False) + for message in messages: + message.organization_id = actor.organization_id + return messages def create_letta_messages_from_llm_response( diff --git a/letta/server/server.py b/letta/server/server.py index 5a7deff73..4553de2f1 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -268,10 +268,11 @@ class SyncServer(Server): ) # collect providers (always has Letta as a default) - self._enabled_providers: List[Provider] = [LettaProvider()] + self._enabled_providers: List[Provider] = [LettaProvider(name="letta")] if model_settings.openai_api_key: self._enabled_providers.append( OpenAIProvider( + name="openai", api_key=model_settings.openai_api_key, base_url=model_settings.openai_api_base, ) @@ -279,12 +280,14 @@ class SyncServer(Server): if model_settings.anthropic_api_key: self._enabled_providers.append( AnthropicProvider( + name="anthropic", api_key=model_settings.anthropic_api_key, ) ) if model_settings.ollama_base_url: self._enabled_providers.append( OllamaProvider( + name="ollama", base_url=model_settings.ollama_base_url, api_key=None, default_prompt_formatter=model_settings.default_prompt_formatter, @@ -293,12 +296,14 @@ class SyncServer(Server): if model_settings.gemini_api_key: self._enabled_providers.append( GoogleAIProvider( + name="google_ai", api_key=model_settings.gemini_api_key, ) ) if model_settings.google_cloud_location and model_settings.google_cloud_project: self._enabled_providers.append( GoogleVertexProvider( + name="google_vertex", google_cloud_project=model_settings.google_cloud_project, google_cloud_location=model_settings.google_cloud_location, ) @@ -307,6 +312,7 @@ class SyncServer(Server): assert model_settings.azure_api_version, "AZURE_API_VERSION is required" self._enabled_providers.append( AzureProvider( + name="azure", api_key=model_settings.azure_api_key, base_url=model_settings.azure_base_url, api_version=model_settings.azure_api_version, @@ -315,12 +321,14 @@ class SyncServer(Server): if model_settings.groq_api_key: self._enabled_providers.append( GroqProvider( + name="groq", api_key=model_settings.groq_api_key, ) ) if model_settings.together_api_key: self._enabled_providers.append( TogetherProvider( + name="together", api_key=model_settings.together_api_key, default_prompt_formatter=model_settings.default_prompt_formatter, ) @@ -329,6 +337,7 @@ class SyncServer(Server): # vLLM exposes both a /chat/completions and a /completions endpoint self._enabled_providers.append( VLLMCompletionsProvider( + name="vllm", base_url=model_settings.vllm_api_base, default_prompt_formatter=model_settings.default_prompt_formatter, ) @@ -338,12 +347,14 @@ class SyncServer(Server): # e.g. "... --enable-auto-tool-choice --tool-call-parser hermes" self._enabled_providers.append( VLLMChatCompletionsProvider( + name="vllm", base_url=model_settings.vllm_api_base, ) ) if model_settings.aws_access_key and model_settings.aws_secret_access_key and model_settings.aws_region: self._enabled_providers.append( AnthropicBedrockProvider( + name="bedrock", aws_region=model_settings.aws_region, ) ) @@ -355,11 +366,11 @@ class SyncServer(Server): if model_settings.lmstudio_base_url.endswith("/v1") else model_settings.lmstudio_base_url + "/v1" ) - self._enabled_providers.append(LMStudioOpenAIProvider(base_url=lmstudio_url)) + self._enabled_providers.append(LMStudioOpenAIProvider(name="lmstudio_openai", base_url=lmstudio_url)) if model_settings.deepseek_api_key: - self._enabled_providers.append(DeepSeekProvider(api_key=model_settings.deepseek_api_key)) + self._enabled_providers.append(DeepSeekProvider(name="deepseek", api_key=model_settings.deepseek_api_key)) if model_settings.xai_api_key: - self._enabled_providers.append(XAIProvider(api_key=model_settings.xai_api_key)) + self._enabled_providers.append(XAIProvider(name="xai", api_key=model_settings.xai_api_key)) # For MCP """Initialize the MCP clients (there may be multiple)""" @@ -862,6 +873,8 @@ class SyncServer(Server): agent_ids=[voice_sleeptime_agent.id], manager_config=VoiceSleeptimeManager( manager_agent_id=main_agent.id, + max_message_buffer_length=constants.DEFAULT_MAX_MESSAGE_BUFFER_LENGTH, + min_message_buffer_length=constants.DEFAULT_MIN_MESSAGE_BUFFER_LENGTH, ), ), actor=actor, @@ -1182,10 +1195,10 @@ class SyncServer(Server): except NoResultFound: raise HTTPException(status_code=404, detail=f"Organization with id {org_id} not found") - def list_llm_models(self) -> List[LLMConfig]: + def list_llm_models(self, byok_only: bool = False) -> List[LLMConfig]: """List available models""" llm_models = [] - for provider in self.get_enabled_providers(): + for provider in self.get_enabled_providers(byok_only=byok_only): try: llm_models.extend(provider.list_llm_models()) except Exception as e: @@ -1205,11 +1218,12 @@ class SyncServer(Server): warnings.warn(f"An error occurred while listing embedding models for provider {provider}: {e}") return embedding_models - def get_enabled_providers(self): + def get_enabled_providers(self, byok_only: bool = False): + providers_from_db = {p.name: p.cast_to_subtype() for p in self.provider_manager.list_providers()} + if byok_only: + return list(providers_from_db.values()) providers_from_env = {p.name: p for p in self._enabled_providers} - providers_from_db = {p.name: p for p in self.provider_manager.list_providers()} - # Merge the two dictionaries, keeping the values from providers_from_db where conflicts occur - return {**providers_from_env, **providers_from_db}.values() + return list(providers_from_env.values()) + list(providers_from_db.values()) @trace_method def get_llm_config_from_handle( @@ -1294,7 +1308,7 @@ class SyncServer(Server): return embedding_config def get_provider_from_name(self, provider_name: str) -> Provider: - providers = [provider for provider in self._enabled_providers if provider.name == provider_name] + providers = [provider for provider in self.get_enabled_providers() if provider.name == provider_name] if not providers: raise ValueError(f"Provider {provider_name} is not supported") elif len(providers) > 1: diff --git a/letta/services/group_manager.py b/letta/services/group_manager.py index e24d508d5..8bae455f5 100644 --- a/letta/services/group_manager.py +++ b/letta/services/group_manager.py @@ -80,6 +80,12 @@ class GroupManager: case ManagerType.voice_sleeptime: new_group.manager_type = ManagerType.voice_sleeptime new_group.manager_agent_id = group.manager_config.manager_agent_id + max_message_buffer_length = group.manager_config.max_message_buffer_length + min_message_buffer_length = group.manager_config.min_message_buffer_length + # Safety check for buffer length range + self.ensure_buffer_length_range_valid(max_value=max_message_buffer_length, min_value=min_message_buffer_length) + new_group.max_message_buffer_length = max_message_buffer_length + new_group.min_message_buffer_length = min_message_buffer_length case _: raise ValueError(f"Unsupported manager type: {group.manager_config.manager_type}") @@ -97,6 +103,8 @@ class GroupManager: group = GroupModel.read(db_session=session, identifier=group_id, actor=actor) sleeptime_agent_frequency = None + max_message_buffer_length = None + min_message_buffer_length = None max_turns = None termination_token = None manager_agent_id = None @@ -117,11 +125,24 @@ class GroupManager: sleeptime_agent_frequency = group_update.manager_config.sleeptime_agent_frequency if sleeptime_agent_frequency and group.turns_counter is None: group.turns_counter = -1 + case ManagerType.voice_sleeptime: + manager_agent_id = group_update.manager_config.manager_agent_id + max_message_buffer_length = group_update.manager_config.max_message_buffer_length or group.max_message_buffer_length + min_message_buffer_length = group_update.manager_config.min_message_buffer_length or group.min_message_buffer_length + if sleeptime_agent_frequency and group.turns_counter is None: + group.turns_counter = -1 case _: raise ValueError(f"Unsupported manager type: {group_update.manager_config.manager_type}") + # Safety check for buffer length range + self.ensure_buffer_length_range_valid(max_value=max_message_buffer_length, min_value=min_message_buffer_length) + if sleeptime_agent_frequency: group.sleeptime_agent_frequency = sleeptime_agent_frequency + if max_message_buffer_length: + group.max_message_buffer_length = max_message_buffer_length + if min_message_buffer_length: + group.min_message_buffer_length = min_message_buffer_length if max_turns: group.max_turns = max_turns if termination_token: @@ -274,3 +295,40 @@ class GroupManager: if manager_agent: for block in blocks: session.add(BlocksAgents(agent_id=manager_agent.id, block_id=block.id, block_label=block.label)) + + @staticmethod + def ensure_buffer_length_range_valid( + max_value: Optional[int], + min_value: Optional[int], + max_name: str = "max_message_buffer_length", + min_name: str = "min_message_buffer_length", + ) -> None: + """ + 1) Both-or-none: if one is set, the other must be set. + 2) Both must be ints > 4. + 3) max_value must be strictly greater than min_value. + """ + # 1) require both-or-none + if (max_value is None) != (min_value is None): + raise ValueError( + f"Both '{max_name}' and '{min_name}' must be provided together " f"(got {max_name}={max_value}, {min_name}={min_value})" + ) + + # no further checks if neither is provided + if max_value is None: + return + + # 2) type & lower‐bound checks + if not isinstance(max_value, int) or not isinstance(min_value, int): + raise ValueError( + f"Both '{max_name}' and '{min_name}' must be integers " + f"(got {max_name}={type(max_value).__name__}, {min_name}={type(min_value).__name__})" + ) + if max_value <= 4 or min_value <= 4: + raise ValueError( + f"Both '{max_name}' and '{min_name}' must be greater than 4 " f"(got {max_name}={max_value}, {min_name}={min_value})" + ) + + # 3) ordering + if max_value <= min_value: + raise ValueError(f"'{max_name}' must be greater than '{min_name}' " f"(got {max_name}={max_value} <= {min_name}={min_value})") diff --git a/letta/services/provider_manager.py b/letta/services/provider_manager.py index 39596e17f..d012171d2 100644 --- a/letta/services/provider_manager.py +++ b/letta/services/provider_manager.py @@ -1,6 +1,7 @@ -from typing import List, Optional +from typing import List, Optional, Union from letta.orm.provider import Provider as ProviderModel +from letta.schemas.enums import ProviderType from letta.schemas.providers import Provider as PydanticProvider from letta.schemas.providers import ProviderUpdate from letta.schemas.user import User as PydanticUser @@ -18,6 +19,9 @@ class ProviderManager: def create_provider(self, provider: PydanticProvider, actor: PydanticUser) -> PydanticProvider: """Create a new provider if it doesn't already exist.""" with self.session_maker() as session: + if provider.name == provider.provider_type.value: + raise ValueError("Provider name must be unique and different from provider type") + # Assign the organization id based on the actor provider.organization_id = actor.organization_id @@ -59,29 +63,36 @@ class ProviderManager: session.commit() @enforce_types - def list_providers(self, after: Optional[str] = None, limit: Optional[int] = 50, actor: PydanticUser = None) -> List[PydanticProvider]: + def list_providers( + self, + name: Optional[str] = None, + provider_type: Optional[ProviderType] = None, + after: Optional[str] = None, + limit: Optional[int] = 50, + actor: PydanticUser = None, + ) -> List[PydanticProvider]: """List all providers with optional pagination.""" + filter_kwargs = {} + if name: + filter_kwargs["name"] = name + if provider_type: + filter_kwargs["provider_type"] = provider_type with self.session_maker() as session: providers = ProviderModel.list( db_session=session, after=after, limit=limit, actor=actor, + **filter_kwargs, ) return [provider.to_pydantic() for provider in providers] @enforce_types - def get_anthropic_override_provider_id(self) -> Optional[str]: - """Helper function to fetch custom anthropic provider id for v0 BYOK feature""" - anthropic_provider = [provider for provider in self.list_providers() if provider.name == "anthropic"] - if len(anthropic_provider) != 0: - return anthropic_provider[0].id - return None + def get_provider_id_from_name(self, provider_name: Union[str, None]) -> Optional[str]: + providers = self.list_providers(name=provider_name) + return providers[0].id if providers else None @enforce_types - def get_anthropic_override_key(self) -> Optional[str]: - """Helper function to fetch custom anthropic key for v0 BYOK feature""" - anthropic_provider = [provider for provider in self.list_providers() if provider.name == "anthropic"] - if len(anthropic_provider) != 0: - return anthropic_provider[0].api_key - return None + def get_override_key(self, provider_name: Union[str, None]) -> Optional[str]: + providers = self.list_providers(name=provider_name) + return providers[0].api_key if providers else None diff --git a/letta/services/summarizer/summarizer.py b/letta/services/summarizer/summarizer.py index b138bd98e..efbadea33 100644 --- a/letta/services/summarizer/summarizer.py +++ b/letta/services/summarizer/summarizer.py @@ -4,6 +4,7 @@ import traceback from typing import List, Tuple from letta.agents.voice_sleeptime_agent import VoiceSleeptimeAgent +from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG from letta.log import get_logger from letta.schemas.enums import MessageRole from letta.schemas.letta_message_content import TextContent @@ -77,7 +78,7 @@ class Summarizer: logger.info("Buffer length hit, evicting messages.") - target_trim_index = len(all_in_context_messages) - self.message_buffer_min + 1 + target_trim_index = len(all_in_context_messages) - self.message_buffer_min while target_trim_index < len(all_in_context_messages) and all_in_context_messages[target_trim_index].role != MessageRole.user: target_trim_index += 1 @@ -112,11 +113,12 @@ class Summarizer: summary_request_text = f"""You’re a memory-recall helper for an AI that can only keep the last {self.message_buffer_min} messages. Scan the conversation history, focusing on messages about to drop out of that window, and write crisp notes that capture any important facts or insights about the human so they aren’t lost. (Older) Evicted Messages:\n -{evicted_messages_str} +{evicted_messages_str}\n (Newer) In-Context Messages:\n {in_context_messages_str} """ + print(summary_request_text) # Fire-and-forget the summarization task self.fire_and_forget( self.summarizer_agent.step([MessageCreate(role=MessageRole.user, content=[TextContent(text=summary_request_text)])]) @@ -149,6 +151,9 @@ def format_transcript(messages: List[Message], include_system: bool = False) -> # 1) Try plain content if msg.content: + # Skip tool messages where the name is "send_message" + if msg.role == MessageRole.tool and msg.name == DEFAULT_MESSAGE_TOOL: + continue text = "".join(c.text for c in msg.content).strip() # 2) Otherwise, try extracting from function calls @@ -156,11 +161,14 @@ def format_transcript(messages: List[Message], include_system: bool = False) -> parts = [] for call in msg.tool_calls: args_str = call.function.arguments - try: - args = json.loads(args_str) - # pull out a "message" field if present - parts.append(args.get("message", args_str)) - except json.JSONDecodeError: + if call.function.name == DEFAULT_MESSAGE_TOOL: + try: + args = json.loads(args_str) + # pull out a "message" field if present + parts.append(args.get(DEFAULT_MESSAGE_TOOL_KWARG, args_str)) + except json.JSONDecodeError: + parts.append(args_str) + else: parts.append(args_str) text = " ".join(parts).strip() diff --git a/letta/services/tool_executor/tool_execution_manager.py b/letta/services/tool_executor/tool_execution_manager.py index fcc96759e..6ba8679c3 100644 --- a/letta/services/tool_executor/tool_execution_manager.py +++ b/letta/services/tool_executor/tool_execution_manager.py @@ -100,7 +100,7 @@ class ToolExecutionManager: try: executor = ToolExecutorFactory.get_executor(tool.tool_type) # TODO: Extend this async model to composio - if isinstance(executor, SandboxToolExecutor): + if isinstance(executor, (SandboxToolExecutor, ExternalComposioToolExecutor)): result = await executor.execute(function_name, function_args, self.agent_state, tool, self.actor) else: result = executor.execute(function_name, function_args, self.agent_state, tool, self.actor) diff --git a/letta/services/tool_executor/tool_executor.py b/letta/services/tool_executor/tool_executor.py index 7d9cac41f..50879e570 100644 --- a/letta/services/tool_executor/tool_executor.py +++ b/letta/services/tool_executor/tool_executor.py @@ -5,7 +5,7 @@ from typing import Any, Dict, Optional from letta.constants import COMPOSIO_ENTITY_ENV_VAR_KEY, CORE_MEMORY_LINE_NUMBER_WARNING, RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE from letta.functions.ast_parsers import coerce_dict_args_by_annotations, get_function_annotations_from_source -from letta.functions.helpers import execute_composio_action, generate_composio_action_from_func_name +from letta.functions.composio_helpers import execute_composio_action_async, generate_composio_action_from_func_name from letta.helpers.composio_helpers import get_composio_api_key from letta.helpers.json_helpers import json_dumps from letta.schemas.agent import AgentState @@ -486,7 +486,7 @@ class LettaMultiAgentToolExecutor(ToolExecutor): class ExternalComposioToolExecutor(ToolExecutor): """Executor for external Composio tools.""" - def execute( + async def execute( self, function_name: str, function_args: dict, @@ -505,7 +505,7 @@ class ExternalComposioToolExecutor(ToolExecutor): composio_api_key = get_composio_api_key(actor=actor) # TODO (matt): Roll in execute_composio_action into this class - function_response = execute_composio_action( + function_response = await execute_composio_action_async( action_name=action_name, args=function_args, api_key=composio_api_key, entity_id=entity_id ) diff --git a/poetry.lock b/poetry.lock index 8812c8ccf..e67c55d35 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1016,25 +1016,6 @@ e2b = ["e2b (>=0.17.2a37,<1.1.0)", "e2b-code-interpreter"] flyio = ["gql", "requests_toolbelt"] tools = ["diskcache", "flake8", "networkx", "pathspec", "pygments", "ruff", "transformers"] -[[package]] -name = "composio-langchain" -version = "0.7.15" -description = "Use Composio to get an array of tools with your LangChain agent." -optional = false -python-versions = "<4,>=3.9" -groups = ["main"] -files = [ - {file = "composio_langchain-0.7.15-py3-none-any.whl", hash = "sha256:a71b5371ad6c3ee4d4289c7a994fad1424e24c29a38e820b6b2ed259056abb65"}, - {file = "composio_langchain-0.7.15.tar.gz", hash = "sha256:cb75c460289ecdf9590caf7ddc0d7888b0a6622ca4f800c9358abe90c25d055e"}, -] - -[package.dependencies] -composio_core = ">=0.7.0,<0.8.0" -langchain = ">=0.1.0" -langchain-openai = ">=0.0.2.post1" -langchainhub = ">=0.1.15" -pydantic = ">=2.6.4" - [[package]] name = "configargparse" version = "1.7" @@ -2842,9 +2823,10 @@ files = [ name = "jsonpatch" version = "1.33" description = "Apply JSON-Patches (RFC 6902)" -optional = false +optional = true python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, !=3.6.*" groups = ["main"] +markers = "extra == \"external-tools\" or extra == \"desktop\" or extra == \"all\"" files = [ {file = "jsonpatch-1.33-py2.py3-none-any.whl", hash = "sha256:0ae28c0cd062bbd8b8ecc26d7d164fbbea9652a1a3693f3b956c1eae5145dade"}, {file = "jsonpatch-1.33.tar.gz", hash = "sha256:9fcd4009c41e6d12348b4a0ff2563ba56a2923a7dfee731d004e212e1ee5030c"}, @@ -2857,9 +2839,10 @@ jsonpointer = ">=1.9" name = "jsonpointer" version = "3.0.0" description = "Identify specific nodes in a JSON document (RFC 6901)" -optional = false +optional = true python-versions = ">=3.7" groups = ["main"] +markers = "extra == \"external-tools\" or extra == \"desktop\" or extra == \"all\"" files = [ {file = "jsonpointer-3.0.0-py2.py3-none-any.whl", hash = "sha256:13e088adc14fca8b6aa8177c044e12701e6ad4b28ff10e65f2267a90109c9942"}, {file = "jsonpointer-3.0.0.tar.gz", hash = "sha256:2b2d729f2091522d61c3b31f82e11870f60b68f43fbc705cb76bf4b832af59ef"}, @@ -3052,9 +3035,10 @@ files = [ name = "langchain" version = "0.3.23" description = "Building applications with LLMs through composability" -optional = false +optional = true python-versions = "<4.0,>=3.9" groups = ["main"] +markers = "extra == \"external-tools\" or extra == \"desktop\" or extra == \"all\"" files = [ {file = "langchain-0.3.23-py3-none-any.whl", hash = "sha256:084f05ee7e80b7c3f378ebadd7309f2a37868ce2906fa0ae64365a67843ade3d"}, {file = "langchain-0.3.23.tar.gz", hash = "sha256:d95004afe8abebb52d51d6026270248da3f4b53d93e9bf699f76005e0c83ad34"}, @@ -3120,9 +3104,10 @@ tenacity = ">=8.1.0,<8.4.0 || >8.4.0,<10" name = "langchain-core" version = "0.3.51" description = "Building applications with LLMs through composability" -optional = false +optional = true python-versions = "<4.0,>=3.9" groups = ["main"] +markers = "extra == \"external-tools\" or extra == \"desktop\" or extra == \"all\"" files = [ {file = "langchain_core-0.3.51-py3-none-any.whl", hash = "sha256:4bd71e8acd45362aa428953f2a91d8162318014544a2216e4b769463caf68e13"}, {file = "langchain_core-0.3.51.tar.gz", hash = "sha256:db76b9cc331411602cb40ba0469a161febe7a0663fbcaddbc9056046ac2d22f4"}, @@ -3140,30 +3125,14 @@ PyYAML = ">=5.3" tenacity = ">=8.1.0,<8.4.0 || >8.4.0,<10.0.0" typing-extensions = ">=4.7" -[[package]] -name = "langchain-openai" -version = "0.3.12" -description = "An integration package connecting OpenAI and LangChain" -optional = false -python-versions = "<4.0,>=3.9" -groups = ["main"] -files = [ - {file = "langchain_openai-0.3.12-py3-none-any.whl", hash = "sha256:0fab64d58ec95e65ffbaf659470cd362e815685e15edbcb171641e90eca4eb86"}, - {file = "langchain_openai-0.3.12.tar.gz", hash = "sha256:c9dbff63551f6bd91913bca9f99a2d057fd95dc58d4778657d67e5baa1737f61"}, -] - -[package.dependencies] -langchain-core = ">=0.3.49,<1.0.0" -openai = ">=1.68.2,<2.0.0" -tiktoken = ">=0.7,<1" - [[package]] name = "langchain-text-splitters" version = "0.3.8" description = "LangChain text splitting utilities" -optional = false +optional = true python-versions = "<4.0,>=3.9" groups = ["main"] +markers = "extra == \"external-tools\" or extra == \"desktop\" or extra == \"all\"" files = [ {file = "langchain_text_splitters-0.3.8-py3-none-any.whl", hash = "sha256:e75cc0f4ae58dcf07d9f18776400cf8ade27fadd4ff6d264df6278bb302f6f02"}, {file = "langchain_text_splitters-0.3.8.tar.gz", hash = "sha256:116d4b9f2a22dda357d0b79e30acf005c5518177971c66a9f1ab0edfdb0f912e"}, @@ -3172,30 +3141,14 @@ files = [ [package.dependencies] langchain-core = ">=0.3.51,<1.0.0" -[[package]] -name = "langchainhub" -version = "0.1.21" -description = "The LangChain Hub API client" -optional = false -python-versions = "<4.0,>=3.8.1" -groups = ["main"] -files = [ - {file = "langchainhub-0.1.21-py3-none-any.whl", hash = "sha256:1cc002dc31e0d132a776afd044361e2b698743df5202618cf2bad399246b895f"}, - {file = "langchainhub-0.1.21.tar.gz", hash = "sha256:723383b3964a47dbaea6ad5d0ef728accefbc9d2c07480e800bdec43510a8c10"}, -] - -[package.dependencies] -packaging = ">=23.2,<25" -requests = ">=2,<3" -types-requests = ">=2.31.0.2,<3.0.0.0" - [[package]] name = "langsmith" version = "0.3.28" description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform." -optional = false +optional = true python-versions = "<4.0,>=3.9" groups = ["main"] +markers = "extra == \"external-tools\" or extra == \"desktop\" or extra == \"all\"" files = [ {file = "langsmith-0.3.28-py3-none-any.whl", hash = "sha256:54ac8815514af52d9c801ad7970086693667e266bf1db90fc453c1759e8407cd"}, {file = "langsmith-0.3.28.tar.gz", hash = "sha256:4666595207131d7f8d83418e54dc86c05e28562e5c997633e7c33fc18f9aeb89"}, @@ -3221,14 +3174,14 @@ pytest = ["pytest (>=7.0.0)", "rich (>=13.9.4,<14.0.0)"] [[package]] name = "letta-client" -version = "0.1.124" +version = "0.1.129" description = "" optional = false python-versions = "<4.0,>=3.8" groups = ["main"] files = [ - {file = "letta_client-0.1.124-py3-none-any.whl", hash = "sha256:a7901437ef91f395cd85d24c0312046b7c82e5a4dd8e04de0d39b5ca085c65d3"}, - {file = "letta_client-0.1.124.tar.gz", hash = "sha256:e8b5716930824cc98c62ee01343e358f88619d346578d48a466277bc8282036d"}, + {file = "letta_client-0.1.129-py3-none-any.whl", hash = "sha256:87a5fc32471e5b9fefbfc1e1337fd667d5e2e340ece5d2a6c782afbceab4bf36"}, + {file = "letta_client-0.1.129.tar.gz", hash = "sha256:b00f611c18a2ad802ec9265f384e1666938c5fc5c86364b2c410d72f0331d597"}, ] [package.dependencies] @@ -4366,10 +4319,10 @@ files = [ name = "orjson" version = "3.10.16" description = "Fast, correct Python JSON library supporting dataclasses, datetimes, and numpy" -optional = false +optional = true python-versions = ">=3.9" groups = ["main"] -markers = "platform_python_implementation != \"PyPy\"" +markers = "platform_python_implementation != \"PyPy\" and (extra == \"external-tools\" or extra == \"desktop\" or extra == \"all\")" files = [ {file = "orjson-3.10.16-cp310-cp310-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:4cb473b8e79154fa778fb56d2d73763d977be3dcc140587e07dbc545bbfc38f8"}, {file = "orjson-3.10.16-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:622a8e85eeec1948690409a19ca1c7d9fd8ff116f4861d261e6ae2094fe59a00"}, @@ -6069,9 +6022,10 @@ use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] name = "requests-toolbelt" version = "1.0.0" description = "A utility belt for advanced users of python-requests" -optional = false +optional = true python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" groups = ["main"] +markers = "extra == \"external-tools\" or extra == \"desktop\" or extra == \"all\"" files = [ {file = "requests-toolbelt-1.0.0.tar.gz", hash = "sha256:7681a0a3d047012b5bdc0ee37d7f8f07ebe76ab08caeccfc3921ce23c88d5bc6"}, {file = "requests_toolbelt-1.0.0-py2.py3-none-any.whl", hash = "sha256:cccfdd665f0a24fcf4726e690f65639d272bb0637b9b92dfd91a5568ccf6bd06"}, @@ -6855,21 +6809,6 @@ dev = ["autoflake (>=1.3.1,<2.0.0)", "flake8 (>=3.8.3,<4.0.0)", "pre-commit (>=2 doc = ["cairosvg (>=2.5.2,<3.0.0)", "mdx-include (>=1.4.1,<2.0.0)", "mkdocs (>=1.1.2,<2.0.0)", "mkdocs-material (>=8.1.4,<9.0.0)", "pillow (>=9.3.0,<10.0.0)"] test = ["black (>=22.3.0,<23.0.0)", "coverage (>=6.2,<7.0)", "isort (>=5.0.6,<6.0.0)", "mypy (==0.971)", "pytest (>=4.4.0,<8.0.0)", "pytest-cov (>=2.10.0,<5.0.0)", "pytest-sugar (>=0.9.4,<0.10.0)", "pytest-xdist (>=1.32.0,<4.0.0)", "rich (>=10.11.0,<14.0.0)", "shellingham (>=1.3.0,<2.0.0)"] -[[package]] -name = "types-requests" -version = "2.32.0.20250328" -description = "Typing stubs for requests" -optional = false -python-versions = ">=3.9" -groups = ["main"] -files = [ - {file = "types_requests-2.32.0.20250328-py3-none-any.whl", hash = "sha256:72ff80f84b15eb3aa7a8e2625fffb6a93f2ad5a0c20215fc1dcfa61117bcb2a2"}, - {file = "types_requests-2.32.0.20250328.tar.gz", hash = "sha256:c9e67228ea103bd811c96984fac36ed2ae8da87a36a633964a21f199d60baf32"}, -] - -[package.dependencies] -urllib3 = ">=2" - [[package]] name = "typing-extensions" version = "4.13.2" @@ -7438,9 +7377,10 @@ testing = ["coverage[toml]", "zope.event", "zope.testing"] name = "zstandard" version = "0.23.0" description = "Zstandard bindings for Python" -optional = false +optional = true python-versions = ">=3.8" groups = ["main"] +markers = "extra == \"external-tools\" or extra == \"desktop\" or extra == \"all\"" files = [ {file = "zstandard-0.23.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:bf0a05b6059c0528477fba9054d09179beb63744355cab9f38059548fedd46a9"}, {file = "zstandard-0.23.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:fc9ca1c9718cb3b06634c7c8dec57d24e9438b2aa9a0f02b8bb36bf478538880"}, @@ -7563,4 +7503,4 @@ tests = ["wikipedia"] [metadata] lock-version = "2.1" python-versions = "<3.14,>=3.10" -content-hash = "75c1c949aa6c0ef8d681bddd91999f97ed4991451be93ca45bf9c01dd19d8a8a" +content-hash = "ba9cf0e00af2d5542aa4beecbd727af92b77ba584033f05c222b00ae47f96585" diff --git a/pyproject.toml b/pyproject.toml index e967c6709..a20be9ac9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "letta" -version = "0.7.7" +version = "0.7.8" packages = [ {include = "letta"}, ] @@ -56,7 +56,6 @@ nltk = "^3.8.1" jinja2 = "^3.1.5" locust = {version = "^2.31.5", optional = true} wikipedia = {version = "^1.4.0", optional = true} -composio-langchain = "^0.7.7" composio-core = "^0.7.7" alembic = "^1.13.3" pyhumps = "^3.8.0" @@ -74,7 +73,7 @@ llama-index = "^0.12.2" llama-index-embeddings-openai = "^0.3.1" e2b-code-interpreter = {version = "^1.0.3", optional = true} anthropic = "^0.49.0" -letta_client = "^0.1.124" +letta_client = "^0.1.127" openai = "^1.60.0" opentelemetry-api = "1.30.0" opentelemetry-sdk = "1.30.0" diff --git a/tests/configs/letta_hosted.json b/tests/configs/letta_hosted.json index 3fd85a4c1..278050a64 100644 --- a/tests/configs/letta_hosted.json +++ b/tests/configs/letta_hosted.json @@ -1,11 +1,11 @@ { - "context_window": 8192, - "model_endpoint_type": "openai", - "model_endpoint": "https://inference.memgpt.ai", - "model": "memgpt-openai", - "embedding_endpoint_type": "hugging-face", - "embedding_endpoint": "https://embeddings.memgpt.ai", - "embedding_model": "BAAI/bge-large-en-v1.5", - "embedding_dim": 1024, - "embedding_chunk_size": 300 + "context_window": 8192, + "model_endpoint_type": "openai", + "model_endpoint": "https://inference.letta.com", + "model": "memgpt-openai", + "embedding_endpoint_type": "hugging-face", + "embedding_endpoint": "https://embeddings.memgpt.ai", + "embedding_model": "BAAI/bge-large-en-v1.5", + "embedding_dim": 1024, + "embedding_chunk_size": 300 } diff --git a/tests/configs/llm_model_configs/letta-hosted.json b/tests/configs/llm_model_configs/letta-hosted.json index 82ece9e4d..419cda814 100644 --- a/tests/configs/llm_model_configs/letta-hosted.json +++ b/tests/configs/llm_model_configs/letta-hosted.json @@ -1,7 +1,7 @@ { "context_window": 8192, "model_endpoint_type": "openai", - "model_endpoint": "https://inference.memgpt.ai", + "model_endpoint": "https://inference.letta.com", "model": "memgpt-openai", "put_inner_thoughts_in_kwargs": true } diff --git a/tests/helpers/endpoints_helper.py b/tests/helpers/endpoints_helper.py index 9b8f9a9f1..b0cb28025 100644 --- a/tests/helpers/endpoints_helper.py +++ b/tests/helpers/endpoints_helper.py @@ -105,7 +105,9 @@ def check_first_response_is_valid_for_llm_endpoint(filename: str, validate_inner agent = Agent(agent_state=full_agent_state, interface=None, user=client.user) llm_client = LLMClient.create( - provider=agent_state.llm_config.model_endpoint_type, + provider_name=agent_state.llm_config.provider_name, + provider_type=agent_state.llm_config.model_endpoint_type, + actor_id=client.user.id, ) if llm_client: response = llm_client.send_llm_request( @@ -179,7 +181,7 @@ def check_agent_uses_external_tool(filename: str) -> LettaResponse: Note: This is acting on the Letta response, note the usage of `user_message` """ - from composio_langchain import Action + from composio import Action # Set up client client = create_client() diff --git a/tests/integration_test_composio.py b/tests/integration_test_composio.py index fd6b32cab..e1219d1ea 100644 --- a/tests/integration_test_composio.py +++ b/tests/integration_test_composio.py @@ -56,7 +56,7 @@ def test_add_composio_tool(fastapi_client): assert "name" in response.json() -def test_composio_tool_execution_e2e(check_composio_key_set, composio_get_emojis, server: SyncServer, default_user): +async def test_composio_tool_execution_e2e(check_composio_key_set, composio_get_emojis, server: SyncServer, default_user): agent_state = server.agent_manager.create_agent( agent_create=CreateAgent( name="sarah_agent", @@ -67,7 +67,7 @@ def test_composio_tool_execution_e2e(check_composio_key_set, composio_get_emojis actor=default_user, ) - tool_execution_result = ToolExecutionManager(agent_state, actor=default_user).execute_tool( + tool_execution_result = await ToolExecutionManager(agent_state, actor=default_user).execute_tool( function_name=composio_get_emojis.name, function_args={}, tool=composio_get_emojis ) diff --git a/tests/integration_test_voice_agent.py b/tests/integration_test_voice_agent.py index 271091165..bc6c09dbe 100644 --- a/tests/integration_test_voice_agent.py +++ b/tests/integration_test_voice_agent.py @@ -1,26 +1,26 @@ import os import threading +from unittest.mock import MagicMock import pytest from dotenv import load_dotenv from letta_client import Letta from openai import AsyncOpenAI from openai.types.chat import ChatCompletionChunk -from sqlalchemy import delete from letta.agents.voice_sleeptime_agent import VoiceSleeptimeAgent from letta.config import LettaConfig -from letta.orm import Provider, Step +from letta.constants import DEFAULT_MAX_MESSAGE_BUFFER_LENGTH, DEFAULT_MIN_MESSAGE_BUFFER_LENGTH from letta.orm.errors import NoResultFound from letta.schemas.agent import AgentType, CreateAgent from letta.schemas.block import CreateBlock from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.enums import MessageRole, MessageStreamStatus -from letta.schemas.group import ManagerType +from letta.schemas.group import GroupUpdate, ManagerType, VoiceSleeptimeManagerUpdate from letta.schemas.letta_message import AssistantMessage, ReasoningMessage, ToolCallMessage, ToolReturnMessage, UserMessage from letta.schemas.letta_message_content import TextContent from letta.schemas.llm_config import LLMConfig -from letta.schemas.message import MessageCreate +from letta.schemas.message import Message, MessageCreate from letta.schemas.openai.chat_completion_request import ChatCompletionRequest from letta.schemas.openai.chat_completion_request import UserMessage as OpenAIUserMessage from letta.schemas.tool import ToolCreate @@ -29,6 +29,8 @@ from letta.server.server import SyncServer from letta.services.agent_manager import AgentManager from letta.services.block_manager import BlockManager from letta.services.message_manager import MessageManager +from letta.services.summarizer.enums import SummarizationMode +from letta.services.summarizer.summarizer import Summarizer from letta.services.tool_manager import ToolManager from letta.services.user_manager import UserManager from letta.utils import get_persona_text @@ -48,16 +50,24 @@ MESSAGE_TRANSCRIPTS = [ "user: Maybe just a recommendation for a nice vegan bakery to grab a birthday treat.", "assistant: How about Vegan Treats in Santa Barbara? They’re highly rated.", "user: Sounds good. Also, I work remotely as a UX designer, usually on a MacBook Pro.", - "user: I want to make sure my itinerary isn’t too tight—aiming for 3–4 days total.", "assistant: Understood. I can draft a relaxed 4-day schedule with driving and stops.", "user: Yes, let’s do that.", "assistant: I’ll put together a day-by-day plan now.", ] -SUMMARY_REQ_TEXT = """ -Here is the conversation history. Lines marked (Older) are about to be evicted; lines marked (Newer) are still in context for clarity: +SYSTEM_MESSAGE = Message(role=MessageRole.system, content=[TextContent(text="System message")]) +MESSAGE_OBJECTS = [SYSTEM_MESSAGE] +for entry in MESSAGE_TRANSCRIPTS: + role_str, text = entry.split(":", 1) + role = MessageRole.user if role_str.strip() == "user" else MessageRole.assistant + MESSAGE_OBJECTS.append(Message(role=role, content=[TextContent(text=text.strip())])) +MESSAGE_EVICT_BREAKPOINT = 14 + +SUMMARY_REQ_TEXT = """ +You’re a memory-recall helper for an AI that can only keep the last 4 messages. Scan the conversation history, focusing on messages about to drop out of that window, and write crisp notes that capture any important facts or insights about the human so they aren’t lost. + +(Older) Evicted Messages: -(Older) 0. user: Hey, I’ve been thinking about planning a road trip up the California coast next month. 1. assistant: That sounds amazing! Do you have any particular cities or sights in mind? 2. user: I definitely want to stop in Big Sur and maybe Santa Barbara. Also, I love craft coffee shops. @@ -70,16 +80,13 @@ Here is the conversation history. Lines marked (Older) are about to be evicted; 9. assistant: Happy early birthday! Would you like gift ideas or celebration tips? 10. user: Maybe just a recommendation for a nice vegan bakery to grab a birthday treat. 11. assistant: How about Vegan Treats in Santa Barbara? They’re highly rated. + +(Newer) In-Context Messages: + 12. user: Sounds good. Also, I work remotely as a UX designer, usually on a MacBook Pro. - -(Newer) -13. user: I want to make sure my itinerary isn’t too tight—aiming for 3–4 days total. -14. assistant: Understood. I can draft a relaxed 4-day schedule with driving and stops. -15. user: Yes, let’s do that. -16. assistant: I’ll put together a day-by-day plan now. - -Please segment the (Older) portion into coherent chunks and—using **only** the `store_memory` tool—output a JSON call that lists each chunk’s `start_index`, `end_index`, and a one-sentence `contextual_description`. - """ +13. assistant: Understood. I can draft a relaxed 4-day schedule with driving and stops. +14. user: Yes, let’s do that. +15. assistant: I’ll put together a day-by-day plan now.""" # --- Server Management --- # @@ -214,22 +221,12 @@ def org_id(server): yield org.id - # cleanup - with server.organization_manager.session_maker() as session: - session.execute(delete(Step)) - session.execute(delete(Provider)) - session.commit() - server.organization_manager.delete_organization_by_id(org.id) - @pytest.fixture(scope="module") def actor(server, org_id): user = server.user_manager.create_default_user() yield user - # cleanup - server.user_manager.delete_user_by_id(user.id) - # --- Helper Functions --- # @@ -301,6 +298,80 @@ async def test_multiple_messages(disable_e2b_api_key, client, voice_agent, endpo print(chunk.choices[0].delta.content) +@pytest.mark.asyncio +async def test_summarization(disable_e2b_api_key, voice_agent): + agent_manager = AgentManager() + user_manager = UserManager() + actor = user_manager.get_default_user() + + request = CreateAgent( + name=voice_agent.name + "-sleeptime", + agent_type=AgentType.voice_sleeptime_agent, + block_ids=[block.id for block in voice_agent.memory.blocks], + memory_blocks=[ + CreateBlock( + label="memory_persona", + value=get_persona_text("voice_memory_persona"), + ), + ], + llm_config=LLMConfig.default_config(model_name="gpt-4o-mini"), + embedding_config=EmbeddingConfig.default_config(provider="openai"), + project_id=voice_agent.project_id, + ) + sleeptime_agent = agent_manager.create_agent(request, actor=actor) + + async_client = AsyncOpenAI() + + memory_agent = VoiceSleeptimeAgent( + agent_id=sleeptime_agent.id, + convo_agent_state=sleeptime_agent, # In reality, this will be the main convo agent + openai_client=async_client, + message_manager=MessageManager(), + agent_manager=agent_manager, + actor=actor, + block_manager=BlockManager(), + target_block_label="human", + message_transcripts=MESSAGE_TRANSCRIPTS, + ) + + summarizer = Summarizer( + mode=SummarizationMode.STATIC_MESSAGE_BUFFER, + summarizer_agent=memory_agent, + message_buffer_limit=8, + message_buffer_min=4, + ) + + # stub out the agent.step so it returns a known sentinel + memory_agent.step = MagicMock(return_value="STEP_RESULT") + + # patch fire_and_forget on *this* summarizer instance to a MagicMock + summarizer.fire_and_forget = MagicMock() + + # now call the method under test + in_ctx = MESSAGE_OBJECTS[:MESSAGE_EVICT_BREAKPOINT] + new_msgs = MESSAGE_OBJECTS[MESSAGE_EVICT_BREAKPOINT:] + # call under test (this is sync) + updated, did_summarize = summarizer._static_buffer_summarization( + in_context_messages=in_ctx, + new_letta_messages=new_msgs, + ) + + assert did_summarize is True + assert len(updated) == summarizer.message_buffer_min + 1 # One extra for system message + assert updated[0].role == MessageRole.system # Preserved system message + + # 2) the summarizer_agent.step() should have been *called* exactly once + memory_agent.step.assert_called_once() + call_args = memory_agent.step.call_args.args[0] # the single positional argument: a list of MessageCreate + assert isinstance(call_args, list) + assert isinstance(call_args[0], MessageCreate) + assert call_args[0].role == MessageRole.user + assert "15. assistant: I’ll put together a day-by-day plan now." in call_args[0].content[0].text + + # 3) fire_and_forget should have been called once, and its argument must be the coroutine returned by step() + summarizer.fire_and_forget.assert_called_once() + + @pytest.mark.asyncio async def test_voice_sleeptime_agent(disable_e2b_api_key, voice_agent): """Tests chat completion streaming using the Async OpenAI client.""" @@ -427,3 +498,66 @@ async def test_init_voice_convo_agent(voice_agent, server, actor): server.group_manager.retrieve_group(group_id=group.id, actor=actor) with pytest.raises(NoResultFound): server.agent_manager.get_agent_by_id(agent_id=sleeptime_agent_id, actor=actor) + + +def _modify(group_id, server, actor, max_val, min_val): + """Helper to invoke modify_group with voice_sleeptime config.""" + return server.group_manager.modify_group( + group_id=group_id, + group_update=GroupUpdate( + manager_config=VoiceSleeptimeManagerUpdate( + manager_type=ManagerType.voice_sleeptime, + max_message_buffer_length=max_val, + min_message_buffer_length=min_val, + ) + ), + actor=actor, + ) + + +@pytest.fixture +def group_id(voice_agent): + return voice_agent.multi_agent_group.id + + +def test_valid_buffer_lengths_above_four(group_id, server, actor): + # both > 4 and max > min + updated = _modify(group_id, server, actor, max_val=10, min_val=5) + assert updated.max_message_buffer_length == 10 + assert updated.min_message_buffer_length == 5 + + +def test_valid_buffer_lengths_only_max(group_id, server, actor): + # both > 4 and max > min + updated = _modify(group_id, server, actor, max_val=DEFAULT_MAX_MESSAGE_BUFFER_LENGTH + 1, min_val=None) + assert updated.max_message_buffer_length == DEFAULT_MAX_MESSAGE_BUFFER_LENGTH + 1 + assert updated.min_message_buffer_length == DEFAULT_MIN_MESSAGE_BUFFER_LENGTH + + +def test_valid_buffer_lengths_only_min(group_id, server, actor): + # both > 4 and max > min + updated = _modify(group_id, server, actor, max_val=None, min_val=DEFAULT_MIN_MESSAGE_BUFFER_LENGTH + 1) + assert updated.max_message_buffer_length == DEFAULT_MAX_MESSAGE_BUFFER_LENGTH + assert updated.min_message_buffer_length == DEFAULT_MIN_MESSAGE_BUFFER_LENGTH + 1 + + +@pytest.mark.parametrize( + "max_val,min_val,err_part", + [ + # only one set → both-or-none + (None, DEFAULT_MAX_MESSAGE_BUFFER_LENGTH, "must be greater than"), + (DEFAULT_MIN_MESSAGE_BUFFER_LENGTH, None, "must be greater than"), + # ordering violations + (5, 5, "must be greater than"), + (6, 7, "must be greater than"), + # lower-bound (must both be > 4) + (4, 5, "greater than 4"), + (5, 4, "greater than 4"), + (1, 10, "greater than 4"), + (10, 1, "greater than 4"), + ], +) +def test_invalid_buffer_lengths(group_id, server, actor, max_val, min_val, err_part): + with pytest.raises(ValueError) as exc: + _modify(group_id, server, actor, max_val, min_val) + assert err_part in str(exc.value) diff --git a/tests/test_local_client.py b/tests/test_local_client.py index 0bd9a1401..a3967e4a0 100644 --- a/tests/test_local_client.py +++ b/tests/test_local_client.py @@ -124,7 +124,7 @@ def test_agent(client: LocalClient): def test_agent_add_remove_tools(client: LocalClient, agent): # Create and add two tools to the client # tool 1 - from composio_langchain import Action + from composio import Action github_tool = client.load_composio_tool(action=Action.GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER) @@ -316,7 +316,7 @@ def test_tools(client: LocalClient): def test_tools_from_composio_basic(client: LocalClient): - from composio_langchain import Action + from composio import Action # Create a `LocalClient` (you can also use a `RESTClient`, see the letta_rest_client.py example) client = create_client() diff --git a/tests/test_optimistic_json_parser.py b/tests/test_optimistic_json_parser.py index f7741f7ce..08bb11c13 100644 --- a/tests/test_optimistic_json_parser.py +++ b/tests/test_optimistic_json_parser.py @@ -3,7 +3,7 @@ from unittest.mock import patch import pytest -from letta.server.rest_api.optimistic_json_parser import OptimisticJSONParser +from letta.server.rest_api.json_parser import OptimisticJSONParser @pytest.fixture diff --git a/tests/test_providers.py b/tests/test_providers.py index 0394dec01..2ab6606d7 100644 --- a/tests/test_providers.py +++ b/tests/test_providers.py @@ -19,97 +19,166 @@ from letta.settings import model_settings def test_openai(): api_key = os.getenv("OPENAI_API_KEY") assert api_key is not None - provider = OpenAIProvider(api_key=api_key, base_url=model_settings.openai_api_base) + provider = OpenAIProvider( + name="openai", + api_key=api_key, + base_url=model_settings.openai_api_base, + ) models = provider.list_llm_models() - print(models) + assert len(models) > 0 + assert models[0].handle == f"{provider.name}/{models[0].model}" + + embedding_models = provider.list_embedding_models() + assert len(embedding_models) > 0 + assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}" def test_deepseek(): api_key = os.getenv("DEEPSEEK_API_KEY") assert api_key is not None - provider = DeepSeekProvider(api_key=api_key) + provider = DeepSeekProvider( + name="deepseek", + api_key=api_key, + ) models = provider.list_llm_models() - print(models) + assert len(models) > 0 + assert models[0].handle == f"{provider.name}/{models[0].model}" def test_anthropic(): api_key = os.getenv("ANTHROPIC_API_KEY") assert api_key is not None - provider = AnthropicProvider(api_key=api_key) + provider = AnthropicProvider( + name="anthropic", + api_key=api_key, + ) models = provider.list_llm_models() - print(models) + assert len(models) > 0 + assert models[0].handle == f"{provider.name}/{models[0].model}" def test_groq(): - provider = GroqProvider(api_key=os.getenv("GROQ_API_KEY")) + provider = GroqProvider( + name="groq", + api_key=os.getenv("GROQ_API_KEY"), + ) models = provider.list_llm_models() - print(models) + assert len(models) > 0 + assert models[0].handle == f"{provider.name}/{models[0].model}" def test_azure(): - provider = AzureProvider(api_key=os.getenv("AZURE_API_KEY"), base_url=os.getenv("AZURE_BASE_URL")) + provider = AzureProvider( + name="azure", + api_key=os.getenv("AZURE_API_KEY"), + base_url=os.getenv("AZURE_BASE_URL"), + ) models = provider.list_llm_models() - print([m.model for m in models]) + assert len(models) > 0 + assert models[0].handle == f"{provider.name}/{models[0].model}" - embed_models = provider.list_embedding_models() - print([m.embedding_model for m in embed_models]) + embedding_models = provider.list_embedding_models() + assert len(embedding_models) > 0 + assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}" def test_ollama(): base_url = os.getenv("OLLAMA_BASE_URL") assert base_url is not None - provider = OllamaProvider(base_url=base_url, default_prompt_formatter=model_settings.default_prompt_formatter, api_key=None) + provider = OllamaProvider( + name="ollama", + base_url=base_url, + default_prompt_formatter=model_settings.default_prompt_formatter, + api_key=None, + ) models = provider.list_llm_models() - print(models) + assert len(models) > 0 + assert models[0].handle == f"{provider.name}/{models[0].model}" embedding_models = provider.list_embedding_models() - print(embedding_models) + assert len(embedding_models) > 0 + assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}" def test_googleai(): api_key = os.getenv("GEMINI_API_KEY") assert api_key is not None - provider = GoogleAIProvider(api_key=api_key) + provider = GoogleAIProvider( + name="google_ai", + api_key=api_key, + ) models = provider.list_llm_models() - print(models) + assert len(models) > 0 + assert models[0].handle == f"{provider.name}/{models[0].model}" - provider.list_embedding_models() + embedding_models = provider.list_embedding_models() + assert len(embedding_models) > 0 + assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}" def test_google_vertex(): - provider = GoogleVertexProvider(google_cloud_project=os.getenv("GCP_PROJECT_ID"), google_cloud_location=os.getenv("GCP_REGION")) + provider = GoogleVertexProvider( + name="google_vertex", + google_cloud_project=os.getenv("GCP_PROJECT_ID"), + google_cloud_location=os.getenv("GCP_REGION"), + ) models = provider.list_llm_models() - print(models) - print([m.model for m in models]) + assert len(models) > 0 + assert models[0].handle == f"{provider.name}/{models[0].model}" embedding_models = provider.list_embedding_models() - print([m.embedding_model for m in embedding_models]) + assert len(embedding_models) > 0 + assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}" def test_mistral(): - provider = MistralProvider(api_key=os.getenv("MISTRAL_API_KEY")) + provider = MistralProvider( + name="mistral", + api_key=os.getenv("MISTRAL_API_KEY"), + ) models = provider.list_llm_models() - print([m.model for m in models]) + assert len(models) > 0 + assert models[0].handle == f"{provider.name}/{models[0].model}" def test_together(): - provider = TogetherProvider(api_key=os.getenv("TOGETHER_API_KEY"), default_prompt_formatter="chatml") + provider = TogetherProvider( + name="together", + api_key=os.getenv("TOGETHER_API_KEY"), + default_prompt_formatter="chatml", + ) models = provider.list_llm_models() - print([m.model for m in models]) + assert len(models) > 0 + assert models[0].handle == f"{provider.name}/{models[0].model}" embedding_models = provider.list_embedding_models() - print([m.embedding_model for m in embedding_models]) + assert len(embedding_models) > 0 + assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}" def test_anthropic_bedrock(): from letta.settings import model_settings - provider = AnthropicBedrockProvider(aws_region=model_settings.aws_region) + provider = AnthropicBedrockProvider(name="bedrock", aws_region=model_settings.aws_region) models = provider.list_llm_models() - print([m.model for m in models]) + assert len(models) > 0 + assert models[0].handle == f"{provider.name}/{models[0].model}" embedding_models = provider.list_embedding_models() - print([m.embedding_model for m in embedding_models]) + assert len(embedding_models) > 0 + assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}" + + +def test_custom_anthropic(): + api_key = os.getenv("ANTHROPIC_API_KEY") + assert api_key is not None + provider = AnthropicProvider( + name="custom_anthropic", + api_key=api_key, + ) + models = provider.list_llm_models() + assert len(models) > 0 + assert models[0].handle == f"{provider.name}/{models[0].model}" # def test_vllm(): diff --git a/tests/test_server.py b/tests/test_server.py index 023897cdd..7d6d73e67 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -13,7 +13,7 @@ import letta.utils as utils from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, LETTA_DIR, LETTA_TOOL_EXECUTION_DIR from letta.orm import Provider, Step from letta.schemas.block import CreateBlock -from letta.schemas.enums import MessageRole +from letta.schemas.enums import MessageRole, ProviderType from letta.schemas.letta_message import LettaMessage, ReasoningMessage, SystemMessage, ToolCallMessage, ToolReturnMessage, UserMessage from letta.schemas.llm_config import LLMConfig from letta.schemas.providers import Provider as PydanticProvider @@ -1226,7 +1226,8 @@ def test_messages_with_provider_override(server: SyncServer, user_id: str): actor = server.user_manager.get_user_or_default(user_id) provider = server.provider_manager.create_provider( provider=PydanticProvider( - name="anthropic", + name="caren-anthropic", + provider_type=ProviderType.anthropic, api_key=os.getenv("ANTHROPIC_API_KEY"), ), actor=actor, @@ -1234,8 +1235,8 @@ def test_messages_with_provider_override(server: SyncServer, user_id: str): agent = server.create_agent( request=CreateAgent( memory_blocks=[], - model="anthropic/claude-3-opus-20240229", - context_window_limit=200000, + model="caren-anthropic/claude-3-opus-20240229", + context_window_limit=100000, embedding="openai/text-embedding-ada-002", ), actor=actor,