feat: byok 2.0 (#1963)

This commit is contained in:
cthomas 2025-04-30 21:26:50 -07:00 committed by GitHub
parent e3819cf066
commit 835792d5e0
23 changed files with 352 additions and 111 deletions

View File

@ -0,0 +1,35 @@
"""add byok fields and unique constraint
Revision ID: 373dabcba6cf
Revises: c56081a05371
Create Date: 2025-04-30 19:38:25.010856
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "373dabcba6cf"
down_revision: Union[str, None] = "c56081a05371"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("providers", sa.Column("provider_type", sa.String(), nullable=True))
op.add_column("providers", sa.Column("base_url", sa.String(), nullable=True))
op.create_unique_constraint("unique_name_organization_id", "providers", ["name", "organization_id"])
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint("unique_name_organization_id", "providers", type_="unique")
op.drop_column("providers", "base_url")
op.drop_column("providers", "provider_type")
# ### end Alembic commands ###

View File

@ -331,7 +331,8 @@ class Agent(BaseAgent):
log_telemetry(self.logger, "_get_ai_reply create start") log_telemetry(self.logger, "_get_ai_reply create start")
# New LLM client flow # New LLM client flow
llm_client = LLMClient.create( llm_client = LLMClient.create(
provider=self.agent_state.llm_config.model_endpoint_type, provider_name=self.agent_state.llm_config.provider_name,
provider_type=self.agent_state.llm_config.model_endpoint_type,
put_inner_thoughts_first=put_inner_thoughts_first, put_inner_thoughts_first=put_inner_thoughts_first,
) )
@ -941,12 +942,7 @@ class Agent(BaseAgent):
model_endpoint=self.agent_state.llm_config.model_endpoint, model_endpoint=self.agent_state.llm_config.model_endpoint,
context_window_limit=self.agent_state.llm_config.context_window, context_window_limit=self.agent_state.llm_config.context_window,
usage=response.usage, usage=response.usage,
# TODO(@caren): Add full provider support - this line is a workaround for v0 BYOK feature provider_id=self.provider_manager.get_provider_id_from_name(self.agent_state.llm_config.provider_name),
provider_id=(
self.provider_manager.get_anthropic_override_provider_id()
if self.agent_state.llm_config.model_endpoint_type == "anthropic"
else None
),
job_id=job_id, job_id=job_id,
) )
for message in all_new_messages: for message in all_new_messages:

View File

@ -67,7 +67,8 @@ class LettaAgent(BaseAgent):
) )
tool_rules_solver = ToolRulesSolver(agent_state.tool_rules) tool_rules_solver = ToolRulesSolver(agent_state.tool_rules)
llm_client = LLMClient.create( llm_client = LLMClient.create(
provider=agent_state.llm_config.model_endpoint_type, provider_name=agent_state.llm_config.provider_name,
provider_type=agent_state.llm_config.model_endpoint_type,
put_inner_thoughts_first=True, put_inner_thoughts_first=True,
) )
for step in range(max_steps): for step in range(max_steps):

View File

@ -156,7 +156,8 @@ class LettaAgentBatch:
log_event(name="init_llm_client") log_event(name="init_llm_client")
llm_client = LLMClient.create( llm_client = LLMClient.create(
provider=agent_states[0].llm_config.model_endpoint_type, provider_name=agent_states[0].llm_config.provider_name,
provider_type=agent_states[0].llm_config.model_endpoint_type,
put_inner_thoughts_first=True, put_inner_thoughts_first=True,
) )
agent_llm_config_mapping = {s.id: s.llm_config for s in agent_states} agent_llm_config_mapping = {s.id: s.llm_config for s in agent_states}
@ -273,7 +274,8 @@ class LettaAgentBatch:
# translate providerspecific response → OpenAIstyle tool call (unchanged) # translate providerspecific response → OpenAIstyle tool call (unchanged)
llm_client = LLMClient.create( llm_client = LLMClient.create(
provider=item.llm_config.model_endpoint_type, provider_name=item.llm_config.provider_name,
provider_type=item.llm_config.model_endpoint_type,
put_inner_thoughts_first=True, put_inner_thoughts_first=True,
) )
tool_call = ( tool_call = (

View File

@ -26,6 +26,7 @@ from letta.llm_api.helpers import add_inner_thoughts_to_functions
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION
from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages
from letta.log import get_logger from letta.log import get_logger
from letta.schemas.enums import ProviderType
from letta.schemas.message import Message as _Message from letta.schemas.message import Message as _Message
from letta.schemas.message import MessageRole as _MessageRole from letta.schemas.message import MessageRole as _MessageRole
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, Tool from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, Tool
@ -128,11 +129,12 @@ def anthropic_get_model_list(url: str, api_key: Union[str, None]) -> dict:
# NOTE: currently there is no GET /models, so we need to hardcode # NOTE: currently there is no GET /models, so we need to hardcode
# return MODEL_LIST # return MODEL_LIST
anthropic_override_key = ProviderManager().get_anthropic_override_key() if api_key:
if anthropic_override_key: anthropic_client = anthropic.Anthropic(api_key=api_key)
anthropic_client = anthropic.Anthropic(api_key=anthropic_override_key)
elif model_settings.anthropic_api_key: elif model_settings.anthropic_api_key:
anthropic_client = anthropic.Anthropic() anthropic_client = anthropic.Anthropic()
else:
raise ValueError("No API key provided")
models = anthropic_client.models.list() models = anthropic_client.models.list()
models_json = models.model_dump() models_json = models.model_dump()
@ -738,13 +740,14 @@ def anthropic_chat_completions_request(
put_inner_thoughts_in_kwargs: bool = False, put_inner_thoughts_in_kwargs: bool = False,
extended_thinking: bool = False, extended_thinking: bool = False,
max_reasoning_tokens: Optional[int] = None, max_reasoning_tokens: Optional[int] = None,
provider_name: Optional[str] = None,
betas: List[str] = ["tools-2024-04-04"], betas: List[str] = ["tools-2024-04-04"],
) -> ChatCompletionResponse: ) -> ChatCompletionResponse:
"""https://docs.anthropic.com/claude/docs/tool-use""" """https://docs.anthropic.com/claude/docs/tool-use"""
anthropic_client = None anthropic_client = None
anthropic_override_key = ProviderManager().get_anthropic_override_key() if provider_name and provider_name != ProviderType.anthropic.value:
if anthropic_override_key: api_key = ProviderManager().get_override_key(provider_name)
anthropic_client = anthropic.Anthropic(api_key=anthropic_override_key) anthropic_client = anthropic.Anthropic(api_key=api_key)
elif model_settings.anthropic_api_key: elif model_settings.anthropic_api_key:
anthropic_client = anthropic.Anthropic() anthropic_client = anthropic.Anthropic()
else: else:
@ -796,6 +799,7 @@ def anthropic_chat_completions_request_stream(
put_inner_thoughts_in_kwargs: bool = False, put_inner_thoughts_in_kwargs: bool = False,
extended_thinking: bool = False, extended_thinking: bool = False,
max_reasoning_tokens: Optional[int] = None, max_reasoning_tokens: Optional[int] = None,
provider_name: Optional[str] = None,
betas: List[str] = ["tools-2024-04-04"], betas: List[str] = ["tools-2024-04-04"],
) -> Generator[ChatCompletionChunkResponse, None, None]: ) -> Generator[ChatCompletionChunkResponse, None, None]:
"""Stream chat completions from Anthropic API. """Stream chat completions from Anthropic API.
@ -810,10 +814,9 @@ def anthropic_chat_completions_request_stream(
extended_thinking=extended_thinking, extended_thinking=extended_thinking,
max_reasoning_tokens=max_reasoning_tokens, max_reasoning_tokens=max_reasoning_tokens,
) )
if provider_name and provider_name != ProviderType.anthropic.value:
anthropic_override_key = ProviderManager().get_anthropic_override_key() api_key = ProviderManager().get_override_key(provider_name)
if anthropic_override_key: anthropic_client = anthropic.Anthropic(api_key=api_key)
anthropic_client = anthropic.Anthropic(api_key=anthropic_override_key)
elif model_settings.anthropic_api_key: elif model_settings.anthropic_api_key:
anthropic_client = anthropic.Anthropic() anthropic_client = anthropic.Anthropic()
@ -860,6 +863,7 @@ def anthropic_chat_completions_process_stream(
put_inner_thoughts_in_kwargs: bool = False, put_inner_thoughts_in_kwargs: bool = False,
extended_thinking: bool = False, extended_thinking: bool = False,
max_reasoning_tokens: Optional[int] = None, max_reasoning_tokens: Optional[int] = None,
provider_name: Optional[str] = None,
create_message_id: bool = True, create_message_id: bool = True,
create_message_datetime: bool = True, create_message_datetime: bool = True,
betas: List[str] = ["tools-2024-04-04"], betas: List[str] = ["tools-2024-04-04"],
@ -944,6 +948,7 @@ def anthropic_chat_completions_process_stream(
put_inner_thoughts_in_kwargs=put_inner_thoughts_in_kwargs, put_inner_thoughts_in_kwargs=put_inner_thoughts_in_kwargs,
extended_thinking=extended_thinking, extended_thinking=extended_thinking,
max_reasoning_tokens=max_reasoning_tokens, max_reasoning_tokens=max_reasoning_tokens,
provider_name=provider_name,
betas=betas, betas=betas,
) )
): ):

View File

@ -27,6 +27,7 @@ from letta.llm_api.helpers import add_inner_thoughts_to_functions, unpack_all_in
from letta.llm_api.llm_client_base import LLMClientBase from letta.llm_api.llm_client_base import LLMClientBase
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION
from letta.log import get_logger from letta.log import get_logger
from letta.schemas.enums import ProviderType
from letta.schemas.llm_config import LLMConfig from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message as PydanticMessage from letta.schemas.message import Message as PydanticMessage
from letta.schemas.openai.chat_completion_request import Tool from letta.schemas.openai.chat_completion_request import Tool
@ -112,7 +113,10 @@ class AnthropicClient(LLMClientBase):
@trace_method @trace_method
def _get_anthropic_client(self, async_client: bool = False) -> Union[anthropic.AsyncAnthropic, anthropic.Anthropic]: def _get_anthropic_client(self, async_client: bool = False) -> Union[anthropic.AsyncAnthropic, anthropic.Anthropic]:
override_key = ProviderManager().get_anthropic_override_key() override_key = None
if self.provider_name and self.provider_name != ProviderType.anthropic.value:
override_key = ProviderManager().get_override_key(self.provider_name)
if async_client: if async_client:
return anthropic.AsyncAnthropic(api_key=override_key) if override_key else anthropic.AsyncAnthropic() return anthropic.AsyncAnthropic(api_key=override_key) if override_key else anthropic.AsyncAnthropic()
return anthropic.Anthropic(api_key=override_key) if override_key else anthropic.Anthropic() return anthropic.Anthropic(api_key=override_key) if override_key else anthropic.Anthropic()

View File

@ -24,6 +24,7 @@ from letta.llm_api.openai import (
from letta.local_llm.chat_completion_proxy import get_chat_completion from letta.local_llm.chat_completion_proxy import get_chat_completion
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION
from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages
from letta.schemas.enums import ProviderType
from letta.schemas.llm_config import LLMConfig from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message from letta.schemas.message import Message
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, cast_message_to_subtype from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, cast_message_to_subtype
@ -171,6 +172,10 @@ def create(
if model_settings.openai_api_key is None and llm_config.model_endpoint == "https://api.openai.com/v1": if model_settings.openai_api_key is None and llm_config.model_endpoint == "https://api.openai.com/v1":
# only is a problem if we are *not* using an openai proxy # only is a problem if we are *not* using an openai proxy
raise LettaConfigurationError(message="OpenAI key is missing from letta config file", missing_fields=["openai_api_key"]) raise LettaConfigurationError(message="OpenAI key is missing from letta config file", missing_fields=["openai_api_key"])
elif llm_config.provider_name and llm_config.provider_name != ProviderType.openai.value:
from letta.services.provider_manager import ProviderManager
api_key = ProviderManager().get_override_key(llm_config.provider_name)
elif model_settings.openai_api_key is None: elif model_settings.openai_api_key is None:
# the openai python client requires a dummy API key # the openai python client requires a dummy API key
api_key = "DUMMY_API_KEY" api_key = "DUMMY_API_KEY"
@ -373,6 +378,7 @@ def create(
stream_interface=stream_interface, stream_interface=stream_interface,
extended_thinking=llm_config.enable_reasoner, extended_thinking=llm_config.enable_reasoner,
max_reasoning_tokens=llm_config.max_reasoning_tokens, max_reasoning_tokens=llm_config.max_reasoning_tokens,
provider_name=llm_config.provider_name,
name=name, name=name,
) )
@ -383,6 +389,7 @@ def create(
put_inner_thoughts_in_kwargs=llm_config.put_inner_thoughts_in_kwargs, put_inner_thoughts_in_kwargs=llm_config.put_inner_thoughts_in_kwargs,
extended_thinking=llm_config.enable_reasoner, extended_thinking=llm_config.enable_reasoner,
max_reasoning_tokens=llm_config.max_reasoning_tokens, max_reasoning_tokens=llm_config.max_reasoning_tokens,
provider_name=llm_config.provider_name,
) )
if llm_config.put_inner_thoughts_in_kwargs: if llm_config.put_inner_thoughts_in_kwargs:

View File

@ -9,7 +9,8 @@ class LLMClient:
@staticmethod @staticmethod
def create( def create(
provider: ProviderType, provider_type: ProviderType,
provider_name: Optional[str] = None,
put_inner_thoughts_first: bool = True, put_inner_thoughts_first: bool = True,
) -> Optional[LLMClientBase]: ) -> Optional[LLMClientBase]:
""" """
@ -25,29 +26,33 @@ class LLMClient:
Raises: Raises:
ValueError: If the model endpoint type is not supported ValueError: If the model endpoint type is not supported
""" """
match provider: match provider_type:
case ProviderType.google_ai: case ProviderType.google_ai:
from letta.llm_api.google_ai_client import GoogleAIClient from letta.llm_api.google_ai_client import GoogleAIClient
return GoogleAIClient( return GoogleAIClient(
provider_name=provider_name,
put_inner_thoughts_first=put_inner_thoughts_first, put_inner_thoughts_first=put_inner_thoughts_first,
) )
case ProviderType.google_vertex: case ProviderType.google_vertex:
from letta.llm_api.google_vertex_client import GoogleVertexClient from letta.llm_api.google_vertex_client import GoogleVertexClient
return GoogleVertexClient( return GoogleVertexClient(
provider_name=provider_name,
put_inner_thoughts_first=put_inner_thoughts_first, put_inner_thoughts_first=put_inner_thoughts_first,
) )
case ProviderType.anthropic: case ProviderType.anthropic:
from letta.llm_api.anthropic_client import AnthropicClient from letta.llm_api.anthropic_client import AnthropicClient
return AnthropicClient( return AnthropicClient(
provider_name=provider_name,
put_inner_thoughts_first=put_inner_thoughts_first, put_inner_thoughts_first=put_inner_thoughts_first,
) )
case ProviderType.openai: case ProviderType.openai:
from letta.llm_api.openai_client import OpenAIClient from letta.llm_api.openai_client import OpenAIClient
return OpenAIClient( return OpenAIClient(
provider_name=provider_name,
put_inner_thoughts_first=put_inner_thoughts_first, put_inner_thoughts_first=put_inner_thoughts_first,
) )
case _: case _:

View File

@ -20,9 +20,11 @@ class LLMClientBase:
def __init__( def __init__(
self, self,
provider_name: Optional[str] = None,
put_inner_thoughts_first: Optional[bool] = True, put_inner_thoughts_first: Optional[bool] = True,
use_tool_naming: bool = True, use_tool_naming: bool = True,
): ):
self.provider_name = provider_name
self.put_inner_thoughts_first = put_inner_thoughts_first self.put_inner_thoughts_first = put_inner_thoughts_first
self.use_tool_naming = use_tool_naming self.use_tool_naming = use_tool_naming

View File

@ -22,6 +22,7 @@ from letta.llm_api.helpers import add_inner_thoughts_to_functions, convert_to_st
from letta.llm_api.llm_client_base import LLMClientBase from letta.llm_api.llm_client_base import LLMClientBase
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION, INNER_THOUGHTS_KWARG_DESCRIPTION_GO_FIRST from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION, INNER_THOUGHTS_KWARG_DESCRIPTION_GO_FIRST
from letta.log import get_logger from letta.log import get_logger
from letta.schemas.enums import ProviderType
from letta.schemas.llm_config import LLMConfig from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message as PydanticMessage from letta.schemas.message import Message as PydanticMessage
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest from letta.schemas.openai.chat_completion_request import ChatCompletionRequest
@ -64,7 +65,14 @@ def supports_parallel_tool_calling(model: str) -> bool:
class OpenAIClient(LLMClientBase): class OpenAIClient(LLMClientBase):
def _prepare_client_kwargs(self, llm_config: LLMConfig) -> dict: def _prepare_client_kwargs(self, llm_config: LLMConfig) -> dict:
api_key = model_settings.openai_api_key or os.environ.get("OPENAI_API_KEY") api_key = None
if llm_config.provider_name and llm_config.provider_name != ProviderType.openai.value:
from letta.services.provider_manager import ProviderManager
api_key = ProviderManager().get_override_key(llm_config.provider_name)
if not api_key:
api_key = model_settings.openai_api_key or os.environ.get("OPENAI_API_KEY")
# supposedly the openai python client requires a dummy API key # supposedly the openai python client requires a dummy API key
api_key = api_key or "DUMMY_API_KEY" api_key = api_key or "DUMMY_API_KEY"
kwargs = {"api_key": api_key, "base_url": llm_config.model_endpoint} kwargs = {"api_key": api_key, "base_url": llm_config.model_endpoint}

View File

@ -79,7 +79,8 @@ def summarize_messages(
llm_config_no_inner_thoughts.put_inner_thoughts_in_kwargs = False llm_config_no_inner_thoughts.put_inner_thoughts_in_kwargs = False
llm_client = LLMClient.create( llm_client = LLMClient.create(
provider=llm_config_no_inner_thoughts.model_endpoint_type, provider_name=llm_config_no_inner_thoughts.provider_name,
provider_type=llm_config_no_inner_thoughts.model_endpoint_type,
put_inner_thoughts_first=False, put_inner_thoughts_first=False,
) )
# try to use new client, otherwise fallback to old flow # try to use new client, otherwise fallback to old flow

View File

@ -1,5 +1,6 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from sqlalchemy import UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.orm import Mapped, mapped_column, relationship
from letta.orm.mixins import OrganizationMixin from letta.orm.mixins import OrganizationMixin
@ -15,9 +16,18 @@ class Provider(SqlalchemyBase, OrganizationMixin):
__tablename__ = "providers" __tablename__ = "providers"
__pydantic_model__ = PydanticProvider __pydantic_model__ = PydanticProvider
__table_args__ = (
UniqueConstraint(
"name",
"organization_id",
name="unique_name_organization_id",
),
)
name: Mapped[str] = mapped_column(nullable=False, doc="The name of the provider") name: Mapped[str] = mapped_column(nullable=False, doc="The name of the provider")
provider_type: Mapped[str] = mapped_column(nullable=True, doc="The type of the provider")
api_key: Mapped[str] = mapped_column(nullable=True, doc="API key used for requests to the provider.") api_key: Mapped[str] = mapped_column(nullable=True, doc="API key used for requests to the provider.")
base_url: Mapped[str] = mapped_column(nullable=True, doc="Base URL for the provider.")
# relationships # relationships
organization: Mapped["Organization"] = relationship("Organization", back_populates="providers") organization: Mapped["Organization"] = relationship("Organization", back_populates="providers")

View File

@ -6,6 +6,17 @@ class ProviderType(str, Enum):
google_ai = "google_ai" google_ai = "google_ai"
google_vertex = "google_vertex" google_vertex = "google_vertex"
openai = "openai" openai = "openai"
letta = "letta"
deepseek = "deepseek"
lmstudio_openai = "lmstudio_openai"
xai = "xai"
mistral = "mistral"
ollama = "ollama"
groq = "groq"
together = "together"
azure = "azure"
vllm = "vllm"
bedrock = "bedrock"
class MessageRole(str, Enum): class MessageRole(str, Enum):

View File

@ -50,6 +50,7 @@ class LLMConfig(BaseModel):
"xai", "xai",
] = Field(..., description="The endpoint type for the model.") ] = Field(..., description="The endpoint type for the model.")
model_endpoint: Optional[str] = Field(None, description="The endpoint for the model.") model_endpoint: Optional[str] = Field(None, description="The endpoint for the model.")
provider_name: Optional[str] = Field(None, description="The provider name for the model.")
model_wrapper: Optional[str] = Field(None, description="The wrapper for the model.") model_wrapper: Optional[str] = Field(None, description="The wrapper for the model.")
context_window: int = Field(..., description="The context window size for the model.") context_window: int = Field(..., description="The context window size for the model.")
put_inner_thoughts_in_kwargs: Optional[bool] = Field( put_inner_thoughts_in_kwargs: Optional[bool] = Field(

View File

@ -2,8 +2,8 @@ from typing import Dict
LLM_HANDLE_OVERRIDES: Dict[str, Dict[str, str]] = { LLM_HANDLE_OVERRIDES: Dict[str, Dict[str, str]] = {
"anthropic": { "anthropic": {
"claude-3-5-haiku-20241022": "claude-3.5-haiku", "claude-3-5-haiku-20241022": "claude-3-5-haiku",
"claude-3-5-sonnet-20241022": "claude-3.5-sonnet", "claude-3-5-sonnet-20241022": "claude-3-5-sonnet",
"claude-3-opus-20240229": "claude-3-opus", "claude-3-opus-20240229": "claude-3-opus",
}, },
"openai": { "openai": {

View File

@ -1,6 +1,6 @@
import warnings import warnings
from datetime import datetime from datetime import datetime
from typing import List, Optional from typing import List, Literal, Optional
from pydantic import Field, model_validator from pydantic import Field, model_validator
@ -9,9 +9,11 @@ from letta.llm_api.azure_openai import get_azure_chat_completions_endpoint, get_
from letta.llm_api.azure_openai_constants import AZURE_MODEL_TO_CONTEXT_LENGTH from letta.llm_api.azure_openai_constants import AZURE_MODEL_TO_CONTEXT_LENGTH
from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.embedding_config_overrides import EMBEDDING_HANDLE_OVERRIDES from letta.schemas.embedding_config_overrides import EMBEDDING_HANDLE_OVERRIDES
from letta.schemas.enums import ProviderType
from letta.schemas.letta_base import LettaBase from letta.schemas.letta_base import LettaBase
from letta.schemas.llm_config import LLMConfig from letta.schemas.llm_config import LLMConfig
from letta.schemas.llm_config_overrides import LLM_HANDLE_OVERRIDES from letta.schemas.llm_config_overrides import LLM_HANDLE_OVERRIDES
from letta.settings import model_settings
class ProviderBase(LettaBase): class ProviderBase(LettaBase):
@ -21,10 +23,18 @@ class ProviderBase(LettaBase):
class Provider(ProviderBase): class Provider(ProviderBase):
id: Optional[str] = Field(None, description="The id of the provider, lazily created by the database manager.") id: Optional[str] = Field(None, description="The id of the provider, lazily created by the database manager.")
name: str = Field(..., description="The name of the provider") name: str = Field(..., description="The name of the provider")
provider_type: ProviderType = Field(..., description="The type of the provider")
api_key: Optional[str] = Field(None, description="API key used for requests to the provider.") api_key: Optional[str] = Field(None, description="API key used for requests to the provider.")
base_url: Optional[str] = Field(None, description="Base URL for the provider.")
organization_id: Optional[str] = Field(None, description="The organization id of the user") organization_id: Optional[str] = Field(None, description="The organization id of the user")
updated_at: Optional[datetime] = Field(None, description="The last update timestamp of the provider.") updated_at: Optional[datetime] = Field(None, description="The last update timestamp of the provider.")
@model_validator(mode="after")
def default_base_url(self):
if self.provider_type == ProviderType.openai and self.base_url is None:
self.base_url = model_settings.openai_api_base
return self
def resolve_identifier(self): def resolve_identifier(self):
if not self.id: if not self.id:
self.id = ProviderBase.generate_id(prefix=ProviderBase.__id_prefix__) self.id = ProviderBase.generate_id(prefix=ProviderBase.__id_prefix__)
@ -59,9 +69,41 @@ class Provider(ProviderBase):
return f"{self.name}/{model_name}" return f"{self.name}/{model_name}"
def cast_to_subtype(self):
match (self.provider_type):
case ProviderType.letta:
return LettaProvider(**self.model_dump(exclude_none=True))
case ProviderType.openai:
return OpenAIProvider(**self.model_dump(exclude_none=True))
case ProviderType.anthropic:
return AnthropicProvider(**self.model_dump(exclude_none=True))
case ProviderType.anthropic_bedrock:
return AnthropicBedrockProvider(**self.model_dump(exclude_none=True))
case ProviderType.ollama:
return OllamaProvider(**self.model_dump(exclude_none=True))
case ProviderType.google_ai:
return GoogleAIProvider(**self.model_dump(exclude_none=True))
case ProviderType.google_vertex:
return GoogleVertexProvider(**self.model_dump(exclude_none=True))
case ProviderType.azure:
return AzureProvider(**self.model_dump(exclude_none=True))
case ProviderType.groq:
return GroqProvider(**self.model_dump(exclude_none=True))
case ProviderType.together:
return TogetherProvider(**self.model_dump(exclude_none=True))
case ProviderType.vllm_chat_completions:
return VLLMChatCompletionsProvider(**self.model_dump(exclude_none=True))
case ProviderType.vllm_completions:
return VLLMCompletionsProvider(**self.model_dump(exclude_none=True))
case ProviderType.xai:
return XAIProvider(**self.model_dump(exclude_none=True))
case _:
raise ValueError(f"Unknown provider type: {self.provider_type}")
class ProviderCreate(ProviderBase): class ProviderCreate(ProviderBase):
name: str = Field(..., description="The name of the provider.") name: str = Field(..., description="The name of the provider.")
provider_type: ProviderType = Field(..., description="The type of the provider.")
api_key: str = Field(..., description="API key used for requests to the provider.") api_key: str = Field(..., description="API key used for requests to the provider.")
@ -70,8 +112,7 @@ class ProviderUpdate(ProviderBase):
class LettaProvider(Provider): class LettaProvider(Provider):
provider_type: Literal[ProviderType.letta] = Field(ProviderType.letta, description="The type of the provider.")
name: str = "letta"
def list_llm_models(self) -> List[LLMConfig]: def list_llm_models(self) -> List[LLMConfig]:
return [ return [
@ -81,6 +122,7 @@ class LettaProvider(Provider):
model_endpoint=LETTA_MODEL_ENDPOINT, model_endpoint=LETTA_MODEL_ENDPOINT,
context_window=8192, context_window=8192,
handle=self.get_handle("letta-free"), handle=self.get_handle("letta-free"),
provider_name=self.name,
) )
] ]
@ -98,7 +140,7 @@ class LettaProvider(Provider):
class OpenAIProvider(Provider): class OpenAIProvider(Provider):
name: str = "openai" provider_type: Literal[ProviderType.openai] = Field(ProviderType.openai, description="The type of the provider.")
api_key: str = Field(..., description="API key for the OpenAI API.") api_key: str = Field(..., description="API key for the OpenAI API.")
base_url: str = Field(..., description="Base URL for the OpenAI API.") base_url: str = Field(..., description="Base URL for the OpenAI API.")
@ -180,6 +222,7 @@ class OpenAIProvider(Provider):
model_endpoint=self.base_url, model_endpoint=self.base_url,
context_window=context_window_size, context_window=context_window_size,
handle=self.get_handle(model_name), handle=self.get_handle(model_name),
provider_name=self.name,
) )
) )
@ -235,7 +278,7 @@ class DeepSeekProvider(OpenAIProvider):
* It also does not support native function calling * It also does not support native function calling
""" """
name: str = "deepseek" provider_type: Literal[ProviderType.deepseek] = Field(ProviderType.deepseek, description="The type of the provider.")
base_url: str = Field("https://api.deepseek.com/v1", description="Base URL for the DeepSeek API.") base_url: str = Field("https://api.deepseek.com/v1", description="Base URL for the DeepSeek API.")
api_key: str = Field(..., description="API key for the DeepSeek API.") api_key: str = Field(..., description="API key for the DeepSeek API.")
@ -286,6 +329,7 @@ class DeepSeekProvider(OpenAIProvider):
context_window=context_window_size, context_window=context_window_size,
handle=self.get_handle(model_name), handle=self.get_handle(model_name),
put_inner_thoughts_in_kwargs=put_inner_thoughts_in_kwargs, put_inner_thoughts_in_kwargs=put_inner_thoughts_in_kwargs,
provider_name=self.name,
) )
) )
@ -297,7 +341,7 @@ class DeepSeekProvider(OpenAIProvider):
class LMStudioOpenAIProvider(OpenAIProvider): class LMStudioOpenAIProvider(OpenAIProvider):
name: str = "lmstudio-openai" provider_type: Literal[ProviderType.lmstudio_openai] = Field(ProviderType.lmstudio_openai, description="The type of the provider.")
base_url: str = Field(..., description="Base URL for the LMStudio OpenAI API.") base_url: str = Field(..., description="Base URL for the LMStudio OpenAI API.")
api_key: Optional[str] = Field(None, description="API key for the LMStudio API.") api_key: Optional[str] = Field(None, description="API key for the LMStudio API.")
@ -423,7 +467,7 @@ class LMStudioOpenAIProvider(OpenAIProvider):
class XAIProvider(OpenAIProvider): class XAIProvider(OpenAIProvider):
"""https://docs.x.ai/docs/api-reference""" """https://docs.x.ai/docs/api-reference"""
name: str = "xai" provider_type: Literal[ProviderType.xai] = Field(ProviderType.xai, description="The type of the provider.")
api_key: str = Field(..., description="API key for the xAI/Grok API.") api_key: str = Field(..., description="API key for the xAI/Grok API.")
base_url: str = Field("https://api.x.ai/v1", description="Base URL for the xAI/Grok API.") base_url: str = Field("https://api.x.ai/v1", description="Base URL for the xAI/Grok API.")
@ -476,6 +520,7 @@ class XAIProvider(OpenAIProvider):
model_endpoint=self.base_url, model_endpoint=self.base_url,
context_window=context_window_size, context_window=context_window_size,
handle=self.get_handle(model_name), handle=self.get_handle(model_name),
provider_name=self.name,
) )
) )
@ -487,7 +532,7 @@ class XAIProvider(OpenAIProvider):
class AnthropicProvider(Provider): class AnthropicProvider(Provider):
name: str = "anthropic" provider_type: Literal[ProviderType.anthropic] = Field(ProviderType.anthropic, description="The type of the provider.")
api_key: str = Field(..., description="API key for the Anthropic API.") api_key: str = Field(..., description="API key for the Anthropic API.")
base_url: str = "https://api.anthropic.com/v1" base_url: str = "https://api.anthropic.com/v1"
@ -563,6 +608,7 @@ class AnthropicProvider(Provider):
handle=self.get_handle(model["id"]), handle=self.get_handle(model["id"]),
put_inner_thoughts_in_kwargs=inner_thoughts_in_kwargs, put_inner_thoughts_in_kwargs=inner_thoughts_in_kwargs,
max_tokens=max_tokens, max_tokens=max_tokens,
provider_name=self.name,
) )
) )
return configs return configs
@ -572,7 +618,7 @@ class AnthropicProvider(Provider):
class MistralProvider(Provider): class MistralProvider(Provider):
name: str = "mistral" provider_type: Literal[ProviderType.mistral] = Field(ProviderType.mistral, description="The type of the provider.")
api_key: str = Field(..., description="API key for the Mistral API.") api_key: str = Field(..., description="API key for the Mistral API.")
base_url: str = "https://api.mistral.ai/v1" base_url: str = "https://api.mistral.ai/v1"
@ -596,6 +642,7 @@ class MistralProvider(Provider):
model_endpoint=self.base_url, model_endpoint=self.base_url,
context_window=model["max_context_length"], context_window=model["max_context_length"],
handle=self.get_handle(model["id"]), handle=self.get_handle(model["id"]),
provider_name=self.name,
) )
) )
@ -622,7 +669,7 @@ class OllamaProvider(OpenAIProvider):
See: https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion See: https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion
""" """
name: str = "ollama" provider_type: Literal[ProviderType.ollama] = Field(ProviderType.ollama, description="The type of the provider.")
base_url: str = Field(..., description="Base URL for the Ollama API.") base_url: str = Field(..., description="Base URL for the Ollama API.")
api_key: Optional[str] = Field(None, description="API key for the Ollama API (default: `None`).") api_key: Optional[str] = Field(None, description="API key for the Ollama API (default: `None`).")
default_prompt_formatter: str = Field( default_prompt_formatter: str = Field(
@ -652,6 +699,7 @@ class OllamaProvider(OpenAIProvider):
model_wrapper=self.default_prompt_formatter, model_wrapper=self.default_prompt_formatter,
context_window=context_window, context_window=context_window,
handle=self.get_handle(model["name"]), handle=self.get_handle(model["name"]),
provider_name=self.name,
) )
) )
return configs return configs
@ -734,7 +782,7 @@ class OllamaProvider(OpenAIProvider):
class GroqProvider(OpenAIProvider): class GroqProvider(OpenAIProvider):
name: str = "groq" provider_type: Literal[ProviderType.groq] = Field(ProviderType.groq, description="The type of the provider.")
base_url: str = "https://api.groq.com/openai/v1" base_url: str = "https://api.groq.com/openai/v1"
api_key: str = Field(..., description="API key for the Groq API.") api_key: str = Field(..., description="API key for the Groq API.")
@ -753,6 +801,7 @@ class GroqProvider(OpenAIProvider):
model_endpoint=self.base_url, model_endpoint=self.base_url,
context_window=model["context_window"], context_window=model["context_window"],
handle=self.get_handle(model["id"]), handle=self.get_handle(model["id"]),
provider_name=self.name,
) )
) )
return configs return configs
@ -773,7 +822,7 @@ class TogetherProvider(OpenAIProvider):
function calling support is limited. function calling support is limited.
""" """
name: str = "together" provider_type: Literal[ProviderType.together] = Field(ProviderType.together, description="The type of the provider.")
base_url: str = "https://api.together.ai/v1" base_url: str = "https://api.together.ai/v1"
api_key: str = Field(..., description="API key for the TogetherAI API.") api_key: str = Field(..., description="API key for the TogetherAI API.")
default_prompt_formatter: str = Field(..., description="Default prompt formatter (aka model wrapper) to use on vLLM /completions API.") default_prompt_formatter: str = Field(..., description="Default prompt formatter (aka model wrapper) to use on vLLM /completions API.")
@ -821,6 +870,7 @@ class TogetherProvider(OpenAIProvider):
model_wrapper=self.default_prompt_formatter, model_wrapper=self.default_prompt_formatter,
context_window=context_window_size, context_window=context_window_size,
handle=self.get_handle(model_name), handle=self.get_handle(model_name),
provider_name=self.name,
) )
) )
@ -874,7 +924,7 @@ class TogetherProvider(OpenAIProvider):
class GoogleAIProvider(Provider): class GoogleAIProvider(Provider):
# gemini # gemini
name: str = "google_ai" provider_type: Literal[ProviderType.google_ai] = Field(ProviderType.google_ai, description="The type of the provider.")
api_key: str = Field(..., description="API key for the Google AI API.") api_key: str = Field(..., description="API key for the Google AI API.")
base_url: str = "https://generativelanguage.googleapis.com" base_url: str = "https://generativelanguage.googleapis.com"
@ -889,7 +939,6 @@ class GoogleAIProvider(Provider):
# filter by model names # filter by model names
model_options = [mo[len("models/") :] if mo.startswith("models/") else mo for mo in model_options] model_options = [mo[len("models/") :] if mo.startswith("models/") else mo for mo in model_options]
# TODO remove manual filtering for gemini-pro
# Add support for all gemini models # Add support for all gemini models
model_options = [mo for mo in model_options if str(mo).startswith("gemini-")] model_options = [mo for mo in model_options if str(mo).startswith("gemini-")]
@ -903,6 +952,7 @@ class GoogleAIProvider(Provider):
context_window=self.get_model_context_window(model), context_window=self.get_model_context_window(model),
handle=self.get_handle(model), handle=self.get_handle(model),
max_tokens=8192, max_tokens=8192,
provider_name=self.name,
) )
) )
return configs return configs
@ -938,7 +988,7 @@ class GoogleAIProvider(Provider):
class GoogleVertexProvider(Provider): class GoogleVertexProvider(Provider):
name: str = "google_vertex" provider_type: Literal[ProviderType.google_vertex] = Field(ProviderType.google_vertex, description="The type of the provider.")
google_cloud_project: str = Field(..., description="GCP project ID for the Google Vertex API.") google_cloud_project: str = Field(..., description="GCP project ID for the Google Vertex API.")
google_cloud_location: str = Field(..., description="GCP region for the Google Vertex API.") google_cloud_location: str = Field(..., description="GCP region for the Google Vertex API.")
@ -955,6 +1005,7 @@ class GoogleVertexProvider(Provider):
context_window=context_length, context_window=context_length,
handle=self.get_handle(model), handle=self.get_handle(model),
max_tokens=8192, max_tokens=8192,
provider_name=self.name,
) )
) )
return configs return configs
@ -978,7 +1029,7 @@ class GoogleVertexProvider(Provider):
class AzureProvider(Provider): class AzureProvider(Provider):
name: str = "azure" provider_type: Literal[ProviderType.azure] = Field(ProviderType.azure, description="The type of the provider.")
latest_api_version: str = "2024-09-01-preview" # https://learn.microsoft.com/en-us/azure/ai-services/openai/api-version-deprecation latest_api_version: str = "2024-09-01-preview" # https://learn.microsoft.com/en-us/azure/ai-services/openai/api-version-deprecation
base_url: str = Field( base_url: str = Field(
..., description="Base URL for the Azure API endpoint. This should be specific to your org, e.g. `https://letta.openai.azure.com`." ..., description="Base URL for the Azure API endpoint. This should be specific to your org, e.g. `https://letta.openai.azure.com`."
@ -1011,6 +1062,7 @@ class AzureProvider(Provider):
model_endpoint=model_endpoint, model_endpoint=model_endpoint,
context_window=context_window_size, context_window=context_window_size,
handle=self.get_handle(model_name), handle=self.get_handle(model_name),
provider_name=self.name,
), ),
) )
return configs return configs
@ -1051,7 +1103,7 @@ class VLLMChatCompletionsProvider(Provider):
"""vLLM provider that treats vLLM as an OpenAI /chat/completions proxy""" """vLLM provider that treats vLLM as an OpenAI /chat/completions proxy"""
# NOTE: vLLM only serves one model at a time (so could configure that through env variables) # NOTE: vLLM only serves one model at a time (so could configure that through env variables)
name: str = "vllm" provider_type: Literal[ProviderType.vllm] = Field(ProviderType.vllm, description="The type of the provider.")
base_url: str = Field(..., description="Base URL for the vLLM API.") base_url: str = Field(..., description="Base URL for the vLLM API.")
def list_llm_models(self) -> List[LLMConfig]: def list_llm_models(self) -> List[LLMConfig]:
@ -1070,6 +1122,7 @@ class VLLMChatCompletionsProvider(Provider):
model_endpoint=self.base_url, model_endpoint=self.base_url,
context_window=model["max_model_len"], context_window=model["max_model_len"],
handle=self.get_handle(model["id"]), handle=self.get_handle(model["id"]),
provider_name=self.name,
) )
) )
return configs return configs
@ -1083,7 +1136,7 @@ class VLLMCompletionsProvider(Provider):
"""This uses /completions API as the backend, not /chat/completions, so we need to specify a model wrapper""" """This uses /completions API as the backend, not /chat/completions, so we need to specify a model wrapper"""
# NOTE: vLLM only serves one model at a time (so could configure that through env variables) # NOTE: vLLM only serves one model at a time (so could configure that through env variables)
name: str = "vllm" provider_type: Literal[ProviderType.vllm] = Field(ProviderType.vllm, description="The type of the provider.")
base_url: str = Field(..., description="Base URL for the vLLM API.") base_url: str = Field(..., description="Base URL for the vLLM API.")
default_prompt_formatter: str = Field(..., description="Default prompt formatter (aka model wrapper) to use on vLLM /completions API.") default_prompt_formatter: str = Field(..., description="Default prompt formatter (aka model wrapper) to use on vLLM /completions API.")
@ -1103,6 +1156,7 @@ class VLLMCompletionsProvider(Provider):
model_wrapper=self.default_prompt_formatter, model_wrapper=self.default_prompt_formatter,
context_window=model["max_model_len"], context_window=model["max_model_len"],
handle=self.get_handle(model["id"]), handle=self.get_handle(model["id"]),
provider_name=self.name,
) )
) )
return configs return configs
@ -1117,7 +1171,7 @@ class CohereProvider(OpenAIProvider):
class AnthropicBedrockProvider(Provider): class AnthropicBedrockProvider(Provider):
name: str = "bedrock" provider_type: Literal[ProviderType.bedrock] = Field(ProviderType.bedrock, description="The type of the provider.")
aws_region: str = Field(..., description="AWS region for Bedrock") aws_region: str = Field(..., description="AWS region for Bedrock")
def list_llm_models(self): def list_llm_models(self):
@ -1131,10 +1185,11 @@ class AnthropicBedrockProvider(Provider):
configs.append( configs.append(
LLMConfig( LLMConfig(
model=model_arn, model=model_arn,
model_endpoint_type=self.name, model_endpoint_type=self.provider_type.value,
model_endpoint=None, model_endpoint=None,
context_window=self.get_model_context_window(model_arn), context_window=self.get_model_context_window(model_arn),
handle=self.get_handle(model_arn), handle=self.get_handle(model_arn),
provider_name=self.name,
) )
) )
return configs return configs

View File

@ -1,6 +1,6 @@
from typing import TYPE_CHECKING, List from typing import TYPE_CHECKING, List, Optional
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends, Query
from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.llm_config import LLMConfig from letta.schemas.llm_config import LLMConfig
@ -14,10 +14,11 @@ router = APIRouter(prefix="/models", tags=["models", "llms"])
@router.get("/", response_model=List[LLMConfig], operation_id="list_models") @router.get("/", response_model=List[LLMConfig], operation_id="list_models")
def list_llm_models( def list_llm_models(
byok_only: Optional[bool] = Query(None),
server: "SyncServer" = Depends(get_letta_server), server: "SyncServer" = Depends(get_letta_server),
): ):
models = server.list_llm_models() models = server.list_llm_models(byok_only=byok_only)
# print(models) # print(models)
return models return models

View File

@ -2,6 +2,7 @@ from typing import TYPE_CHECKING, List, Optional
from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query
from letta.schemas.enums import ProviderType
from letta.schemas.providers import Provider, ProviderCreate, ProviderUpdate from letta.schemas.providers import Provider, ProviderCreate, ProviderUpdate
from letta.server.rest_api.utils import get_letta_server from letta.server.rest_api.utils import get_letta_server
@ -13,6 +14,8 @@ router = APIRouter(prefix="/providers", tags=["providers"])
@router.get("/", response_model=List[Provider], operation_id="list_providers") @router.get("/", response_model=List[Provider], operation_id="list_providers")
def list_providers( def list_providers(
name: Optional[str] = Query(None),
provider_type: Optional[ProviderType] = Query(None),
after: Optional[str] = Query(None), after: Optional[str] = Query(None),
limit: Optional[int] = Query(50), limit: Optional[int] = Query(50),
actor_id: Optional[str] = Header(None, alias="user_id"), actor_id: Optional[str] = Header(None, alias="user_id"),
@ -23,7 +26,7 @@ def list_providers(
""" """
try: try:
actor = server.user_manager.get_user_or_default(user_id=actor_id) actor = server.user_manager.get_user_or_default(user_id=actor_id)
providers = server.provider_manager.list_providers(after=after, limit=limit, actor=actor) providers = server.provider_manager.list_providers(after=after, limit=limit, actor=actor, name=name, provider_type=provider_type)
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:

View File

@ -268,10 +268,11 @@ class SyncServer(Server):
) )
# collect providers (always has Letta as a default) # collect providers (always has Letta as a default)
self._enabled_providers: List[Provider] = [LettaProvider()] self._enabled_providers: List[Provider] = [LettaProvider(name="letta")]
if model_settings.openai_api_key: if model_settings.openai_api_key:
self._enabled_providers.append( self._enabled_providers.append(
OpenAIProvider( OpenAIProvider(
name="openai",
api_key=model_settings.openai_api_key, api_key=model_settings.openai_api_key,
base_url=model_settings.openai_api_base, base_url=model_settings.openai_api_base,
) )
@ -279,12 +280,14 @@ class SyncServer(Server):
if model_settings.anthropic_api_key: if model_settings.anthropic_api_key:
self._enabled_providers.append( self._enabled_providers.append(
AnthropicProvider( AnthropicProvider(
name="anthropic",
api_key=model_settings.anthropic_api_key, api_key=model_settings.anthropic_api_key,
) )
) )
if model_settings.ollama_base_url: if model_settings.ollama_base_url:
self._enabled_providers.append( self._enabled_providers.append(
OllamaProvider( OllamaProvider(
name="ollama",
base_url=model_settings.ollama_base_url, base_url=model_settings.ollama_base_url,
api_key=None, api_key=None,
default_prompt_formatter=model_settings.default_prompt_formatter, default_prompt_formatter=model_settings.default_prompt_formatter,
@ -293,12 +296,14 @@ class SyncServer(Server):
if model_settings.gemini_api_key: if model_settings.gemini_api_key:
self._enabled_providers.append( self._enabled_providers.append(
GoogleAIProvider( GoogleAIProvider(
name="google_ai",
api_key=model_settings.gemini_api_key, api_key=model_settings.gemini_api_key,
) )
) )
if model_settings.google_cloud_location and model_settings.google_cloud_project: if model_settings.google_cloud_location and model_settings.google_cloud_project:
self._enabled_providers.append( self._enabled_providers.append(
GoogleVertexProvider( GoogleVertexProvider(
name="google_vertex",
google_cloud_project=model_settings.google_cloud_project, google_cloud_project=model_settings.google_cloud_project,
google_cloud_location=model_settings.google_cloud_location, google_cloud_location=model_settings.google_cloud_location,
) )
@ -307,6 +312,7 @@ class SyncServer(Server):
assert model_settings.azure_api_version, "AZURE_API_VERSION is required" assert model_settings.azure_api_version, "AZURE_API_VERSION is required"
self._enabled_providers.append( self._enabled_providers.append(
AzureProvider( AzureProvider(
name="azure",
api_key=model_settings.azure_api_key, api_key=model_settings.azure_api_key,
base_url=model_settings.azure_base_url, base_url=model_settings.azure_base_url,
api_version=model_settings.azure_api_version, api_version=model_settings.azure_api_version,
@ -315,12 +321,14 @@ class SyncServer(Server):
if model_settings.groq_api_key: if model_settings.groq_api_key:
self._enabled_providers.append( self._enabled_providers.append(
GroqProvider( GroqProvider(
name="groq",
api_key=model_settings.groq_api_key, api_key=model_settings.groq_api_key,
) )
) )
if model_settings.together_api_key: if model_settings.together_api_key:
self._enabled_providers.append( self._enabled_providers.append(
TogetherProvider( TogetherProvider(
name="together",
api_key=model_settings.together_api_key, api_key=model_settings.together_api_key,
default_prompt_formatter=model_settings.default_prompt_formatter, default_prompt_formatter=model_settings.default_prompt_formatter,
) )
@ -329,6 +337,7 @@ class SyncServer(Server):
# vLLM exposes both a /chat/completions and a /completions endpoint # vLLM exposes both a /chat/completions and a /completions endpoint
self._enabled_providers.append( self._enabled_providers.append(
VLLMCompletionsProvider( VLLMCompletionsProvider(
name="vllm",
base_url=model_settings.vllm_api_base, base_url=model_settings.vllm_api_base,
default_prompt_formatter=model_settings.default_prompt_formatter, default_prompt_formatter=model_settings.default_prompt_formatter,
) )
@ -338,12 +347,14 @@ class SyncServer(Server):
# e.g. "... --enable-auto-tool-choice --tool-call-parser hermes" # e.g. "... --enable-auto-tool-choice --tool-call-parser hermes"
self._enabled_providers.append( self._enabled_providers.append(
VLLMChatCompletionsProvider( VLLMChatCompletionsProvider(
name="vllm",
base_url=model_settings.vllm_api_base, base_url=model_settings.vllm_api_base,
) )
) )
if model_settings.aws_access_key and model_settings.aws_secret_access_key and model_settings.aws_region: if model_settings.aws_access_key and model_settings.aws_secret_access_key and model_settings.aws_region:
self._enabled_providers.append( self._enabled_providers.append(
AnthropicBedrockProvider( AnthropicBedrockProvider(
name="bedrock",
aws_region=model_settings.aws_region, aws_region=model_settings.aws_region,
) )
) )
@ -355,11 +366,11 @@ class SyncServer(Server):
if model_settings.lmstudio_base_url.endswith("/v1") if model_settings.lmstudio_base_url.endswith("/v1")
else model_settings.lmstudio_base_url + "/v1" else model_settings.lmstudio_base_url + "/v1"
) )
self._enabled_providers.append(LMStudioOpenAIProvider(base_url=lmstudio_url)) self._enabled_providers.append(LMStudioOpenAIProvider(name="lmstudio_openai", base_url=lmstudio_url))
if model_settings.deepseek_api_key: if model_settings.deepseek_api_key:
self._enabled_providers.append(DeepSeekProvider(api_key=model_settings.deepseek_api_key)) self._enabled_providers.append(DeepSeekProvider(name="deepseek", api_key=model_settings.deepseek_api_key))
if model_settings.xai_api_key: if model_settings.xai_api_key:
self._enabled_providers.append(XAIProvider(api_key=model_settings.xai_api_key)) self._enabled_providers.append(XAIProvider(name="xai", api_key=model_settings.xai_api_key))
# For MCP # For MCP
"""Initialize the MCP clients (there may be multiple)""" """Initialize the MCP clients (there may be multiple)"""
@ -1184,10 +1195,10 @@ class SyncServer(Server):
except NoResultFound: except NoResultFound:
raise HTTPException(status_code=404, detail=f"Organization with id {org_id} not found") raise HTTPException(status_code=404, detail=f"Organization with id {org_id} not found")
def list_llm_models(self) -> List[LLMConfig]: def list_llm_models(self, byok_only: bool = False) -> List[LLMConfig]:
"""List available models""" """List available models"""
llm_models = [] llm_models = []
for provider in self.get_enabled_providers(): for provider in self.get_enabled_providers(byok_only=byok_only):
try: try:
llm_models.extend(provider.list_llm_models()) llm_models.extend(provider.list_llm_models())
except Exception as e: except Exception as e:
@ -1207,11 +1218,12 @@ class SyncServer(Server):
warnings.warn(f"An error occurred while listing embedding models for provider {provider}: {e}") warnings.warn(f"An error occurred while listing embedding models for provider {provider}: {e}")
return embedding_models return embedding_models
def get_enabled_providers(self): def get_enabled_providers(self, byok_only: bool = False):
providers_from_db = {p.name: p.cast_to_subtype() for p in self.provider_manager.list_providers()}
if byok_only:
return list(providers_from_db.values())
providers_from_env = {p.name: p for p in self._enabled_providers} providers_from_env = {p.name: p for p in self._enabled_providers}
providers_from_db = {p.name: p for p in self.provider_manager.list_providers()} return list(providers_from_env.values()) + list(providers_from_db.values())
# Merge the two dictionaries, keeping the values from providers_from_db where conflicts occur
return {**providers_from_env, **providers_from_db}.values()
@trace_method @trace_method
def get_llm_config_from_handle( def get_llm_config_from_handle(
@ -1296,7 +1308,7 @@ class SyncServer(Server):
return embedding_config return embedding_config
def get_provider_from_name(self, provider_name: str) -> Provider: def get_provider_from_name(self, provider_name: str) -> Provider:
providers = [provider for provider in self._enabled_providers if provider.name == provider_name] providers = [provider for provider in self.get_enabled_providers() if provider.name == provider_name]
if not providers: if not providers:
raise ValueError(f"Provider {provider_name} is not supported") raise ValueError(f"Provider {provider_name} is not supported")
elif len(providers) > 1: elif len(providers) > 1:

View File

@ -1,6 +1,7 @@
from typing import List, Optional from typing import List, Optional, Union
from letta.orm.provider import Provider as ProviderModel from letta.orm.provider import Provider as ProviderModel
from letta.schemas.enums import ProviderType
from letta.schemas.providers import Provider as PydanticProvider from letta.schemas.providers import Provider as PydanticProvider
from letta.schemas.providers import ProviderUpdate from letta.schemas.providers import ProviderUpdate
from letta.schemas.user import User as PydanticUser from letta.schemas.user import User as PydanticUser
@ -18,6 +19,9 @@ class ProviderManager:
def create_provider(self, provider: PydanticProvider, actor: PydanticUser) -> PydanticProvider: def create_provider(self, provider: PydanticProvider, actor: PydanticUser) -> PydanticProvider:
"""Create a new provider if it doesn't already exist.""" """Create a new provider if it doesn't already exist."""
with self.session_maker() as session: with self.session_maker() as session:
if provider.name == provider.provider_type.value:
raise ValueError("Provider name must be unique and different from provider type")
# Assign the organization id based on the actor # Assign the organization id based on the actor
provider.organization_id = actor.organization_id provider.organization_id = actor.organization_id
@ -59,29 +63,36 @@ class ProviderManager:
session.commit() session.commit()
@enforce_types @enforce_types
def list_providers(self, after: Optional[str] = None, limit: Optional[int] = 50, actor: PydanticUser = None) -> List[PydanticProvider]: def list_providers(
self,
name: Optional[str] = None,
provider_type: Optional[ProviderType] = None,
after: Optional[str] = None,
limit: Optional[int] = 50,
actor: PydanticUser = None,
) -> List[PydanticProvider]:
"""List all providers with optional pagination.""" """List all providers with optional pagination."""
filter_kwargs = {}
if name:
filter_kwargs["name"] = name
if provider_type:
filter_kwargs["provider_type"] = provider_type
with self.session_maker() as session: with self.session_maker() as session:
providers = ProviderModel.list( providers = ProviderModel.list(
db_session=session, db_session=session,
after=after, after=after,
limit=limit, limit=limit,
actor=actor, actor=actor,
**filter_kwargs,
) )
return [provider.to_pydantic() for provider in providers] return [provider.to_pydantic() for provider in providers]
@enforce_types @enforce_types
def get_anthropic_override_provider_id(self) -> Optional[str]: def get_provider_id_from_name(self, provider_name: Union[str, None]) -> Optional[str]:
"""Helper function to fetch custom anthropic provider id for v0 BYOK feature""" providers = self.list_providers(name=provider_name)
anthropic_provider = [provider for provider in self.list_providers() if provider.name == "anthropic"] return providers[0].id if providers else None
if len(anthropic_provider) != 0:
return anthropic_provider[0].id
return None
@enforce_types @enforce_types
def get_anthropic_override_key(self) -> Optional[str]: def get_override_key(self, provider_name: Union[str, None]) -> Optional[str]:
"""Helper function to fetch custom anthropic key for v0 BYOK feature""" providers = self.list_providers(name=provider_name)
anthropic_provider = [provider for provider in self.list_providers() if provider.name == "anthropic"] return providers[0].api_key if providers else None
if len(anthropic_provider) != 0:
return anthropic_provider[0].api_key
return None

View File

@ -105,7 +105,8 @@ def check_first_response_is_valid_for_llm_endpoint(filename: str, validate_inner
agent = Agent(agent_state=full_agent_state, interface=None, user=client.user) agent = Agent(agent_state=full_agent_state, interface=None, user=client.user)
llm_client = LLMClient.create( llm_client = LLMClient.create(
provider=agent_state.llm_config.model_endpoint_type, provider_name=agent_state.llm_config.provider_name,
provider_type=agent_state.llm_config.model_endpoint_type,
) )
if llm_client: if llm_client:
response = llm_client.send_llm_request( response = llm_client.send_llm_request(

View File

@ -19,97 +19,166 @@ from letta.settings import model_settings
def test_openai(): def test_openai():
api_key = os.getenv("OPENAI_API_KEY") api_key = os.getenv("OPENAI_API_KEY")
assert api_key is not None assert api_key is not None
provider = OpenAIProvider(api_key=api_key, base_url=model_settings.openai_api_base) provider = OpenAIProvider(
name="openai",
api_key=api_key,
base_url=model_settings.openai_api_base,
)
models = provider.list_llm_models() models = provider.list_llm_models()
print(models) assert len(models) > 0
assert models[0].handle == f"{provider.name}/{models[0].model}"
embedding_models = provider.list_embedding_models()
assert len(embedding_models) > 0
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
def test_deepseek(): def test_deepseek():
api_key = os.getenv("DEEPSEEK_API_KEY") api_key = os.getenv("DEEPSEEK_API_KEY")
assert api_key is not None assert api_key is not None
provider = DeepSeekProvider(api_key=api_key) provider = DeepSeekProvider(
name="deepseek",
api_key=api_key,
)
models = provider.list_llm_models() models = provider.list_llm_models()
print(models) assert len(models) > 0
assert models[0].handle == f"{provider.name}/{models[0].model}"
def test_anthropic(): def test_anthropic():
api_key = os.getenv("ANTHROPIC_API_KEY") api_key = os.getenv("ANTHROPIC_API_KEY")
assert api_key is not None assert api_key is not None
provider = AnthropicProvider(api_key=api_key) provider = AnthropicProvider(
name="anthropic",
api_key=api_key,
)
models = provider.list_llm_models() models = provider.list_llm_models()
print(models) assert len(models) > 0
assert models[0].handle == f"{provider.name}/{models[0].model}"
def test_groq(): def test_groq():
provider = GroqProvider(api_key=os.getenv("GROQ_API_KEY")) provider = GroqProvider(
name="groq",
api_key=os.getenv("GROQ_API_KEY"),
)
models = provider.list_llm_models() models = provider.list_llm_models()
print(models) assert len(models) > 0
assert models[0].handle == f"{provider.name}/{models[0].model}"
def test_azure(): def test_azure():
provider = AzureProvider(api_key=os.getenv("AZURE_API_KEY"), base_url=os.getenv("AZURE_BASE_URL")) provider = AzureProvider(
name="azure",
api_key=os.getenv("AZURE_API_KEY"),
base_url=os.getenv("AZURE_BASE_URL"),
)
models = provider.list_llm_models() models = provider.list_llm_models()
print([m.model for m in models]) assert len(models) > 0
assert models[0].handle == f"{provider.name}/{models[0].model}"
embed_models = provider.list_embedding_models() embedding_models = provider.list_embedding_models()
print([m.embedding_model for m in embed_models]) assert len(embedding_models) > 0
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
def test_ollama(): def test_ollama():
base_url = os.getenv("OLLAMA_BASE_URL") base_url = os.getenv("OLLAMA_BASE_URL")
assert base_url is not None assert base_url is not None
provider = OllamaProvider(base_url=base_url, default_prompt_formatter=model_settings.default_prompt_formatter, api_key=None) provider = OllamaProvider(
name="ollama",
base_url=base_url,
default_prompt_formatter=model_settings.default_prompt_formatter,
api_key=None,
)
models = provider.list_llm_models() models = provider.list_llm_models()
print(models) assert len(models) > 0
assert models[0].handle == f"{provider.name}/{models[0].model}"
embedding_models = provider.list_embedding_models() embedding_models = provider.list_embedding_models()
print(embedding_models) assert len(embedding_models) > 0
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
def test_googleai(): def test_googleai():
api_key = os.getenv("GEMINI_API_KEY") api_key = os.getenv("GEMINI_API_KEY")
assert api_key is not None assert api_key is not None
provider = GoogleAIProvider(api_key=api_key) provider = GoogleAIProvider(
name="google_ai",
api_key=api_key,
)
models = provider.list_llm_models() models = provider.list_llm_models()
print(models) assert len(models) > 0
assert models[0].handle == f"{provider.name}/{models[0].model}"
provider.list_embedding_models() embedding_models = provider.list_embedding_models()
assert len(embedding_models) > 0
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
def test_google_vertex(): def test_google_vertex():
provider = GoogleVertexProvider(google_cloud_project=os.getenv("GCP_PROJECT_ID"), google_cloud_location=os.getenv("GCP_REGION")) provider = GoogleVertexProvider(
name="google_vertex",
google_cloud_project=os.getenv("GCP_PROJECT_ID"),
google_cloud_location=os.getenv("GCP_REGION"),
)
models = provider.list_llm_models() models = provider.list_llm_models()
print(models) assert len(models) > 0
print([m.model for m in models]) assert models[0].handle == f"{provider.name}/{models[0].model}"
embedding_models = provider.list_embedding_models() embedding_models = provider.list_embedding_models()
print([m.embedding_model for m in embedding_models]) assert len(embedding_models) > 0
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
def test_mistral(): def test_mistral():
provider = MistralProvider(api_key=os.getenv("MISTRAL_API_KEY")) provider = MistralProvider(
name="mistral",
api_key=os.getenv("MISTRAL_API_KEY"),
)
models = provider.list_llm_models() models = provider.list_llm_models()
print([m.model for m in models]) assert len(models) > 0
assert models[0].handle == f"{provider.name}/{models[0].model}"
def test_together(): def test_together():
provider = TogetherProvider(api_key=os.getenv("TOGETHER_API_KEY"), default_prompt_formatter="chatml") provider = TogetherProvider(
name="together",
api_key=os.getenv("TOGETHER_API_KEY"),
default_prompt_formatter="chatml",
)
models = provider.list_llm_models() models = provider.list_llm_models()
print([m.model for m in models]) assert len(models) > 0
assert models[0].handle == f"{provider.name}/{models[0].model}"
embedding_models = provider.list_embedding_models() embedding_models = provider.list_embedding_models()
print([m.embedding_model for m in embedding_models]) assert len(embedding_models) > 0
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
def test_anthropic_bedrock(): def test_anthropic_bedrock():
from letta.settings import model_settings from letta.settings import model_settings
provider = AnthropicBedrockProvider(aws_region=model_settings.aws_region) provider = AnthropicBedrockProvider(name="bedrock", aws_region=model_settings.aws_region)
models = provider.list_llm_models() models = provider.list_llm_models()
print([m.model for m in models]) assert len(models) > 0
assert models[0].handle == f"{provider.name}/{models[0].model}"
embedding_models = provider.list_embedding_models() embedding_models = provider.list_embedding_models()
print([m.embedding_model for m in embedding_models]) assert len(embedding_models) > 0
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
def test_custom_anthropic():
api_key = os.getenv("ANTHROPIC_API_KEY")
assert api_key is not None
provider = AnthropicProvider(
name="custom_anthropic",
api_key=api_key,
)
models = provider.list_llm_models()
assert len(models) > 0
assert models[0].handle == f"{provider.name}/{models[0].model}"
# def test_vllm(): # def test_vllm():

View File

@ -13,7 +13,7 @@ import letta.utils as utils
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, LETTA_DIR, LETTA_TOOL_EXECUTION_DIR from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, LETTA_DIR, LETTA_TOOL_EXECUTION_DIR
from letta.orm import Provider, Step from letta.orm import Provider, Step
from letta.schemas.block import CreateBlock from letta.schemas.block import CreateBlock
from letta.schemas.enums import MessageRole from letta.schemas.enums import MessageRole, ProviderType
from letta.schemas.letta_message import LettaMessage, ReasoningMessage, SystemMessage, ToolCallMessage, ToolReturnMessage, UserMessage from letta.schemas.letta_message import LettaMessage, ReasoningMessage, SystemMessage, ToolCallMessage, ToolReturnMessage, UserMessage
from letta.schemas.llm_config import LLMConfig from letta.schemas.llm_config import LLMConfig
from letta.schemas.providers import Provider as PydanticProvider from letta.schemas.providers import Provider as PydanticProvider
@ -1226,7 +1226,8 @@ def test_messages_with_provider_override(server: SyncServer, user_id: str):
actor = server.user_manager.get_user_or_default(user_id) actor = server.user_manager.get_user_or_default(user_id)
provider = server.provider_manager.create_provider( provider = server.provider_manager.create_provider(
provider=PydanticProvider( provider=PydanticProvider(
name="anthropic", name="caren-anthropic",
provider_type=ProviderType.anthropic,
api_key=os.getenv("ANTHROPIC_API_KEY"), api_key=os.getenv("ANTHROPIC_API_KEY"),
), ),
actor=actor, actor=actor,
@ -1234,8 +1235,8 @@ def test_messages_with_provider_override(server: SyncServer, user_id: str):
agent = server.create_agent( agent = server.create_agent(
request=CreateAgent( request=CreateAgent(
memory_blocks=[], memory_blocks=[],
model="anthropic/claude-3-opus-20240229", model="caren-anthropic/claude-3-opus-20240229",
context_window_limit=200000, context_window_limit=100000,
embedding="openai/text-embedding-ada-002", embedding="openai/text-embedding-ada-002",
), ),
actor=actor, actor=actor,