From 30162b1d76bdd7a8d120e7de488ae7abe86ad959 Mon Sep 17 00:00:00 2001 From: Roger-luo Date: Fri, 15 Jul 2022 23:17:47 +0000 Subject: [PATCH 1/3] limit conversion to Cint --- lib/cusparse/conversions.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lib/cusparse/conversions.jl b/lib/cusparse/conversions.jl index 328180bcfd..6ecf580317 100644 --- a/lib/cusparse/conversions.jl +++ b/lib/cusparse/conversions.jl @@ -136,7 +136,7 @@ end # by flipping rows and columns, we can use that to get CSC to CSR too for elty in (Float32, Float64, ComplexF32, ComplexF64) @eval begin - function CuSparseMatrixCSC{$elty}(csr::CuSparseMatrixCSR{$elty}; inda::SparseChar='O') + function CuSparseMatrixCSC{$elty, Cint}(csr::CuSparseMatrixCSR{$elty, Cint}; inda::SparseChar='O') m,n = size(csr) colPtr = CUDA.zeros(Cint, n+1) rowVal = CUDA.zeros(Cint, nnz(csr)) @@ -156,10 +156,10 @@ for elty in (Float32, Float64, ComplexF32, ComplexF64) $elty, CUSPARSE_ACTION_NUMERIC, inda, CUSPARSE_CSR2CSC_ALG1, buffer) end - CuSparseMatrixCSC(colPtr,rowVal,nzVal,size(csr)) + CuSparseMatrixCSC{$elty, Cint}(colPtr,rowVal,nzVal,size(csr)) end - function CuSparseMatrixCSR{$elty}(csc::CuSparseMatrixCSC{$elty}; inda::SparseChar='O') + function CuSparseMatrixCSR{$elty, Cint}(csc::CuSparseMatrixCSC{$elty, Cint}; inda::SparseChar='O') m,n = size(csc) rowPtr = CUDA.zeros(Cint,m+1) colVal = CUDA.zeros(Cint,nnz(csc)) @@ -179,7 +179,7 @@ for elty in (Float32, Float64, ComplexF32, ComplexF64) $elty, CUSPARSE_ACTION_NUMERIC, inda, CUSPARSE_CSR2CSC_ALG1, buffer) end - CuSparseMatrixCSR(rowPtr,colVal,nzVal,size(csc)) + CuSparseMatrixCSR{$elty, Cint}(rowPtr,colVal,nzVal,size(csc)) end end end From 7af43e1b69ee9f7d594f12a16066dac25d2f007e Mon Sep 17 00:00:00 2001 From: Roger-luo Date: Sat, 16 Jul 2022 01:06:20 +0000 Subject: [PATCH 2/3] fix all tests --- lib/cusparse/array.jl | 72 ++++++++++++++++++++++---------- lib/cusparse/conversions.jl | 80 +++++++++++++++++++++++------------- test/cusparse/conversions.jl | 11 ++++- 3 files changed, 113 insertions(+), 50 deletions(-) diff --git a/lib/cusparse/array.jl b/lib/cusparse/array.jl index 34d52c0a85..0858ad2908 100644 --- a/lib/cusparse/array.jl +++ b/lib/cusparse/array.jl @@ -7,7 +7,7 @@ export CuSparseMatrixCSC, CuSparseMatrixCSR, CuSparseMatrixBSR, CuSparseMatrixCO CuSparseVecOrMat using LinearAlgebra: BlasFloat -using SparseArrays: nonzeroinds, dimlub +using SparseArrays: nonzeroinds, dimlub, SparseMatrixCSC, SparseVector abstract type AbstractCuSparseArray{Tv, Ti, N} <: AbstractSparseArray{Tv, Ti, N} end const AbstractCuSparseVector{Tv, Ti} = AbstractCuSparseArray{Tv, Ti, 1} @@ -27,6 +27,12 @@ mutable struct CuSparseVector{Tv, Ti} <: AbstractCuSparseVector{Tv, Ti} end end +CuSparseVector(A::CuSparseVector) = A + +function CuSparseVector{Tv, Ti}(A::CuSparseVector) where {Tv, Ti} + return CuSparseVector{Tv, Ti}(A.iPtr, A.nzVal, A.len) +end + function CUDA.unsafe_free!(xs::CuSparseVector) unsafe_free!(nonzeroinds(xs)) unsafe_free!(nonzeros(xs)) @@ -48,6 +54,10 @@ end CuSparseMatrixCSC(A::CuSparseMatrixCSC) = A +function CuSparseMatrixCSC{Tv, Ti}(A::CuSparseMatrixCSC) where {Tv, Ti} + return CuSparseMatrixCSC{Tv, Ti}(A.colPtr, A.rowVal, A.nzVal, A.dims) +end + function CUDA.unsafe_free!(xs::CuSparseMatrixCSC) unsafe_free!(xs.colPtr) unsafe_free!(rowvals(xs)) @@ -83,6 +93,10 @@ end CuSparseMatrixCSR(A::CuSparseMatrixCSR) = A +function CuSparseMatrixCSR{Tv, Ti}(A::CuSparseMatrixCSR) where {Tv, Ti} + return CuSparseMatrixCSR{Tv, Ti}(A.rowPtr, A.colVal, A.nzVal, A.dims) +end + function CUDA.unsafe_free!(xs::CuSparseMatrixCSR) unsafe_free!(xs.rowPtr) unsafe_free!(xs.colVal) @@ -112,6 +126,9 @@ mutable struct CuSparseMatrixBSR{Tv, Ti} <: AbstractCuSparseMatrix{Tv, Ti} end CuSparseMatrixBSR(A::CuSparseMatrixBSR) = A +function CuSparseMatrixBSR{Tv, Ti}(A::CuSparseMatrixBSR) where {Tv, Ti} + return CuSparseMatrixBSR{Tv, Ti}(A.rowPtr, A.colVal, A.nzVal, A.dims, A.blockDim, A.dir, A.nnz) +end function CUDA.unsafe_free!(xs::CuSparseMatrixBSR) unsafe_free!(xs.rowPtr) @@ -140,6 +157,9 @@ mutable struct CuSparseMatrixCOO{Tv, Ti} <: AbstractCuSparseMatrix{Tv, Ti} end CuSparseMatrixCOO(A::CuSparseMatrixCOO) = A +function CuSparseMatrixCOO{Tv, Ti}(A::CuSparseMatrixCOO) where {Tv, Ti} + return CuSparseMatrixCOO{Tv, Ti}(A.rowInd, A.colInd, A.nzVal, A.dims, A.nnz) +end """ Utility union type of [`CuSparseMatrixCSC`](@ref), [`CuSparseMatrixCSR`](@ref), @@ -353,28 +373,38 @@ end ## interop with sparse CPU arrays # cpu to gpu -# NOTE: we eagerly convert the indices to Cint here to avoid additional conversion later on -CuSparseVector{T}(Vec::SparseVector) where {T} = - CuSparseVector(CuVector{Cint}(Vec.nzind), CuVector{T}(Vec.nzval), length(Vec)) -CuSparseVector{T}(Mat::SparseMatrixCSC) where {T} = +CuSparseVector{Tv, Ti}(Vec::SparseVector) where {Tv, Ti} = + CuSparseVector(CuVector{Ti}(Vec.nzind), CuVector{Tv}(Vec.nzval), length(Vec)) +CuSparseVector{Tv, Ti}(Mat::SparseMatrixCSC) where {Tv, Ti} = size(Mat,2) == 1 ? - CuSparseVector(CuVector{Cint}(Mat.rowval), CuVector{T}(Mat.nzval), size(Mat)[1]) : + CuSparseVector(CuVector{Ti}(Mat.rowval), CuVector{Tv}(Mat.nzval), size(Mat)[1]) : throw(ArgumentError("The input argument must have a single column")) -CuSparseMatrixCSC{T}(Vec::SparseVector) where {T} = - CuSparseMatrixCSC{T}(CuVector{Cint}([1]), CuVector{Cint}(Vec.nzind), - CuVector{T}(Vec.nzval), size(Vec)) -CuSparseMatrixCSC{T}(Mat::SparseMatrixCSC) where {T} = - CuSparseMatrixCSC{T}(CuVector{Cint}(Mat.colptr), CuVector{Cint}(Mat.rowval), - CuVector{T}(Mat.nzval), size(Mat)) -CuSparseMatrixCSR{T}(Mat::Transpose{Tv, <:SparseMatrixCSC}) where {T, Tv} = - CuSparseMatrixCSR{T}(CuVector{Cint}(parent(Mat).colptr), CuVector{Cint}(parent(Mat).rowval), - CuVector{T}(parent(Mat).nzval), size(Mat)) -CuSparseMatrixCSR{T}(Mat::Adjoint{Tv, <:SparseMatrixCSC}) where {T, Tv} = - CuSparseMatrixCSR{T}(CuVector{Cint}(parent(Mat).colptr), CuVector{Cint}(parent(Mat).rowval), - CuVector{T}(conj.(parent(Mat).nzval)), size(Mat)) -CuSparseMatrixCSR{T}(Mat::SparseMatrixCSC) where {T} = CuSparseMatrixCSR(CuSparseMatrixCSC{T}(Mat)) -CuSparseMatrixBSR{T}(Mat::SparseMatrixCSC, blockdim) where {T} = CuSparseMatrixBSR(CuSparseMatrixCSR{T}(Mat), blockdim) -CuSparseMatrixCOO{T}(Mat::SparseMatrixCSC) where {T} = CuSparseMatrixCOO(CuSparseMatrixCSR{T}(Mat)) +CuSparseMatrixCSC{Tv, Ti}(Vec::SparseVector) where {Tv, Ti} = + CuSparseMatrixCSC{Tv}(CuVector{Ti}([1]), CuVector{Ti}(Vec.nzind), + CuVector{Tv}(Vec.nzval), size(Vec)) +CuSparseMatrixCSC{Tv, Ti}(Mat::SparseMatrixCSC) where {Tv, Ti} = + CuSparseMatrixCSC{Tv, Ti}(CuVector{Ti}(Mat.colptr), CuVector{Ti}(Mat.rowval), + CuVector{Tv}(Mat.nzval), size(Mat)) +CuSparseMatrixCSR{Tv, Ti}(Mat::Transpose{<:Any, <:SparseMatrixCSC}) where {Tv, Ti} = + CuSparseMatrixCSR{Tv, Ti}(CuVector{Ti}(parent(Mat).colptr), CuVector{Ti}(parent(Mat).rowval), + CuVector{Tv}(parent(Mat).nzval), size(Mat)) +CuSparseMatrixCSR{Tv, Ti}(Mat::Adjoint{<:Any, <:SparseMatrixCSC}) where {Tv, Ti} = + CuSparseMatrixCSR{Tv, Ti}(CuVector{Ti}(parent(Mat).colptr), CuVector{Ti}(parent(Mat).rowval), + CuVector{Tv}(conj.(parent(Mat).nzval)), size(Mat)) +CuSparseMatrixCSR{Tv, Ti}(Mat::SparseMatrixCSC) where {Tv, Ti} = CuSparseMatrixCSR(CuSparseMatrixCSC{Tv, Ti}(Mat)) +CuSparseMatrixBSR{Tv, Ti}(Mat::SparseMatrixCSC, blockdim) where {Tv, Ti} = CuSparseMatrixBSR(CuSparseMatrixCSR{Tv, Ti}(Mat), blockdim) +CuSparseMatrixCOO{Tv, Ti}(Mat::SparseMatrixCSC) where {Tv, Ti} = CuSparseMatrixCOO(CuSparseMatrixCSR{Tv, Ti}(Mat)) + +# NOTE: we eagerly convert the indices to Cint here to avoid additional conversion later on +CuSparseVector{Tv}(Vec::SparseVector) where {Tv} = CuSparseVector{Tv, Cint}(Vec) +CuSparseVector{Tv}(Mat::SparseMatrixCSC) where {Tv} = CuSparseVector{Tv, Cint}(Mat) +CuSparseMatrixCSC{Tv}(Vec::SparseVector) where {Tv} = CuSparseMatrixCSC{Tv, Cint}(Vec) +CuSparseMatrixCSC{Tv}(Mat::SparseMatrixCSC) where {Tv} = CuSparseMatrixCSC{Tv, Cint}(Mat) +CuSparseMatrixCSR{Tv}(Mat::Transpose{<:Any, <:SparseMatrixCSC}) where {Tv} = CuSparseMatrixCSR{Tv, Cint}(Mat) +CuSparseMatrixCSR{Tv}(Mat::Adjoint{<:Any, <:SparseMatrixCSC}) where {Tv} = CuSparseMatrixCSR{Tv, Cint}(Mat) +CuSparseMatrixCSR{Tv}(Mat::SparseMatrixCSC) where {Tv} = CuSparseMatrixCSR{Tv, Cint}(Mat) +CuSparseMatrixBSR{Tv}(Mat::SparseMatrixCSC, blockdim) where {Tv} = CuSparseMatrixBSR{Tv, Cint}(Mat, blockdim) +CuSparseMatrixCOO{Tv}(Mat::SparseMatrixCSC) where {Tv} = CuSparseMatrixCOO{Tv, Cint}(Mat) # untyped variants CuSparseVector(x::AbstractSparseArray{T}) where {T} = CuSparseVector{T}(x) diff --git a/lib/cusparse/conversions.jl b/lib/cusparse/conversions.jl index 6ecf580317..94bdc89bce 100644 --- a/lib/cusparse/conversions.jl +++ b/lib/cusparse/conversions.jl @@ -136,6 +136,11 @@ end # by flipping rows and columns, we can use that to get CSC to CSR too for elty in (Float32, Float64, ComplexF32, ComplexF64) @eval begin + CuSparseMatrixCSC{$elty}(csr::CuSparseMatrixCSR{$elty, Cint}; inda::SparseChar='O') = + CuSparseMatrixCSC{$elty, Cint}(csr; inda) + CuSparseMatrixCSR{$elty}(csc::CuSparseMatrixCSC{$elty, Cint}; inda::SparseChar='O') = + CuSparseMatrixCSR{$elty, Cint}(csc; inda) + function CuSparseMatrixCSC{$elty, Cint}(csr::CuSparseMatrixCSR{$elty, Cint}; inda::SparseChar='O') m,n = size(csr) colPtr = CUDA.zeros(Cint, n+1) @@ -189,7 +194,12 @@ end for (elty, welty) in ((:Float16, :Float32), (:ComplexF16, :ComplexF32)) @eval begin - function CuSparseMatrixCSC{$elty}(csr::CuSparseMatrixCSR{$elty}; inda::SparseChar='O') + CuSparseMatrixCSC{$elty}(csr::CuSparseMatrixCSR{$elty, Cint}; inda::SparseChar='O') = + CuSparseMatrixCSC{$elty, Cint}(csr; inda) + CuSparseMatrixCSR{$elty}(csc::CuSparseMatrixCSC{$elty, Cint}; inda::SparseChar='O') = + CuSparseMatrixCSR{$elty, Cint}(csc; inda) + + function CuSparseMatrixCSC{$elty, Cint}(csr::CuSparseMatrixCSR{$elty, Cint}; inda::SparseChar='O') m,n = size(csr) colPtr = CUDA.zeros(Cint, n+1) rowVal = CUDA.zeros(Cint, nnz(csr)) @@ -210,15 +220,15 @@ for (elty, welty) in ((:Float16, :Float32), $elty, CUSPARSE_ACTION_NUMERIC, inda, CUSPARSE_CSR2CSC_ALG1, buffer) end - return CuSparseMatrixCSC(colPtr,rowVal,nzVal,size(csr)) + return CuSparseMatrixCSC{$elty, Cint}(colPtr,rowVal,nzVal,size(csr)) else - wide_csr = CuSparseMatrixCSR(csr.rowPtr, csr.colVal, convert(CuVector{$welty}, nonzeros(csr)), size(csr)) - wide_csc = CuSparseMatrixCSC(wide_csr) - return CuSparseMatrixCSC(wide_csc.colPtr, wide_csc.rowVal, convert(CuVector{$elty}, nonzeros(wide_csc)), size(wide_csc)) + wide_csr = CuSparseMatrixCSR{$welty, Cint}(csr.rowPtr, csr.colVal, convert(CuVector{$welty}, nonzeros(csr)), size(csr)) + wide_csc = CuSparseMatrixCSC{$welty, Cint}(wide_csr) + return CuSparseMatrixCSC{$elty, Cint}(wide_csc.colPtr, wide_csc.rowVal, convert(CuVector{$elty}, nonzeros(wide_csc)), size(wide_csc)) end end - function CuSparseMatrixCSR{$elty}(csc::CuSparseMatrixCSC{$elty}; inda::SparseChar='O') + function CuSparseMatrixCSR{$elty, Cint}(csc::CuSparseMatrixCSC{$elty, Cint}; inda::SparseChar='O') m,n = size(csc) rowPtr = CUDA.zeros(Cint,m+1) colVal = CUDA.zeros(Cint,nnz(csc)) @@ -239,11 +249,11 @@ for (elty, welty) in ((:Float16, :Float32), $elty, CUSPARSE_ACTION_NUMERIC, inda, CUSPARSE_CSR2CSC_ALG1, buffer) end - return CuSparseMatrixCSR(rowPtr,colVal,nzVal,size(csc)) + return CuSparseMatrixCSR{$elty, Cint}(rowPtr,colVal,nzVal,size(csc)) else - wide_csc = CuSparseMatrixCSC(csc.colPtr, csc.rowVal, convert(CuVector{$welty}, nonzeros(csc)), size(csc)) - wide_csr = CuSparseMatrixCSR(wide_csc) - return CuSparseMatrixCSR(wide_csr.rowPtr, wide_csr.colVal, convert(CuVector{$elty}, nonzeros(wide_csr)), size(wide_csr)) + wide_csc = CuSparseMatrixCSC{$welty, Cint}(csc.colPtr, csc.rowVal, convert(CuVector{$welty}, nonzeros(csc)), size(csc)) + wide_csr = CuSparseMatrixCSR{$welty, Cint}(wide_csc) + return CuSparseMatrixCSR{$elty, Cint}(wide_csr.rowPtr, wide_csr.colVal, convert(CuVector{$elty}, nonzeros(wide_csr)), size(wide_csr)) end end end @@ -255,15 +265,19 @@ for (elty, felty) in ((:Int16, :Float16), (:Int64, :Float64), (:Int128, :ComplexF64)) @eval begin - function CuSparseMatrixCSR{$elty}(csc::CuSparseMatrixCSC{$elty}) - csc_compat = CuSparseMatrixCSC( + CuSparseMatrixCSR{$elty}(csc::CuSparseMatrixCSC{$elty, Cint}) = + CuSparseMatrixCSR{$elty, Cint}(csc) + CuSparseMatrixCSC{$elty}(csr::CuSparseMatrixCSR{$elty, Cint}) = + CuSparseMatrixCSC{$elty, Cint}(csr) + function CuSparseMatrixCSR{$elty, Cint}(csc::CuSparseMatrixCSC{$elty, Cint}) + csc_compat = CuSparseMatrixCSC{$elty, Cint}( csc.colPtr, csc.rowVal, reinterpret($felty, csc.nzVal), size(csc) ) csr_compat = CuSparseMatrixCSR(csc_compat) - CuSparseMatrixCSR( + CuSparseMatrixCSR{$elty, Cint}( csr_compat.rowPtr, csr_compat.colVal, reinterpret($elty, csr_compat.nzVal), @@ -271,15 +285,15 @@ for (elty, felty) in ((:Int16, :Float16), ) end - function CuSparseMatrixCSC{$elty}(csr::CuSparseMatrixCSR{$elty}) - csr_compat = CuSparseMatrixCSR( + function CuSparseMatrixCSC{$elty, Cint}(csr::CuSparseMatrixCSR{$elty, Cint}) + csr_compat = CuSparseMatrixCSR{$elty, Cint}( csr.rowPtr, csr.colVal, reinterpret($felty, csr.nzVal), size(csr) ) csc_compat = CuSparseMatrixCSC(csr_compat) - CuSparseMatrixCSC( + CuSparseMatrixCSC{$elty, Cint}( csc_compat.colPtr, csc_compat.rowVal, reinterpret($elty, csc_compat.nzVal), @@ -296,7 +310,11 @@ for (fname,elty) in ((:cusparseScsr2bsr, :Float32), (:cusparseCcsr2bsr, :ComplexF32), (:cusparseZcsr2bsr, :ComplexF64)) @eval begin - function CuSparseMatrixBSR{$elty}(csr::CuSparseMatrixCSR{$elty}, blockDim::Integer; + CuSparseMatrixBSR{$elty}(csr::CuSparseMatrixCSR{$elty, Cint}, blockDim::Integer; + dir::SparseChar='R', inda::SparseChar='O', indc::SparseChar='O') = + CuSparseMatrixBSR{$elty, Cint}(csr, blockDim; dir, inda, indc) + + function CuSparseMatrixBSR{$elty, Cint}(csr::CuSparseMatrixCSR{$elty, Cint}, blockDim::Integer; dir::SparseChar='R', inda::SparseChar='O', indc::SparseChar='O') m,n = size(csr) @@ -314,7 +332,7 @@ for (fname,elty) in ((:cusparseScsr2bsr, :Float32), cudesca, nonzeros(csr), csr.rowPtr, csr.colVal, blockDim, cudescc, bsrNzVal, bsrRowPtr, bsrColInd) - CuSparseMatrixBSR{$elty}(bsrRowPtr, bsrColInd, bsrNzVal, size(csr), blockDim, dir, nnz_ref[]) + CuSparseMatrixBSR{$elty, Cint}(bsrRowPtr, bsrColInd, bsrNzVal, size(csr), blockDim, dir, nnz_ref[]) end end end @@ -324,7 +342,10 @@ for (fname,elty) in ((:cusparseSbsr2csr, :Float32), (:cusparseCbsr2csr, :ComplexF32), (:cusparseZbsr2csr, :ComplexF64)) @eval begin - function CuSparseMatrixCSR{$elty}(bsr::CuSparseMatrixBSR{$elty}; + CuSparseMatrixCSR{$elty}(bsr::CuSparseMatrixBSR{$elty, Cint}; + inda::SparseChar='O', indc::SparseChar='O') = + CuSparseMatrixCSR{$elty, Cint}(bsr;inda, indc) + function CuSparseMatrixCSR{$elty, Cint}(bsr::CuSparseMatrixBSR{$elty, Cint}; inda::SparseChar='O', indc::SparseChar='O') m,n = size(bsr) mb = cld(m, bsr.blockDim) @@ -340,7 +361,7 @@ for (fname,elty) in ((:cusparseSbsr2csr, :Float32), csrColInd) # XXX: the size here may not match the expected size, when the matrix dimension # is not a multiple of the block dimension! - CuSparseMatrixCSR(csrRowPtr, csrColInd, csrNzVal, (mb*bsr.blockDim, nb*bsr.blockDim)) + CuSparseMatrixCSR{$elty, Cint}(csrRowPtr, csrColInd, csrNzVal, (mb*bsr.blockDim, nb*bsr.blockDim)) end end end @@ -351,8 +372,11 @@ for (elty, felty) in ((:Int16, :Float16), (:Int64, :Float64), (:Int128, :ComplexF64)) @eval begin - function CuSparseMatrixCSR{$elty}(bsr::CuSparseMatrixBSR{$elty}) - bsr_compat = CuSparseMatrixBSR( + CuSparseMatrixCSR{$elty}(bsr::CuSparseMatrixBSR{$elty, Cint}) = CuSparseMatrixCSR{$elty, Cint}(bsr) + CuSparseMatrixBSR{$elty}(csr::CuSparseMatrixCSR{$elty, Cint}, blockDim) = CuSparseMatrixBSR{$elty, Cint}(csr, blockDim) + + function CuSparseMatrixCSR{$elty, Cint}(bsr::CuSparseMatrixBSR{$elty, Cint}) + bsr_compat = CuSparseMatrixBSR{$elty, Cint}( bsr.rowPtr, bsr.colVal, reinterpret($felty, bsr.nzVal), @@ -361,8 +385,8 @@ for (elty, felty) in ((:Int16, :Float16), bsr.nnzb, size(bsr) ) - csr_compat = CuSparseMatrixCSR(bsr_compat) - CuSparseMatrixCSR( + csr_compat = CuSparseMatrixCSR{$elty, Cint}(bsr_compat) + CuSparseMatrixCSR{$elty, Cint}( csr_compat.rowPtr, csr_compat.colVal, reinterpret($elty, csr_compat.nzVal), @@ -370,15 +394,15 @@ for (elty, felty) in ((:Int16, :Float16), ) end - function CuSparseMatrixBSR{$elty}(csr::CuSparseMatrixCSR{$elty}, blockDim) - csr_compat = CuSparseMatrixCSR( + function CuSparseMatrixBSR{$elty, Cint}(csr::CuSparseMatrixCSR{$elty, Cint}, blockDim) + csr_compat = CuSparseMatrixCSR{$elty, Cint}( csr.rowPtr, csr.colVal, reinterpret($felty, csr.nzVal), size(csr) ) - bsr_compat = CuSparseMatrixBSR(csr_compat, blockDim) - CuSparseMatrixBSR( + bsr_compat = CuSparseMatrixBSR{$elty, Cint}(csr_compat, blockDim) + CuSparseMatrixBSR{$elty, Cint}( bsr_compat.rowPtr, bsr_compat.colVal, reinterpret($elty, bsr_compat.nzVal), diff --git a/test/cusparse/conversions.jl b/test/cusparse/conversions.jl index b3a35e637a..c8a52a49aa 100644 --- a/test/cusparse/conversions.jl +++ b/test/cusparse/conversions.jl @@ -56,4 +56,13 @@ end dZ = CuSparseMatrixCSR{Float64, Int32}(dX) @test SparseMatrixCSC(dY) ≈ SparseMatrixCSC(dZ) @test SparseMatrixCSC(CuSparseMatrixCSC(X)) ≈ SparseMatrixCSC(CuSparseMatrixCSR(X)) -end \ No newline at end of file +end + +@testset "$TA{$T}(::$TB)" for T in [Float16, Float32, Float64, ComplexF16, ComplexF32, ComplexF64], + TA in [CuSparseMatrixCSC, CuSparseMatrixCSR], TB in [CuSparseMatrixCSC, CuSparseMatrixCSR] + X = sprand(T, 10, 10, 0.1) + dX = TA{T, Cint}(TB{T, Cint}(X)) + @test TA{T}(X) isa TA{T, Cint} # eagerly convert to Cint + @test TA(X) isa TA{T, Cint} + @test SparseMatrixCSC(dX) ≈ X +end From 19be4bd463ef923154f022b4dcf68b0fff79909d Mon Sep 17 00:00:00 2001 From: Roger-luo Date: Sat, 16 Jul 2022 03:38:27 +0000 Subject: [PATCH 3/3] fix integer conversion --- lib/cusparse/conversions.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/cusparse/conversions.jl b/lib/cusparse/conversions.jl index 94bdc89bce..e905c3705a 100644 --- a/lib/cusparse/conversions.jl +++ b/lib/cusparse/conversions.jl @@ -270,7 +270,7 @@ for (elty, felty) in ((:Int16, :Float16), CuSparseMatrixCSC{$elty}(csr::CuSparseMatrixCSR{$elty, Cint}) = CuSparseMatrixCSC{$elty, Cint}(csr) function CuSparseMatrixCSR{$elty, Cint}(csc::CuSparseMatrixCSC{$elty, Cint}) - csc_compat = CuSparseMatrixCSC{$elty, Cint}( + csc_compat = CuSparseMatrixCSC{$felty, Cint}( csc.colPtr, csc.rowVal, reinterpret($felty, csc.nzVal), @@ -286,7 +286,7 @@ for (elty, felty) in ((:Int16, :Float16), end function CuSparseMatrixCSC{$elty, Cint}(csr::CuSparseMatrixCSR{$elty, Cint}) - csr_compat = CuSparseMatrixCSR{$elty, Cint}( + csr_compat = CuSparseMatrixCSR{$felty, Cint}( csr.rowPtr, csr.colVal, reinterpret($felty, csr.nzVal),