From 816ee1fda95194670c633b7d0d1291a19fdb7199 Mon Sep 17 00:00:00 2001 From: Brayden Zhong Date: Mon, 17 Nov 2025 16:00:46 -0800 Subject: [PATCH] upd --- python/sglang/srt/layers/moe/topk.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 30b7cc5da49..68eaee72d43 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -74,6 +74,29 @@ if _is_cuda: from sgl_kernel import kimi_k2_moe_fused_gate, moe_fused_gate + @torch.library.register_fake("sgl_kernel::kimi_k2_moe_fused_gate") + def _kimi_k2_moe_fused_gate( + input_tensor, + bias, + topk, + renormalize, + routed_scaling_factor, + apply_routed_scaling_factor_on_output, + ): + num_rows = input_tensor.shape[0] + topk_weights = input_tensor.new_empty( + num_rows, + topk, + dtype=torch.float32, + ) + topk_ids = input_tensor.new_empty( + num_rows, + topk, + dtype=torch.int32, + ) + return topk_weights, topk_ids + + if _is_cuda or _is_hip: from sgl_kernel import topk_softmax if _use_aiter: