mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
feat: Write batch request on base LLM client (#1646)
This commit is contained in:
parent
15c04bc28c
commit
c0e1f793cf
@ -59,25 +59,49 @@ class AnthropicClient(LLMClientBase):
|
|||||||
return await client.beta.messages.create(**request_data, betas=["tools-2024-04-04"])
|
return await client.beta.messages.create(**request_data, betas=["tools-2024-04-04"])
|
||||||
|
|
||||||
@trace_method
|
@trace_method
|
||||||
async def batch_async(self, requests: Dict[str, dict]) -> BetaMessageBatch:
|
async def send_llm_batch_request_async(
|
||||||
|
self,
|
||||||
|
agent_messages_mapping: Dict[str, List[PydanticMessage]],
|
||||||
|
agent_tools_mapping: Dict[str, List[dict]],
|
||||||
|
) -> BetaMessageBatch:
|
||||||
"""
|
"""
|
||||||
Send a batch of requests to the Anthropic API asynchronously.
|
Sends a batch request to the Anthropic API using the provided agent messages and tools mappings.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
requests (Dict[str, dict]): A mapping from custom_id to request parameter dicts.
|
agent_messages_mapping: A dict mapping agent_id to their list of PydanticMessages.
|
||||||
|
agent_tools_mapping: A dict mapping agent_id to their list of tool dicts.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[dict]: A list of response dictionaries corresponding to each request.
|
BetaMessageBatch: The batch response from the Anthropic API.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the sets of agent_ids in the two mappings do not match.
|
||||||
|
Exception: Transformed errors from the underlying API call.
|
||||||
"""
|
"""
|
||||||
client = self._get_anthropic_client(async_client=True)
|
# Validate that both mappings use the same set of agent_ids.
|
||||||
|
if set(agent_messages_mapping.keys()) != set(agent_tools_mapping.keys()):
|
||||||
|
raise ValueError("Agent mappings for messages and tools must use the same agent_ids.")
|
||||||
|
|
||||||
anthropic_requests = [
|
try:
|
||||||
Request(custom_id=custom_id, params=MessageCreateParamsNonStreaming(**params)) for custom_id, params in requests.items()
|
requests = {
|
||||||
]
|
agent_id: self.build_request_data(messages=agent_messages_mapping[agent_id], tools=agent_tools_mapping[agent_id])
|
||||||
|
for agent_id in agent_messages_mapping
|
||||||
|
}
|
||||||
|
|
||||||
batch_response = await client.beta.messages.batches.create(requests=anthropic_requests)
|
client = self._get_anthropic_client(async_client=True)
|
||||||
|
|
||||||
return batch_response
|
anthropic_requests = [
|
||||||
|
Request(custom_id=agent_id, params=MessageCreateParamsNonStreaming(**params)) for agent_id, params in requests.items()
|
||||||
|
]
|
||||||
|
|
||||||
|
batch_response = await client.beta.messages.batches.create(requests=anthropic_requests)
|
||||||
|
|
||||||
|
return batch_response
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Enhance logging here if additional context is needed
|
||||||
|
logger.error("Error during send_llm_batch_request_async.", exc_info=True)
|
||||||
|
raise self.handle_llm_error(e)
|
||||||
|
|
||||||
@trace_method
|
@trace_method
|
||||||
def _get_anthropic_client(self, async_client: bool = False) -> Union[anthropic.AsyncAnthropic, anthropic.Anthropic]:
|
def _get_anthropic_client(self, async_client: bool = False) -> Union[anthropic.AsyncAnthropic, anthropic.Anthropic]:
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from typing import List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
from openai import AsyncStream, Stream
|
from openai import AsyncStream, Stream
|
||||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||||
@ -21,7 +21,6 @@ class LLMClientBase:
|
|||||||
self,
|
self,
|
||||||
llm_config: LLMConfig,
|
llm_config: LLMConfig,
|
||||||
put_inner_thoughts_first: Optional[bool] = True,
|
put_inner_thoughts_first: Optional[bool] = True,
|
||||||
use_structured_output: Optional[bool] = True,
|
|
||||||
use_tool_naming: bool = True,
|
use_tool_naming: bool = True,
|
||||||
):
|
):
|
||||||
self.llm_config = llm_config
|
self.llm_config = llm_config
|
||||||
@ -67,7 +66,6 @@ class LLMClientBase:
|
|||||||
Otherwise returns a ChatCompletionResponse.
|
Otherwise returns a ChatCompletionResponse.
|
||||||
"""
|
"""
|
||||||
request_data = self.build_request_data(messages, tools, force_tool_call)
|
request_data = self.build_request_data(messages, tools, force_tool_call)
|
||||||
response_data = {}
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
log_event(name="llm_request_sent", attributes=request_data)
|
log_event(name="llm_request_sent", attributes=request_data)
|
||||||
@ -81,6 +79,11 @@ class LLMClientBase:
|
|||||||
|
|
||||||
return self.convert_response_to_chat_completion(response_data, messages)
|
return self.convert_response_to_chat_completion(response_data, messages)
|
||||||
|
|
||||||
|
async def send_llm_batch_request_async(
|
||||||
|
self, agent_messages_mapping: Dict[str, List[Message]], agent_tools_mapping: Dict[str, List[dict]]
|
||||||
|
):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def build_request_data(
|
def build_request_data(
|
||||||
self,
|
self,
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
import logging
|
import logging
|
||||||
|
from datetime import datetime, timezone
|
||||||
from typing import Generator
|
from typing import Generator
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from anthropic.types.beta.messages import BetaMessageBatch, BetaMessageBatchRequestCounts
|
||||||
|
|
||||||
from letta.services.organization_manager import OrganizationManager
|
from letta.services.organization_manager import OrganizationManager
|
||||||
from letta.services.user_manager import UserManager
|
from letta.services.user_manager import UserManager
|
||||||
@ -103,3 +105,25 @@ def print_tool_func():
|
|||||||
return message
|
return message
|
||||||
|
|
||||||
yield print_tool
|
yield print_tool
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def dummy_beta_message_batch() -> BetaMessageBatch:
|
||||||
|
return BetaMessageBatch(
|
||||||
|
id="msgbatch_013Zva2CMHLNnXjNJJKqJ2EF",
|
||||||
|
archived_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc),
|
||||||
|
cancel_initiated_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc),
|
||||||
|
created_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc),
|
||||||
|
ended_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc),
|
||||||
|
expires_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc),
|
||||||
|
processing_status="in_progress",
|
||||||
|
request_counts=BetaMessageBatchRequestCounts(
|
||||||
|
canceled=10,
|
||||||
|
errored=30,
|
||||||
|
expired=10,
|
||||||
|
processing=100,
|
||||||
|
succeeded=50,
|
||||||
|
),
|
||||||
|
results_url="https://api.anthropic.com/v1/messages/batches/msgbatch_013Zva2CMHLNnXjNJJKqJ2EF/results",
|
||||||
|
type="message_batch",
|
||||||
|
)
|
||||||
|
@ -1,7 +1,13 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
# Import your AnthropicClient and related types
|
||||||
from letta.llm_api.anthropic_client import AnthropicClient
|
from letta.llm_api.anthropic_client import AnthropicClient
|
||||||
|
from letta.schemas.enums import MessageRole
|
||||||
from letta.schemas.llm_config import LLMConfig
|
from letta.schemas.llm_config import LLMConfig
|
||||||
|
from letta.schemas.message import Message as PydanticMessage
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -11,48 +17,74 @@ def anthropic_client():
|
|||||||
model_endpoint_type="anthropic",
|
model_endpoint_type="anthropic",
|
||||||
model_endpoint="https://api.anthropic.com/v1",
|
model_endpoint="https://api.anthropic.com/v1",
|
||||||
context_window=32000,
|
context_window=32000,
|
||||||
handle=f"anthropic/claude-3-5-sonnet-20241022",
|
handle="anthropic/claude-3-5-sonnet-20241022",
|
||||||
put_inner_thoughts_in_kwargs=False,
|
put_inner_thoughts_in_kwargs=False,
|
||||||
max_tokens=4096,
|
max_tokens=4096,
|
||||||
enable_reasoner=True,
|
enable_reasoner=True,
|
||||||
max_reasoning_tokens=1024,
|
max_reasoning_tokens=1024,
|
||||||
)
|
)
|
||||||
|
return AnthropicClient(llm_config=llm_config)
|
||||||
yield AnthropicClient(llm_config=llm_config)
|
|
||||||
|
|
||||||
|
|
||||||
# ======================================================================================================================
|
@pytest.fixture
|
||||||
# AnthropicClient
|
def mock_agent_messages():
|
||||||
# ======================================================================================================================
|
return {
|
||||||
|
"agent-1": [
|
||||||
|
PydanticMessage(
|
||||||
|
role=MessageRole.system, content=[{"type": "text", "text": "You are a helpful assistant."}], created_at=datetime.utcnow()
|
||||||
|
),
|
||||||
|
PydanticMessage(
|
||||||
|
role=MessageRole.user, content=[{"type": "text", "text": "What's the weather like?"}], created_at=datetime.utcnow()
|
||||||
|
),
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_agent_tools():
|
||||||
|
return {
|
||||||
|
"agent-1": [
|
||||||
|
{
|
||||||
|
"name": "get_weather",
|
||||||
|
"description": "Fetch current weather data",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"location": {"type": "string", "description": "The location to get weather for"}},
|
||||||
|
"required": ["location"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_batch_async_live(anthropic_client):
|
async def test_send_llm_batch_request_async_success(anthropic_client, mock_agent_messages, mock_agent_tools, dummy_beta_message_batch):
|
||||||
input_requests = {
|
"""Test a successful batch request using mocked Anthropic client responses."""
|
||||||
"my-first-request": {
|
# Patch the _get_anthropic_client method so that it returns a mock client.
|
||||||
"model": "claude-3-7-sonnet-20250219",
|
with patch.object(anthropic_client, "_get_anthropic_client") as mock_get_client:
|
||||||
"max_tokens": 1024,
|
mock_client = AsyncMock()
|
||||||
"messages": [
|
# Set the create method to return the dummy response asynchronously.
|
||||||
{
|
mock_client.beta.messages.batches.create.return_value = dummy_beta_message_batch
|
||||||
"role": "user",
|
mock_get_client.return_value = mock_client
|
||||||
"content": "Hello, world",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
},
|
|
||||||
"my-second-request": {
|
|
||||||
"model": "claude-3-7-sonnet-20250219",
|
|
||||||
"max_tokens": 1024,
|
|
||||||
"messages": [
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": "Hi again, friend",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
response = await anthropic_client.batch_async(input_requests)
|
# Call the method under test.
|
||||||
assert response.id.startswith("msgbatch_")
|
response = await anthropic_client.send_llm_batch_request_async(mock_agent_messages, mock_agent_tools)
|
||||||
assert response.processing_status in {"in_progress", "succeeded"}
|
|
||||||
assert response.request_counts.processing + response.request_counts.succeeded == len(input_requests.keys())
|
# Assert that the response is our dummy response.
|
||||||
assert response.created_at < response.expires_at
|
assert response.id == dummy_beta_message_batch.id
|
||||||
|
# Assert that the mocked create method was called and received the correct request payload.
|
||||||
|
assert mock_client.beta.messages.batches.create.called
|
||||||
|
requests_sent = mock_client.beta.messages.batches.create.call_args[1]["requests"]
|
||||||
|
assert isinstance(requests_sent, list)
|
||||||
|
assert all(isinstance(req, dict) and "custom_id" in req and "params" in req for req in requests_sent)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_llm_batch_request_async_mismatched_keys(anthropic_client, mock_agent_messages):
|
||||||
|
"""
|
||||||
|
This test verifies that if the keys in the messages and tools mappings do not match,
|
||||||
|
a ValueError is raised.
|
||||||
|
"""
|
||||||
|
mismatched_tools = {"agent-2": []} # Different agent ID than in the messages mapping.
|
||||||
|
with pytest.raises(ValueError, match="Agent mappings for messages and tools must use the same agent_ids."):
|
||||||
|
await anthropic_client.send_llm_batch_request_async(mock_agent_messages, mismatched_tools)
|
||||||
|
@ -7,12 +7,7 @@ from typing import List
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from anthropic.types.beta import BetaMessage
|
from anthropic.types.beta import BetaMessage
|
||||||
from anthropic.types.beta.messages import (
|
from anthropic.types.beta.messages import BetaMessageBatchIndividualResponse, BetaMessageBatchSucceededResult
|
||||||
BetaMessageBatch,
|
|
||||||
BetaMessageBatchIndividualResponse,
|
|
||||||
BetaMessageBatchRequestCounts,
|
|
||||||
BetaMessageBatchSucceededResult,
|
|
||||||
)
|
|
||||||
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall
|
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall
|
||||||
from openai.types.chat.chat_completion_message_tool_call import Function as OpenAIFunction
|
from openai.types.chat.chat_completion_message_tool_call import Function as OpenAIFunction
|
||||||
from sqlalchemy.exc import IntegrityError
|
from sqlalchemy.exc import IntegrityError
|
||||||
@ -584,28 +579,6 @@ def agent_with_tags(server: SyncServer, default_user):
|
|||||||
return [agent1, agent2, agent3]
|
return [agent1, agent2, agent3]
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def dummy_beta_message_batch() -> BetaMessageBatch:
|
|
||||||
return BetaMessageBatch(
|
|
||||||
id="msgbatch_013Zva2CMHLNnXjNJJKqJ2EF",
|
|
||||||
archived_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc),
|
|
||||||
cancel_initiated_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc),
|
|
||||||
created_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc),
|
|
||||||
ended_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc),
|
|
||||||
expires_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc),
|
|
||||||
processing_status="in_progress",
|
|
||||||
request_counts=BetaMessageBatchRequestCounts(
|
|
||||||
canceled=10,
|
|
||||||
errored=30,
|
|
||||||
expired=10,
|
|
||||||
processing=100,
|
|
||||||
succeeded=50,
|
|
||||||
),
|
|
||||||
results_url="https://api.anthropic.com/v1/messages/batches/msgbatch_013Zva2CMHLNnXjNJJKqJ2EF/results",
|
|
||||||
type="message_batch",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def dummy_llm_config() -> LLMConfig:
|
def dummy_llm_config() -> LLMConfig:
|
||||||
return LLMConfig.default_config("gpt-4")
|
return LLMConfig.default_config("gpt-4")
|
||||||
|
Loading…
Reference in New Issue
Block a user