diff --git a/python/sglang/srt/models/glm4_moe.py b/python/sglang/srt/models/glm4_moe.py index 3b04422b1e8b..0267a9593855 100644 --- a/python/sglang/srt/models/glm4_moe.py +++ b/python/sglang/srt/models/glm4_moe.py @@ -448,12 +448,56 @@ def forward( ) -> torch.Tensor: if not get_moe_a2a_backend().is_deepep(): - return self.forward_normal( - hidden_states, should_allreduce_fusion, use_reduce_scatter - ) + if ( + self.alt_stream is not None + and hidden_states.shape[0] > 0 + and get_is_capture_mode() + ): + return self.forward_normal_dual_stream( + hidden_states, should_allreduce_fusion, use_reduce_scatter + ) + else: + return self.forward_normal( + hidden_states, should_allreduce_fusion, use_reduce_scatter + ) else: return self.forward_deepep(hidden_states, forward_batch) + def forward_normal_dual_stream( + self, + hidden_states: torch.Tensor, + should_allreduce_fusion: bool = False, + use_reduce_scatter: bool = False, + ) -> torch.Tensor: + current_stream = torch.cuda.current_stream() + self.alt_stream.wait_stream(current_stream) + shared_output = self._forward_shared_experts(hidden_states) + with torch.cuda.stream(self.alt_stream): + # router_logits: (num_tokens, n_experts) + router_logits = self.gate(hidden_states) + topk_output = self.topk(hidden_states, router_logits) + + final_hidden_states = self.experts(hidden_states, topk_output) + if not _is_cuda and not _use_aiter: + # fused in biased_grouped_topk so we can skip here + final_hidden_states *= self.routed_scaling_factor + + current_stream.wait_stream(self.alt_stream) + with use_symmetric_memory( + parallel_state.get_tp_group(), disabled=not is_allocation_symmetric() + ): + final_hidden_states_out = torch.empty_like(final_hidden_states) + torch.add(final_hidden_states, shared_output, out=final_hidden_states_out) + final_hidden_states = final_hidden_states_out + if ( + self.tp_size > 1 + and not should_allreduce_fusion + and not use_reduce_scatter + and not should_use_flashinfer_cutlass_moe_fp4_allgather() + ): + final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) + return final_hidden_states + def forward_normal( self, hidden_states: torch.Tensor,