diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 42c28bc3f171..e63397ec17a8 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -53,7 +53,11 @@ def is_deepseek_nsa(config: PretrainedConfig) -> bool: return ( config.architectures is not None and config.architectures[0] - in ["DeepseekV3ForCausalLM", "DeepseekV32ForCausalLM"] + in [ + "DeepseekV3ForCausalLM", + "DeepseekV32ForCausalLM", + "DeepseekV3ForCausalLMNextN", + ] and getattr(config, "index_topk", None) is not None ) diff --git a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py index 93d7b61a6ccc..25a191e08f89 100644 --- a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py +++ b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py @@ -266,7 +266,10 @@ def _get_topk_paged( ) blocksize = page_size - seqlens_32 = metadata.get_seqlens_int32() + if forward_batch.forward_mode.is_target_verify(): + seqlens_32 = metadata.get_seqlens_expanded() + else: + seqlens_32 = metadata.get_seqlens_int32() # NOTE(dark): 132 is SM count on H200/B200, not magic number schedule_metadata = deep_gemm.get_paged_mqa_logits_metadata( seqlens_32, blocksize, self.sm_count @@ -317,8 +320,9 @@ def _get_topk_ragged( k_fp8_list = [] k_scale_list = [] ks_list = [] + ke_list = [] offset = 0 - + seq_lens_expanded = metadata.get_seqlens_expanded() block_tables = metadata.get_page_table_64() assert ( @@ -341,30 +345,34 @@ def _get_topk_ragged( ) extend_seq_len = forward_batch.extend_seq_lens_cpu[i] ks = torch.full((extend_seq_len,), offset, dtype=torch.int32, device="cuda") + ke = ks + seq_lens_expanded[offset : offset + extend_seq_len] k_fp8_list.append(k_fp8) k_scale_list.append(k_scale) ks_list.append(ks) + ke_list.append(ke) offset += extend_seq_len k_fp8 = torch.cat(k_fp8_list, dim=0).view(torch.float8_e4m3fn) k_scale = torch.cat(k_scale_list, dim=0).view(torch.float32).squeeze(-1) kv_fp8 = (k_fp8, k_scale) ks = torch.cat(ks_list, dim=0) - seq_lens_expanded = metadata.get_seqlens_expanded() - ke = ks + seq_lens_expanded + ke = torch.cat(ke_list, dim=0) logits = deep_gemm.fp8_mqa_logits( - q_fp8, + q_fp8[:offset], kv_fp8, - weights, + weights[:offset], ks, ke, clean_logits=False, ) - + token_nums, _, _ = q_fp8.shape assert logits.shape[0] == len(seq_lens_expanded) - topk_result = metadata.topk_transform(logits, self.index_topk) - + raw_topk_result = metadata.topk_transform(logits, self.index_topk) + topk_result = torch.full( + (token_nums, self.index_topk), -1, device=q_fp8.device, dtype=torch.int32 + ) + topk_result[:offset] = raw_topk_result return topk_result def forward_indexer( @@ -500,6 +508,8 @@ def forward_cuda( # k_buffer: (num_total_tokens + page_size, head_dim) fp8_e4m3fn # k_scale: (seq_len, head_dim // block_size = 1) fp8_e4m3fn # k_scale_cache: (num_total_tokens + page_size, head_dim // block_size = 1) fp8_e4m3fn + if not forward_batch.out_cache_loc.is_contiguous(): + forward_batch.out_cache_loc = forward_batch.out_cache_loc.contiguous() forward_batch.token_to_kv_pool.set_index_k_and_scale_buffer( layer_id=layer_id, loc=forward_batch.out_cache_loc, @@ -521,7 +531,10 @@ def forward_cuda( (x.shape[0], self.index_topk), -1, dtype=torch.int, device="cuda" ) - if forward_batch.forward_mode.is_decode_or_idle(): + if ( + forward_batch.forward_mode.is_decode_or_idle() + or forward_batch.forward_mode.is_target_verify() + ): topk_result = self._get_topk_paged( forward_batch, layer_id, q_fp8, weights, metadata ) diff --git a/python/sglang/srt/layers/attention/nsa_backend.py b/python/sglang/srt/layers/attention/nsa_backend.py index 6ec4652f415d..66b32b2c65ca 100644 --- a/python/sglang/srt/layers/attention/nsa_backend.py +++ b/python/sglang/srt/layers/attention/nsa_backend.py @@ -29,6 +29,7 @@ from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.speculative.spec_info import SpecInput + _is_hip = is_hip() if _is_hip: @@ -148,7 +149,14 @@ def compute_cu_seqlens(seqlens: torch.Tensor) -> torch.Tensor: class NativeSparseAttnBackend(AttentionBackend): - def __init__(self, model_runner: ModelRunner): + def __init__( + self, + model_runner: ModelRunner, + skip_prefill: bool = False, + speculative_step_id=0, + topk=0, + speculative_num_steps=0, + ): super().__init__() self.forward_metadata: NSAMetadata self.device = model_runner.device @@ -185,6 +193,14 @@ def __init__(self, model_runner: ModelRunner): (max_bs + 1,), dtype=torch.int32, device=model_runner.device ) + # Speculative decoding + self.topk = model_runner.server_args.speculative_eagle_topk or 0 + self.speculative_num_steps = speculative_num_steps + self.speculative_num_draft_tokens = ( + model_runner.server_args.speculative_num_draft_tokens + ) + self.speculative_step_id = speculative_step_id + def get_device_int32_arange(self, l: int) -> torch.Tensor: if l > len(self._arange_buf): next_pow_of_2 = 1 << (l - 1).bit_length() @@ -208,13 +224,15 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): batch_size = forward_batch.batch_size device = forward_batch.seq_lens.device - assert ( - forward_batch.spec_info is None - ), "Spec decoding is not supported for NSA backend now" - cache_seqlens_int32 = forward_batch.seq_lens.to(torch.int32) + if forward_batch.forward_mode.is_target_verify(): + draft_token_num = self.speculative_num_draft_tokens + else: + draft_token_num = 0 + + cache_seqlens_int32 = (forward_batch.seq_lens + draft_token_num).to(torch.int32) cu_seqlens_k = compute_cu_seqlens(cache_seqlens_int32) assert forward_batch.seq_lens_cpu is not None - max_seqlen_k = int(forward_batch.seq_lens_cpu.max().item()) + max_seqlen_k = int(forward_batch.seq_lens_cpu.max().item() + draft_token_num) page_table = forward_batch.req_to_token_pool.req_to_token[ forward_batch.req_pool_indices, :max_seqlen_k ] @@ -224,6 +242,41 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): max_seqlen_q = 1 cu_seqlens_q = self.get_device_int32_arange(batch_size + 1) seqlens_expanded = cache_seqlens_int32 + elif forward_batch.forward_mode.is_target_verify(): + max_seqlen_q = self.speculative_num_draft_tokens + nsa_max_seqlen_q = self.speculative_num_draft_tokens + cu_seqlens_q = torch.arange( + 0, + batch_size * self.speculative_num_draft_tokens + 1, + 1, + dtype=torch.int32, + device=device, + ) + extend_seq_lens_cpu = [self.speculative_num_draft_tokens] * batch_size + forward_batch.extend_seq_lens_cpu = extend_seq_lens_cpu + + seqlens_int32_cpu = [ + self.speculative_num_draft_tokens + kv_len + for kv_len in forward_batch.seq_lens_cpu.tolist() + ] + seqlens_expanded = torch.cat( + [ + torch.arange( + kv_len - qo_len + 1, + kv_len + 1, + dtype=torch.int32, + device=device, + ) + for qo_len, kv_len in zip( + extend_seq_lens_cpu, + seqlens_int32_cpu, + strict=True, + ) + ] + ) + page_table = torch.repeat_interleave( + page_table, repeats=self.speculative_num_draft_tokens, dim=0 + ) elif forward_batch.forward_mode.is_extend(): assert ( forward_batch.extend_seq_lens_cpu is not None @@ -232,7 +285,11 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): ), "All of them must not be None" extend_seq_lens_cpu = forward_batch.extend_seq_lens_cpu assert forward_batch.extend_seq_lens is not None - if any(forward_batch.extend_prefix_lens_cpu): + + if ( + any(forward_batch.extend_prefix_lens_cpu) + or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND + ): max_seqlen_q = max(extend_seq_lens_cpu) cu_seqlens_q = compute_cu_seqlens( forward_batch.extend_seq_lens.to(torch.int32) @@ -277,7 +334,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): flashmla_metadata=( self._compute_flashmla_metadata( cache_seqlens=nsa_cache_seqlens_int32, - seq_len_q=1, # TODO handle MTP which is not 1 + seq_len_q=1, ) if NSA_DECODE_IMPL == "flashmla_decode" else None @@ -288,6 +345,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): nsa_seqlens_expanded=seqlens_expanded, nsa_extend_seq_lens_list=extend_seq_lens_cpu, real_page_table=self._transform_table_1_to_real(page_table), + nsa_max_seqlen_q=1, ) self.forward_metadata = metadata @@ -302,7 +360,9 @@ def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int): to avoid memory allocations. """ self.decode_cuda_graph_metadata: Dict = { - "cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device), + "cache_seqlens": torch.ones( + max_num_tokens, dtype=torch.int32, device=self.device + ), "cu_seqlens_q": torch.arange( 0, max_bs + 1, dtype=torch.int32, device=self.device ), @@ -311,7 +371,7 @@ def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int): ), # fake page_table for sparse_prefill "page_table": torch.zeros( - max_bs, + max_num_tokens, self.max_context_len, dtype=torch.int32, device=self.device, @@ -319,9 +379,9 @@ def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int): "flashmla_metadata": ( self._compute_flashmla_metadata( cache_seqlens=torch.ones( - max_bs, dtype=torch.int32, device=self.device + max_num_tokens, dtype=torch.int32, device=self.device ), - seq_len_q=1, # TODO handle MTP which is not 1 + seq_len_q=1, ) if NSA_DECODE_IMPL == "flashmla_decode" else None @@ -339,50 +399,166 @@ def init_forward_metadata_capture_cuda_graph( spec_info: Optional[SpecInput], ): """Initialize forward metadata for capturing CUDA graph.""" - assert forward_mode.is_decode_or_idle(), "Only support decode for now" - assert ( - spec_info is None - ), "Speculative decoding is not supported for NSA backend now" + if forward_mode.is_decode_or_idle(): + # Normal Decode + # Get sequence information + cache_seqlens_int32 = seq_lens.to(torch.int32) + cu_seqlens_k = compute_cu_seqlens(cache_seqlens_int32) + + # Use max context length for seq_len_k + page_table_1 = self.decode_cuda_graph_metadata["page_table"][:bs, :] + max_seqlen_q = 1 + max_seqlen_k = page_table_1.shape[1] - # Normal Decode - # Get sequence information - cache_seqlens_int32 = seq_lens.to(torch.int32) - cu_seqlens_k = compute_cu_seqlens(cache_seqlens_int32) + # Precompute page table + # Precompute cumulative sequence lengths + + # NOTE(dark): this is always arange, since we are decoding + cu_seqlens_q = self.decode_cuda_graph_metadata["cu_seqlens_q"][: bs + 1] + nsa_cache_seqlens_int32 = compute_nsa_seqlens( + cache_seqlens_int32, nsa_index_topk=self.nsa_index_topk + ) + + seqlens_expanded = cache_seqlens_int32 + nsa_extend_seq_lens_list = [1] * num_tokens + if NSA_DECODE_IMPL == "flashmla_decode": + flashmla_metadata = self.decode_cuda_graph_metadata[ + "flashmla_metadata" + ].slice(slice(0, num_tokens + 1)) + flashmla_metadata.copy_( + self._compute_flashmla_metadata( + cache_seqlens=nsa_cache_seqlens_int32, + seq_len_q=1, + ) + ) + else: + flashmla_metadata = None + elif forward_mode.is_target_verify(): + cache_seqlens_int32 = (seq_lens + self.speculative_num_draft_tokens).to( + torch.int32 + ) + cu_seqlens_k = compute_cu_seqlens(cache_seqlens_int32) + max_seqlen_q = 1 + page_table_1 = self.decode_cuda_graph_metadata["page_table"][ + : bs * self.speculative_num_draft_tokens, : + ] + max_seqlen_k = page_table_1.shape[1] + + cu_seqlens_q = torch.arange( + 0, + bs * self.speculative_num_draft_tokens + 1, + 1, + dtype=torch.int32, + device=self.device, + ) - # Use max context length for seq_len_k - page_table_1 = self.decode_cuda_graph_metadata["page_table"][:bs, :] - max_seq_len_k = page_table_1.shape[1] + extend_seq_lens_cpu = [self.speculative_num_draft_tokens] * bs - # Precompute page table - # Precompute cumulative sequence lengths + seqlens_int32_cpu = [ + self.speculative_num_draft_tokens + kv_len + for kv_len in seq_lens.tolist() + ] + seqlens_expanded = torch.cat( + [ + torch.arange( + kv_len - qo_len + 1, + kv_len + 1, + dtype=torch.int32, + device=self.device, + ) + for qo_len, kv_len in zip( + extend_seq_lens_cpu, + seqlens_int32_cpu, + strict=True, + ) + ] + ) + nsa_cache_seqlens_int32 = compute_nsa_seqlens( + seqlens_expanded, nsa_index_topk=self.nsa_index_topk + ) + nsa_extend_seq_lens_list = [1] * bs * self.speculative_num_draft_tokens - # NOTE(dark): this is always arange, since we are decoding - cu_seqlens_q = self.decode_cuda_graph_metadata["cu_seqlens_q"][: bs + 1] - nsa_cache_seqlens_int32 = compute_nsa_seqlens( - cache_seqlens_int32, nsa_index_topk=self.nsa_index_topk - ) - nsa_cu_seqlens_k = compute_cu_seqlens(nsa_cache_seqlens_int32) - nsa_cu_seqlens_q = self.get_device_int32_arange(len(nsa_cu_seqlens_k)) - real_page_table = self._transform_table_1_to_real(page_table_1) + if NSA_DECODE_IMPL == "flashmla_decode": + flashmla_metadata = self.decode_cuda_graph_metadata[ + "flashmla_metadata" + ].slice(slice(0, bs * self.speculative_num_draft_tokens + 1)) - if NSA_DECODE_IMPL == "flashmla_decode": - flashmla_metadata = self.decode_cuda_graph_metadata[ - "flashmla_metadata" - ].slice(slice(0, bs + 1)) - flashmla_metadata.copy_( - self._compute_flashmla_metadata( - cache_seqlens=nsa_cache_seqlens_int32, - seq_len_q=1, # TODO handle MTP which is not 1 + flashmla_metadata.copy_( + self._compute_flashmla_metadata( + cache_seqlens=nsa_cache_seqlens_int32, + seq_len_q=1, + ) ) + else: + flashmla_metadata = None + elif forward_mode.is_draft_extend(): + cache_seqlens_int32 = (seq_lens + self.speculative_num_draft_tokens).to( + torch.int32 ) - else: - flashmla_metadata = None + cu_seqlens_k = compute_cu_seqlens(cache_seqlens_int32) + page_table_1 = self.decode_cuda_graph_metadata["page_table"][:bs, :] + max_seqlen_k = page_table_1.shape[1] + + extend_seq_lens_cpu = [self.speculative_num_draft_tokens] * bs + extend_seq_lens = torch.full( + (bs,), + self.speculative_num_draft_tokens, + device=self.device, + dtype=torch.int32, + ) + + max_seqlen_q = max(extend_seq_lens_cpu) + cu_seqlens_q = compute_cu_seqlens(extend_seq_lens.to(torch.int32)) + + seqlens_int32_cpu = [ + self.speculative_num_draft_tokens + kv_len + for kv_len in seq_lens.tolist() + ] + seqlens_expanded = torch.cat( + [ + torch.arange( + kv_len - qo_len + 1, + kv_len + 1, + dtype=torch.int32, + device=self.device, + ) + for qo_len, kv_len in zip( + extend_seq_lens_cpu, + seqlens_int32_cpu, + strict=True, + ) + ] + ) + nsa_cache_seqlens_int32 = compute_nsa_seqlens( + seqlens_expanded, nsa_index_topk=self.nsa_index_topk + ) + nsa_extend_seq_lens_list = [1] * bs + + if NSA_DECODE_IMPL == "flashmla_decode": + flashmla_metadata = self.decode_cuda_graph_metadata[ + "flashmla_metadata" + ].slice(slice(0, bs * self.speculative_num_draft_tokens + 1)) + # As the DeepGemm is not support for q_len = 3/4 in Indexer and every token has independent topk_indices, + # we made the Q shape [bs * speculative_num_draft_tokens, 1, head_nums, dim]. + # So seq_len_q is 1 for flashmla_metadata in target_verify and draft_extend mode. + flashmla_metadata.copy_( + self._compute_flashmla_metadata( + cache_seqlens=nsa_cache_seqlens_int32, + seq_len_q=1, + ) + ) + else: + flashmla_metadata = None + + nsa_cu_seqlens_k = compute_cu_seqlens(nsa_cache_seqlens_int32) + nsa_cu_seqlens_q = self.get_device_int32_arange(len(nsa_cu_seqlens_k)) + real_page_table = self._transform_table_1_to_real(page_table_1) metadata = NSAMetadata( page_size=self.real_page_size, cache_seqlens_int32=cache_seqlens_int32, - max_seq_len_q=1, - max_seq_len_k=max_seq_len_k, + max_seq_len_q=max_seqlen_q, + max_seq_len_k=max_seqlen_k, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, page_table_1=page_table_1, @@ -390,9 +566,9 @@ def init_forward_metadata_capture_cuda_graph( nsa_cache_seqlens_int32=nsa_cache_seqlens_int32, nsa_cu_seqlens_q=nsa_cu_seqlens_q, nsa_cu_seqlens_k=nsa_cu_seqlens_k, - nsa_seqlens_expanded=cache_seqlens_int32, + nsa_seqlens_expanded=seqlens_expanded, real_page_table=real_page_table, - nsa_extend_seq_lens_list=[1] * bs, + nsa_extend_seq_lens_list=nsa_extend_seq_lens_list, ) self.decode_cuda_graph_metadata[bs] = metadata self.forward_metadata = metadata @@ -411,33 +587,119 @@ def init_forward_metadata_replay_cuda_graph( ): """Initialize forward metadata for replaying CUDA graph.""" assert seq_lens_cpu is not None - assert forward_mode.is_decode_or_idle(), "Only support decode for now" - assert ( - spec_info is None - ), "Speculative decoding is not supported for NSA backend now" + seq_lens = seq_lens[:bs] seq_lens_cpu = seq_lens_cpu[:bs] req_pool_indices = req_pool_indices[:bs] # Normal Decode metadata: NSAMetadata = self.decode_cuda_graph_metadata[bs] - max_len = int(seq_lens_cpu.max().item()) + if forward_mode.is_decode_or_idle(): + # Normal Decode + max_len = int(seq_lens_cpu.max().item()) + + cache_seqlens = seq_lens.to(torch.int32) + metadata.cache_seqlens_int32.copy_(cache_seqlens) + metadata.cu_seqlens_k[1:].copy_( + torch.cumsum(cache_seqlens, dim=0, dtype=torch.int32) + ) + page_indices = self.req_to_token[req_pool_indices, :max_len] + metadata.page_table_1[:, :max_len].copy_(page_indices) + nsa_cache_seqlens = compute_nsa_seqlens( + cache_seqlens, nsa_index_topk=self.nsa_index_topk + ) + metadata.nsa_cache_seqlens_int32.copy_(nsa_cache_seqlens) + seqlens_expanded = cache_seqlens + elif forward_mode.is_target_verify(): + max_seqlen_k = int( + seq_lens_cpu.max().item() + self.speculative_num_draft_tokens + ) - cache_seqlens = seq_lens.to(torch.int32) - metadata.cache_seqlens_int32.copy_(cache_seqlens) - metadata.cu_seqlens_k[1:].copy_( - torch.cumsum(cache_seqlens, dim=0, dtype=torch.int32) - ) - page_indices = self.req_to_token[req_pool_indices, :max_len] - metadata.page_table_1[:, :max_len].copy_(page_indices) + cache_seqlens = (seq_lens + self.speculative_num_draft_tokens).to( + torch.int32 + ) + metadata.cache_seqlens_int32.copy_(cache_seqlens) + metadata.cu_seqlens_k[1:].copy_( + torch.cumsum(cache_seqlens, dim=0, dtype=torch.int32) + ) + page_indices = self.req_to_token[req_pool_indices, :max_seqlen_k] + page_indices = torch.repeat_interleave( + page_indices, repeats=self.speculative_num_draft_tokens, dim=0 + ) + metadata.page_table_1[:, :max_seqlen_k].copy_(page_indices) + extend_seq_lens_cpu = [self.speculative_num_draft_tokens] * bs + + seqlens_int32_cpu = [ + self.speculative_num_draft_tokens + kv_len + for kv_len in seq_lens_cpu.tolist() + ] + seqlens_expanded = torch.cat( + [ + torch.arange( + kv_len - qo_len + 1, + kv_len + 1, + dtype=torch.int32, + device=self.device, + ) + for qo_len, kv_len in zip( + extend_seq_lens_cpu, + seqlens_int32_cpu, + strict=True, + ) + ] + ) + metadata.nsa_seqlens_expanded.copy_(seqlens_expanded) + nsa_cache_seqlens = compute_nsa_seqlens( + seqlens_expanded, self.nsa_index_topk + ) + metadata.nsa_cache_seqlens_int32.copy_(nsa_cache_seqlens) + elif forward_mode.is_draft_extend(): + max_seqlen_k = int(seq_lens_cpu.max().item()) + cache_seqlens = seq_lens.to(torch.int32) + metadata.cache_seqlens_int32.copy_(cache_seqlens) + metadata.cu_seqlens_k[1:].copy_( + torch.cumsum(cache_seqlens, dim=0, dtype=torch.int32) + ) + page_indices = self.req_to_token[req_pool_indices, :max_seqlen_k] + metadata.page_table_1[:, :max_seqlen_k].copy_(page_indices) + extend_seq_lens_cpu = spec_info.accept_length[:bs].tolist() + + seqlens_int32_cpu = [ + self.speculative_num_draft_tokens + kv_len + for kv_len in seq_lens_cpu.tolist() + ] + seqlens_expanded = torch.cat( + [ + torch.arange( + kv_len - qo_len + 1, + kv_len + 1, + dtype=torch.int32, + device=self.device, + ) + for qo_len, kv_len in zip( + extend_seq_lens_cpu, + seqlens_int32_cpu, + strict=True, + ) + ] + ) + metadata.nsa_seqlens_expanded[: seqlens_expanded.size(0)].copy_( + seqlens_expanded + ) + nsa_cache_seqlens = compute_nsa_seqlens( + seqlens_expanded, self.nsa_index_topk + ) + metadata.nsa_cache_seqlens_int32[: seqlens_expanded.size(0)].copy_( + nsa_cache_seqlens + ) + seqlens_expanded_size = seqlens_expanded.size(0) assert ( metadata.nsa_cache_seqlens_int32 is not None and metadata.nsa_cu_seqlens_k is not None and self.nsa_index_topk is not None ) - nsa_cache_seqlens = compute_nsa_seqlens(cache_seqlens, self.nsa_index_topk) - metadata.nsa_cache_seqlens_int32.copy_(nsa_cache_seqlens) - metadata.nsa_cu_seqlens_k[1:].copy_( + + metadata.nsa_cu_seqlens_k[1 : 1 + seqlens_expanded_size].copy_( torch.cumsum(nsa_cache_seqlens, dim=0, dtype=torch.int32) ) # NOTE(dark): (nsa-) cu_seqlens_q is always arange, no need to copy @@ -451,10 +713,13 @@ def init_forward_metadata_replay_cuda_graph( assert metadata.real_page_table is metadata.page_table_1 if NSA_DECODE_IMPL == "flashmla_decode": - metadata.flashmla_metadata.copy_( + flashmla_metadata = metadata.flashmla_metadata.slice( + slice(0, seqlens_expanded_size + 1) + ) + flashmla_metadata.copy_( self._compute_flashmla_metadata( cache_seqlens=nsa_cache_seqlens, - seq_len_q=1, # TODO handle MTP which is not 1 + seq_len_q=1, ) ) @@ -473,10 +738,7 @@ def forward_extend( k_rope: Optional[torch.Tensor] = None, topk_indices: Optional[torch.Tensor] = None, ) -> torch.Tensor: - assert ( - not forward_batch.forward_mode.is_target_verify() - and not forward_batch.forward_mode.is_draft_extend() - ), "NSA backend doesn't support speculative decoding" + if k is not None: assert v is not None if save_kv_cache: @@ -884,3 +1146,58 @@ def _compute_flashmla_metadata(self, cache_seqlens: torch.Tensor, seq_len_q: int flashmla_metadata=flashmla_metadata, num_splits=num_splits, ) + + +class NativeSparseAttnMultiStepBackend: + + def __init__( + self, model_runner: ModelRunner, topk: int, speculative_num_steps: int + ): + self.model_runner = model_runner + self.topk = topk + self.speculative_num_steps = speculative_num_steps + self.attn_backends = [] + for i in range(self.speculative_num_steps): + self.attn_backends.append( + NativeSparseAttnBackend( + model_runner, + speculative_step_id=i, + topk=self.topk, + speculative_num_steps=self.speculative_num_steps, + ) + ) + + def init_forward_metadata(self, forward_batch: ForwardBatch): + for i in range(self.speculative_num_steps - 1): + self.attn_backends[i].init_forward_metadata(forward_batch) + + def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int): + for i in range(self.speculative_num_steps): + self.attn_backends[i].init_cuda_graph_state(max_bs, max_num_tokens) + + def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch): + for i in range(self.speculative_num_steps): + self.attn_backends[i].init_forward_metadata_capture_cuda_graph( + forward_batch.batch_size, + forward_batch.batch_size * self.topk, + forward_batch.req_pool_indices, + forward_batch.seq_lens, + encoder_lens=None, + forward_mode=ForwardMode.DECODE, + spec_info=forward_batch.spec_info, + ) + + def init_forward_metadata_replay_cuda_graph( + self, forward_batch: ForwardBatch, bs: int + ): + for i in range(self.speculative_num_steps): + self.attn_backends[i].init_forward_metadata_replay_cuda_graph( + bs, + forward_batch.req_pool_indices, + forward_batch.seq_lens, + seq_lens_sum=-1, + encoder_lens=None, + forward_mode=ForwardMode.DECODE, + spec_info=forward_batch.spec_info, + seq_lens_cpu=forward_batch.seq_lens_cpu, + ) diff --git a/python/sglang/srt/speculative/draft_utils.py b/python/sglang/srt/speculative/draft_utils.py index 4f3a4617e0ee..1a8c3b70e000 100644 --- a/python/sglang/srt/speculative/draft_utils.py +++ b/python/sglang/srt/speculative/draft_utils.py @@ -48,6 +48,7 @@ def create_decode_backend(self): "flashmla": self._create_flashmla_decode_backend, "trtllm_mha": self._create_trtllm_mha_decode_backend, "trtllm_mla": self._create_trtllm_mla_decode_backend, + "nsa": self._create_nsa_decode_backend, } return self._create_backend( @@ -70,6 +71,7 @@ def create_draft_extend_backend(self): "flashmla": self._create_flashmla_prefill_backend, "trtllm_mha": self._create_trtllm_mha_prefill_backend, "trtllm_mla": self._create_trtllm_mla_prefill_backend, + "nsa": self._create_nsa_prefill_backend, } backend_name = ( "decode_attention_backend" @@ -82,6 +84,20 @@ def create_draft_extend_backend(self): "EAGLE is not supported in attention backend {backend_type}", ) + def _create_nsa_decode_backend(self): + from sglang.srt.layers.attention.nsa_backend import ( + NativeSparseAttnMultiStepBackend, + ) + + return NativeSparseAttnMultiStepBackend( + self.draft_model_runner, self.topk, self.speculative_num_steps + ) + + def _create_nsa_prefill_backend(self): + from sglang.srt.layers.attention.nsa_backend import NativeSparseAttnBackend + + return NativeSparseAttnBackend(self.draft_model_runner, skip_prefill=False) + def _create_flashinfer_decode_backend(self): if not get_global_server_args().use_mla_backend: from sglang.srt.layers.attention.flashinfer_backend import ( diff --git a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py index c82df4d2e920..c8f2b4e6639f 100644 --- a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py @@ -81,6 +81,7 @@ def __init__(self, eagle_worker: EAGLEWorker): self.seq_lens_cpu = torch.full( (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32 ) + self.extend_seq_lens_cpu = [self.seq_len_fill_value] * self.max_bs if self.enable_torch_compile: set_torch_compile_config() @@ -92,6 +93,7 @@ def __init__(self, eagle_worker: EAGLEWorker): self.seq_lens = torch.full( (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32 ) + self.extend_seq_lens = torch.ones((self.max_bs,), dtype=torch.int32) self.out_cache_loc = torch.zeros( (self.max_num_token * self.speculative_num_steps,), dtype=torch.int64 ) @@ -165,6 +167,9 @@ def capture_one_batch_size(self, num_seqs: int, forward: Callable): # Graph inputs req_pool_indices = self.req_pool_indices[:num_seqs] seq_lens = self.seq_lens[:num_seqs] + seq_lens_cpu = self.seq_lens_cpu[:num_seqs] + extend_seq_lens = self.extend_seq_lens[:num_seqs] + extend_seq_lens_cpu = self.extend_seq_lens_cpu[:num_seqs] out_cache_loc = self.out_cache_loc[: num_tokens * self.speculative_num_steps] positions = self.positions[:num_tokens] mrope_positions = self.mrope_positions[:, :num_tokens] @@ -227,6 +232,9 @@ def capture_one_batch_size(self, num_seqs: int, forward: Callable): input_ids=None, req_pool_indices=req_pool_indices, seq_lens=seq_lens, + seq_lens_cpu=seq_lens_cpu, + extend_seq_lens=extend_seq_lens, + extend_seq_lens_cpu=extend_seq_lens_cpu, req_to_token_pool=self.model_runner.req_to_token_pool, token_to_kv_pool=self.model_runner.token_to_kv_pool, out_cache_loc=out_cache_loc, diff --git a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py index 39d8e0f6a8a4..d54d86a8cb06 100644 --- a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py @@ -78,6 +78,7 @@ def __init__(self, eagle_worker: EAGLEWorker): self.seq_lens_cpu = torch.full( (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32 ) + self.extend_seq_lens_cpu = [self.num_tokens_per_bs] * self.max_bs if self.enable_torch_compile: set_torch_compile_config() @@ -196,7 +197,9 @@ def capture_one_batch_size(self, bs: int, forward: Callable): input_ids = self.input_ids[:num_tokens] req_pool_indices = self.req_pool_indices[:bs] seq_lens = self.seq_lens[:bs] + seq_lens_cpu = self.seq_lens_cpu[:bs] extend_seq_lens = self.extend_seq_lens[:bs] + extend_seq_lens_cpu = self.extend_seq_lens_cpu[:bs] accept_length = self.accept_length[:bs] out_cache_loc = self.out_cache_loc[:num_tokens] positions = self.positions[:num_tokens] @@ -254,6 +257,7 @@ def capture_one_batch_size(self, bs: int, forward: Callable): input_ids=input_ids, req_pool_indices=req_pool_indices, seq_lens=seq_lens, + seq_lens_cpu=seq_lens_cpu, next_token_logits_buffer=next_token_logits_buffer, req_to_token_pool=self.model_runner.req_to_token_pool, token_to_kv_pool=self.model_runner.token_to_kv_pool, @@ -271,6 +275,7 @@ def capture_one_batch_size(self, bs: int, forward: Callable): capture_hidden_mode=CaptureHiddenMode.LAST, attn_backend=self.eagle_worker.draft_extend_attn_backend, extend_seq_lens=extend_seq_lens, + extend_seq_lens_cpu=extend_seq_lens_cpu, padded_static_len=self.padded_static_len, ) @@ -373,6 +378,9 @@ def replay(self, forward_batch: ForwardBatch): self.seq_lens_cpu.fill_(self.seq_len_fill_value) self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu) + if forward_batch.extend_seq_lens_cpu is not None: + self.extend_seq_lens_cpu[:raw_bs] = forward_batch.extend_seq_lens_cpu + if bs != raw_bs: forward_batch.spec_info.positions = self.positions[:num_tokens] forward_batch.spec_info.accept_length = self.accept_length[:bs]