Skip to content

Commit a71f40a

Browse files
committed
docs: add 3rd order AD example using Reactant
1 parent 219eeba commit a71f40a

File tree

4 files changed

+79
-7
lines changed

4 files changed

+79
-7
lines changed

docs/make.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ pages = [
4747
"manual/distributed_utils.md",
4848
"manual/nested_autodiff.md",
4949
"manual/compiling_lux_models.md",
50+
"manual/exporting_to_jax.md",
51+
"manual/nested_autodiff_reactant.md"
5052
],
5153
"API Reference" => [
5254
"Lux" => [

docs/src/.vitepress/config.mts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,10 @@ export default defineConfig({
314314
text: "Exporting Lux Models to Jax",
315315
link: "/manual/exporting_to_jax",
316316
},
317+
{
318+
text: "Nested AutoDiff",
319+
link: "/manual/nested_autodiff_reactant",
320+
}
317321
],
318322
},
319323
{

docs/src/manual/nested_autodiff.md

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
# [Nested Automatic Differentiation](@id nested_autodiff)
22

3-
!!! note
4-
5-
This is a relatively new feature in Lux, so there might be some rough edges. If you
6-
encounter any issues, please let us know by opening an issue on the
7-
[GitHub repository](https://github.com/LuxDL/Lux.jl).
8-
93
In this manual, we will explore how to use automatic differentiation (AD) inside your layers
104
or loss functions and have Lux automatically switch the AD backend with a faster one when
115
needed.
126

13-
!!! tip
7+
!!! tip "Reactant Support"
8+
9+
Reactant + Lux natively supports Nested AD (even higher dimensions). If you are using
10+
Reactant, please see the [Nested AD with Reactant](@ref nested_autodiff_reactant)
11+
manual.
12+
13+
!!! tip "Disabling Nested AD Switching"
1414

1515
Don't wan't Lux to do this switching for you? You can disable it by setting the
1616
`automatic_nested_ad_switching` Preference to `false`.
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# [Nested AutoDiff with Reactant](@id nested_autodiff_reactant)
2+
3+
We will be using the example from [issue 614](https://github.com/LuxDL/Lux.jl/issues/614).
4+
5+
```@example nested_ad_reactant
6+
using Reactant, Enzyme, Lux, Random, LinearAlgebra
7+
8+
const xdev = reactant_device()
9+
const cdev = cpu_device()
10+
11+
# XXX: We need to be able to compile this with a for-loop else tracing time will scale
12+
# proportionally to the number of elements in the input.
13+
function ∇potential(potential, x)
14+
dxs = onehot(x)
15+
∇p = similar(x)
16+
for i in eachindex(dxs)
17+
dxᵢ = dxs[i]
18+
res = only(Enzyme.autodiff(
19+
Enzyme.set_abi(Forward, Reactant.ReactantABI), potential, Duplicated(x, dxᵢ)
20+
))
21+
@allowscalar ∇p[i] = res[i]
22+
end
23+
return ∇p
24+
end
25+
26+
function ∇²potential(potential, x)
27+
dxs = onehot(x)
28+
∇²p = similar(x)
29+
for i in eachindex(dxs)
30+
dxᵢ = dxs[i]
31+
res = only(Enzyme.autodiff(
32+
Enzyme.set_abi(Forward, Reactant.ReactantABI),
33+
∇potential, Const(potential), Duplicated(x, dxᵢ)
34+
))
35+
@allowscalar ∇²p[i] = res[i]
36+
end
37+
return ∇²p
38+
end
39+
40+
struct PotentialNet{P} <: Lux.AbstractLuxWrapperLayer{:potential}
41+
potential::P
42+
end
43+
44+
function (potential::PotentialNet)(x, ps, st)
45+
pnet = StatefulLuxLayer{true}(potential.potential, ps, st)
46+
return ∇²potential(pnet, x), pnet.st
47+
end
48+
49+
model = PotentialNet(Dense(5 => 5, gelu))
50+
ps, st = Lux.setup(Random.default_rng(), model) |> xdev
51+
52+
x_ra = randn(Float32, 5, 3) |> xdev
53+
54+
model_compiled = @compile model(x_ra, ps, st)
55+
model_compiled(x_ra, ps, st)
56+
57+
sumabs2first(model, x, ps, st) = sum(abs2, first(model(x, ps, st)))
58+
59+
function enzyme_gradient(model, x, ps, st)
60+
return Enzyme.gradient(
61+
Enzyme.Reverse, Const(sumabs2first), Const(model), Const(x), ps, Const(st)
62+
)
63+
end
64+
65+
@jit enzyme_gradient(model, x_ra, ps, st)
66+
```

0 commit comments

Comments
 (0)