1- struct Foo
2- x
3- y
4- end
1+
2+ using Functors : functor
3+
4+ struct Foo; x; y; end
55@functor Foo
66
7- struct Bar
8- x
9- end
7+ struct Bar; x; end
108@functor Bar
119
12- struct Baz
13- x
14- y
15- z
16- end
17- @functor Baz (y,)
10+ struct OneChild3; x; y; z; end
11+ @functor OneChild3 (y,)
1812
19- struct NoChildren
20- x
21- y
22- end
13+ struct NoChildren2; x; y; end
2314
2415@static if VERSION >= v" 1.6"
2516 @testset " ComposedFunction" begin
3122 end
3223end
3324
25+ # ##
26+ # ## Basic functionality
27+ # ##
28+
3429@testset " Nested" begin
3530 model = Bar(Foo(1 , [1 , 2 , 3 ]))
3631
5348 @test fmap(f, x; exclude = x -> x isa AbstractArray) == x
5449end
5550
51+ @testset " Property list" begin
52+ model = OneChild3(1 , 2 , 3 )
53+ model′ = fmap(x -> 2 x, model)
54+
55+ @test (model′. x, model′. y, model′. z) == (1 , 4 , 3 )
56+ end
57+
58+ @testset " cache" begin
59+ shared = [1 ,2 ,3 ]
60+ m1 = Foo(shared, Foo([1 ,2 ,3 ], Foo(shared, [1 ,2 ,3 ])))
61+ m1f = fmap(float, m1)
62+ @test m1f. x === m1f. y. y. x
63+ @test m1f. x != = m1f. y. x
64+ m1p = fmapstructure(identity, m1; prune = nothing )
65+ @test m1p == (x = [1 , 2 , 3 ], y = (x = [1 , 2 , 3 ], y = (x = nothing , y = [1 , 2 , 3 ])))
66+
67+ # A non-leaf node can also be repeated:
68+ m2 = Foo(Foo(shared, 4 ), Foo(shared, 4 ))
69+ @test m2. x === m2. y
70+ m2f = fmap(float, m2)
71+ @test m2f. x. x === m2f. y. x
72+ m2p = fmapstructure(identity, m2; prune = Bar(0 ))
73+ @test m2p == (x = (x = [1 , 2 , 3 ], y = 4 ), y = Bar(0 ))
74+
75+ # Repeated isbits types should not automatically be regarded as shared:
76+ m3 = Foo(Foo(shared, 1 : 3 ), Foo(1 : 3 , shared))
77+ m3p = fmapstructure(identity, m3; prune = 0 )
78+ @test m3p. y. y == 0
79+ @test_broken m3p. y. x == 1 : 3
80+ end
81+
82+ @testset " functor(typeof(x), y) from @functor" begin
83+ nt1, re1 = functor(Foo, (x= 1 , y= 2 , z= 3 ))
84+ @test nt1 == (x = 1 , y = 2 )
85+ @test re1((x = 10 , y = 20 )) == Foo(10 , 20 )
86+ re1((y = 22 , x = 11 )) # gives Foo(22, 11), is that a bug?
87+
88+ nt2, re2 = functor(Foo, (z= 33 , x= 1 , y= 2 ))
89+ @test nt2 == (x = 1 , y = 2 )
90+ @test re2((x = 10 , y = 20 )) == Foo(10 , 20 )
91+
92+ @test_throws Exception functor(Foo, (z= 33 , x= 1 )) # type NamedTuple has no field y
93+
94+ nt3, re3 = functor(OneChild3, (x= 1 , y= 2 , z= 3 ))
95+ @test nt3 == (y = 2 ,)
96+ @test re3((y = 20 ,)) == OneChild3(1 , 20 , 3 )
97+ re3(22 ) # gives OneChild3(1, 22, 3), is that a bug?
98+ end
99+
100+ @testset " functor(typeof(x), y) for Base types" begin
101+ nt11, re11 = functor(NamedTuple{(:x, :y)}, (x= 1 , y= 2 , z= 3 ))
102+ @test nt11 == (x = 1 , y = 2 )
103+ @test re11((x = 10 , y = 20 )) == (x = 10 , y = 20 )
104+ re11((y = 22 , x = 11 ))
105+ re11((11 , 22 )) # passes right through
106+
107+ nt12, re12 = functor(NamedTuple{(:x, :y)}, (z= 33 , x= 1 , y= 2 ))
108+ @test nt12 == (x = 1 , y = 2 )
109+ @test re12((x = 10 , y = 20 )) == (x = 10 , y = 20 )
110+
111+ @test_throws Exception functor(NamedTuple{(:x, :y)}, (z= 33 , x= 1 ))
112+ end
113+
114+ # ##
115+ # ## Extras
116+ # ##
117+
56118@testset " Walk" begin
57119 model = Foo((0 , Bar([1 , 2 , 3 ])), [4 , 5 ])
58120
59121 model′ = fmapstructure(identity, model)
60122 @test model′ == (; x= (0 , (; x= [1 , 2 , 3 ])), y= [4 , 5 ])
61123end
62124
63- @testset " Property list" begin
64- model = Baz(1 , 2 , 3 )
65- model′ = fmap(x -> 2 x, model)
66-
67- @test (model′. x, model′. y, model′. z) == (1 , 4 , 3 )
68- end
69-
70125@testset " fcollect" begin
71126 m1 = [1 , 2 , 3 ]
72127 m2 = 1
78133
79134 m1 = [1 , 2 , 3 ]
80135 m2 = Bar(m1)
81- m0 = NoChildren (:a, :b)
136+ m0 = NoChildren2 (:a, :b)
82137 m3 = Foo(m2, m0)
83138 m4 = Bar(m3)
84139 @test all(fcollect(m4) .=== [m4, m3, m2, m1, m0])
89144 @test all(fcollect(m3) .=== [m3, m1, m2])
90145end
91146
147+ # ##
148+ # ## Vararg forms
149+ # ##
150+
151+ @testset " fmap(f, x, y)" begin
152+ m1 = (x = [1 ,2 ], y = 3 )
153+ n1 = (x = [4 ,5 ], y = 6 )
154+ @test fmap(+ , m1, n1) == (x = [5 , 7 ], y = 9 )
155+
156+ # Reconstruction type comes from the first argument
157+ foo1 = Foo([7 ,8 ], 9 )
158+ @test fmap(+ , m1, foo1) == (x = [8 , 10 ], y = 12 )
159+ @test fmap(+ , foo1, n1) isa Foo
160+ @test fmap(+ , foo1, n1). x == [11 , 13 ]
161+
162+ # Mismatched trees should be an error
163+ m2 = (x = [1 ,2 ], y = (a = [3 ,4 ], b = 5 ))
164+ n2 = (x = [6 ,7 ], y = 8 )
165+ @test_throws Exception fmap(first∘ tuple, m2, n2) # ERROR: type Int64 has no field a
166+ @test_throws Exception fmap(first∘ tuple, m2, n2)
167+
168+ # The cache uses IDs from the first argument
169+ shared = [1 ,2 ,3 ]
170+ m3 = (x = shared, y = [4 ,5 ,6 ], z = shared)
171+ n3 = (x = shared, y = shared, z = [7 ,8 ,9 ])
172+ @test fmap(+ , m3, n3) == (x = [2 , 4 , 6 ], y = [5 , 7 , 9 ], z = [2 , 4 , 6 ])
173+ z3 = fmap(+ , m3, n3)
174+ @test z3. x === z3. z
175+
176+ # Pruning of duplicates:
177+ @test fmap(+ , m3, n3; prune = nothing ) == (x = [2 ,4 ,6 ], y = [5 ,7 ,9 ], z = nothing )
178+
179+ # More than two arguments:
180+ z4 = fmap(+ , m3, n3, m3, n3)
181+ @test z4 == fmap(x -> 2 x, z3)
182+ @test z4. x === z4. z
183+
184+ @test fmap(+ , foo1, m1, n1) isa Foo
185+ @static if VERSION >= v" 1.6" # fails on Julia 1.0
186+ @test fmap(.* , m1, foo1, n1) == (x = [4 * 7 , 2 * 5 * 8 ], y = 3 * 6 * 9 )
187+ end
188+ end
189+
190+ @static if VERSION >= v" 1.6" # Julia 1.0: LoadError: error compiling top-level scope: type definition not allowed inside a local scope
191+ @testset " old test update.jl" begin
192+ struct M{F,T,S}
193+ σ:: F
194+ W:: T
195+ b:: S
196+ end
197+
198+ @functor M
199+
200+ (m:: M )(x) = m. σ.(m. W * x .+ m. b)
201+
202+ m = M(identity, ones(Float32, 3 , 4 ), zeros(Float32, 3 ))
203+ x = ones(Float32, 4 , 2 )
204+ m̄, _ = gradient((m,x) -> sum(m(x)), m, x)
205+ m̂ = Functors. fmap(m, m̄) do x, y
206+ isnothing(x) && return y
207+ isnothing(y) && return x
208+ x .- 0.1f0 .* y
209+ end
210+
211+ @test m̂. W ≈ fill(0.8f0 , size(m. W))
212+ @test m̂. b ≈ fill(- 0.2f0 , size(m. b))
213+ end
214+ end # VERSION
215+
216+ # ##
217+ # ## FlexibleFunctors.jl
218+ # ##
219+
92220struct FFoo
93221 x
94222 y
@@ -102,13 +230,13 @@ struct FBar
102230end
103231@flexiblefunctor FBar p
104232
105- struct FBaz
233+ struct FOneChild4
106234 x
107235 y
108236 z
109237 p
110238end
111- @flexiblefunctor FBaz p
239+ @flexiblefunctor FOneChild4 p
112240
113241@testset " Flexible Nested" begin
114242 model = FBar(FFoo(1 , [1 , 2 , 3 ], (:y, )), (:x,))
132260end
133261
134262@testset " Flexible Property list" begin
135- model = FBaz (1 , 2 , 3 , (:x, :z))
263+ model = FOneChild4 (1 , 2 , 3 , (:x, :z))
136264 model′ = fmap(x -> 2 x, model)
137265
138266 @test (model′. x, model′. y, model′. z) == (2 , 2 , 6 )
147275 @test all(fcollect(m4, exclude = x -> x isa Array) .=== [m4, m3])
148276 @test all(fcollect(m4, exclude = x -> x isa FFoo) .=== [m4])
149277
150- m0 = NoChildren (:a, :b)
278+ m0 = NoChildren2 (:a, :b)
151279 m1 = [1 , 2 , 3 ]
152280 m2 = FBar(m1, ())
153281 m3 = FFoo(m2, m0, (:x, :y,))
0 commit comments