feat: Add message listing for a letta batch (#1982)

This commit is contained in:
Matthew Zhou 2025-05-02 11:14:03 -07:00 committed by GitHub
parent ee1f3b54c6
commit 6590786bd3
13 changed files with 289 additions and 48 deletions

View File

@ -0,0 +1,31 @@
"""Add batch_item_id to messages
Revision ID: 0335b1eb9c40
Revises: 373dabcba6cf
Create Date: 2025-05-02 10:30:08.156190
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "0335b1eb9c40"
down_revision: Union[str, None] = "373dabcba6cf"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("messages", sa.Column("batch_item_id", sa.String(), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("messages", "batch_item_id")
# ### end Alembic commands ###

View File

@ -21,7 +21,10 @@ def _create_letta_response(new_in_context_messages: list[Message], use_assistant
def _prepare_in_context_messages(
input_messages: List[MessageCreate], agent_state: AgentState, message_manager: MessageManager, actor: User
input_messages: List[MessageCreate],
agent_state: AgentState,
message_manager: MessageManager,
actor: User,
) -> Tuple[List[Message], List[Message]]:
"""
Prepares in-context messages for an agent, based on the current state and a new user input.

View File

@ -137,21 +137,37 @@ class LettaAgentBatch:
log_event(name="load_and_prepare_agents")
agent_messages_mapping: Dict[str, List[Message]] = {}
agent_tools_mapping: Dict[str, List[dict]] = {}
# TODO: This isn't optimal, moving fast - prone to bugs because we pass around this half formed pydantic object
agent_batch_item_mapping: Dict[str, LLMBatchItem] = {}
agent_states = []
for batch_request in batch_requests:
agent_id = batch_request.agent_id
agent_state = self.agent_manager.get_agent_by_id(agent_id, actor=self.actor)
agent_states.append(agent_state)
agent_messages_mapping[agent_id] = self._get_in_context_messages_per_agent(
agent_state=agent_state, input_messages=batch_request.messages
)
if agent_id not in agent_step_state_mapping:
agent_step_state_mapping[agent_id] = AgentStepState(
step_number=0, tool_rules_solver=ToolRulesSolver(tool_rules=agent_state.tool_rules)
)
llm_batch_item = LLMBatchItem(
llm_batch_id="", # TODO: This is hacky, it gets filled in later
agent_id=agent_state.id,
llm_config=agent_state.llm_config,
request_status=JobStatus.created,
step_status=AgentStepStatus.paused,
step_state=agent_step_state_mapping[agent_id],
)
agent_batch_item_mapping[agent_id] = llm_batch_item
# Fill in the batch_item_id for the message
for msg in batch_request.messages:
msg.batch_item_id = llm_batch_item.id
agent_messages_mapping[agent_id] = self._prepare_in_context_messages_per_agent(
agent_state=agent_state, input_messages=batch_request.messages
)
agent_tools_mapping[agent_id] = self._prepare_tools_per_agent(agent_state, agent_step_state_mapping[agent_id].tool_rules_solver)
log_event(name="init_llm_client")
@ -182,21 +198,14 @@ class LettaAgentBatch:
log_event(name="prepare_batch_items")
batch_items = []
for state in agent_states:
step_state = agent_step_state_mapping[state.id]
batch_items.append(
LLMBatchItem(
llm_batch_id=llm_batch_job.id,
agent_id=state.id,
llm_config=state.llm_config,
request_status=JobStatus.created,
step_status=AgentStepStatus.paused,
step_state=step_state,
)
)
llm_batch_item = agent_batch_item_mapping[state.id]
# TODO This is hacky
llm_batch_item.llm_batch_id = llm_batch_job.id
batch_items.append(llm_batch_item)
if batch_items:
log_event(name="bulk_create_batch_items")
self.batch_manager.create_llm_batch_items_bulk(batch_items, actor=self.actor)
batch_items_persisted = self.batch_manager.create_llm_batch_items_bulk(batch_items, actor=self.actor)
log_event(name="return_batch_response")
return LettaBatchResponse(
@ -335,9 +344,14 @@ class LettaAgentBatch:
exec_results: Sequence[Tuple[str, Tuple[str, bool]]],
ctx: _ResumeContext,
) -> Dict[str, List[Message]]:
# TODO: This is redundant, we should have this ready on the ctx
# TODO: I am doing it quick and dirty for now
agent_item_map: Dict[str, LLMBatchItem] = {item.agent_id: item for item in ctx.batch_items}
msg_map: Dict[str, List[Message]] = {}
for aid, (tool_res, success) in exec_results:
msgs = self._create_tool_call_messages(
llm_batch_item_id=agent_item_map[aid].id,
agent_state=ctx.agent_state_map[aid],
tool_call_name=ctx.tool_call_name_map[aid],
tool_call_args=ctx.tool_call_args_map[aid],
@ -399,6 +413,7 @@ class LettaAgentBatch:
def _create_tool_call_messages(
self,
llm_batch_item_id: str,
agent_state: AgentState,
tool_call_name: str,
tool_call_args: Dict[str, Any],
@ -421,6 +436,7 @@ class LettaAgentBatch:
reasoning_content=reasoning_content,
pre_computed_assistant_message_id=None,
pre_computed_tool_message_id=None,
llm_batch_item_id=llm_batch_item_id,
)
return tool_call_messages
@ -477,7 +493,7 @@ class LettaAgentBatch:
valid_tool_names = tool_rules_solver.get_allowed_tool_names(available_tools=set([t.name for t in tools]))
return [enable_strict_mode(t.json_schema) for t in tools if t.name in set(valid_tool_names)]
def _get_in_context_messages_per_agent(self, agent_state: AgentState, input_messages: List[MessageCreate]) -> List[Message]:
def _prepare_in_context_messages_per_agent(self, agent_state: AgentState, input_messages: List[MessageCreate]) -> List[Message]:
current_in_context_messages, new_in_context_messages = _prepare_in_context_messages(
input_messages, agent_state, self.message_manager, self.actor
)

View File

@ -5,57 +5,58 @@ from letta.schemas.message import Message, MessageCreate
def convert_message_creates_to_messages(
messages: list[MessageCreate],
message_creates: list[MessageCreate],
agent_id: str,
wrap_user_message: bool = True,
wrap_system_message: bool = True,
) -> list[Message]:
return [
_convert_message_create_to_message(
message=message,
message_create=create,
agent_id=agent_id,
wrap_user_message=wrap_user_message,
wrap_system_message=wrap_system_message,
)
for message in messages
for create in message_creates
]
def _convert_message_create_to_message(
message: MessageCreate,
message_create: MessageCreate,
agent_id: str,
wrap_user_message: bool = True,
wrap_system_message: bool = True,
) -> Message:
"""Converts a MessageCreate object into a Message object, applying wrapping if needed."""
# TODO: This seems like extra boilerplate with little benefit
assert isinstance(message, MessageCreate)
assert isinstance(message_create, MessageCreate)
# Extract message content
if isinstance(message.content, str):
message_content = message.content
elif message.content and len(message.content) > 0 and isinstance(message.content[0], TextContent):
message_content = message.content[0].text
if isinstance(message_create.content, str):
message_content = message_create.content
elif message_create.content and len(message_create.content) > 0 and isinstance(message_create.content[0], TextContent):
message_content = message_create.content[0].text
else:
raise ValueError("Message content is empty or invalid")
# Apply wrapping if needed
if message.role not in {MessageRole.user, MessageRole.system}:
raise ValueError(f"Invalid message role: {message.role}")
elif message.role == MessageRole.user and wrap_user_message:
if message_create.role not in {MessageRole.user, MessageRole.system}:
raise ValueError(f"Invalid message role: {message_create.role}")
elif message_create.role == MessageRole.user and wrap_user_message:
message_content = system.package_user_message(user_message=message_content)
elif message.role == MessageRole.system and wrap_system_message:
elif message_create.role == MessageRole.system and wrap_system_message:
message_content = system.package_system_message(system_message=message_content)
return Message(
agent_id=agent_id,
role=message.role,
role=message_create.role,
content=[TextContent(text=message_content)] if message_content else [],
name=message.name,
name=message_create.name,
model=None, # assigned later?
tool_calls=None, # irrelevant
tool_call_id=None,
otid=message.otid,
sender_id=message.sender_id,
group_id=message.group_id,
otid=message_create.otid,
sender_id=message_create.sender_id,
group_id=message_create.group_id,
batch_item_id=message_create.batch_item_id,
)

View File

@ -44,6 +44,10 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):
sender_id: Mapped[Optional[str]] = mapped_column(
nullable=True, doc="The id of the sender of the message, can be an identity id or agent id"
)
batch_item_id: Mapped[Optional[str]] = mapped_column(
nullable=True,
doc="The id of the LLMBatchItem that this message is associated with",
)
# Monotonically increasing sequence for efficient/correct listing
sequence_id: Mapped[int] = mapped_column(

View File

@ -9,6 +9,7 @@ from pydantic import BaseModel, Field
from letta.helpers.json_helpers import json_dumps
from letta.schemas.enums import JobStatus, MessageStreamStatus
from letta.schemas.letta_message import LettaMessage, LettaMessageUnion
from letta.schemas.message import Message
from letta.schemas.usage import LettaUsageStatistics
# TODO: consider moving into own file
@ -175,3 +176,7 @@ class LettaBatchResponse(BaseModel):
agent_count: int = Field(..., description="The number of agents in the batch request.")
last_polled_at: datetime = Field(..., description="The timestamp when the batch was last polled for updates.")
created_at: datetime = Field(..., description="The timestamp when the batch request was created.")
class LettaBatchMessages(BaseModel):
messages: List[Message]

View File

@ -10,16 +10,18 @@ from letta.schemas.letta_base import OrmMetadataBase
from letta.schemas.llm_config import LLMConfig
class LLMBatchItem(OrmMetadataBase, validate_assignment=True):
class LLMBatchItemBase(OrmMetadataBase, validate_assignment=True):
__id_prefix__ = "batch_item"
class LLMBatchItem(LLMBatchItemBase, validate_assignment=True):
"""
Represents a single agent's LLM request within a batch.
This object captures the configuration, execution status, and eventual result of one agent's request within a larger LLM batch job.
"""
__id_prefix__ = "batch_item"
id: Optional[str] = Field(None, description="The id of the batch item. Assigned by the database.")
id: str = LLMBatchItemBase.generate_id_field()
llm_batch_id: str = Field(..., description="The id of the parent LLM batch job this item belongs to.")
agent_id: str = Field(..., description="The id of the agent associated with this LLM request.")

View File

@ -85,6 +85,7 @@ class MessageCreate(BaseModel):
name: Optional[str] = Field(None, description="The name of the participant.")
otid: Optional[str] = Field(None, description="The offline threading id associated with this message")
sender_id: Optional[str] = Field(None, description="The id of the sender of the message, can be an identity id or agent id")
batch_item_id: Optional[str] = Field(None, description="The id of the LLMBatchItem that this message is associated with")
group_id: Optional[str] = Field(None, description="The multi-agent group that the message was sent in")
def model_dump(self, to_orm: bool = False, **kwargs) -> Dict[str, Any]:
@ -168,6 +169,7 @@ class Message(BaseMessage):
tool_returns: Optional[List[ToolReturn]] = Field(None, description="Tool execution return information for prior tool calls")
group_id: Optional[str] = Field(None, description="The multi-agent group that the message was sent in")
sender_id: Optional[str] = Field(None, description="The id of the sender of the message, can be an identity id or agent id")
batch_item_id: Optional[str] = Field(None, description="The id of the LLMBatchItem that this message is associated with")
# This overrides the optional base orm schema, created_at MUST exist on all messages objects
created_at: datetime = Field(default_factory=get_utc_time, description="The timestamp when the object was created.")

View File

@ -1,6 +1,6 @@
from typing import List, Optional
from fastapi import APIRouter, Body, Depends, Header, status
from fastapi import APIRouter, Body, Depends, Header, Query, status
from fastapi.exceptions import HTTPException
from starlette.requests import Request
@ -9,6 +9,7 @@ from letta.log import get_logger
from letta.orm.errors import NoResultFound
from letta.schemas.job import BatchJob, JobStatus, JobType, JobUpdate
from letta.schemas.letta_request import CreateBatch
from letta.schemas.letta_response import LettaBatchMessages
from letta.server.rest_api.utils import get_letta_server
from letta.server.server import SyncServer
from letta.settings import settings
@ -123,6 +124,50 @@ async def list_batch_runs(
return [BatchJob.from_job(job) for job in jobs]
@router.get(
"/batches/{batch_id}/messages",
response_model=LettaBatchMessages,
operation_id="list_batch_messages",
)
async def list_batch_messages(
batch_id: str,
limit: int = Query(100, description="Maximum number of messages to return"),
cursor: Optional[str] = Query(
None, description="Message ID to use as pagination cursor (get messages before/after this ID) depending on sort_descending."
),
agent_id: Optional[str] = Query(None, description="Filter messages by agent ID"),
sort_descending: bool = Query(True, description="Sort messages by creation time (true=newest first)"),
actor_id: Optional[str] = Header(None, alias="user_id"),
server: SyncServer = Depends(get_letta_server),
):
"""
Get messages for a specific batch job.
Returns messages associated with the batch in chronological order.
Pagination:
- For the first page, omit the cursor parameter
- For subsequent pages, use the ID of the last message from the previous response as the cursor
- Results will include messages before/after the cursor based on sort_descending
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
# First, verify the batch job exists and the user has access to it
try:
job = server.job_manager.get_job_by_id(job_id=batch_id, actor=actor)
BatchJob.from_job(job)
except NoResultFound:
raise HTTPException(status_code=404, detail="Batch not found")
# Get messages directly using our efficient method
# We'll need to update the underlying implementation to use message_id as cursor
messages = server.batch_manager.get_messages_for_letta_batch(
letta_batch_job_id=batch_id, limit=limit, actor=actor, agent_id=agent_id, sort_descending=sort_descending, cursor=cursor
)
return LettaBatchMessages(messages=messages)
@router.patch("/batches/{batch_id}/cancel", operation_id="cancel_batch_run")
async def cancel_batch_run(
batch_id: str,

View File

@ -168,6 +168,7 @@ def create_letta_messages_from_llm_response(
reasoning_content: Optional[List[Union[TextContent, ReasoningContent, RedactedReasoningContent, OmittedReasoningContent]]] = None,
pre_computed_assistant_message_id: Optional[str] = None,
pre_computed_tool_message_id: Optional[str] = None,
llm_batch_item_id: Optional[str] = None,
) -> List[Message]:
messages = []
@ -192,6 +193,7 @@ def create_letta_messages_from_llm_response(
tool_calls=[tool_call],
tool_call_id=tool_call_id,
created_at=get_utc_time(),
batch_item_id=llm_batch_item_id,
)
if pre_computed_assistant_message_id:
assistant_message.id = pre_computed_assistant_message_id
@ -209,6 +211,7 @@ def create_letta_messages_from_llm_response(
tool_call_id=tool_call_id,
created_at=get_utc_time(),
name=function_name,
batch_item_id=llm_batch_item_id,
)
if pre_computed_tool_message_id:
tool_message.id = pre_computed_tool_message_id
@ -216,7 +219,7 @@ def create_letta_messages_from_llm_response(
if add_heartbeat_request_system_message:
heartbeat_system_message = create_heartbeat_system_message(
agent_id=agent_id, model=model, function_call_success=function_call_success, actor=actor
agent_id=agent_id, model=model, function_call_success=function_call_success, actor=actor, llm_batch_item_id=llm_batch_item_id
)
messages.append(heartbeat_system_message)
@ -224,10 +227,7 @@ def create_letta_messages_from_llm_response(
def create_heartbeat_system_message(
agent_id: str,
model: str,
function_call_success: bool,
actor: User,
agent_id: str, model: str, function_call_success: bool, actor: User, llm_batch_item_id: Optional[str] = None
) -> Message:
text_content = REQ_HEARTBEAT_MESSAGE if function_call_success else FUNC_FAILED_HEARTBEAT_MESSAGE
heartbeat_system_message = Message(
@ -239,6 +239,7 @@ def create_heartbeat_system_message(
tool_calls=[],
tool_call_id=None,
created_at=get_utc_time(),
batch_item_id=llm_batch_item_id,
)
return heartbeat_system_message

View File

@ -2,10 +2,11 @@ import datetime
from typing import Any, Dict, List, Optional, Tuple
from anthropic.types.beta.messages import BetaMessageBatch, BetaMessageBatchIndividualResponse
from sqlalchemy import func, tuple_
from sqlalchemy import desc, func, tuple_
from letta.jobs.types import BatchPollingResult, ItemUpdateInfo, RequestStatusUpdateInfo, StepStatusUpdateInfo
from letta.log import get_logger
from letta.orm import Message as MessageModel
from letta.orm.llm_batch_items import LLMBatchItem
from letta.orm.llm_batch_job import LLMBatchJob
from letta.schemas.agent import AgentStepState
@ -13,6 +14,7 @@ from letta.schemas.enums import AgentStepStatus, JobStatus, ProviderType
from letta.schemas.llm_batch_job import LLMBatchItem as PydanticLLMBatchItem
from letta.schemas.llm_batch_job import LLMBatchJob as PydanticLLMBatchJob
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message as PydanticMessage
from letta.schemas.user import User as PydanticUser
from letta.utils import enforce_types
@ -142,6 +144,62 @@ class LLMBatchManager:
batch = LLMBatchJob.read(db_session=session, identifier=llm_batch_id, actor=actor)
batch.hard_delete(db_session=session, actor=actor)
@enforce_types
def get_messages_for_letta_batch(
self,
letta_batch_job_id: str,
limit: int = 100,
actor: Optional[PydanticUser] = None,
agent_id: Optional[str] = None,
sort_descending: bool = True,
cursor: Optional[str] = None, # Message ID as cursor
) -> List[PydanticMessage]:
"""
Retrieve messages across all LLM batch jobs associated with a Letta batch job.
Optimized for PostgreSQL performance using ID-based keyset pagination.
"""
with self.session_maker() as session:
# If cursor is provided, get sequence_id for that message
cursor_sequence_id = None
if cursor:
cursor_query = session.query(MessageModel.sequence_id).filter(MessageModel.id == cursor).limit(1)
cursor_result = cursor_query.first()
if cursor_result:
cursor_sequence_id = cursor_result[0]
else:
# If cursor message doesn't exist, ignore it
pass
query = (
session.query(MessageModel)
.join(LLMBatchItem, MessageModel.batch_item_id == LLMBatchItem.id)
.join(LLMBatchJob, LLMBatchItem.llm_batch_id == LLMBatchJob.id)
.filter(LLMBatchJob.letta_batch_job_id == letta_batch_job_id)
)
if actor is not None:
query = query.filter(MessageModel.organization_id == actor.organization_id)
if agent_id is not None:
query = query.filter(MessageModel.agent_id == agent_id)
# Apply cursor-based pagination if cursor exists
if cursor_sequence_id is not None:
if sort_descending:
query = query.filter(MessageModel.sequence_id < cursor_sequence_id)
else:
query = query.filter(MessageModel.sequence_id > cursor_sequence_id)
if sort_descending:
query = query.order_by(desc(MessageModel.sequence_id))
else:
query = query.order_by(MessageModel.sequence_id)
query = query.limit(limit)
results = query.all()
return [message.to_pydantic() for message in results]
@enforce_types
def list_running_llm_batches(self, actor: Optional[PydanticUser] = None) -> List[PydanticLLMBatchJob]:
"""Return all running LLM batch jobs, optionally filtered by actor's organization."""
@ -196,6 +254,7 @@ class LLMBatchManager:
orm_items = []
for item in llm_batch_items:
orm_item = LLMBatchItem(
id=item.id,
llm_batch_id=item.llm_batch_id,
agent_id=item.agent_id,
llm_config=item.llm_config,

View File

@ -73,6 +73,7 @@ class MessageManager:
Returns:
List of created Pydantic message models
"""
if not pydantic_msgs:
return []

View File

@ -23,7 +23,7 @@ from letta.helpers import ToolRulesSolver
from letta.jobs.llm_batch_job_polling import poll_running_llm_batches
from letta.orm import Base
from letta.schemas.agent import AgentState, AgentStepState
from letta.schemas.enums import AgentStepStatus, JobStatus, ProviderType
from letta.schemas.enums import AgentStepStatus, JobStatus, MessageRole, ProviderType
from letta.schemas.job import BatchJob
from letta.schemas.letta_message_content import TextContent
from letta.schemas.letta_request import LettaBatchRequest
@ -589,6 +589,26 @@ async def test_partial_error_from_anthropic_batch(
len(refreshed_agent.message_ids) == 6
), f"Agent's in-context messages have been extended, are length: {len(refreshed_agent.message_ids)}"
# Check the total list of messages
messages = server.batch_manager.get_messages_for_letta_batch(
letta_batch_job_id=pre_resume_response.letta_batch_id, limit=200, actor=default_user
)
assert len(messages) == (len(agents) - 1) * 4 + 1
assert_descending_order(messages)
# Check that each agent is represented
for agent in agents_continue:
agent_messages = [m for m in messages if m.agent_id == agent.id]
assert len(agent_messages) == 4
assert agent_messages[-1].role == MessageRole.user, "Expected initial user message"
assert agent_messages[-2].role == MessageRole.assistant, "Expected assistant tool call after user message"
assert agent_messages[-3].role == MessageRole.tool, "Expected tool response after assistant tool call"
assert agent_messages[-4].role == MessageRole.user, "Expected final system-level heartbeat user message"
for agent in agents_failed:
agent_messages = [m for m in messages if m.agent_id == agent.id]
assert len(agent_messages) == 1
assert agent_messages[0].role == MessageRole.user, "Expected initial user message"
@pytest.mark.asyncio
async def test_resume_step_some_stop(
@ -718,6 +738,42 @@ async def test_resume_step_some_stop(
len(refreshed_agent.message_ids) == 6
), f"Agent's in-context messages have been extended, are length: {len(refreshed_agent.message_ids)}"
# Check the total list of messages
messages = server.batch_manager.get_messages_for_letta_batch(
letta_batch_job_id=pre_resume_response.letta_batch_id, limit=200, actor=default_user
)
assert len(messages) == len(agents) * 3 + 1
assert_descending_order(messages)
# Check that each agent is represented
for agent in agents_continue:
agent_messages = [m for m in messages if m.agent_id == agent.id]
assert len(agent_messages) == 4
assert agent_messages[-1].role == MessageRole.user, "Expected initial user message"
assert agent_messages[-2].role == MessageRole.assistant, "Expected assistant tool call after user message"
assert agent_messages[-3].role == MessageRole.tool, "Expected tool response after assistant tool call"
assert agent_messages[-4].role == MessageRole.user, "Expected final system-level heartbeat user message"
for agent in agents_finish:
agent_messages = [m for m in messages if m.agent_id == agent.id]
assert len(agent_messages) == 3
assert agent_messages[-1].role == MessageRole.user, "Expected initial user message"
assert agent_messages[-2].role == MessageRole.assistant, "Expected assistant tool call after user message"
assert agent_messages[-3].role == MessageRole.tool, "Expected tool response after assistant tool call"
def assert_descending_order(messages):
"""Assert messages are in descending order by created_at timestamps."""
if len(messages) <= 1:
return True
for i in range(1, len(messages)):
assert messages[i].created_at <= messages[i - 1].created_at, (
f"Order violation: {messages[i - 1].id} ({messages[i - 1].created_at}) "
f"followed by {messages[i].id} ({messages[i].created_at})"
)
return True
@pytest.mark.asyncio
async def test_resume_step_after_request_all_continue(
@ -841,6 +897,21 @@ async def test_resume_step_after_request_all_continue(
len(refreshed_agent.message_ids) == 6
), f"Agent's in-context messages have been extended, are length: {len(refreshed_agent.message_ids)}"
# Check the total list of messages
messages = server.batch_manager.get_messages_for_letta_batch(
letta_batch_job_id=pre_resume_response.letta_batch_id, limit=200, actor=default_user
)
assert len(messages) == len(agents) * 4
assert_descending_order(messages)
# Check that each agent is represented
for agent in agents:
agent_messages = [m for m in messages if m.agent_id == agent.id]
assert len(agent_messages) == 4
assert agent_messages[-1].role == MessageRole.user, "Expected initial user message"
assert agent_messages[-2].role == MessageRole.assistant, "Expected assistant tool call after user message"
assert agent_messages[-3].role == MessageRole.tool, "Expected tool response after assistant tool call"
assert agent_messages[-4].role == MessageRole.user, "Expected final system-level heartbeat user message"
@pytest.mark.asyncio
async def test_step_until_request_prepares_and_submits_batch_correctly(