mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
chore: Add more comprehensive testing around blocks (#1419)
This commit is contained in:
parent
8098bad21e
commit
039345d7f5
@ -1,4 +1,6 @@
|
||||
import os
|
||||
import random
|
||||
import string
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
@ -31,6 +33,7 @@ from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.message import MessageCreate, MessageUpdate
|
||||
from letta.schemas.openai.chat_completion_response import UsageStatistics
|
||||
from letta.schemas.organization import Organization
|
||||
from letta.schemas.organization import Organization as PydanticOrganization
|
||||
from letta.schemas.organization import OrganizationUpdate
|
||||
from letta.schemas.passage import Passage as PydanticPassage
|
||||
@ -80,6 +83,13 @@ def default_organization(server: SyncServer):
|
||||
yield org
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def other_organization(server: SyncServer):
|
||||
"""Fixture to create and return the default organization."""
|
||||
org = server.organization_manager.create_organization(pydantic_org=Organization(name="letta"))
|
||||
yield org
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_user(server: SyncServer, default_organization):
|
||||
"""Fixture to create and return the default user within the default organization."""
|
||||
@ -94,6 +104,13 @@ def other_user(server: SyncServer, default_organization):
|
||||
yield user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def other_user_different_org(server: SyncServer, other_organization):
|
||||
"""Fixture to create and return the default user within the default organization."""
|
||||
user = server.user_manager.create_user(PydanticUser(name="other", organization_id=other_organization.id))
|
||||
yield user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_source(server: SyncServer, default_user):
|
||||
source_pydantic = PydanticSource(
|
||||
@ -2398,6 +2415,61 @@ def test_get_blocks(server, default_user):
|
||||
assert persona_blocks[0].label == "persona"
|
||||
|
||||
|
||||
def test_get_blocks_comprehensive(server, default_user, other_user_different_org):
|
||||
def random_label(prefix="label"):
|
||||
return f"{prefix}_{''.join(random.choices(string.ascii_lowercase, k=6))}"
|
||||
|
||||
def random_value():
|
||||
return "".join(random.choices(string.ascii_letters + string.digits, k=12))
|
||||
|
||||
block_manager = BlockManager()
|
||||
|
||||
# Create 10 blocks for default_user
|
||||
default_user_blocks = []
|
||||
for _ in range(10):
|
||||
label = random_label("default")
|
||||
value = random_value()
|
||||
block_manager.create_or_update_block(PydanticBlock(label=label, value=value), actor=default_user)
|
||||
default_user_blocks.append((label, value))
|
||||
|
||||
# Create 3 blocks for other_user
|
||||
other_user_blocks = []
|
||||
for _ in range(3):
|
||||
label = random_label("other")
|
||||
value = random_value()
|
||||
block_manager.create_or_update_block(PydanticBlock(label=label, value=value), actor=other_user_different_org)
|
||||
other_user_blocks.append((label, value))
|
||||
|
||||
# Check default_user sees only their blocks
|
||||
retrieved_default_blocks = block_manager.get_blocks(actor=default_user)
|
||||
assert len(retrieved_default_blocks) == 10
|
||||
retrieved_labels = {b.label for b in retrieved_default_blocks}
|
||||
for label, value in default_user_blocks:
|
||||
assert label in retrieved_labels
|
||||
|
||||
# Check individual filtering for default_user
|
||||
for label, value in default_user_blocks:
|
||||
filtered = block_manager.get_blocks(actor=default_user, label=label)
|
||||
assert len(filtered) == 1
|
||||
assert filtered[0].label == label
|
||||
assert filtered[0].value == value
|
||||
|
||||
# Check other_user sees only their blocks
|
||||
retrieved_other_blocks = block_manager.get_blocks(actor=other_user_different_org)
|
||||
assert len(retrieved_other_blocks) == 3
|
||||
retrieved_labels = {b.label for b in retrieved_other_blocks}
|
||||
for label, value in other_user_blocks:
|
||||
assert label in retrieved_labels
|
||||
|
||||
# Other user shouldn't see default_user's blocks
|
||||
for label, _ in default_user_blocks:
|
||||
assert block_manager.get_blocks(actor=other_user_different_org, label=label) == []
|
||||
|
||||
# Default user shouldn't see other_user's blocks
|
||||
for label, _ in other_user_blocks:
|
||||
assert block_manager.get_blocks(actor=default_user, label=label) == []
|
||||
|
||||
|
||||
def test_update_block(server: SyncServer, default_user):
|
||||
block_manager = BlockManager()
|
||||
block = block_manager.create_or_update_block(PydanticBlock(label="persona", value="Original Content"), actor=default_user)
|
||||
|
Loading…
Reference in New Issue
Block a user