mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
247 lines
10 KiB
Python
247 lines
10 KiB
Python
import asyncio
|
|
import uuid
|
|
from datetime import datetime
|
|
from enum import Enum
|
|
from functools import partial
|
|
from typing import List, Optional, Union
|
|
|
|
from fastapi import APIRouter, Body, Depends, HTTPException, Query
|
|
from pydantic import BaseModel, Field
|
|
from starlette.responses import StreamingResponse
|
|
|
|
from memgpt.models.pydantic_models import MemGPTUsageStatistics
|
|
from memgpt.server.rest_api.auth_token import get_current_user
|
|
from memgpt.server.rest_api.interface import QueuingInterface, StreamingServerInterface
|
|
from memgpt.server.rest_api.utils import sse_async_generator
|
|
from memgpt.server.server import SyncServer
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
class MessageRoleType(str, Enum):
|
|
user = "user"
|
|
system = "system"
|
|
|
|
|
|
class UserMessageRequest(BaseModel):
|
|
message: str = Field(..., description="The message content to be processed by the agent.")
|
|
name: Optional[str] = Field(default=None, description="Name of the message request sender")
|
|
role: MessageRoleType = Field(default=MessageRoleType.user, description="Role of the message sender (either 'user' or 'system')")
|
|
stream_steps: bool = Field(
|
|
default=False, description="Flag to determine if the response should be streamed. Set to True for streaming agent steps."
|
|
)
|
|
stream_tokens: bool = Field(
|
|
default=False,
|
|
description="Flag to determine if individual tokens should be streamed. Set to True for token streaming (requires stream_steps = True).",
|
|
)
|
|
timestamp: Optional[datetime] = Field(
|
|
None,
|
|
description="Timestamp to tag the message with (in ISO format). If null, timestamp will be created server-side on receipt of message.",
|
|
)
|
|
stream: bool = Field(
|
|
default=False,
|
|
description="Legacy flag for old streaming API, will be deprecrated in the future.",
|
|
deprecated=True,
|
|
)
|
|
|
|
# @validator("timestamp", pre=True, always=True)
|
|
# def validate_timestamp(cls, value: Optional[datetime]) -> Optional[datetime]:
|
|
# if value is None:
|
|
# return value # If the timestamp is None, just return None, implying default handling to set server-side
|
|
|
|
# if not isinstance(value, datetime):
|
|
# raise TypeError("Timestamp must be a datetime object with timezone information.")
|
|
|
|
# if value.tzinfo is None or value.tzinfo.utcoffset(value) is None:
|
|
# raise ValueError("Timestamp must be timezone-aware.")
|
|
|
|
# # Convert timestamp to UTC if it's not already in UTC
|
|
# if value.tzinfo.utcoffset(value) != timezone.utc.utcoffset(value):
|
|
# value = value.astimezone(timezone.utc)
|
|
|
|
# return value
|
|
|
|
|
|
class UserMessageResponse(BaseModel):
|
|
messages: List[dict] = Field(..., description="List of messages generated by the agent in response to the received message.")
|
|
usage: MemGPTUsageStatistics = Field(..., description="Usage statistics for the completion.")
|
|
|
|
|
|
class GetAgentMessagesRequest(BaseModel):
|
|
start: int = Field(..., description="Message index to start on (reverse chronological).")
|
|
count: int = Field(..., description="How many messages to retrieve.")
|
|
|
|
|
|
class GetAgentMessagesCursorRequest(BaseModel):
|
|
before: Optional[uuid.UUID] = Field(..., description="Message before which to retrieve the returned messages.")
|
|
limit: int = Field(..., description="Maximum number of messages to retrieve.")
|
|
|
|
|
|
class GetAgentMessagesResponse(BaseModel):
|
|
messages: list = Field(..., description="List of message objects.")
|
|
|
|
|
|
async def send_message_to_agent(
|
|
server: SyncServer,
|
|
agent_id: uuid.UUID,
|
|
user_id: uuid.UUID,
|
|
role: str,
|
|
message: str,
|
|
stream_legacy: bool, # legacy
|
|
stream_steps: bool,
|
|
stream_tokens: bool,
|
|
chat_completion_mode: Optional[bool] = False,
|
|
timestamp: Optional[datetime] = None,
|
|
) -> Union[StreamingResponse, UserMessageResponse]:
|
|
"""Split off into a separate function so that it can be imported in the /chat/completion proxy."""
|
|
|
|
# TODO this is a total hack but is required until we move streaming into the model config
|
|
if server.server_llm_config.model_endpoint != "https://api.openai.com/v1":
|
|
stream_tokens = False
|
|
|
|
# handle the legacy mode streaming
|
|
if stream_legacy:
|
|
# NOTE: override
|
|
stream_steps = True
|
|
stream_tokens = False
|
|
include_final_message = False
|
|
else:
|
|
include_final_message = True
|
|
|
|
if role == "user" or role is None:
|
|
message_func = server.user_message
|
|
elif role == "system":
|
|
message_func = server.system_message
|
|
else:
|
|
raise HTTPException(status_code=500, detail=f"Bad role {role}")
|
|
|
|
if not stream_steps and stream_tokens:
|
|
raise HTTPException(status_code=400, detail="stream_steps must be 'true' if stream_tokens is 'true'")
|
|
|
|
# For streaming response
|
|
try:
|
|
|
|
# Get the generator object off of the agent's streaming interface
|
|
# This will be attached to the POST SSE request used under-the-hood
|
|
memgpt_agent = server._get_or_load_agent(user_id=user_id, agent_id=agent_id)
|
|
streaming_interface = memgpt_agent.interface
|
|
if not isinstance(streaming_interface, StreamingServerInterface):
|
|
raise ValueError(f"Agent has wrong type of interface: {type(streaming_interface)}")
|
|
|
|
# Enable token-streaming within the request if desired
|
|
streaming_interface.streaming_mode = stream_tokens
|
|
# "chatcompletion mode" does some remapping and ignores inner thoughts
|
|
streaming_interface.streaming_chat_completion_mode = chat_completion_mode
|
|
|
|
# NOTE: for legacy 'stream' flag
|
|
streaming_interface.nonstreaming_legacy_mode = stream_legacy
|
|
# streaming_interface.allow_assistant_message = stream
|
|
# streaming_interface.function_call_legacy_mode = stream
|
|
|
|
# Offload the synchronous message_func to a separate thread
|
|
streaming_interface.stream_start()
|
|
task = asyncio.create_task(
|
|
asyncio.to_thread(message_func, user_id=user_id, agent_id=agent_id, message=message, timestamp=timestamp)
|
|
)
|
|
|
|
if stream_steps:
|
|
# return a stream
|
|
return StreamingResponse(
|
|
sse_async_generator(streaming_interface.get_generator(), finish_message=include_final_message),
|
|
media_type="text/event-stream",
|
|
)
|
|
else:
|
|
# buffer the stream, then return the list
|
|
generated_stream = []
|
|
async for message in streaming_interface.get_generator():
|
|
generated_stream.append(message)
|
|
if "data" in message and message["data"] == "[DONE]":
|
|
break
|
|
filtered_stream = [d for d in generated_stream if d not in ["[DONE_GEN]", "[DONE_STEP]", "[DONE]"]]
|
|
usage = await task
|
|
return UserMessageResponse(messages=filtered_stream, usage=usage)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
print(e)
|
|
import traceback
|
|
|
|
traceback.print_exc()
|
|
raise HTTPException(status_code=500, detail=f"{e}")
|
|
|
|
|
|
def setup_agents_message_router(server: SyncServer, interface: QueuingInterface, password: str):
|
|
get_current_user_with_server = partial(partial(get_current_user, server), password)
|
|
|
|
@router.get("/agents/{agent_id}/messages", tags=["agents"], response_model=GetAgentMessagesResponse)
|
|
def get_agent_messages(
|
|
agent_id: uuid.UUID,
|
|
start: int = Query(..., description="Message index to start on (reverse chronological)."),
|
|
count: int = Query(..., description="How many messages to retrieve."),
|
|
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
|
):
|
|
"""
|
|
Retrieve the in-context messages of a specific agent. Paginated, provide start and count to iterate.
|
|
"""
|
|
# Validate with the Pydantic model (optional)
|
|
request = GetAgentMessagesRequest(agent_id=agent_id, start=start, count=count)
|
|
# agent_id = uuid.UUID(request.agent_id) if request.agent_id else None
|
|
|
|
interface.clear()
|
|
messages = server.get_agent_messages(user_id=user_id, agent_id=agent_id, start=request.start, count=request.count)
|
|
return GetAgentMessagesResponse(messages=messages)
|
|
|
|
@router.get("/agents/{agent_id}/messages-cursor", tags=["agents"], response_model=GetAgentMessagesResponse)
|
|
def get_agent_messages_cursor(
|
|
agent_id: uuid.UUID,
|
|
before: Optional[uuid.UUID] = Query(None, description="Message before which to retrieve the returned messages."),
|
|
limit: int = Query(10, description="Maximum number of messages to retrieve."),
|
|
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
|
):
|
|
"""
|
|
Retrieve the in-context messages of a specific agent. Paginated, provide start and count to iterate.
|
|
"""
|
|
# Validate with the Pydantic model (optional)
|
|
request = GetAgentMessagesCursorRequest(agent_id=agent_id, before=before, limit=limit)
|
|
|
|
interface.clear()
|
|
[_, messages] = server.get_agent_recall_cursor(
|
|
user_id=user_id, agent_id=agent_id, before=request.before, limit=request.limit, reverse=True
|
|
)
|
|
# print("====> messages-cursor DEBUG")
|
|
# for i, msg in enumerate(messages):
|
|
# print(f"message {i+1}/{len(messages)}")
|
|
# print(f"UTC created-at: {msg.created_at.strftime('%Y-%m-%dT%H:%M:%S.%f')[:-3] + 'Z'}")
|
|
# print(f"ISO format string: {msg['created_at']}")
|
|
# print(msg)
|
|
return GetAgentMessagesResponse(messages=messages)
|
|
|
|
@router.post("/agents/{agent_id}/messages", tags=["agents"], response_model=UserMessageResponse)
|
|
async def send_message(
|
|
# background_tasks: BackgroundTasks,
|
|
agent_id: uuid.UUID,
|
|
request: UserMessageRequest = Body(...),
|
|
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
|
):
|
|
"""
|
|
Process a user message and return the agent's response.
|
|
|
|
This endpoint accepts a message from a user and processes it through the agent.
|
|
It can optionally stream the response if 'stream' is set to True.
|
|
"""
|
|
return await send_message_to_agent(
|
|
server=server,
|
|
agent_id=agent_id,
|
|
user_id=user_id,
|
|
role=request.role,
|
|
message=request.message,
|
|
stream_steps=request.stream_steps,
|
|
stream_tokens=request.stream_tokens,
|
|
timestamp=request.timestamp,
|
|
# legacy
|
|
stream_legacy=request.stream,
|
|
)
|
|
|
|
return router
|