|
| 1 | +# [Nested AutoDiff with Reactant](@id nested_autodiff_reactant) |
| 2 | + |
| 3 | +We will be using the example from [issue 614](https://github.com/LuxDL/Lux.jl/issues/614). |
| 4 | + |
| 5 | +```@example nested_ad_reactant |
| 6 | +using Reactant, Enzyme, Lux, Random, LinearAlgebra |
| 7 | +
|
| 8 | +const xdev = reactant_device(; force=true) |
| 9 | +const cdev = cpu_device() |
| 10 | +
|
| 11 | +function ∇potential(potential, x) |
| 12 | + dxs = stack(onehot(x)) |
| 13 | + ∇p = similar(x) |
| 14 | + colons = [Colon() for _ in 1:ndims(x)] |
| 15 | + @trace for i in 1:length(x) |
| 16 | + dxᵢ = dxs[colons..., i] |
| 17 | + res = only(Enzyme.autodiff( |
| 18 | + Enzyme.set_abi(Forward, Reactant.ReactantABI), potential, Duplicated(x, dxᵢ) |
| 19 | + )) |
| 20 | + @allowscalar ∇p[i] = res[i] |
| 21 | + end |
| 22 | + return ∇p |
| 23 | +end |
| 24 | +
|
| 25 | +function ∇²potential(potential, x) |
| 26 | + dxs = stack(onehot(x)) |
| 27 | + ∇²p = similar(x) |
| 28 | + colons = [Colon() for _ in 1:ndims(x)] |
| 29 | + @trace for i in 1:length(x) |
| 30 | + dxᵢ = dxs[colons..., i] |
| 31 | + res = only(Enzyme.autodiff( |
| 32 | + Enzyme.set_abi(Forward, Reactant.ReactantABI), |
| 33 | + ∇potential, Const(potential), Duplicated(x, dxᵢ) |
| 34 | + )) |
| 35 | + @allowscalar ∇²p[i] = res[i] |
| 36 | + end |
| 37 | + return ∇²p |
| 38 | +end |
| 39 | +
|
| 40 | +struct PotentialNet{P} <: AbstractLuxWrapperLayer{:potential} |
| 41 | + potential::P |
| 42 | +end |
| 43 | +
|
| 44 | +function (potential::PotentialNet)(x, ps, st) |
| 45 | + pnet = StatefulLuxLayer{true}(potential.potential, ps, st) |
| 46 | + return ∇²potential(pnet, x), pnet.st |
| 47 | +end |
| 48 | +
|
| 49 | +model = PotentialNet(Dense(5 => 5, gelu)) |
| 50 | +ps, st = Lux.setup(Random.default_rng(), model) |> xdev |
| 51 | +
|
| 52 | +x_ra = randn(Float32, 5, 3) |> xdev |
| 53 | +
|
| 54 | +@code_hlo model(x_ra, ps, st) |
| 55 | +
|
| 56 | +@jit model(x_ra, ps, st) |
| 57 | +``` |
| 58 | + |
| 59 | +```@example nested_ad_reactant |
| 60 | +sumabs2first(model, x, ps, st) = sum(abs2, first(model(x, ps, st))) |
| 61 | +
|
| 62 | +function enzyme_gradient(model, x, ps, st) |
| 63 | + return Enzyme.gradient( |
| 64 | + Enzyme.Reverse, Const(sumabs2first), Const(model), Const(x), ps, Const(st) |
| 65 | + ) |
| 66 | +end |
| 67 | +
|
| 68 | +@jit enzyme_gradient(model, x_ra, ps, st) |
| 69 | +``` |
0 commit comments