mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
Refactor config + determine LLM via config.model_endpoint_type
(#422)
* mark depricated API section * CLI bug fixes for azure * check azure before running * Update README.md * Update README.md * bug fix with persona loading * remove print * make errors for cli flags more clear * format * fix imports * fix imports * add prints * update lock * update config fields * cleanup config loading * commit * remove asserts * refactor configure * put into different functions * add embedding default * pass in config * fixes * allow overriding openai embedding endpoint * black * trying to patch tests (some circular import errors) * update flags and docs * patched support for local llms using endpoint and endpoint type passed via configs, not env vars * missing files * fix naming * fix import * fix two runtime errors * patch ollama typo, move ollama model question pre-wrapper, modify question phrasing to include link to readthedocs, also have a default ollama model that has a tag included * disable debug messages * made error message for failed load more informative * don't print dynamic linking function warning unless --debug * updated tests to work with new cli workflow (disabled openai config test for now) * added skips for tests when vars are missing * update bad arg * revise test to soft pass on empty string too * don't run configure twice * extend timeout (try to pass against nltk download) * update defaults * typo with endpoint type default * patch runtime errors for when model is None * catching another case of 'x in model' when model is None (preemptively) * allow overrides to local llm related config params * made model wrapper selection from a list vs raw input * update test for select instead of input * Fixed bug in endpoint when using local->openai selection, also added validation loop to manual endpoint entry * updated error messages to be more informative with links to readthedocs * add back gpt3.5-turbo --------- Co-authored-by: cpacker <packercharles@gmail.com>
This commit is contained in:
parent
8fdc3a29da
commit
28514da5df
@ -6,15 +6,21 @@ The `memgpt run` command supports the following optional flags (if set, will ove
|
|||||||
* `--agent`: (str) Name of agent to create or to resume chatting with.
|
* `--agent`: (str) Name of agent to create or to resume chatting with.
|
||||||
* `--human`: (str) Name of the human to run the agent with.
|
* `--human`: (str) Name of the human to run the agent with.
|
||||||
* `--persona`: (str) Name of agent persona to use.
|
* `--persona`: (str) Name of agent persona to use.
|
||||||
* `--model`: (str) LLM model to run [gpt-4, gpt-3.5].
|
* `--model`: (str) LLM model to run (e.g. `gpt-4`, `dolphin_xxx`)
|
||||||
* `--preset`: (str) MemGPT preset to run agent with.
|
* `--preset`: (str) MemGPT preset to run agent with.
|
||||||
* `--first`: (str) Allow user to sent the first message.
|
* `--first`: (str) Allow user to sent the first message.
|
||||||
* `--debug`: (bool) Show debug logs (default=False)
|
* `--debug`: (bool) Show debug logs (default=False)
|
||||||
* `--no-verify`: (bool) Bypass message verification (default=False)
|
* `--no-verify`: (bool) Bypass message verification (default=False)
|
||||||
* `--yes`/`-y`: (bool) Skip confirmation prompt and use defaults (default=False)
|
* `--yes`/`-y`: (bool) Skip confirmation prompt and use defaults (default=False)
|
||||||
|
|
||||||
|
You can override the parameters you set with `memgpt configure` with the following additional flags specific to local LLMs:
|
||||||
|
* `--model-wrapper`: (str) Model wrapper used by backend (e.g. `airoboros_xxx`)
|
||||||
|
* `--model-endpoint-type`: (str) Model endpoint backend type (e.g. lmstudio, ollama)
|
||||||
|
* `--model-endpoint`: (str) Model endpoint url (e.g. `localhost:5000`)
|
||||||
|
* `--context-window`: (int) Size of model context window (specific to model type)
|
||||||
|
|
||||||
#### Updating the config location
|
#### Updating the config location
|
||||||
You can override the location of the config path by setting the enviornment variable `MEMGPT_CONFIG_PATH`:
|
You can override the location of the config path by setting the environment variable `MEMGPT_CONFIG_PATH`:
|
||||||
```
|
```
|
||||||
export MEMGPT_CONFIG_PATH=/my/custom/path/config # make sure this is a file, not a directory
|
export MEMGPT_CONFIG_PATH=/my/custom/path/config # make sure this is a file, not a directory
|
||||||
```
|
```
|
||||||
|
@ -9,6 +9,7 @@ from memgpt.config import AgentConfig
|
|||||||
from .system import get_login_event, package_function_response, package_summarize_message, get_initial_boot_messages
|
from .system import get_login_event, package_function_response, package_summarize_message, get_initial_boot_messages
|
||||||
from .memory import CoreMemory as Memory, summarize_messages
|
from .memory import CoreMemory as Memory, summarize_messages
|
||||||
from .openai_tools import completions_with_backoff as create
|
from .openai_tools import completions_with_backoff as create
|
||||||
|
from memgpt.openai_tools import chat_completion_with_backoff
|
||||||
from .utils import get_local_time, parse_json, united_diff, printd, count_tokens, get_schema_diff
|
from .utils import get_local_time, parse_json, united_diff, printd, count_tokens, get_schema_diff
|
||||||
from .constants import (
|
from .constants import (
|
||||||
FIRST_MESSAGE_ATTEMPTS,
|
FIRST_MESSAGE_ATTEMPTS,
|
||||||
@ -73,7 +74,7 @@ def initialize_message_sequence(
|
|||||||
first_user_message = get_login_event() # event letting MemGPT know the user just logged in
|
first_user_message = get_login_event() # event letting MemGPT know the user just logged in
|
||||||
|
|
||||||
if include_initial_boot_message:
|
if include_initial_boot_message:
|
||||||
if "gpt-3.5" in model:
|
if model is not None and "gpt-3.5" in model:
|
||||||
initial_boot_messages = get_initial_boot_messages("startup_with_send_message_gpt35")
|
initial_boot_messages = get_initial_boot_messages("startup_with_send_message_gpt35")
|
||||||
else:
|
else:
|
||||||
initial_boot_messages = get_initial_boot_messages("startup_with_send_message")
|
initial_boot_messages = get_initial_boot_messages("startup_with_send_message")
|
||||||
@ -96,37 +97,6 @@ def initialize_message_sequence(
|
|||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|
||||||
def get_ai_reply(
|
|
||||||
model,
|
|
||||||
message_sequence,
|
|
||||||
functions,
|
|
||||||
function_call="auto",
|
|
||||||
context_window=None,
|
|
||||||
):
|
|
||||||
try:
|
|
||||||
response = create(
|
|
||||||
model=model,
|
|
||||||
context_window=context_window,
|
|
||||||
messages=message_sequence,
|
|
||||||
functions=functions,
|
|
||||||
function_call=function_call,
|
|
||||||
)
|
|
||||||
|
|
||||||
# special case for 'length'
|
|
||||||
if response.choices[0].finish_reason == "length":
|
|
||||||
raise Exception("Finish reason was length (maximum context length)")
|
|
||||||
|
|
||||||
# catches for soft errors
|
|
||||||
if response.choices[0].finish_reason not in ["stop", "function_call"]:
|
|
||||||
raise Exception(f"API call finish with bad finish reason: {response}")
|
|
||||||
|
|
||||||
# unpack with response.choices[0].message.content
|
|
||||||
return response
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
raise e
|
|
||||||
|
|
||||||
|
|
||||||
class Agent(object):
|
class Agent(object):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -310,7 +280,7 @@ class Agent(object):
|
|||||||
json_files = glob.glob(os.path.join(directory, "*.json")) # This will list all .json files in the current directory.
|
json_files = glob.glob(os.path.join(directory, "*.json")) # This will list all .json files in the current directory.
|
||||||
if not json_files:
|
if not json_files:
|
||||||
print(f"/load error: no .json checkpoint files found")
|
print(f"/load error: no .json checkpoint files found")
|
||||||
raise ValueError(f"Cannot load {agent_name}: does not exist in {directory}")
|
raise ValueError(f"Cannot load {agent_name} - no saved checkpoints found in {directory}")
|
||||||
|
|
||||||
# Sort files based on modified timestamp, with the latest file being the first.
|
# Sort files based on modified timestamp, with the latest file being the first.
|
||||||
filename = max(json_files, key=os.path.getmtime)
|
filename = max(json_files, key=os.path.getmtime)
|
||||||
@ -360,7 +330,7 @@ class Agent(object):
|
|||||||
|
|
||||||
# NOTE to handle old configs, instead of erroring here let's just warn
|
# NOTE to handle old configs, instead of erroring here let's just warn
|
||||||
# raise ValueError(error_message)
|
# raise ValueError(error_message)
|
||||||
print(error_message)
|
printd(error_message)
|
||||||
linked_function_set[f_name] = linked_function
|
linked_function_set[f_name] = linked_function
|
||||||
|
|
||||||
messages = state["messages"]
|
messages = state["messages"]
|
||||||
@ -602,8 +572,7 @@ class Agent(object):
|
|||||||
printd(f"This is the first message. Running extra verifier on AI response.")
|
printd(f"This is the first message. Running extra verifier on AI response.")
|
||||||
counter = 0
|
counter = 0
|
||||||
while True:
|
while True:
|
||||||
response = get_ai_reply(
|
response = self.get_ai_reply(
|
||||||
model=self.model,
|
|
||||||
message_sequence=input_message_sequence,
|
message_sequence=input_message_sequence,
|
||||||
functions=self.functions,
|
functions=self.functions,
|
||||||
context_window=None if self.config.context_window is None else int(self.config.context_window),
|
context_window=None if self.config.context_window is None else int(self.config.context_window),
|
||||||
@ -616,8 +585,7 @@ class Agent(object):
|
|||||||
raise Exception(f"Hit first message retry limit ({first_message_retry_limit})")
|
raise Exception(f"Hit first message retry limit ({first_message_retry_limit})")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
response = get_ai_reply(
|
response = self.get_ai_reply(
|
||||||
model=self.model,
|
|
||||||
message_sequence=input_message_sequence,
|
message_sequence=input_message_sequence,
|
||||||
functions=self.functions,
|
functions=self.functions,
|
||||||
context_window=None if self.config.context_window is None else int(self.config.context_window),
|
context_window=None if self.config.context_window is None else int(self.config.context_window),
|
||||||
@ -785,3 +753,55 @@ class Agent(object):
|
|||||||
# Check if it's been more than pause_heartbeats_minutes since pause_heartbeats_start
|
# Check if it's been more than pause_heartbeats_minutes since pause_heartbeats_start
|
||||||
elapsed_time = datetime.datetime.now() - self.pause_heartbeats_start
|
elapsed_time = datetime.datetime.now() - self.pause_heartbeats_start
|
||||||
return elapsed_time.total_seconds() < self.pause_heartbeats_minutes * 60
|
return elapsed_time.total_seconds() < self.pause_heartbeats_minutes * 60
|
||||||
|
|
||||||
|
def get_ai_reply(
|
||||||
|
self,
|
||||||
|
message_sequence,
|
||||||
|
function_call="auto",
|
||||||
|
):
|
||||||
|
"""Get response from LLM API"""
|
||||||
|
|
||||||
|
# TODO: Legacy code - delete
|
||||||
|
if self.config is None:
|
||||||
|
try:
|
||||||
|
response = create(
|
||||||
|
model=self.model,
|
||||||
|
context_window=self.context_window,
|
||||||
|
messages=message_sequence,
|
||||||
|
functions=self.functions,
|
||||||
|
function_call=function_call,
|
||||||
|
)
|
||||||
|
|
||||||
|
# special case for 'length'
|
||||||
|
if response.choices[0].finish_reason == "length":
|
||||||
|
raise Exception("Finish reason was length (maximum context length)")
|
||||||
|
|
||||||
|
# catches for soft errors
|
||||||
|
if response.choices[0].finish_reason not in ["stop", "function_call"]:
|
||||||
|
raise Exception(f"API call finish with bad finish reason: {response}")
|
||||||
|
|
||||||
|
# unpack with response.choices[0].message.content
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = chat_completion_with_backoff(
|
||||||
|
agent_config=self.config,
|
||||||
|
model=self.model, # TODO: remove (is redundant)
|
||||||
|
messages=message_sequence,
|
||||||
|
functions=self.functions,
|
||||||
|
function_call=function_call,
|
||||||
|
)
|
||||||
|
# special case for 'length'
|
||||||
|
if response.choices[0].finish_reason == "length":
|
||||||
|
raise Exception("Finish reason was length (maximum context length)")
|
||||||
|
|
||||||
|
# catches for soft errors
|
||||||
|
if response.choices[0].finish_reason not in ["stop", "function_call"]:
|
||||||
|
raise Exception(f"API call finish with bad finish reason: {response}")
|
||||||
|
|
||||||
|
# unpack with response.choices[0].message.content
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import typer
|
import typer
|
||||||
|
import json
|
||||||
import sys
|
import sys
|
||||||
import io
|
import io
|
||||||
import logging
|
import logging
|
||||||
@ -35,16 +36,21 @@ def run(
|
|||||||
persona: str = typer.Option(None, help="Specify persona"),
|
persona: str = typer.Option(None, help="Specify persona"),
|
||||||
agent: str = typer.Option(None, help="Specify agent save file"),
|
agent: str = typer.Option(None, help="Specify agent save file"),
|
||||||
human: str = typer.Option(None, help="Specify human"),
|
human: str = typer.Option(None, help="Specify human"),
|
||||||
model: str = typer.Option(None, help="Specify the LLM model"),
|
|
||||||
preset: str = typer.Option(None, help="Specify preset"),
|
preset: str = typer.Option(None, help="Specify preset"),
|
||||||
|
# model flags
|
||||||
|
model: str = typer.Option(None, help="Specify the LLM model"),
|
||||||
|
model_wrapper: str = typer.Option(None, help="Specify the LLM model wrapper"),
|
||||||
|
model_endpoint: str = typer.Option(None, help="Specify the LLM model endpoint"),
|
||||||
|
model_endpoint_type: str = typer.Option(None, help="Specify the LLM model endpoint type"),
|
||||||
|
context_window: int = typer.Option(
|
||||||
|
None, "--context_window", help="The context window of the LLM you are using (e.g. 8k for most Mistral 7B variants)"
|
||||||
|
),
|
||||||
|
# other
|
||||||
first: bool = typer.Option(False, "--first", help="Use --first to send the first message in the sequence"),
|
first: bool = typer.Option(False, "--first", help="Use --first to send the first message in the sequence"),
|
||||||
strip_ui: bool = typer.Option(False, "--strip_ui", help="Remove all the bells and whistles in CLI output (helpful for testing)"),
|
strip_ui: bool = typer.Option(False, "--strip_ui", help="Remove all the bells and whistles in CLI output (helpful for testing)"),
|
||||||
debug: bool = typer.Option(False, "--debug", help="Use --debug to enable debugging output"),
|
debug: bool = typer.Option(False, "--debug", help="Use --debug to enable debugging output"),
|
||||||
no_verify: bool = typer.Option(False, "--no_verify", help="Bypass message verification"),
|
no_verify: bool = typer.Option(False, "--no_verify", help="Bypass message verification"),
|
||||||
yes: bool = typer.Option(False, "-y", help="Skip confirmation prompt and use defaults"),
|
yes: bool = typer.Option(False, "-y", help="Skip confirmation prompt and use defaults"),
|
||||||
context_window: int = typer.Option(
|
|
||||||
None, "--context_window", help="The context window of the LLM you are using (e.g. 8k for most Mistral 7B variants)"
|
|
||||||
),
|
|
||||||
):
|
):
|
||||||
"""Start chatting with an MemGPT agent
|
"""Start chatting with an MemGPT agent
|
||||||
|
|
||||||
@ -99,11 +105,6 @@ def run(
|
|||||||
set_global_service_context(service_context)
|
set_global_service_context(service_context)
|
||||||
sys.stdout = original_stdout
|
sys.stdout = original_stdout
|
||||||
|
|
||||||
# overwrite the context_window if specified
|
|
||||||
if context_window is not None and int(context_window) != int(config.context_window):
|
|
||||||
typer.secho(f"Warning: Overriding existing context window {config.context_window} with {context_window}", fg=typer.colors.YELLOW)
|
|
||||||
config.context_window = str(context_window)
|
|
||||||
|
|
||||||
# create agent config
|
# create agent config
|
||||||
if agent and AgentConfig.exists(agent): # use existing agent
|
if agent and AgentConfig.exists(agent): # use existing agent
|
||||||
typer.secho(f"Using existing agent {agent}", fg=typer.colors.GREEN)
|
typer.secho(f"Using existing agent {agent}", fg=typer.colors.GREEN)
|
||||||
@ -121,10 +122,34 @@ def run(
|
|||||||
typer.secho(f"Warning: Overriding existing human {agent_config.human} with {human}", fg=typer.colors.YELLOW)
|
typer.secho(f"Warning: Overriding existing human {agent_config.human} with {human}", fg=typer.colors.YELLOW)
|
||||||
agent_config.human = human
|
agent_config.human = human
|
||||||
# raise ValueError(f"Cannot override {agent_config.name} existing human {agent_config.human} with {human}")
|
# raise ValueError(f"Cannot override {agent_config.name} existing human {agent_config.human} with {human}")
|
||||||
|
|
||||||
|
# Allow overriding model specifics (model, model wrapper, model endpoint IP + type, context_window)
|
||||||
if model and model != agent_config.model:
|
if model and model != agent_config.model:
|
||||||
typer.secho(f"Warning: Overriding existing model {agent_config.model} with {model}", fg=typer.colors.YELLOW)
|
typer.secho(f"Warning: Overriding existing model {agent_config.model} with {model}", fg=typer.colors.YELLOW)
|
||||||
agent_config.model = model
|
agent_config.model = model
|
||||||
# raise ValueError(f"Cannot override {agent_config.name} existing model {agent_config.model} with {model}")
|
if context_window is not None and int(context_window) != agent_config.context_window:
|
||||||
|
typer.secho(
|
||||||
|
f"Warning: Overriding existing context window {agent_config.context_window} with {context_window}", fg=typer.colors.YELLOW
|
||||||
|
)
|
||||||
|
agent_config.context_window = context_window
|
||||||
|
if model_wrapper and model_wrapper != agent_config.model_wrapper:
|
||||||
|
typer.secho(
|
||||||
|
f"Warning: Overriding existing model wrapper {agent_config.model_wrapper} with {model_wrapper}", fg=typer.colors.YELLOW
|
||||||
|
)
|
||||||
|
agent_config.model_wrapper = model_wrapper
|
||||||
|
if model_endpoint and model_endpoint != agent_config.model_endpoint:
|
||||||
|
typer.secho(
|
||||||
|
f"Warning: Overriding existing model endpoint {agent_config.model_endpoint} with {model_endpoint}", fg=typer.colors.YELLOW
|
||||||
|
)
|
||||||
|
agent_config.model_endpoint = model_endpoint
|
||||||
|
if model_endpoint_type and model_endpoint_type != agent_config.model_endpoint_type:
|
||||||
|
typer.secho(
|
||||||
|
f"Warning: Overriding existing model endpoint type {agent_config.model_endpoint_type} with {model_endpoint_type}",
|
||||||
|
fg=typer.colors.YELLOW,
|
||||||
|
)
|
||||||
|
agent_config.model_endpoint_type = model_endpoint_type
|
||||||
|
|
||||||
|
# Update the agent config with any overrides
|
||||||
agent_config.save()
|
agent_config.save()
|
||||||
|
|
||||||
# load existing agent
|
# load existing agent
|
||||||
@ -133,17 +158,17 @@ def run(
|
|||||||
# create new agent config: override defaults with args if provided
|
# create new agent config: override defaults with args if provided
|
||||||
typer.secho("Creating new agent...", fg=typer.colors.GREEN)
|
typer.secho("Creating new agent...", fg=typer.colors.GREEN)
|
||||||
agent_config = AgentConfig(
|
agent_config = AgentConfig(
|
||||||
name=agent if agent else None,
|
name=agent,
|
||||||
persona=persona if persona else config.default_persona,
|
persona=persona,
|
||||||
human=human if human else config.default_human,
|
human=human,
|
||||||
model=model if model else config.model,
|
preset=preset,
|
||||||
context_window=context_window if context_window else config.context_window,
|
model=model,
|
||||||
preset=preset if preset else config.preset,
|
model_wrapper=model_wrapper,
|
||||||
|
model_endpoint_type=model_endpoint_type,
|
||||||
|
model_endpoint=model_endpoint,
|
||||||
|
context_window=context_window,
|
||||||
)
|
)
|
||||||
|
|
||||||
## attach data source to agent
|
|
||||||
# agent_config.attach_data_source(data_source)
|
|
||||||
|
|
||||||
# TODO: allow configrable state manager (only local is supported right now)
|
# TODO: allow configrable state manager (only local is supported right now)
|
||||||
persistence_manager = LocalStateManager(agent_config) # TODO: insert dataset/pre-fill
|
persistence_manager = LocalStateManager(agent_config) # TODO: insert dataset/pre-fill
|
||||||
|
|
||||||
@ -162,6 +187,9 @@ def run(
|
|||||||
persistence_manager,
|
persistence_manager,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# pretty print agent config
|
||||||
|
printd(json.dumps(vars(agent_config), indent=4, sort_keys=True))
|
||||||
|
|
||||||
# start event loop
|
# start event loop
|
||||||
from memgpt.main import run_agent_loop
|
from memgpt.main import run_agent_loop
|
||||||
|
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import builtins
|
||||||
import questionary
|
import questionary
|
||||||
import openai
|
import openai
|
||||||
from prettytable import PrettyTable
|
from prettytable import PrettyTable
|
||||||
@ -11,126 +12,118 @@ from memgpt import utils
|
|||||||
|
|
||||||
import memgpt.humans.humans as humans
|
import memgpt.humans.humans as humans
|
||||||
import memgpt.personas.personas as personas
|
import memgpt.personas.personas as personas
|
||||||
from memgpt.config import MemGPTConfig, AgentConfig
|
from memgpt.config import MemGPTConfig, AgentConfig, Config
|
||||||
from memgpt.constants import MEMGPT_DIR
|
from memgpt.constants import MEMGPT_DIR
|
||||||
from memgpt.connectors.storage import StorageConnector
|
from memgpt.connectors.storage import StorageConnector
|
||||||
from memgpt.constants import LLM_MAX_TOKENS
|
from memgpt.constants import LLM_MAX_TOKENS
|
||||||
|
from memgpt.local_llm.constants import DEFAULT_ENDPOINTS, DEFAULT_OLLAMA_MODEL, DEFAULT_WRAPPER_NAME
|
||||||
|
from memgpt.local_llm.utils import get_available_wrappers
|
||||||
|
|
||||||
app = typer.Typer()
|
app = typer.Typer()
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
def get_azure_credentials():
|
||||||
def configure():
|
|
||||||
"""Updates default MemGPT configurations"""
|
|
||||||
|
|
||||||
from memgpt.presets.presets import DEFAULT_PRESET, preset_options
|
|
||||||
|
|
||||||
MemGPTConfig.create_config_dir()
|
|
||||||
|
|
||||||
# Will pre-populate with defaults, or what the user previously set
|
|
||||||
config = MemGPTConfig.load()
|
|
||||||
|
|
||||||
# openai credentials
|
|
||||||
use_openai = questionary.confirm("Do you want to enable MemGPT with OpenAI?", default=True).ask()
|
|
||||||
if use_openai:
|
|
||||||
# search for key in enviornment
|
|
||||||
openai_key = os.getenv("OPENAI_API_KEY")
|
|
||||||
if not openai_key:
|
|
||||||
print("Missing enviornment variables for OpenAI. Please set them and run `memgpt configure` again.")
|
|
||||||
# TODO: eventually stop relying on env variables and pass in keys explicitly
|
|
||||||
# openai_key = questionary.text("Open AI API keys not found in enviornment - please enter:").ask()
|
|
||||||
|
|
||||||
# azure credentials
|
|
||||||
use_azure = questionary.confirm("Do you want to enable MemGPT with Azure?", default=(config.azure_key is not None)).ask()
|
|
||||||
use_azure_deployment_ids = False
|
|
||||||
if use_azure:
|
|
||||||
# search for key in enviornment
|
|
||||||
azure_key = os.getenv("AZURE_OPENAI_KEY")
|
azure_key = os.getenv("AZURE_OPENAI_KEY")
|
||||||
azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
|
azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
|
||||||
azure_version = os.getenv("AZURE_OPENAI_VERSION")
|
azure_version = os.getenv("AZURE_OPENAI_VERSION")
|
||||||
azure_deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT")
|
azure_deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT")
|
||||||
azure_embedding_deployment = os.getenv("AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT")
|
azure_embedding_deployment = os.getenv("AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT")
|
||||||
|
return azure_key, azure_endpoint, azure_version, azure_deployment, azure_embedding_deployment
|
||||||
|
|
||||||
if all([azure_key, azure_endpoint, azure_version]):
|
|
||||||
print(f"Using Microsoft endpoint {azure_endpoint}.")
|
|
||||||
if all([azure_deployment, azure_embedding_deployment]):
|
|
||||||
print(f"Using deployment id {azure_deployment}")
|
|
||||||
use_azure_deployment_ids = True
|
|
||||||
|
|
||||||
# configure openai
|
def get_openai_credentials():
|
||||||
openai.api_type = "azure"
|
openai_key = os.getenv("OPENAI_API_KEY")
|
||||||
openai.api_key = azure_key
|
return openai_key
|
||||||
openai.api_base = azure_endpoint
|
|
||||||
openai.api_version = azure_version
|
|
||||||
else:
|
|
||||||
print("Missing enviornment variables for Azure. Please set then run `memgpt configure` again.")
|
|
||||||
# TODO: allow for manual setting
|
|
||||||
use_azure = False
|
|
||||||
|
|
||||||
# TODO: configure local model
|
|
||||||
|
|
||||||
# configure provider
|
def configure_llm_endpoint(config: MemGPTConfig):
|
||||||
model_endpoint_options = []
|
# configure model endpoint
|
||||||
if os.getenv("OPENAI_API_BASE") is not None:
|
model_endpoint_type, model_endpoint = None, None
|
||||||
model_endpoint_options.append(os.getenv("OPENAI_API_BASE"))
|
|
||||||
if use_openai:
|
# get default
|
||||||
model_endpoint_options += ["openai"]
|
default_model_endpoint_type = config.model_endpoint_type
|
||||||
if use_azure:
|
if config.model_endpoint_type is not None and config.model_endpoint_type not in ["openai", "azure"]: # local model
|
||||||
model_endpoint_options += ["azure"]
|
default_model_endpoint_type = "local"
|
||||||
assert (
|
|
||||||
len(model_endpoint_options) > 0
|
provider = questionary.select(
|
||||||
), "No endpoints found. Please enable OpenAI, Azure, or set OPENAI_API_BASE to point at the IP address of your LLM server."
|
"Select LLM inference provider:", choices=["openai", "azure", "local"], default=default_model_endpoint_type
|
||||||
valid_default_model = config.model_endpoint in model_endpoint_options
|
|
||||||
default_endpoint = questionary.select(
|
|
||||||
"Select default inference endpoint:",
|
|
||||||
model_endpoint_options,
|
|
||||||
default=config.model_endpoint if valid_default_model else model_endpoint_options[0],
|
|
||||||
).ask()
|
).ask()
|
||||||
|
|
||||||
# configure embedding provider
|
# set: model_endpoint_type, model_endpoint
|
||||||
embedding_endpoint_options = []
|
if provider == "openai":
|
||||||
if use_azure:
|
model_endpoint_type = "openai"
|
||||||
embedding_endpoint_options += ["azure"]
|
model_endpoint = "https://api.openai.com/v1"
|
||||||
if use_openai:
|
model_endpoint = questionary.text("Override default endpoint:", default=model_endpoint).ask()
|
||||||
embedding_endpoint_options += ["openai"]
|
provider = "openai"
|
||||||
embedding_endpoint_options += ["local"]
|
elif provider == "azure":
|
||||||
valid_default_embedding = config.embedding_model in embedding_endpoint_options
|
model_endpoint_type = "azure"
|
||||||
# determine the default selection in a smart way
|
_, model_endpoint, _, _, _ = get_azure_credentials()
|
||||||
if "openai" in embedding_endpoint_options and default_endpoint == "openai":
|
else: # local models
|
||||||
# openai llm -> openai embeddings
|
backend_options = ["webui", "llamacpp", "koboldcpp", "ollama", "lmstudio", "openai"]
|
||||||
default_embedding_endpoint_default = "openai"
|
default_model_endpoint_type = None
|
||||||
elif default_endpoint not in ["openai", "azure"]: # is local
|
if config.model_endpoint_type in backend_options:
|
||||||
# local llm -> local embeddings
|
# set from previous config
|
||||||
default_embedding_endpoint_default = "local"
|
default_model_endpoint_type = config.model_endpoint_type
|
||||||
else:
|
else:
|
||||||
default_embedding_endpoint_default = config.embedding_model if valid_default_embedding else embedding_endpoint_options[-1]
|
# set form env variable (ok if none)
|
||||||
default_embedding_endpoint = questionary.select(
|
default_model_endpoint_type = os.getenv("BACKEND_TYPE")
|
||||||
"Select default embedding endpoint:", embedding_endpoint_options, default=default_embedding_endpoint_default
|
model_endpoint_type = questionary.select(
|
||||||
|
"Select LLM backend (select 'openai' if you have an OpenAI compatible proxy):",
|
||||||
|
backend_options,
|
||||||
|
default=default_model_endpoint_type,
|
||||||
).ask()
|
).ask()
|
||||||
|
|
||||||
# configure embedding dimentions
|
# set default endpoint
|
||||||
default_embedding_dim = config.embedding_dim
|
# if OPENAI_API_BASE is set, assume that this is the IP+port the user wanted to use
|
||||||
if default_embedding_endpoint == "local":
|
default_model_endpoint = os.getenv("OPENAI_API_BASE")
|
||||||
# HF model uses lower dimentionality
|
# if OPENAI_API_BASE is not set, try to pull a default IP+port format from a hardcoded set
|
||||||
default_embedding_dim = 384
|
if default_model_endpoint is None:
|
||||||
|
if model_endpoint_type in DEFAULT_ENDPOINTS:
|
||||||
|
default_model_endpoint = DEFAULT_ENDPOINTS[model_endpoint_type]
|
||||||
|
model_endpoint = questionary.text("Enter default endpoint:", default=default_model_endpoint).ask()
|
||||||
|
else:
|
||||||
|
# default_model_endpoint = None
|
||||||
|
model_endpoint = None
|
||||||
|
while not model_endpoint:
|
||||||
|
model_endpoint = questionary.text("Enter default endpoint:").ask()
|
||||||
|
if "http://" not in model_endpoint and "https://" not in model_endpoint:
|
||||||
|
typer.secho(f"Endpoint must be a valid address", fg=typer.colors.YELLOW)
|
||||||
|
model_endpoint = None
|
||||||
|
assert model_endpoint, f"Environment variable OPENAI_API_BASE must be set."
|
||||||
|
|
||||||
# configure preset
|
return model_endpoint_type, model_endpoint
|
||||||
default_preset = questionary.select("Select default preset:", preset_options, default=config.preset).ask()
|
|
||||||
|
|
||||||
# default model
|
|
||||||
if use_openai or use_azure:
|
def configure_model(config: MemGPTConfig, model_endpoint_type: str):
|
||||||
model_options = []
|
# set: model, model_wrapper
|
||||||
if use_openai:
|
model, model_wrapper = None, None
|
||||||
model_options += ["gpt-4", "gpt-4-1106-preview", "gpt-3.5-turbo-16k"]
|
if model_endpoint_type == "openai" or model_endpoint_type == "azure":
|
||||||
|
model_options = ["gpt-4", "gpt-4-1106-preview", "gpt-3.5-turbo", "gpt-3.5-turbo-16k"]
|
||||||
|
# TODO: select
|
||||||
valid_model = config.model in model_options
|
valid_model = config.model in model_options
|
||||||
default_model = questionary.select(
|
model = questionary.select(
|
||||||
"Select default model (recommended: gpt-4):", choices=model_options, default=config.model if valid_model else model_options[0]
|
"Select default model (recommended: gpt-4):", choices=model_options, default=config.model if valid_model else model_options[0]
|
||||||
).ask()
|
).ask()
|
||||||
else:
|
else: # local models
|
||||||
default_model = "local" # TODO: figure out if this is ok? this is for local endpoint
|
# ollama also needs model type
|
||||||
|
if model_endpoint_type == "ollama":
|
||||||
|
default_model = config.model if config.model and config.model_endpoint_type == "ollama" else DEFAULT_OLLAMA_MODEL
|
||||||
|
model = questionary.text(
|
||||||
|
"Enter default model name (required for Ollama, see: https://memgpt.readthedocs.io/en/latest/ollama):",
|
||||||
|
default=default_model,
|
||||||
|
).ask()
|
||||||
|
model = None if len(model) == 0 else model
|
||||||
|
|
||||||
# get the max tokens (context window) for the model
|
# model wrapper
|
||||||
if default_model == "local" or str(default_model) not in LLM_MAX_TOKENS:
|
available_model_wrappers = builtins.list(get_available_wrappers().keys())
|
||||||
|
model_wrapper = questionary.select(
|
||||||
|
f"Select default model wrapper (recommended: {DEFAULT_WRAPPER_NAME}):",
|
||||||
|
choices=available_model_wrappers,
|
||||||
|
default=DEFAULT_WRAPPER_NAME,
|
||||||
|
).ask()
|
||||||
|
|
||||||
|
# set: context_window
|
||||||
|
if str(model) not in LLM_MAX_TOKENS:
|
||||||
# Ask the user to specify the context length
|
# Ask the user to specify the context length
|
||||||
context_length_options = [
|
context_length_options = [
|
||||||
str(2**12), # 4096
|
str(2**12), # 4096
|
||||||
@ -140,46 +133,80 @@ def configure():
|
|||||||
str(2**18), # 262144
|
str(2**18), # 262144
|
||||||
"custom", # enter yourself
|
"custom", # enter yourself
|
||||||
]
|
]
|
||||||
default_model_context_window = questionary.select(
|
context_window = questionary.select(
|
||||||
"Select your model's context window (for Mistral 7B models, this is probably 8k / 8192):",
|
"Select your model's context window (for Mistral 7B models, this is probably 8k / 8192):",
|
||||||
choices=context_length_options,
|
choices=context_length_options,
|
||||||
default=str(LLM_MAX_TOKENS["DEFAULT"]),
|
default=str(LLM_MAX_TOKENS["DEFAULT"]),
|
||||||
).ask()
|
).ask()
|
||||||
|
|
||||||
# If custom, ask for input
|
# If custom, ask for input
|
||||||
if default_model_context_window == "custom":
|
if context_window == "custom":
|
||||||
while True:
|
while True:
|
||||||
default_model_context_window = questionary.text("Enter context window (e.g. 8192)").ask()
|
context_window = questionary.text("Enter context window (e.g. 8192)").ask()
|
||||||
try:
|
try:
|
||||||
default_model_context_window = int(default_model_context_window)
|
context_window = int(context_window)
|
||||||
break
|
break
|
||||||
except ValueError:
|
except ValueError:
|
||||||
print(f"Context window must be a valid integer")
|
print(f"Context window must be a valid integer")
|
||||||
else:
|
else:
|
||||||
default_model_context_window = int(default_model_context_window)
|
context_window = int(context_window)
|
||||||
else:
|
else:
|
||||||
# Pull the context length from the models
|
# Pull the context length from the models
|
||||||
default_model_context_window = LLM_MAX_TOKENS[default_model]
|
context_window = LLM_MAX_TOKENS[model]
|
||||||
|
return model, model_wrapper, context_window
|
||||||
|
|
||||||
# defaults
|
|
||||||
|
def configure_embedding_endpoint(config: MemGPTConfig):
|
||||||
|
# configure embedding endpoint
|
||||||
|
|
||||||
|
default_embedding_endpoint_type = config.embedding_endpoint_type
|
||||||
|
if config.embedding_endpoint_type is not None and config.embedding_endpoint_type not in ["openai", "azure"]: # local model
|
||||||
|
default_embedding_endpoint_type = "local"
|
||||||
|
|
||||||
|
embedding_endpoint_type, embedding_endpoint, embedding_dim = None, None, None
|
||||||
|
embedding_provider = questionary.select(
|
||||||
|
"Select embedding provider:", choices=["openai", "azure", "local"], default=default_embedding_endpoint_type
|
||||||
|
).ask()
|
||||||
|
if embedding_provider == "openai":
|
||||||
|
embedding_endpoint_type = "openai"
|
||||||
|
embedding_endpoint = "https://api.openai.com/v1"
|
||||||
|
embedding_dim = 1536
|
||||||
|
elif embedding_provider == "azure":
|
||||||
|
embedding_endpoint_type = "azure"
|
||||||
|
_, _, _, _, embedding_endpoint = get_azure_credentials()
|
||||||
|
embedding_dim = 1536
|
||||||
|
else: # local models
|
||||||
|
embedding_endpoint_type = "local"
|
||||||
|
embedding_endpoint = None
|
||||||
|
embedding_dim = 384
|
||||||
|
return embedding_endpoint_type, embedding_endpoint, embedding_dim
|
||||||
|
|
||||||
|
|
||||||
|
def configure_cli(config: MemGPTConfig):
|
||||||
|
# set: preset, default_persona, default_human, default_agent``
|
||||||
|
from memgpt.presets.presets import preset_options
|
||||||
|
|
||||||
|
# preset
|
||||||
|
default_preset = config.preset if config.preset and config.preset in preset_options else None
|
||||||
|
preset = questionary.select("Select default preset:", preset_options, default=default_preset).ask()
|
||||||
|
|
||||||
|
# persona
|
||||||
personas = [os.path.basename(f).replace(".txt", "") for f in utils.list_persona_files()]
|
personas = [os.path.basename(f).replace(".txt", "") for f in utils.list_persona_files()]
|
||||||
# print(personas)
|
default_persona = config.persona if config.persona and config.persona in personas else None
|
||||||
default_persona = questionary.select("Select default persona:", personas, default=config.default_persona).ask()
|
persona = questionary.select("Select default persona:", personas, default=default_persona).ask()
|
||||||
|
|
||||||
|
# human
|
||||||
humans = [os.path.basename(f).replace(".txt", "") for f in utils.list_human_files()]
|
humans = [os.path.basename(f).replace(".txt", "") for f in utils.list_human_files()]
|
||||||
# print(humans)
|
default_human = config.human if config.human and config.human in humans else None
|
||||||
default_human = questionary.select("Select default human:", humans, default=config.default_human).ask()
|
human = questionary.select("Select default human:", humans, default=default_human).ask()
|
||||||
|
|
||||||
# TODO: figure out if we should set a default agent or not
|
# TODO: figure out if we should set a default agent or not
|
||||||
default_agent = None
|
agent = None
|
||||||
# agents = [os.path.basename(f).replace(".json", "") for f in utils.list_agent_config_files()]
|
|
||||||
# if len(agents) > 0: # agents have been created
|
|
||||||
# default_agent = questionary.select(
|
|
||||||
# "Select default agent:",
|
|
||||||
# agents
|
|
||||||
# ).ask()
|
|
||||||
# else:
|
|
||||||
# default_agent = None
|
|
||||||
|
|
||||||
|
return preset, persona, human, agent
|
||||||
|
|
||||||
|
|
||||||
|
def configure_archival_storage(config: MemGPTConfig):
|
||||||
# Configure archival storage backend
|
# Configure archival storage backend
|
||||||
archival_storage_options = ["local", "postgres"]
|
archival_storage_options = ["local", "postgres"]
|
||||||
archival_storage_type = questionary.select(
|
archival_storage_type = questionary.select(
|
||||||
@ -191,25 +218,65 @@ def configure():
|
|||||||
"Enter postgres connection string (e.g. postgresql+pg8000://{user}:{password}@{ip}:5432/{database}):",
|
"Enter postgres connection string (e.g. postgresql+pg8000://{user}:{password}@{ip}:5432/{database}):",
|
||||||
default=config.archival_storage_uri if config.archival_storage_uri else "",
|
default=config.archival_storage_uri if config.archival_storage_uri else "",
|
||||||
).ask()
|
).ask()
|
||||||
|
return archival_storage_type, archival_storage_uri
|
||||||
|
|
||||||
# TODO: allow configuring embedding model
|
|
||||||
|
@app.command()
|
||||||
|
def configure():
|
||||||
|
"""Updates default MemGPT configurations"""
|
||||||
|
|
||||||
|
MemGPTConfig.create_config_dir()
|
||||||
|
|
||||||
|
# Will pre-populate with defaults, or what the user previously set
|
||||||
|
config = MemGPTConfig.load()
|
||||||
|
model_endpoint_type, model_endpoint = configure_llm_endpoint(config)
|
||||||
|
model, model_wrapper, context_window = configure_model(config, model_endpoint_type)
|
||||||
|
embedding_endpoint_type, embedding_endpoint, embedding_dim = configure_embedding_endpoint(config)
|
||||||
|
default_preset, default_persona, default_human, default_agent = configure_cli(config)
|
||||||
|
archival_storage_type, archival_storage_uri = configure_archival_storage(config)
|
||||||
|
|
||||||
|
# check credentials
|
||||||
|
azure_key, azure_endpoint, azure_version, azure_deployment, azure_embedding_deployment = get_azure_credentials()
|
||||||
|
openai_key = get_openai_credentials()
|
||||||
|
if model_endpoint_type == "azure" or embedding_endpoint_type == "azure":
|
||||||
|
if all([azure_key, azure_endpoint, azure_version]):
|
||||||
|
print(f"Using Microsoft endpoint {azure_endpoint}.")
|
||||||
|
if all([azure_deployment, azure_embedding_deployment]):
|
||||||
|
print(f"Using deployment id {azure_deployment}")
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Missing environment variables for Azure (see https://memgpt.readthedocs.io/en/latest/endpoints/#azure). Please set then run `memgpt configure` again."
|
||||||
|
)
|
||||||
|
if model_endpoint_type == "openai" or embedding_endpoint_type == "openai":
|
||||||
|
if not openai_key:
|
||||||
|
raise ValueError(
|
||||||
|
"Missing environment variables for OpenAI (see https://memgpt.readthedocs.io/en/latest/endpoints/#openai). Please set them and run `memgpt configure` again."
|
||||||
|
)
|
||||||
|
|
||||||
config = MemGPTConfig(
|
config = MemGPTConfig(
|
||||||
model=default_model,
|
# model configs
|
||||||
context_window=default_model_context_window,
|
model=model,
|
||||||
|
model_endpoint=model_endpoint,
|
||||||
|
model_endpoint_type=model_endpoint_type,
|
||||||
|
model_wrapper=model_wrapper,
|
||||||
|
context_window=context_window,
|
||||||
|
# embedding configs
|
||||||
|
embedding_endpoint_type=embedding_endpoint_type,
|
||||||
|
embedding_endpoint=embedding_endpoint,
|
||||||
|
embedding_dim=embedding_dim,
|
||||||
|
# cli configs
|
||||||
preset=default_preset,
|
preset=default_preset,
|
||||||
model_endpoint=default_endpoint,
|
persona=default_persona,
|
||||||
embedding_model=default_embedding_endpoint,
|
human=default_human,
|
||||||
embedding_dim=default_embedding_dim,
|
agent=default_agent,
|
||||||
default_persona=default_persona,
|
# credentials
|
||||||
default_human=default_human,
|
openai_key=openai_key,
|
||||||
default_agent=default_agent,
|
azure_key=azure_key,
|
||||||
openai_key=openai_key if use_openai else None,
|
azure_endpoint=azure_endpoint,
|
||||||
azure_key=azure_key if use_azure else None,
|
azure_version=azure_version,
|
||||||
azure_endpoint=azure_endpoint if use_azure else None,
|
azure_deployment=azure_deployment,
|
||||||
azure_version=azure_version if use_azure else None,
|
azure_embedding_deployment=azure_embedding_deployment,
|
||||||
azure_deployment=azure_deployment if use_azure_deployment_ids else None,
|
# storage
|
||||||
azure_embedding_deployment=azure_embedding_deployment if use_azure_deployment_ids else None,
|
|
||||||
archival_storage_type=archival_storage_type,
|
archival_storage_type=archival_storage_type,
|
||||||
archival_storage_uri=archival_storage_uri,
|
archival_storage_uri=archival_storage_uri,
|
||||||
)
|
)
|
||||||
|
253
memgpt/config.py
253
memgpt/config.py
@ -16,6 +16,7 @@ from colorama import Fore, Style
|
|||||||
|
|
||||||
from typing import List, Type
|
from typing import List, Type
|
||||||
|
|
||||||
|
import memgpt
|
||||||
import memgpt.utils as utils
|
import memgpt.utils as utils
|
||||||
from memgpt.interface import CLIInterface as interface
|
from memgpt.interface import CLIInterface as interface
|
||||||
from memgpt.personas.personas import get_persona_text
|
from memgpt.personas.personas import get_persona_text
|
||||||
@ -40,6 +41,24 @@ model_choices = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# helper functions for writing to configs
|
||||||
|
def get_field(config, section, field):
|
||||||
|
if section not in config:
|
||||||
|
return None
|
||||||
|
if config.has_option(section, field):
|
||||||
|
return config.get(section, field)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def set_field(config, section, field, value):
|
||||||
|
if value is None: # cannot write None
|
||||||
|
return
|
||||||
|
if section not in config: # create section
|
||||||
|
config.add_section(section)
|
||||||
|
config.set(section, field, value)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MemGPTConfig:
|
class MemGPTConfig:
|
||||||
config_path: str = os.path.join(MEMGPT_DIR, "config")
|
config_path: str = os.path.join(MEMGPT_DIR, "config")
|
||||||
@ -49,9 +68,10 @@ class MemGPTConfig:
|
|||||||
preset: str = DEFAULT_PRESET
|
preset: str = DEFAULT_PRESET
|
||||||
|
|
||||||
# model parameters
|
# model parameters
|
||||||
# provider: str = "openai" # openai, azure, local (TODO)
|
model: str = None
|
||||||
model_endpoint: str = "openai"
|
model_endpoint_type: str = None
|
||||||
model: str = "gpt-4" # gpt-4, gpt-3.5-turbo, local
|
model_endpoint: str = None # localhost:8000
|
||||||
|
model_wrapper: str = None
|
||||||
context_window: int = LLM_MAX_TOKENS[model] if model in LLM_MAX_TOKENS else LLM_MAX_TOKENS["DEFAULT"]
|
context_window: int = LLM_MAX_TOKENS[model] if model in LLM_MAX_TOKENS else LLM_MAX_TOKENS["DEFAULT"]
|
||||||
|
|
||||||
# model parameters: openai
|
# model parameters: openai
|
||||||
@ -65,12 +85,13 @@ class MemGPTConfig:
|
|||||||
azure_embedding_deployment: str = None
|
azure_embedding_deployment: str = None
|
||||||
|
|
||||||
# persona parameters
|
# persona parameters
|
||||||
default_persona: str = personas.DEFAULT
|
persona: str = personas.DEFAULT
|
||||||
default_human: str = humans.DEFAULT
|
human: str = humans.DEFAULT
|
||||||
default_agent: str = None
|
agent: str = None
|
||||||
|
|
||||||
# embedding parameters
|
# embedding parameters
|
||||||
embedding_model: str = "openai"
|
embedding_endpoint_type: str = "openai" # openai, azure, local
|
||||||
|
embedding_endpoint: str = None
|
||||||
embedding_dim: int = 1536
|
embedding_dim: int = 1536
|
||||||
embedding_chunk_size: int = 300 # number of tokens
|
embedding_chunk_size: int = 300 # number of tokens
|
||||||
|
|
||||||
@ -89,6 +110,12 @@ class MemGPTConfig:
|
|||||||
persistence_manager_save_file: str = None # local file
|
persistence_manager_save_file: str = None # local file
|
||||||
persistence_manager_uri: str = None # db URI
|
persistence_manager_uri: str = None # db URI
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
# ensure types
|
||||||
|
self.embedding_chunk_size = int(self.embedding_chunk_size)
|
||||||
|
self.embedding_dim = int(self.embedding_dim)
|
||||||
|
self.context_window = int(self.context_window)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def generate_uuid() -> str:
|
def generate_uuid() -> str:
|
||||||
return uuid.UUID(int=uuid.getnode()).hex
|
return uuid.UUID(int=uuid.getnode()).hex
|
||||||
@ -104,72 +131,38 @@ class MemGPTConfig:
|
|||||||
config_path = MemGPTConfig.config_path
|
config_path = MemGPTConfig.config_path
|
||||||
|
|
||||||
if os.path.exists(config_path):
|
if os.path.exists(config_path):
|
||||||
|
# read existing config
|
||||||
config.read(config_path)
|
config.read(config_path)
|
||||||
|
config_dict = {
|
||||||
|
"model": get_field(config, "model", "model"),
|
||||||
|
"model_endpoint": get_field(config, "model", "model_endpoint"),
|
||||||
|
"model_endpoint_type": get_field(config, "model", "model_endpoint_type"),
|
||||||
|
"model_wrapper": get_field(config, "model", "model_wrapper"),
|
||||||
|
"context_window": get_field(config, "model", "context_window"),
|
||||||
|
"preset": get_field(config, "defaults", "preset"),
|
||||||
|
"persona": get_field(config, "defaults", "persona"),
|
||||||
|
"human": get_field(config, "defaults", "human"),
|
||||||
|
"agent": get_field(config, "defaults", "agent"),
|
||||||
|
"openai_key": get_field(config, "openai", "key"),
|
||||||
|
"azure_key": get_field(config, "azure", "key"),
|
||||||
|
"azure_endpoint": get_field(config, "azure", "endpoint"),
|
||||||
|
"azure_version": get_field(config, "azure", "version"),
|
||||||
|
"azure_deployment": get_field(config, "azure", "deployment"),
|
||||||
|
"azure_embedding_deployment": get_field(config, "azure", "embedding_deployment"),
|
||||||
|
"embedding_endpoint": get_field(config, "embedding", "embedding_endpoint"),
|
||||||
|
"embedding_endpoint_type": get_field(config, "embedding", "embedding_endpoint_type"),
|
||||||
|
"embedding_dim": get_field(config, "embedding", "embedding_dim"),
|
||||||
|
"embedding_chunk_size": get_field(config, "embedding", "chunk_size"),
|
||||||
|
"archival_storage_type": get_field(config, "archival_storage", "type"),
|
||||||
|
"archival_storage_path": get_field(config, "archival_storage", "path"),
|
||||||
|
"archival_storage_uri": get_field(config, "archival_storage", "uri"),
|
||||||
|
"anon_clientid": get_field(config, "client", "anon_clientid"),
|
||||||
|
"config_path": config_path,
|
||||||
|
}
|
||||||
|
config_dict = {k: v for k, v in config_dict.items() if v is not None}
|
||||||
|
return cls(**config_dict)
|
||||||
|
|
||||||
# read config values
|
# create new config
|
||||||
model = config.get("defaults", "model")
|
|
||||||
context_window = (
|
|
||||||
int(config.get("defaults", "context_window"))
|
|
||||||
if config.has_option("defaults", "context_window")
|
|
||||||
else LLM_MAX_TOKENS["DEFAULT"]
|
|
||||||
)
|
|
||||||
preset = config.get("defaults", "preset")
|
|
||||||
model_endpoint = config.get("defaults", "model_endpoint")
|
|
||||||
default_persona = config.get("defaults", "persona")
|
|
||||||
default_human = config.get("defaults", "human")
|
|
||||||
default_agent = config.get("defaults", "agent") if config.has_option("defaults", "agent") else None
|
|
||||||
|
|
||||||
openai_key, openai_model = None, None
|
|
||||||
if "openai" in config:
|
|
||||||
openai_key = config.get("openai", "key")
|
|
||||||
|
|
||||||
azure_key, azure_endpoint, azure_version, azure_deployment, azure_embedding_deployment = None, None, None, None, None
|
|
||||||
if "azure" in config:
|
|
||||||
azure_key = config.get("azure", "key")
|
|
||||||
azure_endpoint = config.get("azure", "endpoint")
|
|
||||||
azure_version = config.get("azure", "version")
|
|
||||||
azure_deployment = config.get("azure", "deployment") if config.has_option("azure", "deployment") else None
|
|
||||||
azure_embedding_deployment = (
|
|
||||||
config.get("azure", "embedding_deployment") if config.has_option("azure", "embedding_deployment") else None
|
|
||||||
)
|
|
||||||
|
|
||||||
embedding_model = config.get("embedding", "model")
|
|
||||||
embedding_dim = config.getint("embedding", "dim")
|
|
||||||
embedding_chunk_size = config.getint("embedding", "chunk_size")
|
|
||||||
|
|
||||||
# archival storage
|
|
||||||
archival_storage_type, archival_storage_path, archival_storage_uri = "local", None, None
|
|
||||||
if "archival_storage" in config:
|
|
||||||
archival_storage_type = config.get("archival_storage", "type")
|
|
||||||
archival_storage_path = config.get("archival_storage", "path") if config.has_option("archival_storage", "path") else None
|
|
||||||
archival_storage_uri = config.get("archival_storage", "uri") if config.has_option("archival_storage", "uri") else None
|
|
||||||
|
|
||||||
anon_clientid = config.get("client", "anon_clientid")
|
|
||||||
|
|
||||||
return cls(
|
|
||||||
model=model,
|
|
||||||
context_window=context_window,
|
|
||||||
preset=preset,
|
|
||||||
model_endpoint=model_endpoint,
|
|
||||||
default_persona=default_persona,
|
|
||||||
default_human=default_human,
|
|
||||||
default_agent=default_agent,
|
|
||||||
openai_key=openai_key,
|
|
||||||
azure_key=azure_key,
|
|
||||||
azure_endpoint=azure_endpoint,
|
|
||||||
azure_version=azure_version,
|
|
||||||
azure_deployment=azure_deployment,
|
|
||||||
azure_embedding_deployment=azure_embedding_deployment,
|
|
||||||
embedding_model=embedding_model,
|
|
||||||
embedding_dim=embedding_dim,
|
|
||||||
embedding_chunk_size=embedding_chunk_size,
|
|
||||||
archival_storage_type=archival_storage_type,
|
|
||||||
archival_storage_path=archival_storage_path,
|
|
||||||
archival_storage_uri=archival_storage_uri,
|
|
||||||
anon_clientid=anon_clientid,
|
|
||||||
config_path=config_path,
|
|
||||||
)
|
|
||||||
|
|
||||||
anon_clientid = MemGPTConfig.generate_uuid()
|
anon_clientid = MemGPTConfig.generate_uuid()
|
||||||
config = cls(anon_clientid=anon_clientid, config_path=config_path)
|
config = cls(anon_clientid=anon_clientid, config_path=config_path)
|
||||||
config.save() # save updated config
|
config.save() # save updated config
|
||||||
@ -179,51 +172,43 @@ class MemGPTConfig:
|
|||||||
config = configparser.ConfigParser()
|
config = configparser.ConfigParser()
|
||||||
|
|
||||||
# CLI defaults
|
# CLI defaults
|
||||||
config.add_section("defaults")
|
set_field(config, "defaults", "preset", self.preset)
|
||||||
config.set("defaults", "model", self.model)
|
set_field(config, "defaults", "persona", self.persona)
|
||||||
config.set("defaults", "context_window", str(self.context_window))
|
set_field(config, "defaults", "human", self.human)
|
||||||
config.set("defaults", "preset", self.preset)
|
set_field(config, "defaults", "agent", self.agent)
|
||||||
assert self.model_endpoint is not None, "Endpoint must be set"
|
|
||||||
config.set("defaults", "model_endpoint", self.model_endpoint)
|
|
||||||
config.set("defaults", "persona", self.default_persona)
|
|
||||||
config.set("defaults", "human", self.default_human)
|
|
||||||
if self.default_agent:
|
|
||||||
config.set("defaults", "agent", self.default_agent)
|
|
||||||
|
|
||||||
# security credentials
|
# model defaults
|
||||||
if self.openai_key:
|
set_field(config, "model", "model", self.model)
|
||||||
config.add_section("openai")
|
set_field(config, "model", "model_endpoint", self.model_endpoint)
|
||||||
config.set("openai", "key", self.openai_key)
|
set_field(config, "model", "model_endpoint_type", self.model_endpoint_type)
|
||||||
|
set_field(config, "model", "model_wrapper", self.model_wrapper)
|
||||||
|
set_field(config, "model", "context_window", str(self.context_window))
|
||||||
|
|
||||||
if self.azure_key:
|
# security credentials: openai
|
||||||
config.add_section("azure")
|
set_field(config, "openai", "key", self.openai_key)
|
||||||
config.set("azure", "key", self.azure_key)
|
|
||||||
config.set("azure", "endpoint", self.azure_endpoint)
|
# security credentials: azure
|
||||||
config.set("azure", "version", self.azure_version)
|
set_field(config, "azure", "key", self.azure_key)
|
||||||
if self.azure_deployment:
|
set_field(config, "azure", "endpoint", self.azure_endpoint)
|
||||||
config.set("azure", "deployment", self.azure_deployment)
|
set_field(config, "azure", "version", self.azure_version)
|
||||||
config.set("azure", "embedding_deployment", self.azure_embedding_deployment)
|
set_field(config, "azure", "deployment", self.azure_deployment)
|
||||||
|
set_field(config, "azure", "embedding_deployment", self.azure_embedding_deployment)
|
||||||
|
|
||||||
# embeddings
|
# embeddings
|
||||||
config.add_section("embedding")
|
set_field(config, "embedding", "embedding_endpoint_type", self.embedding_endpoint_type)
|
||||||
config.set("embedding", "model", self.embedding_model)
|
set_field(config, "embedding", "embedding_endpoint", self.embedding_endpoint)
|
||||||
config.set("embedding", "dim", str(self.embedding_dim))
|
set_field(config, "embedding", "embedding_dim", str(self.embedding_dim))
|
||||||
config.set("embedding", "chunk_size", str(self.embedding_chunk_size))
|
set_field(config, "embedding", "embedding_chunk_size", str(self.embedding_chunk_size))
|
||||||
|
|
||||||
# archival storage
|
# archival storage
|
||||||
config.add_section("archival_storage")
|
set_field(config, "archival_storage", "type", self.archival_storage_type)
|
||||||
# print("archival storage", self.archival_storage_type)
|
set_field(config, "archival_storage", "path", self.archival_storage_path)
|
||||||
config.set("archival_storage", "type", self.archival_storage_type)
|
set_field(config, "archival_storage", "uri", self.archival_storage_uri)
|
||||||
if self.archival_storage_path:
|
|
||||||
config.set("archival_storage", "path", self.archival_storage_path)
|
|
||||||
if self.archival_storage_uri:
|
|
||||||
config.set("archival_storage", "uri", self.archival_storage_uri)
|
|
||||||
|
|
||||||
# client
|
# client
|
||||||
config.add_section("client")
|
|
||||||
if not self.anon_clientid:
|
if not self.anon_clientid:
|
||||||
self.anon_clientid = self.generate_uuid()
|
self.anon_clientid = self.generate_uuid()
|
||||||
config.set("client", "anon_clientid", self.anon_clientid)
|
set_field(config, "client", "anon_clientid", self.anon_clientid)
|
||||||
|
|
||||||
if not os.path.exists(MEMGPT_DIR):
|
if not os.path.exists(MEMGPT_DIR):
|
||||||
os.makedirs(MEMGPT_DIR, exist_ok=True)
|
os.makedirs(MEMGPT_DIR, exist_ok=True)
|
||||||
@ -262,32 +247,54 @@ class AgentConfig:
|
|||||||
self,
|
self,
|
||||||
persona,
|
persona,
|
||||||
human,
|
human,
|
||||||
|
# model info
|
||||||
model,
|
model,
|
||||||
|
model_endpoint_type=None,
|
||||||
|
model_endpoint=None,
|
||||||
|
model_wrapper=None,
|
||||||
context_window=None,
|
context_window=None,
|
||||||
preset=DEFAULT_PRESET,
|
# embedding info
|
||||||
name=None,
|
embedding_endpoint_type=None,
|
||||||
data_sources=[],
|
embedding_endpoint=None,
|
||||||
|
embedding_dim=None,
|
||||||
|
embedding_chunk_size=None,
|
||||||
|
# other
|
||||||
|
preset=None,
|
||||||
|
data_sources=None,
|
||||||
|
# agent info
|
||||||
agent_config_path=None,
|
agent_config_path=None,
|
||||||
|
name=None,
|
||||||
create_time=None,
|
create_time=None,
|
||||||
data_source=None,
|
memgpt_version=None,
|
||||||
):
|
):
|
||||||
if name is None:
|
if name is None:
|
||||||
self.name = f"agent_{self.generate_agent_id()}"
|
self.name = f"agent_{self.generate_agent_id()}"
|
||||||
else:
|
else:
|
||||||
self.name = name
|
self.name = name
|
||||||
self.persona = persona
|
|
||||||
self.human = human
|
|
||||||
self.model = model
|
|
||||||
self.context_window = context_window
|
|
||||||
self.preset = preset
|
|
||||||
self.data_sources = data_sources
|
|
||||||
self.create_time = create_time if create_time is not None else utils.get_local_time()
|
|
||||||
self.data_source = None # deprecated
|
|
||||||
|
|
||||||
if context_window is None:
|
config = MemGPTConfig.load() # get default values
|
||||||
self.context_window = LLM_MAX_TOKENS[self.model] if self.model in LLM_MAX_TOKENS else LLM_MAX_TOKENS["DEFAULT"]
|
self.persona = config.persona if persona is None else persona
|
||||||
|
self.human = config.human if human is None else human
|
||||||
|
self.preset = config.preset if preset is None else preset
|
||||||
|
self.context_window = config.context_window if context_window is None else context_window
|
||||||
|
self.model = config.model if model is None else model
|
||||||
|
self.model_endpoint_type = config.model_endpoint_type if model_endpoint_type is None else model_endpoint_type
|
||||||
|
self.model_endpoint = config.model_endpoint if model_endpoint is None else model_endpoint
|
||||||
|
self.model_wrapper = config.model_wrapper if model_wrapper is None else model_wrapper
|
||||||
|
self.embedding_endpoint_type = config.embedding_endpoint_type if embedding_endpoint_type is None else embedding_endpoint_type
|
||||||
|
self.embedding_endpoint = config.embedding_endpoint if embedding_endpoint is None else embedding_endpoint
|
||||||
|
self.embedding_dim = config.embedding_dim if embedding_dim is None else embedding_dim
|
||||||
|
self.embedding_chunk_size = config.embedding_chunk_size if embedding_chunk_size is None else embedding_chunk_size
|
||||||
|
|
||||||
|
# agent metadata
|
||||||
|
self.data_sources = data_sources if data_sources is not None else []
|
||||||
|
self.create_time = create_time if create_time is not None else utils.get_local_time()
|
||||||
|
if memgpt_version is None:
|
||||||
|
import memgpt
|
||||||
|
|
||||||
|
self.memgpt_version = memgpt.__version__
|
||||||
else:
|
else:
|
||||||
self.context_window = context_window
|
self.memgpt_version = memgpt_version
|
||||||
|
|
||||||
# save agent config
|
# save agent config
|
||||||
self.agent_config_path = (
|
self.agent_config_path = (
|
||||||
@ -326,6 +333,8 @@ class AgentConfig:
|
|||||||
def save(self):
|
def save(self):
|
||||||
# save state of persistence manager
|
# save state of persistence manager
|
||||||
os.makedirs(os.path.join(MEMGPT_DIR, "agents", self.name), exist_ok=True)
|
os.makedirs(os.path.join(MEMGPT_DIR, "agents", self.name), exist_ok=True)
|
||||||
|
# save version
|
||||||
|
self.memgpt_version = memgpt.__version__
|
||||||
with open(self.agent_config_path, "w") as f:
|
with open(self.agent_config_path, "w") as f:
|
||||||
json.dump(vars(self), f, indent=4)
|
json.dump(vars(self), f, indent=4)
|
||||||
|
|
||||||
@ -342,7 +351,6 @@ class AgentConfig:
|
|||||||
assert os.path.exists(agent_config_path), f"Agent config file does not exist at {agent_config_path}"
|
assert os.path.exists(agent_config_path), f"Agent config file does not exist at {agent_config_path}"
|
||||||
with open(agent_config_path, "r") as f:
|
with open(agent_config_path, "r") as f:
|
||||||
agent_config = json.load(f)
|
agent_config = json.load(f)
|
||||||
|
|
||||||
# allow compatibility accross versions
|
# allow compatibility accross versions
|
||||||
try:
|
try:
|
||||||
class_args = inspect.getargspec(cls.__init__).args
|
class_args = inspect.getargspec(cls.__init__).args
|
||||||
@ -354,7 +362,6 @@ class AgentConfig:
|
|||||||
if key not in class_args:
|
if key not in class_args:
|
||||||
utils.printd(f"Removing missing argument {key} from agent config")
|
utils.printd(f"Removing missing argument {key} from agent config")
|
||||||
del agent_config[key]
|
del agent_config[key]
|
||||||
|
|
||||||
return cls(**agent_config)
|
return cls(**agent_config)
|
||||||
|
|
||||||
|
|
||||||
|
@ -11,9 +11,9 @@ def embedding_model():
|
|||||||
# load config
|
# load config
|
||||||
config = MemGPTConfig.load()
|
config = MemGPTConfig.load()
|
||||||
|
|
||||||
endpoint = config.embedding_model
|
endpoint = config.embedding_endpoint_type
|
||||||
if endpoint == "openai":
|
if endpoint == "openai":
|
||||||
model = OpenAIEmbedding(api_base="https://api.openai.com/v1", api_key=config.openai_key)
|
model = OpenAIEmbedding(api_base=config.embedding_endpoint, api_key=config.openai_key)
|
||||||
return model
|
return model
|
||||||
elif endpoint == "azure":
|
elif endpoint == "azure":
|
||||||
return OpenAIEmbedding(
|
return OpenAIEmbedding(
|
||||||
|
@ -4,7 +4,7 @@ import os
|
|||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
|
|
||||||
from ...constants import MAX_PAUSE_HEARTBEATS, RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
|
from memgpt.constants import MAX_PAUSE_HEARTBEATS, RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
|
||||||
|
|
||||||
### Functions / tools the agent can use
|
### Functions / tools the agent can use
|
||||||
# All functions should return a response string (or None)
|
# All functions should return a response string (or None)
|
||||||
|
@ -4,8 +4,8 @@ import json
|
|||||||
import requests
|
import requests
|
||||||
|
|
||||||
|
|
||||||
from ...constants import MESSAGE_CHATGPT_FUNCTION_MODEL, MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE, MAX_PAUSE_HEARTBEATS
|
from memgpt.constants import MESSAGE_CHATGPT_FUNCTION_MODEL, MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE, MAX_PAUSE_HEARTBEATS
|
||||||
from ...openai_tools import completions_with_backoff as create
|
from memgpt.openai_tools import completions_with_backoff as create
|
||||||
|
|
||||||
|
|
||||||
def message_chatgpt(self, message: str):
|
def message_chatgpt(self, message: str):
|
||||||
|
@ -10,74 +10,65 @@ from .llamacpp.api import get_llamacpp_completion
|
|||||||
from .koboldcpp.api import get_koboldcpp_completion
|
from .koboldcpp.api import get_koboldcpp_completion
|
||||||
from .ollama.api import get_ollama_completion
|
from .ollama.api import get_ollama_completion
|
||||||
from .llm_chat_completion_wrappers import airoboros, dolphin, zephyr, simple_summary_wrapper
|
from .llm_chat_completion_wrappers import airoboros, dolphin, zephyr, simple_summary_wrapper
|
||||||
from .utils import DotDict
|
from .constants import DEFAULT_WRAPPER
|
||||||
|
from .utils import DotDict, get_available_wrappers
|
||||||
from ..prompts.gpt_summarize import SYSTEM as SUMMARIZE_SYSTEM_MESSAGE
|
from ..prompts.gpt_summarize import SYSTEM as SUMMARIZE_SYSTEM_MESSAGE
|
||||||
from ..errors import LocalLLMConnectionError, LocalLLMError
|
from ..errors import LocalLLMConnectionError, LocalLLMError
|
||||||
|
|
||||||
HOST = os.getenv("OPENAI_API_BASE")
|
endpoint = os.getenv("OPENAI_API_BASE")
|
||||||
HOST_TYPE = os.getenv("BACKEND_TYPE") # default None == ChatCompletion
|
endpoint_type = os.getenv("BACKEND_TYPE") # default None == ChatCompletion
|
||||||
DEBUG = False
|
DEBUG = False
|
||||||
# DEBUG = True
|
# DEBUG = True
|
||||||
DEFAULT_WRAPPER = airoboros.Airoboros21InnerMonologueWrapper
|
|
||||||
has_shown_warning = False
|
has_shown_warning = False
|
||||||
|
|
||||||
|
|
||||||
def get_chat_completion(
|
def get_chat_completion(
|
||||||
model, # no model, since the model is fixed to whatever you set in your own backend
|
model, # no model required (except for Ollama), since the model is fixed to whatever you set in your own backend
|
||||||
messages,
|
messages,
|
||||||
functions=None,
|
functions=None,
|
||||||
function_call="auto",
|
function_call="auto",
|
||||||
context_window=None,
|
context_window=None,
|
||||||
|
# required
|
||||||
|
wrapper=None,
|
||||||
|
endpoint=None,
|
||||||
|
endpoint_type=None,
|
||||||
):
|
):
|
||||||
assert context_window is not None, "Local LLM calls need the context length to be explicitly set"
|
assert context_window is not None, "Local LLM calls need the context length to be explicitly set"
|
||||||
|
assert endpoint is not None, "Local LLM calls need the endpoint (eg http://localendpoint:1234) to be explicitly set"
|
||||||
|
assert endpoint_type is not None, "Local LLM calls need the endpoint type (eg webui) to be explicitly set"
|
||||||
global has_shown_warning
|
global has_shown_warning
|
||||||
grammar_name = None
|
grammar_name = None
|
||||||
|
|
||||||
if HOST is None:
|
|
||||||
raise ValueError(f"The OPENAI_API_BASE environment variable is not defined. Please set it in your environment.")
|
|
||||||
if HOST_TYPE is None:
|
|
||||||
raise ValueError(f"The BACKEND_TYPE environment variable is not defined. Please set it in your environment.")
|
|
||||||
|
|
||||||
if function_call != "auto":
|
if function_call != "auto":
|
||||||
raise ValueError(f"function_call == {function_call} not supported (auto only)")
|
raise ValueError(f"function_call == {function_call} not supported (auto only)")
|
||||||
|
|
||||||
|
available_wrappers = get_available_wrappers()
|
||||||
if messages[0]["role"] == "system" and messages[0]["content"].strip() == SUMMARIZE_SYSTEM_MESSAGE.strip():
|
if messages[0]["role"] == "system" and messages[0]["content"].strip() == SUMMARIZE_SYSTEM_MESSAGE.strip():
|
||||||
# Special case for if the call we're making is coming from the summarizer
|
# Special case for if the call we're making is coming from the summarizer
|
||||||
llm_wrapper = simple_summary_wrapper.SimpleSummaryWrapper()
|
llm_wrapper = simple_summary_wrapper.SimpleSummaryWrapper()
|
||||||
elif model == "airoboros-l2-70b-2.1":
|
elif wrapper is None:
|
||||||
llm_wrapper = airoboros.Airoboros21InnerMonologueWrapper()
|
|
||||||
elif model == "airoboros-l2-70b-2.1-grammar":
|
|
||||||
llm_wrapper = airoboros.Airoboros21InnerMonologueWrapper(include_opening_brace_in_prefix=False)
|
|
||||||
# grammar_name = "json"
|
|
||||||
grammar_name = "json_func_calls_with_inner_thoughts"
|
|
||||||
elif model == "dolphin-2.1-mistral-7b":
|
|
||||||
llm_wrapper = dolphin.Dolphin21MistralWrapper()
|
|
||||||
elif model == "dolphin-2.1-mistral-7b-grammar":
|
|
||||||
llm_wrapper = dolphin.Dolphin21MistralWrapper(include_opening_brace_in_prefix=False)
|
|
||||||
# grammar_name = "json"
|
|
||||||
grammar_name = "json_func_calls_with_inner_thoughts"
|
|
||||||
elif model == "zephyr-7B-alpha" or model == "zephyr-7B-beta":
|
|
||||||
llm_wrapper = zephyr.ZephyrMistralInnerMonologueWrapper()
|
|
||||||
elif model == "zephyr-7B-alpha-grammar" or model == "zephyr-7B-beta-grammar":
|
|
||||||
llm_wrapper = zephyr.ZephyrMistralInnerMonologueWrapper(include_opening_brace_in_prefix=False)
|
|
||||||
# grammar_name = "json"
|
|
||||||
grammar_name = "json_func_calls_with_inner_thoughts"
|
|
||||||
else:
|
|
||||||
# Warn the user that we're using the fallback
|
# Warn the user that we're using the fallback
|
||||||
if not has_shown_warning:
|
if not has_shown_warning:
|
||||||
print(
|
print(
|
||||||
f"Warning: no wrapper specified for local LLM, using the default wrapper (you can remove this warning by specifying the wrapper with --model)"
|
f"Warning: no wrapper specified for local LLM, using the default wrapper (you can remove this warning by specifying the wrapper with --wrapper)"
|
||||||
)
|
)
|
||||||
has_shown_warning = True
|
has_shown_warning = True
|
||||||
if HOST_TYPE in ["koboldcpp", "llamacpp", "webui"]:
|
if endpoint_type in ["koboldcpp", "llamacpp", "webui"]:
|
||||||
# make the default to use grammar
|
# make the default to use grammar
|
||||||
llm_wrapper = DEFAULT_WRAPPER(include_opening_brace_in_prefix=False)
|
llm_wrapper = DEFAULT_WRAPPER(include_opening_brace_in_prefix=False)
|
||||||
# grammar_name = "json"
|
# grammar_name = "json"
|
||||||
grammar_name = "json_func_calls_with_inner_thoughts"
|
grammar_name = "json_func_calls_with_inner_thoughts"
|
||||||
else:
|
else:
|
||||||
llm_wrapper = DEFAULT_WRAPPER()
|
llm_wrapper = DEFAULT_WRAPPER()
|
||||||
|
elif wrapper not in available_wrappers:
|
||||||
|
raise ValueError(f"Could not find requested wrapper '{wrapper} in available wrappers list:\n{available_wrappers}")
|
||||||
|
else:
|
||||||
|
llm_wrapper = available_wrappers[wrapper]
|
||||||
|
if "grammar" in wrapper:
|
||||||
|
grammar_name = "json_func_calls_with_inner_thoughts"
|
||||||
|
|
||||||
if grammar_name is not None and HOST_TYPE not in ["koboldcpp", "llamacpp", "webui"]:
|
if grammar_name is not None and endpoint_type not in ["koboldcpp", "llamacpp", "webui"]:
|
||||||
print(f"Warning: grammars are currently only supported when using llama.cpp as the MemGPT local LLM backend")
|
print(f"Warning: grammars are currently only supported when using llama.cpp as the MemGPT local LLM backend")
|
||||||
|
|
||||||
# First step: turn the message sequence into a prompt that the model expects
|
# First step: turn the message sequence into a prompt that the model expects
|
||||||
@ -91,25 +82,25 @@ def get_chat_completion(
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if HOST_TYPE == "webui":
|
if endpoint_type == "webui":
|
||||||
result = get_webui_completion(prompt, context_window, grammar=grammar_name)
|
result = get_webui_completion(endpoint, prompt, context_window, grammar=grammar_name)
|
||||||
elif HOST_TYPE == "lmstudio":
|
elif endpoint_type == "lmstudio":
|
||||||
result = get_lmstudio_completion(prompt, context_window)
|
result = get_lmstudio_completion(endpoint, prompt, context_window)
|
||||||
elif HOST_TYPE == "llamacpp":
|
elif endpoint_type == "llamacpp":
|
||||||
result = get_llamacpp_completion(prompt, context_window, grammar=grammar_name)
|
result = get_llamacpp_completion(endpoint, prompt, context_window, grammar=grammar_name)
|
||||||
elif HOST_TYPE == "koboldcpp":
|
elif endpoint_type == "koboldcpp":
|
||||||
result = get_koboldcpp_completion(prompt, context_window, grammar=grammar_name)
|
result = get_koboldcpp_completion(endpoint, prompt, context_window, grammar=grammar_name)
|
||||||
elif HOST_TYPE == "ollama":
|
elif endpoint_type == "ollama":
|
||||||
result = get_ollama_completion(prompt, context_window)
|
result = get_ollama_completion(endpoint, model, prompt, context_window)
|
||||||
else:
|
else:
|
||||||
raise LocalLLMError(
|
raise LocalLLMError(
|
||||||
f"BACKEND_TYPE is not set, please set variable depending on your backend (webui, lmstudio, llamacpp, koboldcpp)"
|
f"BACKEND_TYPE is not set, please set variable depending on your backend (webui, lmstudio, llamacpp, koboldcpp)"
|
||||||
)
|
)
|
||||||
except requests.exceptions.ConnectionError as e:
|
except requests.exceptions.ConnectionError as e:
|
||||||
raise LocalLLMConnectionError(f"Unable to connect to host {HOST}")
|
raise LocalLLMConnectionError(f"Unable to connect to endpoint {endpoint}")
|
||||||
|
|
||||||
if result is None or result == "":
|
if result is None or result == "":
|
||||||
raise LocalLLMError(f"Got back an empty response string from {HOST}")
|
raise LocalLLMError(f"Got back an empty response string from {endpoint}")
|
||||||
if DEBUG:
|
if DEBUG:
|
||||||
print(f"Raw LLM output:\n{result}")
|
print(f"Raw LLM output:\n{result}")
|
||||||
|
|
||||||
@ -123,7 +114,7 @@ def get_chat_completion(
|
|||||||
# unpack with response.choices[0].message.content
|
# unpack with response.choices[0].message.content
|
||||||
response = DotDict(
|
response = DotDict(
|
||||||
{
|
{
|
||||||
"model": None,
|
"model": model,
|
||||||
"choices": [
|
"choices": [
|
||||||
DotDict(
|
DotDict(
|
||||||
{
|
{
|
||||||
|
14
memgpt/local_llm/constants.py
Normal file
14
memgpt/local_llm/constants.py
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
import memgpt.local_llm.llm_chat_completion_wrappers.airoboros as airoboros
|
||||||
|
|
||||||
|
DEFAULT_ENDPOINTS = {
|
||||||
|
"koboldcpp": "http://localhost:5001",
|
||||||
|
"llamacpp": "http://localhost:8080",
|
||||||
|
"lmstudio": "http://localhost:1234",
|
||||||
|
"ollama": "http://localhost:11434",
|
||||||
|
"webui": "http://localhost:5000",
|
||||||
|
}
|
||||||
|
|
||||||
|
DEFAULT_OLLAMA_MODEL = "dolphin2.2-mistral:7b-q6_K"
|
||||||
|
|
||||||
|
DEFAULT_WRAPPER = airoboros.Airoboros21InnerMonologueWrapper
|
||||||
|
DEFAULT_WRAPPER_NAME = "airoboros-l2-70b-2.1"
|
@ -5,14 +5,12 @@ import requests
|
|||||||
from .settings import SIMPLE
|
from .settings import SIMPLE
|
||||||
from ..utils import load_grammar_file, count_tokens
|
from ..utils import load_grammar_file, count_tokens
|
||||||
|
|
||||||
HOST = os.getenv("OPENAI_API_BASE")
|
|
||||||
HOST_TYPE = os.getenv("BACKEND_TYPE") # default None == ChatCompletion
|
|
||||||
KOBOLDCPP_API_SUFFIX = "/api/v1/generate"
|
KOBOLDCPP_API_SUFFIX = "/api/v1/generate"
|
||||||
# DEBUG = False
|
DEBUG = False
|
||||||
DEBUG = True
|
# DEBUG = True
|
||||||
|
|
||||||
|
|
||||||
def get_koboldcpp_completion(prompt, context_window, grammar=None, settings=SIMPLE):
|
def get_koboldcpp_completion(endpoint, prompt, context_window, grammar=None, settings=SIMPLE):
|
||||||
"""See https://lite.koboldai.net/koboldcpp_api for API spec"""
|
"""See https://lite.koboldai.net/koboldcpp_api for API spec"""
|
||||||
prompt_tokens = count_tokens(prompt)
|
prompt_tokens = count_tokens(prompt)
|
||||||
if prompt_tokens > context_window:
|
if prompt_tokens > context_window:
|
||||||
@ -27,13 +25,13 @@ def get_koboldcpp_completion(prompt, context_window, grammar=None, settings=SIMP
|
|||||||
if grammar is not None:
|
if grammar is not None:
|
||||||
request["grammar"] = load_grammar_file(grammar)
|
request["grammar"] = load_grammar_file(grammar)
|
||||||
|
|
||||||
if not HOST.startswith(("http://", "https://")):
|
if not endpoint.startswith(("http://", "https://")):
|
||||||
raise ValueError(f"Provided OPENAI_API_BASE value ({HOST}) must begin with http:// or https://")
|
raise ValueError(f"Provided OPENAI_API_BASE value ({endpoint}) must begin with http:// or https://")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# NOTE: llama.cpp server returns the following when it's out of context
|
# NOTE: llama.cpp server returns the following when it's out of context
|
||||||
# curl: (52) Empty reply from server
|
# curl: (52) Empty reply from server
|
||||||
URI = urljoin(HOST.strip("/") + "/", KOBOLDCPP_API_SUFFIX.strip("/"))
|
URI = urljoin(endpoint.strip("/") + "/", KOBOLDCPP_API_SUFFIX.strip("/"))
|
||||||
response = requests.post(URI, json=request)
|
response = requests.post(URI, json=request)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
result = response.json()
|
result = response.json()
|
||||||
|
@ -5,14 +5,12 @@ import requests
|
|||||||
from .settings import SIMPLE
|
from .settings import SIMPLE
|
||||||
from ..utils import load_grammar_file, count_tokens
|
from ..utils import load_grammar_file, count_tokens
|
||||||
|
|
||||||
HOST = os.getenv("OPENAI_API_BASE")
|
|
||||||
HOST_TYPE = os.getenv("BACKEND_TYPE") # default None == ChatCompletion
|
|
||||||
LLAMACPP_API_SUFFIX = "/completion"
|
LLAMACPP_API_SUFFIX = "/completion"
|
||||||
# DEBUG = False
|
DEBUG = False
|
||||||
DEBUG = True
|
# DEBUG = True
|
||||||
|
|
||||||
|
|
||||||
def get_llamacpp_completion(prompt, context_window, grammar=None, settings=SIMPLE):
|
def get_llamacpp_completion(endpoint, prompt, context_window, grammar=None, settings=SIMPLE):
|
||||||
"""See https://github.com/ggerganov/llama.cpp/blob/master/examples/server/README.md for instructions on how to run the LLM web server"""
|
"""See https://github.com/ggerganov/llama.cpp/blob/master/examples/server/README.md for instructions on how to run the LLM web server"""
|
||||||
prompt_tokens = count_tokens(prompt)
|
prompt_tokens = count_tokens(prompt)
|
||||||
if prompt_tokens > context_window:
|
if prompt_tokens > context_window:
|
||||||
@ -26,13 +24,13 @@ def get_llamacpp_completion(prompt, context_window, grammar=None, settings=SIMPL
|
|||||||
if grammar is not None:
|
if grammar is not None:
|
||||||
request["grammar"] = load_grammar_file(grammar)
|
request["grammar"] = load_grammar_file(grammar)
|
||||||
|
|
||||||
if not HOST.startswith(("http://", "https://")):
|
if not endpoint.startswith(("http://", "https://")):
|
||||||
raise ValueError(f"Provided OPENAI_API_BASE value ({HOST}) must begin with http:// or https://")
|
raise ValueError(f"Provided OPENAI_API_BASE value ({endpoint}) must begin with http:// or https://")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# NOTE: llama.cpp server returns the following when it's out of context
|
# NOTE: llama.cpp server returns the following when it's out of context
|
||||||
# curl: (52) Empty reply from server
|
# curl: (52) Empty reply from server
|
||||||
URI = urljoin(HOST.strip("/") + "/", LLAMACPP_API_SUFFIX.strip("/"))
|
URI = urljoin(endpoint.strip("/") + "/", LLAMACPP_API_SUFFIX.strip("/"))
|
||||||
response = requests.post(URI, json=request)
|
response = requests.post(URI, json=request)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
result = response.json()
|
result = response.json()
|
||||||
|
@ -5,14 +5,12 @@ import requests
|
|||||||
from .settings import SIMPLE
|
from .settings import SIMPLE
|
||||||
from ..utils import count_tokens
|
from ..utils import count_tokens
|
||||||
|
|
||||||
HOST = os.getenv("OPENAI_API_BASE")
|
|
||||||
HOST_TYPE = os.getenv("BACKEND_TYPE") # default None == ChatCompletion
|
|
||||||
LMSTUDIO_API_CHAT_SUFFIX = "/v1/chat/completions"
|
LMSTUDIO_API_CHAT_SUFFIX = "/v1/chat/completions"
|
||||||
LMSTUDIO_API_COMPLETIONS_SUFFIX = "/v1/completions"
|
LMSTUDIO_API_COMPLETIONS_SUFFIX = "/v1/completions"
|
||||||
DEBUG = False
|
DEBUG = False
|
||||||
|
|
||||||
|
|
||||||
def get_lmstudio_completion(prompt, context_window, settings=SIMPLE, api="chat"):
|
def get_lmstudio_completion(endpoint, prompt, context_window, settings=SIMPLE, api="chat"):
|
||||||
"""Based on the example for using LM Studio as a backend from https://github.com/lmstudio-ai/examples/tree/main/Hello%2C%20world%20-%20OpenAI%20python%20client"""
|
"""Based on the example for using LM Studio as a backend from https://github.com/lmstudio-ai/examples/tree/main/Hello%2C%20world%20-%20OpenAI%20python%20client"""
|
||||||
prompt_tokens = count_tokens(prompt)
|
prompt_tokens = count_tokens(prompt)
|
||||||
if prompt_tokens > context_window:
|
if prompt_tokens > context_window:
|
||||||
@ -25,19 +23,19 @@ def get_lmstudio_completion(prompt, context_window, settings=SIMPLE, api="chat")
|
|||||||
if api == "chat":
|
if api == "chat":
|
||||||
# Uses the ChatCompletions API style
|
# Uses the ChatCompletions API style
|
||||||
# Seems to work better, probably because it's applying some extra settings under-the-hood?
|
# Seems to work better, probably because it's applying some extra settings under-the-hood?
|
||||||
URI = urljoin(HOST.strip("/") + "/", LMSTUDIO_API_CHAT_SUFFIX.strip("/"))
|
URI = urljoin(endpoint.strip("/") + "/", LMSTUDIO_API_CHAT_SUFFIX.strip("/"))
|
||||||
message_structure = [{"role": "user", "content": prompt}]
|
message_structure = [{"role": "user", "content": prompt}]
|
||||||
request["messages"] = message_structure
|
request["messages"] = message_structure
|
||||||
elif api == "completions":
|
elif api == "completions":
|
||||||
# Uses basic string completions (string in, string out)
|
# Uses basic string completions (string in, string out)
|
||||||
# Does not work as well as ChatCompletions for some reason
|
# Does not work as well as ChatCompletions for some reason
|
||||||
URI = urljoin(HOST.strip("/") + "/", LMSTUDIO_API_COMPLETIONS_SUFFIX.strip("/"))
|
URI = urljoin(endpoint.strip("/") + "/", LMSTUDIO_API_COMPLETIONS_SUFFIX.strip("/"))
|
||||||
request["prompt"] = prompt
|
request["prompt"] = prompt
|
||||||
else:
|
else:
|
||||||
raise ValueError(api)
|
raise ValueError(api)
|
||||||
|
|
||||||
if not HOST.startswith(("http://", "https://")):
|
if not endpoint.startswith(("http://", "https://")):
|
||||||
raise ValueError(f"Provided OPENAI_API_BASE value ({HOST}) must begin with http:// or https://")
|
raise ValueError(f"Provided OPENAI_API_BASE value ({endpoint}) must begin with http:// or https://")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = requests.post(URI, json=request)
|
response = requests.post(URI, json=request)
|
||||||
|
@ -6,26 +6,25 @@ from .settings import SIMPLE
|
|||||||
from ..utils import count_tokens
|
from ..utils import count_tokens
|
||||||
from ...errors import LocalLLMError
|
from ...errors import LocalLLMError
|
||||||
|
|
||||||
HOST = os.getenv("OPENAI_API_BASE")
|
|
||||||
HOST_TYPE = os.getenv("BACKEND_TYPE") # default None == ChatCompletion
|
|
||||||
MODEL_NAME = os.getenv("OLLAMA_MODEL") # ollama API requires this in the request
|
|
||||||
OLLAMA_API_SUFFIX = "/api/generate"
|
OLLAMA_API_SUFFIX = "/api/generate"
|
||||||
DEBUG = False
|
DEBUG = False
|
||||||
|
|
||||||
|
|
||||||
def get_ollama_completion(prompt, context_window, settings=SIMPLE, grammar=None):
|
def get_ollama_completion(endpoint, model, prompt, context_window, settings=SIMPLE, grammar=None):
|
||||||
"""See https://github.com/jmorganca/ollama/blob/main/docs/api.md for instructions on how to run the LLM web server"""
|
"""See https://github.com/jmorganca/ollama/blob/main/docs/api.md for instructions on how to run the LLM web server"""
|
||||||
prompt_tokens = count_tokens(prompt)
|
prompt_tokens = count_tokens(prompt)
|
||||||
if prompt_tokens > context_window:
|
if prompt_tokens > context_window:
|
||||||
raise Exception(f"Request exceeds maximum context length ({prompt_tokens} > {context_window} tokens)")
|
raise Exception(f"Request exceeds maximum context length ({prompt_tokens} > {context_window} tokens)")
|
||||||
|
|
||||||
if MODEL_NAME is None:
|
if model is None:
|
||||||
raise LocalLLMError(f"Error: OLLAMA_MODEL not specified. Set OLLAMA_MODEL to the model you want to run (e.g. 'dolphin2.2-mistral')")
|
raise LocalLLMError(
|
||||||
|
f"Error: model name not specified. Set model in your config to the model you want to run (e.g. 'dolphin2.2-mistral')"
|
||||||
|
)
|
||||||
|
|
||||||
# Settings for the generation, includes the prompt + stop tokens, max length, etc
|
# Settings for the generation, includes the prompt + stop tokens, max length, etc
|
||||||
request = settings
|
request = settings
|
||||||
request["prompt"] = prompt
|
request["prompt"] = prompt
|
||||||
request["model"] = MODEL_NAME
|
request["model"] = model
|
||||||
request["options"]["num_ctx"] = context_window
|
request["options"]["num_ctx"] = context_window
|
||||||
|
|
||||||
# Set grammar
|
# Set grammar
|
||||||
@ -33,11 +32,11 @@ def get_ollama_completion(prompt, context_window, settings=SIMPLE, grammar=None)
|
|||||||
# request["grammar_string"] = load_grammar_file(grammar)
|
# request["grammar_string"] = load_grammar_file(grammar)
|
||||||
raise NotImplementedError(f"Ollama does not support grammars")
|
raise NotImplementedError(f"Ollama does not support grammars")
|
||||||
|
|
||||||
if not HOST.startswith(("http://", "https://")):
|
if not endpoint.startswith(("http://", "https://")):
|
||||||
raise ValueError(f"Provided OPENAI_API_BASE value ({HOST}) must begin with http:// or https://")
|
raise ValueError(f"Provided OPENAI_API_BASE value ({endpoint}) must begin with http:// or https://")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
URI = urljoin(HOST.strip("/") + "/", OLLAMA_API_SUFFIX.strip("/"))
|
URI = urljoin(endpoint.strip("/") + "/", OLLAMA_API_SUFFIX.strip("/"))
|
||||||
response = requests.post(URI, json=request)
|
response = requests.post(URI, json=request)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
result = response.json()
|
result = response.json()
|
||||||
|
@ -1,6 +1,10 @@
|
|||||||
import os
|
import os
|
||||||
import tiktoken
|
import tiktoken
|
||||||
|
|
||||||
|
import memgpt.local_llm.llm_chat_completion_wrappers.airoboros as airoboros
|
||||||
|
import memgpt.local_llm.llm_chat_completion_wrappers.dolphin as dolphin
|
||||||
|
import memgpt.local_llm.llm_chat_completion_wrappers.zephyr as zephyr
|
||||||
|
|
||||||
|
|
||||||
class DotDict(dict):
|
class DotDict(dict):
|
||||||
"""Allow dot access on properties similar to OpenAI response object"""
|
"""Allow dot access on properties similar to OpenAI response object"""
|
||||||
@ -37,3 +41,14 @@ def load_grammar_file(grammar):
|
|||||||
def count_tokens(s: str, model: str = "gpt-4") -> int:
|
def count_tokens(s: str, model: str = "gpt-4") -> int:
|
||||||
encoding = tiktoken.encoding_for_model(model)
|
encoding = tiktoken.encoding_for_model(model)
|
||||||
return len(encoding.encode(s))
|
return len(encoding.encode(s))
|
||||||
|
|
||||||
|
|
||||||
|
def get_available_wrappers() -> dict:
|
||||||
|
return {
|
||||||
|
"airoboros-l2-70b-2.1": airoboros.Airoboros21InnerMonologueWrapper(),
|
||||||
|
"airoboros-l2-70b-2.1-grammar": airoboros.Airoboros21InnerMonologueWrapper(include_opening_brace_in_prefix=False),
|
||||||
|
"dolphin-2.1-mistral-7b": dolphin.Dolphin21MistralWrapper(),
|
||||||
|
"dolphin-2.1-mistral-7b-grammar": dolphin.Dolphin21MistralWrapper(include_opening_brace_in_prefix=False),
|
||||||
|
"zephyr-7B": zephyr.ZephyrMistralInnerMonologueWrapper(),
|
||||||
|
"zephyr-7B-grammar": zephyr.ZephyrMistralInnerMonologueWrapper(include_opening_brace_in_prefix=False),
|
||||||
|
}
|
||||||
|
@ -5,13 +5,11 @@ import requests
|
|||||||
from .settings import SIMPLE
|
from .settings import SIMPLE
|
||||||
from ..utils import load_grammar_file, count_tokens
|
from ..utils import load_grammar_file, count_tokens
|
||||||
|
|
||||||
HOST = os.getenv("OPENAI_API_BASE")
|
|
||||||
HOST_TYPE = os.getenv("BACKEND_TYPE") # default None == ChatCompletion
|
|
||||||
WEBUI_API_SUFFIX = "/api/v1/generate"
|
WEBUI_API_SUFFIX = "/api/v1/generate"
|
||||||
DEBUG = False
|
DEBUG = False
|
||||||
|
|
||||||
|
|
||||||
def get_webui_completion(prompt, context_window, settings=SIMPLE, grammar=None):
|
def get_webui_completion(endpoint, prompt, context_window, settings=SIMPLE, grammar=None):
|
||||||
"""See https://github.com/oobabooga/text-generation-webui for instructions on how to run the LLM web server"""
|
"""See https://github.com/oobabooga/text-generation-webui for instructions on how to run the LLM web server"""
|
||||||
prompt_tokens = count_tokens(prompt)
|
prompt_tokens = count_tokens(prompt)
|
||||||
if prompt_tokens > context_window:
|
if prompt_tokens > context_window:
|
||||||
@ -26,11 +24,11 @@ def get_webui_completion(prompt, context_window, settings=SIMPLE, grammar=None):
|
|||||||
if grammar is not None:
|
if grammar is not None:
|
||||||
request["grammar_string"] = load_grammar_file(grammar)
|
request["grammar_string"] = load_grammar_file(grammar)
|
||||||
|
|
||||||
if not HOST.startswith(("http://", "https://")):
|
if not endpoint.startswith(("http://", "https://")):
|
||||||
raise ValueError(f"Provided OPENAI_API_BASE value ({HOST}) must begin with http:// or https://")
|
raise ValueError(f"Provided OPENAI_API_BASE value ({endpoint}) must begin with http:// or https://")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
URI = urljoin(HOST.strip("/") + "/", WEBUI_API_SUFFIX.strip("/"))
|
URI = urljoin(endpoint.strip("/") + "/", WEBUI_API_SUFFIX.strip("/"))
|
||||||
response = requests.post(URI, json=request)
|
response = requests.post(URI, json=request)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
result = response.json()
|
result = response.json()
|
||||||
|
@ -241,7 +241,7 @@ def main(
|
|||||||
memgpt_persona = persona
|
memgpt_persona = persona
|
||||||
if memgpt_persona is None:
|
if memgpt_persona is None:
|
||||||
memgpt_persona = (
|
memgpt_persona = (
|
||||||
personas.GPT35_DEFAULT if "gpt-3.5" in model else personas.DEFAULT,
|
personas.GPT35_DEFAULT if (model is not None and "gpt-3.5" in model) else personas.DEFAULT,
|
||||||
None, # represents the personas dir in pymemgpt package
|
None, # represents the personas dir in pymemgpt package
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
@ -2,10 +2,14 @@ import random
|
|||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from .local_llm.chat_completion_proxy import get_chat_completion
|
import time
|
||||||
|
from typing import Callable, TypeVar
|
||||||
|
|
||||||
|
from memgpt.local_llm.chat_completion_proxy import get_chat_completion
|
||||||
|
|
||||||
HOST = os.getenv("OPENAI_API_BASE")
|
HOST = os.getenv("OPENAI_API_BASE")
|
||||||
HOST_TYPE = os.getenv("BACKEND_TYPE") # default None == ChatCompletion
|
HOST_TYPE = os.getenv("BACKEND_TYPE") # default None == ChatCompletion
|
||||||
|
R = TypeVar("R")
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
|
|
||||||
@ -55,6 +59,7 @@ def retry_with_exponential_backoff(
|
|||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: delete/ignore --legacy
|
||||||
@retry_with_exponential_backoff
|
@retry_with_exponential_backoff
|
||||||
def completions_with_backoff(**kwargs):
|
def completions_with_backoff(**kwargs):
|
||||||
# Local model
|
# Local model
|
||||||
@ -75,6 +80,38 @@ def completions_with_backoff(**kwargs):
|
|||||||
return openai.ChatCompletion.create(**kwargs)
|
return openai.ChatCompletion.create(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@retry_with_exponential_backoff
|
||||||
|
def chat_completion_with_backoff(agent_config, **kwargs):
|
||||||
|
from memgpt.utils import printd
|
||||||
|
from memgpt.config import AgentConfig, MemGPTConfig
|
||||||
|
|
||||||
|
printd(f"Using model {agent_config.model_endpoint_type}, endpoint: {agent_config.model_endpoint}")
|
||||||
|
if agent_config.model_endpoint_type == "openai":
|
||||||
|
# openai
|
||||||
|
openai.api_base = agent_config.model_endpoint
|
||||||
|
return openai.ChatCompletion.create(**kwargs)
|
||||||
|
elif agent_config.model_endpoint_type == "azure":
|
||||||
|
# configure openai
|
||||||
|
config = MemGPTConfig.load() # load credentials (currently not stored in agent config)
|
||||||
|
openai.api_type = "azure"
|
||||||
|
openai.api_key = config.azure_key
|
||||||
|
openai.api_base = config.azure_endpoint
|
||||||
|
openai.api_version = config.azure_version
|
||||||
|
if config.azure_deployment is not None:
|
||||||
|
kwargs["deployment_id"] = config.azure_deployment
|
||||||
|
else:
|
||||||
|
kwargs["engine"] = MODEL_TO_AZURE_ENGINE[config.model]
|
||||||
|
del kwargs["model"]
|
||||||
|
return openai.ChatCompletion.create(**kwargs)
|
||||||
|
else: # local model
|
||||||
|
kwargs["context_window"] = agent_config.context_window # specify for open LLMs
|
||||||
|
kwargs["endpoint"] = agent_config.model_endpoint # specify for open LLMs
|
||||||
|
kwargs["endpoint_type"] = agent_config.model_endpoint_type # specify for open LLMs
|
||||||
|
kwargs["wrapper"] = agent_config.model_wrapper # specify for open LLMs
|
||||||
|
return get_chat_completion(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: deprecate
|
||||||
@retry_with_exponential_backoff
|
@retry_with_exponential_backoff
|
||||||
def create_embedding_with_backoff(**kwargs):
|
def create_embedding_with_backoff(**kwargs):
|
||||||
if using_azure():
|
if using_azure():
|
||||||
|
@ -57,5 +57,5 @@ def use_preset(preset_name, agent_config, model, persona, human, interface, pers
|
|||||||
persona_notes=persona,
|
persona_notes=persona,
|
||||||
human_notes=human,
|
human_notes=human,
|
||||||
# gpt-3.5-turbo tends to omit inner monologue, relax this requirement for now
|
# gpt-3.5-turbo tends to omit inner monologue, relax this requirement for now
|
||||||
first_message_verify_mono=True if "gpt-4" in model else False,
|
first_message_verify_mono=True if (model is not None and "gpt-4" in model) else False,
|
||||||
)
|
)
|
||||||
|
@ -13,7 +13,7 @@ def test_configure_memgpt():
|
|||||||
|
|
||||||
|
|
||||||
def test_save_load():
|
def test_save_load():
|
||||||
configure_memgpt()
|
# configure_memgpt() # rely on configure running first^
|
||||||
child = pexpect.spawn("memgpt run --agent test_save_load --first --strip_ui")
|
child = pexpect.spawn("memgpt run --agent test_save_load --first --strip_ui")
|
||||||
|
|
||||||
child.expect("Enter your message:", timeout=TIMEOUT)
|
child.expect("Enter your message:", timeout=TIMEOUT)
|
||||||
|
@ -1,12 +1,12 @@
|
|||||||
# import tempfile
|
# import tempfile
|
||||||
# import asyncio
|
# import asyncio
|
||||||
import os
|
# import os
|
||||||
|
|
||||||
# import asyncio
|
# import asyncio
|
||||||
from datasets import load_dataset
|
# from datasets import load_dataset
|
||||||
|
|
||||||
import memgpt
|
# import memgpt
|
||||||
from memgpt.cli.cli_load import load_directory, load_database, load_webpage
|
# from memgpt.cli.cli_load import load_directory, load_database, load_webpage
|
||||||
|
|
||||||
# import memgpt.presets as presets
|
# import memgpt.presets as presets
|
||||||
# import memgpt.personas.personas as personas
|
# import memgpt.personas.personas as personas
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
|
import pytest
|
||||||
|
|
||||||
subprocess.check_call(
|
subprocess.check_call(
|
||||||
[sys.executable, "-m", "pip", "install", "pgvector", "psycopg", "psycopg2-binary"]
|
[sys.executable, "-m", "pip", "install", "pgvector", "psycopg", "psycopg2-binary"]
|
||||||
@ -15,9 +16,11 @@ from memgpt.config import MemGPTConfig, AgentConfig
|
|||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not os.getenv("PGVECTOR_TEST_DB_URL") or not os.getenv("OPENAI_API_KEY"), reason="Missing PG URI and/or OpenAI API key")
|
||||||
def test_postgres_openai():
|
def test_postgres_openai():
|
||||||
assert os.getenv("PGVECTOR_TEST_DB_URL") is not None
|
if not os.getenv("PGVECTOR_TEST_DB_URL"):
|
||||||
if os.getenv("OPENAI_API_KEY") is None:
|
return # soft pass
|
||||||
|
if not os.getenv("OPENAI_API_KEY"):
|
||||||
return # soft pass
|
return # soft pass
|
||||||
|
|
||||||
# os.environ["MEMGPT_CONFIG_PATH"] = "./config"
|
# os.environ["MEMGPT_CONFIG_PATH"] = "./config"
|
||||||
@ -54,14 +57,16 @@ def test_postgres_openai():
|
|||||||
# print("...finished")
|
# print("...finished")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not os.getenv("PGVECTOR_TEST_DB_URL"), reason="Missing PG URI")
|
||||||
def test_postgres_local():
|
def test_postgres_local():
|
||||||
assert os.getenv("PGVECTOR_TEST_DB_URL") is not None
|
if not os.getenv("PGVECTOR_TEST_DB_URL"):
|
||||||
|
return
|
||||||
# os.environ["MEMGPT_CONFIG_PATH"] = "./config"
|
# os.environ["MEMGPT_CONFIG_PATH"] = "./config"
|
||||||
|
|
||||||
config = MemGPTConfig(
|
config = MemGPTConfig(
|
||||||
archival_storage_type="postgres",
|
archival_storage_type="postgres",
|
||||||
archival_storage_uri=os.getenv("PGVECTOR_TEST_DB_URL"),
|
archival_storage_uri=os.getenv("PGVECTOR_TEST_DB_URL"),
|
||||||
embedding_model="local",
|
embedding_endpoint_type="local",
|
||||||
embedding_dim=384, # use HF model
|
embedding_dim=384, # use HF model
|
||||||
)
|
)
|
||||||
print(config.config_path)
|
print(config.config_path)
|
||||||
|
@ -3,43 +3,55 @@ import pexpect
|
|||||||
from .constants import TIMEOUT
|
from .constants import TIMEOUT
|
||||||
|
|
||||||
|
|
||||||
def configure_memgpt(enable_openai=True, enable_azure=False):
|
def configure_memgpt_localllm():
|
||||||
child = pexpect.spawn("memgpt configure")
|
child = pexpect.spawn("memgpt configure")
|
||||||
|
|
||||||
child.expect("Do you want to enable MemGPT with OpenAI?", timeout=TIMEOUT)
|
child.expect("Select LLM inference provider", timeout=TIMEOUT)
|
||||||
if enable_openai:
|
child.send("\x1b[B") # Send the down arrow key
|
||||||
child.sendline("y")
|
child.send("\x1b[B") # Send the down arrow key
|
||||||
else:
|
|
||||||
child.sendline("n")
|
|
||||||
|
|
||||||
child.expect("Do you want to enable MemGPT with Azure?", timeout=TIMEOUT)
|
|
||||||
if enable_azure:
|
|
||||||
child.sendline("y")
|
|
||||||
else:
|
|
||||||
child.sendline("n")
|
|
||||||
|
|
||||||
child.expect("Select default inference endpoint:", timeout=TIMEOUT)
|
|
||||||
child.sendline()
|
child.sendline()
|
||||||
|
|
||||||
child.expect("Select default embedding endpoint:", timeout=TIMEOUT)
|
child.expect("Select LLM backend", timeout=TIMEOUT)
|
||||||
child.sendline()
|
child.sendline()
|
||||||
|
|
||||||
child.expect("Select default preset:", timeout=TIMEOUT)
|
child.expect("Enter default endpoint", timeout=TIMEOUT)
|
||||||
child.sendline()
|
child.sendline()
|
||||||
|
|
||||||
child.expect("Select default model", timeout=TIMEOUT)
|
child.expect("Select default model wrapper", timeout=TIMEOUT)
|
||||||
child.sendline()
|
child.sendline()
|
||||||
|
|
||||||
child.expect("Select default persona:", timeout=TIMEOUT)
|
child.expect("Select your model's context window", timeout=TIMEOUT)
|
||||||
child.sendline()
|
child.sendline()
|
||||||
|
|
||||||
child.expect("Select default human:", timeout=TIMEOUT)
|
child.expect("Select embedding provider", timeout=TIMEOUT)
|
||||||
|
child.send("\x1b[B") # Send the down arrow key
|
||||||
|
child.send("\x1b[B") # Send the down arrow key
|
||||||
|
child.sendline()
|
||||||
|
|
||||||
|
child.expect("Select default preset", timeout=TIMEOUT)
|
||||||
|
child.sendline()
|
||||||
|
|
||||||
|
child.expect("Select default persona", timeout=TIMEOUT)
|
||||||
|
child.sendline()
|
||||||
|
|
||||||
|
child.expect("Select default human", timeout=TIMEOUT)
|
||||||
|
child.sendline()
|
||||||
|
|
||||||
|
child.expect("Select storage backend for archival data", timeout=TIMEOUT)
|
||||||
child.sendline()
|
child.sendline()
|
||||||
|
|
||||||
child.expect("Select storage backend for archival data:", timeout=TIMEOUT)
|
|
||||||
child.sendline()
|
child.sendline()
|
||||||
|
|
||||||
child.expect(pexpect.EOF, timeout=TIMEOUT) # Wait for child to exit
|
child.expect(pexpect.EOF, timeout=TIMEOUT) # Wait for child to exit
|
||||||
child.close()
|
child.close()
|
||||||
assert child.isalive() is False, "CLI should have terminated."
|
assert child.isalive() is False, "CLI should have terminated."
|
||||||
assert child.exitstatus == 0, "CLI did not exit cleanly."
|
assert child.exitstatus == 0, "CLI did not exit cleanly."
|
||||||
|
|
||||||
|
|
||||||
|
def configure_memgpt(enable_openai=False, enable_azure=False):
|
||||||
|
if enable_openai:
|
||||||
|
raise NotImplementedError
|
||||||
|
elif enable_azure:
|
||||||
|
raise NotImplementedError
|
||||||
|
else:
|
||||||
|
configure_memgpt_localllm()
|
||||||
|
Loading…
Reference in New Issue
Block a user