|
| 1 | +# [Nested AutoDiff with Reactant](@id nested_autodiff_reactant) |
| 2 | + |
| 3 | +We will be using the example from [issue 614](https://github.com/LuxDL/Lux.jl/issues/614). |
| 4 | + |
| 5 | +```@example nested_ad_reactant |
| 6 | +using Reactant, Enzyme, Lux, Random, LinearAlgebra |
| 7 | +
|
| 8 | +const xdev = reactant_device(; force=true) |
| 9 | +const cdev = cpu_device() |
| 10 | +
|
| 11 | +function stacked_onehot(x::AbstractMatrix{T}) where {T} |
| 12 | + onehot_matrix = Reactant.promote_to( |
| 13 | + Reactant.TracedRArray{Reactant.unwrapped_eltype(T),2}, LinearAlgebra.I(length(x)) |
| 14 | + ) |
| 15 | + return Reactant.materialize_traced_array(reshape(onehot_matrix, size(x)..., size(x)...)) |
| 16 | +end |
| 17 | +
|
| 18 | +function ∇potential(potential, x::AbstractMatrix) |
| 19 | + N, B = size(x) |
| 20 | + dxs = stacked_onehot(x) |
| 21 | + ∇p = similar(x) |
| 22 | + @trace for i in 1:B |
| 23 | + @trace for j in 1:N |
| 24 | + dxᵢ = dxs[:, :, j, i] |
| 25 | + res = only(Enzyme.autodiff(Forward, potential, Duplicated(x, dxᵢ))) |
| 26 | + @allowscalar ∇p[j, i] = res[j, i] |
| 27 | + end |
| 28 | + end |
| 29 | + return ∇p |
| 30 | +end |
| 31 | +
|
| 32 | +function ∇²potential(potential, x::AbstractMatrix) |
| 33 | + N, B = size(x) |
| 34 | + dxs = stacked_onehot(x) |
| 35 | + ∇²p = similar(x) |
| 36 | + @trace for i in 1:B |
| 37 | + @trace for j in 1:N |
| 38 | + dxᵢ = dxs[:, :, j, i] |
| 39 | + res = only(Enzyme.autodiff( |
| 40 | + Forward, ∇potential, Const(potential), Duplicated(x, dxᵢ) |
| 41 | + )) |
| 42 | + @allowscalar ∇²p[j, i] = res[j, i] |
| 43 | + end |
| 44 | + end |
| 45 | + return ∇²p |
| 46 | +end |
| 47 | +``` |
| 48 | + |
| 49 | +```@example nested_ad_reactant |
| 50 | +struct PotentialNet{P} <: AbstractLuxWrapperLayer{:potential} |
| 51 | + potential::P |
| 52 | +end |
| 53 | +
|
| 54 | +function (potential::PotentialNet)(x, ps, st) |
| 55 | + pnet = StatefulLuxLayer{true}(potential.potential, ps, st) |
| 56 | + return ∇²potential(pnet, x), pnet.st |
| 57 | +end |
| 58 | +``` |
| 59 | + |
| 60 | +```@example nested_ad_reactant |
| 61 | +model = PotentialNet(Dense(5 => 5, gelu)) |
| 62 | +ps, st = Lux.setup(Random.default_rng(), model) |> xdev |
| 63 | +
|
| 64 | +x_ra = randn(Float32, 5, 1024) |> xdev |
| 65 | +
|
| 66 | +@jit model(x_ra, ps, st) |
| 67 | +``` |
| 68 | + |
| 69 | +```@example nested_ad_reactant |
| 70 | +sumabs2first(model, x, ps, st) = sum(abs2, first(model(x, ps, st))) |
| 71 | +
|
| 72 | +function enzyme_gradient(model, x, ps, st) |
| 73 | + return Enzyme.gradient( |
| 74 | + Enzyme.Reverse, Const(sumabs2first), Const(model), Const(x), ps, Const(st) |
| 75 | + ) |
| 76 | +end |
| 77 | +
|
| 78 | +@jit enzyme_gradient(model, x_ra, ps, st) |
| 79 | +``` |
0 commit comments