From b4434464085db330a2b4daa06ede33c6f844394b Mon Sep 17 00:00:00 2001 From: cthomas Date: Thu, 23 Jan 2025 10:13:05 -0800 Subject: [PATCH] feat: support custom llm configs (#737) --- letta/server/server.py | 50 +++++++++++++++++++++++++++-------- tests/test_server.py | 59 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 98 insertions(+), 11 deletions(-) diff --git a/letta/server/server.py b/letta/server/server.py index d195852f5..9cafebbda 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -1,5 +1,6 @@ # inspecting tools import asyncio +import json import os import traceback import warnings @@ -1053,6 +1054,8 @@ class SyncServer(Server): llm_models.extend(provider.list_llm_models()) except Exception as e: warnings.warn(f"An error occurred while listing LLM models for provider {provider}: {e}") + + llm_models.extend(self.get_local_llm_configs()) return llm_models def list_embedding_models(self) -> List[EmbeddingConfig]: @@ -1072,20 +1075,26 @@ class SyncServer(Server): return {**providers_from_env, **providers_from_db}.values() def get_llm_config_from_handle(self, handle: str, context_window_limit: Optional[int] = None) -> LLMConfig: - provider_name, model_name = handle.split("/", 1) - provider = self.get_provider_from_name(provider_name) + try: + provider_name, model_name = handle.split("/", 1) + provider = self.get_provider_from_name(provider_name) - llm_configs = [config for config in provider.list_llm_models() if config.handle == handle] - if len(llm_configs) == 1: - llm_config = llm_configs[0] - else: - llm_configs = [config for config in provider.list_llm_models() if config.model == model_name] + llm_configs = [config for config in provider.list_llm_models() if config.handle == handle] + if not llm_configs: + llm_configs = [config for config in provider.list_llm_models() if config.model == model_name] if not llm_configs: raise ValueError(f"LLM model {model_name} is not supported by {provider_name}") - elif len(llm_configs) > 1: - raise ValueError(f"Multiple LLM models with name {model_name} supported by {provider_name}") - else: - llm_config = llm_configs[0] + except ValueError as e: + llm_configs = [config for config in self.get_local_llm_configs() if config.handle == handle] + if not llm_configs: + raise e + + if len(llm_configs) == 1: + llm_config = llm_configs[0] + elif len(llm_configs) > 1: + raise ValueError(f"Multiple LLM models with name {model_name} supported by {provider_name}") + else: + llm_config = llm_configs[0] if context_window_limit: if context_window_limit > llm_config.context_window: @@ -1128,6 +1137,25 @@ class SyncServer(Server): return provider + def get_local_llm_configs(self): + llm_models = [] + try: + llm_configs_dir = os.path.expanduser("~/.letta/llm_configs") + if os.path.exists(llm_configs_dir): + for filename in os.listdir(llm_configs_dir): + if filename.endswith(".json"): + filepath = os.path.join(llm_configs_dir, filename) + try: + with open(filepath, "r") as f: + config_data = json.load(f) + llm_config = LLMConfig(**config_data) + llm_models.append(llm_config) + except (json.JSONDecodeError, ValueError) as e: + warnings.warn(f"Error parsing LLM config file {filename}: {e}") + except Exception as e: + warnings.warn(f"Error reading LLM configs directory: {e}") + return llm_models + def add_llm_model(self, request: LLMConfig) -> LLMConfig: """Add a new LLM model""" diff --git a/tests/test_server.py b/tests/test_server.py index cbca00bb1..762285b73 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -1,5 +1,6 @@ import json import os +import shutil import uuid import warnings from typing import List, Tuple @@ -13,6 +14,7 @@ from letta.orm import Provider, Step from letta.schemas.block import CreateBlock from letta.schemas.enums import MessageRole from letta.schemas.letta_message import LettaMessage, ReasoningMessage, SystemMessage, ToolCallMessage, ToolReturnMessage, UserMessage +from letta.schemas.llm_config import LLMConfig from letta.schemas.providers import Provider as PydanticProvider from letta.schemas.user import User @@ -563,6 +565,63 @@ def test_delete_agent_same_org(server: SyncServer, org_id: str, user: User): server.agent_manager.delete_agent(agent_state.id, actor=another_user) +def test_read_local_llm_configs(server: SyncServer, user: User): + configs_base_dir = os.path.join(os.path.expanduser("~"), ".letta", "llm_configs") + clean_up_dir = False + if not os.path.exists(configs_base_dir): + os.makedirs(configs_base_dir) + clean_up_dir = True + + try: + sample_config = LLMConfig( + model="my-custom-model", + model_endpoint_type="openai", + model_endpoint="https://api.openai.com/v1", + context_window=8192, + handle="caren/my-custom-model", + ) + + config_filename = f"custom_llm_config_{uuid.uuid4().hex}.json" + config_filepath = os.path.join(configs_base_dir, config_filename) + with open(config_filepath, "w") as f: + json.dump(sample_config.model_dump(), f) + + # Call list_llm_models + assert os.path.exists(configs_base_dir) + llm_models = server.list_llm_models() + + # Assert that the config is in the returned models + assert any( + model.model == "my-custom-model" + and model.model_endpoint_type == "openai" + and model.model_endpoint == "https://api.openai.com/v1" + and model.context_window == 8192 + and model.handle == "caren/my-custom-model" + for model in llm_models + ), "Custom LLM config not found in list_llm_models result" + + # Try to use in agent creation + context_window_override = 4000 + agent = server.create_agent( + request=CreateAgent( + model="caren/my-custom-model", + context_window_limit=context_window_override, + embedding="openai/text-embedding-ada-002", + ), + actor=user, + ) + assert agent.llm_config.model == sample_config.model + assert agent.llm_config.model_endpoint == sample_config.model_endpoint + assert agent.llm_config.model_endpoint_type == sample_config.model_endpoint_type + assert agent.llm_config.context_window == context_window_override + assert agent.llm_config.handle == sample_config.handle + + finally: + os.remove(config_filepath) + if clean_up_dir: + shutil.rmtree(configs_base_dir) + + def _test_get_messages_letta_format( server, user,