chore: fix branch (#1865)

This commit is contained in:
Sarah Wooders 2024-10-10 14:07:45 -07:00 committed by GitHub
parent 60b51b4847
commit eb1aab1fc7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 50 additions and 12 deletions

View File

@ -14,7 +14,9 @@ from letta.constants import CLI_WARNING_PREFIX, LETTA_DIR
from letta.local_llm.constants import ASSISTANT_MESSAGE_CLI_SYMBOL
from letta.log import get_logger
from letta.metadata import MetadataStore
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import OptionState
from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import ChatMemory, Memory
from letta.server.server import logger as server_logger
@ -233,25 +235,46 @@ def run(
# choose from list of llm_configs
llm_configs = client.list_llm_configs()
llm_options = [llm_config.model for llm_config in llm_configs]
# TODO move into LLMConfig as a class method?
def prettify_llm_config(llm_config: LLMConfig) -> str:
return f"{llm_config.model}" + f" ({llm_config.model_endpoint})" if llm_config.model_endpoint else ""
llm_choices = [questionary.Choice(title=prettify_llm_config(llm_config), value=llm_config) for llm_config in llm_configs]
# select model
if len(llm_options) == 0:
raise ValueError("No LLM models found. Please enable a provider.")
elif len(llm_options) == 1:
llm_model_name = llm_options[0]
else:
llm_model_name = questionary.select("Select LLM model:", choices=llm_options).ask()
llm_model_name = questionary.select("Select LLM model:", choices=llm_choices).ask().model
llm_config = [llm_config for llm_config in llm_configs if llm_config.model == llm_model_name][0]
# choose form list of embedding configs
embedding_configs = client.list_embedding_configs()
embedding_options = [embedding_config.embedding_model for embedding_config in embedding_configs]
# TODO move into EmbeddingConfig as a class method?
def prettify_embed_config(embedding_config: EmbeddingConfig) -> str:
return (
f"{embedding_config.embedding_model}" + f" ({embedding_config.embedding_endpoint})"
if embedding_config.embedding_endpoint
else ""
)
embedding_choices = [
questionary.Choice(title=prettify_embed_config(embedding_config), value=embedding_config)
for embedding_config in embedding_configs
]
# select model
if len(embedding_options) == 0:
raise ValueError("No embedding models found. Please enable a provider.")
elif len(embedding_options) == 1:
embedding_model_name = embedding_options[0]
else:
embedding_model_name = questionary.select("Select embedding model:", choices=embedding_options).ask()
embedding_model_name = questionary.select("Select embedding model:", choices=embedding_choices).ask().embedding_model
embedding_config = [
embedding_config for embedding_config in embedding_configs if embedding_config.embedding_model == embedding_model_name
][0]

View File

@ -41,7 +41,9 @@ from letta.utils import smart_urljoin
OPENAI_SSE_DONE = "[DONE]"
def openai_get_model_list(url: str, api_key: Union[str, None], fix_url: Optional[bool] = False) -> dict:
def openai_get_model_list(
url: str, api_key: Union[str, None], fix_url: Optional[bool] = False, extra_params: Optional[dict] = None
) -> dict:
"""https://platform.openai.com/docs/api-reference/models/list"""
from letta.utils import printd
@ -60,7 +62,8 @@ def openai_get_model_list(url: str, api_key: Union[str, None], fix_url: Optional
printd(f"Sending request to {url}")
try:
response = requests.get(url, headers=headers)
# TODO add query param "tool" to be true
response = requests.get(url, headers=headers, params=extra_params)
response.raise_for_status() # Raises HTTPError for 4XX/5XX status
response = response.json() # convert to dict from string
printd(f"response = {response}")

View File

@ -53,17 +53,28 @@ class LettaProvider(Provider):
class OpenAIProvider(Provider):
name: str = "openai"
api_key: str = Field(..., description="API key for the OpenAI API.")
base_url: str = "https://api.openai.com/v1"
base_url: str = Field(..., description="Base URL for the OpenAI API.")
def list_llm_models(self) -> List[LLMConfig]:
from letta.llm_api.openai import openai_get_model_list
response = openai_get_model_list(self.base_url, api_key=self.api_key)
model_options = [obj["id"] for obj in response["data"]]
# Some hardcoded support for OpenRouter (so that we only get models with tool calling support)...
# See: https://openrouter.ai/docs/requests
extra_params = {"supported_parameters": "tools"} if "openrouter.ai" in self.base_url else None
response = openai_get_model_list(self.base_url, api_key=self.api_key, extra_params=extra_params)
assert "data" in response, f"OpenAI model query response missing 'data' field: {response}"
configs = []
for model_name in model_options:
context_window_size = self.get_model_context_window_size(model_name)
for model in response["data"]:
assert "id" in model, f"OpenAI model missing 'id' field: {model}"
model_name = model["id"]
if "context_length" in model:
# Context length is returned in OpenRouter as "context_length"
context_window_size = model["context_length"]
else:
context_window_size = self.get_model_context_window_size(model_name)
if not context_window_size:
continue

View File

@ -50,6 +50,7 @@ from letta.providers import (
LettaProvider,
OllamaProvider,
OpenAIProvider,
Provider,
VLLMProvider,
)
from letta.schemas.agent import AgentState, AgentType, CreateAgent, UpdateAgentState
@ -261,9 +262,9 @@ class SyncServer(Server):
self.add_default_tools(module_name="base")
# collect providers (always has Letta as a default)
self._enabled_providers = [LettaProvider()]
self._enabled_providers: List[Provider] = [LettaProvider()]
if model_settings.openai_api_key:
self._enabled_providers.append(OpenAIProvider(api_key=model_settings.openai_api_key))
self._enabled_providers.append(OpenAIProvider(api_key=model_settings.openai_api_key, base_url=model_settings.openai_api_base))
if model_settings.anthropic_api_key:
self._enabled_providers.append(AnthropicProvider(api_key=model_settings.anthropic_api_key))
if model_settings.ollama_base_url:

View File

@ -11,7 +11,7 @@ class ModelSettings(BaseSettings):
# openai
openai_api_key: Optional[str] = None
# TODO: provide overriding BASE_URL?
openai_api_base: Optional[str] = "https://api.openai.com/v1"
# groq
groq_api_key: Optional[str] = None