33import torch
44import torch .nn .functional as F
55from tqdm import tqdm
6- from vllm .model_executor .layers .fused_moe import fused_moe as fused_moe_vllm
76
87from sglang .srt .layers .activation import SiluAndMul
98from sglang .srt .layers .moe .fused_moe_triton .fused_moe import fused_moe
@@ -45,20 +44,49 @@ def get_tolerance(self, dtype):
4544 else :
4645 return 1e-2 , 1e-2 # Default values for other types
4746
48- def torch_naive_moe (self , a , w1 , w2 , score , topk ):
47+ def torch_naive_moe (
48+ self ,
49+ a ,
50+ w1 ,
51+ w2 ,
52+ score ,
53+ topk ,
54+ w1_scale = None ,
55+ w2_scale = None ,
56+ a1_scale = None ,
57+ a2_scale = None ,
58+ ):
4959 B , D = a .shape
5060 a = a .view (B , - 1 , D ).repeat (1 , topk , 1 ).reshape (- 1 , D )
5161 out = torch .zeros (B * topk , w2 .shape [1 ], dtype = a .dtype , device = a .device )
5262 score = torch .softmax (score , dim = - 1 , dtype = torch .float32 )
5363 topk_weight , topk_ids = torch .topk (score , topk )
5464 topk_weight = topk_weight .view (- 1 )
5565 topk_ids = topk_ids .view (- 1 )
56- for i in range (w1 .shape [0 ]):
66+
67+ if w1 .dtype == torch .float8_e4m3fn :
68+ w1_compute = w1 .to (a .dtype )
69+ w2_compute = w2 .to (a .dtype )
70+
71+ if w1_scale is not None :
72+ w1_compute = (w1_compute * w1_scale .view (- 1 , 1 , 1 )).to (a .dtype )
73+ if w2_scale is not None :
74+ w2_compute = (w2_compute * w2_scale .view (- 1 , 1 , 1 )).to (a .dtype )
75+ if a1_scale is not None :
76+ a = (a * a1_scale ).to (a .dtype )
77+ if a2_scale is not None :
78+ a = (a * a2_scale ).to (a .dtype )
79+ else :
80+ w1_compute = w1
81+ w2_compute = w2
82+
83+ for i in range (w1_compute .shape [0 ]):
5784 mask = topk_ids == i
5885 if mask .sum ():
59- out [mask ] = SiluAndMul ()(a [mask ] @ w1 [i ].transpose (0 , 1 )) @ w2 [
60- i
61- ].transpose (0 , 1 )
86+ out [mask ] = SiluAndMul ()(
87+ a [mask ] @ w1_compute [i ].transpose (0 , 1 )
88+ ) @ w2_compute [i ].transpose (0 , 1 )
89+
6290 return (
6391 out .view (B , - 1 , w2 .shape [1 ]) * topk_weight .view (B , - 1 , 1 ).to (out .dtype )
6492 ).sum (dim = 1 )
@@ -98,21 +126,12 @@ def _test_case(self, m, n, k, e, topk, dtype, use_fp8_w8a8=False):
98126 a2_scale = a2_scale ,
99127 )
100128
101- vllm_output = fused_moe_vllm (
102- a ,
103- w1 ,
104- w2 ,
105- score ,
106- topk ,
107- renormalize = False ,
108- use_fp8_w8a8 = True ,
109- w1_scale = w1_scale ,
110- w2_scale = w2_scale ,
111- a1_scale = a1_scale ,
112- a2_scale = a2_scale ,
129+ torch_output = self .torch_naive_moe (
130+ a , w1 , w2 , score , topk , w1_scale , w2_scale , a1_scale , a2_scale
131+ )
132+ torch .testing .assert_close (
133+ sglang_output , torch_output , rtol = rtol , atol = atol
113134 )
114-
115- torch .testing .assert_close (sglang_output , vllm_output , rtol = rtol , atol = atol )
116135
117136 else :
118137 a = self .create_random_cuda_tensor ((m , k ), dtype )
@@ -127,8 +146,8 @@ def _test_case(self, m, n, k, e, topk, dtype, use_fp8_w8a8=False):
127146 )
128147
129148 def test_various_configurations (self ):
130- m_values = [1 , 33 , 64 , 222 , 1024 * 128 ]
131- n_values = [128 , 1024 , 2048 ]
149+ m_values = [1 , 33 , 64 , 222 ]
150+ n_values = [128 , 1024 ]
132151 k_values = [128 , 511 , 1024 ]
133152 dtypes = [torch .float16 , torch .bfloat16 ]
134153 fp8_modes = [False , True ]
@@ -171,6 +190,7 @@ def test_various_configurations(self):
171190 dtype ,
172191 use_fp8_w8a8 = use_fp8_w8a8 ,
173192 )
193+ torch .cuda .empty_cache ()
174194 pbar .update (1 )
175195
176196
0 commit comments