[Kernel] Add JIT apply_rope_with_cos_sin_cache_inplace#18155
[Kernel] Add JIT apply_rope_with_cos_sin_cache_inplace#18155BBuf merged 8 commits intosgl-project:mainfrom
Conversation
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
|
/tag-and-rerun-ci |
| size_t k_rope_stride_h = k_rope.stride(1); | ||
|
|
||
| auto query_dtype = q.dtype(); | ||
| const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); |
There was a problem hiding this comment.
@DarkSharpness Should we avoid use torch aten in jit_kernel, how can we replace at::cuda::getCurrentCUDAStream() in tvm-ffi.
There was a problem hiding this comment.
Can we use const cudaStream_t stream = LaunchKernel::resolve_device(device); to replace torch Aten CUDA Stream?
|
/rerun-failed-ci |
python/sglang/jit_kernel/rope.py
Outdated
|
|
||
| @cache_once | ||
| def _jit_apply_rope_pos_ids_cos_sin_cache_module() -> Module: | ||
| import flashinfer |
There was a problem hiding this comment.
Can we move the line to the top?
python/sglang/jit_kernel/rope.py
Outdated
| import flashinfer | ||
|
|
||
| flashinfer_dir = pathlib.Path(flashinfer.__file__).parent.resolve() | ||
| assert (flashinfer_dir / "data" / "include").exists() |
python/sglang/jit_kernel/rope.py
Outdated
| # and torch.custom_op cannot express optional mutates_args reliably | ||
| @custom_op( | ||
| "sgl_jit_kernel::apply_rope_pos_ids_cos_sin_cache_save_kv_cache", | ||
| mutates_args="unknown", |
|
one small question: do we really need to upstream all the kernels from flashinfer? In most cases, i guess we can directly use flashinfer kernels if our JIT kernels doesn't have advantages in performance. |
It seems the goal is to build a fused RoPE + KV-cache-save kernel, so the BatchQKApplyRotaryPosIdsCosSinCacheEnhanced class is implemented. If KV-cache saving isn’t needed, then BatchQKApplyRotaryPosIdsCosSinCache should be able to call flashInfer directly from Python. |
|
/rerun-failed-ci |
@DarkSharpness Totally agree with you. Honestly speaking, I observed some perf regression for jit kernel, which introduces unnecessary compiling. It might not be implemented by design, but I'd like to say if jit kernel is not well applied, it introduces more overhead than profits. Take the following case for example:
|

Motivation
Add JIT-compiled CUDA kernels for apply_rope_with_cos_sin_cache_inplace
Modifications
Accuracy Tests
Verified against apply_rope_with_cos_sin_cache_inplace from sgl_kernel
Check
python/sglang/jit_kernel/tests/test_rope.pyRun
python -m pytest python/sglang/jit_kernel/tests/test_rope.py -sAccuracy test on gsm8k after applying jit apply_rope_with_cos_sin_cache_inplace to srt
Benchmarking and Profiling
A10
Checklist
Review Process
/tag-run-ci-label,/rerun-failed-ci,/tag-and-rerun-ci