mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
feat: add privileged organizations (#1365)
This commit is contained in:
parent
1d552bde55
commit
e1ec91baff
@ -0,0 +1,39 @@
|
|||||||
|
"""add privileged_tools to Organization
|
||||||
|
|
||||||
|
Revision ID: bdddd421ec41
|
||||||
|
Revises: 1e553a664210
|
||||||
|
Create Date: 2025-03-21 17:55:30.405519
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "bdddd421ec41"
|
||||||
|
down_revision: Union[str, None] = "1e553a664210"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# Step 1: Add `privileged_tools` column with nullable=True
|
||||||
|
op.add_column("organizations", sa.Column("privileged_tools", sa.Boolean(), nullable=True))
|
||||||
|
|
||||||
|
# fill in column with `False`
|
||||||
|
op.execute(
|
||||||
|
f"""
|
||||||
|
UPDATE organizations
|
||||||
|
SET privileged_tools = False
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 2: Make `privileged_tools` non-nullable
|
||||||
|
op.alter_column("organizations", "privileged_tools", nullable=False)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_column("organizations", "privileged_tools")
|
@ -23,6 +23,7 @@ class Organization(SqlalchemyBase):
|
|||||||
__pydantic_model__ = PydanticOrganization
|
__pydantic_model__ = PydanticOrganization
|
||||||
|
|
||||||
name: Mapped[str] = mapped_column(doc="The display name of the organization.")
|
name: Mapped[str] = mapped_column(doc="The display name of the organization.")
|
||||||
|
privileged_tools: Mapped[bool] = mapped_column(doc="Whether the organization has access to privileged tools.")
|
||||||
|
|
||||||
# relationships
|
# relationships
|
||||||
users: Mapped[List["User"]] = relationship("User", back_populates="organization", cascade="all, delete-orphan")
|
users: Mapped[List["User"]] = relationship("User", back_populates="organization", cascade="all, delete-orphan")
|
||||||
|
@ -16,7 +16,14 @@ class Organization(OrganizationBase):
|
|||||||
id: str = OrganizationBase.generate_id_field()
|
id: str = OrganizationBase.generate_id_field()
|
||||||
name: str = Field(create_random_username(), description="The name of the organization.", json_schema_extra={"default": "SincereYogurt"})
|
name: str = Field(create_random_username(), description="The name of the organization.", json_schema_extra={"default": "SincereYogurt"})
|
||||||
created_at: Optional[datetime] = Field(default_factory=get_utc_time, description="The creation date of the organization.")
|
created_at: Optional[datetime] = Field(default_factory=get_utc_time, description="The creation date of the organization.")
|
||||||
|
privileged_tools: bool = Field(False, description="Whether the organization has access to privileged tools.")
|
||||||
|
|
||||||
|
|
||||||
class OrganizationCreate(OrganizationBase):
|
class OrganizationCreate(OrganizationBase):
|
||||||
name: Optional[str] = Field(None, description="The name of the organization.")
|
name: Optional[str] = Field(None, description="The name of the organization.")
|
||||||
|
privileged_tools: Optional[bool] = Field(False, description="Whether the organization has access to privileged tools.")
|
||||||
|
|
||||||
|
|
||||||
|
class OrganizationUpdate(OrganizationBase):
|
||||||
|
name: Optional[str] = Field(None, description="The name of the organization.")
|
||||||
|
privileged_tools: Optional[bool] = Field(False, description="Whether the organization has access to privileged tools.")
|
||||||
|
@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, List, Optional
|
|||||||
|
|
||||||
from fastapi import APIRouter, Body, Depends, HTTPException, Query
|
from fastapi import APIRouter, Body, Depends, HTTPException, Query
|
||||||
|
|
||||||
from letta.schemas.organization import Organization, OrganizationCreate
|
from letta.schemas.organization import Organization, OrganizationCreate, OrganizationUpdate
|
||||||
from letta.server.rest_api.utils import get_letta_server
|
from letta.server.rest_api.utils import get_letta_server
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -59,3 +59,21 @@ def delete_org(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=f"{e}")
|
raise HTTPException(status_code=500, detail=f"{e}")
|
||||||
return org
|
return org
|
||||||
|
|
||||||
|
|
||||||
|
@router.patch("/", tags=["admin"], response_model=Organization, operation_id="update_organization")
|
||||||
|
def update_org(
|
||||||
|
org_id: str = Query(..., description="The org_id key to be updated."),
|
||||||
|
request: OrganizationUpdate = Body(...),
|
||||||
|
server: "SyncServer" = Depends(get_letta_server),
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
org = server.organization_manager.get_organization_by_id(org_id=org_id)
|
||||||
|
if org is None:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Organization does not exist")
|
||||||
|
org = server.organization_manager.update_organization(org_id=org_id, name=request.name)
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=f"{e}")
|
||||||
|
return org
|
||||||
|
@ -3,6 +3,7 @@ from typing import List, Optional
|
|||||||
from letta.orm.errors import NoResultFound
|
from letta.orm.errors import NoResultFound
|
||||||
from letta.orm.organization import Organization as OrganizationModel
|
from letta.orm.organization import Organization as OrganizationModel
|
||||||
from letta.schemas.organization import Organization as PydanticOrganization
|
from letta.schemas.organization import Organization as PydanticOrganization
|
||||||
|
from letta.schemas.organization import OrganizationUpdate
|
||||||
from letta.utils import enforce_types
|
from letta.utils import enforce_types
|
||||||
|
|
||||||
|
|
||||||
@ -63,6 +64,18 @@ class OrganizationManager:
|
|||||||
org.update(session)
|
org.update(session)
|
||||||
return org.to_pydantic()
|
return org.to_pydantic()
|
||||||
|
|
||||||
|
@enforce_types
|
||||||
|
def update_organization(self, org_id: str, org_update: OrganizationUpdate) -> PydanticOrganization:
|
||||||
|
"""Update an organization."""
|
||||||
|
with self.session_maker() as session:
|
||||||
|
org = OrganizationModel.read(db_session=session, identifier=org_id)
|
||||||
|
if org_update.name:
|
||||||
|
org.name = org_update.name
|
||||||
|
if org_update.privileged_tools:
|
||||||
|
org.privileged_tools = org_update.privileged_tools
|
||||||
|
org.update(session)
|
||||||
|
return org.to_pydantic()
|
||||||
|
|
||||||
@enforce_types
|
@enforce_types
|
||||||
def delete_organization_by_id(self, org_id: str):
|
def delete_organization_by_id(self, org_id: str):
|
||||||
"""Delete an organization by marking it as deleted."""
|
"""Delete an organization by marking it as deleted."""
|
||||||
|
@ -23,6 +23,7 @@ from letta.services.helpers.tool_execution_helper import (
|
|||||||
find_python_executable,
|
find_python_executable,
|
||||||
install_pip_requirements_for_sandbox,
|
install_pip_requirements_for_sandbox,
|
||||||
)
|
)
|
||||||
|
from letta.services.organization_manager import OrganizationManager
|
||||||
from letta.services.sandbox_config_manager import SandboxConfigManager
|
from letta.services.sandbox_config_manager import SandboxConfigManager
|
||||||
from letta.services.tool_manager import ToolManager
|
from letta.services.tool_manager import ToolManager
|
||||||
from letta.settings import tool_settings
|
from letta.settings import tool_settings
|
||||||
@ -50,6 +51,9 @@ class ToolExecutionSandbox:
|
|||||||
self.tool_name = tool_name
|
self.tool_name = tool_name
|
||||||
self.args = args
|
self.args = args
|
||||||
self.user = user
|
self.user = user
|
||||||
|
# get organization
|
||||||
|
self.organization = OrganizationManager().get_organization_by_id(self.user.organization_id)
|
||||||
|
self.privileged_tools = self.organization.privileged_tools
|
||||||
|
|
||||||
# If a tool object is provided, we use it directly, otherwise pull via name
|
# If a tool object is provided, we use it directly, otherwise pull via name
|
||||||
if tool_object is not None:
|
if tool_object is not None:
|
||||||
@ -79,7 +83,7 @@ class ToolExecutionSandbox:
|
|||||||
Returns:
|
Returns:
|
||||||
Tuple[Any, Optional[AgentState]]: Tuple containing (tool_result, agent_state)
|
Tuple[Any, Optional[AgentState]]: Tuple containing (tool_result, agent_state)
|
||||||
"""
|
"""
|
||||||
if tool_settings.e2b_api_key:
|
if tool_settings.e2b_api_key and not self.privileged_tools:
|
||||||
logger.debug(f"Using e2b sandbox to execute {self.tool_name}")
|
logger.debug(f"Using e2b sandbox to execute {self.tool_name}")
|
||||||
result = self.run_e2b_sandbox(agent_state=agent_state, additional_env_vars=additional_env_vars)
|
result = self.run_e2b_sandbox(agent_state=agent_state, additional_env_vars=additional_env_vars)
|
||||||
else:
|
else:
|
||||||
|
@ -32,6 +32,7 @@ from letta.schemas.message import Message as PydanticMessage
|
|||||||
from letta.schemas.message import MessageCreate, MessageUpdate
|
from letta.schemas.message import MessageCreate, MessageUpdate
|
||||||
from letta.schemas.openai.chat_completion_response import UsageStatistics
|
from letta.schemas.openai.chat_completion_response import UsageStatistics
|
||||||
from letta.schemas.organization import Organization as PydanticOrganization
|
from letta.schemas.organization import Organization as PydanticOrganization
|
||||||
|
from letta.schemas.organization import OrganizationUpdate
|
||||||
from letta.schemas.passage import Passage as PydanticPassage
|
from letta.schemas.passage import Passage as PydanticPassage
|
||||||
from letta.schemas.run import Run as PydanticRun
|
from letta.schemas.run import Run as PydanticRun
|
||||||
from letta.schemas.sandbox_config import E2BSandboxConfig, LocalSandboxConfig, SandboxConfigCreate, SandboxConfigUpdate, SandboxType
|
from letta.schemas.sandbox_config import E2BSandboxConfig, LocalSandboxConfig, SandboxConfigCreate, SandboxConfigUpdate, SandboxType
|
||||||
@ -1788,6 +1789,14 @@ def test_update_organization_name(server: SyncServer):
|
|||||||
assert org.name == org_name_b
|
assert org.name == org_name_b
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_organization_privileged_tools(server: SyncServer):
|
||||||
|
org_name = "test"
|
||||||
|
org = server.organization_manager.create_organization(pydantic_org=PydanticOrganization(name=org_name))
|
||||||
|
assert org.privileged_tools == False
|
||||||
|
org = server.organization_manager.update_organization(org_id=org.id, org_update=OrganizationUpdate(privileged_tools=True))
|
||||||
|
assert org.privileged_tools == True
|
||||||
|
|
||||||
|
|
||||||
def test_list_organizations_pagination(server: SyncServer):
|
def test_list_organizations_pagination(server: SyncServer):
|
||||||
server.organization_manager.create_organization(pydantic_org=PydanticOrganization(name="a"))
|
server.organization_manager.create_organization(pydantic_org=PydanticOrganization(name="a"))
|
||||||
server.organization_manager.create_organization(pydantic_org=PydanticOrganization(name="b"))
|
server.organization_manager.create_organization(pydantic_org=PydanticOrganization(name="b"))
|
||||||
|
Loading…
Reference in New Issue
Block a user