mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
feat: Make the tool runner take a schema (#1328)
This commit is contained in:
parent
21cb2e6cfe
commit
50ab98cfe3
@ -250,3 +250,6 @@ class ToolRunFromSource(LettaBase):
|
||||
name: Optional[str] = Field(None, description="The name of the tool to run.")
|
||||
source_type: Optional[str] = Field(None, description="The type of the source code.")
|
||||
args_json_schema: Optional[Dict] = Field(None, description="The args JSON schema of the function.")
|
||||
json_schema: Optional[Dict] = Field(
|
||||
None, description="The JSON schema of the function (auto-generated from source_code if not provided)"
|
||||
)
|
||||
|
@ -192,6 +192,7 @@ def run_tool_from_source(
|
||||
tool_env_vars=request.env_vars,
|
||||
tool_name=request.name,
|
||||
tool_args_json_schema=request.args_json_schema,
|
||||
tool_json_schema=request.json_schema,
|
||||
actor=actor,
|
||||
)
|
||||
except LettaToolCreateError as e:
|
||||
|
@ -1202,6 +1202,7 @@ class SyncServer(Server):
|
||||
tool_source_type: Optional[str] = None,
|
||||
tool_name: Optional[str] = None,
|
||||
tool_args_json_schema: Optional[Dict[str, Any]] = None,
|
||||
tool_json_schema: Optional[Dict[str, Any]] = None,
|
||||
) -> ToolReturnMessage:
|
||||
"""Run a tool from source code"""
|
||||
if tool_source_type is not None and tool_source_type != "python":
|
||||
@ -1213,6 +1214,11 @@ class SyncServer(Server):
|
||||
source_code=tool_source,
|
||||
args_json_schema=tool_args_json_schema,
|
||||
)
|
||||
|
||||
# If tools_json_schema is explicitly passed in, override it on the created Tool object
|
||||
if tool_json_schema:
|
||||
tool.json_schema = tool_json_schema
|
||||
|
||||
assert tool.name is not None, "Failed to create tool object"
|
||||
|
||||
# TODO eventually allow using agent state in tools
|
||||
|
@ -798,22 +798,25 @@ def ingest(message: str):
|
||||
'''
|
||||
|
||||
|
||||
def test_tool_run(server, mock_e2b_api_key_none, user, agent_id):
|
||||
"""Test that the server can run tools"""
|
||||
import pytest
|
||||
|
||||
|
||||
def test_tool_run_basic(server, mock_e2b_api_key_none, user):
|
||||
"""Test running a simple tool from source"""
|
||||
result = server.run_tool_from_source(
|
||||
actor=user,
|
||||
tool_source=EXAMPLE_TOOL_SOURCE,
|
||||
tool_source_type="python",
|
||||
tool_args={"message": "Hello, world!"},
|
||||
# tool_name="ingest",
|
||||
)
|
||||
print(result)
|
||||
assert result.status == "success"
|
||||
assert result.tool_return == "Ingested message Hello, world!", result.tool_return
|
||||
assert result.tool_return == "Ingested message Hello, world!"
|
||||
assert not result.stdout
|
||||
assert not result.stderr
|
||||
|
||||
|
||||
def test_tool_run_with_env_var(server, mock_e2b_api_key_none, user):
|
||||
"""Test running a tool that uses an environment variable"""
|
||||
result = server.run_tool_from_source(
|
||||
actor=user,
|
||||
tool_source=EXAMPLE_TOOL_SOURCE_WITH_ENV_VAR,
|
||||
@ -821,56 +824,45 @@ def test_tool_run(server, mock_e2b_api_key_none, user, agent_id):
|
||||
tool_args={},
|
||||
tool_env_vars={"secret": "banana"},
|
||||
)
|
||||
print(result)
|
||||
assert result.status == "success"
|
||||
assert result.tool_return == "banana", result.tool_return
|
||||
assert result.tool_return == "banana"
|
||||
assert not result.stdout
|
||||
assert not result.stderr
|
||||
|
||||
result = server.run_tool_from_source(
|
||||
actor=user,
|
||||
tool_source=EXAMPLE_TOOL_SOURCE,
|
||||
tool_source_type="python",
|
||||
tool_args={"message": "Well well well"},
|
||||
# tool_name="ingest",
|
||||
)
|
||||
print(result)
|
||||
assert result.status == "success"
|
||||
assert result.tool_return == "Ingested message Well well well", result.tool_return
|
||||
assert not result.stdout
|
||||
assert not result.stderr
|
||||
|
||||
def test_tool_run_invalid_args(server, mock_e2b_api_key_none, user):
|
||||
"""Test running a tool with incorrect arguments"""
|
||||
result = server.run_tool_from_source(
|
||||
actor=user,
|
||||
tool_source=EXAMPLE_TOOL_SOURCE,
|
||||
tool_source_type="python",
|
||||
tool_args={"bad_arg": "oh no"},
|
||||
# tool_name="ingest",
|
||||
)
|
||||
print(result)
|
||||
assert result.status == "error"
|
||||
assert "Error" in result.tool_return, result.tool_return
|
||||
assert "missing 1 required positional argument" in result.tool_return, result.tool_return
|
||||
assert "Error" in result.tool_return
|
||||
assert "missing 1 required positional argument" in result.tool_return
|
||||
assert not result.stdout
|
||||
assert result.stderr
|
||||
assert "missing 1 required positional argument" in result.stderr[0]
|
||||
|
||||
# Test that we can still pull the tool out by default (pulls that last tool in the source)
|
||||
|
||||
def test_tool_run_with_distractor(server, mock_e2b_api_key_none, user):
|
||||
"""Test running a tool with a distractor function in the source"""
|
||||
result = server.run_tool_from_source(
|
||||
actor=user,
|
||||
tool_source=EXAMPLE_TOOL_SOURCE_WITH_DISTRACTOR,
|
||||
tool_source_type="python",
|
||||
tool_args={"message": "Well well well"},
|
||||
# tool_name="ingest",
|
||||
)
|
||||
print(result)
|
||||
assert result.status == "success"
|
||||
assert result.tool_return == "Ingested message Well well well", result.tool_return
|
||||
assert result.tool_return == "Ingested message Well well well"
|
||||
assert result.stdout
|
||||
assert "I'm a distractor" in result.stdout[0]
|
||||
assert not result.stderr
|
||||
|
||||
# Test that we can pull the tool out by name
|
||||
|
||||
def test_tool_run_explicit_tool_name(server, mock_e2b_api_key_none, user):
|
||||
"""Test selecting a tool by name when multiple tools exist in the source"""
|
||||
result = server.run_tool_from_source(
|
||||
actor=user,
|
||||
tool_source=EXAMPLE_TOOL_SOURCE_WITH_DISTRACTOR,
|
||||
@ -878,14 +870,15 @@ def test_tool_run(server, mock_e2b_api_key_none, user, agent_id):
|
||||
tool_args={"message": "Well well well"},
|
||||
tool_name="ingest",
|
||||
)
|
||||
print(result)
|
||||
assert result.status == "success"
|
||||
assert result.tool_return == "Ingested message Well well well", result.tool_return
|
||||
assert result.tool_return == "Ingested message Well well well"
|
||||
assert result.stdout
|
||||
assert "I'm a distractor" in result.stdout[0]
|
||||
assert not result.stderr
|
||||
|
||||
# Test that we can pull a different tool out by name
|
||||
|
||||
def test_tool_run_util_function(server, mock_e2b_api_key_none, user):
|
||||
"""Test selecting a utility function that does not return anything meaningful"""
|
||||
result = server.run_tool_from_source(
|
||||
actor=user,
|
||||
tool_source=EXAMPLE_TOOL_SOURCE_WITH_DISTRACTOR,
|
||||
@ -893,14 +886,44 @@ def test_tool_run(server, mock_e2b_api_key_none, user, agent_id):
|
||||
tool_args={},
|
||||
tool_name="util_do_nothing",
|
||||
)
|
||||
print(result)
|
||||
assert result.status == "success"
|
||||
assert result.tool_return == str(None), result.tool_return
|
||||
assert result.tool_return == str(None)
|
||||
assert result.stdout
|
||||
assert "I'm a distractor" in result.stdout[0]
|
||||
assert not result.stderr
|
||||
|
||||
|
||||
def test_tool_run_with_explicit_json_schema(server, mock_e2b_api_key_none, user):
|
||||
"""Test overriding the autogenerated JSON schema with an explicit one"""
|
||||
explicit_json_schema = {
|
||||
"name": "ingest",
|
||||
"description": "Blah blah blah.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"message": {"type": "string", "description": "The message to ingest into the system."},
|
||||
"request_heartbeat": {
|
||||
"type": "boolean",
|
||||
"description": "Request an immediate heartbeat after function execution. Set to `True` if you want to send a follow-up message or run a follow-up function.",
|
||||
},
|
||||
},
|
||||
"required": ["message", "request_heartbeat"],
|
||||
},
|
||||
}
|
||||
|
||||
result = server.run_tool_from_source(
|
||||
actor=user,
|
||||
tool_source=EXAMPLE_TOOL_SOURCE,
|
||||
tool_source_type="python",
|
||||
tool_args={"message": "Custom schema test"},
|
||||
tool_json_schema=explicit_json_schema,
|
||||
)
|
||||
assert result.status == "success"
|
||||
assert result.tool_return == "Ingested message Custom schema test"
|
||||
assert not result.stdout
|
||||
assert not result.stderr
|
||||
|
||||
|
||||
def test_composio_client_simple(server):
|
||||
apps = server.get_composio_apps()
|
||||
# Assert there's some amount of apps returned
|
||||
|
Loading…
Reference in New Issue
Block a user