feat: Add tool_type column (#576)

This commit is contained in:
Matthew Zhou 2025-01-10 12:52:15 -10:00 committed by GitHub
parent 696940e481
commit 14a13b80b0
11 changed files with 225 additions and 42 deletions

View File

@ -0,0 +1,70 @@
"""Add tool types
Revision ID: e20573fe9b86
Revises: 915b68780108
Create Date: 2025-01-09 15:11:47.779646
"""
from typing import Sequence, Union
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from alembic import op
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS
from letta.orm.enums import ToolType
# revision identifiers, used by Alembic.
revision: str = "e20573fe9b86"
down_revision: Union[str, None] = "915b68780108"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Step 1: Add the column as nullable with no default
op.add_column("tools", sa.Column("tool_type", sa.String(), nullable=True))
# Step 2: Backpopulate the tool_type column based on tool name
# Define the list of Letta core tools
letta_core_value = ToolType.LETTA_CORE.value
letta_memory_core_value = ToolType.LETTA_MEMORY_CORE.value
custom_value = ToolType.CUSTOM.value
# Update tool_type for Letta core tools
op.execute(
f"""
UPDATE tools
SET tool_type = '{letta_core_value}'
WHERE name IN ({','.join(f"'{name}'" for name in BASE_TOOLS)});
"""
)
op.execute(
f"""
UPDATE tools
SET tool_type = '{letta_memory_core_value}'
WHERE name IN ({','.join(f"'{name}'" for name in BASE_MEMORY_TOOLS)});
"""
)
# Update tool_type for all other tools
op.execute(
f"""
UPDATE tools
SET tool_type = '{custom_value}'
WHERE tool_type IS NULL;
"""
)
# Step 3: Alter the column to be non-nullable
op.alter_column("tools", "tool_type", nullable=False)
op.alter_column("tools", "json_schema", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=True)
def downgrade() -> None:
# Revert the changes made during the upgrade
op.alter_column("tools", "json_schema", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=False)
op.drop_column("tools", "tool_type")
# ### end Alembic commands ###

View File

@ -7,11 +7,11 @@ from abc import ABC, abstractmethod
from typing import List, Optional, Tuple, Union
from letta.constants import (
BASE_TOOLS,
CLI_WARNING_PREFIX,
ERROR_MESSAGE_PREFIX,
FIRST_MESSAGE_ATTEMPTS,
FUNC_FAILED_HEARTBEAT_MESSAGE,
LETTA_CORE_TOOL_MODULE_NAME,
LLM_MAX_TOKENS,
MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST,
MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC,
@ -19,6 +19,7 @@ from letta.constants import (
REQ_HEARTBEAT_MESSAGE,
)
from letta.errors import ContextWindowExceededError
from letta.functions.functions import get_function_from_module
from letta.helpers import ToolRulesSolver
from letta.interface import AgentInterface
from letta.llm_api.helpers import is_context_overflow_error
@ -26,6 +27,7 @@ from letta.llm_api.llm_api_tools import create
from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages
from letta.memory import summarize_messages
from letta.orm import User
from letta.orm.enums import ToolType
from letta.schemas.agent import AgentState, AgentStepResponse, UpdateAgent
from letta.schemas.block import BlockUpdate
from letta.schemas.embedding_config import EmbeddingConfig
@ -153,7 +155,7 @@ class Agent(BaseAgent):
raise ValueError(f"Invalid JSON format in message: {msg.text}")
return None
def update_memory_if_change(self, new_memory: Memory) -> bool:
def update_memory_if_changed(self, new_memory: Memory) -> bool:
"""
Update internal memory object and system prompt if there have been modifications.
@ -192,39 +194,45 @@ class Agent(BaseAgent):
Execute tool modifications and persist the state of the agent.
Note: only some agent state modifications will be persisted, such as data in the AgentState ORM and block data
"""
# TODO: Get rid of this. This whole piece is pretty shady, that we exec the function to just get the type hints for args.
env = {}
env.update(globals())
exec(target_letta_tool.source_code, env)
callable_func = env[target_letta_tool.json_schema["name"]]
spec = inspect.getfullargspec(callable_func).annotations
for name, arg in function_args.items():
if isinstance(function_args[name], dict):
function_args[name] = spec[name](**function_args[name])
# TODO: add agent manager here
orig_memory_str = self.agent_state.memory.compile()
# TODO: need to have an AgentState object that actually has full access to the block data
# this is because the sandbox tools need to be able to access block.value to edit this data
try:
# TODO: This is NO BUENO
# TODO: Matching purely by names is extremely problematic, users can create tools with these names and run them in the agent loop
# TODO: We will have probably have to match the function strings exactly for safety
if function_name in BASE_TOOLS:
if target_letta_tool.tool_type == ToolType.LETTA_CORE:
# base tools are allowed to access the `Agent` object and run on the database
callable_func = get_function_from_module(LETTA_CORE_TOOL_MODULE_NAME, function_name)
function_args["self"] = self # need to attach self to arg since it's dynamically linked
function_response = callable_func(**function_args)
elif target_letta_tool.tool_type == ToolType.LETTA_MEMORY_CORE:
callable_func = get_function_from_module(LETTA_CORE_TOOL_MODULE_NAME, function_name)
agent_state_copy = self.agent_state.__deepcopy__()
function_args["agent_state"] = agent_state_copy # need to attach self to arg since it's dynamically linked
function_response = callable_func(**function_args)
self.update_memory_if_changed(agent_state_copy.memory)
else:
# TODO: Get rid of this. This whole piece is pretty shady, that we exec the function to just get the type hints for args.
env = {}
env.update(globals())
exec(target_letta_tool.source_code, env)
callable_func = env[target_letta_tool.json_schema["name"]]
spec = inspect.getfullargspec(callable_func).annotations
for name, arg in function_args.items():
if isinstance(function_args[name], dict):
function_args[name] = spec[name](**function_args[name])
# execute tool in a sandbox
# TODO: allow agent_state to specify which sandbox to execute tools in
sandbox_run_result = ToolExecutionSandbox(function_name, function_args, self.user).run(
agent_state=self.agent_state.__deepcopy__()
)
# TODO: This is only temporary, can remove after we publish a pip package with this object
agent_state_copy = self.agent_state.__deepcopy__()
agent_state_copy.tools = []
sandbox_run_result = ToolExecutionSandbox(function_name, function_args, self.user).run(agent_state=agent_state_copy)
function_response, updated_agent_state = sandbox_run_result.func_return, sandbox_run_result.agent_state
assert orig_memory_str == self.agent_state.memory.compile(), "Memory should not be modified in a sandbox tool"
if updated_agent_state is not None:
self.update_memory_if_change(updated_agent_state.memory)
self.update_memory_if_changed(updated_agent_state.memory)
except Exception as e:
# Need to catch error here, or else trunction wont happen
# TODO: modify to function execution error
@ -677,7 +685,7 @@ class Agent(BaseAgent):
current_persisted_memory = Memory(
blocks=[self.block_manager.get_block_by_id(block.id, actor=self.user) for block in self.agent_state.memory.get_blocks()]
) # read blocks from DB
self.update_memory_if_change(current_persisted_memory)
self.update_memory_if_changed(current_persisted_memory)
# Step 1: add user message
if isinstance(messages, Message):

View File

@ -8,6 +8,9 @@ API_PREFIX = "/v1"
OPENAI_API_PREFIX = "/openai"
COMPOSIO_ENTITY_ENV_VAR_KEY = "COMPOSIO_ENTITY"
COMPOSIO_TOOL_TAG_NAME = "composio"
LETTA_CORE_TOOL_MODULE_NAME = "letta.functions.function_sets.base"
# String in the error message for when the context window is too large
# Example full message:

View File

@ -1,3 +1,4 @@
import importlib
import inspect
from textwrap import dedent # remove indentation
from types import ModuleType
@ -64,6 +65,70 @@ def parse_source_code(func) -> str:
return source_code
def get_function_from_module(module_name: str, function_name: str):
"""
Dynamically imports a function from a specified module.
Args:
module_name (str): The name of the module to import (e.g., 'base').
function_name (str): The name of the function to retrieve.
Returns:
Callable: The imported function.
Raises:
ModuleNotFoundError: If the specified module cannot be found.
AttributeError: If the function is not found in the module.
"""
try:
# Dynamically import the module
module = importlib.import_module(module_name)
# Retrieve the function
return getattr(module, function_name)
except ModuleNotFoundError:
raise ModuleNotFoundError(f"Module '{module_name}' not found.")
except AttributeError:
raise AttributeError(f"Function '{function_name}' not found in module '{module_name}'.")
def get_json_schema_from_module(module_name: str, function_name: str) -> dict:
"""
Dynamically loads a specific function from a module and generates its JSON schema.
Args:
module_name (str): The name of the module to import (e.g., 'base').
function_name (str): The name of the function to retrieve.
Returns:
dict: The JSON schema for the specified function.
Raises:
ModuleNotFoundError: If the specified module cannot be found.
AttributeError: If the function is not found in the module.
ValueError: If the attribute is not a user-defined function.
"""
try:
# Dynamically import the module
module = importlib.import_module(module_name)
# Retrieve the function
attr = getattr(module, function_name, None)
# Check if it's a user-defined function
if not (inspect.isfunction(attr) and attr.__module__ == module.__name__):
raise ValueError(f"'{function_name}' is not a user-defined function in module '{module_name}'")
# Generate schema (assuming a `generate_schema` function exists)
generated_schema = generate_schema(attr)
return generated_schema
except ModuleNotFoundError:
raise ModuleNotFoundError(f"Module '{module_name}' not found.")
except AttributeError:
raise AttributeError(f"Function '{function_name}' not found in module '{module_name}'.")
def load_function_set(module: ModuleType) -> dict:
"""Load the functions and generate schema for them, given a module object"""
function_dict = {}

View File

@ -109,6 +109,7 @@ class Agent(SqlalchemyBase, OrganizationMixin):
"""converts to the basic pydantic model counterpart"""
state = {
"id": self.id,
"organization_id": self.organization_id,
"name": self.name,
"description": self.description,
"message_ids": self.message_ids,

View File

@ -4,7 +4,7 @@ from sqlalchemy import JSON, String, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column, relationship
# TODO everything in functions should live in this model
from letta.orm.enums import ToolSourceType
from letta.orm.enums import ToolSourceType, ToolType
from letta.orm.mixins import OrganizationMixin
from letta.orm.sqlalchemy_base import SqlalchemyBase
from letta.schemas.tool import Tool as PydanticTool
@ -29,12 +29,17 @@ class Tool(SqlalchemyBase, OrganizationMixin):
__table_args__ = (UniqueConstraint("name", "organization_id", name="uix_name_organization"),)
name: Mapped[str] = mapped_column(doc="The display name of the tool.")
tool_type: Mapped[ToolType] = mapped_column(
String,
default=ToolType.CUSTOM,
doc="The type of tool. This affects whether or not we generate json_schema and source_code on the fly.",
)
return_char_limit: Mapped[int] = mapped_column(nullable=True, doc="The maximum number of characters the tool can return.")
description: Mapped[Optional[str]] = mapped_column(nullable=True, doc="The description of the tool.")
tags: Mapped[List] = mapped_column(JSON, doc="Metadata tags used to filter tools.")
source_type: Mapped[ToolSourceType] = mapped_column(String, doc="The type of the source code.", default=ToolSourceType.json)
source_code: Mapped[Optional[str]] = mapped_column(String, doc="The source code of the function.")
json_schema: Mapped[dict] = mapped_column(JSON, default=lambda: {}, doc="The OAI compatable JSON schema of the function.")
json_schema: Mapped[Optional[dict]] = mapped_column(JSON, default=lambda: {}, doc="The OAI compatable JSON schema of the function.")
module: Mapped[Optional[str]] = mapped_column(
String, nullable=True, doc="the module path from which this tool was derived in the codebase."
)

View File

@ -2,10 +2,11 @@ from typing import Dict, List, Optional
from pydantic import Field, model_validator
from letta.constants import FUNCTION_RETURN_CHAR_LIMIT
from letta.functions.functions import derive_openai_json_schema
from letta.constants import COMPOSIO_TOOL_TAG_NAME, FUNCTION_RETURN_CHAR_LIMIT, LETTA_CORE_TOOL_MODULE_NAME
from letta.functions.functions import derive_openai_json_schema, get_json_schema_from_module
from letta.functions.helpers import generate_composio_tool_wrapper, generate_langchain_tool_wrapper
from letta.functions.schema_generator import generate_schema_from_args_schema_v2
from letta.orm.enums import ToolType
from letta.schemas.letta_base import LettaBase
from letta.schemas.openai.chat_completions import ToolCall
@ -28,6 +29,7 @@ class Tool(BaseTool):
"""
id: str = BaseTool.generate_id_field()
tool_type: ToolType = Field(ToolType.CUSTOM, description="The type of the tool.")
description: Optional[str] = Field(None, description="The description of the tool.")
source_type: Optional[str] = Field(None, description="The type of the source code.")
module: Optional[str] = Field(None, description="The module of the function.")
@ -36,7 +38,7 @@ class Tool(BaseTool):
tags: List[str] = Field([], description="Metadata tags.")
# code
source_code: str = Field(..., description="The source code of the function.")
source_code: Optional[str] = Field(None, description="The source code of the function.")
json_schema: Optional[Dict] = Field(None, description="The JSON schema of the function.")
# tool configuration
@ -51,9 +53,19 @@ class Tool(BaseTool):
"""
Populate missing fields: name, description, and json_schema.
"""
# Derive JSON schema if not provided
if not self.json_schema:
self.json_schema = derive_openai_json_schema(source_code=self.source_code)
if self.tool_type == ToolType.CUSTOM:
# If it's a custom tool, we need to ensure source_code is present
if not self.source_code:
raise ValueError(f"Custom tool with id={self.id} is missing source_code field.")
# Always derive json_schema for freshest possible json_schema
# TODO: Instead of checking the tag, we should having `COMPOSIO` as a specific ToolType
# TODO: We skip this for Composio bc composio json schemas are derived differently
if not (COMPOSIO_TOOL_TAG_NAME in self.tags):
self.json_schema = derive_openai_json_schema(source_code=self.source_code)
elif self.tool_type in {ToolType.LETTA_CORE, ToolType.LETTA_MEMORY_CORE}:
# If it's letta core tool, we generate the json_schema on the fly here
self.json_schema = get_json_schema_from_module(module_name=LETTA_CORE_TOOL_MODULE_NAME, function_name=self.name)
# Derive name from the JSON schema if not provided
if not self.name:
@ -125,7 +137,7 @@ class ToolCreate(LettaBase):
description = composio_tool.description
source_type = "python"
tags = ["composio"]
tags = [COMPOSIO_TOOL_TAG_NAME]
wrapper_func_name, wrapper_function_str = generate_composio_tool_wrapper(action_name)
json_schema = generate_schema_from_args_schema_v2(composio_tool.args_schema, name=wrapper_func_name, description=description)

View File

@ -1,10 +1,10 @@
import importlib
import inspect
import warnings
from typing import List, Optional
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS
from letta.functions.functions import derive_openai_json_schema, load_function_set
from letta.orm.enums import ToolType
# TODO: Remove this once we translate all of these to the ORM
from letta.orm.errors import NoResultFound
@ -32,10 +32,10 @@ class ToolManager:
self.session_maker = db_context
# TODO: Refactor this across the codebase to use CreateTool instead of passing in a Tool object
@enforce_types
def create_or_update_tool(self, pydantic_tool: PydanticTool, actor: PydanticUser) -> PydanticTool:
"""Create a new tool based on the ToolCreate schema."""
# Derive json_schema
tool = self.get_tool_by_name(tool_name=pydantic_tool.name, actor=actor)
if tool:
# Put to dict and remove fields that should not be reset
@ -63,6 +63,7 @@ class ToolManager:
if pydantic_tool.description is None:
pydantic_tool.description = pydantic_tool.json_schema.get("description", None)
tool_data = pydantic_tool.model_dump()
tool = ToolModel(**tool_data)
tool.create(session, actor=actor) # Re-raise other database-related errors
return tool.to_pydantic()
@ -113,8 +114,6 @@ class ToolManager:
# If source code is changed and a new json_schema is not provided, we want to auto-refresh the schema
if "source_code" in update_data.keys() and "json_schema" not in update_data.keys():
pydantic_tool = tool.to_pydantic()
update_data["name"] if "name" in update_data.keys() else None
new_schema = derive_openai_json_schema(source_code=pydantic_tool.source_code)
tool.json_schema = new_schema
@ -155,12 +154,19 @@ class ToolManager:
tools = []
for name, schema in functions_to_schema.items():
if name in BASE_TOOLS + BASE_MEMORY_TOOLS:
# print([str(inspect.getsource(line)) for line in schema["imports"]])
source_code = inspect.getsource(schema["python_function"])
tags = [module_name]
if module_name == "base":
tags.append("letta-base")
# BASE_MEMORY_TOOLS should be executed in an e2b sandbox
# so they should NOT be letta_core tools, instead, treated as custom tools
if name in BASE_TOOLS:
tool_type = ToolType.LETTA_CORE
elif name in BASE_MEMORY_TOOLS:
tool_type = ToolType.LETTA_MEMORY_CORE
else:
raise ValueError(f"Tool name {name} is not in the list of base tool names: {BASE_TOOLS + BASE_MEMORY_TOOLS}")
# create to tool
tools.append(
self.create_or_update_tool(
@ -168,9 +174,7 @@ class ToolManager:
name=name,
tags=tags,
source_type="python",
module=schema["module"],
source_code=source_code,
json_schema=schema["json_schema"],
tool_type=tool_type,
),
actor=actor,
)

View File

@ -197,7 +197,7 @@ def composio_gmail_get_profile_tool(test_user):
@pytest.fixture
def clear_core_memory_tool(test_user):
def clear_memory(agent_state: AgentState):
def clear_memory(agent_state: "AgentState"):
"""Clear the core memory"""
agent_state.memory.get_block("human").value = ""
agent_state.memory.get_block("persona").value = ""

View File

@ -42,7 +42,9 @@ def run_server():
@pytest.fixture(
params=[{"server": False}, {"server": True}], # whether to use REST API server
params=[
{"server": False},
], # {"server": True}], # whether to use REST API server
# params=[{"server": False}], # whether to use REST API server
scope="module",
)
@ -121,7 +123,6 @@ def test_shared_blocks(mock_e2b_api_key_none, client: Union[LocalClient, RESTCli
assert (
"charles" in client.get_core_memory(agent_state2.id).get_block("human").value.lower()
), f"Shared block update failed {client.get_core_memory(agent_state2.id).get_block('human').value}"
# assert "charles" in response.messages[1].text.lower(), f"Shared block update failed {response.messages[0].text}"
# cleanup
client.delete_agent(agent_state1.id)

View File

@ -30,6 +30,7 @@ from letta.orm import (
User,
)
from letta.orm.agents_tags import AgentsTags
from letta.orm.enums import ToolType
from letta.orm.errors import NoResultFound, UniqueConstraintViolationError
from letta.schemas.agent import CreateAgent, UpdateAgent
from letta.schemas.block import Block as PydanticBlock
@ -1368,6 +1369,7 @@ def test_get_tool_by_id(server: SyncServer, print_tool, default_user):
assert fetched_tool.tags == print_tool.tags
assert fetched_tool.source_code == print_tool.source_code
assert fetched_tool.source_type == print_tool.source_type
assert fetched_tool.tool_type == ToolType.CUSTOM
def test_get_tool_with_actor(server: SyncServer, print_tool, default_user):
@ -1382,6 +1384,7 @@ def test_get_tool_with_actor(server: SyncServer, print_tool, default_user):
assert fetched_tool.tags == print_tool.tags
assert fetched_tool.source_code == print_tool.source_code
assert fetched_tool.source_type == print_tool.source_type
assert fetched_tool.tool_type == ToolType.CUSTOM
def test_list_tools(server: SyncServer, print_tool, default_user):
@ -1445,6 +1448,7 @@ def test_update_tool_source_code_refreshes_schema_and_name(server: SyncServer, p
new_schema = derive_openai_json_schema(source_code=updated_tool.source_code)
assert updated_tool.json_schema == new_schema
assert updated_tool.tool_type == ToolType.CUSTOM
def test_update_tool_source_code_refreshes_schema_only(server: SyncServer, print_tool, default_user):
@ -1483,6 +1487,7 @@ def test_update_tool_source_code_refreshes_schema_only(server: SyncServer, print
new_schema = derive_openai_json_schema(source_code=updated_tool.source_code, name=updated_tool.name)
assert updated_tool.json_schema == new_schema
assert updated_tool.name == name
assert updated_tool.tool_type == ToolType.CUSTOM
def test_update_tool_multi_user(server: SyncServer, print_tool, default_user, other_user):
@ -1519,6 +1524,15 @@ def test_upsert_base_tools(server: SyncServer, default_user):
tools = server.tool_manager.upsert_base_tools(actor=default_user)
assert sorted([t.name for t in tools]) == expected_tool_names
# Confirm that the return tools have no source_code, but a json_schema
for t in tools:
if t.name in BASE_TOOLS:
assert t.tool_type == ToolType.LETTA_CORE
else:
assert t.tool_type == ToolType.LETTA_MEMORY_CORE
assert t.source_code is None
assert t.json_schema
# ======================================================================================================================
# Message Manager Tests