mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
feat: add agent to steps table and support filtering (#1212)
This commit is contained in:
parent
6fdcb49f17
commit
52dba65bde
31
alembic/versions/d211df879a5f_add_agent_id_to_steps.py
Normal file
31
alembic/versions/d211df879a5f_add_agent_id_to_steps.py
Normal file
@ -0,0 +1,31 @@
|
||||
"""add agent id to steps
|
||||
|
||||
Revision ID: d211df879a5f
|
||||
Revises: 2f4ede6ae33b
|
||||
Create Date: 2025-03-06 21:42:22.289345
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "d211df879a5f"
|
||||
down_revision: Union[str, None] = "2f4ede6ae33b"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("steps", sa.Column("agent_id", sa.String(), nullable=True))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("steps", "agent_id")
|
||||
# ### end Alembic commands ###
|
@ -909,6 +909,7 @@ class Agent(BaseAgent):
|
||||
# Log step - this must happen before messages are persisted
|
||||
step = self.step_manager.log_step(
|
||||
actor=self.user,
|
||||
agent_id=self.agent_state.id,
|
||||
provider_name=self.agent_state.llm_config.model_endpoint_type,
|
||||
model=self.agent_state.llm_config.model,
|
||||
model_endpoint=self.agent_state.llm_config.model_endpoint,
|
||||
|
@ -4,7 +4,6 @@ import time
|
||||
from typing import Callable, Dict, Generator, List, Optional, Union
|
||||
|
||||
import requests
|
||||
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall
|
||||
|
||||
import letta.utils
|
||||
from letta.constants import ADMIN_PREFIX, BASE_MEMORY_TOOLS, BASE_TOOLS, DEFAULT_HUMAN, DEFAULT_PERSONA, FUNCTION_RETURN_CHAR_LIMIT
|
||||
@ -29,7 +28,7 @@ from letta.schemas.letta_request import LettaRequest, LettaStreamingRequest
|
||||
from letta.schemas.letta_response import LettaResponse, LettaStreamingResponse
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.memory import ArchivalMemorySummary, ChatMemory, CreateArchivalMemory, Memory, RecallMemorySummary
|
||||
from letta.schemas.message import Message, MessageCreate, MessageUpdate
|
||||
from letta.schemas.message import Message, MessageCreate
|
||||
from letta.schemas.openai.chat_completion_response import UsageStatistics
|
||||
from letta.schemas.organization import Organization
|
||||
from letta.schemas.passage import Passage
|
||||
|
@ -33,6 +33,7 @@ class Step(SqlalchemyBase):
|
||||
job_id: Mapped[Optional[str]] = mapped_column(
|
||||
ForeignKey("jobs.id", ondelete="SET NULL"), nullable=True, doc="The unique identified of the job run that triggered this step"
|
||||
)
|
||||
agent_id: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="The name of the model used for this step.")
|
||||
provider_name: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="The name of the provider used for this step.")
|
||||
model: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="The name of the model used for this step.")
|
||||
model_endpoint: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="The model endpoint url used for this step.")
|
||||
|
@ -18,6 +18,7 @@ class Step(StepBase):
|
||||
job_id: Optional[str] = Field(
|
||||
None, description="The unique identifier of the job that this step belongs to. Only included for async calls."
|
||||
)
|
||||
agent_id: Optional[str] = Field(None, description="The ID of the agent that performed the step.")
|
||||
provider_name: Optional[str] = Field(None, description="The name of the provider used for this step.")
|
||||
model: Optional[str] = Field(None, description="The name of the model used for this step.")
|
||||
model_endpoint: Optional[str] = Field(None, description="The model endpoint url used for this step.")
|
||||
|
@ -19,7 +19,6 @@ from letta.schemas.letta_message import LettaMessageUnion, LettaMessageUpdateUni
|
||||
from letta.schemas.letta_request import LettaRequest, LettaStreamingRequest
|
||||
from letta.schemas.letta_response import LettaResponse
|
||||
from letta.schemas.memory import ContextWindowOverview, CreateArchivalMemory, Memory
|
||||
from letta.schemas.message import Message, MessageUpdate
|
||||
from letta.schemas.passage import Passage, PassageUpdate
|
||||
from letta.schemas.run import Run
|
||||
from letta.schemas.source import Source
|
||||
|
@ -20,6 +20,7 @@ def list_steps(
|
||||
start_date: Optional[str] = Query(None, description='Return steps after this ISO datetime (e.g. "2025-01-29T15:01:19-08:00")'),
|
||||
end_date: Optional[str] = Query(None, description='Return steps before this ISO datetime (e.g. "2025-01-29T15:01:19-08:00")'),
|
||||
model: Optional[str] = Query(None, description="Filter by the name of the model used for the step"),
|
||||
agent_id: Optional[str] = Query(None, description="Filter by the ID of the agent that performed the step"),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
@ -42,6 +43,7 @@ def list_steps(
|
||||
limit=limit,
|
||||
order=order,
|
||||
model=model,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
|
||||
|
||||
|
@ -1,7 +1,6 @@
|
||||
import json
|
||||
from typing import List, Optional
|
||||
|
||||
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall
|
||||
from sqlalchemy import and_, or_
|
||||
|
||||
from letta.log import get_logger
|
||||
|
@ -33,10 +33,15 @@ class StepManager:
|
||||
limit: Optional[int] = 50,
|
||||
order: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
agent_id: Optional[str] = None,
|
||||
) -> List[PydanticStep]:
|
||||
"""List all jobs with optional pagination and status filter."""
|
||||
with self.session_maker() as session:
|
||||
filter_kwargs = {"organization_id": actor.organization_id, "model": model}
|
||||
filter_kwargs = {"organization_id": actor.organization_id}
|
||||
if model:
|
||||
filter_kwargs["model"] = model
|
||||
if agent_id:
|
||||
filter_kwargs["agent_id"] = agent_id
|
||||
|
||||
steps = StepModel.list(
|
||||
db_session=session,
|
||||
@ -54,6 +59,7 @@ class StepManager:
|
||||
def log_step(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
agent_id: str,
|
||||
provider_name: str,
|
||||
model: str,
|
||||
model_endpoint: Optional[str],
|
||||
@ -65,6 +71,7 @@ class StepManager:
|
||||
step_data = {
|
||||
"origin": None,
|
||||
"organization_id": actor.organization_id,
|
||||
"agent_id": agent_id,
|
||||
"provider_id": provider_id,
|
||||
"provider_name": provider_name,
|
||||
"model": model,
|
||||
|
@ -26,7 +26,7 @@ from letta.schemas.letta_message import (
|
||||
ToolReturnMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from letta.schemas.letta_response import LettaResponse, LettaStreamingResponse
|
||||
from letta.schemas.letta_response import LettaStreamingResponse
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import MessageCreate
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
|
@ -24,7 +24,7 @@ from letta.schemas.file import FileMetadata as PydanticFileMetadata
|
||||
from letta.schemas.identity import IdentityCreate, IdentityProperty, IdentityPropertyType, IdentityType, IdentityUpdate
|
||||
from letta.schemas.job import Job as PydanticJob
|
||||
from letta.schemas.job import JobUpdate, LettaRequestConfig
|
||||
from letta.schemas.letta_message import LettaMessage, UpdateAssistantMessage, UpdateReasoningMessage, UpdateSystemMessage, UpdateUserMessage
|
||||
from letta.schemas.letta_message import UpdateAssistantMessage, UpdateReasoningMessage, UpdateSystemMessage, UpdateUserMessage
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.message import MessageCreate, MessageUpdate
|
||||
@ -3391,13 +3391,14 @@ def test_get_run_messages(server: SyncServer, default_user: PydanticUser, sarah_
|
||||
# ======================================================================================================================
|
||||
|
||||
|
||||
def test_job_usage_stats_add_and_get(server: SyncServer, default_job, default_user):
|
||||
def test_job_usage_stats_add_and_get(server: SyncServer, sarah_agent, default_job, default_user):
|
||||
"""Test adding and retrieving job usage statistics."""
|
||||
job_manager = server.job_manager
|
||||
step_manager = server.step_manager
|
||||
|
||||
# Add usage statistics
|
||||
step_manager.log_step(
|
||||
agent_id=sarah_agent.id,
|
||||
provider_name="openai",
|
||||
model="gpt-4",
|
||||
model_endpoint="https://api.openai.com/v1",
|
||||
@ -3441,13 +3442,14 @@ def test_job_usage_stats_get_no_stats(server: SyncServer, default_job, default_u
|
||||
assert len(steps) == 0
|
||||
|
||||
|
||||
def test_job_usage_stats_add_multiple(server: SyncServer, default_job, default_user):
|
||||
def test_job_usage_stats_add_multiple(server: SyncServer, sarah_agent, default_job, default_user):
|
||||
"""Test adding multiple usage statistics entries for a job."""
|
||||
job_manager = server.job_manager
|
||||
step_manager = server.step_manager
|
||||
|
||||
# Add first usage statistics entry
|
||||
step_manager.log_step(
|
||||
agent_id=sarah_agent.id,
|
||||
provider_name="openai",
|
||||
model="gpt-4",
|
||||
model_endpoint="https://api.openai.com/v1",
|
||||
@ -3463,6 +3465,7 @@ def test_job_usage_stats_add_multiple(server: SyncServer, default_job, default_u
|
||||
|
||||
# Add second usage statistics entry
|
||||
step_manager.log_step(
|
||||
agent_id=sarah_agent.id,
|
||||
provider_name="openai",
|
||||
model="gpt-4",
|
||||
model_endpoint="https://api.openai.com/v1",
|
||||
@ -3489,6 +3492,10 @@ def test_job_usage_stats_add_multiple(server: SyncServer, default_job, default_u
|
||||
steps = job_manager.get_job_steps(job_id=default_job.id, actor=default_user)
|
||||
assert len(steps) == 2
|
||||
|
||||
# get agent steps
|
||||
steps = step_manager.list_steps(agent_id=sarah_agent.id, actor=default_user)
|
||||
assert len(steps) == 2
|
||||
|
||||
|
||||
def test_job_usage_stats_get_nonexistent_job(server: SyncServer, default_user):
|
||||
"""Test getting usage statistics for a nonexistent job."""
|
||||
@ -3498,12 +3505,13 @@ def test_job_usage_stats_get_nonexistent_job(server: SyncServer, default_user):
|
||||
job_manager.get_job_usage(job_id="nonexistent_job", actor=default_user)
|
||||
|
||||
|
||||
def test_job_usage_stats_add_nonexistent_job(server: SyncServer, default_user):
|
||||
def test_job_usage_stats_add_nonexistent_job(server: SyncServer, sarah_agent, default_user):
|
||||
"""Test adding usage statistics for a nonexistent job."""
|
||||
step_manager = server.step_manager
|
||||
|
||||
with pytest.raises(NoResultFound):
|
||||
step_manager.log_step(
|
||||
agent_id=sarah_agent.id,
|
||||
provider_name="openai",
|
||||
model="gpt-4",
|
||||
model_endpoint="https://api.openai.com/v1",
|
||||
|
Loading…
Reference in New Issue
Block a user