Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ julia --project=benchmark benchmark/benchmarks.jl

2. **Operators** (`src/operators/`): Differential operators built on RBFs
- `RadialBasisOperator` - Main operator type with lazy weight computation
- Specific operators: `Partial`, `Gradient`, `Laplacian`, `Directional`, `Custom`
- Specific operators: `Partial`, `Jacobian`, `Laplacian`, `Directional`, `Custom`
- `operator_algebra.jl` - Composition and algebraic operations on operators
- Virtual operators for performance optimization

Expand Down
16 changes: 15 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,28 +1,42 @@
name = "RadialBasisFunctions"
uuid = "79ee0514-adf7-4479-8807-6f72ea8967e8"
authors = ["Kyle Beggs"]
version = "0.2.7"
authors = ["Kyle Beggs"]

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChunkSplitters = "ae650224-84b6-46f8-82ea-d812ca08434e"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
NearestNeighbors = "b8a86587-4115-5ab1-83bc-aa920d37bbce"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
SymRCM = "286e6d88-80af-4590-acc9-0001b223b9bd"

[weakdeps]
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"

[extensions]
RadialBasisFunctionsChainRulesCoreExt = "ChainRulesCore"
RadialBasisFunctionsMooncakeExt = ["ChainRulesCore", "Mooncake"]

[compat]
ChainRulesCore = "1.20"
ChunkSplitters = "3"
Combinatorics = "1"
Distances = "0.9, 0.10"
FiniteDifferences = "0.12.33"
KernelAbstractions = "0.9.34"
LinearAlgebra = "1"
Mooncake = "0.4"
NearestNeighbors = "0.4.8"
PrecompileTools = "1.2"
StaticArrays = "1.9.15"
StaticArraysCore = "1.4"
SymRCM = "0.2"
julia = "1.10"
14 changes: 8 additions & 6 deletions docs/src/getting_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,23 +83,25 @@ lap(x) = 4
all(abs.(lap.(x) .- lap_rbf(y)) .< 1e-8)
```

### Gradient
### Gradient / Jacobian

We can also retrieve the gradient. This is really just a convenience wrapper around `Partial`.
The `jacobian` function computes all partial derivatives. For scalar fields, this is the gradient.
The `gradient` function is a convenience alias for `jacobian`.

```@example overview
grad = gradient(x)
op = jacobian(x) # or equivalently: gradient(x)
result = op(y) # Matrix of size (N, dim)

# define exacts
df_x(x) = 4*x[1]
df_y(x) = 3

# error
all(df_x.(x) .≈ grad(y)[1])
# error - access columns for each partial derivative
all(df_x.(x) .≈ result[:, 1])
```

```@example overview
all(df_y.(x) .≈ grad(y)[2])
all(df_y.(x) .≈ result[:, 2])
```

## Current Limitations
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""
RadialBasisFunctionsChainRulesCoreExt

Package extension providing ChainRulesCore.jl custom differentiation rules for
RadialBasisFunctions.jl. These rules enable efficient automatic differentiation
with backends like Zygote.jl, Enzyme.jl (via @import_rrule), and others that
support ChainRulesCore.

The rules leverage the analytical derivatives already implemented in the package
(∂, ∇, ∇² methods) rather than relying on AD to trace through the computations.

Includes rrules for:
- Basis function evaluation (basis_rules.jl)
- Operator application (operator_rules.jl)
- Interpolator evaluation (interpolation_rules.jl)
- Weight construction for shape optimization (build_weights_*.jl)
"""
module RadialBasisFunctionsChainRulesCoreExt

using RadialBasisFunctions
using ChainRulesCore
using LinearAlgebra
using SparseArrays
using Combinatorics: binomial

# Import internal functions we need to extend
import RadialBasisFunctions: _eval_op, _build_weights
import RadialBasisFunctions: _build_collocation_matrix!, _build_rhs!

# Import types we need
import RadialBasisFunctions: RadialBasisOperator, Interpolator
import RadialBasisFunctions: AbstractRadialBasis, PHS1, PHS3, PHS5, PHS7, IMQ, Gaussian
import RadialBasisFunctions: VectorValuedOperator, ScalarValuedOperator
import RadialBasisFunctions: MonomialBasis, BoundaryData
import RadialBasisFunctions: Partial, Laplacian

# Import the gradient function for basis functions (not exported from main module)
const ∇ = RadialBasisFunctions.∇
const ∂ = RadialBasisFunctions.∂

# Existing rules
include("operator_rules.jl")
include("basis_rules.jl")
include("interpolation_rules.jl")

# Shape optimization support: rrules for _build_weights
include("build_weights_cache.jl")
include("operator_second_derivatives.jl")
include("build_weights_backward.jl")
include("build_weights_rrule.jl")

end # module
109 changes: 109 additions & 0 deletions ext/RadialBasisFunctionsChainRulesCoreExt/basis_rules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
#=
Differentiation rules for radial basis function evaluations.

Each RBF type (PHS, IMQ, Gaussian) is callable: basis(x, xi) computes φ(||x - xi||).
The analytical gradients are already implemented via the ∇(basis) function, which
returns a function that computes the gradient vector ∇φ at a given (x, xi) pair.

For reverse-mode AD:
- d/dx[φ(x, xi)] = ∇φ(x, xi)
- d/dxi[φ(x, xi)] = -∇φ(x, xi) (by symmetry: φ depends on x - xi)
=#

# PHS basis functions (Polyharmonic Splines)
# Note: PHS ∇ functions accept an optional normal argument, but we use the
# default (nothing) for standard differentiation.

function ChainRulesCore.rrule(basis::PHS1, x::AbstractVector, xi::AbstractVector)
y = basis(x, xi)

function phs1_pullback(Δy)
Δy_real = unthunk(Δy)
grad_fn = ∇(basis)
∇φ = grad_fn(x, xi)
Δx = Δy_real .* ∇φ
Δxi = -Δx
return NoTangent(), Δx, Δxi
end

return y, phs1_pullback
end

function ChainRulesCore.rrule(basis::PHS3, x::AbstractVector, xi::AbstractVector)
y = basis(x, xi)

function phs3_pullback(Δy)
Δy_real = unthunk(Δy)
grad_fn = ∇(basis)
∇φ = grad_fn(x, xi)
Δx = Δy_real .* ∇φ
Δxi = -Δx
return NoTangent(), Δx, Δxi
end

return y, phs3_pullback
end

function ChainRulesCore.rrule(basis::PHS5, x::AbstractVector, xi::AbstractVector)
y = basis(x, xi)

function phs5_pullback(Δy)
Δy_real = unthunk(Δy)
grad_fn = ∇(basis)
∇φ = grad_fn(x, xi)
Δx = Δy_real .* ∇φ
Δxi = -Δx
return NoTangent(), Δx, Δxi
end

return y, phs5_pullback
end

function ChainRulesCore.rrule(basis::PHS7, x::AbstractVector, xi::AbstractVector)
y = basis(x, xi)

function phs7_pullback(Δy)
Δy_real = unthunk(Δy)
grad_fn = ∇(basis)
∇φ = grad_fn(x, xi)
Δx = Δy_real .* ∇φ
Δxi = -Δx
return NoTangent(), Δx, Δxi
end

return y, phs7_pullback
end

# IMQ (Inverse Multiquadric) basis function

function ChainRulesCore.rrule(basis::IMQ, x::AbstractVector, xi::AbstractVector)
y = basis(x, xi)

function imq_pullback(Δy)
Δy_real = unthunk(Δy)
grad_fn = ∇(basis)
∇φ = grad_fn(x, xi)
Δx = Δy_real .* ∇φ
Δxi = -Δx
return NoTangent(), Δx, Δxi
end

return y, imq_pullback
end

# Gaussian basis function

function ChainRulesCore.rrule(basis::Gaussian, x::AbstractVector, xi::AbstractVector)
y = basis(x, xi)

function gaussian_pullback(Δy)
Δy_real = unthunk(Δy)
grad_fn = ∇(basis)
∇φ = grad_fn(x, xi)
Δx = Δy_real .* ∇φ
Δxi = -Δx
return NoTangent(), Δx, Δxi
end

return y, gaussian_pullback
end
Loading
Loading