diff --git a/Project.toml b/Project.toml index a6735b1d..e3f7ac1a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "RadialBasisFunctions" uuid = "79ee0514-adf7-4479-8807-6f72ea8967e8" authors = ["Kyle Beggs"] -version = "0.2.4" +version = "0.2.5" [deps] ChunkSplitters = "ae650224-84b6-46f8-82ea-d812ca08434e" diff --git a/src/RadialBasisFunctions.jl b/src/RadialBasisFunctions.jl index ebf99637..84c2fa03 100644 --- a/src/RadialBasisFunctions.jl +++ b/src/RadialBasisFunctions.jl @@ -24,7 +24,8 @@ export find_neighbors, reorder_points! include("solve.jl") include("operators/operators.jl") -export RadialBasisOperator, update_weights! +export RadialBasisOperator, ScalarValuedOperator, VectorValuedOperator +export update_weights!, is_cache_valid include("operators/partial.jl") export Partial, partial diff --git a/src/basis/inverse_multiquadric.jl b/src/basis/inverse_multiquadric.jl index 4a1dc8a3..d58dc988 100644 --- a/src/basis/inverse_multiquadric.jl +++ b/src/basis/inverse_multiquadric.jl @@ -47,21 +47,11 @@ function ∇²(rbf::IMQ) end end -function shape_param_franke(x) - # modified Franke's formula for double precision as initial guess - D = 2.0 * euclidean(first(x), last(x)) - N = size(points, 2) - return D / (0.8 * (N^0.25)) -end - function Base.show(io::IO, rbf::IMQ) print(io, "Inverse Multiquadrics, 1/sqrt((r*ε)²+1)") print(io, "\n├─Shape factor: ε = $(rbf.ε)") - if rbf.poly_deg < 0 - print(io, "\n└─No Monomial augmentation") - else - print(io, "\n└─Polynomial augmentation: degree $(rbf.poly_deg)") - end + print(io, "\n└─Polynomial augmentation: degree $(rbf.poly_deg)") + return nothing end print_basis(rbf::IMQ) = "Inverse Multiquadric (ε = $(rbf.ε))" diff --git a/src/basis/polyharmonic_spline.jl b/src/basis/polyharmonic_spline.jl index 07aa59f6..fa8af001 100644 --- a/src/basis/polyharmonic_spline.jl +++ b/src/basis/polyharmonic_spline.jl @@ -172,11 +172,8 @@ end function Base.show(io::IO, rbf::R) where {R<:AbstractPHS} print(io, print_basis(rbf)) - if rbf.poly_deg < 0 - print(io, "\n└─No Monomial augmentation") - else - print(io, "\n└─Polynomial augmentation: degree $(rbf.poly_deg)") - end + print(io, "\n└─Polynomial augmentation: degree $(rbf.poly_deg)") + return nothing end print_basis(::PHS1) = "Polyharmonic spline (r¹)" diff --git a/src/operators/directional.jl b/src/operators/directional.jl index d01a9150..b36654e8 100644 --- a/src/operators/directional.jl +++ b/src/operators/directional.jl @@ -85,7 +85,6 @@ end function update_weights!(op::RadialBasisOperator{<:Directional}) v = op.ℒ.v N = length(first(op.data)) - @assert length(v) == N || length(v) == size(op)[1] "wrong size for v" if length(v) == N op.weights .= mapreduce(+, enumerate(op.ℒ.ℒ)) do (i, ℒ) _build_weights(ℒ, op) * v[i] @@ -96,7 +95,7 @@ function update_weights!(op::RadialBasisOperator{<:Directional}) Diagonal(vv[i]) * _build_weights(ℒ, op) end end - validate_cache(op) + validate_cache!(op) return nothing end diff --git a/src/operators/gradient.jl b/src/operators/gradient.jl index 86fdedcd..9a0d9a39 100644 --- a/src/operators/gradient.jl +++ b/src/operators/gradient.jl @@ -41,7 +41,5 @@ function gradient( return RadialBasisOperator(ℒ, data, eval_points, basis; k=k, adjl=adjl) end -Base.size(op::RadialBasisOperator{<:Gradient}) = size(first(op.weights)) - # pretty printing print_op(op::Gradient) = "Gradient (∇f)" diff --git a/src/operators/operators.jl b/src/operators/operators.jl index 10adfb87..13cefc3b 100644 --- a/src/operators/operators.jl +++ b/src/operators/operators.jl @@ -81,78 +81,47 @@ function RadialBasisOperator( return RadialBasisOperator(ℒ, weights, data, eval_points, adjl, basis, true) end -# extend Base methods -Base.length(op::RadialBasisOperator) = length(op.adjl) -Base.size(op::RadialBasisOperator) = size(op.weights) -Base.size(op::RadialBasisOperator, dim::Int) = size(op.weights, dim) -function Base.size(op::RadialBasisOperator{<:VectorValuedOperator}) - return ntuple(i -> size(op.weights[i]), embeddim(op)) -end -Base.getindex(op::RadialBasisOperator, i) = nonzeros(op.weights[i, :]) -function Base.getindex(op::RadialBasisOperator{VectorValuedOperator}, i) - return ntuple(j -> nonzeros(op.weights[j][i, :]), embeddim(op)) -end - -# convienience methods -embeddim(op::RadialBasisOperator) = length(first(op.data)) +dim(op::RadialBasisOperator) = length(first(op.data)) # caching -invalidate_cache(op::RadialBasisOperator) = op.valid_cache[] = false -validate_cache(op::RadialBasisOperator) = op.valid_cache[] = true +invalidate_cache!(op::RadialBasisOperator) = op.valid_cache[] = false +validate_cache!(op::RadialBasisOperator) = op.valid_cache[] = true is_cache_valid(op::RadialBasisOperator) = op.valid_cache[] -# dispatches for evaluation -_eval_op(op::RadialBasisOperator, x::AbstractVector) = _eval_op(op.weights, x) - -function _eval_op(op::RadialBasisOperator{<:VectorValuedOperator}, x::AbstractVector) - return ntuple(i -> _eval_op(op.weights[i], x), embeddim(op)) -end - -function _eval_op( - op::RadialBasisOperator{<:VectorValuedOperator,W}, x::AbstractVector -) where {W<:Tuple} - if first(op.weights) isa Vector{<:Vector} - return ntuple(i -> _eval_op(op.weights[i], x, op.adjl), embeddim(op)) - else - return ntuple(i -> _eval_op(op.weights[i], x), embeddim(op)) - end -end - -function _eval_op( - op::RadialBasisOperator{L,W}, x::AbstractVector -) where {L,W<:Vector{<:Vector}} - return _eval_op(op.weights, x, op.adjl) -end - -_eval_op(w::AbstractMatrix, x::AbstractVector) = w * x - -function _eval_op(w::AbstractVector{<:AbstractVector{T}}, x::AbstractVector, adjl) where {T} - y = zeros(T, length(w)) - Threads.@threads for i in eachindex(adjl) - @views y[i] = w[i] ⋅ x[adjl[i]] - end - return y -end +""" + function (op::RadialBasisOperator)(x) -# evaluate +Evaluate the operator at `x`. +""" function (op::RadialBasisOperator)(x) !is_cache_valid(op) && update_weights!(op) return _eval_op(op, x) end -function LinearAlgebra.mul!( - y::AbstractVecOrMat, op::RadialBasisOperator, x::AbstractVecOrMat -) +""" + function (op::RadialBasisOperator)(y, x) + +Evaluate the operator at `x` in-place and store the result in `y`. +""" +function (op::RadialBasisOperator)(y, x) !is_cache_valid(op) && update_weights!(op) - return mul!(y, op.weights, x) + return _eval_op(op, y, x) end -function LinearAlgebra.mul!( - y::AbstractVecOrMat, op::RadialBasisOperator, x::AbstractVecOrMat, α, β -) - !is_cache_valid(op) && update_weights!(op) - return mul!(y, op.weights, x, α, β) + +# dispatches for evaluation +_eval_op(op::RadialBasisOperator, x) = op.weights * x +_eval_op(op::RadialBasisOperator, y, x) = mul!(y, op.weights, x) + +function _eval_op(op::RadialBasisOperator{<:VectorValuedOperator}, x) + return ntuple(i -> op.weights[i] * x, dim(op)) +end +function _eval_op(op::RadialBasisOperator{<:VectorValuedOperator}, y, x) + for i in eachindex(op.weights) + mul!(y[i], op.weights[i], x) + end end +# LinearAlgebra methods function LinearAlgebra.:⋅( op::RadialBasisOperator{<:VectorValuedOperator}, x::AbstractVector ) @@ -160,44 +129,18 @@ function LinearAlgebra.:⋅( return sum(op(x)) end -function LinearAlgebra.mul!( - y::AbstractVector{<:Real}, - op::RadialBasisOperator{<:VectorValuedOperator}, - x::AbstractVector, -) - !is_cache_valid(op) && update_weights!(op) - for i in eachindex(op.weights) - mul!(y[i], op.weights[i], x) - end -end - # update weights function update_weights!(op::RadialBasisOperator) op.weights .= _build_weights(op.ℒ.ℒ, op) - validate_cache(op) + validate_cache!(op) return nothing end function update_weights!(op::RadialBasisOperator{<:VectorValuedOperator}) - return _update_weights!(op, op.weights) -end - -function _update_weights!(op, weights::NTuple{N,AbstractMatrix}) where {N} for (i, ℒ) in enumerate(op.ℒ.ℒ) - weights[i] .= _build_weights(ℒ, op) + op.weights[i] .= _build_weights(ℒ, op) end - validate_cache(op) - return nothing -end - -function _update_weights!(op, weights::NTuple{N,AbstractVector}) where {N} - for (i, ℒ) in enumerate(op.ℒ.ℒ) - w = _build_weights(ℒ, op) - for j in eachindex(weights[i]) - weights[i][j] .= w[j] - end - end - validate_cache(op) + validate_cache!(op) return nothing end @@ -215,5 +158,3 @@ function Base.show(io::IO, op::RadialBasisOperator) " with degree $(op.basis.poly_deg) polynomial augmentation", ) end - -print_op(op) = "$op" diff --git a/src/operators/partial.jl b/src/operators/partial.jl index d9c07759..491d1bce 100644 --- a/src/operators/partial.jl +++ b/src/operators/partial.jl @@ -57,7 +57,6 @@ function ∂(basis::AbstractBasis, order::T, dim::T) where {T<:Int} elseif order == 2 return ∂²(basis, dim) else - return _higher_order_partial(basis, order, dim) throw( ArgumentError( "Only first and second order derivatives are supported right now. You may use the custom operator.", @@ -66,12 +65,5 @@ function ∂(basis::AbstractBasis, order::T, dim::T) where {T<:Int} end end -function _higher_order_partial(basis::MonomialBasis, order::T, dim::T) where {T<:Int} - return _∂(basis, order, Val(dim)) -end -function _higher_order_partial(_, _, _) - throw(ArgumentError("Higher order partials are not supported for RBFs yet.")) -end - # pretty printing print_op(op::Partial) = "∂ⁿf/∂xᵢ (n = $(op.order), i = $(op.dim))" diff --git a/test/operators/operators.jl b/test/operators/operators.jl new file mode 100644 index 00000000..124e4819 --- /dev/null +++ b/test/operators/operators.jl @@ -0,0 +1,64 @@ +using RadialBasisFunctions +import RadialBasisFunctions as RBF +using StaticArraysCore +using LinearAlgebra +using Statistics +using HaltonSequences + +N = 100 +x = SVector{2}.(HaltonPoint(2)[1:N]) + +@testset "Base Methods" begin + ∂ = partial(x, 1, 1) + @test is_cache_valid(∂) + RBF.invalidate_cache!(∂) + @test !is_cache_valid(∂) +end + +@testset "Operator Evaluation" begin + ∂ = partial(x, 1, 1) + y = rand(N) + z = rand(N) + ∂(y, z) + @test y ≈ ∂.weights * z + + ∇ = gradient(x, PHS(3; poly_deg=2)) + y = (rand(N), rand(N)) + ∇(y, z) + @test y[1] ≈ ∇.weights[1] * z + @test y[2] ≈ ∇.weights[2] * z + + @test ∇ ⋅ z ≈ (∇.weights[1] * z) .+ (∇.weights[2] * z) +end + +@testset "Operator Update" begin + ∂ = partial(x, 1, 1) + correct_weights = copy(∂.weights) + ∂.weights .= rand(size(∂.weights)) + update_weights!(∂) + @test ∂.weights ≈ correct_weights + @test is_cache_valid(∂) + + ∇ = gradient(x, PHS(3; poly_deg=2)) + correct_weights = copy.(∇.weights) + ∇.weights[1] .= rand(size(∇.weights[1])) + ∇.weights[2] .= rand(size(∇.weights[2])) + update_weights!(∇) + @test ∇.weights[1] ≈ correct_weights[1] + @test ∇.weights[2] ≈ correct_weights[2] + @test is_cache_valid(∇) +end + +@testset "Printing" begin + ∂ = partial(x, 1, 1) + @test repr(∂) == """ + RadialBasisOperator + ├─Operator: ∂ⁿf/∂xᵢ (n = 1, i = 1) + ├─Data type: StaticArraysCore.SVector{2, Float64} + ├─Number of points: 100 + ├─Stencil size: 12 + └─Basis: Polyharmonic spline (r³) with degree 2 polynomial augmentation + """ + + @test RBF.print_op(∂.ℒ) == "∂ⁿf/∂xᵢ (n = 1, i = 1)" +end diff --git a/test/operators/partial.jl b/test/operators/partial.jl index 545c1ca7..ad3ae6b4 100644 --- a/test/operators/partial.jl +++ b/test/operators/partial.jl @@ -70,6 +70,10 @@ end @test mean_percent_error(∂y(y), df_dy.(x2)) < 10 end +@testset "Higher Order Derivatives" begin + @test_throws ArgumentError partial(x, 3, 1) +end + @testset "Printing" begin ∂ = Partial(identity, 1, 2) @test RadialBasisFunctions.print_op(∂) == "∂ⁿf/∂xᵢ (n = 1, i = 2)" diff --git a/test/runtests.jl b/test/runtests.jl index 95e446d3..6018a719 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,6 @@ using SafeTestsets -@safetestset "Basis - General Utils" begin +@safetestset "Basis - General" begin include("basis/basis.jl") end @@ -20,6 +20,10 @@ end include("basis/monomial.jl") end +@safetestset "Operators" begin + include("operators/operators.jl") +end + @safetestset "Partial Derivatives" begin include("operators/partial.jl") end