mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
230 lines
9.8 KiB
Python
230 lines
9.8 KiB
Python
import typer
|
|
import json
|
|
import sys
|
|
import io
|
|
import logging
|
|
import questionary
|
|
|
|
from llama_index import set_global_service_context
|
|
from llama_index import ServiceContext
|
|
|
|
from memgpt.interface import CLIInterface as interface # for printing to terminal
|
|
from memgpt.cli.cli_config import configure
|
|
import memgpt.presets.presets as presets
|
|
import memgpt.utils as utils
|
|
from memgpt.utils import printd
|
|
from memgpt.persistence_manager import LocalStateManager
|
|
from memgpt.config import MemGPTConfig, AgentConfig
|
|
from memgpt.constants import MEMGPT_DIR
|
|
from memgpt.agent import Agent
|
|
from memgpt.embeddings import embedding_model
|
|
|
|
|
|
def run(
|
|
persona: str = typer.Option(None, help="Specify persona"),
|
|
agent: str = typer.Option(None, help="Specify agent save file"),
|
|
human: str = typer.Option(None, help="Specify human"),
|
|
preset: str = typer.Option(None, help="Specify preset"),
|
|
# model flags
|
|
model: str = typer.Option(None, help="Specify the LLM model"),
|
|
model_wrapper: str = typer.Option(None, help="Specify the LLM model wrapper"),
|
|
model_endpoint: str = typer.Option(None, help="Specify the LLM model endpoint"),
|
|
model_endpoint_type: str = typer.Option(None, help="Specify the LLM model endpoint type"),
|
|
context_window: int = typer.Option(None, help="The context window of the LLM you are using (e.g. 8k for most Mistral 7B variants)"),
|
|
# other
|
|
first: bool = typer.Option(False, "--first", help="Use --first to send the first message in the sequence"),
|
|
strip_ui: bool = typer.Option(False, help="Remove all the bells and whistles in CLI output (helpful for testing)"),
|
|
debug: bool = typer.Option(False, "--debug", help="Use --debug to enable debugging output"),
|
|
no_verify: bool = typer.Option(False, help="Bypass message verification"),
|
|
yes: bool = typer.Option(False, "-y", help="Skip confirmation prompt and use defaults"),
|
|
):
|
|
"""Start chatting with an MemGPT agent
|
|
|
|
Example usage: `memgpt run --agent myagent --data-source mydata --persona mypersona --human myhuman --model gpt-3.5-turbo`
|
|
|
|
:param persona: Specify persona
|
|
:param agent: Specify agent name (will load existing state if the agent exists, or create a new one with that name)
|
|
:param human: Specify human
|
|
:param model: Specify the LLM model
|
|
|
|
"""
|
|
|
|
# setup logger
|
|
utils.DEBUG = debug
|
|
logging.getLogger().setLevel(logging.CRITICAL)
|
|
if debug:
|
|
logging.getLogger().setLevel(logging.DEBUG)
|
|
|
|
if not MemGPTConfig.exists(): # if no config, run configure
|
|
if yes:
|
|
# use defaults
|
|
config = MemGPTConfig()
|
|
else:
|
|
# use input
|
|
configure()
|
|
config = MemGPTConfig.load()
|
|
else: # load config
|
|
config = MemGPTConfig.load()
|
|
|
|
# force re-configuration is config is from old version
|
|
if config.memgpt_version is None: # TODO: eventually add checks for older versions, if config changes again
|
|
typer.secho("MemGPT has been updated to a newer version, so re-running configuration.", fg=typer.colors.YELLOW)
|
|
configure()
|
|
config = MemGPTConfig.load()
|
|
|
|
# override with command line arguments
|
|
if debug:
|
|
config.debug = debug
|
|
if no_verify:
|
|
config.no_verify = no_verify
|
|
|
|
# determine agent to use, if not provided
|
|
if not yes and not agent:
|
|
agent_files = utils.list_agent_config_files()
|
|
agents = [AgentConfig.load(f).name for f in agent_files]
|
|
|
|
if len(agents) > 0 and not any([persona, human, model]):
|
|
select_agent = questionary.confirm("Would you like to select an existing agent?").ask()
|
|
if select_agent:
|
|
agent = questionary.select("Select agent:", choices=agents).ask()
|
|
|
|
# configure llama index
|
|
config = MemGPTConfig.load()
|
|
original_stdout = sys.stdout # unfortunate hack required to suppress confusing print statements from llama index
|
|
sys.stdout = io.StringIO()
|
|
embed_model = embedding_model()
|
|
service_context = ServiceContext.from_defaults(llm=None, embed_model=embed_model, chunk_size=config.embedding_chunk_size)
|
|
set_global_service_context(service_context)
|
|
sys.stdout = original_stdout
|
|
|
|
# create agent config
|
|
if agent and AgentConfig.exists(agent): # use existing agent
|
|
typer.secho(f"Using existing agent {agent}", fg=typer.colors.GREEN)
|
|
agent_config = AgentConfig.load(agent)
|
|
printd("State path:", agent_config.save_state_dir())
|
|
printd("Persistent manager path:", agent_config.save_persistence_manager_dir())
|
|
printd("Index path:", agent_config.save_agent_index_dir())
|
|
# persistence_manager = LocalStateManager(agent_config).load() # TODO: implement load
|
|
# TODO: load prior agent state
|
|
if persona and persona != agent_config.persona:
|
|
typer.secho(f"Warning: Overriding existing persona {agent_config.persona} with {persona}", fg=typer.colors.YELLOW)
|
|
agent_config.persona = persona
|
|
# raise ValueError(f"Cannot override {agent_config.name} existing persona {agent_config.persona} with {persona}")
|
|
if human and human != agent_config.human:
|
|
typer.secho(f"Warning: Overriding existing human {agent_config.human} with {human}", fg=typer.colors.YELLOW)
|
|
agent_config.human = human
|
|
# raise ValueError(f"Cannot override {agent_config.name} existing human {agent_config.human} with {human}")
|
|
|
|
# Allow overriding model specifics (model, model wrapper, model endpoint IP + type, context_window)
|
|
if model and model != agent_config.model:
|
|
typer.secho(f"Warning: Overriding existing model {agent_config.model} with {model}", fg=typer.colors.YELLOW)
|
|
agent_config.model = model
|
|
if context_window is not None and int(context_window) != agent_config.context_window:
|
|
typer.secho(
|
|
f"Warning: Overriding existing context window {agent_config.context_window} with {context_window}", fg=typer.colors.YELLOW
|
|
)
|
|
agent_config.context_window = context_window
|
|
if model_wrapper and model_wrapper != agent_config.model_wrapper:
|
|
typer.secho(
|
|
f"Warning: Overriding existing model wrapper {agent_config.model_wrapper} with {model_wrapper}", fg=typer.colors.YELLOW
|
|
)
|
|
agent_config.model_wrapper = model_wrapper
|
|
if model_endpoint and model_endpoint != agent_config.model_endpoint:
|
|
typer.secho(
|
|
f"Warning: Overriding existing model endpoint {agent_config.model_endpoint} with {model_endpoint}", fg=typer.colors.YELLOW
|
|
)
|
|
agent_config.model_endpoint = model_endpoint
|
|
if model_endpoint_type and model_endpoint_type != agent_config.model_endpoint_type:
|
|
typer.secho(
|
|
f"Warning: Overriding existing model endpoint type {agent_config.model_endpoint_type} with {model_endpoint_type}",
|
|
fg=typer.colors.YELLOW,
|
|
)
|
|
agent_config.model_endpoint_type = model_endpoint_type
|
|
|
|
# Update the agent config with any overrides
|
|
agent_config.save()
|
|
|
|
# load existing agent
|
|
memgpt_agent = Agent.load_agent(interface, agent_config)
|
|
else: # create new agent
|
|
# create new agent config: override defaults with args if provided
|
|
typer.secho("Creating new agent...", fg=typer.colors.GREEN)
|
|
agent_config = AgentConfig(
|
|
name=agent,
|
|
persona=persona,
|
|
human=human,
|
|
preset=preset,
|
|
model=model,
|
|
model_wrapper=model_wrapper,
|
|
model_endpoint_type=model_endpoint_type,
|
|
model_endpoint=model_endpoint,
|
|
context_window=context_window,
|
|
)
|
|
|
|
# TODO: allow configrable state manager (only local is supported right now)
|
|
persistence_manager = LocalStateManager(agent_config) # TODO: insert dataset/pre-fill
|
|
|
|
# save new agent config
|
|
agent_config.save()
|
|
typer.secho(f"Created new agent {agent_config.name}.", fg=typer.colors.GREEN)
|
|
|
|
# create agent
|
|
memgpt_agent = presets.use_preset(
|
|
agent_config.preset,
|
|
agent_config,
|
|
agent_config.model,
|
|
utils.get_persona_text(agent_config.persona),
|
|
utils.get_human_text(agent_config.human),
|
|
interface,
|
|
persistence_manager,
|
|
)
|
|
|
|
# pretty print agent config
|
|
printd(json.dumps(vars(agent_config), indent=4, sort_keys=True))
|
|
|
|
# start event loop
|
|
from memgpt.main import run_agent_loop
|
|
|
|
run_agent_loop(memgpt_agent, first, no_verify, config) # TODO: add back no_verify
|
|
|
|
|
|
def attach(
|
|
agent: str = typer.Option(help="Specify agent to attach data to"),
|
|
data_source: str = typer.Option(help="Data source to attach to avent"),
|
|
):
|
|
# loads the data contained in data source into the agent's memory
|
|
from memgpt.connectors.storage import StorageConnector
|
|
from tqdm import tqdm
|
|
|
|
agent_config = AgentConfig.load(agent)
|
|
|
|
# get storage connectors
|
|
source_storage = StorageConnector.get_storage_connector(name=data_source)
|
|
dest_storage = StorageConnector.get_storage_connector(agent_config=agent_config)
|
|
|
|
size = source_storage.size()
|
|
typer.secho(f"Ingesting {size} passages into {agent_config.name}", fg=typer.colors.GREEN)
|
|
page_size = 100
|
|
generator = source_storage.get_all_paginated(page_size=page_size) # yields List[Passage]
|
|
passages = []
|
|
for i in tqdm(range(0, size, page_size)):
|
|
passages = next(generator)
|
|
dest_storage.insert_many(passages)
|
|
|
|
# save destination storage
|
|
dest_storage.save()
|
|
|
|
total_agent_passages = dest_storage.size()
|
|
|
|
typer.secho(
|
|
f"Attached data source {data_source} to agent {agent}, consisting of {len(passages)}. Agent now has {total_agent_passages} embeddings in archival memory.",
|
|
fg=typer.colors.GREEN,
|
|
)
|
|
|
|
|
|
def version():
|
|
import memgpt
|
|
|
|
print(memgpt.__version__)
|
|
return memgpt.__version__
|