feat: Finish step_until_request in new batch agent loop (#1656)

This commit is contained in:
Matthew Zhou 2025-04-10 10:19:06 -07:00 committed by GitHub
parent 2636e3d384
commit c1f9d3c2b7
8 changed files with 193 additions and 10 deletions

View File

@ -0,0 +1,164 @@
from typing import Dict, List
from letta.agents.helpers import _prepare_in_context_messages
from letta.helpers import ToolRulesSolver
from letta.helpers.datetime_helpers import get_utc_time
from letta.helpers.tool_execution_helper import enable_strict_mode
from letta.llm_api.llm_client import LLMClient
from letta.log import get_logger
from letta.orm.enums import ToolType
from letta.schemas.agent import AgentState, AgentStepState
from letta.schemas.enums import JobStatus, ProviderType
from letta.schemas.letta_request import LettaBatchRequest
from letta.schemas.letta_response import LettaBatchResponse
from letta.schemas.message import Message, MessageCreate, MessageUpdate
from letta.schemas.user import User
from letta.services.agent_manager import AgentManager
from letta.services.block_manager import BlockManager
from letta.services.helpers.agent_manager_helper import compile_system_message
from letta.services.llm_batch_manager import LLMBatchManager
from letta.services.message_manager import MessageManager
from letta.services.passage_manager import PassageManager
from letta.utils import united_diff
logger = get_logger(__name__)
# TODO: Limitations ->
# TODO: Only works with anthropic for now
class LettaAgentBatch:
def __init__(
self,
batch_id: str,
message_manager: MessageManager,
agent_manager: AgentManager,
block_manager: BlockManager,
passage_manager: PassageManager,
batch_manager: LLMBatchManager,
actor: User,
use_assistant_message: bool = True,
max_steps: int = 10,
):
self.batch_id = batch_id
self.message_manager = message_manager
self.agent_manager = agent_manager
self.block_manager = block_manager
self.passage_manager = passage_manager
self.batch_manager = batch_manager
self.use_assistant_message = use_assistant_message
self.actor = actor
self.max_steps = max_steps
async def step_until_request(
self, batch_requests: List[LettaBatchRequest], agent_step_state_mapping: Dict[str, AgentStepState]
) -> LettaBatchResponse:
agent_messages_mapping: Dict[str, List[Message]] = {}
agent_tools_mapping: Dict[str, List[dict]] = {}
agent_states = []
for batch_request in batch_requests:
agent_id = batch_request.agent_id
agent_state = self.agent_manager.get_agent_by_id(agent_id)
agent_states.append(agent_state)
agent_messages_mapping[agent_id] = self.get_in_context_messages_per_agent(
agent_state=agent_state, input_messages=batch_request.messages
)
agent_tools_mapping[agent_id] = self.prepare_tools_per_agent(
agent_state, agent_step_state_mapping.get(agent_id).tool_rules_solver
)
# TODO: This is a hack, this is because LLM client expects a LLM config
# TODO: But that doesn't really work in batch land
# TODO: @caren will factor this out
llm_client = LLMClient.create(
llm_config=agent_states[0].llm_config,
put_inner_thoughts_first=True,
)
agent_llm_config_mapping = {agent_state.id: agent_state.llm_config for agent_state in agent_states}
batch_response = await llm_client.send_llm_batch_request_async(
agent_messages_mapping=agent_messages_mapping,
agent_tools_mapping=agent_tools_mapping,
agent_llm_config_mapping=agent_llm_config_mapping,
)
# Write the response into the jobs table, where it will get picked up by the next cron run
batch_job = self.batch_manager.create_batch_job(
llm_provider=ProviderType.anthropic, # TODO: Expand to more
create_batch_response=batch_response,
actor=self.actor,
status=JobStatus.running,
)
# TODO: Make this much more efficient by doing creates in bulk
for agent_state in agent_states:
agent_step_state = agent_step_state_mapping.get(agent_state.id)
self.batch_manager.create_batch_item(
batch_id=batch_job.id,
agent_id=agent_state.id,
llm_config=agent_state.llm_config,
actor=self.actor,
step_state=agent_step_state,
)
return LettaBatchResponse(
batch_id=batch_job.id, statue=batch_job.status, last_polled_at=batch_job.last_polled_at, created_at=batch_job.created_at
)
async def resume_step_after_request(self, batch_id: str):
pass
def prepare_tools_per_agent(self, agent_state: AgentState, tool_rules_solver: ToolRulesSolver) -> List[dict]:
tools = [t for t in agent_state.tools if t.tool_type in {ToolType.CUSTOM, ToolType.LETTA_CORE, ToolType.LETTA_MEMORY_CORE}]
valid_tool_names = tool_rules_solver.get_allowed_tool_names(available_tools=set([t.name for t in tools]))
return [enable_strict_mode(t.json_schema) for t in tools if t.name in set(valid_tool_names)]
def get_in_context_messages_per_agent(self, agent_state: AgentState, input_messages: List[MessageCreate]) -> List[Message]:
current_in_context_messages, new_in_context_messages = _prepare_in_context_messages(
input_messages, agent_state, self.message_manager, self.actor
)
in_context_messages = self._rebuild_memory(current_in_context_messages + new_in_context_messages, agent_state)
return in_context_messages
# TODO: Make this a bullk function
def _rebuild_memory(self, in_context_messages: List[Message], agent_state: AgentState) -> List[Message]:
self.agent_manager.refresh_memory(agent_state=agent_state, actor=self.actor)
# TODO: This is a pretty brittle pattern established all over our code, need to get rid of this
curr_system_message = in_context_messages[0]
curr_memory_str = agent_state.memory.compile()
curr_system_message_text = curr_system_message.content[0].text
if curr_memory_str in curr_system_message_text:
# NOTE: could this cause issues if a block is removed? (substring match would still work)
logger.debug(
f"Memory hasn't changed for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name}), skipping system prompt rebuild"
)
return in_context_messages
memory_edit_timestamp = get_utc_time()
num_messages = self.message_manager.size(actor=self.actor, agent_id=agent_state.id)
num_archival_memories = self.passage_manager.size(actor=self.actor, agent_id=agent_state.id)
new_system_message_str = compile_system_message(
system_prompt=agent_state.system,
in_context_memory=agent_state.memory,
in_context_memory_last_edit=memory_edit_timestamp,
previous_message_count=num_messages,
archival_memory_size=num_archival_memories,
)
diff = united_diff(curr_system_message_text, new_system_message_str)
if len(diff) > 0:
logger.debug(f"Rebuilding system with new memory...\nDiff:\n{diff}")
new_system_message = self.message_manager.update_message_by_id(
curr_system_message.id, message_update=MessageUpdate(content=new_system_message_str), actor=self.actor
)
# Skip pulling down the agent's memory again to save on a db call
return [new_system_message] + in_context_messages[1:]
else:
return in_context_messages

View File

@ -1,6 +1,7 @@
from abc import abstractmethod
from typing import Dict, List, Optional, Union
from anthropic.types.beta.messages import BetaMessageBatch
from openai import AsyncStream, Stream
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
@ -80,8 +81,11 @@ class LLMClientBase:
return self.convert_response_to_chat_completion(response_data, messages)
async def send_llm_batch_request_async(
self, agent_messages_mapping: Dict[str, List[Message]], agent_tools_mapping: Dict[str, List[dict]]
):
self,
agent_messages_mapping: Dict[str, List[Message]],
agent_tools_mapping: Dict[str, List[dict]],
agent_llm_config_mapping: Dict[str, LLMConfig],
) -> Union[BetaMessageBatch]:
raise NotImplementedError
@abstractmethod

View File

@ -27,3 +27,7 @@ class LettaStreamingRequest(LettaRequest):
default=False,
description="Flag to determine if individual tokens should be streamed. Set to True for token streaming (requires stream_steps = True).",
)
class LettaBatchRequest(LettaRequest):
agent_id: str = Field(..., description="The ID of the agent to send this batch request for")

View File

@ -1,12 +1,13 @@
import html
import json
import re
from datetime import datetime
from typing import List, Union
from pydantic import BaseModel, Field
from letta.helpers.json_helpers import json_dumps
from letta.schemas.enums import MessageStreamStatus
from letta.schemas.enums import JobStatus, MessageStreamStatus
from letta.schemas.letta_message import LettaMessage, LettaMessageUnion
from letta.schemas.usage import LettaUsageStatistics
@ -165,3 +166,10 @@ class LettaResponse(BaseModel):
# The streaming response is either [DONE], [DONE_STEP], [DONE], an error, or a LettaMessage
LettaStreamingResponse = Union[LettaMessage, MessageStreamStatus, LettaUsageStatistics]
class LettaBatchResponse(BaseModel):
batch_id: str = Field(..., description="A unique identifier for this batch request.")
status: JobStatus = Field(..., description="The current status of the batch request.")
last_polled_at: datetime = Field(..., description="The timestamp when the batch was last polled for updates.")
created_at: datetime = Field(..., description="The timestamp when the batch request was created.")

View File

@ -658,6 +658,9 @@ class AgentManager:
message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids
return self.message_manager.get_message_by_id(message_id=message_ids[0], actor=actor)
# TODO: This is duplicated below
# TODO: This is legacy code and should be cleaned up
# TODO: A lot of the memory "compilation" should be offset to a separate class
@enforce_types
def rebuild_system_prompt(self, agent_id: str, actor: PydanticUser, force=False, update_timestamp=True) -> PydanticAgentState:
"""Rebuilds the system message with the latest memory object and any shared memory block updates

View File

@ -28,7 +28,7 @@ class LLMBatchManager:
self.session_maker = db_context
@enforce_types
def create_batch_request(
def create_batch_job(
self,
llm_provider: ProviderType,
create_batch_response: BetaMessageBatch,

View File

@ -147,7 +147,7 @@ def create_test_agent(client, name, model="anthropic/claude-3-5-sonnet-20241022"
def create_test_batch_job(server, batch_response, default_user):
"""Create a test batch job with the given batch response."""
return server.batch_manager.create_batch_request(
return server.batch_manager.create_batch_job(
llm_provider=ProviderType.anthropic,
create_batch_response=batch_response,
actor=default_user,

View File

@ -4680,7 +4680,7 @@ def test_list_tags(server: SyncServer, default_user, default_organization):
def test_create_and_get_batch_request(server, default_user, dummy_beta_message_batch):
batch = server.batch_manager.create_batch_request(
batch = server.batch_manager.create_batch_job(
llm_provider=ProviderType.anthropic,
status=JobStatus.created,
create_batch_response=dummy_beta_message_batch,
@ -4693,7 +4693,7 @@ def test_create_and_get_batch_request(server, default_user, dummy_beta_message_b
def test_update_batch_status(server, default_user, dummy_beta_message_batch):
batch = server.batch_manager.create_batch_request(
batch = server.batch_manager.create_batch_job(
llm_provider=ProviderType.anthropic,
status=JobStatus.created,
create_batch_response=dummy_beta_message_batch,
@ -4715,7 +4715,7 @@ def test_update_batch_status(server, default_user, dummy_beta_message_batch):
def test_create_and_get_batch_item(server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state):
batch = server.batch_manager.create_batch_request(
batch = server.batch_manager.create_batch_job(
llm_provider=ProviderType.anthropic,
status=JobStatus.created,
create_batch_response=dummy_beta_message_batch,
@ -4741,7 +4741,7 @@ def test_create_and_get_batch_item(server, default_user, sarah_agent, dummy_beta
def test_update_batch_item(
server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state, dummy_successful_response
):
batch = server.batch_manager.create_batch_request(
batch = server.batch_manager.create_batch_job(
llm_provider=ProviderType.anthropic,
status=JobStatus.created,
create_batch_response=dummy_beta_message_batch,
@ -4773,7 +4773,7 @@ def test_update_batch_item(
def test_delete_batch_item(server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state):
batch = server.batch_manager.create_batch_request(
batch = server.batch_manager.create_batch_job(
llm_provider=ProviderType.anthropic,
status=JobStatus.created,
create_batch_response=dummy_beta_message_batch,