Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
320 changes: 318 additions & 2 deletions python/sglang/srt/layers/attention/nsa/index_buf_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
class GetK:
@classmethod
def execute(cls, *args, **kwargs):
return cls.torch_fast(*args, **kwargs)
return cls.triton(*args, **kwargs)

@classmethod
def slow(
Expand Down Expand Up @@ -67,11 +67,28 @@ def torch_fast(
out = flat_buf[flat_indices]
return out.view(-1, 128)

@classmethod
def triton(
cls, pool: "NSATokenToKVPool", buf, seq_len: int, page_indices: torch.Tensor
):
"""
Triton implementation for gathering K data from paged buffer.
:param page_indices: (num_pages,), int32/int64
:return: (seq_len, index_head_dim), uint8
"""
return _get_k_triton(
buf=buf,
page_indices=page_indices,
seq_len=seq_len,
page_size=pool.page_size,
index_head_dim=pool.index_head_dim,
)


class GetS:
@classmethod
def execute(cls, *args, **kwargs):
return cls.torch_fast(*args, **kwargs)
return cls.triton(*args, **kwargs)

@classmethod
def slow(
Expand Down Expand Up @@ -119,6 +136,48 @@ def torch_fast(
out = flat_buf[flat_indices]
return out.view(-1, 4)

@classmethod
def triton(
cls, pool: "NSATokenToKVPool", buf, seq_len: int, page_indices: torch.Tensor
):
"""
Triton implementation for gathering S (scale) data from paged buffer.
:param page_indices: (num_pages,), int32/int64
:return: (seq_len, 4), uint8
"""
return _get_s_triton(
buf=buf,
page_indices=page_indices,
seq_len=seq_len,
page_size=pool.page_size,
index_head_dim=pool.index_head_dim,
)


class GetKAndS:
@classmethod
def execute(cls, *args, **kwargs):
return cls.triton(*args, **kwargs)

@classmethod
def triton(
cls, pool: "NSATokenToKVPool", buf, seq_len: int, page_indices: torch.Tensor
):
"""
Triton implementation for gathering both K and S data from paged buffer in a single call.
:param page_indices: (num_pages,), int32/int64
:return: tuple of (k_fp8, k_scale) where
k_fp8: (seq_len, index_head_dim), uint8
k_scale: (seq_len, 4), uint8
"""
return _get_k_and_s_triton(
buf=buf,
page_indices=page_indices,
seq_len=seq_len,
page_size=pool.page_size,
index_head_dim=pool.index_head_dim,
)


class SetK:
@classmethod
Expand Down Expand Up @@ -363,3 +422,260 @@ def _set_k_and_s_triton_kernel(

tl.store(buf_fp8_ptr + out_k_offsets, k)
tl.store(buf_fp32_ptr + out_s_offset, k_scale)


def _get_k_triton(
buf: torch.Tensor,
page_indices: torch.Tensor,
seq_len: int,
page_size: int,
index_head_dim: int,
):
"""
Gather K (key) data from paged buffer using Triton.

:param buf: (num_pages, page_size * 128 + page_size * 4), uint8
:param page_indices: (num_pages,), int32/int64
:param seq_len: int, number of tokens to gather
:param page_size: int, typically 64
:param index_head_dim: int, typically 128
:return: (seq_len, index_head_dim), uint8
"""
num_pages, buf_numel_per_page = buf.shape

# Allocate output
out = torch.empty((seq_len, index_head_dim), dtype=torch.uint8, device=buf.device)

# Launch kernel with one thread per token
grid = (seq_len,)
_get_k_triton_kernel[grid](
buf,
page_indices,
out,
seq_len,
page_size,
buf_numel_per_page,
index_head_dim,
BLOCK_SIZE=128,
)

return out


@triton.jit
def _get_k_triton_kernel(
buf_ptr,
page_indices_ptr,
out_ptr,
seq_len: tl.constexpr,
page_size: tl.constexpr,
buf_numel_per_page: tl.constexpr,
index_head_dim: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""
Each program handles one token (seq_len tokens total).
Loads 128 bytes from the appropriate page.
"""
token_id = tl.program_id(0)

# Calculate which page and offset within page
page_idx = token_id // page_size
token_offset_in_page = token_id % page_size

# Load the page index from page_indices
page_index = tl.load(page_indices_ptr + page_idx)

# Calculate source offset in buf
# buf[page_index, token_offset_in_page * index_head_dim : ...]
src_base_offset = (
page_index * buf_numel_per_page + token_offset_in_page * index_head_dim
)

# Load 128 bytes (index_head_dim elements)
offsets = tl.arange(0, BLOCK_SIZE)
mask = offsets < index_head_dim
data = tl.load(buf_ptr + src_base_offset + offsets, mask=mask)

# Store to output
dst_offset = token_id * index_head_dim
tl.store(out_ptr + dst_offset + offsets, data, mask=mask)


def _get_s_triton(
buf: torch.Tensor,
page_indices: torch.Tensor,
seq_len: int,
page_size: int,
index_head_dim: int,
):
"""
Gather S (scale) data from paged buffer using Triton.

:param buf: (num_pages, page_size * 128 + page_size * 4), uint8
:param page_indices: (num_pages,), int32/int64
:param seq_len: int, number of tokens to gather
:param page_size: int, typically 64
:param index_head_dim: int, typically 128
:return: (seq_len, 4), uint8 (representing fp32 scale)
"""
num_pages, buf_numel_per_page = buf.shape
s_offset_in_page = page_size * index_head_dim # Scales start after K data

# Allocate output
out = torch.empty((seq_len, 4), dtype=torch.uint8, device=buf.device)

# Launch kernel with one thread per token
grid = (seq_len,)
_get_s_triton_kernel[grid](
buf,
page_indices,
out,
seq_len,
page_size,
buf_numel_per_page,
s_offset_in_page,
)

return out


@triton.jit
def _get_s_triton_kernel(
buf_ptr,
page_indices_ptr,
out_ptr,
seq_len: tl.constexpr,
page_size: tl.constexpr,
buf_numel_per_page: tl.constexpr,
s_offset_in_page: tl.constexpr,
):
"""
Each program handles one token (seq_len tokens total).
Loads 4 bytes (fp32 scale) from the appropriate page.
"""
token_id = tl.program_id(0)

# Calculate which page and offset within page
page_idx = token_id // page_size
token_offset_in_page = token_id % page_size

# Load the page index from page_indices
page_index = tl.load(page_indices_ptr + page_idx)

# Calculate source offset in buf
# Scales are stored after K data: page_size * index_head_dim offset
# buf[page_index, s_offset_in_page + token_offset_in_page * 4 : ...]
src_base_offset = (
page_index * buf_numel_per_page + s_offset_in_page + token_offset_in_page * 4
)

# Load 4 bytes (fp32 scale)
offsets = tl.arange(0, 4)
data = tl.load(buf_ptr + src_base_offset + offsets)

# Store to output
dst_offset = token_id * 4
tl.store(out_ptr + dst_offset + offsets, data)


def _get_k_and_s_triton(
buf: torch.Tensor,
page_indices: torch.Tensor,
seq_len: int,
page_size: int,
index_head_dim: int,
):
"""
Fused gather of both K (key) and S (scale) data from paged buffer using Triton.
This is more efficient than calling GetK and GetS separately.

:param buf: (num_pages, page_size * 128 + page_size * 4), uint8
:param page_indices: (num_pages,), int32/int64
:param seq_len: int, number of tokens to gather
:param page_size: int, typically 64
:param index_head_dim: int, typically 128
:return: tuple of (k_out, s_out) where
k_out: (seq_len, index_head_dim), uint8
s_out: (seq_len, 4), uint8
"""
num_pages, buf_numel_per_page = buf.shape
s_offset_in_page = page_size * index_head_dim # Scales start after K data

# Allocate outputs
k_out = torch.empty((seq_len, index_head_dim), dtype=torch.uint8, device=buf.device)
s_out = torch.empty((seq_len, 4), dtype=torch.uint8, device=buf.device)

# Launch kernel with one thread per token
grid = (seq_len,)
_get_k_and_s_triton_kernel[grid](
buf,
page_indices,
k_out,
s_out,
seq_len,
page_size,
buf_numel_per_page,
index_head_dim,
s_offset_in_page,
BLOCK_SIZE_K=128,
)

return k_out, s_out


@triton.jit
def _get_k_and_s_triton_kernel(
buf_ptr,
page_indices_ptr,
k_out_ptr,
s_out_ptr,
seq_len: tl.constexpr,
page_size: tl.constexpr,
buf_numel_per_page: tl.constexpr,
index_head_dim: tl.constexpr,
s_offset_in_page: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
):
"""
Fused kernel that gathers both K and S data in a single pass.
Each program handles one token (seq_len tokens total).
Loads 128 bytes (K) + 4 bytes (S) from the appropriate page.
"""
token_id = tl.program_id(0)

# Calculate which page and offset within page
page_idx = token_id // page_size
token_offset_in_page = token_id % page_size

# Load the page index from page_indices
page_index = tl.load(page_indices_ptr + page_idx)

# ===== Load K data (128 bytes) =====
# Calculate source offset for K in buf
k_src_base_offset = (
page_index * buf_numel_per_page + token_offset_in_page * index_head_dim
)

# Load 128 bytes (index_head_dim elements)
k_offsets = tl.arange(0, BLOCK_SIZE_K)
k_mask = k_offsets < index_head_dim
k_data = tl.load(buf_ptr + k_src_base_offset + k_offsets, mask=k_mask)

# Store K to output
k_dst_offset = token_id * index_head_dim
tl.store(k_out_ptr + k_dst_offset + k_offsets, k_data, mask=k_mask)

# ===== Load S data (4 bytes) =====
# Calculate source offset for S in buf
s_src_base_offset = (
page_index * buf_numel_per_page + s_offset_in_page + token_offset_in_page * 4
)

# Load 4 bytes (fp32 scale)
s_offsets = tl.arange(0, 4)
s_data = tl.load(buf_ptr + s_src_base_offset + s_offsets)

# Store S to output
s_dst_offset = token_id * 4
tl.store(s_out_ptr + s_dst_offset + s_offsets, s_data)
8 changes: 2 additions & 6 deletions python/sglang/srt/layers/attention/nsa/nsa_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,12 +385,8 @@ def _get_topk_ragged(
for i in range(forward_batch.batch_size):
seq_len = forward_batch.seq_lens_cpu[i].item()
assert isinstance(seq_len, int)
k_fp8 = forward_batch.token_to_kv_pool.get_index_k_continuous(
layer_id,
seq_len,
block_tables[i],
)
k_scale = forward_batch.token_to_kv_pool.get_index_k_scale_continuous(
# Use fused Triton kernel to get both K and scale in a single call
k_fp8, k_scale = forward_batch.token_to_kv_pool.get_index_k_scale_buffer(
layer_id,
seq_len,
block_tables[i],
Expand Down
22 changes: 22 additions & 0 deletions python/sglang/srt/mem_cache/memory_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -1826,6 +1826,28 @@ def get_index_k_scale_continuous(
self, buf, seq_len=seq_len, page_indices=page_indices
)

def get_index_k_scale_buffer(
self,
layer_id: int,
seq_len: int,
page_indices: torch.Tensor,
):
"""
Fused method to get both index K and scale data in a single call using Triton.
More efficient than calling get_index_k_continuous and get_index_k_scale_continuous separately.

:param layer_id: Layer index
:param seq_len: Sequence length
:param page_indices: Page indices tensor
:return: tuple of (k_fp8, k_scale) where
k_fp8: (seq_len, index_head_dim), uint8
k_scale: (seq_len, 4), uint8
"""
buf = self.index_k_with_scale_buffer[layer_id - self.start_layer]
return index_buf_accessor.GetKAndS.execute(
self, buf, seq_len=seq_len, page_indices=page_indices
)

def set_index_k_scale_buffer(
self,
layer_id: int,
Expand Down
Loading
Loading