Skip to content

Commit 7cbc770

Browse files
authored
Merge pull request #187 from SciML/myb/mtk
MTK 5 update
2 parents b9c8474 + 25c869f commit 7cbc770

File tree

2 files changed

+41
-48
lines changed

2 files changed

+41
-48
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ DataInterpolations = "3.1"
2525
DiffEqBase = "6.45"
2626
DocStringExtensions = "0.7, 0.8"
2727
FiniteDifferences = "0.11"
28-
ModelingToolkit = "4.0.8"
28+
ModelingToolkit = "4.0.8, 5.0"
2929
ProximalOperators = "0.11, 0.12, 0.13"
3030
QuadGK = "2.4"
3131
StatsBase = "0.32.0, 0.33"

src/basis.jl

Lines changed: 40 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
using ModelingToolkit
22
using LinearAlgebra
33
using DiffEqBase
4-
import Base.==
5-
import Base.unique, Base.unique!
6-
using ModelingToolkit: <ₑ, value, isparameter
4+
import Base: unique, unique!, ==
5+
using ModelingToolkit: <ₑ, value, isparameter, operation, arguments, istree
76

87
"""
98
$(TYPEDEF)
@@ -17,7 +16,7 @@ or in place with `f(du, u, p, t)`.
1716
If `linear_independent` is set to `true`, a linear independent basis is created from all atom function in `f`.
1817
If `simplify_eqs` is set to `true`, `simplify` is called on `f`.
1918
Additional keyworded arguments include `name`, which can be used to name the basis, `pins` used for connections and
20-
`observed` for defining observeables.
19+
`observed` for defining observeables.
2120
2221
# Fields
2322
$(FIELDS)
@@ -65,19 +64,19 @@ mutable struct Basis <: ModelingToolkit.AbstractSystem
6564
end
6665

6766
function Basis(eqs::AbstractVector, states::AbstractVector; parameters::AbstractArray = [], iv = nothing,
68-
67+
6968
simplify = false, linear_independent = false, name = gensym(:Basis), eval_expression = false,
7069
pins = [], observed = [],
7170
kwargs...)
72-
71+
7372
eqs = simplify ? ModelingToolkit.simplify.(eqs) : eqs
7473
eqs = linear_independent ? create_linear_independent_eqs(eqs) : eqs
7574
isnothing(iv) && (iv = Num(Variable(:t)))
7675
unique!(eqs, !simplify)
77-
76+
7877
if eval_expression
7978
f_oop, f_iip = eval.(build_function(eqs, value.(states), value.(parameters), [value(iv)], expression = Val{true}))
80-
else
79+
else
8180
f_oop, f_iip = build_function(eqs, value.(states), value.(parameters), [value(iv)], expression = Val{false})
8281
end
8382
eqs = [Variable(,i) ~ eq for (i,eq) enumerate(eqs)]
@@ -87,10 +86,10 @@ function Basis(eqs::AbstractVector, states::AbstractVector; parameters::Abstract
8786
return Basis(eqs, value.(states), value.(parameters), pins, observed, value(iv), f_, name, Basis[])
8887
end
8988

90-
function Basis(f::Function, states::AbstractVector; parameters::AbstractArray = [], iv = nothing, kwargs...)
91-
89+
function Basis(f::Function, states::AbstractVector; parameters::AbstractArray = [], iv = nothing, kwargs...)
90+
9291
isnothing(iv) && (iv = Num(Variable(:t)))
93-
92+
9493
try
9594
eqs = f(states, parameters, iv)
9695
return Basis(eqs, states, parameters = parameters, iv = iv; kwargs...)
@@ -99,9 +98,8 @@ function Basis(f::Function, states::AbstractVector; parameters::AbstractArray =
9998
end
10099
end
101100

102-
_get_name(x::Num) = _get_name(x.val)
103-
_get_name(x) = x.name
104-
_get_name(x::Term) = x.f.name
101+
_get_name(x::Num) = _get_name(value(x))
102+
_get_name(x) = nameof(istree(x) ? operation(x) : x)
105103

106104
Base.show(io::IO, x::Basis) = print(io, "$(String.(x.name)) : $(length(x.eqs)) dimensional basis in ", "$(String.([_get_name(v) for v in x.states]))")
107105

@@ -115,14 +113,14 @@ Base.show(io::IO, x::Basis) = print(io, "$(String.(x.name)) : $(length(x.eqs)) d
115113
println(io, "$(eq.lhs) = $(eq.rhs)")
116114
elseif i == 5
117115
println(io, "...")
118-
else
116+
else
119117
continue
120118
end
121119
end
122120
end
123121

124122
@inline function Base.println(io::IO, x::Basis, fullview::DataType = Val{false})
125-
fullview == Val{false} && return print(io, x)
123+
fullview == Val{false} && return print(io, x)
126124
show(io, x)
127125
!isempty(x.ps) && println(io, "\nParameters : $(x.ps)")
128126
println(io, "\nIndependent variable: $(x.iv)")
@@ -349,7 +347,7 @@ function dynamics(b::Basis)
349347
return b.f_
350348
end
351349

352-
## Term Manipulation
350+
## Symbolic Manipulation
353351

354352
function is_unary(f::Function)
355353
for m in methods(f)
@@ -362,19 +360,19 @@ count_operation(x::Number, op::Function, nested::Bool = true) = 0
362360
count_operation(x::Sym, op::Function, nested::Bool = true) = 0
363361
count_operation(x::Num, op::Function, nested::Bool = true) = count_operation(value(x), op, nested)
364362

365-
function count_operation(x::Term, op::Function, nested::Bool = true)
366-
if x.f == op
363+
function count_operation(x, op::Function, nested::Bool = true)
364+
if operation(x)== op
367365
if is_unary(op)
368366
# Handles sin, cos and stuff
369-
nested && return 1 + count_operation(x.arguments, op)
367+
nested && return 1 + count_operation(arguments(x), op)
370368
return 1
371369
else
372370
# Handles +, *
373-
nested && length(x.arguments)-1 + count_operation(x.arguments, op)
374-
return length(x.arguments)-1
371+
nested && length(arguments(x))-1 + count_operation(arguments(x), op)
372+
return length(arguments(x))-1
375373
end
376374
elseif nested
377-
return count_operation(x.arguments, op, nested)
375+
return count_operation(arguments(x), op, nested)
378376
end
379377
return 0
380378
end
@@ -395,40 +393,35 @@ function count_operation(x::AbstractArray, ops::AbstractArray, nested::Bool = tr
395393
counter
396394
end
397395

398-
function split_term!(x::AbstractArray, o::Term, ops::AbstractArray = [+])
399-
n_ops = count_operation(o, ops, false)
400-
c_ops = 0
401-
@views begin
402-
if n_ops == 0
403-
x[begin]= o
404-
else
405-
counter_ = 1
406-
for oi in o.arguments
407-
c_ops = count_operation(oi, ops, false)
408-
split_term!(x[counter_:counter_+c_ops], oi, ops)
409-
counter_ += c_ops + 1
396+
function split_term!(x::AbstractArray, o, ops::AbstractArray = [+])
397+
if istree(o)
398+
n_ops = count_operation(o, ops, false)
399+
c_ops = 0
400+
@views begin
401+
if n_ops == 0
402+
x[begin]= o
403+
else
404+
counter_ = 1
405+
for oi in arguments(o)
406+
c_ops = count_operation(oi, ops, false)
407+
split_term!(x[counter_:counter_+c_ops], oi, ops)
408+
counter_ += c_ops + 1
409+
end
410410
end
411411
end
412+
else
413+
x[begin] = o
412414
end
413-
end
414-
415-
split_term!(x::AbstractArray,o::Num, ops::AbstractArray = [+]) = split_term!(x, value(o), ops)
416-
417-
function split_term!(x::AbstractArray, o::Sym, ops::AbstractArray = [+])
418-
x[begin] = o
419415
return
420416
end
421417

422-
function split_term!(x::AbstractArray, o::Number, ops::AbstractArray = [+])
423-
x[begin] = o
424-
return
425-
end
418+
split_term!(x::AbstractArray,o::Num, ops::AbstractArray = [+]) = split_term!(x, value(o), ops)
426419

427420
remove_constant_factor(x::Num) = remove_constant_factor(value(x))
428-
remove_constant_factor(x::Sym) = x
429421
remove_constant_factor(x::Number) = one(x)
430422

431-
function remove_constant_factor(x::Term)
423+
function remove_constant_factor(x)
424+
istree(x) || return x
432425
n_ops = count_operation(x, *, false)+1
433426
ops = Array{Any}(undef, n_ops)
434427
@views split_term!(ops, x, [*])

0 commit comments

Comments
 (0)