Skip to content

[autoWS] fix causal mode of Blackwell/FA/dp#993

Open
Sibylau wants to merge 4 commits intomainfrom
jieeliu/fa-causal
Open

[autoWS] fix causal mode of Blackwell/FA/dp#993
Sibylau wants to merge 4 commits intomainfrom
jieeliu/fa-causal

Conversation

@Sibylau
Copy link
Contributor

@Sibylau Sibylau commented Feb 26, 2026

This PR tries to make the tritonbench kernel blackwell_triton_fused_attention_dp.py causal mode work with autoWS. It also needs the backport fix PR #959 and #989.

Test plan

In tritonbench, run

CUDA_VISIBLE_DEVICES=4 TRITON_ALWAYS_COMPILE=1 TRITON_KERNEL_DUMP=1 TRITON_DUMP_DIR=/tmp/triton_dumps TRITON_USE_META_WS=1 bash ~/fbsource/fbcode/ads_mkl/benchmarks/denoise.sh python run.py --op blackwell_attentions --seq-len 8192 --batch 4 --n-heads 32 --d-head 128 --mode fwd  --causal --rep 3000 --sleep 1.0 --metrics tflops --simple-output --only triton_tutorial_flash_dp_persistent_blackwell --force

Result using devgpu-31.atn1:

Processing GPU 4...
→ Locking power cap to 750 W and SM clock to 1965 MHz on GPU 4
/home/jieeliu/envs/conda/envs/autows/lib/python3.12/site-packages/torch/cuda/__init__.py:63: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
  import pynvml  # type: ignore[import]
TMA benchmarks will be running with experimental grid constant TMA descriptor.
INFO:root:TMA benchmarks will be running with experimental grid constant TMA descriptor.
WARNING:tritonbench.utils.triton_op:First-k mode: Selected 1 sequential inputs starting from index 0 (total available: 1)
WARNING:tritonbench.utils.triton_op:Input IDs to run: [0]
WARNING:tritonbench.utils.triton_op:Running input ID 0:
(Batch, Heads, Heads_KV, SeqLen, SeqLen_KV, Dhead)
----------------------------------------------------
(4, 32, 32, 8192, 8192, 128) Causal fwd
INFO:tritonbench.utils.triton_op:Took 0.11ms to get benchmark function for triton_tutorial_flash_dp_persistent_blackwell
WARNING:tritonbench.utils.triton_op:Completed input ID 0:
(Batch, Heads, Heads_KV, SeqLen, SeqLen_KV, Dhead)
----------------------------------------------------
(4, 32, 32, 8192, 8192, 128) Causal fwd
  (Batch, Heads, Heads_KV, SeqLen, SeqLen_KV, Dhead)    triton_tutorial_flash_dp_persistent_blackwell-tflops
----------------------------------------------------  ------------------------------------------------------
             (4, 32, 32, 8192, 8192, 128) Causal fwd                                                 664.033
                                   664.0332975799417

Tested the change does not impact the non-causal version:
In tritonbench, run

CUDA_VISIBLE_DEVICES=4 TRITON_ALWAYS_COMPILE=1 TRITON_KERNEL_DUMP=1 TRITON_DUMP_DIR=/tmp/triton_dumps TRITON_USE_META_WS=1 bash ~/fbsource/fbcode/ads_mkl/benchmarks/denoise.sh python run.py --op blackwell_attentions --seq-len 8192 --batch 4 --n-heads 32 --d-head 128 --mode fwd --rep 3000 --sleep 1.0 --metrics tflops --simple-output --only triton_tutorial_flash_dp_persistent_blackwell --force 

Result using devgpu-31.atn1:

Processing GPU 4...
→ Locking power cap to 750 W and SM clock to 1965 MHz on GPU 4
/home/jieeliu/envs/conda/envs/autows/lib/python3.12/site-packages/torch/cuda/__init__.py:63: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
  import pynvml  # type: ignore[import]
TMA benchmarks will be running with experimental grid constant TMA descriptor.
INFO:root:TMA benchmarks will be running with experimental grid constant TMA descriptor.
WARNING:tritonbench.utils.triton_op:First-k mode: Selected 1 sequential inputs starting from index 0 (total available: 1)
WARNING:tritonbench.utils.triton_op:Input IDs to run: [0]
WARNING:tritonbench.utils.triton_op:Running input ID 0:
(Batch, Heads, Heads_KV, SeqLen, SeqLen_KV, Dhead)
----------------------------------------------------
(4, 32, 32, 8192, 8192, 128) fwd
INFO:tritonbench.utils.triton_op:Took 0.10ms to get benchmark function for triton_tutorial_flash_dp_persistent_blackwell
WARNING:tritonbench.utils.triton_op:Completed input ID 0:
(Batch, Heads, Heads_KV, SeqLen, SeqLen_KV, Dhead)
----------------------------------------------------
(4, 32, 32, 8192, 8192, 128) fwd
  (Batch, Heads, Heads_KV, SeqLen, SeqLen_KV, Dhead)    triton_tutorial_flash_dp_persistent_blackwell-tflops
----------------------------------------------------  ------------------------------------------------------
                    (4, 32, 32, 8192, 8192, 128) fwd                                                 918.906
                                   918.9061001979485

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Feb 26, 2026
@Sibylau Sibylau changed the title [autoWS] fix causal mode fof Blackwell/FA/dp [autoWS] fix causal mode of Blackwell/FA/dp Feb 27, 2026
// loop has one shared reuse accumCnt, not one per inner loop.
SmallVector<Operation *> dummy;
// Track seen ops for the reuse group section.
DenseSet<Operation *> seenOps;
Copy link
Contributor Author

@Sibylau Sibylau Feb 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fix tries to separately handle accumCnts for perRegion and per reuse group case, and thus we can remove seenOps.

// Use rowSize for the N-dimension extent, since colOffset is an
// N-dimension position.
Interval candSizeRange = {colOffset, colOffset + cand->rowSize};
Interval allocSizeRange = {alloc->colOffset,
Copy link
Contributor Author

@Sibylau Sibylau Feb 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

colOffset is an N-dimension (TMEM row) position, while colSize is an M-dimension extent. Using rowSize (N-dimension extent) to advance the position resolves the "can't find TMEM space" issue

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a little confused about the colSize vs. rowSize. Was it working for GEMM and non-causal fwd just by chance? :]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, this is related to one of the backport PR #975:
image

I realized it is better to swap the argument during instantiation, so I reverted rowSize change and changed ttng::TMemAllocation(allocSize.numRows, allocSize.numCols)}); in the new commit.

if (!alloc->isOwnerOfSpace && alloc->reuseOwner == reuseOwner) {
if (sameLoop(alloc) ||
bufferRange[alloc].intersects(bufferRange[cand]))
if (bufferRange[alloc].intersects(bufferRange[cand])) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

relax the sameLoop(alloc) here to support two sequential loops in causal FA.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to refresh my memory on causal now :] For causal, do we call allocateTMem twice? Once on the first innermost loop, then on the 2nd innermost loop. So when running on the 2nd innermost loop, checking against allocs for the 1st inner loop, sameLoop should return false?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. I reverted this change in the new commits. Here it seems a simplification rather than necessary part of the fix -- for cross-loop allocs, sameLoop is false so only intersects is evaluated; for same-loop allocs, seems in our use case intersects returns true for local allocs that are live through the loop body, so removing sameLoop did not break anything.

idTypes[bufferIdInnermostSplit] = elemType;
++bufferId;
}
assignedId = bufferIdInnermostSplit;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Enable sharing for split buffers to avoid SMEM out of space error.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current logic in CodePartitioning is that if buffer.copy is 1, reuse group will use separate barriers, if it is > 1, reuse group will share the same array of barriers. CC @njriasan Did you add this logic for gemm? I wonder if we have a pytest of gemm to make sure this will not regress your use case. We can check in a shell script to run all tests locally for now.

@Sibylau Sibylau requested review from manman-ren and njriasan and removed request for njriasan February 27, 2026 05:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants