mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
1077 lines
32 KiB
Python
1077 lines
32 KiB
Python
import asyncio
|
|
import copy
|
|
import difflib
|
|
import hashlib
|
|
import inspect
|
|
import io
|
|
import os
|
|
import pickle
|
|
import platform
|
|
import random
|
|
import re
|
|
import subprocess
|
|
import sys
|
|
import uuid
|
|
from contextlib import contextmanager
|
|
from datetime import datetime, timezone
|
|
from functools import wraps
|
|
from logging import Logger
|
|
from typing import Any, Coroutine, List, Union, _GenericAlias, get_args, get_origin, get_type_hints
|
|
from urllib.parse import urljoin, urlparse
|
|
|
|
import demjson3 as demjson
|
|
import tiktoken
|
|
from pathvalidate import sanitize_filename as pathvalidate_sanitize_filename
|
|
|
|
import letta
|
|
from letta.constants import (
|
|
CLI_WARNING_PREFIX,
|
|
CORE_MEMORY_HUMAN_CHAR_LIMIT,
|
|
CORE_MEMORY_PERSONA_CHAR_LIMIT,
|
|
ERROR_MESSAGE_PREFIX,
|
|
LETTA_DIR,
|
|
MAX_FILENAME_LENGTH,
|
|
TOOL_CALL_ID_MAX_LEN,
|
|
)
|
|
from letta.helpers.json_helpers import json_dumps, json_loads
|
|
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
|
|
|
|
DEBUG = False
|
|
if "LOG_LEVEL" in os.environ:
|
|
if os.environ["LOG_LEVEL"] == "DEBUG":
|
|
DEBUG = True
|
|
|
|
|
|
ADJECTIVE_BANK = [
|
|
"beautiful",
|
|
"gentle",
|
|
"angry",
|
|
"vivacious",
|
|
"grumpy",
|
|
"luxurious",
|
|
"fierce",
|
|
"delicate",
|
|
"fluffy",
|
|
"radiant",
|
|
"elated",
|
|
"magnificent",
|
|
"sassy",
|
|
"ecstatic",
|
|
"lustrous",
|
|
"gleaming",
|
|
"sorrowful",
|
|
"majestic",
|
|
"proud",
|
|
"dynamic",
|
|
"energetic",
|
|
"mysterious",
|
|
"loyal",
|
|
"brave",
|
|
"decisive",
|
|
"frosty",
|
|
"cheerful",
|
|
"adorable",
|
|
"melancholy",
|
|
"vibrant",
|
|
"elegant",
|
|
"gracious",
|
|
"inquisitive",
|
|
"opulent",
|
|
"peaceful",
|
|
"rebellious",
|
|
"scintillating",
|
|
"dazzling",
|
|
"whimsical",
|
|
"impeccable",
|
|
"meticulous",
|
|
"resilient",
|
|
"charming",
|
|
"vivacious",
|
|
"creative",
|
|
"intuitive",
|
|
"compassionate",
|
|
"innovative",
|
|
"enthusiastic",
|
|
"tremendous",
|
|
"effervescent",
|
|
"tenacious",
|
|
"fearless",
|
|
"sophisticated",
|
|
"witty",
|
|
"optimistic",
|
|
"exquisite",
|
|
"sincere",
|
|
"generous",
|
|
"kindhearted",
|
|
"serene",
|
|
"amiable",
|
|
"adventurous",
|
|
"bountiful",
|
|
"courageous",
|
|
"diligent",
|
|
"exotic",
|
|
"grateful",
|
|
"harmonious",
|
|
"imaginative",
|
|
"jubilant",
|
|
"keen",
|
|
"luminous",
|
|
"nurturing",
|
|
"outgoing",
|
|
"passionate",
|
|
"quaint",
|
|
"resourceful",
|
|
"sturdy",
|
|
"tactful",
|
|
"unassuming",
|
|
"versatile",
|
|
"wondrous",
|
|
"youthful",
|
|
"zealous",
|
|
"ardent",
|
|
"benevolent",
|
|
"capricious",
|
|
"dedicated",
|
|
"empathetic",
|
|
"fabulous",
|
|
"gregarious",
|
|
"humble",
|
|
"intriguing",
|
|
"jovial",
|
|
"kind",
|
|
"lovable",
|
|
"mindful",
|
|
"noble",
|
|
"original",
|
|
"pleasant",
|
|
"quixotic",
|
|
"reliable",
|
|
"spirited",
|
|
"tranquil",
|
|
"unique",
|
|
"venerable",
|
|
"warmhearted",
|
|
"xenodochial",
|
|
"yearning",
|
|
"zesty",
|
|
"amusing",
|
|
"blissful",
|
|
"calm",
|
|
"daring",
|
|
"enthusiastic",
|
|
"faithful",
|
|
"graceful",
|
|
"honest",
|
|
"incredible",
|
|
"joyful",
|
|
"kind",
|
|
"lovely",
|
|
"merry",
|
|
"noble",
|
|
"optimistic",
|
|
"peaceful",
|
|
"quirky",
|
|
"respectful",
|
|
"sweet",
|
|
"trustworthy",
|
|
"understanding",
|
|
"vibrant",
|
|
"witty",
|
|
"xenial",
|
|
"youthful",
|
|
"zealous",
|
|
"ambitious",
|
|
"brilliant",
|
|
"careful",
|
|
"devoted",
|
|
"energetic",
|
|
"friendly",
|
|
"glorious",
|
|
"humorous",
|
|
"intelligent",
|
|
"jovial",
|
|
"knowledgeable",
|
|
"loyal",
|
|
"modest",
|
|
"nice",
|
|
"obedient",
|
|
"patient",
|
|
"quiet",
|
|
"resilient",
|
|
"selfless",
|
|
"tolerant",
|
|
"unique",
|
|
"versatile",
|
|
"warm",
|
|
"xerothermic",
|
|
"yielding",
|
|
"zestful",
|
|
"amazing",
|
|
"bold",
|
|
"charming",
|
|
"determined",
|
|
"exciting",
|
|
"funny",
|
|
"happy",
|
|
"imaginative",
|
|
"jolly",
|
|
"keen",
|
|
"loving",
|
|
"magnificent",
|
|
"nifty",
|
|
"outstanding",
|
|
"polite",
|
|
"quick",
|
|
"reliable",
|
|
"sincere",
|
|
"thoughtful",
|
|
"unusual",
|
|
"valuable",
|
|
"wonderful",
|
|
"xenodochial",
|
|
"zealful",
|
|
"admirable",
|
|
"bright",
|
|
"clever",
|
|
"dedicated",
|
|
"extraordinary",
|
|
"generous",
|
|
"hardworking",
|
|
"inspiring",
|
|
"jubilant",
|
|
"kindhearted",
|
|
"lively",
|
|
"miraculous",
|
|
"neat",
|
|
"openminded",
|
|
"passionate",
|
|
"remarkable",
|
|
"stunning",
|
|
"truthful",
|
|
"upbeat",
|
|
"vivacious",
|
|
"welcoming",
|
|
"yare",
|
|
"zealous",
|
|
]
|
|
|
|
NOUN_BANK = [
|
|
"lizard",
|
|
"firefighter",
|
|
"banana",
|
|
"castle",
|
|
"dolphin",
|
|
"elephant",
|
|
"forest",
|
|
"giraffe",
|
|
"harbor",
|
|
"iceberg",
|
|
"jewelry",
|
|
"kangaroo",
|
|
"library",
|
|
"mountain",
|
|
"notebook",
|
|
"orchard",
|
|
"penguin",
|
|
"quilt",
|
|
"rainbow",
|
|
"squirrel",
|
|
"teapot",
|
|
"umbrella",
|
|
"volcano",
|
|
"waterfall",
|
|
"xylophone",
|
|
"yacht",
|
|
"zebra",
|
|
"apple",
|
|
"butterfly",
|
|
"caterpillar",
|
|
"dragonfly",
|
|
"elephant",
|
|
"flamingo",
|
|
"gorilla",
|
|
"hippopotamus",
|
|
"iguana",
|
|
"jellyfish",
|
|
"koala",
|
|
"lemur",
|
|
"mongoose",
|
|
"nighthawk",
|
|
"octopus",
|
|
"panda",
|
|
"quokka",
|
|
"rhinoceros",
|
|
"salamander",
|
|
"tortoise",
|
|
"unicorn",
|
|
"vulture",
|
|
"walrus",
|
|
"xenopus",
|
|
"yak",
|
|
"zebu",
|
|
"asteroid",
|
|
"balloon",
|
|
"compass",
|
|
"dinosaur",
|
|
"eagle",
|
|
"firefly",
|
|
"galaxy",
|
|
"hedgehog",
|
|
"island",
|
|
"jaguar",
|
|
"kettle",
|
|
"lion",
|
|
"mammoth",
|
|
"nucleus",
|
|
"owl",
|
|
"pumpkin",
|
|
"quasar",
|
|
"reindeer",
|
|
"snail",
|
|
"tiger",
|
|
"universe",
|
|
"vampire",
|
|
"wombat",
|
|
"xerus",
|
|
"yellowhammer",
|
|
"zeppelin",
|
|
"alligator",
|
|
"buffalo",
|
|
"cactus",
|
|
"donkey",
|
|
"emerald",
|
|
"falcon",
|
|
"gazelle",
|
|
"hamster",
|
|
"icicle",
|
|
"jackal",
|
|
"kitten",
|
|
"leopard",
|
|
"mushroom",
|
|
"narwhal",
|
|
"opossum",
|
|
"peacock",
|
|
"quail",
|
|
"rabbit",
|
|
"scorpion",
|
|
"toucan",
|
|
"urchin",
|
|
"viper",
|
|
"wolf",
|
|
"xray",
|
|
"yucca",
|
|
"zebu",
|
|
"acorn",
|
|
"biscuit",
|
|
"cupcake",
|
|
"daisy",
|
|
"eyeglasses",
|
|
"frisbee",
|
|
"goblin",
|
|
"hamburger",
|
|
"icicle",
|
|
"jackfruit",
|
|
"kaleidoscope",
|
|
"lighthouse",
|
|
"marshmallow",
|
|
"nectarine",
|
|
"obelisk",
|
|
"pancake",
|
|
"quicksand",
|
|
"raspberry",
|
|
"spinach",
|
|
"truffle",
|
|
"umbrella",
|
|
"volleyball",
|
|
"walnut",
|
|
"xylophonist",
|
|
"yogurt",
|
|
"zucchini",
|
|
"asterisk",
|
|
"blackberry",
|
|
"chimpanzee",
|
|
"dumpling",
|
|
"espresso",
|
|
"fireplace",
|
|
"gnome",
|
|
"hedgehog",
|
|
"illustration",
|
|
"jackhammer",
|
|
"kumquat",
|
|
"lemongrass",
|
|
"mandolin",
|
|
"nugget",
|
|
"ostrich",
|
|
"parakeet",
|
|
"quiche",
|
|
"racquet",
|
|
"seashell",
|
|
"tadpole",
|
|
"unicorn",
|
|
"vaccination",
|
|
"wolverine",
|
|
"xenophobia",
|
|
"yam",
|
|
"zeppelin",
|
|
"accordion",
|
|
"broccoli",
|
|
"carousel",
|
|
"daffodil",
|
|
"eggplant",
|
|
"flamingo",
|
|
"grapefruit",
|
|
"harpsichord",
|
|
"impression",
|
|
"jackrabbit",
|
|
"kitten",
|
|
"llama",
|
|
"mandarin",
|
|
"nachos",
|
|
"obelisk",
|
|
"papaya",
|
|
"quokka",
|
|
"rooster",
|
|
"sunflower",
|
|
"turnip",
|
|
"ukulele",
|
|
"viper",
|
|
"waffle",
|
|
"xylograph",
|
|
"yeti",
|
|
"zephyr",
|
|
"abacus",
|
|
"blueberry",
|
|
"crocodile",
|
|
"dandelion",
|
|
"echidna",
|
|
"fig",
|
|
"giraffe",
|
|
"hamster",
|
|
"iguana",
|
|
"jackal",
|
|
"kiwi",
|
|
"lobster",
|
|
"marmot",
|
|
"noodle",
|
|
"octopus",
|
|
"platypus",
|
|
"quail",
|
|
"raccoon",
|
|
"starfish",
|
|
"tulip",
|
|
"urchin",
|
|
"vampire",
|
|
"walrus",
|
|
"xylophone",
|
|
"yak",
|
|
"zebra",
|
|
]
|
|
|
|
|
|
def deduplicate(target_list: list) -> list:
|
|
seen = set()
|
|
dedup_list = []
|
|
for i in target_list:
|
|
if i not in seen:
|
|
seen.add(i)
|
|
dedup_list.append(i)
|
|
|
|
return dedup_list
|
|
|
|
|
|
def smart_urljoin(base_url: str, relative_url: str) -> str:
|
|
"""urljoin is stupid and wants a trailing / at the end of the endpoint address, or it will chop the suffix off"""
|
|
if not base_url.endswith("/"):
|
|
base_url += "/"
|
|
return urljoin(base_url, relative_url)
|
|
|
|
|
|
def get_tool_call_id() -> str:
|
|
# TODO(sarah) make this a slug-style string?
|
|
# e.g. OpenAI: "call_xlIfzR1HqAW7xJPa3ExJSg3C"
|
|
# or similar to agents: "call-xlIfzR1HqAW7xJPa3ExJSg3C"
|
|
return str(uuid.uuid4())[:TOOL_CALL_ID_MAX_LEN]
|
|
|
|
|
|
def assistant_function_to_tool(assistant_message: dict) -> dict:
|
|
assert "function_call" in assistant_message
|
|
new_msg = copy.deepcopy(assistant_message)
|
|
function_call = new_msg.pop("function_call")
|
|
new_msg["tool_calls"] = [
|
|
{
|
|
"id": get_tool_call_id(),
|
|
"type": "function",
|
|
"function": function_call,
|
|
}
|
|
]
|
|
return new_msg
|
|
|
|
|
|
def is_optional_type(hint):
|
|
"""Check if the type hint is an Optional type."""
|
|
if isinstance(hint, _GenericAlias):
|
|
return hint.__origin__ is Union and type(None) in hint.__args__
|
|
return False
|
|
|
|
|
|
def enforce_types(func):
|
|
@wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
# Get type hints, excluding the return type hint
|
|
hints = {k: v for k, v in get_type_hints(func).items() if k != "return"}
|
|
|
|
# Get the function's argument names
|
|
arg_names = inspect.getfullargspec(func).args
|
|
|
|
# Pair each argument with its corresponding type hint
|
|
args_with_hints = dict(zip(arg_names[1:], args[1:])) # Skipping 'self'
|
|
|
|
# Function to check if a value matches a given type hint
|
|
def matches_type(value, hint):
|
|
origin = get_origin(hint)
|
|
args = get_args(hint)
|
|
|
|
if origin is Union: # Handle Union types (including Optional)
|
|
return any(matches_type(value, arg) for arg in args)
|
|
elif origin is list and isinstance(value, list): # Handle List[T]
|
|
element_type = args[0] if args else None
|
|
return all(isinstance(v, element_type) for v in value) if element_type else True
|
|
elif origin: # Handle other generics like Dict, Tuple, etc.
|
|
return isinstance(value, origin)
|
|
else: # Handle non-generic types
|
|
return isinstance(value, hint)
|
|
|
|
# Check types of arguments
|
|
for arg_name, arg_value in args_with_hints.items():
|
|
hint = hints.get(arg_name)
|
|
if hint and not matches_type(arg_value, hint):
|
|
raise ValueError(f"Argument {arg_name} does not match type {hint}; is {arg_value}")
|
|
|
|
# Check types of keyword arguments
|
|
for arg_name, arg_value in kwargs.items():
|
|
hint = hints.get(arg_name)
|
|
if hint and not matches_type(arg_value, hint):
|
|
raise ValueError(f"Argument {arg_name} does not match type {hint}; is {arg_value}")
|
|
|
|
return func(*args, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
|
|
def annotate_message_json_list_with_tool_calls(messages: List[dict], allow_tool_roles: bool = False):
|
|
"""Add in missing tool_call_id fields to a list of messages using function call style
|
|
|
|
Walk through the list forwards:
|
|
- If we encounter an assistant message that calls a function ("function_call") but doesn't have a "tool_call_id" field
|
|
- Generate the tool_call_id
|
|
- Then check if the subsequent message is a role == "function" message
|
|
- If so, then att
|
|
"""
|
|
tool_call_index = None
|
|
tool_call_id = None
|
|
updated_messages = []
|
|
|
|
for i, message in enumerate(messages):
|
|
if "role" not in message:
|
|
raise ValueError(f"message missing 'role' field:\n{message}")
|
|
|
|
# If we find a function call w/o a tool call ID annotation, annotate it
|
|
if message["role"] == "assistant" and "function_call" in message:
|
|
if "tool_call_id" in message and message["tool_call_id"] is not None:
|
|
printd(f"Message already has tool_call_id")
|
|
tool_call_id = message["tool_call_id"]
|
|
else:
|
|
tool_call_id = str(uuid.uuid4())
|
|
message["tool_call_id"] = tool_call_id
|
|
tool_call_index = i
|
|
|
|
# After annotating the call, we expect to find a follow-up response (also unannotated)
|
|
elif message["role"] == "function":
|
|
# We should have a new tool call id in the buffer
|
|
if tool_call_id is None:
|
|
# raise ValueError(
|
|
print(
|
|
f"Got a function call role, but did not have a saved tool_call_id ready to use (i={i}, total={len(messages)}):\n{messages[:i]}\n{message}"
|
|
)
|
|
# allow a soft fail in this case
|
|
message["tool_call_id"] = str(uuid.uuid4())
|
|
elif "tool_call_id" in message:
|
|
raise ValueError(
|
|
f"Got a function call role, but it already had a saved tool_call_id (i={i}, total={len(messages)}):\n{messages[:i]}\n{message}"
|
|
)
|
|
elif i != tool_call_index + 1:
|
|
raise ValueError(
|
|
f"Got a function call role, saved tool_call_id came earlier than i-1 (i={i}, total={len(messages)}):\n{messages[:i]}\n{message}"
|
|
)
|
|
else:
|
|
message["tool_call_id"] = tool_call_id
|
|
tool_call_id = None # wipe the buffer
|
|
|
|
elif message["role"] == "assistant" and "tool_calls" in message and message["tool_calls"] is not None:
|
|
if not allow_tool_roles:
|
|
raise NotImplementedError(
|
|
f"tool_call_id annotation is meant for deprecated functions style, but got role 'assistant' with 'tool_calls' in message (i={i}, total={len(messages)}):\n{messages[:i]}\n{message}"
|
|
)
|
|
|
|
if len(message["tool_calls"]) != 1:
|
|
raise NotImplementedError(
|
|
f"Got unexpected format for tool_calls inside assistant message (i={i}, total={len(messages)}):\n{messages[:i]}\n{message}"
|
|
)
|
|
|
|
assistant_tool_call = message["tool_calls"][0]
|
|
if "id" in assistant_tool_call and assistant_tool_call["id"] is not None:
|
|
printd(f"Message already has id (tool_call_id)")
|
|
tool_call_id = assistant_tool_call["id"]
|
|
else:
|
|
tool_call_id = str(uuid.uuid4())
|
|
message["tool_calls"][0]["id"] = tool_call_id
|
|
# also just put it at the top level for ease-of-access
|
|
# message["tool_call_id"] = tool_call_id
|
|
tool_call_index = i
|
|
|
|
elif message["role"] == "tool":
|
|
if not allow_tool_roles:
|
|
raise NotImplementedError(
|
|
f"tool_call_id annotation is meant for deprecated functions style, but got role 'tool' in message (i={i}, total={len(messages)}):\n{messages[:i]}\n{message}"
|
|
)
|
|
|
|
# if "tool_call_id" not in message or message["tool_call_id"] is None:
|
|
# raise ValueError(f"Got a tool call role, but there's no tool_call_id:\n{messages[:i]}\n{message}")
|
|
|
|
# We should have a new tool call id in the buffer
|
|
if tool_call_id is None:
|
|
# raise ValueError(
|
|
print(
|
|
f"Got a tool call role, but did not have a saved tool_call_id ready to use (i={i}, total={len(messages)}):\n{messages[:i]}\n{message}"
|
|
)
|
|
# allow a soft fail in this case
|
|
message["tool_call_id"] = str(uuid.uuid4())
|
|
elif "tool_call_id" in message and message["tool_call_id"] is not None:
|
|
if tool_call_id is not None and tool_call_id != message["tool_call_id"]:
|
|
# just wipe it
|
|
# raise ValueError(
|
|
# f"Got a tool call role, but it already had a saved tool_call_id (i={i}, total={len(messages)}):\n{messages[:i]}\n{message}"
|
|
# )
|
|
message["tool_call_id"] = tool_call_id
|
|
tool_call_id = None # wipe the buffer
|
|
else:
|
|
tool_call_id = None
|
|
elif i != tool_call_index + 1:
|
|
raise ValueError(
|
|
f"Got a tool call role, saved tool_call_id came earlier than i-1 (i={i}, total={len(messages)}):\n{messages[:i]}\n{message}"
|
|
)
|
|
else:
|
|
message["tool_call_id"] = tool_call_id
|
|
tool_call_id = None # wipe the buffer
|
|
|
|
else:
|
|
# eg role == 'user', nothing to do here
|
|
pass
|
|
|
|
updated_messages.append(copy.deepcopy(message))
|
|
|
|
return updated_messages
|
|
|
|
|
|
def version_less_than(version_a: str, version_b: str) -> bool:
|
|
"""Compare versions to check if version_a is less than version_b."""
|
|
# Regular expression to match version strings of the format int.int.int
|
|
version_pattern = re.compile(r"^\d+\.\d+\.\d+$")
|
|
|
|
# Assert that version strings match the required format
|
|
if not version_pattern.match(version_a) or not version_pattern.match(version_b):
|
|
raise ValueError("Version strings must be in the format 'int.int.int'")
|
|
|
|
# Split the version strings into parts
|
|
parts_a = [int(part) for part in version_a.split(".")]
|
|
parts_b = [int(part) for part in version_b.split(".")]
|
|
|
|
# Compare version parts
|
|
return parts_a < parts_b
|
|
|
|
|
|
def create_random_username() -> str:
|
|
"""Generate a random username by combining an adjective and a noun."""
|
|
adjective = random.choice(ADJECTIVE_BANK).capitalize()
|
|
noun = random.choice(NOUN_BANK).capitalize()
|
|
return adjective + noun
|
|
|
|
|
|
def verify_first_message_correctness(
|
|
response: ChatCompletionResponse, require_send_message: bool = True, require_monologue: bool = False
|
|
) -> bool:
|
|
"""Can be used to enforce that the first message always uses send_message"""
|
|
response_message = response.choices[0].message
|
|
|
|
# First message should be a call to send_message with a non-empty content
|
|
if (hasattr(response_message, "function_call") and response_message.function_call is not None) and (
|
|
hasattr(response_message, "tool_calls") and response_message.tool_calls is not None
|
|
):
|
|
printd(f"First message includes both function call AND tool call: {response_message}")
|
|
return False
|
|
elif hasattr(response_message, "function_call") and response_message.function_call is not None:
|
|
function_call = response_message.function_call
|
|
elif hasattr(response_message, "tool_calls") and response_message.tool_calls is not None:
|
|
function_call = response_message.tool_calls[0].function
|
|
else:
|
|
printd(f"First message didn't include function call: {response_message}")
|
|
return False
|
|
|
|
function_name = function_call.name if function_call is not None else ""
|
|
if require_send_message and function_name != "send_message" and function_name != "archival_memory_search":
|
|
printd(f"First message function call wasn't send_message or archival_memory_search: {response_message}")
|
|
return False
|
|
|
|
if require_monologue and (not response_message.content or response_message.content is None or response_message.content == ""):
|
|
printd(f"First message missing internal monologue: {response_message}")
|
|
return False
|
|
|
|
if response_message.content:
|
|
### Extras
|
|
monologue = response_message.content
|
|
|
|
def contains_special_characters(s):
|
|
special_characters = '(){}[]"'
|
|
return any(char in s for char in special_characters)
|
|
|
|
if contains_special_characters(monologue):
|
|
printd(f"First message internal monologue contained special characters: {response_message}")
|
|
return False
|
|
# if 'functions' in monologue or 'send_message' in monologue or 'inner thought' in monologue.lower():
|
|
if "functions" in monologue or "send_message" in monologue:
|
|
# Sometimes the syntax won't be correct and internal syntax will leak into message.context
|
|
printd(f"First message internal monologue contained reserved words: {response_message}")
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
def is_valid_url(url):
|
|
try:
|
|
result = urlparse(url)
|
|
return all([result.scheme, result.netloc])
|
|
except ValueError:
|
|
return False
|
|
|
|
|
|
@contextmanager
|
|
def suppress_stdout():
|
|
"""Used to temporarily stop stdout (eg for the 'MockLLM' message)"""
|
|
new_stdout = io.StringIO()
|
|
old_stdout = sys.stdout
|
|
sys.stdout = new_stdout
|
|
try:
|
|
yield
|
|
finally:
|
|
sys.stdout = old_stdout
|
|
|
|
|
|
def open_folder_in_explorer(folder_path):
|
|
"""
|
|
Opens the specified folder in the system's native file explorer.
|
|
|
|
:param folder_path: Absolute path to the folder to be opened.
|
|
"""
|
|
if not os.path.exists(folder_path):
|
|
raise ValueError(f"The specified folder {folder_path} does not exist.")
|
|
|
|
# Determine the operating system
|
|
os_name = platform.system()
|
|
|
|
# Open the folder based on the operating system
|
|
if os_name == "Windows":
|
|
# Windows: use 'explorer' command
|
|
subprocess.run(["explorer", folder_path], check=True)
|
|
elif os_name == "Darwin":
|
|
# macOS: use 'open' command
|
|
subprocess.run(["open", folder_path], check=True)
|
|
elif os_name == "Linux":
|
|
# Linux: use 'xdg-open' command (works for most Linux distributions)
|
|
subprocess.run(["xdg-open", folder_path], check=True)
|
|
else:
|
|
raise OSError(f"Unsupported operating system {os_name}.")
|
|
|
|
|
|
# Custom unpickler
|
|
class OpenAIBackcompatUnpickler(pickle.Unpickler):
|
|
def find_class(self, module, name):
|
|
if module == "openai.openai_object":
|
|
from letta.openai_backcompat.openai_object import OpenAIObject
|
|
|
|
return OpenAIObject
|
|
return super().find_class(module, name)
|
|
|
|
|
|
def count_tokens(s: str, model: str = "gpt-4") -> int:
|
|
encoding = tiktoken.encoding_for_model(model)
|
|
return len(encoding.encode(s))
|
|
|
|
|
|
def printd(*args, **kwargs):
|
|
if DEBUG:
|
|
print(*args, **kwargs)
|
|
|
|
|
|
def united_diff(str1: str, str2: str) -> str:
|
|
lines1 = str1.splitlines(True)
|
|
lines2 = str2.splitlines(True)
|
|
diff = difflib.unified_diff(lines1, lines2)
|
|
return "".join(diff)
|
|
|
|
|
|
def parse_json(string) -> dict:
|
|
"""Parse JSON string into JSON with both json and demjson"""
|
|
result = None
|
|
try:
|
|
result = json_loads(string)
|
|
if not isinstance(result, dict):
|
|
raise ValueError(f"JSON from string input ({string}) is not a dictionary (type {type(result)}): {result}")
|
|
return result
|
|
except Exception as e:
|
|
print(f"Error parsing json with json package, falling back to demjson: {e}")
|
|
|
|
try:
|
|
result = demjson.decode(string)
|
|
if not isinstance(result, dict):
|
|
raise ValueError(f"JSON from string input ({string}) is not a dictionary (type {type(result)}): {result}")
|
|
return result
|
|
except demjson.JSONDecodeError as e:
|
|
print(f"Error parsing json with demjson package (fatal): {e}")
|
|
raise e
|
|
|
|
|
|
def validate_function_response(function_response_string: any, return_char_limit: int, strict: bool = False, truncate: bool = True) -> str:
|
|
"""Check to make sure that a function used by Letta returned a valid response. Truncates to return_char_limit if necessary.
|
|
|
|
Responses need to be strings (or None) that fall under a certain text count limit.
|
|
"""
|
|
if not isinstance(function_response_string, str):
|
|
# Soft correction for a few basic types
|
|
|
|
if function_response_string is None:
|
|
# function_response_string = "Empty (no function output)"
|
|
function_response_string = "None" # backcompat
|
|
|
|
elif isinstance(function_response_string, dict):
|
|
if strict:
|
|
# TODO add better error message
|
|
raise ValueError(function_response_string)
|
|
|
|
# Allow dict through since it will be cast to json.dumps()
|
|
try:
|
|
# TODO find a better way to do this that won't result in double escapes
|
|
function_response_string = json_dumps(function_response_string)
|
|
except:
|
|
raise ValueError(function_response_string)
|
|
|
|
else:
|
|
if strict:
|
|
# TODO add better error message
|
|
raise ValueError(function_response_string)
|
|
|
|
# Try to convert to a string, but throw a warning to alert the user
|
|
try:
|
|
function_response_string = str(function_response_string)
|
|
except:
|
|
raise ValueError(function_response_string)
|
|
|
|
# Now check the length and make sure it doesn't go over the limit
|
|
# TODO we should change this to a max token limit that's variable based on tokens remaining (or context-window)
|
|
if truncate and len(function_response_string) > return_char_limit:
|
|
print(
|
|
f"{CLI_WARNING_PREFIX}function return was over limit ({len(function_response_string)} > {return_char_limit}) and was truncated"
|
|
)
|
|
function_response_string = f"{function_response_string[:return_char_limit]}... [NOTE: function output was truncated since it exceeded the character limit ({len(function_response_string)} > {return_char_limit})]"
|
|
|
|
return function_response_string
|
|
|
|
|
|
def list_agent_config_files(sort="last_modified"):
|
|
"""List all agent config files, ignoring dotfiles."""
|
|
agent_dir = os.path.join(LETTA_DIR, "agents")
|
|
files = os.listdir(agent_dir)
|
|
|
|
# Remove dotfiles like .DS_Store
|
|
files = [file for file in files if not file.startswith(".")]
|
|
|
|
# Remove anything that's not a directory
|
|
files = [file for file in files if os.path.isdir(os.path.join(agent_dir, file))]
|
|
|
|
if sort is not None:
|
|
if sort == "last_modified":
|
|
# Sort the directories by last modified (most recent first)
|
|
files.sort(key=lambda x: os.path.getmtime(os.path.join(agent_dir, x)), reverse=True)
|
|
else:
|
|
raise ValueError(f"Unrecognized sorting option {sort}")
|
|
|
|
return files
|
|
|
|
|
|
def list_human_files():
|
|
"""List all humans files"""
|
|
defaults_dir = os.path.join(letta.__path__[0], "humans", "examples")
|
|
user_dir = os.path.join(LETTA_DIR, "humans")
|
|
|
|
letta_defaults = os.listdir(defaults_dir)
|
|
letta_defaults = [os.path.join(defaults_dir, f) for f in letta_defaults if f.endswith(".txt")]
|
|
|
|
if os.path.exists(user_dir):
|
|
user_added = os.listdir(user_dir)
|
|
user_added = [os.path.join(user_dir, f) for f in user_added]
|
|
else:
|
|
user_added = []
|
|
return letta_defaults + user_added
|
|
|
|
|
|
def list_persona_files():
|
|
"""List all personas files"""
|
|
defaults_dir = os.path.join(letta.__path__[0], "personas", "examples")
|
|
user_dir = os.path.join(LETTA_DIR, "personas")
|
|
|
|
letta_defaults = os.listdir(defaults_dir)
|
|
letta_defaults = [os.path.join(defaults_dir, f) for f in letta_defaults if f.endswith(".txt")]
|
|
|
|
if os.path.exists(user_dir):
|
|
user_added = os.listdir(user_dir)
|
|
user_added = [os.path.join(user_dir, f) for f in user_added]
|
|
else:
|
|
user_added = []
|
|
return letta_defaults + user_added
|
|
|
|
|
|
def get_human_text(name: str, enforce_limit=True):
|
|
for file_path in list_human_files():
|
|
file = os.path.basename(file_path)
|
|
if f"{name}.txt" == file or name == file:
|
|
human_text = open(file_path, "r", encoding="utf-8").read().strip()
|
|
if enforce_limit and len(human_text) > CORE_MEMORY_HUMAN_CHAR_LIMIT:
|
|
raise ValueError(f"Contents of {name}.txt is over the character limit ({len(human_text)} > {CORE_MEMORY_HUMAN_CHAR_LIMIT})")
|
|
return human_text
|
|
|
|
raise ValueError(f"Human {name}.txt not found")
|
|
|
|
|
|
def get_persona_text(name: str, enforce_limit=True):
|
|
for file_path in list_persona_files():
|
|
file = os.path.basename(file_path)
|
|
if f"{name}.txt" == file or name == file:
|
|
persona_text = open(file_path, "r", encoding="utf-8").read().strip()
|
|
if enforce_limit and len(persona_text) > CORE_MEMORY_PERSONA_CHAR_LIMIT:
|
|
raise ValueError(
|
|
f"Contents of {name}.txt is over the character limit ({len(persona_text)} > {CORE_MEMORY_PERSONA_CHAR_LIMIT})"
|
|
)
|
|
return persona_text
|
|
|
|
raise ValueError(f"Persona {name}.txt not found")
|
|
|
|
|
|
def get_schema_diff(schema_a, schema_b):
|
|
# Assuming f_schema and linked_function['json_schema'] are your JSON schemas
|
|
f_schema_json = json_dumps(schema_a)
|
|
linked_function_json = json_dumps(schema_b)
|
|
|
|
# Compute the difference using difflib
|
|
difference = list(difflib.ndiff(f_schema_json.splitlines(keepends=True), linked_function_json.splitlines(keepends=True)))
|
|
|
|
# Filter out lines that don't represent changes
|
|
difference = [line for line in difference if line.startswith("+ ") or line.startswith("- ")]
|
|
|
|
return "".join(difference)
|
|
|
|
|
|
def create_uuid_from_string(val: str):
|
|
"""
|
|
Generate consistent UUID from a string
|
|
from: https://samos-it.com/posts/python-create-uuid-from-random-string-of-words.html
|
|
"""
|
|
hex_string = hashlib.md5(val.encode("UTF-8")).hexdigest()
|
|
return uuid.UUID(hex=hex_string)
|
|
|
|
|
|
def sanitize_filename(filename: str) -> str:
|
|
"""
|
|
Sanitize the given filename to prevent directory traversal, invalid characters,
|
|
and reserved names while ensuring it fits within the maximum length allowed by the filesystem.
|
|
|
|
Parameters:
|
|
filename (str): The user-provided filename.
|
|
|
|
Returns:
|
|
str: A sanitized filename that is unique and safe for use.
|
|
"""
|
|
# Extract the base filename to avoid directory components
|
|
filename = os.path.basename(filename)
|
|
|
|
# Split the base and extension
|
|
base, ext = os.path.splitext(filename)
|
|
|
|
# External sanitization library
|
|
base = pathvalidate_sanitize_filename(base)
|
|
|
|
# Cannot start with a period
|
|
if base.startswith("."):
|
|
raise ValueError(f"Invalid filename - derived file name {base} cannot start with '.'")
|
|
|
|
# Truncate the base name to fit within the maximum allowed length
|
|
max_base_length = MAX_FILENAME_LENGTH - len(ext) - 33 # 32 for UUID + 1 for `_`
|
|
if len(base) > max_base_length:
|
|
base = base[:max_base_length]
|
|
|
|
# Append a unique UUID suffix for uniqueness
|
|
unique_suffix = uuid.uuid4().hex
|
|
sanitized_filename = f"{base}_{unique_suffix}{ext}"
|
|
|
|
# Return the sanitized filename
|
|
return sanitized_filename
|
|
|
|
|
|
def get_friendly_error_msg(function_name: str, exception_name: str, exception_message: str):
|
|
from letta.constants import MAX_ERROR_MESSAGE_CHAR_LIMIT
|
|
|
|
error_msg = f"{ERROR_MESSAGE_PREFIX} executing function {function_name}: {exception_name}: {exception_message}"
|
|
if len(error_msg) > MAX_ERROR_MESSAGE_CHAR_LIMIT:
|
|
error_msg = error_msg[:MAX_ERROR_MESSAGE_CHAR_LIMIT]
|
|
return error_msg
|
|
|
|
|
|
def run_async_task(coro: Coroutine[Any, Any, Any]) -> Any:
|
|
"""
|
|
Safely runs an asynchronous coroutine in a synchronous context.
|
|
|
|
If an event loop is already running, it uses `asyncio.ensure_future`.
|
|
Otherwise, it creates a new event loop and runs the coroutine.
|
|
|
|
Args:
|
|
coro: The coroutine to execute.
|
|
|
|
Returns:
|
|
The result of the coroutine.
|
|
"""
|
|
try:
|
|
# If there's already a running event loop, schedule the coroutine
|
|
loop = asyncio.get_running_loop()
|
|
return asyncio.run_until_complete(coro) if loop.is_closed() else asyncio.ensure_future(coro)
|
|
except RuntimeError:
|
|
# If no event loop is running, create a new one
|
|
return asyncio.run(coro)
|
|
|
|
|
|
def log_telemetry(logger: Logger, event: str, **kwargs):
|
|
"""
|
|
Logs telemetry events with a timestamp.
|
|
|
|
:param logger: A logger
|
|
:param event: A string describing the event.
|
|
:param kwargs: Additional key-value pairs for logging metadata.
|
|
"""
|
|
from letta.settings import settings
|
|
|
|
if settings.verbose_telemetry_logging:
|
|
timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S,%f UTC") # More readable timestamp
|
|
extra_data = " | ".join(f"{key}={value}" for key, value in kwargs.items() if value is not None)
|
|
logger.info(f"[{timestamp}] EVENT: {event} | {extra_data}")
|
|
|
|
|
|
def make_key(*args, **kwargs):
|
|
return str((args, tuple(sorted(kwargs.items()))))
|