Improve type stability of LayerNorm and Dropout#2005
Improve type stability of LayerNorm and Dropout#2005
Conversation
25f0a1b to
9259e4a
Compare
|
TTFG timings using the following snippet: Test codeusing Metalhead, Flux, Zygote
using Metalhead: ChannelLayerNorm
model = ConvNeXt(:tiny; inchannels=1, nclasses=1).layers
# ChannelLayerNorm isn't type stable yet (for the same reason as LayerNorm wasn't),
# So remove it for this demo
model = fmap(Returns(identity), model; exclude=Base.Fix2(isa, ChannelLayerNorm))
# display(model); println()
loss(m, x) = sum(m(x))
inputs = randn(Float32, 32, 32, 1, 1)
# @time loss(model, inputs)
# @time loss(model, inputs)
loss_grad(m, x) = gradient((m, x) -> loss(m, x), m, x)
@time loss_grad(model, inputs)
# @time loss_grad(model, inputs)Replacing the |
|
For kicks, here is Diffractor with JuliaDiff/ChainRules.jl#644: julia> @time loss_grad(model, inputs)
30.442982 seconds (92.61 M allocations: 4.148 GiB, 3.18% gc time, 89.07% compilation time) # tuple chain
23.051121 seconds (88.06 M allocations: 3.920 GiB, 3.81% gc time, 85.11% compilation time) # vector chain, requires https://github.com/JuliaDiff/Diffractor.jl/pull/82Re-enabling Edit: added times for vector chains using a patched Diffractor. |
|
Does Diffractor already work with most Flux models (or at least those with built-in layers)? I was under the impression that it wasn't there yet 😅 |
|
Not OOTB, which is why that ChainRules PR is required. |
|
@ToucheSir Could you try running the layer norm gradient with gpu? I have try that manual broadcast fusion before but |
|
You're right, it allocates one more time for over 2x the memory overhead. I also found this out the hard way recently while trying to fuse the RNN cell kernels for #2023, but forgot about the change here. |
9259e4a to
29ef2ff
Compare
Codecov Report
@@ Coverage Diff @@
## master #2005 +/- ##
==========================================
+ Coverage 87.10% 87.37% +0.27%
==========================================
Files 20 20
Lines 1528 1553 +25
==========================================
+ Hits 1331 1357 +26
+ Misses 197 196 -1
Continue to review full report at Codecov.
|
|
Any updates on this (like benchmarks after unfusing)? |
These two layers made use of explicit or implicit control flow (e.g. default keyword argument values) which Zygote does not like. This PR is essentially a set of small hacks to work around that.
Any ideas on how to avoid
return_typein_dropoutwould be much appreciated, but for now it seems to work.TODO benchmarks.
PR Checklist