Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions python/sglang/srt/layers/moe/topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ def fused_topk(
return topk_weights, topk_ids


# This is used by the Deepseek V2/V3/R1 series models
@torch.compile(dynamic=True, backend=get_compiler_backend())
def grouped_topk(
hidden_states: torch.Tensor,
Expand All @@ -84,10 +83,17 @@ def grouped_topk(
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
scoring_func: str = "softmax",
):
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"

scores = torch.softmax(gating_output, dim=-1)
if scoring_func == "softmax":
scores = torch.softmax(gating_output, dim=-1)
elif scoring_func == "sigmoid":
scores = gating_output.sigmoid()
else:
raise ValueError(f"Scoring function '{scoring_func}' is not supported.")

num_token = scores.shape[0]
group_scores = (
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
Expand All @@ -111,6 +117,7 @@ def grouped_topk(
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)


# DeepSeek V2/V3/R1 uses biased_grouped_top
@torch.compile(dynamic=True, backend=get_compiler_backend())
def biased_grouped_topk(
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -165,7 +172,7 @@ def select_experts(
correction_bias: Optional[torch.Tensor] = None,
torch_native: bool = False,
):
# DeekSeekv2 uses grouped_top_k
# DeepSeek V2/V3/R1 uses biased_grouped_top
if use_grouped_topk:
assert topk_group is not None
assert num_expert_group is not None
Expand Down
Loading