mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
feat: Add safety checks to buffer length update and tests to updating voice sleeptime (#1959)
This commit is contained in:
parent
7fb90ff612
commit
b756f31b94
@ -35,6 +35,10 @@ TOOL_CALL_ID_MAX_LEN = 29
|
||||
# minimum context window size
|
||||
MIN_CONTEXT_WINDOW = 4096
|
||||
|
||||
# Voice Sleeptime message buffer lengths
|
||||
DEFAULT_MAX_MESSAGE_BUFFER_LENGTH = 30
|
||||
DEFAULT_MIN_MESSAGE_BUFFER_LENGTH = 15
|
||||
|
||||
# embeddings
|
||||
MAX_EMBEDDING_DIM = 4096 # maximum supported embeding size - do NOT change or else DBs will need to be reset
|
||||
DEFAULT_EMBEDDING_CHUNK_SIZE = 300
|
||||
|
@ -862,8 +862,8 @@ class SyncServer(Server):
|
||||
agent_ids=[voice_sleeptime_agent.id],
|
||||
manager_config=VoiceSleeptimeManager(
|
||||
manager_agent_id=main_agent.id,
|
||||
max_message_buffer_length=30,
|
||||
min_message_buffer_length=15,
|
||||
max_message_buffer_length=constants.DEFAULT_MAX_MESSAGE_BUFFER_LENGTH,
|
||||
min_message_buffer_length=constants.DEFAULT_MIN_MESSAGE_BUFFER_LENGTH,
|
||||
),
|
||||
),
|
||||
actor=actor,
|
||||
|
@ -125,10 +125,10 @@ class GroupManager:
|
||||
sleeptime_agent_frequency = group_update.manager_config.sleeptime_agent_frequency
|
||||
if sleeptime_agent_frequency and group.turns_counter is None:
|
||||
group.turns_counter = -1
|
||||
case ManagerType.sleeptime:
|
||||
case ManagerType.voice_sleeptime:
|
||||
manager_agent_id = group_update.manager_config.manager_agent_id
|
||||
max_message_buffer_length = group_update.manager_config.max_message_buffer_length
|
||||
min_message_buffer_length = group_update.manager_config.min_message_buffer_length
|
||||
max_message_buffer_length = group_update.manager_config.max_message_buffer_length or group.max_message_buffer_length
|
||||
min_message_buffer_length = group_update.manager_config.min_message_buffer_length or group.min_message_buffer_length
|
||||
if sleeptime_agent_frequency and group.turns_counter is None:
|
||||
group.turns_counter = -1
|
||||
case _:
|
||||
@ -304,8 +304,9 @@ class GroupManager:
|
||||
min_name: str = "min_message_buffer_length",
|
||||
) -> None:
|
||||
"""
|
||||
1) If one of max_value/min_value is set, the other must also be set.
|
||||
2) If both are set, max_value must be greater than min_value.
|
||||
1) Both-or-none: if one is set, the other must be set.
|
||||
2) Both must be ints > 4.
|
||||
3) max_value must be strictly greater than min_value.
|
||||
"""
|
||||
# 1) require both-or-none
|
||||
if (max_value is None) != (min_value is None):
|
||||
@ -313,6 +314,21 @@ class GroupManager:
|
||||
f"Both '{max_name}' and '{min_name}' must be provided together " f"(got {max_name}={max_value}, {min_name}={min_value})"
|
||||
)
|
||||
|
||||
# 2) valid range
|
||||
if max_value is not None and min_value is not None and max_value <= min_value:
|
||||
# no further checks if neither is provided
|
||||
if max_value is None:
|
||||
return
|
||||
|
||||
# 2) type & lower‐bound checks
|
||||
if not isinstance(max_value, int) or not isinstance(min_value, int):
|
||||
raise ValueError(
|
||||
f"Both '{max_name}' and '{min_name}' must be integers "
|
||||
f"(got {max_name}={type(max_value).__name__}, {min_name}={type(min_value).__name__})"
|
||||
)
|
||||
if max_value <= 4 or min_value <= 4:
|
||||
raise ValueError(
|
||||
f"Both '{max_name}' and '{min_name}' must be greater than 4 " f"(got {max_name}={max_value}, {min_name}={min_value})"
|
||||
)
|
||||
|
||||
# 3) ordering
|
||||
if max_value <= min_value:
|
||||
raise ValueError(f"'{max_name}' must be greater than '{min_name}' " f"(got {max_name}={max_value} <= {min_name}={min_value})")
|
||||
|
@ -10,12 +10,13 @@ from openai.types.chat import ChatCompletionChunk
|
||||
|
||||
from letta.agents.voice_sleeptime_agent import VoiceSleeptimeAgent
|
||||
from letta.config import LettaConfig
|
||||
from letta.constants import DEFAULT_MAX_MESSAGE_BUFFER_LENGTH, DEFAULT_MIN_MESSAGE_BUFFER_LENGTH
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.schemas.agent import AgentType, CreateAgent
|
||||
from letta.schemas.block import CreateBlock
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import MessageRole, MessageStreamStatus
|
||||
from letta.schemas.group import ManagerType
|
||||
from letta.schemas.group import GroupUpdate, ManagerType, VoiceSleeptimeManagerUpdate
|
||||
from letta.schemas.letta_message import AssistantMessage, ReasoningMessage, ToolCallMessage, ToolReturnMessage, UserMessage
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
@ -497,3 +498,66 @@ async def test_init_voice_convo_agent(voice_agent, server, actor):
|
||||
server.group_manager.retrieve_group(group_id=group.id, actor=actor)
|
||||
with pytest.raises(NoResultFound):
|
||||
server.agent_manager.get_agent_by_id(agent_id=sleeptime_agent_id, actor=actor)
|
||||
|
||||
|
||||
def _modify(group_id, server, actor, max_val, min_val):
|
||||
"""Helper to invoke modify_group with voice_sleeptime config."""
|
||||
return server.group_manager.modify_group(
|
||||
group_id=group_id,
|
||||
group_update=GroupUpdate(
|
||||
manager_config=VoiceSleeptimeManagerUpdate(
|
||||
manager_type=ManagerType.voice_sleeptime,
|
||||
max_message_buffer_length=max_val,
|
||||
min_message_buffer_length=min_val,
|
||||
)
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def group_id(voice_agent):
|
||||
return voice_agent.multi_agent_group.id
|
||||
|
||||
|
||||
def test_valid_buffer_lengths_above_four(group_id, server, actor):
|
||||
# both > 4 and max > min
|
||||
updated = _modify(group_id, server, actor, max_val=10, min_val=5)
|
||||
assert updated.max_message_buffer_length == 10
|
||||
assert updated.min_message_buffer_length == 5
|
||||
|
||||
|
||||
def test_valid_buffer_lengths_only_max(group_id, server, actor):
|
||||
# both > 4 and max > min
|
||||
updated = _modify(group_id, server, actor, max_val=DEFAULT_MAX_MESSAGE_BUFFER_LENGTH + 1, min_val=None)
|
||||
assert updated.max_message_buffer_length == DEFAULT_MAX_MESSAGE_BUFFER_LENGTH + 1
|
||||
assert updated.min_message_buffer_length == DEFAULT_MIN_MESSAGE_BUFFER_LENGTH
|
||||
|
||||
|
||||
def test_valid_buffer_lengths_only_min(group_id, server, actor):
|
||||
# both > 4 and max > min
|
||||
updated = _modify(group_id, server, actor, max_val=None, min_val=DEFAULT_MIN_MESSAGE_BUFFER_LENGTH + 1)
|
||||
assert updated.max_message_buffer_length == DEFAULT_MAX_MESSAGE_BUFFER_LENGTH
|
||||
assert updated.min_message_buffer_length == DEFAULT_MIN_MESSAGE_BUFFER_LENGTH + 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"max_val,min_val,err_part",
|
||||
[
|
||||
# only one set → both-or-none
|
||||
(None, DEFAULT_MAX_MESSAGE_BUFFER_LENGTH, "must be greater than"),
|
||||
(DEFAULT_MIN_MESSAGE_BUFFER_LENGTH, None, "must be greater than"),
|
||||
# ordering violations
|
||||
(5, 5, "must be greater than"),
|
||||
(6, 7, "must be greater than"),
|
||||
# lower-bound (must both be > 4)
|
||||
(4, 5, "greater than 4"),
|
||||
(5, 4, "greater than 4"),
|
||||
(1, 10, "greater than 4"),
|
||||
(10, 1, "greater than 4"),
|
||||
],
|
||||
)
|
||||
def test_invalid_buffer_lengths(group_id, server, actor, max_val, min_val, err_part):
|
||||
with pytest.raises(ValueError) as exc:
|
||||
_modify(group_id, server, actor, max_val, min_val)
|
||||
assert err_part in str(exc.value)
|
||||
|
Loading…
Reference in New Issue
Block a user