@@ -621,23 +621,26 @@ function apply!(o::ClipNorm, state, x, dx)
621621end
622622
623623"""
624- OptimiserChain(opts...)
624+ OptimiserChain(o1, o2, o34...)
625+ o1 => o2 => o3
625626
626- Compose a sequence of optimisers so that each `opt` in `opts `
627+ Compose a sequence of optimisers so that each `opt` in `(o1, o2, o34...) `
627628updates the gradient, in the order specified.
629+ May be entered using `Pair` syntax with several `AbstractRule`s.
628630
629631With an empty sequence, `OptimiserChain()` is the identity,
630632so `update!` will subtract the full gradient from the parameters.
631633This is equivalent to `Descent(1)`.
632634
633635# Example
634636```jldoctest
635- julia> o = OptimiserChain(ClipGrad(1.0), Descent(0.1));
637+ julia> o = ClipGrad(1.0) => Descent(0.1)
638+ OptimiserChain(ClipGrad{Float64}(1.0), Descent{Float64}(0.1))
636639
637640julia> m = (zeros(3),);
638641
639642julia> s = Optimisers.setup(o, m)
640- (Leaf(OptimiserChain( ClipGrad{Float64}(1.0), Descent{Float64}(0.1) ), (nothing, nothing)),)
643+ (Leaf(ClipGrad{Float64}(1.0) => Descent{Float64}(0.1), (nothing, nothing)),)
641644
642645julia> Optimisers.update(s, m, ([0.3, 1, 7],))[2] # clips before discounting
643646([-0.03, -0.1, -0.1],)
@@ -648,6 +651,9 @@ struct OptimiserChain{O<:Tuple} <: AbstractRule
648651end
649652OptimiserChain(opts... ) = OptimiserChain(opts)
650653
654+ Base. Pair(a:: AbstractRule , b:: AbstractRule ) = OptimiserChain(a, b)
655+ Base. Pair(a:: AbstractRule , bc:: OptimiserChain ) = OptimiserChain(a, bc. opts... )
656+
651657@functor OptimiserChain
652658
653659init(o:: OptimiserChain , x:: AbstractArray ) = map(opt -> init(opt, x), o. opts)
@@ -659,7 +665,14 @@ function apply!(o::OptimiserChain, states, x, dx, dxs...)
659665 end
660666end
661667
662- function Base. show(io:: IO , c:: OptimiserChain )
668+ function Base. show(io:: IO , c:: OptimiserChain ) # compact show
669+ if length(c. opts) > 1
670+ join(io, c. opts, " => " )
671+ else
672+ show(io, MIME" text/plain" (), c)
673+ end
674+ end
675+ function Base. show(io:: IO , :: MIME"text/plain" , c:: OptimiserChain )
663676 print(io, " OptimiserChain(" )
664677 join(io, c. opts, " , " )
665678 print(io, " )" )
0 commit comments