-
Notifications
You must be signed in to change notification settings - Fork 4.5k
[Move sgl-kernel Kernel to JIT] Add JIT concat MLA kernels #17889
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
03e0b33
wip: add jit concat mla
celve 5b35ae5
feat: add test
celve 97084ed
wip: different return style
celve 299a76a
wip: remove can use helper
celve 8b3094e
wip: add benchmark
celve 79a1026
wip: align with utils
celve c000169
wip: fix lint issues
celve File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,163 @@ | ||
| import itertools | ||
|
|
||
| import torch | ||
| import triton | ||
| import triton.testing | ||
| from sgl_kernel import concat_mla_absorb_q as aot_absorb_q | ||
| from sgl_kernel import concat_mla_k as aot_k | ||
|
|
||
| from sglang.jit_kernel.benchmark.utils import is_in_ci | ||
| from sglang.jit_kernel.concat_mla import concat_mla_absorb_q as jit_absorb_q | ||
| from sglang.jit_kernel.concat_mla import concat_mla_k as jit_k | ||
|
|
||
| IS_CI = is_in_ci() | ||
|
|
||
| # Constants | ||
| NUM_LOCAL_HEADS = 128 | ||
| QK_NOPE_HEAD_DIM = 128 | ||
| QK_ROPE_HEAD_DIM = 64 | ||
| K_HEAD_DIM = QK_NOPE_HEAD_DIM + QK_ROPE_HEAD_DIM | ||
|
|
||
| A_LAST_DIM = 512 | ||
| B_LAST_DIM = 64 | ||
|
|
||
| DTYPE = torch.bfloat16 | ||
| DEVICE = "cuda" | ||
|
|
||
|
|
||
| def aot_concat_mla_k(k, k_nope, k_rope): | ||
| aot_k(k, k_nope, k_rope) | ||
|
|
||
|
|
||
| def jit_concat_mla_k(k, k_nope, k_rope): | ||
| jit_k(k, k_nope, k_rope) | ||
|
|
||
|
|
||
| def torch_concat_mla_k(k, k_nope, k_rope): | ||
| nope_head_dim = k_nope.shape[-1] | ||
| k[:, :, :nope_head_dim] = k_nope | ||
| k[:, :, nope_head_dim:] = k_rope.expand(-1, k.shape[1], -1) | ||
|
|
||
|
|
||
| def aot_concat_mla_absorb_q(a, b): | ||
| return aot_absorb_q(a, b) | ||
|
|
||
|
|
||
| def jit_concat_mla_absorb_q(a, b): | ||
| return jit_absorb_q(a, b) | ||
|
|
||
|
|
||
| def torch_concat_mla_absorb_q(a, b, out): | ||
| a_last_dim = a.shape[-1] | ||
| out[:, :, :a_last_dim] = a | ||
| out[:, :, a_last_dim:] = b | ||
|
|
||
|
|
||
| if IS_CI: | ||
| NUM_TOKENS_VALS = [256, 1024] | ||
| else: | ||
| NUM_TOKENS_VALS = [256, 512, 1024, 2048, 4096, 8192, 16384, 32768] | ||
|
|
||
| K_LINE_VALS = ["aot", "jit", "torch"] | ||
| K_LINE_NAMES = ["SGL AOT Kernel", "SGL JIT Kernel", "PyTorch"] | ||
| K_STYLES = [("orange", "-"), ("blue", "--"), ("green", "-.")] | ||
|
|
||
|
|
||
| def _create_concat_mla_k_data(num_tokens): | ||
| """Allocate oversized containers and slice to produce non-contiguous tensors.""" | ||
| k_nope_container = torch.randn( | ||
| (num_tokens, NUM_LOCAL_HEADS, QK_NOPE_HEAD_DIM + 128), | ||
| dtype=DTYPE, | ||
| device=DEVICE, | ||
| ) | ||
| k_nope = k_nope_container[:, :, :QK_NOPE_HEAD_DIM] | ||
|
|
||
| k_rope_container = torch.randn( | ||
| (num_tokens, 1, 128 + QK_ROPE_HEAD_DIM), | ||
| dtype=DTYPE, | ||
| device=DEVICE, | ||
| ) | ||
| k_rope = k_rope_container[:, :, -QK_ROPE_HEAD_DIM:] | ||
|
|
||
| k = torch.empty( | ||
| (num_tokens, NUM_LOCAL_HEADS, K_HEAD_DIM), | ||
| dtype=DTYPE, | ||
| device=DEVICE, | ||
| ) | ||
| return k, k_nope, k_rope | ||
|
|
||
|
|
||
| @triton.testing.perf_report( | ||
| triton.testing.Benchmark( | ||
| x_names=["num_tokens"], | ||
| x_vals=NUM_TOKENS_VALS, | ||
| line_arg="provider", | ||
| line_vals=K_LINE_VALS, | ||
| line_names=K_LINE_NAMES, | ||
| styles=K_STYLES, | ||
| ylabel="us", | ||
| plot_name="concat-mla-k-performance", | ||
| args={}, | ||
| ) | ||
| ) | ||
| def bench_concat_mla_k(num_tokens: int, provider: str): | ||
| k, k_nope, k_rope = _create_concat_mla_k_data(num_tokens) | ||
|
|
||
| FN_MAP = { | ||
| "aot": aot_concat_mla_k, | ||
| "jit": jit_concat_mla_k, | ||
| "torch": torch_concat_mla_k, | ||
| } | ||
| fn = lambda: FN_MAP[provider](k, k_nope, k_rope) | ||
| quantiles = [0.5, 0.2, 0.8] | ||
| ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles) | ||
| return 1000 * ms, 1000 * max_ms, 1000 * min_ms | ||
|
|
||
|
|
||
| if IS_CI: | ||
| ABSORB_Q_VALS = list(itertools.product([4, 16], [16])) | ||
| else: | ||
| ABSORB_Q_VALS = list(itertools.product([1, 4, 8, 16, 32], [1, 8, 32, 128])) | ||
|
|
||
| Q_LINE_VALS = ["aot", "jit", "torch"] | ||
| Q_LINE_NAMES = ["SGL AOT Kernel", "SGL JIT Kernel", "PyTorch"] | ||
| Q_STYLES = [("orange", "-"), ("blue", "--"), ("green", "-.")] | ||
|
|
||
|
|
||
| @triton.testing.perf_report( | ||
| triton.testing.Benchmark( | ||
| x_names=["dim_0", "dim_1"], | ||
| x_vals=ABSORB_Q_VALS, | ||
| line_arg="provider", | ||
| line_vals=Q_LINE_VALS, | ||
| line_names=Q_LINE_NAMES, | ||
| styles=Q_STYLES, | ||
| ylabel="us", | ||
| plot_name="concat-mla-absorb-q-performance", | ||
| args={}, | ||
| ) | ||
| ) | ||
| def bench_concat_mla_absorb_q(dim_0: int, dim_1: int, provider: str): | ||
| a = torch.randn(dim_0, dim_1, A_LAST_DIM, dtype=DTYPE, device=DEVICE) | ||
| b = torch.randn(dim_0, dim_1, B_LAST_DIM, dtype=DTYPE, device=DEVICE) | ||
|
|
||
| if provider == "torch": | ||
| out = torch.empty( | ||
| dim_0, dim_1, A_LAST_DIM + B_LAST_DIM, dtype=DTYPE, device=DEVICE | ||
| ) | ||
| fn = lambda: torch_concat_mla_absorb_q(a, b, out) | ||
| else: | ||
| FN_MAP = { | ||
| "aot": aot_concat_mla_absorb_q, | ||
| "jit": jit_concat_mla_absorb_q, | ||
| } | ||
| fn = lambda: FN_MAP[provider](a, b) | ||
|
|
||
| quantiles = [0.5, 0.2, 0.8] | ||
| ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles) | ||
| return 1000 * ms, 1000 * max_ms, 1000 * min_ms | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| bench_concat_mla_k.run(print_data=True) | ||
| bench_concat_mla_absorb_q.run(print_data=True) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,65 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from typing import TYPE_CHECKING | ||
|
|
||
| import torch | ||
|
|
||
| from sglang.jit_kernel.utils import cache_once, load_jit | ||
|
|
||
| if TYPE_CHECKING: | ||
| from tvm_ffi.module import Module | ||
|
|
||
|
|
||
| @cache_once | ||
| def _jit_concat_mla_k_module() -> Module: | ||
| return load_jit( | ||
| "concat_mla_k", | ||
| cuda_files=["elementwise/concat_mla.cuh"], | ||
| cuda_wrappers=[("concat_mla_k", "ConcatMlaKKernel::run")], | ||
| ) | ||
|
|
||
|
|
||
| @cache_once | ||
| def _jit_concat_mla_absorb_q_module() -> Module: | ||
| return load_jit( | ||
| "concat_mla_absorb_q", | ||
| cuda_files=["elementwise/concat_mla.cuh"], | ||
| cuda_wrappers=[("concat_mla_absorb_q", "ConcatMlaAbsorbQKernel::run")], | ||
| ) | ||
|
|
||
|
|
||
| def concat_mla_k(k: torch.Tensor, k_nope: torch.Tensor, k_rope: torch.Tensor) -> None: | ||
| """ | ||
| Concatenate k_nope and k_rope into k for MLA (Multi-head Latent Attention). | ||
|
|
||
| This kernel efficiently broadcasts k_rope across all heads while copying | ||
| k_nope values directly. | ||
|
|
||
| Args: | ||
| k: Output tensor of shape [num_tokens, num_heads=128, k_head_dim=192], dtype=bfloat16 | ||
| k_nope: Input tensor of shape [num_tokens, num_heads=128, nope_head_dim=128], dtype=bfloat16 | ||
| k_rope: Input tensor of shape [num_tokens, 1, rope_head_dim=64], dtype=bfloat16 | ||
| """ | ||
| module = _jit_concat_mla_k_module() | ||
| module.concat_mla_k(k, k_nope, k_rope) | ||
|
|
||
|
|
||
| def concat_mla_absorb_q(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: | ||
| """ | ||
| Concatenate tensors a and b for MLA absorbed Q computation. | ||
|
|
||
| Args: | ||
| a: Input tensor of shape [dim_0, dim_1, a_last_dim], dtype=bfloat16 | ||
| b: Input tensor of shape [dim_0, dim_1, b_last_dim], dtype=bfloat16 | ||
|
|
||
| Returns: | ||
| Output tensor of shape [dim_0, dim_1, a_last_dim + b_last_dim], dtype=bfloat16 | ||
| """ | ||
DarkSharpness marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| out = torch.empty( | ||
| (*a.shape[:-1], a.shape[-1] + b.shape[-1]), | ||
| dtype=a.dtype, | ||
| device=a.device, | ||
| ) | ||
| module = _jit_concat_mla_absorb_q_module() | ||
| module.concat_mla_absorb_q(a, b, out) | ||
| return out | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.