Skip to content

[CuTeDSL] Flash Attention v2 for SM120 (Blackwell GeForce)#3030

Open
blake-snc wants to merge 7 commits intoNVIDIA:mainfrom
blake-snc:sm120-flash-attention-v2
Open

[CuTeDSL] Flash Attention v2 for SM120 (Blackwell GeForce)#3030
blake-snc wants to merge 7 commits intoNVIDIA:mainfrom
blake-snc:sm120-flash-attention-v2

Conversation

@blake-snc
Copy link

@blake-snc blake-snc commented Feb 13, 2026

Summary

Adds CuTe DSL Flash Attention v2 forward pass kernels for SM120 (RTX 5090, GB10 / DGX Spark), addressing the lack of high-performance FA kernels for this architecture.

Three implementations included:

BF16 Kernels (flash_attention_v2.py)

  1. CpAsync variant (FlashAttentionForwardSm120): All threads perform both loads and compute using cp.async — the Ampere-era approach, tuned for SM120's 101 KB shared memory.

  2. TMA variant (FlashAttentionForwardSm120Tma): Uses TMA (cp.async.bulk) with warp specialization — 1 dedicated DMA warp handles TMA loads while N MMA warps compute, enabling load/compute overlap via multi-stage KV pipelining with PipelineTmaAsync mbarrier synchronization.

Both use SM80-compatible tensor core instructions (mma.sync.aligned.m16n8k16) since SM120 lacks tcgen05 MMA. Supports FP16/BF16, causal/non-causal, configurable tile sizes, asymmetric Q/K lengths, online softmax fusion, and register pipelining.

FP8 Kernel (fp8_flash_attention.py)

  1. FP8 optimized variant (FP8FlashAttentionSm120Opt): Uses SM120's native FP8 MMA instruction (mma.sync.aligned.kind::f8f6f4.m16n8k32) via inline PTX, with CpAsync double-buffered K/V pipeline, vectorized 4×4 byte transpose via prmt.b32, and bank-conflict-free SMEM layout (+16 byte row padding).

FP8 kernel features:

  • 4-warp design with register-level O accumulation (no SMEM round-trip for output)
  • Inline PTX MMA as workaround for CUTLASS MmaAtomSM80Type segfault with FP8 types (CuTe DSL: FP8 MMA segfaults on SM120 — MmaAtomSM80Type missing kind::f8f6f4 lowering #3044)
  • CpAsync pipelining: V load overlaps QK GEMM, next K prefetch overlaps transpose + PV GEMM
  • Vectorized V transpose using 8-instruction prmt.b32 4×4 byte shuffle (no ldmatrix.b8 on SM120)
  • Bank-conflict-free SMEM: +16 bytes/row padding eliminates 8-way bank conflicts for stride-D access

TMA variant details

  • 4D TMA descriptors (seq, dim, head, batch) for native multi-batch support
  • TMA-compatible Swizzle(B, 4, 3) pattern (M=4 required by TMA hardware; the CpAsync version uses Swizzle(B, 3, 3) which is not valid for TMA)
  • Separate K and V pipelines for independent scheduling
  • Configurable kv_stages for double-buffering (default 2, falls back to 1 for large head_dim)

Benchmark Results

FP8 vs BF16 Performance on DGX Spark (SM121a)

FP8 kernel: FP8FlashAttentionSm120Opt (CpAsync, bank-conflict-free, 4 warps, M=64, N=32)
BF16 kernel: FlashAttentionForwardSm120 (CpAsync, tiled MMA, M=128, N=64)

Batch SeqLen HeadDim Causal FP8 (μs) BF16 (μs) FP8 TFLOPS BF16 TFLOPS Ratio
1 512 64 no 46.8 32.0 18.38 26.84 0.68x
1 512 64 yes 26.8 19.1 16.03 22.53 0.71x
1 512 128 no 163.8 98.6 13.11 21.77 0.60x
1 512 128 yes 89.2 55.9 12.06 19.25 0.63x
1 1024 64 no 161.3 108.4 21.26 31.65 0.67x
1 1024 64 yes 85.8 57.3 20.01 29.97 0.67x
1 1024 128 no 545.0 308.4 15.76 27.86 0.57x
1 1024 128 yes 291.1 170.3 14.73 25.18 0.59x
1 2048 64 no 614.0 419.3 22.28 32.63 0.68x
1 2048 64 yes 316.1 213.1 21.68 32.16 0.67x
1 2048 128 no 2043.3 1140.2 16.84 30.18 0.56x
1 2048 128 yes 1053.2 601.3 16.35 28.65 0.57x
4 512 64 no 115.1 101.2 29.82 33.92 0.88x
4 512 128 no 243.9 348.1 35.26 24.71 1.43x
4 1024 64 no 382.0 349.9 35.81 39.09 0.92x
4 1024 128 no 810.7 1137.1 42.41 30.23 1.40x

Key findings:

  • FP8 peaks at 42.4 TFLOPS (D=128, B=4) and 35.8 TFLOPS (D=64, B=4)
  • FP8 outperforms BF16 by up to 1.43× at larger batch sizes with D=128
  • At B=1, BF16 is faster due to FP8's software V transpose overhead and P SMEM round-trip
  • The crossover point is around B=4 where FP8's 2× arithmetic throughput compensates for overhead

BF16 TMA vs CpAsync Performance on DGX Spark (SM121a)

TMA kernel: FlashAttentionForwardSm120Tma (warp-specialized, 3 MMA + 1 DMA warp, PipelineTmaAsync)
CpAsync kernel: FlashAttentionForwardSm120 (all threads load+compute, cp.async)
Both: M=128, N=64, 16 heads, 20 warmup iters, 100 timed iters

Batch SeqLen HDim Causal CpAsync (μs) TMA (μs) CpAsync TFLOPS TMA TFLOPS Speedup
1 512 64 no 20.5 24.6 52.48 43.67 0.83x
1 512 64 yes 22.7 32.8 23.60 16.39 0.69x
1 512 128 no 41.0 41.0 52.40 52.38 1.00x
1 512 128 yes 32.8 34.8 32.77 30.85 0.94x
1 1024 64 no 82.2 65.5 52.26 65.55 1.25x
1 1024 64 yes 45.1 51.3 47.66 41.85 0.88x
1 1024 128 no 112.4 110.8 76.45 77.55 1.01x
1 1024 128 yes 85.4 86.1 50.32 49.90 0.99x
1 2048 64 no 217.8 281.5 78.86 61.03 0.77x
1 2048 64 yes 155.2 157.5 55.33 54.52 0.99x
1 2048 128 no 519.7 461.3 66.11 74.49 1.13x
1 2048 128 yes 317.6 285.9 54.09 60.10 1.11x
4 512 64 no 55.3 88.9 77.72 48.30 0.62x
4 512 128 no 175.7 179.3 48.88 47.91 0.98x
4 1024 64 no 226.9 318.2 75.72 53.98 0.71x
4 1024 128 no 598.8 533.4 57.38 64.41 1.12x
4 2048 64 no 1272.0 1136.9 54.02 60.45 1.12x
4 2048 128 no 1842.1 1865.4 74.61 73.68 0.99x
1 4096 128 no 1863.6 1901.2 73.75 72.29 0.98x
1 4096 128 yes 909.2 981.3 75.58 70.03 0.93x
1 8192 128 no 7425.3 8109.9 74.04 67.79 0.92x
1 8192 128 yes 3866.6 3648.6 71.09 75.34 1.06x
4 4096 128 no 6966.4 6840.3 78.91 80.37 1.02x

Geometric mean speedup: 0.95x · Min: 0.62x · Max: 1.25x

Note: These initial results contained JIT compilation artifacts (some configs inflated/deflated by first-compilation overhead) and a causal masking bug — see updated results below.

Key findings:

  • D=128, medium-to-large sequences: TMA wins 6-13% (S≥2048 causal, B=4 S≥1024) — warp specialization amortizes DMA warp overhead
  • D=64: TMA consistently loses (0.62x-0.88x) — tiles too small to amortize the dedicated DMA warp
  • Short sequences (S=512): Roughly neutral for D=128, overhead-dominated for D=64
  • kv_stages=1 (required by SMEM budget for D=128): Limits TMA pipelining benefit — no double-buffering of KV tiles
  • TMA's primary advantage is architectural: it demonstrates cp.async.bulk + warp specialization on SM120, a pattern that scales better with larger tile sizes and multi-stage pipelining

BF16 TMA vs CpAsync — Updated (post causal masking fix)

The initial results above had two issues:

  1. JIT compilation artifacts: The CuTe DSL JIT-compiles kernels on first invocation, which inflated/deflated some configs (e.g. the 1.25x outlier at B=1 S=1024 D=64 was a JIT artifact). The updated benchmark pre-warms all kernel variants before timing.

  2. Causal masking bug (fixed in 106e24b): The TMA variant applied expensive per-element causal masking to all KV tiles instead of only the ceil_div(m_block, n_block) tiles near the diagonal. The CpAsync variant correctly used a two-loop structure (masked loop + fast loop). This caused up to 40% regression on causal configs (e.g. B=1 S=512 D=64 causal: 0.69x → 1.00x after fix).

Updated results with JIT pre-warming + causal fix:

Batch SeqLen HDim Causal CpAsync (μs) TMA (μs) CpAsync TFLOPS TMA TFLOPS Speedup
1 512 64 no 20.5 24.6 52.46 43.69 0.83x
1 512 64 yes 20.5 20.5 26.22 26.18 1.00x
1 512 128 no 40.9 41.0 52.44 52.40 1.00x
1 512 128 yes 32.5 34.8 33.08 30.84 0.93x
1 1024 64 no 53.3 65.4 80.52 65.63 0.82x
1 1024 64 yes 43.5 51.2 49.41 41.92 0.85x
1 1024 128 no 123.9 122.7 69.35 70.03 1.01x
1 1024 128 yes 81.6 85.6 52.64 50.18 0.95x
1 2048 64 no 201.2 240.1 85.38 71.54 0.84x
1 2048 64 yes 128.2 156.0 67.00 55.07 0.82x
1 2048 128 no 484.3 474.7 70.95 72.38 1.02x
1 2048 128 yes 269.7 281.1 63.71 61.11 0.96x
4 512 64 no 53.9 69.6 79.70 61.71 0.77x
4 512 128 no 160.9 206.6 53.39 41.58 0.78x
4 1024 64 no 259.4 272.6 66.22 63.01 0.95x
4 1024 128 no 556.5 504.5 61.74 68.11 1.10x
4 2048 64 no 818.2 1040.6 83.99 66.04 0.79x
4 2048 128 no 1784.1 1727.8 77.03 79.55 1.03x
1 4096 128 no 1629.5 1579.0 84.35 87.04 1.03x
1 4096 128 yes 923.3 928.0 74.43 74.05 0.99x
1 8192 128 no 6420.1 7276.4 85.63 75.55 0.88x
1 8192 128 yes 3668.0 3825.6 74.94 71.85 0.96x
4 4096 128 no 6848.6 6269.1 80.27 87.69 1.09x

Geometric mean speedup: 0.93x · Min: 0.77x · Max: 1.10x

What changed vs initial results:

  • Causal regression eliminated: B=1 S=512 D=64 causal went from 0.69x → 1.00x, B=1 S=512 D=128 causal from 0.94x → 0.93x (stable)
  • JIT artifacts removed: B=1 S=1024 D=64 non-causal went from 1.25x → 0.82x (the "win" was JIT noise); similarly B=4 S=2048 D=64 from 1.12x → 0.79x
  • Range tightened: Min improved from 0.62x → 0.77x; Max normalized from 1.25x → 1.10x

Updated key findings:

  • D=128 compute-heavy configs: TMA provides real 3-10% wins (B=4 S≥1024, B=1 S=4096) where warp specialization amortizes the dedicated DMA warp cost
  • D=64: TMA consistently loses 5-23% — the N=64 KV tile is too small to benefit from hardware-managed bulk loads vs thread-cooperative CpAsync
  • Causal configs: After the fix, causal is roughly neutral to slightly slower (0.93x-1.00x) — the runtime branch adds minor overhead vs CpAsync's compile-time two-loop structure
  • Overall: CpAsync remains the faster default for most configs on SM120. TMA's value is primarily as a reference implementation demonstrating cp.async.bulk + warp specialization patterns in the CuTe DSL

Test Results

Validated on NVIDIA GB10 (SM121a / DGX Spark) hardware at Second Nature Computing against PyTorch scaled_dot_product_attention:

BF16 CpAsync variant

dtype head_dim seqlen_q seqlen_k heads causal tile (m x n) result
FP16 64 128 128 4 no 64x64 PASS
BF16 128 256 256 8 no 64x64 PASS
FP16 128 512 512 16 yes 128x128 PASS
BF16 128 1024 1024 16 yes 128x64 PASS
BF16 64 512 1024 4 no 64x64 PASS

BF16 TMA variant (--use_tma)

dtype batch seqlen heads head_dim kv_stages causal result
BF16 1 128 1 64 1 no PASS
BF16 1 512 4 128 1 no PASS
BF16 1 512 4 64 1 yes PASS
BF16 1 1024 1 128 1 yes PASS
BF16 2 512 4 128 1 no PASS
BF16 2 256 8 64 2 no PASS
BF16 4 128 4 128 1 no PASS
BF16 3 512 4 64 1 no PASS
BF16 2 512 4 128 1 yes PASS
BF16 3 256 8 64 1 yes PASS
BF16 2 1024 1 128 1 yes PASS

FP8 kernel

variant batch seqlen_q seqlen_k heads head_dim causal result
POC 1 16 32 1 128 no PASS
POC 1 16 64 1 128 no PASS
POC 1 16 128 1 128 no PASS
POC 1 32 64 1 128 no PASS
POC 1 64 128 1 128 no PASS
POC 1 16 32 1 64 no PASS
POC 1 16 32 1 256 no PASS
POC 2 32 64 1 128 no PASS
POC 1 32 64 4 128 no PASS
POC 1 32 32 1 128 yes PASS
POC 1 64 64 1 128 yes PASS
Opt 1 64 64 1 128 no PASS
Opt 1 64 128 1 128 no PASS
Opt 1 128 256 1 128 no PASS
Opt 1 64 64 1 64 no PASS
Opt 2 128 128 1 128 no PASS
Opt 1 128 128 4 128 no PASS
Opt 1 64 64 1 128 yes PASS
Opt 1 128 128 1 128 yes PASS
Opt 1 256 256 1 128 yes PASS

Tolerance: atol=1e-02, rtol=1e-04

Motivation

There are currently few high-performance flash attention kernels available for SM120. The existing Blackwell FMHA (blackwell/fmha.py) targets SM100 and uses tcgen05 MMA + TMEM, which SM120 does not support. This implementation fills that gap by:

  1. Providing a working BF16 CpAsync baseline (same algorithmic approach as Ampere FA2)
  2. Adding TMA with warp specialization as a meaningful architectural improvement — hardware-managed bulk data movement frees compute warps from load duties
  3. Demonstrating FP8 flash attention using SM120's native f8f6f4 MMA instructions for up to 2× arithmetic throughput over BF16

Usage

# BF16 CpAsync variant (default)
python flash_attention_v2.py --batch_size 4 --seqlen_q 8192 --seqlen_k 8192 --num_head 16 --head_dim 128

# BF16 TMA variant with warp specialization
python flash_attention_v2.py --use_tma --batch_size 4 --seqlen_q 8192 --seqlen_k 8192 --num_head 16 --head_dim 128 --kv_stages 1

# FP8 Flash Attention
python fp8_flash_attention.py

# FP8 vs BF16 benchmark
python benchmark_fp8_vs_bf16.py --iters 50

Closes #2956


Contributed by Second Nature Computing — tested on DGX Spark hardware

Add a CuTe DSL Flash Attention v2 forward pass kernel targeting SM120
(RTX 5090, GB10 / DGX Spark). SM120 lacks tcgen05 MMA support, so
this implementation uses SM80-compatible tensor core instructions
(mma.sync.aligned.m16n8k16) with CpAsync for global-to-shared memory
transfers — the same proven approach as the Ampere FA2 example, tuned
for SM120's 101 KB shared memory capacity.

Features:
- FP16 and BF16 support
- Online softmax fusion (Flash Attention v2 algorithm)
- Causal masking support
- Configurable tile sizes (m_block_size, n_block_size)
- Register pipeline for smem-to-register overlap
- Predicated loads for boundary handling

Tested on NVIDIA GB10 (SM121a / DGX Spark) with multiple configs:
- head_dim=64/128, seqlen up to 2048, batch_size up to 4
- Both causal and non-causal modes
- Asymmetric Q/K sequence lengths
All verified against PyTorch scaled_dot_product_attention reference.

Closes NVIDIA#2956

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@johnnynunez
Copy link

johnnynunez commented Feb 13, 2026

Summary

Adds a CuTe DSL Flash Attention v2 forward pass kernel for SM120 (RTX 5090, GB10 / DGX Spark), addressing the lack of high-performance FA kernels for this architecture.

  • SM120 lacks tcgen05 MMA (datacenter Blackwell only), so this uses SM80-compatible tensor core instructions (mma.sync.aligned.m16n8k16) with CpAsync — the same approach as the Ampere FA2 example, tuned for SM120's 101 KB shared memory
  • Supports FP16/BF16, causal/non-causal, configurable tile sizes, asymmetric Q/K lengths
  • Online softmax fusion following the Flash Attention v2 algorithm
  • Register pipeline for shared-to-register overlap

Test Results

Validated on NVIDIA GB10 (SM121a / DGX Spark) hardware at Second Nature Computing against PyTorch scaled_dot_product_attention:

dtype head_dim seqlen_q seqlen_k heads causal tile (m×n) result
FP16 64 128 128 4 no 64×64 PASS
BF16 128 256 256 8 no 64×64 PASS
FP16 128 512 512 16 yes 128×128 PASS
FP16 64 256 256 8 no 64×64 PASS
BF16 128 1024 1024 16 yes 128×64 PASS
FP16 128 2048 2048 8 yes 128×128 PASS
BF16 64 512 1024 4 no 64×64 PASS
Tolerance: atol=1e-02, rtol=1e-04

Motivation

There are currently few high-performance flash attention kernels available for SM120. The existing Blackwell FMHA (blackwell/fmha.py) targets SM100 and uses tcgen05 MMA + TMEM, which SM120 does not support. This implementation fills that gap by adapting the proven Ampere approach for SM120's instruction set.

Future Work

  • TMA optimization: SM120 supports TMA (cp.async.bulk) but 4D attention tensor layouts require careful descriptor construction. Upgrading from CpAsync to TMA could improve bandwidth utilization.
  • Warp specialization: Requires setmaxregister support in the DSL (not yet available in nvidia-cutlass-dsl 4.3.5).

Closes #2956

Contributed by Second Nature Computing — tested on DGX Spark hardware

🤖 Generated with Claude Code

Thanks.
Are you using new mma operations for blackwell? I shared this:
https://github.com/NVIDIA/cutlass/blob/main/python/CuTeDSL/cutlass/cute/nvgpu/warp/mma.py#L123

Are you using TMA?

in fact, DGX spark should be flash attention 3 without wargroups from Hopper

@blake-snc
Copy link
Author

blake-snc commented Feb 14, 2026

@johnnynunez Thanks for taking a look! To answer your questions:

The link you shared points to MmaSM120BlockScaledOp, which targets FP4/FP8 block-scaled operations (mma.sync.aligned.block_scale). For FP16/BF16 attention, SM120 only has the SM80-compatible mma.sync.aligned.m16n8k16 via MmaF16BF16Op, which is what this kernel uses. There's no SM120-native FP16/BF16 MMA instruction beyond that, as far as I am aware.

No TMA yet. We're using CpAsync (cp.async) for global-to-shared data movement. SM120 does support TMA (cp.async.bulk), but constructing descriptors for 4D attention tensor layouts with the CuTe DSL requires some work. It's listed as future work and would be a meaningful bandwidth improvement.

Regarding FA3 without warpgroups: From what I know, FA3's core performance gains come from three techniques that are deeply tied to async WGMMA. Producer-consumer warp specialization needs WGMMA + TMA overlap. Pingpong scheduling needs warpgroups with barrier synchronization. And intra-warpgroup GEMM-softmax overlap relies on WGMMA executing asynchronously, which mma.sync does not. SM120 also doesn't have TMEM (requires tcgen05, confirmed in cute/arch/config.hpp). So adapting FA3 to SM120 would lose the features that make it faster than FA2 in the first place.

One FA3 improvement that does seem portable is FP8 block quantization, where SM120's block-scaled MMA would be a natural fit. That could be a good follow-up kernel using MmaSM120BlockScaledOp for FP8 attention specifically.

@drisspg
Copy link
Contributor

drisspg commented Feb 14, 2026

@johnnynunez
Copy link

@blake-snc
Copy link
Author

blake-snc commented Feb 17, 2026

@drisspg Good question - the Dao-AILab SM80 path uses the same fundamental approach: mma.sync.aligned.m16n8k16 with CpAsync and online softmax. Their SM80 kernel is significantly more feature-complete (GQA packing, sliding window, block sparsity, variable-length sequences, score/mask hooks, persistent tile schedulers). It would likely run on SM120 out of the box since SM120 supports SM80 instructions.

This PR is more narrowly scoped - an SM120-tuned standalone example for the CUTLASS repo addressing #2956, with tile sizes configured for SM120's 101 KB shared memory. The main opportunity for differentiation would be adding TMA (cp.async.bulk), which SM120 supports but the Dao-AILab SM80 path doesn't use. That's listed as future work here.

@drisspg
Copy link
Contributor

drisspg commented Feb 17, 2026

oh whoops sorry you can ignore my comment I though this PR was in the FA repo, my bad

Add FlashAttentionForwardSm120Tma alongside the existing CpAsync
implementation, using TMA (cp.async.bulk) loads with a dedicated DMA
warp for compute/load overlap:

- 4D TMA descriptors (seq, dim, head, batch) for multi-batch support
- TMA-compatible Swizzle(B, 4, 3) pattern (M=4 required for TMA hardware)
- Warp specialization: 1 DMA warp + N MMA warps (default 4)
- PipelineTmaAsync with mbarrier-based producer/consumer synchronization
- Multi-stage KV double-buffering (configurable kv_stages)
- Separate K and V pipelines for independent scheduling
- SM80-compatible MMA (mma.sync.aligned.m16n8k16) unchanged

Validated on NVIDIA GB10 (SM121a / DGX Spark) against PyTorch SDPA:
  B=1..4, S=128..1024, H=1..8, D=64/128, causal/non-causal

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@blake-snc
Copy link
Author

Updated with a TMA variant (FlashAttentionForwardSm120Tma) alongside the existing CpAsync implementation:

  • Uses cp.async.bulk (TMA) with warp specialization: 1 DMA warp handles TMA loads while MMA warps compute
  • 4D TMA descriptors (seq, dim, head, batch) for native multi-batch support
  • PipelineTmaAsync with mbarrier synchronization for KV double-buffering
  • Separate K and V pipelines for independent scheduling

Key finding during development: TMA on SM120 requires Swizzle(B, 4, 3) patterns (M=4). The CpAsync version uses Swizzle(B, 3, 3) which works fine for cp.async but causes CUDA_ERROR_ILLEGAL_INSTRUCTION when used with TMA descriptors. This matches the swizzle patterns used by all TMA-based kernels in CUTLASS (warpgroup/helpers.py, tcgen05/helpers.py).

Usage: python flash_attention_v2.py --use_tma --kv_stages 1

Verified on GB10 (SM121a) with B=1..4, S=128..1024, H=1..8, D=64/128, causal/non-causal.

@johnnynunez
Copy link

Updated with a TMA variant (FlashAttentionForwardSm120Tma) alongside the existing CpAsync implementation:

  • Uses cp.async.bulk (TMA) with warp specialization: 1 DMA warp handles TMA loads while MMA warps compute
  • 4D TMA descriptors (seq, dim, head, batch) for native multi-batch support
  • PipelineTmaAsync with mbarrier synchronization for KV double-buffering
  • Separate K and V pipelines for independent scheduling

Key finding during development: TMA on SM120 requires Swizzle(B, 4, 3) patterns (M=4). The CpAsync version uses Swizzle(B, 3, 3) which works fine for cp.async but causes CUDA_ERROR_ILLEGAL_INSTRUCTION when used with TMA descriptors. This matches the swizzle patterns used by all TMA-based kernels in CUTLASS (warpgroup/helpers.py, tcgen05/helpers.py).

Usage: python flash_attention_v2.py --use_tma --kv_stages 1

Verified on GB10 (SM121a) with B=1..4, S=128..1024, H=1..8, D=64/128, causal/non-causal.

super cool! I'm going to report internally

@IonThruster
Copy link
Collaborator

Thanks for the PR - did you happen to collect any performance data vs the existing FAV2 non-TMA kernel ?

@johnnynunez
Copy link

johnnynunez commented Feb 20, 2026

@blake-snc i love your work... do you want adapt it to https://github.com/Dao-AILab/flash-attention? If not i will try

@blake-snc
Copy link
Author

@johnnynunez Thanks, I really appreciate that! I'd love to take that on. I'll put up a PR to Dao-AILab/flash-attention with an SM120 path. They already have the CuTe DSL flash_fwd.py for SM80, so the structure is there to build on.

I'm also collecting TMA vs CpAsync benchmark numbers on GB10 for @IonThruster's question as I did not think about that one, and I will post those a bit later!

@blake-snc
Copy link
Author

blake-snc commented Feb 20, 2026

@IonThruster Here are the benchmark results comparing the CpAsync (non-TMA) kernel vs. the TMA kernel from this PR, measured on DGX Spark (GB10 / SM121a).

Benchmark Configuration

  • Device: NVIDIA GB10 (DGX Spark), SM121a
  • Precision: BF16
  • MMA: mma.sync.aligned.m16n8k16 (SM80-compatible)
  • Heads: 16 (all configs)
  • Warmup: 10 iterations, Measured: 50 iterations
  • Metric: Wall-clock kernel time (us), lower is better. Speedup = CpAsync time / TMA time (>1 means TMA is faster).

Results

Batch SeqLen HeadDim Causal CpAsync (us) TMA (us) CpAsync TFLOPS TMA TFLOPS Speedup
1 512 64 no 18.0 20.5 59.62 52.43 0.88x
1 512 64 yes 16.5 17.3 32.54 31.07 0.95x
4 512 64 no 106.5 110.4 40.32 38.91 0.97x
4 512 64 yes 85.8 95.6 25.04 22.46 0.90x
1 512 128 no 89.4 92.3 24.02 23.28 0.97x
1 512 128 yes 32.8 36.7 32.78 29.26 0.89x
4 512 128 no 379.0 417.5 22.66 20.57 0.91x
4 512 128 yes 383.7 376.5 11.19 11.41 1.02x
1 1024 64 no 108.0 125.3 39.79 34.28 0.86x
1 1024 64 yes 88.9 94.0 24.16 22.85 0.95x
4 1024 64 no 601.9 938.4 28.54 18.31 0.64x
4 1024 64 yes 479.8 465.5 17.90 18.45 1.03x
1 1024 128 no 265.1 305.3 32.40 28.14 0.87x
1 1024 128 yes 185.8 203.5 23.11 21.11 0.91x
4 1024 128 no 1296.5 1274.6 26.50 26.96 1.02x
4 1024 128 yes 864.7 920.4 19.87 18.67 0.94x
1 2048 64 no 567.6 499.5 30.27 34.39 1.14x
1 2048 64 yes 195.6 303.9 43.92 28.27 0.64x
4 2048 64 no 2057.3 2450.4 33.40 28.04 0.84x
4 2048 64 yes 1073.5 1252.4 32.01 27.44 0.86x
1 2048 128 no 1126.8 1147.8 30.49 29.93 0.98x
1 2048 128 yes 665.9 635.9 25.80 27.02 1.05x
4 2048 128 no 4435.0 4241.3 30.99 32.40 1.05x
4 2048 128 yes 2212.8 2254.9 31.06 30.48 0.98x
1 4096 64 no 1954.6 2043.5 35.16 33.63 0.96x
1 4096 64 yes 1069.1 1343.5 32.14 25.57 0.80x
4 4096 64 no 8340.2 9976.9 32.96 27.55 0.84x
4 4096 64 yes 4259.2 4273.4 32.27 32.16 1.00x
1 4096 128 no 3666.2 3677.1 37.49 37.38 1.00x
1 4096 128 yes 2052.5 2063.7 33.48 33.30 0.99x
4 4096 128 no 14650.5 18450.9 37.52 29.80 0.79x

Note: Configs with SeqLen=8192 and (B=4, S=4096, D=128, causal=yes) failed with OOM on the GB10's unified memory.

Summary

  • CpAsync generally matches or beats TMA across the tested configurations. Out of 31 successful configs, TMA shows marginal wins (1.02x-1.14x) on 6 of them, while CpAsync wins by larger margins on many others.
  • TMA's advantage is limited on SM120. Our working hypothesis is that the overhead from warp specialization and TMA descriptor setup doesn't pay off here because the SM80-compatible mma.sync.aligned.m16n8k16 instructions are narrow enough that a monolithic kernel (all threads doing both loads and compute) keeps threads fully utilized. There isn't enough idle time during compute to justify dedicating warps to asynchronous loads.
  • TMA does show promise at larger problem sizes (e.g., B=1 S=2048 D=64 non-causal at 1.14x, and several ~1.02-1.05x wins at S>=2048 D=128), which suggests TMA could pay off once the working set is large enough to stress the memory subsystem.

@blake-snc
Copy link
Author

@IonThruster Great question — here are the benchmark results comparing the CpAsync (non-TMA) kernel vs. the TMA kernel from this PR, measured on DGX Spark (GB10 / SM121a).

Benchmark Configuration

  • Device: NVIDIA GB10 (DGX Spark), SM121a
  • Precision: BF16
  • MMA: mma.sync.aligned.m16n8k16 (SM80-compatible)
  • Heads: 16 (all configs)
  • Metric: Wall-clock kernel time (μs), lower is better

Results

Batch SeqLen HeadDim Causal CpAsync (μs) TMA (μs) CpAsync TFLOPS TMA TFLOPS Speedup (TMA/CpAsync)
1 512 64 no 18.0 20.5 59.62 52.43 0.88x
1 512 64 yes 16.5 17.3 32.54 31.07 0.95x
4 512 64 no 106.5 110.4 40.32 38.91 0.97x
4 512 64 yes 85.8 95.6 25.04 22.46 0.90x
1 512 128 no 89.4 92.3 24.02 23.28 0.97x
1 512 128 yes 32.8 36.7 32.78 29.26 0.89x
4 512 128 no 379.0 417.5 22.66 20.57 0.91x
4 512 128 yes 383.7 376.5 11.19 11.41 1.02x
1 1024 64 no 108.0 125.3 39.79 34.28 0.86x
1 1024 64 yes 88.9 94.0 24.16 22.85 0.95x
4 1024 64 no 601.9 938.4 28.54 18.31 0.64x
4 1024 64 yes 479.8 465.5 17.90 18.45 1.03x
1 1024 128 no 265.1 305.3 32.40 28.14 0.87x
1 1024 128 yes 185.8 203.5 23.11 21.11 0.91x
4 1024 128 no 1296.5 1274.6 26.50 26.96 1.02x
4 1024 128 yes 864.7 920.4 19.87 18.67 0.94x
1 2048 64 no 567.6 499.5 30.27 34.39 1.14x
1 2048 64 yes 195.6 303.9 43.92 28.27 0.64x
4 2048 64 no 2057.3 2450.4 33.40 28.04 0.84x
4 2048 64 yes 1073.5 1252.4 32.01 27.44 0.86x
1 2048 128 no 1126.8 1147.8 30.49 29.93 0.98x
1 2048 128 yes 665.9 635.9 25.80 27.02 1.05x
4 2048 128 no 4435.0 4241.3 30.99 32.40 1.05x
4 2048 128 yes 2212.8 2254.9 31.06 30.48 0.98x
1 4096 64 no 1954.6 2043.5 35.16 33.63 0.96x
1 4096 64 yes 1069.1 1343.5 32.14 25.57 0.80x
4 4096 64 no 8340.2 9976.9 32.96 27.55 0.84x
4 4096 64 yes 4259.2 4273.4 32.27 32.16 1.00x
1 4096 128 no 3666.2 3677.1 37.49 37.38 1.00x
1 4096 128 yes 2052.5 2063.7 33.48 33.30 0.99x
4 4096 128 no 14650.5 18450.9 37.52 29.80 0.79x

Note: Configs with SeqLen=8192, and (B=4, S=4096, D=128, causal=yes) failed with OOM on the GB10's unified memory.

Summary

  • CpAsync generally matches or beats TMA across the tested configurations. Out of 31 configs, TMA only shows marginal wins (1.02x–1.14x) on 6 of them, while CpAsync wins by larger margins on many others.
  • TMA's advantage is limited on SM120. Our working hypothesis is that the overhead from warp specialization and TMA descriptor setup doesn't pay off here because the SM80-compatible mma.sync.aligned.m16n8k16 instructions are narrow enough that a monolithic kernel (all threads doing both loads and compute) keeps threads fully utilized. There isn't enough idle time during compute to justify dedicating warps to asynchronous loads.
  • TMA does show promise at larger problem sizes (e.g., B=1 S=2048 D=64 non-causal at 1.14x, and several ~1.02–1.05x wins at S≥2048 D=128), which suggests TMA could pay off once the working set is large enough to stress the memory subsystem.

Next Steps

We're investigating further optimizations to the TMA path:

  • K/V staging strategies — double-buffering K/V tiles through shared memory to better overlap TMA loads with MMA compute
  • Warp allocation ratios — experimenting with different producer/consumer warp ratios to find the optimal split on SM120
  • Profiling with Nsight Compute to identify the actual bottleneck (descriptor setup latency vs. memory bandwidth vs. warp scheduling)

Happy to share Nsight Compute profiles or run additional configs if useful.


Contributed by Second Nature Computing

SM120's FP8 MMA uses `mma.sync.aligned.kind::f8f6f4.m16n8k32`
(SM120_16x8x32_TN in mma_sm120.hpp), which differs from SM89's FP8
instruction and is not yet exposed in the CuTe Python DSL. Added a
NOTE documenting this for future FP8 FA enablement. Also fixed
run/run_tma to capture and print avg execution time.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@blake-snc
Copy link
Author

FP8 Flash Attention Update: SM120 FP8 MMA Validated via Inline PTX

We investigated adding FP8 flash attention using SM120's mma.sync.aligned.kind::f8f6f4.m16n8k32 instruction (2x theoretical throughput over BF16's m16n8k16).

Key findings:

  1. MmaAtomSM80Type segfaults on SM120 with FP8 types — the MLIR backend doesn't have a lowering path for SM120's kind::f8f6f4 PTX variant. Filed as CuTe DSL: FP8 MMA segfaults on SM120 — MmaAtomSM80Type missing kind::f8f6f4 lowering #3044.

  2. Workaround: inline PTX assembly works perfectly. Using llvm.inline_asm() with @dsl_user_op, we can emit the kind::f8f6f4 instruction directly, bypassing the broken MLIR MMA lowering. Validated on GB10 (SM121a) — compilation succeeds and produces correct results (32.0 for K=32 all-ones FP8 inputs).

  3. FP8 FA kernel in progress — we're building an FP8 flash attention kernel using this inline PTX approach. Additional challenges solved:

    • FP8 SMEM alignment: CopyUniversalOp with 64-bit copies (CpAsync requires 128-bit alignment which FP8 can't provide with standard swizzles)
    • DLPack FP8: uint8 storage with element_type override
    • S2R loads: CopyUniversalOp with 32-bit copies (LdMatrix is 16-bit only, MMA partitioning gives only 4-element alignment)

Will update with the full FP8 FA kernel and benchmarks once complete.


Contributed by Second Nature Computing

@blake-snc
Copy link
Author

FP8 Flash Attention Progress Update

Following up on the earlier FP8 investigation — the proof-of-concept FP8 flash attention kernel is now producing correct output across all test configurations.

What works:

  • Full FP8 attention pipeline validated: FP8 Q/K/V loads → kind::f8f6f4.m16n8k32 MMA → online softmax → FP32→FP8 conversion → PV GEMM → FP32 output
  • cvt.rn.satfinite.e4m3x2.f32 PTX instruction validated for FP32→FP8 conversion (the pair variant; single-element e4m3.f32 does not exist in PTX ISA)
  • Causal masking with per-element mask
  • All 11 test configs pass (varying batch, seq_len, heads, head_dim, causal/non-causal)

Current kernel architecture (single-warp POC):

The current kernel uses a minimal single-warp (32 threads) design with M=16, N=32 tiles and GMEM O accumulation. This validates correctness of the FP8 MMA register layout and the full softmax→conversion→GEMM pipeline, but is not performance-optimized — O accumulation through global memory is the dominant bottleneck.

Next step: register-tiled multi-warp kernel

Now working on a performance-optimized FP8 kernel matching the BF16 kernel's architecture (4 warps, M=128/N=64 tiles, register O accumulation). The kind::f8f6f4.m16n8k32 instruction provides 2x theoretical MMA throughput over BF16's m16n8k16, and FP8 inputs halve memory bandwidth requirements — so the optimized kernel should show meaningful gains. Will share benchmarks once the register-tiled version is ready.


Contributed by Second Nature Computing

@blake-snc
Copy link
Author

FP8 Flash Attention — Performance Update

Added fp8_flash_attention.py with an FP8 Flash Attention kernel for SM120 using mma.sync.aligned.kind::f8f6f4.m16n8k32. This required inline PTX since MmaAtomSM80Type segfaults with FP8 types on SM120 (#3044).

Optimizations applied

  1. Register O accumulation — 4 warps (M=64, N=32), O stays in registers across N-tiles
  2. CpAsync pipeline — V load overlaps with QK GEMM, next K load overlaps with transpose+PV GEMM
  3. SMEM bank-conflict-free padding — +16 bytes/row for FP8 buffers, +4 F32/row for P buffer. Eliminates 8-way bank conflicts from row-major stride matching bank count
  4. Vectorized V transpose — 4×4 byte register transpose via prmt.b32 (8 instructions per 16-byte block, no SMEM scratch)
  5. FP32→FP8 conversioncvt.rn.satfinite.e4m3x2.f32 pair variant with mov.b32 packing

Benchmark on DGX Spark (NVIDIA GB10, SM121a)

FP8 kernel: FP8FlashAttentionSm120Opt — CpAsync, bank-conflict-free SMEM, 4 warps (M=64, N=32)
BF16 kernel: FlashAttentionForwardSm120 — CpAsync, tiled MMA, 4 warps (M=128, N=64)

Batch SeqLen HeadDim Causal FP8 (μs) BF16 (μs) FP8 TFLOPS BF16 TFLOPS Ratio
1 512 64 no 25.3 20.5 42.41 52.44 0.81x
1 512 64 yes 24.5 20.5 21.90 26.22 0.84x
1 1024 64 no 181.7 108.2 23.63 39.71 0.60x
1 1024 64 yes 128.6 103.0 16.71 20.84 0.80x
1 2048 64 no 722.1 520.5 23.79 33.01 0.72x
1 2048 64 yes 478.7 335.3 17.94 25.62 0.70x
1 512 128 no 118.4 109.5 18.14 19.62 0.92x
1 512 128 yes 107.3 91.2 10.00 11.77 0.85x
1 1024 128 no 354.4 297.6 24.24 28.87 0.84x
1 1024 128 yes 290.9 187.4 14.76 22.92 0.64x
1 2048 128 no 1751.6 1146.6 19.62 29.97 0.65x
1 2048 128 yes 827.7 679.9 20.76 25.27 0.82x
4 512 128 no 479.0 238.2 17.93 36.06 0.50x
4 1024 128 no 974.4 1345.2 35.26 25.54 1.38x

The FP8 kernel reaches 0.60–1.38x of the BF16 kernel's performance, peaking at 42 TFLOPS for D=64 and 35 TFLOPS for D=128. The B=4 S=1024 D=128 case beats BF16 by 38%.

Remaining performance gaps

The main gap vs BF16 comes from:

  • Software V transpose — SM120 lacks hardware 8-bit transpose (ldmatrix.b8 requires tcgen05, absent on SM120). Using prmt.b32 register transpose as workaround.
  • P SMEM round-trip — MMA CLayout → A-operand layout mismatch requires writing P to SMEM (as F32) and re-reading. BF16 keeps P in registers via cute.make_fragment_like.
  • Manual MMA — inline PTX MMA instead of CuTe's cute.gemm with automatic register management. Blocked by CuTe DSL: FP8 MMA segfaults on SM120 — MmaAtomSM80Type missing kind::f8f6f4 lowering #3044.

Files added

File Description
fp8_flash_attention.py FP8 FA with POC + optimized kernels, 20 correctness tests
benchmark_fp8_vs_bf16.py FP8 vs BF16 sweep benchmark
fp8_gemm.py Standalone FP8 GEMM validation

Second Nature Computing

FP8 Flash Attention using mma.sync.aligned.kind::f8f6f4.m16n8k32
with CpAsync pipelining and bank-conflict-free SMEM layout.

- POC kernel: 1 warp, basic correctness validation
- Optimized kernel: 4 warps, register O accumulation,
  vectorized 4x4 byte transpose via prmt.b32,
  CpAsync double-buffered K/V pipeline,
  +16 byte SMEM row padding for bank conflict elimination
- FP8 GEMM helper with inline PTX MMA (workaround for NVIDIA#3044)
- Benchmark script comparing FP8 vs BF16 kernels

Performance on DGX Spark (SM121a):
- FP8 peaks at 42.4 TFLOPS (D=64) and 35.3 TFLOPS (D=128)
- FP8 outperforms BF16 by up to 1.38x at larger batch sizes

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@johnnynunez
Copy link

johnnynunez commented Feb 21, 2026

@blake-snc
Copy link
Author

@johnnynunez Thanks for sharing the TMA results! We've seen gau-nernst's work — impressive numbers (94.4% peak on RTX 5090 with Ampere-era instructions).

Our TMA variant (FlashAttentionForwardSm120Tma) already uses cp.async.bulk with warp specialization, and we just pushed an FP8 kernel (FP8FlashAttentionSm120Opt) that hits 42.4 TFLOPS using SM120's native f8f6f4.m16n8k32 MMA — see the benchmark table in the updated PR description.

Also — we already have the Dao-AILab/flash-attention adaptation up at #2268 (just rebased onto latest main). Happy to coordinate if you want to help expand it (e.g. varlen, backward pass).

…nal tiles

The TMA consumer loop was passing in_mask_steps=True for every KV tile
when is_causal=True, applying expensive per-element causal masking
(identity tensor creation + column comparisons) to all tiles including
those fully below the diagonal. This caused up to 40% regression vs
CpAsync on causal workloads.

Fix: add a runtime check to only apply masking for the mask_steps tiles
near the causal diagonal (n_block >= n_block_max - ceil_div(m_block_size,
n_block_size)), matching the CpAsync variant's two-loop approach. Tiles
fully below the diagonal use in_mask_steps=False and skip the masking.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@blake-snc
Copy link
Author

TMA Causal Masking Fix + Updated Benchmarks

Bug fix (commit 106e24b)

The TMA variant had a causal masking bug: it applied expensive per-element causal masking to all KV tiles when is_causal=True, instead of only the ceil_div(m_block, n_block) tiles near the diagonal. The CpAsync variant correctly used a two-loop structure (masked tiles near diagonal + fast unmasked tiles below).

Root cause: The consumer loop passed in_mask_steps=cutlass.const_expr(self._is_causal) — this was True for every tile in causal mode. The fix adds a runtime branch within the single loop:

if cutlass.const_expr(self._is_causal):
    causal_mask_boundary = n_block_max - cute.ceil_div(m_block_size, n_block_size)
    if n_block >= causal_mask_boundary:
        # Near diagonal — apply per-element causal mask
        self._softmax_rescale_O(..., in_mask_steps=True)
    else:
        # Below diagonal — skip masking (all elements valid)
        self._softmax_rescale_O(..., in_mask_steps=False)

Why a single loop with runtime branch instead of CpAsync's two-loop approach: CuTe DSL's MLIR IR has SSA dominance constraints — PipelineTmaAsync state objects produce new SSA values on .advance() that can't flow across separate loop boundaries. A two-loop approach produces operand #0 does not dominate this use errors at IR verification time.

Impact on causal configs

Config Before fix After fix
B=1 S=512 D=64 causal 0.69x 1.00x
B=1 S=512 D=128 causal 0.94x 0.93x
B=1 S=1024 D=64 causal 0.88x 0.85x
B=1 S=1024 D=128 causal 0.99x 0.95x

The biggest improvement is at short sequences where the masked-tile fraction is highest. Updated full benchmark results are in the PR description.

Two correctness fixes to FlashAttentionForwardSm120Tma:

1. Non-causal OOB masking: the last K tile is now masked when seqlen_k
   is not divisible by n_block_size. TMA zero-fills OOB positions during
   load, but softmax must treat them as -inf (not 0) to avoid corrupting
   the normalization. The consumer loop now passes in_mask_steps=True for
   n_block == n_block_max - 1 in the non-causal path, and
   _softmax_rescale_O handles both causal and non-causal masking when
   in_mask_steps=True.

2. SMEM capacity check: the previous estimate used 3*1024 bytes of
   alignment overhead, which over-counted by ~2 KB. The mbar region
   (< 200 B) rounds up to 1024 B before sQ; sQ and sKV are typically
   multiples of 1024 B for standard tile configs, so they need no further
   padding. Updated can_implement() to use the actual layout arithmetic,
   allowing the default config (m=128, n=64, d=128, kv_stages=2, bf16)
   which uses 97.0 KB of SM120's 99.0 KB SMEM budget.

Validated on SM121a with 8 configs:
  - TMA: default non-causal, default causal, seqlen_k non-divisible
    (non-causal and causal), head_dim=64 — all PASS
  - CpAsync: same 4 configs — all PASS (no regressions)

Contributed by Second Nature Computing (https://joinsecondnature.com)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@blake-snc
Copy link
Author

Update: Two additional correctness fixes (commit 6e193de)

1. Non-causal OOB masking

The TMA variant was missing out-of-bounds masking for the last K tile when seqlen_k is not divisible by n_block_size. TMA zero-fills OOB positions during load, but softmax must treat those positions as -inf (not 0) to avoid contributing to the normalization.

Fix: the consumer loop now passes in_mask_steps=True for n_block == n_block_max - 1 in the non-causal path, and _softmax_rescale_O applies OOB masking in that path.

Previously only visible with seqlen_k % n_block_size != 0 — not triggered by the default test (seqlen_k=8192 is divisible by 64). All 8 configs now verified correct on SM121a.

2. SMEM capacity check in can_implement

The previous estimate used 3 * 1024 bytes for alignment overhead, which over-counted by ~2 KB. The actual layout is:

[mbar arrays (80 B)] [pad to 1024] [sQ] [sK stages] [sV stages]

The mbar region (< 200 B) rounds up to 1024 B before sQ. sQ and sKV are multiples of 1024 B for typical tile sizes, so they need no additional alignment padding. The corrected formula uses the actual layout arithmetic, allowing the default config (m=128, n=64, d=128, kv_stages=2, BF16) which needs 97.0 KB of SM120's 99.0 KB SMEM.

Add BSD-3-Clause license header to benchmark_fp8_vs_bf16.py (was missing
entirely). Normalize license header format in fp8_flash_attention.py to
match the canonical CUTLASS style used across all other example files.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@johnnynunez
Copy link

cc @Junkai-Wu

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[FEA] A flash-attention cuteDSL kernel for sm120

4 participants