From dfce6a155205aa7bf4fa3766299e95bd81f7457f Mon Sep 17 00:00:00 2001 From: cthomas Date: Thu, 23 Jan 2025 17:24:52 -0800 Subject: [PATCH] feat: extend message model to support more content types (#756) --- letta/functions/function_sets/extras.py | 11 ++- letta/memory.py | 8 +-- letta/orm/message.py | 8 +++ letta/schemas/enums.py | 4 ++ letta/schemas/message.py | 67 ++++++++++++++++--- letta/server/server.py | 12 ++-- .../services/helpers/agent_manager_helper.py | 11 ++- letta/services/message_manager.py | 2 +- tests/integration_test_summarizer.py | 4 +- 9 files changed, 99 insertions(+), 28 deletions(-) diff --git a/letta/functions/function_sets/extras.py b/letta/functions/function_sets/extras.py index d5d21644f..65652b91b 100644 --- a/letta/functions/function_sets/extras.py +++ b/letta/functions/function_sets/extras.py @@ -6,7 +6,7 @@ import requests from letta.constants import MESSAGE_CHATGPT_FUNCTION_MODEL, MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE from letta.llm_api.llm_api_tools import create -from letta.schemas.message import Message +from letta.schemas.message import Message, TextContent from letta.utils import json_dumps, json_loads @@ -23,8 +23,13 @@ def message_chatgpt(self, message: str): dummy_user_id = uuid.uuid4() dummy_agent_id = uuid.uuid4() message_sequence = [ - Message(user_id=dummy_user_id, agent_id=dummy_agent_id, role="system", text=MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE), - Message(user_id=dummy_user_id, agent_id=dummy_agent_id, role="user", text=str(message)), + Message( + user_id=dummy_user_id, + agent_id=dummy_agent_id, + role="system", + content=[TextContent(text=MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE)], + ), + Message(user_id=dummy_user_id, agent_id=dummy_agent_id, role="user", content=[TextContent(text=str(message))]), ] # TODO: this will error without an LLMConfig response = create( diff --git a/letta/memory.py b/letta/memory.py index b81e5e1da..e997be617 100644 --- a/letta/memory.py +++ b/letta/memory.py @@ -6,7 +6,7 @@ from letta.prompts.gpt_summarize import SYSTEM as SUMMARY_PROMPT_SYSTEM from letta.schemas.agent import AgentState from letta.schemas.enums import MessageRole from letta.schemas.memory import Memory -from letta.schemas.message import Message +from letta.schemas.message import Message, TextContent from letta.settings import summarizer_settings from letta.utils import count_tokens, printd @@ -60,9 +60,9 @@ def summarize_messages( dummy_agent_id = agent_state.id message_sequence = [ - Message(agent_id=dummy_agent_id, role=MessageRole.system, text=summary_prompt), - Message(agent_id=dummy_agent_id, role=MessageRole.assistant, text=MESSAGE_SUMMARY_REQUEST_ACK), - Message(agent_id=dummy_agent_id, role=MessageRole.user, text=summary_input), + Message(agent_id=dummy_agent_id, role=MessageRole.system, content=[TextContent(text=summary_prompt)]), + Message(agent_id=dummy_agent_id, role=MessageRole.assistant, content=[TextContent(text=MESSAGE_SUMMARY_REQUEST_ACK)]), + Message(agent_id=dummy_agent_id, role=MessageRole.user, content=[TextContent(text=summary_input)]), ] # TODO: We need to eventually have a separate LLM config for the summarizer LLM diff --git a/letta/orm/message.py b/letta/orm/message.py index 0591528c4..9183c4ae1 100644 --- a/letta/orm/message.py +++ b/letta/orm/message.py @@ -8,6 +8,7 @@ from letta.orm.custom_columns import ToolCallColumn from letta.orm.mixins import AgentMixin, OrganizationMixin from letta.orm.sqlalchemy_base import SqlalchemyBase from letta.schemas.message import Message as PydanticMessage +from letta.schemas.message import TextContent as PydanticTextContent class Message(SqlalchemyBase, OrganizationMixin, AgentMixin): @@ -45,3 +46,10 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin): def job(self) -> Optional["Job"]: """Get the job associated with this message, if any.""" return self.job_message.job if self.job_message else None + + def to_pydantic(self) -> PydanticMessage: + """custom pydantic conversion for message content mapping""" + model = self.__pydantic_model__.model_validate(self) + if self.text: + model.content = [PydanticTextContent(text=self.text)] + return model diff --git a/letta/schemas/enums.py b/letta/schemas/enums.py index e0bb485ed..9a3076aea 100644 --- a/letta/schemas/enums.py +++ b/letta/schemas/enums.py @@ -9,6 +9,10 @@ class MessageRole(str, Enum): system = "system" +class MessageContentType(str, Enum): + text = "text" + + class OptionState(str, Enum): """Useful for kwargs that are bool + default option""" diff --git a/letta/schemas/message.py b/letta/schemas/message.py index 9b84ce5ae..84781fdb8 100644 --- a/letta/schemas/message.py +++ b/letta/schemas/message.py @@ -2,15 +2,15 @@ import copy import json import warnings from datetime import datetime, timezone -from typing import List, Literal, Optional +from typing import Annotated, Any, Dict, List, Literal, Optional, Union from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall from openai.types.chat.chat_completion_message_tool_call import Function as OpenAIFunction -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, Field, field_validator, model_validator from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, TOOL_CALL_ID_MAX_LEN from letta.local_llm.constants import INNER_THOUGHTS_KWARG -from letta.schemas.enums import MessageRole +from letta.schemas.enums import MessageContentType, MessageRole from letta.schemas.letta_base import OrmMetadataBase from letta.schemas.letta_message import ( AssistantMessage, @@ -80,6 +80,21 @@ class MessageUpdate(BaseModel): tool_call_id: Optional[str] = Field(None, description="The id of the tool call.") +class MessageContent(BaseModel): + type: MessageContentType = Field(..., description="The type of the message.") + + +class TextContent(MessageContent): + type: Literal[MessageContentType.text] = Field(MessageContentType.text, description="The type of the message.") + text: str = Field(..., description="The text content of the message.") + + +MessageContentUnion = Annotated[ + Union[TextContent], + Field(discriminator="type"), +] + + class Message(BaseMessage): """ Letta's internal representation of a message. Includes methods to convert to/from LLM provider formats. @@ -100,7 +115,7 @@ class Message(BaseMessage): id: str = BaseMessage.generate_id_field() role: MessageRole = Field(..., description="The role of the participant.") - text: Optional[str] = Field(None, description="The text of the message.") + content: Optional[List[MessageContentUnion]] = Field(None, description="The content of the message.") organization_id: Optional[str] = Field(None, description="The unique identifier of the organization.") agent_id: Optional[str] = Field(None, description="The unique identifier of the agent.") model: Optional[str] = Field(None, description="The model used to make the function call.") @@ -108,6 +123,7 @@ class Message(BaseMessage): tool_calls: Optional[List[OpenAIToolCall,]] = Field(None, description="The list of tool calls requested.") tool_call_id: Optional[str] = Field(None, description="The id of the tool call.") step_id: Optional[str] = Field(None, description="The id of the step that this message was created in.") + # This overrides the optional base orm schema, created_at MUST exist on all messages objects created_at: datetime = Field(default_factory=get_utc_time, description="The timestamp when the object was created.") @@ -118,6 +134,37 @@ class Message(BaseMessage): assert v in roles, f"Role must be one of {roles}" return v + @model_validator(mode="before") + @classmethod + def convert_from_orm(cls, data: Dict[str, Any]) -> Dict[str, Any]: + if isinstance(data, dict): + if "text" in data and "content" not in data: + data["content"] = [TextContent(text=data["text"])] + del data["text"] + return data + + def model_dump(self, to_orm: bool = False, **kwargs) -> Dict[str, Any]: + data = super().model_dump(**kwargs) + if to_orm: + for content in data["content"]: + if content["type"] == "text": + data["text"] = content["text"] + del data["content"] + return data + + @property + def text(self) -> Optional[str]: + """ + Retrieve the first text content's text. + + Returns: + str: The text content, or None if no text content exists + """ + if not self.content: + return None + text_content = [content.text for content in self.content if content.type == MessageContentType.text] + return text_content[0] if text_content else None + def to_json(self): json_message = vars(self) if json_message["tool_calls"] is not None: @@ -283,7 +330,7 @@ class Message(BaseMessage): model=model, # standard fields expected in an OpenAI ChatCompletion message object role=MessageRole.tool, # NOTE - text=openai_message_dict["content"], + content=[TextContent(text=openai_message_dict["content"])], name=openai_message_dict["name"] if "name" in openai_message_dict else None, tool_calls=openai_message_dict["tool_calls"] if "tool_calls" in openai_message_dict else None, tool_call_id=openai_message_dict["tool_call_id"] if "tool_call_id" in openai_message_dict else None, @@ -296,7 +343,7 @@ class Message(BaseMessage): model=model, # standard fields expected in an OpenAI ChatCompletion message object role=MessageRole.tool, # NOTE - text=openai_message_dict["content"], + content=[TextContent(text=openai_message_dict["content"])], name=openai_message_dict["name"] if "name" in openai_message_dict else None, tool_calls=openai_message_dict["tool_calls"] if "tool_calls" in openai_message_dict else None, tool_call_id=openai_message_dict["tool_call_id"] if "tool_call_id" in openai_message_dict else None, @@ -328,7 +375,7 @@ class Message(BaseMessage): model=model, # standard fields expected in an OpenAI ChatCompletion message object role=MessageRole(openai_message_dict["role"]), - text=openai_message_dict["content"], + content=[TextContent(text=openai_message_dict["content"])], name=openai_message_dict["name"] if "name" in openai_message_dict else None, tool_calls=tool_calls, tool_call_id=None, # NOTE: None, since this field is only non-null for role=='tool' @@ -341,7 +388,7 @@ class Message(BaseMessage): model=model, # standard fields expected in an OpenAI ChatCompletion message object role=MessageRole(openai_message_dict["role"]), - text=openai_message_dict["content"], + content=[TextContent(text=openai_message_dict["content"])], name=openai_message_dict["name"] if "name" in openai_message_dict else None, tool_calls=tool_calls, tool_call_id=None, # NOTE: None, since this field is only non-null for role=='tool' @@ -373,7 +420,7 @@ class Message(BaseMessage): model=model, # standard fields expected in an OpenAI ChatCompletion message object role=MessageRole(openai_message_dict["role"]), - text=openai_message_dict["content"], + content=[TextContent(text=openai_message_dict["content"])], name=openai_message_dict["name"] if "name" in openai_message_dict else None, tool_calls=tool_calls, tool_call_id=openai_message_dict["tool_call_id"] if "tool_call_id" in openai_message_dict else None, @@ -386,7 +433,7 @@ class Message(BaseMessage): model=model, # standard fields expected in an OpenAI ChatCompletion message object role=MessageRole(openai_message_dict["role"]), - text=openai_message_dict["content"], + content=[TextContent(text=openai_message_dict["content"] or "")], name=openai_message_dict["name"] if "name" in openai_message_dict else None, tool_calls=tool_calls, tool_call_id=openai_message_dict["tool_call_id"] if "tool_call_id" in openai_message_dict else None, diff --git a/letta/server/server.py b/letta/server/server.py index 02e66e5c0..f79875fc8 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -39,7 +39,7 @@ from letta.schemas.letta_message import LegacyLettaMessage, LettaMessage, ToolRe from letta.schemas.letta_response import LettaResponse from letta.schemas.llm_config import LLMConfig from letta.schemas.memory import ArchivalMemorySummary, ContextWindowOverview, Memory, RecallMemorySummary -from letta.schemas.message import Message, MessageCreate, MessageRole, MessageUpdate +from letta.schemas.message import Message, MessageCreate, MessageRole, MessageUpdate, TextContent from letta.schemas.organization import Organization from letta.schemas.passage import Passage from letta.schemas.providers import ( @@ -617,14 +617,14 @@ class SyncServer(Server): message = Message( agent_id=agent_id, role="user", - text=packaged_user_message, + content=[TextContent(text=packaged_user_message)], created_at=timestamp, ) else: message = Message( agent_id=agent_id, role="user", - text=packaged_user_message, + content=[TextContent(text=packaged_user_message)], ) # Run the agent state forward @@ -667,14 +667,14 @@ class SyncServer(Server): message = Message( agent_id=agent_id, role="system", - text=packaged_system_message, + content=[TextContent(text=packaged_system_message)], created_at=timestamp, ) else: message = Message( agent_id=agent_id, role="system", - text=packaged_system_message, + content=[TextContent(text=packaged_system_message)], ) if isinstance(message, Message): @@ -732,7 +732,7 @@ class SyncServer(Server): Message( agent_id=agent_id, role=message.role, - text=message.text, + content=[TextContent(text=message.text)], name=message.name, # assigned later? model=None, diff --git a/letta/services/helpers/agent_manager_helper.py b/letta/services/helpers/agent_manager_helper.py index 0846a0c74..400dfd31a 100644 --- a/letta/services/helpers/agent_manager_helper.py +++ b/letta/services/helpers/agent_manager_helper.py @@ -11,7 +11,7 @@ from letta.prompts import gpt_system from letta.schemas.agent import AgentState, AgentType from letta.schemas.enums import MessageRole from letta.schemas.memory import Memory -from letta.schemas.message import Message, MessageCreate +from letta.schemas.message import Message, MessageCreate, TextContent from letta.schemas.tool_rule import ToolRule from letta.schemas.user import User from letta.system import get_initial_boot_messages, get_login_event @@ -244,7 +244,14 @@ def package_initial_message_sequence( raise ValueError(f"Invalid message role: {message_create.role}") init_messages.append( - Message(role=message_create.role, text=packed_message, organization_id=actor.organization_id, agent_id=agent_id, model=model) + Message( + role=message_create.role, + content=[TextContent(text=packed_message)], + name=message_create.name, + organization_id=actor.organization_id, + agent_id=agent_id, + model=model, + ) ) return init_messages diff --git a/letta/services/message_manager.py b/letta/services/message_manager.py index 5d813225d..59b0a3b6d 100644 --- a/letta/services/message_manager.py +++ b/letta/services/message_manager.py @@ -49,7 +49,7 @@ class MessageManager: with self.session_maker() as session: # Set the organization id of the Pydantic message pydantic_msg.organization_id = actor.organization_id - msg_data = pydantic_msg.model_dump() + msg_data = pydantic_msg.model_dump(to_orm=True) msg = MessageModel(**msg_data) msg.create(session, actor=actor) # Persist to database return msg.to_pydantic() diff --git a/tests/integration_test_summarizer.py b/tests/integration_test_summarizer.py index 07b0e90a2..606600aa3 100644 --- a/tests/integration_test_summarizer.py +++ b/tests/integration_test_summarizer.py @@ -14,7 +14,7 @@ from letta.llm_api.helpers import calculate_summarizer_cutoff from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.enums import MessageRole from letta.schemas.llm_config import LLMConfig -from letta.schemas.message import Message +from letta.schemas.message import Message, TextContent from letta.settings import summarizer_settings from letta.streaming_interface import StreamingRefreshCLIInterface from tests.helpers.endpoints_helper import EMBEDDING_CONFIG_PATH @@ -55,7 +55,7 @@ def generate_message(role: str, text: str = None, tool_calls: List = None) -> Me return Message( id="message-" + str(uuid.uuid4()), role=MessageRole(role), - text=text or f"{role} message text", + content=[TextContent(text=text or f"{role} message text")], created_at=datetime.utcnow(), tool_calls=tool_calls or [], )