|
| 1 | +## Environment and packages |
| 2 | +cd(@__DIR__) |
| 3 | +using Pkg; Pkg.activate("."); Pkg.instantiate() |
| 4 | + |
| 5 | +using OrdinaryDiffEq |
| 6 | +using ModelingToolkit |
| 7 | +using DataDrivenDiffEq |
| 8 | +using LinearAlgebra, DiffEqSensitivity, Optim |
| 9 | +using DiffEqFlux, Flux |
| 10 | +using Plots |
| 11 | +gr() |
| 12 | +using JLD2, FileIO |
| 13 | +using Statistics |
| 14 | +using DelimitedFiles |
| 15 | +# Set a random seed for reproduceable behaviour |
| 16 | +using Random |
| 17 | +Random.seed!(5443) |
| 18 | + |
| 19 | +## Data Preprocessing |
| 20 | +# The data has been taken from https://jmahaffy.sdsu.edu/courses/f00/math122/labs/labj/q3v1.htm |
| 21 | +# Originally published in |
| 22 | +hudson_bay_data = readdlm("hudson_bay_data.dat", '\t', Float32, '\n') |
| 23 | +# Measurements of prey and predator |
| 24 | +Xₙ = Matrix(transpose(hudson_bay_data[:, 2:3])) |
| 25 | +plot(t, transpose(Xₙ)) |
| 26 | +# Normalize the data; since the data domain is strictly positive |
| 27 | +# we just need to divide by the maximum |
| 28 | +xscale = maximum(Xₙ, dims =2) |
| 29 | +Xₙ .= 1f0 ./ xscale .* Xₙ |
| 30 | +# Time from 0 -> n |
| 31 | +t = hudson_bay_data[:, 1] .- hudson_bay_data[1, 1] |
| 32 | +tspan = (t[1], t[end]) |
| 33 | + |
| 34 | +# Plot the data |
| 35 | +scatter(t, transpose(Xₙ), xlabel = "t [a]", ylabel = "x(t), y(t)") |
| 36 | +plot!(t, transpose(Xₙ), xlabel = "t", ylabel = "x(t), y(t)") |
| 37 | + |
| 38 | +## Direct Identification via SINDy + Collocation |
| 39 | +# Create a collocation |
| 40 | +dx̂,x̂ = collocate_data(Xₙ,t, GaussianKernel()) |
| 41 | +# Look at the collocation |
| 42 | +plot(t, dx̂') |
| 43 | +# Perform sindy |
| 44 | + |
| 45 | +# Create a Basis |
| 46 | +@variables u[1:2] |
| 47 | + |
| 48 | +# Generate the basis functions, multivariate polynomials up to deg 5 |
| 49 | +# and sine |
| 50 | +b = [polynomial_basis(u, 5); sin.(u)] |
| 51 | +basis = Basis(b, u) |
| 52 | +# Create an optimizer for the SINDy problem |
| 53 | +opt = SR3(Float32(1e-2), Float32(1e-2)) |
| 54 | +# Create the thresholds which should be used in the search process |
| 55 | +λ = Float32.(exp10.(-7:0.1:3)) |
| 56 | +# Target function to choose the results from; x = L0 of coefficients and L2-Error of the model |
| 57 | +g(x) = x[1] < 1 ? Inf : norm(x, 2) |
| 58 | +# Test on derivative data |
| 59 | +Ψ = SINDy(x̂, dx̂, basis, λ, opt, g = g, maxiter = 50000, normalize = true, denoise = true) # Succeed |
| 60 | +println(Ψ) |
| 61 | +print_equations(Ψ) # Fails |
| 62 | +b2 = Basis((u,p,t)->Ψ(u,ones(length(parameters(Ψ))),t),u, linear_independent = true) |
| 63 | +Ψ = SINDy(x̂, dx̂, b2, λ, opt, g = g, maxiter = 50000, normalize = true, denoise = true) # Succeed |
| 64 | +println(Ψ) |
| 65 | +print_equations(Ψ) # Fails |
| 66 | +parameters(Ψ) |
| 67 | +## UDE Approach |
| 68 | +# Subsample the data in y -> initial fitting strategy (batching) |
| 69 | +# We assume we have only 5 measurements in y, evenly distributed |
| 70 | +ty = collect(t[1]:Float32(t[end]/5):t[end]) |
| 71 | +# Create datasets for the different measurements |
| 72 | +t |
| 73 | +XS = zeros(Float32, length(ty)-1, floor(Int64, mean(diff(ty))/mean(diff(t)))+1) # All x data |
| 74 | +TS = zeros(Float32, length(ty)-1, floor(Int64, mean(diff(ty))/mean(diff(t)))+1) # Time data |
| 75 | +YS = zeros(Float32, length(ty)-1, 2) # Just two measurements in y |
| 76 | + |
| 77 | +for i in 1:length(ty)-1 |
| 78 | + idxs = ty[i].<= t .<= ty[i+1] |
| 79 | + XS[i, :] = Xₙ[1, idxs] |
| 80 | + TS[i, :] = t[idxs] |
| 81 | + YS[i, :] = [Xₙ[2, t .== ty[i]]'; Xₙ[2, t .== ty[i+1]]] |
| 82 | +end |
| 83 | + |
| 84 | +## Define the network |
| 85 | +# Gaussian RBF as activation |
| 86 | +rbf(x) = exp.(-(x.^2)) |
| 87 | + |
| 88 | +# Define the network 2->5->5->5->2 |
| 89 | +U = FastChain( |
| 90 | + FastDense(2,5,rbf), FastDense(5,5, rbf), FastDense(5,5, tanh), FastDense(5,2) |
| 91 | +) |
| 92 | + |
| 93 | +# Get the initial parameters, first two is linear birth / decay of prey and predator |
| 94 | +p = [rand(Float32,2); initial_params(U)] |
| 95 | + |
| 96 | +# Define the hybrid model |
| 97 | +function ude_dynamics!(du,u, p, t) |
| 98 | + û = U(u, p[3:end]) # Network prediction |
| 99 | + # We assume a linear birth rate for the prey |
| 100 | + du[1] = p[1]u[1] + û[1] |
| 101 | + # We assume a linear decay rate for the predator |
| 102 | + du[2] = -p[2]*u[2] + û[2] |
| 103 | +end |
| 104 | + |
| 105 | +# Define the problem |
| 106 | +prob_nn = ODEProblem(ude_dynamics!,Xₙ[:, 1], tspan, p) |
| 107 | + |
| 108 | +## Function to train the network |
| 109 | +# Define a predictor |
| 110 | +function predict(θ, X = Xₙ[:,1], T = t) |
| 111 | + Array(solve(prob_nn, Vern7(), u0 = X, p=θ, |
| 112 | + tspan = (T[1], T[end]), saveat = T, |
| 113 | + abstol=1e-6, reltol=1e-6, |
| 114 | + sensealg = ForwardDiffSensitivity() |
| 115 | + )) |
| 116 | +end |
| 117 | + |
| 118 | +# Multiple shooting like loss |
| 119 | +function shooting_loss(θ) |
| 120 | + # Start with a regularization on the network |
| 121 | + l = convert(eltype(θ), 1e-3)*sum(abs2, θ[3:end]) ./ length(θ[3:end]) |
| 122 | + for i in 1:size(XS,1) |
| 123 | + X̂ = predict(θ, [XS[i,1], YS[i,1]], TS[i, :]) |
| 124 | + # Full prediction in x |
| 125 | + l += sum(abs2, XS[i,:] .- X̂[1,:]) |
| 126 | + # Add the boundary condition in y |
| 127 | + l += abs2(YS[i, 2] .- X̂[2, end]) |
| 128 | + end |
| 129 | + |
| 130 | + return l |
| 131 | +end |
| 132 | + |
| 133 | +function loss(θ) |
| 134 | + X̂ = predict(θ) |
| 135 | + sum(abs2, Xₙ - X̂) + convert(eltype(θ), 1e-3)*sum(abs2, θ[3:end]) ./ length(θ[3:end]) |
| 136 | +end |
| 137 | + |
| 138 | +# Container to track the losses |
| 139 | +losses = Float32[] |
| 140 | + |
| 141 | +# Callback to show the loss during training |
| 142 | +callback(θ,l) = begin |
| 143 | + push!(losses, l) |
| 144 | + if length(losses)%5==0 |
| 145 | + println("Current loss after $(length(losses)) iterations: $(losses[end])") |
| 146 | + end |
| 147 | + false |
| 148 | +end |
| 149 | + |
| 150 | +## Training -> First shooting / batching to get a rough estimate |
| 151 | + |
| 152 | +# First train with ADAM for better convergence -> move the parameters into a |
| 153 | +# favourable starting positing for BFGS |
| 154 | +res1 = DiffEqFlux.sciml_train(shooting_loss, p, ADAM(0.1f0), cb=callback, maxiters = 100) |
| 155 | +println("Training loss after $(length(losses)) iterations: $(losses[end])") |
| 156 | +# Train with BFGS to achieve partial fit of the data |
| 157 | +res2 = DiffEqFlux.sciml_train(shooting_loss, res1.minimizer, BFGS(initial_stepnorm=0.01f0), cb=callback, maxiters = 200) |
| 158 | +println("Training loss after $(length(losses)) iterations: $(losses[end])") |
| 159 | +# Full L2-Loss for full prediction |
| 160 | +res3 = DiffEqFlux.sciml_train(loss, res2.minimizer, BFGS(initial_stepnorm=0.01f0), cb=callback, maxiters = 10000) |
| 161 | +println("Final training loss after $(length(losses)) iterations: $(losses[end])") |
| 162 | + |
| 163 | +# Rename the best candidate |
| 164 | +p_trained = res3.minimizer |
| 165 | + |
| 166 | +## Analysis of the trained network |
| 167 | +# Interpolate the solution |
| 168 | +tsample = t[1]:0.5:t[end] |
| 169 | +X̂ = predict(p_trained, Xₙ[:,1], tsample) |
| 170 | +# Trained on noisy data vs real solution |
| 171 | +plot(t, transpose(Xₙ), color = :black, label = ["Measurements" nothing]) |
| 172 | +plot!(tsample, transpose(X̂), color = :red, label = ["Interpolation" nothing]) |
| 173 | + |
| 174 | +# Neural network guess |
| 175 | +Ŷ = U(X̂,p_trained[3:end]) |
| 176 | + |
| 177 | +scatter(tsample, transpose(Ŷ), xlabel = "t", ylabel ="I1(t), I2(t)", color = :red, label = ["UDE Approximation" nothing]) |
| 178 | + |
| 179 | +## Symbolic regression via sparse regression ( SINDy based ) |
| 180 | + |
| 181 | +# Create a Basis |
| 182 | +@variables u[1:2] |
| 183 | + |
| 184 | +# Generate the basis functions, multivariate polynomials up to deg 5 |
| 185 | +# and sine |
| 186 | +b = [polynomial_basis(u, 5); sin.(u)] |
| 187 | +basis = Basis(b, u) |
| 188 | + |
| 189 | +# Create an optimizer for the SINDy problem |
| 190 | +opt = SR3(Float32(1e-2), Float32(1e-2)) |
| 191 | +# Create the thresholds which should be used in the search process |
| 192 | +λ = Float32.(exp10.(-7:0.1:3)) |
| 193 | +# Target function to choose the results from; x = L0 of coefficients and L2-Error of the model |
| 194 | +g(x) = x[1] < 1 ? Inf : norm(x, 2) |
| 195 | + |
| 196 | +# Test on uode derivative data |
| 197 | +println("SINDy on learned, partial, available data") |
| 198 | +Ψ = SINDy(X̂, Ŷ, basis, λ, opt, g = g, maxiter = 50000, normalize = true, denoise = true) |
| 199 | +println(Ψ) |
| 200 | +print_equations(Ψ) |
| 201 | + |
| 202 | +# Extract the parameter |
| 203 | +p̂ = parameters(Ψ) |
| 204 | +println("First parameter guess : $(p̂)") |
| 205 | + |
| 206 | +# Just the equations -> we reiterate on sindy here |
| 207 | +# searching all linear independent components again |
| 208 | +b = Basis((u, p, t)->Ψ(u, ones(length(p̂)), t), u, linear_independent = true) |
| 209 | +println(b) |
| 210 | +# Retune for better parameters -> we could also use DiffEqFlux or other parameter estimation tools here. |
| 211 | +opt = SR3(Float32(1e-2), Float32(1e-2)) |
| 212 | +Ψf = SINDy(X̂, Ŷ, b, opt, maxiter = 10000, normalize = true, convergence_error = eps()) # Succeed |
| 213 | +println(Ψf) |
| 214 | +print_equations(Ψf) |
| 215 | +p̂ = parameters(Ψf) |
| 216 | +println("Second parameter guess : $(p̂)") |
| 217 | + |
| 218 | +# Define the recovered, hyrid model with the rescaled dynamics |
| 219 | +function recovered_dynamics!(du,u, p, t) |
| 220 | + û = Ψf(u, p[3:4]) # Network prediction |
| 221 | + du[1] = p[1]*u[1] + û[1] |
| 222 | + du[2] = -p[2]*u[2] + û[2] |
| 223 | +end |
| 224 | + |
| 225 | +p_model = [p_trained[1:2];p̂] |
| 226 | +estimation_prob = ODEProblem(recovered_dynamics!, Xₙ[:, 1], tspan, p_model) |
| 227 | +estimate = solve(estimation_prob, Tsit5(), saveat = 0.1) |
| 228 | + |
| 229 | +# Plot |
| 230 | +plot(t, transpose(Xₙ)) |
| 231 | +plot!(estimate) |
| 232 | + |
| 233 | +## Simulation |
| 234 | + |
| 235 | +# Look at long term prediction |
| 236 | +t_long = (0.0f0, 50.0f0) |
| 237 | +estimation_prob = ODEProblem(recovered_dynamics!, Xₙ[:, 1], t_long, p_model) |
| 238 | +estimate_long = solve(estimation_prob, Tsit5(), saveat = 0.25) |
| 239 | +plot(estimate_long) |
| 240 | + |
| 241 | +## Save the results |
| 242 | +save("Hudson_Bay_recovery.jld2", |
| 243 | + "X", Xₙ, "t" , t, "neural_network" , U, "initial_parameters", p, "trained_parameters" , p_trained, # Training |
| 244 | + "losses", losses, "result", Ψf, "recovered_parameters", p̂, # Recovery |
| 245 | + "model", recovered_dynamics!, "model_parameter", p_model, |
| 246 | + "long_estimate", estimate_long) # Estimation |
0 commit comments