mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00

* for openai, check for key and if missing allow user to pass it, for azure, throw error if the key isn't present * correct prior checking of azure to be more strict, added similar checks at the embedding endpoint config stage * forgot to override value in config before saving * clean up the valuerrors from missing keys so that no stacktrace gets printed, make success text green to match others
478 lines
20 KiB
Python
478 lines
20 KiB
Python
import builtins
|
|
import questionary
|
|
from prettytable import PrettyTable
|
|
import typer
|
|
import os
|
|
import shutil
|
|
|
|
# from memgpt.cli import app
|
|
from memgpt import utils
|
|
|
|
from memgpt.config import MemGPTConfig, AgentConfig
|
|
from memgpt.constants import MEMGPT_DIR
|
|
from memgpt.connectors.storage import StorageConnector
|
|
from memgpt.constants import LLM_MAX_TOKENS
|
|
from memgpt.local_llm.constants import DEFAULT_ENDPOINTS, DEFAULT_OLLAMA_MODEL, DEFAULT_WRAPPER_NAME
|
|
from memgpt.local_llm.utils import get_available_wrappers
|
|
|
|
app = typer.Typer()
|
|
|
|
|
|
def get_azure_credentials():
|
|
creds = dict(
|
|
azure_key=os.getenv("AZURE_OPENAI_KEY"),
|
|
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
|
|
azure_version=os.getenv("AZURE_OPENAI_VERSION"),
|
|
azure_deployment=os.getenv("AZURE_OPENAI_DEPLOYMENT"),
|
|
azure_embedding_deployment=os.getenv("AZURE_OPENAI_EMBEDDING_DEPLOYMENT"),
|
|
)
|
|
# embedding endpoint and version default to non-embedding
|
|
creds["azure_embedding_endpoint"] = os.getenv("AZURE_OPENAI_EMBEDDING_ENDPOINT", creds["azure_endpoint"])
|
|
creds["azure_embedding_version"] = os.getenv("AZURE_OPENAI_EMBEDDING_VERSION", creds["azure_version"])
|
|
return creds
|
|
|
|
|
|
def get_openai_credentials():
|
|
openai_key = os.getenv("OPENAI_API_KEY")
|
|
return openai_key
|
|
|
|
|
|
def configure_llm_endpoint(config: MemGPTConfig):
|
|
# configure model endpoint
|
|
model_endpoint_type, model_endpoint = None, None
|
|
|
|
# get default
|
|
default_model_endpoint_type = config.model_endpoint_type
|
|
if config.model_endpoint_type is not None and config.model_endpoint_type not in ["openai", "azure"]: # local model
|
|
default_model_endpoint_type = "local"
|
|
|
|
provider = questionary.select(
|
|
"Select LLM inference provider:", choices=["openai", "azure", "local"], default=default_model_endpoint_type
|
|
).ask()
|
|
|
|
# set: model_endpoint_type, model_endpoint
|
|
if provider == "openai":
|
|
# check for key
|
|
if config.openai_key is None:
|
|
# allow key to get pulled from env vars
|
|
openai_api_key = os.getenv("OPENAI_API_KEY", None)
|
|
if openai_api_key is None:
|
|
# if we still can't find it, ask for it as input
|
|
while openai_api_key is None or len(openai_api_key) == 0:
|
|
# Ask for API key as input
|
|
openai_api_key = questionary.text(
|
|
"Enter your OpenAI API key (starts with 'sk-', see https://platform.openai.com/api-keys):"
|
|
).ask()
|
|
config.openai_key = openai_api_key
|
|
config.save()
|
|
|
|
model_endpoint_type = "openai"
|
|
model_endpoint = "https://api.openai.com/v1"
|
|
model_endpoint = questionary.text("Override default endpoint:", default=model_endpoint).ask()
|
|
provider = "openai"
|
|
|
|
elif provider == "azure":
|
|
# check for necessary vars
|
|
azure_creds = get_azure_credentials()
|
|
if not all([azure_creds["azure_key"], azure_creds["azure_endpoint"], azure_creds["azure_version"]]):
|
|
raise ValueError(
|
|
"Missing environment variables for Azure (see https://memgpt.readme.io/docs/endpoints#azure-openai). Please set then run `memgpt configure` again."
|
|
)
|
|
|
|
model_endpoint_type = "azure"
|
|
model_endpoint = azure_creds["azure_endpoint"]
|
|
|
|
else: # local models
|
|
backend_options = ["webui", "webui-legacy", "llamacpp", "koboldcpp", "ollama", "lmstudio", "lmstudio-legacy", "vllm", "openai"]
|
|
default_model_endpoint_type = None
|
|
if config.model_endpoint_type in backend_options:
|
|
# set from previous config
|
|
default_model_endpoint_type = config.model_endpoint_type
|
|
model_endpoint_type = questionary.select(
|
|
"Select LLM backend (select 'openai' if you have an OpenAI compatible proxy):",
|
|
backend_options,
|
|
default=default_model_endpoint_type,
|
|
).ask()
|
|
|
|
# set default endpoint
|
|
# if OPENAI_API_BASE is set, assume that this is the IP+port the user wanted to use
|
|
default_model_endpoint = os.getenv("OPENAI_API_BASE")
|
|
# if OPENAI_API_BASE is not set, try to pull a default IP+port format from a hardcoded set
|
|
if default_model_endpoint is None:
|
|
if model_endpoint_type in DEFAULT_ENDPOINTS:
|
|
default_model_endpoint = DEFAULT_ENDPOINTS[model_endpoint_type]
|
|
model_endpoint = questionary.text("Enter default endpoint:", default=default_model_endpoint).ask()
|
|
elif config.model_endpoint:
|
|
model_endpoint = questionary.text("Enter default endpoint:", default=config.model_endpoint).ask()
|
|
else:
|
|
# default_model_endpoint = None
|
|
model_endpoint = None
|
|
while not model_endpoint:
|
|
model_endpoint = questionary.text("Enter default endpoint:").ask()
|
|
if "http://" not in model_endpoint and "https://" not in model_endpoint:
|
|
typer.secho(f"Endpoint must be a valid address", fg=typer.colors.YELLOW)
|
|
model_endpoint = None
|
|
else:
|
|
model_endpoint = default_model_endpoint
|
|
assert model_endpoint, f"Environment variable OPENAI_API_BASE must be set."
|
|
|
|
return model_endpoint_type, model_endpoint
|
|
|
|
|
|
def configure_model(config: MemGPTConfig, model_endpoint_type: str):
|
|
# set: model, model_wrapper
|
|
model, model_wrapper = None, None
|
|
if model_endpoint_type == "openai" or model_endpoint_type == "azure":
|
|
model_options = ["gpt-4", "gpt-4-1106-preview", "gpt-3.5-turbo", "gpt-3.5-turbo-16k"]
|
|
# TODO: select
|
|
valid_model = config.model in model_options
|
|
model = questionary.select(
|
|
"Select default model (recommended: gpt-4):", choices=model_options, default=config.model if valid_model else model_options[0]
|
|
).ask()
|
|
else: # local models
|
|
# ollama also needs model type
|
|
if model_endpoint_type == "ollama":
|
|
default_model = config.model if config.model and config.model_endpoint_type == "ollama" else DEFAULT_OLLAMA_MODEL
|
|
model = questionary.text(
|
|
"Enter default model name (required for Ollama, see: https://memgpt.readme.io/docs/ollama):",
|
|
default=default_model,
|
|
).ask()
|
|
model = None if len(model) == 0 else model
|
|
|
|
# vllm needs huggingface model tag
|
|
if model_endpoint_type == "vllm":
|
|
default_model = config.model if config.model and config.model_endpoint_type == "vllm" else ""
|
|
model = questionary.text(
|
|
"Enter HuggingFace model tag (e.g. ehartford/dolphin-2.2.1-mistral-7b):",
|
|
default=default_model,
|
|
).ask()
|
|
model = None if len(model) == 0 else model
|
|
model_wrapper = None # no model wrapper for vLLM
|
|
|
|
# model wrapper
|
|
if model_endpoint_type != "vllm":
|
|
available_model_wrappers = builtins.list(get_available_wrappers().keys())
|
|
model_wrapper = questionary.select(
|
|
f"Select default model wrapper (recommended: {DEFAULT_WRAPPER_NAME}):",
|
|
choices=available_model_wrappers,
|
|
default=DEFAULT_WRAPPER_NAME,
|
|
).ask()
|
|
|
|
# set: context_window
|
|
if str(model) not in LLM_MAX_TOKENS:
|
|
# Ask the user to specify the context length
|
|
context_length_options = [
|
|
str(2**12), # 4096
|
|
str(2**13), # 8192
|
|
str(2**14), # 16384
|
|
str(2**15), # 32768
|
|
str(2**18), # 262144
|
|
"custom", # enter yourself
|
|
]
|
|
context_window = questionary.select(
|
|
"Select your model's context window (for Mistral 7B models, this is probably 8k / 8192):",
|
|
choices=context_length_options,
|
|
default=str(LLM_MAX_TOKENS["DEFAULT"]),
|
|
).ask()
|
|
|
|
# If custom, ask for input
|
|
if context_window == "custom":
|
|
while True:
|
|
context_window = questionary.text("Enter context window (e.g. 8192)").ask()
|
|
try:
|
|
context_window = int(context_window)
|
|
break
|
|
except ValueError:
|
|
print(f"Context window must be a valid integer")
|
|
else:
|
|
context_window = int(context_window)
|
|
else:
|
|
# Pull the context length from the models
|
|
context_window = LLM_MAX_TOKENS[model]
|
|
return model, model_wrapper, context_window
|
|
|
|
|
|
def configure_embedding_endpoint(config: MemGPTConfig):
|
|
# configure embedding endpoint
|
|
|
|
default_embedding_endpoint_type = config.embedding_endpoint_type
|
|
|
|
embedding_endpoint_type, embedding_endpoint, embedding_dim, embedding_model = None, None, None, None
|
|
embedding_provider = questionary.select(
|
|
"Select embedding provider:", choices=["openai", "azure", "hugging-face", "local"], default=default_embedding_endpoint_type
|
|
).ask()
|
|
|
|
if embedding_provider == "openai":
|
|
# check for key
|
|
if config.openai_key is None:
|
|
# allow key to get pulled from env vars
|
|
openai_api_key = os.getenv("OPENAI_API_KEY", None)
|
|
if openai_api_key is None:
|
|
# if we still can't find it, ask for it as input
|
|
while openai_api_key is None or len(openai_api_key) == 0:
|
|
# Ask for API key as input
|
|
openai_api_key = questionary.text(
|
|
"Enter your OpenAI API key (starts with 'sk-', see https://platform.openai.com/api-keys):"
|
|
).ask()
|
|
config.openai_key = openai_api_key
|
|
config.save()
|
|
|
|
embedding_endpoint_type = "openai"
|
|
embedding_endpoint = "https://api.openai.com/v1"
|
|
embedding_dim = 1536
|
|
|
|
elif embedding_provider == "azure":
|
|
# check for necessary vars
|
|
azure_creds = get_azure_credentials()
|
|
if not all([azure_creds["azure_key"], azure_creds["azure_embedding_endpoint"], azure_creds["azure_embedding_version"]]):
|
|
raise ValueError(
|
|
"Missing environment variables for Azure (see https://memgpt.readme.io/docs/endpoints#azure-openai). Please set then run `memgpt configure` again."
|
|
)
|
|
|
|
embedding_endpoint_type = "azure"
|
|
embedding_endpoint = azure_creds["azure_embedding_endpoint"]
|
|
embedding_dim = 1536
|
|
|
|
elif embedding_provider == "hugging-face":
|
|
# configure hugging face embedding endpoint (https://github.com/huggingface/text-embeddings-inference)
|
|
# supports custom model/endpoints
|
|
embedding_endpoint_type = "hugging-face"
|
|
embedding_endpoint = None
|
|
|
|
# get endpoint
|
|
embedding_endpoint = questionary.text("Enter default endpoint:").ask()
|
|
if "http://" not in embedding_endpoint and "https://" not in embedding_endpoint:
|
|
typer.secho(f"Endpoint must be a valid address", fg=typer.colors.YELLOW)
|
|
embedding_endpoint = None
|
|
|
|
# get model type
|
|
default_embedding_model = config.embedding_model if config.embedding_model else "BAAI/bge-large-en-v1.5"
|
|
embedding_model = questionary.text(
|
|
"Enter HuggingFace model tag (e.g. BAAI/bge-large-en-v1.5):",
|
|
default=default_embedding_model,
|
|
).ask()
|
|
|
|
# get model dimentions
|
|
default_embedding_dim = config.embedding_dim if config.embedding_dim else "1024"
|
|
embedding_dim = questionary.text("Enter embedding model dimentions (e.g. 1024):", default=str(default_embedding_dim)).ask()
|
|
try:
|
|
embedding_dim = int(embedding_dim)
|
|
except Exception as e:
|
|
raise ValueError(f"Failed to cast {embedding_dim} to integer.")
|
|
else: # local models
|
|
embedding_endpoint_type = "local"
|
|
embedding_endpoint = None
|
|
embedding_dim = 384
|
|
|
|
return embedding_endpoint_type, embedding_endpoint, embedding_dim, embedding_model
|
|
|
|
|
|
def configure_cli(config: MemGPTConfig):
|
|
# set: preset, default_persona, default_human, default_agent``
|
|
from memgpt.presets.presets import preset_options
|
|
|
|
# preset
|
|
default_preset = config.preset if config.preset and config.preset in preset_options else None
|
|
preset = questionary.select("Select default preset:", preset_options, default=default_preset).ask()
|
|
|
|
# persona
|
|
personas = [os.path.basename(f).replace(".txt", "") for f in utils.list_persona_files()]
|
|
default_persona = config.persona if config.persona and config.persona in personas else None
|
|
persona = questionary.select("Select default persona:", personas, default=default_persona).ask()
|
|
|
|
# human
|
|
humans = [os.path.basename(f).replace(".txt", "") for f in utils.list_human_files()]
|
|
default_human = config.human if config.human and config.human in humans else None
|
|
human = questionary.select("Select default human:", humans, default=default_human).ask()
|
|
|
|
# TODO: figure out if we should set a default agent or not
|
|
agent = None
|
|
|
|
return preset, persona, human, agent
|
|
|
|
|
|
def configure_archival_storage(config: MemGPTConfig):
|
|
# Configure archival storage backend
|
|
archival_storage_options = ["local", "lancedb", "postgres", "chroma"]
|
|
archival_storage_type = questionary.select(
|
|
"Select storage backend for archival data:", archival_storage_options, default=config.archival_storage_type
|
|
).ask()
|
|
archival_storage_uri, archival_storage_path = None, None
|
|
|
|
# configure postgres
|
|
if archival_storage_type == "postgres":
|
|
archival_storage_uri = questionary.text(
|
|
"Enter postgres connection string (e.g. postgresql+pg8000://{user}:{password}@{ip}:5432/{database}):",
|
|
default=config.archival_storage_uri if config.archival_storage_uri else "",
|
|
).ask()
|
|
|
|
# configure lancedb
|
|
if archival_storage_type == "lancedb":
|
|
archival_storage_uri = questionary.text(
|
|
"Enter lanncedb connection string (e.g. ./.lancedb",
|
|
default=config.archival_storage_uri if config.archival_storage_uri else "./.lancedb",
|
|
).ask()
|
|
|
|
# configure chroma
|
|
if archival_storage_type == "chroma":
|
|
chroma_type = questionary.select("Select chroma backend:", ["http", "persistent"], default="http").ask()
|
|
if chroma_type == "http":
|
|
archival_storage_uri = questionary.text("Enter chroma ip (e.g. localhost:8000):", default="localhost:8000").ask()
|
|
if chroma_type == "persistent":
|
|
print(config.config_path, config.archival_storage_path)
|
|
default_archival_storage_path = (
|
|
config.archival_storage_path if config.archival_storage_path else os.path.join(config.config_path, "chroma")
|
|
)
|
|
print(default_archival_storage_path)
|
|
archival_storage_path = questionary.text("Enter persistent storage location:", default=default_archival_storage_path).ask()
|
|
|
|
return archival_storage_type, archival_storage_uri, archival_storage_path
|
|
|
|
# TODO: allow configuring embedding model
|
|
|
|
|
|
@app.command()
|
|
def configure():
|
|
"""Updates default MemGPT configurations"""
|
|
|
|
# check credentials
|
|
openai_key = get_openai_credentials()
|
|
azure_creds = get_azure_credentials()
|
|
|
|
MemGPTConfig.create_config_dir()
|
|
|
|
# Will pre-populate with defaults, or what the user previously set
|
|
config = MemGPTConfig.load()
|
|
try:
|
|
model_endpoint_type, model_endpoint = configure_llm_endpoint(config)
|
|
model, model_wrapper, context_window = configure_model(config, model_endpoint_type)
|
|
embedding_endpoint_type, embedding_endpoint, embedding_dim, embedding_model = configure_embedding_endpoint(config)
|
|
default_preset, default_persona, default_human, default_agent = configure_cli(config)
|
|
archival_storage_type, archival_storage_uri, archival_storage_path = configure_archival_storage(config)
|
|
except ValueError as e:
|
|
typer.secho(str(e), fg=typer.colors.RED)
|
|
return
|
|
|
|
config = MemGPTConfig(
|
|
# model configs
|
|
model=model,
|
|
model_endpoint=model_endpoint,
|
|
model_endpoint_type=model_endpoint_type,
|
|
model_wrapper=model_wrapper,
|
|
context_window=context_window,
|
|
# embedding configs
|
|
embedding_endpoint_type=embedding_endpoint_type,
|
|
embedding_endpoint=embedding_endpoint,
|
|
embedding_dim=embedding_dim,
|
|
embedding_model=embedding_model,
|
|
# cli configs
|
|
preset=default_preset,
|
|
persona=default_persona,
|
|
human=default_human,
|
|
agent=default_agent,
|
|
# credentials
|
|
openai_key=openai_key,
|
|
azure_key=azure_creds["azure_key"],
|
|
azure_endpoint=azure_creds["azure_endpoint"],
|
|
azure_version=azure_creds["azure_version"],
|
|
azure_deployment=azure_creds["azure_deployment"], # OK if None
|
|
azure_embedding_deployment=azure_creds["azure_embedding_deployment"], # OK if None
|
|
# storage
|
|
archival_storage_type=archival_storage_type,
|
|
archival_storage_uri=archival_storage_uri,
|
|
archival_storage_path=archival_storage_path,
|
|
)
|
|
typer.secho(f"📖 Saving config to {config.config_path}", fg=typer.colors.GREEN)
|
|
config.save()
|
|
|
|
|
|
@app.command()
|
|
def list(option: str):
|
|
if option == "agents":
|
|
"""List all agents"""
|
|
table = PrettyTable()
|
|
table.field_names = ["Name", "Model", "Persona", "Human", "Data Source", "Create Time"]
|
|
for agent_file in utils.list_agent_config_files():
|
|
agent_name = os.path.basename(agent_file).replace(".json", "")
|
|
agent_config = AgentConfig.load(agent_name)
|
|
table.add_row(
|
|
[
|
|
agent_name,
|
|
agent_config.model,
|
|
agent_config.persona,
|
|
agent_config.human,
|
|
",".join(agent_config.data_sources),
|
|
agent_config.create_time,
|
|
]
|
|
)
|
|
print(table)
|
|
elif option == "humans":
|
|
"""List all humans"""
|
|
table = PrettyTable()
|
|
table.field_names = ["Name", "Text"]
|
|
for human_file in utils.list_human_files():
|
|
text = open(human_file, "r").read()
|
|
name = os.path.basename(human_file).replace("txt", "")
|
|
table.add_row([name, text])
|
|
print(table)
|
|
elif option == "personas":
|
|
"""List all personas"""
|
|
table = PrettyTable()
|
|
table.field_names = ["Name", "Text"]
|
|
for persona_file in utils.list_persona_files():
|
|
print(persona_file)
|
|
text = open(persona_file, "r").read()
|
|
name = os.path.basename(persona_file).replace(".txt", "")
|
|
table.add_row([name, text])
|
|
print(table)
|
|
elif option == "sources":
|
|
"""List all data sources"""
|
|
table = PrettyTable()
|
|
table.field_names = ["Name", "Location", "Agents"]
|
|
config = MemGPTConfig.load()
|
|
# TODO: eventually look accross all storage connections
|
|
# TODO: add data source stats
|
|
source_to_agents = {}
|
|
for agent_file in utils.list_agent_config_files():
|
|
agent_name = os.path.basename(agent_file).replace(".json", "")
|
|
agent_config = AgentConfig.load(agent_name)
|
|
for ds in agent_config.data_sources:
|
|
if ds in source_to_agents:
|
|
source_to_agents[ds].append(agent_name)
|
|
else:
|
|
source_to_agents[ds] = [agent_name]
|
|
for data_source in StorageConnector.list_loaded_data():
|
|
location = config.archival_storage_type
|
|
agents = ",".join(source_to_agents[data_source]) if data_source in source_to_agents else ""
|
|
table.add_row([data_source, location, agents])
|
|
print(table)
|
|
else:
|
|
raise ValueError(f"Unknown option {option}")
|
|
|
|
|
|
@app.command()
|
|
def add(
|
|
option: str, # [human, persona]
|
|
name: str = typer.Option(help="Name of human/persona"),
|
|
text: str = typer.Option(None, help="Text of human/persona"),
|
|
filename: str = typer.Option(None, "-f", help="Specify filename"),
|
|
):
|
|
"""Add a person/human"""
|
|
|
|
if option == "persona":
|
|
directory = os.path.join(MEMGPT_DIR, "personas")
|
|
elif option == "human":
|
|
directory = os.path.join(MEMGPT_DIR, "humans")
|
|
else:
|
|
raise ValueError(f"Unknown kind {option}")
|
|
|
|
if filename:
|
|
assert text is None, f"Cannot provide both filename and text"
|
|
# copy file to directory
|
|
shutil.copyfile(filename, os.path.join(directory, name))
|
|
if text:
|
|
assert filename is None, f"Cannot provide both filename and text"
|
|
# write text to file
|
|
with open(os.path.join(directory, name), "w") as f:
|
|
f.write(text)
|