feat: Make blocks agents mapping table (#2103)

This commit is contained in:
Matthew Zhou 2024-11-22 16:27:47 -08:00 committed by GitHub
parent 90ec1d860f
commit 251619eb16
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 324 additions and 7 deletions

View File

@ -0,0 +1,52 @@
"""Make an blocks agents mapping table
Revision ID: 1c8880d671ee
Revises: f81ceea2c08d
Create Date: 2024-11-22 15:42:47.209229
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "1c8880d671ee"
down_revision: Union[str, None] = "f81ceea2c08d"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_unique_constraint("unique_block_id_label", "block", ["id", "label"])
op.create_table(
"blocks_agents",
sa.Column("agent_id", sa.String(), nullable=False),
sa.Column("block_id", sa.String(), nullable=False),
sa.Column("block_label", sa.String(), nullable=False),
sa.Column("id", sa.String(), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False),
sa.Column("_created_by_id", sa.String(), nullable=True),
sa.Column("_last_updated_by_id", sa.String(), nullable=True),
sa.ForeignKeyConstraint(
["agent_id"],
["agents.id"],
),
sa.ForeignKeyConstraint(["block_id", "block_label"], ["block.id", "block.label"], name="fk_block_id_label"),
sa.PrimaryKeyConstraint("agent_id", "block_id", "block_label", "id"),
sa.UniqueConstraint("agent_id", "block_label", name="unique_label_per_agent"),
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint("unique_block_id_label", "block", type_="unique")
op.drop_table("blocks_agents")
# ### end Alembic commands ###

View File

@ -1,5 +1,6 @@
from letta.orm.base import Base from letta.orm.base import Base
from letta.orm.block import Block from letta.orm.block import Block
from letta.orm.blocks_agents import BlocksAgents
from letta.orm.file import FileMetadata from letta.orm.file import FileMetadata
from letta.orm.organization import Organization from letta.orm.organization import Organization
from letta.orm.sandbox_config import SandboxConfig, SandboxEnvironmentVariable from letta.orm.sandbox_config import SandboxConfig, SandboxEnvironmentVariable

View File

@ -1,6 +1,6 @@
from typing import TYPE_CHECKING, Optional, Type from typing import TYPE_CHECKING, Optional, Type
from sqlalchemy import JSON, BigInteger, Integer from sqlalchemy import JSON, BigInteger, Integer, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.orm import Mapped, mapped_column, relationship
from letta.constants import CORE_MEMORY_BLOCK_CHAR_LIMIT from letta.constants import CORE_MEMORY_BLOCK_CHAR_LIMIT
@ -18,6 +18,8 @@ class Block(OrganizationMixin, SqlalchemyBase):
__tablename__ = "block" __tablename__ = "block"
__pydantic_model__ = PydanticBlock __pydantic_model__ = PydanticBlock
# This may seem redundant, but is necessary for the BlocksAgents composite FK relationship
__table_args__ = (UniqueConstraint("id", "label", name="unique_block_id_label"),)
template_name: Mapped[Optional[str]] = mapped_column( template_name: Mapped[Optional[str]] = mapped_column(
nullable=True, doc="the unique name that identifies a block in a human-readable way" nullable=True, doc="the unique name that identifies a block in a human-readable way"

View File

@ -0,0 +1,29 @@
from sqlalchemy import ForeignKey, ForeignKeyConstraint, String, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column
from letta.orm.sqlalchemy_base import SqlalchemyBase
from letta.schemas.blocks_agents import BlocksAgents as PydanticBlocksAgents
class BlocksAgents(SqlalchemyBase):
"""Agents must have one or many blocks to make up their core memory."""
__tablename__ = "blocks_agents"
__pydantic_model__ = PydanticBlocksAgents
__table_args__ = (
UniqueConstraint(
"agent_id",
"block_label",
name="unique_label_per_agent",
),
ForeignKeyConstraint(
["block_id", "block_label"],
["block.id", "block.label"],
name="fk_block_id_label",
),
)
# unique agent + block label
agent_id: Mapped[str] = mapped_column(String, ForeignKey("agents.id"), primary_key=True)
block_id: Mapped[str] = mapped_column(String, primary_key=True)
block_label: Mapped[str] = mapped_column(String, primary_key=True)

View File

@ -0,0 +1,32 @@
from datetime import datetime
from typing import Optional
from pydantic import Field
from letta.schemas.letta_base import LettaBase
class BlocksAgentsBase(LettaBase):
__id_prefix__ = "blocks_agents"
class BlocksAgents(BlocksAgentsBase):
"""
Schema representing the relationship between blocks and agents.
Parameters:
agent_id (str): The ID of the associated agent.
block_id (str): The ID of the associated block.
block_label (str): The label of the block.
created_at (datetime): The date this relationship was created.
updated_at (datetime): The date this relationship was last updated.
is_deleted (bool): Whether this block-agent relationship is deleted or not.
"""
id: str = BlocksAgentsBase.generate_id_field()
agent_id: str = Field(..., description="The ID of the associated agent.")
block_id: str = Field(..., description="The ID of the associated block.")
block_label: str = Field(..., description="The label of the block.")
created_at: Optional[datetime] = Field(None, description="The creation date of the association.")
updated_at: Optional[datetime] = Field(None, description="The update date of the association.")
is_deleted: bool = Field(False, description="Whether this block-agent relationship is deleted or not.")

View File

@ -77,6 +77,7 @@ from letta.schemas.usage import LettaUsageStatistics
from letta.schemas.user import User from letta.schemas.user import User
from letta.services.agents_tags_manager import AgentsTagsManager from letta.services.agents_tags_manager import AgentsTagsManager
from letta.services.block_manager import BlockManager from letta.services.block_manager import BlockManager
from letta.services.blocks_agents_manager import BlocksAgentsManager
from letta.services.organization_manager import OrganizationManager from letta.services.organization_manager import OrganizationManager
from letta.services.sandbox_config_manager import SandboxConfigManager from letta.services.sandbox_config_manager import SandboxConfigManager
from letta.services.source_manager import SourceManager from letta.services.source_manager import SourceManager
@ -248,6 +249,7 @@ class SyncServer(Server):
self.block_manager = BlockManager() self.block_manager = BlockManager()
self.source_manager = SourceManager() self.source_manager = SourceManager()
self.agents_tags_manager = AgentsTagsManager() self.agents_tags_manager = AgentsTagsManager()
self.blocks_agents_manager = BlocksAgentsManager()
self.sandbox_config_manager = SandboxConfigManager(tool_settings) self.sandbox_config_manager = SandboxConfigManager(tool_settings)
# Make default user and org # Make default user and org

View File

@ -36,7 +36,7 @@ class BlockManager:
return block.to_pydantic() return block.to_pydantic()
@enforce_types @enforce_types
def update_block(self, block_id: str, block_update: BlockUpdate, actor: PydanticUser, limit: Optional[int] = None) -> PydanticBlock: def update_block(self, block_id: str, block_update: BlockUpdate, actor: PydanticUser) -> PydanticBlock:
"""Update a block by its ID with the given BlockUpdate object.""" """Update a block by its ID with the given BlockUpdate object."""
with self.session_maker() as session: with self.session_maker() as session:
block = BlockModel.read(db_session=session, identifier=block_id, actor=actor) block = BlockModel.read(db_session=session, identifier=block_id, actor=actor)

View File

@ -0,0 +1,84 @@
import warnings
from typing import List
from letta.orm.blocks_agents import BlocksAgents as BlocksAgentsModel
from letta.orm.errors import NoResultFound
from letta.schemas.blocks_agents import BlocksAgents as PydanticBlocksAgents
from letta.utils import enforce_types
# TODO: DELETE THIS ASAP
# TODO: So we have a patch where we manually specify CRUD operations
# TODO: This is because Agent is NOT migrated to the ORM yet
# TODO: Once we migrate Agent to the ORM, we should deprecate any agents relationship table managers
class BlocksAgentsManager:
"""Manager class to handle business logic related to Blocks and Agents."""
def __init__(self):
from letta.server.server import db_context
self.session_maker = db_context
@enforce_types
def add_block_to_agent(self, agent_id: str, block_id: str, block_label: str) -> PydanticBlocksAgents:
"""Add a block to an agent. If the label already exists on that agent, this will error."""
with self.session_maker() as session:
try:
# Check if the block-label combination already exists for this agent
blocks_agents_record = BlocksAgentsModel.read(db_session=session, agent_id=agent_id, block_label=block_label)
warnings.warn(f"Block label '{block_label}' already exists for agent '{agent_id}'.")
except NoResultFound:
blocks_agents_record = PydanticBlocksAgents(agent_id=agent_id, block_id=block_id, block_label=block_label)
blocks_agents_record = BlocksAgentsModel(**blocks_agents_record.model_dump(exclude_none=True))
blocks_agents_record.create(session)
return blocks_agents_record.to_pydantic()
@enforce_types
def remove_block_with_label_from_agent(self, agent_id: str, block_label: str) -> PydanticBlocksAgents:
"""Remove a block with a label from an agent."""
with self.session_maker() as session:
try:
# Find and delete the block-label association for the agent
blocks_agents_record = BlocksAgentsModel.read(db_session=session, agent_id=agent_id, block_label=block_label)
blocks_agents_record.hard_delete(session)
return blocks_agents_record.to_pydantic()
except NoResultFound:
raise ValueError(f"Block label '{block_label}' not found for agent '{agent_id}'.")
@enforce_types
def remove_block_with_id_from_agent(self, agent_id: str, block_id: str) -> PydanticBlocksAgents:
"""Remove a block with a label from an agent."""
with self.session_maker() as session:
try:
# Find and delete the block-label association for the agent
blocks_agents_record = BlocksAgentsModel.read(db_session=session, agent_id=agent_id, block_id=block_id)
blocks_agents_record.hard_delete(session)
return blocks_agents_record.to_pydantic()
except NoResultFound:
raise ValueError(f"Block id '{block_id}' not found for agent '{agent_id}'.")
@enforce_types
def update_block_id_for_agent(self, agent_id: str, block_label: str, new_block_id: str) -> PydanticBlocksAgents:
"""Update the block ID for a specific block label for an agent."""
with self.session_maker() as session:
try:
blocks_agents_record = BlocksAgentsModel.read(db_session=session, agent_id=agent_id, block_label=block_label)
blocks_agents_record.block_id = new_block_id
return blocks_agents_record.to_pydantic()
except NoResultFound:
raise ValueError(f"Block label '{block_label}' not found for agent '{agent_id}'.")
@enforce_types
def list_block_ids_for_agent(self, agent_id: str) -> List[str]:
"""List all blocks associated with a specific agent."""
with self.session_maker() as session:
blocks_agents_record = BlocksAgentsModel.list(db_session=session, agent_id=agent_id)
return [record.block_id for record in blocks_agents_record]
@enforce_types
def list_agent_ids_with_block(self, block_id: str) -> List[str]:
"""List all agents associated with a specific block."""
with self.session_maker() as session:
blocks_agents_record = BlocksAgentsModel.list(db_session=session, block_id=block_id)
return [record.agent_id for record in blocks_agents_record]

View File

@ -1,10 +1,13 @@
import pytest import pytest
from sqlalchemy import delete from sqlalchemy import delete
from sqlalchemy.exc import DBAPIError
import letta.utils as utils import letta.utils as utils
from letta.functions.functions import derive_openai_json_schema, parse_source_code from letta.functions.functions import derive_openai_json_schema, parse_source_code
from letta.metadata import AgentModel
from letta.orm import ( from letta.orm import (
Block, Block,
BlocksAgents,
FileMetadata, FileMetadata,
Organization, Organization,
SandboxConfig, SandboxConfig,
@ -13,6 +16,7 @@ from letta.orm import (
Tool, Tool,
User, User,
) )
from letta.orm.agents_tags import AgentsTags
from letta.schemas.agent import CreateAgent from letta.schemas.agent import CreateAgent
from letta.schemas.block import Block as PydanticBlock from letta.schemas.block import Block as PydanticBlock
from letta.schemas.block import BlockUpdate from letta.schemas.block import BlockUpdate
@ -60,12 +64,15 @@ DEFAULT_EMBEDDING_CONFIG = EmbeddingConfig(
def clear_tables(server: SyncServer): def clear_tables(server: SyncServer):
"""Fixture to clear the organization table before each test.""" """Fixture to clear the organization table before each test."""
with server.organization_manager.session_maker() as session: with server.organization_manager.session_maker() as session:
session.execute(delete(BlocksAgents))
session.execute(delete(AgentsTags))
session.execute(delete(SandboxEnvironmentVariable)) session.execute(delete(SandboxEnvironmentVariable))
session.execute(delete(SandboxConfig)) session.execute(delete(SandboxConfig))
session.execute(delete(Block)) session.execute(delete(Block))
session.execute(delete(FileMetadata)) session.execute(delete(FileMetadata))
session.execute(delete(Source)) session.execute(delete(Source))
session.execute(delete(Tool)) # Clear all records from the Tool table session.execute(delete(Tool)) # Clear all records from the Tool table
session.execute(delete(AgentModel))
session.execute(delete(User)) # Clear all records from the user table session.execute(delete(User)) # Clear all records from the user table
session.execute(delete(Organization)) # Clear all records from the organization table session.execute(delete(Organization)) # Clear all records from the organization table
session.commit() # Commit the deletion session.commit() # Commit the deletion
@ -121,8 +128,6 @@ def sarah_agent(server: SyncServer, default_user, default_organization):
) )
yield agent_state yield agent_state
server.delete_agent(user_id=default_user.id, agent_id=agent_state.id)
@pytest.fixture @pytest.fixture
def charles_agent(server: SyncServer, default_user, default_organization): def charles_agent(server: SyncServer, default_user, default_organization):
@ -141,8 +146,6 @@ def charles_agent(server: SyncServer, default_user, default_organization):
) )
yield agent_state yield agent_state
server.delete_agent(user_id=default_user.id, agent_id=agent_state.id)
@pytest.fixture @pytest.fixture
def tool_fixture(server: SyncServer, default_user, default_organization): def tool_fixture(server: SyncServer, default_user, default_organization):
@ -200,6 +203,36 @@ def sandbox_env_var_fixture(server: SyncServer, sandbox_config_fixture, default_
yield created_env_var yield created_env_var
@pytest.fixture
def default_block(server: SyncServer, default_user):
"""Fixture to create and return a default block."""
block_manager = BlockManager()
block_data = PydanticBlock(
label="default_label",
value="Default Block Content",
description="A default test block",
limit=1000,
metadata_={"type": "test"},
)
block = block_manager.create_or_update_block(block_data, actor=default_user)
yield block
@pytest.fixture
def other_block(server: SyncServer, default_user):
"""Fixture to create and return another block."""
block_manager = BlockManager()
block_data = PydanticBlock(
label="other_label",
value="Other Block Content",
description="Another test block",
limit=500,
metadata_={"type": "test"},
)
block = block_manager.create_or_update_block(block_data, actor=default_user)
yield block
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def server(): def server():
config = LettaConfig.load() config = LettaConfig.load()
@ -561,7 +594,7 @@ def test_update_block_limit(server: SyncServer, default_user):
except Exception: except Exception:
pass pass
block_manager.update_block(block_id=block.id, block_update=update_data, actor=default_user, limit=limit) block_manager.update_block(block_id=block.id, block_update=update_data, actor=default_user)
# Retrieve the updated block # Retrieve the updated block
updated_block = block_manager.get_blocks(actor=default_user, id=block.id)[0] updated_block = block_manager.get_blocks(actor=default_user, id=block.id)[0]
# Assertions to verify the update # Assertions to verify the update
@ -1018,3 +1051,85 @@ def test_get_sandbox_env_var_by_key(server: SyncServer, sandbox_env_var_fixture,
# Assertions to verify correct retrieval # Assertions to verify correct retrieval
assert retrieved_env_var.id == sandbox_env_var_fixture.id assert retrieved_env_var.id == sandbox_env_var_fixture.id
assert retrieved_env_var.key == sandbox_env_var_fixture.key assert retrieved_env_var.key == sandbox_env_var_fixture.key
# ======================================================================================================================
# BlocksAgentsManager Tests
# ======================================================================================================================
def test_add_block_to_agent(server, sarah_agent, default_user, default_block):
block_association = server.blocks_agents_manager.add_block_to_agent(
agent_id=sarah_agent.id, block_id=default_block.id, block_label=default_block.label
)
assert block_association.agent_id == sarah_agent.id
assert block_association.block_id == default_block.id
assert block_association.block_label == default_block.label
def test_add_block_to_agent_nonexistent_block(server, sarah_agent, default_user):
with pytest.raises(DBAPIError, match="violates foreign key constraint .*fk_block_id_label"):
server.blocks_agents_manager.add_block_to_agent(
agent_id=sarah_agent.id, block_id="nonexistent_block", block_label="nonexistent_label"
)
def test_add_block_to_agent_duplicate_label(server, sarah_agent, default_user, default_block, other_block):
server.blocks_agents_manager.add_block_to_agent(agent_id=sarah_agent.id, block_id=default_block.id, block_label=default_block.label)
with pytest.warns(UserWarning, match=f"Block label '{default_block.label}' already exists for agent '{sarah_agent.id}'"):
server.blocks_agents_manager.add_block_to_agent(agent_id=sarah_agent.id, block_id=other_block.id, block_label=default_block.label)
def test_remove_block_with_label_from_agent(server, sarah_agent, default_user, default_block):
server.blocks_agents_manager.add_block_to_agent(agent_id=sarah_agent.id, block_id=default_block.id, block_label=default_block.label)
removed_block = server.blocks_agents_manager.remove_block_with_label_from_agent(
agent_id=sarah_agent.id, block_label=default_block.label
)
assert removed_block.block_label == default_block.label
assert removed_block.block_id == default_block.id
assert removed_block.agent_id == sarah_agent.id
with pytest.raises(ValueError, match=f"Block label '{default_block.label}' not found for agent '{sarah_agent.id}'"):
server.blocks_agents_manager.remove_block_with_label_from_agent(agent_id=sarah_agent.id, block_label=default_block.label)
def test_update_block_id_for_agent(server, sarah_agent, default_user, default_block, other_block):
server.blocks_agents_manager.add_block_to_agent(agent_id=sarah_agent.id, block_id=default_block.id, block_label=default_block.label)
updated_block = server.blocks_agents_manager.update_block_id_for_agent(
agent_id=sarah_agent.id, block_label=default_block.label, new_block_id=other_block.id
)
assert updated_block.block_id == other_block.id
assert updated_block.block_label == default_block.label
assert updated_block.agent_id == sarah_agent.id
def test_list_block_ids_for_agent(server, sarah_agent, default_user, default_block, other_block):
server.blocks_agents_manager.add_block_to_agent(agent_id=sarah_agent.id, block_id=default_block.id, block_label=default_block.label)
server.blocks_agents_manager.add_block_to_agent(agent_id=sarah_agent.id, block_id=other_block.id, block_label=other_block.label)
retrieved_block_ids = server.blocks_agents_manager.list_block_ids_for_agent(agent_id=sarah_agent.id)
assert set(retrieved_block_ids) == {default_block.id, other_block.id}
def test_list_agent_ids_with_block(server, sarah_agent, charles_agent, default_user, default_block):
server.blocks_agents_manager.add_block_to_agent(agent_id=sarah_agent.id, block_id=default_block.id, block_label=default_block.label)
server.blocks_agents_manager.add_block_to_agent(agent_id=charles_agent.id, block_id=default_block.id, block_label=default_block.label)
agent_ids = server.blocks_agents_manager.list_agent_ids_with_block(block_id=default_block.id)
assert sarah_agent.id in agent_ids
assert charles_agent.id in agent_ids
assert len(agent_ids) == 2
def test_add_block_to_agent_with_deleted_block(server, sarah_agent, default_user, default_block):
block_manager = BlockManager()
block_manager.delete_block(block_id=default_block.id, actor=default_user)
with pytest.raises(DBAPIError, match='insert or update on table "blocks_agents" violates foreign key constraint'):
server.blocks_agents_manager.add_block_to_agent(agent_id=sarah_agent.id, block_id=default_block.id, block_label=default_block.label)