Skip to content

Commit 7ed5427

Browse files
committed
rebase code
1 parent 5862faf commit 7ed5427

3 files changed

Lines changed: 1 addition & 659 deletions

File tree

python/sglang/srt/layers/attention/nsa/index_buf_accessor.py

Lines changed: 1 addition & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -186,32 +186,6 @@ def triton(
186186
index_head_dim=pool.index_head_dim,
187187
)
188188

189-
class GetKEAndKS:
190-
@classmethod
191-
def execute(cls, *args, **kwargs):
192-
return cls.triton(*args, **kwargs)
193-
194-
@classmethod
195-
def triton(
196-
cls, pool: "NSATokenToKVPool",extend_sum_seq_len: int, seq_lens_tensor: torch.Tensor,
197-
extend_seq_lens_tensor: torch.Tensor, seq_lens_expanded_tensor: torch.Tensor,
198-
):
199-
"""
200-
Triton implementation for generate ke and ks data for all batch in a single call.
201-
202-
:param page_indices: (num_pages,), int32/int64
203-
:return: tuple of (k_fp8, k_scale) where
204-
k_fp8: (seq_len, index_head_dim), uint8
205-
k_scale: (seq_len, 4), uint8
206-
"""
207-
return _get_ke_and_ks_triton(
208-
extend_sum_seq_len=extend_sum_seq_len,
209-
seq_lens=seq_lens_tensor,
210-
extend_seq_lens=extend_seq_lens_tensor,
211-
seq_lens_expanded=seq_lens_expanded_tensor,
212-
)
213-
214-
215189
class SetK:
216190
@classmethod
217191
def execute(cls, *args, buf, **kwargs):
@@ -740,108 +714,4 @@ def _get_k_and_s_triton_kernel(
740714

741715
# Store S to output
742716
s_dst_offset = token_id * 4
743-
tl.store(s_out_ptr + s_dst_offset + s_offsets + s_offset_batch, s_data, mask=s_mask)
744-
745-
746-
def _get_ke_and_ks_triton(
747-
extend_sum_seq_len: int,
748-
seq_lens: torch.Tensor,
749-
extend_seq_lens: torch.Tensor,
750-
seq_lens_expanded: torch.Tensor
751-
):
752-
"""
753-
Fused gather of both K (key) and S (scale) data from paged buffer using Triton.
754-
This is more efficient than calling GetK and GetS separately.
755-
for example:
756-
seq_lens = [20, 30, 40, 50]
757-
extend_seq_lens = [10, 15, 20, 10]
758-
seq_lens_expanded = [10,11,...,19(bs0),15,16,...,29(bs1),40,41,...,49]
759-
760-
prefix_sum = [0,10,25,45,55]
761-
seq_lens_sum = [0, 20, 50, 90]
762-
763-
:param extend_sum_seq_len: sum of all extend sequence len, int32
764-
:param seq_lens: (num_pages, page_size * 128 + page_size * 4), int32
765-
:param extend_seq_lens: (num_pages,), int32
766-
:param seq_lens_expanded: int, number of tokens to gather
767-
:return: tuple of (ks, ke) where
768-
ks: (sum_extend_seq_len,), int32
769-
ke: (sum_extend_seq_len,), int32
770-
"""
771-
772-
ks = torch.empty((extend_sum_seq_len), dtype=torch.int32, device="cuda")
773-
ke = torch.empty((extend_sum_seq_len), dtype=torch.int32, device="cuda")
774-
775-
max_iter = math.ceil(math.log2(extend_sum_seq_len)) + 1 if extend_sum_seq_len > 0 else 1
776-
777-
BLOCK_SIZE = 256
778-
grid = lambda meta: (triton.cdiv(extend_sum_seq_len, meta['BLOCK_SIZE']),)
779-
_get_ke_ks_triton_kernel[grid](
780-
seq_lens_ptr=seq_lens,
781-
extend_seq_lens_ptr=extend_seq_lens,
782-
seq_lens_expanded=seq_lens_expanded,
783-
ks_out_ptr=ks,
784-
ke_out_ptr=ke,
785-
seq_num=extend_seq_lens.shape[0],
786-
extend_seq_lens_sum=extend_sum_seq_len,
787-
iter_num=max_iter,
788-
BLOCK_SIZE=BLOCK_SIZE,
789-
)
790-
791-
return ks, ke
792-
793-
794-
@triton.jit
795-
def _get_ke_ks_triton_kernel(
796-
seq_lens_ptr,
797-
extend_seq_lens_ptr,
798-
seq_lens_expanded,
799-
ks_out_ptr,
800-
ke_out_ptr,
801-
seq_num: tl.constexpr,
802-
extend_seq_lens_sum: tl.constexpr,
803-
iter_num: tl.constexpr,
804-
BLOCK_SIZE: tl.constexpr,
805-
):
806-
'''
807-
Get ke and ks fuse kernel.
808-
'''
809-
pid = tl.program_id(axis=0)
810-
if pid >= extend_seq_lens_sum:
811-
return
812-
813-
start_pos = pid * BLOCK_SIZE
814-
out_pos = tl.arange(0, BLOCK_SIZE) + start_pos
815-
pos_mask = out_pos < extend_seq_lens_sum
816-
817-
low = tl.zeros((BLOCK_SIZE,), dtype=tl.int32)
818-
high = tl.full((BLOCK_SIZE,), seq_num, dtype=tl.int32)
819-
for _ in range(iter_num):
820-
mid = (low + high) // 2
821-
822-
prefix_mid = tl.zeros((BLOCK_SIZE,), dtype=tl.int32)
823-
for j in range(seq_num):
824-
j_lt_mid = (j < mid) & pos_mask
825-
extend_seq_len_j = tl.load(extend_seq_lens_ptr + j)
826-
prefix_mid = tl.where(j_lt_mid, prefix_mid + extend_seq_len_j, prefix_mid)
827-
828-
cond = out_pos >= prefix_mid
829-
low = tl.where(cond, mid, low)
830-
high = tl.where(~cond, mid, high)
831-
832-
i = low
833-
out_mask = (i >= 0) & (i < seq_num) & pos_mask
834-
835-
seq_lens_sum_val = tl.zeros((BLOCK_SIZE,), dtype=tl.int32)
836-
for j in range(seq_num):
837-
j_lt_i = (j < i) & out_mask
838-
839-
seq_len_j = tl.load(seq_lens_ptr + j)
840-
seq_len_j = tl.cast(seq_len_j, tl.int32)
841-
seq_lens_sum_val = tl.where(j_lt_i, seq_lens_sum_val + seq_len_j, seq_lens_sum_val)
842-
843-
D_val = tl.load(seq_lens_expanded + out_pos, mask=out_mask)
844-
845-
tl.store(ks_out_ptr + out_pos, seq_lens_sum_val, mask=out_mask)
846-
store_val = seq_lens_sum_val + D_val
847-
tl.store(ke_out_ptr + out_pos, store_val, mask=out_mask)
717+
tl.store(s_out_ptr + s_dst_offset + s_offsets + s_offset_batch, s_data, mask=s_mask)

python/sglang/srt/mem_cache/memory_pool.py

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1753,31 +1753,6 @@ def get_index_k_scale_buffer(
17531753
self, buf, page_indices=page_indices, seq_len_tensor=seq_len_tensor,
17541754
seq_len_sum=seq_len_sum, max_seq_len=max_seq_len,
17551755
)
1756-
1757-
def get_ks_ke_buffer(
1758-
self,
1759-
extend_sum_seq_len: int,
1760-
seq_lens_tensor: torch.Tensor,
1761-
extend_seq_lens_tensor: torch.Tensor,
1762-
seq_lens_expanded_tensor: torch.Tensor,
1763-
):
1764-
"""
1765-
Fused method to get both index K and scale data in a single call using Triton.
1766-
More efficient than calling get_index_k_continuous and get_index_k_scale_continuous separately.
1767-
1768-
:param layer_id: Layer index
1769-
:param seq_len: Sequence length
1770-
:param page_indices: Page indices tensor
1771-
:return: tuple of (k_fp8, k_scale) where
1772-
k_fp8: (seq_len, index_head_dim), uint8
1773-
k_scale: (seq_len, 4), uint8
1774-
"""
1775-
return index_buf_accessor.GetKEAndKS.execute(
1776-
self,extend_sum_seq_len=extend_sum_seq_len,
1777-
seq_lens_tensor=seq_lens_tensor,
1778-
extend_seq_lens_tensor=extend_seq_lens_tensor,
1779-
seq_lens_expanded_tensor=seq_lens_expanded_tensor,
1780-
)
17811756

17821757
def set_index_k_scale_buffer(
17831758
self,

0 commit comments

Comments
 (0)