Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 47 additions & 3 deletions python/sglang/srt/models/glm4_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,12 +448,56 @@ def forward(
) -> torch.Tensor:

if not get_moe_a2a_backend().is_deepep():
return self.forward_normal(
hidden_states, should_allreduce_fusion, use_reduce_scatter
)
if (
self.alt_stream is not None
and hidden_states.shape[0] > 0
and get_is_capture_mode()
):
return self.forward_normal_dual_stream(
hidden_states, should_allreduce_fusion, use_reduce_scatter
)
else:
return self.forward_normal(
hidden_states, should_allreduce_fusion, use_reduce_scatter
)
else:
return self.forward_deepep(hidden_states, forward_batch)

def forward_normal_dual_stream(
self,
hidden_states: torch.Tensor,
should_allreduce_fusion: bool = False,
use_reduce_scatter: bool = False,
) -> torch.Tensor:
current_stream = torch.cuda.current_stream()
self.alt_stream.wait_stream(current_stream)
shared_output = self._forward_shared_experts(hidden_states)
with torch.cuda.stream(self.alt_stream):
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states)
topk_output = self.topk(hidden_states, router_logits)

final_hidden_states = self.experts(hidden_states, topk_output)
if not _is_cuda and not _use_aiter:
# fused in biased_grouped_topk so we can skip here
final_hidden_states *= self.routed_scaling_factor

current_stream.wait_stream(self.alt_stream)
with use_symmetric_memory(
parallel_state.get_tp_group(), disabled=not is_allocation_symmetric()
):
final_hidden_states_out = torch.empty_like(final_hidden_states)
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
final_hidden_states = final_hidden_states_out
if (
self.tp_size > 1
and not should_allreduce_fusion
and not use_reduce_scatter
and not should_use_flashinfer_cutlass_moe_fp4_allgather()
):
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states
Comment on lines +466 to +499
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This new method forward_normal_dual_stream duplicates a significant amount of code from the existing forward_normal method. The core MoE computations and the final combination/all-reduce steps are very similar.

To improve maintainability and reduce redundancy, consider refactoring the common logic into helper methods. For example, you could have a helper for the routed expert calculations and another for the final combination and all-reduce step. This would make both forward_normal and forward_normal_dual_stream simpler and highlight their differences (i.e., the stream management) more clearly.


def forward_normal(
self,
hidden_states: torch.Tensor,
Expand Down
Loading