diff --git a/lib/cusparse/src/array.jl b/lib/cusparse/src/array.jl index bc370c144d..be6e2d631a 100644 --- a/lib/cusparse/src/array.jl +++ b/lib/cusparse/src/array.jl @@ -288,18 +288,25 @@ Base.similar(Mat::CuSparseMatrixCOO, T::Type) = CuSparseMatrixCOO(copy(Mat.rowIn Base.similar(Mat::CuSparseMatrixCSC, T::Type, N::Int, M::Int) = CuSparseMatrixCSC(CUDACore.zeros(Int32, 1), CUDACore.zeros(Int32, 0), CuVector{T}(undef, 0), (N, M)) Base.similar(Mat::CuSparseMatrixCSR, T::Type, N::Int, M::Int) = CuSparseMatrixCSR(CUDACore.zeros(Int32, 1), CUDACore.zeros(Int32, 0), CuVector{T}(undef, 0), (N,M)) Base.similar(Mat::CuSparseMatrixCOO, T::Type, N::Int, M::Int) = CuSparseMatrixCOO(CUDACore.zeros(Int32, 0), CUDACore.zeros(Int32, 0), CuVector{T}(undef, 0), (N,M)) +# For dims higher than 3 we default to dense (GPU) arrays as does Base +Base.similar(::CuSparseMatrix, T::Type, dims::Vararg{Int,N}) where N = CuArray{T}(undef, dims) Base.similar(Mat::CuSparseMatrixCSC{Tv, Ti}, N::Int, M::Int) where {Tv, Ti} = similar(Mat, Tv, N, M) Base.similar(Mat::CuSparseMatrixCSR{Tv, Ti}, N::Int, M::Int) where {Tv, Ti} = similar(Mat, Tv, N, M) Base.similar(Mat::CuSparseMatrixCOO{Tv, Ti}, N::Int, M::Int) where {Tv, Ti} = similar(Mat, Tv, N, M) +# For dims higher than 3 we default to dense arrays as does Base +Base.similar(::CuSparseMatrix{Tv}, dims::Vararg{Int, N}) where {N, Tv} = CuArray{Tv}(undef, dims) Base.similar(Mat::CuSparseMatrixCSC, T::Type, dims::Tuple{Int, Int}) = similar(Mat, T, dims...) Base.similar(Mat::CuSparseMatrixCSR, T::Type, dims::Tuple{Int, Int}) = similar(Mat, T, dims...) Base.similar(Mat::CuSparseMatrixCOO, T::Type, dims::Tuple{Int, Int}) = similar(Mat, T, dims...) +# The next one would overwrite `similar` for dims::Tuple{Int}, which breaks linearization +# Base.similar(::CuSparseMatrix, T::Type, dims::NTuple{N, Int}) where N = CuArray{T}(undef, dims) Base.similar(Mat::CuSparseMatrixCSC, dims::Tuple{Int, Int}) = similar(Mat, dims...) Base.similar(Mat::CuSparseMatrixCSR, dims::Tuple{Int, Int}) = similar(Mat, dims...) Base.similar(Mat::CuSparseMatrixCOO, dims::Tuple{Int, Int}) = similar(Mat, dims...) +Base.similar(::CuSparseMatrix{Tv}, dims::NTuple{N, Int}) where {N, Tv} = CuArray{Tv}(undef, dims) Base.similar(Mat::CuSparseArrayCSR) = CuSparseArrayCSR(copy(Mat.rowPtr), copy(Mat.colVal), similar(nonzeros(Mat)), size(Mat))