Skip to content

Commit df7b809

Browse files
committed
docs: add 3rd order AD example using Reactant
[skip ci] docs: add stub code
1 parent 420dce3 commit df7b809

File tree

2 files changed

+85
-0
lines changed

2 files changed

+85
-0
lines changed

docs/src/.vitepress/config.mts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,10 @@ export default defineConfig({
356356
{
357357
text: "Profiling Lux Training Loops",
358358
link: "/manual/profiling_training_loop",
359+
},
360+
{
361+
text: "Nested AutoDiff",
362+
link: "/manual/nested_autodiff_reactant",
359363
}
360364
],
361365
},
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# [Nested AutoDiff with Reactant](@id nested_autodiff_reactant)
2+
3+
We will be using the example from [issue 614](https://github.com/LuxDL/Lux.jl/issues/614).
4+
5+
```@example nested_ad_reactant
6+
using Reactant, Enzyme, Lux, Random, LinearAlgebra
7+
8+
const xdev = reactant_device(; force=true)
9+
const cdev = cpu_device()
10+
11+
function ∇potential(potential, x::AbstractMatrix)
12+
N, B = size(x)
13+
dxs = Reactant.materialize_traced_array(reshape(stack(onehot(x)), N, B, N, B))
14+
∇p = similar(x)
15+
@trace for i in 1:B
16+
@trace for j in 1:N
17+
dxᵢ = dxs[:, :, j, i]
18+
res = only(Enzyme.autodiff(Forward, potential, Duplicated(x, dxᵢ)))
19+
@allowscalar ∇p[j, i] = res[j, i]
20+
@show res
21+
@show dxᵢ
22+
end
23+
end
24+
return ∇p
25+
end
26+
27+
model = Dense(5 => 5, gelu)
28+
ps, st = Lux.setup(Random.default_rng(), model) |> xdev
29+
pnet = StatefulLuxLayer(model, ps, st)
30+
31+
x_ra = randn(Float32, 5, 3) |> xdev
32+
33+
@code_hlo pnet(x_ra)
34+
@code_hlo ∇potential(pnet, x_ra)
35+
36+
function ∇²potential(potential, x)
37+
dxs = stack(onehot(x))
38+
∇²p = similar(x)
39+
colons = [Colon() for _ in 1:ndims(x)]
40+
@trace for i in 1:length(x)
41+
dxᵢ = dxs[colons..., i]
42+
res = only(Enzyme.autodiff(
43+
Forward, ∇potential, Const(potential), Duplicated(x, dxᵢ)
44+
))
45+
@allowscalar ∇²p[i] = res[i]
46+
end
47+
return ∇²p
48+
end
49+
50+
@code_hlo ∇²potential(pnet, x_ra)
51+
52+
struct PotentialNet{P} <: AbstractLuxWrapperLayer{:potential}
53+
potential::P
54+
end
55+
56+
function (potential::PotentialNet)(x, ps, st)
57+
pnet = StatefulLuxLayer{true}(potential.potential, ps, st)
58+
return ∇²potential(pnet, x), pnet.st
59+
end
60+
61+
model = PotentialNet(Dense(5 => 5, gelu))
62+
ps, st = Lux.setup(Random.default_rng(), model) |> xdev
63+
64+
x_ra = randn(Float32, 5, 3) |> xdev
65+
66+
@code_hlo model(x_ra, ps, st)
67+
68+
@jit model(x_ra, ps, st)
69+
```
70+
71+
```@example nested_ad_reactant
72+
sumabs2first(model, x, ps, st) = sum(abs2, first(model(x, ps, st)))
73+
74+
function enzyme_gradient(model, x, ps, st)
75+
return Enzyme.gradient(
76+
Enzyme.Reverse, Const(sumabs2first), Const(model), Const(x), ps, Const(st)
77+
)
78+
end
79+
80+
@jit enzyme_gradient(model, x_ra, ps, st)
81+
```

0 commit comments

Comments
 (0)