Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions sgl-kernel/csrc/elementwise/concat_mla.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ __global__ void concat_mla_k_kernel(
const nv_bfloat16* __restrict__ k_nope,
const nv_bfloat16* __restrict__ k_rope,
const int num_tokens,
const int k_stride_0,
const int64_t k_stride_0,
const int k_stride_1,
const int k_nope_stride_0,
const int64_t k_nope_stride_0,
const int k_nope_stride_1,
const int k_rope_stride_0) {
const int64_t k_rope_stride_0) {
const int flat_warp_id = (blockIdx.x * blockDim.x + threadIdx.x) / 32;
const int token_id = flat_warp_id / NUM_HEAD_CHUNKS;
const int head_chunk_id = flat_warp_id % NUM_HEAD_CHUNKS;
Expand Down Expand Up @@ -126,11 +126,11 @@ __global__ void concat_mla_absorb_q_kernel(
nv_bfloat16* out,
const int num_items,
const int dim_1,
const int a_stride_0,
const int64_t a_stride_0,
const int a_stride_1,
const int b_stride_0,
const int64_t b_stride_0,
const int b_stride_1,
const int out_stride_0,
const int64_t out_stride_0,
const int out_stride_1) {
Comment on lines +129 to 134
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This change correctly prevents integer overflow for large inputs by using int64_t for strides.

A similar potential overflow issue seems to exist in concat_mla_k_kernel. The strides k_stride_0, k_nope_stride_0, and k_rope_stride_0 are defined as int (lines 21, 23, 25), but their product with token_id inside the kernel could overflow if num_tokens is large.

For example, with k_stride_0 being 24576, an overflow would occur if num_tokens exceeds 2,147,483,647 / 24576 ~= 87,375. While current tests might not reach this limit, it's a potential failure point for longer sequences.

It would be beneficial to also change these stride types to int64_t to prevent this issue and ensure consistency across the kernels.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bingps Agreed, it's also beneficial to change the datatype for concat_mla_k_kernel
Can you please fix it here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Fridge003 Updated~ Tests are in the following comment.

const int flat_warp_id = (blockIdx.x * blockDim.x + threadIdx.x) / 32;
const int lane_id = get_lane_id();
Expand Down
Loading