diff --git a/docs/src/api.md b/docs/src/api.md index bea4421..78e0ed3 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -30,6 +30,7 @@ Optimisers.AccumGrad Optimisers.ClipGrad Optimisers.ClipNorm Optimisers.MixedPrecision +Optimisers.add_mixed_precision Optimisers.OptimiserChain Optimisers.SignDecay Optimisers.WeightDecay diff --git a/src/Optimisers.jl b/src/Optimisers.jl index 07e1ac9..80b4b7c 100644 --- a/src/Optimisers.jl +++ b/src/Optimisers.jl @@ -24,7 +24,7 @@ include("rules.jl") export Descent, Adam, Momentum, Nesterov, Rprop, RMSProp, AdaGrad, AdaMax, AdaDelta, AMSGrad, NAdam, AdamW, RAdam, OAdam, AdaBelief, WeightDecay, SignDecay, ClipGrad, ClipNorm, OptimiserChain, Lion, - AccumGrad, MixedPrecision + AccumGrad, MixedPrecision, add_mixed_precision VERSION >= v"1.11.0-DEV.469" && eval(Meta.parse("public apply!, init, setup, update, update!")) diff --git a/src/rules.jl b/src/rules.jl index 224b27a..c1e8304 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -925,3 +925,46 @@ end adjust(o::MixedPrecision{T}, eta::Real) where T = MixedPrecision(T, adjust(o.rule, eta)) adjust(o::MixedPrecision{T}; kw...) where T = MixedPrecision(T, adjust(o.rule; kw...)) + + +""" + add_mixed_precision([T], tree, model) -> new_tree + +Add mixed precision to an existing optimisers state `tree` for `model`. +If `T` is not provided, `Float32` is used. + +Each leaf of the new returned tree will contain a `MixedPrecision` rule wrapping the original rule, +and the states will be preserved and converted to type `T`. +""" +add_mixed_precision(tree, model) = add_mixed_precision(Float32, tree, model) + +function add_mixed_precision(T, tree, model) + cache = IdDict() + tree = _add_mixed_precision(T, tree, model; cache) + isempty(cache) && @warn "setup found no trainable parameters in this model" + return tree +end + +function _add_mixed_precision(T, tree, x; cache) + ch, re = functor(tree) + return mapvalue((ti, xi) -> _add_mixed_precision(T, ti, xi; cache), ch, _trainable(x)) +end + +function _add_mixed_precision(T, tree::Optimisers.Leaf, x; cache) + haskey(cache, tree) && return cache[tree] + fT(z) = z isa AbstractFloat || isnumeric(z) ? T.(z) : z + if !(tree.rule isa MixedPrecision{T}) + if tree.rule isa MixedPrecision # different type + rulenew = MixedPrecision(T, tree.rule.rule) + statenew = fmap(fT, tree.state) + else + rulenew = MixedPrecision(T, tree.rule) + statenew = (T.(x), fmap(fT, tree.state)) + end + treenew = Leaf(rulenew, statenew, tree.frozen) + else + treenew = tree + end + cache[tree] = treenew + return treenew +end diff --git a/test/rules.jl b/test/rules.jl index 0455792..800c9e8 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -286,3 +286,31 @@ end @test_throws ArgumentError OptimiserChain(MixedPrecision(Adam())) end + +@testset "add_mixed_precision" begin + d = rand(Float16, 2,2) + d2 = rand(Float16, 2) + model = Foo(Foo(d, d2), d) + opt_state = Optimisers.setup(AdamW(), model) + @test opt_state.x.x === opt_state.y + @test opt_state.x.y.state[1] isa Vector{Float16} + @test opt_state.x.y.state[2] isa Vector{Float16} + @test opt_state.x.y.state[3] isa Tuple{Float16, Float16} + + opt_state_new = add_mixed_precision(opt_state, model) + + @test opt_state_new.x.x.rule isa MixedPrecision{Float32} + @test opt_state_new.x.x === opt_state_new.y + @test opt_state_new.x.x.state[1] isa Matrix{Float32} + @test opt_state_new.x.x.state[1] ≈ model.x.x + @test opt_state_new.x.y.state[2][1] isa Vector{Float32} + @test opt_state_new.x.y.state[2][2] isa Vector{Float32} + @test opt_state_new.x.y.state[2][3] isa Tuple{Float32, Float32} + + opt_state_new2 = add_mixed_precision(Float64, opt_state_new, model) + + @test opt_state_new2.x.x.rule isa MixedPrecision{Float64} # MixedPrecision{Float32} replaced + @test opt_state_new2.x.x.rule.rule isa AdamW # no nesting of MixedPrecision + @test opt_state_new2.x.x.state[1] isa Matrix{Float64} + @test opt_state_new2.x.x.state[2][1] isa Matrix{Float64} +end