mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
feat: cursor pagination of get_all_users in /admin/users route (#1441)
This commit is contained in:
parent
b7e8a11399
commit
0ceea243a3
@ -28,8 +28,9 @@ class Admin:
|
|||||||
self.token = token
|
self.token = token
|
||||||
self.headers = {"accept": "application/json", "content-type": "application/json", "authorization": f"Bearer {token}"}
|
self.headers = {"accept": "application/json", "content-type": "application/json", "authorization": f"Bearer {token}"}
|
||||||
|
|
||||||
def get_users(self):
|
def get_users(self, cursor: Optional[uuid.UUID] = None, limit: Optional[int] = 50):
|
||||||
response = requests.get(f"{self.base_url}/admin/users", headers=self.headers)
|
payload = {"cursor": str(cursor) if cursor else None, "limit": limit}
|
||||||
|
response = requests.get(f"{self.base_url}/admin/users", headers=self.headers, json=payload)
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
raise HTTPError(response.json())
|
raise HTTPError(response.json())
|
||||||
return GetAllUsersResponse(**response.json())
|
return GetAllUsersResponse(**response.json())
|
||||||
|
@ -17,6 +17,7 @@ from sqlalchemy import (
|
|||||||
String,
|
String,
|
||||||
TypeDecorator,
|
TypeDecorator,
|
||||||
create_engine,
|
create_engine,
|
||||||
|
desc,
|
||||||
func,
|
func,
|
||||||
)
|
)
|
||||||
from sqlalchemy.dialects.postgresql import UUID
|
from sqlalchemy.dialects.postgresql import UUID
|
||||||
@ -647,11 +648,19 @@ class MetadataStore:
|
|||||||
return results[0].to_record()
|
return results[0].to_record()
|
||||||
|
|
||||||
@enforce_types
|
@enforce_types
|
||||||
def get_all_users(self) -> List[User]:
|
def get_all_users(self, cursor: Optional[uuid.UUID] = None, limit: Optional[int] = 50) -> (Optional[uuid.UUID], List[User]):
|
||||||
# TODO make paginated
|
|
||||||
with self.session_maker() as session:
|
with self.session_maker() as session:
|
||||||
results = session.query(UserModel).all()
|
query = session.query(UserModel).order_by(desc(UserModel.id))
|
||||||
return [r.to_record() for r in results]
|
if cursor:
|
||||||
|
query = query.filter(UserModel.id < cursor)
|
||||||
|
results = query.limit(limit).all()
|
||||||
|
if not results:
|
||||||
|
return None, []
|
||||||
|
user_records = [r.to_record() for r in results]
|
||||||
|
next_cursor = user_records[-1].id
|
||||||
|
assert isinstance(next_cursor, uuid.UUID)
|
||||||
|
|
||||||
|
return next_cursor, user_records
|
||||||
|
|
||||||
@enforce_types
|
@enforce_types
|
||||||
def get_source(
|
def get_source(
|
||||||
|
@ -11,7 +11,13 @@ from memgpt.server.server import SyncServer
|
|||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
class GetAllUsersRequest(BaseModel):
|
||||||
|
cursor: Optional[uuid.UUID] = Field(None, description="Cursor to which to start the paginated request.")
|
||||||
|
limit: Optional[int] = Field(50, description="Maximum number of users to retrieve per page.")
|
||||||
|
|
||||||
|
|
||||||
class GetAllUsersResponse(BaseModel):
|
class GetAllUsersResponse(BaseModel):
|
||||||
|
cursor: Optional[uuid.UUID] = Field(None, description="Cursor for the next page in the response.")
|
||||||
user_list: List[dict] = Field(..., description="A list of users.")
|
user_list: List[dict] = Field(..., description="A list of users.")
|
||||||
|
|
||||||
|
|
||||||
@ -54,18 +60,18 @@ class DeleteUserResponse(BaseModel):
|
|||||||
|
|
||||||
def setup_admin_router(server: SyncServer, interface: QueuingInterface):
|
def setup_admin_router(server: SyncServer, interface: QueuingInterface):
|
||||||
@router.get("/users", tags=["admin"], response_model=GetAllUsersResponse)
|
@router.get("/users", tags=["admin"], response_model=GetAllUsersResponse)
|
||||||
def get_all_users():
|
def get_all_users(request: GetAllUsersRequest = Body(...)):
|
||||||
"""
|
"""
|
||||||
Get a list of all users in the database
|
Get a list of all users in the database
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
users = server.ms.get_all_users()
|
next_cursor, users = server.ms.get_all_users(request.cursor, request.limit)
|
||||||
processed_users = [{"user_id": user.id} for user in users]
|
processed_users = [{"user_id": user.id} for user in users]
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=f"{e}")
|
raise HTTPException(status_code=500, detail=f"{e}")
|
||||||
return GetAllUsersResponse(user_list=processed_users)
|
return GetAllUsersResponse(cursor=next_cursor, user_list=processed_users)
|
||||||
|
|
||||||
@router.post("/users", tags=["admin"], response_model=CreateUserResponse)
|
@router.post("/users", tags=["admin"], response_model=CreateUserResponse)
|
||||||
def create_user(request: Optional[CreateUserRequest] = Body(None)):
|
def create_user(request: Optional[CreateUserRequest] = Body(None)):
|
||||||
|
@ -79,3 +79,54 @@ def test_admin_client(admin_client):
|
|||||||
# list users
|
# list users
|
||||||
users = admin_client.get_users()
|
users = admin_client.get_users()
|
||||||
assert len(users.user_list) == 0, f"Expected 0 users, got {users}"
|
assert len(users.user_list) == 0, f"Expected 0 users, got {users}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_users_pagination(admin_client):
|
||||||
|
_reset_config()
|
||||||
|
|
||||||
|
page_size = 5
|
||||||
|
num_users = 7
|
||||||
|
expected_users_remainder = num_users - page_size
|
||||||
|
|
||||||
|
# create users
|
||||||
|
all_user_ids = []
|
||||||
|
for i in range(num_users):
|
||||||
|
|
||||||
|
user_id = uuid.uuid4()
|
||||||
|
all_user_ids.append(user_id)
|
||||||
|
key_name = "test_key" + f"{i}"
|
||||||
|
|
||||||
|
create_user_response = admin_client.create_user(user_id)
|
||||||
|
admin_client.create_key(create_user_response.user_id, key_name)
|
||||||
|
|
||||||
|
# list users in page 1
|
||||||
|
get_all_users_response1 = admin_client.get_users(limit=page_size)
|
||||||
|
cursor1 = get_all_users_response1.cursor
|
||||||
|
user_list1 = get_all_users_response1.user_list
|
||||||
|
assert len(user_list1) == page_size
|
||||||
|
|
||||||
|
# list users in page 2 using cursor
|
||||||
|
get_all_users_response2 = admin_client.get_users(cursor1, limit=page_size)
|
||||||
|
cursor2 = get_all_users_response2.cursor
|
||||||
|
user_list2 = get_all_users_response2.user_list
|
||||||
|
|
||||||
|
assert len(user_list2) == expected_users_remainder
|
||||||
|
assert cursor1 != cursor2
|
||||||
|
|
||||||
|
# delete users
|
||||||
|
clean_up_users_and_keys(all_user_ids)
|
||||||
|
|
||||||
|
# list users to check pagination with no users
|
||||||
|
users = admin_client.get_users()
|
||||||
|
assert len(users.user_list) == 0, f"Expected 0 users, got {users}"
|
||||||
|
|
||||||
|
|
||||||
|
def clean_up_users_and_keys(user_id_list):
|
||||||
|
admin_client = Admin(test_base_url, test_server_token)
|
||||||
|
|
||||||
|
# clean up all keys and users
|
||||||
|
for user_id in user_id_list:
|
||||||
|
keys_list = admin_client.get_keys(user_id)
|
||||||
|
for key in keys_list:
|
||||||
|
admin_client.delete_key(key)
|
||||||
|
admin_client.delete_user(user_id)
|
||||||
|
Loading…
Reference in New Issue
Block a user