feat: support custom api keys for cloud (#533)

This commit is contained in:
cthomas 2025-01-07 22:12:55 -08:00 committed by GitHub
parent 9a0613bdad
commit 4a2e321e99
9 changed files with 244 additions and 5 deletions

View File

@ -0,0 +1,47 @@
"""Add providers data to ORM
Revision ID: 915b68780108
Revises: 400501b04bf0
Create Date: 2025-01-07 10:49:04.717058
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "915b68780108"
down_revision: Union[str, None] = "400501b04bf0"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"providers",
sa.Column("name", sa.String(), nullable=False),
sa.Column("api_key", sa.String(), nullable=True),
sa.Column("id", sa.String(), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False),
sa.Column("_created_by_id", sa.String(), nullable=True),
sa.Column("_last_updated_by_id", sa.String(), nullable=True),
sa.Column("organization_id", sa.String(), nullable=False),
sa.ForeignKeyConstraint(
["organization_id"],
["organizations.id"],
),
sa.PrimaryKeyConstraint("id"),
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("providers")
# ### end Alembic commands ###

View File

@ -22,6 +22,7 @@ from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message from letta.schemas.message import Message
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, Tool, cast_message_to_subtype from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, Tool, cast_message_to_subtype
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
from letta.services.provider_manager import ProviderManager
from letta.settings import ModelSettings from letta.settings import ModelSettings
from letta.streaming_interface import AgentChunkStreamingInterface, AgentRefreshStreamingInterface from letta.streaming_interface import AgentChunkStreamingInterface, AgentRefreshStreamingInterface
@ -251,9 +252,12 @@ def create(
tool_call = {"type": "function", "function": {"name": force_tool_call}} tool_call = {"type": "function", "function": {"name": force_tool_call}}
assert functions is not None assert functions is not None
# load anthropic key from db in case a custom key has been stored
anthropic_key_override = ProviderManager().get_anthropic_key_override()
return anthropic_chat_completions_request( return anthropic_chat_completions_request(
url=llm_config.model_endpoint, url=llm_config.model_endpoint,
api_key=model_settings.anthropic_api_key, api_key=anthropic_key_override if anthropic_key_override else model_settings.anthropic_api_key,
data=ChatCompletionRequest( data=ChatCompletionRequest(
model=llm_config.model, model=llm_config.model,
messages=[cast_message_to_subtype(m.to_openai_dict()) for m in messages], messages=[cast_message_to_subtype(m.to_openai_dict()) for m in messages],

View File

@ -8,6 +8,7 @@ from letta.orm.job import Job
from letta.orm.message import Message from letta.orm.message import Message
from letta.orm.organization import Organization from letta.orm.organization import Organization
from letta.orm.passage import AgentPassage, BasePassage, SourcePassage from letta.orm.passage import AgentPassage, BasePassage, SourcePassage
from letta.orm.provider import Provider
from letta.orm.sandbox_config import AgentEnvironmentVariable, SandboxConfig, SandboxEnvironmentVariable from letta.orm.sandbox_config import AgentEnvironmentVariable, SandboxConfig, SandboxEnvironmentVariable
from letta.orm.source import Source from letta.orm.source import Source
from letta.orm.sources_agents import SourcesAgents from letta.orm.sources_agents import SourcesAgents

View File

@ -9,6 +9,7 @@ if TYPE_CHECKING:
from letta.orm.agent import Agent from letta.orm.agent import Agent
from letta.orm.file import FileMetadata from letta.orm.file import FileMetadata
from letta.orm.provider import Provider
from letta.orm.sandbox_config import AgentEnvironmentVariable from letta.orm.sandbox_config import AgentEnvironmentVariable
from letta.orm.tool import Tool from letta.orm.tool import Tool
from letta.orm.user import User from letta.orm.user import User
@ -45,6 +46,7 @@ class Organization(SqlalchemyBase):
"SourcePassage", back_populates="organization", cascade="all, delete-orphan" "SourcePassage", back_populates="organization", cascade="all, delete-orphan"
) )
agent_passages: Mapped[List["AgentPassage"]] = relationship("AgentPassage", back_populates="organization", cascade="all, delete-orphan") agent_passages: Mapped[List["AgentPassage"]] = relationship("AgentPassage", back_populates="organization", cascade="all, delete-orphan")
providers: Mapped[List["Provider"]] = relationship("Provider", back_populates="organization", cascade="all, delete-orphan")
@property @property
def passages(self) -> List[Union["SourcePassage", "AgentPassage"]]: def passages(self) -> List[Union["SourcePassage", "AgentPassage"]]:

23
letta/orm/provider.py Normal file
View File

@ -0,0 +1,23 @@
from typing import TYPE_CHECKING
from sqlalchemy.orm import Mapped, mapped_column, relationship
from letta.orm.mixins import OrganizationMixin
from letta.orm.sqlalchemy_base import SqlalchemyBase
from letta.providers import Provider as PydanticProvider
if TYPE_CHECKING:
from letta.orm.organization import Organization
class Provider(SqlalchemyBase, OrganizationMixin):
"""Provider ORM class"""
__tablename__ = "providers"
__pydantic_model__ = PydanticProvider
name: Mapped[str] = mapped_column(nullable=False, doc="The name of the provider")
api_key: Mapped[str] = mapped_column(nullable=True, doc="API key used for requests to the provider.")
# relationships
organization: Mapped["Organization"] = relationship("Organization", back_populates="providers")

View File

@ -1,16 +1,24 @@
from typing import List, Optional from typing import List, Optional
from pydantic import BaseModel, Field, model_validator from pydantic import Field, model_validator
from letta.constants import LLM_MAX_TOKENS, MIN_CONTEXT_WINDOW from letta.constants import LLM_MAX_TOKENS, MIN_CONTEXT_WINDOW
from letta.llm_api.azure_openai import get_azure_chat_completions_endpoint, get_azure_embeddings_endpoint from letta.llm_api.azure_openai import get_azure_chat_completions_endpoint, get_azure_embeddings_endpoint
from letta.llm_api.azure_openai_constants import AZURE_MODEL_TO_CONTEXT_LENGTH from letta.llm_api.azure_openai_constants import AZURE_MODEL_TO_CONTEXT_LENGTH
from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.letta_base import LettaBase
from letta.schemas.llm_config import LLMConfig from letta.schemas.llm_config import LLMConfig
from letta.services.organization_manager import OrganizationManager
class Provider(BaseModel): class ProviderBase(LettaBase):
__id_prefix__ = "provider"
class Provider(ProviderBase):
name: str = Field(..., description="The name of the provider") name: str = Field(..., description="The name of the provider")
api_key: Optional[str] = Field(None, description="API key used for requests to the provider.")
organization_id: Optional[str] = Field(OrganizationManager.DEFAULT_ORG_ID, description="The organization id of the user")
def list_llm_models(self) -> List[LLMConfig]: def list_llm_models(self) -> List[LLMConfig]:
return [] return []
@ -29,6 +37,17 @@ class Provider(BaseModel):
return f"{self.name}/{model_name}" return f"{self.name}/{model_name}"
class ProviderCreate(ProviderBase):
name: str = Field(..., description="The name of the provider.")
api_key: str = Field(..., description="API key used for requests to the provider.")
organization_id: str = Field(..., description="The organization id that this provider information pertains to.")
class ProviderUpdate(ProviderBase):
id: str = Field(..., description="The id of the provider to update.")
api_key: str = Field(..., description="API key used for requests to the provider.")
class LettaProvider(Provider): class LettaProvider(Provider):
name: str = "letta" name: str = "letta"

View File

@ -0,0 +1,72 @@
from fastapi import APIRouter, Depends
from letta.providers import Provider, ProviderCreate, ProviderUpdate
from letta.server.rest_api.utils import get_letta_server
if TYPE_CHECKING:
from letta.server.server import SyncServer
router = APIRouter(prefix="/providers", tags=["providers", "admin"])
@router.get("/", tags=["admin"], response_model=List[Provider], operation_id="list_providers")
def list_providers(
cursor: Optional[str] = Query(None),
limit: Optional[int] = Query(50),
server: "SyncServer" = Depends(get_letta_server),
):
"""
Get a list of all custom providers in the database
"""
try:
providers = server.provider_manager.list_providers(cursor=cursor, limit=limit)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
return providers
@router.post("/", tags=["admin"], response_model=Provider, operation_id="create_provider")
def create_provider(
request: ProviderCreate = Body(...),
server: "SyncServer" = Depends(get_letta_server),
):
"""
Create a new custom provider
"""
provider = Provider(**request.model_dump())
provider = server.provider_manager.create_provider(provider)
return provider
@router.put("/", tags=["admin"], response_model=Provider, operation_id="update_provider")
def update_provider(
request: ProviderUpdate = Body(...),
server: "SyncServer" = Depends(get_letta_server),
):
"""
Update an existing custom provider
"""
provider = server.provider_manager.update_provider(request)
return provider
@router.delete("/", tags=["admin"], response_model=Provider, operation_id="delete_provider")
def delete_provider(
provider_id: str = Query(..., description="The provider_id key to be deleted."),
server: "SyncServer" = Depends(get_letta_server),
):
"""
Delete an existing custom provider
"""
try:
provider = server.provider_manager.get_provider_by_id(provider_id=provider_id)
if provider is None:
raise HTTPException(status_code=404, detail=f"Provider does not exist")
server.provider_manager.delete_provider_by_id(provider_id=provider_id)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
return user

View File

@ -66,6 +66,7 @@ from letta.services.message_manager import MessageManager
from letta.services.organization_manager import OrganizationManager from letta.services.organization_manager import OrganizationManager
from letta.services.passage_manager import PassageManager from letta.services.passage_manager import PassageManager
from letta.services.per_agent_lock_manager import PerAgentLockManager from letta.services.per_agent_lock_manager import PerAgentLockManager
from letta.services.provider_manager import ProviderManager
from letta.services.sandbox_config_manager import SandboxConfigManager from letta.services.sandbox_config_manager import SandboxConfigManager
from letta.services.source_manager import SourceManager from letta.services.source_manager import SourceManager
from letta.services.tool_execution_sandbox import ToolExecutionSandbox from letta.services.tool_execution_sandbox import ToolExecutionSandbox
@ -290,6 +291,7 @@ class SyncServer(Server):
self.message_manager = MessageManager() self.message_manager = MessageManager()
self.job_manager = JobManager() self.job_manager = JobManager()
self.agent_manager = AgentManager() self.agent_manager = AgentManager()
self.provider_manager = ProviderManager()
# Managers that interface with parallelism # Managers that interface with parallelism
self.per_agent_lock_manager = PerAgentLockManager() self.per_agent_lock_manager = PerAgentLockManager()
@ -1030,7 +1032,7 @@ class SyncServer(Server):
"""List available models""" """List available models"""
llm_models = [] llm_models = []
for provider in self._enabled_providers: for provider in self.get_enabled_providers():
try: try:
llm_models.extend(provider.list_llm_models()) llm_models.extend(provider.list_llm_models())
except Exception as e: except Exception as e:
@ -1040,13 +1042,19 @@ class SyncServer(Server):
def list_embedding_models(self) -> List[EmbeddingConfig]: def list_embedding_models(self) -> List[EmbeddingConfig]:
"""List available embedding models""" """List available embedding models"""
embedding_models = [] embedding_models = []
for provider in self._enabled_providers: for provider in self.get_enabled_providers():
try: try:
embedding_models.extend(provider.list_embedding_models()) embedding_models.extend(provider.list_embedding_models())
except Exception as e: except Exception as e:
warnings.warn(f"An error occurred while listing embedding models for provider {provider}: {e}") warnings.warn(f"An error occurred while listing embedding models for provider {provider}: {e}")
return embedding_models return embedding_models
def get_enabled_providers(self):
providers_from_env = {p.name: p for p in self._enabled_providers}
providers_from_db = {p.name: p for p in self.provider_manager.list_providers()}
# Merge the two dictionaries, keeping the values from providers_from_db where conflicts occur
return {**providers_from_env, **providers_from_db}.values()
def get_llm_config_from_handle(self, handle: str, context_window_limit: Optional[int] = None) -> LLMConfig: def get_llm_config_from_handle(self, handle: str, context_window_limit: Optional[int] = None) -> LLMConfig:
provider_name, model_name = handle.split("/", 1) provider_name, model_name = handle.split("/", 1)
provider = self.get_provider_from_name(provider_name) provider = self.get_provider_from_name(provider_name)

View File

@ -0,0 +1,63 @@
from typing import List, Optional
from letta.orm.provider import Provider as ProviderModel
from letta.providers import Provider as PydanticProvider
from letta.providers import ProviderUpdate
from letta.utils import enforce_types
class ProviderManager:
def __init__(self):
from letta.server.server import db_context
self.session_maker = db_context
@enforce_types
def create_provider(self, provider: PydanticProvider) -> PydanticProvider:
"""Create a new provider if it doesn't already exist."""
with self.session_maker() as session:
new_provider = ProviderModel(**provider.model_dump())
new_provider.create(session)
return new_provider.to_pydantic()
@enforce_types
def update_provider(self, provider_update: ProviderUpdate) -> PydanticProvider:
"""Update provider details."""
with self.session_maker() as session:
# Retrieve the existing provider by ID
existing_provider = ProviderModel.read(db_session=session, identifier=provider_update.id)
# Update only the fields that are provided in ProviderUpdate
update_data = provider_update.model_dump(exclude_unset=True, exclude_none=True)
for key, value in update_data.items():
setattr(existing_provider, key, value)
# Commit the updated provider
existing_provider.update(session)
return existing_provider.to_pydantic()
@enforce_types
def delete_provider_by_id(self, provider_id: str):
"""Delete a provider."""
with self.session_maker() as session:
# Delete from provider table
provider = ProviderModel.read(db_session=session, identifier=provider_id)
provider.hard_delete(session)
session.commit()
@enforce_types
def list_providers(self, cursor: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticProvider]:
"""List providers with pagination using cursor (id) and limit."""
with self.session_maker() as session:
results = ProviderModel.list(db_session=session, cursor=cursor, limit=limit)
return [provider.to_pydantic() for provider in results]
@enforce_types
def get_anthropic_key_override(self) -> Optional[str]:
"""Helper function to fetch custom anthropic key for v0 BYOK feature"""
providers = self.list_providers(limit=1)
if len(providers) == 1 and providers[0].name == "anthropic":
return providers[0].api_key
return None