mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
384 lines
14 KiB
Python
384 lines
14 KiB
Python
import asyncio
|
|
import json
|
|
import os
|
|
import uuid
|
|
import warnings
|
|
from enum import Enum
|
|
from typing import TYPE_CHECKING, AsyncGenerator, Dict, Iterable, List, Optional, Union, cast
|
|
|
|
from fastapi import Header, HTTPException
|
|
from openai.types.chat import ChatCompletionMessageParam
|
|
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall
|
|
from openai.types.chat.chat_completion_message_tool_call import Function as OpenAIFunction
|
|
from openai.types.chat.completion_create_params import CompletionCreateParams
|
|
from pydantic import BaseModel
|
|
|
|
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, FUNC_FAILED_HEARTBEAT_MESSAGE, REQ_HEARTBEAT_MESSAGE
|
|
from letta.errors import ContextWindowExceededError, RateLimitExceededError
|
|
from letta.helpers.datetime_helpers import get_utc_time
|
|
from letta.log import get_logger
|
|
from letta.schemas.enums import MessageRole
|
|
from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, RedactedReasoningContent, TextContent
|
|
from letta.schemas.message import Message, MessageCreate
|
|
from letta.schemas.usage import LettaUsageStatistics
|
|
from letta.schemas.user import User
|
|
from letta.server.rest_api.interface import StreamingServerInterface
|
|
from letta.system import get_heartbeat, package_function_response
|
|
|
|
if TYPE_CHECKING:
|
|
from letta.server.server import SyncServer
|
|
|
|
|
|
SSE_PREFIX = "data: "
|
|
SSE_SUFFIX = "\n\n"
|
|
SSE_FINISH_MSG = "[DONE]" # mimic openai
|
|
SSE_ARTIFICIAL_DELAY = 0.1
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
def sse_formatter(data: Union[dict, str]) -> str:
|
|
"""Prefix with 'data: ', and always include double newlines"""
|
|
assert type(data) in [dict, str], f"Expected type dict or str, got type {type(data)}"
|
|
data_str = json.dumps(data, separators=(",", ":")) if isinstance(data, dict) else data
|
|
# print(f"data: {data_str}\n\n")
|
|
return f"data: {data_str}\n\n"
|
|
|
|
|
|
async def sse_async_generator(
|
|
generator: AsyncGenerator,
|
|
usage_task: Optional[asyncio.Task] = None,
|
|
finish_message=True,
|
|
):
|
|
"""
|
|
Wraps a generator for use in Server-Sent Events (SSE), handling errors and ensuring a completion message.
|
|
|
|
Args:
|
|
- generator: An asynchronous generator yielding data chunks.
|
|
|
|
Yields:
|
|
- Formatted Server-Sent Event strings.
|
|
"""
|
|
try:
|
|
async for chunk in generator:
|
|
# yield f"data: {json.dumps(chunk)}\n\n"
|
|
if isinstance(chunk, BaseModel):
|
|
chunk = chunk.model_dump()
|
|
elif isinstance(chunk, Enum):
|
|
chunk = str(chunk.value)
|
|
elif not isinstance(chunk, dict):
|
|
chunk = str(chunk)
|
|
yield sse_formatter(chunk)
|
|
|
|
# If we have a usage task, wait for it and send its result
|
|
if usage_task is not None:
|
|
try:
|
|
usage = await usage_task
|
|
# Double-check the type
|
|
if not isinstance(usage, LettaUsageStatistics):
|
|
err_msg = f"Expected LettaUsageStatistics, got {type(usage)}"
|
|
logger.error(err_msg)
|
|
raise ValueError(err_msg)
|
|
yield sse_formatter(usage.model_dump(exclude={"steps_messages"}))
|
|
|
|
except ContextWindowExceededError as e:
|
|
log_error_to_sentry(e)
|
|
logger.error(f"ContextWindowExceededError error: {e}")
|
|
yield sse_formatter({"error": f"Stream failed: {e}", "code": str(e.code.value) if e.code else None})
|
|
|
|
except RateLimitExceededError as e:
|
|
log_error_to_sentry(e)
|
|
logger.error(f"RateLimitExceededError error: {e}")
|
|
yield sse_formatter({"error": f"Stream failed: {e}", "code": str(e.code.value) if e.code else None})
|
|
|
|
except Exception as e:
|
|
log_error_to_sentry(e)
|
|
logger.error(f"Caught unexpected Exception: {e}")
|
|
yield sse_formatter({"error": f"Stream failed (internal error occurred)"})
|
|
|
|
except Exception as e:
|
|
log_error_to_sentry(e)
|
|
logger.error(f"Caught unexpected Exception: {e}")
|
|
yield sse_formatter({"error": "Stream failed (decoder encountered an error)"})
|
|
|
|
finally:
|
|
if finish_message:
|
|
# Signal that the stream is complete
|
|
yield sse_formatter(SSE_FINISH_MSG)
|
|
|
|
|
|
# TODO: why does this double up the interface?
|
|
def get_letta_server() -> "SyncServer":
|
|
# Check if a global server is already instantiated
|
|
from letta.server.rest_api.app import server
|
|
|
|
# assert isinstance(server, SyncServer)
|
|
return server
|
|
|
|
|
|
# Dependency to get user_id from headers
|
|
def get_user_id(user_id: Optional[str] = Header(None, alias="user_id")) -> Optional[str]:
|
|
return user_id
|
|
|
|
|
|
def get_current_interface() -> StreamingServerInterface:
|
|
return StreamingServerInterface
|
|
|
|
|
|
def log_error_to_sentry(e):
|
|
import traceback
|
|
|
|
traceback.print_exc()
|
|
warnings.warn(f"SSE stream generator failed: {e}")
|
|
|
|
# Log the error, since the exception handler upstack (in FastAPI) won't catch it, because this may be a 200 response
|
|
# Print the stack trace
|
|
if (os.getenv("SENTRY_DSN") is not None) and (os.getenv("SENTRY_DSN") != ""):
|
|
import sentry_sdk
|
|
|
|
sentry_sdk.capture_exception(e)
|
|
|
|
|
|
def create_input_messages(input_messages: List[MessageCreate], agent_id: str, actor: User) -> List[Message]:
|
|
"""
|
|
Converts a user input message into the internal structured format.
|
|
"""
|
|
new_messages = []
|
|
for input_message in input_messages:
|
|
# Construct the Message object
|
|
new_message = Message(
|
|
id=f"message-{uuid.uuid4()}",
|
|
role=input_message.role,
|
|
content=input_message.content,
|
|
name=input_message.name,
|
|
otid=input_message.otid,
|
|
sender_id=input_message.sender_id,
|
|
organization_id=actor.organization_id,
|
|
agent_id=agent_id,
|
|
model=None,
|
|
tool_calls=None,
|
|
tool_call_id=None,
|
|
created_at=get_utc_time(),
|
|
)
|
|
new_messages.append(new_message)
|
|
|
|
return new_messages
|
|
|
|
|
|
def create_letta_messages_from_llm_response(
|
|
agent_id: str,
|
|
model: str,
|
|
function_name: str,
|
|
function_arguments: Dict,
|
|
tool_call_id: str,
|
|
function_call_success: bool,
|
|
function_response: Optional[str],
|
|
actor: User,
|
|
add_heartbeat_request_system_message: bool = False,
|
|
reasoning_content: Optional[List[Union[TextContent, ReasoningContent, RedactedReasoningContent, OmittedReasoningContent]]] = None,
|
|
pre_computed_assistant_message_id: Optional[str] = None,
|
|
pre_computed_tool_message_id: Optional[str] = None,
|
|
) -> List[Message]:
|
|
messages = []
|
|
|
|
# Construct the tool call with the assistant's message
|
|
function_arguments["request_heartbeat"] = True
|
|
tool_call = OpenAIToolCall(
|
|
id=tool_call_id,
|
|
function=OpenAIFunction(
|
|
name=function_name,
|
|
arguments=json.dumps(function_arguments),
|
|
),
|
|
type="function",
|
|
)
|
|
# TODO: Use ToolCallContent instead of tool_calls
|
|
# TODO: This helps preserve ordering
|
|
assistant_message = Message(
|
|
role=MessageRole.assistant,
|
|
content=reasoning_content if reasoning_content else [],
|
|
organization_id=actor.organization_id,
|
|
agent_id=agent_id,
|
|
model=model,
|
|
tool_calls=[tool_call],
|
|
tool_call_id=tool_call_id,
|
|
created_at=get_utc_time(),
|
|
)
|
|
if pre_computed_assistant_message_id:
|
|
assistant_message.id = pre_computed_assistant_message_id
|
|
messages.append(assistant_message)
|
|
|
|
# TODO: Use ToolReturnContent instead of TextContent
|
|
# TODO: This helps preserve ordering
|
|
tool_message = Message(
|
|
role=MessageRole.tool,
|
|
content=[TextContent(text=package_function_response(function_call_success, function_response))],
|
|
organization_id=actor.organization_id,
|
|
agent_id=agent_id,
|
|
model=model,
|
|
tool_calls=[],
|
|
tool_call_id=tool_call_id,
|
|
created_at=get_utc_time(),
|
|
)
|
|
if pre_computed_tool_message_id:
|
|
tool_message.id = pre_computed_tool_message_id
|
|
messages.append(tool_message)
|
|
|
|
if add_heartbeat_request_system_message:
|
|
heartbeat_system_message = create_heartbeat_system_message(
|
|
agent_id=agent_id, model=model, function_call_success=function_call_success, actor=actor
|
|
)
|
|
messages.append(heartbeat_system_message)
|
|
|
|
return messages
|
|
|
|
|
|
def create_heartbeat_system_message(
|
|
agent_id: str,
|
|
model: str,
|
|
function_call_success: bool,
|
|
actor: User,
|
|
) -> Message:
|
|
text_content = REQ_HEARTBEAT_MESSAGE if function_call_success else FUNC_FAILED_HEARTBEAT_MESSAGE
|
|
heartbeat_system_message = Message(
|
|
role=MessageRole.user,
|
|
content=[TextContent(text=get_heartbeat(text_content))],
|
|
organization_id=actor.organization_id,
|
|
agent_id=agent_id,
|
|
model=model,
|
|
tool_calls=[],
|
|
tool_call_id=None,
|
|
created_at=get_utc_time(),
|
|
)
|
|
return heartbeat_system_message
|
|
|
|
|
|
def create_assistant_messages_from_openai_response(
|
|
response_text: str,
|
|
agent_id: str,
|
|
model: str,
|
|
actor: User,
|
|
) -> List[Message]:
|
|
"""
|
|
Converts an OpenAI response into Messages that follow the internal
|
|
paradigm where LLM responses are structured as tool calls instead of content.
|
|
"""
|
|
tool_call_id = str(uuid.uuid4())
|
|
|
|
return create_letta_messages_from_llm_response(
|
|
agent_id=agent_id,
|
|
model=model,
|
|
function_name=DEFAULT_MESSAGE_TOOL,
|
|
function_arguments={DEFAULT_MESSAGE_TOOL_KWARG: response_text}, # Avoid raw string manipulation
|
|
tool_call_id=tool_call_id,
|
|
function_call_success=True,
|
|
function_response=None,
|
|
actor=actor,
|
|
add_heartbeat_request_system_message=False,
|
|
)
|
|
|
|
|
|
def convert_letta_messages_to_openai(messages: List[Message]) -> List[dict]:
|
|
"""
|
|
Flattens Letta's messages (with system, user, assistant, tool roles, etc.)
|
|
into standard OpenAI chat messages (system, user, assistant).
|
|
|
|
Transformation rules:
|
|
1. Assistant + send_message tool_call => content = tool_call's "message"
|
|
2. Tool (role=tool) referencing send_message => skip
|
|
3. User messages might store actual text inside JSON => parse that into content
|
|
4. System => pass through as normal
|
|
"""
|
|
|
|
openai_messages = []
|
|
|
|
for msg in messages:
|
|
# 1. Assistant + 'send_message' tool_calls => flatten
|
|
if msg.role == MessageRole.assistant and msg.tool_calls:
|
|
# Find any 'send_message' tool_calls
|
|
send_message_calls = [tc for tc in msg.tool_calls if tc.function.name == "send_message"]
|
|
if send_message_calls:
|
|
# If we have multiple calls, just pick the first or merge them
|
|
# Typically there's only one.
|
|
tc = send_message_calls[0]
|
|
arguments = json.loads(tc.function.arguments)
|
|
# Extract the "message" string
|
|
extracted_text = arguments.get("message", "")
|
|
|
|
# Create a new content with the extracted text
|
|
msg = Message(
|
|
id=msg.id,
|
|
role=msg.role,
|
|
content=[TextContent(text=extracted_text)],
|
|
organization_id=msg.organization_id,
|
|
agent_id=msg.agent_id,
|
|
model=msg.model,
|
|
name=msg.name,
|
|
tool_calls=None, # no longer needed
|
|
tool_call_id=None,
|
|
created_at=msg.created_at,
|
|
)
|
|
|
|
# 2. If role=tool and it's referencing send_message => skip
|
|
if msg.role == MessageRole.tool and msg.name == "send_message":
|
|
# Usually 'tool' messages with `send_message` are just status/OK messages
|
|
# that OpenAI doesn't need to see. So skip them.
|
|
continue
|
|
|
|
# 3. User messages might store text in JSON => parse it
|
|
if msg.role == MessageRole.user:
|
|
# Example: content=[TextContent(text='{"type": "user_message","message":"Hello"}')]
|
|
# Attempt to parse JSON and extract "message"
|
|
if msg.content and msg.content[0].text.strip().startswith("{"):
|
|
try:
|
|
parsed = json.loads(msg.content[0].text)
|
|
# If there's a "message" field, use that as the content
|
|
if "message" in parsed:
|
|
actual_user_text = parsed["message"]
|
|
msg = Message(
|
|
id=msg.id,
|
|
role=msg.role,
|
|
content=[TextContent(text=actual_user_text)],
|
|
organization_id=msg.organization_id,
|
|
agent_id=msg.agent_id,
|
|
model=msg.model,
|
|
name=msg.name,
|
|
tool_calls=msg.tool_calls,
|
|
tool_call_id=msg.tool_call_id,
|
|
created_at=msg.created_at,
|
|
)
|
|
except json.JSONDecodeError:
|
|
pass # It's not JSON, leave as-is
|
|
|
|
# 4. System is left as-is (or any other role that doesn't need special handling)
|
|
#
|
|
# Finally, convert to dict using your existing method
|
|
openai_messages.append(msg.to_openai_dict())
|
|
|
|
return openai_messages
|
|
|
|
|
|
def get_messages_from_completion_request(completion_request: CompletionCreateParams) -> List[Dict]:
|
|
try:
|
|
messages = list(cast(Iterable[ChatCompletionMessageParam], completion_request["messages"]))
|
|
except KeyError:
|
|
# Handle the case where "messages" is not present in the request
|
|
raise HTTPException(status_code=400, detail="The 'messages' field is missing in the request.")
|
|
except TypeError:
|
|
# Handle the case where "messages" is not iterable
|
|
raise HTTPException(status_code=400, detail="The 'messages' field must be an iterable.")
|
|
except Exception as e:
|
|
# Catch any other unexpected errors and include the exception message
|
|
raise HTTPException(status_code=400, detail=f"An error occurred while processing 'messages': {str(e)}")
|
|
|
|
if messages[-1]["role"] != "user":
|
|
logger.error(f"The last message does not have a `user` role: {messages}")
|
|
raise HTTPException(status_code=400, detail="'messages[-1].role' must be a 'user'")
|
|
|
|
input_message = messages[-1]
|
|
if not isinstance(input_message["content"], str):
|
|
logger.error(f"The input message does not have valid content: {input_message}")
|
|
raise HTTPException(status_code=400, detail="'messages[-1].content' must be a 'string'")
|
|
|
|
return messages
|