mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
fix: Fix updating tools (#1886)
Co-authored-by: Matt Zhou <mattzhou@Matts-MacBook-Pro.local>
This commit is contained in:
parent
508680f086
commit
174f31b32b
@ -96,6 +96,12 @@ class AbstractClient(object):
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
def add_tool_to_agent(self, agent_id: str, tool_id: str):
|
||||
raise NotImplementedError
|
||||
|
||||
def remove_tool_from_agent(self, agent_id: str, tool_id: str):
|
||||
raise NotImplementedError
|
||||
|
||||
def rename_agent(self, agent_id: str, new_name: str):
|
||||
raise NotImplementedError
|
||||
|
||||
@ -474,6 +480,39 @@ class RESTClient(AbstractClient):
|
||||
raise ValueError(f"Failed to update agent: {response.text}")
|
||||
return AgentState(**response.json())
|
||||
|
||||
def add_tool_to_agent(self, agent_id: str, tool_id: str):
|
||||
"""
|
||||
Add tool to an existing agent
|
||||
|
||||
Args:
|
||||
agent_id (str): ID of the agent
|
||||
tool_id (str): A tool id
|
||||
|
||||
Returns:
|
||||
agent_state (AgentState): State of the updated agent
|
||||
"""
|
||||
response = requests.patch(f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/add-tool/{tool_id}", headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to update agent: {response.text}")
|
||||
return AgentState(**response.json())
|
||||
|
||||
def remove_tool_from_agent(self, agent_id: str, tool_id: str):
|
||||
"""
|
||||
Removes tools from an existing agent
|
||||
|
||||
Args:
|
||||
agent_id (str): ID of the agent
|
||||
tool_id (str): The tool id
|
||||
|
||||
Returns:
|
||||
agent_state (AgentState): State of the updated agent
|
||||
"""
|
||||
|
||||
response = requests.patch(f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/remove-tool/{tool_id}", headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to update agent: {response.text}")
|
||||
return AgentState(**response.json())
|
||||
|
||||
def rename_agent(self, agent_id: str, new_name: str):
|
||||
"""
|
||||
Rename an agent
|
||||
@ -1653,6 +1692,36 @@ class LocalClient(AbstractClient):
|
||||
)
|
||||
return agent_state
|
||||
|
||||
def add_tool_to_agent(self, agent_id: str, tool_id: str):
|
||||
"""
|
||||
Add tool to an existing agent
|
||||
|
||||
Args:
|
||||
agent_id (str): ID of the agent
|
||||
tool_id (str): A tool id
|
||||
|
||||
Returns:
|
||||
agent_state (AgentState): State of the updated agent
|
||||
"""
|
||||
self.interface.clear()
|
||||
agent_state = self.server.add_tool_to_agent(agent_id=agent_id, tool_id=tool_id, user_id=self.user_id)
|
||||
return agent_state
|
||||
|
||||
def remove_tool_from_agent(self, agent_id: str, tool_id: str):
|
||||
"""
|
||||
Removes tools from an existing agent
|
||||
|
||||
Args:
|
||||
agent_id (str): ID of the agent
|
||||
tool_id (str): The tool id
|
||||
|
||||
Returns:
|
||||
agent_state (AgentState): State of the updated agent
|
||||
"""
|
||||
self.interface.clear()
|
||||
agent_state = self.server.remove_tool_from_agent(agent_id=agent_id, tool_id=tool_id, user_id=self.user_id)
|
||||
return agent_state
|
||||
|
||||
def rename_agent(self, agent_id: str, new_name: str):
|
||||
"""
|
||||
Rename an agent
|
||||
@ -2081,30 +2150,37 @@ class LocalClient(AbstractClient):
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
existing_tool_id = self.get_tool_id(tool.name)
|
||||
if existing_tool_id:
|
||||
if self.tool_with_name_and_user_id_exists(tool):
|
||||
if update:
|
||||
self.server.update_tool(
|
||||
return self.server.update_tool(
|
||||
ToolUpdate(
|
||||
id=existing_tool_id,
|
||||
id=tool.id,
|
||||
description=tool.description,
|
||||
source_type=tool.source_type,
|
||||
source_code=tool.source_code,
|
||||
tags=tool.tags,
|
||||
json_schema=tool.json_schema,
|
||||
name=tool.name,
|
||||
)
|
||||
),
|
||||
self.user_id,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Tool with name {tool.name} already exists")
|
||||
|
||||
# call server function
|
||||
return self.server.create_tool(
|
||||
ToolCreate(
|
||||
source_type=tool.source_type, source_code=tool.source_code, name=tool.name, json_schema=tool.json_schema, tags=tool.tags
|
||||
),
|
||||
user_id=self.user_id,
|
||||
update=update,
|
||||
)
|
||||
raise ValueError(f"Tool with id={tool.id} and name={tool.name}already exists")
|
||||
else:
|
||||
# call server function
|
||||
return self.server.create_tool(
|
||||
ToolCreate(
|
||||
id=tool.id,
|
||||
description=tool.description,
|
||||
source_type=tool.source_type,
|
||||
source_code=tool.source_code,
|
||||
name=tool.name,
|
||||
json_schema=tool.json_schema,
|
||||
tags=tool.tags,
|
||||
),
|
||||
user_id=self.user_id,
|
||||
update=update,
|
||||
)
|
||||
|
||||
# TODO: Use the above function `add_tool` here as there is duplicate logic
|
||||
def create_tool(
|
||||
@ -2170,7 +2246,9 @@ class LocalClient(AbstractClient):
|
||||
|
||||
source_type = "python"
|
||||
|
||||
return self.server.update_tool(ToolUpdate(id=id, source_type=source_type, source_code=source_code, tags=tags, name=name))
|
||||
return self.server.update_tool(
|
||||
ToolUpdate(id=id, source_type=source_type, source_code=source_code, tags=tags, name=name), self.user_id
|
||||
)
|
||||
|
||||
def list_tools(self):
|
||||
"""
|
||||
@ -2215,7 +2293,17 @@ class LocalClient(AbstractClient):
|
||||
"""
|
||||
return self.server.get_tool_id(name, self.user_id)
|
||||
|
||||
# data sources
|
||||
def tool_with_name_and_user_id_exists(self, tool: Tool) -> bool:
|
||||
"""
|
||||
Check if the tool with name and user_id exists
|
||||
|
||||
Args:
|
||||
tool (Tool): the tool
|
||||
|
||||
Returns:
|
||||
(bool): True if the id exists, False otherwise.
|
||||
"""
|
||||
return self.server.tool_with_name_and_user_id_exists(tool, self.user_id)
|
||||
|
||||
def load_data(self, connector: DataConnector, source_name: str):
|
||||
"""
|
||||
|
@ -1,6 +1,9 @@
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from IPython.display import HTML, display
|
||||
from sqlalchemy.testing.plugin.plugin_base import warnings
|
||||
|
||||
from letta.local_llm.constants import (
|
||||
ASSISTANT_MESSAGE_CLI_SYMBOL,
|
||||
@ -64,3 +67,15 @@ def pprint(messages):
|
||||
html_content += "</div>"
|
||||
|
||||
display(HTML(html_content))
|
||||
|
||||
|
||||
def derive_function_name_regex(function_string: str) -> Optional[str]:
|
||||
# Regular expression to match the function name
|
||||
match = re.search(r"def\s+([a-zA-Z_]\w*)\s*\(", function_string)
|
||||
|
||||
if match:
|
||||
function_name = match.group(1)
|
||||
return function_name
|
||||
else:
|
||||
warnings.warn("No function name found.")
|
||||
return None
|
||||
|
@ -577,7 +577,7 @@ class MetadataStore:
|
||||
@enforce_types
|
||||
def create_tool(self, tool: Tool):
|
||||
with self.session_maker() as session:
|
||||
if self.get_tool(tool_name=tool.name, user_id=tool.user_id) is not None:
|
||||
if self.get_tool(tool_id=tool.id, tool_name=tool.name, user_id=tool.user_id) is not None:
|
||||
raise ValueError(f"Tool with name {tool.name} already exists")
|
||||
session.add(ToolModel(**vars(tool)))
|
||||
session.commit()
|
||||
@ -620,9 +620,9 @@ class MetadataStore:
|
||||
session.commit()
|
||||
|
||||
@enforce_types
|
||||
def update_tool(self, tool: Tool):
|
||||
def update_tool(self, tool_id: str, tool: Tool):
|
||||
with self.session_maker() as session:
|
||||
session.query(ToolModel).filter(ToolModel.id == tool.id).update(vars(tool))
|
||||
session.query(ToolModel).filter(ToolModel.id == tool_id).update(vars(tool))
|
||||
session.commit()
|
||||
|
||||
@enforce_types
|
||||
@ -815,6 +815,15 @@ class MetadataStore:
|
||||
results = session.query(ToolModel).filter(ToolModel.name == tool_name).filter(ToolModel.user_id == None).all()
|
||||
if user_id:
|
||||
results += session.query(ToolModel).filter(ToolModel.name == tool_name).filter(ToolModel.user_id == user_id).all()
|
||||
if len(results) == 0:
|
||||
return None
|
||||
# assert len(results) == 1, f"Expected 1 result, got {len(results)}"
|
||||
return results[0].to_record()
|
||||
|
||||
@enforce_types
|
||||
def get_tool_with_name_and_user_id(self, tool_name: Optional[str] = None, user_id: Optional[str] = None) -> Optional[ToolModel]:
|
||||
with self.session_maker() as session:
|
||||
results = session.query(ToolModel).filter(ToolModel.name == tool_name).filter(ToolModel.user_id == user_id).all()
|
||||
if len(results) == 0:
|
||||
return None
|
||||
assert len(results) == 1, f"Expected 1 result, got {len(results)}"
|
||||
|
@ -128,3 +128,8 @@ class AgentStepResponse(BaseModel):
|
||||
..., description="Whether the agent step ended because the in-context memory is near its limit."
|
||||
)
|
||||
usage: UsageStatistics = Field(..., description="Usage statistics of the LLM call during the agent's step.")
|
||||
|
||||
|
||||
class RemoveToolsFromAgent(BaseModel):
|
||||
agent_id: str = Field(..., description="The id of the agent.")
|
||||
tool_ids: Optional[List[str]] = Field(None, description="The tools to be removed from the agent.")
|
||||
|
@ -176,7 +176,9 @@ class Tool(BaseTool):
|
||||
|
||||
|
||||
class ToolCreate(BaseTool):
|
||||
id: Optional[str] = Field(None, description="The unique identifier of the tool. If this is not provided, it will be autogenerated.")
|
||||
name: Optional[str] = Field(None, description="The name of the function (auto-generated from source_code if not provided).")
|
||||
description: Optional[str] = Field(None, description="The description of the tool.")
|
||||
tags: List[str] = Field([], description="Metadata tags.")
|
||||
source_code: str = Field(..., description="The source code of the function.")
|
||||
json_schema: Optional[Dict] = Field(
|
||||
@ -187,6 +189,7 @@ class ToolCreate(BaseTool):
|
||||
|
||||
class ToolUpdate(ToolCreate):
|
||||
id: str = Field(..., description="The unique identifier of the tool.")
|
||||
description: Optional[str] = Field(None, description="The description of the tool.")
|
||||
name: Optional[str] = Field(None, description="The name of the function.")
|
||||
tags: Optional[List[str]] = Field(None, description="Metadata tags.")
|
||||
source_code: Optional[str] = Field(None, description="The source code of the function.")
|
||||
|
@ -100,6 +100,34 @@ def update_agent(
|
||||
return server.update_agent(update_agent, user_id=actor.id)
|
||||
|
||||
|
||||
@router.patch("/{agent_id}/add-tool/{tool_id}", response_model=AgentState, operation_id="add_tool_to_agent")
|
||||
def add_tool_to_agent(
|
||||
agent_id: str,
|
||||
tool_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""Add tools to an exsiting agent"""
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
|
||||
update_agent.id = agent_id
|
||||
return server.add_tool_to_agent(agent_id=agent_id, tool_id=tool_id, user_id=actor.id)
|
||||
|
||||
|
||||
@router.patch("/{agent_id}/remove-tool/{tool_id}", response_model=AgentState, operation_id="remove_tool_from_agent")
|
||||
def remove_tool_from_agent(
|
||||
agent_id: str,
|
||||
tool_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""Add tools to an exsiting agent"""
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
|
||||
update_agent.id = agent_id
|
||||
return server.remove_tool_from_agent(agent_id=agent_id, tool_id=tool_id, user_id=actor.id)
|
||||
|
||||
|
||||
@router.get("/{agent_id}", response_model=AgentState, operation_id="get_agent")
|
||||
def get_agent_state(
|
||||
agent_id: str,
|
||||
|
@ -105,4 +105,4 @@ def update_tool(
|
||||
"""
|
||||
assert tool_id == request.id, "Tool ID in path must match tool ID in request body"
|
||||
# actor = server.get_user_or_default(user_id=user_id)
|
||||
return server.update_tool(request)
|
||||
return server.update_tool(request, user_id)
|
||||
|
@ -16,6 +16,7 @@ import letta.system as system
|
||||
from letta.agent import Agent, save_agent
|
||||
from letta.agent_store.db import attach_base
|
||||
from letta.agent_store.storage import StorageConnector, TableType
|
||||
from letta.client.utils import derive_function_name_regex
|
||||
from letta.credentials import LettaCredentials
|
||||
from letta.data_sources.connectors import DataConnector, load_data
|
||||
|
||||
@ -965,6 +966,80 @@ class SyncServer(Server):
|
||||
# TODO: probably reload the agent somehow?
|
||||
return letta_agent.agent_state
|
||||
|
||||
def add_tool_to_agent(
|
||||
self,
|
||||
agent_id: str,
|
||||
tool_id: str,
|
||||
user_id: str,
|
||||
):
|
||||
"""Update the agents core memory block, return the new state"""
|
||||
if self.ms.get_user(user_id=user_id) is None:
|
||||
raise ValueError(f"User user_id={user_id} does not exist")
|
||||
if self.ms.get_agent(agent_id=agent_id) is None:
|
||||
raise ValueError(f"Agent agent_id={agent_id} does not exist")
|
||||
|
||||
# Get the agent object (loaded in memory)
|
||||
letta_agent = self._get_or_load_agent(agent_id=agent_id)
|
||||
|
||||
# Get all the tool objects from the request
|
||||
tool_objs = []
|
||||
tool_obj = self.ms.get_tool(tool_id=tool_id, user_id=user_id)
|
||||
assert tool_obj, f"Tool with id={tool_id} does not exist"
|
||||
tool_objs.append(tool_obj)
|
||||
|
||||
for tool in letta_agent.tools:
|
||||
tool_obj = self.ms.get_tool(tool_id=tool.id, user_id=user_id)
|
||||
assert tool_obj, f"Tool with id={tool.id} does not exist"
|
||||
|
||||
# If it's not the already added tool
|
||||
if tool_obj.id != tool_id:
|
||||
tool_objs.append(tool_obj)
|
||||
|
||||
# replace the list of tool names ("ids") inside the agent state
|
||||
letta_agent.agent_state.tools = [tool.name for tool in tool_objs]
|
||||
|
||||
# then attempt to link the tools modules
|
||||
letta_agent.link_tools(tool_objs)
|
||||
|
||||
# save the agent
|
||||
save_agent(letta_agent, self.ms)
|
||||
return letta_agent.agent_state
|
||||
|
||||
def remove_tool_from_agent(
|
||||
self,
|
||||
agent_id: str,
|
||||
tool_id: str,
|
||||
user_id: str,
|
||||
):
|
||||
"""Update the agents core memory block, return the new state"""
|
||||
if self.ms.get_user(user_id=user_id) is None:
|
||||
raise ValueError(f"User user_id={user_id} does not exist")
|
||||
if self.ms.get_agent(agent_id=agent_id) is None:
|
||||
raise ValueError(f"Agent agent_id={agent_id} does not exist")
|
||||
|
||||
# Get the agent object (loaded in memory)
|
||||
letta_agent = self._get_or_load_agent(agent_id=agent_id)
|
||||
|
||||
# Get all the tool_objs
|
||||
tool_objs = []
|
||||
for tool in letta_agent.tools:
|
||||
tool_obj = self.ms.get_tool(tool_id=tool.id, user_id=user_id)
|
||||
assert tool_obj, f"Tool with id={tool.id} does not exist"
|
||||
|
||||
# If it's not the tool we want to remove
|
||||
if tool_obj.id != tool_id:
|
||||
tool_objs.append(tool_obj)
|
||||
|
||||
# replace the list of tool names ("ids") inside the agent state
|
||||
letta_agent.agent_state.tools = [tool.name for tool in tool_objs]
|
||||
|
||||
# then attempt to link the tools modules
|
||||
letta_agent.link_tools(tool_objs)
|
||||
|
||||
# save the agent
|
||||
save_agent(letta_agent, self.ms)
|
||||
return letta_agent.agent_state
|
||||
|
||||
def _agent_state_to_config(self, agent_state: AgentState) -> dict:
|
||||
"""Convert AgentState to a dict for a JSON response"""
|
||||
assert agent_state is not None
|
||||
@ -1751,6 +1826,15 @@ class SyncServer(Server):
|
||||
"""Get tool by ID."""
|
||||
return self.ms.get_tool(tool_id=tool_id)
|
||||
|
||||
def tool_with_name_and_user_id_exists(self, tool: Tool, user_id: Optional[str] = None) -> bool:
|
||||
"""Check if tool exists"""
|
||||
tool = self.ms.get_tool_with_name_and_user_id(tool_name=tool.name, user_id=user_id)
|
||||
|
||||
if tool is None:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def get_tool_id(self, name: str, user_id: str) -> Optional[str]:
|
||||
"""Get tool ID from name and user_id."""
|
||||
tool = self.ms.get_tool(tool_name=name, user_id=user_id)
|
||||
@ -1758,16 +1842,27 @@ class SyncServer(Server):
|
||||
return None
|
||||
return tool.id
|
||||
|
||||
def update_tool(
|
||||
self,
|
||||
request: ToolUpdate,
|
||||
) -> Tool:
|
||||
def update_tool(self, request: ToolUpdate, user_id: Optional[str] = None) -> Tool:
|
||||
"""Update an existing tool"""
|
||||
existing_tool = self.ms.get_tool(tool_id=request.id)
|
||||
if not existing_tool:
|
||||
raise ValueError(f"Tool does not exist")
|
||||
if request.name:
|
||||
existing_tool = self.ms.get_tool_with_name_and_user_id(tool_name=request.name, user_id=user_id)
|
||||
if existing_tool is None:
|
||||
raise ValueError(f"Tool with name={request.name}, user_id={user_id} does not exist")
|
||||
else:
|
||||
existing_tool = self.ms.get_tool(tool_id=request.id)
|
||||
if existing_tool is None:
|
||||
raise ValueError(f"Tool with id={request.id} does not exist")
|
||||
|
||||
# Preserve the original tool id
|
||||
# As we can override the tool id as well
|
||||
# This is probably bad design if this is exposed to users...
|
||||
original_id = existing_tool.id
|
||||
|
||||
# override updated fields
|
||||
if request.id:
|
||||
existing_tool.id = request.id
|
||||
if request.description:
|
||||
existing_tool.description = request.description
|
||||
if request.source_code:
|
||||
existing_tool.source_code = request.source_code
|
||||
if request.source_type:
|
||||
@ -1776,10 +1871,15 @@ class SyncServer(Server):
|
||||
existing_tool.tags = request.tags
|
||||
if request.json_schema:
|
||||
existing_tool.json_schema = request.json_schema
|
||||
|
||||
# If name is explicitly provided here, overide the tool name
|
||||
if request.name:
|
||||
existing_tool.name = request.name
|
||||
# Otherwise, if there's no name, and there's source code, we try to derive the name
|
||||
elif request.source_code:
|
||||
existing_tool.name = derive_function_name_regex(request.source_code)
|
||||
|
||||
self.ms.update_tool(existing_tool)
|
||||
self.ms.update_tool(original_id, existing_tool)
|
||||
return self.ms.get_tool(tool_id=request.id)
|
||||
|
||||
def create_tool(self, request: ToolCreate, user_id: Optional[str] = None, update: bool = True) -> Tool: # TODO: add other fields
|
||||
@ -1817,15 +1917,23 @@ class SyncServer(Server):
|
||||
assert request.name, f"Tool name must be provided in json_schema {json_schema}. This should never happen."
|
||||
|
||||
# check if already exists:
|
||||
existing_tool = self.ms.get_tool(tool_name=request.name, user_id=user_id)
|
||||
existing_tool = self.ms.get_tool(tool_id=request.id, tool_name=request.name, user_id=user_id)
|
||||
if existing_tool:
|
||||
if update:
|
||||
updated_tool = self.update_tool(ToolUpdate(id=existing_tool.id, **vars(request)))
|
||||
# id is an optional field, so we will fill it with the existing tool id
|
||||
if not request.id:
|
||||
request.id = existing_tool.id
|
||||
updated_tool = self.update_tool(ToolUpdate(**vars(request)), user_id)
|
||||
assert updated_tool is not None, f"Failed to update tool {request.name}"
|
||||
return updated_tool
|
||||
else:
|
||||
raise ValueError(f"Tool {request.name} already exists and update=False")
|
||||
|
||||
# check for description
|
||||
description = None
|
||||
if request.description:
|
||||
description = request.description
|
||||
|
||||
tool = Tool(
|
||||
name=request.name,
|
||||
source_code=request.source_code,
|
||||
@ -1833,9 +1941,14 @@ class SyncServer(Server):
|
||||
tags=request.tags,
|
||||
json_schema=json_schema,
|
||||
user_id=user_id,
|
||||
description=description,
|
||||
)
|
||||
|
||||
if request.id:
|
||||
tool.id = request.id
|
||||
|
||||
self.ms.create_tool(tool)
|
||||
created_tool = self.ms.get_tool(tool_name=request.name, user_id=user_id)
|
||||
created_tool = self.ms.get_tool(tool_id=tool.id, user_id=user_id)
|
||||
return created_tool
|
||||
|
||||
def delete_tool(self, tool_id: str):
|
||||
|
@ -1,3 +1,4 @@
|
||||
import uuid
|
||||
from typing import Union
|
||||
|
||||
import pytest
|
||||
@ -9,6 +10,7 @@ from letta.schemas.block import Block
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.memory import BasicBlockMemory, ChatMemory, Memory
|
||||
from letta.schemas.tool import Tool
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@ -16,12 +18,17 @@ def client():
|
||||
client = create_client()
|
||||
client.set_default_llm_config(LLMConfig.default_config("gpt-4o-mini"))
|
||||
client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai"))
|
||||
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def agent(client):
|
||||
agent_state = client.create_agent(name="test_agent")
|
||||
# Generate uuid for agent name for this example
|
||||
namespace = uuid.NAMESPACE_DNS
|
||||
agent_uuid = str(uuid.uuid5(namespace, "test_new_client_test_agent"))
|
||||
|
||||
agent_state = client.create_agent(name=agent_uuid)
|
||||
yield agent_state
|
||||
|
||||
client.delete_agent(agent_state.id)
|
||||
@ -114,6 +121,52 @@ def test_agent(client: Union[LocalClient, RESTClient]):
|
||||
client.delete_agent(agent_state_test.id)
|
||||
|
||||
|
||||
def test_agent_add_remove_tools(client: Union[LocalClient, RESTClient], agent):
|
||||
# Create and add two tools to the client
|
||||
# tool 1
|
||||
from composio_langchain import Action
|
||||
|
||||
github_tool = Tool.get_composio_tool(action=Action.GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER)
|
||||
client.add_tool(github_tool)
|
||||
# tool 2
|
||||
from crewai_tools import ScrapeWebsiteTool
|
||||
|
||||
scrape_website_tool = Tool.from_crewai(ScrapeWebsiteTool(website_url="https://www.example.com"))
|
||||
client.add_tool(scrape_website_tool)
|
||||
|
||||
# assert both got added
|
||||
tools = client.list_tools()
|
||||
assert github_tool.id in [t.id for t in tools]
|
||||
assert scrape_website_tool.id in [t.id for t in tools]
|
||||
|
||||
# Assert that all combinations of tool_names, tool_user_ids are unique
|
||||
combinations = [(t.name, t.user_id) for t in tools]
|
||||
assert len(combinations) == len(set(combinations))
|
||||
|
||||
# create agent
|
||||
agent_state = agent
|
||||
curr_num_tools = len(agent_state.tools)
|
||||
|
||||
# add both tools to agent in steps
|
||||
agent_state = client.add_tool_to_agent(agent_id=agent_state.id, tool_id=github_tool.id)
|
||||
agent_state = client.add_tool_to_agent(agent_id=agent_state.id, tool_id=scrape_website_tool.id)
|
||||
|
||||
# confirm that both tools are in the agent state
|
||||
curr_tools = agent_state.tools
|
||||
assert len(curr_tools) == curr_num_tools + 2
|
||||
assert github_tool.name in curr_tools
|
||||
assert scrape_website_tool.name in curr_tools
|
||||
|
||||
# remove only the github tool
|
||||
agent_state = client.remove_tool_from_agent(agent_id=agent_state.id, tool_id=github_tool.id)
|
||||
|
||||
# confirm that only one tool left
|
||||
curr_tools = agent_state.tools
|
||||
assert len(curr_tools) == curr_num_tools + 1
|
||||
assert github_tool.name not in curr_tools
|
||||
assert scrape_website_tool.name in curr_tools
|
||||
|
||||
|
||||
def test_agent_with_shared_blocks(client: Union[LocalClient, RESTClient]):
|
||||
persona_block = Block(name="persona", value="Here to test things!", label="persona", user_id=client.user_id)
|
||||
human_block = Block(name="human", value="Me Human, I swear. Beep boop.", label="human", user_id=client.user_id)
|
||||
@ -242,8 +295,7 @@ def test_tools(client: Union[LocalClient, RESTClient]):
|
||||
print(msg)
|
||||
|
||||
# create tool
|
||||
len(client.list_tools())
|
||||
tool = client.create_tool(print_tool, tags=["extras"])
|
||||
tool = client.create_tool(func=print_tool, tags=["extras"])
|
||||
|
||||
# list tools
|
||||
tools = client.list_tools()
|
||||
@ -258,19 +310,13 @@ def test_tools(client: Union[LocalClient, RESTClient]):
|
||||
assert client.get_tool(tool.id).tags == extras2
|
||||
|
||||
# update tool: source code
|
||||
client.update_tool(tool.id, name="print_tool2", func=print_tool2)
|
||||
client.update_tool(tool.id, func=print_tool2)
|
||||
assert client.get_tool(tool.id).name == "print_tool2"
|
||||
|
||||
## delete tool
|
||||
# client.delete_tool(tool.id)
|
||||
# assert len(client.list_tools()) == orig_tool_length
|
||||
|
||||
|
||||
def test_tools_from_composio_basic(client: Union[LocalClient, RESTClient]):
|
||||
from composio_langchain import Action
|
||||
|
||||
from letta.schemas.tool import Tool
|
||||
|
||||
# Create a `LocalClient` (you can also use a `RESTClient`, see the letta_rest_client.py example)
|
||||
client = create_client()
|
||||
|
||||
@ -292,8 +338,6 @@ def test_tools_from_crewai(client: Union[LocalClient, RESTClient]):
|
||||
|
||||
from crewai_tools import ScrapeWebsiteTool
|
||||
|
||||
from letta.schemas.tool import Tool
|
||||
|
||||
crewai_tool = ScrapeWebsiteTool()
|
||||
|
||||
# Translate to memGPT Tool
|
||||
@ -329,8 +373,6 @@ def test_tools_from_crewai_with_params(client: Union[LocalClient, RESTClient]):
|
||||
|
||||
from crewai_tools import ScrapeWebsiteTool
|
||||
|
||||
from letta.schemas.tool import Tool
|
||||
|
||||
crewai_tool = ScrapeWebsiteTool(website_url="https://www.example.com")
|
||||
|
||||
# Translate to memGPT Tool
|
||||
@ -363,8 +405,6 @@ def test_tools_from_langchain(client: Union[LocalClient, RESTClient]):
|
||||
from langchain_community.tools import WikipediaQueryRun
|
||||
from langchain_community.utilities import WikipediaAPIWrapper
|
||||
|
||||
from letta.schemas.tool import Tool
|
||||
|
||||
api_wrapper = WikipediaAPIWrapper(top_k_results=1, doc_content_chars_max=100)
|
||||
langchain_tool = WikipediaQueryRun(api_wrapper=api_wrapper)
|
||||
|
||||
@ -397,8 +437,6 @@ def test_tool_creation_langchain_missing_imports(client: Union[LocalClient, REST
|
||||
from langchain_community.tools import WikipediaQueryRun
|
||||
from langchain_community.utilities import WikipediaAPIWrapper
|
||||
|
||||
from letta.schemas.tool import Tool
|
||||
|
||||
api_wrapper = WikipediaAPIWrapper(top_k_results=1, doc_content_chars_max=100)
|
||||
langchain_tool = WikipediaQueryRun(api_wrapper=api_wrapper)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user