Skip to content

reset!(optimiser_state) #163

Open
Open
@Vilin97

Description

@Vilin97

Motivation and description

In my application I do 25 steps of gradient descent update! steps in a loop (solving a differential equation). I need the momentum from the previous 25 GD steps to NOT carry over to the next 25 GD steps. In other words, the behavior I am looking for is analogous to calling Flux.setup(optimiser, model) every time. Unfortunately, Flux.setup is type-unstable #162. It would be great to have a function reset!(optimiser_state) that resets the momenta. Maybe a more stringent requirement is that

state = Flux.setup(optimiser, model)
# do some training
reset!(state)
state == Flux.setup(optimiser, model)

holds.

Possible Implementation

Below is an implementation for Adam.

function reset!(leaf::Leaf{A, S}) where {A <: Optimisers.Adam, S}
    leaf.state[1] .= 0
    leaf.state[2] .= 0
    leaf.state = (leaf.state[1], leaf.state[2], leaf.rule.beta)
    nothing
end
function reset!(state::NamedTuple{(:layers,), L}) where {L}
    for layer in state.layers
        reset!(layer.weight)
        reset!(layer.bias)
    end
    nothing
end

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions