mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
feat: support custom api keys for cloud (#533)
This commit is contained in:
parent
9a0613bdad
commit
4a2e321e99
47
alembic/versions/915b68780108_add_providers_data_to_orm.py
Normal file
47
alembic/versions/915b68780108_add_providers_data_to_orm.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
"""Add providers data to ORM
|
||||||
|
|
||||||
|
Revision ID: 915b68780108
|
||||||
|
Revises: 400501b04bf0
|
||||||
|
Create Date: 2025-01-07 10:49:04.717058
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "915b68780108"
|
||||||
|
down_revision: Union[str, None] = "400501b04bf0"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.create_table(
|
||||||
|
"providers",
|
||||||
|
sa.Column("name", sa.String(), nullable=False),
|
||||||
|
sa.Column("api_key", sa.String(), nullable=True),
|
||||||
|
sa.Column("id", sa.String(), nullable=False),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
|
||||||
|
sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False),
|
||||||
|
sa.Column("_created_by_id", sa.String(), nullable=True),
|
||||||
|
sa.Column("_last_updated_by_id", sa.String(), nullable=True),
|
||||||
|
sa.Column("organization_id", sa.String(), nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["organization_id"],
|
||||||
|
["organizations.id"],
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
)
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_table("providers")
|
||||||
|
# ### end Alembic commands ###
|
@ -22,6 +22,7 @@ from letta.schemas.llm_config import LLMConfig
|
|||||||
from letta.schemas.message import Message
|
from letta.schemas.message import Message
|
||||||
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, Tool, cast_message_to_subtype
|
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, Tool, cast_message_to_subtype
|
||||||
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
|
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
|
||||||
|
from letta.services.provider_manager import ProviderManager
|
||||||
from letta.settings import ModelSettings
|
from letta.settings import ModelSettings
|
||||||
from letta.streaming_interface import AgentChunkStreamingInterface, AgentRefreshStreamingInterface
|
from letta.streaming_interface import AgentChunkStreamingInterface, AgentRefreshStreamingInterface
|
||||||
|
|
||||||
@ -251,9 +252,12 @@ def create(
|
|||||||
tool_call = {"type": "function", "function": {"name": force_tool_call}}
|
tool_call = {"type": "function", "function": {"name": force_tool_call}}
|
||||||
assert functions is not None
|
assert functions is not None
|
||||||
|
|
||||||
|
# load anthropic key from db in case a custom key has been stored
|
||||||
|
anthropic_key_override = ProviderManager().get_anthropic_key_override()
|
||||||
|
|
||||||
return anthropic_chat_completions_request(
|
return anthropic_chat_completions_request(
|
||||||
url=llm_config.model_endpoint,
|
url=llm_config.model_endpoint,
|
||||||
api_key=model_settings.anthropic_api_key,
|
api_key=anthropic_key_override if anthropic_key_override else model_settings.anthropic_api_key,
|
||||||
data=ChatCompletionRequest(
|
data=ChatCompletionRequest(
|
||||||
model=llm_config.model,
|
model=llm_config.model,
|
||||||
messages=[cast_message_to_subtype(m.to_openai_dict()) for m in messages],
|
messages=[cast_message_to_subtype(m.to_openai_dict()) for m in messages],
|
||||||
|
@ -8,6 +8,7 @@ from letta.orm.job import Job
|
|||||||
from letta.orm.message import Message
|
from letta.orm.message import Message
|
||||||
from letta.orm.organization import Organization
|
from letta.orm.organization import Organization
|
||||||
from letta.orm.passage import AgentPassage, BasePassage, SourcePassage
|
from letta.orm.passage import AgentPassage, BasePassage, SourcePassage
|
||||||
|
from letta.orm.provider import Provider
|
||||||
from letta.orm.sandbox_config import AgentEnvironmentVariable, SandboxConfig, SandboxEnvironmentVariable
|
from letta.orm.sandbox_config import AgentEnvironmentVariable, SandboxConfig, SandboxEnvironmentVariable
|
||||||
from letta.orm.source import Source
|
from letta.orm.source import Source
|
||||||
from letta.orm.sources_agents import SourcesAgents
|
from letta.orm.sources_agents import SourcesAgents
|
||||||
|
@ -9,6 +9,7 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
from letta.orm.agent import Agent
|
from letta.orm.agent import Agent
|
||||||
from letta.orm.file import FileMetadata
|
from letta.orm.file import FileMetadata
|
||||||
|
from letta.orm.provider import Provider
|
||||||
from letta.orm.sandbox_config import AgentEnvironmentVariable
|
from letta.orm.sandbox_config import AgentEnvironmentVariable
|
||||||
from letta.orm.tool import Tool
|
from letta.orm.tool import Tool
|
||||||
from letta.orm.user import User
|
from letta.orm.user import User
|
||||||
@ -45,6 +46,7 @@ class Organization(SqlalchemyBase):
|
|||||||
"SourcePassage", back_populates="organization", cascade="all, delete-orphan"
|
"SourcePassage", back_populates="organization", cascade="all, delete-orphan"
|
||||||
)
|
)
|
||||||
agent_passages: Mapped[List["AgentPassage"]] = relationship("AgentPassage", back_populates="organization", cascade="all, delete-orphan")
|
agent_passages: Mapped[List["AgentPassage"]] = relationship("AgentPassage", back_populates="organization", cascade="all, delete-orphan")
|
||||||
|
providers: Mapped[List["Provider"]] = relationship("Provider", back_populates="organization", cascade="all, delete-orphan")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def passages(self) -> List[Union["SourcePassage", "AgentPassage"]]:
|
def passages(self) -> List[Union["SourcePassage", "AgentPassage"]]:
|
||||||
|
23
letta/orm/provider.py
Normal file
23
letta/orm/provider.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
|
from letta.orm.mixins import OrganizationMixin
|
||||||
|
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||||
|
from letta.providers import Provider as PydanticProvider
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from letta.orm.organization import Organization
|
||||||
|
|
||||||
|
|
||||||
|
class Provider(SqlalchemyBase, OrganizationMixin):
|
||||||
|
"""Provider ORM class"""
|
||||||
|
|
||||||
|
__tablename__ = "providers"
|
||||||
|
__pydantic_model__ = PydanticProvider
|
||||||
|
|
||||||
|
name: Mapped[str] = mapped_column(nullable=False, doc="The name of the provider")
|
||||||
|
api_key: Mapped[str] = mapped_column(nullable=True, doc="API key used for requests to the provider.")
|
||||||
|
|
||||||
|
# relationships
|
||||||
|
organization: Mapped["Organization"] = relationship("Organization", back_populates="providers")
|
@ -1,16 +1,24 @@
|
|||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, model_validator
|
from pydantic import Field, model_validator
|
||||||
|
|
||||||
from letta.constants import LLM_MAX_TOKENS, MIN_CONTEXT_WINDOW
|
from letta.constants import LLM_MAX_TOKENS, MIN_CONTEXT_WINDOW
|
||||||
from letta.llm_api.azure_openai import get_azure_chat_completions_endpoint, get_azure_embeddings_endpoint
|
from letta.llm_api.azure_openai import get_azure_chat_completions_endpoint, get_azure_embeddings_endpoint
|
||||||
from letta.llm_api.azure_openai_constants import AZURE_MODEL_TO_CONTEXT_LENGTH
|
from letta.llm_api.azure_openai_constants import AZURE_MODEL_TO_CONTEXT_LENGTH
|
||||||
from letta.schemas.embedding_config import EmbeddingConfig
|
from letta.schemas.embedding_config import EmbeddingConfig
|
||||||
|
from letta.schemas.letta_base import LettaBase
|
||||||
from letta.schemas.llm_config import LLMConfig
|
from letta.schemas.llm_config import LLMConfig
|
||||||
|
from letta.services.organization_manager import OrganizationManager
|
||||||
|
|
||||||
|
|
||||||
class Provider(BaseModel):
|
class ProviderBase(LettaBase):
|
||||||
|
__id_prefix__ = "provider"
|
||||||
|
|
||||||
|
|
||||||
|
class Provider(ProviderBase):
|
||||||
name: str = Field(..., description="The name of the provider")
|
name: str = Field(..., description="The name of the provider")
|
||||||
|
api_key: Optional[str] = Field(None, description="API key used for requests to the provider.")
|
||||||
|
organization_id: Optional[str] = Field(OrganizationManager.DEFAULT_ORG_ID, description="The organization id of the user")
|
||||||
|
|
||||||
def list_llm_models(self) -> List[LLMConfig]:
|
def list_llm_models(self) -> List[LLMConfig]:
|
||||||
return []
|
return []
|
||||||
@ -29,6 +37,17 @@ class Provider(BaseModel):
|
|||||||
return f"{self.name}/{model_name}"
|
return f"{self.name}/{model_name}"
|
||||||
|
|
||||||
|
|
||||||
|
class ProviderCreate(ProviderBase):
|
||||||
|
name: str = Field(..., description="The name of the provider.")
|
||||||
|
api_key: str = Field(..., description="API key used for requests to the provider.")
|
||||||
|
organization_id: str = Field(..., description="The organization id that this provider information pertains to.")
|
||||||
|
|
||||||
|
|
||||||
|
class ProviderUpdate(ProviderBase):
|
||||||
|
id: str = Field(..., description="The id of the provider to update.")
|
||||||
|
api_key: str = Field(..., description="API key used for requests to the provider.")
|
||||||
|
|
||||||
|
|
||||||
class LettaProvider(Provider):
|
class LettaProvider(Provider):
|
||||||
|
|
||||||
name: str = "letta"
|
name: str = "letta"
|
||||||
|
72
letta/server/rest_api/routers/v1/providers.py
Normal file
72
letta/server/rest_api/routers/v1/providers.py
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
from fastapi import APIRouter, Depends
|
||||||
|
|
||||||
|
from letta.providers import Provider, ProviderCreate, ProviderUpdate
|
||||||
|
from letta.server.rest_api.utils import get_letta_server
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from letta.server.server import SyncServer
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/providers", tags=["providers", "admin"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/", tags=["admin"], response_model=List[Provider], operation_id="list_providers")
|
||||||
|
def list_providers(
|
||||||
|
cursor: Optional[str] = Query(None),
|
||||||
|
limit: Optional[int] = Query(50),
|
||||||
|
server: "SyncServer" = Depends(get_letta_server),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get a list of all custom providers in the database
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
providers = server.provider_manager.list_providers(cursor=cursor, limit=limit)
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=f"{e}")
|
||||||
|
return providers
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/", tags=["admin"], response_model=Provider, operation_id="create_provider")
|
||||||
|
def create_provider(
|
||||||
|
request: ProviderCreate = Body(...),
|
||||||
|
server: "SyncServer" = Depends(get_letta_server),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Create a new custom provider
|
||||||
|
"""
|
||||||
|
provider = Provider(**request.model_dump())
|
||||||
|
provider = server.provider_manager.create_provider(provider)
|
||||||
|
return provider
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/", tags=["admin"], response_model=Provider, operation_id="update_provider")
|
||||||
|
def update_provider(
|
||||||
|
request: ProviderUpdate = Body(...),
|
||||||
|
server: "SyncServer" = Depends(get_letta_server),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Update an existing custom provider
|
||||||
|
"""
|
||||||
|
provider = server.provider_manager.update_provider(request)
|
||||||
|
return provider
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/", tags=["admin"], response_model=Provider, operation_id="delete_provider")
|
||||||
|
def delete_provider(
|
||||||
|
provider_id: str = Query(..., description="The provider_id key to be deleted."),
|
||||||
|
server: "SyncServer" = Depends(get_letta_server),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Delete an existing custom provider
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
provider = server.provider_manager.get_provider_by_id(provider_id=provider_id)
|
||||||
|
if provider is None:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Provider does not exist")
|
||||||
|
server.provider_manager.delete_provider_by_id(provider_id=provider_id)
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=f"{e}")
|
||||||
|
return user
|
@ -66,6 +66,7 @@ from letta.services.message_manager import MessageManager
|
|||||||
from letta.services.organization_manager import OrganizationManager
|
from letta.services.organization_manager import OrganizationManager
|
||||||
from letta.services.passage_manager import PassageManager
|
from letta.services.passage_manager import PassageManager
|
||||||
from letta.services.per_agent_lock_manager import PerAgentLockManager
|
from letta.services.per_agent_lock_manager import PerAgentLockManager
|
||||||
|
from letta.services.provider_manager import ProviderManager
|
||||||
from letta.services.sandbox_config_manager import SandboxConfigManager
|
from letta.services.sandbox_config_manager import SandboxConfigManager
|
||||||
from letta.services.source_manager import SourceManager
|
from letta.services.source_manager import SourceManager
|
||||||
from letta.services.tool_execution_sandbox import ToolExecutionSandbox
|
from letta.services.tool_execution_sandbox import ToolExecutionSandbox
|
||||||
@ -290,6 +291,7 @@ class SyncServer(Server):
|
|||||||
self.message_manager = MessageManager()
|
self.message_manager = MessageManager()
|
||||||
self.job_manager = JobManager()
|
self.job_manager = JobManager()
|
||||||
self.agent_manager = AgentManager()
|
self.agent_manager = AgentManager()
|
||||||
|
self.provider_manager = ProviderManager()
|
||||||
|
|
||||||
# Managers that interface with parallelism
|
# Managers that interface with parallelism
|
||||||
self.per_agent_lock_manager = PerAgentLockManager()
|
self.per_agent_lock_manager = PerAgentLockManager()
|
||||||
@ -1030,7 +1032,7 @@ class SyncServer(Server):
|
|||||||
"""List available models"""
|
"""List available models"""
|
||||||
|
|
||||||
llm_models = []
|
llm_models = []
|
||||||
for provider in self._enabled_providers:
|
for provider in self.get_enabled_providers():
|
||||||
try:
|
try:
|
||||||
llm_models.extend(provider.list_llm_models())
|
llm_models.extend(provider.list_llm_models())
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -1040,13 +1042,19 @@ class SyncServer(Server):
|
|||||||
def list_embedding_models(self) -> List[EmbeddingConfig]:
|
def list_embedding_models(self) -> List[EmbeddingConfig]:
|
||||||
"""List available embedding models"""
|
"""List available embedding models"""
|
||||||
embedding_models = []
|
embedding_models = []
|
||||||
for provider in self._enabled_providers:
|
for provider in self.get_enabled_providers():
|
||||||
try:
|
try:
|
||||||
embedding_models.extend(provider.list_embedding_models())
|
embedding_models.extend(provider.list_embedding_models())
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
warnings.warn(f"An error occurred while listing embedding models for provider {provider}: {e}")
|
warnings.warn(f"An error occurred while listing embedding models for provider {provider}: {e}")
|
||||||
return embedding_models
|
return embedding_models
|
||||||
|
|
||||||
|
def get_enabled_providers(self):
|
||||||
|
providers_from_env = {p.name: p for p in self._enabled_providers}
|
||||||
|
providers_from_db = {p.name: p for p in self.provider_manager.list_providers()}
|
||||||
|
# Merge the two dictionaries, keeping the values from providers_from_db where conflicts occur
|
||||||
|
return {**providers_from_env, **providers_from_db}.values()
|
||||||
|
|
||||||
def get_llm_config_from_handle(self, handle: str, context_window_limit: Optional[int] = None) -> LLMConfig:
|
def get_llm_config_from_handle(self, handle: str, context_window_limit: Optional[int] = None) -> LLMConfig:
|
||||||
provider_name, model_name = handle.split("/", 1)
|
provider_name, model_name = handle.split("/", 1)
|
||||||
provider = self.get_provider_from_name(provider_name)
|
provider = self.get_provider_from_name(provider_name)
|
||||||
|
63
letta/services/provider_manager.py
Normal file
63
letta/services/provider_manager.py
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from letta.orm.provider import Provider as ProviderModel
|
||||||
|
from letta.providers import Provider as PydanticProvider
|
||||||
|
from letta.providers import ProviderUpdate
|
||||||
|
from letta.utils import enforce_types
|
||||||
|
|
||||||
|
|
||||||
|
class ProviderManager:
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
from letta.server.server import db_context
|
||||||
|
|
||||||
|
self.session_maker = db_context
|
||||||
|
|
||||||
|
@enforce_types
|
||||||
|
def create_provider(self, provider: PydanticProvider) -> PydanticProvider:
|
||||||
|
"""Create a new provider if it doesn't already exist."""
|
||||||
|
with self.session_maker() as session:
|
||||||
|
new_provider = ProviderModel(**provider.model_dump())
|
||||||
|
new_provider.create(session)
|
||||||
|
return new_provider.to_pydantic()
|
||||||
|
|
||||||
|
@enforce_types
|
||||||
|
def update_provider(self, provider_update: ProviderUpdate) -> PydanticProvider:
|
||||||
|
"""Update provider details."""
|
||||||
|
with self.session_maker() as session:
|
||||||
|
# Retrieve the existing provider by ID
|
||||||
|
existing_provider = ProviderModel.read(db_session=session, identifier=provider_update.id)
|
||||||
|
|
||||||
|
# Update only the fields that are provided in ProviderUpdate
|
||||||
|
update_data = provider_update.model_dump(exclude_unset=True, exclude_none=True)
|
||||||
|
for key, value in update_data.items():
|
||||||
|
setattr(existing_provider, key, value)
|
||||||
|
|
||||||
|
# Commit the updated provider
|
||||||
|
existing_provider.update(session)
|
||||||
|
return existing_provider.to_pydantic()
|
||||||
|
|
||||||
|
@enforce_types
|
||||||
|
def delete_provider_by_id(self, provider_id: str):
|
||||||
|
"""Delete a provider."""
|
||||||
|
with self.session_maker() as session:
|
||||||
|
# Delete from provider table
|
||||||
|
provider = ProviderModel.read(db_session=session, identifier=provider_id)
|
||||||
|
provider.hard_delete(session)
|
||||||
|
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
@enforce_types
|
||||||
|
def list_providers(self, cursor: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticProvider]:
|
||||||
|
"""List providers with pagination using cursor (id) and limit."""
|
||||||
|
with self.session_maker() as session:
|
||||||
|
results = ProviderModel.list(db_session=session, cursor=cursor, limit=limit)
|
||||||
|
return [provider.to_pydantic() for provider in results]
|
||||||
|
|
||||||
|
@enforce_types
|
||||||
|
def get_anthropic_key_override(self) -> Optional[str]:
|
||||||
|
"""Helper function to fetch custom anthropic key for v0 BYOK feature"""
|
||||||
|
providers = self.list_providers(limit=1)
|
||||||
|
if len(providers) == 1 and providers[0].name == "anthropic":
|
||||||
|
return providers[0].api_key
|
||||||
|
return None
|
Loading…
Reference in New Issue
Block a user