Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 54 additions & 3 deletions python/sglang/srt/layers/attention/nsa_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,9 @@ def topk_transform(
assert False, f"Unsupported {self.topk_transform_method = }"


_NSA_IMPL_T: TypeAlias = Literal["flashmla_sparse", "flashmla_kv", "fa3", "tilelang"]
_NSA_IMPL_T: TypeAlias = Literal[
"flashmla_sparse", "flashmla_kv", "fa3", "tilelang", "trtllm"
]


class NativeSparseAttnBackend(
Expand Down Expand Up @@ -287,6 +289,9 @@ def __init__(
model_runner.model_config.num_attention_heads // get_attention_tp_size()
)
self.kv_cache_dim = model_runner.token_to_kv_pool.kv_cache_dim
self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim
self.kv_lora_rank = model_runner.model_config.kv_lora_rank
self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim

assert model_runner.req_to_token_pool is not None
self.req_to_token = model_runner.req_to_token_pool.req_to_token
Expand Down Expand Up @@ -318,8 +323,8 @@ def __init__(
self.device_capability = torch.cuda.get_device_capability()
self.device_sm_major = self.device_capability[0]

# Allocate global workspace buffer for TRTLLm ragged attention kernel (SM100/B200)
if self.device_sm_major >= 10:
# Allocate global workspace buffer for TRT-LLM kernels (ragged attention on SM100/B200, or trtllm decode)
if self.device_sm_major >= 10 or self.nsa_decode_impl == "trtllm":
global global_workspace_buffer
if global_workspace_buffer is None:
global_workspace_buffer = torch.empty(
Expand Down Expand Up @@ -1454,6 +1459,17 @@ def forward_decode(
bs=forward_batch.batch_size,
)

elif self.nsa_decode_impl == "trtllm":
if q_rope is not None:
q_all = _concat_mla_absorb_q_general(q_nope, q_rope)
return self._forward_trtllm(
q_all=q_all,
kv_cache=kv_cache,
page_table_1=page_table_1,
metadata=metadata,
sm_scale=layer.scaling,
)

else:
assert False, f"Unsupported {self.nsa_decode_impl = }"

Expand Down Expand Up @@ -1713,6 +1729,41 @@ def _forward_aiter(
# kv_cache = kv_cache.view(-1, 1, layer.head_dim)
return o

def _forward_trtllm(
self,
q_all: torch.Tensor,
kv_cache: torch.Tensor,
page_table_1: torch.Tensor,
metadata: NSAMetadata,
sm_scale: float,
) -> torch.Tensor:
"""Forward using TRT-LLM sparse MLA kernel."""
import flashinfer.decode

batch_size = page_table_1.shape[0]
_, num_heads, head_dim = q_all.shape

q = q_all.view(batch_size, 1, num_heads, head_dim)
kv = kv_cache.view(-1, 1, self.real_page_size, self.kv_cache_dim)
block_tables = page_table_1.unsqueeze(1)

out = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(
query=q,
kv_cache=kv,
workspace_buffer=self.workspace_buffer,
qk_nope_head_dim=self.qk_nope_head_dim,
kv_lora_rank=self.kv_lora_rank,
qk_rope_head_dim=self.qk_rope_head_dim,
block_tables=block_tables,
seq_lens=metadata.cache_seqlens_int32,
max_seq_len=metadata.max_seq_len_k,
sparse_mla_top_k=self.nsa_index_topk,
bmm1_scale=sm_scale,
backend="trtllm-gen",
)
# Output: [batch, q_len=1, heads, v_dim] -> [batch, heads, v_dim]
return out.squeeze(1)

def _pad_topk_indices(
self, topk_indices: torch.Tensor, num_tokens: int
) -> torch.Tensor:
Expand Down
92 changes: 64 additions & 28 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@
"fa3",
"tilelang",
"aiter",
"trtllm",
]

RADIX_EVICTION_POLICY_CHOICES = ["lru", "lfu"]
Expand Down Expand Up @@ -436,8 +437,12 @@ class ServerArgs:
mm_attention_backend: Optional[str] = None
fp8_gemm_runner_backend: str = "auto"
fp4_gemm_runner_backend: str = "auto"
nsa_prefill_backend: str = "flashmla_sparse"
nsa_decode_backend: str = "fa3"
nsa_prefill_backend: Optional[str] = (
None # None = auto-detect based on hardware/kv_cache_dtype
)
nsa_decode_backend: Optional[str] = (
None # auto-detect based on hardware/kv_cache_dtype
)
disable_flashinfer_autotune: bool = False

# Speculative decoding
Expand Down Expand Up @@ -1086,6 +1091,59 @@ def _generate_piecewise_cuda_graph_tokens(self):

return capture_sizes

def _set_default_nsa_kv_cache_dtype(self, major: int) -> str:
user_set_prefill = self.nsa_prefill_backend is not None
user_set_decode = self.nsa_decode_backend is not None

# If user specified a backend but didn't explicitly set kv_cache_dtype,
# suggest them to be explicit about kv_cache_dtype to avoid surprises
if (user_set_prefill or user_set_decode) and self.kv_cache_dtype == "auto":
logger.warning(
f"When specifying --nsa-prefill-backend or --nsa-decode-backend, "
f"you should also explicitly set --kv-cache-dtype (e.g., 'fp8_e4m3' or 'bfloat16'). "
f"DeepSeek V3.2 defaults to FP8 KV cache which may not be compatible with all backends."
)

if self.kv_cache_dtype == "auto":
self.kv_cache_dtype = "fp8_e4m3" if major >= 10 else "bfloat16"
logger.warning(
f"Setting KV cache dtype to {self.kv_cache_dtype} for DeepSeek DSA on SM{major} device."
)
if self.kv_cache_dtype == "bf16":
self.kv_cache_dtype = "bfloat16"
assert self.kv_cache_dtype in [
"bfloat16",
"fp8_e4m3",
], "DeepSeek DSA only supports bf16/bfloat16 or fp8_e4m3 kv_cache_dtype"

def _set_default_nsa_backends(self, kv_cache_dtype: str, major: int) -> str:
user_set_prefill = self.nsa_prefill_backend is not None
user_set_decode = self.nsa_decode_backend is not None

if kv_cache_dtype == "fp8_e4m3":
# flashmla_auto dispatches to flashmla_sparse/flashmla_kv based on hardware and heuristics
if not user_set_prefill:
self.nsa_prefill_backend = "flashmla_auto"
if not user_set_decode:
self.nsa_decode_backend = "flashmla_kv"
else:
# set prefill/decode backends based on hardware architecture.
if major >= 10:
if not user_set_prefill:
self.nsa_prefill_backend = "flashmla_sparse"
if not user_set_decode:
self.nsa_decode_backend = "trtllm"
else:
# Hopper defaults for bfloat16
if not user_set_prefill:
self.nsa_prefill_backend = "flashmla_sparse"
if not user_set_decode:
self.nsa_decode_backend = "fa3"

logger.warning(
f"Set NSA backends for {self.kv_cache_dtype} KV Cache: prefill={self.nsa_prefill_backend}, decode={self.nsa_decode_backend}."
)

def _handle_model_specific_adjustments(self):
from sglang.srt.configs.model_config import is_deepseek_nsa

Expand Down Expand Up @@ -1158,35 +1216,11 @@ def _handle_model_specific_adjustments(self):
self.page_size = 64
logger.warning("Setting page size to 64 for DeepSeek DSA.")

# For Hopper, we support both bf16 and fp8 kv cache; for Blackwell, we support fp8 only currently
import torch

major, _ = torch.cuda.get_device_capability()
if self.kv_cache_dtype == "auto":
self.kv_cache_dtype = "fp8_e4m3" if major >= 10 else "bfloat16"
logger.warning(
f"Setting KV cache dtype to {self.kv_cache_dtype} for DeepSeek DSA on SM{major} device."
)
if self.kv_cache_dtype == "bf16":
self.kv_cache_dtype = "bfloat16"
assert self.kv_cache_dtype in [
"bfloat16",
"fp8_e4m3",
], "DeepSeek DSA only supports bf16/bfloat16 or fp8_e4m3 kv_cache_dtype"

if self.kv_cache_dtype == "fp8_e4m3":
# flashmla_auto dispatches to flashmla_sparse/flashmla_kv based on hardware and heuristics
self.nsa_prefill_backend = "flashmla_auto"
self.nsa_decode_backend = "flashmla_kv"
logger.warning(
"Setting DSA backend to flashmla_auto for prefill and flashmla_kv for decode for FP8 KV Cache."
)
else:
# set prefill/decode backends to flashmla_sparse for Blackwell.
# The default settings (P=flashmla_sparse, D=fa3) are for Hopper.
if major >= 10:
self.nsa_prefill_backend = "flashmla_sparse"
self.nsa_decode_backend = "flashmla_sparse"
self._set_default_nsa_kv_cache_dtype(major)
self._set_default_nsa_backends(self.kv_cache_dtype, major)

if self.enable_nsa_prefill_context_parallel:
assert (
Expand Down Expand Up @@ -3586,12 +3620,14 @@ def add_cli_args(parser: argparse.ArgumentParser):
default=ServerArgs.nsa_prefill_backend,
type=str,
choices=NSA_CHOICES,
help="NSA prefill backend. If not specified, auto-detects based on hardware and kv_cache_dtype.",
)
parser.add_argument(
"--nsa-decode-backend",
default=ServerArgs.nsa_decode_backend,
type=str,
choices=NSA_CHOICES,
help="NSA decode backend. If not specified, auto-detects based on hardware and kv_cache_dtype.",
)
parser.add_argument(
"--fp8-gemm-backend",
Expand Down
Loading