feat: Write batch request on base LLM client (#1646)

This commit is contained in:
Matthew Zhou 2025-04-09 14:58:26 -07:00 committed by GitHub
parent 15c04bc28c
commit c0e1f793cf
5 changed files with 131 additions and 75 deletions

View File

@ -59,25 +59,49 @@ class AnthropicClient(LLMClientBase):
return await client.beta.messages.create(**request_data, betas=["tools-2024-04-04"])
@trace_method
async def batch_async(self, requests: Dict[str, dict]) -> BetaMessageBatch:
async def send_llm_batch_request_async(
self,
agent_messages_mapping: Dict[str, List[PydanticMessage]],
agent_tools_mapping: Dict[str, List[dict]],
) -> BetaMessageBatch:
"""
Send a batch of requests to the Anthropic API asynchronously.
Sends a batch request to the Anthropic API using the provided agent messages and tools mappings.
Args:
requests (Dict[str, dict]): A mapping from custom_id to request parameter dicts.
agent_messages_mapping: A dict mapping agent_id to their list of PydanticMessages.
agent_tools_mapping: A dict mapping agent_id to their list of tool dicts.
Returns:
List[dict]: A list of response dictionaries corresponding to each request.
BetaMessageBatch: The batch response from the Anthropic API.
Raises:
ValueError: If the sets of agent_ids in the two mappings do not match.
Exception: Transformed errors from the underlying API call.
"""
client = self._get_anthropic_client(async_client=True)
# Validate that both mappings use the same set of agent_ids.
if set(agent_messages_mapping.keys()) != set(agent_tools_mapping.keys()):
raise ValueError("Agent mappings for messages and tools must use the same agent_ids.")
anthropic_requests = [
Request(custom_id=custom_id, params=MessageCreateParamsNonStreaming(**params)) for custom_id, params in requests.items()
]
try:
requests = {
agent_id: self.build_request_data(messages=agent_messages_mapping[agent_id], tools=agent_tools_mapping[agent_id])
for agent_id in agent_messages_mapping
}
batch_response = await client.beta.messages.batches.create(requests=anthropic_requests)
client = self._get_anthropic_client(async_client=True)
return batch_response
anthropic_requests = [
Request(custom_id=agent_id, params=MessageCreateParamsNonStreaming(**params)) for agent_id, params in requests.items()
]
batch_response = await client.beta.messages.batches.create(requests=anthropic_requests)
return batch_response
except Exception as e:
# Enhance logging here if additional context is needed
logger.error("Error during send_llm_batch_request_async.", exc_info=True)
raise self.handle_llm_error(e)
@trace_method
def _get_anthropic_client(self, async_client: bool = False) -> Union[anthropic.AsyncAnthropic, anthropic.Anthropic]:

View File

@ -1,5 +1,5 @@
from abc import abstractmethod
from typing import List, Optional, Union
from typing import Dict, List, Optional, Union
from openai import AsyncStream, Stream
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
@ -21,7 +21,6 @@ class LLMClientBase:
self,
llm_config: LLMConfig,
put_inner_thoughts_first: Optional[bool] = True,
use_structured_output: Optional[bool] = True,
use_tool_naming: bool = True,
):
self.llm_config = llm_config
@ -67,7 +66,6 @@ class LLMClientBase:
Otherwise returns a ChatCompletionResponse.
"""
request_data = self.build_request_data(messages, tools, force_tool_call)
response_data = {}
try:
log_event(name="llm_request_sent", attributes=request_data)
@ -81,6 +79,11 @@ class LLMClientBase:
return self.convert_response_to_chat_completion(response_data, messages)
async def send_llm_batch_request_async(
self, agent_messages_mapping: Dict[str, List[Message]], agent_tools_mapping: Dict[str, List[dict]]
):
raise NotImplementedError
@abstractmethod
def build_request_data(
self,

View File

@ -1,7 +1,9 @@
import logging
from datetime import datetime, timezone
from typing import Generator
import pytest
from anthropic.types.beta.messages import BetaMessageBatch, BetaMessageBatchRequestCounts
from letta.services.organization_manager import OrganizationManager
from letta.services.user_manager import UserManager
@ -103,3 +105,25 @@ def print_tool_func():
return message
yield print_tool
@pytest.fixture
def dummy_beta_message_batch() -> BetaMessageBatch:
return BetaMessageBatch(
id="msgbatch_013Zva2CMHLNnXjNJJKqJ2EF",
archived_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc),
cancel_initiated_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc),
created_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc),
ended_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc),
expires_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc),
processing_status="in_progress",
request_counts=BetaMessageBatchRequestCounts(
canceled=10,
errored=30,
expired=10,
processing=100,
succeeded=50,
),
results_url="https://api.anthropic.com/v1/messages/batches/msgbatch_013Zva2CMHLNnXjNJJKqJ2EF/results",
type="message_batch",
)

View File

@ -1,7 +1,13 @@
from datetime import datetime
from unittest.mock import AsyncMock, patch
import pytest
# Import your AnthropicClient and related types
from letta.llm_api.anthropic_client import AnthropicClient
from letta.schemas.enums import MessageRole
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message as PydanticMessage
@pytest.fixture
@ -11,48 +17,74 @@ def anthropic_client():
model_endpoint_type="anthropic",
model_endpoint="https://api.anthropic.com/v1",
context_window=32000,
handle=f"anthropic/claude-3-5-sonnet-20241022",
handle="anthropic/claude-3-5-sonnet-20241022",
put_inner_thoughts_in_kwargs=False,
max_tokens=4096,
enable_reasoner=True,
max_reasoning_tokens=1024,
)
yield AnthropicClient(llm_config=llm_config)
return AnthropicClient(llm_config=llm_config)
# ======================================================================================================================
# AnthropicClient
# ======================================================================================================================
@pytest.fixture
def mock_agent_messages():
return {
"agent-1": [
PydanticMessage(
role=MessageRole.system, content=[{"type": "text", "text": "You are a helpful assistant."}], created_at=datetime.utcnow()
),
PydanticMessage(
role=MessageRole.user, content=[{"type": "text", "text": "What's the weather like?"}], created_at=datetime.utcnow()
),
]
}
@pytest.fixture
def mock_agent_tools():
return {
"agent-1": [
{
"name": "get_weather",
"description": "Fetch current weather data",
"parameters": {
"type": "object",
"properties": {"location": {"type": "string", "description": "The location to get weather for"}},
"required": ["location"],
},
}
]
}
@pytest.mark.asyncio
async def test_batch_async_live(anthropic_client):
input_requests = {
"my-first-request": {
"model": "claude-3-7-sonnet-20250219",
"max_tokens": 1024,
"messages": [
{
"role": "user",
"content": "Hello, world",
}
],
},
"my-second-request": {
"model": "claude-3-7-sonnet-20250219",
"max_tokens": 1024,
"messages": [
{
"role": "user",
"content": "Hi again, friend",
}
],
},
}
async def test_send_llm_batch_request_async_success(anthropic_client, mock_agent_messages, mock_agent_tools, dummy_beta_message_batch):
"""Test a successful batch request using mocked Anthropic client responses."""
# Patch the _get_anthropic_client method so that it returns a mock client.
with patch.object(anthropic_client, "_get_anthropic_client") as mock_get_client:
mock_client = AsyncMock()
# Set the create method to return the dummy response asynchronously.
mock_client.beta.messages.batches.create.return_value = dummy_beta_message_batch
mock_get_client.return_value = mock_client
response = await anthropic_client.batch_async(input_requests)
assert response.id.startswith("msgbatch_")
assert response.processing_status in {"in_progress", "succeeded"}
assert response.request_counts.processing + response.request_counts.succeeded == len(input_requests.keys())
assert response.created_at < response.expires_at
# Call the method under test.
response = await anthropic_client.send_llm_batch_request_async(mock_agent_messages, mock_agent_tools)
# Assert that the response is our dummy response.
assert response.id == dummy_beta_message_batch.id
# Assert that the mocked create method was called and received the correct request payload.
assert mock_client.beta.messages.batches.create.called
requests_sent = mock_client.beta.messages.batches.create.call_args[1]["requests"]
assert isinstance(requests_sent, list)
assert all(isinstance(req, dict) and "custom_id" in req and "params" in req for req in requests_sent)
@pytest.mark.asyncio
async def test_send_llm_batch_request_async_mismatched_keys(anthropic_client, mock_agent_messages):
"""
This test verifies that if the keys in the messages and tools mappings do not match,
a ValueError is raised.
"""
mismatched_tools = {"agent-2": []} # Different agent ID than in the messages mapping.
with pytest.raises(ValueError, match="Agent mappings for messages and tools must use the same agent_ids."):
await anthropic_client.send_llm_batch_request_async(mock_agent_messages, mismatched_tools)

View File

@ -7,12 +7,7 @@ from typing import List
import pytest
from anthropic.types.beta import BetaMessage
from anthropic.types.beta.messages import (
BetaMessageBatch,
BetaMessageBatchIndividualResponse,
BetaMessageBatchRequestCounts,
BetaMessageBatchSucceededResult,
)
from anthropic.types.beta.messages import BetaMessageBatchIndividualResponse, BetaMessageBatchSucceededResult
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall
from openai.types.chat.chat_completion_message_tool_call import Function as OpenAIFunction
from sqlalchemy.exc import IntegrityError
@ -584,28 +579,6 @@ def agent_with_tags(server: SyncServer, default_user):
return [agent1, agent2, agent3]
@pytest.fixture
def dummy_beta_message_batch() -> BetaMessageBatch:
return BetaMessageBatch(
id="msgbatch_013Zva2CMHLNnXjNJJKqJ2EF",
archived_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc),
cancel_initiated_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc),
created_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc),
ended_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc),
expires_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc),
processing_status="in_progress",
request_counts=BetaMessageBatchRequestCounts(
canceled=10,
errored=30,
expired=10,
processing=100,
succeeded=50,
),
results_url="https://api.anthropic.com/v1/messages/batches/msgbatch_013Zva2CMHLNnXjNJJKqJ2EF/results",
type="message_batch",
)
@pytest.fixture
def dummy_llm_config() -> LLMConfig:
return LLMConfig.default_config("gpt-4")