mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
458 lines
18 KiB
Python
458 lines
18 KiB
Python
from datetime import datetime
|
|
from unittest.mock import MagicMock, Mock, patch
|
|
|
|
import pytest
|
|
from composio.client.collections import ActionModel, ActionParametersModel, ActionResponseModel, AppModel
|
|
from fastapi.testclient import TestClient
|
|
|
|
from letta.orm.errors import NoResultFound
|
|
from letta.schemas.tool import ToolCreate, ToolUpdate
|
|
from letta.server.rest_api.app import app
|
|
from letta.server.rest_api.utils import get_letta_server
|
|
from tests.helpers.utils import create_tool_from_func
|
|
|
|
|
|
@pytest.fixture
|
|
def client():
|
|
return TestClient(app)
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_sync_server():
|
|
mock_server = Mock()
|
|
app.dependency_overrides[get_letta_server] = lambda: mock_server
|
|
return mock_server
|
|
|
|
|
|
@pytest.fixture
|
|
def add_integers_tool():
|
|
def add(x: int, y: int) -> int:
|
|
"""
|
|
Simple function that adds two integers.
|
|
|
|
Parameters:
|
|
x (int): The first integer to add.
|
|
y (int): The second integer to add.
|
|
|
|
Returns:
|
|
int: The result of adding x and y.
|
|
"""
|
|
return x + y
|
|
|
|
tool = create_tool_from_func(add)
|
|
yield tool
|
|
|
|
|
|
@pytest.fixture
|
|
def create_integers_tool(add_integers_tool):
|
|
tool_create = ToolCreate(
|
|
name=add_integers_tool.name,
|
|
description=add_integers_tool.description,
|
|
tags=add_integers_tool.tags,
|
|
module=add_integers_tool.module,
|
|
source_code=add_integers_tool.source_code,
|
|
source_type=add_integers_tool.source_type,
|
|
json_schema=add_integers_tool.json_schema,
|
|
)
|
|
yield tool_create
|
|
|
|
|
|
@pytest.fixture
|
|
def update_integers_tool(add_integers_tool):
|
|
tool_update = ToolUpdate(
|
|
name=add_integers_tool.name,
|
|
description=add_integers_tool.description,
|
|
tags=add_integers_tool.tags,
|
|
module=add_integers_tool.module,
|
|
source_code=add_integers_tool.source_code,
|
|
source_type=add_integers_tool.source_type,
|
|
json_schema=add_integers_tool.json_schema,
|
|
)
|
|
yield tool_update
|
|
|
|
|
|
@pytest.fixture
|
|
def composio_apps():
|
|
affinity_app = AppModel(
|
|
name="affinity",
|
|
key="affinity",
|
|
appId="3a7d2dc7-c58c-4491-be84-f64b1ff498a8",
|
|
description="Affinity helps private capital investors to find, manage, and close more deals",
|
|
categories=["CRM"],
|
|
meta={
|
|
"is_custom_app": False,
|
|
"triggersCount": 0,
|
|
"actionsCount": 20,
|
|
"documentation_doc_text": None,
|
|
"configuration_docs_text": None,
|
|
},
|
|
logo="https://cdn.jsdelivr.net/gh/ComposioHQ/open-logos@master/affinity.jpeg",
|
|
docs=None,
|
|
group=None,
|
|
status=None,
|
|
enabled=False,
|
|
no_auth=False,
|
|
auth_schemes=None,
|
|
testConnectors=None,
|
|
documentation_doc_text=None,
|
|
configuration_docs_text=None,
|
|
)
|
|
yield [affinity_app]
|
|
|
|
|
|
@pytest.fixture
|
|
def composio_actions():
|
|
yield [
|
|
ActionModel(
|
|
name="AFFINITY_GET_ALL_COMPANIES",
|
|
display_name="Get all companies",
|
|
parameters=ActionParametersModel(
|
|
properties={
|
|
"cursor": {"default": None, "description": "Cursor for the next or previous page", "title": "Cursor", "type": "string"},
|
|
"limit": {"default": 100, "description": "Number of items to include in the page", "title": "Limit", "type": "integer"},
|
|
"ids": {"default": None, "description": "Company IDs", "items": {"type": "integer"}, "title": "Ids", "type": "array"},
|
|
"fieldIds": {
|
|
"default": None,
|
|
"description": "Field IDs for which to return field data",
|
|
"items": {"type": "string"},
|
|
"title": "Fieldids",
|
|
"type": "array",
|
|
},
|
|
"fieldTypes": {
|
|
"default": None,
|
|
"description": "Field Types for which to return field data",
|
|
"items": {"enum": ["enriched", "global", "relationship-intelligence"], "title": "FieldtypesEnm", "type": "string"},
|
|
"title": "Fieldtypes",
|
|
"type": "array",
|
|
},
|
|
},
|
|
title="GetAllCompaniesRequest",
|
|
type="object",
|
|
required=None,
|
|
),
|
|
response=ActionResponseModel(
|
|
properties={
|
|
"data": {"title": "Data", "type": "object"},
|
|
"successful": {
|
|
"description": "Whether or not the action execution was successful or not",
|
|
"title": "Successful",
|
|
"type": "boolean",
|
|
},
|
|
"error": {
|
|
"anyOf": [{"type": "string"}, {"type": "null"}],
|
|
"default": None,
|
|
"description": "Error if any occurred during the execution of the action",
|
|
"title": "Error",
|
|
},
|
|
},
|
|
title="GetAllCompaniesResponse",
|
|
type="object",
|
|
required=["data", "successful"],
|
|
),
|
|
appName="affinity",
|
|
appId="affinity",
|
|
tags=["companies", "important"],
|
|
enabled=False,
|
|
logo="https://cdn.jsdelivr.net/gh/ComposioHQ/open-logos@master/affinity.jpeg",
|
|
description="Affinity Api Allows Paginated Access To Company Info And Custom Fields. Use `Field Ids` Or `Field Types` To Specify Data In A Request. Retrieve Field I Ds/Types Via Get `/V2/Companies/Fields`. Export Permission Needed.",
|
|
)
|
|
]
|
|
|
|
|
|
def configure_mock_sync_server(mock_sync_server):
|
|
# Mock sandbox config manager to return a valid API key
|
|
mock_api_key = Mock()
|
|
mock_api_key.value = "mock_composio_api_key"
|
|
mock_sync_server.sandbox_config_manager.list_sandbox_env_vars_by_key.return_value = [mock_api_key]
|
|
|
|
# Mock user retrieval
|
|
mock_sync_server.user_manager.get_user_or_default.return_value = Mock() # Provide additional attributes if needed
|
|
|
|
|
|
# ======================================================================================================================
|
|
# Tools Routes Tests
|
|
# ======================================================================================================================
|
|
def test_delete_tool(client, mock_sync_server, add_integers_tool):
|
|
mock_sync_server.tool_manager.delete_tool_by_id = MagicMock()
|
|
|
|
response = client.delete(f"/v1/tools/{add_integers_tool.id}", headers={"user_id": "test_user"})
|
|
|
|
assert response.status_code == 200
|
|
mock_sync_server.tool_manager.delete_tool_by_id.assert_called_once_with(
|
|
tool_id=add_integers_tool.id, actor=mock_sync_server.user_manager.get_user_or_default.return_value
|
|
)
|
|
|
|
|
|
def test_get_tool(client, mock_sync_server, add_integers_tool):
|
|
mock_sync_server.tool_manager.get_tool_by_id.return_value = add_integers_tool
|
|
|
|
response = client.get(f"/v1/tools/{add_integers_tool.id}", headers={"user_id": "test_user"})
|
|
|
|
assert response.status_code == 200
|
|
assert response.json()["id"] == add_integers_tool.id
|
|
assert response.json()["source_code"] == add_integers_tool.source_code
|
|
mock_sync_server.tool_manager.get_tool_by_id.assert_called_once_with(
|
|
tool_id=add_integers_tool.id, actor=mock_sync_server.user_manager.get_user_or_default.return_value
|
|
)
|
|
|
|
|
|
def test_get_tool_404(client, mock_sync_server, add_integers_tool):
|
|
mock_sync_server.tool_manager.get_tool_by_id.return_value = None
|
|
|
|
response = client.get(f"/v1/tools/{add_integers_tool.id}", headers={"user_id": "test_user"})
|
|
|
|
assert response.status_code == 404
|
|
assert response.json()["detail"] == f"Tool with id {add_integers_tool.id} not found."
|
|
|
|
|
|
def test_get_tool_id(client, mock_sync_server, add_integers_tool):
|
|
mock_sync_server.tool_manager.get_tool_by_name.return_value = add_integers_tool
|
|
|
|
response = client.get(f"/v1/tools/name/{add_integers_tool.name}", headers={"user_id": "test_user"})
|
|
|
|
assert response.status_code == 200
|
|
assert response.json() == add_integers_tool.id
|
|
mock_sync_server.tool_manager.get_tool_by_name.assert_called_once_with(
|
|
tool_name=add_integers_tool.name, actor=mock_sync_server.user_manager.get_user_or_default.return_value
|
|
)
|
|
|
|
|
|
def test_get_tool_id_404(client, mock_sync_server):
|
|
mock_sync_server.tool_manager.get_tool_by_name.return_value = None
|
|
|
|
response = client.get("/v1/tools/name/UnknownTool", headers={"user_id": "test_user"})
|
|
|
|
assert response.status_code == 404
|
|
assert "Tool with name UnknownTool" in response.json()["detail"]
|
|
|
|
|
|
def test_list_tools(client, mock_sync_server, add_integers_tool):
|
|
mock_sync_server.tool_manager.list_tools.return_value = [add_integers_tool]
|
|
|
|
response = client.get("/v1/tools", headers={"user_id": "test_user"})
|
|
|
|
assert response.status_code == 200
|
|
assert len(response.json()) == 1
|
|
assert response.json()[0]["id"] == add_integers_tool.id
|
|
mock_sync_server.tool_manager.list_tools.assert_called_once()
|
|
|
|
|
|
def test_create_tool(client, mock_sync_server, create_integers_tool, add_integers_tool):
|
|
mock_sync_server.tool_manager.create_tool.return_value = add_integers_tool
|
|
|
|
response = client.post("/v1/tools", json=create_integers_tool.model_dump(), headers={"user_id": "test_user"})
|
|
|
|
assert response.status_code == 200
|
|
assert response.json()["id"] == add_integers_tool.id
|
|
mock_sync_server.tool_manager.create_tool.assert_called_once()
|
|
|
|
|
|
def test_upsert_tool(client, mock_sync_server, create_integers_tool, add_integers_tool):
|
|
mock_sync_server.tool_manager.create_or_update_tool.return_value = add_integers_tool
|
|
|
|
response = client.put("/v1/tools", json=create_integers_tool.model_dump(), headers={"user_id": "test_user"})
|
|
|
|
assert response.status_code == 200
|
|
assert response.json()["id"] == add_integers_tool.id
|
|
mock_sync_server.tool_manager.create_or_update_tool.assert_called_once()
|
|
|
|
|
|
def test_update_tool(client, mock_sync_server, update_integers_tool, add_integers_tool):
|
|
mock_sync_server.tool_manager.update_tool_by_id.return_value = add_integers_tool
|
|
|
|
response = client.patch(f"/v1/tools/{add_integers_tool.id}", json=update_integers_tool.model_dump(), headers={"user_id": "test_user"})
|
|
|
|
assert response.status_code == 200
|
|
assert response.json()["id"] == add_integers_tool.id
|
|
mock_sync_server.tool_manager.update_tool_by_id.assert_called_once_with(
|
|
tool_id=add_integers_tool.id, tool_update=update_integers_tool, actor=mock_sync_server.user_manager.get_user_or_default.return_value
|
|
)
|
|
|
|
|
|
def test_upsert_base_tools(client, mock_sync_server, add_integers_tool):
|
|
mock_sync_server.tool_manager.upsert_base_tools.return_value = [add_integers_tool]
|
|
|
|
response = client.post("/v1/tools/add-base-tools", headers={"user_id": "test_user"})
|
|
|
|
assert response.status_code == 200
|
|
assert len(response.json()) == 1
|
|
assert response.json()[0]["id"] == add_integers_tool.id
|
|
mock_sync_server.tool_manager.upsert_base_tools.assert_called_once_with(
|
|
actor=mock_sync_server.user_manager.get_user_or_default.return_value
|
|
)
|
|
|
|
|
|
def test_list_composio_apps(client, mock_sync_server, composio_apps):
|
|
configure_mock_sync_server(mock_sync_server)
|
|
|
|
mock_sync_server.get_composio_apps.return_value = composio_apps
|
|
|
|
response = client.get("/v1/tools/composio/apps")
|
|
|
|
assert response.status_code == 200
|
|
assert len(response.json()) == 1
|
|
mock_sync_server.get_composio_apps.assert_called_once()
|
|
|
|
|
|
def test_list_composio_actions_by_app(client, mock_sync_server, composio_actions):
|
|
configure_mock_sync_server(mock_sync_server)
|
|
|
|
mock_sync_server.get_composio_actions_from_app_name.return_value = composio_actions
|
|
|
|
response = client.get("/v1/tools/composio/apps/App1/actions")
|
|
|
|
assert response.status_code == 200
|
|
assert len(response.json()) == 1
|
|
mock_sync_server.get_composio_actions_from_app_name.assert_called_once_with(composio_app_name="App1", api_key="mock_composio_api_key")
|
|
|
|
|
|
def test_add_composio_tool(client, mock_sync_server, add_integers_tool):
|
|
configure_mock_sync_server(mock_sync_server)
|
|
|
|
# Mock ToolCreate.from_composio to return the expected ToolCreate object
|
|
with patch("letta.schemas.tool.ToolCreate.from_composio") as mock_from_composio:
|
|
mock_from_composio.return_value = ToolCreate(
|
|
name=add_integers_tool.name,
|
|
source_code=add_integers_tool.source_code,
|
|
json_schema=add_integers_tool.json_schema,
|
|
)
|
|
|
|
# Mock server behavior
|
|
mock_sync_server.tool_manager.create_or_update_tool.return_value = add_integers_tool
|
|
|
|
# Perform the request
|
|
response = client.post(f"/v1/tools/composio/{add_integers_tool.name}", headers={"user_id": "test_user"})
|
|
|
|
# Assertions
|
|
assert response.status_code == 200
|
|
assert response.json()["id"] == add_integers_tool.id
|
|
mock_sync_server.tool_manager.create_or_update_tool.assert_called_once()
|
|
|
|
# Verify the mocked from_composio method was called
|
|
mock_from_composio.assert_called_once_with(action_name=add_integers_tool.name, api_key="mock_composio_api_key")
|
|
|
|
|
|
# ======================================================================================================================
|
|
# Runs Routes Tests
|
|
# ======================================================================================================================
|
|
|
|
|
|
def test_get_run_messages(client, mock_sync_server):
|
|
"""Test getting messages for a run."""
|
|
# Create properly formatted mock messages
|
|
current_time = datetime.utcnow()
|
|
messages_data = [
|
|
{
|
|
"id": f"message-{i:08x}", # Matches pattern '^message-[a-fA-F0-9]{8}'
|
|
"text": f"Test message {i}",
|
|
"role": "user",
|
|
"organization_id": "org-123",
|
|
"agent_id": "agent-123",
|
|
"model": "gpt-4",
|
|
"name": "test-user",
|
|
"tool_calls": [],
|
|
"tool_call_id": None,
|
|
"created_at": current_time,
|
|
"updated_at": current_time,
|
|
"created_by_id": "user-123",
|
|
"last_updated_by_id": "user-123",
|
|
}
|
|
for i in range(2)
|
|
]
|
|
|
|
mock_messages = []
|
|
for msg_data in messages_data:
|
|
mock_msg = Mock()
|
|
for key, value in msg_data.items():
|
|
setattr(mock_msg, key, value)
|
|
mock_messages.append(mock_msg)
|
|
|
|
# Configure mock server responses
|
|
mock_sync_server.user_manager.get_user_or_default.return_value = Mock(id="user-123")
|
|
mock_sync_server.job_manager.get_job_messages.return_value = mock_messages
|
|
|
|
# Test successful retrieval
|
|
response = client.get(
|
|
"/v1/runs/run-12345678/messages",
|
|
headers={"user_id": "user-123"},
|
|
params={"limit": 10, "cursor": messages_data[1]["id"], "role": "user"},
|
|
)
|
|
assert response.status_code == 200
|
|
assert len(response.json()) == 2
|
|
assert response.json()[0]["id"] == messages_data[0]["id"]
|
|
assert response.json()[1]["id"] == messages_data[1]["id"]
|
|
|
|
# Verify mock calls
|
|
mock_sync_server.user_manager.get_user_or_default.assert_called_once_with(user_id="user-123")
|
|
mock_sync_server.job_manager.get_job_messages.assert_called_once_with(
|
|
job_id="run-12345678",
|
|
actor=mock_sync_server.user_manager.get_user_or_default.return_value,
|
|
limit=10,
|
|
cursor=messages_data[1]["id"],
|
|
start_date=None,
|
|
end_date=None,
|
|
query_text=None,
|
|
ascending=True,
|
|
tags=None,
|
|
match_all_tags=False,
|
|
role="user",
|
|
tool_name=None,
|
|
)
|
|
|
|
|
|
def test_get_run_messages_not_found(client, mock_sync_server):
|
|
"""Test getting messages for a non-existent run."""
|
|
# Configure mock responses
|
|
error_message = "Run 'run-nonexistent' not found"
|
|
mock_sync_server.user_manager.get_user_or_default.return_value = Mock(id="user-123")
|
|
mock_sync_server.job_manager.get_job_messages.side_effect = NoResultFound(error_message)
|
|
|
|
response = client.get("/v1/runs/run-nonexistent/messages", headers={"user_id": "user-123"})
|
|
|
|
assert response.status_code == 404
|
|
assert error_message in response.json()["detail"]
|
|
|
|
|
|
def test_get_run_usage(client, mock_sync_server):
|
|
"""Test getting usage statistics for a run."""
|
|
# Configure mock responses
|
|
mock_sync_server.user_manager.get_user_or_default.return_value = Mock(id="user-123")
|
|
mock_usage = Mock(
|
|
completion_tokens=100,
|
|
prompt_tokens=200,
|
|
total_tokens=300,
|
|
)
|
|
mock_sync_server.job_manager.get_job_usage.return_value = mock_usage
|
|
|
|
# Make request
|
|
response = client.get("/v1/runs/run-12345678/usage", headers={"user_id": "user-123"})
|
|
|
|
# Check response
|
|
assert response.status_code == 200
|
|
assert response.json() == {
|
|
"completion_tokens": 100,
|
|
"prompt_tokens": 200,
|
|
"total_tokens": 300,
|
|
}
|
|
|
|
# Verify mock calls
|
|
mock_sync_server.user_manager.get_user_or_default.assert_called_once_with(user_id="user-123")
|
|
mock_sync_server.job_manager.get_job_usage.assert_called_once_with(
|
|
job_id="run-12345678",
|
|
actor=mock_sync_server.user_manager.get_user_or_default.return_value,
|
|
)
|
|
|
|
|
|
def test_get_run_usage_not_found(client, mock_sync_server):
|
|
"""Test getting usage statistics for a non-existent run."""
|
|
# Configure mock responses
|
|
error_message = "Run 'run-nonexistent' not found"
|
|
mock_sync_server.user_manager.get_user_or_default.return_value = Mock(id="user-123")
|
|
mock_sync_server.job_manager.get_job_usage.side_effect = NoResultFound(error_message)
|
|
|
|
# Make request
|
|
response = client.get("/v1/runs/run-nonexistent/usage", headers={"user_id": "user-123"})
|
|
|
|
assert response.status_code == 404
|
|
assert error_message in response.json()["detail"]
|