@@ -776,7 +776,11 @@ julia> Optimisers.update(s, m, ([0.3, 1, 7],))[2] # clips before discounting
776776struct OptimiserChain{O<: Tuple } <: AbstractRule
777777 opts:: O
778778end
779- OptimiserChain(opts... ) = OptimiserChain(opts)
779+
780+ function OptimiserChain(opts... )
781+ any(opt -> opt isa MixedPrecision, opts) && throw(ArgumentError(" MixedPrecision optimisers should wrap the entire OptimiserChain, not be inside it." ))
782+ return OptimiserChain(opts)
783+ end
780784
781785@functor OptimiserChain
782786
@@ -856,3 +860,68 @@ function apply!(o::AccumGrad, state, x, dx)
856860 return (accum_dx, counter + 1 ), nothing
857861 end
858862end
863+
864+ """
865+ MixedPrecision([T = Float32,] opt)
866+
867+ An optimiser that wraps another optimiser `opt` in order to perform mixed precision
868+ training [1].
869+
870+ The state of `MixedPrecision{T}` will contain a copy in precision `T` of any trainable parameter `x`,
871+ call it `xT`, as well as the internal state of `opt` also at precision `T`.
872+ If `T` is not specified, it defaults to `Float32`.
873+
874+ Call `g` the gradient of `x`. Both `g` and `x` are typically in a precision lower than `T`
875+ (e.g. `Float16`).
876+
877+ In the `update!(opt_state, x, g)` call, `opt` is used to update `xT` instead of `x`,
878+ then `x` is updated with the value of `xT`.
879+
880+ # Reference
881+
882+ [1] Micikevicius et al. '17, "Mixed Precision Training", https://arxiv.org/abs/1710.03740 .
883+
884+ # Examples
885+
886+ ```julia
887+ x = rand(Float16, 2) # A trainable parameter in low precision
888+
889+ opt = MixedPrecision(Adam(1e-3)) # Equivalent to MixedPrecision(Float32, Adam(1e-3))
890+ opt_state = Optimisers.setup(opt, x) # The state contains a copy of x in Float32 precision
891+
892+ g = rand(Float16, 2) # A gradient in low precision
893+
894+ # Accumulation is performed in high precision,
895+ # then also the low precision x is synced
896+ Optimisers.update!(opt_state, x, g)
897+ ```
898+ """
899+ struct MixedPrecision{T<: Number , O<: AbstractRule } <: AbstractRule
900+ rule:: O
901+ end
902+
903+ @functor MixedPrecision
904+
905+ MixedPrecision(rule:: AbstractRule ) = MixedPrecision{Float32, typeof(rule)}(rule)
906+ MixedPrecision(T:: Type , rule:: AbstractRule ) = MixedPrecision{T, typeof(rule)}(rule)
907+
908+ function init(o:: MixedPrecision{T} , x:: AbstractArray ) where T
909+ xT = T.(x)
910+ return (xT, init(o. rule, xT))
911+ end
912+
913+ function apply!(o:: MixedPrecision{T} , state, x, dx) where T
914+ xT, st = state
915+ st′, dx′ = apply!(o. rule, st, xT, dx)
916+ xT = subtract!(xT, dx′)
917+ if maywrite(x)
918+ x .= xT
919+ dx′ = nothing
920+ else
921+ dx′ = eltype(x). (x .- xT)
922+ end
923+ return (xT, st′), dx′
924+ end
925+
926+ adjust(o:: MixedPrecision{T} , eta:: Real ) where T = MixedPrecision(T, adjust(o. rule, eta))
927+ adjust(o:: MixedPrecision{T} ; kw... ) where T = MixedPrecision(T, adjust(o. rule; kw... ))
0 commit comments