feat: cursor pagination of get_all_users in /admin/users route (#1441)

This commit is contained in:
Eren-Ajani Tshimanga 2024-06-09 17:24:11 -07:00 committed by GitHub
parent b7e8a11399
commit 0ceea243a3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 76 additions and 9 deletions

View File

@ -28,8 +28,9 @@ class Admin:
self.token = token
self.headers = {"accept": "application/json", "content-type": "application/json", "authorization": f"Bearer {token}"}
def get_users(self):
response = requests.get(f"{self.base_url}/admin/users", headers=self.headers)
def get_users(self, cursor: Optional[uuid.UUID] = None, limit: Optional[int] = 50):
payload = {"cursor": str(cursor) if cursor else None, "limit": limit}
response = requests.get(f"{self.base_url}/admin/users", headers=self.headers, json=payload)
if response.status_code != 200:
raise HTTPError(response.json())
return GetAllUsersResponse(**response.json())

View File

@ -17,6 +17,7 @@ from sqlalchemy import (
String,
TypeDecorator,
create_engine,
desc,
func,
)
from sqlalchemy.dialects.postgresql import UUID
@ -647,11 +648,19 @@ class MetadataStore:
return results[0].to_record()
@enforce_types
def get_all_users(self) -> List[User]:
# TODO make paginated
def get_all_users(self, cursor: Optional[uuid.UUID] = None, limit: Optional[int] = 50) -> (Optional[uuid.UUID], List[User]):
with self.session_maker() as session:
results = session.query(UserModel).all()
return [r.to_record() for r in results]
query = session.query(UserModel).order_by(desc(UserModel.id))
if cursor:
query = query.filter(UserModel.id < cursor)
results = query.limit(limit).all()
if not results:
return None, []
user_records = [r.to_record() for r in results]
next_cursor = user_records[-1].id
assert isinstance(next_cursor, uuid.UUID)
return next_cursor, user_records
@enforce_types
def get_source(

View File

@ -11,7 +11,13 @@ from memgpt.server.server import SyncServer
router = APIRouter()
class GetAllUsersRequest(BaseModel):
cursor: Optional[uuid.UUID] = Field(None, description="Cursor to which to start the paginated request.")
limit: Optional[int] = Field(50, description="Maximum number of users to retrieve per page.")
class GetAllUsersResponse(BaseModel):
cursor: Optional[uuid.UUID] = Field(None, description="Cursor for the next page in the response.")
user_list: List[dict] = Field(..., description="A list of users.")
@ -54,18 +60,18 @@ class DeleteUserResponse(BaseModel):
def setup_admin_router(server: SyncServer, interface: QueuingInterface):
@router.get("/users", tags=["admin"], response_model=GetAllUsersResponse)
def get_all_users():
def get_all_users(request: GetAllUsersRequest = Body(...)):
"""
Get a list of all users in the database
"""
try:
users = server.ms.get_all_users()
next_cursor, users = server.ms.get_all_users(request.cursor, request.limit)
processed_users = [{"user_id": user.id} for user in users]
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
return GetAllUsersResponse(user_list=processed_users)
return GetAllUsersResponse(cursor=next_cursor, user_list=processed_users)
@router.post("/users", tags=["admin"], response_model=CreateUserResponse)
def create_user(request: Optional[CreateUserRequest] = Body(None)):

View File

@ -79,3 +79,54 @@ def test_admin_client(admin_client):
# list users
users = admin_client.get_users()
assert len(users.user_list) == 0, f"Expected 0 users, got {users}"
def test_get_users_pagination(admin_client):
_reset_config()
page_size = 5
num_users = 7
expected_users_remainder = num_users - page_size
# create users
all_user_ids = []
for i in range(num_users):
user_id = uuid.uuid4()
all_user_ids.append(user_id)
key_name = "test_key" + f"{i}"
create_user_response = admin_client.create_user(user_id)
admin_client.create_key(create_user_response.user_id, key_name)
# list users in page 1
get_all_users_response1 = admin_client.get_users(limit=page_size)
cursor1 = get_all_users_response1.cursor
user_list1 = get_all_users_response1.user_list
assert len(user_list1) == page_size
# list users in page 2 using cursor
get_all_users_response2 = admin_client.get_users(cursor1, limit=page_size)
cursor2 = get_all_users_response2.cursor
user_list2 = get_all_users_response2.user_list
assert len(user_list2) == expected_users_remainder
assert cursor1 != cursor2
# delete users
clean_up_users_and_keys(all_user_ids)
# list users to check pagination with no users
users = admin_client.get_users()
assert len(users.user_list) == 0, f"Expected 0 users, got {users}"
def clean_up_users_and_keys(user_id_list):
admin_client = Admin(test_base_url, test_server_token)
# clean up all keys and users
for user_id in user_id_list:
keys_list = admin_client.get_keys(user_id)
for key in keys_list:
admin_client.delete_key(key)
admin_client.delete_user(user_id)