Skip to content

add auto tma transpose scheduler#6018

Open
liqiangxl wants to merge 10 commits intomainfrom
llu/transpose_output_smem_auto
Open

add auto tma transpose scheduler#6018
liqiangxl wants to merge 10 commits intomainfrom
llu/transpose_output_smem_auto

Conversation

@liqiangxl
Copy link
Collaborator

To reduce number of tranpose ops, is_output_smem_transpose is added to control input/output transpose:

1. When there are more inputs than outputs, is_output_smem_transpose = True
TMA load without swizzle, TMA store with swizzle, transpose at regs --> output cached smem

2. When there are less inputs than outputs, is_output_smem_transpose = False
TMA load with swizzle, register store, transpose at input cached smem -> regs

Current performance is in this doc.

@liqiangxl
Copy link
Collaborator Author

!test

@liqiangxl
Copy link
Collaborator Author

!test

@liqiangxl liqiangxl marked this pull request as ready for review February 27, 2026 15:40
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 27, 2026

Greptile Summary

This PR implements the auto TMA transpose scheduler for nvFuser, choosing between two transpose strategies based on whether inputs or outputs are more numerous: when more inputs exist, the swizzle is applied to output shared memory and TMA store is used (output-smem-transpose path); when outputs are more numerous, the swizzle is applied to input shared memory via TMA load (input-smem-transpose path). A new EnableOption::TmaTranspose flag gates the feature, and the change integrates cleanly with existing scheduler infrastructure.

Key findings:

  • Schedule correctness risk (transpose_tma.cpp:327-331): tma_store_tvs (the output smem intermediate) is not added to skip_tvs before the Step 4 register propagation. Unlike tma_load_tvs, which are explicitly excluded so the register-level TransformPropagator does not overwrite their TMA/swizzle schedule, the output smem TVs scheduled by mma_utils::scheduleTMAStoreForMmaOutput in Step 2 may have their schedule silently overwritten by the propagation pass.
  • Dead swap condition (transpose_tma.cpp:180-183): The cached_inputs.size() > cached_outputs.size() guard is logically impossible when !is_output_smem_transpose holds (which already implies n_input ≤ n_output). This block is unreachable.
  • Missing bounds assertion (transpose_tma.cpp:259): tma_store_tvs.at(0) is dereferenced without first asserting the vector is non-empty; a failed TMA store setup would produce an opaque std::out_of_range exception.
  • Unused test variables (test_transpose.cpp:735-737): The structured binding auto [read_ways, write_ways] = ways; inside the bank-conflict loop is never read, producing compiler warnings. The assertion EXPECT_TRUE(bank_conflicts.empty()) is correctly placed outside the loop, making the loop body entirely dead.

Confidence Score: 2/5

  • Potentially unsafe to merge as-is due to a likely schedule-correctness bug where tma_store_tvs are not excluded from the register propagation pass, risking silent overwrite of the MMA swizzle schedule on the output smem path.
  • The core concern is that tma_store_tvs are missing from skip_tvs in the register propagation step, which by analogy with the existing tma_load_tvs exclusion should cause the output-smem-transpose path to have an incorrect final schedule. The dead swap condition and missing bound-check are lower-severity but also indicate the code path hasn't been exercised under edge conditions. The test suite covers the main cases but doesn't explicitly verify the scheduled IR structure, so a subtle schedule overwrite would not necessarily surface as a numerical test failure.
  • csrc/scheduler/transpose_tma.cpp requires the most attention, specifically the skip_tvs construction in Step 4 and the dead group-swap condition.

Important Files Changed

Filename Overview
csrc/scheduler/transpose_tma.cpp Core new scheduler implementation. Contains the dead group-swap condition, a missing bounds assertion before tma_store_tvs.at(0), and a likely schedule-correctness bug where tma_store_tvs is not excluded from register-propagation skip set, risking overwrite of the MMA swizzle applied in Step 2.
tests/cpp/test_transpose.cpp New parameterized tests covering multiple TMA transpose configurations, dtypes, and bank conflict checks. Contains a dead loop with unused structured bindings (read_ways/write_ways) that generates compiler warnings; otherwise good coverage.
csrc/scheduler/transpose_heuristic.h Adds use_tma_store, is_output_smem_transpose, chunks_per_thread, and elements_per_chunk fields with correct equality, hash, and toString implementations.
csrc/options.h Adds TmaTranspose to EnableOption enum, adds uint8_t underlying type to all option enums (all enums have well under 255 values), and fixes copy constructor initialisation order. All changes are clean.
csrc/options.cpp Registers tma_transpose string → EnableOption::TmaTranspose mapping and migrates std::sort to std::ranges::sort. Straightforward and correct.
csrc/scheduler/transpose.cpp Gates TMA heuristic behind EnableOption::TmaTranspose and extends schedule dispatch to also route use_tma_store paths through the TMA scheduler. Clean changes.
csrc/device_lower/analysis/tma.cpp Refactors getBatchableTmaLoads to skip trivial (extent=1) loop dimensions before checking for thread-parallel or serial parallelisation, enabling the TMA transpose scheduler to work correctly with extent-1 padding dims.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[Fusion Inputs] --> B{is_output_smem_transpose?}

    B -- "true\n(n_input > n_output)" --> C1[TMA Load\nno swizzle\nInput → smem]
    C1 --> D1[smem → Registers]
    D1 --> E1[Compute / Transpose]
    E1 --> F1[Registers → smem\nswizzled MMA layout\nregs → output smem cache]
    F1 --> G1[TMA Store\nwith swizzle\nsmem → Output]

    B -- "false\n(n_input ≤ n_output)" --> C2[TMA Load\nwith swizzle\nInput → smem]
    C2 --> D2[smem → Registers\ntranspose via\nswizzled read]
    D2 --> E2[Compute]
    E2 --> F2[Register Store\nRegisters → Output]

    style C1 fill:#d4edda
    style G1 fill:#d4edda
    style C2 fill:#cce5ff
    style F2 fill:#fff3cd
Loading

Comments Outside Diff (4)

  1. tests/cpp/test_transpose.cpp, line 735-737 (link)

    Dead loop with unused variables

    The for loop declares read_ways and write_ways via structured binding but never uses them. The actual correctness check (EXPECT_TRUE(bank_conflicts.empty())) sits outside the loop, so the entire loop body is dead code and will generate compiler warnings about unused variables. The loop should be removed entirely.

  2. csrc/scheduler/transpose_tma.cpp, line 258-259 (link)

    Unchecked access into tma_store_tvs

    tma_store_tvs.at(0) is called to derive the TMA swizzle type without first asserting that tma_store_tvs is non-empty. This vector is populated in the use_tma_store block earlier; if cached_outputs happened to be empty (e.g., all outputs were non-TensorView), the vector would be empty and this line would throw std::out_of_range at runtime. A defensive assertion keeps the failure mode clear:

  3. csrc/scheduler/transpose_tma.cpp, line 178-183 (link)

    Dead swap condition — always false

    is_output_smem_transpose is set to n_input > n_output (line 59). So !is_output_smem_transpose holds only when n_input <= n_output. Since cached_inputs.size() mirrors n_input (both count TensorView inputs) and similarly for outputs, the second conjunct cached_inputs.size() > cached_outputs.size() is impossible when the first conjunct holds. This block of code is therefore dead and will never execute.

    If a swap is ever needed for an alternate configuration (e.g., manually overridden is_output_smem_transpose), the condition should be derived from is_output_smem_transpose alone rather than a redundant size comparison.

  4. csrc/scheduler/transpose_tma.cpp, line 327-332 (link)

    tma_store_tvs not excluded from register-propagation scope

    skip_tvs correctly excludes tma_load_tvs so that the register-level TransformPropagator in Step 4 does not overwrite the TMA load smem schedule applied in Step 3. However, tma_store_tvs (the smem intermediates for TMA store) are not added to skip_tvs. As a result, the MaxLogicalDomainInfoSpanningTree traversal can still reach these TVs and the TransformPropagator will re-apply ref_tv's register-level transforms, silently overwriting the MMA swizzle schedule set by mma_utils::scheduleTMAStoreForMmaOutput in Step 2.

    By analogy with how tma_load_tvs are excluded, tma_store_tvs should be excluded whenever use_tma_store is true:

Last reviewed commit: 872edfe

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

7 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

NVF_ERROR(grouped_inputs_outputs.size() >= 2);

// When there are more inputs than outputs, output smem transpose should be
// used, however, if it is not, then input smem tranpose will be used, to
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tranpose should be transpose

const int64_t cta_per_sm =
dev_props->maxThreadsPerMultiProcessor / threads_per_cta;
const int64_t bytes_per_cta = bytes_per_sm / cta_per_sm;
const int64_t bytes_per_tile = bytes_per_cta / n_input;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add check that n_input > 0 before this division. While the scheduler validation should prevent this, defensive programming would make the code more robust.

Suggested change
const int64_t bytes_per_tile = bytes_per_cta / n_input;
NVF_ERROR(n_input > 0, "Expected at least one TensorView input for transpose");
const int64_t bytes_per_tile = bytes_per_cta / n_input;

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

@liqiangxl liqiangxl requested a review from rdspring1 February 27, 2026 17:24
@github-actions
Copy link

github-actions bot commented Mar 2, 2026

Review updated until commit bc772db

Description

  • Implements automatic TMA (Tensor Memory Access) transpose scheduler with two paths: input smem transpose (swizzle on input) and output smem transpose (swizzle on output)

  • Adds new TmaTranspose enable option to toggle the feature; scheduler falls back to non-TMA when disabled

  • Introduces new parameters: use_tma_store, is_output_smem_transpose, chunks_per_thread, elements_per_chunk for flexible TMA configuration

  • Adds comprehensive tests covering different dtypes, transpose dimensions, and TMA parameter combinations

Changes walkthrough

Relevant files

PR Reviewer Guide

Here are some key observations to aid the review process:

🧪 PR contains tests
⚡ Recommended focus areas for review
Potential TMA load restriction

The new code filters loop domains to only include non-trivial IDs (extent > 1 or non-const) before checking for thread/serial dims.
This is more restrictive than the original which checked all loop domains. This could potentially exclude valid TMA loads
where some dimensions have extent 1 but other dimensions are parallelized with threads. Need to verify this doesn't break
existing TMA use cases.

auto non_trivial_ids =
    tv->getLoopDomain() | std::views::filter([](const IterDomain* id) {
      return !id->extent()->isConstScalar() ||
          id->extent()->evaluate().as<int64_t>() > 1;
    });
if (std::ranges::any_of(non_trivial_ids, [](const IterDomain* id) {
      return id->isThreadDim() ||
          id->getParallelType() == ParallelType::Serial;
    })) {
  return {};
}
Missing null check

In scheduleTranspose, when setting up TMA store (lines 165-172), the code accesses fusion->outputs()[output_idx] without
checking if output_idx is within bounds. While cached_outputs should correspond to outputs, a bounds check would be safer.

for (auto [cached_output, output_idx] : cached_outputs) {
  auto output = fusion->outputs()[output_idx]->as<TensorView>();
  output->definition()->as<LoadStoreOp>()->setOpType(
      LoadStoreOpType::CpAsyncBulkTensorTile);
  cached_output->setMemoryType(MemoryType::Shared);
  cached_output->cacheBefore();
  tma_store_tvs.push_back(cached_output);
}
Thread safety consideration

The copy constructor was modified to use a lambda that captures other.mutex_ and returns other.options_. While this appears
correct, the original implementation directly assigned options_. The new approach should be verified to maintain the same
thread-safety semantics under concurrent access patterns.

Options(const Options& other)
    : options_([&other]() {
        std::lock_guard<std::mutex> lock_other(other.mutex_);
        return other.options_;
      }()) {}

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 3, 2026

Additional Comments (2)

csrc/scheduler/transpose_tma.cpp, line 106
Infinite loop when estimated_tile_size1 starts at zero

If bytes_per_tile < kTmaSwizzleBytes (line 91-92), integer division yields estimated_tile_size1 = 0. The while loop (line 104) then spins forever because 0 * 2 == 0 and get_chunks_per_thread() (line 98-102) stays at 0, which is always less than min_chunks_per_thread = 4.

On an H100 (maxThreadsPerMultiProcessor = 2048, cta_per_sm = 8, bytes_per_cta = 8192), this triggers when n_input > 64. Add an initialization guard before the loop:

  // Ensure we start from at least 1 to avoid multiplying 0 forever.
  if (estimated_tile_size1 == 0) {
    estimated_tile_size1 = 1;
  }
  while (get_chunks_per_thread() < min_chunks_per_thread) {
    estimated_tile_size1 *= 2;
  }

tests/cpp/test_transpose.cpp, line 1947
Unconditional debug output will pollute test logs

The std::cout block (lines 1945–1947) prints every bank conflict unconditionally. This makes test runner output noisy, especially since the BFloat16 path is expected to have bank conflicts. Consider wrapping the print in a debug flag or removing it:

      if (auto* ke = dynamic_cast<KernelExecutor*>(executor.get())) {
        auto bank_conflicts = getBankConflictInfo(ke->compiledKernel()->kernel());
        if (dtype == DataType::Float) {
          EXPECT_TRUE(bank_conflicts.empty());
        } else {
          // TODO: update to EXPECT_TRUE once bf16 bank conflicts are resolved.
          EXPECT_FALSE(bank_conflicts.empty());
        }
      }

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 3, 2026

Additional Comments (4)

csrc/scheduler/transpose_tma.cpp, line 107
Potential infinite loop when estimated_tile_size1 initializes to zero

If bytes_per_tile < kTmaSwizzleBytes (128), integer division yields estimated_tile_size1 = 0. The while loop then evaluates get_chunks_per_thread() as 0 (because the numerator is 0 * tile_size2 = 0) and multiplies: 0 * 2 = 0 — the loop never terminates.

This happens when bytes_per_cta / n_input < 128. With an SM90 GPU (maxThreadsPerMultiProcessor = 2048), cta_per_sm = 8, giving bytes_per_cta = 8192. So the loop infinite-hangs when n_input > 64.

While unlikely for typical transpose fusions (1–2 inputs), this is an unbounded loop with no guard. A simple fix is to initialise estimated_tile_size1 to at least 1:

int64_t estimated_tile_size1 =
    std::max(int64_t(1), bytes_per_tile / kTmaSwizzleBytes);

csrc/scheduler/transpose_tma.cpp, line 267
Missing guard before accessing tma_store_tvs when use_tma_store may be false

tma_store_tvs is only populated when tparams->use_tma_store == true (lines 164–173), but this block checks only tparams->is_output_smem_transpose. If is_output_smem_transpose = true but use_tma_store = false, then tma_store_tvs will be empty and .at(0) throws std::out_of_range.

Note the asymmetry: Step 3 already guards the analogous constraint with an explicit NVF_ERROR(tparams->use_tma_load, ...) at line 286-288. Adding the same guard here would be consistent:

if (tparams->is_output_smem_transpose) {
    NVF_ERROR(
        tparams->use_tma_store,
        "TMA store must be used when output smem is transposed");
    MmaInputSmemSwizzle swizzle =
        mma_utils::tmaSwizzleSharedMemory(tma_store_tvs.at(0));

tests/cpp/test_transpose.cpp, line 1949
Debug std::cout in test code — use GTest facilities instead

These std::cout lines will only fire when bank conflicts are detected (when the test is already failing). However, raw std::cout in tests is unconventional — GTest's ADD_FAILURE() / SCOPED_TRACE or just the EXPECT_TRUE failure message would be more idiomatic:

      for (auto& [expr, ways] : bank_conflicts) {
        auto [read_ways, write_ways] = ways;
        ADD_FAILURE() << "Bank conflict in: " << expr->toString()
                      << "  read=" << read_ways << "-way"
                      << ", write=" << write_ways << "-way";
      }

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!


tests/cpp/test_transpose.cpp, line 1969
Typo "tranapose" should be "transpose" in multiple lines

// Test different combinations of TMA transpose parameters:
// (is_output_smem, use_tma_load, use_tma_store)
//   (false, true, false)  - input smem transpose, TMA load only
//   (false, true, true)   - input smem transpose, TMA load + TMA store
//   (true,  true, true)   - output smem transpose, TMA load + TMA store
//   (true,  false, true)  - output smem transpose, TMA store only

if (std::ranges::any_of(non_trivial_ids, [](const IterDomain* id) {
return id->isThreadDim() ||
id->getParallelType() == ParallelType::Serial;
})) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

trivial optimization of multiple-tma loads, doesn't have to be in this PR.


// When not using output smem transpose but inputs > outputs, swap groups
// so group 2 remains the swizzled side.
if (!tparams->is_output_smem_transpose &&
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This branch is not used in current heuristics, but may use it in future tuning.

@liqiangxl
Copy link
Collaborator Author

!test

2 similar comments
@liqiangxl
Copy link
Collaborator Author

!test

@liqiangxl
Copy link
Collaborator Author

!test

auto bank_conflicts = getBankConflictInfo(ke->compiledKernel()->kernel());
for (auto& [expr, ways] : bank_conflicts) {
auto [read_ways, write_ways] = ways;
std::cout << " Bank conflict: " << expr->toString()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is std::cout necessary?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed

Copy link
Collaborator

@rdspring1 rdspring1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are the Infinite loop in heuristics and Out-of-range crash with inconsistent params greptile concerns valid?

@liqiangxl
Copy link
Collaborator Author

Are the Infinite loop in heuristics and Out-of-range crash with inconsistent params greptile concerns valid?

Yes, they may happen in theory, added checks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants