mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
feat: Make blocks agents mapping table (#2103)
This commit is contained in:
parent
90ec1d860f
commit
251619eb16
@ -0,0 +1,52 @@
|
||||
"""Make an blocks agents mapping table
|
||||
|
||||
Revision ID: 1c8880d671ee
|
||||
Revises: f81ceea2c08d
|
||||
Create Date: 2024-11-22 15:42:47.209229
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "1c8880d671ee"
|
||||
down_revision: Union[str, None] = "f81ceea2c08d"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_unique_constraint("unique_block_id_label", "block", ["id", "label"])
|
||||
|
||||
op.create_table(
|
||||
"blocks_agents",
|
||||
sa.Column("agent_id", sa.String(), nullable=False),
|
||||
sa.Column("block_id", sa.String(), nullable=False),
|
||||
sa.Column("block_label", sa.String(), nullable=False),
|
||||
sa.Column("id", sa.String(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
|
||||
sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False),
|
||||
sa.Column("_created_by_id", sa.String(), nullable=True),
|
||||
sa.Column("_last_updated_by_id", sa.String(), nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["agent_id"],
|
||||
["agents.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(["block_id", "block_label"], ["block.id", "block.label"], name="fk_block_id_label"),
|
||||
sa.PrimaryKeyConstraint("agent_id", "block_id", "block_label", "id"),
|
||||
sa.UniqueConstraint("agent_id", "block_label", name="unique_label_per_agent"),
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_constraint("unique_block_id_label", "block", type_="unique")
|
||||
op.drop_table("blocks_agents")
|
||||
# ### end Alembic commands ###
|
@ -1,5 +1,6 @@
|
||||
from letta.orm.base import Base
|
||||
from letta.orm.block import Block
|
||||
from letta.orm.blocks_agents import BlocksAgents
|
||||
from letta.orm.file import FileMetadata
|
||||
from letta.orm.organization import Organization
|
||||
from letta.orm.sandbox_config import SandboxConfig, SandboxEnvironmentVariable
|
||||
|
@ -1,6 +1,6 @@
|
||||
from typing import TYPE_CHECKING, Optional, Type
|
||||
|
||||
from sqlalchemy import JSON, BigInteger, Integer
|
||||
from sqlalchemy import JSON, BigInteger, Integer, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from letta.constants import CORE_MEMORY_BLOCK_CHAR_LIMIT
|
||||
@ -18,6 +18,8 @@ class Block(OrganizationMixin, SqlalchemyBase):
|
||||
|
||||
__tablename__ = "block"
|
||||
__pydantic_model__ = PydanticBlock
|
||||
# This may seem redundant, but is necessary for the BlocksAgents composite FK relationship
|
||||
__table_args__ = (UniqueConstraint("id", "label", name="unique_block_id_label"),)
|
||||
|
||||
template_name: Mapped[Optional[str]] = mapped_column(
|
||||
nullable=True, doc="the unique name that identifies a block in a human-readable way"
|
||||
|
29
letta/orm/blocks_agents.py
Normal file
29
letta/orm/blocks_agents.py
Normal file
@ -0,0 +1,29 @@
|
||||
from sqlalchemy import ForeignKey, ForeignKeyConstraint, String, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||
from letta.schemas.blocks_agents import BlocksAgents as PydanticBlocksAgents
|
||||
|
||||
|
||||
class BlocksAgents(SqlalchemyBase):
|
||||
"""Agents must have one or many blocks to make up their core memory."""
|
||||
|
||||
__tablename__ = "blocks_agents"
|
||||
__pydantic_model__ = PydanticBlocksAgents
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"agent_id",
|
||||
"block_label",
|
||||
name="unique_label_per_agent",
|
||||
),
|
||||
ForeignKeyConstraint(
|
||||
["block_id", "block_label"],
|
||||
["block.id", "block.label"],
|
||||
name="fk_block_id_label",
|
||||
),
|
||||
)
|
||||
|
||||
# unique agent + block label
|
||||
agent_id: Mapped[str] = mapped_column(String, ForeignKey("agents.id"), primary_key=True)
|
||||
block_id: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
block_label: Mapped[str] = mapped_column(String, primary_key=True)
|
32
letta/schemas/blocks_agents.py
Normal file
32
letta/schemas/blocks_agents.py
Normal file
@ -0,0 +1,32 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from letta.schemas.letta_base import LettaBase
|
||||
|
||||
|
||||
class BlocksAgentsBase(LettaBase):
|
||||
__id_prefix__ = "blocks_agents"
|
||||
|
||||
|
||||
class BlocksAgents(BlocksAgentsBase):
|
||||
"""
|
||||
Schema representing the relationship between blocks and agents.
|
||||
|
||||
Parameters:
|
||||
agent_id (str): The ID of the associated agent.
|
||||
block_id (str): The ID of the associated block.
|
||||
block_label (str): The label of the block.
|
||||
created_at (datetime): The date this relationship was created.
|
||||
updated_at (datetime): The date this relationship was last updated.
|
||||
is_deleted (bool): Whether this block-agent relationship is deleted or not.
|
||||
"""
|
||||
|
||||
id: str = BlocksAgentsBase.generate_id_field()
|
||||
agent_id: str = Field(..., description="The ID of the associated agent.")
|
||||
block_id: str = Field(..., description="The ID of the associated block.")
|
||||
block_label: str = Field(..., description="The label of the block.")
|
||||
created_at: Optional[datetime] = Field(None, description="The creation date of the association.")
|
||||
updated_at: Optional[datetime] = Field(None, description="The update date of the association.")
|
||||
is_deleted: bool = Field(False, description="Whether this block-agent relationship is deleted or not.")
|
@ -77,6 +77,7 @@ from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.schemas.user import User
|
||||
from letta.services.agents_tags_manager import AgentsTagsManager
|
||||
from letta.services.block_manager import BlockManager
|
||||
from letta.services.blocks_agents_manager import BlocksAgentsManager
|
||||
from letta.services.organization_manager import OrganizationManager
|
||||
from letta.services.sandbox_config_manager import SandboxConfigManager
|
||||
from letta.services.source_manager import SourceManager
|
||||
@ -248,6 +249,7 @@ class SyncServer(Server):
|
||||
self.block_manager = BlockManager()
|
||||
self.source_manager = SourceManager()
|
||||
self.agents_tags_manager = AgentsTagsManager()
|
||||
self.blocks_agents_manager = BlocksAgentsManager()
|
||||
self.sandbox_config_manager = SandboxConfigManager(tool_settings)
|
||||
|
||||
# Make default user and org
|
||||
|
@ -36,7 +36,7 @@ class BlockManager:
|
||||
return block.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def update_block(self, block_id: str, block_update: BlockUpdate, actor: PydanticUser, limit: Optional[int] = None) -> PydanticBlock:
|
||||
def update_block(self, block_id: str, block_update: BlockUpdate, actor: PydanticUser) -> PydanticBlock:
|
||||
"""Update a block by its ID with the given BlockUpdate object."""
|
||||
with self.session_maker() as session:
|
||||
block = BlockModel.read(db_session=session, identifier=block_id, actor=actor)
|
||||
|
84
letta/services/blocks_agents_manager.py
Normal file
84
letta/services/blocks_agents_manager.py
Normal file
@ -0,0 +1,84 @@
|
||||
import warnings
|
||||
from typing import List
|
||||
|
||||
from letta.orm.blocks_agents import BlocksAgents as BlocksAgentsModel
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.schemas.blocks_agents import BlocksAgents as PydanticBlocksAgents
|
||||
from letta.utils import enforce_types
|
||||
|
||||
|
||||
# TODO: DELETE THIS ASAP
|
||||
# TODO: So we have a patch where we manually specify CRUD operations
|
||||
# TODO: This is because Agent is NOT migrated to the ORM yet
|
||||
# TODO: Once we migrate Agent to the ORM, we should deprecate any agents relationship table managers
|
||||
class BlocksAgentsManager:
|
||||
"""Manager class to handle business logic related to Blocks and Agents."""
|
||||
|
||||
def __init__(self):
|
||||
from letta.server.server import db_context
|
||||
|
||||
self.session_maker = db_context
|
||||
|
||||
@enforce_types
|
||||
def add_block_to_agent(self, agent_id: str, block_id: str, block_label: str) -> PydanticBlocksAgents:
|
||||
"""Add a block to an agent. If the label already exists on that agent, this will error."""
|
||||
with self.session_maker() as session:
|
||||
try:
|
||||
# Check if the block-label combination already exists for this agent
|
||||
blocks_agents_record = BlocksAgentsModel.read(db_session=session, agent_id=agent_id, block_label=block_label)
|
||||
warnings.warn(f"Block label '{block_label}' already exists for agent '{agent_id}'.")
|
||||
except NoResultFound:
|
||||
blocks_agents_record = PydanticBlocksAgents(agent_id=agent_id, block_id=block_id, block_label=block_label)
|
||||
blocks_agents_record = BlocksAgentsModel(**blocks_agents_record.model_dump(exclude_none=True))
|
||||
blocks_agents_record.create(session)
|
||||
|
||||
return blocks_agents_record.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def remove_block_with_label_from_agent(self, agent_id: str, block_label: str) -> PydanticBlocksAgents:
|
||||
"""Remove a block with a label from an agent."""
|
||||
with self.session_maker() as session:
|
||||
try:
|
||||
# Find and delete the block-label association for the agent
|
||||
blocks_agents_record = BlocksAgentsModel.read(db_session=session, agent_id=agent_id, block_label=block_label)
|
||||
blocks_agents_record.hard_delete(session)
|
||||
return blocks_agents_record.to_pydantic()
|
||||
except NoResultFound:
|
||||
raise ValueError(f"Block label '{block_label}' not found for agent '{agent_id}'.")
|
||||
|
||||
@enforce_types
|
||||
def remove_block_with_id_from_agent(self, agent_id: str, block_id: str) -> PydanticBlocksAgents:
|
||||
"""Remove a block with a label from an agent."""
|
||||
with self.session_maker() as session:
|
||||
try:
|
||||
# Find and delete the block-label association for the agent
|
||||
blocks_agents_record = BlocksAgentsModel.read(db_session=session, agent_id=agent_id, block_id=block_id)
|
||||
blocks_agents_record.hard_delete(session)
|
||||
return blocks_agents_record.to_pydantic()
|
||||
except NoResultFound:
|
||||
raise ValueError(f"Block id '{block_id}' not found for agent '{agent_id}'.")
|
||||
|
||||
@enforce_types
|
||||
def update_block_id_for_agent(self, agent_id: str, block_label: str, new_block_id: str) -> PydanticBlocksAgents:
|
||||
"""Update the block ID for a specific block label for an agent."""
|
||||
with self.session_maker() as session:
|
||||
try:
|
||||
blocks_agents_record = BlocksAgentsModel.read(db_session=session, agent_id=agent_id, block_label=block_label)
|
||||
blocks_agents_record.block_id = new_block_id
|
||||
return blocks_agents_record.to_pydantic()
|
||||
except NoResultFound:
|
||||
raise ValueError(f"Block label '{block_label}' not found for agent '{agent_id}'.")
|
||||
|
||||
@enforce_types
|
||||
def list_block_ids_for_agent(self, agent_id: str) -> List[str]:
|
||||
"""List all blocks associated with a specific agent."""
|
||||
with self.session_maker() as session:
|
||||
blocks_agents_record = BlocksAgentsModel.list(db_session=session, agent_id=agent_id)
|
||||
return [record.block_id for record in blocks_agents_record]
|
||||
|
||||
@enforce_types
|
||||
def list_agent_ids_with_block(self, block_id: str) -> List[str]:
|
||||
"""List all agents associated with a specific block."""
|
||||
with self.session_maker() as session:
|
||||
blocks_agents_record = BlocksAgentsModel.list(db_session=session, block_id=block_id)
|
||||
return [record.agent_id for record in blocks_agents_record]
|
@ -1,10 +1,13 @@
|
||||
import pytest
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy.exc import DBAPIError
|
||||
|
||||
import letta.utils as utils
|
||||
from letta.functions.functions import derive_openai_json_schema, parse_source_code
|
||||
from letta.metadata import AgentModel
|
||||
from letta.orm import (
|
||||
Block,
|
||||
BlocksAgents,
|
||||
FileMetadata,
|
||||
Organization,
|
||||
SandboxConfig,
|
||||
@ -13,6 +16,7 @@ from letta.orm import (
|
||||
Tool,
|
||||
User,
|
||||
)
|
||||
from letta.orm.agents_tags import AgentsTags
|
||||
from letta.schemas.agent import CreateAgent
|
||||
from letta.schemas.block import Block as PydanticBlock
|
||||
from letta.schemas.block import BlockUpdate
|
||||
@ -60,12 +64,15 @@ DEFAULT_EMBEDDING_CONFIG = EmbeddingConfig(
|
||||
def clear_tables(server: SyncServer):
|
||||
"""Fixture to clear the organization table before each test."""
|
||||
with server.organization_manager.session_maker() as session:
|
||||
session.execute(delete(BlocksAgents))
|
||||
session.execute(delete(AgentsTags))
|
||||
session.execute(delete(SandboxEnvironmentVariable))
|
||||
session.execute(delete(SandboxConfig))
|
||||
session.execute(delete(Block))
|
||||
session.execute(delete(FileMetadata))
|
||||
session.execute(delete(Source))
|
||||
session.execute(delete(Tool)) # Clear all records from the Tool table
|
||||
session.execute(delete(AgentModel))
|
||||
session.execute(delete(User)) # Clear all records from the user table
|
||||
session.execute(delete(Organization)) # Clear all records from the organization table
|
||||
session.commit() # Commit the deletion
|
||||
@ -121,8 +128,6 @@ def sarah_agent(server: SyncServer, default_user, default_organization):
|
||||
)
|
||||
yield agent_state
|
||||
|
||||
server.delete_agent(user_id=default_user.id, agent_id=agent_state.id)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def charles_agent(server: SyncServer, default_user, default_organization):
|
||||
@ -141,8 +146,6 @@ def charles_agent(server: SyncServer, default_user, default_organization):
|
||||
)
|
||||
yield agent_state
|
||||
|
||||
server.delete_agent(user_id=default_user.id, agent_id=agent_state.id)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tool_fixture(server: SyncServer, default_user, default_organization):
|
||||
@ -200,6 +203,36 @@ def sandbox_env_var_fixture(server: SyncServer, sandbox_config_fixture, default_
|
||||
yield created_env_var
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_block(server: SyncServer, default_user):
|
||||
"""Fixture to create and return a default block."""
|
||||
block_manager = BlockManager()
|
||||
block_data = PydanticBlock(
|
||||
label="default_label",
|
||||
value="Default Block Content",
|
||||
description="A default test block",
|
||||
limit=1000,
|
||||
metadata_={"type": "test"},
|
||||
)
|
||||
block = block_manager.create_or_update_block(block_data, actor=default_user)
|
||||
yield block
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def other_block(server: SyncServer, default_user):
|
||||
"""Fixture to create and return another block."""
|
||||
block_manager = BlockManager()
|
||||
block_data = PydanticBlock(
|
||||
label="other_label",
|
||||
value="Other Block Content",
|
||||
description="Another test block",
|
||||
limit=500,
|
||||
metadata_={"type": "test"},
|
||||
)
|
||||
block = block_manager.create_or_update_block(block_data, actor=default_user)
|
||||
yield block
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
config = LettaConfig.load()
|
||||
@ -561,7 +594,7 @@ def test_update_block_limit(server: SyncServer, default_user):
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
block_manager.update_block(block_id=block.id, block_update=update_data, actor=default_user, limit=limit)
|
||||
block_manager.update_block(block_id=block.id, block_update=update_data, actor=default_user)
|
||||
# Retrieve the updated block
|
||||
updated_block = block_manager.get_blocks(actor=default_user, id=block.id)[0]
|
||||
# Assertions to verify the update
|
||||
@ -1018,3 +1051,85 @@ def test_get_sandbox_env_var_by_key(server: SyncServer, sandbox_env_var_fixture,
|
||||
# Assertions to verify correct retrieval
|
||||
assert retrieved_env_var.id == sandbox_env_var_fixture.id
|
||||
assert retrieved_env_var.key == sandbox_env_var_fixture.key
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# BlocksAgentsManager Tests
|
||||
# ======================================================================================================================
|
||||
def test_add_block_to_agent(server, sarah_agent, default_user, default_block):
|
||||
block_association = server.blocks_agents_manager.add_block_to_agent(
|
||||
agent_id=sarah_agent.id, block_id=default_block.id, block_label=default_block.label
|
||||
)
|
||||
|
||||
assert block_association.agent_id == sarah_agent.id
|
||||
assert block_association.block_id == default_block.id
|
||||
assert block_association.block_label == default_block.label
|
||||
|
||||
|
||||
def test_add_block_to_agent_nonexistent_block(server, sarah_agent, default_user):
|
||||
with pytest.raises(DBAPIError, match="violates foreign key constraint .*fk_block_id_label"):
|
||||
server.blocks_agents_manager.add_block_to_agent(
|
||||
agent_id=sarah_agent.id, block_id="nonexistent_block", block_label="nonexistent_label"
|
||||
)
|
||||
|
||||
|
||||
def test_add_block_to_agent_duplicate_label(server, sarah_agent, default_user, default_block, other_block):
|
||||
server.blocks_agents_manager.add_block_to_agent(agent_id=sarah_agent.id, block_id=default_block.id, block_label=default_block.label)
|
||||
|
||||
with pytest.warns(UserWarning, match=f"Block label '{default_block.label}' already exists for agent '{sarah_agent.id}'"):
|
||||
server.blocks_agents_manager.add_block_to_agent(agent_id=sarah_agent.id, block_id=other_block.id, block_label=default_block.label)
|
||||
|
||||
|
||||
def test_remove_block_with_label_from_agent(server, sarah_agent, default_user, default_block):
|
||||
server.blocks_agents_manager.add_block_to_agent(agent_id=sarah_agent.id, block_id=default_block.id, block_label=default_block.label)
|
||||
|
||||
removed_block = server.blocks_agents_manager.remove_block_with_label_from_agent(
|
||||
agent_id=sarah_agent.id, block_label=default_block.label
|
||||
)
|
||||
|
||||
assert removed_block.block_label == default_block.label
|
||||
assert removed_block.block_id == default_block.id
|
||||
assert removed_block.agent_id == sarah_agent.id
|
||||
|
||||
with pytest.raises(ValueError, match=f"Block label '{default_block.label}' not found for agent '{sarah_agent.id}'"):
|
||||
server.blocks_agents_manager.remove_block_with_label_from_agent(agent_id=sarah_agent.id, block_label=default_block.label)
|
||||
|
||||
|
||||
def test_update_block_id_for_agent(server, sarah_agent, default_user, default_block, other_block):
|
||||
server.blocks_agents_manager.add_block_to_agent(agent_id=sarah_agent.id, block_id=default_block.id, block_label=default_block.label)
|
||||
|
||||
updated_block = server.blocks_agents_manager.update_block_id_for_agent(
|
||||
agent_id=sarah_agent.id, block_label=default_block.label, new_block_id=other_block.id
|
||||
)
|
||||
|
||||
assert updated_block.block_id == other_block.id
|
||||
assert updated_block.block_label == default_block.label
|
||||
assert updated_block.agent_id == sarah_agent.id
|
||||
|
||||
|
||||
def test_list_block_ids_for_agent(server, sarah_agent, default_user, default_block, other_block):
|
||||
server.blocks_agents_manager.add_block_to_agent(agent_id=sarah_agent.id, block_id=default_block.id, block_label=default_block.label)
|
||||
server.blocks_agents_manager.add_block_to_agent(agent_id=sarah_agent.id, block_id=other_block.id, block_label=other_block.label)
|
||||
|
||||
retrieved_block_ids = server.blocks_agents_manager.list_block_ids_for_agent(agent_id=sarah_agent.id)
|
||||
|
||||
assert set(retrieved_block_ids) == {default_block.id, other_block.id}
|
||||
|
||||
|
||||
def test_list_agent_ids_with_block(server, sarah_agent, charles_agent, default_user, default_block):
|
||||
server.blocks_agents_manager.add_block_to_agent(agent_id=sarah_agent.id, block_id=default_block.id, block_label=default_block.label)
|
||||
server.blocks_agents_manager.add_block_to_agent(agent_id=charles_agent.id, block_id=default_block.id, block_label=default_block.label)
|
||||
|
||||
agent_ids = server.blocks_agents_manager.list_agent_ids_with_block(block_id=default_block.id)
|
||||
|
||||
assert sarah_agent.id in agent_ids
|
||||
assert charles_agent.id in agent_ids
|
||||
assert len(agent_ids) == 2
|
||||
|
||||
|
||||
def test_add_block_to_agent_with_deleted_block(server, sarah_agent, default_user, default_block):
|
||||
block_manager = BlockManager()
|
||||
block_manager.delete_block(block_id=default_block.id, actor=default_user)
|
||||
|
||||
with pytest.raises(DBAPIError, match='insert or update on table "blocks_agents" violates foreign key constraint'):
|
||||
server.blocks_agents_manager.add_block_to_agent(agent_id=sarah_agent.id, block_id=default_block.id, block_label=default_block.label)
|
||||
|
Loading…
Reference in New Issue
Block a user