Refactor config + determine LLM via config.model_endpoint_type (#422)

* mark depricated API section

* CLI bug fixes for azure

* check azure before running

* Update README.md

* Update README.md

* bug fix with persona loading

* remove print

* make errors for cli flags more clear

* format

* fix imports

* fix imports

* add prints

* update lock

* update config fields

* cleanup config loading

* commit

* remove asserts

* refactor configure

* put into different functions

* add embedding default

* pass in config

* fixes

* allow overriding openai embedding endpoint

* black

* trying to patch tests (some circular import errors)

* update flags and docs

* patched support for local llms using endpoint and endpoint type passed via configs, not env vars

* missing files

* fix naming

* fix import

* fix two runtime errors

* patch ollama typo, move ollama model question pre-wrapper, modify question phrasing to include link to readthedocs, also have a default ollama model that has a tag included

* disable debug messages

* made error message for failed load more informative

* don't print dynamic linking function warning unless --debug

* updated tests to work with new cli workflow (disabled openai config test for now)

* added skips for tests when vars are missing

* update bad arg

* revise test to soft pass on empty string too

* don't run configure twice

* extend timeout (try to pass against nltk download)

* update defaults

* typo with endpoint type default

* patch runtime errors for when model is None

* catching another case of 'x in model' when model is None (preemptively)

* allow overrides to local llm related config params

* made model wrapper selection from a list vs raw input

* update test for select instead of input

* Fixed bug in endpoint when using local->openai selection, also added validation loop to manual endpoint entry

* updated error messages to be more informative with links to readthedocs

* add back gpt3.5-turbo

---------

Co-authored-by: cpacker <packercharles@gmail.com>
This commit is contained in:
Sarah Wooders 2023-11-14 15:58:19 -08:00 committed by GitHub
parent 8fdc3a29da
commit 28514da5df
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 628 additions and 435 deletions

View File

@ -6,15 +6,21 @@ The `memgpt run` command supports the following optional flags (if set, will ove
* `--agent`: (str) Name of agent to create or to resume chatting with.
* `--human`: (str) Name of the human to run the agent with.
* `--persona`: (str) Name of agent persona to use.
* `--model`: (str) LLM model to run [gpt-4, gpt-3.5].
* `--model`: (str) LLM model to run (e.g. `gpt-4`, `dolphin_xxx`)
* `--preset`: (str) MemGPT preset to run agent with.
* `--first`: (str) Allow user to sent the first message.
* `--debug`: (bool) Show debug logs (default=False)
* `--no-verify`: (bool) Bypass message verification (default=False)
* `--yes`/`-y`: (bool) Skip confirmation prompt and use defaults (default=False)
You can override the parameters you set with `memgpt configure` with the following additional flags specific to local LLMs:
* `--model-wrapper`: (str) Model wrapper used by backend (e.g. `airoboros_xxx`)
* `--model-endpoint-type`: (str) Model endpoint backend type (e.g. lmstudio, ollama)
* `--model-endpoint`: (str) Model endpoint url (e.g. `localhost:5000`)
* `--context-window`: (int) Size of model context window (specific to model type)
#### Updating the config location
You can override the location of the config path by setting the enviornment variable `MEMGPT_CONFIG_PATH`:
You can override the location of the config path by setting the environment variable `MEMGPT_CONFIG_PATH`:
```
export MEMGPT_CONFIG_PATH=/my/custom/path/config # make sure this is a file, not a directory
```

View File

@ -9,6 +9,7 @@ from memgpt.config import AgentConfig
from .system import get_login_event, package_function_response, package_summarize_message, get_initial_boot_messages
from .memory import CoreMemory as Memory, summarize_messages
from .openai_tools import completions_with_backoff as create
from memgpt.openai_tools import chat_completion_with_backoff
from .utils import get_local_time, parse_json, united_diff, printd, count_tokens, get_schema_diff
from .constants import (
FIRST_MESSAGE_ATTEMPTS,
@ -73,7 +74,7 @@ def initialize_message_sequence(
first_user_message = get_login_event() # event letting MemGPT know the user just logged in
if include_initial_boot_message:
if "gpt-3.5" in model:
if model is not None and "gpt-3.5" in model:
initial_boot_messages = get_initial_boot_messages("startup_with_send_message_gpt35")
else:
initial_boot_messages = get_initial_boot_messages("startup_with_send_message")
@ -96,37 +97,6 @@ def initialize_message_sequence(
return messages
def get_ai_reply(
model,
message_sequence,
functions,
function_call="auto",
context_window=None,
):
try:
response = create(
model=model,
context_window=context_window,
messages=message_sequence,
functions=functions,
function_call=function_call,
)
# special case for 'length'
if response.choices[0].finish_reason == "length":
raise Exception("Finish reason was length (maximum context length)")
# catches for soft errors
if response.choices[0].finish_reason not in ["stop", "function_call"]:
raise Exception(f"API call finish with bad finish reason: {response}")
# unpack with response.choices[0].message.content
return response
except Exception as e:
raise e
class Agent(object):
def __init__(
self,
@ -310,7 +280,7 @@ class Agent(object):
json_files = glob.glob(os.path.join(directory, "*.json")) # This will list all .json files in the current directory.
if not json_files:
print(f"/load error: no .json checkpoint files found")
raise ValueError(f"Cannot load {agent_name}: does not exist in {directory}")
raise ValueError(f"Cannot load {agent_name} - no saved checkpoints found in {directory}")
# Sort files based on modified timestamp, with the latest file being the first.
filename = max(json_files, key=os.path.getmtime)
@ -360,7 +330,7 @@ class Agent(object):
# NOTE to handle old configs, instead of erroring here let's just warn
# raise ValueError(error_message)
print(error_message)
printd(error_message)
linked_function_set[f_name] = linked_function
messages = state["messages"]
@ -602,8 +572,7 @@ class Agent(object):
printd(f"This is the first message. Running extra verifier on AI response.")
counter = 0
while True:
response = get_ai_reply(
model=self.model,
response = self.get_ai_reply(
message_sequence=input_message_sequence,
functions=self.functions,
context_window=None if self.config.context_window is None else int(self.config.context_window),
@ -616,8 +585,7 @@ class Agent(object):
raise Exception(f"Hit first message retry limit ({first_message_retry_limit})")
else:
response = get_ai_reply(
model=self.model,
response = self.get_ai_reply(
message_sequence=input_message_sequence,
functions=self.functions,
context_window=None if self.config.context_window is None else int(self.config.context_window),
@ -785,3 +753,55 @@ class Agent(object):
# Check if it's been more than pause_heartbeats_minutes since pause_heartbeats_start
elapsed_time = datetime.datetime.now() - self.pause_heartbeats_start
return elapsed_time.total_seconds() < self.pause_heartbeats_minutes * 60
def get_ai_reply(
self,
message_sequence,
function_call="auto",
):
"""Get response from LLM API"""
# TODO: Legacy code - delete
if self.config is None:
try:
response = create(
model=self.model,
context_window=self.context_window,
messages=message_sequence,
functions=self.functions,
function_call=function_call,
)
# special case for 'length'
if response.choices[0].finish_reason == "length":
raise Exception("Finish reason was length (maximum context length)")
# catches for soft errors
if response.choices[0].finish_reason not in ["stop", "function_call"]:
raise Exception(f"API call finish with bad finish reason: {response}")
# unpack with response.choices[0].message.content
return response
except Exception as e:
raise e
try:
response = chat_completion_with_backoff(
agent_config=self.config,
model=self.model, # TODO: remove (is redundant)
messages=message_sequence,
functions=self.functions,
function_call=function_call,
)
# special case for 'length'
if response.choices[0].finish_reason == "length":
raise Exception("Finish reason was length (maximum context length)")
# catches for soft errors
if response.choices[0].finish_reason not in ["stop", "function_call"]:
raise Exception(f"API call finish with bad finish reason: {response}")
# unpack with response.choices[0].message.content
return response
except Exception as e:
raise e

View File

@ -1,4 +1,5 @@
import typer
import json
import sys
import io
import logging
@ -35,16 +36,21 @@ def run(
persona: str = typer.Option(None, help="Specify persona"),
agent: str = typer.Option(None, help="Specify agent save file"),
human: str = typer.Option(None, help="Specify human"),
model: str = typer.Option(None, help="Specify the LLM model"),
preset: str = typer.Option(None, help="Specify preset"),
# model flags
model: str = typer.Option(None, help="Specify the LLM model"),
model_wrapper: str = typer.Option(None, help="Specify the LLM model wrapper"),
model_endpoint: str = typer.Option(None, help="Specify the LLM model endpoint"),
model_endpoint_type: str = typer.Option(None, help="Specify the LLM model endpoint type"),
context_window: int = typer.Option(
None, "--context_window", help="The context window of the LLM you are using (e.g. 8k for most Mistral 7B variants)"
),
# other
first: bool = typer.Option(False, "--first", help="Use --first to send the first message in the sequence"),
strip_ui: bool = typer.Option(False, "--strip_ui", help="Remove all the bells and whistles in CLI output (helpful for testing)"),
debug: bool = typer.Option(False, "--debug", help="Use --debug to enable debugging output"),
no_verify: bool = typer.Option(False, "--no_verify", help="Bypass message verification"),
yes: bool = typer.Option(False, "-y", help="Skip confirmation prompt and use defaults"),
context_window: int = typer.Option(
None, "--context_window", help="The context window of the LLM you are using (e.g. 8k for most Mistral 7B variants)"
),
):
"""Start chatting with an MemGPT agent
@ -99,11 +105,6 @@ def run(
set_global_service_context(service_context)
sys.stdout = original_stdout
# overwrite the context_window if specified
if context_window is not None and int(context_window) != int(config.context_window):
typer.secho(f"Warning: Overriding existing context window {config.context_window} with {context_window}", fg=typer.colors.YELLOW)
config.context_window = str(context_window)
# create agent config
if agent and AgentConfig.exists(agent): # use existing agent
typer.secho(f"Using existing agent {agent}", fg=typer.colors.GREEN)
@ -121,10 +122,34 @@ def run(
typer.secho(f"Warning: Overriding existing human {agent_config.human} with {human}", fg=typer.colors.YELLOW)
agent_config.human = human
# raise ValueError(f"Cannot override {agent_config.name} existing human {agent_config.human} with {human}")
# Allow overriding model specifics (model, model wrapper, model endpoint IP + type, context_window)
if model and model != agent_config.model:
typer.secho(f"Warning: Overriding existing model {agent_config.model} with {model}", fg=typer.colors.YELLOW)
agent_config.model = model
# raise ValueError(f"Cannot override {agent_config.name} existing model {agent_config.model} with {model}")
if context_window is not None and int(context_window) != agent_config.context_window:
typer.secho(
f"Warning: Overriding existing context window {agent_config.context_window} with {context_window}", fg=typer.colors.YELLOW
)
agent_config.context_window = context_window
if model_wrapper and model_wrapper != agent_config.model_wrapper:
typer.secho(
f"Warning: Overriding existing model wrapper {agent_config.model_wrapper} with {model_wrapper}", fg=typer.colors.YELLOW
)
agent_config.model_wrapper = model_wrapper
if model_endpoint and model_endpoint != agent_config.model_endpoint:
typer.secho(
f"Warning: Overriding existing model endpoint {agent_config.model_endpoint} with {model_endpoint}", fg=typer.colors.YELLOW
)
agent_config.model_endpoint = model_endpoint
if model_endpoint_type and model_endpoint_type != agent_config.model_endpoint_type:
typer.secho(
f"Warning: Overriding existing model endpoint type {agent_config.model_endpoint_type} with {model_endpoint_type}",
fg=typer.colors.YELLOW,
)
agent_config.model_endpoint_type = model_endpoint_type
# Update the agent config with any overrides
agent_config.save()
# load existing agent
@ -133,17 +158,17 @@ def run(
# create new agent config: override defaults with args if provided
typer.secho("Creating new agent...", fg=typer.colors.GREEN)
agent_config = AgentConfig(
name=agent if agent else None,
persona=persona if persona else config.default_persona,
human=human if human else config.default_human,
model=model if model else config.model,
context_window=context_window if context_window else config.context_window,
preset=preset if preset else config.preset,
name=agent,
persona=persona,
human=human,
preset=preset,
model=model,
model_wrapper=model_wrapper,
model_endpoint_type=model_endpoint_type,
model_endpoint=model_endpoint,
context_window=context_window,
)
## attach data source to agent
# agent_config.attach_data_source(data_source)
# TODO: allow configrable state manager (only local is supported right now)
persistence_manager = LocalStateManager(agent_config) # TODO: insert dataset/pre-fill
@ -162,6 +187,9 @@ def run(
persistence_manager,
)
# pretty print agent config
printd(json.dumps(vars(agent_config), indent=4, sort_keys=True))
# start event loop
from memgpt.main import run_agent_loop

View File

@ -1,3 +1,4 @@
import builtins
import questionary
import openai
from prettytable import PrettyTable
@ -11,126 +12,118 @@ from memgpt import utils
import memgpt.humans.humans as humans
import memgpt.personas.personas as personas
from memgpt.config import MemGPTConfig, AgentConfig
from memgpt.config import MemGPTConfig, AgentConfig, Config
from memgpt.constants import MEMGPT_DIR
from memgpt.connectors.storage import StorageConnector
from memgpt.constants import LLM_MAX_TOKENS
from memgpt.local_llm.constants import DEFAULT_ENDPOINTS, DEFAULT_OLLAMA_MODEL, DEFAULT_WRAPPER_NAME
from memgpt.local_llm.utils import get_available_wrappers
app = typer.Typer()
@app.command()
def configure():
"""Updates default MemGPT configurations"""
from memgpt.presets.presets import DEFAULT_PRESET, preset_options
MemGPTConfig.create_config_dir()
# Will pre-populate with defaults, or what the user previously set
config = MemGPTConfig.load()
# openai credentials
use_openai = questionary.confirm("Do you want to enable MemGPT with OpenAI?", default=True).ask()
if use_openai:
# search for key in enviornment
openai_key = os.getenv("OPENAI_API_KEY")
if not openai_key:
print("Missing enviornment variables for OpenAI. Please set them and run `memgpt configure` again.")
# TODO: eventually stop relying on env variables and pass in keys explicitly
# openai_key = questionary.text("Open AI API keys not found in enviornment - please enter:").ask()
# azure credentials
use_azure = questionary.confirm("Do you want to enable MemGPT with Azure?", default=(config.azure_key is not None)).ask()
use_azure_deployment_ids = False
if use_azure:
# search for key in enviornment
def get_azure_credentials():
azure_key = os.getenv("AZURE_OPENAI_KEY")
azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
azure_version = os.getenv("AZURE_OPENAI_VERSION")
azure_deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT")
azure_embedding_deployment = os.getenv("AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT")
return azure_key, azure_endpoint, azure_version, azure_deployment, azure_embedding_deployment
if all([azure_key, azure_endpoint, azure_version]):
print(f"Using Microsoft endpoint {azure_endpoint}.")
if all([azure_deployment, azure_embedding_deployment]):
print(f"Using deployment id {azure_deployment}")
use_azure_deployment_ids = True
# configure openai
openai.api_type = "azure"
openai.api_key = azure_key
openai.api_base = azure_endpoint
openai.api_version = azure_version
else:
print("Missing enviornment variables for Azure. Please set then run `memgpt configure` again.")
# TODO: allow for manual setting
use_azure = False
def get_openai_credentials():
openai_key = os.getenv("OPENAI_API_KEY")
return openai_key
# TODO: configure local model
# configure provider
model_endpoint_options = []
if os.getenv("OPENAI_API_BASE") is not None:
model_endpoint_options.append(os.getenv("OPENAI_API_BASE"))
if use_openai:
model_endpoint_options += ["openai"]
if use_azure:
model_endpoint_options += ["azure"]
assert (
len(model_endpoint_options) > 0
), "No endpoints found. Please enable OpenAI, Azure, or set OPENAI_API_BASE to point at the IP address of your LLM server."
valid_default_model = config.model_endpoint in model_endpoint_options
default_endpoint = questionary.select(
"Select default inference endpoint:",
model_endpoint_options,
default=config.model_endpoint if valid_default_model else model_endpoint_options[0],
def configure_llm_endpoint(config: MemGPTConfig):
# configure model endpoint
model_endpoint_type, model_endpoint = None, None
# get default
default_model_endpoint_type = config.model_endpoint_type
if config.model_endpoint_type is not None and config.model_endpoint_type not in ["openai", "azure"]: # local model
default_model_endpoint_type = "local"
provider = questionary.select(
"Select LLM inference provider:", choices=["openai", "azure", "local"], default=default_model_endpoint_type
).ask()
# configure embedding provider
embedding_endpoint_options = []
if use_azure:
embedding_endpoint_options += ["azure"]
if use_openai:
embedding_endpoint_options += ["openai"]
embedding_endpoint_options += ["local"]
valid_default_embedding = config.embedding_model in embedding_endpoint_options
# determine the default selection in a smart way
if "openai" in embedding_endpoint_options and default_endpoint == "openai":
# openai llm -> openai embeddings
default_embedding_endpoint_default = "openai"
elif default_endpoint not in ["openai", "azure"]: # is local
# local llm -> local embeddings
default_embedding_endpoint_default = "local"
# set: model_endpoint_type, model_endpoint
if provider == "openai":
model_endpoint_type = "openai"
model_endpoint = "https://api.openai.com/v1"
model_endpoint = questionary.text("Override default endpoint:", default=model_endpoint).ask()
provider = "openai"
elif provider == "azure":
model_endpoint_type = "azure"
_, model_endpoint, _, _, _ = get_azure_credentials()
else: # local models
backend_options = ["webui", "llamacpp", "koboldcpp", "ollama", "lmstudio", "openai"]
default_model_endpoint_type = None
if config.model_endpoint_type in backend_options:
# set from previous config
default_model_endpoint_type = config.model_endpoint_type
else:
default_embedding_endpoint_default = config.embedding_model if valid_default_embedding else embedding_endpoint_options[-1]
default_embedding_endpoint = questionary.select(
"Select default embedding endpoint:", embedding_endpoint_options, default=default_embedding_endpoint_default
# set form env variable (ok if none)
default_model_endpoint_type = os.getenv("BACKEND_TYPE")
model_endpoint_type = questionary.select(
"Select LLM backend (select 'openai' if you have an OpenAI compatible proxy):",
backend_options,
default=default_model_endpoint_type,
).ask()
# configure embedding dimentions
default_embedding_dim = config.embedding_dim
if default_embedding_endpoint == "local":
# HF model uses lower dimentionality
default_embedding_dim = 384
# set default endpoint
# if OPENAI_API_BASE is set, assume that this is the IP+port the user wanted to use
default_model_endpoint = os.getenv("OPENAI_API_BASE")
# if OPENAI_API_BASE is not set, try to pull a default IP+port format from a hardcoded set
if default_model_endpoint is None:
if model_endpoint_type in DEFAULT_ENDPOINTS:
default_model_endpoint = DEFAULT_ENDPOINTS[model_endpoint_type]
model_endpoint = questionary.text("Enter default endpoint:", default=default_model_endpoint).ask()
else:
# default_model_endpoint = None
model_endpoint = None
while not model_endpoint:
model_endpoint = questionary.text("Enter default endpoint:").ask()
if "http://" not in model_endpoint and "https://" not in model_endpoint:
typer.secho(f"Endpoint must be a valid address", fg=typer.colors.YELLOW)
model_endpoint = None
assert model_endpoint, f"Environment variable OPENAI_API_BASE must be set."
# configure preset
default_preset = questionary.select("Select default preset:", preset_options, default=config.preset).ask()
return model_endpoint_type, model_endpoint
# default model
if use_openai or use_azure:
model_options = []
if use_openai:
model_options += ["gpt-4", "gpt-4-1106-preview", "gpt-3.5-turbo-16k"]
def configure_model(config: MemGPTConfig, model_endpoint_type: str):
# set: model, model_wrapper
model, model_wrapper = None, None
if model_endpoint_type == "openai" or model_endpoint_type == "azure":
model_options = ["gpt-4", "gpt-4-1106-preview", "gpt-3.5-turbo", "gpt-3.5-turbo-16k"]
# TODO: select
valid_model = config.model in model_options
default_model = questionary.select(
model = questionary.select(
"Select default model (recommended: gpt-4):", choices=model_options, default=config.model if valid_model else model_options[0]
).ask()
else:
default_model = "local" # TODO: figure out if this is ok? this is for local endpoint
else: # local models
# ollama also needs model type
if model_endpoint_type == "ollama":
default_model = config.model if config.model and config.model_endpoint_type == "ollama" else DEFAULT_OLLAMA_MODEL
model = questionary.text(
"Enter default model name (required for Ollama, see: https://memgpt.readthedocs.io/en/latest/ollama):",
default=default_model,
).ask()
model = None if len(model) == 0 else model
# get the max tokens (context window) for the model
if default_model == "local" or str(default_model) not in LLM_MAX_TOKENS:
# model wrapper
available_model_wrappers = builtins.list(get_available_wrappers().keys())
model_wrapper = questionary.select(
f"Select default model wrapper (recommended: {DEFAULT_WRAPPER_NAME}):",
choices=available_model_wrappers,
default=DEFAULT_WRAPPER_NAME,
).ask()
# set: context_window
if str(model) not in LLM_MAX_TOKENS:
# Ask the user to specify the context length
context_length_options = [
str(2**12), # 4096
@ -140,46 +133,80 @@ def configure():
str(2**18), # 262144
"custom", # enter yourself
]
default_model_context_window = questionary.select(
context_window = questionary.select(
"Select your model's context window (for Mistral 7B models, this is probably 8k / 8192):",
choices=context_length_options,
default=str(LLM_MAX_TOKENS["DEFAULT"]),
).ask()
# If custom, ask for input
if default_model_context_window == "custom":
if context_window == "custom":
while True:
default_model_context_window = questionary.text("Enter context window (e.g. 8192)").ask()
context_window = questionary.text("Enter context window (e.g. 8192)").ask()
try:
default_model_context_window = int(default_model_context_window)
context_window = int(context_window)
break
except ValueError:
print(f"Context window must be a valid integer")
else:
default_model_context_window = int(default_model_context_window)
context_window = int(context_window)
else:
# Pull the context length from the models
default_model_context_window = LLM_MAX_TOKENS[default_model]
context_window = LLM_MAX_TOKENS[model]
return model, model_wrapper, context_window
# defaults
def configure_embedding_endpoint(config: MemGPTConfig):
# configure embedding endpoint
default_embedding_endpoint_type = config.embedding_endpoint_type
if config.embedding_endpoint_type is not None and config.embedding_endpoint_type not in ["openai", "azure"]: # local model
default_embedding_endpoint_type = "local"
embedding_endpoint_type, embedding_endpoint, embedding_dim = None, None, None
embedding_provider = questionary.select(
"Select embedding provider:", choices=["openai", "azure", "local"], default=default_embedding_endpoint_type
).ask()
if embedding_provider == "openai":
embedding_endpoint_type = "openai"
embedding_endpoint = "https://api.openai.com/v1"
embedding_dim = 1536
elif embedding_provider == "azure":
embedding_endpoint_type = "azure"
_, _, _, _, embedding_endpoint = get_azure_credentials()
embedding_dim = 1536
else: # local models
embedding_endpoint_type = "local"
embedding_endpoint = None
embedding_dim = 384
return embedding_endpoint_type, embedding_endpoint, embedding_dim
def configure_cli(config: MemGPTConfig):
# set: preset, default_persona, default_human, default_agent``
from memgpt.presets.presets import preset_options
# preset
default_preset = config.preset if config.preset and config.preset in preset_options else None
preset = questionary.select("Select default preset:", preset_options, default=default_preset).ask()
# persona
personas = [os.path.basename(f).replace(".txt", "") for f in utils.list_persona_files()]
# print(personas)
default_persona = questionary.select("Select default persona:", personas, default=config.default_persona).ask()
default_persona = config.persona if config.persona and config.persona in personas else None
persona = questionary.select("Select default persona:", personas, default=default_persona).ask()
# human
humans = [os.path.basename(f).replace(".txt", "") for f in utils.list_human_files()]
# print(humans)
default_human = questionary.select("Select default human:", humans, default=config.default_human).ask()
default_human = config.human if config.human and config.human in humans else None
human = questionary.select("Select default human:", humans, default=default_human).ask()
# TODO: figure out if we should set a default agent or not
default_agent = None
# agents = [os.path.basename(f).replace(".json", "") for f in utils.list_agent_config_files()]
# if len(agents) > 0: # agents have been created
# default_agent = questionary.select(
# "Select default agent:",
# agents
# ).ask()
# else:
# default_agent = None
agent = None
return preset, persona, human, agent
def configure_archival_storage(config: MemGPTConfig):
# Configure archival storage backend
archival_storage_options = ["local", "postgres"]
archival_storage_type = questionary.select(
@ -191,25 +218,65 @@ def configure():
"Enter postgres connection string (e.g. postgresql+pg8000://{user}:{password}@{ip}:5432/{database}):",
default=config.archival_storage_uri if config.archival_storage_uri else "",
).ask()
return archival_storage_type, archival_storage_uri
# TODO: allow configuring embedding model
@app.command()
def configure():
"""Updates default MemGPT configurations"""
MemGPTConfig.create_config_dir()
# Will pre-populate with defaults, or what the user previously set
config = MemGPTConfig.load()
model_endpoint_type, model_endpoint = configure_llm_endpoint(config)
model, model_wrapper, context_window = configure_model(config, model_endpoint_type)
embedding_endpoint_type, embedding_endpoint, embedding_dim = configure_embedding_endpoint(config)
default_preset, default_persona, default_human, default_agent = configure_cli(config)
archival_storage_type, archival_storage_uri = configure_archival_storage(config)
# check credentials
azure_key, azure_endpoint, azure_version, azure_deployment, azure_embedding_deployment = get_azure_credentials()
openai_key = get_openai_credentials()
if model_endpoint_type == "azure" or embedding_endpoint_type == "azure":
if all([azure_key, azure_endpoint, azure_version]):
print(f"Using Microsoft endpoint {azure_endpoint}.")
if all([azure_deployment, azure_embedding_deployment]):
print(f"Using deployment id {azure_deployment}")
else:
raise ValueError(
"Missing environment variables for Azure (see https://memgpt.readthedocs.io/en/latest/endpoints/#azure). Please set then run `memgpt configure` again."
)
if model_endpoint_type == "openai" or embedding_endpoint_type == "openai":
if not openai_key:
raise ValueError(
"Missing environment variables for OpenAI (see https://memgpt.readthedocs.io/en/latest/endpoints/#openai). Please set them and run `memgpt configure` again."
)
config = MemGPTConfig(
model=default_model,
context_window=default_model_context_window,
# model configs
model=model,
model_endpoint=model_endpoint,
model_endpoint_type=model_endpoint_type,
model_wrapper=model_wrapper,
context_window=context_window,
# embedding configs
embedding_endpoint_type=embedding_endpoint_type,
embedding_endpoint=embedding_endpoint,
embedding_dim=embedding_dim,
# cli configs
preset=default_preset,
model_endpoint=default_endpoint,
embedding_model=default_embedding_endpoint,
embedding_dim=default_embedding_dim,
default_persona=default_persona,
default_human=default_human,
default_agent=default_agent,
openai_key=openai_key if use_openai else None,
azure_key=azure_key if use_azure else None,
azure_endpoint=azure_endpoint if use_azure else None,
azure_version=azure_version if use_azure else None,
azure_deployment=azure_deployment if use_azure_deployment_ids else None,
azure_embedding_deployment=azure_embedding_deployment if use_azure_deployment_ids else None,
persona=default_persona,
human=default_human,
agent=default_agent,
# credentials
openai_key=openai_key,
azure_key=azure_key,
azure_endpoint=azure_endpoint,
azure_version=azure_version,
azure_deployment=azure_deployment,
azure_embedding_deployment=azure_embedding_deployment,
# storage
archival_storage_type=archival_storage_type,
archival_storage_uri=archival_storage_uri,
)

View File

@ -16,6 +16,7 @@ from colorama import Fore, Style
from typing import List, Type
import memgpt
import memgpt.utils as utils
from memgpt.interface import CLIInterface as interface
from memgpt.personas.personas import get_persona_text
@ -40,6 +41,24 @@ model_choices = [
]
# helper functions for writing to configs
def get_field(config, section, field):
if section not in config:
return None
if config.has_option(section, field):
return config.get(section, field)
else:
return None
def set_field(config, section, field, value):
if value is None: # cannot write None
return
if section not in config: # create section
config.add_section(section)
config.set(section, field, value)
@dataclass
class MemGPTConfig:
config_path: str = os.path.join(MEMGPT_DIR, "config")
@ -49,9 +68,10 @@ class MemGPTConfig:
preset: str = DEFAULT_PRESET
# model parameters
# provider: str = "openai" # openai, azure, local (TODO)
model_endpoint: str = "openai"
model: str = "gpt-4" # gpt-4, gpt-3.5-turbo, local
model: str = None
model_endpoint_type: str = None
model_endpoint: str = None # localhost:8000
model_wrapper: str = None
context_window: int = LLM_MAX_TOKENS[model] if model in LLM_MAX_TOKENS else LLM_MAX_TOKENS["DEFAULT"]
# model parameters: openai
@ -65,12 +85,13 @@ class MemGPTConfig:
azure_embedding_deployment: str = None
# persona parameters
default_persona: str = personas.DEFAULT
default_human: str = humans.DEFAULT
default_agent: str = None
persona: str = personas.DEFAULT
human: str = humans.DEFAULT
agent: str = None
# embedding parameters
embedding_model: str = "openai"
embedding_endpoint_type: str = "openai" # openai, azure, local
embedding_endpoint: str = None
embedding_dim: int = 1536
embedding_chunk_size: int = 300 # number of tokens
@ -89,6 +110,12 @@ class MemGPTConfig:
persistence_manager_save_file: str = None # local file
persistence_manager_uri: str = None # db URI
def __post_init__(self):
# ensure types
self.embedding_chunk_size = int(self.embedding_chunk_size)
self.embedding_dim = int(self.embedding_dim)
self.context_window = int(self.context_window)
@staticmethod
def generate_uuid() -> str:
return uuid.UUID(int=uuid.getnode()).hex
@ -104,72 +131,38 @@ class MemGPTConfig:
config_path = MemGPTConfig.config_path
if os.path.exists(config_path):
# read existing config
config.read(config_path)
config_dict = {
"model": get_field(config, "model", "model"),
"model_endpoint": get_field(config, "model", "model_endpoint"),
"model_endpoint_type": get_field(config, "model", "model_endpoint_type"),
"model_wrapper": get_field(config, "model", "model_wrapper"),
"context_window": get_field(config, "model", "context_window"),
"preset": get_field(config, "defaults", "preset"),
"persona": get_field(config, "defaults", "persona"),
"human": get_field(config, "defaults", "human"),
"agent": get_field(config, "defaults", "agent"),
"openai_key": get_field(config, "openai", "key"),
"azure_key": get_field(config, "azure", "key"),
"azure_endpoint": get_field(config, "azure", "endpoint"),
"azure_version": get_field(config, "azure", "version"),
"azure_deployment": get_field(config, "azure", "deployment"),
"azure_embedding_deployment": get_field(config, "azure", "embedding_deployment"),
"embedding_endpoint": get_field(config, "embedding", "embedding_endpoint"),
"embedding_endpoint_type": get_field(config, "embedding", "embedding_endpoint_type"),
"embedding_dim": get_field(config, "embedding", "embedding_dim"),
"embedding_chunk_size": get_field(config, "embedding", "chunk_size"),
"archival_storage_type": get_field(config, "archival_storage", "type"),
"archival_storage_path": get_field(config, "archival_storage", "path"),
"archival_storage_uri": get_field(config, "archival_storage", "uri"),
"anon_clientid": get_field(config, "client", "anon_clientid"),
"config_path": config_path,
}
config_dict = {k: v for k, v in config_dict.items() if v is not None}
return cls(**config_dict)
# read config values
model = config.get("defaults", "model")
context_window = (
int(config.get("defaults", "context_window"))
if config.has_option("defaults", "context_window")
else LLM_MAX_TOKENS["DEFAULT"]
)
preset = config.get("defaults", "preset")
model_endpoint = config.get("defaults", "model_endpoint")
default_persona = config.get("defaults", "persona")
default_human = config.get("defaults", "human")
default_agent = config.get("defaults", "agent") if config.has_option("defaults", "agent") else None
openai_key, openai_model = None, None
if "openai" in config:
openai_key = config.get("openai", "key")
azure_key, azure_endpoint, azure_version, azure_deployment, azure_embedding_deployment = None, None, None, None, None
if "azure" in config:
azure_key = config.get("azure", "key")
azure_endpoint = config.get("azure", "endpoint")
azure_version = config.get("azure", "version")
azure_deployment = config.get("azure", "deployment") if config.has_option("azure", "deployment") else None
azure_embedding_deployment = (
config.get("azure", "embedding_deployment") if config.has_option("azure", "embedding_deployment") else None
)
embedding_model = config.get("embedding", "model")
embedding_dim = config.getint("embedding", "dim")
embedding_chunk_size = config.getint("embedding", "chunk_size")
# archival storage
archival_storage_type, archival_storage_path, archival_storage_uri = "local", None, None
if "archival_storage" in config:
archival_storage_type = config.get("archival_storage", "type")
archival_storage_path = config.get("archival_storage", "path") if config.has_option("archival_storage", "path") else None
archival_storage_uri = config.get("archival_storage", "uri") if config.has_option("archival_storage", "uri") else None
anon_clientid = config.get("client", "anon_clientid")
return cls(
model=model,
context_window=context_window,
preset=preset,
model_endpoint=model_endpoint,
default_persona=default_persona,
default_human=default_human,
default_agent=default_agent,
openai_key=openai_key,
azure_key=azure_key,
azure_endpoint=azure_endpoint,
azure_version=azure_version,
azure_deployment=azure_deployment,
azure_embedding_deployment=azure_embedding_deployment,
embedding_model=embedding_model,
embedding_dim=embedding_dim,
embedding_chunk_size=embedding_chunk_size,
archival_storage_type=archival_storage_type,
archival_storage_path=archival_storage_path,
archival_storage_uri=archival_storage_uri,
anon_clientid=anon_clientid,
config_path=config_path,
)
# create new config
anon_clientid = MemGPTConfig.generate_uuid()
config = cls(anon_clientid=anon_clientid, config_path=config_path)
config.save() # save updated config
@ -179,51 +172,43 @@ class MemGPTConfig:
config = configparser.ConfigParser()
# CLI defaults
config.add_section("defaults")
config.set("defaults", "model", self.model)
config.set("defaults", "context_window", str(self.context_window))
config.set("defaults", "preset", self.preset)
assert self.model_endpoint is not None, "Endpoint must be set"
config.set("defaults", "model_endpoint", self.model_endpoint)
config.set("defaults", "persona", self.default_persona)
config.set("defaults", "human", self.default_human)
if self.default_agent:
config.set("defaults", "agent", self.default_agent)
set_field(config, "defaults", "preset", self.preset)
set_field(config, "defaults", "persona", self.persona)
set_field(config, "defaults", "human", self.human)
set_field(config, "defaults", "agent", self.agent)
# security credentials
if self.openai_key:
config.add_section("openai")
config.set("openai", "key", self.openai_key)
# model defaults
set_field(config, "model", "model", self.model)
set_field(config, "model", "model_endpoint", self.model_endpoint)
set_field(config, "model", "model_endpoint_type", self.model_endpoint_type)
set_field(config, "model", "model_wrapper", self.model_wrapper)
set_field(config, "model", "context_window", str(self.context_window))
if self.azure_key:
config.add_section("azure")
config.set("azure", "key", self.azure_key)
config.set("azure", "endpoint", self.azure_endpoint)
config.set("azure", "version", self.azure_version)
if self.azure_deployment:
config.set("azure", "deployment", self.azure_deployment)
config.set("azure", "embedding_deployment", self.azure_embedding_deployment)
# security credentials: openai
set_field(config, "openai", "key", self.openai_key)
# security credentials: azure
set_field(config, "azure", "key", self.azure_key)
set_field(config, "azure", "endpoint", self.azure_endpoint)
set_field(config, "azure", "version", self.azure_version)
set_field(config, "azure", "deployment", self.azure_deployment)
set_field(config, "azure", "embedding_deployment", self.azure_embedding_deployment)
# embeddings
config.add_section("embedding")
config.set("embedding", "model", self.embedding_model)
config.set("embedding", "dim", str(self.embedding_dim))
config.set("embedding", "chunk_size", str(self.embedding_chunk_size))
set_field(config, "embedding", "embedding_endpoint_type", self.embedding_endpoint_type)
set_field(config, "embedding", "embedding_endpoint", self.embedding_endpoint)
set_field(config, "embedding", "embedding_dim", str(self.embedding_dim))
set_field(config, "embedding", "embedding_chunk_size", str(self.embedding_chunk_size))
# archival storage
config.add_section("archival_storage")
# print("archival storage", self.archival_storage_type)
config.set("archival_storage", "type", self.archival_storage_type)
if self.archival_storage_path:
config.set("archival_storage", "path", self.archival_storage_path)
if self.archival_storage_uri:
config.set("archival_storage", "uri", self.archival_storage_uri)
set_field(config, "archival_storage", "type", self.archival_storage_type)
set_field(config, "archival_storage", "path", self.archival_storage_path)
set_field(config, "archival_storage", "uri", self.archival_storage_uri)
# client
config.add_section("client")
if not self.anon_clientid:
self.anon_clientid = self.generate_uuid()
config.set("client", "anon_clientid", self.anon_clientid)
set_field(config, "client", "anon_clientid", self.anon_clientid)
if not os.path.exists(MEMGPT_DIR):
os.makedirs(MEMGPT_DIR, exist_ok=True)
@ -262,32 +247,54 @@ class AgentConfig:
self,
persona,
human,
# model info
model,
model_endpoint_type=None,
model_endpoint=None,
model_wrapper=None,
context_window=None,
preset=DEFAULT_PRESET,
name=None,
data_sources=[],
# embedding info
embedding_endpoint_type=None,
embedding_endpoint=None,
embedding_dim=None,
embedding_chunk_size=None,
# other
preset=None,
data_sources=None,
# agent info
agent_config_path=None,
name=None,
create_time=None,
data_source=None,
memgpt_version=None,
):
if name is None:
self.name = f"agent_{self.generate_agent_id()}"
else:
self.name = name
self.persona = persona
self.human = human
self.model = model
self.context_window = context_window
self.preset = preset
self.data_sources = data_sources
self.create_time = create_time if create_time is not None else utils.get_local_time()
self.data_source = None # deprecated
if context_window is None:
self.context_window = LLM_MAX_TOKENS[self.model] if self.model in LLM_MAX_TOKENS else LLM_MAX_TOKENS["DEFAULT"]
config = MemGPTConfig.load() # get default values
self.persona = config.persona if persona is None else persona
self.human = config.human if human is None else human
self.preset = config.preset if preset is None else preset
self.context_window = config.context_window if context_window is None else context_window
self.model = config.model if model is None else model
self.model_endpoint_type = config.model_endpoint_type if model_endpoint_type is None else model_endpoint_type
self.model_endpoint = config.model_endpoint if model_endpoint is None else model_endpoint
self.model_wrapper = config.model_wrapper if model_wrapper is None else model_wrapper
self.embedding_endpoint_type = config.embedding_endpoint_type if embedding_endpoint_type is None else embedding_endpoint_type
self.embedding_endpoint = config.embedding_endpoint if embedding_endpoint is None else embedding_endpoint
self.embedding_dim = config.embedding_dim if embedding_dim is None else embedding_dim
self.embedding_chunk_size = config.embedding_chunk_size if embedding_chunk_size is None else embedding_chunk_size
# agent metadata
self.data_sources = data_sources if data_sources is not None else []
self.create_time = create_time if create_time is not None else utils.get_local_time()
if memgpt_version is None:
import memgpt
self.memgpt_version = memgpt.__version__
else:
self.context_window = context_window
self.memgpt_version = memgpt_version
# save agent config
self.agent_config_path = (
@ -326,6 +333,8 @@ class AgentConfig:
def save(self):
# save state of persistence manager
os.makedirs(os.path.join(MEMGPT_DIR, "agents", self.name), exist_ok=True)
# save version
self.memgpt_version = memgpt.__version__
with open(self.agent_config_path, "w") as f:
json.dump(vars(self), f, indent=4)
@ -342,7 +351,6 @@ class AgentConfig:
assert os.path.exists(agent_config_path), f"Agent config file does not exist at {agent_config_path}"
with open(agent_config_path, "r") as f:
agent_config = json.load(f)
# allow compatibility accross versions
try:
class_args = inspect.getargspec(cls.__init__).args
@ -354,7 +362,6 @@ class AgentConfig:
if key not in class_args:
utils.printd(f"Removing missing argument {key} from agent config")
del agent_config[key]
return cls(**agent_config)

View File

@ -11,9 +11,9 @@ def embedding_model():
# load config
config = MemGPTConfig.load()
endpoint = config.embedding_model
endpoint = config.embedding_endpoint_type
if endpoint == "openai":
model = OpenAIEmbedding(api_base="https://api.openai.com/v1", api_key=config.openai_key)
model = OpenAIEmbedding(api_base=config.embedding_endpoint, api_key=config.openai_key)
return model
elif endpoint == "azure":
return OpenAIEmbedding(

View File

@ -4,7 +4,7 @@ import os
import json
import math
from ...constants import MAX_PAUSE_HEARTBEATS, RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
from memgpt.constants import MAX_PAUSE_HEARTBEATS, RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
### Functions / tools the agent can use
# All functions should return a response string (or None)

View File

@ -4,8 +4,8 @@ import json
import requests
from ...constants import MESSAGE_CHATGPT_FUNCTION_MODEL, MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE, MAX_PAUSE_HEARTBEATS
from ...openai_tools import completions_with_backoff as create
from memgpt.constants import MESSAGE_CHATGPT_FUNCTION_MODEL, MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE, MAX_PAUSE_HEARTBEATS
from memgpt.openai_tools import completions_with_backoff as create
def message_chatgpt(self, message: str):

View File

@ -10,74 +10,65 @@ from .llamacpp.api import get_llamacpp_completion
from .koboldcpp.api import get_koboldcpp_completion
from .ollama.api import get_ollama_completion
from .llm_chat_completion_wrappers import airoboros, dolphin, zephyr, simple_summary_wrapper
from .utils import DotDict
from .constants import DEFAULT_WRAPPER
from .utils import DotDict, get_available_wrappers
from ..prompts.gpt_summarize import SYSTEM as SUMMARIZE_SYSTEM_MESSAGE
from ..errors import LocalLLMConnectionError, LocalLLMError
HOST = os.getenv("OPENAI_API_BASE")
HOST_TYPE = os.getenv("BACKEND_TYPE") # default None == ChatCompletion
endpoint = os.getenv("OPENAI_API_BASE")
endpoint_type = os.getenv("BACKEND_TYPE") # default None == ChatCompletion
DEBUG = False
# DEBUG = True
DEFAULT_WRAPPER = airoboros.Airoboros21InnerMonologueWrapper
has_shown_warning = False
def get_chat_completion(
model, # no model, since the model is fixed to whatever you set in your own backend
model, # no model required (except for Ollama), since the model is fixed to whatever you set in your own backend
messages,
functions=None,
function_call="auto",
context_window=None,
# required
wrapper=None,
endpoint=None,
endpoint_type=None,
):
assert context_window is not None, "Local LLM calls need the context length to be explicitly set"
assert endpoint is not None, "Local LLM calls need the endpoint (eg http://localendpoint:1234) to be explicitly set"
assert endpoint_type is not None, "Local LLM calls need the endpoint type (eg webui) to be explicitly set"
global has_shown_warning
grammar_name = None
if HOST is None:
raise ValueError(f"The OPENAI_API_BASE environment variable is not defined. Please set it in your environment.")
if HOST_TYPE is None:
raise ValueError(f"The BACKEND_TYPE environment variable is not defined. Please set it in your environment.")
if function_call != "auto":
raise ValueError(f"function_call == {function_call} not supported (auto only)")
available_wrappers = get_available_wrappers()
if messages[0]["role"] == "system" and messages[0]["content"].strip() == SUMMARIZE_SYSTEM_MESSAGE.strip():
# Special case for if the call we're making is coming from the summarizer
llm_wrapper = simple_summary_wrapper.SimpleSummaryWrapper()
elif model == "airoboros-l2-70b-2.1":
llm_wrapper = airoboros.Airoboros21InnerMonologueWrapper()
elif model == "airoboros-l2-70b-2.1-grammar":
llm_wrapper = airoboros.Airoboros21InnerMonologueWrapper(include_opening_brace_in_prefix=False)
# grammar_name = "json"
grammar_name = "json_func_calls_with_inner_thoughts"
elif model == "dolphin-2.1-mistral-7b":
llm_wrapper = dolphin.Dolphin21MistralWrapper()
elif model == "dolphin-2.1-mistral-7b-grammar":
llm_wrapper = dolphin.Dolphin21MistralWrapper(include_opening_brace_in_prefix=False)
# grammar_name = "json"
grammar_name = "json_func_calls_with_inner_thoughts"
elif model == "zephyr-7B-alpha" or model == "zephyr-7B-beta":
llm_wrapper = zephyr.ZephyrMistralInnerMonologueWrapper()
elif model == "zephyr-7B-alpha-grammar" or model == "zephyr-7B-beta-grammar":
llm_wrapper = zephyr.ZephyrMistralInnerMonologueWrapper(include_opening_brace_in_prefix=False)
# grammar_name = "json"
grammar_name = "json_func_calls_with_inner_thoughts"
else:
elif wrapper is None:
# Warn the user that we're using the fallback
if not has_shown_warning:
print(
f"Warning: no wrapper specified for local LLM, using the default wrapper (you can remove this warning by specifying the wrapper with --model)"
f"Warning: no wrapper specified for local LLM, using the default wrapper (you can remove this warning by specifying the wrapper with --wrapper)"
)
has_shown_warning = True
if HOST_TYPE in ["koboldcpp", "llamacpp", "webui"]:
if endpoint_type in ["koboldcpp", "llamacpp", "webui"]:
# make the default to use grammar
llm_wrapper = DEFAULT_WRAPPER(include_opening_brace_in_prefix=False)
# grammar_name = "json"
grammar_name = "json_func_calls_with_inner_thoughts"
else:
llm_wrapper = DEFAULT_WRAPPER()
elif wrapper not in available_wrappers:
raise ValueError(f"Could not find requested wrapper '{wrapper} in available wrappers list:\n{available_wrappers}")
else:
llm_wrapper = available_wrappers[wrapper]
if "grammar" in wrapper:
grammar_name = "json_func_calls_with_inner_thoughts"
if grammar_name is not None and HOST_TYPE not in ["koboldcpp", "llamacpp", "webui"]:
if grammar_name is not None and endpoint_type not in ["koboldcpp", "llamacpp", "webui"]:
print(f"Warning: grammars are currently only supported when using llama.cpp as the MemGPT local LLM backend")
# First step: turn the message sequence into a prompt that the model expects
@ -91,25 +82,25 @@ def get_chat_completion(
)
try:
if HOST_TYPE == "webui":
result = get_webui_completion(prompt, context_window, grammar=grammar_name)
elif HOST_TYPE == "lmstudio":
result = get_lmstudio_completion(prompt, context_window)
elif HOST_TYPE == "llamacpp":
result = get_llamacpp_completion(prompt, context_window, grammar=grammar_name)
elif HOST_TYPE == "koboldcpp":
result = get_koboldcpp_completion(prompt, context_window, grammar=grammar_name)
elif HOST_TYPE == "ollama":
result = get_ollama_completion(prompt, context_window)
if endpoint_type == "webui":
result = get_webui_completion(endpoint, prompt, context_window, grammar=grammar_name)
elif endpoint_type == "lmstudio":
result = get_lmstudio_completion(endpoint, prompt, context_window)
elif endpoint_type == "llamacpp":
result = get_llamacpp_completion(endpoint, prompt, context_window, grammar=grammar_name)
elif endpoint_type == "koboldcpp":
result = get_koboldcpp_completion(endpoint, prompt, context_window, grammar=grammar_name)
elif endpoint_type == "ollama":
result = get_ollama_completion(endpoint, model, prompt, context_window)
else:
raise LocalLLMError(
f"BACKEND_TYPE is not set, please set variable depending on your backend (webui, lmstudio, llamacpp, koboldcpp)"
)
except requests.exceptions.ConnectionError as e:
raise LocalLLMConnectionError(f"Unable to connect to host {HOST}")
raise LocalLLMConnectionError(f"Unable to connect to endpoint {endpoint}")
if result is None or result == "":
raise LocalLLMError(f"Got back an empty response string from {HOST}")
raise LocalLLMError(f"Got back an empty response string from {endpoint}")
if DEBUG:
print(f"Raw LLM output:\n{result}")
@ -123,7 +114,7 @@ def get_chat_completion(
# unpack with response.choices[0].message.content
response = DotDict(
{
"model": None,
"model": model,
"choices": [
DotDict(
{

View File

@ -0,0 +1,14 @@
import memgpt.local_llm.llm_chat_completion_wrappers.airoboros as airoboros
DEFAULT_ENDPOINTS = {
"koboldcpp": "http://localhost:5001",
"llamacpp": "http://localhost:8080",
"lmstudio": "http://localhost:1234",
"ollama": "http://localhost:11434",
"webui": "http://localhost:5000",
}
DEFAULT_OLLAMA_MODEL = "dolphin2.2-mistral:7b-q6_K"
DEFAULT_WRAPPER = airoboros.Airoboros21InnerMonologueWrapper
DEFAULT_WRAPPER_NAME = "airoboros-l2-70b-2.1"

View File

@ -5,14 +5,12 @@ import requests
from .settings import SIMPLE
from ..utils import load_grammar_file, count_tokens
HOST = os.getenv("OPENAI_API_BASE")
HOST_TYPE = os.getenv("BACKEND_TYPE") # default None == ChatCompletion
KOBOLDCPP_API_SUFFIX = "/api/v1/generate"
# DEBUG = False
DEBUG = True
DEBUG = False
# DEBUG = True
def get_koboldcpp_completion(prompt, context_window, grammar=None, settings=SIMPLE):
def get_koboldcpp_completion(endpoint, prompt, context_window, grammar=None, settings=SIMPLE):
"""See https://lite.koboldai.net/koboldcpp_api for API spec"""
prompt_tokens = count_tokens(prompt)
if prompt_tokens > context_window:
@ -27,13 +25,13 @@ def get_koboldcpp_completion(prompt, context_window, grammar=None, settings=SIMP
if grammar is not None:
request["grammar"] = load_grammar_file(grammar)
if not HOST.startswith(("http://", "https://")):
raise ValueError(f"Provided OPENAI_API_BASE value ({HOST}) must begin with http:// or https://")
if not endpoint.startswith(("http://", "https://")):
raise ValueError(f"Provided OPENAI_API_BASE value ({endpoint}) must begin with http:// or https://")
try:
# NOTE: llama.cpp server returns the following when it's out of context
# curl: (52) Empty reply from server
URI = urljoin(HOST.strip("/") + "/", KOBOLDCPP_API_SUFFIX.strip("/"))
URI = urljoin(endpoint.strip("/") + "/", KOBOLDCPP_API_SUFFIX.strip("/"))
response = requests.post(URI, json=request)
if response.status_code == 200:
result = response.json()

View File

@ -5,14 +5,12 @@ import requests
from .settings import SIMPLE
from ..utils import load_grammar_file, count_tokens
HOST = os.getenv("OPENAI_API_BASE")
HOST_TYPE = os.getenv("BACKEND_TYPE") # default None == ChatCompletion
LLAMACPP_API_SUFFIX = "/completion"
# DEBUG = False
DEBUG = True
DEBUG = False
# DEBUG = True
def get_llamacpp_completion(prompt, context_window, grammar=None, settings=SIMPLE):
def get_llamacpp_completion(endpoint, prompt, context_window, grammar=None, settings=SIMPLE):
"""See https://github.com/ggerganov/llama.cpp/blob/master/examples/server/README.md for instructions on how to run the LLM web server"""
prompt_tokens = count_tokens(prompt)
if prompt_tokens > context_window:
@ -26,13 +24,13 @@ def get_llamacpp_completion(prompt, context_window, grammar=None, settings=SIMPL
if grammar is not None:
request["grammar"] = load_grammar_file(grammar)
if not HOST.startswith(("http://", "https://")):
raise ValueError(f"Provided OPENAI_API_BASE value ({HOST}) must begin with http:// or https://")
if not endpoint.startswith(("http://", "https://")):
raise ValueError(f"Provided OPENAI_API_BASE value ({endpoint}) must begin with http:// or https://")
try:
# NOTE: llama.cpp server returns the following when it's out of context
# curl: (52) Empty reply from server
URI = urljoin(HOST.strip("/") + "/", LLAMACPP_API_SUFFIX.strip("/"))
URI = urljoin(endpoint.strip("/") + "/", LLAMACPP_API_SUFFIX.strip("/"))
response = requests.post(URI, json=request)
if response.status_code == 200:
result = response.json()

View File

@ -5,14 +5,12 @@ import requests
from .settings import SIMPLE
from ..utils import count_tokens
HOST = os.getenv("OPENAI_API_BASE")
HOST_TYPE = os.getenv("BACKEND_TYPE") # default None == ChatCompletion
LMSTUDIO_API_CHAT_SUFFIX = "/v1/chat/completions"
LMSTUDIO_API_COMPLETIONS_SUFFIX = "/v1/completions"
DEBUG = False
def get_lmstudio_completion(prompt, context_window, settings=SIMPLE, api="chat"):
def get_lmstudio_completion(endpoint, prompt, context_window, settings=SIMPLE, api="chat"):
"""Based on the example for using LM Studio as a backend from https://github.com/lmstudio-ai/examples/tree/main/Hello%2C%20world%20-%20OpenAI%20python%20client"""
prompt_tokens = count_tokens(prompt)
if prompt_tokens > context_window:
@ -25,19 +23,19 @@ def get_lmstudio_completion(prompt, context_window, settings=SIMPLE, api="chat")
if api == "chat":
# Uses the ChatCompletions API style
# Seems to work better, probably because it's applying some extra settings under-the-hood?
URI = urljoin(HOST.strip("/") + "/", LMSTUDIO_API_CHAT_SUFFIX.strip("/"))
URI = urljoin(endpoint.strip("/") + "/", LMSTUDIO_API_CHAT_SUFFIX.strip("/"))
message_structure = [{"role": "user", "content": prompt}]
request["messages"] = message_structure
elif api == "completions":
# Uses basic string completions (string in, string out)
# Does not work as well as ChatCompletions for some reason
URI = urljoin(HOST.strip("/") + "/", LMSTUDIO_API_COMPLETIONS_SUFFIX.strip("/"))
URI = urljoin(endpoint.strip("/") + "/", LMSTUDIO_API_COMPLETIONS_SUFFIX.strip("/"))
request["prompt"] = prompt
else:
raise ValueError(api)
if not HOST.startswith(("http://", "https://")):
raise ValueError(f"Provided OPENAI_API_BASE value ({HOST}) must begin with http:// or https://")
if not endpoint.startswith(("http://", "https://")):
raise ValueError(f"Provided OPENAI_API_BASE value ({endpoint}) must begin with http:// or https://")
try:
response = requests.post(URI, json=request)

View File

@ -6,26 +6,25 @@ from .settings import SIMPLE
from ..utils import count_tokens
from ...errors import LocalLLMError
HOST = os.getenv("OPENAI_API_BASE")
HOST_TYPE = os.getenv("BACKEND_TYPE") # default None == ChatCompletion
MODEL_NAME = os.getenv("OLLAMA_MODEL") # ollama API requires this in the request
OLLAMA_API_SUFFIX = "/api/generate"
DEBUG = False
def get_ollama_completion(prompt, context_window, settings=SIMPLE, grammar=None):
def get_ollama_completion(endpoint, model, prompt, context_window, settings=SIMPLE, grammar=None):
"""See https://github.com/jmorganca/ollama/blob/main/docs/api.md for instructions on how to run the LLM web server"""
prompt_tokens = count_tokens(prompt)
if prompt_tokens > context_window:
raise Exception(f"Request exceeds maximum context length ({prompt_tokens} > {context_window} tokens)")
if MODEL_NAME is None:
raise LocalLLMError(f"Error: OLLAMA_MODEL not specified. Set OLLAMA_MODEL to the model you want to run (e.g. 'dolphin2.2-mistral')")
if model is None:
raise LocalLLMError(
f"Error: model name not specified. Set model in your config to the model you want to run (e.g. 'dolphin2.2-mistral')"
)
# Settings for the generation, includes the prompt + stop tokens, max length, etc
request = settings
request["prompt"] = prompt
request["model"] = MODEL_NAME
request["model"] = model
request["options"]["num_ctx"] = context_window
# Set grammar
@ -33,11 +32,11 @@ def get_ollama_completion(prompt, context_window, settings=SIMPLE, grammar=None)
# request["grammar_string"] = load_grammar_file(grammar)
raise NotImplementedError(f"Ollama does not support grammars")
if not HOST.startswith(("http://", "https://")):
raise ValueError(f"Provided OPENAI_API_BASE value ({HOST}) must begin with http:// or https://")
if not endpoint.startswith(("http://", "https://")):
raise ValueError(f"Provided OPENAI_API_BASE value ({endpoint}) must begin with http:// or https://")
try:
URI = urljoin(HOST.strip("/") + "/", OLLAMA_API_SUFFIX.strip("/"))
URI = urljoin(endpoint.strip("/") + "/", OLLAMA_API_SUFFIX.strip("/"))
response = requests.post(URI, json=request)
if response.status_code == 200:
result = response.json()

View File

@ -1,6 +1,10 @@
import os
import tiktoken
import memgpt.local_llm.llm_chat_completion_wrappers.airoboros as airoboros
import memgpt.local_llm.llm_chat_completion_wrappers.dolphin as dolphin
import memgpt.local_llm.llm_chat_completion_wrappers.zephyr as zephyr
class DotDict(dict):
"""Allow dot access on properties similar to OpenAI response object"""
@ -37,3 +41,14 @@ def load_grammar_file(grammar):
def count_tokens(s: str, model: str = "gpt-4") -> int:
encoding = tiktoken.encoding_for_model(model)
return len(encoding.encode(s))
def get_available_wrappers() -> dict:
return {
"airoboros-l2-70b-2.1": airoboros.Airoboros21InnerMonologueWrapper(),
"airoboros-l2-70b-2.1-grammar": airoboros.Airoboros21InnerMonologueWrapper(include_opening_brace_in_prefix=False),
"dolphin-2.1-mistral-7b": dolphin.Dolphin21MistralWrapper(),
"dolphin-2.1-mistral-7b-grammar": dolphin.Dolphin21MistralWrapper(include_opening_brace_in_prefix=False),
"zephyr-7B": zephyr.ZephyrMistralInnerMonologueWrapper(),
"zephyr-7B-grammar": zephyr.ZephyrMistralInnerMonologueWrapper(include_opening_brace_in_prefix=False),
}

View File

@ -5,13 +5,11 @@ import requests
from .settings import SIMPLE
from ..utils import load_grammar_file, count_tokens
HOST = os.getenv("OPENAI_API_BASE")
HOST_TYPE = os.getenv("BACKEND_TYPE") # default None == ChatCompletion
WEBUI_API_SUFFIX = "/api/v1/generate"
DEBUG = False
def get_webui_completion(prompt, context_window, settings=SIMPLE, grammar=None):
def get_webui_completion(endpoint, prompt, context_window, settings=SIMPLE, grammar=None):
"""See https://github.com/oobabooga/text-generation-webui for instructions on how to run the LLM web server"""
prompt_tokens = count_tokens(prompt)
if prompt_tokens > context_window:
@ -26,11 +24,11 @@ def get_webui_completion(prompt, context_window, settings=SIMPLE, grammar=None):
if grammar is not None:
request["grammar_string"] = load_grammar_file(grammar)
if not HOST.startswith(("http://", "https://")):
raise ValueError(f"Provided OPENAI_API_BASE value ({HOST}) must begin with http:// or https://")
if not endpoint.startswith(("http://", "https://")):
raise ValueError(f"Provided OPENAI_API_BASE value ({endpoint}) must begin with http:// or https://")
try:
URI = urljoin(HOST.strip("/") + "/", WEBUI_API_SUFFIX.strip("/"))
URI = urljoin(endpoint.strip("/") + "/", WEBUI_API_SUFFIX.strip("/"))
response = requests.post(URI, json=request)
if response.status_code == 200:
result = response.json()

View File

@ -241,7 +241,7 @@ def main(
memgpt_persona = persona
if memgpt_persona is None:
memgpt_persona = (
personas.GPT35_DEFAULT if "gpt-3.5" in model else personas.DEFAULT,
personas.GPT35_DEFAULT if (model is not None and "gpt-3.5" in model) else personas.DEFAULT,
None, # represents the personas dir in pymemgpt package
)
else:

View File

@ -2,10 +2,14 @@ import random
import os
import time
from .local_llm.chat_completion_proxy import get_chat_completion
import time
from typing import Callable, TypeVar
from memgpt.local_llm.chat_completion_proxy import get_chat_completion
HOST = os.getenv("OPENAI_API_BASE")
HOST_TYPE = os.getenv("BACKEND_TYPE") # default None == ChatCompletion
R = TypeVar("R")
import openai
@ -55,6 +59,7 @@ def retry_with_exponential_backoff(
return wrapper
# TODO: delete/ignore --legacy
@retry_with_exponential_backoff
def completions_with_backoff(**kwargs):
# Local model
@ -75,6 +80,38 @@ def completions_with_backoff(**kwargs):
return openai.ChatCompletion.create(**kwargs)
@retry_with_exponential_backoff
def chat_completion_with_backoff(agent_config, **kwargs):
from memgpt.utils import printd
from memgpt.config import AgentConfig, MemGPTConfig
printd(f"Using model {agent_config.model_endpoint_type}, endpoint: {agent_config.model_endpoint}")
if agent_config.model_endpoint_type == "openai":
# openai
openai.api_base = agent_config.model_endpoint
return openai.ChatCompletion.create(**kwargs)
elif agent_config.model_endpoint_type == "azure":
# configure openai
config = MemGPTConfig.load() # load credentials (currently not stored in agent config)
openai.api_type = "azure"
openai.api_key = config.azure_key
openai.api_base = config.azure_endpoint
openai.api_version = config.azure_version
if config.azure_deployment is not None:
kwargs["deployment_id"] = config.azure_deployment
else:
kwargs["engine"] = MODEL_TO_AZURE_ENGINE[config.model]
del kwargs["model"]
return openai.ChatCompletion.create(**kwargs)
else: # local model
kwargs["context_window"] = agent_config.context_window # specify for open LLMs
kwargs["endpoint"] = agent_config.model_endpoint # specify for open LLMs
kwargs["endpoint_type"] = agent_config.model_endpoint_type # specify for open LLMs
kwargs["wrapper"] = agent_config.model_wrapper # specify for open LLMs
return get_chat_completion(**kwargs)
# TODO: deprecate
@retry_with_exponential_backoff
def create_embedding_with_backoff(**kwargs):
if using_azure():

View File

@ -57,5 +57,5 @@ def use_preset(preset_name, agent_config, model, persona, human, interface, pers
persona_notes=persona,
human_notes=human,
# gpt-3.5-turbo tends to omit inner monologue, relax this requirement for now
first_message_verify_mono=True if "gpt-4" in model else False,
first_message_verify_mono=True if (model is not None and "gpt-4" in model) else False,
)

View File

@ -13,7 +13,7 @@ def test_configure_memgpt():
def test_save_load():
configure_memgpt()
# configure_memgpt() # rely on configure running first^
child = pexpect.spawn("memgpt run --agent test_save_load --first --strip_ui")
child.expect("Enter your message:", timeout=TIMEOUT)

View File

@ -1,12 +1,12 @@
# import tempfile
# import asyncio
import os
# import os
# import asyncio
from datasets import load_dataset
# from datasets import load_dataset
import memgpt
from memgpt.cli.cli_load import load_directory, load_database, load_webpage
# import memgpt
# from memgpt.cli.cli_load import load_directory, load_database, load_webpage
# import memgpt.presets as presets
# import memgpt.personas.personas as personas

View File

@ -1,6 +1,7 @@
import os
import subprocess
import sys
import pytest
subprocess.check_call(
[sys.executable, "-m", "pip", "install", "pgvector", "psycopg", "psycopg2-binary"]
@ -15,9 +16,11 @@ from memgpt.config import MemGPTConfig, AgentConfig
import argparse
@pytest.mark.skipif(not os.getenv("PGVECTOR_TEST_DB_URL") or not os.getenv("OPENAI_API_KEY"), reason="Missing PG URI and/or OpenAI API key")
def test_postgres_openai():
assert os.getenv("PGVECTOR_TEST_DB_URL") is not None
if os.getenv("OPENAI_API_KEY") is None:
if not os.getenv("PGVECTOR_TEST_DB_URL"):
return # soft pass
if not os.getenv("OPENAI_API_KEY"):
return # soft pass
# os.environ["MEMGPT_CONFIG_PATH"] = "./config"
@ -54,14 +57,16 @@ def test_postgres_openai():
# print("...finished")
@pytest.mark.skipif(not os.getenv("PGVECTOR_TEST_DB_URL"), reason="Missing PG URI")
def test_postgres_local():
assert os.getenv("PGVECTOR_TEST_DB_URL") is not None
if not os.getenv("PGVECTOR_TEST_DB_URL"):
return
# os.environ["MEMGPT_CONFIG_PATH"] = "./config"
config = MemGPTConfig(
archival_storage_type="postgres",
archival_storage_uri=os.getenv("PGVECTOR_TEST_DB_URL"),
embedding_model="local",
embedding_endpoint_type="local",
embedding_dim=384, # use HF model
)
print(config.config_path)

View File

@ -3,43 +3,55 @@ import pexpect
from .constants import TIMEOUT
def configure_memgpt(enable_openai=True, enable_azure=False):
def configure_memgpt_localllm():
child = pexpect.spawn("memgpt configure")
child.expect("Do you want to enable MemGPT with OpenAI?", timeout=TIMEOUT)
if enable_openai:
child.sendline("y")
else:
child.sendline("n")
child.expect("Do you want to enable MemGPT with Azure?", timeout=TIMEOUT)
if enable_azure:
child.sendline("y")
else:
child.sendline("n")
child.expect("Select default inference endpoint:", timeout=TIMEOUT)
child.expect("Select LLM inference provider", timeout=TIMEOUT)
child.send("\x1b[B") # Send the down arrow key
child.send("\x1b[B") # Send the down arrow key
child.sendline()
child.expect("Select default embedding endpoint:", timeout=TIMEOUT)
child.expect("Select LLM backend", timeout=TIMEOUT)
child.sendline()
child.expect("Select default preset:", timeout=TIMEOUT)
child.expect("Enter default endpoint", timeout=TIMEOUT)
child.sendline()
child.expect("Select default model", timeout=TIMEOUT)
child.expect("Select default model wrapper", timeout=TIMEOUT)
child.sendline()
child.expect("Select default persona:", timeout=TIMEOUT)
child.expect("Select your model's context window", timeout=TIMEOUT)
child.sendline()
child.expect("Select default human:", timeout=TIMEOUT)
child.expect("Select embedding provider", timeout=TIMEOUT)
child.send("\x1b[B") # Send the down arrow key
child.send("\x1b[B") # Send the down arrow key
child.sendline()
child.expect("Select default preset", timeout=TIMEOUT)
child.sendline()
child.expect("Select default persona", timeout=TIMEOUT)
child.sendline()
child.expect("Select default human", timeout=TIMEOUT)
child.sendline()
child.expect("Select storage backend for archival data", timeout=TIMEOUT)
child.sendline()
child.expect("Select storage backend for archival data:", timeout=TIMEOUT)
child.sendline()
child.expect(pexpect.EOF, timeout=TIMEOUT) # Wait for child to exit
child.close()
assert child.isalive() is False, "CLI should have terminated."
assert child.exitstatus == 0, "CLI did not exit cleanly."
def configure_memgpt(enable_openai=False, enable_azure=False):
if enable_openai:
raise NotImplementedError
elif enable_azure:
raise NotImplementedError
else:
configure_memgpt_localllm()