mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
feat: support custom api keys for cloud (#533)
This commit is contained in:
parent
9a0613bdad
commit
4a2e321e99
47
alembic/versions/915b68780108_add_providers_data_to_orm.py
Normal file
47
alembic/versions/915b68780108_add_providers_data_to_orm.py
Normal file
@ -0,0 +1,47 @@
|
||||
"""Add providers data to ORM
|
||||
|
||||
Revision ID: 915b68780108
|
||||
Revises: 400501b04bf0
|
||||
Create Date: 2025-01-07 10:49:04.717058
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "915b68780108"
|
||||
down_revision: Union[str, None] = "400501b04bf0"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"providers",
|
||||
sa.Column("name", sa.String(), nullable=False),
|
||||
sa.Column("api_key", sa.String(), nullable=True),
|
||||
sa.Column("id", sa.String(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
|
||||
sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False),
|
||||
sa.Column("_created_by_id", sa.String(), nullable=True),
|
||||
sa.Column("_last_updated_by_id", sa.String(), nullable=True),
|
||||
sa.Column("organization_id", sa.String(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["organization_id"],
|
||||
["organizations.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table("providers")
|
||||
# ### end Alembic commands ###
|
@ -22,6 +22,7 @@ from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, Tool, cast_message_to_subtype
|
||||
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
|
||||
from letta.services.provider_manager import ProviderManager
|
||||
from letta.settings import ModelSettings
|
||||
from letta.streaming_interface import AgentChunkStreamingInterface, AgentRefreshStreamingInterface
|
||||
|
||||
@ -251,9 +252,12 @@ def create(
|
||||
tool_call = {"type": "function", "function": {"name": force_tool_call}}
|
||||
assert functions is not None
|
||||
|
||||
# load anthropic key from db in case a custom key has been stored
|
||||
anthropic_key_override = ProviderManager().get_anthropic_key_override()
|
||||
|
||||
return anthropic_chat_completions_request(
|
||||
url=llm_config.model_endpoint,
|
||||
api_key=model_settings.anthropic_api_key,
|
||||
api_key=anthropic_key_override if anthropic_key_override else model_settings.anthropic_api_key,
|
||||
data=ChatCompletionRequest(
|
||||
model=llm_config.model,
|
||||
messages=[cast_message_to_subtype(m.to_openai_dict()) for m in messages],
|
||||
|
@ -8,6 +8,7 @@ from letta.orm.job import Job
|
||||
from letta.orm.message import Message
|
||||
from letta.orm.organization import Organization
|
||||
from letta.orm.passage import AgentPassage, BasePassage, SourcePassage
|
||||
from letta.orm.provider import Provider
|
||||
from letta.orm.sandbox_config import AgentEnvironmentVariable, SandboxConfig, SandboxEnvironmentVariable
|
||||
from letta.orm.source import Source
|
||||
from letta.orm.sources_agents import SourcesAgents
|
||||
|
@ -9,6 +9,7 @@ if TYPE_CHECKING:
|
||||
|
||||
from letta.orm.agent import Agent
|
||||
from letta.orm.file import FileMetadata
|
||||
from letta.orm.provider import Provider
|
||||
from letta.orm.sandbox_config import AgentEnvironmentVariable
|
||||
from letta.orm.tool import Tool
|
||||
from letta.orm.user import User
|
||||
@ -45,6 +46,7 @@ class Organization(SqlalchemyBase):
|
||||
"SourcePassage", back_populates="organization", cascade="all, delete-orphan"
|
||||
)
|
||||
agent_passages: Mapped[List["AgentPassage"]] = relationship("AgentPassage", back_populates="organization", cascade="all, delete-orphan")
|
||||
providers: Mapped[List["Provider"]] = relationship("Provider", back_populates="organization", cascade="all, delete-orphan")
|
||||
|
||||
@property
|
||||
def passages(self) -> List[Union["SourcePassage", "AgentPassage"]]:
|
||||
|
23
letta/orm/provider.py
Normal file
23
letta/orm/provider.py
Normal file
@ -0,0 +1,23 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from letta.orm.mixins import OrganizationMixin
|
||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||
from letta.providers import Provider as PydanticProvider
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.orm.organization import Organization
|
||||
|
||||
|
||||
class Provider(SqlalchemyBase, OrganizationMixin):
|
||||
"""Provider ORM class"""
|
||||
|
||||
__tablename__ = "providers"
|
||||
__pydantic_model__ = PydanticProvider
|
||||
|
||||
name: Mapped[str] = mapped_column(nullable=False, doc="The name of the provider")
|
||||
api_key: Mapped[str] = mapped_column(nullable=True, doc="API key used for requests to the provider.")
|
||||
|
||||
# relationships
|
||||
organization: Mapped["Organization"] = relationship("Organization", back_populates="providers")
|
@ -1,16 +1,24 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from pydantic import Field, model_validator
|
||||
|
||||
from letta.constants import LLM_MAX_TOKENS, MIN_CONTEXT_WINDOW
|
||||
from letta.llm_api.azure_openai import get_azure_chat_completions_endpoint, get_azure_embeddings_endpoint
|
||||
from letta.llm_api.azure_openai_constants import AZURE_MODEL_TO_CONTEXT_LENGTH
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.letta_base import LettaBase
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.services.organization_manager import OrganizationManager
|
||||
|
||||
|
||||
class Provider(BaseModel):
|
||||
class ProviderBase(LettaBase):
|
||||
__id_prefix__ = "provider"
|
||||
|
||||
|
||||
class Provider(ProviderBase):
|
||||
name: str = Field(..., description="The name of the provider")
|
||||
api_key: Optional[str] = Field(None, description="API key used for requests to the provider.")
|
||||
organization_id: Optional[str] = Field(OrganizationManager.DEFAULT_ORG_ID, description="The organization id of the user")
|
||||
|
||||
def list_llm_models(self) -> List[LLMConfig]:
|
||||
return []
|
||||
@ -29,6 +37,17 @@ class Provider(BaseModel):
|
||||
return f"{self.name}/{model_name}"
|
||||
|
||||
|
||||
class ProviderCreate(ProviderBase):
|
||||
name: str = Field(..., description="The name of the provider.")
|
||||
api_key: str = Field(..., description="API key used for requests to the provider.")
|
||||
organization_id: str = Field(..., description="The organization id that this provider information pertains to.")
|
||||
|
||||
|
||||
class ProviderUpdate(ProviderBase):
|
||||
id: str = Field(..., description="The id of the provider to update.")
|
||||
api_key: str = Field(..., description="API key used for requests to the provider.")
|
||||
|
||||
|
||||
class LettaProvider(Provider):
|
||||
|
||||
name: str = "letta"
|
||||
|
72
letta/server/rest_api/routers/v1/providers.py
Normal file
72
letta/server/rest_api/routers/v1/providers.py
Normal file
@ -0,0 +1,72 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from letta.providers import Provider, ProviderCreate, ProviderUpdate
|
||||
from letta.server.rest_api.utils import get_letta_server
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.server.server import SyncServer
|
||||
|
||||
router = APIRouter(prefix="/providers", tags=["providers", "admin"])
|
||||
|
||||
|
||||
@router.get("/", tags=["admin"], response_model=List[Provider], operation_id="list_providers")
|
||||
def list_providers(
|
||||
cursor: Optional[str] = Query(None),
|
||||
limit: Optional[int] = Query(50),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
):
|
||||
"""
|
||||
Get a list of all custom providers in the database
|
||||
"""
|
||||
try:
|
||||
providers = server.provider_manager.list_providers(cursor=cursor, limit=limit)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"{e}")
|
||||
return providers
|
||||
|
||||
|
||||
@router.post("/", tags=["admin"], response_model=Provider, operation_id="create_provider")
|
||||
def create_provider(
|
||||
request: ProviderCreate = Body(...),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
):
|
||||
"""
|
||||
Create a new custom provider
|
||||
"""
|
||||
provider = Provider(**request.model_dump())
|
||||
provider = server.provider_manager.create_provider(provider)
|
||||
return provider
|
||||
|
||||
|
||||
@router.put("/", tags=["admin"], response_model=Provider, operation_id="update_provider")
|
||||
def update_provider(
|
||||
request: ProviderUpdate = Body(...),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
):
|
||||
"""
|
||||
Update an existing custom provider
|
||||
"""
|
||||
provider = server.provider_manager.update_provider(request)
|
||||
return provider
|
||||
|
||||
|
||||
@router.delete("/", tags=["admin"], response_model=Provider, operation_id="delete_provider")
|
||||
def delete_provider(
|
||||
provider_id: str = Query(..., description="The provider_id key to be deleted."),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
):
|
||||
"""
|
||||
Delete an existing custom provider
|
||||
"""
|
||||
try:
|
||||
provider = server.provider_manager.get_provider_by_id(provider_id=provider_id)
|
||||
if provider is None:
|
||||
raise HTTPException(status_code=404, detail=f"Provider does not exist")
|
||||
server.provider_manager.delete_provider_by_id(provider_id=provider_id)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"{e}")
|
||||
return user
|
@ -66,6 +66,7 @@ from letta.services.message_manager import MessageManager
|
||||
from letta.services.organization_manager import OrganizationManager
|
||||
from letta.services.passage_manager import PassageManager
|
||||
from letta.services.per_agent_lock_manager import PerAgentLockManager
|
||||
from letta.services.provider_manager import ProviderManager
|
||||
from letta.services.sandbox_config_manager import SandboxConfigManager
|
||||
from letta.services.source_manager import SourceManager
|
||||
from letta.services.tool_execution_sandbox import ToolExecutionSandbox
|
||||
@ -290,6 +291,7 @@ class SyncServer(Server):
|
||||
self.message_manager = MessageManager()
|
||||
self.job_manager = JobManager()
|
||||
self.agent_manager = AgentManager()
|
||||
self.provider_manager = ProviderManager()
|
||||
|
||||
# Managers that interface with parallelism
|
||||
self.per_agent_lock_manager = PerAgentLockManager()
|
||||
@ -1030,7 +1032,7 @@ class SyncServer(Server):
|
||||
"""List available models"""
|
||||
|
||||
llm_models = []
|
||||
for provider in self._enabled_providers:
|
||||
for provider in self.get_enabled_providers():
|
||||
try:
|
||||
llm_models.extend(provider.list_llm_models())
|
||||
except Exception as e:
|
||||
@ -1040,13 +1042,19 @@ class SyncServer(Server):
|
||||
def list_embedding_models(self) -> List[EmbeddingConfig]:
|
||||
"""List available embedding models"""
|
||||
embedding_models = []
|
||||
for provider in self._enabled_providers:
|
||||
for provider in self.get_enabled_providers():
|
||||
try:
|
||||
embedding_models.extend(provider.list_embedding_models())
|
||||
except Exception as e:
|
||||
warnings.warn(f"An error occurred while listing embedding models for provider {provider}: {e}")
|
||||
return embedding_models
|
||||
|
||||
def get_enabled_providers(self):
|
||||
providers_from_env = {p.name: p for p in self._enabled_providers}
|
||||
providers_from_db = {p.name: p for p in self.provider_manager.list_providers()}
|
||||
# Merge the two dictionaries, keeping the values from providers_from_db where conflicts occur
|
||||
return {**providers_from_env, **providers_from_db}.values()
|
||||
|
||||
def get_llm_config_from_handle(self, handle: str, context_window_limit: Optional[int] = None) -> LLMConfig:
|
||||
provider_name, model_name = handle.split("/", 1)
|
||||
provider = self.get_provider_from_name(provider_name)
|
||||
|
63
letta/services/provider_manager.py
Normal file
63
letta/services/provider_manager.py
Normal file
@ -0,0 +1,63 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from letta.orm.provider import Provider as ProviderModel
|
||||
from letta.providers import Provider as PydanticProvider
|
||||
from letta.providers import ProviderUpdate
|
||||
from letta.utils import enforce_types
|
||||
|
||||
|
||||
class ProviderManager:
|
||||
|
||||
def __init__(self):
|
||||
from letta.server.server import db_context
|
||||
|
||||
self.session_maker = db_context
|
||||
|
||||
@enforce_types
|
||||
def create_provider(self, provider: PydanticProvider) -> PydanticProvider:
|
||||
"""Create a new provider if it doesn't already exist."""
|
||||
with self.session_maker() as session:
|
||||
new_provider = ProviderModel(**provider.model_dump())
|
||||
new_provider.create(session)
|
||||
return new_provider.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def update_provider(self, provider_update: ProviderUpdate) -> PydanticProvider:
|
||||
"""Update provider details."""
|
||||
with self.session_maker() as session:
|
||||
# Retrieve the existing provider by ID
|
||||
existing_provider = ProviderModel.read(db_session=session, identifier=provider_update.id)
|
||||
|
||||
# Update only the fields that are provided in ProviderUpdate
|
||||
update_data = provider_update.model_dump(exclude_unset=True, exclude_none=True)
|
||||
for key, value in update_data.items():
|
||||
setattr(existing_provider, key, value)
|
||||
|
||||
# Commit the updated provider
|
||||
existing_provider.update(session)
|
||||
return existing_provider.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def delete_provider_by_id(self, provider_id: str):
|
||||
"""Delete a provider."""
|
||||
with self.session_maker() as session:
|
||||
# Delete from provider table
|
||||
provider = ProviderModel.read(db_session=session, identifier=provider_id)
|
||||
provider.hard_delete(session)
|
||||
|
||||
session.commit()
|
||||
|
||||
@enforce_types
|
||||
def list_providers(self, cursor: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticProvider]:
|
||||
"""List providers with pagination using cursor (id) and limit."""
|
||||
with self.session_maker() as session:
|
||||
results = ProviderModel.list(db_session=session, cursor=cursor, limit=limit)
|
||||
return [provider.to_pydantic() for provider in results]
|
||||
|
||||
@enforce_types
|
||||
def get_anthropic_key_override(self) -> Optional[str]:
|
||||
"""Helper function to fetch custom anthropic key for v0 BYOK feature"""
|
||||
providers = self.list_providers(limit=1)
|
||||
if len(providers) == 1 and providers[0].name == "anthropic":
|
||||
return providers[0].api_key
|
||||
return None
|
Loading…
Reference in New Issue
Block a user