fix: misc updates to the migration code to improve migration UX (#941)

This commit is contained in:
Charles Packer 2024-01-29 17:42:41 -08:00 committed by GitHub
parent de6f6b1987
commit 330b199616
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 157 additions and 41 deletions

View File

@ -33,10 +33,12 @@ from memgpt.metadata import MetadataStore, save_agent
from memgpt.migrate import migrate_all_agents, migrate_all_sources
def migrate():
def migrate(
debug: Annotated[bool, typer.Option(help="Print extra tracebacks for failed migrations")] = False,
):
"""Migrate old agents (pre 0.2.12) to the new database system"""
migrate_all_agents()
migrate_all_sources()
migrate_all_agents(debug=debug)
migrate_all_sources(debug=debug)
class QuickstartChoice(Enum):
@ -406,14 +408,17 @@ def run(
raise
elif selection == choices[1]:
try:
wipe_config_and_reconfigure(run_configure=False)
# Don't create a config, so that the next block of code asking about quickstart is run
wipe_config_and_reconfigure(run_configure=False, create_config=False)
except Exception as e:
typer.secho(f"Fresh config generation failed - error:\n{e}", fg=typer.colors.RED)
raise
else:
typer.secho("Migration cancelled (to migrate old agents, run `memgpt migrate`)", fg=typer.colors.RED)
typer.secho("MemGPT config regeneration cancelled", fg=typer.colors.RED)
raise KeyboardInterrupt()
typer.secho("Note: if you would like to migrate old agents to the new release, please run `memgpt migrate`!", fg=typer.colors.GREEN)
if not MemGPTConfig.exists():
# if no config, ask about quickstart
# do you want to do:

View File

@ -8,7 +8,7 @@ import traceback
import uuid
import json
import shutil
from typing import Optional
from typing import Optional, List
import pytz
import typer
@ -42,7 +42,7 @@ VERSION_CUTOFF = "0.2.12"
MIGRATION_BACKUP_FOLDER = "migration_backups"
def wipe_config_and_reconfigure(data_dir: str = MEMGPT_DIR, run_configure=True):
def wipe_config_and_reconfigure(data_dir: str = MEMGPT_DIR, run_configure=True, create_config=True):
"""Wipe (backup) the config file, and launch `memgpt configure`"""
if not os.path.exists(os.path.join(data_dir, MIGRATION_BACKUP_FOLDER)):
@ -67,7 +67,7 @@ def wipe_config_and_reconfigure(data_dir: str = MEMGPT_DIR, run_configure=True):
if run_configure:
# Either run configure
configure()
else:
elif create_config:
# Or create a new config with defaults
MemGPTConfig.load()
@ -183,7 +183,7 @@ def migrate_source(source_name: str, data_dir: str = MEMGPT_DIR, ms: Optional[Me
assert len(passages) > 0, f"Source {source_name} has no passages"
conn = StorageConnector.get_storage_connector(TableType.PASSAGES, config=config, user_id=user_id)
conn.insert_many(passages)
print(f"Inserted {len(passages)} to {source_name}")
# print(f"Inserted {len(passages)} to {source_name}")
except Exception as e:
# delete from metadata store
ms.delete_source(source.id)
@ -194,7 +194,7 @@ def migrate_source(source_name: str, data_dir: str = MEMGPT_DIR, ms: Optional[Me
assert source is not None, f"Failed to load source {source_name} from database after migration"
def migrate_agent(agent_name: str, data_dir: str = MEMGPT_DIR, ms: Optional[MetadataStore] = None):
def migrate_agent(agent_name: str, data_dir: str = MEMGPT_DIR, ms: Optional[MetadataStore] = None) -> List[str]:
"""Migrate an old agent folder (`~/.memgpt/agents/{agent_name}`)
Steps:
@ -202,7 +202,12 @@ def migrate_agent(agent_name: str, data_dir: str = MEMGPT_DIR, ms: Optional[Meta
2. Create a new AgentState using the agent config + agent internal state
3. Instantiate a new Agent by passing AgentState to Agent.__init__
(This will automatically run into a new database)
If success, returns empty list
If warning, returns a list of strings (warning message)
If error, raises an Exception
"""
warnings = []
# 1. Load the agent state JSON from the old folder
# TODO
@ -216,6 +221,7 @@ def migrate_agent(agent_name: str, data_dir: str = MEMGPT_DIR, ms: Optional[Meta
raise ValueError(f"Cannot load {agent_name} - no saved checkpoints found in {agent_ckpt_directory}")
# NOTE this is a soft fail, just allow it to pass
# return
# return [f"Cannot load {agent_name} - no saved checkpoints found in {agent_ckpt_directory}"]
# Sort files based on modified timestamp, with the latest file being the first.
state_filename = max(json_files, key=os.path.getmtime)
@ -231,6 +237,7 @@ def migrate_agent(agent_name: str, data_dir: str = MEMGPT_DIR, ms: Optional[Meta
archival_filename = os.path.join(agent_folder, "persistence_manager", "index", "nodes.pkl")
if not os.path.exists(persistence_filename):
raise ValueError(f"Cannot load {agent_name} - no saved persistence pickle found at {persistence_filename}")
# return [f"Cannot load {agent_name} - no saved persistence pickle found at {persistence_filename}"]
try:
with open(persistence_filename, "rb") as f:
@ -328,7 +335,10 @@ def migrate_agent(agent_name: str, data_dir: str = MEMGPT_DIR, ms: Optional[Meta
persistence_manager = LocalStateManager(agent_state=agent_state)
# First clean up the recall message history to add tool call ids
full_message_history_buffer = annotate_message_json_list_with_tool_calls([d["message"] for d in data["all_messages"]])
# allow_tool_roles in case some of the old messages were actually already in tool call format (for whatever reason)
full_message_history_buffer = annotate_message_json_list_with_tool_calls(
[d["message"] for d in data["all_messages"]], allow_tool_roles=True
)
for i in range(len(data["all_messages"])):
data["all_messages"][i]["message"] = full_message_history_buffer[i]
@ -383,16 +393,22 @@ def migrate_agent(agent_name: str, data_dir: str = MEMGPT_DIR, ms: Optional[Meta
# out_of_context_messages.append(message_obj)
if not any(message_is_in_context):
typer.secho(
f"Warning: didn't find late buffer recall message (i={i}/{len(recall_message_full)-1}) inside agent context\n{recall_message}",
fg=typer.colors.RED,
# typer.secho(
# f"Warning: didn't find late buffer recall message (i={i}/{len(recall_message_full)-1}) inside agent context\n{recall_message}",
# fg=typer.colors.RED,
# )
warnings.append(
f"Didn't find late buffer recall message (i={i}/{len(recall_message_full)-1}) inside agent context\n{recall_message}"
)
out_of_context_messages.append(message_obj)
else:
if sum(message_is_in_context) > 1:
typer.secho(
f"Warning: found multiple occurences of recall message (i={i}/{len(recall_message_full)-1}) inside agent context\n{recall_message}",
fg=typer.colors.RED,
# typer.secho(
# f"Warning: found multiple occurences of recall message (i={i}/{len(recall_message_full)-1}) inside agent context\n{recall_message}",
# fg=typer.colors.RED,
# )
warnings.append(
f"Found multiple occurences of recall message (i={i}/{len(recall_message_full)-1}) inside agent context\n{recall_message}"
)
in_context_messages.append(message_obj)
@ -400,12 +416,15 @@ def migrate_agent(agent_name: str, data_dir: str = MEMGPT_DIR, ms: Optional[Meta
# if we're not in the final portion of the recall memory buffer, then it's 100% out-of-context
out_of_context_messages.append(message_obj)
assert len(in_context_messages) > 0
assert len(in_context_messages) > 0, f"Couldn't find any in-context messages (agent_cache = {len(agent_message_cache)})"
# assert len(in_context_messages) == len(agent_message_cache), (len(in_context_messages), len(agent_message_cache))
if len(in_context_messages) != len(agent_message_cache):
typer.secho(
f"Warning: uneven match of new in-context messages vs loaded cache ({len(in_context_messages)} != {len(agent_message_cache)})",
fg=typer.colors.RED,
# typer.secho(
# f"Warning: uneven match of new in-context messages vs loaded cache ({len(in_context_messages)} != {len(agent_message_cache)})",
# fg=typer.colors.RED,
# )
warnings.append(
f"Uneven match of new in-context messages vs loaded cache ({len(in_context_messages)} != {len(agent_message_cache)})"
)
# assert (
# len(in_context_messages) + len(out_of_context_messages) == state_dict["messages_total"]
@ -468,9 +487,14 @@ def migrate_agent(agent_name: str, data_dir: str = MEMGPT_DIR, ms: Optional[Meta
if os.path.exists(archival_filename):
nodes = pickle.load(open(archival_filename, "rb"))
passages = []
failed_inserts = []
for node in nodes:
if len(node.embedding) != config.default_embedding_config.embedding_dim:
raise ValueError(f"Cannot migrate agent {agent_state.name} due to incompatible embedding dimentions.")
# raise ValueError(f"Cannot migrate agent {agent_state.name} due to incompatible embedding dimentions.")
# raise ValueError(f"Cannot migrate agent {agent_state.name} due to incompatible embedding dimentions.")
failed_inserts.append(
f"Cannot migrate passage due to incompatible embedding dimentions ({len(node.embedding)} != {config.default_embedding_config.embedding_dim}) - content = '{node.text}'."
)
passages.append(
Passage(
user_id=user.id,
@ -483,10 +507,15 @@ def migrate_agent(agent_name: str, data_dir: str = MEMGPT_DIR, ms: Optional[Meta
)
if len(passages) > 0:
agent.persistence_manager.archival_memory.storage.insert_many(passages)
print(f"Inserted {len(passages)} passages into archival memory")
# print(f"Inserted {len(passages)} passages into archival memory")
if len(failed_inserts) > 0:
warnings.append(
f"Failed to transfer {len(failed_inserts)}/{len(nodes)} passages from old archival memory: " + ", ".join(failed_inserts)
)
else:
print("No archival memory found at", archival_filename)
warnings.append("No archival memory found at", archival_filename)
except:
ms.delete_agent(agent_state.id)
@ -499,9 +528,11 @@ def migrate_agent(agent_name: str, data_dir: str = MEMGPT_DIR, ms: Optional[Meta
print(f"Failed to move agent folder from {agent_folder} to {new_agent_folder}")
raise
return warnings
# def migrate_all_agents(stop_on_fail=True):
def migrate_all_agents(data_dir: str = MEMGPT_DIR, stop_on_fail: bool = False) -> dict:
def migrate_all_agents(data_dir: str = MEMGPT_DIR, stop_on_fail: bool = False, debug: bool = False) -> dict:
"""Scan over all agent folders in data_dir and migrate each agent."""
if not os.path.exists(os.path.join(data_dir, MIGRATION_BACKUP_FOLDER)):
@ -533,8 +564,9 @@ def migrate_all_agents(data_dir: str = MEMGPT_DIR, stop_on_fail: bool = False) -
# Iterate over each folder with a tqdm progress bar
count = 0
successes = []
failures = []
successes = [] # agents that migrated w/o warnings
warnings = [] # agents that migrated but had warnings
failures = [] # agents that failed to migrate (fatal error)
candidates = []
config = MemGPTConfig.load()
print(config)
@ -545,15 +577,19 @@ def migrate_all_agents(data_dir: str = MEMGPT_DIR, stop_on_fail: bool = False) -
try:
if agent_is_migrateable(agent_name=agent_name, data_dir=data_dir):
candidates.append(agent_name)
migrate_agent(agent_name, data_dir=data_dir, ms=ms)
successes.append(agent_name)
migration_warnings = migrate_agent(agent_name, data_dir=data_dir, ms=ms)
if len(migration_warnings) == 0:
successes.append(agent_name)
else:
warnings.append((agent_name, migration_warnings))
count += 1
else:
continue
except Exception as e:
failures.append({"name": agent_name, "reason": str(e)})
# typer.secho(f"Migrating {agent_name} failed with: {str(e)}", fg=typer.colors.RED)
traceback.print_exc()
if debug:
traceback.print_exc()
if stop_on_fail:
raise
except KeyboardInterrupt:
@ -562,28 +598,50 @@ def migrate_all_agents(data_dir: str = MEMGPT_DIR, stop_on_fail: bool = False) -
if len(candidates) == 0:
typer.secho(f"No migration candidates found ({len(agent_folders)} agent folders total)", fg=typer.colors.GREEN)
else:
typer.secho(f"Inspected {len(agent_folders)} agent folders")
typer.secho(f"Inspected {len(agent_folders)} agent folders for migration")
if len(warnings) > 0:
typer.secho(f"Migration warnings:", fg=typer.colors.BRIGHT_YELLOW)
for warn in warnings:
typer.secho(f"{warn[0]}: {warn[1]}", fg=typer.colors.BRIGHT_YELLOW)
if len(failures) > 0:
typer.secho(f"Failed migrations:", fg=typer.colors.RED)
for fail in failures:
typer.secho(f"{fail['name']}: {fail['reason']}", fg=typer.colors.RED)
typer.secho(f"{len(failures)}/{len(candidates)} migration targets failed (see reasons above)", fg=typer.colors.RED)
if len(failures) > 0:
typer.secho(
f"🔴 {len(failures)}/{len(candidates)} agents failed to migrate (see reasons above)",
fg=typer.colors.RED,
)
typer.secho(f"{[d['name'] for d in failures]}", fg=typer.colors.RED)
if count > 0:
typer.secho(f"{count}/{len(candidates)} agents were successfully migrated to the new database format", fg=typer.colors.GREEN)
if len(warnings) > 0:
typer.secho(
f"🟠 {len(warnings)}/{len(candidates)} agents successfully migrated with warnings (see reasons above)",
fg=typer.colors.BRIGHT_YELLOW,
)
typer.secho(f"{[t[0] for t in warnings]}", fg=typer.colors.BRIGHT_YELLOW)
if len(successes) > 0:
typer.secho(
f"🟢 {len(successes)}/{len(candidates)} agents successfully migrated with no warnings",
fg=typer.colors.GREEN,
)
typer.secho(f"{successes}", fg=typer.colors.GREEN)
del ms
return {
"agent_folders": len(agent_folders),
"migration_candidates": candidates,
"successful_migrations": count,
"successful_migrations": len(successes) + len(warnings),
"failed_migrations": failures,
"user_id": uuid.UUID(MemGPTConfig.load().anon_clientid),
}
def migrate_all_sources(data_dir: str = MEMGPT_DIR, stop_on_fail: bool = False) -> dict:
def migrate_all_sources(data_dir: str = MEMGPT_DIR, stop_on_fail: bool = False, debug: bool = False) -> dict:
"""Scan over all agent folders in data_dir and migrate each agent."""
sources_dir = os.path.join(data_dir, "archival")
@ -610,7 +668,8 @@ def migrate_all_sources(data_dir: str = MEMGPT_DIR, stop_on_fail: bool = False)
count += 1
except Exception as e:
failures.append({"name": source_name, "reason": str(e)})
traceback.print_exc()
if debug:
traceback.print_exc()
if stop_on_fail:
raise
# typer.secho(f"Migrating {agent_name} failed with: {str(e)}", fg=typer.colors.RED)

View File

@ -520,7 +520,7 @@ def enforce_types(func):
return wrapper
def annotate_message_json_list_with_tool_calls(messages: List[dict]):
def annotate_message_json_list_with_tool_calls(messages: List[dict], allow_tool_roles: bool = False):
"""Add in missing tool_call_id fields to a list of messages using function call style
Walk through the list forwards:
@ -569,10 +569,62 @@ def annotate_message_json_list_with_tool_calls(messages: List[dict]):
message["tool_call_id"] = tool_call_id
tool_call_id = None # wipe the buffer
elif message["role"] == "assistant" and "tool_calls" in message and message["tool_calls"] is not None:
if not allow_tool_roles:
raise NotImplementedError(
f"tool_call_id annotation is meant for deprecated functions style, but got role 'assistant' with 'tool_calls' in message (i={i}, total={len(messages)}):\n{messages[:i]}\n{message}"
)
if len(message["tool_calls"]) != 1:
raise NotImplementedError(
f"Got unexpected format for tool_calls inside assistant message (i={i}, total={len(messages)}):\n{messages[:i]}\n{message}"
)
assistant_tool_call = message["tool_calls"][0]
if "id" in assistant_tool_call and assistant_tool_call["id"] is not None:
printd(f"Message already has id (tool_call_id)")
tool_call_id = assistant_tool_call["id"]
else:
tool_call_id = str(uuid.uuid4())
message["tool_calls"][0]["id"] = tool_call_id
# also just put it at the top level for ease-of-access
# message["tool_call_id"] = tool_call_id
tool_call_index = i
elif message["role"] == "tool":
raise NotImplementedError(
f"tool_call_id annotation is meant for deprecated functions style, but got role 'tool' in message (i={i}, total={len(messages)}):\n{messages[:i]}\n{message}"
)
if not allow_tool_roles:
raise NotImplementedError(
f"tool_call_id annotation is meant for deprecated functions style, but got role 'tool' in message (i={i}, total={len(messages)}):\n{messages[:i]}\n{message}"
)
# if "tool_call_id" not in message or message["tool_call_id"] is None:
# raise ValueError(f"Got a tool call role, but there's no tool_call_id:\n{messages[:i]}\n{message}")
# We should have a new tool call id in the buffer
if tool_call_id is None:
# raise ValueError(
print(
f"Got a tool call role, but did not have a saved tool_call_id ready to use (i={i}, total={len(messages)}):\n{messages[:i]}\n{message}"
)
# allow a soft fail in this case
message["tool_call_id"] = str(uuid.uuid4())
elif "tool_call_id" in message and message["tool_call_id"] is not None:
if tool_call_id is not None and tool_call_id != message["tool_call_id"]:
# just wipe it
# raise ValueError(
# f"Got a tool call role, but it already had a saved tool_call_id (i={i}, total={len(messages)}):\n{messages[:i]}\n{message}"
# )
message["tool_call_id"] = tool_call_id
tool_call_id = None # wipe the buffer
else:
tool_call_id = None
elif i != tool_call_index + 1:
raise ValueError(
f"Got a tool call role, saved tool_call_id came earlier than i-1 (i={i}, total={len(messages)}):\n{messages[:i]}\n{message}"
)
else:
message["tool_call_id"] = tool_call_id
tool_call_id = None # wipe the buffer
else:
# eg role == 'user', nothing to do here