Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions docs/basic_usage/deepseek_v32.md
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,13 @@ Latency: 25.109 s
Output throughput: 5226.235 token/s
```

To test long-context accuracy, run gsm8k with `--num-shots 20`. The results are very close to the 8 shots results:
```
Accuracy: 0.956
Invalid: 0.000
Latency: 29.545 s
Output throughput: 4418.617 token/s
```

### Accuracy Test with `gpqa-diamond`

Expand All @@ -143,6 +150,65 @@ Repeat: 8, mean: 0.797
Scores: ['0.808', '0.798', '0.808', '0.798', '0.783', '0.788', '0.803', '0.793']
```

### Accuracy Test with `aime 2025`

Prepare the environment by installing NeMo-Skills in the docker or your own virtual environment:

```
pip install git+https://github.com/NVIDIA/NeMo-Skills.git --ignore-installed blinker
```

Modify the [`jinja chat_template`](https://huggingface.co/deepseek-ai/DeepSeek-V3.2-Exp/blob/main/tokenizer_config.json#L34) by replacing

```
{% set thinking = false %}
```
with
```
{% set thinking = true %}
```
and save it to `chat_template_thinking.jinja`.

Launch the SGLang server with the modified chat-template file:
```
python -m sglang.launch_server --model deepseek-ai/DeepSeek-V3.2-Exp --tp 8 --dp 8 --enable-dp-attention --chat-template chat_template_thinking.jinja
```

Run the following script to evaluate AIME 2025:
```
#! /bin/bash
export NEMO_SKILLS_DISABLE_UNCOMMITTED_CHANGES_CHECK=1

ns prepare_data aime25

PORT=30000
BACKEND=sglang
MODEL="deepseek-ai/DeepSeek-V3.2-Exp"
MODEL_NAME="dsv32-fp8"

echo "Starting AIME25 evaluation with model $MODEL on port $PORT using backend $BACKEND..."
ns eval \
--benchmarks=aime25:4 \
--server_type=$BACKEND \
--model=$MODEL \
--server_address=http://localhost:${PORT}/v1 \
--output_dir=nemo_skills_aime25_${MODEL_NAME}_output_${BACKEND}_$(date +%Y%m%d_%H%M%S) \
++max_concurrent_requests=512 \
++server.api_key=dummy \
++inference.tokens_to_generate=64000
```

Test results:


| evaluation_mode | num_entries | avg_tokens | gen_seconds | symbolic_correct | no_answer |
|--------------------|-------------|------------|-------------|-----------------------|-----------|
| pass@1[avg-of-4] | 30 | 14410 | 1758 | 85.83% ± 4.19% | 0.00% |
| majority@4 | 30 | 14410 | 1758 | 90.00% | 0.00% |
| pass@4 | 30 | 14410 | 1758 | 93.33% | 0.00% |

Note that the result of problem#3 with id `aime25-2` is marked as false by nemo-skills but is actually correct because nemo-skills fails to match predicted_answer `016` with expected_answer `16`. If we add 1/30 = 3.33% to the results, the pass@1[avg-of-4] result matches with reference which is 89.3.


## DSA long sequence context parallel optimization(experimental)

Expand Down
79 changes: 26 additions & 53 deletions python/sglang/srt/layers/attention/nsa/nsa_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ def __init__(
prefix: str = "",
quant_config: Optional[QuantizationConfig] = None,
alt_stream: Optional[torch.cuda.Stream] = None,
fuse_wk_and_weights_proj: bool = False,
):
super().__init__()
self.hidden_size = hidden_size
Expand All @@ -120,7 +119,6 @@ def __init__(
self.q_lora_rank = q_lora_rank
self.layer_id = layer_id
self.alt_stream = alt_stream
self.fuse_wk_and_weights_proj = fuse_wk_and_weights_proj
self.nsa_enable_prefill_cp = is_nsa_enable_prefill_cp()
if self.nsa_enable_prefill_cp:
self.cp_size = get_attention_tp_size()
Expand All @@ -139,28 +137,22 @@ def __init__(
quant_config=quant_config,
prefix=add_prefix("wq_b", prefix),
)
if self.fuse_wk_and_weights_proj:
self.fused_wk_and_weights_proj = ReplicatedLinear(
self.hidden_size,
self.head_dim + self.n_heads,
bias=False,
prefix=add_prefix("fused_wk_and_weights_proj", prefix),
)
else:
self.wk = ReplicatedLinear(
self.hidden_size,
self.head_dim,
bias=False,
quant_config=quant_config,
prefix=add_prefix("wk", prefix),
)
# NOTE: weight_proj is not quantized
self.weights_proj = ReplicatedLinear(
self.hidden_size,
self.n_heads,
bias=False,
prefix=add_prefix("weights_proj", prefix),
)

self.wk = ReplicatedLinear(
self.hidden_size,
self.head_dim,
bias=False,
quant_config=quant_config,
prefix=add_prefix("wk", prefix),
)
# NOTE: weights_proj in the checkpoint is stored in bf16, while the parameters here are stored in fp32 for convenience
self.weights_proj = ReplicatedLinear(
self.hidden_size,
self.n_heads,
bias=False,
params_dtype=torch.float32,
prefix=add_prefix("weights_proj", prefix),
)
self.k_norm = LayerNorm(self.head_dim, dtype=torch.float32)
self.rotary_emb = get_rope_wrapper(
rope_head_dim,
Expand All @@ -176,7 +168,8 @@ def __init__(
self.softmax_scale = self.head_dim**-0.5

@torch.compile(dynamic=True)
def _get_logits_head_gate(self, weights: torch.Tensor, q_scale: torch.Tensor):
def _get_logits_head_gate(self, x: torch.Tensor, q_scale: torch.Tensor):
weights, _ = self.weights_proj(x.float())
weights = weights * self.n_heads**-0.5
weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale
return weights
Expand All @@ -189,7 +182,6 @@ def _get_q_k_bf16(
enable_dual_stream: bool,
forward_batch: ForwardBatch,
):
weights = None
if enable_dual_stream:
current_stream = torch.cuda.current_stream()
self.alt_stream.wait_stream(current_stream)
Expand All @@ -206,12 +198,7 @@ def _get_q_k_bf16(
)
with torch.cuda.stream(self.alt_stream):
# TODO we should also put DeepGEMM half SM here?
if self.fuse_wk_and_weights_proj:
key, weights = self.fused_wk_and_weights_proj(x)[0].split(
[self.head_dim, self.n_heads], dim=-1
)
else:
key, _ = self.wk(x)
key, _ = self.wk(x)
key = self.k_norm(key)

k_rope, _ = torch.split(
Expand All @@ -224,17 +211,10 @@ def _get_q_k_bf16(
else:
query, _ = self.wq_b(q_lora)
query = rearrange(query, "l (h d) -> l h d", d=self.head_dim)

q_rope, _ = torch.split(
query, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1
)

if self.fuse_wk_and_weights_proj:
key, weights = self.fused_wk_and_weights_proj(x)[0].split(
[self.head_dim, self.n_heads], dim=-1
)
else:
key, _ = self.wk(x)
key, _ = self.wk(x)
key = self.k_norm(key)
k_rope, _ = torch.split(
key, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1
Expand Down Expand Up @@ -266,21 +246,16 @@ def _get_q_k_bf16(
query = rotate_activation(query)
key = rotate_activation(key)

return query, key, weights
return query, key

def _get_k_bf16(
self,
x: torch.Tensor,
positions: torch.Tensor,
enable_dual_stream: bool,
):
# Compute only key, skip query and weights (weights is discarded if fused)
if self.fuse_wk_and_weights_proj:
key, _ = self.fused_wk_and_weights_proj(x)[0].split(
[self.head_dim, self.n_heads], dim=-1
)
else:
key, _ = self.wk(x)
# Compute only key, skip query
key, _ = self.wk(x)
key = self.k_norm(key)
k_rope, _ = torch.split(
key, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1
Expand Down Expand Up @@ -779,7 +754,7 @@ def forward_cuda(
return_indices,
)

query, key, weights = self._get_q_k_bf16(
query, key = self._get_q_k_bf16(
q_lora, x, positions, enable_dual_stream, forward_batch=forward_batch
)

Expand Down Expand Up @@ -808,9 +783,7 @@ def forward_cuda(
index_k_scale=k_scale,
)

if not self.fuse_wk_and_weights_proj:
weights, _ = self.weights_proj(x)
weights = self._get_logits_head_gate(weights, q_scale)
weights = self._get_logits_head_gate(x, q_scale)

if is_cuda():
assert forward_batch.seq_lens_cpu is not None
Expand Down Expand Up @@ -1037,7 +1010,7 @@ def forward_npu(
past_key_states = forward_batch.token_to_kv_pool.get_index_k_buffer(layer_id)

x = x.view(-1, self.hidden_size)
weights = self.weights_proj(x)[0]
weights = self.weights_proj(x.float())[0]
block_table = (
block_table[: actual_seq_lengths_q.size()[0]] if is_prefill else block_table
)
Expand Down
71 changes: 0 additions & 71 deletions python/sglang/srt/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,17 +239,6 @@ def add_forward_absorb_core_attention_backend(backend_name):
logger.info(f"Added {backend_name} to FORWARD_ABSORB_CORE_ATTENTION_BACKENDS.")


def is_nsa_indexer_wk_and_weights_proj_fused(config, quant_config):
"""
NSA Indexer wk and weights_proj can be fused in FP4 model because they are both in BF16
"""
return (
is_deepseek_nsa(config)
and quant_config is not None
and quant_config.get_name() == "modelopt_fp4"
)


class AttnForwardMethod(IntEnum):
# Use multi-head attention
MHA = auto()
Expand Down Expand Up @@ -1226,9 +1215,6 @@ def __init__(
quant_config=quant_config,
layer_id=layer_id,
alt_stream=alt_stream,
fuse_wk_and_weights_proj=is_nsa_indexer_wk_and_weights_proj_fused(
config, quant_config
),
)

self.kv_b_proj = ColumnParallelLinear(
Expand Down Expand Up @@ -3768,12 +3754,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=Fal
self.config.q_lora_rank is not None
)
cached_a_proj = {} if fuse_qkv_a_proj else None
# Fuse wk and weights_proj when NSA Indexer is enabled and quant_config is FP4. For nextn, fp4 is disabled so we cannot fuse.
fuse_wk_and_weights_proj = (
is_nsa_indexer_wk_and_weights_proj_fused(self.config, self.quant_config)
and not is_nextn
)
cached_wk_and_weights_proj = {} if fuse_wk_and_weights_proj else None

if is_nextn:
nextn_layer_prefix = f"model.layers.{nextn_layer_id}"
Expand Down Expand Up @@ -3959,57 +3939,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=Fal
)
cached_a_proj.pop(q_a_proj_name)
cached_a_proj.pop(kv_a_proj_name)
elif fuse_wk_and_weights_proj and (
"wk" in name or "weights_proj" in name
):
cached_wk_and_weights_proj[name] = loaded_weight
wk_name = (
name
if "wk" in name
else name.replace("weights_proj", "wk")
)
weights_proj_name = (
name
if "weights_proj" in name
else name.replace("wk", "weights_proj")
)

# When both wk and weights_proj has been cached, load the fused weight to parameter
if (
wk_name in cached_wk_and_weights_proj
and weights_proj_name in cached_wk_and_weights_proj
):
wk_weight = cached_wk_and_weights_proj[wk_name]
weights_proj_weight = cached_wk_and_weights_proj[
weights_proj_name
]
# todo dequantize wk for fp8
assert wk_weight.dtype == weights_proj_weight.dtype
fused_weight = torch.cat(
[wk_weight, weights_proj_weight], dim=0
)
param_name = (
name.replace("wk", "fused_wk_and_weights_proj")
if "wk" in name
else name.replace(
"weights_proj",
"fused_wk_and_weights_proj",
)
)
param = params_dict[param_name]

weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
maybe_executor_submit(
executor=executor,
futures=futures,
use_async=use_async_loading,
func=weight_loader,
func_args=(param, fused_weight),
)
cached_wk_and_weights_proj.pop(wk_name)
cached_wk_and_weights_proj.pop(weights_proj_name)
else:
if (
"k_scale" in name or "v_scale" in name
Expand Down
Loading