import os from typing import List, Optional import openai from openai import AsyncOpenAI, AsyncStream, OpenAI, Stream from openai.types.chat.chat_completion import ChatCompletion from openai.types.chat.chat_completion_chunk import ChatCompletionChunk from letta.errors import ( ErrorCode, LLMAuthenticationError, LLMBadRequestError, LLMConnectionError, LLMNotFoundError, LLMPermissionDeniedError, LLMRateLimitError, LLMServerError, LLMUnprocessableEntityError, ) from letta.llm_api.helpers import add_inner_thoughts_to_functions, convert_to_structured_output, unpack_all_inner_thoughts_from_kwargs from letta.llm_api.llm_client_base import LLMClientBase from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION, INNER_THOUGHTS_KWARG_DESCRIPTION_GO_FIRST from letta.log import get_logger from letta.schemas.llm_config import LLMConfig from letta.schemas.message import Message as PydanticMessage from letta.schemas.openai.chat_completion_request import ChatCompletionRequest from letta.schemas.openai.chat_completion_request import FunctionCall as ToolFunctionChoiceFunctionCall from letta.schemas.openai.chat_completion_request import FunctionSchema from letta.schemas.openai.chat_completion_request import Tool as OpenAITool from letta.schemas.openai.chat_completion_request import ToolFunctionChoice, cast_message_to_subtype from letta.schemas.openai.chat_completion_response import ChatCompletionResponse from letta.settings import model_settings logger = get_logger(__name__) class OpenAIClient(LLMClientBase): def _prepare_client_kwargs(self) -> dict: api_key = model_settings.openai_api_key or os.environ.get("OPENAI_API_KEY") # supposedly the openai python client requires a dummy API key api_key = api_key or "DUMMY_API_KEY" kwargs = {"api_key": api_key, "base_url": self.llm_config.model_endpoint} return kwargs def build_request_data( self, messages: List[PydanticMessage], llm_config: LLMConfig, tools: Optional[List[dict]] = None, # Keep as dict for now as per base class force_tool_call: Optional[str] = None, ) -> dict: """ Constructs a request object in the expected data format for the OpenAI API. """ if tools and llm_config.put_inner_thoughts_in_kwargs: # Special case for LM Studio backend since it needs extra guidance to force out the thoughts first # TODO(fix) inner_thoughts_desc = ( INNER_THOUGHTS_KWARG_DESCRIPTION_GO_FIRST if ":1234" in llm_config.model_endpoint else INNER_THOUGHTS_KWARG_DESCRIPTION ) tools = add_inner_thoughts_to_functions( functions=tools, inner_thoughts_key=INNER_THOUGHTS_KWARG, inner_thoughts_description=inner_thoughts_desc, put_inner_thoughts_first=True, ) use_developer_message = llm_config.model.startswith("o1") or llm_config.model.startswith("o3") # o-series models openai_message_list = [ cast_message_to_subtype( m.to_openai_dict( put_inner_thoughts_in_kwargs=llm_config.put_inner_thoughts_in_kwargs, use_developer_message=use_developer_message, ) ) for m in messages ] if llm_config.model: model = llm_config.model else: logger.warning(f"Model type not set in llm_config: {llm_config.model_dump_json(indent=4)}") model = None # force function calling for reliability, see https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice # TODO(matt) move into LLMConfig # TODO: This vllm checking is very brittle and is a patch at most tool_choice = None if llm_config.model_endpoint == "https://inference.memgpt.ai" or (llm_config.handle and "vllm" in self.llm_config.handle): tool_choice = "auto" # TODO change to "required" once proxy supports it elif tools: # only set if tools is non-Null tool_choice = "required" if force_tool_call is not None: tool_choice = ToolFunctionChoice(type="function", function=ToolFunctionChoiceFunctionCall(name=force_tool_call)) data = ChatCompletionRequest( model=model, messages=openai_message_list, tools=[OpenAITool(type="function", function=f) for f in tools] if tools else None, tool_choice=tool_choice, user=str(), max_completion_tokens=llm_config.max_tokens, temperature=llm_config.temperature, ) if "inference.memgpt.ai" in llm_config.model_endpoint: # override user id for inference.memgpt.ai import uuid data.user = str(uuid.UUID(int=0)) data.model = "memgpt-openai" if data.tools is not None and len(data.tools) > 0: # Convert to structured output style (which has 'strict' and no optionals) for tool in data.tools: try: structured_output_version = convert_to_structured_output(tool.function.model_dump()) tool.function = FunctionSchema(**structured_output_version) except ValueError as e: logger.warning(f"Failed to convert tool function to structured output, tool={tool}, error={e}") return data.model_dump(exclude_unset=True) def request(self, request_data: dict) -> dict: """ Performs underlying synchronous request to OpenAI API and returns raw response dict. """ client = OpenAI(**self._prepare_client_kwargs()) response: ChatCompletion = client.chat.completions.create(**request_data) return response.model_dump() async def request_async(self, request_data: dict) -> dict: """ Performs underlying asynchronous request to OpenAI API and returns raw response dict. """ client = AsyncOpenAI(**self._prepare_client_kwargs()) response: ChatCompletion = await client.chat.completions.create(**request_data) return response.model_dump() def convert_response_to_chat_completion( self, response_data: dict, input_messages: List[PydanticMessage], # Included for consistency, maybe used later ) -> ChatCompletionResponse: """ Converts raw OpenAI response dict into the ChatCompletionResponse Pydantic model. Handles potential extraction of inner thoughts if they were added via kwargs. """ # OpenAI's response structure directly maps to ChatCompletionResponse # We just need to instantiate the Pydantic model for validation and type safety. chat_completion_response = ChatCompletionResponse(**response_data) # Unpack inner thoughts if they were embedded in function arguments if self.llm_config.put_inner_thoughts_in_kwargs: chat_completion_response = unpack_all_inner_thoughts_from_kwargs( response=chat_completion_response, inner_thoughts_key=INNER_THOUGHTS_KWARG ) return chat_completion_response def stream(self, request_data: dict) -> Stream[ChatCompletionChunk]: """ Performs underlying streaming request to OpenAI and returns the stream iterator. """ client = OpenAI(**self._prepare_client_kwargs()) response_stream: Stream[ChatCompletionChunk] = client.chat.completions.create(**request_data, stream=True) return response_stream async def stream_async(self, request_data: dict) -> AsyncStream[ChatCompletionChunk]: """ Performs underlying asynchronous streaming request to OpenAI and returns the async stream iterator. """ client = AsyncOpenAI(**self._prepare_client_kwargs()) response_stream: AsyncStream[ChatCompletionChunk] = await client.chat.completions.create(**request_data, stream=True) return response_stream def handle_llm_error(self, e: Exception) -> Exception: """ Maps OpenAI-specific errors to common LLMError types. """ if isinstance(e, openai.APIConnectionError): logger.warning(f"[OpenAI] API connection error: {e}") return LLMConnectionError( message=f"Failed to connect to OpenAI: {str(e)}", code=ErrorCode.INTERNAL_SERVER_ERROR, details={"cause": str(e.__cause__) if e.__cause__ else None}, ) if isinstance(e, openai.RateLimitError): logger.warning(f"[OpenAI] Rate limited (429). Consider backoff. Error: {e}") return LLMRateLimitError( message=f"Rate limited by OpenAI: {str(e)}", code=ErrorCode.RATE_LIMIT_EXCEEDED, details=e.body, # Include body which often has rate limit details ) if isinstance(e, openai.BadRequestError): logger.warning(f"[OpenAI] Bad request (400): {str(e)}") # BadRequestError can signify different issues (e.g., invalid args, context length) # Check message content if finer-grained errors are needed # Example: if "context_length_exceeded" in str(e): return LLMContextLengthExceededError(...) return LLMBadRequestError( message=f"Bad request to OpenAI: {str(e)}", code=ErrorCode.INVALID_ARGUMENT, # Or more specific if detectable details=e.body, ) if isinstance(e, openai.AuthenticationError): logger.error(f"[OpenAI] Authentication error (401): {str(e)}") # More severe log level return LLMAuthenticationError( message=f"Authentication failed with OpenAI: {str(e)}", code=ErrorCode.UNAUTHENTICATED, details=e.body ) if isinstance(e, openai.PermissionDeniedError): logger.error(f"[OpenAI] Permission denied (403): {str(e)}") # More severe log level return LLMPermissionDeniedError( message=f"Permission denied by OpenAI: {str(e)}", code=ErrorCode.PERMISSION_DENIED, details=e.body ) if isinstance(e, openai.NotFoundError): logger.warning(f"[OpenAI] Resource not found (404): {str(e)}") # Could be invalid model name, etc. return LLMNotFoundError(message=f"Resource not found in OpenAI: {str(e)}", code=ErrorCode.NOT_FOUND, details=e.body) if isinstance(e, openai.UnprocessableEntityError): logger.warning(f"[OpenAI] Unprocessable entity (422): {str(e)}") return LLMUnprocessableEntityError( message=f"Invalid request content for OpenAI: {str(e)}", code=ErrorCode.INVALID_ARGUMENT, # Usually validation errors details=e.body, ) # General API error catch-all if isinstance(e, openai.APIStatusError): logger.warning(f"[OpenAI] API status error ({e.status_code}): {str(e)}") # Map based on status code potentially if e.status_code >= 500: error_cls = LLMServerError error_code = ErrorCode.INTERNAL_SERVER_ERROR else: # Treat other 4xx as bad requests if not caught above error_cls = LLMBadRequestError error_code = ErrorCode.INVALID_ARGUMENT return error_cls( message=f"OpenAI API error: {str(e)}", code=error_code, details={ "status_code": e.status_code, "response": str(e.response), "body": e.body, }, ) # Fallback for unexpected errors return super().handle_llm_error(e)