mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
feat: byok 2.0 (#1963)
This commit is contained in:
parent
e3819cf066
commit
835792d5e0
@ -0,0 +1,35 @@
|
|||||||
|
"""add byok fields and unique constraint
|
||||||
|
|
||||||
|
Revision ID: 373dabcba6cf
|
||||||
|
Revises: c56081a05371
|
||||||
|
Create Date: 2025-04-30 19:38:25.010856
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "373dabcba6cf"
|
||||||
|
down_revision: Union[str, None] = "c56081a05371"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.add_column("providers", sa.Column("provider_type", sa.String(), nullable=True))
|
||||||
|
op.add_column("providers", sa.Column("base_url", sa.String(), nullable=True))
|
||||||
|
op.create_unique_constraint("unique_name_organization_id", "providers", ["name", "organization_id"])
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_constraint("unique_name_organization_id", "providers", type_="unique")
|
||||||
|
op.drop_column("providers", "base_url")
|
||||||
|
op.drop_column("providers", "provider_type")
|
||||||
|
# ### end Alembic commands ###
|
@ -331,7 +331,8 @@ class Agent(BaseAgent):
|
|||||||
log_telemetry(self.logger, "_get_ai_reply create start")
|
log_telemetry(self.logger, "_get_ai_reply create start")
|
||||||
# New LLM client flow
|
# New LLM client flow
|
||||||
llm_client = LLMClient.create(
|
llm_client = LLMClient.create(
|
||||||
provider=self.agent_state.llm_config.model_endpoint_type,
|
provider_name=self.agent_state.llm_config.provider_name,
|
||||||
|
provider_type=self.agent_state.llm_config.model_endpoint_type,
|
||||||
put_inner_thoughts_first=put_inner_thoughts_first,
|
put_inner_thoughts_first=put_inner_thoughts_first,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -941,12 +942,7 @@ class Agent(BaseAgent):
|
|||||||
model_endpoint=self.agent_state.llm_config.model_endpoint,
|
model_endpoint=self.agent_state.llm_config.model_endpoint,
|
||||||
context_window_limit=self.agent_state.llm_config.context_window,
|
context_window_limit=self.agent_state.llm_config.context_window,
|
||||||
usage=response.usage,
|
usage=response.usage,
|
||||||
# TODO(@caren): Add full provider support - this line is a workaround for v0 BYOK feature
|
provider_id=self.provider_manager.get_provider_id_from_name(self.agent_state.llm_config.provider_name),
|
||||||
provider_id=(
|
|
||||||
self.provider_manager.get_anthropic_override_provider_id()
|
|
||||||
if self.agent_state.llm_config.model_endpoint_type == "anthropic"
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
job_id=job_id,
|
job_id=job_id,
|
||||||
)
|
)
|
||||||
for message in all_new_messages:
|
for message in all_new_messages:
|
||||||
|
@ -67,7 +67,8 @@ class LettaAgent(BaseAgent):
|
|||||||
)
|
)
|
||||||
tool_rules_solver = ToolRulesSolver(agent_state.tool_rules)
|
tool_rules_solver = ToolRulesSolver(agent_state.tool_rules)
|
||||||
llm_client = LLMClient.create(
|
llm_client = LLMClient.create(
|
||||||
provider=agent_state.llm_config.model_endpoint_type,
|
provider_name=agent_state.llm_config.provider_name,
|
||||||
|
provider_type=agent_state.llm_config.model_endpoint_type,
|
||||||
put_inner_thoughts_first=True,
|
put_inner_thoughts_first=True,
|
||||||
)
|
)
|
||||||
for step in range(max_steps):
|
for step in range(max_steps):
|
||||||
|
@ -156,7 +156,8 @@ class LettaAgentBatch:
|
|||||||
|
|
||||||
log_event(name="init_llm_client")
|
log_event(name="init_llm_client")
|
||||||
llm_client = LLMClient.create(
|
llm_client = LLMClient.create(
|
||||||
provider=agent_states[0].llm_config.model_endpoint_type,
|
provider_name=agent_states[0].llm_config.provider_name,
|
||||||
|
provider_type=agent_states[0].llm_config.model_endpoint_type,
|
||||||
put_inner_thoughts_first=True,
|
put_inner_thoughts_first=True,
|
||||||
)
|
)
|
||||||
agent_llm_config_mapping = {s.id: s.llm_config for s in agent_states}
|
agent_llm_config_mapping = {s.id: s.llm_config for s in agent_states}
|
||||||
@ -273,7 +274,8 @@ class LettaAgentBatch:
|
|||||||
|
|
||||||
# translate provider‑specific response → OpenAI‑style tool call (unchanged)
|
# translate provider‑specific response → OpenAI‑style tool call (unchanged)
|
||||||
llm_client = LLMClient.create(
|
llm_client = LLMClient.create(
|
||||||
provider=item.llm_config.model_endpoint_type,
|
provider_name=item.llm_config.provider_name,
|
||||||
|
provider_type=item.llm_config.model_endpoint_type,
|
||||||
put_inner_thoughts_first=True,
|
put_inner_thoughts_first=True,
|
||||||
)
|
)
|
||||||
tool_call = (
|
tool_call = (
|
||||||
|
@ -26,6 +26,7 @@ from letta.llm_api.helpers import add_inner_thoughts_to_functions
|
|||||||
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION
|
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION
|
||||||
from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages
|
from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages
|
||||||
from letta.log import get_logger
|
from letta.log import get_logger
|
||||||
|
from letta.schemas.enums import ProviderType
|
||||||
from letta.schemas.message import Message as _Message
|
from letta.schemas.message import Message as _Message
|
||||||
from letta.schemas.message import MessageRole as _MessageRole
|
from letta.schemas.message import MessageRole as _MessageRole
|
||||||
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, Tool
|
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, Tool
|
||||||
@ -128,11 +129,12 @@ def anthropic_get_model_list(url: str, api_key: Union[str, None]) -> dict:
|
|||||||
# NOTE: currently there is no GET /models, so we need to hardcode
|
# NOTE: currently there is no GET /models, so we need to hardcode
|
||||||
# return MODEL_LIST
|
# return MODEL_LIST
|
||||||
|
|
||||||
anthropic_override_key = ProviderManager().get_anthropic_override_key()
|
if api_key:
|
||||||
if anthropic_override_key:
|
anthropic_client = anthropic.Anthropic(api_key=api_key)
|
||||||
anthropic_client = anthropic.Anthropic(api_key=anthropic_override_key)
|
|
||||||
elif model_settings.anthropic_api_key:
|
elif model_settings.anthropic_api_key:
|
||||||
anthropic_client = anthropic.Anthropic()
|
anthropic_client = anthropic.Anthropic()
|
||||||
|
else:
|
||||||
|
raise ValueError("No API key provided")
|
||||||
|
|
||||||
models = anthropic_client.models.list()
|
models = anthropic_client.models.list()
|
||||||
models_json = models.model_dump()
|
models_json = models.model_dump()
|
||||||
@ -738,13 +740,14 @@ def anthropic_chat_completions_request(
|
|||||||
put_inner_thoughts_in_kwargs: bool = False,
|
put_inner_thoughts_in_kwargs: bool = False,
|
||||||
extended_thinking: bool = False,
|
extended_thinking: bool = False,
|
||||||
max_reasoning_tokens: Optional[int] = None,
|
max_reasoning_tokens: Optional[int] = None,
|
||||||
|
provider_name: Optional[str] = None,
|
||||||
betas: List[str] = ["tools-2024-04-04"],
|
betas: List[str] = ["tools-2024-04-04"],
|
||||||
) -> ChatCompletionResponse:
|
) -> ChatCompletionResponse:
|
||||||
"""https://docs.anthropic.com/claude/docs/tool-use"""
|
"""https://docs.anthropic.com/claude/docs/tool-use"""
|
||||||
anthropic_client = None
|
anthropic_client = None
|
||||||
anthropic_override_key = ProviderManager().get_anthropic_override_key()
|
if provider_name and provider_name != ProviderType.anthropic.value:
|
||||||
if anthropic_override_key:
|
api_key = ProviderManager().get_override_key(provider_name)
|
||||||
anthropic_client = anthropic.Anthropic(api_key=anthropic_override_key)
|
anthropic_client = anthropic.Anthropic(api_key=api_key)
|
||||||
elif model_settings.anthropic_api_key:
|
elif model_settings.anthropic_api_key:
|
||||||
anthropic_client = anthropic.Anthropic()
|
anthropic_client = anthropic.Anthropic()
|
||||||
else:
|
else:
|
||||||
@ -796,6 +799,7 @@ def anthropic_chat_completions_request_stream(
|
|||||||
put_inner_thoughts_in_kwargs: bool = False,
|
put_inner_thoughts_in_kwargs: bool = False,
|
||||||
extended_thinking: bool = False,
|
extended_thinking: bool = False,
|
||||||
max_reasoning_tokens: Optional[int] = None,
|
max_reasoning_tokens: Optional[int] = None,
|
||||||
|
provider_name: Optional[str] = None,
|
||||||
betas: List[str] = ["tools-2024-04-04"],
|
betas: List[str] = ["tools-2024-04-04"],
|
||||||
) -> Generator[ChatCompletionChunkResponse, None, None]:
|
) -> Generator[ChatCompletionChunkResponse, None, None]:
|
||||||
"""Stream chat completions from Anthropic API.
|
"""Stream chat completions from Anthropic API.
|
||||||
@ -810,10 +814,9 @@ def anthropic_chat_completions_request_stream(
|
|||||||
extended_thinking=extended_thinking,
|
extended_thinking=extended_thinking,
|
||||||
max_reasoning_tokens=max_reasoning_tokens,
|
max_reasoning_tokens=max_reasoning_tokens,
|
||||||
)
|
)
|
||||||
|
if provider_name and provider_name != ProviderType.anthropic.value:
|
||||||
anthropic_override_key = ProviderManager().get_anthropic_override_key()
|
api_key = ProviderManager().get_override_key(provider_name)
|
||||||
if anthropic_override_key:
|
anthropic_client = anthropic.Anthropic(api_key=api_key)
|
||||||
anthropic_client = anthropic.Anthropic(api_key=anthropic_override_key)
|
|
||||||
elif model_settings.anthropic_api_key:
|
elif model_settings.anthropic_api_key:
|
||||||
anthropic_client = anthropic.Anthropic()
|
anthropic_client = anthropic.Anthropic()
|
||||||
|
|
||||||
@ -860,6 +863,7 @@ def anthropic_chat_completions_process_stream(
|
|||||||
put_inner_thoughts_in_kwargs: bool = False,
|
put_inner_thoughts_in_kwargs: bool = False,
|
||||||
extended_thinking: bool = False,
|
extended_thinking: bool = False,
|
||||||
max_reasoning_tokens: Optional[int] = None,
|
max_reasoning_tokens: Optional[int] = None,
|
||||||
|
provider_name: Optional[str] = None,
|
||||||
create_message_id: bool = True,
|
create_message_id: bool = True,
|
||||||
create_message_datetime: bool = True,
|
create_message_datetime: bool = True,
|
||||||
betas: List[str] = ["tools-2024-04-04"],
|
betas: List[str] = ["tools-2024-04-04"],
|
||||||
@ -944,6 +948,7 @@ def anthropic_chat_completions_process_stream(
|
|||||||
put_inner_thoughts_in_kwargs=put_inner_thoughts_in_kwargs,
|
put_inner_thoughts_in_kwargs=put_inner_thoughts_in_kwargs,
|
||||||
extended_thinking=extended_thinking,
|
extended_thinking=extended_thinking,
|
||||||
max_reasoning_tokens=max_reasoning_tokens,
|
max_reasoning_tokens=max_reasoning_tokens,
|
||||||
|
provider_name=provider_name,
|
||||||
betas=betas,
|
betas=betas,
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
|
@ -27,6 +27,7 @@ from letta.llm_api.helpers import add_inner_thoughts_to_functions, unpack_all_in
|
|||||||
from letta.llm_api.llm_client_base import LLMClientBase
|
from letta.llm_api.llm_client_base import LLMClientBase
|
||||||
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION
|
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION
|
||||||
from letta.log import get_logger
|
from letta.log import get_logger
|
||||||
|
from letta.schemas.enums import ProviderType
|
||||||
from letta.schemas.llm_config import LLMConfig
|
from letta.schemas.llm_config import LLMConfig
|
||||||
from letta.schemas.message import Message as PydanticMessage
|
from letta.schemas.message import Message as PydanticMessage
|
||||||
from letta.schemas.openai.chat_completion_request import Tool
|
from letta.schemas.openai.chat_completion_request import Tool
|
||||||
@ -112,7 +113,10 @@ class AnthropicClient(LLMClientBase):
|
|||||||
|
|
||||||
@trace_method
|
@trace_method
|
||||||
def _get_anthropic_client(self, async_client: bool = False) -> Union[anthropic.AsyncAnthropic, anthropic.Anthropic]:
|
def _get_anthropic_client(self, async_client: bool = False) -> Union[anthropic.AsyncAnthropic, anthropic.Anthropic]:
|
||||||
override_key = ProviderManager().get_anthropic_override_key()
|
override_key = None
|
||||||
|
if self.provider_name and self.provider_name != ProviderType.anthropic.value:
|
||||||
|
override_key = ProviderManager().get_override_key(self.provider_name)
|
||||||
|
|
||||||
if async_client:
|
if async_client:
|
||||||
return anthropic.AsyncAnthropic(api_key=override_key) if override_key else anthropic.AsyncAnthropic()
|
return anthropic.AsyncAnthropic(api_key=override_key) if override_key else anthropic.AsyncAnthropic()
|
||||||
return anthropic.Anthropic(api_key=override_key) if override_key else anthropic.Anthropic()
|
return anthropic.Anthropic(api_key=override_key) if override_key else anthropic.Anthropic()
|
||||||
|
@ -24,6 +24,7 @@ from letta.llm_api.openai import (
|
|||||||
from letta.local_llm.chat_completion_proxy import get_chat_completion
|
from letta.local_llm.chat_completion_proxy import get_chat_completion
|
||||||
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION
|
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION
|
||||||
from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages
|
from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages
|
||||||
|
from letta.schemas.enums import ProviderType
|
||||||
from letta.schemas.llm_config import LLMConfig
|
from letta.schemas.llm_config import LLMConfig
|
||||||
from letta.schemas.message import Message
|
from letta.schemas.message import Message
|
||||||
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, cast_message_to_subtype
|
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, cast_message_to_subtype
|
||||||
@ -171,6 +172,10 @@ def create(
|
|||||||
if model_settings.openai_api_key is None and llm_config.model_endpoint == "https://api.openai.com/v1":
|
if model_settings.openai_api_key is None and llm_config.model_endpoint == "https://api.openai.com/v1":
|
||||||
# only is a problem if we are *not* using an openai proxy
|
# only is a problem if we are *not* using an openai proxy
|
||||||
raise LettaConfigurationError(message="OpenAI key is missing from letta config file", missing_fields=["openai_api_key"])
|
raise LettaConfigurationError(message="OpenAI key is missing from letta config file", missing_fields=["openai_api_key"])
|
||||||
|
elif llm_config.provider_name and llm_config.provider_name != ProviderType.openai.value:
|
||||||
|
from letta.services.provider_manager import ProviderManager
|
||||||
|
|
||||||
|
api_key = ProviderManager().get_override_key(llm_config.provider_name)
|
||||||
elif model_settings.openai_api_key is None:
|
elif model_settings.openai_api_key is None:
|
||||||
# the openai python client requires a dummy API key
|
# the openai python client requires a dummy API key
|
||||||
api_key = "DUMMY_API_KEY"
|
api_key = "DUMMY_API_KEY"
|
||||||
@ -373,6 +378,7 @@ def create(
|
|||||||
stream_interface=stream_interface,
|
stream_interface=stream_interface,
|
||||||
extended_thinking=llm_config.enable_reasoner,
|
extended_thinking=llm_config.enable_reasoner,
|
||||||
max_reasoning_tokens=llm_config.max_reasoning_tokens,
|
max_reasoning_tokens=llm_config.max_reasoning_tokens,
|
||||||
|
provider_name=llm_config.provider_name,
|
||||||
name=name,
|
name=name,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -383,6 +389,7 @@ def create(
|
|||||||
put_inner_thoughts_in_kwargs=llm_config.put_inner_thoughts_in_kwargs,
|
put_inner_thoughts_in_kwargs=llm_config.put_inner_thoughts_in_kwargs,
|
||||||
extended_thinking=llm_config.enable_reasoner,
|
extended_thinking=llm_config.enable_reasoner,
|
||||||
max_reasoning_tokens=llm_config.max_reasoning_tokens,
|
max_reasoning_tokens=llm_config.max_reasoning_tokens,
|
||||||
|
provider_name=llm_config.provider_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
if llm_config.put_inner_thoughts_in_kwargs:
|
if llm_config.put_inner_thoughts_in_kwargs:
|
||||||
|
@ -9,7 +9,8 @@ class LLMClient:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create(
|
def create(
|
||||||
provider: ProviderType,
|
provider_type: ProviderType,
|
||||||
|
provider_name: Optional[str] = None,
|
||||||
put_inner_thoughts_first: bool = True,
|
put_inner_thoughts_first: bool = True,
|
||||||
) -> Optional[LLMClientBase]:
|
) -> Optional[LLMClientBase]:
|
||||||
"""
|
"""
|
||||||
@ -25,29 +26,33 @@ class LLMClient:
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: If the model endpoint type is not supported
|
ValueError: If the model endpoint type is not supported
|
||||||
"""
|
"""
|
||||||
match provider:
|
match provider_type:
|
||||||
case ProviderType.google_ai:
|
case ProviderType.google_ai:
|
||||||
from letta.llm_api.google_ai_client import GoogleAIClient
|
from letta.llm_api.google_ai_client import GoogleAIClient
|
||||||
|
|
||||||
return GoogleAIClient(
|
return GoogleAIClient(
|
||||||
|
provider_name=provider_name,
|
||||||
put_inner_thoughts_first=put_inner_thoughts_first,
|
put_inner_thoughts_first=put_inner_thoughts_first,
|
||||||
)
|
)
|
||||||
case ProviderType.google_vertex:
|
case ProviderType.google_vertex:
|
||||||
from letta.llm_api.google_vertex_client import GoogleVertexClient
|
from letta.llm_api.google_vertex_client import GoogleVertexClient
|
||||||
|
|
||||||
return GoogleVertexClient(
|
return GoogleVertexClient(
|
||||||
|
provider_name=provider_name,
|
||||||
put_inner_thoughts_first=put_inner_thoughts_first,
|
put_inner_thoughts_first=put_inner_thoughts_first,
|
||||||
)
|
)
|
||||||
case ProviderType.anthropic:
|
case ProviderType.anthropic:
|
||||||
from letta.llm_api.anthropic_client import AnthropicClient
|
from letta.llm_api.anthropic_client import AnthropicClient
|
||||||
|
|
||||||
return AnthropicClient(
|
return AnthropicClient(
|
||||||
|
provider_name=provider_name,
|
||||||
put_inner_thoughts_first=put_inner_thoughts_first,
|
put_inner_thoughts_first=put_inner_thoughts_first,
|
||||||
)
|
)
|
||||||
case ProviderType.openai:
|
case ProviderType.openai:
|
||||||
from letta.llm_api.openai_client import OpenAIClient
|
from letta.llm_api.openai_client import OpenAIClient
|
||||||
|
|
||||||
return OpenAIClient(
|
return OpenAIClient(
|
||||||
|
provider_name=provider_name,
|
||||||
put_inner_thoughts_first=put_inner_thoughts_first,
|
put_inner_thoughts_first=put_inner_thoughts_first,
|
||||||
)
|
)
|
||||||
case _:
|
case _:
|
||||||
|
@ -20,9 +20,11 @@ class LLMClientBase:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
provider_name: Optional[str] = None,
|
||||||
put_inner_thoughts_first: Optional[bool] = True,
|
put_inner_thoughts_first: Optional[bool] = True,
|
||||||
use_tool_naming: bool = True,
|
use_tool_naming: bool = True,
|
||||||
):
|
):
|
||||||
|
self.provider_name = provider_name
|
||||||
self.put_inner_thoughts_first = put_inner_thoughts_first
|
self.put_inner_thoughts_first = put_inner_thoughts_first
|
||||||
self.use_tool_naming = use_tool_naming
|
self.use_tool_naming = use_tool_naming
|
||||||
|
|
||||||
|
@ -22,6 +22,7 @@ from letta.llm_api.helpers import add_inner_thoughts_to_functions, convert_to_st
|
|||||||
from letta.llm_api.llm_client_base import LLMClientBase
|
from letta.llm_api.llm_client_base import LLMClientBase
|
||||||
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION, INNER_THOUGHTS_KWARG_DESCRIPTION_GO_FIRST
|
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION, INNER_THOUGHTS_KWARG_DESCRIPTION_GO_FIRST
|
||||||
from letta.log import get_logger
|
from letta.log import get_logger
|
||||||
|
from letta.schemas.enums import ProviderType
|
||||||
from letta.schemas.llm_config import LLMConfig
|
from letta.schemas.llm_config import LLMConfig
|
||||||
from letta.schemas.message import Message as PydanticMessage
|
from letta.schemas.message import Message as PydanticMessage
|
||||||
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest
|
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest
|
||||||
@ -64,7 +65,14 @@ def supports_parallel_tool_calling(model: str) -> bool:
|
|||||||
|
|
||||||
class OpenAIClient(LLMClientBase):
|
class OpenAIClient(LLMClientBase):
|
||||||
def _prepare_client_kwargs(self, llm_config: LLMConfig) -> dict:
|
def _prepare_client_kwargs(self, llm_config: LLMConfig) -> dict:
|
||||||
api_key = model_settings.openai_api_key or os.environ.get("OPENAI_API_KEY")
|
api_key = None
|
||||||
|
if llm_config.provider_name and llm_config.provider_name != ProviderType.openai.value:
|
||||||
|
from letta.services.provider_manager import ProviderManager
|
||||||
|
|
||||||
|
api_key = ProviderManager().get_override_key(llm_config.provider_name)
|
||||||
|
|
||||||
|
if not api_key:
|
||||||
|
api_key = model_settings.openai_api_key or os.environ.get("OPENAI_API_KEY")
|
||||||
# supposedly the openai python client requires a dummy API key
|
# supposedly the openai python client requires a dummy API key
|
||||||
api_key = api_key or "DUMMY_API_KEY"
|
api_key = api_key or "DUMMY_API_KEY"
|
||||||
kwargs = {"api_key": api_key, "base_url": llm_config.model_endpoint}
|
kwargs = {"api_key": api_key, "base_url": llm_config.model_endpoint}
|
||||||
|
@ -79,7 +79,8 @@ def summarize_messages(
|
|||||||
llm_config_no_inner_thoughts.put_inner_thoughts_in_kwargs = False
|
llm_config_no_inner_thoughts.put_inner_thoughts_in_kwargs = False
|
||||||
|
|
||||||
llm_client = LLMClient.create(
|
llm_client = LLMClient.create(
|
||||||
provider=llm_config_no_inner_thoughts.model_endpoint_type,
|
provider_name=llm_config_no_inner_thoughts.provider_name,
|
||||||
|
provider_type=llm_config_no_inner_thoughts.model_endpoint_type,
|
||||||
put_inner_thoughts_first=False,
|
put_inner_thoughts_first=False,
|
||||||
)
|
)
|
||||||
# try to use new client, otherwise fallback to old flow
|
# try to use new client, otherwise fallback to old flow
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from sqlalchemy import UniqueConstraint
|
||||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
from letta.orm.mixins import OrganizationMixin
|
from letta.orm.mixins import OrganizationMixin
|
||||||
@ -15,9 +16,18 @@ class Provider(SqlalchemyBase, OrganizationMixin):
|
|||||||
|
|
||||||
__tablename__ = "providers"
|
__tablename__ = "providers"
|
||||||
__pydantic_model__ = PydanticProvider
|
__pydantic_model__ = PydanticProvider
|
||||||
|
__table_args__ = (
|
||||||
|
UniqueConstraint(
|
||||||
|
"name",
|
||||||
|
"organization_id",
|
||||||
|
name="unique_name_organization_id",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
name: Mapped[str] = mapped_column(nullable=False, doc="The name of the provider")
|
name: Mapped[str] = mapped_column(nullable=False, doc="The name of the provider")
|
||||||
|
provider_type: Mapped[str] = mapped_column(nullable=True, doc="The type of the provider")
|
||||||
api_key: Mapped[str] = mapped_column(nullable=True, doc="API key used for requests to the provider.")
|
api_key: Mapped[str] = mapped_column(nullable=True, doc="API key used for requests to the provider.")
|
||||||
|
base_url: Mapped[str] = mapped_column(nullable=True, doc="Base URL for the provider.")
|
||||||
|
|
||||||
# relationships
|
# relationships
|
||||||
organization: Mapped["Organization"] = relationship("Organization", back_populates="providers")
|
organization: Mapped["Organization"] = relationship("Organization", back_populates="providers")
|
||||||
|
@ -6,6 +6,17 @@ class ProviderType(str, Enum):
|
|||||||
google_ai = "google_ai"
|
google_ai = "google_ai"
|
||||||
google_vertex = "google_vertex"
|
google_vertex = "google_vertex"
|
||||||
openai = "openai"
|
openai = "openai"
|
||||||
|
letta = "letta"
|
||||||
|
deepseek = "deepseek"
|
||||||
|
lmstudio_openai = "lmstudio_openai"
|
||||||
|
xai = "xai"
|
||||||
|
mistral = "mistral"
|
||||||
|
ollama = "ollama"
|
||||||
|
groq = "groq"
|
||||||
|
together = "together"
|
||||||
|
azure = "azure"
|
||||||
|
vllm = "vllm"
|
||||||
|
bedrock = "bedrock"
|
||||||
|
|
||||||
|
|
||||||
class MessageRole(str, Enum):
|
class MessageRole(str, Enum):
|
||||||
|
@ -50,6 +50,7 @@ class LLMConfig(BaseModel):
|
|||||||
"xai",
|
"xai",
|
||||||
] = Field(..., description="The endpoint type for the model.")
|
] = Field(..., description="The endpoint type for the model.")
|
||||||
model_endpoint: Optional[str] = Field(None, description="The endpoint for the model.")
|
model_endpoint: Optional[str] = Field(None, description="The endpoint for the model.")
|
||||||
|
provider_name: Optional[str] = Field(None, description="The provider name for the model.")
|
||||||
model_wrapper: Optional[str] = Field(None, description="The wrapper for the model.")
|
model_wrapper: Optional[str] = Field(None, description="The wrapper for the model.")
|
||||||
context_window: int = Field(..., description="The context window size for the model.")
|
context_window: int = Field(..., description="The context window size for the model.")
|
||||||
put_inner_thoughts_in_kwargs: Optional[bool] = Field(
|
put_inner_thoughts_in_kwargs: Optional[bool] = Field(
|
||||||
|
@ -2,8 +2,8 @@ from typing import Dict
|
|||||||
|
|
||||||
LLM_HANDLE_OVERRIDES: Dict[str, Dict[str, str]] = {
|
LLM_HANDLE_OVERRIDES: Dict[str, Dict[str, str]] = {
|
||||||
"anthropic": {
|
"anthropic": {
|
||||||
"claude-3-5-haiku-20241022": "claude-3.5-haiku",
|
"claude-3-5-haiku-20241022": "claude-3-5-haiku",
|
||||||
"claude-3-5-sonnet-20241022": "claude-3.5-sonnet",
|
"claude-3-5-sonnet-20241022": "claude-3-5-sonnet",
|
||||||
"claude-3-opus-20240229": "claude-3-opus",
|
"claude-3-opus-20240229": "claude-3-opus",
|
||||||
},
|
},
|
||||||
"openai": {
|
"openai": {
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import warnings
|
import warnings
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import List, Optional
|
from typing import List, Literal, Optional
|
||||||
|
|
||||||
from pydantic import Field, model_validator
|
from pydantic import Field, model_validator
|
||||||
|
|
||||||
@ -9,9 +9,11 @@ from letta.llm_api.azure_openai import get_azure_chat_completions_endpoint, get_
|
|||||||
from letta.llm_api.azure_openai_constants import AZURE_MODEL_TO_CONTEXT_LENGTH
|
from letta.llm_api.azure_openai_constants import AZURE_MODEL_TO_CONTEXT_LENGTH
|
||||||
from letta.schemas.embedding_config import EmbeddingConfig
|
from letta.schemas.embedding_config import EmbeddingConfig
|
||||||
from letta.schemas.embedding_config_overrides import EMBEDDING_HANDLE_OVERRIDES
|
from letta.schemas.embedding_config_overrides import EMBEDDING_HANDLE_OVERRIDES
|
||||||
|
from letta.schemas.enums import ProviderType
|
||||||
from letta.schemas.letta_base import LettaBase
|
from letta.schemas.letta_base import LettaBase
|
||||||
from letta.schemas.llm_config import LLMConfig
|
from letta.schemas.llm_config import LLMConfig
|
||||||
from letta.schemas.llm_config_overrides import LLM_HANDLE_OVERRIDES
|
from letta.schemas.llm_config_overrides import LLM_HANDLE_OVERRIDES
|
||||||
|
from letta.settings import model_settings
|
||||||
|
|
||||||
|
|
||||||
class ProviderBase(LettaBase):
|
class ProviderBase(LettaBase):
|
||||||
@ -21,10 +23,18 @@ class ProviderBase(LettaBase):
|
|||||||
class Provider(ProviderBase):
|
class Provider(ProviderBase):
|
||||||
id: Optional[str] = Field(None, description="The id of the provider, lazily created by the database manager.")
|
id: Optional[str] = Field(None, description="The id of the provider, lazily created by the database manager.")
|
||||||
name: str = Field(..., description="The name of the provider")
|
name: str = Field(..., description="The name of the provider")
|
||||||
|
provider_type: ProviderType = Field(..., description="The type of the provider")
|
||||||
api_key: Optional[str] = Field(None, description="API key used for requests to the provider.")
|
api_key: Optional[str] = Field(None, description="API key used for requests to the provider.")
|
||||||
|
base_url: Optional[str] = Field(None, description="Base URL for the provider.")
|
||||||
organization_id: Optional[str] = Field(None, description="The organization id of the user")
|
organization_id: Optional[str] = Field(None, description="The organization id of the user")
|
||||||
updated_at: Optional[datetime] = Field(None, description="The last update timestamp of the provider.")
|
updated_at: Optional[datetime] = Field(None, description="The last update timestamp of the provider.")
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def default_base_url(self):
|
||||||
|
if self.provider_type == ProviderType.openai and self.base_url is None:
|
||||||
|
self.base_url = model_settings.openai_api_base
|
||||||
|
return self
|
||||||
|
|
||||||
def resolve_identifier(self):
|
def resolve_identifier(self):
|
||||||
if not self.id:
|
if not self.id:
|
||||||
self.id = ProviderBase.generate_id(prefix=ProviderBase.__id_prefix__)
|
self.id = ProviderBase.generate_id(prefix=ProviderBase.__id_prefix__)
|
||||||
@ -59,9 +69,41 @@ class Provider(ProviderBase):
|
|||||||
|
|
||||||
return f"{self.name}/{model_name}"
|
return f"{self.name}/{model_name}"
|
||||||
|
|
||||||
|
def cast_to_subtype(self):
|
||||||
|
match (self.provider_type):
|
||||||
|
case ProviderType.letta:
|
||||||
|
return LettaProvider(**self.model_dump(exclude_none=True))
|
||||||
|
case ProviderType.openai:
|
||||||
|
return OpenAIProvider(**self.model_dump(exclude_none=True))
|
||||||
|
case ProviderType.anthropic:
|
||||||
|
return AnthropicProvider(**self.model_dump(exclude_none=True))
|
||||||
|
case ProviderType.anthropic_bedrock:
|
||||||
|
return AnthropicBedrockProvider(**self.model_dump(exclude_none=True))
|
||||||
|
case ProviderType.ollama:
|
||||||
|
return OllamaProvider(**self.model_dump(exclude_none=True))
|
||||||
|
case ProviderType.google_ai:
|
||||||
|
return GoogleAIProvider(**self.model_dump(exclude_none=True))
|
||||||
|
case ProviderType.google_vertex:
|
||||||
|
return GoogleVertexProvider(**self.model_dump(exclude_none=True))
|
||||||
|
case ProviderType.azure:
|
||||||
|
return AzureProvider(**self.model_dump(exclude_none=True))
|
||||||
|
case ProviderType.groq:
|
||||||
|
return GroqProvider(**self.model_dump(exclude_none=True))
|
||||||
|
case ProviderType.together:
|
||||||
|
return TogetherProvider(**self.model_dump(exclude_none=True))
|
||||||
|
case ProviderType.vllm_chat_completions:
|
||||||
|
return VLLMChatCompletionsProvider(**self.model_dump(exclude_none=True))
|
||||||
|
case ProviderType.vllm_completions:
|
||||||
|
return VLLMCompletionsProvider(**self.model_dump(exclude_none=True))
|
||||||
|
case ProviderType.xai:
|
||||||
|
return XAIProvider(**self.model_dump(exclude_none=True))
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"Unknown provider type: {self.provider_type}")
|
||||||
|
|
||||||
|
|
||||||
class ProviderCreate(ProviderBase):
|
class ProviderCreate(ProviderBase):
|
||||||
name: str = Field(..., description="The name of the provider.")
|
name: str = Field(..., description="The name of the provider.")
|
||||||
|
provider_type: ProviderType = Field(..., description="The type of the provider.")
|
||||||
api_key: str = Field(..., description="API key used for requests to the provider.")
|
api_key: str = Field(..., description="API key used for requests to the provider.")
|
||||||
|
|
||||||
|
|
||||||
@ -70,8 +112,7 @@ class ProviderUpdate(ProviderBase):
|
|||||||
|
|
||||||
|
|
||||||
class LettaProvider(Provider):
|
class LettaProvider(Provider):
|
||||||
|
provider_type: Literal[ProviderType.letta] = Field(ProviderType.letta, description="The type of the provider.")
|
||||||
name: str = "letta"
|
|
||||||
|
|
||||||
def list_llm_models(self) -> List[LLMConfig]:
|
def list_llm_models(self) -> List[LLMConfig]:
|
||||||
return [
|
return [
|
||||||
@ -81,6 +122,7 @@ class LettaProvider(Provider):
|
|||||||
model_endpoint=LETTA_MODEL_ENDPOINT,
|
model_endpoint=LETTA_MODEL_ENDPOINT,
|
||||||
context_window=8192,
|
context_window=8192,
|
||||||
handle=self.get_handle("letta-free"),
|
handle=self.get_handle("letta-free"),
|
||||||
|
provider_name=self.name,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -98,7 +140,7 @@ class LettaProvider(Provider):
|
|||||||
|
|
||||||
|
|
||||||
class OpenAIProvider(Provider):
|
class OpenAIProvider(Provider):
|
||||||
name: str = "openai"
|
provider_type: Literal[ProviderType.openai] = Field(ProviderType.openai, description="The type of the provider.")
|
||||||
api_key: str = Field(..., description="API key for the OpenAI API.")
|
api_key: str = Field(..., description="API key for the OpenAI API.")
|
||||||
base_url: str = Field(..., description="Base URL for the OpenAI API.")
|
base_url: str = Field(..., description="Base URL for the OpenAI API.")
|
||||||
|
|
||||||
@ -180,6 +222,7 @@ class OpenAIProvider(Provider):
|
|||||||
model_endpoint=self.base_url,
|
model_endpoint=self.base_url,
|
||||||
context_window=context_window_size,
|
context_window=context_window_size,
|
||||||
handle=self.get_handle(model_name),
|
handle=self.get_handle(model_name),
|
||||||
|
provider_name=self.name,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -235,7 +278,7 @@ class DeepSeekProvider(OpenAIProvider):
|
|||||||
* It also does not support native function calling
|
* It also does not support native function calling
|
||||||
"""
|
"""
|
||||||
|
|
||||||
name: str = "deepseek"
|
provider_type: Literal[ProviderType.deepseek] = Field(ProviderType.deepseek, description="The type of the provider.")
|
||||||
base_url: str = Field("https://api.deepseek.com/v1", description="Base URL for the DeepSeek API.")
|
base_url: str = Field("https://api.deepseek.com/v1", description="Base URL for the DeepSeek API.")
|
||||||
api_key: str = Field(..., description="API key for the DeepSeek API.")
|
api_key: str = Field(..., description="API key for the DeepSeek API.")
|
||||||
|
|
||||||
@ -286,6 +329,7 @@ class DeepSeekProvider(OpenAIProvider):
|
|||||||
context_window=context_window_size,
|
context_window=context_window_size,
|
||||||
handle=self.get_handle(model_name),
|
handle=self.get_handle(model_name),
|
||||||
put_inner_thoughts_in_kwargs=put_inner_thoughts_in_kwargs,
|
put_inner_thoughts_in_kwargs=put_inner_thoughts_in_kwargs,
|
||||||
|
provider_name=self.name,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -297,7 +341,7 @@ class DeepSeekProvider(OpenAIProvider):
|
|||||||
|
|
||||||
|
|
||||||
class LMStudioOpenAIProvider(OpenAIProvider):
|
class LMStudioOpenAIProvider(OpenAIProvider):
|
||||||
name: str = "lmstudio-openai"
|
provider_type: Literal[ProviderType.lmstudio_openai] = Field(ProviderType.lmstudio_openai, description="The type of the provider.")
|
||||||
base_url: str = Field(..., description="Base URL for the LMStudio OpenAI API.")
|
base_url: str = Field(..., description="Base URL for the LMStudio OpenAI API.")
|
||||||
api_key: Optional[str] = Field(None, description="API key for the LMStudio API.")
|
api_key: Optional[str] = Field(None, description="API key for the LMStudio API.")
|
||||||
|
|
||||||
@ -423,7 +467,7 @@ class LMStudioOpenAIProvider(OpenAIProvider):
|
|||||||
class XAIProvider(OpenAIProvider):
|
class XAIProvider(OpenAIProvider):
|
||||||
"""https://docs.x.ai/docs/api-reference"""
|
"""https://docs.x.ai/docs/api-reference"""
|
||||||
|
|
||||||
name: str = "xai"
|
provider_type: Literal[ProviderType.xai] = Field(ProviderType.xai, description="The type of the provider.")
|
||||||
api_key: str = Field(..., description="API key for the xAI/Grok API.")
|
api_key: str = Field(..., description="API key for the xAI/Grok API.")
|
||||||
base_url: str = Field("https://api.x.ai/v1", description="Base URL for the xAI/Grok API.")
|
base_url: str = Field("https://api.x.ai/v1", description="Base URL for the xAI/Grok API.")
|
||||||
|
|
||||||
@ -476,6 +520,7 @@ class XAIProvider(OpenAIProvider):
|
|||||||
model_endpoint=self.base_url,
|
model_endpoint=self.base_url,
|
||||||
context_window=context_window_size,
|
context_window=context_window_size,
|
||||||
handle=self.get_handle(model_name),
|
handle=self.get_handle(model_name),
|
||||||
|
provider_name=self.name,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -487,7 +532,7 @@ class XAIProvider(OpenAIProvider):
|
|||||||
|
|
||||||
|
|
||||||
class AnthropicProvider(Provider):
|
class AnthropicProvider(Provider):
|
||||||
name: str = "anthropic"
|
provider_type: Literal[ProviderType.anthropic] = Field(ProviderType.anthropic, description="The type of the provider.")
|
||||||
api_key: str = Field(..., description="API key for the Anthropic API.")
|
api_key: str = Field(..., description="API key for the Anthropic API.")
|
||||||
base_url: str = "https://api.anthropic.com/v1"
|
base_url: str = "https://api.anthropic.com/v1"
|
||||||
|
|
||||||
@ -563,6 +608,7 @@ class AnthropicProvider(Provider):
|
|||||||
handle=self.get_handle(model["id"]),
|
handle=self.get_handle(model["id"]),
|
||||||
put_inner_thoughts_in_kwargs=inner_thoughts_in_kwargs,
|
put_inner_thoughts_in_kwargs=inner_thoughts_in_kwargs,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
|
provider_name=self.name,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return configs
|
return configs
|
||||||
@ -572,7 +618,7 @@ class AnthropicProvider(Provider):
|
|||||||
|
|
||||||
|
|
||||||
class MistralProvider(Provider):
|
class MistralProvider(Provider):
|
||||||
name: str = "mistral"
|
provider_type: Literal[ProviderType.mistral] = Field(ProviderType.mistral, description="The type of the provider.")
|
||||||
api_key: str = Field(..., description="API key for the Mistral API.")
|
api_key: str = Field(..., description="API key for the Mistral API.")
|
||||||
base_url: str = "https://api.mistral.ai/v1"
|
base_url: str = "https://api.mistral.ai/v1"
|
||||||
|
|
||||||
@ -596,6 +642,7 @@ class MistralProvider(Provider):
|
|||||||
model_endpoint=self.base_url,
|
model_endpoint=self.base_url,
|
||||||
context_window=model["max_context_length"],
|
context_window=model["max_context_length"],
|
||||||
handle=self.get_handle(model["id"]),
|
handle=self.get_handle(model["id"]),
|
||||||
|
provider_name=self.name,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -622,7 +669,7 @@ class OllamaProvider(OpenAIProvider):
|
|||||||
See: https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion
|
See: https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion
|
||||||
"""
|
"""
|
||||||
|
|
||||||
name: str = "ollama"
|
provider_type: Literal[ProviderType.ollama] = Field(ProviderType.ollama, description="The type of the provider.")
|
||||||
base_url: str = Field(..., description="Base URL for the Ollama API.")
|
base_url: str = Field(..., description="Base URL for the Ollama API.")
|
||||||
api_key: Optional[str] = Field(None, description="API key for the Ollama API (default: `None`).")
|
api_key: Optional[str] = Field(None, description="API key for the Ollama API (default: `None`).")
|
||||||
default_prompt_formatter: str = Field(
|
default_prompt_formatter: str = Field(
|
||||||
@ -652,6 +699,7 @@ class OllamaProvider(OpenAIProvider):
|
|||||||
model_wrapper=self.default_prompt_formatter,
|
model_wrapper=self.default_prompt_formatter,
|
||||||
context_window=context_window,
|
context_window=context_window,
|
||||||
handle=self.get_handle(model["name"]),
|
handle=self.get_handle(model["name"]),
|
||||||
|
provider_name=self.name,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return configs
|
return configs
|
||||||
@ -734,7 +782,7 @@ class OllamaProvider(OpenAIProvider):
|
|||||||
|
|
||||||
|
|
||||||
class GroqProvider(OpenAIProvider):
|
class GroqProvider(OpenAIProvider):
|
||||||
name: str = "groq"
|
provider_type: Literal[ProviderType.groq] = Field(ProviderType.groq, description="The type of the provider.")
|
||||||
base_url: str = "https://api.groq.com/openai/v1"
|
base_url: str = "https://api.groq.com/openai/v1"
|
||||||
api_key: str = Field(..., description="API key for the Groq API.")
|
api_key: str = Field(..., description="API key for the Groq API.")
|
||||||
|
|
||||||
@ -753,6 +801,7 @@ class GroqProvider(OpenAIProvider):
|
|||||||
model_endpoint=self.base_url,
|
model_endpoint=self.base_url,
|
||||||
context_window=model["context_window"],
|
context_window=model["context_window"],
|
||||||
handle=self.get_handle(model["id"]),
|
handle=self.get_handle(model["id"]),
|
||||||
|
provider_name=self.name,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return configs
|
return configs
|
||||||
@ -773,7 +822,7 @@ class TogetherProvider(OpenAIProvider):
|
|||||||
function calling support is limited.
|
function calling support is limited.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
name: str = "together"
|
provider_type: Literal[ProviderType.together] = Field(ProviderType.together, description="The type of the provider.")
|
||||||
base_url: str = "https://api.together.ai/v1"
|
base_url: str = "https://api.together.ai/v1"
|
||||||
api_key: str = Field(..., description="API key for the TogetherAI API.")
|
api_key: str = Field(..., description="API key for the TogetherAI API.")
|
||||||
default_prompt_formatter: str = Field(..., description="Default prompt formatter (aka model wrapper) to use on vLLM /completions API.")
|
default_prompt_formatter: str = Field(..., description="Default prompt formatter (aka model wrapper) to use on vLLM /completions API.")
|
||||||
@ -821,6 +870,7 @@ class TogetherProvider(OpenAIProvider):
|
|||||||
model_wrapper=self.default_prompt_formatter,
|
model_wrapper=self.default_prompt_formatter,
|
||||||
context_window=context_window_size,
|
context_window=context_window_size,
|
||||||
handle=self.get_handle(model_name),
|
handle=self.get_handle(model_name),
|
||||||
|
provider_name=self.name,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -874,7 +924,7 @@ class TogetherProvider(OpenAIProvider):
|
|||||||
|
|
||||||
class GoogleAIProvider(Provider):
|
class GoogleAIProvider(Provider):
|
||||||
# gemini
|
# gemini
|
||||||
name: str = "google_ai"
|
provider_type: Literal[ProviderType.google_ai] = Field(ProviderType.google_ai, description="The type of the provider.")
|
||||||
api_key: str = Field(..., description="API key for the Google AI API.")
|
api_key: str = Field(..., description="API key for the Google AI API.")
|
||||||
base_url: str = "https://generativelanguage.googleapis.com"
|
base_url: str = "https://generativelanguage.googleapis.com"
|
||||||
|
|
||||||
@ -889,7 +939,6 @@ class GoogleAIProvider(Provider):
|
|||||||
# filter by model names
|
# filter by model names
|
||||||
model_options = [mo[len("models/") :] if mo.startswith("models/") else mo for mo in model_options]
|
model_options = [mo[len("models/") :] if mo.startswith("models/") else mo for mo in model_options]
|
||||||
|
|
||||||
# TODO remove manual filtering for gemini-pro
|
|
||||||
# Add support for all gemini models
|
# Add support for all gemini models
|
||||||
model_options = [mo for mo in model_options if str(mo).startswith("gemini-")]
|
model_options = [mo for mo in model_options if str(mo).startswith("gemini-")]
|
||||||
|
|
||||||
@ -903,6 +952,7 @@ class GoogleAIProvider(Provider):
|
|||||||
context_window=self.get_model_context_window(model),
|
context_window=self.get_model_context_window(model),
|
||||||
handle=self.get_handle(model),
|
handle=self.get_handle(model),
|
||||||
max_tokens=8192,
|
max_tokens=8192,
|
||||||
|
provider_name=self.name,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return configs
|
return configs
|
||||||
@ -938,7 +988,7 @@ class GoogleAIProvider(Provider):
|
|||||||
|
|
||||||
|
|
||||||
class GoogleVertexProvider(Provider):
|
class GoogleVertexProvider(Provider):
|
||||||
name: str = "google_vertex"
|
provider_type: Literal[ProviderType.google_vertex] = Field(ProviderType.google_vertex, description="The type of the provider.")
|
||||||
google_cloud_project: str = Field(..., description="GCP project ID for the Google Vertex API.")
|
google_cloud_project: str = Field(..., description="GCP project ID for the Google Vertex API.")
|
||||||
google_cloud_location: str = Field(..., description="GCP region for the Google Vertex API.")
|
google_cloud_location: str = Field(..., description="GCP region for the Google Vertex API.")
|
||||||
|
|
||||||
@ -955,6 +1005,7 @@ class GoogleVertexProvider(Provider):
|
|||||||
context_window=context_length,
|
context_window=context_length,
|
||||||
handle=self.get_handle(model),
|
handle=self.get_handle(model),
|
||||||
max_tokens=8192,
|
max_tokens=8192,
|
||||||
|
provider_name=self.name,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return configs
|
return configs
|
||||||
@ -978,7 +1029,7 @@ class GoogleVertexProvider(Provider):
|
|||||||
|
|
||||||
|
|
||||||
class AzureProvider(Provider):
|
class AzureProvider(Provider):
|
||||||
name: str = "azure"
|
provider_type: Literal[ProviderType.azure] = Field(ProviderType.azure, description="The type of the provider.")
|
||||||
latest_api_version: str = "2024-09-01-preview" # https://learn.microsoft.com/en-us/azure/ai-services/openai/api-version-deprecation
|
latest_api_version: str = "2024-09-01-preview" # https://learn.microsoft.com/en-us/azure/ai-services/openai/api-version-deprecation
|
||||||
base_url: str = Field(
|
base_url: str = Field(
|
||||||
..., description="Base URL for the Azure API endpoint. This should be specific to your org, e.g. `https://letta.openai.azure.com`."
|
..., description="Base URL for the Azure API endpoint. This should be specific to your org, e.g. `https://letta.openai.azure.com`."
|
||||||
@ -1011,6 +1062,7 @@ class AzureProvider(Provider):
|
|||||||
model_endpoint=model_endpoint,
|
model_endpoint=model_endpoint,
|
||||||
context_window=context_window_size,
|
context_window=context_window_size,
|
||||||
handle=self.get_handle(model_name),
|
handle=self.get_handle(model_name),
|
||||||
|
provider_name=self.name,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
return configs
|
return configs
|
||||||
@ -1051,7 +1103,7 @@ class VLLMChatCompletionsProvider(Provider):
|
|||||||
"""vLLM provider that treats vLLM as an OpenAI /chat/completions proxy"""
|
"""vLLM provider that treats vLLM as an OpenAI /chat/completions proxy"""
|
||||||
|
|
||||||
# NOTE: vLLM only serves one model at a time (so could configure that through env variables)
|
# NOTE: vLLM only serves one model at a time (so could configure that through env variables)
|
||||||
name: str = "vllm"
|
provider_type: Literal[ProviderType.vllm] = Field(ProviderType.vllm, description="The type of the provider.")
|
||||||
base_url: str = Field(..., description="Base URL for the vLLM API.")
|
base_url: str = Field(..., description="Base URL for the vLLM API.")
|
||||||
|
|
||||||
def list_llm_models(self) -> List[LLMConfig]:
|
def list_llm_models(self) -> List[LLMConfig]:
|
||||||
@ -1070,6 +1122,7 @@ class VLLMChatCompletionsProvider(Provider):
|
|||||||
model_endpoint=self.base_url,
|
model_endpoint=self.base_url,
|
||||||
context_window=model["max_model_len"],
|
context_window=model["max_model_len"],
|
||||||
handle=self.get_handle(model["id"]),
|
handle=self.get_handle(model["id"]),
|
||||||
|
provider_name=self.name,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return configs
|
return configs
|
||||||
@ -1083,7 +1136,7 @@ class VLLMCompletionsProvider(Provider):
|
|||||||
"""This uses /completions API as the backend, not /chat/completions, so we need to specify a model wrapper"""
|
"""This uses /completions API as the backend, not /chat/completions, so we need to specify a model wrapper"""
|
||||||
|
|
||||||
# NOTE: vLLM only serves one model at a time (so could configure that through env variables)
|
# NOTE: vLLM only serves one model at a time (so could configure that through env variables)
|
||||||
name: str = "vllm"
|
provider_type: Literal[ProviderType.vllm] = Field(ProviderType.vllm, description="The type of the provider.")
|
||||||
base_url: str = Field(..., description="Base URL for the vLLM API.")
|
base_url: str = Field(..., description="Base URL for the vLLM API.")
|
||||||
default_prompt_formatter: str = Field(..., description="Default prompt formatter (aka model wrapper) to use on vLLM /completions API.")
|
default_prompt_formatter: str = Field(..., description="Default prompt formatter (aka model wrapper) to use on vLLM /completions API.")
|
||||||
|
|
||||||
@ -1103,6 +1156,7 @@ class VLLMCompletionsProvider(Provider):
|
|||||||
model_wrapper=self.default_prompt_formatter,
|
model_wrapper=self.default_prompt_formatter,
|
||||||
context_window=model["max_model_len"],
|
context_window=model["max_model_len"],
|
||||||
handle=self.get_handle(model["id"]),
|
handle=self.get_handle(model["id"]),
|
||||||
|
provider_name=self.name,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return configs
|
return configs
|
||||||
@ -1117,7 +1171,7 @@ class CohereProvider(OpenAIProvider):
|
|||||||
|
|
||||||
|
|
||||||
class AnthropicBedrockProvider(Provider):
|
class AnthropicBedrockProvider(Provider):
|
||||||
name: str = "bedrock"
|
provider_type: Literal[ProviderType.bedrock] = Field(ProviderType.bedrock, description="The type of the provider.")
|
||||||
aws_region: str = Field(..., description="AWS region for Bedrock")
|
aws_region: str = Field(..., description="AWS region for Bedrock")
|
||||||
|
|
||||||
def list_llm_models(self):
|
def list_llm_models(self):
|
||||||
@ -1131,10 +1185,11 @@ class AnthropicBedrockProvider(Provider):
|
|||||||
configs.append(
|
configs.append(
|
||||||
LLMConfig(
|
LLMConfig(
|
||||||
model=model_arn,
|
model=model_arn,
|
||||||
model_endpoint_type=self.name,
|
model_endpoint_type=self.provider_type.value,
|
||||||
model_endpoint=None,
|
model_endpoint=None,
|
||||||
context_window=self.get_model_context_window(model_arn),
|
context_window=self.get_model_context_window(model_arn),
|
||||||
handle=self.get_handle(model_arn),
|
handle=self.get_handle(model_arn),
|
||||||
|
provider_name=self.name,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return configs
|
return configs
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from typing import TYPE_CHECKING, List
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends
|
from fastapi import APIRouter, Depends, Query
|
||||||
|
|
||||||
from letta.schemas.embedding_config import EmbeddingConfig
|
from letta.schemas.embedding_config import EmbeddingConfig
|
||||||
from letta.schemas.llm_config import LLMConfig
|
from letta.schemas.llm_config import LLMConfig
|
||||||
@ -14,10 +14,11 @@ router = APIRouter(prefix="/models", tags=["models", "llms"])
|
|||||||
|
|
||||||
@router.get("/", response_model=List[LLMConfig], operation_id="list_models")
|
@router.get("/", response_model=List[LLMConfig], operation_id="list_models")
|
||||||
def list_llm_models(
|
def list_llm_models(
|
||||||
|
byok_only: Optional[bool] = Query(None),
|
||||||
server: "SyncServer" = Depends(get_letta_server),
|
server: "SyncServer" = Depends(get_letta_server),
|
||||||
):
|
):
|
||||||
|
|
||||||
models = server.list_llm_models()
|
models = server.list_llm_models(byok_only=byok_only)
|
||||||
# print(models)
|
# print(models)
|
||||||
return models
|
return models
|
||||||
|
|
||||||
|
@ -2,6 +2,7 @@ from typing import TYPE_CHECKING, List, Optional
|
|||||||
|
|
||||||
from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query
|
from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query
|
||||||
|
|
||||||
|
from letta.schemas.enums import ProviderType
|
||||||
from letta.schemas.providers import Provider, ProviderCreate, ProviderUpdate
|
from letta.schemas.providers import Provider, ProviderCreate, ProviderUpdate
|
||||||
from letta.server.rest_api.utils import get_letta_server
|
from letta.server.rest_api.utils import get_letta_server
|
||||||
|
|
||||||
@ -13,6 +14,8 @@ router = APIRouter(prefix="/providers", tags=["providers"])
|
|||||||
|
|
||||||
@router.get("/", response_model=List[Provider], operation_id="list_providers")
|
@router.get("/", response_model=List[Provider], operation_id="list_providers")
|
||||||
def list_providers(
|
def list_providers(
|
||||||
|
name: Optional[str] = Query(None),
|
||||||
|
provider_type: Optional[ProviderType] = Query(None),
|
||||||
after: Optional[str] = Query(None),
|
after: Optional[str] = Query(None),
|
||||||
limit: Optional[int] = Query(50),
|
limit: Optional[int] = Query(50),
|
||||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||||
@ -23,7 +26,7 @@ def list_providers(
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||||
providers = server.provider_manager.list_providers(after=after, limit=limit, actor=actor)
|
providers = server.provider_manager.list_providers(after=after, limit=limit, actor=actor, name=name, provider_type=provider_type)
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -268,10 +268,11 @@ class SyncServer(Server):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# collect providers (always has Letta as a default)
|
# collect providers (always has Letta as a default)
|
||||||
self._enabled_providers: List[Provider] = [LettaProvider()]
|
self._enabled_providers: List[Provider] = [LettaProvider(name="letta")]
|
||||||
if model_settings.openai_api_key:
|
if model_settings.openai_api_key:
|
||||||
self._enabled_providers.append(
|
self._enabled_providers.append(
|
||||||
OpenAIProvider(
|
OpenAIProvider(
|
||||||
|
name="openai",
|
||||||
api_key=model_settings.openai_api_key,
|
api_key=model_settings.openai_api_key,
|
||||||
base_url=model_settings.openai_api_base,
|
base_url=model_settings.openai_api_base,
|
||||||
)
|
)
|
||||||
@ -279,12 +280,14 @@ class SyncServer(Server):
|
|||||||
if model_settings.anthropic_api_key:
|
if model_settings.anthropic_api_key:
|
||||||
self._enabled_providers.append(
|
self._enabled_providers.append(
|
||||||
AnthropicProvider(
|
AnthropicProvider(
|
||||||
|
name="anthropic",
|
||||||
api_key=model_settings.anthropic_api_key,
|
api_key=model_settings.anthropic_api_key,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if model_settings.ollama_base_url:
|
if model_settings.ollama_base_url:
|
||||||
self._enabled_providers.append(
|
self._enabled_providers.append(
|
||||||
OllamaProvider(
|
OllamaProvider(
|
||||||
|
name="ollama",
|
||||||
base_url=model_settings.ollama_base_url,
|
base_url=model_settings.ollama_base_url,
|
||||||
api_key=None,
|
api_key=None,
|
||||||
default_prompt_formatter=model_settings.default_prompt_formatter,
|
default_prompt_formatter=model_settings.default_prompt_formatter,
|
||||||
@ -293,12 +296,14 @@ class SyncServer(Server):
|
|||||||
if model_settings.gemini_api_key:
|
if model_settings.gemini_api_key:
|
||||||
self._enabled_providers.append(
|
self._enabled_providers.append(
|
||||||
GoogleAIProvider(
|
GoogleAIProvider(
|
||||||
|
name="google_ai",
|
||||||
api_key=model_settings.gemini_api_key,
|
api_key=model_settings.gemini_api_key,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if model_settings.google_cloud_location and model_settings.google_cloud_project:
|
if model_settings.google_cloud_location and model_settings.google_cloud_project:
|
||||||
self._enabled_providers.append(
|
self._enabled_providers.append(
|
||||||
GoogleVertexProvider(
|
GoogleVertexProvider(
|
||||||
|
name="google_vertex",
|
||||||
google_cloud_project=model_settings.google_cloud_project,
|
google_cloud_project=model_settings.google_cloud_project,
|
||||||
google_cloud_location=model_settings.google_cloud_location,
|
google_cloud_location=model_settings.google_cloud_location,
|
||||||
)
|
)
|
||||||
@ -307,6 +312,7 @@ class SyncServer(Server):
|
|||||||
assert model_settings.azure_api_version, "AZURE_API_VERSION is required"
|
assert model_settings.azure_api_version, "AZURE_API_VERSION is required"
|
||||||
self._enabled_providers.append(
|
self._enabled_providers.append(
|
||||||
AzureProvider(
|
AzureProvider(
|
||||||
|
name="azure",
|
||||||
api_key=model_settings.azure_api_key,
|
api_key=model_settings.azure_api_key,
|
||||||
base_url=model_settings.azure_base_url,
|
base_url=model_settings.azure_base_url,
|
||||||
api_version=model_settings.azure_api_version,
|
api_version=model_settings.azure_api_version,
|
||||||
@ -315,12 +321,14 @@ class SyncServer(Server):
|
|||||||
if model_settings.groq_api_key:
|
if model_settings.groq_api_key:
|
||||||
self._enabled_providers.append(
|
self._enabled_providers.append(
|
||||||
GroqProvider(
|
GroqProvider(
|
||||||
|
name="groq",
|
||||||
api_key=model_settings.groq_api_key,
|
api_key=model_settings.groq_api_key,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if model_settings.together_api_key:
|
if model_settings.together_api_key:
|
||||||
self._enabled_providers.append(
|
self._enabled_providers.append(
|
||||||
TogetherProvider(
|
TogetherProvider(
|
||||||
|
name="together",
|
||||||
api_key=model_settings.together_api_key,
|
api_key=model_settings.together_api_key,
|
||||||
default_prompt_formatter=model_settings.default_prompt_formatter,
|
default_prompt_formatter=model_settings.default_prompt_formatter,
|
||||||
)
|
)
|
||||||
@ -329,6 +337,7 @@ class SyncServer(Server):
|
|||||||
# vLLM exposes both a /chat/completions and a /completions endpoint
|
# vLLM exposes both a /chat/completions and a /completions endpoint
|
||||||
self._enabled_providers.append(
|
self._enabled_providers.append(
|
||||||
VLLMCompletionsProvider(
|
VLLMCompletionsProvider(
|
||||||
|
name="vllm",
|
||||||
base_url=model_settings.vllm_api_base,
|
base_url=model_settings.vllm_api_base,
|
||||||
default_prompt_formatter=model_settings.default_prompt_formatter,
|
default_prompt_formatter=model_settings.default_prompt_formatter,
|
||||||
)
|
)
|
||||||
@ -338,12 +347,14 @@ class SyncServer(Server):
|
|||||||
# e.g. "... --enable-auto-tool-choice --tool-call-parser hermes"
|
# e.g. "... --enable-auto-tool-choice --tool-call-parser hermes"
|
||||||
self._enabled_providers.append(
|
self._enabled_providers.append(
|
||||||
VLLMChatCompletionsProvider(
|
VLLMChatCompletionsProvider(
|
||||||
|
name="vllm",
|
||||||
base_url=model_settings.vllm_api_base,
|
base_url=model_settings.vllm_api_base,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if model_settings.aws_access_key and model_settings.aws_secret_access_key and model_settings.aws_region:
|
if model_settings.aws_access_key and model_settings.aws_secret_access_key and model_settings.aws_region:
|
||||||
self._enabled_providers.append(
|
self._enabled_providers.append(
|
||||||
AnthropicBedrockProvider(
|
AnthropicBedrockProvider(
|
||||||
|
name="bedrock",
|
||||||
aws_region=model_settings.aws_region,
|
aws_region=model_settings.aws_region,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -355,11 +366,11 @@ class SyncServer(Server):
|
|||||||
if model_settings.lmstudio_base_url.endswith("/v1")
|
if model_settings.lmstudio_base_url.endswith("/v1")
|
||||||
else model_settings.lmstudio_base_url + "/v1"
|
else model_settings.lmstudio_base_url + "/v1"
|
||||||
)
|
)
|
||||||
self._enabled_providers.append(LMStudioOpenAIProvider(base_url=lmstudio_url))
|
self._enabled_providers.append(LMStudioOpenAIProvider(name="lmstudio_openai", base_url=lmstudio_url))
|
||||||
if model_settings.deepseek_api_key:
|
if model_settings.deepseek_api_key:
|
||||||
self._enabled_providers.append(DeepSeekProvider(api_key=model_settings.deepseek_api_key))
|
self._enabled_providers.append(DeepSeekProvider(name="deepseek", api_key=model_settings.deepseek_api_key))
|
||||||
if model_settings.xai_api_key:
|
if model_settings.xai_api_key:
|
||||||
self._enabled_providers.append(XAIProvider(api_key=model_settings.xai_api_key))
|
self._enabled_providers.append(XAIProvider(name="xai", api_key=model_settings.xai_api_key))
|
||||||
|
|
||||||
# For MCP
|
# For MCP
|
||||||
"""Initialize the MCP clients (there may be multiple)"""
|
"""Initialize the MCP clients (there may be multiple)"""
|
||||||
@ -1184,10 +1195,10 @@ class SyncServer(Server):
|
|||||||
except NoResultFound:
|
except NoResultFound:
|
||||||
raise HTTPException(status_code=404, detail=f"Organization with id {org_id} not found")
|
raise HTTPException(status_code=404, detail=f"Organization with id {org_id} not found")
|
||||||
|
|
||||||
def list_llm_models(self) -> List[LLMConfig]:
|
def list_llm_models(self, byok_only: bool = False) -> List[LLMConfig]:
|
||||||
"""List available models"""
|
"""List available models"""
|
||||||
llm_models = []
|
llm_models = []
|
||||||
for provider in self.get_enabled_providers():
|
for provider in self.get_enabled_providers(byok_only=byok_only):
|
||||||
try:
|
try:
|
||||||
llm_models.extend(provider.list_llm_models())
|
llm_models.extend(provider.list_llm_models())
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -1207,11 +1218,12 @@ class SyncServer(Server):
|
|||||||
warnings.warn(f"An error occurred while listing embedding models for provider {provider}: {e}")
|
warnings.warn(f"An error occurred while listing embedding models for provider {provider}: {e}")
|
||||||
return embedding_models
|
return embedding_models
|
||||||
|
|
||||||
def get_enabled_providers(self):
|
def get_enabled_providers(self, byok_only: bool = False):
|
||||||
|
providers_from_db = {p.name: p.cast_to_subtype() for p in self.provider_manager.list_providers()}
|
||||||
|
if byok_only:
|
||||||
|
return list(providers_from_db.values())
|
||||||
providers_from_env = {p.name: p for p in self._enabled_providers}
|
providers_from_env = {p.name: p for p in self._enabled_providers}
|
||||||
providers_from_db = {p.name: p for p in self.provider_manager.list_providers()}
|
return list(providers_from_env.values()) + list(providers_from_db.values())
|
||||||
# Merge the two dictionaries, keeping the values from providers_from_db where conflicts occur
|
|
||||||
return {**providers_from_env, **providers_from_db}.values()
|
|
||||||
|
|
||||||
@trace_method
|
@trace_method
|
||||||
def get_llm_config_from_handle(
|
def get_llm_config_from_handle(
|
||||||
@ -1296,7 +1308,7 @@ class SyncServer(Server):
|
|||||||
return embedding_config
|
return embedding_config
|
||||||
|
|
||||||
def get_provider_from_name(self, provider_name: str) -> Provider:
|
def get_provider_from_name(self, provider_name: str) -> Provider:
|
||||||
providers = [provider for provider in self._enabled_providers if provider.name == provider_name]
|
providers = [provider for provider in self.get_enabled_providers() if provider.name == provider_name]
|
||||||
if not providers:
|
if not providers:
|
||||||
raise ValueError(f"Provider {provider_name} is not supported")
|
raise ValueError(f"Provider {provider_name} is not supported")
|
||||||
elif len(providers) > 1:
|
elif len(providers) > 1:
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
from typing import List, Optional
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
from letta.orm.provider import Provider as ProviderModel
|
from letta.orm.provider import Provider as ProviderModel
|
||||||
|
from letta.schemas.enums import ProviderType
|
||||||
from letta.schemas.providers import Provider as PydanticProvider
|
from letta.schemas.providers import Provider as PydanticProvider
|
||||||
from letta.schemas.providers import ProviderUpdate
|
from letta.schemas.providers import ProviderUpdate
|
||||||
from letta.schemas.user import User as PydanticUser
|
from letta.schemas.user import User as PydanticUser
|
||||||
@ -18,6 +19,9 @@ class ProviderManager:
|
|||||||
def create_provider(self, provider: PydanticProvider, actor: PydanticUser) -> PydanticProvider:
|
def create_provider(self, provider: PydanticProvider, actor: PydanticUser) -> PydanticProvider:
|
||||||
"""Create a new provider if it doesn't already exist."""
|
"""Create a new provider if it doesn't already exist."""
|
||||||
with self.session_maker() as session:
|
with self.session_maker() as session:
|
||||||
|
if provider.name == provider.provider_type.value:
|
||||||
|
raise ValueError("Provider name must be unique and different from provider type")
|
||||||
|
|
||||||
# Assign the organization id based on the actor
|
# Assign the organization id based on the actor
|
||||||
provider.organization_id = actor.organization_id
|
provider.organization_id = actor.organization_id
|
||||||
|
|
||||||
@ -59,29 +63,36 @@ class ProviderManager:
|
|||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
@enforce_types
|
@enforce_types
|
||||||
def list_providers(self, after: Optional[str] = None, limit: Optional[int] = 50, actor: PydanticUser = None) -> List[PydanticProvider]:
|
def list_providers(
|
||||||
|
self,
|
||||||
|
name: Optional[str] = None,
|
||||||
|
provider_type: Optional[ProviderType] = None,
|
||||||
|
after: Optional[str] = None,
|
||||||
|
limit: Optional[int] = 50,
|
||||||
|
actor: PydanticUser = None,
|
||||||
|
) -> List[PydanticProvider]:
|
||||||
"""List all providers with optional pagination."""
|
"""List all providers with optional pagination."""
|
||||||
|
filter_kwargs = {}
|
||||||
|
if name:
|
||||||
|
filter_kwargs["name"] = name
|
||||||
|
if provider_type:
|
||||||
|
filter_kwargs["provider_type"] = provider_type
|
||||||
with self.session_maker() as session:
|
with self.session_maker() as session:
|
||||||
providers = ProviderModel.list(
|
providers = ProviderModel.list(
|
||||||
db_session=session,
|
db_session=session,
|
||||||
after=after,
|
after=after,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
actor=actor,
|
actor=actor,
|
||||||
|
**filter_kwargs,
|
||||||
)
|
)
|
||||||
return [provider.to_pydantic() for provider in providers]
|
return [provider.to_pydantic() for provider in providers]
|
||||||
|
|
||||||
@enforce_types
|
@enforce_types
|
||||||
def get_anthropic_override_provider_id(self) -> Optional[str]:
|
def get_provider_id_from_name(self, provider_name: Union[str, None]) -> Optional[str]:
|
||||||
"""Helper function to fetch custom anthropic provider id for v0 BYOK feature"""
|
providers = self.list_providers(name=provider_name)
|
||||||
anthropic_provider = [provider for provider in self.list_providers() if provider.name == "anthropic"]
|
return providers[0].id if providers else None
|
||||||
if len(anthropic_provider) != 0:
|
|
||||||
return anthropic_provider[0].id
|
|
||||||
return None
|
|
||||||
|
|
||||||
@enforce_types
|
@enforce_types
|
||||||
def get_anthropic_override_key(self) -> Optional[str]:
|
def get_override_key(self, provider_name: Union[str, None]) -> Optional[str]:
|
||||||
"""Helper function to fetch custom anthropic key for v0 BYOK feature"""
|
providers = self.list_providers(name=provider_name)
|
||||||
anthropic_provider = [provider for provider in self.list_providers() if provider.name == "anthropic"]
|
return providers[0].api_key if providers else None
|
||||||
if len(anthropic_provider) != 0:
|
|
||||||
return anthropic_provider[0].api_key
|
|
||||||
return None
|
|
||||||
|
@ -105,7 +105,8 @@ def check_first_response_is_valid_for_llm_endpoint(filename: str, validate_inner
|
|||||||
agent = Agent(agent_state=full_agent_state, interface=None, user=client.user)
|
agent = Agent(agent_state=full_agent_state, interface=None, user=client.user)
|
||||||
|
|
||||||
llm_client = LLMClient.create(
|
llm_client = LLMClient.create(
|
||||||
provider=agent_state.llm_config.model_endpoint_type,
|
provider_name=agent_state.llm_config.provider_name,
|
||||||
|
provider_type=agent_state.llm_config.model_endpoint_type,
|
||||||
)
|
)
|
||||||
if llm_client:
|
if llm_client:
|
||||||
response = llm_client.send_llm_request(
|
response = llm_client.send_llm_request(
|
||||||
|
@ -19,97 +19,166 @@ from letta.settings import model_settings
|
|||||||
def test_openai():
|
def test_openai():
|
||||||
api_key = os.getenv("OPENAI_API_KEY")
|
api_key = os.getenv("OPENAI_API_KEY")
|
||||||
assert api_key is not None
|
assert api_key is not None
|
||||||
provider = OpenAIProvider(api_key=api_key, base_url=model_settings.openai_api_base)
|
provider = OpenAIProvider(
|
||||||
|
name="openai",
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=model_settings.openai_api_base,
|
||||||
|
)
|
||||||
models = provider.list_llm_models()
|
models = provider.list_llm_models()
|
||||||
print(models)
|
assert len(models) > 0
|
||||||
|
assert models[0].handle == f"{provider.name}/{models[0].model}"
|
||||||
|
|
||||||
|
embedding_models = provider.list_embedding_models()
|
||||||
|
assert len(embedding_models) > 0
|
||||||
|
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
|
||||||
|
|
||||||
|
|
||||||
def test_deepseek():
|
def test_deepseek():
|
||||||
api_key = os.getenv("DEEPSEEK_API_KEY")
|
api_key = os.getenv("DEEPSEEK_API_KEY")
|
||||||
assert api_key is not None
|
assert api_key is not None
|
||||||
provider = DeepSeekProvider(api_key=api_key)
|
provider = DeepSeekProvider(
|
||||||
|
name="deepseek",
|
||||||
|
api_key=api_key,
|
||||||
|
)
|
||||||
models = provider.list_llm_models()
|
models = provider.list_llm_models()
|
||||||
print(models)
|
assert len(models) > 0
|
||||||
|
assert models[0].handle == f"{provider.name}/{models[0].model}"
|
||||||
|
|
||||||
|
|
||||||
def test_anthropic():
|
def test_anthropic():
|
||||||
api_key = os.getenv("ANTHROPIC_API_KEY")
|
api_key = os.getenv("ANTHROPIC_API_KEY")
|
||||||
assert api_key is not None
|
assert api_key is not None
|
||||||
provider = AnthropicProvider(api_key=api_key)
|
provider = AnthropicProvider(
|
||||||
|
name="anthropic",
|
||||||
|
api_key=api_key,
|
||||||
|
)
|
||||||
models = provider.list_llm_models()
|
models = provider.list_llm_models()
|
||||||
print(models)
|
assert len(models) > 0
|
||||||
|
assert models[0].handle == f"{provider.name}/{models[0].model}"
|
||||||
|
|
||||||
|
|
||||||
def test_groq():
|
def test_groq():
|
||||||
provider = GroqProvider(api_key=os.getenv("GROQ_API_KEY"))
|
provider = GroqProvider(
|
||||||
|
name="groq",
|
||||||
|
api_key=os.getenv("GROQ_API_KEY"),
|
||||||
|
)
|
||||||
models = provider.list_llm_models()
|
models = provider.list_llm_models()
|
||||||
print(models)
|
assert len(models) > 0
|
||||||
|
assert models[0].handle == f"{provider.name}/{models[0].model}"
|
||||||
|
|
||||||
|
|
||||||
def test_azure():
|
def test_azure():
|
||||||
provider = AzureProvider(api_key=os.getenv("AZURE_API_KEY"), base_url=os.getenv("AZURE_BASE_URL"))
|
provider = AzureProvider(
|
||||||
|
name="azure",
|
||||||
|
api_key=os.getenv("AZURE_API_KEY"),
|
||||||
|
base_url=os.getenv("AZURE_BASE_URL"),
|
||||||
|
)
|
||||||
models = provider.list_llm_models()
|
models = provider.list_llm_models()
|
||||||
print([m.model for m in models])
|
assert len(models) > 0
|
||||||
|
assert models[0].handle == f"{provider.name}/{models[0].model}"
|
||||||
|
|
||||||
embed_models = provider.list_embedding_models()
|
embedding_models = provider.list_embedding_models()
|
||||||
print([m.embedding_model for m in embed_models])
|
assert len(embedding_models) > 0
|
||||||
|
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
|
||||||
|
|
||||||
|
|
||||||
def test_ollama():
|
def test_ollama():
|
||||||
base_url = os.getenv("OLLAMA_BASE_URL")
|
base_url = os.getenv("OLLAMA_BASE_URL")
|
||||||
assert base_url is not None
|
assert base_url is not None
|
||||||
provider = OllamaProvider(base_url=base_url, default_prompt_formatter=model_settings.default_prompt_formatter, api_key=None)
|
provider = OllamaProvider(
|
||||||
|
name="ollama",
|
||||||
|
base_url=base_url,
|
||||||
|
default_prompt_formatter=model_settings.default_prompt_formatter,
|
||||||
|
api_key=None,
|
||||||
|
)
|
||||||
models = provider.list_llm_models()
|
models = provider.list_llm_models()
|
||||||
print(models)
|
assert len(models) > 0
|
||||||
|
assert models[0].handle == f"{provider.name}/{models[0].model}"
|
||||||
|
|
||||||
embedding_models = provider.list_embedding_models()
|
embedding_models = provider.list_embedding_models()
|
||||||
print(embedding_models)
|
assert len(embedding_models) > 0
|
||||||
|
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
|
||||||
|
|
||||||
|
|
||||||
def test_googleai():
|
def test_googleai():
|
||||||
api_key = os.getenv("GEMINI_API_KEY")
|
api_key = os.getenv("GEMINI_API_KEY")
|
||||||
assert api_key is not None
|
assert api_key is not None
|
||||||
provider = GoogleAIProvider(api_key=api_key)
|
provider = GoogleAIProvider(
|
||||||
|
name="google_ai",
|
||||||
|
api_key=api_key,
|
||||||
|
)
|
||||||
models = provider.list_llm_models()
|
models = provider.list_llm_models()
|
||||||
print(models)
|
assert len(models) > 0
|
||||||
|
assert models[0].handle == f"{provider.name}/{models[0].model}"
|
||||||
|
|
||||||
provider.list_embedding_models()
|
embedding_models = provider.list_embedding_models()
|
||||||
|
assert len(embedding_models) > 0
|
||||||
|
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
|
||||||
|
|
||||||
|
|
||||||
def test_google_vertex():
|
def test_google_vertex():
|
||||||
provider = GoogleVertexProvider(google_cloud_project=os.getenv("GCP_PROJECT_ID"), google_cloud_location=os.getenv("GCP_REGION"))
|
provider = GoogleVertexProvider(
|
||||||
|
name="google_vertex",
|
||||||
|
google_cloud_project=os.getenv("GCP_PROJECT_ID"),
|
||||||
|
google_cloud_location=os.getenv("GCP_REGION"),
|
||||||
|
)
|
||||||
models = provider.list_llm_models()
|
models = provider.list_llm_models()
|
||||||
print(models)
|
assert len(models) > 0
|
||||||
print([m.model for m in models])
|
assert models[0].handle == f"{provider.name}/{models[0].model}"
|
||||||
|
|
||||||
embedding_models = provider.list_embedding_models()
|
embedding_models = provider.list_embedding_models()
|
||||||
print([m.embedding_model for m in embedding_models])
|
assert len(embedding_models) > 0
|
||||||
|
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
|
||||||
|
|
||||||
|
|
||||||
def test_mistral():
|
def test_mistral():
|
||||||
provider = MistralProvider(api_key=os.getenv("MISTRAL_API_KEY"))
|
provider = MistralProvider(
|
||||||
|
name="mistral",
|
||||||
|
api_key=os.getenv("MISTRAL_API_KEY"),
|
||||||
|
)
|
||||||
models = provider.list_llm_models()
|
models = provider.list_llm_models()
|
||||||
print([m.model for m in models])
|
assert len(models) > 0
|
||||||
|
assert models[0].handle == f"{provider.name}/{models[0].model}"
|
||||||
|
|
||||||
|
|
||||||
def test_together():
|
def test_together():
|
||||||
provider = TogetherProvider(api_key=os.getenv("TOGETHER_API_KEY"), default_prompt_formatter="chatml")
|
provider = TogetherProvider(
|
||||||
|
name="together",
|
||||||
|
api_key=os.getenv("TOGETHER_API_KEY"),
|
||||||
|
default_prompt_formatter="chatml",
|
||||||
|
)
|
||||||
models = provider.list_llm_models()
|
models = provider.list_llm_models()
|
||||||
print([m.model for m in models])
|
assert len(models) > 0
|
||||||
|
assert models[0].handle == f"{provider.name}/{models[0].model}"
|
||||||
|
|
||||||
embedding_models = provider.list_embedding_models()
|
embedding_models = provider.list_embedding_models()
|
||||||
print([m.embedding_model for m in embedding_models])
|
assert len(embedding_models) > 0
|
||||||
|
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
|
||||||
|
|
||||||
|
|
||||||
def test_anthropic_bedrock():
|
def test_anthropic_bedrock():
|
||||||
from letta.settings import model_settings
|
from letta.settings import model_settings
|
||||||
|
|
||||||
provider = AnthropicBedrockProvider(aws_region=model_settings.aws_region)
|
provider = AnthropicBedrockProvider(name="bedrock", aws_region=model_settings.aws_region)
|
||||||
models = provider.list_llm_models()
|
models = provider.list_llm_models()
|
||||||
print([m.model for m in models])
|
assert len(models) > 0
|
||||||
|
assert models[0].handle == f"{provider.name}/{models[0].model}"
|
||||||
|
|
||||||
embedding_models = provider.list_embedding_models()
|
embedding_models = provider.list_embedding_models()
|
||||||
print([m.embedding_model for m in embedding_models])
|
assert len(embedding_models) > 0
|
||||||
|
assert embedding_models[0].handle == f"{provider.name}/{embedding_models[0].embedding_model}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_anthropic():
|
||||||
|
api_key = os.getenv("ANTHROPIC_API_KEY")
|
||||||
|
assert api_key is not None
|
||||||
|
provider = AnthropicProvider(
|
||||||
|
name="custom_anthropic",
|
||||||
|
api_key=api_key,
|
||||||
|
)
|
||||||
|
models = provider.list_llm_models()
|
||||||
|
assert len(models) > 0
|
||||||
|
assert models[0].handle == f"{provider.name}/{models[0].model}"
|
||||||
|
|
||||||
|
|
||||||
# def test_vllm():
|
# def test_vllm():
|
||||||
|
@ -13,7 +13,7 @@ import letta.utils as utils
|
|||||||
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, LETTA_DIR, LETTA_TOOL_EXECUTION_DIR
|
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, LETTA_DIR, LETTA_TOOL_EXECUTION_DIR
|
||||||
from letta.orm import Provider, Step
|
from letta.orm import Provider, Step
|
||||||
from letta.schemas.block import CreateBlock
|
from letta.schemas.block import CreateBlock
|
||||||
from letta.schemas.enums import MessageRole
|
from letta.schemas.enums import MessageRole, ProviderType
|
||||||
from letta.schemas.letta_message import LettaMessage, ReasoningMessage, SystemMessage, ToolCallMessage, ToolReturnMessage, UserMessage
|
from letta.schemas.letta_message import LettaMessage, ReasoningMessage, SystemMessage, ToolCallMessage, ToolReturnMessage, UserMessage
|
||||||
from letta.schemas.llm_config import LLMConfig
|
from letta.schemas.llm_config import LLMConfig
|
||||||
from letta.schemas.providers import Provider as PydanticProvider
|
from letta.schemas.providers import Provider as PydanticProvider
|
||||||
@ -1226,7 +1226,8 @@ def test_messages_with_provider_override(server: SyncServer, user_id: str):
|
|||||||
actor = server.user_manager.get_user_or_default(user_id)
|
actor = server.user_manager.get_user_or_default(user_id)
|
||||||
provider = server.provider_manager.create_provider(
|
provider = server.provider_manager.create_provider(
|
||||||
provider=PydanticProvider(
|
provider=PydanticProvider(
|
||||||
name="anthropic",
|
name="caren-anthropic",
|
||||||
|
provider_type=ProviderType.anthropic,
|
||||||
api_key=os.getenv("ANTHROPIC_API_KEY"),
|
api_key=os.getenv("ANTHROPIC_API_KEY"),
|
||||||
),
|
),
|
||||||
actor=actor,
|
actor=actor,
|
||||||
@ -1234,8 +1235,8 @@ def test_messages_with_provider_override(server: SyncServer, user_id: str):
|
|||||||
agent = server.create_agent(
|
agent = server.create_agent(
|
||||||
request=CreateAgent(
|
request=CreateAgent(
|
||||||
memory_blocks=[],
|
memory_blocks=[],
|
||||||
model="anthropic/claude-3-opus-20240229",
|
model="caren-anthropic/claude-3-opus-20240229",
|
||||||
context_window_limit=200000,
|
context_window_limit=100000,
|
||||||
embedding="openai/text-embedding-ada-002",
|
embedding="openai/text-embedding-ada-002",
|
||||||
),
|
),
|
||||||
actor=actor,
|
actor=actor,
|
||||||
|
Loading…
Reference in New Issue
Block a user