mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
feat: implement remaining Admin
routes in client and add tests (#1157)
This commit is contained in:
parent
d73c64f20e
commit
03e9f3ebbf
@ -277,33 +277,6 @@ def create_default_user_or_exit(config: MemGPTConfig, ms: MetadataStore):
|
||||
return user
|
||||
|
||||
|
||||
def generate_self_signed_cert(cert_path="selfsigned.crt", key_path="selfsigned.key"):
|
||||
"""Generate a self-signed SSL certificate.
|
||||
|
||||
NOTE: intended to be used for development only.
|
||||
"""
|
||||
subprocess.run(
|
||||
[
|
||||
"openssl",
|
||||
"req",
|
||||
"-x509",
|
||||
"-newkey",
|
||||
"rsa:4096",
|
||||
"-keyout",
|
||||
key_path,
|
||||
"-out",
|
||||
cert_path,
|
||||
"-days",
|
||||
"365",
|
||||
"-nodes",
|
||||
"-subj",
|
||||
"/C=US/ST=Denial/L=Springfield/O=Dis/CN=localhost",
|
||||
],
|
||||
check=True,
|
||||
)
|
||||
return cert_path, key_path
|
||||
|
||||
|
||||
def server(
|
||||
type: Annotated[ServerChoice, typer.Option(help="Server to run")] = "rest",
|
||||
port: Annotated[Optional[int], typer.Option(help="Port to run the server on")] = None,
|
||||
@ -311,22 +284,22 @@ def server(
|
||||
use_ssl: Annotated[bool, typer.Option(help="Run the server using HTTPS?")] = False,
|
||||
ssl_cert: Annotated[Optional[str], typer.Option(help="Path to SSL certificate (if use_ssl is True)")] = None,
|
||||
ssl_key: Annotated[Optional[str], typer.Option(help="Path to SSL key file (if use_ssl is True)")] = None,
|
||||
debug: Annotated[bool, typer.Option(help="Turn debugging output on")] = True,
|
||||
debug: Annotated[bool, typer.Option(help="Turn debugging output on")] = False,
|
||||
):
|
||||
"""Launch a MemGPT server process"""
|
||||
|
||||
if debug:
|
||||
from memgpt.server.server import logger as server_logger
|
||||
# if debug:
|
||||
# from memgpt.server.server import logger as server_logger
|
||||
|
||||
# Set the logging level
|
||||
server_logger.setLevel(logging.DEBUG)
|
||||
# Create a StreamHandler
|
||||
stream_handler = logging.StreamHandler()
|
||||
# Set the formatter (optional)
|
||||
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
stream_handler.setFormatter(formatter)
|
||||
# Add the handler to the logger
|
||||
server_logger.addHandler(stream_handler)
|
||||
# # Set the logging level
|
||||
# server_logger.setLevel(logging.DEBUG)
|
||||
# # Create a StreamHandler
|
||||
# stream_handler = logging.StreamHandler()
|
||||
# # Set the formatter (optional)
|
||||
# formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
# stream_handler.setFormatter(formatter)
|
||||
# # Add the handler to the logger
|
||||
# server_logger.addHandler(stream_handler)
|
||||
|
||||
if type == ServerChoice.rest_api:
|
||||
import uvicorn
|
||||
@ -341,35 +314,45 @@ def server(
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
if use_ssl:
|
||||
if ssl_cert is None: # No certificate path provided, generate a self-signed certificate
|
||||
ssl_certfile, ssl_keyfile = generate_self_signed_cert()
|
||||
print(f"Running server with self-signed SSL cert: {ssl_certfile}, {ssl_keyfile}")
|
||||
else:
|
||||
ssl_certfile, ssl_keyfile = ssl_cert, ssl_key # Assuming cert includes both
|
||||
print(f"Running server with provided SSL cert: {ssl_certfile}, {ssl_keyfile}")
|
||||
from memgpt.server.rest_api.server import start_server
|
||||
|
||||
# This will start the server on HTTPS
|
||||
assert isinstance(ssl_certfile, str) and os.path.exists(ssl_certfile), ssl_certfile
|
||||
assert isinstance(ssl_keyfile, str) and os.path.exists(ssl_keyfile), ssl_keyfile
|
||||
print(
|
||||
f"Running: uvicorn {app}:app --host {host or 'localhost'} --port {port or REST_DEFAULT_PORT} --ssl-keyfile {ssl_keyfile} --ssl-certfile {ssl_certfile}"
|
||||
)
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=host or "localhost",
|
||||
port=port or REST_DEFAULT_PORT,
|
||||
ssl_keyfile=ssl_keyfile,
|
||||
ssl_certfile=ssl_certfile,
|
||||
)
|
||||
else:
|
||||
# Start the subprocess in a new session
|
||||
print(f"Running: uvicorn {app}:app --host {host or 'localhost'} --port {port or REST_DEFAULT_PORT}")
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=host or "localhost",
|
||||
port=port or REST_DEFAULT_PORT,
|
||||
)
|
||||
start_server(
|
||||
port=port,
|
||||
host=host,
|
||||
use_ssl=use_ssl,
|
||||
ssl_cert=ssl_cert,
|
||||
ssl_key=ssl_key,
|
||||
debug=debug,
|
||||
)
|
||||
# if use_ssl:
|
||||
# if ssl_cert is None: # No certificate path provided, generate a self-signed certificate
|
||||
# ssl_certfile, ssl_keyfile = generate_self_signed_cert()
|
||||
# print(f"Running server with self-signed SSL cert: {ssl_certfile}, {ssl_keyfile}")
|
||||
# else:
|
||||
# ssl_certfile, ssl_keyfile = ssl_cert, ssl_key # Assuming cert includes both
|
||||
# print(f"Running server with provided SSL cert: {ssl_certfile}, {ssl_keyfile}")
|
||||
|
||||
# # This will start the server on HTTPS
|
||||
# assert isinstance(ssl_certfile, str) and os.path.exists(ssl_certfile), ssl_certfile
|
||||
# assert isinstance(ssl_keyfile, str) and os.path.exists(ssl_keyfile), ssl_keyfile
|
||||
# print(
|
||||
# f"Running: uvicorn {app}:app --host {host or 'localhost'} --port {port or REST_DEFAULT_PORT} --ssl-keyfile {ssl_keyfile} --ssl-certfile {ssl_certfile}"
|
||||
# )
|
||||
# uvicorn.run(
|
||||
# app,
|
||||
# host=host or "localhost",
|
||||
# port=port or REST_DEFAULT_PORT,
|
||||
# ssl_keyfile=ssl_keyfile,
|
||||
# ssl_certfile=ssl_certfile,
|
||||
# )
|
||||
# else:
|
||||
# # Start the subprocess in a new session
|
||||
# print(f"Running: uvicorn {app}:app --host {host or 'localhost'} --port {port or REST_DEFAULT_PORT}")
|
||||
# uvicorn.run(
|
||||
# app,
|
||||
# host=host or "localhost",
|
||||
# port=port or REST_DEFAULT_PORT,
|
||||
# )
|
||||
|
||||
except KeyboardInterrupt:
|
||||
# Handle CTRL-C
|
||||
@ -377,6 +360,19 @@ def server(
|
||||
sys.exit(0)
|
||||
|
||||
elif type == ServerChoice.ws_api:
|
||||
if debug:
|
||||
from memgpt.server.server import logger as server_logger
|
||||
|
||||
# Set the logging level
|
||||
server_logger.setLevel(logging.DEBUG)
|
||||
# Create a StreamHandler
|
||||
stream_handler = logging.StreamHandler()
|
||||
# Set the formatter (optional)
|
||||
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
stream_handler.setFormatter(formatter)
|
||||
# Add the handler to the logger
|
||||
server_logger.addHandler(stream_handler)
|
||||
|
||||
if port is None:
|
||||
port = WS_DEFAULT_PORT
|
||||
|
||||
|
@ -1,5 +1,16 @@
|
||||
from typing import Optional
|
||||
import uuid
|
||||
import requests
|
||||
from requests import HTTPError
|
||||
|
||||
from memgpt.server.rest_api.admin.users import (
|
||||
CreateUserResponse,
|
||||
CreateAPIKeyResponse,
|
||||
GetAPIKeysResponse,
|
||||
DeleteAPIKeyResponse,
|
||||
DeleteUserResponse,
|
||||
GetAllUsersResponse,
|
||||
)
|
||||
|
||||
|
||||
class Admin:
|
||||
@ -14,13 +25,58 @@ class Admin:
|
||||
self.token = token
|
||||
self.headers = {"accept": "application/json", "content-type": "application/json", "authorization": f"Bearer {token}"}
|
||||
|
||||
def create_user(self, user_id: Optional[str] = None):
|
||||
def get_users(self):
|
||||
response = requests.get(f"{self.base_url}/admin/users", headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise HTTPError(response.json())
|
||||
return GetAllUsersResponse(**response.json())
|
||||
|
||||
def create_key(self, user_id: uuid.UUID, key_name: str):
|
||||
payload = {"user_id": str(user_id), "key_name": key_name}
|
||||
response = requests.post(f"{self.base_url}/admin/users/keys", headers=self.headers, json=payload)
|
||||
print(response.json())
|
||||
if response.status_code != 200:
|
||||
raise HTTPError(response.json())
|
||||
return CreateAPIKeyResponse(**response.json())
|
||||
|
||||
def get_keys(self, user_id: uuid.UUID):
|
||||
params = {"user_id": str(user_id)}
|
||||
response = requests.get(f"{self.base_url}/admin/users/keys", params=params, headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise HTTPError(response.json())
|
||||
print(response.text, response.status_code)
|
||||
return GetAPIKeysResponse(**response.json())
|
||||
|
||||
def delete_key(self, api_key: str):
|
||||
params = {"api_key": api_key}
|
||||
response = requests.delete(f"{self.base_url}/admin/users/keys", params=params, headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise HTTPError(response.json())
|
||||
return DeleteAPIKeyResponse(**response.json())
|
||||
|
||||
def create_user(self, user_id: Optional[uuid.UUID] = None):
|
||||
payload = {"user_id": str(user_id) if user_id else None}
|
||||
response = requests.post(f"{self.base_url}/admin/users", headers=self.headers, json=payload)
|
||||
if response.status_code != 200:
|
||||
raise HTTPError(response.json())
|
||||
response_json = response.json()
|
||||
print(response_json)
|
||||
return response_json["user_id"], response_json["api_key"]
|
||||
return CreateUserResponse(**response_json)
|
||||
|
||||
def delete_user(self, user_id: str):
|
||||
response = requests.delete(f"{self.base_url}/admin/users/{user_id}", headers=self.headers)
|
||||
return response.json()
|
||||
def delete_user(self, user_id: uuid.UUID):
|
||||
params = {"user_id": str(user_id)}
|
||||
response = requests.delete(f"{self.base_url}/admin/users", params=params, headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise HTTPError(response.json())
|
||||
return DeleteUserResponse(**response.json())
|
||||
|
||||
def _reset_server(self):
|
||||
# DANGER: this will delete all users and keys
|
||||
# clear all state associated with users
|
||||
# TODO: clear out all agents, presets, etc.
|
||||
users = self.get_users().user_list
|
||||
for user in users:
|
||||
keys = self.get_keys(user["user_id"]).api_key_list
|
||||
for key in keys:
|
||||
self.delete_key(key)
|
||||
self.delete_user(user["user_id"])
|
||||
|
@ -140,7 +140,10 @@ class RESTClient(AbstractClient):
|
||||
}
|
||||
}
|
||||
response = requests.post(f"{self.base_url}/api/agents", json=payload, headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to create agent: {response.text}")
|
||||
response_json = response.json()
|
||||
print(response_json)
|
||||
llm_config = LLMConfig(**response_json["agent_state"]["llm_config"])
|
||||
embedding_config = EmbeddingConfig(**response_json["agent_state"]["embedding_config"])
|
||||
agent_state = AgentState(
|
||||
|
@ -16,17 +16,17 @@ class GetAllUsersResponse(BaseModel):
|
||||
|
||||
|
||||
class CreateUserRequest(BaseModel):
|
||||
user_id: Optional[str] = Field(None, description="Identifier of the user (optional, generated automatically if null).")
|
||||
user_id: Optional[uuid.UUID] = Field(None, description="Identifier of the user (optional, generated automatically if null).")
|
||||
api_key_name: Optional[str] = Field(None, description="Name for API key autogenerated on user creation (optional).")
|
||||
|
||||
|
||||
class CreateUserResponse(BaseModel):
|
||||
user_id: str = Field(..., description="Identifier of the user (UUID).")
|
||||
user_id: uuid.UUID = Field(..., description="Identifier of the user (UUID).")
|
||||
api_key: str = Field(..., description="New API key generated for user.")
|
||||
|
||||
|
||||
class CreateAPIKeyRequest(BaseModel):
|
||||
user_id: str = Field(..., description="Identifier of the user (UUID).")
|
||||
user_id: uuid.UUID = Field(..., description="Identifier of the user (UUID).")
|
||||
name: Optional[str] = Field(None, description="Name for the API key (optional).")
|
||||
|
||||
|
||||
@ -35,7 +35,7 @@ class CreateAPIKeyResponse(BaseModel):
|
||||
|
||||
|
||||
class GetAPIKeysRequest(BaseModel):
|
||||
user_id: str = Field(..., description="Identifier of the user (UUID).")
|
||||
user_id: uuid.UUID = Field(..., description="Identifier of the user (UUID).")
|
||||
|
||||
|
||||
class GetAPIKeysResponse(BaseModel):
|
||||
@ -49,7 +49,7 @@ class DeleteAPIKeyResponse(BaseModel):
|
||||
|
||||
class DeleteUserResponse(BaseModel):
|
||||
message: str
|
||||
user_id_deleted: str
|
||||
user_id_deleted: uuid.UUID
|
||||
|
||||
|
||||
def setup_admin_router(server: SyncServer, interface: QueuingInterface):
|
||||
@ -76,7 +76,7 @@ def setup_admin_router(server: SyncServer, interface: QueuingInterface):
|
||||
request = CreateUserRequest()
|
||||
|
||||
new_user = User(
|
||||
id=None if not request.user_id else uuid.UUID(request.user_id),
|
||||
id=None if not request.user_id else request.user_id,
|
||||
# TODO can add more fields (name? metadata?)
|
||||
)
|
||||
|
||||
@ -98,17 +98,18 @@ def setup_admin_router(server: SyncServer, interface: QueuingInterface):
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"{e}")
|
||||
return CreateUserResponse(user_id=str(new_user_ret.id), api_key=token.token)
|
||||
return CreateUserResponse(user_id=new_user_ret.id, api_key=token.token)
|
||||
|
||||
@router.delete("/users/{user_id}", tags=["admin"], response_model=DeleteUserResponse)
|
||||
def delete_user(user_id):
|
||||
@router.delete("/users", tags=["admin"], response_model=DeleteUserResponse)
|
||||
def delete_user(
|
||||
user_id: uuid.UUID = Query(..., description="The user_id key to be deleted."),
|
||||
):
|
||||
# TODO make a soft deletion, instead of a hard deletion
|
||||
try:
|
||||
user_id_uuid = uuid.UUID(user_id)
|
||||
user = server.ms.get_user(user_id=user_id_uuid)
|
||||
user = server.ms.get_user(user_id=user_id)
|
||||
if user is None:
|
||||
raise HTTPException(status_code=404, detail=f"User does not exist")
|
||||
server.ms.delete_user(user_id=user_id_uuid)
|
||||
server.ms.delete_user(user_id=user_id)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@ -121,8 +122,7 @@ def setup_admin_router(server: SyncServer, interface: QueuingInterface):
|
||||
Create a new API key for a user
|
||||
"""
|
||||
try:
|
||||
user_id = uuid.UUID(request.user_id)
|
||||
token = server.ms.create_api_key(user_id=user_id, name=request.name)
|
||||
token = server.ms.create_api_key(user_id=request.user_id, name=request.name)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@ -131,19 +131,20 @@ def setup_admin_router(server: SyncServer, interface: QueuingInterface):
|
||||
|
||||
@router.get("/users/keys", tags=["admin"], response_model=GetAPIKeysResponse)
|
||||
def get_api_keys(
|
||||
user_id: str = Query(..., description="The unique identifier of the user."),
|
||||
user_id: uuid.UUID = Query(..., description="The unique identifier of the user."),
|
||||
):
|
||||
"""
|
||||
Get a list of all API keys for a user
|
||||
"""
|
||||
print("GET USERS", user_id)
|
||||
try:
|
||||
user_id = uuid.UUID(user_id)
|
||||
tokens = server.ms.get_all_api_keys_for_user(user_id=user_id)
|
||||
processed_tokens = [t.token for t in tokens]
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"{e}")
|
||||
print("TOKENS", processed_tokens)
|
||||
return GetAPIKeysResponse(api_key_list=processed_tokens)
|
||||
|
||||
@router.delete("/users/keys", tags=["admin"], response_model=DeleteAPIKeyResponse)
|
||||
|
@ -1,4 +1,7 @@
|
||||
import json
|
||||
import uvicorn
|
||||
from typing import Optional
|
||||
import logging
|
||||
import os
|
||||
import secrets
|
||||
|
||||
@ -23,6 +26,9 @@ from memgpt.server.rest_api.static_files import mount_static_files
|
||||
from memgpt.server.rest_api.tools.index import setup_tools_index_router
|
||||
from memgpt.server.rest_api.sources.index import setup_sources_index_router
|
||||
from memgpt.server.server import SyncServer
|
||||
from memgpt.config import MemGPTConfig
|
||||
from memgpt.server.constants import REST_DEFAULT_PORT
|
||||
import subprocess
|
||||
|
||||
"""
|
||||
Basic REST API sitting on top of the internal MemGPT python server (SyncServer)
|
||||
@ -147,3 +153,83 @@ def on_shutdown():
|
||||
global server
|
||||
server.save_agents()
|
||||
server = None
|
||||
|
||||
|
||||
def generate_self_signed_cert(cert_path="selfsigned.crt", key_path="selfsigned.key"):
|
||||
"""Generate a self-signed SSL certificate.
|
||||
|
||||
NOTE: intended to be used for development only.
|
||||
"""
|
||||
subprocess.run(
|
||||
[
|
||||
"openssl",
|
||||
"req",
|
||||
"-x509",
|
||||
"-newkey",
|
||||
"rsa:4096",
|
||||
"-keyout",
|
||||
key_path,
|
||||
"-out",
|
||||
cert_path,
|
||||
"-days",
|
||||
"365",
|
||||
"-nodes",
|
||||
"-subj",
|
||||
"/C=US/ST=Denial/L=Springfield/O=Dis/CN=localhost",
|
||||
],
|
||||
check=True,
|
||||
)
|
||||
return cert_path, key_path
|
||||
|
||||
|
||||
def start_server(
|
||||
port: Optional[int] = None,
|
||||
host: Optional[str] = None,
|
||||
use_ssl: bool = False,
|
||||
ssl_cert: Optional[str] = None,
|
||||
ssl_key: Optional[str] = None,
|
||||
debug: bool = False,
|
||||
):
|
||||
print("DEBUG", debug)
|
||||
if debug:
|
||||
from memgpt.server.server import logger as server_logger
|
||||
|
||||
# Set the logging level
|
||||
server_logger.setLevel(logging.DEBUG)
|
||||
# Create a StreamHandler
|
||||
stream_handler = logging.StreamHandler()
|
||||
# Set the formatter (optional)
|
||||
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
stream_handler.setFormatter(formatter)
|
||||
# Add the handler to the logger
|
||||
server_logger.addHandler(stream_handler)
|
||||
|
||||
if use_ssl:
|
||||
if ssl_cert is None: # No certificate path provided, generate a self-signed certificate
|
||||
ssl_certfile, ssl_keyfile = generate_self_signed_cert()
|
||||
print(f"Running server with self-signed SSL cert: {ssl_certfile}, {ssl_keyfile}")
|
||||
else:
|
||||
ssl_certfile, ssl_keyfile = ssl_cert, ssl_key # Assuming cert includes both
|
||||
print(f"Running server with provided SSL cert: {ssl_certfile}, {ssl_keyfile}")
|
||||
|
||||
# This will start the server on HTTPS
|
||||
assert isinstance(ssl_certfile, str) and os.path.exists(ssl_certfile), ssl_certfile
|
||||
assert isinstance(ssl_keyfile, str) and os.path.exists(ssl_keyfile), ssl_keyfile
|
||||
print(
|
||||
f"Running: uvicorn {app}:app --host {host or 'localhost'} --port {port or REST_DEFAULT_PORT} --ssl-keyfile {ssl_keyfile} --ssl-certfile {ssl_certfile}"
|
||||
)
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=host or "localhost",
|
||||
port=port or REST_DEFAULT_PORT,
|
||||
ssl_keyfile=ssl_keyfile,
|
||||
ssl_certfile=ssl_certfile,
|
||||
)
|
||||
else:
|
||||
# Start the subprocess in a new session
|
||||
print(f"Running: uvicorn {app}:app --host {host or 'localhost'} --port {port or REST_DEFAULT_PORT}")
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=host or "localhost",
|
||||
port=port or REST_DEFAULT_PORT,
|
||||
)
|
||||
|
@ -1,4 +1,5 @@
|
||||
import json
|
||||
import subprocess
|
||||
import logging
|
||||
import uuid
|
||||
from abc import abstractmethod
|
||||
@ -7,6 +8,7 @@ from threading import Lock
|
||||
from typing import Union, Callable, Optional, List
|
||||
|
||||
from fastapi import HTTPException
|
||||
import uvicorn
|
||||
|
||||
import memgpt.constants as constants
|
||||
import memgpt.presets.presets as presets
|
||||
|
158
tests/test_admin_client.py
Normal file
158
tests/test_admin_client.py
Normal file
@ -0,0 +1,158 @@
|
||||
import uuid
|
||||
import os
|
||||
import time
|
||||
import threading
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from memgpt import Admin, create_client
|
||||
from memgpt.constants import DEFAULT_PRESET
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from tests.config import TestMGPTConfig
|
||||
|
||||
from memgpt.server.rest_api.server import start_server
|
||||
from memgpt.credentials import MemGPTCredentials
|
||||
from memgpt.data_types import EmbeddingConfig, LLMConfig
|
||||
from .utils import wipe_config, wipe_memgpt_home
|
||||
|
||||
|
||||
import pytest
|
||||
import uuid
|
||||
|
||||
test_base_url = "http://localhost:8283"
|
||||
|
||||
# admin credentials
|
||||
test_server_token = "test_server_token"
|
||||
|
||||
|
||||
def run_server():
|
||||
import uvicorn
|
||||
from memgpt.server.rest_api.server import app
|
||||
|
||||
load_dotenv()
|
||||
|
||||
# Use os.getenv with a fallback to os.environ.get
|
||||
db_url = os.getenv("PGVECTOR_TEST_DB_URL") or os.environ.get("PGVECTOR_TEST_DB_URL")
|
||||
assert db_url, "Missing PGVECTOR_TEST_DB_URL"
|
||||
|
||||
if os.getenv("OPENAI_API_KEY"):
|
||||
config = TestMGPTConfig(
|
||||
archival_storage_uri=db_url,
|
||||
recall_storage_uri=db_url,
|
||||
metadata_storage_uri=db_url,
|
||||
archival_storage_type="postgres",
|
||||
recall_storage_type="postgres",
|
||||
metadata_storage_type="postgres",
|
||||
# embeddings
|
||||
default_embedding_config=EmbeddingConfig(
|
||||
embedding_endpoint_type="openai",
|
||||
embedding_endpoint="https://api.openai.com/v1",
|
||||
embedding_dim=1536,
|
||||
),
|
||||
# llms
|
||||
default_llm_config=LLMConfig(
|
||||
model_endpoint_type="openai",
|
||||
model_endpoint="https://api.openai.com/v1",
|
||||
model="gpt-4",
|
||||
),
|
||||
)
|
||||
credentials = MemGPTCredentials(
|
||||
openai_key=os.getenv("OPENAI_API_KEY"),
|
||||
)
|
||||
else: # hosted
|
||||
config = TestMGPTConfig(
|
||||
archival_storage_uri=db_url,
|
||||
recall_storage_uri=db_url,
|
||||
metadata_storage_uri=db_url,
|
||||
archival_storage_type="postgres",
|
||||
recall_storage_type="postgres",
|
||||
metadata_storage_type="postgres",
|
||||
# embeddings
|
||||
default_embedding_config=EmbeddingConfig(
|
||||
embedding_endpoint_type="hugging-face",
|
||||
embedding_endpoint="https://embeddings.memgpt.ai",
|
||||
embedding_model="BAAI/bge-large-en-v1.5",
|
||||
embedding_dim=1024,
|
||||
),
|
||||
# llms
|
||||
default_llm_config=LLMConfig(
|
||||
model_endpoint_type="vllm",
|
||||
model_endpoint="https://api.memgpt.ai",
|
||||
model="ehartford/dolphin-2.5-mixtral-8x7b",
|
||||
),
|
||||
)
|
||||
credentials = MemGPTCredentials()
|
||||
|
||||
config.save()
|
||||
credentials.save()
|
||||
|
||||
# start server
|
||||
start_server(debug=True)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def start_uvicorn_server():
|
||||
"""Starts Uvicorn server in a background thread."""
|
||||
|
||||
thread = threading.Thread(target=run_server, daemon=True)
|
||||
thread.start()
|
||||
print("Starting server...")
|
||||
time.sleep(5)
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def admin_client():
|
||||
# Setup: Create a user via the client before the tests
|
||||
|
||||
admin = Admin(test_base_url, test_server_token)
|
||||
admin._reset_server()
|
||||
yield admin
|
||||
|
||||
|
||||
def test_admin_client(admin_client):
|
||||
# create a user
|
||||
user_id = uuid.uuid4()
|
||||
create_user1_response = admin_client.create_user(user_id)
|
||||
assert user_id == create_user1_response.user_id, f"Expected {user_id}, got {create_user1_response.user_id}"
|
||||
|
||||
# create another user
|
||||
create_user_2_response = admin_client.create_user()
|
||||
|
||||
# create keys
|
||||
key1_name = "test_key1"
|
||||
key2_name = "test_key2"
|
||||
create_key1_response = admin_client.create_key(user_id, key1_name)
|
||||
create_key2_response = admin_client.create_key(create_user_2_response.user_id, key2_name)
|
||||
|
||||
# list users
|
||||
users = admin_client.get_users()
|
||||
assert len(users.user_list) == 2
|
||||
print(users.user_list)
|
||||
assert user_id in [uuid.UUID(u["user_id"]) for u in users.user_list]
|
||||
|
||||
# list keys
|
||||
user1_keys = admin_client.get_keys(user_id)
|
||||
assert len(user1_keys.api_key_list) == 2, f"Expected 2 keys, got {user1_keys}"
|
||||
assert create_key1_response.api_key in user1_keys.api_key_list, f"Expected {create_key1_response.api_key} in {user1_keys.api_key_list}"
|
||||
assert (
|
||||
create_user1_response.api_key in user1_keys.api_key_list
|
||||
), f"Expected {create_user1_response.api_key} in {user1_keys.api_key_list}"
|
||||
|
||||
# delete key
|
||||
delete_key1_response = admin_client.delete_key(create_key1_response.api_key)
|
||||
assert delete_key1_response.api_key_deleted == create_key1_response.api_key
|
||||
assert len(admin_client.get_keys(user_id).api_key_list) == 1
|
||||
delete_key2_response = admin_client.delete_key(create_key2_response.api_key)
|
||||
assert delete_key2_response.api_key_deleted == create_key2_response.api_key
|
||||
assert len(admin_client.get_keys(create_user_2_response.user_id).api_key_list) == 1
|
||||
|
||||
# delete users
|
||||
delete_user1_response = admin_client.delete_user(user_id)
|
||||
assert delete_user1_response.user_id_deleted == user_id
|
||||
delete_user2_response = admin_client.delete_user(create_user_2_response.user_id)
|
||||
assert delete_user2_response.user_id_deleted == create_user_2_response.user_id
|
||||
|
||||
# list users
|
||||
users = admin_client.get_users()
|
||||
assert len(users.user_list) == 0, f"Expected 0 users, got {users}"
|
@ -4,6 +4,7 @@ import time
|
||||
import threading
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from memgpt.server.rest_api.server import start_server
|
||||
from memgpt import Admin, create_client
|
||||
from memgpt.constants import DEFAULT_PRESET
|
||||
from dotenv import load_dotenv
|
||||
@ -94,7 +95,7 @@ def run_server():
|
||||
config.save()
|
||||
credentials.save()
|
||||
|
||||
uvicorn.run(app, host="localhost", port=8283, log_level="info")
|
||||
start_server(debug=False)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
@ -113,8 +114,9 @@ def user_token():
|
||||
# Setup: Create a user via the client before the tests
|
||||
|
||||
admin = Admin(test_base_url, test_server_token)
|
||||
user_id, token = admin.create_user(test_user_id) # Adjust as per your client's method
|
||||
print(user_id, token)
|
||||
response = admin.create_user(test_user_id) # Adjust as per your client's method
|
||||
user_id = response.user_id
|
||||
token = response.api_key
|
||||
|
||||
yield token
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user