From 052c296ea17fefda1d1db434533aeba0010e15e3 Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Thu, 13 Mar 2025 17:50:19 -0700 Subject: [PATCH] feat: Make pydantic serialized agent object (#1278) Co-authored-by: Caren Thomas --- letta/serialize_schemas/__init__.py | 2 +- .../{agent.py => marshmallow_agent.py} | 17 ++- ...marshmallow_agent_environment_variable.py} | 2 +- .../{base.py => marshmallow_base.py} | 0 .../{block.py => marshmallow_block.py} | 2 +- ...fields.py => marshmallow_custom_fields.py} | 0 .../{message.py => marshmallow_message.py} | 4 +- .../{tag.py => marshmallow_tag.py} | 2 +- .../{tool.py => marshmallow_tool.py} | 2 +- .../pydantic_agent_schema.py | 110 ++++++++++++++++++ letta/server/rest_api/routers/v1/agents.py | 30 ++--- letta/services/agent_manager.py | 17 +-- tests/test_agent_serialization.py | 17 +-- 13 files changed, 162 insertions(+), 43 deletions(-) rename letta/serialize_schemas/{agent.py => marshmallow_agent.py} (85%) rename letta/serialize_schemas/{agent_environment_variable.py => marshmallow_agent_environment_variable.py} (90%) rename letta/serialize_schemas/{base.py => marshmallow_base.py} (100%) rename letta/serialize_schemas/{block.py => marshmallow_block.py} (85%) rename letta/serialize_schemas/{custom_fields.py => marshmallow_custom_fields.py} (100%) rename letta/serialize_schemas/{message.py => marshmallow_message.py} (89%) rename letta/serialize_schemas/{tag.py => marshmallow_tag.py} (90%) rename letta/serialize_schemas/{tool.py => marshmallow_tool.py} (84%) create mode 100644 letta/serialize_schemas/pydantic_agent_schema.py diff --git a/letta/serialize_schemas/__init__.py b/letta/serialize_schemas/__init__.py index d0e09d6d8..1f6be200d 100644 --- a/letta/serialize_schemas/__init__.py +++ b/letta/serialize_schemas/__init__.py @@ -1 +1 @@ -from letta.serialize_schemas.agent import SerializedAgentSchema +from letta.serialize_schemas.marshmallow_agent import MarshmallowAgentSchema diff --git a/letta/serialize_schemas/agent.py b/letta/serialize_schemas/marshmallow_agent.py similarity index 85% rename from letta/serialize_schemas/agent.py rename to letta/serialize_schemas/marshmallow_agent.py index 9a8b9e5f4..182fb559c 100644 --- a/letta/serialize_schemas/agent.py +++ b/letta/serialize_schemas/marshmallow_agent.py @@ -6,17 +6,17 @@ import letta from letta.orm import Agent from letta.schemas.agent import AgentState as PydanticAgentState from letta.schemas.user import User -from letta.serialize_schemas.agent_environment_variable import SerializedAgentEnvironmentVariableSchema -from letta.serialize_schemas.base import BaseSchema -from letta.serialize_schemas.block import SerializedBlockSchema -from letta.serialize_schemas.custom_fields import EmbeddingConfigField, LLMConfigField, ToolRulesField -from letta.serialize_schemas.message import SerializedMessageSchema -from letta.serialize_schemas.tag import SerializedAgentTagSchema -from letta.serialize_schemas.tool import SerializedToolSchema +from letta.serialize_schemas.marshmallow_agent_environment_variable import SerializedAgentEnvironmentVariableSchema +from letta.serialize_schemas.marshmallow_base import BaseSchema +from letta.serialize_schemas.marshmallow_block import SerializedBlockSchema +from letta.serialize_schemas.marshmallow_custom_fields import EmbeddingConfigField, LLMConfigField, ToolRulesField +from letta.serialize_schemas.marshmallow_message import SerializedMessageSchema +from letta.serialize_schemas.marshmallow_tag import SerializedAgentTagSchema +from letta.serialize_schemas.marshmallow_tool import SerializedToolSchema from letta.server.db import SessionLocal -class SerializedAgentSchema(BaseSchema): +class MarshmallowAgentSchema(BaseSchema): """ Marshmallow schema for serializing/deserializing Agent objects. Excludes relational fields. @@ -98,7 +98,6 @@ class SerializedAgentSchema(BaseSchema): class Meta(BaseSchema.Meta): model = Agent - # TODO: Serialize these as well... exclude = BaseSchema.Meta.exclude + ( "project_id", "template_id", diff --git a/letta/serialize_schemas/agent_environment_variable.py b/letta/serialize_schemas/marshmallow_agent_environment_variable.py similarity index 90% rename from letta/serialize_schemas/agent_environment_variable.py rename to letta/serialize_schemas/marshmallow_agent_environment_variable.py index f8c606f49..371614a89 100644 --- a/letta/serialize_schemas/agent_environment_variable.py +++ b/letta/serialize_schemas/marshmallow_agent_environment_variable.py @@ -2,7 +2,7 @@ import uuid from typing import Optional from letta.orm.sandbox_config import AgentEnvironmentVariable -from letta.serialize_schemas.base import BaseSchema +from letta.serialize_schemas.marshmallow_base import BaseSchema class SerializedAgentEnvironmentVariableSchema(BaseSchema): diff --git a/letta/serialize_schemas/base.py b/letta/serialize_schemas/marshmallow_base.py similarity index 100% rename from letta/serialize_schemas/base.py rename to letta/serialize_schemas/marshmallow_base.py diff --git a/letta/serialize_schemas/block.py b/letta/serialize_schemas/marshmallow_block.py similarity index 85% rename from letta/serialize_schemas/block.py rename to letta/serialize_schemas/marshmallow_block.py index 411391211..6432f6423 100644 --- a/letta/serialize_schemas/block.py +++ b/letta/serialize_schemas/marshmallow_block.py @@ -1,6 +1,6 @@ from letta.orm.block import Block from letta.schemas.block import Block as PydanticBlock -from letta.serialize_schemas.base import BaseSchema +from letta.serialize_schemas.marshmallow_base import BaseSchema class SerializedBlockSchema(BaseSchema): diff --git a/letta/serialize_schemas/custom_fields.py b/letta/serialize_schemas/marshmallow_custom_fields.py similarity index 100% rename from letta/serialize_schemas/custom_fields.py rename to letta/serialize_schemas/marshmallow_custom_fields.py diff --git a/letta/serialize_schemas/message.py b/letta/serialize_schemas/marshmallow_message.py similarity index 89% rename from letta/serialize_schemas/message.py rename to letta/serialize_schemas/marshmallow_message.py index 187b8f88c..664947f84 100644 --- a/letta/serialize_schemas/message.py +++ b/letta/serialize_schemas/marshmallow_message.py @@ -4,8 +4,8 @@ from marshmallow import post_dump, pre_load from letta.orm.message import Message from letta.schemas.message import Message as PydanticMessage -from letta.serialize_schemas.base import BaseSchema -from letta.serialize_schemas.custom_fields import ToolCallField +from letta.serialize_schemas.marshmallow_base import BaseSchema +from letta.serialize_schemas.marshmallow_custom_fields import ToolCallField class SerializedMessageSchema(BaseSchema): diff --git a/letta/serialize_schemas/tag.py b/letta/serialize_schemas/marshmallow_tag.py similarity index 90% rename from letta/serialize_schemas/tag.py rename to letta/serialize_schemas/marshmallow_tag.py index 38c5e97cd..be19b90cc 100644 --- a/letta/serialize_schemas/tag.py +++ b/letta/serialize_schemas/marshmallow_tag.py @@ -3,7 +3,7 @@ from typing import Dict from marshmallow import fields, post_dump, pre_load from letta.orm.agents_tags import AgentsTags -from letta.serialize_schemas.base import BaseSchema +from letta.serialize_schemas.marshmallow_base import BaseSchema class SerializedAgentTagSchema(BaseSchema): diff --git a/letta/serialize_schemas/tool.py b/letta/serialize_schemas/marshmallow_tool.py similarity index 84% rename from letta/serialize_schemas/tool.py rename to letta/serialize_schemas/marshmallow_tool.py index fe2debe85..3ae65ceb9 100644 --- a/letta/serialize_schemas/tool.py +++ b/letta/serialize_schemas/marshmallow_tool.py @@ -1,6 +1,6 @@ from letta.orm import Tool from letta.schemas.tool import Tool as PydanticTool -from letta.serialize_schemas.base import BaseSchema +from letta.serialize_schemas.marshmallow_base import BaseSchema class SerializedToolSchema(BaseSchema): diff --git a/letta/serialize_schemas/pydantic_agent_schema.py b/letta/serialize_schemas/pydantic_agent_schema.py new file mode 100644 index 000000000..510871254 --- /dev/null +++ b/letta/serialize_schemas/pydantic_agent_schema.py @@ -0,0 +1,110 @@ +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel, Field + +from letta.schemas.embedding_config import EmbeddingConfig +from letta.schemas.llm_config import LLMConfig + + +class CoreMemoryBlockSchema(BaseModel): + created_at: str + description: Optional[str] + identities: List[Any] + is_deleted: bool + is_template: bool + label: str + limit: int + metadata_: Dict[str, Any] = Field(default_factory=dict) + template_name: Optional[str] + updated_at: str + value: str + + +class MessageSchema(BaseModel): + created_at: str + group_id: Optional[str] + in_context: bool + model: Optional[str] + name: Optional[str] + role: str + text: str + tool_call_id: Optional[str] + tool_calls: List[Any] + tool_returns: List[Any] + updated_at: str + + +class TagSchema(BaseModel): + tag: str + + +class ToolEnvVarSchema(BaseModel): + created_at: str + description: Optional[str] + is_deleted: bool + key: str + updated_at: str + value: str + + +class ToolRuleSchema(BaseModel): + tool_name: str + type: str + + +class ParameterProperties(BaseModel): + type: str + description: Optional[str] = None + + +class ParametersSchema(BaseModel): + type: Optional[str] = "object" + properties: Dict[str, ParameterProperties] + required: List[str] = Field(default_factory=list) + + +class ToolJSONSchema(BaseModel): + name: str + description: str + parameters: ParametersSchema # <— nested strong typing + type: Optional[str] = None # top-level 'type' if it exists + required: Optional[List[str]] = Field(default_factory=list) + + +class ToolSchema(BaseModel): + args_json_schema: Optional[Any] + created_at: str + description: str + is_deleted: bool + json_schema: ToolJSONSchema + name: str + return_char_limit: int + source_code: Optional[str] + source_type: str + tags: List[str] + tool_type: str + updated_at: str + + +class AgentSchema(BaseModel): + agent_type: str + core_memory: List[CoreMemoryBlockSchema] + created_at: str + description: str + embedding_config: EmbeddingConfig + groups: List[Any] + identities: List[Any] + is_deleted: bool + llm_config: LLMConfig + message_buffer_autoclear: bool + messages: List[MessageSchema] + metadata_: Dict + multi_agent_group: Optional[Any] + name: str + system: str + tags: List[TagSchema] + tool_exec_environment_variables: List[ToolEnvVarSchema] + tool_rules: List[ToolRuleSchema] + tools: List[ToolSchema] + updated_at: str + version: str diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index 247dbe225..7233024b5 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -24,6 +24,7 @@ from letta.schemas.run import Run from letta.schemas.source import Source from letta.schemas.tool import Tool from letta.schemas.user import User +from letta.serialize_schemas.pydantic_agent_schema import AgentSchema from letta.server.rest_api.utils import get_letta_server from letta.server.server import SyncServer @@ -92,26 +93,25 @@ def list_agents( ) -@router.get("/{agent_id}/download", operation_id="download_agent_serialized") -def download_agent_serialized( +@router.get("/{agent_id}/export", operation_id="export_agent_serialized", response_model=AgentSchema) +def export_agent_serialized( agent_id: str, server: "SyncServer" = Depends(get_letta_server), actor_id: Optional[str] = Header(None, alias="user_id"), -): +) -> AgentSchema: """ - Download the serialized JSON representation of an agent. + Export the serialized JSON representation of an agent. """ actor = server.user_manager.get_user_or_default(user_id=actor_id) try: - serialized_agent = server.agent_manager.serialize(agent_id=agent_id, actor=actor) - return JSONResponse(content=serialized_agent, media_type="application/json") + return server.agent_manager.serialize(agent_id=agent_id, actor=actor) except NoResultFound: raise HTTPException(status_code=404, detail=f"Agent with id={agent_id} not found for user_id={actor.id}.") -@router.post("/upload", response_model=AgentState, operation_id="upload_agent_serialized") -async def upload_agent_serialized( +@router.post("/import", response_model=AgentState, operation_id="import_agent_serialized") +async def import_agent_serialized( file: UploadFile = File(...), server: "SyncServer" = Depends(get_letta_server), actor_id: Optional[str] = Header(None, alias="user_id"), @@ -123,15 +123,19 @@ async def upload_agent_serialized( project_id: Optional[str] = Query(None, description="The project ID to associate the uploaded agent with."), ): """ - Upload a serialized agent JSON file and recreate the agent in the system. + Import a serialized agent file and recreate the agent in the system. """ actor = server.user_manager.get_user_or_default(user_id=actor_id) try: serialized_data = await file.read() agent_json = json.loads(serialized_data) + + # Validate the JSON against AgentSchema before passing it to deserialize + agent_schema = AgentSchema.model_validate(agent_json) + new_agent = server.agent_manager.deserialize( - serialized_agent=agent_json, + serialized_agent=agent_schema, # Ensure we're passing a validated AgentSchema actor=actor, append_copy_suffix=append_copy_suffix, override_existing_tools=override_existing_tools, @@ -143,7 +147,7 @@ async def upload_agent_serialized( raise HTTPException(status_code=400, detail="Corrupted agent file format.") except ValidationError as e: - raise HTTPException(status_code=422, detail=f"Invalid agent schema: {str(e)}") + raise HTTPException(status_code=422, detail=f"Invalid agent schema: {e.errors()}") except IntegrityError as e: raise HTTPException(status_code=409, detail=f"Database integrity error: {str(e)}") @@ -151,9 +155,9 @@ async def upload_agent_serialized( except OperationalError as e: raise HTTPException(status_code=503, detail=f"Database connection error. Please try again later: {str(e)}") - except Exception: + except Exception as e: traceback.print_exc() - raise HTTPException(status_code=500, detail="An unexpected error occurred while uploading the agent.") + raise HTTPException(status_code=500, detail=f"An unexpected error occurred while uploading the agent: {str(e)}") @router.get("/{agent_id}/context", response_model=ContextWindowOverview, operation_id="retrieve_agent_context_window") diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index 8d2191797..51a35b141 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -36,8 +36,9 @@ from letta.schemas.tool_rule import ContinueToolRule as PydanticContinueToolRule from letta.schemas.tool_rule import TerminalToolRule as PydanticTerminalToolRule from letta.schemas.tool_rule import ToolRule as PydanticToolRule from letta.schemas.user import User as PydanticUser -from letta.serialize_schemas import SerializedAgentSchema -from letta.serialize_schemas.tool import SerializedToolSchema +from letta.serialize_schemas import MarshmallowAgentSchema +from letta.serialize_schemas.marshmallow_tool import SerializedToolSchema +from letta.serialize_schemas.pydantic_agent_schema import AgentSchema from letta.services.block_manager import BlockManager from letta.services.helpers.agent_manager_helper import ( _apply_filters, @@ -464,26 +465,28 @@ class AgentManager: agent.hard_delete(session) @enforce_types - def serialize(self, agent_id: str, actor: PydanticUser) -> dict: + def serialize(self, agent_id: str, actor: PydanticUser) -> AgentSchema: with self.session_maker() as session: # Retrieve the agent agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor) - schema = SerializedAgentSchema(session=session, actor=actor) - return schema.dump(agent) + schema = MarshmallowAgentSchema(session=session, actor=actor) + data = schema.dump(agent) + return AgentSchema(**data) @enforce_types def deserialize( self, - serialized_agent: dict, + serialized_agent: AgentSchema, actor: PydanticUser, append_copy_suffix: bool = True, override_existing_tools: bool = True, project_id: Optional[str] = None, ) -> PydanticAgentState: + serialized_agent = serialized_agent.model_dump() tool_data_list = serialized_agent.pop("tools", []) with self.session_maker() as session: - schema = SerializedAgentSchema(session=session, actor=actor) + schema = MarshmallowAgentSchema(session=session, actor=actor) agent = schema.load(serialized_agent, session=session) if append_copy_suffix: agent.name += "_copy" diff --git a/tests/test_agent_serialization.py b/tests/test_agent_serialization.py index 9ed1a4bc3..0fe873291 100644 --- a/tests/test_agent_serialization.py +++ b/tests/test_agent_serialization.py @@ -21,6 +21,7 @@ from letta.schemas.llm_config import LLMConfig from letta.schemas.message import MessageCreate from letta.schemas.organization import Organization from letta.schemas.user import User +from letta.serialize_schemas.pydantic_agent_schema import AgentSchema from letta.server.rest_api.app import app from letta.server.server import SyncServer @@ -369,12 +370,12 @@ def test_deserialize_override_existing_tools( result = server.agent_manager.serialize(agent_id=serialize_test_agent.id, actor=default_user) # Extract tools before upload - tool_data_list = result.get("tools", []) - tool_names = {tool["name"]: tool for tool in tool_data_list} + tool_data_list = result.tools + tool_names = {tool.name: tool for tool in tool_data_list} # Rewrite all the tool source code to the print_tool source code - for tool in result["tools"]: - tool["source_code"] = print_tool.source_code + for tool in result.tools: + tool.source_code = print_tool.source_code # Deserialize the agent with different override settings server.agent_manager.deserialize( @@ -466,8 +467,7 @@ def test_in_context_message_id_remapping(local_client, server, serialize_test_ag # Make sure all the messages are able to be retrieved in_context_messages = server.agent_manager.get_in_context_messages(agent_id=agent_copy.id, actor=other_user) - assert len(in_context_messages) == len(result["message_ids"]) - assert sorted([m.id for m in in_context_messages]) == sorted(result["message_ids"]) + assert len(in_context_messages) == len(serialize_test_agent.message_ids) # FastAPI endpoint tests @@ -485,7 +485,9 @@ def test_agent_download_upload_flow(fastapi_client, server, serialize_test_agent response = fastapi_client.get(f"/v1/agents/{agent_id}/download", headers={"user_id": default_user.id}) assert response.status_code == 200, f"Download failed: {response.text}" - agent_json = response.json() + # Ensure response matches expected schema + agent_schema = AgentSchema.model_validate(response.json()) # Validate as Pydantic model + agent_json = agent_schema.model_dump(mode="json") # Convert back to serializable JSON # Step 2: Upload the serialized agent as a copy agent_bytes = BytesIO(json.dumps(agent_json).encode("utf-8")) @@ -508,6 +510,7 @@ def test_agent_download_upload_flow(fastapi_client, server, serialize_test_agent # Step 3: Retrieve the copied agent serialize_test_agent = server.agent_manager.get_agent_by_id(agent_id=serialize_test_agent.id, actor=default_user) agent_copy = server.agent_manager.get_agent_by_id(agent_id=copied_agent_id, actor=other_user) + print_dict_diff(json.loads(serialize_test_agent.model_dump_json()), json.loads(agent_copy.model_dump_json())) assert compare_agent_state(agent_copy, serialize_test_agent, append_copy_suffix=append_copy_suffix)