@@ -287,8 +287,10 @@ def apply(
287287 """
288288 Triton compute Fused MoE.
289289 """
290- gate_out = gate (x .cast ("float32" ))
291290 token_num = x .shape [0 ]
291+ if token_num == 0 :
292+ return paddle .zeros ([token_num , layer .hidden_size ], dtype = x .dtype )
293+ gate_out = gate (x .cast ("float32" ))
292294 top_k = layer .top_k
293295 num_local_experts = layer .num_local_experts
294296 top_k = layer .top_k
@@ -669,8 +671,10 @@ def apply(
669671 """
670672 Triton compute Fused MoE.
671673 """
672- gate_out = gate (x .cast ("float32" ))
673674 token_num = x .shape [0 ]
675+ if token_num == 0 :
676+ return paddle .zeros ([token_num , layer .hidden_size ], dtype = x .dtype )
677+ gate_out = gate (x .cast ("float32" ))
674678 top_k = layer .top_k
675679 num_local_experts = layer .num_local_experts
676680 moe_intermediate_size = layer .moe_intermediate_size
@@ -959,8 +963,10 @@ def apply(
959963 """
960964 Triton compute Fused MoE.
961965 """
962- gate_out = gate (x .cast ("float32" ))
963966 token_num = x .shape [0 ]
967+ if token_num == 0 :
968+ return paddle .zeros ([token_num , layer .hidden_size ], dtype = x .dtype )
969+ gate_out = gate (x .cast ("float32" ))
964970 top_k = layer .top_k
965971 num_local_experts = layer .num_local_experts
966972 moe_intermediate_size = layer .moe_intermediate_size
@@ -1480,8 +1486,10 @@ def apply(
14801486 """
14811487 Triton compute Fused MoE.
14821488 """
1483- gate_out = gate (x .cast ("float32" ))
14841489 token_num = x .shape [0 ]
1490+ if token_num == 0 :
1491+ return paddle .zeros ([token_num , layer .hidden_size ], dtype = x .dtype )
1492+ gate_out = gate (x .cast ("float32" ))
14851493 top_k = layer .top_k
14861494 num_local_experts = layer .num_local_experts
14871495 moe_intermediate_size = layer .moe_intermediate_size
0 commit comments