2929)
3030
3131import torch
32- import torch .nn .functional as F
3332
3433from sglang .srt .custom_op import CustomOp
3534from sglang .srt .distributed import get_tp_group
8180 pass
8281
8382if _is_cuda or _is_hip :
84- from sgl_kernel import topk_softmax
83+ from sgl_kernel import topk_sigmoid , topk_softmax
8584if _use_aiter :
8685 try :
8786 from aiter import biased_grouped_topk as aiter_biased_grouped_topk
@@ -109,6 +108,7 @@ class TopKConfig:
109108 apply_routed_scaling_factor_on_output : bool = False
110109 fused_shared_experts_scaling_factor : Optional [float ] = None
111110 output_format : Optional [TopKOutputFormat ] = None
111+ scoring_func : str = "softmax"
112112
113113
114114# -------------------------------- TopKOutput ---------------------------------------
@@ -244,6 +244,7 @@ def __init__(
244244 apply_routed_scaling_factor_on_output = apply_routed_scaling_factor_on_output ,
245245 fused_shared_experts_scaling_factor = fused_shared_experts_scaling_factor ,
246246 output_format = output_format ,
247+ scoring_func = scoring_func ,
247248 )
248249
249250 def forward_native (
@@ -430,10 +431,19 @@ def fused_topk_torch_native(
430431 topk : int ,
431432 renormalize : bool ,
432433 correction_bias : torch .Tensor = None ,
434+ scoring_func : str = "softmax" ,
433435):
436+ def scoring_func_impl (gating_output : torch .Tensor ) -> torch .Tensor :
437+ if scoring_func == "softmax" :
438+ return gating_output .softmax (dim = - 1 )
439+ elif scoring_func == "sigmoid" :
440+ return gating_output .sigmoid ()
441+ else :
442+ raise ValueError (f"Invalid scoring function: { scoring_func } " )
443+
434444 if correction_bias is not None :
435445 n_routed_experts = gating_output .shape [- 1 ]
436- scores = gating_output . softmax ( dim = - 1 )
446+ scores = scoring_func_impl ( gating_output )
437447 scores_for_choice = scores .view (
438448 - 1 , n_routed_experts
439449 ) + correction_bias .unsqueeze (0 )
@@ -448,7 +458,7 @@ def fused_topk_torch_native(
448458 M , topk , dtype = torch .float32 , device = hidden_states .device
449459 )
450460 topk_ids = torch .empty (M , topk , dtype = torch .int32 , device = hidden_states .device )
451- topk_weights = F . softmax (gating_output .float (), dim = - 1 )
461+ topk_weights = scoring_func_impl (gating_output .float ())
452462 topk_weights , topk_ids = torch .topk (topk_weights , topk , dim = - 1 )
453463
454464 if renormalize :
@@ -464,6 +474,7 @@ def fused_topk_cpu(
464474 num_token_non_padded : Optional [torch .Tensor ] = None ,
465475 expert_location_dispatch_info : Optional [ExpertLocationDispatchInfo ] = None ,
466476 correction_bias : torch .Tensor = None ,
477+ scoring_func : str = "softmax" ,
467478):
468479 topk_weights , topk_ids = torch .ops .sgl_kernel .topk_softmax_cpu (
469480 hidden_states = hidden_states ,
@@ -494,8 +505,10 @@ def fused_topk(
494505 gating_output : torch .Tensor ,
495506 topk : int ,
496507 renormalize : bool ,
508+ correction_bias : Optional [torch .Tensor ] = None ,
497509 num_token_non_padded : Optional [torch .Tensor ] = None ,
498510 expert_location_dispatch_info : Optional [ExpertLocationDispatchInfo ] = None ,
511+ scoring_func : str = "softmax" ,
499512):
500513 assert hidden_states .shape [0 ] == gating_output .shape [0 ], "Number of tokens mismatch"
501514
@@ -506,12 +519,23 @@ def fused_topk(
506519 )
507520 topk_ids = torch .empty (M , topk , dtype = torch .int32 , device = hidden_states .device )
508521
509- topk_softmax (
510- topk_weights ,
511- topk_ids ,
512- gating_output ,
513- renormalize ,
514- )
522+ if scoring_func == "softmax" :
523+ topk_softmax (
524+ topk_weights ,
525+ topk_ids ,
526+ gating_output ,
527+ renormalize ,
528+ )
529+ elif scoring_func == "sigmoid" :
530+ topk_sigmoid (
531+ topk_weights ,
532+ topk_ids ,
533+ gating_output ,
534+ renormalize ,
535+ correction_bias ,
536+ )
537+ else :
538+ raise ValueError (f"Invalid scoring function: { scoring_func } " )
515539
516540 topk_ids = topk_ids_logical_to_physical (topk_ids , expert_location_dispatch_info )
517541 _mask_topk_ids_padded_region (topk_ids , num_token_non_padded )
@@ -916,6 +940,7 @@ def select_experts(
916940 fused_shared_experts_scaling_factor = (
917941 topk_config .fused_shared_experts_scaling_factor
918942 )
943+ scoring_func = topk_config .scoring_func
919944
920945 router_logits , correction_bias = (
921946 expert_location_dispatch .transform_select_experts_inputs (
@@ -972,6 +997,7 @@ def select_experts(
972997 topk = num_routed_topk if _use_aiter else top_k ,
973998 renormalize = renormalize ,
974999 correction_bias = correction_bias ,
1000+ scoring_func = scoring_func ,
9751001 )
9761002 elif custom_routing_function is None :
9771003 assert not apply_routed_scaling_factor_on_output , "Not implemented"
@@ -981,8 +1007,10 @@ def select_experts(
9811007 gating_output = router_logits ,
9821008 topk = num_routed_topk if _use_aiter else top_k ,
9831009 renormalize = renormalize ,
1010+ correction_bias = correction_bias ,
9841011 num_token_non_padded = num_token_non_padded ,
9851012 expert_location_dispatch_info = expert_location_dispatch_info ,
1013+ scoring_func = scoring_func ,
9861014 )
9871015 else :
9881016 assert (
0 commit comments