feat: add privileged organizations (#1365)

This commit is contained in:
Sarah Wooders 2025-03-24 16:12:53 -07:00 committed by GitHub
parent 1d552bde55
commit e1ec91baff
7 changed files with 93 additions and 2 deletions

View File

@ -0,0 +1,39 @@
"""add privileged_tools to Organization
Revision ID: bdddd421ec41
Revises: 1e553a664210
Create Date: 2025-03-21 17:55:30.405519
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "bdddd421ec41"
down_revision: Union[str, None] = "1e553a664210"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Step 1: Add `privileged_tools` column with nullable=True
op.add_column("organizations", sa.Column("privileged_tools", sa.Boolean(), nullable=True))
# fill in column with `False`
op.execute(
f"""
UPDATE organizations
SET privileged_tools = False
"""
)
# Step 2: Make `privileged_tools` non-nullable
op.alter_column("organizations", "privileged_tools", nullable=False)
def downgrade() -> None:
op.drop_column("organizations", "privileged_tools")

View File

@ -23,6 +23,7 @@ class Organization(SqlalchemyBase):
__pydantic_model__ = PydanticOrganization
name: Mapped[str] = mapped_column(doc="The display name of the organization.")
privileged_tools: Mapped[bool] = mapped_column(doc="Whether the organization has access to privileged tools.")
# relationships
users: Mapped[List["User"]] = relationship("User", back_populates="organization", cascade="all, delete-orphan")

View File

@ -16,7 +16,14 @@ class Organization(OrganizationBase):
id: str = OrganizationBase.generate_id_field()
name: str = Field(create_random_username(), description="The name of the organization.", json_schema_extra={"default": "SincereYogurt"})
created_at: Optional[datetime] = Field(default_factory=get_utc_time, description="The creation date of the organization.")
privileged_tools: bool = Field(False, description="Whether the organization has access to privileged tools.")
class OrganizationCreate(OrganizationBase):
name: Optional[str] = Field(None, description="The name of the organization.")
privileged_tools: Optional[bool] = Field(False, description="Whether the organization has access to privileged tools.")
class OrganizationUpdate(OrganizationBase):
name: Optional[str] = Field(None, description="The name of the organization.")
privileged_tools: Optional[bool] = Field(False, description="Whether the organization has access to privileged tools.")

View File

@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, List, Optional
from fastapi import APIRouter, Body, Depends, HTTPException, Query
from letta.schemas.organization import Organization, OrganizationCreate
from letta.schemas.organization import Organization, OrganizationCreate, OrganizationUpdate
from letta.server.rest_api.utils import get_letta_server
if TYPE_CHECKING:
@ -59,3 +59,21 @@ def delete_org(
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
return org
@router.patch("/", tags=["admin"], response_model=Organization, operation_id="update_organization")
def update_org(
org_id: str = Query(..., description="The org_id key to be updated."),
request: OrganizationUpdate = Body(...),
server: "SyncServer" = Depends(get_letta_server),
):
try:
org = server.organization_manager.get_organization_by_id(org_id=org_id)
if org is None:
raise HTTPException(status_code=404, detail=f"Organization does not exist")
org = server.organization_manager.update_organization(org_id=org_id, name=request.name)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
return org

View File

@ -3,6 +3,7 @@ from typing import List, Optional
from letta.orm.errors import NoResultFound
from letta.orm.organization import Organization as OrganizationModel
from letta.schemas.organization import Organization as PydanticOrganization
from letta.schemas.organization import OrganizationUpdate
from letta.utils import enforce_types
@ -63,6 +64,18 @@ class OrganizationManager:
org.update(session)
return org.to_pydantic()
@enforce_types
def update_organization(self, org_id: str, org_update: OrganizationUpdate) -> PydanticOrganization:
"""Update an organization."""
with self.session_maker() as session:
org = OrganizationModel.read(db_session=session, identifier=org_id)
if org_update.name:
org.name = org_update.name
if org_update.privileged_tools:
org.privileged_tools = org_update.privileged_tools
org.update(session)
return org.to_pydantic()
@enforce_types
def delete_organization_by_id(self, org_id: str):
"""Delete an organization by marking it as deleted."""

View File

@ -23,6 +23,7 @@ from letta.services.helpers.tool_execution_helper import (
find_python_executable,
install_pip_requirements_for_sandbox,
)
from letta.services.organization_manager import OrganizationManager
from letta.services.sandbox_config_manager import SandboxConfigManager
from letta.services.tool_manager import ToolManager
from letta.settings import tool_settings
@ -50,6 +51,9 @@ class ToolExecutionSandbox:
self.tool_name = tool_name
self.args = args
self.user = user
# get organization
self.organization = OrganizationManager().get_organization_by_id(self.user.organization_id)
self.privileged_tools = self.organization.privileged_tools
# If a tool object is provided, we use it directly, otherwise pull via name
if tool_object is not None:
@ -79,7 +83,7 @@ class ToolExecutionSandbox:
Returns:
Tuple[Any, Optional[AgentState]]: Tuple containing (tool_result, agent_state)
"""
if tool_settings.e2b_api_key:
if tool_settings.e2b_api_key and not self.privileged_tools:
logger.debug(f"Using e2b sandbox to execute {self.tool_name}")
result = self.run_e2b_sandbox(agent_state=agent_state, additional_env_vars=additional_env_vars)
else:

View File

@ -32,6 +32,7 @@ from letta.schemas.message import Message as PydanticMessage
from letta.schemas.message import MessageCreate, MessageUpdate
from letta.schemas.openai.chat_completion_response import UsageStatistics
from letta.schemas.organization import Organization as PydanticOrganization
from letta.schemas.organization import OrganizationUpdate
from letta.schemas.passage import Passage as PydanticPassage
from letta.schemas.run import Run as PydanticRun
from letta.schemas.sandbox_config import E2BSandboxConfig, LocalSandboxConfig, SandboxConfigCreate, SandboxConfigUpdate, SandboxType
@ -1788,6 +1789,14 @@ def test_update_organization_name(server: SyncServer):
assert org.name == org_name_b
def test_update_organization_privileged_tools(server: SyncServer):
org_name = "test"
org = server.organization_manager.create_organization(pydantic_org=PydanticOrganization(name=org_name))
assert org.privileged_tools == False
org = server.organization_manager.update_organization(org_id=org.id, org_update=OrganizationUpdate(privileged_tools=True))
assert org.privileged_tools == True
def test_list_organizations_pagination(server: SyncServer):
server.organization_manager.create_organization(pydantic_org=PydanticOrganization(name="a"))
server.organization_manager.create_organization(pydantic_org=PydanticOrganization(name="b"))