Skip to content

EP Support for Piecewise Cuda Graph#14164

Merged
ispobock merged 36 commits intosgl-project:mainfrom
Oasis-Git:ep-support
Dec 19, 2025
Merged

EP Support for Piecewise Cuda Graph#14164
ispobock merged 36 commits intosgl-project:mainfrom
Oasis-Git:ep-support

Conversation

@Oasis-Git
Copy link
Collaborator

@Oasis-Git Oasis-Git commented Nov 30, 2025

Motivation

Support EP for Piecewise Cuda Graph

Modifications

Following discussions with @ch-wan, @ispobock, and @BBuf, the Piecewise CUDA Graph wrapper for the MoE layer has been scoped to focus on Fused MoE and its derived classes. Currently, we primarily support standard top-k output, handling the necessary packing and unpacking operations for Torch library registration.

Also credit to: @byjiang1996

Accuracy Tests & Benchmark Output

For Deepseek-v2 model:

$ (EP=2 with pcg)
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1319/1319 [00:14<00:00, 93.96it/s]
Accuracy: 0.379
Invalid: 0.005
Latency: 14.082 s
Output throughput: 11308.761 token/s
$ (TP=2 with pcg)
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1319/1319 [00:14<00:00, 91.64it/s]
Accuracy: 0.381
Invalid: 0.005
Latency: 14.441 s
Output throughput: 11134.007 token/s

For gpt-oss 20B model:

$ (EP=2 with pcg)
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1319/1319 [00:14<00:00, 93.96it/s]
Accuracy: 0.522
Invalid: 0.120
Latency: 35.192 s
Output throughput: 16782.753 token/s
$ (TP=2 with pcg)
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1319/1319 [00:14<00:00, 91.64it/s]
Accuracy: 0.518
Invalid: 0.149
Latency: 33.683 s
Output throughput: 16732.154 token/s

For Qwen3-30B-A3B:

$ (EP=2 with pcg)
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1319/1319 [00:14<00:00, 93.96it/s]
Accuracy: 0.917
Invalid: 0.000
Latency: 14.053 s
Output throughput: 10801.772 token/s
$ (TP=2 with pcg)
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1319/1319 [00:14<00:00, 91.64it/s]
Accuracy: 0.917
Invalid: 0.000
Latency: 15.722 s
Output throughput: 9668.987 token/s

Checklist

Signed-off-by: Oasis-Git <ayw.sirius19@gmail.com>
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @Oasis-Git, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the SGLang runtime by integrating Mixture-of-Experts (MoE) layers into the piecewise CUDA graph execution pipeline. The primary goal is to leverage CUDA graphs for improved performance of Fused MoE operations, with a focus on standard top-k output formats for compatibility. Additionally, the changes include a refactoring of MoE output handling for better type management and the introduction of NPU-specific optimizations for the Qwen3MoE model, broadening hardware support and efficiency.

Highlights

  • Piecewise CUDA Graph Support for MoE Layers: Introduced support for running Mixture-of-Experts (MoE) layers within a piecewise CUDA graph, specifically targeting Fused MoE and its derived classes. This involves modifying the forward pass of MoE layers to conditionally use a CUDA graph implementation.
  • Standard Top-K Output Enforcement: When operating within the piecewise CUDA graph context, the system now strictly enforces the use of the standard top-k output format for MoE layers, with assertions added to prevent non-standard formats.
  • Refactoring of TopKOutputFormat: The TopKOutputFormat enum has been refactored to an IntEnum, and type-checking methods in TopKOutputChecker were updated to use isinstance for improved type safety and clarity.
  • NPU Support for Qwen3MoE: Added NPU-specific optimizations and dispatch logic for the Qwen3MoE model's attention mechanism, including a new split_qkv_rmsnorm_rope kernel for NPU, and updated MoE backend checks to include is_ascend_fuseep().
  • Context Management for MoE Layers: The ForwardContext and set_forward_context mechanisms were extended to include MoE layers, allowing them to be properly managed and accessed within the piecewise CUDA graph execution flow.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for Expert Parallelism in Piecewise CUDA Graphs, which is a significant feature for optimizing MoE models. The changes are mostly well-structured, touching upon the forward context, MoE layers, and model execution runners. My review highlights a critical bug that could lead to infinite recursion, a TypeError due to incorrect API usage, and several medium-severity maintainability issues related to code duplication and brittle logic. Addressing these points will improve the robustness and long-term health of the codebase.

Comment on lines 869 to +893
def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
if is_in_piecewise_cuda_graph():
assert TopKOutputChecker.format_is_standard(
topk_output
), "Only standard topk output is supported for piecewise cuda graph"
return torch.ops.sglang.moe_forward_piecewise_cuda_graph_impl(
hidden_states,
topk_output.topk_weights,
topk_output.topk_ids,
topk_output.router_logits,
self.layer_id,
)
else:
return self.forward_impl(hidden_states, topk_output)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This forward method is identical to the one in the base class FusedMoE. You can remove this duplicated implementation from FlashInferFusedMoE and FlashInferFP4MoE to improve maintainability. The logic will be inherited from FusedMoE.

    def forward_impl(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
        origin_hidden_states_dim = hidden_states.shape[-1]
        assert self.quant_method is not None

        dispatch_output = self.dispatcher.dispatch(
            hidden_states=hidden_states, topk_output=topk_output
        )
        if _use_aiter and self.dispatcher.local_expert_mapping is not None:
            self.expert_mask_gpu = (
                (
                    (self.dispatcher.local_expert_mapping >= 0)
                    & (self.dispatcher.local_expert_mapping < self.num_local_experts)
                )
                .to(torch.int32)
                .to(device="cuda")
            )

        combine_input = self.run_moe_core(
            dispatch_output=dispatch_output,
        )

Comment on lines 1019 to +1058
def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
if is_in_piecewise_cuda_graph():
assert (
topk_output.format() == TopKOutputFormat.STANDARD
), "Only standard topk output is supported for piecewise cuda graph"
return torch.ops.sglang.moe_forward_piecewise_cuda_graph_impl(
hidden_states,
topk_output.topk_weights,
topk_output.topk_ids,
topk_output.router_logits,
self.layer_id,
)
else:
return self.forward_impl(hidden_states, topk_output)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This forward method is identical to the one in the base class FusedMoE. You can remove this duplicated implementation to improve maintainability. The logic will be inherited from FusedMoE.

Comment on lines 1093 to +1147
def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
if is_in_piecewise_cuda_graph():
assert (
topk_output.format() == TopKOutputFormat.STANDARD
), "Only standard topk output is supported for piecewise cuda graph"
return torch.ops.sglang.moe_forward_piecewise_cuda_graph_impl(
hidden_states,
topk_output.topk_weights,
topk_output.topk_ids,
topk_output.router_logits,
self.layer_id,
)
else:
return self.forward_impl(hidden_states, topk_output)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This forward method is identical to the one in the base class FusedMoE. You can remove this duplicated implementation to improve maintainability. The logic will be inherited from FusedMoE.

Comment on lines +365 to +374
moe_block = None
if hasattr(layer, "mlp") and hasattr(layer.mlp, "experts"):
moe_block = layer.mlp.experts
if hasattr(layer, "block_sparse_moe") and hasattr(
layer.block_sparse_moe, "experts"
):
moe_block = layer.block_sparse_moe.experts
if hasattr(layer, "moe") and hasattr(layer.moe, "experts"):
moe_block = layer.moe.experts
self.moe_layers.append(moe_block)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This series of if statements to find the moe_block can be brittle. If a layer happens to have more than one of these MoE attributes (e.g., both mlp.experts and block_sparse_moe.experts), the moe_block variable will be overwritten, and only the last one found will be used. Using if/elif/else would make the priority explicit and the code more robust.

Suggested change
moe_block = None
if hasattr(layer, "mlp") and hasattr(layer.mlp, "experts"):
moe_block = layer.mlp.experts
if hasattr(layer, "block_sparse_moe") and hasattr(
layer.block_sparse_moe, "experts"
):
moe_block = layer.block_sparse_moe.experts
if hasattr(layer, "moe") and hasattr(layer.moe, "experts"):
moe_block = layer.moe.experts
self.moe_layers.append(moe_block)
moe_block = None
if hasattr(layer, "mlp") and hasattr(layer.mlp, "experts"):
moe_block = layer.mlp.experts
elif hasattr(layer, "block_sparse_moe") and hasattr(
layer.block_sparse_moe, "experts"
):
moe_block = layer.block_sparse_moe.experts
elif hasattr(layer, "moe") and hasattr(layer.moe, "experts"):
moe_block = layer.moe.experts
self.moe_layers.append(moe_block)

Signed-off-by: Oasis-Git <ayw.sirius19@gmail.com>
@ispobock
Copy link
Collaborator

/tag-and-rerun-ci

@ispobock
Copy link
Collaborator

@Oasis-Git Could you add a unit test?

Copy link
Collaborator

@ch-wan ch-wan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Signed-off-by: Oasis-Git <ayw.sirius19@gmail.com>
Signed-off-by: Oasis-Git <ayw.sirius19@gmail.com>
Signed-off-by: Oasis-Git <ayw.sirius19@gmail.com>
Signed-off-by: Oasis-Git <ayw.sirius19@gmail.com>
Signed-off-by: Oasis-Git <ayw.sirius19@gmail.com>
@Oasis-Git Oasis-Git changed the title [WIP]EP Support for Piecewise Cuda Graph EP Support for Piecewise Cuda Graph Dec 12, 2025
Signed-off-by: Oasis-Git <ayw.sirius19@gmail.com>
@Oasis-Git
Copy link
Collaborator Author

/tag-and-rerun-ci

Oasis-Git and others added 19 commits December 13, 2025 19:43
Signed-off-by: Oasis-Git <ayw.sirius19@gmail.com>
Signed-off-by: Oasis-Git <ayw.sirius19@gmail.com>
Signed-off-by: Ann <yuanlai444@gmail.com>
Signed-off-by: Ann <yuanlai444@gmail.com>
Signed-off-by: Ann <yuanlai444@gmail.com>
Signed-off-by: Oasis-Git <ayw.sirius19@gmail.com>
Signed-off-by: Oasis-Git <ayw.sirius19@gmail.com>
Signed-off-by: Oasis-Git <ayw.sirius19@gmail.com>
Signed-off-by: Oasis-Git <ayw.sirius19@gmail.com>
Signed-off-by: Oasis-Git <ayw.sirius19@gmail.com>
Signed-off-by: Oasis-Git <ayw.sirius19@gmail.com>
@ispobock ispobock merged commit 9d0347b into sgl-project:main Dec 19, 2025
224 of 249 checks passed
@Oasis-Git Oasis-Git deleted the ep-support branch December 20, 2025 23:59
Prozac614 pushed a commit to Prozac614/sglang that referenced this pull request Dec 23, 2025
Signed-off-by: Oasis-Git <ayw.sirius19@gmail.com>
jiaming1130 pushed a commit to zhuyijie88/sglang that referenced this pull request Dec 25, 2025
Signed-off-by: Oasis-Git <ayw.sirius19@gmail.com>
Comment on lines 63 to +66
_forward_context.set_forward_batch(forward_batch)
_forward_context.set_attention_layers(attention_layers)
_forward_context.set_quant_config(quant_config)
_forward_context.set_moe_layers(moe_layers)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These setters are redundant. Can you just put them in the __init__ of ForwardContext?

hidden_states: torch.Tensor,
topk_output: TopKOutput,
):
if is_in_piecewise_cuda_graph():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Normally, we put the comment branch as the first branch of the if/else statements to increase the code readability.
So in this case,
It should be

if not is_in_piecewise_cuda_graph():
    ...
else:
    ...

f"Use Sliding window memory pool. full_layer_tokens={self.full_max_total_num_tokens}, swa_layer_tokens={self.swa_max_total_num_tokens}"
)

def can_run_piecewise_cuda_graph(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we put these checks to server_args.py?
We would like to put the checks as early as possible and make the printed ServerArgs reflect most things.

Comment on lines +270 to +271
with freeze_gc(self.model_runner.server_args.enable_cudagraph_gc):
with set_compiled(True):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can merge the two with in a single line.

YChange01 pushed a commit to YChange01/sglang that referenced this pull request Jan 13, 2026
Signed-off-by: Oasis-Git <ayw.sirius19@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants