mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
feat: Add pagination to listing LLM batch items (#1724)
This commit is contained in:
parent
cc791f7fd1
commit
73361fd931
@ -205,22 +205,47 @@ class LLMBatchManager:
|
||||
|
||||
return item.update(db_session=session, actor=actor).to_pydantic()
|
||||
|
||||
# TODO: Maybe make this paginated?
|
||||
@enforce_types
|
||||
def list_batch_items(
|
||||
self,
|
||||
batch_id: str,
|
||||
limit: Optional[int] = None,
|
||||
actor: Optional[PydanticUser] = None,
|
||||
after: Optional[str] = None,
|
||||
agent_id: Optional[str] = None,
|
||||
request_status: Optional[JobStatus] = None,
|
||||
step_status: Optional[AgentStepStatus] = None,
|
||||
) -> List[PydanticLLMBatchItem]:
|
||||
"""List all batch items for a given batch_id, optionally filtered by organization and limited in count."""
|
||||
"""
|
||||
List all batch items for a given batch_id, optionally filtered by additional criteria and limited in count.
|
||||
|
||||
Optional filters:
|
||||
- after: A cursor string. Only items with an `id` greater than this value are returned.
|
||||
- agent_id: Restrict the result set to a specific agent.
|
||||
- request_status: Filter items based on their request status (e.g., created, completed, expired).
|
||||
- step_status: Filter items based on their step execution status.
|
||||
|
||||
The results are ordered by their id in ascending order.
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
query = session.query(LLMBatchItem).filter(LLMBatchItem.batch_id == batch_id)
|
||||
|
||||
if actor is not None:
|
||||
query = query.filter(LLMBatchItem.organization_id == actor.organization_id)
|
||||
|
||||
if limit:
|
||||
# Additional optional filters
|
||||
if agent_id is not None:
|
||||
query = query.filter(LLMBatchItem.agent_id == agent_id)
|
||||
if request_status is not None:
|
||||
query = query.filter(LLMBatchItem.request_status == request_status)
|
||||
if step_status is not None:
|
||||
query = query.filter(LLMBatchItem.step_status == step_status)
|
||||
if after is not None:
|
||||
query = query.filter(LLMBatchItem.id > after)
|
||||
|
||||
query = query.order_by(LLMBatchItem.id.asc())
|
||||
|
||||
if limit is not None:
|
||||
query = query.limit(limit)
|
||||
|
||||
results = query.all()
|
||||
|
@ -4947,6 +4947,65 @@ def test_list_batch_items_limit_and_filter(server, default_user, sarah_agent, du
|
||||
assert len(limited_items) == 2
|
||||
|
||||
|
||||
def test_list_batch_items_pagination(server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state):
|
||||
# Create a batch job.
|
||||
batch = server.batch_manager.create_batch_job(
|
||||
llm_provider=ProviderType.anthropic,
|
||||
create_batch_response=dummy_beta_message_batch,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Create 10 batch items.
|
||||
created_items = []
|
||||
for i in range(10):
|
||||
item = server.batch_manager.create_batch_item(
|
||||
batch_id=batch.id,
|
||||
agent_id=sarah_agent.id,
|
||||
llm_config=dummy_llm_config,
|
||||
step_state=dummy_step_state,
|
||||
actor=default_user,
|
||||
)
|
||||
created_items.append(item)
|
||||
|
||||
# Retrieve all items (without pagination).
|
||||
all_items = server.batch_manager.list_batch_items(batch_id=batch.id, actor=default_user)
|
||||
assert len(all_items) >= 10, f"Expected at least 10 items, got {len(all_items)}"
|
||||
|
||||
# Verify the items are ordered ascending by id (based on our implementation).
|
||||
sorted_ids = [item.id for item in sorted(all_items, key=lambda i: i.id)]
|
||||
retrieved_ids = [item.id for item in all_items]
|
||||
assert retrieved_ids == sorted_ids, "Batch items are not ordered in ascending order by id"
|
||||
|
||||
# Choose a cursor: the id of the 5th item.
|
||||
cursor = all_items[4].id
|
||||
|
||||
# Retrieve items after the cursor.
|
||||
paged_items = server.batch_manager.list_batch_items(batch_id=batch.id, actor=default_user, after=cursor)
|
||||
|
||||
# All returned items should have an id greater than the cursor.
|
||||
for item in paged_items:
|
||||
assert item.id > cursor, f"Item id {item.id} is not greater than the cursor {cursor}"
|
||||
|
||||
# Count expected remaining items.
|
||||
# Find the index of the cursor in our sorted list.
|
||||
cursor_index = sorted_ids.index(cursor)
|
||||
expected_remaining = len(sorted_ids) - cursor_index - 1
|
||||
assert len(paged_items) == expected_remaining, f"Expected {expected_remaining} items after cursor, got {len(paged_items)}"
|
||||
|
||||
# Test pagination with a limit.
|
||||
limit = 3
|
||||
limited_page = server.batch_manager.list_batch_items(batch_id=batch.id, actor=default_user, after=cursor, limit=limit)
|
||||
# If more than 'limit' items remain, we should only get exactly 'limit' items.
|
||||
assert len(limited_page) == min(
|
||||
limit, expected_remaining
|
||||
), f"Expected {min(limit, expected_remaining)} items with limit {limit}, got {len(limited_page)}"
|
||||
|
||||
# Optional: Test with a cursor beyond the last item returns an empty list.
|
||||
last_cursor = sorted_ids[-1]
|
||||
empty_page = server.batch_manager.list_batch_items(batch_id=batch.id, actor=default_user, after=last_cursor)
|
||||
assert empty_page == [], "Expected an empty list when cursor is after the last item"
|
||||
|
||||
|
||||
def test_bulk_update_batch_items_request_status_by_agent(
|
||||
server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state
|
||||
):
|
||||
|
Loading…
Reference in New Issue
Block a user