mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
fix: patch o-series (#1699)
This commit is contained in:
parent
801188c0e7
commit
69ba0607e4
@ -10,7 +10,7 @@ def get_composio_api_key(actor: User, logger: Optional[Logger] = None) -> Option
|
||||
api_keys = SandboxConfigManager().list_sandbox_env_vars_by_key(key="COMPOSIO_API_KEY", actor=actor)
|
||||
if not api_keys:
|
||||
if logger:
|
||||
logger.warning(f"No API keys found for Composio. Defaulting to the environment variable...")
|
||||
logger.debug(f"No API keys found for Composio. Defaulting to the environment variable...")
|
||||
if tool_settings.composio_api_key:
|
||||
return tool_settings.composio_api_key
|
||||
else:
|
||||
|
@ -5,6 +5,7 @@ import requests
|
||||
from openai import OpenAI
|
||||
|
||||
from letta.llm_api.helpers import add_inner_thoughts_to_functions, convert_to_structured_output, make_post_request
|
||||
from letta.llm_api.openai_client import supports_parallel_tool_calling, supports_temperature_param
|
||||
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION, INNER_THOUGHTS_KWARG_DESCRIPTION_GO_FIRST
|
||||
from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages
|
||||
from letta.log import get_logger
|
||||
@ -135,7 +136,7 @@ def build_openai_chat_completions_request(
|
||||
tool_choice=tool_choice,
|
||||
user=str(user_id),
|
||||
max_completion_tokens=llm_config.max_tokens,
|
||||
temperature=1.0 if llm_config.enable_reasoner else llm_config.temperature,
|
||||
temperature=llm_config.temperature if supports_temperature_param(model) else None,
|
||||
reasoning_effort=llm_config.reasoning_effort,
|
||||
)
|
||||
else:
|
||||
@ -489,6 +490,7 @@ def prepare_openai_payload(chat_completion_request: ChatCompletionRequest):
|
||||
# except ValueError as e:
|
||||
# warnings.warn(f"Failed to convert tool function to structured output, tool={tool}, error={e}")
|
||||
|
||||
if "o3-mini" in chat_completion_request.model or "o1" in chat_completion_request.model:
|
||||
if not supports_parallel_tool_calling(chat_completion_request.model):
|
||||
data.pop("parallel_tool_calls", None)
|
||||
|
||||
return data
|
||||
|
@ -34,6 +34,33 @@ from letta.settings import model_settings
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def is_openai_reasoning_model(model: str) -> bool:
|
||||
"""Utility function to check if the model is a 'reasoner'"""
|
||||
|
||||
# NOTE: needs to be updated with new model releases
|
||||
return model.startswith("o1") or model.startswith("o3")
|
||||
|
||||
|
||||
def supports_temperature_param(model: str) -> bool:
|
||||
"""Certain OpenAI models don't support configuring the temperature.
|
||||
|
||||
Example error: 400 - {'error': {'message': "Unsupported parameter: 'temperature' is not supported with this model.", 'type': 'invalid_request_error', 'param': 'temperature', 'code': 'unsupported_parameter'}}
|
||||
"""
|
||||
if is_openai_reasoning_model(model):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def supports_parallel_tool_calling(model: str) -> bool:
|
||||
"""Certain OpenAI models don't support parallel tool calls."""
|
||||
|
||||
if is_openai_reasoning_model(model):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
class OpenAIClient(LLMClientBase):
|
||||
def _prepare_client_kwargs(self) -> dict:
|
||||
api_key = model_settings.openai_api_key or os.environ.get("OPENAI_API_KEY")
|
||||
@ -66,7 +93,8 @@ class OpenAIClient(LLMClientBase):
|
||||
put_inner_thoughts_first=True,
|
||||
)
|
||||
|
||||
use_developer_message = llm_config.model.startswith("o1") or llm_config.model.startswith("o3") # o-series models
|
||||
use_developer_message = is_openai_reasoning_model(llm_config.model)
|
||||
|
||||
openai_message_list = [
|
||||
cast_message_to_subtype(
|
||||
m.to_openai_dict(
|
||||
@ -103,7 +131,7 @@ class OpenAIClient(LLMClientBase):
|
||||
tool_choice=tool_choice,
|
||||
user=str(),
|
||||
max_completion_tokens=llm_config.max_tokens,
|
||||
temperature=llm_config.temperature,
|
||||
temperature=llm_config.temperature if supports_temperature_param(model) else None,
|
||||
)
|
||||
|
||||
if "inference.memgpt.ai" in llm_config.model_endpoint:
|
||||
@ -160,6 +188,10 @@ class OpenAIClient(LLMClientBase):
|
||||
response=chat_completion_response, inner_thoughts_key=INNER_THOUGHTS_KWARG
|
||||
)
|
||||
|
||||
# If we used a reasoning model, create a content part for the ommitted reasoning
|
||||
if is_openai_reasoning_model(self.llm_config.model):
|
||||
chat_completion_response.choices[0].message.ommitted_reasoning_content = True
|
||||
|
||||
return chat_completion_response
|
||||
|
||||
def stream(self, request_data: dict) -> Stream[ChatCompletionChunk]:
|
||||
|
@ -145,7 +145,8 @@ class OmittedReasoningContent(MessageContent):
|
||||
type: Literal[MessageContentType.omitted_reasoning] = Field(
|
||||
MessageContentType.omitted_reasoning, description="Indicates this is an omitted reasoning step."
|
||||
)
|
||||
tokens: int = Field(..., description="The reasoning token count for intermediate reasoning content.")
|
||||
# NOTE: dropping because we don't track this kind of information for the other reasoning types
|
||||
# tokens: int = Field(..., description="The reasoning token count for intermediate reasoning content.")
|
||||
|
||||
|
||||
LettaMessageContentUnion = Annotated[
|
||||
|
@ -81,8 +81,11 @@ class LLMConfig(BaseModel):
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def set_default_enable_reasoner(cls, values):
|
||||
if any(openai_reasoner_model in values.get("model", "") for openai_reasoner_model in ["o3-mini", "o1"]):
|
||||
values["enable_reasoner"] = True
|
||||
# NOTE: this is really only applicable for models that can toggle reasoning on-and-off, like 3.7
|
||||
# We can also use this field to identify if a model is a "reasoning" model (o1/o3, etc.) if we want
|
||||
# if any(openai_reasoner_model in values.get("model", "") for openai_reasoner_model in ["o3-mini", "o1"]):
|
||||
# values["enable_reasoner"] = True
|
||||
# values["put_inner_thoughts_in_kwargs"] = False
|
||||
return values
|
||||
|
||||
@model_validator(mode="before")
|
||||
@ -100,6 +103,13 @@ class LLMConfig(BaseModel):
|
||||
if values.get("put_inner_thoughts_in_kwargs") is None:
|
||||
values["put_inner_thoughts_in_kwargs"] = False if model in avoid_put_inner_thoughts_in_kwargs else True
|
||||
|
||||
# For the o1/o3 series from OpenAI, set to False by default
|
||||
# We can set this flag to `true` if desired, which will enable "double-think"
|
||||
from letta.llm_api.openai_client import is_openai_reasoning_model
|
||||
|
||||
if is_openai_reasoning_model(model):
|
||||
values["put_inner_thoughts_in_kwargs"] = False
|
||||
|
||||
return values
|
||||
|
||||
@model_validator(mode="after")
|
||||
|
@ -31,6 +31,7 @@ from letta.schemas.letta_message import (
|
||||
)
|
||||
from letta.schemas.letta_message_content import (
|
||||
LettaMessageContentUnion,
|
||||
OmittedReasoningContent,
|
||||
ReasoningContent,
|
||||
RedactedReasoningContent,
|
||||
TextContent,
|
||||
@ -295,6 +296,18 @@ class Message(BaseMessage):
|
||||
sender_id=self.sender_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(content_part, OmittedReasoningContent):
|
||||
# Special case for "hidden reasoning" models like o1/o3
|
||||
# NOTE: we also have to think about how to return this during streaming
|
||||
messages.append(
|
||||
HiddenReasoningMessage(
|
||||
id=self.id,
|
||||
date=self.created_at,
|
||||
state="omitted",
|
||||
name=self.name,
|
||||
otid=otid,
|
||||
)
|
||||
)
|
||||
else:
|
||||
warnings.warn(f"Unrecognized content part in assistant message: {content_part}")
|
||||
|
||||
@ -464,6 +477,10 @@ class Message(BaseMessage):
|
||||
data=openai_message_dict["redacted_reasoning_content"] if "redacted_reasoning_content" in openai_message_dict else None,
|
||||
),
|
||||
)
|
||||
if "omitted_reasoning_content" in openai_message_dict and openai_message_dict["omitted_reasoning_content"]:
|
||||
content.append(
|
||||
OmittedReasoningContent(),
|
||||
)
|
||||
|
||||
# If we're going from deprecated function form
|
||||
if openai_message_dict["role"] == "function":
|
||||
|
@ -39,9 +39,10 @@ class Message(BaseModel):
|
||||
tool_calls: Optional[List[ToolCall]] = None
|
||||
role: str
|
||||
function_call: Optional[FunctionCall] = None # Deprecated
|
||||
reasoning_content: Optional[str] = None # Used in newer reasoning APIs
|
||||
reasoning_content: Optional[str] = None # Used in newer reasoning APIs, e.g. DeepSeek
|
||||
reasoning_content_signature: Optional[str] = None # NOTE: for Anthropic
|
||||
redacted_reasoning_content: Optional[str] = None # NOTE: for Anthropic
|
||||
ommitted_reasoning_content: bool = False # NOTE: for OpenAI o1/o3
|
||||
|
||||
|
||||
class Choice(BaseModel):
|
||||
@ -52,16 +53,64 @@ class Choice(BaseModel):
|
||||
seed: Optional[int] = None # found in TogetherAI
|
||||
|
||||
|
||||
class UsageStatisticsPromptTokenDetails(BaseModel):
|
||||
cached_tokens: int = 0
|
||||
# NOTE: OAI specific
|
||||
# audio_tokens: int = 0
|
||||
|
||||
def __add__(self, other: "UsageStatisticsPromptTokenDetails") -> "UsageStatisticsPromptTokenDetails":
|
||||
return UsageStatisticsPromptTokenDetails(
|
||||
cached_tokens=self.cached_tokens + other.cached_tokens,
|
||||
)
|
||||
|
||||
|
||||
class UsageStatisticsCompletionTokenDetails(BaseModel):
|
||||
reasoning_tokens: int = 0
|
||||
# NOTE: OAI specific
|
||||
# audio_tokens: int = 0
|
||||
# accepted_prediction_tokens: int = 0
|
||||
# rejected_prediction_tokens: int = 0
|
||||
|
||||
def __add__(self, other: "UsageStatisticsCompletionTokenDetails") -> "UsageStatisticsCompletionTokenDetails":
|
||||
return UsageStatisticsCompletionTokenDetails(
|
||||
reasoning_tokens=self.reasoning_tokens + other.reasoning_tokens,
|
||||
)
|
||||
|
||||
|
||||
class UsageStatistics(BaseModel):
|
||||
completion_tokens: int = 0
|
||||
prompt_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
|
||||
prompt_tokens_details: Optional[UsageStatisticsPromptTokenDetails] = None
|
||||
completion_tokens_details: Optional[UsageStatisticsCompletionTokenDetails] = None
|
||||
|
||||
def __add__(self, other: "UsageStatistics") -> "UsageStatistics":
|
||||
|
||||
if self.prompt_tokens_details is None and other.prompt_tokens_details is None:
|
||||
total_prompt_tokens_details = None
|
||||
elif self.prompt_tokens_details is None:
|
||||
total_prompt_tokens_details = other.prompt_tokens_details
|
||||
elif other.prompt_tokens_details is None:
|
||||
total_prompt_tokens_details = self.prompt_tokens_details
|
||||
else:
|
||||
total_prompt_tokens_details = self.prompt_tokens_details + other.prompt_tokens_details
|
||||
|
||||
if self.completion_tokens_details is None and other.completion_tokens_details is None:
|
||||
total_completion_tokens_details = None
|
||||
elif self.completion_tokens_details is None:
|
||||
total_completion_tokens_details = other.completion_tokens_details
|
||||
elif other.completion_tokens_details is None:
|
||||
total_completion_tokens_details = self.completion_tokens_details
|
||||
else:
|
||||
total_completion_tokens_details = self.completion_tokens_details + other.completion_tokens_details
|
||||
|
||||
return UsageStatistics(
|
||||
completion_tokens=self.completion_tokens + other.completion_tokens,
|
||||
prompt_tokens=self.prompt_tokens + other.prompt_tokens,
|
||||
total_tokens=self.total_tokens + other.total_tokens,
|
||||
prompt_tokens_details=total_prompt_tokens_details,
|
||||
completion_tokens_details=total_completion_tokens_details,
|
||||
)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user