@@ -27,6 +27,46 @@ function Zygote._pullback(ctx::Zygote.AContext, ::typeof(checkpointed), f, xs...
2727 return y, pullback_checkpointed
2828end
2929
30+
31+
32+ """
33+
34+ eager_update!(state, model, update!)
35+
36+ Eagerly updates the model parameters, discarding the updated gradients to save memory.
37+ `model` stores the parameters to be updated, `state` is the optimization state (eg. from Optimisers.jl) matching your model component, and
38+ `update!` is the function that updates the parameters (eg. from `Optimisers.jl`), usually called as `update!(state, model, grads)`.
39+
40+ If `f` is a function that takes a single layer, called as `h = f(model.layers[i], h, other_args...)` then we can eagerly update with:
41+
42+ ```julia
43+ h = f(Zygote.eager_update!(state.layers[i], model.layers[i], Optimisers.update!), h, other_args...)
44+ ```
45+
46+ or combine this with gradient checkpointing (for additional memory saving at the cost of increased execution time) with:
47+
48+ ```julia
49+ h = Zygote.checkpointed(f, eager_update!(state.layers[i], model.layers[i], Optimisers.update!), h, other_args...)
50+ ```
51+
52+ If `model.layers[i]` itself is callable, we can use the above by first wrapping it:
53+
54+ ```julia
55+ f(model, xs...) = model(xs...)
56+ h = f(Zygote.eager_update!(state.layers[i], model.layers[i], Optimisers.update!), h, other_args...)
57+ ```
58+
59+ !!! warning
60+ If different layers share trainable parameters, then `eager_update!` will likely give wrong results.
61+ """
62+ function eager_update!(state, model, update!)
63+ function update_hook(dmodel)
64+ update!(state, model, dmodel)
65+ return nothing
66+ end
67+ return Zygote. hook(update_hook, model)
68+ end
69+
3070"""
3171 hessian(f, x)
3272
0 commit comments