feat: implement remaining Admin routes in client and add tests (#1157)

This commit is contained in:
Sarah Wooders 2024-03-16 20:03:19 -07:00 committed by GitHub
parent d73c64f20e
commit 03e9f3ebbf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 395 additions and 91 deletions

View File

@ -277,33 +277,6 @@ def create_default_user_or_exit(config: MemGPTConfig, ms: MetadataStore):
return user
def generate_self_signed_cert(cert_path="selfsigned.crt", key_path="selfsigned.key"):
"""Generate a self-signed SSL certificate.
NOTE: intended to be used for development only.
"""
subprocess.run(
[
"openssl",
"req",
"-x509",
"-newkey",
"rsa:4096",
"-keyout",
key_path,
"-out",
cert_path,
"-days",
"365",
"-nodes",
"-subj",
"/C=US/ST=Denial/L=Springfield/O=Dis/CN=localhost",
],
check=True,
)
return cert_path, key_path
def server(
type: Annotated[ServerChoice, typer.Option(help="Server to run")] = "rest",
port: Annotated[Optional[int], typer.Option(help="Port to run the server on")] = None,
@ -311,22 +284,22 @@ def server(
use_ssl: Annotated[bool, typer.Option(help="Run the server using HTTPS?")] = False,
ssl_cert: Annotated[Optional[str], typer.Option(help="Path to SSL certificate (if use_ssl is True)")] = None,
ssl_key: Annotated[Optional[str], typer.Option(help="Path to SSL key file (if use_ssl is True)")] = None,
debug: Annotated[bool, typer.Option(help="Turn debugging output on")] = True,
debug: Annotated[bool, typer.Option(help="Turn debugging output on")] = False,
):
"""Launch a MemGPT server process"""
if debug:
from memgpt.server.server import logger as server_logger
# if debug:
# from memgpt.server.server import logger as server_logger
# Set the logging level
server_logger.setLevel(logging.DEBUG)
# Create a StreamHandler
stream_handler = logging.StreamHandler()
# Set the formatter (optional)
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
stream_handler.setFormatter(formatter)
# Add the handler to the logger
server_logger.addHandler(stream_handler)
# # Set the logging level
# server_logger.setLevel(logging.DEBUG)
# # Create a StreamHandler
# stream_handler = logging.StreamHandler()
# # Set the formatter (optional)
# formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
# stream_handler.setFormatter(formatter)
# # Add the handler to the logger
# server_logger.addHandler(stream_handler)
if type == ServerChoice.rest_api:
import uvicorn
@ -341,35 +314,45 @@ def server(
sys.exit(1)
try:
if use_ssl:
if ssl_cert is None: # No certificate path provided, generate a self-signed certificate
ssl_certfile, ssl_keyfile = generate_self_signed_cert()
print(f"Running server with self-signed SSL cert: {ssl_certfile}, {ssl_keyfile}")
else:
ssl_certfile, ssl_keyfile = ssl_cert, ssl_key # Assuming cert includes both
print(f"Running server with provided SSL cert: {ssl_certfile}, {ssl_keyfile}")
from memgpt.server.rest_api.server import start_server
# This will start the server on HTTPS
assert isinstance(ssl_certfile, str) and os.path.exists(ssl_certfile), ssl_certfile
assert isinstance(ssl_keyfile, str) and os.path.exists(ssl_keyfile), ssl_keyfile
print(
f"Running: uvicorn {app}:app --host {host or 'localhost'} --port {port or REST_DEFAULT_PORT} --ssl-keyfile {ssl_keyfile} --ssl-certfile {ssl_certfile}"
)
uvicorn.run(
app,
host=host or "localhost",
port=port or REST_DEFAULT_PORT,
ssl_keyfile=ssl_keyfile,
ssl_certfile=ssl_certfile,
)
else:
# Start the subprocess in a new session
print(f"Running: uvicorn {app}:app --host {host or 'localhost'} --port {port or REST_DEFAULT_PORT}")
uvicorn.run(
app,
host=host or "localhost",
port=port or REST_DEFAULT_PORT,
)
start_server(
port=port,
host=host,
use_ssl=use_ssl,
ssl_cert=ssl_cert,
ssl_key=ssl_key,
debug=debug,
)
# if use_ssl:
# if ssl_cert is None: # No certificate path provided, generate a self-signed certificate
# ssl_certfile, ssl_keyfile = generate_self_signed_cert()
# print(f"Running server with self-signed SSL cert: {ssl_certfile}, {ssl_keyfile}")
# else:
# ssl_certfile, ssl_keyfile = ssl_cert, ssl_key # Assuming cert includes both
# print(f"Running server with provided SSL cert: {ssl_certfile}, {ssl_keyfile}")
# # This will start the server on HTTPS
# assert isinstance(ssl_certfile, str) and os.path.exists(ssl_certfile), ssl_certfile
# assert isinstance(ssl_keyfile, str) and os.path.exists(ssl_keyfile), ssl_keyfile
# print(
# f"Running: uvicorn {app}:app --host {host or 'localhost'} --port {port or REST_DEFAULT_PORT} --ssl-keyfile {ssl_keyfile} --ssl-certfile {ssl_certfile}"
# )
# uvicorn.run(
# app,
# host=host or "localhost",
# port=port or REST_DEFAULT_PORT,
# ssl_keyfile=ssl_keyfile,
# ssl_certfile=ssl_certfile,
# )
# else:
# # Start the subprocess in a new session
# print(f"Running: uvicorn {app}:app --host {host or 'localhost'} --port {port or REST_DEFAULT_PORT}")
# uvicorn.run(
# app,
# host=host or "localhost",
# port=port or REST_DEFAULT_PORT,
# )
except KeyboardInterrupt:
# Handle CTRL-C
@ -377,6 +360,19 @@ def server(
sys.exit(0)
elif type == ServerChoice.ws_api:
if debug:
from memgpt.server.server import logger as server_logger
# Set the logging level
server_logger.setLevel(logging.DEBUG)
# Create a StreamHandler
stream_handler = logging.StreamHandler()
# Set the formatter (optional)
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
stream_handler.setFormatter(formatter)
# Add the handler to the logger
server_logger.addHandler(stream_handler)
if port is None:
port = WS_DEFAULT_PORT

View File

@ -1,5 +1,16 @@
from typing import Optional
import uuid
import requests
from requests import HTTPError
from memgpt.server.rest_api.admin.users import (
CreateUserResponse,
CreateAPIKeyResponse,
GetAPIKeysResponse,
DeleteAPIKeyResponse,
DeleteUserResponse,
GetAllUsersResponse,
)
class Admin:
@ -14,13 +25,58 @@ class Admin:
self.token = token
self.headers = {"accept": "application/json", "content-type": "application/json", "authorization": f"Bearer {token}"}
def create_user(self, user_id: Optional[str] = None):
def get_users(self):
response = requests.get(f"{self.base_url}/admin/users", headers=self.headers)
if response.status_code != 200:
raise HTTPError(response.json())
return GetAllUsersResponse(**response.json())
def create_key(self, user_id: uuid.UUID, key_name: str):
payload = {"user_id": str(user_id), "key_name": key_name}
response = requests.post(f"{self.base_url}/admin/users/keys", headers=self.headers, json=payload)
print(response.json())
if response.status_code != 200:
raise HTTPError(response.json())
return CreateAPIKeyResponse(**response.json())
def get_keys(self, user_id: uuid.UUID):
params = {"user_id": str(user_id)}
response = requests.get(f"{self.base_url}/admin/users/keys", params=params, headers=self.headers)
if response.status_code != 200:
raise HTTPError(response.json())
print(response.text, response.status_code)
return GetAPIKeysResponse(**response.json())
def delete_key(self, api_key: str):
params = {"api_key": api_key}
response = requests.delete(f"{self.base_url}/admin/users/keys", params=params, headers=self.headers)
if response.status_code != 200:
raise HTTPError(response.json())
return DeleteAPIKeyResponse(**response.json())
def create_user(self, user_id: Optional[uuid.UUID] = None):
payload = {"user_id": str(user_id) if user_id else None}
response = requests.post(f"{self.base_url}/admin/users", headers=self.headers, json=payload)
if response.status_code != 200:
raise HTTPError(response.json())
response_json = response.json()
print(response_json)
return response_json["user_id"], response_json["api_key"]
return CreateUserResponse(**response_json)
def delete_user(self, user_id: str):
response = requests.delete(f"{self.base_url}/admin/users/{user_id}", headers=self.headers)
return response.json()
def delete_user(self, user_id: uuid.UUID):
params = {"user_id": str(user_id)}
response = requests.delete(f"{self.base_url}/admin/users", params=params, headers=self.headers)
if response.status_code != 200:
raise HTTPError(response.json())
return DeleteUserResponse(**response.json())
def _reset_server(self):
# DANGER: this will delete all users and keys
# clear all state associated with users
# TODO: clear out all agents, presets, etc.
users = self.get_users().user_list
for user in users:
keys = self.get_keys(user["user_id"]).api_key_list
for key in keys:
self.delete_key(key)
self.delete_user(user["user_id"])

View File

@ -140,7 +140,10 @@ class RESTClient(AbstractClient):
}
}
response = requests.post(f"{self.base_url}/api/agents", json=payload, headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to create agent: {response.text}")
response_json = response.json()
print(response_json)
llm_config = LLMConfig(**response_json["agent_state"]["llm_config"])
embedding_config = EmbeddingConfig(**response_json["agent_state"]["embedding_config"])
agent_state = AgentState(

View File

@ -16,17 +16,17 @@ class GetAllUsersResponse(BaseModel):
class CreateUserRequest(BaseModel):
user_id: Optional[str] = Field(None, description="Identifier of the user (optional, generated automatically if null).")
user_id: Optional[uuid.UUID] = Field(None, description="Identifier of the user (optional, generated automatically if null).")
api_key_name: Optional[str] = Field(None, description="Name for API key autogenerated on user creation (optional).")
class CreateUserResponse(BaseModel):
user_id: str = Field(..., description="Identifier of the user (UUID).")
user_id: uuid.UUID = Field(..., description="Identifier of the user (UUID).")
api_key: str = Field(..., description="New API key generated for user.")
class CreateAPIKeyRequest(BaseModel):
user_id: str = Field(..., description="Identifier of the user (UUID).")
user_id: uuid.UUID = Field(..., description="Identifier of the user (UUID).")
name: Optional[str] = Field(None, description="Name for the API key (optional).")
@ -35,7 +35,7 @@ class CreateAPIKeyResponse(BaseModel):
class GetAPIKeysRequest(BaseModel):
user_id: str = Field(..., description="Identifier of the user (UUID).")
user_id: uuid.UUID = Field(..., description="Identifier of the user (UUID).")
class GetAPIKeysResponse(BaseModel):
@ -49,7 +49,7 @@ class DeleteAPIKeyResponse(BaseModel):
class DeleteUserResponse(BaseModel):
message: str
user_id_deleted: str
user_id_deleted: uuid.UUID
def setup_admin_router(server: SyncServer, interface: QueuingInterface):
@ -76,7 +76,7 @@ def setup_admin_router(server: SyncServer, interface: QueuingInterface):
request = CreateUserRequest()
new_user = User(
id=None if not request.user_id else uuid.UUID(request.user_id),
id=None if not request.user_id else request.user_id,
# TODO can add more fields (name? metadata?)
)
@ -98,17 +98,18 @@ def setup_admin_router(server: SyncServer, interface: QueuingInterface):
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
return CreateUserResponse(user_id=str(new_user_ret.id), api_key=token.token)
return CreateUserResponse(user_id=new_user_ret.id, api_key=token.token)
@router.delete("/users/{user_id}", tags=["admin"], response_model=DeleteUserResponse)
def delete_user(user_id):
@router.delete("/users", tags=["admin"], response_model=DeleteUserResponse)
def delete_user(
user_id: uuid.UUID = Query(..., description="The user_id key to be deleted."),
):
# TODO make a soft deletion, instead of a hard deletion
try:
user_id_uuid = uuid.UUID(user_id)
user = server.ms.get_user(user_id=user_id_uuid)
user = server.ms.get_user(user_id=user_id)
if user is None:
raise HTTPException(status_code=404, detail=f"User does not exist")
server.ms.delete_user(user_id=user_id_uuid)
server.ms.delete_user(user_id=user_id)
except HTTPException:
raise
except Exception as e:
@ -121,8 +122,7 @@ def setup_admin_router(server: SyncServer, interface: QueuingInterface):
Create a new API key for a user
"""
try:
user_id = uuid.UUID(request.user_id)
token = server.ms.create_api_key(user_id=user_id, name=request.name)
token = server.ms.create_api_key(user_id=request.user_id, name=request.name)
except HTTPException:
raise
except Exception as e:
@ -131,19 +131,20 @@ def setup_admin_router(server: SyncServer, interface: QueuingInterface):
@router.get("/users/keys", tags=["admin"], response_model=GetAPIKeysResponse)
def get_api_keys(
user_id: str = Query(..., description="The unique identifier of the user."),
user_id: uuid.UUID = Query(..., description="The unique identifier of the user."),
):
"""
Get a list of all API keys for a user
"""
print("GET USERS", user_id)
try:
user_id = uuid.UUID(user_id)
tokens = server.ms.get_all_api_keys_for_user(user_id=user_id)
processed_tokens = [t.token for t in tokens]
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
print("TOKENS", processed_tokens)
return GetAPIKeysResponse(api_key_list=processed_tokens)
@router.delete("/users/keys", tags=["admin"], response_model=DeleteAPIKeyResponse)

View File

@ -1,4 +1,7 @@
import json
import uvicorn
from typing import Optional
import logging
import os
import secrets
@ -23,6 +26,9 @@ from memgpt.server.rest_api.static_files import mount_static_files
from memgpt.server.rest_api.tools.index import setup_tools_index_router
from memgpt.server.rest_api.sources.index import setup_sources_index_router
from memgpt.server.server import SyncServer
from memgpt.config import MemGPTConfig
from memgpt.server.constants import REST_DEFAULT_PORT
import subprocess
"""
Basic REST API sitting on top of the internal MemGPT python server (SyncServer)
@ -147,3 +153,83 @@ def on_shutdown():
global server
server.save_agents()
server = None
def generate_self_signed_cert(cert_path="selfsigned.crt", key_path="selfsigned.key"):
"""Generate a self-signed SSL certificate.
NOTE: intended to be used for development only.
"""
subprocess.run(
[
"openssl",
"req",
"-x509",
"-newkey",
"rsa:4096",
"-keyout",
key_path,
"-out",
cert_path,
"-days",
"365",
"-nodes",
"-subj",
"/C=US/ST=Denial/L=Springfield/O=Dis/CN=localhost",
],
check=True,
)
return cert_path, key_path
def start_server(
port: Optional[int] = None,
host: Optional[str] = None,
use_ssl: bool = False,
ssl_cert: Optional[str] = None,
ssl_key: Optional[str] = None,
debug: bool = False,
):
print("DEBUG", debug)
if debug:
from memgpt.server.server import logger as server_logger
# Set the logging level
server_logger.setLevel(logging.DEBUG)
# Create a StreamHandler
stream_handler = logging.StreamHandler()
# Set the formatter (optional)
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
stream_handler.setFormatter(formatter)
# Add the handler to the logger
server_logger.addHandler(stream_handler)
if use_ssl:
if ssl_cert is None: # No certificate path provided, generate a self-signed certificate
ssl_certfile, ssl_keyfile = generate_self_signed_cert()
print(f"Running server with self-signed SSL cert: {ssl_certfile}, {ssl_keyfile}")
else:
ssl_certfile, ssl_keyfile = ssl_cert, ssl_key # Assuming cert includes both
print(f"Running server with provided SSL cert: {ssl_certfile}, {ssl_keyfile}")
# This will start the server on HTTPS
assert isinstance(ssl_certfile, str) and os.path.exists(ssl_certfile), ssl_certfile
assert isinstance(ssl_keyfile, str) and os.path.exists(ssl_keyfile), ssl_keyfile
print(
f"Running: uvicorn {app}:app --host {host or 'localhost'} --port {port or REST_DEFAULT_PORT} --ssl-keyfile {ssl_keyfile} --ssl-certfile {ssl_certfile}"
)
uvicorn.run(
app,
host=host or "localhost",
port=port or REST_DEFAULT_PORT,
ssl_keyfile=ssl_keyfile,
ssl_certfile=ssl_certfile,
)
else:
# Start the subprocess in a new session
print(f"Running: uvicorn {app}:app --host {host or 'localhost'} --port {port or REST_DEFAULT_PORT}")
uvicorn.run(
app,
host=host or "localhost",
port=port or REST_DEFAULT_PORT,
)

View File

@ -1,4 +1,5 @@
import json
import subprocess
import logging
import uuid
from abc import abstractmethod
@ -7,6 +8,7 @@ from threading import Lock
from typing import Union, Callable, Optional, List
from fastapi import HTTPException
import uvicorn
import memgpt.constants as constants
import memgpt.presets.presets as presets

158
tests/test_admin_client.py Normal file
View File

@ -0,0 +1,158 @@
import uuid
import os
import time
import threading
from dotenv import load_dotenv
from memgpt import Admin, create_client
from memgpt.constants import DEFAULT_PRESET
from dotenv import load_dotenv
from tests.config import TestMGPTConfig
from memgpt.server.rest_api.server import start_server
from memgpt.credentials import MemGPTCredentials
from memgpt.data_types import EmbeddingConfig, LLMConfig
from .utils import wipe_config, wipe_memgpt_home
import pytest
import uuid
test_base_url = "http://localhost:8283"
# admin credentials
test_server_token = "test_server_token"
def run_server():
import uvicorn
from memgpt.server.rest_api.server import app
load_dotenv()
# Use os.getenv with a fallback to os.environ.get
db_url = os.getenv("PGVECTOR_TEST_DB_URL") or os.environ.get("PGVECTOR_TEST_DB_URL")
assert db_url, "Missing PGVECTOR_TEST_DB_URL"
if os.getenv("OPENAI_API_KEY"):
config = TestMGPTConfig(
archival_storage_uri=db_url,
recall_storage_uri=db_url,
metadata_storage_uri=db_url,
archival_storage_type="postgres",
recall_storage_type="postgres",
metadata_storage_type="postgres",
# embeddings
default_embedding_config=EmbeddingConfig(
embedding_endpoint_type="openai",
embedding_endpoint="https://api.openai.com/v1",
embedding_dim=1536,
),
# llms
default_llm_config=LLMConfig(
model_endpoint_type="openai",
model_endpoint="https://api.openai.com/v1",
model="gpt-4",
),
)
credentials = MemGPTCredentials(
openai_key=os.getenv("OPENAI_API_KEY"),
)
else: # hosted
config = TestMGPTConfig(
archival_storage_uri=db_url,
recall_storage_uri=db_url,
metadata_storage_uri=db_url,
archival_storage_type="postgres",
recall_storage_type="postgres",
metadata_storage_type="postgres",
# embeddings
default_embedding_config=EmbeddingConfig(
embedding_endpoint_type="hugging-face",
embedding_endpoint="https://embeddings.memgpt.ai",
embedding_model="BAAI/bge-large-en-v1.5",
embedding_dim=1024,
),
# llms
default_llm_config=LLMConfig(
model_endpoint_type="vllm",
model_endpoint="https://api.memgpt.ai",
model="ehartford/dolphin-2.5-mixtral-8x7b",
),
)
credentials = MemGPTCredentials()
config.save()
credentials.save()
# start server
start_server(debug=True)
@pytest.fixture(scope="session", autouse=True)
def start_uvicorn_server():
"""Starts Uvicorn server in a background thread."""
thread = threading.Thread(target=run_server, daemon=True)
thread.start()
print("Starting server...")
time.sleep(5)
yield
@pytest.fixture(scope="module")
def admin_client():
# Setup: Create a user via the client before the tests
admin = Admin(test_base_url, test_server_token)
admin._reset_server()
yield admin
def test_admin_client(admin_client):
# create a user
user_id = uuid.uuid4()
create_user1_response = admin_client.create_user(user_id)
assert user_id == create_user1_response.user_id, f"Expected {user_id}, got {create_user1_response.user_id}"
# create another user
create_user_2_response = admin_client.create_user()
# create keys
key1_name = "test_key1"
key2_name = "test_key2"
create_key1_response = admin_client.create_key(user_id, key1_name)
create_key2_response = admin_client.create_key(create_user_2_response.user_id, key2_name)
# list users
users = admin_client.get_users()
assert len(users.user_list) == 2
print(users.user_list)
assert user_id in [uuid.UUID(u["user_id"]) for u in users.user_list]
# list keys
user1_keys = admin_client.get_keys(user_id)
assert len(user1_keys.api_key_list) == 2, f"Expected 2 keys, got {user1_keys}"
assert create_key1_response.api_key in user1_keys.api_key_list, f"Expected {create_key1_response.api_key} in {user1_keys.api_key_list}"
assert (
create_user1_response.api_key in user1_keys.api_key_list
), f"Expected {create_user1_response.api_key} in {user1_keys.api_key_list}"
# delete key
delete_key1_response = admin_client.delete_key(create_key1_response.api_key)
assert delete_key1_response.api_key_deleted == create_key1_response.api_key
assert len(admin_client.get_keys(user_id).api_key_list) == 1
delete_key2_response = admin_client.delete_key(create_key2_response.api_key)
assert delete_key2_response.api_key_deleted == create_key2_response.api_key
assert len(admin_client.get_keys(create_user_2_response.user_id).api_key_list) == 1
# delete users
delete_user1_response = admin_client.delete_user(user_id)
assert delete_user1_response.user_id_deleted == user_id
delete_user2_response = admin_client.delete_user(create_user_2_response.user_id)
assert delete_user2_response.user_id_deleted == create_user_2_response.user_id
# list users
users = admin_client.get_users()
assert len(users.user_list) == 0, f"Expected 0 users, got {users}"

View File

@ -4,6 +4,7 @@ import time
import threading
from dotenv import load_dotenv
from memgpt.server.rest_api.server import start_server
from memgpt import Admin, create_client
from memgpt.constants import DEFAULT_PRESET
from dotenv import load_dotenv
@ -94,7 +95,7 @@ def run_server():
config.save()
credentials.save()
uvicorn.run(app, host="localhost", port=8283, log_level="info")
start_server(debug=False)
@pytest.fixture(scope="session", autouse=True)
@ -113,8 +114,9 @@ def user_token():
# Setup: Create a user via the client before the tests
admin = Admin(test_base_url, test_server_token)
user_id, token = admin.create_user(test_user_id) # Adjust as per your client's method
print(user_id, token)
response = admin.create_user(test_user_id) # Adjust as per your client's method
user_id = response.user_id
token = response.api_key
yield token