mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
feat: add privileged organizations (#1365)
This commit is contained in:
parent
1d552bde55
commit
e1ec91baff
@ -0,0 +1,39 @@
|
||||
"""add privileged_tools to Organization
|
||||
|
||||
Revision ID: bdddd421ec41
|
||||
Revises: 1e553a664210
|
||||
Create Date: 2025-03-21 17:55:30.405519
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "bdddd421ec41"
|
||||
down_revision: Union[str, None] = "1e553a664210"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Step 1: Add `privileged_tools` column with nullable=True
|
||||
op.add_column("organizations", sa.Column("privileged_tools", sa.Boolean(), nullable=True))
|
||||
|
||||
# fill in column with `False`
|
||||
op.execute(
|
||||
f"""
|
||||
UPDATE organizations
|
||||
SET privileged_tools = False
|
||||
"""
|
||||
)
|
||||
|
||||
# Step 2: Make `privileged_tools` non-nullable
|
||||
op.alter_column("organizations", "privileged_tools", nullable=False)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("organizations", "privileged_tools")
|
@ -23,6 +23,7 @@ class Organization(SqlalchemyBase):
|
||||
__pydantic_model__ = PydanticOrganization
|
||||
|
||||
name: Mapped[str] = mapped_column(doc="The display name of the organization.")
|
||||
privileged_tools: Mapped[bool] = mapped_column(doc="Whether the organization has access to privileged tools.")
|
||||
|
||||
# relationships
|
||||
users: Mapped[List["User"]] = relationship("User", back_populates="organization", cascade="all, delete-orphan")
|
||||
|
@ -16,7 +16,14 @@ class Organization(OrganizationBase):
|
||||
id: str = OrganizationBase.generate_id_field()
|
||||
name: str = Field(create_random_username(), description="The name of the organization.", json_schema_extra={"default": "SincereYogurt"})
|
||||
created_at: Optional[datetime] = Field(default_factory=get_utc_time, description="The creation date of the organization.")
|
||||
privileged_tools: bool = Field(False, description="Whether the organization has access to privileged tools.")
|
||||
|
||||
|
||||
class OrganizationCreate(OrganizationBase):
|
||||
name: Optional[str] = Field(None, description="The name of the organization.")
|
||||
privileged_tools: Optional[bool] = Field(False, description="Whether the organization has access to privileged tools.")
|
||||
|
||||
|
||||
class OrganizationUpdate(OrganizationBase):
|
||||
name: Optional[str] = Field(None, description="The name of the organization.")
|
||||
privileged_tools: Optional[bool] = Field(False, description="Whether the organization has access to privileged tools.")
|
||||
|
@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Query
|
||||
|
||||
from letta.schemas.organization import Organization, OrganizationCreate
|
||||
from letta.schemas.organization import Organization, OrganizationCreate, OrganizationUpdate
|
||||
from letta.server.rest_api.utils import get_letta_server
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -59,3 +59,21 @@ def delete_org(
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"{e}")
|
||||
return org
|
||||
|
||||
|
||||
@router.patch("/", tags=["admin"], response_model=Organization, operation_id="update_organization")
|
||||
def update_org(
|
||||
org_id: str = Query(..., description="The org_id key to be updated."),
|
||||
request: OrganizationUpdate = Body(...),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
):
|
||||
try:
|
||||
org = server.organization_manager.get_organization_by_id(org_id=org_id)
|
||||
if org is None:
|
||||
raise HTTPException(status_code=404, detail=f"Organization does not exist")
|
||||
org = server.organization_manager.update_organization(org_id=org_id, name=request.name)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"{e}")
|
||||
return org
|
||||
|
@ -3,6 +3,7 @@ from typing import List, Optional
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.orm.organization import Organization as OrganizationModel
|
||||
from letta.schemas.organization import Organization as PydanticOrganization
|
||||
from letta.schemas.organization import OrganizationUpdate
|
||||
from letta.utils import enforce_types
|
||||
|
||||
|
||||
@ -63,6 +64,18 @@ class OrganizationManager:
|
||||
org.update(session)
|
||||
return org.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def update_organization(self, org_id: str, org_update: OrganizationUpdate) -> PydanticOrganization:
|
||||
"""Update an organization."""
|
||||
with self.session_maker() as session:
|
||||
org = OrganizationModel.read(db_session=session, identifier=org_id)
|
||||
if org_update.name:
|
||||
org.name = org_update.name
|
||||
if org_update.privileged_tools:
|
||||
org.privileged_tools = org_update.privileged_tools
|
||||
org.update(session)
|
||||
return org.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def delete_organization_by_id(self, org_id: str):
|
||||
"""Delete an organization by marking it as deleted."""
|
||||
|
@ -23,6 +23,7 @@ from letta.services.helpers.tool_execution_helper import (
|
||||
find_python_executable,
|
||||
install_pip_requirements_for_sandbox,
|
||||
)
|
||||
from letta.services.organization_manager import OrganizationManager
|
||||
from letta.services.sandbox_config_manager import SandboxConfigManager
|
||||
from letta.services.tool_manager import ToolManager
|
||||
from letta.settings import tool_settings
|
||||
@ -50,6 +51,9 @@ class ToolExecutionSandbox:
|
||||
self.tool_name = tool_name
|
||||
self.args = args
|
||||
self.user = user
|
||||
# get organization
|
||||
self.organization = OrganizationManager().get_organization_by_id(self.user.organization_id)
|
||||
self.privileged_tools = self.organization.privileged_tools
|
||||
|
||||
# If a tool object is provided, we use it directly, otherwise pull via name
|
||||
if tool_object is not None:
|
||||
@ -79,7 +83,7 @@ class ToolExecutionSandbox:
|
||||
Returns:
|
||||
Tuple[Any, Optional[AgentState]]: Tuple containing (tool_result, agent_state)
|
||||
"""
|
||||
if tool_settings.e2b_api_key:
|
||||
if tool_settings.e2b_api_key and not self.privileged_tools:
|
||||
logger.debug(f"Using e2b sandbox to execute {self.tool_name}")
|
||||
result = self.run_e2b_sandbox(agent_state=agent_state, additional_env_vars=additional_env_vars)
|
||||
else:
|
||||
|
@ -32,6 +32,7 @@ from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.message import MessageCreate, MessageUpdate
|
||||
from letta.schemas.openai.chat_completion_response import UsageStatistics
|
||||
from letta.schemas.organization import Organization as PydanticOrganization
|
||||
from letta.schemas.organization import OrganizationUpdate
|
||||
from letta.schemas.passage import Passage as PydanticPassage
|
||||
from letta.schemas.run import Run as PydanticRun
|
||||
from letta.schemas.sandbox_config import E2BSandboxConfig, LocalSandboxConfig, SandboxConfigCreate, SandboxConfigUpdate, SandboxType
|
||||
@ -1788,6 +1789,14 @@ def test_update_organization_name(server: SyncServer):
|
||||
assert org.name == org_name_b
|
||||
|
||||
|
||||
def test_update_organization_privileged_tools(server: SyncServer):
|
||||
org_name = "test"
|
||||
org = server.organization_manager.create_organization(pydantic_org=PydanticOrganization(name=org_name))
|
||||
assert org.privileged_tools == False
|
||||
org = server.organization_manager.update_organization(org_id=org.id, org_update=OrganizationUpdate(privileged_tools=True))
|
||||
assert org.privileged_tools == True
|
||||
|
||||
|
||||
def test_list_organizations_pagination(server: SyncServer):
|
||||
server.organization_manager.create_organization(pydantic_org=PydanticOrganization(name="a"))
|
||||
server.organization_manager.create_organization(pydantic_org=PydanticOrganization(name="b"))
|
||||
|
Loading…
Reference in New Issue
Block a user