From d8d3d890732b1f0d27dfc62c1ee863c46f228a3f Mon Sep 17 00:00:00 2001 From: cthomas Date: Wed, 23 Apr 2025 16:37:05 -0700 Subject: [PATCH] feat: add llm config per request (#1866) --- letta/agent.py | 3 ++- letta/agents/letta_agent.py | 3 ++- letta/agents/letta_agent_batch.py | 11 ++++++--- letta/llm_api/anthropic_client.py | 9 ++++---- letta/llm_api/google_ai_client.py | 17 +++++++------- letta/llm_api/google_vertex_client.py | 11 +++++---- letta/llm_api/llm_client.py | 22 +++++++----------- letta/llm_api/llm_client_base.py | 33 ++++++++++++++------------- letta/llm_api/openai_client.py | 27 +++++++++++----------- letta/memory.py | 3 ++- letta/schemas/enums.py | 3 +++ tests/helpers/endpoints_helper.py | 5 +++- tests/test_llm_clients.py | 4 ++-- 13 files changed, 82 insertions(+), 69 deletions(-) diff --git a/letta/agent.py b/letta/agent.py index 1d54ac278..a6742b257 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -332,13 +332,14 @@ class Agent(BaseAgent): log_telemetry(self.logger, "_get_ai_reply create start") # New LLM client flow llm_client = LLMClient.create( - llm_config=self.agent_state.llm_config, + provider=self.agent_state.llm_config.model_endpoint_type, put_inner_thoughts_first=put_inner_thoughts_first, ) if llm_client and not stream: response = llm_client.send_llm_request( messages=message_sequence, + llm_config=self.agent_state.llm_config, tools=allowed_functions, stream=stream, force_tool_call=force_tool_call, diff --git a/letta/agents/letta_agent.py b/letta/agents/letta_agent.py index 5751acd01..35aad8116 100644 --- a/letta/agents/letta_agent.py +++ b/letta/agents/letta_agent.py @@ -66,7 +66,7 @@ class LettaAgent(BaseAgent): ) tool_rules_solver = ToolRulesSolver(agent_state.tool_rules) llm_client = LLMClient.create( - llm_config=agent_state.llm_config, + provider=agent_state.llm_config.model_endpoint_type, put_inner_thoughts_first=True, ) for step in range(max_steps): @@ -182,6 +182,7 @@ class LettaAgent(BaseAgent): response = await llm_client.send_llm_request_async( messages=in_context_messages, + llm_config=agent_state.llm_config, tools=allowed_tools, force_tool_call=force_tool_call, stream=stream, diff --git a/letta/agents/letta_agent_batch.py b/letta/agents/letta_agent_batch.py index 0f742fc14..a6d31a09c 100644 --- a/letta/agents/letta_agent_batch.py +++ b/letta/agents/letta_agent_batch.py @@ -156,7 +156,7 @@ class LettaAgentBatch: log_event(name="init_llm_client") llm_client = LLMClient.create( - llm_config=agent_states[0].llm_config, + provider=agent_states[0].llm_config.model_endpoint_type, put_inner_thoughts_first=True, ) agent_llm_config_mapping = {s.id: s.llm_config for s in agent_states} @@ -272,9 +272,14 @@ class LettaAgentBatch: request_status_updates.append(RequestStatusUpdateInfo(llm_batch_id=llm_batch_id, agent_id=aid, request_status=status)) # translate provider‑specific response → OpenAI‑style tool call (unchanged) - llm_client = LLMClient.create(llm_config=item.llm_config, put_inner_thoughts_first=True) + llm_client = LLMClient.create( + provider=item.llm_config.model_endpoint_type, + put_inner_thoughts_first=True, + ) tool_call = ( - llm_client.convert_response_to_chat_completion(response_data=pr.message.model_dump(), input_messages=[]) + llm_client.convert_response_to_chat_completion( + response_data=pr.message.model_dump(), input_messages=[], llm_config=item.llm_config + ) .choices[0] .message.tool_calls[0] ) diff --git a/letta/llm_api/anthropic_client.py b/letta/llm_api/anthropic_client.py index 4c79cb688..863fcef0d 100644 --- a/letta/llm_api/anthropic_client.py +++ b/letta/llm_api/anthropic_client.py @@ -43,18 +43,18 @@ logger = get_logger(__name__) class AnthropicClient(LLMClientBase): - def request(self, request_data: dict) -> dict: + def request(self, request_data: dict, llm_config: LLMConfig) -> dict: client = self._get_anthropic_client(async_client=False) response = client.beta.messages.create(**request_data, betas=["tools-2024-04-04"]) return response.model_dump() - async def request_async(self, request_data: dict) -> dict: + async def request_async(self, request_data: dict, llm_config: LLMConfig) -> dict: client = self._get_anthropic_client(async_client=True) response = await client.beta.messages.create(**request_data, betas=["tools-2024-04-04"]) return response.model_dump() @trace_method - async def stream_async(self, request_data: dict) -> AsyncStream[BetaRawMessageStreamEvent]: + async def stream_async(self, request_data: dict, llm_config: LLMConfig) -> AsyncStream[BetaRawMessageStreamEvent]: client = self._get_anthropic_client(async_client=True) request_data["stream"] = True return await client.beta.messages.create(**request_data, betas=["tools-2024-04-04"]) @@ -310,6 +310,7 @@ class AnthropicClient(LLMClientBase): self, response_data: dict, input_messages: List[PydanticMessage], + llm_config: LLMConfig, ) -> ChatCompletionResponse: """ Example response from Claude 3: @@ -411,7 +412,7 @@ class AnthropicClient(LLMClientBase): total_tokens=prompt_tokens + completion_tokens, ), ) - if self.llm_config.put_inner_thoughts_in_kwargs: + if llm_config.put_inner_thoughts_in_kwargs: chat_completion_response = unpack_all_inner_thoughts_from_kwargs( response=chat_completion_response, inner_thoughts_key=INNER_THOUGHTS_KWARG ) diff --git a/letta/llm_api/google_ai_client.py b/letta/llm_api/google_ai_client.py index e471cb85e..dc20f1f45 100644 --- a/letta/llm_api/google_ai_client.py +++ b/letta/llm_api/google_ai_client.py @@ -25,15 +25,15 @@ logger = get_logger(__name__) class GoogleAIClient(LLMClientBase): - def request(self, request_data: dict) -> dict: + def request(self, request_data: dict, llm_config: LLMConfig) -> dict: """ Performs underlying request to llm and returns raw response. """ # print("[google_ai request]", json.dumps(request_data, indent=2)) url, headers = get_gemini_endpoint_and_headers( - base_url=str(self.llm_config.model_endpoint), - model=self.llm_config.model, + base_url=str(llm_config.model_endpoint), + model=llm_config.model, api_key=str(model_settings.gemini_api_key), key_in_header=True, generate_content=True, @@ -55,7 +55,7 @@ class GoogleAIClient(LLMClientBase): tool_objs = [Tool(**t) for t in tools] tool_names = [t.function.name for t in tool_objs] # Convert to the exact payload style Google expects - tools = self.convert_tools_to_google_ai_format(tool_objs) + tools = self.convert_tools_to_google_ai_format(tool_objs, llm_config) else: tool_names = [] @@ -88,6 +88,7 @@ class GoogleAIClient(LLMClientBase): self, response_data: dict, input_messages: List[PydanticMessage], + llm_config: LLMConfig, ) -> ChatCompletionResponse: """ Converts custom response format from llm client into an OpenAI @@ -150,7 +151,7 @@ class GoogleAIClient(LLMClientBase): assert isinstance(function_args, dict), function_args # NOTE: this also involves stripping the inner monologue out of the function - if self.llm_config.put_inner_thoughts_in_kwargs: + if llm_config.put_inner_thoughts_in_kwargs: from letta.local_llm.constants import INNER_THOUGHTS_KWARG assert INNER_THOUGHTS_KWARG in function_args, f"Couldn't find inner thoughts in function args:\n{function_call}" @@ -259,14 +260,14 @@ class GoogleAIClient(LLMClientBase): return ChatCompletionResponse( id=response_id, choices=choices, - model=self.llm_config.model, # NOTE: Google API doesn't pass back model in the response + model=llm_config.model, # NOTE: Google API doesn't pass back model in the response created=get_utc_time_int(), usage=usage, ) except KeyError as e: raise e - def convert_tools_to_google_ai_format(self, tools: List[Tool]) -> List[dict]: + def convert_tools_to_google_ai_format(self, tools: List[Tool], llm_config: LLMConfig) -> List[dict]: """ OpenAI style: "tools": [{ @@ -326,7 +327,7 @@ class GoogleAIClient(LLMClientBase): # Note: Google AI API used to have weird casing requirements, but not any more # Add inner thoughts - if self.llm_config.put_inner_thoughts_in_kwargs: + if llm_config.put_inner_thoughts_in_kwargs: from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION func["parameters"]["properties"][INNER_THOUGHTS_KWARG] = { diff --git a/letta/llm_api/google_vertex_client.py b/letta/llm_api/google_vertex_client.py index f22faebad..76a84daf3 100644 --- a/letta/llm_api/google_vertex_client.py +++ b/letta/llm_api/google_vertex_client.py @@ -18,7 +18,7 @@ from letta.utils import get_tool_call_id class GoogleVertexClient(GoogleAIClient): - def request(self, request_data: dict) -> dict: + def request(self, request_data: dict, llm_config: LLMConfig) -> dict: """ Performs underlying request to llm and returns raw response. """ @@ -29,7 +29,7 @@ class GoogleVertexClient(GoogleAIClient): http_options={"api_version": "v1"}, ) response = client.models.generate_content( - model=self.llm_config.model, + model=llm_config.model, contents=request_data["contents"], config=request_data["config"], ) @@ -45,7 +45,7 @@ class GoogleVertexClient(GoogleAIClient): """ Constructs a request object in the expected data format for this client. """ - request_data = super().build_request_data(messages, self.llm_config, tools, force_tool_call) + request_data = super().build_request_data(messages, llm_config, tools, force_tool_call) request_data["config"] = request_data.pop("generation_config") request_data["config"]["tools"] = request_data.pop("tools") @@ -66,6 +66,7 @@ class GoogleVertexClient(GoogleAIClient): self, response_data: dict, input_messages: List[PydanticMessage], + llm_config: LLMConfig, ) -> ChatCompletionResponse: """ Converts custom response format from llm client into an OpenAI @@ -127,7 +128,7 @@ class GoogleVertexClient(GoogleAIClient): assert isinstance(function_args, dict), function_args # NOTE: this also involves stripping the inner monologue out of the function - if self.llm_config.put_inner_thoughts_in_kwargs: + if llm_config.put_inner_thoughts_in_kwargs: from letta.local_llm.constants import INNER_THOUGHTS_KWARG assert INNER_THOUGHTS_KWARG in function_args, f"Couldn't find inner thoughts in function args:\n{function_call}" @@ -224,7 +225,7 @@ class GoogleVertexClient(GoogleAIClient): return ChatCompletionResponse( id=response_id, choices=choices, - model=self.llm_config.model, # NOTE: Google API doesn't pass back model in the response + model=llm_config.model, # NOTE: Google API doesn't pass back model in the response created=get_utc_time_int(), usage=usage, ) diff --git a/letta/llm_api/llm_client.py b/letta/llm_api/llm_client.py index e4f07db07..674f94974 100644 --- a/letta/llm_api/llm_client.py +++ b/letta/llm_api/llm_client.py @@ -1,7 +1,7 @@ from typing import Optional from letta.llm_api.llm_client_base import LLMClientBase -from letta.schemas.llm_config import LLMConfig +from letta.schemas.enums import ProviderType class LLMClient: @@ -9,17 +9,15 @@ class LLMClient: @staticmethod def create( - llm_config: LLMConfig, + provider: ProviderType, put_inner_thoughts_first: bool = True, ) -> Optional[LLMClientBase]: """ Create an LLM client based on the model endpoint type. Args: - llm_config: Configuration for the LLM model + provider: The model endpoint type put_inner_thoughts_first: Whether to put inner thoughts first in the response - use_structured_output: Whether to use structured output - use_tool_naming: Whether to use tool naming Returns: An instance of LLMClientBase subclass @@ -27,33 +25,29 @@ class LLMClient: Raises: ValueError: If the model endpoint type is not supported """ - match llm_config.model_endpoint_type: - case "google_ai": + match provider: + case ProviderType.google_ai: from letta.llm_api.google_ai_client import GoogleAIClient return GoogleAIClient( - llm_config=llm_config, put_inner_thoughts_first=put_inner_thoughts_first, ) - case "google_vertex": + case ProviderType.google_vertex: from letta.llm_api.google_vertex_client import GoogleVertexClient return GoogleVertexClient( - llm_config=llm_config, put_inner_thoughts_first=put_inner_thoughts_first, ) - case "anthropic": + case ProviderType.anthropic: from letta.llm_api.anthropic_client import AnthropicClient return AnthropicClient( - llm_config=llm_config, put_inner_thoughts_first=put_inner_thoughts_first, ) - case "openai": + case ProviderType.openai: from letta.llm_api.openai_client import OpenAIClient return OpenAIClient( - llm_config=llm_config, put_inner_thoughts_first=put_inner_thoughts_first, ) case _: diff --git a/letta/llm_api/llm_client_base.py b/letta/llm_api/llm_client_base.py index 12cf2fec2..5c7dcab9e 100644 --- a/letta/llm_api/llm_client_base.py +++ b/letta/llm_api/llm_client_base.py @@ -20,17 +20,16 @@ class LLMClientBase: def __init__( self, - llm_config: LLMConfig, put_inner_thoughts_first: Optional[bool] = True, use_tool_naming: bool = True, ): - self.llm_config = llm_config self.put_inner_thoughts_first = put_inner_thoughts_first self.use_tool_naming = use_tool_naming def send_llm_request( self, messages: List[Message], + llm_config: LLMConfig, tools: Optional[List[dict]] = None, # TODO: change to Tool object stream: bool = False, force_tool_call: Optional[str] = None, @@ -40,23 +39,24 @@ class LLMClientBase: If stream=True, returns a Stream[ChatCompletionChunk] that can be iterated over. Otherwise returns a ChatCompletionResponse. """ - request_data = self.build_request_data(messages, self.llm_config, tools, force_tool_call) + request_data = self.build_request_data(messages, llm_config, tools, force_tool_call) try: log_event(name="llm_request_sent", attributes=request_data) if stream: - return self.stream(request_data) + return self.stream(request_data, llm_config) else: - response_data = self.request(request_data) + response_data = self.request(request_data, llm_config) log_event(name="llm_response_received", attributes=response_data) except Exception as e: raise self.handle_llm_error(e) - return self.convert_response_to_chat_completion(response_data, messages) + return self.convert_response_to_chat_completion(response_data, messages, llm_config) async def send_llm_request_async( self, messages: List[Message], + llm_config: LLMConfig, tools: Optional[List[dict]] = None, # TODO: change to Tool object stream: bool = False, force_tool_call: Optional[str] = None, @@ -66,19 +66,19 @@ class LLMClientBase: If stream=True, returns an AsyncStream[ChatCompletionChunk] that can be async iterated over. Otherwise returns a ChatCompletionResponse. """ - request_data = self.build_request_data(messages, self.llm_config, tools, force_tool_call) + request_data = self.build_request_data(messages, llm_config, tools, force_tool_call) try: log_event(name="llm_request_sent", attributes=request_data) if stream: - return await self.stream_async(request_data) + return await self.stream_async(request_data, llm_config) else: - response_data = await self.request_async(request_data) + response_data = await self.request_async(request_data, llm_config) log_event(name="llm_response_received", attributes=response_data) except Exception as e: raise self.handle_llm_error(e) - return self.convert_response_to_chat_completion(response_data, messages) + return self.convert_response_to_chat_completion(response_data, messages, llm_config) async def send_llm_batch_request_async( self, @@ -102,14 +102,14 @@ class LLMClientBase: raise NotImplementedError @abstractmethod - def request(self, request_data: dict) -> dict: + def request(self, request_data: dict, llm_config: LLMConfig) -> dict: """ Performs underlying request to llm and returns raw response. """ raise NotImplementedError @abstractmethod - async def request_async(self, request_data: dict) -> dict: + async def request_async(self, request_data: dict, llm_config: LLMConfig) -> dict: """ Performs underlying request to llm and returns raw response. """ @@ -120,6 +120,7 @@ class LLMClientBase: self, response_data: dict, input_messages: List[Message], + llm_config: LLMConfig, ) -> ChatCompletionResponse: """ Converts custom response format from llm client into an OpenAI @@ -128,18 +129,18 @@ class LLMClientBase: raise NotImplementedError @abstractmethod - def stream(self, request_data: dict) -> Stream[ChatCompletionChunk]: + def stream(self, request_data: dict, llm_config: LLMConfig) -> Stream[ChatCompletionChunk]: """ Performs underlying streaming request to llm and returns raw response. """ - raise NotImplementedError(f"Streaming is not supported for {self.llm_config.model_endpoint_type}") + raise NotImplementedError(f"Streaming is not supported for {llm_config.model_endpoint_type}") @abstractmethod - async def stream_async(self, request_data: dict) -> AsyncStream[ChatCompletionChunk]: + async def stream_async(self, request_data: dict, llm_config: LLMConfig) -> AsyncStream[ChatCompletionChunk]: """ Performs underlying streaming request to llm and returns raw response. """ - raise NotImplementedError(f"Streaming is not supported for {self.llm_config.model_endpoint_type}") + raise NotImplementedError(f"Streaming is not supported for {llm_config.model_endpoint_type}") @abstractmethod def handle_llm_error(self, e: Exception) -> Exception: diff --git a/letta/llm_api/openai_client.py b/letta/llm_api/openai_client.py index ada788caa..5639f884b 100644 --- a/letta/llm_api/openai_client.py +++ b/letta/llm_api/openai_client.py @@ -62,11 +62,11 @@ def supports_parallel_tool_calling(model: str) -> bool: class OpenAIClient(LLMClientBase): - def _prepare_client_kwargs(self) -> dict: + def _prepare_client_kwargs(self, llm_config: LLMConfig) -> dict: api_key = model_settings.openai_api_key or os.environ.get("OPENAI_API_KEY") # supposedly the openai python client requires a dummy API key api_key = api_key or "DUMMY_API_KEY" - kwargs = {"api_key": api_key, "base_url": self.llm_config.model_endpoint} + kwargs = {"api_key": api_key, "base_url": llm_config.model_endpoint} return kwargs @@ -115,7 +115,7 @@ class OpenAIClient(LLMClientBase): # TODO(matt) move into LLMConfig # TODO: This vllm checking is very brittle and is a patch at most tool_choice = None - if llm_config.model_endpoint == "https://inference.memgpt.ai" or (llm_config.handle and "vllm" in self.llm_config.handle): + if llm_config.model_endpoint == "https://inference.memgpt.ai" or (llm_config.handle and "vllm" in llm_config.handle): tool_choice = "auto" # TODO change to "required" once proxy supports it elif tools: # only set if tools is non-Null @@ -152,20 +152,20 @@ class OpenAIClient(LLMClientBase): return data.model_dump(exclude_unset=True) - def request(self, request_data: dict) -> dict: + def request(self, request_data: dict, llm_config: LLMConfig) -> dict: """ Performs underlying synchronous request to OpenAI API and returns raw response dict. """ - client = OpenAI(**self._prepare_client_kwargs()) + client = OpenAI(**self._prepare_client_kwargs(llm_config)) response: ChatCompletion = client.chat.completions.create(**request_data) return response.model_dump() - async def request_async(self, request_data: dict) -> dict: + async def request_async(self, request_data: dict, llm_config: LLMConfig) -> dict: """ Performs underlying asynchronous request to OpenAI API and returns raw response dict. """ - client = AsyncOpenAI(**self._prepare_client_kwargs()) + client = AsyncOpenAI(**self._prepare_client_kwargs(llm_config)) response: ChatCompletion = await client.chat.completions.create(**request_data) return response.model_dump() @@ -173,6 +173,7 @@ class OpenAIClient(LLMClientBase): self, response_data: dict, input_messages: List[PydanticMessage], # Included for consistency, maybe used later + llm_config: LLMConfig, ) -> ChatCompletionResponse: """ Converts raw OpenAI response dict into the ChatCompletionResponse Pydantic model. @@ -183,30 +184,30 @@ class OpenAIClient(LLMClientBase): chat_completion_response = ChatCompletionResponse(**response_data) # Unpack inner thoughts if they were embedded in function arguments - if self.llm_config.put_inner_thoughts_in_kwargs: + if llm_config.put_inner_thoughts_in_kwargs: chat_completion_response = unpack_all_inner_thoughts_from_kwargs( response=chat_completion_response, inner_thoughts_key=INNER_THOUGHTS_KWARG ) # If we used a reasoning model, create a content part for the ommitted reasoning - if is_openai_reasoning_model(self.llm_config.model): + if is_openai_reasoning_model(llm_config.model): chat_completion_response.choices[0].message.ommitted_reasoning_content = True return chat_completion_response - def stream(self, request_data: dict) -> Stream[ChatCompletionChunk]: + def stream(self, request_data: dict, llm_config: LLMConfig) -> Stream[ChatCompletionChunk]: """ Performs underlying streaming request to OpenAI and returns the stream iterator. """ - client = OpenAI(**self._prepare_client_kwargs()) + client = OpenAI(**self._prepare_client_kwargs(llm_config)) response_stream: Stream[ChatCompletionChunk] = client.chat.completions.create(**request_data, stream=True) return response_stream - async def stream_async(self, request_data: dict) -> AsyncStream[ChatCompletionChunk]: + async def stream_async(self, request_data: dict, llm_config: LLMConfig) -> AsyncStream[ChatCompletionChunk]: """ Performs underlying asynchronous streaming request to OpenAI and returns the async stream iterator. """ - client = AsyncOpenAI(**self._prepare_client_kwargs()) + client = AsyncOpenAI(**self._prepare_client_kwargs(llm_config)) response_stream: AsyncStream[ChatCompletionChunk] = await client.chat.completions.create(**request_data, stream=True) return response_stream diff --git a/letta/memory.py b/letta/memory.py index 8554709e7..6d29963f0 100644 --- a/letta/memory.py +++ b/letta/memory.py @@ -79,7 +79,7 @@ def summarize_messages( llm_config_no_inner_thoughts.put_inner_thoughts_in_kwargs = False llm_client = LLMClient.create( - llm_config=llm_config_no_inner_thoughts, + provider=llm_config_no_inner_thoughts.model_endpoint_type, put_inner_thoughts_first=False, ) # try to use new client, otherwise fallback to old flow @@ -87,6 +87,7 @@ def summarize_messages( if llm_client: response = llm_client.send_llm_request( messages=message_sequence, + llm_config=llm_config_no_inner_thoughts, stream=False, ) else: diff --git a/letta/schemas/enums.py b/letta/schemas/enums.py index f4d1aef61..c1d54d776 100644 --- a/letta/schemas/enums.py +++ b/letta/schemas/enums.py @@ -3,6 +3,9 @@ from enum import Enum class ProviderType(str, Enum): anthropic = "anthropic" + google_ai = "google_ai" + google_vertex = "google_vertex" + openai = "openai" class MessageRole(str, Enum): diff --git a/tests/helpers/endpoints_helper.py b/tests/helpers/endpoints_helper.py index ebe7c2b82..9b8f9a9f1 100644 --- a/tests/helpers/endpoints_helper.py +++ b/tests/helpers/endpoints_helper.py @@ -104,10 +104,13 @@ def check_first_response_is_valid_for_llm_endpoint(filename: str, validate_inner messages = client.server.agent_manager.get_in_context_messages(agent_id=full_agent_state.id, actor=client.user) agent = Agent(agent_state=full_agent_state, interface=None, user=client.user) - llm_client = LLMClient.create(llm_config=agent_state.llm_config) + llm_client = LLMClient.create( + provider=agent_state.llm_config.model_endpoint_type, + ) if llm_client: response = llm_client.send_llm_request( messages=messages, + llm_config=agent_state.llm_config, tools=[t.json_schema for t in agent.agent_state.tools], ) else: diff --git a/tests/test_llm_clients.py b/tests/test_llm_clients.py index c33468dfc..7eabb864e 100644 --- a/tests/test_llm_clients.py +++ b/tests/test_llm_clients.py @@ -26,8 +26,8 @@ def llm_config(): @pytest.fixture -def anthropic_client(llm_config): - return AnthropicClient(llm_config=llm_config) +def anthropic_client(): + return AnthropicClient() @pytest.fixture