Skip to content

Commit 5e2bcbf

Browse files
committed
Add support for nonuniform data structures [perf]
1 parent ee39fa6 commit 5e2bcbf

22 files changed

Lines changed: 757 additions & 951 deletions

docs/src/APIs/datalayouts_api.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,14 @@ DataLayouts.IHF
1818
DataLayouts.IJHF
1919
DataLayouts.VIHF
2020
DataLayouts.VIJHF
21+
DataLayouts.parent_array_type
22+
DataLayouts.promote_parent_array_type
23+
DataLayouts.default_basetype
24+
DataLayouts.check_basetype
25+
DataLayouts.checked_valid_basetype
26+
DataLayouts.storage_length
27+
DataLayouts.struct_field_view
28+
DataLayouts.set_struct!
29+
DataLayouts.get_struct
30+
DataLayouts.bitcast_struct
2131
```

docs/src/APIs/utilities_api.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ CurrentModule = ClimaCore
77
```@docs
88
Utilities.PlusHalf
99
Utilities.half
10+
Utilities.replace_type_parameter
1011
```
1112

1213
## Utilities.Cache

ext/cuda/data_layouts.jl

Lines changed: 14 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -12,51 +12,27 @@ import ClimaCore.DataLayouts: fused_copyto!
1212
import Adapt
1313
import CUDA
1414

15-
parent_array_type(::Type{<:CUDA.CuArray{T, N, B} where {N}}) where {T, B} =
16-
CUDA.CuArray{T, N, B} where {N}
17-
18-
# allow on-device use of lazy broadcast objects
15+
# Ensure that all CuArrays have the same memory buffer type.
1916
parent_array_type(
20-
::Type{<:CUDA.CuDeviceArray{T, N, A} where {N}},
21-
) where {T, A} = CUDA.CuDeviceArray{T, N, A} where {N}
22-
23-
# Ensure that both parent array types have the same memory buffer type.
24-
promote_parent_array_type(
25-
::Type{CUDA.CuArray{T1, N, B} where {N}},
26-
::Type{CUDA.CuArray{T2, N, B} where {N}},
27-
) where {T1, T2, B} = CUDA.CuArray{promote_type(T1, T2), N, B} where {N}
28-
29-
# allow on-device use of lazy broadcast objects
17+
::Type{<:CUDA.CuArray{<:Any, <:Any, B}},
18+
::Type{T},
19+
) where {T, B} = CUDA.CuArray{T, <:Any, B}
3020
promote_parent_array_type(
31-
::Type{CUDA.CuDeviceArray{T1, N, B} where {N}},
32-
::Type{CUDA.CuDeviceArray{T2, N, B} where {N}},
33-
) where {T1, T2, B} = CUDA.CuDeviceArray{promote_type(T1, T2), N, B} where {N}
21+
::Type{CUDA.CuArray{T1, <:Any, B}},
22+
::Type{CUDA.CuArray{T2, <:Any, B}},
23+
) where {T1, T2, B} = CUDA.CuArray{promote_type(T1, T2), <:Any, B}
3424

35-
# allow on-device use of lazy broadcast objects with different type params
36-
promote_parent_array_type(
37-
::Type{CUDA.CuDeviceArray{T1, N, B1} where {N}},
38-
::Type{CUDA.CuDeviceArray{T2, N, B2} where {N}},
39-
) where {T1, T2, B1, B2} =
40-
CUDA.CuDeviceArray{promote_type(T1, T2), N, B} where {N, B}
41-
42-
# allow on-device use of lazy broadcast objects with different type params
25+
# Allow on-device use of lazy broadcast objects.
26+
parent_array_type(::Type{<:CUDA.CuDeviceArray}, ::Type{T}) where {T} =
27+
CUDA.CuDeviceArray{T}
4328
promote_parent_array_type(
4429
::Type{CUDA.CuDeviceArray{T1}},
45-
::Type{CUDA.CuDeviceArray{T2, N, B2} where {N}},
46-
) where {T1, T2, B2} =
47-
CUDA.CuDeviceArray{promote_type(T1, T2), N, B} where {N, B}
48-
49-
promote_parent_array_type(
50-
::Type{CUDA.CuDeviceArray{T1, N, B1} where {N}},
51-
::Type{CUDA.CuDeviceArray{T2} where {N}},
52-
) where {T1, T2, B1} =
53-
CUDA.CuDeviceArray{promote_type(T1, T2), N, B} where {N, B}
30+
::Type{CUDA.CuDeviceArray{T2}},
31+
) where {T1, T2} = CUDA.CuDeviceArray{promote_type(T1, T2)}
5432

5533
# Make `similar` accept our special `UnionAll` parent array type for CuArray.
56-
Base.similar(
57-
::Type{CUDA.CuArray{T, N′, B} where {N′}},
58-
dims::Dims{N},
59-
) where {T, N, B} = similar(CUDA.CuArray{T, N, B}, dims)
34+
Base.similar(::Type{CUDA.CuArray{T, <:Any, B}}, dims::Dims{N}) where {T, N, B} =
35+
similar(CUDA.CuArray{T, N, B}, dims)
6036

6137
unval(::Val{CI}) where {CI} = CI
6238
unval(CI) = CI

ext/cuda/operators_sem_shmem.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ Base.@propagate_inbounds function operator_shmem(
9292
Nq = Quadratures.degrees_of_freedom(QS)
9393
# allocate temp output
9494
RT = operator_return_eltype(op, eltype(arg))
95-
Nf = DataLayouts.typesize(FT, RT)
95+
Nf = DataLayouts.storage_length(FT, RT)
9696
WJv¹ = CUDA.CuStaticSharedArray(RT, (Nq, Nvt))
9797
return (WJv¹,)
9898
end
@@ -107,7 +107,7 @@ Base.@propagate_inbounds function operator_shmem(
107107
Nq = Quadratures.degrees_of_freedom(QS)
108108
# allocate temp output
109109
RT = operator_return_eltype(op, eltype(arg))
110-
Nf = DataLayouts.typesize(FT, RT)
110+
Nf = DataLayouts.storage_length(FT, RT)
111111
WJv¹ = CUDA.CuStaticSharedArray(RT, (Nq, Nq, Nvt))
112112
WJv² = CUDA.CuStaticSharedArray(RT, (Nq, Nq, Nvt))
113113
return (WJv¹, WJv²)

0 commit comments

Comments
 (0)