Skip to content

Comments

[fsdp] feat: checkpoint input CPU offload for gradient checkpointing#5363

Open
aoshen524 wants to merge 1 commit intoverl-project:mainfrom
aoshen524:feat/checkpoint-input-offload
Open

[fsdp] feat: checkpoint input CPU offload for gradient checkpointing#5363
aoshen524 wants to merge 1 commit intoverl-project:mainfrom
aoshen524:feat/checkpoint-input-offload

Conversation

@aoshen524
Copy link
Contributor

@aoshen524 aoshen524 commented Feb 21, 2026

Summary

Add checkpoint input CPU offload for FSDP2 training with CPU offload, reducing GPU memory usage during forward pass.

  • Checkpoint input CPU offload (CheckpointInputOffload): Offloads gradient checkpoint saved tensors to CPU pinned memory during forward using PyTorch's saved_tensors_hooks "innermost wins" nesting, transfers back to GPU during backward. Reduces forward memory delta from ~9 GB to ~0.5 GB for 32B models.

Key design:

  • Async D2H via dedicated CUDA stream with wait_stream synchronization
  • Skip parameters, small tensors, and non-CUDA tensors automatically
  • Build-time ValueError when combined with PrefixGrouper (incompatible)

Benchmark (Qwen2.5-VL-32B, 8×H100 80GB, FSDP2, SP_SIZE=2):

Metric Value
Forward memory delta (with offload) +0.31 ~ +0.64 GB
Tensors offloaded per micro-batch 115
Data offloaded per step 24 ~ 44 GB

Configuration

# Checkpoint input offload
actor_rollout_ref.actor.fsdp_config.checkpoint_input_offload=true

# Typically used with param/optimizer offload
actor_rollout_ref.actor.fsdp_config.param_offload=true
actor_rollout_ref.actor.fsdp_config.optimizer_offload=true

Files Changed

File Change
verl/utils/checkpoint_offload.py New: CheckpointInputOffload class (pack/unpack hooks, async D2H)
verl/workers/fsdp_workers.py Integration: build-time init, PrefixGrouper conflict check
verl/workers/actor/dp_actor.py Integration: offload context manager wrapping model forward
verl/workers/config/engine.py Config field: checkpoint_input_offload
verl/trainer/config/engine/fsdp.yaml Default config entry
tests/test_checkpoint_offload.py 13 unit tests
docs/perf/checkpoint_input_offload.md Documentation

Test plan

  • pytest tests/test_checkpoint_offload.py -v — numerical correctness, memory reduction, edge cases
  • End-to-end training with checkpoint_input_offload=true on multi-GPU setup
  • Verify backward compatibility: default false, no behavior change when disabled

🤖 Generated with Claude Code

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces two significant memory optimization features for FSDP2 training: per-layer GPU optimizer step and checkpoint input offload. The changes are well-structured, with new functionalities encapsulated in PerLayerGPUOptimizerStep and CheckpointInputOffload classes, along with comprehensive unit tests and documentation. The integration into the existing FSDP worker logic is clean.

My review focuses on improving the long-term maintainability and robustness of the implementation. I've identified a critical issue regarding the use of a private PyTorch API in checkpoint_offload.py, which could break with future PyTorch updates. I've also suggested an improvement in fsdp_utils.py to use a more standard import path for the functional Adam optimizer.

Overall, this is a great contribution that should significantly improve training performance and memory efficiency.

Comment on lines +146 to +156
def __enter__(self):
self._reset_diag()
self._enter_time = time.monotonic()
torch._C._autograd._push_saved_tensors_default_hooks(self._pack, self._unpack)
return self

def __exit__(self, *args):
torch._C._autograd._pop_saved_tensors_default_hooks()
if self.d2h_stream is not None:
self.d2h_stream.synchronize()
self._exit_time = time.monotonic()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The implementation uses private PyTorch C-APIs (_push_saved_tensors_default_hooks and _pop_saved_tensors_default_hooks). This is risky as private APIs can change without notice in future PyTorch versions, breaking this functionality.

PyTorch provides a public context manager torch.autograd.graph.saved_tensors_hooks since version 2.0 which should be used instead for better forward compatibility and maintainability.

You should initialize the context manager in __init__ and use it in __enter__ and __exit__.

In __init__:

import torch.autograd.graph
# ...
self._hook_context = torch.autograd.graph.saved_tensors_hooks(self._pack, self._unpack)

Then, __enter__ and __exit__ can be updated as suggested.

Suggested change
def __enter__(self):
self._reset_diag()
self._enter_time = time.monotonic()
torch._C._autograd._push_saved_tensors_default_hooks(self._pack, self._unpack)
return self
def __exit__(self, *args):
torch._C._autograd._pop_saved_tensors_default_hooks()
if self.d2h_stream is not None:
self.d2h_stream.synchronize()
self._exit_time = time.monotonic()
def __enter__(self):
self._reset_diag()
self._enter_time = time.monotonic()
self._hook_context.__enter__()
return self
def __exit__(self, *args):
self._hook_context.__exit__(*args)
if self.d2h_stream is not None:
self.d2h_stream.synchronize()
self._exit_time = time.monotonic()


def _run_adam_for_layer(self, gpu_states):
"""Call torch.optim.adam.adam() functional API on GPU tensors."""
from torch.optim.adam import adam
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The adam function is being imported from torch.optim.adam, which is not its canonical public location. For better forward-compatibility and adherence to PyTorch's API structure, it's recommended to import the functional optimizer implementations from torch.optim._functional.

Suggested change
from torch.optim.adam import adam
from torch.optim._functional import adam

Offload gradient checkpoint saved tensors to CPU pinned memory during
forward, transfer back to GPU during backward. Reduces GPU memory by
~8.5 GB per forward pass for large models (e.g., Qwen2.5-VL-32B).

Key design:
- Exploit PyTorch saved_tensors_hooks "innermost wins" nesting so only
  checkpoint inputs are offloaded (not recomputation intermediates)
- Async D2H via dedicated CUDA stream with wait_stream synchronization
- Skip parameters, small tensors, and non-CUDA tensors
- Build-time ValueError when combined with PrefixGrouper (incompatible)

New config flag:
  actor_rollout_ref.actor.fsdp_config.checkpoint_input_offload=true

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@aoshen524 aoshen524 force-pushed the feat/checkpoint-input-offload branch from 591a40b to 387b37a Compare February 21, 2026 05:20
@aoshen524 aoshen524 changed the title [fsdp] feat: per-layer GPU optimizer step and checkpoint input offload [fsdp] feat: checkpoint input CPU offload for gradient checkpointing Feb 21, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant