mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
feat: extend message model to support more content types (#756)
This commit is contained in:
parent
2b8bf8262c
commit
dfce6a1552
@ -6,7 +6,7 @@ import requests
|
||||
|
||||
from letta.constants import MESSAGE_CHATGPT_FUNCTION_MODEL, MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE
|
||||
from letta.llm_api.llm_api_tools import create
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.message import Message, TextContent
|
||||
from letta.utils import json_dumps, json_loads
|
||||
|
||||
|
||||
@ -23,8 +23,13 @@ def message_chatgpt(self, message: str):
|
||||
dummy_user_id = uuid.uuid4()
|
||||
dummy_agent_id = uuid.uuid4()
|
||||
message_sequence = [
|
||||
Message(user_id=dummy_user_id, agent_id=dummy_agent_id, role="system", text=MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE),
|
||||
Message(user_id=dummy_user_id, agent_id=dummy_agent_id, role="user", text=str(message)),
|
||||
Message(
|
||||
user_id=dummy_user_id,
|
||||
agent_id=dummy_agent_id,
|
||||
role="system",
|
||||
content=[TextContent(text=MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE)],
|
||||
),
|
||||
Message(user_id=dummy_user_id, agent_id=dummy_agent_id, role="user", content=[TextContent(text=str(message))]),
|
||||
]
|
||||
# TODO: this will error without an LLMConfig
|
||||
response = create(
|
||||
|
@ -6,7 +6,7 @@ from letta.prompts.gpt_summarize import SYSTEM as SUMMARY_PROMPT_SYSTEM
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.memory import Memory
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.message import Message, TextContent
|
||||
from letta.settings import summarizer_settings
|
||||
from letta.utils import count_tokens, printd
|
||||
|
||||
@ -60,9 +60,9 @@ def summarize_messages(
|
||||
|
||||
dummy_agent_id = agent_state.id
|
||||
message_sequence = [
|
||||
Message(agent_id=dummy_agent_id, role=MessageRole.system, text=summary_prompt),
|
||||
Message(agent_id=dummy_agent_id, role=MessageRole.assistant, text=MESSAGE_SUMMARY_REQUEST_ACK),
|
||||
Message(agent_id=dummy_agent_id, role=MessageRole.user, text=summary_input),
|
||||
Message(agent_id=dummy_agent_id, role=MessageRole.system, content=[TextContent(text=summary_prompt)]),
|
||||
Message(agent_id=dummy_agent_id, role=MessageRole.assistant, content=[TextContent(text=MESSAGE_SUMMARY_REQUEST_ACK)]),
|
||||
Message(agent_id=dummy_agent_id, role=MessageRole.user, content=[TextContent(text=summary_input)]),
|
||||
]
|
||||
|
||||
# TODO: We need to eventually have a separate LLM config for the summarizer LLM
|
||||
|
@ -8,6 +8,7 @@ from letta.orm.custom_columns import ToolCallColumn
|
||||
from letta.orm.mixins import AgentMixin, OrganizationMixin
|
||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.message import TextContent as PydanticTextContent
|
||||
|
||||
|
||||
class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):
|
||||
@ -45,3 +46,10 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):
|
||||
def job(self) -> Optional["Job"]:
|
||||
"""Get the job associated with this message, if any."""
|
||||
return self.job_message.job if self.job_message else None
|
||||
|
||||
def to_pydantic(self) -> PydanticMessage:
|
||||
"""custom pydantic conversion for message content mapping"""
|
||||
model = self.__pydantic_model__.model_validate(self)
|
||||
if self.text:
|
||||
model.content = [PydanticTextContent(text=self.text)]
|
||||
return model
|
||||
|
@ -9,6 +9,10 @@ class MessageRole(str, Enum):
|
||||
system = "system"
|
||||
|
||||
|
||||
class MessageContentType(str, Enum):
|
||||
text = "text"
|
||||
|
||||
|
||||
class OptionState(str, Enum):
|
||||
"""Useful for kwargs that are bool + default option"""
|
||||
|
||||
|
@ -2,15 +2,15 @@ import copy
|
||||
import json
|
||||
import warnings
|
||||
from datetime import datetime, timezone
|
||||
from typing import List, Literal, Optional
|
||||
from typing import Annotated, Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall
|
||||
from openai.types.chat.chat_completion_message_tool_call import Function as OpenAIFunction
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
|
||||
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, TOOL_CALL_ID_MAX_LEN
|
||||
from letta.local_llm.constants import INNER_THOUGHTS_KWARG
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.enums import MessageContentType, MessageRole
|
||||
from letta.schemas.letta_base import OrmMetadataBase
|
||||
from letta.schemas.letta_message import (
|
||||
AssistantMessage,
|
||||
@ -80,6 +80,21 @@ class MessageUpdate(BaseModel):
|
||||
tool_call_id: Optional[str] = Field(None, description="The id of the tool call.")
|
||||
|
||||
|
||||
class MessageContent(BaseModel):
|
||||
type: MessageContentType = Field(..., description="The type of the message.")
|
||||
|
||||
|
||||
class TextContent(MessageContent):
|
||||
type: Literal[MessageContentType.text] = Field(MessageContentType.text, description="The type of the message.")
|
||||
text: str = Field(..., description="The text content of the message.")
|
||||
|
||||
|
||||
MessageContentUnion = Annotated[
|
||||
Union[TextContent],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
class Message(BaseMessage):
|
||||
"""
|
||||
Letta's internal representation of a message. Includes methods to convert to/from LLM provider formats.
|
||||
@ -100,7 +115,7 @@ class Message(BaseMessage):
|
||||
|
||||
id: str = BaseMessage.generate_id_field()
|
||||
role: MessageRole = Field(..., description="The role of the participant.")
|
||||
text: Optional[str] = Field(None, description="The text of the message.")
|
||||
content: Optional[List[MessageContentUnion]] = Field(None, description="The content of the message.")
|
||||
organization_id: Optional[str] = Field(None, description="The unique identifier of the organization.")
|
||||
agent_id: Optional[str] = Field(None, description="The unique identifier of the agent.")
|
||||
model: Optional[str] = Field(None, description="The model used to make the function call.")
|
||||
@ -108,6 +123,7 @@ class Message(BaseMessage):
|
||||
tool_calls: Optional[List[OpenAIToolCall,]] = Field(None, description="The list of tool calls requested.")
|
||||
tool_call_id: Optional[str] = Field(None, description="The id of the tool call.")
|
||||
step_id: Optional[str] = Field(None, description="The id of the step that this message was created in.")
|
||||
|
||||
# This overrides the optional base orm schema, created_at MUST exist on all messages objects
|
||||
created_at: datetime = Field(default_factory=get_utc_time, description="The timestamp when the object was created.")
|
||||
|
||||
@ -118,6 +134,37 @@ class Message(BaseMessage):
|
||||
assert v in roles, f"Role must be one of {roles}"
|
||||
return v
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def convert_from_orm(cls, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
if isinstance(data, dict):
|
||||
if "text" in data and "content" not in data:
|
||||
data["content"] = [TextContent(text=data["text"])]
|
||||
del data["text"]
|
||||
return data
|
||||
|
||||
def model_dump(self, to_orm: bool = False, **kwargs) -> Dict[str, Any]:
|
||||
data = super().model_dump(**kwargs)
|
||||
if to_orm:
|
||||
for content in data["content"]:
|
||||
if content["type"] == "text":
|
||||
data["text"] = content["text"]
|
||||
del data["content"]
|
||||
return data
|
||||
|
||||
@property
|
||||
def text(self) -> Optional[str]:
|
||||
"""
|
||||
Retrieve the first text content's text.
|
||||
|
||||
Returns:
|
||||
str: The text content, or None if no text content exists
|
||||
"""
|
||||
if not self.content:
|
||||
return None
|
||||
text_content = [content.text for content in self.content if content.type == MessageContentType.text]
|
||||
return text_content[0] if text_content else None
|
||||
|
||||
def to_json(self):
|
||||
json_message = vars(self)
|
||||
if json_message["tool_calls"] is not None:
|
||||
@ -283,7 +330,7 @@ class Message(BaseMessage):
|
||||
model=model,
|
||||
# standard fields expected in an OpenAI ChatCompletion message object
|
||||
role=MessageRole.tool, # NOTE
|
||||
text=openai_message_dict["content"],
|
||||
content=[TextContent(text=openai_message_dict["content"])],
|
||||
name=openai_message_dict["name"] if "name" in openai_message_dict else None,
|
||||
tool_calls=openai_message_dict["tool_calls"] if "tool_calls" in openai_message_dict else None,
|
||||
tool_call_id=openai_message_dict["tool_call_id"] if "tool_call_id" in openai_message_dict else None,
|
||||
@ -296,7 +343,7 @@ class Message(BaseMessage):
|
||||
model=model,
|
||||
# standard fields expected in an OpenAI ChatCompletion message object
|
||||
role=MessageRole.tool, # NOTE
|
||||
text=openai_message_dict["content"],
|
||||
content=[TextContent(text=openai_message_dict["content"])],
|
||||
name=openai_message_dict["name"] if "name" in openai_message_dict else None,
|
||||
tool_calls=openai_message_dict["tool_calls"] if "tool_calls" in openai_message_dict else None,
|
||||
tool_call_id=openai_message_dict["tool_call_id"] if "tool_call_id" in openai_message_dict else None,
|
||||
@ -328,7 +375,7 @@ class Message(BaseMessage):
|
||||
model=model,
|
||||
# standard fields expected in an OpenAI ChatCompletion message object
|
||||
role=MessageRole(openai_message_dict["role"]),
|
||||
text=openai_message_dict["content"],
|
||||
content=[TextContent(text=openai_message_dict["content"])],
|
||||
name=openai_message_dict["name"] if "name" in openai_message_dict else None,
|
||||
tool_calls=tool_calls,
|
||||
tool_call_id=None, # NOTE: None, since this field is only non-null for role=='tool'
|
||||
@ -341,7 +388,7 @@ class Message(BaseMessage):
|
||||
model=model,
|
||||
# standard fields expected in an OpenAI ChatCompletion message object
|
||||
role=MessageRole(openai_message_dict["role"]),
|
||||
text=openai_message_dict["content"],
|
||||
content=[TextContent(text=openai_message_dict["content"])],
|
||||
name=openai_message_dict["name"] if "name" in openai_message_dict else None,
|
||||
tool_calls=tool_calls,
|
||||
tool_call_id=None, # NOTE: None, since this field is only non-null for role=='tool'
|
||||
@ -373,7 +420,7 @@ class Message(BaseMessage):
|
||||
model=model,
|
||||
# standard fields expected in an OpenAI ChatCompletion message object
|
||||
role=MessageRole(openai_message_dict["role"]),
|
||||
text=openai_message_dict["content"],
|
||||
content=[TextContent(text=openai_message_dict["content"])],
|
||||
name=openai_message_dict["name"] if "name" in openai_message_dict else None,
|
||||
tool_calls=tool_calls,
|
||||
tool_call_id=openai_message_dict["tool_call_id"] if "tool_call_id" in openai_message_dict else None,
|
||||
@ -386,7 +433,7 @@ class Message(BaseMessage):
|
||||
model=model,
|
||||
# standard fields expected in an OpenAI ChatCompletion message object
|
||||
role=MessageRole(openai_message_dict["role"]),
|
||||
text=openai_message_dict["content"],
|
||||
content=[TextContent(text=openai_message_dict["content"] or "")],
|
||||
name=openai_message_dict["name"] if "name" in openai_message_dict else None,
|
||||
tool_calls=tool_calls,
|
||||
tool_call_id=openai_message_dict["tool_call_id"] if "tool_call_id" in openai_message_dict else None,
|
||||
|
@ -39,7 +39,7 @@ from letta.schemas.letta_message import LegacyLettaMessage, LettaMessage, ToolRe
|
||||
from letta.schemas.letta_response import LettaResponse
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.memory import ArchivalMemorySummary, ContextWindowOverview, Memory, RecallMemorySummary
|
||||
from letta.schemas.message import Message, MessageCreate, MessageRole, MessageUpdate
|
||||
from letta.schemas.message import Message, MessageCreate, MessageRole, MessageUpdate, TextContent
|
||||
from letta.schemas.organization import Organization
|
||||
from letta.schemas.passage import Passage
|
||||
from letta.schemas.providers import (
|
||||
@ -617,14 +617,14 @@ class SyncServer(Server):
|
||||
message = Message(
|
||||
agent_id=agent_id,
|
||||
role="user",
|
||||
text=packaged_user_message,
|
||||
content=[TextContent(text=packaged_user_message)],
|
||||
created_at=timestamp,
|
||||
)
|
||||
else:
|
||||
message = Message(
|
||||
agent_id=agent_id,
|
||||
role="user",
|
||||
text=packaged_user_message,
|
||||
content=[TextContent(text=packaged_user_message)],
|
||||
)
|
||||
|
||||
# Run the agent state forward
|
||||
@ -667,14 +667,14 @@ class SyncServer(Server):
|
||||
message = Message(
|
||||
agent_id=agent_id,
|
||||
role="system",
|
||||
text=packaged_system_message,
|
||||
content=[TextContent(text=packaged_system_message)],
|
||||
created_at=timestamp,
|
||||
)
|
||||
else:
|
||||
message = Message(
|
||||
agent_id=agent_id,
|
||||
role="system",
|
||||
text=packaged_system_message,
|
||||
content=[TextContent(text=packaged_system_message)],
|
||||
)
|
||||
|
||||
if isinstance(message, Message):
|
||||
@ -732,7 +732,7 @@ class SyncServer(Server):
|
||||
Message(
|
||||
agent_id=agent_id,
|
||||
role=message.role,
|
||||
text=message.text,
|
||||
content=[TextContent(text=message.text)],
|
||||
name=message.name,
|
||||
# assigned later?
|
||||
model=None,
|
||||
|
@ -11,7 +11,7 @@ from letta.prompts import gpt_system
|
||||
from letta.schemas.agent import AgentState, AgentType
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.memory import Memory
|
||||
from letta.schemas.message import Message, MessageCreate
|
||||
from letta.schemas.message import Message, MessageCreate, TextContent
|
||||
from letta.schemas.tool_rule import ToolRule
|
||||
from letta.schemas.user import User
|
||||
from letta.system import get_initial_boot_messages, get_login_event
|
||||
@ -244,7 +244,14 @@ def package_initial_message_sequence(
|
||||
raise ValueError(f"Invalid message role: {message_create.role}")
|
||||
|
||||
init_messages.append(
|
||||
Message(role=message_create.role, text=packed_message, organization_id=actor.organization_id, agent_id=agent_id, model=model)
|
||||
Message(
|
||||
role=message_create.role,
|
||||
content=[TextContent(text=packed_message)],
|
||||
name=message_create.name,
|
||||
organization_id=actor.organization_id,
|
||||
agent_id=agent_id,
|
||||
model=model,
|
||||
)
|
||||
)
|
||||
return init_messages
|
||||
|
||||
|
@ -49,7 +49,7 @@ class MessageManager:
|
||||
with self.session_maker() as session:
|
||||
# Set the organization id of the Pydantic message
|
||||
pydantic_msg.organization_id = actor.organization_id
|
||||
msg_data = pydantic_msg.model_dump()
|
||||
msg_data = pydantic_msg.model_dump(to_orm=True)
|
||||
msg = MessageModel(**msg_data)
|
||||
msg.create(session, actor=actor) # Persist to database
|
||||
return msg.to_pydantic()
|
||||
|
@ -14,7 +14,7 @@ from letta.llm_api.helpers import calculate_summarizer_cutoff
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.message import Message, TextContent
|
||||
from letta.settings import summarizer_settings
|
||||
from letta.streaming_interface import StreamingRefreshCLIInterface
|
||||
from tests.helpers.endpoints_helper import EMBEDDING_CONFIG_PATH
|
||||
@ -55,7 +55,7 @@ def generate_message(role: str, text: str = None, tool_calls: List = None) -> Me
|
||||
return Message(
|
||||
id="message-" + str(uuid.uuid4()),
|
||||
role=MessageRole(role),
|
||||
text=text or f"{role} message text",
|
||||
content=[TextContent(text=text or f"{role} message text")],
|
||||
created_at=datetime.utcnow(),
|
||||
tool_calls=tool_calls or [],
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user