Skip to content

Feat: Handle Adjoints through Initialization #1168

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 91 commits into from
Apr 24, 2025
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
91 commits
Select commit Hold shift + click to select a range
d6290bf
feat(initialization): get gradients against initialization problem
DhairyaLGandhi Mar 6, 2025
d3199c0
test: add initialization problem to test suite
DhairyaLGandhi Mar 6, 2025
a4fa7c5
chore: pass tunables to jacobian
DhairyaLGandhi Mar 6, 2025
5a7dd26
test: cleanup imports
DhairyaLGandhi Mar 6, 2025
94ec324
chore: rm debug statements
DhairyaLGandhi Mar 6, 2025
0c5564e
test: add Core8 to CI
DhairyaLGandhi Mar 6, 2025
aca9cd4
chore: use get_initial_values
DhairyaLGandhi Mar 11, 2025
9bec784
chore: check for OVERDETERMINED initialization for solving initializa…
DhairyaLGandhi Mar 11, 2025
72cbb35
chore: pass sensealg to initial_values
DhairyaLGandhi Mar 11, 2025
2d85e19
chore: treat Delta as array
DhairyaLGandhi Mar 13, 2025
9a8a845
chore: use autojacvec from sensealg
DhairyaLGandhi Mar 13, 2025
95ebbf3
chore: move igs before solve, re-use initialization
DhairyaLGandhi Mar 13, 2025
957d7fe
chore: update igs to re-use inital values
DhairyaLGandhi Mar 13, 2025
a00574f
chore: qualify NoInit
DhairyaLGandhi Mar 13, 2025
4562f0c
chore: remove igs from steady state adjoint gor initialization
DhairyaLGandhi Mar 17, 2025
6c21324
chore: accumulate gradients in steady state adjoint explicitly
DhairyaLGandhi Mar 19, 2025
a675a7f
fix: handle MTKparameters and Arrays uniformly
DhairyaLGandhi Mar 20, 2025
7941a3c
feat: allow reverse mode for initialization solving
DhairyaLGandhi Mar 20, 2025
9557e8c
test: add more tests for parameter initialization
DhairyaLGandhi Mar 20, 2025
8feae0e
test: fix label
DhairyaLGandhi Mar 20, 2025
d3b1669
chore: rename file
DhairyaLGandhi Mar 20, 2025
e01eb77
test: fix sensealg and confusing error message
DhairyaLGandhi Mar 21, 2025
0ad6c62
chore: return new_u0
DhairyaLGandhi Mar 24, 2025
6df7987
chore: rebase branch
DhairyaLGandhi Mar 24, 2025
c4c7807
chore: mark symbol local
DhairyaLGandhi Mar 25, 2025
b85b16e
chore: pass tunables to Tape
DhairyaLGandhi Mar 25, 2025
8fc4136
chore: update new_u0, new_p to orignal vals if not initializing
DhairyaLGandhi Mar 26, 2025
1f1cce5
chore: rebase master
DhairyaLGandhi Mar 26, 2025
0d1abcc
Merge branch 'dg/initprob' of github.com:SciML/SciMLSensitivity.jl in…
DhairyaLGandhi Mar 26, 2025
b164b18
chore: add MSL to test deps
DhairyaLGandhi Mar 26, 2025
915d949
feat: allow analytically solved initialization solutions to propagate…
DhairyaLGandhi Apr 3, 2025
d2fd79a
chore: force allocated buffers for vjp to be deterministic
DhairyaLGandhi Apr 3, 2025
a0cd94a
chore: pass tunables to allocate vjp of MTKParameters
DhairyaLGandhi Apr 3, 2025
885794d
test: Core6 dont access J.du
DhairyaLGandhi Apr 4, 2025
13a1ffb
chore: SteadyStateAdjoint could be thunk
DhairyaLGandhi Apr 4, 2025
7826866
chore: import AbstractThunk
DhairyaLGandhi Apr 4, 2025
6ceaa1a
chore: handle upstream pullback better in steady state adjoint
DhairyaLGandhi Apr 4, 2025
dfebd0b
chore: dont accum pullback for parameters
DhairyaLGandhi Apr 4, 2025
396f63e
test: import SII
DhairyaLGandhi Apr 5, 2025
de7e7da
test: wrap ps in ComponentArray
DhairyaLGandhi Apr 7, 2025
f5fb559
chore: call du for jacobian
DhairyaLGandhi Apr 10, 2025
019a051
chore: add recursive_copyto for identical NT trees
DhairyaLGandhi Apr 10, 2025
91ee019
deps: MSL compat
DhairyaLGandhi Apr 10, 2025
4b74718
chore: undo du access
DhairyaLGandhi Apr 10, 2025
94d5e2b
chore: handle J through fwd mode
DhairyaLGandhi Apr 10, 2025
764d3ff
chore: J = nothing instead of NT
DhairyaLGandhi Apr 11, 2025
84b2602
chore: check nothing in steady state
DhairyaLGandhi Apr 13, 2025
8a4aa79
chore: also canonicalize dp
DhairyaLGandhi Apr 13, 2025
d69ccb1
chore: adjoint of preallocation tools
DhairyaLGandhi Apr 13, 2025
2cc3673
chore: pass Δ to canonicalize
DhairyaLGandhi Apr 14, 2025
22f056a
chore: handle different parameter types
DhairyaLGandhi Apr 14, 2025
1f95b25
chore: check for number
DhairyaLGandhi Apr 14, 2025
0d78fa8
test: rm commented out code
DhairyaLGandhi Apr 14, 2025
6620e8a
chore: get tunables from dp
DhairyaLGandhi Apr 14, 2025
5f5633b
test: clean up initialization tests
DhairyaLGandhi Apr 14, 2025
984c2ce
chore: pass initializealg
DhairyaLGandhi Apr 15, 2025
82cd5fe
chore: force path through BrownBasicInit
DhairyaLGandhi Apr 16, 2025
987e8be
chore: replace NoInit with BrownBasicInit
DhairyaLGandhi Apr 17, 2025
056fffa
test: add test to check residual with initialization
DhairyaLGandhi Apr 17, 2025
a122340
chore: replace BrownBasicInit with CheckInit
DhairyaLGandhi Apr 17, 2025
e105838
chore: qualify checinit
DhairyaLGandhi Apr 17, 2025
aaddc02
test: add test to force DAE initialization and prevent simplification…
DhairyaLGandhi Apr 21, 2025
60be1c7
test: check DAE initialization takes in BrownFullBasicInit
DhairyaLGandhi Apr 21, 2025
9edbe02
chore: check for default path, handle nlsolve kwargs as ODECore inter…
DhairyaLGandhi Apr 21, 2025
308ae5c
chore: update imported symbols
DhairyaLGandhi Apr 21, 2025
a333588
Update src/concrete_solve.jl
DhairyaLGandhi Apr 21, 2025
a2cf0e6
Update src/concrete_solve.jl
ChrisRackauckas Apr 21, 2025
755c9df
Update mtk.jl
ChrisRackauckas Apr 21, 2025
75ad141
chore: typo
DhairyaLGandhi Apr 22, 2025
e07dd53
chore: qualify DefaultInit
DhairyaLGandhi Apr 22, 2025
ee804b2
chore: run parameter initialization by passing missing parameters
DhairyaLGandhi Apr 23, 2025
7abf42c
chore: rm dead code
DhairyaLGandhi Apr 23, 2025
a7d4e5a
test: allocate gt based on size of new_sol
DhairyaLGandhi Apr 23, 2025
8e660fe
Update Project.toml
ChrisRackauckas Apr 23, 2025
de63cf9
Update test/mtk.jl
ChrisRackauckas Apr 23, 2025
b88f468
Update mtk.jl
ChrisRackauckas Apr 23, 2025
7dd1cc7
Also test u0 gradients
ChrisRackauckas Apr 23, 2025
cdaa2c7
chore: merge upstream
DhairyaLGandhi Apr 24, 2025
d3608c4
Merge branch 'dg/initprob' of github.com:SciML/SciMLSensitivity.jl in…
DhairyaLGandhi Apr 24, 2025
4cf7bd5
Update test/mtk.jl
DhairyaLGandhi Apr 24, 2025
2ae712b
chore: handle when p is a functor in steady state adjoint
DhairyaLGandhi Apr 24, 2025
229e691
Merge branch 'dg/initprob' of github.com:SciML/SciMLSensitivity.jl in…
DhairyaLGandhi Apr 24, 2025
6e1109e
Merge branch 'master' into dg/initprob
DhairyaLGandhi Apr 24, 2025
d32b3f2
chore: git mixup
DhairyaLGandhi Apr 24, 2025
6e549e7
chore: git mixup
DhairyaLGandhi Apr 24, 2025
9789034
chore: revert bad commit
DhairyaLGandhi Apr 24, 2025
35937c0
chore: handle nothing dtunables for SteadyStateAdjoint
DhairyaLGandhi Apr 24, 2025
f934635
chore: rm u0 nothing forced to empty array
DhairyaLGandhi Apr 24, 2025
a83cd29
chore: reverse order of nothing check
DhairyaLGandhi Apr 24, 2025
2dfbc6e
chore: rm dead code
DhairyaLGandhi Apr 24, 2025
9aecbfd
chore: DEQ handling
DhairyaLGandhi Apr 24, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ jobs:
- Core5
- Core6
- Core7
- Core8
- QA
- SDE1
- SDE2
Expand Down
15 changes: 15 additions & 0 deletions src/concrete_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,21 @@ function DiffEqBase._concrete_solve_adjoint(
save_end = true, kwargs_fwd...)
end

# Get gradients for the initialization problem if it exists
igs = if _prob.f.initialization_data.initializeprob != nothing
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be before the solve, since you can use the initialization solution from here in the remakes of 397-405 in order to set new u0 and p and thus skip running the initialization a second time.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How can I indicate to solve to avoid running initialization?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

initializealg = NoInit(). Should probably just do CheckInit() for safety but either is fine.

iprob = _prob.f.initialization_data.initializeprob
ip = parameter_values(iprob)
itunables, irepack, ialiases = canonicalize(Tunable(), ip)
igs, = Zygote.gradient(ip) do ip
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This gradient isn't used? I think this would go into the backpass and if I'm thinking clearly, the resulting return is dp .* igs?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not yet. These gradients are currently against the parameters of the initialization problem, not the system exactly. And the mapping between the two is ill defined, so we cannot simply accum

I spoke with @AayushSabharwal about a way to map, it seems initialization_data.intializeprobmap might have some support to return the correctly shaped vector, but there are cases where we cannot know the ordering of dp either.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's another subtlety. I am not sure we haven't missed some part of the cfg by manually handling accumulation of gradients. Or any transforms we might need to calculate gradients for. The regular AD graph building typically took care of these details for us, but in this case we would need to worry about incorrect gradients manually

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yes, you need to use the initializeprobmap https://github.com/SciML/SciMLBase.jl/blob/master/src/initialization.jl#L268 to map it back to the shape of the initial parameters.

but there are cases where we cannot know the ordering of dp either.

p and dp just need the same ordering, so initializeprobmap should do the trick.

There's another subtlety. I am not sure we haven't missed some part of the cfg by manually handling accumulation of gradients. Or any transforms we might need to calculate gradients for. The regular AD graph building typically took care of these details for us, but in this case we would need to worry about incorrect gradients manually

This is the only change to (u0,p) before solving, so this would account for it, given initializeprobmap is just an index map so an identity function.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed this occurance in 95ebbf3 to check if this is correct. Will need to work around the global call

iprob2 = remake(iprob, p = ip)
sol = solve(iprob2)
sum(Array(sol))
end
igs
else
nothing
end

# Force `save_start` and `save_end` in the forward pass This forces the
# solver to do the backsolve all the way back to `u0` Since the start aliases
# `_prob.u0`, this doesn't actually use more memory But it cleans up the
Expand Down
5 changes: 3 additions & 2 deletions src/derivative_wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1068,9 +1068,10 @@ end

function build_param_jac_config(alg, pf, u, p)
if alg_autodiff(alg)
jac_config = ForwardDiff.JacobianConfig(pf, u, p,
tunables, repack, aliases = canonicalize(Tunable(), p)
jac_config = ForwardDiff.JacobianConfig(pf, u, tunables,
ForwardDiff.Chunk{
determine_chunksize(p,
determine_chunksize(tunables,
alg)}())
else
if diff_type(alg) != Val{:complex}
Expand Down
10 changes: 6 additions & 4 deletions src/steadystate_adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,12 @@ end
sense = SteadyStateAdjointSensitivityFunction(g, sensealg, alg, sol, dgdu, dgdp,
f, f.colorvec, needs_jac)
(; diffcache, y, sol, λ, vjp, linsolve) = sense

if needs_jac
if SciMLBase.has_jac(f)
f.jac(diffcache.J, y, p, nothing)
else
if DiffEqBase.isinplace(sol.prob)
jacobian!(diffcache.J, diffcache.uf, y, diffcache.f_cache,
jacobian!(diffcache.J.du, diffcache.uf, y, diffcache.f_cache,
sensealg, diffcache.jac_config)
else
diffcache.J .= jacobian(diffcache.uf, y, sensealg)
Expand Down Expand Up @@ -103,15 +102,18 @@ end
else
if linsolve === nothing && isempty(sensealg.linsolve_kwargs)
# For the default case use `\` to avoid any form of unnecessary cache allocation
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I don't know about that comment. I think it's just old. (a) \ always allocates because it uses lu instead of lu!, so it's re-allocating the while matrix which is larger than any LinearSolve allocation, and (b) we have since 2023 setup tests on StaticArrays, so the immutable path is non-allocating. I don't think (b) was true when this was written.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So glad we can remove this branch altogether.

vec(λ) .= diffcache.J' \ vec(dgdu_val)
linear_problem = LinearProblem(diffcache.J.du', vec(dgdu_val'); u0 = vec(λ))
solve(linear_problem, linsolve; alias = LinearAliasSpecifier(alias_A = true), sensealg.linsolve_kwargs...) # u is vec(λ)
else
linear_problem = LinearProblem(diffcache.J', vec(dgdu_val'); u0 = vec(λ))
solve(linear_problem, linsolve; alias = LinearAliasSpecifier(alias_A = true), sensealg.linsolve_kwargs...) # u is vec(λ)
end
end

try
vecjacobian!(vec(dgdu_val), y, λ, p, nothing, sense; dgrad = vjp, dy = nothing)
tunables, repack, aliases = canonicalize(Tunable(), p)
vjp_tunables, vjp_repack, vjp_aliases = canonicalize(Tunable(), vjp)
vecjacobian!(vec(dgdu_val), y, λ, tunables, nothing, sense; dgrad = vjp_tunables, dy = nothing)
catch e
if sense.sensealg.autojacvec === nothing
@warn "Automatic AD choice of autojacvec failed in nonlinear solve adjoint, failing back to ODE adjoint + numerical vjp"
Expand Down
64 changes: 64 additions & 0 deletions test/desauty_dae_mwe.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
using ModelingToolkit, OrdinaryDiffEq
using ModelingToolkitStandardLibrary.Electrical
using ModelingToolkitStandardLibrary.Blocks: Sine
using NonlinearSolve
import SciMLStructures as SS
import SciMLSensitivity
using Zygote

function create_model(; C₁ = 3e-5, C₂ = 1e-6)
@variables t
@named resistor1 = Resistor(R = 5.0)
@named resistor2 = Resistor(R = 2.0)
@named capacitor1 = Capacitor(C = C₁)
@named capacitor2 = Capacitor(C = C₂)
@named source = Voltage()
@named input_signal = Sine(frequency = 100.0)
@named ground = Ground()
@named ampermeter = CurrentSensor()

eqs = [connect(input_signal.output, source.V)
connect(source.p, capacitor1.n, capacitor2.n)
connect(source.n, resistor1.p, resistor2.p, ground.g)
connect(resistor1.n, capacitor1.p, ampermeter.n)
connect(resistor2.n, capacitor2.p, ampermeter.p)]

@named circuit_model = ODESystem(eqs, t,
systems = [
resistor1, resistor2, capacitor1, capacitor2,
source, input_signal, ground, ampermeter,
])
end

desauty_model = create_model()
sys = structural_simplify(desauty_model)


prob = ODEProblem(sys, [], (0.0, 0.1), guesses = [sys.resistor1.v => 1.])
iprob = prob.f.initialization_data.initializeprob
isys = iprob.f.sys

tunables, repack, aliases = SS.canonicalize(SS.Tunable(), parameter_values(iprob))

linsolve = LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.QRFactorization)
sensealg = SciMLSensitivity.SteadyStateAdjoint(autojacvec = SciMLSensitivity.ZygoteVJP(), linsolve = linsolve)
igs, = Zygote.gradient(tunables) do p
iprob2 = remake(iprob, p = repack(p))
sol = solve(iprob2,
sensealg = sensealg
)
sum(Array(sol))
end

@test !iszero(sum(igs))


# tunable_parameters(isys) .=> gs

# gradient_unk1_idx = only(findfirst(x -> isequal(x, Initial(sys.capacitor1.v)), tunable_parameters(isys)))

# gs[gradient_unk1_idx]

# prob.f.initialization_data.update_initializeprob!(iprob, prob)
# prob.f.initialization_data.update_initializeprob!(iprob, ::Vector)
# prob.f.initialization_data.update_initializeprob!(iprob, gs)
6 changes: 6 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,12 @@ end
end
end

if GROUP == "All" || GROUP == "Core8"
@testset "Core 8" begin
@time @safetestset "Initialization with MTK" include("desauty_dae_mwe.jl")
end
end

if GROUP == "All" || GROUP == "QA"
@time @safetestset "Quality Assurance" include("aqua.jl")
end
Expand Down
Loading