Skip to content

Commit 8a1832a

Browse files
committed
remove name from at-compact
1 parent cc0e36f commit 8a1832a

File tree

2 files changed

+25
-83
lines changed

2 files changed

+25
-83
lines changed

src/compact.jl

Lines changed: 3 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -68,20 +68,6 @@ for epoch in 1:1000
6868
Flux.train!((m,x,y) -> (m(x) - y)^2, model, data, optim)
6969
end
7070
```
71-
72-
You may also specify a `name` for the model, which will
73-
be used instead of the default printout, which gives a verbatim
74-
representation of the code used to construct the model:
75-
76-
```
77-
model = @compact(w=rand(3), name="Linear(3 => 1)") do x
78-
sum(w .* x)
79-
end
80-
println(model) # "Linear(3 => 1)"
81-
```
82-
83-
This can be useful when using `@compact` to hierarchically construct
84-
complex models to be used inside a `Chain`.
8571
"""
8672
macro compact(_exs...)
8773
# check inputs, extracting function expression fex and unprocessed keyword arguments _kwexs
@@ -108,16 +94,6 @@ macro compact(_exs...)
10894
kwexs2 = map(ex -> Expr(:kw, ex.args...), _kwexs) # handle keyword arguments provided before semicolon
10995
kwexs = (kwexs1..., kwexs2...)
11096

111-
# check if user has named layer:
112-
name = findfirst(ex -> ex.args[1] == :name, kwexs)
113-
if name !== nothing && kwexs[name].args[2] !== nothing
114-
length(kwexs) == 1 && error("expects keyword arguments")
115-
name_str = kwexs[name].args[2]
116-
# remove name from kwexs (a tuple)
117-
kwexs = (kwexs[1:name-1]..., kwexs[name+1:end]...)
118-
name = name_str
119-
end
120-
12197
# make strings
12298
layer = "@compact"
12399
setup = NamedTuple(map(ex -> Symbol(string(ex.args[1])) => string(ex.args[2]), kwexs))
@@ -136,7 +112,7 @@ macro compact(_exs...)
136112
fex = supportself(fex, vars)
137113

138114
# assemble
139-
return esc(:($CompactLayer($fex, $name, ($layer, $input, $block), $setup; $(kwexs...))))
115+
return esc(:($CompactLayer($fex, ($layer, $input, $block), $setup; $(kwexs...))))
140116
end
141117

142118
function supportself(fex::Expr, vars)
@@ -155,12 +131,11 @@ end
155131

156132
struct CompactLayer{F,NT1<:NamedTuple,NT2<:NamedTuple}
157133
fun::F
158-
name::Union{String,Nothing}
159134
strings::NTuple{3,String}
160135
setup_strings::NT1
161136
variables::NT2
162137
end
163-
CompactLayer(f::Function, name::Union{String,Nothing}, str::Tuple, setup_str::NamedTuple; kw...) = CompactLayer(f, name, str, setup_str, NamedTuple(kw))
138+
CompactLayer(f::Function, str::Tuple, setup_str::NamedTuple; kw...) = CompactLayer(f, str, setup_str, NamedTuple(kw))
164139
(m::CompactLayer)(x...) = m.fun(m.variables, x...)
165140
CompactLayer(args...) = error("CompactLayer is meant to be constructed by the macro")
166141
Flux.@functor CompactLayer
@@ -179,19 +154,9 @@ end
179154

180155
function Flux._big_show(io::IO, obj::CompactLayer, indent::Int=0, name=nothing)
181156
setup_strings = obj.setup_strings
182-
local_name = obj.name
183-
has_explicit_name = local_name !== nothing
184-
if has_explicit_name
185-
if indent != 0 || length(Flux.params(obj)) <= 2
186-
_just_show_params(io, local_name, obj, indent)
187-
else # indent == 0
188-
print(io, local_name)
189-
Flux._big_finale(io, obj)
190-
end
191-
else # no name, so print normally
192157
layer, input, block = obj.strings
193158
pre, post = ("(", ")")
194-
println(io, " "^indent, isnothing(name) ? "" : "$name = ", layer, pre)
159+
println(io, " "^indent, "@compact", pre)
195160
for k in keys(obj.variables)
196161
v = obj.variables[k]
197162
if Flux._show_leaflike(v)
@@ -220,7 +185,6 @@ function Flux._big_show(io::IO, obj::CompactLayer, indent::Int=0, name=nothing)
220185
else
221186
println(io, ",")
222187
end
223-
end
224188
end
225189

226190
# Modified from src/layers/show.jl

test/compact.jl

Lines changed: 22 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -118,15 +118,6 @@ end
118118
@test similar_strings(get_model_string(model), expected_string)
119119
end
120120

121-
@testset "Custom naming" begin
122-
model = @compact(w=Dense(32, 32), name="Linear(...)") do x, y
123-
tmp = sum(w(x))
124-
return tmp + y
125-
end
126-
expected_string = "Linear(...) # 1_056 parameters"
127-
@test similar_strings(get_model_string(model), expected_string)
128-
end
129-
130121
@testset "Hierarchical models" begin
131122
model1 = @compact(w1=Dense(32=>32, relu), w2=Dense(32=>32, relu)) do x
132123
w2(w1(x))
@@ -148,6 +139,28 @@ end
148139
@test similar_strings(get_model_string(model2), expected_string)
149140
end
150141

142+
#= # This test is broken:
143+
144+
julia> model1 = @compact(w1=Dense(32=>32, relu), w2=Dense(32=>32, relu)) do x
145+
w2(w1(x));
146+
147+
julia> model2 = @compact(w1=model1, w2=Dense(32=>32, relu)) do x
148+
w2(w1(x))
149+
end
150+
@compact(
151+
@compact(
152+
w1 = Dense(32 => 32, relu), # 1_056 parameters
153+
w2 = Dense(32 => 32, relu), # 1_056 parameters
154+
) do x
155+
w2(w1(x))
156+
end,
157+
w2 = Dense(32 => 32, relu), # 1_056 parameters
158+
) do x
159+
w2(w1(x))
160+
end # Total: 6 arrays, 3_168 parameters, 13.239 KiB.
161+
162+
=#
163+
151164
@testset "Array parameters" begin
152165
model = @compact(x=randn(32), w=Dense(32=>32)) do s
153166
w(x .* s)
@@ -161,41 +174,6 @@ end
161174
@test similar_strings(get_model_string(model), expected_string)
162175
end
163176

164-
@testset "Hierarchy with inner model named" begin
165-
model = @compact(
166-
w1=@compact(w1=randn(32, 32), name="Model(32)") do x
167-
w1 * x
168-
end,
169-
w2=randn(32, 32),
170-
w3=randn(32),
171-
) do x
172-
w2 * w1(x)
173-
end
174-
expected_string = """@compact(
175-
Model(32), # 1_024 parameters
176-
w2 = randn(32, 32), # 1_024 parameters
177-
w3 = randn(32), # 32 parameters
178-
) do x
179-
w2 * w1(x)
180-
end # Total: 3 arrays, 2_080 parameters, 17.089 KiB."""
181-
@test similar_strings(get_model_string(model), expected_string)
182-
end
183-
184-
@testset "Hierarchy with outer model named" begin
185-
model = @compact(
186-
w1=@compact(w1=randn(32, 32)) do x
187-
w1 * x
188-
end,
189-
w2=randn(32, 32),
190-
w3=randn(32),
191-
name="Model(32)"
192-
) do x
193-
w2 * w1(x)
194-
end
195-
expected_string = """Model(32) # Total: 3 arrays, 2_080 parameters, 17.057KiB."""
196-
@test similar_strings(get_model_string(model), expected_string)
197-
end
198-
199177
@testset "Dependent initializations" begin
200178
# Test that initialization lines cannot depend on each other
201179
@test_throws UndefVarError @compact(y = 3, z = y^2) do x

0 commit comments

Comments
 (0)