Skip to content
18 changes: 9 additions & 9 deletions src/optimise/train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,29 +81,29 @@ batchmemaybe(x) = tuple(x)
batchmemaybe(x::Tuple) = x

"""
step!(loss, params, opt)
optimstep!(loss, params, opt)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest optimstep! instead of trainstep! to indicate that this is the optimiser interface and keep the ML jargon to a minimum

Copy link
Member

@mcabbott mcabbott Mar 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One vote for something evoking train! to stress that they are closely related.

If the longer-term plan is to use Optimisers.jl, this may not fit with train! at all -- some recent discussion here: #1902 (comment) . In which case there will be an implicit-style train! & Params story, and an explicit-style gradient and Optimisers.update!. With such a divide, this function wants to be clearly on the train! & Params side.

Maybe it should just be 3-arg train!? Without a data iterator, there is no iteration, that's all:

train!(loss, ::Params, data, ::AbstractOptimiser)  # calls loss(d...) for d in data
train!(loss, ::Params, ::AbstractOptimiser)        # calls loss() since there is no data


`step!` uses a `loss` function (with no inputs) to improve the [Model parameters](@ref) (`params`)
`optimstep!` uses a `loss` function (with no inputs) to improve the [Model parameters](@ref) (`params`)
based on a pluggable [Optimisers](@ref) (`opt`). It represents a single step in
the training loop `train!`.

The default implementation for `step!` is takes the gradient of `loss`
The default implementation for `optimstep!` is takes the gradient of `loss`
and calls `Flux.Optimise.update!` to adjust the parameters, but you can overload
`step!` for specific types of `opt`. This can be useful if your optimization routine
`optimstep!` for specific types of `opt`. This can be useful if your optimization routine
has does not follow the standard gradient descent procedure (e.g. gradient-free optimizers).

Unlike `train!`, the loss function of `step!` accepts no input.
Instead, `train!` cycles through the data in a loop and calls `step!`:
Unlike `train!`, the loss function of `optimstep!` accepts no input.
Instead, `train!` cycles through the data in a loop and calls `optimstep!`:
```julia
for d in data
step!(ps, opt) do
optimstep!(ps, opt) do
loss(d)
end
end
```
If you are writing [Custom Training loops](@ref), then you should follow this pattern.
"""
function step!(loss, params, opt)
function optimstep!(loss, params, opt)
val, gs = withgradient(loss, params)
update!(opt, params, gs)
return val, gs
Expand Down Expand Up @@ -135,7 +135,7 @@ function train!(loss, ps, data, opt; cb = () -> ())
cb = runall(cb)
@progress for d in data
try
step!(ps, opt) do
optimstep!(ps, opt) do
loss(batchmemaybe(d)...)
end
cb()
Expand Down