diff --git a/python/sgl_kernel_npu/sgl_kernel_npu/activation/swiglu_oai.py b/python/sgl_kernel_npu/sgl_kernel_npu/activation/swiglu_oai.py index fe093c383..54ad03e93 100644 --- a/python/sgl_kernel_npu/sgl_kernel_npu/activation/swiglu_oai.py +++ b/python/sgl_kernel_npu/sgl_kernel_npu/activation/swiglu_oai.py @@ -83,7 +83,7 @@ def swiglu_oai_triton( def swiglu_oai_native(layer, hidden_states): - E, N, _ = layer.w13_weight.size() + E, _, N = layer.w13_weight.size() gate_up = hidden_states.view(-1, N) alpha = layer.moe_runner_config.gemm1_alpha limit = layer.moe_runner_config.gemm1_clamp_limit @@ -98,7 +98,7 @@ def swiglu_oai_native(layer, hidden_states): def swiglu_oai(layer, hidden_states): return swiglu_oai_triton( hidden_states, - layer.w13_weight.shape[1], + layer.w13_weight.shape[2], layer.moe_runner_config.gemm1_alpha, layer.moe_runner_config.gemm1_clamp_limit, )