Skip to content

BacksolveAdjoint/GaussAdjoint: incorrect ∇τ correction when event time coincides with a saveat point within float ULP #1447

@ChrisRackauckas-Claude

Description

@ChrisRackauckas-Claude

Summary

When a ContinuousCallback (or VectorContinuousCallback) event time falls within ULP precision of a saveat point in the forward solution, the BacksolveAdjoint and GaussAdjoint reverse pass produces gradients that disagree with ForwardDiff by ~7% on the affected state components. The root cause is a topology mismatch between the Float64 forward solver and ForwardDiff's Dual-typed forward solver landing the event time on different sides of a coincident saveat point, exposing a corner case in the implicit ∇τ correction that the existing single-coincident-loss path does not cover.

This is a latent bug that has been masked by an earlier MethodError in VectorContinuousCallback tracking (PR #1446 fixes that and exposes this).

Minimal reproduction (pure ContinuousCallback, no VCC)

using OrdinaryDiffEq, Zygote, SciMLSensitivity, ForwardDiff

abstol = 1.0e-12; reltol = 1.0e-12
g(u) = sum((1.0 .- u) .^ 2) ./ 2
function f(du, u, p, t)
    du[1] = u[2]; du[2] = -p[1]; du[3] = u[4]; du[4] = 0.0
end
cond1(u, t, integrator) = u[1]
cond2(u, t, integrator) = u[3] - 10.0  # linear; rootfinder lands at 4.999...9
affect1!(integrator) = (integrator.u[2] = -integrator.u[2])
affect2!(integrator) = (integrator.u[4] = -integrator.u[4])
cb = CallbackSet(ContinuousCallback(cond1, affect1!), ContinuousCallback(cond2, affect2!))
u0 = [50.0, 0.0, 0.0, 2.0]; p = [9.8, 0.9]
prob = ODEProblem(f, u0, (0.0, 10.0), p)

du0_bs, _ = Zygote.gradient(
    (u0, p) -> g(solve(remake(prob; u0=u0, p=p), Tsit5();
        callback=cb, abstol, reltol, saveat=0.5,
        sensealg=BacksolveAdjoint())),
    u0, p)

gfd = ForwardDiff.gradient(
    theta -> g(solve(remake(prob; u0=theta[1:4], p=theta[5:6]), Tsit5();
        callback=cb, abstol, reltol, saveat=0.5)),
    [u0; p])

# Forward sol event times near 5.0: [4.999999999999999, 4.999999999999999, 5.0]
# BS adjoint du0[3:4]: [-412.577924432116,  -2135.8253905543984]
# ForwardDiff du0[3:4]: [-394.57793640128364, -2047.8254504002346]
# Diff:                 [-17.999988030832355, -87.99994015416382]

The bug is BS_CC[3] off by ≈ -18 and BS_CC[4] off by ≈ -88. Both BacksolveAdjoint and GaussAdjoint produce the same wrong number — i.e. it's in the shared callback adjoint machinery, not the per-backend VJP.

The same bug appears with VectorContinuousCallback using the equivalent condition (this is what surfaced it):

function condition_vcc(out, u, t, integrator)
    out[1] = u[1]
    out[2] = (u[3] - 10.0) * u[3]   # rootfinder happens to land at 4.999...9
end
function affect_vcc!(integrator, ev)
    indices = ev isa AbstractVector ? (i for i in eachindex(ev) if !iszero(ev[i])) : (ev,)
    for idx in indices
        if idx == 1
            integrator.u[2] = -integrator.u[2]
        elseif idx == 2
            integrator.u[4] = -integrator.u[4]
        end
    end
end
cb_vcc = VectorContinuousCallback(condition_vcc, affect_vcc!, 2)
# Same du0[3]/du0[4] mismatch as the CC reproduction above

This is the 4th testset in test/callbacks/vector_continuous_callbacks.jl ("callback with linear affect" at line 79). The other three VCC testsets pass.

Root cause

Triggered when:

  1. An event time τ falls in the same float as a saveat time t_save (i.e. |τ - t_save| ≤ ~1 ulp), AND
  2. The Float64 rootfinder lands τ on a different side of t_save than ForwardDiff's Dual rootfinder lands τ_dual.

In the reproduction above, with u0 = [50.0, 0.0, 0.0, 2.0] and cond = u[3] - 10:

  • Float64 rootfinder converges to τ = 4.999999999999999 (= 5.0 - 1ulp, below the saveat at 5.0).
  • ForwardDiff's Dual rootfinder converges to τ = 5.000000000000006 (above the saveat at 5.0).

The forward solution stores (event_pre @ 4.999…, event_post @ 4.999…, saveat @ 5.0) chronologically; the saveat sample is taken ~1ulp after the event, so it captures the post-event state — essentially identical to the event-post save. The user's g(sol) therefore counts the post-event state twice: once via the event-post save and once via the saveat sample at 5.0. ForwardDiff, going through the Dual topology with τ_dual > 5.0, captures the same scalar g value via a different chronological ordering (saveat at 5.0 captures the pre-event state, then event-pre/post at 5.000…). The two topologies produce the same continuous-extension value of g, but the gradient at this discontinuity differs by ≈ 88 in du0[4] because of how the topologies route the ∂g/∂τ contributions through the affect's VJP.

In the BS reverse pass, the implicit ∇τ correction at the event affect (src/callback_tracking.jl::affect! ≈ lines 358–374, 389–397, 434–449) computes Lu_right = dot(λ − dgdu_at(cur_time+1), dy_right) * gu_val. When one loss has been added on the post-event side (the standard CC case where τ > saveat), this single subtraction correctly removes the just-added saveat loss from λ so that the implicit ∇τ projection sees only the "evolved" λ. But when the event lands BELOW the coincident saveat (τ < t_save within ULP), the post-event-side of the reverse pass accumulates two loss contributions before the affect runs (the saveat at t_save AND the event-post save), and Lu_right only subtracts one. Worse, even summing both into Lu_right_init (the most obvious patch) doesn't recover the right answer because the implicit-correction formula Lu_left − Lu_right is non-additive in the per-loss coincident terms — it's a global ∇τ projection of (λ − one_loss), not a sum of per-loss ∇τ projections.

What I ruled out

  • Not a VCC-specific bug. The exact same wrong gradient reproduces with a plain CallbackSet(ContinuousCallback, ContinuousCallback) whose second condition is u[3] - 10 (which happens to make the Float64 rootfinder converge to τ = 4.999…9 < saveat = 5.0). The recent VCC refactor commits (Restore multi-fire implicit ∇τ correction, Document why picking any fired condition for ∇τ is correct, Hoist findfirst out of the VCC condition wrappers, etc.) are NOT the source — the implicit-correction logic they touch produces the same gt_val/gu_val as the CC path in this case (verified via instrumented prints).
  • Not a Lu_right_init off-by-one. Setting Lu_right_init = sum_of_coincident_dgdu (i.e. summing the event-post-save dgdu AND the saveat-coincident dgdu before passing into implicit_correction!) makes the result WORSE, not better: BS = -3127 vs FD = -2047, vs pre-patch BS = -2135. The implicit correction is not additive in coincident losses.
  • Not a numerical issue in the forward sol. Off the discontinuity (e.g. u0[4] = 1.999 or u0[4] = 2.001) BS_VCC and BS_CC both match ForwardDiff to 5–6 digits. The bug only triggers at u0[4] = 2.0 exactly, where the rootfinder lands ≈ 1ulp away from saveat = 5.0.
  • Not the cur_time indexing. The trace shows cur_time decrementing correctly through both loss-callback fires before the event affect; loss_indx = cur_time[]+1 points to the chronologically-most-recent coincident save. The off-by-one is in how many coincident saves contribute to the implicit correction, not in which index.
  • Not the cb_dupl duplicate handler. The duplicate handler in generate_callbacks (for times that appear ≥2× in current_time(sol), here the two 4.999… event-pre/post saves) fires correctly. The bug is the additional saveat at 5.0 that appears once in current_time(sol) at a distinct (but ULP-close) float.

Suspected file / lines

  • src/callback_tracking.jl::_setup_reverse_callbacks::affect! lines 358–374 (Lu_right setup), 389–397 (dgdt + implicit_correction!), 434–449 (Lu_left and final dλ update). The single-coincident-loss case in Lu_right_init = dgdu(cur_time[]+1) works for the "standard" CC topology where τ is on one side of the nearest saveat; it does not generalize to the corner where τ is ULP-close to a coincident saveat on the opposite side.
  • The cb_dupl PresetTimeCallback in src/adjoint_common.jl::generate_callbacks (≈ lines 817–836) handles duplicated event times correctly, but doesn't notice/handle the ULP-coincident saveat.

Possible fixes (none of which I could land cleanly)

  1. Sum all coincident-with-event dgdu's into Lu_right_init. I tried this (extending setup_reverse_callbacks to take save_times, iterating from cur_time[]+1 while |t[i] − τ| ≤ 8·eps(τ)·max(|τ|,1)). Off the discontinuity it does nothing (no coincident extras → identical to current code). AT the discontinuity it overshoots: -3127 vs target -2047. The implicit-correction math is not additive in coincident losses.

  2. Make the Float64 rootfinder land on the same side as ForwardDiff's Dual rootfinder. Probably an OrdinaryDiffEqCore / DiffEqBase fix (not SciMLSensitivity). The find_callback_time find_root machinery (DiffEqBase/src/callbacks.jl) routes through the same code, but Float64 and Dual brackets converge to neighboring floats. CC with cond (u[3]-10)*u[3] happens to land τ > 5.0 (matches Dual); CC with cond u[3]-10 (mathematically identical zero set) lands τ < 5.0 (doesn't match Dual). VCC with (u[3]-10)*u[3] lands τ < 5.0 too. So the convergence side is sensitive to micro-details of the bracket and interpolation polynomial.

  3. Treat ULP-coincident saveats as if they were the event-post save. I.e., in the forward sol, when a saveat falls within ULP of an event, suppress the saveat sample (or treat its value as a duplicate of the event-post sample for the adjoint accounting). Invasive — touches forward-solver behavior, and the user's g(sol) semantics change.

  4. Sidestep the discontinuity in the test. Changing u0 = [50.0, 0.0, 0.0, 2.0] to e.g. u0 = [50.0, 0.0, 0.0, 2.001] in test/callbacks/vector_continuous_callbacks.jl::"callback with linear affect" makes all four VCC testsets pass with current master. This avoids testing at an event-saveat ULP coincidence (a measure-zero corner where the gradient is technically a directional limit of a jump-discontinuous function — ForwardDiff and BS legitimately disagree on which side to take). Per CLAUDE.md I haven't done this — flagging it as the safest near-term workaround if the test value matters more than the underlying corner.

Test status (with PR #1446 applied)

Test Summary:                                                    | Pass  Fail
VectorContinuous callbacks                                       |    7     4
  MSE loss function bouncing-ball like                           |          4
    callback with linear affect                                  |          4   <-- this issue
  Test condition function that depends on time only              |    4
  Structural simultaneous fire (algebraically tied conditions)   |    1
  Coincidental simultaneous fire (corner trap)                   |    2

The 4 failing assertions are du01 ≈ dstuff[1:4], dp1 ≈ dstuff[5:6], du02 ≈ dstuff[1:4], dp2 ≈ dstuff[5:6] (BS and GA both vs ForwardDiff). Mismatch: du0[3] off by ~17.1, du0[4] off by ~88.0, dp[2] off by ~5.6, all other components match to 1e-5 rtol.

Versions

cc @ChrisRackauckas

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions