mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00

Co-authored-by: Kian Jones <11655409+kianjones9@users.noreply.github.com> Co-authored-by: Andy Li <55300002+cliandy@users.noreply.github.com> Co-authored-by: Matthew Zhou <mattzh1314@gmail.com>
309 lines
14 KiB
Python
309 lines
14 KiB
Python
import os
|
|
from typing import List, Optional
|
|
|
|
import openai
|
|
from openai import AsyncOpenAI, AsyncStream, OpenAI, Stream
|
|
from openai.types.chat.chat_completion import ChatCompletion
|
|
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
|
|
|
from letta.constants import LETTA_MODEL_ENDPOINT
|
|
from letta.errors import (
|
|
ErrorCode,
|
|
LLMAuthenticationError,
|
|
LLMBadRequestError,
|
|
LLMConnectionError,
|
|
LLMNotFoundError,
|
|
LLMPermissionDeniedError,
|
|
LLMRateLimitError,
|
|
LLMServerError,
|
|
LLMUnprocessableEntityError,
|
|
)
|
|
from letta.llm_api.helpers import add_inner_thoughts_to_functions, convert_to_structured_output, unpack_all_inner_thoughts_from_kwargs
|
|
from letta.llm_api.llm_client_base import LLMClientBase
|
|
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION, INNER_THOUGHTS_KWARG_DESCRIPTION_GO_FIRST
|
|
from letta.log import get_logger
|
|
from letta.schemas.enums import ProviderType
|
|
from letta.schemas.llm_config import LLMConfig
|
|
from letta.schemas.message import Message as PydanticMessage
|
|
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest
|
|
from letta.schemas.openai.chat_completion_request import FunctionCall as ToolFunctionChoiceFunctionCall
|
|
from letta.schemas.openai.chat_completion_request import FunctionSchema
|
|
from letta.schemas.openai.chat_completion_request import Tool as OpenAITool
|
|
from letta.schemas.openai.chat_completion_request import ToolFunctionChoice, cast_message_to_subtype
|
|
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
|
|
from letta.settings import model_settings
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
def is_openai_reasoning_model(model: str) -> bool:
|
|
"""Utility function to check if the model is a 'reasoner'"""
|
|
|
|
# NOTE: needs to be updated with new model releases
|
|
return model.startswith("o1") or model.startswith("o3")
|
|
|
|
|
|
def supports_temperature_param(model: str) -> bool:
|
|
"""Certain OpenAI models don't support configuring the temperature.
|
|
|
|
Example error: 400 - {'error': {'message': "Unsupported parameter: 'temperature' is not supported with this model.", 'type': 'invalid_request_error', 'param': 'temperature', 'code': 'unsupported_parameter'}}
|
|
"""
|
|
if is_openai_reasoning_model(model):
|
|
return False
|
|
else:
|
|
return True
|
|
|
|
|
|
def supports_parallel_tool_calling(model: str) -> bool:
|
|
"""Certain OpenAI models don't support parallel tool calls."""
|
|
|
|
if is_openai_reasoning_model(model):
|
|
return False
|
|
else:
|
|
return True
|
|
|
|
|
|
class OpenAIClient(LLMClientBase):
|
|
def _prepare_client_kwargs(self, llm_config: LLMConfig) -> dict:
|
|
api_key = None
|
|
if llm_config.provider_name and llm_config.provider_name != ProviderType.openai.value:
|
|
from letta.services.provider_manager import ProviderManager
|
|
|
|
api_key = ProviderManager().get_override_key(llm_config.provider_name)
|
|
|
|
if not api_key:
|
|
api_key = model_settings.openai_api_key or os.environ.get("OPENAI_API_KEY")
|
|
# supposedly the openai python client requires a dummy API key
|
|
api_key = api_key or "DUMMY_API_KEY"
|
|
kwargs = {"api_key": api_key, "base_url": llm_config.model_endpoint}
|
|
|
|
return kwargs
|
|
|
|
def build_request_data(
|
|
self,
|
|
messages: List[PydanticMessage],
|
|
llm_config: LLMConfig,
|
|
tools: Optional[List[dict]] = None, # Keep as dict for now as per base class
|
|
force_tool_call: Optional[str] = None,
|
|
) -> dict:
|
|
"""
|
|
Constructs a request object in the expected data format for the OpenAI API.
|
|
"""
|
|
if tools and llm_config.put_inner_thoughts_in_kwargs:
|
|
# Special case for LM Studio backend since it needs extra guidance to force out the thoughts first
|
|
# TODO(fix)
|
|
inner_thoughts_desc = (
|
|
INNER_THOUGHTS_KWARG_DESCRIPTION_GO_FIRST if ":1234" in llm_config.model_endpoint else INNER_THOUGHTS_KWARG_DESCRIPTION
|
|
)
|
|
tools = add_inner_thoughts_to_functions(
|
|
functions=tools,
|
|
inner_thoughts_key=INNER_THOUGHTS_KWARG,
|
|
inner_thoughts_description=inner_thoughts_desc,
|
|
put_inner_thoughts_first=True,
|
|
)
|
|
|
|
use_developer_message = is_openai_reasoning_model(llm_config.model)
|
|
|
|
openai_message_list = [
|
|
cast_message_to_subtype(
|
|
m.to_openai_dict(
|
|
put_inner_thoughts_in_kwargs=llm_config.put_inner_thoughts_in_kwargs,
|
|
use_developer_message=use_developer_message,
|
|
)
|
|
)
|
|
for m in messages
|
|
]
|
|
|
|
if llm_config.model:
|
|
model = llm_config.model
|
|
else:
|
|
logger.warning(f"Model type not set in llm_config: {llm_config.model_dump_json(indent=4)}")
|
|
model = None
|
|
|
|
# force function calling for reliability, see https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice
|
|
# TODO(matt) move into LLMConfig
|
|
# TODO: This vllm checking is very brittle and is a patch at most
|
|
tool_choice = None
|
|
if llm_config.model_endpoint == LETTA_MODEL_ENDPOINT or (llm_config.handle and "vllm" in llm_config.handle):
|
|
tool_choice = "auto" # TODO change to "required" once proxy supports it
|
|
elif tools:
|
|
# only set if tools is non-Null
|
|
tool_choice = "required"
|
|
|
|
if force_tool_call is not None:
|
|
tool_choice = ToolFunctionChoice(type="function", function=ToolFunctionChoiceFunctionCall(name=force_tool_call))
|
|
|
|
data = ChatCompletionRequest(
|
|
model=model,
|
|
messages=openai_message_list,
|
|
tools=[OpenAITool(type="function", function=f) for f in tools] if tools else None,
|
|
tool_choice=tool_choice,
|
|
user=str(),
|
|
max_completion_tokens=llm_config.max_tokens,
|
|
temperature=llm_config.temperature if supports_temperature_param(model) else None,
|
|
)
|
|
|
|
# always set user id for openai requests
|
|
if self.actor_id:
|
|
data.user = self.actor_id
|
|
|
|
if llm_config.model_endpoint == LETTA_MODEL_ENDPOINT:
|
|
if not self.actor_id:
|
|
# override user id for inference.letta.com
|
|
import uuid
|
|
|
|
data.user = str(uuid.UUID(int=0))
|
|
|
|
data.model = "memgpt-openai"
|
|
|
|
if data.tools is not None and len(data.tools) > 0:
|
|
# Convert to structured output style (which has 'strict' and no optionals)
|
|
for tool in data.tools:
|
|
try:
|
|
structured_output_version = convert_to_structured_output(tool.function.model_dump())
|
|
tool.function = FunctionSchema(**structured_output_version)
|
|
except ValueError as e:
|
|
logger.warning(f"Failed to convert tool function to structured output, tool={tool}, error={e}")
|
|
|
|
return data.model_dump(exclude_unset=True)
|
|
|
|
def request(self, request_data: dict, llm_config: LLMConfig) -> dict:
|
|
"""
|
|
Performs underlying synchronous request to OpenAI API and returns raw response dict.
|
|
"""
|
|
client = OpenAI(**self._prepare_client_kwargs(llm_config))
|
|
|
|
response: ChatCompletion = client.chat.completions.create(**request_data)
|
|
return response.model_dump()
|
|
|
|
async def request_async(self, request_data: dict, llm_config: LLMConfig) -> dict:
|
|
"""
|
|
Performs underlying asynchronous request to OpenAI API and returns raw response dict.
|
|
"""
|
|
client = AsyncOpenAI(**self._prepare_client_kwargs(llm_config))
|
|
response: ChatCompletion = await client.chat.completions.create(**request_data)
|
|
return response.model_dump()
|
|
|
|
def convert_response_to_chat_completion(
|
|
self,
|
|
response_data: dict,
|
|
input_messages: List[PydanticMessage], # Included for consistency, maybe used later
|
|
llm_config: LLMConfig,
|
|
) -> ChatCompletionResponse:
|
|
"""
|
|
Converts raw OpenAI response dict into the ChatCompletionResponse Pydantic model.
|
|
Handles potential extraction of inner thoughts if they were added via kwargs.
|
|
"""
|
|
# OpenAI's response structure directly maps to ChatCompletionResponse
|
|
# We just need to instantiate the Pydantic model for validation and type safety.
|
|
chat_completion_response = ChatCompletionResponse(**response_data)
|
|
|
|
# Unpack inner thoughts if they were embedded in function arguments
|
|
if llm_config.put_inner_thoughts_in_kwargs:
|
|
chat_completion_response = unpack_all_inner_thoughts_from_kwargs(
|
|
response=chat_completion_response, inner_thoughts_key=INNER_THOUGHTS_KWARG
|
|
)
|
|
|
|
# If we used a reasoning model, create a content part for the ommitted reasoning
|
|
if is_openai_reasoning_model(llm_config.model):
|
|
chat_completion_response.choices[0].message.ommitted_reasoning_content = True
|
|
|
|
return chat_completion_response
|
|
|
|
def stream(self, request_data: dict, llm_config: LLMConfig) -> Stream[ChatCompletionChunk]:
|
|
"""
|
|
Performs underlying streaming request to OpenAI and returns the stream iterator.
|
|
"""
|
|
client = OpenAI(**self._prepare_client_kwargs(llm_config))
|
|
response_stream: Stream[ChatCompletionChunk] = client.chat.completions.create(**request_data, stream=True)
|
|
return response_stream
|
|
|
|
async def stream_async(self, request_data: dict, llm_config: LLMConfig) -> AsyncStream[ChatCompletionChunk]:
|
|
"""
|
|
Performs underlying asynchronous streaming request to OpenAI and returns the async stream iterator.
|
|
"""
|
|
client = AsyncOpenAI(**self._prepare_client_kwargs(llm_config))
|
|
response_stream: AsyncStream[ChatCompletionChunk] = await client.chat.completions.create(**request_data, stream=True)
|
|
return response_stream
|
|
|
|
def handle_llm_error(self, e: Exception) -> Exception:
|
|
"""
|
|
Maps OpenAI-specific errors to common LLMError types.
|
|
"""
|
|
if isinstance(e, openai.APIConnectionError):
|
|
logger.warning(f"[OpenAI] API connection error: {e}")
|
|
return LLMConnectionError(
|
|
message=f"Failed to connect to OpenAI: {str(e)}",
|
|
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
|
details={"cause": str(e.__cause__) if e.__cause__ else None},
|
|
)
|
|
|
|
if isinstance(e, openai.RateLimitError):
|
|
logger.warning(f"[OpenAI] Rate limited (429). Consider backoff. Error: {e}")
|
|
return LLMRateLimitError(
|
|
message=f"Rate limited by OpenAI: {str(e)}",
|
|
code=ErrorCode.RATE_LIMIT_EXCEEDED,
|
|
details=e.body, # Include body which often has rate limit details
|
|
)
|
|
|
|
if isinstance(e, openai.BadRequestError):
|
|
logger.warning(f"[OpenAI] Bad request (400): {str(e)}")
|
|
# BadRequestError can signify different issues (e.g., invalid args, context length)
|
|
# Check message content if finer-grained errors are needed
|
|
# Example: if "context_length_exceeded" in str(e): return LLMContextLengthExceededError(...)
|
|
return LLMBadRequestError(
|
|
message=f"Bad request to OpenAI: {str(e)}",
|
|
code=ErrorCode.INVALID_ARGUMENT, # Or more specific if detectable
|
|
details=e.body,
|
|
)
|
|
|
|
if isinstance(e, openai.AuthenticationError):
|
|
logger.error(f"[OpenAI] Authentication error (401): {str(e)}") # More severe log level
|
|
return LLMAuthenticationError(
|
|
message=f"Authentication failed with OpenAI: {str(e)}", code=ErrorCode.UNAUTHENTICATED, details=e.body
|
|
)
|
|
|
|
if isinstance(e, openai.PermissionDeniedError):
|
|
logger.error(f"[OpenAI] Permission denied (403): {str(e)}") # More severe log level
|
|
return LLMPermissionDeniedError(
|
|
message=f"Permission denied by OpenAI: {str(e)}", code=ErrorCode.PERMISSION_DENIED, details=e.body
|
|
)
|
|
|
|
if isinstance(e, openai.NotFoundError):
|
|
logger.warning(f"[OpenAI] Resource not found (404): {str(e)}")
|
|
# Could be invalid model name, etc.
|
|
return LLMNotFoundError(message=f"Resource not found in OpenAI: {str(e)}", code=ErrorCode.NOT_FOUND, details=e.body)
|
|
|
|
if isinstance(e, openai.UnprocessableEntityError):
|
|
logger.warning(f"[OpenAI] Unprocessable entity (422): {str(e)}")
|
|
return LLMUnprocessableEntityError(
|
|
message=f"Invalid request content for OpenAI: {str(e)}",
|
|
code=ErrorCode.INVALID_ARGUMENT, # Usually validation errors
|
|
details=e.body,
|
|
)
|
|
|
|
# General API error catch-all
|
|
if isinstance(e, openai.APIStatusError):
|
|
logger.warning(f"[OpenAI] API status error ({e.status_code}): {str(e)}")
|
|
# Map based on status code potentially
|
|
if e.status_code >= 500:
|
|
error_cls = LLMServerError
|
|
error_code = ErrorCode.INTERNAL_SERVER_ERROR
|
|
else:
|
|
# Treat other 4xx as bad requests if not caught above
|
|
error_cls = LLMBadRequestError
|
|
error_code = ErrorCode.INVALID_ARGUMENT
|
|
|
|
return error_cls(
|
|
message=f"OpenAI API error: {str(e)}",
|
|
code=error_code,
|
|
details={
|
|
"status_code": e.status_code,
|
|
"response": str(e.response),
|
|
"body": e.body,
|
|
},
|
|
)
|
|
|
|
# Fallback for unexpected errors
|
|
return super().handle_llm_error(e)
|