diff --git a/memgpt/cli/cli.py b/memgpt/cli/cli.py index 7a34a5b9e..76a07db24 100644 --- a/memgpt/cli/cli.py +++ b/memgpt/cli/cli.py @@ -277,33 +277,6 @@ def create_default_user_or_exit(config: MemGPTConfig, ms: MetadataStore): return user -def generate_self_signed_cert(cert_path="selfsigned.crt", key_path="selfsigned.key"): - """Generate a self-signed SSL certificate. - - NOTE: intended to be used for development only. - """ - subprocess.run( - [ - "openssl", - "req", - "-x509", - "-newkey", - "rsa:4096", - "-keyout", - key_path, - "-out", - cert_path, - "-days", - "365", - "-nodes", - "-subj", - "/C=US/ST=Denial/L=Springfield/O=Dis/CN=localhost", - ], - check=True, - ) - return cert_path, key_path - - def server( type: Annotated[ServerChoice, typer.Option(help="Server to run")] = "rest", port: Annotated[Optional[int], typer.Option(help="Port to run the server on")] = None, @@ -311,22 +284,22 @@ def server( use_ssl: Annotated[bool, typer.Option(help="Run the server using HTTPS?")] = False, ssl_cert: Annotated[Optional[str], typer.Option(help="Path to SSL certificate (if use_ssl is True)")] = None, ssl_key: Annotated[Optional[str], typer.Option(help="Path to SSL key file (if use_ssl is True)")] = None, - debug: Annotated[bool, typer.Option(help="Turn debugging output on")] = True, + debug: Annotated[bool, typer.Option(help="Turn debugging output on")] = False, ): """Launch a MemGPT server process""" - if debug: - from memgpt.server.server import logger as server_logger + # if debug: + # from memgpt.server.server import logger as server_logger - # Set the logging level - server_logger.setLevel(logging.DEBUG) - # Create a StreamHandler - stream_handler = logging.StreamHandler() - # Set the formatter (optional) - formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") - stream_handler.setFormatter(formatter) - # Add the handler to the logger - server_logger.addHandler(stream_handler) + # # Set the logging level + # server_logger.setLevel(logging.DEBUG) + # # Create a StreamHandler + # stream_handler = logging.StreamHandler() + # # Set the formatter (optional) + # formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") + # stream_handler.setFormatter(formatter) + # # Add the handler to the logger + # server_logger.addHandler(stream_handler) if type == ServerChoice.rest_api: import uvicorn @@ -341,35 +314,45 @@ def server( sys.exit(1) try: - if use_ssl: - if ssl_cert is None: # No certificate path provided, generate a self-signed certificate - ssl_certfile, ssl_keyfile = generate_self_signed_cert() - print(f"Running server with self-signed SSL cert: {ssl_certfile}, {ssl_keyfile}") - else: - ssl_certfile, ssl_keyfile = ssl_cert, ssl_key # Assuming cert includes both - print(f"Running server with provided SSL cert: {ssl_certfile}, {ssl_keyfile}") + from memgpt.server.rest_api.server import start_server - # This will start the server on HTTPS - assert isinstance(ssl_certfile, str) and os.path.exists(ssl_certfile), ssl_certfile - assert isinstance(ssl_keyfile, str) and os.path.exists(ssl_keyfile), ssl_keyfile - print( - f"Running: uvicorn {app}:app --host {host or 'localhost'} --port {port or REST_DEFAULT_PORT} --ssl-keyfile {ssl_keyfile} --ssl-certfile {ssl_certfile}" - ) - uvicorn.run( - app, - host=host or "localhost", - port=port or REST_DEFAULT_PORT, - ssl_keyfile=ssl_keyfile, - ssl_certfile=ssl_certfile, - ) - else: - # Start the subprocess in a new session - print(f"Running: uvicorn {app}:app --host {host or 'localhost'} --port {port or REST_DEFAULT_PORT}") - uvicorn.run( - app, - host=host or "localhost", - port=port or REST_DEFAULT_PORT, - ) + start_server( + port=port, + host=host, + use_ssl=use_ssl, + ssl_cert=ssl_cert, + ssl_key=ssl_key, + debug=debug, + ) + # if use_ssl: + # if ssl_cert is None: # No certificate path provided, generate a self-signed certificate + # ssl_certfile, ssl_keyfile = generate_self_signed_cert() + # print(f"Running server with self-signed SSL cert: {ssl_certfile}, {ssl_keyfile}") + # else: + # ssl_certfile, ssl_keyfile = ssl_cert, ssl_key # Assuming cert includes both + # print(f"Running server with provided SSL cert: {ssl_certfile}, {ssl_keyfile}") + + # # This will start the server on HTTPS + # assert isinstance(ssl_certfile, str) and os.path.exists(ssl_certfile), ssl_certfile + # assert isinstance(ssl_keyfile, str) and os.path.exists(ssl_keyfile), ssl_keyfile + # print( + # f"Running: uvicorn {app}:app --host {host or 'localhost'} --port {port or REST_DEFAULT_PORT} --ssl-keyfile {ssl_keyfile} --ssl-certfile {ssl_certfile}" + # ) + # uvicorn.run( + # app, + # host=host or "localhost", + # port=port or REST_DEFAULT_PORT, + # ssl_keyfile=ssl_keyfile, + # ssl_certfile=ssl_certfile, + # ) + # else: + # # Start the subprocess in a new session + # print(f"Running: uvicorn {app}:app --host {host or 'localhost'} --port {port or REST_DEFAULT_PORT}") + # uvicorn.run( + # app, + # host=host or "localhost", + # port=port or REST_DEFAULT_PORT, + # ) except KeyboardInterrupt: # Handle CTRL-C @@ -377,6 +360,19 @@ def server( sys.exit(0) elif type == ServerChoice.ws_api: + if debug: + from memgpt.server.server import logger as server_logger + + # Set the logging level + server_logger.setLevel(logging.DEBUG) + # Create a StreamHandler + stream_handler = logging.StreamHandler() + # Set the formatter (optional) + formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") + stream_handler.setFormatter(formatter) + # Add the handler to the logger + server_logger.addHandler(stream_handler) + if port is None: port = WS_DEFAULT_PORT diff --git a/memgpt/client/admin.py b/memgpt/client/admin.py index a47c956b3..dd2eea095 100644 --- a/memgpt/client/admin.py +++ b/memgpt/client/admin.py @@ -1,5 +1,16 @@ from typing import Optional +import uuid import requests +from requests import HTTPError + +from memgpt.server.rest_api.admin.users import ( + CreateUserResponse, + CreateAPIKeyResponse, + GetAPIKeysResponse, + DeleteAPIKeyResponse, + DeleteUserResponse, + GetAllUsersResponse, +) class Admin: @@ -14,13 +25,58 @@ class Admin: self.token = token self.headers = {"accept": "application/json", "content-type": "application/json", "authorization": f"Bearer {token}"} - def create_user(self, user_id: Optional[str] = None): + def get_users(self): + response = requests.get(f"{self.base_url}/admin/users", headers=self.headers) + if response.status_code != 200: + raise HTTPError(response.json()) + return GetAllUsersResponse(**response.json()) + + def create_key(self, user_id: uuid.UUID, key_name: str): + payload = {"user_id": str(user_id), "key_name": key_name} + response = requests.post(f"{self.base_url}/admin/users/keys", headers=self.headers, json=payload) + print(response.json()) + if response.status_code != 200: + raise HTTPError(response.json()) + return CreateAPIKeyResponse(**response.json()) + + def get_keys(self, user_id: uuid.UUID): + params = {"user_id": str(user_id)} + response = requests.get(f"{self.base_url}/admin/users/keys", params=params, headers=self.headers) + if response.status_code != 200: + raise HTTPError(response.json()) + print(response.text, response.status_code) + return GetAPIKeysResponse(**response.json()) + + def delete_key(self, api_key: str): + params = {"api_key": api_key} + response = requests.delete(f"{self.base_url}/admin/users/keys", params=params, headers=self.headers) + if response.status_code != 200: + raise HTTPError(response.json()) + return DeleteAPIKeyResponse(**response.json()) + + def create_user(self, user_id: Optional[uuid.UUID] = None): payload = {"user_id": str(user_id) if user_id else None} response = requests.post(f"{self.base_url}/admin/users", headers=self.headers, json=payload) + if response.status_code != 200: + raise HTTPError(response.json()) response_json = response.json() print(response_json) - return response_json["user_id"], response_json["api_key"] + return CreateUserResponse(**response_json) - def delete_user(self, user_id: str): - response = requests.delete(f"{self.base_url}/admin/users/{user_id}", headers=self.headers) - return response.json() + def delete_user(self, user_id: uuid.UUID): + params = {"user_id": str(user_id)} + response = requests.delete(f"{self.base_url}/admin/users", params=params, headers=self.headers) + if response.status_code != 200: + raise HTTPError(response.json()) + return DeleteUserResponse(**response.json()) + + def _reset_server(self): + # DANGER: this will delete all users and keys + # clear all state associated with users + # TODO: clear out all agents, presets, etc. + users = self.get_users().user_list + for user in users: + keys = self.get_keys(user["user_id"]).api_key_list + for key in keys: + self.delete_key(key) + self.delete_user(user["user_id"]) diff --git a/memgpt/client/client.py b/memgpt/client/client.py index dbccf0ace..fced3fd79 100644 --- a/memgpt/client/client.py +++ b/memgpt/client/client.py @@ -140,7 +140,10 @@ class RESTClient(AbstractClient): } } response = requests.post(f"{self.base_url}/api/agents", json=payload, headers=self.headers) + if response.status_code != 200: + raise ValueError(f"Failed to create agent: {response.text}") response_json = response.json() + print(response_json) llm_config = LLMConfig(**response_json["agent_state"]["llm_config"]) embedding_config = EmbeddingConfig(**response_json["agent_state"]["embedding_config"]) agent_state = AgentState( diff --git a/memgpt/server/rest_api/admin/users.py b/memgpt/server/rest_api/admin/users.py index 0b7468efc..0f6f575ca 100644 --- a/memgpt/server/rest_api/admin/users.py +++ b/memgpt/server/rest_api/admin/users.py @@ -16,17 +16,17 @@ class GetAllUsersResponse(BaseModel): class CreateUserRequest(BaseModel): - user_id: Optional[str] = Field(None, description="Identifier of the user (optional, generated automatically if null).") + user_id: Optional[uuid.UUID] = Field(None, description="Identifier of the user (optional, generated automatically if null).") api_key_name: Optional[str] = Field(None, description="Name for API key autogenerated on user creation (optional).") class CreateUserResponse(BaseModel): - user_id: str = Field(..., description="Identifier of the user (UUID).") + user_id: uuid.UUID = Field(..., description="Identifier of the user (UUID).") api_key: str = Field(..., description="New API key generated for user.") class CreateAPIKeyRequest(BaseModel): - user_id: str = Field(..., description="Identifier of the user (UUID).") + user_id: uuid.UUID = Field(..., description="Identifier of the user (UUID).") name: Optional[str] = Field(None, description="Name for the API key (optional).") @@ -35,7 +35,7 @@ class CreateAPIKeyResponse(BaseModel): class GetAPIKeysRequest(BaseModel): - user_id: str = Field(..., description="Identifier of the user (UUID).") + user_id: uuid.UUID = Field(..., description="Identifier of the user (UUID).") class GetAPIKeysResponse(BaseModel): @@ -49,7 +49,7 @@ class DeleteAPIKeyResponse(BaseModel): class DeleteUserResponse(BaseModel): message: str - user_id_deleted: str + user_id_deleted: uuid.UUID def setup_admin_router(server: SyncServer, interface: QueuingInterface): @@ -76,7 +76,7 @@ def setup_admin_router(server: SyncServer, interface: QueuingInterface): request = CreateUserRequest() new_user = User( - id=None if not request.user_id else uuid.UUID(request.user_id), + id=None if not request.user_id else request.user_id, # TODO can add more fields (name? metadata?) ) @@ -98,17 +98,18 @@ def setup_admin_router(server: SyncServer, interface: QueuingInterface): raise except Exception as e: raise HTTPException(status_code=500, detail=f"{e}") - return CreateUserResponse(user_id=str(new_user_ret.id), api_key=token.token) + return CreateUserResponse(user_id=new_user_ret.id, api_key=token.token) - @router.delete("/users/{user_id}", tags=["admin"], response_model=DeleteUserResponse) - def delete_user(user_id): + @router.delete("/users", tags=["admin"], response_model=DeleteUserResponse) + def delete_user( + user_id: uuid.UUID = Query(..., description="The user_id key to be deleted."), + ): # TODO make a soft deletion, instead of a hard deletion try: - user_id_uuid = uuid.UUID(user_id) - user = server.ms.get_user(user_id=user_id_uuid) + user = server.ms.get_user(user_id=user_id) if user is None: raise HTTPException(status_code=404, detail=f"User does not exist") - server.ms.delete_user(user_id=user_id_uuid) + server.ms.delete_user(user_id=user_id) except HTTPException: raise except Exception as e: @@ -121,8 +122,7 @@ def setup_admin_router(server: SyncServer, interface: QueuingInterface): Create a new API key for a user """ try: - user_id = uuid.UUID(request.user_id) - token = server.ms.create_api_key(user_id=user_id, name=request.name) + token = server.ms.create_api_key(user_id=request.user_id, name=request.name) except HTTPException: raise except Exception as e: @@ -131,19 +131,20 @@ def setup_admin_router(server: SyncServer, interface: QueuingInterface): @router.get("/users/keys", tags=["admin"], response_model=GetAPIKeysResponse) def get_api_keys( - user_id: str = Query(..., description="The unique identifier of the user."), + user_id: uuid.UUID = Query(..., description="The unique identifier of the user."), ): """ Get a list of all API keys for a user """ + print("GET USERS", user_id) try: - user_id = uuid.UUID(user_id) tokens = server.ms.get_all_api_keys_for_user(user_id=user_id) processed_tokens = [t.token for t in tokens] except HTTPException: raise except Exception as e: raise HTTPException(status_code=500, detail=f"{e}") + print("TOKENS", processed_tokens) return GetAPIKeysResponse(api_key_list=processed_tokens) @router.delete("/users/keys", tags=["admin"], response_model=DeleteAPIKeyResponse) diff --git a/memgpt/server/rest_api/server.py b/memgpt/server/rest_api/server.py index 6a52ea11c..88d80e79a 100644 --- a/memgpt/server/rest_api/server.py +++ b/memgpt/server/rest_api/server.py @@ -1,4 +1,7 @@ import json +import uvicorn +from typing import Optional +import logging import os import secrets @@ -23,6 +26,9 @@ from memgpt.server.rest_api.static_files import mount_static_files from memgpt.server.rest_api.tools.index import setup_tools_index_router from memgpt.server.rest_api.sources.index import setup_sources_index_router from memgpt.server.server import SyncServer +from memgpt.config import MemGPTConfig +from memgpt.server.constants import REST_DEFAULT_PORT +import subprocess """ Basic REST API sitting on top of the internal MemGPT python server (SyncServer) @@ -147,3 +153,83 @@ def on_shutdown(): global server server.save_agents() server = None + + +def generate_self_signed_cert(cert_path="selfsigned.crt", key_path="selfsigned.key"): + """Generate a self-signed SSL certificate. + + NOTE: intended to be used for development only. + """ + subprocess.run( + [ + "openssl", + "req", + "-x509", + "-newkey", + "rsa:4096", + "-keyout", + key_path, + "-out", + cert_path, + "-days", + "365", + "-nodes", + "-subj", + "/C=US/ST=Denial/L=Springfield/O=Dis/CN=localhost", + ], + check=True, + ) + return cert_path, key_path + + +def start_server( + port: Optional[int] = None, + host: Optional[str] = None, + use_ssl: bool = False, + ssl_cert: Optional[str] = None, + ssl_key: Optional[str] = None, + debug: bool = False, +): + print("DEBUG", debug) + if debug: + from memgpt.server.server import logger as server_logger + + # Set the logging level + server_logger.setLevel(logging.DEBUG) + # Create a StreamHandler + stream_handler = logging.StreamHandler() + # Set the formatter (optional) + formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") + stream_handler.setFormatter(formatter) + # Add the handler to the logger + server_logger.addHandler(stream_handler) + + if use_ssl: + if ssl_cert is None: # No certificate path provided, generate a self-signed certificate + ssl_certfile, ssl_keyfile = generate_self_signed_cert() + print(f"Running server with self-signed SSL cert: {ssl_certfile}, {ssl_keyfile}") + else: + ssl_certfile, ssl_keyfile = ssl_cert, ssl_key # Assuming cert includes both + print(f"Running server with provided SSL cert: {ssl_certfile}, {ssl_keyfile}") + + # This will start the server on HTTPS + assert isinstance(ssl_certfile, str) and os.path.exists(ssl_certfile), ssl_certfile + assert isinstance(ssl_keyfile, str) and os.path.exists(ssl_keyfile), ssl_keyfile + print( + f"Running: uvicorn {app}:app --host {host or 'localhost'} --port {port or REST_DEFAULT_PORT} --ssl-keyfile {ssl_keyfile} --ssl-certfile {ssl_certfile}" + ) + uvicorn.run( + app, + host=host or "localhost", + port=port or REST_DEFAULT_PORT, + ssl_keyfile=ssl_keyfile, + ssl_certfile=ssl_certfile, + ) + else: + # Start the subprocess in a new session + print(f"Running: uvicorn {app}:app --host {host or 'localhost'} --port {port or REST_DEFAULT_PORT}") + uvicorn.run( + app, + host=host or "localhost", + port=port or REST_DEFAULT_PORT, + ) diff --git a/memgpt/server/server.py b/memgpt/server/server.py index c98855600..42554f458 100644 --- a/memgpt/server/server.py +++ b/memgpt/server/server.py @@ -1,4 +1,5 @@ import json +import subprocess import logging import uuid from abc import abstractmethod @@ -7,6 +8,7 @@ from threading import Lock from typing import Union, Callable, Optional, List from fastapi import HTTPException +import uvicorn import memgpt.constants as constants import memgpt.presets.presets as presets diff --git a/tests/test_admin_client.py b/tests/test_admin_client.py new file mode 100644 index 000000000..ab1ac3a53 --- /dev/null +++ b/tests/test_admin_client.py @@ -0,0 +1,158 @@ +import uuid +import os +import time +import threading +from dotenv import load_dotenv + +from memgpt import Admin, create_client +from memgpt.constants import DEFAULT_PRESET +from dotenv import load_dotenv + +from tests.config import TestMGPTConfig + +from memgpt.server.rest_api.server import start_server +from memgpt.credentials import MemGPTCredentials +from memgpt.data_types import EmbeddingConfig, LLMConfig +from .utils import wipe_config, wipe_memgpt_home + + +import pytest +import uuid + +test_base_url = "http://localhost:8283" + +# admin credentials +test_server_token = "test_server_token" + + +def run_server(): + import uvicorn + from memgpt.server.rest_api.server import app + + load_dotenv() + + # Use os.getenv with a fallback to os.environ.get + db_url = os.getenv("PGVECTOR_TEST_DB_URL") or os.environ.get("PGVECTOR_TEST_DB_URL") + assert db_url, "Missing PGVECTOR_TEST_DB_URL" + + if os.getenv("OPENAI_API_KEY"): + config = TestMGPTConfig( + archival_storage_uri=db_url, + recall_storage_uri=db_url, + metadata_storage_uri=db_url, + archival_storage_type="postgres", + recall_storage_type="postgres", + metadata_storage_type="postgres", + # embeddings + default_embedding_config=EmbeddingConfig( + embedding_endpoint_type="openai", + embedding_endpoint="https://api.openai.com/v1", + embedding_dim=1536, + ), + # llms + default_llm_config=LLMConfig( + model_endpoint_type="openai", + model_endpoint="https://api.openai.com/v1", + model="gpt-4", + ), + ) + credentials = MemGPTCredentials( + openai_key=os.getenv("OPENAI_API_KEY"), + ) + else: # hosted + config = TestMGPTConfig( + archival_storage_uri=db_url, + recall_storage_uri=db_url, + metadata_storage_uri=db_url, + archival_storage_type="postgres", + recall_storage_type="postgres", + metadata_storage_type="postgres", + # embeddings + default_embedding_config=EmbeddingConfig( + embedding_endpoint_type="hugging-face", + embedding_endpoint="https://embeddings.memgpt.ai", + embedding_model="BAAI/bge-large-en-v1.5", + embedding_dim=1024, + ), + # llms + default_llm_config=LLMConfig( + model_endpoint_type="vllm", + model_endpoint="https://api.memgpt.ai", + model="ehartford/dolphin-2.5-mixtral-8x7b", + ), + ) + credentials = MemGPTCredentials() + + config.save() + credentials.save() + + # start server + start_server(debug=True) + + +@pytest.fixture(scope="session", autouse=True) +def start_uvicorn_server(): + """Starts Uvicorn server in a background thread.""" + + thread = threading.Thread(target=run_server, daemon=True) + thread.start() + print("Starting server...") + time.sleep(5) + yield + + +@pytest.fixture(scope="module") +def admin_client(): + # Setup: Create a user via the client before the tests + + admin = Admin(test_base_url, test_server_token) + admin._reset_server() + yield admin + + +def test_admin_client(admin_client): + # create a user + user_id = uuid.uuid4() + create_user1_response = admin_client.create_user(user_id) + assert user_id == create_user1_response.user_id, f"Expected {user_id}, got {create_user1_response.user_id}" + + # create another user + create_user_2_response = admin_client.create_user() + + # create keys + key1_name = "test_key1" + key2_name = "test_key2" + create_key1_response = admin_client.create_key(user_id, key1_name) + create_key2_response = admin_client.create_key(create_user_2_response.user_id, key2_name) + + # list users + users = admin_client.get_users() + assert len(users.user_list) == 2 + print(users.user_list) + assert user_id in [uuid.UUID(u["user_id"]) for u in users.user_list] + + # list keys + user1_keys = admin_client.get_keys(user_id) + assert len(user1_keys.api_key_list) == 2, f"Expected 2 keys, got {user1_keys}" + assert create_key1_response.api_key in user1_keys.api_key_list, f"Expected {create_key1_response.api_key} in {user1_keys.api_key_list}" + assert ( + create_user1_response.api_key in user1_keys.api_key_list + ), f"Expected {create_user1_response.api_key} in {user1_keys.api_key_list}" + + # delete key + delete_key1_response = admin_client.delete_key(create_key1_response.api_key) + assert delete_key1_response.api_key_deleted == create_key1_response.api_key + assert len(admin_client.get_keys(user_id).api_key_list) == 1 + delete_key2_response = admin_client.delete_key(create_key2_response.api_key) + assert delete_key2_response.api_key_deleted == create_key2_response.api_key + assert len(admin_client.get_keys(create_user_2_response.user_id).api_key_list) == 1 + + # delete users + delete_user1_response = admin_client.delete_user(user_id) + assert delete_user1_response.user_id_deleted == user_id + delete_user2_response = admin_client.delete_user(create_user_2_response.user_id) + assert delete_user2_response.user_id_deleted == create_user_2_response.user_id + + # list users + users = admin_client.get_users() + assert len(users.user_list) == 0, f"Expected 0 users, got {users}" diff --git a/tests/test_client.py b/tests/test_client.py index d54da509c..0b9e22691 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -4,6 +4,7 @@ import time import threading from dotenv import load_dotenv +from memgpt.server.rest_api.server import start_server from memgpt import Admin, create_client from memgpt.constants import DEFAULT_PRESET from dotenv import load_dotenv @@ -94,7 +95,7 @@ def run_server(): config.save() credentials.save() - uvicorn.run(app, host="localhost", port=8283, log_level="info") + start_server(debug=False) @pytest.fixture(scope="session", autouse=True) @@ -113,8 +114,9 @@ def user_token(): # Setup: Create a user via the client before the tests admin = Admin(test_base_url, test_server_token) - user_id, token = admin.create_user(test_user_id) # Adjust as per your client's method - print(user_id, token) + response = admin.create_user(test_user_id) # Adjust as per your client's method + user_id = response.user_id + token = response.api_key yield token