Skip to content

Commit 753127d

Browse files
Remove @compact(name=...) and replace with NoShow (#19)
Co-authored-by: Gaurav Arya <gauravarya272@gmail.com>
1 parent d3738a1 commit 753127d

File tree

6 files changed

+110
-83
lines changed

6 files changed

+110
-83
lines changed

src/Fluxperimental.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ include("chain.jl")
1313

1414
include("compact.jl")
1515

16+
include("noshow.jl")
17+
export NoShow
18+
1619
include("new_recur.jl")
1720

1821
end # module Fluxperimental

src/compact.jl

Lines changed: 3 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -68,20 +68,7 @@ for epoch in 1:1000
6868
Flux.train!((m,x,y) -> (m(x) - y)^2, model, data, optim)
6969
end
7070
```
71-
72-
You may also specify a `name` for the model, which will
73-
be used instead of the default printout, which gives a verbatim
74-
representation of the code used to construct the model:
75-
76-
```
77-
model = @compact(w=rand(3), name="Linear(3 => 1)") do x
78-
sum(w .* x)
79-
end
80-
println(model) # "Linear(3 => 1)"
81-
```
82-
83-
This can be useful when using `@compact` to hierarchically construct
84-
complex models to be used inside a `Chain`.
71+
To specify a custom printout for the model, you may find [`NoShow`](@ref) useful.
8572
"""
8673
macro compact(_exs...)
8774
# check inputs, extracting function expression fex and unprocessed keyword arguments _kwexs
@@ -108,16 +95,6 @@ macro compact(_exs...)
10895
kwexs2 = map(ex -> Expr(:kw, ex.args...), _kwexs) # handle keyword arguments provided before semicolon
10996
kwexs = (kwexs1..., kwexs2...)
11097

111-
# check if user has named layer:
112-
name = findfirst(ex -> ex.args[1] == :name, kwexs)
113-
if name !== nothing && kwexs[name].args[2] !== nothing
114-
length(kwexs) == 1 && error("expects keyword arguments")
115-
name_str = kwexs[name].args[2]
116-
# remove name from kwexs (a tuple)
117-
kwexs = (kwexs[1:name-1]..., kwexs[name+1:end]...)
118-
name = name_str
119-
end
120-
12198
# make strings
12299
layer = "@compact"
123100
setup = NamedTuple(map(ex -> Symbol(string(ex.args[1])) => string(ex.args[2]), kwexs))
@@ -136,7 +113,7 @@ macro compact(_exs...)
136113
fex = supportself(fex, vars)
137114

138115
# assemble
139-
return esc(:($CompactLayer($fex, $name, ($layer, $input, $block), $setup; $(kwexs...))))
116+
return esc(:($CompactLayer($fex, ($layer, $input, $block), $setup; $(kwexs...))))
140117
end
141118

142119
function supportself(fex::Expr, vars)
@@ -155,12 +132,11 @@ end
155132

156133
struct CompactLayer{F,NT1<:NamedTuple,NT2<:NamedTuple}
157134
fun::F
158-
name::Union{String,Nothing}
159135
strings::NTuple{3,String}
160136
setup_strings::NT1
161137
variables::NT2
162138
end
163-
CompactLayer(f::Function, name::Union{String,Nothing}, str::Tuple, setup_str::NamedTuple; kw...) = CompactLayer(f, name, str, setup_str, NamedTuple(kw))
139+
CompactLayer(f::Function, str::Tuple, setup_str::NamedTuple; kw...) = CompactLayer(f, str, setup_str, NamedTuple(kw))
164140
(m::CompactLayer)(x...) = m.fun(m.variables, x...)
165141
CompactLayer(args...) = error("CompactLayer is meant to be constructed by the macro")
166142
Flux.@functor CompactLayer
@@ -179,16 +155,6 @@ end
179155

180156
function Flux._big_show(io::IO, obj::CompactLayer, indent::Int=0, name=nothing)
181157
setup_strings = obj.setup_strings
182-
local_name = obj.name
183-
has_explicit_name = local_name !== nothing
184-
if has_explicit_name
185-
if indent != 0 || length(Flux.params(obj)) <= 2
186-
_just_show_params(io, local_name, obj, indent)
187-
else # indent == 0
188-
print(io, local_name)
189-
Flux._big_finale(io, obj)
190-
end
191-
else # no name, so print normally
192158
layer, input, block = obj.strings
193159
pre, post = ("(", ")")
194160
println(io, " "^indent, isnothing(name) ? "" : "$name = ", layer, pre)
@@ -220,7 +186,6 @@ function Flux._big_show(io::IO, obj::CompactLayer, indent::Int=0, name=nothing)
220186
else
221187
println(io, ",")
222188
end
223-
end
224189
end
225190

226191
# Modified from src/layers/show.jl

src/noshow.jl

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
2+
"""
3+
NoShow(layer)
4+
NoShow(string, layer)
5+
6+
This alters printing (for instance at the REPL prompt) to let you hide the complexity
7+
of some part of a Flux model. It has no effect on the actual running of the model.
8+
9+
By default it prints `NoShow(...)` instead of the given layer.
10+
If you provide a string, it prints that instead -- it can be anything,
11+
but it may make sense to print the name of a function which will
12+
re-create the same structure.
13+
14+
# Examples
15+
16+
```jldoctest
17+
julia> Chain(Dense(2 => 3), NoShow(Parallel(vcat, Dense(3 => 4), Dense(3 => 5))), Dense(9 => 10))
18+
Chain(
19+
Dense(2 => 3), # 9 parameters
20+
NoShow(...), # 36 parameters
21+
Dense(9 => 10), # 100 parameters
22+
) # Total: 8 arrays, 145 parameters, 1.191 KiB.
23+
24+
julia> pseudolayer((i,o)::Pair) = NoShow(
25+
"pseudolayer(\$i => \$o)",
26+
Parallel(+, Dense(i => o, relu), Dense(i => o, tanh)),
27+
)
28+
pseudolayer (generic function with 1 method)
29+
30+
julia> Chain(Dense(2 => 3), pseudolayer(3 => 10), Dense(9 => 10))
31+
Chain(
32+
Dense(2 => 3), # 9 parameters
33+
pseudolayer(3 => 10), # 80 parameters
34+
Dense(9 => 10), # 100 parameters
35+
) # Total: 8 arrays, 189 parameters, 1.379 KiB.
36+
```
37+
"""
38+
struct NoShow{T}
39+
str::String
40+
layer::T
41+
end
42+
43+
NoShow(layer) = NoShow("NoShow(...)", layer)
44+
45+
Flux.@functor NoShow
46+
47+
(no::NoShow)(x...) = no.layer(x...)
48+
49+
Base.show(io::IO, no::NoShow) = print(io, no.str)
50+
51+
Flux._show_leaflike(::NoShow) = true # I think this is right
52+
Flux._show_children(::NoShow) = (;) # Seems to be needed?
53+
54+
function Base.show(io::IO, ::MIME"text/plain", m::NoShow)
55+
if get(io, :typeinfo, nothing) === nothing # e.g., top level of REPL
56+
Flux._big_show(io, m)
57+
elseif !get(io, :compact, false) # e.g., printed inside a Vector, but not a matrix
58+
Flux._layer_show(io, m)
59+
else
60+
show(io, m)
61+
end
62+
end

test/compact.jl

Lines changed: 13 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ end
101101
(1, 128),
102102
(1,),
103103
]
104-
@test size(model(randn(n_in, 32))) == (1, 32)
104+
@test size(model(randn(Float32, n_in, 32))) == (1, 32)
105105
end
106106

107107
@testset "String representations" begin
@@ -118,15 +118,6 @@ end
118118
@test similar_strings(get_model_string(model), expected_string)
119119
end
120120

121-
@testset "Custom naming" begin
122-
model = @compact(w=Dense(32, 32), name="Linear(...)") do x, y
123-
tmp = sum(w(x))
124-
return tmp + y
125-
end
126-
expected_string = "Linear(...) # 1_056 parameters"
127-
@test similar_strings(get_model_string(model), expected_string)
128-
end
129-
130121
@testset "Hierarchical models" begin
131122
model1 = @compact(w1=Dense(32=>32, relu), w2=Dense(32=>32, relu)) do x
132123
w2(w1(x))
@@ -161,41 +152,6 @@ end
161152
@test similar_strings(get_model_string(model), expected_string)
162153
end
163154

164-
@testset "Hierarchy with inner model named" begin
165-
model = @compact(
166-
w1=@compact(w1=randn(32, 32), name="Model(32)") do x
167-
w1 * x
168-
end,
169-
w2=randn(32, 32),
170-
w3=randn(32),
171-
) do x
172-
w2 * w1(x)
173-
end
174-
expected_string = """@compact(
175-
Model(32), # 1_024 parameters
176-
w2 = randn(32, 32), # 1_024 parameters
177-
w3 = randn(32), # 32 parameters
178-
) do x
179-
w2 * w1(x)
180-
end # Total: 3 arrays, 2_080 parameters, 17.089 KiB."""
181-
@test similar_strings(get_model_string(model), expected_string)
182-
end
183-
184-
@testset "Hierarchy with outer model named" begin
185-
model = @compact(
186-
w1=@compact(w1=randn(32, 32)) do x
187-
w1 * x
188-
end,
189-
w2=randn(32, 32),
190-
w3=randn(32),
191-
name="Model(32)"
192-
) do x
193-
w2 * w1(x)
194-
end
195-
expected_string = """Model(32) # Total: 3 arrays, 2_080 parameters, 17.057KiB."""
196-
@test similar_strings(get_model_string(model), expected_string)
197-
end
198-
199155
@testset "Dependent initializations" begin
200156
# Test that initialization lines cannot depend on each other
201157
@test_throws UndefVarError @compact(y = 3, z = y^2) do x
@@ -234,3 +190,15 @@ end
234190
end
235191
end
236192

193+
194+
@testset "Custom naming of @compact with NoShow" begin
195+
_model = @compact(w=Dense(32, 32)) do x, y
196+
tmp = sum(w(x))
197+
return tmp + y
198+
end
199+
model = NoShow(_model)
200+
expected_string = "NoShow(...) # 1_056 parameters"
201+
@test similar_strings(get_model_string(model), expected_string)
202+
model2 = NoShow("test", _model)
203+
@test contains(get_model_string(model2), "test")
204+
end

test/noshow.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
2+
@testset "NoShow" begin
3+
d23 = Dense(2 => 3)
4+
d34 = Dense(3 => 4, tanh)
5+
d35 = Dense(3 => 5, relu)
6+
d910 = Dense(9 => 10)
7+
8+
model = Chain(d23, Parallel(vcat, d34, d35), d910)
9+
m_no = Chain(d23, NoShow(Parallel(vcat, d34, NoShow("zzz", d35))), d910)
10+
11+
@test sum(length, Flux.params(model)) == sum(length, Flux.params(m_no))
12+
13+
xin = randn(Float32, 2, 7)
14+
@test model(xin) m_no(xin)
15+
16+
# gradients
17+
grad = gradient(m -> m(xin)[1], model)[1]
18+
g_no = gradient(m -> m(xin)[1], m_no)[1]
19+
20+
@test grad.layers[2].layers[1].bias g_no.layers[2].layer.layers[1].bias
21+
@test grad.layers[2].layers[2].bias g_no.layers[2].layer.layers[2].layer.bias
22+
23+
# printing -- see also compact.jl for another test
24+
@test !contains(string(model), "NoShow(...)")
25+
@test contains(string(m_no), "NoShow(...)")
26+
@test !contains(string(m_no), "3 => 4")
27+
end
28+

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ using Flux, Fluxperimental
77
include("chain.jl")
88

99
include("compact.jl")
10+
include("noshow.jl")
1011

1112
include("new_recur.jl")
1213

0 commit comments

Comments
 (0)