Skip to content

Commit 3c6591e

Browse files
committed
Add make_zero(!) tests
Aiming for full coverage of both new and old implementations of make_zero(!)
1 parent cf665cc commit 3c6591e

File tree

3 files changed

+323
-13
lines changed

3 files changed

+323
-13
lines changed

test/abi.jl

-13
Original file line numberDiff line numberDiff line change
@@ -480,19 +480,6 @@ mulsin(x) = sin(x[1] * x[2])
480480
@test Enzyme.autodiff(ForwardWithPrimal, () -> Enzyme.within_autodiff())[1]
481481
end
482482

483-
mutable struct ConstVal
484-
x::Float64
485-
const y::Float64
486-
end
487-
488-
@testset "Make Zero" begin
489-
v = ConstVal(2.0, 3.0)
490-
dv = make_zero(v)
491-
@test dv isa ConstVal
492-
@test dv.x 0.0
493-
@test dv.y 0.0
494-
end
495-
496483
@testset "Type inference" begin
497484
x = ones(10)
498485
@inferred autodiff(Enzyme.Reverse, abssum, Duplicated(x,x))

test/make_zero.jl

+322
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,322 @@
1+
using Enzyme
2+
using Test
3+
4+
mutable struct MutableWrapper{T}
5+
x::T
6+
end
7+
8+
Base.:(==)(a::MutableWrapper, b::MutableWrapper) = (a === b) || isequal(a.x, b.x)
9+
10+
struct Incomplete{T}
11+
s::String
12+
x::Float64
13+
w::T
14+
z # not initialized
15+
Incomplete(s, x, w) = new{typeof(w)}(s, x, w)
16+
end
17+
18+
function Base.:(==)(a::Incomplete, b::Incomplete)
19+
(a === b) && return true
20+
(isequal(a.s, b.s) && isequal(a.x, b.x) && isequal(a.w, b.w)) || return false
21+
if isdefined(a, :z) && isdefined(b, :z)
22+
isequal(a.z, b.z) || return false
23+
elseif isdefined(a, :z) || isdefined(b, :z)
24+
return false
25+
end
26+
return true
27+
end
28+
29+
mutable struct MutableIncomplete{T}
30+
s::String
31+
const x::Float64
32+
y::Float64
33+
z # not initialized
34+
w::T
35+
function MutableIncomplete(s, x, y, w)
36+
ret = new{typeof(w)}(s, x, y)
37+
ret.w = w
38+
return ret
39+
end
40+
end
41+
42+
function Base.:(==)(a::MutableIncomplete, b::MutableIncomplete)
43+
(a === b) && return true
44+
if !isequal(a.s, b.s) || !isequal(a.x, b.x) || !isequal(a.y, b.y) || !isequal(a.w, b.w)
45+
return false
46+
end
47+
if isdefined(a, :z) && isdefined(b, :z)
48+
isequal(a.z, b.z) || return false
49+
elseif isdefined(a, :z) || isdefined(b, :z)
50+
return false
51+
end
52+
return true
53+
end
54+
55+
@testset "make_zero" begin
56+
# floats
57+
@test make_zero(1.0) == 0.0
58+
@test make_zero(1.0im) == 0.0im
59+
60+
# float arrays + multiple references
61+
rearr = [1.0]
62+
imarr = [1.0im]
63+
rearr0 = make_zero(rearr)
64+
imarr0 = make_zero(imarr)
65+
@test typeof(rearr0) === typeof(rearr)
66+
@test typeof(imarr0) === typeof(imarr)
67+
@test rearr == [1.0] # no mutation
68+
@test imarr == [1.0im] # no mutation
69+
@test rearr0 == [0.0]
70+
@test imarr0 == [0.0im]
71+
rearrs0 = make_zero((rearr, rearr))
72+
imarrs0 = make_zero((imarr, imarr))
73+
@test typeof(rearrs0) === typeof((rearr, rearr))
74+
@test typeof(imarrs0) === typeof((imarr, imarr))
75+
@test rearr == [1.0] # no mutation
76+
@test imarr == [1.0im] # no mutation
77+
@test rearrs0[1] === rearrs0[2]
78+
@test imarrs0[1] === imarrs0[2]
79+
@test rearrs0[1] == [0.0]
80+
@test imarrs0[1] == [0.0im]
81+
82+
# floats in structs
83+
rewrapped = MutableWrapper(1.0)
84+
imwrapped = MutableWrapper(1.0im)
85+
rewrapped0 = make_zero(rewrapped)
86+
imwrapped0 = make_zero(imwrapped)
87+
@test typeof(rewrapped0) === typeof(rewrapped)
88+
@test typeof(imwrapped0) === typeof(imwrapped)
89+
@test rewrapped == MutableWrapper(1.0) # no mutation
90+
@test imwrapped == MutableWrapper(1.0im) # no mutation
91+
@test rewrapped0 == MutableWrapper(0.0)
92+
@test imwrapped0 == MutableWrapper(0.0im)
93+
94+
# generic array + multiple references
95+
wrapped = MutableWrapper(1.0)
96+
mixarr = ["a", 1.0, wrapped]
97+
mixarr0 = make_zero(mixarr)
98+
@test typeof(mixarr0) === typeof(mixarr)
99+
@test view(mixarr, 1:2) == ["a", 1.0] # no mutation
100+
@test mixarr[3] === wrapped # no mutation
101+
@test mixarr0 == ["a", 0.0, MutableWrapper(0.0)]
102+
mixarrs0 = make_zero((mixarr, mixarr))
103+
@test typeof(mixarrs0) === typeof((mixarr, mixarr))
104+
@test view(mixarr, 1:2) == ["a", 1.0] # no mutation
105+
@test mixarr[3] === wrapped # no mutation
106+
@test mixarrs0[1] === mixarrs0[2]
107+
@test mixarrs0[1] == ["a", 0.0, MutableWrapper(0.0)]
108+
109+
# non-differentiable array + copy_if_inactive
110+
constarr = ["a"]
111+
constarr0 = make_zero(constarr)
112+
@test typeof(constarr0) === typeof(constarr)
113+
@test constarr == ["a"] # no mutation
114+
@test constarr0 === constarr
115+
constarr0copy = make_zero(constarr, #=copy_if_inactive=#Val(true))
116+
@test typeof(constarr0copy) === typeof(constarr0)
117+
@test constarr == ["a"] # no mutation
118+
@test constarr0copy !== constarr
119+
@test constarr0copy == constarr
120+
121+
# Tuple
122+
tup = ("a", 1.0, MutableWrapper(1.0))
123+
tup0 = make_zero(tup)
124+
@test typeof(tup0) === typeof(tup)
125+
@test tup == ("a", 1.0, MutableWrapper(1.0)) # no mutation
126+
@test tup0 == ("a", 0.0, MutableWrapper(0.0))
127+
128+
# NamedTuple
129+
ntup = (a="a", b=1.0, c=MutableWrapper(1.0))
130+
ntup0 = make_zero(ntup)
131+
@test typeof(ntup0) === typeof(ntup)
132+
@test ntup == (a="a", b=1.0, c=MutableWrapper(1.0)) # no mutation
133+
@test ntup0 == (a="a", b=0.0, c=MutableWrapper(0.0))
134+
135+
# Box + multiple references
136+
box = Core.Box(1.0)
137+
box0 = make_zero(box)
138+
@test typeof(box0) === typeof(box)
139+
@test box.contents == 1.0 # no mutation
140+
@test box0.contents == 0.0
141+
boxes0 = make_zero((box, box))
142+
@test typeof(boxes0) === typeof((box, box))
143+
@test box.contents == 1.0 # no mutation
144+
@test boxes0[1] === boxes0[2]
145+
@test boxes0[1].contents == 0.0
146+
147+
# differentiable custom type + multiple references
148+
wrapped = MutableWrapper(1.0)
149+
wrapped0 = make_zero(wrapped)
150+
@test typeof(wrapped0) === typeof(wrapped)
151+
@test wrapped == MutableWrapper(1.0) # no mutation
152+
@test wrapped0 == MutableWrapper(0.0)
153+
wrappeds0 = make_zero((wrapped, wrapped))
154+
@test typeof(wrappeds0) === typeof((wrapped, wrapped))
155+
@test wrapped == MutableWrapper(1.0) # no mutation
156+
@test wrappeds0[1] === wrappeds0[2]
157+
@test wrappeds0[1] == MutableWrapper(0.0)
158+
159+
# non-differentiable custom type + copy_if_inactive
160+
constwrapped = MutableWrapper("a")
161+
constwrapped0 = make_zero(constwrapped)
162+
@test typeof(constwrapped0) === typeof(constwrapped)
163+
@test constwrapped == MutableWrapper("a") # no mutation
164+
@test constwrapped0 === constwrapped
165+
constwrapped0copy = make_zero(constwrapped, #=copy_if_inactive=#Val(true))
166+
@test typeof(constwrapped0copy) === typeof(constwrapped0)
167+
@test constwrapped == MutableWrapper("a") # no mutation
168+
@test constwrapped0copy !== constwrapped
169+
@test constwrapped0copy == constwrapped
170+
171+
# immutable struct with active, mutable, inactive and undefined fields
172+
incomplete = Incomplete("a", 1.0, MutableWrapper(1.0))
173+
incomplete0 = make_zero(incomplete)
174+
@test typeof(incomplete0) === typeof(incomplete)
175+
@test incomplete == Incomplete("a", 1.0, MutableWrapper(1.0)) # no mutation
176+
@test incomplete0 == Incomplete("a", 0.0, MutableWrapper(0.0))
177+
178+
# mutable struct with inactive, active, undefined, and mutable fields
179+
# + multiple references
180+
incompletemut = MutableIncomplete("a", 1.0, 1.0, MutableWrapper(1.0))
181+
incompletemut0 = make_zero(incompletemut)
182+
@test typeof(incompletemut0) === typeof(incompletemut)
183+
@test incompletemut == MutableIncomplete("a", 1.0, 1.0, MutableWrapper(1.0)) # no mutation
184+
@test incompletemut0 == MutableIncomplete("a", 0.0, 0.0, MutableWrapper(0.0))
185+
incompletemuts0 = make_zero((incompletemut, incompletemut))
186+
@test typeof(incompletemuts0) === typeof((incompletemut, incompletemut))
187+
@test incompletemut == MutableIncomplete("a", 1.0, 1.0, MutableWrapper(1.0)) # no mutation
188+
@test incompletemuts0[1] === incompletemuts0[2]
189+
@test incompletemuts0[1] == MutableIncomplete("a", 0.0, 0.0, MutableWrapper(0.0))
190+
end
191+
192+
@testset "make_zero!" begin
193+
# floats in mutable struct
194+
rewrapped, imwrapped = MutableWrapper(1.0), MutableWrapper(1.0im)
195+
make_zero!(rewrapped)
196+
make_zero!(imwrapped)
197+
@test rewrapped == MutableWrapper(0.0)
198+
@test imwrapped == MutableWrapper(0.0im)
199+
200+
# mixed tuple in mutable container
201+
wrapped = MutableWrapper(1.0)
202+
tuparr = [(1.0, wrapped)]
203+
make_zero!(tuparr)
204+
@test tuparr[1] === (0.0, wrapped)
205+
@test wrapped == MutableWrapper(0.0)
206+
207+
# mixed namedtuple in mutable container
208+
wrapped = MutableWrapper(1.0)
209+
ntuparr = [(a=1.0, b=wrapped)]
210+
make_zero!(ntuparr)
211+
@test ntuparr[1] === (a=0.0, b=wrapped)
212+
@test wrapped == MutableWrapper(0.0)
213+
214+
# immutable struct with active, mutable, inactive and undefined fields in mutable container
215+
wrapped = MutableWrapper(1.0)
216+
incompletearr = [Incomplete("a", 1.0, wrapped)]
217+
make_zero!(incompletearr)
218+
@test incompletearr[1] == Incomplete("a", 0.0, wrapped)
219+
@test wrapped == MutableWrapper(0.0)
220+
221+
# floats in Ref
222+
reref, imref = Ref(1.0), Ref(1.0im)
223+
make_zero!(reref)
224+
make_zero!(imref)
225+
@test reref[] == 0.0
226+
@test imref[] == 0.0im
227+
228+
# float arrays
229+
rearr, imarr = [1.0], [1.0im]
230+
make_zero!(rearr)
231+
make_zero!(imarr)
232+
@test rearr[1] == 0.0
233+
@test imarr[1] == 0.0im
234+
235+
# non-differentiable array
236+
constarr = ["a"]
237+
make_zero!(constarr)
238+
@test constarr[1] == "a"
239+
240+
# array with active, mutable, inactive and unassigned elements + multiple references
241+
wrapped = MutableWrapper(1.0)
242+
genericarr = Vector(undef, 4)
243+
genericarr[1:3] .= ("a", 1.0, wrapped)
244+
genericarrs = [genericarr, genericarr]
245+
make_zero!(genericarrs)
246+
@test genericarrs[1] === genericarrs[2]
247+
@test genericarrs[1] === genericarr
248+
@test view(genericarr, 1:2) == ["a", 0.0]
249+
@test genericarr[3] === wrapped
250+
@test wrapped == MutableWrapper(0.0)
251+
@test !isassigned(genericarr, 4)
252+
253+
# Ref with multiple references
254+
genericref = Ref((1.0,))
255+
genericrefs = [genericref, genericref]
256+
make_zero!(genericrefs)
257+
@test genericrefs[1] === genericrefs[2]
258+
@test genericrefs[1] === genericref
259+
@test genericref[] == (0.0,)
260+
261+
# Ref with mutable value
262+
wrapped = MutableWrapper(1.0)
263+
mutref = Ref(wrapped)
264+
make_zero!(mutref)
265+
@test mutref[] === wrapped
266+
@test wrapped == MutableWrapper(0.0)
267+
268+
# Ref with non-differentiable value
269+
constref = Ref("a")
270+
make_zero!(constref)
271+
@test constref[] == "a"
272+
273+
# Box with multiple references
274+
box = Core.Box(1.0)
275+
boxes = [box, box]
276+
make_zero!(boxes)
277+
@test boxes[1] === boxes[2]
278+
@test boxes[1] === box
279+
@test box.contents == 0.0
280+
281+
# Box with mutable value
282+
wrapped = MutableWrapper(1.0)
283+
mutbox = Core.Box(wrapped)
284+
make_zero!(mutbox)
285+
@test mutbox.contents === wrapped
286+
@test wrapped == MutableWrapper(0.0)
287+
288+
# Box with non-differentiable value
289+
constbox = Core.Box("a")
290+
make_zero!(constbox)
291+
@test constbox.contents == "a"
292+
293+
# mutable struct with inactive, active, const active, undefined, and mutable fields
294+
# + multiple references
295+
wrapped = MutableWrapper(1.0)
296+
incompletemut = MutableIncomplete("a", #=const=#1.0, 1.0, wrapped)
297+
incompletemuts = [incompletemut, incompletemut]
298+
make_zero!(incompletemuts)
299+
@test incompletemuts[1] === incompletemuts[2]
300+
@test incompletemuts[1] === incompletemut
301+
@test incompletemut == MutableIncomplete("a", #=const=#0.0, 0.0, MutableWrapper(0.0))
302+
@test incompletemut.w === wrapped
303+
304+
# wrapped differentiable array
305+
arr = [1.0]
306+
arrwrapped = MutableWrapper(arr)
307+
make_zero!(arrwrapped)
308+
@test arrwrapped.x === arr
309+
@test arr == [0.0]
310+
311+
# early error on active/mixed type
312+
@test_throws ArgumentError make_zero!(1.0)
313+
@test_throws ArgumentError make_zero!((1.0, MutableWrapper(1.0)))
314+
315+
# immutable struct with both active and undefined fields in immutable container
316+
# (the previous implementation would fail due to #1935)
317+
wrapped = MutableWrapper(1.0)
318+
incompletetuparr = [(Incomplete("a", 1.0, wrapped),)]
319+
make_zero!(incompletetuparr)
320+
@test incompletetuparr[1][1] == Incomplete("a", 0.0, MutableWrapper(0.0))
321+
@test incompletetuparr[1][1].w === wrapped
322+
end

test/runtests.jl

+1
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ end
8484

8585
include("abi.jl")
8686
include("typetree.jl")
87+
include("make_zero.jl")
8788

8889
include("rules.jl")
8990
include("rrules.jl")

0 commit comments

Comments
 (0)