Skip to content

Commit e38be45

Browse files
committed
docs: add 3rd order AD example using Reactant
[skip ci]
1 parent a2167f2 commit e38be45

File tree

3 files changed

+80
-1
lines changed

3 files changed

+80
-1
lines changed

docs/src/.vitepress/config.mts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,10 @@ export default defineConfig({
356356
{
357357
text: "Profiling Lux Training Loops",
358358
link: "/manual/profiling_training_loop",
359+
},
360+
{
361+
text: "Nested AutoDiff",
362+
link: "/manual/nested_autodiff_reactant",
359363
}
360364
],
361365
},

docs/src/manual/nested_autodiff.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,13 @@ In this manual, we will explore how to use automatic differentiation (AD) inside
99
or loss functions and have Lux automatically switch the AD backend with a faster one when
1010
needed.
1111

12-
!!! tip
12+
!!! tip "Reactant Support"
13+
14+
Reactant + Lux natively supports Nested AD (even higher dimensions). If you are using
15+
Reactant, please see the [Nested AD with Reactant](@ref nested_autodiff_reactant)
16+
manual.
17+
18+
!!! tip "Disabling Nested AD Switching"
1319

1420
Don't wan't Lux to do this switching for you? You can disable it by setting the
1521
`automatic_nested_ad_switching` Preference to `false`.
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# [Nested AutoDiff with Reactant](@id nested_autodiff_reactant)
2+
3+
We will be using the example from [issue 614](https://github.com/LuxDL/Lux.jl/issues/614).
4+
5+
```@example nested_ad_reactant
6+
using Reactant, Enzyme, Lux, Random, LinearAlgebra
7+
8+
const xdev = reactant_device(; force=true)
9+
const cdev = cpu_device()
10+
11+
function ∇potential(potential, x)
12+
dxs = stack(onehot(x))
13+
∇p = similar(x)
14+
colons = [Colon() for _ in 1:ndims(x)]
15+
@trace for i in 1:length(x)
16+
dxᵢ = dxs[colons..., i]
17+
res = only(Enzyme.autodiff(
18+
Enzyme.set_abi(Forward, Reactant.ReactantABI), potential, Duplicated(x, dxᵢ)
19+
))
20+
@allowscalar ∇p[i] = res[i]
21+
end
22+
return ∇p
23+
end
24+
25+
function ∇²potential(potential, x)
26+
dxs = stack(onehot(x))
27+
∇²p = similar(x)
28+
colons = [Colon() for _ in 1:ndims(x)]
29+
@trace for i in 1:length(x)
30+
dxᵢ = dxs[colons..., i]
31+
res = only(Enzyme.autodiff(
32+
Enzyme.set_abi(Forward, Reactant.ReactantABI),
33+
∇potential, Const(potential), Duplicated(x, dxᵢ)
34+
))
35+
@allowscalar ∇²p[i] = res[i]
36+
end
37+
return ∇²p
38+
end
39+
40+
struct PotentialNet{P} <: AbstractLuxWrapperLayer{:potential}
41+
potential::P
42+
end
43+
44+
function (potential::PotentialNet)(x, ps, st)
45+
pnet = StatefulLuxLayer{true}(potential.potential, ps, st)
46+
return ∇²potential(pnet, x), pnet.st
47+
end
48+
49+
model = PotentialNet(Dense(5 => 5, gelu))
50+
ps, st = Lux.setup(Random.default_rng(), model) |> xdev
51+
52+
x_ra = randn(Float32, 5, 3) |> xdev
53+
54+
@code_hlo model(x_ra, ps, st)
55+
56+
@jit model(x_ra, ps, st)
57+
```
58+
59+
```@example nested_ad_reactant
60+
sumabs2first(model, x, ps, st) = sum(abs2, first(model(x, ps, st)))
61+
62+
function enzyme_gradient(model, x, ps, st)
63+
return Enzyme.gradient(
64+
Enzyme.Reverse, Const(sumabs2first), Const(model), Const(x), ps, Const(st)
65+
)
66+
end
67+
68+
@jit enzyme_gradient(model, x_ra, ps, st)
69+
```

0 commit comments

Comments
 (0)