fix: Add tests for migration code and bugfix for pulling LLM/embedding configs from user instead of config (#878)

Co-authored-by: cpacker <packercharles@gmail.com>
This commit is contained in:
Sarah Wooders 2024-01-20 22:30:08 -08:00 committed by GitHub
parent e07168663a
commit 75ea61161b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 200 additions and 59 deletions

View File

@ -123,9 +123,13 @@ class Message(Record):
openai_message_dict: dict,
model: Optional[str] = None, # model used to make function call
allow_functions_style: bool = False, # allow deprecated functions style?
created_at: Optional[datetime] = None,
):
"""Convert a ChatCompletion message object into a Message object (synced to DB)"""
assert "role" in openai_message_dict, openai_message_dict
assert "content" in openai_message_dict, openai_message_dict
# If we're going from deprecated function form
if openai_message_dict["role"] == "function":
if not allow_functions_style:
@ -135,6 +139,7 @@ class Message(Record):
# Convert from 'function' response to a 'tool' response
# NOTE: this does not conventionally include a tool_call_id, it's on the caster to provide it
return Message(
created_at=created_at,
user_id=user_id,
agent_id=agent_id,
model=model,
@ -166,6 +171,7 @@ class Message(Record):
]
return Message(
created_at=created_at,
user_id=user_id,
agent_id=agent_id,
model=model,
@ -197,6 +203,7 @@ class Message(Record):
# If we're going from tool-call style
return Message(
created_at=created_at,
user_id=user_id,
agent_id=agent_id,
model=model,

View File

@ -1,5 +1,5 @@
import configparser
import datetime
from datetime import datetime
import os
import pickle
import glob
@ -8,6 +8,8 @@ import traceback
import uuid
import json
import shutil
from typing import Optional
import pytz
import typer
from tqdm import tqdm
@ -21,10 +23,17 @@ from llama_index import (
from memgpt.agent import Agent
from memgpt.data_types import AgentState, User, Passage, Source, Message
from memgpt.metadata import MetadataStore
from memgpt.utils import MEMGPT_DIR, version_less_than, OpenAIBackcompatUnpickler, annotate_message_json_list_with_tool_calls
from memgpt.utils import (
MEMGPT_DIR,
version_less_than,
OpenAIBackcompatUnpickler,
annotate_message_json_list_with_tool_calls,
parse_formatted_time,
)
from memgpt.config import MemGPTConfig
from memgpt.cli.cli_config import configure
from memgpt.agent_store.storage import StorageConnector, TableType
from memgpt.persistence_manager import PersistenceManager, LocalStateManager
# This is the version where the breaking change was made
VERSION_CUTOFF = "0.2.12"
@ -33,19 +42,19 @@ VERSION_CUTOFF = "0.2.12"
MIGRATION_BACKUP_FOLDER = "migration_backups"
def wipe_config_and_reconfigure(run_configure=True):
def wipe_config_and_reconfigure(data_dir: str = MEMGPT_DIR, run_configure=True):
"""Wipe (backup) the config file, and launch `memgpt configure`"""
if not os.path.exists(os.path.join(MEMGPT_DIR, MIGRATION_BACKUP_FOLDER)):
os.makedirs(os.path.join(MEMGPT_DIR, MIGRATION_BACKUP_FOLDER))
os.makedirs(os.path.join(MEMGPT_DIR, MIGRATION_BACKUP_FOLDER, "agents"))
if not os.path.exists(os.path.join(data_dir, MIGRATION_BACKUP_FOLDER)):
os.makedirs(os.path.join(data_dir, MIGRATION_BACKUP_FOLDER))
os.makedirs(os.path.join(data_dir, MIGRATION_BACKUP_FOLDER, "agents"))
# Get the current timestamp in a readable format (e.g., YYYYMMDD_HHMMSS)
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
# Construct the new backup directory name with the timestamp
backup_filename = os.path.join(MEMGPT_DIR, MIGRATION_BACKUP_FOLDER, f"config_backup_{timestamp}")
existing_filename = os.path.join(MEMGPT_DIR, "config")
backup_filename = os.path.join(data_dir, MIGRATION_BACKUP_FOLDER, f"config_backup_{timestamp}")
existing_filename = os.path.join(data_dir, "config")
# Check if the existing file exists before moving
if os.path.exists(existing_filename):
@ -63,11 +72,11 @@ def wipe_config_and_reconfigure(run_configure=True):
MemGPTConfig.load()
def config_is_compatible(allow_empty=False, echo=False) -> bool:
def config_is_compatible(data_dir: str = MEMGPT_DIR, allow_empty=False, echo=False) -> bool:
"""Check if the config is OK to use with 0.2.12, or if it needs to be deleted"""
# NOTE: don't use built-in load(), since that will apply defaults
# memgpt_config = MemGPTConfig.load()
memgpt_config_file = os.path.join(MEMGPT_DIR, "config")
memgpt_config_file = os.path.join(data_dir, "config")
if not os.path.exists(memgpt_config_file):
return True if allow_empty else False
parser = configparser.ConfigParser()
@ -92,9 +101,9 @@ def config_is_compatible(allow_empty=False, echo=False) -> bool:
return True
def agent_is_migrateable(agent_name: str) -> bool:
def agent_is_migrateable(agent_name: str, data_dir: str = MEMGPT_DIR) -> bool:
"""Determine whether or not the agent folder is a migration target"""
agent_folder = os.path.join(MEMGPT_DIR, "agents", agent_name)
agent_folder = os.path.join(data_dir, "agents", agent_name)
if not os.path.exists(agent_folder):
raise ValueError(f"Folder {agent_folder} does not exist")
@ -115,14 +124,14 @@ def agent_is_migrateable(agent_name: str) -> bool:
return False
def migrate_source(source_name: str):
def migrate_source(source_name: str, data_dir: str = MEMGPT_DIR, ms: Optional[MetadataStore] = None):
"""
Migrate an old source folder (`~/.memgpt/sources/{source_name}`).
"""
# 1. Load the VectorIndex from ~/.memgpt/sources/{source_name}/index
# TODO
source_path = os.path.join(MEMGPT_DIR, "archival", source_name, "nodes.pkl")
source_path = os.path.join(data_dir, "archival", source_name, "nodes.pkl")
assert os.path.exists(source_path), f"Source {source_name} does not exist at {source_path}"
# load state from old checkpoint file
@ -130,15 +139,21 @@ def migrate_source(source_name: str):
# 2. Create a new AgentState using the agent config + agent internal state
config = MemGPTConfig.load()
if ms is None:
ms = MetadataStore(config)
# gets default user
ms = MetadataStore(config)
user_id = uuid.UUID(config.anon_clientid)
user = ms.get_user(user_id=user_id)
if user is None:
raise ValueError(
f"Failed to load user {str(user_id)} from database. Please make sure to migrate your config before migrating agents."
)
ms.create_user(User(id=user_id))
user = ms.get_user(user_id=user_id)
if user is None:
typer.secho(f"Failed to create default user in database.", fg=typer.colors.RED)
sys.exit(1)
# raise ValueError(
# f"Failed to load user {str(user_id)} from database. Please make sure to migrate your config before migrating agents."
# )
# insert source into metadata store
source = Source(user_id=user.id, name=source_name)
@ -179,7 +194,7 @@ def migrate_source(source_name: str):
assert source is not None, f"Failed to load source {source_name} from database after migration"
def migrate_agent(agent_name: str):
def migrate_agent(agent_name: str, data_dir: str = MEMGPT_DIR, ms: Optional[MetadataStore] = None):
"""Migrate an old agent folder (`~/.memgpt/agents/{agent_name}`)
Steps:
@ -191,7 +206,7 @@ def migrate_agent(agent_name: str):
# 1. Load the agent state JSON from the old folder
# TODO
agent_folder = os.path.join(MEMGPT_DIR, "agents", agent_name)
agent_folder = os.path.join(data_dir, "agents", agent_name)
# migration_file = os.path.join(agent_folder, MIGRATION_FILE_NAME)
# load state from old checkpoint file
@ -255,22 +270,45 @@ def migrate_agent(agent_name: str):
# 2. Create a new AgentState using the agent config + agent internal state
config = MemGPTConfig.load()
if ms is None:
ms = MetadataStore(config)
# gets default user
ms = MetadataStore(config)
user_id = uuid.UUID(config.anon_clientid)
user = ms.get_user(user_id=user_id)
if user is None:
raise ValueError(
f"Failed to load user {str(user_id)} from database. Please make sure to migrate your config before migrating agents."
)
ms.create_user(User(id=user_id))
user = ms.get_user(user_id=user_id)
if user is None:
typer.secho(f"Failed to create default user in database.", fg=typer.colors.RED)
sys.exit(1)
# raise ValueError(
# f"Failed to load user {str(user_id)} from database. Please make sure to migrate your config before migrating agents."
# )
# ms.create_user(User(id=user_id))
# user = ms.get_user(user_id=user_id)
# if user is None:
# typer.secho(f"Failed to create default user in database.", fg=typer.colors.RED)
# sys.exit(1)
# create an agent_id ahead of time
agent_id = uuid.uuid4()
# create all the Messages in the database
# message_objs = []
# for message_dict in annotate_message_json_list_with_tool_calls(state_dict["messages"]):
# message_obj = Message.dict_to_message(
# user_id=user.id,
# agent_id=agent_id,
# openai_message_dict=message_dict,
# model=state_dict["model"] if "model" in state_dict else None,
# # allow_functions_style=False,
# allow_functions_style=True,
# )
# message_objs.append(message_obj)
agent_state = AgentState(
id=agent_id,
name=agent_config["name"],
user_id=user.id,
persona=agent_config["persona"], # eg 'sam_pov'
@ -281,18 +319,97 @@ def migrate_agent(agent_name: str):
persona=state_dict["memory"]["persona"],
system=state_dict["system"],
functions=state_dict["functions"], # this shouldn't matter, since Agent.__init__ will re-link
messages=annotate_message_json_list_with_tool_calls(state_dict["messages"]),
# messages=[str(m.id) for m in message_objs], # this is a list of uuids, not message dicts
),
llm_config=user.default_llm_config,
embedding_config=user.default_embedding_config,
llm_config=config.default_llm_config,
embedding_config=config.default_embedding_config,
)
persistence_manager = LocalStateManager(agent_state=agent_state)
# First clean up the recall message history to add tool call ids
full_message_history_buffer = annotate_message_json_list_with_tool_calls([d["message"] for d in data["all_messages"]])
for i in range(len(data["all_messages"])):
data["all_messages"][i]["message"] = full_message_history_buffer[i]
# Figure out what messages in recall are in-context, and which are out-of-context
agent_message_cache = state_dict["messages"]
recall_message_full = data["all_messages"]
def messages_are_equal(msg1, msg2):
return msg1["role"] == msg2["role"] and msg1["content"] == msg2["content"]
in_context_messages = []
out_of_context_messages = []
assert len(agent_message_cache) <= len(recall_message_full), (len(agent_message_cache), len(recall_message_full))
for d in recall_message_full:
# unpack into "timestamp" and "message"
recall_message = d["message"]
recall_timestamp = d["timestamp"]
try:
recall_datetime = parse_formatted_time(recall_timestamp).astimezone(pytz.utc)
except ValueError:
recall_datetime = datetime.strptime(recall_timestamp, "%Y-%m-%d %I:%M:%S %p").astimezone(pytz.utc)
# message object
message_obj = Message.dict_to_message(
created_at=recall_datetime,
user_id=user.id,
agent_id=agent_id,
openai_message_dict=recall_message,
allow_functions_style=True,
)
# message is either in-context, or out-of-context
message_is_in_context = [messages_are_equal(recall_message, cache_message) for cache_message in agent_message_cache]
assert sum(message_is_in_context) <= 1, message_is_in_context
if any(message_is_in_context):
in_context_messages.append(message_obj)
else:
out_of_context_messages.append(message_obj)
assert len(in_context_messages) > 0
assert len(in_context_messages) == len(agent_message_cache), (len(in_context_messages), len(agent_message_cache))
# assert (
# len(in_context_messages) + len(out_of_context_messages) == state_dict["messages_total"]
# ), f"{len(in_context_messages)} + {len(out_of_context_messages)} != {state_dict['messages_total']}"
# Now we can insert the messages into the actual recall database
# So when we construct the agent from the state, they will be available
persistence_manager.recall_memory.insert_many(out_of_context_messages)
persistence_manager.recall_memory.insert_many(in_context_messages)
# Overwrite the agent_state message object
agent_state.state["messages"] = [str(m.id) for m in in_context_messages] # this is a list of uuids, not message dicts
## 4. Insert into recall
# TODO should this be 'messages', or 'all_messages'?
# all_messages in recall will have fields "timestamp" and "message"
# full_message_history_buffer = annotate_message_json_list_with_tool_calls([d["message"] for d in data["all_messages"]])
# We want to keep the timestamp
# for i in range(len(data["all_messages"])):
# data["all_messages"][i]["message"] = full_message_history_buffer[i]
# messages_to_insert = [
# Message.dict_to_message(
# user_id=user.id,
# agent_id=agent_id,
# openai_message_dict=msg,
# allow_functions_style=True,
# )
# # for msg in data["all_messages"]
# for msg in full_message_history_buffer
# ]
# agent.persistence_manager.recall_memory.insert_many(messages_to_insert)
# print("Finished migrating recall memory")
# 3. Instantiate a new Agent by passing AgentState to Agent.__init__
# NOTE: the Agent.__init__ will trigger a save, which will write to the DB
try:
agent = Agent(
agent_state=agent_state,
messages_total=state_dict["messages_total"], # TODO: do we need this?
# messages_total=state_dict["messages_total"], # TODO: do we need this?
messages_total=len(in_context_messages) + len(out_of_context_messages),
interface=None,
)
except Exception as e:
@ -308,17 +425,6 @@ def migrate_agent(agent_name: str):
# Wrap the rest in a try-except so that we can cleanup by deleting the agent if we fail
try:
## 4. Insert into recall
# TODO should this be 'messages', or 'all_messages'?
# all_messages in recall will have fields "timestamp" and "message"
full_message_history_buffer = annotate_message_json_list_with_tool_calls([d["message"] for d in data["all_messages"]])
# We want to keep the timestamp
for i in range(len(data["all_messages"])):
data["all_messages"][i]["message"] = full_message_history_buffer[i]
messages_to_insert = [Message.dict_to_message(msg, allow_functions_style=True) for msg in data["all_messages"]]
agent.persistence_manager.recall_memory.insert_many(messages_to_insert)
# print("Finished migrating recall memory")
# TODO should we also assign data["messages"] to RecallMemory.messages?
# 5. Insert into archival
@ -350,7 +456,7 @@ def migrate_agent(agent_name: str):
raise
try:
new_agent_folder = os.path.join(MEMGPT_DIR, MIGRATION_BACKUP_FOLDER, "agents", agent_name)
new_agent_folder = os.path.join(data_dir, MIGRATION_BACKUP_FOLDER, "agents", agent_name)
shutil.move(agent_folder, new_agent_folder)
except Exception as e:
print(f"Failed to move agent folder from {agent_folder} to {new_agent_folder}")
@ -358,20 +464,20 @@ def migrate_agent(agent_name: str):
# def migrate_all_agents(stop_on_fail=True):
def migrate_all_agents(stop_on_fail: bool = False) -> dict:
"""Scan over all agent folders in MEMGPT_DIR and migrate each agent."""
def migrate_all_agents(data_dir: str = MEMGPT_DIR, stop_on_fail: bool = False) -> dict:
"""Scan over all agent folders in data_dir and migrate each agent."""
if not os.path.exists(os.path.join(MEMGPT_DIR, MIGRATION_BACKUP_FOLDER)):
os.makedirs(os.path.join(MEMGPT_DIR, MIGRATION_BACKUP_FOLDER))
os.makedirs(os.path.join(MEMGPT_DIR, MIGRATION_BACKUP_FOLDER, "agents"))
if not os.path.exists(os.path.join(data_dir, MIGRATION_BACKUP_FOLDER)):
os.makedirs(os.path.join(data_dir, MIGRATION_BACKUP_FOLDER))
os.makedirs(os.path.join(data_dir, MIGRATION_BACKUP_FOLDER, "agents"))
if not config_is_compatible(echo=True):
if not config_is_compatible(data_dir, echo=True):
typer.secho(f"Your current config file is incompatible with MemGPT versions >= {VERSION_CUTOFF}", fg=typer.colors.RED)
if questionary.confirm(
"To migrate old MemGPT agents, you must delete your config file and run `memgpt configure`. Would you like to proceed?"
).ask():
try:
wipe_config_and_reconfigure()
wipe_config_and_reconfigure(data_dir)
except Exception as e:
typer.secho(f"Fresh config generation failed - error:\n{e}", fg=typer.colors.RED)
raise
@ -379,7 +485,7 @@ def migrate_all_agents(stop_on_fail: bool = False) -> dict:
typer.secho("Migration cancelled (to migrate old agents, run `memgpt migrate`)", fg=typer.colors.RED)
raise KeyboardInterrupt()
agents_dir = os.path.join(MEMGPT_DIR, "agents")
agents_dir = os.path.join(data_dir, "agents")
# Ensure the directory exists
if not os.path.exists(agents_dir):
@ -392,13 +498,16 @@ def migrate_all_agents(stop_on_fail: bool = False) -> dict:
count = 0
failures = []
candidates = []
config = MemGPTConfig.load()
print(config)
ms = MetadataStore(config)
try:
for agent_name in tqdm(agent_folders, desc="Migrating agents"):
# Assuming migrate_agent is a function that takes the agent name and performs migration
try:
if agent_is_migrateable(agent_name=agent_name):
if agent_is_migrateable(agent_name=agent_name, data_dir=data_dir):
candidates.append(agent_name)
migrate_agent(agent_name)
migrate_agent(agent_name, data_dir=data_dir, ms=ms)
count += 1
else:
continue
@ -423,6 +532,7 @@ def migrate_all_agents(stop_on_fail: bool = False) -> dict:
if count > 0:
typer.secho(f"{count}/{len(candidates)} agents were successfully migrated to the new database format", fg=typer.colors.GREEN)
del ms
return {
"agent_folders": len(agent_folders),
"migration_candidates": len(candidates),
@ -431,10 +541,10 @@ def migrate_all_agents(stop_on_fail: bool = False) -> dict:
}
def migrate_all_sources(stop_on_fail: bool = False) -> dict:
"""Scan over all agent folders in MEMGPT_DIR and migrate each agent."""
def migrate_all_sources(data_dir: str = MEMGPT_DIR, stop_on_fail: bool = False) -> dict:
"""Scan over all agent folders in data_dir and migrate each agent."""
sources_dir = os.path.join(MEMGPT_DIR, "archival")
sources_dir = os.path.join(data_dir, "archival")
# Ensure the directory exists
if not os.path.exists(sources_dir):
@ -447,12 +557,14 @@ def migrate_all_sources(stop_on_fail: bool = False) -> dict:
count = 0
failures = []
candidates = []
config = MemGPTConfig.load()
ms = MetadataStore(config)
try:
for source_name in tqdm(source_folders, desc="Migrating data sources"):
# Assuming migrate_agent is a function that takes the agent name and performs migration
try:
candidates.append(source_name)
migrate_source(source_name)
migrate_source(source_name, data_dir, ms=ms)
count += 1
except Exception as e:
failures.append({"name": source_name, "reason": str(e)})
@ -475,6 +587,7 @@ def migrate_all_sources(stop_on_fail: bool = False) -> dict:
if count > 0:
typer.secho(f"{count}/{len(candidates)} sources were successfully migrated to the new database format", fg=typer.colors.GREEN)
del ms
return {
"source_folders": len(source_folders),
"migration_candidates": len(candidates),

View File

@ -9,20 +9,27 @@ model_endpoint = https://api.openai.com/v1
model_endpoint_type = openai
context_window = 8192
[openai]
key = FAKE_KEY
[embedding]
embedding_endpoint_type = openai
embedding_endpoint = https://api.openai.com/v1
embedding_model = text-embedding-ada-002
embedding_dim = 1536
embedding_chunk_size = 300
[archival_storage]
type = local
type = chroma
path = /Users/sarahwooders/.memgpt/chroma
[recall_storage]
type = sqlite
path = /Users/sarahwooders/.memgpt
[metadata_storage]
type = sqlite
path = /Users/sarahwooders/.memgpt
[version]
memgpt_version = 0.2.11
memgpt_version = 0.2.12
[client]
anon_clientid = 00000000000000000000d67f40108c5c

14
tests/test_migrate.py Normal file
View File

@ -0,0 +1,14 @@
import os
from memgpt.migrate import migrate_all_agents, migrate_all_sources
def test_migrate_0211():
data_dir = "tests/data/memgpt-0.2.11"
# os.environ["MEMGPT_CONFIG_PATH"] = os.path.join(data_dir, "config")
# print(f"MEMGPT_CONFIG_PATH={os.environ['MEMGPT_CONFIG_PATH']}")
res = migrate_all_agents(data_dir)
assert res["failed_migrations"] == 0, f"Failed migrations: {res}"
res = migrate_all_sources(data_dir)
assert res["failed_migrations"] == 0, f"Failed migrations: {res}"
# TODO: assert everything is in the DB