Skip to content

feat: FP8 rollout in GRPO for MoE models#1175

Closed
guyueh1 wants to merge 13 commits intoNVIDIA-NeMo:mainfrom
guyueh1:moe_fp8_rollout
Closed

feat: FP8 rollout in GRPO for MoE models#1175
guyueh1 wants to merge 13 commits intoNVIDIA-NeMo:mainfrom
guyueh1:moe_fp8_rollout

Conversation

@guyueh1
Copy link
Copy Markdown
Contributor

@guyueh1 guyueh1 commented Sep 21, 2025

What does this PR do ?

Support FP8 rollout (deepseek-style quantization) in GRPO for MoE models

This PR depends on the following PRs so do not merge until these are merged

Usage

uv run examples/run_grpo_math.py \
--config examples/configs/grpo_math_qwen30ba3b_megatron.yaml \
cluster.num_nodes=2 \
policy.megatron_cfg.activation_checkpointing=True \
policy.generation.vllm_cfg.precision="fp8" \
policy.generation.vllm_cfg.use_deep_gemm=true \
loss_fn.use_importance_sampling_correction=True \
policy.generation.vllm_cfg.tensor_parallel_size=2 \
policy.megatron_cfg.pipeline_model_parallel_size=2

Additionally you can choose to use e2e fp8 (in both rollout and policy) , this depends on #971 to be merged

export NVTE_FP8_BLOCK_SCALING_FP32_SCALES=1

uv run examples/run_grpo_math.py \
--config examples/configs/grpo_math_qwen30ba3b_megatron.yaml \
cluster.num_nodes=2 \
policy.megatron_cfg.activation_checkpointing=True \
policy.generation.vllm_cfg.precision="fp8" \
policy.generation.vllm_cfg.use_deep_gemm=true \
loss_fn.use_importance_sampling_correction=True \
policy.generation.vllm_cfg.tensor_parallel_size=2 \
policy.megatron_cfg.pipeline_model_parallel_size=2 \
+policy.megatron_cfg.fp8_cfg.enabled=true \
+policy.megatron_cfg.fp8_cfg.fp8=e4m3 \
+policy.megatron_cfg.fp8_cfg.fp8_recipe=blockwise \
+policy.megatron_cfg.fp8_cfg.fp8_param=false \
policy.megatron_cfg.moe_router_dtype=fp32

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you run the unit tests and functional tests locally? Visit our Testing Guide for how to run tests
  • Did you add or update any necessary documentation? Visit our Document Development Guide for how to write, build and test the docs.

Additional Information

  • ...

Summary by CodeRabbit

  • New Features
    • Added a ready-to-use GRPO configuration preset optimized for FP8 precision, including tensor/expert parallel settings, memory utilization tuning, and importance sampling correction.
    • Introduced FP8 support for MoE models, enabling loading and execution of FP8 MoE weights with optional accelerated backend integration for improved performance.
    • Relaxed initialization constraints to allow MoE configurations, improving compatibility across setups and simplifying FP8-enabled deployments.

Signed-off-by: Guyue Huang <guyueh@nvidia.com>
Signed-off-by: Guyue Huang <guyueh@nvidia.com>
@guyueh1 guyueh1 self-assigned this Sep 21, 2025
@guyueh1 guyueh1 added the r0.4.0 label Sep 21, 2025
@guyueh1 guyueh1 added r0.4.0 and removed r0.4.0 labels Sep 22, 2025
Signed-off-by: Guyue Huang <guyueh@nvidia.com>
Signed-off-by: Guyue Huang <guyueh@nvidia.com>
@guyueh1
Copy link
Copy Markdown
Contributor Author

guyueh1 commented Sep 22, 2025

The Qwen30B GRPO test with FP8 e2e is blocked by bumping vllm to 0.10.1 because we have to include this commit vllm-project/vllm@2212cd6; also depends on #1163 and this commit vllm-project/vllm@7e0b121

Moonlight16B/DeepseekV3 tests depend on adding flashinfer to the dependency

@guyueh1 guyueh1 marked this pull request as ready for review September 22, 2025 05:17
@guyueh1 guyueh1 requested review from a team as code owners September 22, 2025 05:17
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Sep 22, 2025

📝 Walkthrough

Walkthrough

Adds a new GRPO FP8 config for Qwen2.5 30B A3/Megatron. Extends FP8 generation to support MoE in vLLM by patching FusedMoE weight handling, enabling FP8 MoE initialization, and adding FlashInfer MOE FP8 integration with guards and traversal updates.

Changes

Cohort / File(s) Summary
Config: GRPO FP8 for Qwen2.5 30B A3
examples/configs/grpo_math_qwen30ba3b_megatron_fp8.yaml
New config enabling importance sampling correction; sets Megatron MoE dtype to fp32; enables FP8 (e4m3, blockwise, no param FP8); disables precision-aware optimizer; sets NVTE FP8 block scaling env; vLLM generation with TP=8, EP=8, precision fp8, deep GEMM, gpu mem util 0.6.
Runtime: FP8 MoE support in vLLM patching
nemo_rl/models/generation/fp8.py
Adds FusedMoE import; extends FP8 detection to MoE weights; adds process_weights_after_loading_moe; updates module traversal to short-circuit on FusedMoE; removes MoE expert-count assert in init; adds FlashInfer MOE FP8 gating/env flag; patches Fp8MoEMethod.process_weights_after_loading.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant Trainer
  participant FP8 as FP8 Init
  participant vLLM as vLLM Model
  participant MoE as FusedMoE
  participant Patch as FP8 Patches

  Trainer->>FP8: init_fp8(vllm_cfg, model_name, mp_size)
  alt use_flashinfer_moe_fp8 enabled
    FP8->>FP8: set VLLM_USE_FLASHINFER_MOE_FP8=1
  end
  FP8->>Patch: apply_fp8_patches()
  Patch->>vLLM: override Fp8MoEMethod.process_weights_after_loading

  Note over vLLM,MoE: Model/weights loading

  vLLM->>MoE: load weights
  alt module is FusedMoE and FP8 MOE path enabled
    vLLM->>Patch: process_weights_after_loading_moe(layer)
    Patch->>MoE: detect w13/w2 FP8, optional swap/integration
  else non-MoE or FP8 MOE disabled
    vLLM->>Patch: process_weights_after_loading(layer)
    Patch->>vLLM: standard FP8 linear handling
  end
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Test Results For Major Changes ⚠️ Warning The PR introduces a major feature (FP8 rollout for MoE in GRPO, plus config enabling FP8 paths) that can affect both numerics/convergence and performance. From the provided PR summary and comments, there are usage examples and dependency notes, but no documented test results, no convergence validation, and no before/after performance numbers or measurement context. The author even notes tests are blocked pending a vLLM bump and additional dependencies, reinforcing that results are not yet available. Therefore, the PR description does not meet the testing and performance evidence requirements for major changes. Please add: 1) brief functional test results demonstrating FP8 MoE generation works end-to-end (commands, model(s), hardware, seeds), 2) numerics/convergence checks showing no regression in GRPO rewards/loss curves vs FP16/BF16 baselines, and 3) performance benchmarks (tokens/s, latency, memory) before vs after with configuration details (vLLM version, tensor/expert parallel, FP8 recipe, FlashInfer settings, GPU type/count). If results depend on vLLM 0.10.1 and PR #1163, state exact versions/commits and include results after rebasing; otherwise mark the PR as draft until those dependencies land.
✅ Passed checks (5 passed)
Check name Status Explanation
Linked Issues Check ✅ Passed The changes provided implement the core coding requirements from issue #978 by extending FP8 support to fused MoE modules: FusedMoE handling is added, FP8 weight detection is extended to MoE-specific tensors, a MoE-specific process_weights_after_loading_moe path is introduced, the MoE initialization assertion was relaxed, and environment gating for FlashInfer MOE FP8 was added; the example config demonstrates intended usage. These modifications directly map to the linked issue's objective of enabling FP8 generation for vLLM fused MoE layers.
Out of Scope Changes Check ✅ Passed The diff appears focused: an example config was added and fp8.py was extended to support MoE FP8 paths and detection; there are no unrelated file edits or feature additions shown in the provided summaries, and the new import/function are directly tied to the FP8-for-MoE objective. I do not detect out-of-scope changes based on the supplied summaries.
Docstring Coverage ✅ Passed No functions found in the changes. Docstring coverage check skipped.
Title Check ✅ Passed The provided title succinctly and accurately captures the main change, namely adding FP8 rollout in GRPO for MoE models, and is both concise and directly relevant to the pull request contents.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
✨ Finishing touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
nemo_rl/models/generation/fp8.py (2)

127-135: Bug: list.append called with 3 args. Use extend to add multiple patchers.

This raises a TypeError at runtime, preventing patches from being applied.

Apply this diff:

-        fp8_state.vllm_patches.append(patcher2, patcher3, patcher4)
+        fp8_state.vllm_patches.extend([patcher2, patcher3, patcher4])

248-276: Robustness: guard empty module_path and missing packed_modules_mapping; avoid premature AttributeError.

reversed_mapping is computed before try and module_path[-1] may IndexError when name has no path. Also model might not define packed_modules_mapping.

Apply this diff:

 def _get_module_from_param_name(model, name: str):
@@
-    path_parts = name.split(".")
-    module_path = path_parts[:-1]
-    # Replace with the fused model name
-    packed_modules_mapping = model.packed_modules_mapping
-    reversed_mapping = {
-        original_name: fused_name
-        for fused_name, original_names_list in packed_modules_mapping.items()
-        for original_name in original_names_list
-    }
-    if module_path[-1] in reversed_mapping.keys():
-        module_path[-1] = reversed_mapping[module_path[-1]]
+    path_parts = name.split(".")
+    module_path = path_parts[:-1]
+    # Replace with the fused model name (if present)
+    packed_modules_mapping = getattr(model, "packed_modules_mapping", {}) or {}
+    reversed_mapping = {
+        original_name: fused_name
+        for fused_name, original_names_list in packed_modules_mapping.items()
+        for original_name in original_names_list
+    }
+    if module_path:
+        last = module_path[-1]
+        if last in reversed_mapping:
+            module_path[-1] = reversed_mapping[last]
@@
-    except (AttributeError, IndexError, ValueError) as e:
-        print(f"Warning: Could not find module for parameter '{name}'. Error: {e}")
+    except (AttributeError, IndexError, ValueError) as e:
+        logger = __import__("logging").getLogger(__name__)
+        logger.debug("Could not find module for parameter '%s': %s", name, e)
     return current_module

Additionally add (outside this hunk) once near imports:

import logging
logger = logging.getLogger(__name__)
🧹 Nitpick comments (7)
examples/configs/grpo_math_qwen30ba3b_megatron_fp8.yaml (3)

23-29: Expose all FP8/MoE knobs used by code (documented defaults).

init_fp8 reads several keys from vllm_cfg (pow2_* flags, bf16 layer counts, async_engine, use_flashinfer_moe_fp8). Surface them here with explicit, documented defaults to keep YAML as the single source of truth per repo guidelines.

Apply this diff:

   generation:
     vllm_cfg:
       tensor_parallel_size: 8
-      expert_parallel_size: 8 # need to make moe_intermediate_size / expert_tensor_parallel_size % 128 == 0
+      expert_parallel_size: 8  # ensure (moe_intermediate_size / expert_parallel_size) % 128 == 0
       precision: "fp8"
       use_deep_gemm: true
       gpu_memory_utilization: 0.6
+      # FP8/MoE controls (documented defaults)
+      async_engine: false
+      use_flashinfer_moe_fp8: false  # set true when FlashInfer MoE FP8 kernels are available
+      pow2_weight_scaling_factors: false
+      pow2_activation_scaling_factors: false
+      num_first_layers_in_bf16: 0
+      num_last_layers_in_bf16: 0

26-26: Fix misleading comment (var name mismatch).

“expert_tensor_parallel_size” isn’t a field here; the knob is expert_parallel_size. Updated above.


29-29: Add newline at end of file.

YAMLLint flagged this. Keeps tooling happy.

Apply this diff:

-      gpu_memory_utilization: 0.6
\ No newline at end of file
+      gpu_memory_utilization: 0.6
+
nemo_rl/models/generation/fp8.py (4)

298-314: Prefer tuples for (name, tensor) pairs and keep dtypes explicit.

Small hygiene to match common loaders and avoid accidental mutation.

Apply this diff:

-        param_lp, param_scale = cast_tensor_to_fp8_blockwise(
-            v.to(torch.float),
+        param_lp, param_scale = cast_tensor_to_fp8_blockwise(
+            v.to(torch.float32),
             weight_block_size=FP8_BLOCK_QUANT_KWARGS["weight_block_size"],
         )
         param_scale = torch.squeeze(param_scale, dim=-1)
-        weights_quantized.append([k, param_lp])
-        weights_quantized.append([k + "_scale_inv", param_scale])
+        weights_quantized.append((k, param_lp))
+        weights_quantized.append((k + "_scale_inv", param_scale))

400-417: Add a brief docstring and no_grad guard; keep assignments safe.

Clarifies behavior and avoids autograd surprises.

Apply this diff:

-def process_weights_after_loading_moe(self, layer) -> None:
-    
+def process_weights_after_loading_moe(self, layer) -> None:
+    """Post-load fixups for FP8 MoE layers.
+
+    - Asserts FP8 block-quant serialized checkpoints with dynamic activation.
+    - When FlashInfer MoE FP8 is enabled, swaps W13 layout to W31.
+    """
@@
-    flashinfer_moe_enabled = envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe()
-    if flashinfer_moe_enabled:
-        layer.w13_weight.data = _swap_w13_to_w31(layer.w13_weight.data)
-        layer.w13_weight_scale_inv.data = _swap_w13_to_w31(layer.w13_weight_scale_inv.data)
+    flashinfer_moe_enabled = envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe()
+    if flashinfer_moe_enabled:
+        import torch
+        with torch.no_grad():
+            layer.w13_weight.data = _swap_w13_to_w31(layer.w13_weight.data)
+            layer.w13_weight_scale_inv.data = _swap_w13_to_w31(layer.w13_weight_scale_inv.data)

19-19: Unused import?

ray is only needed when model_parallel_size > 1; if single‑GPU paths are dominant, consider importing locally in monkey_patch_vllm_ray_executor to speed cold import. Optional.


61-65: Global naming doesn’t follow repo guideline (G_ UPPER_SNAKE_CASE).

global_fp8_config, fp8_state, fp8_patches_applied, original_run_engine_core, original_init are globals. Consider renaming to G_FP8_CONFIG, G_FP8_STATE, G_FP8_PATCHES_APPLIED, G_ORIGINAL_RUN_ENGINE_CORE, G_ORIGINAL_COREENGINEPROC_INIT.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 5faaea8 and a30c96d.

📒 Files selected for processing (2)
  • examples/configs/grpo_math_qwen30ba3b_megatron_fp8.yaml (1 hunks)
  • nemo_rl/models/generation/fp8.py (6 hunks)
🧰 Additional context used
📓 Path-based instructions (3)
examples/configs/*.yaml

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

examples/configs/*.yaml: Exemplar configs under examples/configs/.yaml must include documented defaults
When adding a new config key, reflect its recommended default in exemplar YAMLs under examples/configs/
.yaml

Files:

  • examples/configs/grpo_math_qwen30ba3b_megatron_fp8.yaml
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Follow the Google Python Style Guide for all Python code
Target Python 3.12+ for all Python code in NeMo-RL
Indent Python code with 4 spaces; do not use tabs
Python filenames should be snake_case (e.g., some_file.py)
Class names should be PascalCase
Function and method names should be snake_case
Local variable names should be snake_case; if starting with a number, prefix with k (e.g., k_99th_percentile)
Global variables should be UPPER_SNAKE_CASE and prefixed with G_ (e.g., G_MY_GLOBAL)
Constants should be UPPER_SNAKE_CASE
Avoid shadowing variables declared in an outer scope
Initialize all externally visible members of a class in the constructor
For public interfaces used outside a file, prefer docstrings over comments
Use comments mainly for code within a function or interfaces local to a file
Commented-out code must include a nearby comment explaining usage and why it is commented out; otherwise remove before merging
Use Google-style docstrings for classes and functions (Sphinx-parseable)
Avoid using reflection when functionality can be easily achieved without it
Limit except clauses to the smallest specific set of exceptions possible
For duck-typing via try/except, keep the try body minimal and use else for main logic
Add the NVIDIA copyright header (with current year) at the top of all Python files, excluding tests/ and test-only scripts

Files:

  • nemo_rl/models/generation/fp8.py
nemo_rl/**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

nemo_rl/**/*.py: Do not set non-None configuration defaults in code; YAML is the single source of truth for defaults
Access required config attributes directly (e.g., policy_cfg["precision"]) and assume presence; do not introduce hidden defaults
Express configuration optionality via TypedDict using typing.NotRequired
When adding a new config key to a TypedDict subclass, document the key’s purpose, valid values/types, and recommended default in code
For any class or function decorated with @ray.remote, add '# pragma: no cover' on the class/def line (and on remote functions)

Files:

  • nemo_rl/models/generation/fp8.py
🪛 YAMLlint (1.37.1)
examples/configs/grpo_math_qwen30ba3b_megatron_fp8.yaml

[error] 29-29: no new line character at the end of file

(new-line-at-end-of-file)

🪛 GitHub Actions: CICD NeMo RL
nemo_rl/models/generation/fp8.py

[error] 1-1: Trailing whitespace found. pre-commit hook 'trailing-whitespace' failed; files were modified by this hook.


[error] 1-1: ruff hook failed: 2 issues found and fixed, some lint issues remain. See CI logs for details.


[error] 1-1: ruff-format hook failed: 1 file reformatted; 207 files left unchanged.

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: Post submodule check comment / Comment on PR
  • GitHub Check: Post automodel integration comment / Comment on PR
🔇 Additional comments (2)
nemo_rl/models/generation/fp8.py (2)

156-160: Wire flag -> env var for FlashInfer MoE FP8 is fine, but ensure config key exists in YAML.

The code reads use_flashinfer_moe_fp8; the exemplar YAML should expose it (default false) to avoid hidden defaults.

Please confirm the YAML change in examples/configs/grpo_math_qwen30ba3b_megatron_fp8.yaml includes use_flashinfer_moe_fp8 as proposed.


1-1: Run pre-commit and fix trailing-whitespace / ruff-format in nemo_rl/models/generation/fp8.py

CI reported trailing-whitespace and ruff-format edits for nemo_rl/models/generation/fp8.py. pre-commit is not available in the verification environment; run locally: pre-commit run -a and commit the auto-fixes (or run ruff --fix and strip trailing whitespace from the file).

Signed-off-by: Guyue Huang <guyueh@nvidia.com>
@guyueh1 guyueh1 requested a review from a team as a code owner September 22, 2025 17:06
Signed-off-by: Guyue Huang <guyueh@nvidia.com>
Signed-off-by: Guyue Huang <guyueh@nvidia.com>
Signed-off-by: Guyue Huang <guyueh@nvidia.com>
Signed-off-by: Guyue Huang <guyueh@nvidia.com>
@guyueh1 guyueh1 requested a review from a team as a code owner September 23, 2025 20:17
parthchadha
parthchadha previously approved these changes Sep 23, 2025
Signed-off-by: Guyue Huang <guyueh@nvidia.com>
Signed-off-by: Guyue Huang <guyueh@nvidia.com>
@guyueh1 guyueh1 changed the title feat: FP8 rollout in GRPO for MoE models feat: Bump to vllm 0.10.1.1 ray 2.48.0 and support FP8 rollout in GRPO for MoE models Sep 23, 2025
Comment thread ray.sub
Comment thread pyproject.toml
"megatron-bridge",
# Remove this once https://github.com/NVIDIA-NeMo/RL/issues/501 resolved
"vllm==0.10.0",
"vllm==0.10.1.1",
Copy link
Copy Markdown
Contributor

@yuki-97 yuki-97 Sep 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest to file another PR to bump vllm to a newer version, and we should run some convergence test to make sure the bump won't break anything.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good idea

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can work on this minor version bump in a separate PR #1199
then rebase this PR after 1199 PR is merged

Comment thread pyproject.toml
# Remove this once https://github.com/NVIDIA-NeMo/RL/issues/811 resolved
"transformers>=4.51.0,<4.54.0",
"ray[default]==2.48.0",
"transformers",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

transformers will be updated to >=4.55.4 in #1115, is this version ok with the FP8 rollout feature?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes 4.55.4 is ok; i will rebase after 1115

@guyueh1 guyueh1 changed the title feat: Bump to vllm 0.10.1.1 ray 2.48.0 and support FP8 rollout in GRPO for MoE models feat: FP8 rollout in GRPO for MoE models Sep 24, 2025
@guyueh1
Copy link
Copy Markdown
Contributor Author

guyueh1 commented Oct 23, 2025

Closing this, will start a new PR based on #1334

@guyueh1 guyueh1 closed this Oct 23, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants