@@ -12,51 +12,27 @@ import ClimaCore.DataLayouts: fused_copyto!
1212import Adapt
1313import CUDA
1414
15- parent_array_type (:: Type{<:CUDA.CuArray{T, N, B} where {N}} ) where {T, B} =
16- CUDA. CuArray{T, N, B} where {N}
17-
18- # allow on-device use of lazy broadcast objects
15+ # Ensure that all CuArrays have the same memory buffer type.
1916parent_array_type (
20- :: Type{<:CUDA.CuDeviceArray{T, N, A} where {N}} ,
21- ) where {T, A} = CUDA. CuDeviceArray{T, N, A} where {N}
22-
23- # Ensure that both parent array types have the same memory buffer type.
24- promote_parent_array_type (
25- :: Type{CUDA.CuArray{T1, N, B} where {N}} ,
26- :: Type{CUDA.CuArray{T2, N, B} where {N}} ,
27- ) where {T1, T2, B} = CUDA. CuArray{promote_type (T1, T2), N, B} where {N}
28-
29- # allow on-device use of lazy broadcast objects
17+ :: Type{<:CUDA.CuArray{<:Any, <:Any, B}} ,
18+ :: Type{T} ,
19+ ) where {T, B} = CUDA. CuArray{T, <: Any , B}
3020promote_parent_array_type (
31- :: Type{CUDA.CuDeviceArray {T1, N , B} where {N }} ,
32- :: Type{CUDA.CuDeviceArray {T2, N , B} where {N }} ,
33- ) where {T1, T2, B} = CUDA. CuDeviceArray {promote_type (T1, T2), N , B} where {N }
21+ :: Type{CUDA.CuArray {T1, <:Any , B}} ,
22+ :: Type{CUDA.CuArray {T2, <:Any , B}} ,
23+ ) where {T1, T2, B} = CUDA. CuArray {promote_type (T1, T2), <: Any , B}
3424
35- # allow on-device use of lazy broadcast objects with different type params
36- promote_parent_array_type (
37- :: Type{CUDA.CuDeviceArray{T1, N, B1} where {N}} ,
38- :: Type{CUDA.CuDeviceArray{T2, N, B2} where {N}} ,
39- ) where {T1, T2, B1, B2} =
40- CUDA. CuDeviceArray{promote_type (T1, T2), N, B} where {N, B}
41-
42- # allow on-device use of lazy broadcast objects with different type params
25+ # Allow on-device use of lazy broadcast objects.
26+ parent_array_type (:: Type{<:CUDA.CuDeviceArray} , :: Type{T} ) where {T} =
27+ CUDA. CuDeviceArray{T}
4328promote_parent_array_type (
4429 :: Type{CUDA.CuDeviceArray{T1}} ,
45- :: Type{CUDA.CuDeviceArray{T2, N, B2} where {N}} ,
46- ) where {T1, T2, B2} =
47- CUDA. CuDeviceArray{promote_type (T1, T2), N, B} where {N, B}
48-
49- promote_parent_array_type (
50- :: Type{CUDA.CuDeviceArray{T1, N, B1} where {N}} ,
51- :: Type{CUDA.CuDeviceArray{T2} where {N}} ,
52- ) where {T1, T2, B1} =
53- CUDA. CuDeviceArray{promote_type (T1, T2), N, B} where {N, B}
30+ :: Type{CUDA.CuDeviceArray{T2}} ,
31+ ) where {T1, T2} = CUDA. CuDeviceArray{promote_type (T1, T2)}
5432
5533# Make `similar` accept our special `UnionAll` parent array type for CuArray.
56- Base. similar (
57- :: Type{CUDA.CuArray{T, N′, B} where {N′}} ,
58- dims:: Dims{N} ,
59- ) where {T, N, B} = similar (CUDA. CuArray{T, N, B}, dims)
34+ Base. similar (:: Type{CUDA.CuArray{T, <:Any, B}} , dims:: Dims{N} ) where {T, N, B} =
35+ similar (CUDA. CuArray{T, N, B}, dims)
6036
6137unval (:: Val{CI} ) where {CI} = CI
6238unval (CI) = CI
0 commit comments