From 52dba65bdeadffe2714505854d1fc6ca6de6d10a Mon Sep 17 00:00:00 2001 From: cthomas Date: Fri, 7 Mar 2025 10:10:29 -0800 Subject: [PATCH] feat: add agent to steps table and support filtering (#1212) --- .../d211df879a5f_add_agent_id_to_steps.py | 31 +++++++++++++++++++ letta/agent.py | 1 + letta/client/client.py | 3 +- letta/orm/step.py | 1 + letta/schemas/step.py | 1 + letta/server/rest_api/routers/v1/agents.py | 1 - letta/server/rest_api/routers/v1/steps.py | 2 ++ letta/services/message_manager.py | 1 - letta/services/step_manager.py | 9 +++++- tests/test_client_legacy.py | 2 +- tests/test_managers.py | 16 +++++++--- 11 files changed, 58 insertions(+), 10 deletions(-) create mode 100644 alembic/versions/d211df879a5f_add_agent_id_to_steps.py diff --git a/alembic/versions/d211df879a5f_add_agent_id_to_steps.py b/alembic/versions/d211df879a5f_add_agent_id_to_steps.py new file mode 100644 index 000000000..d857fca47 --- /dev/null +++ b/alembic/versions/d211df879a5f_add_agent_id_to_steps.py @@ -0,0 +1,31 @@ +"""add agent id to steps + +Revision ID: d211df879a5f +Revises: 2f4ede6ae33b +Create Date: 2025-03-06 21:42:22.289345 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "d211df879a5f" +down_revision: Union[str, None] = "2f4ede6ae33b" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("steps", sa.Column("agent_id", sa.String(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("steps", "agent_id") + # ### end Alembic commands ### diff --git a/letta/agent.py b/letta/agent.py index d6a62de7e..5655e04cd 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -909,6 +909,7 @@ class Agent(BaseAgent): # Log step - this must happen before messages are persisted step = self.step_manager.log_step( actor=self.user, + agent_id=self.agent_state.id, provider_name=self.agent_state.llm_config.model_endpoint_type, model=self.agent_state.llm_config.model, model_endpoint=self.agent_state.llm_config.model_endpoint, diff --git a/letta/client/client.py b/letta/client/client.py index 3e9de1178..4405a167d 100644 --- a/letta/client/client.py +++ b/letta/client/client.py @@ -4,7 +4,6 @@ import time from typing import Callable, Dict, Generator, List, Optional, Union import requests -from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall import letta.utils from letta.constants import ADMIN_PREFIX, BASE_MEMORY_TOOLS, BASE_TOOLS, DEFAULT_HUMAN, DEFAULT_PERSONA, FUNCTION_RETURN_CHAR_LIMIT @@ -29,7 +28,7 @@ from letta.schemas.letta_request import LettaRequest, LettaStreamingRequest from letta.schemas.letta_response import LettaResponse, LettaStreamingResponse from letta.schemas.llm_config import LLMConfig from letta.schemas.memory import ArchivalMemorySummary, ChatMemory, CreateArchivalMemory, Memory, RecallMemorySummary -from letta.schemas.message import Message, MessageCreate, MessageUpdate +from letta.schemas.message import Message, MessageCreate from letta.schemas.openai.chat_completion_response import UsageStatistics from letta.schemas.organization import Organization from letta.schemas.passage import Passage diff --git a/letta/orm/step.py b/letta/orm/step.py index f13fac6e4..ce7b82442 100644 --- a/letta/orm/step.py +++ b/letta/orm/step.py @@ -33,6 +33,7 @@ class Step(SqlalchemyBase): job_id: Mapped[Optional[str]] = mapped_column( ForeignKey("jobs.id", ondelete="SET NULL"), nullable=True, doc="The unique identified of the job run that triggered this step" ) + agent_id: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="The name of the model used for this step.") provider_name: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="The name of the provider used for this step.") model: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="The name of the model used for this step.") model_endpoint: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="The model endpoint url used for this step.") diff --git a/letta/schemas/step.py b/letta/schemas/step.py index f0e7f0800..d25d8b684 100644 --- a/letta/schemas/step.py +++ b/letta/schemas/step.py @@ -18,6 +18,7 @@ class Step(StepBase): job_id: Optional[str] = Field( None, description="The unique identifier of the job that this step belongs to. Only included for async calls." ) + agent_id: Optional[str] = Field(None, description="The ID of the agent that performed the step.") provider_name: Optional[str] = Field(None, description="The name of the provider used for this step.") model: Optional[str] = Field(None, description="The name of the model used for this step.") model_endpoint: Optional[str] = Field(None, description="The model endpoint url used for this step.") diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index 7d549b3bd..2b7aec1cc 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -19,7 +19,6 @@ from letta.schemas.letta_message import LettaMessageUnion, LettaMessageUpdateUni from letta.schemas.letta_request import LettaRequest, LettaStreamingRequest from letta.schemas.letta_response import LettaResponse from letta.schemas.memory import ContextWindowOverview, CreateArchivalMemory, Memory -from letta.schemas.message import Message, MessageUpdate from letta.schemas.passage import Passage, PassageUpdate from letta.schemas.run import Run from letta.schemas.source import Source diff --git a/letta/server/rest_api/routers/v1/steps.py b/letta/server/rest_api/routers/v1/steps.py index 7c67de9c0..fa31e2bd6 100644 --- a/letta/server/rest_api/routers/v1/steps.py +++ b/letta/server/rest_api/routers/v1/steps.py @@ -20,6 +20,7 @@ def list_steps( start_date: Optional[str] = Query(None, description='Return steps after this ISO datetime (e.g. "2025-01-29T15:01:19-08:00")'), end_date: Optional[str] = Query(None, description='Return steps before this ISO datetime (e.g. "2025-01-29T15:01:19-08:00")'), model: Optional[str] = Query(None, description="Filter by the name of the model used for the step"), + agent_id: Optional[str] = Query(None, description="Filter by the ID of the agent that performed the step"), server: SyncServer = Depends(get_letta_server), actor_id: Optional[str] = Header(None, alias="user_id"), ): @@ -42,6 +43,7 @@ def list_steps( limit=limit, order=order, model=model, + agent_id=agent_id, ) diff --git a/letta/services/message_manager.py b/letta/services/message_manager.py index 6cd3efc71..b07cb5d8a 100644 --- a/letta/services/message_manager.py +++ b/letta/services/message_manager.py @@ -1,7 +1,6 @@ import json from typing import List, Optional -from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall from sqlalchemy import and_, or_ from letta.log import get_logger diff --git a/letta/services/step_manager.py b/letta/services/step_manager.py index dbaf9f90b..fc5ed3cf4 100644 --- a/letta/services/step_manager.py +++ b/letta/services/step_manager.py @@ -33,10 +33,15 @@ class StepManager: limit: Optional[int] = 50, order: Optional[str] = None, model: Optional[str] = None, + agent_id: Optional[str] = None, ) -> List[PydanticStep]: """List all jobs with optional pagination and status filter.""" with self.session_maker() as session: - filter_kwargs = {"organization_id": actor.organization_id, "model": model} + filter_kwargs = {"organization_id": actor.organization_id} + if model: + filter_kwargs["model"] = model + if agent_id: + filter_kwargs["agent_id"] = agent_id steps = StepModel.list( db_session=session, @@ -54,6 +59,7 @@ class StepManager: def log_step( self, actor: PydanticUser, + agent_id: str, provider_name: str, model: str, model_endpoint: Optional[str], @@ -65,6 +71,7 @@ class StepManager: step_data = { "origin": None, "organization_id": actor.organization_id, + "agent_id": agent_id, "provider_id": provider_id, "provider_name": provider_name, "model": model, diff --git a/tests/test_client_legacy.py b/tests/test_client_legacy.py index 2d7ed16e6..3a744fc29 100644 --- a/tests/test_client_legacy.py +++ b/tests/test_client_legacy.py @@ -26,7 +26,7 @@ from letta.schemas.letta_message import ( ToolReturnMessage, UserMessage, ) -from letta.schemas.letta_response import LettaResponse, LettaStreamingResponse +from letta.schemas.letta_response import LettaStreamingResponse from letta.schemas.llm_config import LLMConfig from letta.schemas.message import MessageCreate from letta.schemas.usage import LettaUsageStatistics diff --git a/tests/test_managers.py b/tests/test_managers.py index b86cf7182..61bfad925 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -24,7 +24,7 @@ from letta.schemas.file import FileMetadata as PydanticFileMetadata from letta.schemas.identity import IdentityCreate, IdentityProperty, IdentityPropertyType, IdentityType, IdentityUpdate from letta.schemas.job import Job as PydanticJob from letta.schemas.job import JobUpdate, LettaRequestConfig -from letta.schemas.letta_message import LettaMessage, UpdateAssistantMessage, UpdateReasoningMessage, UpdateSystemMessage, UpdateUserMessage +from letta.schemas.letta_message import UpdateAssistantMessage, UpdateReasoningMessage, UpdateSystemMessage, UpdateUserMessage from letta.schemas.llm_config import LLMConfig from letta.schemas.message import Message as PydanticMessage from letta.schemas.message import MessageCreate, MessageUpdate @@ -3391,13 +3391,14 @@ def test_get_run_messages(server: SyncServer, default_user: PydanticUser, sarah_ # ====================================================================================================================== -def test_job_usage_stats_add_and_get(server: SyncServer, default_job, default_user): +def test_job_usage_stats_add_and_get(server: SyncServer, sarah_agent, default_job, default_user): """Test adding and retrieving job usage statistics.""" job_manager = server.job_manager step_manager = server.step_manager # Add usage statistics step_manager.log_step( + agent_id=sarah_agent.id, provider_name="openai", model="gpt-4", model_endpoint="https://api.openai.com/v1", @@ -3441,13 +3442,14 @@ def test_job_usage_stats_get_no_stats(server: SyncServer, default_job, default_u assert len(steps) == 0 -def test_job_usage_stats_add_multiple(server: SyncServer, default_job, default_user): +def test_job_usage_stats_add_multiple(server: SyncServer, sarah_agent, default_job, default_user): """Test adding multiple usage statistics entries for a job.""" job_manager = server.job_manager step_manager = server.step_manager # Add first usage statistics entry step_manager.log_step( + agent_id=sarah_agent.id, provider_name="openai", model="gpt-4", model_endpoint="https://api.openai.com/v1", @@ -3463,6 +3465,7 @@ def test_job_usage_stats_add_multiple(server: SyncServer, default_job, default_u # Add second usage statistics entry step_manager.log_step( + agent_id=sarah_agent.id, provider_name="openai", model="gpt-4", model_endpoint="https://api.openai.com/v1", @@ -3489,6 +3492,10 @@ def test_job_usage_stats_add_multiple(server: SyncServer, default_job, default_u steps = job_manager.get_job_steps(job_id=default_job.id, actor=default_user) assert len(steps) == 2 + # get agent steps + steps = step_manager.list_steps(agent_id=sarah_agent.id, actor=default_user) + assert len(steps) == 2 + def test_job_usage_stats_get_nonexistent_job(server: SyncServer, default_user): """Test getting usage statistics for a nonexistent job.""" @@ -3498,12 +3505,13 @@ def test_job_usage_stats_get_nonexistent_job(server: SyncServer, default_user): job_manager.get_job_usage(job_id="nonexistent_job", actor=default_user) -def test_job_usage_stats_add_nonexistent_job(server: SyncServer, default_user): +def test_job_usage_stats_add_nonexistent_job(server: SyncServer, sarah_agent, default_user): """Test adding usage statistics for a nonexistent job.""" step_manager = server.step_manager with pytest.raises(NoResultFound): step_manager.log_step( + agent_id=sarah_agent.id, provider_name="openai", model="gpt-4", model_endpoint="https://api.openai.com/v1",