Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 38 additions & 10 deletions python/sglang/srt/layers/moe/topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
)

import torch
import torch.nn.functional as F

from sglang.srt.custom_op import CustomOp
from sglang.srt.distributed import get_tp_group
Expand Down Expand Up @@ -81,7 +80,7 @@
pass

if _is_cuda or _is_hip:
from sgl_kernel import topk_softmax
from sgl_kernel import topk_sigmoid, topk_softmax
if _use_aiter:
try:
from aiter import biased_grouped_topk as aiter_biased_grouped_topk
Expand Down Expand Up @@ -109,6 +108,7 @@ class TopKConfig:
apply_routed_scaling_factor_on_output: bool = False
fused_shared_experts_scaling_factor: Optional[float] = None
output_format: Optional[TopKOutputFormat] = None
scoring_func: str = "softmax"


# -------------------------------- TopKOutput ---------------------------------------
Expand Down Expand Up @@ -244,6 +244,7 @@ def __init__(
apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
fused_shared_experts_scaling_factor=fused_shared_experts_scaling_factor,
output_format=output_format,
scoring_func=scoring_func,
)

def forward_native(
Expand Down Expand Up @@ -430,10 +431,19 @@ def fused_topk_torch_native(
topk: int,
renormalize: bool,
correction_bias: torch.Tensor = None,
scoring_func: str = "softmax",
):
def scoring_func_impl(gating_output: torch.Tensor) -> torch.Tensor:
if scoring_func == "softmax":
return gating_output.softmax(dim=-1)
elif scoring_func == "sigmoid":
return gating_output.sigmoid()
else:
raise ValueError(f"Invalid scoring function: {scoring_func}")

if correction_bias is not None:
n_routed_experts = gating_output.shape[-1]
scores = gating_output.softmax(dim=-1)
scores = scoring_func_impl(gating_output)
scores_for_choice = scores.view(
-1, n_routed_experts
) + correction_bias.unsqueeze(0)
Expand All @@ -448,7 +458,7 @@ def fused_topk_torch_native(
M, topk, dtype=torch.float32, device=hidden_states.device
)
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
topk_weights = F.softmax(gating_output.float(), dim=-1)
topk_weights = scoring_func_impl(gating_output.float())
topk_weights, topk_ids = torch.topk(topk_weights, topk, dim=-1)

if renormalize:
Expand All @@ -464,6 +474,7 @@ def fused_topk_cpu(
num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
correction_bias: torch.Tensor = None,
scoring_func: str = "softmax",
):
topk_weights, topk_ids = torch.ops.sgl_kernel.topk_softmax_cpu(
hidden_states=hidden_states,
Expand Down Expand Up @@ -494,8 +505,10 @@ def fused_topk(
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
correction_bias: Optional[torch.Tensor] = None,
num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
scoring_func: str = "softmax",
):
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"

Expand All @@ -506,12 +519,23 @@ def fused_topk(
)
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)

topk_softmax(
topk_weights,
topk_ids,
gating_output,
renormalize,
)
if scoring_func == "softmax":
topk_softmax(
topk_weights,
topk_ids,
gating_output,
renormalize,
)
elif scoring_func == "sigmoid":
topk_sigmoid(
topk_weights,
topk_ids,
gating_output,
renormalize,
correction_bias,
)
else:
raise ValueError(f"Invalid scoring function: {scoring_func}")

topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
_mask_topk_ids_padded_region(topk_ids, num_token_non_padded)
Expand Down Expand Up @@ -916,6 +940,7 @@ def select_experts(
fused_shared_experts_scaling_factor = (
topk_config.fused_shared_experts_scaling_factor
)
scoring_func = topk_config.scoring_func

router_logits, correction_bias = (
expert_location_dispatch.transform_select_experts_inputs(
Expand Down Expand Up @@ -972,6 +997,7 @@ def select_experts(
topk=num_routed_topk if _use_aiter else top_k,
renormalize=renormalize,
correction_bias=correction_bias,
scoring_func=scoring_func,
)
elif custom_routing_function is None:
assert not apply_routed_scaling_factor_on_output, "Not implemented"
Expand All @@ -981,8 +1007,10 @@ def select_experts(
gating_output=router_logits,
topk=num_routed_topk if _use_aiter else top_k,
renormalize=renormalize,
correction_bias=correction_bias,
num_token_non_padded=num_token_non_padded,
expert_location_dispatch_info=expert_location_dispatch_info,
scoring_func=scoring_func,
)
else:
assert (
Expand Down
3 changes: 0 additions & 3 deletions python/sglang/srt/models/minimax_m2.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,6 @@ def __init__(
top_k=config.num_experts_per_tok,
renormalize=True,
scoring_func=config.scoring_func,
use_grouped_topk=True, # TODO: Use "grouped top-k" flag only for hardcoded sigmoid scoring
num_expert_group=1,
topk_group=1,
correction_bias=self.e_score_correction_bias,
routed_scaling_factor=1.0,
)
Expand Down
Loading