Skip to content

Commit 7ecac73

Browse files
committed
Added tests for the adapt functions
1 parent 53b222f commit 7ecac73

File tree

7 files changed

+100
-8
lines changed

7 files changed

+100
-8
lines changed

src/adapt.jl

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,14 @@ end
2121

2222
function Adapt.adapt_structure(to,v::SplitVectorBlocks)
2323
own = Adapt.adapt(to,v.own)
24-
ghost = Adapt.adapt(to,v.own_ghost)
25-
24+
ghost = Adapt.adapt(to,v.ghost)
2625
split_vector_blocks(own,ghost)
2726
end
2827

2928
function Adapt.adapt_structure(to,v::SplitVector)
3029
blocks = Adapt.adapt(to,v.blocks)
31-
split_vector(blocks,v.permutation)
30+
perm = Adapt.adapt(to,v.permutation)
31+
split_vector(blocks,perm)
3232
end
3333

3434
function Adapt.adapt_structure(to,v::JaggedArray)
@@ -39,12 +39,14 @@ end
3939

4040
function Adapt.adapt_structure(to,v::SplitMatrix)
4141
blocks = Adapt.adapt_structure(to,v.blocks)
42-
col_par = v.col_permutation
43-
row_par = v.row_permutation
44-
split_matrix(blocks,row_par,col_par)
42+
col_per = v.col_permutation
43+
row_per = v.row_permutation
44+
split_matrix(blocks,row_par,col_per)
4545
end
4646

4747
function Adapt.adapt_structure(to,v::PSparseMatrix)
4848
matrix_partition = Adapt.adapt_structure(to,v.matrix_partition)
49-
PSparseMatrix(matrix_partition,v.row_partition,v.col_partition,v.assembled)
50-
end
49+
col_par = v.col_permutation
50+
row_par = v.row_permutation
51+
PSparseMatrix(matrix_partition,row_par,col_par,v.assembled)
52+
end

test/adapt_tests.jl

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
using Test
2+
using PartitionedArrays
3+
using Adapt
4+
5+
struct FakeCuVector{A} <: AbstractVector{Float64}
6+
vector::A
7+
end
8+
9+
Base.size(v::FakeCuVector) = size(v.vector)
10+
Base.getindex(v::FakeCuVector,i::Integer) = v.vector[i]
11+
12+
function Adapt.adapt_storage(::Type{<:FakeCuVector},x::AbstractArray)
13+
FakeCuVector(x)
14+
end
15+
16+
function adapt_tests(distribute)
17+
18+
rank = distribute(LinearIndices((2,2)))
19+
20+
a = [[1,2],[3,4,5],Int[],[3,4]]
21+
b = JaggedArray(a)
22+
c = deepcopy(b)
23+
24+
c = Adapt.adapt(FakeCuVector,c)
25+
26+
@test typeof(c.data) == FakeCuVector{typeof(b.data)}
27+
@test typeof(c.ptrs) == FakeCuVector{typeof(b.ptrs)}
28+
@test typeof(c).name.wrapper == GenericJaggedArray
29+
30+
a = [1,2,3,4,5]
31+
b = deepcopy(a)
32+
b = Adapt.adapt(FakeCuVector,b)
33+
@test typeof(b) == FakeCuVector{typeof(a)}
34+
@test b.vector == a
35+
36+
own = [1,2,3,4]
37+
ghost = [5,6,7,8]
38+
block_a = split_vector_blocks(own, ghost)
39+
block_b = deepcopy(block_a)
40+
block_b = Adapt.adapt(FakeCuVector,block_b)
41+
@test block_b.own.vector == block_a.own
42+
@test block_b.ghost.vector == block_a.ghost
43+
@test typeof(block_b.own) == FakeCuVector{typeof(block_a.own)}
44+
@test typeof(block_b.ghost) == FakeCuVector{typeof(block_a.ghost)}
45+
46+
47+
a = split_vector(block_a,[1,2,3,4,5,6,7,8])
48+
b = deepcopy(a)
49+
b = Adapt.adapt(FakeCuVector,b)
50+
51+
@test b.blocks.own.vector == a.blocks.own
52+
@test b.blocks.ghost.vector == a.blocks.ghost
53+
@test b.permutation.vector == a.permutation
54+
55+
56+
a = distribute([[1,1,1],[2,2,2],[3,3,3],[4,4,4]])
57+
b = distribute([[1,1,1],[2,2,2],[3,3,3],[4,4,4]])
58+
b = Adapt.adapt(FakeCuVector,b)
59+
60+
map(a,b) do val_a,val_b
61+
@test typeof(val_b) == FakeCuVector{typeof(val_a)}
62+
@test val_b.vector == val_a
63+
end
64+
end

test/debug_array/adapt_tests.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
module DebugArrayAdaptTests
2+
3+
using PartitionedArrays
4+
5+
include(joinpath("..","adapt_tests.jl"))
6+
7+
with_debug(adapt_tests)
8+
9+
end # module

test/debug_array/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,6 @@ using PartitionedArrays
2323

2424
@testset "fem_example" begin include("fem_example.jl") end
2525

26+
@testset "adapt" begin include("adapt_tests.jl") end
27+
2628
end #module

test/mpi_array/adapt_tests.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
using MPI
2+
include("run_mpi_driver.jl")
3+
file = joinpath(@__DIR__,"drivers","adapt_tests.jl")
4+
run_mpi_driver(file;procs=4)
5+

test/mpi_array/drivers/adapt_tests.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
module MPIArrayAdaptTests
2+
3+
using PartitionedArrays
4+
5+
include(joinpath("..","..","adapt_tests.jl"))
6+
7+
with_mpi(adapt_tests)
8+
9+
end # module

test/mpi_array/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,6 @@ using PartitionedArrays
1313
@testset "p_timer_tests" begin include("p_timer_tests.jl") end
1414
@testset "fdm_example" begin include("fdm_example.jl") end
1515
@testset "fem_example" begin include("fem_example.jl") end
16+
@testset "adapt" begin include("adapt_tests.jl") end
1617

1718
end #module

0 commit comments

Comments
 (0)