mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
chore: fix branch (#1865)
This commit is contained in:
parent
60b51b4847
commit
eb1aab1fc7
@ -14,7 +14,9 @@ from letta.constants import CLI_WARNING_PREFIX, LETTA_DIR
|
||||
from letta.local_llm.constants import ASSISTANT_MESSAGE_CLI_SYMBOL
|
||||
from letta.log import get_logger
|
||||
from letta.metadata import MetadataStore
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import OptionState
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.memory import ChatMemory, Memory
|
||||
from letta.server.server import logger as server_logger
|
||||
|
||||
@ -233,25 +235,46 @@ def run(
|
||||
# choose from list of llm_configs
|
||||
llm_configs = client.list_llm_configs()
|
||||
llm_options = [llm_config.model for llm_config in llm_configs]
|
||||
|
||||
# TODO move into LLMConfig as a class method?
|
||||
def prettify_llm_config(llm_config: LLMConfig) -> str:
|
||||
return f"{llm_config.model}" + f" ({llm_config.model_endpoint})" if llm_config.model_endpoint else ""
|
||||
|
||||
llm_choices = [questionary.Choice(title=prettify_llm_config(llm_config), value=llm_config) for llm_config in llm_configs]
|
||||
|
||||
# select model
|
||||
if len(llm_options) == 0:
|
||||
raise ValueError("No LLM models found. Please enable a provider.")
|
||||
elif len(llm_options) == 1:
|
||||
llm_model_name = llm_options[0]
|
||||
else:
|
||||
llm_model_name = questionary.select("Select LLM model:", choices=llm_options).ask()
|
||||
llm_model_name = questionary.select("Select LLM model:", choices=llm_choices).ask().model
|
||||
llm_config = [llm_config for llm_config in llm_configs if llm_config.model == llm_model_name][0]
|
||||
|
||||
# choose form list of embedding configs
|
||||
embedding_configs = client.list_embedding_configs()
|
||||
embedding_options = [embedding_config.embedding_model for embedding_config in embedding_configs]
|
||||
|
||||
# TODO move into EmbeddingConfig as a class method?
|
||||
def prettify_embed_config(embedding_config: EmbeddingConfig) -> str:
|
||||
return (
|
||||
f"{embedding_config.embedding_model}" + f" ({embedding_config.embedding_endpoint})"
|
||||
if embedding_config.embedding_endpoint
|
||||
else ""
|
||||
)
|
||||
|
||||
embedding_choices = [
|
||||
questionary.Choice(title=prettify_embed_config(embedding_config), value=embedding_config)
|
||||
for embedding_config in embedding_configs
|
||||
]
|
||||
|
||||
# select model
|
||||
if len(embedding_options) == 0:
|
||||
raise ValueError("No embedding models found. Please enable a provider.")
|
||||
elif len(embedding_options) == 1:
|
||||
embedding_model_name = embedding_options[0]
|
||||
else:
|
||||
embedding_model_name = questionary.select("Select embedding model:", choices=embedding_options).ask()
|
||||
embedding_model_name = questionary.select("Select embedding model:", choices=embedding_choices).ask().embedding_model
|
||||
embedding_config = [
|
||||
embedding_config for embedding_config in embedding_configs if embedding_config.embedding_model == embedding_model_name
|
||||
][0]
|
||||
|
@ -41,7 +41,9 @@ from letta.utils import smart_urljoin
|
||||
OPENAI_SSE_DONE = "[DONE]"
|
||||
|
||||
|
||||
def openai_get_model_list(url: str, api_key: Union[str, None], fix_url: Optional[bool] = False) -> dict:
|
||||
def openai_get_model_list(
|
||||
url: str, api_key: Union[str, None], fix_url: Optional[bool] = False, extra_params: Optional[dict] = None
|
||||
) -> dict:
|
||||
"""https://platform.openai.com/docs/api-reference/models/list"""
|
||||
from letta.utils import printd
|
||||
|
||||
@ -60,7 +62,8 @@ def openai_get_model_list(url: str, api_key: Union[str, None], fix_url: Optional
|
||||
|
||||
printd(f"Sending request to {url}")
|
||||
try:
|
||||
response = requests.get(url, headers=headers)
|
||||
# TODO add query param "tool" to be true
|
||||
response = requests.get(url, headers=headers, params=extra_params)
|
||||
response.raise_for_status() # Raises HTTPError for 4XX/5XX status
|
||||
response = response.json() # convert to dict from string
|
||||
printd(f"response = {response}")
|
||||
|
@ -53,17 +53,28 @@ class LettaProvider(Provider):
|
||||
class OpenAIProvider(Provider):
|
||||
name: str = "openai"
|
||||
api_key: str = Field(..., description="API key for the OpenAI API.")
|
||||
base_url: str = "https://api.openai.com/v1"
|
||||
base_url: str = Field(..., description="Base URL for the OpenAI API.")
|
||||
|
||||
def list_llm_models(self) -> List[LLMConfig]:
|
||||
from letta.llm_api.openai import openai_get_model_list
|
||||
|
||||
response = openai_get_model_list(self.base_url, api_key=self.api_key)
|
||||
model_options = [obj["id"] for obj in response["data"]]
|
||||
# Some hardcoded support for OpenRouter (so that we only get models with tool calling support)...
|
||||
# See: https://openrouter.ai/docs/requests
|
||||
extra_params = {"supported_parameters": "tools"} if "openrouter.ai" in self.base_url else None
|
||||
response = openai_get_model_list(self.base_url, api_key=self.api_key, extra_params=extra_params)
|
||||
|
||||
assert "data" in response, f"OpenAI model query response missing 'data' field: {response}"
|
||||
|
||||
configs = []
|
||||
for model_name in model_options:
|
||||
context_window_size = self.get_model_context_window_size(model_name)
|
||||
for model in response["data"]:
|
||||
assert "id" in model, f"OpenAI model missing 'id' field: {model}"
|
||||
model_name = model["id"]
|
||||
|
||||
if "context_length" in model:
|
||||
# Context length is returned in OpenRouter as "context_length"
|
||||
context_window_size = model["context_length"]
|
||||
else:
|
||||
context_window_size = self.get_model_context_window_size(model_name)
|
||||
|
||||
if not context_window_size:
|
||||
continue
|
||||
|
@ -50,6 +50,7 @@ from letta.providers import (
|
||||
LettaProvider,
|
||||
OllamaProvider,
|
||||
OpenAIProvider,
|
||||
Provider,
|
||||
VLLMProvider,
|
||||
)
|
||||
from letta.schemas.agent import AgentState, AgentType, CreateAgent, UpdateAgentState
|
||||
@ -261,9 +262,9 @@ class SyncServer(Server):
|
||||
self.add_default_tools(module_name="base")
|
||||
|
||||
# collect providers (always has Letta as a default)
|
||||
self._enabled_providers = [LettaProvider()]
|
||||
self._enabled_providers: List[Provider] = [LettaProvider()]
|
||||
if model_settings.openai_api_key:
|
||||
self._enabled_providers.append(OpenAIProvider(api_key=model_settings.openai_api_key))
|
||||
self._enabled_providers.append(OpenAIProvider(api_key=model_settings.openai_api_key, base_url=model_settings.openai_api_base))
|
||||
if model_settings.anthropic_api_key:
|
||||
self._enabled_providers.append(AnthropicProvider(api_key=model_settings.anthropic_api_key))
|
||||
if model_settings.ollama_base_url:
|
||||
|
@ -11,7 +11,7 @@ class ModelSettings(BaseSettings):
|
||||
|
||||
# openai
|
||||
openai_api_key: Optional[str] = None
|
||||
# TODO: provide overriding BASE_URL?
|
||||
openai_api_base: Optional[str] = "https://api.openai.com/v1"
|
||||
|
||||
# groq
|
||||
groq_api_key: Optional[str] = None
|
||||
|
Loading…
Reference in New Issue
Block a user