Skip to content

Commit ac4f1c4

Browse files
Merge pull request #3666 from AayushSabharwal/as/remove-pdeps
refactor: remove `parameter_dependencies` kwarg
2 parents 117bf00 + f26064f commit ac4f1c4

18 files changed

+211
-174
lines changed

src/systems/abstractsystem.jl

Lines changed: 98 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@ function has_parameter_dependency_with_lhs(sys, sym)
295295
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
296296
return haskey(ic.dependent_pars_to_timeseries, unwrap(sym))
297297
else
298-
return any(isequal(sym), [eq.lhs for eq in parameter_dependencies(sys)])
298+
return any(isequal(sym), [eq.lhs for eq in get_parameter_dependencies(sys)])
299299
end
300300
end
301301

@@ -565,7 +565,7 @@ function add_initialization_parameters(sys::AbstractSystem; split = true)
565565
D = Differential(get_iv(sys))
566566
union!(all_initialvars, [D(v) for v in all_initialvars if iscall(v)])
567567
end
568-
for eq in parameter_dependencies(sys)
568+
for eq in get_parameter_dependencies(sys)
569569
is_variable_floatingpoint(eq.lhs) || continue
570570
push!(all_initialvars, eq.lhs)
571571
end
@@ -596,6 +596,22 @@ function isinitial(p)
596596
operation(p) === getindex && isinitial(arguments(p)[1]))
597597
end
598598

599+
"""
600+
$(TYPEDSIGNATURES)
601+
602+
Find [`GlobalScope`](@ref)d variables in `sys` and add them to the unknowns/parameters.
603+
"""
604+
function discover_globalscoped(sys::AbstractSystem)
605+
newunknowns = OrderedSet()
606+
newparams = OrderedSet()
607+
iv = has_iv(sys) ? get_iv(sys) : nothing
608+
collect_scoped_vars!(newunknowns, newparams, sys, iv; depth = -1)
609+
setdiff!(newunknowns, observables(sys))
610+
@set! sys.ps = unique!(vcat(get_ps(sys), collect(newparams)))
611+
@set! sys.unknowns = unique!(vcat(get_unknowns(sys), collect(newunknowns)))
612+
return sys
613+
end
614+
599615
"""
600616
$(TYPEDSIGNATURES)
601617
@@ -612,13 +628,7 @@ using [`toggle_namespacing`](@ref).
612628
"""
613629
function complete(
614630
sys::AbstractSystem; split = true, flatten = true, add_initial_parameters = true)
615-
newunknowns = OrderedSet()
616-
newparams = OrderedSet()
617-
iv = has_iv(sys) ? get_iv(sys) : nothing
618-
collect_scoped_vars!(newunknowns, newparams, sys, iv; depth = -1)
619-
# don't update unknowns to not disturb `mtkcompile` order
620-
# `GlobalScope`d unknowns will be picked up and added there
621-
@set! sys.ps = unique!(vcat(get_ps(sys), collect(newparams)))
631+
sys = discover_globalscoped(sys)
622632

623633
if flatten
624634
eqs = equations(sys)
@@ -632,6 +642,7 @@ function complete(
632642
@set! newsys.parent = complete(sys; split = false, flatten = false)
633643
end
634644
sys = newsys
645+
sys = process_parameter_equations(sys)
635646
if add_initial_parameters
636647
sys = add_initialization_parameters(sys; split)
637648
end
@@ -1263,6 +1274,12 @@ function parameters(sys::AbstractSystem; initial_parameters = false)
12631274
end
12641275

12651276
function dependent_parameters(sys::AbstractSystem)
1277+
if !iscomplete(sys)
1278+
throw(ArgumentError("""
1279+
`dependent_parameters` requires that the system is marked as complete. Call
1280+
`complete` or `mtkcompile` on the system.
1281+
"""))
1282+
end
12661283
return map(eq -> eq.lhs, parameter_dependencies(sys))
12671284
end
12681285

@@ -1279,27 +1296,33 @@ function parameters_toplevel(sys::AbstractSystem)
12791296
end
12801297

12811298
"""
1282-
$(TYPEDSIGNATURES)
1283-
Get the parameter dependencies of the system `sys` and its subsystems.
1299+
$(TYPEDSIGNATURES)
12841300
1285-
See also [`defaults`](@ref) and [`ModelingToolkit.get_parameter_dependencies`](@ref).
1301+
Get the parameter dependencies of the system `sys` and its subsystems. Requires that the
1302+
system is `complete`d.
12861303
"""
12871304
function parameter_dependencies(sys::AbstractSystem)
1305+
if !iscomplete(sys)
1306+
throw(ArgumentError("""
1307+
`parameter_dependencies` requires that the system is marked as complete. Call \
1308+
`complete` or `mtkcompile` on the system.
1309+
"""))
1310+
end
12881311
if !has_parameter_dependencies(sys)
12891312
return Equation[]
12901313
end
1291-
pdeps = get_parameter_dependencies(sys)
1292-
systems = get_systems(sys)
1293-
# put pdeps after those of subsystems to maintain topological sorted order
1294-
namespaced_deps = mapreduce(
1295-
s -> map(eq -> namespace_equation(eq, s), parameter_dependencies(s)), vcat,
1296-
systems; init = Equation[])
1297-
1298-
return vcat(namespaced_deps, pdeps)
1314+
get_parameter_dependencies(sys)
12991315
end
13001316

1317+
"""
1318+
$(TYPEDSIGNATURES)
1319+
1320+
Return all of the parameters of the system, including hidden initial parameters and ones
1321+
eliminated via `parameter_dependencies`.
1322+
"""
13011323
function full_parameters(sys::AbstractSystem)
1302-
vcat(parameters(sys; initial_parameters = true), dependent_parameters(sys))
1324+
dep_ps = [eq.lhs for eq in get_parameter_dependencies(sys)]
1325+
vcat(parameters(sys; initial_parameters = true), dep_ps)
13031326
end
13041327

13051328
"""
@@ -2079,7 +2102,7 @@ function Base.show(
20792102
end
20802103

20812104
# Print parameter dependencies
2082-
npdeps = has_parameter_dependencies(sys) ? length(parameter_dependencies(sys)) : 0
2105+
npdeps = has_parameter_dependencies(sys) ? length(get_parameter_dependencies(sys)) : 0
20832106
npdeps > 0 && printstyled(io, "\nParameter dependencies ($npdeps):"; bold)
20842107
npdeps > 0 && hint && print(io, " see parameter_dependencies($name)")
20852108

@@ -2588,15 +2611,15 @@ function extend(sys::AbstractSystem, basesys::AbstractSystem;
25882611
eqs = union(get_eqs(basesys), get_eqs(sys))
25892612
sts = union(get_unknowns(basesys), get_unknowns(sys))
25902613
ps = union(get_ps(basesys), get_ps(sys))
2591-
dep_ps = union(parameter_dependencies(basesys), parameter_dependencies(sys))
2614+
dep_ps = union(get_parameter_dependencies(basesys), get_parameter_dependencies(sys))
25922615
obs = union(get_observed(basesys), get_observed(sys))
25932616
cevs = union(get_continuous_events(basesys), get_continuous_events(sys))
25942617
devs = union(get_discrete_events(basesys), get_discrete_events(sys))
25952618
defs = merge(get_defaults(basesys), get_defaults(sys)) # prefer `sys`
25962619
meta = merge(get_metadata(basesys), get_metadata(sys))
25972620
syss = union(get_systems(basesys), get_systems(sys))
25982621
args = length(ivs) == 0 ? (eqs, sts, ps) : (eqs, ivs[1], sts, ps)
2599-
kwargs = (parameter_dependencies = dep_ps, observed = obs, continuous_events = cevs,
2622+
kwargs = (observed = obs, continuous_events = cevs,
26002623
discrete_events = devs, defaults = defs, systems = syss, metadata = meta,
26012624
name = name, description = description, gui_metadata = gui_metadata)
26022625

@@ -2610,7 +2633,10 @@ function extend(sys::AbstractSystem, basesys::AbstractSystem;
26102633
kwargs, (; assertions = merge(get_assertions(basesys), get_assertions(sys))))
26112634
end
26122635

2613-
return T(args...; kwargs...)
2636+
newsys = T(args...; kwargs...)
2637+
@set! newsys.parameter_dependencies = dep_ps
2638+
2639+
return newsys
26142640
end
26152641

26162642
"""
@@ -2752,60 +2778,62 @@ function Symbolics.substitute(sys::AbstractSystem, rules::Union{Vector{<:Pair},
27522778
initialization_eqs = fast_substitute(get_initialization_eqs(sys), rules)
27532779
cstrs = fast_substitute(get_constraints(sys), rules)
27542780
subsys = map(s -> substitute(s, rules), get_systems(sys))
2755-
System(eqs, get_iv(sys); name = nameof(sys), defaults = defs,
2756-
guesses = guess, parameter_dependencies = pdeps, systems = subsys, noise_eqs,
2781+
newsys = System(eqs, get_iv(sys); name = nameof(sys), defaults = defs,
2782+
guesses = guess, systems = subsys, noise_eqs,
27572783
observed, initialization_eqs, constraints = cstrs)
2784+
@set! newsys.parameter_dependencies = pdeps
27582785
else
27592786
error("substituting symbols is not supported for $(typeof(sys))")
27602787
end
27612788
end
27622789

2763-
struct InvalidParameterDependenciesType
2764-
got::Any
2765-
end
2766-
2767-
function Base.showerror(io::IO, err::InvalidParameterDependenciesType)
2768-
print(
2769-
io, "Parameter dependencies must be a `Dict`, or an array of `Pair` or `Equation`.")
2770-
if err.got !== nothing
2771-
print(io, " Got ", err.got)
2772-
end
2773-
end
2790+
"""
2791+
$(TYPEDSIGNATURES)
27742792
2775-
function process_parameter_dependencies(pdeps, ps)
2776-
if pdeps === nothing || isempty(pdeps)
2777-
return Equation[], ps
2778-
end
2779-
if pdeps isa Dict
2780-
pdeps = [k ~ v for (k, v) in pdeps]
2781-
else
2782-
pdeps isa AbstractArray || throw(InvalidParameterDependenciesType(pdeps))
2783-
pdeps = [if p isa Pair
2784-
p[1] ~ p[2]
2785-
elseif p isa Equation
2786-
p
2787-
else
2788-
error("Parameter dependencies must be a `Dict`, `Vector{Pair}` or `Vector{Equation}`")
2789-
end
2790-
for p in pdeps]
2793+
Find equations of `sys` involving only parameters and separate them out into the
2794+
`parameter_dependencies` field. Relative ordering of equations is maintained.
2795+
Parameter-only equations are validated to be explicit and sorted topologically. All such
2796+
explicitly determined parameters are removed from the parameters of `sys`. Return the new
2797+
system.
2798+
"""
2799+
function process_parameter_equations(sys::AbstractSystem)
2800+
if !isempty(get_systems(sys))
2801+
throw(ArgumentError("Expected flattened system"))
27912802
end
2792-
lhss = []
2793-
for p in pdeps
2794-
if !isparameter(p.lhs)
2795-
error("LHS of parameter dependency must be a single parameter. Found $(p.lhs).")
2796-
end
2797-
syms = vars(p.rhs)
2798-
if !all(isparameter, syms)
2799-
error("RHS of parameter dependency must only include parameters. Found $(p.rhs)")
2803+
varsbuf = Set()
2804+
pareq_idxs = Int[]
2805+
eqs = equations(sys)
2806+
for (i, eq) in enumerate(eqs)
2807+
empty!(varsbuf)
2808+
vars!(varsbuf, eq; op = Union{Differential, Initial, Pre})
2809+
# singular equations
2810+
isempty(varsbuf) && continue
2811+
if all(varsbuf) do sym
2812+
is_parameter(sys, sym) ||
2813+
symbolic_type(sym) == ArraySymbolic() &&
2814+
is_sized_array_symbolic(sym) &&
2815+
all(Base.Fix1(is_parameter, sys), collect(sym))
2816+
end
2817+
if !isparameter(eq.lhs)
2818+
throw(ArgumentError("""
2819+
LHS of parameter dependency equation must be a single parameter. Found \
2820+
$(eq.lhs).
2821+
"""))
2822+
end
2823+
push!(pareq_idxs, i)
28002824
end
2801-
push!(lhss, p.lhs)
2802-
end
2803-
lhss = map(identity, lhss)
2804-
pdeps = topsort_equations(pdeps, union(ps, lhss))
2805-
ps = filter(ps) do p
2806-
!any(isequal(p), lhss)
28072825
end
2808-
return pdeps, ps
2826+
2827+
pareqs = [get_parameter_dependencies(sys); eqs[pareq_idxs]]
2828+
explicitpars = [eq.lhs for eq in pareqs]
2829+
pareqs = topsort_equations(pareqs, explicitpars)
2830+
2831+
eqs = eqs[setdiff(eachindex(eqs), pareq_idxs)]
2832+
2833+
@set! sys.eqs = eqs
2834+
@set! sys.parameter_dependencies = pareqs
2835+
@set! sys.ps = setdiff(get_ps(sys), explicitpars)
2836+
return sys
28092837
end
28102838

28112839
"""
@@ -2829,7 +2857,7 @@ See also: [`ModelingToolkit.dump_variable_metadata`](@ref), [`ModelingToolkit.du
28292857
"""
28302858
function dump_parameters(sys::AbstractSystem)
28312859
defs = defaults(sys)
2832-
pdeps = parameter_dependencies(sys)
2860+
pdeps = get_parameter_dependencies(sys)
28332861
metas = map(dump_variable_metadata.(parameters(sys))) do meta
28342862
if haskey(defs, meta.var)
28352863
meta = merge(meta, (; default = defs[meta.var]))

src/systems/codegen_utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ function build_function_wrapper(sys::AbstractSystem, expr, args...; p_start = 2,
246246
p_start += 1
247247
p_end += 1
248248
end
249-
pdeps = parameter_dependencies(sys)
249+
pdeps = get_parameter_dependencies(sys)
250250

251251
# only get the necessary observed equations, avoiding extra computation
252252
if add_observed && !isempty(obs)

src/systems/diffeqs/basic_transformations.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,12 +223,13 @@ function change_independent_variable(
223223
wasflat = isempty(systems)
224224
sys = typeof(sys)( # recreate system with transformed fields
225225
eqs, iv2, unknowns, ps; observed, initialization_eqs,
226-
parameter_dependencies, defaults, guesses, connector_type,
226+
defaults, guesses, connector_type,
227227
assertions, name = nameof(sys), description = description(sys)
228228
)
229229
sys = compose(sys, systems) # rebuild hierarchical system
230230
if wascomplete
231231
sys = complete(sys; split = wassplit, flatten = wasflat) # complete output if input was complete
232+
@set! sys.parameter_dependencies = parameter_dependencies
232233
end
233234
return sys
234235
end

src/systems/index_cache.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ function IndexCache(sys::AbstractSystem)
315315
dependent_pars_to_timeseries = Dict{
316316
Union{BasicSymbolic, CallWithMetadata}, TimeseriesSetType}()
317317

318-
for eq in parameter_dependencies(sys)
318+
for eq in get_parameter_dependencies(sys)
319319
sym = eq.lhs
320320
vs = vars(eq.rhs)
321321
timeseries = TimeseriesSetType()

src/systems/nonlinear/initializesystem.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -167,15 +167,15 @@ function generate_initializesystem_timevarying(sys::AbstractSystem;
167167
for k in keys(defs)
168168
defs[k] = substitute(defs[k], paramsubs)
169169
end
170-
return System(eqs_ics,
170+
isys = System(eqs_ics,
171171
vars,
172172
pars;
173173
defaults = defs,
174174
checks = check_units,
175-
parameter_dependencies = new_parameter_deps,
176175
name,
177176
is_initializesystem = true,
178177
kwargs...)
178+
@set isys.parameter_dependencies = new_parameter_deps
179179
end
180180

181181
"""
@@ -280,15 +280,15 @@ function generate_initializesystem_timeindependent(sys::AbstractSystem;
280280
for k in keys(defs)
281281
defs[k] = substitute(defs[k], paramsubs)
282282
end
283-
return System(eqs_ics,
283+
isys = System(eqs_ics,
284284
vars,
285285
pars;
286286
defaults = defs,
287287
checks = check_units,
288-
parameter_dependencies = new_parameter_deps,
289288
name,
290289
is_initializesystem = true,
291290
kwargs...)
291+
@set isys.parameter_dependencies = new_parameter_deps
292292
end
293293

294294
"""

0 commit comments

Comments
 (0)