Skip to content

Feat: Handle Adjoints through Initialization #1168

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 77 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
77 commits
Select commit Hold shift + click to select a range
d6290bf
feat(initialization): get gradients against initialization problem
DhairyaLGandhi Mar 6, 2025
d3199c0
test: add initialization problem to test suite
DhairyaLGandhi Mar 6, 2025
a4fa7c5
chore: pass tunables to jacobian
DhairyaLGandhi Mar 6, 2025
5a7dd26
test: cleanup imports
DhairyaLGandhi Mar 6, 2025
94ec324
chore: rm debug statements
DhairyaLGandhi Mar 6, 2025
0c5564e
test: add Core8 to CI
DhairyaLGandhi Mar 6, 2025
aca9cd4
chore: use get_initial_values
DhairyaLGandhi Mar 11, 2025
9bec784
chore: check for OVERDETERMINED initialization for solving initializa…
DhairyaLGandhi Mar 11, 2025
72cbb35
chore: pass sensealg to initial_values
DhairyaLGandhi Mar 11, 2025
2d85e19
chore: treat Delta as array
DhairyaLGandhi Mar 13, 2025
9a8a845
chore: use autojacvec from sensealg
DhairyaLGandhi Mar 13, 2025
95ebbf3
chore: move igs before solve, re-use initialization
DhairyaLGandhi Mar 13, 2025
957d7fe
chore: update igs to re-use inital values
DhairyaLGandhi Mar 13, 2025
a00574f
chore: qualify NoInit
DhairyaLGandhi Mar 13, 2025
4562f0c
chore: remove igs from steady state adjoint gor initialization
DhairyaLGandhi Mar 17, 2025
6c21324
chore: accumulate gradients in steady state adjoint explicitly
DhairyaLGandhi Mar 19, 2025
a675a7f
fix: handle MTKparameters and Arrays uniformly
DhairyaLGandhi Mar 20, 2025
7941a3c
feat: allow reverse mode for initialization solving
DhairyaLGandhi Mar 20, 2025
9557e8c
test: add more tests for parameter initialization
DhairyaLGandhi Mar 20, 2025
8feae0e
test: fix label
DhairyaLGandhi Mar 20, 2025
d3b1669
chore: rename file
DhairyaLGandhi Mar 20, 2025
e01eb77
test: fix sensealg and confusing error message
DhairyaLGandhi Mar 21, 2025
0ad6c62
chore: return new_u0
DhairyaLGandhi Mar 24, 2025
6df7987
chore: rebase branch
DhairyaLGandhi Mar 24, 2025
c4c7807
chore: mark symbol local
DhairyaLGandhi Mar 25, 2025
b85b16e
chore: pass tunables to Tape
DhairyaLGandhi Mar 25, 2025
8fc4136
chore: update new_u0, new_p to orignal vals if not initializing
DhairyaLGandhi Mar 26, 2025
1f1cce5
chore: rebase master
DhairyaLGandhi Mar 26, 2025
0d1abcc
Merge branch 'dg/initprob' of github.com:SciML/SciMLSensitivity.jl in…
DhairyaLGandhi Mar 26, 2025
b164b18
chore: add MSL to test deps
DhairyaLGandhi Mar 26, 2025
915d949
feat: allow analytically solved initialization solutions to propagate…
DhairyaLGandhi Apr 3, 2025
d2fd79a
chore: force allocated buffers for vjp to be deterministic
DhairyaLGandhi Apr 3, 2025
a0cd94a
chore: pass tunables to allocate vjp of MTKParameters
DhairyaLGandhi Apr 3, 2025
885794d
test: Core6 dont access J.du
DhairyaLGandhi Apr 4, 2025
13a1ffb
chore: SteadyStateAdjoint could be thunk
DhairyaLGandhi Apr 4, 2025
7826866
chore: import AbstractThunk
DhairyaLGandhi Apr 4, 2025
6ceaa1a
chore: handle upstream pullback better in steady state adjoint
DhairyaLGandhi Apr 4, 2025
dfebd0b
chore: dont accum pullback for parameters
DhairyaLGandhi Apr 4, 2025
396f63e
test: import SII
DhairyaLGandhi Apr 5, 2025
de7e7da
test: wrap ps in ComponentArray
DhairyaLGandhi Apr 7, 2025
f5fb559
chore: call du for jacobian
DhairyaLGandhi Apr 10, 2025
019a051
chore: add recursive_copyto for identical NT trees
DhairyaLGandhi Apr 10, 2025
91ee019
deps: MSL compat
DhairyaLGandhi Apr 10, 2025
4b74718
chore: undo du access
DhairyaLGandhi Apr 10, 2025
94d5e2b
chore: handle J through fwd mode
DhairyaLGandhi Apr 10, 2025
764d3ff
chore: J = nothing instead of NT
DhairyaLGandhi Apr 11, 2025
84b2602
chore: check nothing in steady state
DhairyaLGandhi Apr 13, 2025
8a4aa79
chore: also canonicalize dp
DhairyaLGandhi Apr 13, 2025
d69ccb1
chore: adjoint of preallocation tools
DhairyaLGandhi Apr 13, 2025
2cc3673
chore: pass Δ to canonicalize
DhairyaLGandhi Apr 14, 2025
22f056a
chore: handle different parameter types
DhairyaLGandhi Apr 14, 2025
1f95b25
chore: check for number
DhairyaLGandhi Apr 14, 2025
0d78fa8
test: rm commented out code
DhairyaLGandhi Apr 14, 2025
6620e8a
chore: get tunables from dp
DhairyaLGandhi Apr 14, 2025
5f5633b
test: clean up initialization tests
DhairyaLGandhi Apr 14, 2025
984c2ce
chore: pass initializealg
DhairyaLGandhi Apr 15, 2025
82cd5fe
chore: force path through BrownBasicInit
DhairyaLGandhi Apr 16, 2025
987e8be
chore: replace NoInit with BrownBasicInit
DhairyaLGandhi Apr 17, 2025
056fffa
test: add test to check residual with initialization
DhairyaLGandhi Apr 17, 2025
a122340
chore: replace BrownBasicInit with CheckInit
DhairyaLGandhi Apr 17, 2025
e105838
chore: qualify checinit
DhairyaLGandhi Apr 17, 2025
aaddc02
test: add test to force DAE initialization and prevent simplification…
DhairyaLGandhi Apr 21, 2025
60be1c7
test: check DAE initialization takes in BrownFullBasicInit
DhairyaLGandhi Apr 21, 2025
9edbe02
chore: check for default path, handle nlsolve kwargs as ODECore inter…
DhairyaLGandhi Apr 21, 2025
308ae5c
chore: update imported symbols
DhairyaLGandhi Apr 21, 2025
a333588
Update src/concrete_solve.jl
DhairyaLGandhi Apr 21, 2025
a2cf0e6
Update src/concrete_solve.jl
ChrisRackauckas Apr 21, 2025
755c9df
Update mtk.jl
ChrisRackauckas Apr 21, 2025
75ad141
chore: typo
DhairyaLGandhi Apr 22, 2025
e07dd53
chore: qualify DefaultInit
DhairyaLGandhi Apr 22, 2025
ee804b2
chore: run parameter initialization by passing missing parameters
DhairyaLGandhi Apr 23, 2025
7abf42c
chore: rm dead code
DhairyaLGandhi Apr 23, 2025
a7d4e5a
test: allocate gt based on size of new_sol
DhairyaLGandhi Apr 23, 2025
8e660fe
Update Project.toml
ChrisRackauckas Apr 23, 2025
de63cf9
Update test/mtk.jl
ChrisRackauckas Apr 23, 2025
b88f468
Update mtk.jl
ChrisRackauckas Apr 23, 2025
7dd1cc7
Also test u0 gradients
ChrisRackauckas Apr 23, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ jobs:
- Core5
- Core6
- Core7
- Core8
- QA
- SDE1
- SDE2
Expand Down
4 changes: 2 additions & 2 deletions src/adjoint_common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -379,11 +379,11 @@ function get_paramjac_config(autojacvec::ReverseDiffVJP, p, f, y, _p, _t;
if !isRODE
__p = p isa SciMLBase.NullParameters ? _p :
SciMLStructures.replace(Tunable(), p, _p)
tape = ReverseDiff.GradientTape((y, __p, [_t])) do u, p, t
tape = ReverseDiff.GradientTape((y, _p, [_t])) do u, p, t
du1 = (p !== nothing && p !== SciMLBase.NullParameters()) ?
similar(p, size(u)) : similar(u)
du1 .= false
f(du1, u, p, first(t))
f(du1, u, repack(p), first(t))
return vec(du1)
end
else
Expand Down
47 changes: 46 additions & 1 deletion src/concrete_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ function DiffEqBase._concrete_solve_adjoint(
tunables, repack = Functors.functor(p)
end

u0 = state_values(prob) === nothing ? Float64[] : u0
default_sensealg = automatic_sensealg_choice(prob, u0, tunables, verbose, repack)
DiffEqBase._concrete_solve_adjoint(prob, alg, default_sensealg, u0, p,
originator::SciMLBase.ADOriginator, args...; verbose,
Expand Down Expand Up @@ -425,6 +426,25 @@ function DiffEqBase._concrete_solve_adjoint(
save_end = true, kwargs_fwd...)
end

# Get gradients for the initialization problem if it exists
igs = if _prob.f.initialization_data != nothing
Zygote.gradient(tunables) do tunables
new_prob = remake(_prob, p = repack(tunables))
new_u0, new_p, _ = SciMLBase.get_initial_values(new_prob, new_prob, new_prob.f, SciMLBase.OverrideInit(), Val(true);
abstol = 1e-6,
reltol = 1e-6,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These don't make sense.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, these should probably inherit from kwargs or be set up to some default. Note that we must specify a tol for this dispatch.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in 984c2ce

sensealg = SteadyStateAdjoint(autojacvec = ZygoteVJP()))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't default to ZygoteVJP. Should use the autojacvec of the ODE

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in 9a8a845

new_tunables, _, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), new_p)
if SciMLBase.initialization_status(_prob) == SciMLBase.OVERDETERMINED
sum(new_tunables)
else
sum(new_u0) + sum(new_tunables)
end
end[1] .- one(eltype(tunables))
else
nothing
end

# Force `save_start` and `save_end` in the forward pass This forces the
# solver to do the backsolve all the way back to `u0` Since the start aliases
# `_prob.u0`, this doesn't actually use more memory But it cleans up the
Expand Down Expand Up @@ -642,6 +662,8 @@ function DiffEqBase._concrete_solve_adjoint(
dp = p === nothing || p === DiffEqBase.NullParameters() ? nothing :
dp isa AbstractArray ? reshape(dp', size(tunables)) : dp

dp = Zygote.accum(dp, igs)

_, repack_adjoint = if p === nothing || p === DiffEqBase.NullParameters() ||
!isscimlstructure(p)
nothing, x -> (x,)
Expand Down Expand Up @@ -1686,6 +1708,25 @@ function DiffEqBase._concrete_solve_adjoint(
out = SciMLBase.sensitivity_solution(sol, sol[_save_idxs])
end

# Get gradients for the initialization problem if it exists
igs = if _prob.f.initialization_data != nothing
Zygote.gradient(tunables) do tunables
new_prob = remake(_prob, p = repack(tunables))
new_u0, new_p, _ = SciMLBase.get_initial_values(new_prob, new_prob, new_prob.f, SciMLBase.OverrideInit(), Val(true);
abstol = 1e-6,
reltol = 1e-6,
sensealg = SteadyStateAdjoint(autojacvec = ZygoteVJP()))
new_tunables, _, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), new_p)
if SciMLBase.initialization_status(_prob) == SciMLBase.OVERDETERMINED
sum(new_tunables)
else
sum(new_u0) + sum(new_tunables)
end
end[1] .- one(eltype(tunables))
else
nothing
end

function steadystatebackpass(Δ)
# Δ = dg/dx or diffcache.dg_val
# del g/del p = 0
Expand All @@ -1694,11 +1735,15 @@ function DiffEqBase._concrete_solve_adjoint(
_out[_save_idxs] = Δ[_save_idxs]
elseif Δ isa Number
@. _out[_save_idxs] = Δ
else
elseif Δ isa AbstractArray{<:AbstractArray} || Δ isa AbstractVectorOfArray
@. _out[_save_idxs] = Δ[_save_idxs]
else
@. _out[_save_idxs] = Δ.u[_save_idxs]
end
end
# dp = adjoint_sensitivities(sol, alg; sensealg = sensealg, dgdu = df, dgdp = dp)
dp = adjoint_sensitivities(sol, alg; sensealg = sensealg, dgdu = df)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When the new reverse ode is built it needs to drop the initial eqs but still keep the dae constraints. It can brownbasic?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way to drop the initial eqs after its solved? The assumption was since we run with NoInit, no initialization is run post the first call to get_initial_values and we accumulate those gradients independently of the adaptive solve.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But the reverse pass needs to run with some form of initialization or the starting algebraic conditions may not be satisfied. Don't run this one with NoInit(), that would be prone to hiding issue. For this one, at most CheckInit(), but I'm saying that BrownBasicInit() is likely the one justified here since the 0 initial condition is only true on the differential variables, while the algebraic variable initial conditions will be unknown, but the Newton solve will have zero derivative because all of the inputs are just Newton guesses, so BrownBasic will work out for the reverse. We should probably hardcode that since it's always the solution there.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, that will require us to add an OrdinaryDiffEqCore dep in this package. I will add that.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the 0 derivative also applicable to parameters? Or only the unknowns?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its applicable to all Newton guess values. There is no parameter init going on to reverse so it's only for algebraic conditions so it's only Newton guesses.

dp = Zygote.accum(dp, SciMLStructures.replace(SciMLStructures.Tunable(), dp, igs))

if originator isa SciMLBase.TrackerOriginator ||
originator isa SciMLBase.ReverseDiffOriginator
Expand Down
10 changes: 6 additions & 4 deletions src/derivative_wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -456,9 +456,10 @@ function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::ReverseDiffVJP, dg
elseif inplace_sensitivity(S)
_y = eltype(y) === eltype(λ) ? y : convert.(promote_type(eltype(y), eltype(λ)), y)
if W === nothing
tape = ReverseDiff.GradientTape((_y, _p, [t])) do u, p, t
_tunables, _repack, _ = canonicalize(Tunable(), _p)
tape = ReverseDiff.GradientTape((_y, _tunables, [t])) do u, p, t
du1 = similar(u, size(u))
f(du1, u, p, first(t))
f(du1, u, _repack(p), first(t))
return vec(du1)
end
else
Expand Down Expand Up @@ -1068,9 +1069,10 @@ end

function build_param_jac_config(alg, pf, u, p)
if alg_autodiff(alg)
jac_config = ForwardDiff.JacobianConfig(pf, u, p,
tunables, repack, aliases = canonicalize(Tunable(), p)
jac_config = ForwardDiff.JacobianConfig(pf, u, tunables,
ForwardDiff.Chunk{
determine_chunksize(p,
determine_chunksize(tunables,
alg)}())
else
if diff_type(alg) != Val{:complex}
Expand Down
16 changes: 6 additions & 10 deletions src/steadystate_adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,12 @@ end
sense = SteadyStateAdjointSensitivityFunction(g, sensealg, alg, sol, dgdu, dgdp,
f, f.colorvec, needs_jac)
(; diffcache, y, sol, λ, vjp, linsolve) = sense

if needs_jac
if SciMLBase.has_jac(f)
f.jac(diffcache.J, y, p, nothing)
else
if DiffEqBase.isinplace(sol.prob)
jacobian!(diffcache.J, diffcache.uf, y, diffcache.f_cache,
jacobian!(diffcache.J.du, diffcache.uf, y, diffcache.f_cache,
sensealg, diffcache.jac_config)
else
diffcache.J .= jacobian(diffcache.uf, y, sensealg)
Expand Down Expand Up @@ -101,17 +100,14 @@ end
linear_problem = LinearProblem(soperator, vec(dgdu_val); u0 = vec(λ))
solve(linear_problem, linsolve; alias = LinearAliasSpecifier(alias_A = true), sensealg.linsolve_kwargs...)
else
if linsolve === nothing && isempty(sensealg.linsolve_kwargs)
# For the default case use `\` to avoid any form of unnecessary cache allocation
vec(λ) .= diffcache.J' \ vec(dgdu_val)
else
linear_problem = LinearProblem(diffcache.J', vec(dgdu_val'); u0 = vec(λ))
solve(linear_problem, linsolve; alias = LinearAliasSpecifier(alias_A = true), sensealg.linsolve_kwargs...) # u is vec(λ)
end
linear_problem = LinearProblem(diffcache.J.du', vec(dgdu_val'); u0 = vec(λ))
solve(linear_problem, linsolve; alias = LinearAliasSpecifier(alias_A = true), sensealg.linsolve_kwargs...) # u is vec(λ)
end

try
vecjacobian!(vec(dgdu_val), y, λ, p, nothing, sense; dgrad = vjp, dy = nothing)
tunables, repack, aliases = canonicalize(Tunable(), p)
vjp_tunables, vjp_repack, vjp_aliases = canonicalize(Tunable(), vjp)
vecjacobian!(vec(dgdu_val), y, λ, tunables, nothing, sense; dgrad = vjp_tunables, dy = nothing)
catch e
if sense.sensealg.autojacvec === nothing
@warn "Automatic AD choice of autojacvec failed in nonlinear solve adjoint, failing back to ODE adjoint + numerical vjp"
Expand Down
64 changes: 64 additions & 0 deletions test/desauty_dae_mwe.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
using ModelingToolkit, OrdinaryDiffEq
using ModelingToolkitStandardLibrary.Electrical
using ModelingToolkitStandardLibrary.Blocks: Sine
using NonlinearSolve
import SciMLStructures as SS
import SciMLSensitivity
using Zygote

function create_model(; C₁ = 3e-5, C₂ = 1e-6)
@variables t
@named resistor1 = Resistor(R = 5.0)
@named resistor2 = Resistor(R = 2.0)
@named capacitor1 = Capacitor(C = C₁)
@named capacitor2 = Capacitor(C = C₂)
@named source = Voltage()
@named input_signal = Sine(frequency = 100.0)
@named ground = Ground()
@named ampermeter = CurrentSensor()

eqs = [connect(input_signal.output, source.V)
connect(source.p, capacitor1.n, capacitor2.n)
connect(source.n, resistor1.p, resistor2.p, ground.g)
connect(resistor1.n, capacitor1.p, ampermeter.n)
connect(resistor2.n, capacitor2.p, ampermeter.p)]

@named circuit_model = ODESystem(eqs, t,
systems = [
resistor1, resistor2, capacitor1, capacitor2,
source, input_signal, ground, ampermeter,
])
end

desauty_model = create_model()
sys = structural_simplify(desauty_model)


prob = ODEProblem(sys, [], (0.0, 0.1), guesses = [sys.resistor1.v => 1.])
iprob = prob.f.initialization_data.initializeprob
isys = iprob.f.sys

tunables, repack, aliases = SS.canonicalize(SS.Tunable(), parameter_values(iprob))

linsolve = LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.QRFactorization)
sensealg = SciMLSensitivity.SteadyStateAdjoint(autojacvec = SciMLSensitivity.ZygoteVJP(), linsolve = linsolve)
igs, = Zygote.gradient(tunables) do p
iprob2 = remake(iprob, p = repack(p))
sol = solve(iprob2,
sensealg = sensealg
)
sum(Array(sol))
end

@test !iszero(sum(igs))


# tunable_parameters(isys) .=> gs

# gradient_unk1_idx = only(findfirst(x -> isequal(x, Initial(sys.capacitor1.v)), tunable_parameters(isys)))

# gs[gradient_unk1_idx]

# prob.f.initialization_data.update_initializeprob!(iprob, prob)
# prob.f.initialization_data.update_initializeprob!(iprob, ::Vector)
# prob.f.initialization_data.update_initializeprob!(iprob, gs)
6 changes: 6 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,12 @@ end
end
end

if GROUP == "All" || GROUP == "Core8"
@testset "Core 8" begin
@time @safetestset "Initialization with MTK" include("desauty_dae_mwe.jl")
end
end

if GROUP == "All" || GROUP == "QA"
@time @safetestset "Quality Assurance" include("aqua.jl")
end
Expand Down
Loading