Skip to content

Commit 3e1ea7c

Browse files
Rule for mixed precision training (#152)
* mixed precision * docs * handle non-writeable * adjust * more tests * Update src/rules.jl Co-authored-by: Kyle Daruwalla <daruwalla.k.public@icloud.com> * Update src/rules.jl Co-authored-by: Kyle Daruwalla <daruwalla.k.public@icloud.com> * change constructor * add back nothing * address review * more comments * improve precision * fix test * fix test --------- Co-authored-by: Kyle Daruwalla <daruwalla.k.public@icloud.com>
1 parent 4ff61fc commit 3e1ea7c

File tree

6 files changed

+128
-3
lines changed

6 files changed

+128
-3
lines changed

docs/src/api.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,10 @@ In addition to the main course, you may wish to order some of these condiments:
2929
Optimisers.AccumGrad
3030
Optimisers.ClipGrad
3131
Optimisers.ClipNorm
32+
Optimisers.MixedPrecision
33+
Optimisers.OptimiserChain
3234
Optimisers.SignDecay
3335
Optimisers.WeightDecay
34-
Optimisers.OptimiserChain
3536
```
3637

3738
## Model Interface

src/Optimisers.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ include("rules.jl")
2424
export Descent, Adam, Momentum, Nesterov, Rprop, RMSProp,
2525
AdaGrad, AdaMax, AdaDelta, AMSGrad, NAdam, AdamW, RAdam, OAdam, AdaBelief,
2626
WeightDecay, SignDecay, ClipGrad, ClipNorm, OptimiserChain, Lion,
27-
AccumGrad
27+
AccumGrad, MixedPrecision
2828

2929
VERSION >= v"1.11.0-DEV.469" && eval(Meta.parse("public apply!, init, setup, update, update!"))
3030

src/rules.jl

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -776,7 +776,11 @@ julia> Optimisers.update(s, m, ([0.3, 1, 7],))[2] # clips before discounting
776776
struct OptimiserChain{O<:Tuple} <: AbstractRule
777777
opts::O
778778
end
779-
OptimiserChain(opts...) = OptimiserChain(opts)
779+
780+
function OptimiserChain(opts...)
781+
any(opt -> opt isa MixedPrecision, opts) && throw(ArgumentError("MixedPrecision optimisers should wrap the entire OptimiserChain, not be inside it."))
782+
return OptimiserChain(opts)
783+
end
780784

781785
@functor OptimiserChain
782786

@@ -856,3 +860,68 @@ function apply!(o::AccumGrad, state, x, dx)
856860
return (accum_dx, counter + 1), nothing
857861
end
858862
end
863+
864+
"""
865+
MixedPrecision([T = Float32,] opt)
866+
867+
An optimiser that wraps another optimiser `opt` in order to perform mixed precision
868+
training [1].
869+
870+
The state of `MixedPrecision{T}` will contain a copy in precision `T` of any trainable parameter `x`,
871+
call it `xT`, as well as the internal state of `opt` also at precision `T`.
872+
If `T` is not specified, it defaults to `Float32`.
873+
874+
Call `g` the gradient of `x`. Both `g` and `x` are typically in a precision lower than `T`
875+
(e.g. `Float16`).
876+
877+
In the `update!(opt_state, x, g)` call, `opt` is used to update `xT` instead of `x`,
878+
then `x` is updated with the value of `xT`.
879+
880+
# Reference
881+
882+
[1] Micikevicius et al. '17, "Mixed Precision Training", https://arxiv.org/abs/1710.03740 .
883+
884+
# Examples
885+
886+
```julia
887+
x = rand(Float16, 2) # A trainable parameter in low precision
888+
889+
opt = MixedPrecision(Adam(1e-3)) # Equivalent to MixedPrecision(Float32, Adam(1e-3))
890+
opt_state = Optimisers.setup(opt, x) # The state contains a copy of x in Float32 precision
891+
892+
g = rand(Float16, 2) # A gradient in low precision
893+
894+
# Accumulation is performed in high precision,
895+
# then also the low precision x is synced
896+
Optimisers.update!(opt_state, x, g)
897+
```
898+
"""
899+
struct MixedPrecision{T<:Number, O<:AbstractRule} <: AbstractRule
900+
rule::O
901+
end
902+
903+
@functor MixedPrecision
904+
905+
MixedPrecision(rule::AbstractRule) = MixedPrecision{Float32, typeof(rule)}(rule)
906+
MixedPrecision(T::Type, rule::AbstractRule) = MixedPrecision{T, typeof(rule)}(rule)
907+
908+
function init(o::MixedPrecision{T}, x::AbstractArray) where T
909+
xT = T.(x)
910+
return (xT, init(o.rule, xT))
911+
end
912+
913+
function apply!(o::MixedPrecision{T}, state, x, dx) where T
914+
xT, st = state
915+
st′, dx′ = apply!(o.rule, st, xT, dx)
916+
xT = subtract!(xT, dx′)
917+
if maywrite(x)
918+
x .= xT
919+
dx′ = nothing
920+
else
921+
dx′ = eltype(x).(x .- xT)
922+
end
923+
return (xT, st′), dx′
924+
end
925+
926+
adjust(o::MixedPrecision{T}, eta::Real) where T = MixedPrecision(T, adjust(o.rule, eta))
927+
adjust(o::MixedPrecision{T}; kw...) where T = MixedPrecision(T, adjust(o.rule; kw...))

test/Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,6 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1010
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1111
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1212
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
13+
14+
[sources]
15+
Optimisers = { path = ".." }

test/rules.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ RULES = [
99
Descent(), Adam(), Momentum(), Nesterov(), Rprop(), RMSProp(),
1010
AdaGrad(), AdaMax(), AdaDelta(), AMSGrad(), NAdam(),
1111
AdamW(), RAdam(), OAdam(), AdaBelief(), Lion(),
12+
MixedPrecision(Float64, Adam()),
1213
# A few chained combinations:
1314
OptimiserChain(SignDecay(0.001), Adam(0.001)),
1415
OptimiserChain(ClipNorm(), Adam(0.001)),
@@ -262,3 +263,26 @@ end
262263
os, x = Optimisers.update(os, x, δx)
263264
@test x Float16[1.835, -0.886, 0.5493] rtol=1e-3
264265
end
266+
267+
@testset "MixedPrecision" begin
268+
x = rand(Float16, 2)
269+
opt_state = Optimisers.setup(MixedPrecision(Adam(1e-3)), x)
270+
@test opt_state.state[1] isa Vector{Float32}
271+
@test opt_state.state[2][1] isa Vector{Float32}
272+
g = rand(Float16, 2)
273+
new_state, new_x = Optimisers.update(opt_state, x, rand(Float16, 2))
274+
@test new_x == Float16.(new_state.state[1])
275+
@test new_x x .- 1e-3 .* g
276+
277+
x = rand(Float16, 2)
278+
opt_state = Optimisers.setup(MixedPrecision(Float64, Adam(1e-3)), x)
279+
@test opt_state.state[1] isa Vector{Float64}
280+
@test opt_state.state[2][1] isa Vector{Float64}
281+
282+
opt = MixedPrecision(Float64, Adam(1e-3))
283+
opt2 = Optimisers.adjust(opt, 2e-3)
284+
@test opt2.rule.eta == 2e-3
285+
@test opt2 isa MixedPrecision{Float64, <:Adam}
286+
287+
@test_throws ArgumentError OptimiserChain(MixedPrecision(Adam()))
288+
end

test/runtests.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,20 @@ end
257257
@test sc2.γ.rule.opts[1].delta == 2.5
258258
@test sc2.γ.rule.opts[2].eta == 0.001 # unchanged
259259
@test sc2.γ.state[2][1] [0.1, 0.2, 0.2]
260+
261+
# MixedPrecision
262+
mp = Optimisers.setup(MixedPrecision(Momentum(0.1, 0.9)), m)
263+
mp1, _ = Optimisers.update(mp, m, (α = nothing, γ = [1,10,100],))
264+
@test mp1.γ.rule.rule.eta == 0.1
265+
@test mp1.γ.state[2] [0.1, 1, 10]
266+
267+
mp2 = Optimisers.adjust(mp1, 0.2)
268+
@test mp2.γ.rule.rule.eta == 0.2
269+
@test mp2.γ.rule.rule.rho == 0.9
270+
271+
mp3 = Optimisers.adjust(mp1; eta=0.3, rho=0.7)
272+
@test mp3.γ.rule.rule.eta == 0.3
273+
@test mp3.γ.rule.rule.rho == 0.7
260274
end
261275

262276
@testset "adjusting parameters, in-place" begin
@@ -301,6 +315,20 @@ end
301315
@test sc1.γ.rule.opts[1].delta 2.5
302316
@test sc1.γ.rule.opts[2].eta 0.2 # unchanged
303317
@test sc1.γ.state[2][1] [0.1, 0.2, 0.2]
318+
319+
# MixedPrecision
320+
mp = Optimisers.setup(MixedPrecision(Momentum(0.1, 0.9)), m)
321+
mp1, _ = Optimisers.update(mp, m, (α = nothing, γ = [1,10,100],))
322+
@test mp1.γ.rule.rule.eta == 0.1
323+
@test mp1.γ.state[2] [0.1, 1, 10]
324+
325+
Optimisers.adjust!(mp1, 0.2)
326+
@test mp1.γ.rule.rule.eta == 0.2
327+
@test mp1.γ.rule.rule.rho == 0.9
328+
329+
Optimisers.adjust!(mp1; eta=0.3, rho=0.7)
330+
@test mp1.γ.rule.rule.eta == 0.3
331+
@test mp1.γ.rule.rule.rho == 0.7
304332
end
305333

306334
@testset "freeze/thaw" begin

0 commit comments

Comments
 (0)