Skip to content

Commit ffd2584

Browse files
authored
Merge pull request #161 from SciML/simplify_basis_routines
Adapt count operations and introduce linear independent basis
2 parents 5da00ed + ce9cda9 commit ffd2584

File tree

2 files changed

+101
-14
lines changed

2 files changed

+101
-14
lines changed

src/basis.jl

Lines changed: 97 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,15 @@ is_independent(o::Operation) = isempty(o.args)
3434

3535

3636
"""
37-
Basis(f, u; p, iv, eval_expression)
37+
Basis(f, u; p, iv, linear_independent = false, simplify_eqs = true, eval_expression = false)
3838
3939
A basis over the variables `u` with parameters `p` and independent variable `iv`.
4040
`f` can either be a Julia function which is able to use ModelingToolkit variables or
4141
a vector of `Operation`.
4242
It can be called with the typical DiffEq signature, meaning out of place with `f(u,p,t)`
4343
or in place with `f(du, u, p, t)`.
44+
If `linear_independent` is set to `true`, a linear independent basis is created from all atom function in `f`.
45+
If `simplify_eqs` is set to `true`, `simplify` is called on `f`.
4446
4547
# Example
4648
@@ -66,10 +68,13 @@ on sufficiently large basis functions. By default eval_expression=false.
6668
6769
"""
6870
function Basis(basis::AbstractArray{Operation}, variables::AbstractArray{Operation};
69-
parameters::AbstractArray = Operation[], iv = nothing, eval_expression = false)
71+
parameters::AbstractArray = Operation[], iv = nothing, linear_independent::Bool = false, simplify_eqs = true, eval_expression = false)
7072
@assert all(is_independent.(variables)) "Please provide independent states."
7173

72-
bs = unique(basis)
74+
bs = deepcopy(basis)
75+
simplify_eqs && (bs = simplify.(bs))
76+
linear_independent && (bs = create_linear_independent_eqs(bs))
77+
unique!(bs)
7378

7479
if isnothing(iv)
7580
@parameters t
@@ -92,7 +97,7 @@ function Basis(basis::AbstractArray{Operation}, variables::AbstractArray{Operati
9297
end
9398

9499

95-
function Basis(basis::Function, variables::AbstractArray{Operation}; parameters::AbstractArray = Operation[], iv = nothing)
100+
function Basis(basis::Function, variables::AbstractArray{Operation}; parameters::AbstractArray = Operation[], iv = nothing, kwargs...)
96101
@assert all(is_independent.(variables)) "Please provide independent variables for basis."
97102

98103
if isnothing(iv)
@@ -102,7 +107,7 @@ function Basis(basis::Function, variables::AbstractArray{Operation}; parameters
102107

103108
try
104109
eqs = basis(variables, parameters, iv)
105-
return Basis(eqs, variables, parameters = parameters, iv = iv)
110+
return Basis(eqs, variables, parameters = parameters, iv = iv, kwargs...)
106111
catch e
107112
rethrow(e)
108113
end
@@ -202,18 +207,97 @@ function (==)(x::Basis, y::Basis)
202207
return all(n)
203208
end
204209

205-
function count_operation(o::Expression, ops::AbstractArray)
206-
if isa(o, ModelingToolkit.Constant)
207-
return 0
210+
function is_unary(f::Function)
211+
for m in methods(f)
212+
m.nargs - 1 > 1 && return false
208213
end
209-
k = o.op ops ? 1 : 0
210-
if !isempty(o.args)
211-
k += sum([count_operation(ai, ops) for ai in o.args])
214+
return true
215+
end
216+
217+
function count_operation(x::T, op::Function, nested::Bool = true) where T <: Expression
218+
isa(x, ModelingToolkit.Constant) && return 0
219+
isa(x.op, Expression) && return 0
220+
if x.op == op
221+
if is_unary(op)
222+
# Handles sin, cos and stuff
223+
nested && return 1 + count_operation(x.args, op)
224+
return 1
225+
else
226+
# Handles +, *
227+
nested && length(x.args)-1 + count_operation(x.args, op)
228+
return length(x.args)-1
229+
end
230+
elseif nested
231+
return count_operation(x.args, op, nested)
232+
end
233+
return 0
234+
end
235+
236+
function count_operation(x::T, ops::AbstractArray,nested::Bool = true) where T<:Expression
237+
c_ops = 0
238+
for oi in ops
239+
c_ops += count_operation(x, oi, nested)
240+
end
241+
return c_ops
242+
end
243+
244+
function count_operation(x::AbstractVector{T}, op, nested::Bool = true) where T <: Expression
245+
c_ops = 0
246+
for xi in x
247+
c_ops += count_operation(xi, op, nested)
248+
end
249+
return c_ops
250+
end
251+
252+
function remove_constant_factor(o::T) where T <: Expression
253+
isa(o, ModelingToolkit.Constant) && return ModelingToolkit.Constant(1)
254+
n_ops = count_operation(o, *, false) +1
255+
ops = Array{Expression}(undef, n_ops)
256+
@views split_operation!(ops, o, [*])
257+
filter!(x->!isa(x, ModelingToolkit.Constant), ops)
258+
return prod(ops)
259+
end
260+
261+
function remove_constant_factor!(o::AbstractArray{T}) where T <: Expression
262+
for i in eachindex(o)
263+
o[i] = remove_constant_factor(o[i])
264+
end
265+
end
266+
267+
function split_operation!(k::AbstractVector{T}, o::Expression, ops::AbstractArray = [+]) where T <: Expression
268+
n_ops = count_operation(o, ops, false)
269+
c_ops = 0
270+
@views begin
271+
if n_ops == 0
272+
k .= o
273+
else
274+
counter_ = 1
275+
for oi in o.args
276+
c_ops = count_operation(oi, ops, false)
277+
split_operation!(k[counter_:counter_+c_ops], oi, ops)
278+
counter_ += c_ops + 1
279+
end
280+
end
281+
end
282+
end
283+
284+
function create_linear_independent_eqs(o::AbstractVector{T}) where T <: Expression
285+
unique!(o)
286+
n_ops = [count_operation(bi, +, false) for bi in o]
287+
n_x = sum(n_ops) + length(o)
288+
u_o = Array{T}(undef, n_x)
289+
ind_lo, ind_up = 0, 0
290+
for i in eachindex(o)
291+
ind_lo = i > 1 ? sum(n_ops[1:i-1]) + i : 1
292+
ind_up = sum(n_ops[1:i]) + i
293+
@views split_operation!(u_o[ind_lo:ind_up], o[i], [+])
212294
end
213-
return k
295+
remove_constant_factor!(u_o)
296+
unique!(u_o)
297+
return u_o
214298
end
215299

216-
free_parameters(b::Basis; operations = [+]) = sum([count_operation(bi, operations) for bi in b.basis]) + length(b.basis)
300+
free_parameters(b::Basis; operations = [+]) = count_operation(b.basis, operations) + length(b.basis)
217301

218302
(b::Basis)(u, p::DiffEqBase.NullParameters, t) = b(u, [], t)
219303
(b::Basis)(du, u, p::DiffEqBase.NullParameters, t) = b(du, u, [], t)

test/basis.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,15 @@
55
h = [u[1]; u[2]; cos(w[1]*u[2]+w[2]*u[3]); u[3]+u[2]]
66
h_not_unique = [u[1]; u[1]; u[1]^1; h]
77
basis = Basis(h_not_unique, u, parameters = w, iv = t)
8-
8+
basis_2 = Basis(h_not_unique, u, parameters = w, iv = t, linear_independent = true)
99
@test isequal(variables(basis), u)
1010
@test isequal(parameters(basis), w)
1111
@test isequal(independent_variable(basis), t)
12+
1213
@test free_parameters(basis) == 6
14+
@test free_parameters(basis_2) == 5
1315
@test free_parameters(basis, operations = [+, cos]) == 7
16+
@test free_parameters(basis_2, operations = [+, cos]) == 6
1417
@test DataDrivenDiffEq.count_operation((ModelingToolkit.Constant(1) + cos(u[2])*sin(u[1]))^3, [+, cos, ^, *]) == 4
1518

1619
basis_2 = unique(basis)

0 commit comments

Comments
 (0)