mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
fix: misc updates to the migration code to improve migration UX (#941)
This commit is contained in:
parent
de6f6b1987
commit
330b199616
@ -33,10 +33,12 @@ from memgpt.metadata import MetadataStore, save_agent
|
||||
from memgpt.migrate import migrate_all_agents, migrate_all_sources
|
||||
|
||||
|
||||
def migrate():
|
||||
def migrate(
|
||||
debug: Annotated[bool, typer.Option(help="Print extra tracebacks for failed migrations")] = False,
|
||||
):
|
||||
"""Migrate old agents (pre 0.2.12) to the new database system"""
|
||||
migrate_all_agents()
|
||||
migrate_all_sources()
|
||||
migrate_all_agents(debug=debug)
|
||||
migrate_all_sources(debug=debug)
|
||||
|
||||
|
||||
class QuickstartChoice(Enum):
|
||||
@ -406,14 +408,17 @@ def run(
|
||||
raise
|
||||
elif selection == choices[1]:
|
||||
try:
|
||||
wipe_config_and_reconfigure(run_configure=False)
|
||||
# Don't create a config, so that the next block of code asking about quickstart is run
|
||||
wipe_config_and_reconfigure(run_configure=False, create_config=False)
|
||||
except Exception as e:
|
||||
typer.secho(f"Fresh config generation failed - error:\n{e}", fg=typer.colors.RED)
|
||||
raise
|
||||
else:
|
||||
typer.secho("Migration cancelled (to migrate old agents, run `memgpt migrate`)", fg=typer.colors.RED)
|
||||
typer.secho("MemGPT config regeneration cancelled", fg=typer.colors.RED)
|
||||
raise KeyboardInterrupt()
|
||||
|
||||
typer.secho("Note: if you would like to migrate old agents to the new release, please run `memgpt migrate`!", fg=typer.colors.GREEN)
|
||||
|
||||
if not MemGPTConfig.exists():
|
||||
# if no config, ask about quickstart
|
||||
# do you want to do:
|
||||
|
@ -8,7 +8,7 @@ import traceback
|
||||
import uuid
|
||||
import json
|
||||
import shutil
|
||||
from typing import Optional
|
||||
from typing import Optional, List
|
||||
import pytz
|
||||
|
||||
import typer
|
||||
@ -42,7 +42,7 @@ VERSION_CUTOFF = "0.2.12"
|
||||
MIGRATION_BACKUP_FOLDER = "migration_backups"
|
||||
|
||||
|
||||
def wipe_config_and_reconfigure(data_dir: str = MEMGPT_DIR, run_configure=True):
|
||||
def wipe_config_and_reconfigure(data_dir: str = MEMGPT_DIR, run_configure=True, create_config=True):
|
||||
"""Wipe (backup) the config file, and launch `memgpt configure`"""
|
||||
|
||||
if not os.path.exists(os.path.join(data_dir, MIGRATION_BACKUP_FOLDER)):
|
||||
@ -67,7 +67,7 @@ def wipe_config_and_reconfigure(data_dir: str = MEMGPT_DIR, run_configure=True):
|
||||
if run_configure:
|
||||
# Either run configure
|
||||
configure()
|
||||
else:
|
||||
elif create_config:
|
||||
# Or create a new config with defaults
|
||||
MemGPTConfig.load()
|
||||
|
||||
@ -183,7 +183,7 @@ def migrate_source(source_name: str, data_dir: str = MEMGPT_DIR, ms: Optional[Me
|
||||
assert len(passages) > 0, f"Source {source_name} has no passages"
|
||||
conn = StorageConnector.get_storage_connector(TableType.PASSAGES, config=config, user_id=user_id)
|
||||
conn.insert_many(passages)
|
||||
print(f"Inserted {len(passages)} to {source_name}")
|
||||
# print(f"Inserted {len(passages)} to {source_name}")
|
||||
except Exception as e:
|
||||
# delete from metadata store
|
||||
ms.delete_source(source.id)
|
||||
@ -194,7 +194,7 @@ def migrate_source(source_name: str, data_dir: str = MEMGPT_DIR, ms: Optional[Me
|
||||
assert source is not None, f"Failed to load source {source_name} from database after migration"
|
||||
|
||||
|
||||
def migrate_agent(agent_name: str, data_dir: str = MEMGPT_DIR, ms: Optional[MetadataStore] = None):
|
||||
def migrate_agent(agent_name: str, data_dir: str = MEMGPT_DIR, ms: Optional[MetadataStore] = None) -> List[str]:
|
||||
"""Migrate an old agent folder (`~/.memgpt/agents/{agent_name}`)
|
||||
|
||||
Steps:
|
||||
@ -202,7 +202,12 @@ def migrate_agent(agent_name: str, data_dir: str = MEMGPT_DIR, ms: Optional[Meta
|
||||
2. Create a new AgentState using the agent config + agent internal state
|
||||
3. Instantiate a new Agent by passing AgentState to Agent.__init__
|
||||
(This will automatically run into a new database)
|
||||
|
||||
If success, returns empty list
|
||||
If warning, returns a list of strings (warning message)
|
||||
If error, raises an Exception
|
||||
"""
|
||||
warnings = []
|
||||
|
||||
# 1. Load the agent state JSON from the old folder
|
||||
# TODO
|
||||
@ -216,6 +221,7 @@ def migrate_agent(agent_name: str, data_dir: str = MEMGPT_DIR, ms: Optional[Meta
|
||||
raise ValueError(f"Cannot load {agent_name} - no saved checkpoints found in {agent_ckpt_directory}")
|
||||
# NOTE this is a soft fail, just allow it to pass
|
||||
# return
|
||||
# return [f"Cannot load {agent_name} - no saved checkpoints found in {agent_ckpt_directory}"]
|
||||
|
||||
# Sort files based on modified timestamp, with the latest file being the first.
|
||||
state_filename = max(json_files, key=os.path.getmtime)
|
||||
@ -231,6 +237,7 @@ def migrate_agent(agent_name: str, data_dir: str = MEMGPT_DIR, ms: Optional[Meta
|
||||
archival_filename = os.path.join(agent_folder, "persistence_manager", "index", "nodes.pkl")
|
||||
if not os.path.exists(persistence_filename):
|
||||
raise ValueError(f"Cannot load {agent_name} - no saved persistence pickle found at {persistence_filename}")
|
||||
# return [f"Cannot load {agent_name} - no saved persistence pickle found at {persistence_filename}"]
|
||||
|
||||
try:
|
||||
with open(persistence_filename, "rb") as f:
|
||||
@ -328,7 +335,10 @@ def migrate_agent(agent_name: str, data_dir: str = MEMGPT_DIR, ms: Optional[Meta
|
||||
persistence_manager = LocalStateManager(agent_state=agent_state)
|
||||
|
||||
# First clean up the recall message history to add tool call ids
|
||||
full_message_history_buffer = annotate_message_json_list_with_tool_calls([d["message"] for d in data["all_messages"]])
|
||||
# allow_tool_roles in case some of the old messages were actually already in tool call format (for whatever reason)
|
||||
full_message_history_buffer = annotate_message_json_list_with_tool_calls(
|
||||
[d["message"] for d in data["all_messages"]], allow_tool_roles=True
|
||||
)
|
||||
for i in range(len(data["all_messages"])):
|
||||
data["all_messages"][i]["message"] = full_message_history_buffer[i]
|
||||
|
||||
@ -383,16 +393,22 @@ def migrate_agent(agent_name: str, data_dir: str = MEMGPT_DIR, ms: Optional[Meta
|
||||
# out_of_context_messages.append(message_obj)
|
||||
|
||||
if not any(message_is_in_context):
|
||||
typer.secho(
|
||||
f"Warning: didn't find late buffer recall message (i={i}/{len(recall_message_full)-1}) inside agent context\n{recall_message}",
|
||||
fg=typer.colors.RED,
|
||||
# typer.secho(
|
||||
# f"Warning: didn't find late buffer recall message (i={i}/{len(recall_message_full)-1}) inside agent context\n{recall_message}",
|
||||
# fg=typer.colors.RED,
|
||||
# )
|
||||
warnings.append(
|
||||
f"Didn't find late buffer recall message (i={i}/{len(recall_message_full)-1}) inside agent context\n{recall_message}"
|
||||
)
|
||||
out_of_context_messages.append(message_obj)
|
||||
else:
|
||||
if sum(message_is_in_context) > 1:
|
||||
typer.secho(
|
||||
f"Warning: found multiple occurences of recall message (i={i}/{len(recall_message_full)-1}) inside agent context\n{recall_message}",
|
||||
fg=typer.colors.RED,
|
||||
# typer.secho(
|
||||
# f"Warning: found multiple occurences of recall message (i={i}/{len(recall_message_full)-1}) inside agent context\n{recall_message}",
|
||||
# fg=typer.colors.RED,
|
||||
# )
|
||||
warnings.append(
|
||||
f"Found multiple occurences of recall message (i={i}/{len(recall_message_full)-1}) inside agent context\n{recall_message}"
|
||||
)
|
||||
in_context_messages.append(message_obj)
|
||||
|
||||
@ -400,12 +416,15 @@ def migrate_agent(agent_name: str, data_dir: str = MEMGPT_DIR, ms: Optional[Meta
|
||||
# if we're not in the final portion of the recall memory buffer, then it's 100% out-of-context
|
||||
out_of_context_messages.append(message_obj)
|
||||
|
||||
assert len(in_context_messages) > 0
|
||||
assert len(in_context_messages) > 0, f"Couldn't find any in-context messages (agent_cache = {len(agent_message_cache)})"
|
||||
# assert len(in_context_messages) == len(agent_message_cache), (len(in_context_messages), len(agent_message_cache))
|
||||
if len(in_context_messages) != len(agent_message_cache):
|
||||
typer.secho(
|
||||
f"Warning: uneven match of new in-context messages vs loaded cache ({len(in_context_messages)} != {len(agent_message_cache)})",
|
||||
fg=typer.colors.RED,
|
||||
# typer.secho(
|
||||
# f"Warning: uneven match of new in-context messages vs loaded cache ({len(in_context_messages)} != {len(agent_message_cache)})",
|
||||
# fg=typer.colors.RED,
|
||||
# )
|
||||
warnings.append(
|
||||
f"Uneven match of new in-context messages vs loaded cache ({len(in_context_messages)} != {len(agent_message_cache)})"
|
||||
)
|
||||
# assert (
|
||||
# len(in_context_messages) + len(out_of_context_messages) == state_dict["messages_total"]
|
||||
@ -468,9 +487,14 @@ def migrate_agent(agent_name: str, data_dir: str = MEMGPT_DIR, ms: Optional[Meta
|
||||
if os.path.exists(archival_filename):
|
||||
nodes = pickle.load(open(archival_filename, "rb"))
|
||||
passages = []
|
||||
failed_inserts = []
|
||||
for node in nodes:
|
||||
if len(node.embedding) != config.default_embedding_config.embedding_dim:
|
||||
raise ValueError(f"Cannot migrate agent {agent_state.name} due to incompatible embedding dimentions.")
|
||||
# raise ValueError(f"Cannot migrate agent {agent_state.name} due to incompatible embedding dimentions.")
|
||||
# raise ValueError(f"Cannot migrate agent {agent_state.name} due to incompatible embedding dimentions.")
|
||||
failed_inserts.append(
|
||||
f"Cannot migrate passage due to incompatible embedding dimentions ({len(node.embedding)} != {config.default_embedding_config.embedding_dim}) - content = '{node.text}'."
|
||||
)
|
||||
passages.append(
|
||||
Passage(
|
||||
user_id=user.id,
|
||||
@ -483,10 +507,15 @@ def migrate_agent(agent_name: str, data_dir: str = MEMGPT_DIR, ms: Optional[Meta
|
||||
)
|
||||
if len(passages) > 0:
|
||||
agent.persistence_manager.archival_memory.storage.insert_many(passages)
|
||||
print(f"Inserted {len(passages)} passages into archival memory")
|
||||
# print(f"Inserted {len(passages)} passages into archival memory")
|
||||
|
||||
if len(failed_inserts) > 0:
|
||||
warnings.append(
|
||||
f"Failed to transfer {len(failed_inserts)}/{len(nodes)} passages from old archival memory: " + ", ".join(failed_inserts)
|
||||
)
|
||||
|
||||
else:
|
||||
print("No archival memory found at", archival_filename)
|
||||
warnings.append("No archival memory found at", archival_filename)
|
||||
|
||||
except:
|
||||
ms.delete_agent(agent_state.id)
|
||||
@ -499,9 +528,11 @@ def migrate_agent(agent_name: str, data_dir: str = MEMGPT_DIR, ms: Optional[Meta
|
||||
print(f"Failed to move agent folder from {agent_folder} to {new_agent_folder}")
|
||||
raise
|
||||
|
||||
return warnings
|
||||
|
||||
|
||||
# def migrate_all_agents(stop_on_fail=True):
|
||||
def migrate_all_agents(data_dir: str = MEMGPT_DIR, stop_on_fail: bool = False) -> dict:
|
||||
def migrate_all_agents(data_dir: str = MEMGPT_DIR, stop_on_fail: bool = False, debug: bool = False) -> dict:
|
||||
"""Scan over all agent folders in data_dir and migrate each agent."""
|
||||
|
||||
if not os.path.exists(os.path.join(data_dir, MIGRATION_BACKUP_FOLDER)):
|
||||
@ -533,8 +564,9 @@ def migrate_all_agents(data_dir: str = MEMGPT_DIR, stop_on_fail: bool = False) -
|
||||
|
||||
# Iterate over each folder with a tqdm progress bar
|
||||
count = 0
|
||||
successes = []
|
||||
failures = []
|
||||
successes = [] # agents that migrated w/o warnings
|
||||
warnings = [] # agents that migrated but had warnings
|
||||
failures = [] # agents that failed to migrate (fatal error)
|
||||
candidates = []
|
||||
config = MemGPTConfig.load()
|
||||
print(config)
|
||||
@ -545,15 +577,19 @@ def migrate_all_agents(data_dir: str = MEMGPT_DIR, stop_on_fail: bool = False) -
|
||||
try:
|
||||
if agent_is_migrateable(agent_name=agent_name, data_dir=data_dir):
|
||||
candidates.append(agent_name)
|
||||
migrate_agent(agent_name, data_dir=data_dir, ms=ms)
|
||||
successes.append(agent_name)
|
||||
migration_warnings = migrate_agent(agent_name, data_dir=data_dir, ms=ms)
|
||||
if len(migration_warnings) == 0:
|
||||
successes.append(agent_name)
|
||||
else:
|
||||
warnings.append((agent_name, migration_warnings))
|
||||
count += 1
|
||||
else:
|
||||
continue
|
||||
except Exception as e:
|
||||
failures.append({"name": agent_name, "reason": str(e)})
|
||||
# typer.secho(f"Migrating {agent_name} failed with: {str(e)}", fg=typer.colors.RED)
|
||||
traceback.print_exc()
|
||||
if debug:
|
||||
traceback.print_exc()
|
||||
if stop_on_fail:
|
||||
raise
|
||||
except KeyboardInterrupt:
|
||||
@ -562,28 +598,50 @@ def migrate_all_agents(data_dir: str = MEMGPT_DIR, stop_on_fail: bool = False) -
|
||||
if len(candidates) == 0:
|
||||
typer.secho(f"No migration candidates found ({len(agent_folders)} agent folders total)", fg=typer.colors.GREEN)
|
||||
else:
|
||||
typer.secho(f"Inspected {len(agent_folders)} agent folders")
|
||||
typer.secho(f"Inspected {len(agent_folders)} agent folders for migration")
|
||||
|
||||
if len(warnings) > 0:
|
||||
typer.secho(f"Migration warnings:", fg=typer.colors.BRIGHT_YELLOW)
|
||||
for warn in warnings:
|
||||
typer.secho(f"{warn[0]}: {warn[1]}", fg=typer.colors.BRIGHT_YELLOW)
|
||||
|
||||
if len(failures) > 0:
|
||||
typer.secho(f"Failed migrations:", fg=typer.colors.RED)
|
||||
for fail in failures:
|
||||
typer.secho(f"{fail['name']}: {fail['reason']}", fg=typer.colors.RED)
|
||||
typer.secho(f"❌ {len(failures)}/{len(candidates)} migration targets failed (see reasons above)", fg=typer.colors.RED)
|
||||
|
||||
if len(failures) > 0:
|
||||
typer.secho(
|
||||
f"🔴 {len(failures)}/{len(candidates)} agents failed to migrate (see reasons above)",
|
||||
fg=typer.colors.RED,
|
||||
)
|
||||
typer.secho(f"{[d['name'] for d in failures]}", fg=typer.colors.RED)
|
||||
if count > 0:
|
||||
typer.secho(f"✅ {count}/{len(candidates)} agents were successfully migrated to the new database format", fg=typer.colors.GREEN)
|
||||
|
||||
if len(warnings) > 0:
|
||||
typer.secho(
|
||||
f"🟠 {len(warnings)}/{len(candidates)} agents successfully migrated with warnings (see reasons above)",
|
||||
fg=typer.colors.BRIGHT_YELLOW,
|
||||
)
|
||||
typer.secho(f"{[t[0] for t in warnings]}", fg=typer.colors.BRIGHT_YELLOW)
|
||||
|
||||
if len(successes) > 0:
|
||||
typer.secho(
|
||||
f"🟢 {len(successes)}/{len(candidates)} agents successfully migrated with no warnings",
|
||||
fg=typer.colors.GREEN,
|
||||
)
|
||||
typer.secho(f"{successes}", fg=typer.colors.GREEN)
|
||||
|
||||
del ms
|
||||
return {
|
||||
"agent_folders": len(agent_folders),
|
||||
"migration_candidates": candidates,
|
||||
"successful_migrations": count,
|
||||
"successful_migrations": len(successes) + len(warnings),
|
||||
"failed_migrations": failures,
|
||||
"user_id": uuid.UUID(MemGPTConfig.load().anon_clientid),
|
||||
}
|
||||
|
||||
|
||||
def migrate_all_sources(data_dir: str = MEMGPT_DIR, stop_on_fail: bool = False) -> dict:
|
||||
def migrate_all_sources(data_dir: str = MEMGPT_DIR, stop_on_fail: bool = False, debug: bool = False) -> dict:
|
||||
"""Scan over all agent folders in data_dir and migrate each agent."""
|
||||
|
||||
sources_dir = os.path.join(data_dir, "archival")
|
||||
@ -610,7 +668,8 @@ def migrate_all_sources(data_dir: str = MEMGPT_DIR, stop_on_fail: bool = False)
|
||||
count += 1
|
||||
except Exception as e:
|
||||
failures.append({"name": source_name, "reason": str(e)})
|
||||
traceback.print_exc()
|
||||
if debug:
|
||||
traceback.print_exc()
|
||||
if stop_on_fail:
|
||||
raise
|
||||
# typer.secho(f"Migrating {agent_name} failed with: {str(e)}", fg=typer.colors.RED)
|
||||
|
@ -520,7 +520,7 @@ def enforce_types(func):
|
||||
return wrapper
|
||||
|
||||
|
||||
def annotate_message_json_list_with_tool_calls(messages: List[dict]):
|
||||
def annotate_message_json_list_with_tool_calls(messages: List[dict], allow_tool_roles: bool = False):
|
||||
"""Add in missing tool_call_id fields to a list of messages using function call style
|
||||
|
||||
Walk through the list forwards:
|
||||
@ -569,10 +569,62 @@ def annotate_message_json_list_with_tool_calls(messages: List[dict]):
|
||||
message["tool_call_id"] = tool_call_id
|
||||
tool_call_id = None # wipe the buffer
|
||||
|
||||
elif message["role"] == "assistant" and "tool_calls" in message and message["tool_calls"] is not None:
|
||||
if not allow_tool_roles:
|
||||
raise NotImplementedError(
|
||||
f"tool_call_id annotation is meant for deprecated functions style, but got role 'assistant' with 'tool_calls' in message (i={i}, total={len(messages)}):\n{messages[:i]}\n{message}"
|
||||
)
|
||||
|
||||
if len(message["tool_calls"]) != 1:
|
||||
raise NotImplementedError(
|
||||
f"Got unexpected format for tool_calls inside assistant message (i={i}, total={len(messages)}):\n{messages[:i]}\n{message}"
|
||||
)
|
||||
|
||||
assistant_tool_call = message["tool_calls"][0]
|
||||
if "id" in assistant_tool_call and assistant_tool_call["id"] is not None:
|
||||
printd(f"Message already has id (tool_call_id)")
|
||||
tool_call_id = assistant_tool_call["id"]
|
||||
else:
|
||||
tool_call_id = str(uuid.uuid4())
|
||||
message["tool_calls"][0]["id"] = tool_call_id
|
||||
# also just put it at the top level for ease-of-access
|
||||
# message["tool_call_id"] = tool_call_id
|
||||
tool_call_index = i
|
||||
|
||||
elif message["role"] == "tool":
|
||||
raise NotImplementedError(
|
||||
f"tool_call_id annotation is meant for deprecated functions style, but got role 'tool' in message (i={i}, total={len(messages)}):\n{messages[:i]}\n{message}"
|
||||
)
|
||||
if not allow_tool_roles:
|
||||
raise NotImplementedError(
|
||||
f"tool_call_id annotation is meant for deprecated functions style, but got role 'tool' in message (i={i}, total={len(messages)}):\n{messages[:i]}\n{message}"
|
||||
)
|
||||
|
||||
# if "tool_call_id" not in message or message["tool_call_id"] is None:
|
||||
# raise ValueError(f"Got a tool call role, but there's no tool_call_id:\n{messages[:i]}\n{message}")
|
||||
|
||||
# We should have a new tool call id in the buffer
|
||||
if tool_call_id is None:
|
||||
# raise ValueError(
|
||||
print(
|
||||
f"Got a tool call role, but did not have a saved tool_call_id ready to use (i={i}, total={len(messages)}):\n{messages[:i]}\n{message}"
|
||||
)
|
||||
# allow a soft fail in this case
|
||||
message["tool_call_id"] = str(uuid.uuid4())
|
||||
elif "tool_call_id" in message and message["tool_call_id"] is not None:
|
||||
if tool_call_id is not None and tool_call_id != message["tool_call_id"]:
|
||||
# just wipe it
|
||||
# raise ValueError(
|
||||
# f"Got a tool call role, but it already had a saved tool_call_id (i={i}, total={len(messages)}):\n{messages[:i]}\n{message}"
|
||||
# )
|
||||
message["tool_call_id"] = tool_call_id
|
||||
tool_call_id = None # wipe the buffer
|
||||
else:
|
||||
tool_call_id = None
|
||||
elif i != tool_call_index + 1:
|
||||
raise ValueError(
|
||||
f"Got a tool call role, saved tool_call_id came earlier than i-1 (i={i}, total={len(messages)}):\n{messages[:i]}\n{message}"
|
||||
)
|
||||
else:
|
||||
message["tool_call_id"] = tool_call_id
|
||||
tool_call_id = None # wipe the buffer
|
||||
|
||||
else:
|
||||
# eg role == 'user', nothing to do here
|
||||
|
Loading…
Reference in New Issue
Block a user