Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/src/.vitepress/config.mts
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,10 @@ export default defineConfig({
{
text: "Profiling Lux Training Loops",
link: "/manual/profiling_training_loop",
},
{
text: "Nested AutoDiff",
link: "/manual/nested_autodiff_reactant",
}
],
},
Expand Down
2 changes: 1 addition & 1 deletion docs/src/manual/nested_autodiff.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
!!! note "Reactant"

Reactant.jl natively supports nested AD (with orders greater than 2nd order). For more
robust nested AD, use Lux with Reactant.jl.
robust nested AD, use [Lux with Reactant.jl](@ref nested_ad_reactant).

In this manual, we will explore how to use automatic differentiation (AD) inside your layers
or loss functions and have Lux automatically switch the AD backend with a faster one when
Expand Down
79 changes: 79 additions & 0 deletions docs/src/manual/nested_autodiff_reactant.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# [Nested AutoDiff with Reactant](@id nested_autodiff_reactant)

We will be using the example from [issue 614](https://github.com/LuxDL/Lux.jl/issues/614).

```@example nested_ad_reactant
using Reactant, Enzyme, Lux, Random, LinearAlgebra

const xdev = reactant_device(; force=true)
const cdev = cpu_device()

function stacked_onehot(x::AbstractMatrix{T}) where {T}
onehot_matrix = Reactant.promote_to(
Reactant.TracedRArray{Reactant.unwrapped_eltype(T),2}, LinearAlgebra.I(length(x))
)
return Reactant.materialize_traced_array(reshape(onehot_matrix, size(x)..., size(x)...))
end

function ∇potential(potential, x::AbstractMatrix)
N, B = size(x)
dxs = stacked_onehot(x)
∇p = similar(x)
@trace for i in 1:B
@trace for j in 1:N
dxᵢ = dxs[:, :, j, i]
res = only(Enzyme.autodiff(Forward, potential, Duplicated(x, dxᵢ)))
@allowscalar ∇p[j, i] = res[j, i]
end
end
return ∇p
end

function ∇²potential(potential, x::AbstractMatrix)
N, B = size(x)
dxs = stacked_onehot(x)
∇²p = similar(x)
@trace for i in 1:B
@trace for j in 1:N
dxᵢ = dxs[:, :, j, i]
res = only(Enzyme.autodiff(
Forward, ∇potential, Const(potential), Duplicated(x, dxᵢ)
))
@allowscalar ∇²p[j, i] = res[j, i]
end
end
return ∇²p
end
```

```@example nested_ad_reactant
struct PotentialNet{P} <: AbstractLuxWrapperLayer{:potential}
potential::P
end

function (potential::PotentialNet)(x, ps, st)
pnet = StatefulLuxLayer{true}(potential.potential, ps, st)
return ∇²potential(pnet, x), pnet.st
end
```

```@example nested_ad_reactant
model = PotentialNet(Dense(5 => 5, gelu))
ps, st = Lux.setup(Random.default_rng(), model) |> xdev

x_ra = randn(Float32, 5, 1024) |> xdev

@jit model(x_ra, ps, st)
```

```@example nested_ad_reactant
sumabs2first(model, x, ps, st) = sum(abs2, first(model(x, ps, st)))

function enzyme_gradient(model, x, ps, st)
return Enzyme.gradient(
Enzyme.Reverse, Const(sumabs2first), Const(model), Const(x), ps, Const(st)
)
end

@jit enzyme_gradient(model, x_ra, ps, st)
```
Loading