feat: Add safety checks to buffer length update and tests to updating voice sleeptime (#1959)

This commit is contained in:
Matthew Zhou 2025-04-30 18:08:47 -07:00 committed by GitHub
parent 7fb90ff612
commit b756f31b94
4 changed files with 94 additions and 10 deletions

View File

@ -35,6 +35,10 @@ TOOL_CALL_ID_MAX_LEN = 29
# minimum context window size
MIN_CONTEXT_WINDOW = 4096
# Voice Sleeptime message buffer lengths
DEFAULT_MAX_MESSAGE_BUFFER_LENGTH = 30
DEFAULT_MIN_MESSAGE_BUFFER_LENGTH = 15
# embeddings
MAX_EMBEDDING_DIM = 4096 # maximum supported embeding size - do NOT change or else DBs will need to be reset
DEFAULT_EMBEDDING_CHUNK_SIZE = 300

View File

@ -862,8 +862,8 @@ class SyncServer(Server):
agent_ids=[voice_sleeptime_agent.id],
manager_config=VoiceSleeptimeManager(
manager_agent_id=main_agent.id,
max_message_buffer_length=30,
min_message_buffer_length=15,
max_message_buffer_length=constants.DEFAULT_MAX_MESSAGE_BUFFER_LENGTH,
min_message_buffer_length=constants.DEFAULT_MIN_MESSAGE_BUFFER_LENGTH,
),
),
actor=actor,

View File

@ -125,10 +125,10 @@ class GroupManager:
sleeptime_agent_frequency = group_update.manager_config.sleeptime_agent_frequency
if sleeptime_agent_frequency and group.turns_counter is None:
group.turns_counter = -1
case ManagerType.sleeptime:
case ManagerType.voice_sleeptime:
manager_agent_id = group_update.manager_config.manager_agent_id
max_message_buffer_length = group_update.manager_config.max_message_buffer_length
min_message_buffer_length = group_update.manager_config.min_message_buffer_length
max_message_buffer_length = group_update.manager_config.max_message_buffer_length or group.max_message_buffer_length
min_message_buffer_length = group_update.manager_config.min_message_buffer_length or group.min_message_buffer_length
if sleeptime_agent_frequency and group.turns_counter is None:
group.turns_counter = -1
case _:
@ -304,8 +304,9 @@ class GroupManager:
min_name: str = "min_message_buffer_length",
) -> None:
"""
1) If one of max_value/min_value is set, the other must also be set.
2) If both are set, max_value must be greater than min_value.
1) Both-or-none: if one is set, the other must be set.
2) Both must be ints > 4.
3) max_value must be strictly greater than min_value.
"""
# 1) require both-or-none
if (max_value is None) != (min_value is None):
@ -313,6 +314,21 @@ class GroupManager:
f"Both '{max_name}' and '{min_name}' must be provided together " f"(got {max_name}={max_value}, {min_name}={min_value})"
)
# 2) valid range
if max_value is not None and min_value is not None and max_value <= min_value:
# no further checks if neither is provided
if max_value is None:
return
# 2) type & lowerbound checks
if not isinstance(max_value, int) or not isinstance(min_value, int):
raise ValueError(
f"Both '{max_name}' and '{min_name}' must be integers "
f"(got {max_name}={type(max_value).__name__}, {min_name}={type(min_value).__name__})"
)
if max_value <= 4 or min_value <= 4:
raise ValueError(
f"Both '{max_name}' and '{min_name}' must be greater than 4 " f"(got {max_name}={max_value}, {min_name}={min_value})"
)
# 3) ordering
if max_value <= min_value:
raise ValueError(f"'{max_name}' must be greater than '{min_name}' " f"(got {max_name}={max_value} <= {min_name}={min_value})")

View File

@ -10,12 +10,13 @@ from openai.types.chat import ChatCompletionChunk
from letta.agents.voice_sleeptime_agent import VoiceSleeptimeAgent
from letta.config import LettaConfig
from letta.constants import DEFAULT_MAX_MESSAGE_BUFFER_LENGTH, DEFAULT_MIN_MESSAGE_BUFFER_LENGTH
from letta.orm.errors import NoResultFound
from letta.schemas.agent import AgentType, CreateAgent
from letta.schemas.block import CreateBlock
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import MessageRole, MessageStreamStatus
from letta.schemas.group import ManagerType
from letta.schemas.group import GroupUpdate, ManagerType, VoiceSleeptimeManagerUpdate
from letta.schemas.letta_message import AssistantMessage, ReasoningMessage, ToolCallMessage, ToolReturnMessage, UserMessage
from letta.schemas.letta_message_content import TextContent
from letta.schemas.llm_config import LLMConfig
@ -497,3 +498,66 @@ async def test_init_voice_convo_agent(voice_agent, server, actor):
server.group_manager.retrieve_group(group_id=group.id, actor=actor)
with pytest.raises(NoResultFound):
server.agent_manager.get_agent_by_id(agent_id=sleeptime_agent_id, actor=actor)
def _modify(group_id, server, actor, max_val, min_val):
"""Helper to invoke modify_group with voice_sleeptime config."""
return server.group_manager.modify_group(
group_id=group_id,
group_update=GroupUpdate(
manager_config=VoiceSleeptimeManagerUpdate(
manager_type=ManagerType.voice_sleeptime,
max_message_buffer_length=max_val,
min_message_buffer_length=min_val,
)
),
actor=actor,
)
@pytest.fixture
def group_id(voice_agent):
return voice_agent.multi_agent_group.id
def test_valid_buffer_lengths_above_four(group_id, server, actor):
# both > 4 and max > min
updated = _modify(group_id, server, actor, max_val=10, min_val=5)
assert updated.max_message_buffer_length == 10
assert updated.min_message_buffer_length == 5
def test_valid_buffer_lengths_only_max(group_id, server, actor):
# both > 4 and max > min
updated = _modify(group_id, server, actor, max_val=DEFAULT_MAX_MESSAGE_BUFFER_LENGTH + 1, min_val=None)
assert updated.max_message_buffer_length == DEFAULT_MAX_MESSAGE_BUFFER_LENGTH + 1
assert updated.min_message_buffer_length == DEFAULT_MIN_MESSAGE_BUFFER_LENGTH
def test_valid_buffer_lengths_only_min(group_id, server, actor):
# both > 4 and max > min
updated = _modify(group_id, server, actor, max_val=None, min_val=DEFAULT_MIN_MESSAGE_BUFFER_LENGTH + 1)
assert updated.max_message_buffer_length == DEFAULT_MAX_MESSAGE_BUFFER_LENGTH
assert updated.min_message_buffer_length == DEFAULT_MIN_MESSAGE_BUFFER_LENGTH + 1
@pytest.mark.parametrize(
"max_val,min_val,err_part",
[
# only one set → both-or-none
(None, DEFAULT_MAX_MESSAGE_BUFFER_LENGTH, "must be greater than"),
(DEFAULT_MIN_MESSAGE_BUFFER_LENGTH, None, "must be greater than"),
# ordering violations
(5, 5, "must be greater than"),
(6, 7, "must be greater than"),
# lower-bound (must both be > 4)
(4, 5, "greater than 4"),
(5, 4, "greater than 4"),
(1, 10, "greater than 4"),
(10, 1, "greater than 4"),
],
)
def test_invalid_buffer_lengths(group_id, server, actor, max_val, min_val, err_part):
with pytest.raises(ValueError) as exc:
_modify(group_id, server, actor, max_val, min_val)
assert err_part in str(exc.value)