Skip to content

Commit 2224e4b

Browse files
add more conversions + mirror sparse conversion tests
1 parent 08b2fee commit 2224e4b

File tree

2 files changed

+41
-30
lines changed

2 files changed

+41
-30
lines changed

Diff for: ext/SparseMatricesCSRExt.jl

+18-16
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,32 @@
11
module SparseMatricesCSRExt
22

33
using CUDA
4-
import CUDA: CUSPARSE
4+
import CUDA.CUSPARSE:
5+
CuSparseMatrixCSR, CuSparseMatrixCSC, CuSparseMatrixCOO, CuSparseMatrixBSR,
6+
SparseMatrixCSC
57
using SparseMatricesCSR
68
import SparseMatricesCSR: SparseMatrixCSR
79
import Adapt
810

11+
# CPU → GPU
912
CUSPARSE.CuSparseMatrixCSR{T}(Mat::SparseMatrixCSR) where {T} =
1013
CUSPARSE.CuSparseMatrixCSR{T}(
1114
CuVector{Cint}(Mat.rowptr), CuVector{Cint}(Mat.colval),
1215
CuVector{T}(Mat.nzval), size(Mat)
1316
)
14-
15-
CUSPARSE.CuSparseMatrixCSC{T}(Mat::SparseMatrixCSR) where {T} =
16-
CUSPARSE.CuSparseMatrixCSC{T}(CUSPARSE.CuSparseMatrixCSR(Mat))
17-
18-
SparseMatricesCSR.SparseMatrixCSR(A::CUSPARSE.CuSparseMatrixCSR) =
19-
SparseMatrixCSR(CUSPARSE.SparseMatrixCSC(A)) # no direct conversion (gpu_CSR -> cpu_CSC -> cpu_CSR)
20-
21-
Adapt.adapt_storage(::Type{CuArray}, xs::SparseMatrixCSR) =
22-
CUSPARSE.CuSparseMatrixCSR(xs)
23-
24-
Adapt.adapt_storage(::Type{CuArray{T}}, xs::SparseMatrixCSR) where {T} =
25-
CUSPARSE.CuSparseMatrixCSR{T}(xs)
26-
27-
Adapt.adapt_storage(::Type{Array}, mat::CUSPARSE.CuSparseMatrixCSR) =
28-
SparseMatrixCSR(mat)
17+
CUSPARSE.CuSparseMatrixCSC{T}(Mat::SparseMatrixCSR) where {T} = CuSparseMatrixCSC(CuSparseMatrixCSR{T}(Mat))
18+
CUSPARSE.CuSparseMatrixCOO{T}(Mat::SparseMatrixCSR) where {T} = CuSparseMatrixCOO(CuSparseMatrixCSR{T}(Mat))
19+
CUSPARSE.CuSparseMatrixBSR{T}(Mat::SparseMatrixCSR, blockdim) where {T} = CuSparseMatrixBSR(CuSparseMatrixCSR{T}(Mat), blockdim)
20+
21+
# GPU → CPU
22+
SparseMatricesCSR.SparseMatrixCSR(A::CUSPARSE.CuSparseMatrixCSR) = SparseMatrixCSR{1}(size(A)..., Array(A.rowPtr), Array(A.colVal), Array(A.nzVal))
23+
SparseMatricesCSR.SparseMatrixCSR(A::CUSPARSE.CuSparseMatrixCOO) = SparseMatrixCSR(CuSparseMatrixCSR(A))
24+
SparseMatricesCSR.SparseMatrixCSR(A::CUSPARSE.CuSparseMatrixCSC) = SparseMatrixCSR(CuSparseMatrixCSR(A))
25+
SparseMatricesCSR.SparseMatrixCSR(A::CUSPARSE.CuSparseMatrixBSR) = SparseMatrixCSR(CuSparseMatrixCSR(A))
26+
27+
# Adapt
28+
Adapt.adapt_storage(::Type{CuArray}, xs::SparseMatrixCSR) = CUSPARSE.CuSparseMatrixCSR(xs)
29+
Adapt.adapt_storage(::Type{CuArray{T}}, xs::SparseMatrixCSR) where {T} = CUSPARSE.CuSparseMatrixCSR{T}(xs)
30+
Adapt.adapt_storage(::Type{Array}, mat::CUSPARSE.CuSparseMatrixCSR) = SparseMatrixCSR(mat)
2931

3032
end

Diff for: test/extensions/sparse_matrices_csr.jl

+23-14
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,30 @@
11
using SparseMatricesCSR
22
using SparseArrays
33
using CUDA
4+
using CUDA.CUSPARSE
45
using Test
56

6-
@testset "SparseMatricesCSR" begin
7-
A = sprand(10, 10, 0.1)
8-
A_csr = SparseMatrixCSR(A)
9-
A_gpu = CUSPARSE.CuSparseMatrixCSR(A_csr)
7+
@testset "SparseMatricesCSRExt" begin
108

11-
@test size(A_gpu) == size(A_csr)
12-
@test CUSPARSE.nnz(A_gpu) == nnz(A_csr)
13-
@test SparseMatrixCSR(A_gpu) A_csr
14-
@test A_csr |> cu isa CUSPARSE.CuSparseMatrixCSR
15-
16-
# convert from CSR to CuCSC
17-
A_csc_gpu = CUSPARSE.CuSparseMatrixCSC(A_csr)
18-
@test size(A_csc_gpu) == size(A)
19-
@test CUSPARSE.nnz(A_csc_gpu) == nnz(A)
20-
@test SparseMatrixCSC(A_csc_gpu) A
9+
for (n, bd, p) in [(100, 5, 0.02), (5, 1, 0.8), (4, 2, 0.5)]
10+
v"12.0" <= CUSPARSE.version() < v"12.1" && n == 4 && continue
11+
@testset "conversions between CuSparseMatrices (n, bd, p) = ($n, $bd, $p)" begin
12+
_A = sprand(n, n, p)
13+
A = SparseMatrixCSR(_A)
14+
blockdim = bd
15+
for CuSparseMatrixType1 in (CuSparseMatrixCSC, CuSparseMatrixCSR, CuSparseMatrixCOO, CuSparseMatrixBSR)
16+
dA1 = CuSparseMatrixType1 == CuSparseMatrixBSR ? CuSparseMatrixType1(A, blockdim) : CuSparseMatrixType1(A)
17+
@testset "conversion $CuSparseMatrixType1 --> SparseMatrixCSR" begin
18+
@test SparseMatrixCSR(dA1) A
19+
end
20+
for CuSparseMatrixType2 in (CuSparseMatrixCSC, CuSparseMatrixCSR, CuSparseMatrixCOO, CuSparseMatrixBSR)
21+
CuSparseMatrixType1 == CuSparseMatrixType2 && continue
22+
dA2 = CuSparseMatrixType2 == CuSparseMatrixBSR ? CuSparseMatrixType2(dA1, blockdim) : CuSparseMatrixType2(dA1)
23+
@testset "conversion $CuSparseMatrixType1 --> $CuSparseMatrixType2" begin
24+
@test collect(dA1) collect(dA2)
25+
end
26+
end
27+
end
28+
end
29+
end
2130
end

0 commit comments

Comments
 (0)