mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
feat: Make blocks agents mapping table (#2103)
This commit is contained in:
parent
90ec1d860f
commit
251619eb16
@ -0,0 +1,52 @@
|
|||||||
|
"""Make an blocks agents mapping table
|
||||||
|
|
||||||
|
Revision ID: 1c8880d671ee
|
||||||
|
Revises: f81ceea2c08d
|
||||||
|
Create Date: 2024-11-22 15:42:47.209229
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "1c8880d671ee"
|
||||||
|
down_revision: Union[str, None] = "f81ceea2c08d"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.create_unique_constraint("unique_block_id_label", "block", ["id", "label"])
|
||||||
|
|
||||||
|
op.create_table(
|
||||||
|
"blocks_agents",
|
||||||
|
sa.Column("agent_id", sa.String(), nullable=False),
|
||||||
|
sa.Column("block_id", sa.String(), nullable=False),
|
||||||
|
sa.Column("block_label", sa.String(), nullable=False),
|
||||||
|
sa.Column("id", sa.String(), nullable=False),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
|
||||||
|
sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False),
|
||||||
|
sa.Column("_created_by_id", sa.String(), nullable=True),
|
||||||
|
sa.Column("_last_updated_by_id", sa.String(), nullable=True),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["agent_id"],
|
||||||
|
["agents.id"],
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(["block_id", "block_label"], ["block.id", "block.label"], name="fk_block_id_label"),
|
||||||
|
sa.PrimaryKeyConstraint("agent_id", "block_id", "block_label", "id"),
|
||||||
|
sa.UniqueConstraint("agent_id", "block_label", name="unique_label_per_agent"),
|
||||||
|
)
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_constraint("unique_block_id_label", "block", type_="unique")
|
||||||
|
op.drop_table("blocks_agents")
|
||||||
|
# ### end Alembic commands ###
|
@ -1,5 +1,6 @@
|
|||||||
from letta.orm.base import Base
|
from letta.orm.base import Base
|
||||||
from letta.orm.block import Block
|
from letta.orm.block import Block
|
||||||
|
from letta.orm.blocks_agents import BlocksAgents
|
||||||
from letta.orm.file import FileMetadata
|
from letta.orm.file import FileMetadata
|
||||||
from letta.orm.organization import Organization
|
from letta.orm.organization import Organization
|
||||||
from letta.orm.sandbox_config import SandboxConfig, SandboxEnvironmentVariable
|
from letta.orm.sandbox_config import SandboxConfig, SandboxEnvironmentVariable
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from typing import TYPE_CHECKING, Optional, Type
|
from typing import TYPE_CHECKING, Optional, Type
|
||||||
|
|
||||||
from sqlalchemy import JSON, BigInteger, Integer
|
from sqlalchemy import JSON, BigInteger, Integer, UniqueConstraint
|
||||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
from letta.constants import CORE_MEMORY_BLOCK_CHAR_LIMIT
|
from letta.constants import CORE_MEMORY_BLOCK_CHAR_LIMIT
|
||||||
@ -18,6 +18,8 @@ class Block(OrganizationMixin, SqlalchemyBase):
|
|||||||
|
|
||||||
__tablename__ = "block"
|
__tablename__ = "block"
|
||||||
__pydantic_model__ = PydanticBlock
|
__pydantic_model__ = PydanticBlock
|
||||||
|
# This may seem redundant, but is necessary for the BlocksAgents composite FK relationship
|
||||||
|
__table_args__ = (UniqueConstraint("id", "label", name="unique_block_id_label"),)
|
||||||
|
|
||||||
template_name: Mapped[Optional[str]] = mapped_column(
|
template_name: Mapped[Optional[str]] = mapped_column(
|
||||||
nullable=True, doc="the unique name that identifies a block in a human-readable way"
|
nullable=True, doc="the unique name that identifies a block in a human-readable way"
|
||||||
|
29
letta/orm/blocks_agents.py
Normal file
29
letta/orm/blocks_agents.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
from sqlalchemy import ForeignKey, ForeignKeyConstraint, String, UniqueConstraint
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
||||||
|
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||||
|
from letta.schemas.blocks_agents import BlocksAgents as PydanticBlocksAgents
|
||||||
|
|
||||||
|
|
||||||
|
class BlocksAgents(SqlalchemyBase):
|
||||||
|
"""Agents must have one or many blocks to make up their core memory."""
|
||||||
|
|
||||||
|
__tablename__ = "blocks_agents"
|
||||||
|
__pydantic_model__ = PydanticBlocksAgents
|
||||||
|
__table_args__ = (
|
||||||
|
UniqueConstraint(
|
||||||
|
"agent_id",
|
||||||
|
"block_label",
|
||||||
|
name="unique_label_per_agent",
|
||||||
|
),
|
||||||
|
ForeignKeyConstraint(
|
||||||
|
["block_id", "block_label"],
|
||||||
|
["block.id", "block.label"],
|
||||||
|
name="fk_block_id_label",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# unique agent + block label
|
||||||
|
agent_id: Mapped[str] = mapped_column(String, ForeignKey("agents.id"), primary_key=True)
|
||||||
|
block_id: Mapped[str] = mapped_column(String, primary_key=True)
|
||||||
|
block_label: Mapped[str] = mapped_column(String, primary_key=True)
|
32
letta/schemas/blocks_agents.py
Normal file
32
letta/schemas/blocks_agents.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from letta.schemas.letta_base import LettaBase
|
||||||
|
|
||||||
|
|
||||||
|
class BlocksAgentsBase(LettaBase):
|
||||||
|
__id_prefix__ = "blocks_agents"
|
||||||
|
|
||||||
|
|
||||||
|
class BlocksAgents(BlocksAgentsBase):
|
||||||
|
"""
|
||||||
|
Schema representing the relationship between blocks and agents.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
agent_id (str): The ID of the associated agent.
|
||||||
|
block_id (str): The ID of the associated block.
|
||||||
|
block_label (str): The label of the block.
|
||||||
|
created_at (datetime): The date this relationship was created.
|
||||||
|
updated_at (datetime): The date this relationship was last updated.
|
||||||
|
is_deleted (bool): Whether this block-agent relationship is deleted or not.
|
||||||
|
"""
|
||||||
|
|
||||||
|
id: str = BlocksAgentsBase.generate_id_field()
|
||||||
|
agent_id: str = Field(..., description="The ID of the associated agent.")
|
||||||
|
block_id: str = Field(..., description="The ID of the associated block.")
|
||||||
|
block_label: str = Field(..., description="The label of the block.")
|
||||||
|
created_at: Optional[datetime] = Field(None, description="The creation date of the association.")
|
||||||
|
updated_at: Optional[datetime] = Field(None, description="The update date of the association.")
|
||||||
|
is_deleted: bool = Field(False, description="Whether this block-agent relationship is deleted or not.")
|
@ -77,6 +77,7 @@ from letta.schemas.usage import LettaUsageStatistics
|
|||||||
from letta.schemas.user import User
|
from letta.schemas.user import User
|
||||||
from letta.services.agents_tags_manager import AgentsTagsManager
|
from letta.services.agents_tags_manager import AgentsTagsManager
|
||||||
from letta.services.block_manager import BlockManager
|
from letta.services.block_manager import BlockManager
|
||||||
|
from letta.services.blocks_agents_manager import BlocksAgentsManager
|
||||||
from letta.services.organization_manager import OrganizationManager
|
from letta.services.organization_manager import OrganizationManager
|
||||||
from letta.services.sandbox_config_manager import SandboxConfigManager
|
from letta.services.sandbox_config_manager import SandboxConfigManager
|
||||||
from letta.services.source_manager import SourceManager
|
from letta.services.source_manager import SourceManager
|
||||||
@ -248,6 +249,7 @@ class SyncServer(Server):
|
|||||||
self.block_manager = BlockManager()
|
self.block_manager = BlockManager()
|
||||||
self.source_manager = SourceManager()
|
self.source_manager = SourceManager()
|
||||||
self.agents_tags_manager = AgentsTagsManager()
|
self.agents_tags_manager = AgentsTagsManager()
|
||||||
|
self.blocks_agents_manager = BlocksAgentsManager()
|
||||||
self.sandbox_config_manager = SandboxConfigManager(tool_settings)
|
self.sandbox_config_manager = SandboxConfigManager(tool_settings)
|
||||||
|
|
||||||
# Make default user and org
|
# Make default user and org
|
||||||
|
@ -36,7 +36,7 @@ class BlockManager:
|
|||||||
return block.to_pydantic()
|
return block.to_pydantic()
|
||||||
|
|
||||||
@enforce_types
|
@enforce_types
|
||||||
def update_block(self, block_id: str, block_update: BlockUpdate, actor: PydanticUser, limit: Optional[int] = None) -> PydanticBlock:
|
def update_block(self, block_id: str, block_update: BlockUpdate, actor: PydanticUser) -> PydanticBlock:
|
||||||
"""Update a block by its ID with the given BlockUpdate object."""
|
"""Update a block by its ID with the given BlockUpdate object."""
|
||||||
with self.session_maker() as session:
|
with self.session_maker() as session:
|
||||||
block = BlockModel.read(db_session=session, identifier=block_id, actor=actor)
|
block = BlockModel.read(db_session=session, identifier=block_id, actor=actor)
|
||||||
|
84
letta/services/blocks_agents_manager.py
Normal file
84
letta/services/blocks_agents_manager.py
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
import warnings
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from letta.orm.blocks_agents import BlocksAgents as BlocksAgentsModel
|
||||||
|
from letta.orm.errors import NoResultFound
|
||||||
|
from letta.schemas.blocks_agents import BlocksAgents as PydanticBlocksAgents
|
||||||
|
from letta.utils import enforce_types
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: DELETE THIS ASAP
|
||||||
|
# TODO: So we have a patch where we manually specify CRUD operations
|
||||||
|
# TODO: This is because Agent is NOT migrated to the ORM yet
|
||||||
|
# TODO: Once we migrate Agent to the ORM, we should deprecate any agents relationship table managers
|
||||||
|
class BlocksAgentsManager:
|
||||||
|
"""Manager class to handle business logic related to Blocks and Agents."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
from letta.server.server import db_context
|
||||||
|
|
||||||
|
self.session_maker = db_context
|
||||||
|
|
||||||
|
@enforce_types
|
||||||
|
def add_block_to_agent(self, agent_id: str, block_id: str, block_label: str) -> PydanticBlocksAgents:
|
||||||
|
"""Add a block to an agent. If the label already exists on that agent, this will error."""
|
||||||
|
with self.session_maker() as session:
|
||||||
|
try:
|
||||||
|
# Check if the block-label combination already exists for this agent
|
||||||
|
blocks_agents_record = BlocksAgentsModel.read(db_session=session, agent_id=agent_id, block_label=block_label)
|
||||||
|
warnings.warn(f"Block label '{block_label}' already exists for agent '{agent_id}'.")
|
||||||
|
except NoResultFound:
|
||||||
|
blocks_agents_record = PydanticBlocksAgents(agent_id=agent_id, block_id=block_id, block_label=block_label)
|
||||||
|
blocks_agents_record = BlocksAgentsModel(**blocks_agents_record.model_dump(exclude_none=True))
|
||||||
|
blocks_agents_record.create(session)
|
||||||
|
|
||||||
|
return blocks_agents_record.to_pydantic()
|
||||||
|
|
||||||
|
@enforce_types
|
||||||
|
def remove_block_with_label_from_agent(self, agent_id: str, block_label: str) -> PydanticBlocksAgents:
|
||||||
|
"""Remove a block with a label from an agent."""
|
||||||
|
with self.session_maker() as session:
|
||||||
|
try:
|
||||||
|
# Find and delete the block-label association for the agent
|
||||||
|
blocks_agents_record = BlocksAgentsModel.read(db_session=session, agent_id=agent_id, block_label=block_label)
|
||||||
|
blocks_agents_record.hard_delete(session)
|
||||||
|
return blocks_agents_record.to_pydantic()
|
||||||
|
except NoResultFound:
|
||||||
|
raise ValueError(f"Block label '{block_label}' not found for agent '{agent_id}'.")
|
||||||
|
|
||||||
|
@enforce_types
|
||||||
|
def remove_block_with_id_from_agent(self, agent_id: str, block_id: str) -> PydanticBlocksAgents:
|
||||||
|
"""Remove a block with a label from an agent."""
|
||||||
|
with self.session_maker() as session:
|
||||||
|
try:
|
||||||
|
# Find and delete the block-label association for the agent
|
||||||
|
blocks_agents_record = BlocksAgentsModel.read(db_session=session, agent_id=agent_id, block_id=block_id)
|
||||||
|
blocks_agents_record.hard_delete(session)
|
||||||
|
return blocks_agents_record.to_pydantic()
|
||||||
|
except NoResultFound:
|
||||||
|
raise ValueError(f"Block id '{block_id}' not found for agent '{agent_id}'.")
|
||||||
|
|
||||||
|
@enforce_types
|
||||||
|
def update_block_id_for_agent(self, agent_id: str, block_label: str, new_block_id: str) -> PydanticBlocksAgents:
|
||||||
|
"""Update the block ID for a specific block label for an agent."""
|
||||||
|
with self.session_maker() as session:
|
||||||
|
try:
|
||||||
|
blocks_agents_record = BlocksAgentsModel.read(db_session=session, agent_id=agent_id, block_label=block_label)
|
||||||
|
blocks_agents_record.block_id = new_block_id
|
||||||
|
return blocks_agents_record.to_pydantic()
|
||||||
|
except NoResultFound:
|
||||||
|
raise ValueError(f"Block label '{block_label}' not found for agent '{agent_id}'.")
|
||||||
|
|
||||||
|
@enforce_types
|
||||||
|
def list_block_ids_for_agent(self, agent_id: str) -> List[str]:
|
||||||
|
"""List all blocks associated with a specific agent."""
|
||||||
|
with self.session_maker() as session:
|
||||||
|
blocks_agents_record = BlocksAgentsModel.list(db_session=session, agent_id=agent_id)
|
||||||
|
return [record.block_id for record in blocks_agents_record]
|
||||||
|
|
||||||
|
@enforce_types
|
||||||
|
def list_agent_ids_with_block(self, block_id: str) -> List[str]:
|
||||||
|
"""List all agents associated with a specific block."""
|
||||||
|
with self.session_maker() as session:
|
||||||
|
blocks_agents_record = BlocksAgentsModel.list(db_session=session, block_id=block_id)
|
||||||
|
return [record.agent_id for record in blocks_agents_record]
|
@ -1,10 +1,13 @@
|
|||||||
import pytest
|
import pytest
|
||||||
from sqlalchemy import delete
|
from sqlalchemy import delete
|
||||||
|
from sqlalchemy.exc import DBAPIError
|
||||||
|
|
||||||
import letta.utils as utils
|
import letta.utils as utils
|
||||||
from letta.functions.functions import derive_openai_json_schema, parse_source_code
|
from letta.functions.functions import derive_openai_json_schema, parse_source_code
|
||||||
|
from letta.metadata import AgentModel
|
||||||
from letta.orm import (
|
from letta.orm import (
|
||||||
Block,
|
Block,
|
||||||
|
BlocksAgents,
|
||||||
FileMetadata,
|
FileMetadata,
|
||||||
Organization,
|
Organization,
|
||||||
SandboxConfig,
|
SandboxConfig,
|
||||||
@ -13,6 +16,7 @@ from letta.orm import (
|
|||||||
Tool,
|
Tool,
|
||||||
User,
|
User,
|
||||||
)
|
)
|
||||||
|
from letta.orm.agents_tags import AgentsTags
|
||||||
from letta.schemas.agent import CreateAgent
|
from letta.schemas.agent import CreateAgent
|
||||||
from letta.schemas.block import Block as PydanticBlock
|
from letta.schemas.block import Block as PydanticBlock
|
||||||
from letta.schemas.block import BlockUpdate
|
from letta.schemas.block import BlockUpdate
|
||||||
@ -60,12 +64,15 @@ DEFAULT_EMBEDDING_CONFIG = EmbeddingConfig(
|
|||||||
def clear_tables(server: SyncServer):
|
def clear_tables(server: SyncServer):
|
||||||
"""Fixture to clear the organization table before each test."""
|
"""Fixture to clear the organization table before each test."""
|
||||||
with server.organization_manager.session_maker() as session:
|
with server.organization_manager.session_maker() as session:
|
||||||
|
session.execute(delete(BlocksAgents))
|
||||||
|
session.execute(delete(AgentsTags))
|
||||||
session.execute(delete(SandboxEnvironmentVariable))
|
session.execute(delete(SandboxEnvironmentVariable))
|
||||||
session.execute(delete(SandboxConfig))
|
session.execute(delete(SandboxConfig))
|
||||||
session.execute(delete(Block))
|
session.execute(delete(Block))
|
||||||
session.execute(delete(FileMetadata))
|
session.execute(delete(FileMetadata))
|
||||||
session.execute(delete(Source))
|
session.execute(delete(Source))
|
||||||
session.execute(delete(Tool)) # Clear all records from the Tool table
|
session.execute(delete(Tool)) # Clear all records from the Tool table
|
||||||
|
session.execute(delete(AgentModel))
|
||||||
session.execute(delete(User)) # Clear all records from the user table
|
session.execute(delete(User)) # Clear all records from the user table
|
||||||
session.execute(delete(Organization)) # Clear all records from the organization table
|
session.execute(delete(Organization)) # Clear all records from the organization table
|
||||||
session.commit() # Commit the deletion
|
session.commit() # Commit the deletion
|
||||||
@ -121,8 +128,6 @@ def sarah_agent(server: SyncServer, default_user, default_organization):
|
|||||||
)
|
)
|
||||||
yield agent_state
|
yield agent_state
|
||||||
|
|
||||||
server.delete_agent(user_id=default_user.id, agent_id=agent_state.id)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def charles_agent(server: SyncServer, default_user, default_organization):
|
def charles_agent(server: SyncServer, default_user, default_organization):
|
||||||
@ -141,8 +146,6 @@ def charles_agent(server: SyncServer, default_user, default_organization):
|
|||||||
)
|
)
|
||||||
yield agent_state
|
yield agent_state
|
||||||
|
|
||||||
server.delete_agent(user_id=default_user.id, agent_id=agent_state.id)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def tool_fixture(server: SyncServer, default_user, default_organization):
|
def tool_fixture(server: SyncServer, default_user, default_organization):
|
||||||
@ -200,6 +203,36 @@ def sandbox_env_var_fixture(server: SyncServer, sandbox_config_fixture, default_
|
|||||||
yield created_env_var
|
yield created_env_var
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def default_block(server: SyncServer, default_user):
|
||||||
|
"""Fixture to create and return a default block."""
|
||||||
|
block_manager = BlockManager()
|
||||||
|
block_data = PydanticBlock(
|
||||||
|
label="default_label",
|
||||||
|
value="Default Block Content",
|
||||||
|
description="A default test block",
|
||||||
|
limit=1000,
|
||||||
|
metadata_={"type": "test"},
|
||||||
|
)
|
||||||
|
block = block_manager.create_or_update_block(block_data, actor=default_user)
|
||||||
|
yield block
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def other_block(server: SyncServer, default_user):
|
||||||
|
"""Fixture to create and return another block."""
|
||||||
|
block_manager = BlockManager()
|
||||||
|
block_data = PydanticBlock(
|
||||||
|
label="other_label",
|
||||||
|
value="Other Block Content",
|
||||||
|
description="Another test block",
|
||||||
|
limit=500,
|
||||||
|
metadata_={"type": "test"},
|
||||||
|
)
|
||||||
|
block = block_manager.create_or_update_block(block_data, actor=default_user)
|
||||||
|
yield block
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def server():
|
def server():
|
||||||
config = LettaConfig.load()
|
config = LettaConfig.load()
|
||||||
@ -561,7 +594,7 @@ def test_update_block_limit(server: SyncServer, default_user):
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
block_manager.update_block(block_id=block.id, block_update=update_data, actor=default_user, limit=limit)
|
block_manager.update_block(block_id=block.id, block_update=update_data, actor=default_user)
|
||||||
# Retrieve the updated block
|
# Retrieve the updated block
|
||||||
updated_block = block_manager.get_blocks(actor=default_user, id=block.id)[0]
|
updated_block = block_manager.get_blocks(actor=default_user, id=block.id)[0]
|
||||||
# Assertions to verify the update
|
# Assertions to verify the update
|
||||||
@ -1018,3 +1051,85 @@ def test_get_sandbox_env_var_by_key(server: SyncServer, sandbox_env_var_fixture,
|
|||||||
# Assertions to verify correct retrieval
|
# Assertions to verify correct retrieval
|
||||||
assert retrieved_env_var.id == sandbox_env_var_fixture.id
|
assert retrieved_env_var.id == sandbox_env_var_fixture.id
|
||||||
assert retrieved_env_var.key == sandbox_env_var_fixture.key
|
assert retrieved_env_var.key == sandbox_env_var_fixture.key
|
||||||
|
|
||||||
|
|
||||||
|
# ======================================================================================================================
|
||||||
|
# BlocksAgentsManager Tests
|
||||||
|
# ======================================================================================================================
|
||||||
|
def test_add_block_to_agent(server, sarah_agent, default_user, default_block):
|
||||||
|
block_association = server.blocks_agents_manager.add_block_to_agent(
|
||||||
|
agent_id=sarah_agent.id, block_id=default_block.id, block_label=default_block.label
|
||||||
|
)
|
||||||
|
|
||||||
|
assert block_association.agent_id == sarah_agent.id
|
||||||
|
assert block_association.block_id == default_block.id
|
||||||
|
assert block_association.block_label == default_block.label
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_block_to_agent_nonexistent_block(server, sarah_agent, default_user):
|
||||||
|
with pytest.raises(DBAPIError, match="violates foreign key constraint .*fk_block_id_label"):
|
||||||
|
server.blocks_agents_manager.add_block_to_agent(
|
||||||
|
agent_id=sarah_agent.id, block_id="nonexistent_block", block_label="nonexistent_label"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_block_to_agent_duplicate_label(server, sarah_agent, default_user, default_block, other_block):
|
||||||
|
server.blocks_agents_manager.add_block_to_agent(agent_id=sarah_agent.id, block_id=default_block.id, block_label=default_block.label)
|
||||||
|
|
||||||
|
with pytest.warns(UserWarning, match=f"Block label '{default_block.label}' already exists for agent '{sarah_agent.id}'"):
|
||||||
|
server.blocks_agents_manager.add_block_to_agent(agent_id=sarah_agent.id, block_id=other_block.id, block_label=default_block.label)
|
||||||
|
|
||||||
|
|
||||||
|
def test_remove_block_with_label_from_agent(server, sarah_agent, default_user, default_block):
|
||||||
|
server.blocks_agents_manager.add_block_to_agent(agent_id=sarah_agent.id, block_id=default_block.id, block_label=default_block.label)
|
||||||
|
|
||||||
|
removed_block = server.blocks_agents_manager.remove_block_with_label_from_agent(
|
||||||
|
agent_id=sarah_agent.id, block_label=default_block.label
|
||||||
|
)
|
||||||
|
|
||||||
|
assert removed_block.block_label == default_block.label
|
||||||
|
assert removed_block.block_id == default_block.id
|
||||||
|
assert removed_block.agent_id == sarah_agent.id
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match=f"Block label '{default_block.label}' not found for agent '{sarah_agent.id}'"):
|
||||||
|
server.blocks_agents_manager.remove_block_with_label_from_agent(agent_id=sarah_agent.id, block_label=default_block.label)
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_block_id_for_agent(server, sarah_agent, default_user, default_block, other_block):
|
||||||
|
server.blocks_agents_manager.add_block_to_agent(agent_id=sarah_agent.id, block_id=default_block.id, block_label=default_block.label)
|
||||||
|
|
||||||
|
updated_block = server.blocks_agents_manager.update_block_id_for_agent(
|
||||||
|
agent_id=sarah_agent.id, block_label=default_block.label, new_block_id=other_block.id
|
||||||
|
)
|
||||||
|
|
||||||
|
assert updated_block.block_id == other_block.id
|
||||||
|
assert updated_block.block_label == default_block.label
|
||||||
|
assert updated_block.agent_id == sarah_agent.id
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_block_ids_for_agent(server, sarah_agent, default_user, default_block, other_block):
|
||||||
|
server.blocks_agents_manager.add_block_to_agent(agent_id=sarah_agent.id, block_id=default_block.id, block_label=default_block.label)
|
||||||
|
server.blocks_agents_manager.add_block_to_agent(agent_id=sarah_agent.id, block_id=other_block.id, block_label=other_block.label)
|
||||||
|
|
||||||
|
retrieved_block_ids = server.blocks_agents_manager.list_block_ids_for_agent(agent_id=sarah_agent.id)
|
||||||
|
|
||||||
|
assert set(retrieved_block_ids) == {default_block.id, other_block.id}
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_agent_ids_with_block(server, sarah_agent, charles_agent, default_user, default_block):
|
||||||
|
server.blocks_agents_manager.add_block_to_agent(agent_id=sarah_agent.id, block_id=default_block.id, block_label=default_block.label)
|
||||||
|
server.blocks_agents_manager.add_block_to_agent(agent_id=charles_agent.id, block_id=default_block.id, block_label=default_block.label)
|
||||||
|
|
||||||
|
agent_ids = server.blocks_agents_manager.list_agent_ids_with_block(block_id=default_block.id)
|
||||||
|
|
||||||
|
assert sarah_agent.id in agent_ids
|
||||||
|
assert charles_agent.id in agent_ids
|
||||||
|
assert len(agent_ids) == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_block_to_agent_with_deleted_block(server, sarah_agent, default_user, default_block):
|
||||||
|
block_manager = BlockManager()
|
||||||
|
block_manager.delete_block(block_id=default_block.id, actor=default_user)
|
||||||
|
|
||||||
|
with pytest.raises(DBAPIError, match='insert or update on table "blocks_agents" violates foreign key constraint'):
|
||||||
|
server.blocks_agents_manager.add_block_to_agent(agent_id=sarah_agent.id, block_id=default_block.id, block_label=default_block.label)
|
||||||
|
Loading…
Reference in New Issue
Block a user