-
Notifications
You must be signed in to change notification settings - Fork 873
Improve backward data convolution with non-unit strides #23976
Description
Problem
Backward-data convolutions with non-unit strides (stride 2, stride 3) produce a strided tensor.insert_slice into a zero-filled buffer, followed by a convolution that reads the scattered result. On the current main branch this lowers to:
- Memset — zero-fill the output buffer via DMA
- slow_memcpy — copy the source elements at strided positions via DMA
- Conv dispatch — run the convolution on the padded buffer
The slow_memcpy dispatch uses hardware DMA copy engines, which have high bandwidth for moderate tensor sizes (~100–200 MB) but degrade significantly for large tensors (>1 GB). The tensor.insert_slice cannot be fused into the convolution dispatch, so the scatter always runs as a separate dispatch before the convolution. For 1x1 backward convolutions this is particularly wasteful because the scatter commutes with the contraction and could be deferred to after it, operating on a much smaller result tensor.
Current IR after dispatch formation
1x1 backward conv (convfp16 -n 32 -c 2048 -k 1024 -y 1 -x 1 -u 2 -v 2 -g 1):
// IR after FormDispatchRegionsPass — insert_slice stays outside the dispatch
util.func public @...(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view, %arg3: !hal.fence, %arg4: !hal.fence) -> !hal.buffer_view {
%cst = arith.constant 0.000000e+00 : f32
%cst_0 = arith.constant dense<0.000000e+00> : tensor<32x50x50x2048xf16>
%0 = hal.tensor.import wait(%arg3) => %arg0 : !hal.buffer_view -> tensor<32x25x25x2048xf16>
%1 = iree_tensor_ext.compute_barrier.start %0 : tensor<32x25x25x2048xf16> -> tensor<32x25x25x2048xf16>
%2 = hal.tensor.import wait(%arg3) => %arg2 : !hal.buffer_view -> tensor<2048x1x1x1024xf16>
// insert_slice becomes Memset + slow_memcpy — not fused with conv dispatch
%inserted_slice = tensor.insert_slice %1 into %cst_0[0, 0, 0, 0] [32, 25, 25, 2048] [1, 2, 2, 1] : tensor<32x25x25x2048xf16> into tensor<32x50x50x2048xf16>
%3 = tensor.empty() : tensor<32x50x50x1024xf32>
%4 = linalg.fill ins(%cst : f32) outs(%3 : tensor<32x50x50x1024xf32>) -> tensor<32x50x50x1024xf32>
%collapsed = tensor.collapse_shape %2 [[0, 1, 2], [3]] : tensor<2048x1x1x1024xf16> into tensor<2048x1024xf16>
%5 = iree_tensor_ext.compute_barrier.start %collapsed : tensor<2048x1024xf16> -> tensor<2048x1024xf16>
%6 = tensor.empty() : tensor<1024x2048xf16>
// Preprocessing: transpose rhs.
%7 = flow.dispatch.region -> (tensor<1024x2048xf16>) {
%14 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], iterator_types = ["parallel", "parallel"]} ins(%5 : tensor<2048x1024xf16>) outs(%6 : tensor<1024x2048xf16>) {
^bb0(%in: f16, %out: f16):
linalg.yield %in : f16
} -> tensor<1024x2048xf16>
flow.return %14 : tensor<1024x2048xf16>
}
%8 = iree_tensor_ext.compute_barrier.start %7 : tensor<1024x2048xf16> -> tensor<1024x2048xf16>
%9 = tensor.empty() : tensor<32x50x50x1024xf16>
// Contraction on the LARGE scattered tensor (32x50x50 spatial — 4x more work than needed)
%10 = flow.dispatch.region -> (tensor<32x50x50x1024xf16>) {
%14 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} ins(%inserted_slice, %8 : tensor<32x50x50x2048xf16>, tensor<1024x2048xf16>) outs(%4 : tensor<32x50x50x1024xf32>) {
^bb0(%in: f16, %in_1: f16, %out: f32):
%16 = arith.extf %in : f16 to f32
%17 = arith.extf %in_1 : f16 to f32
%18 = arith.mulf %16, %17 : f32
%19 = arith.addf %out, %18 : f32
linalg.yield %19 : f32
} -> tensor<32x50x50x1024xf32>
%15 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%14 : tensor<32x50x50x1024xf32>) outs(%9 : tensor<32x50x50x1024xf16>) {
^bb0(%in: f32, %out: f16):
%16 = arith.truncf %in : f32 to f16
linalg.yield %16 : f16
} -> tensor<32x50x50x1024xf16>
flow.return %15 : tensor<32x50x50x1024xf16>
}
%11 = iree_tensor_ext.compute_barrier.end %10 : tensor<32x50x50x1024xf16> -> tensor<32x50x50x1024xf16>
%12 = hal.tensor.barrier join(%11 : tensor<32x50x50x1024xf16>) => %arg4 : !hal.fence
%13 = hal.tensor.export %12 : tensor<32x50x50x1024xf16> -> !hal.buffer_view
util.return %13 : !hal.buffer_view
}3x3 grouped backward conv (convfp16 -n 32 -c 1024 -k 1024 -y 3 -x 3 -u 2 -v 2 -g 32):
// IR after FormDispatchRegionsPass — insert_slice stays outside the dispatch
util.func public @...(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view, %arg3: !hal.fence, %arg4: !hal.fence) -> !hal.buffer_view {
%cst = arith.constant 0.000000e+00 : f32
%c2 = arith.constant 2 : index
%cst_0 = arith.constant dense<0.000000e+00> : tensor<32x52x52x32x32xf16>
%0 = hal.tensor.import wait(%arg3) => %arg0 : !hal.buffer_view -> tensor<32x25x25x1024xf16>
%1 = hal.tensor.import wait(%arg3) => %arg2 : !hal.buffer_view -> tensor<1024x3x3x32xf16>
%expanded = tensor.expand_shape %0 [[0], [1], [2], [3, 4]] output_shape [32, 25, 25, 32, 32] : tensor<32x25x25x1024xf16> into tensor<32x25x25x32x32xf16>
%2 = iree_tensor_ext.compute_barrier.start %expanded : tensor<32x25x25x32x32xf16> -> tensor<32x25x25x32x32xf16>
%expanded_1 = tensor.expand_shape %1 [[0, 1], [2], [3], [4]] output_shape [32, 32, 3, 3, 32] : tensor<1024x3x3x32xf16> into tensor<32x32x3x3x32xf16>
%3 = iree_tensor_ext.compute_barrier.start %expanded_1 : tensor<32x32x3x3x32xf16> -> tensor<32x32x3x3x32xf16>
// insert_slice becomes Memset + slow_memcpy — not fused with conv dispatch
%inserted_slice = tensor.insert_slice %2 into %cst_0[0, 1, 1, 0, 0] [32, 25, 25, 32, 32] [1, 2, 2, 1, 1] : tensor<32x25x25x32x32xf16> into tensor<32x52x52x32x32xf16>
%4 = tensor.empty() : tensor<32x50x50x32x32xf32>
%5 = linalg.fill ins(%cst : f32) outs(%4 : tensor<32x50x50x32x32xf32>) -> tensor<32x50x50x32x32xf32>
%6 = tensor.empty() : tensor<32x32x3x3x32xf16>
// Filter flip dispatch
%7 = flow.dispatch.region -> (tensor<32x32x3x3x32xf16>) {
%14 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} outs(%6 : tensor<32x32x3x3x32xf16>) {
^bb0(%out: f16):
%15 = linalg.index 0 : index
%16 = linalg.index 4 : index
%17 = linalg.index 2 : index
%18 = linalg.index 3 : index
%19 = linalg.index 1 : index
%20 = arith.subi %c2, %17 : index
%21 = arith.subi %c2, %18 : index
%extracted = tensor.extract %3[%15, %16, %20, %21, %19] : tensor<32x32x3x3x32xf16>
linalg.yield %extracted : f16
} -> tensor<32x32x3x3x32xf16>
flow.return %14 : tensor<32x32x3x3x32xf16>
}
%8 = iree_tensor_ext.compute_barrier.start %7 : tensor<32x32x3x3x32xf16> -> tensor<32x32x3x3x32xf16>
%9 = tensor.empty() : tensor<32x50x50x32x32xf16>
// Grouped convolution on the LARGE scattered tensor (32x52x52 spatial)
%10 = flow.dispatch.region -> (tensor<32x50x50x32x32xf16>) {
%14 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d3, d4, d5, d6, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%inserted_slice, %8 : tensor<32x52x52x32x32xf16>, tensor<32x32x3x3x32xf16>) outs(%5 : tensor<32x50x50x32x32xf32>) {
^bb0(%in: f16, %in_2: f16, %out: f32):
%16 = arith.extf %in : f16 to f32
%17 = arith.extf %in_2 : f16 to f32
%18 = arith.mulf %16, %17 : f32
%19 = arith.addf %out, %18 : f32
linalg.yield %19 : f32
} -> tensor<32x50x50x32x32xf32>
%15 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%14 : tensor<32x50x50x32x32xf32>) outs(%9 : tensor<32x50x50x32x32xf16>) {
^bb0(%in: f32, %out: f16):
%16 = arith.truncf %in : f32 to f16
linalg.yield %16 : f16
} -> tensor<32x50x50x32x32xf16>
flow.return %15 : tensor<32x50x50x32x32xf16>
}
%11 = iree_tensor_ext.compute_barrier.end %10 : tensor<32x50x50x32x32xf16> -> tensor<32x50x50x32x32xf16>
%collapsed = tensor.collapse_shape %11 [[0], [1], [2], [3, 4]] : tensor<32x50x50x32x32xf16> into tensor<32x50x50x1024xf16>
%12 = hal.tensor.barrier join(%collapsed : tensor<32x50x50x1024xf16>) => %arg4 : !hal.fence
%13 = hal.tensor.export %12 : tensor<32x50x50x1024xf16> -> !hal.buffer_view
util.return %13 : !hal.buffer_view
}Proposed improvements
1x1 backward convolutions: swap scatter with contraction
For 1x1 backward convolutions, the strided scatter commutes with the contraction because the reduction dimensions (kernel H, W) have loop bound 1. The key insight is that the insert_slice + contraction can be reordered: compute the contraction first on the small (un-scattered) source, then scatter the smaller result afterward.
This avoids running the contraction on the large stride-padded tensor (e.g., 32x50x50 instead of 32x25x25), reducing FLOPs by ~4x for stride 2 in 2 spatial dims. The post-contraction scatter is a simple data movement on the smaller result tensor, which is much cheaper than scattering the full input before the contraction.
Expected IR after preprocessing:
// Contraction on un-scattered source (32x25x25 spatial — 4x smaller than 32x50x50)
%small_result = linalg.generic {
indexing_maps = [
affine_map<(d0,d1,d2,d3,d4,d5,d6) -> (d0, d1+d5, d2+d6, d4)>, // src
affine_map<(d0,d1,d2,d3,d4,d5,d6) -> (d4, d5, d6, d3)>, // filter
affine_map<(d0,d1,d2,d3,d4,d5,d6) -> (d0, d1, d2, d3)> // result
],
iterator_types = ["parallel","parallel","parallel","parallel","reduction","reduction","reduction"]
} ins(%src, %filter : tensor<32x25x25x2048xf16>, tensor<2048x1x1x1024xf16>)
outs(%fill : tensor<32x25x25x1024xf32>) {
^bb0(%in: f16, %in_0: f16, %out: f32):
%0 = arith.extf %in : f16 to f32
%1 = arith.extf %in_0 : f16 to f32
%2 = arith.mulf %0, %1 : f32
%3 = arith.addf %out, %2 : f32
linalg.yield %3 : f32
} -> tensor<32x25x25x1024xf32>
// Truncate
%small_trunced = linalg.generic { ... truncf ... }
ins(%small_result : tensor<32x25x25x1024xf32>) -> tensor<32x25x25x1024xf16>
// Scatter the small result — data movement only, on a 4x smaller tensor
%output = tensor.insert_slice %small_trunced into %zeros[0,0,0,0] [32,25,25,1024] [1,2,2,1]
: tensor<32x25x25x1024xf16> into tensor<32x50x50x1024xf16>3x3 backward convolutions: convert strided insert_slice to linalg.generic
For 3x3 (and other non-1x1) backward convolutions, the scatter cannot be swapped with the contraction because the reduction dimensions (kernel H, W) have loop bound > 1. Instead, we can replace the tensor.insert_slice with a linalg.generic that computes the strided scatter in a single pass using index arithmetic.
For each output position, the generic checks whether the position maps to a valid source element: (pos - offset) must be non-negative, divisible by stride, and the quotient must be in-bounds. Power-of-2 strides use bitwise ops (and/shift) instead of expensive div/mod. The result is selected via arith.select for branchless GPU execution.
This replaces the Memset + slow_memcpy dispatch pair with a single compute dispatch. For large tensors where slow_memcpy bandwidth degrades, the compute-based scatter is faster.
Expected IR after preprocessing (3x3 grouped conv, g=32, cpg=32):
// Scatter generic — replaces Memset + slow_memcpy
%scatter = linalg.generic {
indexing_maps = [affine_map<(d0,d1,d2,d3,d4) -> (d0,d1,d2,d3,d4)>],
iterator_types = ["parallel","parallel","parallel","parallel","parallel"]
} outs(%empty : tensor<32x52x52x32x32xf16>) {
^bb0(%out: f16):
%idx_h = linalg.index 1
%shifted_h = arith.subi %idx_h, %c1 // offset = 1
%rem_h = arith.andi %shifted_h, %c1_mask // stride 2 is pow2: use bitwise
%src_h = arith.shrsi %shifted_h, %c1_shift
%ge_zero_h = arith.cmpi sge, %shifted_h, %c0
%rem_zero_h = arith.cmpi eq, %rem_h, %c0
%lt_size_h = arith.cmpi slt, %src_h, %c25
%valid_h = arith.andi %ge_zero_h, %rem_zero_h
%valid_h2 = arith.andi %valid_h, %lt_size_h
// ... same for dim 2 (W) ...
%valid = arith.andi %valid_h2, %valid_w2
%clamped_h = arith.maxsi %src_h, %c0
%clamped_h2 = arith.minsi %clamped_h, %c24 // clamp for safe extract
// ...
%extracted = tensor.extract %src[%idx_n, %clamped_h2, %clamped_w2, %idx_g, %idx_c]
%result = arith.select %valid, %extracted, %zero
linalg.yield %result : f16
} -> tensor<32x52x52x32x32xf16>
// Filter flip (reverse kernel spatial dims)
%flipped = linalg.generic { ... } ins(%filter) -> tensor<32x32x3x3x32xf16>
// Grouped convolution reads the scattered buffer
%conv = linalg.generic {
indexing_maps = [
affine_map<(d0,d1,d2,d3,d4,d5,d6,d7) -> (d0, d1+d5, d2+d6, d3, d7)>,
affine_map<(d0,d1,d2,d3,d4,d5,d6,d7) -> (d3, d4, d5, d6, d7)>,
affine_map<(d0,d1,d2,d3,d4,d5,d6,d7) -> (d0, d1, d2, d3, d4)>
],
iterator_types = ["parallel","parallel","parallel","parallel","parallel",
"reduction","reduction","reduction"]
} ins(%scatter, %flipped : tensor<32x52x52x32x32xf16>, tensor<32x32x3x3x32xf16>)
outs(%fill : tensor<32x50x50x32x32xf32>) { ... }
-> tensor<32x50x50x32x32xf32>
// Truncate + collapse group dims
%trunced = linalg.generic { ... truncf ... } -> tensor<32x50x50x32x32xf16>
%result = tensor.collapse_shape %trunced [[0],[1],[2],[3,4]]
: tensor<32x50x50x32x32xf16> into tensor<32x50x50x1024xf16>Some benchmark results on Mi355x
1x1 backward convolutions see the largest gains (up to 6.8x) from the scatter-contraction swap:
| Shape | Baseline | Improved | Speedup |
|---|---|---|---|
| bf16 n=4 c=32 H=470 W=725 k=224 g=1 | 469 us | 69 us | 6.82x |
| bf16 n=12 c=896 H=59 W=91 k=2016 g=1 | 650 us | 193 us | 3.37x |
| fp16 n=32 c=1024 H=50 W=50 k=2048 g=1 | 648 us | 220 us | 2.94x |
| fp16 n=32 c=512 H=100 W=100 k=1024 g=1 | 831 us | 292 us | 2.85x |
| bf16 n=10 c=448 H=118 W=182 k=896 g=1 | 502 us | 179 us | 2.81x |
3x3 backward convolutions see gains from the compute-based scatter replacing slow_memcpy.
| Shape | Baseline | Improved | Speedup |
|---|---|---|---|
| fp16 n=32 c=256 H=200 W=200 k=256 g=32 | 4580 us | 2580 us | 1.77x |
| bf16 n=1 c=3 H=940 W=1450 k=32 g=1 | 341 us | 274 us | 1.25x |
| bf16 n=5 c=3 H=940 W=1450 k=32 g=1 | 1575 us | 1293 us | 1.22x |
| bf16 n=3 c=3 H=940 W=1450 k=32 g=1 | 931 us | 783 us | 1.19x |
Old discussions: #20710