Skip to content

Commit 94ad9dd

Browse files
committed
don't capture a string for each variable
1 parent 8a1832a commit 94ad9dd

File tree

1 file changed

+35
-14
lines changed

1 file changed

+35
-14
lines changed

src/compact.jl

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,6 @@ macro compact(_exs...)
9595
kwexs = (kwexs1..., kwexs2...)
9696

9797
# make strings
98-
layer = "@compact"
99-
setup = NamedTuple(map(ex -> Symbol(string(ex.args[1])) => string(ex.args[2]), kwexs))
10098
input =
10199
try
102100
fex_args = fex.args[1]
@@ -112,7 +110,7 @@ macro compact(_exs...)
112110
fex = supportself(fex, vars)
113111

114112
# assemble
115-
return esc(:($CompactLayer($fex, ($layer, $input, $block), $setup; $(kwexs...))))
113+
return esc(:($CompactLayer($fex, ($input, $block); $(kwexs...))))
116114
end
117115

118116
function supportself(fex::Expr, vars)
@@ -129,17 +127,18 @@ function supportself(fex::Expr, vars)
129127
end
130128
end
131129

132-
struct CompactLayer{F,NT1<:NamedTuple,NT2<:NamedTuple}
130+
struct CompactLayer{F<:Function, NT<:NamedTuple}
133131
fun::F
134-
strings::NTuple{3,String}
135-
setup_strings::NT1
136-
variables::NT2
132+
strings::NTuple{2,String}
133+
variables::NT
137134
end
138-
CompactLayer(f::Function, str::Tuple, setup_str::NamedTuple; kw...) = CompactLayer(f, str, setup_str, NamedTuple(kw))
139-
(m::CompactLayer)(x...) = m.fun(m.variables, x...)
140-
CompactLayer(args...) = error("CompactLayer is meant to be constructed by the macro")
135+
CompactLayer(f::Function, str::Tuple; kw...) = CompactLayer(f, str, NamedTuple(kw))
136+
CompactLayer(args...) = error("CompactLayer is meant to be constructed by the macro @compact")
137+
141138
Flux.@functor CompactLayer
142139

140+
(m::CompactLayer)(x...) = m.fun(m.variables, x...)
141+
143142
Flux._show_children(m::CompactLayer) = m.variables
144143

145144
function Base.show(io::IO, ::MIME"text/plain", m::CompactLayer)
@@ -153,16 +152,17 @@ function Base.show(io::IO, ::MIME"text/plain", m::CompactLayer)
153152
end
154153

155154
function Flux._big_show(io::IO, obj::CompactLayer, indent::Int=0, name=nothing)
156-
setup_strings = obj.setup_strings
157-
layer, input, block = obj.strings
155+
input, block = obj.strings
158156
pre, post = ("(", ")")
159157
println(io, " "^indent, "@compact", pre)
160158
for k in keys(obj.variables)
161159
v = obj.variables[k]
162-
if Flux._show_leaflike(v)
160+
if false # Flux._show_leaflike(v)
163161
# If the value is a leaf, just print verbatim what the user wrote:
164-
str = String(k) * " = " * setup_strings[k]
162+
# str = String(k) * " = " * summary(v)
163+
str = String(k) * " isa " * string(typeof(v))
165164
_just_show_params(io, str, v, indent+2)
165+
# Flux._layer_show(io::IO, str, indent+2, nothing) # doesn't work
166166
else
167167
Flux._big_show(io, v, indent+2, String(k))
168168
end
@@ -187,6 +187,27 @@ function Flux._big_show(io::IO, obj::CompactLayer, indent::Int=0, name=nothing)
187187
end
188188
end
189189

190+
# Temporarily fixing things via piracy, but would be an easy change in Flux
191+
using Flux: params, underscorise, _childarray_sum, _nan_show
192+
function Flux._layer_show(io::IO, layer::AbstractArray, indent::Int=0, name=nothing)
193+
_str = isnothing(name) ? "" : "$name = "
194+
# str = _str * sprint(show, layer, context=io) # before
195+
# str = _str * String(typeof(layer).name.name) # print Array
196+
str = _str * summary(layer) # print size too, sometimes too long... trim it?
197+
print(io, " "^indent, str, indent==0 ? "" : ",")
198+
if !isempty(params(layer))
199+
print(io, " "^max(2, (indent==0 ? 20 : 39) - indent - length(str)))
200+
printstyled(io, "# ", underscorise(sum(length, params(layer); init=0)), " parameters";
201+
color=:light_black)
202+
nonparam = _childarray_sum(length, layer) - sum(length, params(layer), init=0)
203+
if nonparam > 0
204+
printstyled(io, ", plus ", underscorise(nonparam), indent==0 ? " non-trainable" : ""; color=:light_black)
205+
end
206+
_nan_show(io, params(layer))
207+
end
208+
indent==0 || println(io)
209+
end
210+
190211
# Modified from src/layers/show.jl
191212
function _just_show_params(io::IO, str::String, layer, indent::Int=0)
192213
print(io, " "^indent, str, indent==0 ? "" : ",")

0 commit comments

Comments
 (0)