feat: add endpoint to test connection to llm provider (#2032)

Co-authored-by: Jin Peng <jinjpeng@Jins-MacBook-Pro.local>
This commit is contained in:
jnjpng 2025-05-07 16:26:55 -07:00 committed by GitHub
parent 51d7872604
commit 1a5283177d
8 changed files with 113 additions and 4 deletions

View File

@ -19,7 +19,7 @@ from anthropic.types.beta import (
BetaToolUseBlock,
)
from letta.errors import BedrockError, BedrockPermissionError
from letta.errors import BedrockError, BedrockPermissionError, ErrorCode, LLMAuthenticationError, LLMError
from letta.helpers.datetime_helpers import get_utc_time_int, timestamp_to_datetime
from letta.llm_api.aws_bedrock import get_bedrock_client
from letta.llm_api.helpers import add_inner_thoughts_to_functions
@ -119,6 +119,20 @@ DUMMY_FIRST_USER_MESSAGE = "User initializing bootup sequence."
VALID_EVENT_TYPES = {"content_block_stop", "message_stop"}
def anthropic_check_valid_api_key(api_key: Union[str, None]) -> None:
if api_key:
anthropic_client = anthropic.Anthropic(api_key=api_key)
try:
# just use a cheap model to count some tokens - as of 5/7/2025 this is faster than fetching the list of models
anthropic_client.messages.count_tokens(model=MODEL_LIST[-1]["name"], messages=[{"role": "user", "content": "a"}])
except anthropic.AuthenticationError as e:
raise LLMAuthenticationError(message=f"Failed to authenticate with Anthropic: {e}", code=ErrorCode.UNAUTHENTICATED)
except Exception as e:
raise LLMError(message=f"{e}", code=ErrorCode.INTERNAL_SERVER_ERROR)
else:
raise ValueError("No API key provided")
def antropic_get_model_context_window(url: str, api_key: Union[str, None], model: str) -> int:
for model_dict in anthropic_get_model_list(url=url, api_key=api_key):
if model_dict["name"] == model:

View File

@ -3,11 +3,14 @@ import uuid
from typing import List, Optional, Tuple
import requests
from google import genai
from google.genai.types import FunctionCallingConfig, FunctionCallingConfigMode, ToolConfig
from letta.constants import NON_USER_MSG_PREFIX
from letta.errors import ErrorCode, LLMAuthenticationError, LLMError
from letta.helpers.datetime_helpers import get_utc_time_int
from letta.helpers.json_helpers import json_dumps
from letta.llm_api.google_constants import GOOGLE_MODEL_FOR_API_KEY_CHECK
from letta.llm_api.helpers import make_post_request
from letta.llm_api.llm_client_base import LLMClientBase
from letta.local_llm.json_parser import clean_json_string_extra_backslash
@ -443,6 +446,23 @@ def get_gemini_endpoint_and_headers(
return url, headers
def google_ai_check_valid_api_key(api_key: str):
client = genai.Client(api_key=api_key)
# use the count token endpoint for a cheap model - as of 5/7/2025 this is slightly faster than fetching the list of models
try:
client.models.count_tokens(
model=GOOGLE_MODEL_FOR_API_KEY_CHECK,
contents="",
)
except genai.errors.ClientError as e:
# google api returns 400 invalid argument for invalid api key
if e.code == 400:
raise LLMAuthenticationError(message=f"Failed to authenticate with Google AI: {e}", code=ErrorCode.UNAUTHENTICATED)
raise e
except Exception as e:
raise LLMError(message=f"{e}", code=ErrorCode.INTERNAL_SERVER_ERROR)
def google_ai_get_model_list(base_url: str, api_key: str, key_in_header: bool = True) -> List[dict]:
from letta.utils import printd

View File

@ -14,3 +14,5 @@ GOOGLE_MODEL_TO_CONTEXT_LENGTH = {
GOOGLE_MODEL_TO_OUTPUT_LENGTH = {"gemini-2.0-flash-001": 8192, "gemini-2.5-pro-exp-03-25": 65536}
GOOGLE_EMBEDING_MODEL_TO_DIM = {"text-embedding-005": 768, "text-multilingual-embedding-002": 768}
GOOGLE_MODEL_FOR_API_KEY_CHECK = "gemini-2.0-flash-lite"

View File

@ -5,6 +5,7 @@ import requests
from openai import OpenAI
from letta.constants import LETTA_MODEL_ENDPOINT
from letta.errors import ErrorCode, LLMAuthenticationError, LLMError
from letta.helpers.datetime_helpers import timestamp_to_datetime
from letta.llm_api.helpers import add_inner_thoughts_to_functions, convert_to_structured_output, make_post_request
from letta.llm_api.openai_client import accepts_developer_role, supports_parallel_tool_calling, supports_temperature_param
@ -34,6 +35,21 @@ from letta.utils import get_tool_call_id, smart_urljoin
logger = get_logger(__name__)
def openai_check_valid_api_key(base_url: str, api_key: Union[str, None]) -> None:
if api_key:
try:
# just get model list to check if the api key is valid until we find a cheaper / quicker endpoint
openai_get_model_list(url=base_url, api_key=api_key)
except requests.HTTPError as e:
if e.response.status_code == 401:
raise LLMAuthenticationError(message=f"Failed to authenticate with OpenAI: {e}", code=ErrorCode.UNAUTHENTICATED)
raise e
except Exception as e:
raise LLMError(message=f"{e}", code=ErrorCode.INTERNAL_SERVER_ERROR)
else:
raise ValueError("No API key provided")
def openai_get_model_list(
url: str, api_key: Optional[str] = None, fix_url: Optional[bool] = False, extra_params: Optional[dict] = None
) -> dict:

View File

@ -3,6 +3,7 @@ from enum import Enum
class ProviderType(str, Enum):
anthropic = "anthropic"
anthropic_bedrock = "bedrock"
google_ai = "google_ai"
google_vertex = "google_vertex"
openai = "openai"

View File

@ -2,7 +2,7 @@ import warnings
from datetime import datetime
from typing import List, Literal, Optional
from pydantic import Field, model_validator
from pydantic import BaseModel, Field, model_validator
from letta.constants import LETTA_MODEL_ENDPOINT, LLM_MAX_TOKENS, MIN_CONTEXT_WINDOW
from letta.llm_api.azure_openai import get_azure_chat_completions_endpoint, get_azure_embeddings_endpoint
@ -40,6 +40,10 @@ class Provider(ProviderBase):
if not self.id:
self.id = ProviderBase.generate_id(prefix=ProviderBase.__id_prefix__)
def check_api_key(self):
"""Check if the API key is valid for the provider"""
raise NotImplementedError
def list_llm_models(self) -> List[LLMConfig]:
return []
@ -112,6 +116,11 @@ class ProviderUpdate(ProviderBase):
api_key: str = Field(..., description="API key used for requests to the provider.")
class ProviderCheck(BaseModel):
provider_type: ProviderType = Field(..., description="The type of the provider.")
api_key: str = Field(..., description="API key used for requests to the provider.")
class LettaProvider(Provider):
provider_type: Literal[ProviderType.letta] = Field(ProviderType.letta, description="The type of the provider.")
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
@ -148,6 +157,11 @@ class OpenAIProvider(Provider):
api_key: str = Field(..., description="API key for the OpenAI API.")
base_url: str = Field(..., description="Base URL for the OpenAI API.")
def check_api_key(self):
from letta.llm_api.openai import openai_check_valid_api_key
openai_check_valid_api_key(self.base_url, self.api_key)
def list_llm_models(self) -> List[LLMConfig]:
from letta.llm_api.openai import openai_get_model_list
@ -549,6 +563,11 @@ class AnthropicProvider(Provider):
api_key: str = Field(..., description="API key for the Anthropic API.")
base_url: str = "https://api.anthropic.com/v1"
def check_api_key(self):
from letta.llm_api.anthropic import anthropic_check_valid_api_key
anthropic_check_valid_api_key(self.api_key)
def list_llm_models(self) -> List[LLMConfig]:
from letta.llm_api.anthropic import MODEL_LIST, anthropic_get_model_list
@ -951,6 +970,11 @@ class GoogleAIProvider(Provider):
api_key: str = Field(..., description="API key for the Google AI API.")
base_url: str = "https://generativelanguage.googleapis.com"
def check_api_key(self):
from letta.llm_api.google_ai_client import google_ai_check_valid_api_key
google_ai_check_valid_api_key(self.api_key)
def list_llm_models(self):
from letta.llm_api.google_ai_client import google_ai_get_model_list

View File

@ -3,9 +3,10 @@ from typing import TYPE_CHECKING, List, Optional
from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query, status
from fastapi.responses import JSONResponse
from letta.errors import LLMAuthenticationError
from letta.orm.errors import NoResultFound
from letta.schemas.enums import ProviderType
from letta.schemas.providers import Provider, ProviderCreate, ProviderUpdate
from letta.schemas.providers import Provider, ProviderCheck, ProviderCreate, ProviderUpdate
from letta.server.rest_api.utils import get_letta_server
if TYPE_CHECKING:
@ -67,6 +68,22 @@ def modify_provider(
return server.provider_manager.update_provider(provider_id=provider_id, request=request, actor=actor)
@router.get("/check", response_model=None, operation_id="check_provider")
def check_provider(
provider_type: ProviderType = Query(...),
api_key: str = Header(..., alias="x-api-key"),
server: "SyncServer" = Depends(get_letta_server),
):
try:
provider_check = ProviderCheck(provider_type=provider_type, api_key=api_key)
server.provider_manager.check_provider_api_key(provider_check=provider_check)
return JSONResponse(status_code=status.HTTP_200_OK, content={"message": f"Valid api key for provider_type={provider_type.value}"})
except LLMAuthenticationError as e:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=f"{e.message}")
except Exception as e:
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"{e}")
@router.delete("/{provider_id}", response_model=None, operation_id="delete_provider")
def delete_provider(
provider_id: str,

View File

@ -3,7 +3,7 @@ from typing import List, Optional, Union
from letta.orm.provider import Provider as ProviderModel
from letta.schemas.enums import ProviderCategory, ProviderType
from letta.schemas.providers import Provider as PydanticProvider
from letta.schemas.providers import ProviderCreate, ProviderUpdate
from letta.schemas.providers import ProviderCheck, ProviderCreate, ProviderUpdate
from letta.schemas.user import User as PydanticUser
from letta.utils import enforce_types
@ -99,3 +99,18 @@ class ProviderManager:
def get_override_key(self, provider_name: Union[str, None], actor: PydanticUser) -> Optional[str]:
providers = self.list_providers(name=provider_name, actor=actor)
return providers[0].api_key if providers else None
@enforce_types
def check_provider_api_key(self, provider_check: ProviderCheck) -> None:
provider = PydanticProvider(
name=provider_check.provider_type.value,
provider_type=provider_check.provider_type,
api_key=provider_check.api_key,
provider_category=ProviderCategory.byok,
).cast_to_subtype()
# TODO: add more string sanity checks here before we hit actual endpoints
if not provider.api_key:
raise ValueError("API key is required")
provider.check_api_key()