MemGPT/tests/sdk/conftest.py
Kian Jones 6e25591eee feat(ci): Adds Integration Testing Against cloud-api (#1909)
Co-authored-by: Shubham Naik <shub@letta.com>
Co-authored-by: Shubham Naik <shub@memgpt.ai>
2025-05-08 11:28:41 -07:00

307 lines
11 KiB
Python

import os
import threading
import time
from typing import Any, Dict, List, Optional, Tuple
import pytest
from dotenv import load_dotenv
from letta_client import Letta
def run_server():
load_dotenv()
from letta.server.rest_api.app import start_server
print("Starting server...")
start_server(debug=True)
# This fixture starts the server once for the entire test session
@pytest.fixture(scope="session")
def server():
# Get URL from environment or start server
server_url = os.getenv("LETTA_SERVER_URL", f"http://localhost:8283")
if not os.getenv("LETTA_SERVER_URL"):
print("Starting server thread")
thread = threading.Thread(target=run_server, daemon=True)
thread.start()
time.sleep(5)
return server_url
# This fixture creates a client for each test module
@pytest.fixture(scope="session")
def client(server):
# Use the server URL from the server fixture
server_url = server
print("Running client tests with server:", server_url)
# Overide the base_url if the LETTA_API_URL is set
api_url = os.getenv("LETTA_API_URL")
base_url = api_url if api_url else server_url
# create the Letta client
yield Letta(base_url=base_url, token=None)
def skip_test_if_not_implemented(handler, resource_name, test_name):
if not hasattr(handler, test_name):
pytest.skip(f"client.{resource_name}.{test_name} not implemented")
def create_test_module(
resource_name: str,
id_param_name: str,
create_params: List[Tuple[str, Dict[str, Any], Dict[str, Any], Optional[Exception]]] = [],
upsert_params: List[Tuple[str, Dict[str, Any], Dict[str, Any], Optional[Exception]]] = [],
modify_params: List[Tuple[str, Dict[str, Any], Dict[str, Any], Optional[Exception]]] = [],
list_params: List[Tuple[Dict[str, Any], int]] = [],
) -> Dict[str, Any]:
"""Create a test module for a resource.
This function creates all the necessary test methods and returns them in a dictionary
that can be added to the globals() of the module.
Args:
resource_name: Name of the resource (e.g., "blocks", "tools")
id_param_name: Name of the ID parameter (e.g., "block_id", "tool_id")
create_params: List of (name, params, expected_error) tuples for create tests
modify_params: List of (name, params, expected_error) tuples for modify tests
list_params: List of (query_params, expected_count) tuples for list tests
Returns:
Dict: A dictionary of all test functions that should be added to the module globals
"""
# Create shared test state
test_item_ids = {}
# Create fixture functions
@pytest.fixture(scope="session")
def handler(client):
return getattr(client, resource_name)
@pytest.fixture(scope="session")
def caren_agent(client, request):
"""Create an agent to be used as manager in supervisor groups."""
agent = client.agents.create(
name="caren_agent",
model="openai/gpt-4o-mini",
embedding="openai/text-embedding-ada-002",
)
# Add finalizer to ensure cleanup happens in the right order
request.addfinalizer(lambda: client.agents.delete(agent_id=agent.id))
return agent
# Create standalone test functions
@pytest.mark.order(0)
def test_create(handler, caren_agent, name, params, extra_expected_values, expected_error):
"""Test creating a resource."""
skip_test_if_not_implemented(handler, resource_name, "create")
# Use preprocess_params which adds fixtures
processed_params = preprocess_params(params, caren_agent)
processed_extra_expected = preprocess_params(extra_expected_values, caren_agent)
try:
item = handler.create(**processed_params)
except Exception as e:
if expected_error is not None:
if hasattr(e, "status_code"):
assert e.status_code == expected_error
elif hasattr(e, "status"):
assert e.status == expected_error
else:
pytest.fail(f"Expected error with status {expected_error}, but got {type(e)}: {e}")
else:
raise e
# Store item ID for later tests
test_item_ids[name] = item.id
# Verify item properties
expected_values = processed_params | processed_extra_expected
for key, value in expected_values.items():
if hasattr(item, key):
assert custom_model_dump(getattr(item, key)) == value
@pytest.mark.order(1)
def test_retrieve(handler):
"""Test retrieving resources."""
skip_test_if_not_implemented(handler, resource_name, "retrieve")
for name, item_id in test_item_ids.items():
kwargs = {id_param_name: item_id}
item = handler.retrieve(**kwargs)
assert hasattr(item, "id") and item.id == item_id, f"{resource_name.capitalize()} {name} with id {item_id} not found"
@pytest.mark.order(2)
def test_upsert(handler, name, params, extra_expected_values, expected_error):
"""Test upserting resources."""
skip_test_if_not_implemented(handler, resource_name, "upsert")
existing_item_id = test_item_ids[name]
try:
item = handler.upsert(**params)
except Exception as e:
if expected_error is not None:
if hasattr(e, "status_code"):
assert e.status_code == expected_error
elif hasattr(e, "status"):
assert e.status == expected_error
else:
pytest.fail(f"Expected error with status {expected_error}, but got {type(e)}: {e}")
else:
raise e
assert existing_item_id == item.id
# Verify item properties
expected_values = params | extra_expected_values
for key, value in expected_values.items():
if hasattr(item, key):
assert custom_model_dump(getattr(item, key)) == value
@pytest.mark.order(3)
def test_modify(handler, caren_agent, name, params, extra_expected_values, expected_error):
"""Test modifying a resource."""
skip_test_if_not_implemented(handler, resource_name, "modify")
if name not in test_item_ids:
pytest.skip(f"Item '{name}' not found in test_items")
kwargs = {id_param_name: test_item_ids[name]}
kwargs.update(params)
processed_params = preprocess_params(kwargs, caren_agent)
processed_extra_expected = preprocess_params(extra_expected_values, caren_agent)
try:
item = handler.modify(**processed_params)
except Exception as e:
if expected_error is not None:
assert isinstance(e, expected_error), f"Expected error with type {expected_error}, but got {type(e)}: {e}"
return
else:
raise e
# Verify item properties
expected_values = processed_params | processed_extra_expected
for key, value in expected_values.items():
if hasattr(item, key):
assert custom_model_dump(getattr(item, key)) == value
# Verify via retrieve as well
retrieve_kwargs = {id_param_name: item.id}
retrieved_item = handler.retrieve(**retrieve_kwargs)
expected_values = processed_params | processed_extra_expected
for key, value in expected_values.items():
if hasattr(retrieved_item, key):
assert custom_model_dump(getattr(retrieved_item, key)) == value
@pytest.mark.order(4)
def test_list(handler, query_params, count):
"""Test listing resources."""
skip_test_if_not_implemented(handler, resource_name, "list")
all_items = handler.list(**query_params)
test_items_list = [item.id for item in all_items if item.id in test_item_ids.values()]
assert len(test_items_list) == count
@pytest.mark.order(-1)
def test_delete(handler):
"""Test deleting resources."""
skip_test_if_not_implemented(handler, resource_name, "delete")
for item_id in test_item_ids.values():
kwargs = {id_param_name: item_id}
handler.delete(**kwargs)
for name, item_id in test_item_ids.items():
try:
kwargs = {id_param_name: item_id}
item = handler.retrieve(**kwargs)
raise AssertionError(f"{resource_name.capitalize()} {name} with id {item.id} was not deleted")
except Exception as e:
if isinstance(e, AssertionError):
raise e
if hasattr(e, "status_code"):
assert e.status_code == 404, f"Expected 404 error, got {e.status_code}"
else:
raise AssertionError(f"Unexpected error type: {type(e)}")
test_item_ids.clear()
# Create test methods dictionary
result = {
"handler": handler,
"caren_agent": caren_agent,
"test_create": pytest.mark.parametrize("name, params, extra_expected_values, expected_error", create_params)(test_create),
"test_retrieve": test_retrieve,
"test_upsert": pytest.mark.parametrize("name, params, extra_expected_values, expected_error", upsert_params)(test_upsert),
"test_modify": pytest.mark.parametrize("name, params, extra_expected_values, expected_error", modify_params)(test_modify),
"test_delete": test_delete,
"test_list": pytest.mark.parametrize("query_params, count", list_params)(test_list),
}
return result
def custom_model_dump(model):
"""
Dumps the given model to a form that can be easily compared.
Args:
model: The model to dump
Returns:
The dumped model
"""
if isinstance(model, (str, int, float, bool, type(None))):
return model
if isinstance(model, list):
return [custom_model_dump(item) for item in model]
else:
return model.model_dump()
def add_fixture_params(value, caren_agent):
"""
Replaces string values containing '.id' with their mapped values.
Args:
value: The value to process (should be a string)
caren_agent: The agent object to use for ID replacement
Returns:
The processed value with ID strings replaced by actual values
"""
param_to_fixture_mapping = {
"caren_agent.id": caren_agent.id,
}
return param_to_fixture_mapping.get(value, value)
def preprocess_params(params, caren_agent):
"""
Recursively processes a nested structure of dictionaries and lists,
replacing string values containing '.id' with their mapped values.
Args:
params: The parameters to process (dict, list, or scalar value)
caren_agent: The agent object to use for ID replacement
Returns:
The processed parameters with ID strings replaced by actual values
"""
if isinstance(params, dict):
# Process each key-value pair in the dictionary
return {key: preprocess_params(value, caren_agent) for key, value in params.items()}
elif isinstance(params, list):
# Process each item in the list
return [preprocess_params(item, caren_agent) for item in params]
elif isinstance(params, str) and ".id" in params:
# Replace string values containing '.id' with their mapped values
return add_fixture_params(params, caren_agent)
else:
# Return other values unchanged
return params