feat: Extend tool runs to also take in environment variables (#554)

This commit is contained in:
Matthew Zhou 2025-01-08 14:16:19 -10:00 committed by GitHub
parent 6db73894ea
commit 71031e9cda
5 changed files with 58 additions and 29 deletions

View File

@ -212,13 +212,9 @@ class ToolUpdate(LettaBase):
# TODO: Remove this, and clean usage of ToolUpdate everywhere else
class ToolRun(LettaBase):
id: str = Field(..., description="The ID of the tool to run.")
args: str = Field(..., description="The arguments to pass to the tool (as stringified JSON).")
class ToolRunFromSource(LettaBase):
source_code: str = Field(..., description="The source code of the function.")
args: str = Field(..., description="The arguments to pass to the tool (as stringified JSON).")
args: Dict[str, str] = Field(..., description="The arguments to pass to the tool.")
env_vars: Dict[str, str] = Field(None, description="The environment variables to pass to the tool.")
name: Optional[str] = Field(None, description="The name of the tool to run.")
source_type: Optional[str] = Field(None, description="The type of the source code.")

View File

@ -181,6 +181,7 @@ def run_tool_from_source(
tool_source=request.source_code,
tool_source_type=request.source_type,
tool_args=request.args,
tool_env_vars=request.env_vars,
tool_name=request.name,
actor=actor,
)

View File

@ -1,11 +1,10 @@
# inspecting tools
import json
import os
import traceback
import warnings
from abc import abstractmethod
from datetime import datetime
from typing import Callable, List, Optional, Tuple, Union
from typing import Callable, Dict, List, Optional, Tuple, Union
from composio.client import Composio
from composio.client.collections import ActionModel, AppModel
@ -1117,22 +1116,17 @@ class SyncServer(Server):
def run_tool_from_source(
self,
actor: User,
tool_args: str,
tool_args: Dict[str, str],
tool_source: str,
tool_env_vars: Optional[Dict[str, str]] = None,
tool_source_type: Optional[str] = None,
tool_name: Optional[str] = None,
) -> ToolReturnMessage:
"""Run a tool from source code"""
try:
tool_args_dict = json.loads(tool_args)
except json.JSONDecodeError:
raise ValueError("Invalid JSON string for tool_args")
if tool_source_type is not None and tool_source_type != "python":
raise ValueError("Only Python source code is supported at this time")
# NOTE: we're creating a floating Tool object and NOT persiting to DB
# NOTE: we're creating a floating Tool object and NOT persisting to DB
tool = Tool(
name=tool_name,
source_code=tool_source,
@ -1144,7 +1138,9 @@ class SyncServer(Server):
# Next, attempt to run the tool with the sandbox
try:
sandbox_run_result = ToolExecutionSandbox(tool.name, tool_args_dict, actor, tool_object=tool).run(agent_state=agent_state)
sandbox_run_result = ToolExecutionSandbox(tool.name, tool_args, actor, tool_object=tool).run(
agent_state=agent_state, additional_env_vars=tool_env_vars
)
return ToolReturnMessage(
id="null",
tool_call_id="null",

View File

@ -59,22 +59,23 @@ class ToolExecutionSandbox:
self.sandbox_config_manager = SandboxConfigManager(tool_settings)
self.force_recreate = force_recreate
def run(self, agent_state: Optional[AgentState] = None) -> SandboxRunResult:
def run(self, agent_state: Optional[AgentState] = None, additional_env_vars: Optional[Dict] = None) -> SandboxRunResult:
"""
Run the tool in a sandbox environment.
Args:
agent_state (Optional[AgentState]): The state of the agent invoking the tool
additional_env_vars (Optional[Dict]): Environment variables to inject into the sandbox
Returns:
Tuple[Any, Optional[AgentState]]: Tuple containing (tool_result, agent_state)
"""
if tool_settings.e2b_api_key:
logger.debug(f"Using e2b sandbox to execute {self.tool_name}")
result = self.run_e2b_sandbox(agent_state=agent_state)
result = self.run_e2b_sandbox(agent_state=agent_state, additional_env_vars=additional_env_vars)
else:
logger.debug(f"Using local sandbox to execute {self.tool_name}")
result = self.run_local_dir_sandbox(agent_state=agent_state)
result = self.run_local_dir_sandbox(agent_state=agent_state, additional_env_vars=additional_env_vars)
# Log out any stdout/stderr from the tool run
logger.debug(f"Executed tool '{self.tool_name}', logging output from tool run: \n")
@ -98,19 +99,25 @@ class ToolExecutionSandbox:
os.environ.clear()
os.environ.update(original_env) # Restore original environment variables
def run_local_dir_sandbox(self, agent_state: Optional[AgentState] = None) -> SandboxRunResult:
def run_local_dir_sandbox(
self, agent_state: Optional[AgentState] = None, additional_env_vars: Optional[Dict] = None
) -> SandboxRunResult:
sbx_config = self.sandbox_config_manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.LOCAL, actor=self.user)
local_configs = sbx_config.get_local_config()
# Get environment variables for the sandbox
env_vars = self.sandbox_config_manager.get_sandbox_env_vars_as_dict(sandbox_config_id=sbx_config.id, actor=self.user, limit=100)
env = os.environ.copy()
env_vars = self.sandbox_config_manager.get_sandbox_env_vars_as_dict(sandbox_config_id=sbx_config.id, actor=self.user, limit=100)
env.update(env_vars)
# Get environment variables for this agent specifically
if agent_state:
env.update(agent_state.get_agent_env_vars_as_dict())
# Finally, get any that are passed explicitly into the `run` function call
if additional_env_vars:
env.update(additional_env_vars)
# Safety checks
if not os.path.isdir(local_configs.sandbox_dir):
raise FileNotFoundError(f"Sandbox directory does not exist: {local_configs.sandbox_dir}")
@ -277,7 +284,7 @@ class ToolExecutionSandbox:
# e2b sandbox specific functions
def run_e2b_sandbox(self, agent_state: Optional[AgentState] = None) -> SandboxRunResult:
def run_e2b_sandbox(self, agent_state: Optional[AgentState] = None, additional_env_vars: Optional[Dict] = None) -> SandboxRunResult:
sbx_config = self.sandbox_config_manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.E2B, actor=self.user)
sbx = self.get_running_e2b_sandbox_with_same_state(sbx_config)
if not sbx or self.force_recreate:
@ -300,6 +307,10 @@ class ToolExecutionSandbox:
if agent_state:
env_vars.update(agent_state.get_agent_env_vars_as_dict())
# Finally, get any that are passed explicitly into the `run` function call
if additional_env_vars:
env_vars.update(additional_env_vars)
code = self.generate_execution_script(agent_state=agent_state)
execution = sbx.run_code(code, envs=env_vars)

View File

@ -687,6 +687,18 @@ def ingest(message: str):
'''
EXAMPLE_TOOL_SOURCE_WITH_ENV_VAR = '''
def ingest():
"""
Ingest a message into the system.
Returns:
str: The result of ingesting the message.
"""
import os
return os.getenv("secret")
'''
EXAMPLE_TOOL_SOURCE_WITH_DISTRACTOR = '''
def util_do_nothing():
@ -721,7 +733,7 @@ def test_tool_run(server, mock_e2b_api_key_none, user, agent_id):
actor=user,
tool_source=EXAMPLE_TOOL_SOURCE,
tool_source_type="python",
tool_args=json.dumps({"message": "Hello, world!"}),
tool_args={"message": "Hello, world!"},
# tool_name="ingest",
)
print(result)
@ -730,11 +742,24 @@ def test_tool_run(server, mock_e2b_api_key_none, user, agent_id):
assert not result.stdout
assert not result.stderr
result = server.run_tool_from_source(
actor=user,
tool_source=EXAMPLE_TOOL_SOURCE_WITH_ENV_VAR,
tool_source_type="python",
tool_args={},
tool_env_vars={"secret": "banana"},
)
print(result)
assert result.status == "success"
assert result.tool_return == "banana", result.tool_return
assert not result.stdout
assert not result.stderr
result = server.run_tool_from_source(
actor=user,
tool_source=EXAMPLE_TOOL_SOURCE,
tool_source_type="python",
tool_args=json.dumps({"message": "Well well well"}),
tool_args={"message": "Well well well"},
# tool_name="ingest",
)
print(result)
@ -747,7 +772,7 @@ def test_tool_run(server, mock_e2b_api_key_none, user, agent_id):
actor=user,
tool_source=EXAMPLE_TOOL_SOURCE,
tool_source_type="python",
tool_args=json.dumps({"bad_arg": "oh no"}),
tool_args={"bad_arg": "oh no"},
# tool_name="ingest",
)
print(result)
@ -763,7 +788,7 @@ def test_tool_run(server, mock_e2b_api_key_none, user, agent_id):
actor=user,
tool_source=EXAMPLE_TOOL_SOURCE_WITH_DISTRACTOR,
tool_source_type="python",
tool_args=json.dumps({"message": "Well well well"}),
tool_args={"message": "Well well well"},
# tool_name="ingest",
)
print(result)
@ -778,7 +803,7 @@ def test_tool_run(server, mock_e2b_api_key_none, user, agent_id):
actor=user,
tool_source=EXAMPLE_TOOL_SOURCE_WITH_DISTRACTOR,
tool_source_type="python",
tool_args=json.dumps({"message": "Well well well"}),
tool_args={"message": "Well well well"},
tool_name="ingest",
)
print(result)
@ -793,7 +818,7 @@ def test_tool_run(server, mock_e2b_api_key_none, user, agent_id):
actor=user,
tool_source=EXAMPLE_TOOL_SOURCE_WITH_DISTRACTOR,
tool_source_type="python",
tool_args=json.dumps({}),
tool_args={},
tool_name="util_do_nothing",
)
print(result)