mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
don't allow bad endpoint addresses during memgpt configure
This commit is contained in:
parent
366add62c5
commit
bad6abc2f4
@ -123,16 +123,21 @@ def configure_llm_endpoint(config: MemGPTConfig):
|
||||
if model_endpoint_type in DEFAULT_ENDPOINTS:
|
||||
default_model_endpoint = DEFAULT_ENDPOINTS[model_endpoint_type]
|
||||
model_endpoint = questionary.text("Enter default endpoint:", default=default_model_endpoint).ask()
|
||||
while not utils.is_valid_url(model_endpoint):
|
||||
typer.secho(f"Endpoint must be a valid address", fg=typer.colors.YELLOW)
|
||||
model_endpoint = questionary.text("Enter default endpoint:", default=default_model_endpoint).ask()
|
||||
elif config.model_endpoint:
|
||||
model_endpoint = questionary.text("Enter default endpoint:", default=config.model_endpoint).ask()
|
||||
while not utils.is_valid_url(model_endpoint):
|
||||
typer.secho(f"Endpoint must be a valid address", fg=typer.colors.YELLOW)
|
||||
model_endpoint = questionary.text("Enter default endpoint:", default=config.model_endpoint).ask()
|
||||
else:
|
||||
# default_model_endpoint = None
|
||||
model_endpoint = None
|
||||
while not model_endpoint:
|
||||
model_endpoint = questionary.text("Enter default endpoint:").ask()
|
||||
while not utils.is_valid_url(model_endpoint):
|
||||
typer.secho(f"Endpoint must be a valid address", fg=typer.colors.YELLOW)
|
||||
model_endpoint = questionary.text("Enter default endpoint:").ask()
|
||||
if "http://" not in model_endpoint and "https://" not in model_endpoint:
|
||||
typer.secho(f"Endpoint must be a valid address", fg=typer.colors.YELLOW)
|
||||
model_endpoint = None
|
||||
else:
|
||||
model_endpoint = default_model_endpoint
|
||||
assert model_endpoint, f"Environment variable OPENAI_API_BASE must be set."
|
||||
@ -330,9 +335,9 @@ def configure_embedding_endpoint(config: MemGPTConfig):
|
||||
|
||||
# get endpoint
|
||||
embedding_endpoint = questionary.text("Enter default endpoint:").ask()
|
||||
if "http://" not in embedding_endpoint and "https://" not in embedding_endpoint:
|
||||
while not utils.is_valid_url(embedding_endpoint):
|
||||
typer.secho(f"Endpoint must be a valid address", fg=typer.colors.YELLOW)
|
||||
embedding_endpoint = None
|
||||
embedding_endpoint = questionary.text("Enter default endpoint:").ask()
|
||||
|
||||
# get model type
|
||||
default_embedding_model = config.embedding_model if config.embedding_model else "BAAI/bge-large-en-v1.5"
|
||||
|
Loading…
Reference in New Issue
Block a user