Skip to content

Commit 5a6a813

Browse files
Merge pull request #3493 from AayushSabharwal/as/param-derivatives
fix: handle derivatives of time-dependent parameters
2 parents b1ccc75 + a2b4745 commit 5a6a813

File tree

5 files changed

+158
-3
lines changed

5 files changed

+158
-3
lines changed

src/structural_transformation/pantelides.jl

+1
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ function pantelides_reassemble(state::TearingState, var_eq_matching)
5454
D(eq.lhs)
5555
end
5656
rhs = ModelingToolkit.expand_derivatives(D(eq.rhs))
57+
rhs = fast_substitute(rhs, state.param_derivative_map)
5758
substitution_dict = Dict(x.lhs => x.rhs
5859
for x in out_eqs if x !== nothing && x.lhs isa Symbolic)
5960
sub_rhs = substitute(rhs, substitution_dict)

src/structural_transformation/symbolics_tearing.jl

+17-1
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,23 @@ function eq_derivative!(ts::TearingState{ODESystem}, ieq::Int; kwargs...)
6565

6666
sys = ts.sys
6767
eq = equations(ts)[ieq]
68-
eq = 0 ~ Symbolics.derivative(eq.rhs - eq.lhs, get_iv(sys); throw_no_derivative = true)
68+
eq = 0 ~ fast_substitute(
69+
ModelingToolkit.derivative(
70+
eq.rhs - eq.lhs, get_iv(sys); throw_no_derivative = true), ts.param_derivative_map)
71+
72+
vs = ModelingToolkit.vars(eq.rhs)
73+
for v in vs
74+
# parameters with unknown derivatives have a value of `nothing` in the map,
75+
# so use `missing` as the default.
76+
get(ts.param_derivative_map, v, missing) === nothing || continue
77+
_original_eq = equations(ts)[ieq]
78+
error("""
79+
Encountered derivative of discrete variable `$(only(arguments(v)))` when \
80+
differentiating equation `$(_original_eq)`. This may indicate a model error or a \
81+
missing equation of the form `$v ~ ...` that defines this derivative.
82+
""")
83+
end
84+
6985
push!(equations(ts), eq)
7086
# Analyze the new equation and update the graph/solvable_graph
7187
# First, copy the previous incidence and add the derivative terms.

src/systems/systemstructure.jl

+30-1
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,7 @@ mutable struct TearingState{T <: AbstractSystem} <: AbstractTearingState{T}
207207
fullvars::Vector
208208
structure::SystemStructure
209209
extra_eqs::Vector
210+
param_derivative_map::Dict{BasicSymbolic, Any}
210211
end
211212

212213
TransformationState(sys::AbstractSystem) = TearingState(sys)
@@ -253,6 +254,12 @@ function Base.push!(ev::EquationsView, eq)
253254
push!(ev.ts.extra_eqs, eq)
254255
end
255256

257+
function is_time_dependent_parameter(p, iv)
258+
return iv !== nothing && isparameter(p) && iscall(p) &&
259+
(operation(p) === getindex && is_time_dependent_parameter(arguments(p)[1], iv) ||
260+
(args = arguments(p); length(args)) == 1 && isequal(only(args), iv))
261+
end
262+
256263
function TearingState(sys; quick_cancel = false, check = true)
257264
sys = flatten(sys)
258265
ivs = independent_variables(sys)
@@ -264,6 +271,7 @@ function TearingState(sys; quick_cancel = false, check = true)
264271
var2idx = Dict{Any, Int}()
265272
symbolic_incidence = []
266273
fullvars = []
274+
param_derivative_map = Dict{BasicSymbolic, Any}()
267275
var_counter = Ref(0)
268276
var_types = VariableType[]
269277
addvar! = let fullvars = fullvars, var_counter = var_counter, var_types = var_types
@@ -276,11 +284,23 @@ function TearingState(sys; quick_cancel = false, check = true)
276284

277285
vars = OrderedSet()
278286
varsvec = []
287+
eqs_to_retain = trues(length(eqs))
279288
for (i, eq′) in enumerate(eqs)
280289
if eq′.lhs isa Connection
281290
check ? error("$(nameof(sys)) has unexpanded `connect` statements") :
282291
return nothing
283292
end
293+
if iscall(eq′.lhs) && (op = operation(eq′.lhs)) isa Differential &&
294+
isequal(op.x, iv) && is_time_dependent_parameter(only(arguments(eq′.lhs)), iv)
295+
# parameter derivatives are opted out by specifying `D(p) ~ missing`, but
296+
# we want to store `nothing` in the map because that means `fast_substitute`
297+
# will ignore the rule. We will this identify the presence of `eq′.lhs` in
298+
# the differentiated expression and error.
299+
param_derivative_map[eq′.lhs] = coalesce(eq′.rhs, nothing)
300+
eqs_to_retain[i] = false
301+
# change the equation if the RHS is `missing` so the rest of this loop works
302+
eq′ = eq′.lhs ~ coalesce(eq′.rhs, 0.0)
303+
end
284304
if _iszero(eq′.lhs)
285305
rhs = quick_cancel ? quick_cancel_expr(eq′.rhs) : eq′.rhs
286306
eq = eq′
@@ -295,6 +315,12 @@ function TearingState(sys; quick_cancel = false, check = true)
295315
any(isequal(_var), ivs) && continue
296316
if isparameter(_var) ||
297317
(iscall(_var) && isparameter(operation(_var)) || isconstant(_var))
318+
if is_time_dependent_parameter(_var, iv) &&
319+
!haskey(param_derivative_map, Differential(iv)(_var))
320+
# Parameter derivatives default to zero - they stay constant
321+
# between callbacks
322+
param_derivative_map[Differential(iv)(_var)] = 0.0
323+
end
298324
continue
299325
end
300326
v = scalarize(v)
@@ -351,6 +377,9 @@ function TearingState(sys; quick_cancel = false, check = true)
351377
eqs[i] = eqs[i].lhs ~ rhs
352378
end
353379
end
380+
eqs = eqs[eqs_to_retain]
381+
neqs = length(eqs)
382+
symbolic_incidence = symbolic_incidence[eqs_to_retain]
354383

355384
### Handle discrete variables
356385
lowest_shift = Dict()
@@ -438,7 +467,7 @@ function TearingState(sys; quick_cancel = false, check = true)
438467
ts = TearingState(sys, fullvars,
439468
SystemStructure(complete(var_to_diff), complete(eq_to_diff),
440469
complete(graph), nothing, var_types, sys isa AbstractDiscreteSystem),
441-
Any[])
470+
Any[], param_derivative_map)
442471
if sys isa DiscreteSystem
443472
ts = shift_discrete_system(ts)
444473
end

test/state_selection.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ using ModelingToolkit, OrdinaryDiffEq, Test
22
using ModelingToolkit: t_nounits as t, D_nounits as D
33

44
sts = @variables x1(t) x2(t) x3(t) x4(t)
5-
params = @parameters u1(t) u2(t) u3(t) u4(t)
5+
params = @parameters u1 u2 u3 u4
66
eqs = [x1 + x2 + u1 ~ 0
77
x1 + x2 + x3 + u2 ~ 0
88
x1 + D(x3) + x4 + u3 ~ 0

test/structural_transformation/utils.jl

+109
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ using SparseArrays
55
using UnPack
66
using ModelingToolkit: t_nounits as t, D_nounits as D, default_toterm
77
using Symbolics: unwrap
8+
using DataInterpolations
89
const ST = StructuralTransformations
910

1011
# Define some variables
@@ -282,3 +283,111 @@ end
282283
@test length(mapping) == 3
283284
end
284285
end
286+
287+
@testset "Issue#3480: Derivatives of time-dependent parameters" begin
288+
@component function FilteredInput(; name, x0 = 0, T = 0.1)
289+
params = @parameters begin
290+
k(t) = x0
291+
T = T
292+
end
293+
vars = @variables begin
294+
x(t) = k
295+
dx(t) = 0
296+
ddx(t)
297+
end
298+
systems = []
299+
eqs = [D(x) ~ dx
300+
D(dx) ~ ddx
301+
dx ~ (k - x) / T]
302+
return ODESystem(eqs, t, vars, params; systems, name)
303+
end
304+
305+
@component function FilteredInputExplicit(; name, x0 = 0, T = 0.1)
306+
params = @parameters begin
307+
k(t)[1:1] = [x0]
308+
T = T
309+
end
310+
vars = @variables begin
311+
x(t) = k
312+
dx(t) = 0
313+
ddx(t)
314+
end
315+
systems = []
316+
eqs = [D(x) ~ dx
317+
D(dx) ~ ddx
318+
D(k[1]) ~ 1.0
319+
dx ~ (k[1] - x) / T]
320+
return ODESystem(eqs, t, vars, params; systems, name)
321+
end
322+
323+
@component function FilteredInputErr(; name, x0 = 0, T = 0.1)
324+
params = @parameters begin
325+
k(t) = x0
326+
T = T
327+
end
328+
vars = @variables begin
329+
x(t) = k
330+
dx(t) = 0
331+
ddx(t)
332+
end
333+
systems = []
334+
eqs = [D(x) ~ dx
335+
D(dx) ~ ddx
336+
dx ~ (k - x) / T
337+
D(k) ~ missing]
338+
return ODESystem(eqs, t, vars, params; systems, name)
339+
end
340+
341+
@named sys = FilteredInputErr()
342+
@test_throws ["derivative of discrete variable", "k(t)"] structural_simplify(sys)
343+
344+
@mtkbuild sys = FilteredInput()
345+
vs = Set()
346+
for eq in equations(sys)
347+
ModelingToolkit.vars!(vs, eq)
348+
end
349+
for eq in observed(sys)
350+
ModelingToolkit.vars!(vs, eq)
351+
end
352+
353+
@test !(D(sys.k) in vs)
354+
355+
@mtkbuild sys = FilteredInputExplicit()
356+
obsfn1 = ModelingToolkit.build_explicit_observed_function(sys, sys.ddx)
357+
obsfn2 = ModelingToolkit.build_explicit_observed_function(sys, sys.dx)
358+
u = [1.0]
359+
p = MTKParameters(sys, [sys.k => [2.0], sys.T => 3.0])
360+
@test obsfn1(u, p, 0.0) (1 - obsfn2(u, p, 0.0)) / 3.0
361+
362+
@testset "Called parameter still has derivative" begin
363+
@component function FilteredInput2(; name, x0 = 0, T = 0.1)
364+
ts = collect(0.0:0.1:10.0)
365+
spline = LinearInterpolation(ts .^ 2, ts)
366+
params = @parameters begin
367+
(k::LinearInterpolation)(..) = spline
368+
T = T
369+
end
370+
vars = @variables begin
371+
x(t) = k(t)
372+
dx(t) = 0
373+
ddx(t)
374+
end
375+
systems = []
376+
eqs = [D(x) ~ dx
377+
D(dx) ~ ddx
378+
dx ~ (k(t) - x) / T]
379+
return ODESystem(eqs, t, vars, params; systems, name)
380+
end
381+
382+
@mtkbuild sys = FilteredInput2()
383+
vs = Set()
384+
for eq in equations(sys)
385+
ModelingToolkit.vars!(vs, eq)
386+
end
387+
for eq in observed(sys)
388+
ModelingToolkit.vars!(vs, eq)
389+
end
390+
391+
@test D(sys.k(t)) in vs
392+
end
393+
end

0 commit comments

Comments
 (0)