MemGPT/letta/server/rest_api/utils.py
Matthew Zhou ae083fc205
feat: Sandboxing for tool execution (#2040)
Co-authored-by: Caren Thomas <carenthomas@Jeffs-MacBook-Pro-2.local>
Co-authored-by: Caren Thomas <carenthomas@jeffs-mbp-2.lan>
Co-authored-by: Caren Thomas <carenthomas@Jeffs-MBP-2.hsd1.ca.comcast.net>
Co-authored-by: Sarah Wooders <sarahwooders@gmail.com>
2024-11-22 10:34:08 -08:00

95 lines
3.0 KiB
Python

import asyncio
import json
import traceback
import warnings
from enum import Enum
from typing import AsyncGenerator, Optional, Union
from fastapi import Header
from pydantic import BaseModel
from letta.schemas.usage import LettaUsageStatistics
from letta.server.rest_api.interface import StreamingServerInterface
from letta.server.server import SyncServer
# from letta.orm.user import User
# from letta.orm.utilities import get_db_session
SSE_PREFIX = "data: "
SSE_SUFFIX = "\n\n"
SSE_FINISH_MSG = "[DONE]" # mimic openai
SSE_ARTIFICIAL_DELAY = 0.1
def sse_formatter(data: Union[dict, str]) -> str:
"""Prefix with 'data: ', and always include double newlines"""
assert type(data) in [dict, str], f"Expected type dict or str, got type {type(data)}"
data_str = json.dumps(data, separators=(",", ":")) if isinstance(data, dict) else data
return f"data: {data_str}\n\n"
async def sse_async_generator(
generator: AsyncGenerator,
usage_task: Optional[asyncio.Task] = None,
finish_message=True,
):
"""
Wraps a generator for use in Server-Sent Events (SSE), handling errors and ensuring a completion message.
Args:
- generator: An asynchronous generator yielding data chunks.
Yields:
- Formatted Server-Sent Event strings.
"""
try:
async for chunk in generator:
# yield f"data: {json.dumps(chunk)}\n\n"
if isinstance(chunk, BaseModel):
chunk = chunk.model_dump()
elif isinstance(chunk, Enum):
chunk = str(chunk.value)
elif not isinstance(chunk, dict):
chunk = str(chunk)
yield sse_formatter(chunk)
# If we have a usage task, wait for it and send its result
if usage_task is not None:
try:
usage = await usage_task
# Double-check the type
if not isinstance(usage, LettaUsageStatistics):
raise ValueError(f"Expected LettaUsageStatistics, got {type(usage)}")
yield sse_formatter({"usage": usage.model_dump()})
except Exception as e:
warnings.warn(f"Error getting usage data: {e}")
yield sse_formatter({"error": "Failed to get usage data"})
except Exception as e:
print("stream decoder hit error:", e)
print(traceback.print_stack())
yield sse_formatter({"error": "stream decoder encountered an error"})
finally:
if finish_message:
# Signal that the stream is complete
yield sse_formatter(SSE_FINISH_MSG)
# TODO: why does this double up the interface?
def get_letta_server() -> SyncServer:
# Check if a global server is already instantiated
from letta.server.rest_api.app import server
# assert isinstance(server, SyncServer)
return server
# Dependency to get user_id from headers
def get_user_id(user_id: Optional[str] = Header(None, alias="user_id")) -> Optional[str]:
return user_id
def get_current_interface() -> StreamingServerInterface:
return StreamingServerInterface