feat: add agent to steps table and support filtering (#1212)

This commit is contained in:
cthomas 2025-03-07 10:10:29 -08:00 committed by GitHub
parent 6fdcb49f17
commit 52dba65bde
11 changed files with 58 additions and 10 deletions

View File

@ -0,0 +1,31 @@
"""add agent id to steps
Revision ID: d211df879a5f
Revises: 2f4ede6ae33b
Create Date: 2025-03-06 21:42:22.289345
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "d211df879a5f"
down_revision: Union[str, None] = "2f4ede6ae33b"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("steps", sa.Column("agent_id", sa.String(), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("steps", "agent_id")
# ### end Alembic commands ###

View File

@ -909,6 +909,7 @@ class Agent(BaseAgent):
# Log step - this must happen before messages are persisted
step = self.step_manager.log_step(
actor=self.user,
agent_id=self.agent_state.id,
provider_name=self.agent_state.llm_config.model_endpoint_type,
model=self.agent_state.llm_config.model,
model_endpoint=self.agent_state.llm_config.model_endpoint,

View File

@ -4,7 +4,6 @@ import time
from typing import Callable, Dict, Generator, List, Optional, Union
import requests
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall
import letta.utils
from letta.constants import ADMIN_PREFIX, BASE_MEMORY_TOOLS, BASE_TOOLS, DEFAULT_HUMAN, DEFAULT_PERSONA, FUNCTION_RETURN_CHAR_LIMIT
@ -29,7 +28,7 @@ from letta.schemas.letta_request import LettaRequest, LettaStreamingRequest
from letta.schemas.letta_response import LettaResponse, LettaStreamingResponse
from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import ArchivalMemorySummary, ChatMemory, CreateArchivalMemory, Memory, RecallMemorySummary
from letta.schemas.message import Message, MessageCreate, MessageUpdate
from letta.schemas.message import Message, MessageCreate
from letta.schemas.openai.chat_completion_response import UsageStatistics
from letta.schemas.organization import Organization
from letta.schemas.passage import Passage

View File

@ -33,6 +33,7 @@ class Step(SqlalchemyBase):
job_id: Mapped[Optional[str]] = mapped_column(
ForeignKey("jobs.id", ondelete="SET NULL"), nullable=True, doc="The unique identified of the job run that triggered this step"
)
agent_id: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="The name of the model used for this step.")
provider_name: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="The name of the provider used for this step.")
model: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="The name of the model used for this step.")
model_endpoint: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="The model endpoint url used for this step.")

View File

@ -18,6 +18,7 @@ class Step(StepBase):
job_id: Optional[str] = Field(
None, description="The unique identifier of the job that this step belongs to. Only included for async calls."
)
agent_id: Optional[str] = Field(None, description="The ID of the agent that performed the step.")
provider_name: Optional[str] = Field(None, description="The name of the provider used for this step.")
model: Optional[str] = Field(None, description="The name of the model used for this step.")
model_endpoint: Optional[str] = Field(None, description="The model endpoint url used for this step.")

View File

@ -19,7 +19,6 @@ from letta.schemas.letta_message import LettaMessageUnion, LettaMessageUpdateUni
from letta.schemas.letta_request import LettaRequest, LettaStreamingRequest
from letta.schemas.letta_response import LettaResponse
from letta.schemas.memory import ContextWindowOverview, CreateArchivalMemory, Memory
from letta.schemas.message import Message, MessageUpdate
from letta.schemas.passage import Passage, PassageUpdate
from letta.schemas.run import Run
from letta.schemas.source import Source

View File

@ -20,6 +20,7 @@ def list_steps(
start_date: Optional[str] = Query(None, description='Return steps after this ISO datetime (e.g. "2025-01-29T15:01:19-08:00")'),
end_date: Optional[str] = Query(None, description='Return steps before this ISO datetime (e.g. "2025-01-29T15:01:19-08:00")'),
model: Optional[str] = Query(None, description="Filter by the name of the model used for the step"),
agent_id: Optional[str] = Query(None, description="Filter by the ID of the agent that performed the step"),
server: SyncServer = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"),
):
@ -42,6 +43,7 @@ def list_steps(
limit=limit,
order=order,
model=model,
agent_id=agent_id,
)

View File

@ -1,7 +1,6 @@
import json
from typing import List, Optional
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall
from sqlalchemy import and_, or_
from letta.log import get_logger

View File

@ -33,10 +33,15 @@ class StepManager:
limit: Optional[int] = 50,
order: Optional[str] = None,
model: Optional[str] = None,
agent_id: Optional[str] = None,
) -> List[PydanticStep]:
"""List all jobs with optional pagination and status filter."""
with self.session_maker() as session:
filter_kwargs = {"organization_id": actor.organization_id, "model": model}
filter_kwargs = {"organization_id": actor.organization_id}
if model:
filter_kwargs["model"] = model
if agent_id:
filter_kwargs["agent_id"] = agent_id
steps = StepModel.list(
db_session=session,
@ -54,6 +59,7 @@ class StepManager:
def log_step(
self,
actor: PydanticUser,
agent_id: str,
provider_name: str,
model: str,
model_endpoint: Optional[str],
@ -65,6 +71,7 @@ class StepManager:
step_data = {
"origin": None,
"organization_id": actor.organization_id,
"agent_id": agent_id,
"provider_id": provider_id,
"provider_name": provider_name,
"model": model,

View File

@ -26,7 +26,7 @@ from letta.schemas.letta_message import (
ToolReturnMessage,
UserMessage,
)
from letta.schemas.letta_response import LettaResponse, LettaStreamingResponse
from letta.schemas.letta_response import LettaStreamingResponse
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import MessageCreate
from letta.schemas.usage import LettaUsageStatistics

View File

@ -24,7 +24,7 @@ from letta.schemas.file import FileMetadata as PydanticFileMetadata
from letta.schemas.identity import IdentityCreate, IdentityProperty, IdentityPropertyType, IdentityType, IdentityUpdate
from letta.schemas.job import Job as PydanticJob
from letta.schemas.job import JobUpdate, LettaRequestConfig
from letta.schemas.letta_message import LettaMessage, UpdateAssistantMessage, UpdateReasoningMessage, UpdateSystemMessage, UpdateUserMessage
from letta.schemas.letta_message import UpdateAssistantMessage, UpdateReasoningMessage, UpdateSystemMessage, UpdateUserMessage
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message as PydanticMessage
from letta.schemas.message import MessageCreate, MessageUpdate
@ -3391,13 +3391,14 @@ def test_get_run_messages(server: SyncServer, default_user: PydanticUser, sarah_
# ======================================================================================================================
def test_job_usage_stats_add_and_get(server: SyncServer, default_job, default_user):
def test_job_usage_stats_add_and_get(server: SyncServer, sarah_agent, default_job, default_user):
"""Test adding and retrieving job usage statistics."""
job_manager = server.job_manager
step_manager = server.step_manager
# Add usage statistics
step_manager.log_step(
agent_id=sarah_agent.id,
provider_name="openai",
model="gpt-4",
model_endpoint="https://api.openai.com/v1",
@ -3441,13 +3442,14 @@ def test_job_usage_stats_get_no_stats(server: SyncServer, default_job, default_u
assert len(steps) == 0
def test_job_usage_stats_add_multiple(server: SyncServer, default_job, default_user):
def test_job_usage_stats_add_multiple(server: SyncServer, sarah_agent, default_job, default_user):
"""Test adding multiple usage statistics entries for a job."""
job_manager = server.job_manager
step_manager = server.step_manager
# Add first usage statistics entry
step_manager.log_step(
agent_id=sarah_agent.id,
provider_name="openai",
model="gpt-4",
model_endpoint="https://api.openai.com/v1",
@ -3463,6 +3465,7 @@ def test_job_usage_stats_add_multiple(server: SyncServer, default_job, default_u
# Add second usage statistics entry
step_manager.log_step(
agent_id=sarah_agent.id,
provider_name="openai",
model="gpt-4",
model_endpoint="https://api.openai.com/v1",
@ -3489,6 +3492,10 @@ def test_job_usage_stats_add_multiple(server: SyncServer, default_job, default_u
steps = job_manager.get_job_steps(job_id=default_job.id, actor=default_user)
assert len(steps) == 2
# get agent steps
steps = step_manager.list_steps(agent_id=sarah_agent.id, actor=default_user)
assert len(steps) == 2
def test_job_usage_stats_get_nonexistent_job(server: SyncServer, default_user):
"""Test getting usage statistics for a nonexistent job."""
@ -3498,12 +3505,13 @@ def test_job_usage_stats_get_nonexistent_job(server: SyncServer, default_user):
job_manager.get_job_usage(job_id="nonexistent_job", actor=default_user)
def test_job_usage_stats_add_nonexistent_job(server: SyncServer, default_user):
def test_job_usage_stats_add_nonexistent_job(server: SyncServer, sarah_agent, default_user):
"""Test adding usage statistics for a nonexistent job."""
step_manager = server.step_manager
with pytest.raises(NoResultFound):
step_manager.log_step(
agent_id=sarah_agent.id,
provider_name="openai",
model="gpt-4",
model_endpoint="https://api.openai.com/v1",