refactor: Add OpenAI assistants API endpoints to memgpt server (#1006)

This commit is contained in:
Sarah Wooders 2024-02-14 15:51:56 -08:00 committed by GitHub
parent 326923b578
commit 9c39e8cefd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 328 additions and 339 deletions

View File

@ -45,15 +45,7 @@ Start the server with:
poetry run uvicorn assistants:app --reload --port 8080
"""
interface: QueuingInterface = QueuingInterface()
server: SyncServer = SyncServer(default_interface=interface)
# router = APIRouter()
app = FastAPI()
user_id = uuid.UUID(MemGPTConfig.load().anon_clientid)
print(f"User ID: {user_id}")
router = APIRouter()
class CreateAssistantRequest(BaseModel):
@ -167,366 +159,359 @@ class SubmitToolOutputsToRunRequest(BaseModel):
# TODO: implement mechanism for creating/authenticating users associated with a bearer token
def setup_openai_assistant_router(server: SyncServer, interface: QueuingInterface):
# TODO: remove this (when we have user auth)
user_id = uuid.UUID(MemGPTConfig.load().anon_clientid)
print(f"User ID: {user_id}")
# create assistant (MemGPT agent)
@router.post("/assistants", tags=["assistants"], response_model=OpenAIAssistant)
def create_assistant(request: CreateAssistantRequest = Body(...)):
# TODO: create preset
return OpenAIAssistant(
id=DEFAULT_PRESET,
name="default_preset",
description=request.description,
created_at=int(datetime.now().timestamp()),
model=request.model,
instructions=request.instructions,
tools=request.tools,
file_ids=request.file_ids,
metadata=request.metadata,
)
@app.get("/v1/health", tags=["assistant"])
def get_health():
return {"status": "healthy"}
@router.post("/assistants/{assistant_id}/files", tags=["assistants"], response_model=AssistantFile)
def create_assistant_file(
assistant_id: str = Path(..., description="The unique identifier of the assistant."),
request: CreateAssistantFileRequest = Body(...),
):
# TODO: add file to assistant
return AssistantFile(
id=request.file_id,
created_at=int(datetime.now().timestamp()),
assistant_id=assistant_id,
)
@router.get("/assistants", tags=["assistants"], response_model=List[OpenAIAssistant])
def list_assistants(
limit: int = Query(1000, description="How many assistants to retrieve."),
order: str = Query("asc", description="Order of assistants to retrieve (either 'asc' or 'desc')."),
after: str = Query(
None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."
),
before: str = Query(
None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."
),
):
# TODO: implement list assistants (i.e. list available MemGPT presets)
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
# create assistant (MemGPT agent)
@app.post("/v1/assistants", tags=["assistants"], response_model=OpenAIAssistant)
def create_assistant(request: CreateAssistantRequest = Body(...)):
# TODO: create preset
return OpenAIAssistant(
id=DEFAULT_PRESET,
name="default_preset",
description=request.description,
created_at=int(datetime.now().timestamp()),
model=request.model,
instructions=request.instructions,
tools=request.tools,
file_ids=request.file_ids,
metadata=request.metadata,
)
@router.get("/assistants/{assistant_id}/files", tags=["assistants"], response_model=List[AssistantFile])
def list_assistant_files(
assistant_id: str = Path(..., description="The unique identifier of the assistant."),
limit: int = Query(1000, description="How many files to retrieve."),
order: str = Query("asc", description="Order of files to retrieve (either 'asc' or 'desc')."),
after: str = Query(
None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."
),
before: str = Query(
None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."
),
):
# TODO: list attached data sources to preset
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@router.get("/assistants/{assistant_id}", tags=["assistants"], response_model=OpenAIAssistant)
def retrieve_assistant(
assistant_id: str = Path(..., description="The unique identifier of the assistant."),
):
# TODO: get and return preset
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@app.post("/v1/assistants/{assistant_id}/files", tags=["assistants"], response_model=AssistantFile)
def create_assistant_file(
assistant_id: str = Path(..., description="The unique identifier of the assistant."),
request: CreateAssistantFileRequest = Body(...),
):
# TODO: add file to assistant
return AssistantFile(
id=request.file_id,
created_at=int(datetime.now().timestamp()),
assistant_id=assistant_id,
)
@router.get("/assistants/{assistant_id}/files/{file_id}", tags=["assistants"], response_model=AssistantFile)
def retrieve_assistant_file(
assistant_id: str = Path(..., description="The unique identifier of the assistant."),
file_id: str = Path(..., description="The unique identifier of the file."),
):
# TODO: return data source attached to preset
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@router.post("/assistants/{assistant_id}", tags=["assistants"], response_model=OpenAIAssistant)
def modify_assistant(
assistant_id: str = Path(..., description="The unique identifier of the assistant."),
request: CreateAssistantRequest = Body(...),
):
# TODO: modify preset
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@app.get("/v1/assistants", tags=["assistants"], response_model=List[OpenAIAssistant])
def list_assistants(
limit: int = Query(1000, description="How many assistants to retrieve."),
order: str = Query("asc", description="Order of assistants to retrieve (either 'asc' or 'desc')."),
after: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
before: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
):
# TODO: implement list assistants (i.e. list available MemGPT presets)
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@router.delete("/assistants/{assistant_id}", tags=["assistants"], response_model=DeleteAssistantResponse)
def delete_assistant(
assistant_id: str = Path(..., description="The unique identifier of the assistant."),
):
# TODO: delete preset
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@router.delete("/assistants/{assistant_id}/files/{file_id}", tags=["assistants"], response_model=DeleteAssistantFileResponse)
def delete_assistant_file(
assistant_id: str = Path(..., description="The unique identifier of the assistant."),
file_id: str = Path(..., description="The unique identifier of the file."),
):
# TODO: delete source on preset
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@app.get("/v1/assistants/{assistant_id}/files", tags=["assistants"], response_model=List[AssistantFile])
def list_assistant_files(
assistant_id: str = Path(..., description="The unique identifier of the assistant."),
limit: int = Query(1000, description="How many files to retrieve."),
order: str = Query("asc", description="Order of files to retrieve (either 'asc' or 'desc')."),
after: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
before: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
):
# TODO: list attached data sources to preset
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@router.post("/threads", tags=["threads"], response_model=OpenAIThread)
def create_thread(request: CreateThreadRequest = Body(...)):
# TODO: use requests.description and requests.metadata fields
# TODO: handle requests.file_ids and requests.tools
# TODO: eventually allow request to override embedding/llm model
print("Create thread/agent", request)
# create a memgpt agent
agent_state = server.create_agent(
user_id=user_id,
agent_config={
"user_id": user_id,
},
)
# TODO: insert messages into recall memory
return OpenAIThread(
id=str(agent_state.id),
created_at=int(agent_state.created_at.timestamp()),
)
@app.get("/v1/assistants/{assistant_id}", tags=["assistants"], response_model=OpenAIAssistant)
def retrieve_assistant(
assistant_id: str = Path(..., description="The unique identifier of the assistant."),
):
# TODO: get and return preset
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@router.get("/threads/{thread_id}", tags=["threads"], response_model=OpenAIThread)
def retrieve_thread(
thread_id: str = Path(..., description="The unique identifier of the thread."),
):
agent = server.get_agent(uuid.UUID(thread_id))
return OpenAIThread(
id=str(agent.id),
created_at=int(agent.created_at.timestamp()),
)
@router.get("/threads/{thread_id}", tags=["threads"], response_model=OpenAIThread)
def modify_thread(
thread_id: str = Path(..., description="The unique identifier of the thread."),
request: ModifyThreadRequest = Body(...),
):
# TODO: add agent metadata so this can be modified
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@app.get("/v1/assistants/{assistant_id}/files/{file_id}", tags=["assistants"], response_model=AssistantFile)
def retrieve_assistant_file(
assistant_id: str = Path(..., description="The unique identifier of the assistant."),
file_id: str = Path(..., description="The unique identifier of the file."),
):
# TODO: return data source attached to preset
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@router.delete("/threads/{thread_id}", tags=["threads"], response_model=DeleteThreadResponse)
def delete_thread(
thread_id: str = Path(..., description="The unique identifier of the thread."),
):
# TODO: delete agent
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@router.post("/threads/{thread_id}/messages", tags=["messages"], response_model=OpenAIMessage)
def create_message(
thread_id: str = Path(..., description="The unique identifier of the thread."),
request: CreateMessageRequest = Body(...),
):
agent_id = uuid.UUID(thread_id)
# create message object
message = Message(
user_id=user_id,
agent_id=agent_id,
role=request.role,
text=request.content,
)
agent = server._get_or_load_agent(user_id=user_id, agent_id=agent_id)
# add message to agent
agent._append_to_messages([message])
@app.post("/v1/assistants/{assistant_id}", tags=["assistants"], response_model=OpenAIAssistant)
def modify_assistant(
assistant_id: str = Path(..., description="The unique identifier of the assistant."),
request: CreateAssistantRequest = Body(...),
):
# TODO: modify preset
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@app.delete("/v1/assistants/{assistant_id}", tags=["assistants"], response_model=DeleteAssistantResponse)
def delete_assistant(
assistant_id: str = Path(..., description="The unique identifier of the assistant."),
):
# TODO: delete preset
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@app.delete("/v1/assistants/{assistant_id}/files/{file_id}", tags=["assistants"], response_model=DeleteAssistantFileResponse)
def delete_assistant_file(
assistant_id: str = Path(..., description="The unique identifier of the assistant."),
file_id: str = Path(..., description="The unique identifier of the file."),
):
# TODO: delete source on preset
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@app.post("/v1/threads", tags=["assistants"], response_model=OpenAIThread)
def create_thread(request: CreateThreadRequest = Body(...)):
# TODO: use requests.description and requests.metadata fields
# TODO: handle requests.file_ids and requests.tools
# TODO: eventually allow request to override embedding/llm model
print("Create thread/agent", request)
# create a memgpt agent
agent_state = server.create_agent(
user_id=user_id,
agent_config={
"user_id": user_id,
},
)
# TODO: insert messages into recall memory
return OpenAIThread(
id=str(agent_state.id),
created_at=int(agent_state.created_at.timestamp()),
)
@app.get("/v1/threads/{thread_id}", tags=["assistants"], response_model=OpenAIThread)
def retrieve_thread(
thread_id: str = Path(..., description="The unique identifier of the thread."),
):
agent = server.get_agent(uuid.UUID(thread_id))
return OpenAIThread(
id=str(agent.id),
created_at=int(agent.created_at.timestamp()),
)
@app.get("/v1/threads/{thread_id}", tags=["assistants"], response_model=OpenAIThread)
def modify_thread(
thread_id: str = Path(..., description="The unique identifier of the thread."),
request: ModifyThreadRequest = Body(...),
):
# TODO: add agent metadata so this can be modified
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@app.delete("/v1/threads/{thread_id}", tags=["assistants"], response_model=DeleteThreadResponse)
def delete_thread(
thread_id: str = Path(..., description="The unique identifier of the thread."),
):
# TODO: delete agent
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@app.post("/v1/threads/{thread_id}/messages", tags=["assistants"], response_model=OpenAIMessage)
def create_message(
thread_id: str = Path(..., description="The unique identifier of the thread."),
request: CreateMessageRequest = Body(...),
):
agent_id = uuid.UUID(thread_id)
# create message object
message = Message(
user_id=user_id,
agent_id=agent_id,
role=request.role,
text=request.content,
)
agent = server._get_or_load_agent(user_id=user_id, agent_id=agent_id)
# add message to agent
agent._append_to_messages([message])
openai_message = OpenAIMessage(
id=str(message.id),
created_at=int(message.created_at.timestamp()),
content=[Text(text=message.text)],
role=message.role,
thread_id=str(message.agent_id),
assistant_id=DEFAULT_PRESET, # TODO: update this
# file_ids=message.file_ids,
# metadata=message.metadata,
)
return openai_message
@app.get("/v1/threads/{thread_id}/messages", tags=["assistants"], response_model=ListMessagesResponse)
def list_messages(
thread_id: str = Path(..., description="The unique identifier of the thread."),
limit: int = Query(1000, description="How many messages to retrieve."),
order: str = Query("asc", description="Order of messages to retrieve (either 'asc' or 'desc')."),
after: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
before: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
):
after_uuid = uuid.UUID(after) if before else None
before_uuid = uuid.UUID(before) if before else None
agent_id = uuid.UUID(thread_id)
reverse = True if (order == "desc") else False
cursor, json_messages = server.get_agent_recall_cursor(
user_id=user_id,
agent_id=agent_id,
limit=limit,
after=after_uuid,
before=before_uuid,
order_by="created_at",
reverse=reverse,
)
print(json_messages[0]["text"])
# convert to openai style messages
openai_messages = [
OpenAIMessage(
id=str(message["id"]),
created_at=int(message["created_at"].timestamp()),
content=[Text(text=message["text"])],
role=message["role"],
thread_id=str(message["agent_id"]),
assistant_id=DEFAULT_PRESET # TODO: update this
openai_message = OpenAIMessage(
id=str(message.id),
created_at=int(message.created_at.timestamp()),
content=[Text(text=message.text)],
role=message.role,
thread_id=str(message.agent_id),
assistant_id=DEFAULT_PRESET, # TODO: update this
# file_ids=message.file_ids,
# metadata=message.metadata,
)
for message in json_messages
]
print("MESSAGES", openai_messages)
# TODO: cast back to message objects
return ListMessagesResponse(messages=openai_messages)
return openai_message
@router.get("/threads/{thread_id}/messages", tags=["messages"], response_model=ListMessagesResponse)
def list_messages(
thread_id: str = Path(..., description="The unique identifier of the thread."),
limit: int = Query(1000, description="How many messages to retrieve."),
order: str = Query("asc", description="Order of messages to retrieve (either 'asc' or 'desc')."),
after: str = Query(
None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."
),
before: str = Query(
None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."
),
):
after_uuid = uuid.UUID(after) if before else None
before_uuid = uuid.UUID(before) if before else None
agent_id = uuid.UUID(thread_id)
reverse = True if (order == "desc") else False
cursor, json_messages = server.get_agent_recall_cursor(
user_id=user_id,
agent_id=agent_id,
limit=limit,
after=after_uuid,
before=before_uuid,
order_by="created_at",
reverse=reverse,
)
print(json_messages[0]["text"])
# convert to openai style messages
openai_messages = [
OpenAIMessage(
id=str(message["id"]),
created_at=int(message["created_at"].timestamp()),
content=[Text(text=message["text"])],
role=message["role"],
thread_id=str(message["agent_id"]),
assistant_id=DEFAULT_PRESET # TODO: update this
# file_ids=message.file_ids,
# metadata=message.metadata,
)
for message in json_messages
]
print("MESSAGES", openai_messages)
# TODO: cast back to message objects
return ListMessagesResponse(messages=openai_messages)
app.get("/v1/threads/{thread_id}/messages/{message_id}", tags=["assistants"], response_model=OpenAIMessage)
router.get("/threads/{thread_id}/messages/{message_id}", tags=["messages"], response_model=OpenAIMessage)
def retrieve_message(
thread_id: str = Path(..., description="The unique identifier of the thread."),
message_id: str = Path(..., description="The unique identifier of the message."),
):
message_id = uuid.UUID(message_id)
agent_id = uuid.UUID(thread_id)
message = server.get_agent_message(agent_id, message_id)
return OpenAIMessage(
id=str(message.id),
created_at=int(message.created_at.timestamp()),
content=[Text(text=message.text)],
role=message.role,
thread_id=str(message.agent_id),
assistant_id=DEFAULT_PRESET, # TODO: update this
# file_ids=message.file_ids,
# metadata=message.metadata,
)
def retrieve_message(
thread_id: str = Path(..., description="The unique identifier of the thread."),
message_id: str = Path(..., description="The unique identifier of the message."),
):
message_id = uuid.UUID(message_id)
agent_id = uuid.UUID(thread_id)
message = server.get_agent_message(agent_id, message_id)
return OpenAIMessage(
id=str(message.id),
created_at=int(message.created_at.timestamp()),
content=[Text(text=message.text)],
role=message.role,
thread_id=str(message.agent_id),
assistant_id=DEFAULT_PRESET, # TODO: update this
# file_ids=message.file_ids,
# metadata=message.metadata,
)
@router.get("/threads/{thread_id}/messages/{message_id}/files/{file_id}", tags=["messages"], response_model=MessageFile)
def retrieve_message_file(
thread_id: str = Path(..., description="The unique identifier of the thread."),
message_id: str = Path(..., description="The unique identifier of the message."),
file_id: str = Path(..., description="The unique identifier of the file."),
):
# TODO: implement?
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@router.post("/threads/{thread_id}/messages/{message_id}", tags=["messages"], response_model=OpenAIMessage)
def modify_message(
thread_id: str = Path(..., description="The unique identifier of the thread."),
message_id: str = Path(..., description="The unique identifier of the message."),
request: ModifyMessageRequest = Body(...),
):
# TODO: add metada field to message so this can be modified
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@app.get("/v1/threads/{thread_id}/messages/{message_id}/files/{file_id}", tags=["assistants"], response_model=MessageFile)
def retrieve_message_file(
thread_id: str = Path(..., description="The unique identifier of the thread."),
message_id: str = Path(..., description="The unique identifier of the message."),
file_id: str = Path(..., description="The unique identifier of the file."),
):
# TODO: implement?
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@router.post("/threads/{thread_id}/runs", tags=["runs"], response_model=OpenAIRun)
def create_run(
thread_id: str = Path(..., description="The unique identifier of the thread."),
request: CreateRunRequest = Body(...),
):
# TODO: add request.instructions as a message?
agent_id = uuid.UUID(thread_id)
# TODO: override preset of agent with request.assistant_id
agent = server._get_or_load_agent(user_id=user_id, agent_id=agent_id)
agent.step(user_message=None) # already has messages added
run_id = str(uuid.uuid4())
create_time = int(datetime.now().timestamp())
return OpenAIRun(
id=run_id,
created_at=create_time,
thread_id=str(agent_id),
assistant_id=DEFAULT_PRESET, # TODO: update this
status="completed", # TODO: eventaully allow offline execution
expires_at=create_time,
model=agent.agent_state.llm_config.model,
instructions=request.instructions,
)
@router.post("/threads/runs", tags=["runs"], response_model=OpenAIRun)
def create_thread_and_run(
request: CreateThreadRunRequest = Body(...),
):
# TODO: add a bunch of messages and execute
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@app.post("/v1/threads/{thread_id}/messages/{message_id}", tags=["assistants"], response_model=OpenAIMessage)
def modify_message(
thread_id: str = Path(..., description="The unique identifier of the thread."),
message_id: str = Path(..., description="The unique identifier of the message."),
request: ModifyMessageRequest = Body(...),
):
# TODO: add metada field to message so this can be modified
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@router.get("/threads/{thread_id}/runs", tags=["runs"], response_model=List[OpenAIRun])
def list_runs(
thread_id: str = Path(..., description="The unique identifier of the thread."),
limit: int = Query(1000, description="How many runs to retrieve."),
order: str = Query("asc", description="Order of runs to retrieve (either 'asc' or 'desc')."),
after: str = Query(
None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."
),
before: str = Query(
None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."
),
):
# TODO: store run information in a DB so it can be returned here
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@router.get("/threads/{thread_id}/runs/{run_id}/steps", tags=["runs"], response_model=List[OpenAIRunStep])
def list_run_steps(
thread_id: str = Path(..., description="The unique identifier of the thread."),
run_id: str = Path(..., description="The unique identifier of the run."),
limit: int = Query(1000, description="How many run steps to retrieve."),
order: str = Query("asc", description="Order of run steps to retrieve (either 'asc' or 'desc')."),
after: str = Query(
None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."
),
before: str = Query(
None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."
),
):
# TODO: store run information in a DB so it can be returned here
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@app.post("/v1/threads/{thread_id}/runs", tags=["assistants"], response_model=OpenAIRun)
def create_run(
thread_id: str = Path(..., description="The unique identifier of the thread."),
request: CreateRunRequest = Body(...),
):
# TODO: add request.instructions as a message?
agent_id = uuid.UUID(thread_id)
# TODO: override preset of agent with request.assistant_id
agent = server._get_or_load_agent(user_id=user_id, agent_id=agent_id)
agent.step(user_message=None) # already has messages added
run_id = str(uuid.uuid4())
create_time = int(datetime.now().timestamp())
return OpenAIRun(
id=run_id,
created_at=create_time,
thread_id=str(agent_id),
assistant_id=DEFAULT_PRESET, # TODO: update this
status="completed", # TODO: eventaully allow offline execution
expires_at=create_time,
model=agent.agent_state.llm_config.model,
instructions=request.instructions,
)
@router.get("/threads/{thread_id}/runs/{run_id}", tags=["runs"], response_model=OpenAIRun)
def retrieve_run(
thread_id: str = Path(..., description="The unique identifier of the thread."),
run_id: str = Path(..., description="The unique identifier of the run."),
):
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@router.get("/threads/{thread_id}/runs/{run_id}/steps/{step_id}", tags=["runs"], response_model=OpenAIRunStep)
def retrieve_run_step(
thread_id: str = Path(..., description="The unique identifier of the thread."),
run_id: str = Path(..., description="The unique identifier of the run."),
step_id: str = Path(..., description="The unique identifier of the run step."),
):
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@app.post("/v1/threads/runs", tags=["assistants"], response_model=OpenAIRun)
def create_thread_and_run(
request: CreateThreadRunRequest = Body(...),
):
# TODO: add a bunch of messages and execute
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@router.post("/threads/{thread_id}/runs/{run_id}", tags=["runs"], response_model=OpenAIRun)
def modify_run(
thread_id: str = Path(..., description="The unique identifier of the thread."),
run_id: str = Path(..., description="The unique identifier of the run."),
request: ModifyRunRequest = Body(...),
):
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@router.post("/threads/{thread_id}/runs/{run_id}/submit_tool_outputs", tags=["runs"], response_model=OpenAIRun)
def submit_tool_outputs_to_run(
thread_id: str = Path(..., description="The unique identifier of the thread."),
run_id: str = Path(..., description="The unique identifier of the run."),
request: SubmitToolOutputsToRunRequest = Body(...),
):
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@app.get("/v1/threads/{thread_id}/runs", tags=["assistants"], response_model=List[OpenAIRun])
def list_runs(
thread_id: str = Path(..., description="The unique identifier of the thread."),
limit: int = Query(1000, description="How many runs to retrieve."),
order: str = Query("asc", description="Order of runs to retrieve (either 'asc' or 'desc')."),
after: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
before: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
):
# TODO: store run information in a DB so it can be returned here
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@router.post("/threads/{thread_id}/runs/{run_id}/cancel", tags=["runs"], response_model=OpenAIRun)
def cancel_run(
thread_id: str = Path(..., description="The unique identifier of the thread."),
run_id: str = Path(..., description="The unique identifier of the run."),
):
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@app.get("/v1/threads/{thread_id}/runs/{run_id}/steps", tags=["assistants"], response_model=List[OpenAIRunStep])
def list_run_steps(
thread_id: str = Path(..., description="The unique identifier of the thread."),
run_id: str = Path(..., description="The unique identifier of the run."),
limit: int = Query(1000, description="How many run steps to retrieve."),
order: str = Query("asc", description="Order of run steps to retrieve (either 'asc' or 'desc')."),
after: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
before: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
):
# TODO: store run information in a DB so it can be returned here
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@app.get("/v1/threads/{thread_id}/runs/{run_id}", tags=["assistants"], response_model=OpenAIRun)
def retrieve_run(
thread_id: str = Path(..., description="The unique identifier of the thread."),
run_id: str = Path(..., description="The unique identifier of the run."),
):
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@app.get("/v1/threads/{thread_id}/runs/{run_id}/steps/{step_id}", tags=["assistants"], response_model=OpenAIRunStep)
def retrieve_run_step(
thread_id: str = Path(..., description="The unique identifier of the thread."),
run_id: str = Path(..., description="The unique identifier of the run."),
step_id: str = Path(..., description="The unique identifier of the run step."),
):
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@app.post("/v1/threads/{thread_id}/runs/{run_id}", tags=["assistants"], response_model=OpenAIRun)
def modify_run(
thread_id: str = Path(..., description="The unique identifier of the thread."),
run_id: str = Path(..., description="The unique identifier of the run."),
request: ModifyRunRequest = Body(...),
):
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@app.post("/v1/threads/{thread_id}/runs/{run_id}/submit_tool_outputs", tags=["assistants"], response_model=OpenAIRun)
def submit_tool_outputs_to_run(
thread_id: str = Path(..., description="The unique identifier of the thread."),
run_id: str = Path(..., description="The unique identifier of the run."),
request: SubmitToolOutputsToRunRequest = Body(...),
):
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@app.post("/v1/threads/{thread_id}/runs/{run_id}/cancel", tags=["assistants"], response_model=OpenAIRun)
def cancel_run(
thread_id: str = Path(..., description="The unique identifier of the thread."),
run_id: str = Path(..., description="The unique identifier of the run."),
):
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
return router

View File

@ -16,6 +16,7 @@ from memgpt.server.rest_api.config.index import setup_config_index_router
from memgpt.server.rest_api.humans.index import setup_humans_index_router
from memgpt.server.rest_api.personas.index import setup_personas_index_router
from memgpt.server.rest_api.models.index import setup_models_index_router
from memgpt.server.rest_api.openai_assistants.assistants import setup_openai_assistant_router
from memgpt.server.server import SyncServer
from memgpt.server.rest_api.interface import QueuingInterface
from memgpt.server.rest_api.static_files import mount_static_files
@ -33,6 +34,7 @@ server: SyncServer = SyncServer(default_interface=interface)
API_PREFIX = "/api"
OPENAI_API_PREFIX = "/v1"
CORS_ORIGINS = [
"http://localhost:4200",
@ -65,6 +67,8 @@ app.include_router(setup_personas_index_router(server, interface), prefix=API_PR
app.include_router(setup_models_index_router(server, interface), prefix=API_PREFIX)
# /api/config endpoints
app.include_router(setup_config_index_router(server, interface), prefix=API_PREFIX)
# /v1/assistants endpoints
app.include_router(setup_openai_assistant_router(server, interface), prefix=OPENAI_API_PREFIX)
# / static files
mount_static_files(app)

View File

@ -3,7 +3,7 @@ from fastapi.testclient import TestClient
import uuid
from memgpt.server.server import SyncServer
from memgpt.server.rest_api.openai_assistants.assistants import app
from memgpt.server.rest_api.server import app
from memgpt.constants import DEFAULT_PRESET