Skip to content
6 changes: 5 additions & 1 deletion python/sglang/srt/configs/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,11 @@ def is_deepseek_nsa(config: PretrainedConfig) -> bool:
return (
config.architectures is not None
and config.architectures[0]
in ["DeepseekV3ForCausalLM", "DeepseekV32ForCausalLM"]
in [
"DeepseekV3ForCausalLM",
"DeepseekV32ForCausalLM",
"DeepseekV3ForCausalLMNextN",
]
and getattr(config, "index_topk", None) is not None
)

Expand Down
33 changes: 23 additions & 10 deletions python/sglang/srt/layers/attention/nsa/nsa_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,10 @@ def _get_topk_paged(
)

blocksize = page_size
seqlens_32 = metadata.get_seqlens_int32()
if forward_batch.forward_mode.is_target_verify():
seqlens_32 = metadata.get_seqlens_expanded()
else:
seqlens_32 = metadata.get_seqlens_int32()
# NOTE(dark): 132 is SM count on H200/B200, not magic number
schedule_metadata = deep_gemm.get_paged_mqa_logits_metadata(
seqlens_32, blocksize, self.sm_count
Expand Down Expand Up @@ -317,8 +320,9 @@ def _get_topk_ragged(
k_fp8_list = []
k_scale_list = []
ks_list = []
ke_list = []
offset = 0

seq_lens_expanded = metadata.get_seqlens_expanded()
block_tables = metadata.get_page_table_64()

assert (
Expand All @@ -341,30 +345,34 @@ def _get_topk_ragged(
)
extend_seq_len = forward_batch.extend_seq_lens_cpu[i]
ks = torch.full((extend_seq_len,), offset, dtype=torch.int32, device="cuda")
ke = ks + seq_lens_expanded[offset : offset + extend_seq_len]
k_fp8_list.append(k_fp8)
k_scale_list.append(k_scale)
ks_list.append(ks)
ke_list.append(ke)
offset += extend_seq_len

k_fp8 = torch.cat(k_fp8_list, dim=0).view(torch.float8_e4m3fn)
k_scale = torch.cat(k_scale_list, dim=0).view(torch.float32).squeeze(-1)
kv_fp8 = (k_fp8, k_scale)
ks = torch.cat(ks_list, dim=0)
seq_lens_expanded = metadata.get_seqlens_expanded()
ke = ks + seq_lens_expanded
ke = torch.cat(ke_list, dim=0)

logits = deep_gemm.fp8_mqa_logits(
q_fp8,
q_fp8[:offset],
kv_fp8,
weights,
weights[:offset],
ks,
ke,
clean_logits=False,
)

token_nums, _, _ = q_fp8.shape
assert logits.shape[0] == len(seq_lens_expanded)
topk_result = metadata.topk_transform(logits, self.index_topk)

raw_topk_result = metadata.topk_transform(logits, self.index_topk)
topk_result = torch.full(
(token_nums, self.index_topk), -1, device=q_fp8.device, dtype=torch.int32
)
topk_result[:offset] = raw_topk_result
return topk_result

def forward_indexer(
Expand Down Expand Up @@ -500,6 +508,8 @@ def forward_cuda(
# k_buffer: (num_total_tokens + page_size, head_dim) fp8_e4m3fn
# k_scale: (seq_len, head_dim // block_size = 1) fp8_e4m3fn
# k_scale_cache: (num_total_tokens + page_size, head_dim // block_size = 1) fp8_e4m3fn
if not forward_batch.out_cache_loc.is_contiguous():
forward_batch.out_cache_loc = forward_batch.out_cache_loc.contiguous()
forward_batch.token_to_kv_pool.set_index_k_and_scale_buffer(
layer_id=layer_id,
loc=forward_batch.out_cache_loc,
Expand All @@ -521,7 +531,10 @@ def forward_cuda(
(x.shape[0], self.index_topk), -1, dtype=torch.int, device="cuda"
)

if forward_batch.forward_mode.is_decode_or_idle():
if (
forward_batch.forward_mode.is_decode_or_idle()
or forward_batch.forward_mode.is_target_verify()
):
topk_result = self._get_topk_paged(
forward_batch, layer_id, q_fp8, weights, metadata
)
Expand Down
Loading
Loading