Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 19 additions & 5 deletions docs/src/index.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,23 @@
# Optimisers.jl

Optimisers.jl defines many standard gradient-based optimisation rules, and tools for applying them to deeply nested models.

This was written as the new training system for [Flux.jl](https://github.com/FluxML/Flux.jl) neural networks,
and also used by [Lux.jl](https://github.com/LuxDL/Lux.jl).
But it can be used separately on any array, or anything else understood by [Functors.jl](https://github.com/FluxML/Functors.jl).

## Installation

In the Julia REPL, type
```julia
]add Optimisers
```

or
```julia-repl
julia> import Pkg; Pkg.add("Optimisers")
```

## An optimisation rule

A new optimiser must overload two functions, [`apply!`](@ref Optimisers.apply!) and [`init`](@ref Optimisers.init).
Expand Down Expand Up @@ -38,7 +56,6 @@ state for every trainable array. Then at each step, [`update`](@ref Optimisers.u
to adjust the model:

```julia

using Flux, Metalhead, Zygote, Optimisers

model = Metalhead.ResNet(18) |> gpu # define a model to train
Expand All @@ -54,7 +71,6 @@ end;

state_tree, model = Optimisers.update(state_tree, model, ∇model);
@show sum(model(image)); # reduced

```

Notice that a completely new instance of the model is returned. Internally, this
Expand Down Expand Up @@ -91,7 +107,6 @@ Beware that it has nothing to do with Zygote's notion of "explicit" gradients.
identical trees of nested `NamedTuple`s.)

```julia

using Lux, Boltz, Zygote, Optimisers

lux_model, params, lux_state = Boltz.resnet(:resnet18) |> gpu; # define and initialise model
Expand All @@ -113,7 +128,6 @@ opt_state, params = Optimisers.update!(opt_state, params, ∇params);

y, lux_state = Lux.apply(lux_model, images, params, lux_state);
@show sum(y); # now reduced

```

Besides the parameters stored in `params` and gradually optimised, any other model state
Expand Down Expand Up @@ -297,7 +311,7 @@ similarly to what [`destructure`](@ref Optimisers.destructure) does but without
concatenating the arrays into a flat vector.
This is done by [`trainables`](@ref Optimisers.trainables), which returns a list of arrays:

```julia
```julia-repl
julia> using Flux, Optimisers
julia> model = Chain(Dense(2 => 3, tanh), BatchNorm(3), Dense(3 => 2));
Expand Down
2 changes: 1 addition & 1 deletion src/destructure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ This is what [`destructure`](@ref Optimisers.destructure) returns, and `re(p)` w
new parameters from vector `p`. If the model is callable, then `re(x, p) == re(p)(x)`.
# Example
```julia
```julia-repl
julia> using Flux, Optimisers
julia> _, re = destructure(Dense([1 2; 3 4], [0, 0], sigmoid))
Expand Down
6 changes: 3 additions & 3 deletions src/trainables.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ julia> trainables(x)
1-element Vector{AbstractArray}:
[1.0, 2.0, 3.0]
julia> x = MyLayer((a=[1.0,2.0], b=[3.0]), [4.0,5.0,6.0]);
julia> x = MyLayer((a=[1.0,2.0], b=[3.0]), [4.0,5.0,6.0]);
julia> trainables(x) # collects nested parameters
2-element Vector{AbstractArray}:
julia> trainables(x) # collects nested parameters
2-element Vector{AbstractArray}:
[1.0, 2.0]
[3.0]
```
Expand Down
Loading