mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
feat: add llm config per request (#1866)
This commit is contained in:
parent
51afbcb57e
commit
d8d3d89073
@ -332,13 +332,14 @@ class Agent(BaseAgent):
|
||||
log_telemetry(self.logger, "_get_ai_reply create start")
|
||||
# New LLM client flow
|
||||
llm_client = LLMClient.create(
|
||||
llm_config=self.agent_state.llm_config,
|
||||
provider=self.agent_state.llm_config.model_endpoint_type,
|
||||
put_inner_thoughts_first=put_inner_thoughts_first,
|
||||
)
|
||||
|
||||
if llm_client and not stream:
|
||||
response = llm_client.send_llm_request(
|
||||
messages=message_sequence,
|
||||
llm_config=self.agent_state.llm_config,
|
||||
tools=allowed_functions,
|
||||
stream=stream,
|
||||
force_tool_call=force_tool_call,
|
||||
|
@ -66,7 +66,7 @@ class LettaAgent(BaseAgent):
|
||||
)
|
||||
tool_rules_solver = ToolRulesSolver(agent_state.tool_rules)
|
||||
llm_client = LLMClient.create(
|
||||
llm_config=agent_state.llm_config,
|
||||
provider=agent_state.llm_config.model_endpoint_type,
|
||||
put_inner_thoughts_first=True,
|
||||
)
|
||||
for step in range(max_steps):
|
||||
@ -182,6 +182,7 @@ class LettaAgent(BaseAgent):
|
||||
|
||||
response = await llm_client.send_llm_request_async(
|
||||
messages=in_context_messages,
|
||||
llm_config=agent_state.llm_config,
|
||||
tools=allowed_tools,
|
||||
force_tool_call=force_tool_call,
|
||||
stream=stream,
|
||||
|
@ -156,7 +156,7 @@ class LettaAgentBatch:
|
||||
|
||||
log_event(name="init_llm_client")
|
||||
llm_client = LLMClient.create(
|
||||
llm_config=agent_states[0].llm_config,
|
||||
provider=agent_states[0].llm_config.model_endpoint_type,
|
||||
put_inner_thoughts_first=True,
|
||||
)
|
||||
agent_llm_config_mapping = {s.id: s.llm_config for s in agent_states}
|
||||
@ -272,9 +272,14 @@ class LettaAgentBatch:
|
||||
request_status_updates.append(RequestStatusUpdateInfo(llm_batch_id=llm_batch_id, agent_id=aid, request_status=status))
|
||||
|
||||
# translate provider‑specific response → OpenAI‑style tool call (unchanged)
|
||||
llm_client = LLMClient.create(llm_config=item.llm_config, put_inner_thoughts_first=True)
|
||||
llm_client = LLMClient.create(
|
||||
provider=item.llm_config.model_endpoint_type,
|
||||
put_inner_thoughts_first=True,
|
||||
)
|
||||
tool_call = (
|
||||
llm_client.convert_response_to_chat_completion(response_data=pr.message.model_dump(), input_messages=[])
|
||||
llm_client.convert_response_to_chat_completion(
|
||||
response_data=pr.message.model_dump(), input_messages=[], llm_config=item.llm_config
|
||||
)
|
||||
.choices[0]
|
||||
.message.tool_calls[0]
|
||||
)
|
||||
|
@ -43,18 +43,18 @@ logger = get_logger(__name__)
|
||||
|
||||
class AnthropicClient(LLMClientBase):
|
||||
|
||||
def request(self, request_data: dict) -> dict:
|
||||
def request(self, request_data: dict, llm_config: LLMConfig) -> dict:
|
||||
client = self._get_anthropic_client(async_client=False)
|
||||
response = client.beta.messages.create(**request_data, betas=["tools-2024-04-04"])
|
||||
return response.model_dump()
|
||||
|
||||
async def request_async(self, request_data: dict) -> dict:
|
||||
async def request_async(self, request_data: dict, llm_config: LLMConfig) -> dict:
|
||||
client = self._get_anthropic_client(async_client=True)
|
||||
response = await client.beta.messages.create(**request_data, betas=["tools-2024-04-04"])
|
||||
return response.model_dump()
|
||||
|
||||
@trace_method
|
||||
async def stream_async(self, request_data: dict) -> AsyncStream[BetaRawMessageStreamEvent]:
|
||||
async def stream_async(self, request_data: dict, llm_config: LLMConfig) -> AsyncStream[BetaRawMessageStreamEvent]:
|
||||
client = self._get_anthropic_client(async_client=True)
|
||||
request_data["stream"] = True
|
||||
return await client.beta.messages.create(**request_data, betas=["tools-2024-04-04"])
|
||||
@ -310,6 +310,7 @@ class AnthropicClient(LLMClientBase):
|
||||
self,
|
||||
response_data: dict,
|
||||
input_messages: List[PydanticMessage],
|
||||
llm_config: LLMConfig,
|
||||
) -> ChatCompletionResponse:
|
||||
"""
|
||||
Example response from Claude 3:
|
||||
@ -411,7 +412,7 @@ class AnthropicClient(LLMClientBase):
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
),
|
||||
)
|
||||
if self.llm_config.put_inner_thoughts_in_kwargs:
|
||||
if llm_config.put_inner_thoughts_in_kwargs:
|
||||
chat_completion_response = unpack_all_inner_thoughts_from_kwargs(
|
||||
response=chat_completion_response, inner_thoughts_key=INNER_THOUGHTS_KWARG
|
||||
)
|
||||
|
@ -25,15 +25,15 @@ logger = get_logger(__name__)
|
||||
|
||||
class GoogleAIClient(LLMClientBase):
|
||||
|
||||
def request(self, request_data: dict) -> dict:
|
||||
def request(self, request_data: dict, llm_config: LLMConfig) -> dict:
|
||||
"""
|
||||
Performs underlying request to llm and returns raw response.
|
||||
"""
|
||||
# print("[google_ai request]", json.dumps(request_data, indent=2))
|
||||
|
||||
url, headers = get_gemini_endpoint_and_headers(
|
||||
base_url=str(self.llm_config.model_endpoint),
|
||||
model=self.llm_config.model,
|
||||
base_url=str(llm_config.model_endpoint),
|
||||
model=llm_config.model,
|
||||
api_key=str(model_settings.gemini_api_key),
|
||||
key_in_header=True,
|
||||
generate_content=True,
|
||||
@ -55,7 +55,7 @@ class GoogleAIClient(LLMClientBase):
|
||||
tool_objs = [Tool(**t) for t in tools]
|
||||
tool_names = [t.function.name for t in tool_objs]
|
||||
# Convert to the exact payload style Google expects
|
||||
tools = self.convert_tools_to_google_ai_format(tool_objs)
|
||||
tools = self.convert_tools_to_google_ai_format(tool_objs, llm_config)
|
||||
else:
|
||||
tool_names = []
|
||||
|
||||
@ -88,6 +88,7 @@ class GoogleAIClient(LLMClientBase):
|
||||
self,
|
||||
response_data: dict,
|
||||
input_messages: List[PydanticMessage],
|
||||
llm_config: LLMConfig,
|
||||
) -> ChatCompletionResponse:
|
||||
"""
|
||||
Converts custom response format from llm client into an OpenAI
|
||||
@ -150,7 +151,7 @@ class GoogleAIClient(LLMClientBase):
|
||||
assert isinstance(function_args, dict), function_args
|
||||
|
||||
# NOTE: this also involves stripping the inner monologue out of the function
|
||||
if self.llm_config.put_inner_thoughts_in_kwargs:
|
||||
if llm_config.put_inner_thoughts_in_kwargs:
|
||||
from letta.local_llm.constants import INNER_THOUGHTS_KWARG
|
||||
|
||||
assert INNER_THOUGHTS_KWARG in function_args, f"Couldn't find inner thoughts in function args:\n{function_call}"
|
||||
@ -259,14 +260,14 @@ class GoogleAIClient(LLMClientBase):
|
||||
return ChatCompletionResponse(
|
||||
id=response_id,
|
||||
choices=choices,
|
||||
model=self.llm_config.model, # NOTE: Google API doesn't pass back model in the response
|
||||
model=llm_config.model, # NOTE: Google API doesn't pass back model in the response
|
||||
created=get_utc_time_int(),
|
||||
usage=usage,
|
||||
)
|
||||
except KeyError as e:
|
||||
raise e
|
||||
|
||||
def convert_tools_to_google_ai_format(self, tools: List[Tool]) -> List[dict]:
|
||||
def convert_tools_to_google_ai_format(self, tools: List[Tool], llm_config: LLMConfig) -> List[dict]:
|
||||
"""
|
||||
OpenAI style:
|
||||
"tools": [{
|
||||
@ -326,7 +327,7 @@ class GoogleAIClient(LLMClientBase):
|
||||
# Note: Google AI API used to have weird casing requirements, but not any more
|
||||
|
||||
# Add inner thoughts
|
||||
if self.llm_config.put_inner_thoughts_in_kwargs:
|
||||
if llm_config.put_inner_thoughts_in_kwargs:
|
||||
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION
|
||||
|
||||
func["parameters"]["properties"][INNER_THOUGHTS_KWARG] = {
|
||||
|
@ -18,7 +18,7 @@ from letta.utils import get_tool_call_id
|
||||
|
||||
class GoogleVertexClient(GoogleAIClient):
|
||||
|
||||
def request(self, request_data: dict) -> dict:
|
||||
def request(self, request_data: dict, llm_config: LLMConfig) -> dict:
|
||||
"""
|
||||
Performs underlying request to llm and returns raw response.
|
||||
"""
|
||||
@ -29,7 +29,7 @@ class GoogleVertexClient(GoogleAIClient):
|
||||
http_options={"api_version": "v1"},
|
||||
)
|
||||
response = client.models.generate_content(
|
||||
model=self.llm_config.model,
|
||||
model=llm_config.model,
|
||||
contents=request_data["contents"],
|
||||
config=request_data["config"],
|
||||
)
|
||||
@ -45,7 +45,7 @@ class GoogleVertexClient(GoogleAIClient):
|
||||
"""
|
||||
Constructs a request object in the expected data format for this client.
|
||||
"""
|
||||
request_data = super().build_request_data(messages, self.llm_config, tools, force_tool_call)
|
||||
request_data = super().build_request_data(messages, llm_config, tools, force_tool_call)
|
||||
request_data["config"] = request_data.pop("generation_config")
|
||||
request_data["config"]["tools"] = request_data.pop("tools")
|
||||
|
||||
@ -66,6 +66,7 @@ class GoogleVertexClient(GoogleAIClient):
|
||||
self,
|
||||
response_data: dict,
|
||||
input_messages: List[PydanticMessage],
|
||||
llm_config: LLMConfig,
|
||||
) -> ChatCompletionResponse:
|
||||
"""
|
||||
Converts custom response format from llm client into an OpenAI
|
||||
@ -127,7 +128,7 @@ class GoogleVertexClient(GoogleAIClient):
|
||||
assert isinstance(function_args, dict), function_args
|
||||
|
||||
# NOTE: this also involves stripping the inner monologue out of the function
|
||||
if self.llm_config.put_inner_thoughts_in_kwargs:
|
||||
if llm_config.put_inner_thoughts_in_kwargs:
|
||||
from letta.local_llm.constants import INNER_THOUGHTS_KWARG
|
||||
|
||||
assert INNER_THOUGHTS_KWARG in function_args, f"Couldn't find inner thoughts in function args:\n{function_call}"
|
||||
@ -224,7 +225,7 @@ class GoogleVertexClient(GoogleAIClient):
|
||||
return ChatCompletionResponse(
|
||||
id=response_id,
|
||||
choices=choices,
|
||||
model=self.llm_config.model, # NOTE: Google API doesn't pass back model in the response
|
||||
model=llm_config.model, # NOTE: Google API doesn't pass back model in the response
|
||||
created=get_utc_time_int(),
|
||||
usage=usage,
|
||||
)
|
||||
|
@ -1,7 +1,7 @@
|
||||
from typing import Optional
|
||||
|
||||
from letta.llm_api.llm_client_base import LLMClientBase
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.enums import ProviderType
|
||||
|
||||
|
||||
class LLMClient:
|
||||
@ -9,17 +9,15 @@ class LLMClient:
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
llm_config: LLMConfig,
|
||||
provider: ProviderType,
|
||||
put_inner_thoughts_first: bool = True,
|
||||
) -> Optional[LLMClientBase]:
|
||||
"""
|
||||
Create an LLM client based on the model endpoint type.
|
||||
|
||||
Args:
|
||||
llm_config: Configuration for the LLM model
|
||||
provider: The model endpoint type
|
||||
put_inner_thoughts_first: Whether to put inner thoughts first in the response
|
||||
use_structured_output: Whether to use structured output
|
||||
use_tool_naming: Whether to use tool naming
|
||||
|
||||
Returns:
|
||||
An instance of LLMClientBase subclass
|
||||
@ -27,33 +25,29 @@ class LLMClient:
|
||||
Raises:
|
||||
ValueError: If the model endpoint type is not supported
|
||||
"""
|
||||
match llm_config.model_endpoint_type:
|
||||
case "google_ai":
|
||||
match provider:
|
||||
case ProviderType.google_ai:
|
||||
from letta.llm_api.google_ai_client import GoogleAIClient
|
||||
|
||||
return GoogleAIClient(
|
||||
llm_config=llm_config,
|
||||
put_inner_thoughts_first=put_inner_thoughts_first,
|
||||
)
|
||||
case "google_vertex":
|
||||
case ProviderType.google_vertex:
|
||||
from letta.llm_api.google_vertex_client import GoogleVertexClient
|
||||
|
||||
return GoogleVertexClient(
|
||||
llm_config=llm_config,
|
||||
put_inner_thoughts_first=put_inner_thoughts_first,
|
||||
)
|
||||
case "anthropic":
|
||||
case ProviderType.anthropic:
|
||||
from letta.llm_api.anthropic_client import AnthropicClient
|
||||
|
||||
return AnthropicClient(
|
||||
llm_config=llm_config,
|
||||
put_inner_thoughts_first=put_inner_thoughts_first,
|
||||
)
|
||||
case "openai":
|
||||
case ProviderType.openai:
|
||||
from letta.llm_api.openai_client import OpenAIClient
|
||||
|
||||
return OpenAIClient(
|
||||
llm_config=llm_config,
|
||||
put_inner_thoughts_first=put_inner_thoughts_first,
|
||||
)
|
||||
case _:
|
||||
|
@ -20,17 +20,16 @@ class LLMClientBase:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_config: LLMConfig,
|
||||
put_inner_thoughts_first: Optional[bool] = True,
|
||||
use_tool_naming: bool = True,
|
||||
):
|
||||
self.llm_config = llm_config
|
||||
self.put_inner_thoughts_first = put_inner_thoughts_first
|
||||
self.use_tool_naming = use_tool_naming
|
||||
|
||||
def send_llm_request(
|
||||
self,
|
||||
messages: List[Message],
|
||||
llm_config: LLMConfig,
|
||||
tools: Optional[List[dict]] = None, # TODO: change to Tool object
|
||||
stream: bool = False,
|
||||
force_tool_call: Optional[str] = None,
|
||||
@ -40,23 +39,24 @@ class LLMClientBase:
|
||||
If stream=True, returns a Stream[ChatCompletionChunk] that can be iterated over.
|
||||
Otherwise returns a ChatCompletionResponse.
|
||||
"""
|
||||
request_data = self.build_request_data(messages, self.llm_config, tools, force_tool_call)
|
||||
request_data = self.build_request_data(messages, llm_config, tools, force_tool_call)
|
||||
|
||||
try:
|
||||
log_event(name="llm_request_sent", attributes=request_data)
|
||||
if stream:
|
||||
return self.stream(request_data)
|
||||
return self.stream(request_data, llm_config)
|
||||
else:
|
||||
response_data = self.request(request_data)
|
||||
response_data = self.request(request_data, llm_config)
|
||||
log_event(name="llm_response_received", attributes=response_data)
|
||||
except Exception as e:
|
||||
raise self.handle_llm_error(e)
|
||||
|
||||
return self.convert_response_to_chat_completion(response_data, messages)
|
||||
return self.convert_response_to_chat_completion(response_data, messages, llm_config)
|
||||
|
||||
async def send_llm_request_async(
|
||||
self,
|
||||
messages: List[Message],
|
||||
llm_config: LLMConfig,
|
||||
tools: Optional[List[dict]] = None, # TODO: change to Tool object
|
||||
stream: bool = False,
|
||||
force_tool_call: Optional[str] = None,
|
||||
@ -66,19 +66,19 @@ class LLMClientBase:
|
||||
If stream=True, returns an AsyncStream[ChatCompletionChunk] that can be async iterated over.
|
||||
Otherwise returns a ChatCompletionResponse.
|
||||
"""
|
||||
request_data = self.build_request_data(messages, self.llm_config, tools, force_tool_call)
|
||||
request_data = self.build_request_data(messages, llm_config, tools, force_tool_call)
|
||||
|
||||
try:
|
||||
log_event(name="llm_request_sent", attributes=request_data)
|
||||
if stream:
|
||||
return await self.stream_async(request_data)
|
||||
return await self.stream_async(request_data, llm_config)
|
||||
else:
|
||||
response_data = await self.request_async(request_data)
|
||||
response_data = await self.request_async(request_data, llm_config)
|
||||
log_event(name="llm_response_received", attributes=response_data)
|
||||
except Exception as e:
|
||||
raise self.handle_llm_error(e)
|
||||
|
||||
return self.convert_response_to_chat_completion(response_data, messages)
|
||||
return self.convert_response_to_chat_completion(response_data, messages, llm_config)
|
||||
|
||||
async def send_llm_batch_request_async(
|
||||
self,
|
||||
@ -102,14 +102,14 @@ class LLMClientBase:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def request(self, request_data: dict) -> dict:
|
||||
def request(self, request_data: dict, llm_config: LLMConfig) -> dict:
|
||||
"""
|
||||
Performs underlying request to llm and returns raw response.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def request_async(self, request_data: dict) -> dict:
|
||||
async def request_async(self, request_data: dict, llm_config: LLMConfig) -> dict:
|
||||
"""
|
||||
Performs underlying request to llm and returns raw response.
|
||||
"""
|
||||
@ -120,6 +120,7 @@ class LLMClientBase:
|
||||
self,
|
||||
response_data: dict,
|
||||
input_messages: List[Message],
|
||||
llm_config: LLMConfig,
|
||||
) -> ChatCompletionResponse:
|
||||
"""
|
||||
Converts custom response format from llm client into an OpenAI
|
||||
@ -128,18 +129,18 @@ class LLMClientBase:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def stream(self, request_data: dict) -> Stream[ChatCompletionChunk]:
|
||||
def stream(self, request_data: dict, llm_config: LLMConfig) -> Stream[ChatCompletionChunk]:
|
||||
"""
|
||||
Performs underlying streaming request to llm and returns raw response.
|
||||
"""
|
||||
raise NotImplementedError(f"Streaming is not supported for {self.llm_config.model_endpoint_type}")
|
||||
raise NotImplementedError(f"Streaming is not supported for {llm_config.model_endpoint_type}")
|
||||
|
||||
@abstractmethod
|
||||
async def stream_async(self, request_data: dict) -> AsyncStream[ChatCompletionChunk]:
|
||||
async def stream_async(self, request_data: dict, llm_config: LLMConfig) -> AsyncStream[ChatCompletionChunk]:
|
||||
"""
|
||||
Performs underlying streaming request to llm and returns raw response.
|
||||
"""
|
||||
raise NotImplementedError(f"Streaming is not supported for {self.llm_config.model_endpoint_type}")
|
||||
raise NotImplementedError(f"Streaming is not supported for {llm_config.model_endpoint_type}")
|
||||
|
||||
@abstractmethod
|
||||
def handle_llm_error(self, e: Exception) -> Exception:
|
||||
|
@ -62,11 +62,11 @@ def supports_parallel_tool_calling(model: str) -> bool:
|
||||
|
||||
|
||||
class OpenAIClient(LLMClientBase):
|
||||
def _prepare_client_kwargs(self) -> dict:
|
||||
def _prepare_client_kwargs(self, llm_config: LLMConfig) -> dict:
|
||||
api_key = model_settings.openai_api_key or os.environ.get("OPENAI_API_KEY")
|
||||
# supposedly the openai python client requires a dummy API key
|
||||
api_key = api_key or "DUMMY_API_KEY"
|
||||
kwargs = {"api_key": api_key, "base_url": self.llm_config.model_endpoint}
|
||||
kwargs = {"api_key": api_key, "base_url": llm_config.model_endpoint}
|
||||
|
||||
return kwargs
|
||||
|
||||
@ -115,7 +115,7 @@ class OpenAIClient(LLMClientBase):
|
||||
# TODO(matt) move into LLMConfig
|
||||
# TODO: This vllm checking is very brittle and is a patch at most
|
||||
tool_choice = None
|
||||
if llm_config.model_endpoint == "https://inference.memgpt.ai" or (llm_config.handle and "vllm" in self.llm_config.handle):
|
||||
if llm_config.model_endpoint == "https://inference.memgpt.ai" or (llm_config.handle and "vllm" in llm_config.handle):
|
||||
tool_choice = "auto" # TODO change to "required" once proxy supports it
|
||||
elif tools:
|
||||
# only set if tools is non-Null
|
||||
@ -152,20 +152,20 @@ class OpenAIClient(LLMClientBase):
|
||||
|
||||
return data.model_dump(exclude_unset=True)
|
||||
|
||||
def request(self, request_data: dict) -> dict:
|
||||
def request(self, request_data: dict, llm_config: LLMConfig) -> dict:
|
||||
"""
|
||||
Performs underlying synchronous request to OpenAI API and returns raw response dict.
|
||||
"""
|
||||
client = OpenAI(**self._prepare_client_kwargs())
|
||||
client = OpenAI(**self._prepare_client_kwargs(llm_config))
|
||||
|
||||
response: ChatCompletion = client.chat.completions.create(**request_data)
|
||||
return response.model_dump()
|
||||
|
||||
async def request_async(self, request_data: dict) -> dict:
|
||||
async def request_async(self, request_data: dict, llm_config: LLMConfig) -> dict:
|
||||
"""
|
||||
Performs underlying asynchronous request to OpenAI API and returns raw response dict.
|
||||
"""
|
||||
client = AsyncOpenAI(**self._prepare_client_kwargs())
|
||||
client = AsyncOpenAI(**self._prepare_client_kwargs(llm_config))
|
||||
response: ChatCompletion = await client.chat.completions.create(**request_data)
|
||||
return response.model_dump()
|
||||
|
||||
@ -173,6 +173,7 @@ class OpenAIClient(LLMClientBase):
|
||||
self,
|
||||
response_data: dict,
|
||||
input_messages: List[PydanticMessage], # Included for consistency, maybe used later
|
||||
llm_config: LLMConfig,
|
||||
) -> ChatCompletionResponse:
|
||||
"""
|
||||
Converts raw OpenAI response dict into the ChatCompletionResponse Pydantic model.
|
||||
@ -183,30 +184,30 @@ class OpenAIClient(LLMClientBase):
|
||||
chat_completion_response = ChatCompletionResponse(**response_data)
|
||||
|
||||
# Unpack inner thoughts if they were embedded in function arguments
|
||||
if self.llm_config.put_inner_thoughts_in_kwargs:
|
||||
if llm_config.put_inner_thoughts_in_kwargs:
|
||||
chat_completion_response = unpack_all_inner_thoughts_from_kwargs(
|
||||
response=chat_completion_response, inner_thoughts_key=INNER_THOUGHTS_KWARG
|
||||
)
|
||||
|
||||
# If we used a reasoning model, create a content part for the ommitted reasoning
|
||||
if is_openai_reasoning_model(self.llm_config.model):
|
||||
if is_openai_reasoning_model(llm_config.model):
|
||||
chat_completion_response.choices[0].message.ommitted_reasoning_content = True
|
||||
|
||||
return chat_completion_response
|
||||
|
||||
def stream(self, request_data: dict) -> Stream[ChatCompletionChunk]:
|
||||
def stream(self, request_data: dict, llm_config: LLMConfig) -> Stream[ChatCompletionChunk]:
|
||||
"""
|
||||
Performs underlying streaming request to OpenAI and returns the stream iterator.
|
||||
"""
|
||||
client = OpenAI(**self._prepare_client_kwargs())
|
||||
client = OpenAI(**self._prepare_client_kwargs(llm_config))
|
||||
response_stream: Stream[ChatCompletionChunk] = client.chat.completions.create(**request_data, stream=True)
|
||||
return response_stream
|
||||
|
||||
async def stream_async(self, request_data: dict) -> AsyncStream[ChatCompletionChunk]:
|
||||
async def stream_async(self, request_data: dict, llm_config: LLMConfig) -> AsyncStream[ChatCompletionChunk]:
|
||||
"""
|
||||
Performs underlying asynchronous streaming request to OpenAI and returns the async stream iterator.
|
||||
"""
|
||||
client = AsyncOpenAI(**self._prepare_client_kwargs())
|
||||
client = AsyncOpenAI(**self._prepare_client_kwargs(llm_config))
|
||||
response_stream: AsyncStream[ChatCompletionChunk] = await client.chat.completions.create(**request_data, stream=True)
|
||||
return response_stream
|
||||
|
||||
|
@ -79,7 +79,7 @@ def summarize_messages(
|
||||
llm_config_no_inner_thoughts.put_inner_thoughts_in_kwargs = False
|
||||
|
||||
llm_client = LLMClient.create(
|
||||
llm_config=llm_config_no_inner_thoughts,
|
||||
provider=llm_config_no_inner_thoughts.model_endpoint_type,
|
||||
put_inner_thoughts_first=False,
|
||||
)
|
||||
# try to use new client, otherwise fallback to old flow
|
||||
@ -87,6 +87,7 @@ def summarize_messages(
|
||||
if llm_client:
|
||||
response = llm_client.send_llm_request(
|
||||
messages=message_sequence,
|
||||
llm_config=llm_config_no_inner_thoughts,
|
||||
stream=False,
|
||||
)
|
||||
else:
|
||||
|
@ -3,6 +3,9 @@ from enum import Enum
|
||||
|
||||
class ProviderType(str, Enum):
|
||||
anthropic = "anthropic"
|
||||
google_ai = "google_ai"
|
||||
google_vertex = "google_vertex"
|
||||
openai = "openai"
|
||||
|
||||
|
||||
class MessageRole(str, Enum):
|
||||
|
@ -104,10 +104,13 @@ def check_first_response_is_valid_for_llm_endpoint(filename: str, validate_inner
|
||||
messages = client.server.agent_manager.get_in_context_messages(agent_id=full_agent_state.id, actor=client.user)
|
||||
agent = Agent(agent_state=full_agent_state, interface=None, user=client.user)
|
||||
|
||||
llm_client = LLMClient.create(llm_config=agent_state.llm_config)
|
||||
llm_client = LLMClient.create(
|
||||
provider=agent_state.llm_config.model_endpoint_type,
|
||||
)
|
||||
if llm_client:
|
||||
response = llm_client.send_llm_request(
|
||||
messages=messages,
|
||||
llm_config=agent_state.llm_config,
|
||||
tools=[t.json_schema for t in agent.agent_state.tools],
|
||||
)
|
||||
else:
|
||||
|
@ -26,8 +26,8 @@ def llm_config():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def anthropic_client(llm_config):
|
||||
return AnthropicClient(llm_config=llm_config)
|
||||
def anthropic_client():
|
||||
return AnthropicClient()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
Loading…
Reference in New Issue
Block a user