fix: decrease number of saves to MemGPTConfig (#943)

This commit is contained in:
tombedor 2024-02-15 19:08:52 -08:00 committed by GitHub
parent c7fbc03e68
commit a456d82542
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 49 additions and 37 deletions

View File

@ -55,8 +55,15 @@ def str_to_quickstart_choice(choice_str: str) -> QuickstartChoice:
raise ValueError(f"{choice_str} is not a valid QuickstartChoice. Valid options are: {valid_options}")
def set_config_with_dict(new_config: dict) -> bool:
"""Set the base config using a dict"""
def set_config_with_dict(new_config: dict) -> (MemGPTConfig, bool):
"""_summary_
Args:
new_config (dict): Dict of new config values
Returns:
new_config MemGPTConfig, modified (bool): Returns the new config and a boolean indicating if the config was modified
"""
from memgpt.utils import printd
old_config = MemGPTConfig.load()
@ -93,32 +100,7 @@ def set_config_with_dict(new_config: dict) -> bool:
else:
printd(f"Skipping new config {k}: {v} == {new_config[k]}")
if modified:
printd(f"Saving new config file.")
old_config.save()
typer.secho(f"📖 MemGPT configuration file updated!", fg=typer.colors.GREEN)
typer.secho(
"\n".join(
[
f"🧠 model\t-> {old_config.default_llm_config.model}",
f"🖥️ endpoint\t-> {old_config.default_llm_config.model_endpoint}",
]
),
fg=typer.colors.GREEN,
)
return True
else:
typer.secho(f"📖 MemGPT configuration file unchanged.", fg=typer.colors.WHITE)
typer.secho(
"\n".join(
[
f"🧠 model\t-> {old_config.default_llm_config.model}",
f"🖥️ endpoint\t-> {old_config.default_llm_config.model_endpoint}",
]
),
fg=typer.colors.WHITE,
)
return False
return (old_config, modified)
def quickstart(
@ -127,7 +109,10 @@ def quickstart(
debug: Annotated[bool, typer.Option(help="Use --debug to enable debugging output")] = False,
terminal: bool = True,
):
"""Set the base config file with a single command"""
"""Set the base config file with a single command
This function and `configure` should be the ONLY places where MemGPTConfig.save() is called.
"""
# setup logger
utils.DEBUG = debug
@ -154,7 +139,7 @@ def quickstart(
config = response.json()
# Output a success message and the first few items in the dictionary as a sample
printd("JSON config file downloaded successfully.")
config_was_modified = set_config_with_dict(config)
new_config, config_was_modified = set_config_with_dict(config)
else:
typer.secho(f"Failed to download config from {url}. Status code: {response.status_code}", fg=typer.colors.RED)
@ -165,7 +150,7 @@ def quickstart(
with open(backup_config_path, "r", encoding="utf-8") as file:
backup_config = json.load(file)
printd("Loaded backup config file successfully.")
config_was_modified = set_config_with_dict(backup_config)
new_config, config_was_modified = set_config_with_dict(backup_config)
except FileNotFoundError:
typer.secho(f"Backup config file not found at {backup_config_path}", fg=typer.colors.RED)
return
@ -177,7 +162,7 @@ def quickstart(
with open(backup_config_path, "r", encoding="utf-8") as file:
backup_config = json.load(file)
printd("Loaded config file successfully.")
config_was_modified = set_config_with_dict(backup_config)
new_config, config_was_modified = set_config_with_dict(backup_config)
except FileNotFoundError:
typer.secho(f"Config file not found at {backup_config_path}", fg=typer.colors.RED)
return
@ -203,7 +188,7 @@ def quickstart(
config = response.json()
# Output a success message and the first few items in the dictionary as a sample
print("JSON config file downloaded successfully.")
config_was_modified = set_config_with_dict(config)
new_config, config_was_modified = set_config_with_dict(config)
else:
typer.secho(f"Failed to download config from {url}. Status code: {response.status_code}", fg=typer.colors.RED)
@ -214,7 +199,7 @@ def quickstart(
with open(backup_config_path, "r", encoding="utf-8") as file:
backup_config = json.load(file)
printd("Loaded backup config file successfully.")
config_was_modified = set_config_with_dict(backup_config)
new_config, config_was_modified = set_config_with_dict(backup_config)
except FileNotFoundError:
typer.secho(f"Backup config file not found at {backup_config_path}", fg=typer.colors.RED)
return
@ -226,7 +211,7 @@ def quickstart(
with open(backup_config_path, "r", encoding="utf-8") as file:
backup_config = json.load(file)
printd("Loaded config file successfully.")
config_was_modified = set_config_with_dict(backup_config)
new_config, config_was_modified = set_config_with_dict(backup_config)
except FileNotFoundError:
typer.secho(f"Config file not found at {backup_config_path}", fg=typer.colors.RED)
return
@ -234,6 +219,31 @@ def quickstart(
else:
raise NotImplementedError(backend)
if config_was_modified:
printd(f"Saving new config file.")
new_config.save()
typer.secho(f"📖 MemGPT configuration file updated!", fg=typer.colors.GREEN)
typer.secho(
"\n".join(
[
f"🧠 model\t-> {new_config.default_llm_config.model}",
f"🖥️ endpoint\t-> {new_config.default_llm_config.model_endpoint}",
]
),
fg=typer.colors.GREEN,
)
else:
typer.secho(f"📖 MemGPT configuration file unchanged.", fg=typer.colors.WHITE)
typer.secho(
"\n".join(
[
f"🧠 model\t-> {new_config.default_llm_config.model}",
f"🖥️ endpoint\t-> {new_config.default_llm_config.model_endpoint}",
]
),
fg=typer.colors.WHITE,
)
# 'terminal' = quickstart was run alone, in which case we should guide the user on the next command
if terminal:
if config_was_modified:

View File

@ -571,7 +571,10 @@ def configure_recall_storage(config: MemGPTConfig, credentials: MemGPTCredential
@app.command()
def configure():
"""Updates default MemGPT configurations"""
"""Updates default MemGPT configurations
This function and quickstart should be the ONLY place where MemGPTConfig.save() is called
"""
# check credentials
credentials = MemGPTCredentials.load()

View File

@ -185,7 +185,6 @@ class MemGPTConfig:
anon_clientid = MemGPTConfig.generate_uuid()
config = cls(anon_clientid=anon_clientid, config_path=config_path)
config.create_config_dir() # create dirs
config.save() # save updated config
return config