-
Notifications
You must be signed in to change notification settings - Fork 4.7k
[Feature] Add tile-based method of supporting large VPT in moe_fused_gate kernel #6946
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
Show all changes
62 commits
Select commit
Hold shift + click to select a range
21c9497
add support the number of group expert from the original 32 to 128 im…
ltaodream 93e5ad1
modify the MAX_VPT limit in topk.py
ltaodream 68ff1d3
Fix the size of row_chunk and bias_chunk to 32 to prevent register ov…
ltaodream 5fbde7f
add test and cycle fusion Optimization
ltaodream 92f418b
add bench_moe_fused_gate for kimi-vl
ltaodream 17d13a6
Merge branch 'main' into main
BBuf 46a93f4
Merge branch 'sgl-project:main' into main
ltaodream 0f908a6
[Kimi K2] add support the number of group expert from the original 32…
ltaodream 3c36fa3
Merge branch 'main' into main
BBuf 8126d99
fix prefetch_global、fix alignment error with cutlass alignedarray
ltaodream 8672b60
Merge branch 'main' of ssh://github.com/ltaodream/sglang
ltaodream 88e1d30
fix prefetch
ltaodream 9cbd2d1
fix kimi k2 config
ltaodream 407ea8a
feat: Optimize MoE routing with efficient shared memory CUDA kernel
ltaodream c909014
test: Enhance MoE routing tests with comprehensive group selection ve…
ltaodream c431d17
add test_moe_fused_gate param
ltaodream 5aa6f04
Merge branch 'sgl-project:main' into main
ltaodream c9b2738
fix: Adjust MOE routing test tolerances for different expert configur…
ltaodream 33c8651
Merge branch 'main' of ssh://github.com/ltaodream/sglang
ltaodream 9c0ee2b
Merge branch 'main' into main
BBuf 98abd24
fix: Optimize MOE Fused Gate kernel to use non-tiled path for VPT≤32 …
ltaodream 5944ff4
Merge branch 'main' of ssh://github.com/ltaodream/sglang
ltaodream 585c3ac
fix: correct function call from biased_grouped_topk to biased_grouped…
ltaodream 7dd53f6
feat: use native topk for single expert group with group size > 128, …
ltaodream cdec8aa
fix
ltaodream eb92d47
Merge branch 'main' into main
BBuf a464ac8
fix lint issues
ltaodream c9eb5f6
Merge branch 'main' of ssh://github.com/ltaodream/sglang
ltaodream 60985c2
Merge branch 'main' into main
ltaodream 0c7cb09
fix lint
ltaodream e29454c
Merge branch 'main' of ssh://github.com/ltaodream/sglang
ltaodream 6e14a95
Merge branch 'main' into main
ltaodream 01e75e4
Merge branch 'main' into main
BBuf 513e1e0
Refactor dynamic kernel launch code to remove duplication between sma…
ltaodream cfead57
Merge branch 'main' of ssh://github.com/ltaodream/sglang
ltaodream 71894df
Refactor dynamic kernel launch logic for moe_fused_gate, unify macro
ltaodream 5dac4cd
fix
ltaodream bfa5a7b
Recovery
ltaodream ef890cc
Consolidate duplicated logic for small and large VPT cases into a uni…
ltaodream 93636fb
fix tile
ltaodream 4d56ecc
pre-commit
ltaodream b7f875a
fix
ltaodream 5072b14
fix index out of bounds
ltaodream 3d5368d
fix num_expert_group=1
ltaodream af449a1
fix Expert repeated selection
ltaodream 876a5a2
pre-commit run
ltaodream c55adda
fix clear bias
ltaodream cf186d8
upd moe_fuse_gate_tiling_more_experts v1
2c4b629
upd
22545d4
fix grouptopk only set 1-shared_experts problem
d26f836
upd
6c0bddb
add tiled dynamic benchmark
17f6813
upd pytest style
ttaohe 71f42c3
add bench_moe_fused_gate_tiled script
ttaohe 4530634
add specific examples like kimi-vl and kimi-k2 moe config
ttaohe 5fe73f8
add kimi-vl 64
ttaohe bdec8b3
bench tile add more test case
ttaohe 11b9570
change specific style
ttaohe fef21b5
commit last bench test
ttaohe 29ed296
Merge pull request #1 from ttaohe/main
ltaodream 08752c9
pre-commit run --all-files
ltaodream 82f6f48
Merge branch 'main' into main
ltaodream File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,140 @@ | ||
| import itertools | ||
| import math | ||
|
|
||
| import torch | ||
| import torch.nn.functional as F | ||
| import triton | ||
| import triton.language as tl | ||
| from sgl_kernel import moe_fused_gate | ||
|
|
||
|
|
||
| def biased_grouped_topk_ref_impl(scores, bias, num_expert_group, topk_group, topk): | ||
| # Pure PyTorch reference to avoid implicit kernel paths and control compile modes. | ||
| # Logic mirrors biased_grouped_topk_impl without shared experts handling (set to 0 for bench). | ||
| # scores: [N, E], bias: [E] | ||
| n, e = scores.shape | ||
| scores_sig = scores.sigmoid() | ||
| scores_for_choice = scores_sig + bias.unsqueeze(0) | ||
|
|
||
| # group selection via top2 sum | ||
| g = num_expert_group | ||
| per_group = e // g | ||
| view = scores_for_choice.view(n, g, per_group) | ||
| top2 = torch.topk(view, k=2, dim=-1).values | ||
| group_scores = top2.sum(dim=-1) # [n, g] | ||
| group_idx = torch.topk( | ||
| group_scores, k=topk_group, dim=-1, sorted=False | ||
| ).indices # [n, topk_group] | ||
|
|
||
| # mask and topk within selected groups | ||
| group_mask = torch.zeros_like(group_scores) | ||
| group_mask.scatter_(1, group_idx, 1) | ||
| score_mask = group_mask.unsqueeze(-1).expand(n, g, per_group).reshape(n, e) | ||
| tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), float("-inf")) | ||
|
|
||
| topk_vals, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) | ||
| topk_weights = scores_sig.gather(1, topk_ids) | ||
|
|
||
| # renormalize | ||
| topk_weights = topk_weights / (topk_weights.sum(dim=-1, keepdim=True) + 1e-12) | ||
| return topk_weights, topk_ids | ||
|
|
||
|
|
||
| # wrap reference with different compile modes | ||
| def make_ref_fn(compile_mode: str): | ||
| fn = biased_grouped_topk_ref_impl | ||
| if compile_mode == "eager": | ||
| return fn | ||
| if compile_mode == "compile-static": | ||
| return torch.compile(fn, dynamic=False) | ||
| if compile_mode == "compile-dynamic": | ||
| return torch.compile(fn, dynamic=True) | ||
| raise ValueError(f"Unknown compile_mode: {compile_mode}") | ||
|
|
||
|
|
||
| def moe_fused_gate_kernel(scores, bias, num_expert_group, topk_group, topk): | ||
| return moe_fused_gate(scores, bias, num_expert_group, topk_group, topk) | ||
|
|
||
|
|
||
| # Choose a sequence length sweep consistent with existing benchmark style | ||
| seq_length_range = [2048, 3072, 4096, 10240, 15360, 20480] | ||
|
|
||
| # Focus on tiled-path configs (VPT > 32) | ||
| configs = [] | ||
| configs += [(sq, 64, 1, 1, 6) for sq in seq_length_range] # Kimi VL: VPT=64 | ||
| configs += [(sq, 384, 1, 1, 8) for sq in seq_length_range] # Kimi K2: VPT=384 | ||
|
|
||
|
|
||
| def _bench_template(dtype: torch.dtype, plot_suffix: str): | ||
| @triton.testing.perf_report( | ||
| triton.testing.Benchmark( | ||
| x_names=[ | ||
| "seq_length", | ||
| "num_experts", | ||
| "num_expert_group", | ||
| "topk_group", | ||
| "topk", | ||
| ], | ||
| x_vals=[list(_) for _ in configs], | ||
| line_arg="provider", | ||
| line_vals=[ | ||
| "orig-eager", | ||
| "orig-compile-static", | ||
| "orig-compile-dynamic", | ||
| "kernel", | ||
| ], | ||
| line_names=[ | ||
| "Original-Eager", | ||
| "Original-Compile-Static", | ||
| "Original-Compile-Dynamic", | ||
| "SGL Kernel", | ||
| ], | ||
| styles=[("blue", "-"), ("green", "-"), ("orange", "-"), ("red", "-")], | ||
| ylabel="us", | ||
| plot_name=f"moe-fused-gate-tiled-performance-{plot_suffix}", | ||
| args={}, | ||
| ) | ||
| ) | ||
| def benchmark( | ||
| seq_length, num_experts, num_expert_group, topk_group, topk, provider | ||
| ): | ||
| device = torch.device("cuda") | ||
|
|
||
| scores = torch.randn((seq_length, num_experts), device=device, dtype=dtype) | ||
| bias = torch.rand(num_experts, device=device, dtype=dtype) | ||
|
|
||
| quantiles = [0.5, 0.2, 0.8] | ||
|
|
||
| if provider.startswith("orig"): | ||
| mode = provider.replace("orig-", "") | ||
| ref = make_ref_fn(mode) | ||
| ms, min_ms, max_ms = triton.testing.do_bench( | ||
| lambda: ref( | ||
| scores.clone(), bias.clone(), num_expert_group, topk_group, topk | ||
| ), | ||
| quantiles=quantiles, | ||
| ) | ||
| elif provider == "kernel": | ||
| ms, min_ms, max_ms = triton.testing.do_bench( | ||
| lambda: moe_fused_gate_kernel( | ||
| scores.clone(), bias.clone(), num_expert_group, topk_group, topk | ||
| ), | ||
| quantiles=quantiles, | ||
| ) | ||
| else: | ||
| raise ValueError(provider) | ||
|
|
||
| return 1000 * ms, 1000 * max_ms, 1000 * min_ms | ||
|
|
||
| return benchmark | ||
|
|
||
|
|
||
| benchmark_bf16 = _bench_template(torch.bfloat16, "bf16") | ||
| benchmark_fp16 = _bench_template(torch.float16, "fp16") | ||
| benchmark_fp32 = _bench_template(torch.float32, "fp32") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| benchmark_bf16.run(print_data=True) | ||
| benchmark_fp16.run(print_data=True) | ||
| benchmark_fp32.run(print_data=True) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why comment kimi config?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here I am uncommenting to benchmark one of the three configurations separately.