mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
chore: fix branch (#1865)
This commit is contained in:
parent
60b51b4847
commit
eb1aab1fc7
@ -14,7 +14,9 @@ from letta.constants import CLI_WARNING_PREFIX, LETTA_DIR
|
|||||||
from letta.local_llm.constants import ASSISTANT_MESSAGE_CLI_SYMBOL
|
from letta.local_llm.constants import ASSISTANT_MESSAGE_CLI_SYMBOL
|
||||||
from letta.log import get_logger
|
from letta.log import get_logger
|
||||||
from letta.metadata import MetadataStore
|
from letta.metadata import MetadataStore
|
||||||
|
from letta.schemas.embedding_config import EmbeddingConfig
|
||||||
from letta.schemas.enums import OptionState
|
from letta.schemas.enums import OptionState
|
||||||
|
from letta.schemas.llm_config import LLMConfig
|
||||||
from letta.schemas.memory import ChatMemory, Memory
|
from letta.schemas.memory import ChatMemory, Memory
|
||||||
from letta.server.server import logger as server_logger
|
from letta.server.server import logger as server_logger
|
||||||
|
|
||||||
@ -233,25 +235,46 @@ def run(
|
|||||||
# choose from list of llm_configs
|
# choose from list of llm_configs
|
||||||
llm_configs = client.list_llm_configs()
|
llm_configs = client.list_llm_configs()
|
||||||
llm_options = [llm_config.model for llm_config in llm_configs]
|
llm_options = [llm_config.model for llm_config in llm_configs]
|
||||||
|
|
||||||
|
# TODO move into LLMConfig as a class method?
|
||||||
|
def prettify_llm_config(llm_config: LLMConfig) -> str:
|
||||||
|
return f"{llm_config.model}" + f" ({llm_config.model_endpoint})" if llm_config.model_endpoint else ""
|
||||||
|
|
||||||
|
llm_choices = [questionary.Choice(title=prettify_llm_config(llm_config), value=llm_config) for llm_config in llm_configs]
|
||||||
|
|
||||||
# select model
|
# select model
|
||||||
if len(llm_options) == 0:
|
if len(llm_options) == 0:
|
||||||
raise ValueError("No LLM models found. Please enable a provider.")
|
raise ValueError("No LLM models found. Please enable a provider.")
|
||||||
elif len(llm_options) == 1:
|
elif len(llm_options) == 1:
|
||||||
llm_model_name = llm_options[0]
|
llm_model_name = llm_options[0]
|
||||||
else:
|
else:
|
||||||
llm_model_name = questionary.select("Select LLM model:", choices=llm_options).ask()
|
llm_model_name = questionary.select("Select LLM model:", choices=llm_choices).ask().model
|
||||||
llm_config = [llm_config for llm_config in llm_configs if llm_config.model == llm_model_name][0]
|
llm_config = [llm_config for llm_config in llm_configs if llm_config.model == llm_model_name][0]
|
||||||
|
|
||||||
# choose form list of embedding configs
|
# choose form list of embedding configs
|
||||||
embedding_configs = client.list_embedding_configs()
|
embedding_configs = client.list_embedding_configs()
|
||||||
embedding_options = [embedding_config.embedding_model for embedding_config in embedding_configs]
|
embedding_options = [embedding_config.embedding_model for embedding_config in embedding_configs]
|
||||||
|
|
||||||
|
# TODO move into EmbeddingConfig as a class method?
|
||||||
|
def prettify_embed_config(embedding_config: EmbeddingConfig) -> str:
|
||||||
|
return (
|
||||||
|
f"{embedding_config.embedding_model}" + f" ({embedding_config.embedding_endpoint})"
|
||||||
|
if embedding_config.embedding_endpoint
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
|
||||||
|
embedding_choices = [
|
||||||
|
questionary.Choice(title=prettify_embed_config(embedding_config), value=embedding_config)
|
||||||
|
for embedding_config in embedding_configs
|
||||||
|
]
|
||||||
|
|
||||||
# select model
|
# select model
|
||||||
if len(embedding_options) == 0:
|
if len(embedding_options) == 0:
|
||||||
raise ValueError("No embedding models found. Please enable a provider.")
|
raise ValueError("No embedding models found. Please enable a provider.")
|
||||||
elif len(embedding_options) == 1:
|
elif len(embedding_options) == 1:
|
||||||
embedding_model_name = embedding_options[0]
|
embedding_model_name = embedding_options[0]
|
||||||
else:
|
else:
|
||||||
embedding_model_name = questionary.select("Select embedding model:", choices=embedding_options).ask()
|
embedding_model_name = questionary.select("Select embedding model:", choices=embedding_choices).ask().embedding_model
|
||||||
embedding_config = [
|
embedding_config = [
|
||||||
embedding_config for embedding_config in embedding_configs if embedding_config.embedding_model == embedding_model_name
|
embedding_config for embedding_config in embedding_configs if embedding_config.embedding_model == embedding_model_name
|
||||||
][0]
|
][0]
|
||||||
|
@ -41,7 +41,9 @@ from letta.utils import smart_urljoin
|
|||||||
OPENAI_SSE_DONE = "[DONE]"
|
OPENAI_SSE_DONE = "[DONE]"
|
||||||
|
|
||||||
|
|
||||||
def openai_get_model_list(url: str, api_key: Union[str, None], fix_url: Optional[bool] = False) -> dict:
|
def openai_get_model_list(
|
||||||
|
url: str, api_key: Union[str, None], fix_url: Optional[bool] = False, extra_params: Optional[dict] = None
|
||||||
|
) -> dict:
|
||||||
"""https://platform.openai.com/docs/api-reference/models/list"""
|
"""https://platform.openai.com/docs/api-reference/models/list"""
|
||||||
from letta.utils import printd
|
from letta.utils import printd
|
||||||
|
|
||||||
@ -60,7 +62,8 @@ def openai_get_model_list(url: str, api_key: Union[str, None], fix_url: Optional
|
|||||||
|
|
||||||
printd(f"Sending request to {url}")
|
printd(f"Sending request to {url}")
|
||||||
try:
|
try:
|
||||||
response = requests.get(url, headers=headers)
|
# TODO add query param "tool" to be true
|
||||||
|
response = requests.get(url, headers=headers, params=extra_params)
|
||||||
response.raise_for_status() # Raises HTTPError for 4XX/5XX status
|
response.raise_for_status() # Raises HTTPError for 4XX/5XX status
|
||||||
response = response.json() # convert to dict from string
|
response = response.json() # convert to dict from string
|
||||||
printd(f"response = {response}")
|
printd(f"response = {response}")
|
||||||
|
@ -53,16 +53,27 @@ class LettaProvider(Provider):
|
|||||||
class OpenAIProvider(Provider):
|
class OpenAIProvider(Provider):
|
||||||
name: str = "openai"
|
name: str = "openai"
|
||||||
api_key: str = Field(..., description="API key for the OpenAI API.")
|
api_key: str = Field(..., description="API key for the OpenAI API.")
|
||||||
base_url: str = "https://api.openai.com/v1"
|
base_url: str = Field(..., description="Base URL for the OpenAI API.")
|
||||||
|
|
||||||
def list_llm_models(self) -> List[LLMConfig]:
|
def list_llm_models(self) -> List[LLMConfig]:
|
||||||
from letta.llm_api.openai import openai_get_model_list
|
from letta.llm_api.openai import openai_get_model_list
|
||||||
|
|
||||||
response = openai_get_model_list(self.base_url, api_key=self.api_key)
|
# Some hardcoded support for OpenRouter (so that we only get models with tool calling support)...
|
||||||
model_options = [obj["id"] for obj in response["data"]]
|
# See: https://openrouter.ai/docs/requests
|
||||||
|
extra_params = {"supported_parameters": "tools"} if "openrouter.ai" in self.base_url else None
|
||||||
|
response = openai_get_model_list(self.base_url, api_key=self.api_key, extra_params=extra_params)
|
||||||
|
|
||||||
|
assert "data" in response, f"OpenAI model query response missing 'data' field: {response}"
|
||||||
|
|
||||||
configs = []
|
configs = []
|
||||||
for model_name in model_options:
|
for model in response["data"]:
|
||||||
|
assert "id" in model, f"OpenAI model missing 'id' field: {model}"
|
||||||
|
model_name = model["id"]
|
||||||
|
|
||||||
|
if "context_length" in model:
|
||||||
|
# Context length is returned in OpenRouter as "context_length"
|
||||||
|
context_window_size = model["context_length"]
|
||||||
|
else:
|
||||||
context_window_size = self.get_model_context_window_size(model_name)
|
context_window_size = self.get_model_context_window_size(model_name)
|
||||||
|
|
||||||
if not context_window_size:
|
if not context_window_size:
|
||||||
|
@ -50,6 +50,7 @@ from letta.providers import (
|
|||||||
LettaProvider,
|
LettaProvider,
|
||||||
OllamaProvider,
|
OllamaProvider,
|
||||||
OpenAIProvider,
|
OpenAIProvider,
|
||||||
|
Provider,
|
||||||
VLLMProvider,
|
VLLMProvider,
|
||||||
)
|
)
|
||||||
from letta.schemas.agent import AgentState, AgentType, CreateAgent, UpdateAgentState
|
from letta.schemas.agent import AgentState, AgentType, CreateAgent, UpdateAgentState
|
||||||
@ -261,9 +262,9 @@ class SyncServer(Server):
|
|||||||
self.add_default_tools(module_name="base")
|
self.add_default_tools(module_name="base")
|
||||||
|
|
||||||
# collect providers (always has Letta as a default)
|
# collect providers (always has Letta as a default)
|
||||||
self._enabled_providers = [LettaProvider()]
|
self._enabled_providers: List[Provider] = [LettaProvider()]
|
||||||
if model_settings.openai_api_key:
|
if model_settings.openai_api_key:
|
||||||
self._enabled_providers.append(OpenAIProvider(api_key=model_settings.openai_api_key))
|
self._enabled_providers.append(OpenAIProvider(api_key=model_settings.openai_api_key, base_url=model_settings.openai_api_base))
|
||||||
if model_settings.anthropic_api_key:
|
if model_settings.anthropic_api_key:
|
||||||
self._enabled_providers.append(AnthropicProvider(api_key=model_settings.anthropic_api_key))
|
self._enabled_providers.append(AnthropicProvider(api_key=model_settings.anthropic_api_key))
|
||||||
if model_settings.ollama_base_url:
|
if model_settings.ollama_base_url:
|
||||||
|
@ -11,7 +11,7 @@ class ModelSettings(BaseSettings):
|
|||||||
|
|
||||||
# openai
|
# openai
|
||||||
openai_api_key: Optional[str] = None
|
openai_api_key: Optional[str] = None
|
||||||
# TODO: provide overriding BASE_URL?
|
openai_api_base: Optional[str] = "https://api.openai.com/v1"
|
||||||
|
|
||||||
# groq
|
# groq
|
||||||
groq_api_key: Optional[str] = None
|
groq_api_key: Optional[str] = None
|
||||||
|
Loading…
Reference in New Issue
Block a user