-
Notifications
You must be signed in to change notification settings - Fork 747
Description
I found out that Qwen3.5 MOE at both 35B and 297B has n_kv_heads == 2, and this hard constraints tensor_parallelism_degree.
It would be nice that user gets a clear warning message at startup.
Root Cause
ColwiseParallel on wk/wv shards the output dimension (n_kv_heads * head_dim). The constraint is:
n_kv_heads % TP == 0
If TP > n_kv_heads (or n_kv_heads not divisible by TP), each rank gets a fractional number of KV heads. The view(bs, seqlen, -1, head_dim) reshape in the attention forward then produces a non-integer head count and crashes.
Affected Models
| Model | n_kv_heads | Max TP | Risk |
|---|---|---|---|
| qwen3_5_moe (all configs) | 2 | TP ≤ 2 | High — all production configs bottlenecked at 2 |
| qwen3 (most configs) | 8 | TP ≤ 8 | Low — users are unlikely to request TP > 8 |
| qwen3 (0.6B / 1.7B) | 4 | TP ≤ 4 | Medium |
| llama3 / llama4 (all configs) | 8 | TP ≤ 8 | Low — same as qwen3 |
No Validation Anywhere in Training Paths
None of the four parallelize_*.py files validate n_kv_heads % tp_degree == 0. The only guard in the repo is in the RL inference experiment (torchtitan/experiments/rl/unified/models/attention.py:158), not in training.
The parallelize functions for all models have the identical blind spot:
llama3/parallelize.py→apply_tp():"attention.wk": colwise_parallel()— no checkllama4/parallelize.py→apply_non_moe_tp(): same pattern — no checkqwen3/parallelize.py→apply_non_moe_tp(): same pattern — no checkqwen3_5_moe/parallelize.py→apply_non_moe_tp(): same pattern — no check
Where to Add Validation
The check should be a ValueError in each model's parallelize_* entry point (before apply_non_moe_tp is called), or factored into a shared helper.
Example for parallelize_qwen3_5_moe:
n_kv_heads = model.config.layer.attention.n_kv_heads
tp = parallel_dims.tp
if parallel_dims.tp_enabled and n_kv_heads % tp != 0:
raise ValueError(
f"Tensor parallel degree ({tp}) must divide n_kv_heads ({n_kv_heads}). "
f"For Qwen3.5-MoE, max supported TP is {n_kv_heads}."
)
The same check should be added to parallelize_qwen3, parallelize_llama (llama3), and llama4's entry point. Without it, users hitting TP=4 on qwen3_5_moe get a cryptic reshape error deep in the forward pass rather than a clear message at startup.