mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
feat: add endpoint to test connection to llm provider (#2032)
Co-authored-by: Jin Peng <jinjpeng@Jins-MacBook-Pro.local>
This commit is contained in:
parent
51d7872604
commit
1a5283177d
@ -19,7 +19,7 @@ from anthropic.types.beta import (
|
||||
BetaToolUseBlock,
|
||||
)
|
||||
|
||||
from letta.errors import BedrockError, BedrockPermissionError
|
||||
from letta.errors import BedrockError, BedrockPermissionError, ErrorCode, LLMAuthenticationError, LLMError
|
||||
from letta.helpers.datetime_helpers import get_utc_time_int, timestamp_to_datetime
|
||||
from letta.llm_api.aws_bedrock import get_bedrock_client
|
||||
from letta.llm_api.helpers import add_inner_thoughts_to_functions
|
||||
@ -119,6 +119,20 @@ DUMMY_FIRST_USER_MESSAGE = "User initializing bootup sequence."
|
||||
VALID_EVENT_TYPES = {"content_block_stop", "message_stop"}
|
||||
|
||||
|
||||
def anthropic_check_valid_api_key(api_key: Union[str, None]) -> None:
|
||||
if api_key:
|
||||
anthropic_client = anthropic.Anthropic(api_key=api_key)
|
||||
try:
|
||||
# just use a cheap model to count some tokens - as of 5/7/2025 this is faster than fetching the list of models
|
||||
anthropic_client.messages.count_tokens(model=MODEL_LIST[-1]["name"], messages=[{"role": "user", "content": "a"}])
|
||||
except anthropic.AuthenticationError as e:
|
||||
raise LLMAuthenticationError(message=f"Failed to authenticate with Anthropic: {e}", code=ErrorCode.UNAUTHENTICATED)
|
||||
except Exception as e:
|
||||
raise LLMError(message=f"{e}", code=ErrorCode.INTERNAL_SERVER_ERROR)
|
||||
else:
|
||||
raise ValueError("No API key provided")
|
||||
|
||||
|
||||
def antropic_get_model_context_window(url: str, api_key: Union[str, None], model: str) -> int:
|
||||
for model_dict in anthropic_get_model_list(url=url, api_key=api_key):
|
||||
if model_dict["name"] == model:
|
||||
|
@ -3,11 +3,14 @@ import uuid
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import requests
|
||||
from google import genai
|
||||
from google.genai.types import FunctionCallingConfig, FunctionCallingConfigMode, ToolConfig
|
||||
|
||||
from letta.constants import NON_USER_MSG_PREFIX
|
||||
from letta.errors import ErrorCode, LLMAuthenticationError, LLMError
|
||||
from letta.helpers.datetime_helpers import get_utc_time_int
|
||||
from letta.helpers.json_helpers import json_dumps
|
||||
from letta.llm_api.google_constants import GOOGLE_MODEL_FOR_API_KEY_CHECK
|
||||
from letta.llm_api.helpers import make_post_request
|
||||
from letta.llm_api.llm_client_base import LLMClientBase
|
||||
from letta.local_llm.json_parser import clean_json_string_extra_backslash
|
||||
@ -443,6 +446,23 @@ def get_gemini_endpoint_and_headers(
|
||||
return url, headers
|
||||
|
||||
|
||||
def google_ai_check_valid_api_key(api_key: str):
|
||||
client = genai.Client(api_key=api_key)
|
||||
# use the count token endpoint for a cheap model - as of 5/7/2025 this is slightly faster than fetching the list of models
|
||||
try:
|
||||
client.models.count_tokens(
|
||||
model=GOOGLE_MODEL_FOR_API_KEY_CHECK,
|
||||
contents="",
|
||||
)
|
||||
except genai.errors.ClientError as e:
|
||||
# google api returns 400 invalid argument for invalid api key
|
||||
if e.code == 400:
|
||||
raise LLMAuthenticationError(message=f"Failed to authenticate with Google AI: {e}", code=ErrorCode.UNAUTHENTICATED)
|
||||
raise e
|
||||
except Exception as e:
|
||||
raise LLMError(message=f"{e}", code=ErrorCode.INTERNAL_SERVER_ERROR)
|
||||
|
||||
|
||||
def google_ai_get_model_list(base_url: str, api_key: str, key_in_header: bool = True) -> List[dict]:
|
||||
from letta.utils import printd
|
||||
|
||||
|
@ -14,3 +14,5 @@ GOOGLE_MODEL_TO_CONTEXT_LENGTH = {
|
||||
GOOGLE_MODEL_TO_OUTPUT_LENGTH = {"gemini-2.0-flash-001": 8192, "gemini-2.5-pro-exp-03-25": 65536}
|
||||
|
||||
GOOGLE_EMBEDING_MODEL_TO_DIM = {"text-embedding-005": 768, "text-multilingual-embedding-002": 768}
|
||||
|
||||
GOOGLE_MODEL_FOR_API_KEY_CHECK = "gemini-2.0-flash-lite"
|
||||
|
@ -5,6 +5,7 @@ import requests
|
||||
from openai import OpenAI
|
||||
|
||||
from letta.constants import LETTA_MODEL_ENDPOINT
|
||||
from letta.errors import ErrorCode, LLMAuthenticationError, LLMError
|
||||
from letta.helpers.datetime_helpers import timestamp_to_datetime
|
||||
from letta.llm_api.helpers import add_inner_thoughts_to_functions, convert_to_structured_output, make_post_request
|
||||
from letta.llm_api.openai_client import accepts_developer_role, supports_parallel_tool_calling, supports_temperature_param
|
||||
@ -34,6 +35,21 @@ from letta.utils import get_tool_call_id, smart_urljoin
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def openai_check_valid_api_key(base_url: str, api_key: Union[str, None]) -> None:
|
||||
if api_key:
|
||||
try:
|
||||
# just get model list to check if the api key is valid until we find a cheaper / quicker endpoint
|
||||
openai_get_model_list(url=base_url, api_key=api_key)
|
||||
except requests.HTTPError as e:
|
||||
if e.response.status_code == 401:
|
||||
raise LLMAuthenticationError(message=f"Failed to authenticate with OpenAI: {e}", code=ErrorCode.UNAUTHENTICATED)
|
||||
raise e
|
||||
except Exception as e:
|
||||
raise LLMError(message=f"{e}", code=ErrorCode.INTERNAL_SERVER_ERROR)
|
||||
else:
|
||||
raise ValueError("No API key provided")
|
||||
|
||||
|
||||
def openai_get_model_list(
|
||||
url: str, api_key: Optional[str] = None, fix_url: Optional[bool] = False, extra_params: Optional[dict] = None
|
||||
) -> dict:
|
||||
|
@ -3,6 +3,7 @@ from enum import Enum
|
||||
|
||||
class ProviderType(str, Enum):
|
||||
anthropic = "anthropic"
|
||||
anthropic_bedrock = "bedrock"
|
||||
google_ai = "google_ai"
|
||||
google_vertex = "google_vertex"
|
||||
openai = "openai"
|
||||
|
@ -2,7 +2,7 @@ import warnings
|
||||
from datetime import datetime
|
||||
from typing import List, Literal, Optional
|
||||
|
||||
from pydantic import Field, model_validator
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from letta.constants import LETTA_MODEL_ENDPOINT, LLM_MAX_TOKENS, MIN_CONTEXT_WINDOW
|
||||
from letta.llm_api.azure_openai import get_azure_chat_completions_endpoint, get_azure_embeddings_endpoint
|
||||
@ -40,6 +40,10 @@ class Provider(ProviderBase):
|
||||
if not self.id:
|
||||
self.id = ProviderBase.generate_id(prefix=ProviderBase.__id_prefix__)
|
||||
|
||||
def check_api_key(self):
|
||||
"""Check if the API key is valid for the provider"""
|
||||
raise NotImplementedError
|
||||
|
||||
def list_llm_models(self) -> List[LLMConfig]:
|
||||
return []
|
||||
|
||||
@ -112,6 +116,11 @@ class ProviderUpdate(ProviderBase):
|
||||
api_key: str = Field(..., description="API key used for requests to the provider.")
|
||||
|
||||
|
||||
class ProviderCheck(BaseModel):
|
||||
provider_type: ProviderType = Field(..., description="The type of the provider.")
|
||||
api_key: str = Field(..., description="API key used for requests to the provider.")
|
||||
|
||||
|
||||
class LettaProvider(Provider):
|
||||
provider_type: Literal[ProviderType.letta] = Field(ProviderType.letta, description="The type of the provider.")
|
||||
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
|
||||
@ -148,6 +157,11 @@ class OpenAIProvider(Provider):
|
||||
api_key: str = Field(..., description="API key for the OpenAI API.")
|
||||
base_url: str = Field(..., description="Base URL for the OpenAI API.")
|
||||
|
||||
def check_api_key(self):
|
||||
from letta.llm_api.openai import openai_check_valid_api_key
|
||||
|
||||
openai_check_valid_api_key(self.base_url, self.api_key)
|
||||
|
||||
def list_llm_models(self) -> List[LLMConfig]:
|
||||
from letta.llm_api.openai import openai_get_model_list
|
||||
|
||||
@ -549,6 +563,11 @@ class AnthropicProvider(Provider):
|
||||
api_key: str = Field(..., description="API key for the Anthropic API.")
|
||||
base_url: str = "https://api.anthropic.com/v1"
|
||||
|
||||
def check_api_key(self):
|
||||
from letta.llm_api.anthropic import anthropic_check_valid_api_key
|
||||
|
||||
anthropic_check_valid_api_key(self.api_key)
|
||||
|
||||
def list_llm_models(self) -> List[LLMConfig]:
|
||||
from letta.llm_api.anthropic import MODEL_LIST, anthropic_get_model_list
|
||||
|
||||
@ -951,6 +970,11 @@ class GoogleAIProvider(Provider):
|
||||
api_key: str = Field(..., description="API key for the Google AI API.")
|
||||
base_url: str = "https://generativelanguage.googleapis.com"
|
||||
|
||||
def check_api_key(self):
|
||||
from letta.llm_api.google_ai_client import google_ai_check_valid_api_key
|
||||
|
||||
google_ai_check_valid_api_key(self.api_key)
|
||||
|
||||
def list_llm_models(self):
|
||||
from letta.llm_api.google_ai_client import google_ai_get_model_list
|
||||
|
||||
|
@ -3,9 +3,10 @@ from typing import TYPE_CHECKING, List, Optional
|
||||
from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query, status
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from letta.errors import LLMAuthenticationError
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.schemas.enums import ProviderType
|
||||
from letta.schemas.providers import Provider, ProviderCreate, ProviderUpdate
|
||||
from letta.schemas.providers import Provider, ProviderCheck, ProviderCreate, ProviderUpdate
|
||||
from letta.server.rest_api.utils import get_letta_server
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -67,6 +68,22 @@ def modify_provider(
|
||||
return server.provider_manager.update_provider(provider_id=provider_id, request=request, actor=actor)
|
||||
|
||||
|
||||
@router.get("/check", response_model=None, operation_id="check_provider")
|
||||
def check_provider(
|
||||
provider_type: ProviderType = Query(...),
|
||||
api_key: str = Header(..., alias="x-api-key"),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
):
|
||||
try:
|
||||
provider_check = ProviderCheck(provider_type=provider_type, api_key=api_key)
|
||||
server.provider_manager.check_provider_api_key(provider_check=provider_check)
|
||||
return JSONResponse(status_code=status.HTTP_200_OK, content={"message": f"Valid api key for provider_type={provider_type.value}"})
|
||||
except LLMAuthenticationError as e:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=f"{e.message}")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"{e}")
|
||||
|
||||
|
||||
@router.delete("/{provider_id}", response_model=None, operation_id="delete_provider")
|
||||
def delete_provider(
|
||||
provider_id: str,
|
||||
|
@ -3,7 +3,7 @@ from typing import List, Optional, Union
|
||||
from letta.orm.provider import Provider as ProviderModel
|
||||
from letta.schemas.enums import ProviderCategory, ProviderType
|
||||
from letta.schemas.providers import Provider as PydanticProvider
|
||||
from letta.schemas.providers import ProviderCreate, ProviderUpdate
|
||||
from letta.schemas.providers import ProviderCheck, ProviderCreate, ProviderUpdate
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.utils import enforce_types
|
||||
|
||||
@ -99,3 +99,18 @@ class ProviderManager:
|
||||
def get_override_key(self, provider_name: Union[str, None], actor: PydanticUser) -> Optional[str]:
|
||||
providers = self.list_providers(name=provider_name, actor=actor)
|
||||
return providers[0].api_key if providers else None
|
||||
|
||||
@enforce_types
|
||||
def check_provider_api_key(self, provider_check: ProviderCheck) -> None:
|
||||
provider = PydanticProvider(
|
||||
name=provider_check.provider_type.value,
|
||||
provider_type=provider_check.provider_type,
|
||||
api_key=provider_check.api_key,
|
||||
provider_category=ProviderCategory.byok,
|
||||
).cast_to_subtype()
|
||||
|
||||
# TODO: add more string sanity checks here before we hit actual endpoints
|
||||
if not provider.api_key:
|
||||
raise ValueError("API key is required")
|
||||
|
||||
provider.check_api_key()
|
||||
|
Loading…
Reference in New Issue
Block a user