|
| 1 | +# [Nested AutoDiff with Reactant](@id nested_autodiff_reactant) |
| 2 | + |
| 3 | +We will be using the example from [issue 614](https://github.com/LuxDL/Lux.jl/issues/614). |
| 4 | + |
| 5 | +```@example nested_ad_reactant |
| 6 | +using Reactant, Enzyme, Lux, Random, LinearAlgebra |
| 7 | +
|
| 8 | +const xdev = reactant_device(; force=true) |
| 9 | +const cdev = cpu_device() |
| 10 | +
|
| 11 | +function ∇potential(potential, x::AbstractMatrix) |
| 12 | + N, B = size(x) |
| 13 | + dxs = Reactant.materialize_traced_array(reshape(stack(onehot(x)), N, B, N, B)) |
| 14 | + ∇p = similar(x) |
| 15 | + @trace for i in 1:B |
| 16 | + @trace for j in 1:N |
| 17 | + dxᵢ = dxs[:, :, j, i] |
| 18 | + res = only(Enzyme.autodiff(Forward, potential, Duplicated(x, dxᵢ))) |
| 19 | + @allowscalar ∇p[j, i] = res[j, i] |
| 20 | + @show res |
| 21 | + @show dxᵢ |
| 22 | + end |
| 23 | + end |
| 24 | + return ∇p |
| 25 | +end |
| 26 | +
|
| 27 | +model = Dense(5 => 5, gelu) |
| 28 | +ps, st = Lux.setup(Random.default_rng(), model) |> xdev |
| 29 | +pnet = StatefulLuxLayer(model, ps, st) |
| 30 | +
|
| 31 | +x_ra = randn(Float32, 5, 3) |> xdev |
| 32 | +
|
| 33 | +@code_hlo pnet(x_ra) |
| 34 | +@code_hlo ∇potential(pnet, x_ra) |
| 35 | +
|
| 36 | +function ∇²potential(potential, x) |
| 37 | + dxs = stack(onehot(x)) |
| 38 | + ∇²p = similar(x) |
| 39 | + colons = [Colon() for _ in 1:ndims(x)] |
| 40 | + @trace for i in 1:length(x) |
| 41 | + dxᵢ = dxs[colons..., i] |
| 42 | + res = only(Enzyme.autodiff( |
| 43 | + Forward, ∇potential, Const(potential), Duplicated(x, dxᵢ) |
| 44 | + )) |
| 45 | + @allowscalar ∇²p[i] = res[i] |
| 46 | + end |
| 47 | + return ∇²p |
| 48 | +end |
| 49 | +
|
| 50 | +@code_hlo ∇²potential(pnet, x_ra) |
| 51 | +
|
| 52 | +struct PotentialNet{P} <: AbstractLuxWrapperLayer{:potential} |
| 53 | + potential::P |
| 54 | +end |
| 55 | +
|
| 56 | +function (potential::PotentialNet)(x, ps, st) |
| 57 | + pnet = StatefulLuxLayer{true}(potential.potential, ps, st) |
| 58 | + return ∇²potential(pnet, x), pnet.st |
| 59 | +end |
| 60 | +
|
| 61 | +model = PotentialNet(Dense(5 => 5, gelu)) |
| 62 | +ps, st = Lux.setup(Random.default_rng(), model) |> xdev |
| 63 | +
|
| 64 | +x_ra = randn(Float32, 5, 3) |> xdev |
| 65 | +
|
| 66 | +@code_hlo model(x_ra, ps, st) |
| 67 | +
|
| 68 | +@jit model(x_ra, ps, st) |
| 69 | +``` |
| 70 | + |
| 71 | +```@example nested_ad_reactant |
| 72 | +sumabs2first(model, x, ps, st) = sum(abs2, first(model(x, ps, st))) |
| 73 | +
|
| 74 | +function enzyme_gradient(model, x, ps, st) |
| 75 | + return Enzyme.gradient( |
| 76 | + Enzyme.Reverse, Const(sumabs2first), Const(model), Const(x), ps, Const(st) |
| 77 | + ) |
| 78 | +end |
| 79 | +
|
| 80 | +@jit enzyme_gradient(model, x_ra, ps, st) |
| 81 | +``` |
0 commit comments