feat: support custom llm configs (#737)

This commit is contained in:
cthomas 2025-01-23 10:13:05 -08:00 committed by GitHub
parent 663a4cf5cf
commit b443446408
2 changed files with 98 additions and 11 deletions

View File

@ -1,5 +1,6 @@
# inspecting tools
import asyncio
import json
import os
import traceback
import warnings
@ -1053,6 +1054,8 @@ class SyncServer(Server):
llm_models.extend(provider.list_llm_models())
except Exception as e:
warnings.warn(f"An error occurred while listing LLM models for provider {provider}: {e}")
llm_models.extend(self.get_local_llm_configs())
return llm_models
def list_embedding_models(self) -> List[EmbeddingConfig]:
@ -1072,20 +1075,26 @@ class SyncServer(Server):
return {**providers_from_env, **providers_from_db}.values()
def get_llm_config_from_handle(self, handle: str, context_window_limit: Optional[int] = None) -> LLMConfig:
provider_name, model_name = handle.split("/", 1)
provider = self.get_provider_from_name(provider_name)
try:
provider_name, model_name = handle.split("/", 1)
provider = self.get_provider_from_name(provider_name)
llm_configs = [config for config in provider.list_llm_models() if config.handle == handle]
if len(llm_configs) == 1:
llm_config = llm_configs[0]
else:
llm_configs = [config for config in provider.list_llm_models() if config.model == model_name]
llm_configs = [config for config in provider.list_llm_models() if config.handle == handle]
if not llm_configs:
llm_configs = [config for config in provider.list_llm_models() if config.model == model_name]
if not llm_configs:
raise ValueError(f"LLM model {model_name} is not supported by {provider_name}")
elif len(llm_configs) > 1:
raise ValueError(f"Multiple LLM models with name {model_name} supported by {provider_name}")
else:
llm_config = llm_configs[0]
except ValueError as e:
llm_configs = [config for config in self.get_local_llm_configs() if config.handle == handle]
if not llm_configs:
raise e
if len(llm_configs) == 1:
llm_config = llm_configs[0]
elif len(llm_configs) > 1:
raise ValueError(f"Multiple LLM models with name {model_name} supported by {provider_name}")
else:
llm_config = llm_configs[0]
if context_window_limit:
if context_window_limit > llm_config.context_window:
@ -1128,6 +1137,25 @@ class SyncServer(Server):
return provider
def get_local_llm_configs(self):
llm_models = []
try:
llm_configs_dir = os.path.expanduser("~/.letta/llm_configs")
if os.path.exists(llm_configs_dir):
for filename in os.listdir(llm_configs_dir):
if filename.endswith(".json"):
filepath = os.path.join(llm_configs_dir, filename)
try:
with open(filepath, "r") as f:
config_data = json.load(f)
llm_config = LLMConfig(**config_data)
llm_models.append(llm_config)
except (json.JSONDecodeError, ValueError) as e:
warnings.warn(f"Error parsing LLM config file {filename}: {e}")
except Exception as e:
warnings.warn(f"Error reading LLM configs directory: {e}")
return llm_models
def add_llm_model(self, request: LLMConfig) -> LLMConfig:
"""Add a new LLM model"""

View File

@ -1,5 +1,6 @@
import json
import os
import shutil
import uuid
import warnings
from typing import List, Tuple
@ -13,6 +14,7 @@ from letta.orm import Provider, Step
from letta.schemas.block import CreateBlock
from letta.schemas.enums import MessageRole
from letta.schemas.letta_message import LettaMessage, ReasoningMessage, SystemMessage, ToolCallMessage, ToolReturnMessage, UserMessage
from letta.schemas.llm_config import LLMConfig
from letta.schemas.providers import Provider as PydanticProvider
from letta.schemas.user import User
@ -563,6 +565,63 @@ def test_delete_agent_same_org(server: SyncServer, org_id: str, user: User):
server.agent_manager.delete_agent(agent_state.id, actor=another_user)
def test_read_local_llm_configs(server: SyncServer, user: User):
configs_base_dir = os.path.join(os.path.expanduser("~"), ".letta", "llm_configs")
clean_up_dir = False
if not os.path.exists(configs_base_dir):
os.makedirs(configs_base_dir)
clean_up_dir = True
try:
sample_config = LLMConfig(
model="my-custom-model",
model_endpoint_type="openai",
model_endpoint="https://api.openai.com/v1",
context_window=8192,
handle="caren/my-custom-model",
)
config_filename = f"custom_llm_config_{uuid.uuid4().hex}.json"
config_filepath = os.path.join(configs_base_dir, config_filename)
with open(config_filepath, "w") as f:
json.dump(sample_config.model_dump(), f)
# Call list_llm_models
assert os.path.exists(configs_base_dir)
llm_models = server.list_llm_models()
# Assert that the config is in the returned models
assert any(
model.model == "my-custom-model"
and model.model_endpoint_type == "openai"
and model.model_endpoint == "https://api.openai.com/v1"
and model.context_window == 8192
and model.handle == "caren/my-custom-model"
for model in llm_models
), "Custom LLM config not found in list_llm_models result"
# Try to use in agent creation
context_window_override = 4000
agent = server.create_agent(
request=CreateAgent(
model="caren/my-custom-model",
context_window_limit=context_window_override,
embedding="openai/text-embedding-ada-002",
),
actor=user,
)
assert agent.llm_config.model == sample_config.model
assert agent.llm_config.model_endpoint == sample_config.model_endpoint
assert agent.llm_config.model_endpoint_type == sample_config.model_endpoint_type
assert agent.llm_config.context_window == context_window_override
assert agent.llm_config.handle == sample_config.handle
finally:
os.remove(config_filepath)
if clean_up_dir:
shutil.rmtree(configs_base_dir)
def _test_get_messages_letta_format(
server,
user,