Skip to content

Check tensor parallelism degree in parallelize_fn #2574

@gali-leilei

Description

@gali-leilei

I found out that Qwen3.5 MOE at both 35B and 297B has n_kv_heads == 2, and this hard constraints tensor_parallelism_degree.
It would be nice that user gets a clear warning message at startup.

Root Cause

ColwiseParallel on wk/wv shards the output dimension (n_kv_heads * head_dim). The constraint is:

n_kv_heads % TP == 0

If TP > n_kv_heads (or n_kv_heads not divisible by TP), each rank gets a fractional number of KV heads. The view(bs, seqlen, -1, head_dim) reshape in the attention forward then produces a non-integer head count and crashes.

Affected Models

Model n_kv_heads Max TP Risk
qwen3_5_moe (all configs) 2 TP ≤ 2 High — all production configs bottlenecked at 2
qwen3 (most configs) 8 TP ≤ 8 Low — users are unlikely to request TP > 8
qwen3 (0.6B / 1.7B) 4 TP ≤ 4 Medium
llama3 / llama4 (all configs) 8 TP ≤ 8 Low — same as qwen3

No Validation Anywhere in Training Paths

None of the four parallelize_*.py files validate n_kv_heads % tp_degree == 0. The only guard in the repo is in the RL inference experiment (torchtitan/experiments/rl/unified/models/attention.py:158), not in training.

The parallelize functions for all models have the identical blind spot:

  • llama3/parallelize.pyapply_tp(): "attention.wk": colwise_parallel() — no check
  • llama4/parallelize.pyapply_non_moe_tp(): same pattern — no check
  • qwen3/parallelize.pyapply_non_moe_tp(): same pattern — no check
  • qwen3_5_moe/parallelize.pyapply_non_moe_tp(): same pattern — no check

Where to Add Validation

The check should be a ValueError in each model's parallelize_* entry point (before apply_non_moe_tp is called), or factored into a shared helper.
Example for parallelize_qwen3_5_moe:

  n_kv_heads = model.config.layer.attention.n_kv_heads
  tp = parallel_dims.tp
  if parallel_dims.tp_enabled and n_kv_heads % tp != 0:
      raise ValueError(
          f"Tensor parallel degree ({tp}) must divide n_kv_heads ({n_kv_heads}). "
          f"For Qwen3.5-MoE, max supported TP is {n_kv_heads}."
      )

The same check should be added to parallelize_qwen3, parallelize_llama (llama3), and llama4's entry point. Without it, users hitting TP=4 on qwen3_5_moe get a cryptic reshape error deep in the forward pass rather than a clear message at startup.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions