Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
21c9497
add support the number of group expert from the original 32 to 128 im…
ltaodream Jun 7, 2025
93e5ad1
modify the MAX_VPT limit in topk.py
ltaodream Jun 7, 2025
68ff1d3
Fix the size of row_chunk and bias_chunk to 32 to prevent register ov…
ltaodream Jun 7, 2025
5fbde7f
add test and cycle fusion Optimization
ltaodream Jun 9, 2025
92f418b
add bench_moe_fused_gate for kimi-vl
ltaodream Jun 10, 2025
17d13a6
Merge branch 'main' into main
BBuf Jul 22, 2025
46a93f4
Merge branch 'sgl-project:main' into main
ltaodream Jul 23, 2025
0f908a6
[Kimi K2] add support the number of group expert from the original 32…
ltaodream Jul 23, 2025
3c36fa3
Merge branch 'main' into main
BBuf Jul 23, 2025
8126d99
fix prefetch_global、fix alignment error with cutlass alignedarray
ltaodream Jul 23, 2025
8672b60
Merge branch 'main' of ssh://github.com/ltaodream/sglang
ltaodream Jul 23, 2025
88e1d30
fix prefetch
ltaodream Jul 23, 2025
9cbd2d1
fix kimi k2 config
ltaodream Jul 23, 2025
407ea8a
feat: Optimize MoE routing with efficient shared memory CUDA kernel
ltaodream Jul 24, 2025
c909014
test: Enhance MoE routing tests with comprehensive group selection ve…
ltaodream Jul 24, 2025
c431d17
add test_moe_fused_gate param
ltaodream Jul 24, 2025
5aa6f04
Merge branch 'sgl-project:main' into main
ltaodream Jul 24, 2025
c9b2738
fix: Adjust MOE routing test tolerances for different expert configur…
ltaodream Jul 25, 2025
33c8651
Merge branch 'main' of ssh://github.com/ltaodream/sglang
ltaodream Jul 25, 2025
9c0ee2b
Merge branch 'main' into main
BBuf Jul 27, 2025
98abd24
fix: Optimize MOE Fused Gate kernel to use non-tiled path for VPT≤32 …
ltaodream Jul 27, 2025
5944ff4
Merge branch 'main' of ssh://github.com/ltaodream/sglang
ltaodream Jul 27, 2025
585c3ac
fix: correct function call from biased_grouped_topk to biased_grouped…
ltaodream Jul 27, 2025
7dd53f6
feat: use native topk for single expert group with group size > 128, …
ltaodream Jul 27, 2025
cdec8aa
fix
ltaodream Jul 27, 2025
eb92d47
Merge branch 'main' into main
BBuf Jul 27, 2025
a464ac8
fix lint issues
ltaodream Jul 28, 2025
c9eb5f6
Merge branch 'main' of ssh://github.com/ltaodream/sglang
ltaodream Jul 28, 2025
60985c2
Merge branch 'main' into main
ltaodream Jul 28, 2025
0c7cb09
fix lint
ltaodream Jul 28, 2025
e29454c
Merge branch 'main' of ssh://github.com/ltaodream/sglang
ltaodream Jul 28, 2025
6e14a95
Merge branch 'main' into main
ltaodream Jul 28, 2025
01e75e4
Merge branch 'main' into main
BBuf Jul 29, 2025
513e1e0
Refactor dynamic kernel launch code to remove duplication between sma…
ltaodream Jul 30, 2025
cfead57
Merge branch 'main' of ssh://github.com/ltaodream/sglang
ltaodream Jul 30, 2025
71894df
Refactor dynamic kernel launch logic for moe_fused_gate, unify macro
ltaodream Jul 30, 2025
5dac4cd
fix
ltaodream Jul 30, 2025
bfa5a7b
Recovery
ltaodream Jul 30, 2025
ef890cc
Consolidate duplicated logic for small and large VPT cases into a uni…
ltaodream Aug 1, 2025
93636fb
fix tile
ltaodream Aug 2, 2025
4d56ecc
pre-commit
ltaodream Aug 2, 2025
b7f875a
fix
ltaodream Aug 4, 2025
5072b14
fix index out of bounds
ltaodream Aug 8, 2025
3d5368d
fix num_expert_group=1
ltaodream Aug 8, 2025
af449a1
fix Expert repeated selection
ltaodream Aug 9, 2025
876a5a2
pre-commit run
ltaodream Aug 9, 2025
c55adda
fix clear bias
ltaodream Aug 9, 2025
cf186d8
upd moe_fuse_gate_tiling_more_experts v1
Aug 11, 2025
2c4b629
upd
Aug 12, 2025
22545d4
fix grouptopk only set 1-shared_experts problem
Aug 12, 2025
d26f836
upd
Aug 12, 2025
6c0bddb
add tiled dynamic benchmark
Aug 12, 2025
17f6813
upd pytest style
ttaohe Aug 12, 2025
71f42c3
add bench_moe_fused_gate_tiled script
ttaohe Aug 12, 2025
4530634
add specific examples like kimi-vl and kimi-k2 moe config
ttaohe Aug 12, 2025
5fe73f8
add kimi-vl 64
ttaohe Aug 12, 2025
bdec8b3
bench tile add more test case
ttaohe Aug 12, 2025
11b9570
change specific style
ttaohe Aug 12, 2025
fef21b5
commit last bench test
ttaohe Aug 12, 2025
29ed296
Merge pull request #1 from ttaohe/main
ltaodream Aug 12, 2025
08752c9
pre-commit run --all-files
ltaodream Aug 12, 2025
82f6f48
Merge branch 'main' into main
ltaodream Aug 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 40 additions & 16 deletions python/sglang/srt/layers/moe/topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,20 +469,31 @@ def grouped_topk_gpu(
sorted=(True if num_fused_shared_experts > 0 else False),
)
if num_fused_shared_experts:
topk_ids[:, -1] = torch.randint(
low=num_experts,
high=num_experts + num_fused_shared_experts,
size=(topk_ids.size(0),),
dtype=topk_ids.dtype,
assert (
topk >= num_fused_shared_experts
), "topk must be >= num_fused_shared_experts"
# Assign the last N ids to all shared expert ids [num_experts, num_experts+N)
shared_ids = torch.arange(
num_experts,
num_experts + num_fused_shared_experts,
device=topk_ids.device,
dtype=topk_ids.dtype,
)
topk_ids[:, -num_fused_shared_experts:] = shared_ids.unsqueeze(0).expand(
topk_ids.size(0), -1
)
topk_weights[:, -1] = topk_weights[:, :-1].sum(dim=-1) / routed_scaling_factor
# Set each shared expert's weight to sum(real_experts)/routed_scaling_factor
real_sum = topk_weights[:, :-num_fused_shared_experts].sum(dim=-1)
shared_weight = real_sum / routed_scaling_factor
topk_weights[:, -num_fused_shared_experts:] = shared_weight.unsqueeze(
-1
).expand(-1, num_fused_shared_experts)

if renormalize:
topk_weights_sum = (
topk_weights.sum(dim=-1, keepdim=True)
if num_fused_shared_experts == 0
else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
else topk_weights[:, :-num_fused_shared_experts].sum(dim=-1, keepdim=True)
)
topk_weights = topk_weights / topk_weights_sum
if apply_routed_scaling_factor_on_output:
Expand Down Expand Up @@ -571,20 +582,31 @@ def biased_grouped_topk_impl(
topk_weights = scores.gather(1, topk_ids)

if num_fused_shared_experts:
topk_ids[:, -1] = torch.randint(
low=num_experts,
high=num_experts + num_fused_shared_experts,
size=(topk_ids.size(0),),
dtype=topk_ids.dtype,
assert (
topk >= num_fused_shared_experts
), "topk must be >= num_fused_shared_experts"
# Assign the last N ids to all shared expert ids [num_experts, num_experts+N)
shared_ids = torch.arange(
num_experts,
num_experts + num_fused_shared_experts,
device=topk_ids.device,
dtype=topk_ids.dtype,
)
topk_ids[:, -num_fused_shared_experts:] = shared_ids.unsqueeze(0).expand(
topk_ids.size(0), -1
)
topk_weights[:, -1] = topk_weights[:, :-1].sum(dim=-1) / routed_scaling_factor
# Set each shared expert's weight to sum(real_experts)/routed_scaling_factor
real_sum = topk_weights[:, :-num_fused_shared_experts].sum(dim=-1)
shared_weight = real_sum / routed_scaling_factor
topk_weights[:, -num_fused_shared_experts:] = shared_weight.unsqueeze(
-1
).expand(-1, num_fused_shared_experts)

if renormalize:
topk_weights_sum = (
topk_weights.sum(dim=-1, keepdim=True)
if num_fused_shared_experts == 0
else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
else topk_weights[:, :-num_fused_shared_experts].sum(dim=-1, keepdim=True)
)
topk_weights = topk_weights / topk_weights_sum
if apply_routed_scaling_factor_on_output:
Expand Down Expand Up @@ -640,8 +662,10 @@ def biased_grouped_topk_gpu(
if (
_is_cuda
and gating_output.shape[1] // num_expert_group
<= 32 # moe_fused_gate kernel ensure that num_experts/num_expert_group does not exceed MAX_VPT=32 now. And when kernel can handle MAX_VPT > 32, we can remove this assertion.
and is_power_of_two(correction_bias.shape[0])
<= 512 # moe_fused_gate kernel now supports MAX_VPT up to 512, including Kimi K2's 384 experts
and (
is_power_of_two(correction_bias.shape[0]) or correction_bias.shape[0] == 384
) # Kimi K2 has 384 experts
):
topk_weights, topk_ids = moe_fused_gate(
gating_output.to(dtype=torch.float32),
Expand Down
1 change: 1 addition & 0 deletions sgl-kernel/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ set(SOURCES
"csrc/moe/marlin_moe_wna16/kernel_fp16_ku8b128.cu"
"csrc/moe/moe_align_kernel.cu"
"csrc/moe/moe_fused_gate.cu"
"csrc/moe/moe_fused_gate_tile_more_experts.cu"
"csrc/moe/moe_topk_softmax_kernels.cu"
"csrc/moe/nvfp4_blockwise_moe.cu"
"csrc/moe/fp8_blockwise_moe_kernel.cu"
Expand Down
13 changes: 7 additions & 6 deletions sgl-kernel/benchmark/bench_moe_fused_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
import triton.language as tl
from sgl_kernel import moe_fused_gate

from sglang.srt.layers.moe.topk import biased_grouped_topk
from sglang.srt.layers.moe.topk import biased_grouped_topk_impl


def biased_grouped_topk_org(scores, bias, num_expert_group, topk_group, topk):
return biased_grouped_topk(
return biased_grouped_topk_impl(
scores,
scores,
bias,
Expand All @@ -29,12 +29,14 @@ def biased_grouped_topk_org_fuse_kernel(


seq_length_range = [5000, 10000, 15000, 20000, 25000, 30000, 35000, 40000]
configs = [(sq,) for sq in seq_length_range]
configs = [(sq, 256, 8, 4, 8) for sq in seq_length_range] # original config
# configs = ([(sq, 64, 1, 1, 6) for sq in seq_length_range]) # kimi vl config
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why comment kimi config?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I am uncommenting to benchmark one of the three configurations separately.

# configs = ([(sq, 384, 1, 1, 8) for sq in seq_length_range]) # Kimi K2 config


@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["seq_length"],
x_names=["seq_length", "num_experts", "num_expert_group", "topk_group", "topk"],
x_vals=[list(_) for _ in configs],
line_arg="provider",
line_vals=["original", "kernel"],
Expand All @@ -45,10 +47,9 @@ def biased_grouped_topk_org_fuse_kernel(
args={},
)
)
def benchmark(seq_length, provider):
def benchmark(seq_length, num_experts, num_expert_group, topk_group, topk, provider):
dtype = torch.bfloat16
device = torch.device("cuda")
num_experts, num_expert_group, topk_group, topk = 256, 8, 4, 8

scores = torch.randn((seq_length, num_experts), device=device, dtype=dtype)
bias = torch.rand(num_experts, device=device, dtype=dtype)
Expand Down
140 changes: 140 additions & 0 deletions sgl-kernel/benchmark/bench_moe_fused_gate_tiled.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import itertools
import math

import torch
import torch.nn.functional as F
import triton
import triton.language as tl
from sgl_kernel import moe_fused_gate


def biased_grouped_topk_ref_impl(scores, bias, num_expert_group, topk_group, topk):
# Pure PyTorch reference to avoid implicit kernel paths and control compile modes.
# Logic mirrors biased_grouped_topk_impl without shared experts handling (set to 0 for bench).
# scores: [N, E], bias: [E]
n, e = scores.shape
scores_sig = scores.sigmoid()
scores_for_choice = scores_sig + bias.unsqueeze(0)

# group selection via top2 sum
g = num_expert_group
per_group = e // g
view = scores_for_choice.view(n, g, per_group)
top2 = torch.topk(view, k=2, dim=-1).values
group_scores = top2.sum(dim=-1) # [n, g]
group_idx = torch.topk(
group_scores, k=topk_group, dim=-1, sorted=False
).indices # [n, topk_group]

# mask and topk within selected groups
group_mask = torch.zeros_like(group_scores)
group_mask.scatter_(1, group_idx, 1)
score_mask = group_mask.unsqueeze(-1).expand(n, g, per_group).reshape(n, e)
tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), float("-inf"))

topk_vals, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
topk_weights = scores_sig.gather(1, topk_ids)

# renormalize
topk_weights = topk_weights / (topk_weights.sum(dim=-1, keepdim=True) + 1e-12)
return topk_weights, topk_ids


# wrap reference with different compile modes
def make_ref_fn(compile_mode: str):
fn = biased_grouped_topk_ref_impl
if compile_mode == "eager":
return fn
if compile_mode == "compile-static":
return torch.compile(fn, dynamic=False)
if compile_mode == "compile-dynamic":
return torch.compile(fn, dynamic=True)
raise ValueError(f"Unknown compile_mode: {compile_mode}")


def moe_fused_gate_kernel(scores, bias, num_expert_group, topk_group, topk):
return moe_fused_gate(scores, bias, num_expert_group, topk_group, topk)


# Choose a sequence length sweep consistent with existing benchmark style
seq_length_range = [2048, 3072, 4096, 10240, 15360, 20480]

# Focus on tiled-path configs (VPT > 32)
configs = []
configs += [(sq, 64, 1, 1, 6) for sq in seq_length_range] # Kimi VL: VPT=64
configs += [(sq, 384, 1, 1, 8) for sq in seq_length_range] # Kimi K2: VPT=384


def _bench_template(dtype: torch.dtype, plot_suffix: str):
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=[
"seq_length",
"num_experts",
"num_expert_group",
"topk_group",
"topk",
],
x_vals=[list(_) for _ in configs],
line_arg="provider",
line_vals=[
"orig-eager",
"orig-compile-static",
"orig-compile-dynamic",
"kernel",
],
line_names=[
"Original-Eager",
"Original-Compile-Static",
"Original-Compile-Dynamic",
"SGL Kernel",
],
styles=[("blue", "-"), ("green", "-"), ("orange", "-"), ("red", "-")],
ylabel="us",
plot_name=f"moe-fused-gate-tiled-performance-{plot_suffix}",
args={},
)
)
def benchmark(
seq_length, num_experts, num_expert_group, topk_group, topk, provider
):
device = torch.device("cuda")

scores = torch.randn((seq_length, num_experts), device=device, dtype=dtype)
bias = torch.rand(num_experts, device=device, dtype=dtype)

quantiles = [0.5, 0.2, 0.8]

if provider.startswith("orig"):
mode = provider.replace("orig-", "")
ref = make_ref_fn(mode)
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: ref(
scores.clone(), bias.clone(), num_expert_group, topk_group, topk
),
quantiles=quantiles,
)
elif provider == "kernel":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: moe_fused_gate_kernel(
scores.clone(), bias.clone(), num_expert_group, topk_group, topk
),
quantiles=quantiles,
)
else:
raise ValueError(provider)

return 1000 * ms, 1000 * max_ms, 1000 * min_ms

return benchmark


benchmark_bf16 = _bench_template(torch.bfloat16, "bf16")
benchmark_fp16 = _bench_template(torch.float16, "fp16")
benchmark_fp32 = _bench_template(torch.float32, "fp32")


if __name__ == "__main__":
benchmark_bf16.run(print_data=True)
benchmark_fp16.run(print_data=True)
benchmark_fp32.run(print_data=True)
38 changes: 27 additions & 11 deletions sgl-kernel/csrc/moe/moe_fused_gate.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@

#include <cfloat>
#include <type_traits>
#include <vector>

#include "moe_fused_gate_tiled.h"
template <typename T, int N>
using AlignedArray = cutlass::AlignedArray<T, N>;
using bfloat16_t = cutlass::bfloat16_t;
Expand Down Expand Up @@ -398,8 +401,11 @@ std::vector<at::Tensor> moe_fused_gate(
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 block_dim(WARP_SIZE, WARPS_PER_CTA);

// Check 1: Ensure that num_experts is a power of 2.
TORCH_CHECK((num_experts & (num_experts - 1)) == 0, "num_experts must be a power of 2, but got ", num_experts);
// Check 1: Ensure that num_experts is a power of 2, except allow 384 (Kimi K2) as a special case.
TORCH_CHECK(
((num_experts & (num_experts - 1)) == 0) || (num_experts == 384),
"num_experts must be a power of 2 or 384 (Kimi K2), but got ",
num_experts);

// Check 2: Ensure that num_experts is divisible by num_expert_group. (this also means num_expert_group is power of 2)
TORCH_CHECK(
Expand All @@ -410,15 +416,8 @@ std::vector<at::Tensor> moe_fused_gate(
num_expert_group);

int computed_vpt = num_experts / num_expert_group;
// Check 3: Ensure that num_experts/num_expert_group does not exceed MAX_VPT=32. Maximum VPT indicate max value per
// threads we can process.
TORCH_CHECK(
computed_vpt <= MAX_VPT,
"Per group experts: num_experts / num_expert_group = (",
computed_vpt,
") exceeds the maximum supported (",
MAX_VPT,
")");
// Ensure subgroup width fits within a warp for shuffle operations
TORCH_CHECK(num_expert_group <= WARP_SIZE, "num_expert_group must be <= ", WARP_SIZE, ", but got ", num_expert_group);

// Dispatch to templated kernel for known compile-time configurations.
// We currently only support for:
Expand All @@ -427,6 +426,18 @@ std::vector<at::Tensor> moe_fused_gate(
// Case 3: other cases, require 8 <= num_experts / num_expert_group <= 32
bool dispatched = false;
switch (num_experts) {
case 384:
if (num_expert_group == 1) {
// Static tiled specialization for THREADS_PER_ROW==1
LAUNCH_MOE_GATE_TILED_CONFIG(384, 1, 32);
}
break;
case 64:
if (num_expert_group == 1) {
// Static tiled specialization for THREADS_PER_ROW==1
LAUNCH_MOE_GATE_TILED_CONFIG(64, 1, 32);
}
break;
case 256:
if (num_expert_group == 8)
// This is deepseek v3 case. Here VPT = 256/8 = 32, ROWS_PER_WARP = 32/8 = 4, ROWS_PER_CTA = 6 * 4 = 24.
Expand Down Expand Up @@ -468,6 +479,11 @@ std::vector<at::Tensor> moe_fused_gate(
default:
break;
}
// If VPT exceeds native path (32), dispatch to tiled kernel which supports larger VPT
if (computed_vpt > MAX_VPT) {
return moe_fused_gate_tiled(
input, bias, num_expert_group, topk_group, topk, num_fused_shared_experts, routed_scaling_factor);
}
if (!dispatched) {
// Fallback to the dynamic kernel if none of the supported combinations match.
// currently only support num_experts / num_expert_group <= 32 for dynamic kernels
Expand Down
Loading