Skip to content

Commit 2c40b8f

Browse files
committed
docs: add 3rd order AD example using Reactant
[skip ci]
1 parent 16031b6 commit 2c40b8f

File tree

4 files changed

+82
-7
lines changed

4 files changed

+82
-7
lines changed

docs/make.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ pages = [
4848
"manual/distributed_utils.md",
4949
"manual/nested_autodiff.md",
5050
"manual/compiling_lux_models.md",
51+
"manual/exporting_to_jax.md",
52+
"manual/nested_autodiff_reactant.md"
5153
],
5254
"API Reference" => [
5355
"Lux" => [

docs/src/.vitepress/config.mts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,10 @@ export default defineConfig({
346346
text: "Exporting Lux Models to Jax",
347347
link: "/manual/exporting_to_jax",
348348
},
349+
{
350+
text: "Nested AutoDiff",
351+
link: "/manual/nested_autodiff_reactant",
352+
}
349353
],
350354
},
351355
{

docs/src/manual/nested_autodiff.md

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
# [Nested Automatic Differentiation](@id nested_autodiff)
22

3-
!!! note
4-
5-
This is a relatively new feature in Lux, so there might be some rough edges. If you
6-
encounter any issues, please let us know by opening an issue on the
7-
[GitHub repository](https://github.com/LuxDL/Lux.jl).
8-
93
In this manual, we will explore how to use automatic differentiation (AD) inside your layers
104
or loss functions and have Lux automatically switch the AD backend with a faster one when
115
needed.
126

13-
!!! tip
7+
!!! tip "Reactant Support"
8+
9+
Reactant + Lux natively supports Nested AD (even higher dimensions). If you are using
10+
Reactant, please see the [Nested AD with Reactant](@ref nested_autodiff_reactant)
11+
manual.
12+
13+
!!! tip "Disabling Nested AD Switching"
1414

1515
Don't wan't Lux to do this switching for you? You can disable it by setting the
1616
`automatic_nested_ad_switching` Preference to `false`.
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# [Nested AutoDiff with Reactant](@id nested_autodiff_reactant)
2+
3+
We will be using the example from [issue 614](https://github.com/LuxDL/Lux.jl/issues/614).
4+
5+
```@example nested_ad_reactant
6+
using Reactant, Enzyme, Lux, Random, LinearAlgebra
7+
8+
const xdev = reactant_device(; force=true)
9+
const cdev = cpu_device()
10+
11+
function ∇potential(potential, x)
12+
dxs = stack(onehot(x))
13+
∇p = similar(x)
14+
colons = [Colon() for _ in 1:ndims(x)]
15+
@trace for i in 1:length(x)
16+
dxᵢ = dxs[colons..., i]
17+
res = only(Enzyme.autodiff(
18+
Enzyme.set_abi(Forward, Reactant.ReactantABI), potential, Duplicated(x, dxᵢ)
19+
))
20+
@allowscalar ∇p[i] = res[i]
21+
end
22+
return ∇p
23+
end
24+
25+
function ∇²potential(potential, x)
26+
dxs = stack(onehot(x))
27+
∇²p = similar(x)
28+
colons = [Colon() for _ in 1:ndims(x)]
29+
@trace for i in 1:length(x)
30+
dxᵢ = dxs[colons..., i]
31+
res = only(Enzyme.autodiff(
32+
Enzyme.set_abi(Forward, Reactant.ReactantABI),
33+
∇potential, Const(potential), Duplicated(x, dxᵢ)
34+
))
35+
@allowscalar ∇²p[i] = res[i]
36+
end
37+
return ∇²p
38+
end
39+
40+
struct PotentialNet{P} <: AbstractLuxWrapperLayer{:potential}
41+
potential::P
42+
end
43+
44+
function (potential::PotentialNet)(x, ps, st)
45+
pnet = StatefulLuxLayer{true}(potential.potential, ps, st)
46+
return ∇²potential(pnet, x), pnet.st
47+
end
48+
49+
model = PotentialNet(Dense(5 => 5, gelu))
50+
ps, st = Lux.setup(Random.default_rng(), model) |> xdev
51+
52+
x_ra = randn(Float32, 5, 3) |> xdev
53+
54+
@code_hlo model(x_ra, ps, st)
55+
56+
@jit model(x_ra, ps, st)
57+
```
58+
59+
```@example nested_ad_reactant
60+
sumabs2first(model, x, ps, st) = sum(abs2, first(model(x, ps, st)))
61+
62+
function enzyme_gradient(model, x, ps, st)
63+
return Enzyme.gradient(
64+
Enzyme.Reverse, Const(sumabs2first), Const(model), Const(x), ps, Const(st)
65+
)
66+
end
67+
68+
@jit enzyme_gradient(model, x_ra, ps, st)
69+
```

0 commit comments

Comments
 (0)