mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
Add parent tool rule (#1648)
This commit is contained in:
parent
c0e1f793cf
commit
3713393966
@ -226,7 +226,7 @@ def core_memory_insert(agent_state: "AgentState", target_block_label: str, new_m
|
|||||||
if line_number is None:
|
if line_number is None:
|
||||||
line_number = len(current_value_list)
|
line_number = len(current_value_list)
|
||||||
if replace:
|
if replace:
|
||||||
current_value_list[line_number] = new_memory
|
current_value_list[line_number - 1] = new_memory
|
||||||
else:
|
else:
|
||||||
current_value_list.insert(line_number, new_memory)
|
current_value_list.insert(line_number, new_memory)
|
||||||
new_value = "\n".join(current_value_list)
|
new_value = "\n".join(current_value_list)
|
||||||
|
@ -28,6 +28,7 @@ from letta.schemas.tool_rule import (
|
|||||||
ContinueToolRule,
|
ContinueToolRule,
|
||||||
InitToolRule,
|
InitToolRule,
|
||||||
MaxCountPerStepToolRule,
|
MaxCountPerStepToolRule,
|
||||||
|
ParentToolRule,
|
||||||
TerminalToolRule,
|
TerminalToolRule,
|
||||||
ToolRule,
|
ToolRule,
|
||||||
)
|
)
|
||||||
@ -89,7 +90,7 @@ def serialize_tool_rules(tool_rules: Optional[List[ToolRule]]) -> List[Dict[str,
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
def deserialize_tool_rules(data: Optional[List[Dict]]) -> List[Union[ChildToolRule, InitToolRule, TerminalToolRule, ConditionalToolRule]]:
|
def deserialize_tool_rules(data: Optional[List[Dict]]) -> List[ToolRule]:
|
||||||
"""Convert a list of dictionaries back into ToolRule objects."""
|
"""Convert a list of dictionaries back into ToolRule objects."""
|
||||||
if not data:
|
if not data:
|
||||||
return []
|
return []
|
||||||
@ -99,7 +100,7 @@ def deserialize_tool_rules(data: Optional[List[Dict]]) -> List[Union[ChildToolRu
|
|||||||
|
|
||||||
def deserialize_tool_rule(
|
def deserialize_tool_rule(
|
||||||
data: Dict,
|
data: Dict,
|
||||||
) -> Union[ChildToolRule, InitToolRule, TerminalToolRule, ConditionalToolRule, ContinueToolRule, MaxCountPerStepToolRule]:
|
) -> ToolRule:
|
||||||
"""Deserialize a dictionary to the appropriate ToolRule subclass based on 'type'."""
|
"""Deserialize a dictionary to the appropriate ToolRule subclass based on 'type'."""
|
||||||
rule_type = ToolRuleType(data.get("type"))
|
rule_type = ToolRuleType(data.get("type"))
|
||||||
|
|
||||||
@ -118,6 +119,8 @@ def deserialize_tool_rule(
|
|||||||
return ContinueToolRule(**data)
|
return ContinueToolRule(**data)
|
||||||
elif rule_type == ToolRuleType.max_count_per_step:
|
elif rule_type == ToolRuleType.max_count_per_step:
|
||||||
return MaxCountPerStepToolRule(**data)
|
return MaxCountPerStepToolRule(**data)
|
||||||
|
elif rule_type == ToolRuleType.parent_last_tool:
|
||||||
|
return ParentToolRule(**data)
|
||||||
raise ValueError(f"Unknown ToolRule type: {rule_type}")
|
raise ValueError(f"Unknown ToolRule type: {rule_type}")
|
||||||
|
|
||||||
|
|
||||||
|
@ -10,6 +10,7 @@ from letta.schemas.tool_rule import (
|
|||||||
ContinueToolRule,
|
ContinueToolRule,
|
||||||
InitToolRule,
|
InitToolRule,
|
||||||
MaxCountPerStepToolRule,
|
MaxCountPerStepToolRule,
|
||||||
|
ParentToolRule,
|
||||||
TerminalToolRule,
|
TerminalToolRule,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -33,6 +34,9 @@ class ToolRulesSolver(BaseModel):
|
|||||||
child_based_tool_rules: List[Union[ChildToolRule, ConditionalToolRule, MaxCountPerStepToolRule]] = Field(
|
child_based_tool_rules: List[Union[ChildToolRule, ConditionalToolRule, MaxCountPerStepToolRule]] = Field(
|
||||||
default_factory=list, description="Standard tool rules for controlling execution sequence and allowed transitions."
|
default_factory=list, description="Standard tool rules for controlling execution sequence and allowed transitions."
|
||||||
)
|
)
|
||||||
|
parent_tool_rules: List[ParentToolRule] = Field(
|
||||||
|
default_factory=list, description="Filter tool rules to be used to filter out tools from the available set."
|
||||||
|
)
|
||||||
terminal_tool_rules: List[TerminalToolRule] = Field(
|
terminal_tool_rules: List[TerminalToolRule] = Field(
|
||||||
default_factory=list, description="Terminal tool rules that end the agent loop if called."
|
default_factory=list, description="Terminal tool rules that end the agent loop if called."
|
||||||
)
|
)
|
||||||
@ -44,6 +48,7 @@ class ToolRulesSolver(BaseModel):
|
|||||||
init_tool_rules: Optional[List[InitToolRule]] = None,
|
init_tool_rules: Optional[List[InitToolRule]] = None,
|
||||||
continue_tool_rules: Optional[List[ContinueToolRule]] = None,
|
continue_tool_rules: Optional[List[ContinueToolRule]] = None,
|
||||||
child_based_tool_rules: Optional[List[Union[ChildToolRule, ConditionalToolRule, MaxCountPerStepToolRule]]] = None,
|
child_based_tool_rules: Optional[List[Union[ChildToolRule, ConditionalToolRule, MaxCountPerStepToolRule]]] = None,
|
||||||
|
parent_tool_rules: Optional[List[ParentToolRule]] = None,
|
||||||
terminal_tool_rules: Optional[List[TerminalToolRule]] = None,
|
terminal_tool_rules: Optional[List[TerminalToolRule]] = None,
|
||||||
tool_call_history: Optional[List[str]] = None,
|
tool_call_history: Optional[List[str]] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@ -52,6 +57,7 @@ class ToolRulesSolver(BaseModel):
|
|||||||
init_tool_rules=init_tool_rules or [],
|
init_tool_rules=init_tool_rules or [],
|
||||||
continue_tool_rules=continue_tool_rules or [],
|
continue_tool_rules=continue_tool_rules or [],
|
||||||
child_based_tool_rules=child_based_tool_rules or [],
|
child_based_tool_rules=child_based_tool_rules or [],
|
||||||
|
parent_tool_rules=parent_tool_rules or [],
|
||||||
terminal_tool_rules=terminal_tool_rules or [],
|
terminal_tool_rules=terminal_tool_rules or [],
|
||||||
tool_call_history=tool_call_history or [],
|
tool_call_history=tool_call_history or [],
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@ -78,6 +84,9 @@ class ToolRulesSolver(BaseModel):
|
|||||||
elif rule.type == ToolRuleType.max_count_per_step:
|
elif rule.type == ToolRuleType.max_count_per_step:
|
||||||
assert isinstance(rule, MaxCountPerStepToolRule)
|
assert isinstance(rule, MaxCountPerStepToolRule)
|
||||||
self.child_based_tool_rules.append(rule)
|
self.child_based_tool_rules.append(rule)
|
||||||
|
elif rule.type == ToolRuleType.parent_last_tool:
|
||||||
|
assert isinstance(rule, ParentToolRule)
|
||||||
|
self.parent_tool_rules.append(rule)
|
||||||
|
|
||||||
def register_tool_call(self, tool_name: str):
|
def register_tool_call(self, tool_name: str):
|
||||||
"""Update the internal state to track tool call history."""
|
"""Update the internal state to track tool call history."""
|
||||||
@ -102,13 +111,14 @@ class ToolRulesSolver(BaseModel):
|
|||||||
# If there are init tool rules, only return those defined in the init tool rules
|
# If there are init tool rules, only return those defined in the init tool rules
|
||||||
return [rule.tool_name for rule in self.init_tool_rules]
|
return [rule.tool_name for rule in self.init_tool_rules]
|
||||||
else:
|
else:
|
||||||
# Otherwise, return all the available tools
|
# Otherwise, return all tools besides those constrained by parent tool rules
|
||||||
|
available_tools = available_tools - set.union(set(), *(set(rule.children) for rule in self.parent_tool_rules))
|
||||||
return list(available_tools)
|
return list(available_tools)
|
||||||
else:
|
else:
|
||||||
# Collect valid tools from all child-based rules
|
# Collect valid tools from all child-based rules
|
||||||
valid_tool_sets = [
|
valid_tool_sets = [
|
||||||
rule.get_valid_tools(self.tool_call_history, available_tools, last_function_response)
|
rule.get_valid_tools(self.tool_call_history, available_tools, last_function_response)
|
||||||
for rule in self.child_based_tool_rules
|
for rule in self.child_based_tool_rules + self.parent_tool_rules
|
||||||
]
|
]
|
||||||
|
|
||||||
# Compute intersection of all valid tool sets
|
# Compute intersection of all valid tool sets
|
||||||
|
@ -347,13 +347,19 @@ class AnthropicClient(LLMClientBase):
|
|||||||
if content_part.type == "text":
|
if content_part.type == "text":
|
||||||
content = strip_xml_tags(string=content_part.text, tag="thinking")
|
content = strip_xml_tags(string=content_part.text, tag="thinking")
|
||||||
if content_part.type == "tool_use":
|
if content_part.type == "tool_use":
|
||||||
|
# hack for tool rules
|
||||||
|
input = json.loads(json.dumps(content_part.input))
|
||||||
|
if "id" in input and input["id"].startswith("toolu_") and "function" in input:
|
||||||
|
arguments = str(input["function"]["arguments"])
|
||||||
|
else:
|
||||||
|
arguments = json.dumps(content_part.input, indent=2)
|
||||||
tool_calls = [
|
tool_calls = [
|
||||||
ToolCall(
|
ToolCall(
|
||||||
id=content_part.id,
|
id=content_part.id,
|
||||||
type="function",
|
type="function",
|
||||||
function=FunctionCall(
|
function=FunctionCall(
|
||||||
name=content_part.name,
|
name=content_part.name,
|
||||||
arguments=json.dumps(content_part.input, indent=2),
|
arguments=arguments,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
@ -64,3 +64,4 @@ class ToolRuleType(str, Enum):
|
|||||||
conditional = "conditional"
|
conditional = "conditional"
|
||||||
constrain_child_tools = "constrain_child_tools"
|
constrain_child_tools = "constrain_child_tools"
|
||||||
max_count_per_step = "max_count_per_step"
|
max_count_per_step = "max_count_per_step"
|
||||||
|
parent_last_tool = "parent_last_tool"
|
||||||
|
@ -29,6 +29,19 @@ class ChildToolRule(BaseToolRule):
|
|||||||
return set(self.children) if last_tool == self.tool_name else available_tools
|
return set(self.children) if last_tool == self.tool_name else available_tools
|
||||||
|
|
||||||
|
|
||||||
|
class ParentToolRule(BaseToolRule):
|
||||||
|
"""
|
||||||
|
A ToolRule that only allows a child tool to be called if the parent has been called.
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: Literal[ToolRuleType.parent_last_tool] = ToolRuleType.parent_last_tool
|
||||||
|
children: List[str] = Field(..., description="The children tools that can be invoked.")
|
||||||
|
|
||||||
|
def get_valid_tools(self, tool_call_history: List[str], available_tools: Set[str], last_function_response: Optional[str]) -> Set[str]:
|
||||||
|
last_tool = tool_call_history[-1] if tool_call_history else None
|
||||||
|
return set(self.children) if last_tool == self.tool_name else available_tools - set(self.children)
|
||||||
|
|
||||||
|
|
||||||
class ConditionalToolRule(BaseToolRule):
|
class ConditionalToolRule(BaseToolRule):
|
||||||
"""
|
"""
|
||||||
A ToolRule that conditionally maps to different child tools based on the output.
|
A ToolRule that conditionally maps to different child tools based on the output.
|
||||||
@ -128,6 +141,6 @@ class MaxCountPerStepToolRule(BaseToolRule):
|
|||||||
|
|
||||||
|
|
||||||
ToolRule = Annotated[
|
ToolRule = Annotated[
|
||||||
Union[ChildToolRule, InitToolRule, TerminalToolRule, ConditionalToolRule, ContinueToolRule, MaxCountPerStepToolRule],
|
Union[ChildToolRule, InitToolRule, TerminalToolRule, ConditionalToolRule, ContinueToolRule, MaxCountPerStepToolRule, ParentToolRule],
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
|
@ -43,6 +43,7 @@ from letta.schemas.passage import Passage as PydanticPassage
|
|||||||
from letta.schemas.source import Source as PydanticSource
|
from letta.schemas.source import Source as PydanticSource
|
||||||
from letta.schemas.tool import Tool as PydanticTool
|
from letta.schemas.tool import Tool as PydanticTool
|
||||||
from letta.schemas.tool_rule import ContinueToolRule as PydanticContinueToolRule
|
from letta.schemas.tool_rule import ContinueToolRule as PydanticContinueToolRule
|
||||||
|
from letta.schemas.tool_rule import ParentToolRule as PydanticParentToolRule
|
||||||
from letta.schemas.tool_rule import TerminalToolRule as PydanticTerminalToolRule
|
from letta.schemas.tool_rule import TerminalToolRule as PydanticTerminalToolRule
|
||||||
from letta.schemas.tool_rule import ToolRule as PydanticToolRule
|
from letta.schemas.tool_rule import ToolRule as PydanticToolRule
|
||||||
from letta.schemas.user import User as PydanticUser
|
from letta.schemas.user import User as PydanticUser
|
||||||
@ -159,9 +160,9 @@ class AgentManager:
|
|||||||
tool_rules.append(PydanticTerminalToolRule(tool_name=tool_name))
|
tool_rules.append(PydanticTerminalToolRule(tool_name=tool_name))
|
||||||
elif tool_name in BASE_TOOLS + BASE_MEMORY_TOOLS + BASE_SLEEPTIME_TOOLS:
|
elif tool_name in BASE_TOOLS + BASE_MEMORY_TOOLS + BASE_SLEEPTIME_TOOLS:
|
||||||
tool_rules.append(PydanticContinueToolRule(tool_name=tool_name))
|
tool_rules.append(PydanticContinueToolRule(tool_name=tool_name))
|
||||||
# we may want to add additional rules for sleeptime agents
|
|
||||||
# if agent_create.agent_type == AgentType.sleeptime_agent:
|
if agent_create.agent_type == AgentType.sleeptime_agent:
|
||||||
# tool_rules.append(PydanticChildToolRule(tool_name="view_core_memory_with_line_numbers", children=["core_memory_insert"]))
|
tool_rules.append(PydanticParentToolRule(tool_name="view_core_memory_with_line_numbers", children=["core_memory_insert"]))
|
||||||
|
|
||||||
# if custom rules, check tool rules are valid
|
# if custom rules, check tool rules are valid
|
||||||
if agent_create.tool_rules:
|
if agent_create.tool_rules:
|
||||||
|
@ -9,7 +9,7 @@ from letta.orm.enums import JobType
|
|||||||
from letta.orm.errors import NoResultFound
|
from letta.orm.errors import NoResultFound
|
||||||
from letta.schemas.agent import CreateAgent
|
from letta.schemas.agent import CreateAgent
|
||||||
from letta.schemas.block import CreateBlock
|
from letta.schemas.block import CreateBlock
|
||||||
from letta.schemas.enums import JobStatus
|
from letta.schemas.enums import JobStatus, ToolRuleType
|
||||||
from letta.schemas.group import (
|
from letta.schemas.group import (
|
||||||
DynamicManager,
|
DynamicManager,
|
||||||
GroupCreate,
|
GroupCreate,
|
||||||
@ -507,6 +507,8 @@ async def test_sleeptime_group_chat(server, actor):
|
|||||||
assert "view_core_memory_with_line_numbers" in sleeptime_agent_tools
|
assert "view_core_memory_with_line_numbers" in sleeptime_agent_tools
|
||||||
assert "core_memory_insert" in sleeptime_agent_tools
|
assert "core_memory_insert" in sleeptime_agent_tools
|
||||||
|
|
||||||
|
assert len([rule for rule in sleeptime_agent.tool_rules if rule.type == ToolRuleType.parent_last_tool]) > 0
|
||||||
|
|
||||||
# 5. Send messages and verify run ids
|
# 5. Send messages and verify run ids
|
||||||
message_text = [
|
message_text = [
|
||||||
"my favorite color is orange",
|
"my favorite color is orange",
|
||||||
|
Loading…
Reference in New Issue
Block a user