diff --git a/letta/agent.py b/letta/agent.py index 6186a6c7d..5720a4e68 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -30,7 +30,7 @@ from letta.persistence_manager import LocalStateManager from letta.schemas.agent import AgentState, AgentStepResponse from letta.schemas.block import Block from letta.schemas.embedding_config import EmbeddingConfig -from letta.schemas.enums import MessageRole, OptionState +from letta.schemas.enums import MessageRole from letta.schemas.memory import ContextWindowOverview, Memory from letta.schemas.message import Message, UpdateMessage from letta.schemas.openai.chat_completion_response import ChatCompletionResponse @@ -463,15 +463,14 @@ class Agent(BaseAgent): function_call: str = "auto", first_message: bool = False, # hint stream: bool = False, # TODO move to config? - inner_thoughts_in_kwargs_option: OptionState = OptionState.DEFAULT, ) -> ChatCompletionResponse: """Get response from LLM API""" try: response = create( # agent_state=self.agent_state, llm_config=self.agent_state.llm_config, - user_id=self.agent_state.user_id, messages=message_sequence, + user_id=self.agent_state.user_id, functions=self.functions, functions_python=self.functions_python, function_call=function_call, @@ -480,8 +479,6 @@ class Agent(BaseAgent): # streaming stream=stream, stream_interface=self.interface, - # putting inner thoughts in func args or not - inner_thoughts_in_kwargs_option=inner_thoughts_in_kwargs_option, ) if len(response.choices) == 0 or response.choices[0] is None: @@ -822,7 +819,6 @@ class Agent(BaseAgent): first_message_retry_limit: int = FIRST_MESSAGE_ATTEMPTS, skip_verify: bool = False, stream: bool = False, # TODO move to config? - inner_thoughts_in_kwargs_option: OptionState = OptionState.DEFAULT, ms: Optional[MetadataStore] = None, ) -> AgentStepResponse: """Runs a single step in the agent loop (generates at most one LLM call)""" @@ -861,10 +857,7 @@ class Agent(BaseAgent): counter = 0 while True: response = self._get_ai_reply( - message_sequence=input_message_sequence, - first_message=True, # passed through to the prompt formatter - stream=stream, - inner_thoughts_in_kwargs_option=inner_thoughts_in_kwargs_option, + message_sequence=input_message_sequence, first_message=True, stream=stream # passed through to the prompt formatter ) if verify_first_message_correctness(response, require_monologue=self.first_message_verify_mono): break @@ -877,7 +870,6 @@ class Agent(BaseAgent): response = self._get_ai_reply( message_sequence=input_message_sequence, stream=stream, - inner_thoughts_in_kwargs_option=inner_thoughts_in_kwargs_option, ) # Step 3: check if LLM wanted to call a function @@ -954,7 +946,6 @@ class Agent(BaseAgent): first_message_retry_limit=first_message_retry_limit, skip_verify=skip_verify, stream=stream, - inner_thoughts_in_kwargs_option=inner_thoughts_in_kwargs_option, ms=ms, ) diff --git a/letta/cli/cli.py b/letta/cli/cli.py index 964517179..3a21f2c75 100644 --- a/letta/cli/cli.py +++ b/letta/cli/cli.py @@ -320,7 +320,6 @@ def run( ms=ms, no_verify=no_verify, stream=stream, - inner_thoughts_in_kwargs=no_content, ) # TODO: add back no_verify diff --git a/letta/llm_api/anthropic.py b/letta/llm_api/anthropic.py index e2385d5de..28095b569 100644 --- a/letta/llm_api/anthropic.py +++ b/letta/llm_api/anthropic.py @@ -53,7 +53,7 @@ def anthropic_get_model_list(url: str, api_key: Union[str, None]) -> dict: return MODEL_LIST -def convert_tools_to_anthropic_format(tools: List[Tool], inner_thoughts_in_kwargs: Optional[bool] = True) -> List[dict]: +def convert_tools_to_anthropic_format(tools: List[Tool]) -> List[dict]: """See: https://docs.anthropic.com/claude/docs/tool-use OpenAI style: diff --git a/letta/llm_api/helpers.py b/letta/llm_api/helpers.py index 05b36f3b7..f35c9a918 100644 --- a/letta/llm_api/helpers.py +++ b/letta/llm_api/helpers.py @@ -6,7 +6,6 @@ from typing import Any, List, Union import requests from letta.constants import OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING -from letta.schemas.enums import OptionState from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice from letta.utils import json_dumps, printd @@ -200,17 +199,3 @@ def is_context_overflow_error(exception: Union[requests.exceptions.RequestExcept # Generic fail else: return False - - -def derive_inner_thoughts_in_kwargs(inner_thoughts_in_kwargs_option: OptionState, model: str): - if inner_thoughts_in_kwargs_option == OptionState.DEFAULT: - # model that are known to not use `content` fields on tool calls - inner_thoughts_in_kwargs = "gpt-4o" in model or "gpt-4-turbo" in model or "gpt-3.5-turbo" in model - else: - inner_thoughts_in_kwargs = True if inner_thoughts_in_kwargs_option == OptionState.YES else False - - if not isinstance(inner_thoughts_in_kwargs, bool): - warnings.warn(f"Bad type detected: {type(inner_thoughts_in_kwargs)}") - inner_thoughts_in_kwargs = bool(inner_thoughts_in_kwargs) - - return inner_thoughts_in_kwargs diff --git a/letta/llm_api/llm_api_tools.py b/letta/llm_api/llm_api_tools.py index 9864fafe1..7408b25b6 100644 --- a/letta/llm_api/llm_api_tools.py +++ b/letta/llm_api/llm_api_tools.py @@ -1,4 +1,3 @@ -import os import random import time from typing import List, Optional, Union @@ -8,14 +7,12 @@ import requests from letta.constants import CLI_WARNING_PREFIX from letta.llm_api.anthropic import anthropic_chat_completions_request from letta.llm_api.azure_openai import azure_openai_chat_completions_request -from letta.llm_api.cohere import cohere_chat_completions_request from letta.llm_api.google_ai import ( convert_tools_to_google_ai_format, google_ai_chat_completions_request, ) from letta.llm_api.helpers import ( add_inner_thoughts_to_functions, - derive_inner_thoughts_in_kwargs, unpack_all_inner_thoughts_from_kwargs, ) from letta.llm_api.openai import ( @@ -28,7 +25,6 @@ from letta.local_llm.constants import ( INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION, ) -from letta.schemas.enums import OptionState from letta.schemas.llm_config import LLMConfig from letta.schemas.message import Message from letta.schemas.openai.chat_completion_request import ( @@ -120,9 +116,6 @@ def create( # streaming? stream: bool = False, stream_interface: Optional[Union[AgentRefreshStreamingInterface, AgentChunkStreamingInterface]] = None, - # TODO move to llm_config? - # if unspecified (None), default to something we've tested - inner_thoughts_in_kwargs_option: OptionState = OptionState.DEFAULT, max_tokens: Optional[int] = None, model_settings: Optional[dict] = None, # TODO: eventually pass from server ) -> ChatCompletionResponse: @@ -146,10 +139,7 @@ def create( # only is a problem if we are *not* using an openai proxy raise ValueError(f"OpenAI key is missing from letta config file") - inner_thoughts_in_kwargs = derive_inner_thoughts_in_kwargs(inner_thoughts_in_kwargs_option, model=llm_config.model) - data = build_openai_chat_completions_request( - llm_config, messages, user_id, functions, function_call, use_tool_naming, inner_thoughts_in_kwargs, max_tokens - ) + data = build_openai_chat_completions_request(llm_config, messages, user_id, functions, function_call, use_tool_naming, max_tokens) if stream: # Client requested token streaming data.stream = True @@ -176,7 +166,7 @@ def create( if isinstance(stream_interface, AgentChunkStreamingInterface): stream_interface.stream_end() - if inner_thoughts_in_kwargs: + if llm_config.put_inner_thoughts_in_kwargs: response = unpack_all_inner_thoughts_from_kwargs(response=response, inner_thoughts_key=INNER_THOUGHTS_KWARG) return response @@ -198,9 +188,8 @@ def create( # Set the llm config model_endpoint from model_settings # For Azure, this model_endpoint is required to be configured via env variable, so users don't need to provide it in the LLM config llm_config.model_endpoint = model_settings.azure_base_url - inner_thoughts_in_kwargs = derive_inner_thoughts_in_kwargs(inner_thoughts_in_kwargs_option, llm_config.model) chat_completion_request = build_openai_chat_completions_request( - llm_config, messages, user_id, functions, function_call, use_tool_naming, inner_thoughts_in_kwargs, max_tokens + llm_config, messages, user_id, functions, function_call, use_tool_naming, max_tokens ) response = azure_openai_chat_completions_request( @@ -210,7 +199,7 @@ def create( chat_completion_request=chat_completion_request, ) - if inner_thoughts_in_kwargs: + if llm_config.put_inner_thoughts_in_kwargs: response = unpack_all_inner_thoughts_from_kwargs(response=response, inner_thoughts_key=INNER_THOUGHTS_KWARG) return response @@ -224,7 +213,7 @@ def create( if functions is not None: tools = [{"type": "function", "function": f} for f in functions] tools = [Tool(**t) for t in tools] - tools = convert_tools_to_google_ai_format(tools, inner_thoughts_in_kwargs=True) + tools = convert_tools_to_google_ai_format(tools, inner_thoughts_in_kwargs=llm_config.put_inner_thoughts_in_kwargs) else: tools = None @@ -237,7 +226,7 @@ def create( contents=[m.to_google_ai_dict() for m in messages], tools=tools, ), - inner_thoughts_in_kwargs=True, + inner_thoughts_in_kwargs=llm_config.put_inner_thoughts_in_kwargs, ) elif llm_config.model_endpoint_type == "anthropic": @@ -260,32 +249,32 @@ def create( ), ) - elif llm_config.model_endpoint_type == "cohere": - if stream: - raise NotImplementedError(f"Streaming not yet implemented for {llm_config.model_endpoint_type}") - if not use_tool_naming: - raise NotImplementedError("Only tool calling supported on Cohere API requests") - - if functions is not None: - tools = [{"type": "function", "function": f} for f in functions] - tools = [Tool(**t) for t in tools] - else: - tools = None - - return cohere_chat_completions_request( - # url=llm_config.model_endpoint, - url="https://api.cohere.ai/v1", # TODO - api_key=os.getenv("COHERE_API_KEY"), # TODO remove - chat_completion_request=ChatCompletionRequest( - model="command-r-plus", # TODO - messages=[cast_message_to_subtype(m.to_openai_dict()) for m in messages], - tools=tools, - tool_choice=function_call, - # user=str(user_id), - # NOTE: max_tokens is required for Anthropic API - # max_tokens=1024, # TODO make dynamic - ), - ) + # elif llm_config.model_endpoint_type == "cohere": + # if stream: + # raise NotImplementedError(f"Streaming not yet implemented for {llm_config.model_endpoint_type}") + # if not use_tool_naming: + # raise NotImplementedError("Only tool calling supported on Cohere API requests") + # + # if functions is not None: + # tools = [{"type": "function", "function": f} for f in functions] + # tools = [Tool(**t) for t in tools] + # else: + # tools = None + # + # return cohere_chat_completions_request( + # # url=llm_config.model_endpoint, + # url="https://api.cohere.ai/v1", # TODO + # api_key=os.getenv("COHERE_API_KEY"), # TODO remove + # chat_completion_request=ChatCompletionRequest( + # model="command-r-plus", # TODO + # messages=[cast_message_to_subtype(m.to_openai_dict()) for m in messages], + # tools=tools, + # tool_choice=function_call, + # # user=str(user_id), + # # NOTE: max_tokens is required for Anthropic API + # # max_tokens=1024, # TODO make dynamic + # ), + # ) elif llm_config.model_endpoint_type == "groq": if stream: @@ -295,8 +284,7 @@ def create( raise ValueError(f"Groq key is missing from letta config file") # force to true for groq, since they don't support 'content' is non-null - inner_thoughts_in_kwargs = True - if inner_thoughts_in_kwargs: + if llm_config.put_inner_thoughts_in_kwargs: functions = add_inner_thoughts_to_functions( functions=functions, inner_thoughts_key=INNER_THOUGHTS_KWARG, @@ -306,7 +294,7 @@ def create( tools = [{"type": "function", "function": f} for f in functions] if functions is not None else None data = ChatCompletionRequest( model=llm_config.model, - messages=[m.to_openai_dict(put_inner_thoughts_in_kwargs=inner_thoughts_in_kwargs) for m in messages], + messages=[m.to_openai_dict(put_inner_thoughts_in_kwargs=llm_config.put_inner_thoughts_in_kwargs) for m in messages], tools=tools, tool_choice=function_call, user=str(user_id), @@ -335,7 +323,7 @@ def create( if isinstance(stream_interface, AgentChunkStreamingInterface): stream_interface.stream_end() - if inner_thoughts_in_kwargs: + if llm_config.put_inner_thoughts_in_kwargs: response = unpack_all_inner_thoughts_from_kwargs(response=response, inner_thoughts_key=INNER_THOUGHTS_KWARG) return response diff --git a/letta/llm_api/openai.py b/letta/llm_api/openai.py index 3d203fe2c..55370de5a 100644 --- a/letta/llm_api/openai.py +++ b/letta/llm_api/openai.py @@ -105,10 +105,9 @@ def build_openai_chat_completions_request( functions: Optional[list], function_call: str, use_tool_naming: bool, - inner_thoughts_in_kwargs: bool, max_tokens: Optional[int], ) -> ChatCompletionRequest: - if inner_thoughts_in_kwargs: + if llm_config.put_inner_thoughts_in_kwargs: functions = add_inner_thoughts_to_functions( functions=functions, inner_thoughts_key=INNER_THOUGHTS_KWARG, @@ -116,7 +115,7 @@ def build_openai_chat_completions_request( ) openai_message_list = [ - cast_message_to_subtype(m.to_openai_dict(put_inner_thoughts_in_kwargs=inner_thoughts_in_kwargs)) for m in messages + cast_message_to_subtype(m.to_openai_dict(put_inner_thoughts_in_kwargs=llm_config.put_inner_thoughts_in_kwargs)) for m in messages ] if llm_config.model: model = llm_config.model diff --git a/letta/main.py b/letta/main.py index f16eb8951..bca3c34a9 100644 --- a/letta/main.py +++ b/letta/main.py @@ -20,7 +20,6 @@ from letta.cli.cli_load import app as load_app from letta.config import LettaConfig from letta.constants import FUNC_FAILED_HEARTBEAT_MESSAGE, REQ_HEARTBEAT_MESSAGE from letta.metadata import MetadataStore -from letta.schemas.enums import OptionState # from letta.interface import CLIInterface as interface # for printing to terminal from letta.streaming_interface import AgentRefreshStreamingInterface @@ -64,7 +63,6 @@ def run_agent_loop( no_verify: bool = False, strip_ui: bool = False, stream: bool = False, - inner_thoughts_in_kwargs: OptionState = OptionState.DEFAULT, ): if isinstance(letta_agent.interface, AgentRefreshStreamingInterface): # letta_agent.interface.toggle_streaming(on=stream) @@ -369,7 +367,6 @@ def run_agent_loop( first_message=False, skip_verify=no_verify, stream=stream, - inner_thoughts_in_kwargs_option=inner_thoughts_in_kwargs, ms=ms, ) else: @@ -378,7 +375,6 @@ def run_agent_loop( first_message=False, skip_verify=no_verify, stream=stream, - inner_thoughts_in_kwargs_option=inner_thoughts_in_kwargs, ms=ms, ) new_messages = step_response.messages diff --git a/letta/schemas/llm_config.py b/letta/schemas/llm_config.py index ffdb60cff..867eaa9a7 100644 --- a/letta/schemas/llm_config.py +++ b/letta/schemas/llm_config.py @@ -1,6 +1,6 @@ from typing import Literal, Optional -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field, root_validator class LLMConfig(BaseModel): @@ -13,6 +13,7 @@ class LLMConfig(BaseModel): model_endpoint (str): The endpoint for the model. model_wrapper (str): The wrapper for the model. This is used to wrap additional text around the input/output of the model. This is useful for text-to-text completions, such as the Completions API in OpenAI. context_window (int): The context window size for the model. + put_inner_thoughts_in_kwargs (bool): Puts 'inner_thoughts' as a kwarg in the function call if this is set to True. This helps with function calling performance and also the generation of inner thoughts. """ # TODO: 🤮 don't default to a vendor! bug city! @@ -38,10 +39,32 @@ class LLMConfig(BaseModel): model_endpoint: Optional[str] = Field(None, description="The endpoint for the model.") model_wrapper: Optional[str] = Field(None, description="The wrapper for the model.") context_window: int = Field(..., description="The context window size for the model.") + put_inner_thoughts_in_kwargs: Optional[bool] = Field( + True, + description="Puts 'inner_thoughts' as a kwarg in the function call if this is set to True. This helps with function calling performance and also the generation of inner thoughts.", + ) # FIXME hack to silence pydantic protected namespace warning model_config = ConfigDict(protected_namespaces=()) + @root_validator(pre=True) + def set_default_put_inner_thoughts(cls, values): + """ + Dynamically set the default for put_inner_thoughts_in_kwargs based on the model field, + falling back to True if no specific rule is defined. + """ + model = values.get("model") + + # Define models where we want put_inner_thoughts_in_kwargs to be False + # For now it is gpt-4 + avoid_put_inner_thoughts_in_kwargs = ["gpt-4"] + + # Only modify the value if it's None or not provided + if values.get("put_inner_thoughts_in_kwargs") is None: + values["put_inner_thoughts_in_kwargs"] = False if model in avoid_put_inner_thoughts_in_kwargs else True + + return values + @classmethod def default_config(cls, model_name: str): if model_name == "gpt-4": diff --git a/tests/configs/llm_model_configs/azure-gpt-4o-mini.json b/tests/configs/llm_model_configs/azure-gpt-4o-mini.json index 323b2cae9..b91e9e6c1 100644 --- a/tests/configs/llm_model_configs/azure-gpt-4o-mini.json +++ b/tests/configs/llm_model_configs/azure-gpt-4o-mini.json @@ -2,5 +2,6 @@ "context_window": 128000, "model": "gpt-4o-mini", "model_endpoint_type": "azure", - "model_wrapper": null + "model_wrapper": null, + "put_inner_thoughts_in_kwargs": true } diff --git a/tests/configs/llm_model_configs/claude-3-opus.json b/tests/configs/llm_model_configs/claude-3-opus.json index 6281aa964..9516b8708 100644 --- a/tests/configs/llm_model_configs/claude-3-opus.json +++ b/tests/configs/llm_model_configs/claude-3-opus.json @@ -3,5 +3,6 @@ "model": "claude-3-opus-20240229", "model_endpoint_type": "anthropic", "model_endpoint": "https://api.anthropic.com/v1", - "model_wrapper": null + "model_wrapper": null, + "put_inner_thoughts_in_kwargs": true } diff --git a/tests/configs/llm_model_configs/gemini-pro.json b/tests/configs/llm_model_configs/gemini-pro.json index 5c425b6d1..e59252822 100644 --- a/tests/configs/llm_model_configs/gemini-pro.json +++ b/tests/configs/llm_model_configs/gemini-pro.json @@ -3,5 +3,6 @@ "model": "gemini-1.5-pro-latest", "model_endpoint_type": "google_ai", "model_endpoint": "https://generativelanguage.googleapis.com", - "model_wrapper": null + "model_wrapper": null, + "put_inner_thoughts_in_kwargs": true } diff --git a/tests/configs/llm_model_configs/gpt-4.json b/tests/configs/llm_model_configs/gpt-4.json index c572428e4..dedc8cec5 100644 --- a/tests/configs/llm_model_configs/gpt-4.json +++ b/tests/configs/llm_model_configs/gpt-4.json @@ -3,5 +3,6 @@ "model": "gpt-4", "model_endpoint_type": "openai", "model_endpoint": "https://api.openai.com/v1", - "model_wrapper": null + "model_wrapper": null, + "put_inner_thoughts_in_kwargs": false } diff --git a/tests/configs/llm_model_configs/groq.json b/tests/configs/llm_model_configs/groq.json index 62cc875b3..5f5c92f97 100644 --- a/tests/configs/llm_model_configs/groq.json +++ b/tests/configs/llm_model_configs/groq.json @@ -3,5 +3,6 @@ "model": "llama-3.1-70b-versatile", "model_endpoint_type": "groq", "model_endpoint": "https://api.groq.com/openai/v1", - "model_wrapper": null + "model_wrapper": null, + "put_inner_thoughts_in_kwargs": true } diff --git a/tests/configs/llm_model_configs/letta-hosted.json b/tests/configs/llm_model_configs/letta-hosted.json index 3ba968226..a0367c469 100644 --- a/tests/configs/llm_model_configs/letta-hosted.json +++ b/tests/configs/llm_model_configs/letta-hosted.json @@ -2,5 +2,6 @@ "context_window": 16384, "model_endpoint_type": "openai", "model_endpoint": "https://inference.memgpt.ai", - "model": "memgpt-openai" + "model": "memgpt-openai", + "put_inner_thoughts_in_kwargs": true } diff --git a/tests/configs/llm_model_configs/ollama.json b/tests/configs/llm_model_configs/ollama.json index d18a4e772..71a7f9a42 100644 --- a/tests/configs/llm_model_configs/ollama.json +++ b/tests/configs/llm_model_configs/ollama.json @@ -2,5 +2,6 @@ "context_window": 8192, "model_endpoint_type": "ollama", "model_endpoint": "http://127.0.0.1:11434", - "model": "dolphin2.2-mistral:7b-q6_K" + "model": "dolphin2.2-mistral:7b-q6_K", + "put_inner_thoughts_in_kwargs": true } diff --git a/tests/helpers/endpoints_helper.py b/tests/helpers/endpoints_helper.py index 00e288973..1935ea4b8 100644 --- a/tests/helpers/endpoints_helper.py +++ b/tests/helpers/endpoints_helper.py @@ -3,11 +3,7 @@ import logging import uuid from typing import Callable, List, Optional, Union -from letta.llm_api.helpers import ( - derive_inner_thoughts_in_kwargs, - unpack_inner_thoughts_from_kwargs, -) -from letta.schemas.enums import OptionState +from letta.llm_api.helpers import unpack_inner_thoughts_from_kwargs logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) @@ -130,10 +126,8 @@ def check_first_response_is_valid_for_llm_endpoint(filename: str) -> ChatComplet validator_func = lambda function_call: function_call.name == "send_message" or function_call.name == "archival_memory_search" assert_contains_valid_function_call(choice.message, validator_func) - # Get inner_thoughts_in_kwargs - inner_thoughts_in_kwargs = derive_inner_thoughts_in_kwargs(OptionState.DEFAULT, agent_state.llm_config.model) # Assert that the message has an inner monologue - assert_contains_correct_inner_monologue(choice, inner_thoughts_in_kwargs) + assert_contains_correct_inner_monologue(choice, agent_state.llm_config.put_inner_thoughts_in_kwargs) return response