From c0e1f793cf7b3c42e90eed8a304c8ca3887cd94e Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Wed, 9 Apr 2025 14:58:26 -0700 Subject: [PATCH] feat: Write batch request on base LLM client (#1646) --- letta/llm_api/anthropic_client.py | 44 ++++++++++--- letta/llm_api/llm_client_base.py | 9 ++- tests/conftest.py | 24 +++++++ tests/test_llm_clients.py | 100 ++++++++++++++++++++---------- tests/test_managers.py | 29 +-------- 5 files changed, 131 insertions(+), 75 deletions(-) diff --git a/letta/llm_api/anthropic_client.py b/letta/llm_api/anthropic_client.py index 9aabc8dd6..ece420169 100644 --- a/letta/llm_api/anthropic_client.py +++ b/letta/llm_api/anthropic_client.py @@ -59,25 +59,49 @@ class AnthropicClient(LLMClientBase): return await client.beta.messages.create(**request_data, betas=["tools-2024-04-04"]) @trace_method - async def batch_async(self, requests: Dict[str, dict]) -> BetaMessageBatch: + async def send_llm_batch_request_async( + self, + agent_messages_mapping: Dict[str, List[PydanticMessage]], + agent_tools_mapping: Dict[str, List[dict]], + ) -> BetaMessageBatch: """ - Send a batch of requests to the Anthropic API asynchronously. + Sends a batch request to the Anthropic API using the provided agent messages and tools mappings. Args: - requests (Dict[str, dict]): A mapping from custom_id to request parameter dicts. + agent_messages_mapping: A dict mapping agent_id to their list of PydanticMessages. + agent_tools_mapping: A dict mapping agent_id to their list of tool dicts. Returns: - List[dict]: A list of response dictionaries corresponding to each request. + BetaMessageBatch: The batch response from the Anthropic API. + + Raises: + ValueError: If the sets of agent_ids in the two mappings do not match. + Exception: Transformed errors from the underlying API call. """ - client = self._get_anthropic_client(async_client=True) + # Validate that both mappings use the same set of agent_ids. + if set(agent_messages_mapping.keys()) != set(agent_tools_mapping.keys()): + raise ValueError("Agent mappings for messages and tools must use the same agent_ids.") - anthropic_requests = [ - Request(custom_id=custom_id, params=MessageCreateParamsNonStreaming(**params)) for custom_id, params in requests.items() - ] + try: + requests = { + agent_id: self.build_request_data(messages=agent_messages_mapping[agent_id], tools=agent_tools_mapping[agent_id]) + for agent_id in agent_messages_mapping + } - batch_response = await client.beta.messages.batches.create(requests=anthropic_requests) + client = self._get_anthropic_client(async_client=True) - return batch_response + anthropic_requests = [ + Request(custom_id=agent_id, params=MessageCreateParamsNonStreaming(**params)) for agent_id, params in requests.items() + ] + + batch_response = await client.beta.messages.batches.create(requests=anthropic_requests) + + return batch_response + + except Exception as e: + # Enhance logging here if additional context is needed + logger.error("Error during send_llm_batch_request_async.", exc_info=True) + raise self.handle_llm_error(e) @trace_method def _get_anthropic_client(self, async_client: bool = False) -> Union[anthropic.AsyncAnthropic, anthropic.Anthropic]: diff --git a/letta/llm_api/llm_client_base.py b/letta/llm_api/llm_client_base.py index 710983e2c..4340813bb 100644 --- a/letta/llm_api/llm_client_base.py +++ b/letta/llm_api/llm_client_base.py @@ -1,5 +1,5 @@ from abc import abstractmethod -from typing import List, Optional, Union +from typing import Dict, List, Optional, Union from openai import AsyncStream, Stream from openai.types.chat.chat_completion_chunk import ChatCompletionChunk @@ -21,7 +21,6 @@ class LLMClientBase: self, llm_config: LLMConfig, put_inner_thoughts_first: Optional[bool] = True, - use_structured_output: Optional[bool] = True, use_tool_naming: bool = True, ): self.llm_config = llm_config @@ -67,7 +66,6 @@ class LLMClientBase: Otherwise returns a ChatCompletionResponse. """ request_data = self.build_request_data(messages, tools, force_tool_call) - response_data = {} try: log_event(name="llm_request_sent", attributes=request_data) @@ -81,6 +79,11 @@ class LLMClientBase: return self.convert_response_to_chat_completion(response_data, messages) + async def send_llm_batch_request_async( + self, agent_messages_mapping: Dict[str, List[Message]], agent_tools_mapping: Dict[str, List[dict]] + ): + raise NotImplementedError + @abstractmethod def build_request_data( self, diff --git a/tests/conftest.py b/tests/conftest.py index 220438e20..314e006db 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,9 @@ import logging +from datetime import datetime, timezone from typing import Generator import pytest +from anthropic.types.beta.messages import BetaMessageBatch, BetaMessageBatchRequestCounts from letta.services.organization_manager import OrganizationManager from letta.services.user_manager import UserManager @@ -103,3 +105,25 @@ def print_tool_func(): return message yield print_tool + + +@pytest.fixture +def dummy_beta_message_batch() -> BetaMessageBatch: + return BetaMessageBatch( + id="msgbatch_013Zva2CMHLNnXjNJJKqJ2EF", + archived_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc), + cancel_initiated_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc), + created_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc), + ended_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc), + expires_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc), + processing_status="in_progress", + request_counts=BetaMessageBatchRequestCounts( + canceled=10, + errored=30, + expired=10, + processing=100, + succeeded=50, + ), + results_url="https://api.anthropic.com/v1/messages/batches/msgbatch_013Zva2CMHLNnXjNJJKqJ2EF/results", + type="message_batch", + ) diff --git a/tests/test_llm_clients.py b/tests/test_llm_clients.py index 00493d2a2..8caf726f1 100644 --- a/tests/test_llm_clients.py +++ b/tests/test_llm_clients.py @@ -1,7 +1,13 @@ +from datetime import datetime +from unittest.mock import AsyncMock, patch + import pytest +# Import your AnthropicClient and related types from letta.llm_api.anthropic_client import AnthropicClient +from letta.schemas.enums import MessageRole from letta.schemas.llm_config import LLMConfig +from letta.schemas.message import Message as PydanticMessage @pytest.fixture @@ -11,48 +17,74 @@ def anthropic_client(): model_endpoint_type="anthropic", model_endpoint="https://api.anthropic.com/v1", context_window=32000, - handle=f"anthropic/claude-3-5-sonnet-20241022", + handle="anthropic/claude-3-5-sonnet-20241022", put_inner_thoughts_in_kwargs=False, max_tokens=4096, enable_reasoner=True, max_reasoning_tokens=1024, ) - - yield AnthropicClient(llm_config=llm_config) + return AnthropicClient(llm_config=llm_config) -# ====================================================================================================================== -# AnthropicClient -# ====================================================================================================================== +@pytest.fixture +def mock_agent_messages(): + return { + "agent-1": [ + PydanticMessage( + role=MessageRole.system, content=[{"type": "text", "text": "You are a helpful assistant."}], created_at=datetime.utcnow() + ), + PydanticMessage( + role=MessageRole.user, content=[{"type": "text", "text": "What's the weather like?"}], created_at=datetime.utcnow() + ), + ] + } + + +@pytest.fixture +def mock_agent_tools(): + return { + "agent-1": [ + { + "name": "get_weather", + "description": "Fetch current weather data", + "parameters": { + "type": "object", + "properties": {"location": {"type": "string", "description": "The location to get weather for"}}, + "required": ["location"], + }, + } + ] + } @pytest.mark.asyncio -async def test_batch_async_live(anthropic_client): - input_requests = { - "my-first-request": { - "model": "claude-3-7-sonnet-20250219", - "max_tokens": 1024, - "messages": [ - { - "role": "user", - "content": "Hello, world", - } - ], - }, - "my-second-request": { - "model": "claude-3-7-sonnet-20250219", - "max_tokens": 1024, - "messages": [ - { - "role": "user", - "content": "Hi again, friend", - } - ], - }, - } +async def test_send_llm_batch_request_async_success(anthropic_client, mock_agent_messages, mock_agent_tools, dummy_beta_message_batch): + """Test a successful batch request using mocked Anthropic client responses.""" + # Patch the _get_anthropic_client method so that it returns a mock client. + with patch.object(anthropic_client, "_get_anthropic_client") as mock_get_client: + mock_client = AsyncMock() + # Set the create method to return the dummy response asynchronously. + mock_client.beta.messages.batches.create.return_value = dummy_beta_message_batch + mock_get_client.return_value = mock_client - response = await anthropic_client.batch_async(input_requests) - assert response.id.startswith("msgbatch_") - assert response.processing_status in {"in_progress", "succeeded"} - assert response.request_counts.processing + response.request_counts.succeeded == len(input_requests.keys()) - assert response.created_at < response.expires_at + # Call the method under test. + response = await anthropic_client.send_llm_batch_request_async(mock_agent_messages, mock_agent_tools) + + # Assert that the response is our dummy response. + assert response.id == dummy_beta_message_batch.id + # Assert that the mocked create method was called and received the correct request payload. + assert mock_client.beta.messages.batches.create.called + requests_sent = mock_client.beta.messages.batches.create.call_args[1]["requests"] + assert isinstance(requests_sent, list) + assert all(isinstance(req, dict) and "custom_id" in req and "params" in req for req in requests_sent) + + +@pytest.mark.asyncio +async def test_send_llm_batch_request_async_mismatched_keys(anthropic_client, mock_agent_messages): + """ + This test verifies that if the keys in the messages and tools mappings do not match, + a ValueError is raised. + """ + mismatched_tools = {"agent-2": []} # Different agent ID than in the messages mapping. + with pytest.raises(ValueError, match="Agent mappings for messages and tools must use the same agent_ids."): + await anthropic_client.send_llm_batch_request_async(mock_agent_messages, mismatched_tools) diff --git a/tests/test_managers.py b/tests/test_managers.py index 83a2a0f33..0bd509aa0 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -7,12 +7,7 @@ from typing import List import pytest from anthropic.types.beta import BetaMessage -from anthropic.types.beta.messages import ( - BetaMessageBatch, - BetaMessageBatchIndividualResponse, - BetaMessageBatchRequestCounts, - BetaMessageBatchSucceededResult, -) +from anthropic.types.beta.messages import BetaMessageBatchIndividualResponse, BetaMessageBatchSucceededResult from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall from openai.types.chat.chat_completion_message_tool_call import Function as OpenAIFunction from sqlalchemy.exc import IntegrityError @@ -584,28 +579,6 @@ def agent_with_tags(server: SyncServer, default_user): return [agent1, agent2, agent3] -@pytest.fixture -def dummy_beta_message_batch() -> BetaMessageBatch: - return BetaMessageBatch( - id="msgbatch_013Zva2CMHLNnXjNJJKqJ2EF", - archived_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc), - cancel_initiated_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc), - created_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc), - ended_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc), - expires_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc), - processing_status="in_progress", - request_counts=BetaMessageBatchRequestCounts( - canceled=10, - errored=30, - expired=10, - processing=100, - succeeded=50, - ), - results_url="https://api.anthropic.com/v1/messages/batches/msgbatch_013Zva2CMHLNnXjNJJKqJ2EF/results", - type="message_batch", - ) - - @pytest.fixture def dummy_llm_config() -> LLMConfig: return LLMConfig.default_config("gpt-4")