Skip to content

Commit b718a7a

Browse files
Almost there
1 parent 60aada0 commit b718a7a

File tree

2 files changed

+31
-14
lines changed

2 files changed

+31
-14
lines changed

Diff for: src/callback_tracking.jl

+16-3
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,11 @@ function _setup_reverse_callbacks(
272272
du = first(get_tmp_cache(integrator))
273273
λ, grad, y, dλ, dgrad, dy = split_states(du, integrator.u, integrator.t, S)
274274

275+
if sensealg isa GaussAdjoint
276+
dgrad = integrator.f.f.integrating_cb.affect!.accumulation_cache
277+
recursive_copyto!(dgrad, 0)
278+
end
279+
275280
# if save_positions[2] = false, then the right limit is not saved. Thus, for
276281
# the QuadratureAdjoint we would need to lift y from the left to the right limit.
277282
# However, one also needs to update dgrad later on.
@@ -339,7 +344,10 @@ function _setup_reverse_callbacks(
339344
vecjacobian!(dλ, y, λ, integrator.p, integrator.t, fakeS;
340345
dgrad = dgrad, dy = dy)
341346

342-
dgrad !== nothing && (dgrad .*= -1)
347+
if dgrad !== nothing && !(sensealg isa QuadratureAdjoint)
348+
dgrad .*= -1
349+
end
350+
343351
if cb isa Union{ContinuousCallback, VectorContinuousCallback}
344352
# second correction to correct for left limit
345353
@unpack Lu_left = correction
@@ -358,8 +366,13 @@ function _setup_reverse_callbacks(
358366

359367
λ .=
360368

361-
if !(sensealg isa QuadratureAdjoint) && !(sensealg isa GaussAdjoint)
362-
grad .-= dgrad
369+
if sensealg isa GaussAdjoint
370+
@assert integrator.f.f isa ODEGaussAdjointSensitivityFunction
371+
integrator.f.f.integrating_cb.affect!.integrand_values.integrand .= dgrad
372+
373+
#recursive_add!(integrator.f.f.integrating_cb.affect!.integrand_values.integrand,dgrad)
374+
elseif !(sensealg isa QuadratureAdjoint)
375+
grad .= dgrad
363376
end
364377
end
365378

Diff for: src/gauss_adjoint.jl

+15-11
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
mutable struct GaussIntegrand{pType, uType, lType, rateType, S, PF, PJC, PJT, DGP,
2-
G}
2+
G, SAlg <: GaussAdjoint}
33
sol::S
44
p::pType
55
y::uType
@@ -8,15 +8,17 @@ mutable struct GaussIntegrand{pType, uType, lType, rateType, S, PF, PJC, PJT, DG
88
f_cache::rateType
99
pJ::PJT
1010
paramjac_config::PJC
11-
sensealg::GaussAdjoint
11+
sensealg::SAlg
1212
dgdp_cache::DGP
1313
dgdp::G
1414
end
1515

1616
struct ODEGaussAdjointSensitivityFunction{C <: AdjointDiffCache,
1717
Alg <: GaussAdjoint,
1818
uType, SType, CPS, pType,
19-
fType <: DiffEqBase.AbstractDiffEqFunction} <: SensitivityFunction
19+
fType <: DiffEqBase.AbstractDiffEqFunction,
20+
GI <: GaussIntegrand,
21+
ICB} <: SensitivityFunction
2022
diffcache::C
2123
sensealg::Alg
2224
discrete::Bool
@@ -25,7 +27,8 @@ struct ODEGaussAdjointSensitivityFunction{C <: AdjointDiffCache,
2527
checkpoint_sol::CPS
2628
prob::pType
2729
f::fType
28-
GaussInt::GaussIntegrand
30+
GaussInt::GI
31+
integrating_cb::ICB
2932
end
3033

3134
TruncatedStacktraces.@truncate_stacktrace ODEGaussAdjointSensitivityFunction
@@ -41,7 +44,7 @@ end
4144
function ODEGaussAdjointSensitivityFunction(
4245
g, sensealg, gaussint, discrete, sol, dgdu, dgdp,
4346
f, alg,
44-
checkpoints, tols, tstops = nothing;
47+
checkpoints, integrating_cb, tols, tstops = nothing;
4548
tspan = reverse(sol.prob.tspan))
4649
checkpointing = ischeckpointing(sensealg, sol)
4750
(checkpointing && checkpoints === nothing) &&
@@ -84,7 +87,7 @@ function ODEGaussAdjointSensitivityFunction(
8487
g, sensealg, discrete, sol, dgdu, dgdp, sol.prob.f, alg;
8588
quad = true)
8689
return ODEGaussAdjointSensitivityFunction(diffcache, sensealg, discrete,
87-
y, sol, checkpoint_sol, sol.prob, f, gaussint)
90+
y, sol, checkpoint_sol, sol.prob, f, gaussint, integrating_cb)
8891
end
8992

9093
function Gaussfindcursor(intervals, t)
@@ -202,7 +205,7 @@ function split_states(u, t, S::ODEGaussAdjointSensitivityFunction; update = true
202205
end
203206

204207
@noinline function ODEAdjointProblem(sol, sensealg::GaussAdjoint, alg,
205-
GaussInt::GaussIntegrand,
208+
GaussInt::GaussIntegrand, integrating_cb,
206209
t = nothing,
207210
dgdu_discrete::DG1 = nothing,
208211
dgdp_discrete::DG2 = nothing,
@@ -275,7 +278,7 @@ end
275278
λ = zero(u0)
276279
end
277280
sense = ODEGaussAdjointSensitivityFunction(g, sensealg, GaussInt, discrete, sol,
278-
dgdu_continuous, dgdp_continuous, f, alg, checkpoints,
281+
dgdu_continuous, dgdp_continuous, f, alg, checkpoints, integrating_cb,
279282
(reltol = reltol, abstol = abstol), tstops, tspan = tspan)
280283

281284
init_cb = (discrete || dgdu_discrete !== nothing) # && tspan[1] == t[end]
@@ -521,15 +524,16 @@ function _adjoint_sensitivities(sol, sensealg::GaussAdjoint, alg; t = nothing,
521524
kwargs...)
522525
integrand = GaussIntegrand(sol, sensealg, checkpoints, dgdp_continuous)
523526
integrand_values = IntegrandValuesSum(allocate_zeros(sol.prob.p))
524-
cb = IntegratingSumCallback((out, u, t, integrator) -> integrand(out, t, u),
527+
integrating_cb = IntegratingSumCallback((out, u, t, integrator) -> integrand(out, t, u),
525528
integrand_values, allocate_vjp(sol.prob.p))
526529
rcb = nothing
527530
cb2 = nothing
528531
adj_prob = nothing
529532

530533
if sol.prob isa ODEProblem
531534
adj_prob, cb2, rcb = ODEAdjointProblem(
532-
sol, sensealg, alg, integrand, t, dgdu_discrete,
535+
sol, sensealg, alg, integrand, integrating_cb,
536+
t, dgdu_discrete,
533537
dgdp_discrete,
534538
dgdu_continuous, dgdp_continuous, g, Val(true);
535539
checkpoints = checkpoints,
@@ -544,7 +548,7 @@ function _adjoint_sensitivities(sol, sensealg::GaussAdjoint, alg; t = nothing,
544548
adj_sol = solve(
545549
adj_prob, alg; abstol = abstol, reltol = reltol, save_everystep = false,
546550
save_start = false, save_end = true, saveat = eltype(sol.u[1])[], tstops = tstops,
547-
callback = CallbackSet(cb, cb2), kwargs...)
551+
callback = CallbackSet(integrating_cb, cb2), kwargs...)
548552
res = integrand_values.integrand
549553

550554
if rcb !== nothing && !isempty(rcb.Δλas)

0 commit comments

Comments
 (0)