Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "RadialBasisFunctions"
uuid = "79ee0514-adf7-4479-8807-6f72ea8967e8"
authors = ["Kyle Beggs"]
version = "0.2.4"
version = "0.2.5"

[deps]
ChunkSplitters = "ae650224-84b6-46f8-82ea-d812ca08434e"
Expand Down
3 changes: 2 additions & 1 deletion src/RadialBasisFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ export find_neighbors, reorder_points!
include("solve.jl")

include("operators/operators.jl")
export RadialBasisOperator, update_weights!
export RadialBasisOperator, ScalarValuedOperator, VectorValuedOperator
export update_weights!, is_cache_valid

include("operators/partial.jl")
export Partial, partial
Expand Down
14 changes: 2 additions & 12 deletions src/basis/inverse_multiquadric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,21 +47,11 @@ function ∇²(rbf::IMQ)
end
end

function shape_param_franke(x)
# modified Franke's formula for double precision as initial guess
D = 2.0 * euclidean(first(x), last(x))
N = size(points, 2)
return D / (0.8 * (N^0.25))
end

function Base.show(io::IO, rbf::IMQ)
print(io, "Inverse Multiquadrics, 1/sqrt((r*ε)²+1)")
print(io, "\n├─Shape factor: ε = $(rbf.ε)")
if rbf.poly_deg < 0
print(io, "\n└─No Monomial augmentation")
else
print(io, "\n└─Polynomial augmentation: degree $(rbf.poly_deg)")
end
print(io, "\n└─Polynomial augmentation: degree $(rbf.poly_deg)")
return nothing
end

print_basis(rbf::IMQ) = "Inverse Multiquadric (ε = $(rbf.ε))"
7 changes: 2 additions & 5 deletions src/basis/polyharmonic_spline.jl
Original file line number Diff line number Diff line change
Expand Up @@ -172,11 +172,8 @@ end

function Base.show(io::IO, rbf::R) where {R<:AbstractPHS}
print(io, print_basis(rbf))
if rbf.poly_deg < 0
print(io, "\n└─No Monomial augmentation")
else
print(io, "\n└─Polynomial augmentation: degree $(rbf.poly_deg)")
end
print(io, "\n└─Polynomial augmentation: degree $(rbf.poly_deg)")
return nothing
end

print_basis(::PHS1) = "Polyharmonic spline (r¹)"
Expand Down
3 changes: 1 addition & 2 deletions src/operators/directional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ end
function update_weights!(op::RadialBasisOperator{<:Directional})
v = op.ℒ.v
N = length(first(op.data))
@assert length(v) == N || length(v) == size(op)[1] "wrong size for v"
if length(v) == N
op.weights .= mapreduce(+, enumerate(op.ℒ.ℒ)) do (i, ℒ)
_build_weights(ℒ, op) * v[i]
Expand All @@ -96,7 +95,7 @@ function update_weights!(op::RadialBasisOperator{<:Directional})
Diagonal(vv[i]) * _build_weights(ℒ, op)
end
end
validate_cache(op)
validate_cache!(op)
return nothing
end

Expand Down
2 changes: 0 additions & 2 deletions src/operators/gradient.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,5 @@ function gradient(
return RadialBasisOperator(ℒ, data, eval_points, basis; k=k, adjl=adjl)
end

Base.size(op::RadialBasisOperator{<:Gradient}) = size(first(op.weights))

# pretty printing
print_op(op::Gradient) = "Gradient (∇f)"
119 changes: 30 additions & 89 deletions src/operators/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,123 +81,66 @@ function RadialBasisOperator(
return RadialBasisOperator(ℒ, weights, data, eval_points, adjl, basis, true)
end

# extend Base methods
Base.length(op::RadialBasisOperator) = length(op.adjl)
Base.size(op::RadialBasisOperator) = size(op.weights)
Base.size(op::RadialBasisOperator, dim::Int) = size(op.weights, dim)
function Base.size(op::RadialBasisOperator{<:VectorValuedOperator})
return ntuple(i -> size(op.weights[i]), embeddim(op))
end
Base.getindex(op::RadialBasisOperator, i) = nonzeros(op.weights[i, :])
function Base.getindex(op::RadialBasisOperator{VectorValuedOperator}, i)
return ntuple(j -> nonzeros(op.weights[j][i, :]), embeddim(op))
end

# convienience methods
embeddim(op::RadialBasisOperator) = length(first(op.data))
dim(op::RadialBasisOperator) = length(first(op.data))

# caching
invalidate_cache(op::RadialBasisOperator) = op.valid_cache[] = false
validate_cache(op::RadialBasisOperator) = op.valid_cache[] = true
invalidate_cache!(op::RadialBasisOperator) = op.valid_cache[] = false
validate_cache!(op::RadialBasisOperator) = op.valid_cache[] = true
is_cache_valid(op::RadialBasisOperator) = op.valid_cache[]

# dispatches for evaluation
_eval_op(op::RadialBasisOperator, x::AbstractVector) = _eval_op(op.weights, x)

function _eval_op(op::RadialBasisOperator{<:VectorValuedOperator}, x::AbstractVector)
return ntuple(i -> _eval_op(op.weights[i], x), embeddim(op))
end

function _eval_op(
op::RadialBasisOperator{<:VectorValuedOperator,W}, x::AbstractVector
) where {W<:Tuple}
if first(op.weights) isa Vector{<:Vector}
return ntuple(i -> _eval_op(op.weights[i], x, op.adjl), embeddim(op))
else
return ntuple(i -> _eval_op(op.weights[i], x), embeddim(op))
end
end

function _eval_op(
op::RadialBasisOperator{L,W}, x::AbstractVector
) where {L,W<:Vector{<:Vector}}
return _eval_op(op.weights, x, op.adjl)
end

_eval_op(w::AbstractMatrix, x::AbstractVector) = w * x

function _eval_op(w::AbstractVector{<:AbstractVector{T}}, x::AbstractVector, adjl) where {T}
y = zeros(T, length(w))
Threads.@threads for i in eachindex(adjl)
@views y[i] = w[i] ⋅ x[adjl[i]]
end
return y
end
"""
function (op::RadialBasisOperator)(x)

# evaluate
Evaluate the operator at `x`.
"""
function (op::RadialBasisOperator)(x)
!is_cache_valid(op) && update_weights!(op)
return _eval_op(op, x)
end

function LinearAlgebra.mul!(
y::AbstractVecOrMat, op::RadialBasisOperator, x::AbstractVecOrMat
)
"""
function (op::RadialBasisOperator)(y, x)

Evaluate the operator at `x` in-place and store the result in `y`.
"""
function (op::RadialBasisOperator)(y, x)
!is_cache_valid(op) && update_weights!(op)
return mul!(y, op.weights, x)
return _eval_op(op, y, x)
end
function LinearAlgebra.mul!(
y::AbstractVecOrMat, op::RadialBasisOperator, x::AbstractVecOrMat, α, β
)
!is_cache_valid(op) && update_weights!(op)
return mul!(y, op.weights, x, α, β)

# dispatches for evaluation
_eval_op(op::RadialBasisOperator, x) = op.weights * x
_eval_op(op::RadialBasisOperator, y, x) = mul!(y, op.weights, x)

function _eval_op(op::RadialBasisOperator{<:VectorValuedOperator}, x)
return ntuple(i -> op.weights[i] * x, dim(op))
end
function _eval_op(op::RadialBasisOperator{<:VectorValuedOperator}, y, x)
for i in eachindex(op.weights)
mul!(y[i], op.weights[i], x)
end
end

# LinearAlgebra methods
function LinearAlgebra.:⋅(
op::RadialBasisOperator{<:VectorValuedOperator}, x::AbstractVector
)
!is_cache_valid(op) && update_weights!(op)
return sum(op(x))
end

function LinearAlgebra.mul!(
y::AbstractVector{<:Real},
op::RadialBasisOperator{<:VectorValuedOperator},
x::AbstractVector,
)
!is_cache_valid(op) && update_weights!(op)
for i in eachindex(op.weights)
mul!(y[i], op.weights[i], x)
end
end

# update weights
function update_weights!(op::RadialBasisOperator)
op.weights .= _build_weights(op.ℒ.ℒ, op)
validate_cache(op)
validate_cache!(op)
return nothing
end

function update_weights!(op::RadialBasisOperator{<:VectorValuedOperator})
return _update_weights!(op, op.weights)
end

function _update_weights!(op, weights::NTuple{N,AbstractMatrix}) where {N}
for (i, ℒ) in enumerate(op.ℒ.ℒ)
weights[i] .= _build_weights(ℒ, op)
op.weights[i] .= _build_weights(ℒ, op)
end
validate_cache(op)
return nothing
end

function _update_weights!(op, weights::NTuple{N,AbstractVector}) where {N}
for (i, ℒ) in enumerate(op.ℒ.ℒ)
w = _build_weights(ℒ, op)
for j in eachindex(weights[i])
weights[i][j] .= w[j]
end
end
validate_cache(op)
validate_cache!(op)
return nothing
end

Expand All @@ -215,5 +158,3 @@ function Base.show(io::IO, op::RadialBasisOperator)
" with degree $(op.basis.poly_deg) polynomial augmentation",
)
end

print_op(op) = "$op"
8 changes: 0 additions & 8 deletions src/operators/partial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ function ∂(basis::AbstractBasis, order::T, dim::T) where {T<:Int}
elseif order == 2
return ∂²(basis, dim)
else
return _higher_order_partial(basis, order, dim)
throw(
ArgumentError(
"Only first and second order derivatives are supported right now. You may use the custom operator.",
Expand All @@ -66,12 +65,5 @@ function ∂(basis::AbstractBasis, order::T, dim::T) where {T<:Int}
end
end

function _higher_order_partial(basis::MonomialBasis, order::T, dim::T) where {T<:Int}
return _∂(basis, order, Val(dim))
end
function _higher_order_partial(_, _, _)
throw(ArgumentError("Higher order partials are not supported for RBFs yet."))
end

# pretty printing
print_op(op::Partial) = "∂ⁿf/∂xᵢ (n = $(op.order), i = $(op.dim))"
64 changes: 64 additions & 0 deletions test/operators/operators.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
using RadialBasisFunctions
import RadialBasisFunctions as RBF
using StaticArraysCore
using LinearAlgebra
using Statistics
using HaltonSequences

N = 100
x = SVector{2}.(HaltonPoint(2)[1:N])

@testset "Base Methods" begin
∂ = partial(x, 1, 1)
@test is_cache_valid(∂)
RBF.invalidate_cache!(∂)
@test !is_cache_valid(∂)
end

@testset "Operator Evaluation" begin
∂ = partial(x, 1, 1)
y = rand(N)
z = rand(N)
∂(y, z)
@test y ≈ ∂.weights * z

∇ = gradient(x, PHS(3; poly_deg=2))
y = (rand(N), rand(N))
∇(y, z)
@test y[1] ≈ ∇.weights[1] * z
@test y[2] ≈ ∇.weights[2] * z

@test ∇ ⋅ z ≈ (∇.weights[1] * z) .+ (∇.weights[2] * z)
end

@testset "Operator Update" begin
∂ = partial(x, 1, 1)
correct_weights = copy(∂.weights)
∂.weights .= rand(size(∂.weights))
update_weights!(∂)
@test ∂.weights ≈ correct_weights
@test is_cache_valid(∂)

∇ = gradient(x, PHS(3; poly_deg=2))
correct_weights = copy.(∇.weights)
∇.weights[1] .= rand(size(∇.weights[1]))
∇.weights[2] .= rand(size(∇.weights[2]))
update_weights!(∇)
@test ∇.weights[1] ≈ correct_weights[1]
@test ∇.weights[2] ≈ correct_weights[2]
@test is_cache_valid(∇)
end

@testset "Printing" begin
∂ = partial(x, 1, 1)
@test repr(∂) == """
RadialBasisOperator
├─Operator: ∂ⁿf/∂xᵢ (n = 1, i = 1)
├─Data type: StaticArraysCore.SVector{2, Float64}
├─Number of points: 100
├─Stencil size: 12
└─Basis: Polyharmonic spline (r³) with degree 2 polynomial augmentation
"""

@test RBF.print_op(∂.ℒ) == "∂ⁿf/∂xᵢ (n = 1, i = 1)"
end
4 changes: 4 additions & 0 deletions test/operators/partial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ end
@test mean_percent_error(∂y(y), df_dy.(x2)) < 10
end

@testset "Higher Order Derivatives" begin
@test_throws ArgumentError partial(x, 3, 1)
end

@testset "Printing" begin
∂ = Partial(identity, 1, 2)
@test RadialBasisFunctions.print_op(∂) == "∂ⁿf/∂xᵢ (n = 1, i = 2)"
Expand Down
6 changes: 5 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using SafeTestsets

@safetestset "Basis - General Utils" begin
@safetestset "Basis - General" begin
include("basis/basis.jl")
end

Expand All @@ -20,6 +20,10 @@ end
include("basis/monomial.jl")
end

@safetestset "Operators" begin
include("operators/operators.jl")
end

@safetestset "Partial Derivatives" begin
include("operators/partial.jl")
end
Expand Down