[fsdp2] feat: add per-layer GPU optimizer step with async prefetch pipeline#5364
[fsdp2] feat: add per-layer GPU optimizer step with async prefetch pipeline#5364aoshen524 wants to merge 4 commits intoverl-project:mainfrom
Conversation
…peline When FSDP2 + optimizer offload trains large models (e.g., 32B), CPU Adam takes ~324s processing ~67GB optimizer states. Loading all states to GPU at once causes OOM. Per-layer GPU optimizer step streams 1-2 layers at a time (~1.5GB each) to GPU using 3-stream async pipeline (H2D/Compute/D2H), achieving ~50-80x speedup (324s -> 4.4s). Supports both CPUOffloadPolicy=True (full offload) and False (optimizer-only offload) modes. Changes: - Add PerLayerGPUOptimizerStep class in fsdp_utils.py - Add per_layer_optimizer_step, optimizer_step_prefetch_layers config - Wire into fsdp_workers.py update_actor() with stepper lifecycle - Add optimizer step dispatch + perf metrics in dp_actor.py - Add 5 unit tests (layer grouping, correctness, multi-step, prefetch) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
There was a problem hiding this comment.
Code Review
The pull request introduces a high-performance per-layer GPU optimizer step for FSDP2, which significantly reduces the time spent in Adam updates when using optimizer offload. The implementation uses a 3-stream async pipeline to overlap data transfers with computation. While the core logic is sound and provides impressive speedups, there are several critical efficiency and correctness issues that should be addressed:
- Performance Bottlenecks: The use of
torch.cuda.empty_cache()twice per iteration and the creation of new CUDA streams on every step will degrade performance. - Resource Management: The stepper is recreated every iteration, leading to redundant module scanning and parameter grouping.
- Correctness: The peak memory metric is not reset per step, and there is a potential crash if the optimizer's
stepstate is a Python float. - Async Pipeline Stalls: Parameters and gradients are not pinned when on CPU, which makes the 'async' transfers synchronous and stalls the pipeline.
| self.device = torch.device(f"cuda:{device_id}") if isinstance(device_id, int) else torch.device(device_id) | ||
| self.prefetch_layers = prefetch_layers | ||
| self._layer_param_groups = self._build_layer_groups(model) | ||
| self._init_states_and_pin() |
There was a problem hiding this comment.
Creating new CUDA streams inside the step() method on every call is inefficient. It is better to initialize the h2d_stream and d2h_stream once in the constructor and reuse them.
| self.device = torch.device(f"cuda:{device_id}") if isinstance(device_id, int) else torch.device(device_id) | |
| self.prefetch_layers = prefetch_layers | |
| self._layer_param_groups = self._build_layer_groups(model) | |
| self._init_states_and_pin() | |
| self.device = torch.device(f"cuda:{device_id}") if isinstance(device_id, int) else torch.device(device_id) | |
| self.prefetch_layers = prefetch_layers | |
| self._layer_param_groups = self._build_layer_groups(model) | |
| self._init_states_and_pin() | |
| self.h2d_stream = torch.cuda.Stream(device=self.device) | |
| self.d2h_stream = torch.cuda.Stream(device=self.device) |
| for key in ("exp_avg", "exp_avg_sq"): | ||
| if key in state: | ||
| local = self._get_local_tensor(state[key]) | ||
| if local.device.type != "cpu": | ||
| state[key] = local.to("cpu") | ||
| # Pin optimizer state tensors for async transfers | ||
| for key in ("exp_avg", "exp_avg_sq", "step"): | ||
| local = self._get_local_tensor(state[key]) | ||
| if local.device.type == "cpu" and not local.is_pinned(): | ||
| local.data = local.pin_memory() |
There was a problem hiding this comment.
This block has two issues:
- It doesn't handle the case where
state['step']is a Python float (common in some Adam implementations), which will cause anAttributeErrorwhen accessing.deviceor a crash in_prefetch_layer. - To ensure the 3-stream pipeline is truly asynchronous, the parameters and gradients should also be pinned if they reside on the CPU (offload mode). Otherwise,
to(non_blocking=True)will behave synchronously.
| for key in ("exp_avg", "exp_avg_sq"): | |
| if key in state: | |
| local = self._get_local_tensor(state[key]) | |
| if local.device.type != "cpu": | |
| state[key] = local.to("cpu") | |
| # Pin optimizer state tensors for async transfers | |
| for key in ("exp_avg", "exp_avg_sq", "step"): | |
| local = self._get_local_tensor(state[key]) | |
| if local.device.type == "cpu" and not local.is_pinned(): | |
| local.data = local.pin_memory() | |
| for key in ("exp_avg", "exp_avg_sq", "step"): | |
| if key in state: | |
| val = state[key] | |
| if not isinstance(val, torch.Tensor): | |
| state[key] = torch.tensor(float(val), dtype=torch.float32, device="cpu") | |
| else: | |
| local = self._get_local_tensor(val) | |
| if local.device.type != "cpu": | |
| state[key] = local.to("cpu") | |
| # Pin optimizer state tensors and param/grad for async transfers | |
| for key in ("exp_avg", "exp_avg_sq", "step"): | |
| local = self._get_local_tensor(state[key]) | |
| if local.device.type == "cpu" and not local.is_pinned(): | |
| local.data = local.pin_memory() | |
| for t in (param, param.grad): | |
| if t is not None: | |
| local_t = self._get_local_tensor(t.data) | |
| if local_t.device.type == "cpu" and not local_t.is_pinned(): | |
| local_t.data = local_t.pin_memory() |
verl/utils/fsdp_utils.py
Outdated
| h2d_stream = torch.cuda.Stream(device=self.device) | ||
| d2h_stream = torch.cuda.Stream(device=self.device) | ||
| compute_stream = torch.cuda.current_stream(self.device) |
There was a problem hiding this comment.
Reuse the persistent streams created in the constructor to avoid the overhead of stream creation on every step. Also, reset the peak memory statistics at the start of the step to ensure the peak_memory_gb metric reflects the usage of the current step rather than the global peak.
| h2d_stream = torch.cuda.Stream(device=self.device) | |
| d2h_stream = torch.cuda.Stream(device=self.device) | |
| compute_stream = torch.cuda.current_stream(self.device) | |
| torch.cuda.reset_peak_memory_stats(self.device) | |
| h2d_stream = self.h2d_stream | |
| d2h_stream = self.d2h_stream | |
| compute_stream = torch.cuda.current_stream(self.device) |
| # blocks to CUDA driver so forward/backward can't repurpose them. | ||
| # Without this, each optimizer step leaks ~1.7 GB of device memory | ||
| # because caching allocator blocks get "stolen" by gradient allocation. | ||
| torch.cuda.empty_cache() |
There was a problem hiding this comment.
torch.cuda.empty_cache() is an extremely expensive operation that synchronizes the GPU and defragments memory. Calling it at the end of every optimizer step will significantly degrade performance and negate the benefits of the async pipeline. If memory pressure is the concern, consider making this optional or investigating why the caching allocator isn't reusing blocks efficiently.
verl/workers/fsdp_workers.py
Outdated
| if _use_per_layer: | ||
| stepper = PerLayerGPUOptimizerStep( | ||
| model=self.actor_module_fsdp, | ||
| optimizer=self.actor_optimizer, | ||
| device_id=get_device_id(), | ||
| prefetch_layers=self.config.actor.fsdp_config.get("optimizer_step_prefetch_layers", 1), | ||
| ) | ||
| self.actor._per_layer_optimizer_stepper = stepper |
There was a problem hiding this comment.
Instantiating the stepper on every update_actor call is inefficient because it re-scans the model modules and re-groups parameters every iteration. The stepper should be cached in self.actor and reused.
| if _use_per_layer: | |
| stepper = PerLayerGPUOptimizerStep( | |
| model=self.actor_module_fsdp, | |
| optimizer=self.actor_optimizer, | |
| device_id=get_device_id(), | |
| prefetch_layers=self.config.actor.fsdp_config.get("optimizer_step_prefetch_layers", 1), | |
| ) | |
| self.actor._per_layer_optimizer_stepper = stepper | |
| if _use_per_layer: | |
| stepper = getattr(self.actor, "_per_layer_optimizer_stepper", None) | |
| if stepper is None: | |
| stepper = PerLayerGPUOptimizerStep( | |
| model=self.actor_module_fsdp, | |
| optimizer=self.actor_optimizer, | |
| device_id=get_device_id(), | |
| prefetch_layers=self.config.actor.fsdp_config.get("optimizer_step_prefetch_layers", 1), | |
| ) | |
| self.actor._per_layer_optimizer_stepper = stepper |
verl/workers/fsdp_workers.py
Outdated
| if _use_per_layer: | ||
| self.actor._per_layer_optimizer_stepper = None | ||
| torch.cuda.empty_cache() |
Breaking change: PerLayerGPUOptimizerStep now requires params/grads on GPU (offload_policy=False). Using CPUOffloadPolicy raises ValueError. Changes based on Gemini review feedback: - Move CUDA stream creation to __init__ (reuse across step() calls) - Handle state['step'] as Python float/int (defensive) - Add reset_peak_memory_stats for accurate per-step peak metrics - Fix decoupled_weight_decay: detect AdamW via isinstance instead of relying on param_group key (was silently wrong for AdamW) - Cache stepper in fsdp_workers.py (avoid repeated module scan/pin) - Remove per-call stepper destruction (cached for reuse) Simplifications from GPU-only mode: - _prefetch_layer: params/grads are references (no H2D copy needed) - _offload_layer: only optimizer states D2H (params updated in-place) - _validate_gpu_params: explicit check at init time - Tests: replace CPU offload test with ValueError assertion test Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…alidation - Assert AdamW optimizer (TypeError if not) - Add _validate_single_hyperparam_set(): ensures all param_groups have identical hyperparams since per-layer step processes by layer not group - _run_adam_for_layer now mirrors Adam.step() (adam.py:248-270): reads all hyperparams from group dict, computes has_complex dynamically - Remove defensive state['step'] float/int handling - Remove _decoupled_weight_decay instance var (use group dict directly) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…e path) Support per_layer_optimizer_step in the FSDPEngine (disable_legacy_worker=True) code path: - initialize(): create PerLayerGPUOptimizerStep after _build_model_optimizer() while params are still on GPU, with offload_policy validation - optimizer_step(): use stepper.step() instead of optimizer.step() - to(): skip bulk optimizer load/offload when per-layer stepper is active (stepper manages its own optimizer states via async H2D/D2H) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…pipeline Port of verl-project/verl#5364. Instead of materializing all optimizer states on GPU at once (~46.74GB for 7B), streams 1-2 layers at a time (~1.68GB) using a 3-stream async pipeline (H2D/compute/D2H), achieving ~50-80x speedup over CPU Adam while avoiding OOM. Changes: - Add PerLayerGPUOptimizerStep class in fsdp_utils/optimizer.py - Add per_layer_optimizer_step and optimizer_step_prefetch_layers config fields - Wire into FSDPEngine.optimizer_step() with performance metrics - Add 5 test cases covering layer grouping, correctness, multi-step, and prefetch Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Summary
For very large models, optimizer states can become a major memory bottleneck.
In the default path, Adam initializes states lazily at the first
optimizer.step(), which can trigger OOM.Also, optimizer states are not needed during forward/backward. Keeping full optimizer states resident on GPU outside optimizer step only increases peak memory pressure.
This PR introduces
per_layer_optimizer_stepto stream optimizer states layer-by-layer only during optimizer step, instead of materializing full optimizer states on GPU at once.Basic Usage
Enable the following flags in actor FSDP config:
Experiments
Comparison setup
per_layer_optimizer_stepper_layer_optimizer_stepResults
per_layer_optimizer_step, optimizer-step memory increase is small and controlled:Precision alignment (grad dump comparison)
To verify that the per-layer optimizer step does not affect training numerics, we compared per-parameter gradient tensors dumped before the optimizer step between the baseline and optimized paths.
Setup:
full_determinism=True,seed=42)VERL_HARDCODE_ADVANTAGE=1) to eliminate reward randomnessVERL_PRECISION_DUMP_GRADS(saves per-param grad tensors as.ptfiles)tools/compare_grads.py(per-param max_diff, mean_diff, cosine_sim)Result (step 0, rank 0, 729 params):
All 729 parameters have bitwise identical gradients (
max_diff = 0) between baseline and optimized paths. This confirms that the per-layer optimizer step does not alter forward/backward computation — it only changes how optimizer states are streamed during the step itself.Reviewer takeaway