mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
chore: Add agent serialization message_ids test (#1451)
This commit is contained in:
parent
81fdbf8dab
commit
0c14d94925
@ -251,6 +251,10 @@ def _compare_agent_state_model_dump(d1: Dict[str, Any], d2: Dict[str, Any], log:
|
||||
if not isinstance(v1, list) or not isinstance(v2, list) or len(v1) != len(v2):
|
||||
_log_mismatch(key, v1, v2, log)
|
||||
return False
|
||||
elif key == "tool_exec_environment_variables":
|
||||
if not isinstance(v1, list) or not isinstance(v2, list) or len(v1) != len(v2):
|
||||
_log_mismatch(key, v1, v2, log)
|
||||
return False
|
||||
elif isinstance(v1, Dict) and isinstance(v2, Dict):
|
||||
if not _compare_agent_state_model_dump(v1, v2):
|
||||
_log_mismatch(key, v1, v2, log)
|
||||
@ -266,14 +270,64 @@ def _compare_agent_state_model_dump(d1: Dict[str, Any], d2: Dict[str, Any], log:
|
||||
return True
|
||||
|
||||
|
||||
def compare_agent_state(original: AgentState, copy: AgentState, append_copy_suffix: bool) -> bool:
|
||||
def compare_agent_state(server, original: AgentState, copy: AgentState, append_copy_suffix: bool, og_user: User, copy_user: User) -> bool:
|
||||
"""Wrapper function that provides a default set of ignored prefix fields."""
|
||||
if not append_copy_suffix:
|
||||
assert original.name == copy.name
|
||||
|
||||
compare_in_context_message_id_remapping(server, original, copy, og_user, copy_user)
|
||||
|
||||
return _compare_agent_state_model_dump(original.model_dump(exclude="name"), copy.model_dump(exclude="name"))
|
||||
|
||||
|
||||
def compare_in_context_message_id_remapping(server, og_agent: AgentState, copy_agent: AgentState, og_user, copy_user):
|
||||
"""
|
||||
Test deserializing JSON into an Agent instance results in messages with
|
||||
remapped IDs but identical relevant content and order.
|
||||
"""
|
||||
# Serialize the original agent state
|
||||
result = server.agent_manager.serialize(agent_id=og_agent.id, actor=og_user)
|
||||
|
||||
# Retrieve the in-context messages for both the original and the copy
|
||||
# Corrected typo: agent_id instead of agent_id
|
||||
in_context_messages_og = server.agent_manager.get_in_context_messages(agent_id=og_agent.id, actor=og_user)
|
||||
in_context_messages_copy = server.agent_manager.get_in_context_messages(agent_id=copy_agent.id, actor=copy_user)
|
||||
|
||||
# 1. Check if the number of messages is the same
|
||||
assert len(in_context_messages_og) == len(
|
||||
in_context_messages_copy
|
||||
), f"Original message count ({len(in_context_messages_og)}) differs from copy ({len(in_context_messages_copy)})"
|
||||
|
||||
# 2. Iterate and compare messages by order, checking content equality and ID difference
|
||||
if not in_context_messages_og:
|
||||
# If there are no messages, the test passes trivially for message comparison.
|
||||
# Depending on the test case, you might want to assert that messages *should* exist.
|
||||
# pytest.fail("Expected messages to exist for comparison, but none were found.")
|
||||
pass # Or skip if empty lists are valid outcomes
|
||||
|
||||
for i, (msg_og, msg_copy) in enumerate(zip(in_context_messages_og, in_context_messages_copy)):
|
||||
# --- Assert ID Remapping ---
|
||||
assert msg_og.id != msg_copy.id, f"Message ID at index {i} was not remapped: {msg_og.id}"
|
||||
|
||||
# --- Assert Content Equivalence (excluding fields expected to change) ---
|
||||
# Fields defining the core message content/intent:
|
||||
assert msg_og.role == msg_copy.role, f"Mismatch in 'role' at index {i}"
|
||||
assert msg_og.content == msg_copy.content, f"Mismatch in 'content' at index {i}"
|
||||
assert msg_og.model == msg_copy.model, f"Mismatch in 'model' at index {i}"
|
||||
assert msg_og.name == msg_copy.name, f"Mismatch in 'name' at index {i}" # Name might be role-based
|
||||
assert msg_og.tool_calls == msg_copy.tool_calls, f"Mismatch in 'tool_calls' at index {i}"
|
||||
assert msg_og.tool_returns == msg_copy.tool_returns, f"Mismatch in 'tool_returns' at index {i}"
|
||||
# Add other fields here if they should be identical across copies
|
||||
|
||||
# --- Assert Context/Ownership Fields (verify they point to the *new* context) ---
|
||||
assert msg_copy.agent_id == copy_agent.id, f"Copied message at index {i} has wrong agent_id: {msg_copy.agent_id} != {copy_agent.id}"
|
||||
# Assuming organization_id should belong to the 'other_user' context if applicable
|
||||
# assert msg_copy.organization_id == other_user.organization_id # If relevant/expected
|
||||
|
||||
# --- Optionally Assert Original Context Fields (verify they point to the *old* context) ---
|
||||
assert msg_og.agent_id == og_agent.id, f"Original message at index {i} has wrong agent_id: {msg_og.agent_id} != {og_agent.id}"
|
||||
|
||||
|
||||
# Sanity tests for our agent model_dump verifier helpers
|
||||
|
||||
|
||||
@ -351,17 +405,12 @@ def test_append_copy_suffix_simple(local_client, server, serialize_test_agent, d
|
||||
"""Test deserializing JSON into an Agent instance."""
|
||||
result = server.agent_manager.serialize(agent_id=serialize_test_agent.id, actor=default_user)
|
||||
|
||||
# write file
|
||||
with open("test_agent_serialization.json", "w") as f:
|
||||
# write json
|
||||
f.write(json.dumps(result.model_dump(), indent=4))
|
||||
|
||||
# Deserialize the agent
|
||||
agent_copy = server.agent_manager.deserialize(serialized_agent=result, actor=other_user, append_copy_suffix=append_copy_suffix)
|
||||
|
||||
# Compare serialized representations to check for exact match
|
||||
print_dict_diff(json.loads(serialize_test_agent.model_dump_json()), json.loads(agent_copy.model_dump_json()))
|
||||
assert compare_agent_state(agent_copy, serialize_test_agent, append_copy_suffix=append_copy_suffix)
|
||||
assert compare_agent_state(server, serialize_test_agent, agent_copy, append_copy_suffix, default_user, other_user)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("override_existing_tools", [True, False])
|
||||
@ -416,7 +465,7 @@ def test_agent_serialize_with_user_messages(local_client, server, serialize_test
|
||||
|
||||
# Compare serialized representations to check for exact match
|
||||
print_dict_diff(json.loads(serialize_test_agent.model_dump_json()), json.loads(agent_copy.model_dump_json()))
|
||||
assert compare_agent_state(agent_copy, serialize_test_agent, append_copy_suffix=append_copy_suffix)
|
||||
assert compare_agent_state(server, serialize_test_agent, agent_copy, append_copy_suffix, default_user, other_user)
|
||||
|
||||
# Make sure both agents can receive messages after
|
||||
server.send_messages(
|
||||
@ -445,7 +494,7 @@ def test_agent_serialize_tool_calls(mock_e2b_api_key_none, local_client, server,
|
||||
|
||||
# Compare serialized representations to check for exact match
|
||||
print_dict_diff(json.loads(serialize_test_agent.model_dump_json()), json.loads(agent_copy.model_dump_json()))
|
||||
assert compare_agent_state(agent_copy, serialize_test_agent, append_copy_suffix=append_copy_suffix)
|
||||
assert compare_agent_state(server, serialize_test_agent, agent_copy, append_copy_suffix, default_user, other_user)
|
||||
|
||||
# Make sure both agents can receive messages after
|
||||
original_agent_response = server.send_messages(
|
||||
@ -463,18 +512,6 @@ def test_agent_serialize_tool_calls(mock_e2b_api_key_none, local_client, server,
|
||||
assert copy_agent_response.completion_tokens > 0 and copy_agent_response.step_count > 0
|
||||
|
||||
|
||||
def test_in_context_message_id_remapping(local_client, server, serialize_test_agent, default_user, other_user):
|
||||
"""Test deserializing JSON into an Agent instance."""
|
||||
result = server.agent_manager.serialize(agent_id=serialize_test_agent.id, actor=default_user)
|
||||
|
||||
# Deserialize the agent
|
||||
agent_copy = server.agent_manager.deserialize(serialized_agent=result, actor=other_user)
|
||||
|
||||
# Make sure all the messages are able to be retrieved
|
||||
in_context_messages = server.agent_manager.get_in_context_messages(agent_id=agent_copy.id, actor=other_user)
|
||||
assert len(in_context_messages) == len(serialize_test_agent.message_ids)
|
||||
|
||||
|
||||
# FastAPI endpoint tests
|
||||
|
||||
|
||||
@ -517,7 +554,7 @@ def test_agent_download_upload_flow(fastapi_client, server, serialize_test_agent
|
||||
agent_copy = server.agent_manager.get_agent_by_id(agent_id=copied_agent_id, actor=other_user)
|
||||
|
||||
print_dict_diff(json.loads(serialize_test_agent.model_dump_json()), json.loads(agent_copy.model_dump_json()))
|
||||
assert compare_agent_state(agent_copy, serialize_test_agent, append_copy_suffix=append_copy_suffix)
|
||||
assert compare_agent_state(server, serialize_test_agent, agent_copy, append_copy_suffix, default_user, other_user)
|
||||
|
||||
# Step 4: Ensure copied agent receives messages correctly
|
||||
server.send_messages(
|
||||
|
Loading…
Reference in New Issue
Block a user