From bacec3be7b33ac6f67912638c923c77049d809c0 Mon Sep 17 00:00:00 2001 From: ispobock Date: Wed, 10 Dec 2025 14:20:30 +0000 Subject: [PATCH] apply --- .../moe/fused_moe_triton/fused_marlin_moe.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_marlin_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_marlin_moe.py index 2f753800f00d..312ccb1ff08a 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_marlin_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_marlin_moe.py @@ -7,7 +7,7 @@ _is_cuda = is_cuda() if _is_cuda: - from sgl_kernel import silu_and_mul + from sgl_kernel import moe_sum_reduce, silu_and_mul def get_scalar_type(num_bits: int, has_zp: bool): @@ -204,9 +204,15 @@ def fused_marlin_moe( ).view(-1, topk, K) output = hidden_states if inplace else torch.empty_like(hidden_states) - torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1, out=output) - if routed_scaling_factor is not None: - output *= routed_scaling_factor + + if routed_scaling_factor is None: + routed_scaling_factor = 1.0 + + moe_sum_reduce( + intermediate_cache3, + output, + routed_scaling_factor, + ) return output