fix: refactor Google AI Provider / helper functions and add endpoint test (#1850)

Co-authored-by: Matt Zhou <mattzhou@Matts-MacBook-Pro.local>
This commit is contained in:
Matthew Zhou 2024-10-08 16:55:11 -07:00 committed by GitHub
parent 51ad4ddac3
commit bc2c0b2482
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 127 additions and 125 deletions

View File

@ -76,7 +76,7 @@ class LettaCredentials:
"azure_embedding_deployment": get_field(config, "azure", "embedding_deployment"), "azure_embedding_deployment": get_field(config, "azure", "embedding_deployment"),
# gemini # gemini
"google_ai_key": get_field(config, "google_ai", "key"), "google_ai_key": get_field(config, "google_ai", "key"),
"google_ai_service_endpoint": get_field(config, "google_ai", "service_endpoint"), # "google_ai_service_endpoint": get_field(config, "google_ai", "service_endpoint"),
# anthropic # anthropic
"anthropic_key": get_field(config, "anthropic", "key"), "anthropic_key": get_field(config, "anthropic", "key"),
# cohere # cohere
@ -117,7 +117,7 @@ class LettaCredentials:
# gemini # gemini
set_field(config, "google_ai", "key", self.google_ai_key) set_field(config, "google_ai", "key", self.google_ai_key)
set_field(config, "google_ai", "service_endpoint", self.google_ai_service_endpoint) # set_field(config, "google_ai", "service_endpoint", self.google_ai_service_endpoint)
# anthropic # anthropic
set_field(config, "anthropic", "key", self.anthropic_key) set_field(config, "anthropic", "key", self.anthropic_key)

View File

@ -1,9 +1,10 @@
import uuid import uuid
from typing import List, Optional from typing import List, Optional, Tuple
import requests import requests
from letta.constants import NON_USER_MSG_PREFIX from letta.constants import NON_USER_MSG_PREFIX
from letta.llm_api.helpers import make_post_request
from letta.local_llm.json_parser import clean_json_string_extra_backslash from letta.local_llm.json_parser import clean_json_string_extra_backslash
from letta.local_llm.utils import count_tokens from letta.local_llm.utils import count_tokens
from letta.schemas.openai.chat_completion_request import Tool from letta.schemas.openai.chat_completion_request import Tool
@ -15,27 +16,41 @@ from letta.schemas.openai.chat_completion_response import (
ToolCall, ToolCall,
UsageStatistics, UsageStatistics,
) )
from letta.utils import get_tool_call_id, get_utc_time from letta.utils import get_tool_call_id, get_utc_time, json_dumps
# from letta.data_types import ToolCall
SUPPORTED_MODELS = [ def get_gemini_endpoint_and_headers(
"gemini-pro", base_url: str, model: Optional[str], api_key: str, key_in_header: bool = True, generate_content: bool = False
] ) -> Tuple[str, dict]:
"""
Dynamically generate the model endpoint and headers.
"""
url = f"{base_url}/v1beta/models"
# Add the model
if model is not None:
url += f"/{model}"
def google_ai_get_model_details(service_endpoint: str, api_key: str, model: str, key_in_header: bool = True) -> List[dict]: # Add extension for generating content if we're hitting the LM
from letta.utils import printd if generate_content:
url += ":generateContent"
# Decide if api key should be in header or not
# Two ways to pass the key: https://ai.google.dev/tutorials/setup # Two ways to pass the key: https://ai.google.dev/tutorials/setup
if key_in_header: if key_in_header:
url = f"https://{service_endpoint}.googleapis.com/v1beta/models/{model}"
headers = {"Content-Type": "application/json", "x-goog-api-key": api_key} headers = {"Content-Type": "application/json", "x-goog-api-key": api_key}
else: else:
url = f"https://{service_endpoint}.googleapis.com/v1beta/models/{model}?key={api_key}" url += f"?key={api_key}"
headers = {"Content-Type": "application/json"} headers = {"Content-Type": "application/json"}
return url, headers
def google_ai_get_model_details(base_url: str, api_key: str, model: str, key_in_header: bool = True) -> List[dict]:
from letta.utils import printd
url, headers = get_gemini_endpoint_and_headers(base_url, model, api_key, key_in_header)
try: try:
response = requests.get(url, headers=headers) response = requests.get(url, headers=headers)
printd(f"response = {response}") printd(f"response = {response}")
@ -66,25 +81,17 @@ def google_ai_get_model_details(service_endpoint: str, api_key: str, model: str,
raise e raise e
def google_ai_get_model_context_window(service_endpoint: str, api_key: str, model: str, key_in_header: bool = True) -> int: def google_ai_get_model_context_window(base_url: str, api_key: str, model: str, key_in_header: bool = True) -> int:
model_details = google_ai_get_model_details( model_details = google_ai_get_model_details(base_url=base_url, api_key=api_key, model=model, key_in_header=key_in_header)
service_endpoint=service_endpoint, api_key=api_key, model=model, key_in_header=key_in_header
)
# TODO should this be: # TODO should this be:
# return model_details["inputTokenLimit"] + model_details["outputTokenLimit"] # return model_details["inputTokenLimit"] + model_details["outputTokenLimit"]
return int(model_details["inputTokenLimit"]) return int(model_details["inputTokenLimit"])
def google_ai_get_model_list(service_endpoint: str, api_key: str, key_in_header: bool = True) -> List[dict]: def google_ai_get_model_list(base_url: str, api_key: str, key_in_header: bool = True) -> List[dict]:
from letta.utils import printd from letta.utils import printd
# Two ways to pass the key: https://ai.google.dev/tutorials/setup url, headers = get_gemini_endpoint_and_headers(base_url, None, api_key, key_in_header)
if key_in_header:
url = f"https://{service_endpoint}.googleapis.com/v1beta/models"
headers = {"Content-Type": "application/json", "x-goog-api-key": api_key}
else:
url = f"https://{service_endpoint}.googleapis.com/v1beta/models?key={api_key}"
headers = {"Content-Type": "application/json"}
try: try:
response = requests.get(url, headers=headers) response = requests.get(url, headers=headers)
@ -396,7 +403,7 @@ def convert_google_ai_response_to_chatcompletion(
# TODO convert 'data' type to pydantic # TODO convert 'data' type to pydantic
def google_ai_chat_completions_request( def google_ai_chat_completions_request(
service_endpoint: str, base_url: str,
model: str, model: str,
api_key: str, api_key: str,
data: dict, data: dict,
@ -414,55 +421,23 @@ def google_ai_chat_completions_request(
This service has the following service endpoint and all URIs below are relative to this service endpoint: This service has the following service endpoint and all URIs below are relative to this service endpoint:
https://xxx.googleapis.com https://xxx.googleapis.com
""" """
from letta.utils import printd
assert service_endpoint is not None, "Missing service_endpoint when calling Google AI"
assert api_key is not None, "Missing api_key when calling Google AI" assert api_key is not None, "Missing api_key when calling Google AI"
assert model in SUPPORTED_MODELS, f"Model '{model}' not in supported models: {', '.join(SUPPORTED_MODELS)}"
# Two ways to pass the key: https://ai.google.dev/tutorials/setup url, headers = get_gemini_endpoint_and_headers(base_url, model, api_key, key_in_header, generate_content=True)
if key_in_header:
url = f"https://{service_endpoint}.googleapis.com/v1beta/models/{model}:generateContent"
headers = {"Content-Type": "application/json", "x-goog-api-key": api_key}
else:
url = f"https://{service_endpoint}.googleapis.com/v1beta/models/{model}:generateContent?key={api_key}"
headers = {"Content-Type": "application/json"}
# data["contents"][-1]["role"] = "model" # data["contents"][-1]["role"] = "model"
if add_postfunc_model_messages: if add_postfunc_model_messages:
data["contents"] = add_dummy_model_messages(data["contents"]) data["contents"] = add_dummy_model_messages(data["contents"])
printd(f"Sending request to {url}") response_json = make_post_request(url, headers, data)
try: try:
response = requests.post(url, headers=headers, json=data)
printd(f"response = {response}")
response.raise_for_status() # Raises HTTPError for 4XX/5XX status
response = response.json() # convert to dict from string
printd(f"response.json = {response}")
# Convert Google AI response to ChatCompletion style
return convert_google_ai_response_to_chatcompletion( return convert_google_ai_response_to_chatcompletion(
response_json=response, response_json=response_json,
model=model, model=data.get("model"),
input_messages=data["contents"], input_messages=data["contents"],
pull_inner_thoughts_from_args=inner_thoughts_in_kwargs, pull_inner_thoughts_from_args=data.get("inner_thoughts_in_kwargs", False),
) )
except Exception as conversion_error:
except requests.exceptions.HTTPError as http_err: print(f"Error during response conversion: {conversion_error}")
# Handle HTTP errors (e.g., response 4XX, 5XX) raise conversion_error
printd(f"Got HTTPError, exception={http_err}, payload={data}")
# Print the HTTP status code
print(f"HTTP Error: {http_err.response.status_code}")
# Print the response content (error message from server)
print(f"Message: {http_err.response.text}")
raise http_err
except requests.exceptions.RequestException as req_err:
# Handle other requests-related errors (e.g., connection error)
printd(f"Got RequestException, exception={req_err}")
raise req_err
except Exception as e:
# Handle other potential errors
printd(f"Got unknown Exception, exception={e}")
raise e

View File

@ -1,14 +1,69 @@
import copy import copy
import json import json
import warnings import warnings
from typing import List, Union from typing import Any, List, Union
import requests import requests
from letta.constants import OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING from letta.constants import OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING
from letta.schemas.enums import OptionState from letta.schemas.enums import OptionState
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice
from letta.utils import json_dumps from letta.utils import json_dumps, printd
def make_post_request(url: str, headers: dict[str, str], data: dict[str, Any]) -> dict[str, Any]:
printd(f"Sending request to {url}")
try:
# Make the POST request
response = requests.post(url, headers=headers, json=data)
printd(f"Response status code: {response.status_code}")
# Raise for 4XX/5XX HTTP errors
response.raise_for_status()
# Ensure the content is JSON before parsing
if response.headers.get("Content-Type") == "application/json":
response_data = response.json() # Convert to dict from JSON
printd(f"Response JSON: {response_data}")
else:
error_message = f"Unexpected content type returned: {response.headers.get('Content-Type')}"
printd(error_message)
raise ValueError(error_message)
# Process the response using the callback function
return response_data
except requests.exceptions.HTTPError as http_err:
# HTTP errors (4XX, 5XX)
error_message = f"HTTP error occurred: {http_err}"
if http_err.response is not None:
error_message += f" | Status code: {http_err.response.status_code}, Message: {http_err.response.text}"
printd(error_message)
raise requests.exceptions.HTTPError(error_message) from http_err
except requests.exceptions.Timeout as timeout_err:
# Handle timeout errors
error_message = f"Request timed out: {timeout_err}"
printd(error_message)
raise requests.exceptions.Timeout(error_message) from timeout_err
except requests.exceptions.RequestException as req_err:
# Non-HTTP errors (e.g., connection, SSL errors)
error_message = f"Request failed: {req_err}"
printd(error_message)
raise requests.exceptions.RequestException(error_message) from req_err
except ValueError as val_err:
# Handle content-type or non-JSON response issues
error_message = f"ValueError: {val_err}"
printd(error_message)
raise ValueError(error_message) from val_err
except Exception as e:
# Catch any other unknown exceptions
error_message = f"An unexpected error occurred: {e}"
printd(error_message)
raise Exception(error_message) from e
# TODO update to use better types # TODO update to use better types

View File

@ -28,7 +28,6 @@ from letta.local_llm.constants import (
INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG,
INNER_THOUGHTS_KWARG_DESCRIPTION, INNER_THOUGHTS_KWARG_DESCRIPTION,
) )
from letta.providers import GoogleAIProvider
from letta.schemas.enums import OptionState from letta.schemas.enums import OptionState
from letta.schemas.llm_config import LLMConfig from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message from letta.schemas.message import Message
@ -231,7 +230,7 @@ def create(
return google_ai_chat_completions_request( return google_ai_chat_completions_request(
inner_thoughts_in_kwargs=google_ai_inner_thoughts_in_kwarg, inner_thoughts_in_kwargs=google_ai_inner_thoughts_in_kwarg,
service_endpoint=GoogleAIProvider(model_settings.gemini_api_key).service_endpoint, base_url=llm_config.model_endpoint,
model=llm_config.model, model=llm_config.model,
api_key=model_settings.gemini_api_key, api_key=model_settings.gemini_api_key,
# see structure of payload here: https://ai.google.dev/docs/function_calling # see structure of payload here: https://ai.google.dev/docs/function_calling

View File

@ -9,7 +9,7 @@ from httpx_sse._exceptions import SSEError
from letta.constants import OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING from letta.constants import OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING
from letta.errors import LLMError from letta.errors import LLMError
from letta.llm_api.helpers import add_inner_thoughts_to_functions from letta.llm_api.helpers import add_inner_thoughts_to_functions, make_post_request
from letta.local_llm.constants import ( from letta.local_llm.constants import (
INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG,
INNER_THOUGHTS_KWARG_DESCRIPTION, INNER_THOUGHTS_KWARG_DESCRIPTION,
@ -483,58 +483,14 @@ def openai_chat_completions_request(
data.pop("tools") data.pop("tools")
data.pop("tool_choice", None) # extra safe, should exist always (default="auto") data.pop("tool_choice", None) # extra safe, should exist always (default="auto")
printd(f"Sending request to {url}") response_json = make_post_request(url, headers, data)
try: return ChatCompletionResponse(**response_json)
response = requests.post(url, headers=headers, json=data)
printd(f"response = {response}, response.text = {response.text}")
# print(json.dumps(data, indent=4))
# raise requests.exceptions.HTTPError
response.raise_for_status() # Raises HTTPError for 4XX/5XX status
response = response.json() # convert to dict from string
printd(f"response.json = {response}")
response = ChatCompletionResponse(**response) # convert to 'dot-dict' style which is the openai python client default
return response
except requests.exceptions.HTTPError as http_err:
# Handle HTTP errors (e.g., response 4XX, 5XX)
printd(f"Got HTTPError, exception={http_err}, payload={data}")
raise http_err
except requests.exceptions.RequestException as req_err:
# Handle other requests-related errors (e.g., connection error)
printd(f"Got RequestException, exception={req_err}")
raise req_err
except Exception as e:
# Handle other potential errors
printd(f"Got unknown Exception, exception={e}")
raise e
def openai_embeddings_request(url: str, api_key: str, data: dict) -> EmbeddingResponse: def openai_embeddings_request(url: str, api_key: str, data: dict) -> EmbeddingResponse:
"""https://platform.openai.com/docs/api-reference/embeddings/create""" """https://platform.openai.com/docs/api-reference/embeddings/create"""
from letta.utils import printd
url = smart_urljoin(url, "embeddings") url = smart_urljoin(url, "embeddings")
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"} headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
response_json = make_post_request(url, headers, data)
printd(f"Sending request to {url}") return EmbeddingResponse(**response_json)
try:
response = requests.post(url, headers=headers, json=data)
printd(f"response = {response}")
response.raise_for_status() # Raises HTTPError for 4XX/5XX status
response = response.json() # convert to dict from string
printd(f"response.json = {response}")
response = EmbeddingResponse(**response) # convert to 'dot-dict' style which is the openai python client default
return response
except requests.exceptions.HTTPError as http_err:
# Handle HTTP errors (e.g., response 4XX, 5XX)
printd(f"Got HTTPError, exception={http_err}, payload={data}")
raise http_err
except requests.exceptions.RequestException as req_err:
# Handle other requests-related errors (e.g., connection error)
printd(f"Got RequestException, exception={req_err}")
raise req_err
except Exception as e:
# Handle other potential errors
printd(f"Got unknown Exception, exception={e}")
raise e

View File

@ -217,14 +217,12 @@ class GroqProvider(OpenAIProvider):
class GoogleAIProvider(Provider): class GoogleAIProvider(Provider):
# gemini # gemini
api_key: str = Field(..., description="API key for the Google AI API.") api_key: str = Field(..., description="API key for the Google AI API.")
service_endpoint: str = "generativelanguage" # TODO: remove once old functions are refactored to just use base_url
base_url: str = "https://generativelanguage.googleapis.com" base_url: str = "https://generativelanguage.googleapis.com"
def list_llm_models(self): def list_llm_models(self):
from letta.llm_api.google_ai import google_ai_get_model_list from letta.llm_api.google_ai import google_ai_get_model_list
# TODO: use base_url instead model_options = google_ai_get_model_list(base_url=self.base_url, api_key=self.api_key)
model_options = google_ai_get_model_list(service_endpoint=self.service_endpoint, api_key=self.api_key)
# filter by 'generateContent' models # filter by 'generateContent' models
model_options = [mo for mo in model_options if "generateContent" in mo["supportedGenerationMethods"]] model_options = [mo for mo in model_options if "generateContent" in mo["supportedGenerationMethods"]]
model_options = [str(m["name"]) for m in model_options] model_options = [str(m["name"]) for m in model_options]
@ -251,7 +249,7 @@ class GoogleAIProvider(Provider):
from letta.llm_api.google_ai import google_ai_get_model_list from letta.llm_api.google_ai import google_ai_get_model_list
# TODO: use base_url instead # TODO: use base_url instead
model_options = google_ai_get_model_list(service_endpoint=self.service_endpoint, api_key=self.api_key) model_options = google_ai_get_model_list(base_url=self.base_url, api_key=self.api_key)
# filter by 'generateContent' models # filter by 'generateContent' models
model_options = [mo for mo in model_options if "embedContent" in mo["supportedGenerationMethods"]] model_options = [mo for mo in model_options if "embedContent" in mo["supportedGenerationMethods"]]
model_options = [str(m["name"]) for m in model_options] model_options = [str(m["name"]) for m in model_options]
@ -273,8 +271,7 @@ class GoogleAIProvider(Provider):
def get_model_context_window(self, model_name: str): def get_model_context_window(self, model_name: str):
from letta.llm_api.google_ai import google_ai_get_model_context_window from letta.llm_api.google_ai import google_ai_get_model_context_window
# TODO: use base_url instead return google_ai_get_model_context_window(self.base_url, self.api_key, model_name)
return google_ai_get_model_context_window(self.service_endpoint, self.api_key, model_name)
class AzureProvider(Provider): class AzureProvider(Provider):

View File

@ -74,6 +74,9 @@ class ChatCompletionResponse(BaseModel):
object: Literal["chat.completion"] = "chat.completion" object: Literal["chat.completion"] = "chat.completion"
usage: UsageStatistics usage: UsageStatistics
def __str__(self):
return self.model_dump_json(indent=4)
class FunctionCallDelta(BaseModel): class FunctionCallDelta(BaseModel):
# arguments: Optional[str] = None # arguments: Optional[str] = None

View File

@ -0,0 +1,7 @@
{
"context_window": 2097152,
"model": "gemini-1.5-pro-latest",
"model_endpoint_type": "google_ai",
"model_endpoint": "https://generativelanguage.googleapis.com",
"model_wrapper": null
}

View File

@ -273,3 +273,13 @@ def test_groq_llama31_70b_edit_core_memory():
response = check_agent_edit_core_memory(filename) response = check_agent_edit_core_memory(filename)
# Log out successful response # Log out successful response
print(f"Got successful response from client: \n\n{response}") print(f"Got successful response from client: \n\n{response}")
# ======================================================================================================================
# GEMINI TESTS
# ======================================================================================================================
def test_gemini_pro_15_returns_valid_first_message():
filename = os.path.join(llm_config_dir, "gemini-pro.json")
response = check_first_response_is_valid_for_llm_endpoint(filename)
# Log out successful response
print(f"Got successful response from client: \n\n{response}")