Skip to content

Commit 6ed92f2

Browse files
committed
fix bugs in gather kv for mha one shot
1 parent cd3b3d0 commit 6ed92f2

File tree

2 files changed

+32
-10
lines changed

2 files changed

+32
-10
lines changed

python/sglang/srt/model_executor/model_runner.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2657,10 +2657,6 @@ def forward_extend(
26572657
)
26582658
dcp_kv_indptr[1:] = forward_batch.seq_lens.cumsum(dim=0)
26592659
dcp_kv_indptr = dcp_kv_indptr[: (len(forward_batch.seq_lens) + 1)]
2660-
forward_batch.dcp_kv_indptr = dcp_kv_indptr
2661-
forward_batch.dcp_local_prefix_kv_indices = (
2662-
dcp_prefix_kv_indices[::8] // get_dcp_world_size()
2663-
)
26642660
dcp_kv_indices = torch.zeros(
26652661
forward_batch.seq_lens_sum,
26662662
dtype=torch.int32,
@@ -2735,7 +2731,10 @@ def create_dcp_kv_indices(
27352731
)
27362732
forward_batch.dcp_kv_indptr = dcp_kv_indptr
27372733
forward_batch.dcp_local_prefix_kv_indices = (
2738-
dcp_prefix_kv_indices[::8] // get_dcp_world_size()
2734+
dcp_prefix_kv_indices[
2735+
dcp_prefix_kv_indices % get_dcp_world_size() == get_dcp_rank()
2736+
]
2737+
// get_dcp_world_size()
27392738
)
27402739
forward_batch.dcp_kv_buffer = torch.empty(
27412740
(

python/sglang/srt/models/deepseek_v2.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1723,8 +1723,32 @@ def forward_normal_prepare(
17231723
)
17241724
prefix_kv_a = self._all_gather_dcp_kv_cache(prefix_kv_a.squeeze(1))
17251725
prefix_k_pe = self._all_gather_dcp_kv_cache(prefix_k_pe)
1726-
kv_a = torch.cat((prefix_kv_a, kv_a), dim=0)
1727-
k_pe = torch.cat((prefix_k_pe, k_pe), dim=0)
1726+
# re-organize kv with query orders
1727+
prefix_lens_cu = torch.zeros(
1728+
len(forward_batch.seq_lens) + 1,
1729+
dtype=torch.int32,
1730+
device=kv_a.device,
1731+
)
1732+
extend_lens_cu = torch.zeros_like(prefix_lens_cu)
1733+
prefix_lens_cu[1:] = torch.cumsum(
1734+
forward_batch.extend_prefix_lens, dim=0
1735+
)
1736+
extend_lens_cu[1:] = torch.cumsum(
1737+
forward_batch.extend_seq_lens, dim=0
1738+
)
1739+
kv_a_tuple = ()
1740+
k_pe_tuple = ()
1741+
for i in range(len(forward_batch.seq_lens)):
1742+
kv_a_tuple += (
1743+
prefix_kv_a[prefix_lens_cu[i] : prefix_lens_cu[i + 1]],
1744+
kv_a[extend_lens_cu[i] : extend_lens_cu[i + 1]],
1745+
)
1746+
k_pe_tuple += (
1747+
prefix_k_pe[prefix_lens_cu[i] : prefix_lens_cu[i + 1]],
1748+
k_pe[extend_lens_cu[i] : extend_lens_cu[i + 1]],
1749+
)
1750+
kv_a = torch.cat(kv_a_tuple, dim=0)
1751+
k_pe = torch.cat(k_pe_tuple, dim=0)
17281752
else:
17291753
# BF16/FP16 path: directly fetch from cache
17301754
kv_a, k_pe = self._get_mla_kv_buffer(
@@ -2716,16 +2740,15 @@ def forward_absorb_fused_mla_rope_cpu_core(
27162740
def _all_gather_dcp_kv_cache(self, kv_a):
27172741
dcp_world_size = get_dcp_world_size()
27182742
dcp_rank = get_dcp_rank()
2719-
gathered_kv_a = torch.empty(
2743+
gathered_kv_a = torch.zeros(
27202744
(kv_a.shape[0] * get_dcp_world_size(), *kv_a.shape[1:]),
27212745
dtype=kv_a.dtype,
27222746
device=kv_a.device,
27232747
)
27242748
idxs = torch.arange(kv_a.shape[0] * dcp_world_size)
27252749
mask = idxs % dcp_world_size == dcp_rank
27262750
gathered_kv_a[mask] = kv_a
2727-
get_dcp_group().all_reduce(gathered_kv_a)
2728-
return gathered_kv_a
2751+
return get_dcp_group().all_reduce(gathered_kv_a)
27292752

27302753
def _chunked_prefix_attn_mha(
27312754
self,

0 commit comments

Comments
 (0)