Fuse writing KV buffer into rope kernel (part 2: srt)#9014
Fuse writing KV buffer into rope kernel (part 2: srt)#9014zhyncs merged 116 commits intosgl-project:mainfrom JeremieMelo:jiaqi/kv_rope_fuse
Conversation
There was a problem hiding this comment.
Summary of Changes
Hello @JeremieMelo, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
I've implemented a significant performance optimization by fusing the Key-Value (KV) cache buffer saving operation directly into the Rotary Positional Embedding (RoPE) computation. This change is specifically applied to the trtllm_mha attention mechanism. The primary goal is to reduce overhead and improve efficiency by performing these two related operations within a single, optimized kernel, rather than as separate steps. This involves modifications across the Python and C++ layers of the attention and rotary embedding components, ensuring seamless integration and leveraging CUDA stream synchronization for better performance, especially in CUDA graph capture mode.
Highlights
- KV Buffer and RoPE Fusion: I have fused the
set_kv_bufferoperation, which is responsible for saving key and value tensors to the KV cache, directly into the rotary positional embedding (RoPE) computation. This optimization specifically targets thetrtllm_mhaattention mechanism. - Streamlined KV Cache Updates: The previous explicit calls to
set_kv_bufferin thetrtllm_mha_backendhave been removed. The KV cache update now occurs implicitly as part of theRotaryEmbedding'sforward_cudamethod, which calls a modified C++ kernel. - C++ Kernel Enhancements for Fused Operations: The underlying C++ CUDA kernel for applying rotary embeddings (
apply_rope_pos_ids_cos_sin_cache) has been enhanced to directly handle the saving of the processed K and V tensors to the KV cache buffers. This includes handling data type conversions, scaling, and optimized memory copies, particularly for CUDA graph capture mode using an alternate stream for overlap. - Integration into Model Architecture: The
GPTAttentionandGPTLMHeadModelclasses have been updated to pass necessary context, such as the attention layer, forward batch information, and thestart_layerfor memory pooling, to the rotary embedding function, enabling the fused operation.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments or fill out our survey to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Code Review
This pull request fuses the set_kv_buffer operation into the RoPE kernel for trtllm_mha attention, which is a good optimization. My review focuses on the correctness of the implementation. I've found a critical issue in the CUDA stream synchronization logic that could lead to race conditions or runtime errors, and a bug in how scaling factors are passed. I've also suggested some code cleanup by removing commented-out code blocks.
sgl-kernel/csrc/elementwise/rope.cu
Outdated
| if (is_capture_mode && alt_stream_ptr != 0) { | ||
| cudaStream_t alt_stream = reinterpret_cast<cudaStream_t>(alt_stream_ptr); | ||
| cudaStream_t main_stream = stream; | ||
|
|
||
| // Wait for main stream to complete RoPE | ||
| cudaStreamWaitEvent(alt_stream, nullptr, 0); | ||
|
|
||
| // Copy K on main stream | ||
| k_buffer_ptr.copy_(k_rope, /*non_blocking=*/true); | ||
|
|
||
| // Copy V on alternate stream | ||
| at::cuda::CUDAStreamGuard guard(at::cuda::getStreamFromExternal(alt_stream, device.index())); | ||
| v_buffer_ptr.copy_(v, /*non_blocking=*/true); | ||
|
|
||
| // Main stream waits for alt stream | ||
| cudaStreamWaitEvent(main_stream, nullptr, 0); | ||
| } else { |
There was a problem hiding this comment.
The stream synchronization logic here is incorrect and will lead to runtime errors or race conditions. cudaStreamWaitEvent is called with a nullptr event, which is invalid. You need to create and use a cudaEvent_t to properly synchronize between the main and alternate streams. The original Python code stream.wait_stream(other_stream) uses an event pool internally, which needs to be replicated here.
Here is a suggestion for a correct implementation:
if (is_capture_mode && alt_stream_ptr != 0) {
cudaStream_t alt_stream = reinterpret_cast<cudaStream_t>(alt_stream_ptr);
cudaStream_t main_stream = stream;
cudaEvent_t event;
cudaEventCreateWithFlags(&event, cudaEventDisableTiming);
// Wait for main stream to complete RoPE
cudaEventRecord(event, main_stream);
cudaStreamWaitEvent(alt_stream, event, 0);
// Copy K on main stream
k_buffer_ptr.copy_(k_rope, /*non_blocking=*/true);
// Copy V on alternate stream
{
at::cuda::CUDAStreamGuard guard(at::cuda::getStreamFromExternal(alt_stream, device.index()));
v_buffer_ptr.copy_(v, /*non_blocking=*/true);
}
// Main stream waits for alt stream
cudaEventRecord(event, alt_stream);
cudaStreamWaitEvent(main_stream, event, 0);
cudaEventDestroy(event);
} else {
| # cache_loc = forward_batch.out_cache_loc | ||
| # if save_kv_cache and k is not None: | ||
| # forward_batch.token_to_kv_pool.set_kv_buffer( | ||
| # layer, cache_loc, k, v, layer.k_scale, layer.v_scale | ||
| # ) |
| # cache_loc = forward_batch.out_cache_loc | ||
| # if save_kv_cache and k is not None: | ||
| # forward_batch.token_to_kv_pool.set_kv_buffer( | ||
| # layer, cache_loc, k, v, layer.k_scale, layer.v_scale | ||
| # ) |
| ## fused from memory_pool set_kv_buffer | ||
| """ | ||
| if layer_id_override is not None: | ||
| layer_id = layer_id_override | ||
| else: | ||
| layer_id = layer.layer_id | ||
| if cache_k.dtype != self.dtype: | ||
| if k_scale is not None: | ||
| cache_k.div_(k_scale) | ||
| if v_scale is not None: | ||
| cache_v.div_(v_scale) | ||
| cache_k = cache_k.to(self.dtype) | ||
| cache_v = cache_v.to(self.dtype) | ||
|
|
||
| if self.store_dtype != self.dtype: | ||
| cache_k = cache_k.view(self.store_dtype) | ||
| cache_v = cache_v.view(self.store_dtype) | ||
|
|
||
| if get_is_capture_mode() and self.alt_stream is not None: | ||
| # Overlap the copy of K and V cache for small batch size | ||
| current_stream = self.device_module.current_stream() | ||
| self.alt_stream.wait_stream(current_stream) | ||
| self.k_buffer[layer_id - self.start_layer][loc] = cache_k | ||
| with self.device_module.stream(self.alt_stream): | ||
| self.v_buffer[layer_id - self.start_layer][loc] = cache_v | ||
| current_stream.wait_stream(self.alt_stream) | ||
| else: | ||
| self.k_buffer[layer_id - self.start_layer][loc] = cache_k | ||
| self.v_buffer[layer_id - self.start_layer][loc] = cache_v | ||
| """ |
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
| forward_batch.token_to_kv_pool.set_kv_buffer( | ||
| layer, cache_loc, k, v, layer.k_scale, layer.v_scale | ||
| ) | ||
| # cache_loc = forward_batch.out_cache_loc |
There was a problem hiding this comment.
(don't forget to re-enable these code and do branching at correct places)
| cos_sin_cache=self.cos_sin_cache, | ||
| is_neox=self.is_neox_style, | ||
| layer=layer, | ||
| forward_batch=forward_batch, |
There was a problem hiding this comment.
maybe we should not pass such objects to this API
what about e.g.
def apply_rope_with_cos_sin_cache_inplace(
...,
# in non-fused version we do `k_buffer[loc] = data` etc
k_buffer: Tensor, v_buffer: Tensor, loc: Tensor,
)
and if none, it means we do not save kv cache; if non-none then we need to save
There was a problem hiding this comment.
confused why do we have a new file instead of minor modifications
There was a problem hiding this comment.
I am unsure about the best practice for such modifications, e.g., ensuring compatibility. I currently use this style to avoid touching any old files/functions/logic, and introduce fully standalone files to decouple. Any suggestion to make it more compatible /extensible for future updates is appreciated.
There was a problem hiding this comment.
I personally think the fuse is only a dozen lines thus can be inlined. note that you can use c++ template and constexpr if there are overheads
| key: torch.Tensor, | ||
| offsets: Optional[torch.Tensor] = None, | ||
| layer: Any = None, # RadixAttention | ||
| forward_batch=None, |
There was a problem hiding this comment.
maybe we should not put layer,forward_batch etc into such a low-level. I personally suggest to put the k_buffer_ptr, v_buffer_ptr, etc, down to here. If too verbose, maybe make a simple @DataClass to pack them
|
@JeremieMelo please fix the conflicts |
Co-authored-by: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com>
Co-authored-by: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com>
Motivation
Fuse set_kv_buffer to sgl-kernel rope function, only for trtllm_mha attention
(below is from @fzyzcjy)
speed may be suboptimal (I have not done any ncu profile or thorough optimization), but anyway it is faster than non-fused
acc looks good: 20B TP4, reasoning low, gpt_oss.evals gpqa: 55.1%
acc from @BBuf (he checked this agree with main):
speedup: @BBuf test to be 345->355