Skip to content

Commit 3b73b84

Browse files
authored
Eager parameter updating (#1541)
1 parent 5a8d0ed commit 3b73b84

File tree

2 files changed

+41
-0
lines changed

2 files changed

+41
-0
lines changed

docs/src/utils.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ Zygote.hook
2626
Zygote.Buffer
2727
Zygote.forwarddiff
2828
Zygote.checkpointed
29+
Zygote.eager_update!
2930
```
3031

3132
`Params` and `Grads` can be copied to and from arrays using the `copy!` function.

src/lib/grad.jl

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,46 @@ function Zygote._pullback(ctx::Zygote.AContext, ::typeof(checkpointed), f, xs...
2727
return y, pullback_checkpointed
2828
end
2929

30+
31+
32+
"""
33+
34+
eager_update!(state, model, update!)
35+
36+
Eagerly updates the model parameters, discarding the updated gradients to save memory.
37+
`model` stores the parameters to be updated, `state` is the optimization state (eg. from Optimisers.jl) matching your model component, and
38+
`update!` is the function that updates the parameters (eg. from `Optimisers.jl`), usually called as `update!(state, model, grads)`.
39+
40+
If `f` is a function that takes a single layer, called as `h = f(model.layers[i], h, other_args...)` then we can eagerly update with:
41+
42+
```julia
43+
h = f(Zygote.eager_update!(state.layers[i], model.layers[i], Optimisers.update!), h, other_args...)
44+
```
45+
46+
or combine this with gradient checkpointing (for additional memory saving at the cost of increased execution time) with:
47+
48+
```julia
49+
h = Zygote.checkpointed(f, eager_update!(state.layers[i], model.layers[i], Optimisers.update!), h, other_args...)
50+
```
51+
52+
If `model.layers[i]` itself is callable, we can use the above by first wrapping it:
53+
54+
```julia
55+
f(model, xs...) = model(xs...)
56+
h = f(Zygote.eager_update!(state.layers[i], model.layers[i], Optimisers.update!), h, other_args...)
57+
```
58+
59+
!!! warning
60+
If different layers share trainable parameters, then `eager_update!` will likely give wrong results.
61+
"""
62+
function eager_update!(state, model, update!)
63+
function update_hook(dmodel)
64+
update!(state, model, dmodel)
65+
return nothing
66+
end
67+
return Zygote.hook(update_hook, model)
68+
end
69+
3070
"""
3171
hessian(f, x)
3272

0 commit comments

Comments
 (0)