@@ -724,7 +724,7 @@ Update the system equations, unknowns, and observables after simplification.
724
724
"""
725
725
function update_simplified_system! (
726
726
state:: TearingState , neweqs, solved_eqs, dummy_sub, var_eq_matching, extra_unknowns;
727
- cse_hack = true , array_hack = true )
727
+ array_hack = true )
728
728
@unpack solvable_graph, var_to_diff, eq_to_diff, graph = state. structure
729
729
diff_to_var = invview (var_to_diff)
730
730
@@ -748,8 +748,7 @@ function update_simplified_system!(
748
748
unknowns = [unknowns; extra_unknowns]
749
749
@set! sys. unknowns = unknowns
750
750
751
- obs = cse_and_array_hacks (
752
- sys, obs, unknowns, neweqs; cse = cse_hack, array = array_hack)
751
+ obs = tearing_hacks (sys, obs, unknowns, neweqs; array = array_hack)
753
752
754
753
deps = Vector{Int}[i == 1 ? Int[] : collect (1 : (i - 1 ))
755
754
for i in 1 : length (solved_eqs)]
@@ -793,7 +792,7 @@ appear in the system. Algebraic variables are variables that are not
793
792
differential variables.
794
793
"""
795
794
function tearing_reassemble (state:: TearingState , var_eq_matching,
796
- full_var_eq_matching = nothing ; simplify = false , mm = nothing , cse_hack = true , array_hack = true )
795
+ full_var_eq_matching = nothing ; simplify = false , mm = nothing , array_hack = true )
797
796
extra_vars = Int[]
798
797
if full_var_eq_matching != = nothing
799
798
for v in 𝑑vertices (state. structure. graph)
@@ -829,68 +828,30 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
829
828
state, var_eq_matching, eq_ordering, var_ordering, nelim_eq, nelim_var)
830
829
831
830
sys = update_simplified_system! (state, neweqs, solved_eqs, dummy_sub, var_eq_matching,
832
- extra_unknowns; cse_hack, array_hack)
831
+ extra_unknowns; array_hack)
833
832
834
833
@set! state. sys = sys
835
834
@set! sys. tearing_state = state
836
835
return invalidate_cache! (sys)
837
836
end
838
837
839
838
"""
840
- # HACK 1
841
-
842
- Since we don't support array equations, any equation of the sort `x[1:n] ~ f(...)[1:n]`
843
- gets turned into `x[1] ~ f(...)[1], x[2] ~ f(...)[2]`. Repeatedly calling `f` gets
844
- _very_ expensive. this hack performs a limited form of CSE specifically for this case to
845
- avoid the unnecessary cost. This and the below hack are implemented simultaneously
846
-
847
- # HACK 2
839
+ # HACK
848
840
849
841
Add equations for array observed variables. If `p[i] ~ (...)` are equations, add an
850
842
equation `p ~ [p[1], p[2], ...]` allow topsort to reorder them only add the new equation
851
843
if all `p[i]` are present and the unscalarized form is used in any equation (observed or
852
844
not) we first count the number of times the scalarized form of each observed variable
853
845
occurs in observed equations (and unknowns if it's split).
854
846
"""
855
- function cse_and_array_hacks (sys, obs, unknowns, neweqs; cse = true , array = true )
856
- # HACK 1
857
- # mapping of rhs to temporary CSE variable
858
- # `f(...) => tmpvar` in above example
859
- rhs_to_tempvar = Dict ()
860
-
861
- # HACK 2
847
+ function tearing_hacks (sys, obs, unknowns, neweqs; array = true )
862
848
# map of array observed variable (unscalarized) to number of its
863
849
# scalarized terms that appear in observed equations
864
850
arr_obs_occurrences = Dict ()
865
851
for (i, eq) in enumerate (obs)
866
852
lhs = eq. lhs
867
853
rhs = eq. rhs
868
854
869
- # HACK 1
870
- if cse && is_getindexed_array (rhs)
871
- rhs_arr = arguments (rhs)[1 ]
872
- iscall (rhs_arr) && operation (rhs_arr) isa Symbolics. Operator && continue
873
- if ! haskey (rhs_to_tempvar, rhs_arr)
874
- tempvar = gensym (Symbol (lhs))
875
- N = length (rhs_arr)
876
- tempvar = unwrap (Symbolics. variable (
877
- tempvar; T = Symbolics. symtype (rhs_arr)))
878
- tempvar = setmetadata (
879
- tempvar, Symbolics. ArrayShapeCtx, Symbolics. shape (rhs_arr))
880
- tempeq = tempvar ~ rhs_arr
881
- rhs_to_tempvar[rhs_arr] = tempvar
882
- push! (obs, tempeq)
883
- end
884
-
885
- # getindex_wrapper is used because `observed2graph` treats `x` and `x[i]` as different,
886
- # so it doesn't find a dependency between this equation and `tempvar ~ rhs_arr`
887
- # which fails the topological sort
888
- neweq = lhs ~ getindex_wrapper (
889
- rhs_to_tempvar[rhs_arr], Tuple (arguments (rhs)[2 : end ]))
890
- obs[i] = neweq
891
- end
892
- # end HACK 1
893
-
894
855
array || continue
895
856
iscall (lhs) || continue
896
857
operation (lhs) === getindex || continue
@@ -901,31 +862,6 @@ function cse_and_array_hacks(sys, obs, unknowns, neweqs; cse = true, array = tru
901
862
continue
902
863
end
903
864
904
- # Also do CSE for `equations(sys)`
905
- if cse
906
- for (i, eq) in enumerate (neweqs)
907
- (; lhs, rhs) = eq
908
- is_getindexed_array (rhs) || continue
909
- rhs_arr = arguments (rhs)[1 ]
910
- if ! haskey (rhs_to_tempvar, rhs_arr)
911
- tempvar = gensym (Symbol (lhs))
912
- N = length (rhs_arr)
913
- tempvar = unwrap (Symbolics. variable (
914
- tempvar; T = Symbolics. symtype (rhs_arr)))
915
- tempvar = setmetadata (
916
- tempvar, Symbolics. ArrayShapeCtx, Symbolics. shape (rhs_arr))
917
- tempeq = tempvar ~ rhs_arr
918
- rhs_to_tempvar[rhs_arr] = tempvar
919
- push! (obs, tempeq)
920
- end
921
- # don't need getindex_wrapper, but do it anyway to know that this
922
- # hack took place
923
- neweq = lhs ~ getindex_wrapper (
924
- rhs_to_tempvar[rhs_arr], Tuple (arguments (rhs)[2 : end ]))
925
- neweqs[i] = neweq
926
- end
927
- end
928
-
929
865
# count variables in unknowns if they are scalarized forms of variables
930
866
# also present as observed. e.g. if `x[1]` is an unknown and `x[2] ~ (..)`
931
867
# is an observed equation.
@@ -960,18 +896,7 @@ function cse_and_array_hacks(sys, obs, unknowns, neweqs; cse = true, array = tru
960
896
return obs
961
897
end
962
898
963
- function is_getindexed_array (rhs)
964
- (! ModelingToolkit. isvariable (rhs) || ModelingToolkit. iscalledparameter (rhs)) &&
965
- iscall (rhs) && operation (rhs) === getindex &&
966
- Symbolics. shape (rhs) != Symbolics. Unknown ()
967
- end
968
-
969
- # PART OF HACK 1
970
- getindex_wrapper (x, i) = x[i... ]
971
-
972
- @register_symbolic getindex_wrapper (x:: AbstractArray , i:: Tuple{Vararg{Int}} )
973
-
974
- # PART OF HACK 2
899
+ # PART OF HACK
975
900
function change_origin (origin, arr)
976
901
if all (isone, Tuple (origin))
977
902
return arr
@@ -999,10 +924,10 @@ new residual equations after tearing. End users are encouraged to call [`structu
999
924
instead, which calls this function internally.
1000
925
"""
1001
926
function tearing (sys:: AbstractSystem , state = TearingState (sys); mm = nothing ,
1002
- simplify = false , cse_hack = true , array_hack = true , kwargs... )
927
+ simplify = false , array_hack = true , kwargs... )
1003
928
var_eq_matching, full_var_eq_matching = tearing (state)
1004
929
invalidate_cache! (tearing_reassemble (
1005
- state, var_eq_matching, full_var_eq_matching; mm, simplify, cse_hack, array_hack))
930
+ state, var_eq_matching, full_var_eq_matching; mm, simplify, array_hack))
1006
931
end
1007
932
1008
933
"""
@@ -1024,7 +949,7 @@ Perform index reduction and use the dummy derivative technique to ensure that
1024
949
the system is balanced.
1025
950
"""
1026
951
function dummy_derivative (sys, state = TearingState (sys); simplify = false ,
1027
- mm = nothing , cse_hack = true , array_hack = true , kwargs... )
952
+ mm = nothing , array_hack = true , kwargs... )
1028
953
jac = let state = state
1029
954
(eqs, vars) -> begin
1030
955
symeqs = EquationsView (state)[eqs]
@@ -1048,5 +973,5 @@ function dummy_derivative(sys, state = TearingState(sys); simplify = false,
1048
973
end
1049
974
var_eq_matching = dummy_derivative_graph! (state, jac; state_priority,
1050
975
kwargs... )
1051
- tearing_reassemble (state, var_eq_matching; simplify, mm, cse_hack, array_hack)
976
+ tearing_reassemble (state, var_eq_matching; simplify, mm, array_hack)
1052
977
end
0 commit comments