|
| 1 | +# [Nested AutoDiff with Reactant](@id nested_autodiff_reactant) |
| 2 | + |
| 3 | +We will be using the example from [issue 614](https://github.com/LuxDL/Lux.jl/issues/614). |
| 4 | + |
| 5 | +```@example nested_ad_reactant |
| 6 | +using Reactant, Enzyme, Lux, Random, LinearAlgebra |
| 7 | +
|
| 8 | +const xdev = reactant_device(; force=true) |
| 9 | +const cdev = cpu_device() |
| 10 | +
|
| 11 | +# XXX: We need to be able to compile this with a for-loop else tracing time will scale |
| 12 | +# proportionally to the number of elements in the input. |
| 13 | +function ∇potential(potential, x) |
| 14 | + dxs = stack(onehot(x)) |
| 15 | + ∇p = similar(x) |
| 16 | + colons = [Colon() for _ in 1:ndims(x)] |
| 17 | + @trace for i in 1:length(x) |
| 18 | + dxᵢ = dxs[colons..., i] |
| 19 | + res = only(Enzyme.autodiff( |
| 20 | + Enzyme.set_abi(Forward, Reactant.ReactantABI), potential, Duplicated(x, dxᵢ) |
| 21 | + )) |
| 22 | + @allowscalar ∇p[i] = res[i] |
| 23 | + end |
| 24 | + return ∇p |
| 25 | +end |
| 26 | +
|
| 27 | +# function ∇²potential(potential, x) |
| 28 | +# dxs = onehot(x) |
| 29 | +# ∇²p = similar(x) |
| 30 | +# for i in eachindex(dxs) |
| 31 | +# dxᵢ = dxs[i] |
| 32 | +# res = only(Enzyme.autodiff( |
| 33 | +# Enzyme.set_abi(Forward, Reactant.ReactantABI), |
| 34 | +# ∇potential, Const(potential), Duplicated(x, dxᵢ) |
| 35 | +# )) |
| 36 | +# @allowscalar ∇²p[i] = res[i] |
| 37 | +# end |
| 38 | +# return ∇²p |
| 39 | +# end |
| 40 | +
|
| 41 | +struct PotentialNet{P} <: Lux.AbstractLuxWrapperLayer{:potential} |
| 42 | + potential::P |
| 43 | +end |
| 44 | +
|
| 45 | +function (potential::PotentialNet)(x, ps, st) |
| 46 | + pnet = StatefulLuxLayer{true}(potential.potential, ps, st) |
| 47 | + return ∇potential(pnet, x), pnet.st |
| 48 | + # return ∇²potential(pnet, x), pnet.st |
| 49 | +end |
| 50 | +
|
| 51 | +model = PotentialNet(Dense(5 => 5, gelu)) |
| 52 | +ps, st = Lux.setup(Random.default_rng(), model) |> xdev |
| 53 | +
|
| 54 | +x_ra = randn(Float32, 5, 3) |> xdev |
| 55 | +
|
| 56 | +@code_hlo model(x_ra, ps, st) |
| 57 | +
|
| 58 | +1 + 1 |
| 59 | +
|
| 60 | +model_compiled = @compile model(x_ra, ps, st) |
| 61 | +model_compiled(x_ra, ps, st) |
| 62 | +
|
| 63 | +sumabs2first(model, x, ps, st) = sum(abs2, first(model(x, ps, st))) |
| 64 | +
|
| 65 | +function enzyme_gradient(model, x, ps, st) |
| 66 | + return Enzyme.gradient( |
| 67 | + Enzyme.Reverse, Const(sumabs2first), Const(model), Const(x), ps, Const(st) |
| 68 | + ) |
| 69 | +end |
| 70 | +
|
| 71 | +@jit enzyme_gradient(model, x_ra, ps, st) |
| 72 | +``` |
0 commit comments