11using ModelingToolkit
22using LinearAlgebra
33using DiffEqBase
4- import Base.==
5- import Base. unique, Base. unique!
6- using ModelingToolkit: < ₑ, value, isparameter
4+ import Base: unique, unique!, ==
5+ using ModelingToolkit: < ₑ, value, isparameter, operation, arguments, istree
76
87"""
98$(TYPEDEF)
@@ -17,7 +16,7 @@ or in place with `f(du, u, p, t)`.
1716If `linear_independent` is set to `true`, a linear independent basis is created from all atom function in `f`.
1817If `simplify_eqs` is set to `true`, `simplify` is called on `f`.
1918Additional keyworded arguments include `name`, which can be used to name the basis, `pins` used for connections and
20- `observed` for defining observeables.
19+ `observed` for defining observeables.
2120
2221# Fields
2322$(FIELDS)
@@ -65,19 +64,19 @@ mutable struct Basis <: ModelingToolkit.AbstractSystem
6564end
6665
6766function Basis (eqs:: AbstractVector , states:: AbstractVector ; parameters:: AbstractArray = [], iv = nothing ,
68-
67+
6968 simplify = false , linear_independent = false , name = gensym (:Basis ), eval_expression = false ,
7069 pins = [], observed = [],
7170 kwargs... )
72-
71+
7372 eqs = simplify ? ModelingToolkit. simplify .(eqs) : eqs
7473 eqs = linear_independent ? create_linear_independent_eqs (eqs) : eqs
7574 isnothing (iv) && (iv = Num (Variable (:t )))
7675 unique! (eqs, ! simplify)
77-
76+
7877 if eval_expression
7978 f_oop, f_iip = eval .(build_function (eqs, value .(states), value .(parameters), [value (iv)], expression = Val{true }))
80- else
79+ else
8180 f_oop, f_iip = build_function (eqs, value .(states), value .(parameters), [value (iv)], expression = Val{false })
8281 end
8382 eqs = [Variable (:φ ,i) ~ eq for (i,eq) ∈ enumerate (eqs)]
@@ -87,10 +86,10 @@ function Basis(eqs::AbstractVector, states::AbstractVector; parameters::Abstract
8786 return Basis (eqs, value .(states), value .(parameters), pins, observed, value (iv), f_, name, Basis[])
8887end
8988
90- function Basis (f:: Function , states:: AbstractVector ; parameters:: AbstractArray = [], iv = nothing , kwargs... )
91-
89+ function Basis (f:: Function , states:: AbstractVector ; parameters:: AbstractArray = [], iv = nothing , kwargs... )
90+
9291 isnothing (iv) && (iv = Num (Variable (:t )))
93-
92+
9493 try
9594 eqs = f (states, parameters, iv)
9695 return Basis (eqs, states, parameters = parameters, iv = iv; kwargs... )
@@ -99,9 +98,8 @@ function Basis(f::Function, states::AbstractVector; parameters::AbstractArray =
9998 end
10099end
101100
102- _get_name (x:: Num ) = _get_name (x. val)
103- _get_name (x) = x. name
104- _get_name (x:: Term ) = x. f. name
101+ _get_name (x:: Num ) = _get_name (value (x))
102+ _get_name (x) = nameof (istree (x) ? operation (x) : x)
105103
106104Base. show (io:: IO , x:: Basis ) = print (io, " $(String .(x. name)) : $(length (x. eqs)) dimensional basis in " , " $(String .([_get_name (v) for v in x. states])) " )
107105
@@ -115,14 +113,14 @@ Base.show(io::IO, x::Basis) = print(io, "$(String.(x.name)) : $(length(x.eqs)) d
115113 println (io, " $(eq. lhs) = $(eq. rhs) " )
116114 elseif i == 5
117115 println (io, " ..." )
118- else
116+ else
119117 continue
120118 end
121119 end
122120end
123121
124122@inline function Base. println (io:: IO , x:: Basis , fullview:: DataType = Val{false })
125- fullview == Val{false } && return print (io, x)
123+ fullview == Val{false } && return print (io, x)
126124 show (io, x)
127125 ! isempty (x. ps) && println (io, " \n Parameters : $(x. ps) " )
128126 println (io, " \n Independent variable: $(x. iv) " )
@@ -349,7 +347,7 @@ function dynamics(b::Basis)
349347 return b. f_
350348end
351349
352- # # Term Manipulation
350+ # # Symbolic Manipulation
353351
354352function is_unary (f:: Function )
355353 for m in methods (f)
@@ -362,19 +360,19 @@ count_operation(x::Number, op::Function, nested::Bool = true) = 0
362360count_operation (x:: Sym , op:: Function , nested:: Bool = true ) = 0
363361count_operation (x:: Num , op:: Function , nested:: Bool = true ) = count_operation (value (x), op, nested)
364362
365- function count_operation (x:: Term , op:: Function , nested:: Bool = true )
366- if x . f == op
363+ function count_operation (x, op:: Function , nested:: Bool = true )
364+ if operation (x) == op
367365 if is_unary (op)
368366 # Handles sin, cos and stuff
369- nested && return 1 + count_operation (x . arguments, op)
367+ nested && return 1 + count_operation (arguments (x) , op)
370368 return 1
371369 else
372370 # Handles +, *
373- nested && length (x . arguments) - 1 + count_operation (x . arguments, op)
374- return length (x . arguments) - 1
371+ nested && length (arguments (x)) - 1 + count_operation (arguments (x) , op)
372+ return length (arguments (x)) - 1
375373 end
376374 elseif nested
377- return count_operation (x . arguments, op, nested)
375+ return count_operation (arguments (x) , op, nested)
378376 end
379377 return 0
380378end
@@ -395,40 +393,35 @@ function count_operation(x::AbstractArray, ops::AbstractArray, nested::Bool = tr
395393 counter
396394end
397395
398- function split_term! (x:: AbstractArray , o:: Term , ops:: AbstractArray = [+ ])
399- n_ops = count_operation (o, ops, false )
400- c_ops = 0
401- @views begin
402- if n_ops == 0
403- x[begin ]= o
404- else
405- counter_ = 1
406- for oi in o. arguments
407- c_ops = count_operation (oi, ops, false )
408- split_term! (x[counter_: counter_+ c_ops], oi, ops)
409- counter_ += c_ops + 1
396+ function split_term! (x:: AbstractArray , o, ops:: AbstractArray = [+ ])
397+ if istree (o)
398+ n_ops = count_operation (o, ops, false )
399+ c_ops = 0
400+ @views begin
401+ if n_ops == 0
402+ x[begin ]= o
403+ else
404+ counter_ = 1
405+ for oi in arguments (o)
406+ c_ops = count_operation (oi, ops, false )
407+ split_term! (x[counter_: counter_+ c_ops], oi, ops)
408+ counter_ += c_ops + 1
409+ end
410410 end
411411 end
412+ else
413+ x[begin ] = o
412414 end
413- end
414-
415- split_term! (x:: AbstractArray ,o:: Num , ops:: AbstractArray = [+ ]) = split_term! (x, value (o), ops)
416-
417- function split_term! (x:: AbstractArray , o:: Sym , ops:: AbstractArray = [+ ])
418- x[begin ] = o
419415 return
420416end
421417
422- function split_term! (x:: AbstractArray , o:: Number , ops:: AbstractArray = [+ ])
423- x[begin ] = o
424- return
425- end
418+ split_term! (x:: AbstractArray ,o:: Num , ops:: AbstractArray = [+ ]) = split_term! (x, value (o), ops)
426419
427420remove_constant_factor (x:: Num ) = remove_constant_factor (value (x))
428- remove_constant_factor (x:: Sym ) = x
429421remove_constant_factor (x:: Number ) = one (x)
430422
431- function remove_constant_factor (x:: Term )
423+ function remove_constant_factor (x)
424+ istree (x) || return x
432425 n_ops = count_operation (x, * , false )+ 1
433426 ops = Array {Any} (undef, n_ops)
434427 @views split_term! (ops, x, [* ])
0 commit comments