feat: Implement forward deletes during undo + checkpoint (#1493)

This commit is contained in:
Matthew Zhou 2025-03-31 17:04:48 -07:00 committed by GitHub
parent 9d803f1fd0
commit ed77771929
2 changed files with 125 additions and 14 deletions

View File

@ -1,8 +1,6 @@
import os
from typing import List, Optional
from sqlalchemy import func
from letta.orm.block import Block as BlockModel
from letta.orm.block_history import BlockHistory
from letta.orm.enums import ActorType
@ -152,28 +150,42 @@ class BlockManager:
block_id: str,
actor: PydanticUser,
agent_id: Optional[str] = None,
use_preloaded_block: Optional[BlockModel] = None, # TODO: Useful for testing concurrency
use_preloaded_block: Optional[BlockModel] = None, # For concurrency tests
) -> PydanticBlock:
"""
Create a new checkpoint for the given Block by copying its
current state into BlockHistory, using SQLAlchemy's built-in
version_id_col for concurrency checks.
Note: We only have a single commit at the end, to avoid weird intermediate states.
e.g. created a BlockHistory, but the block update failed
- If the block was undone to an earlier checkpoint, we remove
any "future" checkpoints beyond the current state to keep a
strictly linear history.
- A single commit at the end ensures atomicity.
"""
with self.session_maker() as session:
# 1) Load the block via the ORM
# 1) Load the Block
if use_preloaded_block is not None:
block = session.merge(use_preloaded_block)
else:
block = BlockModel.read(db_session=session, identifier=block_id, actor=actor)
# 2) Create a new sequence number for BlockHistory
current_max_seq = session.query(func.max(BlockHistory.sequence_number)).filter(BlockHistory.block_id == block_id).scalar()
next_seq = (current_max_seq or 0) + 1
# 2) Identify the block's current checkpoint (if any)
current_entry = None
if block.current_history_entry_id:
current_entry = session.get(BlockHistory, block.current_history_entry_id)
# 3) Create a snapshot in BlockHistory
# The current sequence, or 0 if no checkpoints exist
current_seq = current_entry.sequence_number if current_entry else 0
# 3) Truncate any future checkpoints
# If we are at seq=2, but there's a seq=3 or higher from a prior "redo chain",
# remove those, so we maintain a strictly linear undo/redo stack.
session.query(BlockHistory).filter(BlockHistory.block_id == block.id, BlockHistory.sequence_number > current_seq).delete()
# 4) Determine the next sequence number
next_seq = current_seq + 1
# 5) Create a new BlockHistory row reflecting the block's current state
history_entry = BlockHistory(
organization_id=actor.organization_id,
block_id=block.id,
@ -188,14 +200,13 @@ class BlockManager:
)
history_entry.create(session, actor=actor, no_commit=True)
# 4) Update the blocks pointer
# 6) Update the blocks pointer to the new checkpoint
block.current_history_entry_id = history_entry.id
# 5) Now just flush; SQLAlchemy will:
# 7) Flush changes, then commit once
block = block.update(db_session=session, actor=actor, no_commit=True)
session.commit()
# Return the blocks new state
return block.to_pydantic()
@enforce_types
@ -259,5 +270,4 @@ class BlockManager:
block = block.update(db_session=session, actor=actor, no_commit=True)
session.commit()
# 6) Return the blocks new state in Pydantic form
return block.to_pydantic()

View File

@ -2727,6 +2727,42 @@ def test_checkpoint_concurrency_stale(server: SyncServer, default_user):
)
def test_checkpoint_no_future_states(server: SyncServer, default_user):
"""
Ensures that if the block is already at the highest sequence,
creating a new checkpoint does NOT delete anything.
"""
block_manager = BlockManager()
# 1) Create block with "v1" and checkpoint => seq=1
block_v1 = block_manager.create_or_update_block(PydanticBlock(label="no_future_test", value="v1"), actor=default_user)
block_manager.checkpoint_block(block_id=block_v1.id, actor=default_user)
# 2) Create "v2" and checkpoint => seq=2
updated_data = PydanticBlock(**block_v1.dict())
updated_data.value = "v2"
block_manager.create_or_update_block(updated_data, actor=default_user)
block_manager.checkpoint_block(block_id=block_v1.id, actor=default_user)
# So we have seq=1: v1, seq=2: v2. No "future" states.
# 3) Another checkpoint (no changes made) => should become seq=3, not delete anything
block_manager.checkpoint_block(block_id=block_v1.id, actor=default_user)
with db_context() as session:
# We expect 3 rows in block_history, none removed
history_rows = (
session.query(BlockHistory).filter(BlockHistory.block_id == block_v1.id).order_by(BlockHistory.sequence_number.asc()).all()
)
# Should be seq=1, seq=2, seq=3
assert len(history_rows) == 3
assert history_rows[0].value == "v1"
assert history_rows[1].value == "v2"
# The last is also "v2" if we didn't change it, or the same current fields
assert history_rows[2].sequence_number == 3
# There's no leftover row that was deleted
# ======================================================================================================================
# Block Manager Tests - Undo
# ======================================================================================================================
@ -2764,6 +2800,71 @@ def test_undo_checkpoint_block(server: SyncServer, default_user):
assert undone_block.label == "undo_test", "Label should also revert if changed (or remain the same if unchanged)"
def test_checkpoint_deletes_future_states_after_undo(server: SyncServer, default_user):
"""
Verifies that once we've undone to an earlier checkpoint, creating a new
checkpoint removes any leftover 'future' states that existed beyond that sequence.
"""
block_manager = BlockManager()
# 1) Create block
block_init = PydanticBlock(label="test_truncation", value="v1")
block_v1 = block_manager.create_or_update_block(block_init, actor=default_user)
# Checkpoint => seq=1
block_manager.checkpoint_block(block_id=block_v1.id, actor=default_user)
# 2) Update to "v2", checkpoint => seq=2
block_v2 = PydanticBlock(**block_v1.dict())
block_v2.value = "v2"
block_manager.create_or_update_block(block_v2, actor=default_user)
block_manager.checkpoint_block(block_id=block_v1.id, actor=default_user)
# 3) Update to "v3", checkpoint => seq=3
block_v3 = PydanticBlock(**block_v1.dict())
block_v3.value = "v3"
block_manager.create_or_update_block(block_v3, actor=default_user)
block_manager.checkpoint_block(block_id=block_v1.id, actor=default_user)
# We now have three states in history: seq=1 (v1), seq=2 (v2), seq=3 (v3).
# Undo from seq=3 -> seq=2
block_undo_1 = block_manager.undo_checkpoint_block(block_v1.id, actor=default_user)
assert block_undo_1.value == "v2"
# Undo from seq=2 -> seq=1
block_undo_2 = block_manager.undo_checkpoint_block(block_v1.id, actor=default_user)
assert block_undo_2.value == "v1"
# 4) Now we are at seq=1. If we checkpoint again, we should remove the old seq=2,3
# because the new code truncates future states beyond seq=1.
# Let's do a new edit: "v1.5"
block_v1_5 = PydanticBlock(**block_undo_2.dict())
block_v1_5.value = "v1.5"
block_manager.create_or_update_block(block_v1_5, actor=default_user)
# 5) Checkpoint => new seq=2, removing the old seq=2 and seq=3
block_manager.checkpoint_block(block_id=block_v1.id, actor=default_user)
with db_context() as session:
# Let's see which BlockHistory rows remain
history_entries = (
session.query(BlockHistory).filter(BlockHistory.block_id == block_v1.id).order_by(BlockHistory.sequence_number.asc()).all()
)
# We expect two rows: seq=1 => "v1", seq=2 => "v1.5"
assert len(history_entries) == 2, f"Expected 2 entries, got {len(history_entries)}"
assert history_entries[0].sequence_number == 1
assert history_entries[0].value == "v1"
assert history_entries[1].sequence_number == 2
assert history_entries[1].value == "v1.5"
# No row should contain "v2" or "v3"
existing_values = {h.value for h in history_entries}
assert "v2" not in existing_values, "Old seq=2 should have been removed."
assert "v3" not in existing_values, "Old seq=3 should have been removed."
def test_undo_no_history(server: SyncServer, default_user):
"""
If a block has never been checkpointed (no current_history_entry_id),