Skip to content

Commit 7cd980c

Browse files
authored
Merge pull request #47 from JuliaDiffEq/fix_sindy_init
Fixes dimensionality for parameter vector in sindy
2 parents b26019d + f49336b commit 7cd980c

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

src/sindy.jl

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,9 @@ end
3737
function SInDy(X::AbstractArray, Ẋ::AbstractArray, Ψ::Basis; p::AbstractArray = [], maxiter::Int64 = 10, opt::T = Optimise.STRRidge()) where T <: Optimise.AbstractOptimiser
3838
@assert size(X)[end] == size(Ẋ)[end]
3939
nx, nm = size(X)
40+
ny, nm = size(Ẋ)
4041

41-
Ξ = zeros(eltype(X), length(Ψ), nx)
42+
Ξ = zeros(eltype(X), length(Ψ), ny)
4243
θ = Ψ(X, p = p)
4344

4445
# Initial estimate
@@ -52,14 +53,15 @@ end
5253
function SInDy(X::AbstractArray, Ẋ::AbstractArray, Ψ::Basis, thresholds::AbstractArray ; p::AbstractArray = [], maxiter::Int64 = 10, opt::T = Optimise.STRRidge()) where T <: Optimise.AbstractOptimiser
5354
@assert size(X)[end] == size(Ẋ)[end]
5455
nx, nm = size(X)
56+
ny, nm = size(Ẋ)
5557

5658
θ = Ψ(X, p = p)
5759

58-
ξ = zeros(eltype(X), length(Ψ), nx)
59-
Ξ_opt = zeros(eltype(X), length(Ψ), nx)
60-
Ξ = zeros(eltype(X), length(thresholds), nx, length(Ψ))
61-
x = zeros(eltype(X), length(thresholds), nx, 2)
62-
p = zeros(eltype(X), nx, length(thresholds))
60+
ξ = zeros(eltype(X), length(Ψ), ny)
61+
Ξ_opt = zeros(eltype(X), length(Ψ), ny)
62+
Ξ = zeros(eltype(X), length(thresholds), ny, length(Ψ))
63+
x = zeros(eltype(X), length(thresholds), ny, 2)
64+
pareto = zeros(eltype(X), ny, length(thresholds))
6365

6466
@inbounds for (j, threshold) in enumerate(thresholds)
6567
set_threshold!(opt, threshold)
@@ -72,8 +74,8 @@ function SInDy(X::AbstractArray, Ẋ::AbstractArray, Ψ::Basis, thresholds::Abst
7274
# Create the evaluation
7375
@inbounds for i in 1:nx
7476
x[:, i, 2] .= x[:, i, 2]./maximum(x[:, i, 2])
75-
p[i, :] = [norm(x[j, i, :], 2) for j in 1:length(thresholds)]
76-
_, indx = findmin(p[i, :])
77+
pareto[i, :] = [norm(x[j, i, :], 2) for j in 1:length(thresholds)]
78+
_, indx = findmin(pareto[i, :])
7779
Ξ_opt[:, i] = Ξ[indx, i, :]
7880
end
7981

0 commit comments

Comments
 (0)