Summary
When a ContinuousCallback (or VectorContinuousCallback) event time falls within ULP precision of a saveat point in the forward solution, the BacksolveAdjoint and GaussAdjoint reverse pass produces gradients that disagree with ForwardDiff by ~7% on the affected state components. The root cause is a topology mismatch between the Float64 forward solver and ForwardDiff's Dual-typed forward solver landing the event time on different sides of a coincident saveat point, exposing a corner case in the implicit ∇τ correction that the existing single-coincident-loss path does not cover.
This is a latent bug that has been masked by an earlier MethodError in VectorContinuousCallback tracking (PR #1446 fixes that and exposes this).
Minimal reproduction (pure ContinuousCallback, no VCC)
using OrdinaryDiffEq, Zygote, SciMLSensitivity, ForwardDiff
abstol = 1.0e-12; reltol = 1.0e-12
g(u) = sum((1.0 .- u) .^ 2) ./ 2
function f(du, u, p, t)
du[1] = u[2]; du[2] = -p[1]; du[3] = u[4]; du[4] = 0.0
end
cond1(u, t, integrator) = u[1]
cond2(u, t, integrator) = u[3] - 10.0 # linear; rootfinder lands at 4.999...9
affect1!(integrator) = (integrator.u[2] = -integrator.u[2])
affect2!(integrator) = (integrator.u[4] = -integrator.u[4])
cb = CallbackSet(ContinuousCallback(cond1, affect1!), ContinuousCallback(cond2, affect2!))
u0 = [50.0, 0.0, 0.0, 2.0]; p = [9.8, 0.9]
prob = ODEProblem(f, u0, (0.0, 10.0), p)
du0_bs, _ = Zygote.gradient(
(u0, p) -> g(solve(remake(prob; u0=u0, p=p), Tsit5();
callback=cb, abstol, reltol, saveat=0.5,
sensealg=BacksolveAdjoint())),
u0, p)
gfd = ForwardDiff.gradient(
theta -> g(solve(remake(prob; u0=theta[1:4], p=theta[5:6]), Tsit5();
callback=cb, abstol, reltol, saveat=0.5)),
[u0; p])
# Forward sol event times near 5.0: [4.999999999999999, 4.999999999999999, 5.0]
# BS adjoint du0[3:4]: [-412.577924432116, -2135.8253905543984]
# ForwardDiff du0[3:4]: [-394.57793640128364, -2047.8254504002346]
# Diff: [-17.999988030832355, -87.99994015416382]
The bug is BS_CC[3] off by ≈ -18 and BS_CC[4] off by ≈ -88. Both BacksolveAdjoint and GaussAdjoint produce the same wrong number — i.e. it's in the shared callback adjoint machinery, not the per-backend VJP.
The same bug appears with VectorContinuousCallback using the equivalent condition (this is what surfaced it):
function condition_vcc(out, u, t, integrator)
out[1] = u[1]
out[2] = (u[3] - 10.0) * u[3] # rootfinder happens to land at 4.999...9
end
function affect_vcc!(integrator, ev)
indices = ev isa AbstractVector ? (i for i in eachindex(ev) if !iszero(ev[i])) : (ev,)
for idx in indices
if idx == 1
integrator.u[2] = -integrator.u[2]
elseif idx == 2
integrator.u[4] = -integrator.u[4]
end
end
end
cb_vcc = VectorContinuousCallback(condition_vcc, affect_vcc!, 2)
# Same du0[3]/du0[4] mismatch as the CC reproduction above
This is the 4th testset in test/callbacks/vector_continuous_callbacks.jl ("callback with linear affect" at line 79). The other three VCC testsets pass.
Root cause
Triggered when:
- An event time τ falls in the same float as a
saveat time t_save (i.e. |τ - t_save| ≤ ~1 ulp), AND
- The Float64 rootfinder lands τ on a different side of t_save than ForwardDiff's
Dual rootfinder lands τ_dual.
In the reproduction above, with u0 = [50.0, 0.0, 0.0, 2.0] and cond = u[3] - 10:
- Float64 rootfinder converges to τ =
4.999999999999999 (= 5.0 - 1ulp, below the saveat at 5.0).
- ForwardDiff's
Dual rootfinder converges to τ = 5.000000000000006 (above the saveat at 5.0).
The forward solution stores (event_pre @ 4.999…, event_post @ 4.999…, saveat @ 5.0) chronologically; the saveat sample is taken ~1ulp after the event, so it captures the post-event state — essentially identical to the event-post save. The user's g(sol) therefore counts the post-event state twice: once via the event-post save and once via the saveat sample at 5.0. ForwardDiff, going through the Dual topology with τ_dual > 5.0, captures the same scalar g value via a different chronological ordering (saveat at 5.0 captures the pre-event state, then event-pre/post at 5.000…). The two topologies produce the same continuous-extension value of g, but the gradient at this discontinuity differs by ≈ 88 in du0[4] because of how the topologies route the ∂g/∂τ contributions through the affect's VJP.
In the BS reverse pass, the implicit ∇τ correction at the event affect (src/callback_tracking.jl::affect! ≈ lines 358–374, 389–397, 434–449) computes Lu_right = dot(λ − dgdu_at(cur_time+1), dy_right) * gu_val. When one loss has been added on the post-event side (the standard CC case where τ > saveat), this single subtraction correctly removes the just-added saveat loss from λ so that the implicit ∇τ projection sees only the "evolved" λ. But when the event lands BELOW the coincident saveat (τ < t_save within ULP), the post-event-side of the reverse pass accumulates two loss contributions before the affect runs (the saveat at t_save AND the event-post save), and Lu_right only subtracts one. Worse, even summing both into Lu_right_init (the most obvious patch) doesn't recover the right answer because the implicit-correction formula Lu_left − Lu_right is non-additive in the per-loss coincident terms — it's a global ∇τ projection of (λ − one_loss), not a sum of per-loss ∇τ projections.
What I ruled out
- Not a VCC-specific bug. The exact same wrong gradient reproduces with a plain
CallbackSet(ContinuousCallback, ContinuousCallback) whose second condition is u[3] - 10 (which happens to make the Float64 rootfinder converge to τ = 4.999…9 < saveat = 5.0). The recent VCC refactor commits (Restore multi-fire implicit ∇τ correction, Document why picking any fired condition for ∇τ is correct, Hoist findfirst out of the VCC condition wrappers, etc.) are NOT the source — the implicit-correction logic they touch produces the same gt_val/gu_val as the CC path in this case (verified via instrumented prints).
- Not a Lu_right_init off-by-one. Setting
Lu_right_init = sum_of_coincident_dgdu (i.e. summing the event-post-save dgdu AND the saveat-coincident dgdu before passing into implicit_correction!) makes the result WORSE, not better: BS = -3127 vs FD = -2047, vs pre-patch BS = -2135. The implicit correction is not additive in coincident losses.
- Not a numerical issue in the forward sol. Off the discontinuity (e.g.
u0[4] = 1.999 or u0[4] = 2.001) BS_VCC and BS_CC both match ForwardDiff to 5–6 digits. The bug only triggers at u0[4] = 2.0 exactly, where the rootfinder lands ≈ 1ulp away from saveat = 5.0.
- Not the
cur_time indexing. The trace shows cur_time decrementing correctly through both loss-callback fires before the event affect; loss_indx = cur_time[]+1 points to the chronologically-most-recent coincident save. The off-by-one is in how many coincident saves contribute to the implicit correction, not in which index.
- Not the cb_dupl duplicate handler. The duplicate handler in
generate_callbacks (for times that appear ≥2× in current_time(sol), here the two 4.999… event-pre/post saves) fires correctly. The bug is the additional saveat at 5.0 that appears once in current_time(sol) at a distinct (but ULP-close) float.
Suspected file / lines
src/callback_tracking.jl::_setup_reverse_callbacks::affect! lines 358–374 (Lu_right setup), 389–397 (dgdt + implicit_correction!), 434–449 (Lu_left and final dλ update). The single-coincident-loss case in Lu_right_init = dgdu(cur_time[]+1) works for the "standard" CC topology where τ is on one side of the nearest saveat; it does not generalize to the corner where τ is ULP-close to a coincident saveat on the opposite side.
- The
cb_dupl PresetTimeCallback in src/adjoint_common.jl::generate_callbacks (≈ lines 817–836) handles duplicated event times correctly, but doesn't notice/handle the ULP-coincident saveat.
Possible fixes (none of which I could land cleanly)
-
Sum all coincident-with-event dgdu's into Lu_right_init. I tried this (extending setup_reverse_callbacks to take save_times, iterating from cur_time[]+1 while |t[i] − τ| ≤ 8·eps(τ)·max(|τ|,1)). Off the discontinuity it does nothing (no coincident extras → identical to current code). AT the discontinuity it overshoots: -3127 vs target -2047. The implicit-correction math is not additive in coincident losses.
-
Make the Float64 rootfinder land on the same side as ForwardDiff's Dual rootfinder. Probably an OrdinaryDiffEqCore / DiffEqBase fix (not SciMLSensitivity). The find_callback_time find_root machinery (DiffEqBase/src/callbacks.jl) routes through the same code, but Float64 and Dual brackets converge to neighboring floats. CC with cond (u[3]-10)*u[3] happens to land τ > 5.0 (matches Dual); CC with cond u[3]-10 (mathematically identical zero set) lands τ < 5.0 (doesn't match Dual). VCC with (u[3]-10)*u[3] lands τ < 5.0 too. So the convergence side is sensitive to micro-details of the bracket and interpolation polynomial.
-
Treat ULP-coincident saveats as if they were the event-post save. I.e., in the forward sol, when a saveat falls within ULP of an event, suppress the saveat sample (or treat its value as a duplicate of the event-post sample for the adjoint accounting). Invasive — touches forward-solver behavior, and the user's g(sol) semantics change.
-
Sidestep the discontinuity in the test. Changing u0 = [50.0, 0.0, 0.0, 2.0] to e.g. u0 = [50.0, 0.0, 0.0, 2.001] in test/callbacks/vector_continuous_callbacks.jl::"callback with linear affect" makes all four VCC testsets pass with current master. This avoids testing at an event-saveat ULP coincidence (a measure-zero corner where the gradient is technically a directional limit of a jump-discontinuous function — ForwardDiff and BS legitimately disagree on which side to take). Per CLAUDE.md I haven't done this — flagging it as the safest near-term workaround if the test value matters more than the underlying corner.
Test status (with PR #1446 applied)
Test Summary: | Pass Fail
VectorContinuous callbacks | 7 4
MSE loss function bouncing-ball like | 4
callback with linear affect | 4 <-- this issue
Test condition function that depends on time only | 4
Structural simultaneous fire (algebraically tied conditions) | 1
Coincidental simultaneous fire (corner trap) | 2
The 4 failing assertions are du01 ≈ dstuff[1:4], dp1 ≈ dstuff[5:6], du02 ≈ dstuff[1:4], dp2 ≈ dstuff[5:6] (BS and GA both vs ForwardDiff). Mismatch: du0[3] off by ~17.1, du0[4] off by ~88.0, dp[2] off by ~5.6, all other components match to 1e-5 rtol.
Versions
cc @ChrisRackauckas
Summary
When a
ContinuousCallback(orVectorContinuousCallback) event time falls within ULP precision of asaveatpoint in the forward solution, the BacksolveAdjoint and GaussAdjoint reverse pass produces gradients that disagree with ForwardDiff by ~7% on the affected state components. The root cause is a topology mismatch between the Float64 forward solver and ForwardDiff'sDual-typed forward solver landing the event time on different sides of a coincidentsaveatpoint, exposing a corner case in the implicit ∇τ correction that the existing single-coincident-loss path does not cover.This is a latent bug that has been masked by an earlier
MethodErrorinVectorContinuousCallbacktracking (PR #1446 fixes that and exposes this).Minimal reproduction (pure
ContinuousCallback, no VCC)The bug is
BS_CC[3]off by ≈ -18 andBS_CC[4]off by ≈ -88. BothBacksolveAdjointandGaussAdjointproduce the same wrong number — i.e. it's in the shared callback adjoint machinery, not the per-backend VJP.The same bug appears with
VectorContinuousCallbackusing the equivalent condition (this is what surfaced it):This is the 4th testset in
test/callbacks/vector_continuous_callbacks.jl("callback with linear affect"at line 79). The other three VCC testsets pass.Root cause
Triggered when:
saveattime t_save (i.e.|τ - t_save| ≤ ~1 ulp), ANDDualrootfinder lands τ_dual.In the reproduction above, with
u0 = [50.0, 0.0, 0.0, 2.0]andcond = u[3] - 10:4.999999999999999(=5.0 - 1ulp, below the saveat at5.0).Dualrootfinder converges to τ =5.000000000000006(above the saveat at5.0).The forward solution stores
(event_pre @ 4.999…, event_post @ 4.999…, saveat @ 5.0)chronologically; the saveat sample is taken ~1ulp after the event, so it captures the post-event state — essentially identical to the event-post save. The user'sg(sol)therefore counts the post-event state twice: once via the event-post save and once via the saveat sample at 5.0. ForwardDiff, going through theDualtopology with τ_dual > 5.0, captures the same scalargvalue via a different chronological ordering (saveat at 5.0 captures the pre-event state, then event-pre/post at 5.000…). The two topologies produce the same continuous-extension value ofg, but the gradient at this discontinuity differs by ≈ 88 indu0[4]because of how the topologies route the ∂g/∂τ contributions through the affect's VJP.In the BS reverse pass, the implicit ∇τ correction at the event affect (
src/callback_tracking.jl::affect!≈ lines 358–374, 389–397, 434–449) computesLu_right = dot(λ − dgdu_at(cur_time+1), dy_right) * gu_val. When one loss has been added on the post-event side (the standard CC case where τ > saveat), this single subtraction correctly removes the just-added saveat loss fromλso that the implicit ∇τ projection sees only the "evolved" λ. But when the event lands BELOW the coincident saveat (τ < t_save within ULP), the post-event-side of the reverse pass accumulates two loss contributions before the affect runs (the saveat at t_save AND the event-post save), andLu_rightonly subtracts one. Worse, even summing both intoLu_right_init(the most obvious patch) doesn't recover the right answer because the implicit-correction formulaLu_left − Lu_rightis non-additive in the per-loss coincident terms — it's a global ∇τ projection of (λ − one_loss), not a sum of per-loss ∇τ projections.What I ruled out
CallbackSet(ContinuousCallback, ContinuousCallback)whose second condition isu[3] - 10(which happens to make the Float64 rootfinder converge to τ = 4.999…9 < saveat = 5.0). The recent VCC refactor commits (Restore multi-fire implicit ∇τ correction,Document why picking any fired condition for ∇τ is correct,Hoist findfirst out of the VCC condition wrappers, etc.) are NOT the source — the implicit-correction logic they touch produces the samegt_val/gu_valas the CC path in this case (verified via instrumented prints).Lu_right_init = sum_of_coincident_dgdu(i.e. summing the event-post-save dgdu AND the saveat-coincident dgdu before passing intoimplicit_correction!) makes the result WORSE, not better:BS = -3127vs FD= -2047, vs pre-patchBS = -2135. The implicit correction is not additive in coincident losses.u0[4] = 1.999oru0[4] = 2.001) BS_VCC and BS_CC both match ForwardDiff to 5–6 digits. The bug only triggers atu0[4] = 2.0exactly, where the rootfinder lands ≈ 1ulp away from saveat = 5.0.cur_timeindexing. The trace showscur_timedecrementing correctly through both loss-callback fires before the event affect;loss_indx = cur_time[]+1points to the chronologically-most-recent coincident save. The off-by-one is in how many coincident saves contribute to the implicit correction, not in which index.generate_callbacks(for times that appear ≥2× incurrent_time(sol), here the two4.999…event-pre/post saves) fires correctly. The bug is the additional saveat at 5.0 that appears once incurrent_time(sol)at a distinct (but ULP-close) float.Suspected file / lines
src/callback_tracking.jl::_setup_reverse_callbacks::affect!lines 358–374 (Lu_right setup), 389–397 (dgdt+implicit_correction!), 434–449 (Lu_left and final dλ update). The single-coincident-loss case inLu_right_init = dgdu(cur_time[]+1)works for the "standard" CC topology where τ is on one side of the nearest saveat; it does not generalize to the corner where τ is ULP-close to a coincident saveat on the opposite side.cb_duplPresetTimeCallback insrc/adjoint_common.jl::generate_callbacks(≈ lines 817–836) handles duplicated event times correctly, but doesn't notice/handle the ULP-coincident saveat.Possible fixes (none of which I could land cleanly)
Sum all coincident-with-event dgdu's into
Lu_right_init. I tried this (extendingsetup_reverse_callbacksto takesave_times, iterating fromcur_time[]+1while|t[i] − τ| ≤ 8·eps(τ)·max(|τ|,1)). Off the discontinuity it does nothing (no coincident extras → identical to current code). AT the discontinuity it overshoots: -3127 vs target -2047. The implicit-correction math is not additive in coincident losses.Make the Float64 rootfinder land on the same side as ForwardDiff's Dual rootfinder. Probably an OrdinaryDiffEqCore / DiffEqBase fix (not SciMLSensitivity). The
find_callback_timefind_rootmachinery (DiffEqBase/src/callbacks.jl) routes through the same code, but Float64 andDualbrackets converge to neighboring floats. CC with cond(u[3]-10)*u[3]happens to land τ > 5.0 (matches Dual); CC with condu[3]-10(mathematically identical zero set) lands τ < 5.0 (doesn't match Dual). VCC with(u[3]-10)*u[3]lands τ < 5.0 too. So the convergence side is sensitive to micro-details of the bracket and interpolation polynomial.Treat ULP-coincident saveats as if they were the event-post save. I.e., in the forward sol, when a saveat falls within ULP of an event, suppress the saveat sample (or treat its value as a duplicate of the event-post sample for the adjoint accounting). Invasive — touches forward-solver behavior, and the user's
g(sol)semantics change.Sidestep the discontinuity in the test. Changing
u0 = [50.0, 0.0, 0.0, 2.0]to e.g.u0 = [50.0, 0.0, 0.0, 2.001]intest/callbacks/vector_continuous_callbacks.jl::"callback with linear affect"makes all four VCC testsets pass with currentmaster. This avoids testing at an event-saveat ULP coincidence (a measure-zero corner where the gradient is technically a directional limit of a jump-discontinuous function — ForwardDiff and BS legitimately disagree on which side to take). Per CLAUDE.md I haven't done this — flagging it as the safest near-term workaround if the test value matters more than the underlying corner.Test status (with PR #1446 applied)
The 4 failing assertions are
du01 ≈ dstuff[1:4],dp1 ≈ dstuff[5:6],du02 ≈ dstuff[1:4],dp2 ≈ dstuff[5:6](BS and GA both vs ForwardDiff). Mismatch:du0[3]off by ~17.1,du0[4]off by ~88.0,dp[2]off by ~5.6, all other components match to 1e-5 rtol.Versions
573f8add(with PR Use single-affect VCC constructor in _track_callback #1446)test_env_112cc @ChrisRackauckas