diff --git a/Project.toml b/Project.toml index d714e04..bc0c983 100644 --- a/Project.toml +++ b/Project.toml @@ -1,15 +1,17 @@ name = "Adapt" uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" -version = "4.1.1" +version = "4.2.0" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Requires = "ae029012-a4dd-5104-9daa-d747884805df" [weakdeps] +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" [extensions] +AdaptSparseArraysExt = "SparseArrays" AdaptStaticArraysExt = "StaticArrays" [compat] @@ -17,10 +19,12 @@ Requires = "1" StaticArrays = "1" julia = "1.6" LinearAlgebra = "1" +SparseArrays = "1" [extras] Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" [targets] -test = ["StaticArrays", "Test"] +test = ["SparseArrays", "StaticArrays", "Test"] diff --git a/ext/AdaptSparseArraysExt.jl b/ext/AdaptSparseArraysExt.jl new file mode 100644 index 0000000..09ee7d0 --- /dev/null +++ b/ext/AdaptSparseArraysExt.jl @@ -0,0 +1,11 @@ +module AdaptSparseArraysExt + +using Adapt +isdefined(Base, :get_extension) ? (using SparseArrays) : (using ..SparseArrays) + +Adapt.adapt_storage(::Type{Array}, xs::SparseVector) = xs +Adapt.adapt_storage(::Type{Array}, xs::SparseMatrixCSC) = xs +Adapt.adapt_storage(::Type{Array{T}}, xs::SparseVector) where {T} = SparseVector{T}(xs) +Adapt.adapt_storage(::Type{Array{T}}, xs::SparseMatrixCSC) where {T} = SparseMatrixCSC{T}(xs) + +end diff --git a/src/Adapt.jl b/src/Adapt.jl index 6910c8c..fe5f817 100644 --- a/src/Adapt.jl +++ b/src/Adapt.jl @@ -73,7 +73,12 @@ end @static if !isdefined(Base, :get_extension) function __init__() - @require StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" begin include("../ext/AdaptStaticArraysExt.jl") end + @require SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" begin + include("../ext/AdaptSparseArraysExt.jl") + end + @require StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" begin + include("../ext/AdaptStaticArraysExt.jl") + end end end diff --git a/test/runtests.jl b/test/runtests.jl index 27d9031..56e184c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -224,6 +224,17 @@ end @test typeof(copy(adapt(CustomArray, bc))) == typeof(broadcast(f(mat), (mat,))) end +@testset "SparseArrays" begin + using SparseArrays + m = sparse([1, 2], [2, 1], [1, 2]) + @test_adapt Array m m + @test_adapt Array{Float64} m SparseMatrixCSC{Float64}(m) + + v = sparsevec([1, 3], [1, 2]) + @test_adapt Array v v + @test_adapt Array{Float64} v SparseVector{Float64}(v) +end + @testset "StaticArrays" begin using StaticArrays @test_adapt SArray{Tuple{3}} [1,2,3] SArray{Tuple{3}}([1,2,3])