@@ -4,64 +4,66 @@ import Flux: _big_show
44 @compact(forward::Function; name=nothing, parameters...)
55
66Creates a layer by specifying some `parameters`, in the form of keywords,
7- and (usually as a `do` block) a function for the forward pass.
7+ and a function for the forward pass (often as a `do` block).
8+
89You may think of `@compact` as a specialized `let` block creating local variables
910that are trainable in Flux.
1011Declared variable names may be used within the body of the `forward` function.
1112
12- Here is a linear model:
13+ # Examples
14+
15+ Here is a linear model, equivalent to `Flux.Scale`:
1316
1417```
15- r = @compact(w = rand(3)) do x
16- w .* x
17- end
18- r([1, 1, 1]) # x is set to [1, 1, 1].
18+ using Flux, Fluxperimental
19+
20+ w = rand(3)
21+ sc = @compact(x -> x .* w; w)
22+
23+ sc([1 10 100]) # 3×3 Matrix as output.
24+ ans ≈ Flux.Scale(w)([1 10 100]) # equivalent Flux layer
1925```
2026
21- Here is a linear model with bias and activation:
27+ Here is a linear model with bias and activation, equivalent to Flux's `Dense` layer.
28+ The forward pass function is now written as a do block, instead of `x -> begin y = W * x; ...`
2229
2330```
24- d_in = 5
31+ d_in = 3
2532d_out = 7
26- d = @compact(W = randn(d_out, d_in), b = zeros(d_out), act = relu) do x
33+ layer = @compact(W = randn(d_out, d_in), b = zeros(d_out), act = relu) do x
2734 y = W * x
2835 act.(y .+ b)
2936end
30- d(ones(5, 10)) # 7×10 Matrix as output.
31- d([1,2,3,4,5]) ≈ Dense(d.variables.W, zeros(7), relu)([1,2,3,4,5]) # Equivalent to a dense layer
37+
38+ den = Dense(layer.variables.W, zeros(7), relu)([1,2,3]) # equivalent Flux layer
39+ layer(ones(3, 10)) ≈ layer(ones(3, 10)) # 7×10 Matrix as output.
3240```
33- ```
3441
35- Finally, here is a simple MLP:
42+ Finally, here is a simple MLP, equivalent to a `Chain` with 5 `Dense` layers :
3643
3744```
38- using Flux
39-
40- n_in = 1
41- n_out = 1
45+ d_in = 1
4246nlayers = 3
4347
4448model = @compact(
45- w1=Dense(n_in, 128),
46- w2=[Dense(128, 128) for i=1:nlayers],
47- w3=Dense(128, n_out),
48- act=relu
49+ lay1 = Dense(d_in => 64),
50+ lay234 = [Dense(64 => 64) for i=1:nlayers],
51+ wlast = rand32(64),
4952) do x
50- embed = act(w1 (x))
51- for w in w2
52- embed = act(w(embed ))
53+ y = tanh.(lay1 (x))
54+ for lay in lay234
55+ y = relu.(lay(y ))
5356 end
54- out = w3(embed)
55- return out
57+ return wlast' * y
5658end
5759
58- model(randn(n_in, 32 )) # 1×32 Matrix as output.
60+ model(randn(Float32, d_in, 8 )) # 1×8 array as output.
5961```
6062
61- We can train this model just like any `Chain`:
63+ We can train this model just like any `Chain`, for example :
6264
6365```
64- data = [([x], 2x-x^3) for x in -2:0.1f0:2]
66+ data = [([x], [ 2x-x^3] ) for x in -2:0.1f0:2]
6567optim = Flux.setup(Adam(), model)
6668
6769for epoch in 1:1000
7072```
7173"""
7274macro compact(_exs... )
75+ _compact(_exs... ) |> esc
76+ end
77+
78+ function _compact(_exs... )
7379 # check inputs, extracting function expression fex and unprocessed keyword arguments _kwexs
74- isempty(_exs) && error(" expects at least two expressions: a function and at least one keyword" )
80+ isempty(_exs) && error(" @compact expects at least two expressions: a function and at least one keyword" )
7581 if Meta. isexpr(_exs[1 ], :parameters)
76- length(_exs) >= 2 || error(" expects an anonymous function" )
82+ length(_exs) >= 2 || error(" @compact expects an anonymous function" )
7783 fex = _exs[2 ]
7884 _kwexs = (_exs[1 ], _exs[3 : end ]. .. )
7985 else
8086 fex = _exs[1 ]
8187 _kwexs = _exs[2 : end ]
8288 end
83- Meta. isexpr(fex, :(-> )) || error(" expects an anonymous function" )
84- isempty(_kwexs) && error(" expects keyword arguments" )
85- all(ex -> Meta. isexpr(ex, (:kw,:(= ),:parameters)), _kwexs) || error(" expects only keyword arguments" )
89+ Meta. isexpr(fex, :(-> )) || error(" @compact expects an anonymous function" )
90+ isempty(_kwexs) && error(" @compact expects keyword arguments" )
91+ all(ex -> Meta. isexpr(ex, (:kw,:(= ),:parameters)), _kwexs) || error(" @compact expects only keyword arguments" )
8692
8793 # process keyword arguments
8894 if Meta. isexpr(_kwexs[1 ], :parameters) # handle keyword arguments provided after semicolon
@@ -100,20 +106,20 @@ macro compact(_exs...)
100106 fex_args = fex. args[1 ]
101107 isa(fex_args, Symbol) ? string(fex_args) : join(fex_args. args, " , " )
102108 catch e
103- @warn " Function stringifying does not yet handle all cases. Falling back to empty string for input arguments "
104- " "
109+ @warn """ @compact's function stringifying does not yet handle all cases. Falling back to "?" """ maxlog = 1
110+ " ? "
105111 end
106- block = string(Base. remove_linenums!(fex). args[2 ])
112+ block = string(Base. remove_linenums!(fex). args[2 ]) # TODO make this remove macro comments
107113
108114 # edit expressions
109115 vars = map(ex -> ex. args[1 ], kwexs)
110- fex = supportself (fex, vars)
116+ fex = _supportself (fex, vars)
111117
112118 # assemble
113- return esc( :($ CompactLayer($ fex, ($ input, $ block); $ (kwexs... ) )))
119+ return :($ CompactLayer($ fex, ($ input, $ block); $ (kwexs... )))
114120end
115121
116- function supportself (fex:: Expr , vars)
122+ function _supportself (fex:: Expr , vars)
117123 @gensym self
118124 @gensym curried_f
119125 # To avoid having to manipulate fex's arguments and body explicitly, we form a curried function first
@@ -173,7 +179,7 @@ function Flux._big_show(io::IO, obj::CompactLayer, indent::Int=0, name=nothing)
173179 print(io, " " ^ indent, post)
174180 end
175181
176- input != " " && print(io, " do " , input)
182+ print(io, " do " , input)
177183 if block != " "
178184 block_to_print = block[6 : end ]
179185 # Increase indentation of block according to `indent`:
0 commit comments