MemGPT/memgpt/server/rest_api/agents/message.py
2024-08-16 19:52:47 -07:00

247 lines
10 KiB
Python

import asyncio
import uuid
from datetime import datetime
from enum import Enum
from functools import partial
from typing import List, Optional, Union
from fastapi import APIRouter, Body, Depends, HTTPException, Query
from pydantic import BaseModel, Field
from starlette.responses import StreamingResponse
from memgpt.models.pydantic_models import MemGPTUsageStatistics
from memgpt.server.rest_api.auth_token import get_current_user
from memgpt.server.rest_api.interface import QueuingInterface, StreamingServerInterface
from memgpt.server.rest_api.utils import sse_async_generator
from memgpt.server.server import SyncServer
router = APIRouter()
class MessageRoleType(str, Enum):
user = "user"
system = "system"
class UserMessageRequest(BaseModel):
message: str = Field(..., description="The message content to be processed by the agent.")
name: Optional[str] = Field(default=None, description="Name of the message request sender")
role: MessageRoleType = Field(default=MessageRoleType.user, description="Role of the message sender (either 'user' or 'system')")
stream_steps: bool = Field(
default=False, description="Flag to determine if the response should be streamed. Set to True for streaming agent steps."
)
stream_tokens: bool = Field(
default=False,
description="Flag to determine if individual tokens should be streamed. Set to True for token streaming (requires stream_steps = True).",
)
timestamp: Optional[datetime] = Field(
None,
description="Timestamp to tag the message with (in ISO format). If null, timestamp will be created server-side on receipt of message.",
)
stream: bool = Field(
default=False,
description="Legacy flag for old streaming API, will be deprecrated in the future.",
deprecated=True,
)
# @validator("timestamp", pre=True, always=True)
# def validate_timestamp(cls, value: Optional[datetime]) -> Optional[datetime]:
# if value is None:
# return value # If the timestamp is None, just return None, implying default handling to set server-side
# if not isinstance(value, datetime):
# raise TypeError("Timestamp must be a datetime object with timezone information.")
# if value.tzinfo is None or value.tzinfo.utcoffset(value) is None:
# raise ValueError("Timestamp must be timezone-aware.")
# # Convert timestamp to UTC if it's not already in UTC
# if value.tzinfo.utcoffset(value) != timezone.utc.utcoffset(value):
# value = value.astimezone(timezone.utc)
# return value
class UserMessageResponse(BaseModel):
messages: List[dict] = Field(..., description="List of messages generated by the agent in response to the received message.")
usage: MemGPTUsageStatistics = Field(..., description="Usage statistics for the completion.")
class GetAgentMessagesRequest(BaseModel):
start: int = Field(..., description="Message index to start on (reverse chronological).")
count: int = Field(..., description="How many messages to retrieve.")
class GetAgentMessagesCursorRequest(BaseModel):
before: Optional[uuid.UUID] = Field(..., description="Message before which to retrieve the returned messages.")
limit: int = Field(..., description="Maximum number of messages to retrieve.")
class GetAgentMessagesResponse(BaseModel):
messages: list = Field(..., description="List of message objects.")
async def send_message_to_agent(
server: SyncServer,
agent_id: uuid.UUID,
user_id: uuid.UUID,
role: str,
message: str,
stream_legacy: bool, # legacy
stream_steps: bool,
stream_tokens: bool,
chat_completion_mode: Optional[bool] = False,
timestamp: Optional[datetime] = None,
) -> Union[StreamingResponse, UserMessageResponse]:
"""Split off into a separate function so that it can be imported in the /chat/completion proxy."""
# TODO this is a total hack but is required until we move streaming into the model config
if server.server_llm_config.model_endpoint != "https://api.openai.com/v1":
stream_tokens = False
# handle the legacy mode streaming
if stream_legacy:
# NOTE: override
stream_steps = True
stream_tokens = False
include_final_message = False
else:
include_final_message = True
if role == "user" or role is None:
message_func = server.user_message
elif role == "system":
message_func = server.system_message
else:
raise HTTPException(status_code=500, detail=f"Bad role {role}")
if not stream_steps and stream_tokens:
raise HTTPException(status_code=400, detail="stream_steps must be 'true' if stream_tokens is 'true'")
# For streaming response
try:
# Get the generator object off of the agent's streaming interface
# This will be attached to the POST SSE request used under-the-hood
memgpt_agent = server._get_or_load_agent(user_id=user_id, agent_id=agent_id)
streaming_interface = memgpt_agent.interface
if not isinstance(streaming_interface, StreamingServerInterface):
raise ValueError(f"Agent has wrong type of interface: {type(streaming_interface)}")
# Enable token-streaming within the request if desired
streaming_interface.streaming_mode = stream_tokens
# "chatcompletion mode" does some remapping and ignores inner thoughts
streaming_interface.streaming_chat_completion_mode = chat_completion_mode
# NOTE: for legacy 'stream' flag
streaming_interface.nonstreaming_legacy_mode = stream_legacy
# streaming_interface.allow_assistant_message = stream
# streaming_interface.function_call_legacy_mode = stream
# Offload the synchronous message_func to a separate thread
streaming_interface.stream_start()
task = asyncio.create_task(
asyncio.to_thread(message_func, user_id=user_id, agent_id=agent_id, message=message, timestamp=timestamp)
)
if stream_steps:
# return a stream
return StreamingResponse(
sse_async_generator(streaming_interface.get_generator(), finish_message=include_final_message),
media_type="text/event-stream",
)
else:
# buffer the stream, then return the list
generated_stream = []
async for message in streaming_interface.get_generator():
generated_stream.append(message)
if "data" in message and message["data"] == "[DONE]":
break
filtered_stream = [d for d in generated_stream if d not in ["[DONE_GEN]", "[DONE_STEP]", "[DONE]"]]
usage = await task
return UserMessageResponse(messages=filtered_stream, usage=usage)
except HTTPException:
raise
except Exception as e:
print(e)
import traceback
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"{e}")
def setup_agents_message_router(server: SyncServer, interface: QueuingInterface, password: str):
get_current_user_with_server = partial(partial(get_current_user, server), password)
@router.get("/agents/{agent_id}/messages", tags=["agents"], response_model=GetAgentMessagesResponse)
def get_agent_messages(
agent_id: uuid.UUID,
start: int = Query(..., description="Message index to start on (reverse chronological)."),
count: int = Query(..., description="How many messages to retrieve."),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""
Retrieve the in-context messages of a specific agent. Paginated, provide start and count to iterate.
"""
# Validate with the Pydantic model (optional)
request = GetAgentMessagesRequest(agent_id=agent_id, start=start, count=count)
# agent_id = uuid.UUID(request.agent_id) if request.agent_id else None
interface.clear()
messages = server.get_agent_messages(user_id=user_id, agent_id=agent_id, start=request.start, count=request.count)
return GetAgentMessagesResponse(messages=messages)
@router.get("/agents/{agent_id}/messages-cursor", tags=["agents"], response_model=GetAgentMessagesResponse)
def get_agent_messages_cursor(
agent_id: uuid.UUID,
before: Optional[uuid.UUID] = Query(None, description="Message before which to retrieve the returned messages."),
limit: int = Query(10, description="Maximum number of messages to retrieve."),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""
Retrieve the in-context messages of a specific agent. Paginated, provide start and count to iterate.
"""
# Validate with the Pydantic model (optional)
request = GetAgentMessagesCursorRequest(agent_id=agent_id, before=before, limit=limit)
interface.clear()
[_, messages] = server.get_agent_recall_cursor(
user_id=user_id, agent_id=agent_id, before=request.before, limit=request.limit, reverse=True
)
# print("====> messages-cursor DEBUG")
# for i, msg in enumerate(messages):
# print(f"message {i+1}/{len(messages)}")
# print(f"UTC created-at: {msg.created_at.strftime('%Y-%m-%dT%H:%M:%S.%f')[:-3] + 'Z'}")
# print(f"ISO format string: {msg['created_at']}")
# print(msg)
return GetAgentMessagesResponse(messages=messages)
@router.post("/agents/{agent_id}/messages", tags=["agents"], response_model=UserMessageResponse)
async def send_message(
# background_tasks: BackgroundTasks,
agent_id: uuid.UUID,
request: UserMessageRequest = Body(...),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""
Process a user message and return the agent's response.
This endpoint accepts a message from a user and processes it through the agent.
It can optionally stream the response if 'stream' is set to True.
"""
return await send_message_to_agent(
server=server,
agent_id=agent_id,
user_id=user_id,
role=request.role,
message=request.message,
stream_steps=request.stream_steps,
stream_tokens=request.stream_tokens,
timestamp=request.timestamp,
# legacy
stream_legacy=request.stream,
)
return router