feat: Add put_inner_thoughts_in_kwargs as a config setting for the LLM (#1902)

Co-authored-by: Matt Zhou <mattzhou@Matts-MacBook-Pro.local>
This commit is contained in:
Matthew Zhou 2024-10-17 15:54:03 -07:00 committed by GitHub
parent 960303615d
commit 01fdf8684c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 81 additions and 99 deletions

View File

@ -30,7 +30,7 @@ from letta.persistence_manager import LocalStateManager
from letta.schemas.agent import AgentState, AgentStepResponse
from letta.schemas.block import Block
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import MessageRole, OptionState
from letta.schemas.enums import MessageRole
from letta.schemas.memory import ContextWindowOverview, Memory
from letta.schemas.message import Message, UpdateMessage
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
@ -463,15 +463,14 @@ class Agent(BaseAgent):
function_call: str = "auto",
first_message: bool = False, # hint
stream: bool = False, # TODO move to config?
inner_thoughts_in_kwargs_option: OptionState = OptionState.DEFAULT,
) -> ChatCompletionResponse:
"""Get response from LLM API"""
try:
response = create(
# agent_state=self.agent_state,
llm_config=self.agent_state.llm_config,
user_id=self.agent_state.user_id,
messages=message_sequence,
user_id=self.agent_state.user_id,
functions=self.functions,
functions_python=self.functions_python,
function_call=function_call,
@ -480,8 +479,6 @@ class Agent(BaseAgent):
# streaming
stream=stream,
stream_interface=self.interface,
# putting inner thoughts in func args or not
inner_thoughts_in_kwargs_option=inner_thoughts_in_kwargs_option,
)
if len(response.choices) == 0 or response.choices[0] is None:
@ -822,7 +819,6 @@ class Agent(BaseAgent):
first_message_retry_limit: int = FIRST_MESSAGE_ATTEMPTS,
skip_verify: bool = False,
stream: bool = False, # TODO move to config?
inner_thoughts_in_kwargs_option: OptionState = OptionState.DEFAULT,
ms: Optional[MetadataStore] = None,
) -> AgentStepResponse:
"""Runs a single step in the agent loop (generates at most one LLM call)"""
@ -861,10 +857,7 @@ class Agent(BaseAgent):
counter = 0
while True:
response = self._get_ai_reply(
message_sequence=input_message_sequence,
first_message=True, # passed through to the prompt formatter
stream=stream,
inner_thoughts_in_kwargs_option=inner_thoughts_in_kwargs_option,
message_sequence=input_message_sequence, first_message=True, stream=stream # passed through to the prompt formatter
)
if verify_first_message_correctness(response, require_monologue=self.first_message_verify_mono):
break
@ -877,7 +870,6 @@ class Agent(BaseAgent):
response = self._get_ai_reply(
message_sequence=input_message_sequence,
stream=stream,
inner_thoughts_in_kwargs_option=inner_thoughts_in_kwargs_option,
)
# Step 3: check if LLM wanted to call a function
@ -954,7 +946,6 @@ class Agent(BaseAgent):
first_message_retry_limit=first_message_retry_limit,
skip_verify=skip_verify,
stream=stream,
inner_thoughts_in_kwargs_option=inner_thoughts_in_kwargs_option,
ms=ms,
)

View File

@ -320,7 +320,6 @@ def run(
ms=ms,
no_verify=no_verify,
stream=stream,
inner_thoughts_in_kwargs=no_content,
) # TODO: add back no_verify

View File

@ -53,7 +53,7 @@ def anthropic_get_model_list(url: str, api_key: Union[str, None]) -> dict:
return MODEL_LIST
def convert_tools_to_anthropic_format(tools: List[Tool], inner_thoughts_in_kwargs: Optional[bool] = True) -> List[dict]:
def convert_tools_to_anthropic_format(tools: List[Tool]) -> List[dict]:
"""See: https://docs.anthropic.com/claude/docs/tool-use
OpenAI style:

View File

@ -6,7 +6,6 @@ from typing import Any, List, Union
import requests
from letta.constants import OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING
from letta.schemas.enums import OptionState
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice
from letta.utils import json_dumps, printd
@ -200,17 +199,3 @@ def is_context_overflow_error(exception: Union[requests.exceptions.RequestExcept
# Generic fail
else:
return False
def derive_inner_thoughts_in_kwargs(inner_thoughts_in_kwargs_option: OptionState, model: str):
if inner_thoughts_in_kwargs_option == OptionState.DEFAULT:
# model that are known to not use `content` fields on tool calls
inner_thoughts_in_kwargs = "gpt-4o" in model or "gpt-4-turbo" in model or "gpt-3.5-turbo" in model
else:
inner_thoughts_in_kwargs = True if inner_thoughts_in_kwargs_option == OptionState.YES else False
if not isinstance(inner_thoughts_in_kwargs, bool):
warnings.warn(f"Bad type detected: {type(inner_thoughts_in_kwargs)}")
inner_thoughts_in_kwargs = bool(inner_thoughts_in_kwargs)
return inner_thoughts_in_kwargs

View File

@ -1,4 +1,3 @@
import os
import random
import time
from typing import List, Optional, Union
@ -8,14 +7,12 @@ import requests
from letta.constants import CLI_WARNING_PREFIX
from letta.llm_api.anthropic import anthropic_chat_completions_request
from letta.llm_api.azure_openai import azure_openai_chat_completions_request
from letta.llm_api.cohere import cohere_chat_completions_request
from letta.llm_api.google_ai import (
convert_tools_to_google_ai_format,
google_ai_chat_completions_request,
)
from letta.llm_api.helpers import (
add_inner_thoughts_to_functions,
derive_inner_thoughts_in_kwargs,
unpack_all_inner_thoughts_from_kwargs,
)
from letta.llm_api.openai import (
@ -28,7 +25,6 @@ from letta.local_llm.constants import (
INNER_THOUGHTS_KWARG,
INNER_THOUGHTS_KWARG_DESCRIPTION,
)
from letta.schemas.enums import OptionState
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message
from letta.schemas.openai.chat_completion_request import (
@ -120,9 +116,6 @@ def create(
# streaming?
stream: bool = False,
stream_interface: Optional[Union[AgentRefreshStreamingInterface, AgentChunkStreamingInterface]] = None,
# TODO move to llm_config?
# if unspecified (None), default to something we've tested
inner_thoughts_in_kwargs_option: OptionState = OptionState.DEFAULT,
max_tokens: Optional[int] = None,
model_settings: Optional[dict] = None, # TODO: eventually pass from server
) -> ChatCompletionResponse:
@ -146,10 +139,7 @@ def create(
# only is a problem if we are *not* using an openai proxy
raise ValueError(f"OpenAI key is missing from letta config file")
inner_thoughts_in_kwargs = derive_inner_thoughts_in_kwargs(inner_thoughts_in_kwargs_option, model=llm_config.model)
data = build_openai_chat_completions_request(
llm_config, messages, user_id, functions, function_call, use_tool_naming, inner_thoughts_in_kwargs, max_tokens
)
data = build_openai_chat_completions_request(llm_config, messages, user_id, functions, function_call, use_tool_naming, max_tokens)
if stream: # Client requested token streaming
data.stream = True
@ -176,7 +166,7 @@ def create(
if isinstance(stream_interface, AgentChunkStreamingInterface):
stream_interface.stream_end()
if inner_thoughts_in_kwargs:
if llm_config.put_inner_thoughts_in_kwargs:
response = unpack_all_inner_thoughts_from_kwargs(response=response, inner_thoughts_key=INNER_THOUGHTS_KWARG)
return response
@ -198,9 +188,8 @@ def create(
# Set the llm config model_endpoint from model_settings
# For Azure, this model_endpoint is required to be configured via env variable, so users don't need to provide it in the LLM config
llm_config.model_endpoint = model_settings.azure_base_url
inner_thoughts_in_kwargs = derive_inner_thoughts_in_kwargs(inner_thoughts_in_kwargs_option, llm_config.model)
chat_completion_request = build_openai_chat_completions_request(
llm_config, messages, user_id, functions, function_call, use_tool_naming, inner_thoughts_in_kwargs, max_tokens
llm_config, messages, user_id, functions, function_call, use_tool_naming, max_tokens
)
response = azure_openai_chat_completions_request(
@ -210,7 +199,7 @@ def create(
chat_completion_request=chat_completion_request,
)
if inner_thoughts_in_kwargs:
if llm_config.put_inner_thoughts_in_kwargs:
response = unpack_all_inner_thoughts_from_kwargs(response=response, inner_thoughts_key=INNER_THOUGHTS_KWARG)
return response
@ -224,7 +213,7 @@ def create(
if functions is not None:
tools = [{"type": "function", "function": f} for f in functions]
tools = [Tool(**t) for t in tools]
tools = convert_tools_to_google_ai_format(tools, inner_thoughts_in_kwargs=True)
tools = convert_tools_to_google_ai_format(tools, inner_thoughts_in_kwargs=llm_config.put_inner_thoughts_in_kwargs)
else:
tools = None
@ -237,7 +226,7 @@ def create(
contents=[m.to_google_ai_dict() for m in messages],
tools=tools,
),
inner_thoughts_in_kwargs=True,
inner_thoughts_in_kwargs=llm_config.put_inner_thoughts_in_kwargs,
)
elif llm_config.model_endpoint_type == "anthropic":
@ -260,32 +249,32 @@ def create(
),
)
elif llm_config.model_endpoint_type == "cohere":
if stream:
raise NotImplementedError(f"Streaming not yet implemented for {llm_config.model_endpoint_type}")
if not use_tool_naming:
raise NotImplementedError("Only tool calling supported on Cohere API requests")
if functions is not None:
tools = [{"type": "function", "function": f} for f in functions]
tools = [Tool(**t) for t in tools]
else:
tools = None
return cohere_chat_completions_request(
# url=llm_config.model_endpoint,
url="https://api.cohere.ai/v1", # TODO
api_key=os.getenv("COHERE_API_KEY"), # TODO remove
chat_completion_request=ChatCompletionRequest(
model="command-r-plus", # TODO
messages=[cast_message_to_subtype(m.to_openai_dict()) for m in messages],
tools=tools,
tool_choice=function_call,
# user=str(user_id),
# NOTE: max_tokens is required for Anthropic API
# max_tokens=1024, # TODO make dynamic
),
)
# elif llm_config.model_endpoint_type == "cohere":
# if stream:
# raise NotImplementedError(f"Streaming not yet implemented for {llm_config.model_endpoint_type}")
# if not use_tool_naming:
# raise NotImplementedError("Only tool calling supported on Cohere API requests")
#
# if functions is not None:
# tools = [{"type": "function", "function": f} for f in functions]
# tools = [Tool(**t) for t in tools]
# else:
# tools = None
#
# return cohere_chat_completions_request(
# # url=llm_config.model_endpoint,
# url="https://api.cohere.ai/v1", # TODO
# api_key=os.getenv("COHERE_API_KEY"), # TODO remove
# chat_completion_request=ChatCompletionRequest(
# model="command-r-plus", # TODO
# messages=[cast_message_to_subtype(m.to_openai_dict()) for m in messages],
# tools=tools,
# tool_choice=function_call,
# # user=str(user_id),
# # NOTE: max_tokens is required for Anthropic API
# # max_tokens=1024, # TODO make dynamic
# ),
# )
elif llm_config.model_endpoint_type == "groq":
if stream:
@ -295,8 +284,7 @@ def create(
raise ValueError(f"Groq key is missing from letta config file")
# force to true for groq, since they don't support 'content' is non-null
inner_thoughts_in_kwargs = True
if inner_thoughts_in_kwargs:
if llm_config.put_inner_thoughts_in_kwargs:
functions = add_inner_thoughts_to_functions(
functions=functions,
inner_thoughts_key=INNER_THOUGHTS_KWARG,
@ -306,7 +294,7 @@ def create(
tools = [{"type": "function", "function": f} for f in functions] if functions is not None else None
data = ChatCompletionRequest(
model=llm_config.model,
messages=[m.to_openai_dict(put_inner_thoughts_in_kwargs=inner_thoughts_in_kwargs) for m in messages],
messages=[m.to_openai_dict(put_inner_thoughts_in_kwargs=llm_config.put_inner_thoughts_in_kwargs) for m in messages],
tools=tools,
tool_choice=function_call,
user=str(user_id),
@ -335,7 +323,7 @@ def create(
if isinstance(stream_interface, AgentChunkStreamingInterface):
stream_interface.stream_end()
if inner_thoughts_in_kwargs:
if llm_config.put_inner_thoughts_in_kwargs:
response = unpack_all_inner_thoughts_from_kwargs(response=response, inner_thoughts_key=INNER_THOUGHTS_KWARG)
return response

View File

@ -105,10 +105,9 @@ def build_openai_chat_completions_request(
functions: Optional[list],
function_call: str,
use_tool_naming: bool,
inner_thoughts_in_kwargs: bool,
max_tokens: Optional[int],
) -> ChatCompletionRequest:
if inner_thoughts_in_kwargs:
if llm_config.put_inner_thoughts_in_kwargs:
functions = add_inner_thoughts_to_functions(
functions=functions,
inner_thoughts_key=INNER_THOUGHTS_KWARG,
@ -116,7 +115,7 @@ def build_openai_chat_completions_request(
)
openai_message_list = [
cast_message_to_subtype(m.to_openai_dict(put_inner_thoughts_in_kwargs=inner_thoughts_in_kwargs)) for m in messages
cast_message_to_subtype(m.to_openai_dict(put_inner_thoughts_in_kwargs=llm_config.put_inner_thoughts_in_kwargs)) for m in messages
]
if llm_config.model:
model = llm_config.model

View File

@ -20,7 +20,6 @@ from letta.cli.cli_load import app as load_app
from letta.config import LettaConfig
from letta.constants import FUNC_FAILED_HEARTBEAT_MESSAGE, REQ_HEARTBEAT_MESSAGE
from letta.metadata import MetadataStore
from letta.schemas.enums import OptionState
# from letta.interface import CLIInterface as interface # for printing to terminal
from letta.streaming_interface import AgentRefreshStreamingInterface
@ -64,7 +63,6 @@ def run_agent_loop(
no_verify: bool = False,
strip_ui: bool = False,
stream: bool = False,
inner_thoughts_in_kwargs: OptionState = OptionState.DEFAULT,
):
if isinstance(letta_agent.interface, AgentRefreshStreamingInterface):
# letta_agent.interface.toggle_streaming(on=stream)
@ -369,7 +367,6 @@ def run_agent_loop(
first_message=False,
skip_verify=no_verify,
stream=stream,
inner_thoughts_in_kwargs_option=inner_thoughts_in_kwargs,
ms=ms,
)
else:
@ -378,7 +375,6 @@ def run_agent_loop(
first_message=False,
skip_verify=no_verify,
stream=stream,
inner_thoughts_in_kwargs_option=inner_thoughts_in_kwargs,
ms=ms,
)
new_messages = step_response.messages

View File

@ -1,6 +1,6 @@
from typing import Literal, Optional
from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, ConfigDict, Field, root_validator
class LLMConfig(BaseModel):
@ -13,6 +13,7 @@ class LLMConfig(BaseModel):
model_endpoint (str): The endpoint for the model.
model_wrapper (str): The wrapper for the model. This is used to wrap additional text around the input/output of the model. This is useful for text-to-text completions, such as the Completions API in OpenAI.
context_window (int): The context window size for the model.
put_inner_thoughts_in_kwargs (bool): Puts 'inner_thoughts' as a kwarg in the function call if this is set to True. This helps with function calling performance and also the generation of inner thoughts.
"""
# TODO: 🤮 don't default to a vendor! bug city!
@ -38,10 +39,32 @@ class LLMConfig(BaseModel):
model_endpoint: Optional[str] = Field(None, description="The endpoint for the model.")
model_wrapper: Optional[str] = Field(None, description="The wrapper for the model.")
context_window: int = Field(..., description="The context window size for the model.")
put_inner_thoughts_in_kwargs: Optional[bool] = Field(
True,
description="Puts 'inner_thoughts' as a kwarg in the function call if this is set to True. This helps with function calling performance and also the generation of inner thoughts.",
)
# FIXME hack to silence pydantic protected namespace warning
model_config = ConfigDict(protected_namespaces=())
@root_validator(pre=True)
def set_default_put_inner_thoughts(cls, values):
"""
Dynamically set the default for put_inner_thoughts_in_kwargs based on the model field,
falling back to True if no specific rule is defined.
"""
model = values.get("model")
# Define models where we want put_inner_thoughts_in_kwargs to be False
# For now it is gpt-4
avoid_put_inner_thoughts_in_kwargs = ["gpt-4"]
# Only modify the value if it's None or not provided
if values.get("put_inner_thoughts_in_kwargs") is None:
values["put_inner_thoughts_in_kwargs"] = False if model in avoid_put_inner_thoughts_in_kwargs else True
return values
@classmethod
def default_config(cls, model_name: str):
if model_name == "gpt-4":

View File

@ -2,5 +2,6 @@
"context_window": 128000,
"model": "gpt-4o-mini",
"model_endpoint_type": "azure",
"model_wrapper": null
"model_wrapper": null,
"put_inner_thoughts_in_kwargs": true
}

View File

@ -3,5 +3,6 @@
"model": "claude-3-opus-20240229",
"model_endpoint_type": "anthropic",
"model_endpoint": "https://api.anthropic.com/v1",
"model_wrapper": null
"model_wrapper": null,
"put_inner_thoughts_in_kwargs": true
}

View File

@ -3,5 +3,6 @@
"model": "gemini-1.5-pro-latest",
"model_endpoint_type": "google_ai",
"model_endpoint": "https://generativelanguage.googleapis.com",
"model_wrapper": null
"model_wrapper": null,
"put_inner_thoughts_in_kwargs": true
}

View File

@ -3,5 +3,6 @@
"model": "gpt-4",
"model_endpoint_type": "openai",
"model_endpoint": "https://api.openai.com/v1",
"model_wrapper": null
"model_wrapper": null,
"put_inner_thoughts_in_kwargs": false
}

View File

@ -3,5 +3,6 @@
"model": "llama-3.1-70b-versatile",
"model_endpoint_type": "groq",
"model_endpoint": "https://api.groq.com/openai/v1",
"model_wrapper": null
"model_wrapper": null,
"put_inner_thoughts_in_kwargs": true
}

View File

@ -2,5 +2,6 @@
"context_window": 16384,
"model_endpoint_type": "openai",
"model_endpoint": "https://inference.memgpt.ai",
"model": "memgpt-openai"
"model": "memgpt-openai",
"put_inner_thoughts_in_kwargs": true
}

View File

@ -2,5 +2,6 @@
"context_window": 8192,
"model_endpoint_type": "ollama",
"model_endpoint": "http://127.0.0.1:11434",
"model": "dolphin2.2-mistral:7b-q6_K"
"model": "dolphin2.2-mistral:7b-q6_K",
"put_inner_thoughts_in_kwargs": true
}

View File

@ -3,11 +3,7 @@ import logging
import uuid
from typing import Callable, List, Optional, Union
from letta.llm_api.helpers import (
derive_inner_thoughts_in_kwargs,
unpack_inner_thoughts_from_kwargs,
)
from letta.schemas.enums import OptionState
from letta.llm_api.helpers import unpack_inner_thoughts_from_kwargs
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
@ -130,10 +126,8 @@ def check_first_response_is_valid_for_llm_endpoint(filename: str) -> ChatComplet
validator_func = lambda function_call: function_call.name == "send_message" or function_call.name == "archival_memory_search"
assert_contains_valid_function_call(choice.message, validator_func)
# Get inner_thoughts_in_kwargs
inner_thoughts_in_kwargs = derive_inner_thoughts_in_kwargs(OptionState.DEFAULT, agent_state.llm_config.model)
# Assert that the message has an inner monologue
assert_contains_correct_inner_monologue(choice, inner_thoughts_in_kwargs)
assert_contains_correct_inner_monologue(choice, agent_state.llm_config.put_inner_thoughts_in_kwargs)
return response