feat: Support external codebases as a tool execution sandbox (#2159)

This commit is contained in:
Matthew Zhou 2024-12-04 11:37:51 -08:00 committed by GitHub
parent 128ec1aac7
commit 01afe75844
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 298 additions and 37 deletions

3
.gitignore vendored
View File

@ -1018,3 +1018,6 @@ pgdata/
letta/.pytest_cache/
memgpy/pytest.ini
**/**/pytest_cache
## ignore venvs
tests/test_tool_sandbox/restaurant_management_system/venv

View File

@ -24,6 +24,11 @@ class SandboxRunResult(BaseModel):
class LocalSandboxConfig(BaseModel):
sandbox_dir: str = Field(..., description="Directory for the sandbox environment.")
use_venv: bool = Field(False, description="Whether or not to use the venv, or run directly in the same run loop.")
venv_name: str = Field(
"venv",
description="The name for the venv in the sandbox directory. We first search for an existing venv with this name, otherwise, we make it from the requirements.txt.",
)
@property
def type(self) -> "SandboxType":

View File

@ -1876,7 +1876,8 @@ class SyncServer(Server):
apps = self.composio_client.apps.get()
apps_with_actions = []
for app in apps:
if app.meta["actionsCount"] > 0:
# A bit of hacky logic until composio patches this
if app.meta["actionsCount"] > 0 and not app.name.lower().endswith("_beta"):
apps_with_actions.append(app)
return apps_with_actions

View File

@ -4,9 +4,12 @@ import io
import os
import pickle
import runpy
import subprocess
import sys
import tempfile
from typing import Any, Optional
import uuid
import venv
from typing import Any, Dict, Optional, TextIO
from letta.log import get_logger
from letta.schemas.agent import AgentState
@ -24,6 +27,11 @@ class ToolExecutionSandbox:
METADATA_CONFIG_STATE_KEY = "config_state"
REQUIREMENT_TXT_NAME = "requirements.txt"
# For generating long, random marker hashes
NAMESPACE = uuid.NAMESPACE_DNS
LOCAL_SANDBOX_RESULT_START_MARKER = str(uuid.uuid5(NAMESPACE, "local-sandbox-result-start-marker"))
LOCAL_SANDBOX_RESULT_END_MARKER = str(uuid.uuid5(NAMESPACE, "local-sandbox-result-end-marker"))
# This is the variable name in the auto-generated code that contains the function results
# We make this a long random string to avoid collisions with any variables in the user's code
LOCAL_SANDBOX_RESULT_VAR_NAME = "result_ZQqiequkcFwRwwGQMqkt"
@ -65,12 +73,10 @@ class ToolExecutionSandbox:
"""
if tool_settings.e2b_api_key:
logger.debug(f"Using e2b sandbox to execute {self.tool_name}")
code = self.generate_execution_script(agent_state=agent_state)
result = self.run_e2b_sandbox(code=code)
result = self.run_e2b_sandbox(agent_state=agent_state)
else:
logger.debug(f"Using local sandbox to execute {self.tool_name}")
code = self.generate_execution_script(agent_state=agent_state)
result = self.run_local_dir_sandbox(code=code)
result = self.run_local_dir_sandbox(agent_state=agent_state)
# Log out any stdout from the tool run
logger.debug(f"Executed tool '{self.tool_name}', logging stdout from tool run: \n")
@ -94,12 +100,14 @@ class ToolExecutionSandbox:
os.environ.clear()
os.environ.update(original_env) # Restore original environment variables
def run_local_dir_sandbox(self, code: str) -> Optional[SandboxRunResult]:
def run_local_dir_sandbox(self, agent_state: AgentState) -> Optional[SandboxRunResult]:
sbx_config = self.sandbox_config_manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.LOCAL, actor=self.user)
local_configs = sbx_config.get_local_config()
# Get environment variables for the sandbox
env_vars = self.sandbox_config_manager.get_sandbox_env_vars_as_dict(sandbox_config_id=sbx_config.id, actor=self.user, limit=100)
env = os.environ.copy()
env.update(env_vars)
# Safety checks
if not os.path.isdir(local_configs.sandbox_dir):
@ -107,6 +115,12 @@ class ToolExecutionSandbox:
# Write the code to a temp file in the sandbox_dir
with tempfile.NamedTemporaryFile(mode="w", dir=local_configs.sandbox_dir, suffix=".py", delete=False) as temp_file:
if local_configs.use_venv:
# If using venv, we need to wrap with special string markers to separate out the output and the stdout (since it is all in stdout)
code = self.generate_execution_script(agent_state=agent_state, wrap_print_with_markers=True)
else:
code = self.generate_execution_script(agent_state=agent_state)
temp_file.write(code)
temp_file.flush()
temp_file_path = temp_file.name
@ -114,39 +128,122 @@ class ToolExecutionSandbox:
# Save the old stdout
old_stdout = sys.stdout
try:
# Redirect stdout to capture script output
captured_stdout = io.StringIO()
sys.stdout = captured_stdout
# Execute the temp file
with self.temporary_env_vars(env_vars):
result = runpy.run_path(temp_file_path, init_globals=env_vars)
# Fetch the result
func_result = result.get(self.LOCAL_SANDBOX_RESULT_VAR_NAME)
func_return, agent_state = self.parse_best_effort(func_result)
# Restore stdout and collect captured output
sys.stdout = old_stdout
stdout_output = captured_stdout.getvalue()
return SandboxRunResult(
func_return=func_return,
agent_state=agent_state,
stdout=[stdout_output],
sandbox_config_fingerprint=sbx_config.fingerprint(),
)
if local_configs.use_venv:
return self.run_local_dir_sandbox_venv(sbx_config, env, temp_file_path)
else:
return self.run_local_dir_sandbox_runpy(sbx_config, env_vars, temp_file_path, old_stdout)
except Exception as e:
logger.error(f"Executing tool {self.tool_name} has an unexpected error: {e}")
logger.error(f"Logging out tool {self.tool_name} auto-generated code for debugging: \n\n{code}")
raise e
finally:
# Clean up the temp file and restore stdout
sys.stdout = old_stdout
os.remove(temp_file_path)
def run_local_dir_sandbox_venv(self, sbx_config: SandboxConfig, env: Dict[str, str], temp_file_path: str) -> SandboxRunResult:
local_configs = sbx_config.get_local_config()
venv_path = os.path.join(local_configs.sandbox_dir, local_configs.venv_name)
# Safety checks for the venv
# Verify that the venv path exists and is a directory
if not os.path.isdir(venv_path):
logger.warning(f"Virtual environment directory does not exist at: {venv_path}, creating one now...")
self.create_venv_for_local_sandbox(sandbox_dir_path=local_configs.sandbox_dir, venv_path=venv_path, env=env)
# Ensure the python interpreter exists in the virtual environment
python_executable = os.path.join(venv_path, "bin", "python3")
if not os.path.isfile(python_executable):
raise FileNotFoundError(f"Python executable not found in virtual environment: {python_executable}")
# Set up env for venv
env["VIRTUAL_ENV"] = venv_path
env["PATH"] = os.path.join(venv_path, "bin") + ":" + env["PATH"]
# Execute the code in a restricted subprocess
try:
result = subprocess.run(
[os.path.join(venv_path, "bin", "python3"), temp_file_path],
env=env,
cwd=local_configs.sandbox_dir, # Restrict execution to sandbox_dir
timeout=60,
capture_output=True,
text=True,
)
if result.stderr:
logger.error(f"Sandbox execution error: {result.stderr}")
raise RuntimeError(f"Sandbox execution error: {result.stderr}")
func_result, stdout = self.parse_out_function_results_markers(result.stdout)
func_return, agent_state = self.parse_best_effort(func_result)
return SandboxRunResult(
func_return=func_return, agent_state=agent_state, stdout=[stdout], sandbox_config_fingerprint=sbx_config.fingerprint()
)
except subprocess.TimeoutExpired:
raise TimeoutError(f"Executing tool {self.tool_name} has timed out.")
except subprocess.CalledProcessError as e:
raise RuntimeError(f"Executing tool {self.tool_name} has process error: {e}")
except Exception as e:
raise RuntimeError(f"Executing tool {self.tool_name} has an unexpected error: {e}")
def run_local_dir_sandbox_runpy(
self, sbx_config: SandboxConfig, env_vars: Dict[str, str], temp_file_path: str, old_stdout: TextIO
) -> SandboxRunResult:
# Redirect stdout to capture script output
captured_stdout = io.StringIO()
sys.stdout = captured_stdout
# Execute the temp file
with self.temporary_env_vars(env_vars):
result = runpy.run_path(temp_file_path, init_globals=env_vars)
# Fetch the result
func_result = result.get(self.LOCAL_SANDBOX_RESULT_VAR_NAME)
func_return, agent_state = self.parse_best_effort(func_result)
# Restore stdout and collect captured output
sys.stdout = old_stdout
stdout_output = captured_stdout.getvalue()
return SandboxRunResult(
func_return=func_return,
agent_state=agent_state,
stdout=[stdout_output],
sandbox_config_fingerprint=sbx_config.fingerprint(),
)
def parse_out_function_results_markers(self, text: str):
marker_len = len(self.LOCAL_SANDBOX_RESULT_START_MARKER)
start_index = text.index(self.LOCAL_SANDBOX_RESULT_START_MARKER) + marker_len
end_index = text.index(self.LOCAL_SANDBOX_RESULT_END_MARKER)
return text[start_index:end_index], text[: start_index - marker_len] + text[end_index + +marker_len :]
def create_venv_for_local_sandbox(self, sandbox_dir_path: str, venv_path: str, env: Dict[str, str]):
# Step 1: Create the virtual environment
venv.create(venv_path, with_pip=True)
pip_path = os.path.join(venv_path, "bin", "pip")
try:
# Step 2: Upgrade pip
logger.info("Upgrading pip in the virtual environment...")
subprocess.run([pip_path, "install", "--upgrade", "pip"], env=env, check=True)
# Step 3: Install packages from requirements.txt if provided
requirements_txt_path = os.path.join(sandbox_dir_path, self.REQUIREMENT_TXT_NAME)
if os.path.isfile(requirements_txt_path):
logger.info(f"Installing packages from requirements file: {requirements_txt_path}")
subprocess.run([pip_path, "install", "-r", requirements_txt_path], env=env, check=True)
logger.info("Successfully installed packages from requirements.txt")
else:
logger.warning("No requirements.txt file provided or the file does not exist. Skipping package installation.")
except subprocess.CalledProcessError as e:
logger.error(f"Error while setting up the virtual environment: {e}")
raise RuntimeError(f"Failed to set up the virtual environment: {e}")
# e2b sandbox specific functions
def run_e2b_sandbox(self, code: str) -> Optional[SandboxRunResult]:
def run_e2b_sandbox(self, agent_state: AgentState) -> Optional[SandboxRunResult]:
sbx_config = self.sandbox_config_manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.E2B, actor=self.user)
sbx = self.get_running_e2b_sandbox_with_same_state(sbx_config)
if not sbx or self.force_recreate:
@ -158,6 +255,7 @@ class ToolExecutionSandbox:
# Get environment variables for the sandbox
# TODO: We set limit to 100 here, but maybe we want it uncapped? Realistically this should be fine.
env_vars = self.sandbox_config_manager.get_sandbox_env_vars_as_dict(sandbox_config_id=sbx_config.id, actor=self.user, limit=100)
code = self.generate_execution_script(agent_state=agent_state)
execution = sbx.run_code(code, envs=env_vars)
if execution.error is not None:
logger.error(f"Executing tool {self.tool_name} failed with {execution.error}")
@ -236,13 +334,14 @@ class ToolExecutionSandbox:
args.append(arg.arg)
return args
def generate_execution_script(self, agent_state: AgentState) -> str:
def generate_execution_script(self, agent_state: AgentState, wrap_print_with_markers: bool = False) -> str:
"""
Generate code to run inside of execution sandbox.
Passes into a serialized agent state into the code, to be accessed by the tool.
Args:
agent_state (AgentState): The agent state
wrap_print_with_markers (bool): If true, we wrap the final statement with a `print` and wrap with special markers
Returns:
code (str): The generated code strong
@ -286,7 +385,14 @@ class ToolExecutionSandbox:
code += (
f"{self.LOCAL_SANDBOX_RESULT_VAR_NAME} = base64.b64encode(pickle.dumps({self.LOCAL_SANDBOX_RESULT_VAR_NAME})).decode('utf-8')\n"
)
code += f"{self.LOCAL_SANDBOX_RESULT_VAR_NAME}\n"
if wrap_print_with_markers:
code += f"sys.stdout.write('{self.LOCAL_SANDBOX_RESULT_START_MARKER}')\n"
code += f"sys.stdout.write(str({self.LOCAL_SANDBOX_RESULT_VAR_NAME}))\n"
code += f"sys.stdout.write('{self.LOCAL_SANDBOX_RESULT_END_MARKER}')\n"
else:
code += f"{self.LOCAL_SANDBOX_RESULT_VAR_NAME}\n"
return code
def _convert_param_to_value(self, param_type: str, raw_value: str) -> str:

View File

@ -184,7 +184,7 @@ def composio_github_star_tool(test_user):
@pytest.fixture
def clear_core_memory(test_user):
def clear_core_memory_tool(test_user):
def clear_memory(agent_state: AgentState):
"""Clear the core memory"""
agent_state.memory.get_block("human").value = ""
@ -202,6 +202,17 @@ def core_memory_replace_tool(test_user):
yield tool
@pytest.fixture
def external_codebase_tool(test_user):
from tests.test_tool_sandbox.restaurant_management_system.adjust_menu_prices import (
adjust_menu_prices,
)
tool = create_tool_from_func(adjust_menu_prices)
tool = ToolManager().create_or_update_tool(tool, test_user)
yield tool
@pytest.fixture
def agent_state():
client = create_client()
@ -214,6 +225,8 @@ def agent_state():
# Local sandbox tests
@pytest.mark.local_sandbox
def test_local_sandbox_default(mock_e2b_api_key_none, add_integers_tool, test_user):
args = {"x": 10, "y": 5}
@ -231,10 +244,10 @@ def test_local_sandbox_default(mock_e2b_api_key_none, add_integers_tool, test_us
@pytest.mark.local_sandbox
def test_local_sandbox_stateful_tool(mock_e2b_api_key_none, clear_core_memory, test_user, agent_state):
def test_local_sandbox_stateful_tool(mock_e2b_api_key_none, clear_core_memory_tool, test_user, agent_state):
args = {}
# Run again to get actual response
sandbox = ToolExecutionSandbox(clear_core_memory.name, args, user_id=test_user.id)
sandbox = ToolExecutionSandbox(clear_core_memory_tool.name, args, user_id=test_user.id)
result = sandbox.run(agent_state=agent_state)
assert result.agent_state.memory.get_block("human").value == ""
assert result.agent_state.memory.get_block("persona").value == ""
@ -313,6 +326,29 @@ def test_local_sandbox_e2e_composio_star_github(mock_e2b_api_key_none, check_com
assert result.func_return["details"] == "Action executed successfully"
@pytest.mark.local_sandbox
def test_local_sandbox_external_codebase(mock_e2b_api_key_none, external_codebase_tool, test_user):
# Make the external codebase the sandbox config
manager = SandboxConfigManager(tool_settings)
# Set the sandbox to be within the external codebase path and use a venv
external_codebase_path = str(Path(__file__).parent / "test_tool_sandbox" / "restaurant_management_system")
local_sandbox_config = LocalSandboxConfig(sandbox_dir=external_codebase_path, use_venv=True)
config_create = SandboxConfigCreate(config=local_sandbox_config.model_dump())
manager.create_or_update_sandbox_config(sandbox_config_create=config_create, actor=test_user)
# Set the args
args = {"percentage": 10}
# Run again to get actual response
sandbox = ToolExecutionSandbox(external_codebase_tool.name, args, user_id=test_user.id)
result = sandbox.run()
# Assert that the function return is correct
assert result.func_return == "Price Adjustments:\nBurger: $8.99 -> $9.89\nFries: $2.99 -> $3.29\nSoda: $1.99 -> $2.19"
assert "Hello World" in result.stdout[0]
# E2B sandbox tests
@ -366,8 +402,8 @@ def test_e2b_sandbox_reuses_same_sandbox(check_e2b_key_is_set, list_tool, test_u
@pytest.mark.e2b_sandbox
def test_e2b_sandbox_stateful_tool(check_e2b_key_is_set, clear_core_memory, test_user, agent_state):
sandbox = ToolExecutionSandbox(clear_core_memory.name, {}, user_id=test_user.id)
def test_e2b_sandbox_stateful_tool(check_e2b_key_is_set, clear_core_memory_tool, test_user, agent_state):
sandbox = ToolExecutionSandbox(clear_core_memory_tool.name, {}, user_id=test_user.id)
# run the sandbox
result = sandbox.run(agent_state=agent_state)

View File

@ -0,0 +1,33 @@
def adjust_menu_prices(percentage: float) -> str:
"""
Tool: Adjust Menu Prices
Description: Adjusts the prices of all menu items by a given percentage.
Args:
percentage (float): The percentage by which to adjust prices. Positive for an increase, negative for a decrease.
Returns:
str: A formatted string summarizing the price adjustments.
"""
import cowsay
from core.menu import Menu, MenuItem # Import a class from the codebase
from core.utils import format_currency # Use a utility function to test imports
if not isinstance(percentage, (int, float)):
raise TypeError("percentage must be a number")
# Generate dummy menu object
menu = Menu()
menu.add_item(MenuItem("Burger", 8.99, "Main"))
menu.add_item(MenuItem("Fries", 2.99, "Side"))
menu.add_item(MenuItem("Soda", 1.99, "Drink"))
# Make adjustments and record
adjustments = []
for item in menu.items:
old_price = item.price
item.price += item.price * (percentage / 100)
adjustments.append(f"{item.name}: {format_currency(old_price)} -> {format_currency(item.price)}")
# Cowsay the adjustments because why not
cowsay.cow("Hello World")
return "Price Adjustments:\n" + "\n".join(adjustments)

View File

@ -0,0 +1,7 @@
class Customer:
def __init__(self, name: str, loyalty_points: int = 0):
self.name = name
self.loyalty_points = loyalty_points
def add_loyalty_points(self, points: int):
self.loyalty_points += points

View File

@ -0,0 +1,26 @@
from typing import List
class MenuItem:
def __init__(self, name: str, price: float, category: str):
self.name = name
self.price = price
self.category = category
def __repr__(self):
return f"{self.name} (${self.price:.2f}) - {self.category}"
class Menu:
def __init__(self):
self.items: List[MenuItem] = []
def add_item(self, item: MenuItem):
self.items.append(item)
def update_price(self, name: str, new_price: float):
for item in self.items:
if item.name == name:
item.price = new_price
return
raise ValueError(f"Menu item '{name}' not found.")

View File

@ -0,0 +1,16 @@
from typing import Dict
class Order:
def __init__(self, customer_name: str, items: Dict[str, int]):
self.customer_name = customer_name
self.items = items # Dictionary of item names to quantities
def calculate_total(self, menu):
total = 0
for item_name, quantity in self.items.items():
menu_item = next((item for item in menu.items if item.name == item_name), None)
if menu_item is None:
raise ValueError(f"Menu item '{item_name}' not found.")
total += menu_item.price * quantity
return total

View File

@ -0,0 +1,2 @@
def format_currency(value: float) -> str:
return f"${value:.2f}"

View File

@ -0,0 +1 @@
cowsay

View File

@ -0,0 +1,25 @@
import os
import runpy
def generate_and_execute_tool(tool_name: str, args: dict):
# Define the tool's directory and file
tools_dir = os.path.join(os.path.dirname(__file__), "tools")
script_path = os.path.join(tools_dir, f"{tool_name}_execution.py")
# Generate the Python script
with open(script_path, "w") as script_file:
script_file.write(f"from restaurant_management_system.tools.{tool_name} import {tool_name}\n\n")
arg_str = ", ".join([f"{key}={repr(value)}" for key, value in args.items()])
script_file.write(f"if __name__ == '__main__':\n")
script_file.write(f" result = {tool_name}({arg_str})\n")
script_file.write(f" print(result)\n")
# Execute the script
runpy.run_path(script_path, run_name="__main__")
# Optional: Clean up generated script
# os.remove(script_path)
generate_and_execute_tool("adjust_menu_prices", {"percentage": 10})