mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
chore: bump version 0.7.8 (#2604)
Co-authored-by: Kian Jones <11655409+kianjones9@users.noreply.github.com> Co-authored-by: Andy Li <55300002+cliandy@users.noreply.github.com> Co-authored-by: Matthew Zhou <mattzh1314@gmail.com>
This commit is contained in:
parent
e07edd840b
commit
20ecab29a1
@ -167,7 +167,7 @@ docker exec -it $(docker ps -q -f ancestor=letta/letta) letta run
|
||||
In the CLI tool, you'll be able to create new agents, or load existing agents:
|
||||
```
|
||||
🧬 Creating new agent...
|
||||
? Select LLM model: letta-free [type=openai] [ip=https://inference.memgpt.ai]
|
||||
? Select LLM model: letta-free [type=openai] [ip=https://inference.letta.com]
|
||||
? Select embedding model: letta-free [type=hugging-face] [ip=https://embeddings.memgpt.ai]
|
||||
-> 🤖 Using persona profile: 'sam_pov'
|
||||
-> 🧑 Using human profile: 'basic'
|
||||
@ -233,7 +233,7 @@ letta run
|
||||
```
|
||||
```
|
||||
🧬 Creating new agent...
|
||||
? Select LLM model: letta-free [type=openai] [ip=https://inference.memgpt.ai]
|
||||
? Select LLM model: letta-free [type=openai] [ip=https://inference.letta.com]
|
||||
? Select embedding model: letta-free [type=hugging-face] [ip=https://embeddings.memgpt.ai]
|
||||
-> 🤖 Using persona profile: 'sam_pov'
|
||||
-> 🧑 Using human profile: 'basic'
|
||||
|
@ -0,0 +1,35 @@
|
||||
"""add byok fields and unique constraint
|
||||
|
||||
Revision ID: 373dabcba6cf
|
||||
Revises: c56081a05371
|
||||
Create Date: 2025-04-30 19:38:25.010856
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "373dabcba6cf"
|
||||
down_revision: Union[str, None] = "c56081a05371"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("providers", sa.Column("provider_type", sa.String(), nullable=True))
|
||||
op.add_column("providers", sa.Column("base_url", sa.String(), nullable=True))
|
||||
op.create_unique_constraint("unique_name_organization_id", "providers", ["name", "organization_id"])
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_constraint("unique_name_organization_id", "providers", type_="unique")
|
||||
op.drop_column("providers", "base_url")
|
||||
op.drop_column("providers", "provider_type")
|
||||
# ### end Alembic commands ###
|
@ -0,0 +1,33 @@
|
||||
"""Add buffer length min max for voice sleeptime
|
||||
|
||||
Revision ID: c56081a05371
|
||||
Revises: 28b8765bdd0a
|
||||
Create Date: 2025-04-30 16:03:41.213750
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "c56081a05371"
|
||||
down_revision: Union[str, None] = "28b8765bdd0a"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("groups", sa.Column("max_message_buffer_length", sa.Integer(), nullable=True))
|
||||
op.add_column("groups", sa.Column("min_message_buffer_length", sa.Integer(), nullable=True))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("groups", "min_message_buffer_length")
|
||||
op.drop_column("groups", "max_message_buffer_length")
|
||||
# ### end Alembic commands ###
|
@ -60,7 +60,7 @@ Last updated Oct 2, 2024. Please check `composio` documentation for any composio
|
||||
|
||||
|
||||
def main():
|
||||
from composio_langchain import Action
|
||||
from composio import Action
|
||||
|
||||
# Add the composio tool
|
||||
tool = client.load_composio_tool(action=Action.GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER)
|
||||
|
32
examples/sleeptime/voice_sleeptime_example.py
Normal file
32
examples/sleeptime/voice_sleeptime_example.py
Normal file
@ -0,0 +1,32 @@
|
||||
from letta_client import Letta, VoiceSleeptimeManagerUpdate
|
||||
|
||||
client = Letta(base_url="http://localhost:8283")
|
||||
|
||||
agent = client.agents.create(
|
||||
name="low_latency_voice_agent_demo",
|
||||
agent_type="voice_convo_agent",
|
||||
memory_blocks=[
|
||||
{"value": "Name: ?", "label": "human"},
|
||||
{"value": "You are a helpful assistant.", "label": "persona"},
|
||||
],
|
||||
model="openai/gpt-4o-mini", # Use 4o-mini for speed
|
||||
embedding="openai/text-embedding-3-small",
|
||||
enable_sleeptime=True,
|
||||
initial_message_sequence = [],
|
||||
)
|
||||
print(f"Created agent id {agent.id}")
|
||||
|
||||
# get the group
|
||||
group_id = agent.multi_agent_group.id
|
||||
max_message_buffer_length = agent.multi_agent_group.max_message_buffer_length
|
||||
min_message_buffer_length = agent.multi_agent_group.min_message_buffer_length
|
||||
print(f"Group id: {group_id}, max_message_buffer_length: {max_message_buffer_length}, min_message_buffer_length: {min_message_buffer_length}")
|
||||
|
||||
# change it to be more frequent
|
||||
group = client.groups.modify(
|
||||
group_id=group_id,
|
||||
manager_config=VoiceSleeptimeManagerUpdate(
|
||||
max_message_buffer_length=10,
|
||||
min_message_buffer_length=6,
|
||||
)
|
||||
)
|
@ -1,4 +1,4 @@
|
||||
__version__ = "0.7.7"
|
||||
__version__ = "0.7.8"
|
||||
|
||||
# import clients
|
||||
from letta.client.client import LocalClient, RESTClient, create_client
|
||||
|
@ -21,14 +21,14 @@ from letta.constants import (
|
||||
)
|
||||
from letta.errors import ContextWindowExceededError
|
||||
from letta.functions.ast_parsers import coerce_dict_args_by_annotations, get_function_annotations_from_source
|
||||
from letta.functions.composio_helpers import execute_composio_action, generate_composio_action_from_func_name
|
||||
from letta.functions.functions import get_function_from_module
|
||||
from letta.functions.helpers import execute_composio_action, generate_composio_action_from_func_name
|
||||
from letta.functions.mcp_client.base_client import BaseMCPClient
|
||||
from letta.helpers import ToolRulesSolver
|
||||
from letta.helpers.composio_helpers import get_composio_api_key
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
from letta.helpers.json_helpers import json_dumps, json_loads
|
||||
from letta.helpers.message_helper import prepare_input_message_create
|
||||
from letta.helpers.message_helper import convert_message_creates_to_messages
|
||||
from letta.interface import AgentInterface
|
||||
from letta.llm_api.helpers import calculate_summarizer_cutoff, get_token_counts_for_messages, is_context_overflow_error
|
||||
from letta.llm_api.llm_api_tools import create
|
||||
@ -331,8 +331,10 @@ class Agent(BaseAgent):
|
||||
log_telemetry(self.logger, "_get_ai_reply create start")
|
||||
# New LLM client flow
|
||||
llm_client = LLMClient.create(
|
||||
provider=self.agent_state.llm_config.model_endpoint_type,
|
||||
provider_name=self.agent_state.llm_config.provider_name,
|
||||
provider_type=self.agent_state.llm_config.model_endpoint_type,
|
||||
put_inner_thoughts_first=put_inner_thoughts_first,
|
||||
actor_id=self.user.id,
|
||||
)
|
||||
|
||||
if llm_client and not stream:
|
||||
@ -726,8 +728,7 @@ class Agent(BaseAgent):
|
||||
self.tool_rules_solver.clear_tool_history()
|
||||
|
||||
# Convert MessageCreate objects to Message objects
|
||||
message_objects = [prepare_input_message_create(m, self.agent_state.id, True, True) for m in input_messages]
|
||||
next_input_messages = message_objects
|
||||
next_input_messages = convert_message_creates_to_messages(input_messages, self.agent_state.id)
|
||||
counter = 0
|
||||
total_usage = UsageStatistics()
|
||||
step_count = 0
|
||||
@ -942,12 +943,7 @@ class Agent(BaseAgent):
|
||||
model_endpoint=self.agent_state.llm_config.model_endpoint,
|
||||
context_window_limit=self.agent_state.llm_config.context_window,
|
||||
usage=response.usage,
|
||||
# TODO(@caren): Add full provider support - this line is a workaround for v0 BYOK feature
|
||||
provider_id=(
|
||||
self.provider_manager.get_anthropic_override_provider_id()
|
||||
if self.agent_state.llm_config.model_endpoint_type == "anthropic"
|
||||
else None
|
||||
),
|
||||
provider_id=self.provider_manager.get_provider_id_from_name(self.agent_state.llm_config.provider_name),
|
||||
job_id=job_id,
|
||||
)
|
||||
for message in all_new_messages:
|
||||
|
6
letta/agents/exceptions.py
Normal file
6
letta/agents/exceptions.py
Normal file
@ -0,0 +1,6 @@
|
||||
class IncompatibleAgentType(ValueError):
|
||||
def __init__(self, expected_type: str, actual_type: str):
|
||||
message = f"Incompatible agent type: expected '{expected_type}', but got '{actual_type}'."
|
||||
super().__init__(message)
|
||||
self.expected_type = expected_type
|
||||
self.actual_type = actual_type
|
@ -67,8 +67,10 @@ class LettaAgent(BaseAgent):
|
||||
)
|
||||
tool_rules_solver = ToolRulesSolver(agent_state.tool_rules)
|
||||
llm_client = LLMClient.create(
|
||||
provider=agent_state.llm_config.model_endpoint_type,
|
||||
provider_name=agent_state.llm_config.provider_name,
|
||||
provider_type=agent_state.llm_config.model_endpoint_type,
|
||||
put_inner_thoughts_first=True,
|
||||
actor_id=self.actor.id,
|
||||
)
|
||||
for step in range(max_steps):
|
||||
response = await self._get_ai_reply(
|
||||
@ -109,8 +111,10 @@ class LettaAgent(BaseAgent):
|
||||
)
|
||||
tool_rules_solver = ToolRulesSolver(agent_state.tool_rules)
|
||||
llm_client = LLMClient.create(
|
||||
llm_config=agent_state.llm_config,
|
||||
provider_name=agent_state.llm_config.provider_name,
|
||||
provider_type=agent_state.llm_config.model_endpoint_type,
|
||||
put_inner_thoughts_first=True,
|
||||
actor_id=self.actor.id,
|
||||
)
|
||||
|
||||
for step in range(max_steps):
|
||||
@ -125,7 +129,7 @@ class LettaAgent(BaseAgent):
|
||||
# TODO: THIS IS INCREDIBLY UGLY
|
||||
# TODO: THERE ARE MULTIPLE COPIES OF THE LLM_CONFIG EVERYWHERE THAT ARE GETTING MANIPULATED
|
||||
interface = AnthropicStreamingInterface(
|
||||
use_assistant_message=use_assistant_message, put_inner_thoughts_in_kwarg=llm_client.llm_config.put_inner_thoughts_in_kwargs
|
||||
use_assistant_message=use_assistant_message, put_inner_thoughts_in_kwarg=agent_state.llm_config.put_inner_thoughts_in_kwargs
|
||||
)
|
||||
async for chunk in interface.process(stream):
|
||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||
@ -179,6 +183,7 @@ class LettaAgent(BaseAgent):
|
||||
ToolType.LETTA_SLEEPTIME_CORE,
|
||||
}
|
||||
or (t.tool_type == ToolType.LETTA_MULTI_AGENT_CORE and t.name == "send_message_to_agents_matching_tags")
|
||||
or (t.tool_type == ToolType.EXTERNAL_COMPOSIO)
|
||||
]
|
||||
|
||||
valid_tool_names = tool_rules_solver.get_allowed_tool_names(available_tools=set([t.name for t in tools]))
|
||||
@ -274,6 +279,7 @@ class LettaAgent(BaseAgent):
|
||||
return persisted_messages, continue_stepping
|
||||
|
||||
def _rebuild_memory(self, in_context_messages: List[Message], agent_state: AgentState) -> List[Message]:
|
||||
try:
|
||||
self.agent_manager.refresh_memory(agent_state=agent_state, actor=self.actor)
|
||||
|
||||
# TODO: This is a pretty brittle pattern established all over our code, need to get rid of this
|
||||
@ -313,6 +319,9 @@ class LettaAgent(BaseAgent):
|
||||
|
||||
else:
|
||||
return in_context_messages
|
||||
except:
|
||||
logger.exception(f"Failed to rebuild memory for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name})")
|
||||
raise
|
||||
|
||||
@trace_method
|
||||
async def _execute_tool(self, tool_name: str, tool_args: dict, agent_state: AgentState) -> Tuple[str, bool]:
|
||||
@ -331,6 +340,10 @@ class LettaAgent(BaseAgent):
|
||||
results = await self._send_message_to_agents_matching_tags(**tool_args)
|
||||
log_event(name="finish_send_message_to_agents_matching_tags", attributes=tool_args)
|
||||
return json.dumps(results), True
|
||||
elif target_tool.type == ToolType.EXTERNAL_COMPOSIO:
|
||||
log_event(name=f"start_composio_{tool_name}_execution", attributes=tool_args)
|
||||
log_event(name=f"finish_compsio_{tool_name}_execution", attributes=tool_args)
|
||||
return tool_execution_result.func_return, True
|
||||
else:
|
||||
tool_execution_manager = ToolExecutionManager(agent_state=agent_state, actor=self.actor)
|
||||
# TODO: Integrate sandbox result
|
||||
|
@ -156,8 +156,10 @@ class LettaAgentBatch:
|
||||
|
||||
log_event(name="init_llm_client")
|
||||
llm_client = LLMClient.create(
|
||||
provider=agent_states[0].llm_config.model_endpoint_type,
|
||||
provider_name=agent_states[0].llm_config.provider_name,
|
||||
provider_type=agent_states[0].llm_config.model_endpoint_type,
|
||||
put_inner_thoughts_first=True,
|
||||
actor_id=self.actor.id,
|
||||
)
|
||||
agent_llm_config_mapping = {s.id: s.llm_config for s in agent_states}
|
||||
|
||||
@ -273,8 +275,10 @@ class LettaAgentBatch:
|
||||
|
||||
# translate provider‑specific response → OpenAI‑style tool call (unchanged)
|
||||
llm_client = LLMClient.create(
|
||||
provider=item.llm_config.model_endpoint_type,
|
||||
provider_name=item.llm_config.provider_name,
|
||||
provider_type=item.llm_config.model_endpoint_type,
|
||||
put_inner_thoughts_first=True,
|
||||
actor_id=self.actor.id,
|
||||
)
|
||||
tool_call = (
|
||||
llm_client.convert_response_to_chat_completion(
|
||||
|
@ -6,6 +6,7 @@ from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
|
||||
import openai
|
||||
|
||||
from letta.agents.base_agent import BaseAgent
|
||||
from letta.agents.exceptions import IncompatibleAgentType
|
||||
from letta.agents.voice_sleeptime_agent import VoiceSleeptimeAgent
|
||||
from letta.constants import NON_USER_MSG_PREFIX
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
@ -18,7 +19,7 @@ from letta.helpers.tool_execution_helper import (
|
||||
from letta.interfaces.openai_chat_completions_streaming_interface import OpenAIChatCompletionsStreamingInterface
|
||||
from letta.log import get_logger
|
||||
from letta.orm.enums import ToolType
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.agent import AgentState, AgentType
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.letta_response import LettaResponse
|
||||
from letta.schemas.message import Message, MessageCreate, MessageUpdate
|
||||
@ -68,8 +69,6 @@ class VoiceAgent(BaseAgent):
|
||||
block_manager: BlockManager,
|
||||
passage_manager: PassageManager,
|
||||
actor: User,
|
||||
message_buffer_limit: int,
|
||||
message_buffer_min: int,
|
||||
):
|
||||
super().__init__(
|
||||
agent_id=agent_id, openai_client=openai_client, message_manager=message_manager, agent_manager=agent_manager, actor=actor
|
||||
@ -80,8 +79,6 @@ class VoiceAgent(BaseAgent):
|
||||
self.passage_manager = passage_manager
|
||||
# TODO: This is not guaranteed to exist!
|
||||
self.summary_block_label = "human"
|
||||
self.message_buffer_limit = message_buffer_limit
|
||||
self.message_buffer_min = message_buffer_min
|
||||
|
||||
# Cached archival memory/message size
|
||||
self.num_messages = self.message_manager.size(actor=self.actor, agent_id=agent_id)
|
||||
@ -108,8 +105,8 @@ class VoiceAgent(BaseAgent):
|
||||
target_block_label=self.summary_block_label,
|
||||
message_transcripts=[],
|
||||
),
|
||||
message_buffer_limit=self.message_buffer_limit,
|
||||
message_buffer_min=self.message_buffer_min,
|
||||
message_buffer_limit=agent_state.multi_agent_group.max_message_buffer_length,
|
||||
message_buffer_min=agent_state.multi_agent_group.min_message_buffer_length,
|
||||
)
|
||||
|
||||
return summarizer
|
||||
@ -124,9 +121,15 @@ class VoiceAgent(BaseAgent):
|
||||
"""
|
||||
if len(input_messages) != 1 or input_messages[0].role != MessageRole.user:
|
||||
raise ValueError(f"Voice Agent was invoked with multiple input messages or message did not have role `user`: {input_messages}")
|
||||
|
||||
user_query = input_messages[0].content[0].text
|
||||
|
||||
agent_state = self.agent_manager.get_agent_by_id(self.agent_id, actor=self.actor)
|
||||
|
||||
# Safety check
|
||||
if agent_state.agent_type != AgentType.voice_convo_agent:
|
||||
raise IncompatibleAgentType(expected_type=AgentType.voice_convo_agent, actual_type=agent_state.agent_type)
|
||||
|
||||
summarizer = self.init_summarizer(agent_state=agent_state)
|
||||
|
||||
in_context_messages = self.message_manager.get_messages_by_ids(message_ids=agent_state.message_ids, actor=self.actor)
|
||||
|
@ -4,7 +4,7 @@ from logging import CRITICAL, DEBUG, ERROR, INFO, NOTSET, WARN, WARNING
|
||||
LETTA_DIR = os.path.join(os.path.expanduser("~"), ".letta")
|
||||
LETTA_TOOL_EXECUTION_DIR = os.path.join(LETTA_DIR, "tool_execution_dir")
|
||||
|
||||
LETTA_MODEL_ENDPOINT = "https://inference.memgpt.ai"
|
||||
LETTA_MODEL_ENDPOINT = "https://inference.letta.com"
|
||||
|
||||
ADMIN_PREFIX = "/v1/admin"
|
||||
API_PREFIX = "/v1"
|
||||
@ -35,6 +35,10 @@ TOOL_CALL_ID_MAX_LEN = 29
|
||||
# minimum context window size
|
||||
MIN_CONTEXT_WINDOW = 4096
|
||||
|
||||
# Voice Sleeptime message buffer lengths
|
||||
DEFAULT_MAX_MESSAGE_BUFFER_LENGTH = 30
|
||||
DEFAULT_MIN_MESSAGE_BUFFER_LENGTH = 15
|
||||
|
||||
# embeddings
|
||||
MAX_EMBEDDING_DIM = 4096 # maximum supported embeding size - do NOT change or else DBs will need to be reset
|
||||
DEFAULT_EMBEDDING_CHUNK_SIZE = 300
|
||||
|
100
letta/functions/composio_helpers.py
Normal file
100
letta/functions/composio_helpers.py
Normal file
@ -0,0 +1,100 @@
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Any, Optional
|
||||
|
||||
from composio import ComposioToolSet
|
||||
from composio.constants import DEFAULT_ENTITY_ID
|
||||
from composio.exceptions import (
|
||||
ApiKeyNotProvidedError,
|
||||
ComposioSDKError,
|
||||
ConnectedAccountNotFoundError,
|
||||
EnumMetadataNotFound,
|
||||
EnumStringNotFound,
|
||||
)
|
||||
|
||||
from letta.constants import COMPOSIO_ENTITY_ENV_VAR_KEY
|
||||
|
||||
|
||||
# TODO: This is kind of hacky, as this is used to search up the action later on composio's side
|
||||
# TODO: So be very careful changing/removing these pair of functions
|
||||
def _generate_func_name_from_composio_action(action_name: str) -> str:
|
||||
"""
|
||||
Generates the composio function name from the composio action.
|
||||
|
||||
Args:
|
||||
action_name: The composio action name
|
||||
|
||||
Returns:
|
||||
function name
|
||||
"""
|
||||
return action_name.lower()
|
||||
|
||||
|
||||
def generate_composio_action_from_func_name(func_name: str) -> str:
|
||||
"""
|
||||
Generates the composio action from the composio function name.
|
||||
|
||||
Args:
|
||||
func_name: The composio function name
|
||||
|
||||
Returns:
|
||||
composio action name
|
||||
"""
|
||||
return func_name.upper()
|
||||
|
||||
|
||||
def generate_composio_tool_wrapper(action_name: str) -> tuple[str, str]:
|
||||
# Generate func name
|
||||
func_name = _generate_func_name_from_composio_action(action_name)
|
||||
|
||||
wrapper_function_str = f"""\
|
||||
def {func_name}(**kwargs):
|
||||
raise RuntimeError("Something went wrong - we should never be using the persisted source code for Composio. Please reach out to Letta team")
|
||||
"""
|
||||
|
||||
# Compile safety check
|
||||
_assert_code_gen_compilable(wrapper_function_str.strip())
|
||||
|
||||
return func_name, wrapper_function_str.strip()
|
||||
|
||||
|
||||
async def execute_composio_action_async(
|
||||
action_name: str, args: dict, api_key: Optional[str] = None, entity_id: Optional[str] = None
|
||||
) -> tuple[str, str]:
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
return await loop.run_in_executor(None, execute_composio_action, action_name, args, api_key, entity_id)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error in execute_composio_action_async: {e}") from e
|
||||
|
||||
|
||||
def execute_composio_action(action_name: str, args: dict, api_key: Optional[str] = None, entity_id: Optional[str] = None) -> Any:
|
||||
entity_id = entity_id or os.getenv(COMPOSIO_ENTITY_ENV_VAR_KEY, DEFAULT_ENTITY_ID)
|
||||
try:
|
||||
composio_toolset = ComposioToolSet(api_key=api_key, entity_id=entity_id, lock=False)
|
||||
response = composio_toolset.execute_action(action=action_name, params=args)
|
||||
except ApiKeyNotProvidedError:
|
||||
raise RuntimeError(
|
||||
f"Composio API key is missing for action '{action_name}'. "
|
||||
"Please set the sandbox environment variables either through the ADE or the API."
|
||||
)
|
||||
except ConnectedAccountNotFoundError:
|
||||
raise RuntimeError(f"No connected account was found for action '{action_name}'. " "Please link an account and try again.")
|
||||
except EnumStringNotFound as e:
|
||||
raise RuntimeError(f"Invalid value provided for action '{action_name}': " + str(e) + ". Please check the action parameters.")
|
||||
except EnumMetadataNotFound as e:
|
||||
raise RuntimeError(f"Invalid value provided for action '{action_name}': " + str(e) + ". Please check the action parameters.")
|
||||
except ComposioSDKError as e:
|
||||
raise RuntimeError(f"An unexpected error occurred in Composio SDK while executing action '{action_name}': " + str(e))
|
||||
|
||||
if "error" in response and response["error"]:
|
||||
raise RuntimeError(f"Error while executing action '{action_name}': " + str(response["error"]))
|
||||
|
||||
return response.get("data")
|
||||
|
||||
|
||||
def _assert_code_gen_compilable(code_str):
|
||||
try:
|
||||
compile(code_str, "<string>", "exec")
|
||||
except SyntaxError as e:
|
||||
print(f"Syntax error in code: {e}")
|
@ -1,8 +1,9 @@
|
||||
import importlib
|
||||
import inspect
|
||||
from collections.abc import Callable
|
||||
from textwrap import dedent # remove indentation
|
||||
from types import ModuleType
|
||||
from typing import Dict, List, Literal, Optional
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from letta.errors import LettaToolCreateError
|
||||
from letta.functions.schema_generator import generate_schema
|
||||
@ -66,7 +67,8 @@ def parse_source_code(func) -> str:
|
||||
return source_code
|
||||
|
||||
|
||||
def get_function_from_module(module_name: str, function_name: str):
|
||||
# TODO (cliandy) refactor below two funcs
|
||||
def get_function_from_module(module_name: str, function_name: str) -> Callable[..., Any]:
|
||||
"""
|
||||
Dynamically imports a function from a specified module.
|
||||
|
||||
|
@ -6,10 +6,9 @@ from random import uniform
|
||||
from typing import Any, Dict, List, Optional, Type, Union
|
||||
|
||||
import humps
|
||||
from composio.constants import DEFAULT_ENTITY_ID
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
|
||||
from letta.constants import COMPOSIO_ENTITY_ENV_VAR_KEY, DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
|
||||
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
|
||||
from letta.functions.interface import MultiAgentMessagingInterface
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.schemas.enums import MessageRole
|
||||
@ -21,34 +20,6 @@ from letta.server.rest_api.utils import get_letta_server
|
||||
from letta.settings import settings
|
||||
|
||||
|
||||
# TODO: This is kind of hacky, as this is used to search up the action later on composio's side
|
||||
# TODO: So be very careful changing/removing these pair of functions
|
||||
def generate_func_name_from_composio_action(action_name: str) -> str:
|
||||
"""
|
||||
Generates the composio function name from the composio action.
|
||||
|
||||
Args:
|
||||
action_name: The composio action name
|
||||
|
||||
Returns:
|
||||
function name
|
||||
"""
|
||||
return action_name.lower()
|
||||
|
||||
|
||||
def generate_composio_action_from_func_name(func_name: str) -> str:
|
||||
"""
|
||||
Generates the composio action from the composio function name.
|
||||
|
||||
Args:
|
||||
func_name: The composio function name
|
||||
|
||||
Returns:
|
||||
composio action name
|
||||
"""
|
||||
return func_name.upper()
|
||||
|
||||
|
||||
# TODO needed?
|
||||
def generate_mcp_tool_wrapper(mcp_tool_name: str) -> tuple[str, str]:
|
||||
|
||||
@ -58,71 +29,20 @@ def {mcp_tool_name}(**kwargs):
|
||||
"""
|
||||
|
||||
# Compile safety check
|
||||
assert_code_gen_compilable(wrapper_function_str.strip())
|
||||
_assert_code_gen_compilable(wrapper_function_str.strip())
|
||||
|
||||
return mcp_tool_name, wrapper_function_str.strip()
|
||||
|
||||
|
||||
def generate_composio_tool_wrapper(action_name: str) -> tuple[str, str]:
|
||||
# Generate func name
|
||||
func_name = generate_func_name_from_composio_action(action_name)
|
||||
|
||||
wrapper_function_str = f"""\
|
||||
def {func_name}(**kwargs):
|
||||
raise RuntimeError("Something went wrong - we should never be using the persisted source code for Composio. Please reach out to Letta team")
|
||||
"""
|
||||
|
||||
# Compile safety check
|
||||
assert_code_gen_compilable(wrapper_function_str.strip())
|
||||
|
||||
return func_name, wrapper_function_str.strip()
|
||||
|
||||
|
||||
def execute_composio_action(action_name: str, args: dict, api_key: Optional[str] = None, entity_id: Optional[str] = None) -> Any:
|
||||
import os
|
||||
|
||||
from composio.exceptions import (
|
||||
ApiKeyNotProvidedError,
|
||||
ComposioSDKError,
|
||||
ConnectedAccountNotFoundError,
|
||||
EnumMetadataNotFound,
|
||||
EnumStringNotFound,
|
||||
)
|
||||
from composio_langchain import ComposioToolSet
|
||||
|
||||
entity_id = entity_id or os.getenv(COMPOSIO_ENTITY_ENV_VAR_KEY, DEFAULT_ENTITY_ID)
|
||||
try:
|
||||
composio_toolset = ComposioToolSet(api_key=api_key, entity_id=entity_id, lock=False)
|
||||
response = composio_toolset.execute_action(action=action_name, params=args)
|
||||
except ApiKeyNotProvidedError:
|
||||
raise RuntimeError(
|
||||
f"Composio API key is missing for action '{action_name}'. "
|
||||
"Please set the sandbox environment variables either through the ADE or the API."
|
||||
)
|
||||
except ConnectedAccountNotFoundError:
|
||||
raise RuntimeError(f"No connected account was found for action '{action_name}'. " "Please link an account and try again.")
|
||||
except EnumStringNotFound as e:
|
||||
raise RuntimeError(f"Invalid value provided for action '{action_name}': " + str(e) + ". Please check the action parameters.")
|
||||
except EnumMetadataNotFound as e:
|
||||
raise RuntimeError(f"Invalid value provided for action '{action_name}': " + str(e) + ". Please check the action parameters.")
|
||||
except ComposioSDKError as e:
|
||||
raise RuntimeError(f"An unexpected error occurred in Composio SDK while executing action '{action_name}': " + str(e))
|
||||
|
||||
if "error" in response:
|
||||
raise RuntimeError(f"Error while executing action '{action_name}': " + str(response["error"]))
|
||||
|
||||
return response.get("data")
|
||||
|
||||
|
||||
def generate_langchain_tool_wrapper(
|
||||
tool: "LangChainBaseTool", additional_imports_module_attr_map: dict[str, str] = None
|
||||
) -> tuple[str, str]:
|
||||
tool_name = tool.__class__.__name__
|
||||
import_statement = f"from langchain_community.tools import {tool_name}"
|
||||
extra_module_imports = generate_import_code(additional_imports_module_attr_map)
|
||||
extra_module_imports = _generate_import_code(additional_imports_module_attr_map)
|
||||
|
||||
# Safety check that user has passed in all required imports:
|
||||
assert_all_classes_are_imported(tool, additional_imports_module_attr_map)
|
||||
_assert_all_classes_are_imported(tool, additional_imports_module_attr_map)
|
||||
|
||||
tool_instantiation = f"tool = {generate_imported_tool_instantiation_call_str(tool)}"
|
||||
run_call = f"return tool._run(**kwargs)"
|
||||
@ -139,25 +59,25 @@ def {func_name}(**kwargs):
|
||||
"""
|
||||
|
||||
# Compile safety check
|
||||
assert_code_gen_compilable(wrapper_function_str)
|
||||
_assert_code_gen_compilable(wrapper_function_str)
|
||||
|
||||
return func_name, wrapper_function_str
|
||||
|
||||
|
||||
def assert_code_gen_compilable(code_str):
|
||||
def _assert_code_gen_compilable(code_str):
|
||||
try:
|
||||
compile(code_str, "<string>", "exec")
|
||||
except SyntaxError as e:
|
||||
print(f"Syntax error in code: {e}")
|
||||
|
||||
|
||||
def assert_all_classes_are_imported(tool: Union["LangChainBaseTool"], additional_imports_module_attr_map: dict[str, str]) -> None:
|
||||
def _assert_all_classes_are_imported(tool: Union["LangChainBaseTool"], additional_imports_module_attr_map: dict[str, str]) -> None:
|
||||
# Safety check that user has passed in all required imports:
|
||||
tool_name = tool.__class__.__name__
|
||||
current_class_imports = {tool_name}
|
||||
if additional_imports_module_attr_map:
|
||||
current_class_imports.update(set(additional_imports_module_attr_map.values()))
|
||||
required_class_imports = set(find_required_class_names_for_import(tool))
|
||||
required_class_imports = set(_find_required_class_names_for_import(tool))
|
||||
|
||||
if not current_class_imports.issuperset(required_class_imports):
|
||||
err_msg = f"[ERROR] You are missing module_attr pairs in `additional_imports_module_attr_map`. Currently, you have imports for {current_class_imports}, but the required classes for import are {required_class_imports}"
|
||||
@ -165,7 +85,7 @@ def assert_all_classes_are_imported(tool: Union["LangChainBaseTool"], additional
|
||||
raise RuntimeError(err_msg)
|
||||
|
||||
|
||||
def find_required_class_names_for_import(obj: Union["LangChainBaseTool", BaseModel]) -> list[str]:
|
||||
def _find_required_class_names_for_import(obj: Union["LangChainBaseTool", BaseModel]) -> list[str]:
|
||||
"""
|
||||
Finds all the class names for required imports when instantiating the `obj`.
|
||||
NOTE: This does not return the full import path, only the class name.
|
||||
@ -181,7 +101,7 @@ def find_required_class_names_for_import(obj: Union["LangChainBaseTool", BaseMod
|
||||
|
||||
# Collect all possible candidates for BaseModel objects
|
||||
candidates = []
|
||||
if is_base_model(curr_obj):
|
||||
if _is_base_model(curr_obj):
|
||||
# If it is a base model, we get all the values of the object parameters
|
||||
# i.e., if obj('b' = <class A>), we would want to inspect <class A>
|
||||
fields = dict(curr_obj)
|
||||
@ -198,7 +118,7 @@ def find_required_class_names_for_import(obj: Union["LangChainBaseTool", BaseMod
|
||||
|
||||
# Filter out all candidates that are not BaseModels
|
||||
# In the list example above, ['a', 3, None, <class A>], we want to filter out 'a', 3, and None
|
||||
candidates = filter(lambda x: is_base_model(x), candidates)
|
||||
candidates = filter(lambda x: _is_base_model(x), candidates)
|
||||
|
||||
# Classic BFS here
|
||||
for c in candidates:
|
||||
@ -216,7 +136,7 @@ def generate_imported_tool_instantiation_call_str(obj: Any) -> Optional[str]:
|
||||
# If it is a basic Python type, we trivially return the string version of that value
|
||||
# Handle basic types
|
||||
return repr(obj)
|
||||
elif is_base_model(obj):
|
||||
elif _is_base_model(obj):
|
||||
# Otherwise, if it is a BaseModel
|
||||
# We want to pull out all the parameters, and reformat them into strings
|
||||
# e.g. {arg}={value}
|
||||
@ -269,11 +189,11 @@ def generate_imported_tool_instantiation_call_str(obj: Any) -> Optional[str]:
|
||||
return None
|
||||
|
||||
|
||||
def is_base_model(obj: Any):
|
||||
def _is_base_model(obj: Any):
|
||||
return isinstance(obj, BaseModel)
|
||||
|
||||
|
||||
def generate_import_code(module_attr_map: Optional[dict]):
|
||||
def _generate_import_code(module_attr_map: Optional[dict]):
|
||||
if not module_attr_map:
|
||||
return ""
|
||||
|
||||
@ -286,7 +206,7 @@ def generate_import_code(module_attr_map: Optional[dict]):
|
||||
return "\n".join(code_lines)
|
||||
|
||||
|
||||
def parse_letta_response_for_assistant_message(
|
||||
def _parse_letta_response_for_assistant_message(
|
||||
target_agent_id: str,
|
||||
letta_response: LettaResponse,
|
||||
) -> Optional[str]:
|
||||
@ -346,7 +266,7 @@ def execute_send_message_to_agent(
|
||||
return asyncio.run(async_execute_send_message_to_agent(sender_agent, messages, other_agent_id, log_prefix))
|
||||
|
||||
|
||||
async def send_message_to_agent_no_stream(
|
||||
async def _send_message_to_agent_no_stream(
|
||||
server: "SyncServer",
|
||||
agent_id: str,
|
||||
actor: User,
|
||||
@ -375,7 +295,7 @@ async def send_message_to_agent_no_stream(
|
||||
return LettaResponse(messages=final_messages, usage=usage_stats)
|
||||
|
||||
|
||||
async def async_send_message_with_retries(
|
||||
async def _async_send_message_with_retries(
|
||||
server: "SyncServer",
|
||||
sender_agent: "Agent",
|
||||
target_agent_id: str,
|
||||
@ -389,7 +309,7 @@ async def async_send_message_with_retries(
|
||||
for attempt in range(1, max_retries + 1):
|
||||
try:
|
||||
response = await asyncio.wait_for(
|
||||
send_message_to_agent_no_stream(
|
||||
_send_message_to_agent_no_stream(
|
||||
server=server,
|
||||
agent_id=target_agent_id,
|
||||
actor=sender_agent.user,
|
||||
@ -399,7 +319,7 @@ async def async_send_message_with_retries(
|
||||
)
|
||||
|
||||
# Then parse out the assistant message
|
||||
assistant_message = parse_letta_response_for_assistant_message(target_agent_id, response)
|
||||
assistant_message = _parse_letta_response_for_assistant_message(target_agent_id, response)
|
||||
if assistant_message:
|
||||
sender_agent.logger.info(f"{logging_prefix} - {assistant_message}")
|
||||
return assistant_message
|
||||
|
@ -76,6 +76,7 @@ def load_multi_agent(
|
||||
agent_state=agent_state,
|
||||
interface=interface,
|
||||
user=actor,
|
||||
mcp_clients=mcp_clients,
|
||||
group_id=group.id,
|
||||
agent_ids=group.agent_ids,
|
||||
description=group.description,
|
||||
|
@ -1,9 +1,10 @@
|
||||
import asyncio
|
||||
import threading
|
||||
from datetime import datetime, timezone
|
||||
from typing import List, Optional
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from letta.agent import Agent, AgentState
|
||||
from letta.functions.mcp_client.base_client import BaseMCPClient
|
||||
from letta.groups.helpers import stringify_message
|
||||
from letta.interface import AgentInterface
|
||||
from letta.orm import User
|
||||
@ -26,6 +27,7 @@ class SleeptimeMultiAgent(Agent):
|
||||
interface: AgentInterface,
|
||||
agent_state: AgentState,
|
||||
user: User,
|
||||
mcp_clients: Optional[Dict[str, BaseMCPClient]] = None,
|
||||
# custom
|
||||
group_id: str = "",
|
||||
agent_ids: List[str] = [],
|
||||
@ -115,6 +117,7 @@ class SleeptimeMultiAgent(Agent):
|
||||
agent_state=participant_agent_state,
|
||||
interface=StreamingServerInterface(),
|
||||
user=self.user,
|
||||
mcp_clients=self.mcp_clients,
|
||||
)
|
||||
|
||||
prior_messages = []
|
||||
@ -212,6 +215,7 @@ class SleeptimeMultiAgent(Agent):
|
||||
agent_state=self.agent_state,
|
||||
interface=self.interface,
|
||||
user=self.user,
|
||||
mcp_clients=self.mcp_clients,
|
||||
)
|
||||
# Perform main agent step
|
||||
usage_stats = main_agent.step(
|
||||
|
@ -4,7 +4,24 @@ from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.message import Message, MessageCreate
|
||||
|
||||
|
||||
def prepare_input_message_create(
|
||||
def convert_message_creates_to_messages(
|
||||
messages: list[MessageCreate],
|
||||
agent_id: str,
|
||||
wrap_user_message: bool = True,
|
||||
wrap_system_message: bool = True,
|
||||
) -> list[Message]:
|
||||
return [
|
||||
_convert_message_create_to_message(
|
||||
message=message,
|
||||
agent_id=agent_id,
|
||||
wrap_user_message=wrap_user_message,
|
||||
wrap_system_message=wrap_system_message,
|
||||
)
|
||||
for message in messages
|
||||
]
|
||||
|
||||
|
||||
def _convert_message_create_to_message(
|
||||
message: MessageCreate,
|
||||
agent_id: str,
|
||||
wrap_user_message: bool = True,
|
||||
@ -23,12 +40,12 @@ def prepare_input_message_create(
|
||||
raise ValueError("Message content is empty or invalid")
|
||||
|
||||
# Apply wrapping if needed
|
||||
if message.role == MessageRole.user and wrap_user_message:
|
||||
if message.role not in {MessageRole.user, MessageRole.system}:
|
||||
raise ValueError(f"Invalid message role: {message.role}")
|
||||
elif message.role == MessageRole.user and wrap_user_message:
|
||||
message_content = system.package_user_message(user_message=message_content)
|
||||
elif message.role == MessageRole.system and wrap_system_message:
|
||||
message_content = system.package_system_message(system_message=message_content)
|
||||
elif message.role not in {MessageRole.user, MessageRole.system}:
|
||||
raise ValueError(f"Invalid message role: {message.role}")
|
||||
|
||||
return Message(
|
||||
agent_id=agent_id,
|
||||
|
@ -3,7 +3,7 @@ from typing import Any, Dict, Optional
|
||||
|
||||
from letta.constants import COMPOSIO_ENTITY_ENV_VAR_KEY, PRE_EXECUTION_MESSAGE_ARG
|
||||
from letta.functions.ast_parsers import coerce_dict_args_by_annotations, get_function_annotations_from_source
|
||||
from letta.functions.helpers import execute_composio_action, generate_composio_action_from_func_name
|
||||
from letta.functions.composio_helpers import execute_composio_action, generate_composio_action_from_func_name
|
||||
from letta.helpers.composio_helpers import get_composio_api_key
|
||||
from letta.orm.enums import ToolType
|
||||
from letta.schemas.agent import AgentState
|
||||
|
@ -35,7 +35,7 @@ from letta.schemas.letta_message import (
|
||||
from letta.schemas.letta_message_content import ReasoningContent, RedactedReasoningContent, TextContent
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.openai.chat_completion_response import FunctionCall, ToolCall
|
||||
from letta.server.rest_api.optimistic_json_parser import OptimisticJSONParser
|
||||
from letta.server.rest_api.json_parser import JSONParser, PydanticJSONParser
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@ -56,7 +56,7 @@ class AnthropicStreamingInterface:
|
||||
"""
|
||||
|
||||
def __init__(self, use_assistant_message: bool = False, put_inner_thoughts_in_kwarg: bool = False):
|
||||
self.optimistic_json_parser: OptimisticJSONParser = OptimisticJSONParser()
|
||||
self.json_parser: JSONParser = PydanticJSONParser()
|
||||
self.use_assistant_message = use_assistant_message
|
||||
|
||||
# Premake IDs for database writes
|
||||
@ -68,7 +68,7 @@ class AnthropicStreamingInterface:
|
||||
self.accumulated_inner_thoughts = []
|
||||
self.tool_call_id = None
|
||||
self.tool_call_name = None
|
||||
self.accumulated_tool_call_args = []
|
||||
self.accumulated_tool_call_args = ""
|
||||
self.previous_parse = {}
|
||||
|
||||
# usage trackers
|
||||
@ -85,24 +85,27 @@ class AnthropicStreamingInterface:
|
||||
|
||||
def get_tool_call_object(self) -> ToolCall:
|
||||
"""Useful for agent loop"""
|
||||
return ToolCall(
|
||||
id=self.tool_call_id, function=FunctionCall(arguments="".join(self.accumulated_tool_call_args), name=self.tool_call_name)
|
||||
)
|
||||
return ToolCall(id=self.tool_call_id, function=FunctionCall(arguments=self.accumulated_tool_call_args, name=self.tool_call_name))
|
||||
|
||||
def _check_inner_thoughts_complete(self, combined_args: str) -> bool:
|
||||
"""
|
||||
Check if inner thoughts are complete in the current tool call arguments
|
||||
by looking for a closing quote after the inner_thoughts field
|
||||
"""
|
||||
try:
|
||||
if not self.put_inner_thoughts_in_kwarg:
|
||||
# None of the things should have inner thoughts in kwargs
|
||||
return True
|
||||
else:
|
||||
parsed = self.optimistic_json_parser.parse(combined_args)
|
||||
parsed = self.json_parser.parse(combined_args)
|
||||
# TODO: This will break on tools with 0 input
|
||||
return len(parsed.keys()) > 1 and INNER_THOUGHTS_KWARG in parsed.keys()
|
||||
except Exception as e:
|
||||
logger.error("Error checking inner thoughts: %s", e)
|
||||
raise
|
||||
|
||||
async def process(self, stream: AsyncStream[BetaRawMessageStreamEvent]) -> AsyncGenerator[LettaMessage, None]:
|
||||
try:
|
||||
async with stream:
|
||||
async for event in stream:
|
||||
# TODO: Support BetaThinkingBlock, BetaRedactedThinkingBlock
|
||||
@ -169,9 +172,8 @@ class AnthropicStreamingInterface:
|
||||
f"Streaming integrity failed - received BetaInputJSONDelta object while not in TOOL_USE EventMode: {delta}"
|
||||
)
|
||||
|
||||
self.accumulated_tool_call_args.append(delta.partial_json)
|
||||
combined_args = "".join(self.accumulated_tool_call_args)
|
||||
current_parsed = self.optimistic_json_parser.parse(combined_args)
|
||||
self.accumulated_tool_call_args += delta.partial_json
|
||||
current_parsed = self.json_parser.parse(self.accumulated_tool_call_args)
|
||||
|
||||
# Start detecting a difference in inner thoughts
|
||||
previous_inner_thoughts = self.previous_parse.get(INNER_THOUGHTS_KWARG, "")
|
||||
@ -188,7 +190,7 @@ class AnthropicStreamingInterface:
|
||||
yield reasoning_message
|
||||
|
||||
# Check if inner thoughts are complete - if so, flush the buffer
|
||||
if not self.inner_thoughts_complete and self._check_inner_thoughts_complete(combined_args):
|
||||
if not self.inner_thoughts_complete and self._check_inner_thoughts_complete(self.accumulated_tool_call_args):
|
||||
self.inner_thoughts_complete = True
|
||||
# Flush all buffered tool call messages
|
||||
for buffered_msg in self.tool_call_buffer:
|
||||
@ -272,6 +274,11 @@ class AnthropicStreamingInterface:
|
||||
self.tool_call_buffer = []
|
||||
|
||||
self.anthropic_mode = None
|
||||
except Exception as e:
|
||||
logger.error("Error processing stream: %s", e)
|
||||
raise
|
||||
finally:
|
||||
logger.info("AnthropicStreamingInterface: Stream processing complete.")
|
||||
|
||||
def get_reasoning_content(self) -> List[Union[TextContent, ReasoningContent, RedactedReasoningContent]]:
|
||||
def _process_group(
|
||||
|
@ -5,7 +5,7 @@ from openai.types.chat.chat_completion_chunk import ChatCompletionChunk, Choice,
|
||||
|
||||
from letta.constants import PRE_EXECUTION_MESSAGE_ARG
|
||||
from letta.interfaces.utils import _format_sse_chunk
|
||||
from letta.server.rest_api.optimistic_json_parser import OptimisticJSONParser
|
||||
from letta.server.rest_api.json_parser import OptimisticJSONParser
|
||||
|
||||
|
||||
class OpenAIChatCompletionsStreamingInterface:
|
||||
|
@ -26,6 +26,7 @@ from letta.llm_api.helpers import add_inner_thoughts_to_functions
|
||||
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION
|
||||
from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.enums import ProviderType
|
||||
from letta.schemas.message import Message as _Message
|
||||
from letta.schemas.message import MessageRole as _MessageRole
|
||||
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, Tool
|
||||
@ -128,11 +129,12 @@ def anthropic_get_model_list(url: str, api_key: Union[str, None]) -> dict:
|
||||
# NOTE: currently there is no GET /models, so we need to hardcode
|
||||
# return MODEL_LIST
|
||||
|
||||
anthropic_override_key = ProviderManager().get_anthropic_override_key()
|
||||
if anthropic_override_key:
|
||||
anthropic_client = anthropic.Anthropic(api_key=anthropic_override_key)
|
||||
if api_key:
|
||||
anthropic_client = anthropic.Anthropic(api_key=api_key)
|
||||
elif model_settings.anthropic_api_key:
|
||||
anthropic_client = anthropic.Anthropic()
|
||||
else:
|
||||
raise ValueError("No API key provided")
|
||||
|
||||
models = anthropic_client.models.list()
|
||||
models_json = models.model_dump()
|
||||
@ -738,13 +740,14 @@ def anthropic_chat_completions_request(
|
||||
put_inner_thoughts_in_kwargs: bool = False,
|
||||
extended_thinking: bool = False,
|
||||
max_reasoning_tokens: Optional[int] = None,
|
||||
provider_name: Optional[str] = None,
|
||||
betas: List[str] = ["tools-2024-04-04"],
|
||||
) -> ChatCompletionResponse:
|
||||
"""https://docs.anthropic.com/claude/docs/tool-use"""
|
||||
anthropic_client = None
|
||||
anthropic_override_key = ProviderManager().get_anthropic_override_key()
|
||||
if anthropic_override_key:
|
||||
anthropic_client = anthropic.Anthropic(api_key=anthropic_override_key)
|
||||
if provider_name and provider_name != ProviderType.anthropic.value:
|
||||
api_key = ProviderManager().get_override_key(provider_name)
|
||||
anthropic_client = anthropic.Anthropic(api_key=api_key)
|
||||
elif model_settings.anthropic_api_key:
|
||||
anthropic_client = anthropic.Anthropic()
|
||||
else:
|
||||
@ -796,6 +799,7 @@ def anthropic_chat_completions_request_stream(
|
||||
put_inner_thoughts_in_kwargs: bool = False,
|
||||
extended_thinking: bool = False,
|
||||
max_reasoning_tokens: Optional[int] = None,
|
||||
provider_name: Optional[str] = None,
|
||||
betas: List[str] = ["tools-2024-04-04"],
|
||||
) -> Generator[ChatCompletionChunkResponse, None, None]:
|
||||
"""Stream chat completions from Anthropic API.
|
||||
@ -810,10 +814,9 @@ def anthropic_chat_completions_request_stream(
|
||||
extended_thinking=extended_thinking,
|
||||
max_reasoning_tokens=max_reasoning_tokens,
|
||||
)
|
||||
|
||||
anthropic_override_key = ProviderManager().get_anthropic_override_key()
|
||||
if anthropic_override_key:
|
||||
anthropic_client = anthropic.Anthropic(api_key=anthropic_override_key)
|
||||
if provider_name and provider_name != ProviderType.anthropic.value:
|
||||
api_key = ProviderManager().get_override_key(provider_name)
|
||||
anthropic_client = anthropic.Anthropic(api_key=api_key)
|
||||
elif model_settings.anthropic_api_key:
|
||||
anthropic_client = anthropic.Anthropic()
|
||||
|
||||
@ -860,6 +863,7 @@ def anthropic_chat_completions_process_stream(
|
||||
put_inner_thoughts_in_kwargs: bool = False,
|
||||
extended_thinking: bool = False,
|
||||
max_reasoning_tokens: Optional[int] = None,
|
||||
provider_name: Optional[str] = None,
|
||||
create_message_id: bool = True,
|
||||
create_message_datetime: bool = True,
|
||||
betas: List[str] = ["tools-2024-04-04"],
|
||||
@ -944,6 +948,7 @@ def anthropic_chat_completions_process_stream(
|
||||
put_inner_thoughts_in_kwargs=put_inner_thoughts_in_kwargs,
|
||||
extended_thinking=extended_thinking,
|
||||
max_reasoning_tokens=max_reasoning_tokens,
|
||||
provider_name=provider_name,
|
||||
betas=betas,
|
||||
)
|
||||
):
|
||||
|
@ -27,6 +27,7 @@ from letta.llm_api.helpers import add_inner_thoughts_to_functions, unpack_all_in
|
||||
from letta.llm_api.llm_client_base import LLMClientBase
|
||||
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.enums import ProviderType
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.openai.chat_completion_request import Tool
|
||||
@ -112,7 +113,10 @@ class AnthropicClient(LLMClientBase):
|
||||
|
||||
@trace_method
|
||||
def _get_anthropic_client(self, async_client: bool = False) -> Union[anthropic.AsyncAnthropic, anthropic.Anthropic]:
|
||||
override_key = ProviderManager().get_anthropic_override_key()
|
||||
override_key = None
|
||||
if self.provider_name and self.provider_name != ProviderType.anthropic.value:
|
||||
override_key = ProviderManager().get_override_key(self.provider_name)
|
||||
|
||||
if async_client:
|
||||
return anthropic.AsyncAnthropic(api_key=override_key) if override_key else anthropic.AsyncAnthropic()
|
||||
return anthropic.Anthropic(api_key=override_key) if override_key else anthropic.Anthropic()
|
||||
|
@ -63,7 +63,7 @@ class GoogleVertexClient(GoogleAIClient):
|
||||
# Add thinking_config
|
||||
# If enable_reasoner is False, set thinking_budget to 0
|
||||
# Otherwise, use the value from max_reasoning_tokens
|
||||
thinking_budget = 0 if not self.llm_config.enable_reasoner else self.llm_config.max_reasoning_tokens
|
||||
thinking_budget = 0 if not llm_config.enable_reasoner else llm_config.max_reasoning_tokens
|
||||
thinking_config = ThinkingConfig(
|
||||
thinking_budget=thinking_budget,
|
||||
)
|
||||
|
@ -24,6 +24,7 @@ from letta.llm_api.openai import (
|
||||
from letta.local_llm.chat_completion_proxy import get_chat_completion
|
||||
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION
|
||||
from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages
|
||||
from letta.schemas.enums import ProviderType
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, cast_message_to_subtype
|
||||
@ -171,6 +172,10 @@ def create(
|
||||
if model_settings.openai_api_key is None and llm_config.model_endpoint == "https://api.openai.com/v1":
|
||||
# only is a problem if we are *not* using an openai proxy
|
||||
raise LettaConfigurationError(message="OpenAI key is missing from letta config file", missing_fields=["openai_api_key"])
|
||||
elif llm_config.provider_name and llm_config.provider_name != ProviderType.openai.value:
|
||||
from letta.services.provider_manager import ProviderManager
|
||||
|
||||
api_key = ProviderManager().get_override_key(llm_config.provider_name)
|
||||
elif model_settings.openai_api_key is None:
|
||||
# the openai python client requires a dummy API key
|
||||
api_key = "DUMMY_API_KEY"
|
||||
@ -373,6 +378,7 @@ def create(
|
||||
stream_interface=stream_interface,
|
||||
extended_thinking=llm_config.enable_reasoner,
|
||||
max_reasoning_tokens=llm_config.max_reasoning_tokens,
|
||||
provider_name=llm_config.provider_name,
|
||||
name=name,
|
||||
)
|
||||
|
||||
@ -383,6 +389,7 @@ def create(
|
||||
put_inner_thoughts_in_kwargs=llm_config.put_inner_thoughts_in_kwargs,
|
||||
extended_thinking=llm_config.enable_reasoner,
|
||||
max_reasoning_tokens=llm_config.max_reasoning_tokens,
|
||||
provider_name=llm_config.provider_name,
|
||||
)
|
||||
|
||||
if llm_config.put_inner_thoughts_in_kwargs:
|
||||
|
@ -9,8 +9,10 @@ class LLMClient:
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
provider: ProviderType,
|
||||
provider_type: ProviderType,
|
||||
provider_name: Optional[str] = None,
|
||||
put_inner_thoughts_first: bool = True,
|
||||
actor_id: Optional[str] = None,
|
||||
) -> Optional[LLMClientBase]:
|
||||
"""
|
||||
Create an LLM client based on the model endpoint type.
|
||||
@ -25,30 +27,38 @@ class LLMClient:
|
||||
Raises:
|
||||
ValueError: If the model endpoint type is not supported
|
||||
"""
|
||||
match provider:
|
||||
match provider_type:
|
||||
case ProviderType.google_ai:
|
||||
from letta.llm_api.google_ai_client import GoogleAIClient
|
||||
|
||||
return GoogleAIClient(
|
||||
provider_name=provider_name,
|
||||
put_inner_thoughts_first=put_inner_thoughts_first,
|
||||
actor_id=actor_id,
|
||||
)
|
||||
case ProviderType.google_vertex:
|
||||
from letta.llm_api.google_vertex_client import GoogleVertexClient
|
||||
|
||||
return GoogleVertexClient(
|
||||
provider_name=provider_name,
|
||||
put_inner_thoughts_first=put_inner_thoughts_first,
|
||||
actor_id=actor_id,
|
||||
)
|
||||
case ProviderType.anthropic:
|
||||
from letta.llm_api.anthropic_client import AnthropicClient
|
||||
|
||||
return AnthropicClient(
|
||||
provider_name=provider_name,
|
||||
put_inner_thoughts_first=put_inner_thoughts_first,
|
||||
actor_id=actor_id,
|
||||
)
|
||||
case ProviderType.openai:
|
||||
from letta.llm_api.openai_client import OpenAIClient
|
||||
|
||||
return OpenAIClient(
|
||||
provider_name=provider_name,
|
||||
put_inner_thoughts_first=put_inner_thoughts_first,
|
||||
actor_id=actor_id,
|
||||
)
|
||||
case _:
|
||||
return None
|
||||
|
@ -20,9 +20,13 @@ class LLMClientBase:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
provider_name: Optional[str] = None,
|
||||
put_inner_thoughts_first: Optional[bool] = True,
|
||||
use_tool_naming: bool = True,
|
||||
actor_id: Optional[str] = None,
|
||||
):
|
||||
self.actor_id = actor_id
|
||||
self.provider_name = provider_name
|
||||
self.put_inner_thoughts_first = put_inner_thoughts_first
|
||||
self.use_tool_naming = use_tool_naming
|
||||
|
||||
|
@ -157,11 +157,17 @@ def build_openai_chat_completions_request(
|
||||
# if "gpt-4o" in llm_config.model or "gpt-4-turbo" in llm_config.model or "gpt-3.5-turbo" in llm_config.model:
|
||||
# data.response_format = {"type": "json_object"}
|
||||
|
||||
# always set user id for openai requests
|
||||
if user_id:
|
||||
data.user = str(user_id)
|
||||
|
||||
if llm_config.model_endpoint == LETTA_MODEL_ENDPOINT:
|
||||
# override user id for inference.memgpt.ai
|
||||
if not user_id:
|
||||
# override user id for inference.letta.com
|
||||
import uuid
|
||||
|
||||
data.user = str(uuid.UUID(int=0))
|
||||
|
||||
data.model = "memgpt-openai"
|
||||
|
||||
if use_structured_output and data.tools is not None and len(data.tools) > 0:
|
||||
|
@ -22,6 +22,7 @@ from letta.llm_api.helpers import add_inner_thoughts_to_functions, convert_to_st
|
||||
from letta.llm_api.llm_client_base import LLMClientBase
|
||||
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION, INNER_THOUGHTS_KWARG_DESCRIPTION_GO_FIRST
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.enums import ProviderType
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest
|
||||
@ -64,6 +65,13 @@ def supports_parallel_tool_calling(model: str) -> bool:
|
||||
|
||||
class OpenAIClient(LLMClientBase):
|
||||
def _prepare_client_kwargs(self, llm_config: LLMConfig) -> dict:
|
||||
api_key = None
|
||||
if llm_config.provider_name and llm_config.provider_name != ProviderType.openai.value:
|
||||
from letta.services.provider_manager import ProviderManager
|
||||
|
||||
api_key = ProviderManager().get_override_key(llm_config.provider_name)
|
||||
|
||||
if not api_key:
|
||||
api_key = model_settings.openai_api_key or os.environ.get("OPENAI_API_KEY")
|
||||
# supposedly the openai python client requires a dummy API key
|
||||
api_key = api_key or "DUMMY_API_KEY"
|
||||
@ -135,11 +143,17 @@ class OpenAIClient(LLMClientBase):
|
||||
temperature=llm_config.temperature if supports_temperature_param(model) else None,
|
||||
)
|
||||
|
||||
# always set user id for openai requests
|
||||
if self.actor_id:
|
||||
data.user = self.actor_id
|
||||
|
||||
if llm_config.model_endpoint == LETTA_MODEL_ENDPOINT:
|
||||
# override user id for inference.memgpt.ai
|
||||
if not self.actor_id:
|
||||
# override user id for inference.letta.com
|
||||
import uuid
|
||||
|
||||
data.user = str(uuid.UUID(int=0))
|
||||
|
||||
data.model = "memgpt-openai"
|
||||
|
||||
if data.tools is not None and len(data.tools) > 0:
|
||||
|
@ -79,8 +79,10 @@ def summarize_messages(
|
||||
llm_config_no_inner_thoughts.put_inner_thoughts_in_kwargs = False
|
||||
|
||||
llm_client = LLMClient.create(
|
||||
provider=llm_config_no_inner_thoughts.model_endpoint_type,
|
||||
provider_name=llm_config_no_inner_thoughts.provider_name,
|
||||
provider_type=llm_config_no_inner_thoughts.model_endpoint_type,
|
||||
put_inner_thoughts_first=False,
|
||||
actor_id=agent_state.created_by_id,
|
||||
)
|
||||
# try to use new client, otherwise fallback to old flow
|
||||
# TODO: we can just directly call the LLM here?
|
||||
|
@ -21,6 +21,8 @@ class Group(SqlalchemyBase, OrganizationMixin):
|
||||
termination_token: Mapped[Optional[str]] = mapped_column(nullable=True, doc="")
|
||||
max_turns: Mapped[Optional[int]] = mapped_column(nullable=True, doc="")
|
||||
sleeptime_agent_frequency: Mapped[Optional[int]] = mapped_column(nullable=True, doc="")
|
||||
max_message_buffer_length: Mapped[Optional[int]] = mapped_column(nullable=True, doc="")
|
||||
min_message_buffer_length: Mapped[Optional[int]] = mapped_column(nullable=True, doc="")
|
||||
turns_counter: Mapped[Optional[int]] = mapped_column(nullable=True, doc="")
|
||||
last_processed_message_id: Mapped[Optional[str]] = mapped_column(nullable=True, doc="")
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from letta.orm.mixins import OrganizationMixin
|
||||
@ -15,9 +16,18 @@ class Provider(SqlalchemyBase, OrganizationMixin):
|
||||
|
||||
__tablename__ = "providers"
|
||||
__pydantic_model__ = PydanticProvider
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"name",
|
||||
"organization_id",
|
||||
name="unique_name_organization_id",
|
||||
),
|
||||
)
|
||||
|
||||
name: Mapped[str] = mapped_column(nullable=False, doc="The name of the provider")
|
||||
provider_type: Mapped[str] = mapped_column(nullable=True, doc="The type of the provider")
|
||||
api_key: Mapped[str] = mapped_column(nullable=True, doc="API key used for requests to the provider.")
|
||||
base_url: Mapped[str] = mapped_column(nullable=True, doc="Base URL for the provider.")
|
||||
|
||||
# relationships
|
||||
organization: Mapped["Organization"] = relationship("Organization", back_populates="providers")
|
||||
|
@ -56,7 +56,6 @@ class AgentState(OrmMetadataBase, validate_assignment=True):
|
||||
name: str = Field(..., description="The name of the agent.")
|
||||
# tool rules
|
||||
tool_rules: Optional[List[ToolRule]] = Field(default=None, description="The list of tool rules.")
|
||||
|
||||
# in-context memory
|
||||
message_ids: Optional[List[str]] = Field(default=None, description="The ids of the messages in the agent's in-context memory.")
|
||||
|
||||
|
@ -6,6 +6,17 @@ class ProviderType(str, Enum):
|
||||
google_ai = "google_ai"
|
||||
google_vertex = "google_vertex"
|
||||
openai = "openai"
|
||||
letta = "letta"
|
||||
deepseek = "deepseek"
|
||||
lmstudio_openai = "lmstudio_openai"
|
||||
xai = "xai"
|
||||
mistral = "mistral"
|
||||
ollama = "ollama"
|
||||
groq = "groq"
|
||||
together = "together"
|
||||
azure = "azure"
|
||||
vllm = "vllm"
|
||||
bedrock = "bedrock"
|
||||
|
||||
|
||||
class MessageRole(str, Enum):
|
||||
|
@ -32,6 +32,14 @@ class Group(GroupBase):
|
||||
sleeptime_agent_frequency: Optional[int] = Field(None, description="")
|
||||
turns_counter: Optional[int] = Field(None, description="")
|
||||
last_processed_message_id: Optional[str] = Field(None, description="")
|
||||
max_message_buffer_length: Optional[int] = Field(
|
||||
None,
|
||||
description="The desired maximum length of messages in the context window of the convo agent. This is a best effort, and may be off slightly due to user/assistant interleaving.",
|
||||
)
|
||||
min_message_buffer_length: Optional[int] = Field(
|
||||
None,
|
||||
description="The desired minimum length of messages in the context window of the convo agent. This is a best effort, and may be off-by-one due to user/assistant interleaving.",
|
||||
)
|
||||
|
||||
|
||||
class ManagerConfig(BaseModel):
|
||||
@ -87,11 +95,27 @@ class SleeptimeManagerUpdate(ManagerConfig):
|
||||
class VoiceSleeptimeManager(ManagerConfig):
|
||||
manager_type: Literal[ManagerType.voice_sleeptime] = Field(ManagerType.voice_sleeptime, description="")
|
||||
manager_agent_id: str = Field(..., description="")
|
||||
max_message_buffer_length: Optional[int] = Field(
|
||||
None,
|
||||
description="The desired maximum length of messages in the context window of the convo agent. This is a best effort, and may be off slightly due to user/assistant interleaving.",
|
||||
)
|
||||
min_message_buffer_length: Optional[int] = Field(
|
||||
None,
|
||||
description="The desired minimum length of messages in the context window of the convo agent. This is a best effort, and may be off-by-one due to user/assistant interleaving.",
|
||||
)
|
||||
|
||||
|
||||
class VoiceSleeptimeManagerUpdate(ManagerConfig):
|
||||
manager_type: Literal[ManagerType.voice_sleeptime] = Field(ManagerType.voice_sleeptime, description="")
|
||||
manager_agent_id: Optional[str] = Field(None, description="")
|
||||
max_message_buffer_length: Optional[int] = Field(
|
||||
None,
|
||||
description="The desired maximum length of messages in the context window of the convo agent. This is a best effort, and may be off slightly due to user/assistant interleaving.",
|
||||
)
|
||||
min_message_buffer_length: Optional[int] = Field(
|
||||
None,
|
||||
description="The desired minimum length of messages in the context window of the convo agent. This is a best effort, and may be off-by-one due to user/assistant interleaving.",
|
||||
)
|
||||
|
||||
|
||||
# class SwarmGroup(ManagerConfig):
|
||||
|
@ -50,6 +50,7 @@ class LLMConfig(BaseModel):
|
||||
"xai",
|
||||
] = Field(..., description="The endpoint type for the model.")
|
||||
model_endpoint: Optional[str] = Field(None, description="The endpoint for the model.")
|
||||
provider_name: Optional[str] = Field(None, description="The provider name for the model.")
|
||||
model_wrapper: Optional[str] = Field(None, description="The wrapper for the model.")
|
||||
context_window: int = Field(..., description="The context window size for the model.")
|
||||
put_inner_thoughts_in_kwargs: Optional[bool] = Field(
|
||||
|
@ -2,8 +2,8 @@ from typing import Dict
|
||||
|
||||
LLM_HANDLE_OVERRIDES: Dict[str, Dict[str, str]] = {
|
||||
"anthropic": {
|
||||
"claude-3-5-haiku-20241022": "claude-3.5-haiku",
|
||||
"claude-3-5-sonnet-20241022": "claude-3.5-sonnet",
|
||||
"claude-3-5-haiku-20241022": "claude-3-5-haiku",
|
||||
"claude-3-5-sonnet-20241022": "claude-3-5-sonnet",
|
||||
"claude-3-opus-20240229": "claude-3-opus",
|
||||
},
|
||||
"openai": {
|
||||
|
@ -1,6 +1,6 @@
|
||||
import warnings
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
from typing import List, Literal, Optional
|
||||
|
||||
from pydantic import Field, model_validator
|
||||
|
||||
@ -9,9 +9,11 @@ from letta.llm_api.azure_openai import get_azure_chat_completions_endpoint, get_
|
||||
from letta.llm_api.azure_openai_constants import AZURE_MODEL_TO_CONTEXT_LENGTH
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.embedding_config_overrides import EMBEDDING_HANDLE_OVERRIDES
|
||||
from letta.schemas.enums import ProviderType
|
||||
from letta.schemas.letta_base import LettaBase
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.llm_config_overrides import LLM_HANDLE_OVERRIDES
|
||||
from letta.settings import model_settings
|
||||
|
||||
|
||||
class ProviderBase(LettaBase):
|
||||
@ -21,10 +23,18 @@ class ProviderBase(LettaBase):
|
||||
class Provider(ProviderBase):
|
||||
id: Optional[str] = Field(None, description="The id of the provider, lazily created by the database manager.")
|
||||
name: str = Field(..., description="The name of the provider")
|
||||
provider_type: ProviderType = Field(..., description="The type of the provider")
|
||||
api_key: Optional[str] = Field(None, description="API key used for requests to the provider.")
|
||||
base_url: Optional[str] = Field(None, description="Base URL for the provider.")
|
||||
organization_id: Optional[str] = Field(None, description="The organization id of the user")
|
||||
updated_at: Optional[datetime] = Field(None, description="The last update timestamp of the provider.")
|
||||
|
||||
@model_validator(mode="after")
|
||||
def default_base_url(self):
|
||||
if self.provider_type == ProviderType.openai and self.base_url is None:
|
||||
self.base_url = model_settings.openai_api_base
|
||||
return self
|
||||
|
||||
def resolve_identifier(self):
|
||||
if not self.id:
|
||||
self.id = ProviderBase.generate_id(prefix=ProviderBase.__id_prefix__)
|
||||
@ -59,9 +69,41 @@ class Provider(ProviderBase):
|
||||
|
||||
return f"{self.name}/{model_name}"
|
||||
|
||||
def cast_to_subtype(self):
|
||||
match (self.provider_type):
|
||||
case ProviderType.letta:
|
||||
return LettaProvider(**self.model_dump(exclude_none=True))
|
||||
case ProviderType.openai:
|
||||
return OpenAIProvider(**self.model_dump(exclude_none=True))
|
||||
case ProviderType.anthropic:
|
||||
return AnthropicProvider(**self.model_dump(exclude_none=True))
|
||||
case ProviderType.anthropic_bedrock:
|
||||
return AnthropicBedrockProvider(**self.model_dump(exclude_none=True))
|
||||
case ProviderType.ollama:
|
||||
return OllamaProvider(**self.model_dump(exclude_none=True))
|
||||
case ProviderType.google_ai:
|
||||
return GoogleAIProvider(**self.model_dump(exclude_none=True))
|
||||
case ProviderType.google_vertex:
|
||||
return GoogleVertexProvider(**self.model_dump(exclude_none=True))
|
||||
case ProviderType.azure:
|
||||
return AzureProvider(**self.model_dump(exclude_none=True))
|
||||
case ProviderType.groq:
|
||||
return GroqProvider(**self.model_dump(exclude_none=True))
|
||||
case ProviderType.together:
|
||||
return TogetherProvider(**self.model_dump(exclude_none=True))
|
||||
case ProviderType.vllm_chat_completions:
|
||||
return VLLMChatCompletionsProvider(**self.model_dump(exclude_none=True))
|
||||
case ProviderType.vllm_completions:
|
||||
return VLLMCompletionsProvider(**self.model_dump(exclude_none=True))
|
||||
case ProviderType.xai:
|
||||
return XAIProvider(**self.model_dump(exclude_none=True))
|
||||
case _:
|
||||
raise ValueError(f"Unknown provider type: {self.provider_type}")
|
||||
|
||||
|
||||
class ProviderCreate(ProviderBase):
|
||||
name: str = Field(..., description="The name of the provider.")
|
||||
provider_type: ProviderType = Field(..., description="The type of the provider.")
|
||||
api_key: str = Field(..., description="API key used for requests to the provider.")
|
||||
|
||||
|
||||
@ -70,8 +112,7 @@ class ProviderUpdate(ProviderBase):
|
||||
|
||||
|
||||
class LettaProvider(Provider):
|
||||
|
||||
name: str = "letta"
|
||||
provider_type: Literal[ProviderType.letta] = Field(ProviderType.letta, description="The type of the provider.")
|
||||
|
||||
def list_llm_models(self) -> List[LLMConfig]:
|
||||
return [
|
||||
@ -81,6 +122,7 @@ class LettaProvider(Provider):
|
||||
model_endpoint=LETTA_MODEL_ENDPOINT,
|
||||
context_window=8192,
|
||||
handle=self.get_handle("letta-free"),
|
||||
provider_name=self.name,
|
||||
)
|
||||
]
|
||||
|
||||
@ -98,7 +140,7 @@ class LettaProvider(Provider):
|
||||
|
||||
|
||||
class OpenAIProvider(Provider):
|
||||
name: str = "openai"
|
||||
provider_type: Literal[ProviderType.openai] = Field(ProviderType.openai, description="The type of the provider.")
|
||||
api_key: str = Field(..., description="API key for the OpenAI API.")
|
||||
base_url: str = Field(..., description="Base URL for the OpenAI API.")
|
||||
|
||||
@ -180,6 +222,7 @@ class OpenAIProvider(Provider):
|
||||
model_endpoint=self.base_url,
|
||||
context_window=context_window_size,
|
||||
handle=self.get_handle(model_name),
|
||||
provider_name=self.name,
|
||||
)
|
||||
)
|
||||
|
||||
@ -235,7 +278,7 @@ class DeepSeekProvider(OpenAIProvider):
|
||||
* It also does not support native function calling
|
||||
"""
|
||||
|
||||
name: str = "deepseek"
|
||||
provider_type: Literal[ProviderType.deepseek] = Field(ProviderType.deepseek, description="The type of the provider.")
|
||||
base_url: str = Field("https://api.deepseek.com/v1", description="Base URL for the DeepSeek API.")
|
||||
api_key: str = Field(..., description="API key for the DeepSeek API.")
|
||||
|
||||
@ -286,6 +329,7 @@ class DeepSeekProvider(OpenAIProvider):
|
||||
context_window=context_window_size,
|
||||
handle=self.get_handle(model_name),
|
||||
put_inner_thoughts_in_kwargs=put_inner_thoughts_in_kwargs,
|
||||
provider_name=self.name,
|
||||
)
|
||||
)
|
||||
|
||||
@ -297,7 +341,7 @@ class DeepSeekProvider(OpenAIProvider):
|
||||
|
||||
|
||||
class LMStudioOpenAIProvider(OpenAIProvider):
|
||||
name: str = "lmstudio-openai"
|
||||
provider_type: Literal[ProviderType.lmstudio_openai] = Field(ProviderType.lmstudio_openai, description="The type of the provider.")
|
||||
base_url: str = Field(..., description="Base URL for the LMStudio OpenAI API.")
|
||||
api_key: Optional[str] = Field(None, description="API key for the LMStudio API.")
|
||||
|
||||
@ -423,7 +467,7 @@ class LMStudioOpenAIProvider(OpenAIProvider):
|
||||
class XAIProvider(OpenAIProvider):
|
||||
"""https://docs.x.ai/docs/api-reference"""
|
||||
|
||||
name: str = "xai"
|
||||
provider_type: Literal[ProviderType.xai] = Field(ProviderType.xai, description="The type of the provider.")
|
||||
api_key: str = Field(..., description="API key for the xAI/Grok API.")
|
||||
base_url: str = Field("https://api.x.ai/v1", description="Base URL for the xAI/Grok API.")
|
||||
|
||||
@ -476,6 +520,7 @@ class XAIProvider(OpenAIProvider):
|
||||
model_endpoint=self.base_url,
|
||||
context_window=context_window_size,
|
||||
handle=self.get_handle(model_name),
|
||||
provider_name=self.name,
|
||||
)
|
||||
)
|
||||
|
||||
@ -487,7 +532,7 @@ class XAIProvider(OpenAIProvider):
|
||||
|
||||
|
||||
class AnthropicProvider(Provider):
|
||||
name: str = "anthropic"
|
||||
provider_type: Literal[ProviderType.anthropic] = Field(ProviderType.anthropic, description="The type of the provider.")
|
||||
api_key: str = Field(..., description="API key for the Anthropic API.")
|
||||
base_url: str = "https://api.anthropic.com/v1"
|
||||
|
||||
@ -563,6 +608,7 @@ class AnthropicProvider(Provider):
|
||||
handle=self.get_handle(model["id"]),
|
||||
put_inner_thoughts_in_kwargs=inner_thoughts_in_kwargs,
|
||||
max_tokens=max_tokens,
|
||||
provider_name=self.name,
|
||||
)
|
||||
)
|
||||
return configs
|
||||
@ -572,7 +618,7 @@ class AnthropicProvider(Provider):
|
||||
|
||||
|
||||
class MistralProvider(Provider):
|
||||
name: str = "mistral"
|
||||
provider_type: Literal[ProviderType.mistral] = Field(ProviderType.mistral, description="The type of the provider.")
|
||||
api_key: str = Field(..., description="API key for the Mistral API.")
|
||||
base_url: str = "https://api.mistral.ai/v1"
|
||||
|
||||
@ -596,6 +642,7 @@ class MistralProvider(Provider):
|
||||
model_endpoint=self.base_url,
|
||||
context_window=model["max_context_length"],
|
||||
handle=self.get_handle(model["id"]),
|
||||
provider_name=self.name,
|
||||
)
|
||||
)
|
||||
|
||||
@ -622,7 +669,7 @@ class OllamaProvider(OpenAIProvider):
|
||||
See: https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion
|
||||
"""
|
||||
|
||||
name: str = "ollama"
|
||||
provider_type: Literal[ProviderType.ollama] = Field(ProviderType.ollama, description="The type of the provider.")
|
||||
base_url: str = Field(..., description="Base URL for the Ollama API.")
|
||||
api_key: Optional[str] = Field(None, description="API key for the Ollama API (default: `None`).")
|
||||
default_prompt_formatter: str = Field(
|
||||
@ -652,6 +699,7 @@ class OllamaProvider(OpenAIProvider):
|
||||
model_wrapper=self.default_prompt_formatter,
|
||||
context_window=context_window,
|
||||
handle=self.get_handle(model["name"]),
|
||||
provider_name=self.name,
|
||||
)
|
||||
)
|
||||
return configs
|
||||
@ -734,7 +782,7 @@ class OllamaProvider(OpenAIProvider):
|
||||
|
||||
|
||||
class GroqProvider(OpenAIProvider):
|
||||
name: str = "groq"
|
||||
provider_type: Literal[ProviderType.groq] = Field(ProviderType.groq, description="The type of the provider.")
|
||||
base_url: str = "https://api.groq.com/openai/v1"
|
||||
api_key: str = Field(..., description="API key for the Groq API.")
|
||||
|
||||
@ -753,6 +801,7 @@ class GroqProvider(OpenAIProvider):
|
||||
model_endpoint=self.base_url,
|
||||
context_window=model["context_window"],
|
||||
handle=self.get_handle(model["id"]),
|
||||
provider_name=self.name,
|
||||
)
|
||||
)
|
||||
return configs
|
||||
@ -773,7 +822,7 @@ class TogetherProvider(OpenAIProvider):
|
||||
function calling support is limited.
|
||||
"""
|
||||
|
||||
name: str = "together"
|
||||
provider_type: Literal[ProviderType.together] = Field(ProviderType.together, description="The type of the provider.")
|
||||
base_url: str = "https://api.together.ai/v1"
|
||||
api_key: str = Field(..., description="API key for the TogetherAI API.")
|
||||
default_prompt_formatter: str = Field(..., description="Default prompt formatter (aka model wrapper) to use on vLLM /completions API.")
|
||||
@ -821,6 +870,7 @@ class TogetherProvider(OpenAIProvider):
|
||||
model_wrapper=self.default_prompt_formatter,
|
||||
context_window=context_window_size,
|
||||
handle=self.get_handle(model_name),
|
||||
provider_name=self.name,
|
||||
)
|
||||
)
|
||||
|
||||
@ -874,7 +924,7 @@ class TogetherProvider(OpenAIProvider):
|
||||
|
||||
class GoogleAIProvider(Provider):
|
||||
# gemini
|
||||
name: str = "google_ai"
|
||||
provider_type: Literal[ProviderType.google_ai] = Field(ProviderType.google_ai, description="The type of the provider.")
|
||||
api_key: str = Field(..., description="API key for the Google AI API.")
|
||||
base_url: str = "https://generativelanguage.googleapis.com"
|
||||
|
||||
@ -889,7 +939,6 @@ class GoogleAIProvider(Provider):
|
||||
# filter by model names
|
||||
model_options = [mo[len("models/") :] if mo.startswith("models/") else mo for mo in model_options]
|
||||
|
||||
# TODO remove manual filtering for gemini-pro
|
||||
# Add support for all gemini models
|
||||
model_options = [mo for mo in model_options if str(mo).startswith("gemini-")]
|
||||
|
||||
@ -903,6 +952,7 @@ class GoogleAIProvider(Provider):
|
||||
context_window=self.get_model_context_window(model),
|
||||
handle=self.get_handle(model),
|
||||
max_tokens=8192,
|
||||
provider_name=self.name,
|
||||
)
|
||||
)
|
||||
return configs
|
||||
@ -938,7 +988,7 @@ class GoogleAIProvider(Provider):
|
||||
|
||||
|
||||
class GoogleVertexProvider(Provider):
|
||||
name: str = "google_vertex"
|
||||
provider_type: Literal[ProviderType.google_vertex] = Field(ProviderType.google_vertex, description="The type of the provider.")
|
||||
google_cloud_project: str = Field(..., description="GCP project ID for the Google Vertex API.")
|
||||
google_cloud_location: str = Field(..., description="GCP region for the Google Vertex API.")
|
||||
|
||||
@ -955,6 +1005,7 @@ class GoogleVertexProvider(Provider):
|
||||
context_window=context_length,
|
||||
handle=self.get_handle(model),
|
||||
max_tokens=8192,
|
||||
provider_name=self.name,
|
||||
)
|
||||
)
|
||||
return configs
|
||||
@ -978,7 +1029,7 @@ class GoogleVertexProvider(Provider):
|
||||
|
||||
|
||||
class AzureProvider(Provider):
|
||||
name: str = "azure"
|
||||
provider_type: Literal[ProviderType.azure] = Field(ProviderType.azure, description="The type of the provider.")
|
||||
latest_api_version: str = "2024-09-01-preview" # https://learn.microsoft.com/en-us/azure/ai-services/openai/api-version-deprecation
|
||||
base_url: str = Field(
|
||||
..., description="Base URL for the Azure API endpoint. This should be specific to your org, e.g. `https://letta.openai.azure.com`."
|
||||
@ -1011,6 +1062,7 @@ class AzureProvider(Provider):
|
||||
model_endpoint=model_endpoint,
|
||||
context_window=context_window_size,
|
||||
handle=self.get_handle(model_name),
|
||||
provider_name=self.name,
|
||||
),
|
||||
)
|
||||
return configs
|
||||
@ -1051,7 +1103,7 @@ class VLLMChatCompletionsProvider(Provider):
|
||||
"""vLLM provider that treats vLLM as an OpenAI /chat/completions proxy"""
|
||||
|
||||
# NOTE: vLLM only serves one model at a time (so could configure that through env variables)
|
||||
name: str = "vllm"
|
||||
provider_type: Literal[ProviderType.vllm] = Field(ProviderType.vllm, description="The type of the provider.")
|
||||
base_url: str = Field(..., description="Base URL for the vLLM API.")
|
||||
|
||||
def list_llm_models(self) -> List[LLMConfig]:
|
||||
@ -1070,6 +1122,7 @@ class VLLMChatCompletionsProvider(Provider):
|
||||
model_endpoint=self.base_url,
|
||||
context_window=model["max_model_len"],
|
||||
handle=self.get_handle(model["id"]),
|
||||
provider_name=self.name,
|
||||
)
|
||||
)
|
||||
return configs
|
||||
@ -1083,7 +1136,7 @@ class VLLMCompletionsProvider(Provider):
|
||||
"""This uses /completions API as the backend, not /chat/completions, so we need to specify a model wrapper"""
|
||||
|
||||
# NOTE: vLLM only serves one model at a time (so could configure that through env variables)
|
||||
name: str = "vllm"
|
||||
provider_type: Literal[ProviderType.vllm] = Field(ProviderType.vllm, description="The type of the provider.")
|
||||
base_url: str = Field(..., description="Base URL for the vLLM API.")
|
||||
default_prompt_formatter: str = Field(..., description="Default prompt formatter (aka model wrapper) to use on vLLM /completions API.")
|
||||
|
||||
@ -1103,6 +1156,7 @@ class VLLMCompletionsProvider(Provider):
|
||||
model_wrapper=self.default_prompt_formatter,
|
||||
context_window=model["max_model_len"],
|
||||
handle=self.get_handle(model["id"]),
|
||||
provider_name=self.name,
|
||||
)
|
||||
)
|
||||
return configs
|
||||
@ -1117,7 +1171,7 @@ class CohereProvider(OpenAIProvider):
|
||||
|
||||
|
||||
class AnthropicBedrockProvider(Provider):
|
||||
name: str = "bedrock"
|
||||
provider_type: Literal[ProviderType.bedrock] = Field(ProviderType.bedrock, description="The type of the provider.")
|
||||
aws_region: str = Field(..., description="AWS region for Bedrock")
|
||||
|
||||
def list_llm_models(self):
|
||||
@ -1131,10 +1185,11 @@ class AnthropicBedrockProvider(Provider):
|
||||
configs.append(
|
||||
LLMConfig(
|
||||
model=model_arn,
|
||||
model_endpoint_type=self.name,
|
||||
model_endpoint_type=self.provider_type.value,
|
||||
model_endpoint=None,
|
||||
context_window=self.get_model_context_window(model_arn),
|
||||
handle=self.get_handle(model_arn),
|
||||
provider_name=self.name,
|
||||
)
|
||||
)
|
||||
return configs
|
||||
|
@ -11,13 +11,9 @@ from letta.constants import (
|
||||
MCP_TOOL_TAG_NAME_PREFIX,
|
||||
)
|
||||
from letta.functions.ast_parsers import get_function_name_and_description
|
||||
from letta.functions.composio_helpers import generate_composio_tool_wrapper
|
||||
from letta.functions.functions import derive_openai_json_schema, get_json_schema_from_module
|
||||
from letta.functions.helpers import (
|
||||
generate_composio_tool_wrapper,
|
||||
generate_langchain_tool_wrapper,
|
||||
generate_mcp_tool_wrapper,
|
||||
generate_model_from_args_json_schema,
|
||||
)
|
||||
from letta.functions.helpers import generate_langchain_tool_wrapper, generate_mcp_tool_wrapper, generate_model_from_args_json_schema
|
||||
from letta.functions.mcp_client.types import MCPTool
|
||||
from letta.functions.schema_generator import (
|
||||
generate_schema_from_args_schema_v2,
|
||||
@ -176,8 +172,7 @@ class ToolCreate(LettaBase):
|
||||
Returns:
|
||||
Tool: A Letta Tool initialized with attributes derived from the Composio tool.
|
||||
"""
|
||||
from composio import LogLevel
|
||||
from composio_langchain import ComposioToolSet
|
||||
from composio import ComposioToolSet, LogLevel
|
||||
|
||||
composio_toolset = ComposioToolSet(logging_level=LogLevel.ERROR, lock=False)
|
||||
composio_action_schemas = composio_toolset.get_action_schemas(actions=[action_name], check_connected_accounts=False)
|
||||
|
@ -14,6 +14,7 @@ from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
|
||||
from letta.__init__ import __version__
|
||||
from letta.agents.exceptions import IncompatibleAgentType
|
||||
from letta.constants import ADMIN_PREFIX, API_PREFIX, OPENAI_API_PREFIX
|
||||
from letta.errors import BedrockPermissionError, LettaAgentNotFoundError, LettaUserNotFoundError
|
||||
from letta.jobs.scheduler import shutdown_cron_scheduler, start_cron_jobs
|
||||
@ -173,6 +174,17 @@ def create_application() -> "FastAPI":
|
||||
def shutdown_scheduler():
|
||||
shutdown_cron_scheduler()
|
||||
|
||||
@app.exception_handler(IncompatibleAgentType)
|
||||
async def handle_incompatible_agent_type(request: Request, exc: IncompatibleAgentType):
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={
|
||||
"detail": str(exc),
|
||||
"expected_type": exc.expected_type,
|
||||
"actual_type": exc.actual_type,
|
||||
},
|
||||
)
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def generic_error_handler(request: Request, exc: Exception):
|
||||
# Log the actual error for debugging
|
||||
|
@ -12,7 +12,7 @@ from letta.schemas.enums import MessageStreamStatus
|
||||
from letta.schemas.letta_message import LettaMessage
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.openai.chat_completion_response import ChatCompletionChunkResponse
|
||||
from letta.server.rest_api.optimistic_json_parser import OptimisticJSONParser
|
||||
from letta.server.rest_api.json_parser import OptimisticJSONParser
|
||||
from letta.streaming_interface import AgentChunkStreamingInterface
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
@ -28,7 +28,7 @@ from letta.schemas.letta_message import (
|
||||
from letta.schemas.letta_message_content import ReasoningContent, RedactedReasoningContent, TextContent
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.openai.chat_completion_response import ChatCompletionChunkResponse
|
||||
from letta.server.rest_api.optimistic_json_parser import OptimisticJSONParser
|
||||
from letta.server.rest_api.json_parser import OptimisticJSONParser
|
||||
from letta.streaming_interface import AgentChunkStreamingInterface
|
||||
from letta.streaming_utils import FunctionArgumentsStreamHandler, JSONInnerThoughtsExtractor
|
||||
from letta.utils import parse_json
|
||||
@ -291,7 +291,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
self.streaming_chat_completion_json_reader = FunctionArgumentsStreamHandler(json_key=assistant_message_tool_kwarg)
|
||||
|
||||
# @matt's changes here, adopting new optimistic json parser
|
||||
self.current_function_arguments = []
|
||||
self.current_function_arguments = ""
|
||||
self.optimistic_json_parser = OptimisticJSONParser()
|
||||
self.current_json_parse_result = {}
|
||||
|
||||
@ -387,7 +387,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
def stream_start(self):
|
||||
"""Initialize streaming by activating the generator and clearing any old chunks."""
|
||||
self.streaming_chat_completion_mode_function_name = None
|
||||
self.current_function_arguments = []
|
||||
self.current_function_arguments = ""
|
||||
self.current_json_parse_result = {}
|
||||
|
||||
if not self._active:
|
||||
@ -398,7 +398,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
def stream_end(self):
|
||||
"""Clean up the stream by deactivating and clearing chunks."""
|
||||
self.streaming_chat_completion_mode_function_name = None
|
||||
self.current_function_arguments = []
|
||||
self.current_function_arguments = ""
|
||||
self.current_json_parse_result = {}
|
||||
|
||||
# if not self.streaming_chat_completion_mode and not self.nonstreaming_legacy_mode:
|
||||
@ -609,14 +609,13 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
# early exit to turn into content mode
|
||||
return None
|
||||
if tool_call.function.arguments:
|
||||
self.current_function_arguments.append(tool_call.function.arguments)
|
||||
self.current_function_arguments += tool_call.function.arguments
|
||||
|
||||
# if we're in the middle of parsing a send_message, we'll keep processing the JSON chunks
|
||||
if tool_call.function.arguments and self.streaming_chat_completion_mode_function_name == self.assistant_message_tool_name:
|
||||
# Strip out any extras tokens
|
||||
# In the case that we just have the prefix of something, no message yet, then we should early exit to move to the next chunk
|
||||
combined_args = "".join(self.current_function_arguments)
|
||||
parsed_args = self.optimistic_json_parser.parse(combined_args)
|
||||
parsed_args = self.optimistic_json_parser.parse(self.current_function_arguments)
|
||||
|
||||
if parsed_args.get(self.assistant_message_tool_kwarg) and parsed_args.get(
|
||||
self.assistant_message_tool_kwarg
|
||||
@ -686,7 +685,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
# updates_inner_thoughts = ""
|
||||
# else: # OpenAI
|
||||
# updates_main_json, updates_inner_thoughts = self.function_args_reader.process_fragment(tool_call.function.arguments)
|
||||
self.current_function_arguments.append(tool_call.function.arguments)
|
||||
self.current_function_arguments += tool_call.function.arguments
|
||||
updates_main_json, updates_inner_thoughts = self.function_args_reader.process_fragment(tool_call.function.arguments)
|
||||
|
||||
# If we have inner thoughts, we should output them as a chunk
|
||||
@ -805,8 +804,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
# TODO: THIS IS HORRIBLE
|
||||
# TODO: WE USE THE OLD JSON PARSER EARLIER (WHICH DOES NOTHING) AND NOW THE NEW JSON PARSER
|
||||
# TODO: THIS IS TOTALLY WRONG AND BAD, BUT SAVING FOR A LARGER REWRITE IN THE NEAR FUTURE
|
||||
combined_args = "".join(self.current_function_arguments)
|
||||
parsed_args = self.optimistic_json_parser.parse(combined_args)
|
||||
parsed_args = self.optimistic_json_parser.parse(self.current_function_arguments)
|
||||
|
||||
if parsed_args.get(self.assistant_message_tool_kwarg) and parsed_args.get(
|
||||
self.assistant_message_tool_kwarg
|
||||
|
@ -1,7 +1,43 @@
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from pydantic_core import from_json
|
||||
|
||||
from letta.log import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class OptimisticJSONParser:
|
||||
class JSONParser(ABC):
|
||||
@abstractmethod
|
||||
def parse(self, input_str: str) -> Any:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class PydanticJSONParser(JSONParser):
|
||||
"""
|
||||
https://docs.pydantic.dev/latest/concepts/json/#json-parsing
|
||||
If `strict` is True, we will not allow for partial parsing of JSON.
|
||||
|
||||
Compared with `OptimisticJSONParser`, this parser is more strict.
|
||||
Note: This will not partially parse strings which may be decrease parsing speed for message strings
|
||||
"""
|
||||
|
||||
def __init__(self, strict=False):
|
||||
self.strict = strict
|
||||
|
||||
def parse(self, input_str: str) -> Any:
|
||||
if not input_str:
|
||||
return {}
|
||||
try:
|
||||
return from_json(input_str, allow_partial="trailing-strings" if not self.strict else False)
|
||||
except ValueError as e:
|
||||
logger.error(f"Failed to parse JSON: {e}")
|
||||
raise
|
||||
|
||||
|
||||
class OptimisticJSONParser(JSONParser):
|
||||
"""
|
||||
A JSON parser that attempts to parse a given string using `json.loads`,
|
||||
and if that fails, it parses as much valid JSON as possible while
|
||||
@ -13,25 +49,25 @@ class OptimisticJSONParser:
|
||||
def __init__(self, strict=False):
|
||||
self.strict = strict
|
||||
self.parsers = {
|
||||
" ": self.parse_space,
|
||||
"\r": self.parse_space,
|
||||
"\n": self.parse_space,
|
||||
"\t": self.parse_space,
|
||||
"[": self.parse_array,
|
||||
"{": self.parse_object,
|
||||
'"': self.parse_string,
|
||||
"t": self.parse_true,
|
||||
"f": self.parse_false,
|
||||
"n": self.parse_null,
|
||||
" ": self._parse_space,
|
||||
"\r": self._parse_space,
|
||||
"\n": self._parse_space,
|
||||
"\t": self._parse_space,
|
||||
"[": self._parse_array,
|
||||
"{": self._parse_object,
|
||||
'"': self._parse_string,
|
||||
"t": self._parse_true,
|
||||
"f": self._parse_false,
|
||||
"n": self._parse_null,
|
||||
}
|
||||
# Register number parser for digits and signs
|
||||
for char in "0123456789.-":
|
||||
self.parsers[char] = self.parse_number
|
||||
|
||||
self.last_parse_reminding = None
|
||||
self.on_extra_token = self.default_on_extra_token
|
||||
self.on_extra_token = self._default_on_extra_token
|
||||
|
||||
def default_on_extra_token(self, text, data, reminding):
|
||||
def _default_on_extra_token(self, text, data, reminding):
|
||||
print(f"Parsed JSON with extra tokens: {data}, remaining: {reminding}")
|
||||
|
||||
def parse(self, input_str):
|
||||
@ -45,7 +81,7 @@ class OptimisticJSONParser:
|
||||
try:
|
||||
return json.loads(input_str)
|
||||
except json.JSONDecodeError as decode_error:
|
||||
data, reminding = self.parse_any(input_str, decode_error)
|
||||
data, reminding = self._parse_any(input_str, decode_error)
|
||||
self.last_parse_reminding = reminding
|
||||
if self.on_extra_token and reminding:
|
||||
self.on_extra_token(input_str, data, reminding)
|
||||
@ -53,7 +89,7 @@ class OptimisticJSONParser:
|
||||
else:
|
||||
return json.loads("{}")
|
||||
|
||||
def parse_any(self, input_str, decode_error):
|
||||
def _parse_any(self, input_str, decode_error):
|
||||
"""Determine which parser to use based on the first character."""
|
||||
if not input_str:
|
||||
raise decode_error
|
||||
@ -62,11 +98,11 @@ class OptimisticJSONParser:
|
||||
raise decode_error
|
||||
return parser(input_str, decode_error)
|
||||
|
||||
def parse_space(self, input_str, decode_error):
|
||||
def _parse_space(self, input_str, decode_error):
|
||||
"""Strip leading whitespace and parse again."""
|
||||
return self.parse_any(input_str.strip(), decode_error)
|
||||
return self._parse_any(input_str.strip(), decode_error)
|
||||
|
||||
def parse_array(self, input_str, decode_error):
|
||||
def _parse_array(self, input_str, decode_error):
|
||||
"""Parse a JSON array, returning the list and remaining string."""
|
||||
# Skip the '['
|
||||
input_str = input_str[1:]
|
||||
@ -77,7 +113,7 @@ class OptimisticJSONParser:
|
||||
# Skip the ']'
|
||||
input_str = input_str[1:]
|
||||
break
|
||||
value, input_str = self.parse_any(input_str, decode_error)
|
||||
value, input_str = self._parse_any(input_str, decode_error)
|
||||
array_values.append(value)
|
||||
input_str = input_str.strip()
|
||||
if input_str.startswith(","):
|
||||
@ -85,7 +121,7 @@ class OptimisticJSONParser:
|
||||
input_str = input_str[1:].strip()
|
||||
return array_values, input_str
|
||||
|
||||
def parse_object(self, input_str, decode_error):
|
||||
def _parse_object(self, input_str, decode_error):
|
||||
"""Parse a JSON object, returning the dict and remaining string."""
|
||||
# Skip the '{'
|
||||
input_str = input_str[1:]
|
||||
@ -96,7 +132,7 @@ class OptimisticJSONParser:
|
||||
# Skip the '}'
|
||||
input_str = input_str[1:]
|
||||
break
|
||||
key, input_str = self.parse_any(input_str, decode_error)
|
||||
key, input_str = self._parse_any(input_str, decode_error)
|
||||
input_str = input_str.strip()
|
||||
|
||||
if not input_str or input_str[0] == "}":
|
||||
@ -113,7 +149,7 @@ class OptimisticJSONParser:
|
||||
input_str = input_str[1:]
|
||||
break
|
||||
|
||||
value, input_str = self.parse_any(input_str, decode_error)
|
||||
value, input_str = self._parse_any(input_str, decode_error)
|
||||
obj[key] = value
|
||||
input_str = input_str.strip()
|
||||
if input_str.startswith(","):
|
||||
@ -121,7 +157,7 @@ class OptimisticJSONParser:
|
||||
input_str = input_str[1:].strip()
|
||||
return obj, input_str
|
||||
|
||||
def parse_string(self, input_str, decode_error):
|
||||
def _parse_string(self, input_str, decode_error):
|
||||
"""Parse a JSON string, respecting escaped quotes if present."""
|
||||
end = input_str.find('"', 1)
|
||||
while end != -1 and input_str[end - 1] == "\\":
|
||||
@ -166,19 +202,19 @@ class OptimisticJSONParser:
|
||||
|
||||
return num, remainder
|
||||
|
||||
def parse_true(self, input_str, decode_error):
|
||||
def _parse_true(self, input_str, decode_error):
|
||||
"""Parse a 'true' value."""
|
||||
if input_str.startswith(("t", "T")):
|
||||
return True, input_str[4:]
|
||||
raise decode_error
|
||||
|
||||
def parse_false(self, input_str, decode_error):
|
||||
def _parse_false(self, input_str, decode_error):
|
||||
"""Parse a 'false' value."""
|
||||
if input_str.startswith(("f", "F")):
|
||||
return False, input_str[5:]
|
||||
raise decode_error
|
||||
|
||||
def parse_null(self, input_str, decode_error):
|
||||
def _parse_null(self, input_str, decode_error):
|
||||
"""Parse a 'null' value."""
|
||||
if input_str.startswith("n"):
|
||||
return None, input_str[4:]
|
@ -678,7 +678,7 @@ async def send_message_streaming(
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
request: LettaStreamingRequest = Body(...),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
) -> StreamingResponse | LettaResponse:
|
||||
"""
|
||||
Process a user message and return the agent's response.
|
||||
This endpoint accepts a message from a user and processes it through the agent.
|
||||
|
@ -1,6 +1,6 @@
|
||||
from typing import TYPE_CHECKING, List
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
@ -14,10 +14,11 @@ router = APIRouter(prefix="/models", tags=["models", "llms"])
|
||||
|
||||
@router.get("/", response_model=List[LLMConfig], operation_id="list_models")
|
||||
def list_llm_models(
|
||||
byok_only: Optional[bool] = Query(None),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
):
|
||||
|
||||
models = server.list_llm_models()
|
||||
models = server.list_llm_models(byok_only=byok_only)
|
||||
# print(models)
|
||||
return models
|
||||
|
||||
|
@ -2,6 +2,7 @@ from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query
|
||||
|
||||
from letta.schemas.enums import ProviderType
|
||||
from letta.schemas.providers import Provider, ProviderCreate, ProviderUpdate
|
||||
from letta.server.rest_api.utils import get_letta_server
|
||||
|
||||
@ -13,6 +14,8 @@ router = APIRouter(prefix="/providers", tags=["providers"])
|
||||
|
||||
@router.get("/", response_model=List[Provider], operation_id="list_providers")
|
||||
def list_providers(
|
||||
name: Optional[str] = Query(None),
|
||||
provider_type: Optional[ProviderType] = Query(None),
|
||||
after: Optional[str] = Query(None),
|
||||
limit: Optional[int] = Query(50),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
@ -23,7 +26,7 @@ def list_providers(
|
||||
"""
|
||||
try:
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
providers = server.provider_manager.list_providers(after=after, limit=limit, actor=actor)
|
||||
providers = server.provider_manager.list_providers(after=after, limit=limit, actor=actor, name=name, provider_type=provider_type)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
|
@ -54,8 +54,6 @@ async def create_voice_chat_completions(
|
||||
block_manager=server.block_manager,
|
||||
passage_manager=server.passage_manager,
|
||||
actor=actor,
|
||||
message_buffer_limit=8,
|
||||
message_buffer_min=4,
|
||||
)
|
||||
|
||||
# Return the streaming generator
|
||||
|
@ -16,6 +16,7 @@ from pydantic import BaseModel
|
||||
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, FUNC_FAILED_HEARTBEAT_MESSAGE, REQ_HEARTBEAT_MESSAGE
|
||||
from letta.errors import ContextWindowExceededError, RateLimitExceededError
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
from letta.helpers.message_helper import convert_message_creates_to_messages
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, RedactedReasoningContent, TextContent
|
||||
@ -143,27 +144,15 @@ def log_error_to_sentry(e):
|
||||
def create_input_messages(input_messages: List[MessageCreate], agent_id: str, actor: User) -> List[Message]:
|
||||
"""
|
||||
Converts a user input message into the internal structured format.
|
||||
"""
|
||||
new_messages = []
|
||||
for input_message in input_messages:
|
||||
# Construct the Message object
|
||||
new_message = Message(
|
||||
id=f"message-{uuid.uuid4()}",
|
||||
role=input_message.role,
|
||||
content=input_message.content,
|
||||
name=input_message.name,
|
||||
otid=input_message.otid,
|
||||
sender_id=input_message.sender_id,
|
||||
organization_id=actor.organization_id,
|
||||
agent_id=agent_id,
|
||||
model=None,
|
||||
tool_calls=None,
|
||||
tool_call_id=None,
|
||||
created_at=get_utc_time(),
|
||||
)
|
||||
new_messages.append(new_message)
|
||||
|
||||
return new_messages
|
||||
TODO (cliandy): this effectively duplicates the functionality of `convert_message_creates_to_messages`,
|
||||
we should unify this when it's clear what message attributes we need.
|
||||
"""
|
||||
|
||||
messages = convert_message_creates_to_messages(input_messages, agent_id, wrap_user_message=False, wrap_system_message=False)
|
||||
for message in messages:
|
||||
message.organization_id = actor.organization_id
|
||||
return messages
|
||||
|
||||
|
||||
def create_letta_messages_from_llm_response(
|
||||
|
@ -268,10 +268,11 @@ class SyncServer(Server):
|
||||
)
|
||||
|
||||
# collect providers (always has Letta as a default)
|
||||
self._enabled_providers: List[Provider] = [LettaProvider()]
|
||||
self._enabled_providers: List[Provider] = [LettaProvider(name="letta")]
|
||||
if model_settings.openai_api_key:
|
||||
self._enabled_providers.append(
|
||||
OpenAIProvider(
|
||||
name="openai",
|
||||
api_key=model_settings.openai_api_key,
|
||||
base_url=model_settings.openai_api_base,
|
||||
)
|
||||
@ -279,12 +280,14 @@ class SyncServer(Server):
|
||||
if model_settings.anthropic_api_key:
|
||||
self._enabled_providers.append(
|
||||
AnthropicProvider(
|
||||
name="anthropic",
|
||||
api_key=model_settings.anthropic_api_key,
|
||||
)
|
||||
)
|
||||
if model_settings.ollama_base_url:
|
||||
self._enabled_providers.append(
|
||||
OllamaProvider(
|
||||
name="ollama",
|
||||
base_url=model_settings.ollama_base_url,
|
||||
api_key=None,
|
||||
default_prompt_formatter=model_settings.default_prompt_formatter,
|
||||
@ -293,12 +296,14 @@ class SyncServer(Server):
|
||||
if model_settings.gemini_api_key:
|
||||
self._enabled_providers.append(
|
||||
GoogleAIProvider(
|
||||
name="google_ai",
|
||||
api_key=model_settings.gemini_api_key,
|
||||
)
|
||||
)
|
||||
if model_settings.google_cloud_location and model_settings.google_cloud_project:
|
||||
self._enabled_providers.append(
|
||||
GoogleVertexProvider(
|
||||
name="google_vertex",
|
||||
google_cloud_project=model_settings.google_cloud_project,
|
||||
google_cloud_location=model_settings.google_cloud_location,
|
||||
)
|
||||
@ -307,6 +312,7 @@ class SyncServer(Server):
|
||||
assert model_settings.azure_api_version, "AZURE_API_VERSION is required"
|
||||
self._enabled_providers.append(
|
||||
AzureProvider(
|
||||
name="azure",
|
||||
api_key=model_settings.azure_api_key,
|
||||
base_url=model_settings.azure_base_url,
|
||||
api_version=model_settings.azure_api_version,
|
||||
@ -315,12 +321,14 @@ class SyncServer(Server):
|
||||
if model_settings.groq_api_key:
|
||||
self._enabled_providers.append(
|
||||
GroqProvider(
|
||||
name="groq",
|
||||
api_key=model_settings.groq_api_key,
|
||||
)
|
||||
)
|
||||
if model_settings.together_api_key:
|
||||
self._enabled_providers.append(
|
||||
TogetherProvider(
|
||||
name="together",
|
||||
api_key=model_settings.together_api_key,
|
||||
default_prompt_formatter=model_settings.default_prompt_formatter,
|
||||
)
|
||||
@ -329,6 +337,7 @@ class SyncServer(Server):
|
||||
# vLLM exposes both a /chat/completions and a /completions endpoint
|
||||
self._enabled_providers.append(
|
||||
VLLMCompletionsProvider(
|
||||
name="vllm",
|
||||
base_url=model_settings.vllm_api_base,
|
||||
default_prompt_formatter=model_settings.default_prompt_formatter,
|
||||
)
|
||||
@ -338,12 +347,14 @@ class SyncServer(Server):
|
||||
# e.g. "... --enable-auto-tool-choice --tool-call-parser hermes"
|
||||
self._enabled_providers.append(
|
||||
VLLMChatCompletionsProvider(
|
||||
name="vllm",
|
||||
base_url=model_settings.vllm_api_base,
|
||||
)
|
||||
)
|
||||
if model_settings.aws_access_key and model_settings.aws_secret_access_key and model_settings.aws_region:
|
||||
self._enabled_providers.append(
|
||||
AnthropicBedrockProvider(
|
||||
name="bedrock",
|
||||
aws_region=model_settings.aws_region,
|
||||
)
|
||||
)
|
||||
@ -355,11 +366,11 @@ class SyncServer(Server):
|
||||
if model_settings.lmstudio_base_url.endswith("/v1")
|
||||
else model_settings.lmstudio_base_url + "/v1"
|
||||
)
|
||||
self._enabled_providers.append(LMStudioOpenAIProvider(base_url=lmstudio_url))
|
||||
self._enabled_providers.append(LMStudioOpenAIProvider(name="lmstudio_openai", base_url=lmstudio_url))
|
||||
if model_settings.deepseek_api_key:
|
||||
self._enabled_providers.append(DeepSeekProvider(api_key=model_settings.deepseek_api_key))
|
||||
self._enabled_providers.append(DeepSeekProvider(name="deepseek", api_key=model_settings.deepseek_api_key))
|
||||
if model_settings.xai_api_key:
|
||||
self._enabled_providers.append(XAIProvider(api_key=model_settings.xai_api_key))
|
||||
self._enabled_providers.append(XAIProvider(name="xai", api_key=model_settings.xai_api_key))
|
||||
|
||||
# For MCP
|
||||
"""Initialize the MCP clients (there may be multiple)"""
|
||||
@ -862,6 +873,8 @@ class SyncServer(Server):
|
||||
agent_ids=[voice_sleeptime_agent.id],
|
||||
manager_config=VoiceSleeptimeManager(
|
||||
manager_agent_id=main_agent.id,
|
||||
max_message_buffer_length=constants.DEFAULT_MAX_MESSAGE_BUFFER_LENGTH,
|
||||
min_message_buffer_length=constants.DEFAULT_MIN_MESSAGE_BUFFER_LENGTH,
|
||||
),
|
||||
),
|
||||
actor=actor,
|
||||
@ -1182,10 +1195,10 @@ class SyncServer(Server):
|
||||
except NoResultFound:
|
||||
raise HTTPException(status_code=404, detail=f"Organization with id {org_id} not found")
|
||||
|
||||
def list_llm_models(self) -> List[LLMConfig]:
|
||||
def list_llm_models(self, byok_only: bool = False) -> List[LLMConfig]:
|
||||
"""List available models"""
|
||||
llm_models = []
|
||||
for provider in self.get_enabled_providers():
|
||||
for provider in self.get_enabled_providers(byok_only=byok_only):
|
||||
try:
|
||||
llm_models.extend(provider.list_llm_models())
|
||||
except Exception as e:
|
||||
@ -1205,11 +1218,12 @@ class SyncServer(Server):
|
||||
warnings.warn(f"An error occurred while listing embedding models for provider {provider}: {e}")
|
||||
return embedding_models
|
||||
|
||||
def get_enabled_providers(self):
|
||||
def get_enabled_providers(self, byok_only: bool = False):
|
||||
providers_from_db = {p.name: p.cast_to_subtype() for p in self.provider_manager.list_providers()}
|
||||
if byok_only:
|
||||
return list(providers_from_db.values())
|
||||
providers_from_env = {p.name: p for p in self._enabled_providers}
|
||||
providers_from_db = {p.name: p for p in self.provider_manager.list_providers()}
|
||||
# Merge the two dictionaries, keeping the values from providers_from_db where conflicts occur
|
||||
return {**providers_from_env, **providers_from_db}.values()
|
||||
return list(providers_from_env.values()) + list(providers_from_db.values())
|
||||
|
||||
@trace_method
|
||||
def get_llm_config_from_handle(
|
||||
@ -1294,7 +1308,7 @@ class SyncServer(Server):
|
||||
return embedding_config
|
||||
|
||||
def get_provider_from_name(self, provider_name: str) -> Provider:
|
||||
providers = [provider for provider in self._enabled_providers if provider.name == provider_name]
|
||||
providers = [provider for provider in self.get_enabled_providers() if provider.name == provider_name]
|
||||
if not providers:
|
||||
raise ValueError(f"Provider {provider_name} is not supported")
|
||||
elif len(providers) > 1:
|
||||
|
@ -80,6 +80,12 @@ class GroupManager:
|
||||
case ManagerType.voice_sleeptime:
|
||||
new_group.manager_type = ManagerType.voice_sleeptime
|
||||
new_group.manager_agent_id = group.manager_config.manager_agent_id
|
||||
max_message_buffer_length = group.manager_config.max_message_buffer_length
|
||||
min_message_buffer_length = group.manager_config.min_message_buffer_length
|
||||
# Safety check for buffer length range
|
||||
self.ensure_buffer_length_range_valid(max_value=max_message_buffer_length, min_value=min_message_buffer_length)
|
||||
new_group.max_message_buffer_length = max_message_buffer_length
|
||||
new_group.min_message_buffer_length = min_message_buffer_length
|
||||
case _:
|
||||
raise ValueError(f"Unsupported manager type: {group.manager_config.manager_type}")
|
||||
|
||||
@ -97,6 +103,8 @@ class GroupManager:
|
||||
group = GroupModel.read(db_session=session, identifier=group_id, actor=actor)
|
||||
|
||||
sleeptime_agent_frequency = None
|
||||
max_message_buffer_length = None
|
||||
min_message_buffer_length = None
|
||||
max_turns = None
|
||||
termination_token = None
|
||||
manager_agent_id = None
|
||||
@ -117,11 +125,24 @@ class GroupManager:
|
||||
sleeptime_agent_frequency = group_update.manager_config.sleeptime_agent_frequency
|
||||
if sleeptime_agent_frequency and group.turns_counter is None:
|
||||
group.turns_counter = -1
|
||||
case ManagerType.voice_sleeptime:
|
||||
manager_agent_id = group_update.manager_config.manager_agent_id
|
||||
max_message_buffer_length = group_update.manager_config.max_message_buffer_length or group.max_message_buffer_length
|
||||
min_message_buffer_length = group_update.manager_config.min_message_buffer_length or group.min_message_buffer_length
|
||||
if sleeptime_agent_frequency and group.turns_counter is None:
|
||||
group.turns_counter = -1
|
||||
case _:
|
||||
raise ValueError(f"Unsupported manager type: {group_update.manager_config.manager_type}")
|
||||
|
||||
# Safety check for buffer length range
|
||||
self.ensure_buffer_length_range_valid(max_value=max_message_buffer_length, min_value=min_message_buffer_length)
|
||||
|
||||
if sleeptime_agent_frequency:
|
||||
group.sleeptime_agent_frequency = sleeptime_agent_frequency
|
||||
if max_message_buffer_length:
|
||||
group.max_message_buffer_length = max_message_buffer_length
|
||||
if min_message_buffer_length:
|
||||
group.min_message_buffer_length = min_message_buffer_length
|
||||
if max_turns:
|
||||
group.max_turns = max_turns
|
||||
if termination_token:
|
||||
@ -274,3 +295,40 @@ class GroupManager:
|
||||
if manager_agent:
|
||||
for block in blocks:
|
||||
session.add(BlocksAgents(agent_id=manager_agent.id, block_id=block.id, block_label=block.label))
|
||||
|
||||
@staticmethod
|
||||
def ensure_buffer_length_range_valid(
|
||||
max_value: Optional[int],
|
||||
min_value: Optional[int],
|
||||
max_name: str = "max_message_buffer_length",
|
||||
min_name: str = "min_message_buffer_length",
|
||||
) -> None:
|
||||
"""
|
||||
1) Both-or-none: if one is set, the other must be set.
|
||||
2) Both must be ints > 4.
|
||||
3) max_value must be strictly greater than min_value.
|
||||
"""
|
||||
# 1) require both-or-none
|
||||
if (max_value is None) != (min_value is None):
|
||||
raise ValueError(
|
||||
f"Both '{max_name}' and '{min_name}' must be provided together " f"(got {max_name}={max_value}, {min_name}={min_value})"
|
||||
)
|
||||
|
||||
# no further checks if neither is provided
|
||||
if max_value is None:
|
||||
return
|
||||
|
||||
# 2) type & lower‐bound checks
|
||||
if not isinstance(max_value, int) or not isinstance(min_value, int):
|
||||
raise ValueError(
|
||||
f"Both '{max_name}' and '{min_name}' must be integers "
|
||||
f"(got {max_name}={type(max_value).__name__}, {min_name}={type(min_value).__name__})"
|
||||
)
|
||||
if max_value <= 4 or min_value <= 4:
|
||||
raise ValueError(
|
||||
f"Both '{max_name}' and '{min_name}' must be greater than 4 " f"(got {max_name}={max_value}, {min_name}={min_value})"
|
||||
)
|
||||
|
||||
# 3) ordering
|
||||
if max_value <= min_value:
|
||||
raise ValueError(f"'{max_name}' must be greater than '{min_name}' " f"(got {max_name}={max_value} <= {min_name}={min_value})")
|
||||
|
@ -1,6 +1,7 @@
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from letta.orm.provider import Provider as ProviderModel
|
||||
from letta.schemas.enums import ProviderType
|
||||
from letta.schemas.providers import Provider as PydanticProvider
|
||||
from letta.schemas.providers import ProviderUpdate
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
@ -18,6 +19,9 @@ class ProviderManager:
|
||||
def create_provider(self, provider: PydanticProvider, actor: PydanticUser) -> PydanticProvider:
|
||||
"""Create a new provider if it doesn't already exist."""
|
||||
with self.session_maker() as session:
|
||||
if provider.name == provider.provider_type.value:
|
||||
raise ValueError("Provider name must be unique and different from provider type")
|
||||
|
||||
# Assign the organization id based on the actor
|
||||
provider.organization_id = actor.organization_id
|
||||
|
||||
@ -59,29 +63,36 @@ class ProviderManager:
|
||||
session.commit()
|
||||
|
||||
@enforce_types
|
||||
def list_providers(self, after: Optional[str] = None, limit: Optional[int] = 50, actor: PydanticUser = None) -> List[PydanticProvider]:
|
||||
def list_providers(
|
||||
self,
|
||||
name: Optional[str] = None,
|
||||
provider_type: Optional[ProviderType] = None,
|
||||
after: Optional[str] = None,
|
||||
limit: Optional[int] = 50,
|
||||
actor: PydanticUser = None,
|
||||
) -> List[PydanticProvider]:
|
||||
"""List all providers with optional pagination."""
|
||||
filter_kwargs = {}
|
||||
if name:
|
||||
filter_kwargs["name"] = name
|
||||
if provider_type:
|
||||
filter_kwargs["provider_type"] = provider_type
|
||||
with self.session_maker() as session:
|
||||
providers = ProviderModel.list(
|
||||
db_session=session,
|
||||
after=after,
|
||||
limit=limit,
|
||||
actor=actor,
|
||||
**filter_kwargs,
|
||||
)
|
||||
return [provider.to_pydantic() for provider in providers]
|
||||
|
||||
@enforce_types
|
||||
def get_anthropic_override_provider_id(self) -> Optional[str]:
|
||||
"""Helper function to fetch custom anthropic provider id for v0 BYOK feature"""
|
||||
anthropic_provider = [provider for provider in self.list_providers() if provider.name == "anthropic"]
|
||||
if len(anthropic_provider) != 0:
|
||||
return anthropic_provider[0].id
|
||||
return None
|
||||
def get_provider_id_from_name(self, provider_name: Union[str, None]) -> Optional[str]:
|
||||
providers = self.list_providers(name=provider_name)
|
||||
return providers[0].id if providers else None
|
||||
|
||||
@enforce_types
|
||||
def get_anthropic_override_key(self) -> Optional[str]:
|
||||
"""Helper function to fetch custom anthropic key for v0 BYOK feature"""
|
||||
anthropic_provider = [provider for provider in self.list_providers() if provider.name == "anthropic"]
|
||||
if len(anthropic_provider) != 0:
|
||||
return anthropic_provider[0].api_key
|
||||
return None
|
||||
def get_override_key(self, provider_name: Union[str, None]) -> Optional[str]:
|
||||
providers = self.list_providers(name=provider_name)
|
||||
return providers[0].api_key if providers else None
|
||||
|
@ -4,6 +4,7 @@ import traceback
|
||||
from typing import List, Tuple
|
||||
|
||||
from letta.agents.voice_sleeptime_agent import VoiceSleeptimeAgent
|
||||
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
@ -77,7 +78,7 @@ class Summarizer:
|
||||
|
||||
logger.info("Buffer length hit, evicting messages.")
|
||||
|
||||
target_trim_index = len(all_in_context_messages) - self.message_buffer_min + 1
|
||||
target_trim_index = len(all_in_context_messages) - self.message_buffer_min
|
||||
|
||||
while target_trim_index < len(all_in_context_messages) and all_in_context_messages[target_trim_index].role != MessageRole.user:
|
||||
target_trim_index += 1
|
||||
@ -112,11 +113,12 @@ class Summarizer:
|
||||
summary_request_text = f"""You’re a memory-recall helper for an AI that can only keep the last {self.message_buffer_min} messages. Scan the conversation history, focusing on messages about to drop out of that window, and write crisp notes that capture any important facts or insights about the human so they aren’t lost.
|
||||
|
||||
(Older) Evicted Messages:\n
|
||||
{evicted_messages_str}
|
||||
{evicted_messages_str}\n
|
||||
|
||||
(Newer) In-Context Messages:\n
|
||||
{in_context_messages_str}
|
||||
"""
|
||||
print(summary_request_text)
|
||||
# Fire-and-forget the summarization task
|
||||
self.fire_and_forget(
|
||||
self.summarizer_agent.step([MessageCreate(role=MessageRole.user, content=[TextContent(text=summary_request_text)])])
|
||||
@ -149,6 +151,9 @@ def format_transcript(messages: List[Message], include_system: bool = False) ->
|
||||
|
||||
# 1) Try plain content
|
||||
if msg.content:
|
||||
# Skip tool messages where the name is "send_message"
|
||||
if msg.role == MessageRole.tool and msg.name == DEFAULT_MESSAGE_TOOL:
|
||||
continue
|
||||
text = "".join(c.text for c in msg.content).strip()
|
||||
|
||||
# 2) Otherwise, try extracting from function calls
|
||||
@ -156,12 +161,15 @@ def format_transcript(messages: List[Message], include_system: bool = False) ->
|
||||
parts = []
|
||||
for call in msg.tool_calls:
|
||||
args_str = call.function.arguments
|
||||
if call.function.name == DEFAULT_MESSAGE_TOOL:
|
||||
try:
|
||||
args = json.loads(args_str)
|
||||
# pull out a "message" field if present
|
||||
parts.append(args.get("message", args_str))
|
||||
parts.append(args.get(DEFAULT_MESSAGE_TOOL_KWARG, args_str))
|
||||
except json.JSONDecodeError:
|
||||
parts.append(args_str)
|
||||
else:
|
||||
parts.append(args_str)
|
||||
text = " ".join(parts).strip()
|
||||
|
||||
else:
|
||||
|
@ -100,7 +100,7 @@ class ToolExecutionManager:
|
||||
try:
|
||||
executor = ToolExecutorFactory.get_executor(tool.tool_type)
|
||||
# TODO: Extend this async model to composio
|
||||
if isinstance(executor, SandboxToolExecutor):
|
||||
if isinstance(executor, (SandboxToolExecutor, ExternalComposioToolExecutor)):
|
||||
result = await executor.execute(function_name, function_args, self.agent_state, tool, self.actor)
|
||||
else:
|
||||
result = executor.execute(function_name, function_args, self.agent_state, tool, self.actor)
|
||||
|
@ -5,7 +5,7 @@ from typing import Any, Dict, Optional
|
||||
|
||||
from letta.constants import COMPOSIO_ENTITY_ENV_VAR_KEY, CORE_MEMORY_LINE_NUMBER_WARNING, RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
|
||||
from letta.functions.ast_parsers import coerce_dict_args_by_annotations, get_function_annotations_from_source
|
||||
from letta.functions.helpers import execute_composio_action, generate_composio_action_from_func_name
|
||||
from letta.functions.composio_helpers import execute_composio_action_async, generate_composio_action_from_func_name
|
||||
from letta.helpers.composio_helpers import get_composio_api_key
|
||||
from letta.helpers.json_helpers import json_dumps
|
||||
from letta.schemas.agent import AgentState
|
||||
@ -486,7 +486,7 @@ class LettaMultiAgentToolExecutor(ToolExecutor):
|
||||
class ExternalComposioToolExecutor(ToolExecutor):
|
||||
"""Executor for external Composio tools."""
|
||||
|
||||
def execute(
|
||||
async def execute(
|
||||
self,
|
||||
function_name: str,
|
||||
function_args: dict,
|
||||
@ -505,7 +505,7 @@ class ExternalComposioToolExecutor(ToolExecutor):
|
||||
composio_api_key = get_composio_api_key(actor=actor)
|
||||
|
||||
# TODO (matt): Roll in execute_composio_action into this class
|
||||
function_response = execute_composio_action(
|
||||
function_response = await execute_composio_action_async(
|
||||
action_name=action_name, args=function_args, api_key=composio_api_key, entity_id=entity_id
|
||||
)
|
||||
|
||||
|
104
poetry.lock
generated
104
poetry.lock
generated
@ -1016,25 +1016,6 @@ e2b = ["e2b (>=0.17.2a37,<1.1.0)", "e2b-code-interpreter"]
|
||||
flyio = ["gql", "requests_toolbelt"]
|
||||
tools = ["diskcache", "flake8", "networkx", "pathspec", "pygments", "ruff", "transformers"]
|
||||
|
||||
[[package]]
|
||||
name = "composio-langchain"
|
||||
version = "0.7.15"
|
||||
description = "Use Composio to get an array of tools with your LangChain agent."
|
||||
optional = false
|
||||
python-versions = "<4,>=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "composio_langchain-0.7.15-py3-none-any.whl", hash = "sha256:a71b5371ad6c3ee4d4289c7a994fad1424e24c29a38e820b6b2ed259056abb65"},
|
||||
{file = "composio_langchain-0.7.15.tar.gz", hash = "sha256:cb75c460289ecdf9590caf7ddc0d7888b0a6622ca4f800c9358abe90c25d055e"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
composio_core = ">=0.7.0,<0.8.0"
|
||||
langchain = ">=0.1.0"
|
||||
langchain-openai = ">=0.0.2.post1"
|
||||
langchainhub = ">=0.1.15"
|
||||
pydantic = ">=2.6.4"
|
||||
|
||||
[[package]]
|
||||
name = "configargparse"
|
||||
version = "1.7"
|
||||
@ -2842,9 +2823,10 @@ files = [
|
||||
name = "jsonpatch"
|
||||
version = "1.33"
|
||||
description = "Apply JSON-Patches (RFC 6902)"
|
||||
optional = false
|
||||
optional = true
|
||||
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, !=3.6.*"
|
||||
groups = ["main"]
|
||||
markers = "extra == \"external-tools\" or extra == \"desktop\" or extra == \"all\""
|
||||
files = [
|
||||
{file = "jsonpatch-1.33-py2.py3-none-any.whl", hash = "sha256:0ae28c0cd062bbd8b8ecc26d7d164fbbea9652a1a3693f3b956c1eae5145dade"},
|
||||
{file = "jsonpatch-1.33.tar.gz", hash = "sha256:9fcd4009c41e6d12348b4a0ff2563ba56a2923a7dfee731d004e212e1ee5030c"},
|
||||
@ -2857,9 +2839,10 @@ jsonpointer = ">=1.9"
|
||||
name = "jsonpointer"
|
||||
version = "3.0.0"
|
||||
description = "Identify specific nodes in a JSON document (RFC 6901)"
|
||||
optional = false
|
||||
optional = true
|
||||
python-versions = ">=3.7"
|
||||
groups = ["main"]
|
||||
markers = "extra == \"external-tools\" or extra == \"desktop\" or extra == \"all\""
|
||||
files = [
|
||||
{file = "jsonpointer-3.0.0-py2.py3-none-any.whl", hash = "sha256:13e088adc14fca8b6aa8177c044e12701e6ad4b28ff10e65f2267a90109c9942"},
|
||||
{file = "jsonpointer-3.0.0.tar.gz", hash = "sha256:2b2d729f2091522d61c3b31f82e11870f60b68f43fbc705cb76bf4b832af59ef"},
|
||||
@ -3052,9 +3035,10 @@ files = [
|
||||
name = "langchain"
|
||||
version = "0.3.23"
|
||||
description = "Building applications with LLMs through composability"
|
||||
optional = false
|
||||
optional = true
|
||||
python-versions = "<4.0,>=3.9"
|
||||
groups = ["main"]
|
||||
markers = "extra == \"external-tools\" or extra == \"desktop\" or extra == \"all\""
|
||||
files = [
|
||||
{file = "langchain-0.3.23-py3-none-any.whl", hash = "sha256:084f05ee7e80b7c3f378ebadd7309f2a37868ce2906fa0ae64365a67843ade3d"},
|
||||
{file = "langchain-0.3.23.tar.gz", hash = "sha256:d95004afe8abebb52d51d6026270248da3f4b53d93e9bf699f76005e0c83ad34"},
|
||||
@ -3120,9 +3104,10 @@ tenacity = ">=8.1.0,<8.4.0 || >8.4.0,<10"
|
||||
name = "langchain-core"
|
||||
version = "0.3.51"
|
||||
description = "Building applications with LLMs through composability"
|
||||
optional = false
|
||||
optional = true
|
||||
python-versions = "<4.0,>=3.9"
|
||||
groups = ["main"]
|
||||
markers = "extra == \"external-tools\" or extra == \"desktop\" or extra == \"all\""
|
||||
files = [
|
||||
{file = "langchain_core-0.3.51-py3-none-any.whl", hash = "sha256:4bd71e8acd45362aa428953f2a91d8162318014544a2216e4b769463caf68e13"},
|
||||
{file = "langchain_core-0.3.51.tar.gz", hash = "sha256:db76b9cc331411602cb40ba0469a161febe7a0663fbcaddbc9056046ac2d22f4"},
|
||||
@ -3140,30 +3125,14 @@ PyYAML = ">=5.3"
|
||||
tenacity = ">=8.1.0,<8.4.0 || >8.4.0,<10.0.0"
|
||||
typing-extensions = ">=4.7"
|
||||
|
||||
[[package]]
|
||||
name = "langchain-openai"
|
||||
version = "0.3.12"
|
||||
description = "An integration package connecting OpenAI and LangChain"
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "langchain_openai-0.3.12-py3-none-any.whl", hash = "sha256:0fab64d58ec95e65ffbaf659470cd362e815685e15edbcb171641e90eca4eb86"},
|
||||
{file = "langchain_openai-0.3.12.tar.gz", hash = "sha256:c9dbff63551f6bd91913bca9f99a2d057fd95dc58d4778657d67e5baa1737f61"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
langchain-core = ">=0.3.49,<1.0.0"
|
||||
openai = ">=1.68.2,<2.0.0"
|
||||
tiktoken = ">=0.7,<1"
|
||||
|
||||
[[package]]
|
||||
name = "langchain-text-splitters"
|
||||
version = "0.3.8"
|
||||
description = "LangChain text splitting utilities"
|
||||
optional = false
|
||||
optional = true
|
||||
python-versions = "<4.0,>=3.9"
|
||||
groups = ["main"]
|
||||
markers = "extra == \"external-tools\" or extra == \"desktop\" or extra == \"all\""
|
||||
files = [
|
||||
{file = "langchain_text_splitters-0.3.8-py3-none-any.whl", hash = "sha256:e75cc0f4ae58dcf07d9f18776400cf8ade27fadd4ff6d264df6278bb302f6f02"},
|
||||
{file = "langchain_text_splitters-0.3.8.tar.gz", hash = "sha256:116d4b9f2a22dda357d0b79e30acf005c5518177971c66a9f1ab0edfdb0f912e"},
|
||||
@ -3172,30 +3141,14 @@ files = [
|
||||
[package.dependencies]
|
||||
langchain-core = ">=0.3.51,<1.0.0"
|
||||
|
||||
[[package]]
|
||||
name = "langchainhub"
|
||||
version = "0.1.21"
|
||||
description = "The LangChain Hub API client"
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.8.1"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "langchainhub-0.1.21-py3-none-any.whl", hash = "sha256:1cc002dc31e0d132a776afd044361e2b698743df5202618cf2bad399246b895f"},
|
||||
{file = "langchainhub-0.1.21.tar.gz", hash = "sha256:723383b3964a47dbaea6ad5d0ef728accefbc9d2c07480e800bdec43510a8c10"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
packaging = ">=23.2,<25"
|
||||
requests = ">=2,<3"
|
||||
types-requests = ">=2.31.0.2,<3.0.0.0"
|
||||
|
||||
[[package]]
|
||||
name = "langsmith"
|
||||
version = "0.3.28"
|
||||
description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform."
|
||||
optional = false
|
||||
optional = true
|
||||
python-versions = "<4.0,>=3.9"
|
||||
groups = ["main"]
|
||||
markers = "extra == \"external-tools\" or extra == \"desktop\" or extra == \"all\""
|
||||
files = [
|
||||
{file = "langsmith-0.3.28-py3-none-any.whl", hash = "sha256:54ac8815514af52d9c801ad7970086693667e266bf1db90fc453c1759e8407cd"},
|
||||
{file = "langsmith-0.3.28.tar.gz", hash = "sha256:4666595207131d7f8d83418e54dc86c05e28562e5c997633e7c33fc18f9aeb89"},
|
||||
@ -3221,14 +3174,14 @@ pytest = ["pytest (>=7.0.0)", "rich (>=13.9.4,<14.0.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "letta-client"
|
||||
version = "0.1.124"
|
||||
version = "0.1.129"
|
||||
description = ""
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.8"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "letta_client-0.1.124-py3-none-any.whl", hash = "sha256:a7901437ef91f395cd85d24c0312046b7c82e5a4dd8e04de0d39b5ca085c65d3"},
|
||||
{file = "letta_client-0.1.124.tar.gz", hash = "sha256:e8b5716930824cc98c62ee01343e358f88619d346578d48a466277bc8282036d"},
|
||||
{file = "letta_client-0.1.129-py3-none-any.whl", hash = "sha256:87a5fc32471e5b9fefbfc1e1337fd667d5e2e340ece5d2a6c782afbceab4bf36"},
|
||||
{file = "letta_client-0.1.129.tar.gz", hash = "sha256:b00f611c18a2ad802ec9265f384e1666938c5fc5c86364b2c410d72f0331d597"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@ -4366,10 +4319,10 @@ files = [
|
||||
name = "orjson"
|
||||
version = "3.10.16"
|
||||
description = "Fast, correct Python JSON library supporting dataclasses, datetimes, and numpy"
|
||||
optional = false
|
||||
optional = true
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
markers = "platform_python_implementation != \"PyPy\""
|
||||
markers = "platform_python_implementation != \"PyPy\" and (extra == \"external-tools\" or extra == \"desktop\" or extra == \"all\")"
|
||||
files = [
|
||||
{file = "orjson-3.10.16-cp310-cp310-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:4cb473b8e79154fa778fb56d2d73763d977be3dcc140587e07dbc545bbfc38f8"},
|
||||
{file = "orjson-3.10.16-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:622a8e85eeec1948690409a19ca1c7d9fd8ff116f4861d261e6ae2094fe59a00"},
|
||||
@ -6069,9 +6022,10 @@ use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
|
||||
name = "requests-toolbelt"
|
||||
version = "1.0.0"
|
||||
description = "A utility belt for advanced users of python-requests"
|
||||
optional = false
|
||||
optional = true
|
||||
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
|
||||
groups = ["main"]
|
||||
markers = "extra == \"external-tools\" or extra == \"desktop\" or extra == \"all\""
|
||||
files = [
|
||||
{file = "requests-toolbelt-1.0.0.tar.gz", hash = "sha256:7681a0a3d047012b5bdc0ee37d7f8f07ebe76ab08caeccfc3921ce23c88d5bc6"},
|
||||
{file = "requests_toolbelt-1.0.0-py2.py3-none-any.whl", hash = "sha256:cccfdd665f0a24fcf4726e690f65639d272bb0637b9b92dfd91a5568ccf6bd06"},
|
||||
@ -6855,21 +6809,6 @@ dev = ["autoflake (>=1.3.1,<2.0.0)", "flake8 (>=3.8.3,<4.0.0)", "pre-commit (>=2
|
||||
doc = ["cairosvg (>=2.5.2,<3.0.0)", "mdx-include (>=1.4.1,<2.0.0)", "mkdocs (>=1.1.2,<2.0.0)", "mkdocs-material (>=8.1.4,<9.0.0)", "pillow (>=9.3.0,<10.0.0)"]
|
||||
test = ["black (>=22.3.0,<23.0.0)", "coverage (>=6.2,<7.0)", "isort (>=5.0.6,<6.0.0)", "mypy (==0.971)", "pytest (>=4.4.0,<8.0.0)", "pytest-cov (>=2.10.0,<5.0.0)", "pytest-sugar (>=0.9.4,<0.10.0)", "pytest-xdist (>=1.32.0,<4.0.0)", "rich (>=10.11.0,<14.0.0)", "shellingham (>=1.3.0,<2.0.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "types-requests"
|
||||
version = "2.32.0.20250328"
|
||||
description = "Typing stubs for requests"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "types_requests-2.32.0.20250328-py3-none-any.whl", hash = "sha256:72ff80f84b15eb3aa7a8e2625fffb6a93f2ad5a0c20215fc1dcfa61117bcb2a2"},
|
||||
{file = "types_requests-2.32.0.20250328.tar.gz", hash = "sha256:c9e67228ea103bd811c96984fac36ed2ae8da87a36a633964a21f199d60baf32"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
urllib3 = ">=2"
|
||||
|
||||
[[package]]
|
||||
name = "typing-extensions"
|
||||
version = "4.13.2"
|
||||
@ -7438,9 +7377,10 @@ testing = ["coverage[toml]", "zope.event", "zope.testing"]
|
||||
name = "zstandard"
|
||||
version = "0.23.0"
|
||||
description = "Zstandard bindings for Python"
|
||||
optional = false
|
||||
optional = true
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main"]
|
||||
markers = "extra == \"external-tools\" or extra == \"desktop\" or extra == \"all\""
|
||||
files = [
|
||||
{file = "zstandard-0.23.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:bf0a05b6059c0528477fba9054d09179beb63744355cab9f38059548fedd46a9"},
|
||||
{file = "zstandard-0.23.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:fc9ca1c9718cb3b06634c7c8dec57d24e9438b2aa9a0f02b8bb36bf478538880"},
|
||||
@ -7563,4 +7503,4 @@ tests = ["wikipedia"]
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = "<3.14,>=3.10"
|
||||
content-hash = "75c1c949aa6c0ef8d681bddd91999f97ed4991451be93ca45bf9c01dd19d8a8a"
|
||||
content-hash = "ba9cf0e00af2d5542aa4beecbd727af92b77ba584033f05c222b00ae47f96585"
|
||||
|
@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "letta"
|
||||
version = "0.7.7"
|
||||
version = "0.7.8"
|
||||
packages = [
|
||||
{include = "letta"},
|
||||
]
|
||||
@ -56,7 +56,6 @@ nltk = "^3.8.1"
|
||||
jinja2 = "^3.1.5"
|
||||
locust = {version = "^2.31.5", optional = true}
|
||||
wikipedia = {version = "^1.4.0", optional = true}
|
||||
composio-langchain = "^0.7.7"
|
||||
composio-core = "^0.7.7"
|
||||
alembic = "^1.13.3"
|
||||
pyhumps = "^3.8.0"
|
||||
@ -74,7 +73,7 @@ llama-index = "^0.12.2"
|
||||
llama-index-embeddings-openai = "^0.3.1"
|
||||
e2b-code-interpreter = {version = "^1.0.3", optional = true}
|
||||
anthropic = "^0.49.0"
|
||||
letta_client = "^0.1.124"
|
||||
letta_client = "^0.1.127"
|
||||
openai = "^1.60.0"
|
||||
opentelemetry-api = "1.30.0"
|
||||
opentelemetry-sdk = "1.30.0"
|
||||
|
@ -1,7 +1,7 @@
|
||||
{
|
||||
"context_window": 8192,
|
||||
"model_endpoint_type": "openai",
|
||||
"model_endpoint": "https://inference.memgpt.ai",
|
||||
"model_endpoint": "https://inference.letta.com",
|
||||
"model": "memgpt-openai",
|
||||
"embedding_endpoint_type": "hugging-face",
|
||||
"embedding_endpoint": "https://embeddings.memgpt.ai",
|
||||
|
@ -1,7 +1,7 @@
|
||||
{
|
||||
"context_window": 8192,
|
||||
"model_endpoint_type": "openai",
|
||||
"model_endpoint": "https://inference.memgpt.ai",
|
||||
"model_endpoint": "https://inference.letta.com",
|
||||
"model": "memgpt-openai",
|
||||
"put_inner_thoughts_in_kwargs": true
|
||||
}
|
||||
|
@ -105,7 +105,9 @@ def check_first_response_is_valid_for_llm_endpoint(filename: str, validate_inner
|
||||
agent = Agent(agent_state=full_agent_state, interface=None, user=client.user)
|
||||
|
||||
llm_client = LLMClient.create(
|
||||
provider=agent_state.llm_config.model_endpoint_type,
|
||||
provider_name=agent_state.llm_config.provider_name,
|
||||
provider_type=agent_state.llm_config.model_endpoint_type,
|
||||
actor_id=client.user.id,
|
||||
)
|
||||
if llm_client:
|
||||
response = llm_client.send_llm_request(
|
||||
@ -179,7 +181,7 @@ def check_agent_uses_external_tool(filename: str) -> LettaResponse:
|
||||
|
||||
Note: This is acting on the Letta response, note the usage of `user_message`
|
||||
"""
|
||||
from composio_langchain import Action
|
||||
from composio import Action
|
||||
|
||||
# Set up client
|
||||
client = create_client()
|
||||
|
@ -56,7 +56,7 @@ def test_add_composio_tool(fastapi_client):
|
||||
assert "name" in response.json()
|
||||
|
||||
|
||||
def test_composio_tool_execution_e2e(check_composio_key_set, composio_get_emojis, server: SyncServer, default_user):
|
||||
async def test_composio_tool_execution_e2e(check_composio_key_set, composio_get_emojis, server: SyncServer, default_user):
|
||||
agent_state = server.agent_manager.create_agent(
|
||||
agent_create=CreateAgent(
|
||||
name="sarah_agent",
|
||||
@ -67,7 +67,7 @@ def test_composio_tool_execution_e2e(check_composio_key_set, composio_get_emojis
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
tool_execution_result = ToolExecutionManager(agent_state, actor=default_user).execute_tool(
|
||||
tool_execution_result = await ToolExecutionManager(agent_state, actor=default_user).execute_tool(
|
||||
function_name=composio_get_emojis.name, function_args={}, tool=composio_get_emojis
|
||||
)
|
||||
|
||||
|
@ -1,26 +1,26 @@
|
||||
import os
|
||||
import threading
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from dotenv import load_dotenv
|
||||
from letta_client import Letta
|
||||
from openai import AsyncOpenAI
|
||||
from openai.types.chat import ChatCompletionChunk
|
||||
from sqlalchemy import delete
|
||||
|
||||
from letta.agents.voice_sleeptime_agent import VoiceSleeptimeAgent
|
||||
from letta.config import LettaConfig
|
||||
from letta.orm import Provider, Step
|
||||
from letta.constants import DEFAULT_MAX_MESSAGE_BUFFER_LENGTH, DEFAULT_MIN_MESSAGE_BUFFER_LENGTH
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.schemas.agent import AgentType, CreateAgent
|
||||
from letta.schemas.block import CreateBlock
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import MessageRole, MessageStreamStatus
|
||||
from letta.schemas.group import ManagerType
|
||||
from letta.schemas.group import GroupUpdate, ManagerType, VoiceSleeptimeManagerUpdate
|
||||
from letta.schemas.letta_message import AssistantMessage, ReasoningMessage, ToolCallMessage, ToolReturnMessage, UserMessage
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import MessageCreate
|
||||
from letta.schemas.message import Message, MessageCreate
|
||||
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest
|
||||
from letta.schemas.openai.chat_completion_request import UserMessage as OpenAIUserMessage
|
||||
from letta.schemas.tool import ToolCreate
|
||||
@ -29,6 +29,8 @@ from letta.server.server import SyncServer
|
||||
from letta.services.agent_manager import AgentManager
|
||||
from letta.services.block_manager import BlockManager
|
||||
from letta.services.message_manager import MessageManager
|
||||
from letta.services.summarizer.enums import SummarizationMode
|
||||
from letta.services.summarizer.summarizer import Summarizer
|
||||
from letta.services.tool_manager import ToolManager
|
||||
from letta.services.user_manager import UserManager
|
||||
from letta.utils import get_persona_text
|
||||
@ -48,16 +50,24 @@ MESSAGE_TRANSCRIPTS = [
|
||||
"user: Maybe just a recommendation for a nice vegan bakery to grab a birthday treat.",
|
||||
"assistant: How about Vegan Treats in Santa Barbara? They’re highly rated.",
|
||||
"user: Sounds good. Also, I work remotely as a UX designer, usually on a MacBook Pro.",
|
||||
"user: I want to make sure my itinerary isn’t too tight—aiming for 3–4 days total.",
|
||||
"assistant: Understood. I can draft a relaxed 4-day schedule with driving and stops.",
|
||||
"user: Yes, let’s do that.",
|
||||
"assistant: I’ll put together a day-by-day plan now.",
|
||||
]
|
||||
|
||||
SUMMARY_REQ_TEXT = """
|
||||
Here is the conversation history. Lines marked (Older) are about to be evicted; lines marked (Newer) are still in context for clarity:
|
||||
SYSTEM_MESSAGE = Message(role=MessageRole.system, content=[TextContent(text="System message")])
|
||||
MESSAGE_OBJECTS = [SYSTEM_MESSAGE]
|
||||
for entry in MESSAGE_TRANSCRIPTS:
|
||||
role_str, text = entry.split(":", 1)
|
||||
role = MessageRole.user if role_str.strip() == "user" else MessageRole.assistant
|
||||
MESSAGE_OBJECTS.append(Message(role=role, content=[TextContent(text=text.strip())]))
|
||||
MESSAGE_EVICT_BREAKPOINT = 14
|
||||
|
||||
SUMMARY_REQ_TEXT = """
|
||||
You’re a memory-recall helper for an AI that can only keep the last 4 messages. Scan the conversation history, focusing on messages about to drop out of that window, and write crisp notes that capture any important facts or insights about the human so they aren’t lost.
|
||||
|
||||
(Older) Evicted Messages:
|
||||
|
||||
(Older)
|
||||
0. user: Hey, I’ve been thinking about planning a road trip up the California coast next month.
|
||||
1. assistant: That sounds amazing! Do you have any particular cities or sights in mind?
|
||||
2. user: I definitely want to stop in Big Sur and maybe Santa Barbara. Also, I love craft coffee shops.
|
||||
@ -70,16 +80,13 @@ Here is the conversation history. Lines marked (Older) are about to be evicted;
|
||||
9. assistant: Happy early birthday! Would you like gift ideas or celebration tips?
|
||||
10. user: Maybe just a recommendation for a nice vegan bakery to grab a birthday treat.
|
||||
11. assistant: How about Vegan Treats in Santa Barbara? They’re highly rated.
|
||||
|
||||
(Newer) In-Context Messages:
|
||||
|
||||
12. user: Sounds good. Also, I work remotely as a UX designer, usually on a MacBook Pro.
|
||||
|
||||
(Newer)
|
||||
13. user: I want to make sure my itinerary isn’t too tight—aiming for 3–4 days total.
|
||||
14. assistant: Understood. I can draft a relaxed 4-day schedule with driving and stops.
|
||||
15. user: Yes, let’s do that.
|
||||
16. assistant: I’ll put together a day-by-day plan now.
|
||||
|
||||
Please segment the (Older) portion into coherent chunks and—using **only** the `store_memory` tool—output a JSON call that lists each chunk’s `start_index`, `end_index`, and a one-sentence `contextual_description`.
|
||||
"""
|
||||
13. assistant: Understood. I can draft a relaxed 4-day schedule with driving and stops.
|
||||
14. user: Yes, let’s do that.
|
||||
15. assistant: I’ll put together a day-by-day plan now."""
|
||||
|
||||
# --- Server Management --- #
|
||||
|
||||
@ -214,22 +221,12 @@ def org_id(server):
|
||||
|
||||
yield org.id
|
||||
|
||||
# cleanup
|
||||
with server.organization_manager.session_maker() as session:
|
||||
session.execute(delete(Step))
|
||||
session.execute(delete(Provider))
|
||||
session.commit()
|
||||
server.organization_manager.delete_organization_by_id(org.id)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def actor(server, org_id):
|
||||
user = server.user_manager.create_default_user()
|
||||
yield user
|
||||
|
||||
# cleanup
|
||||
server.user_manager.delete_user_by_id(user.id)
|
||||
|
||||
|
||||
# --- Helper Functions --- #
|
||||
|
||||
@ -301,6 +298,80 @@ async def test_multiple_messages(disable_e2b_api_key, client, voice_agent, endpo
|
||||
print(chunk.choices[0].delta.content)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_summarization(disable_e2b_api_key, voice_agent):
|
||||
agent_manager = AgentManager()
|
||||
user_manager = UserManager()
|
||||
actor = user_manager.get_default_user()
|
||||
|
||||
request = CreateAgent(
|
||||
name=voice_agent.name + "-sleeptime",
|
||||
agent_type=AgentType.voice_sleeptime_agent,
|
||||
block_ids=[block.id for block in voice_agent.memory.blocks],
|
||||
memory_blocks=[
|
||||
CreateBlock(
|
||||
label="memory_persona",
|
||||
value=get_persona_text("voice_memory_persona"),
|
||||
),
|
||||
],
|
||||
llm_config=LLMConfig.default_config(model_name="gpt-4o-mini"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
project_id=voice_agent.project_id,
|
||||
)
|
||||
sleeptime_agent = agent_manager.create_agent(request, actor=actor)
|
||||
|
||||
async_client = AsyncOpenAI()
|
||||
|
||||
memory_agent = VoiceSleeptimeAgent(
|
||||
agent_id=sleeptime_agent.id,
|
||||
convo_agent_state=sleeptime_agent, # In reality, this will be the main convo agent
|
||||
openai_client=async_client,
|
||||
message_manager=MessageManager(),
|
||||
agent_manager=agent_manager,
|
||||
actor=actor,
|
||||
block_manager=BlockManager(),
|
||||
target_block_label="human",
|
||||
message_transcripts=MESSAGE_TRANSCRIPTS,
|
||||
)
|
||||
|
||||
summarizer = Summarizer(
|
||||
mode=SummarizationMode.STATIC_MESSAGE_BUFFER,
|
||||
summarizer_agent=memory_agent,
|
||||
message_buffer_limit=8,
|
||||
message_buffer_min=4,
|
||||
)
|
||||
|
||||
# stub out the agent.step so it returns a known sentinel
|
||||
memory_agent.step = MagicMock(return_value="STEP_RESULT")
|
||||
|
||||
# patch fire_and_forget on *this* summarizer instance to a MagicMock
|
||||
summarizer.fire_and_forget = MagicMock()
|
||||
|
||||
# now call the method under test
|
||||
in_ctx = MESSAGE_OBJECTS[:MESSAGE_EVICT_BREAKPOINT]
|
||||
new_msgs = MESSAGE_OBJECTS[MESSAGE_EVICT_BREAKPOINT:]
|
||||
# call under test (this is sync)
|
||||
updated, did_summarize = summarizer._static_buffer_summarization(
|
||||
in_context_messages=in_ctx,
|
||||
new_letta_messages=new_msgs,
|
||||
)
|
||||
|
||||
assert did_summarize is True
|
||||
assert len(updated) == summarizer.message_buffer_min + 1 # One extra for system message
|
||||
assert updated[0].role == MessageRole.system # Preserved system message
|
||||
|
||||
# 2) the summarizer_agent.step() should have been *called* exactly once
|
||||
memory_agent.step.assert_called_once()
|
||||
call_args = memory_agent.step.call_args.args[0] # the single positional argument: a list of MessageCreate
|
||||
assert isinstance(call_args, list)
|
||||
assert isinstance(call_args[0], MessageCreate)
|
||||
assert call_args[0].role == MessageRole.user
|
||||
assert "15. assistant: I’ll put together a day-by-day plan now." in call_args[0].content[0].text
|
||||
|
||||
# 3) fire_and_forget should have been called once, and its argument must be the coroutine returned by step()
|
||||
summarizer.fire_and_forget.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_voice_sleeptime_agent(disable_e2b_api_key, voice_agent):
|
||||
"""Tests chat completion streaming using the Async OpenAI client."""
|
||||
@ -427,3 +498,66 @@ async def test_init_voice_convo_agent(voice_agent, server, actor):
|
||||
server.group_manager.retrieve_group(group_id=group.id, actor=actor)
|
||||
with pytest.raises(NoResultFound):
|
||||
server.agent_manager.get_agent_by_id(agent_id=sleeptime_agent_id, actor=actor)
|
||||
|
||||
|
||||
def _modify(group_id, server, actor, max_val, min_val):
|
||||
"""Helper to invoke modify_group with voice_sleeptime config."""
|
||||
return server.group_manager.modify_group(
|
||||
group_id=group_id,
|
||||
group_update=GroupUpdate(
|
||||
manager_config=VoiceSleeptimeManagerUpdate(
|
||||
manager_type=ManagerType.voice_sleeptime,
|
||||
max_message_buffer_length=max_val,
|
||||
min_message_buffer_length=min_val,
|
||||
)
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def group_id(voice_agent):
|
||||
return voice_agent.multi_agent_group.id
|
||||
|
||||
|
||||
def test_valid_buffer_lengths_above_four(group_id, server, actor):
|
||||
# both > 4 and max > min
|
||||
updated = _modify(group_id, server, actor, max_val=10, min_val=5)
|
||||
assert updated.max_message_buffer_length == 10
|
||||
assert updated.min_message_buffer_length == 5
|
||||
|
||||
|
||||
def test_valid_buffer_lengths_only_max(group_id, server, actor):
|
||||
# both > 4 and max > min
|
||||
updated = _modify(group_id, server, actor, max_val=DEFAULT_MAX_MESSAGE_BUFFER_LENGTH + 1, min_val=None)
|
||||
assert updated.max_message_buffer_length == DEFAULT_MAX_MESSAGE_BUFFER_LENGTH + 1
|
||||
assert updated.min_message_buffer_length == DEFAULT_MIN_MESSAGE_BUFFER_LENGTH
|
||||
|
||||
|
||||
def test_valid_buffer_lengths_only_min(group_id, server, actor):
|
||||
# both > 4 and max > min
|
||||
updated = _modify(group_id, server, actor, max_val=None, min_val=DEFAULT_MIN_MESSAGE_BUFFER_LENGTH + 1)
|
||||
assert updated.max_message_buffer_length == DEFAULT_MAX_MESSAGE_BUFFER_LENGTH
|
||||
assert updated.min_message_buffer_length == DEFAULT_MIN_MESSAGE_BUFFER_LENGTH + 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"max_val,min_val,err_part",
|
||||
[
|
||||
# only one set → both-or-none
|
||||
(None, DEFAULT_MAX_MESSAGE_BUFFER_LENGTH, "must be greater than"),
|
||||
(DEFAULT_MIN_MESSAGE_BUFFER_LENGTH, None, "must be greater than"),
|
||||
# ordering violations
|
||||
(5, 5, "must be greater than"),
|
||||
(6, 7, "must be greater than"),
|
||||
# lower-bound (must both be > 4)
|
||||
(4, 5, "greater than 4"),
|
||||
(5, 4, "greater than 4"),
|
||||
(1, 10, "greater than 4"),
|
||||
(10, 1, "greater than 4"),
|
||||
],
|
||||
)
|
||||
def test_invalid_buffer_lengths(group_id, server, actor, max_val, min_val, err_part):
|
||||
with pytest.raises(ValueError) as exc:
|
||||
_modify(group_id, server, actor, max_val, min_val)
|
||||
assert err_part in str(exc.value)
|
||||
|
@ -124,7 +124,7 @@ def test_agent(client: LocalClient):
|
||||
def test_agent_add_remove_tools(client: LocalClient, agent):
|
||||
# Create and add two tools to the client
|
||||
# tool 1
|
||||
from composio_langchain import Action
|
||||
from composio import Action
|
||||
|
||||
github_tool = client.load_composio_tool(action=Action.GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER)
|
||||
|
||||
@ -316,7 +316,7 @@ def test_tools(client: LocalClient):
|
||||
|
||||
|
||||
def test_tools_from_composio_basic(client: LocalClient):
|
||||
from composio_langchain import Action
|
||||
from composio import Action
|
||||
|
||||
# Create a `LocalClient` (you can also use a `RESTClient`, see the letta_rest_client.py example)
|
||||
client = create_client()
|
||||
|
@ -3,7 +3,7 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from letta.server.rest_api.optimistic_json_parser import OptimisticJSONParser
|
||||
from letta.server.rest_api.json_parser import OptimisticJSONParser
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -19,97 +19,166 @@ from letta.settings import model_settings
|
||||
def test_openai():
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
assert api_key is not None
|
||||
provider = OpenAIProvider(api_key=api_key, base_url=model_settings.openai_api_base)
|
||||
provider = OpenAIProvider(
|
||||
name="openai",
|
||||
api_key=api_key,
|
||||
base_url=model_settings.openai_api_base,
|
||||
)
|
||||
models = provider.list_llm_models()
|
||||
print(models)
|
||||
assert len(models) > 0
|
||||
assert models[0].handle == f"{provider.name}/{models[0].model}"
|
||||
|
||||
embedding_models = provider.list_embedding_models()
|
||||
assert len(embedding_models) > 0
|
||||
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
|
||||
|
||||
|
||||
def test_deepseek():
|
||||
api_key = os.getenv("DEEPSEEK_API_KEY")
|
||||
assert api_key is not None
|
||||
provider = DeepSeekProvider(api_key=api_key)
|
||||
provider = DeepSeekProvider(
|
||||
name="deepseek",
|
||||
api_key=api_key,
|
||||
)
|
||||
models = provider.list_llm_models()
|
||||
print(models)
|
||||
assert len(models) > 0
|
||||
assert models[0].handle == f"{provider.name}/{models[0].model}"
|
||||
|
||||
|
||||
def test_anthropic():
|
||||
api_key = os.getenv("ANTHROPIC_API_KEY")
|
||||
assert api_key is not None
|
||||
provider = AnthropicProvider(api_key=api_key)
|
||||
provider = AnthropicProvider(
|
||||
name="anthropic",
|
||||
api_key=api_key,
|
||||
)
|
||||
models = provider.list_llm_models()
|
||||
print(models)
|
||||
assert len(models) > 0
|
||||
assert models[0].handle == f"{provider.name}/{models[0].model}"
|
||||
|
||||
|
||||
def test_groq():
|
||||
provider = GroqProvider(api_key=os.getenv("GROQ_API_KEY"))
|
||||
provider = GroqProvider(
|
||||
name="groq",
|
||||
api_key=os.getenv("GROQ_API_KEY"),
|
||||
)
|
||||
models = provider.list_llm_models()
|
||||
print(models)
|
||||
assert len(models) > 0
|
||||
assert models[0].handle == f"{provider.name}/{models[0].model}"
|
||||
|
||||
|
||||
def test_azure():
|
||||
provider = AzureProvider(api_key=os.getenv("AZURE_API_KEY"), base_url=os.getenv("AZURE_BASE_URL"))
|
||||
provider = AzureProvider(
|
||||
name="azure",
|
||||
api_key=os.getenv("AZURE_API_KEY"),
|
||||
base_url=os.getenv("AZURE_BASE_URL"),
|
||||
)
|
||||
models = provider.list_llm_models()
|
||||
print([m.model for m in models])
|
||||
assert len(models) > 0
|
||||
assert models[0].handle == f"{provider.name}/{models[0].model}"
|
||||
|
||||
embed_models = provider.list_embedding_models()
|
||||
print([m.embedding_model for m in embed_models])
|
||||
embedding_models = provider.list_embedding_models()
|
||||
assert len(embedding_models) > 0
|
||||
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
|
||||
|
||||
|
||||
def test_ollama():
|
||||
base_url = os.getenv("OLLAMA_BASE_URL")
|
||||
assert base_url is not None
|
||||
provider = OllamaProvider(base_url=base_url, default_prompt_formatter=model_settings.default_prompt_formatter, api_key=None)
|
||||
provider = OllamaProvider(
|
||||
name="ollama",
|
||||
base_url=base_url,
|
||||
default_prompt_formatter=model_settings.default_prompt_formatter,
|
||||
api_key=None,
|
||||
)
|
||||
models = provider.list_llm_models()
|
||||
print(models)
|
||||
assert len(models) > 0
|
||||
assert models[0].handle == f"{provider.name}/{models[0].model}"
|
||||
|
||||
embedding_models = provider.list_embedding_models()
|
||||
print(embedding_models)
|
||||
assert len(embedding_models) > 0
|
||||
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
|
||||
|
||||
|
||||
def test_googleai():
|
||||
api_key = os.getenv("GEMINI_API_KEY")
|
||||
assert api_key is not None
|
||||
provider = GoogleAIProvider(api_key=api_key)
|
||||
provider = GoogleAIProvider(
|
||||
name="google_ai",
|
||||
api_key=api_key,
|
||||
)
|
||||
models = provider.list_llm_models()
|
||||
print(models)
|
||||
assert len(models) > 0
|
||||
assert models[0].handle == f"{provider.name}/{models[0].model}"
|
||||
|
||||
provider.list_embedding_models()
|
||||
embedding_models = provider.list_embedding_models()
|
||||
assert len(embedding_models) > 0
|
||||
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
|
||||
|
||||
|
||||
def test_google_vertex():
|
||||
provider = GoogleVertexProvider(google_cloud_project=os.getenv("GCP_PROJECT_ID"), google_cloud_location=os.getenv("GCP_REGION"))
|
||||
provider = GoogleVertexProvider(
|
||||
name="google_vertex",
|
||||
google_cloud_project=os.getenv("GCP_PROJECT_ID"),
|
||||
google_cloud_location=os.getenv("GCP_REGION"),
|
||||
)
|
||||
models = provider.list_llm_models()
|
||||
print(models)
|
||||
print([m.model for m in models])
|
||||
assert len(models) > 0
|
||||
assert models[0].handle == f"{provider.name}/{models[0].model}"
|
||||
|
||||
embedding_models = provider.list_embedding_models()
|
||||
print([m.embedding_model for m in embedding_models])
|
||||
assert len(embedding_models) > 0
|
||||
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
|
||||
|
||||
|
||||
def test_mistral():
|
||||
provider = MistralProvider(api_key=os.getenv("MISTRAL_API_KEY"))
|
||||
provider = MistralProvider(
|
||||
name="mistral",
|
||||
api_key=os.getenv("MISTRAL_API_KEY"),
|
||||
)
|
||||
models = provider.list_llm_models()
|
||||
print([m.model for m in models])
|
||||
assert len(models) > 0
|
||||
assert models[0].handle == f"{provider.name}/{models[0].model}"
|
||||
|
||||
|
||||
def test_together():
|
||||
provider = TogetherProvider(api_key=os.getenv("TOGETHER_API_KEY"), default_prompt_formatter="chatml")
|
||||
provider = TogetherProvider(
|
||||
name="together",
|
||||
api_key=os.getenv("TOGETHER_API_KEY"),
|
||||
default_prompt_formatter="chatml",
|
||||
)
|
||||
models = provider.list_llm_models()
|
||||
print([m.model for m in models])
|
||||
assert len(models) > 0
|
||||
assert models[0].handle == f"{provider.name}/{models[0].model}"
|
||||
|
||||
embedding_models = provider.list_embedding_models()
|
||||
print([m.embedding_model for m in embedding_models])
|
||||
assert len(embedding_models) > 0
|
||||
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
|
||||
|
||||
|
||||
def test_anthropic_bedrock():
|
||||
from letta.settings import model_settings
|
||||
|
||||
provider = AnthropicBedrockProvider(aws_region=model_settings.aws_region)
|
||||
provider = AnthropicBedrockProvider(name="bedrock", aws_region=model_settings.aws_region)
|
||||
models = provider.list_llm_models()
|
||||
print([m.model for m in models])
|
||||
assert len(models) > 0
|
||||
assert models[0].handle == f"{provider.name}/{models[0].model}"
|
||||
|
||||
embedding_models = provider.list_embedding_models()
|
||||
print([m.embedding_model for m in embedding_models])
|
||||
assert len(embedding_models) > 0
|
||||
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
|
||||
|
||||
|
||||
def test_custom_anthropic():
|
||||
api_key = os.getenv("ANTHROPIC_API_KEY")
|
||||
assert api_key is not None
|
||||
provider = AnthropicProvider(
|
||||
name="custom_anthropic",
|
||||
api_key=api_key,
|
||||
)
|
||||
models = provider.list_llm_models()
|
||||
assert len(models) > 0
|
||||
assert models[0].handle == f"{provider.name}/{models[0].model}"
|
||||
|
||||
|
||||
# def test_vllm():
|
||||
|
@ -13,7 +13,7 @@ import letta.utils as utils
|
||||
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, LETTA_DIR, LETTA_TOOL_EXECUTION_DIR
|
||||
from letta.orm import Provider, Step
|
||||
from letta.schemas.block import CreateBlock
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.enums import MessageRole, ProviderType
|
||||
from letta.schemas.letta_message import LettaMessage, ReasoningMessage, SystemMessage, ToolCallMessage, ToolReturnMessage, UserMessage
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.providers import Provider as PydanticProvider
|
||||
@ -1226,7 +1226,8 @@ def test_messages_with_provider_override(server: SyncServer, user_id: str):
|
||||
actor = server.user_manager.get_user_or_default(user_id)
|
||||
provider = server.provider_manager.create_provider(
|
||||
provider=PydanticProvider(
|
||||
name="anthropic",
|
||||
name="caren-anthropic",
|
||||
provider_type=ProviderType.anthropic,
|
||||
api_key=os.getenv("ANTHROPIC_API_KEY"),
|
||||
),
|
||||
actor=actor,
|
||||
@ -1234,8 +1235,8 @@ def test_messages_with_provider_override(server: SyncServer, user_id: str):
|
||||
agent = server.create_agent(
|
||||
request=CreateAgent(
|
||||
memory_blocks=[],
|
||||
model="anthropic/claude-3-opus-20240229",
|
||||
context_window_limit=200000,
|
||||
model="caren-anthropic/claude-3-opus-20240229",
|
||||
context_window_limit=100000,
|
||||
embedding="openai/text-embedding-ada-002",
|
||||
),
|
||||
actor=actor,
|
||||
|
Loading…
Reference in New Issue
Block a user