mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
feat: add content union type for requests (#762)
This commit is contained in:
parent
63ea14d48a
commit
06ca10acb1
@ -46,7 +46,7 @@ response = client.agents.messages.send(
|
|||||||
messages=[
|
messages=[
|
||||||
MessageCreate(
|
MessageCreate(
|
||||||
role="user",
|
role="user",
|
||||||
text="hello",
|
content="hello",
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -59,7 +59,7 @@ response = client.agents.messages.send(
|
|||||||
messages=[
|
messages=[
|
||||||
MessageCreate(
|
MessageCreate(
|
||||||
role="system",
|
role="system",
|
||||||
text="[system] user has logged in. send a friendly message.",
|
content="[system] user has logged in. send a friendly message.",
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -29,7 +29,7 @@ response = client.agents.messages.send(
|
|||||||
messages=[
|
messages=[
|
||||||
MessageCreate(
|
MessageCreate(
|
||||||
role="user",
|
role="user",
|
||||||
text="hello",
|
content="hello",
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -43,7 +43,7 @@ def main():
|
|||||||
messages=[
|
messages=[
|
||||||
MessageCreate(
|
MessageCreate(
|
||||||
role="user",
|
role="user",
|
||||||
text="Whats my name?",
|
content="Whats my name?",
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -64,7 +64,7 @@ response = client.agents.messages.send(
|
|||||||
messages=[
|
messages=[
|
||||||
MessageCreate(
|
MessageCreate(
|
||||||
role="user",
|
role="user",
|
||||||
text="roll a dice",
|
content="roll a dice",
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -100,7 +100,7 @@ client.agents.messages.send(
|
|||||||
messages=[
|
messages=[
|
||||||
MessageCreate(
|
MessageCreate(
|
||||||
role="user",
|
role="user",
|
||||||
text="search your archival memory",
|
content="search your archival memory",
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -246,7 +246,7 @@
|
|||||||
" messages=[\n",
|
" messages=[\n",
|
||||||
" MessageCreate(\n",
|
" MessageCreate(\n",
|
||||||
" role=\"user\",\n",
|
" role=\"user\",\n",
|
||||||
" text=\"Search archival for our company's vacation policies\",\n",
|
" content=\"Search archival for our company's vacation policies\",\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
" ],\n",
|
" ],\n",
|
||||||
")\n",
|
")\n",
|
||||||
@ -528,7 +528,7 @@
|
|||||||
" messages=[\n",
|
" messages=[\n",
|
||||||
" MessageCreate(\n",
|
" MessageCreate(\n",
|
||||||
" role=\"user\",\n",
|
" role=\"user\",\n",
|
||||||
" text=\"When is my birthday?\",\n",
|
" content=\"When is my birthday?\",\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
" ],\n",
|
" ],\n",
|
||||||
")\n",
|
")\n",
|
||||||
@ -814,7 +814,7 @@
|
|||||||
" messages=[\n",
|
" messages=[\n",
|
||||||
" MessageCreate(\n",
|
" MessageCreate(\n",
|
||||||
" role=\"user\",\n",
|
" role=\"user\",\n",
|
||||||
" text=\"Who founded OpenAI?\",\n",
|
" content=\"Who founded OpenAI?\",\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
" ],\n",
|
" ],\n",
|
||||||
")\n",
|
")\n",
|
||||||
@ -952,7 +952,7 @@
|
|||||||
" messages=[\n",
|
" messages=[\n",
|
||||||
" MessageCreate(\n",
|
" MessageCreate(\n",
|
||||||
" role=\"user\",\n",
|
" role=\"user\",\n",
|
||||||
" text=\"Who founded OpenAI?\",\n",
|
" content=\"Who founded OpenAI?\",\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
" ],\n",
|
" ],\n",
|
||||||
")\n",
|
")\n",
|
||||||
|
@ -169,7 +169,7 @@
|
|||||||
" messages=[\n",
|
" messages=[\n",
|
||||||
" MessageCreate(\n",
|
" MessageCreate(\n",
|
||||||
" role=\"user\",\n",
|
" role=\"user\",\n",
|
||||||
" text=\"hello!\",\n",
|
" content=\"hello!\",\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
" ],\n",
|
" ],\n",
|
||||||
")\n",
|
")\n",
|
||||||
@ -529,7 +529,7 @@
|
|||||||
" messages=[\n",
|
" messages=[\n",
|
||||||
" MessageCreate(\n",
|
" MessageCreate(\n",
|
||||||
" role=\"user\",\n",
|
" role=\"user\",\n",
|
||||||
" text=\"My name is actually Bob\",\n",
|
" content=\"My name is actually Bob\",\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
" ],\n",
|
" ],\n",
|
||||||
")\n",
|
")\n",
|
||||||
@ -682,7 +682,7 @@
|
|||||||
" messages=[\n",
|
" messages=[\n",
|
||||||
" MessageCreate(\n",
|
" MessageCreate(\n",
|
||||||
" role=\"user\",\n",
|
" role=\"user\",\n",
|
||||||
" text=\"In the future, never use emojis to communicate\",\n",
|
" content=\"In the future, never use emojis to communicate\",\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
" ],\n",
|
" ],\n",
|
||||||
")\n",
|
")\n",
|
||||||
@ -870,7 +870,7 @@
|
|||||||
" messages=[\n",
|
" messages=[\n",
|
||||||
" MessageCreate(\n",
|
" MessageCreate(\n",
|
||||||
" role=\"user\",\n",
|
" role=\"user\",\n",
|
||||||
" text=\"Save the information that 'bob loves cats' to archival\",\n",
|
" content=\"Save the information that 'bob loves cats' to archival\",\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
" ],\n",
|
" ],\n",
|
||||||
")\n",
|
")\n",
|
||||||
@ -1039,7 +1039,7 @@
|
|||||||
" messages=[\n",
|
" messages=[\n",
|
||||||
" MessageCreate(\n",
|
" MessageCreate(\n",
|
||||||
" role=\"user\",\n",
|
" role=\"user\",\n",
|
||||||
" text=\"What animals do I like? Search archival.\",\n",
|
" content=\"What animals do I like? Search archival.\",\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
" ],\n",
|
" ],\n",
|
||||||
")\n",
|
")\n",
|
||||||
|
@ -276,7 +276,7 @@
|
|||||||
" messages=[\n",
|
" messages=[\n",
|
||||||
" MessageCreate(\n",
|
" MessageCreate(\n",
|
||||||
" role=\"user\",\n",
|
" role=\"user\",\n",
|
||||||
" text=\"Candidate: Tony Stark\",\n",
|
" content=\"Candidate: Tony Stark\",\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
" ],\n",
|
" ],\n",
|
||||||
")"
|
")"
|
||||||
@ -403,7 +403,7 @@
|
|||||||
" messages=[\n",
|
" messages=[\n",
|
||||||
" MessageCreate(\n",
|
" MessageCreate(\n",
|
||||||
" role=\"user\",\n",
|
" role=\"user\",\n",
|
||||||
" text=feedback,\n",
|
" content=feedback,\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
" ],\n",
|
" ],\n",
|
||||||
")"
|
")"
|
||||||
@ -423,7 +423,7 @@
|
|||||||
" messages=[\n",
|
" messages=[\n",
|
||||||
" MessageCreate(\n",
|
" MessageCreate(\n",
|
||||||
" role=\"user\",\n",
|
" role=\"user\",\n",
|
||||||
" text=feedback,\n",
|
" content=feedback,\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
" ],\n",
|
" ],\n",
|
||||||
")"
|
")"
|
||||||
@ -540,7 +540,7 @@
|
|||||||
" messages=[\n",
|
" messages=[\n",
|
||||||
" MessageCreate(\n",
|
" MessageCreate(\n",
|
||||||
" role=\"system\",\n",
|
" role=\"system\",\n",
|
||||||
" text=\"Candidate: Spongebob Squarepants\",\n",
|
" content=\"Candidate: Spongebob Squarepants\",\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
" ],\n",
|
" ],\n",
|
||||||
")"
|
")"
|
||||||
@ -758,7 +758,7 @@
|
|||||||
" messages=[\n",
|
" messages=[\n",
|
||||||
" MessageCreate(\n",
|
" MessageCreate(\n",
|
||||||
" role=\"system\",\n",
|
" role=\"system\",\n",
|
||||||
" text=\"Run generation\",\n",
|
" content=\"Run generation\",\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
" ],\n",
|
" ],\n",
|
||||||
")"
|
")"
|
||||||
|
@ -643,7 +643,7 @@ class RESTClient(AbstractClient):
|
|||||||
) -> Message:
|
) -> Message:
|
||||||
request = MessageUpdate(
|
request = MessageUpdate(
|
||||||
role=role,
|
role=role,
|
||||||
text=text,
|
content=text,
|
||||||
name=name,
|
name=name,
|
||||||
tool_calls=tool_calls,
|
tool_calls=tool_calls,
|
||||||
tool_call_id=tool_call_id,
|
tool_call_id=tool_call_id,
|
||||||
@ -1015,7 +1015,7 @@ class RESTClient(AbstractClient):
|
|||||||
response (LettaResponse): Response from the agent
|
response (LettaResponse): Response from the agent
|
||||||
"""
|
"""
|
||||||
# TODO: implement include_full_message
|
# TODO: implement include_full_message
|
||||||
messages = [MessageCreate(role=MessageRole(role), text=message, name=name)]
|
messages = [MessageCreate(role=MessageRole(role), content=message, name=name)]
|
||||||
# TODO: figure out how to handle stream_steps and stream_tokens
|
# TODO: figure out how to handle stream_steps and stream_tokens
|
||||||
|
|
||||||
# When streaming steps is True, stream_tokens must be False
|
# When streaming steps is True, stream_tokens must be False
|
||||||
@ -1062,7 +1062,7 @@ class RESTClient(AbstractClient):
|
|||||||
Returns:
|
Returns:
|
||||||
job (Job): Information about the async job
|
job (Job): Information about the async job
|
||||||
"""
|
"""
|
||||||
messages = [MessageCreate(role=MessageRole(role), text=message, name=name)]
|
messages = [MessageCreate(role=MessageRole(role), content=message, name=name)]
|
||||||
|
|
||||||
request = LettaRequest(messages=messages)
|
request = LettaRequest(messages=messages)
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
@ -2442,7 +2442,7 @@ class LocalClient(AbstractClient):
|
|||||||
message_id=message_id,
|
message_id=message_id,
|
||||||
request=MessageUpdate(
|
request=MessageUpdate(
|
||||||
role=role,
|
role=role,
|
||||||
text=text,
|
content=text,
|
||||||
name=name,
|
name=name,
|
||||||
tool_calls=tool_calls,
|
tool_calls=tool_calls,
|
||||||
tool_call_id=tool_call_id,
|
tool_call_id=tool_call_id,
|
||||||
@ -2741,7 +2741,7 @@ class LocalClient(AbstractClient):
|
|||||||
usage = self.server.send_messages(
|
usage = self.server.send_messages(
|
||||||
actor=self.user,
|
actor=self.user,
|
||||||
agent_id=agent_id,
|
agent_id=agent_id,
|
||||||
messages=[MessageCreate(role=MessageRole(role), text=message, name=name)],
|
messages=[MessageCreate(role=MessageRole(role), content=message, name=name)],
|
||||||
)
|
)
|
||||||
|
|
||||||
## TODO: need to make sure date/timestamp is propely passed
|
## TODO: need to make sure date/timestamp is propely passed
|
||||||
|
@ -50,7 +50,7 @@ def _sse_post(url: str, data: dict, headers: dict) -> Generator[LettaStreamingRe
|
|||||||
chunk_data = json.loads(sse.data)
|
chunk_data = json.loads(sse.data)
|
||||||
if "reasoning" in chunk_data:
|
if "reasoning" in chunk_data:
|
||||||
yield ReasoningMessage(**chunk_data)
|
yield ReasoningMessage(**chunk_data)
|
||||||
elif "assistant_message" in chunk_data:
|
elif "message_type" in chunk_data and chunk_data["message_type"] == "assistant_message":
|
||||||
yield AssistantMessage(**chunk_data)
|
yield AssistantMessage(**chunk_data)
|
||||||
elif "tool_call" in chunk_data:
|
elif "tool_call" in chunk_data:
|
||||||
yield ToolCallMessage(**chunk_data)
|
yield ToolCallMessage(**chunk_data)
|
||||||
|
@ -246,7 +246,7 @@ def parse_letta_response_for_assistant_message(
|
|||||||
reasoning_message = ""
|
reasoning_message = ""
|
||||||
for m in letta_response.messages:
|
for m in letta_response.messages:
|
||||||
if isinstance(m, AssistantMessage):
|
if isinstance(m, AssistantMessage):
|
||||||
return m.assistant_message
|
return m.content
|
||||||
elif isinstance(m, ToolCallMessage) and m.tool_call.name == assistant_message_tool_name:
|
elif isinstance(m, ToolCallMessage) and m.tool_call.name == assistant_message_tool_name:
|
||||||
try:
|
try:
|
||||||
return json.loads(m.tool_call.arguments)[assistant_message_tool_kwarg]
|
return json.loads(m.tool_call.arguments)[assistant_message_tool_kwarg]
|
||||||
@ -290,7 +290,7 @@ async def async_send_message_with_retries(
|
|||||||
logging_prefix = logging_prefix or "[async_send_message_with_retries]"
|
logging_prefix = logging_prefix or "[async_send_message_with_retries]"
|
||||||
for attempt in range(1, max_retries + 1):
|
for attempt in range(1, max_retries + 1):
|
||||||
try:
|
try:
|
||||||
messages = [MessageCreate(role=MessageRole.user, text=message_text, name=sender_agent.agent_state.name)]
|
messages = [MessageCreate(role=MessageRole.user, content=message_text, name=sender_agent.agent_state.name)]
|
||||||
# Wrap in a timeout
|
# Wrap in a timeout
|
||||||
response = await asyncio.wait_for(
|
response = await asyncio.wait_for(
|
||||||
server.send_message_to_agent(
|
server.send_message_to_agent(
|
||||||
|
@ -4,6 +4,8 @@ from typing import Annotated, List, Literal, Optional, Union
|
|||||||
|
|
||||||
from pydantic import BaseModel, Field, field_serializer, field_validator
|
from pydantic import BaseModel, Field, field_serializer, field_validator
|
||||||
|
|
||||||
|
from letta.schemas.enums import MessageContentType
|
||||||
|
|
||||||
# Letta API style responses (intended to be easier to use vs getting true Message types)
|
# Letta API style responses (intended to be easier to use vs getting true Message types)
|
||||||
|
|
||||||
|
|
||||||
@ -32,18 +34,33 @@ class LettaMessage(BaseModel):
|
|||||||
return dt.isoformat(timespec="seconds")
|
return dt.isoformat(timespec="seconds")
|
||||||
|
|
||||||
|
|
||||||
|
class MessageContent(BaseModel):
|
||||||
|
type: MessageContentType = Field(..., description="The type of the message.")
|
||||||
|
|
||||||
|
|
||||||
|
class TextContent(MessageContent):
|
||||||
|
type: Literal[MessageContentType.text] = Field(MessageContentType.text, description="The type of the message.")
|
||||||
|
text: str = Field(..., description="The text content of the message.")
|
||||||
|
|
||||||
|
|
||||||
|
MessageContentUnion = Annotated[
|
||||||
|
Union[TextContent],
|
||||||
|
Field(discriminator="type"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class SystemMessage(LettaMessage):
|
class SystemMessage(LettaMessage):
|
||||||
"""
|
"""
|
||||||
A message generated by the system. Never streamed back on a response, only used for cursor pagination.
|
A message generated by the system. Never streamed back on a response, only used for cursor pagination.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
message (str): The message sent by the system
|
content (Union[str, List[MessageContentUnion]]): The message content sent by the user (can be a string or an array of content parts)
|
||||||
id (str): The ID of the message
|
id (str): The ID of the message
|
||||||
date (datetime): The date the message was created in ISO format
|
date (datetime): The date the message was created in ISO format
|
||||||
"""
|
"""
|
||||||
|
|
||||||
message_type: Literal["system_message"] = "system_message"
|
message_type: Literal["system_message"] = "system_message"
|
||||||
message: str
|
content: Union[str, List[MessageContentUnion]]
|
||||||
|
|
||||||
|
|
||||||
class UserMessage(LettaMessage):
|
class UserMessage(LettaMessage):
|
||||||
@ -51,13 +68,13 @@ class UserMessage(LettaMessage):
|
|||||||
A message sent by the user. Never streamed back on a response, only used for cursor pagination.
|
A message sent by the user. Never streamed back on a response, only used for cursor pagination.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
message (str): The message sent by the user
|
content (Union[str, List[MessageContentUnion]]): The message content sent by the user (can be a string or an array of content parts)
|
||||||
id (str): The ID of the message
|
id (str): The ID of the message
|
||||||
date (datetime): The date the message was created in ISO format
|
date (datetime): The date the message was created in ISO format
|
||||||
"""
|
"""
|
||||||
|
|
||||||
message_type: Literal["user_message"] = "user_message"
|
message_type: Literal["user_message"] = "user_message"
|
||||||
message: str
|
content: Union[str, List[MessageContentUnion]]
|
||||||
|
|
||||||
|
|
||||||
class ReasoningMessage(LettaMessage):
|
class ReasoningMessage(LettaMessage):
|
||||||
@ -167,7 +184,7 @@ class ToolReturnMessage(LettaMessage):
|
|||||||
|
|
||||||
class AssistantMessage(LettaMessage):
|
class AssistantMessage(LettaMessage):
|
||||||
message_type: Literal["assistant_message"] = "assistant_message"
|
message_type: Literal["assistant_message"] = "assistant_message"
|
||||||
assistant_message: str
|
content: Union[str, List[MessageContentUnion]]
|
||||||
|
|
||||||
|
|
||||||
class LegacyFunctionCallMessage(LettaMessage):
|
class LegacyFunctionCallMessage(LettaMessage):
|
||||||
|
@ -2,7 +2,7 @@ import copy
|
|||||||
import json
|
import json
|
||||||
import warnings
|
import warnings
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Annotated, Any, Dict, List, Literal, Optional, Union
|
from typing import Any, Dict, List, Literal, Optional, Union
|
||||||
|
|
||||||
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall
|
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall
|
||||||
from openai.types.chat.chat_completion_message_tool_call import Function as OpenAIFunction
|
from openai.types.chat.chat_completion_message_tool_call import Function as OpenAIFunction
|
||||||
@ -15,8 +15,10 @@ from letta.schemas.letta_base import OrmMetadataBase
|
|||||||
from letta.schemas.letta_message import (
|
from letta.schemas.letta_message import (
|
||||||
AssistantMessage,
|
AssistantMessage,
|
||||||
LettaMessage,
|
LettaMessage,
|
||||||
|
MessageContentUnion,
|
||||||
ReasoningMessage,
|
ReasoningMessage,
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
|
TextContent,
|
||||||
ToolCall,
|
ToolCall,
|
||||||
ToolCallMessage,
|
ToolCallMessage,
|
||||||
ToolReturnMessage,
|
ToolReturnMessage,
|
||||||
@ -59,7 +61,7 @@ class MessageCreate(BaseModel):
|
|||||||
MessageRole.user,
|
MessageRole.user,
|
||||||
MessageRole.system,
|
MessageRole.system,
|
||||||
] = Field(..., description="The role of the participant.")
|
] = Field(..., description="The role of the participant.")
|
||||||
text: str = Field(..., description="The text of the message.")
|
content: Union[str, List[MessageContentUnion]] = Field(..., description="The content of the message.")
|
||||||
name: Optional[str] = Field(None, description="The name of the participant.")
|
name: Optional[str] = Field(None, description="The name of the participant.")
|
||||||
|
|
||||||
|
|
||||||
@ -67,7 +69,7 @@ class MessageUpdate(BaseModel):
|
|||||||
"""Request to update a message"""
|
"""Request to update a message"""
|
||||||
|
|
||||||
role: Optional[MessageRole] = Field(None, description="The role of the participant.")
|
role: Optional[MessageRole] = Field(None, description="The role of the participant.")
|
||||||
text: Optional[str] = Field(None, description="The text of the message.")
|
content: Optional[Union[str, List[MessageContentUnion]]] = Field(..., description="The content of the message.")
|
||||||
# NOTE: probably doesn't make sense to allow remapping user_id or agent_id (vs creating a new message)
|
# NOTE: probably doesn't make sense to allow remapping user_id or agent_id (vs creating a new message)
|
||||||
# user_id: Optional[str] = Field(None, description="The unique identifier of the user.")
|
# user_id: Optional[str] = Field(None, description="The unique identifier of the user.")
|
||||||
# agent_id: Optional[str] = Field(None, description="The unique identifier of the agent.")
|
# agent_id: Optional[str] = Field(None, description="The unique identifier of the agent.")
|
||||||
@ -79,20 +81,17 @@ class MessageUpdate(BaseModel):
|
|||||||
tool_calls: Optional[List[OpenAIToolCall,]] = Field(None, description="The list of tool calls requested.")
|
tool_calls: Optional[List[OpenAIToolCall,]] = Field(None, description="The list of tool calls requested.")
|
||||||
tool_call_id: Optional[str] = Field(None, description="The id of the tool call.")
|
tool_call_id: Optional[str] = Field(None, description="The id of the tool call.")
|
||||||
|
|
||||||
|
def model_dump(self, to_orm: bool = False, **kwargs) -> Dict[str, Any]:
|
||||||
class MessageContent(BaseModel):
|
data = super().model_dump(**kwargs)
|
||||||
type: MessageContentType = Field(..., description="The type of the message.")
|
if to_orm and "content" in data:
|
||||||
|
if isinstance(data["content"], str):
|
||||||
|
data["text"] = data["content"]
|
||||||
class TextContent(MessageContent):
|
else:
|
||||||
type: Literal[MessageContentType.text] = Field(MessageContentType.text, description="The type of the message.")
|
for content in data["content"]:
|
||||||
text: str = Field(..., description="The text content of the message.")
|
if content["type"] == "text":
|
||||||
|
data["text"] = content["text"]
|
||||||
|
del data["content"]
|
||||||
MessageContentUnion = Annotated[
|
return data
|
||||||
Union[TextContent],
|
|
||||||
Field(discriminator="type"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class Message(BaseMessage):
|
class Message(BaseMessage):
|
||||||
@ -212,7 +211,7 @@ class Message(BaseMessage):
|
|||||||
AssistantMessage(
|
AssistantMessage(
|
||||||
id=self.id,
|
id=self.id,
|
||||||
date=self.created_at,
|
date=self.created_at,
|
||||||
assistant_message=message_string,
|
content=message_string,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -268,7 +267,7 @@ class Message(BaseMessage):
|
|||||||
UserMessage(
|
UserMessage(
|
||||||
id=self.id,
|
id=self.id,
|
||||||
date=self.created_at,
|
date=self.created_at,
|
||||||
message=self.text,
|
content=self.text,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif self.role == MessageRole.system:
|
elif self.role == MessageRole.system:
|
||||||
@ -278,7 +277,7 @@ class Message(BaseMessage):
|
|||||||
SystemMessage(
|
SystemMessage(
|
||||||
id=self.id,
|
id=self.id,
|
||||||
date=self.created_at,
|
date=self.created_at,
|
||||||
message=self.text,
|
content=self.text,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
@ -472,7 +472,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
|||||||
processed_chunk = AssistantMessage(
|
processed_chunk = AssistantMessage(
|
||||||
id=message_id,
|
id=message_id,
|
||||||
date=message_date,
|
date=message_date,
|
||||||
assistant_message=cleaned_func_args,
|
content=cleaned_func_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
# otherwise we just do a regular passthrough of a ToolCallDelta via a ToolCallMessage
|
# otherwise we just do a regular passthrough of a ToolCallDelta via a ToolCallMessage
|
||||||
@ -613,7 +613,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
|||||||
processed_chunk = AssistantMessage(
|
processed_chunk = AssistantMessage(
|
||||||
id=message_id,
|
id=message_id,
|
||||||
date=message_date,
|
date=message_date,
|
||||||
assistant_message=combined_chunk,
|
content=combined_chunk,
|
||||||
)
|
)
|
||||||
# Store the ID of the tool call so allow skipping the corresponding response
|
# Store the ID of the tool call so allow skipping the corresponding response
|
||||||
if self.function_id_buffer:
|
if self.function_id_buffer:
|
||||||
@ -627,7 +627,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
|||||||
processed_chunk = AssistantMessage(
|
processed_chunk = AssistantMessage(
|
||||||
id=message_id,
|
id=message_id,
|
||||||
date=message_date,
|
date=message_date,
|
||||||
assistant_message=updates_main_json,
|
content=updates_main_json,
|
||||||
)
|
)
|
||||||
# Store the ID of the tool call so allow skipping the corresponding response
|
# Store the ID of the tool call so allow skipping the corresponding response
|
||||||
if self.function_id_buffer:
|
if self.function_id_buffer:
|
||||||
@ -959,7 +959,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
|||||||
processed_chunk = AssistantMessage(
|
processed_chunk = AssistantMessage(
|
||||||
id=msg_obj.id,
|
id=msg_obj.id,
|
||||||
date=msg_obj.created_at,
|
date=msg_obj.created_at,
|
||||||
assistant_message=func_args["message"],
|
content=func_args["message"],
|
||||||
)
|
)
|
||||||
self._push_to_buffer(processed_chunk)
|
self._push_to_buffer(processed_chunk)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -981,7 +981,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
|||||||
processed_chunk = AssistantMessage(
|
processed_chunk = AssistantMessage(
|
||||||
id=msg_obj.id,
|
id=msg_obj.id,
|
||||||
date=msg_obj.created_at,
|
date=msg_obj.created_at,
|
||||||
assistant_message=func_args[self.assistant_message_tool_kwarg],
|
content=func_args[self.assistant_message_tool_kwarg],
|
||||||
)
|
)
|
||||||
# Store the ID of the tool call so allow skipping the corresponding response
|
# Store the ID of the tool call so allow skipping the corresponding response
|
||||||
self.prev_assistant_message_id = function_call.id
|
self.prev_assistant_message_id = function_call.id
|
||||||
|
@ -721,9 +721,9 @@ class SyncServer(Server):
|
|||||||
|
|
||||||
# If wrapping is eanbled, wrap with metadata before placing content inside the Message object
|
# If wrapping is eanbled, wrap with metadata before placing content inside the Message object
|
||||||
if message.role == MessageRole.user and wrap_user_message:
|
if message.role == MessageRole.user and wrap_user_message:
|
||||||
message.text = system.package_user_message(user_message=message.text)
|
message.content = system.package_user_message(user_message=message.content)
|
||||||
elif message.role == MessageRole.system and wrap_system_message:
|
elif message.role == MessageRole.system and wrap_system_message:
|
||||||
message.text = system.package_system_message(system_message=message.text)
|
message.content = system.package_system_message(system_message=message.content)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid message role: {message.role}")
|
raise ValueError(f"Invalid message role: {message.role}")
|
||||||
|
|
||||||
@ -732,7 +732,7 @@ class SyncServer(Server):
|
|||||||
Message(
|
Message(
|
||||||
agent_id=agent_id,
|
agent_id=agent_id,
|
||||||
role=message.role,
|
role=message.role,
|
||||||
content=[TextContent(text=message.text)],
|
content=[TextContent(text=message.content)],
|
||||||
name=message.name,
|
name=message.name,
|
||||||
# assigned later?
|
# assigned later?
|
||||||
model=None,
|
model=None,
|
||||||
|
@ -234,11 +234,11 @@ def package_initial_message_sequence(
|
|||||||
|
|
||||||
if message_create.role == MessageRole.user:
|
if message_create.role == MessageRole.user:
|
||||||
packed_message = system.package_user_message(
|
packed_message = system.package_user_message(
|
||||||
user_message=message_create.text,
|
user_message=message_create.content,
|
||||||
)
|
)
|
||||||
elif message_create.role == MessageRole.system:
|
elif message_create.role == MessageRole.system:
|
||||||
packed_message = system.package_system_message(
|
packed_message = system.package_system_message(
|
||||||
system_message=message_create.text,
|
system_message=message_create.content,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid message role: {message_create.role}")
|
raise ValueError(f"Invalid message role: {message_create.role}")
|
||||||
|
@ -83,7 +83,7 @@ class MessageManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# get update dictionary
|
# get update dictionary
|
||||||
update_data = message_update.model_dump(exclude_unset=True, exclude_none=True)
|
update_data = message_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
|
||||||
# Remove redundant update fields
|
# Remove redundant update fields
|
||||||
update_data = {key: value for key, value in update_data.items() if getattr(message, key) != value}
|
update_data = {key: value for key, value in update_data.items() if getattr(message, key) != value}
|
||||||
|
|
||||||
|
@ -55,7 +55,7 @@ class LettaUser(HttpUser):
|
|||||||
|
|
||||||
@task(1)
|
@task(1)
|
||||||
def send_message(self):
|
def send_message(self):
|
||||||
messages = [MessageCreate(role=MessageRole("user"), text="hello")]
|
messages = [MessageCreate(role=MessageRole("user"), content="hello")]
|
||||||
request = LettaRequest(messages=messages)
|
request = LettaRequest(messages=messages)
|
||||||
|
|
||||||
with self.client.post(
|
with self.client.post(
|
||||||
@ -70,7 +70,7 @@ class LettaUser(HttpUser):
|
|||||||
# @task(1)
|
# @task(1)
|
||||||
# def send_message_stream(self):
|
# def send_message_stream(self):
|
||||||
|
|
||||||
# messages = [MessageCreate(role=MessageRole("user"), text="hello")]
|
# messages = [MessageCreate(role=MessageRole("user"), content="hello")]
|
||||||
# request = LettaRequest(messages=messages, stream_steps=True, stream_tokens=True, return_message_object=True)
|
# request = LettaRequest(messages=messages, stream_steps=True, stream_tokens=True, return_message_object=True)
|
||||||
# if stream_tokens or stream_steps:
|
# if stream_tokens or stream_steps:
|
||||||
# from letta.client.streaming import _sse_post
|
# from letta.client.streaming import _sse_post
|
||||||
|
@ -628,7 +628,7 @@ def test_initial_message_sequence(client: Union[LocalClient, RESTClient], agent:
|
|||||||
empty_agent_state = client.create_agent(name="test-empty-message-sequence", initial_message_sequence=[])
|
empty_agent_state = client.create_agent(name="test-empty-message-sequence", initial_message_sequence=[])
|
||||||
cleanup_agents.append(empty_agent_state.id)
|
cleanup_agents.append(empty_agent_state.id)
|
||||||
|
|
||||||
custom_sequence = [MessageCreate(**{"text": "Hello, how are you?", "role": MessageRole.user})]
|
custom_sequence = [MessageCreate(**{"content": "Hello, how are you?", "role": MessageRole.user})]
|
||||||
custom_agent_state = client.create_agent(name="test-custom-message-sequence", initial_message_sequence=custom_sequence)
|
custom_agent_state = client.create_agent(name="test-custom-message-sequence", initial_message_sequence=custom_sequence)
|
||||||
cleanup_agents.append(custom_agent_state.id)
|
cleanup_agents.append(custom_agent_state.id)
|
||||||
assert custom_agent_state.message_ids is not None
|
assert custom_agent_state.message_ids is not None
|
||||||
@ -637,7 +637,7 @@ def test_initial_message_sequence(client: Union[LocalClient, RESTClient], agent:
|
|||||||
), f"Expected {len(custom_sequence) + 1} messages, got {len(custom_agent_state.message_ids)}"
|
), f"Expected {len(custom_sequence) + 1} messages, got {len(custom_agent_state.message_ids)}"
|
||||||
# assert custom_agent_state.message_ids[1:] == [msg.id for msg in custom_sequence]
|
# assert custom_agent_state.message_ids[1:] == [msg.id for msg in custom_sequence]
|
||||||
# shoule be contained in second message (after system message)
|
# shoule be contained in second message (after system message)
|
||||||
assert custom_sequence[0].text in client.get_in_context_messages(custom_agent_state.id)[1].text
|
assert custom_sequence[0].content in client.get_in_context_messages(custom_agent_state.id)[1].text
|
||||||
|
|
||||||
|
|
||||||
def test_add_and_manage_tags_for_agent(client: Union[LocalClient, RESTClient], agent: AgentState):
|
def test_add_and_manage_tags_for_agent(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||||
|
@ -446,7 +446,7 @@ def comprehensive_test_agent_fixture(server: SyncServer, default_user, print_too
|
|||||||
description="test_description",
|
description="test_description",
|
||||||
metadata={"test_key": "test_value"},
|
metadata={"test_key": "test_value"},
|
||||||
tool_rules=[InitToolRule(tool_name=print_tool.name)],
|
tool_rules=[InitToolRule(tool_name=print_tool.name)],
|
||||||
initial_message_sequence=[MessageCreate(role=MessageRole.user, text="hello world")],
|
initial_message_sequence=[MessageCreate(role=MessageRole.user, content="hello world")],
|
||||||
tool_exec_environment_variables={"test_env_var_key_a": "test_env_var_value_a", "test_env_var_key_b": "test_env_var_value_b"},
|
tool_exec_environment_variables={"test_env_var_key_a": "test_env_var_value_a", "test_env_var_key_b": "test_env_var_value_b"},
|
||||||
)
|
)
|
||||||
created_agent = server.agent_manager.create_agent(
|
created_agent = server.agent_manager.create_agent(
|
||||||
@ -548,7 +548,7 @@ def test_create_agent_passed_in_initial_messages(server: SyncServer, default_use
|
|||||||
block_ids=[default_block.id],
|
block_ids=[default_block.id],
|
||||||
tags=["a", "b"],
|
tags=["a", "b"],
|
||||||
description="test_description",
|
description="test_description",
|
||||||
initial_message_sequence=[MessageCreate(role=MessageRole.user, text="hello world")],
|
initial_message_sequence=[MessageCreate(role=MessageRole.user, content="hello world")],
|
||||||
)
|
)
|
||||||
agent_state = server.agent_manager.create_agent(
|
agent_state = server.agent_manager.create_agent(
|
||||||
create_agent_request,
|
create_agent_request,
|
||||||
@ -561,7 +561,7 @@ def test_create_agent_passed_in_initial_messages(server: SyncServer, default_use
|
|||||||
assert create_agent_request.memory_blocks[0].value in init_messages[0].text
|
assert create_agent_request.memory_blocks[0].value in init_messages[0].text
|
||||||
# Check that the second message is the passed in initial message seq
|
# Check that the second message is the passed in initial message seq
|
||||||
assert create_agent_request.initial_message_sequence[0].role == init_messages[1].role
|
assert create_agent_request.initial_message_sequence[0].role == init_messages[1].role
|
||||||
assert create_agent_request.initial_message_sequence[0].text in init_messages[1].text
|
assert create_agent_request.initial_message_sequence[0].content in init_messages[1].text
|
||||||
|
|
||||||
|
|
||||||
def test_create_agent_default_initial_message(server: SyncServer, default_user, default_block):
|
def test_create_agent_default_initial_message(server: SyncServer, default_user, default_block):
|
||||||
@ -1830,7 +1830,7 @@ def test_message_get_by_id(server: SyncServer, hello_world_message_fixture, defa
|
|||||||
def test_message_update(server: SyncServer, hello_world_message_fixture, default_user, other_user):
|
def test_message_update(server: SyncServer, hello_world_message_fixture, default_user, other_user):
|
||||||
"""Test updating a message"""
|
"""Test updating a message"""
|
||||||
new_text = "Updated text"
|
new_text = "Updated text"
|
||||||
updated = server.message_manager.update_message_by_id(hello_world_message_fixture.id, MessageUpdate(text=new_text), actor=other_user)
|
updated = server.message_manager.update_message_by_id(hello_world_message_fixture.id, MessageUpdate(content=new_text), actor=other_user)
|
||||||
assert updated is not None
|
assert updated is not None
|
||||||
assert updated.text == new_text
|
assert updated.text == new_text
|
||||||
retrieved = server.message_manager.get_message_by_id(hello_world_message_fixture.id, actor=default_user)
|
retrieved = server.message_manager.get_message_by_id(hello_world_message_fixture.id, actor=default_user)
|
||||||
|
@ -96,7 +96,7 @@ def test_shared_blocks(client):
|
|||||||
messages=[
|
messages=[
|
||||||
MessageCreate(
|
MessageCreate(
|
||||||
role="user",
|
role="user",
|
||||||
text="my name is actually charles",
|
content="my name is actually charles",
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -109,7 +109,7 @@ def test_shared_blocks(client):
|
|||||||
messages=[
|
messages=[
|
||||||
MessageCreate(
|
MessageCreate(
|
||||||
role="user",
|
role="user",
|
||||||
text="whats my name?",
|
content="whats my name?",
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -339,7 +339,7 @@ def test_messages(client, agent):
|
|||||||
messages=[
|
messages=[
|
||||||
MessageCreate(
|
MessageCreate(
|
||||||
role="user",
|
role="user",
|
||||||
text="Test message",
|
content="Test message",
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -359,7 +359,7 @@ def test_send_system_message(client, agent):
|
|||||||
messages=[
|
messages=[
|
||||||
MessageCreate(
|
MessageCreate(
|
||||||
role="system",
|
role="system",
|
||||||
text="Event occurred: The user just logged off.",
|
content="Event occurred: The user just logged off.",
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -388,7 +388,7 @@ def test_function_return_limit(client, agent):
|
|||||||
messages=[
|
messages=[
|
||||||
MessageCreate(
|
MessageCreate(
|
||||||
role="user",
|
role="user",
|
||||||
text="call the big_return function",
|
content="call the big_return function",
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
config=LettaRequestConfig(use_assistant_message=False),
|
config=LettaRequestConfig(use_assistant_message=False),
|
||||||
@ -424,7 +424,7 @@ def test_function_always_error(client, agent):
|
|||||||
messages=[
|
messages=[
|
||||||
MessageCreate(
|
MessageCreate(
|
||||||
role="user",
|
role="user",
|
||||||
text="call the always_error function",
|
content="call the always_error function",
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
config=LettaRequestConfig(use_assistant_message=False),
|
config=LettaRequestConfig(use_assistant_message=False),
|
||||||
@ -455,7 +455,7 @@ async def test_send_message_parallel(client, agent):
|
|||||||
messages=[
|
messages=[
|
||||||
MessageCreate(
|
MessageCreate(
|
||||||
role="user",
|
role="user",
|
||||||
text=message,
|
content=message,
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -490,7 +490,7 @@ def test_send_message_async(client, agent):
|
|||||||
messages=[
|
messages=[
|
||||||
MessageCreate(
|
MessageCreate(
|
||||||
role="user",
|
role="user",
|
||||||
text=test_message,
|
content=test_message,
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
config=LettaRequestConfig(use_assistant_message=False),
|
config=LettaRequestConfig(use_assistant_message=False),
|
||||||
|
@ -711,12 +711,12 @@ def _test_get_messages_letta_format(
|
|||||||
|
|
||||||
elif message.role == MessageRole.user:
|
elif message.role == MessageRole.user:
|
||||||
assert isinstance(letta_message, UserMessage)
|
assert isinstance(letta_message, UserMessage)
|
||||||
assert message.text == letta_message.message
|
assert message.text == letta_message.content
|
||||||
letta_message_index += 1
|
letta_message_index += 1
|
||||||
|
|
||||||
elif message.role == MessageRole.system:
|
elif message.role == MessageRole.system:
|
||||||
assert isinstance(letta_message, SystemMessage)
|
assert isinstance(letta_message, SystemMessage)
|
||||||
assert message.text == letta_message.message
|
assert message.text == letta_message.content
|
||||||
letta_message_index += 1
|
letta_message_index += 1
|
||||||
|
|
||||||
elif message.role == MessageRole.tool:
|
elif message.role == MessageRole.tool:
|
||||||
|
@ -324,7 +324,7 @@ def test_get_run_messages(client, mock_sync_server):
|
|||||||
UserMessage(
|
UserMessage(
|
||||||
id=f"message-{i:08x}",
|
id=f"message-{i:08x}",
|
||||||
date=current_time,
|
date=current_time,
|
||||||
message=f"Test message {i}",
|
content=f"Test message {i}",
|
||||||
)
|
)
|
||||||
for i in range(2)
|
for i in range(2)
|
||||||
]
|
]
|
||||||
|
Loading…
Reference in New Issue
Block a user