feat: Rework summarizer (#654)

This commit is contained in:
Matthew Zhou 2025-01-22 09:19:26 -10:00 committed by GitHub
parent 434b94f088
commit cd75ebe51c
11 changed files with 289 additions and 165 deletions

View File

@ -13,9 +13,6 @@ from letta.constants import (
LETTA_CORE_TOOL_MODULE_NAME, LETTA_CORE_TOOL_MODULE_NAME,
LETTA_MULTI_AGENT_TOOL_MODULE_NAME, LETTA_MULTI_AGENT_TOOL_MODULE_NAME,
LLM_MAX_TOKENS, LLM_MAX_TOKENS,
MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST,
MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC,
MESSAGE_SUMMARY_WARNING_FRAC,
REQ_HEARTBEAT_MESSAGE, REQ_HEARTBEAT_MESSAGE,
) )
from letta.errors import ContextWindowExceededError from letta.errors import ContextWindowExceededError
@ -23,7 +20,7 @@ from letta.functions.ast_parsers import coerce_dict_args_by_annotations, get_fun
from letta.functions.functions import get_function_from_module from letta.functions.functions import get_function_from_module
from letta.helpers import ToolRulesSolver from letta.helpers import ToolRulesSolver
from letta.interface import AgentInterface from letta.interface import AgentInterface
from letta.llm_api.helpers import is_context_overflow_error from letta.llm_api.helpers import calculate_summarizer_cutoff, get_token_counts_for_messages, is_context_overflow_error
from letta.llm_api.llm_api_tools import create from letta.llm_api.llm_api_tools import create
from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages
from letta.log import get_logger from letta.log import get_logger
@ -52,6 +49,7 @@ from letta.services.passage_manager import PassageManager
from letta.services.provider_manager import ProviderManager from letta.services.provider_manager import ProviderManager
from letta.services.step_manager import StepManager from letta.services.step_manager import StepManager
from letta.services.tool_execution_sandbox import ToolExecutionSandbox from letta.services.tool_execution_sandbox import ToolExecutionSandbox
from letta.settings import summarizer_settings
from letta.streaming_interface import StreamingRefreshCLIInterface from letta.streaming_interface import StreamingRefreshCLIInterface
from letta.system import get_heartbeat, get_token_limit_warning, package_function_response, package_summarize_message, package_user_message from letta.system import get_heartbeat, get_token_limit_warning, package_function_response, package_summarize_message, package_user_message
from letta.utils import ( from letta.utils import (
@ -66,6 +64,8 @@ from letta.utils import (
validate_function_response, validate_function_response,
) )
logger = get_logger(__name__)
class BaseAgent(ABC): class BaseAgent(ABC):
""" """
@ -635,7 +635,7 @@ class Agent(BaseAgent):
self.logger.info(f"Hit max chaining steps, stopping after {counter} steps") self.logger.info(f"Hit max chaining steps, stopping after {counter} steps")
break break
# Chain handlers # Chain handlers
elif token_warning: elif token_warning and summarizer_settings.send_memory_warning_message:
assert self.agent_state.created_by_id is not None assert self.agent_state.created_by_id is not None
next_input_message = Message.dict_to_message( next_input_message = Message.dict_to_message(
agent_id=self.agent_state.id, agent_id=self.agent_state.id,
@ -686,6 +686,7 @@ class Agent(BaseAgent):
stream: bool = False, # TODO move to config? stream: bool = False, # TODO move to config?
step_count: Optional[int] = None, step_count: Optional[int] = None,
metadata: Optional[dict] = None, metadata: Optional[dict] = None,
summarize_attempt_count: int = 0,
) -> AgentStepResponse: ) -> AgentStepResponse:
"""Runs a single step in the agent loop (generates at most one LLM call)""" """Runs a single step in the agent loop (generates at most one LLM call)"""
@ -753,9 +754,9 @@ class Agent(BaseAgent):
LLM_MAX_TOKENS[self.model] if (self.model is not None and self.model in LLM_MAX_TOKENS) else LLM_MAX_TOKENS["DEFAULT"] LLM_MAX_TOKENS[self.model] if (self.model is not None and self.model in LLM_MAX_TOKENS) else LLM_MAX_TOKENS["DEFAULT"]
) )
if current_total_tokens > MESSAGE_SUMMARY_WARNING_FRAC * int(self.agent_state.llm_config.context_window): if current_total_tokens > summarizer_settings.memory_warning_threshold * int(self.agent_state.llm_config.context_window):
self.logger.warning( printd(
f"{CLI_WARNING_PREFIX}last response total_tokens ({current_total_tokens}) > {MESSAGE_SUMMARY_WARNING_FRAC * int(self.agent_state.llm_config.context_window)}" f"{CLI_WARNING_PREFIX}last response total_tokens ({current_total_tokens}) > {summarizer_settings.memory_warning_threshold * int(self.agent_state.llm_config.context_window)}"
) )
# Only deliver the alert if we haven't already (this period) # Only deliver the alert if we haven't already (this period)
@ -764,8 +765,8 @@ class Agent(BaseAgent):
self.agent_alerted_about_memory_pressure = True # it's up to the outer loop to handle this self.agent_alerted_about_memory_pressure = True # it's up to the outer loop to handle this
else: else:
self.logger.warning( printd(
f"last response total_tokens ({current_total_tokens}) < {MESSAGE_SUMMARY_WARNING_FRAC * int(self.agent_state.llm_config.context_window)}" f"last response total_tokens ({current_total_tokens}) < {summarizer_settings.memory_warning_threshold * int(self.agent_state.llm_config.context_window)}"
) )
# Log step - this must happen before messages are persisted # Log step - this must happen before messages are persisted
@ -807,28 +808,46 @@ class Agent(BaseAgent):
) )
except Exception as e: except Exception as e:
self.logger.error(f"step() failed\nmessages = {messages}\nerror = {e}") logger.error(f"step() failed\nmessages = {messages}\nerror = {e}")
# If we got a context alert, try trimming the messages length, then try again # If we got a context alert, try trimming the messages length, then try again
if is_context_overflow_error(e): if is_context_overflow_error(e):
self.logger.warning( in_context_messages = self.agent_manager.get_in_context_messages(agent_id=self.agent_state.id, actor=self.user)
f"context window exceeded with limit {self.agent_state.llm_config.context_window}, running summarizer to trim messages"
)
# A separate API call to run a summarizer
self.summarize_messages_inplace()
# Try step again if summarize_attempt_count <= summarizer_settings.max_summarizer_retries:
return self.inner_step( logger.warning(
messages=messages, f"context window exceeded with limit {self.agent_state.llm_config.context_window}, attempting to summarize ({summarize_attempt_count}/{summarizer_settings.max_summarizer_retries}"
first_message=first_message, )
first_message_retry_limit=first_message_retry_limit, # A separate API call to run a summarizer
skip_verify=skip_verify, self.summarize_messages_inplace()
stream=stream,
metadata=metadata, # Try step again
) return self.inner_step(
messages=messages,
first_message=first_message,
first_message_retry_limit=first_message_retry_limit,
skip_verify=skip_verify,
stream=stream,
metadata=metadata,
summarize_attempt_count=summarize_attempt_count + 1,
)
else:
err_msg = f"Ran summarizer {summarize_attempt_count - 1} times for agent id={self.agent_state.id}, but messages are still overflowing the context window."
token_counts = (get_token_counts_for_messages(in_context_messages),)
logger.error(err_msg)
logger.error(f"num_in_context_messages: {len(self.agent_state.message_ids)}")
logger.error(f"token_counts: {token_counts}")
raise ContextWindowExceededError(
err_msg,
details={
"num_in_context_messages": len(self.agent_state.message_ids),
"in_context_messages_text": [m.text for m in in_context_messages],
"token_counts": token_counts,
},
)
else: else:
self.logger.error(f"step() failed with an unrecognized exception: '{str(e)}'") logger.error(f"step() failed with an unrecognized exception: '{str(e)}'")
raise e raise e
def step_user_message(self, user_message_str: str, **kwargs) -> AgentStepResponse: def step_user_message(self, user_message_str: str, **kwargs) -> AgentStepResponse:
@ -865,109 +884,54 @@ class Agent(BaseAgent):
return self.inner_step(messages=[user_message], **kwargs) return self.inner_step(messages=[user_message], **kwargs)
def summarize_messages_inplace(self, cutoff=None, preserve_last_N_messages=True, disallow_tool_as_first=True): def summarize_messages_inplace(self):
in_context_messages = self.agent_manager.get_in_context_messages(agent_id=self.agent_state.id, actor=self.user) in_context_messages = self.agent_manager.get_in_context_messages(agent_id=self.agent_state.id, actor=self.user)
in_context_messages_openai = [m.to_openai_dict() for m in in_context_messages] in_context_messages_openai = [m.to_openai_dict() for m in in_context_messages]
in_context_messages_openai_no_system = in_context_messages_openai[1:]
token_counts = get_token_counts_for_messages(in_context_messages)
logger.info(f"System message token count={token_counts[0]}")
logger.info(f"token_counts_no_system={token_counts[1:]}")
if in_context_messages_openai[0]["role"] != "system": if in_context_messages_openai[0]["role"] != "system":
raise RuntimeError(f"in_context_messages_openai[0] should be system (instead got {in_context_messages_openai[0]})") raise RuntimeError(f"in_context_messages_openai[0] should be system (instead got {in_context_messages_openai[0]})")
# Start at index 1 (past the system message),
# and collect messages for summarization until we reach the desired truncation token fraction (eg 50%)
# Do not allow truncation of the last N messages, since these are needed for in-context examples of function calling
token_counts = [count_tokens(str(msg)) for msg in in_context_messages_openai]
message_buffer_token_count = sum(token_counts[1:]) # no system message
desired_token_count_to_summarize = int(message_buffer_token_count * MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC)
candidate_messages_to_summarize = in_context_messages_openai[1:]
token_counts = token_counts[1:]
if preserve_last_N_messages:
candidate_messages_to_summarize = candidate_messages_to_summarize[:-MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST]
token_counts = token_counts[:-MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST]
printd(f"MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC={MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC}")
printd(f"MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST={MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST}")
printd(f"token_counts={token_counts}")
printd(f"message_buffer_token_count={message_buffer_token_count}")
printd(f"desired_token_count_to_summarize={desired_token_count_to_summarize}")
printd(f"len(candidate_messages_to_summarize)={len(candidate_messages_to_summarize)}")
# If at this point there's nothing to summarize, throw an error # If at this point there's nothing to summarize, throw an error
if len(candidate_messages_to_summarize) == 0: if len(in_context_messages_openai_no_system) == 0:
raise ContextWindowExceededError( raise ContextWindowExceededError(
"Not enough messages to compress for summarization", "Not enough messages to compress for summarization",
details={ details={
"num_candidate_messages": len(candidate_messages_to_summarize), "num_candidate_messages": len(in_context_messages_openai_no_system),
"num_total_messages": len(in_context_messages_openai), "num_total_messages": len(in_context_messages_openai),
"preserve_N": MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST,
}, },
) )
# Walk down the message buffer (front-to-back) until we hit the target token count cutoff = calculate_summarizer_cutoff(in_context_messages=in_context_messages, token_counts=token_counts, logger=logger)
tokens_so_far = 0
cutoff = 0
for i, msg in enumerate(candidate_messages_to_summarize):
cutoff = i
tokens_so_far += token_counts[i]
if tokens_so_far > desired_token_count_to_summarize:
break
# Account for system message
cutoff += 1
# Try to make an assistant message come after the cutoff
try:
printd(f"Selected cutoff {cutoff} was a 'user', shifting one...")
if in_context_messages_openai[cutoff]["role"] == "user":
new_cutoff = cutoff + 1
if in_context_messages_openai[new_cutoff]["role"] == "user":
printd(f"Shifted cutoff {new_cutoff} is still a 'user', ignoring...")
cutoff = new_cutoff
except IndexError:
pass
# Make sure the cutoff isn't on a 'tool' or 'function'
if disallow_tool_as_first:
while in_context_messages_openai[cutoff]["role"] in ["tool", "function"] and cutoff < len(in_context_messages_openai):
printd(f"Selected cutoff {cutoff} was a 'tool', shifting one...")
cutoff += 1
message_sequence_to_summarize = in_context_messages[1:cutoff] # do NOT get rid of the system message message_sequence_to_summarize = in_context_messages[1:cutoff] # do NOT get rid of the system message
if len(message_sequence_to_summarize) <= 1: logger.info(f"Attempting to summarize {len(message_sequence_to_summarize)} messages of {len(in_context_messages)}")
# This prevents a potential infinite loop of summarizing the same message over and over
raise ContextWindowExceededError(
"Not enough messages to compress for summarization after determining cutoff",
details={
"num_candidate_messages": len(message_sequence_to_summarize),
"num_total_messages": len(in_context_messages_openai),
"preserve_N": MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST,
},
)
else:
printd(f"Attempting to summarize {len(message_sequence_to_summarize)} messages [1:{cutoff}] of {len(in_context_messages)}")
# We can't do summarize logic properly if context_window is undefined # We can't do summarize logic properly if context_window is undefined
if self.agent_state.llm_config.context_window is None: if self.agent_state.llm_config.context_window is None:
# Fallback if for some reason context_window is missing, just set to the default # Fallback if for some reason context_window is missing, just set to the default
print(f"{CLI_WARNING_PREFIX}could not find context_window in config, setting to default {LLM_MAX_TOKENS['DEFAULT']}") logger.warning(f"{CLI_WARNING_PREFIX}could not find context_window in config, setting to default {LLM_MAX_TOKENS['DEFAULT']}")
print(f"{self.agent_state}")
self.agent_state.llm_config.context_window = ( self.agent_state.llm_config.context_window = (
LLM_MAX_TOKENS[self.model] if (self.model is not None and self.model in LLM_MAX_TOKENS) else LLM_MAX_TOKENS["DEFAULT"] LLM_MAX_TOKENS[self.model] if (self.model is not None and self.model in LLM_MAX_TOKENS) else LLM_MAX_TOKENS["DEFAULT"]
) )
summary = summarize_messages(agent_state=self.agent_state, message_sequence_to_summarize=message_sequence_to_summarize) summary = summarize_messages(agent_state=self.agent_state, message_sequence_to_summarize=message_sequence_to_summarize)
printd(f"Got summary: {summary}") logger.info(f"Got summary: {summary}")
# Metadata that's useful for the agent to see # Metadata that's useful for the agent to see
all_time_message_count = self.message_manager.size(agent_id=self.agent_state.id, actor=self.user) all_time_message_count = self.message_manager.size(agent_id=self.agent_state.id, actor=self.user)
remaining_message_count = len(in_context_messages_openai[cutoff:]) remaining_message_count = 1 + len(in_context_messages) - cutoff # System + remaining
hidden_message_count = all_time_message_count - remaining_message_count hidden_message_count = all_time_message_count - remaining_message_count
summary_message_count = len(message_sequence_to_summarize) summary_message_count = len(message_sequence_to_summarize)
summary_message = package_summarize_message(summary, summary_message_count, hidden_message_count, all_time_message_count) summary_message = package_summarize_message(summary, summary_message_count, hidden_message_count, all_time_message_count)
printd(f"Packaged into message: {summary_message}") logger.info(f"Packaged into message: {summary_message}")
prior_len = len(in_context_messages_openai) prior_len = len(in_context_messages_openai)
self.agent_state = self.agent_manager.trim_older_in_context_messages(cutoff, agent_id=self.agent_state.id, actor=self.user) self.agent_state = self.agent_manager.trim_all_in_context_messages_except_system(agent_id=self.agent_state.id, actor=self.user)
packed_summary_message = {"role": "user", "content": summary_message} packed_summary_message = {"role": "user", "content": summary_message}
# Prepend the summary
self.agent_state = self.agent_manager.prepend_to_in_context_messages( self.agent_state = self.agent_manager.prepend_to_in_context_messages(
messages=[ messages=[
Message.dict_to_message( Message.dict_to_message(
@ -983,8 +947,12 @@ class Agent(BaseAgent):
# reset alert # reset alert
self.agent_alerted_about_memory_pressure = False self.agent_alerted_about_memory_pressure = False
curr_in_context_messages = self.agent_manager.get_in_context_messages(agent_id=self.agent_state.id, actor=self.user)
printd(f"Ran summarizer, messages length {prior_len} -> {len(in_context_messages_openai)}") logger.info(f"Ran summarizer, messages length {prior_len} -> {len(curr_in_context_messages)}")
logger.info(
f"Summarizer brought down total token count from {sum(token_counts)} -> {sum(get_token_counts_for_messages(curr_in_context_messages))}"
)
def add_function(self, function_name: str) -> str: def add_function(self, function_name: str) -> str:
# TODO: refactor # TODO: refactor

View File

@ -125,8 +125,6 @@ LLM_MAX_TOKENS = {
"gpt-3.5-turbo-16k-0613": 16385, # legacy "gpt-3.5-turbo-16k-0613": 16385, # legacy
"gpt-3.5-turbo-0301": 4096, # legacy "gpt-3.5-turbo-0301": 4096, # legacy
} }
# The amount of tokens before a sytem warning about upcoming truncation is sent to Letta
MESSAGE_SUMMARY_WARNING_FRAC = 0.75
# The error message that Letta will receive # The error message that Letta will receive
# MESSAGE_SUMMARY_WARNING_STR = f"Warning: the conversation history will soon reach its maximum length and be trimmed. Make sure to save any important information from the conversation to your memory before it is removed." # MESSAGE_SUMMARY_WARNING_STR = f"Warning: the conversation history will soon reach its maximum length and be trimmed. Make sure to save any important information from the conversation to your memory before it is removed."
# Much longer and more specific variant of the prompt # Much longer and more specific variant of the prompt
@ -138,15 +136,10 @@ MESSAGE_SUMMARY_WARNING_STR = " ".join(
# "Remember to pass request_heartbeat = true if you would like to send a message immediately after.", # "Remember to pass request_heartbeat = true if you would like to send a message immediately after.",
] ]
) )
# The fraction of tokens we truncate down to
MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC = 0.75
# The ackknowledgement message used in the summarize sequence # The ackknowledgement message used in the summarize sequence
MESSAGE_SUMMARY_REQUEST_ACK = "Understood, I will respond with a summary of the message (and only the summary, nothing else) once I receive the conversation history. I'm ready." MESSAGE_SUMMARY_REQUEST_ACK = "Understood, I will respond with a summary of the message (and only the summary, nothing else) once I receive the conversation history. I'm ready."
# Even when summarizing, we want to keep a handful of recent messages
# These serve as in-context examples of how to use functions / what user messages look like
MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST = 3
# Maximum length of an error message # Maximum length of an error message
MAX_ERROR_MESSAGE_CHAR_LIMIT = 500 MAX_ERROR_MESSAGE_CHAR_LIMIT = 500

View File

@ -7,8 +7,10 @@ from typing import Any, List, Union
import requests import requests
from letta.constants import OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING from letta.constants import OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING
from letta.schemas.message import Message
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice
from letta.utils import json_dumps, printd from letta.settings import summarizer_settings
from letta.utils import count_tokens, json_dumps, printd
def _convert_to_structured_output_helper(property: dict) -> dict: def _convert_to_structured_output_helper(property: dict) -> dict:
@ -287,6 +289,54 @@ def unpack_inner_thoughts_from_kwargs(choice: Choice, inner_thoughts_key: str) -
return rewritten_choice return rewritten_choice
def calculate_summarizer_cutoff(in_context_messages: List[Message], token_counts: List[int], logger: "logging.Logger") -> int:
if len(in_context_messages) != len(token_counts):
raise ValueError(
f"Given in_context_messages has different length from given token_counts: {len(in_context_messages)} != {len(token_counts)}"
)
in_context_messages_openai = [m.to_openai_dict() for m in in_context_messages]
if summarizer_settings.evict_all_messages:
logger.info("Evicting all messages...")
return len(in_context_messages)
else:
# Start at index 1 (past the system message),
# and collect messages for summarization until we reach the desired truncation token fraction (eg 50%)
# We do the inverse of `desired_memory_token_pressure` to get what we need to remove
desired_token_count_to_summarize = int(sum(token_counts) * (1 - summarizer_settings.desired_memory_token_pressure))
logger.info(f"desired_token_count_to_summarize={desired_token_count_to_summarize}")
tokens_so_far = 0
cutoff = 0
for i, msg in enumerate(in_context_messages_openai):
# Skip system
if i == 0:
continue
cutoff = i
tokens_so_far += token_counts[i]
if msg["role"] not in ["user", "tool", "function"] and tokens_so_far >= desired_token_count_to_summarize:
# Break if the role is NOT a user or tool/function and tokens_so_far is enough
break
elif len(in_context_messages) - cutoff - 1 <= summarizer_settings.keep_last_n_messages:
# Also break if we reached the `keep_last_n_messages` threshold
# NOTE: This may be on a user, tool, or function in theory
logger.warning(
f"Breaking summary cutoff early on role={msg['role']} because we hit the `keep_last_n_messages`={summarizer_settings.keep_last_n_messages}"
)
break
logger.info(f"Evicting {cutoff}/{len(in_context_messages)} messages...")
return cutoff + 1
def get_token_counts_for_messages(in_context_messages: List[Message]) -> List[int]:
in_context_messages_openai = [m.to_openai_dict() for m in in_context_messages]
token_counts = [count_tokens(str(msg)) for msg in in_context_messages_openai]
return token_counts
def is_context_overflow_error(exception: Union[requests.exceptions.RequestException, Exception]) -> bool: def is_context_overflow_error(exception: Union[requests.exceptions.RequestException, Exception]) -> bool:
"""Checks if an exception is due to context overflow (based on common OpenAI response messages)""" """Checks if an exception is due to context overflow (based on common OpenAI response messages)"""
from letta.utils import printd from letta.utils import printd

View File

@ -1,12 +1,13 @@
from typing import Callable, Dict, List from typing import Callable, Dict, List
from letta.constants import MESSAGE_SUMMARY_REQUEST_ACK, MESSAGE_SUMMARY_WARNING_FRAC from letta.constants import MESSAGE_SUMMARY_REQUEST_ACK
from letta.llm_api.llm_api_tools import create from letta.llm_api.llm_api_tools import create
from letta.prompts.gpt_summarize import SYSTEM as SUMMARY_PROMPT_SYSTEM from letta.prompts.gpt_summarize import SYSTEM as SUMMARY_PROMPT_SYSTEM
from letta.schemas.agent import AgentState from letta.schemas.agent import AgentState
from letta.schemas.enums import MessageRole from letta.schemas.enums import MessageRole
from letta.schemas.memory import Memory from letta.schemas.memory import Memory
from letta.schemas.message import Message from letta.schemas.message import Message
from letta.settings import summarizer_settings
from letta.utils import count_tokens, printd from letta.utils import count_tokens, printd
@ -49,8 +50,8 @@ def summarize_messages(
summary_prompt = SUMMARY_PROMPT_SYSTEM summary_prompt = SUMMARY_PROMPT_SYSTEM
summary_input = _format_summary_history(message_sequence_to_summarize) summary_input = _format_summary_history(message_sequence_to_summarize)
summary_input_tkns = count_tokens(summary_input) summary_input_tkns = count_tokens(summary_input)
if summary_input_tkns > MESSAGE_SUMMARY_WARNING_FRAC * context_window: if summary_input_tkns > summarizer_settings.memory_warning_threshold * context_window:
trunc_ratio = (MESSAGE_SUMMARY_WARNING_FRAC * context_window / summary_input_tkns) * 0.8 # For good measure... trunc_ratio = (summarizer_settings.memory_warning_threshold * context_window / summary_input_tkns) * 0.8 # For good measure...
cutoff = int(len(message_sequence_to_summarize) * trunc_ratio) cutoff = int(len(message_sequence_to_summarize) * trunc_ratio)
summary_input = str( summary_input = str(
[summarize_messages(agent_state, message_sequence_to_summarize=message_sequence_to_summarize[:cutoff])] [summarize_messages(agent_state, message_sequence_to_summarize=message_sequence_to_summarize[:cutoff])]
@ -58,10 +59,11 @@ def summarize_messages(
) )
dummy_agent_id = agent_state.id dummy_agent_id = agent_state.id
message_sequence = [] message_sequence = [
message_sequence.append(Message(agent_id=dummy_agent_id, role=MessageRole.system, text=summary_prompt)) Message(agent_id=dummy_agent_id, role=MessageRole.system, text=summary_prompt),
message_sequence.append(Message(agent_id=dummy_agent_id, role=MessageRole.assistant, text=MESSAGE_SUMMARY_REQUEST_ACK)) Message(agent_id=dummy_agent_id, role=MessageRole.assistant, text=MESSAGE_SUMMARY_REQUEST_ACK),
message_sequence.append(Message(agent_id=dummy_agent_id, role=MessageRole.user, text=summary_input)) Message(agent_id=dummy_agent_id, role=MessageRole.user, text=summary_input),
]
# TODO: We need to eventually have a separate LLM config for the summarizer LLM # TODO: We need to eventually have a separate LLM config for the summarizer LLM
llm_config_no_inner_thoughts = agent_state.llm_config.model_copy(deep=True) llm_config_no_inner_thoughts = agent_state.llm_config.model_copy(deep=True)

View File

@ -26,7 +26,7 @@ class EnvironmentVariableUpdateBase(LettaBase):
description: Optional[str] = Field(None, description="An optional description of the environment variable.") description: Optional[str] = Field(None, description="An optional description of the environment variable.")
# Sandbox-Specific Environment Variable # Environment Variable
class SandboxEnvironmentVariableBase(EnvironmentVariableBase): class SandboxEnvironmentVariableBase(EnvironmentVariableBase):
__id_prefix__ = "sandbox-env" __id_prefix__ = "sandbox-env"
sandbox_config_id: str = Field(..., description="The ID of the sandbox config this environment variable belongs to.") sandbox_config_id: str = Field(..., description="The ID of the sandbox config this environment variable belongs to.")

View File

@ -464,6 +464,12 @@ class AgentManager:
new_messages = [message_ids[0]] + message_ids[num:] # 0 is system message new_messages = [message_ids[0]] + message_ids[num:] # 0 is system message
return self.set_in_context_messages(agent_id=agent_id, message_ids=new_messages, actor=actor) return self.set_in_context_messages(agent_id=agent_id, message_ids=new_messages, actor=actor)
@enforce_types
def trim_all_in_context_messages_except_system(self, agent_id: str, actor: PydanticUser) -> PydanticAgentState:
message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids
new_messages = [message_ids[0]] # 0 is system message
return self.set_in_context_messages(agent_id=agent_id, message_ids=new_messages, actor=actor)
@enforce_types @enforce_types
def prepend_to_in_context_messages(self, messages: List[PydanticMessage], agent_id: str, actor: PydanticUser) -> PydanticAgentState: def prepend_to_in_context_messages(self, messages: List[PydanticMessage], agent_id: str, actor: PydanticUser) -> PydanticAgentState:
message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids

View File

@ -18,6 +18,34 @@ class ToolSettings(BaseSettings):
local_sandbox_dir: Optional[str] = None local_sandbox_dir: Optional[str] = None
class SummarizerSettings(BaseSettings):
model_config = SettingsConfigDict(env_prefix="letta_summarizer_", extra="ignore")
# Controls if we should evict all messages
# TODO: Can refactor this into an enum if we have a bunch of different kinds of summarizers
evict_all_messages: bool = False
# The maximum number of retries for the summarizer
# If we reach this cutoff, it probably means that the summarizer is not compressing down the in-context messages any further
# And we throw a fatal error
max_summarizer_retries: int = 3
# When to warn the model that a summarize command will happen soon
# The amount of tokens before a system warning about upcoming truncation is sent to Letta
memory_warning_threshold: float = 0.75
# Whether to send the system memory warning message
send_memory_warning_message: bool = False
# The desired memory pressure to summarize down to
desired_memory_token_pressure: float = 0.3
# The number of messages at the end to keep
# Even when summarizing, we may want to keep a handful of recent messages
# These serve as in-context examples of how to use functions / what user messages look like
keep_last_n_messages: int = 0
class ModelSettings(BaseSettings): class ModelSettings(BaseSettings):
model_config = SettingsConfigDict(env_file=".env", extra="ignore") model_config = SettingsConfigDict(env_file=".env", extra="ignore")
@ -147,3 +175,4 @@ settings = Settings(_env_parse_none_str="None")
test_settings = TestSettings() test_settings = TestSettings()
model_settings = ModelSettings() model_settings = ModelSettings()
tool_settings = ToolSettings() tool_settings = ToolSettings()
summarizer_settings = SummarizerSettings()

View File

@ -161,10 +161,10 @@ def package_system_message(system_message, message_type="system_alert", time=Non
return json.dumps(packaged_message) return json.dumps(packaged_message)
def package_summarize_message(summary, summary_length, hidden_message_count, total_message_count, timestamp=None): def package_summarize_message(summary, summary_message_count, hidden_message_count, total_message_count, timestamp=None):
context_message = ( context_message = (
f"Note: prior messages ({hidden_message_count} of {total_message_count} total messages) have been hidden from view due to conversation memory constraints.\n" f"Note: prior messages ({hidden_message_count} of {total_message_count} total messages) have been hidden from view due to conversation memory constraints.\n"
+ f"The following is a summary of the previous {summary_length} messages:\n {summary}" + f"The following is a summary of the previous {summary_message_count} messages:\n {summary}"
) )
formatted_time = get_local_time() if timestamp is None else timestamp formatted_time = get_local_time() if timestamp is None else timestamp

69
poetry.lock generated
View File

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.8.5 and should not be changed by hand. # This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
[[package]] [[package]]
name = "aiohappyeyeballs" name = "aiohappyeyeballs"
@ -416,10 +416,6 @@ files = [
{file = "Brotli-1.1.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:a37b8f0391212d29b3a91a799c8e4a2855e0576911cdfb2515487e30e322253d"}, {file = "Brotli-1.1.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:a37b8f0391212d29b3a91a799c8e4a2855e0576911cdfb2515487e30e322253d"},
{file = "Brotli-1.1.0-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:e84799f09591700a4154154cab9787452925578841a94321d5ee8fb9a9a328f0"}, {file = "Brotli-1.1.0-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:e84799f09591700a4154154cab9787452925578841a94321d5ee8fb9a9a328f0"},
{file = "Brotli-1.1.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:f66b5337fa213f1da0d9000bc8dc0cb5b896b726eefd9c6046f699b169c41b9e"}, {file = "Brotli-1.1.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:f66b5337fa213f1da0d9000bc8dc0cb5b896b726eefd9c6046f699b169c41b9e"},
{file = "Brotli-1.1.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:5dab0844f2cf82be357a0eb11a9087f70c5430b2c241493fc122bb6f2bb0917c"},
{file = "Brotli-1.1.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:e4fe605b917c70283db7dfe5ada75e04561479075761a0b3866c081d035b01c1"},
{file = "Brotli-1.1.0-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:1e9a65b5736232e7a7f91ff3d02277f11d339bf34099a56cdab6a8b3410a02b2"},
{file = "Brotli-1.1.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:58d4b711689366d4a03ac7957ab8c28890415e267f9b6589969e74b6e42225ec"},
{file = "Brotli-1.1.0-cp310-cp310-win32.whl", hash = "sha256:be36e3d172dc816333f33520154d708a2657ea63762ec16b62ece02ab5e4daf2"}, {file = "Brotli-1.1.0-cp310-cp310-win32.whl", hash = "sha256:be36e3d172dc816333f33520154d708a2657ea63762ec16b62ece02ab5e4daf2"},
{file = "Brotli-1.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:0c6244521dda65ea562d5a69b9a26120769b7a9fb3db2fe9545935ed6735b128"}, {file = "Brotli-1.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:0c6244521dda65ea562d5a69b9a26120769b7a9fb3db2fe9545935ed6735b128"},
{file = "Brotli-1.1.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:a3daabb76a78f829cafc365531c972016e4aa8d5b4bf60660ad8ecee19df7ccc"}, {file = "Brotli-1.1.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:a3daabb76a78f829cafc365531c972016e4aa8d5b4bf60660ad8ecee19df7ccc"},
@ -432,14 +428,8 @@ files = [
{file = "Brotli-1.1.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:19c116e796420b0cee3da1ccec3b764ed2952ccfcc298b55a10e5610ad7885f9"}, {file = "Brotli-1.1.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:19c116e796420b0cee3da1ccec3b764ed2952ccfcc298b55a10e5610ad7885f9"},
{file = "Brotli-1.1.0-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:510b5b1bfbe20e1a7b3baf5fed9e9451873559a976c1a78eebaa3b86c57b4265"}, {file = "Brotli-1.1.0-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:510b5b1bfbe20e1a7b3baf5fed9e9451873559a976c1a78eebaa3b86c57b4265"},
{file = "Brotli-1.1.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:a1fd8a29719ccce974d523580987b7f8229aeace506952fa9ce1d53a033873c8"}, {file = "Brotli-1.1.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:a1fd8a29719ccce974d523580987b7f8229aeace506952fa9ce1d53a033873c8"},
{file = "Brotli-1.1.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:c247dd99d39e0338a604f8c2b3bc7061d5c2e9e2ac7ba9cc1be5a69cb6cd832f"},
{file = "Brotli-1.1.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:1b2c248cd517c222d89e74669a4adfa5577e06ab68771a529060cf5a156e9757"},
{file = "Brotli-1.1.0-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:2a24c50840d89ded6c9a8fdc7b6ed3692ed4e86f1c4a4a938e1e92def92933e0"},
{file = "Brotli-1.1.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:f31859074d57b4639318523d6ffdca586ace54271a73ad23ad021acd807eb14b"},
{file = "Brotli-1.1.0-cp311-cp311-win32.whl", hash = "sha256:39da8adedf6942d76dc3e46653e52df937a3c4d6d18fdc94a7c29d263b1f5b50"}, {file = "Brotli-1.1.0-cp311-cp311-win32.whl", hash = "sha256:39da8adedf6942d76dc3e46653e52df937a3c4d6d18fdc94a7c29d263b1f5b50"},
{file = "Brotli-1.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:aac0411d20e345dc0920bdec5548e438e999ff68d77564d5e9463a7ca9d3e7b1"}, {file = "Brotli-1.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:aac0411d20e345dc0920bdec5548e438e999ff68d77564d5e9463a7ca9d3e7b1"},
{file = "Brotli-1.1.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:32d95b80260d79926f5fab3c41701dbb818fde1c9da590e77e571eefd14abe28"},
{file = "Brotli-1.1.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:b760c65308ff1e462f65d69c12e4ae085cff3b332d894637f6273a12a482d09f"},
{file = "Brotli-1.1.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:316cc9b17edf613ac76b1f1f305d2a748f1b976b033b049a6ecdfd5612c70409"}, {file = "Brotli-1.1.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:316cc9b17edf613ac76b1f1f305d2a748f1b976b033b049a6ecdfd5612c70409"},
{file = "Brotli-1.1.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:caf9ee9a5775f3111642d33b86237b05808dafcd6268faa492250e9b78046eb2"}, {file = "Brotli-1.1.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:caf9ee9a5775f3111642d33b86237b05808dafcd6268faa492250e9b78046eb2"},
{file = "Brotli-1.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:70051525001750221daa10907c77830bc889cb6d865cc0b813d9db7fefc21451"}, {file = "Brotli-1.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:70051525001750221daa10907c77830bc889cb6d865cc0b813d9db7fefc21451"},
@ -450,24 +440,8 @@ files = [
{file = "Brotli-1.1.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:4093c631e96fdd49e0377a9c167bfd75b6d0bad2ace734c6eb20b348bc3ea180"}, {file = "Brotli-1.1.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:4093c631e96fdd49e0377a9c167bfd75b6d0bad2ace734c6eb20b348bc3ea180"},
{file = "Brotli-1.1.0-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:7e4c4629ddad63006efa0ef968c8e4751c5868ff0b1c5c40f76524e894c50248"}, {file = "Brotli-1.1.0-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:7e4c4629ddad63006efa0ef968c8e4751c5868ff0b1c5c40f76524e894c50248"},
{file = "Brotli-1.1.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:861bf317735688269936f755fa136a99d1ed526883859f86e41a5d43c61d8966"}, {file = "Brotli-1.1.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:861bf317735688269936f755fa136a99d1ed526883859f86e41a5d43c61d8966"},
{file = "Brotli-1.1.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:87a3044c3a35055527ac75e419dfa9f4f3667a1e887ee80360589eb8c90aabb9"},
{file = "Brotli-1.1.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:c5529b34c1c9d937168297f2c1fde7ebe9ebdd5e121297ff9c043bdb2ae3d6fb"},
{file = "Brotli-1.1.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:ca63e1890ede90b2e4454f9a65135a4d387a4585ff8282bb72964fab893f2111"},
{file = "Brotli-1.1.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:e79e6520141d792237c70bcd7a3b122d00f2613769ae0cb61c52e89fd3443839"},
{file = "Brotli-1.1.0-cp312-cp312-win32.whl", hash = "sha256:5f4d5ea15c9382135076d2fb28dde923352fe02951e66935a9efaac8f10e81b0"}, {file = "Brotli-1.1.0-cp312-cp312-win32.whl", hash = "sha256:5f4d5ea15c9382135076d2fb28dde923352fe02951e66935a9efaac8f10e81b0"},
{file = "Brotli-1.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:906bc3a79de8c4ae5b86d3d75a8b77e44404b0f4261714306e3ad248d8ab0951"}, {file = "Brotli-1.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:906bc3a79de8c4ae5b86d3d75a8b77e44404b0f4261714306e3ad248d8ab0951"},
{file = "Brotli-1.1.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:8bf32b98b75c13ec7cf774164172683d6e7891088f6316e54425fde1efc276d5"},
{file = "Brotli-1.1.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:7bc37c4d6b87fb1017ea28c9508b36bbcb0c3d18b4260fcdf08b200c74a6aee8"},
{file = "Brotli-1.1.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3c0ef38c7a7014ffac184db9e04debe495d317cc9c6fb10071f7fefd93100a4f"},
{file = "Brotli-1.1.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:91d7cc2a76b5567591d12c01f019dd7afce6ba8cba6571187e21e2fc418ae648"},
{file = "Brotli-1.1.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a93dde851926f4f2678e704fadeb39e16c35d8baebd5252c9fd94ce8ce68c4a0"},
{file = "Brotli-1.1.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f0db75f47be8b8abc8d9e31bc7aad0547ca26f24a54e6fd10231d623f183d089"},
{file = "Brotli-1.1.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:6967ced6730aed543b8673008b5a391c3b1076d834ca438bbd70635c73775368"},
{file = "Brotli-1.1.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:7eedaa5d036d9336c95915035fb57422054014ebdeb6f3b42eac809928e40d0c"},
{file = "Brotli-1.1.0-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:d487f5432bf35b60ed625d7e1b448e2dc855422e87469e3f450aa5552b0eb284"},
{file = "Brotli-1.1.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:832436e59afb93e1836081a20f324cb185836c617659b07b129141a8426973c7"},
{file = "Brotli-1.1.0-cp313-cp313-win32.whl", hash = "sha256:43395e90523f9c23a3d5bdf004733246fba087f2948f87ab28015f12359ca6a0"},
{file = "Brotli-1.1.0-cp313-cp313-win_amd64.whl", hash = "sha256:9011560a466d2eb3f5a6e4929cf4a09be405c64154e12df0dd72713f6500e32b"},
{file = "Brotli-1.1.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:a090ca607cbb6a34b0391776f0cb48062081f5f60ddcce5d11838e67a01928d1"}, {file = "Brotli-1.1.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:a090ca607cbb6a34b0391776f0cb48062081f5f60ddcce5d11838e67a01928d1"},
{file = "Brotli-1.1.0-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2de9d02f5bda03d27ede52e8cfe7b865b066fa49258cbab568720aa5be80a47d"}, {file = "Brotli-1.1.0-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2de9d02f5bda03d27ede52e8cfe7b865b066fa49258cbab568720aa5be80a47d"},
{file = "Brotli-1.1.0-cp36-cp36m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2333e30a5e00fe0fe55903c8832e08ee9c3b1382aacf4db26664a16528d51b4b"}, {file = "Brotli-1.1.0-cp36-cp36m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2333e30a5e00fe0fe55903c8832e08ee9c3b1382aacf4db26664a16528d51b4b"},
@ -477,10 +451,6 @@ files = [
{file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:fd5f17ff8f14003595ab414e45fce13d073e0762394f957182e69035c9f3d7c2"}, {file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:fd5f17ff8f14003595ab414e45fce13d073e0762394f957182e69035c9f3d7c2"},
{file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_1_ppc64le.whl", hash = "sha256:069a121ac97412d1fe506da790b3e69f52254b9df4eb665cd42460c837193354"}, {file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_1_ppc64le.whl", hash = "sha256:069a121ac97412d1fe506da790b3e69f52254b9df4eb665cd42460c837193354"},
{file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:e93dfc1a1165e385cc8239fab7c036fb2cd8093728cbd85097b284d7b99249a2"}, {file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:e93dfc1a1165e385cc8239fab7c036fb2cd8093728cbd85097b284d7b99249a2"},
{file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_2_aarch64.whl", hash = "sha256:aea440a510e14e818e67bfc4027880e2fb500c2ccb20ab21c7a7c8b5b4703d75"},
{file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_2_i686.whl", hash = "sha256:6974f52a02321b36847cd19d1b8e381bf39939c21efd6ee2fc13a28b0d99348c"},
{file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_2_ppc64le.whl", hash = "sha256:a7e53012d2853a07a4a79c00643832161a910674a893d296c9f1259859a289d2"},
{file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_2_x86_64.whl", hash = "sha256:d7702622a8b40c49bffb46e1e3ba2e81268d5c04a34f460978c6b5517a34dd52"},
{file = "Brotli-1.1.0-cp36-cp36m-win32.whl", hash = "sha256:a599669fd7c47233438a56936988a2478685e74854088ef5293802123b5b2460"}, {file = "Brotli-1.1.0-cp36-cp36m-win32.whl", hash = "sha256:a599669fd7c47233438a56936988a2478685e74854088ef5293802123b5b2460"},
{file = "Brotli-1.1.0-cp36-cp36m-win_amd64.whl", hash = "sha256:d143fd47fad1db3d7c27a1b1d66162e855b5d50a89666af46e1679c496e8e579"}, {file = "Brotli-1.1.0-cp36-cp36m-win_amd64.whl", hash = "sha256:d143fd47fad1db3d7c27a1b1d66162e855b5d50a89666af46e1679c496e8e579"},
{file = "Brotli-1.1.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:11d00ed0a83fa22d29bc6b64ef636c4552ebafcef57154b4ddd132f5638fbd1c"}, {file = "Brotli-1.1.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:11d00ed0a83fa22d29bc6b64ef636c4552ebafcef57154b4ddd132f5638fbd1c"},
@ -492,10 +462,6 @@ files = [
{file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:919e32f147ae93a09fe064d77d5ebf4e35502a8df75c29fb05788528e330fe74"}, {file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:919e32f147ae93a09fe064d77d5ebf4e35502a8df75c29fb05788528e330fe74"},
{file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:23032ae55523cc7bccb4f6a0bf368cd25ad9bcdcc1990b64a647e7bbcce9cb5b"}, {file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:23032ae55523cc7bccb4f6a0bf368cd25ad9bcdcc1990b64a647e7bbcce9cb5b"},
{file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:224e57f6eac61cc449f498cc5f0e1725ba2071a3d4f48d5d9dffba42db196438"}, {file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:224e57f6eac61cc449f498cc5f0e1725ba2071a3d4f48d5d9dffba42db196438"},
{file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:cb1dac1770878ade83f2ccdf7d25e494f05c9165f5246b46a621cc849341dc01"},
{file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_2_i686.whl", hash = "sha256:3ee8a80d67a4334482d9712b8e83ca6b1d9bc7e351931252ebef5d8f7335a547"},
{file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_2_ppc64le.whl", hash = "sha256:5e55da2c8724191e5b557f8e18943b1b4839b8efc3ef60d65985bcf6f587dd38"},
{file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:d342778ef319e1026af243ed0a07c97acf3bad33b9f29e7ae6a1f68fd083e90c"},
{file = "Brotli-1.1.0-cp37-cp37m-win32.whl", hash = "sha256:587ca6d3cef6e4e868102672d3bd9dc9698c309ba56d41c2b9c85bbb903cdb95"}, {file = "Brotli-1.1.0-cp37-cp37m-win32.whl", hash = "sha256:587ca6d3cef6e4e868102672d3bd9dc9698c309ba56d41c2b9c85bbb903cdb95"},
{file = "Brotli-1.1.0-cp37-cp37m-win_amd64.whl", hash = "sha256:2954c1c23f81c2eaf0b0717d9380bd348578a94161a65b3a2afc62c86467dd68"}, {file = "Brotli-1.1.0-cp37-cp37m-win_amd64.whl", hash = "sha256:2954c1c23f81c2eaf0b0717d9380bd348578a94161a65b3a2afc62c86467dd68"},
{file = "Brotli-1.1.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:efa8b278894b14d6da122a72fefcebc28445f2d3f880ac59d46c90f4c13be9a3"}, {file = "Brotli-1.1.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:efa8b278894b14d6da122a72fefcebc28445f2d3f880ac59d46c90f4c13be9a3"},
@ -508,10 +474,6 @@ files = [
{file = "Brotli-1.1.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:1ab4fbee0b2d9098c74f3057b2bc055a8bd92ccf02f65944a241b4349229185a"}, {file = "Brotli-1.1.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:1ab4fbee0b2d9098c74f3057b2bc055a8bd92ccf02f65944a241b4349229185a"},
{file = "Brotli-1.1.0-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:141bd4d93984070e097521ed07e2575b46f817d08f9fa42b16b9b5f27b5ac088"}, {file = "Brotli-1.1.0-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:141bd4d93984070e097521ed07e2575b46f817d08f9fa42b16b9b5f27b5ac088"},
{file = "Brotli-1.1.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:fce1473f3ccc4187f75b4690cfc922628aed4d3dd013d047f95a9b3919a86596"}, {file = "Brotli-1.1.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:fce1473f3ccc4187f75b4690cfc922628aed4d3dd013d047f95a9b3919a86596"},
{file = "Brotli-1.1.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:d2b35ca2c7f81d173d2fadc2f4f31e88cc5f7a39ae5b6db5513cf3383b0e0ec7"},
{file = "Brotli-1.1.0-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:af6fa6817889314555aede9a919612b23739395ce767fe7fcbea9a80bf140fe5"},
{file = "Brotli-1.1.0-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:2feb1d960f760a575dbc5ab3b1c00504b24caaf6986e2dc2b01c09c87866a943"},
{file = "Brotli-1.1.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:4410f84b33374409552ac9b6903507cdb31cd30d2501fc5ca13d18f73548444a"},
{file = "Brotli-1.1.0-cp38-cp38-win32.whl", hash = "sha256:db85ecf4e609a48f4b29055f1e144231b90edc90af7481aa731ba2d059226b1b"}, {file = "Brotli-1.1.0-cp38-cp38-win32.whl", hash = "sha256:db85ecf4e609a48f4b29055f1e144231b90edc90af7481aa731ba2d059226b1b"},
{file = "Brotli-1.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:3d7954194c36e304e1523f55d7042c59dc53ec20dd4e9ea9d151f1b62b4415c0"}, {file = "Brotli-1.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:3d7954194c36e304e1523f55d7042c59dc53ec20dd4e9ea9d151f1b62b4415c0"},
{file = "Brotli-1.1.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:5fb2ce4b8045c78ebbc7b8f3c15062e435d47e7393cc57c25115cfd49883747a"}, {file = "Brotli-1.1.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:5fb2ce4b8045c78ebbc7b8f3c15062e435d47e7393cc57c25115cfd49883747a"},
@ -524,10 +486,6 @@ files = [
{file = "Brotli-1.1.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:949f3b7c29912693cee0afcf09acd6ebc04c57af949d9bf77d6101ebb61e388c"}, {file = "Brotli-1.1.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:949f3b7c29912693cee0afcf09acd6ebc04c57af949d9bf77d6101ebb61e388c"},
{file = "Brotli-1.1.0-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:89f4988c7203739d48c6f806f1e87a1d96e0806d44f0fba61dba81392c9e474d"}, {file = "Brotli-1.1.0-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:89f4988c7203739d48c6f806f1e87a1d96e0806d44f0fba61dba81392c9e474d"},
{file = "Brotli-1.1.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:de6551e370ef19f8de1807d0a9aa2cdfdce2e85ce88b122fe9f6b2b076837e59"}, {file = "Brotli-1.1.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:de6551e370ef19f8de1807d0a9aa2cdfdce2e85ce88b122fe9f6b2b076837e59"},
{file = "Brotli-1.1.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:0737ddb3068957cf1b054899b0883830bb1fec522ec76b1098f9b6e0f02d9419"},
{file = "Brotli-1.1.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:4f3607b129417e111e30637af1b56f24f7a49e64763253bbc275c75fa887d4b2"},
{file = "Brotli-1.1.0-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:6c6e0c425f22c1c719c42670d561ad682f7bfeeef918edea971a79ac5252437f"},
{file = "Brotli-1.1.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:494994f807ba0b92092a163a0a283961369a65f6cbe01e8891132b7a320e61eb"},
{file = "Brotli-1.1.0-cp39-cp39-win32.whl", hash = "sha256:f0d8a7a6b5983c2496e364b969f0e526647a06b075d034f3297dc66f3b360c64"}, {file = "Brotli-1.1.0-cp39-cp39-win32.whl", hash = "sha256:f0d8a7a6b5983c2496e364b969f0e526647a06b075d034f3297dc66f3b360c64"},
{file = "Brotli-1.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:cdad5b9014d83ca68c25d2e9444e28e967ef16e80f6b436918c700c117a85467"}, {file = "Brotli-1.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:cdad5b9014d83ca68c25d2e9444e28e967ef16e80f6b436918c700c117a85467"},
{file = "Brotli-1.1.0.tar.gz", hash = "sha256:81de08ac11bcb85841e440c13611c00b67d3bf82698314928d0b676362546724"}, {file = "Brotli-1.1.0.tar.gz", hash = "sha256:81de08ac11bcb85841e440c13611c00b67d3bf82698314928d0b676362546724"},
@ -2021,7 +1979,7 @@ files = [
name = "iniconfig" name = "iniconfig"
version = "2.0.0" version = "2.0.0"
description = "brain-dead simple config-ini parsing" description = "brain-dead simple config-ini parsing"
optional = true optional = false
python-versions = ">=3.7" python-versions = ">=3.7"
files = [ files = [
{file = "iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374"}, {file = "iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374"},
@ -3737,7 +3695,7 @@ type = ["mypy (>=1.11.2)"]
name = "pluggy" name = "pluggy"
version = "1.5.0" version = "1.5.0"
description = "plugin and hook calling mechanisms for python" description = "plugin and hook calling mechanisms for python"
optional = true optional = false
python-versions = ">=3.8" python-versions = ">=3.8"
files = [ files = [
{file = "pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669"}, {file = "pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669"},
@ -4417,7 +4375,7 @@ websocket-client = "!=0.49"
name = "pytest" name = "pytest"
version = "8.3.4" version = "8.3.4"
description = "pytest: simple powerful testing with Python" description = "pytest: simple powerful testing with Python"
optional = true optional = false
python-versions = ">=3.8" python-versions = ">=3.8"
files = [ files = [
{file = "pytest-8.3.4-py3-none-any.whl", hash = "sha256:50e16d954148559c9a74109af1eaf0c945ba2d8f30f0a3d3335edde19788b6f6"}, {file = "pytest-8.3.4-py3-none-any.whl", hash = "sha256:50e16d954148559c9a74109af1eaf0c945ba2d8f30f0a3d3335edde19788b6f6"},
@ -4453,6 +4411,23 @@ pytest = ">=7.0.0,<9"
docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"] docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"]
testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"]
[[package]]
name = "pytest-mock"
version = "3.14.0"
description = "Thin-wrapper around the mock package for easier use with pytest"
optional = false
python-versions = ">=3.8"
files = [
{file = "pytest-mock-3.14.0.tar.gz", hash = "sha256:2719255a1efeceadbc056d6bf3df3d1c5015530fb40cf347c0f9afac88410bd0"},
{file = "pytest_mock-3.14.0-py3-none-any.whl", hash = "sha256:0b72c38033392a5f4621342fe11e9219ac11ec9d375f8e2a0c164539e0d70f6f"},
]
[package.dependencies]
pytest = ">=6.2.5"
[package.extras]
dev = ["pre-commit", "pytest-asyncio", "tox"]
[[package]] [[package]]
name = "pytest-order" name = "pytest-order"
version = "1.3.0" version = "1.3.0"
@ -6318,4 +6293,4 @@ tests = ["wikipedia"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "<3.14,>=3.10" python-versions = "<3.14,>=3.10"
content-hash = "f79e70bc03fff20fcd97a1be2c7421d94458df8ffd92096c487b9dbb81f23164" content-hash = "2f552617ff233fe8b07bdec4dc1679935df30030046984962b69ebe625717815"

View File

@ -94,6 +94,7 @@ bedrock = ["boto3"]
black = "^24.4.2" black = "^24.4.2"
ipykernel = "^6.29.5" ipykernel = "^6.29.5"
ipdb = "^0.13.13" ipdb = "^0.13.13"
pytest-mock = "^3.14.0"
[tool.black] [tool.black]
line-length = 140 line-length = 140

View File

@ -1,6 +1,7 @@
import json import json
import os import os
import uuid import uuid
from datetime import datetime
from typing import List from typing import List
import pytest import pytest
@ -8,9 +9,13 @@ import pytest
from letta import create_client from letta import create_client
from letta.agent import Agent from letta.agent import Agent
from letta.client.client import LocalClient from letta.client.client import LocalClient
from letta.errors import ContextWindowExceededError
from letta.llm_api.helpers import calculate_summarizer_cutoff
from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import MessageRole
from letta.schemas.llm_config import LLMConfig from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message from letta.schemas.message import Message
from letta.settings import summarizer_settings
from letta.streaming_interface import StreamingRefreshCLIInterface from letta.streaming_interface import StreamingRefreshCLIInterface
from tests.helpers.endpoints_helper import EMBEDDING_CONFIG_PATH from tests.helpers.endpoints_helper import EMBEDDING_CONFIG_PATH
from tests.helpers.utils import cleanup from tests.helpers.utils import cleanup
@ -44,6 +49,101 @@ def agent_state(client):
client.delete_agent(agent_state.id) client.delete_agent(agent_state.id)
# Sample data setup
def generate_message(role: str, text: str = None, tool_calls: List = None) -> Message:
"""Helper to generate a Message object."""
return Message(
id="message-" + str(uuid.uuid4()),
role=MessageRole(role),
text=text or f"{role} message text",
created_at=datetime.utcnow(),
tool_calls=tool_calls or [],
)
def test_cutoff_calculation(mocker):
"""Test basic scenarios where the function calculates the cutoff correctly."""
# Arrange
logger = mocker.Mock() # Mock logger
messages = [
generate_message("system"),
generate_message("user"),
generate_message("assistant"),
generate_message("user"),
generate_message("assistant"),
]
mocker.patch("letta.settings.summarizer_settings.desired_memory_token_pressure", 0.5)
mocker.patch("letta.settings.summarizer_settings.evict_all_messages", False)
# Basic tests
token_counts = [4, 2, 8, 2, 2]
cutoff = calculate_summarizer_cutoff(messages, token_counts, logger)
assert cutoff == 3
assert messages[cutoff - 1].role == MessageRole.assistant
token_counts = [4, 2, 2, 2, 2]
cutoff = calculate_summarizer_cutoff(messages, token_counts, logger)
assert cutoff == 5
assert messages[cutoff - 1].role == MessageRole.assistant
token_counts = [2, 2, 3, 2, 2]
cutoff = calculate_summarizer_cutoff(messages, token_counts, logger)
assert cutoff == 3
assert messages[cutoff - 1].role == MessageRole.assistant
# Evict all messages
# Should give the end of the token_counts, even though it is not necessary (can just evict up to the 100)
mocker.patch("letta.settings.summarizer_settings.evict_all_messages", True)
token_counts = [1, 1, 100, 1, 1]
cutoff = calculate_summarizer_cutoff(messages, token_counts, logger)
assert cutoff == 5
assert messages[cutoff - 1].role == MessageRole.assistant
# Don't evict all messages with same token_counts, cutoff now should be at the 100
# Should give the end of the token_counts, even though it is not necessary (can just evict up to the 100)
mocker.patch("letta.settings.summarizer_settings.evict_all_messages", False)
cutoff = calculate_summarizer_cutoff(messages, token_counts, logger)
assert cutoff == 3
assert messages[cutoff - 1].role == MessageRole.assistant
# Set `keep_last_n_messages`
mocker.patch("letta.settings.summarizer_settings.keep_last_n_messages", 3)
token_counts = [4, 2, 2, 2, 2]
cutoff = calculate_summarizer_cutoff(messages, token_counts, logger)
assert cutoff == 2
assert messages[cutoff - 1].role == MessageRole.user
def test_summarize_many_messages_basic(client, mock_e2b_api_key_none):
small_context_llm_config = LLMConfig.default_config("gpt-4o-mini")
small_context_llm_config.context_window = 3000
small_agent_state = client.create_agent(
name="small_context_agent",
llm_config=small_context_llm_config,
)
for _ in range(10):
client.user_message(
agent_id=small_agent_state.id,
message="hi " * 60,
)
client.delete_agent(small_agent_state.id)
def test_summarize_large_message_does_not_loop_infinitely(client, mock_e2b_api_key_none):
small_context_llm_config = LLMConfig.default_config("gpt-4o-mini")
small_context_llm_config.context_window = 2000
small_agent_state = client.create_agent(
name="super_small_context_agent",
llm_config=small_context_llm_config,
)
with pytest.raises(ContextWindowExceededError, match=f"Ran summarizer {summarizer_settings.max_summarizer_retries}"):
client.user_message(
agent_id=small_agent_state.id,
message="hi " * 1000,
)
client.delete_agent(small_agent_state.id)
def test_summarize_messages_inplace(client, agent_state, mock_e2b_api_key_none): def test_summarize_messages_inplace(client, agent_state, mock_e2b_api_key_none):
"""Test summarization via sending the summarize CLI command or via a direct call to the agent object""" """Test summarization via sending the summarize CLI command or via a direct call to the agent object"""
# First send a few messages (5) # First send a few messages (5)
@ -134,7 +234,7 @@ def test_auto_summarize(client, mock_e2b_api_key_none):
# "gemini-pro.json", TODO: Gemini is broken # "gemini-pro.json", TODO: Gemini is broken
], ],
) )
def test_summarizer(config_filename): def test_summarizer(config_filename, client, agent_state):
namespace = uuid.NAMESPACE_DNS namespace = uuid.NAMESPACE_DNS
agent_name = str(uuid.uuid5(namespace, f"integration-test-summarizer-{config_filename}")) agent_name = str(uuid.uuid5(namespace, f"integration-test-summarizer-{config_filename}"))
@ -175,6 +275,6 @@ def test_summarizer(config_filename):
) )
# Invoke a summarize # Invoke a summarize
letta_agent.summarize_messages_inplace(preserve_last_N_messages=False) letta_agent.summarize_messages_inplace()
in_context_messages = client.get_in_context_messages(agent_state.id) in_context_messages = client.get_in_context_messages(agent_state.id)
assert SUMMARY_KEY_PHRASE in in_context_messages[1].text, f"Test failed for config: {config_filename}" assert SUMMARY_KEY_PHRASE in in_context_messages[1].text, f"Test failed for config: {config_filename}"