From 406a49ab602b9cfafcddf5c3db143f9b7a1e28ad Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Fri, 28 Mar 2025 11:50:25 -0700 Subject: [PATCH] feat: Change in context message remapping (#1448) Co-authored-by: Sarah Wooders Co-authored-by: Shubham Naik Co-authored-by: Shubham Naik Co-authored-by: Charles Packer --- letta/serialize_schemas/marshmallow_agent.py | 21 +- .../pydantic_agent_schema.py | 2 +- test_agent_serialization.json | 416 ++++++++++++++++++ tests/test_agent_serialization.py | 42 ++ 4 files changed, 471 insertions(+), 10 deletions(-) create mode 100644 test_agent_serialization.json diff --git a/letta/serialize_schemas/marshmallow_agent.py b/letta/serialize_schemas/marshmallow_agent.py index 309333abb..f38f19219 100644 --- a/letta/serialize_schemas/marshmallow_agent.py +++ b/letta/serialize_schemas/marshmallow_agent.py @@ -27,7 +27,7 @@ class MarshmallowAgentSchema(BaseSchema): FIELD_VERSION = "version" FIELD_MESSAGES = "messages" FIELD_MESSAGE_IDS = "message_ids" - FIELD_IN_CONTEXT = "in_context" + FIELD_IN_CONTEXT_INDICES = "in_context_message_indices" FIELD_ID = "id" llm_config = LLMConfigField() @@ -72,12 +72,11 @@ class MarshmallowAgentSchema(BaseSchema): messages = [] # loop through message in the *same* order is the in-context message IDs - for message in data.get(self.FIELD_MESSAGES, []): + data[self.FIELD_IN_CONTEXT_INDICES] = [] + for i, message in enumerate(data.get(self.FIELD_MESSAGES, [])): # if id matches in-context message ID, add to `messages` if message[self.FIELD_ID] in message_ids: - message[self.FIELD_IN_CONTEXT] = True - else: - message[self.FIELD_IN_CONTEXT] = False + data[self.FIELD_IN_CONTEXT_INDICES].append(i) messages.append(message) # remove ids @@ -111,13 +110,17 @@ class MarshmallowAgentSchema(BaseSchema): Restores `message_ids` by collecting message IDs where `in_context` is True, generates new IDs for all messages, and removes `in_context` from all messages. """ - message_ids = [] - for msg in data.get(self.FIELD_MESSAGES, []): + messages = data.get(self.FIELD_MESSAGES, []) + for msg in messages: msg[self.FIELD_ID] = SerializedMessageSchema.generate_id() # Generate new ID - if msg.pop(self.FIELD_IN_CONTEXT, False): # If it was in-context, track its new ID - message_ids.append(msg[self.FIELD_ID]) + + message_ids = [] + in_context_message_indices = data.pop(self.FIELD_IN_CONTEXT_INDICES) + for idx in in_context_message_indices: + message_ids.append(messages[idx][self.FIELD_ID]) data[self.FIELD_MESSAGE_IDS] = message_ids + return data class Meta(BaseSchema.Meta): diff --git a/letta/serialize_schemas/pydantic_agent_schema.py b/letta/serialize_schemas/pydantic_agent_schema.py index 8d11b7811..63a83765c 100644 --- a/letta/serialize_schemas/pydantic_agent_schema.py +++ b/letta/serialize_schemas/pydantic_agent_schema.py @@ -22,7 +22,6 @@ class CoreMemoryBlockSchema(BaseModel): class MessageSchema(BaseModel): created_at: str group_id: Optional[str] - in_context: bool model: Optional[str] name: Optional[str] role: str @@ -112,6 +111,7 @@ class AgentSchema(BaseModel): embedding_config: EmbeddingConfig llm_config: LLMConfig message_buffer_autoclear: bool + in_context_message_indices: List[int] messages: List[MessageSchema] metadata_: Optional[Dict] = None multi_agent_group: Optional[Any] diff --git a/test_agent_serialization.json b/test_agent_serialization.json new file mode 100644 index 000000000..c59d5c1be --- /dev/null +++ b/test_agent_serialization.json @@ -0,0 +1,416 @@ +{ + "agent_type": "memgpt_agent", + "core_memory": [ + { + "created_at": "2025-03-28T01:11:04.570593+00:00", + "description": "A default test block", + "is_template": false, + "label": "default_label", + "limit": 1000, + "metadata_": { + "type": "test" + }, + "template_name": null, + "updated_at": "2025-03-28T01:11:04.570593+00:00", + "value": "Default Block Content" + }, + { + "created_at": "2025-03-28T01:11:04.609286+00:00", + "description": null, + "is_template": false, + "label": "human", + "limit": 5000, + "metadata_": {}, + "template_name": null, + "updated_at": "2025-03-28T01:11:04.609286+00:00", + "value": "BananaBoy" + }, + { + "created_at": "2025-03-28T01:11:04.612946+00:00", + "description": null, + "is_template": false, + "label": "persona", + "limit": 5000, + "metadata_": {}, + "template_name": null, + "updated_at": "2025-03-28T01:11:04.612946+00:00", + "value": "I am a helpful assistant" + } + ], + "created_at": "2025-03-28T01:11:04.624794+00:00", + "description": "test_description", + "embedding_config": { + "embedding_endpoint_type": "openai", + "embedding_endpoint": "https://api.openai.com/v1", + "embedding_model": "text-embedding-ada-002", + "embedding_dim": 1536, + "embedding_chunk_size": 300, + "handle": null, + "azure_endpoint": null, + "azure_version": null, + "azure_deployment": null + }, + "llm_config": { + "model": "gpt-4o-mini", + "model_endpoint_type": "openai", + "model_endpoint": "https://api.openai.com/v1", + "model_wrapper": null, + "context_window": 128000, + "put_inner_thoughts_in_kwargs": true, + "handle": null, + "temperature": 0.7, + "max_tokens": 4096, + "enable_reasoner": false, + "max_reasoning_tokens": 0 + }, + "message_buffer_autoclear": true, + "in_context_message_indices": [0, 1], + "messages": [ + { + "created_at": "2025-03-28T01:11:04.654912+00:00", + "group_id": null, + "model": "gpt-4o-mini", + "name": null, + "role": "system", + "content": [ + { + "type": "text", + "text": "test system\n### Memory [last modified: 2025-03-27 06:11:04 PM PDT-0700]\n0 previous messages between you and the user are stored in recall memory (use functions to access them)\n0 total memories you created are stored in archival memory (use functions to access them)\n\n\nCore memory shown below (limited in size, additional information stored in archival / recall memory):\n\nDefault Block Content\n\n\nBananaBoy\n\n\nI am a helpful assistant\n" + } + ], + "tool_call_id": null, + "tool_calls": [], + "tool_returns": [], + "updated_at": "2025-03-28T01:11:04.654783+00:00" + }, + { + "created_at": "2025-03-28T01:11:04.654966+00:00", + "group_id": null, + "model": "gpt-4o-mini", + "name": null, + "role": "user", + "content": [ + { + "type": "text", + "text": "{\n \"type\": \"user_message\",\n \"message\": \"hello world\",\n \"time\": \"2025-03-27 06:11:04 PM PDT-0700\"\n}" + } + ], + "tool_call_id": null, + "tool_calls": [], + "tool_returns": [], + "updated_at": "2025-03-28T01:11:04.654783+00:00" + } + ], + "metadata_": { + "test_key": "test_value" + }, + "multi_agent_group": null, + "name": "EffervescentYacht", + "system": "test system", + "tags": [ + { + "tag": "a" + }, + { + "tag": "b" + } + ], + "tool_exec_environment_variables": [ + { + "created_at": "2025-03-28T01:11:04.638338+00:00", + "description": null, + "key": "test_env_var_key_a", + "updated_at": "2025-03-28T01:11:04.638338+00:00", + "value": "" + }, + { + "created_at": "2025-03-28T01:11:04.638338+00:00", + "description": null, + "key": "test_env_var_key_b", + "updated_at": "2025-03-28T01:11:04.638338+00:00", + "value": "" + } + ], + "tool_rules": [ + { + "tool_name": "archival_memory_search", + "type": "continue_loop" + }, + { + "tool_name": "archival_memory_insert", + "type": "continue_loop" + }, + { + "tool_name": "send_message", + "type": "exit_loop" + }, + { + "tool_name": "conversation_search", + "type": "continue_loop" + } + ], + "tools": [ + { + "args_json_schema": null, + "created_at": "2025-03-28T01:11:04.575001+00:00", + "description": "Fetches the current weather for a given location.", + "json_schema": { + "name": "get_weather", + "description": "Fetches the current weather for a given location.", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The location to get the weather for." + }, + "request_heartbeat": { + "type": "boolean", + "description": "Request an immediate heartbeat after function execution. Set to `True` if you want to send a follow-up message or run a follow-up function." + } + }, + "required": ["location", "request_heartbeat"] + }, + "type": null, + "required": [] + }, + "name": "get_weather", + "return_char_limit": 6000, + "source_code": "def get_weather(location: str) -> str:\n \"\"\"\n Fetches the current weather for a given location.\n\n Parameters:\n location (str): The location to get the weather for.\n\n Returns:\n str: A formatted string describing the weather in the given location.\n\n Raises:\n RuntimeError: If the request to fetch weather data fails.\n \"\"\"\n import requests\n\n url = f\"https://wttr.in/{location}?format=%C+%t\"\n\n response = requests.get(url)\n if response.status_code == 200:\n weather_data = response.text\n return f\"The weather in {location} is {weather_data}.\"\n else:\n raise RuntimeError(f\"Failed to get weather data, status code: {response.status_code}\")\n", + "source_type": "python", + "tags": [], + "tool_type": "custom", + "updated_at": "2025-03-28T01:11:04.575001+00:00", + "metadata_": {} + }, + { + "args_json_schema": null, + "created_at": "2025-03-28T01:11:04.579856+00:00", + "description": "Add to archival memory. Make sure to phrase the memory contents such that it can be easily queried later.", + "json_schema": { + "name": "archival_memory_insert", + "description": "Add to archival memory. Make sure to phrase the memory contents such that it can be easily queried later.", + "parameters": { + "type": "object", + "properties": { + "content": { + "type": "string", + "description": "Content to write to the memory. All unicode (including emojis) are supported." + }, + "request_heartbeat": { + "type": "boolean", + "description": "Request an immediate heartbeat after function execution. Set to `True` if you want to send a follow-up message or run a follow-up function." + } + }, + "required": ["content", "request_heartbeat"] + }, + "type": null, + "required": [] + }, + "name": "archival_memory_insert", + "return_char_limit": 1000000, + "source_code": null, + "source_type": "python", + "tags": ["letta_core"], + "tool_type": "letta_core", + "updated_at": "2025-03-28T01:11:04.579856+00:00", + "metadata_": {} + }, + { + "args_json_schema": null, + "created_at": "2025-03-28T01:11:04.583369+00:00", + "description": "Search archival memory using semantic (embedding-based) search.", + "json_schema": { + "name": "archival_memory_search", + "description": "Search archival memory using semantic (embedding-based) search.", + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "String to search for." + }, + "page": { + "type": "integer", + "description": "Allows you to page through results. Only use on a follow-up query. Defaults to 0 (first page)." + }, + "start": { + "type": "integer", + "description": "Starting index for the search results. Defaults to 0." + }, + "request_heartbeat": { + "type": "boolean", + "description": "Request an immediate heartbeat after function execution. Set to `True` if you want to send a follow-up message or run a follow-up function." + } + }, + "required": ["query", "request_heartbeat"] + }, + "type": null, + "required": [] + }, + "name": "archival_memory_search", + "return_char_limit": 1000000, + "source_code": null, + "source_type": "python", + "tags": ["letta_core"], + "tool_type": "letta_core", + "updated_at": "2025-03-28T01:11:04.583369+00:00", + "metadata_": {} + }, + { + "args_json_schema": null, + "created_at": "2025-03-28T01:11:04.586573+00:00", + "description": "Search prior conversation history using case-insensitive string matching.", + "json_schema": { + "name": "conversation_search", + "description": "Search prior conversation history using case-insensitive string matching.", + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "String to search for." + }, + "page": { + "type": "integer", + "description": "Allows you to page through results. Only use on a follow-up query. Defaults to 0 (first page)." + }, + "request_heartbeat": { + "type": "boolean", + "description": "Request an immediate heartbeat after function execution. Set to `True` if you want to send a follow-up message or run a follow-up function." + } + }, + "required": ["query", "request_heartbeat"] + }, + "type": null, + "required": [] + }, + "name": "conversation_search", + "return_char_limit": 1000000, + "source_code": null, + "source_type": "python", + "tags": ["letta_core"], + "tool_type": "letta_core", + "updated_at": "2025-03-28T01:11:04.586573+00:00", + "metadata_": {} + }, + { + "args_json_schema": null, + "created_at": "2025-03-28T01:11:04.589876+00:00", + "description": "Append to the contents of core memory.", + "json_schema": { + "name": "core_memory_append", + "description": "Append to the contents of core memory.", + "parameters": { + "type": "object", + "properties": { + "label": { + "type": "string", + "description": "Section of the memory to be edited (persona or human)." + }, + "content": { + "type": "string", + "description": "Content to write to the memory. All unicode (including emojis) are supported." + }, + "request_heartbeat": { + "type": "boolean", + "description": "Request an immediate heartbeat after function execution. Set to `True` if you want to send a follow-up message or run a follow-up function." + } + }, + "required": ["label", "content", "request_heartbeat"] + }, + "type": null, + "required": [] + }, + "name": "core_memory_append", + "return_char_limit": 1000000, + "source_code": null, + "source_type": "python", + "tags": ["letta_memory_core"], + "tool_type": "letta_memory_core", + "updated_at": "2025-03-28T01:11:04.589876+00:00", + "metadata_": {} + }, + { + "args_json_schema": null, + "created_at": "2025-03-28T01:11:04.593153+00:00", + "description": "Replace the contents of core memory. To delete memories, use an empty string for new_content.", + "json_schema": { + "name": "core_memory_replace", + "description": "Replace the contents of core memory. To delete memories, use an empty string for new_content.", + "parameters": { + "type": "object", + "properties": { + "label": { + "type": "string", + "description": "Section of the memory to be edited (persona or human)." + }, + "old_content": { + "type": "string", + "description": "String to replace. Must be an exact match." + }, + "new_content": { + "type": "string", + "description": "Content to write to the memory. All unicode (including emojis) are supported." + }, + "request_heartbeat": { + "type": "boolean", + "description": "Request an immediate heartbeat after function execution. Set to `True` if you want to send a follow-up message or run a follow-up function." + } + }, + "required": [ + "label", + "old_content", + "new_content", + "request_heartbeat" + ] + }, + "type": null, + "required": [] + }, + "name": "core_memory_replace", + "return_char_limit": 1000000, + "source_code": null, + "source_type": "python", + "tags": ["letta_memory_core"], + "tool_type": "letta_memory_core", + "updated_at": "2025-03-28T01:11:04.593153+00:00", + "metadata_": {} + }, + { + "args_json_schema": null, + "created_at": "2025-03-28T01:11:04.596458+00:00", + "description": "Sends a message to the human user.", + "json_schema": { + "name": "send_message", + "description": "Sends a message to the human user.", + "parameters": { + "type": "object", + "properties": { + "message": { + "type": "string", + "description": "Message contents. All unicode (including emojis) are supported." + }, + "request_heartbeat": { + "type": "boolean", + "description": "Request an immediate heartbeat after function execution. Set to `True` if you want to send a follow-up message or run a follow-up function." + } + }, + "required": ["message", "request_heartbeat"] + }, + "type": null, + "required": [] + }, + "name": "send_message", + "return_char_limit": 1000000, + "source_code": null, + "source_type": "python", + "tags": ["letta_core"], + "tool_type": "letta_core", + "updated_at": "2025-03-28T01:11:04.596458+00:00", + "metadata_": {} + } + ], + "updated_at": "2025-03-28T01:11:04.680766+00:00", + "version": "0.6.45" +} diff --git a/tests/test_agent_serialization.py b/tests/test_agent_serialization.py index 738091f00..d73bffb17 100644 --- a/tests/test_agent_serialization.py +++ b/tests/test_agent_serialization.py @@ -512,6 +512,48 @@ def test_agent_serialize_tool_calls(mock_e2b_api_key_none, local_client, server, assert copy_agent_response.completion_tokens > 0 and copy_agent_response.step_count > 0 +def test_agent_serialize_update_blocks(mock_e2b_api_key_none, local_client, server, serialize_test_agent, default_user, other_user): + """Test deserializing JSON into an Agent instance.""" + append_copy_suffix = False + server.send_messages( + actor=default_user, + agent_id=serialize_test_agent.id, + messages=[MessageCreate(role=MessageRole.user, content="Append 'banana' to core_memory.")], + ) + server.send_messages( + actor=default_user, + agent_id=serialize_test_agent.id, + messages=[MessageCreate(role=MessageRole.user, content="What do you think about that?")], + ) + + result = server.agent_manager.serialize(agent_id=serialize_test_agent.id, actor=default_user) + + # Deserialize the agent + agent_copy = server.agent_manager.deserialize(serialized_agent=result, actor=other_user, append_copy_suffix=append_copy_suffix) + + # Get most recent original agent instance + serialize_test_agent = server.agent_manager.get_agent_by_id(agent_id=serialize_test_agent.id, actor=default_user) + + # Compare serialized representations to check for exact match + print_dict_diff(json.loads(serialize_test_agent.model_dump_json()), json.loads(agent_copy.model_dump_json())) + assert compare_agent_state(server, serialize_test_agent, agent_copy, append_copy_suffix, default_user, other_user) + + # Make sure both agents can receive messages after + original_agent_response = server.send_messages( + actor=default_user, + agent_id=serialize_test_agent.id, + messages=[MessageCreate(role=MessageRole.user, content="Hi")], + ) + copy_agent_response = server.send_messages( + actor=other_user, + agent_id=agent_copy.id, + messages=[MessageCreate(role=MessageRole.user, content="Hi")], + ) + + assert original_agent_response.completion_tokens > 0 and original_agent_response.step_count > 0 + assert copy_agent_response.completion_tokens > 0 and copy_agent_response.step_count > 0 + + # FastAPI endpoint tests