Skip to content

Commit 76701d2

Browse files
committed
Add adapt definitions for SparseArrays
1 parent a18368b commit 76701d2

File tree

4 files changed

+29
-2
lines changed

4 files changed

+29
-2
lines changed

Project.toml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,30 @@
11
name = "Adapt"
22
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
3-
version = "4.1.1"
3+
version = "4.2.0"
44

55
[deps]
66
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
77
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
88

99
[weakdeps]
10+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1011
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1112

1213
[extensions]
14+
AdaptSparseArraysExt = "SparseArrays"
1315
AdaptStaticArraysExt = "StaticArrays"
1416

1517
[compat]
1618
Requires = "1"
1719
StaticArrays = "1"
1820
julia = "1.6"
1921
LinearAlgebra = "1"
22+
SparseArrays = "1"
2023

2124
[extras]
2225
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
26+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2327
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2428

2529
[targets]
26-
test = ["StaticArrays", "Test"]
30+
test = ["SparseArrays", "StaticArrays", "Test"]

ext/AdaptSparseArraysExt.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
module AdaptSparseArraysExt
2+
3+
using Adapt
4+
isdefined(Base, :get_extension) ? (using SparseArrays) : (using ..SparseArrays)
5+
6+
Adapt.adapt_storage(::Type{Array}, xs::SparseVector) = xs
7+
Adapt.adapt_storage(::Type{Array}, xs::SparseMatrixCSC) = xs
8+
Adapt.adapt_storage(::Type{Array{T}}, xs::SparseVector) where {T} = SparseVector{T}(xs)
9+
Adapt.adapt_storage(::Type{Array{T}}, xs::SparseMatrixCSC) where {T} = SparseMatrixCSC{T}(xs)
10+
11+
end

src/Adapt.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ end
7373

7474
@static if !isdefined(Base, :get_extension)
7575
function __init__()
76+
@require SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" begin include("../ext/AdaptSparseArraysExt.jl") end
7677
@require StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" begin include("../ext/AdaptStaticArraysExt.jl") end
7778
end
7879
end

test/runtests.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,17 @@ end
224224
@test typeof(copy(adapt(CustomArray, bc))) == typeof(broadcast(f(mat), (mat,)))
225225
end
226226

227+
@testset "SparseArrays" begin
228+
using SparseArrays
229+
m = sparse([1, 2], [2, 1], [1, 2])
230+
@test_adapt Array m m
231+
@test_adapt Array{Float64} m SparseMatrixCSC{Float64}(m)
232+
233+
v = sparsevec([1, 3], [1, 2])
234+
@test_adapt Array v v
235+
@test_adapt Array{Float64} v SparseVector{Float64}(v)
236+
end
237+
227238
@testset "StaticArrays" begin
228239
using StaticArrays
229240
@test_adapt SArray{Tuple{3}} [1,2,3] SArray{Tuple{3}}([1,2,3])

0 commit comments

Comments
 (0)