mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
feat: add content union type for requests (#762)
This commit is contained in:
parent
63ea14d48a
commit
06ca10acb1
@ -46,7 +46,7 @@ response = client.agents.messages.send(
|
||||
messages=[
|
||||
MessageCreate(
|
||||
role="user",
|
||||
text="hello",
|
||||
content="hello",
|
||||
)
|
||||
],
|
||||
)
|
||||
@ -59,7 +59,7 @@ response = client.agents.messages.send(
|
||||
messages=[
|
||||
MessageCreate(
|
||||
role="system",
|
||||
text="[system] user has logged in. send a friendly message.",
|
||||
content="[system] user has logged in. send a friendly message.",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
@ -29,7 +29,7 @@ response = client.agents.messages.send(
|
||||
messages=[
|
||||
MessageCreate(
|
||||
role="user",
|
||||
text="hello",
|
||||
content="hello",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
@ -43,7 +43,7 @@ def main():
|
||||
messages=[
|
||||
MessageCreate(
|
||||
role="user",
|
||||
text="Whats my name?",
|
||||
content="Whats my name?",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
@ -64,7 +64,7 @@ response = client.agents.messages.send(
|
||||
messages=[
|
||||
MessageCreate(
|
||||
role="user",
|
||||
text="roll a dice",
|
||||
content="roll a dice",
|
||||
)
|
||||
],
|
||||
)
|
||||
@ -100,7 +100,7 @@ client.agents.messages.send(
|
||||
messages=[
|
||||
MessageCreate(
|
||||
role="user",
|
||||
text="search your archival memory",
|
||||
content="search your archival memory",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
@ -246,7 +246,7 @@
|
||||
" messages=[\n",
|
||||
" MessageCreate(\n",
|
||||
" role=\"user\",\n",
|
||||
" text=\"Search archival for our company's vacation policies\",\n",
|
||||
" content=\"Search archival for our company's vacation policies\",\n",
|
||||
" )\n",
|
||||
" ],\n",
|
||||
")\n",
|
||||
@ -528,7 +528,7 @@
|
||||
" messages=[\n",
|
||||
" MessageCreate(\n",
|
||||
" role=\"user\",\n",
|
||||
" text=\"When is my birthday?\",\n",
|
||||
" content=\"When is my birthday?\",\n",
|
||||
" )\n",
|
||||
" ],\n",
|
||||
")\n",
|
||||
@ -814,7 +814,7 @@
|
||||
" messages=[\n",
|
||||
" MessageCreate(\n",
|
||||
" role=\"user\",\n",
|
||||
" text=\"Who founded OpenAI?\",\n",
|
||||
" content=\"Who founded OpenAI?\",\n",
|
||||
" )\n",
|
||||
" ],\n",
|
||||
")\n",
|
||||
@ -952,7 +952,7 @@
|
||||
" messages=[\n",
|
||||
" MessageCreate(\n",
|
||||
" role=\"user\",\n",
|
||||
" text=\"Who founded OpenAI?\",\n",
|
||||
" content=\"Who founded OpenAI?\",\n",
|
||||
" )\n",
|
||||
" ],\n",
|
||||
")\n",
|
||||
|
@ -169,7 +169,7 @@
|
||||
" messages=[\n",
|
||||
" MessageCreate(\n",
|
||||
" role=\"user\",\n",
|
||||
" text=\"hello!\",\n",
|
||||
" content=\"hello!\",\n",
|
||||
" )\n",
|
||||
" ],\n",
|
||||
")\n",
|
||||
@ -529,7 +529,7 @@
|
||||
" messages=[\n",
|
||||
" MessageCreate(\n",
|
||||
" role=\"user\",\n",
|
||||
" text=\"My name is actually Bob\",\n",
|
||||
" content=\"My name is actually Bob\",\n",
|
||||
" )\n",
|
||||
" ],\n",
|
||||
")\n",
|
||||
@ -682,7 +682,7 @@
|
||||
" messages=[\n",
|
||||
" MessageCreate(\n",
|
||||
" role=\"user\",\n",
|
||||
" text=\"In the future, never use emojis to communicate\",\n",
|
||||
" content=\"In the future, never use emojis to communicate\",\n",
|
||||
" )\n",
|
||||
" ],\n",
|
||||
")\n",
|
||||
@ -870,7 +870,7 @@
|
||||
" messages=[\n",
|
||||
" MessageCreate(\n",
|
||||
" role=\"user\",\n",
|
||||
" text=\"Save the information that 'bob loves cats' to archival\",\n",
|
||||
" content=\"Save the information that 'bob loves cats' to archival\",\n",
|
||||
" )\n",
|
||||
" ],\n",
|
||||
")\n",
|
||||
@ -1039,7 +1039,7 @@
|
||||
" messages=[\n",
|
||||
" MessageCreate(\n",
|
||||
" role=\"user\",\n",
|
||||
" text=\"What animals do I like? Search archival.\",\n",
|
||||
" content=\"What animals do I like? Search archival.\",\n",
|
||||
" )\n",
|
||||
" ],\n",
|
||||
")\n",
|
||||
|
@ -276,7 +276,7 @@
|
||||
" messages=[\n",
|
||||
" MessageCreate(\n",
|
||||
" role=\"user\",\n",
|
||||
" text=\"Candidate: Tony Stark\",\n",
|
||||
" content=\"Candidate: Tony Stark\",\n",
|
||||
" )\n",
|
||||
" ],\n",
|
||||
")"
|
||||
@ -403,7 +403,7 @@
|
||||
" messages=[\n",
|
||||
" MessageCreate(\n",
|
||||
" role=\"user\",\n",
|
||||
" text=feedback,\n",
|
||||
" content=feedback,\n",
|
||||
" )\n",
|
||||
" ],\n",
|
||||
")"
|
||||
@ -423,7 +423,7 @@
|
||||
" messages=[\n",
|
||||
" MessageCreate(\n",
|
||||
" role=\"user\",\n",
|
||||
" text=feedback,\n",
|
||||
" content=feedback,\n",
|
||||
" )\n",
|
||||
" ],\n",
|
||||
")"
|
||||
@ -540,7 +540,7 @@
|
||||
" messages=[\n",
|
||||
" MessageCreate(\n",
|
||||
" role=\"system\",\n",
|
||||
" text=\"Candidate: Spongebob Squarepants\",\n",
|
||||
" content=\"Candidate: Spongebob Squarepants\",\n",
|
||||
" )\n",
|
||||
" ],\n",
|
||||
")"
|
||||
@ -758,7 +758,7 @@
|
||||
" messages=[\n",
|
||||
" MessageCreate(\n",
|
||||
" role=\"system\",\n",
|
||||
" text=\"Run generation\",\n",
|
||||
" content=\"Run generation\",\n",
|
||||
" )\n",
|
||||
" ],\n",
|
||||
")"
|
||||
|
@ -643,7 +643,7 @@ class RESTClient(AbstractClient):
|
||||
) -> Message:
|
||||
request = MessageUpdate(
|
||||
role=role,
|
||||
text=text,
|
||||
content=text,
|
||||
name=name,
|
||||
tool_calls=tool_calls,
|
||||
tool_call_id=tool_call_id,
|
||||
@ -1015,7 +1015,7 @@ class RESTClient(AbstractClient):
|
||||
response (LettaResponse): Response from the agent
|
||||
"""
|
||||
# TODO: implement include_full_message
|
||||
messages = [MessageCreate(role=MessageRole(role), text=message, name=name)]
|
||||
messages = [MessageCreate(role=MessageRole(role), content=message, name=name)]
|
||||
# TODO: figure out how to handle stream_steps and stream_tokens
|
||||
|
||||
# When streaming steps is True, stream_tokens must be False
|
||||
@ -1062,7 +1062,7 @@ class RESTClient(AbstractClient):
|
||||
Returns:
|
||||
job (Job): Information about the async job
|
||||
"""
|
||||
messages = [MessageCreate(role=MessageRole(role), text=message, name=name)]
|
||||
messages = [MessageCreate(role=MessageRole(role), content=message, name=name)]
|
||||
|
||||
request = LettaRequest(messages=messages)
|
||||
response = requests.post(
|
||||
@ -2442,7 +2442,7 @@ class LocalClient(AbstractClient):
|
||||
message_id=message_id,
|
||||
request=MessageUpdate(
|
||||
role=role,
|
||||
text=text,
|
||||
content=text,
|
||||
name=name,
|
||||
tool_calls=tool_calls,
|
||||
tool_call_id=tool_call_id,
|
||||
@ -2741,7 +2741,7 @@ class LocalClient(AbstractClient):
|
||||
usage = self.server.send_messages(
|
||||
actor=self.user,
|
||||
agent_id=agent_id,
|
||||
messages=[MessageCreate(role=MessageRole(role), text=message, name=name)],
|
||||
messages=[MessageCreate(role=MessageRole(role), content=message, name=name)],
|
||||
)
|
||||
|
||||
## TODO: need to make sure date/timestamp is propely passed
|
||||
|
@ -50,7 +50,7 @@ def _sse_post(url: str, data: dict, headers: dict) -> Generator[LettaStreamingRe
|
||||
chunk_data = json.loads(sse.data)
|
||||
if "reasoning" in chunk_data:
|
||||
yield ReasoningMessage(**chunk_data)
|
||||
elif "assistant_message" in chunk_data:
|
||||
elif "message_type" in chunk_data and chunk_data["message_type"] == "assistant_message":
|
||||
yield AssistantMessage(**chunk_data)
|
||||
elif "tool_call" in chunk_data:
|
||||
yield ToolCallMessage(**chunk_data)
|
||||
|
@ -246,7 +246,7 @@ def parse_letta_response_for_assistant_message(
|
||||
reasoning_message = ""
|
||||
for m in letta_response.messages:
|
||||
if isinstance(m, AssistantMessage):
|
||||
return m.assistant_message
|
||||
return m.content
|
||||
elif isinstance(m, ToolCallMessage) and m.tool_call.name == assistant_message_tool_name:
|
||||
try:
|
||||
return json.loads(m.tool_call.arguments)[assistant_message_tool_kwarg]
|
||||
@ -290,7 +290,7 @@ async def async_send_message_with_retries(
|
||||
logging_prefix = logging_prefix or "[async_send_message_with_retries]"
|
||||
for attempt in range(1, max_retries + 1):
|
||||
try:
|
||||
messages = [MessageCreate(role=MessageRole.user, text=message_text, name=sender_agent.agent_state.name)]
|
||||
messages = [MessageCreate(role=MessageRole.user, content=message_text, name=sender_agent.agent_state.name)]
|
||||
# Wrap in a timeout
|
||||
response = await asyncio.wait_for(
|
||||
server.send_message_to_agent(
|
||||
|
@ -4,6 +4,8 @@ from typing import Annotated, List, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field, field_serializer, field_validator
|
||||
|
||||
from letta.schemas.enums import MessageContentType
|
||||
|
||||
# Letta API style responses (intended to be easier to use vs getting true Message types)
|
||||
|
||||
|
||||
@ -32,18 +34,33 @@ class LettaMessage(BaseModel):
|
||||
return dt.isoformat(timespec="seconds")
|
||||
|
||||
|
||||
class MessageContent(BaseModel):
|
||||
type: MessageContentType = Field(..., description="The type of the message.")
|
||||
|
||||
|
||||
class TextContent(MessageContent):
|
||||
type: Literal[MessageContentType.text] = Field(MessageContentType.text, description="The type of the message.")
|
||||
text: str = Field(..., description="The text content of the message.")
|
||||
|
||||
|
||||
MessageContentUnion = Annotated[
|
||||
Union[TextContent],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
class SystemMessage(LettaMessage):
|
||||
"""
|
||||
A message generated by the system. Never streamed back on a response, only used for cursor pagination.
|
||||
|
||||
Attributes:
|
||||
message (str): The message sent by the system
|
||||
content (Union[str, List[MessageContentUnion]]): The message content sent by the user (can be a string or an array of content parts)
|
||||
id (str): The ID of the message
|
||||
date (datetime): The date the message was created in ISO format
|
||||
"""
|
||||
|
||||
message_type: Literal["system_message"] = "system_message"
|
||||
message: str
|
||||
content: Union[str, List[MessageContentUnion]]
|
||||
|
||||
|
||||
class UserMessage(LettaMessage):
|
||||
@ -51,13 +68,13 @@ class UserMessage(LettaMessage):
|
||||
A message sent by the user. Never streamed back on a response, only used for cursor pagination.
|
||||
|
||||
Attributes:
|
||||
message (str): The message sent by the user
|
||||
content (Union[str, List[MessageContentUnion]]): The message content sent by the user (can be a string or an array of content parts)
|
||||
id (str): The ID of the message
|
||||
date (datetime): The date the message was created in ISO format
|
||||
"""
|
||||
|
||||
message_type: Literal["user_message"] = "user_message"
|
||||
message: str
|
||||
content: Union[str, List[MessageContentUnion]]
|
||||
|
||||
|
||||
class ReasoningMessage(LettaMessage):
|
||||
@ -167,7 +184,7 @@ class ToolReturnMessage(LettaMessage):
|
||||
|
||||
class AssistantMessage(LettaMessage):
|
||||
message_type: Literal["assistant_message"] = "assistant_message"
|
||||
assistant_message: str
|
||||
content: Union[str, List[MessageContentUnion]]
|
||||
|
||||
|
||||
class LegacyFunctionCallMessage(LettaMessage):
|
||||
|
@ -2,7 +2,7 @@ import copy
|
||||
import json
|
||||
import warnings
|
||||
from datetime import datetime, timezone
|
||||
from typing import Annotated, Any, Dict, List, Literal, Optional, Union
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall
|
||||
from openai.types.chat.chat_completion_message_tool_call import Function as OpenAIFunction
|
||||
@ -15,8 +15,10 @@ from letta.schemas.letta_base import OrmMetadataBase
|
||||
from letta.schemas.letta_message import (
|
||||
AssistantMessage,
|
||||
LettaMessage,
|
||||
MessageContentUnion,
|
||||
ReasoningMessage,
|
||||
SystemMessage,
|
||||
TextContent,
|
||||
ToolCall,
|
||||
ToolCallMessage,
|
||||
ToolReturnMessage,
|
||||
@ -59,7 +61,7 @@ class MessageCreate(BaseModel):
|
||||
MessageRole.user,
|
||||
MessageRole.system,
|
||||
] = Field(..., description="The role of the participant.")
|
||||
text: str = Field(..., description="The text of the message.")
|
||||
content: Union[str, List[MessageContentUnion]] = Field(..., description="The content of the message.")
|
||||
name: Optional[str] = Field(None, description="The name of the participant.")
|
||||
|
||||
|
||||
@ -67,7 +69,7 @@ class MessageUpdate(BaseModel):
|
||||
"""Request to update a message"""
|
||||
|
||||
role: Optional[MessageRole] = Field(None, description="The role of the participant.")
|
||||
text: Optional[str] = Field(None, description="The text of the message.")
|
||||
content: Optional[Union[str, List[MessageContentUnion]]] = Field(..., description="The content of the message.")
|
||||
# NOTE: probably doesn't make sense to allow remapping user_id or agent_id (vs creating a new message)
|
||||
# user_id: Optional[str] = Field(None, description="The unique identifier of the user.")
|
||||
# agent_id: Optional[str] = Field(None, description="The unique identifier of the agent.")
|
||||
@ -79,20 +81,17 @@ class MessageUpdate(BaseModel):
|
||||
tool_calls: Optional[List[OpenAIToolCall,]] = Field(None, description="The list of tool calls requested.")
|
||||
tool_call_id: Optional[str] = Field(None, description="The id of the tool call.")
|
||||
|
||||
|
||||
class MessageContent(BaseModel):
|
||||
type: MessageContentType = Field(..., description="The type of the message.")
|
||||
|
||||
|
||||
class TextContent(MessageContent):
|
||||
type: Literal[MessageContentType.text] = Field(MessageContentType.text, description="The type of the message.")
|
||||
text: str = Field(..., description="The text content of the message.")
|
||||
|
||||
|
||||
MessageContentUnion = Annotated[
|
||||
Union[TextContent],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
def model_dump(self, to_orm: bool = False, **kwargs) -> Dict[str, Any]:
|
||||
data = super().model_dump(**kwargs)
|
||||
if to_orm and "content" in data:
|
||||
if isinstance(data["content"], str):
|
||||
data["text"] = data["content"]
|
||||
else:
|
||||
for content in data["content"]:
|
||||
if content["type"] == "text":
|
||||
data["text"] = content["text"]
|
||||
del data["content"]
|
||||
return data
|
||||
|
||||
|
||||
class Message(BaseMessage):
|
||||
@ -212,7 +211,7 @@ class Message(BaseMessage):
|
||||
AssistantMessage(
|
||||
id=self.id,
|
||||
date=self.created_at,
|
||||
assistant_message=message_string,
|
||||
content=message_string,
|
||||
)
|
||||
)
|
||||
else:
|
||||
@ -268,7 +267,7 @@ class Message(BaseMessage):
|
||||
UserMessage(
|
||||
id=self.id,
|
||||
date=self.created_at,
|
||||
message=self.text,
|
||||
content=self.text,
|
||||
)
|
||||
)
|
||||
elif self.role == MessageRole.system:
|
||||
@ -278,7 +277,7 @@ class Message(BaseMessage):
|
||||
SystemMessage(
|
||||
id=self.id,
|
||||
date=self.created_at,
|
||||
message=self.text,
|
||||
content=self.text,
|
||||
)
|
||||
)
|
||||
else:
|
||||
|
@ -472,7 +472,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
processed_chunk = AssistantMessage(
|
||||
id=message_id,
|
||||
date=message_date,
|
||||
assistant_message=cleaned_func_args,
|
||||
content=cleaned_func_args,
|
||||
)
|
||||
|
||||
# otherwise we just do a regular passthrough of a ToolCallDelta via a ToolCallMessage
|
||||
@ -613,7 +613,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
processed_chunk = AssistantMessage(
|
||||
id=message_id,
|
||||
date=message_date,
|
||||
assistant_message=combined_chunk,
|
||||
content=combined_chunk,
|
||||
)
|
||||
# Store the ID of the tool call so allow skipping the corresponding response
|
||||
if self.function_id_buffer:
|
||||
@ -627,7 +627,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
processed_chunk = AssistantMessage(
|
||||
id=message_id,
|
||||
date=message_date,
|
||||
assistant_message=updates_main_json,
|
||||
content=updates_main_json,
|
||||
)
|
||||
# Store the ID of the tool call so allow skipping the corresponding response
|
||||
if self.function_id_buffer:
|
||||
@ -959,7 +959,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
processed_chunk = AssistantMessage(
|
||||
id=msg_obj.id,
|
||||
date=msg_obj.created_at,
|
||||
assistant_message=func_args["message"],
|
||||
content=func_args["message"],
|
||||
)
|
||||
self._push_to_buffer(processed_chunk)
|
||||
except Exception as e:
|
||||
@ -981,7 +981,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
processed_chunk = AssistantMessage(
|
||||
id=msg_obj.id,
|
||||
date=msg_obj.created_at,
|
||||
assistant_message=func_args[self.assistant_message_tool_kwarg],
|
||||
content=func_args[self.assistant_message_tool_kwarg],
|
||||
)
|
||||
# Store the ID of the tool call so allow skipping the corresponding response
|
||||
self.prev_assistant_message_id = function_call.id
|
||||
|
@ -721,9 +721,9 @@ class SyncServer(Server):
|
||||
|
||||
# If wrapping is eanbled, wrap with metadata before placing content inside the Message object
|
||||
if message.role == MessageRole.user and wrap_user_message:
|
||||
message.text = system.package_user_message(user_message=message.text)
|
||||
message.content = system.package_user_message(user_message=message.content)
|
||||
elif message.role == MessageRole.system and wrap_system_message:
|
||||
message.text = system.package_system_message(system_message=message.text)
|
||||
message.content = system.package_system_message(system_message=message.content)
|
||||
else:
|
||||
raise ValueError(f"Invalid message role: {message.role}")
|
||||
|
||||
@ -732,7 +732,7 @@ class SyncServer(Server):
|
||||
Message(
|
||||
agent_id=agent_id,
|
||||
role=message.role,
|
||||
content=[TextContent(text=message.text)],
|
||||
content=[TextContent(text=message.content)],
|
||||
name=message.name,
|
||||
# assigned later?
|
||||
model=None,
|
||||
|
@ -234,11 +234,11 @@ def package_initial_message_sequence(
|
||||
|
||||
if message_create.role == MessageRole.user:
|
||||
packed_message = system.package_user_message(
|
||||
user_message=message_create.text,
|
||||
user_message=message_create.content,
|
||||
)
|
||||
elif message_create.role == MessageRole.system:
|
||||
packed_message = system.package_system_message(
|
||||
system_message=message_create.text,
|
||||
system_message=message_create.content,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid message role: {message_create.role}")
|
||||
|
@ -83,7 +83,7 @@ class MessageManager:
|
||||
)
|
||||
|
||||
# get update dictionary
|
||||
update_data = message_update.model_dump(exclude_unset=True, exclude_none=True)
|
||||
update_data = message_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
|
||||
# Remove redundant update fields
|
||||
update_data = {key: value for key, value in update_data.items() if getattr(message, key) != value}
|
||||
|
||||
|
@ -55,7 +55,7 @@ class LettaUser(HttpUser):
|
||||
|
||||
@task(1)
|
||||
def send_message(self):
|
||||
messages = [MessageCreate(role=MessageRole("user"), text="hello")]
|
||||
messages = [MessageCreate(role=MessageRole("user"), content="hello")]
|
||||
request = LettaRequest(messages=messages)
|
||||
|
||||
with self.client.post(
|
||||
@ -70,7 +70,7 @@ class LettaUser(HttpUser):
|
||||
# @task(1)
|
||||
# def send_message_stream(self):
|
||||
|
||||
# messages = [MessageCreate(role=MessageRole("user"), text="hello")]
|
||||
# messages = [MessageCreate(role=MessageRole("user"), content="hello")]
|
||||
# request = LettaRequest(messages=messages, stream_steps=True, stream_tokens=True, return_message_object=True)
|
||||
# if stream_tokens or stream_steps:
|
||||
# from letta.client.streaming import _sse_post
|
||||
|
@ -628,7 +628,7 @@ def test_initial_message_sequence(client: Union[LocalClient, RESTClient], agent:
|
||||
empty_agent_state = client.create_agent(name="test-empty-message-sequence", initial_message_sequence=[])
|
||||
cleanup_agents.append(empty_agent_state.id)
|
||||
|
||||
custom_sequence = [MessageCreate(**{"text": "Hello, how are you?", "role": MessageRole.user})]
|
||||
custom_sequence = [MessageCreate(**{"content": "Hello, how are you?", "role": MessageRole.user})]
|
||||
custom_agent_state = client.create_agent(name="test-custom-message-sequence", initial_message_sequence=custom_sequence)
|
||||
cleanup_agents.append(custom_agent_state.id)
|
||||
assert custom_agent_state.message_ids is not None
|
||||
@ -637,7 +637,7 @@ def test_initial_message_sequence(client: Union[LocalClient, RESTClient], agent:
|
||||
), f"Expected {len(custom_sequence) + 1} messages, got {len(custom_agent_state.message_ids)}"
|
||||
# assert custom_agent_state.message_ids[1:] == [msg.id for msg in custom_sequence]
|
||||
# shoule be contained in second message (after system message)
|
||||
assert custom_sequence[0].text in client.get_in_context_messages(custom_agent_state.id)[1].text
|
||||
assert custom_sequence[0].content in client.get_in_context_messages(custom_agent_state.id)[1].text
|
||||
|
||||
|
||||
def test_add_and_manage_tags_for_agent(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
|
@ -446,7 +446,7 @@ def comprehensive_test_agent_fixture(server: SyncServer, default_user, print_too
|
||||
description="test_description",
|
||||
metadata={"test_key": "test_value"},
|
||||
tool_rules=[InitToolRule(tool_name=print_tool.name)],
|
||||
initial_message_sequence=[MessageCreate(role=MessageRole.user, text="hello world")],
|
||||
initial_message_sequence=[MessageCreate(role=MessageRole.user, content="hello world")],
|
||||
tool_exec_environment_variables={"test_env_var_key_a": "test_env_var_value_a", "test_env_var_key_b": "test_env_var_value_b"},
|
||||
)
|
||||
created_agent = server.agent_manager.create_agent(
|
||||
@ -548,7 +548,7 @@ def test_create_agent_passed_in_initial_messages(server: SyncServer, default_use
|
||||
block_ids=[default_block.id],
|
||||
tags=["a", "b"],
|
||||
description="test_description",
|
||||
initial_message_sequence=[MessageCreate(role=MessageRole.user, text="hello world")],
|
||||
initial_message_sequence=[MessageCreate(role=MessageRole.user, content="hello world")],
|
||||
)
|
||||
agent_state = server.agent_manager.create_agent(
|
||||
create_agent_request,
|
||||
@ -561,7 +561,7 @@ def test_create_agent_passed_in_initial_messages(server: SyncServer, default_use
|
||||
assert create_agent_request.memory_blocks[0].value in init_messages[0].text
|
||||
# Check that the second message is the passed in initial message seq
|
||||
assert create_agent_request.initial_message_sequence[0].role == init_messages[1].role
|
||||
assert create_agent_request.initial_message_sequence[0].text in init_messages[1].text
|
||||
assert create_agent_request.initial_message_sequence[0].content in init_messages[1].text
|
||||
|
||||
|
||||
def test_create_agent_default_initial_message(server: SyncServer, default_user, default_block):
|
||||
@ -1830,7 +1830,7 @@ def test_message_get_by_id(server: SyncServer, hello_world_message_fixture, defa
|
||||
def test_message_update(server: SyncServer, hello_world_message_fixture, default_user, other_user):
|
||||
"""Test updating a message"""
|
||||
new_text = "Updated text"
|
||||
updated = server.message_manager.update_message_by_id(hello_world_message_fixture.id, MessageUpdate(text=new_text), actor=other_user)
|
||||
updated = server.message_manager.update_message_by_id(hello_world_message_fixture.id, MessageUpdate(content=new_text), actor=other_user)
|
||||
assert updated is not None
|
||||
assert updated.text == new_text
|
||||
retrieved = server.message_manager.get_message_by_id(hello_world_message_fixture.id, actor=default_user)
|
||||
|
@ -96,7 +96,7 @@ def test_shared_blocks(client):
|
||||
messages=[
|
||||
MessageCreate(
|
||||
role="user",
|
||||
text="my name is actually charles",
|
||||
content="my name is actually charles",
|
||||
)
|
||||
],
|
||||
)
|
||||
@ -109,7 +109,7 @@ def test_shared_blocks(client):
|
||||
messages=[
|
||||
MessageCreate(
|
||||
role="user",
|
||||
text="whats my name?",
|
||||
content="whats my name?",
|
||||
)
|
||||
],
|
||||
)
|
||||
@ -339,7 +339,7 @@ def test_messages(client, agent):
|
||||
messages=[
|
||||
MessageCreate(
|
||||
role="user",
|
||||
text="Test message",
|
||||
content="Test message",
|
||||
),
|
||||
],
|
||||
)
|
||||
@ -359,7 +359,7 @@ def test_send_system_message(client, agent):
|
||||
messages=[
|
||||
MessageCreate(
|
||||
role="system",
|
||||
text="Event occurred: The user just logged off.",
|
||||
content="Event occurred: The user just logged off.",
|
||||
),
|
||||
],
|
||||
)
|
||||
@ -388,7 +388,7 @@ def test_function_return_limit(client, agent):
|
||||
messages=[
|
||||
MessageCreate(
|
||||
role="user",
|
||||
text="call the big_return function",
|
||||
content="call the big_return function",
|
||||
),
|
||||
],
|
||||
config=LettaRequestConfig(use_assistant_message=False),
|
||||
@ -424,7 +424,7 @@ def test_function_always_error(client, agent):
|
||||
messages=[
|
||||
MessageCreate(
|
||||
role="user",
|
||||
text="call the always_error function",
|
||||
content="call the always_error function",
|
||||
),
|
||||
],
|
||||
config=LettaRequestConfig(use_assistant_message=False),
|
||||
@ -455,7 +455,7 @@ async def test_send_message_parallel(client, agent):
|
||||
messages=[
|
||||
MessageCreate(
|
||||
role="user",
|
||||
text=message,
|
||||
content=message,
|
||||
),
|
||||
],
|
||||
)
|
||||
@ -490,7 +490,7 @@ def test_send_message_async(client, agent):
|
||||
messages=[
|
||||
MessageCreate(
|
||||
role="user",
|
||||
text=test_message,
|
||||
content=test_message,
|
||||
),
|
||||
],
|
||||
config=LettaRequestConfig(use_assistant_message=False),
|
||||
|
@ -711,12 +711,12 @@ def _test_get_messages_letta_format(
|
||||
|
||||
elif message.role == MessageRole.user:
|
||||
assert isinstance(letta_message, UserMessage)
|
||||
assert message.text == letta_message.message
|
||||
assert message.text == letta_message.content
|
||||
letta_message_index += 1
|
||||
|
||||
elif message.role == MessageRole.system:
|
||||
assert isinstance(letta_message, SystemMessage)
|
||||
assert message.text == letta_message.message
|
||||
assert message.text == letta_message.content
|
||||
letta_message_index += 1
|
||||
|
||||
elif message.role == MessageRole.tool:
|
||||
|
@ -324,7 +324,7 @@ def test_get_run_messages(client, mock_sync_server):
|
||||
UserMessage(
|
||||
id=f"message-{i:08x}",
|
||||
date=current_time,
|
||||
message=f"Test message {i}",
|
||||
content=f"Test message {i}",
|
||||
)
|
||||
for i in range(2)
|
||||
]
|
||||
|
Loading…
Reference in New Issue
Block a user