Skip to content

Comments

[WIP] Add Triton kernel get_last_loc & assign_extend_cache_locs kernel#261

Draft
hw-csong wants to merge 2 commits intosgl-project:mainfrom
hw-csong:triton_kernel
Draft

[WIP] Add Triton kernel get_last_loc & assign_extend_cache_locs kernel#261
hw-csong wants to merge 2 commits intosgl-project:mainfrom
hw-csong:triton_kernel

Conversation

@hw-csong
Copy link

No description provided.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @hw-csong, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces two new Triton kernels aimed at optimizing memory cache operations for token management. The get_last_loc_kernel provides a fast mechanism to identify the final token position for individual requests, while the assign_extend_cache_locs kernel handles the dynamic allocation and extension of cache memory for these tokens. These additions are crucial for enhancing the efficiency of memory handling, particularly in contexts requiring rapid and flexible cache management, such as speculative decoding or large language model inference.

Highlights

  • New Triton Kernel for Last Token Location: Introduced get_last_loc_kernel using Triton to efficiently determine the last token's location for each request within a memory cache system.
  • New Triton Kernel for Cache Location Assignment: Added assign_extend_cache_locs Triton kernel, designed to manage and extend cache locations for tokens, likely supporting dynamic memory allocation strategies for speculative decoding or similar operations.
  • Triton Integration: Both new functionalities are implemented as Triton kernels, indicating a focus on high-performance, GPU-accelerated operations for memory management within the sgl_kernel_npu framework.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces two new Triton kernels, get_last_loc and assign_extend_cache_locs. My review focuses on improving their performance and correctness.

For get_last_loc_kernel, I've suggested adding Triton's autotuning feature to find the optimal BLOCK_SIZE, as the current hardcoded value is very small and could lead to underutilization of the hardware.

For assign_extend_cache_locs, I've identified a critical performance issue in how offsets are calculated, which has quadratic complexity. I've recommended pre-calculating these offsets on the host. I also suggested improving its autotuning configuration to tune BLOCK_SIZE and to use a proper key to avoid unnecessary kernel recompilations.

Comment on lines +26 to +32
length_offset = tl.arange(0, bs_upper)
start = tl.load(start_offset + length_offset, mask=length_offset < pid, other=0)
end = tl.load(end_offset + length_offset, mask=length_offset < pid, other=0)
start = start.to(tl.float32)
end = end.to(tl.float32)
out_offset = tl.sum(end - start, axis=0)
out_offset = out_offset.to(tl.int64)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The current implementation for calculating out_offset is highly inefficient. It computes a prefix sum within each kernel program, which results in a quadratic number of memory loads with respect to the number of requests (O(bs_upper^2) total loads). This will be a major performance bottleneck for larger batch sizes.

This calculation should be performed efficiently on the host before launching the kernel. You can compute the offsets using torch.cumsum and pass them as a new tensor argument to the kernel.

Example of host-side computation:

lengths = end_offset - start_offset
# Exclusive prefix sum to get the start offset for each item in the output.
out_offsets = torch.cumsum(lengths, dim=0) - lengths

Then, in the kernel, you would replace lines 26-32 with a simple load:
out_offset = tl.load(out_offsets_ptr + pid)

This would require changing the kernel signature to accept the out_offsets tensor.

Additionally, the type conversion to tl.float32 and back to tl.int64 on lines 29-32 is unnecessary, as tl.sum supports integer types. This should be removed.

Comment on lines +5 to +47
def get_last_loc_triton(
req_to_token: torch.Tensor,
req_pool_indices_tensor: torch.Tensor,
prefix_lens_tensor: torch.Tensor,
) -> torch.Tensor:
BLOCK_SIZE = 4
num_tokens = prefix_lens_tensor.shape[0]
result = torch.empty_like(prefix_lens_tensor)
grid = (triton.cdiv(num_tokens, BLOCK_SIZE),)

get_last_loc_kernel[grid](
req_to_token,
req_pool_indices_tensor,
prefix_lens_tensor,
result,
num_tokens,
req_to_token.stride(0),
BLOCK_SIZE,
)
return result

@triton.jit
def get_last_loc_kernel(
req_to_token,
req_pool_indices_tensor,
prefix_lens_tensor,
result,
num_tokens,
req_to_token_stride,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
offset = tl.arange(0, BLOCK_SIZE) + pid * BLOCK_SIZE
mask = offset < num_tokens

prefix_lens = tl.load(prefix_lens_tensor + offset, mask=mask, other=0)
req_pool_indices = tl.load(req_pool_indices_tensor + offset, mask=mask, other=0)

token_mask = prefix_lens > 0
token_index = req_pool_indices * req_to_token_stride + (prefix_lens - 1)
tokens = tl.load(req_to_token + token_index, mask=token_mask, other=-1)

tl.store(result + offset, tokens, mask=mask)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The BLOCK_SIZE is hardcoded to a very small value of 4. This is likely to underutilize the hardware and lead to poor performance. It's better to use Triton's autotuning feature to find an optimal BLOCK_SIZE for different inputs.

To implement this, you should decorate the get_last_loc_kernel with @triton.autotune, define a set of BLOCK_SIZE values to test, and adjust the host-side wrapper get_last_loc_triton to launch the kernel correctly. This will require reordering the functions so that the kernel is defined before the wrapper function that calls it.

Here is a suggested implementation:

@triton.autotune(
    configs=[
        triton.Config({'BLOCK_SIZE': 128}, num_warps=4),
        triton.Config({'BLOCK_SIZE': 256}, num_warps=4),
        triton.Config({'BLOCK_SIZE': 512}, num_warps=8),
        triton.Config({'BLOCK_SIZE': 1024}, num_warps=8),
    ],
    key=['num_tokens'],
)
@triton.jit
def get_last_loc_kernel(
        req_to_token,
        req_pool_indices_tensor,
        prefix_lens_tensor,
        result,
        num_tokens,
        req_to_token_stride,
        BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(0)
    offset = tl.arange(0, BLOCK_SIZE) + pid * BLOCK_SIZE
    mask = offset < num_tokens

    prefix_lens = tl.load(prefix_lens_tensor + offset, mask=mask, other=0)
    req_pool_indices = tl.load(req_pool_indices_tensor + offset, mask=mask, other=0)

    token_mask = prefix_lens > 0
    token_index = req_pool_indices * req_to_token_stride + (prefix_lens - 1)
    tokens = tl.load(req_to_token + token_index, mask=token_mask, other=-1)

    tl.store(result + offset, tokens, mask=mask)

def get_last_loc_triton(
    req_to_token: torch.Tensor,
    req_pool_indices_tensor: torch.Tensor,
    prefix_lens_tensor: torch.Tensor,
) -> torch.Tensor:
    num_tokens = prefix_lens_tensor.shape[0]
    result = torch.empty_like(prefix_lens_tensor)

    def grid(meta):
        return (triton.cdiv(num_tokens, meta['BLOCK_SIZE']),)

    get_last_loc_kernel[grid](
        req_to_token,
        req_pool_indices_tensor,
        prefix_lens_tensor,
        result,
        num_tokens,
        req_to_token.stride(0),
    )
    return result

Comment on lines +6 to +20
@triton.autotune(
configs=get_autotune_config(),
key=[],
)
@triton.jit
def assign_extend_cache_locs(
req_pool_indices,
req_to_token,
start_offset,
end_offset,
out_cache_loc,
pool_len: tl.constexpr,
bs_upper: tl.constexpr,
):
BLOCK_SIZE: tl.constexpr = 32
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The autotuning configuration can be significantly improved for better performance.

  1. Empty key: An empty key (key=[]) causes Triton to recompile the kernel for every new input shape, which is inefficient. The key should include arguments that affect the kernel's performance characteristics, such as tensor dimensions that are not known at compile time. For instance, you could use key=['bs_upper'].
  2. Hardcoded BLOCK_SIZE: The BLOCK_SIZE is hardcoded to 32 inside the kernel. This is a key parameter for performance and should be tuned. It should be moved to the autotuning configuration and passed as a constexpr argument to the kernel.

These changes would make the get_autotune_config function unnecessary. Here's a suggested implementation that applies these improvements:

@triton.autotune(
    configs=[
        triton.Config({'BLOCK_SIZE': 32}, num_warps=2, multibuffer=True),
        triton.Config({'BLOCK_SIZE': 64}, num_warps=2, multibuffer=True),
        triton.Config({'BLOCK_SIZE': 128}, num_warps=4, multibuffer=True),
        triton.Config({'BLOCK_SIZE': 256}, num_warps=4, multibuffer=True),
    ],
    key=['bs_upper'],
)
@triton.jit
def assign_extend_cache_locs(
    req_pool_indices,
    req_to_token,
    start_offset,
    end_offset,
    out_cache_loc,
    pool_len: tl.constexpr,
    bs_upper: tl.constexpr,
    BLOCK_SIZE: tl.constexpr,
):

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant