mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00

Co-authored-by: Andy Li <55300002+cliandy@users.noreply.github.com> Co-authored-by: Kevin Lin <klin5061@gmail.com> Co-authored-by: Sarah Wooders <sarahwooders@gmail.com> Co-authored-by: jnjpng <jin@letta.com> Co-authored-by: Matthew Zhou <mattzh1314@gmail.com>
472 lines
21 KiB
Python
472 lines
21 KiB
Python
import json
|
||
import uuid
|
||
from datetime import datetime, timedelta, timezone
|
||
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
|
||
|
||
import openai
|
||
|
||
from letta.agents.base_agent import BaseAgent
|
||
from letta.agents.exceptions import IncompatibleAgentType
|
||
from letta.agents.voice_sleeptime_agent import VoiceSleeptimeAgent
|
||
from letta.constants import NON_USER_MSG_PREFIX
|
||
from letta.helpers.datetime_helpers import get_utc_time
|
||
from letta.helpers.tool_execution_helper import (
|
||
add_pre_execution_message,
|
||
enable_strict_mode,
|
||
execute_external_tool,
|
||
remove_request_heartbeat,
|
||
)
|
||
from letta.interfaces.openai_chat_completions_streaming_interface import OpenAIChatCompletionsStreamingInterface
|
||
from letta.log import get_logger
|
||
from letta.orm.enums import ToolType
|
||
from letta.schemas.agent import AgentState, AgentType
|
||
from letta.schemas.enums import MessageRole
|
||
from letta.schemas.letta_response import LettaResponse
|
||
from letta.schemas.message import Message, MessageCreate
|
||
from letta.schemas.openai.chat_completion_request import (
|
||
AssistantMessage,
|
||
ChatCompletionRequest,
|
||
Tool,
|
||
ToolCall,
|
||
ToolCallFunction,
|
||
ToolMessage,
|
||
UserMessage,
|
||
)
|
||
from letta.schemas.user import User
|
||
from letta.server.rest_api.utils import (
|
||
convert_in_context_letta_messages_to_openai,
|
||
create_assistant_messages_from_openai_response,
|
||
create_input_messages,
|
||
create_letta_messages_from_llm_response,
|
||
)
|
||
from letta.services.agent_manager import AgentManager
|
||
from letta.services.block_manager import BlockManager
|
||
from letta.services.helpers.agent_manager_helper import compile_system_message
|
||
from letta.services.message_manager import MessageManager
|
||
from letta.services.passage_manager import PassageManager
|
||
from letta.services.summarizer.enums import SummarizationMode
|
||
from letta.services.summarizer.summarizer import Summarizer
|
||
from letta.settings import model_settings
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
|
||
class VoiceAgent(BaseAgent):
|
||
"""
|
||
A function-calling loop for streaming OpenAI responses with tool execution.
|
||
This agent:
|
||
- Streams partial tokens in real-time for low-latency output.
|
||
- Detects tool calls and invokes external tools.
|
||
- Gracefully handles OpenAI API failures (429, etc.) and streams errors.
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
agent_id: str,
|
||
openai_client: openai.AsyncClient,
|
||
message_manager: MessageManager,
|
||
agent_manager: AgentManager,
|
||
block_manager: BlockManager,
|
||
passage_manager: PassageManager,
|
||
actor: User,
|
||
):
|
||
super().__init__(
|
||
agent_id=agent_id, openai_client=openai_client, message_manager=message_manager, agent_manager=agent_manager, actor=actor
|
||
)
|
||
|
||
# Summarizer settings
|
||
self.block_manager = block_manager
|
||
self.passage_manager = passage_manager
|
||
# TODO: This is not guaranteed to exist!
|
||
self.summary_block_label = "human"
|
||
|
||
# Cached archival memory/message size
|
||
self.num_messages = self.message_manager.size(actor=self.actor, agent_id=agent_id)
|
||
self.num_archival_memories = self.passage_manager.size(actor=self.actor, agent_id=agent_id)
|
||
|
||
def init_summarizer(self, agent_state: AgentState) -> Summarizer:
|
||
if not agent_state.multi_agent_group:
|
||
raise ValueError("Low latency voice agent is not part of a multiagent group, missing sleeptime agent.")
|
||
if len(agent_state.multi_agent_group.agent_ids) != 1:
|
||
raise ValueError(
|
||
f"None or multiple participant agents found in voice sleeptime group: {agent_state.multi_agent_group.agent_ids}"
|
||
)
|
||
voice_sleeptime_agent_id = agent_state.multi_agent_group.agent_ids[0]
|
||
summarizer = Summarizer(
|
||
mode=SummarizationMode.STATIC_MESSAGE_BUFFER,
|
||
summarizer_agent=VoiceSleeptimeAgent(
|
||
agent_id=voice_sleeptime_agent_id,
|
||
convo_agent_state=agent_state,
|
||
message_manager=self.message_manager,
|
||
agent_manager=self.agent_manager,
|
||
actor=self.actor,
|
||
block_manager=self.block_manager,
|
||
passage_manager=self.passage_manager,
|
||
target_block_label=self.summary_block_label,
|
||
),
|
||
message_buffer_limit=agent_state.multi_agent_group.max_message_buffer_length,
|
||
message_buffer_min=agent_state.multi_agent_group.min_message_buffer_length,
|
||
)
|
||
|
||
return summarizer
|
||
|
||
async def step(self, input_messages: List[MessageCreate], max_steps: int = 10) -> LettaResponse:
|
||
raise NotImplementedError("VoiceAgent does not have a synchronous step implemented currently.")
|
||
|
||
async def step_stream(self, input_messages: List[MessageCreate], max_steps: int = 10) -> AsyncGenerator[str, None]:
|
||
"""
|
||
Main streaming loop that yields partial tokens.
|
||
Whenever we detect a tool call, we yield from _handle_ai_response as well.
|
||
"""
|
||
if len(input_messages) != 1 or input_messages[0].role != MessageRole.user:
|
||
raise ValueError(f"Voice Agent was invoked with multiple input messages or message did not have role `user`: {input_messages}")
|
||
|
||
user_query = input_messages[0].content[0].text
|
||
|
||
agent_state = self.agent_manager.get_agent_by_id(self.agent_id, actor=self.actor)
|
||
|
||
# TODO: Refactor this so it uses our in-house clients
|
||
# TODO: For now, piggyback off of OpenAI client for ease
|
||
if agent_state.llm_config.model_endpoint_type == "anthropic":
|
||
self.openai_client.api_key = model_settings.anthropic_api_key
|
||
self.openai_client.base_url = "https://api.anthropic.com/v1/"
|
||
elif agent_state.llm_config.model_endpoint_type != "openai":
|
||
raise ValueError("Letta voice agents are only compatible with OpenAI or Anthropic.")
|
||
|
||
# Safety check
|
||
if agent_state.agent_type != AgentType.voice_convo_agent:
|
||
raise IncompatibleAgentType(expected_type=AgentType.voice_convo_agent, actual_type=agent_state.agent_type)
|
||
|
||
summarizer = self.init_summarizer(agent_state=agent_state)
|
||
|
||
in_context_messages = self.message_manager.get_messages_by_ids(message_ids=agent_state.message_ids, actor=self.actor)
|
||
memory_edit_timestamp = get_utc_time()
|
||
in_context_messages[0].content[0].text = compile_system_message(
|
||
system_prompt=agent_state.system,
|
||
in_context_memory=agent_state.memory,
|
||
in_context_memory_last_edit=memory_edit_timestamp,
|
||
previous_message_count=self.num_messages,
|
||
archival_memory_size=self.num_archival_memories,
|
||
)
|
||
letta_message_db_queue = create_input_messages(input_messages=input_messages, agent_id=agent_state.id, actor=self.actor)
|
||
in_memory_message_history = self.pre_process_input_message(input_messages)
|
||
|
||
# TODO: Define max steps here
|
||
for _ in range(max_steps):
|
||
# Rebuild memory each loop
|
||
in_context_messages = await self._rebuild_memory_async(in_context_messages, agent_state)
|
||
openai_messages = convert_in_context_letta_messages_to_openai(in_context_messages, exclude_system_messages=True)
|
||
openai_messages.extend(in_memory_message_history)
|
||
|
||
request = self._build_openai_request(openai_messages, agent_state)
|
||
|
||
stream = await self.openai_client.chat.completions.create(**request.model_dump(exclude_unset=True))
|
||
streaming_interface = OpenAIChatCompletionsStreamingInterface(stream_pre_execution_message=True)
|
||
|
||
# 1) Yield partial tokens from OpenAI
|
||
async for sse_chunk in streaming_interface.process(stream):
|
||
yield sse_chunk
|
||
|
||
# 2) Now handle the final AI response. This might yield more text (stalling, etc.)
|
||
should_continue = await self._handle_ai_response(
|
||
user_query,
|
||
streaming_interface,
|
||
agent_state,
|
||
in_memory_message_history,
|
||
letta_message_db_queue,
|
||
)
|
||
|
||
if not should_continue:
|
||
break
|
||
|
||
# Rebuild context window if desired
|
||
await self._rebuild_context_window(summarizer, in_context_messages, letta_message_db_queue)
|
||
|
||
# TODO: This may be out of sync, if in between steps users add files
|
||
self.num_messages = self.message_manager.size(actor=self.actor, agent_id=agent_state.id)
|
||
self.num_archival_memories = self.passage_manager.size(actor=self.actor, agent_id=agent_state.id)
|
||
|
||
yield "data: [DONE]\n\n"
|
||
|
||
async def _handle_ai_response(
|
||
self,
|
||
user_query: str,
|
||
streaming_interface: "OpenAIChatCompletionsStreamingInterface",
|
||
agent_state: AgentState,
|
||
in_memory_message_history: List[Dict[str, Any]],
|
||
letta_message_db_queue: List[Any],
|
||
) -> bool:
|
||
"""
|
||
Now that streaming is done, handle the final AI response.
|
||
This might yield additional SSE tokens if we do stalling.
|
||
At the end, set self._continue_execution accordingly.
|
||
"""
|
||
# 1. If we have any leftover content from partial stream, store it as an assistant message
|
||
if streaming_interface.content_buffer:
|
||
content = "".join(streaming_interface.content_buffer)
|
||
in_memory_message_history.append({"role": "assistant", "content": content})
|
||
|
||
assistant_msgs = create_assistant_messages_from_openai_response(
|
||
response_text=content,
|
||
agent_id=agent_state.id,
|
||
model=agent_state.llm_config.model,
|
||
actor=self.actor,
|
||
)
|
||
letta_message_db_queue.extend(assistant_msgs)
|
||
|
||
# 2. If a tool call was requested, handle it
|
||
if streaming_interface.tool_call_happened:
|
||
tool_call_name = streaming_interface.tool_call_name
|
||
tool_call_args_str = streaming_interface.tool_call_args_str or "{}"
|
||
try:
|
||
tool_args = json.loads(tool_call_args_str)
|
||
except json.JSONDecodeError:
|
||
tool_args = {}
|
||
|
||
tool_call_id = streaming_interface.tool_call_id or f"call_{uuid.uuid4().hex[:8]}"
|
||
assistant_tool_call_msg = AssistantMessage(
|
||
content=None,
|
||
tool_calls=[
|
||
ToolCall(
|
||
id=tool_call_id,
|
||
function=ToolCallFunction(
|
||
name=tool_call_name,
|
||
arguments=tool_call_args_str,
|
||
),
|
||
)
|
||
],
|
||
)
|
||
in_memory_message_history.append(assistant_tool_call_msg.model_dump())
|
||
|
||
tool_result, success_flag = await self._execute_tool(
|
||
user_query=user_query,
|
||
tool_name=tool_call_name,
|
||
tool_args=tool_args,
|
||
agent_state=agent_state,
|
||
)
|
||
|
||
# 3. Provide function_call response back into the conversation
|
||
tool_message = ToolMessage(
|
||
content=json.dumps({"result": tool_result}),
|
||
tool_call_id=tool_call_id,
|
||
)
|
||
in_memory_message_history.append(tool_message.model_dump())
|
||
|
||
# 4. Insert heartbeat message for follow-up
|
||
heartbeat_user_message = UserMessage(
|
||
content=f"{NON_USER_MSG_PREFIX} Tool finished executing. Summarize the result for the user."
|
||
)
|
||
in_memory_message_history.append(heartbeat_user_message.model_dump())
|
||
|
||
# 5. Also store in DB
|
||
tool_call_messages = create_letta_messages_from_llm_response(
|
||
agent_id=agent_state.id,
|
||
model=agent_state.llm_config.model,
|
||
function_name=tool_call_name,
|
||
function_arguments=tool_args,
|
||
tool_call_id=tool_call_id,
|
||
function_call_success=success_flag,
|
||
function_response=tool_result,
|
||
actor=self.actor,
|
||
add_heartbeat_request_system_message=True,
|
||
)
|
||
letta_message_db_queue.extend(tool_call_messages)
|
||
|
||
# Because we have new data, we want to continue the while-loop in `step_stream`
|
||
return True
|
||
else:
|
||
# If we got here, there's no tool call. If finish_reason_stop => done
|
||
return not streaming_interface.finish_reason_stop
|
||
|
||
async def _rebuild_context_window(
|
||
self, summarizer: Summarizer, in_context_messages: List[Message], letta_message_db_queue: List[Message]
|
||
) -> None:
|
||
new_letta_messages = self.message_manager.create_many_messages(letta_message_db_queue, actor=self.actor)
|
||
|
||
# TODO: Make this more general and configurable, less brittle
|
||
new_in_context_messages, updated = summarizer.summarize(
|
||
in_context_messages=in_context_messages, new_letta_messages=new_letta_messages
|
||
)
|
||
|
||
self.agent_manager.set_in_context_messages(
|
||
agent_id=self.agent_id, message_ids=[m.id for m in new_in_context_messages], actor=self.actor
|
||
)
|
||
|
||
async def _rebuild_memory_async(
|
||
self,
|
||
in_context_messages: List[Message],
|
||
agent_state: AgentState,
|
||
num_messages: int | None = None,
|
||
num_archival_memories: int | None = None,
|
||
) -> List[Message]:
|
||
return await super()._rebuild_memory_async(
|
||
in_context_messages, agent_state, num_messages=self.num_messages, num_archival_memories=self.num_archival_memories
|
||
)
|
||
|
||
def _build_openai_request(self, openai_messages: List[Dict], agent_state: AgentState) -> ChatCompletionRequest:
|
||
tool_schemas = self._build_tool_schemas(agent_state)
|
||
tool_choice = "auto" if tool_schemas else None
|
||
|
||
openai_request = ChatCompletionRequest(
|
||
model=agent_state.llm_config.model,
|
||
messages=openai_messages,
|
||
tools=self._build_tool_schemas(agent_state),
|
||
tool_choice=tool_choice,
|
||
user=self.actor.id,
|
||
max_completion_tokens=agent_state.llm_config.max_tokens,
|
||
temperature=agent_state.llm_config.temperature,
|
||
stream=True,
|
||
)
|
||
return openai_request
|
||
|
||
def _build_tool_schemas(self, agent_state: AgentState, external_tools_only=True) -> List[Tool]:
|
||
if external_tools_only:
|
||
tools = [t for t in agent_state.tools if t.tool_type in {ToolType.EXTERNAL_COMPOSIO, ToolType.CUSTOM}]
|
||
else:
|
||
tools = agent_state.tools
|
||
|
||
# Special tool state
|
||
search_memory_utterance_description = (
|
||
"A lengthier message to be uttered while your memories of the current conversation are being re-contextualized."
|
||
"You MUST also include punctuation at the end of this message."
|
||
"For example: 'Let me double-check my notes—one moment, please.'"
|
||
)
|
||
|
||
search_memory_json = Tool(
|
||
type="function",
|
||
function=enable_strict_mode( # strict=True ✓
|
||
add_pre_execution_message( # injects pre_exec_msg ✓
|
||
{
|
||
"name": "search_memory",
|
||
"description": (
|
||
"Look in long-term or earlier-conversation memory **only when** the "
|
||
"user asks about something missing from the visible context. "
|
||
"The user’s latest utterance is sent automatically as the main query.\n\n"
|
||
"Optional refinements (set unused fields to *null*):\n"
|
||
"• `convo_keyword_queries` – extra names/IDs if the request is vague.\n"
|
||
"• `start_minutes_ago` / `end_minutes_ago` – limit results to a recent time window."
|
||
),
|
||
"parameters": {
|
||
"type": "object",
|
||
"properties": {
|
||
"convo_keyword_queries": {
|
||
"type": ["array", "null"],
|
||
"items": {"type": "string"},
|
||
"description": (
|
||
"Extra keywords (e.g., order ID, place name). " "Use *null* when the utterance is already specific."
|
||
),
|
||
},
|
||
"start_minutes_ago": {
|
||
"type": ["integer", "null"],
|
||
"description": (
|
||
"Newer bound of the time window, in minutes ago. " "Use *null* if no lower bound is needed."
|
||
),
|
||
},
|
||
"end_minutes_ago": {
|
||
"type": ["integer", "null"],
|
||
"description": (
|
||
"Older bound of the time window, in minutes ago. " "Use *null* if no upper bound is needed."
|
||
),
|
||
},
|
||
},
|
||
"required": [
|
||
"convo_keyword_queries",
|
||
"start_minutes_ago",
|
||
"end_minutes_ago",
|
||
],
|
||
"additionalProperties": False,
|
||
},
|
||
},
|
||
description=search_memory_utterance_description,
|
||
)
|
||
),
|
||
)
|
||
|
||
# TODO: Customize whether or not to have heartbeats, pre_exec_message, etc.
|
||
return [search_memory_json] + [
|
||
Tool(type="function", function=enable_strict_mode(add_pre_execution_message(remove_request_heartbeat(t.json_schema))))
|
||
for t in tools
|
||
]
|
||
|
||
async def _execute_tool(self, user_query: str, tool_name: str, tool_args: dict, agent_state: AgentState) -> Tuple[str, bool]:
|
||
"""
|
||
Executes a tool and returns (result, success_flag).
|
||
"""
|
||
# Special memory case
|
||
if tool_name == "search_memory":
|
||
tool_result = await self._search_memory(
|
||
archival_query=user_query,
|
||
convo_keyword_queries=tool_args["convo_keyword_queries"],
|
||
start_minutes_ago=tool_args["start_minutes_ago"],
|
||
end_minutes_ago=tool_args["end_minutes_ago"],
|
||
agent_state=agent_state,
|
||
)
|
||
return tool_result, True
|
||
else:
|
||
target_tool = next((x for x in agent_state.tools if x.name == tool_name), None)
|
||
if not target_tool:
|
||
return f"Tool not found: {tool_name}", False
|
||
|
||
try:
|
||
tool_result, _ = execute_external_tool(
|
||
agent_state=agent_state,
|
||
function_name=tool_name,
|
||
function_args=tool_args,
|
||
target_letta_tool=target_tool,
|
||
actor=self.actor,
|
||
allow_agent_state_modifications=False,
|
||
)
|
||
return tool_result, True
|
||
except Exception as e:
|
||
return f"Failed to call tool. Error: {e}", False
|
||
|
||
async def _search_memory(
|
||
self,
|
||
archival_query: str,
|
||
agent_state: AgentState,
|
||
convo_keyword_queries: Optional[List[str]] = None,
|
||
start_minutes_ago: Optional[int] = None,
|
||
end_minutes_ago: Optional[int] = None,
|
||
) -> str:
|
||
# Retrieve from archival memory
|
||
now = datetime.now(timezone.utc)
|
||
start_date = now - timedelta(minutes=end_minutes_ago) if end_minutes_ago is not None else None
|
||
end_date = now - timedelta(minutes=start_minutes_ago) if start_minutes_ago is not None else None
|
||
|
||
# If both bounds exist but got reversed, swap them
|
||
# Shouldn't happen, but in case LLM misunderstands
|
||
if start_date and end_date and start_date > end_date:
|
||
start_date, end_date = end_date, start_date
|
||
|
||
archival_results = await self.agent_manager.list_passages_async(
|
||
actor=self.actor,
|
||
agent_id=self.agent_id,
|
||
query_text=archival_query,
|
||
limit=5,
|
||
embedding_config=agent_state.embedding_config,
|
||
embed_query=True,
|
||
start_date=start_date,
|
||
end_date=end_date,
|
||
)
|
||
formatted_archival_results = [{"timestamp": str(result.created_at), "content": result.text} for result in archival_results]
|
||
response = {
|
||
"archival_search_results": formatted_archival_results,
|
||
}
|
||
|
||
# Retrieve from conversation
|
||
keyword_results = {}
|
||
if convo_keyword_queries:
|
||
for keyword in convo_keyword_queries:
|
||
messages = await self.message_manager.list_messages_for_agent_async(
|
||
agent_id=self.agent_id,
|
||
actor=self.actor,
|
||
query_text=keyword,
|
||
limit=3,
|
||
)
|
||
if messages:
|
||
keyword_results[keyword] = [message.content[0].text for message in messages]
|
||
|
||
response["convo_keyword_search_results"] = keyword_results
|
||
|
||
return json.dumps(response, indent=2)
|