mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
feat: Add callback for jobs (#1776)
Co-authored-by: Sarah Wooders <sarahwooders@gmail.com>
This commit is contained in:
parent
a5e4ebc137
commit
74ec15c97b
@ -0,0 +1,35 @@
|
||||
"""Add callback data to jobs table
|
||||
|
||||
Revision ID: a3c7d62e08ca
|
||||
Revises: 7b189006c97d
|
||||
Create Date: 2025-04-17 17:40:16.224424
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "a3c7d62e08ca"
|
||||
down_revision: Union[str, None] = "7b189006c97d"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("jobs", sa.Column("callback_url", sa.String(), nullable=True))
|
||||
op.add_column("jobs", sa.Column("callback_sent_at", sa.DateTime(), nullable=True))
|
||||
op.add_column("jobs", sa.Column("callback_status_code", sa.Integer(), nullable=True))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("jobs", "callback_status_code")
|
||||
op.drop_column("jobs", "callback_sent_at")
|
||||
op.drop_column("jobs", "callback_url")
|
||||
# ### end Alembic commands ###
|
@ -2,11 +2,14 @@ import asyncio
|
||||
import datetime
|
||||
from typing import List
|
||||
|
||||
from letta.agents.letta_agent_batch import LettaAgentBatch
|
||||
from letta.jobs.helpers import map_anthropic_batch_job_status_to_job_status, map_anthropic_individual_batch_item_status_to_job_status
|
||||
from letta.jobs.types import BatchPollingResult, ItemUpdateInfo
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.enums import JobStatus, ProviderType
|
||||
from letta.schemas.letta_response import LettaBatchResponse
|
||||
from letta.schemas.llm_batch_job import LLMBatchJob
|
||||
from letta.schemas.user import User
|
||||
from letta.server.server import SyncServer
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@ -156,7 +159,7 @@ async def process_completed_batches(
|
||||
return item_updates
|
||||
|
||||
|
||||
async def poll_running_llm_batches(server: "SyncServer") -> None:
|
||||
async def poll_running_llm_batches(server: "SyncServer") -> List[LettaBatchResponse]:
|
||||
"""
|
||||
Cron job to poll all running LLM batch jobs and update their polling responses in bulk.
|
||||
|
||||
@ -194,6 +197,32 @@ async def poll_running_llm_batches(server: "SyncServer") -> None:
|
||||
if item_updates:
|
||||
metrics.updated_items_count = len(item_updates)
|
||||
server.batch_manager.bulk_update_batch_llm_items_results_by_agent(item_updates)
|
||||
|
||||
# ─── Kick off post‑processing for each batch that just completed ───
|
||||
completed = [r for r in batch_results if r.request_status == JobStatus.completed]
|
||||
|
||||
async def _resume(batch_row: LLMBatchJob) -> LettaBatchResponse:
|
||||
actor: User = server.user_manager.get_user_by_id(batch_row.created_by_id)
|
||||
runner = LettaAgentBatch(
|
||||
message_manager=server.message_manager,
|
||||
agent_manager=server.agent_manager,
|
||||
block_manager=server.block_manager,
|
||||
passage_manager=server.passage_manager,
|
||||
batch_manager=server.batch_manager,
|
||||
sandbox_config_manager=server.sandbox_config_manager,
|
||||
job_manager=server.job_manager,
|
||||
actor=actor,
|
||||
)
|
||||
return await runner.resume_step_after_request(
|
||||
letta_batch_id=batch_row.letta_batch_job_id,
|
||||
llm_batch_id=batch_row.id,
|
||||
)
|
||||
|
||||
# launch them all at once
|
||||
tasks = [_resume(server.batch_manager.get_llm_batch_job_by_id(bid)) for bid, *_ in completed]
|
||||
new_batch_responses = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
return new_batch_responses
|
||||
else:
|
||||
logger.info("[Poll BatchJob] No item-level updates needed.")
|
||||
|
||||
|
@ -39,6 +39,11 @@ class Job(SqlalchemyBase, UserMixin):
|
||||
JSON, nullable=True, doc="The request configuration for the job, stored as JSON."
|
||||
)
|
||||
|
||||
# callback related columns
|
||||
callback_url: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="When set, POST to this URL after job completion.")
|
||||
callback_sent_at: Mapped[Optional[datetime]] = mapped_column(nullable=True, doc="Timestamp when the callback was last attempted.")
|
||||
callback_status_code: Mapped[Optional[int]] = mapped_column(nullable=True, doc="HTTP status code returned by the callback endpoint.")
|
||||
|
||||
# relationships
|
||||
user: Mapped["User"] = relationship("User", back_populates="jobs")
|
||||
job_messages: Mapped[List["JobMessage"]] = relationship("JobMessage", back_populates="job", cascade="all, delete-orphan")
|
||||
|
@ -16,6 +16,10 @@ class JobBase(OrmMetadataBase):
|
||||
metadata: Optional[dict] = Field(None, validation_alias="metadata_", description="The metadata of the job.")
|
||||
job_type: JobType = Field(default=JobType.JOB, description="The type of the job.")
|
||||
|
||||
callback_url: Optional[str] = Field(None, description="If set, POST to this URL when the job completes.")
|
||||
callback_sent_at: Optional[datetime] = Field(None, description="Timestamp when the callback was last attempted.")
|
||||
callback_status_code: Optional[int] = Field(None, description="HTTP status code returned by the callback endpoint.")
|
||||
|
||||
|
||||
class Job(JobBase):
|
||||
"""
|
||||
|
@ -1,6 +1,6 @@
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, HttpUrl
|
||||
|
||||
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
|
||||
from letta.schemas.message import MessageCreate
|
||||
@ -35,3 +35,4 @@ class LettaBatchRequest(LettaRequest):
|
||||
|
||||
class CreateBatch(BaseModel):
|
||||
requests: List[LettaBatchRequest] = Field(..., description="List of requests to be processed in batch.")
|
||||
callback_url: Optional[HttpUrl] = Field(None, description="Optional URL to call via POST when the batch completes.")
|
||||
|
@ -53,6 +53,7 @@ async def create_messages_batch(
|
||||
metadata={
|
||||
"job_type": "batch_messages",
|
||||
},
|
||||
callback_url=str(payload.callback_url),
|
||||
)
|
||||
|
||||
# create the batch runner
|
||||
|
@ -60,14 +60,16 @@ class JobManager:
|
||||
update_data = job_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
|
||||
|
||||
# Automatically update the completion timestamp if status is set to 'completed'
|
||||
if update_data.get("status") == JobStatus.completed and not job.completed_at:
|
||||
job.completed_at = get_utc_time()
|
||||
|
||||
for key, value in update_data.items():
|
||||
setattr(job, key, value)
|
||||
|
||||
if update_data.get("status") == JobStatus.completed and not job.completed_at:
|
||||
job.completed_at = get_utc_time()
|
||||
if job.callback_url:
|
||||
self._dispatch_callback(session, job)
|
||||
|
||||
# Save the updated job to the database
|
||||
job.update(db_session=session) # TODO: Add this later , actor=actor)
|
||||
job.update(db_session=session, actor=actor)
|
||||
|
||||
return job.to_pydantic()
|
||||
|
||||
@ -455,3 +457,27 @@ class JobManager:
|
||||
job = session.query(JobModel).filter(JobModel.id == run_id).first()
|
||||
request_config = job.request_config or LettaRequestConfig()
|
||||
return request_config
|
||||
|
||||
def _dispatch_callback(self, session: Session, job: JobModel) -> None:
|
||||
"""
|
||||
POST a standard JSON payload to job.callback_url
|
||||
and record timestamp + HTTP status.
|
||||
"""
|
||||
|
||||
payload = {
|
||||
"job_id": job.id,
|
||||
"status": job.status,
|
||||
"completed_at": job.completed_at.isoformat(),
|
||||
}
|
||||
try:
|
||||
import httpx
|
||||
|
||||
resp = httpx.post(job.callback_url, json=payload, timeout=5.0)
|
||||
job.callback_sent_at = get_utc_time()
|
||||
job.callback_status_code = resp.status_code
|
||||
|
||||
except Exception:
|
||||
return
|
||||
|
||||
session.add(job)
|
||||
session.commit()
|
||||
|
@ -392,78 +392,75 @@ async def test_resume_step_after_request_all_continue(
|
||||
mock_results.return_value = MockAsyncIterable(mock_items.copy()) # Using copy to preserve the original list
|
||||
|
||||
with patch.object(server.anthropic_async_client.beta.messages.batches, "results", mock_results):
|
||||
await poll_running_llm_batches(server)
|
||||
with patch("letta.llm_api.anthropic_client.AnthropicClient.send_llm_batch_request_async", return_value=dummy_batch_response):
|
||||
msg_counts_before = {agent.id: server.message_manager.size(actor=default_user, agent_id=agent.id) for agent in agents}
|
||||
|
||||
# Verify database records were updated correctly
|
||||
llm_batch_job = server.batch_manager.get_llm_batch_job_by_id(llm_batch_job.id, actor=default_user)
|
||||
new_batch_responses = await poll_running_llm_batches(server)
|
||||
|
||||
# Verify job properties
|
||||
assert llm_batch_job.status == JobStatus.completed, "Job status should be 'completed'"
|
||||
# Verify database records were updated correctly
|
||||
llm_batch_job = server.batch_manager.get_llm_batch_job_by_id(llm_batch_job.id, actor=default_user)
|
||||
|
||||
# Verify batch items
|
||||
items = server.batch_manager.list_llm_batch_items(llm_batch_id=llm_batch_job.id, actor=default_user)
|
||||
assert len(items) == 3, f"Expected 3 batch items, got {len(items)}"
|
||||
assert all([item.request_status == JobStatus.completed for item in items])
|
||||
# Verify job properties
|
||||
assert llm_batch_job.status == JobStatus.completed, "Job status should be 'completed'"
|
||||
|
||||
# 3. Call resume_step_after_request
|
||||
letta_batch_agent = LettaAgentBatch(
|
||||
message_manager=server.message_manager,
|
||||
agent_manager=server.agent_manager,
|
||||
block_manager=server.block_manager,
|
||||
passage_manager=server.passage_manager,
|
||||
batch_manager=server.batch_manager,
|
||||
sandbox_config_manager=server.sandbox_config_manager,
|
||||
job_manager=server.job_manager,
|
||||
actor=default_user,
|
||||
)
|
||||
with patch("letta.llm_api.anthropic_client.AnthropicClient.send_llm_batch_request_async", return_value=dummy_batch_response):
|
||||
msg_counts_before = {agent.id: server.message_manager.size(actor=default_user, agent_id=agent.id) for agent in agents}
|
||||
# Verify batch items
|
||||
items = server.batch_manager.list_llm_batch_items(llm_batch_id=llm_batch_job.id, actor=default_user)
|
||||
assert len(items) == 3, f"Expected 3 batch items, got {len(items)}"
|
||||
assert all([item.request_status == JobStatus.completed for item in items])
|
||||
|
||||
post_resume_response = await letta_batch_agent.resume_step_after_request(
|
||||
letta_batch_id=pre_resume_response.letta_batch_id, llm_batch_id=llm_batch_job.id
|
||||
)
|
||||
# Verify only one new batch response
|
||||
assert len(new_batch_responses) == 1
|
||||
post_resume_response = new_batch_responses[0]
|
||||
|
||||
assert (
|
||||
post_resume_response.letta_batch_id == pre_resume_response.letta_batch_id
|
||||
), "resume_step_after_request is expected to have the same letta_batch_id"
|
||||
assert (
|
||||
post_resume_response.last_llm_batch_id != pre_resume_response.last_llm_batch_id
|
||||
), "resume_step_after_request is expected to have different llm_batch_id."
|
||||
assert post_resume_response.status == JobStatus.running
|
||||
assert post_resume_response.agent_count == 3
|
||||
assert (
|
||||
post_resume_response.letta_batch_id == pre_resume_response.letta_batch_id
|
||||
), "resume_step_after_request is expected to have the same letta_batch_id"
|
||||
assert (
|
||||
post_resume_response.last_llm_batch_id != pre_resume_response.last_llm_batch_id
|
||||
), "resume_step_after_request is expected to have different llm_batch_id."
|
||||
assert post_resume_response.status == JobStatus.running
|
||||
assert post_resume_response.agent_count == 3
|
||||
|
||||
# New batch‑items should exist, initialised in (created, paused) state
|
||||
new_items = server.batch_manager.list_llm_batch_items(llm_batch_id=post_resume_response.last_llm_batch_id, actor=default_user)
|
||||
assert len(new_items) == 3, f"Expected 3 new batch items, got {len(new_items)}"
|
||||
assert {i.request_status for i in new_items} == {JobStatus.created}
|
||||
assert {i.step_status for i in new_items} == {AgentStepStatus.paused}
|
||||
# New batch‑items should exist, initialised in (created, paused) state
|
||||
new_items = server.batch_manager.list_llm_batch_items(
|
||||
llm_batch_id=post_resume_response.last_llm_batch_id, actor=default_user
|
||||
)
|
||||
assert len(new_items) == 3, f"Expected 3 new batch items, got {len(new_items)}"
|
||||
assert {i.request_status for i in new_items} == {JobStatus.created}
|
||||
assert {i.step_status for i in new_items} == {AgentStepStatus.paused}
|
||||
|
||||
# Confirm that tool_rules_solver state was preserved correctly
|
||||
# Assert every new item's step_state's tool_rules_solver has "get_weather" in the tool_call_history
|
||||
assert all(
|
||||
"get_weather" in item.step_state.tool_rules_solver.tool_call_history for item in new_items
|
||||
), "Expected 'get_weather' in tool_call_history for all new_items"
|
||||
# Assert that each new item's step_number was incremented to 1
|
||||
assert all(item.step_state.step_number == 1 for item in new_items), "Expected step_number to be incremented to 1 for all new_items"
|
||||
# Confirm that tool_rules_solver state was preserved correctly
|
||||
# Assert every new item's step_state's tool_rules_solver has "get_weather" in the tool_call_history
|
||||
assert all(
|
||||
"get_weather" in item.step_state.tool_rules_solver.tool_call_history for item in new_items
|
||||
), "Expected 'get_weather' in tool_call_history for all new_items"
|
||||
# Assert that each new item's step_number was incremented to 1
|
||||
assert all(
|
||||
item.step_state.step_number == 1 for item in new_items
|
||||
), "Expected step_number to be incremented to 1 for all new_items"
|
||||
|
||||
# Old items must have been flipped to completed / finished earlier
|
||||
# (sanity – we already asserted this above, but we keep it close for clarity)
|
||||
old_items = server.batch_manager.list_llm_batch_items(llm_batch_id=pre_resume_response.last_llm_batch_id, actor=default_user)
|
||||
assert {i.request_status for i in old_items} == {JobStatus.completed}
|
||||
assert {i.step_status for i in old_items} == {AgentStepStatus.completed}
|
||||
# Old items must have been flipped to completed / finished earlier
|
||||
# (sanity – we already asserted this above, but we keep it close for clarity)
|
||||
old_items = server.batch_manager.list_llm_batch_items(
|
||||
llm_batch_id=pre_resume_response.last_llm_batch_id, actor=default_user
|
||||
)
|
||||
assert {i.request_status for i in old_items} == {JobStatus.completed}
|
||||
assert {i.step_status for i in old_items} == {AgentStepStatus.completed}
|
||||
|
||||
# Tool‑call side‑effects – each agent gets at least 2 extra messages
|
||||
for agent in agents:
|
||||
before = msg_counts_before[agent.id] # captured just before resume
|
||||
after = server.message_manager.size(actor=default_user, agent_id=agent.id)
|
||||
assert after - before >= 2, f"Agent {agent.id} should have an assistant tool‑call " f"and tool‑response message persisted."
|
||||
# Tool‑call side‑effects – each agent gets at least 2 extra messages
|
||||
for agent in agents:
|
||||
before = msg_counts_before[agent.id] # captured just before resume
|
||||
after = server.message_manager.size(actor=default_user, agent_id=agent.id)
|
||||
assert after - before >= 2, (
|
||||
f"Agent {agent.id} should have an assistant tool‑call " f"and tool‑response message persisted."
|
||||
)
|
||||
|
||||
# Check that agent states have been properly modified to have extended in-context messages
|
||||
for agent in agents:
|
||||
refreshed_agent = server.agent_manager.get_agent_by_id(agent_id=agent.id, actor=default_user)
|
||||
assert (
|
||||
len(refreshed_agent.message_ids) == 6
|
||||
), f"Agent's in-context messages have not been extended, are length: {len(refreshed_agent.message_ids)}"
|
||||
# Check that agent states have been properly modified to have extended in-context messages
|
||||
for agent in agents:
|
||||
refreshed_agent = server.agent_manager.get_agent_by_id(agent_id=agent.id, actor=default_user)
|
||||
assert (
|
||||
len(refreshed_agent.message_ids) == 6
|
||||
), f"Agent's in-context messages have not been extended, are length: {len(refreshed_agent.message_ids)}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -5,6 +5,7 @@ import time
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import List
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from anthropic.types.beta import BetaMessage
|
||||
from anthropic.types.beta.messages import BetaMessageBatchIndividualResponse, BetaMessageBatchSucceededResult
|
||||
@ -4204,6 +4205,40 @@ def test_list_jobs_filter_by_type(server: SyncServer, default_user, default_job)
|
||||
assert jobs[0].id == run.id
|
||||
|
||||
|
||||
def test_e2e_job_callback(monkeypatch, server: SyncServer, default_user):
|
||||
captured = {}
|
||||
|
||||
def fake_post(url, json, timeout):
|
||||
captured["url"] = url
|
||||
captured["json"] = json
|
||||
|
||||
class FakeResponse:
|
||||
status_code = 202
|
||||
|
||||
return FakeResponse()
|
||||
|
||||
monkeypatch.setattr(httpx, "post", fake_post)
|
||||
|
||||
job_in = PydanticJob(status=JobStatus.created, metadata={"foo": "bar"}, callback_url="http://example.test/webhook/jobs")
|
||||
created = server.job_manager.create_job(job_in, actor=default_user)
|
||||
assert created.callback_url == "http://example.test/webhook/jobs"
|
||||
|
||||
update = JobUpdate(status=JobStatus.completed)
|
||||
updated = server.job_manager.update_job_by_id(created.id, update, actor=default_user)
|
||||
|
||||
assert captured["url"] == created.callback_url
|
||||
assert captured["json"]["job_id"] == created.id
|
||||
assert captured["json"]["status"] == JobStatus.completed.value
|
||||
|
||||
# Normalize the received completed_at to compare properly
|
||||
actual_dt = datetime.fromisoformat(captured["json"]["completed_at"]).replace(tzinfo=None)
|
||||
expected_dt = updated.completed_at.replace(tzinfo=None)
|
||||
assert actual_dt == expected_dt
|
||||
|
||||
assert isinstance(updated.callback_sent_at, datetime)
|
||||
assert updated.callback_status_code == 202
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# JobManager Tests - Messages
|
||||
# ======================================================================================================================
|
||||
|
Loading…
Reference in New Issue
Block a user