feat: Add pagination to listing LLM batch items (#1724)

This commit is contained in:
Matthew Zhou 2025-04-15 15:05:04 -07:00 committed by GitHub
parent cc791f7fd1
commit 73361fd931
2 changed files with 87 additions and 3 deletions

View File

@ -205,22 +205,47 @@ class LLMBatchManager:
return item.update(db_session=session, actor=actor).to_pydantic()
# TODO: Maybe make this paginated?
@enforce_types
def list_batch_items(
self,
batch_id: str,
limit: Optional[int] = None,
actor: Optional[PydanticUser] = None,
after: Optional[str] = None,
agent_id: Optional[str] = None,
request_status: Optional[JobStatus] = None,
step_status: Optional[AgentStepStatus] = None,
) -> List[PydanticLLMBatchItem]:
"""List all batch items for a given batch_id, optionally filtered by organization and limited in count."""
"""
List all batch items for a given batch_id, optionally filtered by additional criteria and limited in count.
Optional filters:
- after: A cursor string. Only items with an `id` greater than this value are returned.
- agent_id: Restrict the result set to a specific agent.
- request_status: Filter items based on their request status (e.g., created, completed, expired).
- step_status: Filter items based on their step execution status.
The results are ordered by their id in ascending order.
"""
with self.session_maker() as session:
query = session.query(LLMBatchItem).filter(LLMBatchItem.batch_id == batch_id)
if actor is not None:
query = query.filter(LLMBatchItem.organization_id == actor.organization_id)
if limit:
# Additional optional filters
if agent_id is not None:
query = query.filter(LLMBatchItem.agent_id == agent_id)
if request_status is not None:
query = query.filter(LLMBatchItem.request_status == request_status)
if step_status is not None:
query = query.filter(LLMBatchItem.step_status == step_status)
if after is not None:
query = query.filter(LLMBatchItem.id > after)
query = query.order_by(LLMBatchItem.id.asc())
if limit is not None:
query = query.limit(limit)
results = query.all()

View File

@ -4947,6 +4947,65 @@ def test_list_batch_items_limit_and_filter(server, default_user, sarah_agent, du
assert len(limited_items) == 2
def test_list_batch_items_pagination(server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state):
# Create a batch job.
batch = server.batch_manager.create_batch_job(
llm_provider=ProviderType.anthropic,
create_batch_response=dummy_beta_message_batch,
actor=default_user,
)
# Create 10 batch items.
created_items = []
for i in range(10):
item = server.batch_manager.create_batch_item(
batch_id=batch.id,
agent_id=sarah_agent.id,
llm_config=dummy_llm_config,
step_state=dummy_step_state,
actor=default_user,
)
created_items.append(item)
# Retrieve all items (without pagination).
all_items = server.batch_manager.list_batch_items(batch_id=batch.id, actor=default_user)
assert len(all_items) >= 10, f"Expected at least 10 items, got {len(all_items)}"
# Verify the items are ordered ascending by id (based on our implementation).
sorted_ids = [item.id for item in sorted(all_items, key=lambda i: i.id)]
retrieved_ids = [item.id for item in all_items]
assert retrieved_ids == sorted_ids, "Batch items are not ordered in ascending order by id"
# Choose a cursor: the id of the 5th item.
cursor = all_items[4].id
# Retrieve items after the cursor.
paged_items = server.batch_manager.list_batch_items(batch_id=batch.id, actor=default_user, after=cursor)
# All returned items should have an id greater than the cursor.
for item in paged_items:
assert item.id > cursor, f"Item id {item.id} is not greater than the cursor {cursor}"
# Count expected remaining items.
# Find the index of the cursor in our sorted list.
cursor_index = sorted_ids.index(cursor)
expected_remaining = len(sorted_ids) - cursor_index - 1
assert len(paged_items) == expected_remaining, f"Expected {expected_remaining} items after cursor, got {len(paged_items)}"
# Test pagination with a limit.
limit = 3
limited_page = server.batch_manager.list_batch_items(batch_id=batch.id, actor=default_user, after=cursor, limit=limit)
# If more than 'limit' items remain, we should only get exactly 'limit' items.
assert len(limited_page) == min(
limit, expected_remaining
), f"Expected {min(limit, expected_remaining)} items with limit {limit}, got {len(limited_page)}"
# Optional: Test with a cursor beyond the last item returns an empty list.
last_cursor = sorted_ids[-1]
empty_page = server.batch_manager.list_batch_items(batch_id=batch.id, actor=default_user, after=last_cursor)
assert empty_page == [], "Expected an empty list when cursor is after the last item"
def test_bulk_update_batch_items_request_status_by_agent(
server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state
):