Add parent tool rule (#1648)

This commit is contained in:
cthomas 2025-04-09 15:22:15 -07:00 committed by GitHub
parent c0e1f793cf
commit 3713393966
8 changed files with 47 additions and 11 deletions

View File

@ -226,7 +226,7 @@ def core_memory_insert(agent_state: "AgentState", target_block_label: str, new_m
if line_number is None:
line_number = len(current_value_list)
if replace:
current_value_list[line_number] = new_memory
current_value_list[line_number - 1] = new_memory
else:
current_value_list.insert(line_number, new_memory)
new_value = "\n".join(current_value_list)

View File

@ -28,6 +28,7 @@ from letta.schemas.tool_rule import (
ContinueToolRule,
InitToolRule,
MaxCountPerStepToolRule,
ParentToolRule,
TerminalToolRule,
ToolRule,
)
@ -89,7 +90,7 @@ def serialize_tool_rules(tool_rules: Optional[List[ToolRule]]) -> List[Dict[str,
return data
def deserialize_tool_rules(data: Optional[List[Dict]]) -> List[Union[ChildToolRule, InitToolRule, TerminalToolRule, ConditionalToolRule]]:
def deserialize_tool_rules(data: Optional[List[Dict]]) -> List[ToolRule]:
"""Convert a list of dictionaries back into ToolRule objects."""
if not data:
return []
@ -99,7 +100,7 @@ def deserialize_tool_rules(data: Optional[List[Dict]]) -> List[Union[ChildToolRu
def deserialize_tool_rule(
data: Dict,
) -> Union[ChildToolRule, InitToolRule, TerminalToolRule, ConditionalToolRule, ContinueToolRule, MaxCountPerStepToolRule]:
) -> ToolRule:
"""Deserialize a dictionary to the appropriate ToolRule subclass based on 'type'."""
rule_type = ToolRuleType(data.get("type"))
@ -118,6 +119,8 @@ def deserialize_tool_rule(
return ContinueToolRule(**data)
elif rule_type == ToolRuleType.max_count_per_step:
return MaxCountPerStepToolRule(**data)
elif rule_type == ToolRuleType.parent_last_tool:
return ParentToolRule(**data)
raise ValueError(f"Unknown ToolRule type: {rule_type}")

View File

@ -10,6 +10,7 @@ from letta.schemas.tool_rule import (
ContinueToolRule,
InitToolRule,
MaxCountPerStepToolRule,
ParentToolRule,
TerminalToolRule,
)
@ -33,6 +34,9 @@ class ToolRulesSolver(BaseModel):
child_based_tool_rules: List[Union[ChildToolRule, ConditionalToolRule, MaxCountPerStepToolRule]] = Field(
default_factory=list, description="Standard tool rules for controlling execution sequence and allowed transitions."
)
parent_tool_rules: List[ParentToolRule] = Field(
default_factory=list, description="Filter tool rules to be used to filter out tools from the available set."
)
terminal_tool_rules: List[TerminalToolRule] = Field(
default_factory=list, description="Terminal tool rules that end the agent loop if called."
)
@ -44,6 +48,7 @@ class ToolRulesSolver(BaseModel):
init_tool_rules: Optional[List[InitToolRule]] = None,
continue_tool_rules: Optional[List[ContinueToolRule]] = None,
child_based_tool_rules: Optional[List[Union[ChildToolRule, ConditionalToolRule, MaxCountPerStepToolRule]]] = None,
parent_tool_rules: Optional[List[ParentToolRule]] = None,
terminal_tool_rules: Optional[List[TerminalToolRule]] = None,
tool_call_history: Optional[List[str]] = None,
**kwargs,
@ -52,6 +57,7 @@ class ToolRulesSolver(BaseModel):
init_tool_rules=init_tool_rules or [],
continue_tool_rules=continue_tool_rules or [],
child_based_tool_rules=child_based_tool_rules or [],
parent_tool_rules=parent_tool_rules or [],
terminal_tool_rules=terminal_tool_rules or [],
tool_call_history=tool_call_history or [],
**kwargs,
@ -78,6 +84,9 @@ class ToolRulesSolver(BaseModel):
elif rule.type == ToolRuleType.max_count_per_step:
assert isinstance(rule, MaxCountPerStepToolRule)
self.child_based_tool_rules.append(rule)
elif rule.type == ToolRuleType.parent_last_tool:
assert isinstance(rule, ParentToolRule)
self.parent_tool_rules.append(rule)
def register_tool_call(self, tool_name: str):
"""Update the internal state to track tool call history."""
@ -102,13 +111,14 @@ class ToolRulesSolver(BaseModel):
# If there are init tool rules, only return those defined in the init tool rules
return [rule.tool_name for rule in self.init_tool_rules]
else:
# Otherwise, return all the available tools
# Otherwise, return all tools besides those constrained by parent tool rules
available_tools = available_tools - set.union(set(), *(set(rule.children) for rule in self.parent_tool_rules))
return list(available_tools)
else:
# Collect valid tools from all child-based rules
valid_tool_sets = [
rule.get_valid_tools(self.tool_call_history, available_tools, last_function_response)
for rule in self.child_based_tool_rules
for rule in self.child_based_tool_rules + self.parent_tool_rules
]
# Compute intersection of all valid tool sets

View File

@ -347,13 +347,19 @@ class AnthropicClient(LLMClientBase):
if content_part.type == "text":
content = strip_xml_tags(string=content_part.text, tag="thinking")
if content_part.type == "tool_use":
# hack for tool rules
input = json.loads(json.dumps(content_part.input))
if "id" in input and input["id"].startswith("toolu_") and "function" in input:
arguments = str(input["function"]["arguments"])
else:
arguments = json.dumps(content_part.input, indent=2)
tool_calls = [
ToolCall(
id=content_part.id,
type="function",
function=FunctionCall(
name=content_part.name,
arguments=json.dumps(content_part.input, indent=2),
arguments=arguments,
),
)
]

View File

@ -64,3 +64,4 @@ class ToolRuleType(str, Enum):
conditional = "conditional"
constrain_child_tools = "constrain_child_tools"
max_count_per_step = "max_count_per_step"
parent_last_tool = "parent_last_tool"

View File

@ -29,6 +29,19 @@ class ChildToolRule(BaseToolRule):
return set(self.children) if last_tool == self.tool_name else available_tools
class ParentToolRule(BaseToolRule):
"""
A ToolRule that only allows a child tool to be called if the parent has been called.
"""
type: Literal[ToolRuleType.parent_last_tool] = ToolRuleType.parent_last_tool
children: List[str] = Field(..., description="The children tools that can be invoked.")
def get_valid_tools(self, tool_call_history: List[str], available_tools: Set[str], last_function_response: Optional[str]) -> Set[str]:
last_tool = tool_call_history[-1] if tool_call_history else None
return set(self.children) if last_tool == self.tool_name else available_tools - set(self.children)
class ConditionalToolRule(BaseToolRule):
"""
A ToolRule that conditionally maps to different child tools based on the output.
@ -128,6 +141,6 @@ class MaxCountPerStepToolRule(BaseToolRule):
ToolRule = Annotated[
Union[ChildToolRule, InitToolRule, TerminalToolRule, ConditionalToolRule, ContinueToolRule, MaxCountPerStepToolRule],
Union[ChildToolRule, InitToolRule, TerminalToolRule, ConditionalToolRule, ContinueToolRule, MaxCountPerStepToolRule, ParentToolRule],
Field(discriminator="type"),
]

View File

@ -43,6 +43,7 @@ from letta.schemas.passage import Passage as PydanticPassage
from letta.schemas.source import Source as PydanticSource
from letta.schemas.tool import Tool as PydanticTool
from letta.schemas.tool_rule import ContinueToolRule as PydanticContinueToolRule
from letta.schemas.tool_rule import ParentToolRule as PydanticParentToolRule
from letta.schemas.tool_rule import TerminalToolRule as PydanticTerminalToolRule
from letta.schemas.tool_rule import ToolRule as PydanticToolRule
from letta.schemas.user import User as PydanticUser
@ -159,9 +160,9 @@ class AgentManager:
tool_rules.append(PydanticTerminalToolRule(tool_name=tool_name))
elif tool_name in BASE_TOOLS + BASE_MEMORY_TOOLS + BASE_SLEEPTIME_TOOLS:
tool_rules.append(PydanticContinueToolRule(tool_name=tool_name))
# we may want to add additional rules for sleeptime agents
# if agent_create.agent_type == AgentType.sleeptime_agent:
# tool_rules.append(PydanticChildToolRule(tool_name="view_core_memory_with_line_numbers", children=["core_memory_insert"]))
if agent_create.agent_type == AgentType.sleeptime_agent:
tool_rules.append(PydanticParentToolRule(tool_name="view_core_memory_with_line_numbers", children=["core_memory_insert"]))
# if custom rules, check tool rules are valid
if agent_create.tool_rules:

View File

@ -9,7 +9,7 @@ from letta.orm.enums import JobType
from letta.orm.errors import NoResultFound
from letta.schemas.agent import CreateAgent
from letta.schemas.block import CreateBlock
from letta.schemas.enums import JobStatus
from letta.schemas.enums import JobStatus, ToolRuleType
from letta.schemas.group import (
DynamicManager,
GroupCreate,
@ -507,6 +507,8 @@ async def test_sleeptime_group_chat(server, actor):
assert "view_core_memory_with_line_numbers" in sleeptime_agent_tools
assert "core_memory_insert" in sleeptime_agent_tools
assert len([rule for rule in sleeptime_agent.tool_rules if rule.type == ToolRuleType.parent_last_tool]) > 0
# 5. Send messages and verify run ids
message_text = [
"my favorite color is orange",