mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
feat: Add message listing for a letta batch (#1982)
This commit is contained in:
parent
ee1f3b54c6
commit
6590786bd3
@ -0,0 +1,31 @@
|
||||
"""Add batch_item_id to messages
|
||||
|
||||
Revision ID: 0335b1eb9c40
|
||||
Revises: 373dabcba6cf
|
||||
Create Date: 2025-05-02 10:30:08.156190
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "0335b1eb9c40"
|
||||
down_revision: Union[str, None] = "373dabcba6cf"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("messages", sa.Column("batch_item_id", sa.String(), nullable=True))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("messages", "batch_item_id")
|
||||
# ### end Alembic commands ###
|
@ -21,7 +21,10 @@ def _create_letta_response(new_in_context_messages: list[Message], use_assistant
|
||||
|
||||
|
||||
def _prepare_in_context_messages(
|
||||
input_messages: List[MessageCreate], agent_state: AgentState, message_manager: MessageManager, actor: User
|
||||
input_messages: List[MessageCreate],
|
||||
agent_state: AgentState,
|
||||
message_manager: MessageManager,
|
||||
actor: User,
|
||||
) -> Tuple[List[Message], List[Message]]:
|
||||
"""
|
||||
Prepares in-context messages for an agent, based on the current state and a new user input.
|
||||
|
@ -137,21 +137,37 @@ class LettaAgentBatch:
|
||||
log_event(name="load_and_prepare_agents")
|
||||
agent_messages_mapping: Dict[str, List[Message]] = {}
|
||||
agent_tools_mapping: Dict[str, List[dict]] = {}
|
||||
# TODO: This isn't optimal, moving fast - prone to bugs because we pass around this half formed pydantic object
|
||||
agent_batch_item_mapping: Dict[str, LLMBatchItem] = {}
|
||||
agent_states = []
|
||||
for batch_request in batch_requests:
|
||||
agent_id = batch_request.agent_id
|
||||
agent_state = self.agent_manager.get_agent_by_id(agent_id, actor=self.actor)
|
||||
agent_states.append(agent_state)
|
||||
|
||||
agent_messages_mapping[agent_id] = self._get_in_context_messages_per_agent(
|
||||
agent_state=agent_state, input_messages=batch_request.messages
|
||||
)
|
||||
|
||||
if agent_id not in agent_step_state_mapping:
|
||||
agent_step_state_mapping[agent_id] = AgentStepState(
|
||||
step_number=0, tool_rules_solver=ToolRulesSolver(tool_rules=agent_state.tool_rules)
|
||||
)
|
||||
|
||||
llm_batch_item = LLMBatchItem(
|
||||
llm_batch_id="", # TODO: This is hacky, it gets filled in later
|
||||
agent_id=agent_state.id,
|
||||
llm_config=agent_state.llm_config,
|
||||
request_status=JobStatus.created,
|
||||
step_status=AgentStepStatus.paused,
|
||||
step_state=agent_step_state_mapping[agent_id],
|
||||
)
|
||||
agent_batch_item_mapping[agent_id] = llm_batch_item
|
||||
|
||||
# Fill in the batch_item_id for the message
|
||||
for msg in batch_request.messages:
|
||||
msg.batch_item_id = llm_batch_item.id
|
||||
|
||||
agent_messages_mapping[agent_id] = self._prepare_in_context_messages_per_agent(
|
||||
agent_state=agent_state, input_messages=batch_request.messages
|
||||
)
|
||||
|
||||
agent_tools_mapping[agent_id] = self._prepare_tools_per_agent(agent_state, agent_step_state_mapping[agent_id].tool_rules_solver)
|
||||
|
||||
log_event(name="init_llm_client")
|
||||
@ -182,21 +198,14 @@ class LettaAgentBatch:
|
||||
log_event(name="prepare_batch_items")
|
||||
batch_items = []
|
||||
for state in agent_states:
|
||||
step_state = agent_step_state_mapping[state.id]
|
||||
batch_items.append(
|
||||
LLMBatchItem(
|
||||
llm_batch_id=llm_batch_job.id,
|
||||
agent_id=state.id,
|
||||
llm_config=state.llm_config,
|
||||
request_status=JobStatus.created,
|
||||
step_status=AgentStepStatus.paused,
|
||||
step_state=step_state,
|
||||
)
|
||||
)
|
||||
llm_batch_item = agent_batch_item_mapping[state.id]
|
||||
# TODO This is hacky
|
||||
llm_batch_item.llm_batch_id = llm_batch_job.id
|
||||
batch_items.append(llm_batch_item)
|
||||
|
||||
if batch_items:
|
||||
log_event(name="bulk_create_batch_items")
|
||||
self.batch_manager.create_llm_batch_items_bulk(batch_items, actor=self.actor)
|
||||
batch_items_persisted = self.batch_manager.create_llm_batch_items_bulk(batch_items, actor=self.actor)
|
||||
|
||||
log_event(name="return_batch_response")
|
||||
return LettaBatchResponse(
|
||||
@ -335,9 +344,14 @@ class LettaAgentBatch:
|
||||
exec_results: Sequence[Tuple[str, Tuple[str, bool]]],
|
||||
ctx: _ResumeContext,
|
||||
) -> Dict[str, List[Message]]:
|
||||
# TODO: This is redundant, we should have this ready on the ctx
|
||||
# TODO: I am doing it quick and dirty for now
|
||||
agent_item_map: Dict[str, LLMBatchItem] = {item.agent_id: item for item in ctx.batch_items}
|
||||
|
||||
msg_map: Dict[str, List[Message]] = {}
|
||||
for aid, (tool_res, success) in exec_results:
|
||||
msgs = self._create_tool_call_messages(
|
||||
llm_batch_item_id=agent_item_map[aid].id,
|
||||
agent_state=ctx.agent_state_map[aid],
|
||||
tool_call_name=ctx.tool_call_name_map[aid],
|
||||
tool_call_args=ctx.tool_call_args_map[aid],
|
||||
@ -399,6 +413,7 @@ class LettaAgentBatch:
|
||||
|
||||
def _create_tool_call_messages(
|
||||
self,
|
||||
llm_batch_item_id: str,
|
||||
agent_state: AgentState,
|
||||
tool_call_name: str,
|
||||
tool_call_args: Dict[str, Any],
|
||||
@ -421,6 +436,7 @@ class LettaAgentBatch:
|
||||
reasoning_content=reasoning_content,
|
||||
pre_computed_assistant_message_id=None,
|
||||
pre_computed_tool_message_id=None,
|
||||
llm_batch_item_id=llm_batch_item_id,
|
||||
)
|
||||
|
||||
return tool_call_messages
|
||||
@ -477,7 +493,7 @@ class LettaAgentBatch:
|
||||
valid_tool_names = tool_rules_solver.get_allowed_tool_names(available_tools=set([t.name for t in tools]))
|
||||
return [enable_strict_mode(t.json_schema) for t in tools if t.name in set(valid_tool_names)]
|
||||
|
||||
def _get_in_context_messages_per_agent(self, agent_state: AgentState, input_messages: List[MessageCreate]) -> List[Message]:
|
||||
def _prepare_in_context_messages_per_agent(self, agent_state: AgentState, input_messages: List[MessageCreate]) -> List[Message]:
|
||||
current_in_context_messages, new_in_context_messages = _prepare_in_context_messages(
|
||||
input_messages, agent_state, self.message_manager, self.actor
|
||||
)
|
||||
|
@ -5,57 +5,58 @@ from letta.schemas.message import Message, MessageCreate
|
||||
|
||||
|
||||
def convert_message_creates_to_messages(
|
||||
messages: list[MessageCreate],
|
||||
message_creates: list[MessageCreate],
|
||||
agent_id: str,
|
||||
wrap_user_message: bool = True,
|
||||
wrap_system_message: bool = True,
|
||||
) -> list[Message]:
|
||||
return [
|
||||
_convert_message_create_to_message(
|
||||
message=message,
|
||||
message_create=create,
|
||||
agent_id=agent_id,
|
||||
wrap_user_message=wrap_user_message,
|
||||
wrap_system_message=wrap_system_message,
|
||||
)
|
||||
for message in messages
|
||||
for create in message_creates
|
||||
]
|
||||
|
||||
|
||||
def _convert_message_create_to_message(
|
||||
message: MessageCreate,
|
||||
message_create: MessageCreate,
|
||||
agent_id: str,
|
||||
wrap_user_message: bool = True,
|
||||
wrap_system_message: bool = True,
|
||||
) -> Message:
|
||||
"""Converts a MessageCreate object into a Message object, applying wrapping if needed."""
|
||||
# TODO: This seems like extra boilerplate with little benefit
|
||||
assert isinstance(message, MessageCreate)
|
||||
assert isinstance(message_create, MessageCreate)
|
||||
|
||||
# Extract message content
|
||||
if isinstance(message.content, str):
|
||||
message_content = message.content
|
||||
elif message.content and len(message.content) > 0 and isinstance(message.content[0], TextContent):
|
||||
message_content = message.content[0].text
|
||||
if isinstance(message_create.content, str):
|
||||
message_content = message_create.content
|
||||
elif message_create.content and len(message_create.content) > 0 and isinstance(message_create.content[0], TextContent):
|
||||
message_content = message_create.content[0].text
|
||||
else:
|
||||
raise ValueError("Message content is empty or invalid")
|
||||
|
||||
# Apply wrapping if needed
|
||||
if message.role not in {MessageRole.user, MessageRole.system}:
|
||||
raise ValueError(f"Invalid message role: {message.role}")
|
||||
elif message.role == MessageRole.user and wrap_user_message:
|
||||
if message_create.role not in {MessageRole.user, MessageRole.system}:
|
||||
raise ValueError(f"Invalid message role: {message_create.role}")
|
||||
elif message_create.role == MessageRole.user and wrap_user_message:
|
||||
message_content = system.package_user_message(user_message=message_content)
|
||||
elif message.role == MessageRole.system and wrap_system_message:
|
||||
elif message_create.role == MessageRole.system and wrap_system_message:
|
||||
message_content = system.package_system_message(system_message=message_content)
|
||||
|
||||
return Message(
|
||||
agent_id=agent_id,
|
||||
role=message.role,
|
||||
role=message_create.role,
|
||||
content=[TextContent(text=message_content)] if message_content else [],
|
||||
name=message.name,
|
||||
name=message_create.name,
|
||||
model=None, # assigned later?
|
||||
tool_calls=None, # irrelevant
|
||||
tool_call_id=None,
|
||||
otid=message.otid,
|
||||
sender_id=message.sender_id,
|
||||
group_id=message.group_id,
|
||||
otid=message_create.otid,
|
||||
sender_id=message_create.sender_id,
|
||||
group_id=message_create.group_id,
|
||||
batch_item_id=message_create.batch_item_id,
|
||||
)
|
||||
|
@ -44,6 +44,10 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):
|
||||
sender_id: Mapped[Optional[str]] = mapped_column(
|
||||
nullable=True, doc="The id of the sender of the message, can be an identity id or agent id"
|
||||
)
|
||||
batch_item_id: Mapped[Optional[str]] = mapped_column(
|
||||
nullable=True,
|
||||
doc="The id of the LLMBatchItem that this message is associated with",
|
||||
)
|
||||
|
||||
# Monotonically increasing sequence for efficient/correct listing
|
||||
sequence_id: Mapped[int] = mapped_column(
|
||||
|
@ -9,6 +9,7 @@ from pydantic import BaseModel, Field
|
||||
from letta.helpers.json_helpers import json_dumps
|
||||
from letta.schemas.enums import JobStatus, MessageStreamStatus
|
||||
from letta.schemas.letta_message import LettaMessage, LettaMessageUnion
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
|
||||
# TODO: consider moving into own file
|
||||
@ -175,3 +176,7 @@ class LettaBatchResponse(BaseModel):
|
||||
agent_count: int = Field(..., description="The number of agents in the batch request.")
|
||||
last_polled_at: datetime = Field(..., description="The timestamp when the batch was last polled for updates.")
|
||||
created_at: datetime = Field(..., description="The timestamp when the batch request was created.")
|
||||
|
||||
|
||||
class LettaBatchMessages(BaseModel):
|
||||
messages: List[Message]
|
||||
|
@ -10,16 +10,18 @@ from letta.schemas.letta_base import OrmMetadataBase
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
|
||||
|
||||
class LLMBatchItem(OrmMetadataBase, validate_assignment=True):
|
||||
class LLMBatchItemBase(OrmMetadataBase, validate_assignment=True):
|
||||
__id_prefix__ = "batch_item"
|
||||
|
||||
|
||||
class LLMBatchItem(LLMBatchItemBase, validate_assignment=True):
|
||||
"""
|
||||
Represents a single agent's LLM request within a batch.
|
||||
|
||||
This object captures the configuration, execution status, and eventual result of one agent's request within a larger LLM batch job.
|
||||
"""
|
||||
|
||||
__id_prefix__ = "batch_item"
|
||||
|
||||
id: Optional[str] = Field(None, description="The id of the batch item. Assigned by the database.")
|
||||
id: str = LLMBatchItemBase.generate_id_field()
|
||||
llm_batch_id: str = Field(..., description="The id of the parent LLM batch job this item belongs to.")
|
||||
agent_id: str = Field(..., description="The id of the agent associated with this LLM request.")
|
||||
|
||||
|
@ -85,6 +85,7 @@ class MessageCreate(BaseModel):
|
||||
name: Optional[str] = Field(None, description="The name of the participant.")
|
||||
otid: Optional[str] = Field(None, description="The offline threading id associated with this message")
|
||||
sender_id: Optional[str] = Field(None, description="The id of the sender of the message, can be an identity id or agent id")
|
||||
batch_item_id: Optional[str] = Field(None, description="The id of the LLMBatchItem that this message is associated with")
|
||||
group_id: Optional[str] = Field(None, description="The multi-agent group that the message was sent in")
|
||||
|
||||
def model_dump(self, to_orm: bool = False, **kwargs) -> Dict[str, Any]:
|
||||
@ -168,6 +169,7 @@ class Message(BaseMessage):
|
||||
tool_returns: Optional[List[ToolReturn]] = Field(None, description="Tool execution return information for prior tool calls")
|
||||
group_id: Optional[str] = Field(None, description="The multi-agent group that the message was sent in")
|
||||
sender_id: Optional[str] = Field(None, description="The id of the sender of the message, can be an identity id or agent id")
|
||||
batch_item_id: Optional[str] = Field(None, description="The id of the LLMBatchItem that this message is associated with")
|
||||
# This overrides the optional base orm schema, created_at MUST exist on all messages objects
|
||||
created_at: datetime = Field(default_factory=get_utc_time, description="The timestamp when the object was created.")
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, Header, status
|
||||
from fastapi import APIRouter, Body, Depends, Header, Query, status
|
||||
from fastapi.exceptions import HTTPException
|
||||
from starlette.requests import Request
|
||||
|
||||
@ -9,6 +9,7 @@ from letta.log import get_logger
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.schemas.job import BatchJob, JobStatus, JobType, JobUpdate
|
||||
from letta.schemas.letta_request import CreateBatch
|
||||
from letta.schemas.letta_response import LettaBatchMessages
|
||||
from letta.server.rest_api.utils import get_letta_server
|
||||
from letta.server.server import SyncServer
|
||||
from letta.settings import settings
|
||||
@ -123,6 +124,50 @@ async def list_batch_runs(
|
||||
return [BatchJob.from_job(job) for job in jobs]
|
||||
|
||||
|
||||
@router.get(
|
||||
"/batches/{batch_id}/messages",
|
||||
response_model=LettaBatchMessages,
|
||||
operation_id="list_batch_messages",
|
||||
)
|
||||
async def list_batch_messages(
|
||||
batch_id: str,
|
||||
limit: int = Query(100, description="Maximum number of messages to return"),
|
||||
cursor: Optional[str] = Query(
|
||||
None, description="Message ID to use as pagination cursor (get messages before/after this ID) depending on sort_descending."
|
||||
),
|
||||
agent_id: Optional[str] = Query(None, description="Filter messages by agent ID"),
|
||||
sort_descending: bool = Query(True, description="Sort messages by creation time (true=newest first)"),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
):
|
||||
"""
|
||||
Get messages for a specific batch job.
|
||||
|
||||
Returns messages associated with the batch in chronological order.
|
||||
|
||||
Pagination:
|
||||
- For the first page, omit the cursor parameter
|
||||
- For subsequent pages, use the ID of the last message from the previous response as the cursor
|
||||
- Results will include messages before/after the cursor based on sort_descending
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
# First, verify the batch job exists and the user has access to it
|
||||
try:
|
||||
job = server.job_manager.get_job_by_id(job_id=batch_id, actor=actor)
|
||||
BatchJob.from_job(job)
|
||||
except NoResultFound:
|
||||
raise HTTPException(status_code=404, detail="Batch not found")
|
||||
|
||||
# Get messages directly using our efficient method
|
||||
# We'll need to update the underlying implementation to use message_id as cursor
|
||||
messages = server.batch_manager.get_messages_for_letta_batch(
|
||||
letta_batch_job_id=batch_id, limit=limit, actor=actor, agent_id=agent_id, sort_descending=sort_descending, cursor=cursor
|
||||
)
|
||||
|
||||
return LettaBatchMessages(messages=messages)
|
||||
|
||||
|
||||
@router.patch("/batches/{batch_id}/cancel", operation_id="cancel_batch_run")
|
||||
async def cancel_batch_run(
|
||||
batch_id: str,
|
||||
|
@ -168,6 +168,7 @@ def create_letta_messages_from_llm_response(
|
||||
reasoning_content: Optional[List[Union[TextContent, ReasoningContent, RedactedReasoningContent, OmittedReasoningContent]]] = None,
|
||||
pre_computed_assistant_message_id: Optional[str] = None,
|
||||
pre_computed_tool_message_id: Optional[str] = None,
|
||||
llm_batch_item_id: Optional[str] = None,
|
||||
) -> List[Message]:
|
||||
messages = []
|
||||
|
||||
@ -192,6 +193,7 @@ def create_letta_messages_from_llm_response(
|
||||
tool_calls=[tool_call],
|
||||
tool_call_id=tool_call_id,
|
||||
created_at=get_utc_time(),
|
||||
batch_item_id=llm_batch_item_id,
|
||||
)
|
||||
if pre_computed_assistant_message_id:
|
||||
assistant_message.id = pre_computed_assistant_message_id
|
||||
@ -209,6 +211,7 @@ def create_letta_messages_from_llm_response(
|
||||
tool_call_id=tool_call_id,
|
||||
created_at=get_utc_time(),
|
||||
name=function_name,
|
||||
batch_item_id=llm_batch_item_id,
|
||||
)
|
||||
if pre_computed_tool_message_id:
|
||||
tool_message.id = pre_computed_tool_message_id
|
||||
@ -216,7 +219,7 @@ def create_letta_messages_from_llm_response(
|
||||
|
||||
if add_heartbeat_request_system_message:
|
||||
heartbeat_system_message = create_heartbeat_system_message(
|
||||
agent_id=agent_id, model=model, function_call_success=function_call_success, actor=actor
|
||||
agent_id=agent_id, model=model, function_call_success=function_call_success, actor=actor, llm_batch_item_id=llm_batch_item_id
|
||||
)
|
||||
messages.append(heartbeat_system_message)
|
||||
|
||||
@ -224,10 +227,7 @@ def create_letta_messages_from_llm_response(
|
||||
|
||||
|
||||
def create_heartbeat_system_message(
|
||||
agent_id: str,
|
||||
model: str,
|
||||
function_call_success: bool,
|
||||
actor: User,
|
||||
agent_id: str, model: str, function_call_success: bool, actor: User, llm_batch_item_id: Optional[str] = None
|
||||
) -> Message:
|
||||
text_content = REQ_HEARTBEAT_MESSAGE if function_call_success else FUNC_FAILED_HEARTBEAT_MESSAGE
|
||||
heartbeat_system_message = Message(
|
||||
@ -239,6 +239,7 @@ def create_heartbeat_system_message(
|
||||
tool_calls=[],
|
||||
tool_call_id=None,
|
||||
created_at=get_utc_time(),
|
||||
batch_item_id=llm_batch_item_id,
|
||||
)
|
||||
return heartbeat_system_message
|
||||
|
||||
|
@ -2,10 +2,11 @@ import datetime
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from anthropic.types.beta.messages import BetaMessageBatch, BetaMessageBatchIndividualResponse
|
||||
from sqlalchemy import func, tuple_
|
||||
from sqlalchemy import desc, func, tuple_
|
||||
|
||||
from letta.jobs.types import BatchPollingResult, ItemUpdateInfo, RequestStatusUpdateInfo, StepStatusUpdateInfo
|
||||
from letta.log import get_logger
|
||||
from letta.orm import Message as MessageModel
|
||||
from letta.orm.llm_batch_items import LLMBatchItem
|
||||
from letta.orm.llm_batch_job import LLMBatchJob
|
||||
from letta.schemas.agent import AgentStepState
|
||||
@ -13,6 +14,7 @@ from letta.schemas.enums import AgentStepStatus, JobStatus, ProviderType
|
||||
from letta.schemas.llm_batch_job import LLMBatchItem as PydanticLLMBatchItem
|
||||
from letta.schemas.llm_batch_job import LLMBatchJob as PydanticLLMBatchJob
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.utils import enforce_types
|
||||
|
||||
@ -142,6 +144,62 @@ class LLMBatchManager:
|
||||
batch = LLMBatchJob.read(db_session=session, identifier=llm_batch_id, actor=actor)
|
||||
batch.hard_delete(db_session=session, actor=actor)
|
||||
|
||||
@enforce_types
|
||||
def get_messages_for_letta_batch(
|
||||
self,
|
||||
letta_batch_job_id: str,
|
||||
limit: int = 100,
|
||||
actor: Optional[PydanticUser] = None,
|
||||
agent_id: Optional[str] = None,
|
||||
sort_descending: bool = True,
|
||||
cursor: Optional[str] = None, # Message ID as cursor
|
||||
) -> List[PydanticMessage]:
|
||||
"""
|
||||
Retrieve messages across all LLM batch jobs associated with a Letta batch job.
|
||||
Optimized for PostgreSQL performance using ID-based keyset pagination.
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
# If cursor is provided, get sequence_id for that message
|
||||
cursor_sequence_id = None
|
||||
if cursor:
|
||||
cursor_query = session.query(MessageModel.sequence_id).filter(MessageModel.id == cursor).limit(1)
|
||||
cursor_result = cursor_query.first()
|
||||
if cursor_result:
|
||||
cursor_sequence_id = cursor_result[0]
|
||||
else:
|
||||
# If cursor message doesn't exist, ignore it
|
||||
pass
|
||||
|
||||
query = (
|
||||
session.query(MessageModel)
|
||||
.join(LLMBatchItem, MessageModel.batch_item_id == LLMBatchItem.id)
|
||||
.join(LLMBatchJob, LLMBatchItem.llm_batch_id == LLMBatchJob.id)
|
||||
.filter(LLMBatchJob.letta_batch_job_id == letta_batch_job_id)
|
||||
)
|
||||
|
||||
if actor is not None:
|
||||
query = query.filter(MessageModel.organization_id == actor.organization_id)
|
||||
|
||||
if agent_id is not None:
|
||||
query = query.filter(MessageModel.agent_id == agent_id)
|
||||
|
||||
# Apply cursor-based pagination if cursor exists
|
||||
if cursor_sequence_id is not None:
|
||||
if sort_descending:
|
||||
query = query.filter(MessageModel.sequence_id < cursor_sequence_id)
|
||||
else:
|
||||
query = query.filter(MessageModel.sequence_id > cursor_sequence_id)
|
||||
|
||||
if sort_descending:
|
||||
query = query.order_by(desc(MessageModel.sequence_id))
|
||||
else:
|
||||
query = query.order_by(MessageModel.sequence_id)
|
||||
|
||||
query = query.limit(limit)
|
||||
|
||||
results = query.all()
|
||||
return [message.to_pydantic() for message in results]
|
||||
|
||||
@enforce_types
|
||||
def list_running_llm_batches(self, actor: Optional[PydanticUser] = None) -> List[PydanticLLMBatchJob]:
|
||||
"""Return all running LLM batch jobs, optionally filtered by actor's organization."""
|
||||
@ -196,6 +254,7 @@ class LLMBatchManager:
|
||||
orm_items = []
|
||||
for item in llm_batch_items:
|
||||
orm_item = LLMBatchItem(
|
||||
id=item.id,
|
||||
llm_batch_id=item.llm_batch_id,
|
||||
agent_id=item.agent_id,
|
||||
llm_config=item.llm_config,
|
||||
|
@ -73,6 +73,7 @@ class MessageManager:
|
||||
Returns:
|
||||
List of created Pydantic message models
|
||||
"""
|
||||
|
||||
if not pydantic_msgs:
|
||||
return []
|
||||
|
||||
|
@ -23,7 +23,7 @@ from letta.helpers import ToolRulesSolver
|
||||
from letta.jobs.llm_batch_job_polling import poll_running_llm_batches
|
||||
from letta.orm import Base
|
||||
from letta.schemas.agent import AgentState, AgentStepState
|
||||
from letta.schemas.enums import AgentStepStatus, JobStatus, ProviderType
|
||||
from letta.schemas.enums import AgentStepStatus, JobStatus, MessageRole, ProviderType
|
||||
from letta.schemas.job import BatchJob
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.letta_request import LettaBatchRequest
|
||||
@ -589,6 +589,26 @@ async def test_partial_error_from_anthropic_batch(
|
||||
len(refreshed_agent.message_ids) == 6
|
||||
), f"Agent's in-context messages have been extended, are length: {len(refreshed_agent.message_ids)}"
|
||||
|
||||
# Check the total list of messages
|
||||
messages = server.batch_manager.get_messages_for_letta_batch(
|
||||
letta_batch_job_id=pre_resume_response.letta_batch_id, limit=200, actor=default_user
|
||||
)
|
||||
assert len(messages) == (len(agents) - 1) * 4 + 1
|
||||
assert_descending_order(messages)
|
||||
# Check that each agent is represented
|
||||
for agent in agents_continue:
|
||||
agent_messages = [m for m in messages if m.agent_id == agent.id]
|
||||
assert len(agent_messages) == 4
|
||||
assert agent_messages[-1].role == MessageRole.user, "Expected initial user message"
|
||||
assert agent_messages[-2].role == MessageRole.assistant, "Expected assistant tool call after user message"
|
||||
assert agent_messages[-3].role == MessageRole.tool, "Expected tool response after assistant tool call"
|
||||
assert agent_messages[-4].role == MessageRole.user, "Expected final system-level heartbeat user message"
|
||||
|
||||
for agent in agents_failed:
|
||||
agent_messages = [m for m in messages if m.agent_id == agent.id]
|
||||
assert len(agent_messages) == 1
|
||||
assert agent_messages[0].role == MessageRole.user, "Expected initial user message"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_step_some_stop(
|
||||
@ -718,6 +738,42 @@ async def test_resume_step_some_stop(
|
||||
len(refreshed_agent.message_ids) == 6
|
||||
), f"Agent's in-context messages have been extended, are length: {len(refreshed_agent.message_ids)}"
|
||||
|
||||
# Check the total list of messages
|
||||
messages = server.batch_manager.get_messages_for_letta_batch(
|
||||
letta_batch_job_id=pre_resume_response.letta_batch_id, limit=200, actor=default_user
|
||||
)
|
||||
assert len(messages) == len(agents) * 3 + 1
|
||||
assert_descending_order(messages)
|
||||
# Check that each agent is represented
|
||||
for agent in agents_continue:
|
||||
agent_messages = [m for m in messages if m.agent_id == agent.id]
|
||||
assert len(agent_messages) == 4
|
||||
assert agent_messages[-1].role == MessageRole.user, "Expected initial user message"
|
||||
assert agent_messages[-2].role == MessageRole.assistant, "Expected assistant tool call after user message"
|
||||
assert agent_messages[-3].role == MessageRole.tool, "Expected tool response after assistant tool call"
|
||||
assert agent_messages[-4].role == MessageRole.user, "Expected final system-level heartbeat user message"
|
||||
|
||||
for agent in agents_finish:
|
||||
agent_messages = [m for m in messages if m.agent_id == agent.id]
|
||||
assert len(agent_messages) == 3
|
||||
assert agent_messages[-1].role == MessageRole.user, "Expected initial user message"
|
||||
assert agent_messages[-2].role == MessageRole.assistant, "Expected assistant tool call after user message"
|
||||
assert agent_messages[-3].role == MessageRole.tool, "Expected tool response after assistant tool call"
|
||||
|
||||
|
||||
def assert_descending_order(messages):
|
||||
"""Assert messages are in descending order by created_at timestamps."""
|
||||
if len(messages) <= 1:
|
||||
return True
|
||||
|
||||
for i in range(1, len(messages)):
|
||||
assert messages[i].created_at <= messages[i - 1].created_at, (
|
||||
f"Order violation: {messages[i - 1].id} ({messages[i - 1].created_at}) "
|
||||
f"followed by {messages[i].id} ({messages[i].created_at})"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_step_after_request_all_continue(
|
||||
@ -841,6 +897,21 @@ async def test_resume_step_after_request_all_continue(
|
||||
len(refreshed_agent.message_ids) == 6
|
||||
), f"Agent's in-context messages have been extended, are length: {len(refreshed_agent.message_ids)}"
|
||||
|
||||
# Check the total list of messages
|
||||
messages = server.batch_manager.get_messages_for_letta_batch(
|
||||
letta_batch_job_id=pre_resume_response.letta_batch_id, limit=200, actor=default_user
|
||||
)
|
||||
assert len(messages) == len(agents) * 4
|
||||
assert_descending_order(messages)
|
||||
# Check that each agent is represented
|
||||
for agent in agents:
|
||||
agent_messages = [m for m in messages if m.agent_id == agent.id]
|
||||
assert len(agent_messages) == 4
|
||||
assert agent_messages[-1].role == MessageRole.user, "Expected initial user message"
|
||||
assert agent_messages[-2].role == MessageRole.assistant, "Expected assistant tool call after user message"
|
||||
assert agent_messages[-3].role == MessageRole.tool, "Expected tool response after assistant tool call"
|
||||
assert agent_messages[-4].role == MessageRole.user, "Expected final system-level heartbeat user message"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_step_until_request_prepares_and_submits_batch_correctly(
|
||||
|
Loading…
Reference in New Issue
Block a user