refactor: Add OpenAI assistants API endpoints to memgpt server (#1006)

This commit is contained in:
Sarah Wooders 2024-02-14 15:51:56 -08:00 committed by GitHub
parent 326923b578
commit 9c39e8cefd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 328 additions and 339 deletions

View File

@ -45,15 +45,7 @@ Start the server with:
poetry run uvicorn assistants:app --reload --port 8080
"""
interface: QueuingInterface = QueuingInterface()
server: SyncServer = SyncServer(default_interface=interface)
# router = APIRouter()
app = FastAPI()
user_id = uuid.UUID(MemGPTConfig.load().anon_clientid)
print(f"User ID: {user_id}")
router = APIRouter()
class CreateAssistantRequest(BaseModel):
@ -167,16 +159,14 @@ class SubmitToolOutputsToRunRequest(BaseModel):
# TODO: implement mechanism for creating/authenticating users associated with a bearer token
def setup_openai_assistant_router(server: SyncServer, interface: QueuingInterface):
# TODO: remove this (when we have user auth)
user_id = uuid.UUID(MemGPTConfig.load().anon_clientid)
print(f"User ID: {user_id}")
@app.get("/v1/health", tags=["assistant"])
def get_health():
return {"status": "healthy"}
# create assistant (MemGPT agent)
@app.post("/v1/assistants", tags=["assistants"], response_model=OpenAIAssistant)
def create_assistant(request: CreateAssistantRequest = Body(...)):
# create assistant (MemGPT agent)
@router.post("/assistants", tags=["assistants"], response_model=OpenAIAssistant)
def create_assistant(request: CreateAssistantRequest = Body(...)):
# TODO: create preset
return OpenAIAssistant(
id=DEFAULT_PRESET,
@ -190,12 +180,11 @@ def create_assistant(request: CreateAssistantRequest = Body(...)):
metadata=request.metadata,
)
@app.post("/v1/assistants/{assistant_id}/files", tags=["assistants"], response_model=AssistantFile)
def create_assistant_file(
@router.post("/assistants/{assistant_id}/files", tags=["assistants"], response_model=AssistantFile)
def create_assistant_file(
assistant_id: str = Path(..., description="The unique identifier of the assistant."),
request: CreateAssistantFileRequest = Body(...),
):
):
# TODO: add file to assistant
return AssistantFile(
id=request.file_id,
@ -203,75 +192,75 @@ def create_assistant_file(
assistant_id=assistant_id,
)
@app.get("/v1/assistants", tags=["assistants"], response_model=List[OpenAIAssistant])
def list_assistants(
@router.get("/assistants", tags=["assistants"], response_model=List[OpenAIAssistant])
def list_assistants(
limit: int = Query(1000, description="How many assistants to retrieve."),
order: str = Query("asc", description="Order of assistants to retrieve (either 'asc' or 'desc')."),
after: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
before: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
):
after: str = Query(
None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."
),
before: str = Query(
None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."
),
):
# TODO: implement list assistants (i.e. list available MemGPT presets)
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@app.get("/v1/assistants/{assistant_id}/files", tags=["assistants"], response_model=List[AssistantFile])
def list_assistant_files(
@router.get("/assistants/{assistant_id}/files", tags=["assistants"], response_model=List[AssistantFile])
def list_assistant_files(
assistant_id: str = Path(..., description="The unique identifier of the assistant."),
limit: int = Query(1000, description="How many files to retrieve."),
order: str = Query("asc", description="Order of files to retrieve (either 'asc' or 'desc')."),
after: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
before: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
):
after: str = Query(
None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."
),
before: str = Query(
None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."
),
):
# TODO: list attached data sources to preset
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@app.get("/v1/assistants/{assistant_id}", tags=["assistants"], response_model=OpenAIAssistant)
def retrieve_assistant(
@router.get("/assistants/{assistant_id}", tags=["assistants"], response_model=OpenAIAssistant)
def retrieve_assistant(
assistant_id: str = Path(..., description="The unique identifier of the assistant."),
):
):
# TODO: get and return preset
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@app.get("/v1/assistants/{assistant_id}/files/{file_id}", tags=["assistants"], response_model=AssistantFile)
def retrieve_assistant_file(
@router.get("/assistants/{assistant_id}/files/{file_id}", tags=["assistants"], response_model=AssistantFile)
def retrieve_assistant_file(
assistant_id: str = Path(..., description="The unique identifier of the assistant."),
file_id: str = Path(..., description="The unique identifier of the file."),
):
):
# TODO: return data source attached to preset
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@app.post("/v1/assistants/{assistant_id}", tags=["assistants"], response_model=OpenAIAssistant)
def modify_assistant(
@router.post("/assistants/{assistant_id}", tags=["assistants"], response_model=OpenAIAssistant)
def modify_assistant(
assistant_id: str = Path(..., description="The unique identifier of the assistant."),
request: CreateAssistantRequest = Body(...),
):
):
# TODO: modify preset
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@app.delete("/v1/assistants/{assistant_id}", tags=["assistants"], response_model=DeleteAssistantResponse)
def delete_assistant(
@router.delete("/assistants/{assistant_id}", tags=["assistants"], response_model=DeleteAssistantResponse)
def delete_assistant(
assistant_id: str = Path(..., description="The unique identifier of the assistant."),
):
):
# TODO: delete preset
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@app.delete("/v1/assistants/{assistant_id}/files/{file_id}", tags=["assistants"], response_model=DeleteAssistantFileResponse)
def delete_assistant_file(
@router.delete("/assistants/{assistant_id}/files/{file_id}", tags=["assistants"], response_model=DeleteAssistantFileResponse)
def delete_assistant_file(
assistant_id: str = Path(..., description="The unique identifier of the assistant."),
file_id: str = Path(..., description="The unique identifier of the file."),
):
):
# TODO: delete source on preset
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@app.post("/v1/threads", tags=["assistants"], response_model=OpenAIThread)
def create_thread(request: CreateThreadRequest = Body(...)):
@router.post("/threads", tags=["threads"], response_model=OpenAIThread)
def create_thread(request: CreateThreadRequest = Body(...)):
# TODO: use requests.description and requests.metadata fields
# TODO: handle requests.file_ids and requests.tools
# TODO: eventually allow request to override embedding/llm model
@ -290,40 +279,36 @@ def create_thread(request: CreateThreadRequest = Body(...)):
created_at=int(agent_state.created_at.timestamp()),
)
@app.get("/v1/threads/{thread_id}", tags=["assistants"], response_model=OpenAIThread)
def retrieve_thread(
@router.get("/threads/{thread_id}", tags=["threads"], response_model=OpenAIThread)
def retrieve_thread(
thread_id: str = Path(..., description="The unique identifier of the thread."),
):
):
agent = server.get_agent(uuid.UUID(thread_id))
return OpenAIThread(
id=str(agent.id),
created_at=int(agent.created_at.timestamp()),
)
@app.get("/v1/threads/{thread_id}", tags=["assistants"], response_model=OpenAIThread)
def modify_thread(
@router.get("/threads/{thread_id}", tags=["threads"], response_model=OpenAIThread)
def modify_thread(
thread_id: str = Path(..., description="The unique identifier of the thread."),
request: ModifyThreadRequest = Body(...),
):
):
# TODO: add agent metadata so this can be modified
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@app.delete("/v1/threads/{thread_id}", tags=["assistants"], response_model=DeleteThreadResponse)
def delete_thread(
@router.delete("/threads/{thread_id}", tags=["threads"], response_model=DeleteThreadResponse)
def delete_thread(
thread_id: str = Path(..., description="The unique identifier of the thread."),
):
):
# TODO: delete agent
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@app.post("/v1/threads/{thread_id}/messages", tags=["assistants"], response_model=OpenAIMessage)
def create_message(
@router.post("/threads/{thread_id}/messages", tags=["messages"], response_model=OpenAIMessage)
def create_message(
thread_id: str = Path(..., description="The unique identifier of the thread."),
request: CreateMessageRequest = Body(...),
):
):
agent_id = uuid.UUID(thread_id)
# create message object
message = Message(
@ -348,15 +333,18 @@ def create_message(
)
return openai_message
@app.get("/v1/threads/{thread_id}/messages", tags=["assistants"], response_model=ListMessagesResponse)
def list_messages(
@router.get("/threads/{thread_id}/messages", tags=["messages"], response_model=ListMessagesResponse)
def list_messages(
thread_id: str = Path(..., description="The unique identifier of the thread."),
limit: int = Query(1000, description="How many messages to retrieve."),
order: str = Query("asc", description="Order of messages to retrieve (either 'asc' or 'desc')."),
after: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
before: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
):
after: str = Query(
None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."
),
before: str = Query(
None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."
),
):
after_uuid = uuid.UUID(after) if before else None
before_uuid = uuid.UUID(before) if before else None
agent_id = uuid.UUID(thread_id)
@ -389,14 +377,12 @@ def list_messages(
# TODO: cast back to message objects
return ListMessagesResponse(messages=openai_messages)
router.get("/threads/{thread_id}/messages/{message_id}", tags=["messages"], response_model=OpenAIMessage)
app.get("/v1/threads/{thread_id}/messages/{message_id}", tags=["assistants"], response_model=OpenAIMessage)
def retrieve_message(
def retrieve_message(
thread_id: str = Path(..., description="The unique identifier of the thread."),
message_id: str = Path(..., description="The unique identifier of the message."),
):
):
message_id = uuid.UUID(message_id)
agent_id = uuid.UUID(thread_id)
message = server.get_agent_message(agent_id, message_id)
@ -411,32 +397,29 @@ def retrieve_message(
# metadata=message.metadata,
)
@app.get("/v1/threads/{thread_id}/messages/{message_id}/files/{file_id}", tags=["assistants"], response_model=MessageFile)
def retrieve_message_file(
@router.get("/threads/{thread_id}/messages/{message_id}/files/{file_id}", tags=["messages"], response_model=MessageFile)
def retrieve_message_file(
thread_id: str = Path(..., description="The unique identifier of the thread."),
message_id: str = Path(..., description="The unique identifier of the message."),
file_id: str = Path(..., description="The unique identifier of the file."),
):
):
# TODO: implement?
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@app.post("/v1/threads/{thread_id}/messages/{message_id}", tags=["assistants"], response_model=OpenAIMessage)
def modify_message(
@router.post("/threads/{thread_id}/messages/{message_id}", tags=["messages"], response_model=OpenAIMessage)
def modify_message(
thread_id: str = Path(..., description="The unique identifier of the thread."),
message_id: str = Path(..., description="The unique identifier of the message."),
request: ModifyMessageRequest = Body(...),
):
):
# TODO: add metada field to message so this can be modified
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@app.post("/v1/threads/{thread_id}/runs", tags=["assistants"], response_model=OpenAIRun)
def create_run(
@router.post("/threads/{thread_id}/runs", tags=["runs"], response_model=OpenAIRun)
def create_run(
thread_id: str = Path(..., description="The unique identifier of the thread."),
request: CreateRunRequest = Body(...),
):
):
# TODO: add request.instructions as a message?
agent_id = uuid.UUID(thread_id)
# TODO: override preset of agent with request.assistant_id
@ -455,78 +438,80 @@ def create_run(
instructions=request.instructions,
)
@app.post("/v1/threads/runs", tags=["assistants"], response_model=OpenAIRun)
def create_thread_and_run(
@router.post("/threads/runs", tags=["runs"], response_model=OpenAIRun)
def create_thread_and_run(
request: CreateThreadRunRequest = Body(...),
):
):
# TODO: add a bunch of messages and execute
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@app.get("/v1/threads/{thread_id}/runs", tags=["assistants"], response_model=List[OpenAIRun])
def list_runs(
@router.get("/threads/{thread_id}/runs", tags=["runs"], response_model=List[OpenAIRun])
def list_runs(
thread_id: str = Path(..., description="The unique identifier of the thread."),
limit: int = Query(1000, description="How many runs to retrieve."),
order: str = Query("asc", description="Order of runs to retrieve (either 'asc' or 'desc')."),
after: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
before: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
):
after: str = Query(
None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."
),
before: str = Query(
None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."
),
):
# TODO: store run information in a DB so it can be returned here
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@app.get("/v1/threads/{thread_id}/runs/{run_id}/steps", tags=["assistants"], response_model=List[OpenAIRunStep])
def list_run_steps(
@router.get("/threads/{thread_id}/runs/{run_id}/steps", tags=["runs"], response_model=List[OpenAIRunStep])
def list_run_steps(
thread_id: str = Path(..., description="The unique identifier of the thread."),
run_id: str = Path(..., description="The unique identifier of the run."),
limit: int = Query(1000, description="How many run steps to retrieve."),
order: str = Query("asc", description="Order of run steps to retrieve (either 'asc' or 'desc')."),
after: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
before: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
):
after: str = Query(
None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."
),
before: str = Query(
None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."
),
):
# TODO: store run information in a DB so it can be returned here
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@app.get("/v1/threads/{thread_id}/runs/{run_id}", tags=["assistants"], response_model=OpenAIRun)
def retrieve_run(
@router.get("/threads/{thread_id}/runs/{run_id}", tags=["runs"], response_model=OpenAIRun)
def retrieve_run(
thread_id: str = Path(..., description="The unique identifier of the thread."),
run_id: str = Path(..., description="The unique identifier of the run."),
):
):
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@app.get("/v1/threads/{thread_id}/runs/{run_id}/steps/{step_id}", tags=["assistants"], response_model=OpenAIRunStep)
def retrieve_run_step(
@router.get("/threads/{thread_id}/runs/{run_id}/steps/{step_id}", tags=["runs"], response_model=OpenAIRunStep)
def retrieve_run_step(
thread_id: str = Path(..., description="The unique identifier of the thread."),
run_id: str = Path(..., description="The unique identifier of the run."),
step_id: str = Path(..., description="The unique identifier of the run step."),
):
):
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@app.post("/v1/threads/{thread_id}/runs/{run_id}", tags=["assistants"], response_model=OpenAIRun)
def modify_run(
@router.post("/threads/{thread_id}/runs/{run_id}", tags=["runs"], response_model=OpenAIRun)
def modify_run(
thread_id: str = Path(..., description="The unique identifier of the thread."),
run_id: str = Path(..., description="The unique identifier of the run."),
request: ModifyRunRequest = Body(...),
):
):
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@app.post("/v1/threads/{thread_id}/runs/{run_id}/submit_tool_outputs", tags=["assistants"], response_model=OpenAIRun)
def submit_tool_outputs_to_run(
@router.post("/threads/{thread_id}/runs/{run_id}/submit_tool_outputs", tags=["runs"], response_model=OpenAIRun)
def submit_tool_outputs_to_run(
thread_id: str = Path(..., description="The unique identifier of the thread."),
run_id: str = Path(..., description="The unique identifier of the run."),
request: SubmitToolOutputsToRunRequest = Body(...),
):
):
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@app.post("/v1/threads/{thread_id}/runs/{run_id}/cancel", tags=["assistants"], response_model=OpenAIRun)
def cancel_run(
@router.post("/threads/{thread_id}/runs/{run_id}/cancel", tags=["runs"], response_model=OpenAIRun)
def cancel_run(
thread_id: str = Path(..., description="The unique identifier of the thread."),
run_id: str = Path(..., description="The unique identifier of the run."),
):
):
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
return router

View File

@ -16,6 +16,7 @@ from memgpt.server.rest_api.config.index import setup_config_index_router
from memgpt.server.rest_api.humans.index import setup_humans_index_router
from memgpt.server.rest_api.personas.index import setup_personas_index_router
from memgpt.server.rest_api.models.index import setup_models_index_router
from memgpt.server.rest_api.openai_assistants.assistants import setup_openai_assistant_router
from memgpt.server.server import SyncServer
from memgpt.server.rest_api.interface import QueuingInterface
from memgpt.server.rest_api.static_files import mount_static_files
@ -33,6 +34,7 @@ server: SyncServer = SyncServer(default_interface=interface)
API_PREFIX = "/api"
OPENAI_API_PREFIX = "/v1"
CORS_ORIGINS = [
"http://localhost:4200",
@ -65,6 +67,8 @@ app.include_router(setup_personas_index_router(server, interface), prefix=API_PR
app.include_router(setup_models_index_router(server, interface), prefix=API_PREFIX)
# /api/config endpoints
app.include_router(setup_config_index_router(server, interface), prefix=API_PREFIX)
# /v1/assistants endpoints
app.include_router(setup_openai_assistant_router(server, interface), prefix=OPENAI_API_PREFIX)
# / static files
mount_static_files(app)

View File

@ -3,7 +3,7 @@ from fastapi.testclient import TestClient
import uuid
from memgpt.server.server import SyncServer
from memgpt.server.rest_api.openai_assistants.assistants import app
from memgpt.server.rest_api.server import app
from memgpt.constants import DEFAULT_PRESET