diff --git a/alembic/versions/373dabcba6cf_add_byok_fields_and_unique_constraint.py b/alembic/versions/373dabcba6cf_add_byok_fields_and_unique_constraint.py new file mode 100644 index 000000000..3b94ceddb --- /dev/null +++ b/alembic/versions/373dabcba6cf_add_byok_fields_and_unique_constraint.py @@ -0,0 +1,35 @@ +"""add byok fields and unique constraint + +Revision ID: 373dabcba6cf +Revises: c56081a05371 +Create Date: 2025-04-30 19:38:25.010856 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "373dabcba6cf" +down_revision: Union[str, None] = "c56081a05371" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("providers", sa.Column("provider_type", sa.String(), nullable=True)) + op.add_column("providers", sa.Column("base_url", sa.String(), nullable=True)) + op.create_unique_constraint("unique_name_organization_id", "providers", ["name", "organization_id"]) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint("unique_name_organization_id", "providers", type_="unique") + op.drop_column("providers", "base_url") + op.drop_column("providers", "provider_type") + # ### end Alembic commands ### diff --git a/letta/agent.py b/letta/agent.py index 01587bedf..df40fc63c 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -331,7 +331,8 @@ class Agent(BaseAgent): log_telemetry(self.logger, "_get_ai_reply create start") # New LLM client flow llm_client = LLMClient.create( - provider=self.agent_state.llm_config.model_endpoint_type, + provider_name=self.agent_state.llm_config.provider_name, + provider_type=self.agent_state.llm_config.model_endpoint_type, put_inner_thoughts_first=put_inner_thoughts_first, ) @@ -941,12 +942,7 @@ class Agent(BaseAgent): model_endpoint=self.agent_state.llm_config.model_endpoint, context_window_limit=self.agent_state.llm_config.context_window, usage=response.usage, - # TODO(@caren): Add full provider support - this line is a workaround for v0 BYOK feature - provider_id=( - self.provider_manager.get_anthropic_override_provider_id() - if self.agent_state.llm_config.model_endpoint_type == "anthropic" - else None - ), + provider_id=self.provider_manager.get_provider_id_from_name(self.agent_state.llm_config.provider_name), job_id=job_id, ) for message in all_new_messages: diff --git a/letta/agents/letta_agent.py b/letta/agents/letta_agent.py index 79a6f7d05..efc81331f 100644 --- a/letta/agents/letta_agent.py +++ b/letta/agents/letta_agent.py @@ -67,7 +67,8 @@ class LettaAgent(BaseAgent): ) tool_rules_solver = ToolRulesSolver(agent_state.tool_rules) llm_client = LLMClient.create( - provider=agent_state.llm_config.model_endpoint_type, + provider_name=agent_state.llm_config.provider_name, + provider_type=agent_state.llm_config.model_endpoint_type, put_inner_thoughts_first=True, ) for step in range(max_steps): diff --git a/letta/agents/letta_agent_batch.py b/letta/agents/letta_agent_batch.py index a6d31a09c..b832968e6 100644 --- a/letta/agents/letta_agent_batch.py +++ b/letta/agents/letta_agent_batch.py @@ -156,7 +156,8 @@ class LettaAgentBatch: log_event(name="init_llm_client") llm_client = LLMClient.create( - provider=agent_states[0].llm_config.model_endpoint_type, + provider_name=agent_states[0].llm_config.provider_name, + provider_type=agent_states[0].llm_config.model_endpoint_type, put_inner_thoughts_first=True, ) agent_llm_config_mapping = {s.id: s.llm_config for s in agent_states} @@ -273,7 +274,8 @@ class LettaAgentBatch: # translate provider‑specific response → OpenAI‑style tool call (unchanged) llm_client = LLMClient.create( - provider=item.llm_config.model_endpoint_type, + provider_name=item.llm_config.provider_name, + provider_type=item.llm_config.model_endpoint_type, put_inner_thoughts_first=True, ) tool_call = ( diff --git a/letta/llm_api/anthropic.py b/letta/llm_api/anthropic.py index 59939e4d6..08e70d069 100644 --- a/letta/llm_api/anthropic.py +++ b/letta/llm_api/anthropic.py @@ -26,6 +26,7 @@ from letta.llm_api.helpers import add_inner_thoughts_to_functions from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages from letta.log import get_logger +from letta.schemas.enums import ProviderType from letta.schemas.message import Message as _Message from letta.schemas.message import MessageRole as _MessageRole from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, Tool @@ -128,11 +129,12 @@ def anthropic_get_model_list(url: str, api_key: Union[str, None]) -> dict: # NOTE: currently there is no GET /models, so we need to hardcode # return MODEL_LIST - anthropic_override_key = ProviderManager().get_anthropic_override_key() - if anthropic_override_key: - anthropic_client = anthropic.Anthropic(api_key=anthropic_override_key) + if api_key: + anthropic_client = anthropic.Anthropic(api_key=api_key) elif model_settings.anthropic_api_key: anthropic_client = anthropic.Anthropic() + else: + raise ValueError("No API key provided") models = anthropic_client.models.list() models_json = models.model_dump() @@ -738,13 +740,14 @@ def anthropic_chat_completions_request( put_inner_thoughts_in_kwargs: bool = False, extended_thinking: bool = False, max_reasoning_tokens: Optional[int] = None, + provider_name: Optional[str] = None, betas: List[str] = ["tools-2024-04-04"], ) -> ChatCompletionResponse: """https://docs.anthropic.com/claude/docs/tool-use""" anthropic_client = None - anthropic_override_key = ProviderManager().get_anthropic_override_key() - if anthropic_override_key: - anthropic_client = anthropic.Anthropic(api_key=anthropic_override_key) + if provider_name and provider_name != ProviderType.anthropic.value: + api_key = ProviderManager().get_override_key(provider_name) + anthropic_client = anthropic.Anthropic(api_key=api_key) elif model_settings.anthropic_api_key: anthropic_client = anthropic.Anthropic() else: @@ -796,6 +799,7 @@ def anthropic_chat_completions_request_stream( put_inner_thoughts_in_kwargs: bool = False, extended_thinking: bool = False, max_reasoning_tokens: Optional[int] = None, + provider_name: Optional[str] = None, betas: List[str] = ["tools-2024-04-04"], ) -> Generator[ChatCompletionChunkResponse, None, None]: """Stream chat completions from Anthropic API. @@ -810,10 +814,9 @@ def anthropic_chat_completions_request_stream( extended_thinking=extended_thinking, max_reasoning_tokens=max_reasoning_tokens, ) - - anthropic_override_key = ProviderManager().get_anthropic_override_key() - if anthropic_override_key: - anthropic_client = anthropic.Anthropic(api_key=anthropic_override_key) + if provider_name and provider_name != ProviderType.anthropic.value: + api_key = ProviderManager().get_override_key(provider_name) + anthropic_client = anthropic.Anthropic(api_key=api_key) elif model_settings.anthropic_api_key: anthropic_client = anthropic.Anthropic() @@ -860,6 +863,7 @@ def anthropic_chat_completions_process_stream( put_inner_thoughts_in_kwargs: bool = False, extended_thinking: bool = False, max_reasoning_tokens: Optional[int] = None, + provider_name: Optional[str] = None, create_message_id: bool = True, create_message_datetime: bool = True, betas: List[str] = ["tools-2024-04-04"], @@ -944,6 +948,7 @@ def anthropic_chat_completions_process_stream( put_inner_thoughts_in_kwargs=put_inner_thoughts_in_kwargs, extended_thinking=extended_thinking, max_reasoning_tokens=max_reasoning_tokens, + provider_name=provider_name, betas=betas, ) ): diff --git a/letta/llm_api/anthropic_client.py b/letta/llm_api/anthropic_client.py index 863fcef0d..35317dd82 100644 --- a/letta/llm_api/anthropic_client.py +++ b/letta/llm_api/anthropic_client.py @@ -27,6 +27,7 @@ from letta.llm_api.helpers import add_inner_thoughts_to_functions, unpack_all_in from letta.llm_api.llm_client_base import LLMClientBase from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION from letta.log import get_logger +from letta.schemas.enums import ProviderType from letta.schemas.llm_config import LLMConfig from letta.schemas.message import Message as PydanticMessage from letta.schemas.openai.chat_completion_request import Tool @@ -112,7 +113,10 @@ class AnthropicClient(LLMClientBase): @trace_method def _get_anthropic_client(self, async_client: bool = False) -> Union[anthropic.AsyncAnthropic, anthropic.Anthropic]: - override_key = ProviderManager().get_anthropic_override_key() + override_key = None + if self.provider_name and self.provider_name != ProviderType.anthropic.value: + override_key = ProviderManager().get_override_key(self.provider_name) + if async_client: return anthropic.AsyncAnthropic(api_key=override_key) if override_key else anthropic.AsyncAnthropic() return anthropic.Anthropic(api_key=override_key) if override_key else anthropic.Anthropic() diff --git a/letta/llm_api/llm_api_tools.py b/letta/llm_api/llm_api_tools.py index be1b9d82a..b1112290c 100644 --- a/letta/llm_api/llm_api_tools.py +++ b/letta/llm_api/llm_api_tools.py @@ -24,6 +24,7 @@ from letta.llm_api.openai import ( from letta.local_llm.chat_completion_proxy import get_chat_completion from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages +from letta.schemas.enums import ProviderType from letta.schemas.llm_config import LLMConfig from letta.schemas.message import Message from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, cast_message_to_subtype @@ -171,6 +172,10 @@ def create( if model_settings.openai_api_key is None and llm_config.model_endpoint == "https://api.openai.com/v1": # only is a problem if we are *not* using an openai proxy raise LettaConfigurationError(message="OpenAI key is missing from letta config file", missing_fields=["openai_api_key"]) + elif llm_config.provider_name and llm_config.provider_name != ProviderType.openai.value: + from letta.services.provider_manager import ProviderManager + + api_key = ProviderManager().get_override_key(llm_config.provider_name) elif model_settings.openai_api_key is None: # the openai python client requires a dummy API key api_key = "DUMMY_API_KEY" @@ -373,6 +378,7 @@ def create( stream_interface=stream_interface, extended_thinking=llm_config.enable_reasoner, max_reasoning_tokens=llm_config.max_reasoning_tokens, + provider_name=llm_config.provider_name, name=name, ) @@ -383,6 +389,7 @@ def create( put_inner_thoughts_in_kwargs=llm_config.put_inner_thoughts_in_kwargs, extended_thinking=llm_config.enable_reasoner, max_reasoning_tokens=llm_config.max_reasoning_tokens, + provider_name=llm_config.provider_name, ) if llm_config.put_inner_thoughts_in_kwargs: diff --git a/letta/llm_api/llm_client.py b/letta/llm_api/llm_client.py index 674f94974..9028dd04a 100644 --- a/letta/llm_api/llm_client.py +++ b/letta/llm_api/llm_client.py @@ -9,7 +9,8 @@ class LLMClient: @staticmethod def create( - provider: ProviderType, + provider_type: ProviderType, + provider_name: Optional[str] = None, put_inner_thoughts_first: bool = True, ) -> Optional[LLMClientBase]: """ @@ -25,29 +26,33 @@ class LLMClient: Raises: ValueError: If the model endpoint type is not supported """ - match provider: + match provider_type: case ProviderType.google_ai: from letta.llm_api.google_ai_client import GoogleAIClient return GoogleAIClient( + provider_name=provider_name, put_inner_thoughts_first=put_inner_thoughts_first, ) case ProviderType.google_vertex: from letta.llm_api.google_vertex_client import GoogleVertexClient return GoogleVertexClient( + provider_name=provider_name, put_inner_thoughts_first=put_inner_thoughts_first, ) case ProviderType.anthropic: from letta.llm_api.anthropic_client import AnthropicClient return AnthropicClient( + provider_name=provider_name, put_inner_thoughts_first=put_inner_thoughts_first, ) case ProviderType.openai: from letta.llm_api.openai_client import OpenAIClient return OpenAIClient( + provider_name=provider_name, put_inner_thoughts_first=put_inner_thoughts_first, ) case _: diff --git a/letta/llm_api/llm_client_base.py b/letta/llm_api/llm_client_base.py index 5c7dcab9e..2b2512891 100644 --- a/letta/llm_api/llm_client_base.py +++ b/letta/llm_api/llm_client_base.py @@ -20,9 +20,11 @@ class LLMClientBase: def __init__( self, + provider_name: Optional[str] = None, put_inner_thoughts_first: Optional[bool] = True, use_tool_naming: bool = True, ): + self.provider_name = provider_name self.put_inner_thoughts_first = put_inner_thoughts_first self.use_tool_naming = use_tool_naming diff --git a/letta/llm_api/openai_client.py b/letta/llm_api/openai_client.py index afd6bf475..2c2057172 100644 --- a/letta/llm_api/openai_client.py +++ b/letta/llm_api/openai_client.py @@ -22,6 +22,7 @@ from letta.llm_api.helpers import add_inner_thoughts_to_functions, convert_to_st from letta.llm_api.llm_client_base import LLMClientBase from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION, INNER_THOUGHTS_KWARG_DESCRIPTION_GO_FIRST from letta.log import get_logger +from letta.schemas.enums import ProviderType from letta.schemas.llm_config import LLMConfig from letta.schemas.message import Message as PydanticMessage from letta.schemas.openai.chat_completion_request import ChatCompletionRequest @@ -64,7 +65,14 @@ def supports_parallel_tool_calling(model: str) -> bool: class OpenAIClient(LLMClientBase): def _prepare_client_kwargs(self, llm_config: LLMConfig) -> dict: - api_key = model_settings.openai_api_key or os.environ.get("OPENAI_API_KEY") + api_key = None + if llm_config.provider_name and llm_config.provider_name != ProviderType.openai.value: + from letta.services.provider_manager import ProviderManager + + api_key = ProviderManager().get_override_key(llm_config.provider_name) + + if not api_key: + api_key = model_settings.openai_api_key or os.environ.get("OPENAI_API_KEY") # supposedly the openai python client requires a dummy API key api_key = api_key or "DUMMY_API_KEY" kwargs = {"api_key": api_key, "base_url": llm_config.model_endpoint} diff --git a/letta/memory.py b/letta/memory.py index 6d29963f0..83b169a04 100644 --- a/letta/memory.py +++ b/letta/memory.py @@ -79,7 +79,8 @@ def summarize_messages( llm_config_no_inner_thoughts.put_inner_thoughts_in_kwargs = False llm_client = LLMClient.create( - provider=llm_config_no_inner_thoughts.model_endpoint_type, + provider_name=llm_config_no_inner_thoughts.provider_name, + provider_type=llm_config_no_inner_thoughts.model_endpoint_type, put_inner_thoughts_first=False, ) # try to use new client, otherwise fallback to old flow diff --git a/letta/orm/provider.py b/letta/orm/provider.py index 2ae524b56..d85e5ef2b 100644 --- a/letta/orm/provider.py +++ b/letta/orm/provider.py @@ -1,5 +1,6 @@ from typing import TYPE_CHECKING +from sqlalchemy import UniqueConstraint from sqlalchemy.orm import Mapped, mapped_column, relationship from letta.orm.mixins import OrganizationMixin @@ -15,9 +16,18 @@ class Provider(SqlalchemyBase, OrganizationMixin): __tablename__ = "providers" __pydantic_model__ = PydanticProvider + __table_args__ = ( + UniqueConstraint( + "name", + "organization_id", + name="unique_name_organization_id", + ), + ) name: Mapped[str] = mapped_column(nullable=False, doc="The name of the provider") + provider_type: Mapped[str] = mapped_column(nullable=True, doc="The type of the provider") api_key: Mapped[str] = mapped_column(nullable=True, doc="API key used for requests to the provider.") + base_url: Mapped[str] = mapped_column(nullable=True, doc="Base URL for the provider.") # relationships organization: Mapped["Organization"] = relationship("Organization", back_populates="providers") diff --git a/letta/schemas/enums.py b/letta/schemas/enums.py index c1d54d776..6258e1e51 100644 --- a/letta/schemas/enums.py +++ b/letta/schemas/enums.py @@ -6,6 +6,17 @@ class ProviderType(str, Enum): google_ai = "google_ai" google_vertex = "google_vertex" openai = "openai" + letta = "letta" + deepseek = "deepseek" + lmstudio_openai = "lmstudio_openai" + xai = "xai" + mistral = "mistral" + ollama = "ollama" + groq = "groq" + together = "together" + azure = "azure" + vllm = "vllm" + bedrock = "bedrock" class MessageRole(str, Enum): diff --git a/letta/schemas/llm_config.py b/letta/schemas/llm_config.py index 3dc2c92e4..b888a6750 100644 --- a/letta/schemas/llm_config.py +++ b/letta/schemas/llm_config.py @@ -50,6 +50,7 @@ class LLMConfig(BaseModel): "xai", ] = Field(..., description="The endpoint type for the model.") model_endpoint: Optional[str] = Field(None, description="The endpoint for the model.") + provider_name: Optional[str] = Field(None, description="The provider name for the model.") model_wrapper: Optional[str] = Field(None, description="The wrapper for the model.") context_window: int = Field(..., description="The context window size for the model.") put_inner_thoughts_in_kwargs: Optional[bool] = Field( diff --git a/letta/schemas/llm_config_overrides.py b/letta/schemas/llm_config_overrides.py index f8f286ae2..407c73a29 100644 --- a/letta/schemas/llm_config_overrides.py +++ b/letta/schemas/llm_config_overrides.py @@ -2,8 +2,8 @@ from typing import Dict LLM_HANDLE_OVERRIDES: Dict[str, Dict[str, str]] = { "anthropic": { - "claude-3-5-haiku-20241022": "claude-3.5-haiku", - "claude-3-5-sonnet-20241022": "claude-3.5-sonnet", + "claude-3-5-haiku-20241022": "claude-3-5-haiku", + "claude-3-5-sonnet-20241022": "claude-3-5-sonnet", "claude-3-opus-20240229": "claude-3-opus", }, "openai": { diff --git a/letta/schemas/providers.py b/letta/schemas/providers.py index a985a412a..f067007a3 100644 --- a/letta/schemas/providers.py +++ b/letta/schemas/providers.py @@ -1,6 +1,6 @@ import warnings from datetime import datetime -from typing import List, Optional +from typing import List, Literal, Optional from pydantic import Field, model_validator @@ -9,9 +9,11 @@ from letta.llm_api.azure_openai import get_azure_chat_completions_endpoint, get_ from letta.llm_api.azure_openai_constants import AZURE_MODEL_TO_CONTEXT_LENGTH from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.embedding_config_overrides import EMBEDDING_HANDLE_OVERRIDES +from letta.schemas.enums import ProviderType from letta.schemas.letta_base import LettaBase from letta.schemas.llm_config import LLMConfig from letta.schemas.llm_config_overrides import LLM_HANDLE_OVERRIDES +from letta.settings import model_settings class ProviderBase(LettaBase): @@ -21,10 +23,18 @@ class ProviderBase(LettaBase): class Provider(ProviderBase): id: Optional[str] = Field(None, description="The id of the provider, lazily created by the database manager.") name: str = Field(..., description="The name of the provider") + provider_type: ProviderType = Field(..., description="The type of the provider") api_key: Optional[str] = Field(None, description="API key used for requests to the provider.") + base_url: Optional[str] = Field(None, description="Base URL for the provider.") organization_id: Optional[str] = Field(None, description="The organization id of the user") updated_at: Optional[datetime] = Field(None, description="The last update timestamp of the provider.") + @model_validator(mode="after") + def default_base_url(self): + if self.provider_type == ProviderType.openai and self.base_url is None: + self.base_url = model_settings.openai_api_base + return self + def resolve_identifier(self): if not self.id: self.id = ProviderBase.generate_id(prefix=ProviderBase.__id_prefix__) @@ -59,9 +69,41 @@ class Provider(ProviderBase): return f"{self.name}/{model_name}" + def cast_to_subtype(self): + match (self.provider_type): + case ProviderType.letta: + return LettaProvider(**self.model_dump(exclude_none=True)) + case ProviderType.openai: + return OpenAIProvider(**self.model_dump(exclude_none=True)) + case ProviderType.anthropic: + return AnthropicProvider(**self.model_dump(exclude_none=True)) + case ProviderType.anthropic_bedrock: + return AnthropicBedrockProvider(**self.model_dump(exclude_none=True)) + case ProviderType.ollama: + return OllamaProvider(**self.model_dump(exclude_none=True)) + case ProviderType.google_ai: + return GoogleAIProvider(**self.model_dump(exclude_none=True)) + case ProviderType.google_vertex: + return GoogleVertexProvider(**self.model_dump(exclude_none=True)) + case ProviderType.azure: + return AzureProvider(**self.model_dump(exclude_none=True)) + case ProviderType.groq: + return GroqProvider(**self.model_dump(exclude_none=True)) + case ProviderType.together: + return TogetherProvider(**self.model_dump(exclude_none=True)) + case ProviderType.vllm_chat_completions: + return VLLMChatCompletionsProvider(**self.model_dump(exclude_none=True)) + case ProviderType.vllm_completions: + return VLLMCompletionsProvider(**self.model_dump(exclude_none=True)) + case ProviderType.xai: + return XAIProvider(**self.model_dump(exclude_none=True)) + case _: + raise ValueError(f"Unknown provider type: {self.provider_type}") + class ProviderCreate(ProviderBase): name: str = Field(..., description="The name of the provider.") + provider_type: ProviderType = Field(..., description="The type of the provider.") api_key: str = Field(..., description="API key used for requests to the provider.") @@ -70,8 +112,7 @@ class ProviderUpdate(ProviderBase): class LettaProvider(Provider): - - name: str = "letta" + provider_type: Literal[ProviderType.letta] = Field(ProviderType.letta, description="The type of the provider.") def list_llm_models(self) -> List[LLMConfig]: return [ @@ -81,6 +122,7 @@ class LettaProvider(Provider): model_endpoint=LETTA_MODEL_ENDPOINT, context_window=8192, handle=self.get_handle("letta-free"), + provider_name=self.name, ) ] @@ -98,7 +140,7 @@ class LettaProvider(Provider): class OpenAIProvider(Provider): - name: str = "openai" + provider_type: Literal[ProviderType.openai] = Field(ProviderType.openai, description="The type of the provider.") api_key: str = Field(..., description="API key for the OpenAI API.") base_url: str = Field(..., description="Base URL for the OpenAI API.") @@ -180,6 +222,7 @@ class OpenAIProvider(Provider): model_endpoint=self.base_url, context_window=context_window_size, handle=self.get_handle(model_name), + provider_name=self.name, ) ) @@ -235,7 +278,7 @@ class DeepSeekProvider(OpenAIProvider): * It also does not support native function calling """ - name: str = "deepseek" + provider_type: Literal[ProviderType.deepseek] = Field(ProviderType.deepseek, description="The type of the provider.") base_url: str = Field("https://api.deepseek.com/v1", description="Base URL for the DeepSeek API.") api_key: str = Field(..., description="API key for the DeepSeek API.") @@ -286,6 +329,7 @@ class DeepSeekProvider(OpenAIProvider): context_window=context_window_size, handle=self.get_handle(model_name), put_inner_thoughts_in_kwargs=put_inner_thoughts_in_kwargs, + provider_name=self.name, ) ) @@ -297,7 +341,7 @@ class DeepSeekProvider(OpenAIProvider): class LMStudioOpenAIProvider(OpenAIProvider): - name: str = "lmstudio-openai" + provider_type: Literal[ProviderType.lmstudio_openai] = Field(ProviderType.lmstudio_openai, description="The type of the provider.") base_url: str = Field(..., description="Base URL for the LMStudio OpenAI API.") api_key: Optional[str] = Field(None, description="API key for the LMStudio API.") @@ -423,7 +467,7 @@ class LMStudioOpenAIProvider(OpenAIProvider): class XAIProvider(OpenAIProvider): """https://docs.x.ai/docs/api-reference""" - name: str = "xai" + provider_type: Literal[ProviderType.xai] = Field(ProviderType.xai, description="The type of the provider.") api_key: str = Field(..., description="API key for the xAI/Grok API.") base_url: str = Field("https://api.x.ai/v1", description="Base URL for the xAI/Grok API.") @@ -476,6 +520,7 @@ class XAIProvider(OpenAIProvider): model_endpoint=self.base_url, context_window=context_window_size, handle=self.get_handle(model_name), + provider_name=self.name, ) ) @@ -487,7 +532,7 @@ class XAIProvider(OpenAIProvider): class AnthropicProvider(Provider): - name: str = "anthropic" + provider_type: Literal[ProviderType.anthropic] = Field(ProviderType.anthropic, description="The type of the provider.") api_key: str = Field(..., description="API key for the Anthropic API.") base_url: str = "https://api.anthropic.com/v1" @@ -563,6 +608,7 @@ class AnthropicProvider(Provider): handle=self.get_handle(model["id"]), put_inner_thoughts_in_kwargs=inner_thoughts_in_kwargs, max_tokens=max_tokens, + provider_name=self.name, ) ) return configs @@ -572,7 +618,7 @@ class AnthropicProvider(Provider): class MistralProvider(Provider): - name: str = "mistral" + provider_type: Literal[ProviderType.mistral] = Field(ProviderType.mistral, description="The type of the provider.") api_key: str = Field(..., description="API key for the Mistral API.") base_url: str = "https://api.mistral.ai/v1" @@ -596,6 +642,7 @@ class MistralProvider(Provider): model_endpoint=self.base_url, context_window=model["max_context_length"], handle=self.get_handle(model["id"]), + provider_name=self.name, ) ) @@ -622,7 +669,7 @@ class OllamaProvider(OpenAIProvider): See: https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion """ - name: str = "ollama" + provider_type: Literal[ProviderType.ollama] = Field(ProviderType.ollama, description="The type of the provider.") base_url: str = Field(..., description="Base URL for the Ollama API.") api_key: Optional[str] = Field(None, description="API key for the Ollama API (default: `None`).") default_prompt_formatter: str = Field( @@ -652,6 +699,7 @@ class OllamaProvider(OpenAIProvider): model_wrapper=self.default_prompt_formatter, context_window=context_window, handle=self.get_handle(model["name"]), + provider_name=self.name, ) ) return configs @@ -734,7 +782,7 @@ class OllamaProvider(OpenAIProvider): class GroqProvider(OpenAIProvider): - name: str = "groq" + provider_type: Literal[ProviderType.groq] = Field(ProviderType.groq, description="The type of the provider.") base_url: str = "https://api.groq.com/openai/v1" api_key: str = Field(..., description="API key for the Groq API.") @@ -753,6 +801,7 @@ class GroqProvider(OpenAIProvider): model_endpoint=self.base_url, context_window=model["context_window"], handle=self.get_handle(model["id"]), + provider_name=self.name, ) ) return configs @@ -773,7 +822,7 @@ class TogetherProvider(OpenAIProvider): function calling support is limited. """ - name: str = "together" + provider_type: Literal[ProviderType.together] = Field(ProviderType.together, description="The type of the provider.") base_url: str = "https://api.together.ai/v1" api_key: str = Field(..., description="API key for the TogetherAI API.") default_prompt_formatter: str = Field(..., description="Default prompt formatter (aka model wrapper) to use on vLLM /completions API.") @@ -821,6 +870,7 @@ class TogetherProvider(OpenAIProvider): model_wrapper=self.default_prompt_formatter, context_window=context_window_size, handle=self.get_handle(model_name), + provider_name=self.name, ) ) @@ -874,7 +924,7 @@ class TogetherProvider(OpenAIProvider): class GoogleAIProvider(Provider): # gemini - name: str = "google_ai" + provider_type: Literal[ProviderType.google_ai] = Field(ProviderType.google_ai, description="The type of the provider.") api_key: str = Field(..., description="API key for the Google AI API.") base_url: str = "https://generativelanguage.googleapis.com" @@ -889,7 +939,6 @@ class GoogleAIProvider(Provider): # filter by model names model_options = [mo[len("models/") :] if mo.startswith("models/") else mo for mo in model_options] - # TODO remove manual filtering for gemini-pro # Add support for all gemini models model_options = [mo for mo in model_options if str(mo).startswith("gemini-")] @@ -903,6 +952,7 @@ class GoogleAIProvider(Provider): context_window=self.get_model_context_window(model), handle=self.get_handle(model), max_tokens=8192, + provider_name=self.name, ) ) return configs @@ -938,7 +988,7 @@ class GoogleAIProvider(Provider): class GoogleVertexProvider(Provider): - name: str = "google_vertex" + provider_type: Literal[ProviderType.google_vertex] = Field(ProviderType.google_vertex, description="The type of the provider.") google_cloud_project: str = Field(..., description="GCP project ID for the Google Vertex API.") google_cloud_location: str = Field(..., description="GCP region for the Google Vertex API.") @@ -955,6 +1005,7 @@ class GoogleVertexProvider(Provider): context_window=context_length, handle=self.get_handle(model), max_tokens=8192, + provider_name=self.name, ) ) return configs @@ -978,7 +1029,7 @@ class GoogleVertexProvider(Provider): class AzureProvider(Provider): - name: str = "azure" + provider_type: Literal[ProviderType.azure] = Field(ProviderType.azure, description="The type of the provider.") latest_api_version: str = "2024-09-01-preview" # https://learn.microsoft.com/en-us/azure/ai-services/openai/api-version-deprecation base_url: str = Field( ..., description="Base URL for the Azure API endpoint. This should be specific to your org, e.g. `https://letta.openai.azure.com`." @@ -1011,6 +1062,7 @@ class AzureProvider(Provider): model_endpoint=model_endpoint, context_window=context_window_size, handle=self.get_handle(model_name), + provider_name=self.name, ), ) return configs @@ -1051,7 +1103,7 @@ class VLLMChatCompletionsProvider(Provider): """vLLM provider that treats vLLM as an OpenAI /chat/completions proxy""" # NOTE: vLLM only serves one model at a time (so could configure that through env variables) - name: str = "vllm" + provider_type: Literal[ProviderType.vllm] = Field(ProviderType.vllm, description="The type of the provider.") base_url: str = Field(..., description="Base URL for the vLLM API.") def list_llm_models(self) -> List[LLMConfig]: @@ -1070,6 +1122,7 @@ class VLLMChatCompletionsProvider(Provider): model_endpoint=self.base_url, context_window=model["max_model_len"], handle=self.get_handle(model["id"]), + provider_name=self.name, ) ) return configs @@ -1083,7 +1136,7 @@ class VLLMCompletionsProvider(Provider): """This uses /completions API as the backend, not /chat/completions, so we need to specify a model wrapper""" # NOTE: vLLM only serves one model at a time (so could configure that through env variables) - name: str = "vllm" + provider_type: Literal[ProviderType.vllm] = Field(ProviderType.vllm, description="The type of the provider.") base_url: str = Field(..., description="Base URL for the vLLM API.") default_prompt_formatter: str = Field(..., description="Default prompt formatter (aka model wrapper) to use on vLLM /completions API.") @@ -1103,6 +1156,7 @@ class VLLMCompletionsProvider(Provider): model_wrapper=self.default_prompt_formatter, context_window=model["max_model_len"], handle=self.get_handle(model["id"]), + provider_name=self.name, ) ) return configs @@ -1117,7 +1171,7 @@ class CohereProvider(OpenAIProvider): class AnthropicBedrockProvider(Provider): - name: str = "bedrock" + provider_type: Literal[ProviderType.bedrock] = Field(ProviderType.bedrock, description="The type of the provider.") aws_region: str = Field(..., description="AWS region for Bedrock") def list_llm_models(self): @@ -1131,10 +1185,11 @@ class AnthropicBedrockProvider(Provider): configs.append( LLMConfig( model=model_arn, - model_endpoint_type=self.name, + model_endpoint_type=self.provider_type.value, model_endpoint=None, context_window=self.get_model_context_window(model_arn), handle=self.get_handle(model_arn), + provider_name=self.name, ) ) return configs diff --git a/letta/server/rest_api/routers/v1/llms.py b/letta/server/rest_api/routers/v1/llms.py index 173b1a578..02c369f66 100644 --- a/letta/server/rest_api/routers/v1/llms.py +++ b/letta/server/rest_api/routers/v1/llms.py @@ -1,6 +1,6 @@ -from typing import TYPE_CHECKING, List +from typing import TYPE_CHECKING, List, Optional -from fastapi import APIRouter, Depends +from fastapi import APIRouter, Depends, Query from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.llm_config import LLMConfig @@ -14,10 +14,11 @@ router = APIRouter(prefix="/models", tags=["models", "llms"]) @router.get("/", response_model=List[LLMConfig], operation_id="list_models") def list_llm_models( + byok_only: Optional[bool] = Query(None), server: "SyncServer" = Depends(get_letta_server), ): - models = server.list_llm_models() + models = server.list_llm_models(byok_only=byok_only) # print(models) return models diff --git a/letta/server/rest_api/routers/v1/providers.py b/letta/server/rest_api/routers/v1/providers.py index 1de78ba57..02615f633 100644 --- a/letta/server/rest_api/routers/v1/providers.py +++ b/letta/server/rest_api/routers/v1/providers.py @@ -2,6 +2,7 @@ from typing import TYPE_CHECKING, List, Optional from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query +from letta.schemas.enums import ProviderType from letta.schemas.providers import Provider, ProviderCreate, ProviderUpdate from letta.server.rest_api.utils import get_letta_server @@ -13,6 +14,8 @@ router = APIRouter(prefix="/providers", tags=["providers"]) @router.get("/", response_model=List[Provider], operation_id="list_providers") def list_providers( + name: Optional[str] = Query(None), + provider_type: Optional[ProviderType] = Query(None), after: Optional[str] = Query(None), limit: Optional[int] = Query(50), actor_id: Optional[str] = Header(None, alias="user_id"), @@ -23,7 +26,7 @@ def list_providers( """ try: actor = server.user_manager.get_user_or_default(user_id=actor_id) - providers = server.provider_manager.list_providers(after=after, limit=limit, actor=actor) + providers = server.provider_manager.list_providers(after=after, limit=limit, actor=actor, name=name, provider_type=provider_type) except HTTPException: raise except Exception as e: diff --git a/letta/server/server.py b/letta/server/server.py index b46ff1ea0..4553de2f1 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -268,10 +268,11 @@ class SyncServer(Server): ) # collect providers (always has Letta as a default) - self._enabled_providers: List[Provider] = [LettaProvider()] + self._enabled_providers: List[Provider] = [LettaProvider(name="letta")] if model_settings.openai_api_key: self._enabled_providers.append( OpenAIProvider( + name="openai", api_key=model_settings.openai_api_key, base_url=model_settings.openai_api_base, ) @@ -279,12 +280,14 @@ class SyncServer(Server): if model_settings.anthropic_api_key: self._enabled_providers.append( AnthropicProvider( + name="anthropic", api_key=model_settings.anthropic_api_key, ) ) if model_settings.ollama_base_url: self._enabled_providers.append( OllamaProvider( + name="ollama", base_url=model_settings.ollama_base_url, api_key=None, default_prompt_formatter=model_settings.default_prompt_formatter, @@ -293,12 +296,14 @@ class SyncServer(Server): if model_settings.gemini_api_key: self._enabled_providers.append( GoogleAIProvider( + name="google_ai", api_key=model_settings.gemini_api_key, ) ) if model_settings.google_cloud_location and model_settings.google_cloud_project: self._enabled_providers.append( GoogleVertexProvider( + name="google_vertex", google_cloud_project=model_settings.google_cloud_project, google_cloud_location=model_settings.google_cloud_location, ) @@ -307,6 +312,7 @@ class SyncServer(Server): assert model_settings.azure_api_version, "AZURE_API_VERSION is required" self._enabled_providers.append( AzureProvider( + name="azure", api_key=model_settings.azure_api_key, base_url=model_settings.azure_base_url, api_version=model_settings.azure_api_version, @@ -315,12 +321,14 @@ class SyncServer(Server): if model_settings.groq_api_key: self._enabled_providers.append( GroqProvider( + name="groq", api_key=model_settings.groq_api_key, ) ) if model_settings.together_api_key: self._enabled_providers.append( TogetherProvider( + name="together", api_key=model_settings.together_api_key, default_prompt_formatter=model_settings.default_prompt_formatter, ) @@ -329,6 +337,7 @@ class SyncServer(Server): # vLLM exposes both a /chat/completions and a /completions endpoint self._enabled_providers.append( VLLMCompletionsProvider( + name="vllm", base_url=model_settings.vllm_api_base, default_prompt_formatter=model_settings.default_prompt_formatter, ) @@ -338,12 +347,14 @@ class SyncServer(Server): # e.g. "... --enable-auto-tool-choice --tool-call-parser hermes" self._enabled_providers.append( VLLMChatCompletionsProvider( + name="vllm", base_url=model_settings.vllm_api_base, ) ) if model_settings.aws_access_key and model_settings.aws_secret_access_key and model_settings.aws_region: self._enabled_providers.append( AnthropicBedrockProvider( + name="bedrock", aws_region=model_settings.aws_region, ) ) @@ -355,11 +366,11 @@ class SyncServer(Server): if model_settings.lmstudio_base_url.endswith("/v1") else model_settings.lmstudio_base_url + "/v1" ) - self._enabled_providers.append(LMStudioOpenAIProvider(base_url=lmstudio_url)) + self._enabled_providers.append(LMStudioOpenAIProvider(name="lmstudio_openai", base_url=lmstudio_url)) if model_settings.deepseek_api_key: - self._enabled_providers.append(DeepSeekProvider(api_key=model_settings.deepseek_api_key)) + self._enabled_providers.append(DeepSeekProvider(name="deepseek", api_key=model_settings.deepseek_api_key)) if model_settings.xai_api_key: - self._enabled_providers.append(XAIProvider(api_key=model_settings.xai_api_key)) + self._enabled_providers.append(XAIProvider(name="xai", api_key=model_settings.xai_api_key)) # For MCP """Initialize the MCP clients (there may be multiple)""" @@ -1184,10 +1195,10 @@ class SyncServer(Server): except NoResultFound: raise HTTPException(status_code=404, detail=f"Organization with id {org_id} not found") - def list_llm_models(self) -> List[LLMConfig]: + def list_llm_models(self, byok_only: bool = False) -> List[LLMConfig]: """List available models""" llm_models = [] - for provider in self.get_enabled_providers(): + for provider in self.get_enabled_providers(byok_only=byok_only): try: llm_models.extend(provider.list_llm_models()) except Exception as e: @@ -1207,11 +1218,12 @@ class SyncServer(Server): warnings.warn(f"An error occurred while listing embedding models for provider {provider}: {e}") return embedding_models - def get_enabled_providers(self): + def get_enabled_providers(self, byok_only: bool = False): + providers_from_db = {p.name: p.cast_to_subtype() for p in self.provider_manager.list_providers()} + if byok_only: + return list(providers_from_db.values()) providers_from_env = {p.name: p for p in self._enabled_providers} - providers_from_db = {p.name: p for p in self.provider_manager.list_providers()} - # Merge the two dictionaries, keeping the values from providers_from_db where conflicts occur - return {**providers_from_env, **providers_from_db}.values() + return list(providers_from_env.values()) + list(providers_from_db.values()) @trace_method def get_llm_config_from_handle( @@ -1296,7 +1308,7 @@ class SyncServer(Server): return embedding_config def get_provider_from_name(self, provider_name: str) -> Provider: - providers = [provider for provider in self._enabled_providers if provider.name == provider_name] + providers = [provider for provider in self.get_enabled_providers() if provider.name == provider_name] if not providers: raise ValueError(f"Provider {provider_name} is not supported") elif len(providers) > 1: diff --git a/letta/services/provider_manager.py b/letta/services/provider_manager.py index 39596e17f..d012171d2 100644 --- a/letta/services/provider_manager.py +++ b/letta/services/provider_manager.py @@ -1,6 +1,7 @@ -from typing import List, Optional +from typing import List, Optional, Union from letta.orm.provider import Provider as ProviderModel +from letta.schemas.enums import ProviderType from letta.schemas.providers import Provider as PydanticProvider from letta.schemas.providers import ProviderUpdate from letta.schemas.user import User as PydanticUser @@ -18,6 +19,9 @@ class ProviderManager: def create_provider(self, provider: PydanticProvider, actor: PydanticUser) -> PydanticProvider: """Create a new provider if it doesn't already exist.""" with self.session_maker() as session: + if provider.name == provider.provider_type.value: + raise ValueError("Provider name must be unique and different from provider type") + # Assign the organization id based on the actor provider.organization_id = actor.organization_id @@ -59,29 +63,36 @@ class ProviderManager: session.commit() @enforce_types - def list_providers(self, after: Optional[str] = None, limit: Optional[int] = 50, actor: PydanticUser = None) -> List[PydanticProvider]: + def list_providers( + self, + name: Optional[str] = None, + provider_type: Optional[ProviderType] = None, + after: Optional[str] = None, + limit: Optional[int] = 50, + actor: PydanticUser = None, + ) -> List[PydanticProvider]: """List all providers with optional pagination.""" + filter_kwargs = {} + if name: + filter_kwargs["name"] = name + if provider_type: + filter_kwargs["provider_type"] = provider_type with self.session_maker() as session: providers = ProviderModel.list( db_session=session, after=after, limit=limit, actor=actor, + **filter_kwargs, ) return [provider.to_pydantic() for provider in providers] @enforce_types - def get_anthropic_override_provider_id(self) -> Optional[str]: - """Helper function to fetch custom anthropic provider id for v0 BYOK feature""" - anthropic_provider = [provider for provider in self.list_providers() if provider.name == "anthropic"] - if len(anthropic_provider) != 0: - return anthropic_provider[0].id - return None + def get_provider_id_from_name(self, provider_name: Union[str, None]) -> Optional[str]: + providers = self.list_providers(name=provider_name) + return providers[0].id if providers else None @enforce_types - def get_anthropic_override_key(self) -> Optional[str]: - """Helper function to fetch custom anthropic key for v0 BYOK feature""" - anthropic_provider = [provider for provider in self.list_providers() if provider.name == "anthropic"] - if len(anthropic_provider) != 0: - return anthropic_provider[0].api_key - return None + def get_override_key(self, provider_name: Union[str, None]) -> Optional[str]: + providers = self.list_providers(name=provider_name) + return providers[0].api_key if providers else None diff --git a/tests/helpers/endpoints_helper.py b/tests/helpers/endpoints_helper.py index f9025c43f..51bf33203 100644 --- a/tests/helpers/endpoints_helper.py +++ b/tests/helpers/endpoints_helper.py @@ -105,7 +105,8 @@ def check_first_response_is_valid_for_llm_endpoint(filename: str, validate_inner agent = Agent(agent_state=full_agent_state, interface=None, user=client.user) llm_client = LLMClient.create( - provider=agent_state.llm_config.model_endpoint_type, + provider_name=agent_state.llm_config.provider_name, + provider_type=agent_state.llm_config.model_endpoint_type, ) if llm_client: response = llm_client.send_llm_request( diff --git a/tests/test_providers.py b/tests/test_providers.py index 0394dec01..2ab6606d7 100644 --- a/tests/test_providers.py +++ b/tests/test_providers.py @@ -19,97 +19,166 @@ from letta.settings import model_settings def test_openai(): api_key = os.getenv("OPENAI_API_KEY") assert api_key is not None - provider = OpenAIProvider(api_key=api_key, base_url=model_settings.openai_api_base) + provider = OpenAIProvider( + name="openai", + api_key=api_key, + base_url=model_settings.openai_api_base, + ) models = provider.list_llm_models() - print(models) + assert len(models) > 0 + assert models[0].handle == f"{provider.name}/{models[0].model}" + + embedding_models = provider.list_embedding_models() + assert len(embedding_models) > 0 + assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}" def test_deepseek(): api_key = os.getenv("DEEPSEEK_API_KEY") assert api_key is not None - provider = DeepSeekProvider(api_key=api_key) + provider = DeepSeekProvider( + name="deepseek", + api_key=api_key, + ) models = provider.list_llm_models() - print(models) + assert len(models) > 0 + assert models[0].handle == f"{provider.name}/{models[0].model}" def test_anthropic(): api_key = os.getenv("ANTHROPIC_API_KEY") assert api_key is not None - provider = AnthropicProvider(api_key=api_key) + provider = AnthropicProvider( + name="anthropic", + api_key=api_key, + ) models = provider.list_llm_models() - print(models) + assert len(models) > 0 + assert models[0].handle == f"{provider.name}/{models[0].model}" def test_groq(): - provider = GroqProvider(api_key=os.getenv("GROQ_API_KEY")) + provider = GroqProvider( + name="groq", + api_key=os.getenv("GROQ_API_KEY"), + ) models = provider.list_llm_models() - print(models) + assert len(models) > 0 + assert models[0].handle == f"{provider.name}/{models[0].model}" def test_azure(): - provider = AzureProvider(api_key=os.getenv("AZURE_API_KEY"), base_url=os.getenv("AZURE_BASE_URL")) + provider = AzureProvider( + name="azure", + api_key=os.getenv("AZURE_API_KEY"), + base_url=os.getenv("AZURE_BASE_URL"), + ) models = provider.list_llm_models() - print([m.model for m in models]) + assert len(models) > 0 + assert models[0].handle == f"{provider.name}/{models[0].model}" - embed_models = provider.list_embedding_models() - print([m.embedding_model for m in embed_models]) + embedding_models = provider.list_embedding_models() + assert len(embedding_models) > 0 + assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}" def test_ollama(): base_url = os.getenv("OLLAMA_BASE_URL") assert base_url is not None - provider = OllamaProvider(base_url=base_url, default_prompt_formatter=model_settings.default_prompt_formatter, api_key=None) + provider = OllamaProvider( + name="ollama", + base_url=base_url, + default_prompt_formatter=model_settings.default_prompt_formatter, + api_key=None, + ) models = provider.list_llm_models() - print(models) + assert len(models) > 0 + assert models[0].handle == f"{provider.name}/{models[0].model}" embedding_models = provider.list_embedding_models() - print(embedding_models) + assert len(embedding_models) > 0 + assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}" def test_googleai(): api_key = os.getenv("GEMINI_API_KEY") assert api_key is not None - provider = GoogleAIProvider(api_key=api_key) + provider = GoogleAIProvider( + name="google_ai", + api_key=api_key, + ) models = provider.list_llm_models() - print(models) + assert len(models) > 0 + assert models[0].handle == f"{provider.name}/{models[0].model}" - provider.list_embedding_models() + embedding_models = provider.list_embedding_models() + assert len(embedding_models) > 0 + assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}" def test_google_vertex(): - provider = GoogleVertexProvider(google_cloud_project=os.getenv("GCP_PROJECT_ID"), google_cloud_location=os.getenv("GCP_REGION")) + provider = GoogleVertexProvider( + name="google_vertex", + google_cloud_project=os.getenv("GCP_PROJECT_ID"), + google_cloud_location=os.getenv("GCP_REGION"), + ) models = provider.list_llm_models() - print(models) - print([m.model for m in models]) + assert len(models) > 0 + assert models[0].handle == f"{provider.name}/{models[0].model}" embedding_models = provider.list_embedding_models() - print([m.embedding_model for m in embedding_models]) + assert len(embedding_models) > 0 + assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}" def test_mistral(): - provider = MistralProvider(api_key=os.getenv("MISTRAL_API_KEY")) + provider = MistralProvider( + name="mistral", + api_key=os.getenv("MISTRAL_API_KEY"), + ) models = provider.list_llm_models() - print([m.model for m in models]) + assert len(models) > 0 + assert models[0].handle == f"{provider.name}/{models[0].model}" def test_together(): - provider = TogetherProvider(api_key=os.getenv("TOGETHER_API_KEY"), default_prompt_formatter="chatml") + provider = TogetherProvider( + name="together", + api_key=os.getenv("TOGETHER_API_KEY"), + default_prompt_formatter="chatml", + ) models = provider.list_llm_models() - print([m.model for m in models]) + assert len(models) > 0 + assert models[0].handle == f"{provider.name}/{models[0].model}" embedding_models = provider.list_embedding_models() - print([m.embedding_model for m in embedding_models]) + assert len(embedding_models) > 0 + assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}" def test_anthropic_bedrock(): from letta.settings import model_settings - provider = AnthropicBedrockProvider(aws_region=model_settings.aws_region) + provider = AnthropicBedrockProvider(name="bedrock", aws_region=model_settings.aws_region) models = provider.list_llm_models() - print([m.model for m in models]) + assert len(models) > 0 + assert models[0].handle == f"{provider.name}/{models[0].model}" embedding_models = provider.list_embedding_models() - print([m.embedding_model for m in embedding_models]) + assert len(embedding_models) > 0 + assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}" + + +def test_custom_anthropic(): + api_key = os.getenv("ANTHROPIC_API_KEY") + assert api_key is not None + provider = AnthropicProvider( + name="custom_anthropic", + api_key=api_key, + ) + models = provider.list_llm_models() + assert len(models) > 0 + assert models[0].handle == f"{provider.name}/{models[0].model}" # def test_vllm(): diff --git a/tests/test_server.py b/tests/test_server.py index 023897cdd..7d6d73e67 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -13,7 +13,7 @@ import letta.utils as utils from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, LETTA_DIR, LETTA_TOOL_EXECUTION_DIR from letta.orm import Provider, Step from letta.schemas.block import CreateBlock -from letta.schemas.enums import MessageRole +from letta.schemas.enums import MessageRole, ProviderType from letta.schemas.letta_message import LettaMessage, ReasoningMessage, SystemMessage, ToolCallMessage, ToolReturnMessage, UserMessage from letta.schemas.llm_config import LLMConfig from letta.schemas.providers import Provider as PydanticProvider @@ -1226,7 +1226,8 @@ def test_messages_with_provider_override(server: SyncServer, user_id: str): actor = server.user_manager.get_user_or_default(user_id) provider = server.provider_manager.create_provider( provider=PydanticProvider( - name="anthropic", + name="caren-anthropic", + provider_type=ProviderType.anthropic, api_key=os.getenv("ANTHROPIC_API_KEY"), ), actor=actor, @@ -1234,8 +1235,8 @@ def test_messages_with_provider_override(server: SyncServer, user_id: str): agent = server.create_agent( request=CreateAgent( memory_blocks=[], - model="anthropic/claude-3-opus-20240229", - context_window_limit=200000, + model="caren-anthropic/claude-3-opus-20240229", + context_window_limit=100000, embedding="openai/text-embedding-ada-002", ), actor=actor,