Conversation
| Chain([conv_bn((1, 1), inplanes, outplanes; bias = false, rev = true)..., | ||
| MeanPool((2, 2))]...) |
There was a problem hiding this comment.
| Chain([conv_bn((1, 1), inplanes, outplanes; bias = false, rev = true)..., | |
| MeanPool((2, 2))]...) | |
| Chain([conv_bn((1, 1), inplanes, outplanes; bias = false, rev = true)]..., MeanPool((2, 2))) |
There was a problem hiding this comment.
The return type for conv_bn is already a Vector, so shouldn't just Chain(conv_bn((1, 1), inplanes, outplanes; bias = false, rev = true)..., MeanPool((2, 2))) work? Also, I know this suggestion has been shot down before because it would cause visual noise, but simply tweaking conv_bn to return a Chain does wonders for the TTFG:
master:
julia> using Metalhead
julia> using Flux: Zygote
julia> den = DenseNet();
julia> ip = rand(Float32, 224, 224, 3, 1);
julia> @time Zygote.gradient((m,x) -> sum(m(x)), den, ip);
77.621622 seconds (124.76 M allocations: 11.324 GiB, 1.67% gc time, 97.00% compilation time)with conv_bn returning a Chain:
julia> @time Zygote.gradient((m,x) -> sum(m(x)), den, ip);
28.244888 seconds (89.40 M allocations: 9.049 GiB, 3.60% gc time, 90.78% compilation time)There was a problem hiding this comment.
^ This needs some tricks to get this fast though. One major trick being that large Vectors that are being splatted to give Chains....should not be (Flux 0.13 deals with this, so this works). Removing a single splat to a large vector of layers (the "body" of the DenseNet) makes it shoot back up:
julia> @time Zygote.gradient((m,x) -> sum(m(x)), den, ip);
46.788491 seconds (117.59 M allocations: 10.873 GiB, 2.65% gc time, 94.90% compilation time)There was a problem hiding this comment.
Woops, you are indeed right and the suggestion looks good.
There was a problem hiding this comment.
One thing I am curious about is the large discrepancy b/w first compiles on master. I regularly get ~500s TTFG with DenseNet, you don't seem to get nearly as bad times. Mine is with GPUs turned off. Does that make up some of the difference?
There was a problem hiding this comment.
I am testing on an M1 Mac CPU, with 4 threads and Julia master. Maybe some of the discrepancy is there? Julia 1.8+ seemed to be an order of magnitude faster than Julia 1.7 last I checked for compilation of some stuff
DenseNethad a major regression in the compile time to differentiate it over the releases.This is often times due to very long
Chains. This is a small fix that makes things a lot more manageable for the moment.This is a pattern we have across the library, so maybe something to fix elsewhere as well.