Skip to content

Commit 6804d18

Browse files
author
DrWatson
committed
Add Hudson Bay example
1 parent 558e4f1 commit 6804d18

File tree

3 files changed

+267
-0
lines changed

3 files changed

+267
-0
lines changed
83.4 KB
Binary file not shown.

LotkaVolterra/hudson_bay.jl

Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
## Environment and packages
2+
cd(@__DIR__)
3+
using Pkg; Pkg.activate("."); Pkg.instantiate()
4+
5+
using OrdinaryDiffEq
6+
using ModelingToolkit
7+
using DataDrivenDiffEq
8+
using LinearAlgebra, DiffEqSensitivity, Optim
9+
using DiffEqFlux, Flux
10+
using Plots
11+
gr()
12+
using JLD2, FileIO
13+
using Statistics
14+
using DelimitedFiles
15+
# Set a random seed for reproduceable behaviour
16+
using Random
17+
Random.seed!(5443)
18+
19+
## Data Preprocessing
20+
# The data has been taken from https://jmahaffy.sdsu.edu/courses/f00/math122/labs/labj/q3v1.htm
21+
# Originally published in
22+
hudson_bay_data = readdlm("hudson_bay_data.dat", '\t', Float32, '\n')
23+
# Measurements of prey and predator
24+
Xₙ = Matrix(transpose(hudson_bay_data[:, 2:3]))
25+
plot(t, transpose(Xₙ))
26+
# Normalize the data; since the data domain is strictly positive
27+
# we just need to divide by the maximum
28+
xscale = maximum(Xₙ, dims =2)
29+
Xₙ .= 1f0 ./ xscale .* Xₙ
30+
# Time from 0 -> n
31+
t = hudson_bay_data[:, 1] .- hudson_bay_data[1, 1]
32+
tspan = (t[1], t[end])
33+
34+
# Plot the data
35+
scatter(t, transpose(Xₙ), xlabel = "t [a]", ylabel = "x(t), y(t)")
36+
plot!(t, transpose(Xₙ), xlabel = "t", ylabel = "x(t), y(t)")
37+
38+
## Direct Identification via SINDy + Collocation
39+
# Create a collocation
40+
dx̂,x̂ = collocate_data(Xₙ,t, GaussianKernel())
41+
# Look at the collocation
42+
plot(t, dx̂')
43+
# Perform sindy
44+
45+
# Create a Basis
46+
@variables u[1:2]
47+
48+
# Generate the basis functions, multivariate polynomials up to deg 5
49+
# and sine
50+
b = [polynomial_basis(u, 5); sin.(u)]
51+
basis = Basis(b, u)
52+
# Create an optimizer for the SINDy problem
53+
opt = SR3(Float32(1e-2), Float32(1e-2))
54+
# Create the thresholds which should be used in the search process
55+
λ = Float32.(exp10.(-7:0.1:3))
56+
# Target function to choose the results from; x = L0 of coefficients and L2-Error of the model
57+
g(x) = x[1] < 1 ? Inf : norm(x, 2)
58+
# Test on derivative data
59+
Ψ = SINDy(x̂, dx̂, basis, λ, opt, g = g, maxiter = 50000, normalize = true, denoise = true) # Succeed
60+
println(Ψ)
61+
print_equations(Ψ) # Fails
62+
b2 = Basis((u,p,t)->Ψ(u,ones(length(parameters(Ψ))),t),u, linear_independent = true)
63+
Ψ = SINDy(x̂, dx̂, b2, λ, opt, g = g, maxiter = 50000, normalize = true, denoise = true) # Succeed
64+
println(Ψ)
65+
print_equations(Ψ) # Fails
66+
parameters(Ψ)
67+
## UDE Approach
68+
# Subsample the data in y -> initial fitting strategy (batching)
69+
# We assume we have only 5 measurements in y, evenly distributed
70+
ty = collect(t[1]:Float32(t[end]/5):t[end])
71+
# Create datasets for the different measurements
72+
t
73+
XS = zeros(Float32, length(ty)-1, floor(Int64, mean(diff(ty))/mean(diff(t)))+1) # All x data
74+
TS = zeros(Float32, length(ty)-1, floor(Int64, mean(diff(ty))/mean(diff(t)))+1) # Time data
75+
YS = zeros(Float32, length(ty)-1, 2) # Just two measurements in y
76+
77+
for i in 1:length(ty)-1
78+
idxs = ty[i].<= t .<= ty[i+1]
79+
XS[i, :] = Xₙ[1, idxs]
80+
TS[i, :] = t[idxs]
81+
YS[i, :] = [Xₙ[2, t .== ty[i]]'; Xₙ[2, t .== ty[i+1]]]
82+
end
83+
84+
## Define the network
85+
# Gaussian RBF as activation
86+
rbf(x) = exp.(-(x.^2))
87+
88+
# Define the network 2->5->5->5->2
89+
U = FastChain(
90+
FastDense(2,5,rbf), FastDense(5,5, rbf), FastDense(5,5, tanh), FastDense(5,2)
91+
)
92+
93+
# Get the initial parameters, first two is linear birth / decay of prey and predator
94+
p = [rand(Float32,2); initial_params(U)]
95+
96+
# Define the hybrid model
97+
function ude_dynamics!(du,u, p, t)
98+
= U(u, p[3:end]) # Network prediction
99+
# We assume a linear birth rate for the prey
100+
du[1] = p[1]u[1] + û[1]
101+
# We assume a linear decay rate for the predator
102+
du[2] = -p[2]*u[2] + û[2]
103+
end
104+
105+
# Define the problem
106+
prob_nn = ODEProblem(ude_dynamics!,Xₙ[:, 1], tspan, p)
107+
108+
## Function to train the network
109+
# Define a predictor
110+
function predict(θ, X = Xₙ[:,1], T = t)
111+
Array(solve(prob_nn, Vern7(), u0 = X, p=θ,
112+
tspan = (T[1], T[end]), saveat = T,
113+
abstol=1e-6, reltol=1e-6,
114+
sensealg = ForwardDiffSensitivity()
115+
))
116+
end
117+
118+
# Multiple shooting like loss
119+
function shooting_loss(θ)
120+
# Start with a regularization on the network
121+
l = convert(eltype(θ), 1e-3)*sum(abs2, θ[3:end]) ./ length(θ[3:end])
122+
for i in 1:size(XS,1)
123+
= predict(θ, [XS[i,1], YS[i,1]], TS[i, :])
124+
# Full prediction in x
125+
l += sum(abs2, XS[i,:] .- X̂[1,:])
126+
# Add the boundary condition in y
127+
l += abs2(YS[i, 2] .- X̂[2, end])
128+
end
129+
130+
return l
131+
end
132+
133+
function loss(θ)
134+
= predict(θ)
135+
sum(abs2, Xₙ - X̂) + convert(eltype(θ), 1e-3)*sum(abs2, θ[3:end]) ./ length(θ[3:end])
136+
end
137+
138+
# Container to track the losses
139+
losses = Float32[]
140+
141+
# Callback to show the loss during training
142+
callback(θ,l) = begin
143+
push!(losses, l)
144+
if length(losses)%5==0
145+
println("Current loss after $(length(losses)) iterations: $(losses[end])")
146+
end
147+
false
148+
end
149+
150+
## Training -> First shooting / batching to get a rough estimate
151+
152+
# First train with ADAM for better convergence -> move the parameters into a
153+
# favourable starting positing for BFGS
154+
res1 = DiffEqFlux.sciml_train(shooting_loss, p, ADAM(0.1f0), cb=callback, maxiters = 100)
155+
println("Training loss after $(length(losses)) iterations: $(losses[end])")
156+
# Train with BFGS to achieve partial fit of the data
157+
res2 = DiffEqFlux.sciml_train(shooting_loss, res1.minimizer, BFGS(initial_stepnorm=0.01f0), cb=callback, maxiters = 200)
158+
println("Training loss after $(length(losses)) iterations: $(losses[end])")
159+
# Full L2-Loss for full prediction
160+
res3 = DiffEqFlux.sciml_train(loss, res2.minimizer, BFGS(initial_stepnorm=0.01f0), cb=callback, maxiters = 10000)
161+
println("Final training loss after $(length(losses)) iterations: $(losses[end])")
162+
163+
# Rename the best candidate
164+
p_trained = res3.minimizer
165+
166+
## Analysis of the trained network
167+
# Interpolate the solution
168+
tsample = t[1]:0.5:t[end]
169+
= predict(p_trained, Xₙ[:,1], tsample)
170+
# Trained on noisy data vs real solution
171+
plot(t, transpose(Xₙ), color = :black, label = ["Measurements" nothing])
172+
plot!(tsample, transpose(X̂), color = :red, label = ["Interpolation" nothing])
173+
174+
# Neural network guess
175+
= U(X̂,p_trained[3:end])
176+
177+
scatter(tsample, transpose(Ŷ), xlabel = "t", ylabel ="I1(t), I2(t)", color = :red, label = ["UDE Approximation" nothing])
178+
179+
## Symbolic regression via sparse regression ( SINDy based )
180+
181+
# Create a Basis
182+
@variables u[1:2]
183+
184+
# Generate the basis functions, multivariate polynomials up to deg 5
185+
# and sine
186+
b = [polynomial_basis(u, 5); sin.(u)]
187+
basis = Basis(b, u)
188+
189+
# Create an optimizer for the SINDy problem
190+
opt = SR3(Float32(1e-2), Float32(1e-2))
191+
# Create the thresholds which should be used in the search process
192+
λ = Float32.(exp10.(-7:0.1:3))
193+
# Target function to choose the results from; x = L0 of coefficients and L2-Error of the model
194+
g(x) = x[1] < 1 ? Inf : norm(x, 2)
195+
196+
# Test on uode derivative data
197+
println("SINDy on learned, partial, available data")
198+
Ψ = SINDy(X̂, Ŷ, basis, λ, opt, g = g, maxiter = 50000, normalize = true, denoise = true)
199+
println(Ψ)
200+
print_equations(Ψ)
201+
202+
# Extract the parameter
203+
= parameters(Ψ)
204+
println("First parameter guess : $(p̂)")
205+
206+
# Just the equations -> we reiterate on sindy here
207+
# searching all linear independent components again
208+
b = Basis((u, p, t)->Ψ(u, ones(length(p̂)), t), u, linear_independent = true)
209+
println(b)
210+
# Retune for better parameters -> we could also use DiffEqFlux or other parameter estimation tools here.
211+
opt = SR3(Float32(1e-2), Float32(1e-2))
212+
Ψf = SINDy(X̂, Ŷ, b, opt, maxiter = 10000, normalize = true, convergence_error = eps()) # Succeed
213+
println(Ψf)
214+
print_equations(Ψf)
215+
= parameters(Ψf)
216+
println("Second parameter guess : $(p̂)")
217+
218+
# Define the recovered, hyrid model with the rescaled dynamics
219+
function recovered_dynamics!(du,u, p, t)
220+
= Ψf(u, p[3:4]) # Network prediction
221+
du[1] = p[1]*u[1] + û[1]
222+
du[2] = -p[2]*u[2] + û[2]
223+
end
224+
225+
p_model = [p_trained[1:2];p̂]
226+
estimation_prob = ODEProblem(recovered_dynamics!, Xₙ[:, 1], tspan, p_model)
227+
estimate = solve(estimation_prob, Tsit5(), saveat = 0.1)
228+
229+
# Plot
230+
plot(t, transpose(Xₙ))
231+
plot!(estimate)
232+
233+
## Simulation
234+
235+
# Look at long term prediction
236+
t_long = (0.0f0, 50.0f0)
237+
estimation_prob = ODEProblem(recovered_dynamics!, Xₙ[:, 1], t_long, p_model)
238+
estimate_long = solve(estimation_prob, Tsit5(), saveat = 0.25)
239+
plot(estimate_long)
240+
241+
## Save the results
242+
save("Hudson_Bay_recovery.jld2",
243+
"X", Xₙ, "t" , t, "neural_network" , U, "initial_parameters", p, "trained_parameters" , p_trained, # Training
244+
"losses", losses, "result", Ψf, "recovered_parameters", p̂, # Recovery
245+
"model", recovered_dynamics!, "model_parameter", p_model,
246+
"long_estimate", estimate_long) # Estimation

LotkaVolterra/hudson_bay_data.dat

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
1900 30 4
2+
1901 47.2 6.1
3+
1902 70.2 9.8
4+
1903 77.4 35.2
5+
1904 36.3 59.4
6+
1905 20.6 41.7
7+
1906 18.1 19
8+
1907 21.4 13
9+
1908 22 8.3
10+
1909 25.4 9.1
11+
1910 27.1 7.4
12+
1911 40.3 8
13+
1912 57 12.3
14+
1913 76.6 19.5
15+
1914 52.3 45.7
16+
1915 19.5 51.1
17+
1916 11.2 29.7
18+
1917 7.6 15.8
19+
1918 14.6 9.7
20+
1919 16.2 10.1
21+
1920 24.7 8.6

0 commit comments

Comments
 (0)