11using Enzyme, Test
22
3- concat () = ()
4- concat (a) = a
5- concat (a, b) = (a... , b... )
6- concat (a, b, c... ) = concat ( concat (a, b), c... )
3+ mixed_concat () = ()
4+ mixed_concat (a) = a
5+ mixed_concat (a, b) = (a... , b... )
6+ mixed_concat (a, b, c... ) = mixed_concat ( mixed_concat (a, b), c... )
77
8- metaconcat (x) = concat (x... )
9-
10- metaconcat2 (x, y) = concat (x... , y... )
11-
12- midconcat (x, y) = (x, concat (y... )... )
13-
14- metaconcat3 (x, y, z) = concat (x... , y... , z... )
8+ mixed_metaconcat (x) = mixed_concat (x... )
159
1610function mixed_metasumsq (f, args... )
1711 res = 0.0
@@ -33,35 +27,35 @@ function mixed_metasumsq3(f, args...)
3327 return res
3428end
3529
36- function make_byref (out, fn, args... )
30+ function mixed_make_byref (out, fn, args... )
3731 out[] = fn (args... )
3832 nothing
3933end
4034
41- function tupapprox (a, b)
42- if a isa Tuple && b isa Tuple
43- if length (a) != length (b)
44- return false
45- end
46- for (aa, bb) in zip (a, b)
47- if ! tupapprox (aa, bb)
48- return false
49- end
50- end
51- return true
52- end
53- if a isa Array && b isa Array
54- if size (a) != size (b)
55- return false
56- end
57- for i in length (a)
58- if ! tupapprox (a[i], b[i])
59- return false
60- end
61- end
62- return true
63- end
64- return a ≈ b
35+ function mixed_tupapprox (a, b)
36+ if a isa Tuple && b isa Tuple
37+ if length (a) != length (b)
38+ return false
39+ end
40+ for (aa, bb) in zip (a, b)
41+ if ! mixed_tupapprox (aa, bb)
42+ return false
43+ end
44+ end
45+ return true
46+ end
47+ if a isa Array && b isa Array
48+ if size (a) != size (b)
49+ return false
50+ end
51+ for i in length (a)
52+ if ! mixed_tupapprox (a[i], b[i])
53+ return false
54+ end
55+ end
56+ return true
57+ end
58+ return a ≈ b
6559end
6660
6761@testset " Mixed Reverse Apply iterate (tuple)" begin
8074 ),
8175 ]
8276 dx = deepcopy (dx_pre)
83- Enzyme. autodiff (Reverse, mixed_metasumsq, Active, Const (metaconcat ), Duplicated (x, dx))
84- @test tupapprox (dx, dx_post)
77+ Enzyme. autodiff (Reverse, mixed_metasumsq, Active, Const (mixed_metaconcat ), Duplicated (x, dx))
78+ @test mixed_tupapprox (dx, dx_post)
8579
8680 dx = deepcopy (dx_pre)
87- res = Enzyme. autodiff (ReverseWithPrimal, mixed_metasumsq, Active, Const (metaconcat ), Duplicated (x, dx))
81+ res = Enzyme. autodiff (ReverseWithPrimal, mixed_metasumsq, Active, Const (mixed_metaconcat ), Duplicated (x, dx))
8882 @test res[2 ] ≈ primal
89- @test tupapprox (dx, dx_post)
83+ @test mixed_tupapprox (dx, dx_post)
9084 end
9185end
9286
@@ -110,20 +104,20 @@ end
110104 ]
111105 out, dout, dout2 = Ref .((out_pre, dout_pre, dout2_pre))
112106 dx, dx2 = deepcopy .((dx_pre, dx_pre))
113- Enzyme. autodiff (Reverse, make_byref , Const, BatchDuplicatedNoNeed (out, (dout, dout2)), Const (mixed_metasumsq), Const (metaconcat ), BatchDuplicated (x, (dx, dx2)))
107+ Enzyme. autodiff (Reverse, mixed_make_byref , Const, BatchDuplicatedNoNeed (out, (dout, dout2)), Const (mixed_metasumsq), Const (mixed_metaconcat ), BatchDuplicated (x, (dx, dx2)))
114108 @test dout[] ≈ 0
115109 @test dout2[] ≈ 0
116- @test tupapprox (dx, dx_post)
117- @test tupapprox (dx2, dx2_post)
110+ @test mixed_tupapprox (dx, dx_post)
111+ @test mixed_tupapprox (dx2, dx2_post)
118112
119113 out, dout, dout2 = Ref .((out_pre, dout_pre, dout2_pre))
120114 dx, dx2 = deepcopy .((dx_pre, dx_pre))
121- Enzyme. autodiff (Reverse, make_byref , Const, BatchDuplicated (out, (dout, dout2)), Const (mixed_metasumsq), Const (metaconcat ), BatchDuplicated (x, (dx, dx2)))
115+ Enzyme. autodiff (Reverse, mixed_make_byref , Const, BatchDuplicated (out, (dout, dout2)), Const (mixed_metasumsq), Const (mixed_metaconcat ), BatchDuplicated (x, (dx, dx2)))
122116 @test out[] ≈ primal
123117 @test dout[] ≈ 0
124118 @test dout2[] ≈ 0
125- @test tupapprox (dx, dx_post)
126- @test tupapprox (dx2, dx2_post)
119+ @test mixed_tupapprox (dx, dx_post)
120+ @test mixed_tupapprox (dx2, dx2_post)
127121 end
128122end
129123
@@ -143,13 +137,13 @@ end
143137 ),
144138 ]
145139 dx = deepcopy (dx_pre)
146- Enzyme. autodiff (Reverse, mixed_metasumsq, Active, Const (metaconcat ), Duplicated (x, dx))
147- @test tupapprox (dx, dx_post)
140+ Enzyme. autodiff (Reverse, mixed_metasumsq, Active, Const (mixed_metaconcat ), Duplicated (x, dx))
141+ @test mixed_tupapprox (dx, dx_post)
148142
149143 dx = deepcopy (dx_pre)
150- res = Enzyme. autodiff (ReverseWithPrimal, mixed_metasumsq, Active, Const (metaconcat ), Duplicated (x, dx))
144+ res = Enzyme. autodiff (ReverseWithPrimal, mixed_metasumsq, Active, Const (mixed_metaconcat ), Duplicated (x, dx))
151145 @test res[2 ] ≈ primal
152- @test tupapprox (dx, dx_post)
146+ @test mixed_tupapprox (dx, dx_post)
153147 end
154148end
155149
@@ -173,20 +167,20 @@ end
173167 ]
174168 out, dout, dout2 = Ref .((out_pre, dout_pre, dout2_pre))
175169 dx, dx2 = deepcopy .((dx_pre, dx_pre))
176- Enzyme. autodiff (Reverse, make_byref , Const, BatchDuplicatedNoNeed (out, (dout, dout2)), Const (mixed_metasumsq), Const (metaconcat ), BatchDuplicated (x, (dx, dx2)))
170+ Enzyme. autodiff (Reverse, mixed_make_byref , Const, BatchDuplicatedNoNeed (out, (dout, dout2)), Const (mixed_metasumsq), Const (mixed_metaconcat ), BatchDuplicated (x, (dx, dx2)))
177171 @test dout[] ≈ 0
178172 @test dout2[] ≈ 0
179- @test tupapprox (dx, dx_post)
180- @test tupapprox (dx2, dx2_post)
173+ @test mixed_tupapprox (dx, dx_post)
174+ @test mixed_tupapprox (dx2, dx2_post)
181175
182176 out, dout, dout2 = Ref .((out_pre, dout_pre, dout2_pre))
183177 dx, dx2 = deepcopy .((dx_pre, dx_pre))
184- Enzyme. autodiff (Reverse, make_byref , Const, BatchDuplicated (out, (dout, dout2)), Const (mixed_metasumsq), Const (metaconcat ), BatchDuplicated (x, (dx, dx2)))
178+ Enzyme. autodiff (Reverse, mixed_make_byref , Const, BatchDuplicated (out, (dout, dout2)), Const (mixed_metasumsq), Const (mixed_metaconcat ), BatchDuplicated (x, (dx, dx2)))
185179 @test out[] ≈ primal
186180 @test dout[] ≈ 0
187181 @test dout2[] ≈ 0
188- @test tupapprox (dx, dx_post)
189- @test tupapprox (dx2, dx2_post)
182+ @test mixed_tupapprox (dx, dx_post)
183+ @test mixed_tupapprox (dx2, dx2_post)
190184 end
191185end
192186
@@ -196,12 +190,12 @@ struct MyRectilinearGrid5{FT,FZ}
196190end
197191
198192
199- @inline flatten_tuple (a:: Tuple ) = @inbounds a[2 : end ]
200- @inline flatten_tuple (a:: Tuple{<:Any} ) = tuple () # inner_flatten_tuple (a[1])...)
193+ @inline mixediter_flatten_tuple (a:: Tuple ) = @inbounds a[2 : end ]
194+ @inline mixediter_flatten_tuple (a:: Tuple{<:Any} ) = tuple () # inner_mixediter_flatten_tuple (a[1])...)
201195
202196function myupdate_state! (model)
203197 tupled = Base. inferencebarrier ((model,model))
204- flatten_tuple (tupled)
198+ mixediter_flatten_tuple (tupled)
205199 return nothing
206200end
207201
0 commit comments