fix: Fix updating tools (#1886)

Co-authored-by: Matt Zhou <mattzhou@Matts-MacBook-Pro.local>
This commit is contained in:
Matthew Zhou 2024-10-15 16:51:18 -07:00 committed by GitHub
parent 508680f086
commit 174f31b32b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 349 additions and 50 deletions

View File

@ -96,6 +96,12 @@ class AbstractClient(object):
):
raise NotImplementedError
def add_tool_to_agent(self, agent_id: str, tool_id: str):
raise NotImplementedError
def remove_tool_from_agent(self, agent_id: str, tool_id: str):
raise NotImplementedError
def rename_agent(self, agent_id: str, new_name: str):
raise NotImplementedError
@ -474,6 +480,39 @@ class RESTClient(AbstractClient):
raise ValueError(f"Failed to update agent: {response.text}")
return AgentState(**response.json())
def add_tool_to_agent(self, agent_id: str, tool_id: str):
"""
Add tool to an existing agent
Args:
agent_id (str): ID of the agent
tool_id (str): A tool id
Returns:
agent_state (AgentState): State of the updated agent
"""
response = requests.patch(f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/add-tool/{tool_id}", headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to update agent: {response.text}")
return AgentState(**response.json())
def remove_tool_from_agent(self, agent_id: str, tool_id: str):
"""
Removes tools from an existing agent
Args:
agent_id (str): ID of the agent
tool_id (str): The tool id
Returns:
agent_state (AgentState): State of the updated agent
"""
response = requests.patch(f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/remove-tool/{tool_id}", headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to update agent: {response.text}")
return AgentState(**response.json())
def rename_agent(self, agent_id: str, new_name: str):
"""
Rename an agent
@ -1653,6 +1692,36 @@ class LocalClient(AbstractClient):
)
return agent_state
def add_tool_to_agent(self, agent_id: str, tool_id: str):
"""
Add tool to an existing agent
Args:
agent_id (str): ID of the agent
tool_id (str): A tool id
Returns:
agent_state (AgentState): State of the updated agent
"""
self.interface.clear()
agent_state = self.server.add_tool_to_agent(agent_id=agent_id, tool_id=tool_id, user_id=self.user_id)
return agent_state
def remove_tool_from_agent(self, agent_id: str, tool_id: str):
"""
Removes tools from an existing agent
Args:
agent_id (str): ID of the agent
tool_id (str): The tool id
Returns:
agent_state (AgentState): State of the updated agent
"""
self.interface.clear()
agent_state = self.server.remove_tool_from_agent(agent_id=agent_id, tool_id=tool_id, user_id=self.user_id)
return agent_state
def rename_agent(self, agent_id: str, new_name: str):
"""
Rename an agent
@ -2081,30 +2150,37 @@ class LocalClient(AbstractClient):
Returns:
None
"""
existing_tool_id = self.get_tool_id(tool.name)
if existing_tool_id:
if self.tool_with_name_and_user_id_exists(tool):
if update:
self.server.update_tool(
return self.server.update_tool(
ToolUpdate(
id=existing_tool_id,
id=tool.id,
description=tool.description,
source_type=tool.source_type,
source_code=tool.source_code,
tags=tool.tags,
json_schema=tool.json_schema,
name=tool.name,
)
),
self.user_id,
)
else:
raise ValueError(f"Tool with name {tool.name} already exists")
# call server function
return self.server.create_tool(
ToolCreate(
source_type=tool.source_type, source_code=tool.source_code, name=tool.name, json_schema=tool.json_schema, tags=tool.tags
),
user_id=self.user_id,
update=update,
)
raise ValueError(f"Tool with id={tool.id} and name={tool.name}already exists")
else:
# call server function
return self.server.create_tool(
ToolCreate(
id=tool.id,
description=tool.description,
source_type=tool.source_type,
source_code=tool.source_code,
name=tool.name,
json_schema=tool.json_schema,
tags=tool.tags,
),
user_id=self.user_id,
update=update,
)
# TODO: Use the above function `add_tool` here as there is duplicate logic
def create_tool(
@ -2170,7 +2246,9 @@ class LocalClient(AbstractClient):
source_type = "python"
return self.server.update_tool(ToolUpdate(id=id, source_type=source_type, source_code=source_code, tags=tags, name=name))
return self.server.update_tool(
ToolUpdate(id=id, source_type=source_type, source_code=source_code, tags=tags, name=name), self.user_id
)
def list_tools(self):
"""
@ -2215,7 +2293,17 @@ class LocalClient(AbstractClient):
"""
return self.server.get_tool_id(name, self.user_id)
# data sources
def tool_with_name_and_user_id_exists(self, tool: Tool) -> bool:
"""
Check if the tool with name and user_id exists
Args:
tool (Tool): the tool
Returns:
(bool): True if the id exists, False otherwise.
"""
return self.server.tool_with_name_and_user_id_exists(tool, self.user_id)
def load_data(self, connector: DataConnector, source_name: str):
"""

View File

@ -1,6 +1,9 @@
import re
from datetime import datetime
from typing import Optional
from IPython.display import HTML, display
from sqlalchemy.testing.plugin.plugin_base import warnings
from letta.local_llm.constants import (
ASSISTANT_MESSAGE_CLI_SYMBOL,
@ -64,3 +67,15 @@ def pprint(messages):
html_content += "</div>"
display(HTML(html_content))
def derive_function_name_regex(function_string: str) -> Optional[str]:
# Regular expression to match the function name
match = re.search(r"def\s+([a-zA-Z_]\w*)\s*\(", function_string)
if match:
function_name = match.group(1)
return function_name
else:
warnings.warn("No function name found.")
return None

View File

@ -577,7 +577,7 @@ class MetadataStore:
@enforce_types
def create_tool(self, tool: Tool):
with self.session_maker() as session:
if self.get_tool(tool_name=tool.name, user_id=tool.user_id) is not None:
if self.get_tool(tool_id=tool.id, tool_name=tool.name, user_id=tool.user_id) is not None:
raise ValueError(f"Tool with name {tool.name} already exists")
session.add(ToolModel(**vars(tool)))
session.commit()
@ -620,9 +620,9 @@ class MetadataStore:
session.commit()
@enforce_types
def update_tool(self, tool: Tool):
def update_tool(self, tool_id: str, tool: Tool):
with self.session_maker() as session:
session.query(ToolModel).filter(ToolModel.id == tool.id).update(vars(tool))
session.query(ToolModel).filter(ToolModel.id == tool_id).update(vars(tool))
session.commit()
@enforce_types
@ -815,6 +815,15 @@ class MetadataStore:
results = session.query(ToolModel).filter(ToolModel.name == tool_name).filter(ToolModel.user_id == None).all()
if user_id:
results += session.query(ToolModel).filter(ToolModel.name == tool_name).filter(ToolModel.user_id == user_id).all()
if len(results) == 0:
return None
# assert len(results) == 1, f"Expected 1 result, got {len(results)}"
return results[0].to_record()
@enforce_types
def get_tool_with_name_and_user_id(self, tool_name: Optional[str] = None, user_id: Optional[str] = None) -> Optional[ToolModel]:
with self.session_maker() as session:
results = session.query(ToolModel).filter(ToolModel.name == tool_name).filter(ToolModel.user_id == user_id).all()
if len(results) == 0:
return None
assert len(results) == 1, f"Expected 1 result, got {len(results)}"

View File

@ -128,3 +128,8 @@ class AgentStepResponse(BaseModel):
..., description="Whether the agent step ended because the in-context memory is near its limit."
)
usage: UsageStatistics = Field(..., description="Usage statistics of the LLM call during the agent's step.")
class RemoveToolsFromAgent(BaseModel):
agent_id: str = Field(..., description="The id of the agent.")
tool_ids: Optional[List[str]] = Field(None, description="The tools to be removed from the agent.")

View File

@ -176,7 +176,9 @@ class Tool(BaseTool):
class ToolCreate(BaseTool):
id: Optional[str] = Field(None, description="The unique identifier of the tool. If this is not provided, it will be autogenerated.")
name: Optional[str] = Field(None, description="The name of the function (auto-generated from source_code if not provided).")
description: Optional[str] = Field(None, description="The description of the tool.")
tags: List[str] = Field([], description="Metadata tags.")
source_code: str = Field(..., description="The source code of the function.")
json_schema: Optional[Dict] = Field(
@ -187,6 +189,7 @@ class ToolCreate(BaseTool):
class ToolUpdate(ToolCreate):
id: str = Field(..., description="The unique identifier of the tool.")
description: Optional[str] = Field(None, description="The description of the tool.")
name: Optional[str] = Field(None, description="The name of the function.")
tags: Optional[List[str]] = Field(None, description="Metadata tags.")
source_code: Optional[str] = Field(None, description="The source code of the function.")

View File

@ -100,6 +100,34 @@ def update_agent(
return server.update_agent(update_agent, user_id=actor.id)
@router.patch("/{agent_id}/add-tool/{tool_id}", response_model=AgentState, operation_id="add_tool_to_agent")
def add_tool_to_agent(
agent_id: str,
tool_id: str,
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""Add tools to an exsiting agent"""
actor = server.get_user_or_default(user_id=user_id)
update_agent.id = agent_id
return server.add_tool_to_agent(agent_id=agent_id, tool_id=tool_id, user_id=actor.id)
@router.patch("/{agent_id}/remove-tool/{tool_id}", response_model=AgentState, operation_id="remove_tool_from_agent")
def remove_tool_from_agent(
agent_id: str,
tool_id: str,
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""Add tools to an exsiting agent"""
actor = server.get_user_or_default(user_id=user_id)
update_agent.id = agent_id
return server.remove_tool_from_agent(agent_id=agent_id, tool_id=tool_id, user_id=actor.id)
@router.get("/{agent_id}", response_model=AgentState, operation_id="get_agent")
def get_agent_state(
agent_id: str,

View File

@ -105,4 +105,4 @@ def update_tool(
"""
assert tool_id == request.id, "Tool ID in path must match tool ID in request body"
# actor = server.get_user_or_default(user_id=user_id)
return server.update_tool(request)
return server.update_tool(request, user_id)

View File

@ -16,6 +16,7 @@ import letta.system as system
from letta.agent import Agent, save_agent
from letta.agent_store.db import attach_base
from letta.agent_store.storage import StorageConnector, TableType
from letta.client.utils import derive_function_name_regex
from letta.credentials import LettaCredentials
from letta.data_sources.connectors import DataConnector, load_data
@ -965,6 +966,80 @@ class SyncServer(Server):
# TODO: probably reload the agent somehow?
return letta_agent.agent_state
def add_tool_to_agent(
self,
agent_id: str,
tool_id: str,
user_id: str,
):
"""Update the agents core memory block, return the new state"""
if self.ms.get_user(user_id=user_id) is None:
raise ValueError(f"User user_id={user_id} does not exist")
if self.ms.get_agent(agent_id=agent_id) is None:
raise ValueError(f"Agent agent_id={agent_id} does not exist")
# Get the agent object (loaded in memory)
letta_agent = self._get_or_load_agent(agent_id=agent_id)
# Get all the tool objects from the request
tool_objs = []
tool_obj = self.ms.get_tool(tool_id=tool_id, user_id=user_id)
assert tool_obj, f"Tool with id={tool_id} does not exist"
tool_objs.append(tool_obj)
for tool in letta_agent.tools:
tool_obj = self.ms.get_tool(tool_id=tool.id, user_id=user_id)
assert tool_obj, f"Tool with id={tool.id} does not exist"
# If it's not the already added tool
if tool_obj.id != tool_id:
tool_objs.append(tool_obj)
# replace the list of tool names ("ids") inside the agent state
letta_agent.agent_state.tools = [tool.name for tool in tool_objs]
# then attempt to link the tools modules
letta_agent.link_tools(tool_objs)
# save the agent
save_agent(letta_agent, self.ms)
return letta_agent.agent_state
def remove_tool_from_agent(
self,
agent_id: str,
tool_id: str,
user_id: str,
):
"""Update the agents core memory block, return the new state"""
if self.ms.get_user(user_id=user_id) is None:
raise ValueError(f"User user_id={user_id} does not exist")
if self.ms.get_agent(agent_id=agent_id) is None:
raise ValueError(f"Agent agent_id={agent_id} does not exist")
# Get the agent object (loaded in memory)
letta_agent = self._get_or_load_agent(agent_id=agent_id)
# Get all the tool_objs
tool_objs = []
for tool in letta_agent.tools:
tool_obj = self.ms.get_tool(tool_id=tool.id, user_id=user_id)
assert tool_obj, f"Tool with id={tool.id} does not exist"
# If it's not the tool we want to remove
if tool_obj.id != tool_id:
tool_objs.append(tool_obj)
# replace the list of tool names ("ids") inside the agent state
letta_agent.agent_state.tools = [tool.name for tool in tool_objs]
# then attempt to link the tools modules
letta_agent.link_tools(tool_objs)
# save the agent
save_agent(letta_agent, self.ms)
return letta_agent.agent_state
def _agent_state_to_config(self, agent_state: AgentState) -> dict:
"""Convert AgentState to a dict for a JSON response"""
assert agent_state is not None
@ -1751,6 +1826,15 @@ class SyncServer(Server):
"""Get tool by ID."""
return self.ms.get_tool(tool_id=tool_id)
def tool_with_name_and_user_id_exists(self, tool: Tool, user_id: Optional[str] = None) -> bool:
"""Check if tool exists"""
tool = self.ms.get_tool_with_name_and_user_id(tool_name=tool.name, user_id=user_id)
if tool is None:
return False
else:
return True
def get_tool_id(self, name: str, user_id: str) -> Optional[str]:
"""Get tool ID from name and user_id."""
tool = self.ms.get_tool(tool_name=name, user_id=user_id)
@ -1758,16 +1842,27 @@ class SyncServer(Server):
return None
return tool.id
def update_tool(
self,
request: ToolUpdate,
) -> Tool:
def update_tool(self, request: ToolUpdate, user_id: Optional[str] = None) -> Tool:
"""Update an existing tool"""
existing_tool = self.ms.get_tool(tool_id=request.id)
if not existing_tool:
raise ValueError(f"Tool does not exist")
if request.name:
existing_tool = self.ms.get_tool_with_name_and_user_id(tool_name=request.name, user_id=user_id)
if existing_tool is None:
raise ValueError(f"Tool with name={request.name}, user_id={user_id} does not exist")
else:
existing_tool = self.ms.get_tool(tool_id=request.id)
if existing_tool is None:
raise ValueError(f"Tool with id={request.id} does not exist")
# Preserve the original tool id
# As we can override the tool id as well
# This is probably bad design if this is exposed to users...
original_id = existing_tool.id
# override updated fields
if request.id:
existing_tool.id = request.id
if request.description:
existing_tool.description = request.description
if request.source_code:
existing_tool.source_code = request.source_code
if request.source_type:
@ -1776,10 +1871,15 @@ class SyncServer(Server):
existing_tool.tags = request.tags
if request.json_schema:
existing_tool.json_schema = request.json_schema
# If name is explicitly provided here, overide the tool name
if request.name:
existing_tool.name = request.name
# Otherwise, if there's no name, and there's source code, we try to derive the name
elif request.source_code:
existing_tool.name = derive_function_name_regex(request.source_code)
self.ms.update_tool(existing_tool)
self.ms.update_tool(original_id, existing_tool)
return self.ms.get_tool(tool_id=request.id)
def create_tool(self, request: ToolCreate, user_id: Optional[str] = None, update: bool = True) -> Tool: # TODO: add other fields
@ -1817,15 +1917,23 @@ class SyncServer(Server):
assert request.name, f"Tool name must be provided in json_schema {json_schema}. This should never happen."
# check if already exists:
existing_tool = self.ms.get_tool(tool_name=request.name, user_id=user_id)
existing_tool = self.ms.get_tool(tool_id=request.id, tool_name=request.name, user_id=user_id)
if existing_tool:
if update:
updated_tool = self.update_tool(ToolUpdate(id=existing_tool.id, **vars(request)))
# id is an optional field, so we will fill it with the existing tool id
if not request.id:
request.id = existing_tool.id
updated_tool = self.update_tool(ToolUpdate(**vars(request)), user_id)
assert updated_tool is not None, f"Failed to update tool {request.name}"
return updated_tool
else:
raise ValueError(f"Tool {request.name} already exists and update=False")
# check for description
description = None
if request.description:
description = request.description
tool = Tool(
name=request.name,
source_code=request.source_code,
@ -1833,9 +1941,14 @@ class SyncServer(Server):
tags=request.tags,
json_schema=json_schema,
user_id=user_id,
description=description,
)
if request.id:
tool.id = request.id
self.ms.create_tool(tool)
created_tool = self.ms.get_tool(tool_name=request.name, user_id=user_id)
created_tool = self.ms.get_tool(tool_id=tool.id, user_id=user_id)
return created_tool
def delete_tool(self, tool_id: str):

View File

@ -1,3 +1,4 @@
import uuid
from typing import Union
import pytest
@ -9,6 +10,7 @@ from letta.schemas.block import Block
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import BasicBlockMemory, ChatMemory, Memory
from letta.schemas.tool import Tool
@pytest.fixture(scope="module")
@ -16,12 +18,17 @@ def client():
client = create_client()
client.set_default_llm_config(LLMConfig.default_config("gpt-4o-mini"))
client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai"))
yield client
@pytest.fixture(scope="module")
def agent(client):
agent_state = client.create_agent(name="test_agent")
# Generate uuid for agent name for this example
namespace = uuid.NAMESPACE_DNS
agent_uuid = str(uuid.uuid5(namespace, "test_new_client_test_agent"))
agent_state = client.create_agent(name=agent_uuid)
yield agent_state
client.delete_agent(agent_state.id)
@ -114,6 +121,52 @@ def test_agent(client: Union[LocalClient, RESTClient]):
client.delete_agent(agent_state_test.id)
def test_agent_add_remove_tools(client: Union[LocalClient, RESTClient], agent):
# Create and add two tools to the client
# tool 1
from composio_langchain import Action
github_tool = Tool.get_composio_tool(action=Action.GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER)
client.add_tool(github_tool)
# tool 2
from crewai_tools import ScrapeWebsiteTool
scrape_website_tool = Tool.from_crewai(ScrapeWebsiteTool(website_url="https://www.example.com"))
client.add_tool(scrape_website_tool)
# assert both got added
tools = client.list_tools()
assert github_tool.id in [t.id for t in tools]
assert scrape_website_tool.id in [t.id for t in tools]
# Assert that all combinations of tool_names, tool_user_ids are unique
combinations = [(t.name, t.user_id) for t in tools]
assert len(combinations) == len(set(combinations))
# create agent
agent_state = agent
curr_num_tools = len(agent_state.tools)
# add both tools to agent in steps
agent_state = client.add_tool_to_agent(agent_id=agent_state.id, tool_id=github_tool.id)
agent_state = client.add_tool_to_agent(agent_id=agent_state.id, tool_id=scrape_website_tool.id)
# confirm that both tools are in the agent state
curr_tools = agent_state.tools
assert len(curr_tools) == curr_num_tools + 2
assert github_tool.name in curr_tools
assert scrape_website_tool.name in curr_tools
# remove only the github tool
agent_state = client.remove_tool_from_agent(agent_id=agent_state.id, tool_id=github_tool.id)
# confirm that only one tool left
curr_tools = agent_state.tools
assert len(curr_tools) == curr_num_tools + 1
assert github_tool.name not in curr_tools
assert scrape_website_tool.name in curr_tools
def test_agent_with_shared_blocks(client: Union[LocalClient, RESTClient]):
persona_block = Block(name="persona", value="Here to test things!", label="persona", user_id=client.user_id)
human_block = Block(name="human", value="Me Human, I swear. Beep boop.", label="human", user_id=client.user_id)
@ -242,8 +295,7 @@ def test_tools(client: Union[LocalClient, RESTClient]):
print(msg)
# create tool
len(client.list_tools())
tool = client.create_tool(print_tool, tags=["extras"])
tool = client.create_tool(func=print_tool, tags=["extras"])
# list tools
tools = client.list_tools()
@ -258,19 +310,13 @@ def test_tools(client: Union[LocalClient, RESTClient]):
assert client.get_tool(tool.id).tags == extras2
# update tool: source code
client.update_tool(tool.id, name="print_tool2", func=print_tool2)
client.update_tool(tool.id, func=print_tool2)
assert client.get_tool(tool.id).name == "print_tool2"
## delete tool
# client.delete_tool(tool.id)
# assert len(client.list_tools()) == orig_tool_length
def test_tools_from_composio_basic(client: Union[LocalClient, RESTClient]):
from composio_langchain import Action
from letta.schemas.tool import Tool
# Create a `LocalClient` (you can also use a `RESTClient`, see the letta_rest_client.py example)
client = create_client()
@ -292,8 +338,6 @@ def test_tools_from_crewai(client: Union[LocalClient, RESTClient]):
from crewai_tools import ScrapeWebsiteTool
from letta.schemas.tool import Tool
crewai_tool = ScrapeWebsiteTool()
# Translate to memGPT Tool
@ -329,8 +373,6 @@ def test_tools_from_crewai_with_params(client: Union[LocalClient, RESTClient]):
from crewai_tools import ScrapeWebsiteTool
from letta.schemas.tool import Tool
crewai_tool = ScrapeWebsiteTool(website_url="https://www.example.com")
# Translate to memGPT Tool
@ -363,8 +405,6 @@ def test_tools_from_langchain(client: Union[LocalClient, RESTClient]):
from langchain_community.tools import WikipediaQueryRun
from langchain_community.utilities import WikipediaAPIWrapper
from letta.schemas.tool import Tool
api_wrapper = WikipediaAPIWrapper(top_k_results=1, doc_content_chars_max=100)
langchain_tool = WikipediaQueryRun(api_wrapper=api_wrapper)
@ -397,8 +437,6 @@ def test_tool_creation_langchain_missing_imports(client: Union[LocalClient, REST
from langchain_community.tools import WikipediaQueryRun
from langchain_community.utilities import WikipediaAPIWrapper
from letta.schemas.tool import Tool
api_wrapper = WikipediaAPIWrapper(top_k_results=1, doc_content_chars_max=100)
langchain_tool = WikipediaQueryRun(api_wrapper=api_wrapper)