mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
371 lines
15 KiB
Python
371 lines
15 KiB
Python
import json
|
|
from datetime import datetime, timezone
|
|
from enum import Enum
|
|
from typing import Annotated, List, Literal, Optional, Union
|
|
|
|
from pydantic import BaseModel, Field, field_serializer, field_validator
|
|
|
|
from letta.schemas.letta_message_content import (
|
|
LettaAssistantMessageContentUnion,
|
|
LettaUserMessageContentUnion,
|
|
get_letta_assistant_message_content_union_str_json_schema,
|
|
get_letta_user_message_content_union_str_json_schema,
|
|
)
|
|
|
|
# ---------------------------
|
|
# Letta API Messaging Schemas
|
|
# ---------------------------
|
|
|
|
|
|
class MessageType(str, Enum):
|
|
system_message = "system_message"
|
|
user_message = "user_message"
|
|
assistant_message = "assistant_message"
|
|
reasoning_message = "reasoning_message"
|
|
hidden_reasoning_message = "hidden_reasoning_message"
|
|
tool_call_message = "tool_call_message"
|
|
tool_return_message = "tool_return_message"
|
|
|
|
|
|
class LettaMessage(BaseModel):
|
|
"""
|
|
Base class for simplified Letta message response type. This is intended to be used for developers
|
|
who want the internal monologue, tool calls, and tool returns in a simplified format that does not
|
|
include additional information other than the content and timestamp.
|
|
|
|
Args:
|
|
id (str): The ID of the message
|
|
date (datetime): The date the message was created in ISO format
|
|
name (Optional[str]): The name of the sender of the message
|
|
message_type (MessageType): The type of the message
|
|
otid (Optional[str]): The offline threading id associated with this message
|
|
sender_id (Optional[str]): The id of the sender of the message, can be an identity id or agent id
|
|
"""
|
|
|
|
id: str
|
|
date: datetime
|
|
name: Optional[str] = None
|
|
message_type: MessageType = Field(..., description="The type of the message.")
|
|
otid: Optional[str] = None
|
|
sender_id: Optional[str] = None
|
|
|
|
@field_serializer("date")
|
|
def serialize_datetime(self, dt: datetime, _info):
|
|
"""
|
|
Remove microseconds since it seems like we're inconsistent with getting them
|
|
TODO: figure out why we don't always get microseconds (get_utc_time() does)
|
|
"""
|
|
if dt.tzinfo is None or dt.tzinfo.utcoffset(dt) is None:
|
|
dt = dt.replace(tzinfo=timezone.utc)
|
|
return dt.isoformat(timespec="seconds")
|
|
|
|
|
|
class SystemMessage(LettaMessage):
|
|
"""
|
|
A message generated by the system. Never streamed back on a response, only used for cursor pagination.
|
|
|
|
Args:
|
|
id (str): The ID of the message
|
|
date (datetime): The date the message was created in ISO format
|
|
name (Optional[str]): The name of the sender of the message
|
|
content (str): The message content sent by the system
|
|
"""
|
|
|
|
message_type: Literal[MessageType.system_message] = Field(MessageType.system_message, description="The type of the message.")
|
|
content: str = Field(..., description="The message content sent by the system")
|
|
|
|
|
|
class UserMessage(LettaMessage):
|
|
"""
|
|
A message sent by the user. Never streamed back on a response, only used for cursor pagination.
|
|
|
|
Args:
|
|
id (str): The ID of the message
|
|
date (datetime): The date the message was created in ISO format
|
|
name (Optional[str]): The name of the sender of the message
|
|
content (Union[str, List[LettaUserMessageContentUnion]]): The message content sent by the user (can be a string or an array of multi-modal content parts)
|
|
"""
|
|
|
|
message_type: Literal[MessageType.user_message] = Field(MessageType.user_message, description="The type of the message.")
|
|
content: Union[str, List[LettaUserMessageContentUnion]] = Field(
|
|
...,
|
|
description="The message content sent by the user (can be a string or an array of multi-modal content parts)",
|
|
json_schema_extra=get_letta_user_message_content_union_str_json_schema(),
|
|
)
|
|
|
|
|
|
class ReasoningMessage(LettaMessage):
|
|
"""
|
|
Representation of an agent's internal reasoning.
|
|
|
|
Args:
|
|
id (str): The ID of the message
|
|
date (datetime): The date the message was created in ISO format
|
|
name (Optional[str]): The name of the sender of the message
|
|
source (Literal["reasoner_model", "non_reasoner_model"]): Whether the reasoning
|
|
content was generated natively by a reasoner model or derived via prompting
|
|
reasoning (str): The internal reasoning of the agent
|
|
signature (Optional[str]): The model-generated signature of the reasoning step
|
|
"""
|
|
|
|
message_type: Literal[MessageType.reasoning_message] = Field(MessageType.reasoning_message, description="The type of the message.")
|
|
source: Literal["reasoner_model", "non_reasoner_model"] = "non_reasoner_model"
|
|
reasoning: str
|
|
signature: Optional[str] = None
|
|
|
|
|
|
class HiddenReasoningMessage(LettaMessage):
|
|
"""
|
|
Representation of an agent's internal reasoning where reasoning content
|
|
has been hidden from the response.
|
|
|
|
Args:
|
|
id (str): The ID of the message
|
|
date (datetime): The date the message was created in ISO format
|
|
name (Optional[str]): The name of the sender of the message
|
|
state (Literal["redacted", "omitted"]): Whether the reasoning
|
|
content was redacted by the provider or simply omitted by the API
|
|
hidden_reasoning (Optional[str]): The internal reasoning of the agent
|
|
"""
|
|
|
|
message_type: Literal[MessageType.hidden_reasoning_message] = Field(
|
|
MessageType.hidden_reasoning_message, description="The type of the message."
|
|
)
|
|
state: Literal["redacted", "omitted"]
|
|
hidden_reasoning: Optional[str] = None
|
|
|
|
|
|
class ToolCall(BaseModel):
|
|
name: str
|
|
arguments: str
|
|
tool_call_id: str
|
|
|
|
|
|
class ToolCallDelta(BaseModel):
|
|
name: Optional[str] = None
|
|
arguments: Optional[str] = None
|
|
tool_call_id: Optional[str] = None
|
|
|
|
def model_dump(self, *args, **kwargs):
|
|
"""
|
|
This is a workaround to exclude None values from the JSON dump since the
|
|
OpenAI style of returning chunks doesn't include keys with null values.
|
|
"""
|
|
kwargs["exclude_none"] = True
|
|
return super().model_dump(*args, **kwargs)
|
|
|
|
def json(self, *args, **kwargs):
|
|
return json.dumps(self.model_dump(exclude_none=True), *args, **kwargs)
|
|
|
|
|
|
class ToolCallMessage(LettaMessage):
|
|
"""
|
|
A message representing a request to call a tool (generated by the LLM to trigger tool execution).
|
|
|
|
Args:
|
|
id (str): The ID of the message
|
|
date (datetime): The date the message was created in ISO format
|
|
name (Optional[str]): The name of the sender of the message
|
|
tool_call (Union[ToolCall, ToolCallDelta]): The tool call
|
|
"""
|
|
|
|
message_type: Literal[MessageType.tool_call_message] = Field(MessageType.tool_call_message, description="The type of the message.")
|
|
tool_call: Union[ToolCall, ToolCallDelta]
|
|
|
|
def model_dump(self, *args, **kwargs):
|
|
"""
|
|
Handling for the ToolCallDelta exclude_none to work correctly
|
|
"""
|
|
kwargs["exclude_none"] = True
|
|
data = super().model_dump(*args, **kwargs)
|
|
if isinstance(data["tool_call"], dict):
|
|
data["tool_call"] = {k: v for k, v in data["tool_call"].items() if v is not None}
|
|
return data
|
|
|
|
class Config:
|
|
json_encoders = {
|
|
ToolCallDelta: lambda v: v.model_dump(exclude_none=True),
|
|
ToolCall: lambda v: v.model_dump(exclude_none=True),
|
|
}
|
|
|
|
@field_validator("tool_call", mode="before")
|
|
@classmethod
|
|
def validate_tool_call(cls, v):
|
|
"""
|
|
Casts dicts into ToolCallMessage objects. Without this extra validator, Pydantic will throw
|
|
an error if 'name' or 'arguments' are None instead of properly casting to ToolCallDelta
|
|
instead of ToolCall.
|
|
"""
|
|
if isinstance(v, dict):
|
|
if "name" in v and "arguments" in v and "tool_call_id" in v:
|
|
return ToolCall(name=v["name"], arguments=v["arguments"], tool_call_id=v["tool_call_id"])
|
|
elif "name" in v or "arguments" in v or "tool_call_id" in v:
|
|
return ToolCallDelta(name=v.get("name"), arguments=v.get("arguments"), tool_call_id=v.get("tool_call_id"))
|
|
else:
|
|
raise ValueError("tool_call must contain either 'name' or 'arguments'")
|
|
return v
|
|
|
|
|
|
class ToolReturnMessage(LettaMessage):
|
|
"""
|
|
A message representing the return value of a tool call (generated by Letta executing the requested tool).
|
|
|
|
Args:
|
|
id (str): The ID of the message
|
|
date (datetime): The date the message was created in ISO format
|
|
name (Optional[str]): The name of the sender of the message
|
|
tool_return (str): The return value of the tool
|
|
status (Literal["success", "error"]): The status of the tool call
|
|
tool_call_id (str): A unique identifier for the tool call that generated this message
|
|
stdout (Optional[List(str)]): Captured stdout (e.g. prints, logs) from the tool invocation
|
|
stderr (Optional[List(str)]): Captured stderr from the tool invocation
|
|
"""
|
|
|
|
message_type: Literal[MessageType.tool_return_message] = Field(MessageType.tool_return_message, description="The type of the message.")
|
|
tool_return: str
|
|
status: Literal["success", "error"]
|
|
tool_call_id: str
|
|
stdout: Optional[List[str]] = None
|
|
stderr: Optional[List[str]] = None
|
|
|
|
|
|
class AssistantMessage(LettaMessage):
|
|
"""
|
|
A message sent by the LLM in response to user input. Used in the LLM context.
|
|
|
|
Args:
|
|
id (str): The ID of the message
|
|
date (datetime): The date the message was created in ISO format
|
|
name (Optional[str]): The name of the sender of the message
|
|
content (Union[str, List[LettaAssistantMessageContentUnion]]): The message content sent by the agent (can be a string or an array of content parts)
|
|
"""
|
|
|
|
message_type: Literal[MessageType.assistant_message] = Field(MessageType.assistant_message, description="The type of the message.")
|
|
content: Union[str, List[LettaAssistantMessageContentUnion]] = Field(
|
|
...,
|
|
description="The message content sent by the agent (can be a string or an array of content parts)",
|
|
json_schema_extra=get_letta_assistant_message_content_union_str_json_schema(),
|
|
)
|
|
|
|
|
|
# NOTE: use Pydantic's discriminated unions feature: https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions
|
|
LettaMessageUnion = Annotated[
|
|
Union[SystemMessage, UserMessage, ReasoningMessage, HiddenReasoningMessage, ToolCallMessage, ToolReturnMessage, AssistantMessage],
|
|
Field(discriminator="message_type"),
|
|
]
|
|
|
|
|
|
def create_letta_message_union_schema():
|
|
return {
|
|
"oneOf": [
|
|
{"$ref": "#/components/schemas/SystemMessage"},
|
|
{"$ref": "#/components/schemas/UserMessage"},
|
|
{"$ref": "#/components/schemas/ReasoningMessage"},
|
|
{"$ref": "#/components/schemas/HiddenReasoningMessage"},
|
|
{"$ref": "#/components/schemas/ToolCallMessage"},
|
|
{"$ref": "#/components/schemas/ToolReturnMessage"},
|
|
{"$ref": "#/components/schemas/AssistantMessage"},
|
|
],
|
|
"discriminator": {
|
|
"propertyName": "message_type",
|
|
"mapping": {
|
|
"system_message": "#/components/schemas/SystemMessage",
|
|
"user_message": "#/components/schemas/UserMessage",
|
|
"reasoning_message": "#/components/schemas/ReasoningMessage",
|
|
"hidden_reasoning_message": "#/components/schemas/HiddenReasoningMessage",
|
|
"tool_call_message": "#/components/schemas/ToolCallMessage",
|
|
"tool_return_message": "#/components/schemas/ToolReturnMessage",
|
|
"assistant_message": "#/components/schemas/AssistantMessage",
|
|
},
|
|
},
|
|
}
|
|
|
|
|
|
# --------------------------
|
|
# Message Update API Schemas
|
|
# --------------------------
|
|
|
|
|
|
class UpdateSystemMessage(BaseModel):
|
|
message_type: Literal["system_message"] = "system_message"
|
|
content: str = Field(
|
|
..., description="The message content sent by the system (can be a string or an array of multi-modal content parts)"
|
|
)
|
|
|
|
|
|
class UpdateUserMessage(BaseModel):
|
|
message_type: Literal["user_message"] = "user_message"
|
|
content: Union[str, List[LettaUserMessageContentUnion]] = Field(
|
|
...,
|
|
description="The message content sent by the user (can be a string or an array of multi-modal content parts)",
|
|
json_schema_extra=get_letta_user_message_content_union_str_json_schema(),
|
|
)
|
|
|
|
|
|
class UpdateReasoningMessage(BaseModel):
|
|
reasoning: str
|
|
message_type: Literal["reasoning_message"] = "reasoning_message"
|
|
|
|
|
|
class UpdateAssistantMessage(BaseModel):
|
|
message_type: Literal["assistant_message"] = "assistant_message"
|
|
content: Union[str, List[LettaAssistantMessageContentUnion]] = Field(
|
|
...,
|
|
description="The message content sent by the assistant (can be a string or an array of content parts)",
|
|
json_schema_extra=get_letta_assistant_message_content_union_str_json_schema(),
|
|
)
|
|
|
|
|
|
LettaMessageUpdateUnion = Annotated[
|
|
Union[UpdateSystemMessage, UpdateUserMessage, UpdateReasoningMessage, UpdateAssistantMessage],
|
|
Field(discriminator="message_type"),
|
|
]
|
|
|
|
|
|
# --------------------------
|
|
# Deprecated Message Schemas
|
|
# --------------------------
|
|
|
|
|
|
class LegacyFunctionCallMessage(LettaMessage):
|
|
function_call: str
|
|
|
|
|
|
class LegacyFunctionReturn(LettaMessage):
|
|
"""
|
|
A message representing the return value of a function call (generated by Letta executing the requested function).
|
|
|
|
Args:
|
|
function_return (str): The return value of the function
|
|
status (Literal["success", "error"]): The status of the function call
|
|
id (str): The ID of the message
|
|
date (datetime): The date the message was created in ISO format
|
|
function_call_id (str): A unique identifier for the function call that generated this message
|
|
stdout (Optional[List(str)]): Captured stdout (e.g. prints, logs) from the function invocation
|
|
stderr (Optional[List(str)]): Captured stderr from the function invocation
|
|
"""
|
|
|
|
message_type: Literal["function_return"] = "function_return"
|
|
function_return: str
|
|
status: Literal["success", "error"]
|
|
function_call_id: str
|
|
stdout: Optional[List[str]] = None
|
|
stderr: Optional[List[str]] = None
|
|
|
|
|
|
class LegacyInternalMonologue(LettaMessage):
|
|
"""
|
|
Representation of an agent's internal monologue.
|
|
|
|
Args:
|
|
internal_monologue (str): The internal monologue of the agent
|
|
id (str): The ID of the message
|
|
date (datetime): The date the message was created in ISO format
|
|
"""
|
|
|
|
message_type: Literal["internal_monologue"] = "internal_monologue"
|
|
internal_monologue: str
|
|
|
|
|
|
LegacyLettaMessage = Union[LegacyInternalMonologue, AssistantMessage, LegacyFunctionCallMessage, LegacyFunctionReturn]
|