Refactor config + determine LLM via config.model_endpoint_type (#422)

* mark depricated API section

* CLI bug fixes for azure

* check azure before running

* Update README.md

* Update README.md

* bug fix with persona loading

* remove print

* make errors for cli flags more clear

* format

* fix imports

* fix imports

* add prints

* update lock

* update config fields

* cleanup config loading

* commit

* remove asserts

* refactor configure

* put into different functions

* add embedding default

* pass in config

* fixes

* allow overriding openai embedding endpoint

* black

* trying to patch tests (some circular import errors)

* update flags and docs

* patched support for local llms using endpoint and endpoint type passed via configs, not env vars

* missing files

* fix naming

* fix import

* fix two runtime errors

* patch ollama typo, move ollama model question pre-wrapper, modify question phrasing to include link to readthedocs, also have a default ollama model that has a tag included

* disable debug messages

* made error message for failed load more informative

* don't print dynamic linking function warning unless --debug

* updated tests to work with new cli workflow (disabled openai config test for now)

* added skips for tests when vars are missing

* update bad arg

* revise test to soft pass on empty string too

* don't run configure twice

* extend timeout (try to pass against nltk download)

* update defaults

* typo with endpoint type default

* patch runtime errors for when model is None

* catching another case of 'x in model' when model is None (preemptively)

* allow overrides to local llm related config params

* made model wrapper selection from a list vs raw input

* update test for select instead of input

* Fixed bug in endpoint when using local->openai selection, also added validation loop to manual endpoint entry

* updated error messages to be more informative with links to readthedocs

* add back gpt3.5-turbo

---------

Co-authored-by: cpacker <packercharles@gmail.com>
This commit is contained in:
Sarah Wooders 2023-11-14 15:58:19 -08:00 committed by GitHub
parent 8fdc3a29da
commit 28514da5df
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 628 additions and 435 deletions

View File

@ -6,15 +6,21 @@ The `memgpt run` command supports the following optional flags (if set, will ove
* `--agent`: (str) Name of agent to create or to resume chatting with. * `--agent`: (str) Name of agent to create or to resume chatting with.
* `--human`: (str) Name of the human to run the agent with. * `--human`: (str) Name of the human to run the agent with.
* `--persona`: (str) Name of agent persona to use. * `--persona`: (str) Name of agent persona to use.
* `--model`: (str) LLM model to run [gpt-4, gpt-3.5]. * `--model`: (str) LLM model to run (e.g. `gpt-4`, `dolphin_xxx`)
* `--preset`: (str) MemGPT preset to run agent with. * `--preset`: (str) MemGPT preset to run agent with.
* `--first`: (str) Allow user to sent the first message. * `--first`: (str) Allow user to sent the first message.
* `--debug`: (bool) Show debug logs (default=False) * `--debug`: (bool) Show debug logs (default=False)
* `--no-verify`: (bool) Bypass message verification (default=False) * `--no-verify`: (bool) Bypass message verification (default=False)
* `--yes`/`-y`: (bool) Skip confirmation prompt and use defaults (default=False) * `--yes`/`-y`: (bool) Skip confirmation prompt and use defaults (default=False)
You can override the parameters you set with `memgpt configure` with the following additional flags specific to local LLMs:
* `--model-wrapper`: (str) Model wrapper used by backend (e.g. `airoboros_xxx`)
* `--model-endpoint-type`: (str) Model endpoint backend type (e.g. lmstudio, ollama)
* `--model-endpoint`: (str) Model endpoint url (e.g. `localhost:5000`)
* `--context-window`: (int) Size of model context window (specific to model type)
#### Updating the config location #### Updating the config location
You can override the location of the config path by setting the enviornment variable `MEMGPT_CONFIG_PATH`: You can override the location of the config path by setting the environment variable `MEMGPT_CONFIG_PATH`:
``` ```
export MEMGPT_CONFIG_PATH=/my/custom/path/config # make sure this is a file, not a directory export MEMGPT_CONFIG_PATH=/my/custom/path/config # make sure this is a file, not a directory
``` ```

View File

@ -9,6 +9,7 @@ from memgpt.config import AgentConfig
from .system import get_login_event, package_function_response, package_summarize_message, get_initial_boot_messages from .system import get_login_event, package_function_response, package_summarize_message, get_initial_boot_messages
from .memory import CoreMemory as Memory, summarize_messages from .memory import CoreMemory as Memory, summarize_messages
from .openai_tools import completions_with_backoff as create from .openai_tools import completions_with_backoff as create
from memgpt.openai_tools import chat_completion_with_backoff
from .utils import get_local_time, parse_json, united_diff, printd, count_tokens, get_schema_diff from .utils import get_local_time, parse_json, united_diff, printd, count_tokens, get_schema_diff
from .constants import ( from .constants import (
FIRST_MESSAGE_ATTEMPTS, FIRST_MESSAGE_ATTEMPTS,
@ -73,7 +74,7 @@ def initialize_message_sequence(
first_user_message = get_login_event() # event letting MemGPT know the user just logged in first_user_message = get_login_event() # event letting MemGPT know the user just logged in
if include_initial_boot_message: if include_initial_boot_message:
if "gpt-3.5" in model: if model is not None and "gpt-3.5" in model:
initial_boot_messages = get_initial_boot_messages("startup_with_send_message_gpt35") initial_boot_messages = get_initial_boot_messages("startup_with_send_message_gpt35")
else: else:
initial_boot_messages = get_initial_boot_messages("startup_with_send_message") initial_boot_messages = get_initial_boot_messages("startup_with_send_message")
@ -96,37 +97,6 @@ def initialize_message_sequence(
return messages return messages
def get_ai_reply(
model,
message_sequence,
functions,
function_call="auto",
context_window=None,
):
try:
response = create(
model=model,
context_window=context_window,
messages=message_sequence,
functions=functions,
function_call=function_call,
)
# special case for 'length'
if response.choices[0].finish_reason == "length":
raise Exception("Finish reason was length (maximum context length)")
# catches for soft errors
if response.choices[0].finish_reason not in ["stop", "function_call"]:
raise Exception(f"API call finish with bad finish reason: {response}")
# unpack with response.choices[0].message.content
return response
except Exception as e:
raise e
class Agent(object): class Agent(object):
def __init__( def __init__(
self, self,
@ -310,7 +280,7 @@ class Agent(object):
json_files = glob.glob(os.path.join(directory, "*.json")) # This will list all .json files in the current directory. json_files = glob.glob(os.path.join(directory, "*.json")) # This will list all .json files in the current directory.
if not json_files: if not json_files:
print(f"/load error: no .json checkpoint files found") print(f"/load error: no .json checkpoint files found")
raise ValueError(f"Cannot load {agent_name}: does not exist in {directory}") raise ValueError(f"Cannot load {agent_name} - no saved checkpoints found in {directory}")
# Sort files based on modified timestamp, with the latest file being the first. # Sort files based on modified timestamp, with the latest file being the first.
filename = max(json_files, key=os.path.getmtime) filename = max(json_files, key=os.path.getmtime)
@ -360,7 +330,7 @@ class Agent(object):
# NOTE to handle old configs, instead of erroring here let's just warn # NOTE to handle old configs, instead of erroring here let's just warn
# raise ValueError(error_message) # raise ValueError(error_message)
print(error_message) printd(error_message)
linked_function_set[f_name] = linked_function linked_function_set[f_name] = linked_function
messages = state["messages"] messages = state["messages"]
@ -602,8 +572,7 @@ class Agent(object):
printd(f"This is the first message. Running extra verifier on AI response.") printd(f"This is the first message. Running extra verifier on AI response.")
counter = 0 counter = 0
while True: while True:
response = get_ai_reply( response = self.get_ai_reply(
model=self.model,
message_sequence=input_message_sequence, message_sequence=input_message_sequence,
functions=self.functions, functions=self.functions,
context_window=None if self.config.context_window is None else int(self.config.context_window), context_window=None if self.config.context_window is None else int(self.config.context_window),
@ -616,8 +585,7 @@ class Agent(object):
raise Exception(f"Hit first message retry limit ({first_message_retry_limit})") raise Exception(f"Hit first message retry limit ({first_message_retry_limit})")
else: else:
response = get_ai_reply( response = self.get_ai_reply(
model=self.model,
message_sequence=input_message_sequence, message_sequence=input_message_sequence,
functions=self.functions, functions=self.functions,
context_window=None if self.config.context_window is None else int(self.config.context_window), context_window=None if self.config.context_window is None else int(self.config.context_window),
@ -785,3 +753,55 @@ class Agent(object):
# Check if it's been more than pause_heartbeats_minutes since pause_heartbeats_start # Check if it's been more than pause_heartbeats_minutes since pause_heartbeats_start
elapsed_time = datetime.datetime.now() - self.pause_heartbeats_start elapsed_time = datetime.datetime.now() - self.pause_heartbeats_start
return elapsed_time.total_seconds() < self.pause_heartbeats_minutes * 60 return elapsed_time.total_seconds() < self.pause_heartbeats_minutes * 60
def get_ai_reply(
self,
message_sequence,
function_call="auto",
):
"""Get response from LLM API"""
# TODO: Legacy code - delete
if self.config is None:
try:
response = create(
model=self.model,
context_window=self.context_window,
messages=message_sequence,
functions=self.functions,
function_call=function_call,
)
# special case for 'length'
if response.choices[0].finish_reason == "length":
raise Exception("Finish reason was length (maximum context length)")
# catches for soft errors
if response.choices[0].finish_reason not in ["stop", "function_call"]:
raise Exception(f"API call finish with bad finish reason: {response}")
# unpack with response.choices[0].message.content
return response
except Exception as e:
raise e
try:
response = chat_completion_with_backoff(
agent_config=self.config,
model=self.model, # TODO: remove (is redundant)
messages=message_sequence,
functions=self.functions,
function_call=function_call,
)
# special case for 'length'
if response.choices[0].finish_reason == "length":
raise Exception("Finish reason was length (maximum context length)")
# catches for soft errors
if response.choices[0].finish_reason not in ["stop", "function_call"]:
raise Exception(f"API call finish with bad finish reason: {response}")
# unpack with response.choices[0].message.content
return response
except Exception as e:
raise e

View File

@ -1,4 +1,5 @@
import typer import typer
import json
import sys import sys
import io import io
import logging import logging
@ -35,16 +36,21 @@ def run(
persona: str = typer.Option(None, help="Specify persona"), persona: str = typer.Option(None, help="Specify persona"),
agent: str = typer.Option(None, help="Specify agent save file"), agent: str = typer.Option(None, help="Specify agent save file"),
human: str = typer.Option(None, help="Specify human"), human: str = typer.Option(None, help="Specify human"),
model: str = typer.Option(None, help="Specify the LLM model"),
preset: str = typer.Option(None, help="Specify preset"), preset: str = typer.Option(None, help="Specify preset"),
# model flags
model: str = typer.Option(None, help="Specify the LLM model"),
model_wrapper: str = typer.Option(None, help="Specify the LLM model wrapper"),
model_endpoint: str = typer.Option(None, help="Specify the LLM model endpoint"),
model_endpoint_type: str = typer.Option(None, help="Specify the LLM model endpoint type"),
context_window: int = typer.Option(
None, "--context_window", help="The context window of the LLM you are using (e.g. 8k for most Mistral 7B variants)"
),
# other
first: bool = typer.Option(False, "--first", help="Use --first to send the first message in the sequence"), first: bool = typer.Option(False, "--first", help="Use --first to send the first message in the sequence"),
strip_ui: bool = typer.Option(False, "--strip_ui", help="Remove all the bells and whistles in CLI output (helpful for testing)"), strip_ui: bool = typer.Option(False, "--strip_ui", help="Remove all the bells and whistles in CLI output (helpful for testing)"),
debug: bool = typer.Option(False, "--debug", help="Use --debug to enable debugging output"), debug: bool = typer.Option(False, "--debug", help="Use --debug to enable debugging output"),
no_verify: bool = typer.Option(False, "--no_verify", help="Bypass message verification"), no_verify: bool = typer.Option(False, "--no_verify", help="Bypass message verification"),
yes: bool = typer.Option(False, "-y", help="Skip confirmation prompt and use defaults"), yes: bool = typer.Option(False, "-y", help="Skip confirmation prompt and use defaults"),
context_window: int = typer.Option(
None, "--context_window", help="The context window of the LLM you are using (e.g. 8k for most Mistral 7B variants)"
),
): ):
"""Start chatting with an MemGPT agent """Start chatting with an MemGPT agent
@ -99,11 +105,6 @@ def run(
set_global_service_context(service_context) set_global_service_context(service_context)
sys.stdout = original_stdout sys.stdout = original_stdout
# overwrite the context_window if specified
if context_window is not None and int(context_window) != int(config.context_window):
typer.secho(f"Warning: Overriding existing context window {config.context_window} with {context_window}", fg=typer.colors.YELLOW)
config.context_window = str(context_window)
# create agent config # create agent config
if agent and AgentConfig.exists(agent): # use existing agent if agent and AgentConfig.exists(agent): # use existing agent
typer.secho(f"Using existing agent {agent}", fg=typer.colors.GREEN) typer.secho(f"Using existing agent {agent}", fg=typer.colors.GREEN)
@ -121,10 +122,34 @@ def run(
typer.secho(f"Warning: Overriding existing human {agent_config.human} with {human}", fg=typer.colors.YELLOW) typer.secho(f"Warning: Overriding existing human {agent_config.human} with {human}", fg=typer.colors.YELLOW)
agent_config.human = human agent_config.human = human
# raise ValueError(f"Cannot override {agent_config.name} existing human {agent_config.human} with {human}") # raise ValueError(f"Cannot override {agent_config.name} existing human {agent_config.human} with {human}")
# Allow overriding model specifics (model, model wrapper, model endpoint IP + type, context_window)
if model and model != agent_config.model: if model and model != agent_config.model:
typer.secho(f"Warning: Overriding existing model {agent_config.model} with {model}", fg=typer.colors.YELLOW) typer.secho(f"Warning: Overriding existing model {agent_config.model} with {model}", fg=typer.colors.YELLOW)
agent_config.model = model agent_config.model = model
# raise ValueError(f"Cannot override {agent_config.name} existing model {agent_config.model} with {model}") if context_window is not None and int(context_window) != agent_config.context_window:
typer.secho(
f"Warning: Overriding existing context window {agent_config.context_window} with {context_window}", fg=typer.colors.YELLOW
)
agent_config.context_window = context_window
if model_wrapper and model_wrapper != agent_config.model_wrapper:
typer.secho(
f"Warning: Overriding existing model wrapper {agent_config.model_wrapper} with {model_wrapper}", fg=typer.colors.YELLOW
)
agent_config.model_wrapper = model_wrapper
if model_endpoint and model_endpoint != agent_config.model_endpoint:
typer.secho(
f"Warning: Overriding existing model endpoint {agent_config.model_endpoint} with {model_endpoint}", fg=typer.colors.YELLOW
)
agent_config.model_endpoint = model_endpoint
if model_endpoint_type and model_endpoint_type != agent_config.model_endpoint_type:
typer.secho(
f"Warning: Overriding existing model endpoint type {agent_config.model_endpoint_type} with {model_endpoint_type}",
fg=typer.colors.YELLOW,
)
agent_config.model_endpoint_type = model_endpoint_type
# Update the agent config with any overrides
agent_config.save() agent_config.save()
# load existing agent # load existing agent
@ -133,17 +158,17 @@ def run(
# create new agent config: override defaults with args if provided # create new agent config: override defaults with args if provided
typer.secho("Creating new agent...", fg=typer.colors.GREEN) typer.secho("Creating new agent...", fg=typer.colors.GREEN)
agent_config = AgentConfig( agent_config = AgentConfig(
name=agent if agent else None, name=agent,
persona=persona if persona else config.default_persona, persona=persona,
human=human if human else config.default_human, human=human,
model=model if model else config.model, preset=preset,
context_window=context_window if context_window else config.context_window, model=model,
preset=preset if preset else config.preset, model_wrapper=model_wrapper,
model_endpoint_type=model_endpoint_type,
model_endpoint=model_endpoint,
context_window=context_window,
) )
## attach data source to agent
# agent_config.attach_data_source(data_source)
# TODO: allow configrable state manager (only local is supported right now) # TODO: allow configrable state manager (only local is supported right now)
persistence_manager = LocalStateManager(agent_config) # TODO: insert dataset/pre-fill persistence_manager = LocalStateManager(agent_config) # TODO: insert dataset/pre-fill
@ -162,6 +187,9 @@ def run(
persistence_manager, persistence_manager,
) )
# pretty print agent config
printd(json.dumps(vars(agent_config), indent=4, sort_keys=True))
# start event loop # start event loop
from memgpt.main import run_agent_loop from memgpt.main import run_agent_loop

View File

@ -1,3 +1,4 @@
import builtins
import questionary import questionary
import openai import openai
from prettytable import PrettyTable from prettytable import PrettyTable
@ -11,126 +12,118 @@ from memgpt import utils
import memgpt.humans.humans as humans import memgpt.humans.humans as humans
import memgpt.personas.personas as personas import memgpt.personas.personas as personas
from memgpt.config import MemGPTConfig, AgentConfig from memgpt.config import MemGPTConfig, AgentConfig, Config
from memgpt.constants import MEMGPT_DIR from memgpt.constants import MEMGPT_DIR
from memgpt.connectors.storage import StorageConnector from memgpt.connectors.storage import StorageConnector
from memgpt.constants import LLM_MAX_TOKENS from memgpt.constants import LLM_MAX_TOKENS
from memgpt.local_llm.constants import DEFAULT_ENDPOINTS, DEFAULT_OLLAMA_MODEL, DEFAULT_WRAPPER_NAME
from memgpt.local_llm.utils import get_available_wrappers
app = typer.Typer() app = typer.Typer()
@app.command() def get_azure_credentials():
def configure():
"""Updates default MemGPT configurations"""
from memgpt.presets.presets import DEFAULT_PRESET, preset_options
MemGPTConfig.create_config_dir()
# Will pre-populate with defaults, or what the user previously set
config = MemGPTConfig.load()
# openai credentials
use_openai = questionary.confirm("Do you want to enable MemGPT with OpenAI?", default=True).ask()
if use_openai:
# search for key in enviornment
openai_key = os.getenv("OPENAI_API_KEY")
if not openai_key:
print("Missing enviornment variables for OpenAI. Please set them and run `memgpt configure` again.")
# TODO: eventually stop relying on env variables and pass in keys explicitly
# openai_key = questionary.text("Open AI API keys not found in enviornment - please enter:").ask()
# azure credentials
use_azure = questionary.confirm("Do you want to enable MemGPT with Azure?", default=(config.azure_key is not None)).ask()
use_azure_deployment_ids = False
if use_azure:
# search for key in enviornment
azure_key = os.getenv("AZURE_OPENAI_KEY") azure_key = os.getenv("AZURE_OPENAI_KEY")
azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT") azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
azure_version = os.getenv("AZURE_OPENAI_VERSION") azure_version = os.getenv("AZURE_OPENAI_VERSION")
azure_deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT") azure_deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT")
azure_embedding_deployment = os.getenv("AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT") azure_embedding_deployment = os.getenv("AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT")
return azure_key, azure_endpoint, azure_version, azure_deployment, azure_embedding_deployment
if all([azure_key, azure_endpoint, azure_version]):
print(f"Using Microsoft endpoint {azure_endpoint}.")
if all([azure_deployment, azure_embedding_deployment]):
print(f"Using deployment id {azure_deployment}")
use_azure_deployment_ids = True
# configure openai def get_openai_credentials():
openai.api_type = "azure" openai_key = os.getenv("OPENAI_API_KEY")
openai.api_key = azure_key return openai_key
openai.api_base = azure_endpoint
openai.api_version = azure_version
else:
print("Missing enviornment variables for Azure. Please set then run `memgpt configure` again.")
# TODO: allow for manual setting
use_azure = False
# TODO: configure local model
# configure provider def configure_llm_endpoint(config: MemGPTConfig):
model_endpoint_options = [] # configure model endpoint
if os.getenv("OPENAI_API_BASE") is not None: model_endpoint_type, model_endpoint = None, None
model_endpoint_options.append(os.getenv("OPENAI_API_BASE"))
if use_openai: # get default
model_endpoint_options += ["openai"] default_model_endpoint_type = config.model_endpoint_type
if use_azure: if config.model_endpoint_type is not None and config.model_endpoint_type not in ["openai", "azure"]: # local model
model_endpoint_options += ["azure"] default_model_endpoint_type = "local"
assert (
len(model_endpoint_options) > 0 provider = questionary.select(
), "No endpoints found. Please enable OpenAI, Azure, or set OPENAI_API_BASE to point at the IP address of your LLM server." "Select LLM inference provider:", choices=["openai", "azure", "local"], default=default_model_endpoint_type
valid_default_model = config.model_endpoint in model_endpoint_options
default_endpoint = questionary.select(
"Select default inference endpoint:",
model_endpoint_options,
default=config.model_endpoint if valid_default_model else model_endpoint_options[0],
).ask() ).ask()
# configure embedding provider # set: model_endpoint_type, model_endpoint
embedding_endpoint_options = [] if provider == "openai":
if use_azure: model_endpoint_type = "openai"
embedding_endpoint_options += ["azure"] model_endpoint = "https://api.openai.com/v1"
if use_openai: model_endpoint = questionary.text("Override default endpoint:", default=model_endpoint).ask()
embedding_endpoint_options += ["openai"] provider = "openai"
embedding_endpoint_options += ["local"] elif provider == "azure":
valid_default_embedding = config.embedding_model in embedding_endpoint_options model_endpoint_type = "azure"
# determine the default selection in a smart way _, model_endpoint, _, _, _ = get_azure_credentials()
if "openai" in embedding_endpoint_options and default_endpoint == "openai": else: # local models
# openai llm -> openai embeddings backend_options = ["webui", "llamacpp", "koboldcpp", "ollama", "lmstudio", "openai"]
default_embedding_endpoint_default = "openai" default_model_endpoint_type = None
elif default_endpoint not in ["openai", "azure"]: # is local if config.model_endpoint_type in backend_options:
# local llm -> local embeddings # set from previous config
default_embedding_endpoint_default = "local" default_model_endpoint_type = config.model_endpoint_type
else: else:
default_embedding_endpoint_default = config.embedding_model if valid_default_embedding else embedding_endpoint_options[-1] # set form env variable (ok if none)
default_embedding_endpoint = questionary.select( default_model_endpoint_type = os.getenv("BACKEND_TYPE")
"Select default embedding endpoint:", embedding_endpoint_options, default=default_embedding_endpoint_default model_endpoint_type = questionary.select(
"Select LLM backend (select 'openai' if you have an OpenAI compatible proxy):",
backend_options,
default=default_model_endpoint_type,
).ask() ).ask()
# configure embedding dimentions # set default endpoint
default_embedding_dim = config.embedding_dim # if OPENAI_API_BASE is set, assume that this is the IP+port the user wanted to use
if default_embedding_endpoint == "local": default_model_endpoint = os.getenv("OPENAI_API_BASE")
# HF model uses lower dimentionality # if OPENAI_API_BASE is not set, try to pull a default IP+port format from a hardcoded set
default_embedding_dim = 384 if default_model_endpoint is None:
if model_endpoint_type in DEFAULT_ENDPOINTS:
default_model_endpoint = DEFAULT_ENDPOINTS[model_endpoint_type]
model_endpoint = questionary.text("Enter default endpoint:", default=default_model_endpoint).ask()
else:
# default_model_endpoint = None
model_endpoint = None
while not model_endpoint:
model_endpoint = questionary.text("Enter default endpoint:").ask()
if "http://" not in model_endpoint and "https://" not in model_endpoint:
typer.secho(f"Endpoint must be a valid address", fg=typer.colors.YELLOW)
model_endpoint = None
assert model_endpoint, f"Environment variable OPENAI_API_BASE must be set."
# configure preset return model_endpoint_type, model_endpoint
default_preset = questionary.select("Select default preset:", preset_options, default=config.preset).ask()
# default model
if use_openai or use_azure: def configure_model(config: MemGPTConfig, model_endpoint_type: str):
model_options = [] # set: model, model_wrapper
if use_openai: model, model_wrapper = None, None
model_options += ["gpt-4", "gpt-4-1106-preview", "gpt-3.5-turbo-16k"] if model_endpoint_type == "openai" or model_endpoint_type == "azure":
model_options = ["gpt-4", "gpt-4-1106-preview", "gpt-3.5-turbo", "gpt-3.5-turbo-16k"]
# TODO: select
valid_model = config.model in model_options valid_model = config.model in model_options
default_model = questionary.select( model = questionary.select(
"Select default model (recommended: gpt-4):", choices=model_options, default=config.model if valid_model else model_options[0] "Select default model (recommended: gpt-4):", choices=model_options, default=config.model if valid_model else model_options[0]
).ask() ).ask()
else: else: # local models
default_model = "local" # TODO: figure out if this is ok? this is for local endpoint # ollama also needs model type
if model_endpoint_type == "ollama":
default_model = config.model if config.model and config.model_endpoint_type == "ollama" else DEFAULT_OLLAMA_MODEL
model = questionary.text(
"Enter default model name (required for Ollama, see: https://memgpt.readthedocs.io/en/latest/ollama):",
default=default_model,
).ask()
model = None if len(model) == 0 else model
# get the max tokens (context window) for the model # model wrapper
if default_model == "local" or str(default_model) not in LLM_MAX_TOKENS: available_model_wrappers = builtins.list(get_available_wrappers().keys())
model_wrapper = questionary.select(
f"Select default model wrapper (recommended: {DEFAULT_WRAPPER_NAME}):",
choices=available_model_wrappers,
default=DEFAULT_WRAPPER_NAME,
).ask()
# set: context_window
if str(model) not in LLM_MAX_TOKENS:
# Ask the user to specify the context length # Ask the user to specify the context length
context_length_options = [ context_length_options = [
str(2**12), # 4096 str(2**12), # 4096
@ -140,46 +133,80 @@ def configure():
str(2**18), # 262144 str(2**18), # 262144
"custom", # enter yourself "custom", # enter yourself
] ]
default_model_context_window = questionary.select( context_window = questionary.select(
"Select your model's context window (for Mistral 7B models, this is probably 8k / 8192):", "Select your model's context window (for Mistral 7B models, this is probably 8k / 8192):",
choices=context_length_options, choices=context_length_options,
default=str(LLM_MAX_TOKENS["DEFAULT"]), default=str(LLM_MAX_TOKENS["DEFAULT"]),
).ask() ).ask()
# If custom, ask for input # If custom, ask for input
if default_model_context_window == "custom": if context_window == "custom":
while True: while True:
default_model_context_window = questionary.text("Enter context window (e.g. 8192)").ask() context_window = questionary.text("Enter context window (e.g. 8192)").ask()
try: try:
default_model_context_window = int(default_model_context_window) context_window = int(context_window)
break break
except ValueError: except ValueError:
print(f"Context window must be a valid integer") print(f"Context window must be a valid integer")
else: else:
default_model_context_window = int(default_model_context_window) context_window = int(context_window)
else: else:
# Pull the context length from the models # Pull the context length from the models
default_model_context_window = LLM_MAX_TOKENS[default_model] context_window = LLM_MAX_TOKENS[model]
return model, model_wrapper, context_window
# defaults
def configure_embedding_endpoint(config: MemGPTConfig):
# configure embedding endpoint
default_embedding_endpoint_type = config.embedding_endpoint_type
if config.embedding_endpoint_type is not None and config.embedding_endpoint_type not in ["openai", "azure"]: # local model
default_embedding_endpoint_type = "local"
embedding_endpoint_type, embedding_endpoint, embedding_dim = None, None, None
embedding_provider = questionary.select(
"Select embedding provider:", choices=["openai", "azure", "local"], default=default_embedding_endpoint_type
).ask()
if embedding_provider == "openai":
embedding_endpoint_type = "openai"
embedding_endpoint = "https://api.openai.com/v1"
embedding_dim = 1536
elif embedding_provider == "azure":
embedding_endpoint_type = "azure"
_, _, _, _, embedding_endpoint = get_azure_credentials()
embedding_dim = 1536
else: # local models
embedding_endpoint_type = "local"
embedding_endpoint = None
embedding_dim = 384
return embedding_endpoint_type, embedding_endpoint, embedding_dim
def configure_cli(config: MemGPTConfig):
# set: preset, default_persona, default_human, default_agent``
from memgpt.presets.presets import preset_options
# preset
default_preset = config.preset if config.preset and config.preset in preset_options else None
preset = questionary.select("Select default preset:", preset_options, default=default_preset).ask()
# persona
personas = [os.path.basename(f).replace(".txt", "") for f in utils.list_persona_files()] personas = [os.path.basename(f).replace(".txt", "") for f in utils.list_persona_files()]
# print(personas) default_persona = config.persona if config.persona and config.persona in personas else None
default_persona = questionary.select("Select default persona:", personas, default=config.default_persona).ask() persona = questionary.select("Select default persona:", personas, default=default_persona).ask()
# human
humans = [os.path.basename(f).replace(".txt", "") for f in utils.list_human_files()] humans = [os.path.basename(f).replace(".txt", "") for f in utils.list_human_files()]
# print(humans) default_human = config.human if config.human and config.human in humans else None
default_human = questionary.select("Select default human:", humans, default=config.default_human).ask() human = questionary.select("Select default human:", humans, default=default_human).ask()
# TODO: figure out if we should set a default agent or not # TODO: figure out if we should set a default agent or not
default_agent = None agent = None
# agents = [os.path.basename(f).replace(".json", "") for f in utils.list_agent_config_files()]
# if len(agents) > 0: # agents have been created
# default_agent = questionary.select(
# "Select default agent:",
# agents
# ).ask()
# else:
# default_agent = None
return preset, persona, human, agent
def configure_archival_storage(config: MemGPTConfig):
# Configure archival storage backend # Configure archival storage backend
archival_storage_options = ["local", "postgres"] archival_storage_options = ["local", "postgres"]
archival_storage_type = questionary.select( archival_storage_type = questionary.select(
@ -191,25 +218,65 @@ def configure():
"Enter postgres connection string (e.g. postgresql+pg8000://{user}:{password}@{ip}:5432/{database}):", "Enter postgres connection string (e.g. postgresql+pg8000://{user}:{password}@{ip}:5432/{database}):",
default=config.archival_storage_uri if config.archival_storage_uri else "", default=config.archival_storage_uri if config.archival_storage_uri else "",
).ask() ).ask()
return archival_storage_type, archival_storage_uri
# TODO: allow configuring embedding model
@app.command()
def configure():
"""Updates default MemGPT configurations"""
MemGPTConfig.create_config_dir()
# Will pre-populate with defaults, or what the user previously set
config = MemGPTConfig.load()
model_endpoint_type, model_endpoint = configure_llm_endpoint(config)
model, model_wrapper, context_window = configure_model(config, model_endpoint_type)
embedding_endpoint_type, embedding_endpoint, embedding_dim = configure_embedding_endpoint(config)
default_preset, default_persona, default_human, default_agent = configure_cli(config)
archival_storage_type, archival_storage_uri = configure_archival_storage(config)
# check credentials
azure_key, azure_endpoint, azure_version, azure_deployment, azure_embedding_deployment = get_azure_credentials()
openai_key = get_openai_credentials()
if model_endpoint_type == "azure" or embedding_endpoint_type == "azure":
if all([azure_key, azure_endpoint, azure_version]):
print(f"Using Microsoft endpoint {azure_endpoint}.")
if all([azure_deployment, azure_embedding_deployment]):
print(f"Using deployment id {azure_deployment}")
else:
raise ValueError(
"Missing environment variables for Azure (see https://memgpt.readthedocs.io/en/latest/endpoints/#azure). Please set then run `memgpt configure` again."
)
if model_endpoint_type == "openai" or embedding_endpoint_type == "openai":
if not openai_key:
raise ValueError(
"Missing environment variables for OpenAI (see https://memgpt.readthedocs.io/en/latest/endpoints/#openai). Please set them and run `memgpt configure` again."
)
config = MemGPTConfig( config = MemGPTConfig(
model=default_model, # model configs
context_window=default_model_context_window, model=model,
model_endpoint=model_endpoint,
model_endpoint_type=model_endpoint_type,
model_wrapper=model_wrapper,
context_window=context_window,
# embedding configs
embedding_endpoint_type=embedding_endpoint_type,
embedding_endpoint=embedding_endpoint,
embedding_dim=embedding_dim,
# cli configs
preset=default_preset, preset=default_preset,
model_endpoint=default_endpoint, persona=default_persona,
embedding_model=default_embedding_endpoint, human=default_human,
embedding_dim=default_embedding_dim, agent=default_agent,
default_persona=default_persona, # credentials
default_human=default_human, openai_key=openai_key,
default_agent=default_agent, azure_key=azure_key,
openai_key=openai_key if use_openai else None, azure_endpoint=azure_endpoint,
azure_key=azure_key if use_azure else None, azure_version=azure_version,
azure_endpoint=azure_endpoint if use_azure else None, azure_deployment=azure_deployment,
azure_version=azure_version if use_azure else None, azure_embedding_deployment=azure_embedding_deployment,
azure_deployment=azure_deployment if use_azure_deployment_ids else None, # storage
azure_embedding_deployment=azure_embedding_deployment if use_azure_deployment_ids else None,
archival_storage_type=archival_storage_type, archival_storage_type=archival_storage_type,
archival_storage_uri=archival_storage_uri, archival_storage_uri=archival_storage_uri,
) )

View File

@ -16,6 +16,7 @@ from colorama import Fore, Style
from typing import List, Type from typing import List, Type
import memgpt
import memgpt.utils as utils import memgpt.utils as utils
from memgpt.interface import CLIInterface as interface from memgpt.interface import CLIInterface as interface
from memgpt.personas.personas import get_persona_text from memgpt.personas.personas import get_persona_text
@ -40,6 +41,24 @@ model_choices = [
] ]
# helper functions for writing to configs
def get_field(config, section, field):
if section not in config:
return None
if config.has_option(section, field):
return config.get(section, field)
else:
return None
def set_field(config, section, field, value):
if value is None: # cannot write None
return
if section not in config: # create section
config.add_section(section)
config.set(section, field, value)
@dataclass @dataclass
class MemGPTConfig: class MemGPTConfig:
config_path: str = os.path.join(MEMGPT_DIR, "config") config_path: str = os.path.join(MEMGPT_DIR, "config")
@ -49,9 +68,10 @@ class MemGPTConfig:
preset: str = DEFAULT_PRESET preset: str = DEFAULT_PRESET
# model parameters # model parameters
# provider: str = "openai" # openai, azure, local (TODO) model: str = None
model_endpoint: str = "openai" model_endpoint_type: str = None
model: str = "gpt-4" # gpt-4, gpt-3.5-turbo, local model_endpoint: str = None # localhost:8000
model_wrapper: str = None
context_window: int = LLM_MAX_TOKENS[model] if model in LLM_MAX_TOKENS else LLM_MAX_TOKENS["DEFAULT"] context_window: int = LLM_MAX_TOKENS[model] if model in LLM_MAX_TOKENS else LLM_MAX_TOKENS["DEFAULT"]
# model parameters: openai # model parameters: openai
@ -65,12 +85,13 @@ class MemGPTConfig:
azure_embedding_deployment: str = None azure_embedding_deployment: str = None
# persona parameters # persona parameters
default_persona: str = personas.DEFAULT persona: str = personas.DEFAULT
default_human: str = humans.DEFAULT human: str = humans.DEFAULT
default_agent: str = None agent: str = None
# embedding parameters # embedding parameters
embedding_model: str = "openai" embedding_endpoint_type: str = "openai" # openai, azure, local
embedding_endpoint: str = None
embedding_dim: int = 1536 embedding_dim: int = 1536
embedding_chunk_size: int = 300 # number of tokens embedding_chunk_size: int = 300 # number of tokens
@ -89,6 +110,12 @@ class MemGPTConfig:
persistence_manager_save_file: str = None # local file persistence_manager_save_file: str = None # local file
persistence_manager_uri: str = None # db URI persistence_manager_uri: str = None # db URI
def __post_init__(self):
# ensure types
self.embedding_chunk_size = int(self.embedding_chunk_size)
self.embedding_dim = int(self.embedding_dim)
self.context_window = int(self.context_window)
@staticmethod @staticmethod
def generate_uuid() -> str: def generate_uuid() -> str:
return uuid.UUID(int=uuid.getnode()).hex return uuid.UUID(int=uuid.getnode()).hex
@ -104,72 +131,38 @@ class MemGPTConfig:
config_path = MemGPTConfig.config_path config_path = MemGPTConfig.config_path
if os.path.exists(config_path): if os.path.exists(config_path):
# read existing config
config.read(config_path) config.read(config_path)
config_dict = {
"model": get_field(config, "model", "model"),
"model_endpoint": get_field(config, "model", "model_endpoint"),
"model_endpoint_type": get_field(config, "model", "model_endpoint_type"),
"model_wrapper": get_field(config, "model", "model_wrapper"),
"context_window": get_field(config, "model", "context_window"),
"preset": get_field(config, "defaults", "preset"),
"persona": get_field(config, "defaults", "persona"),
"human": get_field(config, "defaults", "human"),
"agent": get_field(config, "defaults", "agent"),
"openai_key": get_field(config, "openai", "key"),
"azure_key": get_field(config, "azure", "key"),
"azure_endpoint": get_field(config, "azure", "endpoint"),
"azure_version": get_field(config, "azure", "version"),
"azure_deployment": get_field(config, "azure", "deployment"),
"azure_embedding_deployment": get_field(config, "azure", "embedding_deployment"),
"embedding_endpoint": get_field(config, "embedding", "embedding_endpoint"),
"embedding_endpoint_type": get_field(config, "embedding", "embedding_endpoint_type"),
"embedding_dim": get_field(config, "embedding", "embedding_dim"),
"embedding_chunk_size": get_field(config, "embedding", "chunk_size"),
"archival_storage_type": get_field(config, "archival_storage", "type"),
"archival_storage_path": get_field(config, "archival_storage", "path"),
"archival_storage_uri": get_field(config, "archival_storage", "uri"),
"anon_clientid": get_field(config, "client", "anon_clientid"),
"config_path": config_path,
}
config_dict = {k: v for k, v in config_dict.items() if v is not None}
return cls(**config_dict)
# read config values # create new config
model = config.get("defaults", "model")
context_window = (
int(config.get("defaults", "context_window"))
if config.has_option("defaults", "context_window")
else LLM_MAX_TOKENS["DEFAULT"]
)
preset = config.get("defaults", "preset")
model_endpoint = config.get("defaults", "model_endpoint")
default_persona = config.get("defaults", "persona")
default_human = config.get("defaults", "human")
default_agent = config.get("defaults", "agent") if config.has_option("defaults", "agent") else None
openai_key, openai_model = None, None
if "openai" in config:
openai_key = config.get("openai", "key")
azure_key, azure_endpoint, azure_version, azure_deployment, azure_embedding_deployment = None, None, None, None, None
if "azure" in config:
azure_key = config.get("azure", "key")
azure_endpoint = config.get("azure", "endpoint")
azure_version = config.get("azure", "version")
azure_deployment = config.get("azure", "deployment") if config.has_option("azure", "deployment") else None
azure_embedding_deployment = (
config.get("azure", "embedding_deployment") if config.has_option("azure", "embedding_deployment") else None
)
embedding_model = config.get("embedding", "model")
embedding_dim = config.getint("embedding", "dim")
embedding_chunk_size = config.getint("embedding", "chunk_size")
# archival storage
archival_storage_type, archival_storage_path, archival_storage_uri = "local", None, None
if "archival_storage" in config:
archival_storage_type = config.get("archival_storage", "type")
archival_storage_path = config.get("archival_storage", "path") if config.has_option("archival_storage", "path") else None
archival_storage_uri = config.get("archival_storage", "uri") if config.has_option("archival_storage", "uri") else None
anon_clientid = config.get("client", "anon_clientid")
return cls(
model=model,
context_window=context_window,
preset=preset,
model_endpoint=model_endpoint,
default_persona=default_persona,
default_human=default_human,
default_agent=default_agent,
openai_key=openai_key,
azure_key=azure_key,
azure_endpoint=azure_endpoint,
azure_version=azure_version,
azure_deployment=azure_deployment,
azure_embedding_deployment=azure_embedding_deployment,
embedding_model=embedding_model,
embedding_dim=embedding_dim,
embedding_chunk_size=embedding_chunk_size,
archival_storage_type=archival_storage_type,
archival_storage_path=archival_storage_path,
archival_storage_uri=archival_storage_uri,
anon_clientid=anon_clientid,
config_path=config_path,
)
anon_clientid = MemGPTConfig.generate_uuid() anon_clientid = MemGPTConfig.generate_uuid()
config = cls(anon_clientid=anon_clientid, config_path=config_path) config = cls(anon_clientid=anon_clientid, config_path=config_path)
config.save() # save updated config config.save() # save updated config
@ -179,51 +172,43 @@ class MemGPTConfig:
config = configparser.ConfigParser() config = configparser.ConfigParser()
# CLI defaults # CLI defaults
config.add_section("defaults") set_field(config, "defaults", "preset", self.preset)
config.set("defaults", "model", self.model) set_field(config, "defaults", "persona", self.persona)
config.set("defaults", "context_window", str(self.context_window)) set_field(config, "defaults", "human", self.human)
config.set("defaults", "preset", self.preset) set_field(config, "defaults", "agent", self.agent)
assert self.model_endpoint is not None, "Endpoint must be set"
config.set("defaults", "model_endpoint", self.model_endpoint)
config.set("defaults", "persona", self.default_persona)
config.set("defaults", "human", self.default_human)
if self.default_agent:
config.set("defaults", "agent", self.default_agent)
# security credentials # model defaults
if self.openai_key: set_field(config, "model", "model", self.model)
config.add_section("openai") set_field(config, "model", "model_endpoint", self.model_endpoint)
config.set("openai", "key", self.openai_key) set_field(config, "model", "model_endpoint_type", self.model_endpoint_type)
set_field(config, "model", "model_wrapper", self.model_wrapper)
set_field(config, "model", "context_window", str(self.context_window))
if self.azure_key: # security credentials: openai
config.add_section("azure") set_field(config, "openai", "key", self.openai_key)
config.set("azure", "key", self.azure_key)
config.set("azure", "endpoint", self.azure_endpoint) # security credentials: azure
config.set("azure", "version", self.azure_version) set_field(config, "azure", "key", self.azure_key)
if self.azure_deployment: set_field(config, "azure", "endpoint", self.azure_endpoint)
config.set("azure", "deployment", self.azure_deployment) set_field(config, "azure", "version", self.azure_version)
config.set("azure", "embedding_deployment", self.azure_embedding_deployment) set_field(config, "azure", "deployment", self.azure_deployment)
set_field(config, "azure", "embedding_deployment", self.azure_embedding_deployment)
# embeddings # embeddings
config.add_section("embedding") set_field(config, "embedding", "embedding_endpoint_type", self.embedding_endpoint_type)
config.set("embedding", "model", self.embedding_model) set_field(config, "embedding", "embedding_endpoint", self.embedding_endpoint)
config.set("embedding", "dim", str(self.embedding_dim)) set_field(config, "embedding", "embedding_dim", str(self.embedding_dim))
config.set("embedding", "chunk_size", str(self.embedding_chunk_size)) set_field(config, "embedding", "embedding_chunk_size", str(self.embedding_chunk_size))
# archival storage # archival storage
config.add_section("archival_storage") set_field(config, "archival_storage", "type", self.archival_storage_type)
# print("archival storage", self.archival_storage_type) set_field(config, "archival_storage", "path", self.archival_storage_path)
config.set("archival_storage", "type", self.archival_storage_type) set_field(config, "archival_storage", "uri", self.archival_storage_uri)
if self.archival_storage_path:
config.set("archival_storage", "path", self.archival_storage_path)
if self.archival_storage_uri:
config.set("archival_storage", "uri", self.archival_storage_uri)
# client # client
config.add_section("client")
if not self.anon_clientid: if not self.anon_clientid:
self.anon_clientid = self.generate_uuid() self.anon_clientid = self.generate_uuid()
config.set("client", "anon_clientid", self.anon_clientid) set_field(config, "client", "anon_clientid", self.anon_clientid)
if not os.path.exists(MEMGPT_DIR): if not os.path.exists(MEMGPT_DIR):
os.makedirs(MEMGPT_DIR, exist_ok=True) os.makedirs(MEMGPT_DIR, exist_ok=True)
@ -262,32 +247,54 @@ class AgentConfig:
self, self,
persona, persona,
human, human,
# model info
model, model,
model_endpoint_type=None,
model_endpoint=None,
model_wrapper=None,
context_window=None, context_window=None,
preset=DEFAULT_PRESET, # embedding info
name=None, embedding_endpoint_type=None,
data_sources=[], embedding_endpoint=None,
embedding_dim=None,
embedding_chunk_size=None,
# other
preset=None,
data_sources=None,
# agent info
agent_config_path=None, agent_config_path=None,
name=None,
create_time=None, create_time=None,
data_source=None, memgpt_version=None,
): ):
if name is None: if name is None:
self.name = f"agent_{self.generate_agent_id()}" self.name = f"agent_{self.generate_agent_id()}"
else: else:
self.name = name self.name = name
self.persona = persona
self.human = human
self.model = model
self.context_window = context_window
self.preset = preset
self.data_sources = data_sources
self.create_time = create_time if create_time is not None else utils.get_local_time()
self.data_source = None # deprecated
if context_window is None: config = MemGPTConfig.load() # get default values
self.context_window = LLM_MAX_TOKENS[self.model] if self.model in LLM_MAX_TOKENS else LLM_MAX_TOKENS["DEFAULT"] self.persona = config.persona if persona is None else persona
self.human = config.human if human is None else human
self.preset = config.preset if preset is None else preset
self.context_window = config.context_window if context_window is None else context_window
self.model = config.model if model is None else model
self.model_endpoint_type = config.model_endpoint_type if model_endpoint_type is None else model_endpoint_type
self.model_endpoint = config.model_endpoint if model_endpoint is None else model_endpoint
self.model_wrapper = config.model_wrapper if model_wrapper is None else model_wrapper
self.embedding_endpoint_type = config.embedding_endpoint_type if embedding_endpoint_type is None else embedding_endpoint_type
self.embedding_endpoint = config.embedding_endpoint if embedding_endpoint is None else embedding_endpoint
self.embedding_dim = config.embedding_dim if embedding_dim is None else embedding_dim
self.embedding_chunk_size = config.embedding_chunk_size if embedding_chunk_size is None else embedding_chunk_size
# agent metadata
self.data_sources = data_sources if data_sources is not None else []
self.create_time = create_time if create_time is not None else utils.get_local_time()
if memgpt_version is None:
import memgpt
self.memgpt_version = memgpt.__version__
else: else:
self.context_window = context_window self.memgpt_version = memgpt_version
# save agent config # save agent config
self.agent_config_path = ( self.agent_config_path = (
@ -326,6 +333,8 @@ class AgentConfig:
def save(self): def save(self):
# save state of persistence manager # save state of persistence manager
os.makedirs(os.path.join(MEMGPT_DIR, "agents", self.name), exist_ok=True) os.makedirs(os.path.join(MEMGPT_DIR, "agents", self.name), exist_ok=True)
# save version
self.memgpt_version = memgpt.__version__
with open(self.agent_config_path, "w") as f: with open(self.agent_config_path, "w") as f:
json.dump(vars(self), f, indent=4) json.dump(vars(self), f, indent=4)
@ -342,7 +351,6 @@ class AgentConfig:
assert os.path.exists(agent_config_path), f"Agent config file does not exist at {agent_config_path}" assert os.path.exists(agent_config_path), f"Agent config file does not exist at {agent_config_path}"
with open(agent_config_path, "r") as f: with open(agent_config_path, "r") as f:
agent_config = json.load(f) agent_config = json.load(f)
# allow compatibility accross versions # allow compatibility accross versions
try: try:
class_args = inspect.getargspec(cls.__init__).args class_args = inspect.getargspec(cls.__init__).args
@ -354,7 +362,6 @@ class AgentConfig:
if key not in class_args: if key not in class_args:
utils.printd(f"Removing missing argument {key} from agent config") utils.printd(f"Removing missing argument {key} from agent config")
del agent_config[key] del agent_config[key]
return cls(**agent_config) return cls(**agent_config)

View File

@ -11,9 +11,9 @@ def embedding_model():
# load config # load config
config = MemGPTConfig.load() config = MemGPTConfig.load()
endpoint = config.embedding_model endpoint = config.embedding_endpoint_type
if endpoint == "openai": if endpoint == "openai":
model = OpenAIEmbedding(api_base="https://api.openai.com/v1", api_key=config.openai_key) model = OpenAIEmbedding(api_base=config.embedding_endpoint, api_key=config.openai_key)
return model return model
elif endpoint == "azure": elif endpoint == "azure":
return OpenAIEmbedding( return OpenAIEmbedding(

View File

@ -4,7 +4,7 @@ import os
import json import json
import math import math
from ...constants import MAX_PAUSE_HEARTBEATS, RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE from memgpt.constants import MAX_PAUSE_HEARTBEATS, RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
### Functions / tools the agent can use ### Functions / tools the agent can use
# All functions should return a response string (or None) # All functions should return a response string (or None)

View File

@ -4,8 +4,8 @@ import json
import requests import requests
from ...constants import MESSAGE_CHATGPT_FUNCTION_MODEL, MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE, MAX_PAUSE_HEARTBEATS from memgpt.constants import MESSAGE_CHATGPT_FUNCTION_MODEL, MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE, MAX_PAUSE_HEARTBEATS
from ...openai_tools import completions_with_backoff as create from memgpt.openai_tools import completions_with_backoff as create
def message_chatgpt(self, message: str): def message_chatgpt(self, message: str):

View File

@ -10,74 +10,65 @@ from .llamacpp.api import get_llamacpp_completion
from .koboldcpp.api import get_koboldcpp_completion from .koboldcpp.api import get_koboldcpp_completion
from .ollama.api import get_ollama_completion from .ollama.api import get_ollama_completion
from .llm_chat_completion_wrappers import airoboros, dolphin, zephyr, simple_summary_wrapper from .llm_chat_completion_wrappers import airoboros, dolphin, zephyr, simple_summary_wrapper
from .utils import DotDict from .constants import DEFAULT_WRAPPER
from .utils import DotDict, get_available_wrappers
from ..prompts.gpt_summarize import SYSTEM as SUMMARIZE_SYSTEM_MESSAGE from ..prompts.gpt_summarize import SYSTEM as SUMMARIZE_SYSTEM_MESSAGE
from ..errors import LocalLLMConnectionError, LocalLLMError from ..errors import LocalLLMConnectionError, LocalLLMError
HOST = os.getenv("OPENAI_API_BASE") endpoint = os.getenv("OPENAI_API_BASE")
HOST_TYPE = os.getenv("BACKEND_TYPE") # default None == ChatCompletion endpoint_type = os.getenv("BACKEND_TYPE") # default None == ChatCompletion
DEBUG = False DEBUG = False
# DEBUG = True # DEBUG = True
DEFAULT_WRAPPER = airoboros.Airoboros21InnerMonologueWrapper
has_shown_warning = False has_shown_warning = False
def get_chat_completion( def get_chat_completion(
model, # no model, since the model is fixed to whatever you set in your own backend model, # no model required (except for Ollama), since the model is fixed to whatever you set in your own backend
messages, messages,
functions=None, functions=None,
function_call="auto", function_call="auto",
context_window=None, context_window=None,
# required
wrapper=None,
endpoint=None,
endpoint_type=None,
): ):
assert context_window is not None, "Local LLM calls need the context length to be explicitly set" assert context_window is not None, "Local LLM calls need the context length to be explicitly set"
assert endpoint is not None, "Local LLM calls need the endpoint (eg http://localendpoint:1234) to be explicitly set"
assert endpoint_type is not None, "Local LLM calls need the endpoint type (eg webui) to be explicitly set"
global has_shown_warning global has_shown_warning
grammar_name = None grammar_name = None
if HOST is None:
raise ValueError(f"The OPENAI_API_BASE environment variable is not defined. Please set it in your environment.")
if HOST_TYPE is None:
raise ValueError(f"The BACKEND_TYPE environment variable is not defined. Please set it in your environment.")
if function_call != "auto": if function_call != "auto":
raise ValueError(f"function_call == {function_call} not supported (auto only)") raise ValueError(f"function_call == {function_call} not supported (auto only)")
available_wrappers = get_available_wrappers()
if messages[0]["role"] == "system" and messages[0]["content"].strip() == SUMMARIZE_SYSTEM_MESSAGE.strip(): if messages[0]["role"] == "system" and messages[0]["content"].strip() == SUMMARIZE_SYSTEM_MESSAGE.strip():
# Special case for if the call we're making is coming from the summarizer # Special case for if the call we're making is coming from the summarizer
llm_wrapper = simple_summary_wrapper.SimpleSummaryWrapper() llm_wrapper = simple_summary_wrapper.SimpleSummaryWrapper()
elif model == "airoboros-l2-70b-2.1": elif wrapper is None:
llm_wrapper = airoboros.Airoboros21InnerMonologueWrapper()
elif model == "airoboros-l2-70b-2.1-grammar":
llm_wrapper = airoboros.Airoboros21InnerMonologueWrapper(include_opening_brace_in_prefix=False)
# grammar_name = "json"
grammar_name = "json_func_calls_with_inner_thoughts"
elif model == "dolphin-2.1-mistral-7b":
llm_wrapper = dolphin.Dolphin21MistralWrapper()
elif model == "dolphin-2.1-mistral-7b-grammar":
llm_wrapper = dolphin.Dolphin21MistralWrapper(include_opening_brace_in_prefix=False)
# grammar_name = "json"
grammar_name = "json_func_calls_with_inner_thoughts"
elif model == "zephyr-7B-alpha" or model == "zephyr-7B-beta":
llm_wrapper = zephyr.ZephyrMistralInnerMonologueWrapper()
elif model == "zephyr-7B-alpha-grammar" or model == "zephyr-7B-beta-grammar":
llm_wrapper = zephyr.ZephyrMistralInnerMonologueWrapper(include_opening_brace_in_prefix=False)
# grammar_name = "json"
grammar_name = "json_func_calls_with_inner_thoughts"
else:
# Warn the user that we're using the fallback # Warn the user that we're using the fallback
if not has_shown_warning: if not has_shown_warning:
print( print(
f"Warning: no wrapper specified for local LLM, using the default wrapper (you can remove this warning by specifying the wrapper with --model)" f"Warning: no wrapper specified for local LLM, using the default wrapper (you can remove this warning by specifying the wrapper with --wrapper)"
) )
has_shown_warning = True has_shown_warning = True
if HOST_TYPE in ["koboldcpp", "llamacpp", "webui"]: if endpoint_type in ["koboldcpp", "llamacpp", "webui"]:
# make the default to use grammar # make the default to use grammar
llm_wrapper = DEFAULT_WRAPPER(include_opening_brace_in_prefix=False) llm_wrapper = DEFAULT_WRAPPER(include_opening_brace_in_prefix=False)
# grammar_name = "json" # grammar_name = "json"
grammar_name = "json_func_calls_with_inner_thoughts" grammar_name = "json_func_calls_with_inner_thoughts"
else: else:
llm_wrapper = DEFAULT_WRAPPER() llm_wrapper = DEFAULT_WRAPPER()
elif wrapper not in available_wrappers:
raise ValueError(f"Could not find requested wrapper '{wrapper} in available wrappers list:\n{available_wrappers}")
else:
llm_wrapper = available_wrappers[wrapper]
if "grammar" in wrapper:
grammar_name = "json_func_calls_with_inner_thoughts"
if grammar_name is not None and HOST_TYPE not in ["koboldcpp", "llamacpp", "webui"]: if grammar_name is not None and endpoint_type not in ["koboldcpp", "llamacpp", "webui"]:
print(f"Warning: grammars are currently only supported when using llama.cpp as the MemGPT local LLM backend") print(f"Warning: grammars are currently only supported when using llama.cpp as the MemGPT local LLM backend")
# First step: turn the message sequence into a prompt that the model expects # First step: turn the message sequence into a prompt that the model expects
@ -91,25 +82,25 @@ def get_chat_completion(
) )
try: try:
if HOST_TYPE == "webui": if endpoint_type == "webui":
result = get_webui_completion(prompt, context_window, grammar=grammar_name) result = get_webui_completion(endpoint, prompt, context_window, grammar=grammar_name)
elif HOST_TYPE == "lmstudio": elif endpoint_type == "lmstudio":
result = get_lmstudio_completion(prompt, context_window) result = get_lmstudio_completion(endpoint, prompt, context_window)
elif HOST_TYPE == "llamacpp": elif endpoint_type == "llamacpp":
result = get_llamacpp_completion(prompt, context_window, grammar=grammar_name) result = get_llamacpp_completion(endpoint, prompt, context_window, grammar=grammar_name)
elif HOST_TYPE == "koboldcpp": elif endpoint_type == "koboldcpp":
result = get_koboldcpp_completion(prompt, context_window, grammar=grammar_name) result = get_koboldcpp_completion(endpoint, prompt, context_window, grammar=grammar_name)
elif HOST_TYPE == "ollama": elif endpoint_type == "ollama":
result = get_ollama_completion(prompt, context_window) result = get_ollama_completion(endpoint, model, prompt, context_window)
else: else:
raise LocalLLMError( raise LocalLLMError(
f"BACKEND_TYPE is not set, please set variable depending on your backend (webui, lmstudio, llamacpp, koboldcpp)" f"BACKEND_TYPE is not set, please set variable depending on your backend (webui, lmstudio, llamacpp, koboldcpp)"
) )
except requests.exceptions.ConnectionError as e: except requests.exceptions.ConnectionError as e:
raise LocalLLMConnectionError(f"Unable to connect to host {HOST}") raise LocalLLMConnectionError(f"Unable to connect to endpoint {endpoint}")
if result is None or result == "": if result is None or result == "":
raise LocalLLMError(f"Got back an empty response string from {HOST}") raise LocalLLMError(f"Got back an empty response string from {endpoint}")
if DEBUG: if DEBUG:
print(f"Raw LLM output:\n{result}") print(f"Raw LLM output:\n{result}")
@ -123,7 +114,7 @@ def get_chat_completion(
# unpack with response.choices[0].message.content # unpack with response.choices[0].message.content
response = DotDict( response = DotDict(
{ {
"model": None, "model": model,
"choices": [ "choices": [
DotDict( DotDict(
{ {

View File

@ -0,0 +1,14 @@
import memgpt.local_llm.llm_chat_completion_wrappers.airoboros as airoboros
DEFAULT_ENDPOINTS = {
"koboldcpp": "http://localhost:5001",
"llamacpp": "http://localhost:8080",
"lmstudio": "http://localhost:1234",
"ollama": "http://localhost:11434",
"webui": "http://localhost:5000",
}
DEFAULT_OLLAMA_MODEL = "dolphin2.2-mistral:7b-q6_K"
DEFAULT_WRAPPER = airoboros.Airoboros21InnerMonologueWrapper
DEFAULT_WRAPPER_NAME = "airoboros-l2-70b-2.1"

View File

@ -5,14 +5,12 @@ import requests
from .settings import SIMPLE from .settings import SIMPLE
from ..utils import load_grammar_file, count_tokens from ..utils import load_grammar_file, count_tokens
HOST = os.getenv("OPENAI_API_BASE")
HOST_TYPE = os.getenv("BACKEND_TYPE") # default None == ChatCompletion
KOBOLDCPP_API_SUFFIX = "/api/v1/generate" KOBOLDCPP_API_SUFFIX = "/api/v1/generate"
# DEBUG = False DEBUG = False
DEBUG = True # DEBUG = True
def get_koboldcpp_completion(prompt, context_window, grammar=None, settings=SIMPLE): def get_koboldcpp_completion(endpoint, prompt, context_window, grammar=None, settings=SIMPLE):
"""See https://lite.koboldai.net/koboldcpp_api for API spec""" """See https://lite.koboldai.net/koboldcpp_api for API spec"""
prompt_tokens = count_tokens(prompt) prompt_tokens = count_tokens(prompt)
if prompt_tokens > context_window: if prompt_tokens > context_window:
@ -27,13 +25,13 @@ def get_koboldcpp_completion(prompt, context_window, grammar=None, settings=SIMP
if grammar is not None: if grammar is not None:
request["grammar"] = load_grammar_file(grammar) request["grammar"] = load_grammar_file(grammar)
if not HOST.startswith(("http://", "https://")): if not endpoint.startswith(("http://", "https://")):
raise ValueError(f"Provided OPENAI_API_BASE value ({HOST}) must begin with http:// or https://") raise ValueError(f"Provided OPENAI_API_BASE value ({endpoint}) must begin with http:// or https://")
try: try:
# NOTE: llama.cpp server returns the following when it's out of context # NOTE: llama.cpp server returns the following when it's out of context
# curl: (52) Empty reply from server # curl: (52) Empty reply from server
URI = urljoin(HOST.strip("/") + "/", KOBOLDCPP_API_SUFFIX.strip("/")) URI = urljoin(endpoint.strip("/") + "/", KOBOLDCPP_API_SUFFIX.strip("/"))
response = requests.post(URI, json=request) response = requests.post(URI, json=request)
if response.status_code == 200: if response.status_code == 200:
result = response.json() result = response.json()

View File

@ -5,14 +5,12 @@ import requests
from .settings import SIMPLE from .settings import SIMPLE
from ..utils import load_grammar_file, count_tokens from ..utils import load_grammar_file, count_tokens
HOST = os.getenv("OPENAI_API_BASE")
HOST_TYPE = os.getenv("BACKEND_TYPE") # default None == ChatCompletion
LLAMACPP_API_SUFFIX = "/completion" LLAMACPP_API_SUFFIX = "/completion"
# DEBUG = False DEBUG = False
DEBUG = True # DEBUG = True
def get_llamacpp_completion(prompt, context_window, grammar=None, settings=SIMPLE): def get_llamacpp_completion(endpoint, prompt, context_window, grammar=None, settings=SIMPLE):
"""See https://github.com/ggerganov/llama.cpp/blob/master/examples/server/README.md for instructions on how to run the LLM web server""" """See https://github.com/ggerganov/llama.cpp/blob/master/examples/server/README.md for instructions on how to run the LLM web server"""
prompt_tokens = count_tokens(prompt) prompt_tokens = count_tokens(prompt)
if prompt_tokens > context_window: if prompt_tokens > context_window:
@ -26,13 +24,13 @@ def get_llamacpp_completion(prompt, context_window, grammar=None, settings=SIMPL
if grammar is not None: if grammar is not None:
request["grammar"] = load_grammar_file(grammar) request["grammar"] = load_grammar_file(grammar)
if not HOST.startswith(("http://", "https://")): if not endpoint.startswith(("http://", "https://")):
raise ValueError(f"Provided OPENAI_API_BASE value ({HOST}) must begin with http:// or https://") raise ValueError(f"Provided OPENAI_API_BASE value ({endpoint}) must begin with http:// or https://")
try: try:
# NOTE: llama.cpp server returns the following when it's out of context # NOTE: llama.cpp server returns the following when it's out of context
# curl: (52) Empty reply from server # curl: (52) Empty reply from server
URI = urljoin(HOST.strip("/") + "/", LLAMACPP_API_SUFFIX.strip("/")) URI = urljoin(endpoint.strip("/") + "/", LLAMACPP_API_SUFFIX.strip("/"))
response = requests.post(URI, json=request) response = requests.post(URI, json=request)
if response.status_code == 200: if response.status_code == 200:
result = response.json() result = response.json()

View File

@ -5,14 +5,12 @@ import requests
from .settings import SIMPLE from .settings import SIMPLE
from ..utils import count_tokens from ..utils import count_tokens
HOST = os.getenv("OPENAI_API_BASE")
HOST_TYPE = os.getenv("BACKEND_TYPE") # default None == ChatCompletion
LMSTUDIO_API_CHAT_SUFFIX = "/v1/chat/completions" LMSTUDIO_API_CHAT_SUFFIX = "/v1/chat/completions"
LMSTUDIO_API_COMPLETIONS_SUFFIX = "/v1/completions" LMSTUDIO_API_COMPLETIONS_SUFFIX = "/v1/completions"
DEBUG = False DEBUG = False
def get_lmstudio_completion(prompt, context_window, settings=SIMPLE, api="chat"): def get_lmstudio_completion(endpoint, prompt, context_window, settings=SIMPLE, api="chat"):
"""Based on the example for using LM Studio as a backend from https://github.com/lmstudio-ai/examples/tree/main/Hello%2C%20world%20-%20OpenAI%20python%20client""" """Based on the example for using LM Studio as a backend from https://github.com/lmstudio-ai/examples/tree/main/Hello%2C%20world%20-%20OpenAI%20python%20client"""
prompt_tokens = count_tokens(prompt) prompt_tokens = count_tokens(prompt)
if prompt_tokens > context_window: if prompt_tokens > context_window:
@ -25,19 +23,19 @@ def get_lmstudio_completion(prompt, context_window, settings=SIMPLE, api="chat")
if api == "chat": if api == "chat":
# Uses the ChatCompletions API style # Uses the ChatCompletions API style
# Seems to work better, probably because it's applying some extra settings under-the-hood? # Seems to work better, probably because it's applying some extra settings under-the-hood?
URI = urljoin(HOST.strip("/") + "/", LMSTUDIO_API_CHAT_SUFFIX.strip("/")) URI = urljoin(endpoint.strip("/") + "/", LMSTUDIO_API_CHAT_SUFFIX.strip("/"))
message_structure = [{"role": "user", "content": prompt}] message_structure = [{"role": "user", "content": prompt}]
request["messages"] = message_structure request["messages"] = message_structure
elif api == "completions": elif api == "completions":
# Uses basic string completions (string in, string out) # Uses basic string completions (string in, string out)
# Does not work as well as ChatCompletions for some reason # Does not work as well as ChatCompletions for some reason
URI = urljoin(HOST.strip("/") + "/", LMSTUDIO_API_COMPLETIONS_SUFFIX.strip("/")) URI = urljoin(endpoint.strip("/") + "/", LMSTUDIO_API_COMPLETIONS_SUFFIX.strip("/"))
request["prompt"] = prompt request["prompt"] = prompt
else: else:
raise ValueError(api) raise ValueError(api)
if not HOST.startswith(("http://", "https://")): if not endpoint.startswith(("http://", "https://")):
raise ValueError(f"Provided OPENAI_API_BASE value ({HOST}) must begin with http:// or https://") raise ValueError(f"Provided OPENAI_API_BASE value ({endpoint}) must begin with http:// or https://")
try: try:
response = requests.post(URI, json=request) response = requests.post(URI, json=request)

View File

@ -6,26 +6,25 @@ from .settings import SIMPLE
from ..utils import count_tokens from ..utils import count_tokens
from ...errors import LocalLLMError from ...errors import LocalLLMError
HOST = os.getenv("OPENAI_API_BASE")
HOST_TYPE = os.getenv("BACKEND_TYPE") # default None == ChatCompletion
MODEL_NAME = os.getenv("OLLAMA_MODEL") # ollama API requires this in the request
OLLAMA_API_SUFFIX = "/api/generate" OLLAMA_API_SUFFIX = "/api/generate"
DEBUG = False DEBUG = False
def get_ollama_completion(prompt, context_window, settings=SIMPLE, grammar=None): def get_ollama_completion(endpoint, model, prompt, context_window, settings=SIMPLE, grammar=None):
"""See https://github.com/jmorganca/ollama/blob/main/docs/api.md for instructions on how to run the LLM web server""" """See https://github.com/jmorganca/ollama/blob/main/docs/api.md for instructions on how to run the LLM web server"""
prompt_tokens = count_tokens(prompt) prompt_tokens = count_tokens(prompt)
if prompt_tokens > context_window: if prompt_tokens > context_window:
raise Exception(f"Request exceeds maximum context length ({prompt_tokens} > {context_window} tokens)") raise Exception(f"Request exceeds maximum context length ({prompt_tokens} > {context_window} tokens)")
if MODEL_NAME is None: if model is None:
raise LocalLLMError(f"Error: OLLAMA_MODEL not specified. Set OLLAMA_MODEL to the model you want to run (e.g. 'dolphin2.2-mistral')") raise LocalLLMError(
f"Error: model name not specified. Set model in your config to the model you want to run (e.g. 'dolphin2.2-mistral')"
)
# Settings for the generation, includes the prompt + stop tokens, max length, etc # Settings for the generation, includes the prompt + stop tokens, max length, etc
request = settings request = settings
request["prompt"] = prompt request["prompt"] = prompt
request["model"] = MODEL_NAME request["model"] = model
request["options"]["num_ctx"] = context_window request["options"]["num_ctx"] = context_window
# Set grammar # Set grammar
@ -33,11 +32,11 @@ def get_ollama_completion(prompt, context_window, settings=SIMPLE, grammar=None)
# request["grammar_string"] = load_grammar_file(grammar) # request["grammar_string"] = load_grammar_file(grammar)
raise NotImplementedError(f"Ollama does not support grammars") raise NotImplementedError(f"Ollama does not support grammars")
if not HOST.startswith(("http://", "https://")): if not endpoint.startswith(("http://", "https://")):
raise ValueError(f"Provided OPENAI_API_BASE value ({HOST}) must begin with http:// or https://") raise ValueError(f"Provided OPENAI_API_BASE value ({endpoint}) must begin with http:// or https://")
try: try:
URI = urljoin(HOST.strip("/") + "/", OLLAMA_API_SUFFIX.strip("/")) URI = urljoin(endpoint.strip("/") + "/", OLLAMA_API_SUFFIX.strip("/"))
response = requests.post(URI, json=request) response = requests.post(URI, json=request)
if response.status_code == 200: if response.status_code == 200:
result = response.json() result = response.json()

View File

@ -1,6 +1,10 @@
import os import os
import tiktoken import tiktoken
import memgpt.local_llm.llm_chat_completion_wrappers.airoboros as airoboros
import memgpt.local_llm.llm_chat_completion_wrappers.dolphin as dolphin
import memgpt.local_llm.llm_chat_completion_wrappers.zephyr as zephyr
class DotDict(dict): class DotDict(dict):
"""Allow dot access on properties similar to OpenAI response object""" """Allow dot access on properties similar to OpenAI response object"""
@ -37,3 +41,14 @@ def load_grammar_file(grammar):
def count_tokens(s: str, model: str = "gpt-4") -> int: def count_tokens(s: str, model: str = "gpt-4") -> int:
encoding = tiktoken.encoding_for_model(model) encoding = tiktoken.encoding_for_model(model)
return len(encoding.encode(s)) return len(encoding.encode(s))
def get_available_wrappers() -> dict:
return {
"airoboros-l2-70b-2.1": airoboros.Airoboros21InnerMonologueWrapper(),
"airoboros-l2-70b-2.1-grammar": airoboros.Airoboros21InnerMonologueWrapper(include_opening_brace_in_prefix=False),
"dolphin-2.1-mistral-7b": dolphin.Dolphin21MistralWrapper(),
"dolphin-2.1-mistral-7b-grammar": dolphin.Dolphin21MistralWrapper(include_opening_brace_in_prefix=False),
"zephyr-7B": zephyr.ZephyrMistralInnerMonologueWrapper(),
"zephyr-7B-grammar": zephyr.ZephyrMistralInnerMonologueWrapper(include_opening_brace_in_prefix=False),
}

View File

@ -5,13 +5,11 @@ import requests
from .settings import SIMPLE from .settings import SIMPLE
from ..utils import load_grammar_file, count_tokens from ..utils import load_grammar_file, count_tokens
HOST = os.getenv("OPENAI_API_BASE")
HOST_TYPE = os.getenv("BACKEND_TYPE") # default None == ChatCompletion
WEBUI_API_SUFFIX = "/api/v1/generate" WEBUI_API_SUFFIX = "/api/v1/generate"
DEBUG = False DEBUG = False
def get_webui_completion(prompt, context_window, settings=SIMPLE, grammar=None): def get_webui_completion(endpoint, prompt, context_window, settings=SIMPLE, grammar=None):
"""See https://github.com/oobabooga/text-generation-webui for instructions on how to run the LLM web server""" """See https://github.com/oobabooga/text-generation-webui for instructions on how to run the LLM web server"""
prompt_tokens = count_tokens(prompt) prompt_tokens = count_tokens(prompt)
if prompt_tokens > context_window: if prompt_tokens > context_window:
@ -26,11 +24,11 @@ def get_webui_completion(prompt, context_window, settings=SIMPLE, grammar=None):
if grammar is not None: if grammar is not None:
request["grammar_string"] = load_grammar_file(grammar) request["grammar_string"] = load_grammar_file(grammar)
if not HOST.startswith(("http://", "https://")): if not endpoint.startswith(("http://", "https://")):
raise ValueError(f"Provided OPENAI_API_BASE value ({HOST}) must begin with http:// or https://") raise ValueError(f"Provided OPENAI_API_BASE value ({endpoint}) must begin with http:// or https://")
try: try:
URI = urljoin(HOST.strip("/") + "/", WEBUI_API_SUFFIX.strip("/")) URI = urljoin(endpoint.strip("/") + "/", WEBUI_API_SUFFIX.strip("/"))
response = requests.post(URI, json=request) response = requests.post(URI, json=request)
if response.status_code == 200: if response.status_code == 200:
result = response.json() result = response.json()

View File

@ -241,7 +241,7 @@ def main(
memgpt_persona = persona memgpt_persona = persona
if memgpt_persona is None: if memgpt_persona is None:
memgpt_persona = ( memgpt_persona = (
personas.GPT35_DEFAULT if "gpt-3.5" in model else personas.DEFAULT, personas.GPT35_DEFAULT if (model is not None and "gpt-3.5" in model) else personas.DEFAULT,
None, # represents the personas dir in pymemgpt package None, # represents the personas dir in pymemgpt package
) )
else: else:

View File

@ -2,10 +2,14 @@ import random
import os import os
import time import time
from .local_llm.chat_completion_proxy import get_chat_completion import time
from typing import Callable, TypeVar
from memgpt.local_llm.chat_completion_proxy import get_chat_completion
HOST = os.getenv("OPENAI_API_BASE") HOST = os.getenv("OPENAI_API_BASE")
HOST_TYPE = os.getenv("BACKEND_TYPE") # default None == ChatCompletion HOST_TYPE = os.getenv("BACKEND_TYPE") # default None == ChatCompletion
R = TypeVar("R")
import openai import openai
@ -55,6 +59,7 @@ def retry_with_exponential_backoff(
return wrapper return wrapper
# TODO: delete/ignore --legacy
@retry_with_exponential_backoff @retry_with_exponential_backoff
def completions_with_backoff(**kwargs): def completions_with_backoff(**kwargs):
# Local model # Local model
@ -75,6 +80,38 @@ def completions_with_backoff(**kwargs):
return openai.ChatCompletion.create(**kwargs) return openai.ChatCompletion.create(**kwargs)
@retry_with_exponential_backoff
def chat_completion_with_backoff(agent_config, **kwargs):
from memgpt.utils import printd
from memgpt.config import AgentConfig, MemGPTConfig
printd(f"Using model {agent_config.model_endpoint_type}, endpoint: {agent_config.model_endpoint}")
if agent_config.model_endpoint_type == "openai":
# openai
openai.api_base = agent_config.model_endpoint
return openai.ChatCompletion.create(**kwargs)
elif agent_config.model_endpoint_type == "azure":
# configure openai
config = MemGPTConfig.load() # load credentials (currently not stored in agent config)
openai.api_type = "azure"
openai.api_key = config.azure_key
openai.api_base = config.azure_endpoint
openai.api_version = config.azure_version
if config.azure_deployment is not None:
kwargs["deployment_id"] = config.azure_deployment
else:
kwargs["engine"] = MODEL_TO_AZURE_ENGINE[config.model]
del kwargs["model"]
return openai.ChatCompletion.create(**kwargs)
else: # local model
kwargs["context_window"] = agent_config.context_window # specify for open LLMs
kwargs["endpoint"] = agent_config.model_endpoint # specify for open LLMs
kwargs["endpoint_type"] = agent_config.model_endpoint_type # specify for open LLMs
kwargs["wrapper"] = agent_config.model_wrapper # specify for open LLMs
return get_chat_completion(**kwargs)
# TODO: deprecate
@retry_with_exponential_backoff @retry_with_exponential_backoff
def create_embedding_with_backoff(**kwargs): def create_embedding_with_backoff(**kwargs):
if using_azure(): if using_azure():

View File

@ -57,5 +57,5 @@ def use_preset(preset_name, agent_config, model, persona, human, interface, pers
persona_notes=persona, persona_notes=persona,
human_notes=human, human_notes=human,
# gpt-3.5-turbo tends to omit inner monologue, relax this requirement for now # gpt-3.5-turbo tends to omit inner monologue, relax this requirement for now
first_message_verify_mono=True if "gpt-4" in model else False, first_message_verify_mono=True if (model is not None and "gpt-4" in model) else False,
) )

View File

@ -13,7 +13,7 @@ def test_configure_memgpt():
def test_save_load(): def test_save_load():
configure_memgpt() # configure_memgpt() # rely on configure running first^
child = pexpect.spawn("memgpt run --agent test_save_load --first --strip_ui") child = pexpect.spawn("memgpt run --agent test_save_load --first --strip_ui")
child.expect("Enter your message:", timeout=TIMEOUT) child.expect("Enter your message:", timeout=TIMEOUT)

View File

@ -1,12 +1,12 @@
# import tempfile # import tempfile
# import asyncio # import asyncio
import os # import os
# import asyncio # import asyncio
from datasets import load_dataset # from datasets import load_dataset
import memgpt # import memgpt
from memgpt.cli.cli_load import load_directory, load_database, load_webpage # from memgpt.cli.cli_load import load_directory, load_database, load_webpage
# import memgpt.presets as presets # import memgpt.presets as presets
# import memgpt.personas.personas as personas # import memgpt.personas.personas as personas

View File

@ -1,6 +1,7 @@
import os import os
import subprocess import subprocess
import sys import sys
import pytest
subprocess.check_call( subprocess.check_call(
[sys.executable, "-m", "pip", "install", "pgvector", "psycopg", "psycopg2-binary"] [sys.executable, "-m", "pip", "install", "pgvector", "psycopg", "psycopg2-binary"]
@ -15,9 +16,11 @@ from memgpt.config import MemGPTConfig, AgentConfig
import argparse import argparse
@pytest.mark.skipif(not os.getenv("PGVECTOR_TEST_DB_URL") or not os.getenv("OPENAI_API_KEY"), reason="Missing PG URI and/or OpenAI API key")
def test_postgres_openai(): def test_postgres_openai():
assert os.getenv("PGVECTOR_TEST_DB_URL") is not None if not os.getenv("PGVECTOR_TEST_DB_URL"):
if os.getenv("OPENAI_API_KEY") is None: return # soft pass
if not os.getenv("OPENAI_API_KEY"):
return # soft pass return # soft pass
# os.environ["MEMGPT_CONFIG_PATH"] = "./config" # os.environ["MEMGPT_CONFIG_PATH"] = "./config"
@ -54,14 +57,16 @@ def test_postgres_openai():
# print("...finished") # print("...finished")
@pytest.mark.skipif(not os.getenv("PGVECTOR_TEST_DB_URL"), reason="Missing PG URI")
def test_postgres_local(): def test_postgres_local():
assert os.getenv("PGVECTOR_TEST_DB_URL") is not None if not os.getenv("PGVECTOR_TEST_DB_URL"):
return
# os.environ["MEMGPT_CONFIG_PATH"] = "./config" # os.environ["MEMGPT_CONFIG_PATH"] = "./config"
config = MemGPTConfig( config = MemGPTConfig(
archival_storage_type="postgres", archival_storage_type="postgres",
archival_storage_uri=os.getenv("PGVECTOR_TEST_DB_URL"), archival_storage_uri=os.getenv("PGVECTOR_TEST_DB_URL"),
embedding_model="local", embedding_endpoint_type="local",
embedding_dim=384, # use HF model embedding_dim=384, # use HF model
) )
print(config.config_path) print(config.config_path)

View File

@ -3,43 +3,55 @@ import pexpect
from .constants import TIMEOUT from .constants import TIMEOUT
def configure_memgpt(enable_openai=True, enable_azure=False): def configure_memgpt_localllm():
child = pexpect.spawn("memgpt configure") child = pexpect.spawn("memgpt configure")
child.expect("Do you want to enable MemGPT with OpenAI?", timeout=TIMEOUT) child.expect("Select LLM inference provider", timeout=TIMEOUT)
if enable_openai: child.send("\x1b[B") # Send the down arrow key
child.sendline("y") child.send("\x1b[B") # Send the down arrow key
else:
child.sendline("n")
child.expect("Do you want to enable MemGPT with Azure?", timeout=TIMEOUT)
if enable_azure:
child.sendline("y")
else:
child.sendline("n")
child.expect("Select default inference endpoint:", timeout=TIMEOUT)
child.sendline() child.sendline()
child.expect("Select default embedding endpoint:", timeout=TIMEOUT) child.expect("Select LLM backend", timeout=TIMEOUT)
child.sendline() child.sendline()
child.expect("Select default preset:", timeout=TIMEOUT) child.expect("Enter default endpoint", timeout=TIMEOUT)
child.sendline() child.sendline()
child.expect("Select default model", timeout=TIMEOUT) child.expect("Select default model wrapper", timeout=TIMEOUT)
child.sendline() child.sendline()
child.expect("Select default persona:", timeout=TIMEOUT) child.expect("Select your model's context window", timeout=TIMEOUT)
child.sendline() child.sendline()
child.expect("Select default human:", timeout=TIMEOUT) child.expect("Select embedding provider", timeout=TIMEOUT)
child.send("\x1b[B") # Send the down arrow key
child.send("\x1b[B") # Send the down arrow key
child.sendline()
child.expect("Select default preset", timeout=TIMEOUT)
child.sendline()
child.expect("Select default persona", timeout=TIMEOUT)
child.sendline()
child.expect("Select default human", timeout=TIMEOUT)
child.sendline()
child.expect("Select storage backend for archival data", timeout=TIMEOUT)
child.sendline() child.sendline()
child.expect("Select storage backend for archival data:", timeout=TIMEOUT)
child.sendline() child.sendline()
child.expect(pexpect.EOF, timeout=TIMEOUT) # Wait for child to exit child.expect(pexpect.EOF, timeout=TIMEOUT) # Wait for child to exit
child.close() child.close()
assert child.isalive() is False, "CLI should have terminated." assert child.isalive() is False, "CLI should have terminated."
assert child.exitstatus == 0, "CLI did not exit cleanly." assert child.exitstatus == 0, "CLI did not exit cleanly."
def configure_memgpt(enable_openai=False, enable_azure=False):
if enable_openai:
raise NotImplementedError
elif enable_azure:
raise NotImplementedError
else:
configure_memgpt_localllm()