mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
324 lines
13 KiB
Python
324 lines
13 KiB
Python
import re
|
|
from abc import ABC, abstractmethod
|
|
from typing import List, Optional
|
|
|
|
from colorama import Fore, Style, init
|
|
|
|
from letta.constants import CLI_WARNING_PREFIX
|
|
from letta.helpers.json_helpers import json_loads
|
|
from letta.local_llm.constants import ASSISTANT_MESSAGE_CLI_SYMBOL, INNER_THOUGHTS_CLI_SYMBOL
|
|
from letta.schemas.message import Message
|
|
from letta.utils import printd
|
|
|
|
init(autoreset=True)
|
|
|
|
# DEBUG = True # puts full message outputs in the terminal
|
|
DEBUG = False # only dumps important messages in the terminal
|
|
|
|
STRIP_UI = False
|
|
|
|
|
|
class AgentInterface(ABC):
|
|
"""Interfaces handle Letta-related events (observer pattern)
|
|
|
|
The 'msg' args provides the scoped message, and the optional Message arg can provide additional metadata.
|
|
"""
|
|
|
|
@abstractmethod
|
|
def user_message(self, msg: str, msg_obj: Optional[Message] = None):
|
|
"""Letta receives a user message"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None, chunk_index: Optional[int] = None):
|
|
"""Letta generates some internal monologue"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def assistant_message(self, msg: str, msg_obj: Optional[Message] = None):
|
|
"""Letta uses send_message"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def function_message(self, msg: str, msg_obj: Optional[Message] = None, chunk_index: Optional[int] = None):
|
|
"""Letta calls a function"""
|
|
raise NotImplementedError
|
|
|
|
# @abstractmethod
|
|
# @staticmethod
|
|
# def print_messages():
|
|
# raise NotImplementedError
|
|
|
|
# @abstractmethod
|
|
# @staticmethod
|
|
# def print_messages_raw():
|
|
# raise NotImplementedError
|
|
|
|
# @abstractmethod
|
|
# @staticmethod
|
|
# def step_yield():
|
|
# raise NotImplementedError
|
|
|
|
|
|
class CLIInterface(AgentInterface):
|
|
"""Basic interface for dumping agent events to the command-line"""
|
|
|
|
@staticmethod
|
|
def important_message(msg: str):
|
|
fstr = f"{Fore.MAGENTA}{Style.BRIGHT}{{msg}}{Style.RESET_ALL}"
|
|
if STRIP_UI:
|
|
fstr = "{msg}"
|
|
print(fstr.format(msg=msg))
|
|
|
|
@staticmethod
|
|
def warning_message(msg: str):
|
|
fstr = f"{Fore.RED}{Style.BRIGHT}{{msg}}{Style.RESET_ALL}"
|
|
if STRIP_UI:
|
|
fstr = "{msg}"
|
|
else:
|
|
print(fstr.format(msg=msg))
|
|
|
|
@staticmethod
|
|
def internal_monologue(msg: str, msg_obj: Optional[Message] = None, chunk_index: Optional[int] = None):
|
|
# ANSI escape code for italic is '\x1B[3m'
|
|
fstr = f"\x1B[3m{Fore.LIGHTBLACK_EX}{INNER_THOUGHTS_CLI_SYMBOL} {{msg}}{Style.RESET_ALL}"
|
|
if STRIP_UI:
|
|
fstr = "{msg}"
|
|
print(fstr.format(msg=msg))
|
|
|
|
@staticmethod
|
|
def assistant_message(msg: str, msg_obj: Optional[Message] = None):
|
|
fstr = f"{Fore.YELLOW}{Style.BRIGHT}{ASSISTANT_MESSAGE_CLI_SYMBOL} {Fore.YELLOW}{{msg}}{Style.RESET_ALL}"
|
|
if STRIP_UI:
|
|
fstr = "{msg}"
|
|
print(fstr.format(msg=msg))
|
|
|
|
@staticmethod
|
|
def memory_message(msg: str, msg_obj: Optional[Message] = None):
|
|
fstr = f"{Fore.LIGHTMAGENTA_EX}{Style.BRIGHT}🧠 {Fore.LIGHTMAGENTA_EX}{{msg}}{Style.RESET_ALL}"
|
|
if STRIP_UI:
|
|
fstr = "{msg}"
|
|
print(fstr.format(msg=msg))
|
|
|
|
@staticmethod
|
|
def system_message(msg: str, msg_obj: Optional[Message] = None):
|
|
fstr = f"{Fore.MAGENTA}{Style.BRIGHT}🖥️ [system] {Fore.MAGENTA}{msg}{Style.RESET_ALL}"
|
|
if STRIP_UI:
|
|
fstr = "{msg}"
|
|
print(fstr.format(msg=msg))
|
|
|
|
@staticmethod
|
|
def user_message(
|
|
msg: str,
|
|
msg_obj: Optional[Message] = None,
|
|
raw: bool = False,
|
|
dump: bool = False,
|
|
debug: bool = DEBUG,
|
|
chunk_index: Optional[int] = None,
|
|
):
|
|
def print_user_message(icon, msg, printf=print):
|
|
if STRIP_UI:
|
|
printf(f"{icon} {msg}")
|
|
else:
|
|
printf(f"{Fore.GREEN}{Style.BRIGHT}{icon} {Fore.GREEN}{msg}{Style.RESET_ALL}")
|
|
|
|
def printd_user_message(icon, msg):
|
|
return print_user_message(icon, msg)
|
|
|
|
if not (raw or dump or debug):
|
|
# we do not want to repeat the message in normal use
|
|
return
|
|
|
|
if isinstance(msg, str):
|
|
if raw:
|
|
printd_user_message("🧑", msg)
|
|
return
|
|
else:
|
|
try:
|
|
msg_json = json_loads(msg)
|
|
except:
|
|
printd(f"{CLI_WARNING_PREFIX}failed to parse user message into json")
|
|
printd_user_message("🧑", msg)
|
|
return
|
|
if msg_json["type"] == "user_message":
|
|
if dump:
|
|
print_user_message("🧑", msg_json["message"])
|
|
return
|
|
msg_json.pop("type")
|
|
printd_user_message("🧑", msg_json)
|
|
elif msg_json["type"] == "heartbeat":
|
|
if debug:
|
|
msg_json.pop("type")
|
|
printd_user_message("💓", msg_json)
|
|
elif dump:
|
|
print_user_message("💓", msg_json)
|
|
return
|
|
|
|
elif msg_json["type"] == "system_message":
|
|
msg_json.pop("type")
|
|
printd_user_message("🖥️", msg_json)
|
|
else:
|
|
printd_user_message("🧑", msg_json)
|
|
|
|
@staticmethod
|
|
def function_message(msg: str, msg_obj: Optional[Message] = None, debug: bool = DEBUG, chunk_index: Optional[int] = None):
|
|
def print_function_message(icon, msg, color=Fore.RED, printf=print):
|
|
if STRIP_UI:
|
|
printf(f"⚡{icon} [function] {msg}")
|
|
else:
|
|
printf(f"{color}{Style.BRIGHT}⚡{icon} [function] {color}{msg}{Style.RESET_ALL}")
|
|
|
|
def printd_function_message(icon, msg, color=Fore.RED):
|
|
return print_function_message(icon, msg, color, printf=(print if debug else printd))
|
|
|
|
if isinstance(msg, dict):
|
|
printd_function_message("", msg)
|
|
return
|
|
|
|
if msg.startswith("Success"):
|
|
printd_function_message("🟢", msg)
|
|
elif msg.startswith("Error: "):
|
|
printd_function_message("🔴", msg)
|
|
elif msg.startswith("Ran "):
|
|
# NOTE: ignore 'ran' messages that come post-execution
|
|
return
|
|
elif msg.startswith("Running "):
|
|
if debug:
|
|
printd_function_message("", msg)
|
|
else:
|
|
match = re.search(r"Running (\w+)\((.*)\)", msg)
|
|
if match:
|
|
function_name = match.group(1)
|
|
function_args = match.group(2)
|
|
if function_name in ["archival_memory_insert", "archival_memory_search", "core_memory_replace", "core_memory_append"]:
|
|
if function_name in ["archival_memory_insert", "core_memory_append", "core_memory_replace"]:
|
|
print_function_message("🧠", f"updating memory with {function_name}")
|
|
elif function_name == "archival_memory_search":
|
|
print_function_message("🧠", f"searching memory with {function_name}")
|
|
try:
|
|
msg_dict = eval(function_args)
|
|
if function_name == "archival_memory_search":
|
|
output = f'\tquery: {msg_dict["query"]}, page: {msg_dict["page"]}'
|
|
if STRIP_UI:
|
|
print(output)
|
|
else:
|
|
print(f"{Fore.RED}{output}{Style.RESET_ALL}")
|
|
elif function_name == "archival_memory_insert":
|
|
output = f'\t→ {msg_dict["content"]}'
|
|
if STRIP_UI:
|
|
print(output)
|
|
else:
|
|
print(f"{Style.BRIGHT}{Fore.RED}{output}{Style.RESET_ALL}")
|
|
else:
|
|
if STRIP_UI:
|
|
print(f'\t {msg_dict["old_content"]}\n\t→ {msg_dict["new_content"]}')
|
|
else:
|
|
print(
|
|
f'{Style.BRIGHT}\t{Fore.RED} {msg_dict["old_content"]}\n\t{Fore.GREEN}→ {msg_dict["new_content"]}{Style.RESET_ALL}'
|
|
)
|
|
except Exception as e:
|
|
printd(str(e))
|
|
printd(msg_dict)
|
|
elif function_name in ["conversation_search", "conversation_search_date"]:
|
|
print_function_message("🧠", f"searching memory with {function_name}")
|
|
try:
|
|
msg_dict = eval(function_args)
|
|
output = f'\tquery: {msg_dict["query"]}, page: {msg_dict["page"]}'
|
|
if STRIP_UI:
|
|
print(output)
|
|
else:
|
|
print(f"{Fore.RED}{output}{Style.RESET_ALL}")
|
|
except Exception as e:
|
|
printd(str(e))
|
|
printd(msg_dict)
|
|
else:
|
|
printd(f"{CLI_WARNING_PREFIX}did not recognize function message")
|
|
printd_function_message("", msg)
|
|
else:
|
|
try:
|
|
msg_dict = json_loads(msg)
|
|
if "status" in msg_dict and msg_dict["status"] == "OK":
|
|
printd_function_message("", str(msg), color=Fore.GREEN)
|
|
else:
|
|
printd_function_message("", str(msg), color=Fore.RED)
|
|
except Exception:
|
|
print(f"{CLI_WARNING_PREFIX}did not recognize function message {type(msg)} {msg}")
|
|
printd_function_message("", msg)
|
|
|
|
@staticmethod
|
|
def print_messages(message_sequence: List[Message], dump=False):
|
|
# rewrite to dict format
|
|
message_sequence = [msg.to_openai_dict() for msg in message_sequence]
|
|
|
|
idx = len(message_sequence)
|
|
for msg in message_sequence:
|
|
if dump:
|
|
print(f"[{idx}] ", end="")
|
|
idx -= 1
|
|
role = msg["role"]
|
|
content = msg["content"]
|
|
|
|
if role == "system":
|
|
CLIInterface.system_message(content)
|
|
elif role == "assistant":
|
|
# Differentiate between internal monologue, function calls, and messages
|
|
if msg.get("function_call"):
|
|
if content is not None:
|
|
CLIInterface.internal_monologue(content)
|
|
# I think the next one is not up to date
|
|
# function_message(msg["function_call"])
|
|
args = json_loads(msg["function_call"].get("arguments"))
|
|
CLIInterface.assistant_message(args.get("message"))
|
|
# assistant_message(content)
|
|
elif msg.get("tool_calls"):
|
|
if content is not None:
|
|
CLIInterface.internal_monologue(content)
|
|
function_obj = msg["tool_calls"][0].get("function")
|
|
if function_obj:
|
|
args = json_loads(function_obj.get("arguments"))
|
|
CLIInterface.assistant_message(args.get("message"))
|
|
else:
|
|
CLIInterface.internal_monologue(content)
|
|
elif role == "user":
|
|
CLIInterface.user_message(content, dump=dump)
|
|
elif role == "function":
|
|
CLIInterface.function_message(content, debug=dump)
|
|
elif role == "tool":
|
|
CLIInterface.function_message(content, debug=dump)
|
|
else:
|
|
print(f"Unknown role: {content}")
|
|
|
|
@staticmethod
|
|
def print_messages_simple(message_sequence: List[Message]):
|
|
# rewrite to dict format
|
|
message_sequence = [msg.to_openai_dict() for msg in message_sequence]
|
|
|
|
for msg in message_sequence:
|
|
role = msg["role"]
|
|
content = msg["content"]
|
|
|
|
if role == "system":
|
|
CLIInterface.system_message(content)
|
|
elif role == "assistant":
|
|
CLIInterface.assistant_message(content)
|
|
elif role == "user":
|
|
CLIInterface.user_message(content, raw=True)
|
|
else:
|
|
print(f"Unknown role: {content}")
|
|
|
|
@staticmethod
|
|
def print_messages_raw(message_sequence: List[Message]):
|
|
# rewrite to dict format
|
|
message_sequence = [msg.to_openai_dict() for msg in message_sequence]
|
|
|
|
for msg in message_sequence:
|
|
print(msg)
|
|
|
|
@staticmethod
|
|
def step_yield():
|
|
pass
|
|
|
|
@staticmethod
|
|
def step_complete():
|
|
pass
|