mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
feat: Extend tool runs to also take in environment variables (#554)
This commit is contained in:
parent
6db73894ea
commit
71031e9cda
@ -212,13 +212,9 @@ class ToolUpdate(LettaBase):
|
||||
# TODO: Remove this, and clean usage of ToolUpdate everywhere else
|
||||
|
||||
|
||||
class ToolRun(LettaBase):
|
||||
id: str = Field(..., description="The ID of the tool to run.")
|
||||
args: str = Field(..., description="The arguments to pass to the tool (as stringified JSON).")
|
||||
|
||||
|
||||
class ToolRunFromSource(LettaBase):
|
||||
source_code: str = Field(..., description="The source code of the function.")
|
||||
args: str = Field(..., description="The arguments to pass to the tool (as stringified JSON).")
|
||||
args: Dict[str, str] = Field(..., description="The arguments to pass to the tool.")
|
||||
env_vars: Dict[str, str] = Field(None, description="The environment variables to pass to the tool.")
|
||||
name: Optional[str] = Field(None, description="The name of the tool to run.")
|
||||
source_type: Optional[str] = Field(None, description="The type of the source code.")
|
||||
|
@ -181,6 +181,7 @@ def run_tool_from_source(
|
||||
tool_source=request.source_code,
|
||||
tool_source_type=request.source_type,
|
||||
tool_args=request.args,
|
||||
tool_env_vars=request.env_vars,
|
||||
tool_name=request.name,
|
||||
actor=actor,
|
||||
)
|
||||
|
@ -1,11 +1,10 @@
|
||||
# inspecting tools
|
||||
import json
|
||||
import os
|
||||
import traceback
|
||||
import warnings
|
||||
from abc import abstractmethod
|
||||
from datetime import datetime
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from composio.client import Composio
|
||||
from composio.client.collections import ActionModel, AppModel
|
||||
@ -1117,22 +1116,17 @@ class SyncServer(Server):
|
||||
def run_tool_from_source(
|
||||
self,
|
||||
actor: User,
|
||||
tool_args: str,
|
||||
tool_args: Dict[str, str],
|
||||
tool_source: str,
|
||||
tool_env_vars: Optional[Dict[str, str]] = None,
|
||||
tool_source_type: Optional[str] = None,
|
||||
tool_name: Optional[str] = None,
|
||||
) -> ToolReturnMessage:
|
||||
"""Run a tool from source code"""
|
||||
|
||||
try:
|
||||
tool_args_dict = json.loads(tool_args)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError("Invalid JSON string for tool_args")
|
||||
|
||||
if tool_source_type is not None and tool_source_type != "python":
|
||||
raise ValueError("Only Python source code is supported at this time")
|
||||
|
||||
# NOTE: we're creating a floating Tool object and NOT persiting to DB
|
||||
# NOTE: we're creating a floating Tool object and NOT persisting to DB
|
||||
tool = Tool(
|
||||
name=tool_name,
|
||||
source_code=tool_source,
|
||||
@ -1144,7 +1138,9 @@ class SyncServer(Server):
|
||||
|
||||
# Next, attempt to run the tool with the sandbox
|
||||
try:
|
||||
sandbox_run_result = ToolExecutionSandbox(tool.name, tool_args_dict, actor, tool_object=tool).run(agent_state=agent_state)
|
||||
sandbox_run_result = ToolExecutionSandbox(tool.name, tool_args, actor, tool_object=tool).run(
|
||||
agent_state=agent_state, additional_env_vars=tool_env_vars
|
||||
)
|
||||
return ToolReturnMessage(
|
||||
id="null",
|
||||
tool_call_id="null",
|
||||
|
@ -59,22 +59,23 @@ class ToolExecutionSandbox:
|
||||
self.sandbox_config_manager = SandboxConfigManager(tool_settings)
|
||||
self.force_recreate = force_recreate
|
||||
|
||||
def run(self, agent_state: Optional[AgentState] = None) -> SandboxRunResult:
|
||||
def run(self, agent_state: Optional[AgentState] = None, additional_env_vars: Optional[Dict] = None) -> SandboxRunResult:
|
||||
"""
|
||||
Run the tool in a sandbox environment.
|
||||
|
||||
Args:
|
||||
agent_state (Optional[AgentState]): The state of the agent invoking the tool
|
||||
additional_env_vars (Optional[Dict]): Environment variables to inject into the sandbox
|
||||
|
||||
Returns:
|
||||
Tuple[Any, Optional[AgentState]]: Tuple containing (tool_result, agent_state)
|
||||
"""
|
||||
if tool_settings.e2b_api_key:
|
||||
logger.debug(f"Using e2b sandbox to execute {self.tool_name}")
|
||||
result = self.run_e2b_sandbox(agent_state=agent_state)
|
||||
result = self.run_e2b_sandbox(agent_state=agent_state, additional_env_vars=additional_env_vars)
|
||||
else:
|
||||
logger.debug(f"Using local sandbox to execute {self.tool_name}")
|
||||
result = self.run_local_dir_sandbox(agent_state=agent_state)
|
||||
result = self.run_local_dir_sandbox(agent_state=agent_state, additional_env_vars=additional_env_vars)
|
||||
|
||||
# Log out any stdout/stderr from the tool run
|
||||
logger.debug(f"Executed tool '{self.tool_name}', logging output from tool run: \n")
|
||||
@ -98,19 +99,25 @@ class ToolExecutionSandbox:
|
||||
os.environ.clear()
|
||||
os.environ.update(original_env) # Restore original environment variables
|
||||
|
||||
def run_local_dir_sandbox(self, agent_state: Optional[AgentState] = None) -> SandboxRunResult:
|
||||
def run_local_dir_sandbox(
|
||||
self, agent_state: Optional[AgentState] = None, additional_env_vars: Optional[Dict] = None
|
||||
) -> SandboxRunResult:
|
||||
sbx_config = self.sandbox_config_manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.LOCAL, actor=self.user)
|
||||
local_configs = sbx_config.get_local_config()
|
||||
|
||||
# Get environment variables for the sandbox
|
||||
env_vars = self.sandbox_config_manager.get_sandbox_env_vars_as_dict(sandbox_config_id=sbx_config.id, actor=self.user, limit=100)
|
||||
env = os.environ.copy()
|
||||
env_vars = self.sandbox_config_manager.get_sandbox_env_vars_as_dict(sandbox_config_id=sbx_config.id, actor=self.user, limit=100)
|
||||
env.update(env_vars)
|
||||
|
||||
# Get environment variables for this agent specifically
|
||||
if agent_state:
|
||||
env.update(agent_state.get_agent_env_vars_as_dict())
|
||||
|
||||
# Finally, get any that are passed explicitly into the `run` function call
|
||||
if additional_env_vars:
|
||||
env.update(additional_env_vars)
|
||||
|
||||
# Safety checks
|
||||
if not os.path.isdir(local_configs.sandbox_dir):
|
||||
raise FileNotFoundError(f"Sandbox directory does not exist: {local_configs.sandbox_dir}")
|
||||
@ -277,7 +284,7 @@ class ToolExecutionSandbox:
|
||||
|
||||
# e2b sandbox specific functions
|
||||
|
||||
def run_e2b_sandbox(self, agent_state: Optional[AgentState] = None) -> SandboxRunResult:
|
||||
def run_e2b_sandbox(self, agent_state: Optional[AgentState] = None, additional_env_vars: Optional[Dict] = None) -> SandboxRunResult:
|
||||
sbx_config = self.sandbox_config_manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.E2B, actor=self.user)
|
||||
sbx = self.get_running_e2b_sandbox_with_same_state(sbx_config)
|
||||
if not sbx or self.force_recreate:
|
||||
@ -300,6 +307,10 @@ class ToolExecutionSandbox:
|
||||
if agent_state:
|
||||
env_vars.update(agent_state.get_agent_env_vars_as_dict())
|
||||
|
||||
# Finally, get any that are passed explicitly into the `run` function call
|
||||
if additional_env_vars:
|
||||
env_vars.update(additional_env_vars)
|
||||
|
||||
code = self.generate_execution_script(agent_state=agent_state)
|
||||
execution = sbx.run_code(code, envs=env_vars)
|
||||
|
||||
|
@ -687,6 +687,18 @@ def ingest(message: str):
|
||||
|
||||
'''
|
||||
|
||||
EXAMPLE_TOOL_SOURCE_WITH_ENV_VAR = '''
|
||||
def ingest():
|
||||
"""
|
||||
Ingest a message into the system.
|
||||
|
||||
Returns:
|
||||
str: The result of ingesting the message.
|
||||
"""
|
||||
import os
|
||||
return os.getenv("secret")
|
||||
'''
|
||||
|
||||
|
||||
EXAMPLE_TOOL_SOURCE_WITH_DISTRACTOR = '''
|
||||
def util_do_nothing():
|
||||
@ -721,7 +733,7 @@ def test_tool_run(server, mock_e2b_api_key_none, user, agent_id):
|
||||
actor=user,
|
||||
tool_source=EXAMPLE_TOOL_SOURCE,
|
||||
tool_source_type="python",
|
||||
tool_args=json.dumps({"message": "Hello, world!"}),
|
||||
tool_args={"message": "Hello, world!"},
|
||||
# tool_name="ingest",
|
||||
)
|
||||
print(result)
|
||||
@ -730,11 +742,24 @@ def test_tool_run(server, mock_e2b_api_key_none, user, agent_id):
|
||||
assert not result.stdout
|
||||
assert not result.stderr
|
||||
|
||||
result = server.run_tool_from_source(
|
||||
actor=user,
|
||||
tool_source=EXAMPLE_TOOL_SOURCE_WITH_ENV_VAR,
|
||||
tool_source_type="python",
|
||||
tool_args={},
|
||||
tool_env_vars={"secret": "banana"},
|
||||
)
|
||||
print(result)
|
||||
assert result.status == "success"
|
||||
assert result.tool_return == "banana", result.tool_return
|
||||
assert not result.stdout
|
||||
assert not result.stderr
|
||||
|
||||
result = server.run_tool_from_source(
|
||||
actor=user,
|
||||
tool_source=EXAMPLE_TOOL_SOURCE,
|
||||
tool_source_type="python",
|
||||
tool_args=json.dumps({"message": "Well well well"}),
|
||||
tool_args={"message": "Well well well"},
|
||||
# tool_name="ingest",
|
||||
)
|
||||
print(result)
|
||||
@ -747,7 +772,7 @@ def test_tool_run(server, mock_e2b_api_key_none, user, agent_id):
|
||||
actor=user,
|
||||
tool_source=EXAMPLE_TOOL_SOURCE,
|
||||
tool_source_type="python",
|
||||
tool_args=json.dumps({"bad_arg": "oh no"}),
|
||||
tool_args={"bad_arg": "oh no"},
|
||||
# tool_name="ingest",
|
||||
)
|
||||
print(result)
|
||||
@ -763,7 +788,7 @@ def test_tool_run(server, mock_e2b_api_key_none, user, agent_id):
|
||||
actor=user,
|
||||
tool_source=EXAMPLE_TOOL_SOURCE_WITH_DISTRACTOR,
|
||||
tool_source_type="python",
|
||||
tool_args=json.dumps({"message": "Well well well"}),
|
||||
tool_args={"message": "Well well well"},
|
||||
# tool_name="ingest",
|
||||
)
|
||||
print(result)
|
||||
@ -778,7 +803,7 @@ def test_tool_run(server, mock_e2b_api_key_none, user, agent_id):
|
||||
actor=user,
|
||||
tool_source=EXAMPLE_TOOL_SOURCE_WITH_DISTRACTOR,
|
||||
tool_source_type="python",
|
||||
tool_args=json.dumps({"message": "Well well well"}),
|
||||
tool_args={"message": "Well well well"},
|
||||
tool_name="ingest",
|
||||
)
|
||||
print(result)
|
||||
@ -793,7 +818,7 @@ def test_tool_run(server, mock_e2b_api_key_none, user, agent_id):
|
||||
actor=user,
|
||||
tool_source=EXAMPLE_TOOL_SOURCE_WITH_DISTRACTOR,
|
||||
tool_source_type="python",
|
||||
tool_args=json.dumps({}),
|
||||
tool_args={},
|
||||
tool_name="util_do_nothing",
|
||||
)
|
||||
print(result)
|
||||
|
Loading…
Reference in New Issue
Block a user