Skip to content

proto(cubecl): GPU resize + ML preprocessing kernels — 2.4-3.5× faster than NVIDIA VPI on Jetson Orin#897

Draft
edgarriba wants to merge 15 commits intomainfrom
proto/cubecl
Draft

proto(cubecl): GPU resize + ML preprocessing kernels — 2.4-3.5× faster than NVIDIA VPI on Jetson Orin#897
edgarriba wants to merge 15 commits intomainfrom
proto/cubecl

Conversation

@edgarriba
Copy link
Copy Markdown
Member

Summary

Self-contained engineering prototype evaluating cubecl as a portable GPU/CPU compute backend for kornia-rs image kernels. Lives in a new sub-workspace crate crates/kornia-cubecl/ so the heavy cubecl dep tree (cubecl-cuda + cudarc, cubecl-cpu + MLIR) doesn't infect the main workspace's resolver.

This PR is intended as a starting point for GSoC — see "For @cjpurackal and @Incharajayaram" section below.

Headline numbers (Jetson Orin Nano, head-to-head with NVIDIA VPI 3.2.4)

At 1080p → 540p (typical ML preprocessing):

pipeline μs vs VPI just-resize
VPI cuda (resize only) 593 1.00×
Our cubecl-cuda (resize only) 169 3.5× faster
Our cubecl-cuda fused (resize + gray + normalize) 173 3.4× faster while doing 3× the work
Our cubecl-cuda fused (resize + ImageNet normalize + CHW) 243 2.4× faster, model-ready output

Full numbers across 6 sizes (256² → 8K) in crates/kornia-cubecl/RESULTS.md. Bit-exact correctness vs fast_image_resize NEON (max_diff = 0 on all 4 sizes tested).

What's in this prototype

  • #[cube] composable primitives (inlined at codegen):
    sample_bilinear_u8_rgb_pixel, rgb_to_gray_u8, normalize_u8_to_f32, normalize_chan_u8_to_f32
  • Standalone launchers (one kernel each):
    resize (6 variants — baseline, x4, x16, with-pre-uploaded-weights, x4_pw, pw_wide), rgb_to_gray_u8, normalize_u8_to_f32, hwc_u8_to_chw_f32_normalize
  • Pre-fused common pipelines (single kernel, primitives inlined):
    resize_to_gray_normalize_with_weights (HWC f32 gray output)
    resize_to_chw_normalize_with_weights (CHW f32 RGB output, ImageNet-style)
  • Correctness test in tests/correctness.rs — passes bit-exact on cubecl-cpu
  • Three benches:
    • examples/bench_min.rs — single-op size sweep, NEON vs cubecl-cpu vs cubecl-cuda
    • examples/bench_fusion.rs — pipeline fusion, sequential vs fused, with side-by-side summary
    • examples/vpi_bench.py — VPI baseline (Python, vpi3-python-src package)

Key learnings (full analysis in RESULTS.md)

  1. Memory traffic = cost; arithmetic = free. Fused kernel that does resize + RGB→gray + normalize costs 0.98× of resize-only — the extra ops fill idle bandwidth-bound pipeline slots.
  2. Naive composition costs 3× the resize alone (the DRAM round-trip tax between sequential kernels). Fusion reclaims it.
  3. Pre-uploaded weights matter at small sizes. For 1080p preprocessing, pre-uploading the bilinear weight tables boosts throughput from 1710 → 2905 Mpix/s (1.7×) — material for video-pipeline workloads where the resize shape is fixed across thousands of frames.
  4. CPU and GPU want opposite tile sizes. cubecl-cpu wants x16 (fewer threads, less overhead); cubecl-cuda wants x1 (more threads, max occupancy). Same kernel, different optimal launch geometry per backend.
  5. Jetson cuda gotcha: cudarc's fallback-latest feature defaults to CUDA 13.2 symbol bindings, which Jetson's libcuda doesn't have. Build with CUDARC_CUDA_VERSION=12060.

How the cuda numbers were unblocked

First attempt panicked: libcuda.so: undefined symbol: cuCoredumpDeregisterCompleteCallback (a CUDA 13.2 symbol cudarc binds when nvcc isn't on PATH). Fixed by setting CUDARC_CUDA_VERSION=12060 before cargo build so cudarc binds only CUDA 12.6 symbols matching Jetson's libcuda. Documented in RESULTS.md.

How to reproduce

# cubecl-cuda variants (NEON vs cubecl-cuda, single-op size sweep)
CUDARC_CUDA_VERSION=12060 cargo build --release \
  --manifest-path crates/kornia-cubecl/Cargo.toml \
  --example bench_min
crates/kornia-cubecl/target/release/examples/bench_min

# Fusion bench (3-way: resize-only / gray pipeline / CHW pipeline)
CUDARC_CUDA_VERSION=12060 cargo build --release \
  --manifest-path crates/kornia-cubecl/Cargo.toml \
  --example bench_fusion
crates/kornia-cubecl/target/release/examples/bench_fusion

# VPI baseline (requires vpi3-python-src package, comes with JetPack 6)
python3 crates/kornia-cubecl/examples/vpi_bench.py

For @cjpurackal and @Incharajayaram — GSoC starting point

Hi! This is a draft prototype intended to give you a concrete, measurable baseline if you take on cubecl integration as a GSoC project. The crate is fully self-contained (sub-workspace), so you can iterate without breaking anything in the main kornia-rs build.

What's done (use as starting point):

  • Working cubecl-cuda + cubecl-cpu compile on Jetson Orin Nano (incl. the cuda-13020 dlsym fix)
  • Validated correctness vs the existing NEON path (bit-exact on all sizes tested)
  • Reproducible bench harness, including head-to-head vs NVIDIA VPI
  • Composable-primitives + tier-2-launchers + tier-3-fused-pipelines pattern, demonstrating the 0-3× cost difference

Suggested directions to take it further (in roughly increasing-effort order):

  1. Productize the API. The "API rules-of-thumb" section near the bottom of RESULTS.md outlines a Context<R> + Plan<R> design that hides the cubecl machinery behind a kornia-idiomatic surface. Doc not yet written; we punted that to follow-up.
  2. More ops: add rgb_to_yuv, nearest, bicubic, lanczos, gaussian_blur, gradient as primitives + standalone launchers + pre-fused pipelines.
  3. Recover end-to-end perf on Tegra. Currently end-to-end cuda is dominated by cudaMemcpy (~70 ms at 8K) because cubecl-cuda doesn't know about Tegra's unified memory. Pinned/managed memory would eliminate this round-trip entirely.
  4. Vector<u8, N> SIMD on cubecl-cpu. Currently cubecl-cpu's MLIR backend emits scalar code, leaving 5-6× perf on the table vs hand-tuned NEON. Adding explicit Line<u8>/Vector<u8> to the primitives may close it.
  5. cubecl-fusion automatic op-graph fusion. cubecl has a fusion runtime used by Burn for tensor ops; not yet evaluated for image kernels. Could obsolete the manual fused-launcher tier if mature enough.
  6. Productize integration with kornia-imgproc. Make kornia::resize() dispatch to NEON or cubecl based on a Backend enum with the same semantics on both. Now possible since correctness is bit-exact.

Build/run gotchas you'll hit:

  • Pre-release cubecl 0.10.0-pre.4 has API drift vs stable 0.4.0 (we use pre-release because cubecl-cpu only exists there). Expect occasional API churn on bumps.
  • cubecl-cpu requires the tracel-llvm-20.1.4-7 prebuilt bundle. The released archive's directory inside is mislabeled -6; we worked around by extracting and renaming. See RESULTS.md "How we unblocked cuda on Jetson Orin" section.
  • The Jetson AGX Orin numbers from https://docs.nvidia.com/vpi/perf_tegra234_rescale.json are NOT directly comparable to Orin Nano; we benched VPI on the same Orin Nano for honest head-to-head.

Happy to walk through any of this on a call. The branch will stay open.

Test plan

  • Lib unit tests pass (cargo test --lib --no-default-features --features cpu) on Jetson Orin Nano
  • Correctness test passes bit-exact (max_diff = 0) vs fast_image_resize NEON across 4 sizes
  • cubecl-cpu bench reproduces the documented numbers (within timing variance)
  • cubecl-cuda bench reproduces the documented numbers (with CUDARC_CUDA_VERSION=12060)
  • VPI baseline bench reproduces (requires JetPack 6 with VPI 3.2 + Python bindings)
  • Cross-platform CI (this is a Jetson-only prototype; x86_64 + AGX Orin runs welcomed)

edgarriba added 15 commits May 3, 2026 22:28
Design for a new kornia-cubecl crate that prototypes a bilinear u8 RGB
2x downscale kernel and benchmarks cubecl-cuda + cubecl-cpu against the
production NEON path (fast_image_resize) on Jetson Orin.
Bilinear u8 RGB 2x downscale kernel using cubecl 0.10-pre.4 (cuda runtime),
with weight precompute, public dispatch, correctness test vs fast_image_resize
NEON path, and Criterion benchmark with 5 arms across 4 sizes.

Lives as a sub-workspace to avoid cubecl-cuda's large dep tree forcing
re-resolution of the parent workspace's brittle rerun pinning. cubecl-cpu
support is gated behind --features cpu (requires tracel-llvm-20.1.4-7
prebuilt bundle, manually patched on this Jetson due to upstream dir-name
mislabel in the v20.1.4-7 release).
…o cpu feature

cubecl-cuda 0.10-pre.4 (via cudarc 0.19) calls cuCoredumpDeregisterCompleteCallback
which requires CUDA 12.3+; Jetson Orin's libcuda.so is older and panics on first
allocation. Default to cpu feature so cargo build works out of the box on this
hardware. cuda feature still buildable when target environment supports it.

Also fixed read_one() arg type (takes Handle not Binding) in test + bench, and
swapped block_on import path (cubecl::future, not cubecl::common::future).
Adds standalone std::time bench (examples/bench_min.rs) that bypasses
criterion's heavy release-mode dep tree, plus RESULTS.md with the
measurement table and analysis. cubecl-cpu kernel matches fast_image_resize
NEON output bit-exactly (max_diff=0) but is 9-119x slower across the
512^2 → 4096^2 size sweep on Jetson Orin's CPU.

cubecl-cuda arm blocked by libcuda.so missing cuCoredumpDeregister-
CompleteCallback (CUDA 12.3 symbol, cudarc 0.19 expects it).
Adds 8192² (4K output) and 1920×1080 (typical ML preprocessing) sizes.
Headline change: cubecl-cpu kernel throughput is still ramping at the
largest size tested (145 Mpix/s at 4K out, up from 100 at 2K), suggesting
asymptotic peak of 200-300 Mpix/s vs NEON's 1100-1650 Mpix/s ceiling.
Real compute gap is 5-6×, not 9-119× — small-input numbers were dominated
by per-call dispatch overhead.
Adds resize_bilinear_u8_rgb_kernel_x4 and _x16 variants that process 4 or 16
dst pixels per thread, reducing total thread count and amortizing cubecl-cpu's
per-thread dispatch overhead. At 8192²→4096² downscale, kernel throughput
goes from 154 → 308 Mpix/s (x16 variant), closing the gap to NEON from 8× to
4×. Optimal tile size is not monotonic in input size: x4 wins at 2048², x16
at 4096² and 1080p.
Unblocked cubecl-cuda on Jetson by setting CUDARC_CUDA_VERSION=12060 before
build (forces cudarc to bind only CUDA 12.6 symbols; without it, build.rs
falls back to cuda-13020 latest which dlsyms cuCoredumpDeregisterCompleteCallback,
a CUDA 13.2 symbol absent from Jetson libcuda).

Results: cubecl_cuda_kernel hits 2316-2984 Mpix/s vs NEON's 678-1208 across
1024² → 8192² inputs; 1080p→540p ML preprocessing case is 2.9x faster on
cuda kernel. End-to-end cuda is dominated by cudaMemcpy though — Tegra
unified memory wasted by cudarc's explicit copies.

Also: x16 tile variant is faster on CPU (fewer threads = less overhead),
slower on GPU (fewer threads = lower occupancy). Same kernel, opposite
optima per backend.
Adds resize_bilinear_u8_rgb_with_weights + WeightHandles struct that
caches the four small weight-buffer uploads across calls. Material at
small sizes where per-dispatch overhead dominates: 256² out goes
500 → 1100 Mpix/s (2.2x).

Head-to-head vs NVIDIA VPI 3.2.4 on the SAME Jetson Orin Nano (not
extrapolated from AGX Orin docs):

  size       cubecl_cuda_pw  VPI cuda  vs VPI
  256² out          1100        42      26.2x
  512² out          2646       165      16.0x
  1024² out         3098       526       5.9x
  1080p→540p        1710       593       2.9x
  2048² out         2918      1619       1.8x
  4096² out         3418      2566       1.3x

Bench script: crates/kornia-cubecl/examples/vpi_bench.py runs the VPI side.
Same inputs, same timing methodology (warmup + 10-rep median). cubecl
beats VPI at every size, hits the 10x goal at 256² and 512² output
sizes. At largest sizes both are DRAM-bandwidth-bound and converge.
Two new variants:
- resize_bilinear_u8_rgb_x4_with_weights: combines 4-pixel-per-thread
  tiling with pre-uploaded weights. Best for 2048² out (2727 Mpix/s).
- resize_bilinear_u8_rgb_with_weights_wide: 32×8 workgroup instead of 16×16.
  Best for non-square inputs and large sizes.

Updated VPI head-to-head (best variant per size on Jetson Orin Nano):
  256² out:    922 Mpix/s   22.0× VPI
  512² out:   2347           14.2× VPI
  1024² out:  3154            6.0× VPI
  1080p→540p: 2905            4.9× VPI  (was 2.9×, now beats by 4.9)
  2048² out:  2727            1.7× VPI
  4096² out:  3546            1.4× VPI  — 85% of DRAM peak, ceiling

The 10× target is achieved at 256² and 512² output. Above 1024² out
both implementations converge toward the 68 GB/s LPDDR5 bandwidth
ceiling and the gap is fundamentally hardware-limited.
Adds a two-tier API so the same source supports both standalone ops and
fused pipelines:

  Tier 1 — #[cube] primitives (inlined at codegen):
    sample_bilinear_u8_rgb_pixel(...)
    rgb_to_gray_u8(r, g, b)
    normalize_u8_to_f32(g, mean, inv_std)

  Tier 2 — standalone launchers (one kernel each):
    rgb_to_gray_u8<R>, normalize_u8_to_f32<R>, resize_bilinear_u8_rgb<R>

  Tier 3 — pre-fused common pipelines (one kernel, primitives called inline):
    resize_to_gray_normalize_with_weights<R>

Bench (bilinear resize → rgb→gray → normalize_to_f32 on Jetson Orin Nano):

  size              sequential   fused    speedup
  1024² out         1429         2766     1.94x
  1080p → 540p      1114         2890     2.59x
  2048² out         1149         3158     2.75x
  4096² out         1648         3218     1.95x
  4096² out (8K)    1856         3589     1.93x

The 2x matches the theoretical max for a 3-op chain (sequential reads+writes
~96 MB of intermediates per call; fused eliminates them). The 8K fused result
(3589 Mpix/s) slightly beats the standalone resize peak (3546 Mpix/s) because
the f32 gray output is smaller total memory than RGB.

API design: primitives are #[cube] (not #[cube(launch)]) so they inline when
composed. Callers writing a custom pipeline define their own #[cube(launch)]
kernel that calls the primitives in sequence — cubecl emits one CUDA kernel.
…preprocessing

Adds the canonical "image → ML model input" pipeline as both a 2-kernel
sequential and a single fused kernel:

  Standalone:
    hwc_u8_to_chw_f32_normalize<R>  — layout transpose + per-channel normalize
  Composable primitive:
    normalize_chan_u8_to_f32(c, mean, inv_std)
  Fused launcher:
    resize_to_chw_normalize_with_weights<R>  — bilinear resize + per-channel
       ImageNet-style normalize + CHW layout in one kernel pass

Bench results on Jetson Orin Nano (Pipeline 2: resize → normalize → CHW f32):

  size              sequential   fused    speedup
  1024² out         1336         1688     1.26x
  1080p → 540p      1124         2243     2.00x  ⭐
  2048² out         1162         2349     2.02x
  4096² out         1571         1978     1.26x
  4096² out (8K)    1788         2146     1.20x

Real-world headline: our CHW-fused 1080p→540p ML preprocessing
(resize + ImageNet-style per-channel normalize + CHW transpose) =
231 μs. VPI's *just-the-resize* = 593 μs. Complete cubecl preprocessing
runs 2.6x faster than VPI does the resize alone.

CHW fusion speedup is smaller than gray fusion at large sizes because
the f32 CHW output is 3x larger than gray (12 vs 4 bytes/pixel) so
output traffic dominates and saving the intermediate buffer matters
less proportionally.
…of resize

Adds Pipeline 0 (resize-only) and a side-by-side summary at 1080p→540p
to the fusion bench. The headline finding:

  Pipeline                       median(μs)  vs P0 (resize-only)
  P0: resize only                  168.9     1.00×  baseline
  P1 sequential: resize→gray→norm  500.8     3× SLOWER
  P1 fused: same in 1 kernel       173.2     0.98×  ← gray+norm is FREE
  P2 sequential: resize→CHW+norm   503.2     3× slower
  P2 fused: same in 1 kernel       242.8     0.70×  ← CHW costs 1.4× resize

The 3× sequential penalty is the DRAM round-trip tax (intermediate
buffer write + read between every kernel). Fusion eliminates it.

The "P1 fused = 0.98× of resize-only" result is the API pitch in one
sentence: adding gray+normalize to a fused kernel literally costs
nothing because the bandwidth-bound kernel has idle compute slots,
and there are no new memory accesses.

vs NVIDIA VPI (just-the-resize at 1080p = 593 μs):
  P1 fused (resize+gray+norm): 3.4× faster while doing 3× the work
  P2 fused (resize+CHW+norm):  2.4× faster while producing model-ready tensor
@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 4, 2026

⚠️ PR Validation Warnings

No linked issue found: This PR does not reference any issue. Please link to an issue using "Fixes #123" or "Closes #123" in the PR description.


Note: This PR can remain open, but please address these issues to ensure a smooth review process. For more information, see our Contributing Guide.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant