Skip to content

Feat: Handle Adjoints through Initialization #1168

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 91 commits into from
Apr 24, 2025
Merged
Show file tree
Hide file tree
Changes from 55 commits
Commits
Show all changes
91 commits
Select commit Hold shift + click to select a range
d6290bf
feat(initialization): get gradients against initialization problem
DhairyaLGandhi Mar 6, 2025
d3199c0
test: add initialization problem to test suite
DhairyaLGandhi Mar 6, 2025
a4fa7c5
chore: pass tunables to jacobian
DhairyaLGandhi Mar 6, 2025
5a7dd26
test: cleanup imports
DhairyaLGandhi Mar 6, 2025
94ec324
chore: rm debug statements
DhairyaLGandhi Mar 6, 2025
0c5564e
test: add Core8 to CI
DhairyaLGandhi Mar 6, 2025
aca9cd4
chore: use get_initial_values
DhairyaLGandhi Mar 11, 2025
9bec784
chore: check for OVERDETERMINED initialization for solving initializa…
DhairyaLGandhi Mar 11, 2025
72cbb35
chore: pass sensealg to initial_values
DhairyaLGandhi Mar 11, 2025
2d85e19
chore: treat Delta as array
DhairyaLGandhi Mar 13, 2025
9a8a845
chore: use autojacvec from sensealg
DhairyaLGandhi Mar 13, 2025
95ebbf3
chore: move igs before solve, re-use initialization
DhairyaLGandhi Mar 13, 2025
957d7fe
chore: update igs to re-use inital values
DhairyaLGandhi Mar 13, 2025
a00574f
chore: qualify NoInit
DhairyaLGandhi Mar 13, 2025
4562f0c
chore: remove igs from steady state adjoint gor initialization
DhairyaLGandhi Mar 17, 2025
6c21324
chore: accumulate gradients in steady state adjoint explicitly
DhairyaLGandhi Mar 19, 2025
a675a7f
fix: handle MTKparameters and Arrays uniformly
DhairyaLGandhi Mar 20, 2025
7941a3c
feat: allow reverse mode for initialization solving
DhairyaLGandhi Mar 20, 2025
9557e8c
test: add more tests for parameter initialization
DhairyaLGandhi Mar 20, 2025
8feae0e
test: fix label
DhairyaLGandhi Mar 20, 2025
d3b1669
chore: rename file
DhairyaLGandhi Mar 20, 2025
e01eb77
test: fix sensealg and confusing error message
DhairyaLGandhi Mar 21, 2025
0ad6c62
chore: return new_u0
DhairyaLGandhi Mar 24, 2025
6df7987
chore: rebase branch
DhairyaLGandhi Mar 24, 2025
c4c7807
chore: mark symbol local
DhairyaLGandhi Mar 25, 2025
b85b16e
chore: pass tunables to Tape
DhairyaLGandhi Mar 25, 2025
8fc4136
chore: update new_u0, new_p to orignal vals if not initializing
DhairyaLGandhi Mar 26, 2025
1f1cce5
chore: rebase master
DhairyaLGandhi Mar 26, 2025
0d1abcc
Merge branch 'dg/initprob' of github.com:SciML/SciMLSensitivity.jl in…
DhairyaLGandhi Mar 26, 2025
b164b18
chore: add MSL to test deps
DhairyaLGandhi Mar 26, 2025
915d949
feat: allow analytically solved initialization solutions to propagate…
DhairyaLGandhi Apr 3, 2025
d2fd79a
chore: force allocated buffers for vjp to be deterministic
DhairyaLGandhi Apr 3, 2025
a0cd94a
chore: pass tunables to allocate vjp of MTKParameters
DhairyaLGandhi Apr 3, 2025
885794d
test: Core6 dont access J.du
DhairyaLGandhi Apr 4, 2025
13a1ffb
chore: SteadyStateAdjoint could be thunk
DhairyaLGandhi Apr 4, 2025
7826866
chore: import AbstractThunk
DhairyaLGandhi Apr 4, 2025
6ceaa1a
chore: handle upstream pullback better in steady state adjoint
DhairyaLGandhi Apr 4, 2025
dfebd0b
chore: dont accum pullback for parameters
DhairyaLGandhi Apr 4, 2025
396f63e
test: import SII
DhairyaLGandhi Apr 5, 2025
de7e7da
test: wrap ps in ComponentArray
DhairyaLGandhi Apr 7, 2025
f5fb559
chore: call du for jacobian
DhairyaLGandhi Apr 10, 2025
019a051
chore: add recursive_copyto for identical NT trees
DhairyaLGandhi Apr 10, 2025
91ee019
deps: MSL compat
DhairyaLGandhi Apr 10, 2025
4b74718
chore: undo du access
DhairyaLGandhi Apr 10, 2025
94d5e2b
chore: handle J through fwd mode
DhairyaLGandhi Apr 10, 2025
764d3ff
chore: J = nothing instead of NT
DhairyaLGandhi Apr 11, 2025
84b2602
chore: check nothing in steady state
DhairyaLGandhi Apr 13, 2025
8a4aa79
chore: also canonicalize dp
DhairyaLGandhi Apr 13, 2025
d69ccb1
chore: adjoint of preallocation tools
DhairyaLGandhi Apr 13, 2025
2cc3673
chore: pass Δ to canonicalize
DhairyaLGandhi Apr 14, 2025
22f056a
chore: handle different parameter types
DhairyaLGandhi Apr 14, 2025
1f95b25
chore: check for number
DhairyaLGandhi Apr 14, 2025
0d78fa8
test: rm commented out code
DhairyaLGandhi Apr 14, 2025
6620e8a
chore: get tunables from dp
DhairyaLGandhi Apr 14, 2025
5f5633b
test: clean up initialization tests
DhairyaLGandhi Apr 14, 2025
984c2ce
chore: pass initializealg
DhairyaLGandhi Apr 15, 2025
82cd5fe
chore: force path through BrownBasicInit
DhairyaLGandhi Apr 16, 2025
987e8be
chore: replace NoInit with BrownBasicInit
DhairyaLGandhi Apr 17, 2025
056fffa
test: add test to check residual with initialization
DhairyaLGandhi Apr 17, 2025
a122340
chore: replace BrownBasicInit with CheckInit
DhairyaLGandhi Apr 17, 2025
e105838
chore: qualify checinit
DhairyaLGandhi Apr 17, 2025
aaddc02
test: add test to force DAE initialization and prevent simplification…
DhairyaLGandhi Apr 21, 2025
60be1c7
test: check DAE initialization takes in BrownFullBasicInit
DhairyaLGandhi Apr 21, 2025
9edbe02
chore: check for default path, handle nlsolve kwargs as ODECore inter…
DhairyaLGandhi Apr 21, 2025
308ae5c
chore: update imported symbols
DhairyaLGandhi Apr 21, 2025
a333588
Update src/concrete_solve.jl
DhairyaLGandhi Apr 21, 2025
a2cf0e6
Update src/concrete_solve.jl
ChrisRackauckas Apr 21, 2025
755c9df
Update mtk.jl
ChrisRackauckas Apr 21, 2025
75ad141
chore: typo
DhairyaLGandhi Apr 22, 2025
e07dd53
chore: qualify DefaultInit
DhairyaLGandhi Apr 22, 2025
ee804b2
chore: run parameter initialization by passing missing parameters
DhairyaLGandhi Apr 23, 2025
7abf42c
chore: rm dead code
DhairyaLGandhi Apr 23, 2025
a7d4e5a
test: allocate gt based on size of new_sol
DhairyaLGandhi Apr 23, 2025
8e660fe
Update Project.toml
ChrisRackauckas Apr 23, 2025
de63cf9
Update test/mtk.jl
ChrisRackauckas Apr 23, 2025
b88f468
Update mtk.jl
ChrisRackauckas Apr 23, 2025
7dd1cc7
Also test u0 gradients
ChrisRackauckas Apr 23, 2025
cdaa2c7
chore: merge upstream
DhairyaLGandhi Apr 24, 2025
d3608c4
Merge branch 'dg/initprob' of github.com:SciML/SciMLSensitivity.jl in…
DhairyaLGandhi Apr 24, 2025
4cf7bd5
Update test/mtk.jl
DhairyaLGandhi Apr 24, 2025
2ae712b
chore: handle when p is a functor in steady state adjoint
DhairyaLGandhi Apr 24, 2025
229e691
Merge branch 'dg/initprob' of github.com:SciML/SciMLSensitivity.jl in…
DhairyaLGandhi Apr 24, 2025
6e1109e
Merge branch 'master' into dg/initprob
DhairyaLGandhi Apr 24, 2025
d32b3f2
chore: git mixup
DhairyaLGandhi Apr 24, 2025
6e549e7
chore: git mixup
DhairyaLGandhi Apr 24, 2025
9789034
chore: revert bad commit
DhairyaLGandhi Apr 24, 2025
35937c0
chore: handle nothing dtunables for SteadyStateAdjoint
DhairyaLGandhi Apr 24, 2025
f934635
chore: rm u0 nothing forced to empty array
DhairyaLGandhi Apr 24, 2025
a83cd29
chore: reverse order of nothing check
DhairyaLGandhi Apr 24, 2025
2dfbc6e
chore: rm dead code
DhairyaLGandhi Apr 24, 2025
9aecbfd
chore: DEQ handling
DhairyaLGandhi Apr 24, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ jobs:
- Core5
- Core6
- Core7
- Core8
- QA
- SDE1
- SDE2
Expand Down
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ LinearSolve = "2, 3"
Lux = "1"
Markdown = "1.10"
ModelingToolkit = "9.42"
ModelingToolkitStandardLibrary = "2"
Mooncake = "0.4.52"
NLsolve = "4.5.1"
NonlinearSolve = "3.0.1, 4"
Expand Down Expand Up @@ -117,6 +118,7 @@ DelayDiffEq = "bcd4f6db-9728-5f36-b5f7-82caef46ccdb"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
ModelingToolkitStandardLibrary = "16a59e39-deab-5bd0-87e4-056b12336739"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
Expand All @@ -131,4 +133,4 @@ StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["AlgebraicMultigrid", "Aqua", "Calculus", "ComponentArrays", "DelayDiffEq", "Distributed", "Lux", "ModelingToolkit", "Mooncake", "NLsolve", "NonlinearSolve", "Optimization", "OptimizationOptimisers", "OrdinaryDiffEq", "Pkg", "SafeTestsets", "SparseArrays", "SteadyStateDiffEq", "StochasticDiffEq", "Test"]
test = ["AlgebraicMultigrid", "Aqua", "Calculus", "ComponentArrays", "DelayDiffEq", "Distributed", "Lux", "ModelingToolkit", "ModelingToolkitStandardLibrary", "Mooncake", "NLsolve", "NonlinearSolve", "Optimization", "OptimizationOptimisers", "OrdinaryDiffEq", "Pkg", "SafeTestsets", "SparseArrays", "SteadyStateDiffEq", "StochasticDiffEq", "Test"]
5 changes: 4 additions & 1 deletion src/SciMLSensitivity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,14 @@ using SciMLBase: SciMLBase, AbstractOverloadingSensitivityAlgorithm,
solve, u_modified!, LinearAliasSpecifier

# AD Backends
using ChainRulesCore: unthunk, @thunk, NoTangent, @not_implemented, Tangent, ZeroTangent
using ChainRulesCore: unthunk, @thunk, NoTangent, @not_implemented, Tangent, ZeroTangent, AbstractThunk
using Enzyme: Enzyme
using FiniteDiff: FiniteDiff
using ForwardDiff: ForwardDiff
using Tracker: Tracker, TrackedArray
using ReverseDiff: ReverseDiff
using Zygote: Zygote
using SciMLBase.ConstructionBase

# Std Libs
using LinearAlgebra: LinearAlgebra, Diagonal, I, UniformScaling, adjoint, axpy!,
Expand All @@ -56,6 +57,8 @@ using Markdown: Markdown, @doc_str
using Random: Random, rand!
using Statistics: Statistics, mean

using LinearAlgebra: diag

abstract type SensitivityFunction end
abstract type TransformedFunction end

Expand Down
40 changes: 26 additions & 14 deletions src/adjoint_common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ function adjointdiffcache(g::G, sensealg, discrete, sol, dgdu::DG1, dgdp::DG2, f
unwrappedf = unwrapped_f(f)

numparams = p === nothing || p === SciMLBase.NullParameters() ? 0 : length(tunables)
numindvar = length(u0)
numindvar = isnothing(u0) ? nothing : length(u0)
isautojacvec = get_jacvec(sensealg)

issemiexplicitdae = false
Expand Down Expand Up @@ -106,18 +106,22 @@ function adjointdiffcache(g::G, sensealg, discrete, sol, dgdu::DG1, dgdp::DG2, f
isempty(algevar_idxs) || (issemiexplicitdae = true)
end
if !issemiexplicitdae
diffvar_idxs = eachindex(u0)
diffvar_idxs = isnothing(u0) ? nothing : eachindex(u0)
algevar_idxs = 1:0
end

if !needs_jac && !issemiexplicitdae && !(autojacvec isa Bool)
J = nothing
else
if alg === nothing || SciMLBase.forwarddiffs_model_time(alg)
# 1 chunk is fine because it's only t
_J = similar(u0, numindvar, numindvar)
_J .= 0
J = dualcache(_J, ForwardDiff.pickchunksize(length(u0)))
if !isnothing(u0)
# 1 chunk is fine because it's only t
_J = similar(u0, numindvar, numindvar)
_J .= 0
J = dualcache(_J, ForwardDiff.pickchunksize(length(u0)))
else
J = nothing
end
else
J = similar(u0, numindvar, numindvar)
J .= 0
Expand All @@ -133,8 +137,12 @@ function adjointdiffcache(g::G, sensealg, discrete, sol, dgdu::DG1, dgdp::DG2, f
dg_val[1] .= false
dg_val[2] .= false
else
dg_val = similar(u0, numindvar) # number of funcs size
dg_val .= false
if !isnothing(u0)
dg_val = similar(u0, numindvar) # number of funcs size
dg_val .= false
else
dg_val = nothing
end
end
else
pgpu = UGradientWrapper(g, _t, p)
Expand Down Expand Up @@ -241,8 +249,12 @@ function adjointdiffcache(g::G, sensealg, discrete, sol, dgdu::DG1, dgdp::DG2, f
pJ = if (quad || !(autojacvec isa Bool))
nothing
else
_pJ = similar(u0, numindvar, numparams)
_pJ .= false
if !isnothing(u0)
_pJ = similar(u0, numindvar, numparams)
_pJ .= false
else
_pJ = nothing
end
end

f_cache = isinplace ? deepcopy(u0) : nothing
Expand Down Expand Up @@ -379,11 +391,11 @@ function get_paramjac_config(autojacvec::ReverseDiffVJP, p, f, y, _p, _t;
if !isRODE
__p = p isa SciMLBase.NullParameters ? _p :
SciMLStructures.replace(Tunable(), p, _p)
tape = ReverseDiff.GradientTape((y, __p, [_t])) do u, p, t
tape = ReverseDiff.GradientTape((y, _p, [_t])) do u, p, t
du1 = (p !== nothing && p !== SciMLBase.NullParameters()) ?
similar(p, size(u)) : similar(u)
du1 .= false
f(du1, u, p, first(t))
f(du1, u, repack(p), first(t))
return vec(du1)
end
else
Expand All @@ -402,8 +414,8 @@ function get_paramjac_config(autojacvec::ReverseDiffVJP, p, f, y, _p, _t;
# because hasportion(Tunable(), NullParameters) == false
__p = p isa SciMLBase.NullParameters ? _p :
SciMLStructures.replace(Tunable(), p, _p)
tape = ReverseDiff.GradientTape((y, __p, [_t])) do u, p, t
vec(f(u, p, first(t)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same down here?

tape = ReverseDiff.GradientTape((y, _p, [_t])) do u, p, t
vec(f(u, repack(p), first(t)))
end
else
tape = ReverseDiff.GradientTape((y, _p, [_t], _W)) do u, p, t, W
Expand Down
78 changes: 70 additions & 8 deletions src/concrete_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,16 @@ function inplace_vjp(prob, u0, p, verbose, repack)

vjp = try
f = unwrapped_f(prob.f)
tspan_ = prob isa AbstractNonlinearProblem ? nothing : [prob.tspan[1]]
if p === nothing || p isa SciMLBase.NullParameters
ReverseDiff.GradientTape((copy(u0), [prob.tspan[1]])) do u, t
ReverseDiff.GradientTape((copy(u0), tspan_)) do u, t
du1 = similar(u, size(u))
du1 .= 0
f(du1, u, p, first(t))
return vec(du1)
end
else
ReverseDiff.GradientTape((copy(u0), p, [prob.tspan[1]])) do u, p, t
ReverseDiff.GradientTape((copy(u0), p, tspan_)) do u, p, t
du1 = similar(u, size(u))
du1 .= 0
f(du1, u, repack(p), first(t))
Expand Down Expand Up @@ -299,6 +300,7 @@ function DiffEqBase._concrete_solve_adjoint(
tunables, repack = Functors.functor(p)
end

u0 = state_values(prob) === nothing ? Float64[] : u0
default_sensealg = automatic_sensealg_choice(prob, u0, tunables, verbose, repack)
DiffEqBase._concrete_solve_adjoint(prob, alg, default_sensealg, u0, p,
originator::SciMLBase.ADOriginator, args...; verbose,
Expand Down Expand Up @@ -412,16 +414,42 @@ function DiffEqBase._concrete_solve_adjoint(
Base.diff_names(Base._nt_names(values(kwargs)),
(:callback_adj, :callback))}(values(kwargs))
isq = sensealg isa QuadratureAdjoint

igs, new_u0, new_p = if _prob.f.initialization_data !== nothing
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also needs to check that initializealg is not set, is the default, or is using OverrideInit. Should test this is not triggered with say manual BrownBasicInit

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. My understanding was that OverrideInit was what we strictly needed. We can check for BrownBasicInit/ defaults here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There doesn't seem to be a method which can take a BrownFullBasicInit(). I get a MethodError:

ERROR: MethodError: no method matching get_initial_values(::ODEProblem{…}, ::ODEProblem{…}, ::ODEFunction{…}, ::BrownFullBasicInit{…}, ::Val{…}; sensealg::SteadyStateAdjoint{…}, nlsolve_alg::Nothing)

Closest candidates are:
  get_initial_values(::Any, ::Any, ::Any, ::NoInit, ::Any; kwargs...)
   @ SciMLBase ~/Downloads/arpa/jsmo/t2/SciMLBase.jl/src/initialization.jl:282
  get_initial_values(::Any, ::Any, ::Any, ::SciMLBase.OverrideInit, ::Union{Val{true}, Val{false}}; nlsolve_alg, abstol, reltol, kwargs...)
   @ SciMLBase ~/Downloads/arpa/jsmo/t2/SciMLBase.jl/src/initialization.jl:224
  get_initial_values(::SciMLBase.AbstractDEProblem, ::SciMLBase.DEIntegrator, ::Any, ::CheckInit, ::Union{Val{true}, Val{false}}; abstol, kwargs...)
   @ SciMLBase ~/Downloads/arpa/jsmo/t2/SciMLBase.jl/src/initialization.jl:161
  ...

Only CheckInit, NoInit, and OverrideInit have dispatches.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having chatted with @AayushSabharwal on this, it seems like BrownBasic and ShampineCollocation do not yet have a path through get_initial_values and that would need to be fixed in OrdianryDiffEq. Further, as SciMLSensitivity does not depend on OrdinaryDiffEq, it cannot check for whether there is a default initialisation.since those are defined there. Depending on it also seems like a pretty big hammer for a dep.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would be the best course of action here? Seems like supporting BrownBasicInit is a dispatch that will automatically be utilised when it is moved into SciMLBase.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I mean BrownBasicInit should not be taking this path. But that's a problem because then they will be disabled in the next stage below, and that needs to be accounted for. This dispatch is already built and setup for BrownBasicInit and there are tests on that.

local new_u0
local new_p
iy, back = Zygote.pullback(tunables) do tunables
new_prob = remake(_prob, p = repack(tunables))
new_u0, new_p, _ = SciMLBase.get_initial_values(new_prob, new_prob, new_prob.f, SciMLBase.OverrideInit(), Val(true);
abstol = 1e-6,
reltol = 1e-6,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These don't make sense.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, these should probably inherit from kwargs or be set up to some default. Note that we must specify a tol for this dispatch.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in 984c2ce

sensealg = SteadyStateAdjoint(autojacvec = sensealg.autojacvec),
kwargs...)
new_tunables, _, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), new_p)
if SciMLBase.initialization_status(_prob) == SciMLBase.OVERDETERMINED
sum(new_tunables)
else
sum(new_u0) + sum(new_tunables)
end
end
igs = back(one(iy))[1] .- one(eltype(tunables))

igs, new_u0, new_p
else
nothing, u0, p
end
_prob = remake(_prob, u0 = new_u0, p = new_p)

if sensealg isa BacksolveAdjoint
sol = solve(_prob, alg, args...; save_noise = true,
sol = solve(_prob, alg, args...; initializealg = SciMLBase.NoInit(), save_noise = true,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should only noinit if the previous case was ran. Won't this right now break the brownbasic tests?

Copy link
Member Author

@DhairyaLGandhi DhairyaLGandhi Apr 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there was no intialization data, it won't have ran the initialization problem at all.

If I can genetically ignore handling initializealg and pass it directly to get_initial_values, that would be good. Then I can also pass NoInit here genetically.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No that is not correct. If there was no initialization data then it will use the built in initialization, defaulting to BrownBasicInit. It's impossible for a DAE solver to generally work without running initialization of some form, the MTK one is just a new specialized one but there has always been a numerical one in the solver. And if it hits that case, this code will now disable that.

https://github.com/SciML/SciMLSensitivity.jl/blob/master/test/adjoint.jl#L952-L978 this code will hit that. I think it's not failing because it's not so pronounced here. You might want to change that test to https://github.com/SciML/SciMLSensitivity.jl/blob/master/test/adjoint.jl#L975C5-L975C69 prob_singular_mm = ODEProblem(f, [1.0, 0.0, 1.0], (0.0, 100), p) and it would pass before and fail now.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right of course for the DAEs, but since BrownBasicInit is defined in OrdinaryDiffEq, and this package does not depend on it, I need a way for us to be able to dispatch to it. So if I understand the comment from earlier, we need a check for the default initialization, and add a branch that solves for that prob, and collect all the outputs.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both BrownBasicInit and OrdinaryDiffEqCore.DefaultInit require us to depend on a whole package for the default dispatch. Can it be exposed as a dispatch of get_initial_values instead?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh 😅 . That case was too simple, MTK turns it into an ODE. Let's make it a DAE.

@parameters σ ρ β A[1:3]
@variables x(t) y(t) z(t) w(t) w2(t)
eqs = [D(D(x)) ~ σ * (y - x),
    D(y) ~ x *- z) - y,
    D(z) ~ x * y - β * z,
    w ~ x + y + z + 2 * β
    0 ~ x^2 + y^2 - w2^2
]
@mtkbuild sys = ODESystem(eqs, t)

That should make it so that it eliminates the w term, but doesn't eliminate the w2 term. The DAE check is on the w2 term, the observed handling check is on the w term.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That will need to change the integrator to Rodas5P, Tsit5 will not be compatible with this form.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For SDEs, we will just need to make it compatible with BrownBasicInit.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I was so confused why it worked out, I see the InitialFailure now.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay good. Yeah because MTK is too smart and makes lots of simple examples not DAEs 😅. But now you got the DAE, and if not running the built in init then you get the error I was expecting. The fix is that it needs to run brownbasic before solving for the same reason reverse needs to. Good we worked out a test for this

save_start = save_start, save_end = save_end,
saveat = saveat, kwargs_fwd...)
elseif ischeckpointing(sensealg)
sol = solve(_prob, alg, args...; save_noise = true,
sol = solve(_prob, alg, args...; initializealg = SciMLBase.NoInit(), save_noise = true,
save_start = true, save_end = true,
saveat = saveat, kwargs_fwd...)
else
sol = solve(_prob, alg, args...; save_noise = true, save_start = true,
sol = solve(_prob, alg, args...; initializealg = SciMLBase.NoInit(), save_noise = true, save_start = true,
save_end = true, kwargs_fwd...)
end

Expand Down Expand Up @@ -491,6 +519,7 @@ function DiffEqBase._concrete_solve_adjoint(
_save_idxs = save_idxs === nothing ? Colon() : save_idxs

function adjoint_sensitivity_backpass(Δ)
Δ = Δ isa AbstractThunk ? unthunk(Δ) : Δ
function df_iip(_out, u, p, t, i)
outtype = _out isa SubArray ?
ArrayInterface.parameterless_type(_out.parent) :
Expand Down Expand Up @@ -642,6 +671,8 @@ function DiffEqBase._concrete_solve_adjoint(
dp = p === nothing || p === DiffEqBase.NullParameters() ? nothing :
dp isa AbstractArray ? reshape(dp', size(tunables)) : dp

dp = Zygote.accum(dp, igs)

_, repack_adjoint = if p === nothing || p === DiffEqBase.NullParameters() ||
!isscimlstructure(p)
nothing, x -> (x,)
Expand Down Expand Up @@ -1679,6 +1710,7 @@ function DiffEqBase._concrete_solve_adjoint(
u0, p, originator::SciMLBase.ADOriginator,
args...; save_idxs = nothing, kwargs...)
_prob = remake(prob, u0 = u0, p = p)

sol = solve(_prob, alg, args...; kwargs...)
_save_idxs = save_idxs === nothing ? Colon() : save_idxs

Expand All @@ -1688,26 +1720,56 @@ function DiffEqBase._concrete_solve_adjoint(
out = SciMLBase.sensitivity_solution(sol, sol[_save_idxs])
end

_, repack_adjoint = if isscimlstructure(p)
Zygote.pullback(p) do p
t, _, _ = canonicalize(Tunable(), p)
t
end
else
nothing, x -> (x,)
end

function steadystatebackpass(Δ)
Δ = Δ isa AbstractThunk ? unthunk(Δ) : Δ
# Δ = dg/dx or diffcache.dg_val
# del g/del p = 0
function df(_out, u, p, t, i)
if _save_idxs isa Number
_out[_save_idxs] = Δ[_save_idxs]
elseif Δ isa Number
@. _out[_save_idxs] = Δ
else
elseif Δ isa AbstractArray{<:AbstractArray} || Δ isa AbstractVectorOfArray || Δ isa AbstractArray
@. _out[_save_idxs] = Δ[_save_idxs]
elseif isnothing(_out)
_out
else
@. _out[_save_idxs] = Δ.u[_save_idxs]
end
end
dp = adjoint_sensitivities(sol, alg; sensealg = sensealg, dgdu = df)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When the new reverse ode is built it needs to drop the initial eqs but still keep the dae constraints. It can brownbasic?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way to drop the initial eqs after its solved? The assumption was since we run with NoInit, no initialization is run post the first call to get_initial_values and we accumulate those gradients independently of the adaptive solve.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But the reverse pass needs to run with some form of initialization or the starting algebraic conditions may not be satisfied. Don't run this one with NoInit(), that would be prone to hiding issue. For this one, at most CheckInit(), but I'm saying that BrownBasicInit() is likely the one justified here since the 0 initial condition is only true on the differential variables, while the algebraic variable initial conditions will be unknown, but the Newton solve will have zero derivative because all of the inputs are just Newton guesses, so BrownBasic will work out for the reverse. We should probably hardcode that since it's always the solution there.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, that will require us to add an OrdinaryDiffEqCore dep in this package. I will add that.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the 0 derivative also applicable to parameters? Or only the unknowns?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its applicable to all Newton guess values. There is no parameter init going on to reverse so it's only for algebraic conditions so it's only Newton guesses.


dp, Δtunables = if Δ isa AbstractArray || Δ isa Number
# if Δ isa AbstractArray, the gradients correspond to `u`
# this is something that needs changing in the future, but
# this is the applicable till the movement to structuaral
# tangents is completed
dp, _, _ = canonicalize(Tunable(), dp)
dp, nothing
else
Δp = setproperties(dp, to_nt(Δ.prob.p))
Δtunables, _, _ = canonicalize(Tunable(), Δp)
dp, _, _ = canonicalize(Tunable(), dp)
dp, Δtunables
end

dp = Zygote.accum(dp, Δtunables)

if originator isa SciMLBase.TrackerOriginator ||
originator isa SciMLBase.ReverseDiffOriginator
(NoTangent(), NoTangent(), NoTangent(), dp, NoTangent(),
(NoTangent(), NoTangent(), NoTangent(), repack_adjoint(dp)[1], NoTangent(),
ntuple(_ -> NoTangent(), length(args))...)
else
(NoTangent(), NoTangent(), NoTangent(), NoTangent(), dp, NoTangent(),
(NoTangent(), NoTangent(), NoTangent(), NoTangent(), repack_adjoint(dp)[1], NoTangent(),
ntuple(_ -> NoTangent(), length(args))...)
end
end
Expand Down
23 changes: 17 additions & 6 deletions src/derivative_wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,13 @@ function jacobian(f, x::AbstractArray{<:Number},
return J
end

function jacobian!(J::Nothing, f, x::AbstractArray{<:Number},
fx::Union{Nothing, AbstractArray{<:Number}},
alg::AbstractOverloadingSensitivityAlgorithm, jac_config::Nothing)
@assert isempty(x)
J
end
jacobian!(J::PreallocationTools.DiffCache, x::SciMLBase.UJacobianWrapper, args...) = jacobian!(J.du, x, args...)
function jacobian!(J::AbstractMatrix{<:Number}, f, x::AbstractArray{<:Number},
fx::Union{Nothing, AbstractArray{<:Number}},
alg::AbstractOverloadingSensitivityAlgorithm, jac_config)
Expand Down Expand Up @@ -456,9 +463,10 @@ function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::ReverseDiffVJP, dg
elseif inplace_sensitivity(S)
_y = eltype(y) === eltype(λ) ? y : convert.(promote_type(eltype(y), eltype(λ)), y)
if W === nothing
tape = ReverseDiff.GradientTape((_y, _p, [t])) do u, p, t
_tunables, _repack, _ = canonicalize(Tunable(), _p)
tape = ReverseDiff.GradientTape((_y, _tunables, [t])) do u, p, t
du1 = similar(u, size(u))
f(du1, u, p, first(t))
f(du1, u, _repack(p), first(t))
return vec(du1)
end
else
Expand All @@ -474,8 +482,9 @@ function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::ReverseDiffVJP, dg
else
_y = eltype(y) === eltype(λ) ? y : convert.(promote_type(eltype(y), eltype(λ)), y)
if W === nothing
tape = ReverseDiff.GradientTape((_y, _p, [t])) do u, p, t
vec(f(u, p, first(t)))
_tunables, _repack, _ = canonicalize(Tunable(), _p)
tape = ReverseDiff.GradientTape((_y, _tunables, [t])) do u, p, t
vec(f(u, _repack(p), first(t)))
end
else
_W = eltype(W) === eltype(λ) ? W :
Expand Down Expand Up @@ -1047,6 +1056,7 @@ function accumulate_cost(dλ, y, p, t, S::TS,
return dλ, dgrad
end

build_jac_config(alg, uf, u::Nothing) = nothing
function build_jac_config(alg, uf, u)
if alg_autodiff(alg)
jac_config = ForwardDiff.JacobianConfig(uf, u, u,
Expand All @@ -1068,9 +1078,10 @@ end

function build_param_jac_config(alg, pf, u, p)
if alg_autodiff(alg)
jac_config = ForwardDiff.JacobianConfig(pf, u, p,
tunables, repack, aliases = canonicalize(Tunable(), p)
jac_config = ForwardDiff.JacobianConfig(pf, u, tunables,
ForwardDiff.Chunk{
determine_chunksize(p,
determine_chunksize(tunables,
alg)}())
else
if diff_type(alg) != Val{:complex}
Expand Down
7 changes: 5 additions & 2 deletions src/parameters_handling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ end
recursive_copyto!(y::T, x::T) where {T} = fmap(recursive_copyto!, y, x)
recursive_copyto!(y, ::Nothing) = y
recursive_copyto!(::Nothing, ::Nothing) = nothing
function recursive_copyto!(y::T, x::NamedTuple) where T
fmap(recursive_copyto!, y, x)
end

"""
neg!(x)
Expand Down Expand Up @@ -61,14 +64,14 @@ recursive_add!(::Nothing, ::Nothing) = nothing

`similar(λ, size(x))` for generic `x`. This is used to handle non-array parameters!
"""
allocate_vjp(λ::AbstractArray, x::AbstractArray) = similar(λ, size(x))
allocate_vjp(λ::AbstractArray{T}, x::AbstractArray) where T = fill!(similar(λ, size(x)), zero(T))
allocate_vjp(λ::AbstractArray, x::Tuple) = allocate_vjp.((λ,), x)
function allocate_vjp(λ::AbstractArray, x::NamedTuple{F}) where {F}
NamedTuple{F}(allocate_vjp.((λ,), values(x)))
end
allocate_vjp(λ::AbstractArray, x) = fmap(Base.Fix1(allocate_vjp, λ), x)

allocate_vjp(x::AbstractArray) = similar(x)
allocate_vjp(x::AbstractArray) = zero(x) # similar(x)
allocate_vjp(x::Tuple) = allocate_vjp.(x)
allocate_vjp(x::NamedTuple{F}) where {F} = NamedTuple{F}(allocate_vjp.(values(x)))
allocate_vjp(x) = fmap(allocate_vjp, x)
Expand Down
Loading
Loading