1
1
mutable struct GaussIntegrand{pType, uType, lType, rateType, S, PF, PJC, PJT, DGP,
2
- G}
2
+ G, SAlg <: GaussAdjoint }
3
3
sol:: S
4
4
p:: pType
5
5
y:: uType
@@ -8,15 +8,17 @@ mutable struct GaussIntegrand{pType, uType, lType, rateType, S, PF, PJC, PJT, DG
8
8
f_cache:: rateType
9
9
pJ:: PJT
10
10
paramjac_config:: PJC
11
- sensealg:: GaussAdjoint
11
+ sensealg:: SAlg
12
12
dgdp_cache:: DGP
13
13
dgdp:: G
14
14
end
15
15
16
16
struct ODEGaussAdjointSensitivityFunction{C <: AdjointDiffCache ,
17
17
Alg <: GaussAdjoint ,
18
18
uType, SType, CPS, pType,
19
- fType <: DiffEqBase.AbstractDiffEqFunction } <: SensitivityFunction
19
+ fType <: DiffEqBase.AbstractDiffEqFunction ,
20
+ GI <: GaussIntegrand ,
21
+ ICB} <: SensitivityFunction
20
22
diffcache:: C
21
23
sensealg:: Alg
22
24
discrete:: Bool
@@ -25,7 +27,8 @@ struct ODEGaussAdjointSensitivityFunction{C <: AdjointDiffCache,
25
27
checkpoint_sol:: CPS
26
28
prob:: pType
27
29
f:: fType
28
- GaussInt:: GaussIntegrand
30
+ GaussInt:: GI
31
+ integrating_cb:: ICB
29
32
end
30
33
31
34
TruncatedStacktraces. @truncate_stacktrace ODEGaussAdjointSensitivityFunction
41
44
function ODEGaussAdjointSensitivityFunction (
42
45
g, sensealg, gaussint, discrete, sol, dgdu, dgdp,
43
46
f, alg,
44
- checkpoints, tols, tstops = nothing ;
47
+ checkpoints, integrating_cb, tols, tstops = nothing ;
45
48
tspan = reverse (sol. prob. tspan))
46
49
checkpointing = ischeckpointing (sensealg, sol)
47
50
(checkpointing && checkpoints === nothing ) &&
@@ -84,7 +87,7 @@ function ODEGaussAdjointSensitivityFunction(
84
87
g, sensealg, discrete, sol, dgdu, dgdp, sol. prob. f, alg;
85
88
quad = true )
86
89
return ODEGaussAdjointSensitivityFunction (diffcache, sensealg, discrete,
87
- y, sol, checkpoint_sol, sol. prob, f, gaussint)
90
+ y, sol, checkpoint_sol, sol. prob, f, gaussint, integrating_cb )
88
91
end
89
92
90
93
function Gaussfindcursor (intervals, t)
@@ -202,7 +205,7 @@ function split_states(u, t, S::ODEGaussAdjointSensitivityFunction; update = true
202
205
end
203
206
204
207
@noinline function ODEAdjointProblem (sol, sensealg:: GaussAdjoint , alg,
205
- GaussInt:: GaussIntegrand ,
208
+ GaussInt:: GaussIntegrand , integrating_cb,
206
209
t = nothing ,
207
210
dgdu_discrete:: DG1 = nothing ,
208
211
dgdp_discrete:: DG2 = nothing ,
275
278
λ = zero (u0)
276
279
end
277
280
sense = ODEGaussAdjointSensitivityFunction (g, sensealg, GaussInt, discrete, sol,
278
- dgdu_continuous, dgdp_continuous, f, alg, checkpoints,
281
+ dgdu_continuous, dgdp_continuous, f, alg, checkpoints, integrating_cb,
279
282
(reltol = reltol, abstol = abstol), tstops, tspan = tspan)
280
283
281
284
init_cb = (discrete || dgdu_discrete != = nothing ) # && tspan[1] == t[end]
@@ -521,15 +524,16 @@ function _adjoint_sensitivities(sol, sensealg::GaussAdjoint, alg; t = nothing,
521
524
kwargs... )
522
525
integrand = GaussIntegrand (sol, sensealg, checkpoints, dgdp_continuous)
523
526
integrand_values = IntegrandValuesSum (allocate_zeros (sol. prob. p))
524
- cb = IntegratingSumCallback ((out, u, t, integrator) -> integrand (out, t, u),
527
+ integrating_cb = IntegratingSumCallback ((out, u, t, integrator) -> integrand (out, t, u),
525
528
integrand_values, allocate_vjp (sol. prob. p))
526
529
rcb = nothing
527
530
cb2 = nothing
528
531
adj_prob = nothing
529
532
530
533
if sol. prob isa ODEProblem
531
534
adj_prob, cb2, rcb = ODEAdjointProblem (
532
- sol, sensealg, alg, integrand, t, dgdu_discrete,
535
+ sol, sensealg, alg, integrand, integrating_cb,
536
+ t, dgdu_discrete,
533
537
dgdp_discrete,
534
538
dgdu_continuous, dgdp_continuous, g, Val (true );
535
539
checkpoints = checkpoints,
@@ -544,7 +548,7 @@ function _adjoint_sensitivities(sol, sensealg::GaussAdjoint, alg; t = nothing,
544
548
adj_sol = solve (
545
549
adj_prob, alg; abstol = abstol, reltol = reltol, save_everystep = false ,
546
550
save_start = false , save_end = true , saveat = eltype (sol. u[1 ])[], tstops = tstops,
547
- callback = CallbackSet (cb , cb2), kwargs... )
551
+ callback = CallbackSet (integrating_cb , cb2), kwargs... )
548
552
res = integrand_values. integrand
549
553
550
554
if rcb != = nothing && ! isempty (rcb. Δλas)
0 commit comments