-
-
Notifications
You must be signed in to change notification settings - Fork 76
Feat: Handle Adjoints through Initialization #1168
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 55 commits
d6290bf
d3199c0
a4fa7c5
5a7dd26
94ec324
0c5564e
aca9cd4
9bec784
72cbb35
2d85e19
9a8a845
95ebbf3
957d7fe
a00574f
4562f0c
6c21324
a675a7f
7941a3c
9557e8c
8feae0e
d3b1669
e01eb77
0ad6c62
6df7987
c4c7807
b85b16e
8fc4136
1f1cce5
0d1abcc
b164b18
915d949
d2fd79a
a0cd94a
885794d
13a1ffb
7826866
6ceaa1a
dfebd0b
396f63e
de7e7da
f5fb559
019a051
91ee019
4b74718
94d5e2b
764d3ff
84b2602
8a4aa79
d69ccb1
2cc3673
22f056a
1f95b25
0d78fa8
6620e8a
5f5633b
984c2ce
82cd5fe
987e8be
056fffa
a122340
e105838
aaddc02
60be1c7
9edbe02
308ae5c
a333588
a2cf0e6
755c9df
75ad141
e07dd53
ee804b2
7abf42c
a7d4e5a
8e660fe
de63cf9
b88f468
7dd1cc7
cdaa2c7
d3608c4
4cf7bd5
2ae712b
229e691
6e1109e
d32b3f2
6e549e7
9789034
35937c0
f934635
a83cd29
2dfbc6e
9aecbfd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,6 +24,7 @@ jobs: | |
- Core5 | ||
- Core6 | ||
- Core7 | ||
- Core8 | ||
- QA | ||
- SDE1 | ||
- SDE2 | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -78,7 +78,7 @@ function adjointdiffcache(g::G, sensealg, discrete, sol, dgdu::DG1, dgdp::DG2, f | |
unwrappedf = unwrapped_f(f) | ||
|
||
numparams = p === nothing || p === SciMLBase.NullParameters() ? 0 : length(tunables) | ||
numindvar = length(u0) | ||
numindvar = isnothing(u0) ? nothing : length(u0) | ||
isautojacvec = get_jacvec(sensealg) | ||
|
||
issemiexplicitdae = false | ||
|
@@ -106,18 +106,22 @@ function adjointdiffcache(g::G, sensealg, discrete, sol, dgdu::DG1, dgdp::DG2, f | |
isempty(algevar_idxs) || (issemiexplicitdae = true) | ||
end | ||
if !issemiexplicitdae | ||
diffvar_idxs = eachindex(u0) | ||
diffvar_idxs = isnothing(u0) ? nothing : eachindex(u0) | ||
algevar_idxs = 1:0 | ||
end | ||
|
||
if !needs_jac && !issemiexplicitdae && !(autojacvec isa Bool) | ||
J = nothing | ||
else | ||
if alg === nothing || SciMLBase.forwarddiffs_model_time(alg) | ||
# 1 chunk is fine because it's only t | ||
_J = similar(u0, numindvar, numindvar) | ||
_J .= 0 | ||
J = dualcache(_J, ForwardDiff.pickchunksize(length(u0))) | ||
if !isnothing(u0) | ||
# 1 chunk is fine because it's only t | ||
_J = similar(u0, numindvar, numindvar) | ||
_J .= 0 | ||
J = dualcache(_J, ForwardDiff.pickchunksize(length(u0))) | ||
else | ||
J = nothing | ||
end | ||
else | ||
J = similar(u0, numindvar, numindvar) | ||
J .= 0 | ||
|
@@ -133,8 +137,12 @@ function adjointdiffcache(g::G, sensealg, discrete, sol, dgdu::DG1, dgdp::DG2, f | |
dg_val[1] .= false | ||
dg_val[2] .= false | ||
else | ||
dg_val = similar(u0, numindvar) # number of funcs size | ||
dg_val .= false | ||
if !isnothing(u0) | ||
dg_val = similar(u0, numindvar) # number of funcs size | ||
dg_val .= false | ||
else | ||
dg_val = nothing | ||
end | ||
end | ||
else | ||
pgpu = UGradientWrapper(g, _t, p) | ||
|
@@ -241,8 +249,12 @@ function adjointdiffcache(g::G, sensealg, discrete, sol, dgdu::DG1, dgdp::DG2, f | |
pJ = if (quad || !(autojacvec isa Bool)) | ||
nothing | ||
else | ||
_pJ = similar(u0, numindvar, numparams) | ||
_pJ .= false | ||
if !isnothing(u0) | ||
_pJ = similar(u0, numindvar, numparams) | ||
_pJ .= false | ||
else | ||
_pJ = nothing | ||
end | ||
end | ||
|
||
f_cache = isinplace ? deepcopy(u0) : nothing | ||
|
@@ -379,11 +391,11 @@ function get_paramjac_config(autojacvec::ReverseDiffVJP, p, f, y, _p, _t; | |
if !isRODE | ||
__p = p isa SciMLBase.NullParameters ? _p : | ||
SciMLStructures.replace(Tunable(), p, _p) | ||
tape = ReverseDiff.GradientTape((y, __p, [_t])) do u, p, t | ||
tape = ReverseDiff.GradientTape((y, _p, [_t])) do u, p, t | ||
du1 = (p !== nothing && p !== SciMLBase.NullParameters()) ? | ||
similar(p, size(u)) : similar(u) | ||
du1 .= false | ||
f(du1, u, p, first(t)) | ||
f(du1, u, repack(p), first(t)) | ||
return vec(du1) | ||
end | ||
else | ||
|
@@ -402,8 +414,8 @@ function get_paramjac_config(autojacvec::ReverseDiffVJP, p, f, y, _p, _t; | |
# because hasportion(Tunable(), NullParameters) == false | ||
__p = p isa SciMLBase.NullParameters ? _p : | ||
SciMLStructures.replace(Tunable(), p, _p) | ||
tape = ReverseDiff.GradientTape((y, __p, [_t])) do u, p, t | ||
vec(f(u, p, first(t))) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same down here? |
||
tape = ReverseDiff.GradientTape((y, _p, [_t])) do u, p, t | ||
vec(f(u, repack(p), first(t))) | ||
end | ||
else | ||
tape = ReverseDiff.GradientTape((y, _p, [_t], _W)) do u, p, t, W | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -46,15 +46,16 @@ function inplace_vjp(prob, u0, p, verbose, repack) | |
|
||
vjp = try | ||
f = unwrapped_f(prob.f) | ||
tspan_ = prob isa AbstractNonlinearProblem ? nothing : [prob.tspan[1]] | ||
if p === nothing || p isa SciMLBase.NullParameters | ||
ReverseDiff.GradientTape((copy(u0), [prob.tspan[1]])) do u, t | ||
ReverseDiff.GradientTape((copy(u0), tspan_)) do u, t | ||
du1 = similar(u, size(u)) | ||
du1 .= 0 | ||
f(du1, u, p, first(t)) | ||
return vec(du1) | ||
end | ||
else | ||
ReverseDiff.GradientTape((copy(u0), p, [prob.tspan[1]])) do u, p, t | ||
ReverseDiff.GradientTape((copy(u0), p, tspan_)) do u, p, t | ||
du1 = similar(u, size(u)) | ||
du1 .= 0 | ||
f(du1, u, repack(p), first(t)) | ||
|
@@ -299,6 +300,7 @@ function DiffEqBase._concrete_solve_adjoint( | |
tunables, repack = Functors.functor(p) | ||
end | ||
|
||
u0 = state_values(prob) === nothing ? Float64[] : u0 | ||
default_sensealg = automatic_sensealg_choice(prob, u0, tunables, verbose, repack) | ||
DiffEqBase._concrete_solve_adjoint(prob, alg, default_sensealg, u0, p, | ||
originator::SciMLBase.ADOriginator, args...; verbose, | ||
|
@@ -412,16 +414,42 @@ function DiffEqBase._concrete_solve_adjoint( | |
Base.diff_names(Base._nt_names(values(kwargs)), | ||
(:callback_adj, :callback))}(values(kwargs)) | ||
isq = sensealg isa QuadratureAdjoint | ||
|
||
igs, new_u0, new_p = if _prob.f.initialization_data !== nothing | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also needs to check that initializealg is not set, is the default, or is using OverrideInit. Should test this is not triggered with say manual BrownBasicInit There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Got it. My understanding was that OverrideInit was what we strictly needed. We can check for BrownBasicInit/ defaults here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There doesn't seem to be a method which can take a ERROR: MethodError: no method matching get_initial_values(::ODEProblem{…}, ::ODEProblem{…}, ::ODEFunction{…}, ::BrownFullBasicInit{…}, ::Val{…}; sensealg::SteadyStateAdjoint{…}, nlsolve_alg::Nothing)
Closest candidates are:
get_initial_values(::Any, ::Any, ::Any, ::NoInit, ::Any; kwargs...)
@ SciMLBase ~/Downloads/arpa/jsmo/t2/SciMLBase.jl/src/initialization.jl:282
get_initial_values(::Any, ::Any, ::Any, ::SciMLBase.OverrideInit, ::Union{Val{true}, Val{false}}; nlsolve_alg, abstol, reltol, kwargs...)
@ SciMLBase ~/Downloads/arpa/jsmo/t2/SciMLBase.jl/src/initialization.jl:224
get_initial_values(::SciMLBase.AbstractDEProblem, ::SciMLBase.DEIntegrator, ::Any, ::CheckInit, ::Union{Val{true}, Val{false}}; abstol, kwargs...)
@ SciMLBase ~/Downloads/arpa/jsmo/t2/SciMLBase.jl/src/initialization.jl:161
...
Only There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Having chatted with @AayushSabharwal on this, it seems like BrownBasic and ShampineCollocation do not yet have a path through There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What would be the best course of action here? Seems like supporting BrownBasicInit is a dispatch that will automatically be utilised when it is moved into SciMLBase. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes I mean BrownBasicInit should not be taking this path. But that's a problem because then they will be disabled in the next stage below, and that needs to be accounted for. This dispatch is already built and setup for BrownBasicInit and there are tests on that. |
||
local new_u0 | ||
local new_p | ||
iy, back = Zygote.pullback(tunables) do tunables | ||
new_prob = remake(_prob, p = repack(tunables)) | ||
new_u0, new_p, _ = SciMLBase.get_initial_values(new_prob, new_prob, new_prob.f, SciMLBase.OverrideInit(), Val(true); | ||
abstol = 1e-6, | ||
reltol = 1e-6, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These don't make sense. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, these should probably inherit from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Addressed in 984c2ce |
||
sensealg = SteadyStateAdjoint(autojacvec = sensealg.autojacvec), | ||
kwargs...) | ||
new_tunables, _, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), new_p) | ||
if SciMLBase.initialization_status(_prob) == SciMLBase.OVERDETERMINED | ||
sum(new_tunables) | ||
else | ||
sum(new_u0) + sum(new_tunables) | ||
end | ||
end | ||
igs = back(one(iy))[1] .- one(eltype(tunables)) | ||
|
||
igs, new_u0, new_p | ||
else | ||
nothing, u0, p | ||
end | ||
_prob = remake(_prob, u0 = new_u0, p = new_p) | ||
|
||
if sensealg isa BacksolveAdjoint | ||
sol = solve(_prob, alg, args...; save_noise = true, | ||
sol = solve(_prob, alg, args...; initializealg = SciMLBase.NoInit(), save_noise = true, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It should only noinit if the previous case was ran. Won't this right now break the brownbasic tests? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If there was no intialization data, it won't have ran the initialization problem at all. If I can genetically ignore handling There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No that is not correct. If there was no initialization data then it will use the built in initialization, defaulting to BrownBasicInit. It's impossible for a DAE solver to generally work without running initialization of some form, the MTK one is just a new specialized one but there has always been a numerical one in the solver. And if it hits that case, this code will now disable that. https://github.com/SciML/SciMLSensitivity.jl/blob/master/test/adjoint.jl#L952-L978 this code will hit that. I think it's not failing because it's not so pronounced here. You might want to change that test to https://github.com/SciML/SciMLSensitivity.jl/blob/master/test/adjoint.jl#L975C5-L975C69 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You're right of course for the DAEs, but since There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Both There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh 😅 . That case was too simple, MTK turns it into an ODE. Let's make it a DAE. @parameters σ ρ β A[1:3]
@variables x(t) y(t) z(t) w(t) w2(t)
eqs = [D(D(x)) ~ σ * (y - x),
D(y) ~ x * (ρ - z) - y,
D(z) ~ x * y - β * z,
w ~ x + y + z + 2 * β
0 ~ x^2 + y^2 - w2^2
]
@mtkbuild sys = ODESystem(eqs, t) That should make it so that it eliminates the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That will need to change the integrator to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For SDEs, we will just need to make it compatible with BrownBasicInit. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, I was so confused why it worked out, I see the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay good. Yeah because MTK is too smart and makes lots of simple examples not DAEs 😅. But now you got the DAE, and if not running the built in init then you get the error I was expecting. The fix is that it needs to run brownbasic before solving for the same reason reverse needs to. Good we worked out a test for this |
||
save_start = save_start, save_end = save_end, | ||
saveat = saveat, kwargs_fwd...) | ||
elseif ischeckpointing(sensealg) | ||
sol = solve(_prob, alg, args...; save_noise = true, | ||
sol = solve(_prob, alg, args...; initializealg = SciMLBase.NoInit(), save_noise = true, | ||
save_start = true, save_end = true, | ||
saveat = saveat, kwargs_fwd...) | ||
else | ||
sol = solve(_prob, alg, args...; save_noise = true, save_start = true, | ||
sol = solve(_prob, alg, args...; initializealg = SciMLBase.NoInit(), save_noise = true, save_start = true, | ||
save_end = true, kwargs_fwd...) | ||
end | ||
|
||
|
@@ -491,6 +519,7 @@ function DiffEqBase._concrete_solve_adjoint( | |
_save_idxs = save_idxs === nothing ? Colon() : save_idxs | ||
|
||
function adjoint_sensitivity_backpass(Δ) | ||
Δ = Δ isa AbstractThunk ? unthunk(Δ) : Δ | ||
function df_iip(_out, u, p, t, i) | ||
outtype = _out isa SubArray ? | ||
ArrayInterface.parameterless_type(_out.parent) : | ||
|
@@ -642,6 +671,8 @@ function DiffEqBase._concrete_solve_adjoint( | |
dp = p === nothing || p === DiffEqBase.NullParameters() ? nothing : | ||
dp isa AbstractArray ? reshape(dp', size(tunables)) : dp | ||
|
||
dp = Zygote.accum(dp, igs) | ||
|
||
_, repack_adjoint = if p === nothing || p === DiffEqBase.NullParameters() || | ||
!isscimlstructure(p) | ||
nothing, x -> (x,) | ||
|
@@ -1679,6 +1710,7 @@ function DiffEqBase._concrete_solve_adjoint( | |
u0, p, originator::SciMLBase.ADOriginator, | ||
args...; save_idxs = nothing, kwargs...) | ||
_prob = remake(prob, u0 = u0, p = p) | ||
|
||
sol = solve(_prob, alg, args...; kwargs...) | ||
_save_idxs = save_idxs === nothing ? Colon() : save_idxs | ||
|
||
|
@@ -1688,26 +1720,56 @@ function DiffEqBase._concrete_solve_adjoint( | |
out = SciMLBase.sensitivity_solution(sol, sol[_save_idxs]) | ||
end | ||
|
||
_, repack_adjoint = if isscimlstructure(p) | ||
Zygote.pullback(p) do p | ||
t, _, _ = canonicalize(Tunable(), p) | ||
t | ||
end | ||
else | ||
nothing, x -> (x,) | ||
end | ||
|
||
function steadystatebackpass(Δ) | ||
Δ = Δ isa AbstractThunk ? unthunk(Δ) : Δ | ||
# Δ = dg/dx or diffcache.dg_val | ||
# del g/del p = 0 | ||
function df(_out, u, p, t, i) | ||
if _save_idxs isa Number | ||
_out[_save_idxs] = Δ[_save_idxs] | ||
elseif Δ isa Number | ||
@. _out[_save_idxs] = Δ | ||
else | ||
elseif Δ isa AbstractArray{<:AbstractArray} || Δ isa AbstractVectorOfArray || Δ isa AbstractArray | ||
@. _out[_save_idxs] = Δ[_save_idxs] | ||
elseif isnothing(_out) | ||
_out | ||
else | ||
@. _out[_save_idxs] = Δ.u[_save_idxs] | ||
end | ||
end | ||
dp = adjoint_sensitivities(sol, alg; sensealg = sensealg, dgdu = df) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When the new reverse ode is built it needs to drop the initial eqs but still keep the dae constraints. It can brownbasic? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a way to drop the initial eqs after its solved? The assumption was since we run with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But the reverse pass needs to run with some form of initialization or the starting algebraic conditions may not be satisfied. Don't run this one with NoInit(), that would be prone to hiding issue. For this one, at most CheckInit(), but I'm saying that BrownBasicInit() is likely the one justified here since the 0 initial condition is only true on the differential variables, while the algebraic variable initial conditions will be unknown, but the Newton solve will have zero derivative because all of the inputs are just Newton guesses, so BrownBasic will work out for the reverse. We should probably hardcode that since it's always the solution there. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, that will require us to add an OrdinaryDiffEqCore dep in this package. I will add that. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is the 0 derivative also applicable to parameters? Or only the unknowns? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Its applicable to all Newton guess values. There is no parameter init going on to reverse so it's only for algebraic conditions so it's only Newton guesses.
ChrisRackauckas marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
dp, Δtunables = if Δ isa AbstractArray || Δ isa Number | ||
# if Δ isa AbstractArray, the gradients correspond to `u` | ||
# this is something that needs changing in the future, but | ||
# this is the applicable till the movement to structuaral | ||
# tangents is completed | ||
dp, _, _ = canonicalize(Tunable(), dp) | ||
dp, nothing | ||
else | ||
Δp = setproperties(dp, to_nt(Δ.prob.p)) | ||
Δtunables, _, _ = canonicalize(Tunable(), Δp) | ||
dp, _, _ = canonicalize(Tunable(), dp) | ||
dp, Δtunables | ||
end | ||
|
||
dp = Zygote.accum(dp, Δtunables) | ||
|
||
if originator isa SciMLBase.TrackerOriginator || | ||
originator isa SciMLBase.ReverseDiffOriginator | ||
(NoTangent(), NoTangent(), NoTangent(), dp, NoTangent(), | ||
(NoTangent(), NoTangent(), NoTangent(), repack_adjoint(dp)[1], NoTangent(), | ||
ntuple(_ -> NoTangent(), length(args))...) | ||
else | ||
(NoTangent(), NoTangent(), NoTangent(), NoTangent(), dp, NoTangent(), | ||
(NoTangent(), NoTangent(), NoTangent(), NoTangent(), repack_adjoint(dp)[1], NoTangent(), | ||
ntuple(_ -> NoTangent(), length(args))...) | ||
end | ||
end | ||
|
Uh oh!
There was an error while loading. Please reload this page.