feat: FP8 rollout in GRPO for MoE models#1175
feat: FP8 rollout in GRPO for MoE models#1175guyueh1 wants to merge 13 commits intoNVIDIA-NeMo:mainfrom
Conversation
Signed-off-by: Guyue Huang <guyueh@nvidia.com>
Signed-off-by: Guyue Huang <guyueh@nvidia.com>
Signed-off-by: Guyue Huang <guyueh@nvidia.com>
|
The Qwen30B GRPO test with FP8 e2e is blocked by bumping vllm to 0.10.1 because we have to include this commit vllm-project/vllm@2212cd6; also depends on #1163 and this commit vllm-project/vllm@7e0b121 Moonlight16B/DeepseekV3 tests depend on adding flashinfer to the dependency |
📝 WalkthroughWalkthroughAdds a new GRPO FP8 config for Qwen2.5 30B A3/Megatron. Extends FP8 generation to support MoE in vLLM by patching FusedMoE weight handling, enabling FP8 MoE initialization, and adding FlashInfer MOE FP8 integration with guards and traversal updates. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Trainer
participant FP8 as FP8 Init
participant vLLM as vLLM Model
participant MoE as FusedMoE
participant Patch as FP8 Patches
Trainer->>FP8: init_fp8(vllm_cfg, model_name, mp_size)
alt use_flashinfer_moe_fp8 enabled
FP8->>FP8: set VLLM_USE_FLASHINFER_MOE_FP8=1
end
FP8->>Patch: apply_fp8_patches()
Patch->>vLLM: override Fp8MoEMethod.process_weights_after_loading
Note over vLLM,MoE: Model/weights loading
vLLM->>MoE: load weights
alt module is FusedMoE and FP8 MOE path enabled
vLLM->>Patch: process_weights_after_loading_moe(layer)
Patch->>MoE: detect w13/w2 FP8, optional swap/integration
else non-MoE or FP8 MOE disabled
vLLM->>Patch: process_weights_after_loading(layer)
Patch->>vLLM: standard FP8 linear handling
end
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (5 passed)
✨ Finishing touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
nemo_rl/models/generation/fp8.py (2)
127-135: Bug: list.append called with 3 args. Use extend to add multiple patchers.This raises a TypeError at runtime, preventing patches from being applied.
Apply this diff:
- fp8_state.vllm_patches.append(patcher2, patcher3, patcher4) + fp8_state.vllm_patches.extend([patcher2, patcher3, patcher4])
248-276: Robustness: guard empty module_path and missing packed_modules_mapping; avoid premature AttributeError.reversed_mapping is computed before try and module_path[-1] may IndexError when name has no path. Also model might not define packed_modules_mapping.
Apply this diff:
def _get_module_from_param_name(model, name: str): @@ - path_parts = name.split(".") - module_path = path_parts[:-1] - # Replace with the fused model name - packed_modules_mapping = model.packed_modules_mapping - reversed_mapping = { - original_name: fused_name - for fused_name, original_names_list in packed_modules_mapping.items() - for original_name in original_names_list - } - if module_path[-1] in reversed_mapping.keys(): - module_path[-1] = reversed_mapping[module_path[-1]] + path_parts = name.split(".") + module_path = path_parts[:-1] + # Replace with the fused model name (if present) + packed_modules_mapping = getattr(model, "packed_modules_mapping", {}) or {} + reversed_mapping = { + original_name: fused_name + for fused_name, original_names_list in packed_modules_mapping.items() + for original_name in original_names_list + } + if module_path: + last = module_path[-1] + if last in reversed_mapping: + module_path[-1] = reversed_mapping[last] @@ - except (AttributeError, IndexError, ValueError) as e: - print(f"Warning: Could not find module for parameter '{name}'. Error: {e}") + except (AttributeError, IndexError, ValueError) as e: + logger = __import__("logging").getLogger(__name__) + logger.debug("Could not find module for parameter '%s': %s", name, e) return current_moduleAdditionally add (outside this hunk) once near imports:
import logging logger = logging.getLogger(__name__)
🧹 Nitpick comments (7)
examples/configs/grpo_math_qwen30ba3b_megatron_fp8.yaml (3)
23-29: Expose all FP8/MoE knobs used by code (documented defaults).init_fp8 reads several keys from vllm_cfg (pow2_* flags, bf16 layer counts, async_engine, use_flashinfer_moe_fp8). Surface them here with explicit, documented defaults to keep YAML as the single source of truth per repo guidelines.
Apply this diff:
generation: vllm_cfg: tensor_parallel_size: 8 - expert_parallel_size: 8 # need to make moe_intermediate_size / expert_tensor_parallel_size % 128 == 0 + expert_parallel_size: 8 # ensure (moe_intermediate_size / expert_parallel_size) % 128 == 0 precision: "fp8" use_deep_gemm: true gpu_memory_utilization: 0.6 + # FP8/MoE controls (documented defaults) + async_engine: false + use_flashinfer_moe_fp8: false # set true when FlashInfer MoE FP8 kernels are available + pow2_weight_scaling_factors: false + pow2_activation_scaling_factors: false + num_first_layers_in_bf16: 0 + num_last_layers_in_bf16: 0
26-26: Fix misleading comment (var name mismatch).“expert_tensor_parallel_size” isn’t a field here; the knob is expert_parallel_size. Updated above.
29-29: Add newline at end of file.YAMLLint flagged this. Keeps tooling happy.
Apply this diff:
- gpu_memory_utilization: 0.6 \ No newline at end of file + gpu_memory_utilization: 0.6 +nemo_rl/models/generation/fp8.py (4)
298-314: Prefer tuples for (name, tensor) pairs and keep dtypes explicit.Small hygiene to match common loaders and avoid accidental mutation.
Apply this diff:
- param_lp, param_scale = cast_tensor_to_fp8_blockwise( - v.to(torch.float), + param_lp, param_scale = cast_tensor_to_fp8_blockwise( + v.to(torch.float32), weight_block_size=FP8_BLOCK_QUANT_KWARGS["weight_block_size"], ) param_scale = torch.squeeze(param_scale, dim=-1) - weights_quantized.append([k, param_lp]) - weights_quantized.append([k + "_scale_inv", param_scale]) + weights_quantized.append((k, param_lp)) + weights_quantized.append((k + "_scale_inv", param_scale))
400-417: Add a brief docstring and no_grad guard; keep assignments safe.Clarifies behavior and avoids autograd surprises.
Apply this diff:
-def process_weights_after_loading_moe(self, layer) -> None: - +def process_weights_after_loading_moe(self, layer) -> None: + """Post-load fixups for FP8 MoE layers. + + - Asserts FP8 block-quant serialized checkpoints with dynamic activation. + - When FlashInfer MoE FP8 is enabled, swaps W13 layout to W31. + """ @@ - flashinfer_moe_enabled = envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe() - if flashinfer_moe_enabled: - layer.w13_weight.data = _swap_w13_to_w31(layer.w13_weight.data) - layer.w13_weight_scale_inv.data = _swap_w13_to_w31(layer.w13_weight_scale_inv.data) + flashinfer_moe_enabled = envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe() + if flashinfer_moe_enabled: + import torch + with torch.no_grad(): + layer.w13_weight.data = _swap_w13_to_w31(layer.w13_weight.data) + layer.w13_weight_scale_inv.data = _swap_w13_to_w31(layer.w13_weight_scale_inv.data)
19-19: Unused import?ray is only needed when model_parallel_size > 1; if single‑GPU paths are dominant, consider importing locally in monkey_patch_vllm_ray_executor to speed cold import. Optional.
61-65: Global naming doesn’t follow repo guideline (G_ UPPER_SNAKE_CASE).global_fp8_config, fp8_state, fp8_patches_applied, original_run_engine_core, original_init are globals. Consider renaming to G_FP8_CONFIG, G_FP8_STATE, G_FP8_PATCHES_APPLIED, G_ORIGINAL_RUN_ENGINE_CORE, G_ORIGINAL_COREENGINEPROC_INIT.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
examples/configs/grpo_math_qwen30ba3b_megatron_fp8.yaml(1 hunks)nemo_rl/models/generation/fp8.py(6 hunks)
🧰 Additional context used
📓 Path-based instructions (3)
examples/configs/*.yaml
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
examples/configs/*.yaml: Exemplar configs under examples/configs/.yaml must include documented defaults
When adding a new config key, reflect its recommended default in exemplar YAMLs under examples/configs/.yaml
Files:
examples/configs/grpo_math_qwen30ba3b_megatron_fp8.yaml
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py: Follow the Google Python Style Guide for all Python code
Target Python 3.12+ for all Python code in NeMo-RL
Indent Python code with 4 spaces; do not use tabs
Python filenames should be snake_case (e.g., some_file.py)
Class names should be PascalCase
Function and method names should be snake_case
Local variable names should be snake_case; if starting with a number, prefix with k (e.g., k_99th_percentile)
Global variables should be UPPER_SNAKE_CASE and prefixed with G_ (e.g., G_MY_GLOBAL)
Constants should be UPPER_SNAKE_CASE
Avoid shadowing variables declared in an outer scope
Initialize all externally visible members of a class in the constructor
For public interfaces used outside a file, prefer docstrings over comments
Use comments mainly for code within a function or interfaces local to a file
Commented-out code must include a nearby comment explaining usage and why it is commented out; otherwise remove before merging
Use Google-style docstrings for classes and functions (Sphinx-parseable)
Avoid using reflection when functionality can be easily achieved without it
Limit except clauses to the smallest specific set of exceptions possible
For duck-typing via try/except, keep the try body minimal and use else for main logic
Add the NVIDIA copyright header (with current year) at the top of all Python files, excluding tests/ and test-only scripts
Files:
nemo_rl/models/generation/fp8.py
nemo_rl/**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
nemo_rl/**/*.py: Do not set non-None configuration defaults in code; YAML is the single source of truth for defaults
Access required config attributes directly (e.g., policy_cfg["precision"]) and assume presence; do not introduce hidden defaults
Express configuration optionality via TypedDict using typing.NotRequired
When adding a new config key to a TypedDict subclass, document the key’s purpose, valid values/types, and recommended default in code
For any class or function decorated with @ray.remote, add '# pragma: no cover' on the class/def line (and on remote functions)
Files:
nemo_rl/models/generation/fp8.py
🪛 YAMLlint (1.37.1)
examples/configs/grpo_math_qwen30ba3b_megatron_fp8.yaml
[error] 29-29: no new line character at the end of file
(new-line-at-end-of-file)
🪛 GitHub Actions: CICD NeMo RL
nemo_rl/models/generation/fp8.py
[error] 1-1: Trailing whitespace found. pre-commit hook 'trailing-whitespace' failed; files were modified by this hook.
[error] 1-1: ruff hook failed: 2 issues found and fixed, some lint issues remain. See CI logs for details.
[error] 1-1: ruff-format hook failed: 1 file reformatted; 207 files left unchanged.
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: Post submodule check comment / Comment on PR
- GitHub Check: Post automodel integration comment / Comment on PR
🔇 Additional comments (2)
nemo_rl/models/generation/fp8.py (2)
156-160: Wire flag -> env var for FlashInfer MoE FP8 is fine, but ensure config key exists in YAML.The code reads use_flashinfer_moe_fp8; the exemplar YAML should expose it (default false) to avoid hidden defaults.
Please confirm the YAML change in examples/configs/grpo_math_qwen30ba3b_megatron_fp8.yaml includes use_flashinfer_moe_fp8 as proposed.
1-1: Run pre-commit and fix trailing-whitespace / ruff-format in nemo_rl/models/generation/fp8.pyCI reported trailing-whitespace and ruff-format edits for nemo_rl/models/generation/fp8.py. pre-commit is not available in the verification environment; run locally:
pre-commit run -aand commit the auto-fixes (or runruff --fixand strip trailing whitespace from the file).
Signed-off-by: Guyue Huang <guyueh@nvidia.com>
Signed-off-by: Guyue Huang <guyueh@nvidia.com>
Signed-off-by: Guyue Huang <guyueh@nvidia.com>
Signed-off-by: Guyue Huang <guyueh@nvidia.com>
| "megatron-bridge", | ||
| # Remove this once https://github.com/NVIDIA-NeMo/RL/issues/501 resolved | ||
| "vllm==0.10.0", | ||
| "vllm==0.10.1.1", |
There was a problem hiding this comment.
I suggest to file another PR to bump vllm to a newer version, and we should run some convergence test to make sure the bump won't break anything.
There was a problem hiding this comment.
We can work on this minor version bump in a separate PR #1199
then rebase this PR after 1199 PR is merged
| # Remove this once https://github.com/NVIDIA-NeMo/RL/issues/811 resolved | ||
| "transformers>=4.51.0,<4.54.0", | ||
| "ray[default]==2.48.0", | ||
| "transformers", |
There was a problem hiding this comment.
transformers will be updated to >=4.55.4 in #1115, is this version ok with the FP8 rollout feature?
There was a problem hiding this comment.
yes 4.55.4 is ok; i will rebase after 1115
|
Closing this, will start a new PR based on #1334 |
What does this PR do ?
Support FP8 rollout (deepseek-style quantization) in GRPO for MoE models
This PR depends on the following PRs so do not merge until these are merged
Usage
uv run examples/run_grpo_math.py \ --config examples/configs/grpo_math_qwen30ba3b_megatron.yaml \ cluster.num_nodes=2 \ policy.megatron_cfg.activation_checkpointing=True \ policy.generation.vllm_cfg.precision="fp8" \ policy.generation.vllm_cfg.use_deep_gemm=true \ loss_fn.use_importance_sampling_correction=True \ policy.generation.vllm_cfg.tensor_parallel_size=2 \ policy.megatron_cfg.pipeline_model_parallel_size=2Additionally you can choose to use e2e fp8 (in both rollout and policy) , this depends on #971 to be merged
Before your PR is "Ready for review"
Pre checks:
Additional Information
Summary by CodeRabbit