From 4d57436581ef470a532f6798606d24c500a4f57e Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Thu, 14 Sep 2023 14:30:44 +0530 Subject: [PATCH 1/3] Fixing `show` to not be confused by shared parameters. --- src/layers/show.jl | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/layers/show.jl b/src/layers/show.jl index 0ae14dd9ee..c01a887166 100644 --- a/src/layers/show.jl +++ b/src/layers/show.jl @@ -93,7 +93,8 @@ function _big_finale(io::IO, m) if length(ps) > 2 pars = underscorise(sum(length, ps; init=0)) bytes = Base.format_bytes(Base.summarysize(m)) - noncnt = _childarray_sum(_->1, m) - length(ps) + unique_params = IdSet() + noncnt = _childarray_sum(x -> unique_param!(x, unique_params), m) - length(ps) if noncnt > 0 nonparam = underscorise(_childarray_sum(length, m) - sum(length, ps; init=0)) printstyled(io, " "^08, "# Total: ", length(ps), " trainable arrays, "; color=:light_black) @@ -116,6 +117,15 @@ init=0) underscorise(n::Integer) = join(reverse(join.(reverse.(Iterators.partition(digits(n), 3)))), '_') +function unique_param!(x::AbstractArray{<:Number}, idset::Base.IdSet) + if x in idset + 0 + else + push!(idset, x) + 1 + end +end + function _nan_show(io::IO, x) if !isempty(x) && _all(iszero, x) printstyled(io, " (all zero)", color=:cyan) From 918da322f59ad9ef1eee71d92ab961184b3621fd Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Fri, 22 Sep 2023 02:27:32 +0530 Subject: [PATCH 2/3] Fixing implementation of `_big_finale` by adding two more methods for `_childarray_sum`. --- src/layers/show.jl | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/src/layers/show.jl b/src/layers/show.jl index c01a887166..36d9be63b4 100644 --- a/src/layers/show.jl +++ b/src/layers/show.jl @@ -71,7 +71,7 @@ for T in [ end end -function _layer_show(io::IO, layer, indent::Int=0, name=nothing) +function _layer_show(io::IO, layer, indent::Int=1, name=nothing) _str = isnothing(name) ? "" : "$name = " str = _str * sprint(show, layer, context=io) print(io, " "^indent, str, indent==0 ? "" : ",") @@ -94,7 +94,7 @@ function _big_finale(io::IO, m) pars = underscorise(sum(length, ps; init=0)) bytes = Base.format_bytes(Base.summarysize(m)) unique_params = IdSet() - noncnt = _childarray_sum(x -> unique_param!(x, unique_params), m) - length(ps) + noncnt = _childarray_sum(_ -> 1, m, unique_params) - length(ps) if noncnt > 0 nonparam = underscorise(_childarray_sum(length, m) - sum(length, ps; init=0)) printstyled(io, " "^08, "# Total: ", length(ps), " trainable arrays, "; color=:light_black) @@ -111,21 +111,23 @@ end _childarray_sum(f, x::AbstractArray{<:Number}) = f(x) _childarray_sum(f, x) = isleaf(x) ? 0 : sum(y -> _childarray_sum(f, y), Functors.children(x), init=0) +_childarray_sum(f, x::AbstractArray{<:Number}, idset::Base.IdSet) = f(x) +function _childarray_sum(f, x, idset::Base.IdSet) + isleaf(x) && return 0 -# utility functions - -underscorise(n::Integer) = - join(reverse(join.(reverse.(Iterators.partition(digits(n), 3)))), '_') - -function unique_param!(x::AbstractArray{<:Number}, idset::Base.IdSet) if x in idset - 0 + return 0 else push!(idset, x) - 1 + return sum(y -> _childarray_sum(f, y, idset), Functors.children(x), init = 0) end end +# utility functions + +underscorise(n::Integer) = + join(reverse(join.(reverse.(Iterators.partition(digits(n), 3)))), '_') + function _nan_show(io::IO, x) if !isempty(x) && _all(iszero, x) printstyled(io, " (all zero)", color=:cyan) From bdce0f094914bc20398e21b07ca45d8b2395f64c Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Fri, 22 Sep 2023 02:52:50 +0530 Subject: [PATCH 3/3] Minor fix. --- src/layers/show.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/layers/show.jl b/src/layers/show.jl index 36d9be63b4..3fb8fb0d78 100644 --- a/src/layers/show.jl +++ b/src/layers/show.jl @@ -71,7 +71,7 @@ for T in [ end end -function _layer_show(io::IO, layer, indent::Int=1, name=nothing) +function _layer_show(io::IO, layer, indent::Int=0, name=nothing) _str = isnothing(name) ? "" : "$name = " str = _str * sprint(show, layer, context=io) print(io, " "^indent, str, indent==0 ? "" : ",")