diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 00000000..36b0782c --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,22 @@ +# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates +version: 2 +enable-beta-ecosystems: true # Julia ecosystem +updates: + - package-ecosystem: "github-actions" + directory: "/" # Location of package manifests + schedule: + interval: "weekly" + ignore: + - dependency-name: "crate-ci/typos" + update-types: ["version-update:semver-patch", "version-update:semver-minor"] + - package-ecosystem: "julia" + directories: + - "/" + - "/docs" + - "/test" + schedule: + interval: "daily" + groups: + all-julia-packages: + patterns: + - "*" diff --git a/.github/workflows/dependabot.yml b/.github/workflows/dependabot.yml deleted file mode 100644 index 6dd46e93..00000000 --- a/.github/workflows/dependabot.yml +++ /dev/null @@ -1,20 +0,0 @@ -# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates -version: 2 -enable-beta-ecosystems: true # Julia ecosystem -updates: - - package-ecosystem: "github-actions" - directory: "/" # Location of package manifests - schedule: - interval: "weekly" - - package-ecosystem: "julia" - directories: # Location of Julia projects - - "/RadialBasisFunctions" - - "/RadialBasisFunctions/docs" - - "/RadialBasisFunctions/test" - schedule: - interval: "daily" - groups: - # Group all Julia package updates into a single PR: - all-julia-packages: - patterns: - - "*" diff --git a/.github/workflows/downgrade.yml b/.github/workflows/downgrade.yml index 381a56c9..39b2775c 100644 --- a/.github/workflows/downgrade.yml +++ b/.github/workflows/downgrade.yml @@ -32,7 +32,7 @@ jobs: - uses: julia-actions/cache@v2 - uses: julia-actions/julia-downgrade-compat@v1 with: - skip: LinearAlgebra, Random, Statistics + skip: LinearAlgebra, Random, Statistics, Enzyme, EnzymeCore, Mooncake, ChainRulesCore, FiniteDifferences projects: ., test, docs - uses: julia-actions/julia-buildpkg@v1 env: diff --git a/CLAUDE.md b/CLAUDE.md index 24d4de09..fa6ac2ec 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -103,6 +103,7 @@ The package requires `Vector{AbstractVector}` input format (not matrices). Each - The package requires Julia 1.10+ (see Project.toml compatibility) - Uses KernelAbstractions.jl for GPU/CPU parallelization - Data must be in `Vector{AbstractVector}` format (not matrices) - each point needs inferrable dimension (e.g., `SVector{2,Float64}`) +- **Autodiff examples**: Always use DifferentiationInterface.jl for AD examples in docs and tests. This provides a unified interface over Enzyme.jl and Mooncake.jl backends. --- diff --git a/Project.toml b/Project.toml index 65fc5f9a..319bcfd1 100644 --- a/Project.toml +++ b/Project.toml @@ -19,10 +19,13 @@ StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" SymRCM = "286e6d88-80af-4590-acc9-0001b223b9bd" [weakdeps] +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" [extensions] RadialBasisFunctionsChainRulesCoreExt = "ChainRulesCore" +RadialBasisFunctionsEnzymeExt = ["Enzyme", "EnzymeCore"] RadialBasisFunctionsMooncakeExt = ["ChainRulesCore", "Mooncake"] [compat] @@ -30,6 +33,8 @@ ChainRulesCore = "1.20" ChunkSplitters = "3" Combinatorics = "1" Distances = "0.9, 0.10" +Enzyme = "0.13" +EnzymeCore = "0.8" FiniteDifferences = "0.12.33" KernelAbstractions = "0.9.34" LinearAlgebra = "1" diff --git a/benchmark/Project.toml b/benchmark/Project.toml index 9cf457e1..e32f6f97 100644 --- a/benchmark/Project.toml +++ b/benchmark/Project.toml @@ -7,5 +7,5 @@ StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" [compat] BenchmarkTools = "1.5" -RadialBasisFunctions = "0.2.5" +RadialBasisFunctions = "0.3" julia = "1.9" diff --git a/docs/Project.toml b/docs/Project.toml index 7a6bb4b3..a9c0c771 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,9 +1,16 @@ [deps] CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab" +DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" DocumenterVitepress = "4710194d-e776-4893-9690-8d956a29c365" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" LiveServer = "16fef848-5104-11e9-1b77-fb7a48bbb589" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" RadialBasisFunctions = "79ee0514-adf7-4479-8807-6f72ea8967e8" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" + +[compat] +Documenter = "1" +DocumenterVitepress = "0.3" diff --git a/docs/make.jl b/docs/make.jl index 6ef18fb8..f5cb7869 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -21,6 +21,7 @@ makedocs(; pages = [ "Home" => "index.md", "Getting Started" => "getting_started.md", + "Automatic Differentiation" => "autodiff.md", "Quick Reference" => "quickref.md", "Theory" => "theory.md", "Internals" => "internals.md", diff --git a/docs/src/.vitepress/config.mts b/docs/src/.vitepress/config.mts index 33c85b4f..7b9dd86d 100644 --- a/docs/src/.vitepress/config.mts +++ b/docs/src/.vitepress/config.mts @@ -1,9 +1,11 @@ import { defineConfig } from 'vitepress' import { tabsMarkdownPlugin } from 'vitepress-plugin-tabs' -import mathjax3 from "markdown-it-mathjax3"; +import { mathjaxPlugin } from './mathjax-plugin' import footnote from "markdown-it-footnote"; import path from 'path' +const mathjax = mathjaxPlugin() + function getBaseRepository(base: string): string { if (!base || base === '/') return '/'; const parts = base.split('/').filter(Boolean); @@ -18,6 +20,8 @@ const navTemp = { nav: [ { text: 'Home', link: '/index' }, { text: 'Getting Started', link: '/getting_started' }, + { text: 'Autodiff', link: '/autodiff' }, + { text: 'Quick Reference', link: '/quickref' }, { text: 'Theory', link: '/theory' }, { text: 'Internals', link: '/internals' }, { text: 'API', link: '/api' } @@ -45,6 +49,9 @@ export default defineConfig({ ], vite: { + plugins: [ + mathjax.vitePlugin, + ], define: { __DEPLOY_ABSPATH__: JSON.stringify('/'), }, @@ -68,11 +75,10 @@ export default defineConfig({ }, }, markdown: { - math: true, config(md) { - md.use(tabsMarkdownPlugin), - md.use(mathjax3), - md.use(footnote) + md.use(tabsMarkdownPlugin); + md.use(footnote); + mathjax.markdownConfig(md); }, theme: { light: "github-light", @@ -92,6 +98,8 @@ export default defineConfig({ sidebar: [ { text: 'Home', link: '/index' }, { text: 'Getting Started', link: '/getting_started' }, + { text: 'Autodiff', link: '/autodiff' }, + { text: 'Quick Reference', link: '/quickref' }, { text: 'Theory', link: '/theory' }, { text: 'Internals', link: '/internals' }, { text: 'API', link: '/api' } diff --git a/docs/src/.vitepress/mathjax-plugin.ts b/docs/src/.vitepress/mathjax-plugin.ts new file mode 100644 index 00000000..36641e81 --- /dev/null +++ b/docs/src/.vitepress/mathjax-plugin.ts @@ -0,0 +1,161 @@ +// adapter from https://github.com/orgs/vuepress-theme-hope/discussions/5178#discussioncomment-15642629 +// mathjax-plugin.ts +// @ts-ignore +import MathJax from '@mathjax/src' +import type { Plugin as VitePlugin } from 'vite' +import type MarkdownIt from 'markdown-it' +import { tex as mdTex } from '@mdit/plugin-tex' + +const mathjaxStyleModuleID = 'virtual:mathjax-styles.css' + +interface MathJaxOptions { + font?: string +} + +async function initializeMathJax(options: MathJaxOptions = {}) { + const font = options.font || 'mathjax-newcm' + + const config: any = { + loader: { + load: [ + 'input/tex', + 'output/svg', + '[tex]/boldsymbol', + '[tex]/braket', + '[tex]/mathtools', + ], + }, + tex: { + tags: 'ams', + packages: { + '[+]': ['boldsymbol', 'braket', 'mathtools'], + }, + }, + output: { + font, + displayOverflow: 'linebreak', + mtextInheritFont: true, + }, + svg: { + fontCache: 'none', // critical: avoids async font loading + }, + } + + await MathJax.init(config) + + const fontData = MathJax.config.svg?.fontData + + if (fontData?.dynamicFiles) { + const dynamicFiles = fontData.dynamicFiles + const dynamicPrefix: string = + fontData.OPTIONS?.dynamicPrefix || fontData.options?.dynamicPrefix + + if (dynamicPrefix) { + await Promise.all( + Object.keys(dynamicFiles).map(async (name) => { + try { + await import(/* @vite-ignore */ `${dynamicPrefix}/${name}.js`) + dynamicFiles[name]?.setup?.(MathJax.startup.output.font) + } catch { + // Silently ignore missing dynamic files + } + }), + ) + } + } +} + +export function mathjaxPlugin(options: MathJaxOptions = {}) { + let adaptor: any + let initialized = false + + async function ensureInitialized() { + if (!initialized) { + await initializeMathJax(options) + adaptor = MathJax.startup.adaptor + initialized = true + } + } + + function renderMath(content: string, displayMode: boolean): string { + if (!initialized) { + throw new Error('MathJax not initialized') + } + + const node = MathJax.tex2svg(content, { display: displayMode }) + + // Prevent Vue from touching MathJax output + adaptor.setAttribute(node, 'v-pre', '') + + let html = adaptor.outerHTML(node) + + // Preserve spaces inside mjx-break (SVG only) + html = html.replace( + /(.*?)<\/mjx-break>/g, + (_: string, attr: string, inner: string) => + `${inner.replace(/ /g, ' ')}`, + ) + + // Wrap only display equations (not inline math) + html = html.replace( + /(]*display="true"[^>]*>)([\s\S]*?)(<\/mjx-container>)/, + '
$1$2$3
' + ) + + return html + } + + function getMathJaxStyles(): string { + return initialized + ? adaptor.textContent(MathJax.svgStylesheet()) || '' + : '' + } + + function resetMathJax(): void { + if (!initialized) return + MathJax.texReset() + MathJax.typesetClear() + } + + function viteMathJax(): VitePlugin { + const virtualModuleID = '\0' + mathjaxStyleModuleID + + return { + name: 'mathjax-styles', + + resolveId(id) { + if (id === mathjaxStyleModuleID) { + return virtualModuleID + } + }, + + async load(id) { + if (id === virtualModuleID) { + await ensureInitialized() + return getMathJaxStyles() + } + }, + } + } + + function mdMathJax(md: MarkdownIt): void { + mdTex(md, { + render: renderMath, + }) + + const orig = md.render + md.render = function (...args) { + resetMathJax() + return orig.apply(this, args) + } + } + + const init = ensureInitialized() + + return { + vitePlugin: viteMathJax(), + markdownConfig: mdMathJax, + styleModuleID: mathjaxStyleModuleID, + init, + } +} diff --git a/docs/src/autodiff.md b/docs/src/autodiff.md new file mode 100644 index 00000000..033ca10d --- /dev/null +++ b/docs/src/autodiff.md @@ -0,0 +1,197 @@ +# Automatic Differentiation + +RadialBasisFunctions.jl supports automatic differentiation (AD) through two package extensions: + +- **Mooncake.jl** - Reverse-mode AD with support for mutation +- **Enzyme.jl** - Native EnzymeRules for high-performance reverse-mode AD + +All examples use [DifferentiationInterface.jl](https://github.com/gdalle/DifferentiationInterface.jl) which provides a unified API over different AD backends. + +## Compatibility + +Enzyme.jl currently has known issues with Julia 1.12+. If you encounter problems, use Julia < 1.12 or switch to Mooncake.jl. + +See: [Enzyme.jl#2699](https://github.com/EnzymeAD/Enzyme.jl/issues/2699) + +## Implementation Status + +Enzyme.jl has native `EnzymeRules` for all supported operations. Mooncake.jl support currently uses the `@from_rrule` macro to convert ChainRulesCore rules. Native Mooncake rules are planned for a future release. + +## Differentiating Through Operators + +The most common use case is differentiating a loss function with respect to field values while keeping the operator fixed. Create the operator once outside the loss function, then differentiate through its application. + +```@example autodiff +using RadialBasisFunctions +using StaticArrays +import DifferentiationInterface as DI +import Mooncake + +# Create points and operator (outside loss function) +N = 49 +points = [SVector{2}(0.1 + 0.8 * i / 7, 0.1 + 0.8 * j / 7) for i in 1:7 for j in 1:7] +values = sin.(getindex.(points, 1)) .+ cos.(getindex.(points, 2)) + +lap = laplacian(points) + +# Loss function: minimize squared Laplacian +function loss(v) + result = lap(v) + return sum(result .^ 2) +end + +# Compute gradient using DifferentiationInterface +backend = DI.AutoMooncake(; config=nothing) # also supports DI.AutoEnzyme() +grad = DI.gradient(loss, backend, values) +grad[1:5] # Show first 5 gradient values +``` + +This works with any operator type: + +```@example autodiff +# Gradient operator (vector-valued) +∇f = gradient(points) + +function loss_grad(v) + result = ∇f(v) + return sum(result .^ 2) +end + +grad = DI.gradient(loss_grad, backend, values) +grad[1:5] +``` + +```@example autodiff +# Partial derivative operator +∂x = partial(points, 1, 1) + +function loss_partial(v) + result = ∂x(v) + return sum(result .^ 2) +end + +grad = DI.gradient(loss_partial, backend, values) +grad[1:5] +``` + +## Differentiating Through Interpolators + +When differentiating through interpolation, the `Interpolator` must be constructed inside the loss function since changing the input values changes the interpolation weights. + +```@example autodiff +N_interp = 30 +points_interp = [SVector{2}(0.5 + 0.4 * cos(2π * i / N_interp), 0.5 + 0.4 * sin(2π * i / N_interp)) for i in 1:N_interp] +values_interp = sin.(getindex.(points_interp, 1)) +eval_points = [SVector{2}(0.5, 0.5), SVector{2}(0.6, 0.6)] + +# Loss function - must rebuild interpolator inside +function loss_interp(v) + interp = Interpolator(points_interp, v) + result = interp(eval_points) + return sum(result .^ 2) +end + +grad = DI.gradient(loss_interp, backend, values_interp) +grad[1:5] +``` + +## Differentiating Basis Functions Directly + +For low-level control, you can differentiate basis function evaluations directly. This is useful for custom applications or understanding the underlying derivatives. + +```@example autodiff +x = [0.5, 0.5] +xi = [0.3, 0.4] + +# PHS basis +phs = PHS(3) +function loss_phs(xv) + return phs(xv, xi)^2 +end + +grad = DI.gradient(loss_phs, backend, x) +``` + +All basis types are supported: + +```@example autodiff +# IMQ basis +imq = IMQ(1.0) +function loss_imq(xv) + return imq(xv, xi)^2 +end + +grad = DI.gradient(loss_imq, backend, x) +``` + +```@example autodiff +# Gaussian basis +gauss = Gaussian(1.0) +function loss_gauss(xv) + return gauss(xv, xi)^2 +end + +grad = DI.gradient(loss_gauss, backend, x) +``` + +## Differentiating Weight Construction + +For advanced use cases like mesh optimization or shape parameter tuning, you can differentiate through the weight construction process using the internal `_build_weights` function. + +```@example autodiff +# Using Mooncake for weight construction +N_weights = 25 +points_weights = [SVector{2}(0.1 + 0.8 * i / 5, 0.1 + 0.8 * j / 5) for i in 1:5 for j in 1:5] +adjl = RadialBasisFunctions.find_neighbors(points_weights, 10) +basis = PHS(3; poly_deg=2) +ℒ = Partial(1, 1) # First derivative in x + +# Loss function w.r.t. point positions +function loss_weights(pts) + pts_vec = [SVector{2}(pts[2*i-1], pts[2*i]) for i in 1:N_weights] + W = RadialBasisFunctions._build_weights(ℒ, pts_vec, pts_vec, adjl, basis) + return sum(W.nzval .^ 2) +end + +pts_flat = vcat([collect(p) for p in points_weights]...) +grad = DI.gradient(loss_weights, backend, pts_flat) +grad[1:6] # Gradients for first 3 points (x,y pairs) +``` + +This also works with the Laplacian operator and different basis types: + +```@example autodiff +ℒ_lap = Laplacian() +basis_imq = IMQ(1.0; poly_deg=2) + +function loss_weights_lap(pts) + pts_vec = [SVector{2}(pts[2*i-1], pts[2*i]) for i in 1:N_weights] + W = RadialBasisFunctions._build_weights(ℒ_lap, pts_vec, pts_vec, adjl, basis_imq) + return sum(W.nzval .^ 2) +end + +grad = DI.gradient(loss_weights_lap, backend, pts_flat) +grad[1:6] +``` + +## Supported Components + +| Component | Enzyme | Mooncake | +|-----------|:------:|:--------:| +| Operator evaluation (`op(values)`) | ✓ | ✓ | +| Interpolator evaluation | ✓ | ✓ | +| Basis functions (PHS, IMQ, Gaussian) | ✓ | ✓ | +| Weight construction (`_build_weights`) | ✓ | ✓ | +| Shape parameter (ε) differentiation | ✓ | ✓ | + +## Using Enzyme Backend + +Switch to Enzyme by changing the backend (requires Julia < 1.12): + +```julia +import DifferentiationInterface as DI +import Enzyme + +backend = DI.AutoEnzyme() +grad = DI.gradient(loss, backend, values) +``` diff --git a/ext/RadialBasisFunctionsChainRulesCoreExt/RadialBasisFunctionsChainRulesCoreExt.jl b/ext/RadialBasisFunctionsChainRulesCoreExt/RadialBasisFunctionsChainRulesCoreExt.jl index 691a8e03..1783bfe0 100644 --- a/ext/RadialBasisFunctionsChainRulesCoreExt/RadialBasisFunctionsChainRulesCoreExt.jl +++ b/ext/RadialBasisFunctionsChainRulesCoreExt/RadialBasisFunctionsChainRulesCoreExt.jl @@ -13,7 +13,7 @@ Includes rrules for: - Basis function evaluation (basis_rules.jl) - Operator application (operator_rules.jl) - Interpolator evaluation (interpolation_rules.jl) -- Weight construction for shape optimization (build_weights_*.jl) +- Weight construction for shape optimization (build_weights_rrule.jl) """ module RadialBasisFunctionsChainRulesCoreExt @@ -29,11 +29,24 @@ import RadialBasisFunctions: _build_collocation_matrix!, _build_rhs! # Import types we need import RadialBasisFunctions: RadialBasisOperator, Interpolator -import RadialBasisFunctions: AbstractRadialBasis, PHS1, PHS3, PHS5, PHS7, IMQ, Gaussian +import RadialBasisFunctions: AbstractRadialBasis, PHS, PHS1, PHS3, PHS5, PHS7, IMQ, Gaussian import RadialBasisFunctions: VectorValuedOperator, ScalarValuedOperator import RadialBasisFunctions: MonomialBasis, BoundaryData import RadialBasisFunctions: Partial, Laplacian +# Import backward pass support from main package +import RadialBasisFunctions: StencilForwardCache, WeightsBuildForwardCache +import RadialBasisFunctions: backward_linear_solve!, backward_collocation! +import RadialBasisFunctions: backward_rhs_partial!, backward_rhs_laplacian! +import RadialBasisFunctions: backward_stencil_partial!, backward_stencil_laplacian! +import RadialBasisFunctions: backward_stencil_partial_with_ε!, backward_stencil_laplacian_with_ε! +import RadialBasisFunctions: _forward_with_cache +import RadialBasisFunctions: grad_applied_partial_wrt_x, grad_applied_partial_wrt_xi +import RadialBasisFunctions: grad_applied_laplacian_wrt_x, grad_applied_laplacian_wrt_xi + +# Import shape parameter derivative functions +import RadialBasisFunctions: ∂φ_∂ε, ∂Laplacian_φ_∂ε, ∂Partial_φ_∂ε + # Import the gradient function for basis functions (not exported from main module) const ∇ = RadialBasisFunctions.∇ const ∂ = RadialBasisFunctions.∂ @@ -44,9 +57,6 @@ include("basis_rules.jl") include("interpolation_rules.jl") # Shape optimization support: rrules for _build_weights -include("build_weights_cache.jl") -include("operator_second_derivatives.jl") -include("build_weights_backward.jl") include("build_weights_rrule.jl") end # module diff --git a/ext/RadialBasisFunctionsChainRulesCoreExt/basis_rules.jl b/ext/RadialBasisFunctionsChainRulesCoreExt/basis_rules.jl index 7fce370b..08cb27aa 100644 --- a/ext/RadialBasisFunctionsChainRulesCoreExt/basis_rules.jl +++ b/ext/RadialBasisFunctionsChainRulesCoreExt/basis_rules.jl @@ -107,3 +107,55 @@ function ChainRulesCore.rrule(basis::Gaussian, x::AbstractVector, xi::AbstractVe return y, gaussian_pullback end + +# ============================================================================= +# Constructor rrules for shape parameter differentiation +# ============================================================================= +# These rules enable gradients to flow through basis construction +# from Tangent{Gaussian/IMQ}(ε=Δε) back to the input ε. + +# Single rrule with keyword argument that handles all cases +# (Julia desugars positional calls to keyword calls when defaults exist) +function ChainRulesCore.rrule(::Type{Gaussian}, ε::T; poly_deg::Int = 2) where {T} + basis = Gaussian(ε; poly_deg = poly_deg) + + function gaussian_constructor_pullback(Δbasis) + # Extract ε gradient from struct tangent + if Δbasis isa ChainRulesCore.Tangent + Δε = Δbasis.ε + # Handle NoTangent case + if Δε isa NoTangent + return NoTangent(), zero(T) + end + return NoTangent(), Δε + elseif Δbasis isa NoTangent || Δbasis isa ZeroTangent + return NoTangent(), zero(T) + else + return NoTangent(), zero(T) + end + end + + return basis, gaussian_constructor_pullback +end + +function ChainRulesCore.rrule(::Type{IMQ}, ε::T; poly_deg::Int = 2) where {T} + basis = IMQ(ε; poly_deg = poly_deg) + + function imq_constructor_pullback(Δbasis) + # Extract ε gradient from struct tangent + if Δbasis isa ChainRulesCore.Tangent + Δε = Δbasis.ε + # Handle NoTangent case + if Δε isa NoTangent + return NoTangent(), zero(T) + end + return NoTangent(), Δε + elseif Δbasis isa NoTangent || Δbasis isa ZeroTangent + return NoTangent(), zero(T) + else + return NoTangent(), zero(T) + end + end + + return basis, imq_constructor_pullback +end diff --git a/ext/RadialBasisFunctionsChainRulesCoreExt/build_weights_backward.jl b/ext/RadialBasisFunctionsChainRulesCoreExt/build_weights_backward.jl deleted file mode 100644 index 6ff1add5..00000000 --- a/ext/RadialBasisFunctionsChainRulesCoreExt/build_weights_backward.jl +++ /dev/null @@ -1,318 +0,0 @@ -#= -Backward pass functions for _build_weights rrule. - -The backward pass computes: - Given Δw (cotangent of weights), compute Δdata and Δeval_points - -Key steps per stencil: -1. Pad cotangent: Δλ = [Δw; 0] -2. Solve adjoint: η = A⁻ᵀ Δλ -3. Compute: ΔA = -η λᵀ, Δb = η -4. Chain through RHS: accumulate to Δeval_point and Δdata[neighbors] -5. Chain through collocation: accumulate to Δdata[neighbors] -=# - -using LinearAlgebra: dot - -""" - backward_linear_solve!(ΔA, Δb, Δw, cache) - -Compute cotangents of collocation matrix A and RHS vector b -from cotangent of weights Δw. - -Given: Aλ = b, w = λ[1:k] -We have: Δλ = [Δw; 0] (padded with zeros for monomial part) - -Using implicit function theorem: - η = A⁻ᵀ Δλ - ΔA = -η λᵀ - Δb = η -""" -function backward_linear_solve!( - ΔA::AbstractMatrix{T}, - Δb::AbstractVecOrMat{T}, - Δw::AbstractVecOrMat{T}, - cache::StencilForwardCache{T}, - ) where {T} - k = cache.k - nmon = cache.nmon - n = k + nmon - num_ops = size(cache.lambda, 2) - - # Pad Δw with zeros for monomial part - Δλ = zeros(T, n, num_ops) - Δλ[1:k, :] .= Δw - - # Solve adjoint system: A'η = Δλ - # The matrix is symmetric, so A' = A - η = cache.A_mat \ Δλ - - # ΔA = -η * λᵀ (outer product, accumulated across operators) - # For symmetric A, we need to account for the structure - fill!(ΔA, zero(T)) - for op_idx in 1:num_ops - η_vec = view(η, :, op_idx) - λ_vec = view(cache.lambda, :, op_idx) - for j in 1:n - for i in 1:n - ΔA[i, j] -= η_vec[i] * λ_vec[j] - end - end - end - - # Δb = η - Δb .= η - - return nothing -end - -""" - backward_collocation!(Δdata, ΔA, neighbors, data, basis, mon, k) - -Chain rule through collocation matrix construction. - -The collocation matrix has structure: - A[i,j] = φ(xi, xj) for i,j ≤ k (RBF block) - A[i,k+j] = pⱼ(xi) for i ≤ k (polynomial block) - -For RBF block (using ∇φ from existing basis_rules): - Δxi += ΔA[i,j] * ∇φ(xi, xj) - Δxj -= ΔA[i,j] * ∇φ(xi, xj) (by symmetry of φ(x-y)) - -For polynomial block: - Δxi += ΔA[i,k+j] * ∇pⱼ(xi) - -Note: A is symmetric, so we need to handle both triangles. -""" -function backward_collocation!( - Δdata::Vector{Vector{T}}, - ΔA::AbstractMatrix{T}, - neighbors::Vector{Int}, - data::AbstractVector, - basis::AbstractRadialBasis, - mon::MonomialBasis{Dim, Deg}, - k::Int, - ) where {T, Dim, Deg} - grad_φ = ∇(basis) - n = k + binomial(Dim + Deg, Deg) - - # RBF block: accumulate gradients from symmetric matrix - # Only upper triangle stored, but gradients flow both ways - @inbounds for j in 1:k - xj = data[neighbors[j]] - for i in 1:(j - 1) # Skip diagonal (i == j) since φ(x,x) = 0 always, no gradient contribution - xi = data[neighbors[i]] - - # Get gradient of basis function - ∇φ_ij = grad_φ(xi, xj) - - # ΔA[i,j] contributes to both Δxi and Δxj - # For symmetric matrix, ΔA[i,j] == ΔA[j,i] conceptually - # We need to sum contributions from both triangles - scale = ΔA[i, j] + ΔA[j, i] - - # φ depends on xi - xj, so: - # ∂φ/∂xi = ∇φ, ∂φ/∂xj = -∇φ - Δdata[neighbors[i]] .+= scale .* ∇φ_ij - Δdata[neighbors[j]] .-= scale .* ∇φ_ij - end - end - - # Polynomial block: A[i, k+j] = pⱼ(xi) - # Need gradient of monomial basis w.r.t. xi - if Deg > -1 - nmon = binomial(Dim + Deg, Deg) - ∇p = zeros(T, nmon, Dim) - - @inbounds for i in 1:k - xi = data[neighbors[i]] - - # Compute gradient of all monomials at xi - ∇mon = RadialBasisFunctions.∇(mon) - ∇mon(∇p, xi) - - # Accumulate gradient from polynomial block - for j in 1:nmon - # ΔA[i, k+j] contributes to Δxi via ∇pⱼ - # Also ΔA[k+j, i] from transpose block - scale = ΔA[i, k + j] + ΔA[k + j, i] - Δdata[neighbors[i]] .+= scale .* view(∇p, j, :) - end - end - end - - return nothing -end - -""" - backward_rhs!(Δdata, Δeval_point, Δb, neighbors, eval_point, data, basis, ℒ, k, dim) - -Chain rule through RHS vector construction for Partial operator. - -RHS structure: - b[i] = ℒφ(eval_point, xi) for i = 1:k - b[k+j] = ℒpⱼ(eval_point) for j = 1:nmon - -For RBF section, we need: - ∂/∂eval_point [ℒφ(eval_point, xi)] - ∂/∂xi [ℒφ(eval_point, xi)] - -For polynomial section, we need: - ∂/∂eval_point [ℒpⱼ(eval_point)] -""" -function backward_rhs_partial!( - Δdata::Vector{Vector{T}}, - Δeval_point::Vector{T}, - Δb::AbstractVecOrMat{T}, - neighbors::Vector{Int}, - eval_point, - data::AbstractVector, - basis::AbstractRadialBasis, - dim::Int, - k::Int, - ) where {T} - num_ops = size(Δb, 2) - - # Get gradient functions for the applied partial operator - grad_Lφ_x = grad_applied_partial_wrt_x(basis, dim) - grad_Lφ_xi = grad_applied_partial_wrt_xi(basis, dim) - - # RBF section: b[i] = ∂φ/∂x[dim](eval_point, xi) - @inbounds for i in 1:k - xi = data[neighbors[i]] - - # Gradient w.r.t. eval_point - ∇Lφ_x = grad_Lφ_x(eval_point, xi) - # Gradient w.r.t. xi - ∇Lφ_xi = grad_Lφ_xi(eval_point, xi) - - # Accumulate across operators - for op_idx in 1:num_ops - Δb_val = Δb[i, op_idx] - Δeval_point .+= Δb_val .* ∇Lφ_x - Δdata[neighbors[i]] .+= Δb_val .* ∇Lφ_xi - end - end - - # Polynomial section: b[k+j] = ∂pⱼ/∂x[dim](eval_point) - # The gradient of ∂pⱼ/∂x[dim] w.r.t. eval_point is the second derivative of pⱼ - # This is non-trivial and depends on polynomial degree - # For now, we skip this contribution (it's typically small for low-degree polynomials) - # TODO: Implement polynomial second derivatives if needed - - return nothing -end - -""" - backward_rhs_laplacian!(...) - -Chain rule through RHS for Laplacian operator. -""" -function backward_rhs_laplacian!( - Δdata::Vector{Vector{T}}, - Δeval_point::Vector{T}, - Δb::AbstractVecOrMat{T}, - neighbors::Vector{Int}, - eval_point, - data::AbstractVector, - basis::AbstractRadialBasis, - k::Int, - ) where {T} - num_ops = size(Δb, 2) - - # Get gradient functions for the applied Laplacian operator - grad_Lφ_x = grad_applied_laplacian_wrt_x(basis) - grad_Lφ_xi = grad_applied_laplacian_wrt_xi(basis) - - # RBF section: b[i] = ∇²φ(eval_point, xi) - @inbounds for i in 1:k - xi = data[neighbors[i]] - - # Gradient w.r.t. eval_point - ∇Lφ_x = grad_Lφ_x(eval_point, xi) - # Gradient w.r.t. xi - ∇Lφ_xi = grad_Lφ_xi(eval_point, xi) - - # Accumulate across operators - for op_idx in 1:num_ops - Δb_val = Δb[i, op_idx] - Δeval_point .+= Δb_val .* ∇Lφ_x - Δdata[neighbors[i]] .+= Δb_val .* ∇Lφ_xi - end - end - - return nothing -end - -""" - backward_stencil!(Δdata, Δeval_point, Δw, cache, neighbors, eval_point, data, basis, ℒ, mon, k, op_info) - -Complete backward pass for a single stencil. - -Combines: -1. backward_linear_solve! - compute ΔA, Δb from Δw -2. backward_collocation! - chain ΔA to Δdata -3. backward_rhs! - chain Δb to Δdata and Δeval_point -""" -function backward_stencil_partial!( - Δdata::Vector{Vector{T}}, - Δeval_point::Vector{T}, - Δw::AbstractVecOrMat{T}, - cache::StencilForwardCache{T}, - neighbors::Vector{Int}, - eval_point, - data::AbstractVector, - basis::AbstractRadialBasis, - mon::MonomialBasis{Dim, Deg}, - k::Int, - dim::Int, # Partial derivative dimension - ) where {T, Dim, Deg} - n = k + cache.nmon - - # Allocate workspace for ΔA and Δb - ΔA = zeros(T, n, n) - Δb = zeros(T, n, size(Δw, 2)) - - # Step 1: Backprop through linear solve - backward_linear_solve!(ΔA, Δb, Δw, cache) - - # Step 2: Backprop through collocation matrix - backward_collocation!(Δdata, ΔA, neighbors, data, basis, mon, k) - - # Step 3: Backprop through RHS - backward_rhs_partial!( - Δdata, Δeval_point, Δb, neighbors, eval_point, data, basis, dim, k - ) - - return nothing -end - -function backward_stencil_laplacian!( - Δdata::Vector{Vector{T}}, - Δeval_point::Vector{T}, - Δw::AbstractVecOrMat{T}, - cache::StencilForwardCache{T}, - neighbors::Vector{Int}, - eval_point, - data::AbstractVector, - basis::AbstractRadialBasis, - mon::MonomialBasis{Dim, Deg}, - k::Int, - ) where {T, Dim, Deg} - n = k + cache.nmon - - # Allocate workspace for ΔA and Δb - ΔA = zeros(T, n, n) - Δb = zeros(T, n, size(Δw, 2)) - - # Step 1: Backprop through linear solve - backward_linear_solve!(ΔA, Δb, Δw, cache) - - # Step 2: Backprop through collocation matrix - backward_collocation!(Δdata, ΔA, neighbors, data, basis, mon, k) - - # Step 3: Backprop through RHS - backward_rhs_laplacian!(Δdata, Δeval_point, Δb, neighbors, eval_point, data, basis, k) - - return nothing -end diff --git a/ext/RadialBasisFunctionsChainRulesCoreExt/build_weights_rrule.jl b/ext/RadialBasisFunctionsChainRulesCoreExt/build_weights_rrule.jl index c1e1e60e..6a435867 100644 --- a/ext/RadialBasisFunctionsChainRulesCoreExt/build_weights_rrule.jl +++ b/ext/RadialBasisFunctionsChainRulesCoreExt/build_weights_rrule.jl @@ -12,100 +12,6 @@ The rrule defines: using LinearAlgebra: Symmetric using SparseArrays: sparse, SparseMatrixCSC, findnz -import RadialBasisFunctions: _build_weights, _build_collocation_matrix!, _build_rhs! -import RadialBasisFunctions: BoundaryData, MonomialBasis, AbstractRadialBasis -import RadialBasisFunctions: Partial, Laplacian - -""" - _forward_with_cache(data, eval_points, adjl, basis, ℒrbf, ℒmon, mon, ℒ) - -Forward pass that builds weights while caching intermediate results for backward pass. - -Returns: (W, cache) where W is the sparse weight matrix and cache contains -per-stencil factorizations and solutions needed for the pullback. -""" -function _forward_with_cache( - data::AbstractVector, - eval_points::AbstractVector, - adjl::AbstractVector, - basis::AbstractRadialBasis, - ℒrbf, - ℒmon, - mon::MonomialBasis{Dim, Deg}, - ::Type{ℒType}, - ) where {Dim, Deg, ℒType} - TD = eltype(first(data)) - k = length(first(adjl)) - nmon = Deg >= 0 ? binomial(Dim + Deg, Deg) : 0 - n = k + nmon - N_eval = length(eval_points) - N_data = length(data) - - # Determine number of operators (1 for scalar operators) - num_ops = 1 - - # Allocate COO arrays for sparse matrix - nnz = k * N_eval - I = Vector{Int}(undef, nnz) - J = Vector{Int}(undef, nnz) - V = Vector{TD}(undef, nnz) - - # Allocate stencil caches - stencil_caches = Vector{StencilForwardCache{TD, Matrix{TD}}}(undef, N_eval) - - # Process each evaluation point - pos = 1 - for eval_idx in 1:N_eval - neighbors = adjl[eval_idx] - eval_point = eval_points[eval_idx] - - # Get local data for this stencil - local_data = [data[i] for i in neighbors] - - # Build collocation matrix - A_full = zeros(TD, n, n) - A = Symmetric(A_full, :U) - _build_collocation_matrix!(A, local_data, basis, mon, k) - - # Build RHS vector - b = zeros(TD, n, num_ops) - b_vec = view(b, :, 1) - _build_rhs!(b_vec, ℒrbf, ℒmon, local_data, eval_point, basis, mon, k) - - # Solve (symmetric matrix, not positive definite due to zero block) - λ = Symmetric(A_full, :U) \ b - - # Extract weights (first k entries) - w = λ[1:k, :] - - # Store in COO format - for (local_idx, global_idx) in enumerate(neighbors) - I[pos] = eval_idx - J[pos] = global_idx - V[pos] = w[local_idx, 1] - pos += 1 - end - - # Cache for backward pass - store full symmetric matrix - A_full_symmetric = copy(A_full) - # Fill lower triangle from upper - for j in 1:n - for i in (j + 1):n - A_full_symmetric[i, j] = A_full[j, i] - end - end - stencil_caches[eval_idx] = StencilForwardCache(copy(λ), A_full_symmetric, k, nmon) - end - - # Construct sparse matrix - W = sparse(I, J, V, N_eval, N_data) - - # Build global cache - cache = WeightsBuildForwardCache(stencil_caches, k, nmon, num_ops) - - return W, cache -end - """ materialize_sparse_tangent(ΔW_raw, W::SparseMatrixCSC) @@ -154,14 +60,18 @@ function ChainRulesCore.rrule( basis::AbstractRadialBasis, ) # Build monomial basis and apply operator (same as forward pass) - dim = length(first(data)) - mon = MonomialBasis(dim, basis.poly_deg) + dim_space = length(first(data)) + mon = MonomialBasis(dim_space, basis.poly_deg) ℒmon = ℒ(mon) ℒrbf = ℒ(basis) # Forward pass with caching W, cache = _forward_with_cache(data, eval_points, adjl, basis, ℒrbf, ℒmon, mon, Partial) + # Get gradient functions for the partial derivative direction + grad_Lφ_x = grad_applied_partial_wrt_x(basis, ℒ.dim) + grad_Lφ_xi = grad_applied_partial_wrt_xi(basis, ℒ.dim) + function _build_weights_partial_pullback(ΔW_raw) TD = eltype(first(data)) PT = eltype(data) # Point type (e.g., SVector{2,Float64}) @@ -174,6 +84,7 @@ function ChainRulesCore.rrule( # Initialize gradient accumulators (use mutable vectors for accumulation) Δdata_raw = [zeros(TD, length(first(data))) for _ in 1:N_data] Δeval_points_raw = [zeros(TD, length(first(eval_points))) for _ in 1:N_eval] + Δε_acc = Ref(zero(TD)) # Shape parameter gradient accumulator # Process each stencil for eval_idx in 1:N_eval @@ -193,10 +104,11 @@ function ChainRulesCore.rrule( Δlocal_data = [zeros(TD, length(first(data))) for _ in 1:k] Δeval_pt = zeros(TD, length(eval_point)) - # Run backward pass for this stencil - backward_stencil_partial!( + # Run backward pass for this stencil (with ε gradient) + backward_stencil_partial_with_ε!( Δlocal_data, Δeval_pt, + Δε_acc, Δw, stencil_cache, collect(1:k), # Local indices @@ -206,6 +118,8 @@ function ChainRulesCore.rrule( mon, k, ℒ.dim, + grad_Lφ_x, + grad_Lφ_xi, ) # Accumulate to global gradients @@ -216,6 +130,9 @@ function ChainRulesCore.rrule( end end + # Build basis tangent (only for bases with shape parameter) + Δbasis = _make_basis_tangent(basis, Δε_acc[]) + # Convert to match input types (required for Mooncake compatibility) return ( NoTangent(), # function @@ -223,7 +140,7 @@ function ChainRulesCore.rrule( [PT(Δdata_raw[i]) for i in 1:N_data], # data [PT(Δeval_points_raw[i]) for i in 1:N_eval], # eval_points NoTangent(), # adjl (discrete, non-differentiable) - NoTangent(), # basis + Δbasis, # basis ) end @@ -243,15 +160,17 @@ function ChainRulesCore.rrule( basis::AbstractRadialBasis, ) # Build monomial basis and apply operator - dim = length(first(data)) - mon = MonomialBasis(dim, basis.poly_deg) + dim_space = length(first(data)) + mon = MonomialBasis(dim_space, basis.poly_deg) ℒmon = ℒ(mon) ℒrbf = ℒ(basis) # Forward pass with caching - W, cache = _forward_with_cache( - data, eval_points, adjl, basis, ℒrbf, ℒmon, mon, Laplacian - ) + W, cache = _forward_with_cache(data, eval_points, adjl, basis, ℒrbf, ℒmon, mon, Laplacian) + + # Get gradient functions for the Laplacian + grad_Lφ_x = grad_applied_laplacian_wrt_x(basis) + grad_Lφ_xi = grad_applied_laplacian_wrt_xi(basis) function _build_weights_laplacian_pullback(ΔW_raw) TD = eltype(first(data)) @@ -265,6 +184,7 @@ function ChainRulesCore.rrule( # Initialize gradient accumulators (use mutable vectors for accumulation) Δdata_raw = [zeros(TD, length(first(data))) for _ in 1:N_data] Δeval_points_raw = [zeros(TD, length(first(eval_points))) for _ in 1:N_eval] + Δε_acc = Ref(zero(TD)) # Shape parameter gradient accumulator # Process each stencil for eval_idx in 1:N_eval @@ -284,10 +204,11 @@ function ChainRulesCore.rrule( Δlocal_data = [zeros(TD, length(first(data))) for _ in 1:k] Δeval_pt = zeros(TD, length(eval_point)) - # Run backward pass for this stencil - backward_stencil_laplacian!( + # Run backward pass for this stencil (with ε gradient) + backward_stencil_laplacian_with_ε!( Δlocal_data, Δeval_pt, + Δε_acc, Δw, stencil_cache, collect(1:k), @@ -296,6 +217,8 @@ function ChainRulesCore.rrule( basis, mon, k, + grad_Lφ_x, + grad_Lφ_xi, ) # Accumulate to global gradients @@ -306,6 +229,9 @@ function ChainRulesCore.rrule( end end + # Build basis tangent (only for bases with shape parameter) + Δbasis = _make_basis_tangent(basis, Δε_acc[]) + # Convert to match input types (required for Mooncake compatibility) return ( NoTangent(), # function @@ -313,9 +239,14 @@ function ChainRulesCore.rrule( [PT(Δdata_raw[i]) for i in 1:N_data], # data [PT(Δeval_points_raw[i]) for i in 1:N_eval], # eval_points NoTangent(), # adjl - NoTangent(), # basis + Δbasis, # basis ) end return W, _build_weights_laplacian_pullback end + +# Helper to construct appropriate tangent for different basis types +_make_basis_tangent(::AbstractRadialBasis, Δε) = NoTangent() # Default for PHS +_make_basis_tangent(::Gaussian, Δε) = Tangent{Gaussian}(; ε = Δε, poly_deg = NoTangent()) +_make_basis_tangent(::IMQ, Δε) = Tangent{IMQ}(; ε = Δε, poly_deg = NoTangent()) diff --git a/ext/RadialBasisFunctionsChainRulesCoreExt/interpolation_rules.jl b/ext/RadialBasisFunctionsChainRulesCoreExt/interpolation_rules.jl index 0dcc1da5..2bf78939 100644 --- a/ext/RadialBasisFunctionsChainRulesCoreExt/interpolation_rules.jl +++ b/ext/RadialBasisFunctionsChainRulesCoreExt/interpolation_rules.jl @@ -1,16 +1,99 @@ #= -Differentiation rules for Interpolator evaluation. +Differentiation rules for Interpolator construction and evaluation. -The interpolator computes: - f(x) = Σᵢ wᵢ φ(x, xᵢ) + Σⱼ wⱼ pⱼ(x) +Construction: Interpolator(x, y, basis) + Solves A * w = [y; 0] where A is the collocation matrix. + Forward: build A, factor, solve for w + Pullback: Δy = (A⁻ᵀ Δw)[1:k] via adjoint solve -where φ is the RBF kernel and pⱼ are polynomial basis functions. +Evaluation: interp(x) + Computes f(x) = Σᵢ wᵢ φ(x, xᵢ) + Σⱼ wⱼ pⱼ(x) + Pullback returns gradients w.r.t. both evaluation point and weights +=# -For reverse-mode AD, we compute: - ∂f/∂x = Σᵢ wᵢ ∇φ(x, xᵢ) + Σⱼ wⱼ ∇pⱼ(x) +using LinearAlgebra: Symmetric, factorize + +# ============================================================================ +# Interpolator Construction Rule +# ============================================================================ + +""" + rrule for Interpolator construction. + +Differentiates through w = A \\ [y; 0] via adjoint solve. +Backward: Δy = (A⁻ᵀ Δw)[1:k] +""" +function ChainRulesCore.rrule( + ::Type{Interpolator}, + x::AbstractVector, + y::AbstractVector, + basis::AbstractRadialBasis, + ) + # Forward pass: build collocation matrix, factor, solve + dim = length(first(x)) + k = length(x) + npoly = binomial(dim + basis.poly_deg, basis.poly_deg) + n = k + npoly + mon = MonomialBasis(dim, basis.poly_deg) + T = promote_type(eltype(first(x)), eltype(y)) + + A = Symmetric(zeros(T, n, n)) + _build_collocation_matrix!(A, x, basis, mon, k) + + # Factor for reuse in backward pass + A_factored = factorize(A) + b = vcat(y, zeros(T, npoly)) + w = A_factored \ b + + interp = Interpolator(x, y, w[1:k], w[(k + 1):end], basis, mon) + + function interpolator_construction_pullback(Δinterp_raw) + Δinterp = unthunk(Δinterp_raw) + + # Handle NoTangent or ZeroTangent + if Δinterp isa NoTangent || Δinterp isa ZeroTangent + return NoTangent(), NoTangent(), ZeroTangent(), NoTangent() + end -The weights (wᵢ, wⱼ) and data points (xᵢ) are treated as constants. -=# + # Extract weight tangents from Interpolator tangent + # For Tangent types, use getproperty which handles backing correctly + Δw_rbf = hasproperty(Δinterp, :rbf_weights) ? getproperty(Δinterp, :rbf_weights) : ZeroTangent() + Δw_mon = hasproperty(Δinterp, :monomial_weights) ? getproperty(Δinterp, :monomial_weights) : ZeroTangent() + + # Handle ZeroTangent for both components + if Δw_rbf isa ZeroTangent && Δw_mon isa ZeroTangent + return NoTangent(), NoTangent(), ZeroTangent(), NoTangent() + end + + # Combine into full Δw vector + Δw_rbf_vec = Δw_rbf isa ZeroTangent ? zeros(T, k) : collect(Δw_rbf) + Δw_mon_vec = Δw_mon isa ZeroTangent ? zeros(T, npoly) : collect(Δw_mon) + Δw = vcat(Δw_rbf_vec, Δw_mon_vec) + + # Adjoint solve: Δb = A⁻ᵀ Δw + Δb = A_factored' \ Δw + + # Extract Δy (first k elements of b correspond to y values) + Δy = Δb[1:k] + + return NoTangent(), NoTangent(), Δy, NoTangent() + end + + return interp, interpolator_construction_pullback +end + +# Convenience wrapper with default basis +function ChainRulesCore.rrule( + ::Type{Interpolator}, + x::AbstractVector, + y::AbstractVector, + ) + return ChainRulesCore.rrule(Interpolator, x, y, PHS()) +end + +# ============================================================================ +# Interpolator Evaluation Rules +# ============================================================================ function ChainRulesCore.rrule(interp::Interpolator, x::AbstractVector) y = interp(x) @@ -18,76 +101,115 @@ function ChainRulesCore.rrule(interp::Interpolator, x::AbstractVector) function interpolator_pullback(Δy) Δy_real = unthunk(Δy) - # Initialize gradient accumulator + # Initialize gradient accumulators Δx = zero(x) + T = eltype(x) # RBF contribution: Σᵢ wᵢ ∇φ(x, xᵢ) + # Gradient w.r.t. weights: ∂y/∂wᵢ = φ(x, xᵢ) + k = length(interp.rbf_weights) + Δw_rbf = zeros(T, k) grad_fn = ∇(interp.rbf_basis) for i in eachindex(interp.rbf_weights) + φ_val = interp.rbf_basis(x, interp.x[i]) + Δw_rbf[i] = Δy_real * φ_val ∇φ = grad_fn(x, interp.x[i]) Δx = Δx .+ (interp.rbf_weights[i] * Δy_real) .* ∇φ end # Polynomial contribution: Σⱼ wⱼ ∇pⱼ(x) + # Gradient w.r.t. weights: ∂y/∂wⱼ = pⱼ(x) + n_mon = length(interp.monomial_weights) + Δw_mon = zeros(T, n_mon) if !isempty(interp.monomial_weights) dim = length(x) - n_terms = length(interp.monomial_weights) + + # Get polynomial values at x + poly_vals = zeros(T, n_mon) + interp.monomial_basis(poly_vals, x) # Get the gradient operator for the monomial basis - # ∇(monomial_basis) returns a callable that fills a matrix ∇mon = ∇(interp.monomial_basis) - ∇p = zeros(eltype(x), n_terms, dim) + ∇p = zeros(T, n_mon, dim) ∇mon(∇p, x) - # Accumulate: Σⱼ wⱼ ∇pⱼ(x) for j in eachindex(interp.monomial_weights) + Δw_mon[j] = Δy_real * poly_vals[j] Δx = Δx .+ (interp.monomial_weights[j] * Δy_real) .* view(∇p, j, :) end end - return NoTangent(), Δx + # Build Tangent for Interpolator with weight gradients + Δinterp = Tangent{Interpolator}(; + rbf_weights = Δw_rbf, + monomial_weights = Δw_mon, + ) + + return Δinterp, Δx end return y, interpolator_pullback end # Batch evaluation: interp([x1, x2, ...]) returns [f(x1), f(x2), ...] -# The pullback needs to accumulate gradients for each input point. +# The pullback needs to accumulate gradients for each input point and weights. function ChainRulesCore.rrule(interp::Interpolator, xs::Vector{<:AbstractVector}) ys = interp(xs) function interpolator_batch_pullback(Δys) Δys_real = unthunk(Δys) + T = eltype(first(xs)) + + # Initialize weight gradient accumulators + k_rbf = length(interp.rbf_weights) + n_mon = length(interp.monomial_weights) + Δw_rbf = zeros(T, k_rbf) + Δw_mon = zeros(T, n_mon) # Compute gradient for each input point Δxs = similar(xs) + grad_fn = ∇(interp.rbf_basis) + for (i, x) in enumerate(xs) Δx = zero(x) + Δy_i = Δys_real[i] # RBF contribution - grad_fn = ∇(interp.rbf_basis) for j in eachindex(interp.rbf_weights) + φ_val = interp.rbf_basis(x, interp.x[j]) + Δw_rbf[j] += Δy_i * φ_val ∇φ = grad_fn(x, interp.x[j]) - Δx = Δx .+ (interp.rbf_weights[j] * Δys_real[i]) .* ∇φ + Δx = Δx .+ (interp.rbf_weights[j] * Δy_i) .* ∇φ end # Polynomial contribution if !isempty(interp.monomial_weights) dim = length(x) - n_terms = length(interp.monomial_weights) + + # Get polynomial values at x + poly_vals = zeros(T, n_mon) + interp.monomial_basis(poly_vals, x) + ∇mon = ∇(interp.monomial_basis) - ∇p = zeros(eltype(x), n_terms, dim) + ∇p = zeros(T, n_mon, dim) ∇mon(∇p, x) - for k in eachindex(interp.monomial_weights) - Δx = Δx .+ (interp.monomial_weights[k] * Δys_real[i]) .* view(∇p, k, :) + for j in eachindex(interp.monomial_weights) + Δw_mon[j] += Δy_i * poly_vals[j] + Δx = Δx .+ (interp.monomial_weights[j] * Δy_i) .* view(∇p, j, :) end end Δxs[i] = Δx end - return NoTangent(), Δxs + # Build Tangent for Interpolator with weight gradients + Δinterp = Tangent{Interpolator}(; + rbf_weights = Δw_rbf, + monomial_weights = Δw_mon, + ) + + return Δinterp, Δxs end return ys, interpolator_batch_pullback diff --git a/ext/RadialBasisFunctionsChainRulesCoreExt/operator_rules.jl b/ext/RadialBasisFunctionsChainRulesCoreExt/operator_rules.jl index 6429def8..99834d80 100644 --- a/ext/RadialBasisFunctionsChainRulesCoreExt/operator_rules.jl +++ b/ext/RadialBasisFunctionsChainRulesCoreExt/operator_rules.jl @@ -65,3 +65,46 @@ function ChainRulesCore.rrule( return result, _eval_op_inplace_pullback end + +# ============================================================================= +# rrules for operator call syntax: op(x) +# ============================================================================= +# These rules handle the (op::RadialBasisOperator)(x) call directly, +# bypassing the cache check which can cause issues with some AD backends. + +# Scalar-valued operator call: op(x) +function ChainRulesCore.rrule(op::RadialBasisOperator, x::AbstractVector) + # Ensure weights are computed + !RadialBasisFunctions.is_cache_valid(op) && RadialBasisFunctions.update_weights!(op) + y = _eval_op(op, x) + + function op_call_pullback(Δy) + Δy_unthunked = unthunk(Δy) + Δx = op.weights' * Δy_unthunked + return NoTangent(), Δx + end + + return y, op_call_pullback +end + +# Vector-valued operator call: op(x) for gradient/jacobian +function ChainRulesCore.rrule( + op::RadialBasisOperator{<:VectorValuedOperator{D}}, + x::AbstractVector, + ) where {D} + # Ensure weights are computed + !RadialBasisFunctions.is_cache_valid(op) && RadialBasisFunctions.update_weights!(op) + y = _eval_op(op, x) + + function op_call_vector_pullback(Δy) + Δy_unthunked = unthunk(Δy) + Δx = similar(x) + fill!(Δx, zero(eltype(Δx))) + for d in 1:D + Δx .+= op.weights[d]' * view(Δy_unthunked, :, d) + end + return NoTangent(), Δx + end + + return y, op_call_vector_pullback +end diff --git a/ext/RadialBasisFunctionsEnzymeExt/RadialBasisFunctionsEnzymeExt.jl b/ext/RadialBasisFunctionsEnzymeExt/RadialBasisFunctionsEnzymeExt.jl new file mode 100644 index 00000000..377130d7 --- /dev/null +++ b/ext/RadialBasisFunctionsEnzymeExt/RadialBasisFunctionsEnzymeExt.jl @@ -0,0 +1,1004 @@ +""" + RadialBasisFunctionsEnzymeExt + +Package extension that provides native Enzyme.jl AD support for RadialBasisFunctions.jl +using EnzymeRules (augmented_primal + reverse). + +This extension requires Enzyme to be loaded. + +Native rules are provided for: +- Basis function evaluation: PHS1, PHS3, PHS5, PHS7, IMQ, Gaussian +- Operator evaluation: `_eval_op(op, x)` for scalar and vector-valued operators +- Operator call syntax: `op(x)` +- Interpolator evaluation: single point and batch +- Weight construction: `_build_weights` for Partial and Laplacian operators +""" +module RadialBasisFunctionsEnzymeExt + +using RadialBasisFunctions +using Enzyme +using EnzymeCore +using EnzymeCore.EnzymeRules +using LinearAlgebra +using SparseArrays + +# Import internal functions +import RadialBasisFunctions: _eval_op, RadialBasisOperator, Interpolator +import RadialBasisFunctions: PHS1, PHS3, PHS5, PHS7, IMQ, Gaussian +import RadialBasisFunctions: AbstractRadialBasis, VectorValuedOperator +import RadialBasisFunctions: _build_weights, Partial, Laplacian +import RadialBasisFunctions: MonomialBasis + +# Import backward pass support from main package +import RadialBasisFunctions: StencilForwardCache, WeightsBuildForwardCache +import RadialBasisFunctions: backward_stencil_partial!, backward_stencil_laplacian! +import RadialBasisFunctions: backward_stencil_partial_with_ε!, backward_stencil_laplacian_with_ε! +import RadialBasisFunctions: _forward_with_cache +import RadialBasisFunctions: grad_applied_partial_wrt_x, grad_applied_partial_wrt_xi +import RadialBasisFunctions: grad_applied_laplacian_wrt_x, grad_applied_laplacian_wrt_xi + +# Import gradient function +const ∇ = RadialBasisFunctions.∇ + +# ============================================================================= +# Basis Function Rules +# ============================================================================= + +# Helper macro to define basis function rules for a given type +macro define_basis_rule(BasisType) + return quote + # Both x and xi are Duplicated (both being differentiated) + function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::EnzymeCore.Const{<:$(esc(BasisType))}, + ::Type{<:EnzymeCore.Active}, + x::EnzymeCore.Duplicated, + xi::EnzymeCore.Duplicated, + ) + basis = func.val + y = basis(x.val, xi.val) + tape = (copy(x.val), copy(xi.val)) + return EnzymeRules.AugmentedReturn(y, nothing, tape) + end + + function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::EnzymeCore.Const{<:$(esc(BasisType))}, + dret::EnzymeCore.Active, + tape, + x::EnzymeCore.Duplicated, + xi::EnzymeCore.Duplicated, + ) + basis = func.val + x_val, xi_val = tape + grad_fn = ∇(basis) + ∇φ = grad_fn(x_val, xi_val) + x.dval .+= dret.val .* ∇φ + xi.dval .-= dret.val .* ∇φ + return (nothing, nothing) + end + + # x is Duplicated, xi is Const (xi captured in closure) + function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::EnzymeCore.Const{<:$(esc(BasisType))}, + ::Type{<:EnzymeCore.Active}, + x::EnzymeCore.Duplicated, + xi::EnzymeCore.Const, + ) + basis = func.val + y = basis(x.val, xi.val) + tape = (copy(x.val), copy(xi.val)) + return EnzymeRules.AugmentedReturn(y, nothing, tape) + end + + function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::EnzymeCore.Const{<:$(esc(BasisType))}, + dret::EnzymeCore.Active, + tape, + x::EnzymeCore.Duplicated, + xi::EnzymeCore.Const, + ) + basis = func.val + x_val, xi_val = tape + grad_fn = ∇(basis) + ∇φ = grad_fn(x_val, xi_val) + x.dval .+= dret.val .* ∇φ + return (nothing, nothing) + end + end +end + +@define_basis_rule PHS1 +@define_basis_rule PHS3 +@define_basis_rule PHS5 +@define_basis_rule PHS7 +@define_basis_rule IMQ +@define_basis_rule Gaussian + +# ============================================================================= +# Operator Evaluation Rules: _eval_op(op, x) +# ============================================================================= + +# Scalar-valued operator: y = W * x +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::EnzymeCore.Const{typeof(_eval_op)}, + ::Type{RT}, + op::EnzymeCore.Const{<:RadialBasisOperator}, + x::EnzymeCore.Duplicated, + ) where {RT} + y = _eval_op(op.val, x.val) + shadow = RT <: EnzymeCore.Duplicated ? zero(y) : nothing + return EnzymeRules.AugmentedReturn(y, shadow, (op.val, shadow)) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::EnzymeCore.Const{typeof(_eval_op)}, + dret, + tape, + op::EnzymeCore.Const{<:RadialBasisOperator}, + x::EnzymeCore.Duplicated, + ) + operator, shadow = tape + dy = dret isa EnzymeCore.Active ? dret.val : shadow + if dy !== nothing + x.dval .+= operator.weights' * dy + shadow !== nothing && fill!(shadow, 0) + end + return (nothing, nothing) +end + +# Vector-valued operator: y[:,d] = W[d] * x +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::EnzymeCore.Const{typeof(_eval_op)}, + ::Type{RT}, + op::EnzymeCore.Const{<:RadialBasisOperator{<:VectorValuedOperator{D}}}, + x::EnzymeCore.Duplicated, + ) where {D, RT} + y = _eval_op(op.val, x.val) + shadow = RT <: EnzymeCore.Duplicated ? zero(y) : nothing + return EnzymeRules.AugmentedReturn(y, shadow, (op.val, D, shadow)) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::EnzymeCore.Const{typeof(_eval_op)}, + dret, + tape, + op::EnzymeCore.Const{<:RadialBasisOperator{<:VectorValuedOperator}}, + x::EnzymeCore.Duplicated, + ) + operator, D, shadow = tape + dy = dret isa EnzymeCore.Active ? dret.val : shadow + if dy !== nothing + for d in 1:D + x.dval .+= operator.weights[d]' * view(dy, :, d) + end + shadow !== nothing && fill!(shadow, 0) + end + return (nothing, nothing) +end + +# ============================================================================= +# Operator Call Rules: op(x) +# ============================================================================= + +# Scalar-valued operator call +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + op::EnzymeCore.Const{<:RadialBasisOperator}, + ::Type{RT}, + x::EnzymeCore.Duplicated, + ) where {RT} + operator = op.val + y = operator.weights * x.val + shadow = RT <: EnzymeCore.Duplicated ? zero(y) : nothing + return EnzymeRules.AugmentedReturn(y, shadow, (operator, shadow)) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + op::EnzymeCore.Const{<:RadialBasisOperator}, + dret, + tape, + x::EnzymeCore.Duplicated, + ) + operator, shadow = tape + dy = dret isa EnzymeCore.Active ? dret.val : shadow + if dy !== nothing + x.dval .+= operator.weights' * dy + shadow !== nothing && fill!(shadow, 0) + end + return (nothing,) +end + +# Vector-valued operator call +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + op::EnzymeCore.Const{<:RadialBasisOperator{<:VectorValuedOperator{D}}}, + ::Type{RT}, + x::EnzymeCore.Duplicated, + ) where {D, RT} + operator = op.val + y = _eval_op(operator, x.val) + shadow = RT <: EnzymeCore.Duplicated ? zero(y) : nothing + return EnzymeRules.AugmentedReturn(y, shadow, (operator, D, shadow)) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + op::EnzymeCore.Const{<:RadialBasisOperator{<:VectorValuedOperator}}, + dret, + tape, + x::EnzymeCore.Duplicated, + ) + operator, D, shadow = tape + dy = dret isa EnzymeCore.Active ? dret.val : shadow + if dy !== nothing + for d in 1:D + x.dval .+= operator.weights[d]' * view(dy, :, d) + end + shadow !== nothing && fill!(shadow, 0) + end + return (nothing,) +end + +# ============================================================================= +# Interpolator Rules +# ============================================================================= + +# Single point evaluation +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + interp::EnzymeCore.Const{<:Interpolator}, + ::Type{<:EnzymeCore.Active}, + x::EnzymeCore.Duplicated, + ) + y = interp.val(x.val) + tape = (interp.val, copy(x.val)) + return EnzymeRules.AugmentedReturn(y, nothing, tape) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + interp_const::EnzymeCore.Const{<:Interpolator}, + dret::EnzymeCore.Active, + tape, + x::EnzymeCore.Duplicated, + ) + interp, x_val = tape + Δy = dret.val + + # RBF contribution: Σᵢ wᵢ ∇φ(x, xᵢ) + grad_fn = ∇(interp.rbf_basis) + for i in eachindex(interp.rbf_weights) + ∇φ = grad_fn(x_val, interp.x[i]) + x.dval .+= (interp.rbf_weights[i] * Δy) .* ∇φ + end + + # Polynomial contribution: Σⱼ wⱼ ∇pⱼ(x) + if !isempty(interp.monomial_weights) + dim = length(x_val) + n_terms = length(interp.monomial_weights) + ∇mon = ∇(interp.monomial_basis) + ∇p = zeros(eltype(x_val), n_terms, dim) + ∇mon(∇p, x_val) + + for j in eachindex(interp.monomial_weights) + x.dval .+= (interp.monomial_weights[j] * Δy) .* view(∇p, j, :) + end + end + + return (nothing,) +end + +# Batch evaluation +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + interp::EnzymeCore.Const{<:Interpolator}, + ::Type{<:EnzymeCore.Active}, + xs::EnzymeCore.Duplicated{<:Vector{<:AbstractVector}}, + ) + ys = interp.val(xs.val) + tape = (interp.val, deepcopy(xs.val)) + return EnzymeRules.AugmentedReturn(ys, nothing, tape) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + interp_const::EnzymeCore.Const{<:Interpolator}, + dret::EnzymeCore.Active, + tape, + xs::EnzymeCore.Duplicated{<:Vector{<:AbstractVector}}, + ) + interp, xs_val = tape + Δys = dret.val + + for (i, x_val) in enumerate(xs_val) + Δy = Δys[i] + + # RBF contribution + grad_fn = ∇(interp.rbf_basis) + for j in eachindex(interp.rbf_weights) + ∇φ = grad_fn(x_val, interp.x[j]) + xs.dval[i] = xs.dval[i] + (interp.rbf_weights[j] * Δy) .* ∇φ + end + + # Polynomial contribution + if !isempty(interp.monomial_weights) + dim = length(x_val) + n_terms = length(interp.monomial_weights) + ∇mon = ∇(interp.monomial_basis) + ∇p = zeros(eltype(x_val), n_terms, dim) + ∇mon(∇p, x_val) + + for k in eachindex(interp.monomial_weights) + xs.dval[i] = xs.dval[i] + (interp.monomial_weights[k] * Δy) .* view(∇p, k, :) + end + end + end + + return (nothing,) +end + +# ============================================================================= +# _build_weights Rules for Shape Optimization +# ============================================================================= + +# Helper to materialize sparse matrix tangent for Enzyme +function materialize_sparse_tangent_enzyme(ΔW, W::SparseMatrixCSC) + if ΔW isa SparseMatrixCSC + return ΔW + end + # If ΔW is dense or some other form, convert appropriately + return SparseMatrixCSC(W.m, W.n, copy(W.colptr), copy(W.rowval), copy(nonzeros(W))) +end + +# Extract stencil cotangent from sparse matrix +function extract_stencil_cotangent_enzyme( + ΔW::AbstractMatrix{T}, eval_idx::Int, neighbors::Vector{Int}, k::Int, num_ops::Int + ) where {T} + Δw = zeros(T, k, num_ops) + for (local_idx, global_idx) in enumerate(neighbors) + Δw[local_idx, 1] = ΔW[eval_idx, global_idx] + end + return Δw +end + +# ============================================================================= +# Unified _build_weights rule generation via macro +# ============================================================================= + +# Helper functions to get gradient functions and call backward stencil for each operator type +_get_grad_funcs(::Type{<:Partial}, basis, ℒ) = ( + grad_applied_partial_wrt_x(basis, ℒ.dim), + grad_applied_partial_wrt_xi(basis, ℒ.dim), +) +_get_grad_funcs(::Type{<:Laplacian}, basis, ℒ) = ( + grad_applied_laplacian_wrt_x(basis), + grad_applied_laplacian_wrt_xi(basis), +) + +# Backward stencil dispatch (without ε) +function _call_backward_stencil!( + ::Type{<:Partial}, Δlocal_data, Δeval_pt, Δw, stencil_cache, neighbors, + eval_point, local_data, basis, mon, k, ℒ, grad_Lφ_x, grad_Lφ_xi + ) + return backward_stencil_partial!( + Δlocal_data, Δeval_pt, Δw, stencil_cache, neighbors, + eval_point, local_data, basis, mon, k, ℒ.dim, grad_Lφ_x, grad_Lφ_xi + ) +end + +function _call_backward_stencil!( + ::Type{<:Laplacian}, Δlocal_data, Δeval_pt, Δw, stencil_cache, neighbors, + eval_point, local_data, basis, mon, k, ℒ, grad_Lφ_x, grad_Lφ_xi + ) + return backward_stencil_laplacian!( + Δlocal_data, Δeval_pt, Δw, stencil_cache, neighbors, + eval_point, local_data, basis, mon, k, grad_Lφ_x, grad_Lφ_xi + ) +end + +# Backward stencil dispatch (with ε) +function _call_backward_stencil_with_ε!( + ::Type{<:Partial}, Δlocal_data, Δeval_pt, Δε_acc, Δw, stencil_cache, neighbors, + eval_point, local_data, basis, mon, k, ℒ, grad_Lφ_x, grad_Lφ_xi + ) + return backward_stencil_partial_with_ε!( + Δlocal_data, Δeval_pt, Δε_acc, Δw, stencil_cache, neighbors, + eval_point, local_data, basis, mon, k, ℒ.dim, grad_Lφ_x, grad_Lφ_xi + ) +end + +function _call_backward_stencil_with_ε!( + ::Type{<:Laplacian}, Δlocal_data, Δeval_pt, Δε_acc, Δw, stencil_cache, neighbors, + eval_point, local_data, basis, mon, k, ℒ, grad_Lφ_x, grad_Lφ_xi + ) + return backward_stencil_laplacian_with_ε!( + Δlocal_data, Δeval_pt, Δε_acc, Δw, stencil_cache, neighbors, + eval_point, local_data, basis, mon, k, grad_Lφ_x, grad_Lφ_xi + ) +end + +""" +Generate augmented_primal and reverse rules for _build_weights with different argument activities. + +Arguments: +- OpType: Partial or Laplacian +- data_activity: :Duplicated or :Const +- basis_activity: :Const, :Active, or :Duplicated +""" +macro define_build_weights_rule(OpType, data_activity, basis_activity) + # Determine type annotations for signature + data_type = data_activity == :Duplicated ? :(EnzymeCore.Duplicated) : :(EnzymeCore.Const) + eval_type = data_activity == :Duplicated ? :(EnzymeCore.Duplicated) : :(EnzymeCore.Const) + + if basis_activity == :Const + basis_type = :(EnzymeCore.Const{<:AbstractRadialBasis}) + basis_type_param = nothing + elseif basis_activity == :Active + basis_type = :(EnzymeCore.Active{B}) + basis_type_param = :(where{B <: AbstractRadialBasis}) + else # :Duplicated + basis_type = :(EnzymeCore.Duplicated{B}) + basis_type_param = :(where{B <: AbstractRadialBasis}) + end + + # Determine if we need RT type parameter for shadow allocation + needs_rt = basis_activity != :Const || data_activity == :Const + rt_param = needs_rt ? :(::Type{RT}) : :(::Type{<:EnzymeCore.Active}) + rt_where = needs_rt ? :(where{RT}) : nothing + + # Build augmented_primal signature + aug_sig = if basis_activity == :Const && !needs_rt + quote + function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::EnzymeCore.Const{typeof(_build_weights)}, + $rt_param, + ℒ_arg::EnzymeCore.Const{<:$OpType}, + data::$data_type, + eval_points::$eval_type, + adjl::EnzymeCore.Const, + basis::$basis_type, + ) + end + end + else + quote + function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::EnzymeCore.Const{typeof(_build_weights)}, + $rt_param, + ℒ_arg::EnzymeCore.Const{<:$OpType}, + data::$data_type, + eval_points::$eval_type, + adjl::EnzymeCore.Const, + basis::$basis_type, + ) + return $rt_where $ (basis_type_param...) + end + end + end + + # Generate the actual code + return quote + # Augmented primal + function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::EnzymeCore.Const{typeof(_build_weights)}, + $(needs_rt ? :(::Type{RT}) : :(::Type{<:EnzymeCore.Active})), + ℒ_arg::EnzymeCore.Const{<:$(esc(OpType))}, + data::$(esc(data_type)), + eval_points::$(esc(eval_type)), + adjl::EnzymeCore.Const, + basis::$(esc(basis_type)), + ) + $( + needs_rt && basis_activity != :Const ? :(where{RT, B <: AbstractRadialBasis}) : + needs_rt ? :(where{RT}) : + basis_activity != :Const ? :(where{B <: AbstractRadialBasis}) : nothing + ) + + op_val = ℒ_arg.val + data_val = data.val + eval_points_val = eval_points.val + adjl_val = adjl.val + basis_val = basis.val + + dim_space = length(first(data_val)) + mon = MonomialBasis(dim_space, basis_val.poly_deg) + op_mon = op_val(mon) + op_rbf = op_val(basis_val) + + W, cache = _forward_with_cache( + data_val, eval_points_val, adjl_val, basis_val, op_rbf, op_mon, mon, $(esc(OpType)) + ) + + $( + if basis_activity == :Active + quote + shadow = _make_shadow_for_return(RT, W) + tape = (op_val, cache, adjl_val, basis_val, mon, data_val, eval_points_val, shadow) + return EnzymeRules.AugmentedReturn(W, shadow, tape) + end + elseif data_activity == :Duplicated + quote + tape = (op_val, cache, adjl_val, basis_val, mon, deepcopy(data_val), deepcopy(eval_points_val)) + return EnzymeRules.AugmentedReturn(W, nothing, tape) + end + else # Const data, Const or Duplicated basis + quote + tape = (op_val, cache, adjl_val, basis_val, mon, data_val, eval_points_val) + return EnzymeRules.AugmentedReturn(W, nothing, tape) + end + end + ) + end + + # Reverse pass + function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::EnzymeCore.Const{typeof(_build_weights)}, + $(basis_activity == :Active ? :dret : :(dret::EnzymeCore.Active)), + tape, + ℒ_arg::EnzymeCore.Const{<:$(esc(OpType))}, + data::$(esc(data_type)), + eval_points::$(esc(eval_type)), + adjl::EnzymeCore.Const, + basis::$(esc(basis_type)), + ) + $(basis_activity != :Const ? :(where{B <: AbstractRadialBasis}) : nothing) + + $( + if basis_activity == :Active + :((op_cached, cache, adjl_val, basis_val, mon, data_val, eval_points_val, shadow) = tape) + else + :((op_cached, cache, adjl_val, basis_val, mon, data_val, eval_points_val) = tape) + end + ) + + $( + if basis_activity == :Active + :(ΔW = _extract_dret_with_shadow(dret, shadow)) + else + :(ΔW = dret.val) + end + ) + + TD = eltype(first(data_val)) + N_eval = length(eval_points_val) + k = cache.k + + grad_Lφ_x, grad_Lφ_xi = _get_grad_funcs($(esc(OpType)), basis_val, op_cached) + + $( + if basis_activity != :Const + :(Δε_acc = Ref(zero(TD))) + else + nothing + end + ) + + for eval_idx in 1:N_eval + neighbors = adjl_val[eval_idx] + eval_point = eval_points_val[eval_idx] + stencil_cache = cache.stencil_caches[eval_idx] + + Δw = extract_stencil_cotangent_enzyme(ΔW, eval_idx, neighbors, k, cache.num_ops) + + if sum(abs, Δw) > 0 + local_data = [data_val[i] for i in neighbors] + Δlocal_data = [zeros(TD, length(first(data_val))) for _ in 1:k] + Δeval_pt = zeros(TD, length(eval_point)) + + $( + if basis_activity == :Const + quote + _call_backward_stencil!( + $(esc(OpType)), Δlocal_data, Δeval_pt, Δw, stencil_cache, collect(1:k), + eval_point, local_data, basis_val, mon, k, op_cached, grad_Lφ_x, grad_Lφ_xi + ) + end + else + quote + _call_backward_stencil_with_ε!( + $(esc(OpType)), Δlocal_data, Δeval_pt, Δε_acc, Δw, stencil_cache, collect(1:k), + eval_point, local_data, basis_val, mon, k, op_cached, grad_Lφ_x, grad_Lφ_xi + ) + end + end + ) + + $( + if data_activity == :Duplicated + quote + for (local_idx, global_idx) in enumerate(neighbors) + data.dval[global_idx] = data.dval[global_idx] + Δlocal_data[local_idx] + end + eval_points.dval[eval_idx] = eval_points.dval[eval_idx] + Δeval_pt + end + else + nothing + end + ) + end + end + + $( + if basis_activity == :Active + quote + Δbasis = _make_enzyme_tangent(B, basis_val, Δε_acc[]) + return (nothing, nothing, nothing, nothing, Δbasis) + end + elseif basis_activity == :Duplicated + quote + _accumulate_basis_gradient!(basis.dval, Δε_acc[]) + return (nothing, nothing, nothing, nothing, nothing) + end + else + :(return (nothing, nothing, nothing, nothing, nothing)) + end + ) + end + end +end + +# ============================================================================= +# Helper functions (must be defined before macro invocation) +# ============================================================================= + +# For Duplicated return types, we need to allocate a shadow matrix +function _make_shadow_for_return(::Type{<:EnzymeCore.Duplicated}, W::SparseMatrixCSC) + return SparseMatrixCSC(W.m, W.n, copy(W.colptr), copy(W.rowval), zeros(eltype(W), length(W.nzval))) +end +_make_shadow_for_return(::Type, _W) = nothing + +# Helper to extract cotangent from dret (differs between Active and Duplicated return) +_extract_dret_with_shadow(dret::EnzymeCore.Active, _shadow) = dret.val +_extract_dret_with_shadow(::Type, shadow::AbstractMatrix) = shadow +_extract_dret_with_shadow(::Type, ::Nothing) = nothing + +# Helper to construct Enzyme tangent for basis types +_make_enzyme_tangent(::Type{<:AbstractRadialBasis}, _basis, _Δε) = nothing # PHS has no ε + +function _make_enzyme_tangent(::Type{Gaussian{T, D}}, _basis::Gaussian{T, D}, Δε) where {T, D} + return Gaussian(convert(T, Δε); poly_deg = D(0)) +end + +function _make_enzyme_tangent(::Type{IMQ{T, D}}, _basis::IMQ{T, D}, Δε) where {T, D} + return IMQ(convert(T, Δε); poly_deg = D(0)) +end + +# Helper to accumulate gradient into basis shadow +_accumulate_basis_gradient!(::Gaussian{T}, _Δε) where {T <: Number} = nothing +_accumulate_basis_gradient!(::IMQ{T}, _Δε) where {T <: Number} = nothing + +function _accumulate_basis_gradient!(shadow::Gaussian{T}, Δε) where {T <: AbstractVector} + shadow.ε[1] += Δε + return nothing +end + +function _accumulate_basis_gradient!(shadow::IMQ{T}, Δε) where {T <: AbstractVector} + shadow.ε[1] += Δε + return nothing +end + +_accumulate_basis_gradient!(_shadow, _Δε) = nothing + +# ============================================================================= +# Explicit rules (replacing macro due to Julia version compatibility issues) +# ============================================================================= + +# Partial with Duplicated data, Const basis +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::EnzymeCore.Const{typeof(_build_weights)}, + ::Type{RT}, + ℒ_arg::EnzymeCore.Const{<:Partial}, + data::EnzymeCore.Duplicated, + eval_points::EnzymeCore.Duplicated, + adjl::EnzymeCore.Const, + basis::EnzymeCore.Const{<:AbstractRadialBasis}, + ) where {RT} + op_val = ℒ_arg.val + data_val = data.val + eval_points_val = eval_points.val + adjl_val = adjl.val + basis_val = basis.val + + dim_space = length(first(data_val)) + mon = MonomialBasis(dim_space, basis_val.poly_deg) + op_mon = op_val(mon) + op_rbf = op_val(basis_val) + + W, cache = _forward_with_cache( + data_val, eval_points_val, adjl_val, basis_val, op_rbf, op_mon, mon, Partial + ) + + shadow = _make_shadow_for_return(RT, W) + tape = (op_val, cache, adjl_val, basis_val, mon, deepcopy(data_val), deepcopy(eval_points_val), shadow) + return EnzymeRules.AugmentedReturn(W, shadow, tape) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::EnzymeCore.Const{typeof(_build_weights)}, + dret, + tape, + ℒ_arg::EnzymeCore.Const{<:Partial}, + data::EnzymeCore.Duplicated, + eval_points::EnzymeCore.Duplicated, + adjl::EnzymeCore.Const, + basis::EnzymeCore.Const{<:AbstractRadialBasis}, + ) + op_cached, cache, adjl_val, basis_val, mon, data_val, eval_points_val, shadow = tape + ΔW = _extract_dret_with_shadow(dret, shadow) + + TD = eltype(first(data_val)) + N_eval = length(eval_points_val) + k = cache.k + + grad_Lφ_x, grad_Lφ_xi = _get_grad_funcs(Partial, basis_val, op_cached) + + for eval_idx in 1:N_eval + neighbors = adjl_val[eval_idx] + eval_point = eval_points_val[eval_idx] + stencil_cache = cache.stencil_caches[eval_idx] + + Δw = extract_stencil_cotangent_enzyme(ΔW, eval_idx, neighbors, k, cache.num_ops) + + if sum(abs, Δw) > 0 + local_data = [data_val[i] for i in neighbors] + Δlocal_data = [zeros(TD, length(first(data_val))) for _ in 1:k] + Δeval_pt = zeros(TD, length(eval_point)) + + _call_backward_stencil!( + Partial, Δlocal_data, Δeval_pt, Δw, stencil_cache, collect(1:k), + eval_point, local_data, basis_val, mon, k, op_cached, grad_Lφ_x, grad_Lφ_xi + ) + + for (local_idx, global_idx) in enumerate(neighbors) + data.dval[global_idx] = data.dval[global_idx] + Δlocal_data[local_idx] + end + eval_points.dval[eval_idx] = eval_points.dval[eval_idx] + Δeval_pt + end + end + + return (nothing, nothing, nothing, nothing, nothing) +end + +# Laplacian with Duplicated data, Const basis +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::EnzymeCore.Const{typeof(_build_weights)}, + ::Type{RT}, + ℒ_arg::EnzymeCore.Const{<:Laplacian}, + data::EnzymeCore.Duplicated, + eval_points::EnzymeCore.Duplicated, + adjl::EnzymeCore.Const, + basis::EnzymeCore.Const{<:AbstractRadialBasis}, + ) where {RT} + op_val = ℒ_arg.val + data_val = data.val + eval_points_val = eval_points.val + adjl_val = adjl.val + basis_val = basis.val + + dim_space = length(first(data_val)) + mon = MonomialBasis(dim_space, basis_val.poly_deg) + op_mon = op_val(mon) + op_rbf = op_val(basis_val) + + W, cache = _forward_with_cache( + data_val, eval_points_val, adjl_val, basis_val, op_rbf, op_mon, mon, Laplacian + ) + + shadow = _make_shadow_for_return(RT, W) + tape = (op_val, cache, adjl_val, basis_val, mon, deepcopy(data_val), deepcopy(eval_points_val), shadow) + return EnzymeRules.AugmentedReturn(W, shadow, tape) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::EnzymeCore.Const{typeof(_build_weights)}, + dret, + tape, + ℒ_arg::EnzymeCore.Const{<:Laplacian}, + data::EnzymeCore.Duplicated, + eval_points::EnzymeCore.Duplicated, + adjl::EnzymeCore.Const, + basis::EnzymeCore.Const{<:AbstractRadialBasis}, + ) + op_cached, cache, adjl_val, basis_val, mon, data_val, eval_points_val, shadow = tape + ΔW = _extract_dret_with_shadow(dret, shadow) + + TD = eltype(first(data_val)) + N_eval = length(eval_points_val) + k = cache.k + + grad_Lφ_x, grad_Lφ_xi = _get_grad_funcs(Laplacian, basis_val, op_cached) + + for eval_idx in 1:N_eval + neighbors = adjl_val[eval_idx] + eval_point = eval_points_val[eval_idx] + stencil_cache = cache.stencil_caches[eval_idx] + + Δw = extract_stencil_cotangent_enzyme(ΔW, eval_idx, neighbors, k, cache.num_ops) + + if sum(abs, Δw) > 0 + local_data = [data_val[i] for i in neighbors] + Δlocal_data = [zeros(TD, length(first(data_val))) for _ in 1:k] + Δeval_pt = zeros(TD, length(eval_point)) + + _call_backward_stencil!( + Laplacian, Δlocal_data, Δeval_pt, Δw, stencil_cache, collect(1:k), + eval_point, local_data, basis_val, mon, k, op_cached, grad_Lφ_x, grad_Lφ_xi + ) + + for (local_idx, global_idx) in enumerate(neighbors) + data.dval[global_idx] = data.dval[global_idx] + Δlocal_data[local_idx] + end + eval_points.dval[eval_idx] = eval_points.dval[eval_idx] + Δeval_pt + end + end + + return (nothing, nothing, nothing, nothing, nothing) +end + +# Partial with Const data, Active basis (for shape parameter) +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::EnzymeCore.Const{typeof(_build_weights)}, + ::Type{RT}, + ℒ_arg::EnzymeCore.Const{<:Partial}, + data::EnzymeCore.Const, + eval_points::EnzymeCore.Const, + adjl::EnzymeCore.Const, + basis::EnzymeCore.Active{B}, + ) where {RT, B <: AbstractRadialBasis} + op_val = ℒ_arg.val + data_val = data.val + eval_points_val = eval_points.val + adjl_val = adjl.val + basis_val = basis.val + + dim_space = length(first(data_val)) + mon = MonomialBasis(dim_space, basis_val.poly_deg) + op_mon = op_val(mon) + op_rbf = op_val(basis_val) + + W, cache = _forward_with_cache( + data_val, eval_points_val, adjl_val, basis_val, op_rbf, op_mon, mon, Partial + ) + + shadow = _make_shadow_for_return(RT, W) + tape = (op_val, cache, adjl_val, basis_val, mon, data_val, eval_points_val, shadow) + return EnzymeRules.AugmentedReturn(W, shadow, tape) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::EnzymeCore.Const{typeof(_build_weights)}, + dret, + tape, + ℒ_arg::EnzymeCore.Const{<:Partial}, + data::EnzymeCore.Const, + eval_points::EnzymeCore.Const, + adjl::EnzymeCore.Const, + basis::EnzymeCore.Active{B}, + ) where {B <: AbstractRadialBasis} + op_cached, cache, adjl_val, basis_val, mon, data_val, eval_points_val, shadow = tape + ΔW = _extract_dret_with_shadow(dret, shadow) + + TD = eltype(first(data_val)) + N_eval = length(eval_points_val) + k = cache.k + + grad_Lφ_x, grad_Lφ_xi = _get_grad_funcs(Partial, basis_val, op_cached) + Δε_acc = Ref(zero(TD)) + + for eval_idx in 1:N_eval + neighbors = adjl_val[eval_idx] + eval_point = eval_points_val[eval_idx] + stencil_cache = cache.stencil_caches[eval_idx] + + Δw = extract_stencil_cotangent_enzyme(ΔW, eval_idx, neighbors, k, cache.num_ops) + + if sum(abs, Δw) > 0 + local_data = [data_val[i] for i in neighbors] + Δlocal_data = [zeros(TD, length(first(data_val))) for _ in 1:k] + Δeval_pt = zeros(TD, length(eval_point)) + + _call_backward_stencil_with_ε!( + Partial, Δlocal_data, Δeval_pt, Δε_acc, Δw, stencil_cache, collect(1:k), + eval_point, local_data, basis_val, mon, k, op_cached, grad_Lφ_x, grad_Lφ_xi + ) + end + end + + Δbasis = _make_enzyme_tangent(B, basis_val, Δε_acc[]) + return (nothing, nothing, nothing, nothing, Δbasis) +end + +# Laplacian with Const data, Active basis (for shape parameter) +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::EnzymeCore.Const{typeof(_build_weights)}, + ::Type{RT}, + ℒ_arg::EnzymeCore.Const{<:Laplacian}, + data::EnzymeCore.Const, + eval_points::EnzymeCore.Const, + adjl::EnzymeCore.Const, + basis::EnzymeCore.Active{B}, + ) where {RT, B <: AbstractRadialBasis} + op_val = ℒ_arg.val + data_val = data.val + eval_points_val = eval_points.val + adjl_val = adjl.val + basis_val = basis.val + + dim_space = length(first(data_val)) + mon = MonomialBasis(dim_space, basis_val.poly_deg) + op_mon = op_val(mon) + op_rbf = op_val(basis_val) + + W, cache = _forward_with_cache( + data_val, eval_points_val, adjl_val, basis_val, op_rbf, op_mon, mon, Laplacian + ) + + shadow = _make_shadow_for_return(RT, W) + tape = (op_val, cache, adjl_val, basis_val, mon, data_val, eval_points_val, shadow) + return EnzymeRules.AugmentedReturn(W, shadow, tape) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::EnzymeCore.Const{typeof(_build_weights)}, + dret, + tape, + ℒ_arg::EnzymeCore.Const{<:Laplacian}, + data::EnzymeCore.Const, + eval_points::EnzymeCore.Const, + adjl::EnzymeCore.Const, + basis::EnzymeCore.Active{B}, + ) where {B <: AbstractRadialBasis} + op_cached, cache, adjl_val, basis_val, mon, data_val, eval_points_val, shadow = tape + ΔW = _extract_dret_with_shadow(dret, shadow) + + TD = eltype(first(data_val)) + N_eval = length(eval_points_val) + k = cache.k + + grad_Lφ_x, grad_Lφ_xi = _get_grad_funcs(Laplacian, basis_val, op_cached) + Δε_acc = Ref(zero(TD)) + + for eval_idx in 1:N_eval + neighbors = adjl_val[eval_idx] + eval_point = eval_points_val[eval_idx] + stencil_cache = cache.stencil_caches[eval_idx] + + Δw = extract_stencil_cotangent_enzyme(ΔW, eval_idx, neighbors, k, cache.num_ops) + + if sum(abs, Δw) > 0 + local_data = [data_val[i] for i in neighbors] + Δlocal_data = [zeros(TD, length(first(data_val))) for _ in 1:k] + Δeval_pt = zeros(TD, length(eval_point)) + + _call_backward_stencil_with_ε!( + Laplacian, Δlocal_data, Δeval_pt, Δε_acc, Δw, stencil_cache, collect(1:k), + eval_point, local_data, basis_val, mon, k, op_cached, grad_Lφ_x, grad_Lφ_xi + ) + end + end + + Δbasis = _make_enzyme_tangent(B, basis_val, Δε_acc[]) + return (nothing, nothing, nothing, nothing, Δbasis) +end + +end # module diff --git a/ext/RadialBasisFunctionsMooncakeExt/RadialBasisFunctionsMooncakeExt.jl b/ext/RadialBasisFunctionsMooncakeExt/RadialBasisFunctionsMooncakeExt.jl index 31775e0f..2f9e47fe 100644 --- a/ext/RadialBasisFunctionsMooncakeExt/RadialBasisFunctionsMooncakeExt.jl +++ b/ext/RadialBasisFunctionsMooncakeExt/RadialBasisFunctionsMooncakeExt.jl @@ -44,6 +44,41 @@ function Mooncake.increment_and_get_rdata!( return Mooncake.NoRData() end +# ============================================================================= +# increment_and_get_rdata! for Gaussian/IMQ Tangent types from ChainRulesCore +# ============================================================================= +# When rrules return ChainRulesCore.Tangent{Gaussian/IMQ,...}, Mooncake needs to +# know how to accumulate these into its internal RData representation. + +# Gaussian tangent: extract ε from ChainRulesCore.Tangent and add to RData +function Mooncake.increment_and_get_rdata!( + ::Mooncake.NoFData, + r::Mooncake.RData{@NamedTuple{ε::T, poly_deg::Mooncake.NoRData}}, + t::ChainRulesCore.Tangent{<:Gaussian}, + ) where {T} + # Extract ε from ChainRulesCore tangent and add to Mooncake RData + Δε = t.ε + if !(Δε isa ChainRulesCore.NoTangent) && !(Δε isa ChainRulesCore.ZeroTangent) + new_ε = r.data.ε + T(Δε) + return Mooncake.RData{@NamedTuple{ε::T, poly_deg::Mooncake.NoRData}}((ε = new_ε, poly_deg = Mooncake.NoRData())) + end + return r +end + +# IMQ tangent: same pattern as Gaussian +function Mooncake.increment_and_get_rdata!( + ::Mooncake.NoFData, + r::Mooncake.RData{@NamedTuple{ε::T, poly_deg::Mooncake.NoRData}}, + t::ChainRulesCore.Tangent{<:IMQ}, + ) where {T} + Δε = t.ε + if !(Δε isa ChainRulesCore.NoTangent) && !(Δε isa ChainRulesCore.ZeroTangent) + new_ε = r.data.ε + T(Δε) + return Mooncake.RData{@NamedTuple{ε::T, poly_deg::Mooncake.NoRData}}((ε = new_ε, poly_deg = Mooncake.NoRData())) + end + return r +end + # Import ChainRulesCore rules into Mooncake using @from_rrule # The DefaultCtx is used for standard (non-debug) differentiation @@ -62,6 +97,13 @@ Mooncake.@from_rrule( Tuple{typeof(_eval_op), RadialBasisOperator{<:VectorValuedOperator}, Vector{Float64}} ) +# Operator call syntax: op(x) - bypasses cache check issues +Mooncake.@from_rrule(Mooncake.DefaultCtx, Tuple{RadialBasisOperator, Vector{Float64}}) + +Mooncake.@from_rrule( + Mooncake.DefaultCtx, Tuple{RadialBasisOperator{<:VectorValuedOperator}, Vector{Float64}} +) + # Basis function rules for common types (Float64 vectors) # These enable differentiating through weight computation if needed @@ -77,9 +119,15 @@ Mooncake.@from_rrule(Mooncake.DefaultCtx, Tuple{IMQ, Vector{Float64}, Vector{Flo Mooncake.@from_rrule(Mooncake.DefaultCtx, Tuple{Gaussian, Vector{Float64}, Vector{Float64}}) -# Interpolator rules +# Interpolator evaluation rules Mooncake.@from_rrule(Mooncake.DefaultCtx, Tuple{Interpolator, Vector{Float64}}) +# Interpolator construction rules (for differentiating through construction) +Mooncake.@from_rrule( + Mooncake.DefaultCtx, + Tuple{Type{Interpolator}, AbstractVector, AbstractVector, AbstractRadialBasis} +) + # _build_weights rules for shape optimization # These enable differentiating through operator construction w.r.t. point positions @@ -133,4 +181,32 @@ Mooncake.@from_rrule( } ) +# IMQ basis with Partial and Laplacian operators +Mooncake.@from_rrule( + Mooncake.DefaultCtx, + Tuple{typeof(_build_weights), Partial, AbstractVector, AbstractVector, AbstractVector, IMQ} +) + +Mooncake.@from_rrule( + Mooncake.DefaultCtx, + Tuple{ + typeof(_build_weights), Laplacian, AbstractVector, AbstractVector, AbstractVector, IMQ, + } +) + +# Gaussian basis with Partial and Laplacian operators +Mooncake.@from_rrule( + Mooncake.DefaultCtx, + Tuple{ + typeof(_build_weights), Partial, AbstractVector, AbstractVector, AbstractVector, Gaussian, + } +) + +Mooncake.@from_rrule( + Mooncake.DefaultCtx, + Tuple{ + typeof(_build_weights), Laplacian, AbstractVector, AbstractVector, AbstractVector, Gaussian, + } +) + end # module diff --git a/src/RadialBasisFunctions.jl b/src/RadialBasisFunctions.jl index 6b109eb0..cdf758d1 100644 --- a/src/RadialBasisFunctions.jl +++ b/src/RadialBasisFunctions.jl @@ -32,6 +32,13 @@ include("solve/assembly.jl") include("solve/execution.jl") include("solve/api.jl") +# Backward pass support for AD (used by ChainRulesCore and Enzyme extensions) +include("solve/backward_cache.jl") +include("solve/operator_second_derivatives.jl") +include("solve/shape_parameter_derivatives.jl") +include("solve/backward.jl") +include("solve/forward_cache.jl") + include("operators/operators.jl") export RadialBasisOperator, ScalarValuedOperator, VectorValuedOperator export update_weights!, is_cache_valid diff --git a/src/solve/backward.jl b/src/solve/backward.jl new file mode 100644 index 00000000..f10bcfa9 --- /dev/null +++ b/src/solve/backward.jl @@ -0,0 +1,709 @@ +#= +Backward pass functions for _build_weights differentiation rules. + +The backward pass computes: + Given Δw (cotangent of weights), compute Δdata and Δeval_points + +Key steps per stencil: +1. Pad cotangent: Δλ = [Δw; 0] +2. Solve adjoint: η = A⁻ᵀ Δλ +3. Compute: ΔA = -η λᵀ, Δb = η +4. Chain through RHS: accumulate to Δeval_point and Δdata[neighbors] +5. Chain through collocation: accumulate to Δdata[neighbors] +=# + +using LinearAlgebra: dot, mul!, axpy! + +""" + backward_linear_solve!(ΔA, Δb, Δw, cache) + +Compute cotangents of collocation matrix A and RHS vector b +from cotangent of weights Δw. + +Given: Aλ = b, w = λ[1:k] +We have: Δλ = [Δw; 0] (padded with zeros for monomial part) + +Using implicit function theorem: + η = A⁻ᵀ Δλ + ΔA = -η λᵀ + Δb = η +""" +function backward_linear_solve!( + ΔA::AbstractMatrix{T}, + Δb::AbstractVecOrMat{T}, + Δw::AbstractVecOrMat{T}, + cache::StencilForwardCache{T}, + ) where {T} + k = cache.k + nmon = cache.nmon + n = k + nmon + num_ops = size(cache.lambda, 2) + + # Pad Δw with zeros for monomial part + Δλ = zeros(T, n, num_ops) + Δλ[1:k, :] .= Δw + + # Solve adjoint system: A'η = Δλ + # The matrix is symmetric, so A' = A + η = cache.A_mat \ Δλ + + # ΔA = -η * λᵀ (outer product, accumulated across operators) + # Use BLAS mul! for O(n²) instead of scalar triple loop + fill!(ΔA, zero(T)) + for op_idx in 1:num_ops + η_vec = view(η, :, op_idx) + λ_vec = view(cache.lambda, :, op_idx) + # Rank-1 update: ΔA -= η_vec * λ_vec' (outer product) + mul!(ΔA, η_vec, λ_vec', -one(T), one(T)) + end + + # Δb = η + Δb .= η + + return nothing +end + +""" + backward_collocation!(Δdata, ΔA, neighbors, data, basis, mon, k) + +Chain rule through collocation matrix construction. + +The collocation matrix has structure: + A[i,j] = φ(xi, xj) for i,j ≤ k (RBF block) + A[i,k+j] = pⱼ(xi) for i ≤ k (polynomial block) + +For RBF block (using ∇φ from existing basis_rules): + Δxi += ΔA[i,j] * ∇φ(xi, xj) + Δxj -= ΔA[i,j] * ∇φ(xi, xj) (by symmetry of φ(x-y)) + +For polynomial block: + Δxi += ΔA[i,k+j] * ∇pⱼ(xi) + +Note: A is symmetric, so we need to handle both triangles. +""" +function backward_collocation!( + Δdata::Vector{Vector{T}}, + ΔA::AbstractMatrix{T}, + neighbors::Vector{Int}, + data::AbstractVector, + basis::AbstractRadialBasis, + mon::MonomialBasis{Dim, Deg}, + k::Int, + ) where {T, Dim, Deg} + grad_φ = ∇(basis) + n = k + binomial(Dim + Deg, Deg) + + # RBF block: accumulate gradients from symmetric matrix + # Only upper triangle stored, but gradients flow both ways + @inbounds for j in 1:k + xj = data[neighbors[j]] + Δdata_j = Δdata[neighbors[j]] + for i in 1:(j - 1) # Skip diagonal (i == j) since φ(x,x) = 0 always, no gradient contribution + xi = data[neighbors[i]] + Δdata_i = Δdata[neighbors[i]] + + # Get gradient of basis function + ∇φ_ij = grad_φ(xi, xj) + + # ΔA[i,j] contributes to both Δxi and Δxj + # For symmetric matrix, ΔA[i,j] == ΔA[j,i] conceptually + # We need to sum contributions from both triangles + scale = ΔA[i, j] + ΔA[j, i] + + # φ depends on xi - xj, so: ∂φ/∂xi = ∇φ, ∂φ/∂xj = -∇φ + # In-place accumulation avoids broadcast allocation + for d in eachindex(∇φ_ij) + Δdata_i[d] += scale * ∇φ_ij[d] + Δdata_j[d] -= scale * ∇φ_ij[d] + end + end + end + + # Polynomial block: A[i, k+j] = pⱼ(xi) + # Need gradient of monomial basis w.r.t. xi + if Deg > -1 + nmon = binomial(Dim + Deg, Deg) + ∇p = zeros(T, nmon, Dim) + # Hoist functor construction outside loop + ∇mon = ∇(mon) + + @inbounds for i in 1:k + xi = data[neighbors[i]] + ∇mon(∇p, xi) + Δdata_i = Δdata[neighbors[i]] + + # Accumulate gradient from polynomial block + for j in 1:nmon + # ΔA[i, k+j] contributes to Δxi via ∇pⱼ + # Also ΔA[k+j, i] from transpose block + scale = ΔA[i, k + j] + ΔA[k + j, i] + # Use axpy! pattern for in-place accumulation without broadcast allocation + for d in 1:Dim + Δdata_i[d] += scale * ∇p[j, d] + end + end + end + end + + return nothing +end + +""" + backward_rhs_partial!(Δdata, Δeval_point, Δb, neighbors, eval_point, data, basis, dim, k, grad_Lφ_x, grad_Lφ_xi) + +Chain rule through RHS vector construction for Partial operator. + +RHS structure: + b[i] = ℒφ(eval_point, xi) for i = 1:k + b[k+j] = ℒpⱼ(eval_point) for j = 1:nmon + +For RBF section, we need: + ∂/∂eval_point [ℒφ(eval_point, xi)] + ∂/∂xi [ℒφ(eval_point, xi)] + +For polynomial section, we need: + ∂/∂eval_point [ℒpⱼ(eval_point)] + +Note: Unlike Laplacian where ∇²p gives constants, Partial operator ∂p/∂x[dim] +produces terms that depend on eval_point (e.g., ∂(x²)/∂x = 2x, ∂(xy)/∂x = y), +so the polynomial section gradient is NON-ZERO and must be computed. +""" +function backward_rhs_partial!( + Δdata::Vector{Vector{T}}, + Δeval_point::Vector{T}, + Δb::AbstractVecOrMat{T}, + neighbors::Vector{Int}, + eval_point, + data::AbstractVector, + basis::AbstractRadialBasis, + dim::Int, + k::Int, + grad_Lφ_x, + grad_Lφ_xi, + ) where {T} + num_ops = size(Δb, 2) + n = size(Δb, 1) + nmon = n - k + + # RBF section: b[i] = ∂φ/∂x[dim](eval_point, xi) + @inbounds for i in 1:k + xi = data[neighbors[i]] + Δdata_i = Δdata[neighbors[i]] + + # Gradient w.r.t. eval_point and xi + ∇Lφ_x = grad_Lφ_x(eval_point, xi) + ∇Lφ_xi = grad_Lφ_xi(eval_point, xi) + + # Accumulate across operators with in-place scalar ops + for op_idx in 1:num_ops + Δb_val = Δb[i, op_idx] + for d in eachindex(∇Lφ_x) + Δeval_point[d] += Δb_val * ∇Lφ_x[d] + Δdata_i[d] += Δb_val * ∇Lφ_xi[d] + end + end + end + + # Polynomial section: b[k+j] = ∂pⱼ/∂x[dim](eval_point) + # The gradient is ∂²pⱼ/∂x[dim]∂x[d] which is non-zero for some monomials + if nmon > 0 + _backward_partial_polynomial_section!(Δeval_point, Δb, k, nmon, dim, eval_point, num_ops) + end + + return nothing +end + +""" + _backward_partial_polynomial_section!(Δeval_point, Δb, k, nmon, dim, eval_point, num_ops) + +Backward pass through the polynomial section of the RHS for Partial operator. + +For monomials in 2D with poly_deg=2 (1, x, y, xy, x², y²): + ∂/∂x gives: 0, 1, 0, y, 2x, 0 + +The gradients of these w.r.t. eval_point are: + ∂(0)/∂(x,y) = (0, 0) + ∂(1)/∂(x,y) = (0, 0) + ∂(0)/∂(x,y) = (0, 0) + ∂(y)/∂(x,y) = (0, 1) -> b[4] contributes to Δeval_point[2] + ∂(2x)/∂(x,y) = (2, 0) -> b[5] contributes 2 to Δeval_point[1] + ∂(0)/∂(x,y) = (0, 0) + +This is equivalent to computing the mixed second derivatives ∂²pⱼ/∂x[dim]∂x[d]. +""" +function _backward_partial_polynomial_section!( + Δeval_point::Vector{T}, + Δb::AbstractVecOrMat{T}, + k::Int, + nmon::Int, + dim::Int, + eval_point, + num_ops::Int, + ) where {T} + D = length(eval_point) + + # The contribution depends on spatial dimension and polynomial degree + # For poly_deg=2, we have known patterns of non-zero second derivatives + if D == 2 + _backward_partial_poly_2d!(Δeval_point, Δb, k, nmon, dim, num_ops) + elseif D == 3 + _backward_partial_poly_3d!(Δeval_point, Δb, k, nmon, dim, num_ops) + elseif D == 1 + _backward_partial_poly_1d!(Δeval_point, Δb, k, nmon, dim, num_ops) + end + # For higher dimensions, would need additional implementations + return nothing +end + +"""Backward pass for polynomial section in 1D.""" +function _backward_partial_poly_1d!( + Δeval_point::Vector{T}, + Δb::AbstractVecOrMat{T}, + k::Int, + nmon::Int, + dim::Int, + num_ops::Int, + ) where {T} + # 1D monomials up to degree 2: 1, x, x² + # ∂/∂x gives: 0, 1, 2x + # Second derivatives ∂²/∂x²: 0, 0, 2 + # Only x² term contributes, at index k+3 (if nmon >= 3) + if nmon >= 3 + @inbounds for op_idx in 1:num_ops + Δeval_point[1] += Δb[k + 3, op_idx] * 2 + end + end + return nothing +end + +"""Backward pass for polynomial section in 2D.""" +function _backward_partial_poly_2d!( + Δeval_point::Vector{T}, + Δb::AbstractVecOrMat{T}, + k::Int, + nmon::Int, + dim::Int, + num_ops::Int, + ) where {T} + # 2D monomials with poly_deg=2: 1, x, y, xy, x², y² (nmon=6) + # 2D monomials with poly_deg=1: 1, x, y (nmon=3) + # 2D monomials with poly_deg=0: 1 (nmon=1) + + if nmon < 4 + # poly_deg <= 1: all second derivatives are zero + return nothing + end + + # For poly_deg=2 (nmon=6): + # ∂/∂x gives: 0, 1, 0, y, 2x, 0 + # ∂/∂y gives: 0, 0, 1, x, 0, 2y + # + # Second derivatives for ∂/∂x (dim=1): + # ∂(y)/∂y = 1 at index k+4, contributes to Δeval_point[2] + # ∂(2x)/∂x = 2 at index k+5, contributes to Δeval_point[1] + # + # Second derivatives for ∂/∂y (dim=2): + # ∂(x)/∂x = 1 at index k+4, contributes to Δeval_point[1] + # ∂(2y)/∂y = 2 at index k+6, contributes to Δeval_point[2] + + if dim == 1 # ∂/∂x operator + @inbounds for op_idx in 1:num_ops + Δeval_point[2] += Δb[k + 4, op_idx] # from xy term: ∂(y)/∂y = 1 + Δeval_point[1] += Δb[k + 5, op_idx] * 2 # from x² term: ∂(2x)/∂x = 2 + end + elseif dim == 2 # ∂/∂y operator + @inbounds for op_idx in 1:num_ops + Δeval_point[1] += Δb[k + 4, op_idx] # from xy term: ∂(x)/∂x = 1 + Δeval_point[2] += Δb[k + 6, op_idx] * 2 # from y² term: ∂(2y)/∂y = 2 + end + end + + return nothing +end + +"""Backward pass for polynomial section in 3D.""" +function _backward_partial_poly_3d!( + Δeval_point::Vector{T}, + Δb::AbstractVecOrMat{T}, + k::Int, + nmon::Int, + dim::Int, + num_ops::Int, + ) where {T} + # 3D monomials with poly_deg=2: 1, x, y, z, xy, xz, yz, x², y², z² (nmon=10) + # 3D monomials with poly_deg=1: 1, x, y, z (nmon=4) + + if nmon < 5 + # poly_deg <= 1: all second derivatives are zero + return nothing + end + + # For poly_deg=2 (nmon=10): + # Monomial order: 1, x, y, z, xy, xz, yz, x², y², z² + # 1 2 3 4 5 6 7 8 9 10 + # + # ∂/∂x gives: 0, 1, 0, 0, y, z, 0, 2x, 0, 0 + # ∂/∂y gives: 0, 0, 1, 0, x, 0, z, 0, 2y, 0 + # ∂/∂z gives: 0, 0, 0, 1, 0, x, y, 0, 0, 2z + + if dim == 1 # ∂/∂x operator + @inbounds for op_idx in 1:num_ops + Δeval_point[2] += Δb[k + 5, op_idx] # from xy: ∂(y)/∂y = 1 + Δeval_point[3] += Δb[k + 6, op_idx] # from xz: ∂(z)/∂z = 1 + Δeval_point[1] += Δb[k + 8, op_idx] * 2 # from x²: ∂(2x)/∂x = 2 + end + elseif dim == 2 # ∂/∂y operator + @inbounds for op_idx in 1:num_ops + Δeval_point[1] += Δb[k + 5, op_idx] # from xy: ∂(x)/∂x = 1 + Δeval_point[3] += Δb[k + 7, op_idx] # from yz: ∂(z)/∂z = 1 + Δeval_point[2] += Δb[k + 9, op_idx] * 2 # from y²: ∂(2y)/∂y = 2 + end + elseif dim == 3 # ∂/∂z operator + @inbounds for op_idx in 1:num_ops + Δeval_point[1] += Δb[k + 6, op_idx] # from xz: ∂(x)/∂x = 1 + Δeval_point[2] += Δb[k + 7, op_idx] # from yz: ∂(y)/∂y = 1 + Δeval_point[3] += Δb[k + 10, op_idx] * 2 # from z²: ∂(2z)/∂z = 2 + end + end + + return nothing +end + +""" + backward_rhs_laplacian!(Δdata, Δeval_point, Δb, neighbors, eval_point, data, basis, k, grad_Lφ_x, grad_Lφ_xi) + +Chain rule through RHS for Laplacian operator. +""" +function backward_rhs_laplacian!( + Δdata::Vector{Vector{T}}, + Δeval_point::Vector{T}, + Δb::AbstractVecOrMat{T}, + neighbors::Vector{Int}, + eval_point, + data::AbstractVector, + basis::AbstractRadialBasis, + k::Int, + grad_Lφ_x, + grad_Lφ_xi, + ) where {T} + num_ops = size(Δb, 2) + + # RBF section: b[i] = ∇²φ(eval_point, xi) + @inbounds for i in 1:k + xi = data[neighbors[i]] + Δdata_i = Δdata[neighbors[i]] + + # Gradient w.r.t. eval_point and xi + ∇Lφ_x = grad_Lφ_x(eval_point, xi) + ∇Lφ_xi = grad_Lφ_xi(eval_point, xi) + + # Accumulate across operators with in-place scalar ops + for op_idx in 1:num_ops + Δb_val = Δb[i, op_idx] + for d in eachindex(∇Lφ_x) + Δeval_point[d] += Δb_val * ∇Lφ_x[d] + Δdata_i[d] += Δb_val * ∇Lφ_xi[d] + end + end + end + + return nothing +end + +""" + backward_stencil!(Δdata, Δeval_point, Δw, cache, neighbors, eval_point, data, basis, mon, k, grad_Lφ_x, grad_Lφ_xi, backward_rhs!) + +Generic backward pass for a single stencil, parameterized by RHS backward function. + +Combines: +1. backward_linear_solve! - compute ΔA, Δb from Δw +2. backward_collocation! - chain ΔA to Δdata +3. backward_rhs! - chain Δb to Δdata and Δeval_point (operator-specific) +""" +function backward_stencil!( + Δdata::Vector{Vector{T}}, + Δeval_point::Vector{T}, + Δw::AbstractVecOrMat{T}, + cache::StencilForwardCache{T}, + neighbors::Vector{Int}, + eval_point, + data::AbstractVector, + basis::AbstractRadialBasis, + mon::MonomialBasis{Dim, Deg}, + k::Int, + grad_Lφ_x, + grad_Lφ_xi, + backward_rhs!::F, + ) where {T, Dim, Deg, F} + n = k + cache.nmon + + # Allocate workspace for ΔA and Δb + ΔA = zeros(T, n, n) + Δb = zeros(T, n, size(Δw, 2)) + + # Step 1: Backprop through linear solve + backward_linear_solve!(ΔA, Δb, Δw, cache) + + # Step 2: Backprop through collocation matrix + backward_collocation!(Δdata, ΔA, neighbors, data, basis, mon, k) + + # Step 3: Backprop through RHS (operator-specific) + backward_rhs!(Δdata, Δeval_point, Δb, neighbors, eval_point, data, basis, k, grad_Lφ_x, grad_Lφ_xi) + + return nothing +end + +""" + backward_stencil_partial!(Δdata, Δeval_point, Δw, cache, neighbors, eval_point, data, basis, mon, k, dim, grad_Lφ_x, grad_Lφ_xi) + +Complete backward pass for a single stencil with Partial operator. +Dispatches to generic backward_stencil! with partial-specific RHS backward. +""" +function backward_stencil_partial!( + Δdata::Vector{Vector{T}}, + Δeval_point::Vector{T}, + Δw::AbstractVecOrMat{T}, + cache::StencilForwardCache{T}, + neighbors::Vector{Int}, + eval_point, + data::AbstractVector, + basis::AbstractRadialBasis, + mon::MonomialBasis{Dim, Deg}, + k::Int, + dim::Int, + grad_Lφ_x, + grad_Lφ_xi, + ) where {T, Dim, Deg} + return backward_stencil!( + Δdata, Δeval_point, Δw, cache, neighbors, eval_point, data, basis, mon, k, + grad_Lφ_x, grad_Lφ_xi, + (Δdata, Δeval_point, Δb, neighbors, eval_point, data, basis, k, grad_Lφ_x, grad_Lφ_xi) -> + backward_rhs_partial!(Δdata, Δeval_point, Δb, neighbors, eval_point, data, basis, dim, k, grad_Lφ_x, grad_Lφ_xi) + ) +end + +""" + backward_stencil_laplacian!(Δdata, Δeval_point, Δw, cache, neighbors, eval_point, data, basis, mon, k, grad_Lφ_x, grad_Lφ_xi) + +Complete backward pass for a single stencil with Laplacian operator. +Dispatches to generic backward_stencil! with laplacian-specific RHS backward. +""" +function backward_stencil_laplacian!( + Δdata::Vector{Vector{T}}, + Δeval_point::Vector{T}, + Δw::AbstractVecOrMat{T}, + cache::StencilForwardCache{T}, + neighbors::Vector{Int}, + eval_point, + data::AbstractVector, + basis::AbstractRadialBasis, + mon::MonomialBasis{Dim, Deg}, + k::Int, + grad_Lφ_x, + grad_Lφ_xi, + ) where {T, Dim, Deg} + return backward_stencil!( + Δdata, Δeval_point, Δw, cache, neighbors, eval_point, data, basis, mon, k, + grad_Lφ_x, grad_Lφ_xi, backward_rhs_laplacian! + ) +end + +# ============================================================================= +# Shape parameter (ε) gradient computation +# ============================================================================= + +""" + backward_collocation_ε!(Δε_acc, ΔA, neighbors, data, basis, k) + +Compute gradient contribution to shape parameter ε from collocation matrix. + +Uses implicit differentiation: Δε += Σᵢⱼ ΔA[i,j] * ∂A[i,j]/∂ε +where A[i,j] = φ(xi, xj) for the RBF block. +""" +function backward_collocation_ε!( + Δε_acc::Base.RefValue{T}, + ΔA::AbstractMatrix{T}, + neighbors::Vector{Int}, + data::AbstractVector, + basis::AbstractRadialBasis, + k::Int, + ) where {T} + # RBF block: A[i,j] = φ(xi, xj) + # Accumulate gradient from upper triangle (matrix is symmetric) + @inbounds for j in 1:k + xj = data[neighbors[j]] + for i in 1:(j - 1) + xi = data[neighbors[i]] + # ∂φ/∂ε at this pair + ∂φ_∂ε_val = ∂φ_∂ε(basis, xi, xj) + # For symmetric matrix: ΔA[i,j] + ΔA[j,i] + Δε_acc[] += (ΔA[i, j] + ΔA[j, i]) * ∂φ_∂ε_val + end + end + return nothing +end + +""" + backward_rhs_laplacian_ε!(Δε_acc, Δb, neighbors, eval_point, data, basis, k) + +Compute gradient contribution to shape parameter ε from Laplacian RHS. + +Uses: Δε += Σᵢ Δb[i] * ∂(∇²φ)/∂ε +""" +function backward_rhs_laplacian_ε!( + Δε_acc::Base.RefValue{T}, + Δb::AbstractVecOrMat{T}, + neighbors::Vector{Int}, + eval_point, + data::AbstractVector, + basis::AbstractRadialBasis, + k::Int, + ) where {T} + num_ops = size(Δb, 2) + + @inbounds for i in 1:k + xi = data[neighbors[i]] + ∂Lφ_∂ε_val = ∂Laplacian_φ_∂ε(basis, eval_point, xi) + for op_idx in 1:num_ops + Δε_acc[] += Δb[i, op_idx] * ∂Lφ_∂ε_val + end + end + return nothing +end + +""" + backward_rhs_partial_ε!(Δε_acc, Δb, neighbors, eval_point, data, basis, dim, k) + +Compute gradient contribution to shape parameter ε from Partial RHS. + +Uses: Δε += Σᵢ Δb[i] * ∂(∂φ/∂x_dim)/∂ε +""" +function backward_rhs_partial_ε!( + Δε_acc::Base.RefValue{T}, + Δb::AbstractVecOrMat{T}, + neighbors::Vector{Int}, + eval_point, + data::AbstractVector, + basis::AbstractRadialBasis, + dim::Int, + k::Int, + ) where {T} + num_ops = size(Δb, 2) + + @inbounds for i in 1:k + xi = data[neighbors[i]] + ∂Lφ_∂ε_val = ∂Partial_φ_∂ε(basis, dim, eval_point, xi) + for op_idx in 1:num_ops + Δε_acc[] += Δb[i, op_idx] * ∂Lφ_∂ε_val + end + end + return nothing +end + +""" + backward_stencil_with_ε!(Δdata, Δeval_point, Δε_acc, Δw, cache, neighbors, eval_point, data, basis, mon, k, grad_Lφ_x, grad_Lφ_xi, backward_rhs!, backward_rhs_ε!) + +Generic backward pass for a single stencil including shape parameter gradient. +Parameterized by RHS backward functions for both point and ε gradients. +""" +function backward_stencil_with_ε!( + Δdata::Vector{Vector{T}}, + Δeval_point::Vector{T}, + Δε_acc::Base.RefValue{T}, + Δw::AbstractVecOrMat{T}, + cache::StencilForwardCache{T}, + neighbors::Vector{Int}, + eval_point, + data::AbstractVector, + basis::AbstractRadialBasis, + mon::MonomialBasis{Dim, Deg}, + k::Int, + grad_Lφ_x, + grad_Lφ_xi, + backward_rhs!::F1, + backward_rhs_ε!::F2, + ) where {T, Dim, Deg, F1, F2} + n = k + cache.nmon + + # Allocate workspace for ΔA and Δb + ΔA = zeros(T, n, n) + Δb = zeros(T, n, size(Δw, 2)) + + # Step 1: Backprop through linear solve + backward_linear_solve!(ΔA, Δb, Δw, cache) + + # Step 2: Backprop through collocation matrix (point gradients) + backward_collocation!(Δdata, ΔA, neighbors, data, basis, mon, k) + + # Step 3: Backprop through collocation matrix (ε gradient) + backward_collocation_ε!(Δε_acc, ΔA, neighbors, data, basis, k) + + # Step 4: Backprop through RHS (point gradients, operator-specific) + backward_rhs!(Δdata, Δeval_point, Δb, neighbors, eval_point, data, basis, k, grad_Lφ_x, grad_Lφ_xi) + + # Step 5: Backprop through RHS (ε gradient, operator-specific) + backward_rhs_ε!(Δε_acc, Δb, neighbors, eval_point, data, basis, k) + + return nothing +end + +""" + backward_stencil_laplacian_with_ε!(Δdata, Δeval_point, Δε_acc, Δw, cache, ...) + +Complete backward pass for Laplacian including shape parameter gradient. +Dispatches to generic backward_stencil_with_ε!. +""" +function backward_stencil_laplacian_with_ε!( + Δdata::Vector{Vector{T}}, + Δeval_point::Vector{T}, + Δε_acc::Base.RefValue{T}, + Δw::AbstractVecOrMat{T}, + cache::StencilForwardCache{T}, + neighbors::Vector{Int}, + eval_point, + data::AbstractVector, + basis::AbstractRadialBasis, + mon::MonomialBasis{Dim, Deg}, + k::Int, + grad_Lφ_x, + grad_Lφ_xi, + ) where {T, Dim, Deg} + return backward_stencil_with_ε!( + Δdata, Δeval_point, Δε_acc, Δw, cache, neighbors, eval_point, data, basis, mon, k, + grad_Lφ_x, grad_Lφ_xi, + backward_rhs_laplacian!, + backward_rhs_laplacian_ε! + ) +end + +""" + backward_stencil_partial_with_ε!(Δdata, Δeval_point, Δε_acc, Δw, cache, ...) + +Complete backward pass for Partial operator including shape parameter gradient. +Dispatches to generic backward_stencil_with_ε!. +""" +function backward_stencil_partial_with_ε!( + Δdata::Vector{Vector{T}}, + Δeval_point::Vector{T}, + Δε_acc::Base.RefValue{T}, + Δw::AbstractVecOrMat{T}, + cache::StencilForwardCache{T}, + neighbors::Vector{Int}, + eval_point, + data::AbstractVector, + basis::AbstractRadialBasis, + mon::MonomialBasis{Dim, Deg}, + k::Int, + dim::Int, + grad_Lφ_x, + grad_Lφ_xi, + ) where {T, Dim, Deg} + return backward_stencil_with_ε!( + Δdata, Δeval_point, Δε_acc, Δw, cache, neighbors, eval_point, data, basis, mon, k, + grad_Lφ_x, grad_Lφ_xi, + (Δdata, Δeval_point, Δb, neighbors, eval_point, data, basis, k, grad_Lφ_x, grad_Lφ_xi) -> + backward_rhs_partial!(Δdata, Δeval_point, Δb, neighbors, eval_point, data, basis, dim, k, grad_Lφ_x, grad_Lφ_xi), + (Δε_acc, Δb, neighbors, eval_point, data, basis, k) -> + backward_rhs_partial_ε!(Δε_acc, Δb, neighbors, eval_point, data, basis, dim, k) + ) +end diff --git a/ext/RadialBasisFunctionsChainRulesCoreExt/build_weights_cache.jl b/src/solve/backward_cache.jl similarity index 94% rename from ext/RadialBasisFunctionsChainRulesCoreExt/build_weights_cache.jl rename to src/solve/backward_cache.jl index 48c5e38d..bf3eaee7 100644 --- a/ext/RadialBasisFunctionsChainRulesCoreExt/build_weights_cache.jl +++ b/src/solve/backward_cache.jl @@ -1,6 +1,6 @@ #= Cache types for storing forward pass results needed by the backward pass -of _build_weights rrule. +of _build_weights differentiation rules. The key data needed: - lambda: Full solution vector for each stencil (k+nmon size) @@ -11,8 +11,6 @@ the RBF collocation matrix with polynomial augmentation is symmetric but NOT positive definite (has zero blocks), so Cholesky doesn't work. =# -using LinearAlgebra: lu, LU - """ StencilForwardCache{T} @@ -31,7 +29,7 @@ struct StencilForwardCache{T, M <: AbstractMatrix{T}} end """ - WeightsBuildForwardCache{T, TD} + WeightsBuildForwardCache{T} Global cache storing all stencil results and references to inputs. diff --git a/src/solve/forward_cache.jl b/src/solve/forward_cache.jl new file mode 100644 index 00000000..0a1d606f --- /dev/null +++ b/src/solve/forward_cache.jl @@ -0,0 +1,98 @@ +#= +Forward pass with caching for backward pass of _build_weights differentiation rules. + +This builds weights while storing intermediate results needed for the pullback. +=# + +using LinearAlgebra: Symmetric +using SparseArrays: sparse + +""" + _forward_with_cache(data, eval_points, adjl, basis, ℒrbf, ℒmon, mon, ℒType) + +Forward pass that builds weights while caching intermediate results for backward pass. + +Returns: (W, cache) where W is the sparse weight matrix and cache contains +per-stencil factorizations and solutions needed for the pullback. +""" +function _forward_with_cache( + data::AbstractVector, + eval_points::AbstractVector, + adjl::AbstractVector, + basis::AbstractRadialBasis, + ℒrbf, + ℒmon, + mon::MonomialBasis{Dim, Deg}, + ::Type{ℒType}, + ) where {Dim, Deg, ℒType} + TD = eltype(first(data)) + k = length(first(adjl)) + nmon = Deg >= 0 ? binomial(Dim + Deg, Deg) : 0 + n = k + nmon + N_eval = length(eval_points) + N_data = length(data) + + # Determine number of operators (1 for scalar operators) + num_ops = 1 + + # Allocate COO arrays for sparse matrix + nnz = k * N_eval + I = Vector{Int}(undef, nnz) + J = Vector{Int}(undef, nnz) + V = Vector{TD}(undef, nnz) + + # Allocate stencil caches + stencil_caches = Vector{StencilForwardCache{TD, Matrix{TD}}}(undef, N_eval) + + # Process each evaluation point + pos = 1 + for eval_idx in 1:N_eval + neighbors = adjl[eval_idx] + eval_point = eval_points[eval_idx] + + # Get local data for this stencil + local_data = [data[i] for i in neighbors] + + # Build collocation matrix + A_full = zeros(TD, n, n) + A = Symmetric(A_full, :U) + _build_collocation_matrix!(A, local_data, basis, mon, k) + + # Build RHS vector + b = zeros(TD, n, num_ops) + b_vec = view(b, :, 1) + _build_rhs!(b_vec, ℒrbf, ℒmon, local_data, eval_point, basis, mon, k) + + # Solve (symmetric matrix, not positive definite due to zero block) + λ = Symmetric(A_full, :U) \ b + + # Extract weights (first k entries) + w = λ[1:k, :] + + # Store in COO format + for (local_idx, global_idx) in enumerate(neighbors) + I[pos] = eval_idx + J[pos] = global_idx + V[pos] = w[local_idx, 1] + pos += 1 + end + + # Cache for backward pass - store full symmetric matrix + A_full_symmetric = copy(A_full) + # Fill lower triangle from upper + for j in 1:n + for i in (j + 1):n + A_full_symmetric[i, j] = A_full[j, i] + end + end + stencil_caches[eval_idx] = StencilForwardCache(copy(λ), A_full_symmetric, k, nmon) + end + + # Construct sparse matrix + W = sparse(I, J, V, N_eval, N_data) + + # Build global cache + cache = WeightsBuildForwardCache(stencil_caches, k, nmon, num_ops) + + return W, cache +end diff --git a/ext/RadialBasisFunctionsChainRulesCoreExt/operator_second_derivatives.jl b/src/solve/operator_second_derivatives.jl similarity index 57% rename from ext/RadialBasisFunctionsChainRulesCoreExt/operator_second_derivatives.jl rename to src/solve/operator_second_derivatives.jl index 8dae700a..7fdaa584 100644 --- a/ext/RadialBasisFunctionsChainRulesCoreExt/operator_second_derivatives.jl +++ b/src/solve/operator_second_derivatives.jl @@ -11,8 +11,6 @@ These are effectively Hessian-like terms of the basis function. using Distances: euclidean -const AVOID_INF = RadialBasisFunctions.AVOID_INF - # ============================================================================ # PHS3: φ(r) = r³ # First derivative: ∂φ/∂x[dim] = 3 * (x[dim] - xi[dim]) * r @@ -331,65 +329,6 @@ end # Dispatch functions to get correct second derivative based on operator/basis # ============================================================================ -""" - get_grad_Lrbf_wrt_x(ℒrbf, basis) - -Get the function that computes ∂/∂x[ℒφ(x, xi)] for the given operator and basis. -""" -function get_grad_Lrbf_wrt_x(ℒrbf, basis::PHS1) - # Detect operator type from the applied operator - # For now, we support Partial and Laplacian - return _get_grad_Lrbf_wrt_x_impl(ℒrbf, basis) -end - -function get_grad_Lrbf_wrt_x(ℒrbf, basis::PHS3) - return _get_grad_Lrbf_wrt_x_impl(ℒrbf, basis) -end - -function get_grad_Lrbf_wrt_x(ℒrbf, basis::PHS5) - return _get_grad_Lrbf_wrt_x_impl(ℒrbf, basis) -end - -function get_grad_Lrbf_wrt_x(ℒrbf, basis::PHS7) - return _get_grad_Lrbf_wrt_x_impl(ℒrbf, basis) -end - -function get_grad_Lrbf_wrt_xi(ℒrbf, basis::PHS1) - return _get_grad_Lrbf_wrt_xi_impl(ℒrbf, basis) -end - -function get_grad_Lrbf_wrt_xi(ℒrbf, basis::PHS3) - return _get_grad_Lrbf_wrt_xi_impl(ℒrbf, basis) -end - -function get_grad_Lrbf_wrt_xi(ℒrbf, basis::PHS5) - return _get_grad_Lrbf_wrt_xi_impl(ℒrbf, basis) -end - -function get_grad_Lrbf_wrt_xi(ℒrbf, basis::PHS7) - return _get_grad_Lrbf_wrt_xi_impl(ℒrbf, basis) -end - -# Implementation using operator traits -# The applied operator ℒrbf is a closure - we detect type via inspection - -function _get_grad_Lrbf_wrt_x_impl(ℒrbf, basis::PHS1) - # Try to detect if it's a partial or laplacian - # This is a simplified approach - in practice you may need to store operator info - return error( - "Operator type detection not yet implemented for PHS1. Pass operator info explicitly.", - ) -end - -function _get_grad_Lrbf_wrt_xi_impl(ℒrbf, basis::PHS1) - return error( - "Operator type detection not yet implemented for PHS1. Pass operator info explicitly.", - ) -end - -# For now, provide explicit dispatch on operator types -# These will be called from the backward pass with known operator types - """ grad_applied_partial_wrt_x(basis, dim) @@ -463,3 +402,264 @@ end function grad_applied_laplacian_wrt_xi(::PHS7) return grad_laplacian_phs7_wrt_xi() end + +# ============================================================================ +# IMQ: φ(r) = 1/√(1 + (εr)²) +# Let s = ε²r² + 1, then φ = 1/√s = s^(-1/2) +# First derivative: ∂φ/∂x[dim] = -ε² * δ_d / s^(3/2) +# ============================================================================ + +""" + grad_partial_imq_wrt_x(ε, dim) + +Returns a function computing ∂/∂x[j] of [∂φ/∂x[dim]] for IMQ. + +Mathematical derivation: + Let s = ε²r² + 1, δ_d = x[dim] - xi[dim] + ∂φ/∂x[dim] = -ε² * δ_d / s^(3/2) + + ∂²φ/∂x[j]∂x[dim] = -ε² * [δ_{j,dim} / s^(3/2) - δ_d * (3/2) * s^(-5/2) * 2ε² * δ_j] + = -ε² * δ_{j,dim} / s^(3/2) + 3ε⁴ * δ_d * δ_j / s^(5/2) + + For j == dim: -ε² / s^(3/2) + 3ε⁴ * δ_d² / s^(5/2) + For j != dim: 3ε⁴ * δ_d * δ_j / s^(5/2) +""" +function grad_partial_imq_wrt_x(ε::T, dim::Int) where {T} + ε2 = ε^2 + ε4 = ε^4 + function grad_Lφ_x(x, xi) + r2 = sqeuclidean(x, xi) + s = ε2 * r2 + 1 + s32 = sqrt(s^3) + s52 = sqrt(s^5) + δ = x .- xi + δ_d = δ[dim] + + result = similar(x, eltype(x)) + @inbounds for j in eachindex(x) + if j == dim + result[j] = -ε2 / s32 + 3 * ε4 * δ_d^2 / s52 + else + result[j] = 3 * ε4 * δ_d * δ[j] / s52 + end + end + return result + end + return grad_Lφ_x +end + +""" + grad_partial_imq_wrt_xi(ε, dim) + +Returns a function computing ∂/∂xi[j] of [∂φ/∂x[dim]] for IMQ. + +By symmetry: ∂/∂xi = -∂/∂x for terms depending on (x - xi). +""" +function grad_partial_imq_wrt_xi(ε::T, dim::Int) where {T} + grad_x = grad_partial_imq_wrt_x(ε, dim) + function grad_Lφ_xi(x, xi) + return -grad_x(x, xi) + end + return grad_Lφ_xi +end + +# ============================================================================ +# IMQ: Laplacian ∇²φ = 3ε⁴r²/s^(5/2) - D*ε²/s^(3/2) +# where s = ε²r² + 1, D = dimension +# ============================================================================ + +""" + grad_laplacian_imq_wrt_x(ε) + +Returns a function computing ∂/∂x[j] of [∇²φ] for IMQ. + +Mathematical derivation: + ∇²φ = sum_i [∂²φ/∂x[i]²] = 3ε⁴r²/s^(5/2) - D*ε²/s^(3/2) + + Let u = 3ε⁴r²/s^(5/2) and v = D*ε²/s^(3/2) + + ∂u/∂x[j] = 3ε⁴ * [2δ_j / s^(5/2) - r² * (5/2) * s^(-7/2) * 2ε² * δ_j] + = 3ε⁴ * δ_j * [2/s^(5/2) - 5ε²r²/s^(7/2)] + = δ_j * [6ε⁴/s^(5/2) - 15ε⁶r²/s^(7/2)] + + ∂v/∂x[j] = D*ε² * (-(3/2) * s^(-5/2) * 2ε² * δ_j) + = -3D*ε⁴ * δ_j / s^(5/2) + + ∂(∇²φ)/∂x[j] = ∂u/∂x[j] - ∂v/∂x[j] + = δ_j * [6ε⁴/s^(5/2) - 15ε⁶r²/s^(7/2) + 3D*ε⁴/s^(5/2)] + = δ_j * [(6 + 3D)ε⁴/s^(5/2) - 15ε⁶r²/s^(7/2)] + = δ_j * [3(D+2)ε⁴/s^(5/2) - 15ε⁶r²/s^(7/2)] +""" +function grad_laplacian_imq_wrt_x(ε::T) where {T} + ε2 = ε^2 + ε4 = ε^4 + ε6 = ε^6 + function grad_Lφ_x(x, xi) + D = length(x) + r2 = sqeuclidean(x, xi) + s = ε2 * r2 + 1 + s52 = sqrt(s^5) + s72 = sqrt(s^7) + δ = x .- xi + coeff = 3 * (D + 2) * ε4 / s52 - 15 * ε6 * r2 / s72 + return coeff .* δ + end + return grad_Lφ_x +end + +""" + grad_laplacian_imq_wrt_xi(ε) + +Returns a function computing ∂/∂xi[j] of [∇²φ] for IMQ. +""" +function grad_laplacian_imq_wrt_xi(ε::T) where {T} + grad_x = grad_laplacian_imq_wrt_x(ε) + function grad_Lφ_xi(x, xi) + return -grad_x(x, xi) + end + return grad_Lφ_xi +end + +# ============================================================================ +# Gaussian: φ(r) = exp(-(εr)²) +# First derivative: ∂φ/∂x[dim] = -2ε² * δ_d * φ +# ============================================================================ + +""" + grad_partial_gaussian_wrt_x(ε, dim) + +Returns a function computing ∂/∂x[j] of [∂φ/∂x[dim]] for Gaussian. + +Mathematical derivation: + φ = exp(-ε²r²) + ∂φ/∂x[dim] = -2ε² * δ_d * φ + + ∂²φ/∂x[j]∂x[dim] = -2ε² * [δ_{j,dim} * φ + δ_d * ∂φ/∂x[j]] + = -2ε² * [δ_{j,dim} * φ + δ_d * (-2ε² * δ_j * φ)] + = -2ε² * φ * [δ_{j,dim} - 2ε² * δ_d * δ_j] + = φ * [-2ε² * δ_{j,dim} + 4ε⁴ * δ_d * δ_j] + + For j == dim: φ * (4ε⁴ * δ_d² - 2ε²) + For j != dim: φ * 4ε⁴ * δ_d * δ_j +""" +function grad_partial_gaussian_wrt_x(ε::T, dim::Int) where {T} + ε2 = ε^2 + ε4 = ε^4 + function grad_Lφ_x(x, xi) + r2 = sqeuclidean(x, xi) + φ = exp(-ε2 * r2) + δ = x .- xi + δ_d = δ[dim] + + result = similar(x, eltype(x)) + @inbounds for j in eachindex(x) + if j == dim + result[j] = φ * (4 * ε4 * δ_d^2 - 2 * ε2) + else + result[j] = φ * 4 * ε4 * δ_d * δ[j] + end + end + return result + end + return grad_Lφ_x +end + +""" + grad_partial_gaussian_wrt_xi(ε, dim) + +Returns a function computing ∂/∂xi[j] of [∂φ/∂x[dim]] for Gaussian. + +By symmetry: ∂/∂xi = -∂/∂x for terms depending on (x - xi). +""" +function grad_partial_gaussian_wrt_xi(ε::T, dim::Int) where {T} + grad_x = grad_partial_gaussian_wrt_x(ε, dim) + function grad_Lφ_xi(x, xi) + return -grad_x(x, xi) + end + return grad_Lφ_xi +end + +# ============================================================================ +# Gaussian: Laplacian ∇²φ = (4ε⁴r² - 2ε²D) * φ +# where D = dimension +# ============================================================================ + +""" + grad_laplacian_gaussian_wrt_x(ε) + +Returns a function computing ∂/∂x[j] of [∇²φ] for Gaussian. + +Mathematical derivation: + ∇²φ = (4ε⁴r² - 2ε²D) * φ + + Let u = 4ε⁴r² - 2ε²D (coefficient) + ∂u/∂x[j] = 4ε⁴ * 2δ_j = 8ε⁴ * δ_j + + ∂(∇²φ)/∂x[j] = ∂u/∂x[j] * φ + u * ∂φ/∂x[j] + = 8ε⁴ * δ_j * φ + (4ε⁴r² - 2ε²D) * (-2ε² * δ_j * φ) + = φ * δ_j * [8ε⁴ - 2ε² * (4ε⁴r² - 2ε²D)] + = φ * δ_j * [8ε⁴ - 8ε⁶r² + 4ε⁴D] + = φ * δ_j * 4ε⁴ * [2 + D - 2ε²r²] +""" +function grad_laplacian_gaussian_wrt_x(ε::T) where {T} + ε2 = ε^2 + ε4 = ε^4 + function grad_Lφ_x(x, xi) + D = length(x) + r2 = sqeuclidean(x, xi) + φ = exp(-ε2 * r2) + δ = x .- xi + coeff = 4 * ε4 * (2 + D - 2 * ε2 * r2) + return φ * coeff .* δ + end + return grad_Lφ_x +end + +""" + grad_laplacian_gaussian_wrt_xi(ε) + +Returns a function computing ∂/∂xi[j] of [∇²φ] for Gaussian. +""" +function grad_laplacian_gaussian_wrt_xi(ε::T) where {T} + grad_x = grad_laplacian_gaussian_wrt_x(ε) + function grad_Lφ_xi(x, xi) + return -grad_x(x, xi) + end + return grad_Lφ_xi +end + +# ============================================================================ +# Dispatch functions for IMQ and Gaussian +# ============================================================================ + +function grad_applied_partial_wrt_x(basis::IMQ, dim::Int) + return grad_partial_imq_wrt_x(basis.ε, dim) +end + +function grad_applied_partial_wrt_xi(basis::IMQ, dim::Int) + return grad_partial_imq_wrt_xi(basis.ε, dim) +end + +function grad_applied_laplacian_wrt_x(basis::IMQ) + return grad_laplacian_imq_wrt_x(basis.ε) +end + +function grad_applied_laplacian_wrt_xi(basis::IMQ) + return grad_laplacian_imq_wrt_xi(basis.ε) +end + +function grad_applied_partial_wrt_x(basis::Gaussian, dim::Int) + return grad_partial_gaussian_wrt_x(basis.ε, dim) +end + +function grad_applied_partial_wrt_xi(basis::Gaussian, dim::Int) + return grad_partial_gaussian_wrt_xi(basis.ε, dim) +end + +function grad_applied_laplacian_wrt_x(basis::Gaussian) + return grad_laplacian_gaussian_wrt_x(basis.ε) +end + +function grad_applied_laplacian_wrt_xi(basis::Gaussian) + return grad_laplacian_gaussian_wrt_xi(basis.ε) +end diff --git a/src/solve/shape_parameter_derivatives.jl b/src/solve/shape_parameter_derivatives.jl new file mode 100644 index 00000000..b9a3d699 --- /dev/null +++ b/src/solve/shape_parameter_derivatives.jl @@ -0,0 +1,143 @@ +#= +Derivatives of basis functions and applied operators with respect to shape parameter ε. + +These functions are used in the backward pass of _build_weights to compute ∂W/∂ε. + +For a basis function φ(r; ε) where r = |x - xi|: +- ∂φ/∂ε: derivative of basis function w.r.t. shape parameter +- ∂(ℒφ)/∂ε: derivative of applied operator w.r.t. shape parameter + +Currently supported: +- Gaussian: φ(r) = exp(-(εr)²) +- IMQ: φ(r) = 1/√((εr)² + 1) +=# + +# ============================================================================= +# Gaussian basis: φ(r) = exp(-ε²r²) +# ============================================================================= + +""" + ∂φ_∂ε(basis::Gaussian, x, xi) + +Derivative of Gaussian basis function w.r.t. shape parameter ε. + +φ(r) = exp(-ε²r²) +∂φ/∂ε = -2εr² exp(-ε²r²) +""" +function ∂φ_∂ε(basis::Gaussian, x, xi) + ε = basis.ε + r² = sqeuclidean(x, xi) + return -2 * ε * r² * exp(-ε^2 * r²) +end + +""" + ∂Laplacian_φ_∂ε(basis::Gaussian, x, xi) + +Derivative of Laplacian of Gaussian basis w.r.t. shape parameter ε. + +∇²φ = (-2ε²D + 4ε⁴r²) exp(-ε²r²), where D = dimension +∂(∇²φ)/∂ε = exp(-ε²r²) [-4εD + 16ε³r² + 4ε³r²D - 8ε⁵r⁴] +""" +function ∂Laplacian_φ_∂ε(basis::Gaussian, x, xi) + ε = basis.ε + ε² = ε^2 + ε³ = ε^3 + ε⁵ = ε^5 + r² = sqeuclidean(x, xi) + D = length(x) + φ = exp(-ε² * r²) + return φ * (-4 * ε * D + 16 * ε³ * r² + 4 * ε³ * r² * D - 8 * ε⁵ * r²^2) +end + +""" + ∂Partial_φ_∂ε(basis::Gaussian, dim::Int, x, xi) + +Derivative of first partial derivative of Gaussian basis w.r.t. shape parameter ε. + +∂φ/∂x_dim = -2ε²(x_dim - xi_dim) exp(-ε²r²) +∂/∂ε[∂φ/∂x_dim] = 4ε(x_dim - xi_dim)(ε²r² - 1) exp(-ε²r²) +""" +function ∂Partial_φ_∂ε(basis::Gaussian, dim::Int, x, xi) + ε = basis.ε + ε² = ε^2 + r² = sqeuclidean(x, xi) + Δ_dim = x[dim] - xi[dim] + φ = exp(-ε² * r²) + return 4 * ε * Δ_dim * (ε² * r² - 1) * φ +end + +# ============================================================================= +# IMQ basis: φ(r) = 1/√(ε²r² + 1) +# ============================================================================= + +""" + ∂φ_∂ε(basis::IMQ, x, xi) + +Derivative of IMQ basis function w.r.t. shape parameter ε. + +φ(r) = (ε²r² + 1)^(-1/2) +∂φ/∂ε = -εr² (ε²r² + 1)^(-3/2) +""" +function ∂φ_∂ε(basis::IMQ, x, xi) + ε = basis.ε + r² = sqeuclidean(x, xi) + s = ε^2 * r² + 1 + return -ε * r² / sqrt(s^3) +end + +""" + ∂Laplacian_φ_∂ε(basis::IMQ, x, xi) + +Derivative of Laplacian of IMQ basis w.r.t. shape parameter ε. + +Let s = ε²r² + 1, D = dimension +∇²φ = -ε²D/s^(3/2) + 3ε⁴r²/s^(5/2) +∂(∇²φ)/∂ε = ∂/∂ε[-ε²D s^(-3/2) + 3ε⁴r² s^(-5/2)] +""" +function ∂Laplacian_φ_∂ε(basis::IMQ, x, xi) + ε = basis.ε + ε² = ε^2 + ε³ = ε^3 + ε⁴ = ε^4 + ε⁵ = ε^5 + r² = sqeuclidean(x, xi) + r⁴ = r²^2 + D = length(x) + s = ε² * r² + 1 + + # ∂/∂ε[-ε²D s^(-3/2)] = -2εD s^(-3/2) + ε²D (3/2) s^(-5/2) · 2εr² + # = -2εD s^(-3/2) + 3ε³D r² s^(-5/2) + term1 = -2 * ε * D / sqrt(s^3) + 3 * ε³ * D * r² / sqrt(s^5) + + # ∂/∂ε[3ε⁴r² s^(-5/2)] = 12ε³r² s^(-5/2) + 3ε⁴r² (-5/2) s^(-7/2) · 2εr² + # = 12ε³r² s^(-5/2) - 15ε⁵r⁴ s^(-7/2) + term2 = 12 * ε³ * r² / sqrt(s^5) - 15 * ε⁵ * r⁴ / sqrt(s^7) + + return term1 + term2 +end + +""" + ∂Partial_φ_∂ε(basis::IMQ, dim::Int, x, xi) + +Derivative of first partial derivative of IMQ basis w.r.t. shape parameter ε. + +∂φ/∂x_dim = ε²(xi_dim - x_dim) s^(-3/2) +∂/∂ε[∂φ/∂x_dim] = 2ε(xi_dim - x_dim) s^(-3/2) + ε²(xi_dim - x_dim)(-3/2)s^(-5/2) · 2εr² + = (xi_dim - x_dim)[2ε s^(-3/2) - 3ε³r² s^(-5/2)] +""" +function ∂Partial_φ_∂ε(basis::IMQ, dim::Int, x, xi) + ε = basis.ε + ε³ = ε^3 + r² = sqeuclidean(x, xi) + Δ_dim = xi[dim] - x[dim] # Note: IMQ has opposite sign convention + s = ε^2 * r² + 1 + return Δ_dim * (2 * ε / sqrt(s^3) - 3 * ε³ * r² / sqrt(s^5)) +end + +# ============================================================================= +# PHS basis: no shape parameter, gradients are zero +# ============================================================================= + +∂φ_∂ε(::AbstractPHS, x, xi) = zero(eltype(x)) +∂Laplacian_φ_∂ε(::AbstractPHS, x, xi) = zero(eltype(x)) +∂Partial_φ_∂ε(::AbstractPHS, ::Int, x, xi) = zero(eltype(x)) diff --git a/test/Project.toml b/test/Project.toml index 22942397..e3234f5f 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,10 +1,18 @@ [deps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" +FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" HaltonSequences = "13907d55-377f-55d6-a9d6-25ac19e11b95" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" +RadialBasisFunctions = "79ee0514-adf7-4479-8807-6f72ea8967e8" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/extensions/autodiff_di.jl b/test/extensions/autodiff_di.jl new file mode 100644 index 00000000..4ecd1626 --- /dev/null +++ b/test/extensions/autodiff_di.jl @@ -0,0 +1,588 @@ +using RadialBasisFunctions +using StaticArraysCore +using FiniteDifferences +using LinearAlgebra +using Test +import DifferentiationInterface as DI +using Enzyme: Enzyme +using Mooncake: Mooncake +using Random: MersenneTwister + +const FD = FiniteDifferences + +# Version compatibility - Enzyme.jl has known issues with Julia 1.12+ +# See: https://github.com/EnzymeAD/Enzyme.jl/issues/2699 +const ENZYME_SUPPORTED = VERSION < v"1.12" + +# Backend configuration +const ENZYME_BACKEND = DI.AutoEnzyme(; function_annotation = Enzyme.Const) +const MOONCAKE_BACKEND = DI.AutoMooncake(; config = nothing) + +# Build backend registry (only include supported backends) +const AD_BACKENDS = Pair{String, Any}[] +ENZYME_SUPPORTED && push!(AD_BACKENDS, "Enzyme" => ENZYME_BACKEND) +push!(AD_BACKENDS, "Mooncake" => MOONCAKE_BACKEND) + +""" + test_gradient_vs_fd(f, x, backend; rtol=1e-4, name="") + +Test that DI.gradient matches finite differences for function f at point x. +""" +function test_gradient_vs_fd(f, x, backend; rtol = 1.0e-4, name = "") + di_grad = DI.gradient(f, backend, x) + fd_grad = FD.grad(FD.central_fdm(5, 1), f, x)[1] + @test !all(iszero, di_grad) + return @test isapprox(di_grad, fd_grad; rtol = rtol) +end + +@testset "Autodiff via DifferentiationInterface" begin + @testset "Operator Differentiation" begin + N = 50 + points = [ + SVector{2}(0.1 + 0.8 * i / N, 0.1 + 0.8 * j / N) for i in 1:isqrt(N) + for j in 1:isqrt(N) + ] + N = length(points) + values = sin.(getindex.(points, 1)) .+ cos.(getindex.(points, 2)) + + @testset "Laplacian Operator" begin + lap = laplacian(points) + loss_lap(v) = sum(lap(v) .^ 2) + + for (name, backend) in AD_BACKENDS + @testset "$name" begin + test_gradient_vs_fd(loss_lap, values, backend; rtol = 1.0e-4) + end + end + end + + @testset "Gradient Operator" begin + grad_op = RadialBasisFunctions.gradient(points) + loss_grad(v) = sum(grad_op(v) .^ 2) + + for (name, backend) in AD_BACKENDS + @testset "$name" begin + # Vector-valued operators have known issues on Julia 1.12+ with Enzyme + if name == "Enzyme" && VERSION >= v"1.12" + @test_skip "Vector-valued operators have known issues on Julia 1.12+" + else + test_gradient_vs_fd(loss_grad, values, backend; rtol = 1.0e-4) + end + end + end + end + + @testset "Partial Derivative Operator" begin + partial_x = partial(points, 1, 1) + loss_partial(v) = sum(partial_x(v) .^ 2) + + for (name, backend) in AD_BACKENDS + @testset "$name" begin + test_gradient_vs_fd(loss_partial, values, backend; rtol = 1.0e-4) + end + end + end + end + + @testset "Interpolator Differentiation" begin + rng = MersenneTwister(789) + N = 30 + points = [SVector{2}(rand(rng), rand(rng)) for _ in 1:N] + values = sin.(getindex.(points, 1)) + eval_points = [SVector{2}(rand(rng), rand(rng)) for _ in 1:10] + + @testset "Construction w.r.t. values (PHS default)" begin + function loss_interp(v) + interp_local = Interpolator(points, v) + return sum(interp_local(eval_points) .^ 2) + end + + for (name, backend) in AD_BACKENDS + @testset "$name" begin + if name == "Enzyme" + @test_skip "Enzyme cannot differentiate through Interpolator constructor (factorize Union types)" + else + test_gradient_vs_fd(loss_interp, values, backend; rtol = 1.0e-3) + end + end + end + end + + @testset "Construction with IMQ basis" begin + function loss_interp_imq(v) + interp_local = Interpolator(points, v, IMQ(1.0)) + return sum(interp_local(eval_points) .^ 2) + end + + for (name, backend) in AD_BACKENDS + @testset "$name" begin + if name == "Enzyme" + @test_skip "Enzyme cannot differentiate through Interpolator constructor (factorize Union types)" + else + test_gradient_vs_fd(loss_interp_imq, values, backend; rtol = 1.0e-3) + end + end + end + end + + @testset "Construction with Gaussian basis" begin + function loss_interp_gauss(v) + interp_local = Interpolator(points, v, Gaussian(1.0)) + return sum(interp_local(eval_points) .^ 2) + end + + for (name, backend) in AD_BACKENDS + @testset "$name" begin + if name == "Enzyme" + @test_skip "Enzyme cannot differentiate through Interpolator constructor (factorize Union types)" + else + test_gradient_vs_fd(loss_interp_gauss, values, backend; rtol = 1.0e-3) + end + end + end + end + + @testset "Single point evaluation" begin + single_eval_point = SVector{2}(0.5, 0.5) + + function loss_interp_single(v) + interp_local = Interpolator(points, v) + return interp_local(single_eval_point)^2 + end + + for (name, backend) in AD_BACKENDS + @testset "$name" begin + if name == "Enzyme" + @test_skip "Enzyme cannot differentiate through Interpolator constructor (factorize Union types)" + else + test_gradient_vs_fd(loss_interp_single, values, backend; rtol = 1.0e-3) + end + end + end + end + end + + @testset "Basis Function Differentiation" begin + x = [0.5, 0.5] + xi = [0.3, 0.4] + + @testset "PHS Basis Functions" begin + for (order, phs_type) in [(1, PHS(1)), (3, PHS(3)), (5, PHS(5)), (7, PHS(7))] + loss_phs(xv) = phs_type(xv, xi)^2 + + for (name, backend) in AD_BACKENDS + @testset "PHS($order) - $name" begin + test_gradient_vs_fd(loss_phs, x, backend; rtol = 1.0e-4) + end + end + end + end + + @testset "IMQ Basis Function" begin + imq = IMQ(1.0) + loss_imq(xv) = imq(xv, xi)^2 + + for (name, backend) in AD_BACKENDS + @testset "$name" begin + test_gradient_vs_fd(loss_imq, x, backend; rtol = 1.0e-4) + end + end + end + + @testset "Gaussian Basis Function" begin + gauss = Gaussian(1.0) + loss_gauss(xv) = gauss(xv, xi)^2 + + for (name, backend) in AD_BACKENDS + @testset "$name" begin + test_gradient_vs_fd(loss_gauss, x, backend; rtol = 1.0e-4) + end + end + end + end + + @testset "_build_weights Differentiation" begin + N = 25 + points = [SVector{2}(0.1 + 0.8 * i / 5, 0.1 + 0.8 * j / 5) for i in 1:5 for j in 1:5] + adjl = RadialBasisFunctions.find_neighbors(points, 10) + + @testset "Partial operator with PHS3" begin + basis = PHS(3; poly_deg = 2) + ℒ = Partial(1, 1) + + function loss_partial_weights(pts) + pts_vec = [SVector{2}(pts[2 * i - 1], pts[2 * i]) for i in 1:N] + W = RadialBasisFunctions._build_weights(ℒ, pts_vec, pts_vec, adjl, basis) + return sum(W.nzval .^ 2) + end + + pts_flat = vcat([collect(p) for p in points]...) + + for (name, backend) in AD_BACKENDS + @testset "$name" begin + test_gradient_vs_fd(loss_partial_weights, pts_flat, backend; rtol = 1.0e-3) + end + end + end + + @testset "Laplacian operator with PHS3" begin + basis = PHS(3; poly_deg = 2) + ℒ = Laplacian() + + function loss_laplacian_weights(pts) + pts_vec = [SVector{2}(pts[2 * i - 1], pts[2 * i]) for i in 1:N] + W = RadialBasisFunctions._build_weights(ℒ, pts_vec, pts_vec, adjl, basis) + return sum(W.nzval .^ 2) + end + + pts_flat = vcat([collect(p) for p in points]...) + + for (name, backend) in AD_BACKENDS + @testset "$name" begin + test_gradient_vs_fd( + loss_laplacian_weights, pts_flat, backend; rtol = 1.0e-3 + ) + end + end + end + + @testset "Different PHS orders for Laplacian" begin + for n in [1, 3, 5, 7] + basis = PHS(n; poly_deg = 1) + ℒ = Laplacian() + + function loss_laplacian_phs_order(pts) + pts_vec = [SVector{2}(pts[2 * i - 1], pts[2 * i]) for i in 1:N] + W = RadialBasisFunctions._build_weights(ℒ, pts_vec, pts_vec, adjl, basis) + return sum(W.nzval .^ 2) + end + + pts_flat = vcat([collect(p) for p in points]...) + + for (name, backend) in AD_BACKENDS + @testset "PHS($n) - $name" begin + di_grad = DI.gradient(loss_laplacian_phs_order, backend, pts_flat) + fd_grad = FD.grad(FD.central_fdm(5, 1), loss_laplacian_phs_order, pts_flat)[1] + # PHS1 may have zero gradient for some configurations + @test !all(iszero, di_grad) || n == 1 + @test isapprox(di_grad, fd_grad; rtol = 1.0e-2) + end + end + end + end + + @testset "1D Partial operator with PHS3" begin + N_1d = 10 + points_1d = [SVector{1}(0.1 + 0.8 * i / N_1d) for i in 1:N_1d] + adjl_1d = RadialBasisFunctions.find_neighbors(points_1d, 5) + basis_1d = PHS(3; poly_deg = 2) + ℒ_1d = Partial(1, 1) + + function loss_partial_weights_1d(pts) + pts_vec = [SVector{1}(pts[i]) for i in 1:N_1d] + W = RadialBasisFunctions._build_weights( + ℒ_1d, pts_vec, pts_vec, adjl_1d, basis_1d + ) + return sum(W.nzval .^ 2) + end + + pts_flat_1d = vcat([collect(p) for p in points_1d]...) + + for (name, backend) in AD_BACKENDS + @testset "$name" begin + test_gradient_vs_fd( + loss_partial_weights_1d, pts_flat_1d, backend; rtol = 1.0e-3 + ) + end + end + end + + @testset "3D Partial operator with PHS3" begin + # Use Halton-like sequence to avoid singular stencils on regular grids + N_3d = 64 + points_3d = [ + SVector{3}( + 0.1 + 0.8 * ((i * 7 + 3) % N_3d) / N_3d, + 0.1 + 0.8 * ((i * 11 + 5) % N_3d) / N_3d, + 0.1 + 0.8 * ((i * 13 + 7) % N_3d) / N_3d, + ) for i in 1:N_3d + ] + adjl_3d = RadialBasisFunctions.find_neighbors(points_3d, 20) + basis_3d = PHS(3; poly_deg = 2) + ℒ_3d = Partial(1, 1) + + function loss_partial_weights_3d(pts) + pts_vec = [ + SVector{3}(pts[3 * i - 2], pts[3 * i - 1], pts[3 * i]) for i in 1:N_3d + ] + W = RadialBasisFunctions._build_weights( + ℒ_3d, pts_vec, pts_vec, adjl_3d, basis_3d + ) + return sum(W.nzval .^ 2) + end + + pts_flat_3d = vcat([collect(p) for p in points_3d]...) + + for (name, backend) in AD_BACKENDS + @testset "$name" begin + test_gradient_vs_fd( + loss_partial_weights_3d, pts_flat_3d, backend; rtol = 1.0e-3 + ) + end + end + end + + @testset "3D Partial(1,2) operator with PHS3" begin + N_3d_y = 64 + points_3d_y = [ + SVector{3}( + 0.1 + 0.8 * ((i * 7 + 3) % N_3d_y) / N_3d_y, + 0.1 + 0.8 * ((i * 11 + 5) % N_3d_y) / N_3d_y, + 0.1 + 0.8 * ((i * 13 + 7) % N_3d_y) / N_3d_y, + ) for i in 1:N_3d_y + ] + adjl_3d_y = RadialBasisFunctions.find_neighbors(points_3d_y, 20) + basis_3d_y = PHS(3; poly_deg = 2) + ℒ_3d_y = Partial(1, 2) + + function loss_partial_weights_3d_y(pts) + pts_vec = [ + SVector{3}(pts[3 * i - 2], pts[3 * i - 1], pts[3 * i]) for + i in 1:N_3d_y + ] + W = RadialBasisFunctions._build_weights( + ℒ_3d_y, pts_vec, pts_vec, adjl_3d_y, basis_3d_y + ) + return sum(W.nzval .^ 2) + end + + pts_flat_3d_y = vcat([collect(p) for p in points_3d_y]...) + + for (name, backend) in AD_BACKENDS + @testset "$name" begin + test_gradient_vs_fd( + loss_partial_weights_3d_y, pts_flat_3d_y, backend; rtol = 1.0e-3 + ) + end + end + end + + @testset "3D Partial(1,3) operator with PHS3" begin + N_3d_z = 64 + points_3d_z = [ + SVector{3}( + 0.1 + 0.8 * ((i * 7 + 3) % N_3d_z) / N_3d_z, + 0.1 + 0.8 * ((i * 11 + 5) % N_3d_z) / N_3d_z, + 0.1 + 0.8 * ((i * 13 + 7) % N_3d_z) / N_3d_z, + ) for i in 1:N_3d_z + ] + adjl_3d_z = RadialBasisFunctions.find_neighbors(points_3d_z, 20) + basis_3d_z = PHS(3; poly_deg = 2) + ℒ_3d_z = Partial(1, 3) + + function loss_partial_weights_3d_z(pts) + pts_vec = [ + SVector{3}(pts[3 * i - 2], pts[3 * i - 1], pts[3 * i]) for + i in 1:N_3d_z + ] + W = RadialBasisFunctions._build_weights( + ℒ_3d_z, pts_vec, pts_vec, adjl_3d_z, basis_3d_z + ) + return sum(W.nzval .^ 2) + end + + pts_flat_3d_z = vcat([collect(p) for p in points_3d_z]...) + + for (name, backend) in AD_BACKENDS + @testset "$name" begin + test_gradient_vs_fd( + loss_partial_weights_3d_z, pts_flat_3d_z, backend; rtol = 1.0e-3 + ) + end + end + end + + @testset "2D Partial(1,2) operator with PHS3" begin + basis = PHS(3; poly_deg = 2) + ℒ_y = Partial(1, 2) + + function loss_partial_y_weights(pts) + pts_vec = [SVector{2}(pts[2 * i - 1], pts[2 * i]) for i in 1:N] + W = RadialBasisFunctions._build_weights(ℒ_y, pts_vec, pts_vec, adjl, basis) + return sum(W.nzval .^ 2) + end + + pts_flat = vcat([collect(p) for p in points]...) + + for (name, backend) in AD_BACKENDS + @testset "$name" begin + test_gradient_vs_fd( + loss_partial_y_weights, pts_flat, backend; rtol = 1.0e-3 + ) + end + end + end + + @testset "Different PHS orders" begin + for n in [1, 3, 5, 7] + basis = PHS(n; poly_deg = 1) + ℒ = Partial(1, 1) + + function loss_phs_order(pts) + pts_vec = [SVector{2}(pts[2 * i - 1], pts[2 * i]) for i in 1:N] + W = RadialBasisFunctions._build_weights(ℒ, pts_vec, pts_vec, adjl, basis) + return sum(W.nzval .^ 2) + end + + pts_flat = vcat([collect(p) for p in points]...) + + for (name, backend) in AD_BACKENDS + @testset "PHS($n) - $name" begin + di_grad = DI.gradient(loss_phs_order, backend, pts_flat) + fd_grad = FD.grad(FD.central_fdm(5, 1), loss_phs_order, pts_flat)[1] + # PHS1 may have zero gradient for some configurations + @test !all(iszero, di_grad) || n == 1 + @test isapprox(di_grad, fd_grad; rtol = 1.0e-2) + end + end + end + end + + @testset "IMQ basis with Partial operator" begin + basis = IMQ(1.0; poly_deg = 2) + ℒ = Partial(1, 1) + + function loss_imq_partial(pts) + pts_vec = [SVector{2}(pts[2 * i - 1], pts[2 * i]) for i in 1:N] + W = RadialBasisFunctions._build_weights(ℒ, pts_vec, pts_vec, adjl, basis) + return sum(W.nzval .^ 2) + end + + pts_flat = vcat([collect(p) for p in points]...) + + for (name, backend) in AD_BACKENDS + @testset "$name" begin + test_gradient_vs_fd(loss_imq_partial, pts_flat, backend; rtol = 1.0e-3) + end + end + end + + @testset "IMQ basis with Laplacian operator" begin + basis = IMQ(1.0; poly_deg = 2) + ℒ = Laplacian() + + function loss_imq_laplacian(pts) + pts_vec = [SVector{2}(pts[2 * i - 1], pts[2 * i]) for i in 1:N] + W = RadialBasisFunctions._build_weights(ℒ, pts_vec, pts_vec, adjl, basis) + return sum(W.nzval .^ 2) + end + + pts_flat = vcat([collect(p) for p in points]...) + + for (name, backend) in AD_BACKENDS + @testset "$name" begin + test_gradient_vs_fd(loss_imq_laplacian, pts_flat, backend; rtol = 1.0e-3) + end + end + end + + @testset "Gaussian basis with Partial operator" begin + basis = Gaussian(1.0; poly_deg = 2) + ℒ = Partial(1, 1) + + function loss_gaussian_partial(pts) + pts_vec = [SVector{2}(pts[2 * i - 1], pts[2 * i]) for i in 1:N] + W = RadialBasisFunctions._build_weights(ℒ, pts_vec, pts_vec, adjl, basis) + return sum(W.nzval .^ 2) + end + + pts_flat = vcat([collect(p) for p in points]...) + + for (name, backend) in AD_BACKENDS + @testset "$name" begin + test_gradient_vs_fd( + loss_gaussian_partial, pts_flat, backend; rtol = 1.0e-3 + ) + end + end + end + + @testset "Gaussian basis with Laplacian operator" begin + basis = Gaussian(1.0; poly_deg = 2) + ℒ = Laplacian() + + function loss_gaussian_laplacian(pts) + pts_vec = [SVector{2}(pts[2 * i - 1], pts[2 * i]) for i in 1:N] + W = RadialBasisFunctions._build_weights(ℒ, pts_vec, pts_vec, adjl, basis) + return sum(W.nzval .^ 2) + end + + pts_flat = vcat([collect(p) for p in points]...) + + for (name, backend) in AD_BACKENDS + @testset "$name" begin + test_gradient_vs_fd( + loss_gaussian_laplacian, pts_flat, backend; rtol = 1.0e-3 + ) + end + end + end + + @testset "Different shape parameters" begin + for ε in [0.5, 1.0, 2.0] + @testset "IMQ with ε=$ε" begin + basis = IMQ(ε; poly_deg = 2) + ℒ = Partial(1, 1) + + function loss_imq_shape(pts) + pts_vec = [SVector{2}(pts[2 * i - 1], pts[2 * i]) for i in 1:N] + W = RadialBasisFunctions._build_weights( + ℒ, pts_vec, pts_vec, adjl, basis + ) + return sum(W.nzval .^ 2) + end + + pts_flat = vcat([collect(p) for p in points]...) + + for (name, backend) in AD_BACKENDS + @testset "$name" begin + test_gradient_vs_fd( + loss_imq_shape, pts_flat, backend; rtol = 1.0e-2 + ) + end + end + end + + @testset "Gaussian with ε=$ε" begin + basis = Gaussian(ε; poly_deg = 2) + ℒ = Partial(1, 1) + + function loss_gaussian_shape(pts) + pts_vec = [SVector{2}(pts[2 * i - 1], pts[2 * i]) for i in 1:N] + W = RadialBasisFunctions._build_weights( + ℒ, pts_vec, pts_vec, adjl, basis + ) + return sum(W.nzval .^ 2) + end + + pts_flat = vcat([collect(p) for p in points]...) + + for (name, backend) in AD_BACKENDS + @testset "$name" begin + test_gradient_vs_fd( + loss_gaussian_shape, pts_flat, backend; rtol = 1.0e-2 + ) + end + end + end + end + end + end + + # Test that extensions load correctly + @testset "Extension Loading" begin + @test Base.find_package("Enzyme") !== nothing + @test Base.find_package("Mooncake") !== nothing + @test Base.find_package("DifferentiationInterface") !== nothing + end +end diff --git a/test/extensions/chainrules_ext.jl b/test/extensions/chainrules_ext.jl new file mode 100644 index 00000000..e424996d --- /dev/null +++ b/test/extensions/chainrules_ext.jl @@ -0,0 +1,302 @@ +using RadialBasisFunctions +using ChainRulesCore +using StaticArraysCore +using FiniteDifferences +using LinearAlgebra: Symmetric, dot +using Random: MersenneTwister +using Test + +const FD = FiniteDifferences + +@testset "ChainRulesCore - Interpolator Construction Rules" begin + N = 30 + points = [SVector{2}(0.1 + 0.8 * i / 6, 0.1 + 0.8 * j / 6) for i in 1:6 for j in 1:5] + values = sin.(getindex.(points, 1)) + eval_points = [SVector{2}(0.2 + 0.6 * i / 3, 0.2 + 0.6 * j / 3) for i in 1:3 for j in 1:3] + + @testset "End-to-end gradient test via rrules" begin + # Forward pass: construct and evaluate + interp, construction_pb = ChainRulesCore.rrule(Interpolator, points, values) + ys, eval_pb = ChainRulesCore.rrule(interp, eval_points) + + # Backward pass: start with loss gradient + Δys = 2 .* ys # gradient of sum(y^2) w.r.t. y + Δinterp, Δxs = eval_pb(Δys) + + # Propagate through construction + result = construction_pb(Δinterp) + Δv = result[3] + + # Compare with finite differences + function loss(v) + interp_local = Interpolator(points, v) + return sum(interp_local(eval_points) .^ 2) + end + + fd_grad = FD.grad(FD.central_fdm(5, 1), loss, values)[1] + + @test !all(iszero, Δv) + @test isapprox(Δv, fd_grad; rtol = 1.0e-4) + end + + @testset "Different bases" begin + # Use slightly relaxed tolerance for Gaussian basis which can be more sensitive + for (name, basis, tol) in [("PHS", PHS(), 1.0e-3), ("IMQ", IMQ(1.0), 1.0e-3), ("Gaussian", Gaussian(1.0), 5.0e-2)] + @testset "$name basis" begin + interp, construction_pb = ChainRulesCore.rrule(Interpolator, points, values, basis) + ys, eval_pb = ChainRulesCore.rrule(interp, eval_points) + + Δys = 2 .* ys + Δinterp, _ = eval_pb(Δys) + result = construction_pb(Δinterp) + Δv = result[3] + + function loss_basis(v) + interp_local = Interpolator(points, v, basis) + return sum(interp_local(eval_points) .^ 2) + end + + fd_grad = FD.grad(FD.central_fdm(5, 1), loss_basis, values)[1] + + @test !all(iszero, Δv) + @test isapprox(Δv, fd_grad; rtol = tol) + end + end + end + + @testset "Single point evaluation" begin + single_point = SVector{2}(0.5, 0.5) + + interp, construction_pb = ChainRulesCore.rrule(Interpolator, points, values) + y, eval_pb = ChainRulesCore.rrule(interp, single_point) + + Δy = 2 * y # gradient of y^2 + Δinterp, _ = eval_pb(Δy) + result = construction_pb(Δinterp) + Δv = result[3] + + function loss_single(v) + interp_local = Interpolator(points, v) + return interp_local(single_point)^2 + end + + fd_grad = FD.grad(FD.central_fdm(5, 1), loss_single, values)[1] + + @test !all(iszero, Δv) + @test isapprox(Δv, fd_grad; rtol = 1.0e-4) + end +end + +@testset "ChainRulesCore - Evaluation Rules" begin + N = 30 + points = [SVector{2}(0.1 + 0.8 * i / 6, 0.1 + 0.8 * j / 6) for i in 1:6 for j in 1:5] + values = sin.(getindex.(points, 1)) + interp = Interpolator(points, values) + + @testset "Single point evaluation gradient w.r.t. x" begin + x = SVector{2}(0.5, 0.5) + + y, eval_pb = ChainRulesCore.rrule(interp, x) + Δy = 1.0 + Δinterp, Δx = eval_pb(Δy) + + # Compare with finite differences + function eval_loss(xv) + return interp(SVector{2}(xv[1], xv[2])) + end + + fd_grad = FD.grad(FD.central_fdm(5, 1), eval_loss, collect(x))[1] + + @test isapprox(collect(Δx), fd_grad; rtol = 1.0e-4) + end + + @testset "Batch evaluation gradient w.r.t. x" begin + xs = [SVector{2}(0.3, 0.3), SVector{2}(0.5, 0.5), SVector{2}(0.7, 0.7)] + + ys, eval_pb = ChainRulesCore.rrule(interp, xs) + Δys = ones(length(xs)) + Δinterp, Δxs = eval_pb(Δys) + + # Compare each point's gradient with finite differences + for (i, x) in enumerate(xs) + function eval_loss_i(xv) + xs_mod = copy(xs) + xs_mod[i] = SVector{2}(xv[1], xv[2]) + return sum(interp(xs_mod)) + end + + fd_grad = FD.grad(FD.central_fdm(5, 1), eval_loss_i, collect(x))[1] + @test isapprox(collect(Δxs[i]), fd_grad; rtol = 1.0e-4) + end + end + + @testset "Weight gradients from evaluation" begin + x = SVector{2}(0.5, 0.5) + + y, eval_pb = ChainRulesCore.rrule(interp, x) + Δy = 1.0 + Δinterp, _ = eval_pb(Δy) + + # Check that weight gradients are returned + @test Δinterp isa Tangent{Interpolator} + @test hasproperty(Δinterp, :rbf_weights) + @test hasproperty(Δinterp, :monomial_weights) + @test !all(iszero, Δinterp.rbf_weights) + end +end + +@testset "Direct backward stencil functions" begin + # Test backward_stencil_partial! and backward_stencil_laplacian! directly. + # These are the non-ε variants used by the Enzyme extension, which is skipped + # on Julia 1.12+ so they need direct coverage. + + N = 25 + points = [SVector{2}(0.1 + 0.8 * i / 5, 0.1 + 0.8 * j / 5) for i in 1:5 for j in 1:5] + adjl = RadialBasisFunctions.find_neighbors(points, 10) + k = 10 + rng = MersenneTwister(42) + + # Helper: compute stencil weights from flat local_data + eval_point vector + function _compute_stencil_weights(flat_data, dim_space, k, basis, mon, ℒrbf, ℒmon) + nmon = binomial(dim_space + basis.poly_deg, basis.poly_deg) + n = k + nmon + ld = [SVector{dim_space}(flat_data[(dim_space * (i - 1) + 1):(dim_space * i)]...) for i in 1:k] + ep_start = dim_space * k + 1 + ep = SVector{dim_space}(flat_data[ep_start:(ep_start + dim_space - 1)]...) + A_full = zeros(Float64, n, n) + A = Symmetric(A_full, :U) + RadialBasisFunctions._build_collocation_matrix!(A, ld, basis, mon, k) + b = zeros(Float64, n) + RadialBasisFunctions._build_rhs!(b, ℒrbf, ℒmon, ld, ep, basis, mon, k) + λ = Symmetric(A_full, :U) \ b + return λ[1:k] + end + + @testset "backward_stencil_partial!" begin + basis = PHS(3; poly_deg = 2) + ℒ = Partial(1, 1) + mon = MonomialBasis(2, basis.poly_deg) + ℒmon = ℒ(mon) + ℒrbf = ℒ(basis) + + W, cache = RadialBasisFunctions._forward_with_cache( + points, points, adjl, basis, ℒrbf, ℒmon, mon, Partial + ) + + eval_idx = 13 + neighbors = adjl[eval_idx] + eval_point = points[eval_idx] + stencil_cache = cache.stencil_caches[eval_idx] + + Δw = randn(rng, k, 1) + + local_data = [points[i] for i in neighbors] + Δlocal_data = [zeros(Float64, 2) for _ in 1:k] + Δeval_pt = zeros(Float64, 2) + + grad_Lφ_x = RadialBasisFunctions.grad_applied_partial_wrt_x(basis, ℒ.dim) + grad_Lφ_xi = RadialBasisFunctions.grad_applied_partial_wrt_xi(basis, ℒ.dim) + + RadialBasisFunctions.backward_stencil_partial!( + Δlocal_data, Δeval_pt, Δw, stencil_cache, collect(1:k), + eval_point, local_data, basis, mon, k, ℒ.dim, grad_Lφ_x, grad_Lφ_xi + ) + + # Compare with with-ε version (should give identical results for PHS) + Δlocal_data_ε = [zeros(Float64, 2) for _ in 1:k] + Δeval_pt_ε = zeros(Float64, 2) + Δε_acc = Ref(0.0) + + RadialBasisFunctions.backward_stencil_partial_with_ε!( + Δlocal_data_ε, Δeval_pt_ε, Δε_acc, Δw, stencil_cache, collect(1:k), + eval_point, local_data, basis, mon, k, ℒ.dim, grad_Lφ_x, grad_Lφ_xi + ) + + for i in 1:k + @test isapprox(Δlocal_data[i], Δlocal_data_ε[i]; atol = 1.0e-12) + end + @test isapprox(Δeval_pt, Δeval_pt_ε; atol = 1.0e-12) + + # Verify against finite differences + flat_data = vcat([collect(d) for d in local_data]..., collect(eval_point)) + loss_stencil(x) = dot( + Δw[:, 1], _compute_stencil_weights(x, 2, k, basis, mon, ℒrbf, ℒmon) + ) + fd_grad = FD.grad(FD.central_fdm(5, 1), loss_stencil, flat_data)[1] + + backward_grad = zeros(2 * k + 2) + for i in 1:k + backward_grad[2 * i - 1] = Δlocal_data[i][1] + backward_grad[2 * i] = Δlocal_data[i][2] + end + backward_grad[2 * k + 1] = Δeval_pt[1] + backward_grad[2 * k + 2] = Δeval_pt[2] + + @test !all(iszero, backward_grad) + @test isapprox(backward_grad, fd_grad; rtol = 1.0e-3) + end + + @testset "backward_stencil_laplacian!" begin + basis = PHS(3; poly_deg = 2) + ℒ = Laplacian() + mon = MonomialBasis(2, basis.poly_deg) + ℒmon = ℒ(mon) + ℒrbf = ℒ(basis) + + W, cache = RadialBasisFunctions._forward_with_cache( + points, points, adjl, basis, ℒrbf, ℒmon, mon, Laplacian + ) + + eval_idx = 13 + neighbors = adjl[eval_idx] + eval_point = points[eval_idx] + stencil_cache = cache.stencil_caches[eval_idx] + + Δw = randn(rng, k, 1) + + local_data = [points[i] for i in neighbors] + Δlocal_data = [zeros(Float64, 2) for _ in 1:k] + Δeval_pt = zeros(Float64, 2) + + grad_Lφ_x = RadialBasisFunctions.grad_applied_laplacian_wrt_x(basis) + grad_Lφ_xi = RadialBasisFunctions.grad_applied_laplacian_wrt_xi(basis) + + RadialBasisFunctions.backward_stencil_laplacian!( + Δlocal_data, Δeval_pt, Δw, stencil_cache, collect(1:k), + eval_point, local_data, basis, mon, k, grad_Lφ_x, grad_Lφ_xi + ) + + # Compare with with-ε version (should give identical results for PHS) + Δlocal_data_ε = [zeros(Float64, 2) for _ in 1:k] + Δeval_pt_ε = zeros(Float64, 2) + Δε_acc = Ref(0.0) + + RadialBasisFunctions.backward_stencil_laplacian_with_ε!( + Δlocal_data_ε, Δeval_pt_ε, Δε_acc, Δw, stencil_cache, collect(1:k), + eval_point, local_data, basis, mon, k, grad_Lφ_x, grad_Lφ_xi + ) + + for i in 1:k + @test isapprox(Δlocal_data[i], Δlocal_data_ε[i]; atol = 1.0e-12) + end + @test isapprox(Δeval_pt, Δeval_pt_ε; atol = 1.0e-12) + + # Verify against finite differences + flat_data = vcat([collect(d) for d in local_data]..., collect(eval_point)) + loss_stencil(x) = dot( + Δw[:, 1], _compute_stencil_weights(x, 2, k, basis, mon, ℒrbf, ℒmon) + ) + fd_grad = FD.grad(FD.central_fdm(5, 1), loss_stencil, flat_data)[1] + + backward_grad = zeros(2 * k + 2) + for i in 1:k + backward_grad[2 * i - 1] = Δlocal_data[i][1] + backward_grad[2 * i] = Δlocal_data[i][2] + end + backward_grad[2 * k + 1] = Δeval_pt[1] + backward_grad[2 * k + 2] = Δeval_pt[2] + + @test !all(iszero, backward_grad) + @test isapprox(backward_grad, fd_grad; rtol = 1.0e-3) + end +end diff --git a/test/operators/directional.jl b/test/operators/directional.jl index 7ba61f61..feb9bf4e 100644 --- a/test/operators/directional.jl +++ b/test/operators/directional.jl @@ -3,6 +3,9 @@ using StaticArraysCore using Statistics using HaltonSequences using LinearAlgebra +using Random: MersenneTwister + +rng = MersenneTwister(456) include("../test_utils.jl") @@ -26,7 +29,7 @@ end @testset "Direction Vector for Each Data Center" begin v = map(1:length(x)) do i - v = SVector{2}(rand(2)) + v = SVector{2}(rand(rng, 2)) return v /= norm(v) end ∇v = directional(x, v, PHS3(2)) @@ -37,7 +40,7 @@ end @testset "Different Evaluation Points" begin x2 = SVector{2}.(HaltonPoint(2)[1:N]) v = map(1:length(x2)) do i - v = SVector{2}(rand(2)) + v = SVector{2}(rand(rng, 2)) return v /= norm(v) end ∇v = directional(x, x2, v, PHS3(2)) @@ -54,7 +57,7 @@ end # number of geometrical vectors don't match number of points when using a different # geometrical vector for each point v = map(1:(length(x) - 1)) do i - v = SVector{2}(rand(2)) + v = SVector{2}(rand(rng, 2)) return v /= norm(v) end @test_throws DomainError directional(x, v) diff --git a/test/operators/gradient.jl b/test/operators/gradient.jl index 3503ff1a..20fbacc7 100644 --- a/test/operators/gradient.jl +++ b/test/operators/gradient.jl @@ -2,6 +2,9 @@ using RadialBasisFunctions using StaticArraysCore using Statistics using HaltonSequences +using Random: MersenneTwister + +rng = MersenneTwister(345) include("../test_utils.jl") @@ -25,7 +28,7 @@ y = f.(x) end @testset "Different Evaluation Points" begin - x2 = map(x -> SVector{2}(rand(2)), 1:100) + x2 = map(x -> SVector{2}(rand(rng, 2)), 1:100) ∇ = gradient(x, x2, PHS(3; poly_deg = 2)) ∇y = ∇(y) @test ∇y isa Matrix diff --git a/test/operators/operators.jl b/test/operators/operators.jl index 9cf2b781..f239cb73 100644 --- a/test/operators/operators.jl +++ b/test/operators/operators.jl @@ -4,6 +4,9 @@ using StaticArraysCore using LinearAlgebra using Statistics using HaltonSequences +using Random: MersenneTwister + +rng = MersenneTwister(123) N = 100 x = SVector{2}.(HaltonPoint(2)[1:N]) @@ -17,8 +20,8 @@ end @testset "Operator Evaluation" begin ∂ = partial(x, 1, 1) - y = rand(N) - z = rand(N) + y = rand(rng, N) + z = rand(rng, N) ∂(y, z) @test y ≈ ∂.weights * z @@ -34,15 +37,15 @@ end @testset "Operator Update" begin ∂ = partial(x, 1, 1) correct_weights = copy(∂.weights) - ∂.weights .= rand(size(∂.weights)) + ∂.weights .= rand(rng, size(∂.weights)) update_weights!(∂) @test ∂.weights ≈ correct_weights @test is_cache_valid(∂) ∇ = gradient(x, PHS(3; poly_deg = 2)) correct_weights = copy.(∇.weights) - ∇.weights[1] .= rand(size(∇.weights[1])) - ∇.weights[2] .= rand(size(∇.weights[2])) + ∇.weights[1] .= rand(rng, size(∇.weights[1])) + ∇.weights[2] .= rand(rng, size(∇.weights[2])) update_weights!(∇) @test ∇.weights[1] ≈ correct_weights[1] @test ∇.weights[2] ≈ correct_weights[2] @@ -63,7 +66,7 @@ end @test RBF.dim(∂) == 2 # 3D test - x3d = [SVector{3}(rand(3)) for _ in 1:50] + x3d = [SVector{3}(rand(rng, 3)) for _ in 1:50] ∂3d = partial(x3d, 1, 1) @test RBF.dim(∂3d) == 3 end diff --git a/test/operators/partial.jl b/test/operators/partial.jl index bc662829..7af166b3 100644 --- a/test/operators/partial.jl +++ b/test/operators/partial.jl @@ -2,6 +2,9 @@ using RadialBasisFunctions using StaticArraysCore using Statistics using HaltonSequences +using Random: MersenneTwister + +rng = MersenneTwister(234) include("../test_utils.jl") @@ -62,7 +65,7 @@ end end @testset "Different Evaluation Points" begin - x2 = map(x -> SVector{2}(rand(2)), 1:100) + x2 = map(x -> SVector{2}(rand(rng, 2)), 1:100) ∂x = partial(x, x2, 1, 1, PHS(3; poly_deg = 2)) ∂y = partial(x, x2, 1, 2, PHS(3; poly_deg = 2)) @test mean_percent_error(∂x(y), df_dx.(x2)) < 10 diff --git a/test/operators/regrid.jl b/test/operators/regrid.jl index 327ce2a9..7d22a0ac 100644 --- a/test/operators/regrid.jl +++ b/test/operators/regrid.jl @@ -2,6 +2,9 @@ using RadialBasisFunctions using StaticArraysCore using Statistics using HaltonSequences +using Random: MersenneTwister + +rng = MersenneTwister(567) include("../test_utils.jl") @@ -11,7 +14,7 @@ N = 10_000 x = SVector{2}.(HaltonPoint(2)[1:N]) y = f.(x) -x2 = map(x -> SVector{2}(rand(2)), 1:100) +x2 = map(x -> SVector{2}(rand(rng, 2)), 1:100) @testset "Positional Basis Constructor" begin r = regrid(x, x2, PHS(3; poly_deg = 2)) diff --git a/test/operators/virtual.jl b/test/operators/virtual.jl index febc1342..a25f9966 100644 --- a/test/operators/virtual.jl +++ b/test/operators/virtual.jl @@ -2,6 +2,9 @@ using RadialBasisFunctions using StaticArraysCore using Statistics using HaltonSequences +using Random: MersenneTwister + +rng = MersenneTwister(678) include("../test_utils.jl") @@ -40,7 +43,7 @@ y = f.(x) end @testset "Different Evaluation Points" begin - x2 = map(x -> SVector{2}(rand(2)), 1:100) + x2 = map(x -> SVector{2}(rand(rng, 2)), 1:100) ∂x = ∂virtual(x, x2, 1, Δ, PHS(3; poly_deg = 2)) ∂y = ∂virtual(x, x2, 2, Δ, PHS(3; poly_deg = 2)) @test mean_percent_error(∂x(y), df_dx.(x2)) < 10 diff --git a/test/runtests.jl b/test/runtests.jl index 3ab6be84..fd143f34 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -94,3 +94,15 @@ end include("solve/integration/partial_x_end_to_end.jl") include("solve/integration/partial_y_end_to_end.jl") end + +if Base.find_package("ChainRulesCore") !== nothing + @safetestset "ChainRulesCore Extension" begin + include("extensions/chainrules_ext.jl") + end +end + +if Base.find_package("DifferentiationInterface") !== nothing + @safetestset "Autodiff via DifferentiationInterface" begin + include("extensions/autodiff_di.jl") + end +end diff --git a/test/solve/integration/hermite_test_utils.jl b/test/solve/integration/hermite_test_utils.jl index b023a4fc..45905876 100644 --- a/test/solve/integration/hermite_test_utils.jl +++ b/test/solve/integration/hermite_test_utils.jl @@ -1,8 +1,11 @@ using Test using StaticArraysCore using LinearAlgebra +using Random: MersenneTwister -function create_2d_unit_square_domain(spacing::Float64 = 0.05; randomize = false) +function create_2d_unit_square_domain( + spacing::Float64 = 0.05; randomize = false, rng = MersenneTwister(101) + ) domain_2d = SVector{2, Float64}[] for x in 0.0:spacing:1.0 for y in 0.0:spacing:1.0 @@ -11,8 +14,8 @@ function create_2d_unit_square_domain(spacing::Float64 = 0.05; randomize = false if randomize && !is_on_boundary # Only add noise to interior points - noise_x = (rand() - 0.5) * spacing * 0.3 - noise_y = (rand() - 0.5) * spacing * 0.3 + noise_x = (rand(rng) - 0.5) * spacing * 0.3 + noise_y = (rand(rng) - 0.5) * spacing * 0.3 push!(domain_2d, SVector(x + noise_x, y + noise_y)) else # Keep boundary points exact