feat: Add safety checks for voice agent configuration (#1955)

This commit is contained in:
Matthew Zhou 2025-04-30 16:46:45 -07:00 committed by GitHub
parent a55d74f208
commit 2d7f90e38c
3 changed files with 26 additions and 1 deletions

View File

@ -0,0 +1,6 @@
class IncompatibleAgentType(ValueError):
def __init__(self, expected_type: str, actual_type: str):
message = f"Incompatible agent type: expected '{expected_type}', but got '{actual_type}'."
super().__init__(message)
self.expected_type = expected_type
self.actual_type = actual_type

View File

@ -6,6 +6,7 @@ from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
import openai
from letta.agents.base_agent import BaseAgent
from letta.agents.exceptions import IncompatibleAgentType
from letta.agents.voice_sleeptime_agent import VoiceSleeptimeAgent
from letta.constants import NON_USER_MSG_PREFIX
from letta.helpers.datetime_helpers import get_utc_time
@ -18,7 +19,7 @@ from letta.helpers.tool_execution_helper import (
from letta.interfaces.openai_chat_completions_streaming_interface import OpenAIChatCompletionsStreamingInterface
from letta.log import get_logger
from letta.orm.enums import ToolType
from letta.schemas.agent import AgentState
from letta.schemas.agent import AgentState, AgentType
from letta.schemas.enums import MessageRole
from letta.schemas.letta_response import LettaResponse
from letta.schemas.message import Message, MessageCreate, MessageUpdate
@ -124,9 +125,15 @@ class VoiceAgent(BaseAgent):
"""
if len(input_messages) != 1 or input_messages[0].role != MessageRole.user:
raise ValueError(f"Voice Agent was invoked with multiple input messages or message did not have role `user`: {input_messages}")
user_query = input_messages[0].content[0].text
agent_state = self.agent_manager.get_agent_by_id(self.agent_id, actor=self.actor)
# Safety check
if agent_state.agent_type != AgentType.voice_convo_agent:
raise IncompatibleAgentType(expected_type=AgentType.voice_convo_agent, actual_type=agent_state.agent_type)
summarizer = self.init_summarizer(agent_state=agent_state)
in_context_messages = self.message_manager.get_messages_by_ids(message_ids=agent_state.message_ids, actor=self.actor)

View File

@ -14,6 +14,7 @@ from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.cors import CORSMiddleware
from letta.__init__ import __version__
from letta.agents.exceptions import IncompatibleAgentType
from letta.constants import ADMIN_PREFIX, API_PREFIX, OPENAI_API_PREFIX
from letta.errors import BedrockPermissionError, LettaAgentNotFoundError, LettaUserNotFoundError
from letta.jobs.scheduler import shutdown_cron_scheduler, start_cron_jobs
@ -173,6 +174,17 @@ def create_application() -> "FastAPI":
def shutdown_scheduler():
shutdown_cron_scheduler()
@app.exception_handler(IncompatibleAgentType)
async def handle_incompatible_agent_type(request: Request, exc: IncompatibleAgentType):
return JSONResponse(
status_code=400,
content={
"detail": str(exc),
"expected_type": exc.expected_type,
"actual_type": exc.actual_type,
},
)
@app.exception_handler(Exception)
async def generic_error_handler(request: Request, exc: Exception):
# Log the actual error for debugging