MemGPT/tests/test_providers.py
Kevin Lin 7828d77616 feat: support deepseek models (#821)
Co-authored-by: Charles Packer <packercharles@gmail.com>
Co-authored-by: Sarah Wooders <sarahwooders@gmail.com>
Co-authored-by: Shubham Naik <shub@memgpt.ai>
Co-authored-by: Shubham Naik <shub@letta.com>
2025-02-18 15:28:01 -08:00

121 lines
3.4 KiB
Python

import os
from letta.schemas.providers import (
AnthropicBedrockProvider,
AnthropicProvider,
AzureProvider,
DeepSeekProvider,
GoogleAIProvider,
GoogleVertexProvider,
GroqProvider,
MistralProvider,
OllamaProvider,
OpenAIProvider,
TogetherProvider,
)
from letta.settings import model_settings
def test_openai():
api_key = os.getenv("OPENAI_API_KEY")
assert api_key is not None
provider = OpenAIProvider(api_key=api_key, base_url=model_settings.openai_api_base)
models = provider.list_llm_models()
print(models)
def test_deepseek():
api_key = os.getenv("DEEPSEEK_API_KEY")
assert api_key is not None
provider = DeepSeekProvider(api_key=api_key)
models = provider.list_llm_models()
print(models)
def test_anthropic():
api_key = os.getenv("ANTHROPIC_API_KEY")
assert api_key is not None
provider = AnthropicProvider(api_key=api_key)
models = provider.list_llm_models()
print(models)
def test_groq():
provider = GroqProvider(api_key=os.getenv("GROQ_API_KEY"))
models = provider.list_llm_models()
print(models)
def test_azure():
provider = AzureProvider(api_key=os.getenv("AZURE_API_KEY"), base_url=os.getenv("AZURE_BASE_URL"))
models = provider.list_llm_models()
print([m.model for m in models])
embed_models = provider.list_embedding_models()
print([m.embedding_model for m in embed_models])
def test_ollama():
base_url = os.getenv("OLLAMA_BASE_URL")
assert base_url is not None
provider = OllamaProvider(base_url=base_url, default_prompt_formatter=model_settings.default_prompt_formatter, api_key=None)
models = provider.list_llm_models()
print(models)
embedding_models = provider.list_embedding_models()
print(embedding_models)
def test_googleai():
api_key = os.getenv("GEMINI_API_KEY")
assert api_key is not None
provider = GoogleAIProvider(api_key=api_key)
models = provider.list_llm_models()
print(models)
provider.list_embedding_models()
def test_google_vertex():
provider = GoogleVertexProvider(google_cloud_project=os.getenv("GCP_PROJECT_ID"), google_cloud_location=os.getenv("GCP_REGION"))
models = provider.list_llm_models()
print(models)
print([m.model for m in models])
embedding_models = provider.list_embedding_models()
print([m.embedding_model for m in embedding_models])
def test_mistral():
provider = MistralProvider(api_key=os.getenv("MISTRAL_API_KEY"))
models = provider.list_llm_models()
print([m.model for m in models])
def test_together():
provider = TogetherProvider(api_key=os.getenv("TOGETHER_API_KEY"), default_prompt_formatter="chatml")
models = provider.list_llm_models()
print([m.model for m in models])
embedding_models = provider.list_embedding_models()
print([m.embedding_model for m in embedding_models])
def test_anthropic_bedrock():
from letta.settings import model_settings
provider = AnthropicBedrockProvider(aws_region=model_settings.aws_region)
models = provider.list_llm_models()
print([m.model for m in models])
embedding_models = provider.list_embedding_models()
print([m.embedding_model for m in embedding_models])
# def test_vllm():
# provider = VLLMProvider(base_url=os.getenv("VLLM_API_BASE"))
# models = provider.list_llm_models()
# print(models)
#
# provider.list_embedding_models()