mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
chore: Catch orm specific issues in agents v1 routes and simplify tool add/remove from agent (#2259)
This commit is contained in:
parent
12d25a3c3e
commit
f1e125d360
@ -14,6 +14,8 @@ from starlette.middleware.cors import CORSMiddleware
|
||||
from letta.__init__ import __version__
|
||||
from letta.constants import ADMIN_PREFIX, API_PREFIX, OPENAI_API_PREFIX
|
||||
from letta.errors import LettaAgentNotFoundError, LettaUserNotFoundError
|
||||
from letta.log import get_logger
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.schemas.letta_response import LettaResponse
|
||||
from letta.server.constants import REST_DEFAULT_PORT
|
||||
|
||||
@ -45,6 +47,7 @@ from letta.settings import settings
|
||||
# NOTE(charles): @ethan I had to add this to get the global as the bottom to work
|
||||
interface: StreamingServerInterface = StreamingServerInterface
|
||||
server = SyncServer(default_interface_factory=lambda: interface())
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# TODO: remove
|
||||
password = None
|
||||
@ -170,6 +173,16 @@ def create_application() -> "FastAPI":
|
||||
},
|
||||
)
|
||||
|
||||
@app.exception_handler(NoResultFound)
|
||||
async def no_result_found_handler(request: Request, exc: NoResultFound):
|
||||
logger.error(f"NoResultFound request: {request}")
|
||||
logger.error(f"NoResultFound: {exc}")
|
||||
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content={"detail": str(exc)},
|
||||
)
|
||||
|
||||
@app.exception_handler(ValueError)
|
||||
async def value_error_handler(request: Request, exc: ValueError):
|
||||
return JSONResponse(status_code=400, content={"detail": str(exc)})
|
||||
|
@ -19,7 +19,6 @@ from letta.agent import Agent, save_agent
|
||||
from letta.chat_only_agent import ChatOnlyAgent
|
||||
from letta.credentials import LettaCredentials
|
||||
from letta.data_sources.connectors import DataConnector, load_data
|
||||
from letta.errors import LettaAgentNotFoundError
|
||||
|
||||
# TODO use custom interface
|
||||
from letta.interface import AgentInterface # abstract
|
||||
@ -399,9 +398,6 @@ class SyncServer(Server):
|
||||
with agent_lock:
|
||||
agent_state = self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor)
|
||||
|
||||
if agent_state is None:
|
||||
raise LettaAgentNotFoundError(f"Agent (agent_id={agent_id}) does not exist")
|
||||
|
||||
interface = interface or self.default_interface_factory()
|
||||
if agent_state.agent_type == AgentType.memgpt_agent:
|
||||
agent = Agent(agent_state=agent_state, interface=interface, user=actor)
|
||||
@ -901,32 +897,14 @@ class SyncServer(Server):
|
||||
# TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user
|
||||
actor = self.user_manager.get_user_or_default(user_id=user_id)
|
||||
|
||||
agent_state = self.agent_manager.attach_tool(agent_id=agent_id, tool_id=tool_id, actor=actor)
|
||||
|
||||
# TODO: This is very redundant, and should probably be simplified
|
||||
# Get the agent object (loaded in memory)
|
||||
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
|
||||
letta_agent.link_tools(agent_state.tools)
|
||||
|
||||
# Get all the tool objects from the request
|
||||
tool_objs = []
|
||||
tool_obj = self.tool_manager.get_tool_by_id(tool_id=tool_id, actor=actor)
|
||||
assert tool_obj, f"Tool with id={tool_id} does not exist"
|
||||
tool_objs.append(tool_obj)
|
||||
|
||||
for tool in letta_agent.agent_state.tools:
|
||||
tool_obj = self.tool_manager.get_tool_by_id(tool_id=tool.id, actor=actor)
|
||||
assert tool_obj, f"Tool with id={tool.id} does not exist"
|
||||
|
||||
# If it's not the already added tool
|
||||
if tool_obj.id != tool_id:
|
||||
tool_objs.append(tool_obj)
|
||||
|
||||
# replace the list of tool names ("ids") inside the agent state
|
||||
letta_agent.agent_state.tools = tool_objs
|
||||
|
||||
# then attempt to link the tools modules
|
||||
letta_agent.link_tools(tool_objs)
|
||||
|
||||
# save the agent
|
||||
save_agent(letta_agent)
|
||||
return letta_agent.agent_state
|
||||
return agent_state
|
||||
|
||||
def remove_tool_from_agent(
|
||||
self,
|
||||
@ -937,29 +915,13 @@ class SyncServer(Server):
|
||||
"""Remove tools from an existing agent"""
|
||||
# TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user
|
||||
actor = self.user_manager.get_user_or_default(user_id=user_id)
|
||||
agent_state = self.agent_manager.detach_tool(agent_id=agent_id, tool_id=tool_id, actor=actor)
|
||||
|
||||
# Get the agent object (loaded in memory)
|
||||
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
|
||||
letta_agent.link_tools(agent_state.tools)
|
||||
|
||||
# Get all the tool_objs
|
||||
tool_objs = []
|
||||
for tool in letta_agent.agent_state.tools:
|
||||
tool_obj = self.tool_manager.get_tool_by_id(tool_id=tool.id, actor=actor)
|
||||
assert tool_obj, f"Tool with id={tool.id} does not exist"
|
||||
|
||||
# If it's not the tool we want to remove
|
||||
if tool_obj.id != tool_id:
|
||||
tool_objs.append(tool_obj)
|
||||
|
||||
# replace the list of tool names ("ids") inside the agent state
|
||||
letta_agent.agent_state.tools = tool_objs
|
||||
|
||||
# then attempt to link the tools modules
|
||||
letta_agent.link_tools(tool_objs)
|
||||
|
||||
# save the agent
|
||||
save_agent(letta_agent)
|
||||
return letta_agent.agent_state
|
||||
return agent_state
|
||||
|
||||
# convert name->id
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS
|
||||
from letta.log import get_logger
|
||||
from letta.orm import Agent as AgentModel
|
||||
from letta.orm import Block as BlockModel
|
||||
from letta.orm import Source as SourceModel
|
||||
@ -25,6 +26,8 @@ from letta.services.source_manager import SourceManager
|
||||
from letta.services.tool_manager import ToolManager
|
||||
from letta.utils import enforce_types
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
# Agent Manager Class
|
||||
class AgentManager:
|
||||
@ -403,3 +406,74 @@ class AgentManager:
|
||||
|
||||
agent.update(session, actor=actor)
|
||||
return agent.to_pydantic()
|
||||
|
||||
# ======================================================================================================================
|
||||
# Tool Management
|
||||
# ======================================================================================================================
|
||||
@enforce_types
|
||||
def attach_tool(self, agent_id: str, tool_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||||
"""
|
||||
Attaches a tool to an agent.
|
||||
|
||||
Args:
|
||||
agent_id: ID of the agent to attach the tool to.
|
||||
tool_id: ID of the tool to attach.
|
||||
actor: User performing the action.
|
||||
|
||||
Raises:
|
||||
NoResultFound: If the agent or tool is not found.
|
||||
|
||||
Returns:
|
||||
PydanticAgentState: The updated agent state.
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
# Verify the agent exists and user has permission to access it
|
||||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||||
|
||||
# Use the _process_relationship helper to attach the tool
|
||||
_process_relationship(
|
||||
session=session,
|
||||
agent=agent,
|
||||
relationship_name="tools",
|
||||
model_class=ToolModel,
|
||||
item_ids=[tool_id],
|
||||
allow_partial=False, # Ensure the tool exists
|
||||
replace=False, # Extend the existing tools
|
||||
)
|
||||
|
||||
# Commit and refresh the agent
|
||||
agent.update(session, actor=actor)
|
||||
return agent.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def detach_tool(self, agent_id: str, tool_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||||
"""
|
||||
Detaches a tool from an agent.
|
||||
|
||||
Args:
|
||||
agent_id: ID of the agent to detach the tool from.
|
||||
tool_id: ID of the tool to detach.
|
||||
actor: User performing the action.
|
||||
|
||||
Raises:
|
||||
NoResultFound: If the agent or tool is not found.
|
||||
|
||||
Returns:
|
||||
PydanticAgentState: The updated agent state.
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
# Verify the agent exists and user has permission to access it
|
||||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||||
|
||||
# Filter out the tool to be detached
|
||||
remaining_tools = [tool for tool in agent.tools if tool.id != tool_id]
|
||||
|
||||
if len(remaining_tools) == len(agent.tools): # Tool ID was not in the relationship
|
||||
logger.warning(f"Attempted to remove unattached tool id={tool_id} from agent id={agent_id} by actor={actor}")
|
||||
|
||||
# Update the tools relationship
|
||||
agent.tools = remaining_tools
|
||||
|
||||
# Commit and refresh the agent
|
||||
agent.update(session, actor=actor)
|
||||
return agent.to_pydantic()
|
||||
|
@ -438,6 +438,82 @@ def test_update_agent(server: SyncServer, comprehensive_test_agent_fixture, othe
|
||||
assert updated_agent.message_ids == update_agent_request.message_ids
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# AgentManager Tests - Tools Relationship
|
||||
# ======================================================================================================================
|
||||
|
||||
|
||||
def test_attach_tool(server: SyncServer, sarah_agent, print_tool, default_user):
|
||||
"""Test attaching a tool to an agent."""
|
||||
# Attach the tool
|
||||
server.agent_manager.attach_tool(agent_id=sarah_agent.id, tool_id=print_tool.id, actor=default_user)
|
||||
|
||||
# Verify attachment through get_agent_by_id
|
||||
agent = server.agent_manager.get_agent_by_id(sarah_agent.id, actor=default_user)
|
||||
assert print_tool.id in [t.id for t in agent.tools]
|
||||
|
||||
# Verify that attaching the same tool again doesn't cause duplication
|
||||
server.agent_manager.attach_tool(agent_id=sarah_agent.id, tool_id=print_tool.id, actor=default_user)
|
||||
agent = server.agent_manager.get_agent_by_id(sarah_agent.id, actor=default_user)
|
||||
assert len([t for t in agent.tools if t.id == print_tool.id]) == 1
|
||||
|
||||
|
||||
def test_detach_tool(server: SyncServer, sarah_agent, print_tool, default_user):
|
||||
"""Test detaching a tool from an agent."""
|
||||
# Attach the tool first
|
||||
server.agent_manager.attach_tool(agent_id=sarah_agent.id, tool_id=print_tool.id, actor=default_user)
|
||||
|
||||
# Verify it's attached
|
||||
agent = server.agent_manager.get_agent_by_id(sarah_agent.id, actor=default_user)
|
||||
assert print_tool.id in [t.id for t in agent.tools]
|
||||
|
||||
# Detach the tool
|
||||
server.agent_manager.detach_tool(agent_id=sarah_agent.id, tool_id=print_tool.id, actor=default_user)
|
||||
|
||||
# Verify it's detached
|
||||
agent = server.agent_manager.get_agent_by_id(sarah_agent.id, actor=default_user)
|
||||
assert print_tool.id not in [t.id for t in agent.tools]
|
||||
|
||||
# Verify that detaching an already detached tool doesn't cause issues
|
||||
server.agent_manager.detach_tool(agent_id=sarah_agent.id, tool_id=print_tool.id, actor=default_user)
|
||||
|
||||
|
||||
def test_attach_tool_nonexistent_agent(server: SyncServer, print_tool, default_user):
|
||||
"""Test attaching a tool to a nonexistent agent."""
|
||||
with pytest.raises(NoResultFound):
|
||||
server.agent_manager.attach_tool(agent_id="nonexistent-agent-id", tool_id=print_tool.id, actor=default_user)
|
||||
|
||||
|
||||
def test_attach_tool_nonexistent_tool(server: SyncServer, sarah_agent, default_user):
|
||||
"""Test attaching a nonexistent tool to an agent."""
|
||||
with pytest.raises(NoResultFound):
|
||||
server.agent_manager.attach_tool(agent_id=sarah_agent.id, tool_id="nonexistent-tool-id", actor=default_user)
|
||||
|
||||
|
||||
def test_detach_tool_nonexistent_agent(server: SyncServer, print_tool, default_user):
|
||||
"""Test detaching a tool from a nonexistent agent."""
|
||||
with pytest.raises(NoResultFound):
|
||||
server.agent_manager.detach_tool(agent_id="nonexistent-agent-id", tool_id=print_tool.id, actor=default_user)
|
||||
|
||||
|
||||
def test_list_attached_tools(server: SyncServer, sarah_agent, print_tool, other_tool, default_user):
|
||||
"""Test listing tools attached to an agent."""
|
||||
# Initially should have no tools
|
||||
agent = server.agent_manager.get_agent_by_id(sarah_agent.id, actor=default_user)
|
||||
assert len(agent.tools) == 0
|
||||
|
||||
# Attach tools
|
||||
server.agent_manager.attach_tool(agent_id=sarah_agent.id, tool_id=print_tool.id, actor=default_user)
|
||||
server.agent_manager.attach_tool(agent_id=sarah_agent.id, tool_id=other_tool.id, actor=default_user)
|
||||
|
||||
# List tools and verify
|
||||
agent = server.agent_manager.get_agent_by_id(sarah_agent.id, actor=default_user)
|
||||
attached_tool_ids = [t.id for t in agent.tools]
|
||||
assert len(attached_tool_ids) == 2
|
||||
assert print_tool.id in attached_tool_ids
|
||||
assert other_tool.id in attached_tool_ids
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# AgentManager Tests - Sources Relationship
|
||||
# ======================================================================================================================
|
||||
@ -693,6 +769,7 @@ def test_attach_block(server: SyncServer, sarah_agent, default_block, default_us
|
||||
assert agent.memory.blocks[0].label == default_block.label
|
||||
|
||||
|
||||
@pytest.mark.skipif(USING_SQLITE, reason="Test not applicable when using SQLite.")
|
||||
def test_attach_block_duplicate_label(server: SyncServer, sarah_agent, default_block, other_block, default_user):
|
||||
"""Test attempting to attach a block with a duplicate label."""
|
||||
# Set up both blocks with same label
|
||||
@ -1143,6 +1220,7 @@ def test_create_tool(server: SyncServer, print_tool, default_user, default_organ
|
||||
assert print_tool.organization_id == default_organization.id
|
||||
|
||||
|
||||
@pytest.mark.skipif(USING_SQLITE, reason="Test not applicable when using SQLite.")
|
||||
def test_create_tool_duplicate_name(server: SyncServer, print_tool, default_user, default_organization):
|
||||
data = print_tool.model_dump(exclude=["id"])
|
||||
tool = PydanticTool(**data)
|
||||
|
Loading…
Reference in New Issue
Block a user