[WIP] Add Triton kernel get_last_loc & assign_extend_cache_locs kernel#261
[WIP] Add Triton kernel get_last_loc & assign_extend_cache_locs kernel#261hw-csong wants to merge 2 commits intosgl-project:mainfrom
Conversation
Summary of ChangesHello @hw-csong, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces two new Triton kernels aimed at optimizing memory cache operations for token management. The Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces two new Triton kernels, get_last_loc and assign_extend_cache_locs. My review focuses on improving their performance and correctness.
For get_last_loc_kernel, I've suggested adding Triton's autotuning feature to find the optimal BLOCK_SIZE, as the current hardcoded value is very small and could lead to underutilization of the hardware.
For assign_extend_cache_locs, I've identified a critical performance issue in how offsets are calculated, which has quadratic complexity. I've recommended pre-calculating these offsets on the host. I also suggested improving its autotuning configuration to tune BLOCK_SIZE and to use a proper key to avoid unnecessary kernel recompilations.
| length_offset = tl.arange(0, bs_upper) | ||
| start = tl.load(start_offset + length_offset, mask=length_offset < pid, other=0) | ||
| end = tl.load(end_offset + length_offset, mask=length_offset < pid, other=0) | ||
| start = start.to(tl.float32) | ||
| end = end.to(tl.float32) | ||
| out_offset = tl.sum(end - start, axis=0) | ||
| out_offset = out_offset.to(tl.int64) |
There was a problem hiding this comment.
The current implementation for calculating out_offset is highly inefficient. It computes a prefix sum within each kernel program, which results in a quadratic number of memory loads with respect to the number of requests (O(bs_upper^2) total loads). This will be a major performance bottleneck for larger batch sizes.
This calculation should be performed efficiently on the host before launching the kernel. You can compute the offsets using torch.cumsum and pass them as a new tensor argument to the kernel.
Example of host-side computation:
lengths = end_offset - start_offset
# Exclusive prefix sum to get the start offset for each item in the output.
out_offsets = torch.cumsum(lengths, dim=0) - lengthsThen, in the kernel, you would replace lines 26-32 with a simple load:
out_offset = tl.load(out_offsets_ptr + pid)
This would require changing the kernel signature to accept the out_offsets tensor.
Additionally, the type conversion to tl.float32 and back to tl.int64 on lines 29-32 is unnecessary, as tl.sum supports integer types. This should be removed.
| def get_last_loc_triton( | ||
| req_to_token: torch.Tensor, | ||
| req_pool_indices_tensor: torch.Tensor, | ||
| prefix_lens_tensor: torch.Tensor, | ||
| ) -> torch.Tensor: | ||
| BLOCK_SIZE = 4 | ||
| num_tokens = prefix_lens_tensor.shape[0] | ||
| result = torch.empty_like(prefix_lens_tensor) | ||
| grid = (triton.cdiv(num_tokens, BLOCK_SIZE),) | ||
|
|
||
| get_last_loc_kernel[grid]( | ||
| req_to_token, | ||
| req_pool_indices_tensor, | ||
| prefix_lens_tensor, | ||
| result, | ||
| num_tokens, | ||
| req_to_token.stride(0), | ||
| BLOCK_SIZE, | ||
| ) | ||
| return result | ||
|
|
||
| @triton.jit | ||
| def get_last_loc_kernel( | ||
| req_to_token, | ||
| req_pool_indices_tensor, | ||
| prefix_lens_tensor, | ||
| result, | ||
| num_tokens, | ||
| req_to_token_stride, | ||
| BLOCK_SIZE: tl.constexpr, | ||
| ): | ||
| pid = tl.program_id(0) | ||
| offset = tl.arange(0, BLOCK_SIZE) + pid * BLOCK_SIZE | ||
| mask = offset < num_tokens | ||
|
|
||
| prefix_lens = tl.load(prefix_lens_tensor + offset, mask=mask, other=0) | ||
| req_pool_indices = tl.load(req_pool_indices_tensor + offset, mask=mask, other=0) | ||
|
|
||
| token_mask = prefix_lens > 0 | ||
| token_index = req_pool_indices * req_to_token_stride + (prefix_lens - 1) | ||
| tokens = tl.load(req_to_token + token_index, mask=token_mask, other=-1) | ||
|
|
||
| tl.store(result + offset, tokens, mask=mask) |
There was a problem hiding this comment.
The BLOCK_SIZE is hardcoded to a very small value of 4. This is likely to underutilize the hardware and lead to poor performance. It's better to use Triton's autotuning feature to find an optimal BLOCK_SIZE for different inputs.
To implement this, you should decorate the get_last_loc_kernel with @triton.autotune, define a set of BLOCK_SIZE values to test, and adjust the host-side wrapper get_last_loc_triton to launch the kernel correctly. This will require reordering the functions so that the kernel is defined before the wrapper function that calls it.
Here is a suggested implementation:
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE': 128}, num_warps=4),
triton.Config({'BLOCK_SIZE': 256}, num_warps=4),
triton.Config({'BLOCK_SIZE': 512}, num_warps=8),
triton.Config({'BLOCK_SIZE': 1024}, num_warps=8),
],
key=['num_tokens'],
)
@triton.jit
def get_last_loc_kernel(
req_to_token,
req_pool_indices_tensor,
prefix_lens_tensor,
result,
num_tokens,
req_to_token_stride,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
offset = tl.arange(0, BLOCK_SIZE) + pid * BLOCK_SIZE
mask = offset < num_tokens
prefix_lens = tl.load(prefix_lens_tensor + offset, mask=mask, other=0)
req_pool_indices = tl.load(req_pool_indices_tensor + offset, mask=mask, other=0)
token_mask = prefix_lens > 0
token_index = req_pool_indices * req_to_token_stride + (prefix_lens - 1)
tokens = tl.load(req_to_token + token_index, mask=token_mask, other=-1)
tl.store(result + offset, tokens, mask=mask)
def get_last_loc_triton(
req_to_token: torch.Tensor,
req_pool_indices_tensor: torch.Tensor,
prefix_lens_tensor: torch.Tensor,
) -> torch.Tensor:
num_tokens = prefix_lens_tensor.shape[0]
result = torch.empty_like(prefix_lens_tensor)
def grid(meta):
return (triton.cdiv(num_tokens, meta['BLOCK_SIZE']),)
get_last_loc_kernel[grid](
req_to_token,
req_pool_indices_tensor,
prefix_lens_tensor,
result,
num_tokens,
req_to_token.stride(0),
)
return result| @triton.autotune( | ||
| configs=get_autotune_config(), | ||
| key=[], | ||
| ) | ||
| @triton.jit | ||
| def assign_extend_cache_locs( | ||
| req_pool_indices, | ||
| req_to_token, | ||
| start_offset, | ||
| end_offset, | ||
| out_cache_loc, | ||
| pool_len: tl.constexpr, | ||
| bs_upper: tl.constexpr, | ||
| ): | ||
| BLOCK_SIZE: tl.constexpr = 32 |
There was a problem hiding this comment.
The autotuning configuration can be significantly improved for better performance.
- Empty
key: An emptykey(key=[]) causes Triton to recompile the kernel for every new input shape, which is inefficient. The key should include arguments that affect the kernel's performance characteristics, such as tensor dimensions that are not known at compile time. For instance, you could usekey=['bs_upper']. - Hardcoded
BLOCK_SIZE: TheBLOCK_SIZEis hardcoded to 32 inside the kernel. This is a key parameter for performance and should be tuned. It should be moved to the autotuning configuration and passed as aconstexprargument to the kernel.
These changes would make the get_autotune_config function unnecessary. Here's a suggested implementation that applies these improvements:
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE': 32}, num_warps=2, multibuffer=True),
triton.Config({'BLOCK_SIZE': 64}, num_warps=2, multibuffer=True),
triton.Config({'BLOCK_SIZE': 128}, num_warps=4, multibuffer=True),
triton.Config({'BLOCK_SIZE': 256}, num_warps=4, multibuffer=True),
],
key=['bs_upper'],
)
@triton.jit
def assign_extend_cache_locs(
req_pool_indices,
req_to_token,
start_offset,
end_offset,
out_cache_loc,
pool_len: tl.constexpr,
bs_upper: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
No description provided.