Skip to content

Commit 10ea255

Browse files
authored
docs: migrate most examples to Reactant (#1180)
* docs: Basics run on CPU * docs: Run Polynomial Fitting using Reactant * feat: allow users to bump the HLO * docs: update Optimization tutorial * docs: use Reactant for CPU in SimpleChains * docs: update PINN2DPDE * docs: partially move HyperNet to reactant * chore: run formatter [skip tests] * docs: highlight Reactant more prominently * docs: update SimpleRNN * fix: incorrect check in Embedding * fix: bump enzyme in project * feat: handle weight initializers for reactant RNGs * fix: workaround for #1186 * fix: simpleRNN works with reactant * fix: failing tests and use overlay * revert: Hypernet keep in CUDA for now
1 parent 476f3f4 commit 10ea255

File tree

40 files changed

+362
-169
lines changed

40 files changed

+362
-169
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Lux"
22
uuid = "b2108857-7c20-44ae-9111-449ecde12c47"
33
authors = ["Avik Pal <avikpal@mit.edu> and contributors"]
4-
version = "1.4.4"
4+
version = "1.5.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -88,7 +88,7 @@ Compat = "4.16"
8888
ComponentArrays = "0.15.18"
8989
ConcreteStructs = "0.2.3"
9090
DispatchDoctor = "0.4.12"
91-
Enzyme = "0.13.16"
91+
Enzyme = "0.13.28"
9292
EnzymeCore = "0.8.8"
9393
FastClosures = "0.3.2"
9494
Flux = "0.15, 0.16"

README.md

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,38 @@ gs, loss, stats, train_state = Training.single_train_step!(AutoZygote(), MSELoss
170170
(x, dev(rand(rng, Float32, 10, 2))), train_state)
171171
```
172172

173+
## 🤸 Quickstart with Reactant
174+
175+
```julia
176+
using Lux, Random, Optimisers, Reactant, Enzyme
177+
178+
rng = Random.default_rng()
179+
Random.seed!(rng, 0)
180+
181+
model = Chain(Dense(128, 256, tanh), Chain(Dense(256, 1, tanh), Dense(1, 10)))
182+
183+
dev = reactant_device()
184+
185+
ps, st = Lux.setup(rng, model) |> dev
186+
187+
x = rand(rng, Float32, 128, 2) |> dev
188+
189+
# We need to compile the model before we can use it.
190+
model_forward = @compile model(x, ps, Lux.testmode(st))
191+
model_forward(x, ps, Lux.testmode(st))
192+
193+
# Gradients can be computed using Enzyme
194+
@jit Enzyme.gradient(Reverse, sum first Lux.apply, Const(model), x, ps, Const(st))
195+
196+
# All of this can be automated using the TrainState API
197+
train_state = Training.TrainState(model, ps, st, Adam(0.001f0))
198+
199+
gs, loss, stats, train_state = Training.single_train_step!(
200+
AutoEnzyme(), MSELoss(),
201+
(x, dev(rand(rng, Float32, 10, 2))), train_state
202+
)
203+
```
204+
173205
## 📚 Examples
174206

175207
Look in the [examples](/examples/) directory for self-contained usage examples. The [documentation](https://lux.csail.mit.edu) has examples sorted into proper categories.

docs/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ julia = "1.10"
6666
[sources]
6767
Lux = { path = "../" }
6868
LuxLib = { path = "../lib/LuxLib" }
69+
LuxCUDA = { path = "../lib/LuxCUDA" }
6970
LuxCore = { path = "../lib/LuxCore" }
7071
MLDataDevices = { path = "../lib/MLDataDevices" }
7172
LuxTestUtils = { path = "../lib/LuxTestUtils" }

docs/make.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
using Documenter, DocumenterVitepress, Pkg
22
using Lux, LuxCore, LuxLib, WeightInitializers, NNlib
33
using LuxTestUtils, MLDataDevices
4-
using LuxCUDA
54

65
using Optimisers # for some docstrings
76

@@ -78,8 +77,10 @@ pages = [
7877
#! format: on
7978

8079
deploy_config = Documenter.auto_detect_deploy_system()
81-
deploy_decision = Documenter.deploy_folder(deploy_config; repo="github.com/LuxDL/Lux.jl",
82-
devbranch="main", devurl="dev", push_preview=true)
80+
deploy_decision = Documenter.deploy_folder(
81+
deploy_config; repo="github.com/LuxDL/Lux.jl",
82+
devbranch="main", devurl="dev", push_preview=true
83+
)
8384

8485
makedocs(;
8586
sitename="Lux.jl Docs",
@@ -96,7 +97,8 @@ makedocs(;
9697
repo="https://github.com/LuxDL/Lux.jl/blob/{commit}{path}#{line}",
9798
format=DocumenterVitepress.MarkdownVitepress(;
9899
repo="github.com/LuxDL/Lux.jl", devbranch="main", devurl="dev",
99-
deploy_url="https://lux.csail.mit.edu", deploy_decision),
100+
deploy_url="https://lux.csail.mit.edu", deploy_decision
101+
),
100102
draft=false,
101103
pages
102104
)

docs/src/index.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ hero:
2323
2424
features:
2525
- icon: 🚀
26-
title: Fast & Extendible
27-
details: Lux.jl is written in Julia itself, making it extremely extendible. CUDA and AMDGPU are supported first-class, with experimental support for Metal and Intel GPUs.
26+
title: Fast & Extendable
27+
details: Lux.jl is written in Julia itself, making it extremely extendable. CUDA and AMDGPU are supported first-class, with experimental support for Metal and Intel GPUs.
2828
link: /introduction
2929
3030
- icon: 🐎

docs/src/introduction/index.md

Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,7 @@ Pkg.add("Lux")
2525

2626
```@example quickstart
2727
using Lux, Random, Optimisers, Zygote
28-
using LuxCUDA # For CUDA support
29-
# using AMDGPU, Metal, oneAPI # Other pptional packages for GPU support
28+
# using LuxCUDA, AMDGPU, Metal, oneAPI # Optional packages for GPU support
3029
```
3130

3231
We take randomness very seriously
@@ -66,26 +65,33 @@ y, st = Lux.apply(model, x, ps, st)
6665
train_state = Lux.Training.TrainState(model, ps, st, Adam(0.0001f0))
6766
6867
## We can compute the gradients using Training.compute_gradients
69-
gs, loss, stats, train_state = Lux.Training.compute_gradients(AutoZygote(), MSELoss(),
70-
(x, dev(rand(rng, Float32, 10, 2))), train_state)
68+
gs, loss, stats, train_state = Lux.Training.compute_gradients(
69+
AutoZygote(), MSELoss(),
70+
(x, dev(rand(rng, Float32, 10, 2))), train_state
71+
)
7172
7273
## Optimization
7374
train_state = Training.apply_gradients!(train_state, gs) # or Training.apply_gradients (no `!` at the end)
7475
7576
# Both these steps can be combined into a single call
76-
gs, loss, stats, train_state = Training.single_train_step!(AutoZygote(), MSELoss(),
77-
(x, dev(rand(rng, Float32, 10, 2))), train_state)
77+
gs, loss, stats, train_state = Training.single_train_step!(
78+
AutoZygote(), MSELoss(),
79+
(x, dev(rand(rng, Float32, 10, 2))), train_state
80+
)
7881
```
7982

8083
## Defining Custom Layers
8184

85+
We can train our model using the above code, but let's go ahead and see how to use Reactant.
86+
Reactant is a julia frontend that generates MLIR and then compiles it using XLA (after
87+
running fancy optimizations). It is the current recommended way to train large models in
88+
Lux. For more details on using Reactant, see the [manual](@ref reactant-compilation).
89+
8290
```@example custom_compact
83-
using Lux, Random, Optimisers, Zygote
84-
using LuxCUDA # For CUDA support
85-
# using AMDGPU, Metal, oneAPI # Other pptional packages for GPU support
91+
using Lux, Random, Optimisers, Reactant, Enzyme
8692
using Printf # For pretty printing
8793
88-
dev = gpu_device()
94+
dev = reactant_device()
8995
```
9096

9197
We will define a custom MLP using the `@compact` macro. The macro takes in a list of
@@ -97,10 +103,12 @@ n_in = 1
97103
n_out = 1
98104
nlayers = 3
99105
100-
model = @compact(w1=Dense(n_in => 32),
106+
model = @compact(
107+
w1=Dense(n_in => 32),
101108
w2=[Dense(32 => 32) for i in 1:nlayers],
102109
w3=Dense(32 => n_out),
103-
act=relu) do x
110+
act=relu
111+
) do x
104112
embed = act(w1(x))
105113
for w in w2
106114
embed = act(w(embed))
@@ -116,21 +124,24 @@ We can initialize the model and train it with the same code as before!
116124
rng = Random.default_rng()
117125
Random.seed!(rng, 0)
118126
119-
ps, st = Lux.setup(Xoshiro(0), model) |> dev
127+
ps, st = Lux.setup(rng, model) |> dev
120128
121129
x = rand(rng, Float32, n_in, 32) |> dev
122130
123-
model(x, ps, st) # 1×32 Matrix and updated state as output.
131+
@jit model(x, ps, st) # 1×32 Matrix and updated state as output.
124132
125-
x_data = reshape(collect(-2.0f0:0.1f0:2.0f0), 1, :) |> dev
133+
x_data = reshape(collect(-2.0f0:0.1f0:2.0f0), 1, :)
126134
y_data = 2 .* x_data .- x_data .^ 3
135+
x_data, y_data = dev(x_data), dev(y_data)
127136
128137
function train_model!(model, ps, st, x_data, y_data)
129138
train_state = Lux.Training.TrainState(model, ps, st, Adam(0.001f0))
130139
131140
for iter in 1:1000
132-
_, loss, _, train_state = Lux.Training.single_train_step!(AutoZygote(), MSELoss(),
133-
(x_data, y_data), train_state)
141+
_, loss, _, train_state = Lux.Training.single_train_step!(
142+
AutoEnzyme(), MSELoss(),
143+
(x_data, y_data), train_state
144+
)
134145
if iter % 100 == 1 || iter == 1000
135146
@printf "Iteration: %04d \t Loss: %10.9g\n" iter loss
136147
end
@@ -155,6 +166,11 @@ packages mentioned in this documentation are available via the Julia General Reg
155166

156167
You can install all those packages via `import Pkg; Pkg.add(<package name>)`.
157168

169+
## XLA (CPU/GPU/TPU) Support
170+
171+
Lux.jl supports XLA compilation for CPU, GPU, and TPU using
172+
[Reactant.jl](https://github.com/EnzymeAD/Reactant.jl).
173+
158174
## GPU Support
159175

160176
GPU Support for Lux.jl requires loading additional packages:

docs/src/manual/compiling_lux_models.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,16 @@ fmap(Broadcast.BroadcastFunction(-), ∂ps_zyg, ∂ps_enzyme |> cpu_device())
124124

125125
## [Using the `TrainState` API](@id compile_lux_model_trainstate)
126126

127+
!!! tip "Debugging TrainState API Failures"
128+
129+
If the code fails to compile with Reactant, it is useful to dump the HLO. Starting the
130+
Julia session with `LUX_DUMP_REACTANT_HLO_OPTIMIZE` environment variable set to
131+
`no_enzyme`, `false`, or `true` will dump the HLO to a file (filename will be
132+
displayed). This is an useful information to provide when opening an issue.
133+
134+
Alternatively, you can set theglobal reference `Lux.DUMP_REACTANT_HLO_OPT_MODE` to a
135+
symbol corresponding to the `optimize` keyword argument to `@code_hlo`.
136+
127137
Now that we saw the low-level API let's see how to train the model without any of this
128138
boilerplate. Simply follow the following steps:
129139

docs/src/manual/gpu_management.md

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,5 @@
11
# GPU Management
22

3-
!!! info
4-
5-
Starting from `v0.5`, Lux has transitioned to a new GPU management system. The old
6-
system using `cpu` and `gpu` functions is still in place but will be removed in `v1`.
7-
Using the old functions might lead to performance regressions if used inside
8-
performance critical code.
9-
103
`Lux.jl` can handle multiple GPU backends. Currently, the following backends are supported:
114

125
```@example gpu_management
@@ -16,6 +9,12 @@ using Lux, LuxCUDA #, AMDGPU, Metal, oneAPI
169
supported_gpu_backends()
1710
```
1811

12+
!!! tip "GPU Support via Reactant"
13+
14+
If you are using Reactant, you can use the [`reactant_device`](@ref) function to
15+
automatically select Reactant backend if available. Additionally to force Reactant to
16+
use `gpu`, you can run `Reactant.set_default_backend("gpu")` (this is automatic).
17+
1918
!!! danger "Metal Support"
2019

2120
Support for Metal GPUs should be considered extremely experimental at this point.

docs/tutorials.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
#! format: off
22
const BEGINNER_TUTORIALS = [
3-
"Basics/main.jl" => "CUDA",
3+
"Basics/main.jl" => "CPU",
44
"PolynomialFitting/main.jl" => "CUDA",
55
"SimpleRNN/main.jl" => "CUDA",
6+
# Technically this is run on CPU but we need a better machine to run it
67
"SimpleChains/main.jl" => "CUDA",
7-
"OptimizationIntegration/main.jl" => "CUDA",
8+
"OptimizationIntegration/main.jl" => "CPU",
89
]
910
const INTERMEDIATE_TUTORIALS = [
1011
"NeuralODE/main.jl" => "CUDA",

examples/Basics/Project.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
33
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
44
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
5-
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
65
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
76
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
87
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -12,6 +11,5 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1211
ComponentArrays = "0.15.18"
1312
ForwardDiff = "0.10"
1413
Lux = "1"
15-
LuxCUDA = "0.3"
1614
Optimisers = "0.4.1"
1715
Zygote = "0.6"

0 commit comments

Comments
 (0)