don't allow bad endpoint addresses during memgpt configure

This commit is contained in:
cpacker 2024-01-02 13:27:19 -08:00
parent 366add62c5
commit bad6abc2f4

View File

@ -123,16 +123,21 @@ def configure_llm_endpoint(config: MemGPTConfig):
if model_endpoint_type in DEFAULT_ENDPOINTS: if model_endpoint_type in DEFAULT_ENDPOINTS:
default_model_endpoint = DEFAULT_ENDPOINTS[model_endpoint_type] default_model_endpoint = DEFAULT_ENDPOINTS[model_endpoint_type]
model_endpoint = questionary.text("Enter default endpoint:", default=default_model_endpoint).ask() model_endpoint = questionary.text("Enter default endpoint:", default=default_model_endpoint).ask()
while not utils.is_valid_url(model_endpoint):
typer.secho(f"Endpoint must be a valid address", fg=typer.colors.YELLOW)
model_endpoint = questionary.text("Enter default endpoint:", default=default_model_endpoint).ask()
elif config.model_endpoint: elif config.model_endpoint:
model_endpoint = questionary.text("Enter default endpoint:", default=config.model_endpoint).ask() model_endpoint = questionary.text("Enter default endpoint:", default=config.model_endpoint).ask()
while not utils.is_valid_url(model_endpoint):
typer.secho(f"Endpoint must be a valid address", fg=typer.colors.YELLOW)
model_endpoint = questionary.text("Enter default endpoint:", default=config.model_endpoint).ask()
else: else:
# default_model_endpoint = None # default_model_endpoint = None
model_endpoint = None model_endpoint = None
while not model_endpoint: model_endpoint = questionary.text("Enter default endpoint:").ask()
while not utils.is_valid_url(model_endpoint):
typer.secho(f"Endpoint must be a valid address", fg=typer.colors.YELLOW)
model_endpoint = questionary.text("Enter default endpoint:").ask() model_endpoint = questionary.text("Enter default endpoint:").ask()
if "http://" not in model_endpoint and "https://" not in model_endpoint:
typer.secho(f"Endpoint must be a valid address", fg=typer.colors.YELLOW)
model_endpoint = None
else: else:
model_endpoint = default_model_endpoint model_endpoint = default_model_endpoint
assert model_endpoint, f"Environment variable OPENAI_API_BASE must be set." assert model_endpoint, f"Environment variable OPENAI_API_BASE must be set."
@ -330,9 +335,9 @@ def configure_embedding_endpoint(config: MemGPTConfig):
# get endpoint # get endpoint
embedding_endpoint = questionary.text("Enter default endpoint:").ask() embedding_endpoint = questionary.text("Enter default endpoint:").ask()
if "http://" not in embedding_endpoint and "https://" not in embedding_endpoint: while not utils.is_valid_url(embedding_endpoint):
typer.secho(f"Endpoint must be a valid address", fg=typer.colors.YELLOW) typer.secho(f"Endpoint must be a valid address", fg=typer.colors.YELLOW)
embedding_endpoint = None embedding_endpoint = questionary.text("Enter default endpoint:").ask()
# get model type # get model type
default_embedding_model = config.embedding_model if config.embedding_model else "BAAI/bge-large-en-v1.5" default_embedding_model = config.embedding_model if config.embedding_model else "BAAI/bge-large-en-v1.5"