feat: add llm config per request (#1866)

This commit is contained in:
cthomas 2025-04-23 16:37:05 -07:00 committed by GitHub
parent 51afbcb57e
commit d8d3d89073
13 changed files with 82 additions and 69 deletions

View File

@ -332,13 +332,14 @@ class Agent(BaseAgent):
log_telemetry(self.logger, "_get_ai_reply create start")
# New LLM client flow
llm_client = LLMClient.create(
llm_config=self.agent_state.llm_config,
provider=self.agent_state.llm_config.model_endpoint_type,
put_inner_thoughts_first=put_inner_thoughts_first,
)
if llm_client and not stream:
response = llm_client.send_llm_request(
messages=message_sequence,
llm_config=self.agent_state.llm_config,
tools=allowed_functions,
stream=stream,
force_tool_call=force_tool_call,

View File

@ -66,7 +66,7 @@ class LettaAgent(BaseAgent):
)
tool_rules_solver = ToolRulesSolver(agent_state.tool_rules)
llm_client = LLMClient.create(
llm_config=agent_state.llm_config,
provider=agent_state.llm_config.model_endpoint_type,
put_inner_thoughts_first=True,
)
for step in range(max_steps):
@ -182,6 +182,7 @@ class LettaAgent(BaseAgent):
response = await llm_client.send_llm_request_async(
messages=in_context_messages,
llm_config=agent_state.llm_config,
tools=allowed_tools,
force_tool_call=force_tool_call,
stream=stream,

View File

@ -156,7 +156,7 @@ class LettaAgentBatch:
log_event(name="init_llm_client")
llm_client = LLMClient.create(
llm_config=agent_states[0].llm_config,
provider=agent_states[0].llm_config.model_endpoint_type,
put_inner_thoughts_first=True,
)
agent_llm_config_mapping = {s.id: s.llm_config for s in agent_states}
@ -272,9 +272,14 @@ class LettaAgentBatch:
request_status_updates.append(RequestStatusUpdateInfo(llm_batch_id=llm_batch_id, agent_id=aid, request_status=status))
# translate providerspecific response → OpenAIstyle tool call (unchanged)
llm_client = LLMClient.create(llm_config=item.llm_config, put_inner_thoughts_first=True)
llm_client = LLMClient.create(
provider=item.llm_config.model_endpoint_type,
put_inner_thoughts_first=True,
)
tool_call = (
llm_client.convert_response_to_chat_completion(response_data=pr.message.model_dump(), input_messages=[])
llm_client.convert_response_to_chat_completion(
response_data=pr.message.model_dump(), input_messages=[], llm_config=item.llm_config
)
.choices[0]
.message.tool_calls[0]
)

View File

@ -43,18 +43,18 @@ logger = get_logger(__name__)
class AnthropicClient(LLMClientBase):
def request(self, request_data: dict) -> dict:
def request(self, request_data: dict, llm_config: LLMConfig) -> dict:
client = self._get_anthropic_client(async_client=False)
response = client.beta.messages.create(**request_data, betas=["tools-2024-04-04"])
return response.model_dump()
async def request_async(self, request_data: dict) -> dict:
async def request_async(self, request_data: dict, llm_config: LLMConfig) -> dict:
client = self._get_anthropic_client(async_client=True)
response = await client.beta.messages.create(**request_data, betas=["tools-2024-04-04"])
return response.model_dump()
@trace_method
async def stream_async(self, request_data: dict) -> AsyncStream[BetaRawMessageStreamEvent]:
async def stream_async(self, request_data: dict, llm_config: LLMConfig) -> AsyncStream[BetaRawMessageStreamEvent]:
client = self._get_anthropic_client(async_client=True)
request_data["stream"] = True
return await client.beta.messages.create(**request_data, betas=["tools-2024-04-04"])
@ -310,6 +310,7 @@ class AnthropicClient(LLMClientBase):
self,
response_data: dict,
input_messages: List[PydanticMessage],
llm_config: LLMConfig,
) -> ChatCompletionResponse:
"""
Example response from Claude 3:
@ -411,7 +412,7 @@ class AnthropicClient(LLMClientBase):
total_tokens=prompt_tokens + completion_tokens,
),
)
if self.llm_config.put_inner_thoughts_in_kwargs:
if llm_config.put_inner_thoughts_in_kwargs:
chat_completion_response = unpack_all_inner_thoughts_from_kwargs(
response=chat_completion_response, inner_thoughts_key=INNER_THOUGHTS_KWARG
)

View File

@ -25,15 +25,15 @@ logger = get_logger(__name__)
class GoogleAIClient(LLMClientBase):
def request(self, request_data: dict) -> dict:
def request(self, request_data: dict, llm_config: LLMConfig) -> dict:
"""
Performs underlying request to llm and returns raw response.
"""
# print("[google_ai request]", json.dumps(request_data, indent=2))
url, headers = get_gemini_endpoint_and_headers(
base_url=str(self.llm_config.model_endpoint),
model=self.llm_config.model,
base_url=str(llm_config.model_endpoint),
model=llm_config.model,
api_key=str(model_settings.gemini_api_key),
key_in_header=True,
generate_content=True,
@ -55,7 +55,7 @@ class GoogleAIClient(LLMClientBase):
tool_objs = [Tool(**t) for t in tools]
tool_names = [t.function.name for t in tool_objs]
# Convert to the exact payload style Google expects
tools = self.convert_tools_to_google_ai_format(tool_objs)
tools = self.convert_tools_to_google_ai_format(tool_objs, llm_config)
else:
tool_names = []
@ -88,6 +88,7 @@ class GoogleAIClient(LLMClientBase):
self,
response_data: dict,
input_messages: List[PydanticMessage],
llm_config: LLMConfig,
) -> ChatCompletionResponse:
"""
Converts custom response format from llm client into an OpenAI
@ -150,7 +151,7 @@ class GoogleAIClient(LLMClientBase):
assert isinstance(function_args, dict), function_args
# NOTE: this also involves stripping the inner monologue out of the function
if self.llm_config.put_inner_thoughts_in_kwargs:
if llm_config.put_inner_thoughts_in_kwargs:
from letta.local_llm.constants import INNER_THOUGHTS_KWARG
assert INNER_THOUGHTS_KWARG in function_args, f"Couldn't find inner thoughts in function args:\n{function_call}"
@ -259,14 +260,14 @@ class GoogleAIClient(LLMClientBase):
return ChatCompletionResponse(
id=response_id,
choices=choices,
model=self.llm_config.model, # NOTE: Google API doesn't pass back model in the response
model=llm_config.model, # NOTE: Google API doesn't pass back model in the response
created=get_utc_time_int(),
usage=usage,
)
except KeyError as e:
raise e
def convert_tools_to_google_ai_format(self, tools: List[Tool]) -> List[dict]:
def convert_tools_to_google_ai_format(self, tools: List[Tool], llm_config: LLMConfig) -> List[dict]:
"""
OpenAI style:
"tools": [{
@ -326,7 +327,7 @@ class GoogleAIClient(LLMClientBase):
# Note: Google AI API used to have weird casing requirements, but not any more
# Add inner thoughts
if self.llm_config.put_inner_thoughts_in_kwargs:
if llm_config.put_inner_thoughts_in_kwargs:
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION
func["parameters"]["properties"][INNER_THOUGHTS_KWARG] = {

View File

@ -18,7 +18,7 @@ from letta.utils import get_tool_call_id
class GoogleVertexClient(GoogleAIClient):
def request(self, request_data: dict) -> dict:
def request(self, request_data: dict, llm_config: LLMConfig) -> dict:
"""
Performs underlying request to llm and returns raw response.
"""
@ -29,7 +29,7 @@ class GoogleVertexClient(GoogleAIClient):
http_options={"api_version": "v1"},
)
response = client.models.generate_content(
model=self.llm_config.model,
model=llm_config.model,
contents=request_data["contents"],
config=request_data["config"],
)
@ -45,7 +45,7 @@ class GoogleVertexClient(GoogleAIClient):
"""
Constructs a request object in the expected data format for this client.
"""
request_data = super().build_request_data(messages, self.llm_config, tools, force_tool_call)
request_data = super().build_request_data(messages, llm_config, tools, force_tool_call)
request_data["config"] = request_data.pop("generation_config")
request_data["config"]["tools"] = request_data.pop("tools")
@ -66,6 +66,7 @@ class GoogleVertexClient(GoogleAIClient):
self,
response_data: dict,
input_messages: List[PydanticMessage],
llm_config: LLMConfig,
) -> ChatCompletionResponse:
"""
Converts custom response format from llm client into an OpenAI
@ -127,7 +128,7 @@ class GoogleVertexClient(GoogleAIClient):
assert isinstance(function_args, dict), function_args
# NOTE: this also involves stripping the inner monologue out of the function
if self.llm_config.put_inner_thoughts_in_kwargs:
if llm_config.put_inner_thoughts_in_kwargs:
from letta.local_llm.constants import INNER_THOUGHTS_KWARG
assert INNER_THOUGHTS_KWARG in function_args, f"Couldn't find inner thoughts in function args:\n{function_call}"
@ -224,7 +225,7 @@ class GoogleVertexClient(GoogleAIClient):
return ChatCompletionResponse(
id=response_id,
choices=choices,
model=self.llm_config.model, # NOTE: Google API doesn't pass back model in the response
model=llm_config.model, # NOTE: Google API doesn't pass back model in the response
created=get_utc_time_int(),
usage=usage,
)

View File

@ -1,7 +1,7 @@
from typing import Optional
from letta.llm_api.llm_client_base import LLMClientBase
from letta.schemas.llm_config import LLMConfig
from letta.schemas.enums import ProviderType
class LLMClient:
@ -9,17 +9,15 @@ class LLMClient:
@staticmethod
def create(
llm_config: LLMConfig,
provider: ProviderType,
put_inner_thoughts_first: bool = True,
) -> Optional[LLMClientBase]:
"""
Create an LLM client based on the model endpoint type.
Args:
llm_config: Configuration for the LLM model
provider: The model endpoint type
put_inner_thoughts_first: Whether to put inner thoughts first in the response
use_structured_output: Whether to use structured output
use_tool_naming: Whether to use tool naming
Returns:
An instance of LLMClientBase subclass
@ -27,33 +25,29 @@ class LLMClient:
Raises:
ValueError: If the model endpoint type is not supported
"""
match llm_config.model_endpoint_type:
case "google_ai":
match provider:
case ProviderType.google_ai:
from letta.llm_api.google_ai_client import GoogleAIClient
return GoogleAIClient(
llm_config=llm_config,
put_inner_thoughts_first=put_inner_thoughts_first,
)
case "google_vertex":
case ProviderType.google_vertex:
from letta.llm_api.google_vertex_client import GoogleVertexClient
return GoogleVertexClient(
llm_config=llm_config,
put_inner_thoughts_first=put_inner_thoughts_first,
)
case "anthropic":
case ProviderType.anthropic:
from letta.llm_api.anthropic_client import AnthropicClient
return AnthropicClient(
llm_config=llm_config,
put_inner_thoughts_first=put_inner_thoughts_first,
)
case "openai":
case ProviderType.openai:
from letta.llm_api.openai_client import OpenAIClient
return OpenAIClient(
llm_config=llm_config,
put_inner_thoughts_first=put_inner_thoughts_first,
)
case _:

View File

@ -20,17 +20,16 @@ class LLMClientBase:
def __init__(
self,
llm_config: LLMConfig,
put_inner_thoughts_first: Optional[bool] = True,
use_tool_naming: bool = True,
):
self.llm_config = llm_config
self.put_inner_thoughts_first = put_inner_thoughts_first
self.use_tool_naming = use_tool_naming
def send_llm_request(
self,
messages: List[Message],
llm_config: LLMConfig,
tools: Optional[List[dict]] = None, # TODO: change to Tool object
stream: bool = False,
force_tool_call: Optional[str] = None,
@ -40,23 +39,24 @@ class LLMClientBase:
If stream=True, returns a Stream[ChatCompletionChunk] that can be iterated over.
Otherwise returns a ChatCompletionResponse.
"""
request_data = self.build_request_data(messages, self.llm_config, tools, force_tool_call)
request_data = self.build_request_data(messages, llm_config, tools, force_tool_call)
try:
log_event(name="llm_request_sent", attributes=request_data)
if stream:
return self.stream(request_data)
return self.stream(request_data, llm_config)
else:
response_data = self.request(request_data)
response_data = self.request(request_data, llm_config)
log_event(name="llm_response_received", attributes=response_data)
except Exception as e:
raise self.handle_llm_error(e)
return self.convert_response_to_chat_completion(response_data, messages)
return self.convert_response_to_chat_completion(response_data, messages, llm_config)
async def send_llm_request_async(
self,
messages: List[Message],
llm_config: LLMConfig,
tools: Optional[List[dict]] = None, # TODO: change to Tool object
stream: bool = False,
force_tool_call: Optional[str] = None,
@ -66,19 +66,19 @@ class LLMClientBase:
If stream=True, returns an AsyncStream[ChatCompletionChunk] that can be async iterated over.
Otherwise returns a ChatCompletionResponse.
"""
request_data = self.build_request_data(messages, self.llm_config, tools, force_tool_call)
request_data = self.build_request_data(messages, llm_config, tools, force_tool_call)
try:
log_event(name="llm_request_sent", attributes=request_data)
if stream:
return await self.stream_async(request_data)
return await self.stream_async(request_data, llm_config)
else:
response_data = await self.request_async(request_data)
response_data = await self.request_async(request_data, llm_config)
log_event(name="llm_response_received", attributes=response_data)
except Exception as e:
raise self.handle_llm_error(e)
return self.convert_response_to_chat_completion(response_data, messages)
return self.convert_response_to_chat_completion(response_data, messages, llm_config)
async def send_llm_batch_request_async(
self,
@ -102,14 +102,14 @@ class LLMClientBase:
raise NotImplementedError
@abstractmethod
def request(self, request_data: dict) -> dict:
def request(self, request_data: dict, llm_config: LLMConfig) -> dict:
"""
Performs underlying request to llm and returns raw response.
"""
raise NotImplementedError
@abstractmethod
async def request_async(self, request_data: dict) -> dict:
async def request_async(self, request_data: dict, llm_config: LLMConfig) -> dict:
"""
Performs underlying request to llm and returns raw response.
"""
@ -120,6 +120,7 @@ class LLMClientBase:
self,
response_data: dict,
input_messages: List[Message],
llm_config: LLMConfig,
) -> ChatCompletionResponse:
"""
Converts custom response format from llm client into an OpenAI
@ -128,18 +129,18 @@ class LLMClientBase:
raise NotImplementedError
@abstractmethod
def stream(self, request_data: dict) -> Stream[ChatCompletionChunk]:
def stream(self, request_data: dict, llm_config: LLMConfig) -> Stream[ChatCompletionChunk]:
"""
Performs underlying streaming request to llm and returns raw response.
"""
raise NotImplementedError(f"Streaming is not supported for {self.llm_config.model_endpoint_type}")
raise NotImplementedError(f"Streaming is not supported for {llm_config.model_endpoint_type}")
@abstractmethod
async def stream_async(self, request_data: dict) -> AsyncStream[ChatCompletionChunk]:
async def stream_async(self, request_data: dict, llm_config: LLMConfig) -> AsyncStream[ChatCompletionChunk]:
"""
Performs underlying streaming request to llm and returns raw response.
"""
raise NotImplementedError(f"Streaming is not supported for {self.llm_config.model_endpoint_type}")
raise NotImplementedError(f"Streaming is not supported for {llm_config.model_endpoint_type}")
@abstractmethod
def handle_llm_error(self, e: Exception) -> Exception:

View File

@ -62,11 +62,11 @@ def supports_parallel_tool_calling(model: str) -> bool:
class OpenAIClient(LLMClientBase):
def _prepare_client_kwargs(self) -> dict:
def _prepare_client_kwargs(self, llm_config: LLMConfig) -> dict:
api_key = model_settings.openai_api_key or os.environ.get("OPENAI_API_KEY")
# supposedly the openai python client requires a dummy API key
api_key = api_key or "DUMMY_API_KEY"
kwargs = {"api_key": api_key, "base_url": self.llm_config.model_endpoint}
kwargs = {"api_key": api_key, "base_url": llm_config.model_endpoint}
return kwargs
@ -115,7 +115,7 @@ class OpenAIClient(LLMClientBase):
# TODO(matt) move into LLMConfig
# TODO: This vllm checking is very brittle and is a patch at most
tool_choice = None
if llm_config.model_endpoint == "https://inference.memgpt.ai" or (llm_config.handle and "vllm" in self.llm_config.handle):
if llm_config.model_endpoint == "https://inference.memgpt.ai" or (llm_config.handle and "vllm" in llm_config.handle):
tool_choice = "auto" # TODO change to "required" once proxy supports it
elif tools:
# only set if tools is non-Null
@ -152,20 +152,20 @@ class OpenAIClient(LLMClientBase):
return data.model_dump(exclude_unset=True)
def request(self, request_data: dict) -> dict:
def request(self, request_data: dict, llm_config: LLMConfig) -> dict:
"""
Performs underlying synchronous request to OpenAI API and returns raw response dict.
"""
client = OpenAI(**self._prepare_client_kwargs())
client = OpenAI(**self._prepare_client_kwargs(llm_config))
response: ChatCompletion = client.chat.completions.create(**request_data)
return response.model_dump()
async def request_async(self, request_data: dict) -> dict:
async def request_async(self, request_data: dict, llm_config: LLMConfig) -> dict:
"""
Performs underlying asynchronous request to OpenAI API and returns raw response dict.
"""
client = AsyncOpenAI(**self._prepare_client_kwargs())
client = AsyncOpenAI(**self._prepare_client_kwargs(llm_config))
response: ChatCompletion = await client.chat.completions.create(**request_data)
return response.model_dump()
@ -173,6 +173,7 @@ class OpenAIClient(LLMClientBase):
self,
response_data: dict,
input_messages: List[PydanticMessage], # Included for consistency, maybe used later
llm_config: LLMConfig,
) -> ChatCompletionResponse:
"""
Converts raw OpenAI response dict into the ChatCompletionResponse Pydantic model.
@ -183,30 +184,30 @@ class OpenAIClient(LLMClientBase):
chat_completion_response = ChatCompletionResponse(**response_data)
# Unpack inner thoughts if they were embedded in function arguments
if self.llm_config.put_inner_thoughts_in_kwargs:
if llm_config.put_inner_thoughts_in_kwargs:
chat_completion_response = unpack_all_inner_thoughts_from_kwargs(
response=chat_completion_response, inner_thoughts_key=INNER_THOUGHTS_KWARG
)
# If we used a reasoning model, create a content part for the ommitted reasoning
if is_openai_reasoning_model(self.llm_config.model):
if is_openai_reasoning_model(llm_config.model):
chat_completion_response.choices[0].message.ommitted_reasoning_content = True
return chat_completion_response
def stream(self, request_data: dict) -> Stream[ChatCompletionChunk]:
def stream(self, request_data: dict, llm_config: LLMConfig) -> Stream[ChatCompletionChunk]:
"""
Performs underlying streaming request to OpenAI and returns the stream iterator.
"""
client = OpenAI(**self._prepare_client_kwargs())
client = OpenAI(**self._prepare_client_kwargs(llm_config))
response_stream: Stream[ChatCompletionChunk] = client.chat.completions.create(**request_data, stream=True)
return response_stream
async def stream_async(self, request_data: dict) -> AsyncStream[ChatCompletionChunk]:
async def stream_async(self, request_data: dict, llm_config: LLMConfig) -> AsyncStream[ChatCompletionChunk]:
"""
Performs underlying asynchronous streaming request to OpenAI and returns the async stream iterator.
"""
client = AsyncOpenAI(**self._prepare_client_kwargs())
client = AsyncOpenAI(**self._prepare_client_kwargs(llm_config))
response_stream: AsyncStream[ChatCompletionChunk] = await client.chat.completions.create(**request_data, stream=True)
return response_stream

View File

@ -79,7 +79,7 @@ def summarize_messages(
llm_config_no_inner_thoughts.put_inner_thoughts_in_kwargs = False
llm_client = LLMClient.create(
llm_config=llm_config_no_inner_thoughts,
provider=llm_config_no_inner_thoughts.model_endpoint_type,
put_inner_thoughts_first=False,
)
# try to use new client, otherwise fallback to old flow
@ -87,6 +87,7 @@ def summarize_messages(
if llm_client:
response = llm_client.send_llm_request(
messages=message_sequence,
llm_config=llm_config_no_inner_thoughts,
stream=False,
)
else:

View File

@ -3,6 +3,9 @@ from enum import Enum
class ProviderType(str, Enum):
anthropic = "anthropic"
google_ai = "google_ai"
google_vertex = "google_vertex"
openai = "openai"
class MessageRole(str, Enum):

View File

@ -104,10 +104,13 @@ def check_first_response_is_valid_for_llm_endpoint(filename: str, validate_inner
messages = client.server.agent_manager.get_in_context_messages(agent_id=full_agent_state.id, actor=client.user)
agent = Agent(agent_state=full_agent_state, interface=None, user=client.user)
llm_client = LLMClient.create(llm_config=agent_state.llm_config)
llm_client = LLMClient.create(
provider=agent_state.llm_config.model_endpoint_type,
)
if llm_client:
response = llm_client.send_llm_request(
messages=messages,
llm_config=agent_state.llm_config,
tools=[t.json_schema for t in agent.agent_state.tools],
)
else:

View File

@ -26,8 +26,8 @@ def llm_config():
@pytest.fixture
def anthropic_client(llm_config):
return AnthropicClient(llm_config=llm_config)
def anthropic_client():
return AnthropicClient()
@pytest.fixture