feat: Factor out database access unless necessary in Tool Sandboxes (#1711)

This commit is contained in:
Matthew Zhou 2025-04-14 18:00:15 -07:00 committed by GitHub
parent f207c814ce
commit c5b6fb5744
5 changed files with 133 additions and 27 deletions

View File

@ -3,7 +3,7 @@ from typing import Any, Dict, Optional, Tuple, Type
from letta.log import get_logger
from letta.orm.enums import ToolType
from letta.schemas.agent import AgentState
from letta.schemas.sandbox_config import SandboxRunResult
from letta.schemas.sandbox_config import SandboxConfig, SandboxRunResult
from letta.schemas.tool import Tool
from letta.schemas.user import User
from letta.services.tool_executor.tool_executor import (
@ -45,10 +45,18 @@ class ToolExecutorFactory:
class ToolExecutionManager:
"""Manager class for tool execution operations."""
def __init__(self, agent_state: AgentState, actor: User):
def __init__(
self,
agent_state: AgentState,
actor: User,
sandbox_config: Optional[SandboxConfig] = None,
sandbox_env_vars: Optional[Dict[str, Any]] = None,
):
self.agent_state = agent_state
self.logger = get_logger(__name__)
self.actor = actor
self.sandbox_config = sandbox_config
self.sandbox_env_vars = sandbox_env_vars
def execute_tool(self, function_name: str, function_args: dict, tool: Tool) -> Tuple[Any, Optional[SandboxRunResult]]:
"""
@ -67,7 +75,9 @@ class ToolExecutionManager:
executor = ToolExecutorFactory.get_executor(tool.tool_type)
# Execute the tool
return executor.execute(function_name, function_args, self.agent_state, tool, self.actor)
return executor.execute(
function_name, function_args, self.agent_state, tool, self.actor, self.sandbox_config, self.sandbox_env_vars
)
except Exception as e:
self.logger.error(f"Error executing tool {function_name}: {str(e)}")

View File

@ -1,6 +1,6 @@
import math
from abc import ABC, abstractmethod
from typing import Any, Optional, Tuple
from typing import Any, Dict, Optional, Tuple
from letta.constants import COMPOSIO_ENTITY_ENV_VAR_KEY, RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
from letta.functions.ast_parsers import coerce_dict_args_by_annotations, get_function_annotations_from_source
@ -8,7 +8,7 @@ from letta.functions.helpers import execute_composio_action, generate_composio_a
from letta.helpers.composio_helpers import get_composio_api_key
from letta.helpers.json_helpers import json_dumps
from letta.schemas.agent import AgentState
from letta.schemas.sandbox_config import SandboxRunResult
from letta.schemas.sandbox_config import SandboxConfig, SandboxRunResult
from letta.schemas.tool import Tool
from letta.schemas.user import User
from letta.services.agent_manager import AgentManager
@ -25,7 +25,14 @@ class ToolExecutor(ABC):
@abstractmethod
def execute(
self, function_name: str, function_args: dict, agent_state: AgentState, tool: Tool, actor: User
self,
function_name: str,
function_args: dict,
agent_state: AgentState,
tool: Tool,
actor: User,
sandbox_config: Optional[SandboxConfig] = None,
sandbox_env_vars: Optional[Dict[str, Any]] = None,
) -> Tuple[Any, Optional[SandboxRunResult]]:
"""Execute the tool and return the result."""
@ -34,7 +41,14 @@ class LettaCoreToolExecutor(ToolExecutor):
"""Executor for LETTA core tools with direct implementation of functions."""
def execute(
self, function_name: str, function_args: dict, agent_state: AgentState, tool: Tool, actor: User
self,
function_name: str,
function_args: dict,
agent_state: AgentState,
tool: Tool,
actor: User,
sandbox_config: Optional[SandboxConfig] = None,
sandbox_env_vars: Optional[Dict[str, Any]] = None,
) -> Tuple[Any, Optional[SandboxRunResult]]:
# Map function names to method calls
function_map = {
@ -184,7 +198,14 @@ class LettaMemoryToolExecutor(ToolExecutor):
"""Executor for LETTA memory core tools with direct implementation."""
def execute(
self, function_name: str, function_args: dict, agent_state: AgentState, tool: Tool, actor: User
self,
function_name: str,
function_args: dict,
agent_state: AgentState,
tool: Tool,
actor: User,
sandbox_config: Optional[SandboxConfig] = None,
sandbox_env_vars: Optional[Dict[str, Any]] = None,
) -> Tuple[Any, Optional[SandboxRunResult]]:
# Map function names to method calls
function_map = {
@ -244,7 +265,14 @@ class ExternalComposioToolExecutor(ToolExecutor):
"""Executor for external Composio tools."""
def execute(
self, function_name: str, function_args: dict, agent_state: AgentState, tool: Tool, actor: User
self,
function_name: str,
function_args: dict,
agent_state: AgentState,
tool: Tool,
actor: User,
sandbox_config: Optional[SandboxConfig] = None,
sandbox_env_vars: Optional[Dict[str, Any]] = None,
) -> Tuple[Any, Optional[SandboxRunResult]]:
action_name = generate_composio_action_from_func_name(tool.name)
@ -324,7 +352,14 @@ class SandboxToolExecutor(ToolExecutor):
"""Executor for sandboxed tools."""
async def execute(
self, function_name: str, function_args: dict, agent_state: AgentState, tool: Tool, actor: User
self,
function_name: str,
function_args: dict,
agent_state: AgentState,
tool: Tool,
actor: User,
sandbox_config: Optional[SandboxConfig] = None,
sandbox_env_vars: Optional[Dict[str, Any]] = None,
) -> Tuple[Any, Optional[SandboxRunResult]]:
# Store original memory state
@ -338,9 +373,13 @@ class SandboxToolExecutor(ToolExecutor):
# Execute in sandbox depending on API key
if tool_settings.e2b_api_key:
sandbox = AsyncToolSandboxE2B(function_name, function_args, actor, tool_object=tool)
sandbox = AsyncToolSandboxE2B(
function_name, function_args, actor, tool_object=tool, sandbox_config=sandbox_config, sandbox_env_vars=sandbox_env_vars
)
else:
sandbox = AsyncToolSandboxLocal(function_name, function_args, actor, tool_object=tool)
sandbox = AsyncToolSandboxLocal(
function_name, function_args, actor, tool_object=tool, sandbox_config=sandbox_config, sandbox_env_vars=sandbox_env_vars
)
sandbox_run_result = await sandbox.run(agent_state=agent_state_copy)

View File

@ -7,7 +7,8 @@ from typing import Any, Dict, Optional, Tuple
from letta.functions.helpers import generate_model_from_args_json_schema
from letta.schemas.agent import AgentState
from letta.schemas.sandbox_config import SandboxRunResult
from letta.schemas.sandbox_config import SandboxConfig, SandboxRunResult
from letta.schemas.tool import Tool
from letta.services.helpers.tool_execution_helper import add_imports_and_pydantic_schemas_for_args
from letta.services.organization_manager import OrganizationManager
from letta.services.sandbox_config_manager import SandboxConfigManager
@ -20,7 +21,15 @@ class AsyncToolSandboxBase(ABC):
LOCAL_SANDBOX_RESULT_END_MARKER = str(uuid.uuid5(NAMESPACE, "local-sandbox-result-end-marker"))
LOCAL_SANDBOX_RESULT_VAR_NAME = "result_ZQqiequkcFwRwwGQMqkt"
def __init__(self, tool_name: str, args: dict, user, tool_object=None):
def __init__(
self,
tool_name: str,
args: dict,
user,
tool_object: Optional[Tool] = None,
sandbox_config: Optional[SandboxConfig] = None,
sandbox_env_vars: Optional[Dict[str, Any]] = None,
):
self.tool_name = tool_name
self.args = args
self.user = user
@ -33,7 +42,12 @@ class AsyncToolSandboxBase(ABC):
f"Agent attempted to invoke tool {self.tool_name} that does not exist for organization {self.user.organization_id}"
)
self.sandbox_config_manager = SandboxConfigManager()
# Store provided values or create manager to fetch them later
self.provided_sandbox_config = sandbox_config
self.provided_sandbox_env_vars = sandbox_env_vars
# Only create the manager if we need to (lazy initialization)
self._sandbox_config_manager = None
# See if we should inject agent_state or not based on the presence of the "agent_state" arg
if "agent_state" in self.parse_function_arguments(self.tool.source_code, self.tool.name):
@ -41,6 +55,13 @@ class AsyncToolSandboxBase(ABC):
else:
self.inject_agent_state = False
# Lazily initialize the manager only when needed
@property
def sandbox_config_manager(self):
if self._sandbox_config_manager is None:
self._sandbox_config_manager = SandboxConfigManager()
return self._sandbox_config_manager
@abstractmethod
async def run(
self,

View File

@ -1,8 +1,9 @@
from typing import Dict, Optional
from typing import Any, Dict, Optional
from letta.log import get_logger
from letta.schemas.agent import AgentState
from letta.schemas.sandbox_config import SandboxConfig, SandboxRunResult, SandboxType
from letta.schemas.tool import Tool
from letta.services.tool_sandbox.base import AsyncToolSandboxBase
from letta.utils import get_friendly_error_msg
@ -12,8 +13,17 @@ logger = get_logger(__name__)
class AsyncToolSandboxE2B(AsyncToolSandboxBase):
METADATA_CONFIG_STATE_KEY = "config_state"
def __init__(self, tool_name: str, args: dict, user, force_recreate=True, tool_object=None):
super().__init__(tool_name, args, user, tool_object)
def __init__(
self,
tool_name: str,
args: dict,
user,
force_recreate=True,
tool_object: Optional[Tool] = None,
sandbox_config: Optional[SandboxConfig] = None,
sandbox_env_vars: Optional[Dict[str, Any]] = None,
):
super().__init__(tool_name, args, user, tool_object, sandbox_config=sandbox_config, sandbox_env_vars=sandbox_env_vars)
self.force_recreate = force_recreate
async def run(
@ -36,7 +46,10 @@ class AsyncToolSandboxE2B(AsyncToolSandboxBase):
async def run_e2b_sandbox(
self, agent_state: Optional[AgentState] = None, additional_env_vars: Optional[Dict] = None
) -> SandboxRunResult:
sbx_config = self.sandbox_config_manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.E2B, actor=self.user)
if self.provided_sandbox_config:
sbx_config = self.provided_sandbox_config
else:
sbx_config = self.sandbox_config_manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.E2B, actor=self.user)
# TODO: So this defaults to force recreating always
# TODO: Eventually, provision one sandbox PER agent, and that agent re-uses that one specifically
e2b_sandbox = await self.create_e2b_sandbox_with_metadata_hash(sandbox_config=sbx_config)
@ -50,7 +63,14 @@ class AsyncToolSandboxE2B(AsyncToolSandboxBase):
# Get environment variables for the sandbox
# TODO: We set limit to 100 here, but maybe we want it uncapped? Realistically this should be fine.
env_vars = self.sandbox_config_manager.get_sandbox_env_vars_as_dict(sandbox_config_id=sbx_config.id, actor=self.user, limit=100)
env_vars = {}
if self.provided_sandbox_env_vars:
env_vars.update(self.provided_sandbox_env_vars)
else:
db_env_vars = self.sandbox_config_manager.get_sandbox_env_vars_as_dict(
sandbox_config_id=sbx_config.id, actor=self.user, limit=100
)
env_vars.update(db_env_vars)
# Get environment variables for this agent specifically
if agent_state:
env_vars.update(agent_state.get_agent_env_vars_as_dict())

View File

@ -2,10 +2,11 @@ import asyncio
import os
import sys
import tempfile
from typing import Dict, Optional, Tuple
from typing import Any, Dict, Optional, Tuple
from letta.schemas.agent import AgentState
from letta.schemas.sandbox_config import SandboxRunResult, SandboxType
from letta.schemas.sandbox_config import SandboxConfig, SandboxRunResult, SandboxType
from letta.schemas.tool import Tool
from letta.services.helpers.tool_execution_helper import (
create_venv_for_local_sandbox,
find_python_executable,
@ -21,8 +22,17 @@ class AsyncToolSandboxLocal(AsyncToolSandboxBase):
METADATA_CONFIG_STATE_KEY = "config_state"
REQUIREMENT_TXT_NAME = "requirements.txt"
def __init__(self, tool_name: str, args: dict, user, force_recreate_venv=False, tool_object=None):
super().__init__(tool_name, args, user, tool_object)
def __init__(
self,
tool_name: str,
args: dict,
user,
force_recreate_venv=False,
tool_object: Optional[Tool] = None,
sandbox_config: Optional[SandboxConfig] = None,
sandbox_env_vars: Optional[Dict[str, Any]] = None,
):
super().__init__(tool_name, args, user, tool_object, sandbox_config=sandbox_config, sandbox_env_vars=sandbox_env_vars)
self.force_recreate_venv = force_recreate_venv
async def run(
@ -49,14 +59,20 @@ class AsyncToolSandboxLocal(AsyncToolSandboxBase):
always via subprocess for multi-core parallelism.
"""
# Get sandbox configuration
sbx_config = self.sandbox_config_manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.LOCAL, actor=self.user)
if self.provided_sandbox_config:
sbx_config = self.provided_sandbox_config
else:
sbx_config = self.sandbox_config_manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.LOCAL, actor=self.user)
local_configs = sbx_config.get_local_config()
use_venv = local_configs.use_venv
# Prepare environment variables
env = os.environ.copy()
env_vars = self.sandbox_config_manager.get_sandbox_env_vars_as_dict(sandbox_config_id=sbx_config.id, actor=self.user, limit=100)
env.update(env_vars)
if self.provided_sandbox_env_vars:
env.update(self.provided_sandbox_env_vars)
else:
env_vars = self.sandbox_config_manager.get_sandbox_env_vars_as_dict(sandbox_config_id=sbx_config.id, actor=self.user, limit=100)
env.update(env_vars)
if agent_state:
env.update(agent_state.get_agent_env_vars_as_dict())