Add parent tool rule (#1648)

This commit is contained in:
cthomas 2025-04-09 15:22:15 -07:00 committed by GitHub
parent c0e1f793cf
commit 3713393966
8 changed files with 47 additions and 11 deletions

View File

@ -226,7 +226,7 @@ def core_memory_insert(agent_state: "AgentState", target_block_label: str, new_m
if line_number is None: if line_number is None:
line_number = len(current_value_list) line_number = len(current_value_list)
if replace: if replace:
current_value_list[line_number] = new_memory current_value_list[line_number - 1] = new_memory
else: else:
current_value_list.insert(line_number, new_memory) current_value_list.insert(line_number, new_memory)
new_value = "\n".join(current_value_list) new_value = "\n".join(current_value_list)

View File

@ -28,6 +28,7 @@ from letta.schemas.tool_rule import (
ContinueToolRule, ContinueToolRule,
InitToolRule, InitToolRule,
MaxCountPerStepToolRule, MaxCountPerStepToolRule,
ParentToolRule,
TerminalToolRule, TerminalToolRule,
ToolRule, ToolRule,
) )
@ -89,7 +90,7 @@ def serialize_tool_rules(tool_rules: Optional[List[ToolRule]]) -> List[Dict[str,
return data return data
def deserialize_tool_rules(data: Optional[List[Dict]]) -> List[Union[ChildToolRule, InitToolRule, TerminalToolRule, ConditionalToolRule]]: def deserialize_tool_rules(data: Optional[List[Dict]]) -> List[ToolRule]:
"""Convert a list of dictionaries back into ToolRule objects.""" """Convert a list of dictionaries back into ToolRule objects."""
if not data: if not data:
return [] return []
@ -99,7 +100,7 @@ def deserialize_tool_rules(data: Optional[List[Dict]]) -> List[Union[ChildToolRu
def deserialize_tool_rule( def deserialize_tool_rule(
data: Dict, data: Dict,
) -> Union[ChildToolRule, InitToolRule, TerminalToolRule, ConditionalToolRule, ContinueToolRule, MaxCountPerStepToolRule]: ) -> ToolRule:
"""Deserialize a dictionary to the appropriate ToolRule subclass based on 'type'.""" """Deserialize a dictionary to the appropriate ToolRule subclass based on 'type'."""
rule_type = ToolRuleType(data.get("type")) rule_type = ToolRuleType(data.get("type"))
@ -118,6 +119,8 @@ def deserialize_tool_rule(
return ContinueToolRule(**data) return ContinueToolRule(**data)
elif rule_type == ToolRuleType.max_count_per_step: elif rule_type == ToolRuleType.max_count_per_step:
return MaxCountPerStepToolRule(**data) return MaxCountPerStepToolRule(**data)
elif rule_type == ToolRuleType.parent_last_tool:
return ParentToolRule(**data)
raise ValueError(f"Unknown ToolRule type: {rule_type}") raise ValueError(f"Unknown ToolRule type: {rule_type}")

View File

@ -10,6 +10,7 @@ from letta.schemas.tool_rule import (
ContinueToolRule, ContinueToolRule,
InitToolRule, InitToolRule,
MaxCountPerStepToolRule, MaxCountPerStepToolRule,
ParentToolRule,
TerminalToolRule, TerminalToolRule,
) )
@ -33,6 +34,9 @@ class ToolRulesSolver(BaseModel):
child_based_tool_rules: List[Union[ChildToolRule, ConditionalToolRule, MaxCountPerStepToolRule]] = Field( child_based_tool_rules: List[Union[ChildToolRule, ConditionalToolRule, MaxCountPerStepToolRule]] = Field(
default_factory=list, description="Standard tool rules for controlling execution sequence and allowed transitions." default_factory=list, description="Standard tool rules for controlling execution sequence and allowed transitions."
) )
parent_tool_rules: List[ParentToolRule] = Field(
default_factory=list, description="Filter tool rules to be used to filter out tools from the available set."
)
terminal_tool_rules: List[TerminalToolRule] = Field( terminal_tool_rules: List[TerminalToolRule] = Field(
default_factory=list, description="Terminal tool rules that end the agent loop if called." default_factory=list, description="Terminal tool rules that end the agent loop if called."
) )
@ -44,6 +48,7 @@ class ToolRulesSolver(BaseModel):
init_tool_rules: Optional[List[InitToolRule]] = None, init_tool_rules: Optional[List[InitToolRule]] = None,
continue_tool_rules: Optional[List[ContinueToolRule]] = None, continue_tool_rules: Optional[List[ContinueToolRule]] = None,
child_based_tool_rules: Optional[List[Union[ChildToolRule, ConditionalToolRule, MaxCountPerStepToolRule]]] = None, child_based_tool_rules: Optional[List[Union[ChildToolRule, ConditionalToolRule, MaxCountPerStepToolRule]]] = None,
parent_tool_rules: Optional[List[ParentToolRule]] = None,
terminal_tool_rules: Optional[List[TerminalToolRule]] = None, terminal_tool_rules: Optional[List[TerminalToolRule]] = None,
tool_call_history: Optional[List[str]] = None, tool_call_history: Optional[List[str]] = None,
**kwargs, **kwargs,
@ -52,6 +57,7 @@ class ToolRulesSolver(BaseModel):
init_tool_rules=init_tool_rules or [], init_tool_rules=init_tool_rules or [],
continue_tool_rules=continue_tool_rules or [], continue_tool_rules=continue_tool_rules or [],
child_based_tool_rules=child_based_tool_rules or [], child_based_tool_rules=child_based_tool_rules or [],
parent_tool_rules=parent_tool_rules or [],
terminal_tool_rules=terminal_tool_rules or [], terminal_tool_rules=terminal_tool_rules or [],
tool_call_history=tool_call_history or [], tool_call_history=tool_call_history or [],
**kwargs, **kwargs,
@ -78,6 +84,9 @@ class ToolRulesSolver(BaseModel):
elif rule.type == ToolRuleType.max_count_per_step: elif rule.type == ToolRuleType.max_count_per_step:
assert isinstance(rule, MaxCountPerStepToolRule) assert isinstance(rule, MaxCountPerStepToolRule)
self.child_based_tool_rules.append(rule) self.child_based_tool_rules.append(rule)
elif rule.type == ToolRuleType.parent_last_tool:
assert isinstance(rule, ParentToolRule)
self.parent_tool_rules.append(rule)
def register_tool_call(self, tool_name: str): def register_tool_call(self, tool_name: str):
"""Update the internal state to track tool call history.""" """Update the internal state to track tool call history."""
@ -102,13 +111,14 @@ class ToolRulesSolver(BaseModel):
# If there are init tool rules, only return those defined in the init tool rules # If there are init tool rules, only return those defined in the init tool rules
return [rule.tool_name for rule in self.init_tool_rules] return [rule.tool_name for rule in self.init_tool_rules]
else: else:
# Otherwise, return all the available tools # Otherwise, return all tools besides those constrained by parent tool rules
available_tools = available_tools - set.union(set(), *(set(rule.children) for rule in self.parent_tool_rules))
return list(available_tools) return list(available_tools)
else: else:
# Collect valid tools from all child-based rules # Collect valid tools from all child-based rules
valid_tool_sets = [ valid_tool_sets = [
rule.get_valid_tools(self.tool_call_history, available_tools, last_function_response) rule.get_valid_tools(self.tool_call_history, available_tools, last_function_response)
for rule in self.child_based_tool_rules for rule in self.child_based_tool_rules + self.parent_tool_rules
] ]
# Compute intersection of all valid tool sets # Compute intersection of all valid tool sets

View File

@ -347,13 +347,19 @@ class AnthropicClient(LLMClientBase):
if content_part.type == "text": if content_part.type == "text":
content = strip_xml_tags(string=content_part.text, tag="thinking") content = strip_xml_tags(string=content_part.text, tag="thinking")
if content_part.type == "tool_use": if content_part.type == "tool_use":
# hack for tool rules
input = json.loads(json.dumps(content_part.input))
if "id" in input and input["id"].startswith("toolu_") and "function" in input:
arguments = str(input["function"]["arguments"])
else:
arguments = json.dumps(content_part.input, indent=2)
tool_calls = [ tool_calls = [
ToolCall( ToolCall(
id=content_part.id, id=content_part.id,
type="function", type="function",
function=FunctionCall( function=FunctionCall(
name=content_part.name, name=content_part.name,
arguments=json.dumps(content_part.input, indent=2), arguments=arguments,
), ),
) )
] ]

View File

@ -64,3 +64,4 @@ class ToolRuleType(str, Enum):
conditional = "conditional" conditional = "conditional"
constrain_child_tools = "constrain_child_tools" constrain_child_tools = "constrain_child_tools"
max_count_per_step = "max_count_per_step" max_count_per_step = "max_count_per_step"
parent_last_tool = "parent_last_tool"

View File

@ -29,6 +29,19 @@ class ChildToolRule(BaseToolRule):
return set(self.children) if last_tool == self.tool_name else available_tools return set(self.children) if last_tool == self.tool_name else available_tools
class ParentToolRule(BaseToolRule):
"""
A ToolRule that only allows a child tool to be called if the parent has been called.
"""
type: Literal[ToolRuleType.parent_last_tool] = ToolRuleType.parent_last_tool
children: List[str] = Field(..., description="The children tools that can be invoked.")
def get_valid_tools(self, tool_call_history: List[str], available_tools: Set[str], last_function_response: Optional[str]) -> Set[str]:
last_tool = tool_call_history[-1] if tool_call_history else None
return set(self.children) if last_tool == self.tool_name else available_tools - set(self.children)
class ConditionalToolRule(BaseToolRule): class ConditionalToolRule(BaseToolRule):
""" """
A ToolRule that conditionally maps to different child tools based on the output. A ToolRule that conditionally maps to different child tools based on the output.
@ -128,6 +141,6 @@ class MaxCountPerStepToolRule(BaseToolRule):
ToolRule = Annotated[ ToolRule = Annotated[
Union[ChildToolRule, InitToolRule, TerminalToolRule, ConditionalToolRule, ContinueToolRule, MaxCountPerStepToolRule], Union[ChildToolRule, InitToolRule, TerminalToolRule, ConditionalToolRule, ContinueToolRule, MaxCountPerStepToolRule, ParentToolRule],
Field(discriminator="type"), Field(discriminator="type"),
] ]

View File

@ -43,6 +43,7 @@ from letta.schemas.passage import Passage as PydanticPassage
from letta.schemas.source import Source as PydanticSource from letta.schemas.source import Source as PydanticSource
from letta.schemas.tool import Tool as PydanticTool from letta.schemas.tool import Tool as PydanticTool
from letta.schemas.tool_rule import ContinueToolRule as PydanticContinueToolRule from letta.schemas.tool_rule import ContinueToolRule as PydanticContinueToolRule
from letta.schemas.tool_rule import ParentToolRule as PydanticParentToolRule
from letta.schemas.tool_rule import TerminalToolRule as PydanticTerminalToolRule from letta.schemas.tool_rule import TerminalToolRule as PydanticTerminalToolRule
from letta.schemas.tool_rule import ToolRule as PydanticToolRule from letta.schemas.tool_rule import ToolRule as PydanticToolRule
from letta.schemas.user import User as PydanticUser from letta.schemas.user import User as PydanticUser
@ -159,9 +160,9 @@ class AgentManager:
tool_rules.append(PydanticTerminalToolRule(tool_name=tool_name)) tool_rules.append(PydanticTerminalToolRule(tool_name=tool_name))
elif tool_name in BASE_TOOLS + BASE_MEMORY_TOOLS + BASE_SLEEPTIME_TOOLS: elif tool_name in BASE_TOOLS + BASE_MEMORY_TOOLS + BASE_SLEEPTIME_TOOLS:
tool_rules.append(PydanticContinueToolRule(tool_name=tool_name)) tool_rules.append(PydanticContinueToolRule(tool_name=tool_name))
# we may want to add additional rules for sleeptime agents
# if agent_create.agent_type == AgentType.sleeptime_agent: if agent_create.agent_type == AgentType.sleeptime_agent:
# tool_rules.append(PydanticChildToolRule(tool_name="view_core_memory_with_line_numbers", children=["core_memory_insert"])) tool_rules.append(PydanticParentToolRule(tool_name="view_core_memory_with_line_numbers", children=["core_memory_insert"]))
# if custom rules, check tool rules are valid # if custom rules, check tool rules are valid
if agent_create.tool_rules: if agent_create.tool_rules:

View File

@ -9,7 +9,7 @@ from letta.orm.enums import JobType
from letta.orm.errors import NoResultFound from letta.orm.errors import NoResultFound
from letta.schemas.agent import CreateAgent from letta.schemas.agent import CreateAgent
from letta.schemas.block import CreateBlock from letta.schemas.block import CreateBlock
from letta.schemas.enums import JobStatus from letta.schemas.enums import JobStatus, ToolRuleType
from letta.schemas.group import ( from letta.schemas.group import (
DynamicManager, DynamicManager,
GroupCreate, GroupCreate,
@ -507,6 +507,8 @@ async def test_sleeptime_group_chat(server, actor):
assert "view_core_memory_with_line_numbers" in sleeptime_agent_tools assert "view_core_memory_with_line_numbers" in sleeptime_agent_tools
assert "core_memory_insert" in sleeptime_agent_tools assert "core_memory_insert" in sleeptime_agent_tools
assert len([rule for rule in sleeptime_agent.tool_rules if rule.type == ToolRuleType.parent_last_tool]) > 0
# 5. Send messages and verify run ids # 5. Send messages and verify run ids
message_text = [ message_text = [
"my favorite color is orange", "my favorite color is orange",