mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
don't allow bad endpoint addresses during memgpt configure
This commit is contained in:
parent
366add62c5
commit
bad6abc2f4
@ -123,16 +123,21 @@ def configure_llm_endpoint(config: MemGPTConfig):
|
|||||||
if model_endpoint_type in DEFAULT_ENDPOINTS:
|
if model_endpoint_type in DEFAULT_ENDPOINTS:
|
||||||
default_model_endpoint = DEFAULT_ENDPOINTS[model_endpoint_type]
|
default_model_endpoint = DEFAULT_ENDPOINTS[model_endpoint_type]
|
||||||
model_endpoint = questionary.text("Enter default endpoint:", default=default_model_endpoint).ask()
|
model_endpoint = questionary.text("Enter default endpoint:", default=default_model_endpoint).ask()
|
||||||
|
while not utils.is_valid_url(model_endpoint):
|
||||||
|
typer.secho(f"Endpoint must be a valid address", fg=typer.colors.YELLOW)
|
||||||
|
model_endpoint = questionary.text("Enter default endpoint:", default=default_model_endpoint).ask()
|
||||||
elif config.model_endpoint:
|
elif config.model_endpoint:
|
||||||
model_endpoint = questionary.text("Enter default endpoint:", default=config.model_endpoint).ask()
|
model_endpoint = questionary.text("Enter default endpoint:", default=config.model_endpoint).ask()
|
||||||
|
while not utils.is_valid_url(model_endpoint):
|
||||||
|
typer.secho(f"Endpoint must be a valid address", fg=typer.colors.YELLOW)
|
||||||
|
model_endpoint = questionary.text("Enter default endpoint:", default=config.model_endpoint).ask()
|
||||||
else:
|
else:
|
||||||
# default_model_endpoint = None
|
# default_model_endpoint = None
|
||||||
model_endpoint = None
|
model_endpoint = None
|
||||||
while not model_endpoint:
|
model_endpoint = questionary.text("Enter default endpoint:").ask()
|
||||||
|
while not utils.is_valid_url(model_endpoint):
|
||||||
|
typer.secho(f"Endpoint must be a valid address", fg=typer.colors.YELLOW)
|
||||||
model_endpoint = questionary.text("Enter default endpoint:").ask()
|
model_endpoint = questionary.text("Enter default endpoint:").ask()
|
||||||
if "http://" not in model_endpoint and "https://" not in model_endpoint:
|
|
||||||
typer.secho(f"Endpoint must be a valid address", fg=typer.colors.YELLOW)
|
|
||||||
model_endpoint = None
|
|
||||||
else:
|
else:
|
||||||
model_endpoint = default_model_endpoint
|
model_endpoint = default_model_endpoint
|
||||||
assert model_endpoint, f"Environment variable OPENAI_API_BASE must be set."
|
assert model_endpoint, f"Environment variable OPENAI_API_BASE must be set."
|
||||||
@ -330,9 +335,9 @@ def configure_embedding_endpoint(config: MemGPTConfig):
|
|||||||
|
|
||||||
# get endpoint
|
# get endpoint
|
||||||
embedding_endpoint = questionary.text("Enter default endpoint:").ask()
|
embedding_endpoint = questionary.text("Enter default endpoint:").ask()
|
||||||
if "http://" not in embedding_endpoint and "https://" not in embedding_endpoint:
|
while not utils.is_valid_url(embedding_endpoint):
|
||||||
typer.secho(f"Endpoint must be a valid address", fg=typer.colors.YELLOW)
|
typer.secho(f"Endpoint must be a valid address", fg=typer.colors.YELLOW)
|
||||||
embedding_endpoint = None
|
embedding_endpoint = questionary.text("Enter default endpoint:").ask()
|
||||||
|
|
||||||
# get model type
|
# get model type
|
||||||
default_embedding_model = config.embedding_model if config.embedding_model else "BAAI/bge-large-en-v1.5"
|
default_embedding_model = config.embedding_model if config.embedding_model else "BAAI/bge-large-en-v1.5"
|
||||||
|
Loading…
Reference in New Issue
Block a user