Skip to content

Commit 5be67ea

Browse files
HydraQYHzhyncs
authored andcommitted
[sgl-kernel][Feat][B200][1/N] Support MXFP8 Grouped GEMM in Blackwell (sgl-project#13731)
Co-authored-by: Yineng Zhang <me@zhyncs.com>
1 parent 85358c7 commit 5be67ea

12 files changed

+1174
-1
lines changed

sgl-kernel/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,8 @@ set(SOURCES
280280
"csrc/elementwise/rope.cu"
281281
"csrc/elementwise/topk.cu"
282282
"csrc/expert_specialization/es_fp8_blockwise.cu"
283+
"csrc/expert_specialization/es_sm100_mxfp8_blockscaled.cu"
284+
"csrc/expert_specialization/es_sm100_mxfp8_blockscaled_group_quant.cu"
283285

284286
"csrc/gemm/awq_kernel.cu"
285287
"csrc/gemm/bmm_fp8.cu"

sgl-kernel/csrc/common_extension.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -576,6 +576,14 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
576576
"stride_a, Tensor stride_b, Tensor stride_d, Tensor problem_sizes, Tensor expert_offsets, Tensor workspace) -> "
577577
"()");
578578
m.impl("es_fp8_blockwise_scaled_grouped_mm", &es_fp8_blockwise_scaled_grouped_mm);
579+
m.def(
580+
"es_sm100_mxfp8_blockscaled_grouped_mm(Tensor a, Tensor b, Tensor sfa, Tensor sfb, Tensor d, Tensor "
581+
"problem_sizes, Tensor expert_offsets, Tensor blockscale_offsets) -> ()");
582+
m.impl("es_sm100_mxfp8_blockscaled_grouped_mm", &es_sm100_mxfp8_blockscaled_grouped_mm);
583+
m.def(
584+
"es_sm100_mxfp8_blockscaled_grouped_quant(Tensor input, Tensor problem_sizes, Tensor expert_offsets, Tensor "
585+
"blockscale_offsets, Tensor quant_output, Tensor scale_factor) -> () ");
586+
m.impl("es_sm100_mxfp8_blockscaled_grouped_quant", &es_sm100_mxfp8_blockscaled_grouped_quant);
579587

580588
/*
581589
* From fast-hadamard-transform
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
#include <torch/all.h>
2+
3+
#include "es_sm100_mxfp8_blockscaled_launcher.cuh"
4+
5+
void es_sm100_mxfp8_blockscaled_grouped_mm(
6+
const torch::Tensor& a,
7+
const torch::Tensor& b,
8+
const torch::Tensor& sfa,
9+
const torch::Tensor& sfb,
10+
torch::Tensor& d,
11+
const torch::Tensor& problem_sizes,
12+
const torch::Tensor& expert_offsets,
13+
const torch::Tensor& blockscale_offsets) {
14+
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
15+
TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor");
16+
TORCH_CHECK(problem_sizes.size(1) == 3, "problem_sizes must have shape (num_experts, 3)");
17+
TORCH_CHECK(
18+
problem_sizes.size(0) == expert_offsets.size(0), "Number of experts in problem_sizes must match expert_offsets");
19+
TORCH_CHECK(problem_sizes.dtype() == torch::kInt32, "problem_sizes must be int32");
20+
TORCH_CHECK(a.dim() == 2, "a must be a 2D tensor of shape (num_tokens, k)");
21+
TORCH_CHECK(b.dim() == 3, "b must be a 3D tensor of shape (num_experts, k, n)");
22+
TORCH_CHECK(a.size(1) == b.size(1) && a.size(1) % 128 == 0, "k should align 128");
23+
TORCH_CHECK(b.size(2) % 128 == 0, "n should align 128");
24+
TORCH_CHECK(a.strides()[1] == 1, "a must be row major");
25+
TORCH_CHECK(b.strides()[1] == 1, "a must be column major");
26+
27+
auto stream = at::cuda::getCurrentCUDAStream();
28+
if (d.dtype() == torch::kBFloat16) {
29+
expert_specialization::es_sm100_mxfp8_blockscaled_group_mm_dispatch_out_dtype<cutlass::bfloat16_t>(
30+
a, b, sfa, sfb, d, problem_sizes, expert_offsets, blockscale_offsets, stream);
31+
} else if (d.dtype() == torch::kFloat16) {
32+
expert_specialization::es_sm100_mxfp8_blockscaled_group_mm_dispatch_out_dtype<cutlass::half_t>(
33+
a, b, sfa, sfb, d, problem_sizes, expert_offsets, blockscale_offsets, stream);
34+
} else {
35+
TORCH_CHECK(false, "dtype must be kFloat16 or kBFloat16");
36+
}
37+
#else
38+
TORCH_CHECK(false, "No implemented es_sm100_mxfp8_blockscaled_grouped_mm for current device");
39+
#endif
40+
}
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
#pragma once
2+
#include <cuda.h>
3+
4+
#include "cute/tensor.hpp"
5+
#include "cutlass/util/packed_stride.hpp"
6+
#include "es_sm100_mxfp8_blockscaled_traits.cuh"
7+
8+
namespace expert_specialization {
9+
10+
using namespace cute;
11+
12+
template <typename GemmTraits>
13+
struct Sm100Mxfp8BlockScaledOffsetFunctor {
14+
using Gemm = typename GemmTraits::Gemm;
15+
using ElementA = typename Gemm::ElementA;
16+
using ElementB = typename Gemm::ElementB;
17+
using ElementSF = typename GemmTraits::ElementSF;
18+
using ElementD = typename GemmTraits::ElementOutput;
19+
// Input
20+
int* expert_offsets{nullptr};
21+
int* blockscale_offsets{nullptr};
22+
// Output
23+
ElementA* a_base{nullptr};
24+
ElementB* b_base{nullptr};
25+
ElementSF* sfa_base{nullptr};
26+
ElementSF* sfb_base{nullptr};
27+
ElementD* d_base{nullptr};
28+
ElementA** a_offsets{nullptr};
29+
ElementB** b_offsets{nullptr};
30+
ElementSF** sfa_offsets{nullptr};
31+
ElementSF** sfb_offsets{nullptr};
32+
ElementD** d_offsets{nullptr};
33+
34+
Sm100Mxfp8BlockScaledOffsetFunctor() = default;
35+
Sm100Mxfp8BlockScaledOffsetFunctor(
36+
int* _expert_offsets,
37+
int* _blockscale_offsets,
38+
ElementA* _a_base,
39+
ElementB* _b_base,
40+
ElementSF* _sfa_base,
41+
ElementSF* _sfb_base,
42+
ElementD* _d_base,
43+
ElementA** _a_offsets,
44+
ElementB** _b_offsets,
45+
ElementSF** _sfa_offsets,
46+
ElementSF** _sfb_offsets,
47+
ElementD** _d_offsets)
48+
: expert_offsets{_expert_offsets},
49+
blockscale_offsets{_blockscale_offsets},
50+
a_base(_a_base),
51+
b_base(_b_base),
52+
sfa_base(_sfa_base),
53+
sfb_base(_sfb_base),
54+
d_base(_d_base),
55+
a_offsets(_a_offsets),
56+
b_offsets(_b_offsets),
57+
sfa_offsets(_sfa_offsets),
58+
sfb_offsets(_sfb_offsets),
59+
d_offsets(_d_offsets) {}
60+
61+
void CUTE_DEVICE operator()(int64_t expert_id, int m, int n, int k) {
62+
int64_t expert_offset = static_cast<int64_t>(expert_offsets[expert_id]);
63+
int64_t blockscale_offset = static_cast<int64_t>(blockscale_offsets[expert_id]);
64+
int64_t a_stride = expert_offset * k;
65+
int64_t b_stride = expert_id * k * n;
66+
int64_t d_stride = expert_offset * n;
67+
int64_t sfa_stride = blockscale_offset * (k / 32);
68+
int64_t sfb_stride = expert_id * n * (k / 32);
69+
70+
a_offsets[expert_id] = a_base + a_stride;
71+
b_offsets[expert_id] = b_base + b_stride;
72+
sfa_offsets[expert_id] = sfa_base + sfa_stride;
73+
sfb_offsets[expert_id] = sfb_base + sfb_stride;
74+
d_offsets[expert_id] = d_base + d_stride;
75+
}
76+
};
77+
78+
template <typename GemmTraits>
79+
struct Sm100Mxfp8BlockScaledLayoutFunctor {
80+
using Sm1xxBlkScaledConfig = typename GemmTraits::Sm1xxBlkScaledConfig;
81+
using LayoutSFA = typename GemmTraits::LayoutSFA;
82+
using LayoutSFB = typename GemmTraits::LayoutSFB;
83+
LayoutSFA* layout_sfa_base{nullptr};
84+
LayoutSFB* layout_sfb_base{nullptr};
85+
86+
Sm100Mxfp8BlockScaledLayoutFunctor() = default;
87+
Sm100Mxfp8BlockScaledLayoutFunctor(LayoutSFA* _layout_sfa_base, LayoutSFB* _layout_sfb_base)
88+
: layout_sfa_base(_layout_sfa_base), layout_sfb_base(_layout_sfb_base) {}
89+
90+
void CUTE_DEVICE operator()(int64_t expert_id, int m, int n, int k) {
91+
LayoutSFA* layout_sfa_ptr = layout_sfa_base + expert_id;
92+
LayoutSFB* layout_sfb_ptr = layout_sfb_base + expert_id;
93+
*layout_sfa_ptr = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(m, n, k, 1));
94+
*layout_sfb_ptr = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(m, n, k, 1));
95+
}
96+
};
97+
98+
template <typename GemmTraits>
99+
struct Sm100Mxfp8BlockScaledStrideFunctor {
100+
using StrideA = typename GemmTraits::StrideA;
101+
using StrideB = typename GemmTraits::StrideB;
102+
using StrideD = typename GemmTraits::StrideD;
103+
StrideA* stride_A_base{nullptr};
104+
StrideB* stride_B_base{nullptr};
105+
StrideD* stride_D_base{nullptr};
106+
107+
Sm100Mxfp8BlockScaledStrideFunctor() = default;
108+
Sm100Mxfp8BlockScaledStrideFunctor(StrideA* _stride_A_base, StrideB* _stride_B_base, StrideD* _stride_D_base)
109+
: stride_A_base(_stride_A_base), stride_B_base(_stride_B_base), stride_D_base(_stride_D_base) {}
110+
111+
void CUTE_DEVICE operator()(int64_t expert_id, int m, int n, int k) {
112+
StrideA* stride_A = stride_A_base + expert_id;
113+
StrideB* stride_B = stride_B_base + expert_id;
114+
StrideD* stride_D = stride_D_base + expert_id;
115+
*stride_A = cutlass::make_cute_packed_stride(StrideA{}, {m, k, 1});
116+
*stride_B = cutlass::make_cute_packed_stride(StrideB{}, {n, k, 1});
117+
*stride_D = cutlass::make_cute_packed_stride(StrideD{}, {m, n, 1});
118+
}
119+
};
120+
121+
template <typename OffsetFunctor, typename LayoutFunctor, typename StrideFunctor>
122+
__global__ void sm100Mxfp8BlockscaledGroupedGemmPreComputeKernel(
123+
int* problem_sizes, OffsetFunctor offset_functor, LayoutFunctor layout_functor, StrideFunctor stride_functor) {
124+
int64_t expert_id = static_cast<int64_t>(threadIdx.x);
125+
int m = problem_sizes[expert_id * 3 + 0];
126+
int n = problem_sizes[expert_id * 3 + 1];
127+
int k = problem_sizes[expert_id * 3 + 2];
128+
129+
offset_functor(expert_id, m, n, k);
130+
layout_functor(expert_id, m, n, k);
131+
stride_functor(expert_id, m, n, k);
132+
}
133+
134+
} // namespace expert_specialization
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
#include <torch/all.h>
2+
3+
#include "es_sm100_mxfp8_blockscaled_group_quant.cuh"
4+
5+
void es_sm100_mxfp8_blockscaled_grouped_quant(
6+
const torch::Tensor& input,
7+
const torch::Tensor& problem_sizes,
8+
const torch::Tensor& expert_offsets,
9+
const torch::Tensor& blockscale_offsets,
10+
torch::Tensor& quant_output,
11+
torch::Tensor& scale_factor) {
12+
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
13+
TORCH_CHECK(input.dim() == 2, "input must be 2D tensor");
14+
TORCH_CHECK(input.size(1) % 128 == 0, "k must align to 128");
15+
TORCH_CHECK(input.strides()[1] == 1, "input must be row major");
16+
TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor");
17+
18+
auto groups = problem_sizes.size(0);
19+
TORCH_CHECK(
20+
expert_offsets.dim() == 1 && expert_offsets.size(0) == groups,
21+
"expert_offsets must be 1D and have size equal to the number of groups");
22+
TORCH_CHECK(
23+
blockscale_offsets.dim() == 1 && blockscale_offsets.size(0) == groups,
24+
"blockscale_offsets must be 1D and have size equal to the number of groups");
25+
26+
auto stream = at::cuda::getCurrentCUDAStream();
27+
if (input.dtype() == torch::kBFloat16) {
28+
expert_specialization::launch_es_sm100_mxfp8_blockscaled_grouped_quant<__nv_bfloat16>(
29+
input, problem_sizes, expert_offsets, blockscale_offsets, quant_output, scale_factor);
30+
} else if (input.dtype() == torch::kFloat16) {
31+
expert_specialization::launch_es_sm100_mxfp8_blockscaled_grouped_quant<__half>(
32+
input, problem_sizes, expert_offsets, blockscale_offsets, quant_output, scale_factor);
33+
} else {
34+
TORCH_CHECK(false, "dtype must be kFloat16 or kBFloat16");
35+
}
36+
#else
37+
TORCH_CHECK(false, "No implemented es_sm100_mxfp8_blockscaled_grouped_mm for current device");
38+
#endif
39+
}

0 commit comments

Comments
 (0)