diff --git a/alembic/versions/bdddd421ec41_add_privileged_tools_to_organization.py b/alembic/versions/bdddd421ec41_add_privileged_tools_to_organization.py new file mode 100644 index 000000000..2d6191ca1 --- /dev/null +++ b/alembic/versions/bdddd421ec41_add_privileged_tools_to_organization.py @@ -0,0 +1,39 @@ +"""add privileged_tools to Organization + +Revision ID: bdddd421ec41 +Revises: 1e553a664210 +Create Date: 2025-03-21 17:55:30.405519 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "bdddd421ec41" +down_revision: Union[str, None] = "1e553a664210" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # Step 1: Add `privileged_tools` column with nullable=True + op.add_column("organizations", sa.Column("privileged_tools", sa.Boolean(), nullable=True)) + + # fill in column with `False` + op.execute( + f""" + UPDATE organizations + SET privileged_tools = False + """ + ) + + # Step 2: Make `privileged_tools` non-nullable + op.alter_column("organizations", "privileged_tools", nullable=False) + + +def downgrade() -> None: + op.drop_column("organizations", "privileged_tools") diff --git a/letta/orm/organization.py b/letta/orm/organization.py index 133b77d8c..ebd66f80d 100644 --- a/letta/orm/organization.py +++ b/letta/orm/organization.py @@ -23,6 +23,7 @@ class Organization(SqlalchemyBase): __pydantic_model__ = PydanticOrganization name: Mapped[str] = mapped_column(doc="The display name of the organization.") + privileged_tools: Mapped[bool] = mapped_column(doc="Whether the organization has access to privileged tools.") # relationships users: Mapped[List["User"]] = relationship("User", back_populates="organization", cascade="all, delete-orphan") diff --git a/letta/schemas/organization.py b/letta/schemas/organization.py index e54523727..9af86a143 100644 --- a/letta/schemas/organization.py +++ b/letta/schemas/organization.py @@ -16,7 +16,14 @@ class Organization(OrganizationBase): id: str = OrganizationBase.generate_id_field() name: str = Field(create_random_username(), description="The name of the organization.", json_schema_extra={"default": "SincereYogurt"}) created_at: Optional[datetime] = Field(default_factory=get_utc_time, description="The creation date of the organization.") + privileged_tools: bool = Field(False, description="Whether the organization has access to privileged tools.") class OrganizationCreate(OrganizationBase): name: Optional[str] = Field(None, description="The name of the organization.") + privileged_tools: Optional[bool] = Field(False, description="Whether the organization has access to privileged tools.") + + +class OrganizationUpdate(OrganizationBase): + name: Optional[str] = Field(None, description="The name of the organization.") + privileged_tools: Optional[bool] = Field(False, description="Whether the organization has access to privileged tools.") diff --git a/letta/server/rest_api/routers/v1/organizations.py b/letta/server/rest_api/routers/v1/organizations.py index e35bf5bde..dec211877 100644 --- a/letta/server/rest_api/routers/v1/organizations.py +++ b/letta/server/rest_api/routers/v1/organizations.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, List, Optional from fastapi import APIRouter, Body, Depends, HTTPException, Query -from letta.schemas.organization import Organization, OrganizationCreate +from letta.schemas.organization import Organization, OrganizationCreate, OrganizationUpdate from letta.server.rest_api.utils import get_letta_server if TYPE_CHECKING: @@ -59,3 +59,21 @@ def delete_org( except Exception as e: raise HTTPException(status_code=500, detail=f"{e}") return org + + +@router.patch("/", tags=["admin"], response_model=Organization, operation_id="update_organization") +def update_org( + org_id: str = Query(..., description="The org_id key to be updated."), + request: OrganizationUpdate = Body(...), + server: "SyncServer" = Depends(get_letta_server), +): + try: + org = server.organization_manager.get_organization_by_id(org_id=org_id) + if org is None: + raise HTTPException(status_code=404, detail=f"Organization does not exist") + org = server.organization_manager.update_organization(org_id=org_id, name=request.name) + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"{e}") + return org diff --git a/letta/services/organization_manager.py b/letta/services/organization_manager.py index 3f47f8a36..1d0b637d7 100644 --- a/letta/services/organization_manager.py +++ b/letta/services/organization_manager.py @@ -3,6 +3,7 @@ from typing import List, Optional from letta.orm.errors import NoResultFound from letta.orm.organization import Organization as OrganizationModel from letta.schemas.organization import Organization as PydanticOrganization +from letta.schemas.organization import OrganizationUpdate from letta.utils import enforce_types @@ -63,6 +64,18 @@ class OrganizationManager: org.update(session) return org.to_pydantic() + @enforce_types + def update_organization(self, org_id: str, org_update: OrganizationUpdate) -> PydanticOrganization: + """Update an organization.""" + with self.session_maker() as session: + org = OrganizationModel.read(db_session=session, identifier=org_id) + if org_update.name: + org.name = org_update.name + if org_update.privileged_tools: + org.privileged_tools = org_update.privileged_tools + org.update(session) + return org.to_pydantic() + @enforce_types def delete_organization_by_id(self, org_id: str): """Delete an organization by marking it as deleted.""" diff --git a/letta/services/tool_execution_sandbox.py b/letta/services/tool_execution_sandbox.py index 1dc8d3092..5f1a46445 100644 --- a/letta/services/tool_execution_sandbox.py +++ b/letta/services/tool_execution_sandbox.py @@ -23,6 +23,7 @@ from letta.services.helpers.tool_execution_helper import ( find_python_executable, install_pip_requirements_for_sandbox, ) +from letta.services.organization_manager import OrganizationManager from letta.services.sandbox_config_manager import SandboxConfigManager from letta.services.tool_manager import ToolManager from letta.settings import tool_settings @@ -50,6 +51,9 @@ class ToolExecutionSandbox: self.tool_name = tool_name self.args = args self.user = user + # get organization + self.organization = OrganizationManager().get_organization_by_id(self.user.organization_id) + self.privileged_tools = self.organization.privileged_tools # If a tool object is provided, we use it directly, otherwise pull via name if tool_object is not None: @@ -79,7 +83,7 @@ class ToolExecutionSandbox: Returns: Tuple[Any, Optional[AgentState]]: Tuple containing (tool_result, agent_state) """ - if tool_settings.e2b_api_key: + if tool_settings.e2b_api_key and not self.privileged_tools: logger.debug(f"Using e2b sandbox to execute {self.tool_name}") result = self.run_e2b_sandbox(agent_state=agent_state, additional_env_vars=additional_env_vars) else: diff --git a/tests/test_managers.py b/tests/test_managers.py index ac1c8693a..3a1a9bbcd 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -32,6 +32,7 @@ from letta.schemas.message import Message as PydanticMessage from letta.schemas.message import MessageCreate, MessageUpdate from letta.schemas.openai.chat_completion_response import UsageStatistics from letta.schemas.organization import Organization as PydanticOrganization +from letta.schemas.organization import OrganizationUpdate from letta.schemas.passage import Passage as PydanticPassage from letta.schemas.run import Run as PydanticRun from letta.schemas.sandbox_config import E2BSandboxConfig, LocalSandboxConfig, SandboxConfigCreate, SandboxConfigUpdate, SandboxType @@ -1788,6 +1789,14 @@ def test_update_organization_name(server: SyncServer): assert org.name == org_name_b +def test_update_organization_privileged_tools(server: SyncServer): + org_name = "test" + org = server.organization_manager.create_organization(pydantic_org=PydanticOrganization(name=org_name)) + assert org.privileged_tools == False + org = server.organization_manager.update_organization(org_id=org.id, org_update=OrganizationUpdate(privileged_tools=True)) + assert org.privileged_tools == True + + def test_list_organizations_pagination(server: SyncServer): server.organization_manager.create_organization(pydantic_org=PydanticOrganization(name="a")) server.organization_manager.create_organization(pydantic_org=PydanticOrganization(name="b"))