feat: add message type literal to usage stats (#2297)

Co-authored-by: Caren Thomas <caren@caren-mac.local>
This commit is contained in:
cthomas 2024-12-20 17:13:56 -08:00 committed by GitHub
parent 803833e97e
commit b9b77fdc02
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 5 additions and 4 deletions

View File

@ -59,8 +59,8 @@ def _sse_post(url: str, data: dict, headers: dict) -> Generator[LettaStreamingRe
yield ToolCallMessage(**chunk_data)
elif "tool_return" in chunk_data:
yield ToolReturnMessage(**chunk_data)
elif "usage" in chunk_data:
yield LettaUsageStatistics(**chunk_data["usage"])
elif "step_count" in chunk_data:
yield LettaUsageStatistics(**chunk_data)
else:
raise ValueError(f"Unknown message type in chunk_data: {chunk_data}")

View File

@ -1,3 +1,4 @@
from typing import Literal
from pydantic import BaseModel, Field
@ -11,7 +12,7 @@ class LettaUsageStatistics(BaseModel):
total_tokens (int): The total number of tokens processed by the agent.
step_count (int): The number of steps taken by the agent.
"""
message_type: Literal["usage_statistics"] = "usage_statistics"
completion_tokens: int = Field(0, description="The number of tokens generated by the agent.")
prompt_tokens: int = Field(0, description="The number of tokens in the prompt.")
total_tokens: int = Field(0, description="The total number of tokens processed by the agent.")

View File

@ -61,7 +61,7 @@ async def sse_async_generator(
# Double-check the type
if not isinstance(usage, LettaUsageStatistics):
raise ValueError(f"Expected LettaUsageStatistics, got {type(usage)}")
yield sse_formatter({"usage": usage.model_dump()})
yield sse_formatter(usage.model_dump())
except ContextWindowExceededError as e:
log_error_to_sentry(e)