From 371339396637b653404bd208adb24f8164a10fba Mon Sep 17 00:00:00 2001 From: cthomas Date: Wed, 9 Apr 2025 15:22:15 -0700 Subject: [PATCH] Add parent tool rule (#1648) --- letta/functions/function_sets/base.py | 2 +- letta/helpers/converters.py | 7 +++++-- letta/helpers/tool_rule_solver.py | 14 ++++++++++++-- letta/llm_api/anthropic_client.py | 8 +++++++- letta/schemas/enums.py | 1 + letta/schemas/tool_rule.py | 15 ++++++++++++++- letta/services/agent_manager.py | 7 ++++--- tests/test_multi_agent.py | 4 +++- 8 files changed, 47 insertions(+), 11 deletions(-) diff --git a/letta/functions/function_sets/base.py b/letta/functions/function_sets/base.py index 44163deb7..ec1784c7c 100644 --- a/letta/functions/function_sets/base.py +++ b/letta/functions/function_sets/base.py @@ -226,7 +226,7 @@ def core_memory_insert(agent_state: "AgentState", target_block_label: str, new_m if line_number is None: line_number = len(current_value_list) if replace: - current_value_list[line_number] = new_memory + current_value_list[line_number - 1] = new_memory else: current_value_list.insert(line_number, new_memory) new_value = "\n".join(current_value_list) diff --git a/letta/helpers/converters.py b/letta/helpers/converters.py index 6ffe25fb7..853ad7272 100644 --- a/letta/helpers/converters.py +++ b/letta/helpers/converters.py @@ -28,6 +28,7 @@ from letta.schemas.tool_rule import ( ContinueToolRule, InitToolRule, MaxCountPerStepToolRule, + ParentToolRule, TerminalToolRule, ToolRule, ) @@ -89,7 +90,7 @@ def serialize_tool_rules(tool_rules: Optional[List[ToolRule]]) -> List[Dict[str, return data -def deserialize_tool_rules(data: Optional[List[Dict]]) -> List[Union[ChildToolRule, InitToolRule, TerminalToolRule, ConditionalToolRule]]: +def deserialize_tool_rules(data: Optional[List[Dict]]) -> List[ToolRule]: """Convert a list of dictionaries back into ToolRule objects.""" if not data: return [] @@ -99,7 +100,7 @@ def deserialize_tool_rules(data: Optional[List[Dict]]) -> List[Union[ChildToolRu def deserialize_tool_rule( data: Dict, -) -> Union[ChildToolRule, InitToolRule, TerminalToolRule, ConditionalToolRule, ContinueToolRule, MaxCountPerStepToolRule]: +) -> ToolRule: """Deserialize a dictionary to the appropriate ToolRule subclass based on 'type'.""" rule_type = ToolRuleType(data.get("type")) @@ -118,6 +119,8 @@ def deserialize_tool_rule( return ContinueToolRule(**data) elif rule_type == ToolRuleType.max_count_per_step: return MaxCountPerStepToolRule(**data) + elif rule_type == ToolRuleType.parent_last_tool: + return ParentToolRule(**data) raise ValueError(f"Unknown ToolRule type: {rule_type}") diff --git a/letta/helpers/tool_rule_solver.py b/letta/helpers/tool_rule_solver.py index b0ff1d799..15e5700e3 100644 --- a/letta/helpers/tool_rule_solver.py +++ b/letta/helpers/tool_rule_solver.py @@ -10,6 +10,7 @@ from letta.schemas.tool_rule import ( ContinueToolRule, InitToolRule, MaxCountPerStepToolRule, + ParentToolRule, TerminalToolRule, ) @@ -33,6 +34,9 @@ class ToolRulesSolver(BaseModel): child_based_tool_rules: List[Union[ChildToolRule, ConditionalToolRule, MaxCountPerStepToolRule]] = Field( default_factory=list, description="Standard tool rules for controlling execution sequence and allowed transitions." ) + parent_tool_rules: List[ParentToolRule] = Field( + default_factory=list, description="Filter tool rules to be used to filter out tools from the available set." + ) terminal_tool_rules: List[TerminalToolRule] = Field( default_factory=list, description="Terminal tool rules that end the agent loop if called." ) @@ -44,6 +48,7 @@ class ToolRulesSolver(BaseModel): init_tool_rules: Optional[List[InitToolRule]] = None, continue_tool_rules: Optional[List[ContinueToolRule]] = None, child_based_tool_rules: Optional[List[Union[ChildToolRule, ConditionalToolRule, MaxCountPerStepToolRule]]] = None, + parent_tool_rules: Optional[List[ParentToolRule]] = None, terminal_tool_rules: Optional[List[TerminalToolRule]] = None, tool_call_history: Optional[List[str]] = None, **kwargs, @@ -52,6 +57,7 @@ class ToolRulesSolver(BaseModel): init_tool_rules=init_tool_rules or [], continue_tool_rules=continue_tool_rules or [], child_based_tool_rules=child_based_tool_rules or [], + parent_tool_rules=parent_tool_rules or [], terminal_tool_rules=terminal_tool_rules or [], tool_call_history=tool_call_history or [], **kwargs, @@ -78,6 +84,9 @@ class ToolRulesSolver(BaseModel): elif rule.type == ToolRuleType.max_count_per_step: assert isinstance(rule, MaxCountPerStepToolRule) self.child_based_tool_rules.append(rule) + elif rule.type == ToolRuleType.parent_last_tool: + assert isinstance(rule, ParentToolRule) + self.parent_tool_rules.append(rule) def register_tool_call(self, tool_name: str): """Update the internal state to track tool call history.""" @@ -102,13 +111,14 @@ class ToolRulesSolver(BaseModel): # If there are init tool rules, only return those defined in the init tool rules return [rule.tool_name for rule in self.init_tool_rules] else: - # Otherwise, return all the available tools + # Otherwise, return all tools besides those constrained by parent tool rules + available_tools = available_tools - set.union(set(), *(set(rule.children) for rule in self.parent_tool_rules)) return list(available_tools) else: # Collect valid tools from all child-based rules valid_tool_sets = [ rule.get_valid_tools(self.tool_call_history, available_tools, last_function_response) - for rule in self.child_based_tool_rules + for rule in self.child_based_tool_rules + self.parent_tool_rules ] # Compute intersection of all valid tool sets diff --git a/letta/llm_api/anthropic_client.py b/letta/llm_api/anthropic_client.py index ece420169..a5afbc08e 100644 --- a/letta/llm_api/anthropic_client.py +++ b/letta/llm_api/anthropic_client.py @@ -347,13 +347,19 @@ class AnthropicClient(LLMClientBase): if content_part.type == "text": content = strip_xml_tags(string=content_part.text, tag="thinking") if content_part.type == "tool_use": + # hack for tool rules + input = json.loads(json.dumps(content_part.input)) + if "id" in input and input["id"].startswith("toolu_") and "function" in input: + arguments = str(input["function"]["arguments"]) + else: + arguments = json.dumps(content_part.input, indent=2) tool_calls = [ ToolCall( id=content_part.id, type="function", function=FunctionCall( name=content_part.name, - arguments=json.dumps(content_part.input, indent=2), + arguments=arguments, ), ) ] diff --git a/letta/schemas/enums.py b/letta/schemas/enums.py index 2fd6446f5..f566908b6 100644 --- a/letta/schemas/enums.py +++ b/letta/schemas/enums.py @@ -64,3 +64,4 @@ class ToolRuleType(str, Enum): conditional = "conditional" constrain_child_tools = "constrain_child_tools" max_count_per_step = "max_count_per_step" + parent_last_tool = "parent_last_tool" diff --git a/letta/schemas/tool_rule.py b/letta/schemas/tool_rule.py index 37158063a..4a658e2c2 100644 --- a/letta/schemas/tool_rule.py +++ b/letta/schemas/tool_rule.py @@ -29,6 +29,19 @@ class ChildToolRule(BaseToolRule): return set(self.children) if last_tool == self.tool_name else available_tools +class ParentToolRule(BaseToolRule): + """ + A ToolRule that only allows a child tool to be called if the parent has been called. + """ + + type: Literal[ToolRuleType.parent_last_tool] = ToolRuleType.parent_last_tool + children: List[str] = Field(..., description="The children tools that can be invoked.") + + def get_valid_tools(self, tool_call_history: List[str], available_tools: Set[str], last_function_response: Optional[str]) -> Set[str]: + last_tool = tool_call_history[-1] if tool_call_history else None + return set(self.children) if last_tool == self.tool_name else available_tools - set(self.children) + + class ConditionalToolRule(BaseToolRule): """ A ToolRule that conditionally maps to different child tools based on the output. @@ -128,6 +141,6 @@ class MaxCountPerStepToolRule(BaseToolRule): ToolRule = Annotated[ - Union[ChildToolRule, InitToolRule, TerminalToolRule, ConditionalToolRule, ContinueToolRule, MaxCountPerStepToolRule], + Union[ChildToolRule, InitToolRule, TerminalToolRule, ConditionalToolRule, ContinueToolRule, MaxCountPerStepToolRule, ParentToolRule], Field(discriminator="type"), ] diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index d90dbba24..ffbef8507 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -43,6 +43,7 @@ from letta.schemas.passage import Passage as PydanticPassage from letta.schemas.source import Source as PydanticSource from letta.schemas.tool import Tool as PydanticTool from letta.schemas.tool_rule import ContinueToolRule as PydanticContinueToolRule +from letta.schemas.tool_rule import ParentToolRule as PydanticParentToolRule from letta.schemas.tool_rule import TerminalToolRule as PydanticTerminalToolRule from letta.schemas.tool_rule import ToolRule as PydanticToolRule from letta.schemas.user import User as PydanticUser @@ -159,9 +160,9 @@ class AgentManager: tool_rules.append(PydanticTerminalToolRule(tool_name=tool_name)) elif tool_name in BASE_TOOLS + BASE_MEMORY_TOOLS + BASE_SLEEPTIME_TOOLS: tool_rules.append(PydanticContinueToolRule(tool_name=tool_name)) - # we may want to add additional rules for sleeptime agents - # if agent_create.agent_type == AgentType.sleeptime_agent: - # tool_rules.append(PydanticChildToolRule(tool_name="view_core_memory_with_line_numbers", children=["core_memory_insert"])) + + if agent_create.agent_type == AgentType.sleeptime_agent: + tool_rules.append(PydanticParentToolRule(tool_name="view_core_memory_with_line_numbers", children=["core_memory_insert"])) # if custom rules, check tool rules are valid if agent_create.tool_rules: diff --git a/tests/test_multi_agent.py b/tests/test_multi_agent.py index 11462cc15..4a5b825c5 100644 --- a/tests/test_multi_agent.py +++ b/tests/test_multi_agent.py @@ -9,7 +9,7 @@ from letta.orm.enums import JobType from letta.orm.errors import NoResultFound from letta.schemas.agent import CreateAgent from letta.schemas.block import CreateBlock -from letta.schemas.enums import JobStatus +from letta.schemas.enums import JobStatus, ToolRuleType from letta.schemas.group import ( DynamicManager, GroupCreate, @@ -507,6 +507,8 @@ async def test_sleeptime_group_chat(server, actor): assert "view_core_memory_with_line_numbers" in sleeptime_agent_tools assert "core_memory_insert" in sleeptime_agent_tools + assert len([rule for rule in sleeptime_agent.tool_rules if rule.type == ToolRuleType.parent_last_tool]) > 0 + # 5. Send messages and verify run ids message_text = [ "my favorite color is orange",