Open
Description
Motivation and description
In my application I do 25 steps of gradient descent update!
steps in a loop (solving a differential equation). I need the momentum from the previous 25 GD steps to NOT carry over to the next 25 GD steps. In other words, the behavior I am looking for is analogous to calling Flux.setup(optimiser, model)
every time. Unfortunately, Flux.setup
is type-unstable #162. It would be great to have a function reset!(optimiser_state)
that resets the momenta. Maybe a more stringent requirement is that
state = Flux.setup(optimiser, model)
# do some training
reset!(state)
state == Flux.setup(optimiser, model)
holds.
Possible Implementation
Below is an implementation for Adam
.
function reset!(leaf::Leaf{A, S}) where {A <: Optimisers.Adam, S}
leaf.state[1] .= 0
leaf.state[2] .= 0
leaf.state = (leaf.state[1], leaf.state[2], leaf.rule.beta)
nothing
end
function reset!(state::NamedTuple{(:layers,), L}) where {L}
for layer in state.layers
reset!(layer.weight)
reset!(layer.bias)
end
nothing
end
Metadata
Metadata
Assignees
Labels
No labels