Skip to content

Comments

[Kernel] Add JIT apply_rope_with_cos_sin_cache_inplace#18155

Merged
BBuf merged 8 commits intosgl-project:mainfrom
pansicheng:jit-rope
Feb 5, 2026
Merged

[Kernel] Add JIT apply_rope_with_cos_sin_cache_inplace#18155
BBuf merged 8 commits intosgl-project:mainfrom
pansicheng:jit-rope

Conversation

@pansicheng
Copy link
Collaborator

Motivation

Add JIT-compiled CUDA kernels for apply_rope_with_cos_sin_cache_inplace

Modifications

Accuracy Tests

Verified against apply_rope_with_cos_sin_cache_inplace from sgl_kernel
Check python/sglang/jit_kernel/tests/test_rope.py
Run python -m pytest python/sglang/jit_kernel/tests/test_rope.py -s

Accuracy test on gsm8k after applying jit apply_rope_with_cos_sin_cache_inplace to srt

python3 -m sglang.launch_server --model-path /path/to/Qwen3-8B
python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 1319 --parallel 1319
Accuracy: 0.902
Invalid: 0.000
Latency: 161.404 s
Output throughput: 1073.649 token/s

Benchmarking and Profiling

A10

Performance Test - Batch=8, SeqLen=256
JIT: 0.039167404ms, SGL: 0.039143562ms
Speedup (SGL/JIT): 1.00x
.
Performance Test - Batch=8, SeqLen=512
JIT: 0.077519417ms, SGL: 0.077657700ms
Speedup (SGL/JIT): 1.00x
.
Performance Test - Batch=8, SeqLen=1024
JIT: 0.156660080ms, SGL: 0.157065392ms
Speedup (SGL/JIT): 1.00x
.

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@gemini-code-assist
Copy link
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@BBuf
Copy link
Collaborator

BBuf commented Feb 3, 2026

/tag-and-rerun-ci

@github-actions github-actions bot added the run-ci label Feb 3, 2026
size_t k_rope_stride_h = k_rope.stride(1);

auto query_dtype = q.dtype();
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
Copy link
Collaborator

@BBuf BBuf Feb 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DarkSharpness Should we avoid use torch aten in jit_kernel, how can we replace at::cuda::getCurrentCUDAStream() in tvm-ffi.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use const cudaStream_t stream = LaunchKernel::resolve_device(device); to replace torch Aten CUDA Stream?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

@BBuf
Copy link
Collaborator

BBuf commented Feb 4, 2026

/rerun-failed-ci


@cache_once
def _jit_apply_rope_pos_ids_cos_sin_cache_module() -> Module:
import flashinfer
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move the line to the top?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

import flashinfer

flashinfer_dir = pathlib.Path(flashinfer.__file__).parent.resolve()
assert (flashinfer_dir / "data" / "include").exists()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a check message?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

# and torch.custom_op cannot express optional mutates_args reliably
@custom_op(
"sgl_jit_kernel::apply_rope_pos_ids_cos_sin_cache_save_kv_cache",
mutates_args="unknown",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why unknown here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

@DarkSharpness
Copy link
Collaborator

one small question: do we really need to upstream all the kernels from flashinfer?

In most cases, i guess we can directly use flashinfer kernels if our JIT kernels doesn't have advantages in performance.

@pansicheng
Copy link
Collaborator Author

one small question: do we really need to upstream all the kernels from flashinfer?

In most cases, i guess we can directly use flashinfer kernels if our JIT kernels doesn't have advantages in performance.

It seems the goal is to build a fused RoPE + KV-cache-save kernel, so the BatchQKApplyRotaryPosIdsCosSinCacheEnhanced class is implemented. If KV-cache saving isn’t needed, then BatchQKApplyRotaryPosIdsCosSinCache should be able to call flashInfer directly from Python.

#9077
#9014

@BBuf
Copy link
Collaborator

BBuf commented Feb 5, 2026

/rerun-failed-ci

@BBuf
Copy link
Collaborator

BBuf commented Feb 5, 2026

@BBuf BBuf merged commit 2eb4359 into sgl-project:main Feb 5, 2026
104 of 112 checks passed
charlesHsuGG pushed a commit to charlesHsuGG/sglang that referenced this pull request Feb 9, 2026
@yuan-luo
Copy link
Collaborator

one small question: do we really need to upstream all the kernels from flashinfer?

In most cases, i guess we can directly use flashinfer kernels if our JIT kernels doesn't have advantages in performance.

@DarkSharpness Totally agree with you. Honestly speaking, I observed some perf regression for jit kernel, which introduces unnecessary compiling. It might not be implemented by design, but I'd like to say if jit kernel is not well applied, it introduces more overhead than profits. Take the following case for example:

image

Johnsonms pushed a commit to Johnsonms/sglang that referenced this pull request Feb 14, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants