feat: byok 2.0 (#1963)

This commit is contained in:
cthomas 2025-04-30 21:26:50 -07:00 committed by GitHub
parent e3819cf066
commit 835792d5e0
23 changed files with 352 additions and 111 deletions

View File

@ -0,0 +1,35 @@
"""add byok fields and unique constraint
Revision ID: 373dabcba6cf
Revises: c56081a05371
Create Date: 2025-04-30 19:38:25.010856
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "373dabcba6cf"
down_revision: Union[str, None] = "c56081a05371"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("providers", sa.Column("provider_type", sa.String(), nullable=True))
op.add_column("providers", sa.Column("base_url", sa.String(), nullable=True))
op.create_unique_constraint("unique_name_organization_id", "providers", ["name", "organization_id"])
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint("unique_name_organization_id", "providers", type_="unique")
op.drop_column("providers", "base_url")
op.drop_column("providers", "provider_type")
# ### end Alembic commands ###

View File

@ -331,7 +331,8 @@ class Agent(BaseAgent):
log_telemetry(self.logger, "_get_ai_reply create start")
# New LLM client flow
llm_client = LLMClient.create(
provider=self.agent_state.llm_config.model_endpoint_type,
provider_name=self.agent_state.llm_config.provider_name,
provider_type=self.agent_state.llm_config.model_endpoint_type,
put_inner_thoughts_first=put_inner_thoughts_first,
)
@ -941,12 +942,7 @@ class Agent(BaseAgent):
model_endpoint=self.agent_state.llm_config.model_endpoint,
context_window_limit=self.agent_state.llm_config.context_window,
usage=response.usage,
# TODO(@caren): Add full provider support - this line is a workaround for v0 BYOK feature
provider_id=(
self.provider_manager.get_anthropic_override_provider_id()
if self.agent_state.llm_config.model_endpoint_type == "anthropic"
else None
),
provider_id=self.provider_manager.get_provider_id_from_name(self.agent_state.llm_config.provider_name),
job_id=job_id,
)
for message in all_new_messages:

View File

@ -67,7 +67,8 @@ class LettaAgent(BaseAgent):
)
tool_rules_solver = ToolRulesSolver(agent_state.tool_rules)
llm_client = LLMClient.create(
provider=agent_state.llm_config.model_endpoint_type,
provider_name=agent_state.llm_config.provider_name,
provider_type=agent_state.llm_config.model_endpoint_type,
put_inner_thoughts_first=True,
)
for step in range(max_steps):

View File

@ -156,7 +156,8 @@ class LettaAgentBatch:
log_event(name="init_llm_client")
llm_client = LLMClient.create(
provider=agent_states[0].llm_config.model_endpoint_type,
provider_name=agent_states[0].llm_config.provider_name,
provider_type=agent_states[0].llm_config.model_endpoint_type,
put_inner_thoughts_first=True,
)
agent_llm_config_mapping = {s.id: s.llm_config for s in agent_states}
@ -273,7 +274,8 @@ class LettaAgentBatch:
# translate providerspecific response → OpenAIstyle tool call (unchanged)
llm_client = LLMClient.create(
provider=item.llm_config.model_endpoint_type,
provider_name=item.llm_config.provider_name,
provider_type=item.llm_config.model_endpoint_type,
put_inner_thoughts_first=True,
)
tool_call = (

View File

@ -26,6 +26,7 @@ from letta.llm_api.helpers import add_inner_thoughts_to_functions
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION
from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages
from letta.log import get_logger
from letta.schemas.enums import ProviderType
from letta.schemas.message import Message as _Message
from letta.schemas.message import MessageRole as _MessageRole
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, Tool
@ -128,11 +129,12 @@ def anthropic_get_model_list(url: str, api_key: Union[str, None]) -> dict:
# NOTE: currently there is no GET /models, so we need to hardcode
# return MODEL_LIST
anthropic_override_key = ProviderManager().get_anthropic_override_key()
if anthropic_override_key:
anthropic_client = anthropic.Anthropic(api_key=anthropic_override_key)
if api_key:
anthropic_client = anthropic.Anthropic(api_key=api_key)
elif model_settings.anthropic_api_key:
anthropic_client = anthropic.Anthropic()
else:
raise ValueError("No API key provided")
models = anthropic_client.models.list()
models_json = models.model_dump()
@ -738,13 +740,14 @@ def anthropic_chat_completions_request(
put_inner_thoughts_in_kwargs: bool = False,
extended_thinking: bool = False,
max_reasoning_tokens: Optional[int] = None,
provider_name: Optional[str] = None,
betas: List[str] = ["tools-2024-04-04"],
) -> ChatCompletionResponse:
"""https://docs.anthropic.com/claude/docs/tool-use"""
anthropic_client = None
anthropic_override_key = ProviderManager().get_anthropic_override_key()
if anthropic_override_key:
anthropic_client = anthropic.Anthropic(api_key=anthropic_override_key)
if provider_name and provider_name != ProviderType.anthropic.value:
api_key = ProviderManager().get_override_key(provider_name)
anthropic_client = anthropic.Anthropic(api_key=api_key)
elif model_settings.anthropic_api_key:
anthropic_client = anthropic.Anthropic()
else:
@ -796,6 +799,7 @@ def anthropic_chat_completions_request_stream(
put_inner_thoughts_in_kwargs: bool = False,
extended_thinking: bool = False,
max_reasoning_tokens: Optional[int] = None,
provider_name: Optional[str] = None,
betas: List[str] = ["tools-2024-04-04"],
) -> Generator[ChatCompletionChunkResponse, None, None]:
"""Stream chat completions from Anthropic API.
@ -810,10 +814,9 @@ def anthropic_chat_completions_request_stream(
extended_thinking=extended_thinking,
max_reasoning_tokens=max_reasoning_tokens,
)
anthropic_override_key = ProviderManager().get_anthropic_override_key()
if anthropic_override_key:
anthropic_client = anthropic.Anthropic(api_key=anthropic_override_key)
if provider_name and provider_name != ProviderType.anthropic.value:
api_key = ProviderManager().get_override_key(provider_name)
anthropic_client = anthropic.Anthropic(api_key=api_key)
elif model_settings.anthropic_api_key:
anthropic_client = anthropic.Anthropic()
@ -860,6 +863,7 @@ def anthropic_chat_completions_process_stream(
put_inner_thoughts_in_kwargs: bool = False,
extended_thinking: bool = False,
max_reasoning_tokens: Optional[int] = None,
provider_name: Optional[str] = None,
create_message_id: bool = True,
create_message_datetime: bool = True,
betas: List[str] = ["tools-2024-04-04"],
@ -944,6 +948,7 @@ def anthropic_chat_completions_process_stream(
put_inner_thoughts_in_kwargs=put_inner_thoughts_in_kwargs,
extended_thinking=extended_thinking,
max_reasoning_tokens=max_reasoning_tokens,
provider_name=provider_name,
betas=betas,
)
):

View File

@ -27,6 +27,7 @@ from letta.llm_api.helpers import add_inner_thoughts_to_functions, unpack_all_in
from letta.llm_api.llm_client_base import LLMClientBase
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION
from letta.log import get_logger
from letta.schemas.enums import ProviderType
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message as PydanticMessage
from letta.schemas.openai.chat_completion_request import Tool
@ -112,7 +113,10 @@ class AnthropicClient(LLMClientBase):
@trace_method
def _get_anthropic_client(self, async_client: bool = False) -> Union[anthropic.AsyncAnthropic, anthropic.Anthropic]:
override_key = ProviderManager().get_anthropic_override_key()
override_key = None
if self.provider_name and self.provider_name != ProviderType.anthropic.value:
override_key = ProviderManager().get_override_key(self.provider_name)
if async_client:
return anthropic.AsyncAnthropic(api_key=override_key) if override_key else anthropic.AsyncAnthropic()
return anthropic.Anthropic(api_key=override_key) if override_key else anthropic.Anthropic()

View File

@ -24,6 +24,7 @@ from letta.llm_api.openai import (
from letta.local_llm.chat_completion_proxy import get_chat_completion
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION
from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages
from letta.schemas.enums import ProviderType
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, cast_message_to_subtype
@ -171,6 +172,10 @@ def create(
if model_settings.openai_api_key is None and llm_config.model_endpoint == "https://api.openai.com/v1":
# only is a problem if we are *not* using an openai proxy
raise LettaConfigurationError(message="OpenAI key is missing from letta config file", missing_fields=["openai_api_key"])
elif llm_config.provider_name and llm_config.provider_name != ProviderType.openai.value:
from letta.services.provider_manager import ProviderManager
api_key = ProviderManager().get_override_key(llm_config.provider_name)
elif model_settings.openai_api_key is None:
# the openai python client requires a dummy API key
api_key = "DUMMY_API_KEY"
@ -373,6 +378,7 @@ def create(
stream_interface=stream_interface,
extended_thinking=llm_config.enable_reasoner,
max_reasoning_tokens=llm_config.max_reasoning_tokens,
provider_name=llm_config.provider_name,
name=name,
)
@ -383,6 +389,7 @@ def create(
put_inner_thoughts_in_kwargs=llm_config.put_inner_thoughts_in_kwargs,
extended_thinking=llm_config.enable_reasoner,
max_reasoning_tokens=llm_config.max_reasoning_tokens,
provider_name=llm_config.provider_name,
)
if llm_config.put_inner_thoughts_in_kwargs:

View File

@ -9,7 +9,8 @@ class LLMClient:
@staticmethod
def create(
provider: ProviderType,
provider_type: ProviderType,
provider_name: Optional[str] = None,
put_inner_thoughts_first: bool = True,
) -> Optional[LLMClientBase]:
"""
@ -25,29 +26,33 @@ class LLMClient:
Raises:
ValueError: If the model endpoint type is not supported
"""
match provider:
match provider_type:
case ProviderType.google_ai:
from letta.llm_api.google_ai_client import GoogleAIClient
return GoogleAIClient(
provider_name=provider_name,
put_inner_thoughts_first=put_inner_thoughts_first,
)
case ProviderType.google_vertex:
from letta.llm_api.google_vertex_client import GoogleVertexClient
return GoogleVertexClient(
provider_name=provider_name,
put_inner_thoughts_first=put_inner_thoughts_first,
)
case ProviderType.anthropic:
from letta.llm_api.anthropic_client import AnthropicClient
return AnthropicClient(
provider_name=provider_name,
put_inner_thoughts_first=put_inner_thoughts_first,
)
case ProviderType.openai:
from letta.llm_api.openai_client import OpenAIClient
return OpenAIClient(
provider_name=provider_name,
put_inner_thoughts_first=put_inner_thoughts_first,
)
case _:

View File

@ -20,9 +20,11 @@ class LLMClientBase:
def __init__(
self,
provider_name: Optional[str] = None,
put_inner_thoughts_first: Optional[bool] = True,
use_tool_naming: bool = True,
):
self.provider_name = provider_name
self.put_inner_thoughts_first = put_inner_thoughts_first
self.use_tool_naming = use_tool_naming

View File

@ -22,6 +22,7 @@ from letta.llm_api.helpers import add_inner_thoughts_to_functions, convert_to_st
from letta.llm_api.llm_client_base import LLMClientBase
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION, INNER_THOUGHTS_KWARG_DESCRIPTION_GO_FIRST
from letta.log import get_logger
from letta.schemas.enums import ProviderType
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message as PydanticMessage
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest
@ -64,6 +65,13 @@ def supports_parallel_tool_calling(model: str) -> bool:
class OpenAIClient(LLMClientBase):
def _prepare_client_kwargs(self, llm_config: LLMConfig) -> dict:
api_key = None
if llm_config.provider_name and llm_config.provider_name != ProviderType.openai.value:
from letta.services.provider_manager import ProviderManager
api_key = ProviderManager().get_override_key(llm_config.provider_name)
if not api_key:
api_key = model_settings.openai_api_key or os.environ.get("OPENAI_API_KEY")
# supposedly the openai python client requires a dummy API key
api_key = api_key or "DUMMY_API_KEY"

View File

@ -79,7 +79,8 @@ def summarize_messages(
llm_config_no_inner_thoughts.put_inner_thoughts_in_kwargs = False
llm_client = LLMClient.create(
provider=llm_config_no_inner_thoughts.model_endpoint_type,
provider_name=llm_config_no_inner_thoughts.provider_name,
provider_type=llm_config_no_inner_thoughts.model_endpoint_type,
put_inner_thoughts_first=False,
)
# try to use new client, otherwise fallback to old flow

View File

@ -1,5 +1,6 @@
from typing import TYPE_CHECKING
from sqlalchemy import UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column, relationship
from letta.orm.mixins import OrganizationMixin
@ -15,9 +16,18 @@ class Provider(SqlalchemyBase, OrganizationMixin):
__tablename__ = "providers"
__pydantic_model__ = PydanticProvider
__table_args__ = (
UniqueConstraint(
"name",
"organization_id",
name="unique_name_organization_id",
),
)
name: Mapped[str] = mapped_column(nullable=False, doc="The name of the provider")
provider_type: Mapped[str] = mapped_column(nullable=True, doc="The type of the provider")
api_key: Mapped[str] = mapped_column(nullable=True, doc="API key used for requests to the provider.")
base_url: Mapped[str] = mapped_column(nullable=True, doc="Base URL for the provider.")
# relationships
organization: Mapped["Organization"] = relationship("Organization", back_populates="providers")

View File

@ -6,6 +6,17 @@ class ProviderType(str, Enum):
google_ai = "google_ai"
google_vertex = "google_vertex"
openai = "openai"
letta = "letta"
deepseek = "deepseek"
lmstudio_openai = "lmstudio_openai"
xai = "xai"
mistral = "mistral"
ollama = "ollama"
groq = "groq"
together = "together"
azure = "azure"
vllm = "vllm"
bedrock = "bedrock"
class MessageRole(str, Enum):

View File

@ -50,6 +50,7 @@ class LLMConfig(BaseModel):
"xai",
] = Field(..., description="The endpoint type for the model.")
model_endpoint: Optional[str] = Field(None, description="The endpoint for the model.")
provider_name: Optional[str] = Field(None, description="The provider name for the model.")
model_wrapper: Optional[str] = Field(None, description="The wrapper for the model.")
context_window: int = Field(..., description="The context window size for the model.")
put_inner_thoughts_in_kwargs: Optional[bool] = Field(

View File

@ -2,8 +2,8 @@ from typing import Dict
LLM_HANDLE_OVERRIDES: Dict[str, Dict[str, str]] = {
"anthropic": {
"claude-3-5-haiku-20241022": "claude-3.5-haiku",
"claude-3-5-sonnet-20241022": "claude-3.5-sonnet",
"claude-3-5-haiku-20241022": "claude-3-5-haiku",
"claude-3-5-sonnet-20241022": "claude-3-5-sonnet",
"claude-3-opus-20240229": "claude-3-opus",
},
"openai": {

View File

@ -1,6 +1,6 @@
import warnings
from datetime import datetime
from typing import List, Optional
from typing import List, Literal, Optional
from pydantic import Field, model_validator
@ -9,9 +9,11 @@ from letta.llm_api.azure_openai import get_azure_chat_completions_endpoint, get_
from letta.llm_api.azure_openai_constants import AZURE_MODEL_TO_CONTEXT_LENGTH
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.embedding_config_overrides import EMBEDDING_HANDLE_OVERRIDES
from letta.schemas.enums import ProviderType
from letta.schemas.letta_base import LettaBase
from letta.schemas.llm_config import LLMConfig
from letta.schemas.llm_config_overrides import LLM_HANDLE_OVERRIDES
from letta.settings import model_settings
class ProviderBase(LettaBase):
@ -21,10 +23,18 @@ class ProviderBase(LettaBase):
class Provider(ProviderBase):
id: Optional[str] = Field(None, description="The id of the provider, lazily created by the database manager.")
name: str = Field(..., description="The name of the provider")
provider_type: ProviderType = Field(..., description="The type of the provider")
api_key: Optional[str] = Field(None, description="API key used for requests to the provider.")
base_url: Optional[str] = Field(None, description="Base URL for the provider.")
organization_id: Optional[str] = Field(None, description="The organization id of the user")
updated_at: Optional[datetime] = Field(None, description="The last update timestamp of the provider.")
@model_validator(mode="after")
def default_base_url(self):
if self.provider_type == ProviderType.openai and self.base_url is None:
self.base_url = model_settings.openai_api_base
return self
def resolve_identifier(self):
if not self.id:
self.id = ProviderBase.generate_id(prefix=ProviderBase.__id_prefix__)
@ -59,9 +69,41 @@ class Provider(ProviderBase):
return f"{self.name}/{model_name}"
def cast_to_subtype(self):
match (self.provider_type):
case ProviderType.letta:
return LettaProvider(**self.model_dump(exclude_none=True))
case ProviderType.openai:
return OpenAIProvider(**self.model_dump(exclude_none=True))
case ProviderType.anthropic:
return AnthropicProvider(**self.model_dump(exclude_none=True))
case ProviderType.anthropic_bedrock:
return AnthropicBedrockProvider(**self.model_dump(exclude_none=True))
case ProviderType.ollama:
return OllamaProvider(**self.model_dump(exclude_none=True))
case ProviderType.google_ai:
return GoogleAIProvider(**self.model_dump(exclude_none=True))
case ProviderType.google_vertex:
return GoogleVertexProvider(**self.model_dump(exclude_none=True))
case ProviderType.azure:
return AzureProvider(**self.model_dump(exclude_none=True))
case ProviderType.groq:
return GroqProvider(**self.model_dump(exclude_none=True))
case ProviderType.together:
return TogetherProvider(**self.model_dump(exclude_none=True))
case ProviderType.vllm_chat_completions:
return VLLMChatCompletionsProvider(**self.model_dump(exclude_none=True))
case ProviderType.vllm_completions:
return VLLMCompletionsProvider(**self.model_dump(exclude_none=True))
case ProviderType.xai:
return XAIProvider(**self.model_dump(exclude_none=True))
case _:
raise ValueError(f"Unknown provider type: {self.provider_type}")
class ProviderCreate(ProviderBase):
name: str = Field(..., description="The name of the provider.")
provider_type: ProviderType = Field(..., description="The type of the provider.")
api_key: str = Field(..., description="API key used for requests to the provider.")
@ -70,8 +112,7 @@ class ProviderUpdate(ProviderBase):
class LettaProvider(Provider):
name: str = "letta"
provider_type: Literal[ProviderType.letta] = Field(ProviderType.letta, description="The type of the provider.")
def list_llm_models(self) -> List[LLMConfig]:
return [
@ -81,6 +122,7 @@ class LettaProvider(Provider):
model_endpoint=LETTA_MODEL_ENDPOINT,
context_window=8192,
handle=self.get_handle("letta-free"),
provider_name=self.name,
)
]
@ -98,7 +140,7 @@ class LettaProvider(Provider):
class OpenAIProvider(Provider):
name: str = "openai"
provider_type: Literal[ProviderType.openai] = Field(ProviderType.openai, description="The type of the provider.")
api_key: str = Field(..., description="API key for the OpenAI API.")
base_url: str = Field(..., description="Base URL for the OpenAI API.")
@ -180,6 +222,7 @@ class OpenAIProvider(Provider):
model_endpoint=self.base_url,
context_window=context_window_size,
handle=self.get_handle(model_name),
provider_name=self.name,
)
)
@ -235,7 +278,7 @@ class DeepSeekProvider(OpenAIProvider):
* It also does not support native function calling
"""
name: str = "deepseek"
provider_type: Literal[ProviderType.deepseek] = Field(ProviderType.deepseek, description="The type of the provider.")
base_url: str = Field("https://api.deepseek.com/v1", description="Base URL for the DeepSeek API.")
api_key: str = Field(..., description="API key for the DeepSeek API.")
@ -286,6 +329,7 @@ class DeepSeekProvider(OpenAIProvider):
context_window=context_window_size,
handle=self.get_handle(model_name),
put_inner_thoughts_in_kwargs=put_inner_thoughts_in_kwargs,
provider_name=self.name,
)
)
@ -297,7 +341,7 @@ class DeepSeekProvider(OpenAIProvider):
class LMStudioOpenAIProvider(OpenAIProvider):
name: str = "lmstudio-openai"
provider_type: Literal[ProviderType.lmstudio_openai] = Field(ProviderType.lmstudio_openai, description="The type of the provider.")
base_url: str = Field(..., description="Base URL for the LMStudio OpenAI API.")
api_key: Optional[str] = Field(None, description="API key for the LMStudio API.")
@ -423,7 +467,7 @@ class LMStudioOpenAIProvider(OpenAIProvider):
class XAIProvider(OpenAIProvider):
"""https://docs.x.ai/docs/api-reference"""
name: str = "xai"
provider_type: Literal[ProviderType.xai] = Field(ProviderType.xai, description="The type of the provider.")
api_key: str = Field(..., description="API key for the xAI/Grok API.")
base_url: str = Field("https://api.x.ai/v1", description="Base URL for the xAI/Grok API.")
@ -476,6 +520,7 @@ class XAIProvider(OpenAIProvider):
model_endpoint=self.base_url,
context_window=context_window_size,
handle=self.get_handle(model_name),
provider_name=self.name,
)
)
@ -487,7 +532,7 @@ class XAIProvider(OpenAIProvider):
class AnthropicProvider(Provider):
name: str = "anthropic"
provider_type: Literal[ProviderType.anthropic] = Field(ProviderType.anthropic, description="The type of the provider.")
api_key: str = Field(..., description="API key for the Anthropic API.")
base_url: str = "https://api.anthropic.com/v1"
@ -563,6 +608,7 @@ class AnthropicProvider(Provider):
handle=self.get_handle(model["id"]),
put_inner_thoughts_in_kwargs=inner_thoughts_in_kwargs,
max_tokens=max_tokens,
provider_name=self.name,
)
)
return configs
@ -572,7 +618,7 @@ class AnthropicProvider(Provider):
class MistralProvider(Provider):
name: str = "mistral"
provider_type: Literal[ProviderType.mistral] = Field(ProviderType.mistral, description="The type of the provider.")
api_key: str = Field(..., description="API key for the Mistral API.")
base_url: str = "https://api.mistral.ai/v1"
@ -596,6 +642,7 @@ class MistralProvider(Provider):
model_endpoint=self.base_url,
context_window=model["max_context_length"],
handle=self.get_handle(model["id"]),
provider_name=self.name,
)
)
@ -622,7 +669,7 @@ class OllamaProvider(OpenAIProvider):
See: https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion
"""
name: str = "ollama"
provider_type: Literal[ProviderType.ollama] = Field(ProviderType.ollama, description="The type of the provider.")
base_url: str = Field(..., description="Base URL for the Ollama API.")
api_key: Optional[str] = Field(None, description="API key for the Ollama API (default: `None`).")
default_prompt_formatter: str = Field(
@ -652,6 +699,7 @@ class OllamaProvider(OpenAIProvider):
model_wrapper=self.default_prompt_formatter,
context_window=context_window,
handle=self.get_handle(model["name"]),
provider_name=self.name,
)
)
return configs
@ -734,7 +782,7 @@ class OllamaProvider(OpenAIProvider):
class GroqProvider(OpenAIProvider):
name: str = "groq"
provider_type: Literal[ProviderType.groq] = Field(ProviderType.groq, description="The type of the provider.")
base_url: str = "https://api.groq.com/openai/v1"
api_key: str = Field(..., description="API key for the Groq API.")
@ -753,6 +801,7 @@ class GroqProvider(OpenAIProvider):
model_endpoint=self.base_url,
context_window=model["context_window"],
handle=self.get_handle(model["id"]),
provider_name=self.name,
)
)
return configs
@ -773,7 +822,7 @@ class TogetherProvider(OpenAIProvider):
function calling support is limited.
"""
name: str = "together"
provider_type: Literal[ProviderType.together] = Field(ProviderType.together, description="The type of the provider.")
base_url: str = "https://api.together.ai/v1"
api_key: str = Field(..., description="API key for the TogetherAI API.")
default_prompt_formatter: str = Field(..., description="Default prompt formatter (aka model wrapper) to use on vLLM /completions API.")
@ -821,6 +870,7 @@ class TogetherProvider(OpenAIProvider):
model_wrapper=self.default_prompt_formatter,
context_window=context_window_size,
handle=self.get_handle(model_name),
provider_name=self.name,
)
)
@ -874,7 +924,7 @@ class TogetherProvider(OpenAIProvider):
class GoogleAIProvider(Provider):
# gemini
name: str = "google_ai"
provider_type: Literal[ProviderType.google_ai] = Field(ProviderType.google_ai, description="The type of the provider.")
api_key: str = Field(..., description="API key for the Google AI API.")
base_url: str = "https://generativelanguage.googleapis.com"
@ -889,7 +939,6 @@ class GoogleAIProvider(Provider):
# filter by model names
model_options = [mo[len("models/") :] if mo.startswith("models/") else mo for mo in model_options]
# TODO remove manual filtering for gemini-pro
# Add support for all gemini models
model_options = [mo for mo in model_options if str(mo).startswith("gemini-")]
@ -903,6 +952,7 @@ class GoogleAIProvider(Provider):
context_window=self.get_model_context_window(model),
handle=self.get_handle(model),
max_tokens=8192,
provider_name=self.name,
)
)
return configs
@ -938,7 +988,7 @@ class GoogleAIProvider(Provider):
class GoogleVertexProvider(Provider):
name: str = "google_vertex"
provider_type: Literal[ProviderType.google_vertex] = Field(ProviderType.google_vertex, description="The type of the provider.")
google_cloud_project: str = Field(..., description="GCP project ID for the Google Vertex API.")
google_cloud_location: str = Field(..., description="GCP region for the Google Vertex API.")
@ -955,6 +1005,7 @@ class GoogleVertexProvider(Provider):
context_window=context_length,
handle=self.get_handle(model),
max_tokens=8192,
provider_name=self.name,
)
)
return configs
@ -978,7 +1029,7 @@ class GoogleVertexProvider(Provider):
class AzureProvider(Provider):
name: str = "azure"
provider_type: Literal[ProviderType.azure] = Field(ProviderType.azure, description="The type of the provider.")
latest_api_version: str = "2024-09-01-preview" # https://learn.microsoft.com/en-us/azure/ai-services/openai/api-version-deprecation
base_url: str = Field(
..., description="Base URL for the Azure API endpoint. This should be specific to your org, e.g. `https://letta.openai.azure.com`."
@ -1011,6 +1062,7 @@ class AzureProvider(Provider):
model_endpoint=model_endpoint,
context_window=context_window_size,
handle=self.get_handle(model_name),
provider_name=self.name,
),
)
return configs
@ -1051,7 +1103,7 @@ class VLLMChatCompletionsProvider(Provider):
"""vLLM provider that treats vLLM as an OpenAI /chat/completions proxy"""
# NOTE: vLLM only serves one model at a time (so could configure that through env variables)
name: str = "vllm"
provider_type: Literal[ProviderType.vllm] = Field(ProviderType.vllm, description="The type of the provider.")
base_url: str = Field(..., description="Base URL for the vLLM API.")
def list_llm_models(self) -> List[LLMConfig]:
@ -1070,6 +1122,7 @@ class VLLMChatCompletionsProvider(Provider):
model_endpoint=self.base_url,
context_window=model["max_model_len"],
handle=self.get_handle(model["id"]),
provider_name=self.name,
)
)
return configs
@ -1083,7 +1136,7 @@ class VLLMCompletionsProvider(Provider):
"""This uses /completions API as the backend, not /chat/completions, so we need to specify a model wrapper"""
# NOTE: vLLM only serves one model at a time (so could configure that through env variables)
name: str = "vllm"
provider_type: Literal[ProviderType.vllm] = Field(ProviderType.vllm, description="The type of the provider.")
base_url: str = Field(..., description="Base URL for the vLLM API.")
default_prompt_formatter: str = Field(..., description="Default prompt formatter (aka model wrapper) to use on vLLM /completions API.")
@ -1103,6 +1156,7 @@ class VLLMCompletionsProvider(Provider):
model_wrapper=self.default_prompt_formatter,
context_window=model["max_model_len"],
handle=self.get_handle(model["id"]),
provider_name=self.name,
)
)
return configs
@ -1117,7 +1171,7 @@ class CohereProvider(OpenAIProvider):
class AnthropicBedrockProvider(Provider):
name: str = "bedrock"
provider_type: Literal[ProviderType.bedrock] = Field(ProviderType.bedrock, description="The type of the provider.")
aws_region: str = Field(..., description="AWS region for Bedrock")
def list_llm_models(self):
@ -1131,10 +1185,11 @@ class AnthropicBedrockProvider(Provider):
configs.append(
LLMConfig(
model=model_arn,
model_endpoint_type=self.name,
model_endpoint_type=self.provider_type.value,
model_endpoint=None,
context_window=self.get_model_context_window(model_arn),
handle=self.get_handle(model_arn),
provider_name=self.name,
)
)
return configs

View File

@ -1,6 +1,6 @@
from typing import TYPE_CHECKING, List
from typing import TYPE_CHECKING, List, Optional
from fastapi import APIRouter, Depends
from fastapi import APIRouter, Depends, Query
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.llm_config import LLMConfig
@ -14,10 +14,11 @@ router = APIRouter(prefix="/models", tags=["models", "llms"])
@router.get("/", response_model=List[LLMConfig], operation_id="list_models")
def list_llm_models(
byok_only: Optional[bool] = Query(None),
server: "SyncServer" = Depends(get_letta_server),
):
models = server.list_llm_models()
models = server.list_llm_models(byok_only=byok_only)
# print(models)
return models

View File

@ -2,6 +2,7 @@ from typing import TYPE_CHECKING, List, Optional
from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query
from letta.schemas.enums import ProviderType
from letta.schemas.providers import Provider, ProviderCreate, ProviderUpdate
from letta.server.rest_api.utils import get_letta_server
@ -13,6 +14,8 @@ router = APIRouter(prefix="/providers", tags=["providers"])
@router.get("/", response_model=List[Provider], operation_id="list_providers")
def list_providers(
name: Optional[str] = Query(None),
provider_type: Optional[ProviderType] = Query(None),
after: Optional[str] = Query(None),
limit: Optional[int] = Query(50),
actor_id: Optional[str] = Header(None, alias="user_id"),
@ -23,7 +26,7 @@ def list_providers(
"""
try:
actor = server.user_manager.get_user_or_default(user_id=actor_id)
providers = server.provider_manager.list_providers(after=after, limit=limit, actor=actor)
providers = server.provider_manager.list_providers(after=after, limit=limit, actor=actor, name=name, provider_type=provider_type)
except HTTPException:
raise
except Exception as e:

View File

@ -268,10 +268,11 @@ class SyncServer(Server):
)
# collect providers (always has Letta as a default)
self._enabled_providers: List[Provider] = [LettaProvider()]
self._enabled_providers: List[Provider] = [LettaProvider(name="letta")]
if model_settings.openai_api_key:
self._enabled_providers.append(
OpenAIProvider(
name="openai",
api_key=model_settings.openai_api_key,
base_url=model_settings.openai_api_base,
)
@ -279,12 +280,14 @@ class SyncServer(Server):
if model_settings.anthropic_api_key:
self._enabled_providers.append(
AnthropicProvider(
name="anthropic",
api_key=model_settings.anthropic_api_key,
)
)
if model_settings.ollama_base_url:
self._enabled_providers.append(
OllamaProvider(
name="ollama",
base_url=model_settings.ollama_base_url,
api_key=None,
default_prompt_formatter=model_settings.default_prompt_formatter,
@ -293,12 +296,14 @@ class SyncServer(Server):
if model_settings.gemini_api_key:
self._enabled_providers.append(
GoogleAIProvider(
name="google_ai",
api_key=model_settings.gemini_api_key,
)
)
if model_settings.google_cloud_location and model_settings.google_cloud_project:
self._enabled_providers.append(
GoogleVertexProvider(
name="google_vertex",
google_cloud_project=model_settings.google_cloud_project,
google_cloud_location=model_settings.google_cloud_location,
)
@ -307,6 +312,7 @@ class SyncServer(Server):
assert model_settings.azure_api_version, "AZURE_API_VERSION is required"
self._enabled_providers.append(
AzureProvider(
name="azure",
api_key=model_settings.azure_api_key,
base_url=model_settings.azure_base_url,
api_version=model_settings.azure_api_version,
@ -315,12 +321,14 @@ class SyncServer(Server):
if model_settings.groq_api_key:
self._enabled_providers.append(
GroqProvider(
name="groq",
api_key=model_settings.groq_api_key,
)
)
if model_settings.together_api_key:
self._enabled_providers.append(
TogetherProvider(
name="together",
api_key=model_settings.together_api_key,
default_prompt_formatter=model_settings.default_prompt_formatter,
)
@ -329,6 +337,7 @@ class SyncServer(Server):
# vLLM exposes both a /chat/completions and a /completions endpoint
self._enabled_providers.append(
VLLMCompletionsProvider(
name="vllm",
base_url=model_settings.vllm_api_base,
default_prompt_formatter=model_settings.default_prompt_formatter,
)
@ -338,12 +347,14 @@ class SyncServer(Server):
# e.g. "... --enable-auto-tool-choice --tool-call-parser hermes"
self._enabled_providers.append(
VLLMChatCompletionsProvider(
name="vllm",
base_url=model_settings.vllm_api_base,
)
)
if model_settings.aws_access_key and model_settings.aws_secret_access_key and model_settings.aws_region:
self._enabled_providers.append(
AnthropicBedrockProvider(
name="bedrock",
aws_region=model_settings.aws_region,
)
)
@ -355,11 +366,11 @@ class SyncServer(Server):
if model_settings.lmstudio_base_url.endswith("/v1")
else model_settings.lmstudio_base_url + "/v1"
)
self._enabled_providers.append(LMStudioOpenAIProvider(base_url=lmstudio_url))
self._enabled_providers.append(LMStudioOpenAIProvider(name="lmstudio_openai", base_url=lmstudio_url))
if model_settings.deepseek_api_key:
self._enabled_providers.append(DeepSeekProvider(api_key=model_settings.deepseek_api_key))
self._enabled_providers.append(DeepSeekProvider(name="deepseek", api_key=model_settings.deepseek_api_key))
if model_settings.xai_api_key:
self._enabled_providers.append(XAIProvider(api_key=model_settings.xai_api_key))
self._enabled_providers.append(XAIProvider(name="xai", api_key=model_settings.xai_api_key))
# For MCP
"""Initialize the MCP clients (there may be multiple)"""
@ -1184,10 +1195,10 @@ class SyncServer(Server):
except NoResultFound:
raise HTTPException(status_code=404, detail=f"Organization with id {org_id} not found")
def list_llm_models(self) -> List[LLMConfig]:
def list_llm_models(self, byok_only: bool = False) -> List[LLMConfig]:
"""List available models"""
llm_models = []
for provider in self.get_enabled_providers():
for provider in self.get_enabled_providers(byok_only=byok_only):
try:
llm_models.extend(provider.list_llm_models())
except Exception as e:
@ -1207,11 +1218,12 @@ class SyncServer(Server):
warnings.warn(f"An error occurred while listing embedding models for provider {provider}: {e}")
return embedding_models
def get_enabled_providers(self):
def get_enabled_providers(self, byok_only: bool = False):
providers_from_db = {p.name: p.cast_to_subtype() for p in self.provider_manager.list_providers()}
if byok_only:
return list(providers_from_db.values())
providers_from_env = {p.name: p for p in self._enabled_providers}
providers_from_db = {p.name: p for p in self.provider_manager.list_providers()}
# Merge the two dictionaries, keeping the values from providers_from_db where conflicts occur
return {**providers_from_env, **providers_from_db}.values()
return list(providers_from_env.values()) + list(providers_from_db.values())
@trace_method
def get_llm_config_from_handle(
@ -1296,7 +1308,7 @@ class SyncServer(Server):
return embedding_config
def get_provider_from_name(self, provider_name: str) -> Provider:
providers = [provider for provider in self._enabled_providers if provider.name == provider_name]
providers = [provider for provider in self.get_enabled_providers() if provider.name == provider_name]
if not providers:
raise ValueError(f"Provider {provider_name} is not supported")
elif len(providers) > 1:

View File

@ -1,6 +1,7 @@
from typing import List, Optional
from typing import List, Optional, Union
from letta.orm.provider import Provider as ProviderModel
from letta.schemas.enums import ProviderType
from letta.schemas.providers import Provider as PydanticProvider
from letta.schemas.providers import ProviderUpdate
from letta.schemas.user import User as PydanticUser
@ -18,6 +19,9 @@ class ProviderManager:
def create_provider(self, provider: PydanticProvider, actor: PydanticUser) -> PydanticProvider:
"""Create a new provider if it doesn't already exist."""
with self.session_maker() as session:
if provider.name == provider.provider_type.value:
raise ValueError("Provider name must be unique and different from provider type")
# Assign the organization id based on the actor
provider.organization_id = actor.organization_id
@ -59,29 +63,36 @@ class ProviderManager:
session.commit()
@enforce_types
def list_providers(self, after: Optional[str] = None, limit: Optional[int] = 50, actor: PydanticUser = None) -> List[PydanticProvider]:
def list_providers(
self,
name: Optional[str] = None,
provider_type: Optional[ProviderType] = None,
after: Optional[str] = None,
limit: Optional[int] = 50,
actor: PydanticUser = None,
) -> List[PydanticProvider]:
"""List all providers with optional pagination."""
filter_kwargs = {}
if name:
filter_kwargs["name"] = name
if provider_type:
filter_kwargs["provider_type"] = provider_type
with self.session_maker() as session:
providers = ProviderModel.list(
db_session=session,
after=after,
limit=limit,
actor=actor,
**filter_kwargs,
)
return [provider.to_pydantic() for provider in providers]
@enforce_types
def get_anthropic_override_provider_id(self) -> Optional[str]:
"""Helper function to fetch custom anthropic provider id for v0 BYOK feature"""
anthropic_provider = [provider for provider in self.list_providers() if provider.name == "anthropic"]
if len(anthropic_provider) != 0:
return anthropic_provider[0].id
return None
def get_provider_id_from_name(self, provider_name: Union[str, None]) -> Optional[str]:
providers = self.list_providers(name=provider_name)
return providers[0].id if providers else None
@enforce_types
def get_anthropic_override_key(self) -> Optional[str]:
"""Helper function to fetch custom anthropic key for v0 BYOK feature"""
anthropic_provider = [provider for provider in self.list_providers() if provider.name == "anthropic"]
if len(anthropic_provider) != 0:
return anthropic_provider[0].api_key
return None
def get_override_key(self, provider_name: Union[str, None]) -> Optional[str]:
providers = self.list_providers(name=provider_name)
return providers[0].api_key if providers else None

View File

@ -105,7 +105,8 @@ def check_first_response_is_valid_for_llm_endpoint(filename: str, validate_inner
agent = Agent(agent_state=full_agent_state, interface=None, user=client.user)
llm_client = LLMClient.create(
provider=agent_state.llm_config.model_endpoint_type,
provider_name=agent_state.llm_config.provider_name,
provider_type=agent_state.llm_config.model_endpoint_type,
)
if llm_client:
response = llm_client.send_llm_request(

View File

@ -19,97 +19,166 @@ from letta.settings import model_settings
def test_openai():
api_key = os.getenv("OPENAI_API_KEY")
assert api_key is not None
provider = OpenAIProvider(api_key=api_key, base_url=model_settings.openai_api_base)
provider = OpenAIProvider(
name="openai",
api_key=api_key,
base_url=model_settings.openai_api_base,
)
models = provider.list_llm_models()
print(models)
assert len(models) > 0
assert models[0].handle == f"{provider.name}/{models[0].model}"
embedding_models = provider.list_embedding_models()
assert len(embedding_models) > 0
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
def test_deepseek():
api_key = os.getenv("DEEPSEEK_API_KEY")
assert api_key is not None
provider = DeepSeekProvider(api_key=api_key)
provider = DeepSeekProvider(
name="deepseek",
api_key=api_key,
)
models = provider.list_llm_models()
print(models)
assert len(models) > 0
assert models[0].handle == f"{provider.name}/{models[0].model}"
def test_anthropic():
api_key = os.getenv("ANTHROPIC_API_KEY")
assert api_key is not None
provider = AnthropicProvider(api_key=api_key)
provider = AnthropicProvider(
name="anthropic",
api_key=api_key,
)
models = provider.list_llm_models()
print(models)
assert len(models) > 0
assert models[0].handle == f"{provider.name}/{models[0].model}"
def test_groq():
provider = GroqProvider(api_key=os.getenv("GROQ_API_KEY"))
provider = GroqProvider(
name="groq",
api_key=os.getenv("GROQ_API_KEY"),
)
models = provider.list_llm_models()
print(models)
assert len(models) > 0
assert models[0].handle == f"{provider.name}/{models[0].model}"
def test_azure():
provider = AzureProvider(api_key=os.getenv("AZURE_API_KEY"), base_url=os.getenv("AZURE_BASE_URL"))
provider = AzureProvider(
name="azure",
api_key=os.getenv("AZURE_API_KEY"),
base_url=os.getenv("AZURE_BASE_URL"),
)
models = provider.list_llm_models()
print([m.model for m in models])
assert len(models) > 0
assert models[0].handle == f"{provider.name}/{models[0].model}"
embed_models = provider.list_embedding_models()
print([m.embedding_model for m in embed_models])
embedding_models = provider.list_embedding_models()
assert len(embedding_models) > 0
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
def test_ollama():
base_url = os.getenv("OLLAMA_BASE_URL")
assert base_url is not None
provider = OllamaProvider(base_url=base_url, default_prompt_formatter=model_settings.default_prompt_formatter, api_key=None)
provider = OllamaProvider(
name="ollama",
base_url=base_url,
default_prompt_formatter=model_settings.default_prompt_formatter,
api_key=None,
)
models = provider.list_llm_models()
print(models)
assert len(models) > 0
assert models[0].handle == f"{provider.name}/{models[0].model}"
embedding_models = provider.list_embedding_models()
print(embedding_models)
assert len(embedding_models) > 0
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
def test_googleai():
api_key = os.getenv("GEMINI_API_KEY")
assert api_key is not None
provider = GoogleAIProvider(api_key=api_key)
provider = GoogleAIProvider(
name="google_ai",
api_key=api_key,
)
models = provider.list_llm_models()
print(models)
assert len(models) > 0
assert models[0].handle == f"{provider.name}/{models[0].model}"
provider.list_embedding_models()
embedding_models = provider.list_embedding_models()
assert len(embedding_models) > 0
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
def test_google_vertex():
provider = GoogleVertexProvider(google_cloud_project=os.getenv("GCP_PROJECT_ID"), google_cloud_location=os.getenv("GCP_REGION"))
provider = GoogleVertexProvider(
name="google_vertex",
google_cloud_project=os.getenv("GCP_PROJECT_ID"),
google_cloud_location=os.getenv("GCP_REGION"),
)
models = provider.list_llm_models()
print(models)
print([m.model for m in models])
assert len(models) > 0
assert models[0].handle == f"{provider.name}/{models[0].model}"
embedding_models = provider.list_embedding_models()
print([m.embedding_model for m in embedding_models])
assert len(embedding_models) > 0
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
def test_mistral():
provider = MistralProvider(api_key=os.getenv("MISTRAL_API_KEY"))
provider = MistralProvider(
name="mistral",
api_key=os.getenv("MISTRAL_API_KEY"),
)
models = provider.list_llm_models()
print([m.model for m in models])
assert len(models) > 0
assert models[0].handle == f"{provider.name}/{models[0].model}"
def test_together():
provider = TogetherProvider(api_key=os.getenv("TOGETHER_API_KEY"), default_prompt_formatter="chatml")
provider = TogetherProvider(
name="together",
api_key=os.getenv("TOGETHER_API_KEY"),
default_prompt_formatter="chatml",
)
models = provider.list_llm_models()
print([m.model for m in models])
assert len(models) > 0
assert models[0].handle == f"{provider.name}/{models[0].model}"
embedding_models = provider.list_embedding_models()
print([m.embedding_model for m in embedding_models])
assert len(embedding_models) > 0
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
def test_anthropic_bedrock():
from letta.settings import model_settings
provider = AnthropicBedrockProvider(aws_region=model_settings.aws_region)
provider = AnthropicBedrockProvider(name="bedrock", aws_region=model_settings.aws_region)
models = provider.list_llm_models()
print([m.model for m in models])
assert len(models) > 0
assert models[0].handle == f"{provider.name}/{models[0].model}"
embedding_models = provider.list_embedding_models()
print([m.embedding_model for m in embedding_models])
assert len(embedding_models) > 0
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
def test_custom_anthropic():
api_key = os.getenv("ANTHROPIC_API_KEY")
assert api_key is not None
provider = AnthropicProvider(
name="custom_anthropic",
api_key=api_key,
)
models = provider.list_llm_models()
assert len(models) > 0
assert models[0].handle == f"{provider.name}/{models[0].model}"
# def test_vllm():

View File

@ -13,7 +13,7 @@ import letta.utils as utils
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, LETTA_DIR, LETTA_TOOL_EXECUTION_DIR
from letta.orm import Provider, Step
from letta.schemas.block import CreateBlock
from letta.schemas.enums import MessageRole
from letta.schemas.enums import MessageRole, ProviderType
from letta.schemas.letta_message import LettaMessage, ReasoningMessage, SystemMessage, ToolCallMessage, ToolReturnMessage, UserMessage
from letta.schemas.llm_config import LLMConfig
from letta.schemas.providers import Provider as PydanticProvider
@ -1226,7 +1226,8 @@ def test_messages_with_provider_override(server: SyncServer, user_id: str):
actor = server.user_manager.get_user_or_default(user_id)
provider = server.provider_manager.create_provider(
provider=PydanticProvider(
name="anthropic",
name="caren-anthropic",
provider_type=ProviderType.anthropic,
api_key=os.getenv("ANTHROPIC_API_KEY"),
),
actor=actor,
@ -1234,8 +1235,8 @@ def test_messages_with_provider_override(server: SyncServer, user_id: str):
agent = server.create_agent(
request=CreateAgent(
memory_blocks=[],
model="anthropic/claude-3-opus-20240229",
context_window_limit=200000,
model="caren-anthropic/claude-3-opus-20240229",
context_window_limit=100000,
embedding="openai/text-embedding-ada-002",
),
actor=actor,