Skip to content

Commit cc0e36f

Browse files
gaurav-aryaMilesCranmermcabbott
authored
Scope self arguments using let block syntax (#17)
Co-authored-by: Miles Cranmer <miles.cranmer@gmail.com> Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
1 parent 514dff6 commit cc0e36f

File tree

2 files changed

+52
-19
lines changed

2 files changed

+52
-19
lines changed

src/compact.jl

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,18 @@ println(model) # "Linear(3 => 1)"
8383
This can be useful when using `@compact` to hierarchically construct
8484
complex models to be used inside a `Chain`.
8585
"""
86-
macro compact(fex, _kwexs...)
87-
# check inputs
88-
Meta.isexpr(fex, :(->)) || error("expects a do block")
86+
macro compact(_exs...)
87+
# check inputs, extracting function expression fex and unprocessed keyword arguments _kwexs
88+
isempty(_exs) && error("expects at least two expressions: a function and at least one keyword")
89+
if Meta.isexpr(_exs[1], :parameters)
90+
length(_exs) >= 2 || error("expects an anonymous function")
91+
fex = _exs[2]
92+
_kwexs = (_exs[1], _exs[3:end]...)
93+
else
94+
fex = _exs[1]
95+
_kwexs = _exs[2:end]
96+
end
97+
Meta.isexpr(fex, :(->)) || error("expects an anonymous function")
8998
isempty(_kwexs) && error("expects keyword arguments")
9099
all(ex -> Meta.isexpr(ex, (:kw,:(=),:parameters)), _kwexs) || error("expects only keyword arguments")
91100

@@ -112,33 +121,37 @@ macro compact(fex, _kwexs...)
112121
# make strings
113122
layer = "@compact"
114123
setup = NamedTuple(map(ex -> Symbol(string(ex.args[1])) => string(ex.args[2]), kwexs))
115-
input = join(fex.args[1].args, ", ")
124+
input =
125+
try
126+
fex_args = fex.args[1]
127+
isa(fex_args, Symbol) ? string(fex_args) : join(fex_args.args, ", ")
128+
catch e
129+
@warn "Function stringifying does not yet handle all cases. Falling back to empty string for input arguments"
130+
""
131+
end
116132
block = string(Base.remove_linenums!(fex).args[2])
117133

118134
# edit expressions
119135
vars = map(ex -> ex.args[1], kwexs)
120-
@gensym self
121-
pushfirst!(fex.args[1].args, self)
122-
addprefix!(fex, self, vars)
136+
fex = supportself(fex, vars)
123137

124138
# assemble
125-
return esc(quote
126-
let
127-
$CompactLayer($fex, $name, ($layer, $input, $block), $setup; $(kwexs...))
128-
end
129-
end)
139+
return esc(:($CompactLayer($fex, $name, ($layer, $input, $block), $setup; $(kwexs...))))
130140
end
131141

132-
function addprefix!(ex::Expr, self, vars)
133-
for i = 1:length(ex.args)
134-
if ex.args[i] in vars
135-
ex.args[i] = :($self.$(ex.args[i]))
136-
else
137-
addprefix!(ex.args[i], self, vars)
142+
function supportself(fex::Expr, vars)
143+
@gensym self
144+
@gensym curried_f
145+
# To avoid having to manipulate fex's arguments and body explicitly, we form a curried function first
146+
# that wraps the full fex expression, and then uncurry it programatically rather than syntactically.
147+
let_exprs = map(var -> :($var = $self.$var), vars)
148+
return quote
149+
$curried_f = ($self) -> let $(let_exprs...)
150+
$fex
138151
end
152+
($self, args...; kwargs...) -> $curried_f($self)(args...; kwargs...)
139153
end
140154
end
141-
addprefix!(not_ex, self, vars) = nothing
142155

143156
struct CompactLayer{F,NT1<:NamedTuple,NT2<:NamedTuple}
144157
fun::F

test/compact.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,5 +212,25 @@ end
212212
end
213213
@test model(2) == _a + _b * 2 + c * 2^2
214214
end
215+
216+
@testset "Keyword arguments with anonymous function" begin
217+
model = @test_nowarn @compact(x -> x+a+b; a=1, b=2)
218+
@test model(3) == 1 + 2 + 3
219+
expected_string = """@compact(
220+
a = 1,
221+
b = 2,
222+
) do x
223+
x + a + b
224+
end"""
225+
@test similar_strings(get_model_string(model), expected_string)
226+
end
227+
228+
@testset "Scoping of parameter arguments" begin
229+
model = @compact(w1 = 3, w2 = 5) do a
230+
g(w1, w2) = 2 * w1 * w2
231+
return (w1 + w2) * g(a, a)
232+
end
233+
@test model(2) == (3 + 5) * 2 * 2 * 2
234+
end
215235
end
216236

0 commit comments

Comments
 (0)