mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
refactor: Add OpenAI assistants API endpoints to memgpt server
(#1006)
This commit is contained in:
parent
326923b578
commit
9c39e8cefd
@ -45,15 +45,7 @@ Start the server with:
|
||||
poetry run uvicorn assistants:app --reload --port 8080
|
||||
"""
|
||||
|
||||
interface: QueuingInterface = QueuingInterface()
|
||||
server: SyncServer = SyncServer(default_interface=interface)
|
||||
|
||||
|
||||
# router = APIRouter()
|
||||
app = FastAPI()
|
||||
|
||||
user_id = uuid.UUID(MemGPTConfig.load().anon_clientid)
|
||||
print(f"User ID: {user_id}")
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class CreateAssistantRequest(BaseModel):
|
||||
@ -167,366 +159,359 @@ class SubmitToolOutputsToRunRequest(BaseModel):
|
||||
|
||||
|
||||
# TODO: implement mechanism for creating/authenticating users associated with a bearer token
|
||||
def setup_openai_assistant_router(server: SyncServer, interface: QueuingInterface):
|
||||
# TODO: remove this (when we have user auth)
|
||||
user_id = uuid.UUID(MemGPTConfig.load().anon_clientid)
|
||||
print(f"User ID: {user_id}")
|
||||
|
||||
# create assistant (MemGPT agent)
|
||||
@router.post("/assistants", tags=["assistants"], response_model=OpenAIAssistant)
|
||||
def create_assistant(request: CreateAssistantRequest = Body(...)):
|
||||
# TODO: create preset
|
||||
return OpenAIAssistant(
|
||||
id=DEFAULT_PRESET,
|
||||
name="default_preset",
|
||||
description=request.description,
|
||||
created_at=int(datetime.now().timestamp()),
|
||||
model=request.model,
|
||||
instructions=request.instructions,
|
||||
tools=request.tools,
|
||||
file_ids=request.file_ids,
|
||||
metadata=request.metadata,
|
||||
)
|
||||
|
||||
@app.get("/v1/health", tags=["assistant"])
|
||||
def get_health():
|
||||
return {"status": "healthy"}
|
||||
@router.post("/assistants/{assistant_id}/files", tags=["assistants"], response_model=AssistantFile)
|
||||
def create_assistant_file(
|
||||
assistant_id: str = Path(..., description="The unique identifier of the assistant."),
|
||||
request: CreateAssistantFileRequest = Body(...),
|
||||
):
|
||||
# TODO: add file to assistant
|
||||
return AssistantFile(
|
||||
id=request.file_id,
|
||||
created_at=int(datetime.now().timestamp()),
|
||||
assistant_id=assistant_id,
|
||||
)
|
||||
|
||||
@router.get("/assistants", tags=["assistants"], response_model=List[OpenAIAssistant])
|
||||
def list_assistants(
|
||||
limit: int = Query(1000, description="How many assistants to retrieve."),
|
||||
order: str = Query("asc", description="Order of assistants to retrieve (either 'asc' or 'desc')."),
|
||||
after: str = Query(
|
||||
None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."
|
||||
),
|
||||
before: str = Query(
|
||||
None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."
|
||||
),
|
||||
):
|
||||
# TODO: implement list assistants (i.e. list available MemGPT presets)
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
# create assistant (MemGPT agent)
|
||||
@app.post("/v1/assistants", tags=["assistants"], response_model=OpenAIAssistant)
|
||||
def create_assistant(request: CreateAssistantRequest = Body(...)):
|
||||
# TODO: create preset
|
||||
return OpenAIAssistant(
|
||||
id=DEFAULT_PRESET,
|
||||
name="default_preset",
|
||||
description=request.description,
|
||||
created_at=int(datetime.now().timestamp()),
|
||||
model=request.model,
|
||||
instructions=request.instructions,
|
||||
tools=request.tools,
|
||||
file_ids=request.file_ids,
|
||||
metadata=request.metadata,
|
||||
)
|
||||
@router.get("/assistants/{assistant_id}/files", tags=["assistants"], response_model=List[AssistantFile])
|
||||
def list_assistant_files(
|
||||
assistant_id: str = Path(..., description="The unique identifier of the assistant."),
|
||||
limit: int = Query(1000, description="How many files to retrieve."),
|
||||
order: str = Query("asc", description="Order of files to retrieve (either 'asc' or 'desc')."),
|
||||
after: str = Query(
|
||||
None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."
|
||||
),
|
||||
before: str = Query(
|
||||
None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."
|
||||
),
|
||||
):
|
||||
# TODO: list attached data sources to preset
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
@router.get("/assistants/{assistant_id}", tags=["assistants"], response_model=OpenAIAssistant)
|
||||
def retrieve_assistant(
|
||||
assistant_id: str = Path(..., description="The unique identifier of the assistant."),
|
||||
):
|
||||
# TODO: get and return preset
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
@app.post("/v1/assistants/{assistant_id}/files", tags=["assistants"], response_model=AssistantFile)
|
||||
def create_assistant_file(
|
||||
assistant_id: str = Path(..., description="The unique identifier of the assistant."),
|
||||
request: CreateAssistantFileRequest = Body(...),
|
||||
):
|
||||
# TODO: add file to assistant
|
||||
return AssistantFile(
|
||||
id=request.file_id,
|
||||
created_at=int(datetime.now().timestamp()),
|
||||
assistant_id=assistant_id,
|
||||
)
|
||||
@router.get("/assistants/{assistant_id}/files/{file_id}", tags=["assistants"], response_model=AssistantFile)
|
||||
def retrieve_assistant_file(
|
||||
assistant_id: str = Path(..., description="The unique identifier of the assistant."),
|
||||
file_id: str = Path(..., description="The unique identifier of the file."),
|
||||
):
|
||||
# TODO: return data source attached to preset
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
@router.post("/assistants/{assistant_id}", tags=["assistants"], response_model=OpenAIAssistant)
|
||||
def modify_assistant(
|
||||
assistant_id: str = Path(..., description="The unique identifier of the assistant."),
|
||||
request: CreateAssistantRequest = Body(...),
|
||||
):
|
||||
# TODO: modify preset
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
@app.get("/v1/assistants", tags=["assistants"], response_model=List[OpenAIAssistant])
|
||||
def list_assistants(
|
||||
limit: int = Query(1000, description="How many assistants to retrieve."),
|
||||
order: str = Query("asc", description="Order of assistants to retrieve (either 'asc' or 'desc')."),
|
||||
after: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
|
||||
before: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
|
||||
):
|
||||
# TODO: implement list assistants (i.e. list available MemGPT presets)
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
@router.delete("/assistants/{assistant_id}", tags=["assistants"], response_model=DeleteAssistantResponse)
|
||||
def delete_assistant(
|
||||
assistant_id: str = Path(..., description="The unique identifier of the assistant."),
|
||||
):
|
||||
# TODO: delete preset
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
@router.delete("/assistants/{assistant_id}/files/{file_id}", tags=["assistants"], response_model=DeleteAssistantFileResponse)
|
||||
def delete_assistant_file(
|
||||
assistant_id: str = Path(..., description="The unique identifier of the assistant."),
|
||||
file_id: str = Path(..., description="The unique identifier of the file."),
|
||||
):
|
||||
# TODO: delete source on preset
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
@app.get("/v1/assistants/{assistant_id}/files", tags=["assistants"], response_model=List[AssistantFile])
|
||||
def list_assistant_files(
|
||||
assistant_id: str = Path(..., description="The unique identifier of the assistant."),
|
||||
limit: int = Query(1000, description="How many files to retrieve."),
|
||||
order: str = Query("asc", description="Order of files to retrieve (either 'asc' or 'desc')."),
|
||||
after: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
|
||||
before: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
|
||||
):
|
||||
# TODO: list attached data sources to preset
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
@router.post("/threads", tags=["threads"], response_model=OpenAIThread)
|
||||
def create_thread(request: CreateThreadRequest = Body(...)):
|
||||
# TODO: use requests.description and requests.metadata fields
|
||||
# TODO: handle requests.file_ids and requests.tools
|
||||
# TODO: eventually allow request to override embedding/llm model
|
||||
|
||||
print("Create thread/agent", request)
|
||||
# create a memgpt agent
|
||||
agent_state = server.create_agent(
|
||||
user_id=user_id,
|
||||
agent_config={
|
||||
"user_id": user_id,
|
||||
},
|
||||
)
|
||||
# TODO: insert messages into recall memory
|
||||
return OpenAIThread(
|
||||
id=str(agent_state.id),
|
||||
created_at=int(agent_state.created_at.timestamp()),
|
||||
)
|
||||
|
||||
@app.get("/v1/assistants/{assistant_id}", tags=["assistants"], response_model=OpenAIAssistant)
|
||||
def retrieve_assistant(
|
||||
assistant_id: str = Path(..., description="The unique identifier of the assistant."),
|
||||
):
|
||||
# TODO: get and return preset
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
@router.get("/threads/{thread_id}", tags=["threads"], response_model=OpenAIThread)
|
||||
def retrieve_thread(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
):
|
||||
agent = server.get_agent(uuid.UUID(thread_id))
|
||||
return OpenAIThread(
|
||||
id=str(agent.id),
|
||||
created_at=int(agent.created_at.timestamp()),
|
||||
)
|
||||
|
||||
@router.get("/threads/{thread_id}", tags=["threads"], response_model=OpenAIThread)
|
||||
def modify_thread(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
request: ModifyThreadRequest = Body(...),
|
||||
):
|
||||
# TODO: add agent metadata so this can be modified
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
@app.get("/v1/assistants/{assistant_id}/files/{file_id}", tags=["assistants"], response_model=AssistantFile)
|
||||
def retrieve_assistant_file(
|
||||
assistant_id: str = Path(..., description="The unique identifier of the assistant."),
|
||||
file_id: str = Path(..., description="The unique identifier of the file."),
|
||||
):
|
||||
# TODO: return data source attached to preset
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
@router.delete("/threads/{thread_id}", tags=["threads"], response_model=DeleteThreadResponse)
|
||||
def delete_thread(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
):
|
||||
# TODO: delete agent
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
@router.post("/threads/{thread_id}/messages", tags=["messages"], response_model=OpenAIMessage)
|
||||
def create_message(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
request: CreateMessageRequest = Body(...),
|
||||
):
|
||||
agent_id = uuid.UUID(thread_id)
|
||||
# create message object
|
||||
message = Message(
|
||||
user_id=user_id,
|
||||
agent_id=agent_id,
|
||||
role=request.role,
|
||||
text=request.content,
|
||||
)
|
||||
agent = server._get_or_load_agent(user_id=user_id, agent_id=agent_id)
|
||||
# add message to agent
|
||||
agent._append_to_messages([message])
|
||||
|
||||
@app.post("/v1/assistants/{assistant_id}", tags=["assistants"], response_model=OpenAIAssistant)
|
||||
def modify_assistant(
|
||||
assistant_id: str = Path(..., description="The unique identifier of the assistant."),
|
||||
request: CreateAssistantRequest = Body(...),
|
||||
):
|
||||
# TODO: modify preset
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
|
||||
@app.delete("/v1/assistants/{assistant_id}", tags=["assistants"], response_model=DeleteAssistantResponse)
|
||||
def delete_assistant(
|
||||
assistant_id: str = Path(..., description="The unique identifier of the assistant."),
|
||||
):
|
||||
# TODO: delete preset
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
|
||||
@app.delete("/v1/assistants/{assistant_id}/files/{file_id}", tags=["assistants"], response_model=DeleteAssistantFileResponse)
|
||||
def delete_assistant_file(
|
||||
assistant_id: str = Path(..., description="The unique identifier of the assistant."),
|
||||
file_id: str = Path(..., description="The unique identifier of the file."),
|
||||
):
|
||||
# TODO: delete source on preset
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
|
||||
@app.post("/v1/threads", tags=["assistants"], response_model=OpenAIThread)
|
||||
def create_thread(request: CreateThreadRequest = Body(...)):
|
||||
# TODO: use requests.description and requests.metadata fields
|
||||
# TODO: handle requests.file_ids and requests.tools
|
||||
# TODO: eventually allow request to override embedding/llm model
|
||||
|
||||
print("Create thread/agent", request)
|
||||
# create a memgpt agent
|
||||
agent_state = server.create_agent(
|
||||
user_id=user_id,
|
||||
agent_config={
|
||||
"user_id": user_id,
|
||||
},
|
||||
)
|
||||
# TODO: insert messages into recall memory
|
||||
return OpenAIThread(
|
||||
id=str(agent_state.id),
|
||||
created_at=int(agent_state.created_at.timestamp()),
|
||||
)
|
||||
|
||||
|
||||
@app.get("/v1/threads/{thread_id}", tags=["assistants"], response_model=OpenAIThread)
|
||||
def retrieve_thread(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
):
|
||||
agent = server.get_agent(uuid.UUID(thread_id))
|
||||
return OpenAIThread(
|
||||
id=str(agent.id),
|
||||
created_at=int(agent.created_at.timestamp()),
|
||||
)
|
||||
|
||||
|
||||
@app.get("/v1/threads/{thread_id}", tags=["assistants"], response_model=OpenAIThread)
|
||||
def modify_thread(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
request: ModifyThreadRequest = Body(...),
|
||||
):
|
||||
# TODO: add agent metadata so this can be modified
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
|
||||
@app.delete("/v1/threads/{thread_id}", tags=["assistants"], response_model=DeleteThreadResponse)
|
||||
def delete_thread(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
):
|
||||
# TODO: delete agent
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
|
||||
@app.post("/v1/threads/{thread_id}/messages", tags=["assistants"], response_model=OpenAIMessage)
|
||||
def create_message(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
request: CreateMessageRequest = Body(...),
|
||||
):
|
||||
agent_id = uuid.UUID(thread_id)
|
||||
# create message object
|
||||
message = Message(
|
||||
user_id=user_id,
|
||||
agent_id=agent_id,
|
||||
role=request.role,
|
||||
text=request.content,
|
||||
)
|
||||
agent = server._get_or_load_agent(user_id=user_id, agent_id=agent_id)
|
||||
# add message to agent
|
||||
agent._append_to_messages([message])
|
||||
|
||||
openai_message = OpenAIMessage(
|
||||
id=str(message.id),
|
||||
created_at=int(message.created_at.timestamp()),
|
||||
content=[Text(text=message.text)],
|
||||
role=message.role,
|
||||
thread_id=str(message.agent_id),
|
||||
assistant_id=DEFAULT_PRESET, # TODO: update this
|
||||
# file_ids=message.file_ids,
|
||||
# metadata=message.metadata,
|
||||
)
|
||||
return openai_message
|
||||
|
||||
|
||||
@app.get("/v1/threads/{thread_id}/messages", tags=["assistants"], response_model=ListMessagesResponse)
|
||||
def list_messages(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
limit: int = Query(1000, description="How many messages to retrieve."),
|
||||
order: str = Query("asc", description="Order of messages to retrieve (either 'asc' or 'desc')."),
|
||||
after: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
|
||||
before: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
|
||||
):
|
||||
after_uuid = uuid.UUID(after) if before else None
|
||||
before_uuid = uuid.UUID(before) if before else None
|
||||
agent_id = uuid.UUID(thread_id)
|
||||
reverse = True if (order == "desc") else False
|
||||
cursor, json_messages = server.get_agent_recall_cursor(
|
||||
user_id=user_id,
|
||||
agent_id=agent_id,
|
||||
limit=limit,
|
||||
after=after_uuid,
|
||||
before=before_uuid,
|
||||
order_by="created_at",
|
||||
reverse=reverse,
|
||||
)
|
||||
print(json_messages[0]["text"])
|
||||
# convert to openai style messages
|
||||
openai_messages = [
|
||||
OpenAIMessage(
|
||||
id=str(message["id"]),
|
||||
created_at=int(message["created_at"].timestamp()),
|
||||
content=[Text(text=message["text"])],
|
||||
role=message["role"],
|
||||
thread_id=str(message["agent_id"]),
|
||||
assistant_id=DEFAULT_PRESET # TODO: update this
|
||||
openai_message = OpenAIMessage(
|
||||
id=str(message.id),
|
||||
created_at=int(message.created_at.timestamp()),
|
||||
content=[Text(text=message.text)],
|
||||
role=message.role,
|
||||
thread_id=str(message.agent_id),
|
||||
assistant_id=DEFAULT_PRESET, # TODO: update this
|
||||
# file_ids=message.file_ids,
|
||||
# metadata=message.metadata,
|
||||
)
|
||||
for message in json_messages
|
||||
]
|
||||
print("MESSAGES", openai_messages)
|
||||
# TODO: cast back to message objects
|
||||
return ListMessagesResponse(messages=openai_messages)
|
||||
return openai_message
|
||||
|
||||
@router.get("/threads/{thread_id}/messages", tags=["messages"], response_model=ListMessagesResponse)
|
||||
def list_messages(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
limit: int = Query(1000, description="How many messages to retrieve."),
|
||||
order: str = Query("asc", description="Order of messages to retrieve (either 'asc' or 'desc')."),
|
||||
after: str = Query(
|
||||
None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."
|
||||
),
|
||||
before: str = Query(
|
||||
None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."
|
||||
),
|
||||
):
|
||||
after_uuid = uuid.UUID(after) if before else None
|
||||
before_uuid = uuid.UUID(before) if before else None
|
||||
agent_id = uuid.UUID(thread_id)
|
||||
reverse = True if (order == "desc") else False
|
||||
cursor, json_messages = server.get_agent_recall_cursor(
|
||||
user_id=user_id,
|
||||
agent_id=agent_id,
|
||||
limit=limit,
|
||||
after=after_uuid,
|
||||
before=before_uuid,
|
||||
order_by="created_at",
|
||||
reverse=reverse,
|
||||
)
|
||||
print(json_messages[0]["text"])
|
||||
# convert to openai style messages
|
||||
openai_messages = [
|
||||
OpenAIMessage(
|
||||
id=str(message["id"]),
|
||||
created_at=int(message["created_at"].timestamp()),
|
||||
content=[Text(text=message["text"])],
|
||||
role=message["role"],
|
||||
thread_id=str(message["agent_id"]),
|
||||
assistant_id=DEFAULT_PRESET # TODO: update this
|
||||
# file_ids=message.file_ids,
|
||||
# metadata=message.metadata,
|
||||
)
|
||||
for message in json_messages
|
||||
]
|
||||
print("MESSAGES", openai_messages)
|
||||
# TODO: cast back to message objects
|
||||
return ListMessagesResponse(messages=openai_messages)
|
||||
|
||||
app.get("/v1/threads/{thread_id}/messages/{message_id}", tags=["assistants"], response_model=OpenAIMessage)
|
||||
router.get("/threads/{thread_id}/messages/{message_id}", tags=["messages"], response_model=OpenAIMessage)
|
||||
|
||||
def retrieve_message(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
message_id: str = Path(..., description="The unique identifier of the message."),
|
||||
):
|
||||
message_id = uuid.UUID(message_id)
|
||||
agent_id = uuid.UUID(thread_id)
|
||||
message = server.get_agent_message(agent_id, message_id)
|
||||
return OpenAIMessage(
|
||||
id=str(message.id),
|
||||
created_at=int(message.created_at.timestamp()),
|
||||
content=[Text(text=message.text)],
|
||||
role=message.role,
|
||||
thread_id=str(message.agent_id),
|
||||
assistant_id=DEFAULT_PRESET, # TODO: update this
|
||||
# file_ids=message.file_ids,
|
||||
# metadata=message.metadata,
|
||||
)
|
||||
|
||||
def retrieve_message(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
message_id: str = Path(..., description="The unique identifier of the message."),
|
||||
):
|
||||
message_id = uuid.UUID(message_id)
|
||||
agent_id = uuid.UUID(thread_id)
|
||||
message = server.get_agent_message(agent_id, message_id)
|
||||
return OpenAIMessage(
|
||||
id=str(message.id),
|
||||
created_at=int(message.created_at.timestamp()),
|
||||
content=[Text(text=message.text)],
|
||||
role=message.role,
|
||||
thread_id=str(message.agent_id),
|
||||
assistant_id=DEFAULT_PRESET, # TODO: update this
|
||||
# file_ids=message.file_ids,
|
||||
# metadata=message.metadata,
|
||||
)
|
||||
@router.get("/threads/{thread_id}/messages/{message_id}/files/{file_id}", tags=["messages"], response_model=MessageFile)
|
||||
def retrieve_message_file(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
message_id: str = Path(..., description="The unique identifier of the message."),
|
||||
file_id: str = Path(..., description="The unique identifier of the file."),
|
||||
):
|
||||
# TODO: implement?
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
@router.post("/threads/{thread_id}/messages/{message_id}", tags=["messages"], response_model=OpenAIMessage)
|
||||
def modify_message(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
message_id: str = Path(..., description="The unique identifier of the message."),
|
||||
request: ModifyMessageRequest = Body(...),
|
||||
):
|
||||
# TODO: add metada field to message so this can be modified
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
@app.get("/v1/threads/{thread_id}/messages/{message_id}/files/{file_id}", tags=["assistants"], response_model=MessageFile)
|
||||
def retrieve_message_file(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
message_id: str = Path(..., description="The unique identifier of the message."),
|
||||
file_id: str = Path(..., description="The unique identifier of the file."),
|
||||
):
|
||||
# TODO: implement?
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
@router.post("/threads/{thread_id}/runs", tags=["runs"], response_model=OpenAIRun)
|
||||
def create_run(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
request: CreateRunRequest = Body(...),
|
||||
):
|
||||
# TODO: add request.instructions as a message?
|
||||
agent_id = uuid.UUID(thread_id)
|
||||
# TODO: override preset of agent with request.assistant_id
|
||||
agent = server._get_or_load_agent(user_id=user_id, agent_id=agent_id)
|
||||
agent.step(user_message=None) # already has messages added
|
||||
run_id = str(uuid.uuid4())
|
||||
create_time = int(datetime.now().timestamp())
|
||||
return OpenAIRun(
|
||||
id=run_id,
|
||||
created_at=create_time,
|
||||
thread_id=str(agent_id),
|
||||
assistant_id=DEFAULT_PRESET, # TODO: update this
|
||||
status="completed", # TODO: eventaully allow offline execution
|
||||
expires_at=create_time,
|
||||
model=agent.agent_state.llm_config.model,
|
||||
instructions=request.instructions,
|
||||
)
|
||||
|
||||
@router.post("/threads/runs", tags=["runs"], response_model=OpenAIRun)
|
||||
def create_thread_and_run(
|
||||
request: CreateThreadRunRequest = Body(...),
|
||||
):
|
||||
# TODO: add a bunch of messages and execute
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
@app.post("/v1/threads/{thread_id}/messages/{message_id}", tags=["assistants"], response_model=OpenAIMessage)
|
||||
def modify_message(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
message_id: str = Path(..., description="The unique identifier of the message."),
|
||||
request: ModifyMessageRequest = Body(...),
|
||||
):
|
||||
# TODO: add metada field to message so this can be modified
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
@router.get("/threads/{thread_id}/runs", tags=["runs"], response_model=List[OpenAIRun])
|
||||
def list_runs(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
limit: int = Query(1000, description="How many runs to retrieve."),
|
||||
order: str = Query("asc", description="Order of runs to retrieve (either 'asc' or 'desc')."),
|
||||
after: str = Query(
|
||||
None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."
|
||||
),
|
||||
before: str = Query(
|
||||
None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."
|
||||
),
|
||||
):
|
||||
# TODO: store run information in a DB so it can be returned here
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
@router.get("/threads/{thread_id}/runs/{run_id}/steps", tags=["runs"], response_model=List[OpenAIRunStep])
|
||||
def list_run_steps(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
run_id: str = Path(..., description="The unique identifier of the run."),
|
||||
limit: int = Query(1000, description="How many run steps to retrieve."),
|
||||
order: str = Query("asc", description="Order of run steps to retrieve (either 'asc' or 'desc')."),
|
||||
after: str = Query(
|
||||
None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."
|
||||
),
|
||||
before: str = Query(
|
||||
None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."
|
||||
),
|
||||
):
|
||||
# TODO: store run information in a DB so it can be returned here
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
@app.post("/v1/threads/{thread_id}/runs", tags=["assistants"], response_model=OpenAIRun)
|
||||
def create_run(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
request: CreateRunRequest = Body(...),
|
||||
):
|
||||
# TODO: add request.instructions as a message?
|
||||
agent_id = uuid.UUID(thread_id)
|
||||
# TODO: override preset of agent with request.assistant_id
|
||||
agent = server._get_or_load_agent(user_id=user_id, agent_id=agent_id)
|
||||
agent.step(user_message=None) # already has messages added
|
||||
run_id = str(uuid.uuid4())
|
||||
create_time = int(datetime.now().timestamp())
|
||||
return OpenAIRun(
|
||||
id=run_id,
|
||||
created_at=create_time,
|
||||
thread_id=str(agent_id),
|
||||
assistant_id=DEFAULT_PRESET, # TODO: update this
|
||||
status="completed", # TODO: eventaully allow offline execution
|
||||
expires_at=create_time,
|
||||
model=agent.agent_state.llm_config.model,
|
||||
instructions=request.instructions,
|
||||
)
|
||||
@router.get("/threads/{thread_id}/runs/{run_id}", tags=["runs"], response_model=OpenAIRun)
|
||||
def retrieve_run(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
run_id: str = Path(..., description="The unique identifier of the run."),
|
||||
):
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
@router.get("/threads/{thread_id}/runs/{run_id}/steps/{step_id}", tags=["runs"], response_model=OpenAIRunStep)
|
||||
def retrieve_run_step(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
run_id: str = Path(..., description="The unique identifier of the run."),
|
||||
step_id: str = Path(..., description="The unique identifier of the run step."),
|
||||
):
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
@app.post("/v1/threads/runs", tags=["assistants"], response_model=OpenAIRun)
|
||||
def create_thread_and_run(
|
||||
request: CreateThreadRunRequest = Body(...),
|
||||
):
|
||||
# TODO: add a bunch of messages and execute
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
@router.post("/threads/{thread_id}/runs/{run_id}", tags=["runs"], response_model=OpenAIRun)
|
||||
def modify_run(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
run_id: str = Path(..., description="The unique identifier of the run."),
|
||||
request: ModifyRunRequest = Body(...),
|
||||
):
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
@router.post("/threads/{thread_id}/runs/{run_id}/submit_tool_outputs", tags=["runs"], response_model=OpenAIRun)
|
||||
def submit_tool_outputs_to_run(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
run_id: str = Path(..., description="The unique identifier of the run."),
|
||||
request: SubmitToolOutputsToRunRequest = Body(...),
|
||||
):
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
@app.get("/v1/threads/{thread_id}/runs", tags=["assistants"], response_model=List[OpenAIRun])
|
||||
def list_runs(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
limit: int = Query(1000, description="How many runs to retrieve."),
|
||||
order: str = Query("asc", description="Order of runs to retrieve (either 'asc' or 'desc')."),
|
||||
after: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
|
||||
before: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
|
||||
):
|
||||
# TODO: store run information in a DB so it can be returned here
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
@router.post("/threads/{thread_id}/runs/{run_id}/cancel", tags=["runs"], response_model=OpenAIRun)
|
||||
def cancel_run(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
run_id: str = Path(..., description="The unique identifier of the run."),
|
||||
):
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
|
||||
@app.get("/v1/threads/{thread_id}/runs/{run_id}/steps", tags=["assistants"], response_model=List[OpenAIRunStep])
|
||||
def list_run_steps(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
run_id: str = Path(..., description="The unique identifier of the run."),
|
||||
limit: int = Query(1000, description="How many run steps to retrieve."),
|
||||
order: str = Query("asc", description="Order of run steps to retrieve (either 'asc' or 'desc')."),
|
||||
after: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
|
||||
before: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
|
||||
):
|
||||
# TODO: store run information in a DB so it can be returned here
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
|
||||
@app.get("/v1/threads/{thread_id}/runs/{run_id}", tags=["assistants"], response_model=OpenAIRun)
|
||||
def retrieve_run(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
run_id: str = Path(..., description="The unique identifier of the run."),
|
||||
):
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
|
||||
@app.get("/v1/threads/{thread_id}/runs/{run_id}/steps/{step_id}", tags=["assistants"], response_model=OpenAIRunStep)
|
||||
def retrieve_run_step(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
run_id: str = Path(..., description="The unique identifier of the run."),
|
||||
step_id: str = Path(..., description="The unique identifier of the run step."),
|
||||
):
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
|
||||
@app.post("/v1/threads/{thread_id}/runs/{run_id}", tags=["assistants"], response_model=OpenAIRun)
|
||||
def modify_run(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
run_id: str = Path(..., description="The unique identifier of the run."),
|
||||
request: ModifyRunRequest = Body(...),
|
||||
):
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
|
||||
@app.post("/v1/threads/{thread_id}/runs/{run_id}/submit_tool_outputs", tags=["assistants"], response_model=OpenAIRun)
|
||||
def submit_tool_outputs_to_run(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
run_id: str = Path(..., description="The unique identifier of the run."),
|
||||
request: SubmitToolOutputsToRunRequest = Body(...),
|
||||
):
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
|
||||
@app.post("/v1/threads/{thread_id}/runs/{run_id}/cancel", tags=["assistants"], response_model=OpenAIRun)
|
||||
def cancel_run(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
run_id: str = Path(..., description="The unique identifier of the run."),
|
||||
):
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
return router
|
||||
|
@ -16,6 +16,7 @@ from memgpt.server.rest_api.config.index import setup_config_index_router
|
||||
from memgpt.server.rest_api.humans.index import setup_humans_index_router
|
||||
from memgpt.server.rest_api.personas.index import setup_personas_index_router
|
||||
from memgpt.server.rest_api.models.index import setup_models_index_router
|
||||
from memgpt.server.rest_api.openai_assistants.assistants import setup_openai_assistant_router
|
||||
from memgpt.server.server import SyncServer
|
||||
from memgpt.server.rest_api.interface import QueuingInterface
|
||||
from memgpt.server.rest_api.static_files import mount_static_files
|
||||
@ -33,6 +34,7 @@ server: SyncServer = SyncServer(default_interface=interface)
|
||||
|
||||
|
||||
API_PREFIX = "/api"
|
||||
OPENAI_API_PREFIX = "/v1"
|
||||
|
||||
CORS_ORIGINS = [
|
||||
"http://localhost:4200",
|
||||
@ -65,6 +67,8 @@ app.include_router(setup_personas_index_router(server, interface), prefix=API_PR
|
||||
app.include_router(setup_models_index_router(server, interface), prefix=API_PREFIX)
|
||||
# /api/config endpoints
|
||||
app.include_router(setup_config_index_router(server, interface), prefix=API_PREFIX)
|
||||
# /v1/assistants endpoints
|
||||
app.include_router(setup_openai_assistant_router(server, interface), prefix=OPENAI_API_PREFIX)
|
||||
# / static files
|
||||
mount_static_files(app)
|
||||
|
||||
|
@ -3,7 +3,7 @@ from fastapi.testclient import TestClient
|
||||
import uuid
|
||||
|
||||
from memgpt.server.server import SyncServer
|
||||
from memgpt.server.rest_api.openai_assistants.assistants import app
|
||||
from memgpt.server.rest_api.server import app
|
||||
from memgpt.constants import DEFAULT_PRESET
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user