MemGPT/letta/streaming_interface.py
2025-02-20 11:37:52 -08:00

399 lines
15 KiB
Python

import json
from abc import ABC, abstractmethod
from datetime import datetime
from typing import List, Optional
# from colorama import Fore, Style, init
from rich.console import Console
from rich.live import Live
from rich.markup import escape
from letta.interface import CLIInterface
from letta.local_llm.constants import ASSISTANT_MESSAGE_CLI_SYMBOL, INNER_THOUGHTS_CLI_SYMBOL
from letta.schemas.message import Message
from letta.schemas.openai.chat_completion_response import ChatCompletionChunkResponse, ChatCompletionResponse
# init(autoreset=True)
# DEBUG = True # puts full message outputs in the terminal
DEBUG = False # only dumps important messages in the terminal
STRIP_UI = False
class AgentChunkStreamingInterface(ABC):
"""Interfaces handle Letta-related events (observer pattern)
The 'msg' args provides the scoped message, and the optional Message arg can provide additional metadata.
"""
@abstractmethod
def user_message(self, msg: str, msg_obj: Optional[Message] = None):
"""Letta receives a user message"""
raise NotImplementedError
@abstractmethod
def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None):
"""Letta generates some internal monologue"""
raise NotImplementedError
@abstractmethod
def assistant_message(self, msg: str, msg_obj: Optional[Message] = None):
"""Letta uses send_message"""
raise NotImplementedError
@abstractmethod
def function_message(self, msg: str, msg_obj: Optional[Message] = None):
"""Letta calls a function"""
raise NotImplementedError
@abstractmethod
def process_chunk(
self, chunk: ChatCompletionChunkResponse, message_id: str, message_date: datetime, expect_reasoning_content: bool = False
):
"""Process a streaming chunk from an OpenAI-compatible server"""
raise NotImplementedError
@abstractmethod
def stream_start(self):
"""Any setup required before streaming begins"""
raise NotImplementedError
@abstractmethod
def stream_end(self):
"""Any cleanup required after streaming ends"""
raise NotImplementedError
class StreamingCLIInterface(AgentChunkStreamingInterface):
"""Version of the CLI interface that attaches to a stream generator and prints along the way.
When a chunk is received, we write the delta to the buffer. If the buffer type has changed,
we write out a newline + set the formatting for the new line.
The two buffer types are:
(1) content (inner thoughts)
(2) tool_calls (function calling)
NOTE: this assumes that the deltas received in the chunks are in-order, e.g.
that once 'content' deltas stop streaming, they won't be received again. See notes
on alternative version of the StreamingCLIInterface that does not have this same problem below:
An alternative implementation could instead maintain the partial message state, and on each
process chunk (1) update the partial message state, (2) refresh/rewrite the state to the screen.
"""
# CLIInterface is static/stateless
nonstreaming_interface = CLIInterface()
def __init__(self):
"""The streaming CLI interface state for determining which buffer is currently being written to"""
self.streaming_buffer_type = None
def _flush(self):
pass
def process_chunk(
self, chunk: ChatCompletionChunkResponse, message_id: str, message_date: datetime, expect_reasoning_content: bool = False
):
assert len(chunk.choices) == 1, chunk
message_delta = chunk.choices[0].delta
# Starting a new buffer line
if not self.streaming_buffer_type:
assert not (
message_delta.content is not None and message_delta.tool_calls is not None and len(message_delta.tool_calls)
), f"Error: got both content and tool_calls in message stream\n{message_delta}"
if message_delta.content is not None:
# Write out the prefix for inner thoughts
print("Inner thoughts: ", end="", flush=True)
elif message_delta.tool_calls is not None:
assert len(message_delta.tool_calls) == 1, f"Error: got more than one tool call in response\n{message_delta}"
# Write out the prefix for function calling
print("Calling function: ", end="", flush=True)
# Potentially switch/flush a buffer line
else:
pass
# Write out the delta
if message_delta.content is not None:
if self.streaming_buffer_type and self.streaming_buffer_type != "content":
print()
self.streaming_buffer_type = "content"
# Simple, just write out to the buffer
print(message_delta.content, end="", flush=True)
elif message_delta.tool_calls is not None:
if self.streaming_buffer_type and self.streaming_buffer_type != "tool_calls":
print()
self.streaming_buffer_type = "tool_calls"
assert len(message_delta.tool_calls) == 1, f"Error: got more than one tool call in response\n{message_delta}"
function_call = message_delta.tool_calls[0].function
# Slightly more complex - want to write parameters in a certain way (paren-style)
# function_name(function_args)
if function_call and function_call.name:
# NOTE: need to account for closing the brace later
print(f"{function_call.name}(", end="", flush=True)
if function_call and function_call.arguments:
print(function_call.arguments, end="", flush=True)
def stream_start(self):
# should be handled by stream_end(), but just in case
self.streaming_buffer_type = None
def stream_end(self):
if self.streaming_buffer_type is not None:
# TODO: should have a separate self.tool_call_open_paren flag
if self.streaming_buffer_type == "tool_calls":
print(")", end="", flush=True)
print() # newline to move the cursor
self.streaming_buffer_type = None # reset buffer tracker
@staticmethod
def important_message(msg: str):
StreamingCLIInterface.nonstreaming_interface(msg)
@staticmethod
def warning_message(msg: str):
StreamingCLIInterface.nonstreaming_interface(msg)
@staticmethod
def internal_monologue(msg: str, msg_obj: Optional[Message] = None):
StreamingCLIInterface.nonstreaming_interface(msg, msg_obj)
@staticmethod
def assistant_message(msg: str, msg_obj: Optional[Message] = None):
StreamingCLIInterface.nonstreaming_interface(msg, msg_obj)
@staticmethod
def memory_message(msg: str, msg_obj: Optional[Message] = None):
StreamingCLIInterface.nonstreaming_interface(msg, msg_obj)
@staticmethod
def system_message(msg: str, msg_obj: Optional[Message] = None):
StreamingCLIInterface.nonstreaming_interface(msg, msg_obj)
@staticmethod
def user_message(msg: str, msg_obj: Optional[Message] = None, raw: bool = False, dump: bool = False, debug: bool = DEBUG):
StreamingCLIInterface.nonstreaming_interface(msg, msg_obj)
@staticmethod
def function_message(msg: str, msg_obj: Optional[Message] = None, debug: bool = DEBUG):
StreamingCLIInterface.nonstreaming_interface(msg, msg_obj)
@staticmethod
def print_messages(message_sequence: List[Message], dump=False):
StreamingCLIInterface.nonstreaming_interface(message_sequence, dump)
@staticmethod
def print_messages_simple(message_sequence: List[Message]):
StreamingCLIInterface.nonstreaming_interface.print_messages_simple(message_sequence)
@staticmethod
def print_messages_raw(message_sequence: List[Message]):
StreamingCLIInterface.nonstreaming_interface.print_messages_raw(message_sequence)
@staticmethod
def step_yield():
pass
class AgentRefreshStreamingInterface(ABC):
"""Same as the ChunkStreamingInterface, but
The 'msg' args provides the scoped message, and the optional Message arg can provide additional metadata.
"""
@abstractmethod
def user_message(self, msg: str, msg_obj: Optional[Message] = None):
"""Letta receives a user message"""
raise NotImplementedError
@abstractmethod
def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None):
"""Letta generates some internal monologue"""
raise NotImplementedError
@abstractmethod
def assistant_message(self, msg: str, msg_obj: Optional[Message] = None):
"""Letta uses send_message"""
raise NotImplementedError
@abstractmethod
def function_message(self, msg: str, msg_obj: Optional[Message] = None):
"""Letta calls a function"""
raise NotImplementedError
@abstractmethod
def process_refresh(self, response: ChatCompletionResponse):
"""Process a streaming chunk from an OpenAI-compatible server"""
raise NotImplementedError
@abstractmethod
def stream_start(self):
"""Any setup required before streaming begins"""
raise NotImplementedError
@abstractmethod
def stream_end(self):
"""Any cleanup required after streaming ends"""
raise NotImplementedError
@abstractmethod
def toggle_streaming(self, on: bool):
"""Toggle streaming on/off (off = regular CLI interface)"""
raise NotImplementedError
class StreamingRefreshCLIInterface(AgentRefreshStreamingInterface):
"""Version of the CLI interface that attaches to a stream generator and refreshes a render of the message at every step.
We maintain the partial message state in the interface state, and on each
process chunk we:
(1) update the partial message state,
(2) refresh/rewrite the state to the screen.
"""
nonstreaming_interface = CLIInterface
def __init__(self, fancy: bool = True, separate_send_message: bool = True, disable_inner_mono_call: bool = True):
"""Initialize the streaming CLI interface state."""
self.console = Console()
# Using `Live` with `refresh_per_second` parameter to limit the refresh rate, avoiding excessive updates
self.live = Live("", console=self.console, refresh_per_second=10)
# self.live.start() # Start the Live display context and keep it running
# Use italics / emoji?
self.fancy = fancy
self.streaming = True
self.separate_send_message = separate_send_message
self.disable_inner_mono_call = disable_inner_mono_call
def toggle_streaming(self, on: bool):
self.streaming = on
if on:
self.separate_send_message = True
self.disable_inner_mono_call = True
else:
self.separate_send_message = False
self.disable_inner_mono_call = False
def update_output(self, content: str):
"""Update the displayed output with new content."""
# We use the `Live` object's update mechanism to refresh content without clearing the console
if not self.fancy:
content = escape(content)
self.live.update(self.console.render_str(content), refresh=True)
def process_refresh(self, response: ChatCompletionResponse):
"""Process the response to rewrite the current output buffer."""
if not response.choices:
self.update_output(f"{INNER_THOUGHTS_CLI_SYMBOL} [italic]...[/italic]")
return # Early exit if there are no choices
choice = response.choices[0]
inner_thoughts = choice.message.content if choice.message.content else ""
tool_calls = choice.message.tool_calls if choice.message.tool_calls else []
if self.fancy:
message_string = f"{INNER_THOUGHTS_CLI_SYMBOL} [italic]{inner_thoughts}[/italic]" if inner_thoughts else ""
else:
message_string = "[inner thoughts] " + inner_thoughts if inner_thoughts else ""
if tool_calls:
function_call = tool_calls[0].function
function_name = function_call.name # Function name, can be an empty string
function_args = function_call.arguments # Function arguments, can be an empty string
if message_string:
message_string += "\n"
# special case here for send_message
if self.separate_send_message and function_name == "send_message":
try:
message = json.loads(function_args)["message"]
except:
prefix = '{\n "message": "'
if len(function_args) < len(prefix):
message = "..."
elif function_args.startswith(prefix):
message = function_args[len(prefix) :]
else:
message = function_args
message_string += f"{ASSISTANT_MESSAGE_CLI_SYMBOL} [bold yellow]{message}[/bold yellow]"
else:
message_string += f"{function_name}({function_args})"
self.update_output(message_string)
def stream_start(self):
if self.streaming:
print()
self.live.start() # Start the Live display context and keep it running
self.update_output(f"{INNER_THOUGHTS_CLI_SYMBOL} [italic]...[/italic]")
def stream_end(self):
if self.streaming:
if self.live.is_started:
self.live.stop()
print()
self.live = Live("", console=self.console, refresh_per_second=10)
@staticmethod
def important_message(msg: str):
StreamingCLIInterface.nonstreaming_interface.important_message(msg)
@staticmethod
def warning_message(msg: str):
StreamingCLIInterface.nonstreaming_interface.warning_message(msg)
def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None):
if self.disable_inner_mono_call:
return
StreamingCLIInterface.nonstreaming_interface.internal_monologue(msg, msg_obj)
def assistant_message(self, msg: str, msg_obj: Optional[Message] = None):
if self.separate_send_message:
return
StreamingCLIInterface.nonstreaming_interface.assistant_message(msg, msg_obj)
@staticmethod
def memory_message(msg: str, msg_obj: Optional[Message] = None):
StreamingCLIInterface.nonstreaming_interface.memory_message(msg, msg_obj)
@staticmethod
def system_message(msg: str, msg_obj: Optional[Message] = None):
StreamingCLIInterface.nonstreaming_interface.system_message(msg, msg_obj)
@staticmethod
def user_message(msg: str, msg_obj: Optional[Message] = None, raw: bool = False, dump: bool = False, debug: bool = DEBUG):
StreamingCLIInterface.nonstreaming_interface.user_message(msg, msg_obj)
@staticmethod
def function_message(msg: str, msg_obj: Optional[Message] = None, debug: bool = DEBUG):
StreamingCLIInterface.nonstreaming_interface.function_message(msg, msg_obj)
@staticmethod
def print_messages(message_sequence: List[Message], dump=False):
StreamingCLIInterface.nonstreaming_interface.print_messages(message_sequence, dump)
@staticmethod
def print_messages_simple(message_sequence: List[Message]):
StreamingCLIInterface.nonstreaming_interface.print_messages_simple(message_sequence)
@staticmethod
def print_messages_raw(message_sequence: List[Message]):
StreamingCLIInterface.nonstreaming_interface.print_messages_raw(message_sequence)
@staticmethod
def step_yield():
pass