feat: Make blocks agents mapping table (#2103)

This commit is contained in:
Matthew Zhou 2024-11-22 16:27:47 -08:00 committed by GitHub
parent 90ec1d860f
commit 251619eb16
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 324 additions and 7 deletions

View File

@ -0,0 +1,52 @@
"""Make an blocks agents mapping table
Revision ID: 1c8880d671ee
Revises: f81ceea2c08d
Create Date: 2024-11-22 15:42:47.209229
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "1c8880d671ee"
down_revision: Union[str, None] = "f81ceea2c08d"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_unique_constraint("unique_block_id_label", "block", ["id", "label"])
op.create_table(
"blocks_agents",
sa.Column("agent_id", sa.String(), nullable=False),
sa.Column("block_id", sa.String(), nullable=False),
sa.Column("block_label", sa.String(), nullable=False),
sa.Column("id", sa.String(), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False),
sa.Column("_created_by_id", sa.String(), nullable=True),
sa.Column("_last_updated_by_id", sa.String(), nullable=True),
sa.ForeignKeyConstraint(
["agent_id"],
["agents.id"],
),
sa.ForeignKeyConstraint(["block_id", "block_label"], ["block.id", "block.label"], name="fk_block_id_label"),
sa.PrimaryKeyConstraint("agent_id", "block_id", "block_label", "id"),
sa.UniqueConstraint("agent_id", "block_label", name="unique_label_per_agent"),
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint("unique_block_id_label", "block", type_="unique")
op.drop_table("blocks_agents")
# ### end Alembic commands ###

View File

@ -1,5 +1,6 @@
from letta.orm.base import Base
from letta.orm.block import Block
from letta.orm.blocks_agents import BlocksAgents
from letta.orm.file import FileMetadata
from letta.orm.organization import Organization
from letta.orm.sandbox_config import SandboxConfig, SandboxEnvironmentVariable

View File

@ -1,6 +1,6 @@
from typing import TYPE_CHECKING, Optional, Type
from sqlalchemy import JSON, BigInteger, Integer
from sqlalchemy import JSON, BigInteger, Integer, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column, relationship
from letta.constants import CORE_MEMORY_BLOCK_CHAR_LIMIT
@ -18,6 +18,8 @@ class Block(OrganizationMixin, SqlalchemyBase):
__tablename__ = "block"
__pydantic_model__ = PydanticBlock
# This may seem redundant, but is necessary for the BlocksAgents composite FK relationship
__table_args__ = (UniqueConstraint("id", "label", name="unique_block_id_label"),)
template_name: Mapped[Optional[str]] = mapped_column(
nullable=True, doc="the unique name that identifies a block in a human-readable way"

View File

@ -0,0 +1,29 @@
from sqlalchemy import ForeignKey, ForeignKeyConstraint, String, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column
from letta.orm.sqlalchemy_base import SqlalchemyBase
from letta.schemas.blocks_agents import BlocksAgents as PydanticBlocksAgents
class BlocksAgents(SqlalchemyBase):
"""Agents must have one or many blocks to make up their core memory."""
__tablename__ = "blocks_agents"
__pydantic_model__ = PydanticBlocksAgents
__table_args__ = (
UniqueConstraint(
"agent_id",
"block_label",
name="unique_label_per_agent",
),
ForeignKeyConstraint(
["block_id", "block_label"],
["block.id", "block.label"],
name="fk_block_id_label",
),
)
# unique agent + block label
agent_id: Mapped[str] = mapped_column(String, ForeignKey("agents.id"), primary_key=True)
block_id: Mapped[str] = mapped_column(String, primary_key=True)
block_label: Mapped[str] = mapped_column(String, primary_key=True)

View File

@ -0,0 +1,32 @@
from datetime import datetime
from typing import Optional
from pydantic import Field
from letta.schemas.letta_base import LettaBase
class BlocksAgentsBase(LettaBase):
__id_prefix__ = "blocks_agents"
class BlocksAgents(BlocksAgentsBase):
"""
Schema representing the relationship between blocks and agents.
Parameters:
agent_id (str): The ID of the associated agent.
block_id (str): The ID of the associated block.
block_label (str): The label of the block.
created_at (datetime): The date this relationship was created.
updated_at (datetime): The date this relationship was last updated.
is_deleted (bool): Whether this block-agent relationship is deleted or not.
"""
id: str = BlocksAgentsBase.generate_id_field()
agent_id: str = Field(..., description="The ID of the associated agent.")
block_id: str = Field(..., description="The ID of the associated block.")
block_label: str = Field(..., description="The label of the block.")
created_at: Optional[datetime] = Field(None, description="The creation date of the association.")
updated_at: Optional[datetime] = Field(None, description="The update date of the association.")
is_deleted: bool = Field(False, description="Whether this block-agent relationship is deleted or not.")

View File

@ -77,6 +77,7 @@ from letta.schemas.usage import LettaUsageStatistics
from letta.schemas.user import User
from letta.services.agents_tags_manager import AgentsTagsManager
from letta.services.block_manager import BlockManager
from letta.services.blocks_agents_manager import BlocksAgentsManager
from letta.services.organization_manager import OrganizationManager
from letta.services.sandbox_config_manager import SandboxConfigManager
from letta.services.source_manager import SourceManager
@ -248,6 +249,7 @@ class SyncServer(Server):
self.block_manager = BlockManager()
self.source_manager = SourceManager()
self.agents_tags_manager = AgentsTagsManager()
self.blocks_agents_manager = BlocksAgentsManager()
self.sandbox_config_manager = SandboxConfigManager(tool_settings)
# Make default user and org

View File

@ -36,7 +36,7 @@ class BlockManager:
return block.to_pydantic()
@enforce_types
def update_block(self, block_id: str, block_update: BlockUpdate, actor: PydanticUser, limit: Optional[int] = None) -> PydanticBlock:
def update_block(self, block_id: str, block_update: BlockUpdate, actor: PydanticUser) -> PydanticBlock:
"""Update a block by its ID with the given BlockUpdate object."""
with self.session_maker() as session:
block = BlockModel.read(db_session=session, identifier=block_id, actor=actor)

View File

@ -0,0 +1,84 @@
import warnings
from typing import List
from letta.orm.blocks_agents import BlocksAgents as BlocksAgentsModel
from letta.orm.errors import NoResultFound
from letta.schemas.blocks_agents import BlocksAgents as PydanticBlocksAgents
from letta.utils import enforce_types
# TODO: DELETE THIS ASAP
# TODO: So we have a patch where we manually specify CRUD operations
# TODO: This is because Agent is NOT migrated to the ORM yet
# TODO: Once we migrate Agent to the ORM, we should deprecate any agents relationship table managers
class BlocksAgentsManager:
"""Manager class to handle business logic related to Blocks and Agents."""
def __init__(self):
from letta.server.server import db_context
self.session_maker = db_context
@enforce_types
def add_block_to_agent(self, agent_id: str, block_id: str, block_label: str) -> PydanticBlocksAgents:
"""Add a block to an agent. If the label already exists on that agent, this will error."""
with self.session_maker() as session:
try:
# Check if the block-label combination already exists for this agent
blocks_agents_record = BlocksAgentsModel.read(db_session=session, agent_id=agent_id, block_label=block_label)
warnings.warn(f"Block label '{block_label}' already exists for agent '{agent_id}'.")
except NoResultFound:
blocks_agents_record = PydanticBlocksAgents(agent_id=agent_id, block_id=block_id, block_label=block_label)
blocks_agents_record = BlocksAgentsModel(**blocks_agents_record.model_dump(exclude_none=True))
blocks_agents_record.create(session)
return blocks_agents_record.to_pydantic()
@enforce_types
def remove_block_with_label_from_agent(self, agent_id: str, block_label: str) -> PydanticBlocksAgents:
"""Remove a block with a label from an agent."""
with self.session_maker() as session:
try:
# Find and delete the block-label association for the agent
blocks_agents_record = BlocksAgentsModel.read(db_session=session, agent_id=agent_id, block_label=block_label)
blocks_agents_record.hard_delete(session)
return blocks_agents_record.to_pydantic()
except NoResultFound:
raise ValueError(f"Block label '{block_label}' not found for agent '{agent_id}'.")
@enforce_types
def remove_block_with_id_from_agent(self, agent_id: str, block_id: str) -> PydanticBlocksAgents:
"""Remove a block with a label from an agent."""
with self.session_maker() as session:
try:
# Find and delete the block-label association for the agent
blocks_agents_record = BlocksAgentsModel.read(db_session=session, agent_id=agent_id, block_id=block_id)
blocks_agents_record.hard_delete(session)
return blocks_agents_record.to_pydantic()
except NoResultFound:
raise ValueError(f"Block id '{block_id}' not found for agent '{agent_id}'.")
@enforce_types
def update_block_id_for_agent(self, agent_id: str, block_label: str, new_block_id: str) -> PydanticBlocksAgents:
"""Update the block ID for a specific block label for an agent."""
with self.session_maker() as session:
try:
blocks_agents_record = BlocksAgentsModel.read(db_session=session, agent_id=agent_id, block_label=block_label)
blocks_agents_record.block_id = new_block_id
return blocks_agents_record.to_pydantic()
except NoResultFound:
raise ValueError(f"Block label '{block_label}' not found for agent '{agent_id}'.")
@enforce_types
def list_block_ids_for_agent(self, agent_id: str) -> List[str]:
"""List all blocks associated with a specific agent."""
with self.session_maker() as session:
blocks_agents_record = BlocksAgentsModel.list(db_session=session, agent_id=agent_id)
return [record.block_id for record in blocks_agents_record]
@enforce_types
def list_agent_ids_with_block(self, block_id: str) -> List[str]:
"""List all agents associated with a specific block."""
with self.session_maker() as session:
blocks_agents_record = BlocksAgentsModel.list(db_session=session, block_id=block_id)
return [record.agent_id for record in blocks_agents_record]

View File

@ -1,10 +1,13 @@
import pytest
from sqlalchemy import delete
from sqlalchemy.exc import DBAPIError
import letta.utils as utils
from letta.functions.functions import derive_openai_json_schema, parse_source_code
from letta.metadata import AgentModel
from letta.orm import (
Block,
BlocksAgents,
FileMetadata,
Organization,
SandboxConfig,
@ -13,6 +16,7 @@ from letta.orm import (
Tool,
User,
)
from letta.orm.agents_tags import AgentsTags
from letta.schemas.agent import CreateAgent
from letta.schemas.block import Block as PydanticBlock
from letta.schemas.block import BlockUpdate
@ -60,12 +64,15 @@ DEFAULT_EMBEDDING_CONFIG = EmbeddingConfig(
def clear_tables(server: SyncServer):
"""Fixture to clear the organization table before each test."""
with server.organization_manager.session_maker() as session:
session.execute(delete(BlocksAgents))
session.execute(delete(AgentsTags))
session.execute(delete(SandboxEnvironmentVariable))
session.execute(delete(SandboxConfig))
session.execute(delete(Block))
session.execute(delete(FileMetadata))
session.execute(delete(Source))
session.execute(delete(Tool)) # Clear all records from the Tool table
session.execute(delete(AgentModel))
session.execute(delete(User)) # Clear all records from the user table
session.execute(delete(Organization)) # Clear all records from the organization table
session.commit() # Commit the deletion
@ -121,8 +128,6 @@ def sarah_agent(server: SyncServer, default_user, default_organization):
)
yield agent_state
server.delete_agent(user_id=default_user.id, agent_id=agent_state.id)
@pytest.fixture
def charles_agent(server: SyncServer, default_user, default_organization):
@ -141,8 +146,6 @@ def charles_agent(server: SyncServer, default_user, default_organization):
)
yield agent_state
server.delete_agent(user_id=default_user.id, agent_id=agent_state.id)
@pytest.fixture
def tool_fixture(server: SyncServer, default_user, default_organization):
@ -200,6 +203,36 @@ def sandbox_env_var_fixture(server: SyncServer, sandbox_config_fixture, default_
yield created_env_var
@pytest.fixture
def default_block(server: SyncServer, default_user):
"""Fixture to create and return a default block."""
block_manager = BlockManager()
block_data = PydanticBlock(
label="default_label",
value="Default Block Content",
description="A default test block",
limit=1000,
metadata_={"type": "test"},
)
block = block_manager.create_or_update_block(block_data, actor=default_user)
yield block
@pytest.fixture
def other_block(server: SyncServer, default_user):
"""Fixture to create and return another block."""
block_manager = BlockManager()
block_data = PydanticBlock(
label="other_label",
value="Other Block Content",
description="Another test block",
limit=500,
metadata_={"type": "test"},
)
block = block_manager.create_or_update_block(block_data, actor=default_user)
yield block
@pytest.fixture(scope="module")
def server():
config = LettaConfig.load()
@ -561,7 +594,7 @@ def test_update_block_limit(server: SyncServer, default_user):
except Exception:
pass
block_manager.update_block(block_id=block.id, block_update=update_data, actor=default_user, limit=limit)
block_manager.update_block(block_id=block.id, block_update=update_data, actor=default_user)
# Retrieve the updated block
updated_block = block_manager.get_blocks(actor=default_user, id=block.id)[0]
# Assertions to verify the update
@ -1018,3 +1051,85 @@ def test_get_sandbox_env_var_by_key(server: SyncServer, sandbox_env_var_fixture,
# Assertions to verify correct retrieval
assert retrieved_env_var.id == sandbox_env_var_fixture.id
assert retrieved_env_var.key == sandbox_env_var_fixture.key
# ======================================================================================================================
# BlocksAgentsManager Tests
# ======================================================================================================================
def test_add_block_to_agent(server, sarah_agent, default_user, default_block):
block_association = server.blocks_agents_manager.add_block_to_agent(
agent_id=sarah_agent.id, block_id=default_block.id, block_label=default_block.label
)
assert block_association.agent_id == sarah_agent.id
assert block_association.block_id == default_block.id
assert block_association.block_label == default_block.label
def test_add_block_to_agent_nonexistent_block(server, sarah_agent, default_user):
with pytest.raises(DBAPIError, match="violates foreign key constraint .*fk_block_id_label"):
server.blocks_agents_manager.add_block_to_agent(
agent_id=sarah_agent.id, block_id="nonexistent_block", block_label="nonexistent_label"
)
def test_add_block_to_agent_duplicate_label(server, sarah_agent, default_user, default_block, other_block):
server.blocks_agents_manager.add_block_to_agent(agent_id=sarah_agent.id, block_id=default_block.id, block_label=default_block.label)
with pytest.warns(UserWarning, match=f"Block label '{default_block.label}' already exists for agent '{sarah_agent.id}'"):
server.blocks_agents_manager.add_block_to_agent(agent_id=sarah_agent.id, block_id=other_block.id, block_label=default_block.label)
def test_remove_block_with_label_from_agent(server, sarah_agent, default_user, default_block):
server.blocks_agents_manager.add_block_to_agent(agent_id=sarah_agent.id, block_id=default_block.id, block_label=default_block.label)
removed_block = server.blocks_agents_manager.remove_block_with_label_from_agent(
agent_id=sarah_agent.id, block_label=default_block.label
)
assert removed_block.block_label == default_block.label
assert removed_block.block_id == default_block.id
assert removed_block.agent_id == sarah_agent.id
with pytest.raises(ValueError, match=f"Block label '{default_block.label}' not found for agent '{sarah_agent.id}'"):
server.blocks_agents_manager.remove_block_with_label_from_agent(agent_id=sarah_agent.id, block_label=default_block.label)
def test_update_block_id_for_agent(server, sarah_agent, default_user, default_block, other_block):
server.blocks_agents_manager.add_block_to_agent(agent_id=sarah_agent.id, block_id=default_block.id, block_label=default_block.label)
updated_block = server.blocks_agents_manager.update_block_id_for_agent(
agent_id=sarah_agent.id, block_label=default_block.label, new_block_id=other_block.id
)
assert updated_block.block_id == other_block.id
assert updated_block.block_label == default_block.label
assert updated_block.agent_id == sarah_agent.id
def test_list_block_ids_for_agent(server, sarah_agent, default_user, default_block, other_block):
server.blocks_agents_manager.add_block_to_agent(agent_id=sarah_agent.id, block_id=default_block.id, block_label=default_block.label)
server.blocks_agents_manager.add_block_to_agent(agent_id=sarah_agent.id, block_id=other_block.id, block_label=other_block.label)
retrieved_block_ids = server.blocks_agents_manager.list_block_ids_for_agent(agent_id=sarah_agent.id)
assert set(retrieved_block_ids) == {default_block.id, other_block.id}
def test_list_agent_ids_with_block(server, sarah_agent, charles_agent, default_user, default_block):
server.blocks_agents_manager.add_block_to_agent(agent_id=sarah_agent.id, block_id=default_block.id, block_label=default_block.label)
server.blocks_agents_manager.add_block_to_agent(agent_id=charles_agent.id, block_id=default_block.id, block_label=default_block.label)
agent_ids = server.blocks_agents_manager.list_agent_ids_with_block(block_id=default_block.id)
assert sarah_agent.id in agent_ids
assert charles_agent.id in agent_ids
assert len(agent_ids) == 2
def test_add_block_to_agent_with_deleted_block(server, sarah_agent, default_user, default_block):
block_manager = BlockManager()
block_manager.delete_block(block_id=default_block.id, actor=default_user)
with pytest.raises(DBAPIError, match='insert or update on table "blocks_agents" violates foreign key constraint'):
server.blocks_agents_manager.add_block_to_agent(agent_id=sarah_agent.id, block_id=default_block.id, block_label=default_block.label)