[sgl-kernel][Feat][B200][1/N]Support MXFP8 Grouped GEMM in Blackwell#13731
[sgl-kernel][Feat][B200][1/N]Support MXFP8 Grouped GEMM in Blackwell#13731mickqian merged 10 commits intosgl-project:mainfrom
Conversation
Summary of ChangesHello @HydraQYH, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces support for CUTLASS-based MXFP8 Grouped GEMM on the Blackwell architecture, along with a corresponding group quantization kernel. The implementation is comprehensive and includes several performance optimizations. My review has identified a critical type-mismatch issue in the test code that should be addressed, as well as a high-severity issue related to incomplete input validation in one of the new kernels. I've also noted several medium-severity issues, such as typos in function names, incorrect error messages, and leftover TODO comments, which should be cleaned up to improve code quality and maintainability.
sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled.cu
Outdated
Show resolved
Hide resolved
sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled_functor.cuh
Outdated
Show resolved
Hide resolved
sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled_functor.cuh
Outdated
Show resolved
Hide resolved
sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled_group_quant.cu
Outdated
Show resolved
Hide resolved
sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled_group_quant.cuh
Outdated
Show resolved
Hide resolved
sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled_launcher.cuh
Outdated
Show resolved
Hide resolved
sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled_launcher.cuh
Outdated
Show resolved
Hide resolved
fe5b209 to
a9e5243
Compare
28b38b9 to
417d26c
Compare
|
Awesome workpiece! The design spirit is quite similar with DeepSeek DSA Indexer, using a pre-compute kernel to avoid massive computing. Whereas Indexer is to pick up 2048 previous tokens and participate on-going self-attention so as to make the compute linear complexity, in this PR the "intelligent" pre-compute is adopted to do "real" grouped gemm kernel dispatch via masking problem size. Moreover, you propagate the design to sm100 mxfp8. It's really a great design. Thanks a lot. |
|
The basic design for hopper refers to #11432 series. |
|
For sm100, SGLang has The kernel |
sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled_group_quant.cuh
Show resolved
Hide resolved
sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled_group_quant.cuh
Show resolved
Hide resolved
@yuan-luo That is a fantastic idea. I also believe that the choice of which kernel to select for an expert should not be based solely on the number of tokens, but rather on Arithmetic Intensity. Imagine a scenario where the N and K dimensions of an expert are very small. Even with a large number of tokens, the overall arithmetic Intensity of this expert is still very low, placing it within the memory bounds. In this case, a kernel with higher TMA load efficiency should be chosen, rather than a kernel designed for compute bounds. I am still verifying this idea on the SM90. If there is any progress, I will submit a PR for optimization. |
@yuan-luo This dispatch strategy has a very obvious problem. |
Got it, this is one of the key designs in your PR. |
…sgl-project#13731) Co-authored-by: Yineng Zhang <me@zhyncs.com>
…sgl-project#13731) Co-authored-by: Yineng Zhang <me@zhyncs.com>
…sgl-project#13731) Co-authored-by: Yineng Zhang <me@zhyncs.com>
…sgl-project#13731) Co-authored-by: Yineng Zhang <me@zhyncs.com>
…sgl-project#13731) Co-authored-by: Yineng Zhang <me@zhyncs.com>
…sgl-project#13731) Co-authored-by: Yineng Zhang <me@zhyncs.com>
…sgl-project#13731) Co-authored-by: Yineng Zhang <me@zhyncs.com>
…sgl-project#13731) Co-authored-by: Yineng Zhang <me@zhyncs.com>
…sgl-project#13731) Co-authored-by: Yineng Zhang <me@zhyncs.com>
Motivation
This PR supports CUTLASS-based MXFP8 Grouped GEMM. In addition to introducing the MXFP8 Grouped GEMM, this PR also provides a Group Quant Kernel for calculating quantization inputs and scale factors. The Grouped Quant Kernel is implemented in C++ based on CUTLASS CuTe. Grouped Quant Kernel can perform quantization calculations for all groups within a single kernel. Notice that the format of the Scale Factor is special: https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout

In TensorRT-LLM, the scale factor is written to Global Memory using the STG.8 instruction. This method of storing data is inefficient. So in our implementation, we use the following three optimization techniques:
When there is sufficient data, our kernel can effectively utilize HBM bandwidth on B200(up to 85%+):

Modifications
Accuracy Tests
Benchmarking and Profiling
Checklist