chore: Catch orm specific issues in agents v1 routes and simplify tool add/remove from agent (#2259)

This commit is contained in:
Matthew Zhou 2024-12-16 11:01:47 -08:00 committed by GitHub
parent 12d25a3c3e
commit f1e125d360
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 173 additions and 46 deletions

View File

@ -14,6 +14,8 @@ from starlette.middleware.cors import CORSMiddleware
from letta.__init__ import __version__
from letta.constants import ADMIN_PREFIX, API_PREFIX, OPENAI_API_PREFIX
from letta.errors import LettaAgentNotFoundError, LettaUserNotFoundError
from letta.log import get_logger
from letta.orm.errors import NoResultFound
from letta.schemas.letta_response import LettaResponse
from letta.server.constants import REST_DEFAULT_PORT
@ -45,6 +47,7 @@ from letta.settings import settings
# NOTE(charles): @ethan I had to add this to get the global as the bottom to work
interface: StreamingServerInterface = StreamingServerInterface
server = SyncServer(default_interface_factory=lambda: interface())
logger = get_logger(__name__)
# TODO: remove
password = None
@ -170,6 +173,16 @@ def create_application() -> "FastAPI":
},
)
@app.exception_handler(NoResultFound)
async def no_result_found_handler(request: Request, exc: NoResultFound):
logger.error(f"NoResultFound request: {request}")
logger.error(f"NoResultFound: {exc}")
return JSONResponse(
status_code=404,
content={"detail": str(exc)},
)
@app.exception_handler(ValueError)
async def value_error_handler(request: Request, exc: ValueError):
return JSONResponse(status_code=400, content={"detail": str(exc)})

View File

@ -19,7 +19,6 @@ from letta.agent import Agent, save_agent
from letta.chat_only_agent import ChatOnlyAgent
from letta.credentials import LettaCredentials
from letta.data_sources.connectors import DataConnector, load_data
from letta.errors import LettaAgentNotFoundError
# TODO use custom interface
from letta.interface import AgentInterface # abstract
@ -399,9 +398,6 @@ class SyncServer(Server):
with agent_lock:
agent_state = self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor)
if agent_state is None:
raise LettaAgentNotFoundError(f"Agent (agent_id={agent_id}) does not exist")
interface = interface or self.default_interface_factory()
if agent_state.agent_type == AgentType.memgpt_agent:
agent = Agent(agent_state=agent_state, interface=interface, user=actor)
@ -901,32 +897,14 @@ class SyncServer(Server):
# TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user
actor = self.user_manager.get_user_or_default(user_id=user_id)
agent_state = self.agent_manager.attach_tool(agent_id=agent_id, tool_id=tool_id, actor=actor)
# TODO: This is very redundant, and should probably be simplified
# Get the agent object (loaded in memory)
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
letta_agent.link_tools(agent_state.tools)
# Get all the tool objects from the request
tool_objs = []
tool_obj = self.tool_manager.get_tool_by_id(tool_id=tool_id, actor=actor)
assert tool_obj, f"Tool with id={tool_id} does not exist"
tool_objs.append(tool_obj)
for tool in letta_agent.agent_state.tools:
tool_obj = self.tool_manager.get_tool_by_id(tool_id=tool.id, actor=actor)
assert tool_obj, f"Tool with id={tool.id} does not exist"
# If it's not the already added tool
if tool_obj.id != tool_id:
tool_objs.append(tool_obj)
# replace the list of tool names ("ids") inside the agent state
letta_agent.agent_state.tools = tool_objs
# then attempt to link the tools modules
letta_agent.link_tools(tool_objs)
# save the agent
save_agent(letta_agent)
return letta_agent.agent_state
return agent_state
def remove_tool_from_agent(
self,
@ -937,29 +915,13 @@ class SyncServer(Server):
"""Remove tools from an existing agent"""
# TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user
actor = self.user_manager.get_user_or_default(user_id=user_id)
agent_state = self.agent_manager.detach_tool(agent_id=agent_id, tool_id=tool_id, actor=actor)
# Get the agent object (loaded in memory)
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
letta_agent.link_tools(agent_state.tools)
# Get all the tool_objs
tool_objs = []
for tool in letta_agent.agent_state.tools:
tool_obj = self.tool_manager.get_tool_by_id(tool_id=tool.id, actor=actor)
assert tool_obj, f"Tool with id={tool.id} does not exist"
# If it's not the tool we want to remove
if tool_obj.id != tool_id:
tool_objs.append(tool_obj)
# replace the list of tool names ("ids") inside the agent state
letta_agent.agent_state.tools = tool_objs
# then attempt to link the tools modules
letta_agent.link_tools(tool_objs)
# save the agent
save_agent(letta_agent)
return letta_agent.agent_state
return agent_state
# convert name->id

View File

@ -1,6 +1,7 @@
from typing import Dict, List, Optional
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS
from letta.log import get_logger
from letta.orm import Agent as AgentModel
from letta.orm import Block as BlockModel
from letta.orm import Source as SourceModel
@ -25,6 +26,8 @@ from letta.services.source_manager import SourceManager
from letta.services.tool_manager import ToolManager
from letta.utils import enforce_types
logger = get_logger(__name__)
# Agent Manager Class
class AgentManager:
@ -403,3 +406,74 @@ class AgentManager:
agent.update(session, actor=actor)
return agent.to_pydantic()
# ======================================================================================================================
# Tool Management
# ======================================================================================================================
@enforce_types
def attach_tool(self, agent_id: str, tool_id: str, actor: PydanticUser) -> PydanticAgentState:
"""
Attaches a tool to an agent.
Args:
agent_id: ID of the agent to attach the tool to.
tool_id: ID of the tool to attach.
actor: User performing the action.
Raises:
NoResultFound: If the agent or tool is not found.
Returns:
PydanticAgentState: The updated agent state.
"""
with self.session_maker() as session:
# Verify the agent exists and user has permission to access it
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
# Use the _process_relationship helper to attach the tool
_process_relationship(
session=session,
agent=agent,
relationship_name="tools",
model_class=ToolModel,
item_ids=[tool_id],
allow_partial=False, # Ensure the tool exists
replace=False, # Extend the existing tools
)
# Commit and refresh the agent
agent.update(session, actor=actor)
return agent.to_pydantic()
@enforce_types
def detach_tool(self, agent_id: str, tool_id: str, actor: PydanticUser) -> PydanticAgentState:
"""
Detaches a tool from an agent.
Args:
agent_id: ID of the agent to detach the tool from.
tool_id: ID of the tool to detach.
actor: User performing the action.
Raises:
NoResultFound: If the agent or tool is not found.
Returns:
PydanticAgentState: The updated agent state.
"""
with self.session_maker() as session:
# Verify the agent exists and user has permission to access it
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
# Filter out the tool to be detached
remaining_tools = [tool for tool in agent.tools if tool.id != tool_id]
if len(remaining_tools) == len(agent.tools): # Tool ID was not in the relationship
logger.warning(f"Attempted to remove unattached tool id={tool_id} from agent id={agent_id} by actor={actor}")
# Update the tools relationship
agent.tools = remaining_tools
# Commit and refresh the agent
agent.update(session, actor=actor)
return agent.to_pydantic()

View File

@ -438,6 +438,82 @@ def test_update_agent(server: SyncServer, comprehensive_test_agent_fixture, othe
assert updated_agent.message_ids == update_agent_request.message_ids
# ======================================================================================================================
# AgentManager Tests - Tools Relationship
# ======================================================================================================================
def test_attach_tool(server: SyncServer, sarah_agent, print_tool, default_user):
"""Test attaching a tool to an agent."""
# Attach the tool
server.agent_manager.attach_tool(agent_id=sarah_agent.id, tool_id=print_tool.id, actor=default_user)
# Verify attachment through get_agent_by_id
agent = server.agent_manager.get_agent_by_id(sarah_agent.id, actor=default_user)
assert print_tool.id in [t.id for t in agent.tools]
# Verify that attaching the same tool again doesn't cause duplication
server.agent_manager.attach_tool(agent_id=sarah_agent.id, tool_id=print_tool.id, actor=default_user)
agent = server.agent_manager.get_agent_by_id(sarah_agent.id, actor=default_user)
assert len([t for t in agent.tools if t.id == print_tool.id]) == 1
def test_detach_tool(server: SyncServer, sarah_agent, print_tool, default_user):
"""Test detaching a tool from an agent."""
# Attach the tool first
server.agent_manager.attach_tool(agent_id=sarah_agent.id, tool_id=print_tool.id, actor=default_user)
# Verify it's attached
agent = server.agent_manager.get_agent_by_id(sarah_agent.id, actor=default_user)
assert print_tool.id in [t.id for t in agent.tools]
# Detach the tool
server.agent_manager.detach_tool(agent_id=sarah_agent.id, tool_id=print_tool.id, actor=default_user)
# Verify it's detached
agent = server.agent_manager.get_agent_by_id(sarah_agent.id, actor=default_user)
assert print_tool.id not in [t.id for t in agent.tools]
# Verify that detaching an already detached tool doesn't cause issues
server.agent_manager.detach_tool(agent_id=sarah_agent.id, tool_id=print_tool.id, actor=default_user)
def test_attach_tool_nonexistent_agent(server: SyncServer, print_tool, default_user):
"""Test attaching a tool to a nonexistent agent."""
with pytest.raises(NoResultFound):
server.agent_manager.attach_tool(agent_id="nonexistent-agent-id", tool_id=print_tool.id, actor=default_user)
def test_attach_tool_nonexistent_tool(server: SyncServer, sarah_agent, default_user):
"""Test attaching a nonexistent tool to an agent."""
with pytest.raises(NoResultFound):
server.agent_manager.attach_tool(agent_id=sarah_agent.id, tool_id="nonexistent-tool-id", actor=default_user)
def test_detach_tool_nonexistent_agent(server: SyncServer, print_tool, default_user):
"""Test detaching a tool from a nonexistent agent."""
with pytest.raises(NoResultFound):
server.agent_manager.detach_tool(agent_id="nonexistent-agent-id", tool_id=print_tool.id, actor=default_user)
def test_list_attached_tools(server: SyncServer, sarah_agent, print_tool, other_tool, default_user):
"""Test listing tools attached to an agent."""
# Initially should have no tools
agent = server.agent_manager.get_agent_by_id(sarah_agent.id, actor=default_user)
assert len(agent.tools) == 0
# Attach tools
server.agent_manager.attach_tool(agent_id=sarah_agent.id, tool_id=print_tool.id, actor=default_user)
server.agent_manager.attach_tool(agent_id=sarah_agent.id, tool_id=other_tool.id, actor=default_user)
# List tools and verify
agent = server.agent_manager.get_agent_by_id(sarah_agent.id, actor=default_user)
attached_tool_ids = [t.id for t in agent.tools]
assert len(attached_tool_ids) == 2
assert print_tool.id in attached_tool_ids
assert other_tool.id in attached_tool_ids
# ======================================================================================================================
# AgentManager Tests - Sources Relationship
# ======================================================================================================================
@ -693,6 +769,7 @@ def test_attach_block(server: SyncServer, sarah_agent, default_block, default_us
assert agent.memory.blocks[0].label == default_block.label
@pytest.mark.skipif(USING_SQLITE, reason="Test not applicable when using SQLite.")
def test_attach_block_duplicate_label(server: SyncServer, sarah_agent, default_block, other_block, default_user):
"""Test attempting to attach a block with a duplicate label."""
# Set up both blocks with same label
@ -1143,6 +1220,7 @@ def test_create_tool(server: SyncServer, print_tool, default_user, default_organ
assert print_tool.organization_id == default_organization.id
@pytest.mark.skipif(USING_SQLITE, reason="Test not applicable when using SQLite.")
def test_create_tool_duplicate_name(server: SyncServer, print_tool, default_user, default_organization):
data = print_tool.model_dump(exclude=["id"])
tool = PydanticTool(**data)