Description
I'm trying to find an example of how to use automatic differentiation in more complex ODE rather than those vanilla examples in the tutorials. Simple ODEs are quite easy to figure out, but I always have the trouble with the Dual Number
error in complex ODE models. Seems my ODE functions are always not generic enough. And I'd like to learn how to make it work.
I was trying to modify the Beeler-Reuter Model (the CPU version) in the advanced tutorial, as the original code doesn't work with autodiff=true
since it only accepts Float
. I finally manage to make it work, but I have many problems that I don't understand. Here is the modified working code:
using DifferentialEquations, Sundials, ForwardDiff
const v0 = -84.624
const v1 = 10.0
const C_K1 = 1.0f0
const C_x1 = 1.0f0
const C_Na = 1.0f0
const C_s = 1.0f0
const D_Ca = 0.0f0
const D_Na = 0.0f0
const g_s = 0.09f0
const g_Na = 4.0f0
const g_NaC = 0.005f0
const ENa = 50.0f0 + D_Na
const γ = 0.5f0
const C_m = 1.0f0
# make this parametric
mutable struct BeelerReuterCpu{T} <: Function
t::T # the last timestep time to calculate Δt
diff_coef::T # the diffusion-coefficient (coupling strength)
C::Array{T,2} # intracellular calcium concentration
M::Array{T,2} # sodium current activation gate (m)
H::Array{T,2} # sodium current inactivation gate (h)
J::Array{T,2} # sodium current slow inactivaiton gate (j)
D::Array{T,2} # calcium current activaiton gate (d)
F::Array{T,2} # calcium current inactivation gate (f)
XI::Array{T,2} # inward-rectifying potassium current (iK1)
Δu::Array{T,2} # place-holder for the Laplacian
function BeelerReuterCpu(u0::Array{T,2}, diff_coef::T) where {T}
self = new{T}()
ny, nx = size(u0)
self.t = 0.0
self.diff_coef = diff_coef
self.C = fill(0.0001f0, (ny, nx))
self.M = fill(0.01f0, (ny, nx))
self.H = fill(0.988f0, (ny, nx))
self.J = fill(0.975f0, (ny, nx))
self.D = fill(0.003f0, (ny, nx))
self.F = fill(0.994f0, (ny, nx))
self.XI = fill(0.0001f0, (ny, nx))
self.Δu = zeros(ny, nx)
return self
end
end
# 5-point stencil
function laplacian(Δu, u)
n1, n2 = size(u)
# internal nodes
for j = 2:n2-1
for i = 2:n1-1
@inbounds Δu[i, j] = u[i+1, j] + u[i-1, j] + u[i, j+1] + u[i, j-1] - 4 * u[i, j]
end
end
# left/right edges
for i = 2:n1-1
@inbounds Δu[i, 1] = u[i+1, 1] + u[i-1, 1] + 2 * u[i, 2] - 4 * u[i, 1]
@inbounds Δu[i, n2] = u[i+1, n2] + u[i-1, n2] + 2 * u[i, n2-1] - 4 * u[i, n2]
end
# top/bottom edges
for j = 2:n2-1
@inbounds Δu[1, j] = u[1, j+1] + u[1, j-1] + 2 * u[2, j] - 4 * u[1, j]
@inbounds Δu[n1, j] = u[n1, j+1] + u[n1, j-1] + 2 * u[n1-1, j] - 4 * u[n1, j]
end
# corners
@inbounds Δu[1, 1] = 2 * (u[2, 1] + u[1, 2]) - 4 * u[1, 1]
@inbounds Δu[n1, 1] = 2 * (u[n1-1, 1] + u[n1, 2]) - 4 * u[n1, 1]
@inbounds Δu[1, n2] = 2 * (u[2, n2] + u[1, n2-1]) - 4 * u[1, n2]
@inbounds Δu[n1, n2] = 2 * (u[n1-1, n2] + u[n1, n2-1]) - 4 * u[n1, n2]
end
@inline function rush_larsen(g, α, β, Δt)
inf = α / (α + β)
τ = 1.0f0 / (α + β)
return clamp(g + (g - inf) * expm1(-Δt / τ), 0.0f0, 1.0f0)
end
function update_M_cpu(g, v, Δt)
# the condition is needed here to prevent NaN when v == 47.0
α = isapprox(v, 47.0f0) ? 10.0f0 : -(v + 47.0f0) / (exp(-0.1f0 * (v + 47.0f0)) - 1.0f0)
β = (40.0f0 * exp(-0.056f0 * (v + 72.0f0)))
return rush_larsen(g, α, β, Δt)
end
function update_H_cpu(g, v, Δt)
α = 0.126f0 * exp(-0.25f0 * (v + 77.0f0))
β = 1.7f0 / (exp(-0.082f0 * (v + 22.5f0)) + 1.0f0)
return rush_larsen(g, α, β, Δt)
end
function update_J_cpu(g, v, Δt)
α = (0.55f0 * exp(-0.25f0 * (v + 78.0f0))) / (exp(-0.2f0 * (v + 78.0f0)) + 1.0f0)
β = 0.3f0 / (exp(-0.1f0 * (v + 32.0f0)) + 1.0f0)
return rush_larsen(g, α, β, Δt)
end
function update_D_cpu(g, v, Δt)
α = γ * (0.095f0 * exp(-0.01f0 * (v - 5.0f0))) / (exp(-0.072f0 * (v - 5.0f0)) + 1.0f0)
β = γ * (0.07f0 * exp(-0.017f0 * (v + 44.0f0))) / (exp(0.05f0 * (v + 44.0f0)) + 1.0f0)
return rush_larsen(g, α, β, Δt)
end
function update_F_cpu(g, v, Δt)
α = γ * (0.012f0 * exp(-0.008f0 * (v + 28.0f0))) / (exp(0.15f0 * (v + 28.0f0)) + 1.0f0)
β = γ * (0.0065f0 * exp(-0.02f0 * (v + 30.0f0))) / (exp(-0.2f0 * (v + 30.0f0)) + 1.0f0)
return rush_larsen(g, α, β, Δt)
end
function update_XI_cpu(g, v, Δt)
α = (0.0005f0 * exp(0.083f0 * (v + 50.0f0))) / (exp(0.057f0 * (v + 50.0f0)) + 1.0f0)
β = (0.0013f0 * exp(-0.06f0 * (v + 20.0f0))) / (exp(-0.04f0 * (v + 20.0f0)) + 1.0f0)
return rush_larsen(g, α, β, Δt)
end
function update_C_cpu(g, d, f, v, Δt)
ECa = D_Ca - 82.3f0 - 13.0278f0 * log(g)
kCa = C_s * g_s * d * f
iCa = kCa * (v - ECa)
inf = 1.0f-7 * (0.07f0 - g)
τ = 1.0f0 / 0.07f0
return g + (g - inf) * expm1(-Δt / τ)
end
function update_gates_cpu(u, XI, M, H, J, D, F, C, Δt)
# let Δt = Float32(Δt) # remove the Let
n1, n2 = size(u)
for j = 1:n2
for i = 1:n1
v = u[i, j]
XI[i, j] = update_XI_cpu(XI[i, j], v, Δt)
M[i, j] = update_M_cpu(M[i, j], v, Δt)
H[i, j] = update_H_cpu(H[i, j], v, Δt)
J[i, j] = update_J_cpu(J[i, j], v, Δt)
D[i, j] = update_D_cpu(D[i, j], v, Δt)
F[i, j] = update_F_cpu(F[i, j], v, Δt)
C[i, j] = update_C_cpu(C[i, j], D[i, j], F[i, j], v, Δt)
end
end
# end
end
# iK1 is the inward-rectifying potassium current
function calc_iK1(v)
ea = exp(0.04f0 * (v + 85.0f0))
eb = exp(0.08f0 * (v + 53.0f0))
ec = exp(0.04f0 * (v + 53.0f0))
ed = exp(-0.04f0 * (v + 23.0f0))
return 0.35f0 * (
4.0f0 * (ea - 1.0f0) / (eb + ec) +
0.2f0 * (isapprox(v, -23.0f0) ? 25.0f0 : (v + 23.0f0) / (1.0f0 - ed))
)
end
# ix1 is the time-independent background potassium current
function calc_ix1(v, xi)
ea = exp(0.04f0 * (v + 77.0f0))
eb = exp(0.04f0 * (v + 35.0f0))
return xi * 0.8f0 * (ea - 1.0f0) / eb
end
# iNa is the sodium current (similar to the classic Hodgkin-Huxley model)
function calc_iNa(v, m, h, j)
return C_Na * (g_Na * m^3 * h * j + g_NaC) * (v - ENa)
end
# iCa is the calcium current
function calc_iCa(v, d, f, c)
ECa = D_Ca - 82.3f0 - 13.0278f0 * log(c) # ECa is the calcium reversal potential
return C_s * g_s * d * f * (v - ECa)
end
function update_du_cpu(du, u, XI, M, H, J, D, F, C)
n1, n2 = size(u)
for j = 1:n2
for i = 1:n1
v = u[i, j]
# calculating individual currents
iK1 = calc_iK1(v)
ix1 = calc_ix1(v, XI[i, j])
iNa = calc_iNa(v, M[i, j], H[i, j], J[i, j])
iCa = calc_iCa(v, D[i, j], F[i, j], C[i, j])
# # total current
I_sum = iK1 + ix1 + iNa + iCa
# # the reaction part of the reaction-diffusion equation
du[i, j] = -I_sum / C_m
end
end
end
function (f::BeelerReuterCpu)(du, u, p, t)
Δt = t - f.t
if Δt != 0 || t == 0
update_gates_cpu(
u,
eltype(u).(f.XI), # type conversion
eltype(u).(f.M),
eltype(u).(f.H),
eltype(u).(f.J),
eltype(u).(f.D),
eltype(u).(f.F),
eltype(u).(f.C),
eltype(u)(Δt),
)
f.t = t
end
laplacian(eltype(u).(f.Δu), u) # type conversion
# calculate the reaction portion
update_du_cpu(du, u, f.XI, f.M, f.H, f.J, f.D, f.F, f.C)
# ...add the diffusion portion
du .+= f.diff_coef .* f.Δu
du
end
N = 10;
u0 = fill(v0, (N, N));
u0[5:6,5:6] .= v1; # a small square in the middle of the domain
using Plots
heatmap(u0)
deriv_cpu = BeelerReuterCpu(u0, 1.0);
prob = ODEProblem(deriv_cpu, u0, (0.0, 5.0))
# Lower tolerances to show the methods converge to the same value
@time sol=solve(prob, TRBDF2(autodiff = true), saveat = 100.0)
The main changes are :
- Make the
struct BeelerReuterCpu
parametric. - In the
(f::BeelerReuterCpu)(du, u, p, t)
function, when callingupdate_gates_cpu
andlaplacian
functions, the aruguments are converted to the type ofu
, likeeltype(u).(f.Δu)
.
I made these changes otherwise I get the Dual Number
error
ERROR: LoadError: MethodError: no method matching Float64(::ForwardDiff.Dual{ForwardDiff.Tag{SciMLBase.UJacobianWrapper{ODEFunction{true, BeelerReuterCpu{Float64}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Float64, SciMLBase.NullParameters}, Float64}, Float64, 12})
Closest candidates are:
(::Type{T})(::Real, ::RoundingMode) where T<:AbstractFloat at rounding.jl:200
(::Type{T})(::T) where T<:Number at boot.jl:760
(::Type{T})(::AbstractChar) where T<:Union{AbstractChar, Number} at char.jl:50
My questions are:
- why type conversion is needed when calling
update_gates_cpu
andlaplacian
but notupdate_du_cpu
inside(f::BeelerReuterCpu)(du, u, p, t)
? - I used parametric type in
BeelerReuterCpu
to make the type consistent betweenu
and the fields inBeelerReuterCpu
. I assume that's necessary for theDual Number
to work, but obviously that's not enough. I suppose that's because thestruct
is not created dynamically? Is there a way to do that. - In the Documentation it is said
DiffEqBase.dualcache
is needed to avoid theDual Number
error when using cache. The cache example given there was very straightforward. But I wonder what exactly is considered as cache. Does it include any intermediary calculation variables?
I hope there could be an example of complex ODE with automatic differentiation. All the complex ODE examples in the tutorials and benchmarks only accept Float
and therefore doesn't work for automatic differentiation.
Thanks!