mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
feat: cursor pagination of get_all_users in /admin/users route (#1441)
This commit is contained in:
parent
b7e8a11399
commit
0ceea243a3
@ -28,8 +28,9 @@ class Admin:
|
||||
self.token = token
|
||||
self.headers = {"accept": "application/json", "content-type": "application/json", "authorization": f"Bearer {token}"}
|
||||
|
||||
def get_users(self):
|
||||
response = requests.get(f"{self.base_url}/admin/users", headers=self.headers)
|
||||
def get_users(self, cursor: Optional[uuid.UUID] = None, limit: Optional[int] = 50):
|
||||
payload = {"cursor": str(cursor) if cursor else None, "limit": limit}
|
||||
response = requests.get(f"{self.base_url}/admin/users", headers=self.headers, json=payload)
|
||||
if response.status_code != 200:
|
||||
raise HTTPError(response.json())
|
||||
return GetAllUsersResponse(**response.json())
|
||||
|
@ -17,6 +17,7 @@ from sqlalchemy import (
|
||||
String,
|
||||
TypeDecorator,
|
||||
create_engine,
|
||||
desc,
|
||||
func,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
@ -647,11 +648,19 @@ class MetadataStore:
|
||||
return results[0].to_record()
|
||||
|
||||
@enforce_types
|
||||
def get_all_users(self) -> List[User]:
|
||||
# TODO make paginated
|
||||
def get_all_users(self, cursor: Optional[uuid.UUID] = None, limit: Optional[int] = 50) -> (Optional[uuid.UUID], List[User]):
|
||||
with self.session_maker() as session:
|
||||
results = session.query(UserModel).all()
|
||||
return [r.to_record() for r in results]
|
||||
query = session.query(UserModel).order_by(desc(UserModel.id))
|
||||
if cursor:
|
||||
query = query.filter(UserModel.id < cursor)
|
||||
results = query.limit(limit).all()
|
||||
if not results:
|
||||
return None, []
|
||||
user_records = [r.to_record() for r in results]
|
||||
next_cursor = user_records[-1].id
|
||||
assert isinstance(next_cursor, uuid.UUID)
|
||||
|
||||
return next_cursor, user_records
|
||||
|
||||
@enforce_types
|
||||
def get_source(
|
||||
|
@ -11,7 +11,13 @@ from memgpt.server.server import SyncServer
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class GetAllUsersRequest(BaseModel):
|
||||
cursor: Optional[uuid.UUID] = Field(None, description="Cursor to which to start the paginated request.")
|
||||
limit: Optional[int] = Field(50, description="Maximum number of users to retrieve per page.")
|
||||
|
||||
|
||||
class GetAllUsersResponse(BaseModel):
|
||||
cursor: Optional[uuid.UUID] = Field(None, description="Cursor for the next page in the response.")
|
||||
user_list: List[dict] = Field(..., description="A list of users.")
|
||||
|
||||
|
||||
@ -54,18 +60,18 @@ class DeleteUserResponse(BaseModel):
|
||||
|
||||
def setup_admin_router(server: SyncServer, interface: QueuingInterface):
|
||||
@router.get("/users", tags=["admin"], response_model=GetAllUsersResponse)
|
||||
def get_all_users():
|
||||
def get_all_users(request: GetAllUsersRequest = Body(...)):
|
||||
"""
|
||||
Get a list of all users in the database
|
||||
"""
|
||||
try:
|
||||
users = server.ms.get_all_users()
|
||||
next_cursor, users = server.ms.get_all_users(request.cursor, request.limit)
|
||||
processed_users = [{"user_id": user.id} for user in users]
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"{e}")
|
||||
return GetAllUsersResponse(user_list=processed_users)
|
||||
return GetAllUsersResponse(cursor=next_cursor, user_list=processed_users)
|
||||
|
||||
@router.post("/users", tags=["admin"], response_model=CreateUserResponse)
|
||||
def create_user(request: Optional[CreateUserRequest] = Body(None)):
|
||||
|
@ -79,3 +79,54 @@ def test_admin_client(admin_client):
|
||||
# list users
|
||||
users = admin_client.get_users()
|
||||
assert len(users.user_list) == 0, f"Expected 0 users, got {users}"
|
||||
|
||||
|
||||
def test_get_users_pagination(admin_client):
|
||||
_reset_config()
|
||||
|
||||
page_size = 5
|
||||
num_users = 7
|
||||
expected_users_remainder = num_users - page_size
|
||||
|
||||
# create users
|
||||
all_user_ids = []
|
||||
for i in range(num_users):
|
||||
|
||||
user_id = uuid.uuid4()
|
||||
all_user_ids.append(user_id)
|
||||
key_name = "test_key" + f"{i}"
|
||||
|
||||
create_user_response = admin_client.create_user(user_id)
|
||||
admin_client.create_key(create_user_response.user_id, key_name)
|
||||
|
||||
# list users in page 1
|
||||
get_all_users_response1 = admin_client.get_users(limit=page_size)
|
||||
cursor1 = get_all_users_response1.cursor
|
||||
user_list1 = get_all_users_response1.user_list
|
||||
assert len(user_list1) == page_size
|
||||
|
||||
# list users in page 2 using cursor
|
||||
get_all_users_response2 = admin_client.get_users(cursor1, limit=page_size)
|
||||
cursor2 = get_all_users_response2.cursor
|
||||
user_list2 = get_all_users_response2.user_list
|
||||
|
||||
assert len(user_list2) == expected_users_remainder
|
||||
assert cursor1 != cursor2
|
||||
|
||||
# delete users
|
||||
clean_up_users_and_keys(all_user_ids)
|
||||
|
||||
# list users to check pagination with no users
|
||||
users = admin_client.get_users()
|
||||
assert len(users.user_list) == 0, f"Expected 0 users, got {users}"
|
||||
|
||||
|
||||
def clean_up_users_and_keys(user_id_list):
|
||||
admin_client = Admin(test_base_url, test_server_token)
|
||||
|
||||
# clean up all keys and users
|
||||
for user_id in user_id_list:
|
||||
keys_list = admin_client.get_keys(user_id)
|
||||
for key in keys_list:
|
||||
admin_client.delete_key(key)
|
||||
admin_client.delete_user(user_id)
|
||||
|
Loading…
Reference in New Issue
Block a user