Skip to content
2 changes: 2 additions & 0 deletions sgl-kernel/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,8 @@ set(SOURCES
"csrc/elementwise/rope.cu"
"csrc/elementwise/topk.cu"
"csrc/expert_specialization/es_fp8_blockwise.cu"
"csrc/expert_specialization/es_sm100_mxfp8_blockscaled.cu"
"csrc/expert_specialization/es_sm100_mxfp8_blockscaled_group_quant.cu"

"csrc/gemm/awq_kernel.cu"
"csrc/gemm/bmm_fp8.cu"
Expand Down
8 changes: 8 additions & 0 deletions sgl-kernel/csrc/common_extension.cc
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,14 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"stride_a, Tensor stride_b, Tensor stride_d, Tensor problem_sizes, Tensor expert_offsets, Tensor workspace) -> "
"()");
m.impl("es_fp8_blockwise_scaled_grouped_mm", &es_fp8_blockwise_scaled_grouped_mm);
m.def(
"es_sm100_mxfp8_blockscaled_grouped_mm(Tensor a, Tensor b, Tensor sfa, Tensor sfb, Tensor d, Tensor "
"problem_sizes, Tensor expert_offsets, Tensor blockscale_offsets) -> ()");
m.impl("es_sm100_mxfp8_blockscaled_grouped_mm", &es_sm100_mxfp8_blockscaled_grouped_mm);
m.def(
"es_sm100_mxfp8_blockscaled_grouped_quant(Tensor input, Tensor problem_sizes, Tensor expert_offsets, Tensor "
"blockscale_offsets, Tensor quant_output, Tensor scale_factor) -> () ");
m.impl("es_sm100_mxfp8_blockscaled_grouped_quant", &es_sm100_mxfp8_blockscaled_grouped_quant);

/*
* From fast-hadamard-transform
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#include <torch/all.h>

#include "es_sm100_mxfp8_blockscaled_launcher.cuh"

void es_sm100_mxfp8_blockscaled_grouped_mm(
const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& sfa,
const torch::Tensor& sfb,
torch::Tensor& d,
const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets,
const torch::Tensor& blockscale_offsets) {
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor");
TORCH_CHECK(problem_sizes.size(1) == 3, "problem_sizes must have shape (num_experts, 3)");
TORCH_CHECK(
problem_sizes.size(0) == expert_offsets.size(0), "Number of experts in problem_sizes must match expert_offsets");
TORCH_CHECK(problem_sizes.dtype() == torch::kInt32, "problem_sizes must be int32");
TORCH_CHECK(a.dim() == 2, "a must be a 2D tensor of shape (num_tokens, k)");
TORCH_CHECK(b.dim() == 3, "b must be a 3D tensor of shape (num_experts, k, n)");
TORCH_CHECK(a.size(1) == b.size(1) && a.size(1) % 128 == 0, "k should align 128");
TORCH_CHECK(b.size(2) % 128 == 0, "n should align 128");
TORCH_CHECK(a.strides()[1] == 1, "a must be row major");
TORCH_CHECK(b.strides()[1] == 1, "a must be column major");

auto stream = at::cuda::getCurrentCUDAStream();
if (d.dtype() == torch::kBFloat16) {
expert_specialization::es_sm100_mxfp8_blockscaled_group_mm_dispatch_out_dtype<cutlass::bfloat16_t>(
a, b, sfa, sfb, d, problem_sizes, expert_offsets, blockscale_offsets, stream);
} else if (d.dtype() == torch::kFloat16) {
expert_specialization::es_sm100_mxfp8_blockscaled_group_mm_dispatch_out_dtype<cutlass::half_t>(
a, b, sfa, sfb, d, problem_sizes, expert_offsets, blockscale_offsets, stream);
} else {
TORCH_CHECK(false, "dtype must be kFloat16 or kBFloat16");
}
#else
TORCH_CHECK(false, "No implemented es_sm100_mxfp8_blockscaled_grouped_mm for current device");
#endif
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
#pragma once
#include <cuda.h>

#include "cute/tensor.hpp"
#include "cutlass/util/packed_stride.hpp"
#include "es_sm100_mxfp8_blockscaled_traits.cuh"

namespace expert_specialization {

using namespace cute;

template <typename GemmTraits>
struct Sm100Mxfp8BlockScaledOffsetFunctor {
using Gemm = typename GemmTraits::Gemm;
using ElementA = typename Gemm::ElementA;
using ElementB = typename Gemm::ElementB;
using ElementSF = typename GemmTraits::ElementSF;
using ElementD = typename GemmTraits::ElementOutput;
// Input
int* expert_offsets{nullptr};
int* blockscale_offsets{nullptr};
// Output
ElementA* a_base{nullptr};
ElementB* b_base{nullptr};
ElementSF* sfa_base{nullptr};
ElementSF* sfb_base{nullptr};
ElementD* d_base{nullptr};
ElementA** a_offsets{nullptr};
ElementB** b_offsets{nullptr};
ElementSF** sfa_offsets{nullptr};
ElementSF** sfb_offsets{nullptr};
ElementD** d_offsets{nullptr};

Sm100Mxfp8BlockScaledOffsetFunctor() = default;
Sm100Mxfp8BlockScaledOffsetFunctor(
int* _expert_offsets,
int* _blockscale_offsets,
ElementA* _a_base,
ElementB* _b_base,
ElementSF* _sfa_base,
ElementSF* _sfb_base,
ElementD* _d_base,
ElementA** _a_offsets,
ElementB** _b_offsets,
ElementSF** _sfa_offsets,
ElementSF** _sfb_offsets,
ElementD** _d_offsets)
: expert_offsets{_expert_offsets},
blockscale_offsets{_blockscale_offsets},
a_base(_a_base),
b_base(_b_base),
sfa_base(_sfa_base),
sfb_base(_sfb_base),
d_base(_d_base),
a_offsets(_a_offsets),
b_offsets(_b_offsets),
sfa_offsets(_sfa_offsets),
sfb_offsets(_sfb_offsets),
d_offsets(_d_offsets) {}

void CUTE_DEVICE operator()(int64_t expert_id, int m, int n, int k) {
int64_t expert_offset = static_cast<int64_t>(expert_offsets[expert_id]);
int64_t blockscale_offset = static_cast<int64_t>(blockscale_offsets[expert_id]);
int64_t a_stride = expert_offset * k;
int64_t b_stride = expert_id * k * n;
int64_t d_stride = expert_offset * n;
int64_t sfa_stride = blockscale_offset * (k / 32);
int64_t sfb_stride = expert_id * n * (k / 32);

a_offsets[expert_id] = a_base + a_stride;
b_offsets[expert_id] = b_base + b_stride;
sfa_offsets[expert_id] = sfa_base + sfa_stride;
sfb_offsets[expert_id] = sfb_base + sfb_stride;
d_offsets[expert_id] = d_base + d_stride;
}
};

template <typename GemmTraits>
struct Sm100Mxfp8BlockScaledLayoutFunctor {
using Sm1xxBlkScaledConfig = typename GemmTraits::Sm1xxBlkScaledConfig;
using LayoutSFA = typename GemmTraits::LayoutSFA;
using LayoutSFB = typename GemmTraits::LayoutSFB;
LayoutSFA* layout_sfa_base{nullptr};
LayoutSFB* layout_sfb_base{nullptr};

Sm100Mxfp8BlockScaledLayoutFunctor() = default;
Sm100Mxfp8BlockScaledLayoutFunctor(LayoutSFA* _layout_sfa_base, LayoutSFB* _layout_sfb_base)
: layout_sfa_base(_layout_sfa_base), layout_sfb_base(_layout_sfb_base) {}

void CUTE_DEVICE operator()(int64_t expert_id, int m, int n, int k) {
LayoutSFA* layout_sfa_ptr = layout_sfa_base + expert_id;
LayoutSFB* layout_sfb_ptr = layout_sfb_base + expert_id;
*layout_sfa_ptr = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(m, n, k, 1));
*layout_sfb_ptr = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(m, n, k, 1));
}
};

template <typename GemmTraits>
struct Sm100Mxfp8BlockScaledStrideFunctor {
using StrideA = typename GemmTraits::StrideA;
using StrideB = typename GemmTraits::StrideB;
using StrideD = typename GemmTraits::StrideD;
StrideA* stride_A_base{nullptr};
StrideB* stride_B_base{nullptr};
StrideD* stride_D_base{nullptr};

Sm100Mxfp8BlockScaledStrideFunctor() = default;
Sm100Mxfp8BlockScaledStrideFunctor(StrideA* _stride_A_base, StrideB* _stride_B_base, StrideD* _stride_D_base)
: stride_A_base(_stride_A_base), stride_B_base(_stride_B_base), stride_D_base(_stride_D_base) {}

void CUTE_DEVICE operator()(int64_t expert_id, int m, int n, int k) {
StrideA* stride_A = stride_A_base + expert_id;
StrideB* stride_B = stride_B_base + expert_id;
StrideD* stride_D = stride_D_base + expert_id;
*stride_A = cutlass::make_cute_packed_stride(StrideA{}, {m, k, 1});
*stride_B = cutlass::make_cute_packed_stride(StrideB{}, {n, k, 1});
*stride_D = cutlass::make_cute_packed_stride(StrideD{}, {m, n, 1});
}
};

template <typename OffsetFunctor, typename LayoutFunctor, typename StrideFunctor>
__global__ void sm100Mxfp8BlockscaledGroupedGemmPreComputeKernel(
int* problem_sizes, OffsetFunctor offset_functor, LayoutFunctor layout_functor, StrideFunctor stride_functor) {
int64_t expert_id = static_cast<int64_t>(threadIdx.x);
int m = problem_sizes[expert_id * 3 + 0];
int n = problem_sizes[expert_id * 3 + 1];
int k = problem_sizes[expert_id * 3 + 2];

offset_functor(expert_id, m, n, k);
layout_functor(expert_id, m, n, k);
stride_functor(expert_id, m, n, k);
}

} // namespace expert_specialization
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#include <torch/all.h>

#include "es_sm100_mxfp8_blockscaled_group_quant.cuh"

void es_sm100_mxfp8_blockscaled_grouped_quant(
const torch::Tensor& input,
const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets,
const torch::Tensor& blockscale_offsets,
torch::Tensor& quant_output,
torch::Tensor& scale_factor) {
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
TORCH_CHECK(input.dim() == 2, "input must be 2D tensor");
TORCH_CHECK(input.size(1) % 128 == 0, "k must align to 128");
TORCH_CHECK(input.strides()[1] == 1, "input must be row major");
TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor");

auto groups = problem_sizes.size(0);
TORCH_CHECK(
expert_offsets.dim() == 1 && expert_offsets.size(0) == groups,
"expert_offsets must be 1D and have size equal to the number of groups");
TORCH_CHECK(
blockscale_offsets.dim() == 1 && blockscale_offsets.size(0) == groups,
"blockscale_offsets must be 1D and have size equal to the number of groups");

auto stream = at::cuda::getCurrentCUDAStream();
if (input.dtype() == torch::kBFloat16) {
expert_specialization::launch_es_sm100_mxfp8_blockscaled_grouped_quant<__nv_bfloat16>(
input, problem_sizes, expert_offsets, blockscale_offsets, quant_output, scale_factor);
} else if (input.dtype() == torch::kFloat16) {
expert_specialization::launch_es_sm100_mxfp8_blockscaled_grouped_quant<__half>(
input, problem_sizes, expert_offsets, blockscale_offsets, quant_output, scale_factor);
} else {
TORCH_CHECK(false, "dtype must be kFloat16 or kBFloat16");
}
#else
TORCH_CHECK(false, "No implemented es_sm100_mxfp8_blockscaled_grouped_mm for current device");
#endif
}
Loading
Loading