@@ -9,9 +9,10 @@ Invalid TrainState construction using a compiled function.
99
1010`TrainState` is being constructed with a reactant compiled function, i.e. a
1111`Reactant.Compiler.Thunk`. This is likely a mistake as the model should be
12- passed in directly without being compiled first.
12+ passed in directly without being compiled first. When `single_train_step` or other
13+ functions are called on the `TrainState`, the model will be compiled automatically.
1314
14- This is likely originating from the following style of usage:
15+ The correct usage is :
1516
1617```julia
1718using Lux, Reactant, Random, Optimisers
@@ -22,17 +23,25 @@ model = Dense(10, 10)
2223ps, st = Lux.setup(Random.default_rng(), model) |> rdev
2324x = rand(10) |> rdev
2425
25- model_compiled = @compile model(x, ps, st)
26-
27- train_state = Training.TrainState(model_compiled, ps, st, Adam())
26+ train_state = TrainState(model, ps, st, Adam())
2827```
2928
30- Instead avoid compiling the model and pass it directly to `TrainState`. When
31- `single_train_step` or other functions are called on the `TrainState`, the
32- model will be compiled automatically.
29+ The error originates because the model is being compiled first, which is not
30+ supported. **The following is the incorrect way, which potentially causes this
31+ error.**
3332
3433```julia
35- train_state = Training.TrainState(model, ps, st, Adam())
34+ using Lux, Reactant, Random, Optimisers
35+
36+ rdev = reactant_device()
37+
38+ model = Dense(10, 10)
39+ ps, st = Lux.setup(Random.default_rng(), model)
40+ x = rand(10) |> rdev
41+
42+ model_compiled = @compile model(x, ps, st)
43+
44+ train_state = Training.TrainState(model_compiled, ps, st, Adam())
3645```
3746
3847For end-to-end usage example refer to the documentation:
0 commit comments