-
Notifications
You must be signed in to change notification settings - Fork 95
Open
Description
Hello
I have a memory-constrained problem with a Lux.jl model that uses Zygote for most of the backpropagation.
I tried to approach this from chainrules perspective I need to checkpoint each Lux.jl layer in neural network. So I tried to achieve it like that :
function ChainRulesCore.rrule(::typeof(Lux.apply), l::Lux.AbstractExplicitLayer, x, ps, st)
y = Lux.apply(l, x, ps, st)
function pullback_checkpointed(Δy)
y, pb =Zygote.pullback(Lux.apply,l, x, ps, st)
return NoTangent(), pb(Δy)
end
y, pullback_checkpointed
end
Rule gets invoked in backpropagation Hovewer the issue is that for some reason it try recursively to do backpropagation of the first line
y = Lux.apply(l, x, ps, st)
so I get stack overflow error; how to correct it?
I had also posted this issue in https://discourse.julialang.org/t/avoid-storing-intermediate-results-from-the-forward-pass-by-default/119694/4?u=jakub_mitura
Metadata
Metadata
Assignees
Labels
No labels