mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
feat: Write batch request on base LLM client (#1646)
This commit is contained in:
parent
15c04bc28c
commit
c0e1f793cf
@ -59,25 +59,49 @@ class AnthropicClient(LLMClientBase):
|
||||
return await client.beta.messages.create(**request_data, betas=["tools-2024-04-04"])
|
||||
|
||||
@trace_method
|
||||
async def batch_async(self, requests: Dict[str, dict]) -> BetaMessageBatch:
|
||||
async def send_llm_batch_request_async(
|
||||
self,
|
||||
agent_messages_mapping: Dict[str, List[PydanticMessage]],
|
||||
agent_tools_mapping: Dict[str, List[dict]],
|
||||
) -> BetaMessageBatch:
|
||||
"""
|
||||
Send a batch of requests to the Anthropic API asynchronously.
|
||||
Sends a batch request to the Anthropic API using the provided agent messages and tools mappings.
|
||||
|
||||
Args:
|
||||
requests (Dict[str, dict]): A mapping from custom_id to request parameter dicts.
|
||||
agent_messages_mapping: A dict mapping agent_id to their list of PydanticMessages.
|
||||
agent_tools_mapping: A dict mapping agent_id to their list of tool dicts.
|
||||
|
||||
Returns:
|
||||
List[dict]: A list of response dictionaries corresponding to each request.
|
||||
BetaMessageBatch: The batch response from the Anthropic API.
|
||||
|
||||
Raises:
|
||||
ValueError: If the sets of agent_ids in the two mappings do not match.
|
||||
Exception: Transformed errors from the underlying API call.
|
||||
"""
|
||||
client = self._get_anthropic_client(async_client=True)
|
||||
# Validate that both mappings use the same set of agent_ids.
|
||||
if set(agent_messages_mapping.keys()) != set(agent_tools_mapping.keys()):
|
||||
raise ValueError("Agent mappings for messages and tools must use the same agent_ids.")
|
||||
|
||||
anthropic_requests = [
|
||||
Request(custom_id=custom_id, params=MessageCreateParamsNonStreaming(**params)) for custom_id, params in requests.items()
|
||||
]
|
||||
try:
|
||||
requests = {
|
||||
agent_id: self.build_request_data(messages=agent_messages_mapping[agent_id], tools=agent_tools_mapping[agent_id])
|
||||
for agent_id in agent_messages_mapping
|
||||
}
|
||||
|
||||
batch_response = await client.beta.messages.batches.create(requests=anthropic_requests)
|
||||
client = self._get_anthropic_client(async_client=True)
|
||||
|
||||
return batch_response
|
||||
anthropic_requests = [
|
||||
Request(custom_id=agent_id, params=MessageCreateParamsNonStreaming(**params)) for agent_id, params in requests.items()
|
||||
]
|
||||
|
||||
batch_response = await client.beta.messages.batches.create(requests=anthropic_requests)
|
||||
|
||||
return batch_response
|
||||
|
||||
except Exception as e:
|
||||
# Enhance logging here if additional context is needed
|
||||
logger.error("Error during send_llm_batch_request_async.", exc_info=True)
|
||||
raise self.handle_llm_error(e)
|
||||
|
||||
@trace_method
|
||||
def _get_anthropic_client(self, async_client: bool = False) -> Union[anthropic.AsyncAnthropic, anthropic.Anthropic]:
|
||||
|
@ -1,5 +1,5 @@
|
||||
from abc import abstractmethod
|
||||
from typing import List, Optional, Union
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from openai import AsyncStream, Stream
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
@ -21,7 +21,6 @@ class LLMClientBase:
|
||||
self,
|
||||
llm_config: LLMConfig,
|
||||
put_inner_thoughts_first: Optional[bool] = True,
|
||||
use_structured_output: Optional[bool] = True,
|
||||
use_tool_naming: bool = True,
|
||||
):
|
||||
self.llm_config = llm_config
|
||||
@ -67,7 +66,6 @@ class LLMClientBase:
|
||||
Otherwise returns a ChatCompletionResponse.
|
||||
"""
|
||||
request_data = self.build_request_data(messages, tools, force_tool_call)
|
||||
response_data = {}
|
||||
|
||||
try:
|
||||
log_event(name="llm_request_sent", attributes=request_data)
|
||||
@ -81,6 +79,11 @@ class LLMClientBase:
|
||||
|
||||
return self.convert_response_to_chat_completion(response_data, messages)
|
||||
|
||||
async def send_llm_batch_request_async(
|
||||
self, agent_messages_mapping: Dict[str, List[Message]], agent_tools_mapping: Dict[str, List[dict]]
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def build_request_data(
|
||||
self,
|
||||
|
@ -1,7 +1,9 @@
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Generator
|
||||
|
||||
import pytest
|
||||
from anthropic.types.beta.messages import BetaMessageBatch, BetaMessageBatchRequestCounts
|
||||
|
||||
from letta.services.organization_manager import OrganizationManager
|
||||
from letta.services.user_manager import UserManager
|
||||
@ -103,3 +105,25 @@ def print_tool_func():
|
||||
return message
|
||||
|
||||
yield print_tool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_beta_message_batch() -> BetaMessageBatch:
|
||||
return BetaMessageBatch(
|
||||
id="msgbatch_013Zva2CMHLNnXjNJJKqJ2EF",
|
||||
archived_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc),
|
||||
cancel_initiated_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc),
|
||||
created_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc),
|
||||
ended_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc),
|
||||
expires_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc),
|
||||
processing_status="in_progress",
|
||||
request_counts=BetaMessageBatchRequestCounts(
|
||||
canceled=10,
|
||||
errored=30,
|
||||
expired=10,
|
||||
processing=100,
|
||||
succeeded=50,
|
||||
),
|
||||
results_url="https://api.anthropic.com/v1/messages/batches/msgbatch_013Zva2CMHLNnXjNJJKqJ2EF/results",
|
||||
type="message_batch",
|
||||
)
|
||||
|
@ -1,7 +1,13 @@
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
# Import your AnthropicClient and related types
|
||||
from letta.llm_api.anthropic_client import AnthropicClient
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -11,48 +17,74 @@ def anthropic_client():
|
||||
model_endpoint_type="anthropic",
|
||||
model_endpoint="https://api.anthropic.com/v1",
|
||||
context_window=32000,
|
||||
handle=f"anthropic/claude-3-5-sonnet-20241022",
|
||||
handle="anthropic/claude-3-5-sonnet-20241022",
|
||||
put_inner_thoughts_in_kwargs=False,
|
||||
max_tokens=4096,
|
||||
enable_reasoner=True,
|
||||
max_reasoning_tokens=1024,
|
||||
)
|
||||
|
||||
yield AnthropicClient(llm_config=llm_config)
|
||||
return AnthropicClient(llm_config=llm_config)
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# AnthropicClient
|
||||
# ======================================================================================================================
|
||||
@pytest.fixture
|
||||
def mock_agent_messages():
|
||||
return {
|
||||
"agent-1": [
|
||||
PydanticMessage(
|
||||
role=MessageRole.system, content=[{"type": "text", "text": "You are a helpful assistant."}], created_at=datetime.utcnow()
|
||||
),
|
||||
PydanticMessage(
|
||||
role=MessageRole.user, content=[{"type": "text", "text": "What's the weather like?"}], created_at=datetime.utcnow()
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_agent_tools():
|
||||
return {
|
||||
"agent-1": [
|
||||
{
|
||||
"name": "get_weather",
|
||||
"description": "Fetch current weather data",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"location": {"type": "string", "description": "The location to get weather for"}},
|
||||
"required": ["location"],
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_async_live(anthropic_client):
|
||||
input_requests = {
|
||||
"my-first-request": {
|
||||
"model": "claude-3-7-sonnet-20250219",
|
||||
"max_tokens": 1024,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello, world",
|
||||
}
|
||||
],
|
||||
},
|
||||
"my-second-request": {
|
||||
"model": "claude-3-7-sonnet-20250219",
|
||||
"max_tokens": 1024,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hi again, friend",
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
async def test_send_llm_batch_request_async_success(anthropic_client, mock_agent_messages, mock_agent_tools, dummy_beta_message_batch):
|
||||
"""Test a successful batch request using mocked Anthropic client responses."""
|
||||
# Patch the _get_anthropic_client method so that it returns a mock client.
|
||||
with patch.object(anthropic_client, "_get_anthropic_client") as mock_get_client:
|
||||
mock_client = AsyncMock()
|
||||
# Set the create method to return the dummy response asynchronously.
|
||||
mock_client.beta.messages.batches.create.return_value = dummy_beta_message_batch
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
response = await anthropic_client.batch_async(input_requests)
|
||||
assert response.id.startswith("msgbatch_")
|
||||
assert response.processing_status in {"in_progress", "succeeded"}
|
||||
assert response.request_counts.processing + response.request_counts.succeeded == len(input_requests.keys())
|
||||
assert response.created_at < response.expires_at
|
||||
# Call the method under test.
|
||||
response = await anthropic_client.send_llm_batch_request_async(mock_agent_messages, mock_agent_tools)
|
||||
|
||||
# Assert that the response is our dummy response.
|
||||
assert response.id == dummy_beta_message_batch.id
|
||||
# Assert that the mocked create method was called and received the correct request payload.
|
||||
assert mock_client.beta.messages.batches.create.called
|
||||
requests_sent = mock_client.beta.messages.batches.create.call_args[1]["requests"]
|
||||
assert isinstance(requests_sent, list)
|
||||
assert all(isinstance(req, dict) and "custom_id" in req and "params" in req for req in requests_sent)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_llm_batch_request_async_mismatched_keys(anthropic_client, mock_agent_messages):
|
||||
"""
|
||||
This test verifies that if the keys in the messages and tools mappings do not match,
|
||||
a ValueError is raised.
|
||||
"""
|
||||
mismatched_tools = {"agent-2": []} # Different agent ID than in the messages mapping.
|
||||
with pytest.raises(ValueError, match="Agent mappings for messages and tools must use the same agent_ids."):
|
||||
await anthropic_client.send_llm_batch_request_async(mock_agent_messages, mismatched_tools)
|
||||
|
@ -7,12 +7,7 @@ from typing import List
|
||||
|
||||
import pytest
|
||||
from anthropic.types.beta import BetaMessage
|
||||
from anthropic.types.beta.messages import (
|
||||
BetaMessageBatch,
|
||||
BetaMessageBatchIndividualResponse,
|
||||
BetaMessageBatchRequestCounts,
|
||||
BetaMessageBatchSucceededResult,
|
||||
)
|
||||
from anthropic.types.beta.messages import BetaMessageBatchIndividualResponse, BetaMessageBatchSucceededResult
|
||||
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall
|
||||
from openai.types.chat.chat_completion_message_tool_call import Function as OpenAIFunction
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
@ -584,28 +579,6 @@ def agent_with_tags(server: SyncServer, default_user):
|
||||
return [agent1, agent2, agent3]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_beta_message_batch() -> BetaMessageBatch:
|
||||
return BetaMessageBatch(
|
||||
id="msgbatch_013Zva2CMHLNnXjNJJKqJ2EF",
|
||||
archived_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc),
|
||||
cancel_initiated_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc),
|
||||
created_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc),
|
||||
ended_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc),
|
||||
expires_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc),
|
||||
processing_status="in_progress",
|
||||
request_counts=BetaMessageBatchRequestCounts(
|
||||
canceled=10,
|
||||
errored=30,
|
||||
expired=10,
|
||||
processing=100,
|
||||
succeeded=50,
|
||||
),
|
||||
results_url="https://api.anthropic.com/v1/messages/batches/msgbatch_013Zva2CMHLNnXjNJJKqJ2EF/results",
|
||||
type="message_batch",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_llm_config() -> LLMConfig:
|
||||
return LLMConfig.default_config("gpt-4")
|
||||
|
Loading…
Reference in New Issue
Block a user