Skip to content

Commit 50177e0

Browse files
rogeryounghxuebi
authored andcommitted
Optimize topk sigmoid in minimax_m2 (sgl-project#14047)
Co-authored-by: xuebi <[email protected]>
1 parent ce44ad6 commit 50177e0

File tree

2 files changed

+38
-13
lines changed

2 files changed

+38
-13
lines changed

python/sglang/srt/layers/moe/topk.py

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
)
3030

3131
import torch
32-
import torch.nn.functional as F
3332

3433
from sglang.srt.custom_op import CustomOp
3534
from sglang.srt.distributed import get_tp_group
@@ -81,7 +80,7 @@
8180
pass
8281

8382
if _is_cuda or _is_hip:
84-
from sgl_kernel import topk_softmax
83+
from sgl_kernel import topk_sigmoid, topk_softmax
8584
if _use_aiter:
8685
try:
8786
from aiter import biased_grouped_topk as aiter_biased_grouped_topk
@@ -109,6 +108,7 @@ class TopKConfig:
109108
apply_routed_scaling_factor_on_output: bool = False
110109
fused_shared_experts_scaling_factor: Optional[float] = None
111110
output_format: Optional[TopKOutputFormat] = None
111+
scoring_func: str = "softmax"
112112

113113

114114
# -------------------------------- TopKOutput ---------------------------------------
@@ -244,6 +244,7 @@ def __init__(
244244
apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
245245
fused_shared_experts_scaling_factor=fused_shared_experts_scaling_factor,
246246
output_format=output_format,
247+
scoring_func=scoring_func,
247248
)
248249

249250
def forward_native(
@@ -430,10 +431,19 @@ def fused_topk_torch_native(
430431
topk: int,
431432
renormalize: bool,
432433
correction_bias: torch.Tensor = None,
434+
scoring_func: str = "softmax",
433435
):
436+
def scoring_func_impl(gating_output: torch.Tensor) -> torch.Tensor:
437+
if scoring_func == "softmax":
438+
return gating_output.softmax(dim=-1)
439+
elif scoring_func == "sigmoid":
440+
return gating_output.sigmoid()
441+
else:
442+
raise ValueError(f"Invalid scoring function: {scoring_func}")
443+
434444
if correction_bias is not None:
435445
n_routed_experts = gating_output.shape[-1]
436-
scores = gating_output.softmax(dim=-1)
446+
scores = scoring_func_impl(gating_output)
437447
scores_for_choice = scores.view(
438448
-1, n_routed_experts
439449
) + correction_bias.unsqueeze(0)
@@ -448,7 +458,7 @@ def fused_topk_torch_native(
448458
M, topk, dtype=torch.float32, device=hidden_states.device
449459
)
450460
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
451-
topk_weights = F.softmax(gating_output.float(), dim=-1)
461+
topk_weights = scoring_func_impl(gating_output.float())
452462
topk_weights, topk_ids = torch.topk(topk_weights, topk, dim=-1)
453463

454464
if renormalize:
@@ -464,6 +474,7 @@ def fused_topk_cpu(
464474
num_token_non_padded: Optional[torch.Tensor] = None,
465475
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
466476
correction_bias: torch.Tensor = None,
477+
scoring_func: str = "softmax",
467478
):
468479
topk_weights, topk_ids = torch.ops.sgl_kernel.topk_softmax_cpu(
469480
hidden_states=hidden_states,
@@ -494,8 +505,10 @@ def fused_topk(
494505
gating_output: torch.Tensor,
495506
topk: int,
496507
renormalize: bool,
508+
correction_bias: Optional[torch.Tensor] = None,
497509
num_token_non_padded: Optional[torch.Tensor] = None,
498510
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
511+
scoring_func: str = "softmax",
499512
):
500513
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
501514

@@ -506,12 +519,23 @@ def fused_topk(
506519
)
507520
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
508521

509-
topk_softmax(
510-
topk_weights,
511-
topk_ids,
512-
gating_output,
513-
renormalize,
514-
)
522+
if scoring_func == "softmax":
523+
topk_softmax(
524+
topk_weights,
525+
topk_ids,
526+
gating_output,
527+
renormalize,
528+
)
529+
elif scoring_func == "sigmoid":
530+
topk_sigmoid(
531+
topk_weights,
532+
topk_ids,
533+
gating_output,
534+
renormalize,
535+
correction_bias,
536+
)
537+
else:
538+
raise ValueError(f"Invalid scoring function: {scoring_func}")
515539

516540
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
517541
_mask_topk_ids_padded_region(topk_ids, num_token_non_padded)
@@ -916,6 +940,7 @@ def select_experts(
916940
fused_shared_experts_scaling_factor = (
917941
topk_config.fused_shared_experts_scaling_factor
918942
)
943+
scoring_func = topk_config.scoring_func
919944

920945
router_logits, correction_bias = (
921946
expert_location_dispatch.transform_select_experts_inputs(
@@ -972,6 +997,7 @@ def select_experts(
972997
topk=num_routed_topk if _use_aiter else top_k,
973998
renormalize=renormalize,
974999
correction_bias=correction_bias,
1000+
scoring_func=scoring_func,
9751001
)
9761002
elif custom_routing_function is None:
9771003
assert not apply_routed_scaling_factor_on_output, "Not implemented"
@@ -981,8 +1007,10 @@ def select_experts(
9811007
gating_output=router_logits,
9821008
topk=num_routed_topk if _use_aiter else top_k,
9831009
renormalize=renormalize,
1010+
correction_bias=correction_bias,
9841011
num_token_non_padded=num_token_non_padded,
9851012
expert_location_dispatch_info=expert_location_dispatch_info,
1013+
scoring_func=scoring_func,
9861014
)
9871015
else:
9881016
assert (

python/sglang/srt/models/minimax_m2.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -167,9 +167,6 @@ def __init__(
167167
top_k=config.num_experts_per_tok,
168168
renormalize=True,
169169
scoring_func=config.scoring_func,
170-
use_grouped_topk=True, # TODO: Use "grouped top-k" flag only for hardcoded sigmoid scoring
171-
num_expert_group=1,
172-
topk_group=1,
173170
correction_bias=self.e_score_correction_bias,
174171
routed_scaling_factor=1.0,
175172
)

0 commit comments

Comments
 (0)