mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
feat: byok 2.0 (#1963)
This commit is contained in:
parent
e3819cf066
commit
835792d5e0
@ -0,0 +1,35 @@
|
||||
"""add byok fields and unique constraint
|
||||
|
||||
Revision ID: 373dabcba6cf
|
||||
Revises: c56081a05371
|
||||
Create Date: 2025-04-30 19:38:25.010856
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "373dabcba6cf"
|
||||
down_revision: Union[str, None] = "c56081a05371"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("providers", sa.Column("provider_type", sa.String(), nullable=True))
|
||||
op.add_column("providers", sa.Column("base_url", sa.String(), nullable=True))
|
||||
op.create_unique_constraint("unique_name_organization_id", "providers", ["name", "organization_id"])
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_constraint("unique_name_organization_id", "providers", type_="unique")
|
||||
op.drop_column("providers", "base_url")
|
||||
op.drop_column("providers", "provider_type")
|
||||
# ### end Alembic commands ###
|
@ -331,7 +331,8 @@ class Agent(BaseAgent):
|
||||
log_telemetry(self.logger, "_get_ai_reply create start")
|
||||
# New LLM client flow
|
||||
llm_client = LLMClient.create(
|
||||
provider=self.agent_state.llm_config.model_endpoint_type,
|
||||
provider_name=self.agent_state.llm_config.provider_name,
|
||||
provider_type=self.agent_state.llm_config.model_endpoint_type,
|
||||
put_inner_thoughts_first=put_inner_thoughts_first,
|
||||
)
|
||||
|
||||
@ -941,12 +942,7 @@ class Agent(BaseAgent):
|
||||
model_endpoint=self.agent_state.llm_config.model_endpoint,
|
||||
context_window_limit=self.agent_state.llm_config.context_window,
|
||||
usage=response.usage,
|
||||
# TODO(@caren): Add full provider support - this line is a workaround for v0 BYOK feature
|
||||
provider_id=(
|
||||
self.provider_manager.get_anthropic_override_provider_id()
|
||||
if self.agent_state.llm_config.model_endpoint_type == "anthropic"
|
||||
else None
|
||||
),
|
||||
provider_id=self.provider_manager.get_provider_id_from_name(self.agent_state.llm_config.provider_name),
|
||||
job_id=job_id,
|
||||
)
|
||||
for message in all_new_messages:
|
||||
|
@ -67,7 +67,8 @@ class LettaAgent(BaseAgent):
|
||||
)
|
||||
tool_rules_solver = ToolRulesSolver(agent_state.tool_rules)
|
||||
llm_client = LLMClient.create(
|
||||
provider=agent_state.llm_config.model_endpoint_type,
|
||||
provider_name=agent_state.llm_config.provider_name,
|
||||
provider_type=agent_state.llm_config.model_endpoint_type,
|
||||
put_inner_thoughts_first=True,
|
||||
)
|
||||
for step in range(max_steps):
|
||||
|
@ -156,7 +156,8 @@ class LettaAgentBatch:
|
||||
|
||||
log_event(name="init_llm_client")
|
||||
llm_client = LLMClient.create(
|
||||
provider=agent_states[0].llm_config.model_endpoint_type,
|
||||
provider_name=agent_states[0].llm_config.provider_name,
|
||||
provider_type=agent_states[0].llm_config.model_endpoint_type,
|
||||
put_inner_thoughts_first=True,
|
||||
)
|
||||
agent_llm_config_mapping = {s.id: s.llm_config for s in agent_states}
|
||||
@ -273,7 +274,8 @@ class LettaAgentBatch:
|
||||
|
||||
# translate provider‑specific response → OpenAI‑style tool call (unchanged)
|
||||
llm_client = LLMClient.create(
|
||||
provider=item.llm_config.model_endpoint_type,
|
||||
provider_name=item.llm_config.provider_name,
|
||||
provider_type=item.llm_config.model_endpoint_type,
|
||||
put_inner_thoughts_first=True,
|
||||
)
|
||||
tool_call = (
|
||||
|
@ -26,6 +26,7 @@ from letta.llm_api.helpers import add_inner_thoughts_to_functions
|
||||
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION
|
||||
from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.enums import ProviderType
|
||||
from letta.schemas.message import Message as _Message
|
||||
from letta.schemas.message import MessageRole as _MessageRole
|
||||
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, Tool
|
||||
@ -128,11 +129,12 @@ def anthropic_get_model_list(url: str, api_key: Union[str, None]) -> dict:
|
||||
# NOTE: currently there is no GET /models, so we need to hardcode
|
||||
# return MODEL_LIST
|
||||
|
||||
anthropic_override_key = ProviderManager().get_anthropic_override_key()
|
||||
if anthropic_override_key:
|
||||
anthropic_client = anthropic.Anthropic(api_key=anthropic_override_key)
|
||||
if api_key:
|
||||
anthropic_client = anthropic.Anthropic(api_key=api_key)
|
||||
elif model_settings.anthropic_api_key:
|
||||
anthropic_client = anthropic.Anthropic()
|
||||
else:
|
||||
raise ValueError("No API key provided")
|
||||
|
||||
models = anthropic_client.models.list()
|
||||
models_json = models.model_dump()
|
||||
@ -738,13 +740,14 @@ def anthropic_chat_completions_request(
|
||||
put_inner_thoughts_in_kwargs: bool = False,
|
||||
extended_thinking: bool = False,
|
||||
max_reasoning_tokens: Optional[int] = None,
|
||||
provider_name: Optional[str] = None,
|
||||
betas: List[str] = ["tools-2024-04-04"],
|
||||
) -> ChatCompletionResponse:
|
||||
"""https://docs.anthropic.com/claude/docs/tool-use"""
|
||||
anthropic_client = None
|
||||
anthropic_override_key = ProviderManager().get_anthropic_override_key()
|
||||
if anthropic_override_key:
|
||||
anthropic_client = anthropic.Anthropic(api_key=anthropic_override_key)
|
||||
if provider_name and provider_name != ProviderType.anthropic.value:
|
||||
api_key = ProviderManager().get_override_key(provider_name)
|
||||
anthropic_client = anthropic.Anthropic(api_key=api_key)
|
||||
elif model_settings.anthropic_api_key:
|
||||
anthropic_client = anthropic.Anthropic()
|
||||
else:
|
||||
@ -796,6 +799,7 @@ def anthropic_chat_completions_request_stream(
|
||||
put_inner_thoughts_in_kwargs: bool = False,
|
||||
extended_thinking: bool = False,
|
||||
max_reasoning_tokens: Optional[int] = None,
|
||||
provider_name: Optional[str] = None,
|
||||
betas: List[str] = ["tools-2024-04-04"],
|
||||
) -> Generator[ChatCompletionChunkResponse, None, None]:
|
||||
"""Stream chat completions from Anthropic API.
|
||||
@ -810,10 +814,9 @@ def anthropic_chat_completions_request_stream(
|
||||
extended_thinking=extended_thinking,
|
||||
max_reasoning_tokens=max_reasoning_tokens,
|
||||
)
|
||||
|
||||
anthropic_override_key = ProviderManager().get_anthropic_override_key()
|
||||
if anthropic_override_key:
|
||||
anthropic_client = anthropic.Anthropic(api_key=anthropic_override_key)
|
||||
if provider_name and provider_name != ProviderType.anthropic.value:
|
||||
api_key = ProviderManager().get_override_key(provider_name)
|
||||
anthropic_client = anthropic.Anthropic(api_key=api_key)
|
||||
elif model_settings.anthropic_api_key:
|
||||
anthropic_client = anthropic.Anthropic()
|
||||
|
||||
@ -860,6 +863,7 @@ def anthropic_chat_completions_process_stream(
|
||||
put_inner_thoughts_in_kwargs: bool = False,
|
||||
extended_thinking: bool = False,
|
||||
max_reasoning_tokens: Optional[int] = None,
|
||||
provider_name: Optional[str] = None,
|
||||
create_message_id: bool = True,
|
||||
create_message_datetime: bool = True,
|
||||
betas: List[str] = ["tools-2024-04-04"],
|
||||
@ -944,6 +948,7 @@ def anthropic_chat_completions_process_stream(
|
||||
put_inner_thoughts_in_kwargs=put_inner_thoughts_in_kwargs,
|
||||
extended_thinking=extended_thinking,
|
||||
max_reasoning_tokens=max_reasoning_tokens,
|
||||
provider_name=provider_name,
|
||||
betas=betas,
|
||||
)
|
||||
):
|
||||
|
@ -27,6 +27,7 @@ from letta.llm_api.helpers import add_inner_thoughts_to_functions, unpack_all_in
|
||||
from letta.llm_api.llm_client_base import LLMClientBase
|
||||
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.enums import ProviderType
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.openai.chat_completion_request import Tool
|
||||
@ -112,7 +113,10 @@ class AnthropicClient(LLMClientBase):
|
||||
|
||||
@trace_method
|
||||
def _get_anthropic_client(self, async_client: bool = False) -> Union[anthropic.AsyncAnthropic, anthropic.Anthropic]:
|
||||
override_key = ProviderManager().get_anthropic_override_key()
|
||||
override_key = None
|
||||
if self.provider_name and self.provider_name != ProviderType.anthropic.value:
|
||||
override_key = ProviderManager().get_override_key(self.provider_name)
|
||||
|
||||
if async_client:
|
||||
return anthropic.AsyncAnthropic(api_key=override_key) if override_key else anthropic.AsyncAnthropic()
|
||||
return anthropic.Anthropic(api_key=override_key) if override_key else anthropic.Anthropic()
|
||||
|
@ -24,6 +24,7 @@ from letta.llm_api.openai import (
|
||||
from letta.local_llm.chat_completion_proxy import get_chat_completion
|
||||
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION
|
||||
from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages
|
||||
from letta.schemas.enums import ProviderType
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, cast_message_to_subtype
|
||||
@ -171,6 +172,10 @@ def create(
|
||||
if model_settings.openai_api_key is None and llm_config.model_endpoint == "https://api.openai.com/v1":
|
||||
# only is a problem if we are *not* using an openai proxy
|
||||
raise LettaConfigurationError(message="OpenAI key is missing from letta config file", missing_fields=["openai_api_key"])
|
||||
elif llm_config.provider_name and llm_config.provider_name != ProviderType.openai.value:
|
||||
from letta.services.provider_manager import ProviderManager
|
||||
|
||||
api_key = ProviderManager().get_override_key(llm_config.provider_name)
|
||||
elif model_settings.openai_api_key is None:
|
||||
# the openai python client requires a dummy API key
|
||||
api_key = "DUMMY_API_KEY"
|
||||
@ -373,6 +378,7 @@ def create(
|
||||
stream_interface=stream_interface,
|
||||
extended_thinking=llm_config.enable_reasoner,
|
||||
max_reasoning_tokens=llm_config.max_reasoning_tokens,
|
||||
provider_name=llm_config.provider_name,
|
||||
name=name,
|
||||
)
|
||||
|
||||
@ -383,6 +389,7 @@ def create(
|
||||
put_inner_thoughts_in_kwargs=llm_config.put_inner_thoughts_in_kwargs,
|
||||
extended_thinking=llm_config.enable_reasoner,
|
||||
max_reasoning_tokens=llm_config.max_reasoning_tokens,
|
||||
provider_name=llm_config.provider_name,
|
||||
)
|
||||
|
||||
if llm_config.put_inner_thoughts_in_kwargs:
|
||||
|
@ -9,7 +9,8 @@ class LLMClient:
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
provider: ProviderType,
|
||||
provider_type: ProviderType,
|
||||
provider_name: Optional[str] = None,
|
||||
put_inner_thoughts_first: bool = True,
|
||||
) -> Optional[LLMClientBase]:
|
||||
"""
|
||||
@ -25,29 +26,33 @@ class LLMClient:
|
||||
Raises:
|
||||
ValueError: If the model endpoint type is not supported
|
||||
"""
|
||||
match provider:
|
||||
match provider_type:
|
||||
case ProviderType.google_ai:
|
||||
from letta.llm_api.google_ai_client import GoogleAIClient
|
||||
|
||||
return GoogleAIClient(
|
||||
provider_name=provider_name,
|
||||
put_inner_thoughts_first=put_inner_thoughts_first,
|
||||
)
|
||||
case ProviderType.google_vertex:
|
||||
from letta.llm_api.google_vertex_client import GoogleVertexClient
|
||||
|
||||
return GoogleVertexClient(
|
||||
provider_name=provider_name,
|
||||
put_inner_thoughts_first=put_inner_thoughts_first,
|
||||
)
|
||||
case ProviderType.anthropic:
|
||||
from letta.llm_api.anthropic_client import AnthropicClient
|
||||
|
||||
return AnthropicClient(
|
||||
provider_name=provider_name,
|
||||
put_inner_thoughts_first=put_inner_thoughts_first,
|
||||
)
|
||||
case ProviderType.openai:
|
||||
from letta.llm_api.openai_client import OpenAIClient
|
||||
|
||||
return OpenAIClient(
|
||||
provider_name=provider_name,
|
||||
put_inner_thoughts_first=put_inner_thoughts_first,
|
||||
)
|
||||
case _:
|
||||
|
@ -20,9 +20,11 @@ class LLMClientBase:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
provider_name: Optional[str] = None,
|
||||
put_inner_thoughts_first: Optional[bool] = True,
|
||||
use_tool_naming: bool = True,
|
||||
):
|
||||
self.provider_name = provider_name
|
||||
self.put_inner_thoughts_first = put_inner_thoughts_first
|
||||
self.use_tool_naming = use_tool_naming
|
||||
|
||||
|
@ -22,6 +22,7 @@ from letta.llm_api.helpers import add_inner_thoughts_to_functions, convert_to_st
|
||||
from letta.llm_api.llm_client_base import LLMClientBase
|
||||
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION, INNER_THOUGHTS_KWARG_DESCRIPTION_GO_FIRST
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.enums import ProviderType
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest
|
||||
@ -64,6 +65,13 @@ def supports_parallel_tool_calling(model: str) -> bool:
|
||||
|
||||
class OpenAIClient(LLMClientBase):
|
||||
def _prepare_client_kwargs(self, llm_config: LLMConfig) -> dict:
|
||||
api_key = None
|
||||
if llm_config.provider_name and llm_config.provider_name != ProviderType.openai.value:
|
||||
from letta.services.provider_manager import ProviderManager
|
||||
|
||||
api_key = ProviderManager().get_override_key(llm_config.provider_name)
|
||||
|
||||
if not api_key:
|
||||
api_key = model_settings.openai_api_key or os.environ.get("OPENAI_API_KEY")
|
||||
# supposedly the openai python client requires a dummy API key
|
||||
api_key = api_key or "DUMMY_API_KEY"
|
||||
|
@ -79,7 +79,8 @@ def summarize_messages(
|
||||
llm_config_no_inner_thoughts.put_inner_thoughts_in_kwargs = False
|
||||
|
||||
llm_client = LLMClient.create(
|
||||
provider=llm_config_no_inner_thoughts.model_endpoint_type,
|
||||
provider_name=llm_config_no_inner_thoughts.provider_name,
|
||||
provider_type=llm_config_no_inner_thoughts.model_endpoint_type,
|
||||
put_inner_thoughts_first=False,
|
||||
)
|
||||
# try to use new client, otherwise fallback to old flow
|
||||
|
@ -1,5 +1,6 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from letta.orm.mixins import OrganizationMixin
|
||||
@ -15,9 +16,18 @@ class Provider(SqlalchemyBase, OrganizationMixin):
|
||||
|
||||
__tablename__ = "providers"
|
||||
__pydantic_model__ = PydanticProvider
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"name",
|
||||
"organization_id",
|
||||
name="unique_name_organization_id",
|
||||
),
|
||||
)
|
||||
|
||||
name: Mapped[str] = mapped_column(nullable=False, doc="The name of the provider")
|
||||
provider_type: Mapped[str] = mapped_column(nullable=True, doc="The type of the provider")
|
||||
api_key: Mapped[str] = mapped_column(nullable=True, doc="API key used for requests to the provider.")
|
||||
base_url: Mapped[str] = mapped_column(nullable=True, doc="Base URL for the provider.")
|
||||
|
||||
# relationships
|
||||
organization: Mapped["Organization"] = relationship("Organization", back_populates="providers")
|
||||
|
@ -6,6 +6,17 @@ class ProviderType(str, Enum):
|
||||
google_ai = "google_ai"
|
||||
google_vertex = "google_vertex"
|
||||
openai = "openai"
|
||||
letta = "letta"
|
||||
deepseek = "deepseek"
|
||||
lmstudio_openai = "lmstudio_openai"
|
||||
xai = "xai"
|
||||
mistral = "mistral"
|
||||
ollama = "ollama"
|
||||
groq = "groq"
|
||||
together = "together"
|
||||
azure = "azure"
|
||||
vllm = "vllm"
|
||||
bedrock = "bedrock"
|
||||
|
||||
|
||||
class MessageRole(str, Enum):
|
||||
|
@ -50,6 +50,7 @@ class LLMConfig(BaseModel):
|
||||
"xai",
|
||||
] = Field(..., description="The endpoint type for the model.")
|
||||
model_endpoint: Optional[str] = Field(None, description="The endpoint for the model.")
|
||||
provider_name: Optional[str] = Field(None, description="The provider name for the model.")
|
||||
model_wrapper: Optional[str] = Field(None, description="The wrapper for the model.")
|
||||
context_window: int = Field(..., description="The context window size for the model.")
|
||||
put_inner_thoughts_in_kwargs: Optional[bool] = Field(
|
||||
|
@ -2,8 +2,8 @@ from typing import Dict
|
||||
|
||||
LLM_HANDLE_OVERRIDES: Dict[str, Dict[str, str]] = {
|
||||
"anthropic": {
|
||||
"claude-3-5-haiku-20241022": "claude-3.5-haiku",
|
||||
"claude-3-5-sonnet-20241022": "claude-3.5-sonnet",
|
||||
"claude-3-5-haiku-20241022": "claude-3-5-haiku",
|
||||
"claude-3-5-sonnet-20241022": "claude-3-5-sonnet",
|
||||
"claude-3-opus-20240229": "claude-3-opus",
|
||||
},
|
||||
"openai": {
|
||||
|
@ -1,6 +1,6 @@
|
||||
import warnings
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
from typing import List, Literal, Optional
|
||||
|
||||
from pydantic import Field, model_validator
|
||||
|
||||
@ -9,9 +9,11 @@ from letta.llm_api.azure_openai import get_azure_chat_completions_endpoint, get_
|
||||
from letta.llm_api.azure_openai_constants import AZURE_MODEL_TO_CONTEXT_LENGTH
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.embedding_config_overrides import EMBEDDING_HANDLE_OVERRIDES
|
||||
from letta.schemas.enums import ProviderType
|
||||
from letta.schemas.letta_base import LettaBase
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.llm_config_overrides import LLM_HANDLE_OVERRIDES
|
||||
from letta.settings import model_settings
|
||||
|
||||
|
||||
class ProviderBase(LettaBase):
|
||||
@ -21,10 +23,18 @@ class ProviderBase(LettaBase):
|
||||
class Provider(ProviderBase):
|
||||
id: Optional[str] = Field(None, description="The id of the provider, lazily created by the database manager.")
|
||||
name: str = Field(..., description="The name of the provider")
|
||||
provider_type: ProviderType = Field(..., description="The type of the provider")
|
||||
api_key: Optional[str] = Field(None, description="API key used for requests to the provider.")
|
||||
base_url: Optional[str] = Field(None, description="Base URL for the provider.")
|
||||
organization_id: Optional[str] = Field(None, description="The organization id of the user")
|
||||
updated_at: Optional[datetime] = Field(None, description="The last update timestamp of the provider.")
|
||||
|
||||
@model_validator(mode="after")
|
||||
def default_base_url(self):
|
||||
if self.provider_type == ProviderType.openai and self.base_url is None:
|
||||
self.base_url = model_settings.openai_api_base
|
||||
return self
|
||||
|
||||
def resolve_identifier(self):
|
||||
if not self.id:
|
||||
self.id = ProviderBase.generate_id(prefix=ProviderBase.__id_prefix__)
|
||||
@ -59,9 +69,41 @@ class Provider(ProviderBase):
|
||||
|
||||
return f"{self.name}/{model_name}"
|
||||
|
||||
def cast_to_subtype(self):
|
||||
match (self.provider_type):
|
||||
case ProviderType.letta:
|
||||
return LettaProvider(**self.model_dump(exclude_none=True))
|
||||
case ProviderType.openai:
|
||||
return OpenAIProvider(**self.model_dump(exclude_none=True))
|
||||
case ProviderType.anthropic:
|
||||
return AnthropicProvider(**self.model_dump(exclude_none=True))
|
||||
case ProviderType.anthropic_bedrock:
|
||||
return AnthropicBedrockProvider(**self.model_dump(exclude_none=True))
|
||||
case ProviderType.ollama:
|
||||
return OllamaProvider(**self.model_dump(exclude_none=True))
|
||||
case ProviderType.google_ai:
|
||||
return GoogleAIProvider(**self.model_dump(exclude_none=True))
|
||||
case ProviderType.google_vertex:
|
||||
return GoogleVertexProvider(**self.model_dump(exclude_none=True))
|
||||
case ProviderType.azure:
|
||||
return AzureProvider(**self.model_dump(exclude_none=True))
|
||||
case ProviderType.groq:
|
||||
return GroqProvider(**self.model_dump(exclude_none=True))
|
||||
case ProviderType.together:
|
||||
return TogetherProvider(**self.model_dump(exclude_none=True))
|
||||
case ProviderType.vllm_chat_completions:
|
||||
return VLLMChatCompletionsProvider(**self.model_dump(exclude_none=True))
|
||||
case ProviderType.vllm_completions:
|
||||
return VLLMCompletionsProvider(**self.model_dump(exclude_none=True))
|
||||
case ProviderType.xai:
|
||||
return XAIProvider(**self.model_dump(exclude_none=True))
|
||||
case _:
|
||||
raise ValueError(f"Unknown provider type: {self.provider_type}")
|
||||
|
||||
|
||||
class ProviderCreate(ProviderBase):
|
||||
name: str = Field(..., description="The name of the provider.")
|
||||
provider_type: ProviderType = Field(..., description="The type of the provider.")
|
||||
api_key: str = Field(..., description="API key used for requests to the provider.")
|
||||
|
||||
|
||||
@ -70,8 +112,7 @@ class ProviderUpdate(ProviderBase):
|
||||
|
||||
|
||||
class LettaProvider(Provider):
|
||||
|
||||
name: str = "letta"
|
||||
provider_type: Literal[ProviderType.letta] = Field(ProviderType.letta, description="The type of the provider.")
|
||||
|
||||
def list_llm_models(self) -> List[LLMConfig]:
|
||||
return [
|
||||
@ -81,6 +122,7 @@ class LettaProvider(Provider):
|
||||
model_endpoint=LETTA_MODEL_ENDPOINT,
|
||||
context_window=8192,
|
||||
handle=self.get_handle("letta-free"),
|
||||
provider_name=self.name,
|
||||
)
|
||||
]
|
||||
|
||||
@ -98,7 +140,7 @@ class LettaProvider(Provider):
|
||||
|
||||
|
||||
class OpenAIProvider(Provider):
|
||||
name: str = "openai"
|
||||
provider_type: Literal[ProviderType.openai] = Field(ProviderType.openai, description="The type of the provider.")
|
||||
api_key: str = Field(..., description="API key for the OpenAI API.")
|
||||
base_url: str = Field(..., description="Base URL for the OpenAI API.")
|
||||
|
||||
@ -180,6 +222,7 @@ class OpenAIProvider(Provider):
|
||||
model_endpoint=self.base_url,
|
||||
context_window=context_window_size,
|
||||
handle=self.get_handle(model_name),
|
||||
provider_name=self.name,
|
||||
)
|
||||
)
|
||||
|
||||
@ -235,7 +278,7 @@ class DeepSeekProvider(OpenAIProvider):
|
||||
* It also does not support native function calling
|
||||
"""
|
||||
|
||||
name: str = "deepseek"
|
||||
provider_type: Literal[ProviderType.deepseek] = Field(ProviderType.deepseek, description="The type of the provider.")
|
||||
base_url: str = Field("https://api.deepseek.com/v1", description="Base URL for the DeepSeek API.")
|
||||
api_key: str = Field(..., description="API key for the DeepSeek API.")
|
||||
|
||||
@ -286,6 +329,7 @@ class DeepSeekProvider(OpenAIProvider):
|
||||
context_window=context_window_size,
|
||||
handle=self.get_handle(model_name),
|
||||
put_inner_thoughts_in_kwargs=put_inner_thoughts_in_kwargs,
|
||||
provider_name=self.name,
|
||||
)
|
||||
)
|
||||
|
||||
@ -297,7 +341,7 @@ class DeepSeekProvider(OpenAIProvider):
|
||||
|
||||
|
||||
class LMStudioOpenAIProvider(OpenAIProvider):
|
||||
name: str = "lmstudio-openai"
|
||||
provider_type: Literal[ProviderType.lmstudio_openai] = Field(ProviderType.lmstudio_openai, description="The type of the provider.")
|
||||
base_url: str = Field(..., description="Base URL for the LMStudio OpenAI API.")
|
||||
api_key: Optional[str] = Field(None, description="API key for the LMStudio API.")
|
||||
|
||||
@ -423,7 +467,7 @@ class LMStudioOpenAIProvider(OpenAIProvider):
|
||||
class XAIProvider(OpenAIProvider):
|
||||
"""https://docs.x.ai/docs/api-reference"""
|
||||
|
||||
name: str = "xai"
|
||||
provider_type: Literal[ProviderType.xai] = Field(ProviderType.xai, description="The type of the provider.")
|
||||
api_key: str = Field(..., description="API key for the xAI/Grok API.")
|
||||
base_url: str = Field("https://api.x.ai/v1", description="Base URL for the xAI/Grok API.")
|
||||
|
||||
@ -476,6 +520,7 @@ class XAIProvider(OpenAIProvider):
|
||||
model_endpoint=self.base_url,
|
||||
context_window=context_window_size,
|
||||
handle=self.get_handle(model_name),
|
||||
provider_name=self.name,
|
||||
)
|
||||
)
|
||||
|
||||
@ -487,7 +532,7 @@ class XAIProvider(OpenAIProvider):
|
||||
|
||||
|
||||
class AnthropicProvider(Provider):
|
||||
name: str = "anthropic"
|
||||
provider_type: Literal[ProviderType.anthropic] = Field(ProviderType.anthropic, description="The type of the provider.")
|
||||
api_key: str = Field(..., description="API key for the Anthropic API.")
|
||||
base_url: str = "https://api.anthropic.com/v1"
|
||||
|
||||
@ -563,6 +608,7 @@ class AnthropicProvider(Provider):
|
||||
handle=self.get_handle(model["id"]),
|
||||
put_inner_thoughts_in_kwargs=inner_thoughts_in_kwargs,
|
||||
max_tokens=max_tokens,
|
||||
provider_name=self.name,
|
||||
)
|
||||
)
|
||||
return configs
|
||||
@ -572,7 +618,7 @@ class AnthropicProvider(Provider):
|
||||
|
||||
|
||||
class MistralProvider(Provider):
|
||||
name: str = "mistral"
|
||||
provider_type: Literal[ProviderType.mistral] = Field(ProviderType.mistral, description="The type of the provider.")
|
||||
api_key: str = Field(..., description="API key for the Mistral API.")
|
||||
base_url: str = "https://api.mistral.ai/v1"
|
||||
|
||||
@ -596,6 +642,7 @@ class MistralProvider(Provider):
|
||||
model_endpoint=self.base_url,
|
||||
context_window=model["max_context_length"],
|
||||
handle=self.get_handle(model["id"]),
|
||||
provider_name=self.name,
|
||||
)
|
||||
)
|
||||
|
||||
@ -622,7 +669,7 @@ class OllamaProvider(OpenAIProvider):
|
||||
See: https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion
|
||||
"""
|
||||
|
||||
name: str = "ollama"
|
||||
provider_type: Literal[ProviderType.ollama] = Field(ProviderType.ollama, description="The type of the provider.")
|
||||
base_url: str = Field(..., description="Base URL for the Ollama API.")
|
||||
api_key: Optional[str] = Field(None, description="API key for the Ollama API (default: `None`).")
|
||||
default_prompt_formatter: str = Field(
|
||||
@ -652,6 +699,7 @@ class OllamaProvider(OpenAIProvider):
|
||||
model_wrapper=self.default_prompt_formatter,
|
||||
context_window=context_window,
|
||||
handle=self.get_handle(model["name"]),
|
||||
provider_name=self.name,
|
||||
)
|
||||
)
|
||||
return configs
|
||||
@ -734,7 +782,7 @@ class OllamaProvider(OpenAIProvider):
|
||||
|
||||
|
||||
class GroqProvider(OpenAIProvider):
|
||||
name: str = "groq"
|
||||
provider_type: Literal[ProviderType.groq] = Field(ProviderType.groq, description="The type of the provider.")
|
||||
base_url: str = "https://api.groq.com/openai/v1"
|
||||
api_key: str = Field(..., description="API key for the Groq API.")
|
||||
|
||||
@ -753,6 +801,7 @@ class GroqProvider(OpenAIProvider):
|
||||
model_endpoint=self.base_url,
|
||||
context_window=model["context_window"],
|
||||
handle=self.get_handle(model["id"]),
|
||||
provider_name=self.name,
|
||||
)
|
||||
)
|
||||
return configs
|
||||
@ -773,7 +822,7 @@ class TogetherProvider(OpenAIProvider):
|
||||
function calling support is limited.
|
||||
"""
|
||||
|
||||
name: str = "together"
|
||||
provider_type: Literal[ProviderType.together] = Field(ProviderType.together, description="The type of the provider.")
|
||||
base_url: str = "https://api.together.ai/v1"
|
||||
api_key: str = Field(..., description="API key for the TogetherAI API.")
|
||||
default_prompt_formatter: str = Field(..., description="Default prompt formatter (aka model wrapper) to use on vLLM /completions API.")
|
||||
@ -821,6 +870,7 @@ class TogetherProvider(OpenAIProvider):
|
||||
model_wrapper=self.default_prompt_formatter,
|
||||
context_window=context_window_size,
|
||||
handle=self.get_handle(model_name),
|
||||
provider_name=self.name,
|
||||
)
|
||||
)
|
||||
|
||||
@ -874,7 +924,7 @@ class TogetherProvider(OpenAIProvider):
|
||||
|
||||
class GoogleAIProvider(Provider):
|
||||
# gemini
|
||||
name: str = "google_ai"
|
||||
provider_type: Literal[ProviderType.google_ai] = Field(ProviderType.google_ai, description="The type of the provider.")
|
||||
api_key: str = Field(..., description="API key for the Google AI API.")
|
||||
base_url: str = "https://generativelanguage.googleapis.com"
|
||||
|
||||
@ -889,7 +939,6 @@ class GoogleAIProvider(Provider):
|
||||
# filter by model names
|
||||
model_options = [mo[len("models/") :] if mo.startswith("models/") else mo for mo in model_options]
|
||||
|
||||
# TODO remove manual filtering for gemini-pro
|
||||
# Add support for all gemini models
|
||||
model_options = [mo for mo in model_options if str(mo).startswith("gemini-")]
|
||||
|
||||
@ -903,6 +952,7 @@ class GoogleAIProvider(Provider):
|
||||
context_window=self.get_model_context_window(model),
|
||||
handle=self.get_handle(model),
|
||||
max_tokens=8192,
|
||||
provider_name=self.name,
|
||||
)
|
||||
)
|
||||
return configs
|
||||
@ -938,7 +988,7 @@ class GoogleAIProvider(Provider):
|
||||
|
||||
|
||||
class GoogleVertexProvider(Provider):
|
||||
name: str = "google_vertex"
|
||||
provider_type: Literal[ProviderType.google_vertex] = Field(ProviderType.google_vertex, description="The type of the provider.")
|
||||
google_cloud_project: str = Field(..., description="GCP project ID for the Google Vertex API.")
|
||||
google_cloud_location: str = Field(..., description="GCP region for the Google Vertex API.")
|
||||
|
||||
@ -955,6 +1005,7 @@ class GoogleVertexProvider(Provider):
|
||||
context_window=context_length,
|
||||
handle=self.get_handle(model),
|
||||
max_tokens=8192,
|
||||
provider_name=self.name,
|
||||
)
|
||||
)
|
||||
return configs
|
||||
@ -978,7 +1029,7 @@ class GoogleVertexProvider(Provider):
|
||||
|
||||
|
||||
class AzureProvider(Provider):
|
||||
name: str = "azure"
|
||||
provider_type: Literal[ProviderType.azure] = Field(ProviderType.azure, description="The type of the provider.")
|
||||
latest_api_version: str = "2024-09-01-preview" # https://learn.microsoft.com/en-us/azure/ai-services/openai/api-version-deprecation
|
||||
base_url: str = Field(
|
||||
..., description="Base URL for the Azure API endpoint. This should be specific to your org, e.g. `https://letta.openai.azure.com`."
|
||||
@ -1011,6 +1062,7 @@ class AzureProvider(Provider):
|
||||
model_endpoint=model_endpoint,
|
||||
context_window=context_window_size,
|
||||
handle=self.get_handle(model_name),
|
||||
provider_name=self.name,
|
||||
),
|
||||
)
|
||||
return configs
|
||||
@ -1051,7 +1103,7 @@ class VLLMChatCompletionsProvider(Provider):
|
||||
"""vLLM provider that treats vLLM as an OpenAI /chat/completions proxy"""
|
||||
|
||||
# NOTE: vLLM only serves one model at a time (so could configure that through env variables)
|
||||
name: str = "vllm"
|
||||
provider_type: Literal[ProviderType.vllm] = Field(ProviderType.vllm, description="The type of the provider.")
|
||||
base_url: str = Field(..., description="Base URL for the vLLM API.")
|
||||
|
||||
def list_llm_models(self) -> List[LLMConfig]:
|
||||
@ -1070,6 +1122,7 @@ class VLLMChatCompletionsProvider(Provider):
|
||||
model_endpoint=self.base_url,
|
||||
context_window=model["max_model_len"],
|
||||
handle=self.get_handle(model["id"]),
|
||||
provider_name=self.name,
|
||||
)
|
||||
)
|
||||
return configs
|
||||
@ -1083,7 +1136,7 @@ class VLLMCompletionsProvider(Provider):
|
||||
"""This uses /completions API as the backend, not /chat/completions, so we need to specify a model wrapper"""
|
||||
|
||||
# NOTE: vLLM only serves one model at a time (so could configure that through env variables)
|
||||
name: str = "vllm"
|
||||
provider_type: Literal[ProviderType.vllm] = Field(ProviderType.vllm, description="The type of the provider.")
|
||||
base_url: str = Field(..., description="Base URL for the vLLM API.")
|
||||
default_prompt_formatter: str = Field(..., description="Default prompt formatter (aka model wrapper) to use on vLLM /completions API.")
|
||||
|
||||
@ -1103,6 +1156,7 @@ class VLLMCompletionsProvider(Provider):
|
||||
model_wrapper=self.default_prompt_formatter,
|
||||
context_window=model["max_model_len"],
|
||||
handle=self.get_handle(model["id"]),
|
||||
provider_name=self.name,
|
||||
)
|
||||
)
|
||||
return configs
|
||||
@ -1117,7 +1171,7 @@ class CohereProvider(OpenAIProvider):
|
||||
|
||||
|
||||
class AnthropicBedrockProvider(Provider):
|
||||
name: str = "bedrock"
|
||||
provider_type: Literal[ProviderType.bedrock] = Field(ProviderType.bedrock, description="The type of the provider.")
|
||||
aws_region: str = Field(..., description="AWS region for Bedrock")
|
||||
|
||||
def list_llm_models(self):
|
||||
@ -1131,10 +1185,11 @@ class AnthropicBedrockProvider(Provider):
|
||||
configs.append(
|
||||
LLMConfig(
|
||||
model=model_arn,
|
||||
model_endpoint_type=self.name,
|
||||
model_endpoint_type=self.provider_type.value,
|
||||
model_endpoint=None,
|
||||
context_window=self.get_model_context_window(model_arn),
|
||||
handle=self.get_handle(model_arn),
|
||||
provider_name=self.name,
|
||||
)
|
||||
)
|
||||
return configs
|
||||
|
@ -1,6 +1,6 @@
|
||||
from typing import TYPE_CHECKING, List
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
@ -14,10 +14,11 @@ router = APIRouter(prefix="/models", tags=["models", "llms"])
|
||||
|
||||
@router.get("/", response_model=List[LLMConfig], operation_id="list_models")
|
||||
def list_llm_models(
|
||||
byok_only: Optional[bool] = Query(None),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
):
|
||||
|
||||
models = server.list_llm_models()
|
||||
models = server.list_llm_models(byok_only=byok_only)
|
||||
# print(models)
|
||||
return models
|
||||
|
||||
|
@ -2,6 +2,7 @@ from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query
|
||||
|
||||
from letta.schemas.enums import ProviderType
|
||||
from letta.schemas.providers import Provider, ProviderCreate, ProviderUpdate
|
||||
from letta.server.rest_api.utils import get_letta_server
|
||||
|
||||
@ -13,6 +14,8 @@ router = APIRouter(prefix="/providers", tags=["providers"])
|
||||
|
||||
@router.get("/", response_model=List[Provider], operation_id="list_providers")
|
||||
def list_providers(
|
||||
name: Optional[str] = Query(None),
|
||||
provider_type: Optional[ProviderType] = Query(None),
|
||||
after: Optional[str] = Query(None),
|
||||
limit: Optional[int] = Query(50),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
@ -23,7 +26,7 @@ def list_providers(
|
||||
"""
|
||||
try:
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
providers = server.provider_manager.list_providers(after=after, limit=limit, actor=actor)
|
||||
providers = server.provider_manager.list_providers(after=after, limit=limit, actor=actor, name=name, provider_type=provider_type)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
|
@ -268,10 +268,11 @@ class SyncServer(Server):
|
||||
)
|
||||
|
||||
# collect providers (always has Letta as a default)
|
||||
self._enabled_providers: List[Provider] = [LettaProvider()]
|
||||
self._enabled_providers: List[Provider] = [LettaProvider(name="letta")]
|
||||
if model_settings.openai_api_key:
|
||||
self._enabled_providers.append(
|
||||
OpenAIProvider(
|
||||
name="openai",
|
||||
api_key=model_settings.openai_api_key,
|
||||
base_url=model_settings.openai_api_base,
|
||||
)
|
||||
@ -279,12 +280,14 @@ class SyncServer(Server):
|
||||
if model_settings.anthropic_api_key:
|
||||
self._enabled_providers.append(
|
||||
AnthropicProvider(
|
||||
name="anthropic",
|
||||
api_key=model_settings.anthropic_api_key,
|
||||
)
|
||||
)
|
||||
if model_settings.ollama_base_url:
|
||||
self._enabled_providers.append(
|
||||
OllamaProvider(
|
||||
name="ollama",
|
||||
base_url=model_settings.ollama_base_url,
|
||||
api_key=None,
|
||||
default_prompt_formatter=model_settings.default_prompt_formatter,
|
||||
@ -293,12 +296,14 @@ class SyncServer(Server):
|
||||
if model_settings.gemini_api_key:
|
||||
self._enabled_providers.append(
|
||||
GoogleAIProvider(
|
||||
name="google_ai",
|
||||
api_key=model_settings.gemini_api_key,
|
||||
)
|
||||
)
|
||||
if model_settings.google_cloud_location and model_settings.google_cloud_project:
|
||||
self._enabled_providers.append(
|
||||
GoogleVertexProvider(
|
||||
name="google_vertex",
|
||||
google_cloud_project=model_settings.google_cloud_project,
|
||||
google_cloud_location=model_settings.google_cloud_location,
|
||||
)
|
||||
@ -307,6 +312,7 @@ class SyncServer(Server):
|
||||
assert model_settings.azure_api_version, "AZURE_API_VERSION is required"
|
||||
self._enabled_providers.append(
|
||||
AzureProvider(
|
||||
name="azure",
|
||||
api_key=model_settings.azure_api_key,
|
||||
base_url=model_settings.azure_base_url,
|
||||
api_version=model_settings.azure_api_version,
|
||||
@ -315,12 +321,14 @@ class SyncServer(Server):
|
||||
if model_settings.groq_api_key:
|
||||
self._enabled_providers.append(
|
||||
GroqProvider(
|
||||
name="groq",
|
||||
api_key=model_settings.groq_api_key,
|
||||
)
|
||||
)
|
||||
if model_settings.together_api_key:
|
||||
self._enabled_providers.append(
|
||||
TogetherProvider(
|
||||
name="together",
|
||||
api_key=model_settings.together_api_key,
|
||||
default_prompt_formatter=model_settings.default_prompt_formatter,
|
||||
)
|
||||
@ -329,6 +337,7 @@ class SyncServer(Server):
|
||||
# vLLM exposes both a /chat/completions and a /completions endpoint
|
||||
self._enabled_providers.append(
|
||||
VLLMCompletionsProvider(
|
||||
name="vllm",
|
||||
base_url=model_settings.vllm_api_base,
|
||||
default_prompt_formatter=model_settings.default_prompt_formatter,
|
||||
)
|
||||
@ -338,12 +347,14 @@ class SyncServer(Server):
|
||||
# e.g. "... --enable-auto-tool-choice --tool-call-parser hermes"
|
||||
self._enabled_providers.append(
|
||||
VLLMChatCompletionsProvider(
|
||||
name="vllm",
|
||||
base_url=model_settings.vllm_api_base,
|
||||
)
|
||||
)
|
||||
if model_settings.aws_access_key and model_settings.aws_secret_access_key and model_settings.aws_region:
|
||||
self._enabled_providers.append(
|
||||
AnthropicBedrockProvider(
|
||||
name="bedrock",
|
||||
aws_region=model_settings.aws_region,
|
||||
)
|
||||
)
|
||||
@ -355,11 +366,11 @@ class SyncServer(Server):
|
||||
if model_settings.lmstudio_base_url.endswith("/v1")
|
||||
else model_settings.lmstudio_base_url + "/v1"
|
||||
)
|
||||
self._enabled_providers.append(LMStudioOpenAIProvider(base_url=lmstudio_url))
|
||||
self._enabled_providers.append(LMStudioOpenAIProvider(name="lmstudio_openai", base_url=lmstudio_url))
|
||||
if model_settings.deepseek_api_key:
|
||||
self._enabled_providers.append(DeepSeekProvider(api_key=model_settings.deepseek_api_key))
|
||||
self._enabled_providers.append(DeepSeekProvider(name="deepseek", api_key=model_settings.deepseek_api_key))
|
||||
if model_settings.xai_api_key:
|
||||
self._enabled_providers.append(XAIProvider(api_key=model_settings.xai_api_key))
|
||||
self._enabled_providers.append(XAIProvider(name="xai", api_key=model_settings.xai_api_key))
|
||||
|
||||
# For MCP
|
||||
"""Initialize the MCP clients (there may be multiple)"""
|
||||
@ -1184,10 +1195,10 @@ class SyncServer(Server):
|
||||
except NoResultFound:
|
||||
raise HTTPException(status_code=404, detail=f"Organization with id {org_id} not found")
|
||||
|
||||
def list_llm_models(self) -> List[LLMConfig]:
|
||||
def list_llm_models(self, byok_only: bool = False) -> List[LLMConfig]:
|
||||
"""List available models"""
|
||||
llm_models = []
|
||||
for provider in self.get_enabled_providers():
|
||||
for provider in self.get_enabled_providers(byok_only=byok_only):
|
||||
try:
|
||||
llm_models.extend(provider.list_llm_models())
|
||||
except Exception as e:
|
||||
@ -1207,11 +1218,12 @@ class SyncServer(Server):
|
||||
warnings.warn(f"An error occurred while listing embedding models for provider {provider}: {e}")
|
||||
return embedding_models
|
||||
|
||||
def get_enabled_providers(self):
|
||||
def get_enabled_providers(self, byok_only: bool = False):
|
||||
providers_from_db = {p.name: p.cast_to_subtype() for p in self.provider_manager.list_providers()}
|
||||
if byok_only:
|
||||
return list(providers_from_db.values())
|
||||
providers_from_env = {p.name: p for p in self._enabled_providers}
|
||||
providers_from_db = {p.name: p for p in self.provider_manager.list_providers()}
|
||||
# Merge the two dictionaries, keeping the values from providers_from_db where conflicts occur
|
||||
return {**providers_from_env, **providers_from_db}.values()
|
||||
return list(providers_from_env.values()) + list(providers_from_db.values())
|
||||
|
||||
@trace_method
|
||||
def get_llm_config_from_handle(
|
||||
@ -1296,7 +1308,7 @@ class SyncServer(Server):
|
||||
return embedding_config
|
||||
|
||||
def get_provider_from_name(self, provider_name: str) -> Provider:
|
||||
providers = [provider for provider in self._enabled_providers if provider.name == provider_name]
|
||||
providers = [provider for provider in self.get_enabled_providers() if provider.name == provider_name]
|
||||
if not providers:
|
||||
raise ValueError(f"Provider {provider_name} is not supported")
|
||||
elif len(providers) > 1:
|
||||
|
@ -1,6 +1,7 @@
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from letta.orm.provider import Provider as ProviderModel
|
||||
from letta.schemas.enums import ProviderType
|
||||
from letta.schemas.providers import Provider as PydanticProvider
|
||||
from letta.schemas.providers import ProviderUpdate
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
@ -18,6 +19,9 @@ class ProviderManager:
|
||||
def create_provider(self, provider: PydanticProvider, actor: PydanticUser) -> PydanticProvider:
|
||||
"""Create a new provider if it doesn't already exist."""
|
||||
with self.session_maker() as session:
|
||||
if provider.name == provider.provider_type.value:
|
||||
raise ValueError("Provider name must be unique and different from provider type")
|
||||
|
||||
# Assign the organization id based on the actor
|
||||
provider.organization_id = actor.organization_id
|
||||
|
||||
@ -59,29 +63,36 @@ class ProviderManager:
|
||||
session.commit()
|
||||
|
||||
@enforce_types
|
||||
def list_providers(self, after: Optional[str] = None, limit: Optional[int] = 50, actor: PydanticUser = None) -> List[PydanticProvider]:
|
||||
def list_providers(
|
||||
self,
|
||||
name: Optional[str] = None,
|
||||
provider_type: Optional[ProviderType] = None,
|
||||
after: Optional[str] = None,
|
||||
limit: Optional[int] = 50,
|
||||
actor: PydanticUser = None,
|
||||
) -> List[PydanticProvider]:
|
||||
"""List all providers with optional pagination."""
|
||||
filter_kwargs = {}
|
||||
if name:
|
||||
filter_kwargs["name"] = name
|
||||
if provider_type:
|
||||
filter_kwargs["provider_type"] = provider_type
|
||||
with self.session_maker() as session:
|
||||
providers = ProviderModel.list(
|
||||
db_session=session,
|
||||
after=after,
|
||||
limit=limit,
|
||||
actor=actor,
|
||||
**filter_kwargs,
|
||||
)
|
||||
return [provider.to_pydantic() for provider in providers]
|
||||
|
||||
@enforce_types
|
||||
def get_anthropic_override_provider_id(self) -> Optional[str]:
|
||||
"""Helper function to fetch custom anthropic provider id for v0 BYOK feature"""
|
||||
anthropic_provider = [provider for provider in self.list_providers() if provider.name == "anthropic"]
|
||||
if len(anthropic_provider) != 0:
|
||||
return anthropic_provider[0].id
|
||||
return None
|
||||
def get_provider_id_from_name(self, provider_name: Union[str, None]) -> Optional[str]:
|
||||
providers = self.list_providers(name=provider_name)
|
||||
return providers[0].id if providers else None
|
||||
|
||||
@enforce_types
|
||||
def get_anthropic_override_key(self) -> Optional[str]:
|
||||
"""Helper function to fetch custom anthropic key for v0 BYOK feature"""
|
||||
anthropic_provider = [provider for provider in self.list_providers() if provider.name == "anthropic"]
|
||||
if len(anthropic_provider) != 0:
|
||||
return anthropic_provider[0].api_key
|
||||
return None
|
||||
def get_override_key(self, provider_name: Union[str, None]) -> Optional[str]:
|
||||
providers = self.list_providers(name=provider_name)
|
||||
return providers[0].api_key if providers else None
|
||||
|
@ -105,7 +105,8 @@ def check_first_response_is_valid_for_llm_endpoint(filename: str, validate_inner
|
||||
agent = Agent(agent_state=full_agent_state, interface=None, user=client.user)
|
||||
|
||||
llm_client = LLMClient.create(
|
||||
provider=agent_state.llm_config.model_endpoint_type,
|
||||
provider_name=agent_state.llm_config.provider_name,
|
||||
provider_type=agent_state.llm_config.model_endpoint_type,
|
||||
)
|
||||
if llm_client:
|
||||
response = llm_client.send_llm_request(
|
||||
|
@ -19,97 +19,166 @@ from letta.settings import model_settings
|
||||
def test_openai():
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
assert api_key is not None
|
||||
provider = OpenAIProvider(api_key=api_key, base_url=model_settings.openai_api_base)
|
||||
provider = OpenAIProvider(
|
||||
name="openai",
|
||||
api_key=api_key,
|
||||
base_url=model_settings.openai_api_base,
|
||||
)
|
||||
models = provider.list_llm_models()
|
||||
print(models)
|
||||
assert len(models) > 0
|
||||
assert models[0].handle == f"{provider.name}/{models[0].model}"
|
||||
|
||||
embedding_models = provider.list_embedding_models()
|
||||
assert len(embedding_models) > 0
|
||||
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
|
||||
|
||||
|
||||
def test_deepseek():
|
||||
api_key = os.getenv("DEEPSEEK_API_KEY")
|
||||
assert api_key is not None
|
||||
provider = DeepSeekProvider(api_key=api_key)
|
||||
provider = DeepSeekProvider(
|
||||
name="deepseek",
|
||||
api_key=api_key,
|
||||
)
|
||||
models = provider.list_llm_models()
|
||||
print(models)
|
||||
assert len(models) > 0
|
||||
assert models[0].handle == f"{provider.name}/{models[0].model}"
|
||||
|
||||
|
||||
def test_anthropic():
|
||||
api_key = os.getenv("ANTHROPIC_API_KEY")
|
||||
assert api_key is not None
|
||||
provider = AnthropicProvider(api_key=api_key)
|
||||
provider = AnthropicProvider(
|
||||
name="anthropic",
|
||||
api_key=api_key,
|
||||
)
|
||||
models = provider.list_llm_models()
|
||||
print(models)
|
||||
assert len(models) > 0
|
||||
assert models[0].handle == f"{provider.name}/{models[0].model}"
|
||||
|
||||
|
||||
def test_groq():
|
||||
provider = GroqProvider(api_key=os.getenv("GROQ_API_KEY"))
|
||||
provider = GroqProvider(
|
||||
name="groq",
|
||||
api_key=os.getenv("GROQ_API_KEY"),
|
||||
)
|
||||
models = provider.list_llm_models()
|
||||
print(models)
|
||||
assert len(models) > 0
|
||||
assert models[0].handle == f"{provider.name}/{models[0].model}"
|
||||
|
||||
|
||||
def test_azure():
|
||||
provider = AzureProvider(api_key=os.getenv("AZURE_API_KEY"), base_url=os.getenv("AZURE_BASE_URL"))
|
||||
provider = AzureProvider(
|
||||
name="azure",
|
||||
api_key=os.getenv("AZURE_API_KEY"),
|
||||
base_url=os.getenv("AZURE_BASE_URL"),
|
||||
)
|
||||
models = provider.list_llm_models()
|
||||
print([m.model for m in models])
|
||||
assert len(models) > 0
|
||||
assert models[0].handle == f"{provider.name}/{models[0].model}"
|
||||
|
||||
embed_models = provider.list_embedding_models()
|
||||
print([m.embedding_model for m in embed_models])
|
||||
embedding_models = provider.list_embedding_models()
|
||||
assert len(embedding_models) > 0
|
||||
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
|
||||
|
||||
|
||||
def test_ollama():
|
||||
base_url = os.getenv("OLLAMA_BASE_URL")
|
||||
assert base_url is not None
|
||||
provider = OllamaProvider(base_url=base_url, default_prompt_formatter=model_settings.default_prompt_formatter, api_key=None)
|
||||
provider = OllamaProvider(
|
||||
name="ollama",
|
||||
base_url=base_url,
|
||||
default_prompt_formatter=model_settings.default_prompt_formatter,
|
||||
api_key=None,
|
||||
)
|
||||
models = provider.list_llm_models()
|
||||
print(models)
|
||||
assert len(models) > 0
|
||||
assert models[0].handle == f"{provider.name}/{models[0].model}"
|
||||
|
||||
embedding_models = provider.list_embedding_models()
|
||||
print(embedding_models)
|
||||
assert len(embedding_models) > 0
|
||||
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
|
||||
|
||||
|
||||
def test_googleai():
|
||||
api_key = os.getenv("GEMINI_API_KEY")
|
||||
assert api_key is not None
|
||||
provider = GoogleAIProvider(api_key=api_key)
|
||||
provider = GoogleAIProvider(
|
||||
name="google_ai",
|
||||
api_key=api_key,
|
||||
)
|
||||
models = provider.list_llm_models()
|
||||
print(models)
|
||||
assert len(models) > 0
|
||||
assert models[0].handle == f"{provider.name}/{models[0].model}"
|
||||
|
||||
provider.list_embedding_models()
|
||||
embedding_models = provider.list_embedding_models()
|
||||
assert len(embedding_models) > 0
|
||||
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
|
||||
|
||||
|
||||
def test_google_vertex():
|
||||
provider = GoogleVertexProvider(google_cloud_project=os.getenv("GCP_PROJECT_ID"), google_cloud_location=os.getenv("GCP_REGION"))
|
||||
provider = GoogleVertexProvider(
|
||||
name="google_vertex",
|
||||
google_cloud_project=os.getenv("GCP_PROJECT_ID"),
|
||||
google_cloud_location=os.getenv("GCP_REGION"),
|
||||
)
|
||||
models = provider.list_llm_models()
|
||||
print(models)
|
||||
print([m.model for m in models])
|
||||
assert len(models) > 0
|
||||
assert models[0].handle == f"{provider.name}/{models[0].model}"
|
||||
|
||||
embedding_models = provider.list_embedding_models()
|
||||
print([m.embedding_model for m in embedding_models])
|
||||
assert len(embedding_models) > 0
|
||||
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
|
||||
|
||||
|
||||
def test_mistral():
|
||||
provider = MistralProvider(api_key=os.getenv("MISTRAL_API_KEY"))
|
||||
provider = MistralProvider(
|
||||
name="mistral",
|
||||
api_key=os.getenv("MISTRAL_API_KEY"),
|
||||
)
|
||||
models = provider.list_llm_models()
|
||||
print([m.model for m in models])
|
||||
assert len(models) > 0
|
||||
assert models[0].handle == f"{provider.name}/{models[0].model}"
|
||||
|
||||
|
||||
def test_together():
|
||||
provider = TogetherProvider(api_key=os.getenv("TOGETHER_API_KEY"), default_prompt_formatter="chatml")
|
||||
provider = TogetherProvider(
|
||||
name="together",
|
||||
api_key=os.getenv("TOGETHER_API_KEY"),
|
||||
default_prompt_formatter="chatml",
|
||||
)
|
||||
models = provider.list_llm_models()
|
||||
print([m.model for m in models])
|
||||
assert len(models) > 0
|
||||
assert models[0].handle == f"{provider.name}/{models[0].model}"
|
||||
|
||||
embedding_models = provider.list_embedding_models()
|
||||
print([m.embedding_model for m in embedding_models])
|
||||
assert len(embedding_models) > 0
|
||||
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
|
||||
|
||||
|
||||
def test_anthropic_bedrock():
|
||||
from letta.settings import model_settings
|
||||
|
||||
provider = AnthropicBedrockProvider(aws_region=model_settings.aws_region)
|
||||
provider = AnthropicBedrockProvider(name="bedrock", aws_region=model_settings.aws_region)
|
||||
models = provider.list_llm_models()
|
||||
print([m.model for m in models])
|
||||
assert len(models) > 0
|
||||
assert models[0].handle == f"{provider.name}/{models[0].model}"
|
||||
|
||||
embedding_models = provider.list_embedding_models()
|
||||
print([m.embedding_model for m in embedding_models])
|
||||
assert len(embedding_models) > 0
|
||||
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
|
||||
|
||||
|
||||
def test_custom_anthropic():
|
||||
api_key = os.getenv("ANTHROPIC_API_KEY")
|
||||
assert api_key is not None
|
||||
provider = AnthropicProvider(
|
||||
name="custom_anthropic",
|
||||
api_key=api_key,
|
||||
)
|
||||
models = provider.list_llm_models()
|
||||
assert len(models) > 0
|
||||
assert models[0].handle == f"{provider.name}/{models[0].model}"
|
||||
|
||||
|
||||
# def test_vllm():
|
||||
|
@ -13,7 +13,7 @@ import letta.utils as utils
|
||||
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, LETTA_DIR, LETTA_TOOL_EXECUTION_DIR
|
||||
from letta.orm import Provider, Step
|
||||
from letta.schemas.block import CreateBlock
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.enums import MessageRole, ProviderType
|
||||
from letta.schemas.letta_message import LettaMessage, ReasoningMessage, SystemMessage, ToolCallMessage, ToolReturnMessage, UserMessage
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.providers import Provider as PydanticProvider
|
||||
@ -1226,7 +1226,8 @@ def test_messages_with_provider_override(server: SyncServer, user_id: str):
|
||||
actor = server.user_manager.get_user_or_default(user_id)
|
||||
provider = server.provider_manager.create_provider(
|
||||
provider=PydanticProvider(
|
||||
name="anthropic",
|
||||
name="caren-anthropic",
|
||||
provider_type=ProviderType.anthropic,
|
||||
api_key=os.getenv("ANTHROPIC_API_KEY"),
|
||||
),
|
||||
actor=actor,
|
||||
@ -1234,8 +1235,8 @@ def test_messages_with_provider_override(server: SyncServer, user_id: str):
|
||||
agent = server.create_agent(
|
||||
request=CreateAgent(
|
||||
memory_blocks=[],
|
||||
model="anthropic/claude-3-opus-20240229",
|
||||
context_window_limit=200000,
|
||||
model="caren-anthropic/claude-3-opus-20240229",
|
||||
context_window_limit=100000,
|
||||
embedding="openai/text-embedding-ada-002",
|
||||
),
|
||||
actor=actor,
|
||||
|
Loading…
Reference in New Issue
Block a user