@@ -636,10 +636,12 @@ function _norm(dx::Broadcast.Broadcasted, p::Real)
636636end
637637
638638"""
639- OptimiserChain(opts...)
639+ OptimiserChain(o1, o2, o34...)
640+ o1 => o2 => o3
640641
641- Compose a sequence of optimisers so that each `opt` in `opts `
642+ Compose a sequence of optimisers so that each `opt` in `(o1, o2, o34...) `
642643updates the gradient, in the order specified.
644+ May be entered using `Pair` syntax with several `AbstractRule`s.
643645
644646With an empty sequence, `OptimiserChain()` is the identity,
645647so `update!` will subtract the full gradient from the parameters.
@@ -648,12 +650,13 @@ This is equivalent to `Descent(1)`.
648650# Example
649651
650652```jldoctest
651- julia> o = OptimiserChain(ClipGrad(1.0), Descent(0.1));
653+ julia> o = ClipGrad(1.0) => Descent(0.1)
654+ OptimiserChain(ClipGrad{Float64}(1.0), Descent{Float64}(0.1))
652655
653656julia> m = (zeros(3),);
654657
655658julia> s = Optimisers.setup(o, m)
656- (Leaf(OptimiserChain( ClipGrad(1.0), Descent(0.1) ), (nothing, nothing)),)
659+ (Leaf(ClipGrad(1.0) => Descent(0.1), (nothing, nothing)),)
657660
658661julia> Optimisers.update(s, m, ([0.3, 1, 7],))[2] # clips before discounting
659662([-0.03, -0.1, -0.1],)
@@ -664,6 +667,9 @@ struct OptimiserChain{O<:Tuple} <: AbstractRule
664667end
665668OptimiserChain(opts... ) = OptimiserChain(opts)
666669
670+ Base. Pair(a:: AbstractRule , b:: AbstractRule ) = OptimiserChain(a, b)
671+ Base. Pair(a:: AbstractRule , bc:: OptimiserChain ) = OptimiserChain(a, bc. opts... )
672+
667673@functor OptimiserChain
668674
669675init(o:: OptimiserChain , x:: AbstractArray ) = map(opt -> init(opt, x), o. opts)
@@ -679,7 +685,14 @@ function apply!(o::OptimiserChain, states, x, dx, dxs...)
679685 end
680686end
681687
682- function Base. show(io:: IO , c:: OptimiserChain )
688+ function Base. show(io:: IO , c:: OptimiserChain ) # compact show
689+ if length(c. opts) > 1
690+ join(io, c. opts, " => " )
691+ else
692+ show(io, MIME" text/plain" (), c)
693+ end
694+ end
695+ function Base. show(io:: IO , :: MIME"text/plain" , c:: OptimiserChain )
683696 print(io, " OptimiserChain(" )
684697 join(io, c. opts, " , " )
685698 print(io, " )" )
0 commit comments