Skip to content

Commit 5ab2a62

Browse files
authored
v0.6.1
* Document sparse regression * Add eval_expression to solve and build_solutions * Update wrong sizing in tests * v0.6.1
1 parent dc32f12 commit 5ab2a62

File tree

9 files changed

+83
-35
lines changed

9 files changed

+83
-35
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DataDrivenDiffEq"
22
uuid = "2445eb08-9709-466a-b3fc-47e12bd697a2"
33
authors = ["Julius Martensen <[email protected]>"]
4-
version = "0.6.0"
4+
version = "0.6.1"
55

66
[deps]
77
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"

docs/src/prob_and_solve.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,15 @@ ps = parameter_map(res)
6767

6868
## [Optional Arguments](@id optional_arguments)
6969

70+
!!! info
71+
The keyword argument `eval_expression` controls the function creation
72+
behavior. `eval_expression=true` means that `eval` is used, so normal
73+
world-age behavior applies (i.e. the functions cannot be called from
74+
the function that generates them). If `eval_expression=false`,
75+
then construction via GeneralizedGenerated.jl is utilized to allow for
76+
same world-age evaluation. However, this can cause Julia to segfault
77+
on sufficiently large basis functions. By default eval_expression=false.
78+
7079
Koopman based algorithms can be called without a [`Basis`](@ref), resulting in dynamic mode decomposition like methods, or with a basis for extened dynamic mode decomposition :
7180

7281
```julia
@@ -81,6 +90,9 @@ Possible keyworded arguments include
8190
+ `digits` controls the digits / rounding used for deriving the system equations (`digits = 1` would round `10.02` to `10.0`)
8291
+ `operator_only` returns a `NamedTuple` containing the operator, input and output mapping and matrices used for updating the operator as described [here](https://arxiv.org/pdf/1406.7187.pdf)
8392

93+
!!! info
94+
If `eval_expression` is set to `true`, the returning result of the Koopman based inference will not contain a parametrized equation, but rather use the numeric values of the operator/generator.
95+
8496
SINDy based algorithms can be called like :
8597

8698
```julia

src/optimizers/sparseregression.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,17 @@
33
$(SIGNATURES)
44
55
Implements a sparse regression, given an `AbstractOptimizer` or `AbstractSubspaceOptimizer`.
6-
`maxiter` indicate the maximum iterations for each call of the optimizer, `abstol` the absolute tolerance of
6+
`X` denotes the coefficient matrix, `A` the design matrix and `Y` the matrix of observed or target values.
7+
`X` can be derived via `init(opt, A, Y)`.
8+
`maxiter` indicates the maximum iterations for each call of the optimizer, `abstol` the absolute tolerance of
79
the difference between iterations in the 2 norm. If the optimizer is called with a `Vector` of thresholds, each `maxiter` indicates
810
the maximum iterations for each threshold.
911
10-
If `progress` is set to `true`, a progressbar will be available.
12+
If `progress` is set to `true`, a progressbar will be available. `progress_outer` and `progress_offset` are used to compute the initial offset of the
13+
progressbar.
14+
15+
If used with a `Vector` of thresholds, the functions `f` with signature `f(X, A, Y)` and `g` with signature `g(x, threshold) = G(f(X, A, Y))` with the arguments given as stated above can be passed in. These are
16+
used for finding the pareto-optimal solution to the sparse regression.
1117
"""
1218
function sparse_regression!(X, A, Y, opt::AbstractOptimizer{T};
1319
maxiter::Int = maximum(size(A)),
@@ -83,7 +89,7 @@ function sparse_regression!(X, A, Y, opt::AbstractOptimizer{T};
8389
)
8490
end
8591
end
86-
92+
8793
for i in 1:size(Y, 2)
8894
@views clip_by_threshold!(X[:, i], λs[i])
8995
end

src/solution.jl

Lines changed: 36 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ function Base.print(io::IO, r::DataDrivenSolution, fullview::DataType)
108108
println(io, "")
109109
print(io, r.res)
110110
println(io, "")
111-
if length(r.res.ps) > 0
111+
if length(r.res.ps) > 0
112112
x = parameter_map(r)
113113
println(io, "Parameters:")
114114
for v in x
@@ -154,14 +154,15 @@ end
154154

155155

156156
# Explicit sindy
157-
function build_solution(prob::DataDrivenProblem, Ξ::AbstractMatrix, opt::Optimize.AbstractOptimizer, b::Basis)
157+
function build_solution(prob::DataDrivenProblem, Ξ::AbstractMatrix, opt::Optimize.AbstractOptimizer, b::Basis;
158+
eval_expression = false)
158159
if all(iszero(Ξ))
159160
@warn "Sparse regression failed! All coefficients are zero."
160161
return DataDrivenSolution(
161-
nothing , :failed, nothing, opt, Ξ, (Problem = prob, Basis = b, nothing),
162+
nothing , :failed, nothing, opt, Ξ, (Problem = prob, Basis = b, nothing),
162163
)
163164
end
164-
165+
165166
eqs, ps, p_ = build_parametrized_eqs(Ξ, b)
166167

167168
# Build the lhs
@@ -176,7 +177,8 @@ function build_solution(prob::DataDrivenProblem, Ξ::AbstractMatrix, opt::Optimi
176177
eqs, states(b),
177178
parameters = [parameters(b); p_], iv = independent_variable(b),
178179
controls = controls(b), observed = observed(b),
179-
name = gensym(:Basis)
180+
name = gensym(:Basis),
181+
eval_expression = eval_expression
180182
)
181183

182184
sparsity = norm(Ξ, 0)
@@ -225,15 +227,15 @@ function build_solution(prob::DataDrivenProblem, Ξ::AbstractMatrix, opt::Optimi
225227
end
226228

227229
function build_solution(prob::DataDrivenProblem, Ξ::AbstractMatrix, opt::Optimize.AbstractSubspaceOptimizer,
228-
b::Basis, implicits::Vector{Num})
230+
b::Basis, implicits::Vector{Num}; eval_expression = false)
229231

230232
if all(iszero(Ξ))
231233
@warn "Sparse regression failed! All coefficients are zero."
232234
return DataDrivenSolution(
233-
nothing , :failed, nothing, opt, Ξ, (Problem = prob, Basis = b, nothing),
235+
nothing , :failed, nothing, opt, Ξ, (Problem = prob, Basis = b, nothing),
234236
)
235237
end
236-
238+
237239
eqs, ps, p_ = build_parametrized_eqs(Ξ, b)
238240
eqs = [0 .~ eq for eq in eqs]
239241

@@ -242,7 +244,8 @@ function build_solution(prob::DataDrivenProblem, Ξ::AbstractMatrix, opt::Optimi
242244
eqs, states(b),
243245
parameters = [parameters(b); p_], iv = independent_variable(b),
244246
controls = controls(b), observed = observed(b),
245-
name = gensym(:Basis)
247+
name = gensym(:Basis),
248+
eval_expression = eval_expression
246249
)
247250

248251
sparsity = norm(Ξ, 0)
@@ -298,8 +301,8 @@ function _round!(x::AbstractArray{T, N}, digits::Int) where {T, N}
298301
end
299302

300303
#function build_solution(prob::DataDrivenProblem, Ξ::AbstractMatrix, opt::AbstractKoopmanAlgorithm)
301-
function build_solution(prob::DataDrivenProblem, k, C, B, Q, P, inds, b::AbstractBasis, alg::AbstractKoopmanAlgorithm; digits::Int = 10)
302-
# Build parameterized equations
304+
function build_solution(prob::DataDrivenProblem, k, C, B, Q, P, inds, b::AbstractBasis, alg::AbstractKoopmanAlgorithm; digits::Int = 10, eval_expression = false)
305+
# Build parameterized equations, inds indicate the location of basis elements containing an input
303306
Ξ = zeros(eltype(B), size(C,2), length(b))
304307

305308
Ξ[:, inds] .= real.(Matrix(k))
@@ -308,7 +311,14 @@ function build_solution(prob::DataDrivenProblem, k, C, B, Q, P, inds, b::Abstrac
308311
end
309312

310313
# Transpose because of the nature of build_parametrized_eqs
311-
eqs, ps, p_ = build_parametrized_eqs(_round!(C*Ξ, digits)', b)
314+
if !eval_expression
315+
eqs, ps, p_ = build_parametrized_eqs(_round!(C*Ξ, digits)', b)
316+
else
317+
= _round!(C*Ξ, digits)
318+
eqs =*Num[states(b); controls(b)]
319+
p_ = []
320+
ps = [K̃...]
321+
end
312322

313323
# Build the lhs
314324
if length(eqs) == length(states(b))
@@ -317,18 +327,25 @@ function build_solution(prob::DataDrivenProblem, k, C, B, Q, P, inds, b::Abstrac
317327
eqs = [d(xs[i]) ~ eq for (i,eq) in enumerate(eqs)]
318328
end
319329

330+
320331
res_ = Koopman(eqs, states(b),
321332
parameters = [parameters(b); p_],
322333
controls = controls(b), iv = independent_variable(b),
323-
K = k, C = C, Q = Q, P = P, lift = b.f)
334+
K = k, C = C, Q = Q, P = P, lift = b.f,
335+
eval_expression = eval_expression)
324336

325-
326-
retcode = :success
327-
pnew = !isempty(parameters(b)) ? [prob.p; ps] : ps
328-
# Equation space
337+
retcode = :sucess
329338
X = prob.DX
330-
X_, _, t, U = get_oop_args(prob)
331-
Y = res_(X_, pnew, t, U)
339+
X_, p_, t, U = get_oop_args(prob)
340+
341+
pnew = !isempty(parameters(b)) ? [p_; ps] : ps
342+
343+
if !eval_expression
344+
# Equation space
345+
Y = res_(X_, pnew, t, U)
346+
else
347+
Y =*b(X_, p_, t, U)
348+
end
332349

333350
# Build the metrics
334351
error = norm(X-Y, 2)
@@ -357,5 +374,4 @@ function build_solution(prob::DataDrivenProblem, k, C, B, Q, P, inds, b::Abstrac
357374
return DataDrivenSolution(
358375
res_, retcode, pnew, alg, Ξ, inputs, metrics
359376
)
360-
return K
361377
end

src/solve/koopman.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11

22
function DiffEqBase.solve(prob::DataDrivenProblem{dType}, alg::AbstractKoopmanAlgorithm;
33
B::AbstractArray = [], digits::Int = 10, operator_only::Bool = false,
4+
eval_expression = false,
45
kwargs...) where {dType <: Number}
56
# Check the validity
67
@assert is_valid(prob) "The problem seems to be ill-defined. Please check the problem definition."
@@ -39,11 +40,12 @@ function DiffEqBase.solve(prob::DataDrivenProblem{dType}, alg::AbstractKoopmanAl
3940

4041
operator_only && return (K = k, C = C, B = B, Q = Q, P = P)
4142

42-
return build_solution(prob, k, C, B, Q, P, inds, b, alg, digits = digits)
43+
return build_solution(prob, k, C, B, Q, P, inds, b, alg, digits = digits, eval_expression = eval_expression)
4344
end
4445

4546
function DiffEqBase.solve(prob::DataDrivenProblem{dType}, b::Basis, alg::AbstractKoopmanAlgorithm;
4647
digits::Int = 10, operator_only::Bool = false,
48+
eval_expression = false,
4749
kwargs...) where {dType <: Number}
4850
# Check the validity
4951
@assert is_valid(prob) "The problem seems to be ill-defined. Please check the problem definition."
@@ -98,5 +100,5 @@ function DiffEqBase.solve(prob::DataDrivenProblem{dType}, b::Basis, alg::Abstrac
98100

99101
operator_only && return (K = k, C = C, B = B, Q = Q, P = P)
100102

101-
return build_solution(prob, k, C, B, Q, P, inds, b, alg, digits = digits)
103+
return build_solution(prob, k, C, B, Q, P, inds, b, alg, digits = digits, eval_expression = eval_expression)
102104
end

src/solve/sindy.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ end
2222
# Main
2323
function DiffEqBase.solve(p::DataDrivenProblem{dType}, b::Basis, opt::Optimize.AbstractOptimizer;
2424
normalize::Bool = false, denoise::Bool = false, maxiter::Int = 0,
25-
round::Bool = true, kwargs...) where {dType <: Number}
25+
round::Bool = true,
26+
eval_expression = false, kwargs...) where {dType <: Number}
2627
# Check the validity
2728
@assert is_valid(p) "The problem seems to be ill-defined. Please check the problem definition."
2829

@@ -48,7 +49,7 @@ function DiffEqBase.solve(p::DataDrivenProblem{dType}, b::Basis, opt::Optimize.A
4849

4950
# Build solution Basis
5051
return build_solution(
51-
p, Ξ, opt, b
52+
p, Ξ, opt, b, eval_expression = eval_expression
5253
)
5354
end
5455

@@ -78,6 +79,7 @@ end
7879
@views function DiffEqBase.solve(p::DataDrivenProblem{dType}, b::Basis,
7980
opt::Optimize.AbstractSubspaceOptimizer, implicits::Vector{Num} = Num[];
8081
normalize::Bool = false, denoise::Bool = false, maxiter::Int = 0,
82+
eval_expression = false,
8183
round::Bool = true, kwargs...) where {dType <: Number}
8284
# Check the validity
8385
@assert is_valid(p) "The problem seems to be ill-defined. Please check the problem definition."
@@ -119,6 +121,6 @@ end
119121
normalize ? rescale_xi!(Ξ, scales, round) : nothing
120122
# Build solution Basis
121123
return build_solution(
122-
p, Ξ, opt, b, implicits
124+
p, Ξ, opt, b, implicits, eval_expression = eval_expression
123125
)
124126
end

test/dmd/linear_autonomous.jl

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ end
3434
u0 = [10.0; -20.0]
3535
prob = ODEProblem(f, u0, (0.0, 10.0))
3636
sol = solve(prob, Tsit5(), saveat = 0.001)
37-
37+
3838
prob = DataDrivenProblem(sol)
3939

4040
for alg in [DMDPINV(), DMDSVD(), TOTALDMD()]
@@ -45,7 +45,7 @@ end
4545
@test isempty(estimator.B)
4646
res = solve(prob, alg , operator_only = false)
4747
m = metrics(res)
48-
@test m.Error ./ size(X, 2) < 3e-1
48+
@test m.Error ./ size(sol, 2) < 3e-1
4949
@test Matrix(result(res)) Matrix(estimator.K)
5050
end
5151
end
@@ -74,6 +74,16 @@ end
7474
end
7575
end
7676

77+
@testset "Big System" begin
78+
# Creates a big system which would resulting in a segfault otherwise
79+
X = rand([0, 1], 128, 936);
80+
T = collect(LinRange(0, 4.367058580858928, 936));
81+
problem = DiscreteDataDrivenProblem(X, T);
82+
res1 = solve(problem, DMDSVD(), eval_expression = true)
83+
res2 = solve(problem, DMDSVD(), operator_only = true)
84+
@test Matrix(result(res1)) == real.(Matrix(res2.K))
85+
end
86+
7787
# TODO Include the Big System test
7888
# This fails to generate an equation right now...
7989
#@testset "Big System" begin

test/dmd/nonlinear_autonomous.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@
2121
b = result(res)
2222
m = metrics(res)
2323
@test isapprox(eigvals(b), [2*p[1]; p[1]; p[2]], atol = 1e-1)
24-
@test m.Error/size(X, 2) < 1e-1
24+
@test m.Error/size(solution, 2) < 1e-1
2525

2626
_prob = ODEProblem((args...)->b(args...), u0, tspan, parameters(res))
2727
_sol = solve(_prob, Tsit5(), saveat = solution.t)
28-
@test norm(solution - _sol)/size(X, 2) < 1e-1
28+
@test norm(solution - _sol)/size(solution, 2) < 1e-1
2929
end
3030
end

test/dmd/nonlinear_forced.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
problem = ODEProblem(slow_manifold, u0, tspan, p)
1313
solution = solve(problem, Tsit5(), saveat = 0.01)
14-
14+
1515
ufun(u,p,t) = sin(t^2)
1616

1717
prob = ContinuousDataDrivenProblem(solution, U = ufun)
@@ -24,7 +24,7 @@
2424
b = result(res)
2525
m = metrics(res)
2626
@test isapprox(eigvals(b), [2*p[1]; p[1]; p[2]], atol = 1e-1)
27-
@test m.Error/size(X, 2) < 3e-1
27+
@test m.Error/size(solution, 2) < 3e-1
2828

2929
# TODO This does not work right now, but it should
3030
#sdict = Dict([y[1] => sin(t^2)])

0 commit comments

Comments
 (0)