diff --git a/Project.toml b/Project.toml index daa62a5e42..db6403c628 100644 --- a/Project.toml +++ b/Project.toml @@ -41,11 +41,13 @@ demumble_jll = "1e29f10c-031c-5a83-9565-69cddfc27673" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" +SparseMatricesCSR = "a0a7dd2c-ebf4-11e9-1f05-cf50bc540ca1" [extensions] ChainRulesCoreExt = "ChainRulesCore" EnzymeCoreExt = "EnzymeCore" SpecialFunctionsExt = "SpecialFunctions" +SparseMatricesCSRExt = "SparseMatricesCSR" [compat] AbstractFFTs = "0.4, 0.5, 1.0" @@ -80,6 +82,7 @@ RandomNumbers = "1.5.3" Reexport = "0.2, 1.0" Requires = "0.5, 1.0" SparseArrays = "1" +SparseMatricesCSR = "0.6.9" SpecialFunctions = "1.3, 2" StaticArrays = "1" Statistics = "1" @@ -90,3 +93,4 @@ julia = "1.10" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" +SparseMatricesCSR = "a0a7dd2c-ebf4-11e9-1f05-cf50bc540ca1" diff --git a/ext/SparseMatricesCSRExt.jl b/ext/SparseMatricesCSRExt.jl new file mode 100644 index 0000000000..7ffe884014 --- /dev/null +++ b/ext/SparseMatricesCSRExt.jl @@ -0,0 +1,32 @@ +module SparseMatricesCSRExt + +using CUDA +import CUDA.CUSPARSE: + CuSparseMatrixCSR, CuSparseMatrixCSC, CuSparseMatrixCOO, CuSparseMatrixBSR, + SparseMatrixCSC +using SparseMatricesCSR +import SparseMatricesCSR: SparseMatrixCSR +import Adapt + +# CPU → GPU +CUSPARSE.CuSparseMatrixCSR{T}(Mat::SparseMatrixCSR) where {T} = + CUSPARSE.CuSparseMatrixCSR{T}( + CuVector{Cint}(Mat.rowptr), CuVector{Cint}(Mat.colval), + CuVector{T}(Mat.nzval), size(Mat) +) +CUSPARSE.CuSparseMatrixCSC{T}(Mat::SparseMatrixCSR) where {T} = CuSparseMatrixCSC(CuSparseMatrixCSR{T}(Mat)) +CUSPARSE.CuSparseMatrixCOO{T}(Mat::SparseMatrixCSR) where {T} = CuSparseMatrixCOO(CuSparseMatrixCSR{T}(Mat)) +CUSPARSE.CuSparseMatrixBSR{T}(Mat::SparseMatrixCSR, blockdim) where {T} = CuSparseMatrixBSR(CuSparseMatrixCSR{T}(Mat), blockdim) + +# GPU → CPU +SparseMatricesCSR.SparseMatrixCSR(A::CUSPARSE.CuSparseMatrixCSR) = SparseMatrixCSR{1}(size(A)..., Array(A.rowPtr), Array(A.colVal), Array(A.nzVal)) +SparseMatricesCSR.SparseMatrixCSR(A::CUSPARSE.CuSparseMatrixCOO) = SparseMatrixCSR(CuSparseMatrixCSR(A)) +SparseMatricesCSR.SparseMatrixCSR(A::CUSPARSE.CuSparseMatrixCSC) = SparseMatrixCSR(CuSparseMatrixCSR(A)) +SparseMatricesCSR.SparseMatrixCSR(A::CUSPARSE.CuSparseMatrixBSR) = SparseMatrixCSR(CuSparseMatrixCSR(A)) + +# Adapt +Adapt.adapt_storage(::Type{CuArray}, xs::SparseMatrixCSR) = CUSPARSE.CuSparseMatrixCSR(xs) +Adapt.adapt_storage(::Type{CuArray{T}}, xs::SparseMatrixCSR) where {T} = CUSPARSE.CuSparseMatrixCSR{T}(xs) +Adapt.adapt_storage(::Type{Array}, mat::CUSPARSE.CuSparseMatrixCSR) = SparseMatrixCSR(mat) + +end diff --git a/test/extensions/sparse_matrices_csr.jl b/test/extensions/sparse_matrices_csr.jl new file mode 100644 index 0000000000..032a7ec8df --- /dev/null +++ b/test/extensions/sparse_matrices_csr.jl @@ -0,0 +1,30 @@ +using SparseMatricesCSR +using SparseArrays +using CUDA +using CUDA.CUSPARSE +using Test + +@testset "SparseMatricesCSRExt" begin + + for (n, bd, p) in [(100, 5, 0.02), (5, 1, 0.8), (4, 2, 0.5)] + v"12.0" <= CUSPARSE.version() < v"12.1" && n == 4 && continue + @testset "conversions between CuSparseMatrices (n, bd, p) = ($n, $bd, $p)" begin + _A = sprand(n, n, p) + A = SparseMatrixCSR(_A) + blockdim = bd + for CuSparseMatrixType1 in (CuSparseMatrixCSC, CuSparseMatrixCSR, CuSparseMatrixCOO, CuSparseMatrixBSR) + dA1 = CuSparseMatrixType1 == CuSparseMatrixBSR ? CuSparseMatrixType1(A, blockdim) : CuSparseMatrixType1(A) + @testset "conversion $CuSparseMatrixType1 --> SparseMatrixCSR" begin + @test SparseMatrixCSR(dA1) ≈ A + end + for CuSparseMatrixType2 in (CuSparseMatrixCSC, CuSparseMatrixCSR, CuSparseMatrixCOO, CuSparseMatrixBSR) + CuSparseMatrixType1 == CuSparseMatrixType2 && continue + dA2 = CuSparseMatrixType2 == CuSparseMatrixBSR ? CuSparseMatrixType2(dA1, blockdim) : CuSparseMatrixType2(dA1) + @testset "conversion $CuSparseMatrixType1 --> $CuSparseMatrixType2" begin + @test collect(dA1) ≈ collect(dA2) + end + end + end + end + end +end