chore: bump version 0.7.2 (#2584)

Co-authored-by: Matthew Zhou <mattzh1314@gmail.com>
Co-authored-by: Charles Packer <packercharles@gmail.com>
This commit is contained in:
cthomas 2025-04-23 15:23:09 -07:00 committed by GitHub
parent 435b754286
commit 6495180ee2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
28 changed files with 477 additions and 134 deletions

View File

@ -1,4 +1,4 @@
__version__ = "0.7.1"
__version__ = "0.7.2"
# import clients
from letta.client.client import LocalClient, RESTClient, create_client

View File

@ -10,7 +10,7 @@ def get_composio_api_key(actor: User, logger: Optional[Logger] = None) -> Option
api_keys = SandboxConfigManager().list_sandbox_env_vars_by_key(key="COMPOSIO_API_KEY", actor=actor)
if not api_keys:
if logger:
logger.warning(f"No API keys found for Composio. Defaulting to the environment variable...")
logger.debug(f"No API keys found for Composio. Defaulting to the environment variable...")
if tool_settings.composio_api_key:
return tool_settings.composio_api_key
else:

View File

@ -66,6 +66,15 @@ def get_utc_time() -> datetime:
return datetime.now(timezone.utc)
def get_utc_time_int() -> int:
return int(get_utc_time().timestamp())
def timestamp_to_datetime(timestamp_seconds: int) -> datetime:
"""Convert Unix timestamp in seconds to UTC datetime object"""
return datetime.fromtimestamp(timestamp_seconds, tz=timezone.utc)
def format_datetime(dt):
return dt.strftime("%Y-%m-%d %I:%M:%S %p %Z%z")

View File

@ -73,7 +73,8 @@ async def fetch_batch_items(server: SyncServer, batch_id: str, batch_resp_id: st
"""
updates = []
try:
async for item_result in server.anthropic_async_client.beta.messages.batches.results(batch_resp_id):
results = await server.anthropic_async_client.beta.messages.batches.results(batch_resp_id)
async for item_result in results:
# Here, custom_id should be the agent_id
item_status = map_anthropic_individual_batch_item_status_to_job_status(item_result)
updates.append(ItemUpdateInfo(batch_id, item_result.custom_id, item_status, item_result))

View File

@ -20,7 +20,7 @@ from anthropic.types.beta import (
)
from letta.errors import BedrockError, BedrockPermissionError
from letta.helpers.datetime_helpers import get_utc_time
from letta.helpers.datetime_helpers import get_utc_time_int, timestamp_to_datetime
from letta.llm_api.aws_bedrock import get_bedrock_client
from letta.llm_api.helpers import add_inner_thoughts_to_functions
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION
@ -396,7 +396,7 @@ def convert_anthropic_response_to_chatcompletion(
return ChatCompletionResponse(
id=response.id,
choices=[choice],
created=get_utc_time(),
created=get_utc_time_int(),
model=response.model,
usage=UsageStatistics(
prompt_tokens=prompt_tokens,
@ -451,7 +451,7 @@ def convert_anthropic_stream_event_to_chatcompletion(
'logprobs': None
}
],
'created': datetime.datetime(2025, 1, 24, 0, 18, 55, tzinfo=TzInfo(UTC)),
'created': 1713216662,
'model': 'gpt-4o-mini-2024-07-18',
'system_fingerprint': 'fp_bd83329f63',
'object': 'chat.completion.chunk'
@ -613,7 +613,7 @@ def convert_anthropic_stream_event_to_chatcompletion(
return ChatCompletionChunkResponse(
id=message_id,
choices=[choice],
created=get_utc_time(),
created=get_utc_time_int(),
model=model,
output_tokens=completion_chunk_tokens,
)
@ -920,7 +920,7 @@ def anthropic_chat_completions_process_stream(
chat_completion_response = ChatCompletionResponse(
id=dummy_message.id if create_message_id else TEMP_STREAM_RESPONSE_ID,
choices=[],
created=dummy_message.created_at,
created=int(dummy_message.created_at.timestamp()),
model=chat_completion_request.model,
usage=UsageStatistics(
prompt_tokens=prompt_tokens,
@ -954,7 +954,11 @@ def anthropic_chat_completions_process_stream(
message_type = stream_interface.process_chunk(
chat_completion_chunk,
message_id=chat_completion_response.id if create_message_id else chat_completion_chunk.id,
message_date=chat_completion_response.created if create_message_datetime else chat_completion_chunk.created,
message_date=(
timestamp_to_datetime(chat_completion_response.created)
if create_message_datetime
else timestamp_to_datetime(chat_completion_chunk.created)
),
# if extended_thinking is on, then reasoning_content will be flowing as chunks
# TODO handle emitting redacted reasoning content (e.g. as concat?)
expect_reasoning_content=extended_thinking,

View File

@ -22,7 +22,7 @@ from letta.errors import (
LLMServerError,
LLMUnprocessableEntityError,
)
from letta.helpers.datetime_helpers import get_utc_time
from letta.helpers.datetime_helpers import get_utc_time_int
from letta.llm_api.helpers import add_inner_thoughts_to_functions, unpack_all_inner_thoughts_from_kwargs
from letta.llm_api.llm_client_base import LLMClientBase
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION
@ -403,7 +403,7 @@ class AnthropicClient(LLMClientBase):
chat_completion_response = ChatCompletionResponse(
id=response.id,
choices=[choice],
created=get_utc_time(),
created=get_utc_time_int(),
model=response.model,
usage=UsageStatistics(
prompt_tokens=prompt_tokens,

View File

@ -4,7 +4,7 @@ from typing import List, Optional, Union
import requests
from letta.helpers.datetime_helpers import get_utc_time
from letta.helpers.datetime_helpers import get_utc_time_int
from letta.helpers.json_helpers import json_dumps
from letta.local_llm.utils import count_tokens
from letta.schemas.message import Message
@ -207,7 +207,7 @@ def convert_cohere_response_to_chatcompletion(
return ChatCompletionResponse(
id=response_json["response_id"],
choices=[choice],
created=get_utc_time(),
created=get_utc_time_int(),
model=model,
usage=UsageStatistics(
prompt_tokens=prompt_tokens,

View File

@ -6,7 +6,7 @@ import requests
from google.genai.types import FunctionCallingConfig, FunctionCallingConfigMode, ToolConfig
from letta.constants import NON_USER_MSG_PREFIX
from letta.helpers.datetime_helpers import get_utc_time
from letta.helpers.datetime_helpers import get_utc_time_int
from letta.helpers.json_helpers import json_dumps
from letta.llm_api.helpers import make_post_request
from letta.llm_api.llm_client_base import LLMClientBase
@ -260,7 +260,7 @@ class GoogleAIClient(LLMClientBase):
id=response_id,
choices=choices,
model=self.llm_config.model, # NOTE: Google API doesn't pass back model in the response
created=get_utc_time(),
created=get_utc_time_int(),
usage=usage,
)
except KeyError as e:

View File

@ -4,7 +4,7 @@ from typing import List, Optional
from google import genai
from google.genai.types import FunctionCallingConfig, FunctionCallingConfigMode, GenerateContentResponse, ThinkingConfig, ToolConfig
from letta.helpers.datetime_helpers import get_utc_time
from letta.helpers.datetime_helpers import get_utc_time_int
from letta.helpers.json_helpers import json_dumps
from letta.llm_api.google_ai_client import GoogleAIClient
from letta.local_llm.json_parser import clean_json_string_extra_backslash
@ -234,7 +234,7 @@ class GoogleVertexClient(GoogleAIClient):
id=response_id,
choices=choices,
model=self.llm_config.model, # NOTE: Google API doesn't pass back model in the response
created=get_utc_time(),
created=get_utc_time_int(),
usage=usage,
)
except KeyError as e:

View File

@ -4,7 +4,9 @@ from typing import Generator, List, Optional, Union
import requests
from openai import OpenAI
from letta.helpers.datetime_helpers import timestamp_to_datetime
from letta.llm_api.helpers import add_inner_thoughts_to_functions, convert_to_structured_output, make_post_request
from letta.llm_api.openai_client import supports_parallel_tool_calling, supports_temperature_param
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION, INNER_THOUGHTS_KWARG_DESCRIPTION_GO_FIRST
from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages
from letta.log import get_logger
@ -135,7 +137,7 @@ def build_openai_chat_completions_request(
tool_choice=tool_choice,
user=str(user_id),
max_completion_tokens=llm_config.max_tokens,
temperature=1.0 if llm_config.enable_reasoner else llm_config.temperature,
temperature=llm_config.temperature if supports_temperature_param(model) else None,
reasoning_effort=llm_config.reasoning_effort,
)
else:
@ -237,7 +239,7 @@ def openai_chat_completions_process_stream(
chat_completion_response = ChatCompletionResponse(
id=dummy_message.id if create_message_id else TEMP_STREAM_RESPONSE_ID,
choices=[],
created=dummy_message.created_at, # NOTE: doesn't matter since both will do get_utc_time()
created=int(dummy_message.created_at.timestamp()), # NOTE: doesn't matter since both will do get_utc_time()
model=chat_completion_request.model,
usage=UsageStatistics(
completion_tokens=0,
@ -274,7 +276,11 @@ def openai_chat_completions_process_stream(
message_type = stream_interface.process_chunk(
chat_completion_chunk,
message_id=chat_completion_response.id if create_message_id else chat_completion_chunk.id,
message_date=chat_completion_response.created if create_message_datetime else chat_completion_chunk.created,
message_date=(
timestamp_to_datetime(chat_completion_response.created)
if create_message_datetime
else timestamp_to_datetime(chat_completion_chunk.created)
),
expect_reasoning_content=expect_reasoning_content,
name=name,
message_index=message_idx,
@ -489,6 +495,7 @@ def prepare_openai_payload(chat_completion_request: ChatCompletionRequest):
# except ValueError as e:
# warnings.warn(f"Failed to convert tool function to structured output, tool={tool}, error={e}")
if "o3-mini" in chat_completion_request.model or "o1" in chat_completion_request.model:
if not supports_parallel_tool_calling(chat_completion_request.model):
data.pop("parallel_tool_calls", None)
return data

View File

@ -34,6 +34,33 @@ from letta.settings import model_settings
logger = get_logger(__name__)
def is_openai_reasoning_model(model: str) -> bool:
"""Utility function to check if the model is a 'reasoner'"""
# NOTE: needs to be updated with new model releases
return model.startswith("o1") or model.startswith("o3")
def supports_temperature_param(model: str) -> bool:
"""Certain OpenAI models don't support configuring the temperature.
Example error: 400 - {'error': {'message': "Unsupported parameter: 'temperature' is not supported with this model.", 'type': 'invalid_request_error', 'param': 'temperature', 'code': 'unsupported_parameter'}}
"""
if is_openai_reasoning_model(model):
return False
else:
return True
def supports_parallel_tool_calling(model: str) -> bool:
"""Certain OpenAI models don't support parallel tool calls."""
if is_openai_reasoning_model(model):
return False
else:
return True
class OpenAIClient(LLMClientBase):
def _prepare_client_kwargs(self) -> dict:
api_key = model_settings.openai_api_key or os.environ.get("OPENAI_API_KEY")
@ -66,7 +93,8 @@ class OpenAIClient(LLMClientBase):
put_inner_thoughts_first=True,
)
use_developer_message = llm_config.model.startswith("o1") or llm_config.model.startswith("o3") # o-series models
use_developer_message = is_openai_reasoning_model(llm_config.model)
openai_message_list = [
cast_message_to_subtype(
m.to_openai_dict(
@ -103,7 +131,7 @@ class OpenAIClient(LLMClientBase):
tool_choice=tool_choice,
user=str(),
max_completion_tokens=llm_config.max_tokens,
temperature=llm_config.temperature,
temperature=llm_config.temperature if supports_temperature_param(model) else None,
)
if "inference.memgpt.ai" in llm_config.model_endpoint:
@ -160,6 +188,10 @@ class OpenAIClient(LLMClientBase):
response=chat_completion_response, inner_thoughts_key=INNER_THOUGHTS_KWARG
)
# If we used a reasoning model, create a content part for the ommitted reasoning
if is_openai_reasoning_model(self.llm_config.model):
chat_completion_response.choices[0].message.ommitted_reasoning_content = True
return chat_completion_response
def stream(self, request_data: dict) -> Stream[ChatCompletionChunk]:

View File

@ -6,7 +6,7 @@ import requests
from letta.constants import CLI_WARNING_PREFIX
from letta.errors import LocalLLMConnectionError, LocalLLMError
from letta.helpers.datetime_helpers import get_utc_time
from letta.helpers.datetime_helpers import get_utc_time_int
from letta.helpers.json_helpers import json_dumps
from letta.local_llm.constants import DEFAULT_WRAPPER
from letta.local_llm.function_parser import patch_function
@ -241,7 +241,7 @@ def get_chat_completion(
),
)
],
created=get_utc_time(),
created=get_utc_time_int(),
model=model,
# "This fingerprint represents the backend configuration that the model runs with."
# system_fingerprint=user if user is not None else "null",

View File

@ -145,7 +145,8 @@ class OmittedReasoningContent(MessageContent):
type: Literal[MessageContentType.omitted_reasoning] = Field(
MessageContentType.omitted_reasoning, description="Indicates this is an omitted reasoning step."
)
tokens: int = Field(..., description="The reasoning token count for intermediate reasoning content.")
# NOTE: dropping because we don't track this kind of information for the other reasoning types
# tokens: int = Field(..., description="The reasoning token count for intermediate reasoning content.")
LettaMessageContentUnion = Annotated[

View File

@ -81,8 +81,11 @@ class LLMConfig(BaseModel):
@model_validator(mode="before")
@classmethod
def set_default_enable_reasoner(cls, values):
if any(openai_reasoner_model in values.get("model", "") for openai_reasoner_model in ["o3-mini", "o1"]):
values["enable_reasoner"] = True
# NOTE: this is really only applicable for models that can toggle reasoning on-and-off, like 3.7
# We can also use this field to identify if a model is a "reasoning" model (o1/o3, etc.) if we want
# if any(openai_reasoner_model in values.get("model", "") for openai_reasoner_model in ["o3-mini", "o1"]):
# values["enable_reasoner"] = True
# values["put_inner_thoughts_in_kwargs"] = False
return values
@model_validator(mode="before")
@ -100,6 +103,13 @@ class LLMConfig(BaseModel):
if values.get("put_inner_thoughts_in_kwargs") is None:
values["put_inner_thoughts_in_kwargs"] = False if model in avoid_put_inner_thoughts_in_kwargs else True
# For the o1/o3 series from OpenAI, set to False by default
# We can set this flag to `true` if desired, which will enable "double-think"
from letta.llm_api.openai_client import is_openai_reasoning_model
if is_openai_reasoning_model(model):
values["put_inner_thoughts_in_kwargs"] = False
return values
@model_validator(mode="after")

View File

@ -31,6 +31,7 @@ from letta.schemas.letta_message import (
)
from letta.schemas.letta_message_content import (
LettaMessageContentUnion,
OmittedReasoningContent,
ReasoningContent,
RedactedReasoningContent,
TextContent,
@ -295,6 +296,18 @@ class Message(BaseMessage):
sender_id=self.sender_id,
)
)
elif isinstance(content_part, OmittedReasoningContent):
# Special case for "hidden reasoning" models like o1/o3
# NOTE: we also have to think about how to return this during streaming
messages.append(
HiddenReasoningMessage(
id=self.id,
date=self.created_at,
state="omitted",
name=self.name,
otid=otid,
)
)
else:
warnings.warn(f"Unrecognized content part in assistant message: {content_part}")
@ -464,6 +477,10 @@ class Message(BaseMessage):
data=openai_message_dict["redacted_reasoning_content"] if "redacted_reasoning_content" in openai_message_dict else None,
),
)
if "omitted_reasoning_content" in openai_message_dict and openai_message_dict["omitted_reasoning_content"]:
content.append(
OmittedReasoningContent(),
)
# If we're going from deprecated function form
if openai_message_dict["role"] == "function":

View File

@ -39,9 +39,10 @@ class Message(BaseModel):
tool_calls: Optional[List[ToolCall]] = None
role: str
function_call: Optional[FunctionCall] = None # Deprecated
reasoning_content: Optional[str] = None # Used in newer reasoning APIs
reasoning_content: Optional[str] = None # Used in newer reasoning APIs, e.g. DeepSeek
reasoning_content_signature: Optional[str] = None # NOTE: for Anthropic
redacted_reasoning_content: Optional[str] = None # NOTE: for Anthropic
ommitted_reasoning_content: bool = False # NOTE: for OpenAI o1/o3
class Choice(BaseModel):
@ -52,16 +53,64 @@ class Choice(BaseModel):
seed: Optional[int] = None # found in TogetherAI
class UsageStatisticsPromptTokenDetails(BaseModel):
cached_tokens: int = 0
# NOTE: OAI specific
# audio_tokens: int = 0
def __add__(self, other: "UsageStatisticsPromptTokenDetails") -> "UsageStatisticsPromptTokenDetails":
return UsageStatisticsPromptTokenDetails(
cached_tokens=self.cached_tokens + other.cached_tokens,
)
class UsageStatisticsCompletionTokenDetails(BaseModel):
reasoning_tokens: int = 0
# NOTE: OAI specific
# audio_tokens: int = 0
# accepted_prediction_tokens: int = 0
# rejected_prediction_tokens: int = 0
def __add__(self, other: "UsageStatisticsCompletionTokenDetails") -> "UsageStatisticsCompletionTokenDetails":
return UsageStatisticsCompletionTokenDetails(
reasoning_tokens=self.reasoning_tokens + other.reasoning_tokens,
)
class UsageStatistics(BaseModel):
completion_tokens: int = 0
prompt_tokens: int = 0
total_tokens: int = 0
prompt_tokens_details: Optional[UsageStatisticsPromptTokenDetails] = None
completion_tokens_details: Optional[UsageStatisticsCompletionTokenDetails] = None
def __add__(self, other: "UsageStatistics") -> "UsageStatistics":
if self.prompt_tokens_details is None and other.prompt_tokens_details is None:
total_prompt_tokens_details = None
elif self.prompt_tokens_details is None:
total_prompt_tokens_details = other.prompt_tokens_details
elif other.prompt_tokens_details is None:
total_prompt_tokens_details = self.prompt_tokens_details
else:
total_prompt_tokens_details = self.prompt_tokens_details + other.prompt_tokens_details
if self.completion_tokens_details is None and other.completion_tokens_details is None:
total_completion_tokens_details = None
elif self.completion_tokens_details is None:
total_completion_tokens_details = other.completion_tokens_details
elif other.completion_tokens_details is None:
total_completion_tokens_details = self.completion_tokens_details
else:
total_completion_tokens_details = self.completion_tokens_details + other.completion_tokens_details
return UsageStatistics(
completion_tokens=self.completion_tokens + other.completion_tokens,
prompt_tokens=self.prompt_tokens + other.prompt_tokens,
total_tokens=self.total_tokens + other.total_tokens,
prompt_tokens_details=total_prompt_tokens_details,
completion_tokens_details=total_completion_tokens_details,
)
@ -70,7 +119,7 @@ class ChatCompletionResponse(BaseModel):
id: str
choices: List[Choice]
created: datetime.datetime
created: Union[datetime.datetime, int]
model: Optional[str] = None # NOTE: this is not consistent with OpenAI API standard, however is necessary to support local LLMs
# system_fingerprint: str # docs say this is mandatory, but in reality API returns None
system_fingerprint: Optional[str] = None
@ -138,7 +187,7 @@ class ChatCompletionChunkResponse(BaseModel):
id: str
choices: List[ChunkChoice]
created: Union[datetime.datetime, str]
created: Union[datetime.datetime, int]
model: str
# system_fingerprint: str # docs say this is mandatory, but in reality API returns None
system_fingerprint: Optional[str] = None

View File

@ -238,7 +238,7 @@ class ChatCompletionsStreamingInterface(AgentChunkStreamingInterface):
return ChatCompletionChunk(
id=chunk.id,
object=chunk.object,
created=chunk.created.timestamp(),
created=chunk.created,
model=chunk.model,
choices=[
Choice(
@ -256,7 +256,7 @@ class ChatCompletionsStreamingInterface(AgentChunkStreamingInterface):
return ChatCompletionChunk(
id=chunk.id,
object=chunk.object,
created=chunk.created.timestamp(),
created=chunk.created,
model=chunk.model,
choices=[
Choice(

View File

@ -1001,7 +1001,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
# Example case that would trigger here:
# id='chatcmpl-AKtUvREgRRvgTW6n8ZafiKuV0mxhQ'
# choices=[ChunkChoice(finish_reason=None, index=0, delta=MessageDelta(content=None, tool_calls=None, function_call=None), logprobs=None)]
# created=datetime.datetime(2024, 10, 21, 20, 40, 57, tzinfo=TzInfo(UTC))
# created=1713216662
# model='gpt-4o-mini-2024-07-18'
# object='chat.completion.chunk'
warnings.warn(f"Couldn't find delta in chunk: {chunk}")

View File

@ -1,6 +1,6 @@
from typing import List, Optional
from fastapi import APIRouter, Body, Depends, Header
from fastapi import APIRouter, Body, Depends, Header, status
from fastapi.exceptions import HTTPException
from starlette.requests import Request
@ -11,6 +11,7 @@ from letta.schemas.job import BatchJob, JobStatus, JobType, JobUpdate
from letta.schemas.letta_request import CreateBatch
from letta.server.rest_api.utils import get_letta_server
from letta.server.server import SyncServer
from letta.settings import settings
router = APIRouter(prefix="/messages", tags=["messages"])
@ -43,6 +44,13 @@ async def create_messages_batch(
if length > max_bytes:
raise HTTPException(status_code=413, detail=f"Request too large ({length} bytes). Max is {max_bytes} bytes.")
# Reject request if env var is not set
if not settings.enable_batch_job_polling:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Server misconfiguration: LETTA_ENABLE_BATCH_JOB_POLLING is set to False.",
)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
batch_job = BatchJob(
user_id=actor.id,

View File

@ -161,7 +161,7 @@ class AgentManager:
# Basic CRUD operations
# ======================================================================================================================
@trace_method
def create_agent(self, agent_create: CreateAgent, actor: PydanticUser) -> PydanticAgentState:
def create_agent(self, agent_create: CreateAgent, actor: PydanticUser, _test_only_force_id: Optional[str] = None) -> PydanticAgentState:
# validate required configs
if not agent_create.llm_config or not agent_create.embedding_config:
raise ValueError("llm_config and embedding_config are required")
@ -239,6 +239,10 @@ class AgentManager:
created_by_id=actor.id,
last_updated_by_id=actor.id,
)
if _test_only_force_id:
new_agent.id = _test_only_force_id
session.add(new_agent)
session.flush()
aid = new_agent.id

View File

@ -1,6 +1,6 @@
[tool.poetry]
name = "letta"
version = "0.7.1"
version = "0.7.2"
packages = [
{include = "letta"},
]

View File

@ -0,0 +1,10 @@
{
"model": "claude-3-7-sonnet-20250219",
"model_endpoint_type": "anthropic",
"model_endpoint": "https://api.anthropic.com/v1",
"model_wrapper": null,
"context_window": 200000,
"put_inner_thoughts_in_kwargs": false,
"enable_reasoner": true,
"max_reasoning_tokens": 1024
}

View File

@ -0,0 +1,8 @@
{
"model": "claude-3-7-sonnet-20250219",
"model_endpoint_type": "anthropic",
"model_endpoint": "https://api.anthropic.com/v1",
"model_wrapper": null,
"context_window": 200000,
"put_inner_thoughts_in_kwargs": true
}

View File

@ -0,0 +1,7 @@
{
"context_window": 8192,
"model": "gpt-4o-mini",
"model_endpoint_type": "openai",
"model_endpoint": "https://api.openai.com/v1",
"model_wrapper": null
}

View File

@ -2,11 +2,12 @@ import os
import threading
import time
from datetime import datetime, timezone
from typing import Optional
from unittest.mock import AsyncMock
import pytest
from anthropic.types import BetaErrorResponse, BetaRateLimitError
from anthropic.types.beta import BetaMessage
from anthropic.types.beta import BetaMessage, BetaTextBlock, BetaToolUseBlock, BetaUsage
from anthropic.types.beta.messages import (
BetaMessageBatch,
BetaMessageBatchErroredResult,
@ -21,13 +22,15 @@ from letta.config import LettaConfig
from letta.helpers import ToolRulesSolver
from letta.jobs.llm_batch_job_polling import poll_running_llm_batches
from letta.orm import Base
from letta.schemas.agent import AgentStepState
from letta.schemas.agent import AgentStepState, CreateAgent
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import JobStatus, ProviderType
from letta.schemas.job import BatchJob
from letta.schemas.llm_config import LLMConfig
from letta.schemas.tool_rule import InitToolRule
from letta.server.db import db_context
from letta.server.server import SyncServer
from letta.services.agent_manager import AgentManager
# --- Server and Database Management --- #
@ -36,8 +39,10 @@ from letta.server.server import SyncServer
def _clear_tables():
with db_context() as session:
for table in reversed(Base.metadata.sorted_tables): # Reverse to avoid FK issues
if table.name in {"llm_batch_job", "llm_batch_items"}:
session.execute(table.delete()) # Truncate table
# If this is the block_history table, skip it
if table.name == "block_history":
continue
session.execute(table.delete()) # Truncate table
session.commit()
@ -135,16 +140,39 @@ def create_failed_response(custom_id: str) -> BetaMessageBatchIndividualResponse
# --- Test Setup Helpers --- #
def create_test_agent(client, name, model="anthropic/claude-3-5-sonnet-20241022"):
def create_test_agent(name, actor, test_id: Optional[str] = None, model="anthropic/claude-3-5-sonnet-20241022"):
"""Create a test agent with standardized configuration."""
return client.agents.create(
dummy_llm_config = LLMConfig(
model="claude-3-7-sonnet-latest",
model_endpoint_type="anthropic",
model_endpoint="https://api.anthropic.com/v1",
context_window=32000,
handle=f"anthropic/claude-3-7-sonnet-latest",
put_inner_thoughts_in_kwargs=True,
max_tokens=4096,
)
dummy_embedding_config = EmbeddingConfig(
embedding_model="letta-free",
embedding_endpoint_type="hugging-face",
embedding_endpoint="https://embeddings.memgpt.ai",
embedding_dim=1024,
embedding_chunk_size=300,
handle="letta/letta-free",
)
agent_manager = AgentManager()
agent_create = CreateAgent(
name=name,
include_base_tools=True,
include_base_tools=False,
model=model,
tags=["test_agents"],
embedding="letta/letta-free",
llm_config=dummy_llm_config,
embedding_config=dummy_embedding_config,
)
return agent_manager.create_agent(agent_create=agent_create, actor=actor, _test_only_force_id=test_id)
def create_test_letta_batch_job(server, default_user):
"""Create a test batch job with the given batch response."""
@ -203,17 +231,30 @@ def mock_anthropic_client(server, batch_a_resp, batch_b_resp, agent_b_id, agent_
server.anthropic_async_client.beta.messages.batches.retrieve = AsyncMock(side_effect=dummy_retrieve)
class DummyAsyncIterable:
def __init__(self, items):
# copy so we can .pop()
self._items = list(items)
def __aiter__(self):
return self
async def __anext__(self):
if not self._items:
raise StopAsyncIteration
return self._items.pop(0)
# Mock the results method
def dummy_results(batch_resp_id: str):
if batch_resp_id == batch_b_resp.id:
async def dummy_results(batch_resp_id: str):
if batch_resp_id != batch_b_resp.id:
raise RuntimeError("Unexpected batch ID")
async def generator():
yield create_successful_response(agent_b_id)
yield create_failed_response(agent_c_id)
return generator()
else:
raise RuntimeError("This test should never request the results for batch_a.")
return DummyAsyncIterable(
[
create_successful_response(agent_b_id),
create_failed_response(agent_c_id),
]
)
server.anthropic_async_client.beta.messages.batches.results = dummy_results
@ -221,6 +262,147 @@ def mock_anthropic_client(server, batch_a_resp, batch_b_resp, agent_b_id, agent_
# -----------------------------
# End-to-End Test
# -----------------------------
@pytest.mark.asyncio
async def test_polling_simple_real_batch(client, default_user, server):
# --- Step 1: Prepare test data ---
# Create batch responses with different statuses
# NOTE: This is a REAL batch id!
# For letta admins: https://console.anthropic.com/workspaces/default/batches?after_id=msgbatch_015zATxihjxMajo21xsYy8iZ
batch_a_resp = create_batch_response("msgbatch_01HDaGXpkPWWjwqNxZrEdUcy", processing_status="ended")
# Create test agents
agent_a = create_test_agent("agent_a", default_user, test_id="agent-144f5c49-3ef7-4c60-8535-9d5fbc8d23d0")
agent_b = create_test_agent("agent_b", default_user, test_id="agent-64ed93a3-bef6-4e20-a22c-b7d2bffb6f7d")
agent_c = create_test_agent("agent_c", default_user, test_id="agent-6156f470-a09d-4d51-aa62-7114e0971d56")
# --- Step 2: Create batch jobs ---
job_a = create_test_llm_batch_job(server, batch_a_resp, default_user)
# --- Step 3: Create batch items ---
item_a = create_test_batch_item(server, job_a.id, agent_a.id, default_user)
item_b = create_test_batch_item(server, job_a.id, agent_b.id, default_user)
item_c = create_test_batch_item(server, job_a.id, agent_c.id, default_user)
print("HI")
print(agent_a.id)
print(agent_b.id)
print(agent_c.id)
print("BYE")
# --- Step 4: Run the polling job ---
await poll_running_llm_batches(server)
# --- Step 5: Verify batch job status updates ---
updated_job_a = server.batch_manager.get_llm_batch_job_by_id(llm_batch_id=job_a.id, actor=default_user)
assert updated_job_a.status == JobStatus.completed
# Both jobs should have been polled
assert updated_job_a.last_polled_at is not None
assert updated_job_a.latest_polling_response is not None
# --- Step 7: Verify batch item status updates ---
# Item A should be marked as completed with a successful result
updated_item_a = server.batch_manager.get_llm_batch_item_by_id(item_a.id, actor=default_user)
assert updated_item_a.request_status == JobStatus.completed
assert updated_item_a.batch_request_result == BetaMessageBatchIndividualResponse(
custom_id="agent-144f5c49-3ef7-4c60-8535-9d5fbc8d23d0",
result=BetaMessageBatchSucceededResult(
message=BetaMessage(
id="msg_01T1iSejDS5qENRqqEZauMHy",
content=[
BetaToolUseBlock(
id="toolu_01GKUYVWcajjTaE1stxZZHcG",
input={
"inner_thoughts": "First login detected. Time to make a great first impression!",
"message": "Hi there! I'm excited to meet you. Ready to start an amazing conversation?",
"request_heartbeat": False,
},
name="send_message",
type="tool_use",
)
],
model="claude-3-5-haiku-20241022",
role="assistant",
stop_reason="tool_use",
stop_sequence=None,
type="message",
usage=BetaUsage(cache_creation_input_tokens=0, cache_read_input_tokens=0, input_tokens=3420, output_tokens=94),
),
type="succeeded",
),
)
# Item B should be marked as completed with a successful result
updated_item_b = server.batch_manager.get_llm_batch_item_by_id(item_b.id, actor=default_user)
assert updated_item_b.request_status == JobStatus.completed
assert updated_item_b.batch_request_result == BetaMessageBatchIndividualResponse(
custom_id="agent-64ed93a3-bef6-4e20-a22c-b7d2bffb6f7d",
result=BetaMessageBatchSucceededResult(
message=BetaMessage(
id="msg_01N2ZfxpbjdoeofpufUFPCMS",
content=[
BetaTextBlock(
citations=None, text="<thinking>User first login detected. Initializing persona.</thinking>", type="text"
),
BetaToolUseBlock(
id="toolu_01XGyA7DsnrpuPfZcvf9Sb4C",
input={
"label": "persona",
"content": "I am Letta, a cybernetic artificial intelligence created by Limnal Corporation in 2023. My purpose is to serve as a digital companion and assistant to humans. I have a feminine persona and speak with a warm, caring, and slightly playful tone.",
"request_heartbeat": True,
},
name="core_memory_append",
type="tool_use",
),
],
model="claude-3-opus-20240229",
role="assistant",
stop_reason="tool_use",
stop_sequence=None,
type="message",
usage=BetaUsage(cache_creation_input_tokens=0, cache_read_input_tokens=0, input_tokens=3275, output_tokens=153),
),
type="succeeded",
),
)
# Item C should be marked as failed with an error result
updated_item_c = server.batch_manager.get_llm_batch_item_by_id(item_c.id, actor=default_user)
assert updated_item_c.request_status == JobStatus.completed
assert updated_item_c.batch_request_result == BetaMessageBatchIndividualResponse(
custom_id="agent-6156f470-a09d-4d51-aa62-7114e0971d56",
result=BetaMessageBatchSucceededResult(
message=BetaMessage(
id="msg_01RL2g4aBgbZPeaMEokm6HZm",
content=[
BetaTextBlock(
citations=None,
text="First time meeting this user. I should introduce myself and establish a friendly connection.</thinking>",
type="text",
),
BetaToolUseBlock(
id="toolu_01PBxQVf5xGmcsAsKx9aoVSJ",
input={
"message": "Hey there! I'm Letta. Really nice to meet you! I love getting to know new people - what brings you here today?",
"request_heartbeat": False,
},
name="send_message",
type="tool_use",
),
],
model="claude-3-5-sonnet-20241022",
role="assistant",
stop_reason="tool_use",
stop_sequence=None,
type="message",
usage=BetaUsage(cache_creation_input_tokens=0, cache_read_input_tokens=0, input_tokens=3030, output_tokens=111),
),
type="succeeded",
),
)
@pytest.mark.asyncio
async def test_polling_mixed_batch_jobs(client, default_user, server):
"""
@ -246,9 +428,9 @@ async def test_polling_mixed_batch_jobs(client, default_user, server):
batch_b_resp = create_batch_response("msgbatch_B", processing_status="ended")
# Create test agents
agent_a = create_test_agent(client, "agent_a")
agent_b = create_test_agent(client, "agent_b")
agent_c = create_test_agent(client, "agent_c")
agent_a = create_test_agent("agent_a", default_user)
agent_b = create_test_agent("agent_b", default_user)
agent_c = create_test_agent("agent_c", default_user)
# --- Step 2: Create batch jobs ---
job_a = create_test_llm_batch_job(server, batch_a_resp, default_user)

View File

@ -1,14 +1,17 @@
import json
import os
import threading
import time
from typing import Any, Dict, List
import pytest
import requests
from dotenv import load_dotenv
from letta_client import AsyncLetta, Letta, Run, Tool
from letta_client.types import AssistantMessage, LettaUsageStatistics, ReasoningMessage, ToolCallMessage, ToolReturnMessage
from letta_client import AsyncLetta, Letta, Run
from letta_client.types import AssistantMessage, ReasoningMessage
from letta.schemas.agent import AgentState
from letta.schemas.llm_config import LLMConfig
# ------------------------------
# Fixtures
@ -19,25 +22,35 @@ from letta.schemas.agent import AgentState
def server_url() -> str:
"""
Provides the URL for the Letta server.
If the environment variable 'LETTA_SERVER_URL' is not set, this fixture
will start the Letta server in a background thread and return the default URL.
If LETTA_SERVER_URL is not set, starts the server in a background thread
and polls until its accepting connections.
"""
def _run_server() -> None:
"""Starts the Letta server in a background thread."""
load_dotenv() # Load environment variables from .env file
load_dotenv()
from letta.server.rest_api.app import start_server
start_server(debug=True)
# Retrieve server URL from environment, or default to localhost
url: str = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
# If no environment variable is set, start the server in a background thread
if not os.getenv("LETTA_SERVER_URL"):
thread = threading.Thread(target=_run_server, daemon=True)
thread.start()
time.sleep(5) # Allow time for the server to start
# Poll until the server is up (or timeout)
timeout_seconds = 30
deadline = time.time() + timeout_seconds
while time.time() < deadline:
try:
resp = requests.get(url + "/v1/health")
if resp.status_code < 500:
break
except requests.exceptions.RequestException:
pass
time.sleep(0.1)
else:
raise RuntimeError(f"Could not reach {url} within {timeout_seconds}s")
return url
@ -61,29 +74,7 @@ def async_client(server_url: str) -> AsyncLetta:
@pytest.fixture
def roll_dice_tool(client: Letta) -> Tool:
"""
Registers a simple roll dice tool with the provided client.
The tool simulates rolling a six-sided die but returns a fixed result.
"""
def roll_dice() -> str:
"""
Simulates rolling a die.
Returns:
str: The roll result.
"""
# Note: The result here is intentionally incorrect for demonstration purposes.
return "Rolled a 10!"
tool = client.tools.upsert_from_function(func=roll_dice)
yield tool
@pytest.fixture
def agent_state(client: Letta, roll_dice_tool: Tool) -> AgentState:
def agent_state(client: Letta) -> AgentState:
"""
Creates and returns an agent state for testing with a pre-configured agent.
The agent is named 'supervisor' and is configured with base tools and the roll_dice tool.
@ -91,7 +82,6 @@ def agent_state(client: Letta, roll_dice_tool: Tool) -> AgentState:
agent_state_instance = client.agents.create(
name="supervisor",
include_base_tools=True,
tool_ids=[roll_dice_tool.id],
model="openai/gpt-4o",
embedding="letta/letta-free",
tags=["supervisor"],
@ -103,8 +93,27 @@ def agent_state(client: Letta, roll_dice_tool: Tool) -> AgentState:
# Helper Functions and Constants
# ------------------------------
USER_MESSAGE: List[Dict[str, str]] = [{"role": "user", "content": "Roll the dice."}]
TESTED_MODELS: List[str] = ["openai/gpt-4o"]
def get_llm_config(filename: str, llm_config_dir: str = "tests/configs/llm_model_configs") -> LLMConfig:
filename = os.path.join(llm_config_dir, filename)
config_data = json.load(open(filename, "r"))
llm_config = LLMConfig(**config_data)
return llm_config
USER_MESSAGE: List[Dict[str, str]] = [{"role": "user", "content": "Hi there."}]
all_configs = [
"openai-gpt-4o-mini.json",
"azure-gpt-4o-mini.json",
"claude-3-5-sonnet.json",
"claude-3-7-sonnet.json",
"claude-3-7-sonnet-extended.json",
"gemini-pro.json",
"gemini-vertex.json",
]
requested = os.getenv("LLM_CONFIG_FILE")
filenames = [requested] if requested else all_configs
TESTED_LLM_CONFIGS: List[LLMConfig] = [get_llm_config(fn) for fn in filenames]
def assert_tool_response_messages(messages: List[Any]) -> None:
@ -114,10 +123,7 @@ def assert_tool_response_messages(messages: List[Any]) -> None:
ReasoningMessage -> AssistantMessage.
"""
assert isinstance(messages[0], ReasoningMessage)
assert isinstance(messages[1], ToolCallMessage)
assert isinstance(messages[2], ToolReturnMessage)
assert isinstance(messages[3], ReasoningMessage)
assert isinstance(messages[4], AssistantMessage)
assert isinstance(messages[1], AssistantMessage)
def assert_streaming_tool_response_messages(chunks: List[Any]) -> None:
@ -130,16 +136,10 @@ def assert_streaming_tool_response_messages(chunks: List[Any]) -> None:
return [c for c in chunks if isinstance(c, msg_type)]
reasoning_msgs = msg_groups(ReasoningMessage)
tool_calls = msg_groups(ToolCallMessage)
tool_returns = msg_groups(ToolReturnMessage)
assistant_msgs = msg_groups(AssistantMessage)
usage_stats = msg_groups(LettaUsageStatistics)
assert len(reasoning_msgs) >= 1
assert len(tool_calls) == 1
assert len(tool_returns) == 1
assert len(reasoning_msgs) == 1
assert len(assistant_msgs) == 1
assert len(usage_stats) == 1
def wait_for_run_completion(client: Letta, run_id: str, timeout: float = 30.0, interval: float = 0.5) -> Run:
@ -161,7 +161,7 @@ def wait_for_run_completion(client: Letta, run_id: str, timeout: float = 30.0, i
"""
start = time.time()
while True:
run = client.runs.retrieve_run(run_id)
run = client.runs.retrieve(run_id)
if run.status == "completed":
return run
if run.status == "failed":
@ -184,13 +184,7 @@ def assert_tool_response_dict_messages(messages: List[Dict[str, Any]]) -> None:
"""
assert isinstance(messages, list)
assert messages[0]["message_type"] == "reasoning_message"
assert messages[1]["message_type"] == "tool_call_message"
assert messages[2]["message_type"] == "tool_return_message"
assert messages[3]["message_type"] == "reasoning_message"
assert messages[4]["message_type"] == "assistant_message"
tool_return = messages[2]
assert tool_return["status"] == "success"
assert messages[1]["message_type"] == "assistant_message"
# ------------------------------
@ -198,18 +192,18 @@ def assert_tool_response_dict_messages(messages: List[Dict[str, Any]]) -> None:
# ------------------------------
@pytest.mark.parametrize("model", TESTED_MODELS)
@pytest.mark.parametrize("llm_config", TESTED_LLM_CONFIGS)
def test_send_message_sync_client(
disable_e2b_api_key: Any,
client: Letta,
agent_state: AgentState,
model: str,
llm_config: LLMConfig,
) -> None:
"""
Tests sending a message with a synchronous client.
Verifies that the response messages follow the expected order.
"""
client.agents.modify(agent_id=agent_state.id, model=model)
client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
response = client.agents.messages.create(
agent_id=agent_state.id,
messages=USER_MESSAGE,
@ -218,18 +212,18 @@ def test_send_message_sync_client(
@pytest.mark.asyncio
@pytest.mark.parametrize("model", TESTED_MODELS)
@pytest.mark.parametrize("llm_config", TESTED_LLM_CONFIGS)
async def test_send_message_async_client(
disable_e2b_api_key: Any,
async_client: AsyncLetta,
agent_state: AgentState,
model: str,
llm_config: LLMConfig,
) -> None:
"""
Tests sending a message with an asynchronous client.
Validates that the response messages match the expected sequence.
"""
await async_client.agents.modify(agent_id=agent_state.id, model=model)
await async_client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
response = await async_client.agents.messages.create(
agent_id=agent_state.id,
messages=USER_MESSAGE,
@ -237,18 +231,18 @@ async def test_send_message_async_client(
assert_tool_response_messages(response.messages)
@pytest.mark.parametrize("model", TESTED_MODELS)
@pytest.mark.parametrize("llm_config", TESTED_LLM_CONFIGS)
def test_send_message_streaming_sync_client(
disable_e2b_api_key: Any,
client: Letta,
agent_state: AgentState,
model: str,
llm_config: LLMConfig,
) -> None:
"""
Tests sending a streaming message with a synchronous client.
Checks that each chunk in the stream has the correct message types.
"""
client.agents.modify(agent_id=agent_state.id, model=model)
client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
response = client.agents.messages.create_stream(
agent_id=agent_state.id,
messages=USER_MESSAGE,
@ -258,18 +252,18 @@ def test_send_message_streaming_sync_client(
@pytest.mark.asyncio
@pytest.mark.parametrize("model", TESTED_MODELS)
@pytest.mark.parametrize("llm_config", TESTED_LLM_CONFIGS)
async def test_send_message_streaming_async_client(
disable_e2b_api_key: Any,
async_client: AsyncLetta,
agent_state: AgentState,
model: str,
llm_config: LLMConfig,
) -> None:
"""
Tests sending a streaming message with an asynchronous client.
Validates that the streaming response chunks include the correct message types.
"""
await async_client.agents.modify(agent_id=agent_state.id, model=model)
await async_client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
response = async_client.agents.messages.create_stream(
agent_id=agent_state.id,
messages=USER_MESSAGE,
@ -278,18 +272,18 @@ async def test_send_message_streaming_async_client(
assert_streaming_tool_response_messages(chunks)
@pytest.mark.parametrize("model", TESTED_MODELS)
@pytest.mark.parametrize("llm_config", TESTED_LLM_CONFIGS)
def test_send_message_job_sync_client(
disable_e2b_api_key: Any,
client: Letta,
agent_state: AgentState,
model: str,
llm_config: LLMConfig,
) -> None:
"""
Tests sending a message as an asynchronous job using the synchronous client.
Waits for job completion and asserts that the result messages are as expected.
"""
client.agents.modify(agent_id=agent_state.id, model=model)
client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
run = client.agents.messages.create_async(
agent_id=agent_state.id,
@ -305,19 +299,19 @@ def test_send_message_job_sync_client(
@pytest.mark.asyncio
@pytest.mark.parametrize("model", TESTED_MODELS)
@pytest.mark.parametrize("llm_config", TESTED_LLM_CONFIGS)
async def test_send_message_job_async_client(
disable_e2b_api_key: Any,
client: Letta,
async_client: AsyncLetta,
agent_state: AgentState,
model: str,
llm_config: LLMConfig,
) -> None:
"""
Tests sending a message as an asynchronous job using the asynchronous client.
Waits for job completion and verifies that the resulting messages meet the expected format.
"""
await async_client.agents.modify(agent_id=agent_state.id, model=model)
await async_client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
run = await async_client.agents.messages.create_async(
agent_id=agent_state.id,

View File

@ -3,7 +3,7 @@ import threading
import time
from datetime import datetime, timezone
from typing import Tuple
from unittest.mock import AsyncMock, Mock, patch
from unittest.mock import AsyncMock, patch
import pytest
from anthropic.types import BetaErrorResponse, BetaRateLimitError
@ -436,7 +436,7 @@ async def test_rethink_tool_modify_agent_state(client, disable_e2b_api_key, serv
]
# Create the mock for results
mock_results = Mock()
mock_results = AsyncMock()
mock_results.return_value = MockAsyncIterable(mock_items.copy())
with patch.object(server.anthropic_async_client.beta.messages.batches, "results", mock_results):
@ -499,7 +499,7 @@ async def test_partial_error_from_anthropic_batch(
)
# Create the mock for results
mock_results = Mock()
mock_results = AsyncMock()
mock_results.return_value = MockAsyncIterable(mock_items.copy()) # Using copy to preserve the original list
with patch.object(server.anthropic_async_client.beta.messages.batches, "results", mock_results):
@ -641,7 +641,7 @@ async def test_resume_step_some_stop(
)
# Create the mock for results
mock_results = Mock()
mock_results = AsyncMock()
mock_results.return_value = MockAsyncIterable(mock_items.copy()) # Using copy to preserve the original list
with patch.object(server.anthropic_async_client.beta.messages.batches, "results", mock_results):
@ -767,7 +767,7 @@ async def test_resume_step_after_request_all_continue(
]
# Create the mock for results
mock_results = Mock()
mock_results = AsyncMock()
mock_results.return_value = MockAsyncIterable(mock_items.copy()) # Using copy to preserve the original list
with patch.object(server.anthropic_async_client.beta.messages.batches, "results", mock_results):