diff --git a/letta/schemas/message.py b/letta/schemas/message.py index bc5869f6c..5a5bd17ac 100644 --- a/letta/schemas/message.py +++ b/letta/schemas/message.py @@ -25,6 +25,7 @@ from letta.schemas.letta_message import ( ToolReturnMessage, UserMessage, ) +from letta.system import unpack_message from letta.utils import get_utc_time, is_utc_datetime, json_dumps @@ -264,11 +265,12 @@ class Message(BaseMessage): elif self.role == MessageRole.user: # This is type UserMessage assert self.text is not None, self + message_str = unpack_message(self.text) messages.append( UserMessage( id=self.id, date=self.created_at, - content=self.text, + content=message_str or self.text, ) ) elif self.role == MessageRole.system: diff --git a/letta/system.py b/letta/system.py index 9c795704c..a13e36f1e 100644 --- a/letta/system.py +++ b/letta/system.py @@ -1,5 +1,6 @@ import json import uuid +import warnings from typing import Optional from .constants import ( @@ -205,3 +206,22 @@ def get_token_limit_warning(): } return json_dumps(packaged_message) + + +def unpack_message(packed_message) -> str: + """Take a packed message string and attempt to extract the inner message content""" + + try: + message_json = json.loads(packed_message) + except: + warnings.warn(f"Was unable to load message as JSON to unpack: ''{packed_message}") + return packed_message + + if "message" not in message_json: + if "type" in message_json and message_json["type"] in ["login", "heartbeat"]: + # This is a valid user message that the ADE expects, so don't print warning + return packed_message + warnings.warn(f"Was unable to find 'message' field in packed message object: '{packed_message}'") + return packed_message + else: + return message_json.get("message") diff --git a/tests/test_server.py b/tests/test_server.py index 8f7e1e3c8..aef057368 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -26,6 +26,7 @@ from letta.schemas.job import Job as PydanticJob from letta.schemas.message import Message from letta.schemas.source import Source as PydanticSource from letta.server.server import SyncServer +from letta.system import unpack_message from .utils import DummyDataConnector @@ -711,7 +712,7 @@ def _test_get_messages_letta_format( elif message.role == MessageRole.user: assert isinstance(letta_message, UserMessage) - assert message.text == letta_message.content + assert unpack_message(message.text) == letta_message.content letta_message_index += 1 elif message.role == MessageRole.system: @@ -734,8 +735,7 @@ def _test_get_messages_letta_format( def test_get_messages_letta_format(server, user, agent_id): - # for reverse in [False, True]: - for reverse in [False]: + for reverse in [False, True]: _test_get_messages_letta_format(server, user, agent_id, reverse=reverse)