mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
feat: Factor out database access unless necessary in Tool Sandboxes (#1711)
This commit is contained in:
parent
f207c814ce
commit
c5b6fb5744
@ -3,7 +3,7 @@ from typing import Any, Dict, Optional, Tuple, Type
|
||||
from letta.log import get_logger
|
||||
from letta.orm.enums import ToolType
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.sandbox_config import SandboxRunResult
|
||||
from letta.schemas.sandbox_config import SandboxConfig, SandboxRunResult
|
||||
from letta.schemas.tool import Tool
|
||||
from letta.schemas.user import User
|
||||
from letta.services.tool_executor.tool_executor import (
|
||||
@ -45,10 +45,18 @@ class ToolExecutorFactory:
|
||||
class ToolExecutionManager:
|
||||
"""Manager class for tool execution operations."""
|
||||
|
||||
def __init__(self, agent_state: AgentState, actor: User):
|
||||
def __init__(
|
||||
self,
|
||||
agent_state: AgentState,
|
||||
actor: User,
|
||||
sandbox_config: Optional[SandboxConfig] = None,
|
||||
sandbox_env_vars: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
self.agent_state = agent_state
|
||||
self.logger = get_logger(__name__)
|
||||
self.actor = actor
|
||||
self.sandbox_config = sandbox_config
|
||||
self.sandbox_env_vars = sandbox_env_vars
|
||||
|
||||
def execute_tool(self, function_name: str, function_args: dict, tool: Tool) -> Tuple[Any, Optional[SandboxRunResult]]:
|
||||
"""
|
||||
@ -67,7 +75,9 @@ class ToolExecutionManager:
|
||||
executor = ToolExecutorFactory.get_executor(tool.tool_type)
|
||||
|
||||
# Execute the tool
|
||||
return executor.execute(function_name, function_args, self.agent_state, tool, self.actor)
|
||||
return executor.execute(
|
||||
function_name, function_args, self.agent_state, tool, self.actor, self.sandbox_config, self.sandbox_env_vars
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error executing tool {function_name}: {str(e)}")
|
||||
|
@ -1,6 +1,6 @@
|
||||
import math
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Optional, Tuple
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
from letta.constants import COMPOSIO_ENTITY_ENV_VAR_KEY, RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
|
||||
from letta.functions.ast_parsers import coerce_dict_args_by_annotations, get_function_annotations_from_source
|
||||
@ -8,7 +8,7 @@ from letta.functions.helpers import execute_composio_action, generate_composio_a
|
||||
from letta.helpers.composio_helpers import get_composio_api_key
|
||||
from letta.helpers.json_helpers import json_dumps
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.sandbox_config import SandboxRunResult
|
||||
from letta.schemas.sandbox_config import SandboxConfig, SandboxRunResult
|
||||
from letta.schemas.tool import Tool
|
||||
from letta.schemas.user import User
|
||||
from letta.services.agent_manager import AgentManager
|
||||
@ -25,7 +25,14 @@ class ToolExecutor(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def execute(
|
||||
self, function_name: str, function_args: dict, agent_state: AgentState, tool: Tool, actor: User
|
||||
self,
|
||||
function_name: str,
|
||||
function_args: dict,
|
||||
agent_state: AgentState,
|
||||
tool: Tool,
|
||||
actor: User,
|
||||
sandbox_config: Optional[SandboxConfig] = None,
|
||||
sandbox_env_vars: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[Any, Optional[SandboxRunResult]]:
|
||||
"""Execute the tool and return the result."""
|
||||
|
||||
@ -34,7 +41,14 @@ class LettaCoreToolExecutor(ToolExecutor):
|
||||
"""Executor for LETTA core tools with direct implementation of functions."""
|
||||
|
||||
def execute(
|
||||
self, function_name: str, function_args: dict, agent_state: AgentState, tool: Tool, actor: User
|
||||
self,
|
||||
function_name: str,
|
||||
function_args: dict,
|
||||
agent_state: AgentState,
|
||||
tool: Tool,
|
||||
actor: User,
|
||||
sandbox_config: Optional[SandboxConfig] = None,
|
||||
sandbox_env_vars: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[Any, Optional[SandboxRunResult]]:
|
||||
# Map function names to method calls
|
||||
function_map = {
|
||||
@ -184,7 +198,14 @@ class LettaMemoryToolExecutor(ToolExecutor):
|
||||
"""Executor for LETTA memory core tools with direct implementation."""
|
||||
|
||||
def execute(
|
||||
self, function_name: str, function_args: dict, agent_state: AgentState, tool: Tool, actor: User
|
||||
self,
|
||||
function_name: str,
|
||||
function_args: dict,
|
||||
agent_state: AgentState,
|
||||
tool: Tool,
|
||||
actor: User,
|
||||
sandbox_config: Optional[SandboxConfig] = None,
|
||||
sandbox_env_vars: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[Any, Optional[SandboxRunResult]]:
|
||||
# Map function names to method calls
|
||||
function_map = {
|
||||
@ -244,7 +265,14 @@ class ExternalComposioToolExecutor(ToolExecutor):
|
||||
"""Executor for external Composio tools."""
|
||||
|
||||
def execute(
|
||||
self, function_name: str, function_args: dict, agent_state: AgentState, tool: Tool, actor: User
|
||||
self,
|
||||
function_name: str,
|
||||
function_args: dict,
|
||||
agent_state: AgentState,
|
||||
tool: Tool,
|
||||
actor: User,
|
||||
sandbox_config: Optional[SandboxConfig] = None,
|
||||
sandbox_env_vars: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[Any, Optional[SandboxRunResult]]:
|
||||
action_name = generate_composio_action_from_func_name(tool.name)
|
||||
|
||||
@ -324,7 +352,14 @@ class SandboxToolExecutor(ToolExecutor):
|
||||
"""Executor for sandboxed tools."""
|
||||
|
||||
async def execute(
|
||||
self, function_name: str, function_args: dict, agent_state: AgentState, tool: Tool, actor: User
|
||||
self,
|
||||
function_name: str,
|
||||
function_args: dict,
|
||||
agent_state: AgentState,
|
||||
tool: Tool,
|
||||
actor: User,
|
||||
sandbox_config: Optional[SandboxConfig] = None,
|
||||
sandbox_env_vars: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[Any, Optional[SandboxRunResult]]:
|
||||
|
||||
# Store original memory state
|
||||
@ -338,9 +373,13 @@ class SandboxToolExecutor(ToolExecutor):
|
||||
|
||||
# Execute in sandbox depending on API key
|
||||
if tool_settings.e2b_api_key:
|
||||
sandbox = AsyncToolSandboxE2B(function_name, function_args, actor, tool_object=tool)
|
||||
sandbox = AsyncToolSandboxE2B(
|
||||
function_name, function_args, actor, tool_object=tool, sandbox_config=sandbox_config, sandbox_env_vars=sandbox_env_vars
|
||||
)
|
||||
else:
|
||||
sandbox = AsyncToolSandboxLocal(function_name, function_args, actor, tool_object=tool)
|
||||
sandbox = AsyncToolSandboxLocal(
|
||||
function_name, function_args, actor, tool_object=tool, sandbox_config=sandbox_config, sandbox_env_vars=sandbox_env_vars
|
||||
)
|
||||
|
||||
sandbox_run_result = await sandbox.run(agent_state=agent_state_copy)
|
||||
|
||||
|
@ -7,7 +7,8 @@ from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
from letta.functions.helpers import generate_model_from_args_json_schema
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.sandbox_config import SandboxRunResult
|
||||
from letta.schemas.sandbox_config import SandboxConfig, SandboxRunResult
|
||||
from letta.schemas.tool import Tool
|
||||
from letta.services.helpers.tool_execution_helper import add_imports_and_pydantic_schemas_for_args
|
||||
from letta.services.organization_manager import OrganizationManager
|
||||
from letta.services.sandbox_config_manager import SandboxConfigManager
|
||||
@ -20,7 +21,15 @@ class AsyncToolSandboxBase(ABC):
|
||||
LOCAL_SANDBOX_RESULT_END_MARKER = str(uuid.uuid5(NAMESPACE, "local-sandbox-result-end-marker"))
|
||||
LOCAL_SANDBOX_RESULT_VAR_NAME = "result_ZQqiequkcFwRwwGQMqkt"
|
||||
|
||||
def __init__(self, tool_name: str, args: dict, user, tool_object=None):
|
||||
def __init__(
|
||||
self,
|
||||
tool_name: str,
|
||||
args: dict,
|
||||
user,
|
||||
tool_object: Optional[Tool] = None,
|
||||
sandbox_config: Optional[SandboxConfig] = None,
|
||||
sandbox_env_vars: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
self.tool_name = tool_name
|
||||
self.args = args
|
||||
self.user = user
|
||||
@ -33,7 +42,12 @@ class AsyncToolSandboxBase(ABC):
|
||||
f"Agent attempted to invoke tool {self.tool_name} that does not exist for organization {self.user.organization_id}"
|
||||
)
|
||||
|
||||
self.sandbox_config_manager = SandboxConfigManager()
|
||||
# Store provided values or create manager to fetch them later
|
||||
self.provided_sandbox_config = sandbox_config
|
||||
self.provided_sandbox_env_vars = sandbox_env_vars
|
||||
|
||||
# Only create the manager if we need to (lazy initialization)
|
||||
self._sandbox_config_manager = None
|
||||
|
||||
# See if we should inject agent_state or not based on the presence of the "agent_state" arg
|
||||
if "agent_state" in self.parse_function_arguments(self.tool.source_code, self.tool.name):
|
||||
@ -41,6 +55,13 @@ class AsyncToolSandboxBase(ABC):
|
||||
else:
|
||||
self.inject_agent_state = False
|
||||
|
||||
# Lazily initialize the manager only when needed
|
||||
@property
|
||||
def sandbox_config_manager(self):
|
||||
if self._sandbox_config_manager is None:
|
||||
self._sandbox_config_manager = SandboxConfigManager()
|
||||
return self._sandbox_config_manager
|
||||
|
||||
@abstractmethod
|
||||
async def run(
|
||||
self,
|
||||
|
@ -1,8 +1,9 @@
|
||||
from typing import Dict, Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.sandbox_config import SandboxConfig, SandboxRunResult, SandboxType
|
||||
from letta.schemas.tool import Tool
|
||||
from letta.services.tool_sandbox.base import AsyncToolSandboxBase
|
||||
from letta.utils import get_friendly_error_msg
|
||||
|
||||
@ -12,8 +13,17 @@ logger = get_logger(__name__)
|
||||
class AsyncToolSandboxE2B(AsyncToolSandboxBase):
|
||||
METADATA_CONFIG_STATE_KEY = "config_state"
|
||||
|
||||
def __init__(self, tool_name: str, args: dict, user, force_recreate=True, tool_object=None):
|
||||
super().__init__(tool_name, args, user, tool_object)
|
||||
def __init__(
|
||||
self,
|
||||
tool_name: str,
|
||||
args: dict,
|
||||
user,
|
||||
force_recreate=True,
|
||||
tool_object: Optional[Tool] = None,
|
||||
sandbox_config: Optional[SandboxConfig] = None,
|
||||
sandbox_env_vars: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
super().__init__(tool_name, args, user, tool_object, sandbox_config=sandbox_config, sandbox_env_vars=sandbox_env_vars)
|
||||
self.force_recreate = force_recreate
|
||||
|
||||
async def run(
|
||||
@ -36,7 +46,10 @@ class AsyncToolSandboxE2B(AsyncToolSandboxBase):
|
||||
async def run_e2b_sandbox(
|
||||
self, agent_state: Optional[AgentState] = None, additional_env_vars: Optional[Dict] = None
|
||||
) -> SandboxRunResult:
|
||||
sbx_config = self.sandbox_config_manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.E2B, actor=self.user)
|
||||
if self.provided_sandbox_config:
|
||||
sbx_config = self.provided_sandbox_config
|
||||
else:
|
||||
sbx_config = self.sandbox_config_manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.E2B, actor=self.user)
|
||||
# TODO: So this defaults to force recreating always
|
||||
# TODO: Eventually, provision one sandbox PER agent, and that agent re-uses that one specifically
|
||||
e2b_sandbox = await self.create_e2b_sandbox_with_metadata_hash(sandbox_config=sbx_config)
|
||||
@ -50,7 +63,14 @@ class AsyncToolSandboxE2B(AsyncToolSandboxBase):
|
||||
|
||||
# Get environment variables for the sandbox
|
||||
# TODO: We set limit to 100 here, but maybe we want it uncapped? Realistically this should be fine.
|
||||
env_vars = self.sandbox_config_manager.get_sandbox_env_vars_as_dict(sandbox_config_id=sbx_config.id, actor=self.user, limit=100)
|
||||
env_vars = {}
|
||||
if self.provided_sandbox_env_vars:
|
||||
env_vars.update(self.provided_sandbox_env_vars)
|
||||
else:
|
||||
db_env_vars = self.sandbox_config_manager.get_sandbox_env_vars_as_dict(
|
||||
sandbox_config_id=sbx_config.id, actor=self.user, limit=100
|
||||
)
|
||||
env_vars.update(db_env_vars)
|
||||
# Get environment variables for this agent specifically
|
||||
if agent_state:
|
||||
env_vars.update(agent_state.get_agent_env_vars_as_dict())
|
||||
|
@ -2,10 +2,11 @@ import asyncio
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
from typing import Dict, Optional, Tuple
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.sandbox_config import SandboxRunResult, SandboxType
|
||||
from letta.schemas.sandbox_config import SandboxConfig, SandboxRunResult, SandboxType
|
||||
from letta.schemas.tool import Tool
|
||||
from letta.services.helpers.tool_execution_helper import (
|
||||
create_venv_for_local_sandbox,
|
||||
find_python_executable,
|
||||
@ -21,8 +22,17 @@ class AsyncToolSandboxLocal(AsyncToolSandboxBase):
|
||||
METADATA_CONFIG_STATE_KEY = "config_state"
|
||||
REQUIREMENT_TXT_NAME = "requirements.txt"
|
||||
|
||||
def __init__(self, tool_name: str, args: dict, user, force_recreate_venv=False, tool_object=None):
|
||||
super().__init__(tool_name, args, user, tool_object)
|
||||
def __init__(
|
||||
self,
|
||||
tool_name: str,
|
||||
args: dict,
|
||||
user,
|
||||
force_recreate_venv=False,
|
||||
tool_object: Optional[Tool] = None,
|
||||
sandbox_config: Optional[SandboxConfig] = None,
|
||||
sandbox_env_vars: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
super().__init__(tool_name, args, user, tool_object, sandbox_config=sandbox_config, sandbox_env_vars=sandbox_env_vars)
|
||||
self.force_recreate_venv = force_recreate_venv
|
||||
|
||||
async def run(
|
||||
@ -49,14 +59,20 @@ class AsyncToolSandboxLocal(AsyncToolSandboxBase):
|
||||
always via subprocess for multi-core parallelism.
|
||||
"""
|
||||
# Get sandbox configuration
|
||||
sbx_config = self.sandbox_config_manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.LOCAL, actor=self.user)
|
||||
if self.provided_sandbox_config:
|
||||
sbx_config = self.provided_sandbox_config
|
||||
else:
|
||||
sbx_config = self.sandbox_config_manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.LOCAL, actor=self.user)
|
||||
local_configs = sbx_config.get_local_config()
|
||||
use_venv = local_configs.use_venv
|
||||
|
||||
# Prepare environment variables
|
||||
env = os.environ.copy()
|
||||
env_vars = self.sandbox_config_manager.get_sandbox_env_vars_as_dict(sandbox_config_id=sbx_config.id, actor=self.user, limit=100)
|
||||
env.update(env_vars)
|
||||
if self.provided_sandbox_env_vars:
|
||||
env.update(self.provided_sandbox_env_vars)
|
||||
else:
|
||||
env_vars = self.sandbox_config_manager.get_sandbox_env_vars_as_dict(sandbox_config_id=sbx_config.id, actor=self.user, limit=100)
|
||||
env.update(env_vars)
|
||||
|
||||
if agent_state:
|
||||
env.update(agent_state.get_agent_env_vars_as_dict())
|
||||
|
Loading…
Reference in New Issue
Block a user