Skip to content

Commit 5f6ab43

Browse files
Merge pull request #1168 from SciML/dg/initprob
Feat: Handle Adjoints through Initialization
2 parents a21900f + 9aecbfd commit 5f6ab43

16 files changed

+476
-60
lines changed

.github/workflows/CI.yml

+1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ jobs:
2424
- Core5
2525
- Core6
2626
- Core7
27+
- Core8
2728
- QA
2829
- SDE1
2930
- SDE2

Project.toml

+6-2
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
2424
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
2525
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
2626
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
27+
OrdinaryDiffEqCore = "bbf590c4-e513-4bbe-9b18-05decba2e5d8"
2728
PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46"
2829
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
2930
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -76,13 +77,15 @@ LinearAlgebra = "1.10"
7677
LinearSolve = "2, 3"
7778
Lux = "1"
7879
Markdown = "1.10"
79-
ModelingToolkit = "9.42"
80+
ModelingToolkit = "9.74"
81+
ModelingToolkitStandardLibrary = "2"
8082
Mooncake = "0.4.52"
8183
NLsolve = "4.5.1"
8284
NonlinearSolve = "3.0.1, 4"
8385
Optimization = "4"
8486
OptimizationOptimisers = "0.3"
8587
OrdinaryDiffEq = "6.81.1"
88+
OrdinaryDiffEqCore = "1"
8689
Pkg = "1.10"
8790
PreallocationTools = "0.4.4"
8891
QuadGK = "2.9.1"
@@ -117,6 +120,7 @@ DelayDiffEq = "bcd4f6db-9728-5f36-b5f7-82caef46ccdb"
117120
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
118121
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
119122
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
123+
ModelingToolkitStandardLibrary = "16a59e39-deab-5bd0-87e4-056b12336739"
120124
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
121125
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
122126
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
@@ -131,4 +135,4 @@ StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"
131135
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
132136

133137
[targets]
134-
test = ["AlgebraicMultigrid", "Aqua", "Calculus", "ComponentArrays", "DelayDiffEq", "Distributed", "Lux", "ModelingToolkit", "Mooncake", "NLsolve", "NonlinearSolve", "Optimization", "OptimizationOptimisers", "OrdinaryDiffEq", "Pkg", "SafeTestsets", "SparseArrays", "SteadyStateDiffEq", "StochasticDiffEq", "Test"]
138+
test = ["AlgebraicMultigrid", "Aqua", "Calculus", "ComponentArrays", "DelayDiffEq", "Distributed", "Lux", "ModelingToolkit", "ModelingToolkitStandardLibrary", "Mooncake", "NLsolve", "NonlinearSolve", "Optimization", "OptimizationOptimisers", "OrdinaryDiffEq", "Pkg", "SafeTestsets", "SparseArrays", "SteadyStateDiffEq", "StochasticDiffEq", "Test"]

src/SciMLSensitivity.jl

+7-2
Original file line numberDiff line numberDiff line change
@@ -37,16 +37,19 @@ using SciMLBase: SciMLBase, AbstractOverloadingSensitivityAlgorithm,
3737
RODEFunction, RODEProblem, ReturnCode, SDEFunction,
3838
SDEProblem, VectorContinuousCallback, deleteat!,
3939
get_tmp_cache, has_adjoint, isinplace, reinit!, remake,
40-
solve, u_modified!, LinearAliasSpecifier
40+
solve, u_modified!, LinearAliasSpecifier, OverrideInit, CheckInit
41+
42+
using OrdinaryDiffEqCore: OrdinaryDiffEqCore, BrownFullBasicInit, DefaultInit, default_nlsolve, has_autodiff
4143

4244
# AD Backends
43-
using ChainRulesCore: unthunk, @thunk, NoTangent, @not_implemented, Tangent, ZeroTangent
45+
using ChainRulesCore: unthunk, @thunk, NoTangent, @not_implemented, Tangent, ZeroTangent, AbstractThunk
4446
using Enzyme: Enzyme
4547
using FiniteDiff: FiniteDiff
4648
using ForwardDiff: ForwardDiff
4749
using Tracker: Tracker, TrackedArray
4850
using ReverseDiff: ReverseDiff
4951
using Zygote: Zygote
52+
using SciMLBase.ConstructionBase
5053

5154
# Std Libs
5255
using LinearAlgebra: LinearAlgebra, Diagonal, I, UniformScaling, adjoint, axpy!,
@@ -56,6 +59,8 @@ using Markdown: Markdown, @doc_str
5659
using Random: Random, rand!
5760
using Statistics: Statistics, mean
5861

62+
using LinearAlgebra: diag
63+
5964
abstract type SensitivityFunction end
6065
abstract type TransformedFunction end
6166

src/adjoint_common.jl

+26-14
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ function adjointdiffcache(g::G, sensealg, discrete, sol, dgdu::DG1, dgdp::DG2, f
7878
unwrappedf = unwrapped_f(f)
7979

8080
numparams = p === nothing || p === SciMLBase.NullParameters() ? 0 : length(tunables)
81-
numindvar = length(u0)
81+
numindvar = isnothing(u0) ? nothing : length(u0)
8282
isautojacvec = get_jacvec(sensealg)
8383

8484
issemiexplicitdae = false
@@ -106,18 +106,22 @@ function adjointdiffcache(g::G, sensealg, discrete, sol, dgdu::DG1, dgdp::DG2, f
106106
isempty(algevar_idxs) || (issemiexplicitdae = true)
107107
end
108108
if !issemiexplicitdae
109-
diffvar_idxs = eachindex(u0)
109+
diffvar_idxs = isnothing(u0) ? nothing : eachindex(u0)
110110
algevar_idxs = 1:0
111111
end
112112

113113
if !needs_jac && !issemiexplicitdae && !(autojacvec isa Bool)
114114
J = nothing
115115
else
116116
if alg === nothing || SciMLBase.forwarddiffs_model_time(alg)
117-
# 1 chunk is fine because it's only t
118-
_J = similar(u0, numindvar, numindvar)
119-
_J .= 0
120-
J = dualcache(_J, ForwardDiff.pickchunksize(length(u0)))
117+
if !isnothing(u0)
118+
# 1 chunk is fine because it's only t
119+
_J = similar(u0, numindvar, numindvar)
120+
_J .= 0
121+
J = dualcache(_J, ForwardDiff.pickchunksize(length(u0)))
122+
else
123+
J = nothing
124+
end
121125
else
122126
J = similar(u0, numindvar, numindvar)
123127
J .= 0
@@ -133,8 +137,12 @@ function adjointdiffcache(g::G, sensealg, discrete, sol, dgdu::DG1, dgdp::DG2, f
133137
dg_val[1] .= false
134138
dg_val[2] .= false
135139
else
136-
dg_val = similar(u0, numindvar) # number of funcs size
137-
dg_val .= false
140+
if !isnothing(u0)
141+
dg_val = similar(u0, numindvar) # number of funcs size
142+
dg_val .= false
143+
else
144+
dg_val = nothing
145+
end
138146
end
139147
else
140148
pgpu = UGradientWrapper(g, _t, p)
@@ -241,8 +249,12 @@ function adjointdiffcache(g::G, sensealg, discrete, sol, dgdu::DG1, dgdp::DG2, f
241249
pJ = if (quad || !(autojacvec isa Bool))
242250
nothing
243251
else
244-
_pJ = similar(u0, numindvar, numparams)
245-
_pJ .= false
252+
if !isnothing(u0)
253+
_pJ = similar(u0, numindvar, numparams)
254+
_pJ .= false
255+
else
256+
_pJ = nothing
257+
end
246258
end
247259

248260
f_cache = isinplace ? deepcopy(u0) : nothing
@@ -379,11 +391,11 @@ function get_paramjac_config(autojacvec::ReverseDiffVJP, p, f, y, _p, _t;
379391
if !isRODE
380392
__p = p isa SciMLBase.NullParameters ? _p :
381393
SciMLStructures.replace(Tunable(), p, _p)
382-
tape = ReverseDiff.GradientTape((y, __p, [_t])) do u, p, t
394+
tape = ReverseDiff.GradientTape((y, _p, [_t])) do u, p, t
383395
du1 = (p !== nothing && p !== SciMLBase.NullParameters()) ?
384396
similar(p, size(u)) : similar(u)
385397
du1 .= false
386-
f(du1, u, p, first(t))
398+
f(du1, u, repack(p), first(t))
387399
return vec(du1)
388400
end
389401
else
@@ -402,8 +414,8 @@ function get_paramjac_config(autojacvec::ReverseDiffVJP, p, f, y, _p, _t;
402414
# because hasportion(Tunable(), NullParameters) == false
403415
__p = p isa SciMLBase.NullParameters ? _p :
404416
SciMLStructures.replace(Tunable(), p, _p)
405-
tape = ReverseDiff.GradientTape((y, __p, [_t])) do u, p, t
406-
vec(f(u, p, first(t)))
417+
tape = ReverseDiff.GradientTape((y, _p, [_t])) do u, p, t
418+
vec(f(u, repack(p), first(t)))
407419
end
408420
else
409421
tape = ReverseDiff.GradientTape((y, _p, [_t], _W)) do u, p, t, W

src/concrete_solve.jl

+111-11
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,16 @@ function inplace_vjp(prob, u0, p, verbose, repack)
4646

4747
vjp = try
4848
f = unwrapped_f(prob.f)
49+
tspan_ = prob isa AbstractNonlinearProblem ? nothing : [prob.tspan[1]]
4950
if p === nothing || p isa SciMLBase.NullParameters
50-
ReverseDiff.GradientTape((copy(u0), [prob.tspan[1]])) do u, t
51+
ReverseDiff.GradientTape((copy(u0), tspan_)) do u, t
5152
du1 = similar(u, size(u))
5253
du1 .= 0
5354
f(du1, u, p, first(t))
5455
return vec(du1)
5556
end
5657
else
57-
ReverseDiff.GradientTape((copy(u0), p, [prob.tspan[1]])) do u, p, t
58+
ReverseDiff.GradientTape((copy(u0), p, tspan_)) do u, p, t
5859
du1 = similar(u, size(u))
5960
du1 .= 0
6061
f(du1, u, repack(p), first(t))
@@ -299,6 +300,7 @@ function DiffEqBase._concrete_solve_adjoint(
299300
tunables, repack = Functors.functor(p)
300301
end
301302

303+
u0 = state_values(prob) === nothing ? Float64[] : u0
302304
default_sensealg = automatic_sensealg_choice(prob, u0, tunables, verbose, repack)
303305
DiffEqBase._concrete_solve_adjoint(prob, alg, default_sensealg, u0, p,
304306
originator::SciMLBase.ADOriginator, args...; verbose,
@@ -371,6 +373,7 @@ function DiffEqBase._concrete_solve_adjoint(
371373
args...; save_start = true, save_end = true,
372374
saveat = eltype(prob.tspan)[],
373375
save_idxs = nothing,
376+
initializealg_default = SciMLBase.OverrideInit(; abstol = 1e-6, reltol = 1e-3),
374377
kwargs...)
375378
if !(sensealg isa GaussAdjoint) &&
376379
!(p isa Union{Nothing, SciMLBase.NullParameters, AbstractArray}) ||
@@ -412,16 +415,61 @@ function DiffEqBase._concrete_solve_adjoint(
412415
Base.diff_names(Base._nt_names(values(kwargs)),
413416
(:callback_adj, :callback))}(values(kwargs))
414417
isq = sensealg isa QuadratureAdjoint
418+
kwargs_init = kwargs_adj[Base.diff_names(Base._nt_names(kwargs_adj), (:initializealg,))]
419+
420+
if haskey(kwargs, :initializealg) || haskey(prob.kwargs, :initializealg)
421+
initializealg = haskey(kwargs, :initializealg) ? kwargs[:initializealg] : prob.kwargs[:initializealg]
422+
else
423+
initializealg = DefaultInit()
424+
end
425+
426+
default_inits = Union{OverrideInit, Nothing, DefaultInit}
427+
igs, new_u0, new_p, new_initializealg = if (SciMLBase.has_initialization_data(_prob.f) && initializealg isa default_inits)
428+
local new_u0
429+
local new_p
430+
initializeprob = prob.f.initialization_data.initializeprob
431+
iu0 = state_values(initializeprob)
432+
isAD = if iu0 === nothing
433+
AutoForwardDiff
434+
elseif has_autodiff(alg)
435+
OrdinaryDiffEqCore.alg_autodiff(alg) isa AutoForwardDiff
436+
else
437+
true
438+
end
439+
nlsolve_alg = default_nlsolve(nothing, Val(isinplace(_prob)), iu0, initializeprob, isAD)
440+
initializealg = initializealg isa Union{Nothing, DefaultInit} ? initializealg_default : initializealg
441+
442+
iy, back = Zygote.pullback(tunables) do tunables
443+
new_prob = remake(_prob, p = repack(tunables))
444+
new_u0, new_p, _ = SciMLBase.get_initial_values(new_prob, new_prob, new_prob.f, initializealg, Val(isinplace(new_prob));
445+
sensealg = SteadyStateAdjoint(autojacvec = sensealg.autojacvec),
446+
nlsolve_alg,
447+
kwargs_init...)
448+
new_tunables, _, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), new_p)
449+
if SciMLBase.initialization_status(_prob) == SciMLBase.OVERDETERMINED
450+
sum(new_tunables)
451+
else
452+
sum(new_u0) + sum(new_tunables)
453+
end
454+
end
455+
igs = back(one(iy))[1] .- one(eltype(tunables))
456+
457+
igs, new_u0, new_p, SciMLBase.NoInit()
458+
else
459+
nothing, u0, p, initializealg
460+
end
461+
_prob = remake(_prob, u0 = new_u0, p = new_p)
462+
415463
if sensealg isa BacksolveAdjoint
416-
sol = solve(_prob, alg, args...; save_noise = true,
464+
sol = solve(_prob, alg, args...; initializealg = new_initializealg, save_noise = true,
417465
save_start = save_start, save_end = save_end,
418466
saveat = saveat, kwargs_fwd...)
419467
elseif ischeckpointing(sensealg)
420-
sol = solve(_prob, alg, args...; save_noise = true,
468+
sol = solve(_prob, alg, args...; initializealg = new_initializealg, save_noise = true,
421469
save_start = true, save_end = true,
422470
saveat = saveat, kwargs_fwd...)
423471
else
424-
sol = solve(_prob, alg, args...; save_noise = true, save_start = true,
472+
sol = solve(_prob, alg, args...; initializealg = new_initializealg, save_noise = true, save_start = true,
425473
save_end = true, kwargs_fwd...)
426474
end
427475

@@ -491,6 +539,7 @@ function DiffEqBase._concrete_solve_adjoint(
491539
_save_idxs = save_idxs === nothing ? Colon() : save_idxs
492540

493541
function adjoint_sensitivity_backpass(Δ)
542+
Δ = Δ isa AbstractThunk ? unthunk(Δ) : Δ
494543
function df_iip(_out, u, p, t, i)
495544
outtype = _out isa SubArray ?
496545
ArrayInterface.parameterless_type(_out.parent) :
@@ -628,20 +677,22 @@ function DiffEqBase._concrete_solve_adjoint(
628677
dgdu_discrete = df_iip,
629678
sensealg = sensealg,
630679
callback = cb2,
631-
kwargs_adj...)
680+
kwargs_init...)
632681
else
633682
du0, dp = adjoint_sensitivities(sol, alg, args...; t = ts,
634683
dgdu_discrete = df_oop,
635684
sensealg = sensealg,
636685
callback = cb2,
637-
kwargs_adj...)
686+
kwargs_init...)
638687
end
639688

640689
du0 = reshape(du0, size(u0))
641690

642691
dp = p === nothing || p === DiffEqBase.NullParameters() ? nothing :
643692
dp isa AbstractArray ? reshape(dp', size(tunables)) : dp
644693

694+
dp = Zygote.accum(dp, igs)
695+
645696
_, repack_adjoint = if p === nothing || p === DiffEqBase.NullParameters() ||
646697
!isscimlstructure(p)
647698
nothing, x -> (x,)
@@ -1679,6 +1730,7 @@ function DiffEqBase._concrete_solve_adjoint(
16791730
u0, p, originator::SciMLBase.ADOriginator,
16801731
args...; save_idxs = nothing, kwargs...)
16811732
_prob = remake(prob, u0 = u0, p = p)
1733+
16821734
sol = solve(_prob, alg, args...; kwargs...)
16831735
_save_idxs = save_idxs === nothing ? Colon() : save_idxs
16841736

@@ -1688,26 +1740,74 @@ function DiffEqBase._concrete_solve_adjoint(
16881740
out = SciMLBase.sensitivity_solution(sol, sol[_save_idxs])
16891741
end
16901742

1743+
_, repack_adjoint = if isscimlstructure(p)
1744+
Zygote.pullback(p) do p
1745+
t, _, _ = canonicalize(Tunable(), p)
1746+
t
1747+
end
1748+
elseif isfunctor(p)
1749+
ps, re = Functors.functor(p)
1750+
ps, x -> (re(x),)
1751+
else
1752+
nothing, x -> (x,)
1753+
end
1754+
16911755
function steadystatebackpass(Δ)
1756+
Δ = Δ isa AbstractThunk ? unthunk(Δ) : Δ
16921757
# Δ = dg/dx or diffcache.dg_val
16931758
# del g/del p = 0
16941759
function df(_out, u, p, t, i)
16951760
if _save_idxs isa Number
16961761
_out[_save_idxs] = Δ[_save_idxs]
16971762
elseif Δ isa Number
16981763
@. _out[_save_idxs] = Δ
1699-
else
1764+
elseif Δ isa AbstractArray{<:AbstractArray} || Δ isa AbstractVectorOfArray || Δ isa AbstractArray
17001765
@. _out[_save_idxs] = Δ[_save_idxs]
1766+
elseif isnothing(_out)
1767+
_out
1768+
else
1769+
@. _out[_save_idxs] = Δ.u[_save_idxs]
1770+
end
1771+
end
1772+
dp = adjoint_sensitivities(sol, alg; sensealg = sensealg, dgdu = df, initializealg = BrownFullBasicInit())
1773+
1774+
dp, Δtunables = if Δ isa AbstractArray || Δ isa Number
1775+
# if Δ isa AbstractArray, the gradients correspond to `u`
1776+
# this is something that needs changing in the future, but
1777+
# this is the applicable till the movement to structuaral
1778+
# tangents is completed
1779+
dp, Δtunables = if isscimlstructure(dp)
1780+
dp, _, _ = canonicalize(Tunable(), dp)
1781+
dp, nothing
1782+
elseif isfunctor(dp)
1783+
dp, _ = Functors.functor(dp)
1784+
dp, nothing
1785+
else
1786+
dp, nothing
1787+
end
1788+
else
1789+
dp, Δtunables = if isscimlstructure(p)
1790+
Δp = setproperties(dp, to_nt.prob.p))
1791+
Δtunables, _, _ = canonicalize(Tunable(), Δp)
1792+
dp, _, _ = canonicalize(Tunable(), dp)
1793+
dp, Δtunables
1794+
elseif isfunctor(p)
1795+
dp, _ = Functors.functor(dp)
1796+
Δtunables, _ = Functors.functor.prob.p)
1797+
dp, Δtunables
1798+
else
1799+
dp, Δ.prob.p
17011800
end
17021801
end
1703-
dp = adjoint_sensitivities(sol, alg; sensealg = sensealg, dgdu = df)
1802+
1803+
dp = Zygote.accum(dp, (isnothing(Δtunables) || isempty(Δtunables)) ? nothing : Δtunables)
17041804

17051805
if originator isa SciMLBase.TrackerOriginator ||
17061806
originator isa SciMLBase.ReverseDiffOriginator
1707-
(NoTangent(), NoTangent(), NoTangent(), dp, NoTangent(),
1807+
(NoTangent(), NoTangent(), NoTangent(), repack_adjoint(dp)[1], NoTangent(),
17081808
ntuple(_ -> NoTangent(), length(args))...)
17091809
else
1710-
(NoTangent(), NoTangent(), NoTangent(), NoTangent(), dp, NoTangent(),
1810+
(NoTangent(), NoTangent(), NoTangent(), NoTangent(), repack_adjoint(dp)[1], NoTangent(),
17111811
ntuple(_ -> NoTangent(), length(args))...)
17121812
end
17131813
end

0 commit comments

Comments
 (0)