Skip to content

Commit 63b5d52

Browse files
committed
docs: add 3rd order AD example using Reactant
[skip ci]
1 parent b49a1a3 commit 63b5d52

File tree

4 files changed

+85
-7
lines changed

4 files changed

+85
-7
lines changed

docs/make.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ pages = [
4848
"manual/distributed_utils.md",
4949
"manual/nested_autodiff.md",
5050
"manual/compiling_lux_models.md",
51+
"manual/exporting_to_jax.md",
52+
"manual/nested_autodiff_reactant.md"
5153
],
5254
"API Reference" => [
5355
"Lux" => [

docs/src/.vitepress/config.mts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,10 @@ export default defineConfig({
346346
text: "Exporting Lux Models to Jax",
347347
link: "/manual/exporting_to_jax",
348348
},
349+
{
350+
text: "Nested AutoDiff",
351+
link: "/manual/nested_autodiff_reactant",
352+
}
349353
],
350354
},
351355
{

docs/src/manual/nested_autodiff.md

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
# [Nested Automatic Differentiation](@id nested_autodiff)
22

3-
!!! note
4-
5-
This is a relatively new feature in Lux, so there might be some rough edges. If you
6-
encounter any issues, please let us know by opening an issue on the
7-
[GitHub repository](https://github.com/LuxDL/Lux.jl).
8-
93
In this manual, we will explore how to use automatic differentiation (AD) inside your layers
104
or loss functions and have Lux automatically switch the AD backend with a faster one when
115
needed.
126

13-
!!! tip
7+
!!! tip "Reactant Support"
8+
9+
Reactant + Lux natively supports Nested AD (even higher dimensions). If you are using
10+
Reactant, please see the [Nested AD with Reactant](@ref nested_autodiff_reactant)
11+
manual.
12+
13+
!!! tip "Disabling Nested AD Switching"
1414

1515
Don't wan't Lux to do this switching for you? You can disable it by setting the
1616
`automatic_nested_ad_switching` Preference to `false`.
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# [Nested AutoDiff with Reactant](@id nested_autodiff_reactant)
2+
3+
We will be using the example from [issue 614](https://github.com/LuxDL/Lux.jl/issues/614).
4+
5+
```@example nested_ad_reactant
6+
using Reactant, Enzyme, Lux, Random, LinearAlgebra
7+
8+
const xdev = reactant_device(; force=true)
9+
const cdev = cpu_device()
10+
11+
# XXX: We need to be able to compile this with a for-loop else tracing time will scale
12+
# proportionally to the number of elements in the input.
13+
function ∇potential(potential, x)
14+
dxs = stack(onehot(x))
15+
∇p = similar(x)
16+
colons = [Colon() for _ in 1:ndims(x)]
17+
@trace for i in 1:length(x)
18+
dxᵢ = dxs[colons..., i]
19+
res = only(Enzyme.autodiff(
20+
Enzyme.set_abi(Forward, Reactant.ReactantABI), potential, Duplicated(x, dxᵢ)
21+
))
22+
@allowscalar ∇p[i] = res[i]
23+
end
24+
return ∇p
25+
end
26+
27+
# function ∇²potential(potential, x)
28+
# dxs = onehot(x)
29+
# ∇²p = similar(x)
30+
# for i in eachindex(dxs)
31+
# dxᵢ = dxs[i]
32+
# res = only(Enzyme.autodiff(
33+
# Enzyme.set_abi(Forward, Reactant.ReactantABI),
34+
# ∇potential, Const(potential), Duplicated(x, dxᵢ)
35+
# ))
36+
# @allowscalar ∇²p[i] = res[i]
37+
# end
38+
# return ∇²p
39+
# end
40+
41+
struct PotentialNet{P} <: Lux.AbstractLuxWrapperLayer{:potential}
42+
potential::P
43+
end
44+
45+
function (potential::PotentialNet)(x, ps, st)
46+
pnet = StatefulLuxLayer{true}(potential.potential, ps, st)
47+
return ∇potential(pnet, x), pnet.st
48+
# return ∇²potential(pnet, x), pnet.st
49+
end
50+
51+
model = PotentialNet(Dense(5 => 5, gelu))
52+
ps, st = Lux.setup(Random.default_rng(), model) |> xdev
53+
54+
x_ra = randn(Float32, 5, 3) |> xdev
55+
56+
@code_hlo model(x_ra, ps, st)
57+
58+
1 + 1
59+
60+
model_compiled = @compile model(x_ra, ps, st)
61+
model_compiled(x_ra, ps, st)
62+
63+
sumabs2first(model, x, ps, st) = sum(abs2, first(model(x, ps, st)))
64+
65+
function enzyme_gradient(model, x, ps, st)
66+
return Enzyme.gradient(
67+
Enzyme.Reverse, Const(sumabs2first), Const(model), Const(x), ps, Const(st)
68+
)
69+
end
70+
71+
@jit enzyme_gradient(model, x_ra, ps, st)
72+
```

0 commit comments

Comments
 (0)