feat: Write batch request on base LLM client (#1646)

This commit is contained in:
Matthew Zhou 2025-04-09 14:58:26 -07:00 committed by GitHub
parent 15c04bc28c
commit c0e1f793cf
5 changed files with 131 additions and 75 deletions

View File

@ -59,25 +59,49 @@ class AnthropicClient(LLMClientBase):
return await client.beta.messages.create(**request_data, betas=["tools-2024-04-04"]) return await client.beta.messages.create(**request_data, betas=["tools-2024-04-04"])
@trace_method @trace_method
async def batch_async(self, requests: Dict[str, dict]) -> BetaMessageBatch: async def send_llm_batch_request_async(
self,
agent_messages_mapping: Dict[str, List[PydanticMessage]],
agent_tools_mapping: Dict[str, List[dict]],
) -> BetaMessageBatch:
""" """
Send a batch of requests to the Anthropic API asynchronously. Sends a batch request to the Anthropic API using the provided agent messages and tools mappings.
Args: Args:
requests (Dict[str, dict]): A mapping from custom_id to request parameter dicts. agent_messages_mapping: A dict mapping agent_id to their list of PydanticMessages.
agent_tools_mapping: A dict mapping agent_id to their list of tool dicts.
Returns: Returns:
List[dict]: A list of response dictionaries corresponding to each request. BetaMessageBatch: The batch response from the Anthropic API.
Raises:
ValueError: If the sets of agent_ids in the two mappings do not match.
Exception: Transformed errors from the underlying API call.
""" """
client = self._get_anthropic_client(async_client=True) # Validate that both mappings use the same set of agent_ids.
if set(agent_messages_mapping.keys()) != set(agent_tools_mapping.keys()):
raise ValueError("Agent mappings for messages and tools must use the same agent_ids.")
anthropic_requests = [ try:
Request(custom_id=custom_id, params=MessageCreateParamsNonStreaming(**params)) for custom_id, params in requests.items() requests = {
] agent_id: self.build_request_data(messages=agent_messages_mapping[agent_id], tools=agent_tools_mapping[agent_id])
for agent_id in agent_messages_mapping
}
batch_response = await client.beta.messages.batches.create(requests=anthropic_requests) client = self._get_anthropic_client(async_client=True)
return batch_response anthropic_requests = [
Request(custom_id=agent_id, params=MessageCreateParamsNonStreaming(**params)) for agent_id, params in requests.items()
]
batch_response = await client.beta.messages.batches.create(requests=anthropic_requests)
return batch_response
except Exception as e:
# Enhance logging here if additional context is needed
logger.error("Error during send_llm_batch_request_async.", exc_info=True)
raise self.handle_llm_error(e)
@trace_method @trace_method
def _get_anthropic_client(self, async_client: bool = False) -> Union[anthropic.AsyncAnthropic, anthropic.Anthropic]: def _get_anthropic_client(self, async_client: bool = False) -> Union[anthropic.AsyncAnthropic, anthropic.Anthropic]:

View File

@ -1,5 +1,5 @@
from abc import abstractmethod from abc import abstractmethod
from typing import List, Optional, Union from typing import Dict, List, Optional, Union
from openai import AsyncStream, Stream from openai import AsyncStream, Stream
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
@ -21,7 +21,6 @@ class LLMClientBase:
self, self,
llm_config: LLMConfig, llm_config: LLMConfig,
put_inner_thoughts_first: Optional[bool] = True, put_inner_thoughts_first: Optional[bool] = True,
use_structured_output: Optional[bool] = True,
use_tool_naming: bool = True, use_tool_naming: bool = True,
): ):
self.llm_config = llm_config self.llm_config = llm_config
@ -67,7 +66,6 @@ class LLMClientBase:
Otherwise returns a ChatCompletionResponse. Otherwise returns a ChatCompletionResponse.
""" """
request_data = self.build_request_data(messages, tools, force_tool_call) request_data = self.build_request_data(messages, tools, force_tool_call)
response_data = {}
try: try:
log_event(name="llm_request_sent", attributes=request_data) log_event(name="llm_request_sent", attributes=request_data)
@ -81,6 +79,11 @@ class LLMClientBase:
return self.convert_response_to_chat_completion(response_data, messages) return self.convert_response_to_chat_completion(response_data, messages)
async def send_llm_batch_request_async(
self, agent_messages_mapping: Dict[str, List[Message]], agent_tools_mapping: Dict[str, List[dict]]
):
raise NotImplementedError
@abstractmethod @abstractmethod
def build_request_data( def build_request_data(
self, self,

View File

@ -1,7 +1,9 @@
import logging import logging
from datetime import datetime, timezone
from typing import Generator from typing import Generator
import pytest import pytest
from anthropic.types.beta.messages import BetaMessageBatch, BetaMessageBatchRequestCounts
from letta.services.organization_manager import OrganizationManager from letta.services.organization_manager import OrganizationManager
from letta.services.user_manager import UserManager from letta.services.user_manager import UserManager
@ -103,3 +105,25 @@ def print_tool_func():
return message return message
yield print_tool yield print_tool
@pytest.fixture
def dummy_beta_message_batch() -> BetaMessageBatch:
return BetaMessageBatch(
id="msgbatch_013Zva2CMHLNnXjNJJKqJ2EF",
archived_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc),
cancel_initiated_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc),
created_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc),
ended_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc),
expires_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc),
processing_status="in_progress",
request_counts=BetaMessageBatchRequestCounts(
canceled=10,
errored=30,
expired=10,
processing=100,
succeeded=50,
),
results_url="https://api.anthropic.com/v1/messages/batches/msgbatch_013Zva2CMHLNnXjNJJKqJ2EF/results",
type="message_batch",
)

View File

@ -1,7 +1,13 @@
from datetime import datetime
from unittest.mock import AsyncMock, patch
import pytest import pytest
# Import your AnthropicClient and related types
from letta.llm_api.anthropic_client import AnthropicClient from letta.llm_api.anthropic_client import AnthropicClient
from letta.schemas.enums import MessageRole
from letta.schemas.llm_config import LLMConfig from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message as PydanticMessage
@pytest.fixture @pytest.fixture
@ -11,48 +17,74 @@ def anthropic_client():
model_endpoint_type="anthropic", model_endpoint_type="anthropic",
model_endpoint="https://api.anthropic.com/v1", model_endpoint="https://api.anthropic.com/v1",
context_window=32000, context_window=32000,
handle=f"anthropic/claude-3-5-sonnet-20241022", handle="anthropic/claude-3-5-sonnet-20241022",
put_inner_thoughts_in_kwargs=False, put_inner_thoughts_in_kwargs=False,
max_tokens=4096, max_tokens=4096,
enable_reasoner=True, enable_reasoner=True,
max_reasoning_tokens=1024, max_reasoning_tokens=1024,
) )
return AnthropicClient(llm_config=llm_config)
yield AnthropicClient(llm_config=llm_config)
# ====================================================================================================================== @pytest.fixture
# AnthropicClient def mock_agent_messages():
# ====================================================================================================================== return {
"agent-1": [
PydanticMessage(
role=MessageRole.system, content=[{"type": "text", "text": "You are a helpful assistant."}], created_at=datetime.utcnow()
),
PydanticMessage(
role=MessageRole.user, content=[{"type": "text", "text": "What's the weather like?"}], created_at=datetime.utcnow()
),
]
}
@pytest.fixture
def mock_agent_tools():
return {
"agent-1": [
{
"name": "get_weather",
"description": "Fetch current weather data",
"parameters": {
"type": "object",
"properties": {"location": {"type": "string", "description": "The location to get weather for"}},
"required": ["location"],
},
}
]
}
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_batch_async_live(anthropic_client): async def test_send_llm_batch_request_async_success(anthropic_client, mock_agent_messages, mock_agent_tools, dummy_beta_message_batch):
input_requests = { """Test a successful batch request using mocked Anthropic client responses."""
"my-first-request": { # Patch the _get_anthropic_client method so that it returns a mock client.
"model": "claude-3-7-sonnet-20250219", with patch.object(anthropic_client, "_get_anthropic_client") as mock_get_client:
"max_tokens": 1024, mock_client = AsyncMock()
"messages": [ # Set the create method to return the dummy response asynchronously.
{ mock_client.beta.messages.batches.create.return_value = dummy_beta_message_batch
"role": "user", mock_get_client.return_value = mock_client
"content": "Hello, world",
}
],
},
"my-second-request": {
"model": "claude-3-7-sonnet-20250219",
"max_tokens": 1024,
"messages": [
{
"role": "user",
"content": "Hi again, friend",
}
],
},
}
response = await anthropic_client.batch_async(input_requests) # Call the method under test.
assert response.id.startswith("msgbatch_") response = await anthropic_client.send_llm_batch_request_async(mock_agent_messages, mock_agent_tools)
assert response.processing_status in {"in_progress", "succeeded"}
assert response.request_counts.processing + response.request_counts.succeeded == len(input_requests.keys()) # Assert that the response is our dummy response.
assert response.created_at < response.expires_at assert response.id == dummy_beta_message_batch.id
# Assert that the mocked create method was called and received the correct request payload.
assert mock_client.beta.messages.batches.create.called
requests_sent = mock_client.beta.messages.batches.create.call_args[1]["requests"]
assert isinstance(requests_sent, list)
assert all(isinstance(req, dict) and "custom_id" in req and "params" in req for req in requests_sent)
@pytest.mark.asyncio
async def test_send_llm_batch_request_async_mismatched_keys(anthropic_client, mock_agent_messages):
"""
This test verifies that if the keys in the messages and tools mappings do not match,
a ValueError is raised.
"""
mismatched_tools = {"agent-2": []} # Different agent ID than in the messages mapping.
with pytest.raises(ValueError, match="Agent mappings for messages and tools must use the same agent_ids."):
await anthropic_client.send_llm_batch_request_async(mock_agent_messages, mismatched_tools)

View File

@ -7,12 +7,7 @@ from typing import List
import pytest import pytest
from anthropic.types.beta import BetaMessage from anthropic.types.beta import BetaMessage
from anthropic.types.beta.messages import ( from anthropic.types.beta.messages import BetaMessageBatchIndividualResponse, BetaMessageBatchSucceededResult
BetaMessageBatch,
BetaMessageBatchIndividualResponse,
BetaMessageBatchRequestCounts,
BetaMessageBatchSucceededResult,
)
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall
from openai.types.chat.chat_completion_message_tool_call import Function as OpenAIFunction from openai.types.chat.chat_completion_message_tool_call import Function as OpenAIFunction
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
@ -584,28 +579,6 @@ def agent_with_tags(server: SyncServer, default_user):
return [agent1, agent2, agent3] return [agent1, agent2, agent3]
@pytest.fixture
def dummy_beta_message_batch() -> BetaMessageBatch:
return BetaMessageBatch(
id="msgbatch_013Zva2CMHLNnXjNJJKqJ2EF",
archived_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc),
cancel_initiated_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc),
created_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc),
ended_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc),
expires_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc),
processing_status="in_progress",
request_counts=BetaMessageBatchRequestCounts(
canceled=10,
errored=30,
expired=10,
processing=100,
succeeded=50,
),
results_url="https://api.anthropic.com/v1/messages/batches/msgbatch_013Zva2CMHLNnXjNJJKqJ2EF/results",
type="message_batch",
)
@pytest.fixture @pytest.fixture
def dummy_llm_config() -> LLMConfig: def dummy_llm_config() -> LLMConfig:
return LLMConfig.default_config("gpt-4") return LLMConfig.default_config("gpt-4")