Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion ext/cuda/operators_spectral_element.jl
Original file line number Diff line number Diff line change
Expand Up @@ -152,13 +152,14 @@ Base.@propagate_inbounds function resolve_shmem!(
===#

if isactive
args = Operators._get_node(space, ij, slabidx, sbc.args)
operator_fill_shmem!(
sbc.op,
sbc.work,
space,
ij,
slabidx,
Operators._get_node(space, ij, slabidx, sbc.args...)...,
args...,
)
end
CUDA.sync_threads()
Expand Down
2 changes: 2 additions & 0 deletions src/DataLayouts/DataLayouts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,10 @@ module DataLayouts
import Base: Base, @propagate_inbounds
import StaticArrays: SOneTo, MArray, SArray
import ClimaComms
import UnrolledUtilities: unrolled_map, unrolled_all
import MultiBroadcastFusion as MBF
import Adapt
import UnrolledUtilities: unrolled_foreach, unrolled_all, unrolled_findfirst

import ..Utilities: PlusHalf, unionall_type
import ..DebugOnly: call_post_op_callback, post_op_callback
Expand Down
6 changes: 1 addition & 5 deletions src/DataLayouts/fused_copyto.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,8 @@ Base.@propagate_inbounds function rcopyto_at_linear!(
return nothing
end
Base.@propagate_inbounds function rcopyto_at_linear!(pairs::Tuple, I)
rcopyto_at_linear!(first(pairs), I)
rcopyto_at_linear!(Base.tail(pairs), I)
unrolled_foreach(Base.Fix2(rcopyto_at_linear!, I), pairs)
end
Base.@propagate_inbounds rcopyto_at_linear!(pairs::Tuple{<:Any}, I) =
rcopyto_at_linear!(first(pairs), I)
@inline rcopyto_at_linear!(pairs::Tuple{}, I) = nothing

# Fused multi-broadcast entry point for DataLayouts
function Base.copyto!(
Expand Down
40 changes: 11 additions & 29 deletions src/DataLayouts/has_uniform_datalayouts.jl
Original file line number Diff line number Diff line change
@@ -1,38 +1,20 @@
@inline function first_datalayout_in_bc(args::Tuple, rargs...)
x1 = first_datalayout_in_bc(args[1], rargs...)
x1 isa AbstractData && return x1
return first_datalayout_in_bc(Base.tail(args), rargs...)
idx = unrolled_findfirst(Base.Fix2(isa, AbstractData), args)
return isnothing(idx) ? nothing : args[idx]
end

@inline first_datalayout_in_bc(args::Tuple{Any}, rargs...) =
first_datalayout_in_bc(args[1], rargs...)
@inline first_datalayout_in_bc(args::Tuple{}, rargs...) = nothing
@inline first_datalayout_in_bc(x) = nothing
@inline first_datalayout_in_bc(x::AbstractData) = x

@inline first_datalayout_in_bc(bc::Base.Broadcast.Broadcasted) =
first_datalayout_in_bc(bc.args)

@inline _has_uniform_datalayouts_args(truesofar, start, args::Tuple, rargs...) =
truesofar &&
_has_uniform_datalayouts(truesofar, start, args[1], rargs...) &&
_has_uniform_datalayouts_args(truesofar, start, Base.tail(args), rargs...)

@inline _has_uniform_datalayouts_args(
truesofar,
start,
args::Tuple{Any},
rargs...,
) = truesofar && _has_uniform_datalayouts(truesofar, start, args[1], rargs...)
@inline _has_uniform_datalayouts_args(truesofar, _, args::Tuple{}, rargs...) =
truesofar

@inline _has_uniform_datalayouts_args(start, args::Tuple, rargs...) =
unrolled_all(args) do arg
_has_uniform_datalayouts(start, arg, rargs...)
end
@inline function _has_uniform_datalayouts(
truesofar,
start,
bc::Base.Broadcast.Broadcasted,
)
return truesofar && _has_uniform_datalayouts_args(truesofar, start, bc.args)
return _has_uniform_datalayouts_args(start, bc.args)
end
for DL in (
:IJKFVH,
Expand All @@ -50,11 +32,11 @@ for DL in (
:VIHF,
)
@eval begin
@inline _has_uniform_datalayouts(truesofar, ::$(DL), ::$(DL)) = true
@inline _has_uniform_datalayouts(::$(DL), ::$(DL)) = true
end
end
@inline _has_uniform_datalayouts(truesofar, _, x::AbstractData) = false
@inline _has_uniform_datalayouts(truesofar, _, x) = truesofar
@inline _has_uniform_datalayouts(_, x::AbstractData) = false
@inline _has_uniform_datalayouts(_, x) = true

"""
has_uniform_datalayouts
Expand All @@ -69,6 +51,6 @@ Note: a broadcasted object can have different _types_,
function has_uniform_datalayouts end

@inline has_uniform_datalayouts(bc::Base.Broadcast.Broadcasted) =
_has_uniform_datalayouts_args(true, first_datalayout_in_bc(bc), bc.args)
_has_uniform_datalayouts_args(first_datalayout_in_bc(bc), bc.args)

@inline has_uniform_datalayouts(bc::AbstractData) = true
24 changes: 12 additions & 12 deletions src/DataLayouts/non_extruded_broadcasted.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#! format: off
# ============================================================ Adapted from Base.Broadcast (julia version 1.10.4)
import Base.Broadcast: BroadcastStyle
import UnrolledUtilities: unrolled_map

struct NonExtrudedBroadcasted{
Style <: Union{Nothing, BroadcastStyle},
Axes,
Expand Down Expand Up @@ -54,13 +56,11 @@ end
NonExtrudedBroadcasted(bc.style, bc.f, to_non_extruded_broadcasted_args(bc.args), bc.axes)
@inline to_non_extruded_broadcasted(x) = x

@inline to_non_extruded_broadcasted_args(args::Tuple) = (
to_non_extruded_broadcasted(args[1]),
to_non_extruded_broadcasted_args(Base.tail(args))...,
)
@inline to_non_extruded_broadcasted_args(args::Tuple{Any}) =
(to_non_extruded_broadcasted(args[1]),)
@inline to_non_extruded_broadcasted_args(args::Tuple{}) = ()
@inline function to_non_extruded_broadcasted_args(args::Tuple)
unrolled_map(args) do arg
to_non_extruded_broadcasted(arg)
end
end

# CartesianIndex{0} is used for DataF and empty data cases
# And sometimes axes(bc) returns a (e.g.,) CenterFiniteDifferenceSpace
Expand Down Expand Up @@ -140,11 +140,11 @@ Base.@propagate_inbounds function _broadcast_getindex(
end
@inline _broadcast_getindex_evalf(f::Tf, args::Vararg{Any, N}) where {Tf, N} =
f(args...) # not propagate_inbounds
Base.@propagate_inbounds _getindex(args::Tuple, I) =
(_broadcast_getindex(args[1], I), _getindex(Base.tail(args), I)...)
Base.@propagate_inbounds _getindex(args::Tuple{Any}, I) =
(_broadcast_getindex(args[1], I),)
Base.@propagate_inbounds _getindex(args::Tuple{}, I) = ()
Base.@propagate_inbounds function _getindex(args::Tuple, I)
unrolled_map(args) do arg
_broadcast_getindex(arg, I)
end
end

@inline Base.axes(bc::NonExtrudedBroadcasted) = _axes(bc, bc.axes)
_axes(::NonExtrudedBroadcasted, axes::Tuple) = axes
Expand Down
18 changes: 9 additions & 9 deletions src/DataLayouts/struct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@ Determines whether an object of type `S` can be stored in an array with elements
of type `T` by recursively checking whether the non-empty fields of `S` can be
stored in such an array. If `S` is empty, this is always true.
"""
is_valid_basetype(::Type{T}, ::Type{S}) where {T, S} =
function is_valid_basetype(::Type{T}, ::Type{S}) where {T, S}
sizeof(S) == 0 ||
(fieldcount(S) > 0 && is_valid_basetype(T, fieldtypes(S)...))
fieldcount(S) > 0 &&
unrolled_all(s -> is_valid_basetype(T, s), fieldtypes(S))
end
is_valid_basetype(::Type{T}, ::Type{<:T}) where {T} = true
is_valid_basetype(::Type{T}, ::Type{S}, Ss...) where {T, S} =
is_valid_basetype(T, S) && is_valid_basetype(T, Ss...)

"""
check_basetype(::Type{T}, ::Type{S})
Expand Down Expand Up @@ -61,13 +61,13 @@ Changes the type parameters of `S` to produce a new type `S′` such that, if
"""
replace_basetype(::Type{T}, ::Type{T′}, ::Type{S}) where {T, T′, S} =
length(S.parameters) == 0 ? S :
S.name.wrapper{replace_basetypes(T, T′, S.parameters...)...}
S.name.wrapper{replace_basetypes(T, T′, Tuple(S.parameters))...}
replace_basetype(::Type{T}, ::Type{T′}, ::Type{<:T}) where {T, T′} = T′
replace_basetype(::Type{T}, ::Type{T′}, value) where {T, T′} = value
replace_basetypes(::Type{T}, ::Type{T′}, value) where {T, T′} =
(replace_basetype(T, T′, value),)
replace_basetypes(::Type{T}, ::Type{T′}, value, values...) where {T, T′} =
(replace_basetype(T, T′, value), replace_basetypes(T, T′, values...)...)
replace_basetypes(::Type{T}, ::Type{T′}, values) where {T, T′} =
unrolled_map(values) do value
replace_basetype(T, T′, value)
end
# TODO: This could potentially lead to some annoying bugs, since it replaces
# type parameters instead of field types. So, if `S` has `Float64` as a
# parameter, `replace_basetype(Float64, Float32, S)` will replace that parameter
Expand Down
2 changes: 1 addition & 1 deletion src/Fields/Fields.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import ..Utilities: PlusHalf, half
using ..RecursiveApply
using ClimaComms
import Adapt
import UnrolledUtilities: unrolled_map, unrolled_mapreduce, unrolled_findfirst
import UnrolledUtilities: unrolled_map, unrolled_mapreduce, unrolled_findfirst, unrolled_all

import StaticArrays, LinearAlgebra, Statistics, InteractiveUtils

Expand Down
9 changes: 6 additions & 3 deletions src/Fields/broadcast.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import ..DebugOnly: allow_mismatched_spaces_unsafe
import UnrolledUtilities: unrolled_map

"""
AbstractFieldStyle
Expand Down Expand Up @@ -180,9 +181,11 @@ end

# Return underlying DataLayout object, DataStyle of broadcasted
# for `Base.similar` of a Field
_todata_args(args::Tuple) = (todata(args[1]), _todata_args(Base.tail(args))...)
_todata_args(args::Tuple{Any}) = (todata(args[1]),)
_todata_args(::Tuple{}) = ()
# _todata_args(args::Tuple) = (todata(args[1]), _todata_args(Base.tail(args))...)
_todata_args(args::Tuple) =
unrolled_map(args) do arg
todata(arg)
end

todata(obj) = obj
todata(field::Field) = Fields.field_values(field)
Expand Down
57 changes: 16 additions & 41 deletions src/Fields/fieldvector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -231,69 +231,47 @@ function Spaces.weighted_dss!(
end

@inline function first_fieldvector_in_bc(args::Tuple, rargs...)
x1 = first_fieldvector_in_bc(args[1], rargs...)
x1 isa FieldVector && return x1
return first_fieldvector_in_bc(Base.tail(args), rargs...)
idx = unrolled_findfirst(args) do arg
!isnothing(first_fieldvector_in_bc(arg))
end
return isnothing(idx) ? nothing : first_fieldvector_in_bc(args[idx])
end

@inline first_fieldvector_in_bc(args::Tuple{Any}, rargs...) =
first_fieldvector_in_bc(args[1], rargs...)
@inline first_fieldvector_in_bc(args::Tuple{}, rargs...) = nothing
@inline first_fieldvector_in_bc(x) = nothing
@inline first_fieldvector_in_bc(x::FieldVector) = x

@inline first_fieldvector_in_bc(
bc::Base.Broadcast.Broadcasted{FieldVectorStyle},
) = first_fieldvector_in_bc(bc.args)
@inline first_fieldvector_in_bc(fv::FieldVector) = fv
@inline first_fieldvector_in_bc(x) = nothing

@inline _is_diagonal_bc_args(
truesofar,
::Type{TStart},
args::Tuple,
rargs...,
) where {TStart} =
truesofar &&
_is_diagonal_bc(truesofar, TStart, args[1], rargs...) &&
_is_diagonal_bc_args(truesofar, TStart, Base.tail(args), rargs...)

@inline _is_diagonal_bc_args(
truesofar,
::Type{TStart},
args::Tuple{Any},
rargs...,
) where {TStart} =
truesofar && _is_diagonal_bc(truesofar, TStart, args[1], rargs...)
@inline _is_diagonal_bc_args(
truesofar,
::Type{TStart},
args::Tuple{},
rargs...,
) where {TStart} = truesofar
unrolled_all(args) do arg
_is_diagonal_bc(TStart, arg)
end

@inline function _is_diagonal_bc(
truesofar,
::Type{TStart},
bc::Base.Broadcast.Broadcasted{FieldVectorStyle},
) where {TStart}
return truesofar && _is_diagonal_bc_args(truesofar, TStart, bc.args)
return _is_diagonal_bc_args(TStart, bc.args)
end

@inline _is_diagonal_bc(
truesofar,
::Type{TStart},
::TStart,
) where {TStart <: FieldVector} = true
@inline _is_diagonal_bc(
truesofar,
::Type{TStart},
x::FieldVector,
) where {TStart} = false
@inline _is_diagonal_bc(truesofar, ::Type{TStart}, x) where {TStart} = truesofar
@inline _is_diagonal_bc(::Type{TStart}, x) where {TStart} = true

# Find the first fieldvector in the broadcast expression (BCE),
# and compare against every other fieldvector in the BCE
@inline is_diagonal_bc(bc::Base.Broadcast.Broadcasted{FieldVectorStyle}) =
_is_diagonal_bc_args(true, typeof(first_fieldvector_in_bc(bc)), bc.args)
_is_diagonal_bc_args(typeof(first_fieldvector_in_bc(bc)), bc.args)

# Specialize on FieldVectorStyle to avoid inference failure
# in fieldvector broadcast expressions:
Expand All @@ -318,13 +296,10 @@ end

# Recursively call transform_bc_args() on broadcast arguments in a way that is statically reducible by the optimizer
# see Base.Broadcast.preprocess_args
@inline transform_bc_args(args::Tuple, inds...) = (
transform_broadcasted(args[1], inds...),
transform_bc_args(Base.tail(args), inds...)...,
)
@inline transform_bc_args(args::Tuple{Any}, inds...) =
(transform_broadcasted(args[1], inds...),)
@inline transform_bc_args(args::Tuple{}, inds...) = ()
@inline transform_bc_args(args::Tuple, inds...) =
unrolled_map(args) do arg
transform_broadcasted(arg, inds...)
end

@inline function transform_broadcasted(
bc::Base.Broadcast.Broadcasted{FieldVectorStyle},
Expand Down
10 changes: 5 additions & 5 deletions src/MatrixFields/field_name_dict.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ struct FieldNameDict{
end
end
function FieldNameDict{T}(key_entry_pairs::Pair{<:T}...) where {T}
keys = unrolled_map(pair -> pair[1], key_entry_pairs)
entries = unrolled_map(pair -> pair[2], key_entry_pairs)
keys = unrolled_map(first, key_entry_pairs)
entries = unrolled_map(last, key_entry_pairs)
return FieldNameDict(FieldNameSet{T}(keys), entries)
end

Expand Down Expand Up @@ -466,7 +466,7 @@ parent of the first entry in `dict` that is a `Fields.Field`. If no such entry
is found, `target_type` defaults to `Number`.
"""
function get_scalar_keys(dict::FieldMatrix)
first_field_idx = unrolled_findfirst(x -> x isa Fields.Field, dict.entries)
first_field_idx = unrolled_findfirst(Base.Fix2(isa, Fields.Field), dict.entries)
target_type = Val(
isnothing(first_field_idx) ? Number :
eltype(parent(dict.entries[first_field_idx])),
Expand Down Expand Up @@ -597,7 +597,7 @@ function check_diagonal_matrix(matrix, error_message_start = "The matrix")
!is_diagonal_matrix_entry(pair[2])
end
non_diagonal_entry_keys =
FieldMatrixKeys(unrolled_map(pair -> pair[1], non_diagonal_entry_pairs))
FieldMatrixKeys(unrolled_map(first, non_diagonal_entry_pairs))
isempty(non_diagonal_entry_keys) || error(
"$error_message_start has non-diagonal entries at the following keys: \
$(set_string(non_diagonal_entry_keys))",
Expand All @@ -611,7 +611,7 @@ Checks whether the `FieldNameDict` `dict` contains any un-materialized
`AbstractBroadcasted` entries.
"""
is_lazy(dict) =
unrolled_any(entry -> entry isa Base.AbstractBroadcasted, values(dict))
unrolled_any(Base.Fix2(isa, Base.AbstractBroadcasted), values(dict))

"""
lazy_main_diagonal(matrix)
Expand Down
2 changes: 1 addition & 1 deletion src/MatrixFields/operator_matrices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ has_affine_bc(op) = any(
op.bcs,
)

uses_extrapolate(op) = unrolled_any(bc -> bc isa Operators.Extrapolate, op.bcs)
uses_extrapolate(op) = unrolled_any(Base.Fix2(isa, Operators.Extrapolate), op.bcs)

################################################################################

Expand Down
Loading
Loading