mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
feat: Implement forward deletes during undo + checkpoint (#1493)
This commit is contained in:
parent
9d803f1fd0
commit
ed77771929
@ -1,8 +1,6 @@
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlalchemy import func
|
||||
|
||||
from letta.orm.block import Block as BlockModel
|
||||
from letta.orm.block_history import BlockHistory
|
||||
from letta.orm.enums import ActorType
|
||||
@ -152,28 +150,42 @@ class BlockManager:
|
||||
block_id: str,
|
||||
actor: PydanticUser,
|
||||
agent_id: Optional[str] = None,
|
||||
use_preloaded_block: Optional[BlockModel] = None, # TODO: Useful for testing concurrency
|
||||
use_preloaded_block: Optional[BlockModel] = None, # For concurrency tests
|
||||
) -> PydanticBlock:
|
||||
"""
|
||||
Create a new checkpoint for the given Block by copying its
|
||||
current state into BlockHistory, using SQLAlchemy's built-in
|
||||
version_id_col for concurrency checks.
|
||||
|
||||
Note: We only have a single commit at the end, to avoid weird intermediate states.
|
||||
e.g. created a BlockHistory, but the block update failed
|
||||
- If the block was undone to an earlier checkpoint, we remove
|
||||
any "future" checkpoints beyond the current state to keep a
|
||||
strictly linear history.
|
||||
- A single commit at the end ensures atomicity.
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
# 1) Load the block via the ORM
|
||||
# 1) Load the Block
|
||||
if use_preloaded_block is not None:
|
||||
block = session.merge(use_preloaded_block)
|
||||
else:
|
||||
block = BlockModel.read(db_session=session, identifier=block_id, actor=actor)
|
||||
|
||||
# 2) Create a new sequence number for BlockHistory
|
||||
current_max_seq = session.query(func.max(BlockHistory.sequence_number)).filter(BlockHistory.block_id == block_id).scalar()
|
||||
next_seq = (current_max_seq or 0) + 1
|
||||
# 2) Identify the block's current checkpoint (if any)
|
||||
current_entry = None
|
||||
if block.current_history_entry_id:
|
||||
current_entry = session.get(BlockHistory, block.current_history_entry_id)
|
||||
|
||||
# 3) Create a snapshot in BlockHistory
|
||||
# The current sequence, or 0 if no checkpoints exist
|
||||
current_seq = current_entry.sequence_number if current_entry else 0
|
||||
|
||||
# 3) Truncate any future checkpoints
|
||||
# If we are at seq=2, but there's a seq=3 or higher from a prior "redo chain",
|
||||
# remove those, so we maintain a strictly linear undo/redo stack.
|
||||
session.query(BlockHistory).filter(BlockHistory.block_id == block.id, BlockHistory.sequence_number > current_seq).delete()
|
||||
|
||||
# 4) Determine the next sequence number
|
||||
next_seq = current_seq + 1
|
||||
|
||||
# 5) Create a new BlockHistory row reflecting the block's current state
|
||||
history_entry = BlockHistory(
|
||||
organization_id=actor.organization_id,
|
||||
block_id=block.id,
|
||||
@ -188,14 +200,13 @@ class BlockManager:
|
||||
)
|
||||
history_entry.create(session, actor=actor, no_commit=True)
|
||||
|
||||
# 4) Update the block’s pointer
|
||||
# 6) Update the block’s pointer to the new checkpoint
|
||||
block.current_history_entry_id = history_entry.id
|
||||
|
||||
# 5) Now just flush; SQLAlchemy will:
|
||||
# 7) Flush changes, then commit once
|
||||
block = block.update(db_session=session, actor=actor, no_commit=True)
|
||||
session.commit()
|
||||
|
||||
# Return the block’s new state
|
||||
return block.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
@ -259,5 +270,4 @@ class BlockManager:
|
||||
block = block.update(db_session=session, actor=actor, no_commit=True)
|
||||
session.commit()
|
||||
|
||||
# 6) Return the block’s new state in Pydantic form
|
||||
return block.to_pydantic()
|
||||
|
@ -2727,6 +2727,42 @@ def test_checkpoint_concurrency_stale(server: SyncServer, default_user):
|
||||
)
|
||||
|
||||
|
||||
def test_checkpoint_no_future_states(server: SyncServer, default_user):
|
||||
"""
|
||||
Ensures that if the block is already at the highest sequence,
|
||||
creating a new checkpoint does NOT delete anything.
|
||||
"""
|
||||
|
||||
block_manager = BlockManager()
|
||||
|
||||
# 1) Create block with "v1" and checkpoint => seq=1
|
||||
block_v1 = block_manager.create_or_update_block(PydanticBlock(label="no_future_test", value="v1"), actor=default_user)
|
||||
block_manager.checkpoint_block(block_id=block_v1.id, actor=default_user)
|
||||
|
||||
# 2) Create "v2" and checkpoint => seq=2
|
||||
updated_data = PydanticBlock(**block_v1.dict())
|
||||
updated_data.value = "v2"
|
||||
block_manager.create_or_update_block(updated_data, actor=default_user)
|
||||
block_manager.checkpoint_block(block_id=block_v1.id, actor=default_user)
|
||||
|
||||
# So we have seq=1: v1, seq=2: v2. No "future" states.
|
||||
# 3) Another checkpoint (no changes made) => should become seq=3, not delete anything
|
||||
block_manager.checkpoint_block(block_id=block_v1.id, actor=default_user)
|
||||
|
||||
with db_context() as session:
|
||||
# We expect 3 rows in block_history, none removed
|
||||
history_rows = (
|
||||
session.query(BlockHistory).filter(BlockHistory.block_id == block_v1.id).order_by(BlockHistory.sequence_number.asc()).all()
|
||||
)
|
||||
# Should be seq=1, seq=2, seq=3
|
||||
assert len(history_rows) == 3
|
||||
assert history_rows[0].value == "v1"
|
||||
assert history_rows[1].value == "v2"
|
||||
# The last is also "v2" if we didn't change it, or the same current fields
|
||||
assert history_rows[2].sequence_number == 3
|
||||
# There's no leftover row that was deleted
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# Block Manager Tests - Undo
|
||||
# ======================================================================================================================
|
||||
@ -2764,6 +2800,71 @@ def test_undo_checkpoint_block(server: SyncServer, default_user):
|
||||
assert undone_block.label == "undo_test", "Label should also revert if changed (or remain the same if unchanged)"
|
||||
|
||||
|
||||
def test_checkpoint_deletes_future_states_after_undo(server: SyncServer, default_user):
|
||||
"""
|
||||
Verifies that once we've undone to an earlier checkpoint, creating a new
|
||||
checkpoint removes any leftover 'future' states that existed beyond that sequence.
|
||||
"""
|
||||
block_manager = BlockManager()
|
||||
|
||||
# 1) Create block
|
||||
block_init = PydanticBlock(label="test_truncation", value="v1")
|
||||
block_v1 = block_manager.create_or_update_block(block_init, actor=default_user)
|
||||
# Checkpoint => seq=1
|
||||
block_manager.checkpoint_block(block_id=block_v1.id, actor=default_user)
|
||||
|
||||
# 2) Update to "v2", checkpoint => seq=2
|
||||
block_v2 = PydanticBlock(**block_v1.dict())
|
||||
block_v2.value = "v2"
|
||||
block_manager.create_or_update_block(block_v2, actor=default_user)
|
||||
block_manager.checkpoint_block(block_id=block_v1.id, actor=default_user)
|
||||
|
||||
# 3) Update to "v3", checkpoint => seq=3
|
||||
block_v3 = PydanticBlock(**block_v1.dict())
|
||||
block_v3.value = "v3"
|
||||
block_manager.create_or_update_block(block_v3, actor=default_user)
|
||||
block_manager.checkpoint_block(block_id=block_v1.id, actor=default_user)
|
||||
|
||||
# We now have three states in history: seq=1 (v1), seq=2 (v2), seq=3 (v3).
|
||||
|
||||
# Undo from seq=3 -> seq=2
|
||||
block_undo_1 = block_manager.undo_checkpoint_block(block_v1.id, actor=default_user)
|
||||
assert block_undo_1.value == "v2"
|
||||
|
||||
# Undo from seq=2 -> seq=1
|
||||
block_undo_2 = block_manager.undo_checkpoint_block(block_v1.id, actor=default_user)
|
||||
assert block_undo_2.value == "v1"
|
||||
|
||||
# 4) Now we are at seq=1. If we checkpoint again, we should remove the old seq=2,3
|
||||
# because the new code truncates future states beyond seq=1.
|
||||
|
||||
# Let's do a new edit: "v1.5"
|
||||
block_v1_5 = PydanticBlock(**block_undo_2.dict())
|
||||
block_v1_5.value = "v1.5"
|
||||
block_manager.create_or_update_block(block_v1_5, actor=default_user)
|
||||
|
||||
# 5) Checkpoint => new seq=2, removing the old seq=2 and seq=3
|
||||
block_manager.checkpoint_block(block_id=block_v1.id, actor=default_user)
|
||||
|
||||
with db_context() as session:
|
||||
# Let's see which BlockHistory rows remain
|
||||
history_entries = (
|
||||
session.query(BlockHistory).filter(BlockHistory.block_id == block_v1.id).order_by(BlockHistory.sequence_number.asc()).all()
|
||||
)
|
||||
|
||||
# We expect two rows: seq=1 => "v1", seq=2 => "v1.5"
|
||||
assert len(history_entries) == 2, f"Expected 2 entries, got {len(history_entries)}"
|
||||
assert history_entries[0].sequence_number == 1
|
||||
assert history_entries[0].value == "v1"
|
||||
assert history_entries[1].sequence_number == 2
|
||||
assert history_entries[1].value == "v1.5"
|
||||
|
||||
# No row should contain "v2" or "v3"
|
||||
existing_values = {h.value for h in history_entries}
|
||||
assert "v2" not in existing_values, "Old seq=2 should have been removed."
|
||||
assert "v3" not in existing_values, "Old seq=3 should have been removed."
|
||||
|
||||
|
||||
def test_undo_no_history(server: SyncServer, default_user):
|
||||
"""
|
||||
If a block has never been checkpointed (no current_history_entry_id),
|
||||
|
Loading…
Reference in New Issue
Block a user