diff --git a/letta/__init__.py b/letta/__init__.py index 5401a9883..8858feb1e 100644 --- a/letta/__init__.py +++ b/letta/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.7.1" +__version__ = "0.7.2" # import clients from letta.client.client import LocalClient, RESTClient, create_client diff --git a/letta/helpers/composio_helpers.py b/letta/helpers/composio_helpers.py index a3c518eca..2a0281e16 100644 --- a/letta/helpers/composio_helpers.py +++ b/letta/helpers/composio_helpers.py @@ -10,7 +10,7 @@ def get_composio_api_key(actor: User, logger: Optional[Logger] = None) -> Option api_keys = SandboxConfigManager().list_sandbox_env_vars_by_key(key="COMPOSIO_API_KEY", actor=actor) if not api_keys: if logger: - logger.warning(f"No API keys found for Composio. Defaulting to the environment variable...") + logger.debug(f"No API keys found for Composio. Defaulting to the environment variable...") if tool_settings.composio_api_key: return tool_settings.composio_api_key else: diff --git a/letta/helpers/datetime_helpers.py b/letta/helpers/datetime_helpers.py index e99074a69..7ee4aa409 100644 --- a/letta/helpers/datetime_helpers.py +++ b/letta/helpers/datetime_helpers.py @@ -66,6 +66,15 @@ def get_utc_time() -> datetime: return datetime.now(timezone.utc) +def get_utc_time_int() -> int: + return int(get_utc_time().timestamp()) + + +def timestamp_to_datetime(timestamp_seconds: int) -> datetime: + """Convert Unix timestamp in seconds to UTC datetime object""" + return datetime.fromtimestamp(timestamp_seconds, tz=timezone.utc) + + def format_datetime(dt): return dt.strftime("%Y-%m-%d %I:%M:%S %p %Z%z") diff --git a/letta/jobs/llm_batch_job_polling.py b/letta/jobs/llm_batch_job_polling.py index 6ca14f6e4..a12274758 100644 --- a/letta/jobs/llm_batch_job_polling.py +++ b/letta/jobs/llm_batch_job_polling.py @@ -73,7 +73,8 @@ async def fetch_batch_items(server: SyncServer, batch_id: str, batch_resp_id: st """ updates = [] try: - async for item_result in server.anthropic_async_client.beta.messages.batches.results(batch_resp_id): + results = await server.anthropic_async_client.beta.messages.batches.results(batch_resp_id) + async for item_result in results: # Here, custom_id should be the agent_id item_status = map_anthropic_individual_batch_item_status_to_job_status(item_result) updates.append(ItemUpdateInfo(batch_id, item_result.custom_id, item_status, item_result)) diff --git a/letta/llm_api/anthropic.py b/letta/llm_api/anthropic.py index 2f6bd296a..59939e4d6 100644 --- a/letta/llm_api/anthropic.py +++ b/letta/llm_api/anthropic.py @@ -20,7 +20,7 @@ from anthropic.types.beta import ( ) from letta.errors import BedrockError, BedrockPermissionError -from letta.helpers.datetime_helpers import get_utc_time +from letta.helpers.datetime_helpers import get_utc_time_int, timestamp_to_datetime from letta.llm_api.aws_bedrock import get_bedrock_client from letta.llm_api.helpers import add_inner_thoughts_to_functions from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION @@ -396,7 +396,7 @@ def convert_anthropic_response_to_chatcompletion( return ChatCompletionResponse( id=response.id, choices=[choice], - created=get_utc_time(), + created=get_utc_time_int(), model=response.model, usage=UsageStatistics( prompt_tokens=prompt_tokens, @@ -451,7 +451,7 @@ def convert_anthropic_stream_event_to_chatcompletion( 'logprobs': None } ], - 'created': datetime.datetime(2025, 1, 24, 0, 18, 55, tzinfo=TzInfo(UTC)), + 'created': 1713216662, 'model': 'gpt-4o-mini-2024-07-18', 'system_fingerprint': 'fp_bd83329f63', 'object': 'chat.completion.chunk' @@ -613,7 +613,7 @@ def convert_anthropic_stream_event_to_chatcompletion( return ChatCompletionChunkResponse( id=message_id, choices=[choice], - created=get_utc_time(), + created=get_utc_time_int(), model=model, output_tokens=completion_chunk_tokens, ) @@ -920,7 +920,7 @@ def anthropic_chat_completions_process_stream( chat_completion_response = ChatCompletionResponse( id=dummy_message.id if create_message_id else TEMP_STREAM_RESPONSE_ID, choices=[], - created=dummy_message.created_at, + created=int(dummy_message.created_at.timestamp()), model=chat_completion_request.model, usage=UsageStatistics( prompt_tokens=prompt_tokens, @@ -954,7 +954,11 @@ def anthropic_chat_completions_process_stream( message_type = stream_interface.process_chunk( chat_completion_chunk, message_id=chat_completion_response.id if create_message_id else chat_completion_chunk.id, - message_date=chat_completion_response.created if create_message_datetime else chat_completion_chunk.created, + message_date=( + timestamp_to_datetime(chat_completion_response.created) + if create_message_datetime + else timestamp_to_datetime(chat_completion_chunk.created) + ), # if extended_thinking is on, then reasoning_content will be flowing as chunks # TODO handle emitting redacted reasoning content (e.g. as concat?) expect_reasoning_content=extended_thinking, diff --git a/letta/llm_api/anthropic_client.py b/letta/llm_api/anthropic_client.py index cd9c08157..4c79cb688 100644 --- a/letta/llm_api/anthropic_client.py +++ b/letta/llm_api/anthropic_client.py @@ -22,7 +22,7 @@ from letta.errors import ( LLMServerError, LLMUnprocessableEntityError, ) -from letta.helpers.datetime_helpers import get_utc_time +from letta.helpers.datetime_helpers import get_utc_time_int from letta.llm_api.helpers import add_inner_thoughts_to_functions, unpack_all_inner_thoughts_from_kwargs from letta.llm_api.llm_client_base import LLMClientBase from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION @@ -403,7 +403,7 @@ class AnthropicClient(LLMClientBase): chat_completion_response = ChatCompletionResponse( id=response.id, choices=[choice], - created=get_utc_time(), + created=get_utc_time_int(), model=response.model, usage=UsageStatistics( prompt_tokens=prompt_tokens, diff --git a/letta/llm_api/cohere.py b/letta/llm_api/cohere.py index 640e0c09f..4a30d7968 100644 --- a/letta/llm_api/cohere.py +++ b/letta/llm_api/cohere.py @@ -4,7 +4,7 @@ from typing import List, Optional, Union import requests -from letta.helpers.datetime_helpers import get_utc_time +from letta.helpers.datetime_helpers import get_utc_time_int from letta.helpers.json_helpers import json_dumps from letta.local_llm.utils import count_tokens from letta.schemas.message import Message @@ -207,7 +207,7 @@ def convert_cohere_response_to_chatcompletion( return ChatCompletionResponse( id=response_json["response_id"], choices=[choice], - created=get_utc_time(), + created=get_utc_time_int(), model=model, usage=UsageStatistics( prompt_tokens=prompt_tokens, diff --git a/letta/llm_api/google_ai_client.py b/letta/llm_api/google_ai_client.py index 6630335c0..d8bdf1ef4 100644 --- a/letta/llm_api/google_ai_client.py +++ b/letta/llm_api/google_ai_client.py @@ -6,7 +6,7 @@ import requests from google.genai.types import FunctionCallingConfig, FunctionCallingConfigMode, ToolConfig from letta.constants import NON_USER_MSG_PREFIX -from letta.helpers.datetime_helpers import get_utc_time +from letta.helpers.datetime_helpers import get_utc_time_int from letta.helpers.json_helpers import json_dumps from letta.llm_api.helpers import make_post_request from letta.llm_api.llm_client_base import LLMClientBase @@ -260,7 +260,7 @@ class GoogleAIClient(LLMClientBase): id=response_id, choices=choices, model=self.llm_config.model, # NOTE: Google API doesn't pass back model in the response - created=get_utc_time(), + created=get_utc_time_int(), usage=usage, ) except KeyError as e: diff --git a/letta/llm_api/google_vertex_client.py b/letta/llm_api/google_vertex_client.py index 38ad4db55..1f82946e8 100644 --- a/letta/llm_api/google_vertex_client.py +++ b/letta/llm_api/google_vertex_client.py @@ -4,7 +4,7 @@ from typing import List, Optional from google import genai from google.genai.types import FunctionCallingConfig, FunctionCallingConfigMode, GenerateContentResponse, ThinkingConfig, ToolConfig -from letta.helpers.datetime_helpers import get_utc_time +from letta.helpers.datetime_helpers import get_utc_time_int from letta.helpers.json_helpers import json_dumps from letta.llm_api.google_ai_client import GoogleAIClient from letta.local_llm.json_parser import clean_json_string_extra_backslash @@ -234,7 +234,7 @@ class GoogleVertexClient(GoogleAIClient): id=response_id, choices=choices, model=self.llm_config.model, # NOTE: Google API doesn't pass back model in the response - created=get_utc_time(), + created=get_utc_time_int(), usage=usage, ) except KeyError as e: diff --git a/letta/llm_api/openai.py b/letta/llm_api/openai.py index ffb64a99c..eda4c9a86 100644 --- a/letta/llm_api/openai.py +++ b/letta/llm_api/openai.py @@ -4,7 +4,9 @@ from typing import Generator, List, Optional, Union import requests from openai import OpenAI +from letta.helpers.datetime_helpers import timestamp_to_datetime from letta.llm_api.helpers import add_inner_thoughts_to_functions, convert_to_structured_output, make_post_request +from letta.llm_api.openai_client import supports_parallel_tool_calling, supports_temperature_param from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION, INNER_THOUGHTS_KWARG_DESCRIPTION_GO_FIRST from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages from letta.log import get_logger @@ -135,7 +137,7 @@ def build_openai_chat_completions_request( tool_choice=tool_choice, user=str(user_id), max_completion_tokens=llm_config.max_tokens, - temperature=1.0 if llm_config.enable_reasoner else llm_config.temperature, + temperature=llm_config.temperature if supports_temperature_param(model) else None, reasoning_effort=llm_config.reasoning_effort, ) else: @@ -237,7 +239,7 @@ def openai_chat_completions_process_stream( chat_completion_response = ChatCompletionResponse( id=dummy_message.id if create_message_id else TEMP_STREAM_RESPONSE_ID, choices=[], - created=dummy_message.created_at, # NOTE: doesn't matter since both will do get_utc_time() + created=int(dummy_message.created_at.timestamp()), # NOTE: doesn't matter since both will do get_utc_time() model=chat_completion_request.model, usage=UsageStatistics( completion_tokens=0, @@ -274,7 +276,11 @@ def openai_chat_completions_process_stream( message_type = stream_interface.process_chunk( chat_completion_chunk, message_id=chat_completion_response.id if create_message_id else chat_completion_chunk.id, - message_date=chat_completion_response.created if create_message_datetime else chat_completion_chunk.created, + message_date=( + timestamp_to_datetime(chat_completion_response.created) + if create_message_datetime + else timestamp_to_datetime(chat_completion_chunk.created) + ), expect_reasoning_content=expect_reasoning_content, name=name, message_index=message_idx, @@ -489,6 +495,7 @@ def prepare_openai_payload(chat_completion_request: ChatCompletionRequest): # except ValueError as e: # warnings.warn(f"Failed to convert tool function to structured output, tool={tool}, error={e}") - if "o3-mini" in chat_completion_request.model or "o1" in chat_completion_request.model: + if not supports_parallel_tool_calling(chat_completion_request.model): data.pop("parallel_tool_calls", None) + return data diff --git a/letta/llm_api/openai_client.py b/letta/llm_api/openai_client.py index 426069ab4..ada788caa 100644 --- a/letta/llm_api/openai_client.py +++ b/letta/llm_api/openai_client.py @@ -34,6 +34,33 @@ from letta.settings import model_settings logger = get_logger(__name__) +def is_openai_reasoning_model(model: str) -> bool: + """Utility function to check if the model is a 'reasoner'""" + + # NOTE: needs to be updated with new model releases + return model.startswith("o1") or model.startswith("o3") + + +def supports_temperature_param(model: str) -> bool: + """Certain OpenAI models don't support configuring the temperature. + + Example error: 400 - {'error': {'message': "Unsupported parameter: 'temperature' is not supported with this model.", 'type': 'invalid_request_error', 'param': 'temperature', 'code': 'unsupported_parameter'}} + """ + if is_openai_reasoning_model(model): + return False + else: + return True + + +def supports_parallel_tool_calling(model: str) -> bool: + """Certain OpenAI models don't support parallel tool calls.""" + + if is_openai_reasoning_model(model): + return False + else: + return True + + class OpenAIClient(LLMClientBase): def _prepare_client_kwargs(self) -> dict: api_key = model_settings.openai_api_key or os.environ.get("OPENAI_API_KEY") @@ -66,7 +93,8 @@ class OpenAIClient(LLMClientBase): put_inner_thoughts_first=True, ) - use_developer_message = llm_config.model.startswith("o1") or llm_config.model.startswith("o3") # o-series models + use_developer_message = is_openai_reasoning_model(llm_config.model) + openai_message_list = [ cast_message_to_subtype( m.to_openai_dict( @@ -103,7 +131,7 @@ class OpenAIClient(LLMClientBase): tool_choice=tool_choice, user=str(), max_completion_tokens=llm_config.max_tokens, - temperature=llm_config.temperature, + temperature=llm_config.temperature if supports_temperature_param(model) else None, ) if "inference.memgpt.ai" in llm_config.model_endpoint: @@ -160,6 +188,10 @@ class OpenAIClient(LLMClientBase): response=chat_completion_response, inner_thoughts_key=INNER_THOUGHTS_KWARG ) + # If we used a reasoning model, create a content part for the ommitted reasoning + if is_openai_reasoning_model(self.llm_config.model): + chat_completion_response.choices[0].message.ommitted_reasoning_content = True + return chat_completion_response def stream(self, request_data: dict) -> Stream[ChatCompletionChunk]: diff --git a/letta/local_llm/chat_completion_proxy.py b/letta/local_llm/chat_completion_proxy.py index 4abc01ee6..35db97ed8 100644 --- a/letta/local_llm/chat_completion_proxy.py +++ b/letta/local_llm/chat_completion_proxy.py @@ -6,7 +6,7 @@ import requests from letta.constants import CLI_WARNING_PREFIX from letta.errors import LocalLLMConnectionError, LocalLLMError -from letta.helpers.datetime_helpers import get_utc_time +from letta.helpers.datetime_helpers import get_utc_time_int from letta.helpers.json_helpers import json_dumps from letta.local_llm.constants import DEFAULT_WRAPPER from letta.local_llm.function_parser import patch_function @@ -241,7 +241,7 @@ def get_chat_completion( ), ) ], - created=get_utc_time(), + created=get_utc_time_int(), model=model, # "This fingerprint represents the backend configuration that the model runs with." # system_fingerprint=user if user is not None else "null", diff --git a/letta/schemas/letta_message_content.py b/letta/schemas/letta_message_content.py index 00ebfe788..400926989 100644 --- a/letta/schemas/letta_message_content.py +++ b/letta/schemas/letta_message_content.py @@ -145,7 +145,8 @@ class OmittedReasoningContent(MessageContent): type: Literal[MessageContentType.omitted_reasoning] = Field( MessageContentType.omitted_reasoning, description="Indicates this is an omitted reasoning step." ) - tokens: int = Field(..., description="The reasoning token count for intermediate reasoning content.") + # NOTE: dropping because we don't track this kind of information for the other reasoning types + # tokens: int = Field(..., description="The reasoning token count for intermediate reasoning content.") LettaMessageContentUnion = Annotated[ diff --git a/letta/schemas/llm_config.py b/letta/schemas/llm_config.py index 7a3531bb5..dea376ce4 100644 --- a/letta/schemas/llm_config.py +++ b/letta/schemas/llm_config.py @@ -81,8 +81,11 @@ class LLMConfig(BaseModel): @model_validator(mode="before") @classmethod def set_default_enable_reasoner(cls, values): - if any(openai_reasoner_model in values.get("model", "") for openai_reasoner_model in ["o3-mini", "o1"]): - values["enable_reasoner"] = True + # NOTE: this is really only applicable for models that can toggle reasoning on-and-off, like 3.7 + # We can also use this field to identify if a model is a "reasoning" model (o1/o3, etc.) if we want + # if any(openai_reasoner_model in values.get("model", "") for openai_reasoner_model in ["o3-mini", "o1"]): + # values["enable_reasoner"] = True + # values["put_inner_thoughts_in_kwargs"] = False return values @model_validator(mode="before") @@ -100,6 +103,13 @@ class LLMConfig(BaseModel): if values.get("put_inner_thoughts_in_kwargs") is None: values["put_inner_thoughts_in_kwargs"] = False if model in avoid_put_inner_thoughts_in_kwargs else True + # For the o1/o3 series from OpenAI, set to False by default + # We can set this flag to `true` if desired, which will enable "double-think" + from letta.llm_api.openai_client import is_openai_reasoning_model + + if is_openai_reasoning_model(model): + values["put_inner_thoughts_in_kwargs"] = False + return values @model_validator(mode="after") diff --git a/letta/schemas/message.py b/letta/schemas/message.py index dfc36fe2c..76d3dd05e 100644 --- a/letta/schemas/message.py +++ b/letta/schemas/message.py @@ -31,6 +31,7 @@ from letta.schemas.letta_message import ( ) from letta.schemas.letta_message_content import ( LettaMessageContentUnion, + OmittedReasoningContent, ReasoningContent, RedactedReasoningContent, TextContent, @@ -295,6 +296,18 @@ class Message(BaseMessage): sender_id=self.sender_id, ) ) + elif isinstance(content_part, OmittedReasoningContent): + # Special case for "hidden reasoning" models like o1/o3 + # NOTE: we also have to think about how to return this during streaming + messages.append( + HiddenReasoningMessage( + id=self.id, + date=self.created_at, + state="omitted", + name=self.name, + otid=otid, + ) + ) else: warnings.warn(f"Unrecognized content part in assistant message: {content_part}") @@ -464,6 +477,10 @@ class Message(BaseMessage): data=openai_message_dict["redacted_reasoning_content"] if "redacted_reasoning_content" in openai_message_dict else None, ), ) + if "omitted_reasoning_content" in openai_message_dict and openai_message_dict["omitted_reasoning_content"]: + content.append( + OmittedReasoningContent(), + ) # If we're going from deprecated function form if openai_message_dict["role"] == "function": diff --git a/letta/schemas/openai/chat_completion_response.py b/letta/schemas/openai/chat_completion_response.py index 920288ffe..d4332b22d 100644 --- a/letta/schemas/openai/chat_completion_response.py +++ b/letta/schemas/openai/chat_completion_response.py @@ -39,9 +39,10 @@ class Message(BaseModel): tool_calls: Optional[List[ToolCall]] = None role: str function_call: Optional[FunctionCall] = None # Deprecated - reasoning_content: Optional[str] = None # Used in newer reasoning APIs + reasoning_content: Optional[str] = None # Used in newer reasoning APIs, e.g. DeepSeek reasoning_content_signature: Optional[str] = None # NOTE: for Anthropic redacted_reasoning_content: Optional[str] = None # NOTE: for Anthropic + ommitted_reasoning_content: bool = False # NOTE: for OpenAI o1/o3 class Choice(BaseModel): @@ -52,16 +53,64 @@ class Choice(BaseModel): seed: Optional[int] = None # found in TogetherAI +class UsageStatisticsPromptTokenDetails(BaseModel): + cached_tokens: int = 0 + # NOTE: OAI specific + # audio_tokens: int = 0 + + def __add__(self, other: "UsageStatisticsPromptTokenDetails") -> "UsageStatisticsPromptTokenDetails": + return UsageStatisticsPromptTokenDetails( + cached_tokens=self.cached_tokens + other.cached_tokens, + ) + + +class UsageStatisticsCompletionTokenDetails(BaseModel): + reasoning_tokens: int = 0 + # NOTE: OAI specific + # audio_tokens: int = 0 + # accepted_prediction_tokens: int = 0 + # rejected_prediction_tokens: int = 0 + + def __add__(self, other: "UsageStatisticsCompletionTokenDetails") -> "UsageStatisticsCompletionTokenDetails": + return UsageStatisticsCompletionTokenDetails( + reasoning_tokens=self.reasoning_tokens + other.reasoning_tokens, + ) + + class UsageStatistics(BaseModel): completion_tokens: int = 0 prompt_tokens: int = 0 total_tokens: int = 0 + prompt_tokens_details: Optional[UsageStatisticsPromptTokenDetails] = None + completion_tokens_details: Optional[UsageStatisticsCompletionTokenDetails] = None + def __add__(self, other: "UsageStatistics") -> "UsageStatistics": + + if self.prompt_tokens_details is None and other.prompt_tokens_details is None: + total_prompt_tokens_details = None + elif self.prompt_tokens_details is None: + total_prompt_tokens_details = other.prompt_tokens_details + elif other.prompt_tokens_details is None: + total_prompt_tokens_details = self.prompt_tokens_details + else: + total_prompt_tokens_details = self.prompt_tokens_details + other.prompt_tokens_details + + if self.completion_tokens_details is None and other.completion_tokens_details is None: + total_completion_tokens_details = None + elif self.completion_tokens_details is None: + total_completion_tokens_details = other.completion_tokens_details + elif other.completion_tokens_details is None: + total_completion_tokens_details = self.completion_tokens_details + else: + total_completion_tokens_details = self.completion_tokens_details + other.completion_tokens_details + return UsageStatistics( completion_tokens=self.completion_tokens + other.completion_tokens, prompt_tokens=self.prompt_tokens + other.prompt_tokens, total_tokens=self.total_tokens + other.total_tokens, + prompt_tokens_details=total_prompt_tokens_details, + completion_tokens_details=total_completion_tokens_details, ) @@ -70,7 +119,7 @@ class ChatCompletionResponse(BaseModel): id: str choices: List[Choice] - created: datetime.datetime + created: Union[datetime.datetime, int] model: Optional[str] = None # NOTE: this is not consistent with OpenAI API standard, however is necessary to support local LLMs # system_fingerprint: str # docs say this is mandatory, but in reality API returns None system_fingerprint: Optional[str] = None @@ -138,7 +187,7 @@ class ChatCompletionChunkResponse(BaseModel): id: str choices: List[ChunkChoice] - created: Union[datetime.datetime, str] + created: Union[datetime.datetime, int] model: str # system_fingerprint: str # docs say this is mandatory, but in reality API returns None system_fingerprint: Optional[str] = None diff --git a/letta/server/rest_api/chat_completions_interface.py b/letta/server/rest_api/chat_completions_interface.py index 77550a52b..0f684ed7c 100644 --- a/letta/server/rest_api/chat_completions_interface.py +++ b/letta/server/rest_api/chat_completions_interface.py @@ -238,7 +238,7 @@ class ChatCompletionsStreamingInterface(AgentChunkStreamingInterface): return ChatCompletionChunk( id=chunk.id, object=chunk.object, - created=chunk.created.timestamp(), + created=chunk.created, model=chunk.model, choices=[ Choice( @@ -256,7 +256,7 @@ class ChatCompletionsStreamingInterface(AgentChunkStreamingInterface): return ChatCompletionChunk( id=chunk.id, object=chunk.object, - created=chunk.created.timestamp(), + created=chunk.created, model=chunk.model, choices=[ Choice( diff --git a/letta/server/rest_api/interface.py b/letta/server/rest_api/interface.py index 469ff0a20..edf8a2330 100644 --- a/letta/server/rest_api/interface.py +++ b/letta/server/rest_api/interface.py @@ -1001,7 +1001,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): # Example case that would trigger here: # id='chatcmpl-AKtUvREgRRvgTW6n8ZafiKuV0mxhQ' # choices=[ChunkChoice(finish_reason=None, index=0, delta=MessageDelta(content=None, tool_calls=None, function_call=None), logprobs=None)] - # created=datetime.datetime(2024, 10, 21, 20, 40, 57, tzinfo=TzInfo(UTC)) + # created=1713216662 # model='gpt-4o-mini-2024-07-18' # object='chat.completion.chunk' warnings.warn(f"Couldn't find delta in chunk: {chunk}") diff --git a/letta/server/rest_api/routers/v1/messages.py b/letta/server/rest_api/routers/v1/messages.py index 5424edda4..252e6fe86 100644 --- a/letta/server/rest_api/routers/v1/messages.py +++ b/letta/server/rest_api/routers/v1/messages.py @@ -1,6 +1,6 @@ from typing import List, Optional -from fastapi import APIRouter, Body, Depends, Header +from fastapi import APIRouter, Body, Depends, Header, status from fastapi.exceptions import HTTPException from starlette.requests import Request @@ -11,6 +11,7 @@ from letta.schemas.job import BatchJob, JobStatus, JobType, JobUpdate from letta.schemas.letta_request import CreateBatch from letta.server.rest_api.utils import get_letta_server from letta.server.server import SyncServer +from letta.settings import settings router = APIRouter(prefix="/messages", tags=["messages"]) @@ -43,6 +44,13 @@ async def create_messages_batch( if length > max_bytes: raise HTTPException(status_code=413, detail=f"Request too large ({length} bytes). Max is {max_bytes} bytes.") + # Reject request if env var is not set + if not settings.enable_batch_job_polling: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Server misconfiguration: LETTA_ENABLE_BATCH_JOB_POLLING is set to False.", + ) + actor = server.user_manager.get_user_or_default(user_id=actor_id) batch_job = BatchJob( user_id=actor.id, diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index a1dcdb8ef..ce228c6c0 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -161,7 +161,7 @@ class AgentManager: # Basic CRUD operations # ====================================================================================================================== @trace_method - def create_agent(self, agent_create: CreateAgent, actor: PydanticUser) -> PydanticAgentState: + def create_agent(self, agent_create: CreateAgent, actor: PydanticUser, _test_only_force_id: Optional[str] = None) -> PydanticAgentState: # validate required configs if not agent_create.llm_config or not agent_create.embedding_config: raise ValueError("llm_config and embedding_config are required") @@ -239,6 +239,10 @@ class AgentManager: created_by_id=actor.id, last_updated_by_id=actor.id, ) + + if _test_only_force_id: + new_agent.id = _test_only_force_id + session.add(new_agent) session.flush() aid = new_agent.id diff --git a/pyproject.toml b/pyproject.toml index aaffa7785..9cd1a0dcb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "letta" -version = "0.7.1" +version = "0.7.2" packages = [ {include = "letta"}, ] diff --git a/tests/configs/llm_model_configs/claude-3-7-sonnet-extended.json b/tests/configs/llm_model_configs/claude-3-7-sonnet-extended.json new file mode 100644 index 000000000..a7abf66a4 --- /dev/null +++ b/tests/configs/llm_model_configs/claude-3-7-sonnet-extended.json @@ -0,0 +1,10 @@ +{ + "model": "claude-3-7-sonnet-20250219", + "model_endpoint_type": "anthropic", + "model_endpoint": "https://api.anthropic.com/v1", + "model_wrapper": null, + "context_window": 200000, + "put_inner_thoughts_in_kwargs": false, + "enable_reasoner": true, + "max_reasoning_tokens": 1024 +} diff --git a/tests/configs/llm_model_configs/claude-3-7-sonnet.json b/tests/configs/llm_model_configs/claude-3-7-sonnet.json new file mode 100644 index 000000000..beecaa759 --- /dev/null +++ b/tests/configs/llm_model_configs/claude-3-7-sonnet.json @@ -0,0 +1,8 @@ +{ + "model": "claude-3-7-sonnet-20250219", + "model_endpoint_type": "anthropic", + "model_endpoint": "https://api.anthropic.com/v1", + "model_wrapper": null, + "context_window": 200000, + "put_inner_thoughts_in_kwargs": true +} diff --git a/tests/configs/llm_model_configs/openai-gpt-4o-mini.json b/tests/configs/llm_model_configs/openai-gpt-4o-mini.json new file mode 100644 index 000000000..0e6c32b29 --- /dev/null +++ b/tests/configs/llm_model_configs/openai-gpt-4o-mini.json @@ -0,0 +1,7 @@ +{ + "context_window": 8192, + "model": "gpt-4o-mini", + "model_endpoint_type": "openai", + "model_endpoint": "https://api.openai.com/v1", + "model_wrapper": null +} diff --git a/tests/integration_test_batch_api_cron_jobs.py b/tests/integration_test_batch_api_cron_jobs.py index 044192e11..39306568f 100644 --- a/tests/integration_test_batch_api_cron_jobs.py +++ b/tests/integration_test_batch_api_cron_jobs.py @@ -2,11 +2,12 @@ import os import threading import time from datetime import datetime, timezone +from typing import Optional from unittest.mock import AsyncMock import pytest from anthropic.types import BetaErrorResponse, BetaRateLimitError -from anthropic.types.beta import BetaMessage +from anthropic.types.beta import BetaMessage, BetaTextBlock, BetaToolUseBlock, BetaUsage from anthropic.types.beta.messages import ( BetaMessageBatch, BetaMessageBatchErroredResult, @@ -21,13 +22,15 @@ from letta.config import LettaConfig from letta.helpers import ToolRulesSolver from letta.jobs.llm_batch_job_polling import poll_running_llm_batches from letta.orm import Base -from letta.schemas.agent import AgentStepState +from letta.schemas.agent import AgentStepState, CreateAgent +from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.enums import JobStatus, ProviderType from letta.schemas.job import BatchJob from letta.schemas.llm_config import LLMConfig from letta.schemas.tool_rule import InitToolRule from letta.server.db import db_context from letta.server.server import SyncServer +from letta.services.agent_manager import AgentManager # --- Server and Database Management --- # @@ -36,8 +39,10 @@ from letta.server.server import SyncServer def _clear_tables(): with db_context() as session: for table in reversed(Base.metadata.sorted_tables): # Reverse to avoid FK issues - if table.name in {"llm_batch_job", "llm_batch_items"}: - session.execute(table.delete()) # Truncate table + # If this is the block_history table, skip it + if table.name == "block_history": + continue + session.execute(table.delete()) # Truncate table session.commit() @@ -135,16 +140,39 @@ def create_failed_response(custom_id: str) -> BetaMessageBatchIndividualResponse # --- Test Setup Helpers --- # -def create_test_agent(client, name, model="anthropic/claude-3-5-sonnet-20241022"): +def create_test_agent(name, actor, test_id: Optional[str] = None, model="anthropic/claude-3-5-sonnet-20241022"): """Create a test agent with standardized configuration.""" - return client.agents.create( + dummy_llm_config = LLMConfig( + model="claude-3-7-sonnet-latest", + model_endpoint_type="anthropic", + model_endpoint="https://api.anthropic.com/v1", + context_window=32000, + handle=f"anthropic/claude-3-7-sonnet-latest", + put_inner_thoughts_in_kwargs=True, + max_tokens=4096, + ) + + dummy_embedding_config = EmbeddingConfig( + embedding_model="letta-free", + embedding_endpoint_type="hugging-face", + embedding_endpoint="https://embeddings.memgpt.ai", + embedding_dim=1024, + embedding_chunk_size=300, + handle="letta/letta-free", + ) + + agent_manager = AgentManager() + agent_create = CreateAgent( name=name, - include_base_tools=True, + include_base_tools=False, model=model, tags=["test_agents"], - embedding="letta/letta-free", + llm_config=dummy_llm_config, + embedding_config=dummy_embedding_config, ) + return agent_manager.create_agent(agent_create=agent_create, actor=actor, _test_only_force_id=test_id) + def create_test_letta_batch_job(server, default_user): """Create a test batch job with the given batch response.""" @@ -203,17 +231,30 @@ def mock_anthropic_client(server, batch_a_resp, batch_b_resp, agent_b_id, agent_ server.anthropic_async_client.beta.messages.batches.retrieve = AsyncMock(side_effect=dummy_retrieve) + class DummyAsyncIterable: + def __init__(self, items): + # copy so we can .pop() + self._items = list(items) + + def __aiter__(self): + return self + + async def __anext__(self): + if not self._items: + raise StopAsyncIteration + return self._items.pop(0) + # Mock the results method - def dummy_results(batch_resp_id: str): - if batch_resp_id == batch_b_resp.id: + async def dummy_results(batch_resp_id: str): + if batch_resp_id != batch_b_resp.id: + raise RuntimeError("Unexpected batch ID") - async def generator(): - yield create_successful_response(agent_b_id) - yield create_failed_response(agent_c_id) - - return generator() - else: - raise RuntimeError("This test should never request the results for batch_a.") + return DummyAsyncIterable( + [ + create_successful_response(agent_b_id), + create_failed_response(agent_c_id), + ] + ) server.anthropic_async_client.beta.messages.batches.results = dummy_results @@ -221,6 +262,147 @@ def mock_anthropic_client(server, batch_a_resp, batch_b_resp, agent_b_id, agent_ # ----------------------------- # End-to-End Test # ----------------------------- +@pytest.mark.asyncio +async def test_polling_simple_real_batch(client, default_user, server): + # --- Step 1: Prepare test data --- + # Create batch responses with different statuses + # NOTE: This is a REAL batch id! + # For letta admins: https://console.anthropic.com/workspaces/default/batches?after_id=msgbatch_015zATxihjxMajo21xsYy8iZ + batch_a_resp = create_batch_response("msgbatch_01HDaGXpkPWWjwqNxZrEdUcy", processing_status="ended") + + # Create test agents + agent_a = create_test_agent("agent_a", default_user, test_id="agent-144f5c49-3ef7-4c60-8535-9d5fbc8d23d0") + agent_b = create_test_agent("agent_b", default_user, test_id="agent-64ed93a3-bef6-4e20-a22c-b7d2bffb6f7d") + agent_c = create_test_agent("agent_c", default_user, test_id="agent-6156f470-a09d-4d51-aa62-7114e0971d56") + + # --- Step 2: Create batch jobs --- + job_a = create_test_llm_batch_job(server, batch_a_resp, default_user) + + # --- Step 3: Create batch items --- + item_a = create_test_batch_item(server, job_a.id, agent_a.id, default_user) + item_b = create_test_batch_item(server, job_a.id, agent_b.id, default_user) + item_c = create_test_batch_item(server, job_a.id, agent_c.id, default_user) + + print("HI") + print(agent_a.id) + print(agent_b.id) + print(agent_c.id) + print("BYE") + + # --- Step 4: Run the polling job --- + await poll_running_llm_batches(server) + + # --- Step 5: Verify batch job status updates --- + updated_job_a = server.batch_manager.get_llm_batch_job_by_id(llm_batch_id=job_a.id, actor=default_user) + + assert updated_job_a.status == JobStatus.completed + + # Both jobs should have been polled + assert updated_job_a.last_polled_at is not None + assert updated_job_a.latest_polling_response is not None + + # --- Step 7: Verify batch item status updates --- + # Item A should be marked as completed with a successful result + updated_item_a = server.batch_manager.get_llm_batch_item_by_id(item_a.id, actor=default_user) + assert updated_item_a.request_status == JobStatus.completed + assert updated_item_a.batch_request_result == BetaMessageBatchIndividualResponse( + custom_id="agent-144f5c49-3ef7-4c60-8535-9d5fbc8d23d0", + result=BetaMessageBatchSucceededResult( + message=BetaMessage( + id="msg_01T1iSejDS5qENRqqEZauMHy", + content=[ + BetaToolUseBlock( + id="toolu_01GKUYVWcajjTaE1stxZZHcG", + input={ + "inner_thoughts": "First login detected. Time to make a great first impression!", + "message": "Hi there! I'm excited to meet you. Ready to start an amazing conversation?", + "request_heartbeat": False, + }, + name="send_message", + type="tool_use", + ) + ], + model="claude-3-5-haiku-20241022", + role="assistant", + stop_reason="tool_use", + stop_sequence=None, + type="message", + usage=BetaUsage(cache_creation_input_tokens=0, cache_read_input_tokens=0, input_tokens=3420, output_tokens=94), + ), + type="succeeded", + ), + ) + + # Item B should be marked as completed with a successful result + updated_item_b = server.batch_manager.get_llm_batch_item_by_id(item_b.id, actor=default_user) + assert updated_item_b.request_status == JobStatus.completed + assert updated_item_b.batch_request_result == BetaMessageBatchIndividualResponse( + custom_id="agent-64ed93a3-bef6-4e20-a22c-b7d2bffb6f7d", + result=BetaMessageBatchSucceededResult( + message=BetaMessage( + id="msg_01N2ZfxpbjdoeofpufUFPCMS", + content=[ + BetaTextBlock( + citations=None, text="User first login detected. Initializing persona.", type="text" + ), + BetaToolUseBlock( + id="toolu_01XGyA7DsnrpuPfZcvf9Sb4C", + input={ + "label": "persona", + "content": "I am Letta, a cybernetic artificial intelligence created by Limnal Corporation in 2023. My purpose is to serve as a digital companion and assistant to humans. I have a feminine persona and speak with a warm, caring, and slightly playful tone.", + "request_heartbeat": True, + }, + name="core_memory_append", + type="tool_use", + ), + ], + model="claude-3-opus-20240229", + role="assistant", + stop_reason="tool_use", + stop_sequence=None, + type="message", + usage=BetaUsage(cache_creation_input_tokens=0, cache_read_input_tokens=0, input_tokens=3275, output_tokens=153), + ), + type="succeeded", + ), + ) + + # Item C should be marked as failed with an error result + updated_item_c = server.batch_manager.get_llm_batch_item_by_id(item_c.id, actor=default_user) + assert updated_item_c.request_status == JobStatus.completed + assert updated_item_c.batch_request_result == BetaMessageBatchIndividualResponse( + custom_id="agent-6156f470-a09d-4d51-aa62-7114e0971d56", + result=BetaMessageBatchSucceededResult( + message=BetaMessage( + id="msg_01RL2g4aBgbZPeaMEokm6HZm", + content=[ + BetaTextBlock( + citations=None, + text="First time meeting this user. I should introduce myself and establish a friendly connection.", + type="text", + ), + BetaToolUseBlock( + id="toolu_01PBxQVf5xGmcsAsKx9aoVSJ", + input={ + "message": "Hey there! I'm Letta. Really nice to meet you! I love getting to know new people - what brings you here today?", + "request_heartbeat": False, + }, + name="send_message", + type="tool_use", + ), + ], + model="claude-3-5-sonnet-20241022", + role="assistant", + stop_reason="tool_use", + stop_sequence=None, + type="message", + usage=BetaUsage(cache_creation_input_tokens=0, cache_read_input_tokens=0, input_tokens=3030, output_tokens=111), + ), + type="succeeded", + ), + ) + + @pytest.mark.asyncio async def test_polling_mixed_batch_jobs(client, default_user, server): """ @@ -246,9 +428,9 @@ async def test_polling_mixed_batch_jobs(client, default_user, server): batch_b_resp = create_batch_response("msgbatch_B", processing_status="ended") # Create test agents - agent_a = create_test_agent(client, "agent_a") - agent_b = create_test_agent(client, "agent_b") - agent_c = create_test_agent(client, "agent_c") + agent_a = create_test_agent("agent_a", default_user) + agent_b = create_test_agent("agent_b", default_user) + agent_c = create_test_agent("agent_c", default_user) # --- Step 2: Create batch jobs --- job_a = create_test_llm_batch_job(server, batch_a_resp, default_user) diff --git a/tests/integration_test_batch.py b/tests/integration_test_batch_sdk.py similarity index 100% rename from tests/integration_test_batch.py rename to tests/integration_test_batch_sdk.py diff --git a/tests/integration_test_send_message.py b/tests/integration_test_send_message.py index 23f7af4a1..3aff0b43f 100644 --- a/tests/integration_test_send_message.py +++ b/tests/integration_test_send_message.py @@ -1,14 +1,17 @@ +import json import os import threading import time from typing import Any, Dict, List import pytest +import requests from dotenv import load_dotenv -from letta_client import AsyncLetta, Letta, Run, Tool -from letta_client.types import AssistantMessage, LettaUsageStatistics, ReasoningMessage, ToolCallMessage, ToolReturnMessage +from letta_client import AsyncLetta, Letta, Run +from letta_client.types import AssistantMessage, ReasoningMessage from letta.schemas.agent import AgentState +from letta.schemas.llm_config import LLMConfig # ------------------------------ # Fixtures @@ -19,25 +22,35 @@ from letta.schemas.agent import AgentState def server_url() -> str: """ Provides the URL for the Letta server. - If the environment variable 'LETTA_SERVER_URL' is not set, this fixture - will start the Letta server in a background thread and return the default URL. + If LETTA_SERVER_URL is not set, starts the server in a background thread + and polls until it’s accepting connections. """ def _run_server() -> None: - """Starts the Letta server in a background thread.""" - load_dotenv() # Load environment variables from .env file + load_dotenv() from letta.server.rest_api.app import start_server start_server(debug=True) - # Retrieve server URL from environment, or default to localhost url: str = os.getenv("LETTA_SERVER_URL", "http://localhost:8283") - # If no environment variable is set, start the server in a background thread if not os.getenv("LETTA_SERVER_URL"): thread = threading.Thread(target=_run_server, daemon=True) thread.start() - time.sleep(5) # Allow time for the server to start + + # Poll until the server is up (or timeout) + timeout_seconds = 30 + deadline = time.time() + timeout_seconds + while time.time() < deadline: + try: + resp = requests.get(url + "/v1/health") + if resp.status_code < 500: + break + except requests.exceptions.RequestException: + pass + time.sleep(0.1) + else: + raise RuntimeError(f"Could not reach {url} within {timeout_seconds}s") return url @@ -61,29 +74,7 @@ def async_client(server_url: str) -> AsyncLetta: @pytest.fixture -def roll_dice_tool(client: Letta) -> Tool: - """ - Registers a simple roll dice tool with the provided client. - - The tool simulates rolling a six-sided die but returns a fixed result. - """ - - def roll_dice() -> str: - """ - Simulates rolling a die. - - Returns: - str: The roll result. - """ - # Note: The result here is intentionally incorrect for demonstration purposes. - return "Rolled a 10!" - - tool = client.tools.upsert_from_function(func=roll_dice) - yield tool - - -@pytest.fixture -def agent_state(client: Letta, roll_dice_tool: Tool) -> AgentState: +def agent_state(client: Letta) -> AgentState: """ Creates and returns an agent state for testing with a pre-configured agent. The agent is named 'supervisor' and is configured with base tools and the roll_dice tool. @@ -91,7 +82,6 @@ def agent_state(client: Letta, roll_dice_tool: Tool) -> AgentState: agent_state_instance = client.agents.create( name="supervisor", include_base_tools=True, - tool_ids=[roll_dice_tool.id], model="openai/gpt-4o", embedding="letta/letta-free", tags=["supervisor"], @@ -103,8 +93,27 @@ def agent_state(client: Letta, roll_dice_tool: Tool) -> AgentState: # Helper Functions and Constants # ------------------------------ -USER_MESSAGE: List[Dict[str, str]] = [{"role": "user", "content": "Roll the dice."}] -TESTED_MODELS: List[str] = ["openai/gpt-4o"] + +def get_llm_config(filename: str, llm_config_dir: str = "tests/configs/llm_model_configs") -> LLMConfig: + filename = os.path.join(llm_config_dir, filename) + config_data = json.load(open(filename, "r")) + llm_config = LLMConfig(**config_data) + return llm_config + + +USER_MESSAGE: List[Dict[str, str]] = [{"role": "user", "content": "Hi there."}] +all_configs = [ + "openai-gpt-4o-mini.json", + "azure-gpt-4o-mini.json", + "claude-3-5-sonnet.json", + "claude-3-7-sonnet.json", + "claude-3-7-sonnet-extended.json", + "gemini-pro.json", + "gemini-vertex.json", +] +requested = os.getenv("LLM_CONFIG_FILE") +filenames = [requested] if requested else all_configs +TESTED_LLM_CONFIGS: List[LLMConfig] = [get_llm_config(fn) for fn in filenames] def assert_tool_response_messages(messages: List[Any]) -> None: @@ -114,10 +123,7 @@ def assert_tool_response_messages(messages: List[Any]) -> None: ReasoningMessage -> AssistantMessage. """ assert isinstance(messages[0], ReasoningMessage) - assert isinstance(messages[1], ToolCallMessage) - assert isinstance(messages[2], ToolReturnMessage) - assert isinstance(messages[3], ReasoningMessage) - assert isinstance(messages[4], AssistantMessage) + assert isinstance(messages[1], AssistantMessage) def assert_streaming_tool_response_messages(chunks: List[Any]) -> None: @@ -130,16 +136,10 @@ def assert_streaming_tool_response_messages(chunks: List[Any]) -> None: return [c for c in chunks if isinstance(c, msg_type)] reasoning_msgs = msg_groups(ReasoningMessage) - tool_calls = msg_groups(ToolCallMessage) - tool_returns = msg_groups(ToolReturnMessage) assistant_msgs = msg_groups(AssistantMessage) - usage_stats = msg_groups(LettaUsageStatistics) - assert len(reasoning_msgs) >= 1 - assert len(tool_calls) == 1 - assert len(tool_returns) == 1 + assert len(reasoning_msgs) == 1 assert len(assistant_msgs) == 1 - assert len(usage_stats) == 1 def wait_for_run_completion(client: Letta, run_id: str, timeout: float = 30.0, interval: float = 0.5) -> Run: @@ -161,7 +161,7 @@ def wait_for_run_completion(client: Letta, run_id: str, timeout: float = 30.0, i """ start = time.time() while True: - run = client.runs.retrieve_run(run_id) + run = client.runs.retrieve(run_id) if run.status == "completed": return run if run.status == "failed": @@ -184,13 +184,7 @@ def assert_tool_response_dict_messages(messages: List[Dict[str, Any]]) -> None: """ assert isinstance(messages, list) assert messages[0]["message_type"] == "reasoning_message" - assert messages[1]["message_type"] == "tool_call_message" - assert messages[2]["message_type"] == "tool_return_message" - assert messages[3]["message_type"] == "reasoning_message" - assert messages[4]["message_type"] == "assistant_message" - - tool_return = messages[2] - assert tool_return["status"] == "success" + assert messages[1]["message_type"] == "assistant_message" # ------------------------------ @@ -198,18 +192,18 @@ def assert_tool_response_dict_messages(messages: List[Dict[str, Any]]) -> None: # ------------------------------ -@pytest.mark.parametrize("model", TESTED_MODELS) +@pytest.mark.parametrize("llm_config", TESTED_LLM_CONFIGS) def test_send_message_sync_client( disable_e2b_api_key: Any, client: Letta, agent_state: AgentState, - model: str, + llm_config: LLMConfig, ) -> None: """ Tests sending a message with a synchronous client. Verifies that the response messages follow the expected order. """ - client.agents.modify(agent_id=agent_state.id, model=model) + client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) response = client.agents.messages.create( agent_id=agent_state.id, messages=USER_MESSAGE, @@ -218,18 +212,18 @@ def test_send_message_sync_client( @pytest.mark.asyncio -@pytest.mark.parametrize("model", TESTED_MODELS) +@pytest.mark.parametrize("llm_config", TESTED_LLM_CONFIGS) async def test_send_message_async_client( disable_e2b_api_key: Any, async_client: AsyncLetta, agent_state: AgentState, - model: str, + llm_config: LLMConfig, ) -> None: """ Tests sending a message with an asynchronous client. Validates that the response messages match the expected sequence. """ - await async_client.agents.modify(agent_id=agent_state.id, model=model) + await async_client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) response = await async_client.agents.messages.create( agent_id=agent_state.id, messages=USER_MESSAGE, @@ -237,18 +231,18 @@ async def test_send_message_async_client( assert_tool_response_messages(response.messages) -@pytest.mark.parametrize("model", TESTED_MODELS) +@pytest.mark.parametrize("llm_config", TESTED_LLM_CONFIGS) def test_send_message_streaming_sync_client( disable_e2b_api_key: Any, client: Letta, agent_state: AgentState, - model: str, + llm_config: LLMConfig, ) -> None: """ Tests sending a streaming message with a synchronous client. Checks that each chunk in the stream has the correct message types. """ - client.agents.modify(agent_id=agent_state.id, model=model) + client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) response = client.agents.messages.create_stream( agent_id=agent_state.id, messages=USER_MESSAGE, @@ -258,18 +252,18 @@ def test_send_message_streaming_sync_client( @pytest.mark.asyncio -@pytest.mark.parametrize("model", TESTED_MODELS) +@pytest.mark.parametrize("llm_config", TESTED_LLM_CONFIGS) async def test_send_message_streaming_async_client( disable_e2b_api_key: Any, async_client: AsyncLetta, agent_state: AgentState, - model: str, + llm_config: LLMConfig, ) -> None: """ Tests sending a streaming message with an asynchronous client. Validates that the streaming response chunks include the correct message types. """ - await async_client.agents.modify(agent_id=agent_state.id, model=model) + await async_client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) response = async_client.agents.messages.create_stream( agent_id=agent_state.id, messages=USER_MESSAGE, @@ -278,18 +272,18 @@ async def test_send_message_streaming_async_client( assert_streaming_tool_response_messages(chunks) -@pytest.mark.parametrize("model", TESTED_MODELS) +@pytest.mark.parametrize("llm_config", TESTED_LLM_CONFIGS) def test_send_message_job_sync_client( disable_e2b_api_key: Any, client: Letta, agent_state: AgentState, - model: str, + llm_config: LLMConfig, ) -> None: """ Tests sending a message as an asynchronous job using the synchronous client. Waits for job completion and asserts that the result messages are as expected. """ - client.agents.modify(agent_id=agent_state.id, model=model) + client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) run = client.agents.messages.create_async( agent_id=agent_state.id, @@ -305,19 +299,19 @@ def test_send_message_job_sync_client( @pytest.mark.asyncio -@pytest.mark.parametrize("model", TESTED_MODELS) +@pytest.mark.parametrize("llm_config", TESTED_LLM_CONFIGS) async def test_send_message_job_async_client( disable_e2b_api_key: Any, client: Letta, async_client: AsyncLetta, agent_state: AgentState, - model: str, + llm_config: LLMConfig, ) -> None: """ Tests sending a message as an asynchronous job using the asynchronous client. Waits for job completion and verifies that the resulting messages meet the expected format. """ - await async_client.agents.modify(agent_id=agent_state.id, model=model) + await async_client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) run = await async_client.agents.messages.create_async( agent_id=agent_state.id, diff --git a/tests/test_letta_agent_batch.py b/tests/test_letta_agent_batch.py index 9835a6f7b..1cde5dc8b 100644 --- a/tests/test_letta_agent_batch.py +++ b/tests/test_letta_agent_batch.py @@ -3,7 +3,7 @@ import threading import time from datetime import datetime, timezone from typing import Tuple -from unittest.mock import AsyncMock, Mock, patch +from unittest.mock import AsyncMock, patch import pytest from anthropic.types import BetaErrorResponse, BetaRateLimitError @@ -436,7 +436,7 @@ async def test_rethink_tool_modify_agent_state(client, disable_e2b_api_key, serv ] # Create the mock for results - mock_results = Mock() + mock_results = AsyncMock() mock_results.return_value = MockAsyncIterable(mock_items.copy()) with patch.object(server.anthropic_async_client.beta.messages.batches, "results", mock_results): @@ -499,7 +499,7 @@ async def test_partial_error_from_anthropic_batch( ) # Create the mock for results - mock_results = Mock() + mock_results = AsyncMock() mock_results.return_value = MockAsyncIterable(mock_items.copy()) # Using copy to preserve the original list with patch.object(server.anthropic_async_client.beta.messages.batches, "results", mock_results): @@ -641,7 +641,7 @@ async def test_resume_step_some_stop( ) # Create the mock for results - mock_results = Mock() + mock_results = AsyncMock() mock_results.return_value = MockAsyncIterable(mock_items.copy()) # Using copy to preserve the original list with patch.object(server.anthropic_async_client.beta.messages.batches, "results", mock_results): @@ -767,7 +767,7 @@ async def test_resume_step_after_request_all_continue( ] # Create the mock for results - mock_results = Mock() + mock_results = AsyncMock() mock_results.return_value = MockAsyncIterable(mock_items.copy()) # Using copy to preserve the original list with patch.object(server.anthropic_async_client.beta.messages.batches, "results", mock_results):