Skip to content

feat(ad): replace CRC extension with native Enzyme and Mooncake rules#83

Merged
kylebeggs merged 11 commits intomainfrom
mooncake-native-rrule
Feb 17, 2026
Merged

feat(ad): replace CRC extension with native Enzyme and Mooncake rules#83
kylebeggs merged 11 commits intomainfrom
mooncake-native-rrule

Conversation

@kylebeggs
Copy link
Member

Summary

  • Replaced the ChainRulesCore-based AD extension with native EnzymeRules and Mooncake rrule!! implementations
  • Extracted shared backward pass logic into src/solve/backward.jl and src/solve/ad_shared.jl, eliminating duplication between extensions
  • Added shape parameter (ε) differentiation support for IMQ and Gaussian basis functions
  • Consolidated extension code and deduplicated AD tests with shared utilities (test/extensions/ad_test_utils.jl)
  • Removed the entire RadialBasisFunctionsChainRulesCoreExt extension (~960 lines deleted)

13 commits, 25 files changed, +2280 / -1269 lines. All tests green.

Test plan

  • All existing tests pass (Pkg.test())
  • Enzyme extension tests cover operators, interpolation, basis derivatives, and shape parameter gradients
  • Mooncake extension tests cover the same surface area with rrule!!-based differentiation
  • Shared backward pass tested through both extension paths

kylebeggs and others added 8 commits February 2, 2026 14:04
…ights

Mooncake's @from_rrule bridge fails with Vector{Vector{Float64}} tangents
from ChainRulesCore rrules. This commit replaces the bridged rules with
native Mooncake rrule!! implementations.

Key changes:
- Add @is_primitive declarations for all _build_weights signatures
- Implement native rrule!! for Laplacian/Partial with PHS/IMQ/Gaussian
- Use zero_fcodual() for proper FData output type
- Read gradients from FData.data.nzval (not RData) in pullback

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Move backward pass infrastructure (caches, shape parameter derivatives,
operator second derivatives) from extensions into the main package so
both Enzyme and Mooncake extensions can share it. Remove ChainRulesCore
extension entirely — AD is now provided via native EnzymeRules and
Mooncake rrule!! implementations.
Replace CRC-bridged rules with direct EnzymeRules for basis functions,
operators, interpolators, and _build_weights. Add shape parameter (ε)
differentiation tests for IMQ and Gaussian bases.
Replace @from_rrule CRC bridge with native Mooncake rrule!!
implementations for all AD entry points. Add shape parameter (ε)
differentiation tests for IMQ and Gaussian bases.
Replace Enzyme basis macro with single AbstractRadialBasis rule, unify
_build_weights rules via Union{Partial, Laplacian} and shared pullback
core, consolidate Mooncake @is_primitive declarations, and extract
shared test utilities to eliminate ~520 lines of duplication.
IMQ/Gaussian have active ε field, so NoRData() causes increment!!
type mismatch when basis is captured in a closure.
Keep our branch's native Enzyme/Mooncake rules, delete CRC extension
files and tests from main's enzyme PR (#76). Retain DI tests and
docs additions from main.
@github-actions
Copy link
Contributor

github-actions bot commented Feb 15, 2026

Benchmark Results

main 305b61f... main / 305b61f...
Directional 2.41 ± 0.11 ms 2.54 ± 0.094 ms 0.951 ± 0.055
Directional (per point) 2.4 ± 0.11 ms 2.49 ± 0.1 ms 0.963 ± 0.058
Gradient 8.44 ± 0.32 ms 8.54 ± 0.31 ms 0.987 ± 0.052
MonomialBasis/dim=1/deg=0 0.0473 ± 0.014 μs 0.0469 ± 0.013 μs 1.01 ± 0.42
MonomialBasis/dim=1/deg=1 0.0772 ± 0.022 μs 0.0816 ± 0.022 μs 0.946 ± 0.37
MonomialBasis/dim=1/deg=2 0.0873 ± 0.021 μs 0.0694 ± 0.021 μs 1.26 ± 0.5
MonomialBasis/dim=2/deg=0 0.035 ± 0.013 μs 25.1 ± 14 ns 1.4 ± 0.94
MonomialBasis/dim=2/deg=1 25.6 ± 15 ns 29.3 ± 14 ns 0.872 ± 0.64
MonomialBasis/dim=2/deg=2 0.0404 ± 0.014 μs 0.0462 ± 0.013 μs 0.874 ± 0.39
MonomialBasis/dim=3/deg=0 0.0358 ± 0.015 μs 0.0411 ± 0.014 μs 0.869 ± 0.46
MonomialBasis/dim=3/deg=1 0.0405 ± 0.014 μs 0.0467 ± 0.014 μs 0.866 ± 0.39
MonomialBasis/dim=3/deg=2 0.0484 ± 0.014 μs 0.0487 ± 0.014 μs 0.994 ± 0.41
Partial 2.41 ± 0.12 ms 2.48 ± 0.1 ms 0.971 ± 0.064
RBF/Gaussian, exp(-(ε*r)²)
├─Shape factor: ε = 1
└─Polynomial augmentation: degree 0/0/∂ 9.68 ± 0.08 ns 9.68 ± 0.08 ns 1 ± 0.012
RBF/Gaussian, exp(-(ε*r)²)
├─Shape factor: ε = 1
└─Polynomial augmentation: degree 0/0/∂² 10.2 ± 0.17 ns 10.1 ± 0.09 ns 1.01 ± 0.019
RBF/Gaussian, exp(-(ε*r)²)
├─Shape factor: ε = 1
└─Polynomial augmentation: degree 0/0/∇ 17.1 ± 0.1 ns 17 ± 0.061 ns 1.01 ± 0.0069
RBF/Gaussian, exp(-(ε*r)²)
├─Shape factor: ε = 1
└─Polynomial augmentation: degree 0/0/∇² 18.5 ± 0.08 ns 18.5 ± 0.17 ns 0.996 ± 0.01
RBF/Gaussian, exp(-(ε*r)²)
├─Shape factor: ε = 1
└─Polynomial augmentation: degree 1/1/∂ 9.68 ± 0.061 ns 9.68 ± 0.08 ns 1 ± 0.01
RBF/Gaussian, exp(-(ε*r)²)
├─Shape factor: ε = 1
└─Polynomial augmentation: degree 1/1/∂² 10.1 ± 0.2 ns 10.2 ± 0.11 ns 0.993 ± 0.022
RBF/Gaussian, exp(-(ε*r)²)
├─Shape factor: ε = 1
└─Polynomial augmentation: degree 1/1/∇ 17.1 ± 0.14 ns 17 ± 0.061 ns 1 ± 0.009
RBF/Gaussian, exp(-(ε*r)²)
├─Shape factor: ε = 1
└─Polynomial augmentation: degree 1/1/∇² 18.5 ± 0.19 ns 18.6 ± 0.2 ns 0.998 ± 0.015
RBF/Gaussian, exp(-(ε*r)²)
├─Shape factor: ε = 1
└─Polynomial augmentation: degree 2/2/∂ 9.68 ± 0.08 ns 9.67 ± 0.09 ns 1 ± 0.012
RBF/Gaussian, exp(-(ε*r)²)
├─Shape factor: ε = 1
└─Polynomial augmentation: degree 2/2/∂² 10.1 ± 0.18 ns 10.1 ± 0.091 ns 0.999 ± 0.02
RBF/Gaussian, exp(-(ε*r)²)
├─Shape factor: ε = 1
└─Polynomial augmentation: degree 2/2/∇ 17.1 ± 0.13 ns 17 ± 0.061 ns 1 ± 0.0085
RBF/Gaussian, exp(-(ε*r)²)
├─Shape factor: ε = 1
└─Polynomial augmentation: degree 2/2/∇² 18.5 ± 0.16 ns 18.5 ± 0.09 ns 1 ± 0.01
RBF/Inverse Multiquadrics, 1/sqrt((r*ε)²+1)
├─Shape factor: ε = 1
└─Polynomial augmentation: degree 0/0/∂ 6.32 ± 0.001 ns 6.27 ± 0.06 ns 1.01 ± 0.0096
RBF/Inverse Multiquadrics, 1/sqrt((r*ε)²+1)
├─Shape factor: ε = 1
└─Polynomial augmentation: degree 0/0/∂² 14.2 ± 0.021 ns 14.2 ± 0.09 ns 0.999 ± 0.0065
RBF/Inverse Multiquadrics, 1/sqrt((r*ε)²+1)
├─Shape factor: ε = 1
└─Polynomial augmentation: degree 0/0/∇ 8.56 ± 0.1 ns 8.6 ± 0.08 ns 0.995 ± 0.015
RBF/Inverse Multiquadrics, 1/sqrt((r*ε)²+1)
├─Shape factor: ε = 1
└─Polynomial augmentation: degree 0/0/∇² 16 ± 0.09 ns 15.8 ± 0.08 ns 1.01 ± 0.0077
RBF/Inverse Multiquadrics, 1/sqrt((r*ε)²+1)
├─Shape factor: ε = 1
└─Polynomial augmentation: degree 1/1/∂ 6.32 ± 0.01 ns 6.27 ± 0.07 ns 1.01 ± 0.011
RBF/Inverse Multiquadrics, 1/sqrt((r*ε)²+1)
├─Shape factor: ε = 1
└─Polynomial augmentation: degree 1/1/∂² 14.2 ± 0.021 ns 14.2 ± 0.02 ns 1 ± 0.002
RBF/Inverse Multiquadrics, 1/sqrt((r*ε)²+1)
├─Shape factor: ε = 1
└─Polynomial augmentation: degree 1/1/∇ 8.55 ± 0.11 ns 8.6 ± 0.09 ns 0.994 ± 0.016
RBF/Inverse Multiquadrics, 1/sqrt((r*ε)²+1)
├─Shape factor: ε = 1
└─Polynomial augmentation: degree 1/1/∇² 16 ± 0.079 ns 15.8 ± 0.07 ns 1.01 ± 0.0067
RBF/Inverse Multiquadrics, 1/sqrt((r*ε)²+1)
├─Shape factor: ε = 1
└─Polynomial augmentation: degree 2/2/∂ 6.32 ± 0.011 ns 6.27 ± 0.069 ns 1.01 ± 0.011
RBF/Inverse Multiquadrics, 1/sqrt((r*ε)²+1)
├─Shape factor: ε = 1
└─Polynomial augmentation: degree 2/2/∂² 14.2 ± 0.08 ns 14.2 ± 0.02 ns 1 ± 0.0058
RBF/Inverse Multiquadrics, 1/sqrt((r*ε)²+1)
├─Shape factor: ε = 1
└─Polynomial augmentation: degree 2/2/∇ 8.56 ± 0.1 ns 8.67 ± 0.15 ns 0.988 ± 0.021
RBF/Inverse Multiquadrics, 1/sqrt((r*ε)²+1)
├─Shape factor: ε = 1
└─Polynomial augmentation: degree 2/2/∇² 16 ± 0.1 ns 15.8 ± 0.07 ns 1.01 ± 0.0078
RBF/Polyharmonic spline (r³)
└─Polynomial augmentation: degree 0/0/∂ 3.4 ± 0.039 ns 3.42 ± 0.001 ns 0.994 ± 0.011
RBF/Polyharmonic spline (r³)
└─Polynomial augmentation: degree 0/0/∂² 4.7 ± 0.01 ns 4.7 ± 0.01 ns 1 ± 0.003
RBF/Polyharmonic spline (r³)
└─Polynomial augmentation: degree 0/0/∇ 5.65 ± 0.02 ns 5.59 ± 0.041 ns 1.01 ± 0.0082
RBF/Polyharmonic spline (r³)
└─Polynomial augmentation: degree 0/0/∇² 3.42 ± 0.01 ns 3.42 ± 0.01 ns 1 ± 0.0041
RBF/Polyharmonic spline (r³)
└─Polynomial augmentation: degree 1/1/∂ 3.4 ± 0.04 ns 3.42 ± 0.01 ns 0.994 ± 0.012
RBF/Polyharmonic spline (r³)
└─Polynomial augmentation: degree 1/1/∂² 4.7 ± 0.009 ns 4.7 ± 0.01 ns 1 ± 0.0029
RBF/Polyharmonic spline (r³)
└─Polynomial augmentation: degree 1/1/∇ 5.65 ± 0.02 ns 5.59 ± 0.039 ns 1.01 ± 0.0079
RBF/Polyharmonic spline (r³)
└─Polynomial augmentation: degree 1/1/∇² 3.42 ± 0.01 ns 3.42 ± 0.01 ns 1 ± 0.0041
RBF/Polyharmonic spline (r³)
└─Polynomial augmentation: degree 2/2/∂ 3.4 ± 0.04 ns 3.42 ± 0.01 ns 0.994 ± 0.012
RBF/Polyharmonic spline (r³)
└─Polynomial augmentation: degree 2/2/∂² 4.7 ± 0.01 ns 4.7 ± 0.01 ns 1 ± 0.003
RBF/Polyharmonic spline (r³)
└─Polynomial augmentation: degree 2/2/∇ 5.65 ± 0.021 ns 5.59 ± 0.039 ns 1.01 ± 0.008
RBF/Polyharmonic spline (r³)
└─Polynomial augmentation: degree 2/2/∇² 3.42 ± 0.01 ns 3.42 ± 0.01 ns 1 ± 0.0041
RBF/Polyharmonic spline (r¹)
└─Polynomial augmentation: degree 0/0/∂ 4.27 ± 0.01 ns 4.27 ± 0.01 ns 1 ± 0.0033
RBF/Polyharmonic spline (r¹)
└─Polynomial augmentation: degree 0/0/∂² 5.58 ± 0.01 ns 5.56 ± 0.02 ns 1 ± 0.004
RBF/Polyharmonic spline (r¹)
└─Polynomial augmentation: degree 0/0/∇ 6.87 ± 0.12 ns 6.81 ± 0.011 ns 1.01 ± 0.018
RBF/Polyharmonic spline (r¹)
└─Polynomial augmentation: degree 0/0/∇² 4.27 ± 0.01 ns 4.27 ± 0.01 ns 1 ± 0.0033
RBF/Polyharmonic spline (r¹)
└─Polynomial augmentation: degree 1/1/∂ 4.27 ± 0.01 ns 4.27 ± 0.01 ns 1 ± 0.0033
RBF/Polyharmonic spline (r¹)
└─Polynomial augmentation: degree 1/1/∂² 5.58 ± 0.01 ns 5.56 ± 0.02 ns 1 ± 0.004
RBF/Polyharmonic spline (r¹)
└─Polynomial augmentation: degree 1/1/∇ 6.86 ± 0.1 ns 6.81 ± 0.01 ns 1.01 ± 0.015
RBF/Polyharmonic spline (r¹)
└─Polynomial augmentation: degree 1/1/∇² 4.27 ± 0.01 ns 4.27 ± 0.01 ns 1 ± 0.0033
RBF/Polyharmonic spline (r¹)
└─Polynomial augmentation: degree 2/2/∂ 4.27 ± 0.01 ns 4.27 ± 0.01 ns 1 ± 0.0033
RBF/Polyharmonic spline (r¹)
└─Polynomial augmentation: degree 2/2/∂² 5.58 ± 0.01 ns 5.56 ± 0.02 ns 1 ± 0.004
RBF/Polyharmonic spline (r¹)
└─Polynomial augmentation: degree 2/2/∇ 6.86 ± 0.11 ns 6.82 ± 0.01 ns 1.01 ± 0.016
RBF/Polyharmonic spline (r¹)
└─Polynomial augmentation: degree 2/2/∇² 4.27 ± 0.01 ns 4.27 ± 0.01 ns 1 ± 0.0033
RBF/Polyharmonic spline (r⁵)
└─Polynomial augmentation: degree 0/0/∂ 4.65 ± 0.001 ns 4.96 ± 0.001 ns 0.937 ± 0.00028
RBF/Polyharmonic spline (r⁵)
└─Polynomial augmentation: degree 0/0/∂² 4.96 ± 0.01 ns 4.96 ± 0.01 ns 1 ± 0.0029
RBF/Polyharmonic spline (r⁵)
└─Polynomial augmentation: degree 0/0/∇ 6.78 ± 0.11 ns 6.12 ± 0.081 ns 1.11 ± 0.023
RBF/Polyharmonic spline (r⁵)
└─Polynomial augmentation: degree 0/0/∇² 3.42 ± 0.001 ns 3.11 ± 0.009 ns 1.1 ± 0.0032
RBF/Polyharmonic spline (r⁵)
└─Polynomial augmentation: degree 1/1/∂ 4.65 ± 0.001 ns 4.96 ± 0.001 ns 0.937 ± 0.00028
RBF/Polyharmonic spline (r⁵)
└─Polynomial augmentation: degree 1/1/∂² 4.96 ± 0.01 ns 4.96 ± 0.01 ns 1 ± 0.0029
RBF/Polyharmonic spline (r⁵)
└─Polynomial augmentation: degree 1/1/∇ 6.19 ± 0.011 ns 6.12 ± 0.089 ns 1.01 ± 0.015
RBF/Polyharmonic spline (r⁵)
└─Polynomial augmentation: degree 1/1/∇² 3.42 ± 0.001 ns 3.11 ± 0.009 ns 1.1 ± 0.0032
RBF/Polyharmonic spline (r⁵)
└─Polynomial augmentation: degree 2/2/∂ 4.65 ± 0.001 ns 4.96 ± 0.001 ns 0.937 ± 0.00028
RBF/Polyharmonic spline (r⁵)
└─Polynomial augmentation: degree 2/2/∂² 4.96 ± 0.01 ns 4.96 ± 0.01 ns 1 ± 0.0029
RBF/Polyharmonic spline (r⁵)
└─Polynomial augmentation: degree 2/2/∇ 6.19 ± 0.011 ns 6.12 ± 0.089 ns 1.01 ± 0.015
RBF/Polyharmonic spline (r⁵)
└─Polynomial augmentation: degree 2/2/∇² 3.42 ± 0.001 ns 3.11 ± 0.009 ns 1.1 ± 0.0032
RBF/Polyharmonic spline (r⁷)
└─Polynomial augmentation: degree 0/0/∂ 10.4 ± 0.07 ns 10.3 ± 0.061 ns 1 ± 0.009
RBF/Polyharmonic spline (r⁷)
└─Polynomial augmentation: degree 0/0/∂² 5.27 ± 0.01 ns 5.27 ± 0.01 ns 1 ± 0.0027
RBF/Polyharmonic spline (r⁷)
└─Polynomial augmentation: degree 0/0/∇ 12.5 ± 0.08 ns 13 ± 0.08 ns 0.963 ± 0.0086
RBF/Polyharmonic spline (r⁷)
└─Polynomial augmentation: degree 0/0/∇² 8.14 ± 0.04 ns 7.98 ± 0.12 ns 1.02 ± 0.016
RBF/Polyharmonic spline (r⁷)
└─Polynomial augmentation: degree 1/1/∂ 10.4 ± 0.07 ns 10.3 ± 0.071 ns 1 ± 0.0097
RBF/Polyharmonic spline (r⁷)
└─Polynomial augmentation: degree 1/1/∂² 5.27 ± 0.01 ns 5.27 ± 0.01 ns 1 ± 0.0027
RBF/Polyharmonic spline (r⁷)
└─Polynomial augmentation: degree 1/1/∇ 12.5 ± 0.081 ns 12.5 ± 0.15 ns 1 ± 0.014
RBF/Polyharmonic spline (r⁷)
└─Polynomial augmentation: degree 1/1/∇² 8.14 ± 0.03 ns 8 ± 0.14 ns 1.02 ± 0.018
RBF/Polyharmonic spline (r⁷)
└─Polynomial augmentation: degree 2/2/∂ 10.4 ± 0.08 ns 10.3 ± 0.071 ns 1 ± 0.01
RBF/Polyharmonic spline (r⁷)
└─Polynomial augmentation: degree 2/2/∂² 5.27 ± 0.01 ns 5.27 ± 0.01 ns 1 ± 0.0027
RBF/Polyharmonic spline (r⁷)
└─Polynomial augmentation: degree 2/2/∇ 12.5 ± 0.08 ns 12.5 ± 0.14 ns 1 ± 0.013
RBF/Polyharmonic spline (r⁷)
└─Polynomial augmentation: degree 2/2/∇² 8.14 ± 0.05 ns 7.99 ± 0.11 ns 1.02 ± 0.015
time_to_load 0.661 ± 0.0055 s 0.676 ± 0.0013 s 0.978 ± 0.0084

Benchmark Plots

A plot of the benchmark results have been uploaded as an artifact to the workflow run for this PR.
Go to "Actions"->"Benchmark a pull request"->[the most recent run]->"Artifacts" (at the bottom).

The Interpolator(x, y, basis) constructor calls Symmetric(A) \ b which
hits unsupported LAPACK foreigncalls in Mooncake. This adds a primitive
that runs the solve opaquely in the forward pass and uses the implicit
function theorem (Δy = (A⁻¹ [Δw_rbf; Δw_mon])[1:k]) in the backward
pass to propagate cotangents back to y.
@codecov
Copy link

codecov bot commented Feb 15, 2026

Codecov Report

❌ Patch coverage is 99.04762% with 1 line in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/solve/backward.jl 96.77% 1 Missing ⚠️
Files with missing lines Coverage Δ
src/RadialBasisFunctions.jl 100.00% <ø> (ø)
src/interpolation_backward.jl 100.00% <100.00%> (ø)
src/solve/ad_shared.jl 100.00% <100.00%> (ø)
src/solve/operator_second_derivatives.jl 100.00% <100.00%> (ø)
src/solve/backward.jl 98.82% <96.77%> (-0.67%) ⬇️
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

- Use ::Type{RT} where {RT} instead of ::Type{<:Active} for
  array-returning rules (operator call, _eval_op, batch interpolator)
  so Enzyme can match Duplicated return types
- Add xi::Const basis function rule variant for closures using
  function_annotation=Enzyme.Const
- Fix SVector immutability in _build_weights reverse by using
  assignment instead of .+= on immutable static array elements
- Add trailing comma in Mooncake @is_primitive to satisfy Runic
Remove RadialBasisOperator from Enzyme tape to avoid LLVM IR verification
errors on Julia 1.11. The operator is Const and accessible via op.val in
the reverse pass. Also handle DuplicatedNoNeed return type in shadow
allocation and remove update_weights! from augmented_primal since Const
operators have pre-computed weights.
@kylebeggs kylebeggs merged commit 4cba45c into main Feb 17, 2026
25 of 26 checks passed
@kylebeggs kylebeggs deleted the mooncake-native-rrule branch February 20, 2026 20:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant