Skip to content

Commit 52bd150

Browse files
test/parameter_initialization: use Enzyme.autodiff with Duplicated
Apply the same pattern as test/desauty (SciML#1454) to the `Adjoint through Prob (Enzyme)` testset: express the loss as a plain function `enzyme_loss(t, prob_)`, pass `prob` as an explicit `Duplicated` argument (with `make_zero(prob)` shadow), and reconstruct `repack` *inside* the loss from `prob_` so its captured parameter template shares Enzyme's shadow. This is the idiomatic Enzyme formulation for a loss that captures mutable state; `Const(loss)` on a closure capturing mutable `prob` doesn't allocate shadows for the captures, while `autodiff(... Duplicated(prob, dprob))` does. The `@test_broken` is retained because the rewrite now exposes a deeper, well-localized blocker: `Enzyme.autodiff` enters the `GaussAdjoint` `_concrete_solve_adjoint` rule, which calls `_init_originator_gradient` to differentiate the parameter-init `tunables -> new_u0` mapping. That helper invokes `Enzyme.gradient(Enzyme.Reverse, Const(init_loss), tunables)` without `set_runtime_activity`, so the inner Enzyme call still raises `EnzymeRuntimeActivityError` at the MTK init `remake` — the outer call's runtime-activity setting does not propagate into that nested gradient. The test now clearly indicates what's left: teach `_init_originator_gradient(::EnzymeOriginator, ...)` to use `set_runtime_activity(Reverse)` (or `autodiff` + `Duplicated` itself). Flipping `@test_broken` -> `@test` is the only change needed here once that fix lands. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
1 parent 5aadc53 commit 52bd150

1 file changed

Lines changed: 41 additions & 22 deletions

File tree

test/parameter_initialization.jl

Lines changed: 41 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -70,33 +70,52 @@ tunables, repack, _ = SS.canonicalize(SS.Tunable(), parameter_values(prob))
7070
end
7171

7272
# Exercises the EnzymeOriginator method of `_init_originator_gradient`
73-
# added alongside this testset. Annotations follow the documented
74-
# user-side pattern: `Const(loss)` for the closure that captures the
75-
# mutable `ODEProblem`, and `set_runtime_activity(Reverse)` so Enzyme's
76-
# activity analysis tolerates the runtime-activity transitions through
77-
# MTK's `remake` path. With these in place the activity layer is
78-
# handled; the remaining blocker is a `MixedDuplicated` /
79-
# `Core.SimpleVector` MethodError further down in Enzyme's
80-
# runtime-activity wrapping for MTK-System / NonlinearSolution
81-
# types — tracked in SciMLSensitivity.jl#1359. When that lifts,
73+
# via a full ODE solve under `GaussAdjoint(EnzymeVJP())`. We use the
74+
# idiomatic Enzyme pattern (mirroring the SCC-init rewrite in #1454):
75+
# express the loss as a plain function whose captured mutable state
76+
# (the `ODEProblem`) is passed as an explicit `Duplicated` argument,
77+
# and reconstruct `repack` *inside* the loss from the duplicated
78+
# `prob_` so its captured parameter template shares the Enzyme
79+
# shadow. `Const(loss)` annotates the function and
80+
# `set_runtime_activity(Reverse)` tolerates runtime-activity
81+
# transitions through MTK's `remake` path.
82+
#
83+
# Residual blocker: `Enzyme.autodiff` enters the `GaussAdjoint`
84+
# `_concrete_solve_adjoint` rule, which calls
85+
# `_init_originator_gradient` to differentiate the parameter-init
86+
# `tunables -> new_u0` mapping. That helper currently invokes
87+
# `Enzyme.gradient(Enzyme.Reverse, Const(init_loss), tunables)`
88+
# without `set_runtime_activity`, so the inner Enzyme call still
89+
# raises `EnzymeRuntimeActivityError` at the MTK init `remake`. The
90+
# outer-call activity setting doesn't propagate into that nested
91+
# gradient. Fixing this requires teaching
92+
# `_init_originator_gradient(::EnzymeOriginator, ...)` to wrap with
93+
# `set_runtime_activity(Reverse)` (or use `autodiff` + `Duplicated`
94+
# itself); both are outside this test file's scope. When that lifts,
8295
# flipping `@test_broken` → `@test` is the only change needed here.
8396
@testset "Adjoint through Prob (Enzyme)" begin
84-
sensealg = SciMLSensitivity.GaussAdjoint(
85-
autojacvec = SciMLSensitivity.EnzymeVJP(),
86-
)
87-
loss = let prob = prob, repack = repack, sensealg = sensealg
88-
function (tunables)
89-
new_prob = remake(prob; p = repack(tunables))
90-
sol = solve(new_prob; sensealg)
91-
return sum(sol)
92-
end
97+
function enzyme_loss(t, prob_)
98+
_, repack_, _ = SS.canonicalize(
99+
SS.Tunable(), parameter_values(prob_),
100+
)
101+
new_prob = remake(prob_; p = repack_(t))
102+
sensealg = SciMLSensitivity.GaussAdjoint(
103+
autojacvec = SciMLSensitivity.EnzymeVJP(),
104+
)
105+
sol = solve(new_prob; sensealg)
106+
return sum(sol)
93107
end
94108
@test_broken begin
95-
g = Enzyme.gradient(
109+
dprob = Enzyme.make_zero(prob)
110+
dtunables = zero(tunables)
111+
Enzyme.autodiff(
96112
Enzyme.set_runtime_activity(Enzyme.Reverse),
97-
Enzyme.Const(loss), copy(tunables),
98-
)[1]
99-
any(!iszero, g)
113+
Enzyme.Const(enzyme_loss),
114+
Enzyme.Active,
115+
Enzyme.Duplicated(copy(tunables), dtunables),
116+
Enzyme.Duplicated(prob, dprob),
117+
)
118+
any(!iszero, dtunables)
100119
end
101120
end
102121
end

0 commit comments

Comments
 (0)