Skip to content

[sgl-kernel][Feat][B200][1/N]Support MXFP8 Grouped GEMM in Blackwell#13731

Merged
mickqian merged 10 commits intosgl-project:mainfrom
HydraQYH:dev_support_mxfp8_grouped_gemm
Dec 4, 2025
Merged

[sgl-kernel][Feat][B200][1/N]Support MXFP8 Grouped GEMM in Blackwell#13731
mickqian merged 10 commits intosgl-project:mainfrom
HydraQYH:dev_support_mxfp8_grouped_gemm

Conversation

@HydraQYH
Copy link
Collaborator

@HydraQYH HydraQYH commented Nov 21, 2025

Motivation

This PR supports CUTLASS-based MXFP8 Grouped GEMM. In addition to introducing the MXFP8 Grouped GEMM, this PR also provides a Group Quant Kernel for calculating quantization inputs and scale factors. The Grouped Quant Kernel is implemented in C++ based on CUTLASS CuTe. Grouped Quant Kernel can perform quantization calculations for all groups within a single kernel. Notice that the format of the Scale Factor is special: https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
image
In TensorRT-LLM, the scale factor is written to Global Memory using the STG.8 instruction. This method of storing data is inefficient. So in our implementation, we use the following three optimization techniques:

  1. 256bit Load
  2. 100% Occupancy
  3. Overlapping with TMA STORE

When there is sufficient data, our kernel can effectively utilize HBM bandwidth on B200(up to 85%+):
image

Modifications

  • sgl-kernel/CMakeLists.txt
  • sgl-kernel/include/sgl_kernel_ops.h
  • sgl-kernel/csrc/common_extension.cc
  • sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled_functor.cuh
  • sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled_group_quant.cu
  • sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled_group_quant.cuh
  • sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled_launcher.cuh
  • sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled_traits.cuh
  • sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled.cu
  • sgl-kernel/python/sgl_kernel/__init__.py
  • sgl-kernel/python/sgl_kernel/expert_specialization.py
  • sgl-kernel/tests/test_es_mxfp8_blockscaled_moe.py

Accuracy Tests

mxfp8_unitest

Benchmarking and Profiling

Checklist

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @HydraQYH, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the sgl-kernel library by adding specialized support for MXFP8 Grouped GEMM and an optimized Group Quantization Kernel tailored for NVIDIA's Blackwell architecture. The primary goal is to accelerate mixed-precision operations in deep learning models by leveraging advanced hardware capabilities and efficient data handling, ensuring high performance and memory bandwidth utilization.

Highlights

  • MXFP8 Grouped GEMM Support: Introduces support for CUTLASS-based MXFP8 Grouped General Matrix Multiplication (GEMM) specifically for Blackwell (SM100) architecture, enabling highly efficient mixed-precision computations.
  • Group Quantization Kernel: Adds a new Group Quant Kernel, implemented in C++ using CUTLASS CuTe, which calculates quantization inputs and scale factors for all groups within a single kernel, optimizing the quantization process.
  • Performance Optimizations: Incorporates several optimization techniques for scale factor storage and processing, including 256-bit loads, achieving 100% occupancy, and overlapping with TMA STORE, leading to efficient HBM bandwidth utilization (up to 85%+).
  • PyTorch Integration and Testing: Integrates the new MXFP8 grouped GEMM and quantization kernels into the PyTorch extension, making them accessible via torch.ops.sgl_kernel, and includes comprehensive unit tests to verify accuracy on SM100 devices.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for CUTLASS-based MXFP8 Grouped GEMM on the Blackwell architecture, along with a corresponding group quantization kernel. The implementation is comprehensive and includes several performance optimizations. My review has identified a critical type-mismatch issue in the test code that should be addressed, as well as a high-severity issue related to incomplete input validation in one of the new kernels. I've also noted several medium-severity issues, such as typos in function names, incorrect error messages, and leftover TODO comments, which should be cleaned up to improve code quality and maintainability.

@HydraQYH HydraQYH force-pushed the dev_support_mxfp8_grouped_gemm branch from fe5b209 to a9e5243 Compare November 21, 2025 14:35
Copy link
Collaborator

@BBuf BBuf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great job! We can also conside improve many memory bound kernel's HBM bandwidth with 256 bit LDG/STS in b200.

@yuan-luo
Copy link
Collaborator

yuan-luo commented Nov 22, 2025

Awesome workpiece! The design spirit is quite similar with DeepSeek DSA Indexer, using a pre-compute kernel to avoid massive computing. Whereas Indexer is to pick up 2048 previous tokens and participate on-going self-attention so as to make the compute linear complexity, in this PR the "intelligent" pre-compute is adopted to do "real" grouped gemm kernel dispatch via masking problem size. Moreover, you propagate the design to sm100 mxfp8. It's really a great design. Thanks a lot.
One comment may not be matured, is it possible to introduce a pre-trained pre-compute kernel(something like Indexer weight in DSv3.2) to better masking problem size instead of deciding by token number, that would be pretty cool ...

@yuan-luo
Copy link
Collaborator

The basic design for hopper refers to #11432 series.

@yuan-luo
Copy link
Collaborator

yuan-luo commented Nov 23, 2025

For sm100, SGLang has sm100_fp8_blockwise_group_mm_dispatch_shape to do dispatch based on shape. The dispatch principle seems to be similar this PR.

  if (a.size(0) <= 2048 && a.size(1) >= 2048) {
    run_get_group_gemm_starts<MmaConfig1::LayoutSFA, MmaConfig1::LayoutSFB, MmaConfig1::ScaleConfig>(
        expert_offsets,
        a_ptrs,
        b_ptrs,
        out_ptrs,
        a_scales_ptrs,
        b_scales_ptrs,
        b_t,
        a_t,
        output_t,
        scales_b_t,
        scales_a_t,
        layout_sfa,
        layout_sfb,
        problem_sizes,
        problem_sizes_transpose,
        true);
    launch_sm100_fp8_blockwise_scaled_group_mm<OutType, MmaConfig1, cutlass::layout::ColumnMajor>(
        out_ptrs,
        a_ptrs,
        b_ptrs,
        a_scales_ptrs,
        b_scales_ptrs,
        stride_a,
        stride_b,
        stride_c,
        layout_sfa,
        layout_sfb,
        problem_sizes_transpose,
        expert_offsets,
        workspace);
    output = output_t.t();
  } else if (a.size(0) > 2048 && a.size(1) >= 2048) {
  ......
     // Dispatch to another kernel with another set of MmaConfig/LayoutSF.
  }

The kernel launch_es_sm100_mxfp8_blockscaled_grouped_quant to do grouped quant mxfp8 with cute is awesome! It is really an excellent example for cute.

@HydraQYH
Copy link
Collaborator Author

Awesome workpiece! The design spirit is quite similar with DeepSeek DSA Indexer, using a pre-compute kernel to avoid massive computing. Whereas Indexer is to pick up 2048 previous tokens and participate on-going self-attention so as to make the compute linear complexity, in this PR the "intelligent" pre-compute is adopted to do "real" grouped gemm kernel dispatch via masking problem size. Moreover, you propagate the design to sm100 mxfp8. It's really a great design. Thanks a lot. One comment may not be matured, is it possible to introduce a pre-trained pre-compute kernel(something like Indexer weight in DSv3.2) to better masking problem size instead of deciding by token number, that would be pretty cool ...

@yuan-luo That is a fantastic idea. I also believe that the choice of which kernel to select for an expert should not be based solely on the number of tokens, but rather on Arithmetic Intensity. Imagine a scenario where the N and K dimensions of an expert are very small. Even with a large number of tokens, the overall arithmetic Intensity of this expert is still very low, placing it within the memory bounds. In this case, a kernel with higher TMA load efficiency should be chosen, rather than a kernel designed for compute bounds. I am still verifying this idea on the SM90. If there is any progress, I will submit a PR for optimization.

@HydraQYH
Copy link
Collaborator Author

For sm100, SGLang has sm100_fp8_blockwise_group_mm_dispatch_shape to do dispatch based on shape. The dispatch principle seems to be similar this PR.

  if (a.size(0) <= 2048 && a.size(1) >= 2048) {
    run_get_group_gemm_starts<MmaConfig1::LayoutSFA, MmaConfig1::LayoutSFB, MmaConfig1::ScaleConfig>(
        expert_offsets,
        a_ptrs,
        b_ptrs,
        out_ptrs,
        a_scales_ptrs,
        b_scales_ptrs,
        b_t,
        a_t,
        output_t,
        scales_b_t,
        scales_a_t,
        layout_sfa,
        layout_sfb,
        problem_sizes,
        problem_sizes_transpose,
        true);
    launch_sm100_fp8_blockwise_scaled_group_mm<OutType, MmaConfig1, cutlass::layout::ColumnMajor>(
        out_ptrs,
        a_ptrs,
        b_ptrs,
        a_scales_ptrs,
        b_scales_ptrs,
        stride_a,
        stride_b,
        stride_c,
        layout_sfa,
        layout_sfb,
        problem_sizes_transpose,
        expert_offsets,
        workspace);
    output = output_t.t();
  } else if (a.size(0) > 2048 && a.size(1) >= 2048) {
  ......
     // Dispatch to another kernel with another set of MmaConfig/LayoutSF.
  }

The kernel launch_es_sm100_mxfp8_blockscaled_grouped_quant to do grouped quant mxfp8 with cute is awesome! It is really an excellent example for cute.

@yuan-luo This dispatch strategy has a very obvious problem. a.size(0) <= 2048 indicates that it performs dispatching solely based on the Total number of tokens. It do not even infer the average number of tokens processed by each expert based on the number of experts. In #11432, when the batch size is between 512 and 1024, the total number of tokens is already greater than 2048. On average, each expert only needs to process 16 to 32 tokens, but a kernel with M=128 was chosen, resulting in a lot of redundant calculations.

@yuan-luo
Copy link
Collaborator

For sm100, SGLang has sm100_fp8_blockwise_group_mm_dispatch_shape to do dispatch based on shape. The dispatch principle seems to be similar this PR.

  if (a.size(0) <= 2048 && a.size(1) >= 2048) {
    run_get_group_gemm_starts<MmaConfig1::LayoutSFA, MmaConfig1::LayoutSFB, MmaConfig1::ScaleConfig>(
        expert_offsets,
        a_ptrs,
        b_ptrs,
        out_ptrs,
        a_scales_ptrs,
        b_scales_ptrs,
        b_t,
        a_t,
        output_t,
        scales_b_t,
        scales_a_t,
        layout_sfa,
        layout_sfb,
        problem_sizes,
        problem_sizes_transpose,
        true);
    launch_sm100_fp8_blockwise_scaled_group_mm<OutType, MmaConfig1, cutlass::layout::ColumnMajor>(
        out_ptrs,
        a_ptrs,
        b_ptrs,
        a_scales_ptrs,
        b_scales_ptrs,
        stride_a,
        stride_b,
        stride_c,
        layout_sfa,
        layout_sfb,
        problem_sizes_transpose,
        expert_offsets,
        workspace);
    output = output_t.t();
  } else if (a.size(0) > 2048 && a.size(1) >= 2048) {
  ......
     // Dispatch to another kernel with another set of MmaConfig/LayoutSF.
  }

The kernel launch_es_sm100_mxfp8_blockscaled_grouped_quant to do grouped quant mxfp8 with cute is awesome! It is really an excellent example for cute.

@yuan-luo This dispatch strategy has a very obvious problem. a.size(0) <= 2048 indicates that it performs dispatching solely based on the Total number of tokens. It do not even infer the average number of tokens processed by each expert based on the number of experts. In #11432, when the batch size is between 512 and 1024, the total number of tokens is already greater than 2048. On average, each expert only needs to process 16 to 32 tokens, but a kernel with M=128 was chosen, resulting in a lot of redundant calculations.

Got it, this is one of the key designs in your PR.

@FlamingoPg FlamingoPg mentioned this pull request Nov 25, 2025
6 tasks
@mickqian mickqian merged commit 16ff892 into sgl-project:main Dec 4, 2025
135 of 140 checks passed
tom-jerr pushed a commit to tom-jerr/sglang that referenced this pull request Dec 4, 2025
yingluosanqian pushed a commit to yingluosanqian/sglang that referenced this pull request Dec 4, 2025
tonyluj pushed a commit to openanolis/sglang that referenced this pull request Dec 5, 2025
tonyluj pushed a commit to openanolis/sglang that referenced this pull request Dec 5, 2025
yuchengz816-bot pushed a commit to yuchengz816-bot/sglang that referenced this pull request Dec 8, 2025
Kevin-XiongC pushed a commit to novitalabs/sglang that referenced this pull request Dec 9, 2025
tonyluj pushed a commit to openanolis/sglang that referenced this pull request Dec 12, 2025
tonyluj pushed a commit to openanolis/sglang that referenced this pull request Dec 12, 2025
tonyluj pushed a commit to openanolis/sglang that referenced this pull request Dec 12, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants

Comments