Conversation
|
!test |
|
!test |
Greptile SummaryThis PR implements the auto TMA transpose scheduler for nvFuser, choosing between two transpose strategies based on whether inputs or outputs are more numerous: when more inputs exist, the swizzle is applied to output shared memory and TMA store is used (output-smem-transpose path); when outputs are more numerous, the swizzle is applied to input shared memory via TMA load (input-smem-transpose path). A new Key findings:
Confidence Score: 2/5
Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[Fusion Inputs] --> B{is_output_smem_transpose?}
B -- "true\n(n_input > n_output)" --> C1[TMA Load\nno swizzle\nInput → smem]
C1 --> D1[smem → Registers]
D1 --> E1[Compute / Transpose]
E1 --> F1[Registers → smem\nswizzled MMA layout\nregs → output smem cache]
F1 --> G1[TMA Store\nwith swizzle\nsmem → Output]
B -- "false\n(n_input ≤ n_output)" --> C2[TMA Load\nwith swizzle\nInput → smem]
C2 --> D2[smem → Registers\ntranspose via\nswizzled read]
D2 --> E2[Compute]
E2 --> F2[Register Store\nRegisters → Output]
style C1 fill:#d4edda
style G1 fill:#d4edda
style C2 fill:#cce5ff
style F2 fill:#fff3cd
|
csrc/scheduler/transpose_tma.cpp
Outdated
| NVF_ERROR(grouped_inputs_outputs.size() >= 2); | ||
|
|
||
| // When there are more inputs than outputs, output smem transpose should be | ||
| // used, however, if it is not, then input smem tranpose will be used, to |
There was a problem hiding this comment.
tranpose should be transpose
| const int64_t cta_per_sm = | ||
| dev_props->maxThreadsPerMultiProcessor / threads_per_cta; | ||
| const int64_t bytes_per_cta = bytes_per_sm / cta_per_sm; | ||
| const int64_t bytes_per_tile = bytes_per_cta / n_input; |
There was a problem hiding this comment.
Add check that n_input > 0 before this division. While the scheduler validation should prevent this, defensive programming would make the code more robust.
| const int64_t bytes_per_tile = bytes_per_cta / n_input; | |
| NVF_ERROR(n_input > 0, "Expected at least one TensorView input for transpose"); | |
| const int64_t bytes_per_tile = bytes_per_cta / n_input; |
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
|
Review updated until commit bc772db Description
|
| Relevant files |
|---|
PR Reviewer Guide
Here are some key observations to aid the review process:
| 🧪 PR contains tests |
| ⚡ Recommended focus areas for review |
Potential TMA load restriction
This is more restrictive than the original which checked all loop domains. This could potentially exclude valid TMA loads where some dimensions have extent 1 but other dimensions are parallelized with threads. Need to verify this doesn't break existing TMA use cases. |
Additional Comments (2)
If On an H100 (maxThreadsPerMultiProcessor = 2048, cta_per_sm = 8, bytes_per_cta = 8192), this triggers when
The |
Additional Comments (4)
If This happens when While unlikely for typical transpose fusions (1–2 inputs), this is an unbounded loop with no guard. A simple fix is to initialise
Note the asymmetry: Step 3 already guards the analogous constraint with an explicit
These Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
|
| if (std::ranges::any_of(non_trivial_ids, [](const IterDomain* id) { | ||
| return id->isThreadDim() || | ||
| id->getParallelType() == ParallelType::Serial; | ||
| })) { |
There was a problem hiding this comment.
trivial optimization of multiple-tma loads, doesn't have to be in this PR.
|
|
||
| // When not using output smem transpose but inputs > outputs, swap groups | ||
| // so group 2 remains the swizzled side. | ||
| if (!tparams->is_output_smem_transpose && |
There was a problem hiding this comment.
This branch is not used in current heuristics, but may use it in future tuning.
|
!test |
2 similar comments
|
!test |
|
!test |
tests/cpp/test_transpose.cpp
Outdated
| auto bank_conflicts = getBankConflictInfo(ke->compiledKernel()->kernel()); | ||
| for (auto& [expr, ways] : bank_conflicts) { | ||
| auto [read_ways, write_ways] = ways; | ||
| std::cout << " Bank conflict: " << expr->toString() |
rdspring1
left a comment
There was a problem hiding this comment.
Are the Infinite loop in heuristics and Out-of-range crash with inconsistent params greptile concerns valid?
Yes, they may happen in theory, added checks. |
To reduce number of tranpose ops,
is_output_smem_transposeis added to control input/output transpose:1. When there are more inputs than outputs,
is_output_smem_transpose = TrueTMA load without swizzle, TMA store with swizzle, transpose at
regs --> output cached smem2. When there are less inputs than outputs,
is_output_smem_transpose = FalseTMA load with swizzle, register store, transpose at
input cached smem -> regsCurrent performance is in this doc.