[CuTeDSL] Flash Attention v2 for SM120 (Blackwell GeForce)#3030
[CuTeDSL] Flash Attention v2 for SM120 (Blackwell GeForce)#3030blake-snc wants to merge 7 commits intoNVIDIA:mainfrom
Conversation
Add a CuTe DSL Flash Attention v2 forward pass kernel targeting SM120 (RTX 5090, GB10 / DGX Spark). SM120 lacks tcgen05 MMA support, so this implementation uses SM80-compatible tensor core instructions (mma.sync.aligned.m16n8k16) with CpAsync for global-to-shared memory transfers — the same proven approach as the Ampere FA2 example, tuned for SM120's 101 KB shared memory capacity. Features: - FP16 and BF16 support - Online softmax fusion (Flash Attention v2 algorithm) - Causal masking support - Configurable tile sizes (m_block_size, n_block_size) - Register pipeline for smem-to-register overlap - Predicated loads for boundary handling Tested on NVIDIA GB10 (SM121a / DGX Spark) with multiple configs: - head_dim=64/128, seqlen up to 2048, batch_size up to 4 - Both causal and non-causal modes - Asymmetric Q/K sequence lengths All verified against PyTorch scaled_dot_product_attention reference. Closes NVIDIA#2956 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Thanks. Are you using TMA? in fact, DGX spark should be flash attention 3 without wargroups from Hopper |
|
@johnnynunez Thanks for taking a look! To answer your questions: The link you shared points to No TMA yet. We're using CpAsync ( Regarding FA3 without warpgroups: From what I know, FA3's core performance gains come from three techniques that are deeply tied to async WGMMA. Producer-consumer warp specialization needs WGMMA + TMA overlap. Pingpong scheduling needs warpgroups with barrier synchronization. And intra-warpgroup GEMM-softmax overlap relies on WGMMA executing asynchronously, which One FA3 improvement that does seem portable is FP8 block quantization, where SM120's block-scaled MMA would be a natural fit. That could be a good follow-up kernel using |
|
How does this compare to the existing sm80 kernel here: https://github.com/Dao-AILab/flash-attention/blob/c4d8b0630eb81cf88206e0cc9e9bff4e7806d88f/flash_attn/cute/flash_fwd.py#L52 |
in fact if you execute this https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/ampere/flash_attention_v2.py , it runs in DGX Spark the thing is the results shared https://github.com/gau-nernst/learn-cuda/tree/main/02c_matmul_sm120 related with TMA |
|
@drisspg Good question - the Dao-AILab SM80 path uses the same fundamental approach: This PR is more narrowly scoped - an SM120-tuned standalone example for the CUTLASS repo addressing #2956, with tile sizes configured for SM120's 101 KB shared memory. The main opportunity for differentiation would be adding TMA ( |
|
oh whoops sorry you can ignore my comment I though this PR was in the FA repo, my bad |
Add FlashAttentionForwardSm120Tma alongside the existing CpAsync implementation, using TMA (cp.async.bulk) loads with a dedicated DMA warp for compute/load overlap: - 4D TMA descriptors (seq, dim, head, batch) for multi-batch support - TMA-compatible Swizzle(B, 4, 3) pattern (M=4 required for TMA hardware) - Warp specialization: 1 DMA warp + N MMA warps (default 4) - PipelineTmaAsync with mbarrier-based producer/consumer synchronization - Multi-stage KV double-buffering (configurable kv_stages) - Separate K and V pipelines for independent scheduling - SM80-compatible MMA (mma.sync.aligned.m16n8k16) unchanged Validated on NVIDIA GB10 (SM121a / DGX Spark) against PyTorch SDPA: B=1..4, S=128..1024, H=1..8, D=64/128, causal/non-causal Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
|
Updated with a TMA variant (
Key finding during development: TMA on SM120 requires Usage: Verified on GB10 (SM121a) with B=1..4, S=128..1024, H=1..8, D=64/128, causal/non-causal. |
super cool! I'm going to report internally |
|
Thanks for the PR - did you happen to collect any performance data vs the existing FAV2 non-TMA kernel ? |
|
@blake-snc i love your work... do you want adapt it to https://github.com/Dao-AILab/flash-attention? If not i will try |
|
@johnnynunez Thanks, I really appreciate that! I'd love to take that on. I'll put up a PR to Dao-AILab/flash-attention with an SM120 path. They already have the CuTe DSL flash_fwd.py for SM80, so the structure is there to build on. I'm also collecting TMA vs CpAsync benchmark numbers on GB10 for @IonThruster's question as I did not think about that one, and I will post those a bit later! |
|
@IonThruster Here are the benchmark results comparing the CpAsync (non-TMA) kernel vs. the TMA kernel from this PR, measured on DGX Spark (GB10 / SM121a). Benchmark Configuration
Results
Note: Configs with SeqLen=8192 and (B=4, S=4096, D=128, causal=yes) failed with OOM on the GB10's unified memory. Summary
|
|
@IonThruster Great question — here are the benchmark results comparing the CpAsync (non-TMA) kernel vs. the TMA kernel from this PR, measured on DGX Spark (GB10 / SM121a). Benchmark Configuration
Results
Note: Configs with SeqLen=8192, and (B=4, S=4096, D=128, causal=yes) failed with OOM on the GB10's unified memory. Summary
Next StepsWe're investigating further optimizations to the TMA path:
Happy to share Nsight Compute profiles or run additional configs if useful. Contributed by Second Nature Computing |
SM120's FP8 MMA uses `mma.sync.aligned.kind::f8f6f4.m16n8k32` (SM120_16x8x32_TN in mma_sm120.hpp), which differs from SM89's FP8 instruction and is not yet exposed in the CuTe Python DSL. Added a NOTE documenting this for future FP8 FA enablement. Also fixed run/run_tma to capture and print avg execution time. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
FP8 Flash Attention Update: SM120 FP8 MMA Validated via Inline PTXWe investigated adding FP8 flash attention using SM120's Key findings:
Will update with the full FP8 FA kernel and benchmarks once complete. Contributed by Second Nature Computing |
FP8 Flash Attention Progress UpdateFollowing up on the earlier FP8 investigation — the proof-of-concept FP8 flash attention kernel is now producing correct output across all test configurations. What works:
Current kernel architecture (single-warp POC): The current kernel uses a minimal single-warp (32 threads) design with M=16, N=32 tiles and GMEM O accumulation. This validates correctness of the FP8 MMA register layout and the full softmax→conversion→GEMM pipeline, but is not performance-optimized — O accumulation through global memory is the dominant bottleneck. Next step: register-tiled multi-warp kernel Now working on a performance-optimized FP8 kernel matching the BF16 kernel's architecture (4 warps, M=128/N=64 tiles, register O accumulation). The Contributed by Second Nature Computing |
FP8 Flash Attention — Performance UpdateAdded Optimizations applied
Benchmark on DGX Spark (NVIDIA GB10, SM121a)FP8 kernel:
The FP8 kernel reaches 0.60–1.38x of the BF16 kernel's performance, peaking at 42 TFLOPS for D=64 and 35 TFLOPS for D=128. The B=4 S=1024 D=128 case beats BF16 by 38%. Remaining performance gapsThe main gap vs BF16 comes from:
Files added
|
FP8 Flash Attention using mma.sync.aligned.kind::f8f6f4.m16n8k32 with CpAsync pipelining and bank-conflict-free SMEM layout. - POC kernel: 1 warp, basic correctness validation - Optimized kernel: 4 warps, register O accumulation, vectorized 4x4 byte transpose via prmt.b32, CpAsync double-buffered K/V pipeline, +16 byte SMEM row padding for bank conflict elimination - FP8 GEMM helper with inline PTX MMA (workaround for NVIDIA#3044) - Benchmark script comparing FP8 vs BF16 kernels Performance on DGX Spark (SM121a): - FP8 peaks at 42.4 TFLOPS (D=64) and 35.3 TFLOPS (D=128) - FP8 outperforms BF16 by up to 1.38x at larger batch sizes Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
|
@johnnynunez Thanks for sharing the TMA results! We've seen gau-nernst's work — impressive numbers (94.4% peak on RTX 5090 with Ampere-era instructions). Our TMA variant ( Also — we already have the Dao-AILab/flash-attention adaptation up at #2268 (just rebased onto latest main). Happy to coordinate if you want to help expand it (e.g. varlen, backward pass). |
…nal tiles The TMA consumer loop was passing in_mask_steps=True for every KV tile when is_causal=True, applying expensive per-element causal masking (identity tensor creation + column comparisons) to all tiles including those fully below the diagonal. This caused up to 40% regression vs CpAsync on causal workloads. Fix: add a runtime check to only apply masking for the mask_steps tiles near the causal diagonal (n_block >= n_block_max - ceil_div(m_block_size, n_block_size)), matching the CpAsync variant's two-loop approach. Tiles fully below the diagonal use in_mask_steps=False and skip the masking. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
TMA Causal Masking Fix + Updated BenchmarksBug fix (commit
|
| Config | Before fix | After fix |
|---|---|---|
| B=1 S=512 D=64 causal | 0.69x | 1.00x |
| B=1 S=512 D=128 causal | 0.94x | 0.93x |
| B=1 S=1024 D=64 causal | 0.88x | 0.85x |
| B=1 S=1024 D=128 causal | 0.99x | 0.95x |
The biggest improvement is at short sequences where the masked-tile fraction is highest. Updated full benchmark results are in the PR description.
Two correctness fixes to FlashAttentionForwardSm120Tma:
1. Non-causal OOB masking: the last K tile is now masked when seqlen_k
is not divisible by n_block_size. TMA zero-fills OOB positions during
load, but softmax must treat them as -inf (not 0) to avoid corrupting
the normalization. The consumer loop now passes in_mask_steps=True for
n_block == n_block_max - 1 in the non-causal path, and
_softmax_rescale_O handles both causal and non-causal masking when
in_mask_steps=True.
2. SMEM capacity check: the previous estimate used 3*1024 bytes of
alignment overhead, which over-counted by ~2 KB. The mbar region
(< 200 B) rounds up to 1024 B before sQ; sQ and sKV are typically
multiples of 1024 B for standard tile configs, so they need no further
padding. Updated can_implement() to use the actual layout arithmetic,
allowing the default config (m=128, n=64, d=128, kv_stages=2, bf16)
which uses 97.0 KB of SM120's 99.0 KB SMEM budget.
Validated on SM121a with 8 configs:
- TMA: default non-causal, default causal, seqlen_k non-divisible
(non-causal and causal), head_dim=64 — all PASS
- CpAsync: same 4 configs — all PASS (no regressions)
Contributed by Second Nature Computing (https://joinsecondnature.com)
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
|
Update: Two additional correctness fixes (commit 6e193de) 1. Non-causal OOB maskingThe TMA variant was missing out-of-bounds masking for the last K tile when Fix: the consumer loop now passes Previously only visible with 2. SMEM capacity check in
|
Add BSD-3-Clause license header to benchmark_fp8_vs_bf16.py (was missing entirely). Normalize license header format in fp8_flash_attention.py to match the canonical CUTLASS style used across all other example files. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
|
cc @Junkai-Wu |
Summary
Adds CuTe DSL Flash Attention v2 forward pass kernels for SM120 (RTX 5090, GB10 / DGX Spark), addressing the lack of high-performance FA kernels for this architecture.
Three implementations included:
BF16 Kernels (
flash_attention_v2.py)CpAsync variant (
FlashAttentionForwardSm120): All threads perform both loads and compute usingcp.async— the Ampere-era approach, tuned for SM120's 101 KB shared memory.TMA variant (
FlashAttentionForwardSm120Tma): Uses TMA (cp.async.bulk) with warp specialization — 1 dedicated DMA warp handles TMA loads while N MMA warps compute, enabling load/compute overlap via multi-stage KV pipelining withPipelineTmaAsyncmbarrier synchronization.Both use SM80-compatible tensor core instructions (
mma.sync.aligned.m16n8k16) since SM120 lacks tcgen05 MMA. Supports FP16/BF16, causal/non-causal, configurable tile sizes, asymmetric Q/K lengths, online softmax fusion, and register pipelining.FP8 Kernel (
fp8_flash_attention.py)FP8FlashAttentionSm120Opt): Uses SM120's native FP8 MMA instruction (mma.sync.aligned.kind::f8f6f4.m16n8k32) via inline PTX, with CpAsync double-buffered K/V pipeline, vectorized 4×4 byte transpose viaprmt.b32, and bank-conflict-free SMEM layout (+16 byte row padding).FP8 kernel features:
MmaAtomSM80Typesegfault with FP8 types (CuTe DSL: FP8 MMA segfaults on SM120 — MmaAtomSM80Type missing kind::f8f6f4 lowering #3044)prmt.b324×4 byte shuffle (noldmatrix.b8on SM120)TMA variant details
(seq, dim, head, batch)for native multi-batch supportSwizzle(B, 4, 3)pattern (M=4 required by TMA hardware; the CpAsync version usesSwizzle(B, 3, 3)which is not valid for TMA)kv_stagesfor double-buffering (default 2, falls back to 1 for large head_dim)Benchmark Results
FP8 vs BF16 Performance on DGX Spark (SM121a)
FP8 kernel:
FP8FlashAttentionSm120Opt(CpAsync, bank-conflict-free, 4 warps, M=64, N=32)BF16 kernel:
FlashAttentionForwardSm120(CpAsync, tiled MMA, M=128, N=64)Key findings:
BF16 TMA vs CpAsync Performance on DGX Spark (SM121a)
TMA kernel:
FlashAttentionForwardSm120Tma(warp-specialized, 3 MMA + 1 DMA warp,PipelineTmaAsync)CpAsync kernel:
FlashAttentionForwardSm120(all threads load+compute,cp.async)Both: M=128, N=64, 16 heads, 20 warmup iters, 100 timed iters
Geometric mean speedup: 0.95x · Min: 0.62x · Max: 1.25x
Key findings:
cp.async.bulk+ warp specialization on SM120, a pattern that scales better with larger tile sizes and multi-stage pipeliningBF16 TMA vs CpAsync — Updated (post causal masking fix)
The initial results above had two issues:
JIT compilation artifacts: The CuTe DSL JIT-compiles kernels on first invocation, which inflated/deflated some configs (e.g. the 1.25x outlier at B=1 S=1024 D=64 was a JIT artifact). The updated benchmark pre-warms all kernel variants before timing.
Causal masking bug (fixed in
106e24b): The TMA variant applied expensive per-element causal masking to all KV tiles instead of only theceil_div(m_block, n_block)tiles near the diagonal. The CpAsync variant correctly used a two-loop structure (masked loop + fast loop). This caused up to 40% regression on causal configs (e.g. B=1 S=512 D=64 causal: 0.69x → 1.00x after fix).Updated results with JIT pre-warming + causal fix:
Geometric mean speedup: 0.93x · Min: 0.77x · Max: 1.10x
What changed vs initial results:
Updated key findings:
cp.async.bulk+ warp specialization patterns in the CuTe DSLTest Results
Validated on NVIDIA GB10 (SM121a / DGX Spark) hardware at Second Nature Computing against PyTorch
scaled_dot_product_attention:BF16 CpAsync variant
BF16 TMA variant (
--use_tma)FP8 kernel
Tolerance:
atol=1e-02, rtol=1e-04Motivation
There are currently few high-performance flash attention kernels available for SM120. The existing Blackwell FMHA (
blackwell/fmha.py) targets SM100 and uses tcgen05 MMA + TMEM, which SM120 does not support. This implementation fills that gap by:f8f6f4MMA instructions for up to 2× arithmetic throughput over BF16Usage
Closes #2956
Contributed by Second Nature Computing — tested on DGX Spark hardware