Skip to content

Commit 266e1e4

Browse files
Merge pull request #34 from AlCap23/master
Restructured Symbolic Recovery
2 parents 63854e4 + b85d921 commit 266e1e4

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+535
-225
lines changed
-83.4 KB
Binary file not shown.

LotkaVolterra/Manifest.toml

Lines changed: 105 additions & 87 deletions
Large diffs are not rendered by default.
-188 KB
Binary file not shown.
-28.7 KB
Binary file not shown.
-67.4 KB
Binary file not shown.

LotkaVolterra/hudson_bay.jl

Lines changed: 71 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -16,23 +16,23 @@ using DelimitedFiles
1616
using Random
1717
Random.seed!(5443)
1818

19+
svname = "HudsonBay"
1920
## Data Preprocessing
2021
# The data has been taken from https://jmahaffy.sdsu.edu/courses/f00/math122/labs/labj/q3v1.htm
21-
# Originally published in
22+
# Originally published in E. P. Odum (1953), Fundamentals of Ecology, Philadelphia, W. B. Saunders
2223
hudson_bay_data = readdlm("hudson_bay_data.dat", '\t', Float32, '\n')
2324
# Measurements of prey and predator
2425
Xₙ = Matrix(transpose(hudson_bay_data[:, 2:3]))
25-
plot(t, transpose(Xₙ))
26+
t = hudson_bay_data[:, 1] .- hudson_bay_data[1, 1]
2627
# Normalize the data; since the data domain is strictly positive
2728
# we just need to divide by the maximum
2829
xscale = maximum(Xₙ, dims =2)
2930
Xₙ .= 1f0 ./ xscale .* Xₙ
3031
# Time from 0 -> n
31-
t = hudson_bay_data[:, 1] .- hudson_bay_data[1, 1]
3232
tspan = (t[1], t[end])
3333

3434
# Plot the data
35-
scatter(t, transpose(Xₙ), xlabel = "t [a]", ylabel = "x(t), y(t)")
35+
scatter(t, transpose(Xₙ), xlabel = "t", ylabel = "x(t), y(t)")
3636
plot!(t, transpose(Xₙ), xlabel = "t", ylabel = "x(t), y(t)")
3737

3838
## Direct Identification via SINDy + Collocation
@@ -50,17 +50,17 @@ plot(t, dx̂')
5050
b = [polynomial_basis(u, 5); sin.(u)]
5151
basis = Basis(b, u)
5252
# Create an optimizer for the SINDy problem
53-
opt = SR3(Float32(1e-2), Float32(1e-2))
53+
opt = STRRidge()#SR3(Float32(1e-2), Float32(1e-2))
5454
# Create the thresholds which should be used in the search process
5555
λ = Float32.(exp10.(-7:0.1:3))
5656
# Target function to choose the results from; x = L0 of coefficients and L2-Error of the model
5757
g(x) = x[1] < 1 ? Inf : norm(x, 2)
5858
# Test on derivative data
59-
Ψ = SINDy(x̂, dx̂, basis, λ, opt, g = g, maxiter = 50000, normalize = true, denoise = true) # Succeed
59+
Ψ = SINDy(x̂, dx̂, basis, λ, opt, g = g, maxiter = 50000, normalize = true, denoise = true)
6060
println(Ψ)
6161
print_equations(Ψ) # Fails
6262
b2 = Basis((u,p,t)->Ψ(u,ones(length(parameters(Ψ))),t),u, linear_independent = true)
63-
Ψ = SINDy(x̂, dx̂, b2, λ, opt, g = g, maxiter = 50000, normalize = true, denoise = true) # Succeed
63+
Ψ = SINDy(x̂, dx̂, b2, λ, opt, g = g, maxiter = 50000, normalize = true, denoise = true)
6464
println(Ψ)
6565
print_equations(Ψ) # Fails
6666
parameters(Ψ)
@@ -69,7 +69,6 @@ parameters(Ψ)
6969
# We assume we have only 5 measurements in y, evenly distributed
7070
ty = collect(t[1]:Float32(t[end]/5):t[end])
7171
# Create datasets for the different measurements
72-
t
7372
XS = zeros(Float32, length(ty)-1, floor(Int64, mean(diff(ty))/mean(diff(t)))+1) # All x data
7473
TS = zeros(Float32, length(ty)-1, floor(Int64, mean(diff(ty))/mean(diff(t)))+1) # Time data
7574
YS = zeros(Float32, length(ty)-1, 2) # Just two measurements in y
@@ -160,6 +159,12 @@ println("Training loss after $(length(losses)) iterations: $(losses[end])")
160159
res3 = DiffEqFlux.sciml_train(loss, res2.minimizer, BFGS(initial_stepnorm=0.01f0), cb=callback, maxiters = 10000)
161160
println("Final training loss after $(length(losses)) iterations: $(losses[end])")
162161

162+
163+
pl_losses = plot(1:101, losses[1:101], yaxis = :log10, xaxis = :log10, xlabel = "Iterations", ylabel = "Loss", label = "ADAM (Shooting)", color = :blue)
164+
plot!(102:302, losses[102:302], yaxis = :log10, xaxis = :log10, xlabel = "Iterations", ylabel = "Loss", label = "BFGS (Shooting)", color = :red)
165+
plot!(302:length(losses), losses[302:end], color = :black, label = "BFGS (L2)")
166+
savefig(pl_losses, joinpath(pwd(), "plots", "$(svname)_losses.pdf"))
167+
163168
# Rename the best candidate
164169
p_trained = res3.minimizer
165170

@@ -168,14 +173,18 @@ p_trained = res3.minimizer
168173
tsample = t[1]:0.5:t[end]
169174
= predict(p_trained, Xₙ[:,1], tsample)
170175
# Trained on noisy data vs real solution
171-
plot(t, transpose(Xₙ), color = :black, label = ["Measurements" nothing])
172-
plot!(tsample, transpose(X̂), color = :red, label = ["Interpolation" nothing])
176+
pl_trajectory = scatter(t, transpose(Xₙ), color = :black, label = ["Measurements" nothing], xlabel = "t", ylabel = "x(t), y(t)")
177+
plot!(tsample, transpose(X̂), color = :red, label = ["UDE Approximation" nothing])
178+
savefig(pl_trajectory, joinpath(pwd(), "plots", "$(svname)_trajectory_reconstruction.pdf"))
173179

174180
# Neural network guess
175181
= U(X̂,p_trained[3:end])
176182

177-
scatter(tsample, transpose(Ŷ), xlabel = "t", ylabel ="I1(t), I2(t)", color = :red, label = ["UDE Approximation" nothing])
178-
183+
pl_reconstruction = scatter(tsample, transpose(Ŷ), xlabel = "t", ylabel ="U(x,y)", color = :red, label = ["UDE Approximation" nothing])
184+
plot!(tsample, transpose(Ŷ), color = :red, lw = 2, style = :dash, label = [nothing nothing])
185+
savefig(pl_reconstruction, joinpath(pwd(), "plots", "$(svname)_missingterm_reconstruction.pdf"))
186+
pl_missing = plot(pl_trajectory, pl_reconstruction, layout = (2,1))
187+
savefig(pl_missing, joinpath(pwd(), "plots", "$(svname)_reconstruction.pdf"))
179188
## Symbolic regression via sparse regression ( SINDy based )
180189

181190
# Create a Basis
@@ -187,7 +196,7 @@ b = [polynomial_basis(u, 5); sin.(u)]
187196
basis = Basis(b, u)
188197

189198
# Create an optimizer for the SINDy problem
190-
opt = SR3(Float32(1e-2), Float32(1e-2))
199+
opt = STRRidge()
191200
# Create the thresholds which should be used in the search process
192201
λ = Float32.(exp10.(-7:0.1:3))
193202
# Target function to choose the results from; x = L0 of coefficients and L2-Error of the model
@@ -196,51 +205,76 @@ g(x) = x[1] < 1 ? Inf : norm(x, 2)
196205
# Test on uode derivative data
197206
println("SINDy on learned, partial, available data")
198207
Ψ = SINDy(X̂, Ŷ, basis, λ, opt, g = g, maxiter = 50000, normalize = true, denoise = true)
199-
println(Ψ)
200-
print_equations(Ψ)
201208

209+
@info "Found equations:"
210+
print_equations(Ψ)
202211
# Extract the parameter
203212
= parameters(Ψ)
204213
println("First parameter guess : $(p̂)")
205214

206-
# Just the equations -> we reiterate on sindy here
207-
# searching all linear independent components again
208-
b = Basis((u, p, t)->Ψ(u, ones(length(p̂)), t), u, linear_independent = true)
209-
println(b)
210-
# Retune for better parameters -> we could also use DiffEqFlux or other parameter estimation tools here.
211-
opt = SR3(Float32(1e-2), Float32(1e-2))
212-
Ψf = SINDy(X̂, Ŷ, b, opt, maxiter = 10000, normalize = true, convergence_error = eps()) # Succeed
213-
println(Ψf)
214-
print_equations(Ψf)
215-
= parameters(Ψf)
216-
println("Second parameter guess : $(p̂)")
217-
218215
# Define the recovered, hyrid model with the rescaled dynamics
219216
function recovered_dynamics!(du,u, p, t)
220-
= Ψf(u, p[3:4]) # Network prediction
217+
= Ψ(u, p[3:end]) # Network prediction
221218
du[1] = p[1]*u[1] + û[1]
222219
du[2] = -p[2]*u[2] + û[2]
223220
end
224221

225222
p_model = [p_trained[1:2];p̂]
226223
estimation_prob = ODEProblem(recovered_dynamics!, Xₙ[:, 1], tspan, p_model)
227-
estimate = solve(estimation_prob, Tsit5(), saveat = 0.1)
224+
# Convert for reuse
225+
sys = modelingtoolkitize(estimation_prob);
226+
dudt = ODEFunction(sys);
227+
estimation_prob = ODEProblem(dudt,Xₙ[:, 1], tspan, p_model)
228+
estimate = solve(estimation_prob, Tsit5(), saveat = t)
229+
230+
## Fit the found model
231+
function loss_fit(θ)
232+
= Array(solve(estimation_prob, Tsit5(), p = θ, saveat = t))
233+
sum(abs2, X̂ .- Xₙ)
234+
end
235+
236+
# Post-fit the model
237+
res_fit = DiffEqFlux.sciml_train(loss_fit, p_model, BFGS(initial_stepnorm = 0.1f0), maxiters = 1000)
238+
p_fitted = res_fit.minimizer
239+
240+
# Estimate
241+
estimate_rough = solve(estimation_prob, Tsit5(), saveat = 0.1*mean(diff(t)), p = p_model)
242+
estimate = solve(estimation_prob, Tsit5(), saveat = 0.1*mean(diff(t)), p = p_fitted)
228243

229244
# Plot
230-
plot(t, transpose(Xₙ))
231-
plot!(estimate)
245+
pl_fitted = plot(t, transpose(Xₙ), style = :dash, lw = 2,color = :black, label = ["Measurements" nothing], xlabel = "t", ylabel = "x(t), y(t)")
246+
plot!(estimate_rough, color = :red, label = ["Recovered" nothing])
247+
plot!(estimate, color = :blue, label = ["Recovered + Fitted" nothing])
248+
savefig(pl_fitted,joinpath(pwd(),"plots","$(svname)recovery_fitting.pdf"))
232249

233250
## Simulation
234251

235252
# Look at long term prediction
236253
t_long = (0.0f0, 50.0f0)
237-
estimation_prob = ODEProblem(recovered_dynamics!, Xₙ[:, 1], t_long, p_model)
238-
estimate_long = solve(estimation_prob, Tsit5(), saveat = 0.25)
239-
plot(estimate_long)
254+
estimate_long = solve(estimation_prob, Tsit5(), saveat = 0.25f0, tspan = t_long,p = p_fitted)
255+
plot(estimate_long.t, transpose(xscale .* estimate_long[:,:]), xlabel = "t", ylabel = "x(t),y(t)")
256+
240257

241258
## Save the results
242-
save("Hudson_Bay_recovery.jld2",
259+
save(joinpath(pwd(),"results","Hudson_Bay_recovery.jld2"),
243260
"X", Xₙ, "t" , t, "neural_network" , U, "initial_parameters", p, "trained_parameters" , p_trained, # Training
244-
"losses", losses, "result", Ψf, "recovered_parameters", p̂, # Recovery
245-
"model", recovered_dynamics!, "model_parameter", p_model,
261+
"losses", losses, "result", Ψ, "recovered_parameters", p̂, # Recovery
262+
"model", recovered_dynamics!, "model_parameter", p_model, "fitted_parameter", p_fitted,
246263
"long_estimate", estimate_long) # Estimation
264+
265+
## Post Processing and Plots
266+
267+
c1 = 3 # RGBA(174/255,192/255,201/255,1) # Maroon
268+
c2 = :orange # RGBA(132/255,159/255,173/255,1) # Red
269+
c3 = :blue # RGBA(255/255,90/255,0,1) # Orange
270+
c4 = :purple # RGBA(153/255,50/255,204/255,1) # Purple
271+
272+
p3 = scatter(t, transpose(Xₙ), color = [c1 c2], label = ["x data" "y data"],
273+
title = "Recovered Model from Hudson Bay Data",
274+
titlefont = "Helvetica", legendfont = "Helvetica",
275+
markersize = 5)
276+
277+
plot!(p3,estimate_long, color = [c3 c4], lw=1, label = ["Estimated x(t)" "Estimated y(t)"])
278+
plot!(p3,[19.99,20.01],[0.0,maximum(Xₙ)*1.25],lw=1,color=:black, label = nothing)
279+
annotate!([(10.0,maximum(Xₙ)*1.25,text("Training \nData",12 , :center, :top, :black, "Helvetica"))])
280+
savefig(p3,joinpath(pwd(),"plots","$(svname)full_plot.pdf"))

0 commit comments

Comments
 (0)