mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
feat: Add e2e tests where agents invoke MCP tooling (#2020)
This commit is contained in:
parent
0ef64dc365
commit
7b13ab91b1
@ -1,8 +1 @@
|
|||||||
{
|
{}
|
||||||
"mcpServers": {
|
|
||||||
"github_composio": {
|
|
||||||
"transport": "sse",
|
|
||||||
"url": "https://mcp.composio.dev/composio/server/3c44733b-75ae-4ba8-9a68-7153265fadd8"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
@ -1,18 +1,21 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
|
import threading
|
||||||
|
import uuid
|
||||||
import venv
|
import venv
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from mcp import Tool as MCPTool
|
from dotenv import load_dotenv
|
||||||
|
from letta_client import Letta, McpTool, ToolCallMessage, ToolReturnMessage
|
||||||
|
|
||||||
import letta.constants as constants
|
from letta.functions.mcp_client.types import SSEServerConfig, StdioServerConfig
|
||||||
from letta.config import LettaConfig
|
from letta.schemas.embedding_config import EmbeddingConfig
|
||||||
from letta.functions.mcp_client.types import MCPServerType, SSEServerConfig, StdioServerConfig
|
from letta.schemas.letta_message_content import TextContent
|
||||||
from letta.schemas.tool import ToolCreate
|
from letta.schemas.llm_config import LLMConfig
|
||||||
from letta.server.server import SyncServer
|
from letta.schemas.message import MessageCreate
|
||||||
from letta.utils import parse_json
|
from tests.utils import wait_for_server
|
||||||
|
|
||||||
|
|
||||||
def create_virtualenv_and_install_requirements(requirements_path: Path, name="venv") -> Path:
|
def create_virtualenv_and_install_requirements(requirements_path: Path, name="venv") -> Path:
|
||||||
@ -40,122 +43,165 @@ def create_virtualenv_and_install_requirements(requirements_path: Path, name="ve
|
|||||||
return venv_dir
|
return venv_dir
|
||||||
|
|
||||||
|
|
||||||
|
# --- Server Management --- #
|
||||||
|
|
||||||
|
|
||||||
|
def _run_server():
|
||||||
|
"""Starts the Letta server in a background thread."""
|
||||||
|
load_dotenv()
|
||||||
|
from letta.server.rest_api.app import start_server
|
||||||
|
|
||||||
|
start_server(debug=True)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def empty_mcp_config(tmp_path):
|
def empty_mcp_config():
|
||||||
path = Path(__file__).parent / "mcp_config.json"
|
path = Path(__file__).parent / "mcp_config.json"
|
||||||
path.write_text(json.dumps({})) # writes "{}"
|
path.write_text(json.dumps({})) # writes "{}"
|
||||||
|
|
||||||
return path
|
return path
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture()
|
||||||
def server(empty_mcp_config):
|
def server_url(empty_mcp_config):
|
||||||
config = LettaConfig.load()
|
"""Ensures a server is running and returns its base URL."""
|
||||||
print("CONFIG PATH", config.config_path)
|
url = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
|
||||||
|
|
||||||
config.save()
|
if not os.getenv("LETTA_SERVER_URL"):
|
||||||
|
thread = threading.Thread(target=_run_server, daemon=True)
|
||||||
|
thread.start()
|
||||||
|
wait_for_server(url)
|
||||||
|
|
||||||
old_dir = constants.LETTA_DIR
|
return url
|
||||||
constants.LETTA_DIR = str(Path(__file__).parent)
|
|
||||||
|
|
||||||
server = SyncServer()
|
|
||||||
yield server
|
|
||||||
constants.LETTA_DIR = old_dir
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture()
|
||||||
def default_user(server):
|
def client(server_url):
|
||||||
user = server.user_manager.get_user_or_default()
|
"""Creates a REST client for testing."""
|
||||||
yield user
|
client = Letta(base_url=server_url)
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
def test_sse_mcp_server(server, default_user):
|
@pytest.fixture()
|
||||||
assert server.mcp_clients == {}
|
def agent_state(client):
|
||||||
|
"""Creates an agent and ensures cleanup after tests."""
|
||||||
|
agent_state = client.agents.create(
|
||||||
|
name=f"test_compl_{str(uuid.uuid4())[5:]}",
|
||||||
|
include_base_tools=True,
|
||||||
|
memory_blocks=[
|
||||||
|
{
|
||||||
|
"label": "human",
|
||||||
|
"value": "Name: Matt",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"label": "persona",
|
||||||
|
"value": "Friendly agent",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
llm_config=LLMConfig.default_config(model_name="gpt-4o-mini"),
|
||||||
|
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||||
|
)
|
||||||
|
yield agent_state
|
||||||
|
client.agents.delete(agent_state.id)
|
||||||
|
|
||||||
|
|
||||||
|
def test_sse_mcp_server(client, agent_state):
|
||||||
mcp_server_name = "github_composio"
|
mcp_server_name = "github_composio"
|
||||||
server_url = "https://mcp.composio.dev/composio/server/3c44733b-75ae-4ba8-9a68-7153265fadd8"
|
server_url = "https://mcp.composio.dev/composio/server/3c44733b-75ae-4ba8-9a68-7153265fadd8"
|
||||||
sse_mcp_config = SSEServerConfig(server_name=mcp_server_name, server_url=server_url)
|
sse_mcp_config = SSEServerConfig(server_name=mcp_server_name, server_url=server_url)
|
||||||
server.add_mcp_server_to_config(sse_mcp_config)
|
client.tools.add_mcp_server(request=sse_mcp_config)
|
||||||
|
|
||||||
# Check that it's in clients
|
|
||||||
assert mcp_server_name in server.mcp_clients
|
|
||||||
|
|
||||||
# Check that it's in the server mapping
|
# Check that it's in the server mapping
|
||||||
mcp_server_mapping = server.get_mcp_servers()
|
mcp_server_mapping = client.tools.list_mcp_servers()
|
||||||
assert mcp_server_name in mcp_server_mapping
|
assert mcp_server_name in mcp_server_mapping
|
||||||
assert mcp_server_mapping[mcp_server_name] == sse_mcp_config
|
|
||||||
|
|
||||||
# Check tools
|
# Check tools
|
||||||
tools = server.get_tools_from_mcp_server(mcp_server_name)
|
tools = client.tools.list_mcp_tools_by_server(mcp_server_name=mcp_server_name)
|
||||||
assert len(tools) > 0
|
assert len(tools) > 0
|
||||||
assert isinstance(tools[0], MCPTool)
|
assert isinstance(tools[0], McpTool)
|
||||||
star_mcp_tool = next((t for t in tools if t.name == "GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER"), None)
|
star_mcp_tool = next((t for t in tools if t.name == "GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER"), None)
|
||||||
|
|
||||||
# Check that one of the tools are executable
|
# Check that one of the tools are executable
|
||||||
tool_create = ToolCreate.from_mcp(mcp_server_name=mcp_server_name, mcp_tool=star_mcp_tool)
|
letta_tool = client.tools.add_mcp_tool(mcp_server_name=mcp_server_name, mcp_tool_name=star_mcp_tool.name)
|
||||||
server.tool_manager.create_or_update_mcp_tool(tool_create=tool_create, mcp_server_name=mcp_server_name, actor=default_user)
|
|
||||||
|
|
||||||
function_response, is_error = server.mcp_clients[mcp_server_name].execute_tool(
|
tool_args = {"owner": "letta-ai", "repo": "letta"}
|
||||||
tool_name=star_mcp_tool.name, tool_args={"owner": "letta-ai", "repo": "letta"}
|
|
||||||
|
# Add to agent, have agent invoke tool
|
||||||
|
client.agents.tools.attach(agent_id=agent_state.id, tool_id=letta_tool.id)
|
||||||
|
response = client.agents.messages.create(
|
||||||
|
agent_id=agent_state.id,
|
||||||
|
messages=[
|
||||||
|
MessageCreate(
|
||||||
|
role="user",
|
||||||
|
content=[TextContent(text=f"Use the `{letta_tool.name}` tool with these arguments: {tool_args}.")],
|
||||||
|
)
|
||||||
|
],
|
||||||
)
|
)
|
||||||
assert not is_error
|
seq = response.messages
|
||||||
function_response = parse_json(function_response)
|
calls = [m for m in seq if isinstance(m, ToolCallMessage)]
|
||||||
assert function_response.get("successful"), function_response
|
assert calls, "Expected a ToolCallMessage"
|
||||||
assert function_response.get("data").get("details") == "Action executed successfully", function_response
|
assert calls[0].tool_call.name == "GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER"
|
||||||
|
|
||||||
|
returns = [m for m in seq if isinstance(m, ToolReturnMessage)]
|
||||||
|
assert returns, "Expected a ToolReturnMessage"
|
||||||
|
tr = returns[0]
|
||||||
|
# status field
|
||||||
|
assert tr.status == "success", f"Bad status: {tr.status}"
|
||||||
|
# parse JSON payload
|
||||||
|
payload = json.loads(tr.tool_return)
|
||||||
|
assert payload.get("successful", False), f"Tool returned failure payload: {payload}"
|
||||||
|
assert payload["data"]["details"] == "Action executed successfully", f"Unexpected details: {payload}"
|
||||||
|
|
||||||
|
|
||||||
def test_stdio_mcp_server(server, default_user):
|
def test_stdio_mcp_server(client, agent_state):
|
||||||
assert server.mcp_clients == {}
|
req_file = Path(__file__).parent / "weather" / "requirements.txt"
|
||||||
|
create_virtualenv_and_install_requirements(req_file, name="venv")
|
||||||
# Create venv
|
|
||||||
create_virtualenv_and_install_requirements(Path(__file__).parent / "weather" / "requirements.txt")
|
|
||||||
|
|
||||||
mcp_server_name = "weather"
|
mcp_server_name = "weather"
|
||||||
command = str(Path(__file__).parent / "weather" / "venv" / "bin" / "python3")
|
command = str(Path(__file__).parent / "weather" / "venv" / "bin" / "python3")
|
||||||
args = [str(Path(__file__).parent / "weather" / "weather.py")]
|
args = [str(Path(__file__).parent / "weather" / "weather.py")]
|
||||||
stdio_mcp_config = StdioServerConfig(server_name=mcp_server_name, command=command, args=args)
|
|
||||||
server.add_mcp_server_to_config(stdio_mcp_config)
|
|
||||||
|
|
||||||
# Check that it's in clients
|
stdio_config = StdioServerConfig(
|
||||||
assert mcp_server_name in server.mcp_clients
|
server_name=mcp_server_name,
|
||||||
|
command=command,
|
||||||
# Check that it's in the server mapping
|
args=args,
|
||||||
mcp_server_mapping = server.get_mcp_servers()
|
|
||||||
assert mcp_server_name in mcp_server_mapping
|
|
||||||
assert mcp_server_mapping[mcp_server_name] == StdioServerConfig(
|
|
||||||
server_name=mcp_server_name, type=MCPServerType.STDIO, command=command, args=args, env=None
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check that it can return valid tools
|
client.tools.add_mcp_server(request=stdio_config)
|
||||||
tools = server.get_tools_from_mcp_server(mcp_server_name)
|
|
||||||
assert tools == [
|
|
||||||
MCPTool(
|
|
||||||
name="get_alerts",
|
|
||||||
description="Get weather alerts for a US state.\n\n Args:\n state: Two-letter US state code (e.g. CA, NY)\n ",
|
|
||||||
inputSchema={
|
|
||||||
"properties": {"state": {"title": "State", "type": "string"}},
|
|
||||||
"required": ["state"],
|
|
||||||
"title": "get_alertsArguments",
|
|
||||||
"type": "object",
|
|
||||||
},
|
|
||||||
),
|
|
||||||
MCPTool(
|
|
||||||
name="get_forecast",
|
|
||||||
description="Get weather forecast for a location.\n\n Args:\n latitude: Latitude of the location\n longitude: Longitude of the location\n ",
|
|
||||||
inputSchema={
|
|
||||||
"properties": {"latitude": {"title": "Latitude", "type": "number"}, "longitude": {"title": "Longitude", "type": "number"}},
|
|
||||||
"required": ["latitude", "longitude"],
|
|
||||||
"title": "get_forecastArguments",
|
|
||||||
"type": "object",
|
|
||||||
},
|
|
||||||
),
|
|
||||||
]
|
|
||||||
get_alerts_mcp_tool = tools[0]
|
|
||||||
|
|
||||||
tool_create = ToolCreate.from_mcp(mcp_server_name=mcp_server_name, mcp_tool=get_alerts_mcp_tool)
|
servers = client.tools.list_mcp_servers()
|
||||||
server.tool_manager.create_or_update_mcp_tool(tool_create=tool_create, mcp_server_name=mcp_server_name, actor=default_user)
|
assert mcp_server_name in servers
|
||||||
|
|
||||||
# Attempt running the tool
|
tools = client.tools.list_mcp_tools_by_server(mcp_server_name=mcp_server_name)
|
||||||
function_response, is_error = server.mcp_clients[mcp_server_name].execute_tool(tool_name="get_alerts", tool_args={"state": "CA"})
|
assert tools, "Expected at least one tool from the weather MCP server"
|
||||||
assert not is_error
|
assert any(t.name == "get_alerts" for t in tools), f"Got: {[t.name for t in tools]}"
|
||||||
assert len(function_response) > 20, function_response # Crude heuristic for an expected result
|
|
||||||
|
get_alerts = next(t for t in tools if t.name == "get_alerts")
|
||||||
|
|
||||||
|
letta_tool = client.tools.add_mcp_tool(
|
||||||
|
mcp_server_name=mcp_server_name,
|
||||||
|
mcp_tool_name=get_alerts.name,
|
||||||
|
)
|
||||||
|
|
||||||
|
client.agents.tools.attach(agent_id=agent_state.id, tool_id=letta_tool.id)
|
||||||
|
|
||||||
|
response = client.agents.messages.create(
|
||||||
|
agent_id=agent_state.id,
|
||||||
|
messages=[
|
||||||
|
MessageCreate(
|
||||||
|
role="user",
|
||||||
|
content=[TextContent(text=(f"Use the `{letta_tool.name}` tool with these arguments: " f"{{'state': 'CA'}}."))],
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
calls = [m for m in response.messages if isinstance(m, ToolCallMessage) and m.tool_call.name == "get_alerts"]
|
||||||
|
assert calls, "Expected a get_alerts ToolCallMessage"
|
||||||
|
|
||||||
|
returns = [m for m in response.messages if isinstance(m, ToolReturnMessage) and m.tool_call_id == calls[0].tool_call.tool_call_id]
|
||||||
|
assert returns, "Expected a ToolReturnMessage for get_alerts"
|
||||||
|
ret = returns[0]
|
||||||
|
|
||||||
|
assert ret.status == "success", f"Unexpected status: {ret.status}"
|
||||||
|
# make sure there's at least some payload
|
||||||
|
assert len(ret.tool_return.strip()) >= 10, f"Expected at least 10 characters in tool_return, got {len(ret.tool_return.strip())}"
|
||||||
|
Loading…
Reference in New Issue
Block a user