fix: patch o-series (#1699)

This commit is contained in:
Charles Packer 2025-04-23 13:41:34 -07:00 committed by GitHub
parent 801188c0e7
commit 69ba0607e4
7 changed files with 120 additions and 9 deletions

View File

@ -10,7 +10,7 @@ def get_composio_api_key(actor: User, logger: Optional[Logger] = None) -> Option
api_keys = SandboxConfigManager().list_sandbox_env_vars_by_key(key="COMPOSIO_API_KEY", actor=actor)
if not api_keys:
if logger:
logger.warning(f"No API keys found for Composio. Defaulting to the environment variable...")
logger.debug(f"No API keys found for Composio. Defaulting to the environment variable...")
if tool_settings.composio_api_key:
return tool_settings.composio_api_key
else:

View File

@ -5,6 +5,7 @@ import requests
from openai import OpenAI
from letta.llm_api.helpers import add_inner_thoughts_to_functions, convert_to_structured_output, make_post_request
from letta.llm_api.openai_client import supports_parallel_tool_calling, supports_temperature_param
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION, INNER_THOUGHTS_KWARG_DESCRIPTION_GO_FIRST
from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages
from letta.log import get_logger
@ -135,7 +136,7 @@ def build_openai_chat_completions_request(
tool_choice=tool_choice,
user=str(user_id),
max_completion_tokens=llm_config.max_tokens,
temperature=1.0 if llm_config.enable_reasoner else llm_config.temperature,
temperature=llm_config.temperature if supports_temperature_param(model) else None,
reasoning_effort=llm_config.reasoning_effort,
)
else:
@ -489,6 +490,7 @@ def prepare_openai_payload(chat_completion_request: ChatCompletionRequest):
# except ValueError as e:
# warnings.warn(f"Failed to convert tool function to structured output, tool={tool}, error={e}")
if "o3-mini" in chat_completion_request.model or "o1" in chat_completion_request.model:
if not supports_parallel_tool_calling(chat_completion_request.model):
data.pop("parallel_tool_calls", None)
return data

View File

@ -34,6 +34,33 @@ from letta.settings import model_settings
logger = get_logger(__name__)
def is_openai_reasoning_model(model: str) -> bool:
"""Utility function to check if the model is a 'reasoner'"""
# NOTE: needs to be updated with new model releases
return model.startswith("o1") or model.startswith("o3")
def supports_temperature_param(model: str) -> bool:
"""Certain OpenAI models don't support configuring the temperature.
Example error: 400 - {'error': {'message': "Unsupported parameter: 'temperature' is not supported with this model.", 'type': 'invalid_request_error', 'param': 'temperature', 'code': 'unsupported_parameter'}}
"""
if is_openai_reasoning_model(model):
return False
else:
return True
def supports_parallel_tool_calling(model: str) -> bool:
"""Certain OpenAI models don't support parallel tool calls."""
if is_openai_reasoning_model(model):
return False
else:
return True
class OpenAIClient(LLMClientBase):
def _prepare_client_kwargs(self) -> dict:
api_key = model_settings.openai_api_key or os.environ.get("OPENAI_API_KEY")
@ -66,7 +93,8 @@ class OpenAIClient(LLMClientBase):
put_inner_thoughts_first=True,
)
use_developer_message = llm_config.model.startswith("o1") or llm_config.model.startswith("o3") # o-series models
use_developer_message = is_openai_reasoning_model(llm_config.model)
openai_message_list = [
cast_message_to_subtype(
m.to_openai_dict(
@ -103,7 +131,7 @@ class OpenAIClient(LLMClientBase):
tool_choice=tool_choice,
user=str(),
max_completion_tokens=llm_config.max_tokens,
temperature=llm_config.temperature,
temperature=llm_config.temperature if supports_temperature_param(model) else None,
)
if "inference.memgpt.ai" in llm_config.model_endpoint:
@ -160,6 +188,10 @@ class OpenAIClient(LLMClientBase):
response=chat_completion_response, inner_thoughts_key=INNER_THOUGHTS_KWARG
)
# If we used a reasoning model, create a content part for the ommitted reasoning
if is_openai_reasoning_model(self.llm_config.model):
chat_completion_response.choices[0].message.ommitted_reasoning_content = True
return chat_completion_response
def stream(self, request_data: dict) -> Stream[ChatCompletionChunk]:

View File

@ -145,7 +145,8 @@ class OmittedReasoningContent(MessageContent):
type: Literal[MessageContentType.omitted_reasoning] = Field(
MessageContentType.omitted_reasoning, description="Indicates this is an omitted reasoning step."
)
tokens: int = Field(..., description="The reasoning token count for intermediate reasoning content.")
# NOTE: dropping because we don't track this kind of information for the other reasoning types
# tokens: int = Field(..., description="The reasoning token count for intermediate reasoning content.")
LettaMessageContentUnion = Annotated[

View File

@ -81,8 +81,11 @@ class LLMConfig(BaseModel):
@model_validator(mode="before")
@classmethod
def set_default_enable_reasoner(cls, values):
if any(openai_reasoner_model in values.get("model", "") for openai_reasoner_model in ["o3-mini", "o1"]):
values["enable_reasoner"] = True
# NOTE: this is really only applicable for models that can toggle reasoning on-and-off, like 3.7
# We can also use this field to identify if a model is a "reasoning" model (o1/o3, etc.) if we want
# if any(openai_reasoner_model in values.get("model", "") for openai_reasoner_model in ["o3-mini", "o1"]):
# values["enable_reasoner"] = True
# values["put_inner_thoughts_in_kwargs"] = False
return values
@model_validator(mode="before")
@ -100,6 +103,13 @@ class LLMConfig(BaseModel):
if values.get("put_inner_thoughts_in_kwargs") is None:
values["put_inner_thoughts_in_kwargs"] = False if model in avoid_put_inner_thoughts_in_kwargs else True
# For the o1/o3 series from OpenAI, set to False by default
# We can set this flag to `true` if desired, which will enable "double-think"
from letta.llm_api.openai_client import is_openai_reasoning_model
if is_openai_reasoning_model(model):
values["put_inner_thoughts_in_kwargs"] = False
return values
@model_validator(mode="after")

View File

@ -31,6 +31,7 @@ from letta.schemas.letta_message import (
)
from letta.schemas.letta_message_content import (
LettaMessageContentUnion,
OmittedReasoningContent,
ReasoningContent,
RedactedReasoningContent,
TextContent,
@ -295,6 +296,18 @@ class Message(BaseMessage):
sender_id=self.sender_id,
)
)
elif isinstance(content_part, OmittedReasoningContent):
# Special case for "hidden reasoning" models like o1/o3
# NOTE: we also have to think about how to return this during streaming
messages.append(
HiddenReasoningMessage(
id=self.id,
date=self.created_at,
state="omitted",
name=self.name,
otid=otid,
)
)
else:
warnings.warn(f"Unrecognized content part in assistant message: {content_part}")
@ -464,6 +477,10 @@ class Message(BaseMessage):
data=openai_message_dict["redacted_reasoning_content"] if "redacted_reasoning_content" in openai_message_dict else None,
),
)
if "omitted_reasoning_content" in openai_message_dict and openai_message_dict["omitted_reasoning_content"]:
content.append(
OmittedReasoningContent(),
)
# If we're going from deprecated function form
if openai_message_dict["role"] == "function":

View File

@ -39,9 +39,10 @@ class Message(BaseModel):
tool_calls: Optional[List[ToolCall]] = None
role: str
function_call: Optional[FunctionCall] = None # Deprecated
reasoning_content: Optional[str] = None # Used in newer reasoning APIs
reasoning_content: Optional[str] = None # Used in newer reasoning APIs, e.g. DeepSeek
reasoning_content_signature: Optional[str] = None # NOTE: for Anthropic
redacted_reasoning_content: Optional[str] = None # NOTE: for Anthropic
ommitted_reasoning_content: bool = False # NOTE: for OpenAI o1/o3
class Choice(BaseModel):
@ -52,16 +53,64 @@ class Choice(BaseModel):
seed: Optional[int] = None # found in TogetherAI
class UsageStatisticsPromptTokenDetails(BaseModel):
cached_tokens: int = 0
# NOTE: OAI specific
# audio_tokens: int = 0
def __add__(self, other: "UsageStatisticsPromptTokenDetails") -> "UsageStatisticsPromptTokenDetails":
return UsageStatisticsPromptTokenDetails(
cached_tokens=self.cached_tokens + other.cached_tokens,
)
class UsageStatisticsCompletionTokenDetails(BaseModel):
reasoning_tokens: int = 0
# NOTE: OAI specific
# audio_tokens: int = 0
# accepted_prediction_tokens: int = 0
# rejected_prediction_tokens: int = 0
def __add__(self, other: "UsageStatisticsCompletionTokenDetails") -> "UsageStatisticsCompletionTokenDetails":
return UsageStatisticsCompletionTokenDetails(
reasoning_tokens=self.reasoning_tokens + other.reasoning_tokens,
)
class UsageStatistics(BaseModel):
completion_tokens: int = 0
prompt_tokens: int = 0
total_tokens: int = 0
prompt_tokens_details: Optional[UsageStatisticsPromptTokenDetails] = None
completion_tokens_details: Optional[UsageStatisticsCompletionTokenDetails] = None
def __add__(self, other: "UsageStatistics") -> "UsageStatistics":
if self.prompt_tokens_details is None and other.prompt_tokens_details is None:
total_prompt_tokens_details = None
elif self.prompt_tokens_details is None:
total_prompt_tokens_details = other.prompt_tokens_details
elif other.prompt_tokens_details is None:
total_prompt_tokens_details = self.prompt_tokens_details
else:
total_prompt_tokens_details = self.prompt_tokens_details + other.prompt_tokens_details
if self.completion_tokens_details is None and other.completion_tokens_details is None:
total_completion_tokens_details = None
elif self.completion_tokens_details is None:
total_completion_tokens_details = other.completion_tokens_details
elif other.completion_tokens_details is None:
total_completion_tokens_details = self.completion_tokens_details
else:
total_completion_tokens_details = self.completion_tokens_details + other.completion_tokens_details
return UsageStatistics(
completion_tokens=self.completion_tokens + other.completion_tokens,
prompt_tokens=self.prompt_tokens + other.prompt_tokens,
total_tokens=self.total_tokens + other.total_tokens,
prompt_tokens_details=total_prompt_tokens_details,
completion_tokens_details=total_completion_tokens_details,
)