Skip to content

Commit 15bddaf

Browse files
committed
docs: add 3rd order AD example using Reactant
1 parent 5b24b62 commit 15bddaf

File tree

3 files changed

+84
-1
lines changed

3 files changed

+84
-1
lines changed

docs/src/.vitepress/config.mts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,10 @@ export default defineConfig({
356356
{
357357
text: "Profiling Lux Training Loops",
358358
link: "/manual/profiling_training_loop",
359+
},
360+
{
361+
text: "Nested AutoDiff",
362+
link: "/manual/nested_autodiff_reactant",
359363
}
360364
],
361365
},

docs/src/manual/nested_autodiff.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
!!! note "Reactant"
44

55
Reactant.jl natively supports nested AD (with orders greater than 2nd order). For more
6-
robust nested AD, use Lux with Reactant.jl.
6+
robust nested AD, use [Lux with Reactant.jl](@ref nested_ad_reactant).
77

88
In this manual, we will explore how to use automatic differentiation (AD) inside your layers
99
or loss functions and have Lux automatically switch the AD backend with a faster one when
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# [Nested AutoDiff with Reactant](@id nested_autodiff_reactant)
2+
3+
We will be using the example from [issue 614](https://github.com/LuxDL/Lux.jl/issues/614).
4+
5+
```@example nested_ad_reactant
6+
using Reactant, Enzyme, Lux, Random, LinearAlgebra
7+
8+
const xdev = reactant_device(; force=true)
9+
const cdev = cpu_device()
10+
11+
function stacked_onehot(x::AbstractMatrix{T}) where {T}
12+
onehot_matrix = Reactant.promote_to(
13+
Reactant.TracedRArray{Reactant.unwrapped_eltype(T),2}, LinearAlgebra.I(length(x))
14+
)
15+
return Reactant.materialize_traced_array(reshape(onehot_matrix, size(x)..., size(x)...))
16+
end
17+
18+
function ∇potential(potential, x::AbstractMatrix)
19+
N, B = size(x)
20+
dxs = stacked_onehot(x)
21+
∇p = similar(x)
22+
@trace for i in 1:B
23+
@trace for j in 1:N
24+
dxᵢ = dxs[:, :, j, i]
25+
res = only(Enzyme.autodiff(Forward, potential, Duplicated(x, dxᵢ)))
26+
@allowscalar ∇p[j, i] = res[j, i]
27+
end
28+
end
29+
return ∇p
30+
end
31+
32+
function ∇²potential(potential, x::AbstractMatrix)
33+
N, B = size(x)
34+
dxs = stacked_onehot(x)
35+
∇²p = similar(x)
36+
@trace for i in 1:B
37+
@trace for j in 1:N
38+
dxᵢ = dxs[:, :, j, i]
39+
res = only(Enzyme.autodiff(
40+
Forward, ∇potential, Const(potential), Duplicated(x, dxᵢ)
41+
))
42+
@allowscalar ∇²p[j, i] = res[j, i]
43+
end
44+
end
45+
return ∇²p
46+
end
47+
```
48+
49+
```@example nested_ad_reactant
50+
struct PotentialNet{P} <: AbstractLuxWrapperLayer{:potential}
51+
potential::P
52+
end
53+
54+
function (potential::PotentialNet)(x, ps, st)
55+
pnet = StatefulLuxLayer{true}(potential.potential, ps, st)
56+
return ∇²potential(pnet, x), pnet.st
57+
end
58+
```
59+
60+
```@example nested_ad_reactant
61+
model = PotentialNet(Dense(5 => 5, gelu))
62+
ps, st = Lux.setup(Random.default_rng(), model) |> xdev
63+
64+
x_ra = randn(Float32, 5, 1024) |> xdev
65+
66+
@jit model(x_ra, ps, st)
67+
```
68+
69+
```@example nested_ad_reactant
70+
sumabs2first(model, x, ps, st) = sum(abs2, first(model(x, ps, st)))
71+
72+
function enzyme_gradient(model, x, ps, st)
73+
return Enzyme.gradient(
74+
Enzyme.Reverse, Const(sumabs2first), Const(model), Const(x), ps, Const(st)
75+
)
76+
end
77+
78+
@jit enzyme_gradient(model, x_ra, ps, st)
79+
```

0 commit comments

Comments
 (0)