Skip to content

stack overflow issue  #809

@jakubMitura14

Description

@jakubMitura14

Hello
I have a memory-constrained problem with a Lux.jl model that uses Zygote for most of the backpropagation.

I tried to approach this from chainrules perspective I need to checkpoint each Lux.jl layer in neural network. So I tried to achieve it like that :

function ChainRulesCore.rrule(::typeof(Lux.apply), l::Lux.AbstractExplicitLayer, x, ps, st)
    y = Lux.apply(l, x, ps, st)
    
    function pullback_checkpointed(Δy)
        y, pb =Zygote.pullback(Lux.apply,l, x, ps, st) 
        return NoTangent(), pb(Δy)
    end
    
    y, pullback_checkpointed
end

Rule gets invoked in backpropagation Hovewer the issue is that for some reason it try recursively to do backpropagation of the first line

 y = Lux.apply(l, x, ps, st)

so I get stack overflow error; how to correct it?

I had also posted this issue in https://discourse.julialang.org/t/avoid-storing-intermediate-results-from-the-forward-pass-by-default/119694/4?u=jakub_mitura

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions