Skip to content

Commit d91b8d1

Browse files
Merge pull request #3585 from AayushSabharwal/as/mtkparams-ptype
fix: allow specifying type of buffers inside `MTKParameters`
2 parents 60da9d5 + 01dd661 commit d91b8d1

20 files changed

+596
-218
lines changed

.github/workflows/Downstream.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ jobs:
3838
- {user: SciML, repo: MethodOfLines.jl, group: 2D_Diffusion}
3939
- {user: SciML, repo: MethodOfLines.jl, group: DAE}
4040
- {user: SciML, repo: ModelingToolkitNeuralNets.jl, group: All}
41+
- {user: SciML, repo: SciMLSensitivity.jl, group: Core8}
4142

4243
- {user: Neuroblox, repo: Neuroblox.jl, group: All}
4344
steps:

Project.toml

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
88
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
99
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
1010
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
11+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
1112
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
1213
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
1314
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
@@ -65,7 +66,6 @@ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
6566
[weakdeps]
6667
BifurcationKit = "0f109fa4-8a5d-4b75-95aa-f515264e7665"
6768
CasADi = "c49709b8-5c63-11e9-2fb2-69db5844192f"
68-
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
6969
DeepDiffs = "ab62b9b5-e342-54a8-a765-a90f495de1a6"
7070
FMI = "14a09403-18e3-468f-ad8a-74f8dda2d9ac"
7171
InfiniteOpt = "20393b10-9daf-11e9-18c9-8db751c92c57"
@@ -74,7 +74,6 @@ LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
7474
[extensions]
7575
MTKBifurcationKitExt = "BifurcationKit"
7676
MTKCasADiDynamicOptExt = "CasADi"
77-
MTKChainRulesCoreExt = "ChainRulesCore"
7877
MTKDeepDiffsExt = "DeepDiffs"
7978
MTKFMIExt = "FMI"
8079
MTKInfiniteOptExt = "InfiniteOpt"
@@ -142,15 +141,15 @@ RecursiveArrayTools = "3.26"
142141
Reexport = "0.2, 1"
143142
RuntimeGeneratedFunctions = "0.5.9"
144143
SCCNonlinearSolve = "1.0.0"
145-
SciMLBase = "2.84"
144+
SciMLBase = "2.91.1"
146145
SciMLStructures = "1.7"
147146
Serialization = "1"
148147
Setfield = "0.7, 0.8, 1"
149148
SimpleNonlinearSolve = "0.1.0, 1, 2"
150149
SparseArrays = "1"
151150
SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2"
152151
StaticArrays = "0.10, 0.11, 0.12, 1.0"
153-
StochasticDelayDiffEq = "1.8.1"
152+
StochasticDelayDiffEq = "1.10"
154153
StochasticDiffEq = "6.72.1"
155154
SymbolicIndexingInterface = "0.3.39"
156155
SymbolicUtils = "3.26.1"

src/ModelingToolkit.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ import BlockArrays: BlockArray, BlockedArray, Block, blocksize, blocksizes, bloc
6262
using OffsetArrays: Origin
6363
import CommonSolve
6464
import EnumX
65+
import ChainRulesCore
66+
import ChainRulesCore: Tangent, ZeroTangent, NoTangent, zero_tangent, unthunk
6567

6668
using RuntimeGeneratedFunctions
6769
using RuntimeGeneratedFunctions: drop_expr
@@ -204,6 +206,8 @@ include("structural_transformation/StructuralTransformations.jl")
204206
@reexport using .StructuralTransformations
205207
include("inputoutput.jl")
206208

209+
include("adjoints.jl")
210+
207211
for S in subtypes(ModelingToolkit.AbstractSystem)
208212
S = nameof(S)
209213
@eval convert_system(::Type{<:$S}, sys::$S) = sys

ext/MTKChainRulesCoreExt.jl renamed to src/adjoints.jl

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,11 @@
1-
module MTKChainRulesCoreExt
2-
3-
import ModelingToolkit as MTK
4-
import ChainRulesCore
5-
import ChainRulesCore: Tangent, ZeroTangent, NoTangent, zero_tangent, unthunk
6-
7-
function ChainRulesCore.rrule(::Type{MTK.MTKParameters}, tunables, args...)
1+
function ChainRulesCore.rrule(::Type{MTKParameters}, tunables, args...)
82
function mtp_pullback(dt)
93
dt = unthunk(dt)
104
dtunables = dt isa AbstractArray ? dt : dt.tunable
115
(NoTangent(), dtunables[1:length(tunables)],
126
ntuple(_ -> NoTangent(), length(args))...)
137
end
14-
MTK.MTKParameters(tunables, args...), mtp_pullback
8+
MTKParameters(tunables, args...), mtp_pullback
159
end
1610

1711
function subset_idxs(idxs, portion, template)
@@ -70,23 +64,23 @@ function selected_tangents(
7064
end
7165

7266
function ChainRulesCore.rrule(
73-
::typeof(MTK.remake_buffer), indp, oldbuf::MTK.MTKParameters, idxs, vals)
67+
::typeof(remake_buffer), indp, oldbuf::MTKParameters, idxs, vals)
7468
if idxs isa AbstractSet
7569
idxs = collect(idxs)
7670
end
7771
idxs = map(idxs) do i
78-
i isa MTK.ParameterIndex ? i : MTK.parameter_index(indp, i)
72+
i isa ParameterIndex ? i : parameter_index(indp, i)
7973
end
80-
newbuf = MTK.remake_buffer(indp, oldbuf, idxs, vals)
74+
newbuf = remake_buffer(indp, oldbuf, idxs, vals)
8175
tunable_idxs = reduce(
82-
vcat, (idx.idx for idx in idxs if idx.portion isa MTK.SciMLStructures.Tunable);
76+
vcat, (idx.idx for idx in idxs if idx.portion isa SciMLStructures.Tunable);
8377
init = Union{Int, AbstractVector{Int}}[])
8478
initials_idxs = reduce(
85-
vcat, (idx.idx for idx in idxs if idx.portion isa MTK.SciMLStructures.Initials);
79+
vcat, (idx.idx for idx in idxs if idx.portion isa SciMLStructures.Initials);
8680
init = Union{Int, AbstractVector{Int}}[])
87-
disc_idxs = subset_idxs(idxs, MTK.SciMLStructures.Discrete(), oldbuf.discrete)
88-
const_idxs = subset_idxs(idxs, MTK.SciMLStructures.Constants(), oldbuf.constant)
89-
nn_idxs = subset_idxs(idxs, MTK.NONNUMERIC_PORTION, oldbuf.nonnumeric)
81+
disc_idxs = subset_idxs(idxs, SciMLStructures.Discrete(), oldbuf.discrete)
82+
const_idxs = subset_idxs(idxs, SciMLStructures.Constants(), oldbuf.constant)
83+
nn_idxs = subset_idxs(idxs, NONNUMERIC_PORTION, oldbuf.nonnumeric)
9084

9185
pullback = let idxs = idxs
9286
function remake_buffer_pullback(buf′)
@@ -102,13 +96,11 @@ function ChainRulesCore.rrule(
10296
oldbuf′ = Tangent{typeof(oldbuf)}(;
10397
tunable, initials, discrete, constant, nonnumeric)
10498
idxs′ = NoTangent()
105-
vals′ = map(i -> MTK._ducktyped_parameter_values(buf′, i), idxs)
99+
vals′ = map(i -> _ducktyped_parameter_values(buf′, i), idxs)
106100
return f′, indp′, oldbuf′, idxs′, vals′
107101
end
108102
end
109103
newbuf, pullback
110104
end
111105

112-
ChainRulesCore.@non_differentiable Base.getproperty(sys::MTK.AbstractSystem, x::Symbol)
113-
114-
end
106+
ChainRulesCore.@non_differentiable Base.getproperty(sys::AbstractSystem, x::Symbol)

src/linearization.jl

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ function (linfun::LinearizationFunction)(u, p, t)
285285
linfun.num_states == length(u) ||
286286
error("Number of unknown variables ($(linfun.num_states)) does not match the number of input unknowns ($(length(u)))")
287287
integ_cache = (linfun.caches,)
288-
integ = MockIntegrator{true}(u, p, t, integ_cache, nothing)
288+
integ = MockIntegrator{true}(u, p, t, fun, integ_cache, nothing)
289289
u, p, success = SciMLBase.get_initial_values(
290290
linfun.prob, integ, fun, linfun.initializealg, Val(true);
291291
linfun.initialize_kwargs...)
@@ -325,7 +325,7 @@ Mock `DEIntegrator` to allow using `CheckInit` without having to create a new in
325325
326326
$(TYPEDFIELDS)
327327
"""
328-
struct MockIntegrator{iip, U, P, T, C, O} <: SciMLBase.DEIntegrator{Nothing, iip, U, T}
328+
struct MockIntegrator{iip, U, P, T, F, C, O} <: SciMLBase.DEIntegrator{Nothing, iip, U, T}
329329
"""
330330
The state vector.
331331
"""
@@ -339,6 +339,10 @@ struct MockIntegrator{iip, U, P, T, C, O} <: SciMLBase.DEIntegrator{Nothing, iip
339339
"""
340340
t::T
341341
"""
342+
The wrapped `SciMLFunction`.
343+
"""
344+
f::F
345+
"""
342346
The integrator cache.
343347
"""
344348
cache::C
@@ -348,8 +352,9 @@ struct MockIntegrator{iip, U, P, T, C, O} <: SciMLBase.DEIntegrator{Nothing, iip
348352
opts::O
349353
end
350354

351-
function MockIntegrator{iip}(u::U, p::P, t::T, cache::C, opts::O) where {iip, U, P, T, C, O}
352-
return MockIntegrator{iip, U, P, T, C, O}(u, p, t, cache, opts)
355+
function MockIntegrator{iip}(
356+
u::U, p::P, t::T, f::F, cache::C, opts::O) where {iip, U, P, T, F, C, O}
357+
return MockIntegrator{iip, U, P, T, F, C, O}(u, p, t, f, cache, opts)
353358
end
354359

355360
SymbolicIndexingInterface.state_values(integ::MockIntegrator) = integ.u

src/structural_transformation/symbolics_tearing.jl

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -860,12 +860,9 @@ function cse_and_array_hacks(sys, obs, subeqs, unknowns, neweqs; cse = true, arr
860860
# map of array observed variable (unscalarized) to number of its
861861
# scalarized terms that appear in observed equations
862862
arr_obs_occurrences = Dict()
863-
# to check if array variables occur in unscalarized form anywhere
864-
all_vars = Set()
865863
for (i, eq) in enumerate(obs)
866864
lhs = eq.lhs
867865
rhs = eq.rhs
868-
vars!(all_vars, rhs)
869866

870867
# HACK 1
871868
if cse && is_getindexed_array(rhs)
@@ -920,7 +917,6 @@ function cse_and_array_hacks(sys, obs, subeqs, unknowns, neweqs; cse = true, arr
920917
tempvar; T = Symbolics.symtype(rhs_arr)))
921918
tempvar = setmetadata(
922919
tempvar, Symbolics.ArrayShapeCtx, Symbolics.shape(rhs_arr))
923-
vars!(all_vars, rhs_arr)
924920
tempeq = tempvar ~ rhs_arr
925921
rhs_to_tempvar[rhs_arr] = tempvar
926922
push!(obs, tempeq)
@@ -946,18 +942,10 @@ function cse_and_array_hacks(sys, obs, subeqs, unknowns, neweqs; cse = true, arr
946942
cnt == 0 && continue
947943
arr_obs_occurrences[arg1] = cnt + 1
948944
end
949-
for eq in neweqs
950-
vars!(all_vars, eq.rhs)
951-
end
952945

953-
# also count unscalarized variables used in callbacks
954-
for ev in Iterators.flatten((continuous_events(sys), discrete_events(sys)))
955-
vars!(all_vars, ev)
956-
end
957946
obs_arr_eqs = Equation[]
958947
for (arrvar, cnt) in arr_obs_occurrences
959948
cnt == length(arrvar) || continue
960-
arrvar in all_vars || continue
961949
# firstindex returns 1 for multidimensional array symbolics
962950
firstind = first(eachindex(arrvar))
963951
scal = [arrvar[i] for i in eachindex(arrvar)]

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1479,7 +1479,7 @@ function InitializationProblem{iip, specialize}(sys::AbstractSystem,
14791479
end
14801480

14811481
if simplify_system
1482-
isys = structural_simplify(isys; fully_determined)
1482+
isys = structural_simplify(isys; fully_determined, split = is_split(sys))
14831483
end
14841484

14851485
ts = get_tearing_state(isys)
@@ -1554,6 +1554,6 @@ function InitializationProblem{iip, specialize}(sys::AbstractSystem,
15541554
else
15551555
NonlinearLeastSquaresProblem
15561556
end
1557-
TProb(isys, u0map, parammap; kwargs...,
1557+
TProb{iip}(isys, u0map, parammap; kwargs...,
15581558
build_initializeprob = false, is_initializeprob = true)
15591559
end

src/systems/jumps/jumpsystem.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,7 @@ function DiffEqBase.DiscreteProblem(sys::JumpSystem, u0map, tspan::Union{Tuple,
408408
error("The passed in JumpSystem contains `Equation`s or continuous events, please use a problem type that supports these features, such as ODEProblem.")
409409
end
410410

411-
_f, u0, p = process_SciMLProblem(EmptySciMLFunction, sys, u0map, parammap;
411+
_f, u0, p = process_SciMLProblem(EmptySciMLFunction{true}, sys, u0map, parammap;
412412
t = tspan === nothing ? nothing : tspan[1], tofloat = false, check_length = false, build_initializeprob = false, cse)
413413
f = DiffEqBase.DISCRETE_INPLACE_DEFAULT
414414

@@ -449,7 +449,7 @@ function DiscreteProblemExpr{iip}(sys::JumpSystem, u0map, tspan::Union{Tuple, No
449449
error("A completed `JumpSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `DiscreteProblemExpr`")
450450
end
451451

452-
_, u0, p = process_SciMLProblem(EmptySciMLFunction, sys, u0map, parammap;
452+
_, u0, p = process_SciMLProblem(EmptySciMLFunction{iip}, sys, u0map, parammap;
453453
t = tspan === nothing ? nothing : tspan[1], tofloat = false, check_length = false)
454454
# identity function to make syms works
455455
quote
@@ -506,7 +506,7 @@ function DiffEqBase.ODEProblem(sys::JumpSystem, u0map, tspan::Union{Tuple, Nothi
506506
return ODEProblem(osys, u0map, tspan, parammap; check_length = false,
507507
build_initializeprob = false, kwargs...)
508508
else
509-
_, u0, p = process_SciMLProblem(EmptySciMLFunction, sys, u0map, parammap;
509+
_, u0, p = process_SciMLProblem(EmptySciMLFunction{true}, sys, u0map, parammap;
510510
t = tspan === nothing ? nothing : tspan[1], tofloat = false,
511511
check_length = false, build_initializeprob = false, cse)
512512
f = (du, u, p, t) -> (du .= 0; nothing)

src/systems/nonlinear/initializesystem.jl

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -513,8 +513,7 @@ function SciMLBase.remake_initialization_data(
513513
length(oldinitprob.f.resid_prototype), new_initu0, new_initp))
514514
end
515515
initprob = remake(oldinitprob; f = newf, u0 = new_initu0, p = new_initp)
516-
return SciMLBase.OverrideInitData(initprob, oldinitdata.update_initializeprob!,
517-
oldinitdata.initializeprobmap, oldinitdata.initializeprobpmap; metadata = oldinitdata.metadata)
516+
return @set oldinitdata.initializeprob = initprob
518517
end
519518

520519
dvs = unknowns(sys)
@@ -582,21 +581,35 @@ function SciMLBase.remake_initialization_data(
582581
op, missing_unknowns, missing_pars = build_operating_point!(sys,
583582
u0map, pmap, defs, cmap, dvs, ps)
584583
floatT = float_type_from_varmap(op)
584+
u0_constructor = p_constructor = identity
585+
if newu0 isa StaticArray
586+
u0_constructor = vals -> SymbolicUtils.Code.create_array(
587+
typeof(newu0), floatT, Val(1), Val(length(vals)), vals...)
588+
end
589+
if newp isa StaticArray || newp isa MTKParameters && newp.initials isa StaticArray
590+
p_constructor = vals -> SymbolicUtils.Code.create_array(
591+
typeof(newp.initials), floatT, Val(1), Val(length(vals)), vals...)
592+
end
585593
kws = maybe_build_initialization_problem(
586-
sys, op, u0map, pmap, t0, defs, guesses, missing_unknowns;
587-
use_scc, initialization_eqs, floatT, allow_incomplete = true)
594+
sys, SciMLBase.isinplace(odefn), op, u0map, pmap, t0, defs, guesses, missing_unknowns;
595+
use_scc, initialization_eqs, floatT, u0_constructor, p_constructor, allow_incomplete = true)
588596

589-
return SciMLBase.remake_initialization_data(sys, kws, newu0, t0, newp, newu0, newp)
597+
odefn = remake(odefn; kws...)
598+
return SciMLBase.remake_initialization_data(sys, odefn, newu0, t0, newp, newu0, newp)
590599
end
591600

592601
function promote_u0_p(u0, p::MTKParameters, t0)
593602
u0 = DiffEqBase.promote_u0(u0, p.tunable, t0)
594603
u0 = DiffEqBase.promote_u0(u0, p.initials, t0)
595604

596-
tunables = DiffEqBase.promote_u0(p.tunable, u0, t0)
597-
initials = DiffEqBase.promote_u0(p.initials, u0, t0)
598-
p = SciMLStructures.replace(SciMLStructures.Tunable(), p, tunables)
599-
p = SciMLStructures.replace(SciMLStructures.Initials(), p, initials)
605+
if !isempty(p.tunable)
606+
tunables = DiffEqBase.promote_u0(p.tunable, u0, t0)
607+
p = SciMLStructures.replace(SciMLStructures.Tunable(), p, tunables)
608+
end
609+
if !isempty(p.initials)
610+
initials = DiffEqBase.promote_u0(p.initials, u0, t0)
611+
p = SciMLStructures.replace(SciMLStructures.Initials(), p, initials)
612+
end
600613

601614
return u0, p
602615
end
@@ -627,12 +640,12 @@ function SciMLBase.late_binding_update_u0_p(
627640
if length(newu0) != length(prob.u0)
628641
throw(ArgumentError("Expected `newu0` to be of same length as unknowns ($(length(prob.u0))). Got $(typeof(newu0)) of length $(length(newu0))"))
629642
end
630-
meta.set_initial_unknowns!(newp, newu0)
643+
newp = meta.set_initial_unknowns!(newp, newu0)
631644
return newu0, newp
632645
end
633646

634-
newp = p === missing ? copy(newp) : newp
635-
647+
syms = []
648+
vals = []
636649
allsyms = all_symbols(sys)
637650
for (k, v) in u0
638651
v === nothing && continue
@@ -644,9 +657,11 @@ function SciMLBase.late_binding_update_u0_p(
644657
k = k2
645658
end
646659
is_parameter(sys, Initial(k)) || continue
647-
setp(sys, Initial(k))(newp, v)
660+
push!(syms, Initial(k))
661+
push!(vals, v)
648662
end
649663

664+
newp = setp_oop(sys, syms)(newp, vals)
650665
return newu0, newp
651666
end
652667

0 commit comments

Comments
 (0)