|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +"""Cutlass W4A8 MoE kernel.""" |
| 3 | +from typing import Optional |
| 4 | + |
| 5 | +import torch |
| 6 | +from sgl_kernel import ( |
| 7 | + cutlass_w4a8_moe_mm, |
| 8 | + get_cutlass_w4a8_moe_mm_data, |
| 9 | + sgl_per_tensor_quant_fp8, |
| 10 | + silu_and_mul, |
| 11 | +) |
| 12 | + |
| 13 | +from sglang.srt.layers.moe.ep_moe.kernels import ( |
| 14 | + post_reorder_triton_kernel, |
| 15 | + pre_reorder_triton_kernel_for_cutlass_moe, |
| 16 | + run_cutlass_moe_ep_preproess, |
| 17 | +) |
| 18 | + |
| 19 | + |
| 20 | +def cutlass_w4a8_moe( |
| 21 | + start_expert_id: int, |
| 22 | + end_expert_id: int, |
| 23 | + total_num_experts: int, |
| 24 | + a: torch.Tensor, |
| 25 | + w1_q: torch.Tensor, |
| 26 | + w2_q: torch.Tensor, |
| 27 | + w1_scale: torch.Tensor, |
| 28 | + w2_scale: torch.Tensor, |
| 29 | + topk_weights: torch.Tensor, |
| 30 | + topk_ids_: torch.Tensor, |
| 31 | + local_topk_ids: torch.Tensor, |
| 32 | + a_strides1: torch.Tensor, |
| 33 | + b_strides1: torch.Tensor, |
| 34 | + c_strides1: torch.Tensor, |
| 35 | + a_strides2: torch.Tensor, |
| 36 | + b_strides2: torch.Tensor, |
| 37 | + c_strides2: torch.Tensor, |
| 38 | + s_strides13: torch.Tensor, |
| 39 | + s_strides2: torch.Tensor, |
| 40 | + expert_offsets: torch.Tensor, |
| 41 | + problem_sizes1: torch.Tensor, |
| 42 | + problem_sizes2: torch.Tensor, |
| 43 | + a1_scale: Optional[torch.Tensor] = None, |
| 44 | + a2_scale: Optional[torch.Tensor] = None, |
| 45 | + apply_router_weight_on_input: bool = False, |
| 46 | +) -> torch.Tensor: |
| 47 | + """ |
| 48 | + This function computes a w4a8-quantized Mixture of Experts (MoE) layer |
| 49 | + using two sets of quantized weights, w1_q and w2_q, and top-k gating |
| 50 | + mechanism. The matrix multiplications are implemented with CUTLASS |
| 51 | + grouped gemm. |
| 52 | +
|
| 53 | + Parameters: |
| 54 | + - a (torch.Tensor): The input tensor to the MoE layer. |
| 55 | + Shape: [M, K] |
| 56 | + - w1_q (torch.Tensor): The first set of int4-quantized expert weights. |
| 57 | + Shape: [num_experts, N * 2, K // 2] |
| 58 | + (the weights are passed transposed and int4-packed) |
| 59 | + - w2_q (torch.Tensor): The second set of int4-quantized expert weights. |
| 60 | + Shape: [num_experts, K, N // 2] |
| 61 | + (the weights are passed transposed and int4-packed) |
| 62 | + - w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q. |
| 63 | + Shape: [num_experts, K // 512, N * 8] |
| 64 | + - w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q. |
| 65 | + Shape: [num_experts, N // 512, K * 4] |
| 66 | + - topk_weights (torch.Tensor): The weights of each token->expert mapping. |
| 67 | + - a_strides1 (torch.Tensor): The input strides of the first grouped gemm. |
| 68 | + - b_strides1 (torch.Tensor): The weights strides of the first grouped gemm. |
| 69 | + - c_strides1 (torch.Tensor): The output strides of the first grouped gemm. |
| 70 | + - a_strides2 (torch.Tensor): The input strides of the second grouped gemm. |
| 71 | + - b_strides2 (torch.Tensor): The weights strides of the second grouped gemm. |
| 72 | + - c_strides2 (torch.Tensor): The output strides of the second grouped gemm. |
| 73 | + - s_strides13 (torch.Tensor): The input and scale strides of the first grouped gemm. |
| 74 | + - s_strides2 (torch.Tensor): The scale strides of the second grouped gemm. |
| 75 | + - a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a. |
| 76 | + Shape: scalar or [1, K] |
| 77 | + - a2_scale (Optional[torch.Tensor]): The optional fp32 scale to |
| 78 | + quantize the intermediate result between the gemms. |
| 79 | + Shape: scalar or [1, N] |
| 80 | + - apply_router_weight_on_input (bool): When true, the topk weights are |
| 81 | + applied directly on the inputs. This is only applicable when topk is 1. |
| 82 | +
|
| 83 | + Returns: |
| 84 | + - torch.Tensor: The fp8 output tensor after applying the MoE layer. |
| 85 | + """ |
| 86 | + assert topk_weights.shape == topk_ids_.shape, "topk shape mismatch" |
| 87 | + assert w1_q.dtype == torch.int8 |
| 88 | + assert w2_q.dtype == torch.int8 |
| 89 | + assert a.shape[1] // 2 == w1_q.shape[2], "Hidden size mismatch w1" |
| 90 | + assert w1_q.shape[2] * 2 == w2_q.shape[1], "Hidden size mismatch w2" |
| 91 | + assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch" |
| 92 | + assert w1_q.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch" |
| 93 | + assert w1_q.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch" |
| 94 | + assert ( |
| 95 | + w1_scale.shape[1] == w1_q.shape[2] * 2 / 512 |
| 96 | + and w1_scale.shape[2] == w1_q.shape[1] * 4 |
| 97 | + ), "W1 scale shape mismatch" |
| 98 | + assert ( |
| 99 | + w2_scale.shape[1] == w2_q.shape[2] * 2 / 512 |
| 100 | + and w2_scale.shape[2] == w2_q.shape[1] * 4 |
| 101 | + ), "W2 scale shape mismatch" |
| 102 | + |
| 103 | + assert a_strides1.shape[0] == w1_q.shape[0], "A Strides 1 expert number mismatch" |
| 104 | + assert b_strides1.shape[0] == w1_q.shape[0], "B Strides 1 expert number mismatch" |
| 105 | + assert a_strides2.shape[0] == w2_q.shape[0], "A Strides 2 expert number mismatch" |
| 106 | + assert b_strides2.shape[0] == w2_q.shape[0], "B Strides 2 expert number mismatch" |
| 107 | + num_experts = w1_q.size(0) |
| 108 | + m = a.size(0) |
| 109 | + k = w1_q.size(2) * 2 # w1_q is transposed and packed |
| 110 | + n = w2_q.size(2) * 2 # w2_q is transposed and packed |
| 111 | + topk = topk_ids_.size(1) |
| 112 | + |
| 113 | + if apply_router_weight_on_input: |
| 114 | + assert topk == 1, "apply_router_weight_on_input is only implemented for topk=1" |
| 115 | + |
| 116 | + device = a.device |
| 117 | + |
| 118 | + _, src2dst, _ = run_cutlass_moe_ep_preproess( |
| 119 | + local_topk_ids, |
| 120 | + num_experts, |
| 121 | + ) |
| 122 | + |
| 123 | + gateup_input = torch.empty( |
| 124 | + (m * topk, k), |
| 125 | + device=device, |
| 126 | + dtype=torch.float8_e4m3fn, |
| 127 | + ) |
| 128 | + |
| 129 | + pre_reorder_triton_kernel_for_cutlass_moe[(m,)]( |
| 130 | + a, |
| 131 | + gateup_input, |
| 132 | + src2dst, |
| 133 | + local_topk_ids, |
| 134 | + a1_scale, |
| 135 | + total_num_experts, |
| 136 | + topk, |
| 137 | + k, |
| 138 | + BLOCK_SIZE=512, |
| 139 | + ) |
| 140 | + |
| 141 | + # NOTE: a_map and c_map are not used in the get_cutlass_w4a8_moe_mm_data kernel, |
| 142 | + # they are kept to allow for a quick switch of the permutation logic |
| 143 | + # from the current triton kernel implementation to the cutlass-based one if needed. |
| 144 | + a_map = torch.empty((local_topk_ids.numel()), dtype=torch.int32, device=device) |
| 145 | + c_map = torch.empty((local_topk_ids.numel()), dtype=torch.int32, device=device) |
| 146 | + get_cutlass_w4a8_moe_mm_data( |
| 147 | + local_topk_ids, |
| 148 | + expert_offsets, |
| 149 | + problem_sizes1, |
| 150 | + problem_sizes2, |
| 151 | + a_map, |
| 152 | + c_map, |
| 153 | + num_experts, |
| 154 | + n, |
| 155 | + k, |
| 156 | + ) |
| 157 | + |
| 158 | + c1 = torch.empty((m * topk, n * 2), device=device, dtype=torch.half) |
| 159 | + c2 = torch.zeros((m * topk, k), device=device, dtype=torch.half) |
| 160 | + |
| 161 | + cutlass_w4a8_moe_mm( |
| 162 | + c1, |
| 163 | + gateup_input, |
| 164 | + w1_q, |
| 165 | + a1_scale.float(), |
| 166 | + w1_scale, |
| 167 | + expert_offsets[:-1], |
| 168 | + problem_sizes1, |
| 169 | + a_strides1, |
| 170 | + b_strides1, |
| 171 | + c_strides1, |
| 172 | + s_strides13, |
| 173 | + 128, |
| 174 | + topk, |
| 175 | + ) |
| 176 | + |
| 177 | + intermediate = torch.empty((m * topk, n), device=device, dtype=torch.half) |
| 178 | + silu_and_mul(c1, intermediate) |
| 179 | + |
| 180 | + intermediate_q = torch.empty( |
| 181 | + intermediate.shape, dtype=torch.float8_e4m3fn, device=device |
| 182 | + ) |
| 183 | + sgl_per_tensor_quant_fp8(intermediate, intermediate_q, a2_scale.float(), True) |
| 184 | + |
| 185 | + cutlass_w4a8_moe_mm( |
| 186 | + c2, |
| 187 | + intermediate_q, |
| 188 | + w2_q, |
| 189 | + a2_scale.float(), |
| 190 | + w2_scale, |
| 191 | + expert_offsets[:-1], |
| 192 | + problem_sizes2, |
| 193 | + a_strides2, |
| 194 | + b_strides2, |
| 195 | + c_strides2, |
| 196 | + s_strides2, |
| 197 | + 128, |
| 198 | + topk, |
| 199 | + ) |
| 200 | + |
| 201 | + output = torch.empty_like(a) |
| 202 | + post_reorder_triton_kernel[(m,)]( |
| 203 | + c2, |
| 204 | + output, |
| 205 | + src2dst, |
| 206 | + topk_ids_, |
| 207 | + topk_weights, |
| 208 | + start_expert_id, |
| 209 | + end_expert_id, |
| 210 | + topk, |
| 211 | + k, |
| 212 | + 0, |
| 213 | + BLOCK_SIZE=512, |
| 214 | + ) |
| 215 | + return output |
0 commit comments