mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
feat: Add listing llm models and embedding models for Azure endpoint (#1846)
Co-authored-by: Matt Zhou <mattzhou@Matts-MacBook-Pro.local>
This commit is contained in:
parent
6d1b22ff58
commit
f61ac27800
@ -2,6 +2,5 @@
|
|||||||
"context_window": 128000,
|
"context_window": 128000,
|
||||||
"model": "gpt-4o-mini",
|
"model": "gpt-4o-mini",
|
||||||
"model_endpoint_type": "azure",
|
"model_endpoint_type": "azure",
|
||||||
"api_version": "2023-03-15-preview",
|
|
||||||
"model_wrapper": null
|
"model_wrapper": null
|
||||||
}
|
}
|
||||||
|
@ -1,5 +1,3 @@
|
|||||||
from typing import Union
|
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from letta.schemas.llm_config import LLMConfig
|
from letta.schemas.llm_config import LLMConfig
|
||||||
@ -7,70 +5,58 @@ from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
|
|||||||
from letta.schemas.openai.chat_completions import ChatCompletionRequest
|
from letta.schemas.openai.chat_completions import ChatCompletionRequest
|
||||||
from letta.schemas.openai.embedding_response import EmbeddingResponse
|
from letta.schemas.openai.embedding_response import EmbeddingResponse
|
||||||
from letta.settings import ModelSettings
|
from letta.settings import ModelSettings
|
||||||
from letta.utils import smart_urljoin
|
|
||||||
|
|
||||||
MODEL_TO_AZURE_ENGINE = {
|
|
||||||
"gpt-4-1106-preview": "gpt-4",
|
|
||||||
"gpt-4": "gpt-4",
|
|
||||||
"gpt-4-32k": "gpt-4-32k",
|
|
||||||
"gpt-3.5": "gpt-35-turbo",
|
|
||||||
"gpt-3.5-turbo": "gpt-35-turbo",
|
|
||||||
"gpt-3.5-turbo-16k": "gpt-35-turbo-16k",
|
|
||||||
"gpt-4o-mini": "gpt-4o-mini",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def get_azure_endpoint(llm_config: LLMConfig, model_settings: ModelSettings):
|
def get_azure_chat_completions_endpoint(base_url: str, model: str, api_version: str):
|
||||||
assert llm_config.api_version, "Missing model version! This field must be provided in the LLM config for Azure."
|
return f"{base_url}/openai/deployments/{model}/chat/completions?api-version={api_version}"
|
||||||
assert llm_config.model in MODEL_TO_AZURE_ENGINE, f"{llm_config.model} not in supported models: {list(MODEL_TO_AZURE_ENGINE.keys())}"
|
|
||||||
|
|
||||||
model = MODEL_TO_AZURE_ENGINE[llm_config.model]
|
|
||||||
return f"{model_settings.azure_base_url}/openai/deployments/{model}/chat/completions?api-version={llm_config.api_version}"
|
|
||||||
|
|
||||||
|
|
||||||
def azure_openai_get_model_list(url: str, api_key: Union[str, None], api_version: str) -> dict:
|
def get_azure_embeddings_endpoint(base_url: str, model: str, api_version: str):
|
||||||
|
return f"{base_url}/openai/deployments/{model}/embeddings?api-version={api_version}"
|
||||||
|
|
||||||
|
|
||||||
|
def get_azure_model_list_endpoint(base_url: str, api_version: str):
|
||||||
|
return f"{base_url}/openai/models?api-version={api_version}"
|
||||||
|
|
||||||
|
|
||||||
|
def azure_openai_get_model_list(base_url: str, api_key: str, api_version: str) -> list:
|
||||||
"""https://learn.microsoft.com/en-us/rest/api/azureopenai/models/list?view=rest-azureopenai-2023-05-15&tabs=HTTP"""
|
"""https://learn.microsoft.com/en-us/rest/api/azureopenai/models/list?view=rest-azureopenai-2023-05-15&tabs=HTTP"""
|
||||||
from letta.utils import printd
|
|
||||||
|
|
||||||
# https://xxx.openai.azure.com/openai/models?api-version=xxx
|
# https://xxx.openai.azure.com/openai/models?api-version=xxx
|
||||||
url = smart_urljoin(url, "openai")
|
|
||||||
url = smart_urljoin(url, f"models?api-version={api_version}")
|
|
||||||
|
|
||||||
headers = {"Content-Type": "application/json"}
|
headers = {"Content-Type": "application/json"}
|
||||||
if api_key is not None:
|
if api_key is not None:
|
||||||
headers["api-key"] = f"{api_key}"
|
headers["api-key"] = f"{api_key}"
|
||||||
|
|
||||||
printd(f"Sending request to {url}")
|
url = get_azure_model_list_endpoint(base_url, api_version)
|
||||||
try:
|
try:
|
||||||
response = requests.get(url, headers=headers)
|
response = requests.get(url, headers=headers)
|
||||||
response.raise_for_status() # Raises HTTPError for 4XX/5XX status
|
response.raise_for_status()
|
||||||
response = response.json() # convert to dict from string
|
except requests.RequestException as e:
|
||||||
printd(f"response = {response}")
|
raise RuntimeError(f"Failed to retrieve model list: {e}")
|
||||||
return response
|
|
||||||
except requests.exceptions.HTTPError as http_err:
|
return response.json().get("data", [])
|
||||||
# Handle HTTP errors (e.g., response 4XX, 5XX)
|
|
||||||
try:
|
|
||||||
response = response.json()
|
def azure_openai_get_chat_completion_model_list(base_url: str, api_key: str, api_version: str) -> list:
|
||||||
except:
|
model_list = azure_openai_get_model_list(base_url, api_key, api_version)
|
||||||
pass
|
# Extract models that support text generation
|
||||||
printd(f"Got HTTPError, exception={http_err}, response={response}")
|
model_options = [m for m in model_list if m.get("capabilities").get("chat_completion") == True]
|
||||||
raise http_err
|
return model_options
|
||||||
except requests.exceptions.RequestException as req_err:
|
|
||||||
# Handle other requests-related errors (e.g., connection error)
|
|
||||||
try:
|
def azure_openai_get_embeddings_model_list(base_url: str, api_key: str, api_version: str, require_embedding_in_name: bool = True) -> list:
|
||||||
response = response.json()
|
def valid_embedding_model(m: dict):
|
||||||
except:
|
valid_name = True
|
||||||
pass
|
if require_embedding_in_name:
|
||||||
printd(f"Got RequestException, exception={req_err}, response={response}")
|
valid_name = "embedding" in m["id"]
|
||||||
raise req_err
|
|
||||||
except Exception as e:
|
return m.get("capabilities").get("embeddings") == True and valid_name
|
||||||
# Handle other potential errors
|
|
||||||
try:
|
model_list = azure_openai_get_model_list(base_url, api_key, api_version)
|
||||||
response = response.json()
|
# Extract models that support embeddings
|
||||||
except:
|
|
||||||
pass
|
model_options = [m for m in model_list if valid_embedding_model(m)]
|
||||||
printd(f"Got unknown Exception, exception={e}, response={response}")
|
return model_options
|
||||||
raise e
|
|
||||||
|
|
||||||
|
|
||||||
def azure_openai_chat_completions_request(
|
def azure_openai_chat_completions_request(
|
||||||
@ -93,7 +79,7 @@ def azure_openai_chat_completions_request(
|
|||||||
data.pop("tools")
|
data.pop("tools")
|
||||||
data.pop("tool_choice", None) # extra safe, should exist always (default="auto")
|
data.pop("tool_choice", None) # extra safe, should exist always (default="auto")
|
||||||
|
|
||||||
model_endpoint = get_azure_endpoint(llm_config, model_settings)
|
model_endpoint = get_azure_chat_completions_endpoint(model_settings.azure_base_url, llm_config.model, model_settings.api_version)
|
||||||
printd(f"Sending request to {model_endpoint}")
|
printd(f"Sending request to {model_endpoint}")
|
||||||
try:
|
try:
|
||||||
response = requests.post(model_endpoint, headers=headers, json=data)
|
response = requests.post(model_endpoint, headers=headers, json=data)
|
||||||
|
10
letta/llm_api/azure_openai_constants.py
Normal file
10
letta/llm_api/azure_openai_constants.py
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
AZURE_MODEL_TO_CONTEXT_LENGTH = {
|
||||||
|
"babbage-002": 16384,
|
||||||
|
"davinci-002": 16384,
|
||||||
|
"gpt-35-turbo-0613": 4096,
|
||||||
|
"gpt-35-turbo-1106": 16385,
|
||||||
|
"gpt-35-turbo-0125": 16385,
|
||||||
|
"gpt-4-0613": 8192,
|
||||||
|
"gpt-4o-mini-2024-07-18": 128000,
|
||||||
|
"gpt-4o-2024-08-06": 128000,
|
||||||
|
}
|
@ -189,6 +189,9 @@ def create(
|
|||||||
if model_settings.azure_base_url is None:
|
if model_settings.azure_base_url is None:
|
||||||
raise ValueError(f"Azure base url is missing. Did you set AZURE_BASE_URL in your env?")
|
raise ValueError(f"Azure base url is missing. Did you set AZURE_BASE_URL in your env?")
|
||||||
|
|
||||||
|
if model_settings.azure_api_version is None:
|
||||||
|
raise ValueError(f"Azure API version is missing. Did you set AZURE_API_VERSION in your env?")
|
||||||
|
|
||||||
# Set the llm config model_endpoint from model_settings
|
# Set the llm config model_endpoint from model_settings
|
||||||
# For Azure, this model_endpoint is required to be configured via env variable, so users don't need to provide it in the LLM config
|
# For Azure, this model_endpoint is required to be configured via env variable, so users don't need to provide it in the LLM config
|
||||||
llm_config.model_endpoint = model_settings.azure_base_url
|
llm_config.model_endpoint = model_settings.azure_base_url
|
||||||
|
@ -1,8 +1,13 @@
|
|||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, model_validator
|
||||||
|
|
||||||
from letta.constants import LLM_MAX_TOKENS
|
from letta.constants import LLM_MAX_TOKENS
|
||||||
|
from letta.llm_api.azure_openai import (
|
||||||
|
get_azure_chat_completions_endpoint,
|
||||||
|
get_azure_embeddings_endpoint,
|
||||||
|
)
|
||||||
|
from letta.llm_api.azure_openai_constants import AZURE_MODEL_TO_CONTEXT_LENGTH
|
||||||
from letta.schemas.embedding_config import EmbeddingConfig
|
from letta.schemas.embedding_config import EmbeddingConfig
|
||||||
from letta.schemas.llm_config import LLMConfig
|
from letta.schemas.llm_config import LLMConfig
|
||||||
|
|
||||||
@ -274,10 +279,64 @@ class GoogleAIProvider(Provider):
|
|||||||
|
|
||||||
class AzureProvider(Provider):
|
class AzureProvider(Provider):
|
||||||
name: str = "azure"
|
name: str = "azure"
|
||||||
|
latest_api_version: str = "2024-09-01-preview" # https://learn.microsoft.com/en-us/azure/ai-services/openai/api-version-deprecation
|
||||||
base_url: str = Field(
|
base_url: str = Field(
|
||||||
..., description="Base URL for the Azure API endpoint. This should be specific to your org, e.g. `https://letta.openai.azure.com`."
|
..., description="Base URL for the Azure API endpoint. This should be specific to your org, e.g. `https://letta.openai.azure.com`."
|
||||||
)
|
)
|
||||||
api_key: str = Field(..., description="API key for the Azure API.")
|
api_key: str = Field(..., description="API key for the Azure API.")
|
||||||
|
api_version: str = Field(latest_api_version, description="API version for the Azure API")
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
def set_default_api_version(cls, values):
|
||||||
|
"""
|
||||||
|
This ensures that api_version is always set to the default if None is passed in.
|
||||||
|
"""
|
||||||
|
if values.get("api_version") is None:
|
||||||
|
values["api_version"] = cls.model_fields["latest_api_version"].default
|
||||||
|
return values
|
||||||
|
|
||||||
|
def list_llm_models(self) -> List[LLMConfig]:
|
||||||
|
from letta.llm_api.azure_openai import (
|
||||||
|
azure_openai_get_chat_completion_model_list,
|
||||||
|
)
|
||||||
|
|
||||||
|
model_options = azure_openai_get_chat_completion_model_list(self.base_url, api_key=self.api_key, api_version=self.api_version)
|
||||||
|
configs = []
|
||||||
|
for model_option in model_options:
|
||||||
|
model_name = model_option["id"]
|
||||||
|
context_window_size = self.get_model_context_window(model_name)
|
||||||
|
model_endpoint = get_azure_chat_completions_endpoint(self.base_url, model_name, self.api_version)
|
||||||
|
configs.append(
|
||||||
|
LLMConfig(model=model_name, model_endpoint_type="azure", model_endpoint=model_endpoint, context_window=context_window_size)
|
||||||
|
)
|
||||||
|
return configs
|
||||||
|
|
||||||
|
def list_embedding_models(self) -> List[EmbeddingConfig]:
|
||||||
|
from letta.llm_api.azure_openai import azure_openai_get_embeddings_model_list
|
||||||
|
|
||||||
|
model_options = azure_openai_get_embeddings_model_list(
|
||||||
|
self.base_url, api_key=self.api_key, api_version=self.api_version, require_embedding_in_name=True
|
||||||
|
)
|
||||||
|
configs = []
|
||||||
|
for model_option in model_options:
|
||||||
|
model_name = model_option["id"]
|
||||||
|
model_endpoint = get_azure_embeddings_endpoint(self.base_url, model_name, self.api_version)
|
||||||
|
configs.append(
|
||||||
|
EmbeddingConfig(
|
||||||
|
embedding_model=model_name,
|
||||||
|
embedding_endpoint_type="azure",
|
||||||
|
embedding_endpoint=model_endpoint,
|
||||||
|
embedding_dim=768,
|
||||||
|
embedding_chunk_size=300, # NOTE: max is 2048
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return configs
|
||||||
|
|
||||||
|
def get_model_context_window(self, model_name: str):
|
||||||
|
"""
|
||||||
|
This is hardcoded for now, since there is no API endpoints to retrieve metadata for a model.
|
||||||
|
"""
|
||||||
|
return AZURE_MODEL_TO_CONTEXT_LENGTH.get(model_name, 4096)
|
||||||
|
|
||||||
|
|
||||||
class VLLMProvider(OpenAIProvider):
|
class VLLMProvider(OpenAIProvider):
|
||||||
|
@ -35,9 +35,6 @@ class LLMConfig(BaseModel):
|
|||||||
"hugging-face",
|
"hugging-face",
|
||||||
] = Field(..., description="The endpoint type for the model.")
|
] = Field(..., description="The endpoint type for the model.")
|
||||||
model_endpoint: Optional[str] = Field(None, description="The endpoint for the model.")
|
model_endpoint: Optional[str] = Field(None, description="The endpoint for the model.")
|
||||||
api_version: Optional[str] = Field(
|
|
||||||
None, description="The version for the model API. Used by the Azure provider backend, e.g. 2023-03-15-preview."
|
|
||||||
)
|
|
||||||
model_wrapper: Optional[str] = Field(None, description="The wrapper for the model.")
|
model_wrapper: Optional[str] = Field(None, description="The wrapper for the model.")
|
||||||
context_window: int = Field(..., description="The context window size for the model.")
|
context_window: int = Field(..., description="The context window size for the model.")
|
||||||
|
|
||||||
|
@ -272,7 +272,13 @@ class SyncServer(Server):
|
|||||||
if model_settings.gemini_api_key:
|
if model_settings.gemini_api_key:
|
||||||
self._enabled_providers.append(GoogleAIProvider(api_key=model_settings.gemini_api_key))
|
self._enabled_providers.append(GoogleAIProvider(api_key=model_settings.gemini_api_key))
|
||||||
if model_settings.azure_api_key and model_settings.azure_base_url:
|
if model_settings.azure_api_key and model_settings.azure_base_url:
|
||||||
self._enabled_providers.append(AzureProvider(api_key=model_settings.azure_api_key, base_url=model_settings.azure_base_url))
|
self._enabled_providers.append(
|
||||||
|
AzureProvider(
|
||||||
|
api_key=model_settings.azure_api_key,
|
||||||
|
base_url=model_settings.azure_base_url,
|
||||||
|
api_version=model_settings.azure_api_version,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def save_agents(self):
|
def save_agents(self):
|
||||||
"""Saves all the agents that are in the in-memory object store"""
|
"""Saves all the agents that are in the in-memory object store"""
|
||||||
|
@ -25,6 +25,7 @@ class ModelSettings(BaseSettings):
|
|||||||
# azure
|
# azure
|
||||||
azure_api_key: Optional[str] = None
|
azure_api_key: Optional[str] = None
|
||||||
azure_base_url: Optional[str] = None
|
azure_base_url: Optional[str] = None
|
||||||
|
azure_api_version: Optional[str] = None
|
||||||
|
|
||||||
# google ai
|
# google ai
|
||||||
gemini_api_key: Optional[str] = None
|
gemini_api_key: Optional[str] = None
|
||||||
|
@ -2,6 +2,5 @@
|
|||||||
"context_window": 128000,
|
"context_window": 128000,
|
||||||
"model": "gpt-4o-mini",
|
"model": "gpt-4o-mini",
|
||||||
"model_endpoint_type": "azure",
|
"model_endpoint_type": "azure",
|
||||||
"api_version": "2023-03-15-preview",
|
|
||||||
"model_wrapper": null
|
"model_wrapper": null
|
||||||
}
|
}
|
||||||
|
@ -29,6 +29,14 @@ def test_anthropic():
|
|||||||
# print(models)
|
# print(models)
|
||||||
#
|
#
|
||||||
#
|
#
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Add this test
|
||||||
|
# https://linear.app/letta/issue/LET-159/add-tests-for-azure-openai-in-test-providerspy-and-test-endpointspy
|
||||||
|
def test_azure():
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def test_ollama():
|
def test_ollama():
|
||||||
provider = OllamaProvider(base_url=os.getenv("OLLAMA_BASE_URL"))
|
provider = OllamaProvider(base_url=os.getenv("OLLAMA_BASE_URL"))
|
||||||
models = provider.list_llm_models()
|
models = provider.list_llm_models()
|
||||||
|
Loading…
Reference in New Issue
Block a user