MemGPT/letta/llm_api/openai_client.py
cthomas 1b58fae4fb
chore: bump version 0.7.22 (#2655)
Co-authored-by: Andy Li <55300002+cliandy@users.noreply.github.com>
Co-authored-by: Kevin Lin <klin5061@gmail.com>
Co-authored-by: Sarah Wooders <sarahwooders@gmail.com>
Co-authored-by: jnjpng <jin@letta.com>
Co-authored-by: Matthew Zhou <mattzh1314@gmail.com>
2025-05-23 01:13:05 -07:00

354 lines
15 KiB
Python

import os
from typing import List, Optional
import openai
from openai import AsyncOpenAI, AsyncStream, OpenAI
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from letta.constants import LETTA_MODEL_ENDPOINT
from letta.errors import (
ErrorCode,
LLMAuthenticationError,
LLMBadRequestError,
LLMConnectionError,
LLMNotFoundError,
LLMPermissionDeniedError,
LLMRateLimitError,
LLMServerError,
LLMUnprocessableEntityError,
)
from letta.llm_api.helpers import add_inner_thoughts_to_functions, convert_to_structured_output, unpack_all_inner_thoughts_from_kwargs
from letta.llm_api.llm_client_base import LLMClientBase
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION, INNER_THOUGHTS_KWARG_DESCRIPTION_GO_FIRST
from letta.log import get_logger
from letta.schemas.enums import ProviderCategory, ProviderType
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message as PydanticMessage
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest
from letta.schemas.openai.chat_completion_request import FunctionCall as ToolFunctionChoiceFunctionCall
from letta.schemas.openai.chat_completion_request import FunctionSchema
from letta.schemas.openai.chat_completion_request import Tool as OpenAITool
from letta.schemas.openai.chat_completion_request import ToolFunctionChoice, cast_message_to_subtype
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
from letta.settings import model_settings
from letta.tracing import trace_method
logger = get_logger(__name__)
def is_openai_reasoning_model(model: str) -> bool:
"""Utility function to check if the model is a 'reasoner'"""
# NOTE: needs to be updated with new model releases
is_reasoning = model.startswith("o1") or model.startswith("o3")
return is_reasoning
def accepts_developer_role(model: str) -> bool:
"""Checks if the model accepts the 'developer' role. Note that not all reasoning models accept this role.
See: https://community.openai.com/t/developer-role-not-accepted-for-o1-o1-mini-o3-mini/1110750/7
"""
if is_openai_reasoning_model(model):
return True
else:
return False
def supports_temperature_param(model: str) -> bool:
"""Certain OpenAI models don't support configuring the temperature.
Example error: 400 - {'error': {'message': "Unsupported parameter: 'temperature' is not supported with this model.", 'type': 'invalid_request_error', 'param': 'temperature', 'code': 'unsupported_parameter'}}
"""
if is_openai_reasoning_model(model):
return False
else:
return True
def supports_parallel_tool_calling(model: str) -> bool:
"""Certain OpenAI models don't support parallel tool calls."""
if is_openai_reasoning_model(model):
return False
else:
return True
# TODO move into LLMConfig as a field?
def supports_structured_output(llm_config: LLMConfig) -> bool:
"""Certain providers don't support structured output."""
# FIXME pretty hacky - turn off for providers we know users will use,
# but also don't support structured output
if "nebius.com" in llm_config.model_endpoint:
return False
else:
return True
# TODO move into LLMConfig as a field?
def requires_auto_tool_choice(llm_config: LLMConfig) -> bool:
"""Certain providers require the tool choice to be set to 'auto'."""
if "nebius.com" in llm_config.model_endpoint:
return True
if "together.ai" in llm_config.model_endpoint or "together.xyz" in llm_config.model_endpoint:
return True
# proxy also has this issue (FIXME check)
elif llm_config.model_endpoint == LETTA_MODEL_ENDPOINT:
return True
# same with vLLM (FIXME check)
elif llm_config.handle and "vllm" in llm_config.handle:
return True
else:
# will use "required" instead of "auto"
return False
class OpenAIClient(LLMClientBase):
def _prepare_client_kwargs(self, llm_config: LLMConfig) -> dict:
api_key = None
if llm_config.provider_category == ProviderCategory.byok:
from letta.services.provider_manager import ProviderManager
api_key = ProviderManager().get_override_key(llm_config.provider_name, actor=self.actor)
if llm_config.model_endpoint_type == ProviderType.together:
api_key = model_settings.together_api_key or os.environ.get("TOGETHER_API_KEY")
if not api_key:
api_key = model_settings.openai_api_key or os.environ.get("OPENAI_API_KEY")
# supposedly the openai python client requires a dummy API key
api_key = api_key or "DUMMY_API_KEY"
kwargs = {"api_key": api_key, "base_url": llm_config.model_endpoint}
return kwargs
@trace_method
def build_request_data(
self,
messages: List[PydanticMessage],
llm_config: LLMConfig,
tools: Optional[List[dict]] = None, # Keep as dict for now as per base class
force_tool_call: Optional[str] = None,
) -> dict:
"""
Constructs a request object in the expected data format for the OpenAI API.
"""
if tools and llm_config.put_inner_thoughts_in_kwargs:
# Special case for LM Studio backend since it needs extra guidance to force out the thoughts first
# TODO(fix)
inner_thoughts_desc = (
INNER_THOUGHTS_KWARG_DESCRIPTION_GO_FIRST if ":1234" in llm_config.model_endpoint else INNER_THOUGHTS_KWARG_DESCRIPTION
)
tools = add_inner_thoughts_to_functions(
functions=tools,
inner_thoughts_key=INNER_THOUGHTS_KWARG,
inner_thoughts_description=inner_thoughts_desc,
put_inner_thoughts_first=True,
)
use_developer_message = accepts_developer_role(llm_config.model)
openai_message_list = [
cast_message_to_subtype(
m.to_openai_dict(
put_inner_thoughts_in_kwargs=llm_config.put_inner_thoughts_in_kwargs,
use_developer_message=use_developer_message,
)
)
for m in messages
]
if llm_config.model:
model = llm_config.model
else:
logger.warning(f"Model type not set in llm_config: {llm_config.model_dump_json(indent=4)}")
model = None
# force function calling for reliability, see https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice
# TODO(matt) move into LLMConfig
# TODO: This vllm checking is very brittle and is a patch at most
tool_choice = None
if requires_auto_tool_choice(llm_config):
tool_choice = "auto" # TODO change to "required" once proxy supports it
elif tools:
# only set if tools is non-Null
tool_choice = "required"
if force_tool_call is not None:
tool_choice = ToolFunctionChoice(type="function", function=ToolFunctionChoiceFunctionCall(name=force_tool_call))
data = ChatCompletionRequest(
model=model,
messages=openai_message_list,
tools=[OpenAITool(type="function", function=f) for f in tools] if tools else None,
tool_choice=tool_choice,
user=str(),
max_completion_tokens=llm_config.max_tokens,
temperature=llm_config.temperature if supports_temperature_param(model) else None,
)
# always set user id for openai requests
if self.actor:
data.user = self.actor.id
if llm_config.model_endpoint == LETTA_MODEL_ENDPOINT:
if not self.actor:
# override user id for inference.letta.com
import uuid
data.user = str(uuid.UUID(int=0))
data.model = "memgpt-openai"
if data.tools is not None and len(data.tools) > 0:
# Convert to structured output style (which has 'strict' and no optionals)
for tool in data.tools:
if supports_structured_output(llm_config):
try:
structured_output_version = convert_to_structured_output(tool.function.model_dump())
tool.function = FunctionSchema(**structured_output_version)
except ValueError as e:
logger.warning(f"Failed to convert tool function to structured output, tool={tool}, error={e}")
return data.model_dump(exclude_unset=True)
@trace_method
def request(self, request_data: dict, llm_config: LLMConfig) -> dict:
"""
Performs underlying synchronous request to OpenAI API and returns raw response dict.
"""
client = OpenAI(**self._prepare_client_kwargs(llm_config))
response: ChatCompletion = client.chat.completions.create(**request_data)
return response.model_dump()
@trace_method
async def request_async(self, request_data: dict, llm_config: LLMConfig) -> dict:
"""
Performs underlying asynchronous request to OpenAI API and returns raw response dict.
"""
client = AsyncOpenAI(**self._prepare_client_kwargs(llm_config))
response: ChatCompletion = await client.chat.completions.create(**request_data)
return response.model_dump()
@trace_method
def convert_response_to_chat_completion(
self,
response_data: dict,
input_messages: List[PydanticMessage], # Included for consistency, maybe used later
llm_config: LLMConfig,
) -> ChatCompletionResponse:
"""
Converts raw OpenAI response dict into the ChatCompletionResponse Pydantic model.
Handles potential extraction of inner thoughts if they were added via kwargs.
"""
# OpenAI's response structure directly maps to ChatCompletionResponse
# We just need to instantiate the Pydantic model for validation and type safety.
chat_completion_response = ChatCompletionResponse(**response_data)
# Unpack inner thoughts if they were embedded in function arguments
if llm_config.put_inner_thoughts_in_kwargs:
chat_completion_response = unpack_all_inner_thoughts_from_kwargs(
response=chat_completion_response, inner_thoughts_key=INNER_THOUGHTS_KWARG
)
# If we used a reasoning model, create a content part for the ommitted reasoning
if is_openai_reasoning_model(llm_config.model):
chat_completion_response.choices[0].message.ommitted_reasoning_content = True
return chat_completion_response
async def stream_async(self, request_data: dict, llm_config: LLMConfig) -> AsyncStream[ChatCompletionChunk]:
"""
Performs underlying asynchronous streaming request to OpenAI and returns the async stream iterator.
"""
client = AsyncOpenAI(**self._prepare_client_kwargs(llm_config))
response_stream: AsyncStream[ChatCompletionChunk] = await client.chat.completions.create(
**request_data, stream=True, stream_options={"include_usage": True}
)
return response_stream
def handle_llm_error(self, e: Exception) -> Exception:
"""
Maps OpenAI-specific errors to common LLMError types.
"""
if isinstance(e, openai.APIConnectionError):
logger.warning(f"[OpenAI] API connection error: {e}")
return LLMConnectionError(
message=f"Failed to connect to OpenAI: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
details={"cause": str(e.__cause__) if e.__cause__ else None},
)
if isinstance(e, openai.RateLimitError):
logger.warning(f"[OpenAI] Rate limited (429). Consider backoff. Error: {e}")
return LLMRateLimitError(
message=f"Rate limited by OpenAI: {str(e)}",
code=ErrorCode.RATE_LIMIT_EXCEEDED,
details=e.body, # Include body which often has rate limit details
)
if isinstance(e, openai.BadRequestError):
logger.warning(f"[OpenAI] Bad request (400): {str(e)}")
# BadRequestError can signify different issues (e.g., invalid args, context length)
# Check message content if finer-grained errors are needed
# Example: if "context_length_exceeded" in str(e): return LLMContextLengthExceededError(...)
return LLMBadRequestError(
message=f"Bad request to OpenAI: {str(e)}",
code=ErrorCode.INVALID_ARGUMENT, # Or more specific if detectable
details=e.body,
)
if isinstance(e, openai.AuthenticationError):
logger.error(f"[OpenAI] Authentication error (401): {str(e)}") # More severe log level
return LLMAuthenticationError(
message=f"Authentication failed with OpenAI: {str(e)}", code=ErrorCode.UNAUTHENTICATED, details=e.body
)
if isinstance(e, openai.PermissionDeniedError):
logger.error(f"[OpenAI] Permission denied (403): {str(e)}") # More severe log level
return LLMPermissionDeniedError(
message=f"Permission denied by OpenAI: {str(e)}", code=ErrorCode.PERMISSION_DENIED, details=e.body
)
if isinstance(e, openai.NotFoundError):
logger.warning(f"[OpenAI] Resource not found (404): {str(e)}")
# Could be invalid model name, etc.
return LLMNotFoundError(message=f"Resource not found in OpenAI: {str(e)}", code=ErrorCode.NOT_FOUND, details=e.body)
if isinstance(e, openai.UnprocessableEntityError):
logger.warning(f"[OpenAI] Unprocessable entity (422): {str(e)}")
return LLMUnprocessableEntityError(
message=f"Invalid request content for OpenAI: {str(e)}",
code=ErrorCode.INVALID_ARGUMENT, # Usually validation errors
details=e.body,
)
# General API error catch-all
if isinstance(e, openai.APIStatusError):
logger.warning(f"[OpenAI] API status error ({e.status_code}): {str(e)}")
# Map based on status code potentially
if e.status_code >= 500:
error_cls = LLMServerError
error_code = ErrorCode.INTERNAL_SERVER_ERROR
else:
# Treat other 4xx as bad requests if not caught above
error_cls = LLMBadRequestError
error_code = ErrorCode.INVALID_ARGUMENT
return error_cls(
message=f"OpenAI API error: {str(e)}",
code=error_code,
details={
"status_code": e.status_code,
"response": str(e.response),
"body": e.body,
},
)
# Fallback for unexpected errors
return super().handle_llm_error(e)