Skip to content

Commit 5d3cb26

Browse files
Merge pull request #3696 from AayushSabharwal/as/v9-array-initprobpmap
[v9] fix: remove CSE hack, fix unscalarized variables in initializeprobpmap
2 parents bc87875 + 74c38c5 commit 5d3cb26

File tree

5 files changed

+68
-168
lines changed

5 files changed

+68
-168
lines changed

src/structural_transformation/symbolics_tearing.jl

Lines changed: 11 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -724,7 +724,7 @@ Update the system equations, unknowns, and observables after simplification.
724724
"""
725725
function update_simplified_system!(
726726
state::TearingState, neweqs, solved_eqs, dummy_sub, var_eq_matching, extra_unknowns;
727-
cse_hack = true, array_hack = true)
727+
array_hack = true)
728728
@unpack solvable_graph, var_to_diff, eq_to_diff, graph = state.structure
729729
diff_to_var = invview(var_to_diff)
730730

@@ -748,8 +748,7 @@ function update_simplified_system!(
748748
unknowns = [unknowns; extra_unknowns]
749749
@set! sys.unknowns = unknowns
750750

751-
obs = cse_and_array_hacks(
752-
sys, obs, unknowns, neweqs; cse = cse_hack, array = array_hack)
751+
obs = tearing_hacks(sys, obs, unknowns, neweqs; array = array_hack)
753752

754753
deps = Vector{Int}[i == 1 ? Int[] : collect(1:(i - 1))
755754
for i in 1:length(solved_eqs)]
@@ -793,7 +792,7 @@ appear in the system. Algebraic variables are variables that are not
793792
differential variables.
794793
"""
795794
function tearing_reassemble(state::TearingState, var_eq_matching,
796-
full_var_eq_matching = nothing; simplify = false, mm = nothing, cse_hack = true, array_hack = true)
795+
full_var_eq_matching = nothing; simplify = false, mm = nothing, array_hack = true)
797796
extra_vars = Int[]
798797
if full_var_eq_matching !== nothing
799798
for v in 𝑑vertices(state.structure.graph)
@@ -829,68 +828,30 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
829828
state, var_eq_matching, eq_ordering, var_ordering, nelim_eq, nelim_var)
830829

831830
sys = update_simplified_system!(state, neweqs, solved_eqs, dummy_sub, var_eq_matching,
832-
extra_unknowns; cse_hack, array_hack)
831+
extra_unknowns; array_hack)
833832

834833
@set! state.sys = sys
835834
@set! sys.tearing_state = state
836835
return invalidate_cache!(sys)
837836
end
838837

839838
"""
840-
# HACK 1
841-
842-
Since we don't support array equations, any equation of the sort `x[1:n] ~ f(...)[1:n]`
843-
gets turned into `x[1] ~ f(...)[1], x[2] ~ f(...)[2]`. Repeatedly calling `f` gets
844-
_very_ expensive. this hack performs a limited form of CSE specifically for this case to
845-
avoid the unnecessary cost. This and the below hack are implemented simultaneously
846-
847-
# HACK 2
839+
# HACK
848840
849841
Add equations for array observed variables. If `p[i] ~ (...)` are equations, add an
850842
equation `p ~ [p[1], p[2], ...]` allow topsort to reorder them only add the new equation
851843
if all `p[i]` are present and the unscalarized form is used in any equation (observed or
852844
not) we first count the number of times the scalarized form of each observed variable
853845
occurs in observed equations (and unknowns if it's split).
854846
"""
855-
function cse_and_array_hacks(sys, obs, unknowns, neweqs; cse = true, array = true)
856-
# HACK 1
857-
# mapping of rhs to temporary CSE variable
858-
# `f(...) => tmpvar` in above example
859-
rhs_to_tempvar = Dict()
860-
861-
# HACK 2
847+
function tearing_hacks(sys, obs, unknowns, neweqs; array = true)
862848
# map of array observed variable (unscalarized) to number of its
863849
# scalarized terms that appear in observed equations
864850
arr_obs_occurrences = Dict()
865851
for (i, eq) in enumerate(obs)
866852
lhs = eq.lhs
867853
rhs = eq.rhs
868854

869-
# HACK 1
870-
if cse && is_getindexed_array(rhs)
871-
rhs_arr = arguments(rhs)[1]
872-
iscall(rhs_arr) && operation(rhs_arr) isa Symbolics.Operator && continue
873-
if !haskey(rhs_to_tempvar, rhs_arr)
874-
tempvar = gensym(Symbol(lhs))
875-
N = length(rhs_arr)
876-
tempvar = unwrap(Symbolics.variable(
877-
tempvar; T = Symbolics.symtype(rhs_arr)))
878-
tempvar = setmetadata(
879-
tempvar, Symbolics.ArrayShapeCtx, Symbolics.shape(rhs_arr))
880-
tempeq = tempvar ~ rhs_arr
881-
rhs_to_tempvar[rhs_arr] = tempvar
882-
push!(obs, tempeq)
883-
end
884-
885-
# getindex_wrapper is used because `observed2graph` treats `x` and `x[i]` as different,
886-
# so it doesn't find a dependency between this equation and `tempvar ~ rhs_arr`
887-
# which fails the topological sort
888-
neweq = lhs ~ getindex_wrapper(
889-
rhs_to_tempvar[rhs_arr], Tuple(arguments(rhs)[2:end]))
890-
obs[i] = neweq
891-
end
892-
# end HACK 1
893-
894855
array || continue
895856
iscall(lhs) || continue
896857
operation(lhs) === getindex || continue
@@ -901,31 +862,6 @@ function cse_and_array_hacks(sys, obs, unknowns, neweqs; cse = true, array = tru
901862
continue
902863
end
903864

904-
# Also do CSE for `equations(sys)`
905-
if cse
906-
for (i, eq) in enumerate(neweqs)
907-
(; lhs, rhs) = eq
908-
is_getindexed_array(rhs) || continue
909-
rhs_arr = arguments(rhs)[1]
910-
if !haskey(rhs_to_tempvar, rhs_arr)
911-
tempvar = gensym(Symbol(lhs))
912-
N = length(rhs_arr)
913-
tempvar = unwrap(Symbolics.variable(
914-
tempvar; T = Symbolics.symtype(rhs_arr)))
915-
tempvar = setmetadata(
916-
tempvar, Symbolics.ArrayShapeCtx, Symbolics.shape(rhs_arr))
917-
tempeq = tempvar ~ rhs_arr
918-
rhs_to_tempvar[rhs_arr] = tempvar
919-
push!(obs, tempeq)
920-
end
921-
# don't need getindex_wrapper, but do it anyway to know that this
922-
# hack took place
923-
neweq = lhs ~ getindex_wrapper(
924-
rhs_to_tempvar[rhs_arr], Tuple(arguments(rhs)[2:end]))
925-
neweqs[i] = neweq
926-
end
927-
end
928-
929865
# count variables in unknowns if they are scalarized forms of variables
930866
# also present as observed. e.g. if `x[1]` is an unknown and `x[2] ~ (..)`
931867
# is an observed equation.
@@ -960,18 +896,7 @@ function cse_and_array_hacks(sys, obs, unknowns, neweqs; cse = true, array = tru
960896
return obs
961897
end
962898

963-
function is_getindexed_array(rhs)
964-
(!ModelingToolkit.isvariable(rhs) || ModelingToolkit.iscalledparameter(rhs)) &&
965-
iscall(rhs) && operation(rhs) === getindex &&
966-
Symbolics.shape(rhs) != Symbolics.Unknown()
967-
end
968-
969-
# PART OF HACK 1
970-
getindex_wrapper(x, i) = x[i...]
971-
972-
@register_symbolic getindex_wrapper(x::AbstractArray, i::Tuple{Vararg{Int}})
973-
974-
# PART OF HACK 2
899+
# PART OF HACK
975900
function change_origin(origin, arr)
976901
if all(isone, Tuple(origin))
977902
return arr
@@ -999,10 +924,10 @@ new residual equations after tearing. End users are encouraged to call [`structu
999924
instead, which calls this function internally.
1000925
"""
1001926
function tearing(sys::AbstractSystem, state = TearingState(sys); mm = nothing,
1002-
simplify = false, cse_hack = true, array_hack = true, kwargs...)
927+
simplify = false, array_hack = true, kwargs...)
1003928
var_eq_matching, full_var_eq_matching = tearing(state)
1004929
invalidate_cache!(tearing_reassemble(
1005-
state, var_eq_matching, full_var_eq_matching; mm, simplify, cse_hack, array_hack))
930+
state, var_eq_matching, full_var_eq_matching; mm, simplify, array_hack))
1006931
end
1007932

1008933
"""
@@ -1024,7 +949,7 @@ Perform index reduction and use the dummy derivative technique to ensure that
1024949
the system is balanced.
1025950
"""
1026951
function dummy_derivative(sys, state = TearingState(sys); simplify = false,
1027-
mm = nothing, cse_hack = true, array_hack = true, kwargs...)
952+
mm = nothing, array_hack = true, kwargs...)
1028953
jac = let state = state
1029954
(eqs, vars) -> begin
1030955
symeqs = EquationsView(state)[eqs]
@@ -1048,5 +973,5 @@ function dummy_derivative(sys, state = TearingState(sys); simplify = false,
1048973
end
1049974
var_eq_matching = dummy_derivative_graph!(state, jac; state_priority,
1050975
kwargs...)
1051-
tearing_reassemble(state, var_eq_matching; simplify, mm, cse_hack, array_hack)
976+
tearing_reassemble(state, var_eq_matching; simplify, mm, array_hack)
1052977
end

src/systems/nonlinear/initializesystem.jl

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -739,20 +739,6 @@ function unhack_observed(obseqs::Vector{Equation}, eqs::Vector{Equation})
739739
push!(rm_idxs, i)
740740
continue
741741
end
742-
if operation(eq.rhs) == StructuralTransformations.getindex_wrapper
743-
var, idxs = arguments(eq.rhs)
744-
subs[eq.rhs] = var[idxs...]
745-
push!(tempvars, var)
746-
end
747-
end
748-
749-
for (i, eq) in enumerate(eqs)
750-
iscall(eq.rhs) || continue
751-
if operation(eq.rhs) == StructuralTransformations.getindex_wrapper
752-
var, idxs = arguments(eq.rhs)
753-
subs[eq.rhs] = var[idxs...]
754-
push!(tempvars, var)
755-
end
756742
end
757743

758744
for (i, eq) in enumerate(obseqs)

test/code_generation.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,3 +78,34 @@ end
7878
@test SciMLBase.successful_retcode(sol)
7979
end
8080
end
81+
82+
@testset "scalarized array observed calling same function multiple times" begin
83+
@variables x(t) y(t)[1:2]
84+
@parameters foo(::Real)[1:2]
85+
val = Ref(0)
86+
function _tmp_fn2(x)
87+
val[] += 1
88+
return [x, 2x]
89+
end
90+
@mtkbuild sys = ODESystem([D(x) ~ y[1] + y[2], y ~ foo(x)], t)
91+
@test length(equations(sys)) == 1
92+
@test length(observed(sys)) == 3
93+
prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [foo => _tmp_fn2])
94+
val[] = 0
95+
@test_nowarn prob.f(prob.u0, prob.p, 0.0)
96+
@test val[] == 1
97+
98+
@testset "CSE in equations(sys)" begin
99+
val[] = 0
100+
@variables z(t)[1:2]
101+
@mtkbuild sys = ODESystem(
102+
[D(y) ~ foo(x), D(x) ~ sum(y), zeros(2) ~ foo(prod(z))], t)
103+
@test length(equations(sys)) == 5
104+
@test length(observed(sys)) == 0
105+
prob = ODEProblem(
106+
sys, [y => ones(2), z => 2ones(2), x => 3.0], (0.0, 1.0), [foo => _tmp_fn2])
107+
val[] = 0
108+
@test_nowarn prob.f(prob.u0, prob.p, 0.0)
109+
@test val[] == 2
110+
end
111+
end

test/initializationsystem.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1650,3 +1650,24 @@ end
16501650
@test !SciMLBase.isinplace(prob)
16511651
@test !SciMLBase.isinplace(prob.f.initialization_data.initializeprob)
16521652
end
1653+
1654+
@testset "Array unknowns occurring unscalarized in initializeprobpmap" begin
1655+
@variables begin
1656+
u(t)[1:2] = 0.9ones(2)
1657+
x(t)[1:2], [guess = 0.01ones(2)]
1658+
o(t)[1:2]
1659+
end
1660+
@parameters p[1:4] = [2.0, 1.875, 2.0, 1.875]
1661+
1662+
eqs = [D(u[1]) ~ p[1] * u[1] - p[2] * u[1] * u[2] + x[1] + 0.1
1663+
D(u[2]) ~ p[4] * u[1] * u[2] - p[3] * u[2] - x[2]
1664+
o[1] ~ sum(p) * sum(u)
1665+
o[2] ~ sum(p) * sum(x)
1666+
x[1] ~ 0.01exp(-1)
1667+
x[2] ~ 0.01cos(t)]
1668+
1669+
@mtkbuild sys = ODESystem(eqs, t)
1670+
prob = ODEProblem(sys, [], (0.0, 1.0))
1671+
sol = solve(prob, Tsit5())
1672+
@test SciMLBase.successful_retcode(sol)
1673+
end

test/structural_transformation/utils.jl

Lines changed: 5 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ end
5151
@mtkbuild sys = ODESystem(
5252
[D(x) ~ z[1] + z[2] + foo(z)[1], y[1] ~ 2t, y[2] ~ 3t, z ~ foo(y)], t)
5353
@test length(equations(sys)) == 1
54-
@test length(observed(sys)) == 7
54+
@test length(observed(sys)) == 6
5555
@test any(obs -> isequal(obs, y), observables(sys))
5656
@test any(obs -> isequal(obs, z), observables(sys))
5757
prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [foo => _tmp_fn])
@@ -61,76 +61,20 @@ end
6161
@test length(unknowns(isys)) == 5
6262
@test length(equations(isys)) == 4
6363
@test !any(equations(isys)) do eq
64-
iscall(eq.rhs) && operation(eq.rhs) in [StructuralTransformations.getindex_wrapper,
65-
StructuralTransformations.change_origin]
64+
iscall(eq.rhs) && operation(eq.rhs) in [StructuralTransformations.change_origin]
6665
end
6766
end
6867

69-
@testset "scalarized array observed calling same function multiple times" begin
70-
@variables x(t) y(t)[1:2]
71-
@parameters foo(::Real)[1:2]
72-
val = Ref(0)
73-
function _tmp_fn2(x)
74-
val[] += 1
75-
return [x, 2x]
76-
end
77-
@mtkbuild sys = ODESystem([D(x) ~ y[1] + y[2], y ~ foo(x)], t)
78-
@test length(equations(sys)) == 1
79-
@test length(observed(sys)) == 4
80-
prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [foo => _tmp_fn2])
81-
val[] = 0
82-
@test_nowarn prob.f(prob.u0, prob.p, 0.0)
83-
@test val[] == 1
84-
85-
isys = ModelingToolkit.generate_initializesystem(sys)
86-
@test length(unknowns(isys)) == 3
87-
@test length(equations(isys)) == 2
88-
@test !any(equations(isys)) do eq
89-
iscall(eq.rhs) && operation(eq.rhs) in [StructuralTransformations.getindex_wrapper,
90-
StructuralTransformations.change_origin]
91-
end
92-
93-
@testset "CSE hack in equations(sys)" begin
94-
val[] = 0
95-
@variables z(t)[1:2]
96-
@mtkbuild sys = ODESystem(
97-
[D(y) ~ foo(x), D(x) ~ sum(y), zeros(2) ~ foo(prod(z))], t)
98-
@test length(equations(sys)) == 5
99-
@test length(observed(sys)) == 2
100-
prob = ODEProblem(
101-
sys, [y => ones(2), z => 2ones(2), x => 3.0], (0.0, 1.0), [foo => _tmp_fn2])
102-
val[] = 0
103-
@test_nowarn prob.f(prob.u0, prob.p, 0.0)
104-
@test val[] == 2
105-
106-
isys = ModelingToolkit.generate_initializesystem(sys)
107-
@test length(unknowns(isys)) == 5
108-
@test length(equations(isys)) == 2
109-
@test !any(equations(isys)) do eq
110-
iscall(eq.rhs) &&
111-
operation(eq.rhs) in [StructuralTransformations.getindex_wrapper,
112-
StructuralTransformations.change_origin]
113-
end
114-
end
115-
end
116-
117-
@testset "array and cse hacks can be disabled" begin
68+
@testset "array hack can be disabled" begin
11869
@testset "fully_determined = true" begin
11970
@variables x(t) y(t)[1:2] z(t)[1:2]
12071
@parameters foo(::AbstractVector)[1:2]
12172
_tmp_fn(x) = 2x
12273
@named sys = ODESystem(
12374
[D(x) ~ z[1] + z[2] + foo(z)[1], y[1] ~ 2t, y[2] ~ 3t, z ~ foo(y)], t)
12475

125-
sys1 = structural_simplify(sys; cse_hack = false)
126-
@test length(observed(sys1)) == 6
127-
@test !any(observed(sys1)) do eq
128-
iscall(eq.rhs) &&
129-
operation(eq.rhs) == StructuralTransformations.getindex_wrapper
130-
end
131-
13276
sys2 = structural_simplify(sys; array_hack = false)
133-
@test length(observed(sys2)) == 5
77+
@test length(observed(sys2)) == 4
13478
@test !any(observed(sys2)) do eq
13579
iscall(eq.rhs) && operation(eq.rhs) == StructuralTransformations.change_origin
13680
end
@@ -143,15 +87,8 @@ end
14387
@named sys = ODESystem(
14488
[D(x) ~ z[1] + z[2] + foo(z)[1] + w, y[1] ~ 2t, y[2] ~ 3t, z ~ foo(y)], t)
14589

146-
sys1 = structural_simplify(sys; cse_hack = false, fully_determined = false)
147-
@test length(observed(sys1)) == 6
148-
@test !any(observed(sys1)) do eq
149-
iscall(eq.rhs) &&
150-
operation(eq.rhs) == StructuralTransformations.getindex_wrapper
151-
end
152-
15390
sys2 = structural_simplify(sys; array_hack = false, fully_determined = false)
154-
@test length(observed(sys2)) == 5
91+
@test length(observed(sys2)) == 4
15592
@test !any(observed(sys2)) do eq
15693
iscall(eq.rhs) && operation(eq.rhs) == StructuralTransformations.change_origin
15794
end

0 commit comments

Comments
 (0)