feat: add content union type for requests (#762)

This commit is contained in:
cthomas 2025-01-23 20:25:00 -08:00 committed by GitHub
parent 63ea14d48a
commit 06ca10acb1
22 changed files with 99 additions and 83 deletions

View File

@ -46,7 +46,7 @@ response = client.agents.messages.send(
messages=[
MessageCreate(
role="user",
text="hello",
content="hello",
)
],
)
@ -59,7 +59,7 @@ response = client.agents.messages.send(
messages=[
MessageCreate(
role="system",
text="[system] user has logged in. send a friendly message.",
content="[system] user has logged in. send a friendly message.",
)
],
)

View File

@ -29,7 +29,7 @@ response = client.agents.messages.send(
messages=[
MessageCreate(
role="user",
text="hello",
content="hello",
)
],
)

View File

@ -43,7 +43,7 @@ def main():
messages=[
MessageCreate(
role="user",
text="Whats my name?",
content="Whats my name?",
)
],
)

View File

@ -64,7 +64,7 @@ response = client.agents.messages.send(
messages=[
MessageCreate(
role="user",
text="roll a dice",
content="roll a dice",
)
],
)
@ -100,7 +100,7 @@ client.agents.messages.send(
messages=[
MessageCreate(
role="user",
text="search your archival memory",
content="search your archival memory",
)
],
)

View File

@ -246,7 +246,7 @@
" messages=[\n",
" MessageCreate(\n",
" role=\"user\",\n",
" text=\"Search archival for our company's vacation policies\",\n",
" content=\"Search archival for our company's vacation policies\",\n",
" )\n",
" ],\n",
")\n",
@ -528,7 +528,7 @@
" messages=[\n",
" MessageCreate(\n",
" role=\"user\",\n",
" text=\"When is my birthday?\",\n",
" content=\"When is my birthday?\",\n",
" )\n",
" ],\n",
")\n",
@ -814,7 +814,7 @@
" messages=[\n",
" MessageCreate(\n",
" role=\"user\",\n",
" text=\"Who founded OpenAI?\",\n",
" content=\"Who founded OpenAI?\",\n",
" )\n",
" ],\n",
")\n",
@ -952,7 +952,7 @@
" messages=[\n",
" MessageCreate(\n",
" role=\"user\",\n",
" text=\"Who founded OpenAI?\",\n",
" content=\"Who founded OpenAI?\",\n",
" )\n",
" ],\n",
")\n",

View File

@ -169,7 +169,7 @@
" messages=[\n",
" MessageCreate(\n",
" role=\"user\",\n",
" text=\"hello!\",\n",
" content=\"hello!\",\n",
" )\n",
" ],\n",
")\n",
@ -529,7 +529,7 @@
" messages=[\n",
" MessageCreate(\n",
" role=\"user\",\n",
" text=\"My name is actually Bob\",\n",
" content=\"My name is actually Bob\",\n",
" )\n",
" ],\n",
")\n",
@ -682,7 +682,7 @@
" messages=[\n",
" MessageCreate(\n",
" role=\"user\",\n",
" text=\"In the future, never use emojis to communicate\",\n",
" content=\"In the future, never use emojis to communicate\",\n",
" )\n",
" ],\n",
")\n",
@ -870,7 +870,7 @@
" messages=[\n",
" MessageCreate(\n",
" role=\"user\",\n",
" text=\"Save the information that 'bob loves cats' to archival\",\n",
" content=\"Save the information that 'bob loves cats' to archival\",\n",
" )\n",
" ],\n",
")\n",
@ -1039,7 +1039,7 @@
" messages=[\n",
" MessageCreate(\n",
" role=\"user\",\n",
" text=\"What animals do I like? Search archival.\",\n",
" content=\"What animals do I like? Search archival.\",\n",
" )\n",
" ],\n",
")\n",

View File

@ -276,7 +276,7 @@
" messages=[\n",
" MessageCreate(\n",
" role=\"user\",\n",
" text=\"Candidate: Tony Stark\",\n",
" content=\"Candidate: Tony Stark\",\n",
" )\n",
" ],\n",
")"
@ -403,7 +403,7 @@
" messages=[\n",
" MessageCreate(\n",
" role=\"user\",\n",
" text=feedback,\n",
" content=feedback,\n",
" )\n",
" ],\n",
")"
@ -423,7 +423,7 @@
" messages=[\n",
" MessageCreate(\n",
" role=\"user\",\n",
" text=feedback,\n",
" content=feedback,\n",
" )\n",
" ],\n",
")"
@ -540,7 +540,7 @@
" messages=[\n",
" MessageCreate(\n",
" role=\"system\",\n",
" text=\"Candidate: Spongebob Squarepants\",\n",
" content=\"Candidate: Spongebob Squarepants\",\n",
" )\n",
" ],\n",
")"
@ -758,7 +758,7 @@
" messages=[\n",
" MessageCreate(\n",
" role=\"system\",\n",
" text=\"Run generation\",\n",
" content=\"Run generation\",\n",
" )\n",
" ],\n",
")"

View File

@ -643,7 +643,7 @@ class RESTClient(AbstractClient):
) -> Message:
request = MessageUpdate(
role=role,
text=text,
content=text,
name=name,
tool_calls=tool_calls,
tool_call_id=tool_call_id,
@ -1015,7 +1015,7 @@ class RESTClient(AbstractClient):
response (LettaResponse): Response from the agent
"""
# TODO: implement include_full_message
messages = [MessageCreate(role=MessageRole(role), text=message, name=name)]
messages = [MessageCreate(role=MessageRole(role), content=message, name=name)]
# TODO: figure out how to handle stream_steps and stream_tokens
# When streaming steps is True, stream_tokens must be False
@ -1062,7 +1062,7 @@ class RESTClient(AbstractClient):
Returns:
job (Job): Information about the async job
"""
messages = [MessageCreate(role=MessageRole(role), text=message, name=name)]
messages = [MessageCreate(role=MessageRole(role), content=message, name=name)]
request = LettaRequest(messages=messages)
response = requests.post(
@ -2442,7 +2442,7 @@ class LocalClient(AbstractClient):
message_id=message_id,
request=MessageUpdate(
role=role,
text=text,
content=text,
name=name,
tool_calls=tool_calls,
tool_call_id=tool_call_id,
@ -2741,7 +2741,7 @@ class LocalClient(AbstractClient):
usage = self.server.send_messages(
actor=self.user,
agent_id=agent_id,
messages=[MessageCreate(role=MessageRole(role), text=message, name=name)],
messages=[MessageCreate(role=MessageRole(role), content=message, name=name)],
)
## TODO: need to make sure date/timestamp is propely passed

View File

@ -50,7 +50,7 @@ def _sse_post(url: str, data: dict, headers: dict) -> Generator[LettaStreamingRe
chunk_data = json.loads(sse.data)
if "reasoning" in chunk_data:
yield ReasoningMessage(**chunk_data)
elif "assistant_message" in chunk_data:
elif "message_type" in chunk_data and chunk_data["message_type"] == "assistant_message":
yield AssistantMessage(**chunk_data)
elif "tool_call" in chunk_data:
yield ToolCallMessage(**chunk_data)

View File

@ -246,7 +246,7 @@ def parse_letta_response_for_assistant_message(
reasoning_message = ""
for m in letta_response.messages:
if isinstance(m, AssistantMessage):
return m.assistant_message
return m.content
elif isinstance(m, ToolCallMessage) and m.tool_call.name == assistant_message_tool_name:
try:
return json.loads(m.tool_call.arguments)[assistant_message_tool_kwarg]
@ -290,7 +290,7 @@ async def async_send_message_with_retries(
logging_prefix = logging_prefix or "[async_send_message_with_retries]"
for attempt in range(1, max_retries + 1):
try:
messages = [MessageCreate(role=MessageRole.user, text=message_text, name=sender_agent.agent_state.name)]
messages = [MessageCreate(role=MessageRole.user, content=message_text, name=sender_agent.agent_state.name)]
# Wrap in a timeout
response = await asyncio.wait_for(
server.send_message_to_agent(

View File

@ -4,6 +4,8 @@ from typing import Annotated, List, Literal, Optional, Union
from pydantic import BaseModel, Field, field_serializer, field_validator
from letta.schemas.enums import MessageContentType
# Letta API style responses (intended to be easier to use vs getting true Message types)
@ -32,18 +34,33 @@ class LettaMessage(BaseModel):
return dt.isoformat(timespec="seconds")
class MessageContent(BaseModel):
type: MessageContentType = Field(..., description="The type of the message.")
class TextContent(MessageContent):
type: Literal[MessageContentType.text] = Field(MessageContentType.text, description="The type of the message.")
text: str = Field(..., description="The text content of the message.")
MessageContentUnion = Annotated[
Union[TextContent],
Field(discriminator="type"),
]
class SystemMessage(LettaMessage):
"""
A message generated by the system. Never streamed back on a response, only used for cursor pagination.
Attributes:
message (str): The message sent by the system
content (Union[str, List[MessageContentUnion]]): The message content sent by the user (can be a string or an array of content parts)
id (str): The ID of the message
date (datetime): The date the message was created in ISO format
"""
message_type: Literal["system_message"] = "system_message"
message: str
content: Union[str, List[MessageContentUnion]]
class UserMessage(LettaMessage):
@ -51,13 +68,13 @@ class UserMessage(LettaMessage):
A message sent by the user. Never streamed back on a response, only used for cursor pagination.
Attributes:
message (str): The message sent by the user
content (Union[str, List[MessageContentUnion]]): The message content sent by the user (can be a string or an array of content parts)
id (str): The ID of the message
date (datetime): The date the message was created in ISO format
"""
message_type: Literal["user_message"] = "user_message"
message: str
content: Union[str, List[MessageContentUnion]]
class ReasoningMessage(LettaMessage):
@ -167,7 +184,7 @@ class ToolReturnMessage(LettaMessage):
class AssistantMessage(LettaMessage):
message_type: Literal["assistant_message"] = "assistant_message"
assistant_message: str
content: Union[str, List[MessageContentUnion]]
class LegacyFunctionCallMessage(LettaMessage):

View File

@ -2,7 +2,7 @@ import copy
import json
import warnings
from datetime import datetime, timezone
from typing import Annotated, Any, Dict, List, Literal, Optional, Union
from typing import Any, Dict, List, Literal, Optional, Union
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall
from openai.types.chat.chat_completion_message_tool_call import Function as OpenAIFunction
@ -15,8 +15,10 @@ from letta.schemas.letta_base import OrmMetadataBase
from letta.schemas.letta_message import (
AssistantMessage,
LettaMessage,
MessageContentUnion,
ReasoningMessage,
SystemMessage,
TextContent,
ToolCall,
ToolCallMessage,
ToolReturnMessage,
@ -59,7 +61,7 @@ class MessageCreate(BaseModel):
MessageRole.user,
MessageRole.system,
] = Field(..., description="The role of the participant.")
text: str = Field(..., description="The text of the message.")
content: Union[str, List[MessageContentUnion]] = Field(..., description="The content of the message.")
name: Optional[str] = Field(None, description="The name of the participant.")
@ -67,7 +69,7 @@ class MessageUpdate(BaseModel):
"""Request to update a message"""
role: Optional[MessageRole] = Field(None, description="The role of the participant.")
text: Optional[str] = Field(None, description="The text of the message.")
content: Optional[Union[str, List[MessageContentUnion]]] = Field(..., description="The content of the message.")
# NOTE: probably doesn't make sense to allow remapping user_id or agent_id (vs creating a new message)
# user_id: Optional[str] = Field(None, description="The unique identifier of the user.")
# agent_id: Optional[str] = Field(None, description="The unique identifier of the agent.")
@ -79,20 +81,17 @@ class MessageUpdate(BaseModel):
tool_calls: Optional[List[OpenAIToolCall,]] = Field(None, description="The list of tool calls requested.")
tool_call_id: Optional[str] = Field(None, description="The id of the tool call.")
class MessageContent(BaseModel):
type: MessageContentType = Field(..., description="The type of the message.")
class TextContent(MessageContent):
type: Literal[MessageContentType.text] = Field(MessageContentType.text, description="The type of the message.")
text: str = Field(..., description="The text content of the message.")
MessageContentUnion = Annotated[
Union[TextContent],
Field(discriminator="type"),
]
def model_dump(self, to_orm: bool = False, **kwargs) -> Dict[str, Any]:
data = super().model_dump(**kwargs)
if to_orm and "content" in data:
if isinstance(data["content"], str):
data["text"] = data["content"]
else:
for content in data["content"]:
if content["type"] == "text":
data["text"] = content["text"]
del data["content"]
return data
class Message(BaseMessage):
@ -212,7 +211,7 @@ class Message(BaseMessage):
AssistantMessage(
id=self.id,
date=self.created_at,
assistant_message=message_string,
content=message_string,
)
)
else:
@ -268,7 +267,7 @@ class Message(BaseMessage):
UserMessage(
id=self.id,
date=self.created_at,
message=self.text,
content=self.text,
)
)
elif self.role == MessageRole.system:
@ -278,7 +277,7 @@ class Message(BaseMessage):
SystemMessage(
id=self.id,
date=self.created_at,
message=self.text,
content=self.text,
)
)
else:

View File

@ -472,7 +472,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
processed_chunk = AssistantMessage(
id=message_id,
date=message_date,
assistant_message=cleaned_func_args,
content=cleaned_func_args,
)
# otherwise we just do a regular passthrough of a ToolCallDelta via a ToolCallMessage
@ -613,7 +613,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
processed_chunk = AssistantMessage(
id=message_id,
date=message_date,
assistant_message=combined_chunk,
content=combined_chunk,
)
# Store the ID of the tool call so allow skipping the corresponding response
if self.function_id_buffer:
@ -627,7 +627,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
processed_chunk = AssistantMessage(
id=message_id,
date=message_date,
assistant_message=updates_main_json,
content=updates_main_json,
)
# Store the ID of the tool call so allow skipping the corresponding response
if self.function_id_buffer:
@ -959,7 +959,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
processed_chunk = AssistantMessage(
id=msg_obj.id,
date=msg_obj.created_at,
assistant_message=func_args["message"],
content=func_args["message"],
)
self._push_to_buffer(processed_chunk)
except Exception as e:
@ -981,7 +981,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
processed_chunk = AssistantMessage(
id=msg_obj.id,
date=msg_obj.created_at,
assistant_message=func_args[self.assistant_message_tool_kwarg],
content=func_args[self.assistant_message_tool_kwarg],
)
# Store the ID of the tool call so allow skipping the corresponding response
self.prev_assistant_message_id = function_call.id

View File

@ -721,9 +721,9 @@ class SyncServer(Server):
# If wrapping is eanbled, wrap with metadata before placing content inside the Message object
if message.role == MessageRole.user and wrap_user_message:
message.text = system.package_user_message(user_message=message.text)
message.content = system.package_user_message(user_message=message.content)
elif message.role == MessageRole.system and wrap_system_message:
message.text = system.package_system_message(system_message=message.text)
message.content = system.package_system_message(system_message=message.content)
else:
raise ValueError(f"Invalid message role: {message.role}")
@ -732,7 +732,7 @@ class SyncServer(Server):
Message(
agent_id=agent_id,
role=message.role,
content=[TextContent(text=message.text)],
content=[TextContent(text=message.content)],
name=message.name,
# assigned later?
model=None,

View File

@ -234,11 +234,11 @@ def package_initial_message_sequence(
if message_create.role == MessageRole.user:
packed_message = system.package_user_message(
user_message=message_create.text,
user_message=message_create.content,
)
elif message_create.role == MessageRole.system:
packed_message = system.package_system_message(
system_message=message_create.text,
system_message=message_create.content,
)
else:
raise ValueError(f"Invalid message role: {message_create.role}")

View File

@ -83,7 +83,7 @@ class MessageManager:
)
# get update dictionary
update_data = message_update.model_dump(exclude_unset=True, exclude_none=True)
update_data = message_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
# Remove redundant update fields
update_data = {key: value for key, value in update_data.items() if getattr(message, key) != value}

View File

@ -55,7 +55,7 @@ class LettaUser(HttpUser):
@task(1)
def send_message(self):
messages = [MessageCreate(role=MessageRole("user"), text="hello")]
messages = [MessageCreate(role=MessageRole("user"), content="hello")]
request = LettaRequest(messages=messages)
with self.client.post(
@ -70,7 +70,7 @@ class LettaUser(HttpUser):
# @task(1)
# def send_message_stream(self):
# messages = [MessageCreate(role=MessageRole("user"), text="hello")]
# messages = [MessageCreate(role=MessageRole("user"), content="hello")]
# request = LettaRequest(messages=messages, stream_steps=True, stream_tokens=True, return_message_object=True)
# if stream_tokens or stream_steps:
# from letta.client.streaming import _sse_post

View File

@ -628,7 +628,7 @@ def test_initial_message_sequence(client: Union[LocalClient, RESTClient], agent:
empty_agent_state = client.create_agent(name="test-empty-message-sequence", initial_message_sequence=[])
cleanup_agents.append(empty_agent_state.id)
custom_sequence = [MessageCreate(**{"text": "Hello, how are you?", "role": MessageRole.user})]
custom_sequence = [MessageCreate(**{"content": "Hello, how are you?", "role": MessageRole.user})]
custom_agent_state = client.create_agent(name="test-custom-message-sequence", initial_message_sequence=custom_sequence)
cleanup_agents.append(custom_agent_state.id)
assert custom_agent_state.message_ids is not None
@ -637,7 +637,7 @@ def test_initial_message_sequence(client: Union[LocalClient, RESTClient], agent:
), f"Expected {len(custom_sequence) + 1} messages, got {len(custom_agent_state.message_ids)}"
# assert custom_agent_state.message_ids[1:] == [msg.id for msg in custom_sequence]
# shoule be contained in second message (after system message)
assert custom_sequence[0].text in client.get_in_context_messages(custom_agent_state.id)[1].text
assert custom_sequence[0].content in client.get_in_context_messages(custom_agent_state.id)[1].text
def test_add_and_manage_tags_for_agent(client: Union[LocalClient, RESTClient], agent: AgentState):

View File

@ -446,7 +446,7 @@ def comprehensive_test_agent_fixture(server: SyncServer, default_user, print_too
description="test_description",
metadata={"test_key": "test_value"},
tool_rules=[InitToolRule(tool_name=print_tool.name)],
initial_message_sequence=[MessageCreate(role=MessageRole.user, text="hello world")],
initial_message_sequence=[MessageCreate(role=MessageRole.user, content="hello world")],
tool_exec_environment_variables={"test_env_var_key_a": "test_env_var_value_a", "test_env_var_key_b": "test_env_var_value_b"},
)
created_agent = server.agent_manager.create_agent(
@ -548,7 +548,7 @@ def test_create_agent_passed_in_initial_messages(server: SyncServer, default_use
block_ids=[default_block.id],
tags=["a", "b"],
description="test_description",
initial_message_sequence=[MessageCreate(role=MessageRole.user, text="hello world")],
initial_message_sequence=[MessageCreate(role=MessageRole.user, content="hello world")],
)
agent_state = server.agent_manager.create_agent(
create_agent_request,
@ -561,7 +561,7 @@ def test_create_agent_passed_in_initial_messages(server: SyncServer, default_use
assert create_agent_request.memory_blocks[0].value in init_messages[0].text
# Check that the second message is the passed in initial message seq
assert create_agent_request.initial_message_sequence[0].role == init_messages[1].role
assert create_agent_request.initial_message_sequence[0].text in init_messages[1].text
assert create_agent_request.initial_message_sequence[0].content in init_messages[1].text
def test_create_agent_default_initial_message(server: SyncServer, default_user, default_block):
@ -1830,7 +1830,7 @@ def test_message_get_by_id(server: SyncServer, hello_world_message_fixture, defa
def test_message_update(server: SyncServer, hello_world_message_fixture, default_user, other_user):
"""Test updating a message"""
new_text = "Updated text"
updated = server.message_manager.update_message_by_id(hello_world_message_fixture.id, MessageUpdate(text=new_text), actor=other_user)
updated = server.message_manager.update_message_by_id(hello_world_message_fixture.id, MessageUpdate(content=new_text), actor=other_user)
assert updated is not None
assert updated.text == new_text
retrieved = server.message_manager.get_message_by_id(hello_world_message_fixture.id, actor=default_user)

View File

@ -96,7 +96,7 @@ def test_shared_blocks(client):
messages=[
MessageCreate(
role="user",
text="my name is actually charles",
content="my name is actually charles",
)
],
)
@ -109,7 +109,7 @@ def test_shared_blocks(client):
messages=[
MessageCreate(
role="user",
text="whats my name?",
content="whats my name?",
)
],
)
@ -339,7 +339,7 @@ def test_messages(client, agent):
messages=[
MessageCreate(
role="user",
text="Test message",
content="Test message",
),
],
)
@ -359,7 +359,7 @@ def test_send_system_message(client, agent):
messages=[
MessageCreate(
role="system",
text="Event occurred: The user just logged off.",
content="Event occurred: The user just logged off.",
),
],
)
@ -388,7 +388,7 @@ def test_function_return_limit(client, agent):
messages=[
MessageCreate(
role="user",
text="call the big_return function",
content="call the big_return function",
),
],
config=LettaRequestConfig(use_assistant_message=False),
@ -424,7 +424,7 @@ def test_function_always_error(client, agent):
messages=[
MessageCreate(
role="user",
text="call the always_error function",
content="call the always_error function",
),
],
config=LettaRequestConfig(use_assistant_message=False),
@ -455,7 +455,7 @@ async def test_send_message_parallel(client, agent):
messages=[
MessageCreate(
role="user",
text=message,
content=message,
),
],
)
@ -490,7 +490,7 @@ def test_send_message_async(client, agent):
messages=[
MessageCreate(
role="user",
text=test_message,
content=test_message,
),
],
config=LettaRequestConfig(use_assistant_message=False),

View File

@ -711,12 +711,12 @@ def _test_get_messages_letta_format(
elif message.role == MessageRole.user:
assert isinstance(letta_message, UserMessage)
assert message.text == letta_message.message
assert message.text == letta_message.content
letta_message_index += 1
elif message.role == MessageRole.system:
assert isinstance(letta_message, SystemMessage)
assert message.text == letta_message.message
assert message.text == letta_message.content
letta_message_index += 1
elif message.role == MessageRole.tool:

View File

@ -324,7 +324,7 @@ def test_get_run_messages(client, mock_sync_server):
UserMessage(
id=f"message-{i:08x}",
date=current_time,
message=f"Test message {i}",
content=f"Test message {i}",
)
for i in range(2)
]