mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
323 lines
11 KiB
Python
323 lines
11 KiB
Python
from datetime import datetime
|
|
import re
|
|
import json
|
|
import os
|
|
import pickle
|
|
import platform
|
|
import subprocess
|
|
import sys
|
|
import io
|
|
from contextlib import contextmanager
|
|
import difflib
|
|
import demjson3 as demjson
|
|
import pytz
|
|
import tiktoken
|
|
|
|
import memgpt
|
|
from memgpt.constants import (
|
|
MEMGPT_DIR,
|
|
FUNCTION_RETURN_CHAR_LIMIT,
|
|
CLI_WARNING_PREFIX,
|
|
CORE_MEMORY_HUMAN_CHAR_LIMIT,
|
|
CORE_MEMORY_PERSONA_CHAR_LIMIT,
|
|
)
|
|
|
|
from memgpt.openai_backcompat.openai_object import OpenAIObject
|
|
|
|
# TODO: what is this?
|
|
# DEBUG = True
|
|
DEBUG = False
|
|
|
|
|
|
@contextmanager
|
|
def suppress_stdout():
|
|
"""Used to temporarily stop stdout (eg for the 'MockLLM' message)"""
|
|
new_stdout = io.StringIO()
|
|
old_stdout = sys.stdout
|
|
sys.stdout = new_stdout
|
|
try:
|
|
yield
|
|
finally:
|
|
sys.stdout = old_stdout
|
|
|
|
|
|
def open_folder_in_explorer(folder_path):
|
|
"""
|
|
Opens the specified folder in the system's native file explorer.
|
|
|
|
:param folder_path: Absolute path to the folder to be opened.
|
|
"""
|
|
if not os.path.exists(folder_path):
|
|
raise ValueError(f"The specified folder {folder_path} does not exist.")
|
|
|
|
# Determine the operating system
|
|
os_name = platform.system()
|
|
|
|
# Open the folder based on the operating system
|
|
if os_name == "Windows":
|
|
# Windows: use 'explorer' command
|
|
subprocess.run(["explorer", folder_path], check=True)
|
|
elif os_name == "Darwin":
|
|
# macOS: use 'open' command
|
|
subprocess.run(["open", folder_path], check=True)
|
|
elif os_name == "Linux":
|
|
# Linux: use 'xdg-open' command (works for most Linux distributions)
|
|
subprocess.run(["xdg-open", folder_path], check=True)
|
|
else:
|
|
raise OSError(f"Unsupported operating system {os_name}.")
|
|
|
|
|
|
# Custom unpickler
|
|
class OpenAIBackcompatUnpickler(pickle.Unpickler):
|
|
def find_class(self, module, name):
|
|
if module == "openai.openai_object":
|
|
return OpenAIObject
|
|
return super().find_class(module, name)
|
|
|
|
|
|
def count_tokens(s: str, model: str = "gpt-4") -> int:
|
|
encoding = tiktoken.encoding_for_model(model)
|
|
return len(encoding.encode(s))
|
|
|
|
|
|
def printd(*args, **kwargs):
|
|
if DEBUG:
|
|
print(*args, **kwargs)
|
|
|
|
|
|
def united_diff(str1, str2):
|
|
lines1 = str1.splitlines(True)
|
|
lines2 = str2.splitlines(True)
|
|
diff = difflib.unified_diff(lines1, lines2)
|
|
return "".join(diff)
|
|
|
|
|
|
def parse_formatted_time(formatted_time):
|
|
# parse times returned by memgpt.utils.get_formatted_time()
|
|
return datetime.strptime(formatted_time, "%Y-%m-%d %I:%M:%S %p %Z%z")
|
|
|
|
|
|
def datetime_to_timestamp(dt):
|
|
# convert datetime object to integer timestamp
|
|
return int(dt.timestamp())
|
|
|
|
|
|
def timestamp_to_datetime(ts):
|
|
# convert integer timestamp to datetime object
|
|
return datetime.fromtimestamp(ts)
|
|
|
|
|
|
def get_local_time_military():
|
|
# Get the current time in UTC
|
|
current_time_utc = datetime.now(pytz.utc)
|
|
|
|
# Convert to San Francisco's time zone (PST/PDT)
|
|
sf_time_zone = pytz.timezone("America/Los_Angeles")
|
|
local_time = current_time_utc.astimezone(sf_time_zone)
|
|
|
|
# You may format it as you desire
|
|
formatted_time = local_time.strftime("%Y-%m-%d %H:%M:%S %Z%z")
|
|
|
|
return formatted_time
|
|
|
|
|
|
def get_local_time_timezone(timezone="America/Los_Angeles"):
|
|
# Get the current time in UTC
|
|
current_time_utc = datetime.now(pytz.utc)
|
|
|
|
# Convert to San Francisco's time zone (PST/PDT)
|
|
sf_time_zone = pytz.timezone(timezone)
|
|
local_time = current_time_utc.astimezone(sf_time_zone)
|
|
|
|
# You may format it as you desire, including AM/PM
|
|
formatted_time = local_time.strftime("%Y-%m-%d %I:%M:%S %p %Z%z")
|
|
|
|
return formatted_time
|
|
|
|
|
|
def get_local_time(timezone=None):
|
|
if timezone is not None:
|
|
time_str = get_local_time_timezone(timezone)
|
|
else:
|
|
# Get the current time, which will be in the local timezone of the computer
|
|
local_time = datetime.now().astimezone()
|
|
|
|
# You may format it as you desire, including AM/PM
|
|
time_str = local_time.strftime("%Y-%m-%d %I:%M:%S %p %Z%z")
|
|
|
|
return time_str.strip()
|
|
|
|
|
|
def parse_json(string):
|
|
"""Parse JSON string into JSON with both json and demjson"""
|
|
result = None
|
|
try:
|
|
result = json.loads(string)
|
|
return result
|
|
except Exception as e:
|
|
print(f"Error parsing json with json package: {e}")
|
|
|
|
try:
|
|
result = demjson.decode(string)
|
|
return result
|
|
except demjson.JSONDecodeError as e:
|
|
print(f"Error parsing json with demjson package: {e}")
|
|
raise e
|
|
|
|
|
|
def validate_function_response(function_response_string: any, strict: bool = False, truncate: bool = True) -> str:
|
|
"""Check to make sure that a function used by MemGPT returned a valid response
|
|
|
|
Responses need to be strings (or None) that fall under a certain text count limit.
|
|
"""
|
|
if not isinstance(function_response_string, str):
|
|
# Soft correction for a few basic types
|
|
|
|
if function_response_string is None:
|
|
# function_response_string = "Empty (no function output)"
|
|
function_response_string = "None" # backcompat
|
|
|
|
elif isinstance(function_response_string, dict):
|
|
if strict:
|
|
# TODO add better error message
|
|
raise ValueError(function_response_string)
|
|
|
|
# Allow dict through since it will be cast to json.dumps()
|
|
try:
|
|
# TODO find a better way to do this that won't result in double escapes
|
|
function_response_string = json.dumps(function_response_string)
|
|
except:
|
|
raise ValueError(function_response_string)
|
|
|
|
else:
|
|
if strict:
|
|
# TODO add better error message
|
|
raise ValueError(function_response_string)
|
|
|
|
# Try to convert to a string, but throw a warning to alert the user
|
|
try:
|
|
function_response_string = str(function_response_string)
|
|
except:
|
|
raise ValueError(function_response_string)
|
|
|
|
# Now check the length and make sure it doesn't go over the limit
|
|
# TODO we should change this to a max token limit that's variable based on tokens remaining (or context-window)
|
|
if truncate and len(function_response_string) > FUNCTION_RETURN_CHAR_LIMIT:
|
|
print(
|
|
f"{CLI_WARNING_PREFIX}function return was over limit ({len(function_response_string)} > {FUNCTION_RETURN_CHAR_LIMIT}) and was truncated"
|
|
)
|
|
function_response_string = f"{function_response_string[:FUNCTION_RETURN_CHAR_LIMIT]}... [NOTE: function output was truncated since it exceeded the character limit ({len(function_response_string)} > {FUNCTION_RETURN_CHAR_LIMIT})]"
|
|
|
|
return function_response_string
|
|
|
|
|
|
def list_agent_config_files(sort="last_modified"):
|
|
"""List all agent config files, ignoring dotfiles."""
|
|
agent_dir = os.path.join(MEMGPT_DIR, "agents")
|
|
files = os.listdir(agent_dir)
|
|
|
|
# Remove dotfiles like .DS_Store
|
|
files = [file for file in files if not file.startswith(".")]
|
|
|
|
# Remove anything that's not a directory
|
|
files = [file for file in files if os.path.isdir(os.path.join(agent_dir, file))]
|
|
|
|
if sort is not None:
|
|
if sort == "last_modified":
|
|
# Sort the directories by last modified (most recent first)
|
|
files.sort(key=lambda x: os.path.getmtime(os.path.join(agent_dir, x)), reverse=True)
|
|
else:
|
|
raise ValueError(f"Unrecognized sorting option {sort}")
|
|
|
|
return files
|
|
|
|
|
|
def list_human_files():
|
|
"""List all humans files"""
|
|
defaults_dir = os.path.join(memgpt.__path__[0], "humans", "examples")
|
|
user_dir = os.path.join(MEMGPT_DIR, "humans")
|
|
|
|
memgpt_defaults = os.listdir(defaults_dir)
|
|
memgpt_defaults = [os.path.join(defaults_dir, f) for f in memgpt_defaults if f.endswith(".txt")]
|
|
|
|
user_added = os.listdir(user_dir)
|
|
user_added = [os.path.join(user_dir, f) for f in user_added]
|
|
return memgpt_defaults + user_added
|
|
|
|
|
|
def list_persona_files():
|
|
"""List all personas files"""
|
|
defaults_dir = os.path.join(memgpt.__path__[0], "personas", "examples")
|
|
user_dir = os.path.join(MEMGPT_DIR, "personas")
|
|
|
|
memgpt_defaults = os.listdir(defaults_dir)
|
|
memgpt_defaults = [os.path.join(defaults_dir, f) for f in memgpt_defaults if f.endswith(".txt")]
|
|
|
|
user_added = os.listdir(user_dir)
|
|
user_added = [os.path.join(user_dir, f) for f in user_added]
|
|
return memgpt_defaults + user_added
|
|
|
|
|
|
def get_human_text(name: str, enforce_limit=True):
|
|
for file_path in list_human_files():
|
|
file = os.path.basename(file_path)
|
|
if f"{name}.txt" == file or name == file:
|
|
human_text = open(file_path, "r").read().strip()
|
|
if enforce_limit and len(human_text) > CORE_MEMORY_HUMAN_CHAR_LIMIT:
|
|
raise ValueError(f"Contents of {name}.txt is over the character limit ({len(human_text)} > {CORE_MEMORY_HUMAN_CHAR_LIMIT})")
|
|
return human_text
|
|
|
|
raise ValueError(f"Human {name}.txt not found")
|
|
|
|
|
|
def get_persona_text(name: str, enforce_limit=True):
|
|
for file_path in list_persona_files():
|
|
file = os.path.basename(file_path)
|
|
if f"{name}.txt" == file or name == file:
|
|
persona_text = open(file_path, "r").read().strip()
|
|
if enforce_limit and len(persona_text) > CORE_MEMORY_PERSONA_CHAR_LIMIT:
|
|
raise ValueError(
|
|
f"Contents of {name}.txt is over the character limit ({len(persona_text)} > {CORE_MEMORY_PERSONA_CHAR_LIMIT})"
|
|
)
|
|
return persona_text
|
|
|
|
raise ValueError(f"Persona {name}.txt not found")
|
|
|
|
|
|
def get_human_text(name: str):
|
|
for file_path in list_human_files():
|
|
file = os.path.basename(file_path)
|
|
if f"{name}.txt" == file or name == file:
|
|
return open(file_path, "r").read().strip()
|
|
|
|
|
|
def get_schema_diff(schema_a, schema_b):
|
|
# Assuming f_schema and linked_function['json_schema'] are your JSON schemas
|
|
f_schema_json = json.dumps(schema_a, indent=2)
|
|
linked_function_json = json.dumps(schema_b, indent=2)
|
|
|
|
# Compute the difference using difflib
|
|
difference = list(difflib.ndiff(f_schema_json.splitlines(keepends=True), linked_function_json.splitlines(keepends=True)))
|
|
|
|
# Filter out lines that don't represent changes
|
|
difference = [line for line in difference if line.startswith("+ ") or line.startswith("- ")]
|
|
|
|
return "".join(difference)
|
|
|
|
|
|
# datetime related
|
|
def validate_date_format(date_str):
|
|
"""Validate the given date string in the format 'YYYY-MM-DD'."""
|
|
try:
|
|
datetime.datetime.strptime(date_str, "%Y-%m-%d")
|
|
return True
|
|
except (ValueError, TypeError):
|
|
return False
|
|
|
|
|
|
def extract_date_from_timestamp(timestamp):
|
|
"""Extracts and returns the date from the given timestamp."""
|
|
# Extracts the date (ignoring the time and timezone)
|
|
match = re.match(r"(\d{4}-\d{2}-\d{2})", timestamp)
|
|
return match.group(1) if match else None
|