|
| 1 | +#pragma once |
| 2 | +#include <cuda.h> |
| 3 | + |
| 4 | +#include "cute/tensor.hpp" |
| 5 | +#include "cutlass/util/packed_stride.hpp" |
| 6 | +#include "es_sm100_mxfp8_blockscaled_traits.cuh" |
| 7 | + |
| 8 | +namespace expert_specialization { |
| 9 | + |
| 10 | +using namespace cute; |
| 11 | + |
| 12 | +template <typename GemmTraits> |
| 13 | +struct Sm100Mxfp8BlockScaledOffsetFunctor { |
| 14 | + using Gemm = typename GemmTraits::Gemm; |
| 15 | + using ElementA = typename Gemm::ElementA; |
| 16 | + using ElementB = typename Gemm::ElementB; |
| 17 | + using ElementSF = typename GemmTraits::ElementSF; |
| 18 | + using ElementD = typename GemmTraits::ElementOutput; |
| 19 | + // Input |
| 20 | + int* expert_offsets{nullptr}; |
| 21 | + int* blockscale_offsets{nullptr}; |
| 22 | + // Output |
| 23 | + ElementA* a_base{nullptr}; |
| 24 | + ElementB* b_base{nullptr}; |
| 25 | + ElementSF* sfa_base{nullptr}; |
| 26 | + ElementSF* sfb_base{nullptr}; |
| 27 | + ElementD* d_base{nullptr}; |
| 28 | + ElementA** a_offsets{nullptr}; |
| 29 | + ElementB** b_offsets{nullptr}; |
| 30 | + ElementSF** sfa_offsets{nullptr}; |
| 31 | + ElementSF** sfb_offsets{nullptr}; |
| 32 | + ElementD** d_offsets{nullptr}; |
| 33 | + |
| 34 | + Sm100Mxfp8BlockScaledOffsetFunctor() = default; |
| 35 | + Sm100Mxfp8BlockScaledOffsetFunctor( |
| 36 | + int* _expert_offsets, |
| 37 | + int* _blockscale_offsets, |
| 38 | + ElementA* _a_base, |
| 39 | + ElementB* _b_base, |
| 40 | + ElementSF* _sfa_base, |
| 41 | + ElementSF* _sfb_base, |
| 42 | + ElementD* _d_base, |
| 43 | + ElementA** _a_offsets, |
| 44 | + ElementB** _b_offsets, |
| 45 | + ElementSF** _sfa_offsets, |
| 46 | + ElementSF** _sfb_offsets, |
| 47 | + ElementD** _d_offsets) |
| 48 | + : expert_offsets{_expert_offsets}, |
| 49 | + blockscale_offsets{_blockscale_offsets}, |
| 50 | + a_base(_a_base), |
| 51 | + b_base(_b_base), |
| 52 | + sfa_base(_sfa_base), |
| 53 | + sfb_base(_sfb_base), |
| 54 | + d_base(_d_base), |
| 55 | + a_offsets(_a_offsets), |
| 56 | + b_offsets(_b_offsets), |
| 57 | + sfa_offsets(_sfa_offsets), |
| 58 | + sfb_offsets(_sfb_offsets), |
| 59 | + d_offsets(_d_offsets) {} |
| 60 | + |
| 61 | + void CUTE_DEVICE operator()(int64_t expert_id, int m, int n, int k) { |
| 62 | + int64_t expert_offset = static_cast<int64_t>(expert_offsets[expert_id]); |
| 63 | + int64_t blockscale_offset = static_cast<int64_t>(blockscale_offsets[expert_id]); |
| 64 | + int64_t a_stride = expert_offset * k; |
| 65 | + int64_t b_stride = expert_id * k * n; |
| 66 | + int64_t d_stride = expert_offset * n; |
| 67 | + int64_t sfa_stride = blockscale_offset * (k / 32); |
| 68 | + int64_t sfb_stride = expert_id * n * (k / 32); |
| 69 | + |
| 70 | + a_offsets[expert_id] = a_base + a_stride; |
| 71 | + b_offsets[expert_id] = b_base + b_stride; |
| 72 | + sfa_offsets[expert_id] = sfa_base + sfa_stride; |
| 73 | + sfb_offsets[expert_id] = sfb_base + sfb_stride; |
| 74 | + d_offsets[expert_id] = d_base + d_stride; |
| 75 | + } |
| 76 | +}; |
| 77 | + |
| 78 | +template <typename GemmTraits> |
| 79 | +struct Sm100Mxfp8BlockScaledLayoutFunctor { |
| 80 | + using Sm1xxBlkScaledConfig = typename GemmTraits::Sm1xxBlkScaledConfig; |
| 81 | + using LayoutSFA = typename GemmTraits::LayoutSFA; |
| 82 | + using LayoutSFB = typename GemmTraits::LayoutSFB; |
| 83 | + LayoutSFA* layout_sfa_base{nullptr}; |
| 84 | + LayoutSFB* layout_sfb_base{nullptr}; |
| 85 | + |
| 86 | + Sm100Mxfp8BlockScaledLayoutFunctor() = default; |
| 87 | + Sm100Mxfp8BlockScaledLayoutFunctor(LayoutSFA* _layout_sfa_base, LayoutSFB* _layout_sfb_base) |
| 88 | + : layout_sfa_base(_layout_sfa_base), layout_sfb_base(_layout_sfb_base) {} |
| 89 | + |
| 90 | + void CUTE_DEVICE operator()(int64_t expert_id, int m, int n, int k) { |
| 91 | + LayoutSFA* layout_sfa_ptr = layout_sfa_base + expert_id; |
| 92 | + LayoutSFB* layout_sfb_ptr = layout_sfb_base + expert_id; |
| 93 | + *layout_sfa_ptr = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(m, n, k, 1)); |
| 94 | + *layout_sfb_ptr = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(m, n, k, 1)); |
| 95 | + } |
| 96 | +}; |
| 97 | + |
| 98 | +template <typename GemmTraits> |
| 99 | +struct Sm100Mxfp8BlockScaledStrideFunctor { |
| 100 | + using StrideA = typename GemmTraits::StrideA; |
| 101 | + using StrideB = typename GemmTraits::StrideB; |
| 102 | + using StrideD = typename GemmTraits::StrideD; |
| 103 | + StrideA* stride_A_base{nullptr}; |
| 104 | + StrideB* stride_B_base{nullptr}; |
| 105 | + StrideD* stride_D_base{nullptr}; |
| 106 | + |
| 107 | + Sm100Mxfp8BlockScaledStrideFunctor() = default; |
| 108 | + Sm100Mxfp8BlockScaledStrideFunctor(StrideA* _stride_A_base, StrideB* _stride_B_base, StrideD* _stride_D_base) |
| 109 | + : stride_A_base(_stride_A_base), stride_B_base(_stride_B_base), stride_D_base(_stride_D_base) {} |
| 110 | + |
| 111 | + void CUTE_DEVICE operator()(int64_t expert_id, int m, int n, int k) { |
| 112 | + StrideA* stride_A = stride_A_base + expert_id; |
| 113 | + StrideB* stride_B = stride_B_base + expert_id; |
| 114 | + StrideD* stride_D = stride_D_base + expert_id; |
| 115 | + *stride_A = cutlass::make_cute_packed_stride(StrideA{}, {m, k, 1}); |
| 116 | + *stride_B = cutlass::make_cute_packed_stride(StrideB{}, {n, k, 1}); |
| 117 | + *stride_D = cutlass::make_cute_packed_stride(StrideD{}, {m, n, 1}); |
| 118 | + } |
| 119 | +}; |
| 120 | + |
| 121 | +template <typename OffsetFunctor, typename LayoutFunctor, typename StrideFunctor> |
| 122 | +__global__ void sm100Mxfp8BlockscaledGroupedGemmPreComputeKernel( |
| 123 | + int* problem_sizes, OffsetFunctor offset_functor, LayoutFunctor layout_functor, StrideFunctor stride_functor) { |
| 124 | + int64_t expert_id = static_cast<int64_t>(threadIdx.x); |
| 125 | + int m = problem_sizes[expert_id * 3 + 0]; |
| 126 | + int n = problem_sizes[expert_id * 3 + 1]; |
| 127 | + int k = problem_sizes[expert_id * 3 + 2]; |
| 128 | + |
| 129 | + offset_functor(expert_id, m, n, k); |
| 130 | + layout_functor(expert_id, m, n, k); |
| 131 | + stride_functor(expert_id, m, n, k); |
| 132 | +} |
| 133 | + |
| 134 | +} // namespace expert_specialization |
0 commit comments