feat: Add callback for jobs (#1776)

Co-authored-by: Sarah Wooders <sarahwooders@gmail.com>
This commit is contained in:
Matthew Zhou 2025-04-18 10:48:04 -07:00 committed by GitHub
parent a5e4ebc137
commit 74ec15c97b
9 changed files with 201 additions and 68 deletions

View File

@ -0,0 +1,35 @@
"""Add callback data to jobs table
Revision ID: a3c7d62e08ca
Revises: 7b189006c97d
Create Date: 2025-04-17 17:40:16.224424
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "a3c7d62e08ca"
down_revision: Union[str, None] = "7b189006c97d"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("jobs", sa.Column("callback_url", sa.String(), nullable=True))
op.add_column("jobs", sa.Column("callback_sent_at", sa.DateTime(), nullable=True))
op.add_column("jobs", sa.Column("callback_status_code", sa.Integer(), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("jobs", "callback_status_code")
op.drop_column("jobs", "callback_sent_at")
op.drop_column("jobs", "callback_url")
# ### end Alembic commands ###

View File

@ -2,11 +2,14 @@ import asyncio
import datetime
from typing import List
from letta.agents.letta_agent_batch import LettaAgentBatch
from letta.jobs.helpers import map_anthropic_batch_job_status_to_job_status, map_anthropic_individual_batch_item_status_to_job_status
from letta.jobs.types import BatchPollingResult, ItemUpdateInfo
from letta.log import get_logger
from letta.schemas.enums import JobStatus, ProviderType
from letta.schemas.letta_response import LettaBatchResponse
from letta.schemas.llm_batch_job import LLMBatchJob
from letta.schemas.user import User
from letta.server.server import SyncServer
logger = get_logger(__name__)
@ -156,7 +159,7 @@ async def process_completed_batches(
return item_updates
async def poll_running_llm_batches(server: "SyncServer") -> None:
async def poll_running_llm_batches(server: "SyncServer") -> List[LettaBatchResponse]:
"""
Cron job to poll all running LLM batch jobs and update their polling responses in bulk.
@ -194,6 +197,32 @@ async def poll_running_llm_batches(server: "SyncServer") -> None:
if item_updates:
metrics.updated_items_count = len(item_updates)
server.batch_manager.bulk_update_batch_llm_items_results_by_agent(item_updates)
# ─── Kick off postprocessing for each batch that just completed ───
completed = [r for r in batch_results if r.request_status == JobStatus.completed]
async def _resume(batch_row: LLMBatchJob) -> LettaBatchResponse:
actor: User = server.user_manager.get_user_by_id(batch_row.created_by_id)
runner = LettaAgentBatch(
message_manager=server.message_manager,
agent_manager=server.agent_manager,
block_manager=server.block_manager,
passage_manager=server.passage_manager,
batch_manager=server.batch_manager,
sandbox_config_manager=server.sandbox_config_manager,
job_manager=server.job_manager,
actor=actor,
)
return await runner.resume_step_after_request(
letta_batch_id=batch_row.letta_batch_job_id,
llm_batch_id=batch_row.id,
)
# launch them all at once
tasks = [_resume(server.batch_manager.get_llm_batch_job_by_id(bid)) for bid, *_ in completed]
new_batch_responses = await asyncio.gather(*tasks, return_exceptions=True)
return new_batch_responses
else:
logger.info("[Poll BatchJob] No item-level updates needed.")

View File

@ -39,6 +39,11 @@ class Job(SqlalchemyBase, UserMixin):
JSON, nullable=True, doc="The request configuration for the job, stored as JSON."
)
# callback related columns
callback_url: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="When set, POST to this URL after job completion.")
callback_sent_at: Mapped[Optional[datetime]] = mapped_column(nullable=True, doc="Timestamp when the callback was last attempted.")
callback_status_code: Mapped[Optional[int]] = mapped_column(nullable=True, doc="HTTP status code returned by the callback endpoint.")
# relationships
user: Mapped["User"] = relationship("User", back_populates="jobs")
job_messages: Mapped[List["JobMessage"]] = relationship("JobMessage", back_populates="job", cascade="all, delete-orphan")

View File

@ -16,6 +16,10 @@ class JobBase(OrmMetadataBase):
metadata: Optional[dict] = Field(None, validation_alias="metadata_", description="The metadata of the job.")
job_type: JobType = Field(default=JobType.JOB, description="The type of the job.")
callback_url: Optional[str] = Field(None, description="If set, POST to this URL when the job completes.")
callback_sent_at: Optional[datetime] = Field(None, description="Timestamp when the callback was last attempted.")
callback_status_code: Optional[int] = Field(None, description="HTTP status code returned by the callback endpoint.")
class Job(JobBase):
"""

View File

@ -1,6 +1,6 @@
from typing import List
from typing import List, Optional
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, HttpUrl
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
from letta.schemas.message import MessageCreate
@ -35,3 +35,4 @@ class LettaBatchRequest(LettaRequest):
class CreateBatch(BaseModel):
requests: List[LettaBatchRequest] = Field(..., description="List of requests to be processed in batch.")
callback_url: Optional[HttpUrl] = Field(None, description="Optional URL to call via POST when the batch completes.")

View File

@ -53,6 +53,7 @@ async def create_messages_batch(
metadata={
"job_type": "batch_messages",
},
callback_url=str(payload.callback_url),
)
# create the batch runner

View File

@ -60,14 +60,16 @@ class JobManager:
update_data = job_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
# Automatically update the completion timestamp if status is set to 'completed'
if update_data.get("status") == JobStatus.completed and not job.completed_at:
job.completed_at = get_utc_time()
for key, value in update_data.items():
setattr(job, key, value)
if update_data.get("status") == JobStatus.completed and not job.completed_at:
job.completed_at = get_utc_time()
if job.callback_url:
self._dispatch_callback(session, job)
# Save the updated job to the database
job.update(db_session=session) # TODO: Add this later , actor=actor)
job.update(db_session=session, actor=actor)
return job.to_pydantic()
@ -455,3 +457,27 @@ class JobManager:
job = session.query(JobModel).filter(JobModel.id == run_id).first()
request_config = job.request_config or LettaRequestConfig()
return request_config
def _dispatch_callback(self, session: Session, job: JobModel) -> None:
"""
POST a standard JSON payload to job.callback_url
and record timestamp + HTTP status.
"""
payload = {
"job_id": job.id,
"status": job.status,
"completed_at": job.completed_at.isoformat(),
}
try:
import httpx
resp = httpx.post(job.callback_url, json=payload, timeout=5.0)
job.callback_sent_at = get_utc_time()
job.callback_status_code = resp.status_code
except Exception:
return
session.add(job)
session.commit()

View File

@ -392,78 +392,75 @@ async def test_resume_step_after_request_all_continue(
mock_results.return_value = MockAsyncIterable(mock_items.copy()) # Using copy to preserve the original list
with patch.object(server.anthropic_async_client.beta.messages.batches, "results", mock_results):
await poll_running_llm_batches(server)
with patch("letta.llm_api.anthropic_client.AnthropicClient.send_llm_batch_request_async", return_value=dummy_batch_response):
msg_counts_before = {agent.id: server.message_manager.size(actor=default_user, agent_id=agent.id) for agent in agents}
# Verify database records were updated correctly
llm_batch_job = server.batch_manager.get_llm_batch_job_by_id(llm_batch_job.id, actor=default_user)
new_batch_responses = await poll_running_llm_batches(server)
# Verify job properties
assert llm_batch_job.status == JobStatus.completed, "Job status should be 'completed'"
# Verify database records were updated correctly
llm_batch_job = server.batch_manager.get_llm_batch_job_by_id(llm_batch_job.id, actor=default_user)
# Verify batch items
items = server.batch_manager.list_llm_batch_items(llm_batch_id=llm_batch_job.id, actor=default_user)
assert len(items) == 3, f"Expected 3 batch items, got {len(items)}"
assert all([item.request_status == JobStatus.completed for item in items])
# Verify job properties
assert llm_batch_job.status == JobStatus.completed, "Job status should be 'completed'"
# 3. Call resume_step_after_request
letta_batch_agent = LettaAgentBatch(
message_manager=server.message_manager,
agent_manager=server.agent_manager,
block_manager=server.block_manager,
passage_manager=server.passage_manager,
batch_manager=server.batch_manager,
sandbox_config_manager=server.sandbox_config_manager,
job_manager=server.job_manager,
actor=default_user,
)
with patch("letta.llm_api.anthropic_client.AnthropicClient.send_llm_batch_request_async", return_value=dummy_batch_response):
msg_counts_before = {agent.id: server.message_manager.size(actor=default_user, agent_id=agent.id) for agent in agents}
# Verify batch items
items = server.batch_manager.list_llm_batch_items(llm_batch_id=llm_batch_job.id, actor=default_user)
assert len(items) == 3, f"Expected 3 batch items, got {len(items)}"
assert all([item.request_status == JobStatus.completed for item in items])
post_resume_response = await letta_batch_agent.resume_step_after_request(
letta_batch_id=pre_resume_response.letta_batch_id, llm_batch_id=llm_batch_job.id
)
# Verify only one new batch response
assert len(new_batch_responses) == 1
post_resume_response = new_batch_responses[0]
assert (
post_resume_response.letta_batch_id == pre_resume_response.letta_batch_id
), "resume_step_after_request is expected to have the same letta_batch_id"
assert (
post_resume_response.last_llm_batch_id != pre_resume_response.last_llm_batch_id
), "resume_step_after_request is expected to have different llm_batch_id."
assert post_resume_response.status == JobStatus.running
assert post_resume_response.agent_count == 3
assert (
post_resume_response.letta_batch_id == pre_resume_response.letta_batch_id
), "resume_step_after_request is expected to have the same letta_batch_id"
assert (
post_resume_response.last_llm_batch_id != pre_resume_response.last_llm_batch_id
), "resume_step_after_request is expected to have different llm_batch_id."
assert post_resume_response.status == JobStatus.running
assert post_resume_response.agent_count == 3
# New batchitems should exist, initialised in (created, paused) state
new_items = server.batch_manager.list_llm_batch_items(llm_batch_id=post_resume_response.last_llm_batch_id, actor=default_user)
assert len(new_items) == 3, f"Expected 3 new batch items, got {len(new_items)}"
assert {i.request_status for i in new_items} == {JobStatus.created}
assert {i.step_status for i in new_items} == {AgentStepStatus.paused}
# New batchitems should exist, initialised in (created, paused) state
new_items = server.batch_manager.list_llm_batch_items(
llm_batch_id=post_resume_response.last_llm_batch_id, actor=default_user
)
assert len(new_items) == 3, f"Expected 3 new batch items, got {len(new_items)}"
assert {i.request_status for i in new_items} == {JobStatus.created}
assert {i.step_status for i in new_items} == {AgentStepStatus.paused}
# Confirm that tool_rules_solver state was preserved correctly
# Assert every new item's step_state's tool_rules_solver has "get_weather" in the tool_call_history
assert all(
"get_weather" in item.step_state.tool_rules_solver.tool_call_history for item in new_items
), "Expected 'get_weather' in tool_call_history for all new_items"
# Assert that each new item's step_number was incremented to 1
assert all(item.step_state.step_number == 1 for item in new_items), "Expected step_number to be incremented to 1 for all new_items"
# Confirm that tool_rules_solver state was preserved correctly
# Assert every new item's step_state's tool_rules_solver has "get_weather" in the tool_call_history
assert all(
"get_weather" in item.step_state.tool_rules_solver.tool_call_history for item in new_items
), "Expected 'get_weather' in tool_call_history for all new_items"
# Assert that each new item's step_number was incremented to 1
assert all(
item.step_state.step_number == 1 for item in new_items
), "Expected step_number to be incremented to 1 for all new_items"
# Old items must have been flipped to completed / finished earlier
# (sanity we already asserted this above, but we keep it close for clarity)
old_items = server.batch_manager.list_llm_batch_items(llm_batch_id=pre_resume_response.last_llm_batch_id, actor=default_user)
assert {i.request_status for i in old_items} == {JobStatus.completed}
assert {i.step_status for i in old_items} == {AgentStepStatus.completed}
# Old items must have been flipped to completed / finished earlier
# (sanity we already asserted this above, but we keep it close for clarity)
old_items = server.batch_manager.list_llm_batch_items(
llm_batch_id=pre_resume_response.last_llm_batch_id, actor=default_user
)
assert {i.request_status for i in old_items} == {JobStatus.completed}
assert {i.step_status for i in old_items} == {AgentStepStatus.completed}
# Toolcall sideeffects each agent gets at least 2 extra messages
for agent in agents:
before = msg_counts_before[agent.id] # captured just before resume
after = server.message_manager.size(actor=default_user, agent_id=agent.id)
assert after - before >= 2, f"Agent {agent.id} should have an assistant toolcall " f"and toolresponse message persisted."
# Toolcall sideeffects each agent gets at least 2 extra messages
for agent in agents:
before = msg_counts_before[agent.id] # captured just before resume
after = server.message_manager.size(actor=default_user, agent_id=agent.id)
assert after - before >= 2, (
f"Agent {agent.id} should have an assistant toolcall " f"and toolresponse message persisted."
)
# Check that agent states have been properly modified to have extended in-context messages
for agent in agents:
refreshed_agent = server.agent_manager.get_agent_by_id(agent_id=agent.id, actor=default_user)
assert (
len(refreshed_agent.message_ids) == 6
), f"Agent's in-context messages have not been extended, are length: {len(refreshed_agent.message_ids)}"
# Check that agent states have been properly modified to have extended in-context messages
for agent in agents:
refreshed_agent = server.agent_manager.get_agent_by_id(agent_id=agent.id, actor=default_user)
assert (
len(refreshed_agent.message_ids) == 6
), f"Agent's in-context messages have not been extended, are length: {len(refreshed_agent.message_ids)}"
@pytest.mark.asyncio

View File

@ -5,6 +5,7 @@ import time
from datetime import datetime, timedelta, timezone
from typing import List
import httpx
import pytest
from anthropic.types.beta import BetaMessage
from anthropic.types.beta.messages import BetaMessageBatchIndividualResponse, BetaMessageBatchSucceededResult
@ -4204,6 +4205,40 @@ def test_list_jobs_filter_by_type(server: SyncServer, default_user, default_job)
assert jobs[0].id == run.id
def test_e2e_job_callback(monkeypatch, server: SyncServer, default_user):
captured = {}
def fake_post(url, json, timeout):
captured["url"] = url
captured["json"] = json
class FakeResponse:
status_code = 202
return FakeResponse()
monkeypatch.setattr(httpx, "post", fake_post)
job_in = PydanticJob(status=JobStatus.created, metadata={"foo": "bar"}, callback_url="http://example.test/webhook/jobs")
created = server.job_manager.create_job(job_in, actor=default_user)
assert created.callback_url == "http://example.test/webhook/jobs"
update = JobUpdate(status=JobStatus.completed)
updated = server.job_manager.update_job_by_id(created.id, update, actor=default_user)
assert captured["url"] == created.callback_url
assert captured["json"]["job_id"] == created.id
assert captured["json"]["status"] == JobStatus.completed.value
# Normalize the received completed_at to compare properly
actual_dt = datetime.fromisoformat(captured["json"]["completed_at"]).replace(tzinfo=None)
expected_dt = updated.completed_at.replace(tzinfo=None)
assert actual_dt == expected_dt
assert isinstance(updated.callback_sent_at, datetime)
assert updated.callback_status_code == 202
# ======================================================================================================================
# JobManager Tests - Messages
# ======================================================================================================================