feat: Factor out database access unless necessary in Tool Sandboxes (#1711)

This commit is contained in:
Matthew Zhou 2025-04-14 18:00:15 -07:00 committed by GitHub
parent f207c814ce
commit c5b6fb5744
5 changed files with 133 additions and 27 deletions

View File

@ -3,7 +3,7 @@ from typing import Any, Dict, Optional, Tuple, Type
from letta.log import get_logger from letta.log import get_logger
from letta.orm.enums import ToolType from letta.orm.enums import ToolType
from letta.schemas.agent import AgentState from letta.schemas.agent import AgentState
from letta.schemas.sandbox_config import SandboxRunResult from letta.schemas.sandbox_config import SandboxConfig, SandboxRunResult
from letta.schemas.tool import Tool from letta.schemas.tool import Tool
from letta.schemas.user import User from letta.schemas.user import User
from letta.services.tool_executor.tool_executor import ( from letta.services.tool_executor.tool_executor import (
@ -45,10 +45,18 @@ class ToolExecutorFactory:
class ToolExecutionManager: class ToolExecutionManager:
"""Manager class for tool execution operations.""" """Manager class for tool execution operations."""
def __init__(self, agent_state: AgentState, actor: User): def __init__(
self,
agent_state: AgentState,
actor: User,
sandbox_config: Optional[SandboxConfig] = None,
sandbox_env_vars: Optional[Dict[str, Any]] = None,
):
self.agent_state = agent_state self.agent_state = agent_state
self.logger = get_logger(__name__) self.logger = get_logger(__name__)
self.actor = actor self.actor = actor
self.sandbox_config = sandbox_config
self.sandbox_env_vars = sandbox_env_vars
def execute_tool(self, function_name: str, function_args: dict, tool: Tool) -> Tuple[Any, Optional[SandboxRunResult]]: def execute_tool(self, function_name: str, function_args: dict, tool: Tool) -> Tuple[Any, Optional[SandboxRunResult]]:
""" """
@ -67,7 +75,9 @@ class ToolExecutionManager:
executor = ToolExecutorFactory.get_executor(tool.tool_type) executor = ToolExecutorFactory.get_executor(tool.tool_type)
# Execute the tool # Execute the tool
return executor.execute(function_name, function_args, self.agent_state, tool, self.actor) return executor.execute(
function_name, function_args, self.agent_state, tool, self.actor, self.sandbox_config, self.sandbox_env_vars
)
except Exception as e: except Exception as e:
self.logger.error(f"Error executing tool {function_name}: {str(e)}") self.logger.error(f"Error executing tool {function_name}: {str(e)}")

View File

@ -1,6 +1,6 @@
import math import math
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Optional, Tuple from typing import Any, Dict, Optional, Tuple
from letta.constants import COMPOSIO_ENTITY_ENV_VAR_KEY, RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE from letta.constants import COMPOSIO_ENTITY_ENV_VAR_KEY, RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
from letta.functions.ast_parsers import coerce_dict_args_by_annotations, get_function_annotations_from_source from letta.functions.ast_parsers import coerce_dict_args_by_annotations, get_function_annotations_from_source
@ -8,7 +8,7 @@ from letta.functions.helpers import execute_composio_action, generate_composio_a
from letta.helpers.composio_helpers import get_composio_api_key from letta.helpers.composio_helpers import get_composio_api_key
from letta.helpers.json_helpers import json_dumps from letta.helpers.json_helpers import json_dumps
from letta.schemas.agent import AgentState from letta.schemas.agent import AgentState
from letta.schemas.sandbox_config import SandboxRunResult from letta.schemas.sandbox_config import SandboxConfig, SandboxRunResult
from letta.schemas.tool import Tool from letta.schemas.tool import Tool
from letta.schemas.user import User from letta.schemas.user import User
from letta.services.agent_manager import AgentManager from letta.services.agent_manager import AgentManager
@ -25,7 +25,14 @@ class ToolExecutor(ABC):
@abstractmethod @abstractmethod
def execute( def execute(
self, function_name: str, function_args: dict, agent_state: AgentState, tool: Tool, actor: User self,
function_name: str,
function_args: dict,
agent_state: AgentState,
tool: Tool,
actor: User,
sandbox_config: Optional[SandboxConfig] = None,
sandbox_env_vars: Optional[Dict[str, Any]] = None,
) -> Tuple[Any, Optional[SandboxRunResult]]: ) -> Tuple[Any, Optional[SandboxRunResult]]:
"""Execute the tool and return the result.""" """Execute the tool and return the result."""
@ -34,7 +41,14 @@ class LettaCoreToolExecutor(ToolExecutor):
"""Executor for LETTA core tools with direct implementation of functions.""" """Executor for LETTA core tools with direct implementation of functions."""
def execute( def execute(
self, function_name: str, function_args: dict, agent_state: AgentState, tool: Tool, actor: User self,
function_name: str,
function_args: dict,
agent_state: AgentState,
tool: Tool,
actor: User,
sandbox_config: Optional[SandboxConfig] = None,
sandbox_env_vars: Optional[Dict[str, Any]] = None,
) -> Tuple[Any, Optional[SandboxRunResult]]: ) -> Tuple[Any, Optional[SandboxRunResult]]:
# Map function names to method calls # Map function names to method calls
function_map = { function_map = {
@ -184,7 +198,14 @@ class LettaMemoryToolExecutor(ToolExecutor):
"""Executor for LETTA memory core tools with direct implementation.""" """Executor for LETTA memory core tools with direct implementation."""
def execute( def execute(
self, function_name: str, function_args: dict, agent_state: AgentState, tool: Tool, actor: User self,
function_name: str,
function_args: dict,
agent_state: AgentState,
tool: Tool,
actor: User,
sandbox_config: Optional[SandboxConfig] = None,
sandbox_env_vars: Optional[Dict[str, Any]] = None,
) -> Tuple[Any, Optional[SandboxRunResult]]: ) -> Tuple[Any, Optional[SandboxRunResult]]:
# Map function names to method calls # Map function names to method calls
function_map = { function_map = {
@ -244,7 +265,14 @@ class ExternalComposioToolExecutor(ToolExecutor):
"""Executor for external Composio tools.""" """Executor for external Composio tools."""
def execute( def execute(
self, function_name: str, function_args: dict, agent_state: AgentState, tool: Tool, actor: User self,
function_name: str,
function_args: dict,
agent_state: AgentState,
tool: Tool,
actor: User,
sandbox_config: Optional[SandboxConfig] = None,
sandbox_env_vars: Optional[Dict[str, Any]] = None,
) -> Tuple[Any, Optional[SandboxRunResult]]: ) -> Tuple[Any, Optional[SandboxRunResult]]:
action_name = generate_composio_action_from_func_name(tool.name) action_name = generate_composio_action_from_func_name(tool.name)
@ -324,7 +352,14 @@ class SandboxToolExecutor(ToolExecutor):
"""Executor for sandboxed tools.""" """Executor for sandboxed tools."""
async def execute( async def execute(
self, function_name: str, function_args: dict, agent_state: AgentState, tool: Tool, actor: User self,
function_name: str,
function_args: dict,
agent_state: AgentState,
tool: Tool,
actor: User,
sandbox_config: Optional[SandboxConfig] = None,
sandbox_env_vars: Optional[Dict[str, Any]] = None,
) -> Tuple[Any, Optional[SandboxRunResult]]: ) -> Tuple[Any, Optional[SandboxRunResult]]:
# Store original memory state # Store original memory state
@ -338,9 +373,13 @@ class SandboxToolExecutor(ToolExecutor):
# Execute in sandbox depending on API key # Execute in sandbox depending on API key
if tool_settings.e2b_api_key: if tool_settings.e2b_api_key:
sandbox = AsyncToolSandboxE2B(function_name, function_args, actor, tool_object=tool) sandbox = AsyncToolSandboxE2B(
function_name, function_args, actor, tool_object=tool, sandbox_config=sandbox_config, sandbox_env_vars=sandbox_env_vars
)
else: else:
sandbox = AsyncToolSandboxLocal(function_name, function_args, actor, tool_object=tool) sandbox = AsyncToolSandboxLocal(
function_name, function_args, actor, tool_object=tool, sandbox_config=sandbox_config, sandbox_env_vars=sandbox_env_vars
)
sandbox_run_result = await sandbox.run(agent_state=agent_state_copy) sandbox_run_result = await sandbox.run(agent_state=agent_state_copy)

View File

@ -7,7 +7,8 @@ from typing import Any, Dict, Optional, Tuple
from letta.functions.helpers import generate_model_from_args_json_schema from letta.functions.helpers import generate_model_from_args_json_schema
from letta.schemas.agent import AgentState from letta.schemas.agent import AgentState
from letta.schemas.sandbox_config import SandboxRunResult from letta.schemas.sandbox_config import SandboxConfig, SandboxRunResult
from letta.schemas.tool import Tool
from letta.services.helpers.tool_execution_helper import add_imports_and_pydantic_schemas_for_args from letta.services.helpers.tool_execution_helper import add_imports_and_pydantic_schemas_for_args
from letta.services.organization_manager import OrganizationManager from letta.services.organization_manager import OrganizationManager
from letta.services.sandbox_config_manager import SandboxConfigManager from letta.services.sandbox_config_manager import SandboxConfigManager
@ -20,7 +21,15 @@ class AsyncToolSandboxBase(ABC):
LOCAL_SANDBOX_RESULT_END_MARKER = str(uuid.uuid5(NAMESPACE, "local-sandbox-result-end-marker")) LOCAL_SANDBOX_RESULT_END_MARKER = str(uuid.uuid5(NAMESPACE, "local-sandbox-result-end-marker"))
LOCAL_SANDBOX_RESULT_VAR_NAME = "result_ZQqiequkcFwRwwGQMqkt" LOCAL_SANDBOX_RESULT_VAR_NAME = "result_ZQqiequkcFwRwwGQMqkt"
def __init__(self, tool_name: str, args: dict, user, tool_object=None): def __init__(
self,
tool_name: str,
args: dict,
user,
tool_object: Optional[Tool] = None,
sandbox_config: Optional[SandboxConfig] = None,
sandbox_env_vars: Optional[Dict[str, Any]] = None,
):
self.tool_name = tool_name self.tool_name = tool_name
self.args = args self.args = args
self.user = user self.user = user
@ -33,7 +42,12 @@ class AsyncToolSandboxBase(ABC):
f"Agent attempted to invoke tool {self.tool_name} that does not exist for organization {self.user.organization_id}" f"Agent attempted to invoke tool {self.tool_name} that does not exist for organization {self.user.organization_id}"
) )
self.sandbox_config_manager = SandboxConfigManager() # Store provided values or create manager to fetch them later
self.provided_sandbox_config = sandbox_config
self.provided_sandbox_env_vars = sandbox_env_vars
# Only create the manager if we need to (lazy initialization)
self._sandbox_config_manager = None
# See if we should inject agent_state or not based on the presence of the "agent_state" arg # See if we should inject agent_state or not based on the presence of the "agent_state" arg
if "agent_state" in self.parse_function_arguments(self.tool.source_code, self.tool.name): if "agent_state" in self.parse_function_arguments(self.tool.source_code, self.tool.name):
@ -41,6 +55,13 @@ class AsyncToolSandboxBase(ABC):
else: else:
self.inject_agent_state = False self.inject_agent_state = False
# Lazily initialize the manager only when needed
@property
def sandbox_config_manager(self):
if self._sandbox_config_manager is None:
self._sandbox_config_manager = SandboxConfigManager()
return self._sandbox_config_manager
@abstractmethod @abstractmethod
async def run( async def run(
self, self,

View File

@ -1,8 +1,9 @@
from typing import Dict, Optional from typing import Any, Dict, Optional
from letta.log import get_logger from letta.log import get_logger
from letta.schemas.agent import AgentState from letta.schemas.agent import AgentState
from letta.schemas.sandbox_config import SandboxConfig, SandboxRunResult, SandboxType from letta.schemas.sandbox_config import SandboxConfig, SandboxRunResult, SandboxType
from letta.schemas.tool import Tool
from letta.services.tool_sandbox.base import AsyncToolSandboxBase from letta.services.tool_sandbox.base import AsyncToolSandboxBase
from letta.utils import get_friendly_error_msg from letta.utils import get_friendly_error_msg
@ -12,8 +13,17 @@ logger = get_logger(__name__)
class AsyncToolSandboxE2B(AsyncToolSandboxBase): class AsyncToolSandboxE2B(AsyncToolSandboxBase):
METADATA_CONFIG_STATE_KEY = "config_state" METADATA_CONFIG_STATE_KEY = "config_state"
def __init__(self, tool_name: str, args: dict, user, force_recreate=True, tool_object=None): def __init__(
super().__init__(tool_name, args, user, tool_object) self,
tool_name: str,
args: dict,
user,
force_recreate=True,
tool_object: Optional[Tool] = None,
sandbox_config: Optional[SandboxConfig] = None,
sandbox_env_vars: Optional[Dict[str, Any]] = None,
):
super().__init__(tool_name, args, user, tool_object, sandbox_config=sandbox_config, sandbox_env_vars=sandbox_env_vars)
self.force_recreate = force_recreate self.force_recreate = force_recreate
async def run( async def run(
@ -36,6 +46,9 @@ class AsyncToolSandboxE2B(AsyncToolSandboxBase):
async def run_e2b_sandbox( async def run_e2b_sandbox(
self, agent_state: Optional[AgentState] = None, additional_env_vars: Optional[Dict] = None self, agent_state: Optional[AgentState] = None, additional_env_vars: Optional[Dict] = None
) -> SandboxRunResult: ) -> SandboxRunResult:
if self.provided_sandbox_config:
sbx_config = self.provided_sandbox_config
else:
sbx_config = self.sandbox_config_manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.E2B, actor=self.user) sbx_config = self.sandbox_config_manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.E2B, actor=self.user)
# TODO: So this defaults to force recreating always # TODO: So this defaults to force recreating always
# TODO: Eventually, provision one sandbox PER agent, and that agent re-uses that one specifically # TODO: Eventually, provision one sandbox PER agent, and that agent re-uses that one specifically
@ -50,7 +63,14 @@ class AsyncToolSandboxE2B(AsyncToolSandboxBase):
# Get environment variables for the sandbox # Get environment variables for the sandbox
# TODO: We set limit to 100 here, but maybe we want it uncapped? Realistically this should be fine. # TODO: We set limit to 100 here, but maybe we want it uncapped? Realistically this should be fine.
env_vars = self.sandbox_config_manager.get_sandbox_env_vars_as_dict(sandbox_config_id=sbx_config.id, actor=self.user, limit=100) env_vars = {}
if self.provided_sandbox_env_vars:
env_vars.update(self.provided_sandbox_env_vars)
else:
db_env_vars = self.sandbox_config_manager.get_sandbox_env_vars_as_dict(
sandbox_config_id=sbx_config.id, actor=self.user, limit=100
)
env_vars.update(db_env_vars)
# Get environment variables for this agent specifically # Get environment variables for this agent specifically
if agent_state: if agent_state:
env_vars.update(agent_state.get_agent_env_vars_as_dict()) env_vars.update(agent_state.get_agent_env_vars_as_dict())

View File

@ -2,10 +2,11 @@ import asyncio
import os import os
import sys import sys
import tempfile import tempfile
from typing import Dict, Optional, Tuple from typing import Any, Dict, Optional, Tuple
from letta.schemas.agent import AgentState from letta.schemas.agent import AgentState
from letta.schemas.sandbox_config import SandboxRunResult, SandboxType from letta.schemas.sandbox_config import SandboxConfig, SandboxRunResult, SandboxType
from letta.schemas.tool import Tool
from letta.services.helpers.tool_execution_helper import ( from letta.services.helpers.tool_execution_helper import (
create_venv_for_local_sandbox, create_venv_for_local_sandbox,
find_python_executable, find_python_executable,
@ -21,8 +22,17 @@ class AsyncToolSandboxLocal(AsyncToolSandboxBase):
METADATA_CONFIG_STATE_KEY = "config_state" METADATA_CONFIG_STATE_KEY = "config_state"
REQUIREMENT_TXT_NAME = "requirements.txt" REQUIREMENT_TXT_NAME = "requirements.txt"
def __init__(self, tool_name: str, args: dict, user, force_recreate_venv=False, tool_object=None): def __init__(
super().__init__(tool_name, args, user, tool_object) self,
tool_name: str,
args: dict,
user,
force_recreate_venv=False,
tool_object: Optional[Tool] = None,
sandbox_config: Optional[SandboxConfig] = None,
sandbox_env_vars: Optional[Dict[str, Any]] = None,
):
super().__init__(tool_name, args, user, tool_object, sandbox_config=sandbox_config, sandbox_env_vars=sandbox_env_vars)
self.force_recreate_venv = force_recreate_venv self.force_recreate_venv = force_recreate_venv
async def run( async def run(
@ -49,12 +59,18 @@ class AsyncToolSandboxLocal(AsyncToolSandboxBase):
always via subprocess for multi-core parallelism. always via subprocess for multi-core parallelism.
""" """
# Get sandbox configuration # Get sandbox configuration
if self.provided_sandbox_config:
sbx_config = self.provided_sandbox_config
else:
sbx_config = self.sandbox_config_manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.LOCAL, actor=self.user) sbx_config = self.sandbox_config_manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.LOCAL, actor=self.user)
local_configs = sbx_config.get_local_config() local_configs = sbx_config.get_local_config()
use_venv = local_configs.use_venv use_venv = local_configs.use_venv
# Prepare environment variables # Prepare environment variables
env = os.environ.copy() env = os.environ.copy()
if self.provided_sandbox_env_vars:
env.update(self.provided_sandbox_env_vars)
else:
env_vars = self.sandbox_config_manager.get_sandbox_env_vars_as_dict(sandbox_config_id=sbx_config.id, actor=self.user, limit=100) env_vars = self.sandbox_config_manager.get_sandbox_env_vars_as_dict(sandbox_config_id=sbx_config.id, actor=self.user, limit=100)
env.update(env_vars) env.update(env_vars)