@@ -14,13 +14,13 @@ for T in [
1414end
1515
1616function _big_show(io:: IO , obj, indent:: Int = 0 , name= nothing )
17- pre, post = obj isa Chain{ <: AbstractVector } ? ( " ([ " , " ]) " ) : ( " ( " , " ) " )
17+ pre, post = _show_pre_post(obj )
1818 children = _show_children(obj)
1919 if all(_show_leaflike, children)
2020 _layer_show(io, obj, indent, name)
2121 else
22- println(io, " " ^ indent, isnothing(name) ? " " : " $name = " , nameof(typeof(obj)), pre)
23- if obj isa Chain{<: NamedTuple } && children == getfield( obj, :layers)
22+ println(io, " " ^ indent, isnothing(name) ? " " : " $name = " , pre)
23+ if obj isa Chain{<: NamedTuple } || obj isa NamedTuple
2424 # then we insert names -- can this be done more generically?
2525 for k in Base. keys(obj)
2626 _big_show(io, obj[k], indent+ 2 , k)
@@ -44,6 +44,11 @@ function _big_show(io::IO, obj, indent::Int=0, name=nothing)
4444 end
4545end
4646
47+ _show_pre_post(obj) = string(nameof(typeof(obj)), " (" ), " )"
48+ _show_pre_post(:: Chain{<:AbstractVector} ) = " Chain([" , " ])"
49+ _show_pre_post(:: AbstractVector ) = " [" , " ]"
50+ _show_pre_post(:: NamedTuple ) = " (;" , " )"
51+
4752_show_leaflike(x) = isleaf(x) # mostly follow Functors, except for:
4853
4954# note the covariance of tuple, using <:T causes warning or error
7378
7479function _layer_show(io:: IO , layer, indent:: Int = 0 , name= nothing )
7580 _str = isnothing(name) ? " " : " $name = "
76- str = _str * sprint(show , layer, context = io )
81+ str = _str * _layer_string(io , layer)
7782 print(io, " " ^ indent, str, indent== 0 ? " " : " ," )
7883 if ! isempty(params(layer))
7984 print(io, " " ^ max(2 , (indent== 0 ? 20 : 39 ) - indent - length(str)))
@@ -88,6 +93,12 @@ color=:light_black)
8893 indent== 0 || println(io)
8994end
9095
96+ _layer_string(io:: IO , layer) = sprint(show, layer, context= io)
97+ # _layer_string(::IO, a::AbstractArray) = summary(layer) # sometimes too long e.g. CuArray
98+ # _layer_string(::IO, a::AbstractArray) = Base.dims2string(size(a)) * " " * String(typeof(a).name.name)
99+ # _layer_string(::IO, a::Array{T}) where T = Base.dims2string(size(a)) * " Array{$T}"
100+ # _layer_string(::IO, a::AbstractArray{T}) where T = Base.dims2string(size(a)) * " AbstractArray{$T}"
101+
91102function _big_finale(io:: IO , m)
92103 ps = params(m)
93104 if length(ps) > 2
@@ -133,3 +144,43 @@ _any(f, x::Number) = f(x)
133144# _any(f, x) = false
134145
135146_all(f, xs) = ! _any(! f, xs)
147+
148+ #=
149+
150+ julia> struct Tmp2; x; y; end; Flux.@functor Tmp2
151+
152+ # Before, notice Array(), NamedTuple(), and values
153+
154+ julia> Chain(Tmp2([Dense(2,3), randn(3,4)'], (x=1:3, y=Dense(3,4), z=rand(3))))
155+ Chain(
156+ Tmp2(
157+ Array(
158+ Dense(2 => 3), # 9 parameters
159+ [0.351978391016603 0.6408681372462821 -1.326533184688648; 0.09481930831795712 1.430103476272605 0.7250467613675332; 2.03372151428719 -0.015879812799495713 1.9499692162118236; -1.6346846180722918 -0.8364610153059454 -1.2907265737483433], # 12 parameters
160+ ),
161+ NamedTuple(
162+ 1:3, # 3 parameters
163+ Dense(3 => 4), # 16 parameters
164+ [0.9666158193429335, 0.01613900990539574, 0.0205920186127464], # 3 parameters
165+ ),
166+ ),
167+ ) # Total: 7 arrays, 43 parameters, 644 bytes.
168+
169+ # After, (; x=, y=, z=) and "3-element Array"
170+
171+ julia> Chain(Tmp2([Dense(2,3), randn(3,4)'], (x=1:3, y=Dense(3,4), z=rand(3))))
172+ Chain(
173+ Tmp2(
174+ [
175+ Dense(2 => 3), # 9 parameters
176+ 4×3 Adjoint, # 12 parameters
177+ ],
178+ (;
179+ x = 3-element UnitRange, # 3 parameters
180+ y = Dense(3 => 4), # 16 parameters
181+ z = 3-element Array, # 3 parameters
182+ ),
183+ ),
184+ ) # Total: 7 arrays, 43 parameters, 644 bytes.
185+
186+ =#
0 commit comments