feat: extend message model to support more content types (#756)

This commit is contained in:
cthomas 2025-01-23 17:24:52 -08:00 committed by GitHub
parent 2b8bf8262c
commit dfce6a1552
9 changed files with 99 additions and 28 deletions

View File

@ -6,7 +6,7 @@ import requests
from letta.constants import MESSAGE_CHATGPT_FUNCTION_MODEL, MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE from letta.constants import MESSAGE_CHATGPT_FUNCTION_MODEL, MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE
from letta.llm_api.llm_api_tools import create from letta.llm_api.llm_api_tools import create
from letta.schemas.message import Message from letta.schemas.message import Message, TextContent
from letta.utils import json_dumps, json_loads from letta.utils import json_dumps, json_loads
@ -23,8 +23,13 @@ def message_chatgpt(self, message: str):
dummy_user_id = uuid.uuid4() dummy_user_id = uuid.uuid4()
dummy_agent_id = uuid.uuid4() dummy_agent_id = uuid.uuid4()
message_sequence = [ message_sequence = [
Message(user_id=dummy_user_id, agent_id=dummy_agent_id, role="system", text=MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE), Message(
Message(user_id=dummy_user_id, agent_id=dummy_agent_id, role="user", text=str(message)), user_id=dummy_user_id,
agent_id=dummy_agent_id,
role="system",
content=[TextContent(text=MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE)],
),
Message(user_id=dummy_user_id, agent_id=dummy_agent_id, role="user", content=[TextContent(text=str(message))]),
] ]
# TODO: this will error without an LLMConfig # TODO: this will error without an LLMConfig
response = create( response = create(

View File

@ -6,7 +6,7 @@ from letta.prompts.gpt_summarize import SYSTEM as SUMMARY_PROMPT_SYSTEM
from letta.schemas.agent import AgentState from letta.schemas.agent import AgentState
from letta.schemas.enums import MessageRole from letta.schemas.enums import MessageRole
from letta.schemas.memory import Memory from letta.schemas.memory import Memory
from letta.schemas.message import Message from letta.schemas.message import Message, TextContent
from letta.settings import summarizer_settings from letta.settings import summarizer_settings
from letta.utils import count_tokens, printd from letta.utils import count_tokens, printd
@ -60,9 +60,9 @@ def summarize_messages(
dummy_agent_id = agent_state.id dummy_agent_id = agent_state.id
message_sequence = [ message_sequence = [
Message(agent_id=dummy_agent_id, role=MessageRole.system, text=summary_prompt), Message(agent_id=dummy_agent_id, role=MessageRole.system, content=[TextContent(text=summary_prompt)]),
Message(agent_id=dummy_agent_id, role=MessageRole.assistant, text=MESSAGE_SUMMARY_REQUEST_ACK), Message(agent_id=dummy_agent_id, role=MessageRole.assistant, content=[TextContent(text=MESSAGE_SUMMARY_REQUEST_ACK)]),
Message(agent_id=dummy_agent_id, role=MessageRole.user, text=summary_input), Message(agent_id=dummy_agent_id, role=MessageRole.user, content=[TextContent(text=summary_input)]),
] ]
# TODO: We need to eventually have a separate LLM config for the summarizer LLM # TODO: We need to eventually have a separate LLM config for the summarizer LLM

View File

@ -8,6 +8,7 @@ from letta.orm.custom_columns import ToolCallColumn
from letta.orm.mixins import AgentMixin, OrganizationMixin from letta.orm.mixins import AgentMixin, OrganizationMixin
from letta.orm.sqlalchemy_base import SqlalchemyBase from letta.orm.sqlalchemy_base import SqlalchemyBase
from letta.schemas.message import Message as PydanticMessage from letta.schemas.message import Message as PydanticMessage
from letta.schemas.message import TextContent as PydanticTextContent
class Message(SqlalchemyBase, OrganizationMixin, AgentMixin): class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):
@ -45,3 +46,10 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):
def job(self) -> Optional["Job"]: def job(self) -> Optional["Job"]:
"""Get the job associated with this message, if any.""" """Get the job associated with this message, if any."""
return self.job_message.job if self.job_message else None return self.job_message.job if self.job_message else None
def to_pydantic(self) -> PydanticMessage:
"""custom pydantic conversion for message content mapping"""
model = self.__pydantic_model__.model_validate(self)
if self.text:
model.content = [PydanticTextContent(text=self.text)]
return model

View File

@ -9,6 +9,10 @@ class MessageRole(str, Enum):
system = "system" system = "system"
class MessageContentType(str, Enum):
text = "text"
class OptionState(str, Enum): class OptionState(str, Enum):
"""Useful for kwargs that are bool + default option""" """Useful for kwargs that are bool + default option"""

View File

@ -2,15 +2,15 @@ import copy
import json import json
import warnings import warnings
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import List, Literal, Optional from typing import Annotated, Any, Dict, List, Literal, Optional, Union
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall
from openai.types.chat.chat_completion_message_tool_call import Function as OpenAIFunction from openai.types.chat.chat_completion_message_tool_call import Function as OpenAIFunction
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator, model_validator
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, TOOL_CALL_ID_MAX_LEN from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, TOOL_CALL_ID_MAX_LEN
from letta.local_llm.constants import INNER_THOUGHTS_KWARG from letta.local_llm.constants import INNER_THOUGHTS_KWARG
from letta.schemas.enums import MessageRole from letta.schemas.enums import MessageContentType, MessageRole
from letta.schemas.letta_base import OrmMetadataBase from letta.schemas.letta_base import OrmMetadataBase
from letta.schemas.letta_message import ( from letta.schemas.letta_message import (
AssistantMessage, AssistantMessage,
@ -80,6 +80,21 @@ class MessageUpdate(BaseModel):
tool_call_id: Optional[str] = Field(None, description="The id of the tool call.") tool_call_id: Optional[str] = Field(None, description="The id of the tool call.")
class MessageContent(BaseModel):
type: MessageContentType = Field(..., description="The type of the message.")
class TextContent(MessageContent):
type: Literal[MessageContentType.text] = Field(MessageContentType.text, description="The type of the message.")
text: str = Field(..., description="The text content of the message.")
MessageContentUnion = Annotated[
Union[TextContent],
Field(discriminator="type"),
]
class Message(BaseMessage): class Message(BaseMessage):
""" """
Letta's internal representation of a message. Includes methods to convert to/from LLM provider formats. Letta's internal representation of a message. Includes methods to convert to/from LLM provider formats.
@ -100,7 +115,7 @@ class Message(BaseMessage):
id: str = BaseMessage.generate_id_field() id: str = BaseMessage.generate_id_field()
role: MessageRole = Field(..., description="The role of the participant.") role: MessageRole = Field(..., description="The role of the participant.")
text: Optional[str] = Field(None, description="The text of the message.") content: Optional[List[MessageContentUnion]] = Field(None, description="The content of the message.")
organization_id: Optional[str] = Field(None, description="The unique identifier of the organization.") organization_id: Optional[str] = Field(None, description="The unique identifier of the organization.")
agent_id: Optional[str] = Field(None, description="The unique identifier of the agent.") agent_id: Optional[str] = Field(None, description="The unique identifier of the agent.")
model: Optional[str] = Field(None, description="The model used to make the function call.") model: Optional[str] = Field(None, description="The model used to make the function call.")
@ -108,6 +123,7 @@ class Message(BaseMessage):
tool_calls: Optional[List[OpenAIToolCall,]] = Field(None, description="The list of tool calls requested.") tool_calls: Optional[List[OpenAIToolCall,]] = Field(None, description="The list of tool calls requested.")
tool_call_id: Optional[str] = Field(None, description="The id of the tool call.") tool_call_id: Optional[str] = Field(None, description="The id of the tool call.")
step_id: Optional[str] = Field(None, description="The id of the step that this message was created in.") step_id: Optional[str] = Field(None, description="The id of the step that this message was created in.")
# This overrides the optional base orm schema, created_at MUST exist on all messages objects # This overrides the optional base orm schema, created_at MUST exist on all messages objects
created_at: datetime = Field(default_factory=get_utc_time, description="The timestamp when the object was created.") created_at: datetime = Field(default_factory=get_utc_time, description="The timestamp when the object was created.")
@ -118,6 +134,37 @@ class Message(BaseMessage):
assert v in roles, f"Role must be one of {roles}" assert v in roles, f"Role must be one of {roles}"
return v return v
@model_validator(mode="before")
@classmethod
def convert_from_orm(cls, data: Dict[str, Any]) -> Dict[str, Any]:
if isinstance(data, dict):
if "text" in data and "content" not in data:
data["content"] = [TextContent(text=data["text"])]
del data["text"]
return data
def model_dump(self, to_orm: bool = False, **kwargs) -> Dict[str, Any]:
data = super().model_dump(**kwargs)
if to_orm:
for content in data["content"]:
if content["type"] == "text":
data["text"] = content["text"]
del data["content"]
return data
@property
def text(self) -> Optional[str]:
"""
Retrieve the first text content's text.
Returns:
str: The text content, or None if no text content exists
"""
if not self.content:
return None
text_content = [content.text for content in self.content if content.type == MessageContentType.text]
return text_content[0] if text_content else None
def to_json(self): def to_json(self):
json_message = vars(self) json_message = vars(self)
if json_message["tool_calls"] is not None: if json_message["tool_calls"] is not None:
@ -283,7 +330,7 @@ class Message(BaseMessage):
model=model, model=model,
# standard fields expected in an OpenAI ChatCompletion message object # standard fields expected in an OpenAI ChatCompletion message object
role=MessageRole.tool, # NOTE role=MessageRole.tool, # NOTE
text=openai_message_dict["content"], content=[TextContent(text=openai_message_dict["content"])],
name=openai_message_dict["name"] if "name" in openai_message_dict else None, name=openai_message_dict["name"] if "name" in openai_message_dict else None,
tool_calls=openai_message_dict["tool_calls"] if "tool_calls" in openai_message_dict else None, tool_calls=openai_message_dict["tool_calls"] if "tool_calls" in openai_message_dict else None,
tool_call_id=openai_message_dict["tool_call_id"] if "tool_call_id" in openai_message_dict else None, tool_call_id=openai_message_dict["tool_call_id"] if "tool_call_id" in openai_message_dict else None,
@ -296,7 +343,7 @@ class Message(BaseMessage):
model=model, model=model,
# standard fields expected in an OpenAI ChatCompletion message object # standard fields expected in an OpenAI ChatCompletion message object
role=MessageRole.tool, # NOTE role=MessageRole.tool, # NOTE
text=openai_message_dict["content"], content=[TextContent(text=openai_message_dict["content"])],
name=openai_message_dict["name"] if "name" in openai_message_dict else None, name=openai_message_dict["name"] if "name" in openai_message_dict else None,
tool_calls=openai_message_dict["tool_calls"] if "tool_calls" in openai_message_dict else None, tool_calls=openai_message_dict["tool_calls"] if "tool_calls" in openai_message_dict else None,
tool_call_id=openai_message_dict["tool_call_id"] if "tool_call_id" in openai_message_dict else None, tool_call_id=openai_message_dict["tool_call_id"] if "tool_call_id" in openai_message_dict else None,
@ -328,7 +375,7 @@ class Message(BaseMessage):
model=model, model=model,
# standard fields expected in an OpenAI ChatCompletion message object # standard fields expected in an OpenAI ChatCompletion message object
role=MessageRole(openai_message_dict["role"]), role=MessageRole(openai_message_dict["role"]),
text=openai_message_dict["content"], content=[TextContent(text=openai_message_dict["content"])],
name=openai_message_dict["name"] if "name" in openai_message_dict else None, name=openai_message_dict["name"] if "name" in openai_message_dict else None,
tool_calls=tool_calls, tool_calls=tool_calls,
tool_call_id=None, # NOTE: None, since this field is only non-null for role=='tool' tool_call_id=None, # NOTE: None, since this field is only non-null for role=='tool'
@ -341,7 +388,7 @@ class Message(BaseMessage):
model=model, model=model,
# standard fields expected in an OpenAI ChatCompletion message object # standard fields expected in an OpenAI ChatCompletion message object
role=MessageRole(openai_message_dict["role"]), role=MessageRole(openai_message_dict["role"]),
text=openai_message_dict["content"], content=[TextContent(text=openai_message_dict["content"])],
name=openai_message_dict["name"] if "name" in openai_message_dict else None, name=openai_message_dict["name"] if "name" in openai_message_dict else None,
tool_calls=tool_calls, tool_calls=tool_calls,
tool_call_id=None, # NOTE: None, since this field is only non-null for role=='tool' tool_call_id=None, # NOTE: None, since this field is only non-null for role=='tool'
@ -373,7 +420,7 @@ class Message(BaseMessage):
model=model, model=model,
# standard fields expected in an OpenAI ChatCompletion message object # standard fields expected in an OpenAI ChatCompletion message object
role=MessageRole(openai_message_dict["role"]), role=MessageRole(openai_message_dict["role"]),
text=openai_message_dict["content"], content=[TextContent(text=openai_message_dict["content"])],
name=openai_message_dict["name"] if "name" in openai_message_dict else None, name=openai_message_dict["name"] if "name" in openai_message_dict else None,
tool_calls=tool_calls, tool_calls=tool_calls,
tool_call_id=openai_message_dict["tool_call_id"] if "tool_call_id" in openai_message_dict else None, tool_call_id=openai_message_dict["tool_call_id"] if "tool_call_id" in openai_message_dict else None,
@ -386,7 +433,7 @@ class Message(BaseMessage):
model=model, model=model,
# standard fields expected in an OpenAI ChatCompletion message object # standard fields expected in an OpenAI ChatCompletion message object
role=MessageRole(openai_message_dict["role"]), role=MessageRole(openai_message_dict["role"]),
text=openai_message_dict["content"], content=[TextContent(text=openai_message_dict["content"] or "")],
name=openai_message_dict["name"] if "name" in openai_message_dict else None, name=openai_message_dict["name"] if "name" in openai_message_dict else None,
tool_calls=tool_calls, tool_calls=tool_calls,
tool_call_id=openai_message_dict["tool_call_id"] if "tool_call_id" in openai_message_dict else None, tool_call_id=openai_message_dict["tool_call_id"] if "tool_call_id" in openai_message_dict else None,

View File

@ -39,7 +39,7 @@ from letta.schemas.letta_message import LegacyLettaMessage, LettaMessage, ToolRe
from letta.schemas.letta_response import LettaResponse from letta.schemas.letta_response import LettaResponse
from letta.schemas.llm_config import LLMConfig from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import ArchivalMemorySummary, ContextWindowOverview, Memory, RecallMemorySummary from letta.schemas.memory import ArchivalMemorySummary, ContextWindowOverview, Memory, RecallMemorySummary
from letta.schemas.message import Message, MessageCreate, MessageRole, MessageUpdate from letta.schemas.message import Message, MessageCreate, MessageRole, MessageUpdate, TextContent
from letta.schemas.organization import Organization from letta.schemas.organization import Organization
from letta.schemas.passage import Passage from letta.schemas.passage import Passage
from letta.schemas.providers import ( from letta.schemas.providers import (
@ -617,14 +617,14 @@ class SyncServer(Server):
message = Message( message = Message(
agent_id=agent_id, agent_id=agent_id,
role="user", role="user",
text=packaged_user_message, content=[TextContent(text=packaged_user_message)],
created_at=timestamp, created_at=timestamp,
) )
else: else:
message = Message( message = Message(
agent_id=agent_id, agent_id=agent_id,
role="user", role="user",
text=packaged_user_message, content=[TextContent(text=packaged_user_message)],
) )
# Run the agent state forward # Run the agent state forward
@ -667,14 +667,14 @@ class SyncServer(Server):
message = Message( message = Message(
agent_id=agent_id, agent_id=agent_id,
role="system", role="system",
text=packaged_system_message, content=[TextContent(text=packaged_system_message)],
created_at=timestamp, created_at=timestamp,
) )
else: else:
message = Message( message = Message(
agent_id=agent_id, agent_id=agent_id,
role="system", role="system",
text=packaged_system_message, content=[TextContent(text=packaged_system_message)],
) )
if isinstance(message, Message): if isinstance(message, Message):
@ -732,7 +732,7 @@ class SyncServer(Server):
Message( Message(
agent_id=agent_id, agent_id=agent_id,
role=message.role, role=message.role,
text=message.text, content=[TextContent(text=message.text)],
name=message.name, name=message.name,
# assigned later? # assigned later?
model=None, model=None,

View File

@ -11,7 +11,7 @@ from letta.prompts import gpt_system
from letta.schemas.agent import AgentState, AgentType from letta.schemas.agent import AgentState, AgentType
from letta.schemas.enums import MessageRole from letta.schemas.enums import MessageRole
from letta.schemas.memory import Memory from letta.schemas.memory import Memory
from letta.schemas.message import Message, MessageCreate from letta.schemas.message import Message, MessageCreate, TextContent
from letta.schemas.tool_rule import ToolRule from letta.schemas.tool_rule import ToolRule
from letta.schemas.user import User from letta.schemas.user import User
from letta.system import get_initial_boot_messages, get_login_event from letta.system import get_initial_boot_messages, get_login_event
@ -244,7 +244,14 @@ def package_initial_message_sequence(
raise ValueError(f"Invalid message role: {message_create.role}") raise ValueError(f"Invalid message role: {message_create.role}")
init_messages.append( init_messages.append(
Message(role=message_create.role, text=packed_message, organization_id=actor.organization_id, agent_id=agent_id, model=model) Message(
role=message_create.role,
content=[TextContent(text=packed_message)],
name=message_create.name,
organization_id=actor.organization_id,
agent_id=agent_id,
model=model,
)
) )
return init_messages return init_messages

View File

@ -49,7 +49,7 @@ class MessageManager:
with self.session_maker() as session: with self.session_maker() as session:
# Set the organization id of the Pydantic message # Set the organization id of the Pydantic message
pydantic_msg.organization_id = actor.organization_id pydantic_msg.organization_id = actor.organization_id
msg_data = pydantic_msg.model_dump() msg_data = pydantic_msg.model_dump(to_orm=True)
msg = MessageModel(**msg_data) msg = MessageModel(**msg_data)
msg.create(session, actor=actor) # Persist to database msg.create(session, actor=actor) # Persist to database
return msg.to_pydantic() return msg.to_pydantic()

View File

@ -14,7 +14,7 @@ from letta.llm_api.helpers import calculate_summarizer_cutoff
from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import MessageRole from letta.schemas.enums import MessageRole
from letta.schemas.llm_config import LLMConfig from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message from letta.schemas.message import Message, TextContent
from letta.settings import summarizer_settings from letta.settings import summarizer_settings
from letta.streaming_interface import StreamingRefreshCLIInterface from letta.streaming_interface import StreamingRefreshCLIInterface
from tests.helpers.endpoints_helper import EMBEDDING_CONFIG_PATH from tests.helpers.endpoints_helper import EMBEDDING_CONFIG_PATH
@ -55,7 +55,7 @@ def generate_message(role: str, text: str = None, tool_calls: List = None) -> Me
return Message( return Message(
id="message-" + str(uuid.uuid4()), id="message-" + str(uuid.uuid4()),
role=MessageRole(role), role=MessageRole(role),
text=text or f"{role} message text", content=[TextContent(text=text or f"{role} message text")],
created_at=datetime.utcnow(), created_at=datetime.utcnow(),
tool_calls=tool_calls or [], tool_calls=tool_calls or [],
) )