feat: support detaching sources from agents (#1791)

This commit is contained in:
Sarah Wooders 2024-09-25 15:18:46 -07:00 committed by GitHub
parent 14a159b09b
commit 10cb0c118b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 29 additions and 7 deletions

View File

@ -1110,6 +1110,7 @@ class RESTClient(AbstractClient):
params = {"agent_id": str(agent_id)}
response = requests.post(f"{self.base_url}/{self.api_prefix}/sources/{source_id}/detach", params=params, headers=self.headers)
assert response.status_code == 200, f"Failed to detach source from agent: {response.text}"
return Source(**response.json())
# server configuration commands
@ -2158,7 +2159,16 @@ class LocalClient(AbstractClient):
self.server.attach_source_to_agent(source_id=source_id, source_name=source_name, agent_id=agent_id, user_id=self.user_id)
def detach_source_from_agent(self, agent_id: str, source_id: Optional[str] = None, source_name: Optional[str] = None):
self.server.detach_source_from_agent(source_id=source_id, source_name=source_name, agent_id=agent_id, user_id=self.user_id)
"""
Detach a source from an agent by removing all `Passage` objects that were loaded from the source from archival memory.
Args:
agent_id (str): ID of the agent
source_id (str): ID of the source
source_name (str): Name of the source
Returns:
source (Source): Detached source
"""
return self.server.detach_source_from_agent(source_id=source_id, source_name=source_name, agent_id=agent_id, user_id=self.user_id)
def list_sources(self) -> List[Source]:
"""

View File

@ -114,7 +114,7 @@ def attach_source_to_agent(
return source
@router.post("/{source_id}/detach", response_model=None, operation_id="detach_agent_from_source")
@router.post("/{source_id}/detach", response_model=Source, operation_id="detach_agent_from_source")
def detach_source_from_agent(
source_id: str,
agent_id: str = Query(..., description="The unique identifier of the agent to detach the source from."),
@ -125,7 +125,7 @@ def detach_source_from_agent(
"""
actor = server.get_current_user()
server.detach_source_from_agent(source_id=source_id, agent_id=agent_id, user_id=actor.id)
return server.detach_source_from_agent(source_id=source_id, agent_id=agent_id, user_id=actor.id)
@router.post("/{source_id}/upload", response_model=Job, operation_id="upload_file_to_source")

View File

@ -1628,8 +1628,18 @@ class SyncServer(Server):
source_id: Optional[str] = None,
source_name: Optional[str] = None,
) -> Source:
# TODO: remove all passages coresponding to source from agent's archival memory
raise NotImplementedError
if not source_id:
assert source_name is not None, "source_name must be provided if source_id is not"
source = self.ms.get_source(source_name=source_name, user_id=user_id)
source_id = source.id
else:
source = self.ms.get_source(source_id=source_id)
# delete all Passage objects with source_id==source_id from agent's archival memory
agent = self._get_or_load_agent(agent_id=agent_id)
archival_memory = agent.persistence_manager.archival_memory
archival_memory.storage.delete({"source_id": source_id})
return source
def list_attached_sources(self, agent_id: str) -> List[Source]:
# list all attached sources to an agent

View File

@ -391,8 +391,10 @@ def test_sources(client: Union[LocalClient, RESTClient], agent: AgentState):
print(sources)
# detach the source
# TODO: add when implemented
# client.detach_source(source.name, agent.id)
deleted_source = client.detach_source(source_id=source.id, agent_id=agent.id)
assert deleted_source.id == source.id
archival_memories = client.get_archival_memory(agent_id=agent.id)
assert len(archival_memories) == 0, f"Failed to detach source: {len(archival_memories)}"
# delete the source
client.delete_source(source.id)