Skip to content

Commit 40340bf

Browse files
authored
New version and Symbolic Regression Result handling (#325)
* Adapt Symbolic regression solution process * Update getters * Remove old code for now * Bump version
1 parent 2925a9c commit 40340bf

File tree

7 files changed

+83
-114
lines changed

7 files changed

+83
-114
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DataDrivenDiffEq"
22
uuid = "2445eb08-9709-466a-b3fc-47e12bd697a2"
33
authors = ["Julius Martensen <[email protected]>"]
4-
version = "0.6.9"
4+
version = "0.7.0"
55

66
[deps]
77
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"

src/basis/type.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ function dynamics(b::AbstractBasis)
238238
end
239239

240240
## Callable
241-
get_f(b::AbstractBasis) = getproperty(b, :f)
241+
get_f(b::AbstractBasis) = getfield(b, :f)
242242

243243
# Fallback
244244
(b::AbstractBasis)(args...) = get_f(b)(args...)

src/koopman/type.jl

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -146,19 +146,20 @@ end
146146

147147

148148
# We assume that we only have real valued observed
149-
Base.Matrix(k::AbstractKoopman) = real.(Matrix(k.K))
149+
Base.Matrix(k::AbstractKoopman) = real.(Matrix(_get_K(k)))
150150

151151
# Get the lifting function
152-
lifting(k::AbstractKoopman) = k.lift
152+
lifting(k::AbstractKoopman) = getfield(k, :lift)
153153

154+
# Get K
155+
_get_K(k::AbstractKoopman) = getfield(k, :K)
154156

155-
# TODO FIXME MAYBE?
156157
"""
157158
$(SIGNATURES)
158159
159160
Returns `true` if the `AbstractKoopmanOperator` `k` is discrete in time.
160161
"""
161-
is_discrete(k::AbstractKoopman) = !(!k.is_discrete)
162+
is_discrete(k::AbstractKoopman) = getfield(k, :is_discrete)
162163

163164

164165

@@ -167,28 +168,32 @@ $(SIGNATURES)
167168
168169
Returns `true` if the `AbstractKoopmanOperator` `k` is continuous in time.
169170
"""
170-
is_continuous(k::AbstractKoopman) = !k.is_discrete
171+
is_continuous(k::AbstractKoopman) = !is_discrete(k)
171172

172173
"""
173174
$(SIGNATURES)
174175
175176
Return the eigendecomposition of the `AbstractKoopmanOperator`.
176177
"""
177-
LinearAlgebra.eigen(k::AbstractKoopman) = isa(k.K, Eigen) ? k.K : eigen(k.K)
178+
LinearAlgebra.eigen(k::AbstractKoopman) = begin
179+
K = _get_K(k)
180+
isa(K, Eigen) && return K
181+
eigen(K)
182+
end
178183

179184
"""
180185
$(SIGNATURES)
181186
182187
Return the eigenvalues of the `AbstractKoopmanOperator`.
183188
"""
184-
LinearAlgebra.eigvals(k::AbstractKoopman) = eigvals(k.K)
189+
LinearAlgebra.eigvals(k::AbstractKoopman) = eigvals(_get_K(k))
185190

186191
"""
187192
$(SIGNATURES)
188193
189194
Return the eigenvectors of the `AbstractKoopmanOperator`.
190195
"""
191-
LinearAlgebra.eigvecs(k::AbstractKoopman) = eigvecs(k.K)
196+
LinearAlgebra.eigvecs(k::AbstractKoopman) = eigvecs(_get_K(k))
192197

193198
"""
194199
$(SIGNATURES)
@@ -209,14 +214,14 @@ $(SIGNATURES)
209214
210215
Return the approximation of the discrete Koopman operator stored in `k`.
211216
"""
212-
operator(k::AbstractKoopman) = is_discrete(k) ? k.K : throw(AssertionError("Koopman is continouos."))
217+
operator(k::AbstractKoopman) = is_discrete(k) ? _get_K(k) : throw(AssertionError("Koopman is continouos."))
213218

214219
"""
215220
$(SIGNATURES)
216221
217222
Return the approximation of the continuous Koopman generator stored in `k`.
218223
"""
219-
generator(k::AbstractKoopman) = is_continuous(k) ? k.K : throw(AssertionError("Koopman is discrete."))
224+
generator(k::AbstractKoopman) = is_continuous(k) ? _get_K(k) : throw(AssertionError("Koopman is discrete."))
220225

221226
"""
222227
$(SIGNATURES)
@@ -240,7 +245,11 @@ Returns `true` if either:
240245
+ the Koopman operator has just eigenvalues with magnitude less than one or
241246
+ the Koopman generator has just eigenvalues with a negative real part
242247
"""
243-
is_stable(k::AbstractKoopman) = is_discrete(k) ? all(real.(eigvals(k)) .< real.(one(eltype(k.K)))) : all(real.(eigvals(k)) .< zero(eltype(k.K)))
248+
is_stable(k::AbstractKoopman) = begin
249+
K = _get_K(k)
250+
is_discrete(k) && all(real.(eigvals(k)) .< real.(one(eltype(K))))
251+
all(real.(eigvals(k)) .< zero(eltype(K)))
252+
end
244253

245254
# TODO This does not work, since we are using the reduced basis instead of the
246255
# original, lifted dynamics...

src/solution.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,7 @@ function DataDrivenSolution(prob::AbstractDataDrivenProblem, k, C, B, Q, P, inds
375375
res_ = Koopman(equations(bs), states(bs),
376376
parameters = parameters(bs),
377377
controls = controls(bs), iv = get_iv(bs),
378-
K = k, C = C, Q = Q, P = P, lift = b.f,
378+
K = k, C = C, Q = Q, P = P, lift = get_f(b),
379379
is_discrete = is_discrete(prob),
380380
eval_expression = eval_expression)
381381

src/symbolic_regression/symbolic_regression.jl

Lines changed: 51 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,16 @@ function to_options(x::EQSearch)
3737
)
3838
end
3939

40+
41+
4042
function DiffEqBase.solve(prob::AbstractDataDrivenProblem, alg::EQSearch;
4143
max_iter::Int = 10,
4244
weights = nothing,
4345
numprocs = nothing, procs = nothing,
4446
multithreading = false,
4547
runtests::Bool = true,
46-
eval_expression = false
48+
eval_expression = false,
49+
kwargs...
4750
)
4851

4952
opt = to_options(alg)
@@ -60,29 +63,63 @@ function DiffEqBase.solve(prob::AbstractDataDrivenProblem, alg::EQSearch;
6063
hof = SymbolicRegression.EquationSearch(X, Y, niterations = max_iter, weights = weights, options = opt,
6164
numprocs = numprocs, procs = procs, multithreading = multithreading,
6265
runtests = runtests)
63-
# Sort the paretofront
64-
doms = map(1:size(Y, 1)) do i
65-
calculateParetoFrontier(X, Y[i, :], hof[i], opt)
66-
end
6766

68-
build_solution(prob, alg, doms; eval_expression = eval_expression)
67+
build_solution(prob, alg, hof; eval_expression = eval_expression)
6968
end
7069

70+
function pareto_optimal_equations(hof::HallOfFame, prob, alg)
71+
return pareto_optimal_equations([hof], prob, alg)
72+
end
7173

72-
function build_solution(prob::AbstractDataDrivenProblem, alg::EQSearch, doms; eval_expression = false)
7374

74-
opt = to_options(alg)
75+
function pareto_optimal_equations(hof::Vector{HallOfFame}, prob, alg)
76+
77+
opts = DataDrivenDiffEq.to_options(alg)
78+
y = DataDrivenDiffEq.get_target(prob)
79+
x, _, t, c = DataDrivenDiffEq.get_oop_args(prob)
80+
X = vcat([x for x in (x, c, permutedims(t)) if !isempty(x)]...)
81+
7582
@variables x[1:size(prob.X, 1)] u[1:size(prob.U,1)] t
7683
x = Symbolics.scalarize(x)
7784
u = Symbolics.scalarize(u)
78-
x_ = [x;u;t]
85+
x_ = Num[x;u;t]
7986

8087
# Build a dict
8188
subs = Dict([SymbolicUtils.Sym{Number}(Symbol("x$(i)")) => x_[i] for i in 1:size(x_, 1)]...)
82-
# Create a variable
83-
eqs = vcat(map(x->node_to_symbolic(x[end].tree, opt), doms))
84-
eqs = map(x->substitute(x, subs), eqs)
8589

90+
91+
eqs = map(1:size(hof, 1)) do i
92+
@show i
93+
d = calculateParetoFrontier(X, y[i,:], hof[i], opts)
94+
isempty(d) && return Num(0)
95+
eq_ = node_to_symbolic(last(d).tree, opts)
96+
substitute(eq_, subs)
97+
end
98+
99+
return eqs, x, u, t
100+
end
101+
102+
103+
104+
function build_solution(prob::AbstractDataDrivenProblem, alg::EQSearch, hof; eval_expression = false)
105+
106+
#opt = to_options(alg)
107+
#
108+
#@variables x[1:size(prob.X, 1)] u[1:size(prob.U,1)] t
109+
#x = Symbolics.scalarize(x)
110+
#u = Symbolics.scalarize(u)
111+
#x_ = [x;u;t]
112+
113+
# Build a dict
114+
#subs = Dict([SymbolicUtils.Sym{Number}(Symbol("x$(i)")) => x_[i] for i in 1:size(x_, 1)]...)
115+
116+
117+
# Create a variable
118+
#eqs = vcat(map(x->node_to_symbolic(x[end].tree, opt), doms))
119+
#eqs = map(x->substitute(x, subs), eqs)
120+
121+
eqs, x, u, t = pareto_optimal_equations(hof, prob, alg)
122+
86123
lhs, dt = assert_lhs(prob)
87124

88125

@@ -104,10 +141,11 @@ function build_solution(prob::AbstractDataDrivenProblem, alg::EQSearch, doms; ev
104141
Y = res_(get_oop_args(prob)...)
105142

106143

144+
107145
error = sum(abs2, X-Y, dims = 2)[:,1]
108146
retcode = :converged
109147

110148
return DataDrivenSolution(
111-
false, res_, [], retcode, alg, doms, prob, error
149+
false, res_, [], retcode, alg, hof, prob, error
112150
)
113151
end

src/utils/utils.jl

Lines changed: 0 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -142,90 +142,3 @@ function optimal_shrinkage!(X::AbstractArray{T, 2}) where T <: Number
142142
X .= U[:, inds]*Diagonal(S[inds])*V[:, inds]'
143143
return
144144
end
145-
146-
147-
## TODO
148-
# This is old code and will be processed to be used with the Problems
149-
#"""
150-
# ($SIGNATURES)
151-
#
152-
#Randomly selects `n` bursts of data with size `samplesize` from the data `X`.
153-
#
154-
#Randomly selects `n` bursts of data with size `samplesize` from the data `X` and `Y`.
155-
#
156-
#Randomly selects `n` bursts of data within a time window `period` from the data `X`. The time information
157-
#has to be provided in `t`.
158-
#"""
159-
#@inline function burst_sampling(x::AbstractArray, samplesize::Int64, bursts::Int64)
160-
# @assert size(x)[end] >= samplesize*bursts "Bursting impossible. Please provide more data or reduce bursts or samplesize."
161-
# inds = sample(1:size(x)[end]-samplesize, bursts, replace = false)
162-
# inds = sort(unique(vcat([collect(i:i+samplesize) for i in inds]...)))
163-
# return resample(x, inds)
164-
#end
165-
#
166-
#@inline function burst_sampling(x::AbstractArray, y::AbstractArray, samplesize::Int64, bursts::Int64)
167-
# @assert size(x)[end] >= samplesize*bursts "Bursting impossible. Please provide more data or reduce bursts or samplesize"
168-
# @assert size(x)[end] == size(y)[end]
169-
# inds = sample(1:size(x)[end]-samplesize, bursts, replace = false)
170-
# inds = sort(unique(vcat([collect(i:i+samplesize) for i in inds]...)))
171-
# return resample(x, inds), resample(y, inds)
172-
#end
173-
#
174-
#@inline function burst_sampling(x::AbstractArray, t::AbstractVector, period::T, bursts::Int64) where T <: AbstractFloat
175-
# @assert period > zero(typeof(period)) "Sampling period has to be positive."
176-
# @assert size(x)[end] == size(t)[end] "Provide consistent data."
177-
# @assert bursts >= 1 "Number of bursts has to be positive."
178-
# @assert t[end]-t[1]>= period*bursts "Bursting impossible. Please provide more data or reduce bursts or samplesize"
179-
# t_ids = zero(eltype(t)) .<= t .- period .<= t[end] .- 2*period
180-
# samplesize = Int64(floor(period/(t[end]-t[1])*length(t)))
181-
# inds = sample(collect(1:length(t))[t_ids], bursts, replace = false)
182-
# inds = sort(unique(vcat([collect(i:i+samplesize) for i in inds]...)))
183-
# return resample(x, inds), resample(t, inds)
184-
#end
185-
#
186-
#
187-
#"""
188-
# $(SIGNATURES)
189-
#
190-
#Returns the subsampled `X` with only every `n`-th entry.
191-
#
192-
#Returns the subsampled `X` with a a minimum period of `dt` between two data points. `t` provides the
193-
#time information.
194-
#"""
195-
#@inline function subsample(x::AbstractVector, frequency::Int64)
196-
# @assert frequency > 0 "Sampling frequency has to be positive."
197-
# return x[1:frequency:end]
198-
#end
199-
#
200-
#@inline function subsample(x::AbstractArray, frequency::Int64)
201-
# @assert frequency > 0 "Sampling frequency has to be positive."
202-
# return x[:, 1:frequency:end]
203-
#end
204-
#
205-
#@inline function subsample(x::AbstractArray, t::AbstractVector, period::T) where T <: AbstractFloat
206-
# @assert period > zero(typeof(period)) "Sampling period has to be positive."
207-
# @assert size(x)[end] == size(t)[end] "Provide consistent data."
208-
# @assert t[end]-t[1]>= period "Subsampling impossible. Sampling period exceeds time window."
209-
# idx = Int64[1]
210-
# t_now = t[1]
211-
# @inbounds for (i, t_current) in enumerate(t)
212-
# if t_current - t_now >= period
213-
# push!(idx, i)
214-
# t_now = t_current
215-
# end
216-
# end
217-
# return resample(x, idx), resample(t, idx)
218-
#end
219-
#
220-
#@inline function resample(x::AbstractArray{T,1}, indx::AbstractArray{Int64}) where T <: Number
221-
# @assert maximum(indx) <= length(x) "Sampling index has to be consistent with array dimensions."
222-
# @assert minimum(indx) >= 1 "Sampling index has to be consistent with array dimensions."
223-
# return x[indx]
224-
#end
225-
#
226-
#@inline function resample(x::AbstractArray{T,2}, indx::AbstractArray{Int64}) where T <: Number
227-
# @assert maximum(indx) <= size(x, 2) "Sampling index has to be consistent with array dimensions."
228-
# @assert minimum(indx) >= 1 "Sampling index has to be consistent with array dimensions."
229-
# return x[:, indx]
230-
#end
231-
#

test/symbolic_regression/symbolic_regression.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,13 @@
1515
x = states(sys)
1616
@test all(m[:L₂] .<= eps())
1717
@test isequal([x.rhs for x in equations(sys)], [sin(x[1]); exp(x[2])])
18+
19+
# Single target
20+
prob = DirectDataDrivenProblem(X, Y[1:1,:])
21+
res = solve(prob, opts, numprocs = 0, multithreading = false)
22+
sys = result(res)
23+
m = metrics(res)
24+
x = states(sys)
25+
@test all(m[:L₂] .<= eps())
26+
@test isequal([x.rhs for x in equations(sys)], [sin(x[1])])
1827
end

0 commit comments

Comments
 (0)