Skip to content

Commit 15e62d3

Browse files
authored
fix: update how the error message looks (#1553)
1 parent 8c4405c commit 15e62d3

File tree

2 files changed

+30
-20
lines changed

2 files changed

+30
-20
lines changed

ext/LuxReactantExt/training.jl

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,10 @@ Invalid TrainState construction using a compiled function.
99
1010
`TrainState` is being constructed with a reactant compiled function, i.e. a
1111
`Reactant.Compiler.Thunk`. This is likely a mistake as the model should be
12-
passed in directly without being compiled first.
12+
passed in directly without being compiled first. When `single_train_step` or other
13+
functions are called on the `TrainState`, the model will be compiled automatically.
1314
14-
This is likely originating from the following style of usage:
15+
The correct usage is:
1516
1617
```julia
1718
using Lux, Reactant, Random, Optimisers
@@ -22,17 +23,25 @@ model = Dense(10, 10)
2223
ps, st = Lux.setup(Random.default_rng(), model) |> rdev
2324
x = rand(10) |> rdev
2425
25-
model_compiled = @compile model(x, ps, st)
26-
27-
train_state = Training.TrainState(model_compiled, ps, st, Adam())
26+
train_state = TrainState(model, ps, st, Adam())
2827
```
2928
30-
Instead avoid compiling the model and pass it directly to `TrainState`. When
31-
`single_train_step` or other functions are called on the `TrainState`, the
32-
model will be compiled automatically.
29+
The error originates because the model is being compiled first, which is not
30+
supported. **The following is the incorrect way, which potentially causes this
31+
error.**
3332
3433
```julia
35-
train_state = Training.TrainState(model, ps, st, Adam())
34+
using Lux, Reactant, Random, Optimisers
35+
36+
rdev = reactant_device()
37+
38+
model = Dense(10, 10)
39+
ps, st = Lux.setup(Random.default_rng(), model)
40+
x = rand(10) |> rdev
41+
42+
model_compiled = @compile model(x, ps, st)
43+
44+
train_state = Training.TrainState(model_compiled, ps, st, Adam())
3645
```
3746
3847
For end-to-end usage example refer to the documentation:

src/helpers/training.jl

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -69,30 +69,31 @@ end
6969

7070
function Adapt.adapt_structure(to::ReactantDevice, ts::TrainState)
7171
@warn """
72-
Moving `TrainState` to `ReactantDevice` might lead to unwanted behaviour. This
73-
potentially originates from the following style of usage:
72+
Moving `TrainState` to `ReactantDevice` might lead to unwanted behaviour.
73+
74+
Move the `ps` and `st` to the device before constructing the `TrainState`.
75+
This ensures the optimizer state and other internal states are on the device on
76+
construction. Prefer using the following style:
7477
7578
```julia
7679
rdev = reactant_device()
7780
78-
ps, st = Lux.setup(rng, model)
81+
ps, st = Lux.setup(rng, model) |> rdev
7982
train_state = TrainState(model, ps, st, opt)
80-
train_state = train_state |> rdev
8183
```
8284
83-
Specifically, `ps` and `st` we on the host device when `train_state` is being
84-
constructed and later `train_state` is moved to the device. Instead it is recommended
85-
to do the following:
85+
This warning potentially originates from having `ps` and `st` on the host when
86+
constructing the `TrainState`, and later moving the `TrainState` to the device.
87+
**The following is the incorrect way, which potentially causes this warning to
88+
appear.**
8689
8790
```julia
8891
rdev = reactant_device()
8992
90-
ps, st = Lux.setup(rng, model) |> rdev
93+
ps, st = Lux.setup(rng, model)
9194
train_state = TrainState(model, ps, st, opt)
95+
train_state = train_state |> rdev
9296
```
93-
94-
This ensures the optimizer state and other internal states are on the device on
95-
construction.
9697
"""
9798
return @invoke Adapt.adapt_structure(to::AbstractDevice, ts::TrainState)
9899
end

0 commit comments

Comments
 (0)