@@ -34,13 +34,15 @@ is_independent(o::Operation) = isempty(o.args)
3434
3535
3636"""
37- Basis(f, u; p, iv, eval_expression)
37+ Basis(f, u; p, iv, linear_independent = false, simplify_eqs = true, eval_expression = false )
3838
3939A basis over the variables `u` with parameters `p` and independent variable `iv`.
4040`f` can either be a Julia function which is able to use ModelingToolkit variables or
4141a vector of `Operation`.
4242It can be called with the typical DiffEq signature, meaning out of place with `f(u,p,t)`
4343or in place with `f(du, u, p, t)`.
44+ If `linear_independent` is set to `true`, a linear independent basis is created from all atom function in `f`.
45+ If `simplify_eqs` is set to `true`, `simplify` is called on `f`.
4446
4547# Example
4648
@@ -66,10 +68,13 @@ on sufficiently large basis functions. By default eval_expression=false.
6668
6769"""
6870function Basis (basis:: AbstractArray{Operation} , variables:: AbstractArray{Operation} ;
69- parameters:: AbstractArray = Operation[], iv = nothing , eval_expression = false )
71+ parameters:: AbstractArray = Operation[], iv = nothing , linear_independent :: Bool = false , simplify_eqs = true , eval_expression = false )
7072 @assert all (is_independent .(variables)) " Please provide independent states."
7173
72- bs = unique (basis)
74+ bs = deepcopy (basis)
75+ simplify_eqs && (bs = simplify .(bs))
76+ linear_independent && (bs = create_linear_independent_eqs (bs))
77+ unique! (bs)
7378
7479 if isnothing (iv)
7580 @parameters t
@@ -92,7 +97,7 @@ function Basis(basis::AbstractArray{Operation}, variables::AbstractArray{Operati
9297end
9398
9499
95- function Basis (basis:: Function , variables:: AbstractArray{Operation} ; parameters:: AbstractArray = Operation[], iv = nothing )
100+ function Basis (basis:: Function , variables:: AbstractArray{Operation} ; parameters:: AbstractArray = Operation[], iv = nothing , kwargs ... )
96101 @assert all (is_independent .(variables)) " Please provide independent variables for basis."
97102
98103 if isnothing (iv)
@@ -102,7 +107,7 @@ function Basis(basis::Function, variables::AbstractArray{Operation}; parameters
102107
103108 try
104109 eqs = basis (variables, parameters, iv)
105- return Basis (eqs, variables, parameters = parameters, iv = iv)
110+ return Basis (eqs, variables, parameters = parameters, iv = iv, kwargs ... )
106111 catch e
107112 rethrow (e)
108113 end
@@ -202,18 +207,97 @@ function (==)(x::Basis, y::Basis)
202207 return all (n)
203208end
204209
205- function count_operation (o :: Expression , ops :: AbstractArray )
206- if isa (o, ModelingToolkit . Constant )
207- return 0
210+ function is_unary (f :: Function )
211+ for m in methods (f )
212+ m . nargs - 1 > 1 && return false
208213 end
209- k = o. op ∈ ops ? 1 : 0
210- if ! isempty (o. args)
211- k += sum ([count_operation (ai, ops) for ai in o. args])
214+ return true
215+ end
216+
217+ function count_operation (x:: T , op:: Function , nested:: Bool = true ) where T <: Expression
218+ isa (x, ModelingToolkit. Constant) && return 0
219+ isa (x. op, Expression) && return 0
220+ if x. op == op
221+ if is_unary (op)
222+ # Handles sin, cos and stuff
223+ nested && return 1 + count_operation (x. args, op)
224+ return 1
225+ else
226+ # Handles +, *
227+ nested && length (x. args)- 1 + count_operation (x. args, op)
228+ return length (x. args)- 1
229+ end
230+ elseif nested
231+ return count_operation (x. args, op, nested)
232+ end
233+ return 0
234+ end
235+
236+ function count_operation (x:: T , ops:: AbstractArray ,nested:: Bool = true ) where T<: Expression
237+ c_ops = 0
238+ for oi in ops
239+ c_ops += count_operation (x, oi, nested)
240+ end
241+ return c_ops
242+ end
243+
244+ function count_operation (x:: AbstractVector{T} , op, nested:: Bool = true ) where T <: Expression
245+ c_ops = 0
246+ for xi in x
247+ c_ops += count_operation (xi, op, nested)
248+ end
249+ return c_ops
250+ end
251+
252+ function remove_constant_factor (o:: T ) where T <: Expression
253+ isa (o, ModelingToolkit. Constant) && return ModelingToolkit. Constant (1 )
254+ n_ops = count_operation (o, * , false ) + 1
255+ ops = Array {Expression} (undef, n_ops)
256+ @views split_operation! (ops, o, [* ])
257+ filter! (x-> ! isa (x, ModelingToolkit. Constant), ops)
258+ return prod (ops)
259+ end
260+
261+ function remove_constant_factor! (o:: AbstractArray{T} ) where T <: Expression
262+ for i in eachindex (o)
263+ o[i] = remove_constant_factor (o[i])
264+ end
265+ end
266+
267+ function split_operation! (k:: AbstractVector{T} , o:: Expression , ops:: AbstractArray = [+ ]) where T <: Expression
268+ n_ops = count_operation (o, ops, false )
269+ c_ops = 0
270+ @views begin
271+ if n_ops == 0
272+ k .= o
273+ else
274+ counter_ = 1
275+ for oi in o. args
276+ c_ops = count_operation (oi, ops, false )
277+ split_operation! (k[counter_: counter_+ c_ops], oi, ops)
278+ counter_ += c_ops + 1
279+ end
280+ end
281+ end
282+ end
283+
284+ function create_linear_independent_eqs (o:: AbstractVector{T} ) where T <: Expression
285+ unique! (o)
286+ n_ops = [count_operation (bi, + , false ) for bi in o]
287+ n_x = sum (n_ops) + length (o)
288+ u_o = Array {T} (undef, n_x)
289+ ind_lo, ind_up = 0 , 0
290+ for i in eachindex (o)
291+ ind_lo = i > 1 ? sum (n_ops[1 : i- 1 ]) + i : 1
292+ ind_up = sum (n_ops[1 : i]) + i
293+ @views split_operation! (u_o[ind_lo: ind_up], o[i], [+ ])
212294 end
213- return k
295+ remove_constant_factor! (u_o)
296+ unique! (u_o)
297+ return u_o
214298end
215299
216- free_parameters (b:: Basis ; operations = [+ ]) = sum ([ count_operation (bi, operations) for bi in b. basis] ) + length (b. basis)
300+ free_parameters (b:: Basis ; operations = [+ ]) = count_operation (b. basis, operations ) + length (b. basis)
217301
218302(b:: Basis )(u, p:: DiffEqBase.NullParameters , t) = b (u, [], t)
219303(b:: Basis )(du, u, p:: DiffEqBase.NullParameters , t) = b (du, u, [], t)
0 commit comments