mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
feat: move source_id
to path variable (#1171)
This commit is contained in:
parent
89cc4b98ed
commit
d664e25cef
@ -446,9 +446,8 @@ class RESTClient(AbstractClient):
|
||||
|
||||
def load_file_into_source(self, filename: str, source_id: uuid.UUID):
|
||||
"""Load {filename} and insert into source"""
|
||||
params = {"source_id": str(source_id)}
|
||||
files = {"file": open(filename, "rb")}
|
||||
response = requests.post(f"{self.base_url}/api/sources/upload", files=files, params=params, headers=self.headers)
|
||||
response = requests.post(f"{self.base_url}/api/sources/{source_id}/upload", files=files, headers=self.headers)
|
||||
return response.json()
|
||||
|
||||
def create_source(self, name: str) -> Source:
|
||||
@ -466,17 +465,17 @@ class RESTClient(AbstractClient):
|
||||
embedding_model=response_json["embedding_config"]["embedding_model"],
|
||||
)
|
||||
|
||||
def attach_source_to_agent(self, source_name: str, agent_id: uuid.UUID):
|
||||
def attach_source_to_agent(self, source_id: uuid.UUID, agent_id: uuid.UUID):
|
||||
"""Attach a source to an agent"""
|
||||
params = {"source_name": source_name, "agent_id": agent_id}
|
||||
response = requests.post(f"{self.base_url}/api/sources/attach", params=params, headers=self.headers)
|
||||
params = {"agent_id": agent_id}
|
||||
response = requests.post(f"{self.base_url}/api/sources/{source_id}/attach", params=params, headers=self.headers)
|
||||
assert response.status_code == 200, f"Failed to attach source to agent: {response.text}"
|
||||
return response.json()
|
||||
|
||||
def detach_source(self, source_name: str, agent_id: uuid.UUID):
|
||||
def detach_source(self, source_id: uuid.UUID, agent_id: uuid.UUID):
|
||||
"""Detach a source from an agent"""
|
||||
params = {"source_name": source_name, "agent_id": str(agent_id)}
|
||||
response = requests.post(f"{self.base_url}/api/sources/detach", params=params, headers=self.headers)
|
||||
params = {"agent_id": str(agent_id)}
|
||||
response = requests.post(f"{self.base_url}/api/sources/{source_id}/detach", params=params, headers=self.headers)
|
||||
assert response.status_code == 200, f"Failed to detach source from agent: {response.text}"
|
||||
return response.json()
|
||||
|
||||
@ -615,8 +614,8 @@ class LocalClient(AbstractClient):
|
||||
def create_source(self, name: str):
|
||||
self.server.create_source(user_id=self.user_id, name=name)
|
||||
|
||||
def attach_source_to_agent(self, source_name: str, agent_id: uuid.UUID):
|
||||
self.server.attach_source_to_agent(user_id=self.user_id, source_name=source_name, agent_id=agent_id)
|
||||
def attach_source_to_agent(self, source_id: uuid.UUID, agent_id: uuid.UUID):
|
||||
self.server.attach_source_to_agent(user_id=self.user_id, source_id=source_id, agent_id=agent_id)
|
||||
|
||||
def delete_agent(self, agent_id: uuid.UUID):
|
||||
self.server.delete_agent(user_id=self.user_id, agent_id=agent_id)
|
||||
|
@ -118,17 +118,17 @@ def setup_sources_index_router(server: SyncServer, interface: QueuingInterface,
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"{e}")
|
||||
|
||||
@router.post("/sources/attach", tags=["sources"], response_model=SourceModel)
|
||||
@router.post("/sources/{source_id}/attach", tags=["sources"], response_model=SourceModel)
|
||||
async def attach_source_to_agent(
|
||||
source_id: uuid.UUID,
|
||||
agent_id: uuid.UUID = Query(..., description="The unique identifier of the agent to attach the source to."),
|
||||
source_name: str = Query(..., description="The name of the source to attach."),
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""Attach a data source to an existing agent."""
|
||||
interface.clear()
|
||||
assert isinstance(agent_id, uuid.UUID), f"Expected agent_id to be a UUID, got {agent_id}"
|
||||
assert isinstance(user_id, uuid.UUID), f"Expected user_id to be a UUID, got {user_id}"
|
||||
source = server.attach_source_to_agent(source_name=source_name, agent_id=agent_id, user_id=user_id)
|
||||
source = server.attach_source_to_agent(source_id=source_id, agent_id=agent_id, user_id=user_id)
|
||||
return SourceModel(
|
||||
name=source.name,
|
||||
description=None, # TODO: actually store descriptions
|
||||
@ -138,20 +138,20 @@ def setup_sources_index_router(server: SyncServer, interface: QueuingInterface,
|
||||
created_at=source.created_at,
|
||||
)
|
||||
|
||||
@router.post("/sources/detach", tags=["sources"], response_model=SourceModel)
|
||||
@router.post("/sources/{source_id}/detach", tags=["sources"], response_model=SourceModel)
|
||||
async def detach_source_from_agent(
|
||||
source_id: uuid.UUID,
|
||||
agent_id: uuid.UUID = Query(..., description="The unique identifier of the agent to detach the source from."),
|
||||
source_name: str = Query(..., description="The name of the source to detach."),
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""Detach a data source from an existing agent."""
|
||||
server.detach_source_from_agent(source_name=source_name, agent_id=agent_id, user_id=user_id)
|
||||
server.detach_source_from_agent(source_id=source_id, agent_id=agent_id, user_id=user_id)
|
||||
|
||||
@router.post("/sources/upload", tags=["sources"], response_model=UploadFileToSourceResponse)
|
||||
@router.post("/sources/{source_id}/upload", tags=["sources"], response_model=UploadFileToSourceResponse)
|
||||
async def upload_file_to_source(
|
||||
# file: UploadFile = UploadFile(..., description="The file to upload."),
|
||||
file: UploadFile,
|
||||
source_id: uuid.UUID = Query(..., description="The unique identifier of the source to attach."),
|
||||
source_id: uuid.UUID,
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""Upload a file to a data source."""
|
||||
@ -173,18 +173,18 @@ def setup_sources_index_router(server: SyncServer, interface: QueuingInterface,
|
||||
# TODO: actually return added passages/documents
|
||||
return UploadFileToSourceResponse(source=source, added_passages=passage_count, added_documents=document_count)
|
||||
|
||||
@router.get("/sources/passages ", tags=["sources"], response_model=GetSourcePassagesResponse)
|
||||
@router.get("/sources/{source_id}/passages ", tags=["sources"], response_model=GetSourcePassagesResponse)
|
||||
async def list_passages(
|
||||
source_id: uuid.UUID = Body(...),
|
||||
source_id: uuid.UUID,
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""List all passages associated with a data source."""
|
||||
passages = server.list_data_source_passages(user_id=user_id, source_id=source_id)
|
||||
return GetSourcePassagesResponse(passages=passages)
|
||||
|
||||
@router.get("/sources/documents", tags=["sources"], response_model=GetSourceDocumentsResponse)
|
||||
@router.get("/sources/{source_id}/documents", tags=["sources"], response_model=GetSourceDocumentsResponse)
|
||||
async def list_documents(
|
||||
source_id: uuid.UUID = Body(...),
|
||||
source_id: uuid.UUID,
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""List all documents associated with a data source."""
|
||||
|
@ -1293,11 +1293,17 @@ class SyncServer(LockingServer):
|
||||
passage_count, document_count = load_data(connector, source, self.config.default_embedding_config, passage_store, document_store)
|
||||
return passage_count, document_count
|
||||
|
||||
def attach_source_to_agent(self, user_id: uuid.UUID, agent_id: uuid.UUID, source_name: str):
|
||||
def attach_source_to_agent(
|
||||
self,
|
||||
user_id: uuid.UUID,
|
||||
agent_id: uuid.UUID,
|
||||
source_id: Optional[uuid.UUID] = None,
|
||||
source_name: Optional[str] = None,
|
||||
):
|
||||
# attach a data source to an agent
|
||||
data_source = self.ms.get_source(source_name=source_name, user_id=user_id)
|
||||
data_source = self.ms.get_source(source_id=source_id, user_id=user_id, source_name=source_name)
|
||||
if data_source is None:
|
||||
raise ValueError(f"Data source {source_name} does not exist for user_id {user_id}")
|
||||
raise ValueError(f"Data source id={source_id} name={source_name} does not exist for user_id {user_id}")
|
||||
|
||||
# get connection to data source storage
|
||||
source_connector = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=user_id)
|
||||
@ -1310,7 +1316,13 @@ class SyncServer(LockingServer):
|
||||
|
||||
return data_source
|
||||
|
||||
def detach_source_from_agent(self, user_id: uuid.UUID, agent_id: uuid.UUID, source_name: str):
|
||||
def detach_source_from_agent(
|
||||
self,
|
||||
user_id: uuid.UUID,
|
||||
agent_id: uuid.UUID,
|
||||
source_id: Optional[uuid.UUID] = None,
|
||||
source_name: Optional[str] = None,
|
||||
):
|
||||
# TODO: remove all passages coresponding to source from agent's archival memory
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -270,12 +270,12 @@ def test_sources(client, agent):
|
||||
# load a file into a source
|
||||
filename = "CONTRIBUTING.md"
|
||||
num_passages = 20
|
||||
response = client.load_file_into_source(filename, source.id)
|
||||
response = client.load_file_into_source(filename=filename, source_id=source.id)
|
||||
print(response)
|
||||
|
||||
# attach a source
|
||||
# TODO: make sure things run in the right order
|
||||
client.attach_source_to_agent(source_name="test_source", agent_id=agent.id)
|
||||
client.attach_source_to_agent(source_id=source.id, agent_id=agent.id)
|
||||
|
||||
# list archival memory
|
||||
archival_memories = client.get_agent_archival_memory(agent_id=agent.id).archival_memory
|
||||
|
@ -157,7 +157,7 @@ def test_attach_source_to_agent(server, user_id, agent_id):
|
||||
assert len(passages_before) == 0
|
||||
|
||||
# attach source
|
||||
server.attach_source_to_agent(user_id, agent_id, "test_source")
|
||||
server.attach_source_to_agent(user_id=user_id, agent_id=agent_id, source_name="test_source")
|
||||
|
||||
# check archival memory size
|
||||
passages_after = server.get_agent_archival(user_id=user_id, agent_id=agent_id, start=0, count=10000)
|
||||
|
Loading…
Reference in New Issue
Block a user