[fsdp] feat: checkpoint input CPU offload for gradient checkpointing#5363
[fsdp] feat: checkpoint input CPU offload for gradient checkpointing#5363aoshen524 wants to merge 1 commit intoverl-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces two significant memory optimization features for FSDP2 training: per-layer GPU optimizer step and checkpoint input offload. The changes are well-structured, with new functionalities encapsulated in PerLayerGPUOptimizerStep and CheckpointInputOffload classes, along with comprehensive unit tests and documentation. The integration into the existing FSDP worker logic is clean.
My review focuses on improving the long-term maintainability and robustness of the implementation. I've identified a critical issue regarding the use of a private PyTorch API in checkpoint_offload.py, which could break with future PyTorch updates. I've also suggested an improvement in fsdp_utils.py to use a more standard import path for the functional Adam optimizer.
Overall, this is a great contribution that should significantly improve training performance and memory efficiency.
| def __enter__(self): | ||
| self._reset_diag() | ||
| self._enter_time = time.monotonic() | ||
| torch._C._autograd._push_saved_tensors_default_hooks(self._pack, self._unpack) | ||
| return self | ||
|
|
||
| def __exit__(self, *args): | ||
| torch._C._autograd._pop_saved_tensors_default_hooks() | ||
| if self.d2h_stream is not None: | ||
| self.d2h_stream.synchronize() | ||
| self._exit_time = time.monotonic() |
There was a problem hiding this comment.
The implementation uses private PyTorch C-APIs (_push_saved_tensors_default_hooks and _pop_saved_tensors_default_hooks). This is risky as private APIs can change without notice in future PyTorch versions, breaking this functionality.
PyTorch provides a public context manager torch.autograd.graph.saved_tensors_hooks since version 2.0 which should be used instead for better forward compatibility and maintainability.
You should initialize the context manager in __init__ and use it in __enter__ and __exit__.
In __init__:
import torch.autograd.graph
# ...
self._hook_context = torch.autograd.graph.saved_tensors_hooks(self._pack, self._unpack)Then, __enter__ and __exit__ can be updated as suggested.
| def __enter__(self): | |
| self._reset_diag() | |
| self._enter_time = time.monotonic() | |
| torch._C._autograd._push_saved_tensors_default_hooks(self._pack, self._unpack) | |
| return self | |
| def __exit__(self, *args): | |
| torch._C._autograd._pop_saved_tensors_default_hooks() | |
| if self.d2h_stream is not None: | |
| self.d2h_stream.synchronize() | |
| self._exit_time = time.monotonic() | |
| def __enter__(self): | |
| self._reset_diag() | |
| self._enter_time = time.monotonic() | |
| self._hook_context.__enter__() | |
| return self | |
| def __exit__(self, *args): | |
| self._hook_context.__exit__(*args) | |
| if self.d2h_stream is not None: | |
| self.d2h_stream.synchronize() | |
| self._exit_time = time.monotonic() |
verl/utils/fsdp_utils.py
Outdated
|
|
||
| def _run_adam_for_layer(self, gpu_states): | ||
| """Call torch.optim.adam.adam() functional API on GPU tensors.""" | ||
| from torch.optim.adam import adam |
There was a problem hiding this comment.
The adam function is being imported from torch.optim.adam, which is not its canonical public location. For better forward-compatibility and adherence to PyTorch's API structure, it's recommended to import the functional optimizer implementations from torch.optim._functional.
| from torch.optim.adam import adam | |
| from torch.optim._functional import adam |
Offload gradient checkpoint saved tensors to CPU pinned memory during forward, transfer back to GPU during backward. Reduces GPU memory by ~8.5 GB per forward pass for large models (e.g., Qwen2.5-VL-32B). Key design: - Exploit PyTorch saved_tensors_hooks "innermost wins" nesting so only checkpoint inputs are offloaded (not recomputation intermediates) - Async D2H via dedicated CUDA stream with wait_stream synchronization - Skip parameters, small tensors, and non-CUDA tensors - Build-time ValueError when combined with PrefixGrouper (incompatible) New config flag: actor_rollout_ref.actor.fsdp_config.checkpoint_input_offload=true Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
591a40b to
387b37a
Compare
Summary
Add checkpoint input CPU offload for FSDP2 training with CPU offload, reducing GPU memory usage during forward pass.
CheckpointInputOffload): Offloads gradient checkpoint saved tensors to CPU pinned memory during forward using PyTorch'ssaved_tensors_hooks"innermost wins" nesting, transfers back to GPU during backward. Reduces forward memory delta from ~9 GB to ~0.5 GB for 32B models.Key design:
wait_streamsynchronizationValueErrorwhen combined withPrefixGrouper(incompatible)Benchmark (Qwen2.5-VL-32B, 8×H100 80GB, FSDP2, SP_SIZE=2):
Configuration
Files Changed
verl/utils/checkpoint_offload.pyCheckpointInputOffloadclass (pack/unpack hooks, async D2H)verl/workers/fsdp_workers.pyverl/workers/actor/dp_actor.pyverl/workers/config/engine.pycheckpoint_input_offloadverl/trainer/config/engine/fsdp.yamltests/test_checkpoint_offload.pydocs/perf/checkpoint_input_offload.mdTest plan
pytest tests/test_checkpoint_offload.py -v— numerical correctness, memory reduction, edge casescheckpoint_input_offload=trueon multi-GPU setupfalse, no behavior change when disabled🤖 Generated with Claude Code