from datetime import datetime from unittest.mock import MagicMock, Mock, patch import pytest from composio.client.collections import ActionModel, ActionParametersModel, ActionResponseModel, AppModel from fastapi.testclient import TestClient from letta.orm.errors import NoResultFound from letta.schemas.tool import ToolCreate, ToolUpdate from letta.server.rest_api.app import app from letta.server.rest_api.utils import get_letta_server from tests.helpers.utils import create_tool_from_func @pytest.fixture def client(): return TestClient(app) @pytest.fixture def mock_sync_server(): mock_server = Mock() app.dependency_overrides[get_letta_server] = lambda: mock_server return mock_server @pytest.fixture def add_integers_tool(): def add(x: int, y: int) -> int: """ Simple function that adds two integers. Parameters: x (int): The first integer to add. y (int): The second integer to add. Returns: int: The result of adding x and y. """ return x + y tool = create_tool_from_func(add) yield tool @pytest.fixture def create_integers_tool(add_integers_tool): tool_create = ToolCreate( name=add_integers_tool.name, description=add_integers_tool.description, tags=add_integers_tool.tags, module=add_integers_tool.module, source_code=add_integers_tool.source_code, source_type=add_integers_tool.source_type, json_schema=add_integers_tool.json_schema, ) yield tool_create @pytest.fixture def update_integers_tool(add_integers_tool): tool_update = ToolUpdate( name=add_integers_tool.name, description=add_integers_tool.description, tags=add_integers_tool.tags, module=add_integers_tool.module, source_code=add_integers_tool.source_code, source_type=add_integers_tool.source_type, json_schema=add_integers_tool.json_schema, ) yield tool_update @pytest.fixture def composio_apps(): affinity_app = AppModel( name="affinity", key="affinity", appId="3a7d2dc7-c58c-4491-be84-f64b1ff498a8", description="Affinity helps private capital investors to find, manage, and close more deals", categories=["CRM"], meta={ "is_custom_app": False, "triggersCount": 0, "actionsCount": 20, "documentation_doc_text": None, "configuration_docs_text": None, }, logo="https://cdn.jsdelivr.net/gh/ComposioHQ/open-logos@master/affinity.jpeg", docs=None, group=None, status=None, enabled=False, no_auth=False, auth_schemes=None, testConnectors=None, documentation_doc_text=None, configuration_docs_text=None, ) yield [affinity_app] @pytest.fixture def composio_actions(): yield [ ActionModel( name="AFFINITY_GET_ALL_COMPANIES", display_name="Get all companies", parameters=ActionParametersModel( properties={ "cursor": {"default": None, "description": "Cursor for the next or previous page", "title": "Cursor", "type": "string"}, "limit": {"default": 100, "description": "Number of items to include in the page", "title": "Limit", "type": "integer"}, "ids": {"default": None, "description": "Company IDs", "items": {"type": "integer"}, "title": "Ids", "type": "array"}, "fieldIds": { "default": None, "description": "Field IDs for which to return field data", "items": {"type": "string"}, "title": "Fieldids", "type": "array", }, "fieldTypes": { "default": None, "description": "Field Types for which to return field data", "items": {"enum": ["enriched", "global", "relationship-intelligence"], "title": "FieldtypesEnm", "type": "string"}, "title": "Fieldtypes", "type": "array", }, }, title="GetAllCompaniesRequest", type="object", required=None, ), response=ActionResponseModel( properties={ "data": {"title": "Data", "type": "object"}, "successful": { "description": "Whether or not the action execution was successful or not", "title": "Successful", "type": "boolean", }, "error": { "anyOf": [{"type": "string"}, {"type": "null"}], "default": None, "description": "Error if any occurred during the execution of the action", "title": "Error", }, }, title="GetAllCompaniesResponse", type="object", required=["data", "successful"], ), appName="affinity", appId="affinity", tags=["companies", "important"], enabled=False, logo="https://cdn.jsdelivr.net/gh/ComposioHQ/open-logos@master/affinity.jpeg", description="Affinity Api Allows Paginated Access To Company Info And Custom Fields. Use `Field Ids` Or `Field Types` To Specify Data In A Request. Retrieve Field I Ds/Types Via Get `/V2/Companies/Fields`. Export Permission Needed.", ) ] def configure_mock_sync_server(mock_sync_server): # Mock sandbox config manager to return a valid API key mock_api_key = Mock() mock_api_key.value = "mock_composio_api_key" mock_sync_server.sandbox_config_manager.list_sandbox_env_vars_by_key.return_value = [mock_api_key] # Mock user retrieval mock_sync_server.user_manager.get_user_or_default.return_value = Mock() # Provide additional attributes if needed # ====================================================================================================================== # Tools Routes Tests # ====================================================================================================================== def test_delete_tool(client, mock_sync_server, add_integers_tool): mock_sync_server.tool_manager.delete_tool_by_id = MagicMock() response = client.delete(f"/v1/tools/{add_integers_tool.id}", headers={"user_id": "test_user"}) assert response.status_code == 200 mock_sync_server.tool_manager.delete_tool_by_id.assert_called_once_with( tool_id=add_integers_tool.id, actor=mock_sync_server.user_manager.get_user_or_default.return_value ) def test_get_tool(client, mock_sync_server, add_integers_tool): mock_sync_server.tool_manager.get_tool_by_id.return_value = add_integers_tool response = client.get(f"/v1/tools/{add_integers_tool.id}", headers={"user_id": "test_user"}) assert response.status_code == 200 assert response.json()["id"] == add_integers_tool.id assert response.json()["source_code"] == add_integers_tool.source_code mock_sync_server.tool_manager.get_tool_by_id.assert_called_once_with( tool_id=add_integers_tool.id, actor=mock_sync_server.user_manager.get_user_or_default.return_value ) def test_get_tool_404(client, mock_sync_server, add_integers_tool): mock_sync_server.tool_manager.get_tool_by_id.return_value = None response = client.get(f"/v1/tools/{add_integers_tool.id}", headers={"user_id": "test_user"}) assert response.status_code == 404 assert response.json()["detail"] == f"Tool with id {add_integers_tool.id} not found." def test_get_tool_id(client, mock_sync_server, add_integers_tool): mock_sync_server.tool_manager.get_tool_by_name.return_value = add_integers_tool response = client.get(f"/v1/tools/name/{add_integers_tool.name}", headers={"user_id": "test_user"}) assert response.status_code == 200 assert response.json() == add_integers_tool.id mock_sync_server.tool_manager.get_tool_by_name.assert_called_once_with( tool_name=add_integers_tool.name, actor=mock_sync_server.user_manager.get_user_or_default.return_value ) def test_get_tool_id_404(client, mock_sync_server): mock_sync_server.tool_manager.get_tool_by_name.return_value = None response = client.get("/v1/tools/name/UnknownTool", headers={"user_id": "test_user"}) assert response.status_code == 404 assert "Tool with name UnknownTool" in response.json()["detail"] def test_list_tools(client, mock_sync_server, add_integers_tool): mock_sync_server.tool_manager.list_tools.return_value = [add_integers_tool] response = client.get("/v1/tools", headers={"user_id": "test_user"}) assert response.status_code == 200 assert len(response.json()) == 1 assert response.json()[0]["id"] == add_integers_tool.id mock_sync_server.tool_manager.list_tools.assert_called_once() def test_create_tool(client, mock_sync_server, create_integers_tool, add_integers_tool): mock_sync_server.tool_manager.create_tool.return_value = add_integers_tool response = client.post("/v1/tools", json=create_integers_tool.model_dump(), headers={"user_id": "test_user"}) assert response.status_code == 200 assert response.json()["id"] == add_integers_tool.id mock_sync_server.tool_manager.create_tool.assert_called_once() def test_upsert_tool(client, mock_sync_server, create_integers_tool, add_integers_tool): mock_sync_server.tool_manager.create_or_update_tool.return_value = add_integers_tool response = client.put("/v1/tools", json=create_integers_tool.model_dump(), headers={"user_id": "test_user"}) assert response.status_code == 200 assert response.json()["id"] == add_integers_tool.id mock_sync_server.tool_manager.create_or_update_tool.assert_called_once() def test_update_tool(client, mock_sync_server, update_integers_tool, add_integers_tool): mock_sync_server.tool_manager.update_tool_by_id.return_value = add_integers_tool response = client.patch(f"/v1/tools/{add_integers_tool.id}", json=update_integers_tool.model_dump(), headers={"user_id": "test_user"}) assert response.status_code == 200 assert response.json()["id"] == add_integers_tool.id mock_sync_server.tool_manager.update_tool_by_id.assert_called_once_with( tool_id=add_integers_tool.id, tool_update=update_integers_tool, actor=mock_sync_server.user_manager.get_user_or_default.return_value ) def test_upsert_base_tools(client, mock_sync_server, add_integers_tool): mock_sync_server.tool_manager.upsert_base_tools.return_value = [add_integers_tool] response = client.post("/v1/tools/add-base-tools", headers={"user_id": "test_user"}) assert response.status_code == 200 assert len(response.json()) == 1 assert response.json()[0]["id"] == add_integers_tool.id mock_sync_server.tool_manager.upsert_base_tools.assert_called_once_with( actor=mock_sync_server.user_manager.get_user_or_default.return_value ) def test_list_composio_apps(client, mock_sync_server, composio_apps): configure_mock_sync_server(mock_sync_server) mock_sync_server.get_composio_apps.return_value = composio_apps response = client.get("/v1/tools/composio/apps") assert response.status_code == 200 assert len(response.json()) == 1 mock_sync_server.get_composio_apps.assert_called_once() def test_list_composio_actions_by_app(client, mock_sync_server, composio_actions): configure_mock_sync_server(mock_sync_server) mock_sync_server.get_composio_actions_from_app_name.return_value = composio_actions response = client.get("/v1/tools/composio/apps/App1/actions") assert response.status_code == 200 assert len(response.json()) == 1 mock_sync_server.get_composio_actions_from_app_name.assert_called_once_with(composio_app_name="App1", api_key="mock_composio_api_key") def test_add_composio_tool(client, mock_sync_server, add_integers_tool): configure_mock_sync_server(mock_sync_server) # Mock ToolCreate.from_composio to return the expected ToolCreate object with patch("letta.schemas.tool.ToolCreate.from_composio") as mock_from_composio: mock_from_composio.return_value = ToolCreate( name=add_integers_tool.name, source_code=add_integers_tool.source_code, json_schema=add_integers_tool.json_schema, ) # Mock server behavior mock_sync_server.tool_manager.create_or_update_tool.return_value = add_integers_tool # Perform the request response = client.post(f"/v1/tools/composio/{add_integers_tool.name}", headers={"user_id": "test_user"}) # Assertions assert response.status_code == 200 assert response.json()["id"] == add_integers_tool.id mock_sync_server.tool_manager.create_or_update_tool.assert_called_once() # Verify the mocked from_composio method was called mock_from_composio.assert_called_once_with(action_name=add_integers_tool.name, api_key="mock_composio_api_key") # ====================================================================================================================== # Runs Routes Tests # ====================================================================================================================== def test_get_run_messages(client, mock_sync_server): """Test getting messages for a run.""" # Create properly formatted mock messages current_time = datetime.utcnow() messages_data = [ { "id": f"message-{i:08x}", # Matches pattern '^message-[a-fA-F0-9]{8}' "text": f"Test message {i}", "role": "user", "organization_id": "org-123", "agent_id": "agent-123", "model": "gpt-4", "name": "test-user", "tool_calls": [], "tool_call_id": None, "created_at": current_time, "updated_at": current_time, "created_by_id": "user-123", "last_updated_by_id": "user-123", } for i in range(2) ] mock_messages = [] for msg_data in messages_data: mock_msg = Mock() for key, value in msg_data.items(): setattr(mock_msg, key, value) mock_messages.append(mock_msg) # Configure mock server responses mock_sync_server.user_manager.get_user_or_default.return_value = Mock(id="user-123") mock_sync_server.job_manager.get_job_messages.return_value = mock_messages # Test successful retrieval response = client.get( "/v1/runs/run-12345678/messages", headers={"user_id": "user-123"}, params={"limit": 10, "cursor": messages_data[1]["id"], "role": "user"}, ) assert response.status_code == 200 assert len(response.json()) == 2 assert response.json()[0]["id"] == messages_data[0]["id"] assert response.json()[1]["id"] == messages_data[1]["id"] # Verify mock calls mock_sync_server.user_manager.get_user_or_default.assert_called_once_with(user_id="user-123") mock_sync_server.job_manager.get_job_messages.assert_called_once_with( job_id="run-12345678", actor=mock_sync_server.user_manager.get_user_or_default.return_value, limit=10, cursor=messages_data[1]["id"], start_date=None, end_date=None, query_text=None, ascending=True, tags=None, match_all_tags=False, role="user", tool_name=None, ) def test_get_run_messages_not_found(client, mock_sync_server): """Test getting messages for a non-existent run.""" # Configure mock responses error_message = "Run 'run-nonexistent' not found" mock_sync_server.user_manager.get_user_or_default.return_value = Mock(id="user-123") mock_sync_server.job_manager.get_job_messages.side_effect = NoResultFound(error_message) response = client.get("/v1/runs/run-nonexistent/messages", headers={"user_id": "user-123"}) assert response.status_code == 404 assert error_message in response.json()["detail"] def test_get_run_usage(client, mock_sync_server): """Test getting usage statistics for a run.""" # Configure mock responses mock_sync_server.user_manager.get_user_or_default.return_value = Mock(id="user-123") mock_usage = Mock( completion_tokens=100, prompt_tokens=200, total_tokens=300, ) mock_sync_server.job_manager.get_job_usage.return_value = mock_usage # Make request response = client.get("/v1/runs/run-12345678/usage", headers={"user_id": "user-123"}) # Check response assert response.status_code == 200 assert response.json() == { "completion_tokens": 100, "prompt_tokens": 200, "total_tokens": 300, } # Verify mock calls mock_sync_server.user_manager.get_user_or_default.assert_called_once_with(user_id="user-123") mock_sync_server.job_manager.get_job_usage.assert_called_once_with( job_id="run-12345678", actor=mock_sync_server.user_manager.get_user_or_default.return_value, ) def test_get_run_usage_not_found(client, mock_sync_server): """Test getting usage statistics for a non-existent run.""" # Configure mock responses error_message = "Run 'run-nonexistent' not found" mock_sync_server.user_manager.get_user_or_default.return_value = Mock(id="user-123") mock_sync_server.job_manager.get_job_usage.side_effect = NoResultFound(error_message) # Make request response = client.get("/v1/runs/run-nonexistent/usage", headers={"user_id": "user-123"}) assert response.status_code == 404 assert error_message in response.json()["detail"]