diff --git a/README.md b/README.md index cd320a1..1bcf32a 100644 --- a/README.md +++ b/README.md @@ -284,8 +284,9 @@ julia> xrange_long,yrange_long,zrange_long = 1:3000,1:3000,1:3000 julia> params_long = (xrange_long,yrange_long,zrange_long); -julia> ps_long = ProductSplit(params_long, 10, 4) -ProductSplit{Tuple{Int64,Int64,Int64},3,UnitRange{Int64}}((1:3000, 1:3000, 1:3000), (0, 3000, 9000000), 10, 4, 8100000001, 10800000000) +julia> ps = ProductSplit(params_long, 10, 3) +2700000000-element ProductSplit((1:3000, 1:3000, 1:3000), 10, 3) +[(1, 1, 601), ... , (3000, 3000, 900)] # Evaluate length using random ranges to avoid compiler optimizations julia> @btime length(p) setup = (n = rand(3000:4000); p = ProductSplit((1:n,1:n,1:n), 200, 2)); @@ -333,25 +334,24 @@ Another useful function is `whichproc` that returns the rank of the processor a julia> whichproc(params_long, val, 10) 4 -julia> @btime whichproc($params_long, $val, 10) - 1.264 μs (14 allocations: 448 bytes) -4 +julia> @btime whichproc($params_long, $val, 10); + 353.706 ns (0 allocations: 0 bytes) ``` ### Extrema -We can compute the ranges of each variable on any processor in `O(1)` time. +We may compute the ranges of each variable on any processor in `O(1)` time. ```julia -julia> extrema(ps, dim=2) # extrema of the second parameter on this processor +julia> extrema(ps, dim = 2) # extrema of the second parameter on this processor (3, 4) -julia> Tuple(extrema(ps, dim=i) for i in 1:3) +julia> Tuple(extrema(ps, dim = i) for i in 1:3) ((1, 3), (3, 4), (4, 4)) # Minimum and maximum work similarly -julia> (minimum(ps, dim=2), maximum(ps, dim=2)) +julia> (minimum(ps, dim = 2), maximum(ps, dim = 2)) (3, 4) julia> @btime extrema($ps_long, dim=2) diff --git a/src/errors.jl b/src/errors.jl index 986cac6..c1de4f5 100644 --- a/src/errors.jl +++ b/src/errors.jl @@ -1,17 +1,3 @@ -struct ProcessorNumberError <: Exception - p :: Int - np :: Int -end -function Base.showerror(io::IO,err::ProcessorNumberError) - print(io,"processor id $(err.p) does not lie in the range $(1:err.np)") -end - -struct DecreasingIteratorError <: Exception -end -function Base.showerror(io::IO,err::DecreasingIteratorError) - print(io,"all the iterators need to be strictly increasing") -end - struct TaskNotPresentError{T,U} <: Exception t :: T task :: U diff --git a/src/productsplit.jl b/src/productsplit.jl index d013050..a59abc8 100644 --- a/src/productsplit.jl +++ b/src/productsplit.jl @@ -7,42 +7,6 @@ abstract type AbstractConstrainedProduct{T,N} end Base.eltype(::AbstractConstrainedProduct{T}) where {T} = T Base.ndims(::AbstractConstrainedProduct{<:Any,N}) where {N} = N -""" - ProductSplit{T,N,Q} - -Iterator that loops over the outer product of ranges in -reverse-lexicographic order. The ranges need to be strictly -increasing. Given `N` ranges, -each element returned by the iterator will be -a tuple of length `N` with one element from each range. - -See also: [`ProductSection`](@ref) -""" -struct ProductSplit{T,N,Q} <: AbstractConstrainedProduct{T,N} - iterators :: Q - togglelevels :: NTuple{N,Int} - np :: Int - p :: Int - firstind :: Int - lastind :: Int - - function ProductSplit(iterators::Tuple{Vararg{AbstractRange,N}},togglelevels::NTuple{N,Int}, - np::Int,p::Int,firstind::Int,lastind::Int) where {N} - - 1 <= p <= np || throw(ProcessorNumberError(p,np)) - T = Tuple{map(eltype,iterators)...} - Q = typeof(iterators) - - # Ensure that all the iterators are strictly increasing - all(x->step(x)>0,iterators) || - throw(ArgumentError("all the iterators need to be strictly increasing")) - - new{T,N,Q}(iterators,togglelevels,np,p,firstind,lastind) - end -end - -workerrank(ps::ProductSplit) = ps.p - """ ProductSection{T,N,Q} @@ -61,28 +25,17 @@ struct ProductSection{T,N,Q} <: AbstractConstrainedProduct{T,N} firstind :: Int lastind :: Int - function ProductSection(iterators::Tuple{Vararg{AbstractRange,N}},togglelevels::NTuple{N,Int}, - firstind::Int,lastind::Int) where {N} - - T = Tuple{eltype.(iterators)...} - Q = typeof(iterators) + function ProductSection(iterators::Tuple{Vararg{AbstractRange,N}}, togglelevels::NTuple{N,Int}, + firstind::Int, lastind::Int) where {N} # Ensure that all the iterators are strictly increasing - all(x->step(x)>0,iterators) || + all(x->step(x)>0, iterators) || throw(ArgumentError("all the iterators need to be strictly increasing")) - new{T,N,Q}(iterators,togglelevels,firstind,lastind) - end -end + T = Tuple{eltype.(iterators)...} -function mwerepr(ps::ProductSplit) - "ProductSplit("*repr(ps.iterators)*", "*repr(ps.np)*", "*repr(ps.p)*")" -end -function Base.summary(io::IO, ps::ProductSplit) - print(io, length(ps),"-element ", mwerepr(ps)) -end -function Base.show(io::IO, ps::ProductSplit) - print(io, mwerepr(ps)) + new{T,N,typeof(iterators)}(iterators, togglelevels, firstind, lastind) + end end function _cumprod(len::Tuple) @@ -94,16 +47,65 @@ function _cumprod(n::Integer, tl::Tuple) (n,_cumprod(n*first(tl),Base.tail(tl))...) end -@deprecate ntasks(x::Tuple) prod(length, x) +""" + ProductSection(iterators::Tuple{Vararg{AbstractRange}}, inds::AbstractUnitRange) + +Construct a `ProductSection` iterator that represents a 1D view of the outer product +of the ranges provided in `iterators`, with the range of indices in the view being +specified by `inds`. +# Examples +```jldoctest +julia> p = ParallelUtilities.ProductSection((1:3,4:6), 5:8); + +julia> collect(p) +4-element Array{Tuple{Int64,Int64},1}: + (2, 5) + (3, 5) + (1, 6) + (2, 6) + +julia> collect(p) == collect(Iterators.product(1:3, 4:6))[5:8] +true +``` """ - ntasks(iterators::Tuple) +function ProductSection(iterators::Tuple{AbstractRange,Vararg{AbstractRange}}, + inds::AbstractUnitRange) + + firstind, lastind = first(inds), last(inds) + + len = length.(iterators) + Nel = prod(len) + 1 <= firstind || throw( + ArgumentError("the range of indices must start from a number ≥ 1")) + lastind <= Nel || throw( + ArgumentError("the maximum index must be less than or equal to the total number of elements = $Nel")) + togglelevels = _cumprod(len) + ProductSection(iterators, togglelevels, firstind, lastind) +end +ProductSection(::Tuple{}, ::AbstractUnitRange) = throw(ArgumentError("Need at least one iterator")) -The total number of elements in the outer product of the ranges contained in -`iterators`, equal to `prod(length, iterators)` """ -ntasks -ntasks(ps::AbstractConstrainedProduct) = ntasks(ps.iterators) + ProductSplit{T,N,Q} + +Iterator that loops over the outer product of ranges in +reverse-lexicographic order. The ranges need to be strictly +increasing. Given `N` ranges, +each element returned by the iterator will be +a tuple of length `N` with one element from each range. + +See also: [`ProductSection`](@ref) +""" +struct ProductSplit{T,N,Q <: ProductSection{T,N}} <: AbstractConstrainedProduct{T,N} + ps :: Q + np :: Int + p :: Int + + function ProductSplit(ps::ProductSection, np::Integer, p::Integer) + 1 <= p <= np || throw(ArgumentError("processor rank out of range")) + new{eltype(ps),ndims(ps),typeof(ps)}(ps, np, p) + end +end """ ProductSplit(iterators::Tuple{Vararg{AbstractRange}}, np::Integer, p::Integer) @@ -125,62 +127,53 @@ julia> ProductSplit((1:2,4:5), 2, 2) |> collect (2, 5) ``` """ -function ProductSplit(iterators::Tuple{Vararg{AbstractRange}}, np::Integer, p::Integer) - len = size.(iterators,1) - Nel = prod(len) - togglelevels = _cumprod(len) - d,r = divrem(Nel,np) +function ProductSplit(iterators::Tuple{AbstractRange,Vararg{AbstractRange}}, np::Integer, p::Integer) + d,r = divrem(prod(length, iterators), np) firstind = d*(p-1) + min(r,p-1) + 1 lastind = d*p + min(r,p) - ProductSplit(iterators,togglelevels,np,p,firstind,lastind) + ProductSplit(ProductSection(iterators, firstind:lastind), np, p) end -ProductSplit(::Tuple{},::Integer,::Integer) = throw(ArgumentError("Need at least one iterator")) +ProductSplit(::Tuple{}, ::Integer, ::Integer) = throw(ArgumentError("Need at least one iterator")) -""" - ProductSection(iterators::Tuple{Vararg{AbstractRange}}, inds::AbstractUnitRange) +workerrank(ps::ProductSplit) = ps.p -Construct a `ProductSection` iterator that represents a 1D view of the outer product -of the ranges provided in `iterators`, with the range of indices in the view being -specified by `inds`. +ProductSection(ps::ProductSection) = ps +ProductSection(ps::ProductSplit) = ps.ps -# Examples -```jldoctest -julia> p = ParallelUtilities.ProductSection((1:3,4:6), 5:8); +getiterators(ps::AbstractConstrainedProduct) = ProductSection(ps).iterators +togglelevels(ps::AbstractConstrainedProduct) = ProductSection(ps).togglelevels -julia> collect(p) -4-element Array{Tuple{Int64,Int64},1}: - (2, 5) - (3, 5) - (1, 6) - (2, 6) +function mwerepr(ps::ProductSplit) + "ProductSplit(" * repr(getiterators(ps)) * ", " * repr(ps.np) * ", " * repr(ps.p) * ")" +end +function mwerepr(ps::ProductSection) + "ProductSection(" * repr(getiterators(ps)) * ", " * repr(firstindexglobal(ps):lastindexglobal(ps)) * ")" +end +function Base.summary(io::IO, ps::AbstractConstrainedProduct) + print(io, length(ps),"-element ", mwerepr(ps)) + if !isempty(ps) + print(io, "\n[", repr(first(ps)) * ", ... , " * repr(last(ps)), "]") + end +end +function Base.show(io::IO, ps::AbstractConstrainedProduct) + print(io, summary(ps)) +end -julia> collect(p) == collect(Iterators.product(1:3, 4:6))[5:8] -true -``` -""" -function ProductSection(iterators::Tuple{Vararg{AbstractRange}}, - inds::AbstractUnitRange) +@deprecate ntasks(x::Tuple) prod(length, x) - isempty(inds) && throw(ArgumentError("range of indices must not be empty")) - firstind,lastind = extrema(inds) +""" + ntasks(iterators::Tuple) - len = size.(iterators,1) - Nel = prod(len) - 1 <= firstind || throw( - ArgumentError("the range of indices must start from a number ≥ 1")) - lastind <= Nel || throw( - ArgumentError("the maximum index must be less than or equal to the total number of elements = $Nel")) - togglelevels = _cumprod(len) - ProductSection(iterators,togglelevels,firstind,lastind) -end -function ProductSection(::Tuple{},::AbstractUnitRange) - throw(ArgumentError("Need at least one iterator")) -end +The total number of elements in the outer product of the ranges contained in +`iterators`, equal to `prod(length, iterators)` +""" +ntasks +ntasks(ps::AbstractConstrainedProduct) = ntasks(getiterators(ps)) -Base.isempty(ps::AbstractConstrainedProduct) = (ps.firstind > ps.lastind) +Base.isempty(ps::AbstractConstrainedProduct) = (firstindexglobal(ps) > lastindexglobal(ps)) function Base.first(ps::AbstractConstrainedProduct) - isempty(ps) ? nothing : @inbounds _first(ps.iterators, childindex(ps, ps.firstind)...) + isempty(ps) ? nothing : @inbounds _first(getiterators(ps), childindex(ps, firstindexglobal(ps))...) end Base.@propagate_inbounds function _first(t::Tuple, ind::Integer, rest::Integer...) @@ -190,7 +183,7 @@ end _first(::Tuple{}) = () function Base.last(ps::AbstractConstrainedProduct) - isempty(ps) ? nothing : @inbounds _last(ps.iterators, childindex(ps, ps.lastind)...) + isempty(ps) ? nothing : @inbounds _last(getiterators(ps), childindex(ps, lastindexglobal(ps))...) end Base.@propagate_inbounds function _last(t::Tuple, ind::Integer, rest::Integer...) @@ -199,42 +192,52 @@ Base.@propagate_inbounds function _last(t::Tuple, ind::Integer, rest::Integer... end _last(::Tuple{}) = () -Base.length(ps::AbstractConstrainedProduct) = ps.lastind - ps.firstind + 1 +Base.length(ps::AbstractConstrainedProduct) = lastindex(ps) Base.firstindex(ps::AbstractConstrainedProduct) = 1 -Base.lastindex(ps::AbstractConstrainedProduct) = ps.lastind - ps.firstind + 1 +Base.lastindex(ps::AbstractConstrainedProduct) = lastindexglobal(ps) - firstindexglobal(ps) + 1 + +firstindexglobal(ps::AbstractConstrainedProduct) = ProductSection(ps).firstind +lastindexglobal(ps::AbstractConstrainedProduct) = ProductSection(ps).lastind """ childindex(ps::AbstractConstrainedProduct, ind) -Return a tuple containing the indices of the individual iterators +Return a tuple containing the indices of the individual `AbstractRange`s corresponding to the element that is present at index `ind` in the -outer product of the iterators. +outer product of the ranges. + +!!! note + The index `ind` corresponds to the outer product of the ranges, and not to `ps`. # Examples ```jldoctest -julia> ps = ProductSplit((1:5, 2:4, 1:3), 7, 1); +julia> iters = (1:5, 2:4, 1:3); -julia> ParallelUtilities.childindex(ps, 6) +julia> ps = ProductSplit(iters, 7, 1); + +julia> ind = 6; + +julia> cinds = ParallelUtilities.childindex(ps, ind) (1, 2, 1) -julia> v = collect(Iterators.product(1:5, 2:4, 1:3)); +julia> v = collect(Iterators.product(iters...)); -julia> getindex.(ps.iterators, ParallelUtilities.childindex(ps,6)) == v[6] +julia> getindex.(iters, cinds) == v[ind] true ``` See also: [`childindexshifted`](@ref) """ function childindex(ps::AbstractConstrainedProduct, ind) - tl = reverse(Base.tail(ps.togglelevels)) + tl = reverse(Base.tail(togglelevels(ps))) reverse(childindex(tl,ind)) end function childindex(tl::Tuple, ind) t = first(tl) k = div(ind - 1, t) - (k+1,childindex(Base.tail(tl), ind - k*t)...) + (k+1, childindex(Base.tail(tl), ind - k*t)...) end # First iterator gets the final remainder @@ -246,21 +249,28 @@ childindex(::Tuple{}, ind) = (ind,) Return a tuple containing the indices in the individual iterators given an index of a `AbstractConstrainedProduct`. +If the ranges `(r1, r2, ...)` are used to generate +`ps`, then return `(i1, i2, ...)` such that `ps[ind] == (r1[i1], r2[i2], ...)`. + # Examples ```jldoctest -julia> ps = ProductSplit((1:5, 2:4, 1:3), 7, 3); +julia> iters = (1:5, 2:4, 1:3); -julia> cinds = ParallelUtilities.childindexshifted(ps, 3) -(2, 1, 2) +julia> ps = ProductSplit(iters, 7, 3); -julia> getindex.(ps.iterators, cinds) == ps[3] +julia> psind = 4; + +julia> cinds = ParallelUtilities.childindexshifted(ps, psind) +(3, 1, 2) + +julia> getindex.(iters, cinds) == ps[psind] true ``` See also: [`childindex`](@ref) """ function childindexshifted(ps::AbstractConstrainedProduct, ind) - childindex(ps, (ind - 1) + ps.firstind) + childindex(ps, (ind - 1) + firstindexglobal(ps)) end Base.@propagate_inbounds function Base.getindex(ps::AbstractConstrainedProduct, ind) @@ -274,12 +284,12 @@ end Base.@propagate_inbounds function _getindex(ps::AbstractConstrainedProduct{<:Any,N}, inds::Vararg{Integer,N}) where {N} - _getindex(ps.iterators,inds...) + _getindex(getiterators(ps), inds...) end -Base.@propagate_inbounds function _getindex(t::Tuple,ind::Integer,rest::Integer...) +Base.@propagate_inbounds function _getindex(t::Tuple, ind::Integer, rest::Integer...) @boundscheck (1 <= ind <= length(first(t))) || throw(BoundsError(first(t),ind)) - (@inbounds first(t)[ind], _getindex(Base.tail(t),rest...)...) + (@inbounds first(t)[ind], _getindex(Base.tail(t), rest...)...) end _getindex(::Tuple{}, ::Integer...) = () @@ -299,10 +309,10 @@ function Base.iterate(ps::AbstractConstrainedProduct{T}, state=(first(ps), 1)) w end function _firstlastalongdim(ps::AbstractConstrainedProduct, dim, - firstindchild::Tuple = childindex(ps, ps.firstind), - lastindchild::Tuple = childindex(ps, ps.lastind)) + firstindchild::Tuple = childindex(ps, firstindexglobal(ps)), + lastindchild::Tuple = childindex(ps, lastindexglobal(ps))) - iter = ps.iterators[dim] + iter = getiterators(ps)[dim] fic = firstindchild[dim] lic = lastindchild[dim] @@ -314,10 +324,10 @@ function _firstlastalongdim(ps::AbstractConstrainedProduct, dim, end function _checkrollover(ps::AbstractConstrainedProduct, dim, - firstindchild::Tuple=childindex(ps,ps.firstind), - lastindchild::Tuple=childindex(ps,ps.lastind)) + firstindchild::Tuple = childindex(ps, firstindexglobal(ps)), + lastindchild::Tuple = childindex(ps, lastindexglobal(ps))) - _checkrollover(ps.iterators,dim,firstindchild,lastindchild) + _checkrollover(getiterators(ps), dim, firstindchild, lastindchild) end function _checkrollover(t::Tuple, dim, firstindchild::Tuple, lastindchild::Tuple) @@ -341,7 +351,7 @@ _checknorollover(::Tuple{}, ::Tuple{}, ::Tuple{}) = true function _nrollovers(ps::AbstractConstrainedProduct, dim::Integer) dim == ndims(ps) && return 0 - nelements(ps, dim + 1) - 1 + nelements(ps; dim = dim + 1) - 1 end """ @@ -380,7 +390,7 @@ end function nelements(ps::AbstractConstrainedProduct; dim::Integer) 1 <= dim <= ndims(ps) || throw(ArgumentError("1 ⩽ dim ⩽ N=$(ndims(ps)) not satisfied for dim=$dim")) - iter = ps.iterators[dim] + iter = getiterators(ps)[dim] if _nrollovers(ps,dim) == 0 st = first(ps)[dim] @@ -431,8 +441,8 @@ function Base.maximum(ps::AbstractConstrainedProduct; dim::Integer) isempty(ps) && return nothing - firstindchild = childindex(ps, ps.firstind) - lastindchild = childindex(ps, ps.lastind) + firstindchild = childindex(ps, firstindexglobal(ps)) + lastindchild = childindex(ps, lastindexglobal(ps)) first_iter,last_iter = _firstlastalongdim(ps, dim, firstindchild, lastindchild) @@ -444,7 +454,7 @@ function Base.maximum(ps::AbstractConstrainedProduct; dim::Integer) end if _checkrollover(ps, dim, firstindchild, lastindchild) - iter = ps.iterators[dim] + iter = getiterators(ps)[dim] v = maximum(iter) end @@ -458,9 +468,9 @@ end function Base.maximum(ps::AbstractConstrainedProduct{<:Any,1}) isempty(ps) && return nothing - lastindchild = childindex(ps, ps.lastind) + lastindchild = childindex(ps, lastindexglobal(ps)) lic_dim = lastindchild[1] - iter = ps.iterators[1] + iter = getiterators(ps)[1] iter[lic_dim] end @@ -489,8 +499,8 @@ function Base.minimum(ps::AbstractConstrainedProduct; dim::Integer) isempty(ps) && return nothing - firstindchild = childindex(ps, ps.firstind) - lastindchild = childindex(ps, ps.lastind) + firstindchild = childindex(ps, firstindexglobal(ps)) + lastindchild = childindex(ps, lastindexglobal(ps)) first_iter,last_iter = _firstlastalongdim(ps, dim, firstindchild, lastindchild) @@ -502,7 +512,7 @@ function Base.minimum(ps::AbstractConstrainedProduct; dim::Integer) end if _checkrollover(ps, dim, firstindchild, lastindchild) - iter = ps.iterators[dim] + iter = getiterators(ps)[dim] v = minimum(iter) end @@ -516,9 +526,9 @@ end function Base.minimum(ps::AbstractConstrainedProduct{<:Any,1}) isempty(ps) && return nothing - firstindchild = childindex(ps,ps.firstind) + firstindchild = childindex(ps, firstindexglobal(ps)) fic_dim = firstindchild[1] - iter = ps.iterators[1] + iter = getiterators(ps)[1] iter[fic_dim] end @@ -547,8 +557,8 @@ function Base.extrema(ps::AbstractConstrainedProduct; dim::Integer) isempty(ps) && return nothing - firstindchild = childindex(ps, ps.firstind) - lastindchild = childindex(ps, ps.lastind) + firstindchild = childindex(ps, firstindexglobal(ps)) + lastindchild = childindex(ps, lastindexglobal(ps)) first_iter,last_iter = _firstlastalongdim(ps, dim, firstindchild, lastindchild) @@ -559,7 +569,7 @@ function Base.extrema(ps::AbstractConstrainedProduct; dim::Integer) end if _checkrollover(ps, dim, firstindchild, lastindchild) - iter = ps.iterators[dim] + iter = getiterators(ps)[dim] v = extrema(iter) end @@ -568,11 +578,11 @@ end function Base.extrema(ps::AbstractConstrainedProduct{<:Any,1}) isempty(ps) && return nothing - firstindchild = childindex(ps, ps.firstind) - lastindchild = childindex(ps, ps.lastind) + firstindchild = childindex(ps, firstindexglobal(ps)) + lastindchild = childindex(ps, lastindexglobal(ps)) fic_dim = firstindchild[1] lic_dim = lastindchild[1] - iter = ps.iterators[1] + iter = getiterators(ps)[1] (iter[fic_dim], iter[lic_dim]) end @@ -594,7 +604,7 @@ map(i -> extrema(ps, dim = i), 1:ndims(ps)) but it is implemented more efficiently. -Returns a `Tuple` containing the `(min, max)` pairs along each +Returns a `Tuple` containing the `(min,max)` pairs along each dimension, such that the `i`-th index of the result contains the `extrema` along the section of the `i`-th range contained locally. @@ -613,7 +623,7 @@ julia> extremadims(ps) """ function extremadims(ps::AbstractConstrainedProduct) Base.depwarn("extremadims will not be exported in a future release, please call it as ParallelUtilities.extremadims instead", :extremadims) - _extremadims(ps, 1, ps.iterators) + _extremadims(ps, 1, getiterators(ps)) end function _extremadims(ps::AbstractConstrainedProduct, dim::Integer, iterators::Tuple) @@ -682,7 +692,7 @@ function extrema_commonlastdim(ps::AbstractConstrainedProduct{<:Any,N}) where {N [(m,lastvar_min) for m in min_vals],[(m,lastvar_max) for m in max_vals] end -_infullrange(val::T, ps::AbstractConstrainedProduct{T}) where {T} = _infullrange(val,ps.iterators) +_infullrange(val::T, ps::AbstractConstrainedProduct{T}) where {T} = _infullrange(val,getiterators(ps)) function _infullrange(val, t::Tuple) first(val) in first(t) && _infullrange(Base.tail(val),Base.tail(t)) @@ -736,17 +746,17 @@ indexinproduct(::Tuple{}, ::Tuple) = throw(ArgumentError("need at least one iter function Base.in(val::T, ps::AbstractConstrainedProduct{T}) where {T} _infullrange(val,ps) || return false - ind = indexinproduct(ps.iterators, val) - ps.firstind <= ind <= ps.lastind + ind = indexinproduct(getiterators(ps), val) + firstindexglobal(ps) <= ind <= lastindexglobal(ps) end # This struct is just a wrapper to flip the tuples before comparing -struct ReverseLexicographicTuple{T} +struct ReverseLexicographicTuple{T<:Tuple} t :: T end Base.isless(a::ReverseLexicographicTuple{T}, b::ReverseLexicographicTuple{T}) where {T} = reverse(a.t) < reverse(b.t) -Base.isequal(a::ReverseLexicographicTuple{T}, b::ReverseLexicographicTuple{T}) where {T} = a.t == b.t +Base.isequal(a::ReverseLexicographicTuple, b::ReverseLexicographicTuple) = a.t == b.t """ whichproc(iterators::Tuple, val::Tuple, np::Integer) @@ -775,13 +785,15 @@ julia> whichproc(iters, (2,3), np) function whichproc(iterators, val, np::Integer) _infullrange(val,iterators) || return nothing + np >= 1 || throw(ArgumentError("np must be >= 1")) + np == 1 && return 1 # We may carry out a binary search as the iterators are sorted left,right = 1,np val_t = ReverseLexicographicTuple(val) - while left <= right + while left < right mid = div(left+right, 2) ps = ProductSplit(iterators, np, mid) @@ -798,14 +810,18 @@ function whichproc(iterators, val, np::Integer) return mid end end + + return left end whichproc(iterators, ::Nothing, np::Integer) = nothing +whichproc(ps::ProductSplit, val) = whichproc(getiterators(ps), val, ps.np) + # This function tells us the range of processors that would be involved # if we are to compute the tasks contained in the list ps on np_new processors. # The total list of tasks is contained in iterators, and might differ from -# ps.iterators (eg if ps contains a subsection of the parameter set) +# getiterators(ps) (eg if ps contains a subsection of the parameter set) """ procrange_recast(iterators::Tuple, ps::ProductSplit, np_new::Integer) @@ -867,7 +883,7 @@ julia> procrange_recast(ps, 10) # If `iters` were spread across 10 processes ``` """ function procrange_recast(ps::AbstractConstrainedProduct, np_new::Integer) - procrange_recast(ps.iterators, ps, np_new) + procrange_recast(getiterators(ps), ps, np_new) end """ @@ -895,8 +911,8 @@ function localindex(ps::AbstractConstrainedProduct{T}, val::T) where {T} (isempty(ps) || val ∉ ps) && return nothing - indflat = indexinproduct(ps.iterators, val) - indflat - ps.firstind + 1 + indflat = indexinproduct(getiterators(ps), val) + indflat - firstindexglobal(ps) + 1 end # this is only needed because first and last return nothing if the ProductSplit is empty @@ -969,7 +985,7 @@ julia> ParallelUtilities.dropleading(ps) |> collect """ function dropleading(ps::AbstractConstrainedProduct) isempty(ps) && throw(ArgumentError("need at least one iterator")) - iterators = Base.tail(ps.iterators) + iterators = Base.tail(getiterators(ps)) first_element = Base.tail(first(ps)) last_element = Base.tail(last(ps)) firstind = indexinproduct(iterators, first_element) diff --git a/src/utils.jl b/src/utils.jl index 9b8fbf7..5dc886c 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -6,7 +6,7 @@ Number of workers required to contain the outer product of the iterators. function nworkersactive(iterators::Tuple) min(nworkers(), prod(length, iterators)) end -nworkersactive(ps::ProductSplit) = nworkersactive(ps.iterators) +nworkersactive(ps::AbstractConstrainedProduct) = nworkersactive(getiterators(ps)) nworkersactive(args::AbstractRange...) = nworkersactive(args) """ @@ -17,5 +17,5 @@ If `prod(length, iterators) < nworkers()` then the first `prod(length, iterators workers are chosen. """ workersactive(iterators::Tuple) = workers()[1:nworkersactive(iterators)] -workersactive(ps::ProductSplit) = workersactive(ps.iterators) +workersactive(ps::AbstractConstrainedProduct) = workersactive(getiterators(ps)) workersactive(args::AbstractRange...) = workersactive(args) \ No newline at end of file diff --git a/test/tests.jl b/test/tests.jl index 2d06feb..b39f63d 100644 --- a/test/tests.jl +++ b/test/tests.jl @@ -10,7 +10,7 @@ maybepvalput!, createbranchchannels, nworkersactive, workersactive, procs_node, leafrankfoldedtree, TopTreeNode, SubTreeNode, ProductSection, indexinproduct, dropleading, - nelements + nelements, getiterators, firstindexglobal, lastindexglobal end const future_release_warn = r"will not be exported in a future release"i @@ -66,23 +66,23 @@ end @test ndims(ps) == length(iters) @test collect(ps) == collect(split_product_across_processors_iterators(iters,np,p)) @test (@test_deprecated ntasks(ps)) == ntasks_total - @test prod(length, ps.iterators) == ntasks_total + @test prod(length, getiterators(ps)) == ntasks_total @test ParallelUtilities.workerrank(ps) == p end - @test_throws ParallelUtilities.ProcessorNumberError ProductSplit(iters,npmax,npmax+1) + @test_throws ArgumentError ProductSplit(iters, npmax, npmax+1) end @testset "0D" begin @test_throws ArgumentError ProductSplit((),2,1) end - @testset "cumprod" begin - @test ParallelUtilities._cumprod(1,()) == () - @test ParallelUtilities._cumprod(1,(2,)) == (1,) - @test ParallelUtilities._cumprod(1,(2,3)) == (1,2) - @test ParallelUtilities._cumprod(1,(2,3,4)) == (1,2,6) - end + @testset "cumprod" begin + @test ParallelUtilities._cumprod(1,()) == () + @test ParallelUtilities._cumprod(1,(2,)) == (1,) + @test ParallelUtilities._cumprod(1,(2,3)) == (1,2) + @test ParallelUtilities._cumprod(1,(2,3,4)) == (1,2,6) + end @testset "1D" begin iters = (1:10,) @@ -119,22 +119,22 @@ end for iters in [(1:10,),(1:2,Base.OneTo(4),1:3:10)] ps = ProductSplit(iters,2,1) @test firstindex(ps) == 1 - @test ps.firstind == 1 - @test ps.lastind == div(prod(length, iters),2) + @test firstindexglobal(ps) == 1 + @test lastindexglobal(ps) == div(prod(length, iters),2) @test lastindex(ps) == div(prod(length, iters),2) @test lastindex(ps) == length(ps) ps = ProductSplit(iters,2,2) - @test ps.firstind == div(prod(length, iters),2) + 1 + @test firstindexglobal(ps) == div(prod(length, iters),2) + 1 @test firstindex(ps) == 1 - @test ps.lastind == prod(length, iters) + @test lastindexglobal(ps) == prod(length, iters) @test lastindex(ps) == length(ps) for np in prod(length, iters)+1:prod(length, iters)+10, p in prod(length, iters)+1:np ps = ProductSplit(iters,np,p) - @test ps.firstind == prod(length, iters) + 1 - @test ps.lastind == prod(length, iters) + @test firstindexglobal(ps) == prod(length, iters) + 1 + @test lastindexglobal(ps) == prod(length, iters) end end end @@ -146,14 +146,18 @@ end @test ParallelUtilities.mwerepr(ps) == reprstr summarystr = "$(length(ps))-element "*reprstr - @test ParallelUtilities.summary(ps) == summarystr + @test occursin(summarystr, ParallelUtilities.summary(ps)) io = IOBuffer() summary(io,ps) - @test String(take!(io)) == summarystr + @test occursin(summarystr, String(take!(io))) show(io, ps) - @test String(take!(io)) == reprstr + @test occursin(summarystr, String(take!(io))) + + ps = ParallelUtilities.ProductSection(iters,4:5) + reprstr = "ProductSection("*repr(iters)*", " * repr(4:5) * ")" + @test ParallelUtilities.mwerepr(ps) == reprstr end end @@ -280,6 +284,8 @@ end iters = (1:10,4:6,1:4) ps = ProductSplit(iters,np,proc_id) @test whichproc(iters,first(ps),1) == 1 + @test whichproc(ps,first(ps)) == proc_id + @test whichproc(ps,last(ps)) == proc_id @test whichproc(iters,(100,100,100),1) === nothing @test (@test_deprecated future_release_warn procrange_recast(iters,ps,1)) == 1:1 @test (@test_deprecated future_release_warn procrange_recast(ps,1)) == 1:1 @@ -417,7 +423,6 @@ end end @test_throws ArgumentError ProductSection((),2:3) - @test_throws ArgumentError ProductSection((1:3,),1:0) end end @testset "dropleading" begin @@ -2626,14 +2631,6 @@ end; @testset "error" begin io = IOBuffer() - showerror(io,ParallelUtilities.ProcessorNumberError(5,2)) - strexp = "processor id 5 does not lie in the range 1:2" - @test String(take!(io)) == strexp - - showerror(io,ParallelUtilities.DecreasingIteratorError()) - strexp = "all the iterators need to be strictly increasing" - @test String(take!(io)) == strexp - showerror(io,ParallelUtilities.TaskNotPresentError((1:4,),(5,))) strexp = "could not find the task $((5,)) in the list $((1:4,))" @test String(take!(io)) == strexp