feat: extend message model to support more content types (#756)

This commit is contained in:
cthomas 2025-01-23 17:24:52 -08:00 committed by GitHub
parent 2b8bf8262c
commit dfce6a1552
9 changed files with 99 additions and 28 deletions

View File

@ -6,7 +6,7 @@ import requests
from letta.constants import MESSAGE_CHATGPT_FUNCTION_MODEL, MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE
from letta.llm_api.llm_api_tools import create
from letta.schemas.message import Message
from letta.schemas.message import Message, TextContent
from letta.utils import json_dumps, json_loads
@ -23,8 +23,13 @@ def message_chatgpt(self, message: str):
dummy_user_id = uuid.uuid4()
dummy_agent_id = uuid.uuid4()
message_sequence = [
Message(user_id=dummy_user_id, agent_id=dummy_agent_id, role="system", text=MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE),
Message(user_id=dummy_user_id, agent_id=dummy_agent_id, role="user", text=str(message)),
Message(
user_id=dummy_user_id,
agent_id=dummy_agent_id,
role="system",
content=[TextContent(text=MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE)],
),
Message(user_id=dummy_user_id, agent_id=dummy_agent_id, role="user", content=[TextContent(text=str(message))]),
]
# TODO: this will error without an LLMConfig
response = create(

View File

@ -6,7 +6,7 @@ from letta.prompts.gpt_summarize import SYSTEM as SUMMARY_PROMPT_SYSTEM
from letta.schemas.agent import AgentState
from letta.schemas.enums import MessageRole
from letta.schemas.memory import Memory
from letta.schemas.message import Message
from letta.schemas.message import Message, TextContent
from letta.settings import summarizer_settings
from letta.utils import count_tokens, printd
@ -60,9 +60,9 @@ def summarize_messages(
dummy_agent_id = agent_state.id
message_sequence = [
Message(agent_id=dummy_agent_id, role=MessageRole.system, text=summary_prompt),
Message(agent_id=dummy_agent_id, role=MessageRole.assistant, text=MESSAGE_SUMMARY_REQUEST_ACK),
Message(agent_id=dummy_agent_id, role=MessageRole.user, text=summary_input),
Message(agent_id=dummy_agent_id, role=MessageRole.system, content=[TextContent(text=summary_prompt)]),
Message(agent_id=dummy_agent_id, role=MessageRole.assistant, content=[TextContent(text=MESSAGE_SUMMARY_REQUEST_ACK)]),
Message(agent_id=dummy_agent_id, role=MessageRole.user, content=[TextContent(text=summary_input)]),
]
# TODO: We need to eventually have a separate LLM config for the summarizer LLM

View File

@ -8,6 +8,7 @@ from letta.orm.custom_columns import ToolCallColumn
from letta.orm.mixins import AgentMixin, OrganizationMixin
from letta.orm.sqlalchemy_base import SqlalchemyBase
from letta.schemas.message import Message as PydanticMessage
from letta.schemas.message import TextContent as PydanticTextContent
class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):
@ -45,3 +46,10 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):
def job(self) -> Optional["Job"]:
"""Get the job associated with this message, if any."""
return self.job_message.job if self.job_message else None
def to_pydantic(self) -> PydanticMessage:
"""custom pydantic conversion for message content mapping"""
model = self.__pydantic_model__.model_validate(self)
if self.text:
model.content = [PydanticTextContent(text=self.text)]
return model

View File

@ -9,6 +9,10 @@ class MessageRole(str, Enum):
system = "system"
class MessageContentType(str, Enum):
text = "text"
class OptionState(str, Enum):
"""Useful for kwargs that are bool + default option"""

View File

@ -2,15 +2,15 @@ import copy
import json
import warnings
from datetime import datetime, timezone
from typing import List, Literal, Optional
from typing import Annotated, Any, Dict, List, Literal, Optional, Union
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall
from openai.types.chat.chat_completion_message_tool_call import Function as OpenAIFunction
from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel, Field, field_validator, model_validator
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, TOOL_CALL_ID_MAX_LEN
from letta.local_llm.constants import INNER_THOUGHTS_KWARG
from letta.schemas.enums import MessageRole
from letta.schemas.enums import MessageContentType, MessageRole
from letta.schemas.letta_base import OrmMetadataBase
from letta.schemas.letta_message import (
AssistantMessage,
@ -80,6 +80,21 @@ class MessageUpdate(BaseModel):
tool_call_id: Optional[str] = Field(None, description="The id of the tool call.")
class MessageContent(BaseModel):
type: MessageContentType = Field(..., description="The type of the message.")
class TextContent(MessageContent):
type: Literal[MessageContentType.text] = Field(MessageContentType.text, description="The type of the message.")
text: str = Field(..., description="The text content of the message.")
MessageContentUnion = Annotated[
Union[TextContent],
Field(discriminator="type"),
]
class Message(BaseMessage):
"""
Letta's internal representation of a message. Includes methods to convert to/from LLM provider formats.
@ -100,7 +115,7 @@ class Message(BaseMessage):
id: str = BaseMessage.generate_id_field()
role: MessageRole = Field(..., description="The role of the participant.")
text: Optional[str] = Field(None, description="The text of the message.")
content: Optional[List[MessageContentUnion]] = Field(None, description="The content of the message.")
organization_id: Optional[str] = Field(None, description="The unique identifier of the organization.")
agent_id: Optional[str] = Field(None, description="The unique identifier of the agent.")
model: Optional[str] = Field(None, description="The model used to make the function call.")
@ -108,6 +123,7 @@ class Message(BaseMessage):
tool_calls: Optional[List[OpenAIToolCall,]] = Field(None, description="The list of tool calls requested.")
tool_call_id: Optional[str] = Field(None, description="The id of the tool call.")
step_id: Optional[str] = Field(None, description="The id of the step that this message was created in.")
# This overrides the optional base orm schema, created_at MUST exist on all messages objects
created_at: datetime = Field(default_factory=get_utc_time, description="The timestamp when the object was created.")
@ -118,6 +134,37 @@ class Message(BaseMessage):
assert v in roles, f"Role must be one of {roles}"
return v
@model_validator(mode="before")
@classmethod
def convert_from_orm(cls, data: Dict[str, Any]) -> Dict[str, Any]:
if isinstance(data, dict):
if "text" in data and "content" not in data:
data["content"] = [TextContent(text=data["text"])]
del data["text"]
return data
def model_dump(self, to_orm: bool = False, **kwargs) -> Dict[str, Any]:
data = super().model_dump(**kwargs)
if to_orm:
for content in data["content"]:
if content["type"] == "text":
data["text"] = content["text"]
del data["content"]
return data
@property
def text(self) -> Optional[str]:
"""
Retrieve the first text content's text.
Returns:
str: The text content, or None if no text content exists
"""
if not self.content:
return None
text_content = [content.text for content in self.content if content.type == MessageContentType.text]
return text_content[0] if text_content else None
def to_json(self):
json_message = vars(self)
if json_message["tool_calls"] is not None:
@ -283,7 +330,7 @@ class Message(BaseMessage):
model=model,
# standard fields expected in an OpenAI ChatCompletion message object
role=MessageRole.tool, # NOTE
text=openai_message_dict["content"],
content=[TextContent(text=openai_message_dict["content"])],
name=openai_message_dict["name"] if "name" in openai_message_dict else None,
tool_calls=openai_message_dict["tool_calls"] if "tool_calls" in openai_message_dict else None,
tool_call_id=openai_message_dict["tool_call_id"] if "tool_call_id" in openai_message_dict else None,
@ -296,7 +343,7 @@ class Message(BaseMessage):
model=model,
# standard fields expected in an OpenAI ChatCompletion message object
role=MessageRole.tool, # NOTE
text=openai_message_dict["content"],
content=[TextContent(text=openai_message_dict["content"])],
name=openai_message_dict["name"] if "name" in openai_message_dict else None,
tool_calls=openai_message_dict["tool_calls"] if "tool_calls" in openai_message_dict else None,
tool_call_id=openai_message_dict["tool_call_id"] if "tool_call_id" in openai_message_dict else None,
@ -328,7 +375,7 @@ class Message(BaseMessage):
model=model,
# standard fields expected in an OpenAI ChatCompletion message object
role=MessageRole(openai_message_dict["role"]),
text=openai_message_dict["content"],
content=[TextContent(text=openai_message_dict["content"])],
name=openai_message_dict["name"] if "name" in openai_message_dict else None,
tool_calls=tool_calls,
tool_call_id=None, # NOTE: None, since this field is only non-null for role=='tool'
@ -341,7 +388,7 @@ class Message(BaseMessage):
model=model,
# standard fields expected in an OpenAI ChatCompletion message object
role=MessageRole(openai_message_dict["role"]),
text=openai_message_dict["content"],
content=[TextContent(text=openai_message_dict["content"])],
name=openai_message_dict["name"] if "name" in openai_message_dict else None,
tool_calls=tool_calls,
tool_call_id=None, # NOTE: None, since this field is only non-null for role=='tool'
@ -373,7 +420,7 @@ class Message(BaseMessage):
model=model,
# standard fields expected in an OpenAI ChatCompletion message object
role=MessageRole(openai_message_dict["role"]),
text=openai_message_dict["content"],
content=[TextContent(text=openai_message_dict["content"])],
name=openai_message_dict["name"] if "name" in openai_message_dict else None,
tool_calls=tool_calls,
tool_call_id=openai_message_dict["tool_call_id"] if "tool_call_id" in openai_message_dict else None,
@ -386,7 +433,7 @@ class Message(BaseMessage):
model=model,
# standard fields expected in an OpenAI ChatCompletion message object
role=MessageRole(openai_message_dict["role"]),
text=openai_message_dict["content"],
content=[TextContent(text=openai_message_dict["content"] or "")],
name=openai_message_dict["name"] if "name" in openai_message_dict else None,
tool_calls=tool_calls,
tool_call_id=openai_message_dict["tool_call_id"] if "tool_call_id" in openai_message_dict else None,

View File

@ -39,7 +39,7 @@ from letta.schemas.letta_message import LegacyLettaMessage, LettaMessage, ToolRe
from letta.schemas.letta_response import LettaResponse
from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import ArchivalMemorySummary, ContextWindowOverview, Memory, RecallMemorySummary
from letta.schemas.message import Message, MessageCreate, MessageRole, MessageUpdate
from letta.schemas.message import Message, MessageCreate, MessageRole, MessageUpdate, TextContent
from letta.schemas.organization import Organization
from letta.schemas.passage import Passage
from letta.schemas.providers import (
@ -617,14 +617,14 @@ class SyncServer(Server):
message = Message(
agent_id=agent_id,
role="user",
text=packaged_user_message,
content=[TextContent(text=packaged_user_message)],
created_at=timestamp,
)
else:
message = Message(
agent_id=agent_id,
role="user",
text=packaged_user_message,
content=[TextContent(text=packaged_user_message)],
)
# Run the agent state forward
@ -667,14 +667,14 @@ class SyncServer(Server):
message = Message(
agent_id=agent_id,
role="system",
text=packaged_system_message,
content=[TextContent(text=packaged_system_message)],
created_at=timestamp,
)
else:
message = Message(
agent_id=agent_id,
role="system",
text=packaged_system_message,
content=[TextContent(text=packaged_system_message)],
)
if isinstance(message, Message):
@ -732,7 +732,7 @@ class SyncServer(Server):
Message(
agent_id=agent_id,
role=message.role,
text=message.text,
content=[TextContent(text=message.text)],
name=message.name,
# assigned later?
model=None,

View File

@ -11,7 +11,7 @@ from letta.prompts import gpt_system
from letta.schemas.agent import AgentState, AgentType
from letta.schemas.enums import MessageRole
from letta.schemas.memory import Memory
from letta.schemas.message import Message, MessageCreate
from letta.schemas.message import Message, MessageCreate, TextContent
from letta.schemas.tool_rule import ToolRule
from letta.schemas.user import User
from letta.system import get_initial_boot_messages, get_login_event
@ -244,7 +244,14 @@ def package_initial_message_sequence(
raise ValueError(f"Invalid message role: {message_create.role}")
init_messages.append(
Message(role=message_create.role, text=packed_message, organization_id=actor.organization_id, agent_id=agent_id, model=model)
Message(
role=message_create.role,
content=[TextContent(text=packed_message)],
name=message_create.name,
organization_id=actor.organization_id,
agent_id=agent_id,
model=model,
)
)
return init_messages

View File

@ -49,7 +49,7 @@ class MessageManager:
with self.session_maker() as session:
# Set the organization id of the Pydantic message
pydantic_msg.organization_id = actor.organization_id
msg_data = pydantic_msg.model_dump()
msg_data = pydantic_msg.model_dump(to_orm=True)
msg = MessageModel(**msg_data)
msg.create(session, actor=actor) # Persist to database
return msg.to_pydantic()

View File

@ -14,7 +14,7 @@ from letta.llm_api.helpers import calculate_summarizer_cutoff
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import MessageRole
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message
from letta.schemas.message import Message, TextContent
from letta.settings import summarizer_settings
from letta.streaming_interface import StreamingRefreshCLIInterface
from tests.helpers.endpoints_helper import EMBEDDING_CONFIG_PATH
@ -55,7 +55,7 @@ def generate_message(role: str, text: str = None, tool_calls: List = None) -> Me
return Message(
id="message-" + str(uuid.uuid4()),
role=MessageRole(role),
text=text or f"{role} message text",
content=[TextContent(text=text or f"{role} message text")],
created_at=datetime.utcnow(),
tool_calls=tool_calls or [],
)