feat: support custom api keys for cloud (#533)

This commit is contained in:
cthomas 2025-01-07 22:12:55 -08:00 committed by GitHub
parent 9a0613bdad
commit 4a2e321e99
9 changed files with 244 additions and 5 deletions

View File

@ -0,0 +1,47 @@
"""Add providers data to ORM
Revision ID: 915b68780108
Revises: 400501b04bf0
Create Date: 2025-01-07 10:49:04.717058
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "915b68780108"
down_revision: Union[str, None] = "400501b04bf0"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"providers",
sa.Column("name", sa.String(), nullable=False),
sa.Column("api_key", sa.String(), nullable=True),
sa.Column("id", sa.String(), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False),
sa.Column("_created_by_id", sa.String(), nullable=True),
sa.Column("_last_updated_by_id", sa.String(), nullable=True),
sa.Column("organization_id", sa.String(), nullable=False),
sa.ForeignKeyConstraint(
["organization_id"],
["organizations.id"],
),
sa.PrimaryKeyConstraint("id"),
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("providers")
# ### end Alembic commands ###

View File

@ -22,6 +22,7 @@ from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, Tool, cast_message_to_subtype
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
from letta.services.provider_manager import ProviderManager
from letta.settings import ModelSettings
from letta.streaming_interface import AgentChunkStreamingInterface, AgentRefreshStreamingInterface
@ -251,9 +252,12 @@ def create(
tool_call = {"type": "function", "function": {"name": force_tool_call}}
assert functions is not None
# load anthropic key from db in case a custom key has been stored
anthropic_key_override = ProviderManager().get_anthropic_key_override()
return anthropic_chat_completions_request(
url=llm_config.model_endpoint,
api_key=model_settings.anthropic_api_key,
api_key=anthropic_key_override if anthropic_key_override else model_settings.anthropic_api_key,
data=ChatCompletionRequest(
model=llm_config.model,
messages=[cast_message_to_subtype(m.to_openai_dict()) for m in messages],

View File

@ -8,6 +8,7 @@ from letta.orm.job import Job
from letta.orm.message import Message
from letta.orm.organization import Organization
from letta.orm.passage import AgentPassage, BasePassage, SourcePassage
from letta.orm.provider import Provider
from letta.orm.sandbox_config import AgentEnvironmentVariable, SandboxConfig, SandboxEnvironmentVariable
from letta.orm.source import Source
from letta.orm.sources_agents import SourcesAgents

View File

@ -9,6 +9,7 @@ if TYPE_CHECKING:
from letta.orm.agent import Agent
from letta.orm.file import FileMetadata
from letta.orm.provider import Provider
from letta.orm.sandbox_config import AgentEnvironmentVariable
from letta.orm.tool import Tool
from letta.orm.user import User
@ -45,6 +46,7 @@ class Organization(SqlalchemyBase):
"SourcePassage", back_populates="organization", cascade="all, delete-orphan"
)
agent_passages: Mapped[List["AgentPassage"]] = relationship("AgentPassage", back_populates="organization", cascade="all, delete-orphan")
providers: Mapped[List["Provider"]] = relationship("Provider", back_populates="organization", cascade="all, delete-orphan")
@property
def passages(self) -> List[Union["SourcePassage", "AgentPassage"]]:

23
letta/orm/provider.py Normal file
View File

@ -0,0 +1,23 @@
from typing import TYPE_CHECKING
from sqlalchemy.orm import Mapped, mapped_column, relationship
from letta.orm.mixins import OrganizationMixin
from letta.orm.sqlalchemy_base import SqlalchemyBase
from letta.providers import Provider as PydanticProvider
if TYPE_CHECKING:
from letta.orm.organization import Organization
class Provider(SqlalchemyBase, OrganizationMixin):
"""Provider ORM class"""
__tablename__ = "providers"
__pydantic_model__ = PydanticProvider
name: Mapped[str] = mapped_column(nullable=False, doc="The name of the provider")
api_key: Mapped[str] = mapped_column(nullable=True, doc="API key used for requests to the provider.")
# relationships
organization: Mapped["Organization"] = relationship("Organization", back_populates="providers")

View File

@ -1,16 +1,24 @@
from typing import List, Optional
from pydantic import BaseModel, Field, model_validator
from pydantic import Field, model_validator
from letta.constants import LLM_MAX_TOKENS, MIN_CONTEXT_WINDOW
from letta.llm_api.azure_openai import get_azure_chat_completions_endpoint, get_azure_embeddings_endpoint
from letta.llm_api.azure_openai_constants import AZURE_MODEL_TO_CONTEXT_LENGTH
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.letta_base import LettaBase
from letta.schemas.llm_config import LLMConfig
from letta.services.organization_manager import OrganizationManager
class Provider(BaseModel):
class ProviderBase(LettaBase):
__id_prefix__ = "provider"
class Provider(ProviderBase):
name: str = Field(..., description="The name of the provider")
api_key: Optional[str] = Field(None, description="API key used for requests to the provider.")
organization_id: Optional[str] = Field(OrganizationManager.DEFAULT_ORG_ID, description="The organization id of the user")
def list_llm_models(self) -> List[LLMConfig]:
return []
@ -29,6 +37,17 @@ class Provider(BaseModel):
return f"{self.name}/{model_name}"
class ProviderCreate(ProviderBase):
name: str = Field(..., description="The name of the provider.")
api_key: str = Field(..., description="API key used for requests to the provider.")
organization_id: str = Field(..., description="The organization id that this provider information pertains to.")
class ProviderUpdate(ProviderBase):
id: str = Field(..., description="The id of the provider to update.")
api_key: str = Field(..., description="API key used for requests to the provider.")
class LettaProvider(Provider):
name: str = "letta"

View File

@ -0,0 +1,72 @@
from fastapi import APIRouter, Depends
from letta.providers import Provider, ProviderCreate, ProviderUpdate
from letta.server.rest_api.utils import get_letta_server
if TYPE_CHECKING:
from letta.server.server import SyncServer
router = APIRouter(prefix="/providers", tags=["providers", "admin"])
@router.get("/", tags=["admin"], response_model=List[Provider], operation_id="list_providers")
def list_providers(
cursor: Optional[str] = Query(None),
limit: Optional[int] = Query(50),
server: "SyncServer" = Depends(get_letta_server),
):
"""
Get a list of all custom providers in the database
"""
try:
providers = server.provider_manager.list_providers(cursor=cursor, limit=limit)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
return providers
@router.post("/", tags=["admin"], response_model=Provider, operation_id="create_provider")
def create_provider(
request: ProviderCreate = Body(...),
server: "SyncServer" = Depends(get_letta_server),
):
"""
Create a new custom provider
"""
provider = Provider(**request.model_dump())
provider = server.provider_manager.create_provider(provider)
return provider
@router.put("/", tags=["admin"], response_model=Provider, operation_id="update_provider")
def update_provider(
request: ProviderUpdate = Body(...),
server: "SyncServer" = Depends(get_letta_server),
):
"""
Update an existing custom provider
"""
provider = server.provider_manager.update_provider(request)
return provider
@router.delete("/", tags=["admin"], response_model=Provider, operation_id="delete_provider")
def delete_provider(
provider_id: str = Query(..., description="The provider_id key to be deleted."),
server: "SyncServer" = Depends(get_letta_server),
):
"""
Delete an existing custom provider
"""
try:
provider = server.provider_manager.get_provider_by_id(provider_id=provider_id)
if provider is None:
raise HTTPException(status_code=404, detail=f"Provider does not exist")
server.provider_manager.delete_provider_by_id(provider_id=provider_id)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
return user

View File

@ -66,6 +66,7 @@ from letta.services.message_manager import MessageManager
from letta.services.organization_manager import OrganizationManager
from letta.services.passage_manager import PassageManager
from letta.services.per_agent_lock_manager import PerAgentLockManager
from letta.services.provider_manager import ProviderManager
from letta.services.sandbox_config_manager import SandboxConfigManager
from letta.services.source_manager import SourceManager
from letta.services.tool_execution_sandbox import ToolExecutionSandbox
@ -290,6 +291,7 @@ class SyncServer(Server):
self.message_manager = MessageManager()
self.job_manager = JobManager()
self.agent_manager = AgentManager()
self.provider_manager = ProviderManager()
# Managers that interface with parallelism
self.per_agent_lock_manager = PerAgentLockManager()
@ -1030,7 +1032,7 @@ class SyncServer(Server):
"""List available models"""
llm_models = []
for provider in self._enabled_providers:
for provider in self.get_enabled_providers():
try:
llm_models.extend(provider.list_llm_models())
except Exception as e:
@ -1040,13 +1042,19 @@ class SyncServer(Server):
def list_embedding_models(self) -> List[EmbeddingConfig]:
"""List available embedding models"""
embedding_models = []
for provider in self._enabled_providers:
for provider in self.get_enabled_providers():
try:
embedding_models.extend(provider.list_embedding_models())
except Exception as e:
warnings.warn(f"An error occurred while listing embedding models for provider {provider}: {e}")
return embedding_models
def get_enabled_providers(self):
providers_from_env = {p.name: p for p in self._enabled_providers}
providers_from_db = {p.name: p for p in self.provider_manager.list_providers()}
# Merge the two dictionaries, keeping the values from providers_from_db where conflicts occur
return {**providers_from_env, **providers_from_db}.values()
def get_llm_config_from_handle(self, handle: str, context_window_limit: Optional[int] = None) -> LLMConfig:
provider_name, model_name = handle.split("/", 1)
provider = self.get_provider_from_name(provider_name)

View File

@ -0,0 +1,63 @@
from typing import List, Optional
from letta.orm.provider import Provider as ProviderModel
from letta.providers import Provider as PydanticProvider
from letta.providers import ProviderUpdate
from letta.utils import enforce_types
class ProviderManager:
def __init__(self):
from letta.server.server import db_context
self.session_maker = db_context
@enforce_types
def create_provider(self, provider: PydanticProvider) -> PydanticProvider:
"""Create a new provider if it doesn't already exist."""
with self.session_maker() as session:
new_provider = ProviderModel(**provider.model_dump())
new_provider.create(session)
return new_provider.to_pydantic()
@enforce_types
def update_provider(self, provider_update: ProviderUpdate) -> PydanticProvider:
"""Update provider details."""
with self.session_maker() as session:
# Retrieve the existing provider by ID
existing_provider = ProviderModel.read(db_session=session, identifier=provider_update.id)
# Update only the fields that are provided in ProviderUpdate
update_data = provider_update.model_dump(exclude_unset=True, exclude_none=True)
for key, value in update_data.items():
setattr(existing_provider, key, value)
# Commit the updated provider
existing_provider.update(session)
return existing_provider.to_pydantic()
@enforce_types
def delete_provider_by_id(self, provider_id: str):
"""Delete a provider."""
with self.session_maker() as session:
# Delete from provider table
provider = ProviderModel.read(db_session=session, identifier=provider_id)
provider.hard_delete(session)
session.commit()
@enforce_types
def list_providers(self, cursor: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticProvider]:
"""List providers with pagination using cursor (id) and limit."""
with self.session_maker() as session:
results = ProviderModel.list(db_session=session, cursor=cursor, limit=limit)
return [provider.to_pydantic() for provider in results]
@enforce_types
def get_anthropic_key_override(self) -> Optional[str]:
"""Helper function to fetch custom anthropic key for v0 BYOK feature"""
providers = self.list_providers(limit=1)
if len(providers) == 1 and providers[0].name == "anthropic":
return providers[0].api_key
return None