feat: add privileged organizations (#1365)

This commit is contained in:
Sarah Wooders 2025-03-24 16:12:53 -07:00 committed by GitHub
parent 1d552bde55
commit e1ec91baff
7 changed files with 93 additions and 2 deletions

View File

@ -0,0 +1,39 @@
"""add privileged_tools to Organization
Revision ID: bdddd421ec41
Revises: 1e553a664210
Create Date: 2025-03-21 17:55:30.405519
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "bdddd421ec41"
down_revision: Union[str, None] = "1e553a664210"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Step 1: Add `privileged_tools` column with nullable=True
op.add_column("organizations", sa.Column("privileged_tools", sa.Boolean(), nullable=True))
# fill in column with `False`
op.execute(
f"""
UPDATE organizations
SET privileged_tools = False
"""
)
# Step 2: Make `privileged_tools` non-nullable
op.alter_column("organizations", "privileged_tools", nullable=False)
def downgrade() -> None:
op.drop_column("organizations", "privileged_tools")

View File

@ -23,6 +23,7 @@ class Organization(SqlalchemyBase):
__pydantic_model__ = PydanticOrganization __pydantic_model__ = PydanticOrganization
name: Mapped[str] = mapped_column(doc="The display name of the organization.") name: Mapped[str] = mapped_column(doc="The display name of the organization.")
privileged_tools: Mapped[bool] = mapped_column(doc="Whether the organization has access to privileged tools.")
# relationships # relationships
users: Mapped[List["User"]] = relationship("User", back_populates="organization", cascade="all, delete-orphan") users: Mapped[List["User"]] = relationship("User", back_populates="organization", cascade="all, delete-orphan")

View File

@ -16,7 +16,14 @@ class Organization(OrganizationBase):
id: str = OrganizationBase.generate_id_field() id: str = OrganizationBase.generate_id_field()
name: str = Field(create_random_username(), description="The name of the organization.", json_schema_extra={"default": "SincereYogurt"}) name: str = Field(create_random_username(), description="The name of the organization.", json_schema_extra={"default": "SincereYogurt"})
created_at: Optional[datetime] = Field(default_factory=get_utc_time, description="The creation date of the organization.") created_at: Optional[datetime] = Field(default_factory=get_utc_time, description="The creation date of the organization.")
privileged_tools: bool = Field(False, description="Whether the organization has access to privileged tools.")
class OrganizationCreate(OrganizationBase): class OrganizationCreate(OrganizationBase):
name: Optional[str] = Field(None, description="The name of the organization.") name: Optional[str] = Field(None, description="The name of the organization.")
privileged_tools: Optional[bool] = Field(False, description="Whether the organization has access to privileged tools.")
class OrganizationUpdate(OrganizationBase):
name: Optional[str] = Field(None, description="The name of the organization.")
privileged_tools: Optional[bool] = Field(False, description="Whether the organization has access to privileged tools.")

View File

@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, List, Optional
from fastapi import APIRouter, Body, Depends, HTTPException, Query from fastapi import APIRouter, Body, Depends, HTTPException, Query
from letta.schemas.organization import Organization, OrganizationCreate from letta.schemas.organization import Organization, OrganizationCreate, OrganizationUpdate
from letta.server.rest_api.utils import get_letta_server from letta.server.rest_api.utils import get_letta_server
if TYPE_CHECKING: if TYPE_CHECKING:
@ -59,3 +59,21 @@ def delete_org(
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}") raise HTTPException(status_code=500, detail=f"{e}")
return org return org
@router.patch("/", tags=["admin"], response_model=Organization, operation_id="update_organization")
def update_org(
org_id: str = Query(..., description="The org_id key to be updated."),
request: OrganizationUpdate = Body(...),
server: "SyncServer" = Depends(get_letta_server),
):
try:
org = server.organization_manager.get_organization_by_id(org_id=org_id)
if org is None:
raise HTTPException(status_code=404, detail=f"Organization does not exist")
org = server.organization_manager.update_organization(org_id=org_id, name=request.name)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
return org

View File

@ -3,6 +3,7 @@ from typing import List, Optional
from letta.orm.errors import NoResultFound from letta.orm.errors import NoResultFound
from letta.orm.organization import Organization as OrganizationModel from letta.orm.organization import Organization as OrganizationModel
from letta.schemas.organization import Organization as PydanticOrganization from letta.schemas.organization import Organization as PydanticOrganization
from letta.schemas.organization import OrganizationUpdate
from letta.utils import enforce_types from letta.utils import enforce_types
@ -63,6 +64,18 @@ class OrganizationManager:
org.update(session) org.update(session)
return org.to_pydantic() return org.to_pydantic()
@enforce_types
def update_organization(self, org_id: str, org_update: OrganizationUpdate) -> PydanticOrganization:
"""Update an organization."""
with self.session_maker() as session:
org = OrganizationModel.read(db_session=session, identifier=org_id)
if org_update.name:
org.name = org_update.name
if org_update.privileged_tools:
org.privileged_tools = org_update.privileged_tools
org.update(session)
return org.to_pydantic()
@enforce_types @enforce_types
def delete_organization_by_id(self, org_id: str): def delete_organization_by_id(self, org_id: str):
"""Delete an organization by marking it as deleted.""" """Delete an organization by marking it as deleted."""

View File

@ -23,6 +23,7 @@ from letta.services.helpers.tool_execution_helper import (
find_python_executable, find_python_executable,
install_pip_requirements_for_sandbox, install_pip_requirements_for_sandbox,
) )
from letta.services.organization_manager import OrganizationManager
from letta.services.sandbox_config_manager import SandboxConfigManager from letta.services.sandbox_config_manager import SandboxConfigManager
from letta.services.tool_manager import ToolManager from letta.services.tool_manager import ToolManager
from letta.settings import tool_settings from letta.settings import tool_settings
@ -50,6 +51,9 @@ class ToolExecutionSandbox:
self.tool_name = tool_name self.tool_name = tool_name
self.args = args self.args = args
self.user = user self.user = user
# get organization
self.organization = OrganizationManager().get_organization_by_id(self.user.organization_id)
self.privileged_tools = self.organization.privileged_tools
# If a tool object is provided, we use it directly, otherwise pull via name # If a tool object is provided, we use it directly, otherwise pull via name
if tool_object is not None: if tool_object is not None:
@ -79,7 +83,7 @@ class ToolExecutionSandbox:
Returns: Returns:
Tuple[Any, Optional[AgentState]]: Tuple containing (tool_result, agent_state) Tuple[Any, Optional[AgentState]]: Tuple containing (tool_result, agent_state)
""" """
if tool_settings.e2b_api_key: if tool_settings.e2b_api_key and not self.privileged_tools:
logger.debug(f"Using e2b sandbox to execute {self.tool_name}") logger.debug(f"Using e2b sandbox to execute {self.tool_name}")
result = self.run_e2b_sandbox(agent_state=agent_state, additional_env_vars=additional_env_vars) result = self.run_e2b_sandbox(agent_state=agent_state, additional_env_vars=additional_env_vars)
else: else:

View File

@ -32,6 +32,7 @@ from letta.schemas.message import Message as PydanticMessage
from letta.schemas.message import MessageCreate, MessageUpdate from letta.schemas.message import MessageCreate, MessageUpdate
from letta.schemas.openai.chat_completion_response import UsageStatistics from letta.schemas.openai.chat_completion_response import UsageStatistics
from letta.schemas.organization import Organization as PydanticOrganization from letta.schemas.organization import Organization as PydanticOrganization
from letta.schemas.organization import OrganizationUpdate
from letta.schemas.passage import Passage as PydanticPassage from letta.schemas.passage import Passage as PydanticPassage
from letta.schemas.run import Run as PydanticRun from letta.schemas.run import Run as PydanticRun
from letta.schemas.sandbox_config import E2BSandboxConfig, LocalSandboxConfig, SandboxConfigCreate, SandboxConfigUpdate, SandboxType from letta.schemas.sandbox_config import E2BSandboxConfig, LocalSandboxConfig, SandboxConfigCreate, SandboxConfigUpdate, SandboxType
@ -1788,6 +1789,14 @@ def test_update_organization_name(server: SyncServer):
assert org.name == org_name_b assert org.name == org_name_b
def test_update_organization_privileged_tools(server: SyncServer):
org_name = "test"
org = server.organization_manager.create_organization(pydantic_org=PydanticOrganization(name=org_name))
assert org.privileged_tools == False
org = server.organization_manager.update_organization(org_id=org.id, org_update=OrganizationUpdate(privileged_tools=True))
assert org.privileged_tools == True
def test_list_organizations_pagination(server: SyncServer): def test_list_organizations_pagination(server: SyncServer):
server.organization_manager.create_organization(pydantic_org=PydanticOrganization(name="a")) server.organization_manager.create_organization(pydantic_org=PydanticOrganization(name="a"))
server.organization_manager.create_organization(pydantic_org=PydanticOrganization(name="b")) server.organization_manager.create_organization(pydantic_org=PydanticOrganization(name="b"))