mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
Add parent tool rule (#1648)
This commit is contained in:
parent
c0e1f793cf
commit
3713393966
@ -226,7 +226,7 @@ def core_memory_insert(agent_state: "AgentState", target_block_label: str, new_m
|
||||
if line_number is None:
|
||||
line_number = len(current_value_list)
|
||||
if replace:
|
||||
current_value_list[line_number] = new_memory
|
||||
current_value_list[line_number - 1] = new_memory
|
||||
else:
|
||||
current_value_list.insert(line_number, new_memory)
|
||||
new_value = "\n".join(current_value_list)
|
||||
|
@ -28,6 +28,7 @@ from letta.schemas.tool_rule import (
|
||||
ContinueToolRule,
|
||||
InitToolRule,
|
||||
MaxCountPerStepToolRule,
|
||||
ParentToolRule,
|
||||
TerminalToolRule,
|
||||
ToolRule,
|
||||
)
|
||||
@ -89,7 +90,7 @@ def serialize_tool_rules(tool_rules: Optional[List[ToolRule]]) -> List[Dict[str,
|
||||
return data
|
||||
|
||||
|
||||
def deserialize_tool_rules(data: Optional[List[Dict]]) -> List[Union[ChildToolRule, InitToolRule, TerminalToolRule, ConditionalToolRule]]:
|
||||
def deserialize_tool_rules(data: Optional[List[Dict]]) -> List[ToolRule]:
|
||||
"""Convert a list of dictionaries back into ToolRule objects."""
|
||||
if not data:
|
||||
return []
|
||||
@ -99,7 +100,7 @@ def deserialize_tool_rules(data: Optional[List[Dict]]) -> List[Union[ChildToolRu
|
||||
|
||||
def deserialize_tool_rule(
|
||||
data: Dict,
|
||||
) -> Union[ChildToolRule, InitToolRule, TerminalToolRule, ConditionalToolRule, ContinueToolRule, MaxCountPerStepToolRule]:
|
||||
) -> ToolRule:
|
||||
"""Deserialize a dictionary to the appropriate ToolRule subclass based on 'type'."""
|
||||
rule_type = ToolRuleType(data.get("type"))
|
||||
|
||||
@ -118,6 +119,8 @@ def deserialize_tool_rule(
|
||||
return ContinueToolRule(**data)
|
||||
elif rule_type == ToolRuleType.max_count_per_step:
|
||||
return MaxCountPerStepToolRule(**data)
|
||||
elif rule_type == ToolRuleType.parent_last_tool:
|
||||
return ParentToolRule(**data)
|
||||
raise ValueError(f"Unknown ToolRule type: {rule_type}")
|
||||
|
||||
|
||||
|
@ -10,6 +10,7 @@ from letta.schemas.tool_rule import (
|
||||
ContinueToolRule,
|
||||
InitToolRule,
|
||||
MaxCountPerStepToolRule,
|
||||
ParentToolRule,
|
||||
TerminalToolRule,
|
||||
)
|
||||
|
||||
@ -33,6 +34,9 @@ class ToolRulesSolver(BaseModel):
|
||||
child_based_tool_rules: List[Union[ChildToolRule, ConditionalToolRule, MaxCountPerStepToolRule]] = Field(
|
||||
default_factory=list, description="Standard tool rules for controlling execution sequence and allowed transitions."
|
||||
)
|
||||
parent_tool_rules: List[ParentToolRule] = Field(
|
||||
default_factory=list, description="Filter tool rules to be used to filter out tools from the available set."
|
||||
)
|
||||
terminal_tool_rules: List[TerminalToolRule] = Field(
|
||||
default_factory=list, description="Terminal tool rules that end the agent loop if called."
|
||||
)
|
||||
@ -44,6 +48,7 @@ class ToolRulesSolver(BaseModel):
|
||||
init_tool_rules: Optional[List[InitToolRule]] = None,
|
||||
continue_tool_rules: Optional[List[ContinueToolRule]] = None,
|
||||
child_based_tool_rules: Optional[List[Union[ChildToolRule, ConditionalToolRule, MaxCountPerStepToolRule]]] = None,
|
||||
parent_tool_rules: Optional[List[ParentToolRule]] = None,
|
||||
terminal_tool_rules: Optional[List[TerminalToolRule]] = None,
|
||||
tool_call_history: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
@ -52,6 +57,7 @@ class ToolRulesSolver(BaseModel):
|
||||
init_tool_rules=init_tool_rules or [],
|
||||
continue_tool_rules=continue_tool_rules or [],
|
||||
child_based_tool_rules=child_based_tool_rules or [],
|
||||
parent_tool_rules=parent_tool_rules or [],
|
||||
terminal_tool_rules=terminal_tool_rules or [],
|
||||
tool_call_history=tool_call_history or [],
|
||||
**kwargs,
|
||||
@ -78,6 +84,9 @@ class ToolRulesSolver(BaseModel):
|
||||
elif rule.type == ToolRuleType.max_count_per_step:
|
||||
assert isinstance(rule, MaxCountPerStepToolRule)
|
||||
self.child_based_tool_rules.append(rule)
|
||||
elif rule.type == ToolRuleType.parent_last_tool:
|
||||
assert isinstance(rule, ParentToolRule)
|
||||
self.parent_tool_rules.append(rule)
|
||||
|
||||
def register_tool_call(self, tool_name: str):
|
||||
"""Update the internal state to track tool call history."""
|
||||
@ -102,13 +111,14 @@ class ToolRulesSolver(BaseModel):
|
||||
# If there are init tool rules, only return those defined in the init tool rules
|
||||
return [rule.tool_name for rule in self.init_tool_rules]
|
||||
else:
|
||||
# Otherwise, return all the available tools
|
||||
# Otherwise, return all tools besides those constrained by parent tool rules
|
||||
available_tools = available_tools - set.union(set(), *(set(rule.children) for rule in self.parent_tool_rules))
|
||||
return list(available_tools)
|
||||
else:
|
||||
# Collect valid tools from all child-based rules
|
||||
valid_tool_sets = [
|
||||
rule.get_valid_tools(self.tool_call_history, available_tools, last_function_response)
|
||||
for rule in self.child_based_tool_rules
|
||||
for rule in self.child_based_tool_rules + self.parent_tool_rules
|
||||
]
|
||||
|
||||
# Compute intersection of all valid tool sets
|
||||
|
@ -347,13 +347,19 @@ class AnthropicClient(LLMClientBase):
|
||||
if content_part.type == "text":
|
||||
content = strip_xml_tags(string=content_part.text, tag="thinking")
|
||||
if content_part.type == "tool_use":
|
||||
# hack for tool rules
|
||||
input = json.loads(json.dumps(content_part.input))
|
||||
if "id" in input and input["id"].startswith("toolu_") and "function" in input:
|
||||
arguments = str(input["function"]["arguments"])
|
||||
else:
|
||||
arguments = json.dumps(content_part.input, indent=2)
|
||||
tool_calls = [
|
||||
ToolCall(
|
||||
id=content_part.id,
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=content_part.name,
|
||||
arguments=json.dumps(content_part.input, indent=2),
|
||||
arguments=arguments,
|
||||
),
|
||||
)
|
||||
]
|
||||
|
@ -64,3 +64,4 @@ class ToolRuleType(str, Enum):
|
||||
conditional = "conditional"
|
||||
constrain_child_tools = "constrain_child_tools"
|
||||
max_count_per_step = "max_count_per_step"
|
||||
parent_last_tool = "parent_last_tool"
|
||||
|
@ -29,6 +29,19 @@ class ChildToolRule(BaseToolRule):
|
||||
return set(self.children) if last_tool == self.tool_name else available_tools
|
||||
|
||||
|
||||
class ParentToolRule(BaseToolRule):
|
||||
"""
|
||||
A ToolRule that only allows a child tool to be called if the parent has been called.
|
||||
"""
|
||||
|
||||
type: Literal[ToolRuleType.parent_last_tool] = ToolRuleType.parent_last_tool
|
||||
children: List[str] = Field(..., description="The children tools that can be invoked.")
|
||||
|
||||
def get_valid_tools(self, tool_call_history: List[str], available_tools: Set[str], last_function_response: Optional[str]) -> Set[str]:
|
||||
last_tool = tool_call_history[-1] if tool_call_history else None
|
||||
return set(self.children) if last_tool == self.tool_name else available_tools - set(self.children)
|
||||
|
||||
|
||||
class ConditionalToolRule(BaseToolRule):
|
||||
"""
|
||||
A ToolRule that conditionally maps to different child tools based on the output.
|
||||
@ -128,6 +141,6 @@ class MaxCountPerStepToolRule(BaseToolRule):
|
||||
|
||||
|
||||
ToolRule = Annotated[
|
||||
Union[ChildToolRule, InitToolRule, TerminalToolRule, ConditionalToolRule, ContinueToolRule, MaxCountPerStepToolRule],
|
||||
Union[ChildToolRule, InitToolRule, TerminalToolRule, ConditionalToolRule, ContinueToolRule, MaxCountPerStepToolRule, ParentToolRule],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
@ -43,6 +43,7 @@ from letta.schemas.passage import Passage as PydanticPassage
|
||||
from letta.schemas.source import Source as PydanticSource
|
||||
from letta.schemas.tool import Tool as PydanticTool
|
||||
from letta.schemas.tool_rule import ContinueToolRule as PydanticContinueToolRule
|
||||
from letta.schemas.tool_rule import ParentToolRule as PydanticParentToolRule
|
||||
from letta.schemas.tool_rule import TerminalToolRule as PydanticTerminalToolRule
|
||||
from letta.schemas.tool_rule import ToolRule as PydanticToolRule
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
@ -159,9 +160,9 @@ class AgentManager:
|
||||
tool_rules.append(PydanticTerminalToolRule(tool_name=tool_name))
|
||||
elif tool_name in BASE_TOOLS + BASE_MEMORY_TOOLS + BASE_SLEEPTIME_TOOLS:
|
||||
tool_rules.append(PydanticContinueToolRule(tool_name=tool_name))
|
||||
# we may want to add additional rules for sleeptime agents
|
||||
# if agent_create.agent_type == AgentType.sleeptime_agent:
|
||||
# tool_rules.append(PydanticChildToolRule(tool_name="view_core_memory_with_line_numbers", children=["core_memory_insert"]))
|
||||
|
||||
if agent_create.agent_type == AgentType.sleeptime_agent:
|
||||
tool_rules.append(PydanticParentToolRule(tool_name="view_core_memory_with_line_numbers", children=["core_memory_insert"]))
|
||||
|
||||
# if custom rules, check tool rules are valid
|
||||
if agent_create.tool_rules:
|
||||
|
@ -9,7 +9,7 @@ from letta.orm.enums import JobType
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.schemas.agent import CreateAgent
|
||||
from letta.schemas.block import CreateBlock
|
||||
from letta.schemas.enums import JobStatus
|
||||
from letta.schemas.enums import JobStatus, ToolRuleType
|
||||
from letta.schemas.group import (
|
||||
DynamicManager,
|
||||
GroupCreate,
|
||||
@ -507,6 +507,8 @@ async def test_sleeptime_group_chat(server, actor):
|
||||
assert "view_core_memory_with_line_numbers" in sleeptime_agent_tools
|
||||
assert "core_memory_insert" in sleeptime_agent_tools
|
||||
|
||||
assert len([rule for rule in sleeptime_agent.tool_rules if rule.type == ToolRuleType.parent_last_tool]) > 0
|
||||
|
||||
# 5. Send messages and verify run ids
|
||||
message_text = [
|
||||
"my favorite color is orange",
|
||||
|
Loading…
Reference in New Issue
Block a user