chore: Add agent serialization message_ids test (#1451)

This commit is contained in:
Matthew Zhou 2025-03-28 10:31:02 -07:00 committed by GitHub
parent 81fdbf8dab
commit 0c14d94925

View File

@ -251,6 +251,10 @@ def _compare_agent_state_model_dump(d1: Dict[str, Any], d2: Dict[str, Any], log:
if not isinstance(v1, list) or not isinstance(v2, list) or len(v1) != len(v2):
_log_mismatch(key, v1, v2, log)
return False
elif key == "tool_exec_environment_variables":
if not isinstance(v1, list) or not isinstance(v2, list) or len(v1) != len(v2):
_log_mismatch(key, v1, v2, log)
return False
elif isinstance(v1, Dict) and isinstance(v2, Dict):
if not _compare_agent_state_model_dump(v1, v2):
_log_mismatch(key, v1, v2, log)
@ -266,14 +270,64 @@ def _compare_agent_state_model_dump(d1: Dict[str, Any], d2: Dict[str, Any], log:
return True
def compare_agent_state(original: AgentState, copy: AgentState, append_copy_suffix: bool) -> bool:
def compare_agent_state(server, original: AgentState, copy: AgentState, append_copy_suffix: bool, og_user: User, copy_user: User) -> bool:
"""Wrapper function that provides a default set of ignored prefix fields."""
if not append_copy_suffix:
assert original.name == copy.name
compare_in_context_message_id_remapping(server, original, copy, og_user, copy_user)
return _compare_agent_state_model_dump(original.model_dump(exclude="name"), copy.model_dump(exclude="name"))
def compare_in_context_message_id_remapping(server, og_agent: AgentState, copy_agent: AgentState, og_user, copy_user):
"""
Test deserializing JSON into an Agent instance results in messages with
remapped IDs but identical relevant content and order.
"""
# Serialize the original agent state
result = server.agent_manager.serialize(agent_id=og_agent.id, actor=og_user)
# Retrieve the in-context messages for both the original and the copy
# Corrected typo: agent_id instead of agent_id
in_context_messages_og = server.agent_manager.get_in_context_messages(agent_id=og_agent.id, actor=og_user)
in_context_messages_copy = server.agent_manager.get_in_context_messages(agent_id=copy_agent.id, actor=copy_user)
# 1. Check if the number of messages is the same
assert len(in_context_messages_og) == len(
in_context_messages_copy
), f"Original message count ({len(in_context_messages_og)}) differs from copy ({len(in_context_messages_copy)})"
# 2. Iterate and compare messages by order, checking content equality and ID difference
if not in_context_messages_og:
# If there are no messages, the test passes trivially for message comparison.
# Depending on the test case, you might want to assert that messages *should* exist.
# pytest.fail("Expected messages to exist for comparison, but none were found.")
pass # Or skip if empty lists are valid outcomes
for i, (msg_og, msg_copy) in enumerate(zip(in_context_messages_og, in_context_messages_copy)):
# --- Assert ID Remapping ---
assert msg_og.id != msg_copy.id, f"Message ID at index {i} was not remapped: {msg_og.id}"
# --- Assert Content Equivalence (excluding fields expected to change) ---
# Fields defining the core message content/intent:
assert msg_og.role == msg_copy.role, f"Mismatch in 'role' at index {i}"
assert msg_og.content == msg_copy.content, f"Mismatch in 'content' at index {i}"
assert msg_og.model == msg_copy.model, f"Mismatch in 'model' at index {i}"
assert msg_og.name == msg_copy.name, f"Mismatch in 'name' at index {i}" # Name might be role-based
assert msg_og.tool_calls == msg_copy.tool_calls, f"Mismatch in 'tool_calls' at index {i}"
assert msg_og.tool_returns == msg_copy.tool_returns, f"Mismatch in 'tool_returns' at index {i}"
# Add other fields here if they should be identical across copies
# --- Assert Context/Ownership Fields (verify they point to the *new* context) ---
assert msg_copy.agent_id == copy_agent.id, f"Copied message at index {i} has wrong agent_id: {msg_copy.agent_id} != {copy_agent.id}"
# Assuming organization_id should belong to the 'other_user' context if applicable
# assert msg_copy.organization_id == other_user.organization_id # If relevant/expected
# --- Optionally Assert Original Context Fields (verify they point to the *old* context) ---
assert msg_og.agent_id == og_agent.id, f"Original message at index {i} has wrong agent_id: {msg_og.agent_id} != {og_agent.id}"
# Sanity tests for our agent model_dump verifier helpers
@ -351,17 +405,12 @@ def test_append_copy_suffix_simple(local_client, server, serialize_test_agent, d
"""Test deserializing JSON into an Agent instance."""
result = server.agent_manager.serialize(agent_id=serialize_test_agent.id, actor=default_user)
# write file
with open("test_agent_serialization.json", "w") as f:
# write json
f.write(json.dumps(result.model_dump(), indent=4))
# Deserialize the agent
agent_copy = server.agent_manager.deserialize(serialized_agent=result, actor=other_user, append_copy_suffix=append_copy_suffix)
# Compare serialized representations to check for exact match
print_dict_diff(json.loads(serialize_test_agent.model_dump_json()), json.loads(agent_copy.model_dump_json()))
assert compare_agent_state(agent_copy, serialize_test_agent, append_copy_suffix=append_copy_suffix)
assert compare_agent_state(server, serialize_test_agent, agent_copy, append_copy_suffix, default_user, other_user)
@pytest.mark.parametrize("override_existing_tools", [True, False])
@ -416,7 +465,7 @@ def test_agent_serialize_with_user_messages(local_client, server, serialize_test
# Compare serialized representations to check for exact match
print_dict_diff(json.loads(serialize_test_agent.model_dump_json()), json.loads(agent_copy.model_dump_json()))
assert compare_agent_state(agent_copy, serialize_test_agent, append_copy_suffix=append_copy_suffix)
assert compare_agent_state(server, serialize_test_agent, agent_copy, append_copy_suffix, default_user, other_user)
# Make sure both agents can receive messages after
server.send_messages(
@ -445,7 +494,7 @@ def test_agent_serialize_tool_calls(mock_e2b_api_key_none, local_client, server,
# Compare serialized representations to check for exact match
print_dict_diff(json.loads(serialize_test_agent.model_dump_json()), json.loads(agent_copy.model_dump_json()))
assert compare_agent_state(agent_copy, serialize_test_agent, append_copy_suffix=append_copy_suffix)
assert compare_agent_state(server, serialize_test_agent, agent_copy, append_copy_suffix, default_user, other_user)
# Make sure both agents can receive messages after
original_agent_response = server.send_messages(
@ -463,18 +512,6 @@ def test_agent_serialize_tool_calls(mock_e2b_api_key_none, local_client, server,
assert copy_agent_response.completion_tokens > 0 and copy_agent_response.step_count > 0
def test_in_context_message_id_remapping(local_client, server, serialize_test_agent, default_user, other_user):
"""Test deserializing JSON into an Agent instance."""
result = server.agent_manager.serialize(agent_id=serialize_test_agent.id, actor=default_user)
# Deserialize the agent
agent_copy = server.agent_manager.deserialize(serialized_agent=result, actor=other_user)
# Make sure all the messages are able to be retrieved
in_context_messages = server.agent_manager.get_in_context_messages(agent_id=agent_copy.id, actor=other_user)
assert len(in_context_messages) == len(serialize_test_agent.message_ids)
# FastAPI endpoint tests
@ -517,7 +554,7 @@ def test_agent_download_upload_flow(fastapi_client, server, serialize_test_agent
agent_copy = server.agent_manager.get_agent_by_id(agent_id=copied_agent_id, actor=other_user)
print_dict_diff(json.loads(serialize_test_agent.model_dump_json()), json.loads(agent_copy.model_dump_json()))
assert compare_agent_state(agent_copy, serialize_test_agent, append_copy_suffix=append_copy_suffix)
assert compare_agent_state(server, serialize_test_agent, agent_copy, append_copy_suffix, default_user, other_user)
# Step 4: Ensure copied agent receives messages correctly
server.send_messages(