mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
feat: add message type literal to usage stats (#2297)
Co-authored-by: Caren Thomas <caren@caren-mac.local>
This commit is contained in:
parent
803833e97e
commit
b9b77fdc02
@ -59,8 +59,8 @@ def _sse_post(url: str, data: dict, headers: dict) -> Generator[LettaStreamingRe
|
||||
yield ToolCallMessage(**chunk_data)
|
||||
elif "tool_return" in chunk_data:
|
||||
yield ToolReturnMessage(**chunk_data)
|
||||
elif "usage" in chunk_data:
|
||||
yield LettaUsageStatistics(**chunk_data["usage"])
|
||||
elif "step_count" in chunk_data:
|
||||
yield LettaUsageStatistics(**chunk_data)
|
||||
else:
|
||||
raise ValueError(f"Unknown message type in chunk_data: {chunk_data}")
|
||||
|
||||
|
@ -1,3 +1,4 @@
|
||||
from typing import Literal
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@ -11,7 +12,7 @@ class LettaUsageStatistics(BaseModel):
|
||||
total_tokens (int): The total number of tokens processed by the agent.
|
||||
step_count (int): The number of steps taken by the agent.
|
||||
"""
|
||||
|
||||
message_type: Literal["usage_statistics"] = "usage_statistics"
|
||||
completion_tokens: int = Field(0, description="The number of tokens generated by the agent.")
|
||||
prompt_tokens: int = Field(0, description="The number of tokens in the prompt.")
|
||||
total_tokens: int = Field(0, description="The total number of tokens processed by the agent.")
|
||||
|
@ -61,7 +61,7 @@ async def sse_async_generator(
|
||||
# Double-check the type
|
||||
if not isinstance(usage, LettaUsageStatistics):
|
||||
raise ValueError(f"Expected LettaUsageStatistics, got {type(usage)}")
|
||||
yield sse_formatter({"usage": usage.model_dump()})
|
||||
yield sse_formatter(usage.model_dump())
|
||||
|
||||
except ContextWindowExceededError as e:
|
||||
log_error_to_sentry(e)
|
||||
|
Loading…
Reference in New Issue
Block a user