Skip to content

Improve backward data convolution with non-unit strides #23976

@yzhang93

Description

@yzhang93

Problem

Backward-data convolutions with non-unit strides (stride 2, stride 3) produce a strided tensor.insert_slice into a zero-filled buffer, followed by a convolution that reads the scattered result. On the current main branch this lowers to:

  1. Memset — zero-fill the output buffer via DMA
  2. slow_memcpy — copy the source elements at strided positions via DMA
  3. Conv dispatch — run the convolution on the padded buffer

The slow_memcpy dispatch uses hardware DMA copy engines, which have high bandwidth for moderate tensor sizes (~100–200 MB) but degrade significantly for large tensors (>1 GB). The tensor.insert_slice cannot be fused into the convolution dispatch, so the scatter always runs as a separate dispatch before the convolution. For 1x1 backward convolutions this is particularly wasteful because the scatter commutes with the contraction and could be deferred to after it, operating on a much smaller result tensor.

Current IR after dispatch formation

1x1 backward conv (convfp16 -n 32 -c 2048 -k 1024 -y 1 -x 1 -u 2 -v 2 -g 1):

// IR after FormDispatchRegionsPass — insert_slice stays outside the dispatch
util.func public @...(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view, %arg3: !hal.fence, %arg4: !hal.fence) -> !hal.buffer_view {
  %cst = arith.constant 0.000000e+00 : f32
  %cst_0 = arith.constant dense<0.000000e+00> : tensor<32x50x50x2048xf16>
  %0 = hal.tensor.import wait(%arg3) => %arg0 : !hal.buffer_view -> tensor<32x25x25x2048xf16>
  %1 = iree_tensor_ext.compute_barrier.start %0 : tensor<32x25x25x2048xf16> -> tensor<32x25x25x2048xf16>
  %2 = hal.tensor.import wait(%arg3) => %arg2 : !hal.buffer_view -> tensor<2048x1x1x1024xf16>
  // insert_slice becomes Memset + slow_memcpy — not fused with conv dispatch
  %inserted_slice = tensor.insert_slice %1 into %cst_0[0, 0, 0, 0] [32, 25, 25, 2048] [1, 2, 2, 1] : tensor<32x25x25x2048xf16> into tensor<32x50x50x2048xf16>
  %3 = tensor.empty() : tensor<32x50x50x1024xf32>
  %4 = linalg.fill ins(%cst : f32) outs(%3 : tensor<32x50x50x1024xf32>) -> tensor<32x50x50x1024xf32>
  %collapsed = tensor.collapse_shape %2 [[0, 1, 2], [3]] : tensor<2048x1x1x1024xf16> into tensor<2048x1024xf16>
  %5 = iree_tensor_ext.compute_barrier.start %collapsed : tensor<2048x1024xf16> -> tensor<2048x1024xf16>
  %6 = tensor.empty() : tensor<1024x2048xf16>
 // Preprocessing: transpose rhs.
  %7 = flow.dispatch.region -> (tensor<1024x2048xf16>) {
    %14 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], iterator_types = ["parallel", "parallel"]} ins(%5 : tensor<2048x1024xf16>) outs(%6 : tensor<1024x2048xf16>) {
    ^bb0(%in: f16, %out: f16):
      linalg.yield %in : f16
    } -> tensor<1024x2048xf16>
    flow.return %14 : tensor<1024x2048xf16>
  }
  %8 = iree_tensor_ext.compute_barrier.start %7 : tensor<1024x2048xf16> -> tensor<1024x2048xf16>
  %9 = tensor.empty() : tensor<32x50x50x1024xf16>
  // Contraction on the LARGE scattered tensor (32x50x50 spatial — 4x more work than needed)
  %10 = flow.dispatch.region -> (tensor<32x50x50x1024xf16>) {
    %14 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} ins(%inserted_slice, %8 : tensor<32x50x50x2048xf16>, tensor<1024x2048xf16>) outs(%4 : tensor<32x50x50x1024xf32>) {
    ^bb0(%in: f16, %in_1: f16, %out: f32):
      %16 = arith.extf %in : f16 to f32
      %17 = arith.extf %in_1 : f16 to f32
      %18 = arith.mulf %16, %17 : f32
      %19 = arith.addf %out, %18 : f32
      linalg.yield %19 : f32
    } -> tensor<32x50x50x1024xf32>
    %15 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%14 : tensor<32x50x50x1024xf32>) outs(%9 : tensor<32x50x50x1024xf16>) {
    ^bb0(%in: f32, %out: f16):
      %16 = arith.truncf %in : f32 to f16
      linalg.yield %16 : f16
    } -> tensor<32x50x50x1024xf16>
    flow.return %15 : tensor<32x50x50x1024xf16>
  }
  %11 = iree_tensor_ext.compute_barrier.end %10 : tensor<32x50x50x1024xf16> -> tensor<32x50x50x1024xf16>
  %12 = hal.tensor.barrier join(%11 : tensor<32x50x50x1024xf16>) => %arg4 : !hal.fence
  %13 = hal.tensor.export %12 : tensor<32x50x50x1024xf16> -> !hal.buffer_view
  util.return %13 : !hal.buffer_view
}

3x3 grouped backward conv (convfp16 -n 32 -c 1024 -k 1024 -y 3 -x 3 -u 2 -v 2 -g 32):

// IR after FormDispatchRegionsPass — insert_slice stays outside the dispatch
util.func public @...(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view, %arg3: !hal.fence, %arg4: !hal.fence) -> !hal.buffer_view {
  %cst = arith.constant 0.000000e+00 : f32
  %c2 = arith.constant 2 : index
  %cst_0 = arith.constant dense<0.000000e+00> : tensor<32x52x52x32x32xf16>
  %0 = hal.tensor.import wait(%arg3) => %arg0 : !hal.buffer_view -> tensor<32x25x25x1024xf16>
  %1 = hal.tensor.import wait(%arg3) => %arg2 : !hal.buffer_view -> tensor<1024x3x3x32xf16>
  %expanded = tensor.expand_shape %0 [[0], [1], [2], [3, 4]] output_shape [32, 25, 25, 32, 32] : tensor<32x25x25x1024xf16> into tensor<32x25x25x32x32xf16>
  %2 = iree_tensor_ext.compute_barrier.start %expanded : tensor<32x25x25x32x32xf16> -> tensor<32x25x25x32x32xf16>
  %expanded_1 = tensor.expand_shape %1 [[0, 1], [2], [3], [4]] output_shape [32, 32, 3, 3, 32] : tensor<1024x3x3x32xf16> into tensor<32x32x3x3x32xf16>
  %3 = iree_tensor_ext.compute_barrier.start %expanded_1 : tensor<32x32x3x3x32xf16> -> tensor<32x32x3x3x32xf16>
  // insert_slice becomes Memset + slow_memcpy — not fused with conv dispatch
  %inserted_slice = tensor.insert_slice %2 into %cst_0[0, 1, 1, 0, 0] [32, 25, 25, 32, 32] [1, 2, 2, 1, 1] : tensor<32x25x25x32x32xf16> into tensor<32x52x52x32x32xf16>
  %4 = tensor.empty() : tensor<32x50x50x32x32xf32>
  %5 = linalg.fill ins(%cst : f32) outs(%4 : tensor<32x50x50x32x32xf32>) -> tensor<32x50x50x32x32xf32>
  %6 = tensor.empty() : tensor<32x32x3x3x32xf16>
  // Filter flip dispatch
  %7 = flow.dispatch.region -> (tensor<32x32x3x3x32xf16>) {
    %14 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} outs(%6 : tensor<32x32x3x3x32xf16>) {
    ^bb0(%out: f16):
      %15 = linalg.index 0 : index
      %16 = linalg.index 4 : index
      %17 = linalg.index 2 : index
      %18 = linalg.index 3 : index
      %19 = linalg.index 1 : index
      %20 = arith.subi %c2, %17 : index
      %21 = arith.subi %c2, %18 : index
      %extracted = tensor.extract %3[%15, %16, %20, %21, %19] : tensor<32x32x3x3x32xf16>
      linalg.yield %extracted : f16
    } -> tensor<32x32x3x3x32xf16>
    flow.return %14 : tensor<32x32x3x3x32xf16>
  }
  %8 = iree_tensor_ext.compute_barrier.start %7 : tensor<32x32x3x3x32xf16> -> tensor<32x32x3x3x32xf16>
  %9 = tensor.empty() : tensor<32x50x50x32x32xf16>
  // Grouped convolution on the LARGE scattered tensor (32x52x52 spatial)
  %10 = flow.dispatch.region -> (tensor<32x50x50x32x32xf16>) {
    %14 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d3, d4, d5, d6, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%inserted_slice, %8 : tensor<32x52x52x32x32xf16>, tensor<32x32x3x3x32xf16>) outs(%5 : tensor<32x50x50x32x32xf32>) {
    ^bb0(%in: f16, %in_2: f16, %out: f32):
      %16 = arith.extf %in : f16 to f32
      %17 = arith.extf %in_2 : f16 to f32
      %18 = arith.mulf %16, %17 : f32
      %19 = arith.addf %out, %18 : f32
      linalg.yield %19 : f32
    } -> tensor<32x50x50x32x32xf32>
    %15 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%14 : tensor<32x50x50x32x32xf32>) outs(%9 : tensor<32x50x50x32x32xf16>) {
    ^bb0(%in: f32, %out: f16):
      %16 = arith.truncf %in : f32 to f16
      linalg.yield %16 : f16
    } -> tensor<32x50x50x32x32xf16>
    flow.return %15 : tensor<32x50x50x32x32xf16>
  }
  %11 = iree_tensor_ext.compute_barrier.end %10 : tensor<32x50x50x32x32xf16> -> tensor<32x50x50x32x32xf16>
  %collapsed = tensor.collapse_shape %11 [[0], [1], [2], [3, 4]] : tensor<32x50x50x32x32xf16> into tensor<32x50x50x1024xf16>
  %12 = hal.tensor.barrier join(%collapsed : tensor<32x50x50x1024xf16>) => %arg4 : !hal.fence
  %13 = hal.tensor.export %12 : tensor<32x50x50x1024xf16> -> !hal.buffer_view
  util.return %13 : !hal.buffer_view
}

Proposed improvements

1x1 backward convolutions: swap scatter with contraction

For 1x1 backward convolutions, the strided scatter commutes with the contraction because the reduction dimensions (kernel H, W) have loop bound 1. The key insight is that the insert_slice + contraction can be reordered: compute the contraction first on the small (un-scattered) source, then scatter the smaller result afterward.

This avoids running the contraction on the large stride-padded tensor (e.g., 32x50x50 instead of 32x25x25), reducing FLOPs by ~4x for stride 2 in 2 spatial dims. The post-contraction scatter is a simple data movement on the smaller result tensor, which is much cheaper than scattering the full input before the contraction.

Expected IR after preprocessing:

// Contraction on un-scattered source (32x25x25 spatial — 4x smaller than 32x50x50)
%small_result = linalg.generic {
  indexing_maps = [
    affine_map<(d0,d1,d2,d3,d4,d5,d6) -> (d0, d1+d5, d2+d6, d4)>,  // src
    affine_map<(d0,d1,d2,d3,d4,d5,d6) -> (d4, d5, d6, d3)>,          // filter
    affine_map<(d0,d1,d2,d3,d4,d5,d6) -> (d0, d1, d2, d3)>           // result
  ],
  iterator_types = ["parallel","parallel","parallel","parallel","reduction","reduction","reduction"]
} ins(%src, %filter : tensor<32x25x25x2048xf16>, tensor<2048x1x1x1024xf16>)
  outs(%fill : tensor<32x25x25x1024xf32>) {
  ^bb0(%in: f16, %in_0: f16, %out: f32):
    %0 = arith.extf %in : f16 to f32
    %1 = arith.extf %in_0 : f16 to f32
    %2 = arith.mulf %0, %1 : f32
    %3 = arith.addf %out, %2 : f32
    linalg.yield %3 : f32
} -> tensor<32x25x25x1024xf32>
// Truncate
%small_trunced = linalg.generic { ... truncf ... }
  ins(%small_result : tensor<32x25x25x1024xf32>) -> tensor<32x25x25x1024xf16>
// Scatter the small result — data movement only, on a 4x smaller tensor
%output = tensor.insert_slice %small_trunced into %zeros[0,0,0,0] [32,25,25,1024] [1,2,2,1]
  : tensor<32x25x25x1024xf16> into tensor<32x50x50x1024xf16>

3x3 backward convolutions: convert strided insert_slice to linalg.generic

For 3x3 (and other non-1x1) backward convolutions, the scatter cannot be swapped with the contraction because the reduction dimensions (kernel H, W) have loop bound > 1. Instead, we can replace the tensor.insert_slice with a linalg.generic that computes the strided scatter in a single pass using index arithmetic.

For each output position, the generic checks whether the position maps to a valid source element: (pos - offset) must be non-negative, divisible by stride, and the quotient must be in-bounds. Power-of-2 strides use bitwise ops (and/shift) instead of expensive div/mod. The result is selected via arith.select for branchless GPU execution.

This replaces the Memset + slow_memcpy dispatch pair with a single compute dispatch. For large tensors where slow_memcpy bandwidth degrades, the compute-based scatter is faster.

Expected IR after preprocessing (3x3 grouped conv, g=32, cpg=32):

// Scatter generic — replaces Memset + slow_memcpy
%scatter = linalg.generic {
  indexing_maps = [affine_map<(d0,d1,d2,d3,d4) -> (d0,d1,d2,d3,d4)>],
  iterator_types = ["parallel","parallel","parallel","parallel","parallel"]
} outs(%empty : tensor<32x52x52x32x32xf16>) {
^bb0(%out: f16):
  %idx_h = linalg.index 1
  %shifted_h = arith.subi %idx_h, %c1             // offset = 1
  %rem_h = arith.andi %shifted_h, %c1_mask        // stride 2 is pow2: use bitwise
  %src_h = arith.shrsi %shifted_h, %c1_shift
  %ge_zero_h = arith.cmpi sge, %shifted_h, %c0
  %rem_zero_h = arith.cmpi eq, %rem_h, %c0
  %lt_size_h = arith.cmpi slt, %src_h, %c25
  %valid_h = arith.andi %ge_zero_h, %rem_zero_h
  %valid_h2 = arith.andi %valid_h, %lt_size_h
  // ... same for dim 2 (W) ...
  %valid = arith.andi %valid_h2, %valid_w2
  %clamped_h = arith.maxsi %src_h, %c0
  %clamped_h2 = arith.minsi %clamped_h, %c24      // clamp for safe extract
  // ...
  %extracted = tensor.extract %src[%idx_n, %clamped_h2, %clamped_w2, %idx_g, %idx_c]
  %result = arith.select %valid, %extracted, %zero
  linalg.yield %result : f16
} -> tensor<32x52x52x32x32xf16>

// Filter flip (reverse kernel spatial dims)
%flipped = linalg.generic { ... } ins(%filter) -> tensor<32x32x3x3x32xf16>

// Grouped convolution reads the scattered buffer
%conv = linalg.generic {
  indexing_maps = [
    affine_map<(d0,d1,d2,d3,d4,d5,d6,d7) -> (d0, d1+d5, d2+d6, d3, d7)>,
    affine_map<(d0,d1,d2,d3,d4,d5,d6,d7) -> (d3, d4, d5, d6, d7)>,
    affine_map<(d0,d1,d2,d3,d4,d5,d6,d7) -> (d0, d1, d2, d3, d4)>
  ],
  iterator_types = ["parallel","parallel","parallel","parallel","parallel",
                    "reduction","reduction","reduction"]
} ins(%scatter, %flipped : tensor<32x52x52x32x32xf16>, tensor<32x32x3x3x32xf16>)
  outs(%fill : tensor<32x50x50x32x32xf32>) { ... }
  -> tensor<32x50x50x32x32xf32>

// Truncate + collapse group dims
%trunced = linalg.generic { ... truncf ... } -> tensor<32x50x50x32x32xf16>
%result = tensor.collapse_shape %trunced [[0],[1],[2],[3,4]]
  : tensor<32x50x50x32x32xf16> into tensor<32x50x50x1024xf16>

Some benchmark results on Mi355x

1x1 backward convolutions see the largest gains (up to 6.8x) from the scatter-contraction swap:

Shape Baseline Improved Speedup
bf16 n=4 c=32 H=470 W=725 k=224 g=1 469 us 69 us 6.82x
bf16 n=12 c=896 H=59 W=91 k=2016 g=1 650 us 193 us 3.37x
fp16 n=32 c=1024 H=50 W=50 k=2048 g=1 648 us 220 us 2.94x
fp16 n=32 c=512 H=100 W=100 k=1024 g=1 831 us 292 us 2.85x
bf16 n=10 c=448 H=118 W=182 k=896 g=1 502 us 179 us 2.81x

3x3 backward convolutions see gains from the compute-based scatter replacing slow_memcpy.

Shape Baseline Improved Speedup
fp16 n=32 c=256 H=200 W=200 k=256 g=32 4580 us 2580 us 1.77x
bf16 n=1 c=3 H=940 W=1450 k=32 g=1 341 us 274 us 1.25x
bf16 n=5 c=3 H=940 W=1450 k=32 g=1 1575 us 1293 us 1.22x
bf16 n=3 c=3 H=940 W=1450 k=32 g=1 931 us 783 us 1.19x

Old discussions: #20710

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions