@@ -83,9 +83,18 @@ println(model) # "Linear(3 => 1)"
8383This can be useful when using `@compact` to hierarchically construct
8484complex models to be used inside a `Chain`.
8585"""
86- macro compact(fex, _kwexs... )
87- # check inputs
88- Meta. isexpr(fex, :(-> )) || error(" expects a do block" )
86+ macro compact(_exs... )
87+ # check inputs, extracting function expression fex and unprocessed keyword arguments _kwexs
88+ isempty(_exs) && error(" expects at least two expressions: a function and at least one keyword" )
89+ if Meta. isexpr(_exs[1 ], :parameters)
90+ length(_exs) >= 2 || error(" expects an anonymous function" )
91+ fex = _exs[2 ]
92+ _kwexs = (_exs[1 ], _exs[3 : end ]. .. )
93+ else
94+ fex = _exs[1 ]
95+ _kwexs = _exs[2 : end ]
96+ end
97+ Meta. isexpr(fex, :(-> )) || error(" expects an anonymous function" )
8998 isempty(_kwexs) && error(" expects keyword arguments" )
9099 all(ex -> Meta. isexpr(ex, (:kw,:(= ),:parameters)), _kwexs) || error(" expects only keyword arguments" )
91100
@@ -112,33 +121,37 @@ macro compact(fex, _kwexs...)
112121 # make strings
113122 layer = " @compact"
114123 setup = NamedTuple(map(ex -> Symbol(string(ex. args[1 ])) => string(ex. args[2 ]), kwexs))
115- input = join(fex. args[1 ]. args, " , " )
124+ input =
125+ try
126+ fex_args = fex. args[1 ]
127+ isa(fex_args, Symbol) ? string(fex_args) : join(fex_args. args, " , " )
128+ catch e
129+ @warn " Function stringifying does not yet handle all cases. Falling back to empty string for input arguments"
130+ " "
131+ end
116132 block = string(Base. remove_linenums!(fex). args[2 ])
117133
118134 # edit expressions
119135 vars = map(ex -> ex. args[1 ], kwexs)
120- @gensym self
121- pushfirst!(fex. args[1 ]. args, self)
122- addprefix!(fex, self, vars)
136+ fex = supportself(fex, vars)
123137
124138 # assemble
125- return esc(quote
126- let
127- $ CompactLayer($ fex, $ name, ($ layer, $ input, $ block), $ setup; $ (kwexs... ))
128- end
129- end )
139+ return esc(:($ CompactLayer($ fex, $ name, ($ layer, $ input, $ block), $ setup; $ (kwexs... ))))
130140end
131141
132- function addprefix!(ex:: Expr , self, vars)
133- for i = 1 : length(ex. args)
134- if ex. args[i] in vars
135- ex. args[i] = :($ self.$ (ex. args[i]))
136- else
137- addprefix!(ex. args[i], self, vars)
142+ function supportself(fex:: Expr , vars)
143+ @gensym self
144+ @gensym curried_f
145+ # To avoid having to manipulate fex's arguments and body explicitly, we form a curried function first
146+ # that wraps the full fex expression, and then uncurry it programatically rather than syntactically.
147+ let_exprs = map(var -> :($ var = $ self.$ var), vars)
148+ return quote
149+ $ curried_f = ($ self) -> let $ (let_exprs... )
150+ $ fex
138151 end
152+ ($ self, args... ; kwargs... ) -> $ curried_f($ self)(args... ; kwargs... )
139153 end
140154end
141- addprefix!(not_ex, self, vars) = nothing
142155
143156struct CompactLayer{F,NT1<: NamedTuple ,NT2<: NamedTuple }
144157 fun:: F
0 commit comments