mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
feat: support custom llm configs (#737)
This commit is contained in:
parent
663a4cf5cf
commit
b443446408
@ -1,5 +1,6 @@
|
||||
# inspecting tools
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import traceback
|
||||
import warnings
|
||||
@ -1053,6 +1054,8 @@ class SyncServer(Server):
|
||||
llm_models.extend(provider.list_llm_models())
|
||||
except Exception as e:
|
||||
warnings.warn(f"An error occurred while listing LLM models for provider {provider}: {e}")
|
||||
|
||||
llm_models.extend(self.get_local_llm_configs())
|
||||
return llm_models
|
||||
|
||||
def list_embedding_models(self) -> List[EmbeddingConfig]:
|
||||
@ -1072,20 +1075,26 @@ class SyncServer(Server):
|
||||
return {**providers_from_env, **providers_from_db}.values()
|
||||
|
||||
def get_llm_config_from_handle(self, handle: str, context_window_limit: Optional[int] = None) -> LLMConfig:
|
||||
provider_name, model_name = handle.split("/", 1)
|
||||
provider = self.get_provider_from_name(provider_name)
|
||||
try:
|
||||
provider_name, model_name = handle.split("/", 1)
|
||||
provider = self.get_provider_from_name(provider_name)
|
||||
|
||||
llm_configs = [config for config in provider.list_llm_models() if config.handle == handle]
|
||||
if len(llm_configs) == 1:
|
||||
llm_config = llm_configs[0]
|
||||
else:
|
||||
llm_configs = [config for config in provider.list_llm_models() if config.model == model_name]
|
||||
llm_configs = [config for config in provider.list_llm_models() if config.handle == handle]
|
||||
if not llm_configs:
|
||||
llm_configs = [config for config in provider.list_llm_models() if config.model == model_name]
|
||||
if not llm_configs:
|
||||
raise ValueError(f"LLM model {model_name} is not supported by {provider_name}")
|
||||
elif len(llm_configs) > 1:
|
||||
raise ValueError(f"Multiple LLM models with name {model_name} supported by {provider_name}")
|
||||
else:
|
||||
llm_config = llm_configs[0]
|
||||
except ValueError as e:
|
||||
llm_configs = [config for config in self.get_local_llm_configs() if config.handle == handle]
|
||||
if not llm_configs:
|
||||
raise e
|
||||
|
||||
if len(llm_configs) == 1:
|
||||
llm_config = llm_configs[0]
|
||||
elif len(llm_configs) > 1:
|
||||
raise ValueError(f"Multiple LLM models with name {model_name} supported by {provider_name}")
|
||||
else:
|
||||
llm_config = llm_configs[0]
|
||||
|
||||
if context_window_limit:
|
||||
if context_window_limit > llm_config.context_window:
|
||||
@ -1128,6 +1137,25 @@ class SyncServer(Server):
|
||||
|
||||
return provider
|
||||
|
||||
def get_local_llm_configs(self):
|
||||
llm_models = []
|
||||
try:
|
||||
llm_configs_dir = os.path.expanduser("~/.letta/llm_configs")
|
||||
if os.path.exists(llm_configs_dir):
|
||||
for filename in os.listdir(llm_configs_dir):
|
||||
if filename.endswith(".json"):
|
||||
filepath = os.path.join(llm_configs_dir, filename)
|
||||
try:
|
||||
with open(filepath, "r") as f:
|
||||
config_data = json.load(f)
|
||||
llm_config = LLMConfig(**config_data)
|
||||
llm_models.append(llm_config)
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
warnings.warn(f"Error parsing LLM config file {filename}: {e}")
|
||||
except Exception as e:
|
||||
warnings.warn(f"Error reading LLM configs directory: {e}")
|
||||
return llm_models
|
||||
|
||||
def add_llm_model(self, request: LLMConfig) -> LLMConfig:
|
||||
"""Add a new LLM model"""
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import uuid
|
||||
import warnings
|
||||
from typing import List, Tuple
|
||||
@ -13,6 +14,7 @@ from letta.orm import Provider, Step
|
||||
from letta.schemas.block import CreateBlock
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.letta_message import LettaMessage, ReasoningMessage, SystemMessage, ToolCallMessage, ToolReturnMessage, UserMessage
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.providers import Provider as PydanticProvider
|
||||
from letta.schemas.user import User
|
||||
|
||||
@ -563,6 +565,63 @@ def test_delete_agent_same_org(server: SyncServer, org_id: str, user: User):
|
||||
server.agent_manager.delete_agent(agent_state.id, actor=another_user)
|
||||
|
||||
|
||||
def test_read_local_llm_configs(server: SyncServer, user: User):
|
||||
configs_base_dir = os.path.join(os.path.expanduser("~"), ".letta", "llm_configs")
|
||||
clean_up_dir = False
|
||||
if not os.path.exists(configs_base_dir):
|
||||
os.makedirs(configs_base_dir)
|
||||
clean_up_dir = True
|
||||
|
||||
try:
|
||||
sample_config = LLMConfig(
|
||||
model="my-custom-model",
|
||||
model_endpoint_type="openai",
|
||||
model_endpoint="https://api.openai.com/v1",
|
||||
context_window=8192,
|
||||
handle="caren/my-custom-model",
|
||||
)
|
||||
|
||||
config_filename = f"custom_llm_config_{uuid.uuid4().hex}.json"
|
||||
config_filepath = os.path.join(configs_base_dir, config_filename)
|
||||
with open(config_filepath, "w") as f:
|
||||
json.dump(sample_config.model_dump(), f)
|
||||
|
||||
# Call list_llm_models
|
||||
assert os.path.exists(configs_base_dir)
|
||||
llm_models = server.list_llm_models()
|
||||
|
||||
# Assert that the config is in the returned models
|
||||
assert any(
|
||||
model.model == "my-custom-model"
|
||||
and model.model_endpoint_type == "openai"
|
||||
and model.model_endpoint == "https://api.openai.com/v1"
|
||||
and model.context_window == 8192
|
||||
and model.handle == "caren/my-custom-model"
|
||||
for model in llm_models
|
||||
), "Custom LLM config not found in list_llm_models result"
|
||||
|
||||
# Try to use in agent creation
|
||||
context_window_override = 4000
|
||||
agent = server.create_agent(
|
||||
request=CreateAgent(
|
||||
model="caren/my-custom-model",
|
||||
context_window_limit=context_window_override,
|
||||
embedding="openai/text-embedding-ada-002",
|
||||
),
|
||||
actor=user,
|
||||
)
|
||||
assert agent.llm_config.model == sample_config.model
|
||||
assert agent.llm_config.model_endpoint == sample_config.model_endpoint
|
||||
assert agent.llm_config.model_endpoint_type == sample_config.model_endpoint_type
|
||||
assert agent.llm_config.context_window == context_window_override
|
||||
assert agent.llm_config.handle == sample_config.handle
|
||||
|
||||
finally:
|
||||
os.remove(config_filepath)
|
||||
if clean_up_dir:
|
||||
shutil.rmtree(configs_base_dir)
|
||||
|
||||
|
||||
def _test_get_messages_letta_format(
|
||||
server,
|
||||
user,
|
||||
|
Loading…
Reference in New Issue
Block a user