Skip to content

Commit 2b378c1

Browse files
committed
test batch forward
1 parent 00cbd6e commit 2b378c1

File tree

1 file changed

+39
-15
lines changed

1 file changed

+39
-15
lines changed

test/finalizers.jl

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -29,20 +29,44 @@ GC.gc()
2929
@test length(FREE_LIST) == 1
3030
empty!(FREE_LIST)
3131

32-
dx, x = autodiff(ForwardWithPrimal, compute, Duplicated(1.0, 2.0))
33-
@test dx == 4.0
34-
GC.gc()
35-
@test length(FREE_LIST) == 2
36-
empty!(FREE_LIST)
32+
@testset "forward" begin
33+
dx, x = autodiff(ForwardWithPrimal, compute, Duplicated(1.0, 2.0))
34+
@test dx == 4.0
35+
GC.gc()
36+
@test length(FREE_LIST) == 2
37+
empty!(FREE_LIST)
3738

38-
((dx,), x) = autodiff(ReverseWithPrimal, compute, Active(1.0))
39-
@test dx == 2.0
40-
GC.gc()
41-
@test length(FREE_LIST) == 2
42-
empty!(FREE_LIST)
39+
dx, = autodiff(Forward, compute, Duplicated(1.0, 2.0))
40+
@test dx == 4.0
41+
GC.gc()
42+
@test length(FREE_LIST) == 2
43+
empty!(FREE_LIST)
44+
end
4345

44-
dx, x = autodiff(ForwardWithPrimal, compute, DuplicatedBatched(1.0, (1.0, 2.0)))
45-
@test dx == 4.0
46-
GC.gc()
47-
@test length(FREE_LIST) == 3
48-
empty!(FREE_LIST)
46+
@testset "batched forward" begin
47+
dx, x = autodiff(ForwardWithPrimal, compute, BatchDuplicated(1.0, (1.0, 2.0)))
48+
@test dx == 4.0
49+
GC.gc()
50+
@test length(FREE_LIST) == 3
51+
empty!(FREE_LIST)
52+
53+
dx, = autodiff(Forward, compute, BatchDuplicated(1.0, (1.0, 2.0)))
54+
@test dx == 4.0
55+
GC.gc()
56+
@test length(FREE_LIST) == 3
57+
empty!(FREE_LIST)
58+
end
59+
60+
@testset "reverse" begin
61+
((dx,), x) = autodiff(ReverseWithPrimal, compute, Active(1.0))
62+
@test dx == 2.0
63+
GC.gc()
64+
@test length(FREE_LIST) == 2
65+
empty!(FREE_LIST)
66+
67+
((dx,), x) = autodiff(Reverse, compute, Active(1.0))
68+
@test dx == 2.0
69+
GC.gc()
70+
@test length(FREE_LIST) == 2
71+
empty!(FREE_LIST)
72+
end

0 commit comments

Comments
 (0)