mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
537 lines
19 KiB
Python
537 lines
19 KiB
Python
import pytest
|
|
from sqlalchemy import delete
|
|
|
|
from letta.config import LettaConfig
|
|
from letta.orm import Provider, Step
|
|
from letta.orm.enums import JobType
|
|
from letta.orm.errors import NoResultFound
|
|
from letta.schemas.agent import CreateAgent
|
|
from letta.schemas.block import CreateBlock
|
|
from letta.schemas.enums import JobStatus
|
|
from letta.schemas.group import (
|
|
DynamicManager,
|
|
GroupCreate,
|
|
GroupUpdate,
|
|
ManagerType,
|
|
RoundRobinManager,
|
|
SleeptimeManager,
|
|
SupervisorManager,
|
|
)
|
|
from letta.schemas.message import MessageCreate
|
|
from letta.schemas.run import Run
|
|
from letta.server.server import SyncServer
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def server():
|
|
config = LettaConfig.load()
|
|
print("CONFIG PATH", config.config_path)
|
|
|
|
config.save()
|
|
|
|
server = SyncServer()
|
|
return server
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def org_id(server):
|
|
org = server.organization_manager.create_default_organization()
|
|
|
|
yield org.id
|
|
|
|
# cleanup
|
|
with server.organization_manager.session_maker() as session:
|
|
session.execute(delete(Step))
|
|
session.execute(delete(Provider))
|
|
session.commit()
|
|
server.organization_manager.delete_organization_by_id(org.id)
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def actor(server, org_id):
|
|
user = server.user_manager.create_default_user()
|
|
yield user
|
|
|
|
# cleanup
|
|
server.user_manager.delete_user_by_id(user.id)
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def participant_agents(server, actor):
|
|
agent_fred = server.create_agent(
|
|
request=CreateAgent(
|
|
name="fred",
|
|
memory_blocks=[
|
|
CreateBlock(
|
|
label="persona",
|
|
value="Your name is fred and you like to ski and have been wanting to go on a ski trip soon. You are speaking in a group chat with other agent pals where you participate in friendly banter.",
|
|
),
|
|
],
|
|
model="openai/gpt-4o-mini",
|
|
embedding="openai/text-embedding-ada-002",
|
|
),
|
|
actor=actor,
|
|
)
|
|
agent_velma = server.create_agent(
|
|
request=CreateAgent(
|
|
name="velma",
|
|
memory_blocks=[
|
|
CreateBlock(
|
|
label="persona",
|
|
value="Your name is velma and you like tropical locations. You are speaking in a group chat with other agent friends and you love to include everyone.",
|
|
),
|
|
],
|
|
model="openai/gpt-4o-mini",
|
|
embedding="openai/text-embedding-ada-002",
|
|
),
|
|
actor=actor,
|
|
)
|
|
agent_daphne = server.create_agent(
|
|
request=CreateAgent(
|
|
name="daphne",
|
|
memory_blocks=[
|
|
CreateBlock(
|
|
label="persona",
|
|
value="Your name is daphne and you love traveling abroad. You are speaking in a group chat with other agent friends and you love to keep in touch with them.",
|
|
),
|
|
],
|
|
model="openai/gpt-4o-mini",
|
|
embedding="openai/text-embedding-ada-002",
|
|
),
|
|
actor=actor,
|
|
)
|
|
agent_shaggy = server.create_agent(
|
|
request=CreateAgent(
|
|
name="shaggy",
|
|
memory_blocks=[
|
|
CreateBlock(
|
|
label="persona",
|
|
value="Your name is shaggy and your best friend is your dog, scooby. You are speaking in a group chat with other agent friends and you like to solve mysteries with them.",
|
|
),
|
|
],
|
|
model="openai/gpt-4o-mini",
|
|
embedding="openai/text-embedding-ada-002",
|
|
),
|
|
actor=actor,
|
|
)
|
|
yield [agent_fred, agent_velma, agent_daphne, agent_shaggy]
|
|
|
|
# cleanup
|
|
server.agent_manager.delete_agent(agent_fred.id, actor=actor)
|
|
server.agent_manager.delete_agent(agent_velma.id, actor=actor)
|
|
server.agent_manager.delete_agent(agent_daphne.id, actor=actor)
|
|
server.agent_manager.delete_agent(agent_shaggy.id, actor=actor)
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def manager_agent(server, actor):
|
|
agent_scooby = server.create_agent(
|
|
request=CreateAgent(
|
|
name="scooby",
|
|
memory_blocks=[
|
|
CreateBlock(
|
|
label="persona",
|
|
value="You are a puppy operations agent for Letta and you help run multi-agent group chats. Your job is to get to know the agents in your group and pick who is best suited to speak next in the conversation.",
|
|
),
|
|
CreateBlock(
|
|
label="human",
|
|
value="",
|
|
),
|
|
],
|
|
model="openai/gpt-4o-mini",
|
|
embedding="openai/text-embedding-ada-002",
|
|
),
|
|
actor=actor,
|
|
)
|
|
yield agent_scooby
|
|
|
|
# cleanup
|
|
server.agent_manager.delete_agent(agent_scooby.id, actor=actor)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_empty_group(server, actor):
|
|
group = server.group_manager.create_group(
|
|
group=GroupCreate(
|
|
description="This is a group chat between best friends all like to hang out together. In their free time they like to solve mysteries.",
|
|
agent_ids=[],
|
|
),
|
|
actor=actor,
|
|
)
|
|
with pytest.raises(ValueError, match="Empty group"):
|
|
await server.send_group_message_to_agent(
|
|
group_id=group.id,
|
|
actor=actor,
|
|
messages=[
|
|
MessageCreate(
|
|
role="user",
|
|
content="what is everyone up to for the holidays?",
|
|
),
|
|
],
|
|
stream_steps=False,
|
|
stream_tokens=False,
|
|
)
|
|
server.group_manager.delete_group(group_id=group.id, actor=actor)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_modify_group_pattern(server, actor, participant_agents, manager_agent):
|
|
group = server.group_manager.create_group(
|
|
group=GroupCreate(
|
|
description="This is a group chat between best friends all like to hang out together. In their free time they like to solve mysteries.",
|
|
agent_ids=[agent.id for agent in participant_agents],
|
|
),
|
|
actor=actor,
|
|
)
|
|
with pytest.raises(ValueError, match="Cannot change group pattern"):
|
|
server.group_manager.modify_group(
|
|
group_id=group.id,
|
|
group_update=GroupUpdate(
|
|
manager_config=DynamicManager(
|
|
manager_type=ManagerType.dynamic,
|
|
manager_agent_id=manager_agent.id,
|
|
),
|
|
),
|
|
actor=actor,
|
|
)
|
|
|
|
server.group_manager.delete_group(group_id=group.id, actor=actor)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_round_robin(server, actor, participant_agents):
|
|
description = (
|
|
"This is a group chat between best friends all like to hang out together. In their free time they like to solve mysteries."
|
|
)
|
|
group = server.group_manager.create_group(
|
|
group=GroupCreate(
|
|
description=description,
|
|
agent_ids=[agent.id for agent in participant_agents],
|
|
),
|
|
actor=actor,
|
|
)
|
|
|
|
# verify group creation
|
|
assert group.manager_type == ManagerType.round_robin
|
|
assert group.description == description
|
|
assert group.agent_ids == [agent.id for agent in participant_agents]
|
|
assert group.max_turns == None
|
|
assert group.manager_agent_id is None
|
|
assert group.termination_token is None
|
|
|
|
try:
|
|
server.group_manager.reset_messages(group_id=group.id, actor=actor)
|
|
response = await server.send_group_message_to_agent(
|
|
group_id=group.id,
|
|
actor=actor,
|
|
messages=[
|
|
MessageCreate(
|
|
role="user",
|
|
content="what is everyone up to for the holidays?",
|
|
),
|
|
],
|
|
stream_steps=False,
|
|
stream_tokens=False,
|
|
)
|
|
assert response.usage.step_count == len(group.agent_ids)
|
|
assert len(response.messages) == response.usage.step_count * 2
|
|
for i, message in enumerate(response.messages):
|
|
assert message.message_type == "reasoning_message" if i % 2 == 0 else "assistant_message"
|
|
assert message.name == participant_agents[i // 2].name
|
|
|
|
for agent_id in group.agent_ids:
|
|
agent_messages = server.get_agent_recall(
|
|
user_id=actor.id,
|
|
agent_id=agent_id,
|
|
group_id=group.id,
|
|
reverse=True,
|
|
return_message_object=False,
|
|
)
|
|
assert len(agent_messages) == len(group.agent_ids) + 2 # add one for user message, one for reasoning message
|
|
|
|
# TODO: filter this to return a clean conversation history
|
|
messages = server.group_manager.list_group_messages(
|
|
group_id=group.id,
|
|
actor=actor,
|
|
)
|
|
assert len(messages) == (len(group.agent_ids) + 2) * len(group.agent_ids)
|
|
|
|
max_turns = 3
|
|
group = server.group_manager.modify_group(
|
|
group_id=group.id,
|
|
group_update=GroupUpdate(
|
|
agent_ids=[agent.id for agent in participant_agents][::-1],
|
|
manager_config=RoundRobinManager(
|
|
max_turns=max_turns,
|
|
),
|
|
),
|
|
actor=actor,
|
|
)
|
|
assert group.manager_type == ManagerType.round_robin
|
|
assert group.description == description
|
|
assert group.agent_ids == [agent.id for agent in participant_agents][::-1]
|
|
assert group.max_turns == max_turns
|
|
assert group.manager_agent_id is None
|
|
assert group.termination_token is None
|
|
|
|
server.group_manager.reset_messages(group_id=group.id, actor=actor)
|
|
|
|
response = await server.send_group_message_to_agent(
|
|
group_id=group.id,
|
|
actor=actor,
|
|
messages=[
|
|
MessageCreate(
|
|
role="user",
|
|
content="what is everyone up to for the holidays?",
|
|
),
|
|
],
|
|
stream_steps=False,
|
|
stream_tokens=False,
|
|
)
|
|
assert response.usage.step_count == max_turns
|
|
assert len(response.messages) == max_turns * 2
|
|
|
|
for i, message in enumerate(response.messages):
|
|
assert message.message_type == "reasoning_message" if i % 2 == 0 else "assistant_message"
|
|
assert message.name == participant_agents[::-1][i // 2].name
|
|
|
|
for i in range(len(group.agent_ids)):
|
|
agent_messages = server.get_agent_recall(
|
|
user_id=actor.id,
|
|
agent_id=group.agent_ids[i],
|
|
group_id=group.id,
|
|
reverse=True,
|
|
return_message_object=False,
|
|
)
|
|
expected_message_count = max_turns + 1 if i >= max_turns else max_turns + 2
|
|
assert len(agent_messages) == expected_message_count
|
|
|
|
finally:
|
|
server.group_manager.delete_group(group_id=group.id, actor=actor)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_supervisor(server, actor, participant_agents):
|
|
agent_scrappy = server.create_agent(
|
|
request=CreateAgent(
|
|
name="shaggy",
|
|
memory_blocks=[
|
|
CreateBlock(
|
|
label="persona",
|
|
value="You are a puppy operations agent for Letta and you help run multi-agent group chats. Your role is to supervise the group, sending messages and aggregating the responses.",
|
|
),
|
|
CreateBlock(
|
|
label="human",
|
|
value="",
|
|
),
|
|
],
|
|
model="openai/gpt-4o-mini",
|
|
embedding="openai/text-embedding-ada-002",
|
|
),
|
|
actor=actor,
|
|
)
|
|
|
|
group = server.group_manager.create_group(
|
|
group=GroupCreate(
|
|
description="This is a group chat between best friends all like to hang out together. In their free time they like to solve mysteries.",
|
|
agent_ids=[agent.id for agent in participant_agents],
|
|
manager_config=SupervisorManager(
|
|
manager_agent_id=agent_scrappy.id,
|
|
),
|
|
),
|
|
actor=actor,
|
|
)
|
|
try:
|
|
response = await server.send_group_message_to_agent(
|
|
group_id=group.id,
|
|
actor=actor,
|
|
messages=[
|
|
MessageCreate(
|
|
role="user",
|
|
content="ask everyone what they like to do for fun and then come up with an activity for everyone to do together.",
|
|
),
|
|
],
|
|
stream_steps=False,
|
|
stream_tokens=False,
|
|
)
|
|
assert response.usage.step_count == 2
|
|
assert len(response.messages) == 5
|
|
|
|
# verify tool call
|
|
assert response.messages[0].message_type == "reasoning_message"
|
|
assert (
|
|
response.messages[1].message_type == "tool_call_message"
|
|
and response.messages[1].tool_call.name == "send_message_to_all_agents_in_group"
|
|
)
|
|
assert response.messages[2].message_type == "tool_return_message" and len(eval(response.messages[2].tool_return)) == len(
|
|
participant_agents
|
|
)
|
|
assert response.messages[3].message_type == "reasoning_message"
|
|
assert response.messages[4].message_type == "assistant_message"
|
|
|
|
finally:
|
|
server.group_manager.delete_group(group_id=group.id, actor=actor)
|
|
server.agent_manager.delete_agent(agent_id=agent_scrappy.id, actor=actor)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_dynamic_group_chat(server, actor, manager_agent, participant_agents):
|
|
description = (
|
|
"This is a group chat between best friends all like to hang out together. In their free time they like to solve mysteries."
|
|
)
|
|
# error on duplicate agent in participant list
|
|
with pytest.raises(ValueError, match="Duplicate agent ids"):
|
|
server.group_manager.create_group(
|
|
group=GroupCreate(
|
|
description=description,
|
|
agent_ids=[agent.id for agent in participant_agents] + [participant_agents[0].id],
|
|
manager_config=DynamicManager(
|
|
manager_agent_id=manager_agent.id,
|
|
),
|
|
),
|
|
actor=actor,
|
|
)
|
|
# error on duplicate agent names
|
|
duplicate_agent_shaggy = server.create_agent(
|
|
request=CreateAgent(
|
|
name="shaggy",
|
|
model="openai/gpt-4o-mini",
|
|
embedding="openai/text-embedding-ada-002",
|
|
),
|
|
actor=actor,
|
|
)
|
|
with pytest.raises(ValueError, match="Duplicate agent names"):
|
|
server.group_manager.create_group(
|
|
group=GroupCreate(
|
|
description=description,
|
|
agent_ids=[agent.id for agent in participant_agents] + [duplicate_agent_shaggy.id],
|
|
manager_config=DynamicManager(
|
|
manager_agent_id=manager_agent.id,
|
|
),
|
|
),
|
|
actor=actor,
|
|
)
|
|
server.agent_manager.delete_agent(duplicate_agent_shaggy.id, actor=actor)
|
|
|
|
group = server.group_manager.create_group(
|
|
group=GroupCreate(
|
|
description=description,
|
|
agent_ids=[agent.id for agent in participant_agents],
|
|
manager_config=DynamicManager(
|
|
manager_agent_id=manager_agent.id,
|
|
),
|
|
),
|
|
actor=actor,
|
|
)
|
|
try:
|
|
response = await server.send_group_message_to_agent(
|
|
group_id=group.id,
|
|
actor=actor,
|
|
messages=[
|
|
MessageCreate(role="user", content="what is everyone up to for the holidays?"),
|
|
],
|
|
stream_steps=False,
|
|
stream_tokens=False,
|
|
)
|
|
assert response.usage.step_count == len(participant_agents) * 2
|
|
assert len(response.messages) == response.usage.step_count * 2
|
|
|
|
finally:
|
|
server.group_manager.delete_group(group_id=group.id, actor=actor)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_sleeptime_group_chat(server, actor):
|
|
# 0. Refresh base tools
|
|
server.tool_manager.upsert_base_tools(actor=actor)
|
|
|
|
# 1. Create sleeptime agent
|
|
main_agent = server.create_agent(
|
|
request=CreateAgent(
|
|
name="main_agent",
|
|
memory_blocks=[
|
|
CreateBlock(
|
|
label="persona",
|
|
value="You are a personal assistant that helps users with requests.",
|
|
),
|
|
CreateBlock(
|
|
label="human",
|
|
value="",
|
|
),
|
|
],
|
|
model="openai/gpt-4o-mini",
|
|
embedding="openai/text-embedding-ada-002",
|
|
enable_sleeptime=True,
|
|
),
|
|
actor=actor,
|
|
)
|
|
|
|
assert main_agent.enable_sleeptime == True
|
|
|
|
# 2. Override frequency for test
|
|
group = server.group_manager.modify_group(
|
|
group_id=main_agent.multi_agent_group.id,
|
|
group_update=GroupUpdate(
|
|
manager_config=SleeptimeManager(
|
|
manager_agent_id=main_agent.id,
|
|
sleeptime_agent_frequency=2,
|
|
),
|
|
),
|
|
actor=actor,
|
|
)
|
|
|
|
assert group.manager_type == ManagerType.sleeptime
|
|
assert group.sleeptime_agent_frequency == 2
|
|
assert len(group.agent_ids) == 1
|
|
|
|
# 3. Verify shared blocks
|
|
sleeptime_agent_id = group.agent_ids[0]
|
|
shared_block = server.agent_manager.get_block_with_label(agent_id=main_agent.id, block_label="human", actor=actor)
|
|
agents = server.block_manager.get_agents_for_block(block_id=shared_block.id, actor=actor)
|
|
assert len(agents) == 2
|
|
assert sleeptime_agent_id in [agent.id for agent in agents]
|
|
assert main_agent.id in [agent.id for agent in agents]
|
|
|
|
# 4. Send messages and verify run ids
|
|
message_text = [
|
|
"my favorite color is orange",
|
|
"not particularly. today is a good day",
|
|
"actually my favorite color is coral",
|
|
"sorry gotta run",
|
|
]
|
|
run_ids = []
|
|
for i, text in enumerate(message_text):
|
|
response = await server.send_message_to_agent(
|
|
agent_id=main_agent.id,
|
|
actor=actor,
|
|
messages=[
|
|
MessageCreate(
|
|
role="user",
|
|
content=text,
|
|
),
|
|
],
|
|
stream_steps=False,
|
|
stream_tokens=False,
|
|
)
|
|
|
|
assert len(response.messages) > 0
|
|
assert len(response.usage.run_ids or []) == i % 2
|
|
run_ids.extend(response.usage.run_ids or [])
|
|
|
|
jobs = server.job_manager.list_jobs(actor=actor, job_type=JobType.RUN)
|
|
runs = [Run.from_job(job) for job in jobs]
|
|
agent_runs = [run for run in runs if "agent_id" in run.metadata and run.metadata["agent_id"] == sleeptime_agent_id]
|
|
assert len(agent_runs) == len(run_ids)
|
|
|
|
for run_id in run_ids:
|
|
job = server.job_manager.get_job_by_id(job_id=run_id, actor=actor)
|
|
assert job.status == JobStatus.completed
|
|
|
|
# 5. Delete agent
|
|
server.agent_manager.delete_agent(agent_id=main_agent.id, actor=actor)
|
|
|
|
with pytest.raises(NoResultFound):
|
|
server.group_manager.retrieve_group(group_id=group.id, actor=actor)
|
|
with pytest.raises(NoResultFound):
|
|
server.agent_manager.get_agent_by_id(agent_id=sleeptime_agent_id, actor=actor)
|