mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
feat: Make pydantic serialized agent object (#1278)
Co-authored-by: Caren Thomas <caren@letta.com>
This commit is contained in:
parent
e43d635d18
commit
052c296ea1
@ -1 +1 @@
|
||||
from letta.serialize_schemas.agent import SerializedAgentSchema
|
||||
from letta.serialize_schemas.marshmallow_agent import MarshmallowAgentSchema
|
||||
|
@ -6,17 +6,17 @@ import letta
|
||||
from letta.orm import Agent
|
||||
from letta.schemas.agent import AgentState as PydanticAgentState
|
||||
from letta.schemas.user import User
|
||||
from letta.serialize_schemas.agent_environment_variable import SerializedAgentEnvironmentVariableSchema
|
||||
from letta.serialize_schemas.base import BaseSchema
|
||||
from letta.serialize_schemas.block import SerializedBlockSchema
|
||||
from letta.serialize_schemas.custom_fields import EmbeddingConfigField, LLMConfigField, ToolRulesField
|
||||
from letta.serialize_schemas.message import SerializedMessageSchema
|
||||
from letta.serialize_schemas.tag import SerializedAgentTagSchema
|
||||
from letta.serialize_schemas.tool import SerializedToolSchema
|
||||
from letta.serialize_schemas.marshmallow_agent_environment_variable import SerializedAgentEnvironmentVariableSchema
|
||||
from letta.serialize_schemas.marshmallow_base import BaseSchema
|
||||
from letta.serialize_schemas.marshmallow_block import SerializedBlockSchema
|
||||
from letta.serialize_schemas.marshmallow_custom_fields import EmbeddingConfigField, LLMConfigField, ToolRulesField
|
||||
from letta.serialize_schemas.marshmallow_message import SerializedMessageSchema
|
||||
from letta.serialize_schemas.marshmallow_tag import SerializedAgentTagSchema
|
||||
from letta.serialize_schemas.marshmallow_tool import SerializedToolSchema
|
||||
from letta.server.db import SessionLocal
|
||||
|
||||
|
||||
class SerializedAgentSchema(BaseSchema):
|
||||
class MarshmallowAgentSchema(BaseSchema):
|
||||
"""
|
||||
Marshmallow schema for serializing/deserializing Agent objects.
|
||||
Excludes relational fields.
|
||||
@ -98,7 +98,6 @@ class SerializedAgentSchema(BaseSchema):
|
||||
|
||||
class Meta(BaseSchema.Meta):
|
||||
model = Agent
|
||||
# TODO: Serialize these as well...
|
||||
exclude = BaseSchema.Meta.exclude + (
|
||||
"project_id",
|
||||
"template_id",
|
@ -2,7 +2,7 @@ import uuid
|
||||
from typing import Optional
|
||||
|
||||
from letta.orm.sandbox_config import AgentEnvironmentVariable
|
||||
from letta.serialize_schemas.base import BaseSchema
|
||||
from letta.serialize_schemas.marshmallow_base import BaseSchema
|
||||
|
||||
|
||||
class SerializedAgentEnvironmentVariableSchema(BaseSchema):
|
@ -1,6 +1,6 @@
|
||||
from letta.orm.block import Block
|
||||
from letta.schemas.block import Block as PydanticBlock
|
||||
from letta.serialize_schemas.base import BaseSchema
|
||||
from letta.serialize_schemas.marshmallow_base import BaseSchema
|
||||
|
||||
|
||||
class SerializedBlockSchema(BaseSchema):
|
@ -4,8 +4,8 @@ from marshmallow import post_dump, pre_load
|
||||
|
||||
from letta.orm.message import Message
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.serialize_schemas.base import BaseSchema
|
||||
from letta.serialize_schemas.custom_fields import ToolCallField
|
||||
from letta.serialize_schemas.marshmallow_base import BaseSchema
|
||||
from letta.serialize_schemas.marshmallow_custom_fields import ToolCallField
|
||||
|
||||
|
||||
class SerializedMessageSchema(BaseSchema):
|
@ -3,7 +3,7 @@ from typing import Dict
|
||||
from marshmallow import fields, post_dump, pre_load
|
||||
|
||||
from letta.orm.agents_tags import AgentsTags
|
||||
from letta.serialize_schemas.base import BaseSchema
|
||||
from letta.serialize_schemas.marshmallow_base import BaseSchema
|
||||
|
||||
|
||||
class SerializedAgentTagSchema(BaseSchema):
|
@ -1,6 +1,6 @@
|
||||
from letta.orm import Tool
|
||||
from letta.schemas.tool import Tool as PydanticTool
|
||||
from letta.serialize_schemas.base import BaseSchema
|
||||
from letta.serialize_schemas.marshmallow_base import BaseSchema
|
||||
|
||||
|
||||
class SerializedToolSchema(BaseSchema):
|
110
letta/serialize_schemas/pydantic_agent_schema.py
Normal file
110
letta/serialize_schemas/pydantic_agent_schema.py
Normal file
@ -0,0 +1,110 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
|
||||
|
||||
class CoreMemoryBlockSchema(BaseModel):
|
||||
created_at: str
|
||||
description: Optional[str]
|
||||
identities: List[Any]
|
||||
is_deleted: bool
|
||||
is_template: bool
|
||||
label: str
|
||||
limit: int
|
||||
metadata_: Dict[str, Any] = Field(default_factory=dict)
|
||||
template_name: Optional[str]
|
||||
updated_at: str
|
||||
value: str
|
||||
|
||||
|
||||
class MessageSchema(BaseModel):
|
||||
created_at: str
|
||||
group_id: Optional[str]
|
||||
in_context: bool
|
||||
model: Optional[str]
|
||||
name: Optional[str]
|
||||
role: str
|
||||
text: str
|
||||
tool_call_id: Optional[str]
|
||||
tool_calls: List[Any]
|
||||
tool_returns: List[Any]
|
||||
updated_at: str
|
||||
|
||||
|
||||
class TagSchema(BaseModel):
|
||||
tag: str
|
||||
|
||||
|
||||
class ToolEnvVarSchema(BaseModel):
|
||||
created_at: str
|
||||
description: Optional[str]
|
||||
is_deleted: bool
|
||||
key: str
|
||||
updated_at: str
|
||||
value: str
|
||||
|
||||
|
||||
class ToolRuleSchema(BaseModel):
|
||||
tool_name: str
|
||||
type: str
|
||||
|
||||
|
||||
class ParameterProperties(BaseModel):
|
||||
type: str
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
class ParametersSchema(BaseModel):
|
||||
type: Optional[str] = "object"
|
||||
properties: Dict[str, ParameterProperties]
|
||||
required: List[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ToolJSONSchema(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
parameters: ParametersSchema # <— nested strong typing
|
||||
type: Optional[str] = None # top-level 'type' if it exists
|
||||
required: Optional[List[str]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ToolSchema(BaseModel):
|
||||
args_json_schema: Optional[Any]
|
||||
created_at: str
|
||||
description: str
|
||||
is_deleted: bool
|
||||
json_schema: ToolJSONSchema
|
||||
name: str
|
||||
return_char_limit: int
|
||||
source_code: Optional[str]
|
||||
source_type: str
|
||||
tags: List[str]
|
||||
tool_type: str
|
||||
updated_at: str
|
||||
|
||||
|
||||
class AgentSchema(BaseModel):
|
||||
agent_type: str
|
||||
core_memory: List[CoreMemoryBlockSchema]
|
||||
created_at: str
|
||||
description: str
|
||||
embedding_config: EmbeddingConfig
|
||||
groups: List[Any]
|
||||
identities: List[Any]
|
||||
is_deleted: bool
|
||||
llm_config: LLMConfig
|
||||
message_buffer_autoclear: bool
|
||||
messages: List[MessageSchema]
|
||||
metadata_: Dict
|
||||
multi_agent_group: Optional[Any]
|
||||
name: str
|
||||
system: str
|
||||
tags: List[TagSchema]
|
||||
tool_exec_environment_variables: List[ToolEnvVarSchema]
|
||||
tool_rules: List[ToolRuleSchema]
|
||||
tools: List[ToolSchema]
|
||||
updated_at: str
|
||||
version: str
|
@ -24,6 +24,7 @@ from letta.schemas.run import Run
|
||||
from letta.schemas.source import Source
|
||||
from letta.schemas.tool import Tool
|
||||
from letta.schemas.user import User
|
||||
from letta.serialize_schemas.pydantic_agent_schema import AgentSchema
|
||||
from letta.server.rest_api.utils import get_letta_server
|
||||
from letta.server.server import SyncServer
|
||||
|
||||
@ -92,26 +93,25 @@ def list_agents(
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{agent_id}/download", operation_id="download_agent_serialized")
|
||||
def download_agent_serialized(
|
||||
@router.get("/{agent_id}/export", operation_id="export_agent_serialized", response_model=AgentSchema)
|
||||
def export_agent_serialized(
|
||||
agent_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
) -> AgentSchema:
|
||||
"""
|
||||
Download the serialized JSON representation of an agent.
|
||||
Export the serialized JSON representation of an agent.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
try:
|
||||
serialized_agent = server.agent_manager.serialize(agent_id=agent_id, actor=actor)
|
||||
return JSONResponse(content=serialized_agent, media_type="application/json")
|
||||
return server.agent_manager.serialize(agent_id=agent_id, actor=actor)
|
||||
except NoResultFound:
|
||||
raise HTTPException(status_code=404, detail=f"Agent with id={agent_id} not found for user_id={actor.id}.")
|
||||
|
||||
|
||||
@router.post("/upload", response_model=AgentState, operation_id="upload_agent_serialized")
|
||||
async def upload_agent_serialized(
|
||||
@router.post("/import", response_model=AgentState, operation_id="import_agent_serialized")
|
||||
async def import_agent_serialized(
|
||||
file: UploadFile = File(...),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
@ -123,15 +123,19 @@ async def upload_agent_serialized(
|
||||
project_id: Optional[str] = Query(None, description="The project ID to associate the uploaded agent with."),
|
||||
):
|
||||
"""
|
||||
Upload a serialized agent JSON file and recreate the agent in the system.
|
||||
Import a serialized agent file and recreate the agent in the system.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
try:
|
||||
serialized_data = await file.read()
|
||||
agent_json = json.loads(serialized_data)
|
||||
|
||||
# Validate the JSON against AgentSchema before passing it to deserialize
|
||||
agent_schema = AgentSchema.model_validate(agent_json)
|
||||
|
||||
new_agent = server.agent_manager.deserialize(
|
||||
serialized_agent=agent_json,
|
||||
serialized_agent=agent_schema, # Ensure we're passing a validated AgentSchema
|
||||
actor=actor,
|
||||
append_copy_suffix=append_copy_suffix,
|
||||
override_existing_tools=override_existing_tools,
|
||||
@ -143,7 +147,7 @@ async def upload_agent_serialized(
|
||||
raise HTTPException(status_code=400, detail="Corrupted agent file format.")
|
||||
|
||||
except ValidationError as e:
|
||||
raise HTTPException(status_code=422, detail=f"Invalid agent schema: {str(e)}")
|
||||
raise HTTPException(status_code=422, detail=f"Invalid agent schema: {e.errors()}")
|
||||
|
||||
except IntegrityError as e:
|
||||
raise HTTPException(status_code=409, detail=f"Database integrity error: {str(e)}")
|
||||
@ -151,9 +155,9 @@ async def upload_agent_serialized(
|
||||
except OperationalError as e:
|
||||
raise HTTPException(status_code=503, detail=f"Database connection error. Please try again later: {str(e)}")
|
||||
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
raise HTTPException(status_code=500, detail="An unexpected error occurred while uploading the agent.")
|
||||
raise HTTPException(status_code=500, detail=f"An unexpected error occurred while uploading the agent: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/{agent_id}/context", response_model=ContextWindowOverview, operation_id="retrieve_agent_context_window")
|
||||
|
@ -36,8 +36,9 @@ from letta.schemas.tool_rule import ContinueToolRule as PydanticContinueToolRule
|
||||
from letta.schemas.tool_rule import TerminalToolRule as PydanticTerminalToolRule
|
||||
from letta.schemas.tool_rule import ToolRule as PydanticToolRule
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.serialize_schemas import SerializedAgentSchema
|
||||
from letta.serialize_schemas.tool import SerializedToolSchema
|
||||
from letta.serialize_schemas import MarshmallowAgentSchema
|
||||
from letta.serialize_schemas.marshmallow_tool import SerializedToolSchema
|
||||
from letta.serialize_schemas.pydantic_agent_schema import AgentSchema
|
||||
from letta.services.block_manager import BlockManager
|
||||
from letta.services.helpers.agent_manager_helper import (
|
||||
_apply_filters,
|
||||
@ -464,26 +465,28 @@ class AgentManager:
|
||||
agent.hard_delete(session)
|
||||
|
||||
@enforce_types
|
||||
def serialize(self, agent_id: str, actor: PydanticUser) -> dict:
|
||||
def serialize(self, agent_id: str, actor: PydanticUser) -> AgentSchema:
|
||||
with self.session_maker() as session:
|
||||
# Retrieve the agent
|
||||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||||
schema = SerializedAgentSchema(session=session, actor=actor)
|
||||
return schema.dump(agent)
|
||||
schema = MarshmallowAgentSchema(session=session, actor=actor)
|
||||
data = schema.dump(agent)
|
||||
return AgentSchema(**data)
|
||||
|
||||
@enforce_types
|
||||
def deserialize(
|
||||
self,
|
||||
serialized_agent: dict,
|
||||
serialized_agent: AgentSchema,
|
||||
actor: PydanticUser,
|
||||
append_copy_suffix: bool = True,
|
||||
override_existing_tools: bool = True,
|
||||
project_id: Optional[str] = None,
|
||||
) -> PydanticAgentState:
|
||||
serialized_agent = serialized_agent.model_dump()
|
||||
tool_data_list = serialized_agent.pop("tools", [])
|
||||
|
||||
with self.session_maker() as session:
|
||||
schema = SerializedAgentSchema(session=session, actor=actor)
|
||||
schema = MarshmallowAgentSchema(session=session, actor=actor)
|
||||
agent = schema.load(serialized_agent, session=session)
|
||||
if append_copy_suffix:
|
||||
agent.name += "_copy"
|
||||
|
@ -21,6 +21,7 @@ from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import MessageCreate
|
||||
from letta.schemas.organization import Organization
|
||||
from letta.schemas.user import User
|
||||
from letta.serialize_schemas.pydantic_agent_schema import AgentSchema
|
||||
from letta.server.rest_api.app import app
|
||||
from letta.server.server import SyncServer
|
||||
|
||||
@ -369,12 +370,12 @@ def test_deserialize_override_existing_tools(
|
||||
result = server.agent_manager.serialize(agent_id=serialize_test_agent.id, actor=default_user)
|
||||
|
||||
# Extract tools before upload
|
||||
tool_data_list = result.get("tools", [])
|
||||
tool_names = {tool["name"]: tool for tool in tool_data_list}
|
||||
tool_data_list = result.tools
|
||||
tool_names = {tool.name: tool for tool in tool_data_list}
|
||||
|
||||
# Rewrite all the tool source code to the print_tool source code
|
||||
for tool in result["tools"]:
|
||||
tool["source_code"] = print_tool.source_code
|
||||
for tool in result.tools:
|
||||
tool.source_code = print_tool.source_code
|
||||
|
||||
# Deserialize the agent with different override settings
|
||||
server.agent_manager.deserialize(
|
||||
@ -466,8 +467,7 @@ def test_in_context_message_id_remapping(local_client, server, serialize_test_ag
|
||||
|
||||
# Make sure all the messages are able to be retrieved
|
||||
in_context_messages = server.agent_manager.get_in_context_messages(agent_id=agent_copy.id, actor=other_user)
|
||||
assert len(in_context_messages) == len(result["message_ids"])
|
||||
assert sorted([m.id for m in in_context_messages]) == sorted(result["message_ids"])
|
||||
assert len(in_context_messages) == len(serialize_test_agent.message_ids)
|
||||
|
||||
|
||||
# FastAPI endpoint tests
|
||||
@ -485,7 +485,9 @@ def test_agent_download_upload_flow(fastapi_client, server, serialize_test_agent
|
||||
response = fastapi_client.get(f"/v1/agents/{agent_id}/download", headers={"user_id": default_user.id})
|
||||
assert response.status_code == 200, f"Download failed: {response.text}"
|
||||
|
||||
agent_json = response.json()
|
||||
# Ensure response matches expected schema
|
||||
agent_schema = AgentSchema.model_validate(response.json()) # Validate as Pydantic model
|
||||
agent_json = agent_schema.model_dump(mode="json") # Convert back to serializable JSON
|
||||
|
||||
# Step 2: Upload the serialized agent as a copy
|
||||
agent_bytes = BytesIO(json.dumps(agent_json).encode("utf-8"))
|
||||
@ -508,6 +510,7 @@ def test_agent_download_upload_flow(fastapi_client, server, serialize_test_agent
|
||||
# Step 3: Retrieve the copied agent
|
||||
serialize_test_agent = server.agent_manager.get_agent_by_id(agent_id=serialize_test_agent.id, actor=default_user)
|
||||
agent_copy = server.agent_manager.get_agent_by_id(agent_id=copied_agent_id, actor=other_user)
|
||||
|
||||
print_dict_diff(json.loads(serialize_test_agent.model_dump_json()), json.loads(agent_copy.model_dump_json()))
|
||||
assert compare_agent_state(agent_copy, serialize_test_agent, append_copy_suffix=append_copy_suffix)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user