@@ -1723,8 +1723,32 @@ def forward_normal_prepare(
17231723 )
17241724 prefix_kv_a = self ._all_gather_dcp_kv_cache (prefix_kv_a .squeeze (1 ))
17251725 prefix_k_pe = self ._all_gather_dcp_kv_cache (prefix_k_pe )
1726- kv_a = torch .cat ((prefix_kv_a , kv_a ), dim = 0 )
1727- k_pe = torch .cat ((prefix_k_pe , k_pe ), dim = 0 )
1726+ # re-organize kv with query orders
1727+ prefix_lens_cu = torch .zeros (
1728+ len (forward_batch .seq_lens ) + 1 ,
1729+ dtype = torch .int32 ,
1730+ device = kv_a .device ,
1731+ )
1732+ extend_lens_cu = torch .zeros_like (prefix_lens_cu )
1733+ prefix_lens_cu [1 :] = torch .cumsum (
1734+ forward_batch .extend_prefix_lens , dim = 0
1735+ )
1736+ extend_lens_cu [1 :] = torch .cumsum (
1737+ forward_batch .extend_seq_lens , dim = 0
1738+ )
1739+ kv_a_tuple = ()
1740+ k_pe_tuple = ()
1741+ for i in range (len (forward_batch .seq_lens )):
1742+ kv_a_tuple += (
1743+ prefix_kv_a [prefix_lens_cu [i ] : prefix_lens_cu [i + 1 ]],
1744+ kv_a [extend_lens_cu [i ] : extend_lens_cu [i + 1 ]],
1745+ )
1746+ k_pe_tuple += (
1747+ prefix_k_pe [prefix_lens_cu [i ] : prefix_lens_cu [i + 1 ]],
1748+ k_pe [extend_lens_cu [i ] : extend_lens_cu [i + 1 ]],
1749+ )
1750+ kv_a = torch .cat (kv_a_tuple , dim = 0 )
1751+ k_pe = torch .cat (k_pe_tuple , dim = 0 )
17281752 else :
17291753 # BF16/FP16 path: directly fetch from cache
17301754 kv_a , k_pe = self ._get_mla_kv_buffer (
@@ -2716,16 +2740,15 @@ def forward_absorb_fused_mla_rope_cpu_core(
27162740 def _all_gather_dcp_kv_cache (self , kv_a ):
27172741 dcp_world_size = get_dcp_world_size ()
27182742 dcp_rank = get_dcp_rank ()
2719- gathered_kv_a = torch .empty (
2743+ gathered_kv_a = torch .zeros (
27202744 (kv_a .shape [0 ] * get_dcp_world_size (), * kv_a .shape [1 :]),
27212745 dtype = kv_a .dtype ,
27222746 device = kv_a .device ,
27232747 )
27242748 idxs = torch .arange (kv_a .shape [0 ] * dcp_world_size )
27252749 mask = idxs % dcp_world_size == dcp_rank
27262750 gathered_kv_a [mask ] = kv_a
2727- get_dcp_group ().all_reduce (gathered_kv_a )
2728- return gathered_kv_a
2751+ return get_dcp_group ().all_reduce (gathered_kv_a )
27292752
27302753 def _chunked_prefix_attn_mha (
27312754 self ,
0 commit comments