Skip to content

Comments

Guard fused Marlin MoE atomic add usage to prevent crashes#13262

Closed
0xymoro wants to merge 1 commit intosgl-project:mainfrom
0xymoro:main
Closed

Guard fused Marlin MoE atomic add usage to prevent crashes#13262
0xymoro wants to merge 1 commit intosgl-project:mainfrom
0xymoro:main

Conversation

@0xymoro
Copy link
Contributor

@0xymoro 0xymoro commented Nov 14, 2025

Motivation

Marlin‑quantized MoE crash because sgl_kernel.fused_moe always forces the experimental atomic-add reduction path. Those large (M, N, K) shapes exceed the safe window for the Marlin kernel and trigger illegal memory accesses. We already have a proven heuristic elsewhere in the codebase; this PR reuses it so atomic add is only enabled when the shape/device combination is known safe.

Modifications

  • Import should_use_atomic_add_reduce inside sgl_kernel/python/sgl_kernel/fused_moe.py.
  • Evaluate the heuristic separately for the w1 and w2 GEMMs (first stage uses (M, 2N, K), second uses (M·topk, K, N)) and pass the result as use_atomic_add to each moe_wna16_marlin_gemm call.
  • No changes to CUDA kernels; we simply stop invoking them with unsafe settings.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @0xymoro, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request resolves a critical stability issue within the Marlin-quantized Mixture-of-Experts (MoE) implementation. It introduces a mechanism to guard the usage of an experimental atomic-add reduction path, which previously caused crashes when invoked with large matrix shapes that exceeded safe operational limits. By integrating a proven heuristic, the system now intelligently determines when to enable this reduction path, ensuring robust and safe execution of MoE operations without altering the core CUDA kernels.

Highlights

  • Crash Prevention: Addresses crashes in Marlin-quantized MoE by preventing the use of an experimental atomic-add reduction path with unsafe large matrix shapes.
  • Heuristic Integration: Integrates an existing, proven heuristic (should_use_atomic_add_reduce) to dynamically determine when it is safe to use the atomic-add reduction.
  • Conditional Atomic Add: Applies the safety heuristic independently to both the w1 and w2 GEMM operations within the fused_marlin_moe function.
  • No Kernel Changes: The fix is implemented purely in Python logic, without modifying the underlying CUDA kernels.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request effectively addresses a crash in Marlin-quantized MoE layers by introducing a crucial safety check for the atomic-add reduction path. The change correctly replaces a simplistic condition with a more robust heuristic, should_use_atomic_add_reduce, to determine when it's safe to use atomic add. This check is thoughtfully applied to each of the two GEMM stages with their respective tensor shapes. The implementation is clean, correct, and directly solves the illegal memory access issue. The changes are excellent and I have no suggestions for improvement.

@FlamingoPg FlamingoPg self-assigned this Nov 14, 2025
@FlamingoPg
Copy link
Collaborator

Thanks, could plz help fix lint?

@0xymoro
Copy link
Contributor Author

0xymoro commented Nov 14, 2025

Actually investigating something else - this may not be root cause. turning into draft for now

@0xymoro 0xymoro marked this pull request as draft November 14, 2025 18:40
@0xymoro
Copy link
Contributor Author

0xymoro commented Nov 14, 2025

This may actually be the fix - still running some tests. Challenging since it needs high batchsizes over course of hours to see if it nondeterministically crashes. Will update/report back if it works.

@0xymoro
Copy link
Contributor Author

0xymoro commented Nov 15, 2025

I got to a point where the marlin kernel kind of overflowed with large batchsizes. tracing thru but it's a slog, and I worked off Hopper and yes reproducible on large bsz

Current best guess and try:

Root cause: the second Marlin GEMM (moe_wna16_marlin_gemm that multiplies expert outputs back into the model space) was receiving
extremely large batches—e.g. M * topk = 131 072 tokens with N = 7168, K = 256. The kernel relies on a workspace/lock buffer
sized roughly by sms * 4, which was fine for moderate M but far too small for those huge routed batches. Once the workspace got
exhausted, the kernel wrote past the end of the buffer, corrupting intermediate_cache3 before moe_sum_reduce ever ran. The later
moe_sum_reduce/torch.sum change can’t help because the memory was already trashed by the time the reduction ran.

Our fix: scale the workspace with the actual amount of routed work—max(2N, K) // 64 * (# routed token blocks)—and only fall back to
sms * 4 if the computed size is small. That keeps the per-token lock/temporary storage available even when M * topk explodes, so the
Marlin kernel never hits overlap/corruption in the first place. We left the CUDA binaries untouched; the Python glue now allocates
enough space before calling them.

[2025-11-15 02:04:00 TP2] marlin stage2: M=131072 N=7168 K=256 use_atomic=False
[2025-11-15 02:04:00 TP5] marlin stage2: M=131072 N=7168 K=256 use_atomic=False
[2025-11-15 02:04:00 TP4] marlin stage2: M=131072 N=7168 K=256 use_atomic=False
[2025-11-15 02:04:00 TP0] marlin stage2: M=131072 N=7168 K=256 use_atomic=False
[2025-11-15 02:04:00 TP3] marlin stage2: M=131072 N=7168 K=256 use_atomic=False
[2025-11-15 02:04:00 TP6] marlin stage2: M=131072 N=7168 K=256 use_atomic=False
[2025-11-15 02:04:00 TP1] marlin stage2: M=131072 N=7168 K=256 use_atomic=False
[2025-11-15 02:04:00 TP7] marlin stage2: M=131072 N=7168 K=256 use_atomic=False
[2025-11-15 02:04:00 TP3] Scheduler hit an exception: Traceback (most recent call last):
File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 2725, in run_scheduler_process
scheduler.event_loop_overlap()
File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 120, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 1019, in event_loop_overlap
batch_result = self.run_batch(batch)
^^^^^^^^^^^^^^^^^^^^^
File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 2021, in run_batch
batch_result = self.model_worker.forward_batch_generation(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker.py", line 371, in forward_batch_generation
logits_output, can_run_cuda_graph = self.model_runner.forward(

@pdasgup
Copy link
Contributor

pdasgup commented Nov 18, 2025

is there any query to reproduce or fix for this issue? we are observing this crash #13234 on 8xB200 with image 0.5.5.post3 for Kimi K2 Thinking model

@ispobock
Copy link
Collaborator

@0xymoro Could you update the change based on #13596? We moved the fused_marlin_moe out of the sgl-kernel?

@0xymoro
Copy link
Contributor Author

0xymoro commented Nov 20, 2025

I don't think this change is the crux of the fix though (since on some tests 0.5.5 post2 with this still crashed) and it's likely something else. still investigating

@0xymoro 0xymoro closed this by deleting the head repository Jan 5, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants