@@ -186,32 +186,6 @@ def triton(
186186 index_head_dim = pool .index_head_dim ,
187187 )
188188
189- class GetKEAndKS :
190- @classmethod
191- def execute (cls , * args , ** kwargs ):
192- return cls .triton (* args , ** kwargs )
193-
194- @classmethod
195- def triton (
196- cls , pool : "NSATokenToKVPool" ,extend_sum_seq_len : int , seq_lens_tensor : torch .Tensor ,
197- extend_seq_lens_tensor : torch .Tensor , seq_lens_expanded_tensor : torch .Tensor ,
198- ):
199- """
200- Triton implementation for generate ke and ks data for all batch in a single call.
201-
202- :param page_indices: (num_pages,), int32/int64
203- :return: tuple of (k_fp8, k_scale) where
204- k_fp8: (seq_len, index_head_dim), uint8
205- k_scale: (seq_len, 4), uint8
206- """
207- return _get_ke_and_ks_triton (
208- extend_sum_seq_len = extend_sum_seq_len ,
209- seq_lens = seq_lens_tensor ,
210- extend_seq_lens = extend_seq_lens_tensor ,
211- seq_lens_expanded = seq_lens_expanded_tensor ,
212- )
213-
214-
215189class SetK :
216190 @classmethod
217191 def execute (cls , * args , buf , ** kwargs ):
@@ -740,108 +714,4 @@ def _get_k_and_s_triton_kernel(
740714
741715 # Store S to output
742716 s_dst_offset = token_id * 4
743- tl .store (s_out_ptr + s_dst_offset + s_offsets + s_offset_batch , s_data , mask = s_mask )
744-
745-
746- def _get_ke_and_ks_triton (
747- extend_sum_seq_len : int ,
748- seq_lens : torch .Tensor ,
749- extend_seq_lens : torch .Tensor ,
750- seq_lens_expanded : torch .Tensor
751- ):
752- """
753- Fused gather of both K (key) and S (scale) data from paged buffer using Triton.
754- This is more efficient than calling GetK and GetS separately.
755- for example:
756- seq_lens = [20, 30, 40, 50]
757- extend_seq_lens = [10, 15, 20, 10]
758- seq_lens_expanded = [10,11,...,19(bs0),15,16,...,29(bs1),40,41,...,49]
759-
760- prefix_sum = [0,10,25,45,55]
761- seq_lens_sum = [0, 20, 50, 90]
762-
763- :param extend_sum_seq_len: sum of all extend sequence len, int32
764- :param seq_lens: (num_pages, page_size * 128 + page_size * 4), int32
765- :param extend_seq_lens: (num_pages,), int32
766- :param seq_lens_expanded: int, number of tokens to gather
767- :return: tuple of (ks, ke) where
768- ks: (sum_extend_seq_len,), int32
769- ke: (sum_extend_seq_len,), int32
770- """
771-
772- ks = torch .empty ((extend_sum_seq_len ), dtype = torch .int32 , device = "cuda" )
773- ke = torch .empty ((extend_sum_seq_len ), dtype = torch .int32 , device = "cuda" )
774-
775- max_iter = math .ceil (math .log2 (extend_sum_seq_len )) + 1 if extend_sum_seq_len > 0 else 1
776-
777- BLOCK_SIZE = 256
778- grid = lambda meta : (triton .cdiv (extend_sum_seq_len , meta ['BLOCK_SIZE' ]),)
779- _get_ke_ks_triton_kernel [grid ](
780- seq_lens_ptr = seq_lens ,
781- extend_seq_lens_ptr = extend_seq_lens ,
782- seq_lens_expanded = seq_lens_expanded ,
783- ks_out_ptr = ks ,
784- ke_out_ptr = ke ,
785- seq_num = extend_seq_lens .shape [0 ],
786- extend_seq_lens_sum = extend_sum_seq_len ,
787- iter_num = max_iter ,
788- BLOCK_SIZE = BLOCK_SIZE ,
789- )
790-
791- return ks , ke
792-
793-
794- @triton .jit
795- def _get_ke_ks_triton_kernel (
796- seq_lens_ptr ,
797- extend_seq_lens_ptr ,
798- seq_lens_expanded ,
799- ks_out_ptr ,
800- ke_out_ptr ,
801- seq_num : tl .constexpr ,
802- extend_seq_lens_sum : tl .constexpr ,
803- iter_num : tl .constexpr ,
804- BLOCK_SIZE : tl .constexpr ,
805- ):
806- '''
807- Get ke and ks fuse kernel.
808- '''
809- pid = tl .program_id (axis = 0 )
810- if pid >= extend_seq_lens_sum :
811- return
812-
813- start_pos = pid * BLOCK_SIZE
814- out_pos = tl .arange (0 , BLOCK_SIZE ) + start_pos
815- pos_mask = out_pos < extend_seq_lens_sum
816-
817- low = tl .zeros ((BLOCK_SIZE ,), dtype = tl .int32 )
818- high = tl .full ((BLOCK_SIZE ,), seq_num , dtype = tl .int32 )
819- for _ in range (iter_num ):
820- mid = (low + high ) // 2
821-
822- prefix_mid = tl .zeros ((BLOCK_SIZE ,), dtype = tl .int32 )
823- for j in range (seq_num ):
824- j_lt_mid = (j < mid ) & pos_mask
825- extend_seq_len_j = tl .load (extend_seq_lens_ptr + j )
826- prefix_mid = tl .where (j_lt_mid , prefix_mid + extend_seq_len_j , prefix_mid )
827-
828- cond = out_pos >= prefix_mid
829- low = tl .where (cond , mid , low )
830- high = tl .where (~ cond , mid , high )
831-
832- i = low
833- out_mask = (i >= 0 ) & (i < seq_num ) & pos_mask
834-
835- seq_lens_sum_val = tl .zeros ((BLOCK_SIZE ,), dtype = tl .int32 )
836- for j in range (seq_num ):
837- j_lt_i = (j < i ) & out_mask
838-
839- seq_len_j = tl .load (seq_lens_ptr + j )
840- seq_len_j = tl .cast (seq_len_j , tl .int32 )
841- seq_lens_sum_val = tl .where (j_lt_i , seq_lens_sum_val + seq_len_j , seq_lens_sum_val )
842-
843- D_val = tl .load (seq_lens_expanded + out_pos , mask = out_mask )
844-
845- tl .store (ks_out_ptr + out_pos , seq_lens_sum_val , mask = out_mask )
846- store_val = seq_lens_sum_val + D_val
847- tl .store (ke_out_ptr + out_pos , store_val , mask = out_mask )
717+ tl .store (s_out_ptr + s_dst_offset + s_offsets + s_offset_batch , s_data , mask = s_mask )
0 commit comments