-
Notifications
You must be signed in to change notification settings - Fork 60
Open
Description
I couldn't find an example of multiple shooting in SciMLSensitivity, sorry - it's coming up with the error "UndefVarError: multiple_shoot not defined" when I try to run ADAM.
I also tried FENEP but it seems like there's some issue with Tracked Arrays, here: SciML/SciMLSensitivity.jl#609
And also with Flux, because the original doesn't work with it. I get the error: UndefVarError: TrackedArray not defined
So i'll leave this one.
using OrdinaryDiffEq
using ModelingToolkit
using DataDrivenDiffEq
using LinearAlgebra
using SciMLSensitivity
using Random
using Optimization, OptimizationOptimisers, OptimizationOptimJL #OptimizationOptimisers for ADAM and OptimizationOptimJL for BFGS
using Lux
using Statistics
using JLD2, FileIO
using DelimitedFiles
using Plots
gr()
Random.seed!(5443)
#### NOTE
# Since the recent release of DataDrivenDiffEq v0.6.0 where a complete overhaul of the optimizers took
# place, SR3 has been used. Right now, STLSQ performs better and has been changed.
# Additionally, the behaviour of the optimization has changed slightly. This has been adjusted
# by decreasing the tolerance of the gradient.
svname = "HudsonBay"
## Data Preprocessing
# The data has been taken from https://jmahaffy.sdsu.edu/courses/f00/math122/labs/labj/q3v1.htm
# Originally published in E. P. Odum (1953), Fundamentals of Ecology, Philadelphia, W. B. Saunders
hudson_bay_data = readdlm("hudson_bay_data.dat", '\t', Float32, '\n')
# Measurements of prey and predator
Xₙ = Matrix(transpose(hudson_bay_data[:, 2:3]))
t = hudson_bay_data[:, 1] .- hudson_bay_data[1, 1]
# Normalize the data; since the data domain is strictly positive
# we just need to divide by the maximum
xscale = maximum(Xₙ, dims =2)
Xₙ .= 1f0 ./ xscale .* Xₙ
# Time from 0 -> n
tspan = (t[1], t[end])
# Plot the data
scatter(t, transpose(Xₙ), xlabel = "t", ylabel = "x(t), y(t)")
plot!(t, transpose(Xₙ), xlabel = "t", ylabel = "x(t), y(t)")
## Direct Identification via SINDy + Collocation
# Create the problem using a gaussian kernel for collocation
full_problem = ContinuousDataDrivenProblem(Xₙ, t, DataDrivenDiffEq.GaussianKernel())
# Look at the collocation
plot(full_problem.t, full_problem.X')
plot(full_problem.t, full_problem.DX')
# Create a Basis
@variables u[1:2]
# Generate the basis functions, multivariate polynomials up to deg 5
# and sine
b = [polynomial_basis(u, 5); sin.(u)]
basis = Basis(b, u)
# Create the thresholds which should be used in the search process
λ = Float32.(exp10.(-7:0.1:5))
# Create an optimizer for the SINDy problem
opt = STLSQ(λ)
# Best result so far
full_res = solve(full_problem, basis, opt, maxiter = 10000, progress = true, denoise = true, normalize = true)
println(full_res)
println(result(full_res))
## Define the network
# Gaussian RBF as activation
rbf(x) = exp.(-(x.^2))
# Define the network 2->5->5->5->2
U = Lux.Chain(
Lux.Dense(2,5,rbf), Lux.Dense(5,5, rbf), Lux.Dense(5,5, tanh), Lux.Dense(5,2)
)
# Get the initial parameters, first two is linear birth / decay of prey and predator
rng = Random.default_rng()
p, st = Lux.setup(rng, U)
# Define the hybrid model
function ude_dynamics!(du,u, p, t)
û = U(u, p[3:end]) # Network prediction
# We assume a linear birth rate for the prey
du[1] = p[1]*u[1] + û[1]
# We assume a linear decay rate for the predator
du[2] = -p[2]*u[2] + û[2]
end
# Define the problem
prob_nn = ODEProblem(ude_dynamics!,Xₙ[:, 1], tspan, p)
## Function to train the network
# Define a predictor
function predict(θ, X = Xₙ[:,1], T = t)
Array(solve(prob_nn, Vern7(), u0 = X, p=θ,
tspan = (T[1], T[end]), saveat = T,
abstol=1e-6, reltol=1e-6,
sensealg = ForwardDiffSensitivity()
))
end
# Define parameters for Multiple Shooting
group_size = 5
continuity_term = 200.0f0
function loss(data, pred)
return sum(abs2, data - pred)
end
function shooting_loss(p)
return multiple_shoot(p, Xₙ, t, prob_nn, loss, Vern7(),
group_size; continuity_term)
end
function loss(θ)
X̂ = predict(θ)
sum(abs2, Xₙ - X̂) / size(Xₙ, 2) + convert(eltype(θ), 1e-3)*sum(abs2, θ[3:end]) ./ length(θ[3:end])
end
# Container to track the losses
losses = Float32[]
# Callback to show the loss during training
callback(θ,args...) = begin
l = loss(θ) # Equivalent L2 loss
push!(losses, l)
if length(losses)%5==0
println("Current loss after $(length(losses)) iterations: $(losses[end])")
end
false
end
## Training -> First shooting / batching to get a rough estimate
# First train with ADAM for better convergence -> move the parameters into a
# favourable starting positing for BFGS
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x,p)->shooting_loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, Lux.ComponentArray(p))
res1 = Optimization.solve(optprob, ADAM(0.1f0), cb=callback, maxiters = 100)
println("Training loss after $(length(losses)) iterations: $(losses[end])")
# Train with BFGS to achieve partial fit of the data
optprob2 = Optimization.OptimizationProblem(optf, res1.minimizer)
res2 = Optimization.solve(optprob2, Optim.BFGS(initial_stepnorm=0.01f0), cb=callback, maxiters = 500)
println("Training loss after $(length(losses)) iterations: $(losses[end])")
# Full L2-Loss for full prediction
optf2 = Optimization.OptimizationFunction((x,p)->loss(x), adtype)
optprob2 = Optimization.OptimizationProblem(optf2, res2.minimizer)
res3 = Optimization.solve(optprob2, Optim.BFGS(initial_stepnorm=0.01f0), cb=callback, maxiters = 10000)
println("Final training loss after $(length(losses)) iterations: $(losses[end])")
pl_losses = plot(1:101, losses[1:101], yaxis = :log10, xaxis = :log10, xlabel = "Iterations", ylabel = "Loss", label = "ADAM (Shooting)", color = :blue)
plot!(102:302, losses[102:302], yaxis = :log10, xaxis = :log10, xlabel = "Iterations", ylabel = "Loss", label = "BFGS (Shooting)", color = :red)
plot!(302:length(losses), losses[302:end], color = :black, label = "BFGS (L2)")
savefig(pl_losses, "plot_losses.png"))
Metadata
Metadata
Assignees
Labels
No labels