From b756f31b940cc2675fea6950a7799d161193dd9d Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Wed, 30 Apr 2025 18:08:47 -0700 Subject: [PATCH] feat: Add safety checks to buffer length update and tests to updating voice sleeptime (#1959) --- letta/constants.py | 4 ++ letta/server/server.py | 4 +- letta/services/group_manager.py | 30 +++++++++--- tests/integration_test_voice_agent.py | 66 ++++++++++++++++++++++++++- 4 files changed, 94 insertions(+), 10 deletions(-) diff --git a/letta/constants.py b/letta/constants.py index 28cb17e2e..448277f84 100644 --- a/letta/constants.py +++ b/letta/constants.py @@ -35,6 +35,10 @@ TOOL_CALL_ID_MAX_LEN = 29 # minimum context window size MIN_CONTEXT_WINDOW = 4096 +# Voice Sleeptime message buffer lengths +DEFAULT_MAX_MESSAGE_BUFFER_LENGTH = 30 +DEFAULT_MIN_MESSAGE_BUFFER_LENGTH = 15 + # embeddings MAX_EMBEDDING_DIM = 4096 # maximum supported embeding size - do NOT change or else DBs will need to be reset DEFAULT_EMBEDDING_CHUNK_SIZE = 300 diff --git a/letta/server/server.py b/letta/server/server.py index 2fa950344..b46ff1ea0 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -862,8 +862,8 @@ class SyncServer(Server): agent_ids=[voice_sleeptime_agent.id], manager_config=VoiceSleeptimeManager( manager_agent_id=main_agent.id, - max_message_buffer_length=30, - min_message_buffer_length=15, + max_message_buffer_length=constants.DEFAULT_MAX_MESSAGE_BUFFER_LENGTH, + min_message_buffer_length=constants.DEFAULT_MIN_MESSAGE_BUFFER_LENGTH, ), ), actor=actor, diff --git a/letta/services/group_manager.py b/letta/services/group_manager.py index 3e1ee0234..8bae455f5 100644 --- a/letta/services/group_manager.py +++ b/letta/services/group_manager.py @@ -125,10 +125,10 @@ class GroupManager: sleeptime_agent_frequency = group_update.manager_config.sleeptime_agent_frequency if sleeptime_agent_frequency and group.turns_counter is None: group.turns_counter = -1 - case ManagerType.sleeptime: + case ManagerType.voice_sleeptime: manager_agent_id = group_update.manager_config.manager_agent_id - max_message_buffer_length = group_update.manager_config.max_message_buffer_length - min_message_buffer_length = group_update.manager_config.min_message_buffer_length + max_message_buffer_length = group_update.manager_config.max_message_buffer_length or group.max_message_buffer_length + min_message_buffer_length = group_update.manager_config.min_message_buffer_length or group.min_message_buffer_length if sleeptime_agent_frequency and group.turns_counter is None: group.turns_counter = -1 case _: @@ -304,8 +304,9 @@ class GroupManager: min_name: str = "min_message_buffer_length", ) -> None: """ - 1) If one of max_value/min_value is set, the other must also be set. - 2) If both are set, max_value must be greater than min_value. + 1) Both-or-none: if one is set, the other must be set. + 2) Both must be ints > 4. + 3) max_value must be strictly greater than min_value. """ # 1) require both-or-none if (max_value is None) != (min_value is None): @@ -313,6 +314,21 @@ class GroupManager: f"Both '{max_name}' and '{min_name}' must be provided together " f"(got {max_name}={max_value}, {min_name}={min_value})" ) - # 2) valid range - if max_value is not None and min_value is not None and max_value <= min_value: + # no further checks if neither is provided + if max_value is None: + return + + # 2) type & lower‐bound checks + if not isinstance(max_value, int) or not isinstance(min_value, int): + raise ValueError( + f"Both '{max_name}' and '{min_name}' must be integers " + f"(got {max_name}={type(max_value).__name__}, {min_name}={type(min_value).__name__})" + ) + if max_value <= 4 or min_value <= 4: + raise ValueError( + f"Both '{max_name}' and '{min_name}' must be greater than 4 " f"(got {max_name}={max_value}, {min_name}={min_value})" + ) + + # 3) ordering + if max_value <= min_value: raise ValueError(f"'{max_name}' must be greater than '{min_name}' " f"(got {max_name}={max_value} <= {min_name}={min_value})") diff --git a/tests/integration_test_voice_agent.py b/tests/integration_test_voice_agent.py index f4533395a..bc6c09dbe 100644 --- a/tests/integration_test_voice_agent.py +++ b/tests/integration_test_voice_agent.py @@ -10,12 +10,13 @@ from openai.types.chat import ChatCompletionChunk from letta.agents.voice_sleeptime_agent import VoiceSleeptimeAgent from letta.config import LettaConfig +from letta.constants import DEFAULT_MAX_MESSAGE_BUFFER_LENGTH, DEFAULT_MIN_MESSAGE_BUFFER_LENGTH from letta.orm.errors import NoResultFound from letta.schemas.agent import AgentType, CreateAgent from letta.schemas.block import CreateBlock from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.enums import MessageRole, MessageStreamStatus -from letta.schemas.group import ManagerType +from letta.schemas.group import GroupUpdate, ManagerType, VoiceSleeptimeManagerUpdate from letta.schemas.letta_message import AssistantMessage, ReasoningMessage, ToolCallMessage, ToolReturnMessage, UserMessage from letta.schemas.letta_message_content import TextContent from letta.schemas.llm_config import LLMConfig @@ -497,3 +498,66 @@ async def test_init_voice_convo_agent(voice_agent, server, actor): server.group_manager.retrieve_group(group_id=group.id, actor=actor) with pytest.raises(NoResultFound): server.agent_manager.get_agent_by_id(agent_id=sleeptime_agent_id, actor=actor) + + +def _modify(group_id, server, actor, max_val, min_val): + """Helper to invoke modify_group with voice_sleeptime config.""" + return server.group_manager.modify_group( + group_id=group_id, + group_update=GroupUpdate( + manager_config=VoiceSleeptimeManagerUpdate( + manager_type=ManagerType.voice_sleeptime, + max_message_buffer_length=max_val, + min_message_buffer_length=min_val, + ) + ), + actor=actor, + ) + + +@pytest.fixture +def group_id(voice_agent): + return voice_agent.multi_agent_group.id + + +def test_valid_buffer_lengths_above_four(group_id, server, actor): + # both > 4 and max > min + updated = _modify(group_id, server, actor, max_val=10, min_val=5) + assert updated.max_message_buffer_length == 10 + assert updated.min_message_buffer_length == 5 + + +def test_valid_buffer_lengths_only_max(group_id, server, actor): + # both > 4 and max > min + updated = _modify(group_id, server, actor, max_val=DEFAULT_MAX_MESSAGE_BUFFER_LENGTH + 1, min_val=None) + assert updated.max_message_buffer_length == DEFAULT_MAX_MESSAGE_BUFFER_LENGTH + 1 + assert updated.min_message_buffer_length == DEFAULT_MIN_MESSAGE_BUFFER_LENGTH + + +def test_valid_buffer_lengths_only_min(group_id, server, actor): + # both > 4 and max > min + updated = _modify(group_id, server, actor, max_val=None, min_val=DEFAULT_MIN_MESSAGE_BUFFER_LENGTH + 1) + assert updated.max_message_buffer_length == DEFAULT_MAX_MESSAGE_BUFFER_LENGTH + assert updated.min_message_buffer_length == DEFAULT_MIN_MESSAGE_BUFFER_LENGTH + 1 + + +@pytest.mark.parametrize( + "max_val,min_val,err_part", + [ + # only one set → both-or-none + (None, DEFAULT_MAX_MESSAGE_BUFFER_LENGTH, "must be greater than"), + (DEFAULT_MIN_MESSAGE_BUFFER_LENGTH, None, "must be greater than"), + # ordering violations + (5, 5, "must be greater than"), + (6, 7, "must be greater than"), + # lower-bound (must both be > 4) + (4, 5, "greater than 4"), + (5, 4, "greater than 4"), + (1, 10, "greater than 4"), + (10, 1, "greater than 4"), + ], +) +def test_invalid_buffer_lengths(group_id, server, actor, max_val, min_val, err_part): + with pytest.raises(ValueError) as exc: + _modify(group_id, server, actor, max_val, min_val) + assert err_part in str(exc.value)