Skip to content

Feature: Implement aprune() / prune() with keep_last=N strategy for interrupt-safe checkpoint pruning #159

@Mukhsin0508

Description

@Mukhsin0508

Summary

AsyncRedisSaver.aprune() and RedisSaver.prune() currently raise NotImplementedError. We'd like to implement them — and propose a keep_last: int parameter alongside the existing strategy to make pruning safe for multi-tool interrupt chains.

We've been running a production LangGraph deployment on Redis and built our own pruning layer to work around the missing implementation. This issue documents what we faced, what we built, and the PR we intend to open.


The Problem

LangGraph's AsyncRedisSaver creates a new checkpoint after every agent step (model call, tool execution, state update). For a thread with many turns and multi-tool interrupts this produces unbounded Redis growth. The fix seems obvious — call aprune() — but it raises NotImplementedError.

The non-obvious danger: multi-tool interrupt chains

Simply keeping only the latest checkpoint (what strategy="keep_latest" implies) breaks human-in-the-loop flows.

When an AIMessage contains N tool calls, each raises its own sequential interrupt. LangGraph tracks which tools in the batch have already completed via checkpoint_write entries stored under intermediate checkpoints. If you delete those intermediate checkpoints before all tool calls in the batch have resolved, LangGraph loses that tracking state and re-executes already-completed tools on the next resume — producing duplicate messages, duplicate task creation, duplicate sends, etc.

This is easy to reproduce: trigger an agent turn with 3–5 simultaneous tool calls, interrupt mid-batch, prune to keep_last=1, then resume. Every completed tool fires again.


What We Built (Workaround)

We implemented a custom cleanup_old_checkpoints utility that:

  1. Scans all checkpoint:* keys for a thread via raw Redis SCAN + JSON.GET pipeline (your implementation could use checkpoints_index search instead for O(log n))
  2. Sorts checkpoints by checkpoint_ts descending — handling both new-format (ISO 8601 / ms integer) and old-format (UUIDv6 Gregorian hex timestamp fallback)
  3. Retains the latest keep_last checkpoints (default: 20), which comfortably covers any in-flight multi-tool interrupt batch
  4. Deletes evicted checkpoints along with their associated checkpoint_write:* and write_keys_zset:* keys
  5. Handles edge cases: keys that expire between SCAN and pipeline (silently skipped), malformed/missing checkpoint_ts (UUIDv6 lexicographic sort fallback), all-orphan threads

The dynamic window we use: keep_last = max(10, n_tool_calls * 3 + 5) — 3 checkpoints per interrupted tool + 5 overhead buffer. For a 5-recipient send: keep_last=20. For a single-tool turn: keep_last=10.


Proposed API

We'd like to open a PR implementing aprune() in AsyncRedisSaver (and prune() in RedisSaver) using the existing checkpoints_index search — consistent with how adelete_thread is implemented — and proposing a keep_last parameter on the base interface:

async def aprune(
    self,
    thread_ids: Sequence[str],
    *,
    strategy: str = "keep_latest",
    keep_last: int = 1,
) -> None:
    """
    Strategies:
      "keep_latest"  — retain only the single most recent checkpoint (keep_last=1)
      "keep_last"    — retain the N most recent checkpoints (keep_last=N),
                       safe for multi-tool interrupt chains
      "delete"       — remove all checkpoints
    """

The keep_last parameter on the base class would require a companion PR to langchain-ai/langgraphlibs/checkpoint/. We're happy to open both, or to discuss an alternative API shape (e.g. strategy="keep_last:20") if you prefer keeping the base signature stable.


Implementation Sketch

Using the existing index infrastructure (consistent with adelete_thread):

async def aprune(
    self,
    thread_ids: Sequence[str],
    *,
    strategy: str = "keep_latest",
    keep_last: int = 1,
) -> None:
    for thread_id in thread_ids:
        storage_safe_thread_id = to_storage_safe_id(thread_id)

        # 1. Fetch all checkpoints for thread via index
        query = FilterQuery(
            filter_expression=Tag("thread_id") == storage_safe_thread_id,
            return_fields=["checkpoint_ns", "checkpoint_id"],
            num_results=10000,
        )
        results = await self.checkpoints_index.search(query)
        docs = results.docs

        if strategy == "delete":
            keep_n = 0
        elif strategy in ("keep_latest", "keep_last"):
            keep_n = 1 if strategy == "keep_latest" else keep_last
        else:
            raise ValueError(f"Unknown strategy: {strategy!r}")

        # 2. Sort by checkpoint_id descending (UUIDv7/ULID = time-ordered)
        docs_sorted = sorted(docs, key=lambda d: d.checkpoint_id, reverse=True)
        to_evict = docs_sorted[keep_n:]

        if not to_evict:
            return

        keys_to_delete = []
        for doc in to_evict:
            ns, cp_id = doc.checkpoint_ns, doc.checkpoint_id

            # checkpoint key
            keys_to_delete.append(
                self._make_redis_checkpoint_key(storage_safe_thread_id, ns, cp_id)
            )
            # associated write keys
            writes_query = FilterQuery(
                filter_expression=(
                    (Tag("thread_id") == storage_safe_thread_id)
                    & (Tag("checkpoint_id") == cp_id)
                ),
                return_fields=["checkpoint_ns", "checkpoint_id", "task_id", "idx"],
                num_results=10000,
            )
            for wdoc in (await self.checkpoint_writes_index.search(writes_query)).docs:
                keys_to_delete.append(
                    self._make_redis_checkpoint_writes_key(
                        storage_safe_thread_id,
                        wdoc.checkpoint_ns,
                        wdoc.checkpoint_id,
                        wdoc.task_id,
                        int(wdoc.idx),
                    )
                )
            # write_keys_zset (key registry)
            if self._key_registry:
                keys_to_delete.append(
                    self._key_registry.make_write_keys_zset_key(thread_id, ns, cp_id)
                )

        # 3. Bulk delete
        if self.cluster_mode:
            for key in keys_to_delete:
                await self._redis.delete(key)
        else:
            pipeline = self._redis.pipeline()
            for key in keys_to_delete:
                pipeline.delete(key)
            await pipeline.execute()

Open Questions for Maintainers

  1. Is keep_last: int acceptable as a new param on BaseCheckpointSaver.aprune(), or do you prefer a different API shape?
  2. Should checkpoint_blob:* keys also be pruned per evicted checkpoint, or are blobs intentionally shared across checkpoints via channel versioning?
  3. Any preference on where conformance tests for this should live (libs/checkpoint-conformance)?

We're ready to open the PR once we align on the API. Happy to iterate.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions