diff --git a/benchmark/kernels/decoding_attention_triton/triton_flashinfer_cudnn.py b/benchmark/kernels/decoding_attention_triton/triton_flashinfer_cudnn.py index f8c87d48db76..b199a72ed6c1 100644 --- a/benchmark/kernels/decoding_attention_triton/triton_flashinfer_cudnn.py +++ b/benchmark/kernels/decoding_attention_triton/triton_flashinfer_cudnn.py @@ -8,8 +8,8 @@ import triton.language as tl from flashinfer import BatchDecodeWithPagedKVCacheWrapper +from sglang.srt.layers.attention.flashinfer_backend import should_use_tensor_core from sglang.srt.layers.attention.triton_ops.decode_attention import decode_attention_fwd -from sglang.srt.utils import should_use_tensor_core def benchmark_forward( @@ -54,9 +54,10 @@ def decode_attention_sglang( v_buffer = kv_data[1].view(-1, head_num_kv, head_dim) o = torch.empty_like(q) total_tokens = batch_size * kv_len - req_to_token = torch.arange(0, total_tokens).to(0).int().view(batch_size, kv_len) - b_req_idx = torch.arange(0, batch_size).to(0).int() b_seq_len = torch.full((batch_size,), kv_len, dtype=torch.int32, device="cuda") + kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda") + kv_indptr[1 : batch_size + 1] = torch.cumsum(b_seq_len[:batch_size], dim=0) + kv_indices = torch.arange(total_tokens, device="cuda") max_len_in_batch = kv_len sm_scale = 1.0 / (head_dim**0.5) @@ -72,9 +73,8 @@ def decode_attention_sglang( k_buffer, v_buffer, o, - req_to_token, - b_req_idx, - b_seq_len, + kv_indptr, + kv_indices, attn_logits, num_kv_splits, sm_scale, @@ -86,9 +86,8 @@ def decode_attention_sglang( k_buffer, v_buffer, o, - req_to_token, - b_req_idx, - b_seq_len, + kv_indptr, + kv_indices, attn_logits, num_kv_splits, sm_scale,