Skip to content

limit csc/csr/bsr sparse conversion index to be cint & fix a few conversion bugs #1563

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 51 additions & 21 deletions lib/cusparse/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ export CuSparseMatrixCSC, CuSparseMatrixCSR, CuSparseMatrixBSR, CuSparseMatrixCO
CuSparseVecOrMat

using LinearAlgebra: BlasFloat
using SparseArrays: nonzeroinds, dimlub
using SparseArrays: nonzeroinds, dimlub, SparseMatrixCSC, SparseVector

abstract type AbstractCuSparseArray{Tv, Ti, N} <: AbstractSparseArray{Tv, Ti, N} end
const AbstractCuSparseVector{Tv, Ti} = AbstractCuSparseArray{Tv, Ti, 1}
Expand All @@ -27,6 +27,12 @@ mutable struct CuSparseVector{Tv, Ti} <: AbstractCuSparseVector{Tv, Ti}
end
end

CuSparseVector(A::CuSparseVector) = A

function CuSparseVector{Tv, Ti}(A::CuSparseVector) where {Tv, Ti}
return CuSparseVector{Tv, Ti}(A.iPtr, A.nzVal, A.len)
end

function CUDA.unsafe_free!(xs::CuSparseVector)
unsafe_free!(nonzeroinds(xs))
unsafe_free!(nonzeros(xs))
Expand All @@ -48,6 +54,10 @@ end

CuSparseMatrixCSC(A::CuSparseMatrixCSC) = A

function CuSparseMatrixCSC{Tv, Ti}(A::CuSparseMatrixCSC) where {Tv, Ti}
return CuSparseMatrixCSC{Tv, Ti}(A.colPtr, A.rowVal, A.nzVal, A.dims)
end

function CUDA.unsafe_free!(xs::CuSparseMatrixCSC)
unsafe_free!(xs.colPtr)
unsafe_free!(rowvals(xs))
Expand Down Expand Up @@ -83,6 +93,10 @@ end

CuSparseMatrixCSR(A::CuSparseMatrixCSR) = A

function CuSparseMatrixCSR{Tv, Ti}(A::CuSparseMatrixCSR) where {Tv, Ti}
return CuSparseMatrixCSR{Tv, Ti}(A.rowPtr, A.colVal, A.nzVal, A.dims)
end

function CUDA.unsafe_free!(xs::CuSparseMatrixCSR)
unsafe_free!(xs.rowPtr)
unsafe_free!(xs.colVal)
Expand Down Expand Up @@ -112,6 +126,9 @@ mutable struct CuSparseMatrixBSR{Tv, Ti} <: AbstractCuSparseMatrix{Tv, Ti}
end

CuSparseMatrixBSR(A::CuSparseMatrixBSR) = A
function CuSparseMatrixBSR{Tv, Ti}(A::CuSparseMatrixBSR) where {Tv, Ti}
return CuSparseMatrixBSR{Tv, Ti}(A.rowPtr, A.colVal, A.nzVal, A.dims, A.blockDim, A.dir, A.nnz)
end

function CUDA.unsafe_free!(xs::CuSparseMatrixBSR)
unsafe_free!(xs.rowPtr)
Expand Down Expand Up @@ -140,6 +157,9 @@ mutable struct CuSparseMatrixCOO{Tv, Ti} <: AbstractCuSparseMatrix{Tv, Ti}
end

CuSparseMatrixCOO(A::CuSparseMatrixCOO) = A
function CuSparseMatrixCOO{Tv, Ti}(A::CuSparseMatrixCOO) where {Tv, Ti}
return CuSparseMatrixCOO{Tv, Ti}(A.rowInd, A.colInd, A.nzVal, A.dims, A.nnz)
end

"""
Utility union type of [`CuSparseMatrixCSC`](@ref), [`CuSparseMatrixCSR`](@ref),
Expand Down Expand Up @@ -353,28 +373,38 @@ end
## interop with sparse CPU arrays

# cpu to gpu
# NOTE: we eagerly convert the indices to Cint here to avoid additional conversion later on
CuSparseVector{T}(Vec::SparseVector) where {T} =
CuSparseVector(CuVector{Cint}(Vec.nzind), CuVector{T}(Vec.nzval), length(Vec))
CuSparseVector{T}(Mat::SparseMatrixCSC) where {T} =
CuSparseVector{Tv, Ti}(Vec::SparseVector) where {Tv, Ti} =
CuSparseVector(CuVector{Ti}(Vec.nzind), CuVector{Tv}(Vec.nzval), length(Vec))
CuSparseVector{Tv, Ti}(Mat::SparseMatrixCSC) where {Tv, Ti} =
size(Mat,2) == 1 ?
CuSparseVector(CuVector{Cint}(Mat.rowval), CuVector{T}(Mat.nzval), size(Mat)[1]) :
CuSparseVector(CuVector{Ti}(Mat.rowval), CuVector{Tv}(Mat.nzval), size(Mat)[1]) :
throw(ArgumentError("The input argument must have a single column"))
CuSparseMatrixCSC{T}(Vec::SparseVector) where {T} =
CuSparseMatrixCSC{T}(CuVector{Cint}([1]), CuVector{Cint}(Vec.nzind),
CuVector{T}(Vec.nzval), size(Vec))
CuSparseMatrixCSC{T}(Mat::SparseMatrixCSC) where {T} =
CuSparseMatrixCSC{T}(CuVector{Cint}(Mat.colptr), CuVector{Cint}(Mat.rowval),
CuVector{T}(Mat.nzval), size(Mat))
CuSparseMatrixCSR{T}(Mat::Transpose{Tv, <:SparseMatrixCSC}) where {T, Tv} =
CuSparseMatrixCSR{T}(CuVector{Cint}(parent(Mat).colptr), CuVector{Cint}(parent(Mat).rowval),
CuVector{T}(parent(Mat).nzval), size(Mat))
CuSparseMatrixCSR{T}(Mat::Adjoint{Tv, <:SparseMatrixCSC}) where {T, Tv} =
CuSparseMatrixCSR{T}(CuVector{Cint}(parent(Mat).colptr), CuVector{Cint}(parent(Mat).rowval),
CuVector{T}(conj.(parent(Mat).nzval)), size(Mat))
CuSparseMatrixCSR{T}(Mat::SparseMatrixCSC) where {T} = CuSparseMatrixCSR(CuSparseMatrixCSC{T}(Mat))
CuSparseMatrixBSR{T}(Mat::SparseMatrixCSC, blockdim) where {T} = CuSparseMatrixBSR(CuSparseMatrixCSR{T}(Mat), blockdim)
CuSparseMatrixCOO{T}(Mat::SparseMatrixCSC) where {T} = CuSparseMatrixCOO(CuSparseMatrixCSR{T}(Mat))
CuSparseMatrixCSC{Tv, Ti}(Vec::SparseVector) where {Tv, Ti} =
CuSparseMatrixCSC{Tv}(CuVector{Ti}([1]), CuVector{Ti}(Vec.nzind),
CuVector{Tv}(Vec.nzval), size(Vec))
CuSparseMatrixCSC{Tv, Ti}(Mat::SparseMatrixCSC) where {Tv, Ti} =
CuSparseMatrixCSC{Tv, Ti}(CuVector{Ti}(Mat.colptr), CuVector{Ti}(Mat.rowval),
CuVector{Tv}(Mat.nzval), size(Mat))
CuSparseMatrixCSR{Tv, Ti}(Mat::Transpose{<:Any, <:SparseMatrixCSC}) where {Tv, Ti} =
CuSparseMatrixCSR{Tv, Ti}(CuVector{Ti}(parent(Mat).colptr), CuVector{Ti}(parent(Mat).rowval),
CuVector{Tv}(parent(Mat).nzval), size(Mat))
CuSparseMatrixCSR{Tv, Ti}(Mat::Adjoint{<:Any, <:SparseMatrixCSC}) where {Tv, Ti} =
CuSparseMatrixCSR{Tv, Ti}(CuVector{Ti}(parent(Mat).colptr), CuVector{Ti}(parent(Mat).rowval),
CuVector{Tv}(conj.(parent(Mat).nzval)), size(Mat))
CuSparseMatrixCSR{Tv, Ti}(Mat::SparseMatrixCSC) where {Tv, Ti} = CuSparseMatrixCSR(CuSparseMatrixCSC{Tv, Ti}(Mat))
CuSparseMatrixBSR{Tv, Ti}(Mat::SparseMatrixCSC, blockdim) where {Tv, Ti} = CuSparseMatrixBSR(CuSparseMatrixCSR{Tv, Ti}(Mat), blockdim)
CuSparseMatrixCOO{Tv, Ti}(Mat::SparseMatrixCSC) where {Tv, Ti} = CuSparseMatrixCOO(CuSparseMatrixCSR{Tv, Ti}(Mat))

# NOTE: we eagerly convert the indices to Cint here to avoid additional conversion later on
CuSparseVector{Tv}(Vec::SparseVector) where {Tv} = CuSparseVector{Tv, Cint}(Vec)
CuSparseVector{Tv}(Mat::SparseMatrixCSC) where {Tv} = CuSparseVector{Tv, Cint}(Mat)
CuSparseMatrixCSC{Tv}(Vec::SparseVector) where {Tv} = CuSparseMatrixCSC{Tv, Cint}(Vec)
CuSparseMatrixCSC{Tv}(Mat::SparseMatrixCSC) where {Tv} = CuSparseMatrixCSC{Tv, Cint}(Mat)
CuSparseMatrixCSR{Tv}(Mat::Transpose{<:Any, <:SparseMatrixCSC}) where {Tv} = CuSparseMatrixCSR{Tv, Cint}(Mat)
CuSparseMatrixCSR{Tv}(Mat::Adjoint{<:Any, <:SparseMatrixCSC}) where {Tv} = CuSparseMatrixCSR{Tv, Cint}(Mat)
CuSparseMatrixCSR{Tv}(Mat::SparseMatrixCSC) where {Tv} = CuSparseMatrixCSR{Tv, Cint}(Mat)
CuSparseMatrixBSR{Tv}(Mat::SparseMatrixCSC, blockdim) where {Tv} = CuSparseMatrixBSR{Tv, Cint}(Mat, blockdim)
CuSparseMatrixCOO{Tv}(Mat::SparseMatrixCSC) where {Tv} = CuSparseMatrixCOO{Tv, Cint}(Mat)

# untyped variants
CuSparseVector(x::AbstractSparseArray{T}) where {T} = CuSparseVector{T}(x)
Expand Down
88 changes: 56 additions & 32 deletions lib/cusparse/conversions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,12 @@ end
# by flipping rows and columns, we can use that to get CSC to CSR too
for elty in (Float32, Float64, ComplexF32, ComplexF64)
@eval begin
function CuSparseMatrixCSC{$elty}(csr::CuSparseMatrixCSR{$elty}; inda::SparseChar='O')
CuSparseMatrixCSC{$elty}(csr::CuSparseMatrixCSR{$elty, Cint}; inda::SparseChar='O') =
CuSparseMatrixCSC{$elty, Cint}(csr; inda)
CuSparseMatrixCSR{$elty}(csc::CuSparseMatrixCSC{$elty, Cint}; inda::SparseChar='O') =
CuSparseMatrixCSR{$elty, Cint}(csc; inda)

function CuSparseMatrixCSC{$elty, Cint}(csr::CuSparseMatrixCSR{$elty, Cint}; inda::SparseChar='O')
m,n = size(csr)
colPtr = CUDA.zeros(Cint, n+1)
rowVal = CUDA.zeros(Cint, nnz(csr))
Expand All @@ -156,10 +161,10 @@ for elty in (Float32, Float64, ComplexF32, ComplexF64)
$elty, CUSPARSE_ACTION_NUMERIC, inda,
CUSPARSE_CSR2CSC_ALG1, buffer)
end
CuSparseMatrixCSC(colPtr,rowVal,nzVal,size(csr))
CuSparseMatrixCSC{$elty, Cint}(colPtr,rowVal,nzVal,size(csr))
end

function CuSparseMatrixCSR{$elty}(csc::CuSparseMatrixCSC{$elty}; inda::SparseChar='O')
function CuSparseMatrixCSR{$elty, Cint}(csc::CuSparseMatrixCSC{$elty, Cint}; inda::SparseChar='O')
m,n = size(csc)
rowPtr = CUDA.zeros(Cint,m+1)
colVal = CUDA.zeros(Cint,nnz(csc))
Expand All @@ -179,7 +184,7 @@ for elty in (Float32, Float64, ComplexF32, ComplexF64)
$elty, CUSPARSE_ACTION_NUMERIC, inda,
CUSPARSE_CSR2CSC_ALG1, buffer)
end
CuSparseMatrixCSR(rowPtr,colVal,nzVal,size(csc))
CuSparseMatrixCSR{$elty, Cint}(rowPtr,colVal,nzVal,size(csc))
end
end
end
Expand All @@ -189,7 +194,12 @@ end
for (elty, welty) in ((:Float16, :Float32),
(:ComplexF16, :ComplexF32))
@eval begin
function CuSparseMatrixCSC{$elty}(csr::CuSparseMatrixCSR{$elty}; inda::SparseChar='O')
CuSparseMatrixCSC{$elty}(csr::CuSparseMatrixCSR{$elty, Cint}; inda::SparseChar='O') =
CuSparseMatrixCSC{$elty, Cint}(csr; inda)
CuSparseMatrixCSR{$elty}(csc::CuSparseMatrixCSC{$elty, Cint}; inda::SparseChar='O') =
CuSparseMatrixCSR{$elty, Cint}(csc; inda)

function CuSparseMatrixCSC{$elty, Cint}(csr::CuSparseMatrixCSR{$elty, Cint}; inda::SparseChar='O')
m,n = size(csr)
colPtr = CUDA.zeros(Cint, n+1)
rowVal = CUDA.zeros(Cint, nnz(csr))
Expand All @@ -210,15 +220,15 @@ for (elty, welty) in ((:Float16, :Float32),
$elty, CUSPARSE_ACTION_NUMERIC, inda,
CUSPARSE_CSR2CSC_ALG1, buffer)
end
return CuSparseMatrixCSC(colPtr,rowVal,nzVal,size(csr))
return CuSparseMatrixCSC{$elty, Cint}(colPtr,rowVal,nzVal,size(csr))
else
wide_csr = CuSparseMatrixCSR(csr.rowPtr, csr.colVal, convert(CuVector{$welty}, nonzeros(csr)), size(csr))
wide_csc = CuSparseMatrixCSC(wide_csr)
return CuSparseMatrixCSC(wide_csc.colPtr, wide_csc.rowVal, convert(CuVector{$elty}, nonzeros(wide_csc)), size(wide_csc))
wide_csr = CuSparseMatrixCSR{$welty, Cint}(csr.rowPtr, csr.colVal, convert(CuVector{$welty}, nonzeros(csr)), size(csr))
wide_csc = CuSparseMatrixCSC{$welty, Cint}(wide_csr)
return CuSparseMatrixCSC{$elty, Cint}(wide_csc.colPtr, wide_csc.rowVal, convert(CuVector{$elty}, nonzeros(wide_csc)), size(wide_csc))
end
end

function CuSparseMatrixCSR{$elty}(csc::CuSparseMatrixCSC{$elty}; inda::SparseChar='O')
function CuSparseMatrixCSR{$elty, Cint}(csc::CuSparseMatrixCSC{$elty, Cint}; inda::SparseChar='O')
m,n = size(csc)
rowPtr = CUDA.zeros(Cint,m+1)
colVal = CUDA.zeros(Cint,nnz(csc))
Expand All @@ -239,11 +249,11 @@ for (elty, welty) in ((:Float16, :Float32),
$elty, CUSPARSE_ACTION_NUMERIC, inda,
CUSPARSE_CSR2CSC_ALG1, buffer)
end
return CuSparseMatrixCSR(rowPtr,colVal,nzVal,size(csc))
return CuSparseMatrixCSR{$elty, Cint}(rowPtr,colVal,nzVal,size(csc))
else
wide_csc = CuSparseMatrixCSC(csc.colPtr, csc.rowVal, convert(CuVector{$welty}, nonzeros(csc)), size(csc))
wide_csr = CuSparseMatrixCSR(wide_csc)
return CuSparseMatrixCSR(wide_csr.rowPtr, wide_csr.colVal, convert(CuVector{$elty}, nonzeros(wide_csr)), size(wide_csr))
wide_csc = CuSparseMatrixCSC{$welty, Cint}(csc.colPtr, csc.rowVal, convert(CuVector{$welty}, nonzeros(csc)), size(csc))
wide_csr = CuSparseMatrixCSR{$welty, Cint}(wide_csc)
return CuSparseMatrixCSR{$elty, Cint}(wide_csr.rowPtr, wide_csr.colVal, convert(CuVector{$elty}, nonzeros(wide_csr)), size(wide_csr))
end
end
end
Expand All @@ -255,31 +265,35 @@ for (elty, felty) in ((:Int16, :Float16),
(:Int64, :Float64),
(:Int128, :ComplexF64))
@eval begin
function CuSparseMatrixCSR{$elty}(csc::CuSparseMatrixCSC{$elty})
csc_compat = CuSparseMatrixCSC(
CuSparseMatrixCSR{$elty}(csc::CuSparseMatrixCSC{$elty, Cint}) =
CuSparseMatrixCSR{$elty, Cint}(csc)
CuSparseMatrixCSC{$elty}(csr::CuSparseMatrixCSR{$elty, Cint}) =
CuSparseMatrixCSC{$elty, Cint}(csr)
function CuSparseMatrixCSR{$elty, Cint}(csc::CuSparseMatrixCSC{$elty, Cint})
csc_compat = CuSparseMatrixCSC{$felty, Cint}(
csc.colPtr,
csc.rowVal,
reinterpret($felty, csc.nzVal),
size(csc)
)
csr_compat = CuSparseMatrixCSR(csc_compat)
CuSparseMatrixCSR(
CuSparseMatrixCSR{$elty, Cint}(
csr_compat.rowPtr,
csr_compat.colVal,
reinterpret($elty, csr_compat.nzVal),
size(csr_compat)
)
end

function CuSparseMatrixCSC{$elty}(csr::CuSparseMatrixCSR{$elty})
csr_compat = CuSparseMatrixCSR(
function CuSparseMatrixCSC{$elty, Cint}(csr::CuSparseMatrixCSR{$elty, Cint})
csr_compat = CuSparseMatrixCSR{$felty, Cint}(
csr.rowPtr,
csr.colVal,
reinterpret($felty, csr.nzVal),
size(csr)
)
csc_compat = CuSparseMatrixCSC(csr_compat)
CuSparseMatrixCSC(
CuSparseMatrixCSC{$elty, Cint}(
csc_compat.colPtr,
csc_compat.rowVal,
reinterpret($elty, csc_compat.nzVal),
Expand All @@ -296,7 +310,11 @@ for (fname,elty) in ((:cusparseScsr2bsr, :Float32),
(:cusparseCcsr2bsr, :ComplexF32),
(:cusparseZcsr2bsr, :ComplexF64))
@eval begin
function CuSparseMatrixBSR{$elty}(csr::CuSparseMatrixCSR{$elty}, blockDim::Integer;
CuSparseMatrixBSR{$elty}(csr::CuSparseMatrixCSR{$elty, Cint}, blockDim::Integer;
dir::SparseChar='R', inda::SparseChar='O', indc::SparseChar='O') =
CuSparseMatrixBSR{$elty, Cint}(csr, blockDim; dir, inda, indc)

function CuSparseMatrixBSR{$elty, Cint}(csr::CuSparseMatrixCSR{$elty, Cint}, blockDim::Integer;
dir::SparseChar='R', inda::SparseChar='O',
indc::SparseChar='O')
m,n = size(csr)
Expand All @@ -314,7 +332,7 @@ for (fname,elty) in ((:cusparseScsr2bsr, :Float32),
cudesca, nonzeros(csr), csr.rowPtr, csr.colVal,
blockDim, cudescc, bsrNzVal, bsrRowPtr,
bsrColInd)
CuSparseMatrixBSR{$elty}(bsrRowPtr, bsrColInd, bsrNzVal, size(csr), blockDim, dir, nnz_ref[])
CuSparseMatrixBSR{$elty, Cint}(bsrRowPtr, bsrColInd, bsrNzVal, size(csr), blockDim, dir, nnz_ref[])
end
end
end
Expand All @@ -324,7 +342,10 @@ for (fname,elty) in ((:cusparseSbsr2csr, :Float32),
(:cusparseCbsr2csr, :ComplexF32),
(:cusparseZbsr2csr, :ComplexF64))
@eval begin
function CuSparseMatrixCSR{$elty}(bsr::CuSparseMatrixBSR{$elty};
CuSparseMatrixCSR{$elty}(bsr::CuSparseMatrixBSR{$elty, Cint};
inda::SparseChar='O', indc::SparseChar='O') =
CuSparseMatrixCSR{$elty, Cint}(bsr;inda, indc)
function CuSparseMatrixCSR{$elty, Cint}(bsr::CuSparseMatrixBSR{$elty, Cint};
inda::SparseChar='O', indc::SparseChar='O')
m,n = size(bsr)
mb = cld(m, bsr.blockDim)
Expand All @@ -340,7 +361,7 @@ for (fname,elty) in ((:cusparseSbsr2csr, :Float32),
csrColInd)
# XXX: the size here may not match the expected size, when the matrix dimension
# is not a multiple of the block dimension!
CuSparseMatrixCSR(csrRowPtr, csrColInd, csrNzVal, (mb*bsr.blockDim, nb*bsr.blockDim))
CuSparseMatrixCSR{$elty, Cint}(csrRowPtr, csrColInd, csrNzVal, (mb*bsr.blockDim, nb*bsr.blockDim))
end
end
end
Expand All @@ -351,8 +372,11 @@ for (elty, felty) in ((:Int16, :Float16),
(:Int64, :Float64),
(:Int128, :ComplexF64))
@eval begin
function CuSparseMatrixCSR{$elty}(bsr::CuSparseMatrixBSR{$elty})
bsr_compat = CuSparseMatrixBSR(
CuSparseMatrixCSR{$elty}(bsr::CuSparseMatrixBSR{$elty, Cint}) = CuSparseMatrixCSR{$elty, Cint}(bsr)
CuSparseMatrixBSR{$elty}(csr::CuSparseMatrixCSR{$elty, Cint}, blockDim) = CuSparseMatrixBSR{$elty, Cint}(csr, blockDim)

function CuSparseMatrixCSR{$elty, Cint}(bsr::CuSparseMatrixBSR{$elty, Cint})
bsr_compat = CuSparseMatrixBSR{$elty, Cint}(
bsr.rowPtr,
bsr.colVal,
reinterpret($felty, bsr.nzVal),
Expand All @@ -361,24 +385,24 @@ for (elty, felty) in ((:Int16, :Float16),
bsr.nnzb,
size(bsr)
)
csr_compat = CuSparseMatrixCSR(bsr_compat)
CuSparseMatrixCSR(
csr_compat = CuSparseMatrixCSR{$elty, Cint}(bsr_compat)
CuSparseMatrixCSR{$elty, Cint}(
csr_compat.rowPtr,
csr_compat.colVal,
reinterpret($elty, csr_compat.nzVal),
size(csr_compat)
)
end

function CuSparseMatrixBSR{$elty}(csr::CuSparseMatrixCSR{$elty}, blockDim)
csr_compat = CuSparseMatrixCSR(
function CuSparseMatrixBSR{$elty, Cint}(csr::CuSparseMatrixCSR{$elty, Cint}, blockDim)
csr_compat = CuSparseMatrixCSR{$elty, Cint}(
csr.rowPtr,
csr.colVal,
reinterpret($felty, csr.nzVal),
size(csr)
)
bsr_compat = CuSparseMatrixBSR(csr_compat, blockDim)
CuSparseMatrixBSR(
bsr_compat = CuSparseMatrixBSR{$elty, Cint}(csr_compat, blockDim)
CuSparseMatrixBSR{$elty, Cint}(
bsr_compat.rowPtr,
bsr_compat.colVal,
reinterpret($elty, bsr_compat.nzVal),
Expand Down
11 changes: 10 additions & 1 deletion test/cusparse/conversions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,13 @@ end
dZ = CuSparseMatrixCSR{Float64, Int32}(dX)
@test SparseMatrixCSC(dY) ≈ SparseMatrixCSC(dZ)
@test SparseMatrixCSC(CuSparseMatrixCSC(X)) ≈ SparseMatrixCSC(CuSparseMatrixCSR(X))
end
end

@testset "$TA{$T}(::$TB)" for T in [Float16, Float32, Float64, ComplexF16, ComplexF32, ComplexF64],
TA in [CuSparseMatrixCSC, CuSparseMatrixCSR], TB in [CuSparseMatrixCSC, CuSparseMatrixCSR]
X = sprand(T, 10, 10, 0.1)
dX = TA{T, Cint}(TB{T, Cint}(X))
@test TA{T}(X) isa TA{T, Cint} # eagerly convert to Cint
@test TA(X) isa TA{T, Cint}
@test SparseMatrixCSC(dX) ≈ X
end