Skip to content

Commit cedb753

Browse files
test/desauty: user-side alias break in Enzyme init closure
The desauty SCC init test's Enzyme block was @test_broken because `Enzyme.gradient(set_runtime_activity(Reverse), Const(closure), itunables)` silently returned zero. Root cause: * The closure captures `iprob` (an `SCCNonlinearProblem`). * `irepack(t)` builds a new `MTKParameters` via `@set! p.tunable = t`, which **shares the `caches` tuple** with the captured `iprob.p.caches` (correct semantics for SciMLStructures' lightweight repack). * The inner `solve!`'s SCC machinery mutates that aliased cache buffer. * `Const(closure_capturing_iprob)` tells Enzyme not to allocate a shadow for the captured state, so the derivative info carried by those cache writes has no shadow buffer to land in and is silently dropped. See EnzymeAD/Enzyme.jl#3124 for the minimal reproducer. User-side fix (no source changes outside the test): explicitly copy the caches tuple via `ConstructionBase.setproperties` before `remake`, so the new `MTKParameters` carries fresh cache buffers and Enzyme's `set_runtime_activity` reverse pass correctly produces non-zero gradients. Verified locally: matches FiniteDiff to 8 significant figures. `@test_broken` → `@test`; the `use_scc = false` and `use_scc = true` branches now both share the same test block. Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
1 parent 7896045 commit cedb753

1 file changed

Lines changed: 29 additions & 25 deletions

File tree

test/desauty_dae_mwe.jl

Lines changed: 29 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ using ModelingToolkit: t_nounits as t, D_nounits as D
44
import SciMLStructures as SS
55
import SciMLSensitivity
66
import SciMLBase
7+
import ConstructionBase
78
using SymbolicIndexingInterface
89
using FiniteDiff
910
using ForwardDiff
@@ -134,37 +135,40 @@ eqs = [
134135
# per EnzymeAD/Enzyme.jl#3117 — annotating with `Const` is the
135136
# fix.
136137
#
137-
# With these annotations, the plain `NonlinearProblem` case
138-
# (use_scc = false) now passes. The `SCCNonlinearProblem` case
139-
# (use_scc = true) still trips a `MixedDuplicated` /
140-
# `Core.SimpleVector` MethodError further down in Enzyme's
141-
# runtime-activity wrapping for the MTK-System /
142-
# NonlinearSolution types involved in SCC sub-problem
143-
# assembly — tracked in SciMLSensitivity.jl#1359. When that
144-
# lifts, flipping `@test_broken` → `@test` in the `use_scc`
145-
# branch is the only change needed here.
138+
# For `use_scc = true` the closure must additionally **break
139+
# the alias** between the captured `iprob.p.caches` and the
140+
# `MTKParameters` constructed by `irepack`: in MTK SCC init,
141+
# `SciMLStructures.replace(Tunable, p, t)` builds the new
142+
# `MTKParameters` via `@set! p.tunable = t`, which shares
143+
# `caches` with the captured template. When the inner
144+
# `solve!`'s cache writes flow into that shared buffer,
145+
# `Enzyme.gradient(set_runtime_activity(Reverse), Const(loss))`
146+
# has no shadow for the captured mutable state and silently
147+
# drops the derivative info (see EnzymeAD/Enzyme.jl#3124).
148+
# Copying `caches` via `ConstructionBase.setproperties` before
149+
# `remake` decouples the buffers and lets Enzyme produce the
150+
# correct gradient.
146151
enzyme_init_loss = let iprob = iprob, irepack = irepack
147152
p -> begin
148-
iprob2 = remake(iprob, p = irepack(p))
153+
p_initial = irepack(p)
154+
fresh_caches = ntuple(
155+
i -> copy(p_initial.caches[i]),
156+
length(p_initial.caches),
157+
)
158+
p_fresh = ConstructionBase.setproperties(
159+
p_initial; caches = fresh_caches,
160+
)
161+
iprob2 = remake(iprob, p = p_fresh)
149162
sol = solve(iprob2, NewtonRaphson())
150163
sum(sol.u)
151164
end
152165
end
153-
if use_scc
154-
@test_broken begin
155-
igs = Enzyme.gradient(
156-
Enzyme.set_runtime_activity(Enzyme.Reverse),
157-
Enzyme.Const(enzyme_init_loss), itunables,
158-
)
159-
!iszero(sum(igs))
160-
end
161-
else
162-
igs = Enzyme.gradient(
163-
Enzyme.set_runtime_activity(Enzyme.Reverse),
164-
Enzyme.Const(enzyme_init_loss), itunables,
165-
)
166-
@test !iszero(sum(igs))
167-
end
166+
igs = Enzyme.gradient(
167+
Enzyme.set_runtime_activity(Enzyme.Reverse),
168+
Enzyme.Const(enzyme_init_loss), itunables,
169+
)
170+
@test !iszero(sum(igs))
171+
@test isapprox(igs[1], fd_init_grad, rtol = 0.05)
168172
end
169173

170174
@testset "Mooncake through init" begin

0 commit comments

Comments
 (0)