mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
Refactor config + determine LLM via config.model_endpoint_type
(#422)
* mark depricated API section * CLI bug fixes for azure * check azure before running * Update README.md * Update README.md * bug fix with persona loading * remove print * make errors for cli flags more clear * format * fix imports * fix imports * add prints * update lock * update config fields * cleanup config loading * commit * remove asserts * refactor configure * put into different functions * add embedding default * pass in config * fixes * allow overriding openai embedding endpoint * black * trying to patch tests (some circular import errors) * update flags and docs * patched support for local llms using endpoint and endpoint type passed via configs, not env vars * missing files * fix naming * fix import * fix two runtime errors * patch ollama typo, move ollama model question pre-wrapper, modify question phrasing to include link to readthedocs, also have a default ollama model that has a tag included * disable debug messages * made error message for failed load more informative * don't print dynamic linking function warning unless --debug * updated tests to work with new cli workflow (disabled openai config test for now) * added skips for tests when vars are missing * update bad arg * revise test to soft pass on empty string too * don't run configure twice * extend timeout (try to pass against nltk download) * update defaults * typo with endpoint type default * patch runtime errors for when model is None * catching another case of 'x in model' when model is None (preemptively) * allow overrides to local llm related config params * made model wrapper selection from a list vs raw input * update test for select instead of input * Fixed bug in endpoint when using local->openai selection, also added validation loop to manual endpoint entry * updated error messages to be more informative with links to readthedocs * add back gpt3.5-turbo --------- Co-authored-by: cpacker <packercharles@gmail.com>
This commit is contained in:
parent
8fdc3a29da
commit
28514da5df
@ -6,15 +6,21 @@ The `memgpt run` command supports the following optional flags (if set, will ove
|
||||
* `--agent`: (str) Name of agent to create or to resume chatting with.
|
||||
* `--human`: (str) Name of the human to run the agent with.
|
||||
* `--persona`: (str) Name of agent persona to use.
|
||||
* `--model`: (str) LLM model to run [gpt-4, gpt-3.5].
|
||||
* `--model`: (str) LLM model to run (e.g. `gpt-4`, `dolphin_xxx`)
|
||||
* `--preset`: (str) MemGPT preset to run agent with.
|
||||
* `--first`: (str) Allow user to sent the first message.
|
||||
* `--debug`: (bool) Show debug logs (default=False)
|
||||
* `--no-verify`: (bool) Bypass message verification (default=False)
|
||||
* `--yes`/`-y`: (bool) Skip confirmation prompt and use defaults (default=False)
|
||||
|
||||
You can override the parameters you set with `memgpt configure` with the following additional flags specific to local LLMs:
|
||||
* `--model-wrapper`: (str) Model wrapper used by backend (e.g. `airoboros_xxx`)
|
||||
* `--model-endpoint-type`: (str) Model endpoint backend type (e.g. lmstudio, ollama)
|
||||
* `--model-endpoint`: (str) Model endpoint url (e.g. `localhost:5000`)
|
||||
* `--context-window`: (int) Size of model context window (specific to model type)
|
||||
|
||||
#### Updating the config location
|
||||
You can override the location of the config path by setting the enviornment variable `MEMGPT_CONFIG_PATH`:
|
||||
You can override the location of the config path by setting the environment variable `MEMGPT_CONFIG_PATH`:
|
||||
```
|
||||
export MEMGPT_CONFIG_PATH=/my/custom/path/config # make sure this is a file, not a directory
|
||||
```
|
||||
|
@ -9,6 +9,7 @@ from memgpt.config import AgentConfig
|
||||
from .system import get_login_event, package_function_response, package_summarize_message, get_initial_boot_messages
|
||||
from .memory import CoreMemory as Memory, summarize_messages
|
||||
from .openai_tools import completions_with_backoff as create
|
||||
from memgpt.openai_tools import chat_completion_with_backoff
|
||||
from .utils import get_local_time, parse_json, united_diff, printd, count_tokens, get_schema_diff
|
||||
from .constants import (
|
||||
FIRST_MESSAGE_ATTEMPTS,
|
||||
@ -73,7 +74,7 @@ def initialize_message_sequence(
|
||||
first_user_message = get_login_event() # event letting MemGPT know the user just logged in
|
||||
|
||||
if include_initial_boot_message:
|
||||
if "gpt-3.5" in model:
|
||||
if model is not None and "gpt-3.5" in model:
|
||||
initial_boot_messages = get_initial_boot_messages("startup_with_send_message_gpt35")
|
||||
else:
|
||||
initial_boot_messages = get_initial_boot_messages("startup_with_send_message")
|
||||
@ -96,37 +97,6 @@ def initialize_message_sequence(
|
||||
return messages
|
||||
|
||||
|
||||
def get_ai_reply(
|
||||
model,
|
||||
message_sequence,
|
||||
functions,
|
||||
function_call="auto",
|
||||
context_window=None,
|
||||
):
|
||||
try:
|
||||
response = create(
|
||||
model=model,
|
||||
context_window=context_window,
|
||||
messages=message_sequence,
|
||||
functions=functions,
|
||||
function_call=function_call,
|
||||
)
|
||||
|
||||
# special case for 'length'
|
||||
if response.choices[0].finish_reason == "length":
|
||||
raise Exception("Finish reason was length (maximum context length)")
|
||||
|
||||
# catches for soft errors
|
||||
if response.choices[0].finish_reason not in ["stop", "function_call"]:
|
||||
raise Exception(f"API call finish with bad finish reason: {response}")
|
||||
|
||||
# unpack with response.choices[0].message.content
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
class Agent(object):
|
||||
def __init__(
|
||||
self,
|
||||
@ -310,7 +280,7 @@ class Agent(object):
|
||||
json_files = glob.glob(os.path.join(directory, "*.json")) # This will list all .json files in the current directory.
|
||||
if not json_files:
|
||||
print(f"/load error: no .json checkpoint files found")
|
||||
raise ValueError(f"Cannot load {agent_name}: does not exist in {directory}")
|
||||
raise ValueError(f"Cannot load {agent_name} - no saved checkpoints found in {directory}")
|
||||
|
||||
# Sort files based on modified timestamp, with the latest file being the first.
|
||||
filename = max(json_files, key=os.path.getmtime)
|
||||
@ -360,7 +330,7 @@ class Agent(object):
|
||||
|
||||
# NOTE to handle old configs, instead of erroring here let's just warn
|
||||
# raise ValueError(error_message)
|
||||
print(error_message)
|
||||
printd(error_message)
|
||||
linked_function_set[f_name] = linked_function
|
||||
|
||||
messages = state["messages"]
|
||||
@ -602,8 +572,7 @@ class Agent(object):
|
||||
printd(f"This is the first message. Running extra verifier on AI response.")
|
||||
counter = 0
|
||||
while True:
|
||||
response = get_ai_reply(
|
||||
model=self.model,
|
||||
response = self.get_ai_reply(
|
||||
message_sequence=input_message_sequence,
|
||||
functions=self.functions,
|
||||
context_window=None if self.config.context_window is None else int(self.config.context_window),
|
||||
@ -616,8 +585,7 @@ class Agent(object):
|
||||
raise Exception(f"Hit first message retry limit ({first_message_retry_limit})")
|
||||
|
||||
else:
|
||||
response = get_ai_reply(
|
||||
model=self.model,
|
||||
response = self.get_ai_reply(
|
||||
message_sequence=input_message_sequence,
|
||||
functions=self.functions,
|
||||
context_window=None if self.config.context_window is None else int(self.config.context_window),
|
||||
@ -785,3 +753,55 @@ class Agent(object):
|
||||
# Check if it's been more than pause_heartbeats_minutes since pause_heartbeats_start
|
||||
elapsed_time = datetime.datetime.now() - self.pause_heartbeats_start
|
||||
return elapsed_time.total_seconds() < self.pause_heartbeats_minutes * 60
|
||||
|
||||
def get_ai_reply(
|
||||
self,
|
||||
message_sequence,
|
||||
function_call="auto",
|
||||
):
|
||||
"""Get response from LLM API"""
|
||||
|
||||
# TODO: Legacy code - delete
|
||||
if self.config is None:
|
||||
try:
|
||||
response = create(
|
||||
model=self.model,
|
||||
context_window=self.context_window,
|
||||
messages=message_sequence,
|
||||
functions=self.functions,
|
||||
function_call=function_call,
|
||||
)
|
||||
|
||||
# special case for 'length'
|
||||
if response.choices[0].finish_reason == "length":
|
||||
raise Exception("Finish reason was length (maximum context length)")
|
||||
|
||||
# catches for soft errors
|
||||
if response.choices[0].finish_reason not in ["stop", "function_call"]:
|
||||
raise Exception(f"API call finish with bad finish reason: {response}")
|
||||
|
||||
# unpack with response.choices[0].message.content
|
||||
return response
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
try:
|
||||
response = chat_completion_with_backoff(
|
||||
agent_config=self.config,
|
||||
model=self.model, # TODO: remove (is redundant)
|
||||
messages=message_sequence,
|
||||
functions=self.functions,
|
||||
function_call=function_call,
|
||||
)
|
||||
# special case for 'length'
|
||||
if response.choices[0].finish_reason == "length":
|
||||
raise Exception("Finish reason was length (maximum context length)")
|
||||
|
||||
# catches for soft errors
|
||||
if response.choices[0].finish_reason not in ["stop", "function_call"]:
|
||||
raise Exception(f"API call finish with bad finish reason: {response}")
|
||||
|
||||
# unpack with response.choices[0].message.content
|
||||
return response
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
@ -1,4 +1,5 @@
|
||||
import typer
|
||||
import json
|
||||
import sys
|
||||
import io
|
||||
import logging
|
||||
@ -35,16 +36,21 @@ def run(
|
||||
persona: str = typer.Option(None, help="Specify persona"),
|
||||
agent: str = typer.Option(None, help="Specify agent save file"),
|
||||
human: str = typer.Option(None, help="Specify human"),
|
||||
model: str = typer.Option(None, help="Specify the LLM model"),
|
||||
preset: str = typer.Option(None, help="Specify preset"),
|
||||
# model flags
|
||||
model: str = typer.Option(None, help="Specify the LLM model"),
|
||||
model_wrapper: str = typer.Option(None, help="Specify the LLM model wrapper"),
|
||||
model_endpoint: str = typer.Option(None, help="Specify the LLM model endpoint"),
|
||||
model_endpoint_type: str = typer.Option(None, help="Specify the LLM model endpoint type"),
|
||||
context_window: int = typer.Option(
|
||||
None, "--context_window", help="The context window of the LLM you are using (e.g. 8k for most Mistral 7B variants)"
|
||||
),
|
||||
# other
|
||||
first: bool = typer.Option(False, "--first", help="Use --first to send the first message in the sequence"),
|
||||
strip_ui: bool = typer.Option(False, "--strip_ui", help="Remove all the bells and whistles in CLI output (helpful for testing)"),
|
||||
debug: bool = typer.Option(False, "--debug", help="Use --debug to enable debugging output"),
|
||||
no_verify: bool = typer.Option(False, "--no_verify", help="Bypass message verification"),
|
||||
yes: bool = typer.Option(False, "-y", help="Skip confirmation prompt and use defaults"),
|
||||
context_window: int = typer.Option(
|
||||
None, "--context_window", help="The context window of the LLM you are using (e.g. 8k for most Mistral 7B variants)"
|
||||
),
|
||||
):
|
||||
"""Start chatting with an MemGPT agent
|
||||
|
||||
@ -99,11 +105,6 @@ def run(
|
||||
set_global_service_context(service_context)
|
||||
sys.stdout = original_stdout
|
||||
|
||||
# overwrite the context_window if specified
|
||||
if context_window is not None and int(context_window) != int(config.context_window):
|
||||
typer.secho(f"Warning: Overriding existing context window {config.context_window} with {context_window}", fg=typer.colors.YELLOW)
|
||||
config.context_window = str(context_window)
|
||||
|
||||
# create agent config
|
||||
if agent and AgentConfig.exists(agent): # use existing agent
|
||||
typer.secho(f"Using existing agent {agent}", fg=typer.colors.GREEN)
|
||||
@ -121,10 +122,34 @@ def run(
|
||||
typer.secho(f"Warning: Overriding existing human {agent_config.human} with {human}", fg=typer.colors.YELLOW)
|
||||
agent_config.human = human
|
||||
# raise ValueError(f"Cannot override {agent_config.name} existing human {agent_config.human} with {human}")
|
||||
|
||||
# Allow overriding model specifics (model, model wrapper, model endpoint IP + type, context_window)
|
||||
if model and model != agent_config.model:
|
||||
typer.secho(f"Warning: Overriding existing model {agent_config.model} with {model}", fg=typer.colors.YELLOW)
|
||||
agent_config.model = model
|
||||
# raise ValueError(f"Cannot override {agent_config.name} existing model {agent_config.model} with {model}")
|
||||
if context_window is not None and int(context_window) != agent_config.context_window:
|
||||
typer.secho(
|
||||
f"Warning: Overriding existing context window {agent_config.context_window} with {context_window}", fg=typer.colors.YELLOW
|
||||
)
|
||||
agent_config.context_window = context_window
|
||||
if model_wrapper and model_wrapper != agent_config.model_wrapper:
|
||||
typer.secho(
|
||||
f"Warning: Overriding existing model wrapper {agent_config.model_wrapper} with {model_wrapper}", fg=typer.colors.YELLOW
|
||||
)
|
||||
agent_config.model_wrapper = model_wrapper
|
||||
if model_endpoint and model_endpoint != agent_config.model_endpoint:
|
||||
typer.secho(
|
||||
f"Warning: Overriding existing model endpoint {agent_config.model_endpoint} with {model_endpoint}", fg=typer.colors.YELLOW
|
||||
)
|
||||
agent_config.model_endpoint = model_endpoint
|
||||
if model_endpoint_type and model_endpoint_type != agent_config.model_endpoint_type:
|
||||
typer.secho(
|
||||
f"Warning: Overriding existing model endpoint type {agent_config.model_endpoint_type} with {model_endpoint_type}",
|
||||
fg=typer.colors.YELLOW,
|
||||
)
|
||||
agent_config.model_endpoint_type = model_endpoint_type
|
||||
|
||||
# Update the agent config with any overrides
|
||||
agent_config.save()
|
||||
|
||||
# load existing agent
|
||||
@ -133,17 +158,17 @@ def run(
|
||||
# create new agent config: override defaults with args if provided
|
||||
typer.secho("Creating new agent...", fg=typer.colors.GREEN)
|
||||
agent_config = AgentConfig(
|
||||
name=agent if agent else None,
|
||||
persona=persona if persona else config.default_persona,
|
||||
human=human if human else config.default_human,
|
||||
model=model if model else config.model,
|
||||
context_window=context_window if context_window else config.context_window,
|
||||
preset=preset if preset else config.preset,
|
||||
name=agent,
|
||||
persona=persona,
|
||||
human=human,
|
||||
preset=preset,
|
||||
model=model,
|
||||
model_wrapper=model_wrapper,
|
||||
model_endpoint_type=model_endpoint_type,
|
||||
model_endpoint=model_endpoint,
|
||||
context_window=context_window,
|
||||
)
|
||||
|
||||
## attach data source to agent
|
||||
# agent_config.attach_data_source(data_source)
|
||||
|
||||
# TODO: allow configrable state manager (only local is supported right now)
|
||||
persistence_manager = LocalStateManager(agent_config) # TODO: insert dataset/pre-fill
|
||||
|
||||
@ -162,6 +187,9 @@ def run(
|
||||
persistence_manager,
|
||||
)
|
||||
|
||||
# pretty print agent config
|
||||
printd(json.dumps(vars(agent_config), indent=4, sort_keys=True))
|
||||
|
||||
# start event loop
|
||||
from memgpt.main import run_agent_loop
|
||||
|
||||
|
@ -1,3 +1,4 @@
|
||||
import builtins
|
||||
import questionary
|
||||
import openai
|
||||
from prettytable import PrettyTable
|
||||
@ -11,126 +12,118 @@ from memgpt import utils
|
||||
|
||||
import memgpt.humans.humans as humans
|
||||
import memgpt.personas.personas as personas
|
||||
from memgpt.config import MemGPTConfig, AgentConfig
|
||||
from memgpt.config import MemGPTConfig, AgentConfig, Config
|
||||
from memgpt.constants import MEMGPT_DIR
|
||||
from memgpt.connectors.storage import StorageConnector
|
||||
from memgpt.constants import LLM_MAX_TOKENS
|
||||
from memgpt.local_llm.constants import DEFAULT_ENDPOINTS, DEFAULT_OLLAMA_MODEL, DEFAULT_WRAPPER_NAME
|
||||
from memgpt.local_llm.utils import get_available_wrappers
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def configure():
|
||||
"""Updates default MemGPT configurations"""
|
||||
|
||||
from memgpt.presets.presets import DEFAULT_PRESET, preset_options
|
||||
|
||||
MemGPTConfig.create_config_dir()
|
||||
|
||||
# Will pre-populate with defaults, or what the user previously set
|
||||
config = MemGPTConfig.load()
|
||||
|
||||
# openai credentials
|
||||
use_openai = questionary.confirm("Do you want to enable MemGPT with OpenAI?", default=True).ask()
|
||||
if use_openai:
|
||||
# search for key in enviornment
|
||||
openai_key = os.getenv("OPENAI_API_KEY")
|
||||
if not openai_key:
|
||||
print("Missing enviornment variables for OpenAI. Please set them and run `memgpt configure` again.")
|
||||
# TODO: eventually stop relying on env variables and pass in keys explicitly
|
||||
# openai_key = questionary.text("Open AI API keys not found in enviornment - please enter:").ask()
|
||||
|
||||
# azure credentials
|
||||
use_azure = questionary.confirm("Do you want to enable MemGPT with Azure?", default=(config.azure_key is not None)).ask()
|
||||
use_azure_deployment_ids = False
|
||||
if use_azure:
|
||||
# search for key in enviornment
|
||||
def get_azure_credentials():
|
||||
azure_key = os.getenv("AZURE_OPENAI_KEY")
|
||||
azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
|
||||
azure_version = os.getenv("AZURE_OPENAI_VERSION")
|
||||
azure_deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT")
|
||||
azure_embedding_deployment = os.getenv("AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT")
|
||||
return azure_key, azure_endpoint, azure_version, azure_deployment, azure_embedding_deployment
|
||||
|
||||
if all([azure_key, azure_endpoint, azure_version]):
|
||||
print(f"Using Microsoft endpoint {azure_endpoint}.")
|
||||
if all([azure_deployment, azure_embedding_deployment]):
|
||||
print(f"Using deployment id {azure_deployment}")
|
||||
use_azure_deployment_ids = True
|
||||
|
||||
# configure openai
|
||||
openai.api_type = "azure"
|
||||
openai.api_key = azure_key
|
||||
openai.api_base = azure_endpoint
|
||||
openai.api_version = azure_version
|
||||
else:
|
||||
print("Missing enviornment variables for Azure. Please set then run `memgpt configure` again.")
|
||||
# TODO: allow for manual setting
|
||||
use_azure = False
|
||||
def get_openai_credentials():
|
||||
openai_key = os.getenv("OPENAI_API_KEY")
|
||||
return openai_key
|
||||
|
||||
# TODO: configure local model
|
||||
|
||||
# configure provider
|
||||
model_endpoint_options = []
|
||||
if os.getenv("OPENAI_API_BASE") is not None:
|
||||
model_endpoint_options.append(os.getenv("OPENAI_API_BASE"))
|
||||
if use_openai:
|
||||
model_endpoint_options += ["openai"]
|
||||
if use_azure:
|
||||
model_endpoint_options += ["azure"]
|
||||
assert (
|
||||
len(model_endpoint_options) > 0
|
||||
), "No endpoints found. Please enable OpenAI, Azure, or set OPENAI_API_BASE to point at the IP address of your LLM server."
|
||||
valid_default_model = config.model_endpoint in model_endpoint_options
|
||||
default_endpoint = questionary.select(
|
||||
"Select default inference endpoint:",
|
||||
model_endpoint_options,
|
||||
default=config.model_endpoint if valid_default_model else model_endpoint_options[0],
|
||||
def configure_llm_endpoint(config: MemGPTConfig):
|
||||
# configure model endpoint
|
||||
model_endpoint_type, model_endpoint = None, None
|
||||
|
||||
# get default
|
||||
default_model_endpoint_type = config.model_endpoint_type
|
||||
if config.model_endpoint_type is not None and config.model_endpoint_type not in ["openai", "azure"]: # local model
|
||||
default_model_endpoint_type = "local"
|
||||
|
||||
provider = questionary.select(
|
||||
"Select LLM inference provider:", choices=["openai", "azure", "local"], default=default_model_endpoint_type
|
||||
).ask()
|
||||
|
||||
# configure embedding provider
|
||||
embedding_endpoint_options = []
|
||||
if use_azure:
|
||||
embedding_endpoint_options += ["azure"]
|
||||
if use_openai:
|
||||
embedding_endpoint_options += ["openai"]
|
||||
embedding_endpoint_options += ["local"]
|
||||
valid_default_embedding = config.embedding_model in embedding_endpoint_options
|
||||
# determine the default selection in a smart way
|
||||
if "openai" in embedding_endpoint_options and default_endpoint == "openai":
|
||||
# openai llm -> openai embeddings
|
||||
default_embedding_endpoint_default = "openai"
|
||||
elif default_endpoint not in ["openai", "azure"]: # is local
|
||||
# local llm -> local embeddings
|
||||
default_embedding_endpoint_default = "local"
|
||||
# set: model_endpoint_type, model_endpoint
|
||||
if provider == "openai":
|
||||
model_endpoint_type = "openai"
|
||||
model_endpoint = "https://api.openai.com/v1"
|
||||
model_endpoint = questionary.text("Override default endpoint:", default=model_endpoint).ask()
|
||||
provider = "openai"
|
||||
elif provider == "azure":
|
||||
model_endpoint_type = "azure"
|
||||
_, model_endpoint, _, _, _ = get_azure_credentials()
|
||||
else: # local models
|
||||
backend_options = ["webui", "llamacpp", "koboldcpp", "ollama", "lmstudio", "openai"]
|
||||
default_model_endpoint_type = None
|
||||
if config.model_endpoint_type in backend_options:
|
||||
# set from previous config
|
||||
default_model_endpoint_type = config.model_endpoint_type
|
||||
else:
|
||||
default_embedding_endpoint_default = config.embedding_model if valid_default_embedding else embedding_endpoint_options[-1]
|
||||
default_embedding_endpoint = questionary.select(
|
||||
"Select default embedding endpoint:", embedding_endpoint_options, default=default_embedding_endpoint_default
|
||||
# set form env variable (ok if none)
|
||||
default_model_endpoint_type = os.getenv("BACKEND_TYPE")
|
||||
model_endpoint_type = questionary.select(
|
||||
"Select LLM backend (select 'openai' if you have an OpenAI compatible proxy):",
|
||||
backend_options,
|
||||
default=default_model_endpoint_type,
|
||||
).ask()
|
||||
|
||||
# configure embedding dimentions
|
||||
default_embedding_dim = config.embedding_dim
|
||||
if default_embedding_endpoint == "local":
|
||||
# HF model uses lower dimentionality
|
||||
default_embedding_dim = 384
|
||||
# set default endpoint
|
||||
# if OPENAI_API_BASE is set, assume that this is the IP+port the user wanted to use
|
||||
default_model_endpoint = os.getenv("OPENAI_API_BASE")
|
||||
# if OPENAI_API_BASE is not set, try to pull a default IP+port format from a hardcoded set
|
||||
if default_model_endpoint is None:
|
||||
if model_endpoint_type in DEFAULT_ENDPOINTS:
|
||||
default_model_endpoint = DEFAULT_ENDPOINTS[model_endpoint_type]
|
||||
model_endpoint = questionary.text("Enter default endpoint:", default=default_model_endpoint).ask()
|
||||
else:
|
||||
# default_model_endpoint = None
|
||||
model_endpoint = None
|
||||
while not model_endpoint:
|
||||
model_endpoint = questionary.text("Enter default endpoint:").ask()
|
||||
if "http://" not in model_endpoint and "https://" not in model_endpoint:
|
||||
typer.secho(f"Endpoint must be a valid address", fg=typer.colors.YELLOW)
|
||||
model_endpoint = None
|
||||
assert model_endpoint, f"Environment variable OPENAI_API_BASE must be set."
|
||||
|
||||
# configure preset
|
||||
default_preset = questionary.select("Select default preset:", preset_options, default=config.preset).ask()
|
||||
return model_endpoint_type, model_endpoint
|
||||
|
||||
# default model
|
||||
if use_openai or use_azure:
|
||||
model_options = []
|
||||
if use_openai:
|
||||
model_options += ["gpt-4", "gpt-4-1106-preview", "gpt-3.5-turbo-16k"]
|
||||
|
||||
def configure_model(config: MemGPTConfig, model_endpoint_type: str):
|
||||
# set: model, model_wrapper
|
||||
model, model_wrapper = None, None
|
||||
if model_endpoint_type == "openai" or model_endpoint_type == "azure":
|
||||
model_options = ["gpt-4", "gpt-4-1106-preview", "gpt-3.5-turbo", "gpt-3.5-turbo-16k"]
|
||||
# TODO: select
|
||||
valid_model = config.model in model_options
|
||||
default_model = questionary.select(
|
||||
model = questionary.select(
|
||||
"Select default model (recommended: gpt-4):", choices=model_options, default=config.model if valid_model else model_options[0]
|
||||
).ask()
|
||||
else:
|
||||
default_model = "local" # TODO: figure out if this is ok? this is for local endpoint
|
||||
else: # local models
|
||||
# ollama also needs model type
|
||||
if model_endpoint_type == "ollama":
|
||||
default_model = config.model if config.model and config.model_endpoint_type == "ollama" else DEFAULT_OLLAMA_MODEL
|
||||
model = questionary.text(
|
||||
"Enter default model name (required for Ollama, see: https://memgpt.readthedocs.io/en/latest/ollama):",
|
||||
default=default_model,
|
||||
).ask()
|
||||
model = None if len(model) == 0 else model
|
||||
|
||||
# get the max tokens (context window) for the model
|
||||
if default_model == "local" or str(default_model) not in LLM_MAX_TOKENS:
|
||||
# model wrapper
|
||||
available_model_wrappers = builtins.list(get_available_wrappers().keys())
|
||||
model_wrapper = questionary.select(
|
||||
f"Select default model wrapper (recommended: {DEFAULT_WRAPPER_NAME}):",
|
||||
choices=available_model_wrappers,
|
||||
default=DEFAULT_WRAPPER_NAME,
|
||||
).ask()
|
||||
|
||||
# set: context_window
|
||||
if str(model) not in LLM_MAX_TOKENS:
|
||||
# Ask the user to specify the context length
|
||||
context_length_options = [
|
||||
str(2**12), # 4096
|
||||
@ -140,46 +133,80 @@ def configure():
|
||||
str(2**18), # 262144
|
||||
"custom", # enter yourself
|
||||
]
|
||||
default_model_context_window = questionary.select(
|
||||
context_window = questionary.select(
|
||||
"Select your model's context window (for Mistral 7B models, this is probably 8k / 8192):",
|
||||
choices=context_length_options,
|
||||
default=str(LLM_MAX_TOKENS["DEFAULT"]),
|
||||
).ask()
|
||||
|
||||
# If custom, ask for input
|
||||
if default_model_context_window == "custom":
|
||||
if context_window == "custom":
|
||||
while True:
|
||||
default_model_context_window = questionary.text("Enter context window (e.g. 8192)").ask()
|
||||
context_window = questionary.text("Enter context window (e.g. 8192)").ask()
|
||||
try:
|
||||
default_model_context_window = int(default_model_context_window)
|
||||
context_window = int(context_window)
|
||||
break
|
||||
except ValueError:
|
||||
print(f"Context window must be a valid integer")
|
||||
else:
|
||||
default_model_context_window = int(default_model_context_window)
|
||||
context_window = int(context_window)
|
||||
else:
|
||||
# Pull the context length from the models
|
||||
default_model_context_window = LLM_MAX_TOKENS[default_model]
|
||||
context_window = LLM_MAX_TOKENS[model]
|
||||
return model, model_wrapper, context_window
|
||||
|
||||
# defaults
|
||||
|
||||
def configure_embedding_endpoint(config: MemGPTConfig):
|
||||
# configure embedding endpoint
|
||||
|
||||
default_embedding_endpoint_type = config.embedding_endpoint_type
|
||||
if config.embedding_endpoint_type is not None and config.embedding_endpoint_type not in ["openai", "azure"]: # local model
|
||||
default_embedding_endpoint_type = "local"
|
||||
|
||||
embedding_endpoint_type, embedding_endpoint, embedding_dim = None, None, None
|
||||
embedding_provider = questionary.select(
|
||||
"Select embedding provider:", choices=["openai", "azure", "local"], default=default_embedding_endpoint_type
|
||||
).ask()
|
||||
if embedding_provider == "openai":
|
||||
embedding_endpoint_type = "openai"
|
||||
embedding_endpoint = "https://api.openai.com/v1"
|
||||
embedding_dim = 1536
|
||||
elif embedding_provider == "azure":
|
||||
embedding_endpoint_type = "azure"
|
||||
_, _, _, _, embedding_endpoint = get_azure_credentials()
|
||||
embedding_dim = 1536
|
||||
else: # local models
|
||||
embedding_endpoint_type = "local"
|
||||
embedding_endpoint = None
|
||||
embedding_dim = 384
|
||||
return embedding_endpoint_type, embedding_endpoint, embedding_dim
|
||||
|
||||
|
||||
def configure_cli(config: MemGPTConfig):
|
||||
# set: preset, default_persona, default_human, default_agent``
|
||||
from memgpt.presets.presets import preset_options
|
||||
|
||||
# preset
|
||||
default_preset = config.preset if config.preset and config.preset in preset_options else None
|
||||
preset = questionary.select("Select default preset:", preset_options, default=default_preset).ask()
|
||||
|
||||
# persona
|
||||
personas = [os.path.basename(f).replace(".txt", "") for f in utils.list_persona_files()]
|
||||
# print(personas)
|
||||
default_persona = questionary.select("Select default persona:", personas, default=config.default_persona).ask()
|
||||
default_persona = config.persona if config.persona and config.persona in personas else None
|
||||
persona = questionary.select("Select default persona:", personas, default=default_persona).ask()
|
||||
|
||||
# human
|
||||
humans = [os.path.basename(f).replace(".txt", "") for f in utils.list_human_files()]
|
||||
# print(humans)
|
||||
default_human = questionary.select("Select default human:", humans, default=config.default_human).ask()
|
||||
default_human = config.human if config.human and config.human in humans else None
|
||||
human = questionary.select("Select default human:", humans, default=default_human).ask()
|
||||
|
||||
# TODO: figure out if we should set a default agent or not
|
||||
default_agent = None
|
||||
# agents = [os.path.basename(f).replace(".json", "") for f in utils.list_agent_config_files()]
|
||||
# if len(agents) > 0: # agents have been created
|
||||
# default_agent = questionary.select(
|
||||
# "Select default agent:",
|
||||
# agents
|
||||
# ).ask()
|
||||
# else:
|
||||
# default_agent = None
|
||||
agent = None
|
||||
|
||||
return preset, persona, human, agent
|
||||
|
||||
|
||||
def configure_archival_storage(config: MemGPTConfig):
|
||||
# Configure archival storage backend
|
||||
archival_storage_options = ["local", "postgres"]
|
||||
archival_storage_type = questionary.select(
|
||||
@ -191,25 +218,65 @@ def configure():
|
||||
"Enter postgres connection string (e.g. postgresql+pg8000://{user}:{password}@{ip}:5432/{database}):",
|
||||
default=config.archival_storage_uri if config.archival_storage_uri else "",
|
||||
).ask()
|
||||
return archival_storage_type, archival_storage_uri
|
||||
|
||||
# TODO: allow configuring embedding model
|
||||
|
||||
@app.command()
|
||||
def configure():
|
||||
"""Updates default MemGPT configurations"""
|
||||
|
||||
MemGPTConfig.create_config_dir()
|
||||
|
||||
# Will pre-populate with defaults, or what the user previously set
|
||||
config = MemGPTConfig.load()
|
||||
model_endpoint_type, model_endpoint = configure_llm_endpoint(config)
|
||||
model, model_wrapper, context_window = configure_model(config, model_endpoint_type)
|
||||
embedding_endpoint_type, embedding_endpoint, embedding_dim = configure_embedding_endpoint(config)
|
||||
default_preset, default_persona, default_human, default_agent = configure_cli(config)
|
||||
archival_storage_type, archival_storage_uri = configure_archival_storage(config)
|
||||
|
||||
# check credentials
|
||||
azure_key, azure_endpoint, azure_version, azure_deployment, azure_embedding_deployment = get_azure_credentials()
|
||||
openai_key = get_openai_credentials()
|
||||
if model_endpoint_type == "azure" or embedding_endpoint_type == "azure":
|
||||
if all([azure_key, azure_endpoint, azure_version]):
|
||||
print(f"Using Microsoft endpoint {azure_endpoint}.")
|
||||
if all([azure_deployment, azure_embedding_deployment]):
|
||||
print(f"Using deployment id {azure_deployment}")
|
||||
else:
|
||||
raise ValueError(
|
||||
"Missing environment variables for Azure (see https://memgpt.readthedocs.io/en/latest/endpoints/#azure). Please set then run `memgpt configure` again."
|
||||
)
|
||||
if model_endpoint_type == "openai" or embedding_endpoint_type == "openai":
|
||||
if not openai_key:
|
||||
raise ValueError(
|
||||
"Missing environment variables for OpenAI (see https://memgpt.readthedocs.io/en/latest/endpoints/#openai). Please set them and run `memgpt configure` again."
|
||||
)
|
||||
|
||||
config = MemGPTConfig(
|
||||
model=default_model,
|
||||
context_window=default_model_context_window,
|
||||
# model configs
|
||||
model=model,
|
||||
model_endpoint=model_endpoint,
|
||||
model_endpoint_type=model_endpoint_type,
|
||||
model_wrapper=model_wrapper,
|
||||
context_window=context_window,
|
||||
# embedding configs
|
||||
embedding_endpoint_type=embedding_endpoint_type,
|
||||
embedding_endpoint=embedding_endpoint,
|
||||
embedding_dim=embedding_dim,
|
||||
# cli configs
|
||||
preset=default_preset,
|
||||
model_endpoint=default_endpoint,
|
||||
embedding_model=default_embedding_endpoint,
|
||||
embedding_dim=default_embedding_dim,
|
||||
default_persona=default_persona,
|
||||
default_human=default_human,
|
||||
default_agent=default_agent,
|
||||
openai_key=openai_key if use_openai else None,
|
||||
azure_key=azure_key if use_azure else None,
|
||||
azure_endpoint=azure_endpoint if use_azure else None,
|
||||
azure_version=azure_version if use_azure else None,
|
||||
azure_deployment=azure_deployment if use_azure_deployment_ids else None,
|
||||
azure_embedding_deployment=azure_embedding_deployment if use_azure_deployment_ids else None,
|
||||
persona=default_persona,
|
||||
human=default_human,
|
||||
agent=default_agent,
|
||||
# credentials
|
||||
openai_key=openai_key,
|
||||
azure_key=azure_key,
|
||||
azure_endpoint=azure_endpoint,
|
||||
azure_version=azure_version,
|
||||
azure_deployment=azure_deployment,
|
||||
azure_embedding_deployment=azure_embedding_deployment,
|
||||
# storage
|
||||
archival_storage_type=archival_storage_type,
|
||||
archival_storage_uri=archival_storage_uri,
|
||||
)
|
||||
|
253
memgpt/config.py
253
memgpt/config.py
@ -16,6 +16,7 @@ from colorama import Fore, Style
|
||||
|
||||
from typing import List, Type
|
||||
|
||||
import memgpt
|
||||
import memgpt.utils as utils
|
||||
from memgpt.interface import CLIInterface as interface
|
||||
from memgpt.personas.personas import get_persona_text
|
||||
@ -40,6 +41,24 @@ model_choices = [
|
||||
]
|
||||
|
||||
|
||||
# helper functions for writing to configs
|
||||
def get_field(config, section, field):
|
||||
if section not in config:
|
||||
return None
|
||||
if config.has_option(section, field):
|
||||
return config.get(section, field)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def set_field(config, section, field, value):
|
||||
if value is None: # cannot write None
|
||||
return
|
||||
if section not in config: # create section
|
||||
config.add_section(section)
|
||||
config.set(section, field, value)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemGPTConfig:
|
||||
config_path: str = os.path.join(MEMGPT_DIR, "config")
|
||||
@ -49,9 +68,10 @@ class MemGPTConfig:
|
||||
preset: str = DEFAULT_PRESET
|
||||
|
||||
# model parameters
|
||||
# provider: str = "openai" # openai, azure, local (TODO)
|
||||
model_endpoint: str = "openai"
|
||||
model: str = "gpt-4" # gpt-4, gpt-3.5-turbo, local
|
||||
model: str = None
|
||||
model_endpoint_type: str = None
|
||||
model_endpoint: str = None # localhost:8000
|
||||
model_wrapper: str = None
|
||||
context_window: int = LLM_MAX_TOKENS[model] if model in LLM_MAX_TOKENS else LLM_MAX_TOKENS["DEFAULT"]
|
||||
|
||||
# model parameters: openai
|
||||
@ -65,12 +85,13 @@ class MemGPTConfig:
|
||||
azure_embedding_deployment: str = None
|
||||
|
||||
# persona parameters
|
||||
default_persona: str = personas.DEFAULT
|
||||
default_human: str = humans.DEFAULT
|
||||
default_agent: str = None
|
||||
persona: str = personas.DEFAULT
|
||||
human: str = humans.DEFAULT
|
||||
agent: str = None
|
||||
|
||||
# embedding parameters
|
||||
embedding_model: str = "openai"
|
||||
embedding_endpoint_type: str = "openai" # openai, azure, local
|
||||
embedding_endpoint: str = None
|
||||
embedding_dim: int = 1536
|
||||
embedding_chunk_size: int = 300 # number of tokens
|
||||
|
||||
@ -89,6 +110,12 @@ class MemGPTConfig:
|
||||
persistence_manager_save_file: str = None # local file
|
||||
persistence_manager_uri: str = None # db URI
|
||||
|
||||
def __post_init__(self):
|
||||
# ensure types
|
||||
self.embedding_chunk_size = int(self.embedding_chunk_size)
|
||||
self.embedding_dim = int(self.embedding_dim)
|
||||
self.context_window = int(self.context_window)
|
||||
|
||||
@staticmethod
|
||||
def generate_uuid() -> str:
|
||||
return uuid.UUID(int=uuid.getnode()).hex
|
||||
@ -104,72 +131,38 @@ class MemGPTConfig:
|
||||
config_path = MemGPTConfig.config_path
|
||||
|
||||
if os.path.exists(config_path):
|
||||
# read existing config
|
||||
config.read(config_path)
|
||||
config_dict = {
|
||||
"model": get_field(config, "model", "model"),
|
||||
"model_endpoint": get_field(config, "model", "model_endpoint"),
|
||||
"model_endpoint_type": get_field(config, "model", "model_endpoint_type"),
|
||||
"model_wrapper": get_field(config, "model", "model_wrapper"),
|
||||
"context_window": get_field(config, "model", "context_window"),
|
||||
"preset": get_field(config, "defaults", "preset"),
|
||||
"persona": get_field(config, "defaults", "persona"),
|
||||
"human": get_field(config, "defaults", "human"),
|
||||
"agent": get_field(config, "defaults", "agent"),
|
||||
"openai_key": get_field(config, "openai", "key"),
|
||||
"azure_key": get_field(config, "azure", "key"),
|
||||
"azure_endpoint": get_field(config, "azure", "endpoint"),
|
||||
"azure_version": get_field(config, "azure", "version"),
|
||||
"azure_deployment": get_field(config, "azure", "deployment"),
|
||||
"azure_embedding_deployment": get_field(config, "azure", "embedding_deployment"),
|
||||
"embedding_endpoint": get_field(config, "embedding", "embedding_endpoint"),
|
||||
"embedding_endpoint_type": get_field(config, "embedding", "embedding_endpoint_type"),
|
||||
"embedding_dim": get_field(config, "embedding", "embedding_dim"),
|
||||
"embedding_chunk_size": get_field(config, "embedding", "chunk_size"),
|
||||
"archival_storage_type": get_field(config, "archival_storage", "type"),
|
||||
"archival_storage_path": get_field(config, "archival_storage", "path"),
|
||||
"archival_storage_uri": get_field(config, "archival_storage", "uri"),
|
||||
"anon_clientid": get_field(config, "client", "anon_clientid"),
|
||||
"config_path": config_path,
|
||||
}
|
||||
config_dict = {k: v for k, v in config_dict.items() if v is not None}
|
||||
return cls(**config_dict)
|
||||
|
||||
# read config values
|
||||
model = config.get("defaults", "model")
|
||||
context_window = (
|
||||
int(config.get("defaults", "context_window"))
|
||||
if config.has_option("defaults", "context_window")
|
||||
else LLM_MAX_TOKENS["DEFAULT"]
|
||||
)
|
||||
preset = config.get("defaults", "preset")
|
||||
model_endpoint = config.get("defaults", "model_endpoint")
|
||||
default_persona = config.get("defaults", "persona")
|
||||
default_human = config.get("defaults", "human")
|
||||
default_agent = config.get("defaults", "agent") if config.has_option("defaults", "agent") else None
|
||||
|
||||
openai_key, openai_model = None, None
|
||||
if "openai" in config:
|
||||
openai_key = config.get("openai", "key")
|
||||
|
||||
azure_key, azure_endpoint, azure_version, azure_deployment, azure_embedding_deployment = None, None, None, None, None
|
||||
if "azure" in config:
|
||||
azure_key = config.get("azure", "key")
|
||||
azure_endpoint = config.get("azure", "endpoint")
|
||||
azure_version = config.get("azure", "version")
|
||||
azure_deployment = config.get("azure", "deployment") if config.has_option("azure", "deployment") else None
|
||||
azure_embedding_deployment = (
|
||||
config.get("azure", "embedding_deployment") if config.has_option("azure", "embedding_deployment") else None
|
||||
)
|
||||
|
||||
embedding_model = config.get("embedding", "model")
|
||||
embedding_dim = config.getint("embedding", "dim")
|
||||
embedding_chunk_size = config.getint("embedding", "chunk_size")
|
||||
|
||||
# archival storage
|
||||
archival_storage_type, archival_storage_path, archival_storage_uri = "local", None, None
|
||||
if "archival_storage" in config:
|
||||
archival_storage_type = config.get("archival_storage", "type")
|
||||
archival_storage_path = config.get("archival_storage", "path") if config.has_option("archival_storage", "path") else None
|
||||
archival_storage_uri = config.get("archival_storage", "uri") if config.has_option("archival_storage", "uri") else None
|
||||
|
||||
anon_clientid = config.get("client", "anon_clientid")
|
||||
|
||||
return cls(
|
||||
model=model,
|
||||
context_window=context_window,
|
||||
preset=preset,
|
||||
model_endpoint=model_endpoint,
|
||||
default_persona=default_persona,
|
||||
default_human=default_human,
|
||||
default_agent=default_agent,
|
||||
openai_key=openai_key,
|
||||
azure_key=azure_key,
|
||||
azure_endpoint=azure_endpoint,
|
||||
azure_version=azure_version,
|
||||
azure_deployment=azure_deployment,
|
||||
azure_embedding_deployment=azure_embedding_deployment,
|
||||
embedding_model=embedding_model,
|
||||
embedding_dim=embedding_dim,
|
||||
embedding_chunk_size=embedding_chunk_size,
|
||||
archival_storage_type=archival_storage_type,
|
||||
archival_storage_path=archival_storage_path,
|
||||
archival_storage_uri=archival_storage_uri,
|
||||
anon_clientid=anon_clientid,
|
||||
config_path=config_path,
|
||||
)
|
||||
|
||||
# create new config
|
||||
anon_clientid = MemGPTConfig.generate_uuid()
|
||||
config = cls(anon_clientid=anon_clientid, config_path=config_path)
|
||||
config.save() # save updated config
|
||||
@ -179,51 +172,43 @@ class MemGPTConfig:
|
||||
config = configparser.ConfigParser()
|
||||
|
||||
# CLI defaults
|
||||
config.add_section("defaults")
|
||||
config.set("defaults", "model", self.model)
|
||||
config.set("defaults", "context_window", str(self.context_window))
|
||||
config.set("defaults", "preset", self.preset)
|
||||
assert self.model_endpoint is not None, "Endpoint must be set"
|
||||
config.set("defaults", "model_endpoint", self.model_endpoint)
|
||||
config.set("defaults", "persona", self.default_persona)
|
||||
config.set("defaults", "human", self.default_human)
|
||||
if self.default_agent:
|
||||
config.set("defaults", "agent", self.default_agent)
|
||||
set_field(config, "defaults", "preset", self.preset)
|
||||
set_field(config, "defaults", "persona", self.persona)
|
||||
set_field(config, "defaults", "human", self.human)
|
||||
set_field(config, "defaults", "agent", self.agent)
|
||||
|
||||
# security credentials
|
||||
if self.openai_key:
|
||||
config.add_section("openai")
|
||||
config.set("openai", "key", self.openai_key)
|
||||
# model defaults
|
||||
set_field(config, "model", "model", self.model)
|
||||
set_field(config, "model", "model_endpoint", self.model_endpoint)
|
||||
set_field(config, "model", "model_endpoint_type", self.model_endpoint_type)
|
||||
set_field(config, "model", "model_wrapper", self.model_wrapper)
|
||||
set_field(config, "model", "context_window", str(self.context_window))
|
||||
|
||||
if self.azure_key:
|
||||
config.add_section("azure")
|
||||
config.set("azure", "key", self.azure_key)
|
||||
config.set("azure", "endpoint", self.azure_endpoint)
|
||||
config.set("azure", "version", self.azure_version)
|
||||
if self.azure_deployment:
|
||||
config.set("azure", "deployment", self.azure_deployment)
|
||||
config.set("azure", "embedding_deployment", self.azure_embedding_deployment)
|
||||
# security credentials: openai
|
||||
set_field(config, "openai", "key", self.openai_key)
|
||||
|
||||
# security credentials: azure
|
||||
set_field(config, "azure", "key", self.azure_key)
|
||||
set_field(config, "azure", "endpoint", self.azure_endpoint)
|
||||
set_field(config, "azure", "version", self.azure_version)
|
||||
set_field(config, "azure", "deployment", self.azure_deployment)
|
||||
set_field(config, "azure", "embedding_deployment", self.azure_embedding_deployment)
|
||||
|
||||
# embeddings
|
||||
config.add_section("embedding")
|
||||
config.set("embedding", "model", self.embedding_model)
|
||||
config.set("embedding", "dim", str(self.embedding_dim))
|
||||
config.set("embedding", "chunk_size", str(self.embedding_chunk_size))
|
||||
set_field(config, "embedding", "embedding_endpoint_type", self.embedding_endpoint_type)
|
||||
set_field(config, "embedding", "embedding_endpoint", self.embedding_endpoint)
|
||||
set_field(config, "embedding", "embedding_dim", str(self.embedding_dim))
|
||||
set_field(config, "embedding", "embedding_chunk_size", str(self.embedding_chunk_size))
|
||||
|
||||
# archival storage
|
||||
config.add_section("archival_storage")
|
||||
# print("archival storage", self.archival_storage_type)
|
||||
config.set("archival_storage", "type", self.archival_storage_type)
|
||||
if self.archival_storage_path:
|
||||
config.set("archival_storage", "path", self.archival_storage_path)
|
||||
if self.archival_storage_uri:
|
||||
config.set("archival_storage", "uri", self.archival_storage_uri)
|
||||
set_field(config, "archival_storage", "type", self.archival_storage_type)
|
||||
set_field(config, "archival_storage", "path", self.archival_storage_path)
|
||||
set_field(config, "archival_storage", "uri", self.archival_storage_uri)
|
||||
|
||||
# client
|
||||
config.add_section("client")
|
||||
if not self.anon_clientid:
|
||||
self.anon_clientid = self.generate_uuid()
|
||||
config.set("client", "anon_clientid", self.anon_clientid)
|
||||
set_field(config, "client", "anon_clientid", self.anon_clientid)
|
||||
|
||||
if not os.path.exists(MEMGPT_DIR):
|
||||
os.makedirs(MEMGPT_DIR, exist_ok=True)
|
||||
@ -262,32 +247,54 @@ class AgentConfig:
|
||||
self,
|
||||
persona,
|
||||
human,
|
||||
# model info
|
||||
model,
|
||||
model_endpoint_type=None,
|
||||
model_endpoint=None,
|
||||
model_wrapper=None,
|
||||
context_window=None,
|
||||
preset=DEFAULT_PRESET,
|
||||
name=None,
|
||||
data_sources=[],
|
||||
# embedding info
|
||||
embedding_endpoint_type=None,
|
||||
embedding_endpoint=None,
|
||||
embedding_dim=None,
|
||||
embedding_chunk_size=None,
|
||||
# other
|
||||
preset=None,
|
||||
data_sources=None,
|
||||
# agent info
|
||||
agent_config_path=None,
|
||||
name=None,
|
||||
create_time=None,
|
||||
data_source=None,
|
||||
memgpt_version=None,
|
||||
):
|
||||
if name is None:
|
||||
self.name = f"agent_{self.generate_agent_id()}"
|
||||
else:
|
||||
self.name = name
|
||||
self.persona = persona
|
||||
self.human = human
|
||||
self.model = model
|
||||
self.context_window = context_window
|
||||
self.preset = preset
|
||||
self.data_sources = data_sources
|
||||
self.create_time = create_time if create_time is not None else utils.get_local_time()
|
||||
self.data_source = None # deprecated
|
||||
|
||||
if context_window is None:
|
||||
self.context_window = LLM_MAX_TOKENS[self.model] if self.model in LLM_MAX_TOKENS else LLM_MAX_TOKENS["DEFAULT"]
|
||||
config = MemGPTConfig.load() # get default values
|
||||
self.persona = config.persona if persona is None else persona
|
||||
self.human = config.human if human is None else human
|
||||
self.preset = config.preset if preset is None else preset
|
||||
self.context_window = config.context_window if context_window is None else context_window
|
||||
self.model = config.model if model is None else model
|
||||
self.model_endpoint_type = config.model_endpoint_type if model_endpoint_type is None else model_endpoint_type
|
||||
self.model_endpoint = config.model_endpoint if model_endpoint is None else model_endpoint
|
||||
self.model_wrapper = config.model_wrapper if model_wrapper is None else model_wrapper
|
||||
self.embedding_endpoint_type = config.embedding_endpoint_type if embedding_endpoint_type is None else embedding_endpoint_type
|
||||
self.embedding_endpoint = config.embedding_endpoint if embedding_endpoint is None else embedding_endpoint
|
||||
self.embedding_dim = config.embedding_dim if embedding_dim is None else embedding_dim
|
||||
self.embedding_chunk_size = config.embedding_chunk_size if embedding_chunk_size is None else embedding_chunk_size
|
||||
|
||||
# agent metadata
|
||||
self.data_sources = data_sources if data_sources is not None else []
|
||||
self.create_time = create_time if create_time is not None else utils.get_local_time()
|
||||
if memgpt_version is None:
|
||||
import memgpt
|
||||
|
||||
self.memgpt_version = memgpt.__version__
|
||||
else:
|
||||
self.context_window = context_window
|
||||
self.memgpt_version = memgpt_version
|
||||
|
||||
# save agent config
|
||||
self.agent_config_path = (
|
||||
@ -326,6 +333,8 @@ class AgentConfig:
|
||||
def save(self):
|
||||
# save state of persistence manager
|
||||
os.makedirs(os.path.join(MEMGPT_DIR, "agents", self.name), exist_ok=True)
|
||||
# save version
|
||||
self.memgpt_version = memgpt.__version__
|
||||
with open(self.agent_config_path, "w") as f:
|
||||
json.dump(vars(self), f, indent=4)
|
||||
|
||||
@ -342,7 +351,6 @@ class AgentConfig:
|
||||
assert os.path.exists(agent_config_path), f"Agent config file does not exist at {agent_config_path}"
|
||||
with open(agent_config_path, "r") as f:
|
||||
agent_config = json.load(f)
|
||||
|
||||
# allow compatibility accross versions
|
||||
try:
|
||||
class_args = inspect.getargspec(cls.__init__).args
|
||||
@ -354,7 +362,6 @@ class AgentConfig:
|
||||
if key not in class_args:
|
||||
utils.printd(f"Removing missing argument {key} from agent config")
|
||||
del agent_config[key]
|
||||
|
||||
return cls(**agent_config)
|
||||
|
||||
|
||||
|
@ -11,9 +11,9 @@ def embedding_model():
|
||||
# load config
|
||||
config = MemGPTConfig.load()
|
||||
|
||||
endpoint = config.embedding_model
|
||||
endpoint = config.embedding_endpoint_type
|
||||
if endpoint == "openai":
|
||||
model = OpenAIEmbedding(api_base="https://api.openai.com/v1", api_key=config.openai_key)
|
||||
model = OpenAIEmbedding(api_base=config.embedding_endpoint, api_key=config.openai_key)
|
||||
return model
|
||||
elif endpoint == "azure":
|
||||
return OpenAIEmbedding(
|
||||
|
@ -4,7 +4,7 @@ import os
|
||||
import json
|
||||
import math
|
||||
|
||||
from ...constants import MAX_PAUSE_HEARTBEATS, RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
|
||||
from memgpt.constants import MAX_PAUSE_HEARTBEATS, RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
|
||||
|
||||
### Functions / tools the agent can use
|
||||
# All functions should return a response string (or None)
|
||||
|
@ -4,8 +4,8 @@ import json
|
||||
import requests
|
||||
|
||||
|
||||
from ...constants import MESSAGE_CHATGPT_FUNCTION_MODEL, MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE, MAX_PAUSE_HEARTBEATS
|
||||
from ...openai_tools import completions_with_backoff as create
|
||||
from memgpt.constants import MESSAGE_CHATGPT_FUNCTION_MODEL, MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE, MAX_PAUSE_HEARTBEATS
|
||||
from memgpt.openai_tools import completions_with_backoff as create
|
||||
|
||||
|
||||
def message_chatgpt(self, message: str):
|
||||
|
@ -10,74 +10,65 @@ from .llamacpp.api import get_llamacpp_completion
|
||||
from .koboldcpp.api import get_koboldcpp_completion
|
||||
from .ollama.api import get_ollama_completion
|
||||
from .llm_chat_completion_wrappers import airoboros, dolphin, zephyr, simple_summary_wrapper
|
||||
from .utils import DotDict
|
||||
from .constants import DEFAULT_WRAPPER
|
||||
from .utils import DotDict, get_available_wrappers
|
||||
from ..prompts.gpt_summarize import SYSTEM as SUMMARIZE_SYSTEM_MESSAGE
|
||||
from ..errors import LocalLLMConnectionError, LocalLLMError
|
||||
|
||||
HOST = os.getenv("OPENAI_API_BASE")
|
||||
HOST_TYPE = os.getenv("BACKEND_TYPE") # default None == ChatCompletion
|
||||
endpoint = os.getenv("OPENAI_API_BASE")
|
||||
endpoint_type = os.getenv("BACKEND_TYPE") # default None == ChatCompletion
|
||||
DEBUG = False
|
||||
# DEBUG = True
|
||||
DEFAULT_WRAPPER = airoboros.Airoboros21InnerMonologueWrapper
|
||||
|
||||
has_shown_warning = False
|
||||
|
||||
|
||||
def get_chat_completion(
|
||||
model, # no model, since the model is fixed to whatever you set in your own backend
|
||||
model, # no model required (except for Ollama), since the model is fixed to whatever you set in your own backend
|
||||
messages,
|
||||
functions=None,
|
||||
function_call="auto",
|
||||
context_window=None,
|
||||
# required
|
||||
wrapper=None,
|
||||
endpoint=None,
|
||||
endpoint_type=None,
|
||||
):
|
||||
assert context_window is not None, "Local LLM calls need the context length to be explicitly set"
|
||||
assert endpoint is not None, "Local LLM calls need the endpoint (eg http://localendpoint:1234) to be explicitly set"
|
||||
assert endpoint_type is not None, "Local LLM calls need the endpoint type (eg webui) to be explicitly set"
|
||||
global has_shown_warning
|
||||
grammar_name = None
|
||||
|
||||
if HOST is None:
|
||||
raise ValueError(f"The OPENAI_API_BASE environment variable is not defined. Please set it in your environment.")
|
||||
if HOST_TYPE is None:
|
||||
raise ValueError(f"The BACKEND_TYPE environment variable is not defined. Please set it in your environment.")
|
||||
|
||||
if function_call != "auto":
|
||||
raise ValueError(f"function_call == {function_call} not supported (auto only)")
|
||||
|
||||
available_wrappers = get_available_wrappers()
|
||||
if messages[0]["role"] == "system" and messages[0]["content"].strip() == SUMMARIZE_SYSTEM_MESSAGE.strip():
|
||||
# Special case for if the call we're making is coming from the summarizer
|
||||
llm_wrapper = simple_summary_wrapper.SimpleSummaryWrapper()
|
||||
elif model == "airoboros-l2-70b-2.1":
|
||||
llm_wrapper = airoboros.Airoboros21InnerMonologueWrapper()
|
||||
elif model == "airoboros-l2-70b-2.1-grammar":
|
||||
llm_wrapper = airoboros.Airoboros21InnerMonologueWrapper(include_opening_brace_in_prefix=False)
|
||||
# grammar_name = "json"
|
||||
grammar_name = "json_func_calls_with_inner_thoughts"
|
||||
elif model == "dolphin-2.1-mistral-7b":
|
||||
llm_wrapper = dolphin.Dolphin21MistralWrapper()
|
||||
elif model == "dolphin-2.1-mistral-7b-grammar":
|
||||
llm_wrapper = dolphin.Dolphin21MistralWrapper(include_opening_brace_in_prefix=False)
|
||||
# grammar_name = "json"
|
||||
grammar_name = "json_func_calls_with_inner_thoughts"
|
||||
elif model == "zephyr-7B-alpha" or model == "zephyr-7B-beta":
|
||||
llm_wrapper = zephyr.ZephyrMistralInnerMonologueWrapper()
|
||||
elif model == "zephyr-7B-alpha-grammar" or model == "zephyr-7B-beta-grammar":
|
||||
llm_wrapper = zephyr.ZephyrMistralInnerMonologueWrapper(include_opening_brace_in_prefix=False)
|
||||
# grammar_name = "json"
|
||||
grammar_name = "json_func_calls_with_inner_thoughts"
|
||||
else:
|
||||
elif wrapper is None:
|
||||
# Warn the user that we're using the fallback
|
||||
if not has_shown_warning:
|
||||
print(
|
||||
f"Warning: no wrapper specified for local LLM, using the default wrapper (you can remove this warning by specifying the wrapper with --model)"
|
||||
f"Warning: no wrapper specified for local LLM, using the default wrapper (you can remove this warning by specifying the wrapper with --wrapper)"
|
||||
)
|
||||
has_shown_warning = True
|
||||
if HOST_TYPE in ["koboldcpp", "llamacpp", "webui"]:
|
||||
if endpoint_type in ["koboldcpp", "llamacpp", "webui"]:
|
||||
# make the default to use grammar
|
||||
llm_wrapper = DEFAULT_WRAPPER(include_opening_brace_in_prefix=False)
|
||||
# grammar_name = "json"
|
||||
grammar_name = "json_func_calls_with_inner_thoughts"
|
||||
else:
|
||||
llm_wrapper = DEFAULT_WRAPPER()
|
||||
elif wrapper not in available_wrappers:
|
||||
raise ValueError(f"Could not find requested wrapper '{wrapper} in available wrappers list:\n{available_wrappers}")
|
||||
else:
|
||||
llm_wrapper = available_wrappers[wrapper]
|
||||
if "grammar" in wrapper:
|
||||
grammar_name = "json_func_calls_with_inner_thoughts"
|
||||
|
||||
if grammar_name is not None and HOST_TYPE not in ["koboldcpp", "llamacpp", "webui"]:
|
||||
if grammar_name is not None and endpoint_type not in ["koboldcpp", "llamacpp", "webui"]:
|
||||
print(f"Warning: grammars are currently only supported when using llama.cpp as the MemGPT local LLM backend")
|
||||
|
||||
# First step: turn the message sequence into a prompt that the model expects
|
||||
@ -91,25 +82,25 @@ def get_chat_completion(
|
||||
)
|
||||
|
||||
try:
|
||||
if HOST_TYPE == "webui":
|
||||
result = get_webui_completion(prompt, context_window, grammar=grammar_name)
|
||||
elif HOST_TYPE == "lmstudio":
|
||||
result = get_lmstudio_completion(prompt, context_window)
|
||||
elif HOST_TYPE == "llamacpp":
|
||||
result = get_llamacpp_completion(prompt, context_window, grammar=grammar_name)
|
||||
elif HOST_TYPE == "koboldcpp":
|
||||
result = get_koboldcpp_completion(prompt, context_window, grammar=grammar_name)
|
||||
elif HOST_TYPE == "ollama":
|
||||
result = get_ollama_completion(prompt, context_window)
|
||||
if endpoint_type == "webui":
|
||||
result = get_webui_completion(endpoint, prompt, context_window, grammar=grammar_name)
|
||||
elif endpoint_type == "lmstudio":
|
||||
result = get_lmstudio_completion(endpoint, prompt, context_window)
|
||||
elif endpoint_type == "llamacpp":
|
||||
result = get_llamacpp_completion(endpoint, prompt, context_window, grammar=grammar_name)
|
||||
elif endpoint_type == "koboldcpp":
|
||||
result = get_koboldcpp_completion(endpoint, prompt, context_window, grammar=grammar_name)
|
||||
elif endpoint_type == "ollama":
|
||||
result = get_ollama_completion(endpoint, model, prompt, context_window)
|
||||
else:
|
||||
raise LocalLLMError(
|
||||
f"BACKEND_TYPE is not set, please set variable depending on your backend (webui, lmstudio, llamacpp, koboldcpp)"
|
||||
)
|
||||
except requests.exceptions.ConnectionError as e:
|
||||
raise LocalLLMConnectionError(f"Unable to connect to host {HOST}")
|
||||
raise LocalLLMConnectionError(f"Unable to connect to endpoint {endpoint}")
|
||||
|
||||
if result is None or result == "":
|
||||
raise LocalLLMError(f"Got back an empty response string from {HOST}")
|
||||
raise LocalLLMError(f"Got back an empty response string from {endpoint}")
|
||||
if DEBUG:
|
||||
print(f"Raw LLM output:\n{result}")
|
||||
|
||||
@ -123,7 +114,7 @@ def get_chat_completion(
|
||||
# unpack with response.choices[0].message.content
|
||||
response = DotDict(
|
||||
{
|
||||
"model": None,
|
||||
"model": model,
|
||||
"choices": [
|
||||
DotDict(
|
||||
{
|
||||
|
14
memgpt/local_llm/constants.py
Normal file
14
memgpt/local_llm/constants.py
Normal file
@ -0,0 +1,14 @@
|
||||
import memgpt.local_llm.llm_chat_completion_wrappers.airoboros as airoboros
|
||||
|
||||
DEFAULT_ENDPOINTS = {
|
||||
"koboldcpp": "http://localhost:5001",
|
||||
"llamacpp": "http://localhost:8080",
|
||||
"lmstudio": "http://localhost:1234",
|
||||
"ollama": "http://localhost:11434",
|
||||
"webui": "http://localhost:5000",
|
||||
}
|
||||
|
||||
DEFAULT_OLLAMA_MODEL = "dolphin2.2-mistral:7b-q6_K"
|
||||
|
||||
DEFAULT_WRAPPER = airoboros.Airoboros21InnerMonologueWrapper
|
||||
DEFAULT_WRAPPER_NAME = "airoboros-l2-70b-2.1"
|
@ -5,14 +5,12 @@ import requests
|
||||
from .settings import SIMPLE
|
||||
from ..utils import load_grammar_file, count_tokens
|
||||
|
||||
HOST = os.getenv("OPENAI_API_BASE")
|
||||
HOST_TYPE = os.getenv("BACKEND_TYPE") # default None == ChatCompletion
|
||||
KOBOLDCPP_API_SUFFIX = "/api/v1/generate"
|
||||
# DEBUG = False
|
||||
DEBUG = True
|
||||
DEBUG = False
|
||||
# DEBUG = True
|
||||
|
||||
|
||||
def get_koboldcpp_completion(prompt, context_window, grammar=None, settings=SIMPLE):
|
||||
def get_koboldcpp_completion(endpoint, prompt, context_window, grammar=None, settings=SIMPLE):
|
||||
"""See https://lite.koboldai.net/koboldcpp_api for API spec"""
|
||||
prompt_tokens = count_tokens(prompt)
|
||||
if prompt_tokens > context_window:
|
||||
@ -27,13 +25,13 @@ def get_koboldcpp_completion(prompt, context_window, grammar=None, settings=SIMP
|
||||
if grammar is not None:
|
||||
request["grammar"] = load_grammar_file(grammar)
|
||||
|
||||
if not HOST.startswith(("http://", "https://")):
|
||||
raise ValueError(f"Provided OPENAI_API_BASE value ({HOST}) must begin with http:// or https://")
|
||||
if not endpoint.startswith(("http://", "https://")):
|
||||
raise ValueError(f"Provided OPENAI_API_BASE value ({endpoint}) must begin with http:// or https://")
|
||||
|
||||
try:
|
||||
# NOTE: llama.cpp server returns the following when it's out of context
|
||||
# curl: (52) Empty reply from server
|
||||
URI = urljoin(HOST.strip("/") + "/", KOBOLDCPP_API_SUFFIX.strip("/"))
|
||||
URI = urljoin(endpoint.strip("/") + "/", KOBOLDCPP_API_SUFFIX.strip("/"))
|
||||
response = requests.post(URI, json=request)
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
|
@ -5,14 +5,12 @@ import requests
|
||||
from .settings import SIMPLE
|
||||
from ..utils import load_grammar_file, count_tokens
|
||||
|
||||
HOST = os.getenv("OPENAI_API_BASE")
|
||||
HOST_TYPE = os.getenv("BACKEND_TYPE") # default None == ChatCompletion
|
||||
LLAMACPP_API_SUFFIX = "/completion"
|
||||
# DEBUG = False
|
||||
DEBUG = True
|
||||
DEBUG = False
|
||||
# DEBUG = True
|
||||
|
||||
|
||||
def get_llamacpp_completion(prompt, context_window, grammar=None, settings=SIMPLE):
|
||||
def get_llamacpp_completion(endpoint, prompt, context_window, grammar=None, settings=SIMPLE):
|
||||
"""See https://github.com/ggerganov/llama.cpp/blob/master/examples/server/README.md for instructions on how to run the LLM web server"""
|
||||
prompt_tokens = count_tokens(prompt)
|
||||
if prompt_tokens > context_window:
|
||||
@ -26,13 +24,13 @@ def get_llamacpp_completion(prompt, context_window, grammar=None, settings=SIMPL
|
||||
if grammar is not None:
|
||||
request["grammar"] = load_grammar_file(grammar)
|
||||
|
||||
if not HOST.startswith(("http://", "https://")):
|
||||
raise ValueError(f"Provided OPENAI_API_BASE value ({HOST}) must begin with http:// or https://")
|
||||
if not endpoint.startswith(("http://", "https://")):
|
||||
raise ValueError(f"Provided OPENAI_API_BASE value ({endpoint}) must begin with http:// or https://")
|
||||
|
||||
try:
|
||||
# NOTE: llama.cpp server returns the following when it's out of context
|
||||
# curl: (52) Empty reply from server
|
||||
URI = urljoin(HOST.strip("/") + "/", LLAMACPP_API_SUFFIX.strip("/"))
|
||||
URI = urljoin(endpoint.strip("/") + "/", LLAMACPP_API_SUFFIX.strip("/"))
|
||||
response = requests.post(URI, json=request)
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
|
@ -5,14 +5,12 @@ import requests
|
||||
from .settings import SIMPLE
|
||||
from ..utils import count_tokens
|
||||
|
||||
HOST = os.getenv("OPENAI_API_BASE")
|
||||
HOST_TYPE = os.getenv("BACKEND_TYPE") # default None == ChatCompletion
|
||||
LMSTUDIO_API_CHAT_SUFFIX = "/v1/chat/completions"
|
||||
LMSTUDIO_API_COMPLETIONS_SUFFIX = "/v1/completions"
|
||||
DEBUG = False
|
||||
|
||||
|
||||
def get_lmstudio_completion(prompt, context_window, settings=SIMPLE, api="chat"):
|
||||
def get_lmstudio_completion(endpoint, prompt, context_window, settings=SIMPLE, api="chat"):
|
||||
"""Based on the example for using LM Studio as a backend from https://github.com/lmstudio-ai/examples/tree/main/Hello%2C%20world%20-%20OpenAI%20python%20client"""
|
||||
prompt_tokens = count_tokens(prompt)
|
||||
if prompt_tokens > context_window:
|
||||
@ -25,19 +23,19 @@ def get_lmstudio_completion(prompt, context_window, settings=SIMPLE, api="chat")
|
||||
if api == "chat":
|
||||
# Uses the ChatCompletions API style
|
||||
# Seems to work better, probably because it's applying some extra settings under-the-hood?
|
||||
URI = urljoin(HOST.strip("/") + "/", LMSTUDIO_API_CHAT_SUFFIX.strip("/"))
|
||||
URI = urljoin(endpoint.strip("/") + "/", LMSTUDIO_API_CHAT_SUFFIX.strip("/"))
|
||||
message_structure = [{"role": "user", "content": prompt}]
|
||||
request["messages"] = message_structure
|
||||
elif api == "completions":
|
||||
# Uses basic string completions (string in, string out)
|
||||
# Does not work as well as ChatCompletions for some reason
|
||||
URI = urljoin(HOST.strip("/") + "/", LMSTUDIO_API_COMPLETIONS_SUFFIX.strip("/"))
|
||||
URI = urljoin(endpoint.strip("/") + "/", LMSTUDIO_API_COMPLETIONS_SUFFIX.strip("/"))
|
||||
request["prompt"] = prompt
|
||||
else:
|
||||
raise ValueError(api)
|
||||
|
||||
if not HOST.startswith(("http://", "https://")):
|
||||
raise ValueError(f"Provided OPENAI_API_BASE value ({HOST}) must begin with http:// or https://")
|
||||
if not endpoint.startswith(("http://", "https://")):
|
||||
raise ValueError(f"Provided OPENAI_API_BASE value ({endpoint}) must begin with http:// or https://")
|
||||
|
||||
try:
|
||||
response = requests.post(URI, json=request)
|
||||
|
@ -6,26 +6,25 @@ from .settings import SIMPLE
|
||||
from ..utils import count_tokens
|
||||
from ...errors import LocalLLMError
|
||||
|
||||
HOST = os.getenv("OPENAI_API_BASE")
|
||||
HOST_TYPE = os.getenv("BACKEND_TYPE") # default None == ChatCompletion
|
||||
MODEL_NAME = os.getenv("OLLAMA_MODEL") # ollama API requires this in the request
|
||||
OLLAMA_API_SUFFIX = "/api/generate"
|
||||
DEBUG = False
|
||||
|
||||
|
||||
def get_ollama_completion(prompt, context_window, settings=SIMPLE, grammar=None):
|
||||
def get_ollama_completion(endpoint, model, prompt, context_window, settings=SIMPLE, grammar=None):
|
||||
"""See https://github.com/jmorganca/ollama/blob/main/docs/api.md for instructions on how to run the LLM web server"""
|
||||
prompt_tokens = count_tokens(prompt)
|
||||
if prompt_tokens > context_window:
|
||||
raise Exception(f"Request exceeds maximum context length ({prompt_tokens} > {context_window} tokens)")
|
||||
|
||||
if MODEL_NAME is None:
|
||||
raise LocalLLMError(f"Error: OLLAMA_MODEL not specified. Set OLLAMA_MODEL to the model you want to run (e.g. 'dolphin2.2-mistral')")
|
||||
if model is None:
|
||||
raise LocalLLMError(
|
||||
f"Error: model name not specified. Set model in your config to the model you want to run (e.g. 'dolphin2.2-mistral')"
|
||||
)
|
||||
|
||||
# Settings for the generation, includes the prompt + stop tokens, max length, etc
|
||||
request = settings
|
||||
request["prompt"] = prompt
|
||||
request["model"] = MODEL_NAME
|
||||
request["model"] = model
|
||||
request["options"]["num_ctx"] = context_window
|
||||
|
||||
# Set grammar
|
||||
@ -33,11 +32,11 @@ def get_ollama_completion(prompt, context_window, settings=SIMPLE, grammar=None)
|
||||
# request["grammar_string"] = load_grammar_file(grammar)
|
||||
raise NotImplementedError(f"Ollama does not support grammars")
|
||||
|
||||
if not HOST.startswith(("http://", "https://")):
|
||||
raise ValueError(f"Provided OPENAI_API_BASE value ({HOST}) must begin with http:// or https://")
|
||||
if not endpoint.startswith(("http://", "https://")):
|
||||
raise ValueError(f"Provided OPENAI_API_BASE value ({endpoint}) must begin with http:// or https://")
|
||||
|
||||
try:
|
||||
URI = urljoin(HOST.strip("/") + "/", OLLAMA_API_SUFFIX.strip("/"))
|
||||
URI = urljoin(endpoint.strip("/") + "/", OLLAMA_API_SUFFIX.strip("/"))
|
||||
response = requests.post(URI, json=request)
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
|
@ -1,6 +1,10 @@
|
||||
import os
|
||||
import tiktoken
|
||||
|
||||
import memgpt.local_llm.llm_chat_completion_wrappers.airoboros as airoboros
|
||||
import memgpt.local_llm.llm_chat_completion_wrappers.dolphin as dolphin
|
||||
import memgpt.local_llm.llm_chat_completion_wrappers.zephyr as zephyr
|
||||
|
||||
|
||||
class DotDict(dict):
|
||||
"""Allow dot access on properties similar to OpenAI response object"""
|
||||
@ -37,3 +41,14 @@ def load_grammar_file(grammar):
|
||||
def count_tokens(s: str, model: str = "gpt-4") -> int:
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
return len(encoding.encode(s))
|
||||
|
||||
|
||||
def get_available_wrappers() -> dict:
|
||||
return {
|
||||
"airoboros-l2-70b-2.1": airoboros.Airoboros21InnerMonologueWrapper(),
|
||||
"airoboros-l2-70b-2.1-grammar": airoboros.Airoboros21InnerMonologueWrapper(include_opening_brace_in_prefix=False),
|
||||
"dolphin-2.1-mistral-7b": dolphin.Dolphin21MistralWrapper(),
|
||||
"dolphin-2.1-mistral-7b-grammar": dolphin.Dolphin21MistralWrapper(include_opening_brace_in_prefix=False),
|
||||
"zephyr-7B": zephyr.ZephyrMistralInnerMonologueWrapper(),
|
||||
"zephyr-7B-grammar": zephyr.ZephyrMistralInnerMonologueWrapper(include_opening_brace_in_prefix=False),
|
||||
}
|
||||
|
@ -5,13 +5,11 @@ import requests
|
||||
from .settings import SIMPLE
|
||||
from ..utils import load_grammar_file, count_tokens
|
||||
|
||||
HOST = os.getenv("OPENAI_API_BASE")
|
||||
HOST_TYPE = os.getenv("BACKEND_TYPE") # default None == ChatCompletion
|
||||
WEBUI_API_SUFFIX = "/api/v1/generate"
|
||||
DEBUG = False
|
||||
|
||||
|
||||
def get_webui_completion(prompt, context_window, settings=SIMPLE, grammar=None):
|
||||
def get_webui_completion(endpoint, prompt, context_window, settings=SIMPLE, grammar=None):
|
||||
"""See https://github.com/oobabooga/text-generation-webui for instructions on how to run the LLM web server"""
|
||||
prompt_tokens = count_tokens(prompt)
|
||||
if prompt_tokens > context_window:
|
||||
@ -26,11 +24,11 @@ def get_webui_completion(prompt, context_window, settings=SIMPLE, grammar=None):
|
||||
if grammar is not None:
|
||||
request["grammar_string"] = load_grammar_file(grammar)
|
||||
|
||||
if not HOST.startswith(("http://", "https://")):
|
||||
raise ValueError(f"Provided OPENAI_API_BASE value ({HOST}) must begin with http:// or https://")
|
||||
if not endpoint.startswith(("http://", "https://")):
|
||||
raise ValueError(f"Provided OPENAI_API_BASE value ({endpoint}) must begin with http:// or https://")
|
||||
|
||||
try:
|
||||
URI = urljoin(HOST.strip("/") + "/", WEBUI_API_SUFFIX.strip("/"))
|
||||
URI = urljoin(endpoint.strip("/") + "/", WEBUI_API_SUFFIX.strip("/"))
|
||||
response = requests.post(URI, json=request)
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
|
@ -241,7 +241,7 @@ def main(
|
||||
memgpt_persona = persona
|
||||
if memgpt_persona is None:
|
||||
memgpt_persona = (
|
||||
personas.GPT35_DEFAULT if "gpt-3.5" in model else personas.DEFAULT,
|
||||
personas.GPT35_DEFAULT if (model is not None and "gpt-3.5" in model) else personas.DEFAULT,
|
||||
None, # represents the personas dir in pymemgpt package
|
||||
)
|
||||
else:
|
||||
|
@ -2,10 +2,14 @@ import random
|
||||
import os
|
||||
import time
|
||||
|
||||
from .local_llm.chat_completion_proxy import get_chat_completion
|
||||
import time
|
||||
from typing import Callable, TypeVar
|
||||
|
||||
from memgpt.local_llm.chat_completion_proxy import get_chat_completion
|
||||
|
||||
HOST = os.getenv("OPENAI_API_BASE")
|
||||
HOST_TYPE = os.getenv("BACKEND_TYPE") # default None == ChatCompletion
|
||||
R = TypeVar("R")
|
||||
|
||||
import openai
|
||||
|
||||
@ -55,6 +59,7 @@ def retry_with_exponential_backoff(
|
||||
return wrapper
|
||||
|
||||
|
||||
# TODO: delete/ignore --legacy
|
||||
@retry_with_exponential_backoff
|
||||
def completions_with_backoff(**kwargs):
|
||||
# Local model
|
||||
@ -75,6 +80,38 @@ def completions_with_backoff(**kwargs):
|
||||
return openai.ChatCompletion.create(**kwargs)
|
||||
|
||||
|
||||
@retry_with_exponential_backoff
|
||||
def chat_completion_with_backoff(agent_config, **kwargs):
|
||||
from memgpt.utils import printd
|
||||
from memgpt.config import AgentConfig, MemGPTConfig
|
||||
|
||||
printd(f"Using model {agent_config.model_endpoint_type}, endpoint: {agent_config.model_endpoint}")
|
||||
if agent_config.model_endpoint_type == "openai":
|
||||
# openai
|
||||
openai.api_base = agent_config.model_endpoint
|
||||
return openai.ChatCompletion.create(**kwargs)
|
||||
elif agent_config.model_endpoint_type == "azure":
|
||||
# configure openai
|
||||
config = MemGPTConfig.load() # load credentials (currently not stored in agent config)
|
||||
openai.api_type = "azure"
|
||||
openai.api_key = config.azure_key
|
||||
openai.api_base = config.azure_endpoint
|
||||
openai.api_version = config.azure_version
|
||||
if config.azure_deployment is not None:
|
||||
kwargs["deployment_id"] = config.azure_deployment
|
||||
else:
|
||||
kwargs["engine"] = MODEL_TO_AZURE_ENGINE[config.model]
|
||||
del kwargs["model"]
|
||||
return openai.ChatCompletion.create(**kwargs)
|
||||
else: # local model
|
||||
kwargs["context_window"] = agent_config.context_window # specify for open LLMs
|
||||
kwargs["endpoint"] = agent_config.model_endpoint # specify for open LLMs
|
||||
kwargs["endpoint_type"] = agent_config.model_endpoint_type # specify for open LLMs
|
||||
kwargs["wrapper"] = agent_config.model_wrapper # specify for open LLMs
|
||||
return get_chat_completion(**kwargs)
|
||||
|
||||
|
||||
# TODO: deprecate
|
||||
@retry_with_exponential_backoff
|
||||
def create_embedding_with_backoff(**kwargs):
|
||||
if using_azure():
|
||||
|
@ -57,5 +57,5 @@ def use_preset(preset_name, agent_config, model, persona, human, interface, pers
|
||||
persona_notes=persona,
|
||||
human_notes=human,
|
||||
# gpt-3.5-turbo tends to omit inner monologue, relax this requirement for now
|
||||
first_message_verify_mono=True if "gpt-4" in model else False,
|
||||
first_message_verify_mono=True if (model is not None and "gpt-4" in model) else False,
|
||||
)
|
||||
|
@ -13,7 +13,7 @@ def test_configure_memgpt():
|
||||
|
||||
|
||||
def test_save_load():
|
||||
configure_memgpt()
|
||||
# configure_memgpt() # rely on configure running first^
|
||||
child = pexpect.spawn("memgpt run --agent test_save_load --first --strip_ui")
|
||||
|
||||
child.expect("Enter your message:", timeout=TIMEOUT)
|
||||
|
@ -1,12 +1,12 @@
|
||||
# import tempfile
|
||||
# import asyncio
|
||||
import os
|
||||
# import os
|
||||
|
||||
# import asyncio
|
||||
from datasets import load_dataset
|
||||
# from datasets import load_dataset
|
||||
|
||||
import memgpt
|
||||
from memgpt.cli.cli_load import load_directory, load_database, load_webpage
|
||||
# import memgpt
|
||||
# from memgpt.cli.cli_load import load_directory, load_database, load_webpage
|
||||
|
||||
# import memgpt.presets as presets
|
||||
# import memgpt.personas.personas as personas
|
||||
|
@ -1,6 +1,7 @@
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import pytest
|
||||
|
||||
subprocess.check_call(
|
||||
[sys.executable, "-m", "pip", "install", "pgvector", "psycopg", "psycopg2-binary"]
|
||||
@ -15,9 +16,11 @@ from memgpt.config import MemGPTConfig, AgentConfig
|
||||
import argparse
|
||||
|
||||
|
||||
@pytest.mark.skipif(not os.getenv("PGVECTOR_TEST_DB_URL") or not os.getenv("OPENAI_API_KEY"), reason="Missing PG URI and/or OpenAI API key")
|
||||
def test_postgres_openai():
|
||||
assert os.getenv("PGVECTOR_TEST_DB_URL") is not None
|
||||
if os.getenv("OPENAI_API_KEY") is None:
|
||||
if not os.getenv("PGVECTOR_TEST_DB_URL"):
|
||||
return # soft pass
|
||||
if not os.getenv("OPENAI_API_KEY"):
|
||||
return # soft pass
|
||||
|
||||
# os.environ["MEMGPT_CONFIG_PATH"] = "./config"
|
||||
@ -54,14 +57,16 @@ def test_postgres_openai():
|
||||
# print("...finished")
|
||||
|
||||
|
||||
@pytest.mark.skipif(not os.getenv("PGVECTOR_TEST_DB_URL"), reason="Missing PG URI")
|
||||
def test_postgres_local():
|
||||
assert os.getenv("PGVECTOR_TEST_DB_URL") is not None
|
||||
if not os.getenv("PGVECTOR_TEST_DB_URL"):
|
||||
return
|
||||
# os.environ["MEMGPT_CONFIG_PATH"] = "./config"
|
||||
|
||||
config = MemGPTConfig(
|
||||
archival_storage_type="postgres",
|
||||
archival_storage_uri=os.getenv("PGVECTOR_TEST_DB_URL"),
|
||||
embedding_model="local",
|
||||
embedding_endpoint_type="local",
|
||||
embedding_dim=384, # use HF model
|
||||
)
|
||||
print(config.config_path)
|
||||
|
@ -3,43 +3,55 @@ import pexpect
|
||||
from .constants import TIMEOUT
|
||||
|
||||
|
||||
def configure_memgpt(enable_openai=True, enable_azure=False):
|
||||
def configure_memgpt_localllm():
|
||||
child = pexpect.spawn("memgpt configure")
|
||||
|
||||
child.expect("Do you want to enable MemGPT with OpenAI?", timeout=TIMEOUT)
|
||||
if enable_openai:
|
||||
child.sendline("y")
|
||||
else:
|
||||
child.sendline("n")
|
||||
|
||||
child.expect("Do you want to enable MemGPT with Azure?", timeout=TIMEOUT)
|
||||
if enable_azure:
|
||||
child.sendline("y")
|
||||
else:
|
||||
child.sendline("n")
|
||||
|
||||
child.expect("Select default inference endpoint:", timeout=TIMEOUT)
|
||||
child.expect("Select LLM inference provider", timeout=TIMEOUT)
|
||||
child.send("\x1b[B") # Send the down arrow key
|
||||
child.send("\x1b[B") # Send the down arrow key
|
||||
child.sendline()
|
||||
|
||||
child.expect("Select default embedding endpoint:", timeout=TIMEOUT)
|
||||
child.expect("Select LLM backend", timeout=TIMEOUT)
|
||||
child.sendline()
|
||||
|
||||
child.expect("Select default preset:", timeout=TIMEOUT)
|
||||
child.expect("Enter default endpoint", timeout=TIMEOUT)
|
||||
child.sendline()
|
||||
|
||||
child.expect("Select default model", timeout=TIMEOUT)
|
||||
child.expect("Select default model wrapper", timeout=TIMEOUT)
|
||||
child.sendline()
|
||||
|
||||
child.expect("Select default persona:", timeout=TIMEOUT)
|
||||
child.expect("Select your model's context window", timeout=TIMEOUT)
|
||||
child.sendline()
|
||||
|
||||
child.expect("Select default human:", timeout=TIMEOUT)
|
||||
child.expect("Select embedding provider", timeout=TIMEOUT)
|
||||
child.send("\x1b[B") # Send the down arrow key
|
||||
child.send("\x1b[B") # Send the down arrow key
|
||||
child.sendline()
|
||||
|
||||
child.expect("Select default preset", timeout=TIMEOUT)
|
||||
child.sendline()
|
||||
|
||||
child.expect("Select default persona", timeout=TIMEOUT)
|
||||
child.sendline()
|
||||
|
||||
child.expect("Select default human", timeout=TIMEOUT)
|
||||
child.sendline()
|
||||
|
||||
child.expect("Select storage backend for archival data", timeout=TIMEOUT)
|
||||
child.sendline()
|
||||
|
||||
child.expect("Select storage backend for archival data:", timeout=TIMEOUT)
|
||||
child.sendline()
|
||||
|
||||
child.expect(pexpect.EOF, timeout=TIMEOUT) # Wait for child to exit
|
||||
child.close()
|
||||
assert child.isalive() is False, "CLI should have terminated."
|
||||
assert child.exitstatus == 0, "CLI did not exit cleanly."
|
||||
|
||||
|
||||
def configure_memgpt(enable_openai=False, enable_azure=False):
|
||||
if enable_openai:
|
||||
raise NotImplementedError
|
||||
elif enable_azure:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
configure_memgpt_localllm()
|
||||
|
Loading…
Reference in New Issue
Block a user