mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00

* Remove AsyncAgent and async from cli Refactor agent.py memory.py Refactor interface.py Refactor main.py Refactor openai_tools.py Refactor cli/cli.py stray asyncs save make legacy embeddings not use async Refactor presets Remove deleted function from import * remove stray prints * typo * another stray print * patch test --------- Co-authored-by: cpacker <packercharles@gmail.com>
264 lines
10 KiB
Python
264 lines
10 KiB
Python
import questionary
|
|
import openai
|
|
from prettytable import PrettyTable
|
|
import typer
|
|
import os
|
|
import shutil
|
|
from collections import defaultdict
|
|
|
|
# from memgpt.cli import app
|
|
from memgpt import utils
|
|
|
|
import memgpt.humans.humans as humans
|
|
import memgpt.personas.personas as personas
|
|
from memgpt.config import MemGPTConfig, AgentConfig
|
|
from memgpt.constants import MEMGPT_DIR
|
|
from memgpt.connectors.storage import StorageConnector
|
|
|
|
app = typer.Typer()
|
|
|
|
|
|
@app.command()
|
|
def configure():
|
|
"""Updates default MemGPT configurations"""
|
|
|
|
from memgpt.presets import DEFAULT_PRESET, preset_options
|
|
|
|
MemGPTConfig.create_config_dir()
|
|
|
|
# Will pre-populate with defaults, or what the user previously set
|
|
config = MemGPTConfig.load()
|
|
|
|
# openai credentials
|
|
use_openai = questionary.confirm("Do you want to enable MemGPT with OpenAI?", default=True).ask()
|
|
if use_openai:
|
|
# search for key in enviornment
|
|
openai_key = os.getenv("OPENAI_API_KEY")
|
|
if not openai_key:
|
|
print("Missing enviornment variables for OpenAI. Please set them and run `memgpt configure` again.")
|
|
# TODO: eventually stop relying on env variables and pass in keys explicitly
|
|
# openai_key = questionary.text("Open AI API keys not found in enviornment - please enter:").ask()
|
|
|
|
# azure credentials
|
|
use_azure = questionary.confirm("Do you want to enable MemGPT with Azure?", default=(config.azure_key is not None)).ask()
|
|
use_azure_deployment_ids = False
|
|
if use_azure:
|
|
# search for key in enviornment
|
|
azure_key = os.getenv("AZURE_OPENAI_KEY")
|
|
azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
|
|
azure_version = os.getenv("AZURE_OPENAI_VERSION")
|
|
azure_deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT")
|
|
azure_embedding_deployment = os.getenv("AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT")
|
|
|
|
if all([azure_key, azure_endpoint, azure_version]):
|
|
print(f"Using Microsoft endpoint {azure_endpoint}.")
|
|
if all([azure_deployment, azure_embedding_deployment]):
|
|
print(f"Using deployment id {azure_deployment}")
|
|
use_azure_deployment_ids = True
|
|
|
|
# configure openai
|
|
openai.api_type = "azure"
|
|
openai.api_key = azure_key
|
|
openai.api_base = azure_endpoint
|
|
openai.api_version = azure_version
|
|
else:
|
|
print("Missing enviornment variables for Azure. Please set then run `memgpt configure` again.")
|
|
# TODO: allow for manual setting
|
|
use_azure = False
|
|
|
|
# TODO: configure local model
|
|
|
|
# configure provider
|
|
model_endpoint_options = []
|
|
if os.getenv("OPENAI_API_BASE") is not None:
|
|
model_endpoint_options.append(os.getenv("OPENAI_API_BASE"))
|
|
if use_openai:
|
|
model_endpoint_options += ["openai"]
|
|
if use_azure:
|
|
model_endpoint_options += ["azure"]
|
|
assert len(model_endpoint_options) > 0, "No endpoints found. Please enable OpenAI, Azure, or set OPENAI_API_BASE."
|
|
valid_default_model = config.model_endpoint in model_endpoint_options
|
|
default_endpoint = questionary.select(
|
|
"Select default inference endpoint:",
|
|
model_endpoint_options,
|
|
default=config.model_endpoint if valid_default_model else model_endpoint_options[0],
|
|
).ask()
|
|
|
|
# configure embedding provider
|
|
embedding_endpoint_options = ["local"] # cannot configure custom endpoint (too confusing)
|
|
if use_azure:
|
|
embedding_endpoint_options += ["azure"]
|
|
if use_openai:
|
|
embedding_endpoint_options += ["openai"]
|
|
valid_default_embedding = config.embedding_model in embedding_endpoint_options
|
|
default_embedding_endpoint = questionary.select(
|
|
"Select default embedding endpoint:",
|
|
embedding_endpoint_options,
|
|
default=config.embedding_model if valid_default_embedding else embedding_endpoint_options[-1],
|
|
).ask()
|
|
|
|
# configure embedding dimentions
|
|
default_embedding_dim = config.embedding_dim
|
|
if default_embedding_endpoint == "local":
|
|
# HF model uses lower dimentionality
|
|
default_embedding_dim = 384
|
|
|
|
# configure preset
|
|
default_preset = questionary.select("Select default preset:", preset_options, default=config.preset).ask()
|
|
|
|
# default model
|
|
if use_openai or use_azure:
|
|
model_options = []
|
|
if use_openai:
|
|
model_options += ["gpt-4", "gpt-4-1106-preview", "gpt-3.5-turbo-16k"]
|
|
default_model = questionary.select(
|
|
"Select default model (recommended: gpt-4):", choices=["gpt-4", "gpt-4-1106-preview", "gpt-3.5-turbo-16k"], default=config.model
|
|
).ask()
|
|
else:
|
|
default_model = "local" # TODO: figure out if this is ok? this is for local endpoint
|
|
|
|
# defaults
|
|
personas = [os.path.basename(f).replace(".txt", "") for f in utils.list_persona_files()]
|
|
# print(personas)
|
|
default_persona = questionary.select("Select default persona:", personas, default=config.default_persona).ask()
|
|
humans = [os.path.basename(f).replace(".txt", "") for f in utils.list_human_files()]
|
|
# print(humans)
|
|
default_human = questionary.select("Select default human:", humans, default=config.default_human).ask()
|
|
|
|
# TODO: figure out if we should set a default agent or not
|
|
default_agent = None
|
|
# agents = [os.path.basename(f).replace(".json", "") for f in utils.list_agent_config_files()]
|
|
# if len(agents) > 0: # agents have been created
|
|
# default_agent = questionary.select(
|
|
# "Select default agent:",
|
|
# agents
|
|
# ).ask()
|
|
# else:
|
|
# default_agent = None
|
|
|
|
# Configure archival storage backend
|
|
archival_storage_options = ["local", "postgres"]
|
|
archival_storage_type = questionary.select(
|
|
"Select storage backend for archival data:", archival_storage_options, default=config.archival_storage_type
|
|
).ask()
|
|
archival_storage_uri = None
|
|
if archival_storage_type == "postgres":
|
|
archival_storage_uri = questionary.text(
|
|
"Enter postgres connection string (e.g. postgresql+pg8000://{user}:{password}@{ip}:5432/{database}):",
|
|
default=config.archival_storage_uri if config.archival_storage_uri else "",
|
|
).ask()
|
|
|
|
# TODO: allow configuring embedding model
|
|
|
|
config = MemGPTConfig(
|
|
model=default_model,
|
|
preset=default_preset,
|
|
model_endpoint=default_endpoint,
|
|
embedding_model=default_embedding_endpoint,
|
|
embedding_dim=default_embedding_dim,
|
|
default_persona=default_persona,
|
|
default_human=default_human,
|
|
default_agent=default_agent,
|
|
openai_key=openai_key if use_openai else None,
|
|
azure_key=azure_key if use_azure else None,
|
|
azure_endpoint=azure_endpoint if use_azure else None,
|
|
azure_version=azure_version if use_azure else None,
|
|
azure_deployment=azure_deployment if use_azure_deployment_ids else None,
|
|
azure_embedding_deployment=azure_embedding_deployment if use_azure_deployment_ids else None,
|
|
archival_storage_type=archival_storage_type,
|
|
archival_storage_uri=archival_storage_uri,
|
|
)
|
|
print(f"Saving config to {config.config_path}")
|
|
config.save()
|
|
|
|
|
|
@app.command()
|
|
def list(option: str):
|
|
if option == "agents":
|
|
"""List all agents"""
|
|
table = PrettyTable()
|
|
table.field_names = ["Name", "Model", "Persona", "Human", "Data Source", "Create Time"]
|
|
for agent_file in utils.list_agent_config_files():
|
|
agent_name = os.path.basename(agent_file).replace(".json", "")
|
|
agent_config = AgentConfig.load(agent_name)
|
|
table.add_row(
|
|
[
|
|
agent_name,
|
|
agent_config.model,
|
|
agent_config.persona,
|
|
agent_config.human,
|
|
",".join(agent_config.data_sources),
|
|
agent_config.create_time,
|
|
]
|
|
)
|
|
print(table)
|
|
elif option == "humans":
|
|
"""List all humans"""
|
|
table = PrettyTable()
|
|
table.field_names = ["Name", "Text"]
|
|
for human_file in utils.list_human_files():
|
|
text = open(human_file, "r").read()
|
|
name = os.path.basename(human_file).replace("txt", "")
|
|
table.add_row([name, text])
|
|
print(table)
|
|
elif option == "personas":
|
|
"""List all personas"""
|
|
table = PrettyTable()
|
|
table.field_names = ["Name", "Text"]
|
|
for persona_file in utils.list_persona_files():
|
|
print(persona_file)
|
|
text = open(persona_file, "r").read()
|
|
name = os.path.basename(persona_file).replace(".txt", "")
|
|
table.add_row([name, text])
|
|
print(table)
|
|
elif option == "sources":
|
|
"""List all data sources"""
|
|
table = PrettyTable()
|
|
table.field_names = ["Name", "Location", "Agents"]
|
|
config = MemGPTConfig.load()
|
|
# TODO: eventually look accross all storage connections
|
|
# TODO: add data source stats
|
|
source_to_agents = {}
|
|
for agent_file in utils.list_agent_config_files():
|
|
agent_name = os.path.basename(agent_file).replace(".json", "")
|
|
agent_config = AgentConfig.load(agent_name)
|
|
for ds in agent_config.data_sources:
|
|
if ds in source_to_agents:
|
|
source_to_agents[ds].append(agent_name)
|
|
else:
|
|
source_to_agents[ds] = [agent_name]
|
|
for data_source in StorageConnector.list_loaded_data():
|
|
location = config.archival_storage_type
|
|
agents = ",".join(source_to_agents[data_source]) if data_source in source_to_agents else ""
|
|
table.add_row([data_source, location, agents])
|
|
print(table)
|
|
else:
|
|
raise ValueError(f"Unknown option {option}")
|
|
|
|
|
|
@app.command()
|
|
def add(
|
|
option: str, # [human, persona]
|
|
name: str = typer.Option(help="Name of human/persona"),
|
|
text: str = typer.Option(None, help="Text of human/persona"),
|
|
filename: str = typer.Option(None, "-f", help="Specify filename"),
|
|
):
|
|
"""Add a person/human"""
|
|
|
|
if option == "persona":
|
|
directory = os.path.join(MEMGPT_DIR, "personas")
|
|
elif option == "human":
|
|
directory = os.path.join(MEMGPT_DIR, "humans")
|
|
else:
|
|
raise ValueError(f"Unknown kind {kind}")
|
|
|
|
if filename:
|
|
assert text is None, f"Cannot provide both filename and text"
|
|
# copy file to directory
|
|
shutil.copyfile(filename, os.path.join(directory, name))
|
|
if text:
|
|
assert filename is None, f"Cannot provide both filename and text"
|
|
# write text to file
|
|
with open(os.path.join(directory, name), "w") as f:
|
|
f.write(text)
|