Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Optimisers"
uuid = "3bd65402-5787-11e9-1adc-39752487f4e2"
authors = ["Mike J Innes <mike.j.innes@gmail.com>"]
version = "0.4.6"
version = "0.4.7"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
3 changes: 2 additions & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@ In addition to the main course, you may wish to order some of these condiments:
Optimisers.AccumGrad
Optimisers.ClipGrad
Optimisers.ClipNorm
Optimisers.MixedPrecision
Optimisers.OptimiserChain
Optimisers.SignDecay
Optimisers.WeightDecay
Optimisers.OptimiserChain
```

## Model Interface
Expand Down
2 changes: 1 addition & 1 deletion src/Optimisers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ include("rules.jl")
export Descent, Adam, Momentum, Nesterov, Rprop, RMSProp,
AdaGrad, AdaMax, AdaDelta, AMSGrad, NAdam, AdamW, RAdam, OAdam, AdaBelief,
WeightDecay, SignDecay, ClipGrad, ClipNorm, OptimiserChain, Lion,
AccumGrad
AccumGrad, MixedPrecision

VERSION >= v"1.11.0-DEV.469" && eval(Meta.parse("public apply!, init, setup, update, update!"))

Expand Down
71 changes: 70 additions & 1 deletion src/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -776,7 +776,11 @@ julia> Optimisers.update(s, m, ([0.3, 1, 7],))[2] # clips before discounting
struct OptimiserChain{O<:Tuple} <: AbstractRule
opts::O
end
OptimiserChain(opts...) = OptimiserChain(opts)

function OptimiserChain(opts...)
any(opt -> opt isa MixedPrecision, opts) && throw(ArgumentError("MixedPrecision optimisers should wrap the entire OptimiserChain, not be inside it."))
return OptimiserChain(opts)
end

@functor OptimiserChain

Expand Down Expand Up @@ -856,3 +860,68 @@ function apply!(o::AccumGrad, state, x, dx)
return (accum_dx, counter + 1), nothing
end
end

"""
MixedPrecision([T = Float32,] opt)

An optimiser that wraps another optimiser `opt` in order to perform mixed precision
training [1].

The state of `MixedPrecision{T}` will contain a copy in precision `T` of any trainable parameter `x`,
call it `xT`, as well as the internal state of `opt` also at precision `T`.
If `T` is not specified, it defaults to `Float32`.

Call `g` the gradient of `x`. Both `g` and `x` are typically in a precision lower than `T`
(e.g. `Float16`).

In the `update!(opt_state, x, g)` call, `opt` is used to update `xT` instead of `x`,
then `x` is updated with the value of `xT`.

# Reference

[1] Micikevicius et al. '17, "Mixed Precision Training", https://arxiv.org/abs/1710.03740 .

# Examples

```julia
x = rand(Float16, 2) # A trainable parameter in low precision

opt = MixedPrecision(Adam(1e-3)) # Equivalent to MixedPrecision(Float32, Adam(1e-3))
opt_state = Optimisers.setup(opt, x) # The state contains a copy of x in Float32 precision

g = rand(Float16, 2) # A gradient in low precision

# Accumulation is performed in high precision,
# then also the low precision x is synced
Optimisers.update!(opt_state, x, g)
```
"""
struct MixedPrecision{T<:Number, O<:AbstractRule} <: AbstractRule
rule::O
end

@functor MixedPrecision

MixedPrecision(rule::AbstractRule) = MixedPrecision{Float32, typeof(rule)}(rule)
MixedPrecision(T::Type, rule::AbstractRule) = MixedPrecision{T, typeof(rule)}(rule)

function init(o::MixedPrecision{T}, x::AbstractArray) where T
xT = T.(x)
return (xT, init(o.rule, xT))
end

function apply!(o::MixedPrecision{T}, state, x, dx) where T
xT, st = state
st′, dx′ = apply!(o.rule, st, xT, dx)
xT = subtract!(xT, dx′)
if maywrite(x)
x .= xT
dx′ = nothing
else
dx′ = eltype(x).(x .- xT)
end
return (xT, st′), dx′
end

adjust(o::MixedPrecision{T}, eta::Real) where T = MixedPrecision(T, adjust(o.rule, eta))
adjust(o::MixedPrecision{T}; kw...) where T = MixedPrecision(T, adjust(o.rule; kw...))
3 changes: 3 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,6 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[sources]
Optimisers = { path = ".." }
24 changes: 24 additions & 0 deletions test/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ RULES = [
Descent(), Adam(), Momentum(), Nesterov(), Rprop(), RMSProp(),
AdaGrad(), AdaMax(), AdaDelta(), AMSGrad(), NAdam(),
AdamW(), RAdam(), OAdam(), AdaBelief(), Lion(),
MixedPrecision(Float64, Adam()),
# A few chained combinations:
OptimiserChain(SignDecay(0.001), Adam(0.001)),
OptimiserChain(ClipNorm(), Adam(0.001)),
Expand Down Expand Up @@ -262,3 +263,26 @@ end
os, x = Optimisers.update(os, x, δx)
@test x ≈ Float16[1.835, -0.886, 0.5493] rtol=1e-3
end

@testset "MixedPrecision" begin
x = rand(Float16, 2)
opt_state = Optimisers.setup(MixedPrecision(Adam(1e-3)), x)
@test opt_state.state[1] isa Vector{Float32}
@test opt_state.state[2][1] isa Vector{Float32}
g = rand(Float16, 2)
new_state, new_x = Optimisers.update(opt_state, x, rand(Float16, 2))
@test new_x == Float16.(new_state.state[1])
@test new_x ≈ x .- 1e-3 .* g

x = rand(Float16, 2)
opt_state = Optimisers.setup(MixedPrecision(Float64, Adam(1e-3)), x)
@test opt_state.state[1] isa Vector{Float64}
@test opt_state.state[2][1] isa Vector{Float64}

opt = MixedPrecision(Float64, Adam(1e-3))
opt2 = Optimisers.adjust(opt, 2e-3)
@test opt2.rule.eta == 2e-3
@test opt2 isa MixedPrecision{Float64, <:Adam}

@test_throws ArgumentError OptimiserChain(MixedPrecision(Adam()))
end
28 changes: 28 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,20 @@ end
@test sc2.γ.rule.opts[1].delta == 2.5
@test sc2.γ.rule.opts[2].eta == 0.001 # unchanged
@test sc2.γ.state[2][1] ≈ [0.1, 0.2, 0.2]

# MixedPrecision
mp = Optimisers.setup(MixedPrecision(Momentum(0.1, 0.9)), m)
mp1, _ = Optimisers.update(mp, m, (α = nothing, γ = [1,10,100],))
@test mp1.γ.rule.rule.eta == 0.1
@test mp1.γ.state[2] ≈ [0.1, 1, 10]

mp2 = Optimisers.adjust(mp1, 0.2)
@test mp2.γ.rule.rule.eta == 0.2
@test mp2.γ.rule.rule.rho == 0.9

mp3 = Optimisers.adjust(mp1; eta=0.3, rho=0.7)
@test mp3.γ.rule.rule.eta == 0.3
@test mp3.γ.rule.rule.rho == 0.7
end

@testset "adjusting parameters, in-place" begin
Expand Down Expand Up @@ -301,6 +315,20 @@ end
@test sc1.γ.rule.opts[1].delta ≈ 2.5
@test sc1.γ.rule.opts[2].eta ≈ 0.2 # unchanged
@test sc1.γ.state[2][1] ≈ [0.1, 0.2, 0.2]

# MixedPrecision
mp = Optimisers.setup(MixedPrecision(Momentum(0.1, 0.9)), m)
mp1, _ = Optimisers.update(mp, m, (α = nothing, γ = [1,10,100],))
@test mp1.γ.rule.rule.eta == 0.1
@test mp1.γ.state[2] ≈ [0.1, 1, 10]

Optimisers.adjust!(mp1, 0.2)
@test mp1.γ.rule.rule.eta == 0.2
@test mp1.γ.rule.rule.rho == 0.9

Optimisers.adjust!(mp1; eta=0.3, rho=0.7)
@test mp1.γ.rule.rule.eta == 0.3
@test mp1.γ.rule.rule.rho == 0.7
end

@testset "freeze/thaw" begin
Expand Down
Loading