mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
chore: bump version 0.7.2 (#2584)
Co-authored-by: Matthew Zhou <mattzh1314@gmail.com> Co-authored-by: Charles Packer <packercharles@gmail.com>
This commit is contained in:
parent
435b754286
commit
6495180ee2
@ -1,4 +1,4 @@
|
||||
__version__ = "0.7.1"
|
||||
__version__ = "0.7.2"
|
||||
|
||||
# import clients
|
||||
from letta.client.client import LocalClient, RESTClient, create_client
|
||||
|
@ -10,7 +10,7 @@ def get_composio_api_key(actor: User, logger: Optional[Logger] = None) -> Option
|
||||
api_keys = SandboxConfigManager().list_sandbox_env_vars_by_key(key="COMPOSIO_API_KEY", actor=actor)
|
||||
if not api_keys:
|
||||
if logger:
|
||||
logger.warning(f"No API keys found for Composio. Defaulting to the environment variable...")
|
||||
logger.debug(f"No API keys found for Composio. Defaulting to the environment variable...")
|
||||
if tool_settings.composio_api_key:
|
||||
return tool_settings.composio_api_key
|
||||
else:
|
||||
|
@ -66,6 +66,15 @@ def get_utc_time() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
def get_utc_time_int() -> int:
|
||||
return int(get_utc_time().timestamp())
|
||||
|
||||
|
||||
def timestamp_to_datetime(timestamp_seconds: int) -> datetime:
|
||||
"""Convert Unix timestamp in seconds to UTC datetime object"""
|
||||
return datetime.fromtimestamp(timestamp_seconds, tz=timezone.utc)
|
||||
|
||||
|
||||
def format_datetime(dt):
|
||||
return dt.strftime("%Y-%m-%d %I:%M:%S %p %Z%z")
|
||||
|
||||
|
@ -73,7 +73,8 @@ async def fetch_batch_items(server: SyncServer, batch_id: str, batch_resp_id: st
|
||||
"""
|
||||
updates = []
|
||||
try:
|
||||
async for item_result in server.anthropic_async_client.beta.messages.batches.results(batch_resp_id):
|
||||
results = await server.anthropic_async_client.beta.messages.batches.results(batch_resp_id)
|
||||
async for item_result in results:
|
||||
# Here, custom_id should be the agent_id
|
||||
item_status = map_anthropic_individual_batch_item_status_to_job_status(item_result)
|
||||
updates.append(ItemUpdateInfo(batch_id, item_result.custom_id, item_status, item_result))
|
||||
|
@ -20,7 +20,7 @@ from anthropic.types.beta import (
|
||||
)
|
||||
|
||||
from letta.errors import BedrockError, BedrockPermissionError
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
from letta.helpers.datetime_helpers import get_utc_time_int, timestamp_to_datetime
|
||||
from letta.llm_api.aws_bedrock import get_bedrock_client
|
||||
from letta.llm_api.helpers import add_inner_thoughts_to_functions
|
||||
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION
|
||||
@ -396,7 +396,7 @@ def convert_anthropic_response_to_chatcompletion(
|
||||
return ChatCompletionResponse(
|
||||
id=response.id,
|
||||
choices=[choice],
|
||||
created=get_utc_time(),
|
||||
created=get_utc_time_int(),
|
||||
model=response.model,
|
||||
usage=UsageStatistics(
|
||||
prompt_tokens=prompt_tokens,
|
||||
@ -451,7 +451,7 @@ def convert_anthropic_stream_event_to_chatcompletion(
|
||||
'logprobs': None
|
||||
}
|
||||
],
|
||||
'created': datetime.datetime(2025, 1, 24, 0, 18, 55, tzinfo=TzInfo(UTC)),
|
||||
'created': 1713216662,
|
||||
'model': 'gpt-4o-mini-2024-07-18',
|
||||
'system_fingerprint': 'fp_bd83329f63',
|
||||
'object': 'chat.completion.chunk'
|
||||
@ -613,7 +613,7 @@ def convert_anthropic_stream_event_to_chatcompletion(
|
||||
return ChatCompletionChunkResponse(
|
||||
id=message_id,
|
||||
choices=[choice],
|
||||
created=get_utc_time(),
|
||||
created=get_utc_time_int(),
|
||||
model=model,
|
||||
output_tokens=completion_chunk_tokens,
|
||||
)
|
||||
@ -920,7 +920,7 @@ def anthropic_chat_completions_process_stream(
|
||||
chat_completion_response = ChatCompletionResponse(
|
||||
id=dummy_message.id if create_message_id else TEMP_STREAM_RESPONSE_ID,
|
||||
choices=[],
|
||||
created=dummy_message.created_at,
|
||||
created=int(dummy_message.created_at.timestamp()),
|
||||
model=chat_completion_request.model,
|
||||
usage=UsageStatistics(
|
||||
prompt_tokens=prompt_tokens,
|
||||
@ -954,7 +954,11 @@ def anthropic_chat_completions_process_stream(
|
||||
message_type = stream_interface.process_chunk(
|
||||
chat_completion_chunk,
|
||||
message_id=chat_completion_response.id if create_message_id else chat_completion_chunk.id,
|
||||
message_date=chat_completion_response.created if create_message_datetime else chat_completion_chunk.created,
|
||||
message_date=(
|
||||
timestamp_to_datetime(chat_completion_response.created)
|
||||
if create_message_datetime
|
||||
else timestamp_to_datetime(chat_completion_chunk.created)
|
||||
),
|
||||
# if extended_thinking is on, then reasoning_content will be flowing as chunks
|
||||
# TODO handle emitting redacted reasoning content (e.g. as concat?)
|
||||
expect_reasoning_content=extended_thinking,
|
||||
|
@ -22,7 +22,7 @@ from letta.errors import (
|
||||
LLMServerError,
|
||||
LLMUnprocessableEntityError,
|
||||
)
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
from letta.helpers.datetime_helpers import get_utc_time_int
|
||||
from letta.llm_api.helpers import add_inner_thoughts_to_functions, unpack_all_inner_thoughts_from_kwargs
|
||||
from letta.llm_api.llm_client_base import LLMClientBase
|
||||
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION
|
||||
@ -403,7 +403,7 @@ class AnthropicClient(LLMClientBase):
|
||||
chat_completion_response = ChatCompletionResponse(
|
||||
id=response.id,
|
||||
choices=[choice],
|
||||
created=get_utc_time(),
|
||||
created=get_utc_time_int(),
|
||||
model=response.model,
|
||||
usage=UsageStatistics(
|
||||
prompt_tokens=prompt_tokens,
|
||||
|
@ -4,7 +4,7 @@ from typing import List, Optional, Union
|
||||
|
||||
import requests
|
||||
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
from letta.helpers.datetime_helpers import get_utc_time_int
|
||||
from letta.helpers.json_helpers import json_dumps
|
||||
from letta.local_llm.utils import count_tokens
|
||||
from letta.schemas.message import Message
|
||||
@ -207,7 +207,7 @@ def convert_cohere_response_to_chatcompletion(
|
||||
return ChatCompletionResponse(
|
||||
id=response_json["response_id"],
|
||||
choices=[choice],
|
||||
created=get_utc_time(),
|
||||
created=get_utc_time_int(),
|
||||
model=model,
|
||||
usage=UsageStatistics(
|
||||
prompt_tokens=prompt_tokens,
|
||||
|
@ -6,7 +6,7 @@ import requests
|
||||
from google.genai.types import FunctionCallingConfig, FunctionCallingConfigMode, ToolConfig
|
||||
|
||||
from letta.constants import NON_USER_MSG_PREFIX
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
from letta.helpers.datetime_helpers import get_utc_time_int
|
||||
from letta.helpers.json_helpers import json_dumps
|
||||
from letta.llm_api.helpers import make_post_request
|
||||
from letta.llm_api.llm_client_base import LLMClientBase
|
||||
@ -260,7 +260,7 @@ class GoogleAIClient(LLMClientBase):
|
||||
id=response_id,
|
||||
choices=choices,
|
||||
model=self.llm_config.model, # NOTE: Google API doesn't pass back model in the response
|
||||
created=get_utc_time(),
|
||||
created=get_utc_time_int(),
|
||||
usage=usage,
|
||||
)
|
||||
except KeyError as e:
|
||||
|
@ -4,7 +4,7 @@ from typing import List, Optional
|
||||
from google import genai
|
||||
from google.genai.types import FunctionCallingConfig, FunctionCallingConfigMode, GenerateContentResponse, ThinkingConfig, ToolConfig
|
||||
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
from letta.helpers.datetime_helpers import get_utc_time_int
|
||||
from letta.helpers.json_helpers import json_dumps
|
||||
from letta.llm_api.google_ai_client import GoogleAIClient
|
||||
from letta.local_llm.json_parser import clean_json_string_extra_backslash
|
||||
@ -234,7 +234,7 @@ class GoogleVertexClient(GoogleAIClient):
|
||||
id=response_id,
|
||||
choices=choices,
|
||||
model=self.llm_config.model, # NOTE: Google API doesn't pass back model in the response
|
||||
created=get_utc_time(),
|
||||
created=get_utc_time_int(),
|
||||
usage=usage,
|
||||
)
|
||||
except KeyError as e:
|
||||
|
@ -4,7 +4,9 @@ from typing import Generator, List, Optional, Union
|
||||
import requests
|
||||
from openai import OpenAI
|
||||
|
||||
from letta.helpers.datetime_helpers import timestamp_to_datetime
|
||||
from letta.llm_api.helpers import add_inner_thoughts_to_functions, convert_to_structured_output, make_post_request
|
||||
from letta.llm_api.openai_client import supports_parallel_tool_calling, supports_temperature_param
|
||||
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION, INNER_THOUGHTS_KWARG_DESCRIPTION_GO_FIRST
|
||||
from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages
|
||||
from letta.log import get_logger
|
||||
@ -135,7 +137,7 @@ def build_openai_chat_completions_request(
|
||||
tool_choice=tool_choice,
|
||||
user=str(user_id),
|
||||
max_completion_tokens=llm_config.max_tokens,
|
||||
temperature=1.0 if llm_config.enable_reasoner else llm_config.temperature,
|
||||
temperature=llm_config.temperature if supports_temperature_param(model) else None,
|
||||
reasoning_effort=llm_config.reasoning_effort,
|
||||
)
|
||||
else:
|
||||
@ -237,7 +239,7 @@ def openai_chat_completions_process_stream(
|
||||
chat_completion_response = ChatCompletionResponse(
|
||||
id=dummy_message.id if create_message_id else TEMP_STREAM_RESPONSE_ID,
|
||||
choices=[],
|
||||
created=dummy_message.created_at, # NOTE: doesn't matter since both will do get_utc_time()
|
||||
created=int(dummy_message.created_at.timestamp()), # NOTE: doesn't matter since both will do get_utc_time()
|
||||
model=chat_completion_request.model,
|
||||
usage=UsageStatistics(
|
||||
completion_tokens=0,
|
||||
@ -274,7 +276,11 @@ def openai_chat_completions_process_stream(
|
||||
message_type = stream_interface.process_chunk(
|
||||
chat_completion_chunk,
|
||||
message_id=chat_completion_response.id if create_message_id else chat_completion_chunk.id,
|
||||
message_date=chat_completion_response.created if create_message_datetime else chat_completion_chunk.created,
|
||||
message_date=(
|
||||
timestamp_to_datetime(chat_completion_response.created)
|
||||
if create_message_datetime
|
||||
else timestamp_to_datetime(chat_completion_chunk.created)
|
||||
),
|
||||
expect_reasoning_content=expect_reasoning_content,
|
||||
name=name,
|
||||
message_index=message_idx,
|
||||
@ -489,6 +495,7 @@ def prepare_openai_payload(chat_completion_request: ChatCompletionRequest):
|
||||
# except ValueError as e:
|
||||
# warnings.warn(f"Failed to convert tool function to structured output, tool={tool}, error={e}")
|
||||
|
||||
if "o3-mini" in chat_completion_request.model or "o1" in chat_completion_request.model:
|
||||
if not supports_parallel_tool_calling(chat_completion_request.model):
|
||||
data.pop("parallel_tool_calls", None)
|
||||
|
||||
return data
|
||||
|
@ -34,6 +34,33 @@ from letta.settings import model_settings
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def is_openai_reasoning_model(model: str) -> bool:
|
||||
"""Utility function to check if the model is a 'reasoner'"""
|
||||
|
||||
# NOTE: needs to be updated with new model releases
|
||||
return model.startswith("o1") or model.startswith("o3")
|
||||
|
||||
|
||||
def supports_temperature_param(model: str) -> bool:
|
||||
"""Certain OpenAI models don't support configuring the temperature.
|
||||
|
||||
Example error: 400 - {'error': {'message': "Unsupported parameter: 'temperature' is not supported with this model.", 'type': 'invalid_request_error', 'param': 'temperature', 'code': 'unsupported_parameter'}}
|
||||
"""
|
||||
if is_openai_reasoning_model(model):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def supports_parallel_tool_calling(model: str) -> bool:
|
||||
"""Certain OpenAI models don't support parallel tool calls."""
|
||||
|
||||
if is_openai_reasoning_model(model):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
class OpenAIClient(LLMClientBase):
|
||||
def _prepare_client_kwargs(self) -> dict:
|
||||
api_key = model_settings.openai_api_key or os.environ.get("OPENAI_API_KEY")
|
||||
@ -66,7 +93,8 @@ class OpenAIClient(LLMClientBase):
|
||||
put_inner_thoughts_first=True,
|
||||
)
|
||||
|
||||
use_developer_message = llm_config.model.startswith("o1") or llm_config.model.startswith("o3") # o-series models
|
||||
use_developer_message = is_openai_reasoning_model(llm_config.model)
|
||||
|
||||
openai_message_list = [
|
||||
cast_message_to_subtype(
|
||||
m.to_openai_dict(
|
||||
@ -103,7 +131,7 @@ class OpenAIClient(LLMClientBase):
|
||||
tool_choice=tool_choice,
|
||||
user=str(),
|
||||
max_completion_tokens=llm_config.max_tokens,
|
||||
temperature=llm_config.temperature,
|
||||
temperature=llm_config.temperature if supports_temperature_param(model) else None,
|
||||
)
|
||||
|
||||
if "inference.memgpt.ai" in llm_config.model_endpoint:
|
||||
@ -160,6 +188,10 @@ class OpenAIClient(LLMClientBase):
|
||||
response=chat_completion_response, inner_thoughts_key=INNER_THOUGHTS_KWARG
|
||||
)
|
||||
|
||||
# If we used a reasoning model, create a content part for the ommitted reasoning
|
||||
if is_openai_reasoning_model(self.llm_config.model):
|
||||
chat_completion_response.choices[0].message.ommitted_reasoning_content = True
|
||||
|
||||
return chat_completion_response
|
||||
|
||||
def stream(self, request_data: dict) -> Stream[ChatCompletionChunk]:
|
||||
|
@ -6,7 +6,7 @@ import requests
|
||||
|
||||
from letta.constants import CLI_WARNING_PREFIX
|
||||
from letta.errors import LocalLLMConnectionError, LocalLLMError
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
from letta.helpers.datetime_helpers import get_utc_time_int
|
||||
from letta.helpers.json_helpers import json_dumps
|
||||
from letta.local_llm.constants import DEFAULT_WRAPPER
|
||||
from letta.local_llm.function_parser import patch_function
|
||||
@ -241,7 +241,7 @@ def get_chat_completion(
|
||||
),
|
||||
)
|
||||
],
|
||||
created=get_utc_time(),
|
||||
created=get_utc_time_int(),
|
||||
model=model,
|
||||
# "This fingerprint represents the backend configuration that the model runs with."
|
||||
# system_fingerprint=user if user is not None else "null",
|
||||
|
@ -145,7 +145,8 @@ class OmittedReasoningContent(MessageContent):
|
||||
type: Literal[MessageContentType.omitted_reasoning] = Field(
|
||||
MessageContentType.omitted_reasoning, description="Indicates this is an omitted reasoning step."
|
||||
)
|
||||
tokens: int = Field(..., description="The reasoning token count for intermediate reasoning content.")
|
||||
# NOTE: dropping because we don't track this kind of information for the other reasoning types
|
||||
# tokens: int = Field(..., description="The reasoning token count for intermediate reasoning content.")
|
||||
|
||||
|
||||
LettaMessageContentUnion = Annotated[
|
||||
|
@ -81,8 +81,11 @@ class LLMConfig(BaseModel):
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def set_default_enable_reasoner(cls, values):
|
||||
if any(openai_reasoner_model in values.get("model", "") for openai_reasoner_model in ["o3-mini", "o1"]):
|
||||
values["enable_reasoner"] = True
|
||||
# NOTE: this is really only applicable for models that can toggle reasoning on-and-off, like 3.7
|
||||
# We can also use this field to identify if a model is a "reasoning" model (o1/o3, etc.) if we want
|
||||
# if any(openai_reasoner_model in values.get("model", "") for openai_reasoner_model in ["o3-mini", "o1"]):
|
||||
# values["enable_reasoner"] = True
|
||||
# values["put_inner_thoughts_in_kwargs"] = False
|
||||
return values
|
||||
|
||||
@model_validator(mode="before")
|
||||
@ -100,6 +103,13 @@ class LLMConfig(BaseModel):
|
||||
if values.get("put_inner_thoughts_in_kwargs") is None:
|
||||
values["put_inner_thoughts_in_kwargs"] = False if model in avoid_put_inner_thoughts_in_kwargs else True
|
||||
|
||||
# For the o1/o3 series from OpenAI, set to False by default
|
||||
# We can set this flag to `true` if desired, which will enable "double-think"
|
||||
from letta.llm_api.openai_client import is_openai_reasoning_model
|
||||
|
||||
if is_openai_reasoning_model(model):
|
||||
values["put_inner_thoughts_in_kwargs"] = False
|
||||
|
||||
return values
|
||||
|
||||
@model_validator(mode="after")
|
||||
|
@ -31,6 +31,7 @@ from letta.schemas.letta_message import (
|
||||
)
|
||||
from letta.schemas.letta_message_content import (
|
||||
LettaMessageContentUnion,
|
||||
OmittedReasoningContent,
|
||||
ReasoningContent,
|
||||
RedactedReasoningContent,
|
||||
TextContent,
|
||||
@ -295,6 +296,18 @@ class Message(BaseMessage):
|
||||
sender_id=self.sender_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(content_part, OmittedReasoningContent):
|
||||
# Special case for "hidden reasoning" models like o1/o3
|
||||
# NOTE: we also have to think about how to return this during streaming
|
||||
messages.append(
|
||||
HiddenReasoningMessage(
|
||||
id=self.id,
|
||||
date=self.created_at,
|
||||
state="omitted",
|
||||
name=self.name,
|
||||
otid=otid,
|
||||
)
|
||||
)
|
||||
else:
|
||||
warnings.warn(f"Unrecognized content part in assistant message: {content_part}")
|
||||
|
||||
@ -464,6 +477,10 @@ class Message(BaseMessage):
|
||||
data=openai_message_dict["redacted_reasoning_content"] if "redacted_reasoning_content" in openai_message_dict else None,
|
||||
),
|
||||
)
|
||||
if "omitted_reasoning_content" in openai_message_dict and openai_message_dict["omitted_reasoning_content"]:
|
||||
content.append(
|
||||
OmittedReasoningContent(),
|
||||
)
|
||||
|
||||
# If we're going from deprecated function form
|
||||
if openai_message_dict["role"] == "function":
|
||||
|
@ -39,9 +39,10 @@ class Message(BaseModel):
|
||||
tool_calls: Optional[List[ToolCall]] = None
|
||||
role: str
|
||||
function_call: Optional[FunctionCall] = None # Deprecated
|
||||
reasoning_content: Optional[str] = None # Used in newer reasoning APIs
|
||||
reasoning_content: Optional[str] = None # Used in newer reasoning APIs, e.g. DeepSeek
|
||||
reasoning_content_signature: Optional[str] = None # NOTE: for Anthropic
|
||||
redacted_reasoning_content: Optional[str] = None # NOTE: for Anthropic
|
||||
ommitted_reasoning_content: bool = False # NOTE: for OpenAI o1/o3
|
||||
|
||||
|
||||
class Choice(BaseModel):
|
||||
@ -52,16 +53,64 @@ class Choice(BaseModel):
|
||||
seed: Optional[int] = None # found in TogetherAI
|
||||
|
||||
|
||||
class UsageStatisticsPromptTokenDetails(BaseModel):
|
||||
cached_tokens: int = 0
|
||||
# NOTE: OAI specific
|
||||
# audio_tokens: int = 0
|
||||
|
||||
def __add__(self, other: "UsageStatisticsPromptTokenDetails") -> "UsageStatisticsPromptTokenDetails":
|
||||
return UsageStatisticsPromptTokenDetails(
|
||||
cached_tokens=self.cached_tokens + other.cached_tokens,
|
||||
)
|
||||
|
||||
|
||||
class UsageStatisticsCompletionTokenDetails(BaseModel):
|
||||
reasoning_tokens: int = 0
|
||||
# NOTE: OAI specific
|
||||
# audio_tokens: int = 0
|
||||
# accepted_prediction_tokens: int = 0
|
||||
# rejected_prediction_tokens: int = 0
|
||||
|
||||
def __add__(self, other: "UsageStatisticsCompletionTokenDetails") -> "UsageStatisticsCompletionTokenDetails":
|
||||
return UsageStatisticsCompletionTokenDetails(
|
||||
reasoning_tokens=self.reasoning_tokens + other.reasoning_tokens,
|
||||
)
|
||||
|
||||
|
||||
class UsageStatistics(BaseModel):
|
||||
completion_tokens: int = 0
|
||||
prompt_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
|
||||
prompt_tokens_details: Optional[UsageStatisticsPromptTokenDetails] = None
|
||||
completion_tokens_details: Optional[UsageStatisticsCompletionTokenDetails] = None
|
||||
|
||||
def __add__(self, other: "UsageStatistics") -> "UsageStatistics":
|
||||
|
||||
if self.prompt_tokens_details is None and other.prompt_tokens_details is None:
|
||||
total_prompt_tokens_details = None
|
||||
elif self.prompt_tokens_details is None:
|
||||
total_prompt_tokens_details = other.prompt_tokens_details
|
||||
elif other.prompt_tokens_details is None:
|
||||
total_prompt_tokens_details = self.prompt_tokens_details
|
||||
else:
|
||||
total_prompt_tokens_details = self.prompt_tokens_details + other.prompt_tokens_details
|
||||
|
||||
if self.completion_tokens_details is None and other.completion_tokens_details is None:
|
||||
total_completion_tokens_details = None
|
||||
elif self.completion_tokens_details is None:
|
||||
total_completion_tokens_details = other.completion_tokens_details
|
||||
elif other.completion_tokens_details is None:
|
||||
total_completion_tokens_details = self.completion_tokens_details
|
||||
else:
|
||||
total_completion_tokens_details = self.completion_tokens_details + other.completion_tokens_details
|
||||
|
||||
return UsageStatistics(
|
||||
completion_tokens=self.completion_tokens + other.completion_tokens,
|
||||
prompt_tokens=self.prompt_tokens + other.prompt_tokens,
|
||||
total_tokens=self.total_tokens + other.total_tokens,
|
||||
prompt_tokens_details=total_prompt_tokens_details,
|
||||
completion_tokens_details=total_completion_tokens_details,
|
||||
)
|
||||
|
||||
|
||||
@ -70,7 +119,7 @@ class ChatCompletionResponse(BaseModel):
|
||||
|
||||
id: str
|
||||
choices: List[Choice]
|
||||
created: datetime.datetime
|
||||
created: Union[datetime.datetime, int]
|
||||
model: Optional[str] = None # NOTE: this is not consistent with OpenAI API standard, however is necessary to support local LLMs
|
||||
# system_fingerprint: str # docs say this is mandatory, but in reality API returns None
|
||||
system_fingerprint: Optional[str] = None
|
||||
@ -138,7 +187,7 @@ class ChatCompletionChunkResponse(BaseModel):
|
||||
|
||||
id: str
|
||||
choices: List[ChunkChoice]
|
||||
created: Union[datetime.datetime, str]
|
||||
created: Union[datetime.datetime, int]
|
||||
model: str
|
||||
# system_fingerprint: str # docs say this is mandatory, but in reality API returns None
|
||||
system_fingerprint: Optional[str] = None
|
||||
|
@ -238,7 +238,7 @@ class ChatCompletionsStreamingInterface(AgentChunkStreamingInterface):
|
||||
return ChatCompletionChunk(
|
||||
id=chunk.id,
|
||||
object=chunk.object,
|
||||
created=chunk.created.timestamp(),
|
||||
created=chunk.created,
|
||||
model=chunk.model,
|
||||
choices=[
|
||||
Choice(
|
||||
@ -256,7 +256,7 @@ class ChatCompletionsStreamingInterface(AgentChunkStreamingInterface):
|
||||
return ChatCompletionChunk(
|
||||
id=chunk.id,
|
||||
object=chunk.object,
|
||||
created=chunk.created.timestamp(),
|
||||
created=chunk.created,
|
||||
model=chunk.model,
|
||||
choices=[
|
||||
Choice(
|
||||
|
@ -1001,7 +1001,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
# Example case that would trigger here:
|
||||
# id='chatcmpl-AKtUvREgRRvgTW6n8ZafiKuV0mxhQ'
|
||||
# choices=[ChunkChoice(finish_reason=None, index=0, delta=MessageDelta(content=None, tool_calls=None, function_call=None), logprobs=None)]
|
||||
# created=datetime.datetime(2024, 10, 21, 20, 40, 57, tzinfo=TzInfo(UTC))
|
||||
# created=1713216662
|
||||
# model='gpt-4o-mini-2024-07-18'
|
||||
# object='chat.completion.chunk'
|
||||
warnings.warn(f"Couldn't find delta in chunk: {chunk}")
|
||||
|
@ -1,6 +1,6 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, Header
|
||||
from fastapi import APIRouter, Body, Depends, Header, status
|
||||
from fastapi.exceptions import HTTPException
|
||||
from starlette.requests import Request
|
||||
|
||||
@ -11,6 +11,7 @@ from letta.schemas.job import BatchJob, JobStatus, JobType, JobUpdate
|
||||
from letta.schemas.letta_request import CreateBatch
|
||||
from letta.server.rest_api.utils import get_letta_server
|
||||
from letta.server.server import SyncServer
|
||||
from letta.settings import settings
|
||||
|
||||
router = APIRouter(prefix="/messages", tags=["messages"])
|
||||
|
||||
@ -43,6 +44,13 @@ async def create_messages_batch(
|
||||
if length > max_bytes:
|
||||
raise HTTPException(status_code=413, detail=f"Request too large ({length} bytes). Max is {max_bytes} bytes.")
|
||||
|
||||
# Reject request if env var is not set
|
||||
if not settings.enable_batch_job_polling:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Server misconfiguration: LETTA_ENABLE_BATCH_JOB_POLLING is set to False.",
|
||||
)
|
||||
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
batch_job = BatchJob(
|
||||
user_id=actor.id,
|
||||
|
@ -161,7 +161,7 @@ class AgentManager:
|
||||
# Basic CRUD operations
|
||||
# ======================================================================================================================
|
||||
@trace_method
|
||||
def create_agent(self, agent_create: CreateAgent, actor: PydanticUser) -> PydanticAgentState:
|
||||
def create_agent(self, agent_create: CreateAgent, actor: PydanticUser, _test_only_force_id: Optional[str] = None) -> PydanticAgentState:
|
||||
# validate required configs
|
||||
if not agent_create.llm_config or not agent_create.embedding_config:
|
||||
raise ValueError("llm_config and embedding_config are required")
|
||||
@ -239,6 +239,10 @@ class AgentManager:
|
||||
created_by_id=actor.id,
|
||||
last_updated_by_id=actor.id,
|
||||
)
|
||||
|
||||
if _test_only_force_id:
|
||||
new_agent.id = _test_only_force_id
|
||||
|
||||
session.add(new_agent)
|
||||
session.flush()
|
||||
aid = new_agent.id
|
||||
|
@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "letta"
|
||||
version = "0.7.1"
|
||||
version = "0.7.2"
|
||||
packages = [
|
||||
{include = "letta"},
|
||||
]
|
||||
|
@ -0,0 +1,10 @@
|
||||
{
|
||||
"model": "claude-3-7-sonnet-20250219",
|
||||
"model_endpoint_type": "anthropic",
|
||||
"model_endpoint": "https://api.anthropic.com/v1",
|
||||
"model_wrapper": null,
|
||||
"context_window": 200000,
|
||||
"put_inner_thoughts_in_kwargs": false,
|
||||
"enable_reasoner": true,
|
||||
"max_reasoning_tokens": 1024
|
||||
}
|
8
tests/configs/llm_model_configs/claude-3-7-sonnet.json
Normal file
8
tests/configs/llm_model_configs/claude-3-7-sonnet.json
Normal file
@ -0,0 +1,8 @@
|
||||
{
|
||||
"model": "claude-3-7-sonnet-20250219",
|
||||
"model_endpoint_type": "anthropic",
|
||||
"model_endpoint": "https://api.anthropic.com/v1",
|
||||
"model_wrapper": null,
|
||||
"context_window": 200000,
|
||||
"put_inner_thoughts_in_kwargs": true
|
||||
}
|
7
tests/configs/llm_model_configs/openai-gpt-4o-mini.json
Normal file
7
tests/configs/llm_model_configs/openai-gpt-4o-mini.json
Normal file
@ -0,0 +1,7 @@
|
||||
{
|
||||
"context_window": 8192,
|
||||
"model": "gpt-4o-mini",
|
||||
"model_endpoint_type": "openai",
|
||||
"model_endpoint": "https://api.openai.com/v1",
|
||||
"model_wrapper": null
|
||||
}
|
@ -2,11 +2,12 @@ import os
|
||||
import threading
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
from anthropic.types import BetaErrorResponse, BetaRateLimitError
|
||||
from anthropic.types.beta import BetaMessage
|
||||
from anthropic.types.beta import BetaMessage, BetaTextBlock, BetaToolUseBlock, BetaUsage
|
||||
from anthropic.types.beta.messages import (
|
||||
BetaMessageBatch,
|
||||
BetaMessageBatchErroredResult,
|
||||
@ -21,13 +22,15 @@ from letta.config import LettaConfig
|
||||
from letta.helpers import ToolRulesSolver
|
||||
from letta.jobs.llm_batch_job_polling import poll_running_llm_batches
|
||||
from letta.orm import Base
|
||||
from letta.schemas.agent import AgentStepState
|
||||
from letta.schemas.agent import AgentStepState, CreateAgent
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import JobStatus, ProviderType
|
||||
from letta.schemas.job import BatchJob
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.tool_rule import InitToolRule
|
||||
from letta.server.db import db_context
|
||||
from letta.server.server import SyncServer
|
||||
from letta.services.agent_manager import AgentManager
|
||||
|
||||
# --- Server and Database Management --- #
|
||||
|
||||
@ -36,8 +39,10 @@ from letta.server.server import SyncServer
|
||||
def _clear_tables():
|
||||
with db_context() as session:
|
||||
for table in reversed(Base.metadata.sorted_tables): # Reverse to avoid FK issues
|
||||
if table.name in {"llm_batch_job", "llm_batch_items"}:
|
||||
session.execute(table.delete()) # Truncate table
|
||||
# If this is the block_history table, skip it
|
||||
if table.name == "block_history":
|
||||
continue
|
||||
session.execute(table.delete()) # Truncate table
|
||||
session.commit()
|
||||
|
||||
|
||||
@ -135,16 +140,39 @@ def create_failed_response(custom_id: str) -> BetaMessageBatchIndividualResponse
|
||||
# --- Test Setup Helpers --- #
|
||||
|
||||
|
||||
def create_test_agent(client, name, model="anthropic/claude-3-5-sonnet-20241022"):
|
||||
def create_test_agent(name, actor, test_id: Optional[str] = None, model="anthropic/claude-3-5-sonnet-20241022"):
|
||||
"""Create a test agent with standardized configuration."""
|
||||
return client.agents.create(
|
||||
dummy_llm_config = LLMConfig(
|
||||
model="claude-3-7-sonnet-latest",
|
||||
model_endpoint_type="anthropic",
|
||||
model_endpoint="https://api.anthropic.com/v1",
|
||||
context_window=32000,
|
||||
handle=f"anthropic/claude-3-7-sonnet-latest",
|
||||
put_inner_thoughts_in_kwargs=True,
|
||||
max_tokens=4096,
|
||||
)
|
||||
|
||||
dummy_embedding_config = EmbeddingConfig(
|
||||
embedding_model="letta-free",
|
||||
embedding_endpoint_type="hugging-face",
|
||||
embedding_endpoint="https://embeddings.memgpt.ai",
|
||||
embedding_dim=1024,
|
||||
embedding_chunk_size=300,
|
||||
handle="letta/letta-free",
|
||||
)
|
||||
|
||||
agent_manager = AgentManager()
|
||||
agent_create = CreateAgent(
|
||||
name=name,
|
||||
include_base_tools=True,
|
||||
include_base_tools=False,
|
||||
model=model,
|
||||
tags=["test_agents"],
|
||||
embedding="letta/letta-free",
|
||||
llm_config=dummy_llm_config,
|
||||
embedding_config=dummy_embedding_config,
|
||||
)
|
||||
|
||||
return agent_manager.create_agent(agent_create=agent_create, actor=actor, _test_only_force_id=test_id)
|
||||
|
||||
|
||||
def create_test_letta_batch_job(server, default_user):
|
||||
"""Create a test batch job with the given batch response."""
|
||||
@ -203,17 +231,30 @@ def mock_anthropic_client(server, batch_a_resp, batch_b_resp, agent_b_id, agent_
|
||||
|
||||
server.anthropic_async_client.beta.messages.batches.retrieve = AsyncMock(side_effect=dummy_retrieve)
|
||||
|
||||
class DummyAsyncIterable:
|
||||
def __init__(self, items):
|
||||
# copy so we can .pop()
|
||||
self._items = list(items)
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
if not self._items:
|
||||
raise StopAsyncIteration
|
||||
return self._items.pop(0)
|
||||
|
||||
# Mock the results method
|
||||
def dummy_results(batch_resp_id: str):
|
||||
if batch_resp_id == batch_b_resp.id:
|
||||
async def dummy_results(batch_resp_id: str):
|
||||
if batch_resp_id != batch_b_resp.id:
|
||||
raise RuntimeError("Unexpected batch ID")
|
||||
|
||||
async def generator():
|
||||
yield create_successful_response(agent_b_id)
|
||||
yield create_failed_response(agent_c_id)
|
||||
|
||||
return generator()
|
||||
else:
|
||||
raise RuntimeError("This test should never request the results for batch_a.")
|
||||
return DummyAsyncIterable(
|
||||
[
|
||||
create_successful_response(agent_b_id),
|
||||
create_failed_response(agent_c_id),
|
||||
]
|
||||
)
|
||||
|
||||
server.anthropic_async_client.beta.messages.batches.results = dummy_results
|
||||
|
||||
@ -221,6 +262,147 @@ def mock_anthropic_client(server, batch_a_resp, batch_b_resp, agent_b_id, agent_
|
||||
# -----------------------------
|
||||
# End-to-End Test
|
||||
# -----------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_polling_simple_real_batch(client, default_user, server):
|
||||
# --- Step 1: Prepare test data ---
|
||||
# Create batch responses with different statuses
|
||||
# NOTE: This is a REAL batch id!
|
||||
# For letta admins: https://console.anthropic.com/workspaces/default/batches?after_id=msgbatch_015zATxihjxMajo21xsYy8iZ
|
||||
batch_a_resp = create_batch_response("msgbatch_01HDaGXpkPWWjwqNxZrEdUcy", processing_status="ended")
|
||||
|
||||
# Create test agents
|
||||
agent_a = create_test_agent("agent_a", default_user, test_id="agent-144f5c49-3ef7-4c60-8535-9d5fbc8d23d0")
|
||||
agent_b = create_test_agent("agent_b", default_user, test_id="agent-64ed93a3-bef6-4e20-a22c-b7d2bffb6f7d")
|
||||
agent_c = create_test_agent("agent_c", default_user, test_id="agent-6156f470-a09d-4d51-aa62-7114e0971d56")
|
||||
|
||||
# --- Step 2: Create batch jobs ---
|
||||
job_a = create_test_llm_batch_job(server, batch_a_resp, default_user)
|
||||
|
||||
# --- Step 3: Create batch items ---
|
||||
item_a = create_test_batch_item(server, job_a.id, agent_a.id, default_user)
|
||||
item_b = create_test_batch_item(server, job_a.id, agent_b.id, default_user)
|
||||
item_c = create_test_batch_item(server, job_a.id, agent_c.id, default_user)
|
||||
|
||||
print("HI")
|
||||
print(agent_a.id)
|
||||
print(agent_b.id)
|
||||
print(agent_c.id)
|
||||
print("BYE")
|
||||
|
||||
# --- Step 4: Run the polling job ---
|
||||
await poll_running_llm_batches(server)
|
||||
|
||||
# --- Step 5: Verify batch job status updates ---
|
||||
updated_job_a = server.batch_manager.get_llm_batch_job_by_id(llm_batch_id=job_a.id, actor=default_user)
|
||||
|
||||
assert updated_job_a.status == JobStatus.completed
|
||||
|
||||
# Both jobs should have been polled
|
||||
assert updated_job_a.last_polled_at is not None
|
||||
assert updated_job_a.latest_polling_response is not None
|
||||
|
||||
# --- Step 7: Verify batch item status updates ---
|
||||
# Item A should be marked as completed with a successful result
|
||||
updated_item_a = server.batch_manager.get_llm_batch_item_by_id(item_a.id, actor=default_user)
|
||||
assert updated_item_a.request_status == JobStatus.completed
|
||||
assert updated_item_a.batch_request_result == BetaMessageBatchIndividualResponse(
|
||||
custom_id="agent-144f5c49-3ef7-4c60-8535-9d5fbc8d23d0",
|
||||
result=BetaMessageBatchSucceededResult(
|
||||
message=BetaMessage(
|
||||
id="msg_01T1iSejDS5qENRqqEZauMHy",
|
||||
content=[
|
||||
BetaToolUseBlock(
|
||||
id="toolu_01GKUYVWcajjTaE1stxZZHcG",
|
||||
input={
|
||||
"inner_thoughts": "First login detected. Time to make a great first impression!",
|
||||
"message": "Hi there! I'm excited to meet you. Ready to start an amazing conversation?",
|
||||
"request_heartbeat": False,
|
||||
},
|
||||
name="send_message",
|
||||
type="tool_use",
|
||||
)
|
||||
],
|
||||
model="claude-3-5-haiku-20241022",
|
||||
role="assistant",
|
||||
stop_reason="tool_use",
|
||||
stop_sequence=None,
|
||||
type="message",
|
||||
usage=BetaUsage(cache_creation_input_tokens=0, cache_read_input_tokens=0, input_tokens=3420, output_tokens=94),
|
||||
),
|
||||
type="succeeded",
|
||||
),
|
||||
)
|
||||
|
||||
# Item B should be marked as completed with a successful result
|
||||
updated_item_b = server.batch_manager.get_llm_batch_item_by_id(item_b.id, actor=default_user)
|
||||
assert updated_item_b.request_status == JobStatus.completed
|
||||
assert updated_item_b.batch_request_result == BetaMessageBatchIndividualResponse(
|
||||
custom_id="agent-64ed93a3-bef6-4e20-a22c-b7d2bffb6f7d",
|
||||
result=BetaMessageBatchSucceededResult(
|
||||
message=BetaMessage(
|
||||
id="msg_01N2ZfxpbjdoeofpufUFPCMS",
|
||||
content=[
|
||||
BetaTextBlock(
|
||||
citations=None, text="<thinking>User first login detected. Initializing persona.</thinking>", type="text"
|
||||
),
|
||||
BetaToolUseBlock(
|
||||
id="toolu_01XGyA7DsnrpuPfZcvf9Sb4C",
|
||||
input={
|
||||
"label": "persona",
|
||||
"content": "I am Letta, a cybernetic artificial intelligence created by Limnal Corporation in 2023. My purpose is to serve as a digital companion and assistant to humans. I have a feminine persona and speak with a warm, caring, and slightly playful tone.",
|
||||
"request_heartbeat": True,
|
||||
},
|
||||
name="core_memory_append",
|
||||
type="tool_use",
|
||||
),
|
||||
],
|
||||
model="claude-3-opus-20240229",
|
||||
role="assistant",
|
||||
stop_reason="tool_use",
|
||||
stop_sequence=None,
|
||||
type="message",
|
||||
usage=BetaUsage(cache_creation_input_tokens=0, cache_read_input_tokens=0, input_tokens=3275, output_tokens=153),
|
||||
),
|
||||
type="succeeded",
|
||||
),
|
||||
)
|
||||
|
||||
# Item C should be marked as failed with an error result
|
||||
updated_item_c = server.batch_manager.get_llm_batch_item_by_id(item_c.id, actor=default_user)
|
||||
assert updated_item_c.request_status == JobStatus.completed
|
||||
assert updated_item_c.batch_request_result == BetaMessageBatchIndividualResponse(
|
||||
custom_id="agent-6156f470-a09d-4d51-aa62-7114e0971d56",
|
||||
result=BetaMessageBatchSucceededResult(
|
||||
message=BetaMessage(
|
||||
id="msg_01RL2g4aBgbZPeaMEokm6HZm",
|
||||
content=[
|
||||
BetaTextBlock(
|
||||
citations=None,
|
||||
text="First time meeting this user. I should introduce myself and establish a friendly connection.</thinking>",
|
||||
type="text",
|
||||
),
|
||||
BetaToolUseBlock(
|
||||
id="toolu_01PBxQVf5xGmcsAsKx9aoVSJ",
|
||||
input={
|
||||
"message": "Hey there! I'm Letta. Really nice to meet you! I love getting to know new people - what brings you here today?",
|
||||
"request_heartbeat": False,
|
||||
},
|
||||
name="send_message",
|
||||
type="tool_use",
|
||||
),
|
||||
],
|
||||
model="claude-3-5-sonnet-20241022",
|
||||
role="assistant",
|
||||
stop_reason="tool_use",
|
||||
stop_sequence=None,
|
||||
type="message",
|
||||
usage=BetaUsage(cache_creation_input_tokens=0, cache_read_input_tokens=0, input_tokens=3030, output_tokens=111),
|
||||
),
|
||||
type="succeeded",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_polling_mixed_batch_jobs(client, default_user, server):
|
||||
"""
|
||||
@ -246,9 +428,9 @@ async def test_polling_mixed_batch_jobs(client, default_user, server):
|
||||
batch_b_resp = create_batch_response("msgbatch_B", processing_status="ended")
|
||||
|
||||
# Create test agents
|
||||
agent_a = create_test_agent(client, "agent_a")
|
||||
agent_b = create_test_agent(client, "agent_b")
|
||||
agent_c = create_test_agent(client, "agent_c")
|
||||
agent_a = create_test_agent("agent_a", default_user)
|
||||
agent_b = create_test_agent("agent_b", default_user)
|
||||
agent_c = create_test_agent("agent_c", default_user)
|
||||
|
||||
# --- Step 2: Create batch jobs ---
|
||||
job_a = create_test_llm_batch_job(server, batch_a_resp, default_user)
|
||||
|
@ -1,14 +1,17 @@
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from dotenv import load_dotenv
|
||||
from letta_client import AsyncLetta, Letta, Run, Tool
|
||||
from letta_client.types import AssistantMessage, LettaUsageStatistics, ReasoningMessage, ToolCallMessage, ToolReturnMessage
|
||||
from letta_client import AsyncLetta, Letta, Run
|
||||
from letta_client.types import AssistantMessage, ReasoningMessage
|
||||
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
|
||||
# ------------------------------
|
||||
# Fixtures
|
||||
@ -19,25 +22,35 @@ from letta.schemas.agent import AgentState
|
||||
def server_url() -> str:
|
||||
"""
|
||||
Provides the URL for the Letta server.
|
||||
If the environment variable 'LETTA_SERVER_URL' is not set, this fixture
|
||||
will start the Letta server in a background thread and return the default URL.
|
||||
If LETTA_SERVER_URL is not set, starts the server in a background thread
|
||||
and polls until it’s accepting connections.
|
||||
"""
|
||||
|
||||
def _run_server() -> None:
|
||||
"""Starts the Letta server in a background thread."""
|
||||
load_dotenv() # Load environment variables from .env file
|
||||
load_dotenv()
|
||||
from letta.server.rest_api.app import start_server
|
||||
|
||||
start_server(debug=True)
|
||||
|
||||
# Retrieve server URL from environment, or default to localhost
|
||||
url: str = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
|
||||
|
||||
# If no environment variable is set, start the server in a background thread
|
||||
if not os.getenv("LETTA_SERVER_URL"):
|
||||
thread = threading.Thread(target=_run_server, daemon=True)
|
||||
thread.start()
|
||||
time.sleep(5) # Allow time for the server to start
|
||||
|
||||
# Poll until the server is up (or timeout)
|
||||
timeout_seconds = 30
|
||||
deadline = time.time() + timeout_seconds
|
||||
while time.time() < deadline:
|
||||
try:
|
||||
resp = requests.get(url + "/v1/health")
|
||||
if resp.status_code < 500:
|
||||
break
|
||||
except requests.exceptions.RequestException:
|
||||
pass
|
||||
time.sleep(0.1)
|
||||
else:
|
||||
raise RuntimeError(f"Could not reach {url} within {timeout_seconds}s")
|
||||
|
||||
return url
|
||||
|
||||
@ -61,29 +74,7 @@ def async_client(server_url: str) -> AsyncLetta:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def roll_dice_tool(client: Letta) -> Tool:
|
||||
"""
|
||||
Registers a simple roll dice tool with the provided client.
|
||||
|
||||
The tool simulates rolling a six-sided die but returns a fixed result.
|
||||
"""
|
||||
|
||||
def roll_dice() -> str:
|
||||
"""
|
||||
Simulates rolling a die.
|
||||
|
||||
Returns:
|
||||
str: The roll result.
|
||||
"""
|
||||
# Note: The result here is intentionally incorrect for demonstration purposes.
|
||||
return "Rolled a 10!"
|
||||
|
||||
tool = client.tools.upsert_from_function(func=roll_dice)
|
||||
yield tool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def agent_state(client: Letta, roll_dice_tool: Tool) -> AgentState:
|
||||
def agent_state(client: Letta) -> AgentState:
|
||||
"""
|
||||
Creates and returns an agent state for testing with a pre-configured agent.
|
||||
The agent is named 'supervisor' and is configured with base tools and the roll_dice tool.
|
||||
@ -91,7 +82,6 @@ def agent_state(client: Letta, roll_dice_tool: Tool) -> AgentState:
|
||||
agent_state_instance = client.agents.create(
|
||||
name="supervisor",
|
||||
include_base_tools=True,
|
||||
tool_ids=[roll_dice_tool.id],
|
||||
model="openai/gpt-4o",
|
||||
embedding="letta/letta-free",
|
||||
tags=["supervisor"],
|
||||
@ -103,8 +93,27 @@ def agent_state(client: Letta, roll_dice_tool: Tool) -> AgentState:
|
||||
# Helper Functions and Constants
|
||||
# ------------------------------
|
||||
|
||||
USER_MESSAGE: List[Dict[str, str]] = [{"role": "user", "content": "Roll the dice."}]
|
||||
TESTED_MODELS: List[str] = ["openai/gpt-4o"]
|
||||
|
||||
def get_llm_config(filename: str, llm_config_dir: str = "tests/configs/llm_model_configs") -> LLMConfig:
|
||||
filename = os.path.join(llm_config_dir, filename)
|
||||
config_data = json.load(open(filename, "r"))
|
||||
llm_config = LLMConfig(**config_data)
|
||||
return llm_config
|
||||
|
||||
|
||||
USER_MESSAGE: List[Dict[str, str]] = [{"role": "user", "content": "Hi there."}]
|
||||
all_configs = [
|
||||
"openai-gpt-4o-mini.json",
|
||||
"azure-gpt-4o-mini.json",
|
||||
"claude-3-5-sonnet.json",
|
||||
"claude-3-7-sonnet.json",
|
||||
"claude-3-7-sonnet-extended.json",
|
||||
"gemini-pro.json",
|
||||
"gemini-vertex.json",
|
||||
]
|
||||
requested = os.getenv("LLM_CONFIG_FILE")
|
||||
filenames = [requested] if requested else all_configs
|
||||
TESTED_LLM_CONFIGS: List[LLMConfig] = [get_llm_config(fn) for fn in filenames]
|
||||
|
||||
|
||||
def assert_tool_response_messages(messages: List[Any]) -> None:
|
||||
@ -114,10 +123,7 @@ def assert_tool_response_messages(messages: List[Any]) -> None:
|
||||
ReasoningMessage -> AssistantMessage.
|
||||
"""
|
||||
assert isinstance(messages[0], ReasoningMessage)
|
||||
assert isinstance(messages[1], ToolCallMessage)
|
||||
assert isinstance(messages[2], ToolReturnMessage)
|
||||
assert isinstance(messages[3], ReasoningMessage)
|
||||
assert isinstance(messages[4], AssistantMessage)
|
||||
assert isinstance(messages[1], AssistantMessage)
|
||||
|
||||
|
||||
def assert_streaming_tool_response_messages(chunks: List[Any]) -> None:
|
||||
@ -130,16 +136,10 @@ def assert_streaming_tool_response_messages(chunks: List[Any]) -> None:
|
||||
return [c for c in chunks if isinstance(c, msg_type)]
|
||||
|
||||
reasoning_msgs = msg_groups(ReasoningMessage)
|
||||
tool_calls = msg_groups(ToolCallMessage)
|
||||
tool_returns = msg_groups(ToolReturnMessage)
|
||||
assistant_msgs = msg_groups(AssistantMessage)
|
||||
usage_stats = msg_groups(LettaUsageStatistics)
|
||||
|
||||
assert len(reasoning_msgs) >= 1
|
||||
assert len(tool_calls) == 1
|
||||
assert len(tool_returns) == 1
|
||||
assert len(reasoning_msgs) == 1
|
||||
assert len(assistant_msgs) == 1
|
||||
assert len(usage_stats) == 1
|
||||
|
||||
|
||||
def wait_for_run_completion(client: Letta, run_id: str, timeout: float = 30.0, interval: float = 0.5) -> Run:
|
||||
@ -161,7 +161,7 @@ def wait_for_run_completion(client: Letta, run_id: str, timeout: float = 30.0, i
|
||||
"""
|
||||
start = time.time()
|
||||
while True:
|
||||
run = client.runs.retrieve_run(run_id)
|
||||
run = client.runs.retrieve(run_id)
|
||||
if run.status == "completed":
|
||||
return run
|
||||
if run.status == "failed":
|
||||
@ -184,13 +184,7 @@ def assert_tool_response_dict_messages(messages: List[Dict[str, Any]]) -> None:
|
||||
"""
|
||||
assert isinstance(messages, list)
|
||||
assert messages[0]["message_type"] == "reasoning_message"
|
||||
assert messages[1]["message_type"] == "tool_call_message"
|
||||
assert messages[2]["message_type"] == "tool_return_message"
|
||||
assert messages[3]["message_type"] == "reasoning_message"
|
||||
assert messages[4]["message_type"] == "assistant_message"
|
||||
|
||||
tool_return = messages[2]
|
||||
assert tool_return["status"] == "success"
|
||||
assert messages[1]["message_type"] == "assistant_message"
|
||||
|
||||
|
||||
# ------------------------------
|
||||
@ -198,18 +192,18 @@ def assert_tool_response_dict_messages(messages: List[Dict[str, Any]]) -> None:
|
||||
# ------------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", TESTED_MODELS)
|
||||
@pytest.mark.parametrize("llm_config", TESTED_LLM_CONFIGS)
|
||||
def test_send_message_sync_client(
|
||||
disable_e2b_api_key: Any,
|
||||
client: Letta,
|
||||
agent_state: AgentState,
|
||||
model: str,
|
||||
llm_config: LLMConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Tests sending a message with a synchronous client.
|
||||
Verifies that the response messages follow the expected order.
|
||||
"""
|
||||
client.agents.modify(agent_id=agent_state.id, model=model)
|
||||
client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE,
|
||||
@ -218,18 +212,18 @@ def test_send_message_sync_client(
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model", TESTED_MODELS)
|
||||
@pytest.mark.parametrize("llm_config", TESTED_LLM_CONFIGS)
|
||||
async def test_send_message_async_client(
|
||||
disable_e2b_api_key: Any,
|
||||
async_client: AsyncLetta,
|
||||
agent_state: AgentState,
|
||||
model: str,
|
||||
llm_config: LLMConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Tests sending a message with an asynchronous client.
|
||||
Validates that the response messages match the expected sequence.
|
||||
"""
|
||||
await async_client.agents.modify(agent_id=agent_state.id, model=model)
|
||||
await async_client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
response = await async_client.agents.messages.create(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE,
|
||||
@ -237,18 +231,18 @@ async def test_send_message_async_client(
|
||||
assert_tool_response_messages(response.messages)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", TESTED_MODELS)
|
||||
@pytest.mark.parametrize("llm_config", TESTED_LLM_CONFIGS)
|
||||
def test_send_message_streaming_sync_client(
|
||||
disable_e2b_api_key: Any,
|
||||
client: Letta,
|
||||
agent_state: AgentState,
|
||||
model: str,
|
||||
llm_config: LLMConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Tests sending a streaming message with a synchronous client.
|
||||
Checks that each chunk in the stream has the correct message types.
|
||||
"""
|
||||
client.agents.modify(agent_id=agent_state.id, model=model)
|
||||
client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
response = client.agents.messages.create_stream(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE,
|
||||
@ -258,18 +252,18 @@ def test_send_message_streaming_sync_client(
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model", TESTED_MODELS)
|
||||
@pytest.mark.parametrize("llm_config", TESTED_LLM_CONFIGS)
|
||||
async def test_send_message_streaming_async_client(
|
||||
disable_e2b_api_key: Any,
|
||||
async_client: AsyncLetta,
|
||||
agent_state: AgentState,
|
||||
model: str,
|
||||
llm_config: LLMConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Tests sending a streaming message with an asynchronous client.
|
||||
Validates that the streaming response chunks include the correct message types.
|
||||
"""
|
||||
await async_client.agents.modify(agent_id=agent_state.id, model=model)
|
||||
await async_client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
response = async_client.agents.messages.create_stream(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE,
|
||||
@ -278,18 +272,18 @@ async def test_send_message_streaming_async_client(
|
||||
assert_streaming_tool_response_messages(chunks)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", TESTED_MODELS)
|
||||
@pytest.mark.parametrize("llm_config", TESTED_LLM_CONFIGS)
|
||||
def test_send_message_job_sync_client(
|
||||
disable_e2b_api_key: Any,
|
||||
client: Letta,
|
||||
agent_state: AgentState,
|
||||
model: str,
|
||||
llm_config: LLMConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Tests sending a message as an asynchronous job using the synchronous client.
|
||||
Waits for job completion and asserts that the result messages are as expected.
|
||||
"""
|
||||
client.agents.modify(agent_id=agent_state.id, model=model)
|
||||
client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
|
||||
run = client.agents.messages.create_async(
|
||||
agent_id=agent_state.id,
|
||||
@ -305,19 +299,19 @@ def test_send_message_job_sync_client(
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model", TESTED_MODELS)
|
||||
@pytest.mark.parametrize("llm_config", TESTED_LLM_CONFIGS)
|
||||
async def test_send_message_job_async_client(
|
||||
disable_e2b_api_key: Any,
|
||||
client: Letta,
|
||||
async_client: AsyncLetta,
|
||||
agent_state: AgentState,
|
||||
model: str,
|
||||
llm_config: LLMConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Tests sending a message as an asynchronous job using the asynchronous client.
|
||||
Waits for job completion and verifies that the resulting messages meet the expected format.
|
||||
"""
|
||||
await async_client.agents.modify(agent_id=agent_state.id, model=model)
|
||||
await async_client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
|
||||
run = await async_client.agents.messages.create_async(
|
||||
agent_id=agent_state.id,
|
||||
|
@ -3,7 +3,7 @@ import threading
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from typing import Tuple
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from anthropic.types import BetaErrorResponse, BetaRateLimitError
|
||||
@ -436,7 +436,7 @@ async def test_rethink_tool_modify_agent_state(client, disable_e2b_api_key, serv
|
||||
]
|
||||
|
||||
# Create the mock for results
|
||||
mock_results = Mock()
|
||||
mock_results = AsyncMock()
|
||||
mock_results.return_value = MockAsyncIterable(mock_items.copy())
|
||||
|
||||
with patch.object(server.anthropic_async_client.beta.messages.batches, "results", mock_results):
|
||||
@ -499,7 +499,7 @@ async def test_partial_error_from_anthropic_batch(
|
||||
)
|
||||
|
||||
# Create the mock for results
|
||||
mock_results = Mock()
|
||||
mock_results = AsyncMock()
|
||||
mock_results.return_value = MockAsyncIterable(mock_items.copy()) # Using copy to preserve the original list
|
||||
|
||||
with patch.object(server.anthropic_async_client.beta.messages.batches, "results", mock_results):
|
||||
@ -641,7 +641,7 @@ async def test_resume_step_some_stop(
|
||||
)
|
||||
|
||||
# Create the mock for results
|
||||
mock_results = Mock()
|
||||
mock_results = AsyncMock()
|
||||
mock_results.return_value = MockAsyncIterable(mock_items.copy()) # Using copy to preserve the original list
|
||||
|
||||
with patch.object(server.anthropic_async_client.beta.messages.batches, "results", mock_results):
|
||||
@ -767,7 +767,7 @@ async def test_resume_step_after_request_all_continue(
|
||||
]
|
||||
|
||||
# Create the mock for results
|
||||
mock_results = Mock()
|
||||
mock_results = AsyncMock()
|
||||
mock_results.return_value = MockAsyncIterable(mock_items.copy()) # Using copy to preserve the original list
|
||||
|
||||
with patch.object(server.anthropic_async_client.beta.messages.batches, "results", mock_results):
|
||||
|
Loading…
Reference in New Issue
Block a user