mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
feat: Add safety checks for voice agent configuration (#1955)
This commit is contained in:
parent
a55d74f208
commit
2d7f90e38c
6
letta/agents/exceptions.py
Normal file
6
letta/agents/exceptions.py
Normal file
@ -0,0 +1,6 @@
|
||||
class IncompatibleAgentType(ValueError):
|
||||
def __init__(self, expected_type: str, actual_type: str):
|
||||
message = f"Incompatible agent type: expected '{expected_type}', but got '{actual_type}'."
|
||||
super().__init__(message)
|
||||
self.expected_type = expected_type
|
||||
self.actual_type = actual_type
|
@ -6,6 +6,7 @@ from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
|
||||
import openai
|
||||
|
||||
from letta.agents.base_agent import BaseAgent
|
||||
from letta.agents.exceptions import IncompatibleAgentType
|
||||
from letta.agents.voice_sleeptime_agent import VoiceSleeptimeAgent
|
||||
from letta.constants import NON_USER_MSG_PREFIX
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
@ -18,7 +19,7 @@ from letta.helpers.tool_execution_helper import (
|
||||
from letta.interfaces.openai_chat_completions_streaming_interface import OpenAIChatCompletionsStreamingInterface
|
||||
from letta.log import get_logger
|
||||
from letta.orm.enums import ToolType
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.agent import AgentState, AgentType
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.letta_response import LettaResponse
|
||||
from letta.schemas.message import Message, MessageCreate, MessageUpdate
|
||||
@ -124,9 +125,15 @@ class VoiceAgent(BaseAgent):
|
||||
"""
|
||||
if len(input_messages) != 1 or input_messages[0].role != MessageRole.user:
|
||||
raise ValueError(f"Voice Agent was invoked with multiple input messages or message did not have role `user`: {input_messages}")
|
||||
|
||||
user_query = input_messages[0].content[0].text
|
||||
|
||||
agent_state = self.agent_manager.get_agent_by_id(self.agent_id, actor=self.actor)
|
||||
|
||||
# Safety check
|
||||
if agent_state.agent_type != AgentType.voice_convo_agent:
|
||||
raise IncompatibleAgentType(expected_type=AgentType.voice_convo_agent, actual_type=agent_state.agent_type)
|
||||
|
||||
summarizer = self.init_summarizer(agent_state=agent_state)
|
||||
|
||||
in_context_messages = self.message_manager.get_messages_by_ids(message_ids=agent_state.message_ids, actor=self.actor)
|
||||
|
@ -14,6 +14,7 @@ from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
|
||||
from letta.__init__ import __version__
|
||||
from letta.agents.exceptions import IncompatibleAgentType
|
||||
from letta.constants import ADMIN_PREFIX, API_PREFIX, OPENAI_API_PREFIX
|
||||
from letta.errors import BedrockPermissionError, LettaAgentNotFoundError, LettaUserNotFoundError
|
||||
from letta.jobs.scheduler import shutdown_cron_scheduler, start_cron_jobs
|
||||
@ -173,6 +174,17 @@ def create_application() -> "FastAPI":
|
||||
def shutdown_scheduler():
|
||||
shutdown_cron_scheduler()
|
||||
|
||||
@app.exception_handler(IncompatibleAgentType)
|
||||
async def handle_incompatible_agent_type(request: Request, exc: IncompatibleAgentType):
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={
|
||||
"detail": str(exc),
|
||||
"expected_type": exc.expected_type,
|
||||
"actual_type": exc.actual_type,
|
||||
},
|
||||
)
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def generic_error_handler(request: Request, exc: Exception):
|
||||
# Log the actual error for debugging
|
||||
|
Loading…
Reference in New Issue
Block a user