Skip to content

Commit 8e1c24e

Browse files
authored
Rename functions in tests to avoid overwriting (#2597)
* Rename functions in tests to avoid overwriting * Also fix indentation
1 parent 75007b4 commit 8e1c24e

File tree

1 file changed

+53
-59
lines changed

1 file changed

+53
-59
lines changed

test/mixedapplyiter.jl

Lines changed: 53 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,11 @@
11
using Enzyme, Test
22

3-
concat() = ()
4-
concat(a) = a
5-
concat(a, b) = (a..., b...)
6-
concat(a, b, c...) = concat(concat(a, b), c...)
3+
mixed_concat() = ()
4+
mixed_concat(a) = a
5+
mixed_concat(a, b) = (a..., b...)
6+
mixed_concat(a, b, c...) = mixed_concat(mixed_concat(a, b), c...)
77

8-
metaconcat(x) = concat(x...)
9-
10-
metaconcat2(x, y) = concat(x..., y...)
11-
12-
midconcat(x, y) = (x, concat(y...)...)
13-
14-
metaconcat3(x, y, z) = concat(x..., y..., z...)
8+
mixed_metaconcat(x) = mixed_concat(x...)
159

1610
function mixed_metasumsq(f, args...)
1711
res = 0.0
@@ -33,35 +27,35 @@ function mixed_metasumsq3(f, args...)
3327
return res
3428
end
3529

36-
function make_byref(out, fn, args...)
30+
function mixed_make_byref(out, fn, args...)
3731
out[] = fn(args...)
3832
nothing
3933
end
4034

41-
function tupapprox(a, b)
42-
if a isa Tuple && b isa Tuple
43-
if length(a) != length(b)
44-
return false
45-
end
46-
for (aa, bb) in zip(a, b)
47-
if !tupapprox(aa, bb)
48-
return false
49-
end
50-
end
51-
return true
52-
end
53-
if a isa Array && b isa Array
54-
if size(a) != size(b)
55-
return false
56-
end
57-
for i in length(a)
58-
if !tupapprox(a[i], b[i])
59-
return false
60-
end
61-
end
62-
return true
63-
end
64-
return a b
35+
function mixed_tupapprox(a, b)
36+
if a isa Tuple && b isa Tuple
37+
if length(a) != length(b)
38+
return false
39+
end
40+
for (aa, bb) in zip(a, b)
41+
if !mixed_tupapprox(aa, bb)
42+
return false
43+
end
44+
end
45+
return true
46+
end
47+
if a isa Array && b isa Array
48+
if size(a) != size(b)
49+
return false
50+
end
51+
for i in length(a)
52+
if !mixed_tupapprox(a[i], b[i])
53+
return false
54+
end
55+
end
56+
return true
57+
end
58+
return a b
6559
end
6660

6761
@testset "Mixed Reverse Apply iterate (tuple)" begin
@@ -80,13 +74,13 @@ end
8074
),
8175
]
8276
dx = deepcopy(dx_pre)
83-
Enzyme.autodiff(Reverse, mixed_metasumsq, Active, Const(metaconcat), Duplicated(x, dx))
84-
@test tupapprox(dx, dx_post)
77+
Enzyme.autodiff(Reverse, mixed_metasumsq, Active, Const(mixed_metaconcat), Duplicated(x, dx))
78+
@test mixed_tupapprox(dx, dx_post)
8579

8680
dx = deepcopy(dx_pre)
87-
res = Enzyme.autodiff(ReverseWithPrimal, mixed_metasumsq, Active, Const(metaconcat), Duplicated(x, dx))
81+
res = Enzyme.autodiff(ReverseWithPrimal, mixed_metasumsq, Active, Const(mixed_metaconcat), Duplicated(x, dx))
8882
@test res[2] primal
89-
@test tupapprox(dx, dx_post)
83+
@test mixed_tupapprox(dx, dx_post)
9084
end
9185
end
9286

@@ -110,20 +104,20 @@ end
110104
]
111105
out, dout, dout2 = Ref.((out_pre, dout_pre, dout2_pre))
112106
dx, dx2 = deepcopy.((dx_pre, dx_pre))
113-
Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicatedNoNeed(out, (dout, dout2)), Const(mixed_metasumsq), Const(metaconcat), BatchDuplicated(x, (dx, dx2)))
107+
Enzyme.autodiff(Reverse, mixed_make_byref, Const, BatchDuplicatedNoNeed(out, (dout, dout2)), Const(mixed_metasumsq), Const(mixed_metaconcat), BatchDuplicated(x, (dx, dx2)))
114108
@test dout[] 0
115109
@test dout2[] 0
116-
@test tupapprox(dx, dx_post)
117-
@test tupapprox(dx2, dx2_post)
110+
@test mixed_tupapprox(dx, dx_post)
111+
@test mixed_tupapprox(dx2, dx2_post)
118112

119113
out, dout, dout2 = Ref.((out_pre, dout_pre, dout2_pre))
120114
dx, dx2 = deepcopy.((dx_pre, dx_pre))
121-
Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicated(out, (dout, dout2)), Const(mixed_metasumsq), Const(metaconcat), BatchDuplicated(x, (dx, dx2)))
115+
Enzyme.autodiff(Reverse, mixed_make_byref, Const, BatchDuplicated(out, (dout, dout2)), Const(mixed_metasumsq), Const(mixed_metaconcat), BatchDuplicated(x, (dx, dx2)))
122116
@test out[] primal
123117
@test dout[] 0
124118
@test dout2[] 0
125-
@test tupapprox(dx, dx_post)
126-
@test tupapprox(dx2, dx2_post)
119+
@test mixed_tupapprox(dx, dx_post)
120+
@test mixed_tupapprox(dx2, dx2_post)
127121
end
128122
end
129123

@@ -143,13 +137,13 @@ end
143137
),
144138
]
145139
dx = deepcopy(dx_pre)
146-
Enzyme.autodiff(Reverse, mixed_metasumsq, Active, Const(metaconcat), Duplicated(x, dx))
147-
@test tupapprox(dx, dx_post)
140+
Enzyme.autodiff(Reverse, mixed_metasumsq, Active, Const(mixed_metaconcat), Duplicated(x, dx))
141+
@test mixed_tupapprox(dx, dx_post)
148142

149143
dx = deepcopy(dx_pre)
150-
res = Enzyme.autodiff(ReverseWithPrimal, mixed_metasumsq, Active, Const(metaconcat), Duplicated(x, dx))
144+
res = Enzyme.autodiff(ReverseWithPrimal, mixed_metasumsq, Active, Const(mixed_metaconcat), Duplicated(x, dx))
151145
@test res[2] primal
152-
@test tupapprox(dx, dx_post)
146+
@test mixed_tupapprox(dx, dx_post)
153147
end
154148
end
155149

@@ -173,20 +167,20 @@ end
173167
]
174168
out, dout, dout2 = Ref.((out_pre, dout_pre, dout2_pre))
175169
dx, dx2 = deepcopy.((dx_pre, dx_pre))
176-
Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicatedNoNeed(out, (dout, dout2)), Const(mixed_metasumsq), Const(metaconcat), BatchDuplicated(x, (dx, dx2)))
170+
Enzyme.autodiff(Reverse, mixed_make_byref, Const, BatchDuplicatedNoNeed(out, (dout, dout2)), Const(mixed_metasumsq), Const(mixed_metaconcat), BatchDuplicated(x, (dx, dx2)))
177171
@test dout[] 0
178172
@test dout2[] 0
179-
@test tupapprox(dx, dx_post)
180-
@test tupapprox(dx2, dx2_post)
173+
@test mixed_tupapprox(dx, dx_post)
174+
@test mixed_tupapprox(dx2, dx2_post)
181175

182176
out, dout, dout2 = Ref.((out_pre, dout_pre, dout2_pre))
183177
dx, dx2 = deepcopy.((dx_pre, dx_pre))
184-
Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicated(out, (dout, dout2)), Const(mixed_metasumsq), Const(metaconcat), BatchDuplicated(x, (dx, dx2)))
178+
Enzyme.autodiff(Reverse, mixed_make_byref, Const, BatchDuplicated(out, (dout, dout2)), Const(mixed_metasumsq), Const(mixed_metaconcat), BatchDuplicated(x, (dx, dx2)))
185179
@test out[] primal
186180
@test dout[] 0
187181
@test dout2[] 0
188-
@test tupapprox(dx, dx_post)
189-
@test tupapprox(dx2, dx2_post)
182+
@test mixed_tupapprox(dx, dx_post)
183+
@test mixed_tupapprox(dx2, dx2_post)
190184
end
191185
end
192186

@@ -196,12 +190,12 @@ struct MyRectilinearGrid5{FT,FZ}
196190
end
197191

198192

199-
@inline flatten_tuple(a::Tuple) = @inbounds a[2:end]
200-
@inline flatten_tuple(a::Tuple{<:Any}) = tuple() #inner_flatten_tuple(a[1])...)
193+
@inline mixediter_flatten_tuple(a::Tuple) = @inbounds a[2:end]
194+
@inline mixediter_flatten_tuple(a::Tuple{<:Any}) = tuple() #inner_mixediter_flatten_tuple(a[1])...)
201195

202196
function myupdate_state!(model)
203197
tupled = Base.inferencebarrier((model,model))
204-
flatten_tuple(tupled)
198+
mixediter_flatten_tuple(tupled)
205199
return nothing
206200
end
207201

0 commit comments

Comments
 (0)