Skip to content

Commit cb9d91e

Browse files
feat: support DeepSeek-R1-W4AFP8 model with ep-moe mode (#7762)
Signed-off-by: yangsijia.614 <yangsijia.614@bytedance.com>
1 parent 6a6e0bb commit cb9d91e

File tree

10 files changed

+1006
-9
lines changed

10 files changed

+1006
-9
lines changed

python/sglang/srt/configs/model_config.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,17 @@ def _parse_quant_hf_config(self):
359359
if hf_api.file_exists(self.model_path, "hf_quant_config.json"):
360360
quant_cfg = modelopt_quant_config
361361
elif os.path.exists(os.path.join(self.model_path, "hf_quant_config.json")):
362-
quant_cfg = modelopt_quant_config
362+
quant_config_file = os.path.join(
363+
self.model_path, "hf_quant_config.json"
364+
)
365+
with open(quant_config_file) as f:
366+
quant_config_dict = json.load(f)
367+
json_quant_configs = quant_config_dict["quantization"]
368+
quant_algo = json_quant_configs.get("quant_algo", None)
369+
if quant_algo == "MIXED_PRECISION":
370+
quant_cfg = {"quant_method": "w4afp8"}
371+
else:
372+
quant_cfg = modelopt_quant_config
363373
return quant_cfg
364374

365375
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
@@ -389,6 +399,7 @@ def _verify_quantization(self) -> None:
389399
"w8a8_fp8",
390400
"moe_wna16",
391401
"qoq",
402+
"w4afp8",
392403
]
393404
compatible_quantization_methods = {
394405
"modelopt_fp4": ["modelopt"],
Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""Cutlass W4A8 MoE kernel."""
3+
from typing import Optional
4+
5+
import torch
6+
from sgl_kernel import (
7+
cutlass_w4a8_moe_mm,
8+
get_cutlass_w4a8_moe_mm_data,
9+
sgl_per_tensor_quant_fp8,
10+
silu_and_mul,
11+
)
12+
13+
from sglang.srt.layers.moe.ep_moe.kernels import (
14+
post_reorder_triton_kernel,
15+
pre_reorder_triton_kernel_for_cutlass_moe,
16+
run_cutlass_moe_ep_preproess,
17+
)
18+
19+
20+
def cutlass_w4a8_moe(
21+
start_expert_id: int,
22+
end_expert_id: int,
23+
total_num_experts: int,
24+
a: torch.Tensor,
25+
w1_q: torch.Tensor,
26+
w2_q: torch.Tensor,
27+
w1_scale: torch.Tensor,
28+
w2_scale: torch.Tensor,
29+
topk_weights: torch.Tensor,
30+
topk_ids_: torch.Tensor,
31+
local_topk_ids: torch.Tensor,
32+
a_strides1: torch.Tensor,
33+
b_strides1: torch.Tensor,
34+
c_strides1: torch.Tensor,
35+
a_strides2: torch.Tensor,
36+
b_strides2: torch.Tensor,
37+
c_strides2: torch.Tensor,
38+
s_strides13: torch.Tensor,
39+
s_strides2: torch.Tensor,
40+
expert_offsets: torch.Tensor,
41+
problem_sizes1: torch.Tensor,
42+
problem_sizes2: torch.Tensor,
43+
a1_scale: Optional[torch.Tensor] = None,
44+
a2_scale: Optional[torch.Tensor] = None,
45+
apply_router_weight_on_input: bool = False,
46+
) -> torch.Tensor:
47+
"""
48+
This function computes a w4a8-quantized Mixture of Experts (MoE) layer
49+
using two sets of quantized weights, w1_q and w2_q, and top-k gating
50+
mechanism. The matrix multiplications are implemented with CUTLASS
51+
grouped gemm.
52+
53+
Parameters:
54+
- a (torch.Tensor): The input tensor to the MoE layer.
55+
Shape: [M, K]
56+
- w1_q (torch.Tensor): The first set of int4-quantized expert weights.
57+
Shape: [num_experts, N * 2, K // 2]
58+
(the weights are passed transposed and int4-packed)
59+
- w2_q (torch.Tensor): The second set of int4-quantized expert weights.
60+
Shape: [num_experts, K, N // 2]
61+
(the weights are passed transposed and int4-packed)
62+
- w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q.
63+
Shape: [num_experts, K // 512, N * 8]
64+
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
65+
Shape: [num_experts, N // 512, K * 4]
66+
- topk_weights (torch.Tensor): The weights of each token->expert mapping.
67+
- a_strides1 (torch.Tensor): The input strides of the first grouped gemm.
68+
- b_strides1 (torch.Tensor): The weights strides of the first grouped gemm.
69+
- c_strides1 (torch.Tensor): The output strides of the first grouped gemm.
70+
- a_strides2 (torch.Tensor): The input strides of the second grouped gemm.
71+
- b_strides2 (torch.Tensor): The weights strides of the second grouped gemm.
72+
- c_strides2 (torch.Tensor): The output strides of the second grouped gemm.
73+
- s_strides13 (torch.Tensor): The input and scale strides of the first grouped gemm.
74+
- s_strides2 (torch.Tensor): The scale strides of the second grouped gemm.
75+
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
76+
Shape: scalar or [1, K]
77+
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
78+
quantize the intermediate result between the gemms.
79+
Shape: scalar or [1, N]
80+
- apply_router_weight_on_input (bool): When true, the topk weights are
81+
applied directly on the inputs. This is only applicable when topk is 1.
82+
83+
Returns:
84+
- torch.Tensor: The fp8 output tensor after applying the MoE layer.
85+
"""
86+
assert topk_weights.shape == topk_ids_.shape, "topk shape mismatch"
87+
assert w1_q.dtype == torch.int8
88+
assert w2_q.dtype == torch.int8
89+
assert a.shape[1] // 2 == w1_q.shape[2], "Hidden size mismatch w1"
90+
assert w1_q.shape[2] * 2 == w2_q.shape[1], "Hidden size mismatch w2"
91+
assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch"
92+
assert w1_q.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch"
93+
assert w1_q.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch"
94+
assert (
95+
w1_scale.shape[1] == w1_q.shape[2] * 2 / 512
96+
and w1_scale.shape[2] == w1_q.shape[1] * 4
97+
), "W1 scale shape mismatch"
98+
assert (
99+
w2_scale.shape[1] == w2_q.shape[2] * 2 / 512
100+
and w2_scale.shape[2] == w2_q.shape[1] * 4
101+
), "W2 scale shape mismatch"
102+
103+
assert a_strides1.shape[0] == w1_q.shape[0], "A Strides 1 expert number mismatch"
104+
assert b_strides1.shape[0] == w1_q.shape[0], "B Strides 1 expert number mismatch"
105+
assert a_strides2.shape[0] == w2_q.shape[0], "A Strides 2 expert number mismatch"
106+
assert b_strides2.shape[0] == w2_q.shape[0], "B Strides 2 expert number mismatch"
107+
num_experts = w1_q.size(0)
108+
m = a.size(0)
109+
k = w1_q.size(2) * 2 # w1_q is transposed and packed
110+
n = w2_q.size(2) * 2 # w2_q is transposed and packed
111+
topk = topk_ids_.size(1)
112+
113+
if apply_router_weight_on_input:
114+
assert topk == 1, "apply_router_weight_on_input is only implemented for topk=1"
115+
116+
device = a.device
117+
118+
_, src2dst, _ = run_cutlass_moe_ep_preproess(
119+
local_topk_ids,
120+
num_experts,
121+
)
122+
123+
gateup_input = torch.empty(
124+
(m * topk, k),
125+
device=device,
126+
dtype=torch.float8_e4m3fn,
127+
)
128+
129+
pre_reorder_triton_kernel_for_cutlass_moe[(m,)](
130+
a,
131+
gateup_input,
132+
src2dst,
133+
local_topk_ids,
134+
a1_scale,
135+
total_num_experts,
136+
topk,
137+
k,
138+
BLOCK_SIZE=512,
139+
)
140+
141+
# NOTE: a_map and c_map are not used in the get_cutlass_w4a8_moe_mm_data kernel,
142+
# they are kept to allow for a quick switch of the permutation logic
143+
# from the current triton kernel implementation to the cutlass-based one if needed.
144+
a_map = torch.empty((local_topk_ids.numel()), dtype=torch.int32, device=device)
145+
c_map = torch.empty((local_topk_ids.numel()), dtype=torch.int32, device=device)
146+
get_cutlass_w4a8_moe_mm_data(
147+
local_topk_ids,
148+
expert_offsets,
149+
problem_sizes1,
150+
problem_sizes2,
151+
a_map,
152+
c_map,
153+
num_experts,
154+
n,
155+
k,
156+
)
157+
158+
c1 = torch.empty((m * topk, n * 2), device=device, dtype=torch.half)
159+
c2 = torch.zeros((m * topk, k), device=device, dtype=torch.half)
160+
161+
cutlass_w4a8_moe_mm(
162+
c1,
163+
gateup_input,
164+
w1_q,
165+
a1_scale.float(),
166+
w1_scale,
167+
expert_offsets[:-1],
168+
problem_sizes1,
169+
a_strides1,
170+
b_strides1,
171+
c_strides1,
172+
s_strides13,
173+
128,
174+
topk,
175+
)
176+
177+
intermediate = torch.empty((m * topk, n), device=device, dtype=torch.half)
178+
silu_and_mul(c1, intermediate)
179+
180+
intermediate_q = torch.empty(
181+
intermediate.shape, dtype=torch.float8_e4m3fn, device=device
182+
)
183+
sgl_per_tensor_quant_fp8(intermediate, intermediate_q, a2_scale.float(), True)
184+
185+
cutlass_w4a8_moe_mm(
186+
c2,
187+
intermediate_q,
188+
w2_q,
189+
a2_scale.float(),
190+
w2_scale,
191+
expert_offsets[:-1],
192+
problem_sizes2,
193+
a_strides2,
194+
b_strides2,
195+
c_strides2,
196+
s_strides2,
197+
128,
198+
topk,
199+
)
200+
201+
output = torch.empty_like(a)
202+
post_reorder_triton_kernel[(m,)](
203+
c2,
204+
output,
205+
src2dst,
206+
topk_ids_,
207+
topk_weights,
208+
start_expert_id,
209+
end_expert_id,
210+
topk,
211+
k,
212+
0,
213+
BLOCK_SIZE=512,
214+
)
215+
return output

python/sglang/srt/layers/moe/ep_moe/kernels.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ def compute_seg_indptr_triton_kernel(reorder_topk_ids, seg_indptr, num_toks):
146146

147147
def run_moe_ep_preproess(topk_ids: torch.Tensor, num_experts: int):
148148
reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True)
149+
149150
seg_indptr = torch.zeros(num_experts + 1, device=topk_ids.device, dtype=torch.int64)
150151
src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32)
151152

@@ -158,9 +159,66 @@ def run_moe_ep_preproess(topk_ids: torch.Tensor, num_experts: int):
158159
compute_src2dst_triton_kernel[grid](
159160
reorder_ids, src2dst, topk_ids.numel(), BLOCK_SIZE
160161
)
162+
161163
return reorder_topk_ids, src2dst, seg_indptr
162164

163165

166+
def run_cutlass_moe_ep_preproess(local_topk_ids: torch.Tensor, local_num_experts: int):
167+
reorder_topk_ids, reorder_ids = torch.sort(local_topk_ids.view(-1), stable=True)
168+
169+
seg_indptr = torch.zeros(
170+
local_num_experts + 1, device=local_topk_ids.device, dtype=torch.int64
171+
)
172+
src2dst = torch.empty(
173+
local_topk_ids.numel(), device=local_topk_ids.device, dtype=torch.int32
174+
)
175+
176+
BLOCK_SIZE = 512
177+
grid = (triton.cdiv(local_topk_ids.numel(), BLOCK_SIZE),)
178+
compute_src2dst_triton_kernel[grid](
179+
reorder_ids, src2dst, local_topk_ids.numel(), BLOCK_SIZE
180+
)
181+
182+
return reorder_topk_ids, src2dst, seg_indptr
183+
184+
185+
@triton.jit
186+
def pre_reorder_triton_kernel_for_cutlass_moe(
187+
input_ptr,
188+
gateup_input_ptr,
189+
src2dst_ptr,
190+
topk_ids_ptr,
191+
a1_scales_ptr,
192+
num_experts,
193+
topk,
194+
hidden_size,
195+
BLOCK_SIZE: tl.constexpr,
196+
):
197+
OutDtype = gateup_input_ptr.dtype.element_ty
198+
199+
src_idx = tl.program_id(0)
200+
src2dst_ptr = src2dst_ptr + src_idx * topk
201+
topk_ids_ptr = topk_ids_ptr + src_idx * topk
202+
203+
src_ptr = input_ptr + src_idx * hidden_size
204+
for idx in range(topk):
205+
expert_id = tl.load(topk_ids_ptr + idx)
206+
if expert_id != num_experts:
207+
if a1_scales_ptr is not None:
208+
scale = 1.0 / tl.load(a1_scales_ptr)
209+
else:
210+
scale = 1.0
211+
212+
dst_idx = tl.load(src2dst_ptr + idx)
213+
dst_ptr = gateup_input_ptr + dst_idx * hidden_size
214+
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
215+
offset = start_offset + tl.arange(0, BLOCK_SIZE)
216+
mask = offset < hidden_size
217+
in_data = tl.load(src_ptr + offset, mask=mask).to(tl.float32)
218+
out_data = (in_data * scale).to(OutDtype)
219+
tl.store(dst_ptr + offset, out_data, mask=mask)
220+
221+
164222
@triton.jit
165223
def pre_reorder_triton_kernel(
166224
input_ptr,

0 commit comments

Comments
 (0)