Skip to content

Commit a13e1c8

Browse files
authored
Merge pull request #51 from JuliaDiffEq/minor_basis_fix
Minor basis fix
2 parents 0d5c3bc + a8a2d44 commit a13e1c8

File tree

2 files changed

+18
-12
lines changed

2 files changed

+18
-12
lines changed

src/basis.jl

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,21 +32,21 @@ function Basis(basis::AbstractVector{Operation}, variables::AbstractVector{Opera
3232
bs = unique(basis)
3333
fix_single_vars_in_basis!(bs, variables)
3434

35-
vs = sort!([b for b in [ModelingToolkit.vars(bs)...] if !b.known], by = x -> x.name)
36-
ps = sort!([b for b in [ModelingToolkit.vars(bs)...] if b.known], by = x -> x.name )
35+
vs = [ModelingToolkit.Variable(Symbol(i)) for i in variables]
36+
ps = [ModelingToolkit.Variable(Symbol(i)) for i in parameters]
3737

3838
f_ = ModelingToolkit.build_function(bs, vs, ps, (), simplified_expr, Val{false})[1]
3939
return Basis(bs, variables, parameters, f_)
4040
end
4141

42-
function update!(b::Basis)
43-
vs = sort!([bi for bi in [ModelingToolkit.vars(b.basis)...] if !bi.known], by = x->x.name)
44-
ps = sort!([bi for bi in [ModelingToolkit.vars(b.basis)...] if bi.known], by = x->x.name)
42+
function update!(basis::Basis)
43+
vs = [ModelingToolkit.Variable(Symbol(i)) for i in variables(basis)]
44+
ps = [ModelingToolkit.Variable(Symbol(i)) for i in parameters(basis)]
4545

46-
b.f_ = ModelingToolkit.build_function(b.basis, vs, ps, (), simplified_expr, Val{false})[1]
46+
basis.f_ = ModelingToolkit.build_function(basis.basis, vs, ps, (), simplified_expr, Val{false})[1]
4747
return
4848
end
49-
49+
5050
function Base.push!(b::Basis, ops::AbstractArray{Operation})
5151
@inbounds for o in ops
5252
push!(b.basis, o)
@@ -139,10 +139,12 @@ Base.length(b::Basis) = length(b.basis)
139139
ModelingToolkit.parameters(b::Basis) = b.parameter
140140
variables(b::Basis) = b.variables
141141

142-
function jacobian(b::Basis)
143-
vs = sort!([bi for bi in [ModelingToolkit.vars(b.basis)...] if !bi.known], by = x-> x.name)
144-
ps = sort!([bi for bi in [ModelingToolkit.vars(b.basis)...] if bi.known], by = x-> x.name)
145-
j = calculate_jacobian(b.basis, variables(b))
142+
function jacobian(basis::Basis)
143+
144+
vs = [ModelingToolkit.Variable(Symbol(i)) for i in variables(basis)]
145+
ps = [ModelingToolkit.Variable(Symbol(i)) for i in parameters(basis)]
146+
147+
j = calculate_jacobian(basis.basis, variables(basis))
146148
return ModelingToolkit.build_function(expand_derivatives.(j), vs, ps, (), simplified_expr, Val{false})[1]
147149
end
148150

test/runtests.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@ using Test
3737
unique!(basis)
3838
@test size(basis) == size(h)
3939

40+
@variables a
41+
g = [u[1]; u[3]; a]
42+
basis = Basis(g, [u; a])
43+
@test basis([1; 2; 3; 4]) == [1; 3; 4]
4044
g = [1.0*u[1]; 1.0*u[3]; 1.0*u[2]]
4145
basis = Basis(g, u, parameters = [])
4246
X = ones(Float64, 3, 10)
@@ -230,7 +234,7 @@ end
230234
X = sol[:, :] + 1e-3*randn(size(sol[:,:])...)
231235
set_threshold!(opt, 3.5e-1)
232236
Ψ = SInDy(X, DX, basis, maxiter = 10000, opt = opt, denoise = true, normalize = true)
233-
237+
234238
estimator = ODEProblem(dynamics(Ψ), u0, tspan, [])
235239
sol_4 = solve(estimator,Tsit5(), saveat = dt)
236240
@test norm(sol[:,:] - sol_4[:,:], 2) < 5e-1

0 commit comments

Comments
 (0)