Skip to content

SEIR Example #44

@ghost

Description

Translation of SEIR Example, based on Lotka Volterra 1:

Hiya, ok, here's the first...

  • I can't get it to train only for dE, dI, dR the way the old example did, but seems like the way I did it should work (also, it seemed to be linearising the other equations in order to get there back when this did work on the old example, so...)
  • It doesn't predict correctly, for either the UDE or NODE. UDE prediction is linear, same as the problem I was having before, NODE is non-linear but incorrect. (Hopefully it's something simple with the ADAM and BFGS setup, because I don't fully understand that part of the code?)
  • I don't know the correct code to extrapolate after training (lines 155-158 and 248-251)
  • I haven't done the SINDy part of the code yet since the approximation doesn't work.
  • Why can't I drop the .jl file in here?
  • The savefigs don't work in the Lotka Volterra examples, they can be changed to this format.
cd(@__DIR__)
using Pkg; Pkg.activate("."); Pkg.instantiate()

# Single experiment, move to ensemble further on
# Some good parameter values are stored as comments right now
# because this is really good practice
using OrdinaryDiffEq
using ModelingToolkit
using DataDrivenDiffEq
using LinearAlgebra 
using SciMLSensitivity
using Random
using Optimization, OptimizationFlux, OptimizationOptimJL #OptimizationFlux for ADAM and OptimizationOptimJL for BFGS
using Lux
using Statistics
using Plots
gr()
#using DiffEqSensitivity**, Optim**
#using DiffEqFlux**, Flux**

function corona!(du,u,p,t)
    S,E,I,R,N,D,C = u
    F, β0,α,κ,μ,σ,γ,d,λ = p
    dS = -β0*S*F/N - β(t,β0,D,N,κ,α)*S*I/N -μ*S # susceptible
    dE = β0*S*F/N + β(t,β0,D,N,κ,α)*S*I/N -(σ+μ)*E # exposed
    dI = σ*E - (γ+μ)*I # infected
    dR = γ*I - μ*R # removed (recovered + dead)
    dN = -μ*N # total population
    dD = d*γ*I - λ*D # severe, critical cases, and deaths
    dC = σ*E # +cumulative cases

    du[1] = dS; du[2] = dE; du[3] = dI; du[4] = dR
    du[5] = dN; du[6] = dD; du[7] = dC
end

β(t,β0,D,N,κ,α) = β0*(1-α)*(1-D/N)^κ
S0 = 14e6
u0 = [0.9*S0, 0.0, 0.0, 0.0, S0, 0.0, 0.0]
p_ = [10.0, 0.5944, 0.4239, 1117.3, 0.02, 1/3, 1/5,0.2, 1/11.2]
R0 = p_[2]/p_[7]*p_[6]/(p_[6]+p_[5])
tspan = (0.0, 21.0)
prob = ODEProblem(corona!, u0, tspan, p_)
solution = solve(prob, Vern7(), abstol=1e-12, reltol=1e-12, saveat = 1)
t = solution.t
#[2:4] are Exposed, Infected, Removed 
X = Array(solution[2:4,:])'
plot(X)

#Extrapolate to a longer timespan
tspan2 = (0.0,60.0)
prob = ODEProblem(corona!, u0, tspan2, p_)
solution_extrapolate = solve(prob, Vern7(), abstol=1e-12, reltol=1e-12, saveat = 1)
extrapolate = Array(solution_extrapolate[2:4,:])'
plot(extrapolate)

# Ideal data
tsdata = Array(solution)

# Add noise to the data
noisy_data = tsdata + Float32(1e-5)*randn(eltype(tsdata), size(tsdata))
# You can see that the noise looks random
plot(abs.(tsdata-noisy_data)')

### Neural ODE
#Predicts for unknown equations
rng = Random.default_rng()
Random.seed!(111)

#7 inputs for 7 equations, 5 outputs because we know 2 equations already
U = Lux.Chain(Lux.Dense(7, 64, tanh),Lux.Dense(64, 64, tanh), Lux.Dense(64, 64, tanh), Lux.Dense(64, 5))

# Get the initial parameters and state variables of the model
p, st = Lux.setup(rng, U)

function coronaNODE(du,u,p,t,p_)
    û = U(u, p, st)[1] # Network prediction
    S,E,I,R,N,D,C = u
    μ,σ = p_
    dS = û[1]
    dE = û[2]
    dI = û[3]
    dR = û[4]
    dN = -μ*N # total population
    dD = û[5]
    dC = σ*E # +cumulative cases
    du[1] = dS; du[2] = dE; du[3] = dI; du[4] = dR
    du[5] = dN; du[6] = dD; du[7] = dC
end

# Closure with the known parameters
NODE_dynamics!(du,u,p,t) = coronaNODE(du,u,p,t,p_)
# Define the problem
prob_node = ODEProblem(NODE_dynamics!, u0, tspan, p)

## Function to train the network
# Define a predictor
function predict(θ, X = noisy_data[:,1], T = t)
    Array(solve(prob_node, Vern7(), u0 = X, p=θ,
                saveat = T,
                abstol=1e-6, reltol=1e-6,
                sensealg = InterpolatingAdjoint(autojacvec=ReverseDiffVJP())))
end

# Simple L2 loss
function loss(θ)
    X̂ = predict(θ)
    sum(abs2, noisy_data .- X̂)
end


# Container to track the losses
losses = Float32[]

callback = function (p, l)
  push!(losses, l)
  if length(losses)%50==0
      println("Current loss after $(length(losses)) iterations: $(losses[end])")
  end
  return false
end

## Training

# First train with ADAM for better convergence -> move the parameters into a
# favourable starting positing for BFGS
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x,p)->loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, Lux.ComponentArray(p))
res1 = Optimization.solve(optprob, ADAM(0.01), callback=callback, maxiters = 200)
println("Training loss after $(length(losses)) iterations: $(losses[end])")
# Train with BFGS
optprob2 = Optimization.OptimizationProblem(optf, res1.minimizer)
res2 = Optimization.solve(optprob2, Optim.BFGS(initial_stepnorm=0.01), callback=callback, maxiters = 10000)
println("Final training loss after $(length(losses)) iterations: $(losses[end])")

# Plot the losses
pl_losses = plot(1:200, losses[1:200], yaxis = :log10, xaxis = :log10, xlabel = "Iterations", ylabel = "Loss", label = "ADAM", color = :blue)
plot!(201:length(losses), losses[201:end], yaxis = :log10, xaxis = :log10, xlabel = "Iterations", ylabel = "Loss", label = "BFGS", color = :red)
savefig(pl_losses, "pl_lossesNODE.png")

# Rename the best candidate
p_trained = res2.minimizer

## Analysis of the trained network
# Plot the data and the approximation

# Make the prediction to match solution.t
X̂ = predict(p_trained, noisy_data[:,1], t)
# Prediction trained on noisy data vs real solution
pl_trajectory = plot(t, transpose(X̂[2:4,:]), xlabel = "t", ylabel ="x(t), y(t)", color = :red, label = ["NODE Approximation" nothing])
scatter!(solution.t, transpose(noisy_data[2:4,:]), color = :black, label = ["Measurements" nothing])
savefig(pl_trajectory, "plots_trajectory_reconstructionNODE.png")

#Extrapolate the solution to match tspan2
ExtrapolateX̂ = predict(p_trained, noisy_data[:,1], solution_extrapolate.t)
extrapolate_trajectory = plot(solution_extrapolate.t, transpose(ExtrapolateX̂[2:4,:]), xlabel = "t", ylabel ="x(t), y(t)", color = :red, label = ["NODE Approximation" nothing])
scatter!(solution_extrapolate.t, transpose(solution_extrapolate[2:4,:]), color = :black, label = ["Measurements" nothing])
savefig(pl_trajectory, "ExtrapolateNODE.png")

### Universal ODE
##Prediction for missing parameters
rng = Random.default_rng()
Random.seed!(222)

#7 inputs for 7 equations, 1 output for 1 missing part of the equation
U = Lux.Chain(Lux.Dense(7, 64, tanh),Lux.Dense(64, 64, tanh), Lux.Dense(64, 64, tanh), Lux.Dense(64, 1))

# Get the initial parameters and state variables of the model
p, st = Lux.setup(rng, U)

function coronaUDE(du,u,p,t,p_true)
    û = U(u, p, st)[1] # Network prediction
    S,E,I,R,N,D,C = u
    F,β0,α,κ,μ,σ,γ,d,λ = p_
    dS = -β0*S*F/N - û[1] -μ*S # susceptible
    dE = β0*S*F/N + û[1] -(σ+μ)*E # exposed
    dI = σ*E - (γ+μ)*I # infected
    dR = γ*I - μ*R # removed (recovered + dead)
    dN = -μ*N # total population
    dD = d*γ*I - λ*D # severe, critical cases, and deaths
    dC = σ*E # +cumulative cases
    du[1] = dS; du[2] = dE; du[3] = dI; du[4] = dR
    du[5] = dN; du[6] = dD; du[7] = dC
end

# Closure with the known parameters
UDE_dynamics!(du,u,p,t) = coronaUDE(du,u,p,t,p_)
# Define the problem
prob_ude = ODEProblem(UDE_dynamics!, u0, tspan, p)

## Function to train the network
# Define a predictor
function predict(θ, X = noisy_data[:,1], T = t)
    Array(solve(prob_ude, Vern7(), u0 = X, p=θ,
                saveat = T,
                abstol=1e-6, reltol=1e-6,
                sensealg = InterpolatingAdjoint(autojacvec=ReverseDiffVJP())))
end

# Simple L2 loss
function loss(θ)
    X̂ = predict(θ)
    sum(abs2, noisy_data .- X̂)
end


# Container to track the losses
losses = Float32[]

callback = function (p, l)
  push!(losses, l)
  if length(losses)%50==0
      println("Current loss after $(length(losses)) iterations: $(losses[end])")
  end
  return false
end

## Training

# First train with ADAM for better convergence -> move the parameters into a
# favourable starting positing for BFGS
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x,p)->loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, Lux.ComponentArray(p))
res1UDE = Optimization.solve(optprob, ADAM(0.01), callback=callback, maxiters = 200)
println("Training loss after $(length(losses)) iterations: $(losses[end])")
# Train with BFGS
optprob2 = Optimization.OptimizationProblem(optf, res1.minimizer)
res2UDE = Optimization.solve(optprob2, Optim.BFGS(initial_stepnorm=0.01), callback=callback, maxiters = 10000)
println("Final training loss after $(length(losses)) iterations: $(losses[end])")

# Plot the losses
pl_losses = plot(1:200, losses[1:200], yaxis = :log10, xaxis = :log10, xlabel = "Iterations", ylabel = "Loss", label = "ADAM", color = :blue)
plot!(201:length(losses), losses[201:end], yaxis = :log10, xaxis = :log10, xlabel = "Iterations", ylabel = "Loss", label = "BFGS", color = :red)
savefig(pl_losses, "plot_lossesUDE.png")
# Rename the best candidate
p_trained = res2UDE.minimizer

## Analysis of the trained network
# Plot the data and the approximation
X̂ = predict(p_trained, noisy_data[:,1], t)
# Trained on noisy data vs real solution
pl_trajectory = plot(t, transpose(X̂[2:4,:]), xlabel = "t", ylabel ="x(t), y(t)", color = :red, label = ["UDE Approximation" nothing])
scatter!(solution.t, transpose(solution[2:4,:]), color = :black, label = ["Measurements" nothing])
savefig(pl_trajectory, "plot_trajectory_reconstructionUDE.png")

# Extrapolate out
ExtrapolateX̂ = predict(p_trained, noisy_data[:,1], solution_extrapolate.t)
extrapolate_trajectory = plot(solution_extrapolate.t, transpose(ExtrapolateX̂[2:4,:]), xlabel = "t", ylabel ="x(t), y(t)", color = :red, label = ["UDE Approximation" nothing])
scatter!(solution_extrapolate.t, transpose(solution_extrapolate[2:4,:]), color = :black, label = ["Measurements" nothing])
savefig(extrapolate_trajectory, "ExtrapolateUDE.png")

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions