MemGPT/memgpt/server/rest_api/sources/index.py
2024-03-20 21:37:52 -07:00

195 lines
8.1 KiB
Python

import uuid
import tempfile
import os
import hashlib
from functools import partial
from typing import List, Optional
from fastapi import APIRouter, Body, Depends, Query, HTTPException, status, UploadFile
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field
from memgpt.models.pydantic_models import SourceModel, PassageModel, DocumentModel
from memgpt.server.rest_api.auth_token import get_current_user
from memgpt.server.rest_api.interface import QueuingInterface
from memgpt.server.server import SyncServer
from memgpt.data_types import Source
from memgpt.data_sources.connectors import DirectoryConnector
router = APIRouter()
"""
Implement the following functions:
* List all available sources
* Create a new source
* Delete a source
* Upload a file to a server that is loaded into a specific source
* Paginated get all passages from a source
* Paginated get all documents from a source
* Attach a source to an agent
"""
class ListSourcesResponse(BaseModel):
sources: List[SourceModel] = Field(..., description="List of available sources.")
class CreateSourceRequest(BaseModel):
name: str = Field(..., description="The name of the source.")
description: Optional[str] = Field(None, description="The description of the source.")
class CreateSourceResponse(BaseModel):
source: SourceModel = Field(..., description="The newly created source.")
class UploadFileToSourceRequest(BaseModel):
file: UploadFile = Field(..., description="The file to upload.")
class UploadFileToSourceResponse(BaseModel):
source: SourceModel = Field(..., description="The source the file was uploaded to.")
added_passages: int = Field(..., description="The number of passages added to the source.")
added_documents: int = Field(..., description="The number of documents added to the source.")
class GetSourcePassagesResponse(BaseModel):
passages: List[PassageModel] = Field(..., description="List of passages from the source.")
class GetSourceDocumentsResponse(BaseModel):
documents: List[DocumentModel] = Field(..., description="List of documents from the source.")
def setup_sources_index_router(server: SyncServer, interface: QueuingInterface, password: str):
get_current_user_with_server = partial(partial(get_current_user, server), password)
@router.get("/sources", tags=["sources"], response_model=ListSourcesResponse)
async def list_sources(
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""List all data sources created by a user."""
# Clear the interface
interface.clear()
try:
sources = server.list_all_sources(user_id=user_id)
return ListSourcesResponse(sources=sources)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
@router.post("/sources", tags=["sources"], response_model=SourceModel)
async def create_source(
request: CreateSourceRequest = Body(...),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""Create a new data source."""
interface.clear()
try:
# TODO: don't use Source and just use SourceModel once pydantic migration is complete
source = server.create_source(name=request.name, user_id=user_id)
return SourceModel(
name=source.name,
description=None, # TODO: actually store descriptions
user_id=source.user_id,
id=source.id,
embedding_config=server.server_embedding_config,
created_at=source.created_at.timestamp(),
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
@router.delete("/sources/{source_id}", tags=["sources"])
async def delete_source(
source_id: uuid.UUID,
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""Delete a data source."""
interface.clear()
try:
server.delete_source(source_id=source_id, user_id=user_id)
return JSONResponse(status_code=status.HTTP_200_OK, content={"message": f"Source source_id={source_id} successfully deleted"})
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
@router.post("/sources/{source_id}/attach", tags=["sources"], response_model=SourceModel)
async def attach_source_to_agent(
source_id: uuid.UUID,
agent_id: uuid.UUID = Query(..., description="The unique identifier of the agent to attach the source to."),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""Attach a data source to an existing agent."""
interface.clear()
assert isinstance(agent_id, uuid.UUID), f"Expected agent_id to be a UUID, got {agent_id}"
assert isinstance(user_id, uuid.UUID), f"Expected user_id to be a UUID, got {user_id}"
source = server.attach_source_to_agent(source_id=source_id, agent_id=agent_id, user_id=user_id)
return SourceModel(
name=source.name,
description=None, # TODO: actually store descriptions
user_id=source.user_id,
id=source.id,
embedding_config=server.server_embedding_config,
created_at=source.created_at,
)
@router.post("/sources/{source_id}/detach", tags=["sources"], response_model=SourceModel)
async def detach_source_from_agent(
source_id: uuid.UUID,
agent_id: uuid.UUID = Query(..., description="The unique identifier of the agent to detach the source from."),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""Detach a data source from an existing agent."""
server.detach_source_from_agent(source_id=source_id, agent_id=agent_id, user_id=user_id)
@router.post("/sources/{source_id}/upload", tags=["sources"], response_model=UploadFileToSourceResponse)
async def upload_file_to_source(
# file: UploadFile = UploadFile(..., description="The file to upload."),
file: UploadFile,
source_id: uuid.UUID,
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""Upload a file to a data source."""
interface.clear()
source = server.ms.get_source(source_id=source_id, user_id=user_id)
# write the file to a temporary directory (deleted after the context manager exits)
with tempfile.TemporaryDirectory() as tmpdirname:
file_path = os.path.join(tmpdirname, file.filename)
with open(file_path, "wb") as buffer:
buffer.write(file.file.read())
# read the file
connector = DirectoryConnector(input_files=[file_path])
# load the data into the source via the connector
passage_count, document_count = server.load_data(user_id=user_id, source_name=source.name, connector=connector)
# TODO: actually return added passages/documents
return UploadFileToSourceResponse(source=source, added_passages=passage_count, added_documents=document_count)
@router.get("/sources/{source_id}/passages ", tags=["sources"], response_model=GetSourcePassagesResponse)
async def list_passages(
source_id: uuid.UUID,
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""List all passages associated with a data source."""
passages = server.list_data_source_passages(user_id=user_id, source_id=source_id)
return GetSourcePassagesResponse(passages=passages)
@router.get("/sources/{source_id}/documents", tags=["sources"], response_model=GetSourceDocumentsResponse)
async def list_documents(
source_id: uuid.UUID,
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""List all documents associated with a data source."""
documents = server.list_data_source_documents(user_id=user_id, source_id=source_id)
return GetSourceDocumentsResponse(documents=documents)
return router