mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
feat: move source_id
to path variable (#1171)
This commit is contained in:
parent
89cc4b98ed
commit
d664e25cef
@ -446,9 +446,8 @@ class RESTClient(AbstractClient):
|
|||||||
|
|
||||||
def load_file_into_source(self, filename: str, source_id: uuid.UUID):
|
def load_file_into_source(self, filename: str, source_id: uuid.UUID):
|
||||||
"""Load {filename} and insert into source"""
|
"""Load {filename} and insert into source"""
|
||||||
params = {"source_id": str(source_id)}
|
|
||||||
files = {"file": open(filename, "rb")}
|
files = {"file": open(filename, "rb")}
|
||||||
response = requests.post(f"{self.base_url}/api/sources/upload", files=files, params=params, headers=self.headers)
|
response = requests.post(f"{self.base_url}/api/sources/{source_id}/upload", files=files, headers=self.headers)
|
||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
def create_source(self, name: str) -> Source:
|
def create_source(self, name: str) -> Source:
|
||||||
@ -466,17 +465,17 @@ class RESTClient(AbstractClient):
|
|||||||
embedding_model=response_json["embedding_config"]["embedding_model"],
|
embedding_model=response_json["embedding_config"]["embedding_model"],
|
||||||
)
|
)
|
||||||
|
|
||||||
def attach_source_to_agent(self, source_name: str, agent_id: uuid.UUID):
|
def attach_source_to_agent(self, source_id: uuid.UUID, agent_id: uuid.UUID):
|
||||||
"""Attach a source to an agent"""
|
"""Attach a source to an agent"""
|
||||||
params = {"source_name": source_name, "agent_id": agent_id}
|
params = {"agent_id": agent_id}
|
||||||
response = requests.post(f"{self.base_url}/api/sources/attach", params=params, headers=self.headers)
|
response = requests.post(f"{self.base_url}/api/sources/{source_id}/attach", params=params, headers=self.headers)
|
||||||
assert response.status_code == 200, f"Failed to attach source to agent: {response.text}"
|
assert response.status_code == 200, f"Failed to attach source to agent: {response.text}"
|
||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
def detach_source(self, source_name: str, agent_id: uuid.UUID):
|
def detach_source(self, source_id: uuid.UUID, agent_id: uuid.UUID):
|
||||||
"""Detach a source from an agent"""
|
"""Detach a source from an agent"""
|
||||||
params = {"source_name": source_name, "agent_id": str(agent_id)}
|
params = {"agent_id": str(agent_id)}
|
||||||
response = requests.post(f"{self.base_url}/api/sources/detach", params=params, headers=self.headers)
|
response = requests.post(f"{self.base_url}/api/sources/{source_id}/detach", params=params, headers=self.headers)
|
||||||
assert response.status_code == 200, f"Failed to detach source from agent: {response.text}"
|
assert response.status_code == 200, f"Failed to detach source from agent: {response.text}"
|
||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
@ -615,8 +614,8 @@ class LocalClient(AbstractClient):
|
|||||||
def create_source(self, name: str):
|
def create_source(self, name: str):
|
||||||
self.server.create_source(user_id=self.user_id, name=name)
|
self.server.create_source(user_id=self.user_id, name=name)
|
||||||
|
|
||||||
def attach_source_to_agent(self, source_name: str, agent_id: uuid.UUID):
|
def attach_source_to_agent(self, source_id: uuid.UUID, agent_id: uuid.UUID):
|
||||||
self.server.attach_source_to_agent(user_id=self.user_id, source_name=source_name, agent_id=agent_id)
|
self.server.attach_source_to_agent(user_id=self.user_id, source_id=source_id, agent_id=agent_id)
|
||||||
|
|
||||||
def delete_agent(self, agent_id: uuid.UUID):
|
def delete_agent(self, agent_id: uuid.UUID):
|
||||||
self.server.delete_agent(user_id=self.user_id, agent_id=agent_id)
|
self.server.delete_agent(user_id=self.user_id, agent_id=agent_id)
|
||||||
|
@ -118,17 +118,17 @@ def setup_sources_index_router(server: SyncServer, interface: QueuingInterface,
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=f"{e}")
|
raise HTTPException(status_code=500, detail=f"{e}")
|
||||||
|
|
||||||
@router.post("/sources/attach", tags=["sources"], response_model=SourceModel)
|
@router.post("/sources/{source_id}/attach", tags=["sources"], response_model=SourceModel)
|
||||||
async def attach_source_to_agent(
|
async def attach_source_to_agent(
|
||||||
|
source_id: uuid.UUID,
|
||||||
agent_id: uuid.UUID = Query(..., description="The unique identifier of the agent to attach the source to."),
|
agent_id: uuid.UUID = Query(..., description="The unique identifier of the agent to attach the source to."),
|
||||||
source_name: str = Query(..., description="The name of the source to attach."),
|
|
||||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||||
):
|
):
|
||||||
"""Attach a data source to an existing agent."""
|
"""Attach a data source to an existing agent."""
|
||||||
interface.clear()
|
interface.clear()
|
||||||
assert isinstance(agent_id, uuid.UUID), f"Expected agent_id to be a UUID, got {agent_id}"
|
assert isinstance(agent_id, uuid.UUID), f"Expected agent_id to be a UUID, got {agent_id}"
|
||||||
assert isinstance(user_id, uuid.UUID), f"Expected user_id to be a UUID, got {user_id}"
|
assert isinstance(user_id, uuid.UUID), f"Expected user_id to be a UUID, got {user_id}"
|
||||||
source = server.attach_source_to_agent(source_name=source_name, agent_id=agent_id, user_id=user_id)
|
source = server.attach_source_to_agent(source_id=source_id, agent_id=agent_id, user_id=user_id)
|
||||||
return SourceModel(
|
return SourceModel(
|
||||||
name=source.name,
|
name=source.name,
|
||||||
description=None, # TODO: actually store descriptions
|
description=None, # TODO: actually store descriptions
|
||||||
@ -138,20 +138,20 @@ def setup_sources_index_router(server: SyncServer, interface: QueuingInterface,
|
|||||||
created_at=source.created_at,
|
created_at=source.created_at,
|
||||||
)
|
)
|
||||||
|
|
||||||
@router.post("/sources/detach", tags=["sources"], response_model=SourceModel)
|
@router.post("/sources/{source_id}/detach", tags=["sources"], response_model=SourceModel)
|
||||||
async def detach_source_from_agent(
|
async def detach_source_from_agent(
|
||||||
|
source_id: uuid.UUID,
|
||||||
agent_id: uuid.UUID = Query(..., description="The unique identifier of the agent to detach the source from."),
|
agent_id: uuid.UUID = Query(..., description="The unique identifier of the agent to detach the source from."),
|
||||||
source_name: str = Query(..., description="The name of the source to detach."),
|
|
||||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||||
):
|
):
|
||||||
"""Detach a data source from an existing agent."""
|
"""Detach a data source from an existing agent."""
|
||||||
server.detach_source_from_agent(source_name=source_name, agent_id=agent_id, user_id=user_id)
|
server.detach_source_from_agent(source_id=source_id, agent_id=agent_id, user_id=user_id)
|
||||||
|
|
||||||
@router.post("/sources/upload", tags=["sources"], response_model=UploadFileToSourceResponse)
|
@router.post("/sources/{source_id}/upload", tags=["sources"], response_model=UploadFileToSourceResponse)
|
||||||
async def upload_file_to_source(
|
async def upload_file_to_source(
|
||||||
# file: UploadFile = UploadFile(..., description="The file to upload."),
|
# file: UploadFile = UploadFile(..., description="The file to upload."),
|
||||||
file: UploadFile,
|
file: UploadFile,
|
||||||
source_id: uuid.UUID = Query(..., description="The unique identifier of the source to attach."),
|
source_id: uuid.UUID,
|
||||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||||
):
|
):
|
||||||
"""Upload a file to a data source."""
|
"""Upload a file to a data source."""
|
||||||
@ -173,18 +173,18 @@ def setup_sources_index_router(server: SyncServer, interface: QueuingInterface,
|
|||||||
# TODO: actually return added passages/documents
|
# TODO: actually return added passages/documents
|
||||||
return UploadFileToSourceResponse(source=source, added_passages=passage_count, added_documents=document_count)
|
return UploadFileToSourceResponse(source=source, added_passages=passage_count, added_documents=document_count)
|
||||||
|
|
||||||
@router.get("/sources/passages ", tags=["sources"], response_model=GetSourcePassagesResponse)
|
@router.get("/sources/{source_id}/passages ", tags=["sources"], response_model=GetSourcePassagesResponse)
|
||||||
async def list_passages(
|
async def list_passages(
|
||||||
source_id: uuid.UUID = Body(...),
|
source_id: uuid.UUID,
|
||||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||||
):
|
):
|
||||||
"""List all passages associated with a data source."""
|
"""List all passages associated with a data source."""
|
||||||
passages = server.list_data_source_passages(user_id=user_id, source_id=source_id)
|
passages = server.list_data_source_passages(user_id=user_id, source_id=source_id)
|
||||||
return GetSourcePassagesResponse(passages=passages)
|
return GetSourcePassagesResponse(passages=passages)
|
||||||
|
|
||||||
@router.get("/sources/documents", tags=["sources"], response_model=GetSourceDocumentsResponse)
|
@router.get("/sources/{source_id}/documents", tags=["sources"], response_model=GetSourceDocumentsResponse)
|
||||||
async def list_documents(
|
async def list_documents(
|
||||||
source_id: uuid.UUID = Body(...),
|
source_id: uuid.UUID,
|
||||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||||
):
|
):
|
||||||
"""List all documents associated with a data source."""
|
"""List all documents associated with a data source."""
|
||||||
|
@ -1293,11 +1293,17 @@ class SyncServer(LockingServer):
|
|||||||
passage_count, document_count = load_data(connector, source, self.config.default_embedding_config, passage_store, document_store)
|
passage_count, document_count = load_data(connector, source, self.config.default_embedding_config, passage_store, document_store)
|
||||||
return passage_count, document_count
|
return passage_count, document_count
|
||||||
|
|
||||||
def attach_source_to_agent(self, user_id: uuid.UUID, agent_id: uuid.UUID, source_name: str):
|
def attach_source_to_agent(
|
||||||
|
self,
|
||||||
|
user_id: uuid.UUID,
|
||||||
|
agent_id: uuid.UUID,
|
||||||
|
source_id: Optional[uuid.UUID] = None,
|
||||||
|
source_name: Optional[str] = None,
|
||||||
|
):
|
||||||
# attach a data source to an agent
|
# attach a data source to an agent
|
||||||
data_source = self.ms.get_source(source_name=source_name, user_id=user_id)
|
data_source = self.ms.get_source(source_id=source_id, user_id=user_id, source_name=source_name)
|
||||||
if data_source is None:
|
if data_source is None:
|
||||||
raise ValueError(f"Data source {source_name} does not exist for user_id {user_id}")
|
raise ValueError(f"Data source id={source_id} name={source_name} does not exist for user_id {user_id}")
|
||||||
|
|
||||||
# get connection to data source storage
|
# get connection to data source storage
|
||||||
source_connector = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=user_id)
|
source_connector = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=user_id)
|
||||||
@ -1310,7 +1316,13 @@ class SyncServer(LockingServer):
|
|||||||
|
|
||||||
return data_source
|
return data_source
|
||||||
|
|
||||||
def detach_source_from_agent(self, user_id: uuid.UUID, agent_id: uuid.UUID, source_name: str):
|
def detach_source_from_agent(
|
||||||
|
self,
|
||||||
|
user_id: uuid.UUID,
|
||||||
|
agent_id: uuid.UUID,
|
||||||
|
source_id: Optional[uuid.UUID] = None,
|
||||||
|
source_name: Optional[str] = None,
|
||||||
|
):
|
||||||
# TODO: remove all passages coresponding to source from agent's archival memory
|
# TODO: remove all passages coresponding to source from agent's archival memory
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@ -270,12 +270,12 @@ def test_sources(client, agent):
|
|||||||
# load a file into a source
|
# load a file into a source
|
||||||
filename = "CONTRIBUTING.md"
|
filename = "CONTRIBUTING.md"
|
||||||
num_passages = 20
|
num_passages = 20
|
||||||
response = client.load_file_into_source(filename, source.id)
|
response = client.load_file_into_source(filename=filename, source_id=source.id)
|
||||||
print(response)
|
print(response)
|
||||||
|
|
||||||
# attach a source
|
# attach a source
|
||||||
# TODO: make sure things run in the right order
|
# TODO: make sure things run in the right order
|
||||||
client.attach_source_to_agent(source_name="test_source", agent_id=agent.id)
|
client.attach_source_to_agent(source_id=source.id, agent_id=agent.id)
|
||||||
|
|
||||||
# list archival memory
|
# list archival memory
|
||||||
archival_memories = client.get_agent_archival_memory(agent_id=agent.id).archival_memory
|
archival_memories = client.get_agent_archival_memory(agent_id=agent.id).archival_memory
|
||||||
|
@ -157,7 +157,7 @@ def test_attach_source_to_agent(server, user_id, agent_id):
|
|||||||
assert len(passages_before) == 0
|
assert len(passages_before) == 0
|
||||||
|
|
||||||
# attach source
|
# attach source
|
||||||
server.attach_source_to_agent(user_id, agent_id, "test_source")
|
server.attach_source_to_agent(user_id=user_id, agent_id=agent_id, source_name="test_source")
|
||||||
|
|
||||||
# check archival memory size
|
# check archival memory size
|
||||||
passages_after = server.get_agent_archival(user_id=user_id, agent_id=agent_id, start=0, count=10000)
|
passages_after = server.get_agent_archival(user_id=user_id, agent_id=agent_id, start=0, count=10000)
|
||||||
|
Loading…
Reference in New Issue
Block a user