Skip to content

Feat: Handle Adjoints through Initialization #1168

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 61 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
d6290bf
feat(initialization): get gradients against initialization problem
DhairyaLGandhi Mar 6, 2025
d3199c0
test: add initialization problem to test suite
DhairyaLGandhi Mar 6, 2025
a4fa7c5
chore: pass tunables to jacobian
DhairyaLGandhi Mar 6, 2025
5a7dd26
test: cleanup imports
DhairyaLGandhi Mar 6, 2025
94ec324
chore: rm debug statements
DhairyaLGandhi Mar 6, 2025
0c5564e
test: add Core8 to CI
DhairyaLGandhi Mar 6, 2025
aca9cd4
chore: use get_initial_values
DhairyaLGandhi Mar 11, 2025
9bec784
chore: check for OVERDETERMINED initialization for solving initializa…
DhairyaLGandhi Mar 11, 2025
72cbb35
chore: pass sensealg to initial_values
DhairyaLGandhi Mar 11, 2025
2d85e19
chore: treat Delta as array
DhairyaLGandhi Mar 13, 2025
9a8a845
chore: use autojacvec from sensealg
DhairyaLGandhi Mar 13, 2025
95ebbf3
chore: move igs before solve, re-use initialization
DhairyaLGandhi Mar 13, 2025
957d7fe
chore: update igs to re-use inital values
DhairyaLGandhi Mar 13, 2025
a00574f
chore: qualify NoInit
DhairyaLGandhi Mar 13, 2025
4562f0c
chore: remove igs from steady state adjoint gor initialization
DhairyaLGandhi Mar 17, 2025
6c21324
chore: accumulate gradients in steady state adjoint explicitly
DhairyaLGandhi Mar 19, 2025
a675a7f
fix: handle MTKparameters and Arrays uniformly
DhairyaLGandhi Mar 20, 2025
7941a3c
feat: allow reverse mode for initialization solving
DhairyaLGandhi Mar 20, 2025
9557e8c
test: add more tests for parameter initialization
DhairyaLGandhi Mar 20, 2025
8feae0e
test: fix label
DhairyaLGandhi Mar 20, 2025
d3b1669
chore: rename file
DhairyaLGandhi Mar 20, 2025
e01eb77
test: fix sensealg and confusing error message
DhairyaLGandhi Mar 21, 2025
0ad6c62
chore: return new_u0
DhairyaLGandhi Mar 24, 2025
6df7987
chore: rebase branch
DhairyaLGandhi Mar 24, 2025
c4c7807
chore: mark symbol local
DhairyaLGandhi Mar 25, 2025
b85b16e
chore: pass tunables to Tape
DhairyaLGandhi Mar 25, 2025
8fc4136
chore: update new_u0, new_p to orignal vals if not initializing
DhairyaLGandhi Mar 26, 2025
1f1cce5
chore: rebase master
DhairyaLGandhi Mar 26, 2025
0d1abcc
Merge branch 'dg/initprob' of github.com:SciML/SciMLSensitivity.jl in…
DhairyaLGandhi Mar 26, 2025
b164b18
chore: add MSL to test deps
DhairyaLGandhi Mar 26, 2025
915d949
feat: allow analytically solved initialization solutions to propagate…
DhairyaLGandhi Apr 3, 2025
d2fd79a
chore: force allocated buffers for vjp to be deterministic
DhairyaLGandhi Apr 3, 2025
a0cd94a
chore: pass tunables to allocate vjp of MTKParameters
DhairyaLGandhi Apr 3, 2025
885794d
test: Core6 dont access J.du
DhairyaLGandhi Apr 4, 2025
13a1ffb
chore: SteadyStateAdjoint could be thunk
DhairyaLGandhi Apr 4, 2025
7826866
chore: import AbstractThunk
DhairyaLGandhi Apr 4, 2025
6ceaa1a
chore: handle upstream pullback better in steady state adjoint
DhairyaLGandhi Apr 4, 2025
dfebd0b
chore: dont accum pullback for parameters
DhairyaLGandhi Apr 4, 2025
396f63e
test: import SII
DhairyaLGandhi Apr 5, 2025
de7e7da
test: wrap ps in ComponentArray
DhairyaLGandhi Apr 7, 2025
f5fb559
chore: call du for jacobian
DhairyaLGandhi Apr 10, 2025
019a051
chore: add recursive_copyto for identical NT trees
DhairyaLGandhi Apr 10, 2025
91ee019
deps: MSL compat
DhairyaLGandhi Apr 10, 2025
4b74718
chore: undo du access
DhairyaLGandhi Apr 10, 2025
94d5e2b
chore: handle J through fwd mode
DhairyaLGandhi Apr 10, 2025
764d3ff
chore: J = nothing instead of NT
DhairyaLGandhi Apr 11, 2025
84b2602
chore: check nothing in steady state
DhairyaLGandhi Apr 13, 2025
8a4aa79
chore: also canonicalize dp
DhairyaLGandhi Apr 13, 2025
d69ccb1
chore: adjoint of preallocation tools
DhairyaLGandhi Apr 13, 2025
2cc3673
chore: pass Δ to canonicalize
DhairyaLGandhi Apr 14, 2025
22f056a
chore: handle different parameter types
DhairyaLGandhi Apr 14, 2025
1f95b25
chore: check for number
DhairyaLGandhi Apr 14, 2025
0d78fa8
test: rm commented out code
DhairyaLGandhi Apr 14, 2025
6620e8a
chore: get tunables from dp
DhairyaLGandhi Apr 14, 2025
5f5633b
test: clean up initialization tests
DhairyaLGandhi Apr 14, 2025
984c2ce
chore: pass initializealg
DhairyaLGandhi Apr 15, 2025
82cd5fe
chore: force path through BrownBasicInit
DhairyaLGandhi Apr 16, 2025
987e8be
chore: replace NoInit with BrownBasicInit
DhairyaLGandhi Apr 17, 2025
056fffa
test: add test to check residual with initialization
DhairyaLGandhi Apr 17, 2025
a122340
chore: replace BrownBasicInit with CheckInit
DhairyaLGandhi Apr 17, 2025
e105838
chore: qualify checinit
DhairyaLGandhi Apr 17, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ jobs:
- Core5
- Core6
- Core7
- Core8
- QA
- SDE1
- SDE2
Expand Down
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
OrdinaryDiffEqCore = "bbf590c4-e513-4bbe-9b18-05decba2e5d8"
PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46"
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down Expand Up @@ -77,12 +78,14 @@ LinearSolve = "2, 3"
Lux = "1"
Markdown = "1.10"
ModelingToolkit = "9.42"
ModelingToolkitStandardLibrary = "2"
Mooncake = "0.4.52"
NLsolve = "4.5.1"
NonlinearSolve = "3.0.1, 4"
Optimization = "4"
OptimizationOptimisers = "0.3"
OrdinaryDiffEq = "6.81.1"
OrdinaryDiffEqCore = "1"
Pkg = "1.10"
PreallocationTools = "0.4.4"
QuadGK = "2.9.1"
Expand Down Expand Up @@ -117,6 +120,7 @@ DelayDiffEq = "bcd4f6db-9728-5f36-b5f7-82caef46ccdb"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
ModelingToolkitStandardLibrary = "16a59e39-deab-5bd0-87e4-056b12336739"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
Expand All @@ -131,4 +135,4 @@ StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["AlgebraicMultigrid", "Aqua", "Calculus", "ComponentArrays", "DelayDiffEq", "Distributed", "Lux", "ModelingToolkit", "Mooncake", "NLsolve", "NonlinearSolve", "Optimization", "OptimizationOptimisers", "OrdinaryDiffEq", "Pkg", "SafeTestsets", "SparseArrays", "SteadyStateDiffEq", "StochasticDiffEq", "Test"]
test = ["AlgebraicMultigrid", "Aqua", "Calculus", "ComponentArrays", "DelayDiffEq", "Distributed", "Lux", "ModelingToolkit", "ModelingToolkitStandardLibrary", "Mooncake", "NLsolve", "NonlinearSolve", "Optimization", "OptimizationOptimisers", "OrdinaryDiffEq", "Pkg", "SafeTestsets", "SparseArrays", "SteadyStateDiffEq", "StochasticDiffEq", "Test"]
7 changes: 6 additions & 1 deletion src/SciMLSensitivity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,17 @@ using SciMLBase: SciMLBase, AbstractOverloadingSensitivityAlgorithm,
get_tmp_cache, has_adjoint, isinplace, reinit!, remake,
solve, u_modified!, LinearAliasSpecifier

using OrdinaryDiffEqCore: BrownFullBasicInit

# AD Backends
using ChainRulesCore: unthunk, @thunk, NoTangent, @not_implemented, Tangent, ZeroTangent
using ChainRulesCore: unthunk, @thunk, NoTangent, @not_implemented, Tangent, ZeroTangent, AbstractThunk
using Enzyme: Enzyme
using FiniteDiff: FiniteDiff
using ForwardDiff: ForwardDiff
using Tracker: Tracker, TrackedArray
using ReverseDiff: ReverseDiff
using Zygote: Zygote
using SciMLBase.ConstructionBase

# Std Libs
using LinearAlgebra: LinearAlgebra, Diagonal, I, UniformScaling, adjoint, axpy!,
Expand All @@ -56,6 +59,8 @@ using Markdown: Markdown, @doc_str
using Random: Random, rand!
using Statistics: Statistics, mean

using LinearAlgebra: diag

abstract type SensitivityFunction end
abstract type TransformedFunction end

Expand Down
40 changes: 26 additions & 14 deletions src/adjoint_common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ function adjointdiffcache(g::G, sensealg, discrete, sol, dgdu::DG1, dgdp::DG2, f
unwrappedf = unwrapped_f(f)

numparams = p === nothing || p === SciMLBase.NullParameters() ? 0 : length(tunables)
numindvar = length(u0)
numindvar = isnothing(u0) ? nothing : length(u0)
isautojacvec = get_jacvec(sensealg)

issemiexplicitdae = false
Expand Down Expand Up @@ -106,18 +106,22 @@ function adjointdiffcache(g::G, sensealg, discrete, sol, dgdu::DG1, dgdp::DG2, f
isempty(algevar_idxs) || (issemiexplicitdae = true)
end
if !issemiexplicitdae
diffvar_idxs = eachindex(u0)
diffvar_idxs = isnothing(u0) ? nothing : eachindex(u0)
algevar_idxs = 1:0
end

if !needs_jac && !issemiexplicitdae && !(autojacvec isa Bool)
J = nothing
else
if alg === nothing || SciMLBase.forwarddiffs_model_time(alg)
# 1 chunk is fine because it's only t
_J = similar(u0, numindvar, numindvar)
_J .= 0
J = dualcache(_J, ForwardDiff.pickchunksize(length(u0)))
if !isnothing(u0)
# 1 chunk is fine because it's only t
_J = similar(u0, numindvar, numindvar)
_J .= 0
J = dualcache(_J, ForwardDiff.pickchunksize(length(u0)))
else
J = nothing
end
else
J = similar(u0, numindvar, numindvar)
J .= 0
Expand All @@ -133,8 +137,12 @@ function adjointdiffcache(g::G, sensealg, discrete, sol, dgdu::DG1, dgdp::DG2, f
dg_val[1] .= false
dg_val[2] .= false
else
dg_val = similar(u0, numindvar) # number of funcs size
dg_val .= false
if !isnothing(u0)
dg_val = similar(u0, numindvar) # number of funcs size
dg_val .= false
else
dg_val = nothing
end
end
else
pgpu = UGradientWrapper(g, _t, p)
Expand Down Expand Up @@ -241,8 +249,12 @@ function adjointdiffcache(g::G, sensealg, discrete, sol, dgdu::DG1, dgdp::DG2, f
pJ = if (quad || !(autojacvec isa Bool))
nothing
else
_pJ = similar(u0, numindvar, numparams)
_pJ .= false
if !isnothing(u0)
_pJ = similar(u0, numindvar, numparams)
_pJ .= false
else
_pJ = nothing
end
end

f_cache = isinplace ? deepcopy(u0) : nothing
Expand Down Expand Up @@ -379,11 +391,11 @@ function get_paramjac_config(autojacvec::ReverseDiffVJP, p, f, y, _p, _t;
if !isRODE
__p = p isa SciMLBase.NullParameters ? _p :
SciMLStructures.replace(Tunable(), p, _p)
tape = ReverseDiff.GradientTape((y, __p, [_t])) do u, p, t
tape = ReverseDiff.GradientTape((y, _p, [_t])) do u, p, t
du1 = (p !== nothing && p !== SciMLBase.NullParameters()) ?
similar(p, size(u)) : similar(u)
du1 .= false
f(du1, u, p, first(t))
f(du1, u, repack(p), first(t))
return vec(du1)
end
else
Expand All @@ -402,8 +414,8 @@ function get_paramjac_config(autojacvec::ReverseDiffVJP, p, f, y, _p, _t;
# because hasportion(Tunable(), NullParameters) == false
__p = p isa SciMLBase.NullParameters ? _p :
SciMLStructures.replace(Tunable(), p, _p)
tape = ReverseDiff.GradientTape((y, __p, [_t])) do u, p, t
vec(f(u, p, first(t)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same down here?

tape = ReverseDiff.GradientTape((y, _p, [_t])) do u, p, t
vec(f(u, repack(p), first(t)))
end
else
tape = ReverseDiff.GradientTape((y, _p, [_t], _W)) do u, p, t, W
Expand Down
83 changes: 75 additions & 8 deletions src/concrete_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,16 @@ function inplace_vjp(prob, u0, p, verbose, repack)

vjp = try
f = unwrapped_f(prob.f)
tspan_ = prob isa AbstractNonlinearProblem ? nothing : [prob.tspan[1]]
if p === nothing || p isa SciMLBase.NullParameters
ReverseDiff.GradientTape((copy(u0), [prob.tspan[1]])) do u, t
ReverseDiff.GradientTape((copy(u0), tspan_)) do u, t
du1 = similar(u, size(u))
du1 .= 0
f(du1, u, p, first(t))
return vec(du1)
end
else
ReverseDiff.GradientTape((copy(u0), p, [prob.tspan[1]])) do u, p, t
ReverseDiff.GradientTape((copy(u0), p, tspan_)) do u, p, t
du1 = similar(u, size(u))
du1 .= 0
f(du1, u, repack(p), first(t))
Expand Down Expand Up @@ -299,6 +300,7 @@ function DiffEqBase._concrete_solve_adjoint(
tunables, repack = Functors.functor(p)
end

u0 = state_values(prob) === nothing ? Float64[] : u0
default_sensealg = automatic_sensealg_choice(prob, u0, tunables, verbose, repack)
DiffEqBase._concrete_solve_adjoint(prob, alg, default_sensealg, u0, p,
originator::SciMLBase.ADOriginator, args...; verbose,
Expand Down Expand Up @@ -371,6 +373,7 @@ function DiffEqBase._concrete_solve_adjoint(
args...; save_start = true, save_end = true,
saveat = eltype(prob.tspan)[],
save_idxs = nothing,
initializealg_default = SciMLBase.OverrideInit(; abstol = 1e-6, reltol = 1e-3),
kwargs...)
if !(sensealg isa GaussAdjoint) &&
!(p isa Union{Nothing, SciMLBase.NullParameters, AbstractArray}) ||
Expand Down Expand Up @@ -412,16 +415,46 @@ function DiffEqBase._concrete_solve_adjoint(
Base.diff_names(Base._nt_names(values(kwargs)),
(:callback_adj, :callback))}(values(kwargs))
isq = sensealg isa QuadratureAdjoint

if haskey(kwargs, :initializealg) || haskey(prob.kwargs, :initializealg)
initializealg = haskey(kwargs, :initializealg) ? kwargs[:initializealg] : prob.kwargs[:initializealg]
else
initializealg = initializealg_default
end

igs, new_u0, new_p = if _prob.f.initialization_data !== nothing
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also needs to check that initializealg is not set, is the default, or is using OverrideInit. Should test this is not triggered with say manual BrownBasicInit

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. My understanding was that OverrideInit was what we strictly needed. We can check for BrownBasicInit/ defaults here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There doesn't seem to be a method which can take a BrownFullBasicInit(). I get a MethodError:

ERROR: MethodError: no method matching get_initial_values(::ODEProblem{…}, ::ODEProblem{…}, ::ODEFunction{…}, ::BrownFullBasicInit{…}, ::Val{…}; sensealg::SteadyStateAdjoint{…}, nlsolve_alg::Nothing)

Closest candidates are:
  get_initial_values(::Any, ::Any, ::Any, ::NoInit, ::Any; kwargs...)
   @ SciMLBase ~/Downloads/arpa/jsmo/t2/SciMLBase.jl/src/initialization.jl:282
  get_initial_values(::Any, ::Any, ::Any, ::SciMLBase.OverrideInit, ::Union{Val{true}, Val{false}}; nlsolve_alg, abstol, reltol, kwargs...)
   @ SciMLBase ~/Downloads/arpa/jsmo/t2/SciMLBase.jl/src/initialization.jl:224
  get_initial_values(::SciMLBase.AbstractDEProblem, ::SciMLBase.DEIntegrator, ::Any, ::CheckInit, ::Union{Val{true}, Val{false}}; abstol, kwargs...)
   @ SciMLBase ~/Downloads/arpa/jsmo/t2/SciMLBase.jl/src/initialization.jl:161
  ...

Only CheckInit, NoInit, and OverrideInit have dispatches.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having chatted with @AayushSabharwal on this, it seems like BrownBasic and ShampineCollocation do not yet have a path through get_initial_values and that would need to be fixed in OrdianryDiffEq. Further, as SciMLSensitivity does not depend on OrdinaryDiffEq, it cannot check for whether there is a default initialisation.since those are defined there. Depending on it also seems like a pretty big hammer for a dep.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would be the best course of action here? Seems like supporting BrownBasicInit is a dispatch that will automatically be utilised when it is moved into SciMLBase.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I mean BrownBasicInit should not be taking this path. But that's a problem because then they will be disabled in the next stage below, and that needs to be accounted for. This dispatch is already built and setup for BrownBasicInit and there are tests on that.

local new_u0
local new_p
iy, back = Zygote.pullback(tunables) do tunables
new_prob = remake(_prob, p = repack(tunables))
new_u0, new_p, _ = SciMLBase.get_initial_values(new_prob, new_prob, new_prob.f, initializealg, Val(isinplace(new_prob));
sensealg = SteadyStateAdjoint(autojacvec = sensealg.autojacvec),
kwargs_fwd...)
new_tunables, _, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), new_p)
if SciMLBase.initialization_status(_prob) == SciMLBase.OVERDETERMINED
sum(new_tunables)
else
sum(new_u0) + sum(new_tunables)
end
end
igs = back(one(iy))[1] .- one(eltype(tunables))

igs, new_u0, new_p
else
nothing, u0, p
end
_prob = remake(_prob, u0 = new_u0, p = new_p)

if sensealg isa BacksolveAdjoint
sol = solve(_prob, alg, args...; save_noise = true,
sol = solve(_prob, alg, args...; initializealg = SciMLBase.CheckInit(), save_noise = true,
save_start = save_start, save_end = save_end,
saveat = saveat, kwargs_fwd...)
elseif ischeckpointing(sensealg)
sol = solve(_prob, alg, args...; save_noise = true,
sol = solve(_prob, alg, args...; initializealg = SciMLBase.CheckInit(), save_noise = true,
save_start = true, save_end = true,
saveat = saveat, kwargs_fwd...)
else
sol = solve(_prob, alg, args...; save_noise = true, save_start = true,
sol = solve(_prob, alg, args...; initializealg = SciMLBase.CheckInit(), save_noise = true, save_start = true,
save_end = true, kwargs_fwd...)
end

Expand Down Expand Up @@ -491,6 +524,7 @@ function DiffEqBase._concrete_solve_adjoint(
_save_idxs = save_idxs === nothing ? Colon() : save_idxs

function adjoint_sensitivity_backpass(Δ)
Δ = Δ isa AbstractThunk ? unthunk(Δ) : Δ
function df_iip(_out, u, p, t, i)
outtype = _out isa SubArray ?
ArrayInterface.parameterless_type(_out.parent) :
Expand Down Expand Up @@ -642,6 +676,8 @@ function DiffEqBase._concrete_solve_adjoint(
dp = p === nothing || p === DiffEqBase.NullParameters() ? nothing :
dp isa AbstractArray ? reshape(dp', size(tunables)) : dp

dp = Zygote.accum(dp, igs)

_, repack_adjoint = if p === nothing || p === DiffEqBase.NullParameters() ||
!isscimlstructure(p)
nothing, x -> (x,)
Expand Down Expand Up @@ -1679,6 +1715,7 @@ function DiffEqBase._concrete_solve_adjoint(
u0, p, originator::SciMLBase.ADOriginator,
args...; save_idxs = nothing, kwargs...)
_prob = remake(prob, u0 = u0, p = p)

sol = solve(_prob, alg, args...; kwargs...)
_save_idxs = save_idxs === nothing ? Colon() : save_idxs

Expand All @@ -1688,26 +1725,56 @@ function DiffEqBase._concrete_solve_adjoint(
out = SciMLBase.sensitivity_solution(sol, sol[_save_idxs])
end

_, repack_adjoint = if isscimlstructure(p)
Zygote.pullback(p) do p
t, _, _ = canonicalize(Tunable(), p)
t
end
else
nothing, x -> (x,)
end

function steadystatebackpass(Δ)
Δ = Δ isa AbstractThunk ? unthunk(Δ) : Δ
# Δ = dg/dx or diffcache.dg_val
# del g/del p = 0
function df(_out, u, p, t, i)
if _save_idxs isa Number
_out[_save_idxs] = Δ[_save_idxs]
elseif Δ isa Number
@. _out[_save_idxs] = Δ
else
elseif Δ isa AbstractArray{<:AbstractArray} || Δ isa AbstractVectorOfArray || Δ isa AbstractArray
@. _out[_save_idxs] = Δ[_save_idxs]
elseif isnothing(_out)
_out
else
@. _out[_save_idxs] = Δ.u[_save_idxs]
end
end
dp = adjoint_sensitivities(sol, alg; sensealg = sensealg, dgdu = df)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When the new reverse ode is built it needs to drop the initial eqs but still keep the dae constraints. It can brownbasic?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way to drop the initial eqs after its solved? The assumption was since we run with NoInit, no initialization is run post the first call to get_initial_values and we accumulate those gradients independently of the adaptive solve.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But the reverse pass needs to run with some form of initialization or the starting algebraic conditions may not be satisfied. Don't run this one with NoInit(), that would be prone to hiding issue. For this one, at most CheckInit(), but I'm saying that BrownBasicInit() is likely the one justified here since the 0 initial condition is only true on the differential variables, while the algebraic variable initial conditions will be unknown, but the Newton solve will have zero derivative because all of the inputs are just Newton guesses, so BrownBasic will work out for the reverse. We should probably hardcode that since it's always the solution there.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, that will require us to add an OrdinaryDiffEqCore dep in this package. I will add that.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the 0 derivative also applicable to parameters? Or only the unknowns?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its applicable to all Newton guess values. There is no parameter init going on to reverse so it's only for algebraic conditions so it's only Newton guesses.


dp, Δtunables = if Δ isa AbstractArray || Δ isa Number
# if Δ isa AbstractArray, the gradients correspond to `u`
# this is something that needs changing in the future, but
# this is the applicable till the movement to structuaral
# tangents is completed
dp, _, _ = canonicalize(Tunable(), dp)
dp, nothing
else
Δp = setproperties(dp, to_nt(Δ.prob.p))
Δtunables, _, _ = canonicalize(Tunable(), Δp)
dp, _, _ = canonicalize(Tunable(), dp)
dp, Δtunables
end

dp = Zygote.accum(dp, Δtunables)

if originator isa SciMLBase.TrackerOriginator ||
originator isa SciMLBase.ReverseDiffOriginator
(NoTangent(), NoTangent(), NoTangent(), dp, NoTangent(),
(NoTangent(), NoTangent(), NoTangent(), repack_adjoint(dp)[1], NoTangent(),
ntuple(_ -> NoTangent(), length(args))...)
else
(NoTangent(), NoTangent(), NoTangent(), NoTangent(), dp, NoTangent(),
(NoTangent(), NoTangent(), NoTangent(), NoTangent(), repack_adjoint(dp)[1], NoTangent(),
ntuple(_ -> NoTangent(), length(args))...)
end
end
Expand Down
23 changes: 17 additions & 6 deletions src/derivative_wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,13 @@ function jacobian(f, x::AbstractArray{<:Number},
return J
end

function jacobian!(J::Nothing, f, x::AbstractArray{<:Number},
fx::Union{Nothing, AbstractArray{<:Number}},
alg::AbstractOverloadingSensitivityAlgorithm, jac_config::Nothing)
@assert isempty(x)
J
end
jacobian!(J::PreallocationTools.DiffCache, x::SciMLBase.UJacobianWrapper, args...) = jacobian!(J.du, x, args...)
function jacobian!(J::AbstractMatrix{<:Number}, f, x::AbstractArray{<:Number},
fx::Union{Nothing, AbstractArray{<:Number}},
alg::AbstractOverloadingSensitivityAlgorithm, jac_config)
Expand Down Expand Up @@ -456,9 +463,10 @@ function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::ReverseDiffVJP, dg
elseif inplace_sensitivity(S)
_y = eltype(y) === eltype(λ) ? y : convert.(promote_type(eltype(y), eltype(λ)), y)
if W === nothing
tape = ReverseDiff.GradientTape((_y, _p, [t])) do u, p, t
_tunables, _repack, _ = canonicalize(Tunable(), _p)
tape = ReverseDiff.GradientTape((_y, _tunables, [t])) do u, p, t
du1 = similar(u, size(u))
f(du1, u, p, first(t))
f(du1, u, _repack(p), first(t))
return vec(du1)
end
else
Expand All @@ -474,8 +482,9 @@ function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::ReverseDiffVJP, dg
else
_y = eltype(y) === eltype(λ) ? y : convert.(promote_type(eltype(y), eltype(λ)), y)
if W === nothing
tape = ReverseDiff.GradientTape((_y, _p, [t])) do u, p, t
vec(f(u, p, first(t)))
_tunables, _repack, _ = canonicalize(Tunable(), _p)
tape = ReverseDiff.GradientTape((_y, _tunables, [t])) do u, p, t
vec(f(u, _repack(p), first(t)))
end
else
_W = eltype(W) === eltype(λ) ? W :
Expand Down Expand Up @@ -1047,6 +1056,7 @@ function accumulate_cost(dλ, y, p, t, S::TS,
return dλ, dgrad
end

build_jac_config(alg, uf, u::Nothing) = nothing
function build_jac_config(alg, uf, u)
if alg_autodiff(alg)
jac_config = ForwardDiff.JacobianConfig(uf, u, u,
Expand All @@ -1068,9 +1078,10 @@ end

function build_param_jac_config(alg, pf, u, p)
if alg_autodiff(alg)
jac_config = ForwardDiff.JacobianConfig(pf, u, p,
tunables, repack, aliases = canonicalize(Tunable(), p)
jac_config = ForwardDiff.JacobianConfig(pf, u, tunables,
ForwardDiff.Chunk{
determine_chunksize(p,
determine_chunksize(tunables,
alg)}())
else
if diff_type(alg) != Val{:complex}
Expand Down
Loading
Loading