mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
fix: refactor Google AI Provider / helper functions and add endpoint test (#1850)
Co-authored-by: Matt Zhou <mattzhou@Matts-MacBook-Pro.local>
This commit is contained in:
parent
51ad4ddac3
commit
bc2c0b2482
@ -76,7 +76,7 @@ class LettaCredentials:
|
|||||||
"azure_embedding_deployment": get_field(config, "azure", "embedding_deployment"),
|
"azure_embedding_deployment": get_field(config, "azure", "embedding_deployment"),
|
||||||
# gemini
|
# gemini
|
||||||
"google_ai_key": get_field(config, "google_ai", "key"),
|
"google_ai_key": get_field(config, "google_ai", "key"),
|
||||||
"google_ai_service_endpoint": get_field(config, "google_ai", "service_endpoint"),
|
# "google_ai_service_endpoint": get_field(config, "google_ai", "service_endpoint"),
|
||||||
# anthropic
|
# anthropic
|
||||||
"anthropic_key": get_field(config, "anthropic", "key"),
|
"anthropic_key": get_field(config, "anthropic", "key"),
|
||||||
# cohere
|
# cohere
|
||||||
@ -117,7 +117,7 @@ class LettaCredentials:
|
|||||||
|
|
||||||
# gemini
|
# gemini
|
||||||
set_field(config, "google_ai", "key", self.google_ai_key)
|
set_field(config, "google_ai", "key", self.google_ai_key)
|
||||||
set_field(config, "google_ai", "service_endpoint", self.google_ai_service_endpoint)
|
# set_field(config, "google_ai", "service_endpoint", self.google_ai_service_endpoint)
|
||||||
|
|
||||||
# anthropic
|
# anthropic
|
||||||
set_field(config, "anthropic", "key", self.anthropic_key)
|
set_field(config, "anthropic", "key", self.anthropic_key)
|
||||||
|
@ -1,9 +1,10 @@
|
|||||||
import uuid
|
import uuid
|
||||||
from typing import List, Optional
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from letta.constants import NON_USER_MSG_PREFIX
|
from letta.constants import NON_USER_MSG_PREFIX
|
||||||
|
from letta.llm_api.helpers import make_post_request
|
||||||
from letta.local_llm.json_parser import clean_json_string_extra_backslash
|
from letta.local_llm.json_parser import clean_json_string_extra_backslash
|
||||||
from letta.local_llm.utils import count_tokens
|
from letta.local_llm.utils import count_tokens
|
||||||
from letta.schemas.openai.chat_completion_request import Tool
|
from letta.schemas.openai.chat_completion_request import Tool
|
||||||
@ -15,27 +16,41 @@ from letta.schemas.openai.chat_completion_response import (
|
|||||||
ToolCall,
|
ToolCall,
|
||||||
UsageStatistics,
|
UsageStatistics,
|
||||||
)
|
)
|
||||||
from letta.utils import get_tool_call_id, get_utc_time
|
from letta.utils import get_tool_call_id, get_utc_time, json_dumps
|
||||||
|
|
||||||
# from letta.data_types import ToolCall
|
|
||||||
|
|
||||||
|
|
||||||
SUPPORTED_MODELS = [
|
def get_gemini_endpoint_and_headers(
|
||||||
"gemini-pro",
|
base_url: str, model: Optional[str], api_key: str, key_in_header: bool = True, generate_content: bool = False
|
||||||
]
|
) -> Tuple[str, dict]:
|
||||||
|
"""
|
||||||
|
Dynamically generate the model endpoint and headers.
|
||||||
|
"""
|
||||||
|
url = f"{base_url}/v1beta/models"
|
||||||
|
|
||||||
|
# Add the model
|
||||||
|
if model is not None:
|
||||||
|
url += f"/{model}"
|
||||||
|
|
||||||
def google_ai_get_model_details(service_endpoint: str, api_key: str, model: str, key_in_header: bool = True) -> List[dict]:
|
# Add extension for generating content if we're hitting the LM
|
||||||
from letta.utils import printd
|
if generate_content:
|
||||||
|
url += ":generateContent"
|
||||||
|
|
||||||
|
# Decide if api key should be in header or not
|
||||||
# Two ways to pass the key: https://ai.google.dev/tutorials/setup
|
# Two ways to pass the key: https://ai.google.dev/tutorials/setup
|
||||||
if key_in_header:
|
if key_in_header:
|
||||||
url = f"https://{service_endpoint}.googleapis.com/v1beta/models/{model}"
|
|
||||||
headers = {"Content-Type": "application/json", "x-goog-api-key": api_key}
|
headers = {"Content-Type": "application/json", "x-goog-api-key": api_key}
|
||||||
else:
|
else:
|
||||||
url = f"https://{service_endpoint}.googleapis.com/v1beta/models/{model}?key={api_key}"
|
url += f"?key={api_key}"
|
||||||
headers = {"Content-Type": "application/json"}
|
headers = {"Content-Type": "application/json"}
|
||||||
|
|
||||||
|
return url, headers
|
||||||
|
|
||||||
|
|
||||||
|
def google_ai_get_model_details(base_url: str, api_key: str, model: str, key_in_header: bool = True) -> List[dict]:
|
||||||
|
from letta.utils import printd
|
||||||
|
|
||||||
|
url, headers = get_gemini_endpoint_and_headers(base_url, model, api_key, key_in_header)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = requests.get(url, headers=headers)
|
response = requests.get(url, headers=headers)
|
||||||
printd(f"response = {response}")
|
printd(f"response = {response}")
|
||||||
@ -66,25 +81,17 @@ def google_ai_get_model_details(service_endpoint: str, api_key: str, model: str,
|
|||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
def google_ai_get_model_context_window(service_endpoint: str, api_key: str, model: str, key_in_header: bool = True) -> int:
|
def google_ai_get_model_context_window(base_url: str, api_key: str, model: str, key_in_header: bool = True) -> int:
|
||||||
model_details = google_ai_get_model_details(
|
model_details = google_ai_get_model_details(base_url=base_url, api_key=api_key, model=model, key_in_header=key_in_header)
|
||||||
service_endpoint=service_endpoint, api_key=api_key, model=model, key_in_header=key_in_header
|
|
||||||
)
|
|
||||||
# TODO should this be:
|
# TODO should this be:
|
||||||
# return model_details["inputTokenLimit"] + model_details["outputTokenLimit"]
|
# return model_details["inputTokenLimit"] + model_details["outputTokenLimit"]
|
||||||
return int(model_details["inputTokenLimit"])
|
return int(model_details["inputTokenLimit"])
|
||||||
|
|
||||||
|
|
||||||
def google_ai_get_model_list(service_endpoint: str, api_key: str, key_in_header: bool = True) -> List[dict]:
|
def google_ai_get_model_list(base_url: str, api_key: str, key_in_header: bool = True) -> List[dict]:
|
||||||
from letta.utils import printd
|
from letta.utils import printd
|
||||||
|
|
||||||
# Two ways to pass the key: https://ai.google.dev/tutorials/setup
|
url, headers = get_gemini_endpoint_and_headers(base_url, None, api_key, key_in_header)
|
||||||
if key_in_header:
|
|
||||||
url = f"https://{service_endpoint}.googleapis.com/v1beta/models"
|
|
||||||
headers = {"Content-Type": "application/json", "x-goog-api-key": api_key}
|
|
||||||
else:
|
|
||||||
url = f"https://{service_endpoint}.googleapis.com/v1beta/models?key={api_key}"
|
|
||||||
headers = {"Content-Type": "application/json"}
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = requests.get(url, headers=headers)
|
response = requests.get(url, headers=headers)
|
||||||
@ -396,7 +403,7 @@ def convert_google_ai_response_to_chatcompletion(
|
|||||||
|
|
||||||
# TODO convert 'data' type to pydantic
|
# TODO convert 'data' type to pydantic
|
||||||
def google_ai_chat_completions_request(
|
def google_ai_chat_completions_request(
|
||||||
service_endpoint: str,
|
base_url: str,
|
||||||
model: str,
|
model: str,
|
||||||
api_key: str,
|
api_key: str,
|
||||||
data: dict,
|
data: dict,
|
||||||
@ -414,55 +421,23 @@ def google_ai_chat_completions_request(
|
|||||||
This service has the following service endpoint and all URIs below are relative to this service endpoint:
|
This service has the following service endpoint and all URIs below are relative to this service endpoint:
|
||||||
https://xxx.googleapis.com
|
https://xxx.googleapis.com
|
||||||
"""
|
"""
|
||||||
from letta.utils import printd
|
|
||||||
|
|
||||||
assert service_endpoint is not None, "Missing service_endpoint when calling Google AI"
|
|
||||||
assert api_key is not None, "Missing api_key when calling Google AI"
|
assert api_key is not None, "Missing api_key when calling Google AI"
|
||||||
assert model in SUPPORTED_MODELS, f"Model '{model}' not in supported models: {', '.join(SUPPORTED_MODELS)}"
|
|
||||||
|
|
||||||
# Two ways to pass the key: https://ai.google.dev/tutorials/setup
|
url, headers = get_gemini_endpoint_and_headers(base_url, model, api_key, key_in_header, generate_content=True)
|
||||||
if key_in_header:
|
|
||||||
url = f"https://{service_endpoint}.googleapis.com/v1beta/models/{model}:generateContent"
|
|
||||||
headers = {"Content-Type": "application/json", "x-goog-api-key": api_key}
|
|
||||||
else:
|
|
||||||
url = f"https://{service_endpoint}.googleapis.com/v1beta/models/{model}:generateContent?key={api_key}"
|
|
||||||
headers = {"Content-Type": "application/json"}
|
|
||||||
|
|
||||||
# data["contents"][-1]["role"] = "model"
|
# data["contents"][-1]["role"] = "model"
|
||||||
if add_postfunc_model_messages:
|
if add_postfunc_model_messages:
|
||||||
data["contents"] = add_dummy_model_messages(data["contents"])
|
data["contents"] = add_dummy_model_messages(data["contents"])
|
||||||
|
|
||||||
printd(f"Sending request to {url}")
|
response_json = make_post_request(url, headers, data)
|
||||||
try:
|
try:
|
||||||
response = requests.post(url, headers=headers, json=data)
|
|
||||||
printd(f"response = {response}")
|
|
||||||
response.raise_for_status() # Raises HTTPError for 4XX/5XX status
|
|
||||||
response = response.json() # convert to dict from string
|
|
||||||
printd(f"response.json = {response}")
|
|
||||||
|
|
||||||
# Convert Google AI response to ChatCompletion style
|
|
||||||
return convert_google_ai_response_to_chatcompletion(
|
return convert_google_ai_response_to_chatcompletion(
|
||||||
response_json=response,
|
response_json=response_json,
|
||||||
model=model,
|
model=data.get("model"),
|
||||||
input_messages=data["contents"],
|
input_messages=data["contents"],
|
||||||
pull_inner_thoughts_from_args=inner_thoughts_in_kwargs,
|
pull_inner_thoughts_from_args=data.get("inner_thoughts_in_kwargs", False),
|
||||||
)
|
)
|
||||||
|
except Exception as conversion_error:
|
||||||
except requests.exceptions.HTTPError as http_err:
|
print(f"Error during response conversion: {conversion_error}")
|
||||||
# Handle HTTP errors (e.g., response 4XX, 5XX)
|
raise conversion_error
|
||||||
printd(f"Got HTTPError, exception={http_err}, payload={data}")
|
|
||||||
# Print the HTTP status code
|
|
||||||
print(f"HTTP Error: {http_err.response.status_code}")
|
|
||||||
# Print the response content (error message from server)
|
|
||||||
print(f"Message: {http_err.response.text}")
|
|
||||||
raise http_err
|
|
||||||
|
|
||||||
except requests.exceptions.RequestException as req_err:
|
|
||||||
# Handle other requests-related errors (e.g., connection error)
|
|
||||||
printd(f"Got RequestException, exception={req_err}")
|
|
||||||
raise req_err
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
# Handle other potential errors
|
|
||||||
printd(f"Got unknown Exception, exception={e}")
|
|
||||||
raise e
|
|
||||||
|
@ -1,14 +1,69 @@
|
|||||||
import copy
|
import copy
|
||||||
import json
|
import json
|
||||||
import warnings
|
import warnings
|
||||||
from typing import List, Union
|
from typing import Any, List, Union
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from letta.constants import OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING
|
from letta.constants import OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING
|
||||||
from letta.schemas.enums import OptionState
|
from letta.schemas.enums import OptionState
|
||||||
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice
|
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice
|
||||||
from letta.utils import json_dumps
|
from letta.utils import json_dumps, printd
|
||||||
|
|
||||||
|
|
||||||
|
def make_post_request(url: str, headers: dict[str, str], data: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
printd(f"Sending request to {url}")
|
||||||
|
try:
|
||||||
|
# Make the POST request
|
||||||
|
response = requests.post(url, headers=headers, json=data)
|
||||||
|
printd(f"Response status code: {response.status_code}")
|
||||||
|
|
||||||
|
# Raise for 4XX/5XX HTTP errors
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
# Ensure the content is JSON before parsing
|
||||||
|
if response.headers.get("Content-Type") == "application/json":
|
||||||
|
response_data = response.json() # Convert to dict from JSON
|
||||||
|
printd(f"Response JSON: {response_data}")
|
||||||
|
else:
|
||||||
|
error_message = f"Unexpected content type returned: {response.headers.get('Content-Type')}"
|
||||||
|
printd(error_message)
|
||||||
|
raise ValueError(error_message)
|
||||||
|
|
||||||
|
# Process the response using the callback function
|
||||||
|
return response_data
|
||||||
|
|
||||||
|
except requests.exceptions.HTTPError as http_err:
|
||||||
|
# HTTP errors (4XX, 5XX)
|
||||||
|
error_message = f"HTTP error occurred: {http_err}"
|
||||||
|
if http_err.response is not None:
|
||||||
|
error_message += f" | Status code: {http_err.response.status_code}, Message: {http_err.response.text}"
|
||||||
|
printd(error_message)
|
||||||
|
raise requests.exceptions.HTTPError(error_message) from http_err
|
||||||
|
|
||||||
|
except requests.exceptions.Timeout as timeout_err:
|
||||||
|
# Handle timeout errors
|
||||||
|
error_message = f"Request timed out: {timeout_err}"
|
||||||
|
printd(error_message)
|
||||||
|
raise requests.exceptions.Timeout(error_message) from timeout_err
|
||||||
|
|
||||||
|
except requests.exceptions.RequestException as req_err:
|
||||||
|
# Non-HTTP errors (e.g., connection, SSL errors)
|
||||||
|
error_message = f"Request failed: {req_err}"
|
||||||
|
printd(error_message)
|
||||||
|
raise requests.exceptions.RequestException(error_message) from req_err
|
||||||
|
|
||||||
|
except ValueError as val_err:
|
||||||
|
# Handle content-type or non-JSON response issues
|
||||||
|
error_message = f"ValueError: {val_err}"
|
||||||
|
printd(error_message)
|
||||||
|
raise ValueError(error_message) from val_err
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Catch any other unknown exceptions
|
||||||
|
error_message = f"An unexpected error occurred: {e}"
|
||||||
|
printd(error_message)
|
||||||
|
raise Exception(error_message) from e
|
||||||
|
|
||||||
|
|
||||||
# TODO update to use better types
|
# TODO update to use better types
|
||||||
|
@ -28,7 +28,6 @@ from letta.local_llm.constants import (
|
|||||||
INNER_THOUGHTS_KWARG,
|
INNER_THOUGHTS_KWARG,
|
||||||
INNER_THOUGHTS_KWARG_DESCRIPTION,
|
INNER_THOUGHTS_KWARG_DESCRIPTION,
|
||||||
)
|
)
|
||||||
from letta.providers import GoogleAIProvider
|
|
||||||
from letta.schemas.enums import OptionState
|
from letta.schemas.enums import OptionState
|
||||||
from letta.schemas.llm_config import LLMConfig
|
from letta.schemas.llm_config import LLMConfig
|
||||||
from letta.schemas.message import Message
|
from letta.schemas.message import Message
|
||||||
@ -231,7 +230,7 @@ def create(
|
|||||||
|
|
||||||
return google_ai_chat_completions_request(
|
return google_ai_chat_completions_request(
|
||||||
inner_thoughts_in_kwargs=google_ai_inner_thoughts_in_kwarg,
|
inner_thoughts_in_kwargs=google_ai_inner_thoughts_in_kwarg,
|
||||||
service_endpoint=GoogleAIProvider(model_settings.gemini_api_key).service_endpoint,
|
base_url=llm_config.model_endpoint,
|
||||||
model=llm_config.model,
|
model=llm_config.model,
|
||||||
api_key=model_settings.gemini_api_key,
|
api_key=model_settings.gemini_api_key,
|
||||||
# see structure of payload here: https://ai.google.dev/docs/function_calling
|
# see structure of payload here: https://ai.google.dev/docs/function_calling
|
||||||
|
@ -9,7 +9,7 @@ from httpx_sse._exceptions import SSEError
|
|||||||
|
|
||||||
from letta.constants import OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING
|
from letta.constants import OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING
|
||||||
from letta.errors import LLMError
|
from letta.errors import LLMError
|
||||||
from letta.llm_api.helpers import add_inner_thoughts_to_functions
|
from letta.llm_api.helpers import add_inner_thoughts_to_functions, make_post_request
|
||||||
from letta.local_llm.constants import (
|
from letta.local_llm.constants import (
|
||||||
INNER_THOUGHTS_KWARG,
|
INNER_THOUGHTS_KWARG,
|
||||||
INNER_THOUGHTS_KWARG_DESCRIPTION,
|
INNER_THOUGHTS_KWARG_DESCRIPTION,
|
||||||
@ -483,58 +483,14 @@ def openai_chat_completions_request(
|
|||||||
data.pop("tools")
|
data.pop("tools")
|
||||||
data.pop("tool_choice", None) # extra safe, should exist always (default="auto")
|
data.pop("tool_choice", None) # extra safe, should exist always (default="auto")
|
||||||
|
|
||||||
printd(f"Sending request to {url}")
|
response_json = make_post_request(url, headers, data)
|
||||||
try:
|
return ChatCompletionResponse(**response_json)
|
||||||
response = requests.post(url, headers=headers, json=data)
|
|
||||||
printd(f"response = {response}, response.text = {response.text}")
|
|
||||||
# print(json.dumps(data, indent=4))
|
|
||||||
# raise requests.exceptions.HTTPError
|
|
||||||
response.raise_for_status() # Raises HTTPError for 4XX/5XX status
|
|
||||||
|
|
||||||
response = response.json() # convert to dict from string
|
|
||||||
printd(f"response.json = {response}")
|
|
||||||
|
|
||||||
response = ChatCompletionResponse(**response) # convert to 'dot-dict' style which is the openai python client default
|
|
||||||
return response
|
|
||||||
except requests.exceptions.HTTPError as http_err:
|
|
||||||
# Handle HTTP errors (e.g., response 4XX, 5XX)
|
|
||||||
printd(f"Got HTTPError, exception={http_err}, payload={data}")
|
|
||||||
raise http_err
|
|
||||||
except requests.exceptions.RequestException as req_err:
|
|
||||||
# Handle other requests-related errors (e.g., connection error)
|
|
||||||
printd(f"Got RequestException, exception={req_err}")
|
|
||||||
raise req_err
|
|
||||||
except Exception as e:
|
|
||||||
# Handle other potential errors
|
|
||||||
printd(f"Got unknown Exception, exception={e}")
|
|
||||||
raise e
|
|
||||||
|
|
||||||
|
|
||||||
def openai_embeddings_request(url: str, api_key: str, data: dict) -> EmbeddingResponse:
|
def openai_embeddings_request(url: str, api_key: str, data: dict) -> EmbeddingResponse:
|
||||||
"""https://platform.openai.com/docs/api-reference/embeddings/create"""
|
"""https://platform.openai.com/docs/api-reference/embeddings/create"""
|
||||||
from letta.utils import printd
|
|
||||||
|
|
||||||
url = smart_urljoin(url, "embeddings")
|
url = smart_urljoin(url, "embeddings")
|
||||||
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
|
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
|
||||||
|
response_json = make_post_request(url, headers, data)
|
||||||
printd(f"Sending request to {url}")
|
return EmbeddingResponse(**response_json)
|
||||||
try:
|
|
||||||
response = requests.post(url, headers=headers, json=data)
|
|
||||||
printd(f"response = {response}")
|
|
||||||
response.raise_for_status() # Raises HTTPError for 4XX/5XX status
|
|
||||||
response = response.json() # convert to dict from string
|
|
||||||
printd(f"response.json = {response}")
|
|
||||||
response = EmbeddingResponse(**response) # convert to 'dot-dict' style which is the openai python client default
|
|
||||||
return response
|
|
||||||
except requests.exceptions.HTTPError as http_err:
|
|
||||||
# Handle HTTP errors (e.g., response 4XX, 5XX)
|
|
||||||
printd(f"Got HTTPError, exception={http_err}, payload={data}")
|
|
||||||
raise http_err
|
|
||||||
except requests.exceptions.RequestException as req_err:
|
|
||||||
# Handle other requests-related errors (e.g., connection error)
|
|
||||||
printd(f"Got RequestException, exception={req_err}")
|
|
||||||
raise req_err
|
|
||||||
except Exception as e:
|
|
||||||
# Handle other potential errors
|
|
||||||
printd(f"Got unknown Exception, exception={e}")
|
|
||||||
raise e
|
|
||||||
|
@ -217,14 +217,12 @@ class GroqProvider(OpenAIProvider):
|
|||||||
class GoogleAIProvider(Provider):
|
class GoogleAIProvider(Provider):
|
||||||
# gemini
|
# gemini
|
||||||
api_key: str = Field(..., description="API key for the Google AI API.")
|
api_key: str = Field(..., description="API key for the Google AI API.")
|
||||||
service_endpoint: str = "generativelanguage" # TODO: remove once old functions are refactored to just use base_url
|
|
||||||
base_url: str = "https://generativelanguage.googleapis.com"
|
base_url: str = "https://generativelanguage.googleapis.com"
|
||||||
|
|
||||||
def list_llm_models(self):
|
def list_llm_models(self):
|
||||||
from letta.llm_api.google_ai import google_ai_get_model_list
|
from letta.llm_api.google_ai import google_ai_get_model_list
|
||||||
|
|
||||||
# TODO: use base_url instead
|
model_options = google_ai_get_model_list(base_url=self.base_url, api_key=self.api_key)
|
||||||
model_options = google_ai_get_model_list(service_endpoint=self.service_endpoint, api_key=self.api_key)
|
|
||||||
# filter by 'generateContent' models
|
# filter by 'generateContent' models
|
||||||
model_options = [mo for mo in model_options if "generateContent" in mo["supportedGenerationMethods"]]
|
model_options = [mo for mo in model_options if "generateContent" in mo["supportedGenerationMethods"]]
|
||||||
model_options = [str(m["name"]) for m in model_options]
|
model_options = [str(m["name"]) for m in model_options]
|
||||||
@ -251,7 +249,7 @@ class GoogleAIProvider(Provider):
|
|||||||
from letta.llm_api.google_ai import google_ai_get_model_list
|
from letta.llm_api.google_ai import google_ai_get_model_list
|
||||||
|
|
||||||
# TODO: use base_url instead
|
# TODO: use base_url instead
|
||||||
model_options = google_ai_get_model_list(service_endpoint=self.service_endpoint, api_key=self.api_key)
|
model_options = google_ai_get_model_list(base_url=self.base_url, api_key=self.api_key)
|
||||||
# filter by 'generateContent' models
|
# filter by 'generateContent' models
|
||||||
model_options = [mo for mo in model_options if "embedContent" in mo["supportedGenerationMethods"]]
|
model_options = [mo for mo in model_options if "embedContent" in mo["supportedGenerationMethods"]]
|
||||||
model_options = [str(m["name"]) for m in model_options]
|
model_options = [str(m["name"]) for m in model_options]
|
||||||
@ -273,8 +271,7 @@ class GoogleAIProvider(Provider):
|
|||||||
def get_model_context_window(self, model_name: str):
|
def get_model_context_window(self, model_name: str):
|
||||||
from letta.llm_api.google_ai import google_ai_get_model_context_window
|
from letta.llm_api.google_ai import google_ai_get_model_context_window
|
||||||
|
|
||||||
# TODO: use base_url instead
|
return google_ai_get_model_context_window(self.base_url, self.api_key, model_name)
|
||||||
return google_ai_get_model_context_window(self.service_endpoint, self.api_key, model_name)
|
|
||||||
|
|
||||||
|
|
||||||
class AzureProvider(Provider):
|
class AzureProvider(Provider):
|
||||||
|
@ -74,6 +74,9 @@ class ChatCompletionResponse(BaseModel):
|
|||||||
object: Literal["chat.completion"] = "chat.completion"
|
object: Literal["chat.completion"] = "chat.completion"
|
||||||
usage: UsageStatistics
|
usage: UsageStatistics
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.model_dump_json(indent=4)
|
||||||
|
|
||||||
|
|
||||||
class FunctionCallDelta(BaseModel):
|
class FunctionCallDelta(BaseModel):
|
||||||
# arguments: Optional[str] = None
|
# arguments: Optional[str] = None
|
||||||
|
7
tests/configs/llm_model_configs/gemini-pro.json
Normal file
7
tests/configs/llm_model_configs/gemini-pro.json
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
{
|
||||||
|
"context_window": 2097152,
|
||||||
|
"model": "gemini-1.5-pro-latest",
|
||||||
|
"model_endpoint_type": "google_ai",
|
||||||
|
"model_endpoint": "https://generativelanguage.googleapis.com",
|
||||||
|
"model_wrapper": null
|
||||||
|
}
|
@ -273,3 +273,13 @@ def test_groq_llama31_70b_edit_core_memory():
|
|||||||
response = check_agent_edit_core_memory(filename)
|
response = check_agent_edit_core_memory(filename)
|
||||||
# Log out successful response
|
# Log out successful response
|
||||||
print(f"Got successful response from client: \n\n{response}")
|
print(f"Got successful response from client: \n\n{response}")
|
||||||
|
|
||||||
|
|
||||||
|
# ======================================================================================================================
|
||||||
|
# GEMINI TESTS
|
||||||
|
# ======================================================================================================================
|
||||||
|
def test_gemini_pro_15_returns_valid_first_message():
|
||||||
|
filename = os.path.join(llm_config_dir, "gemini-pro.json")
|
||||||
|
response = check_first_response_is_valid_for_llm_endpoint(filename)
|
||||||
|
# Log out successful response
|
||||||
|
print(f"Got successful response from client: \n\n{response}")
|
||||||
|
Loading…
Reference in New Issue
Block a user