Skip to content

automatic differentiation not working in advanced examples? #387

Open
@JianghuiDu

Description

@JianghuiDu

I'm trying to find an example of how to use automatic differentiation in more complex ODE rather than those vanilla examples in the tutorials. Simple ODEs are quite easy to figure out, but I always have the trouble with the Dual Number error in complex ODE models. Seems my ODE functions are always not generic enough. And I'd like to learn how to make it work.

I was trying to modify the Beeler-Reuter Model (the CPU version) in the advanced tutorial, as the original code doesn't work with autodiff=true since it only accepts Float. I finally manage to make it work, but I have many problems that I don't understand. Here is the modified working code:

using DifferentialEquations, Sundials, ForwardDiff

const v0 = -84.624
const v1 = 10.0
const C_K1 = 1.0f0
const C_x1 = 1.0f0
const C_Na = 1.0f0
const C_s = 1.0f0
const D_Ca = 0.0f0
const D_Na = 0.0f0
const g_s = 0.09f0
const g_Na = 4.0f0
const g_NaC = 0.005f0
const ENa = 50.0f0 + D_Na
const γ = 0.5f0
const C_m = 1.0f0

# make this parametric
mutable struct BeelerReuterCpu{T} <: Function
    t::T              # the last timestep time to calculate Δt
    diff_coef::T      # the diffusion-coefficient (coupling strength)

    C::Array{T,2}   # intracellular calcium concentration
    M::Array{T,2}     # sodium current activation gate (m)
    H::Array{T,2}    # sodium current inactivation gate (h)
    J::Array{T,2}     # sodium current slow inactivaiton gate (j)
    D::Array{T,2}     # calcium current activaiton gate (d)
    F::Array{T,2}     # calcium current inactivation gate (f)
    XI::Array{T,2}    # inward-rectifying potassium current (iK1)
    Δu::Array{T,2}    # place-holder for the Laplacian

    function BeelerReuterCpu(u0::Array{T,2}, diff_coef::T) where {T}
        self = new{T}()

        ny, nx = size(u0)
        self.t = 0.0
        self.diff_coef = diff_coef

        self.C = fill(0.0001f0, (ny, nx))
        self.M = fill(0.01f0, (ny, nx))
        self.H = fill(0.988f0, (ny, nx))
        self.J = fill(0.975f0, (ny, nx))
        self.D = fill(0.003f0, (ny, nx))
        self.F = fill(0.994f0, (ny, nx))
        self.XI = fill(0.0001f0, (ny, nx))
        self.Δu = zeros(ny, nx)

        return self
    end
end


# 5-point stencil
function laplacian(Δu, u) 

    n1, n2 = size(u)

    # internal nodes
    for j = 2:n2-1
        for i = 2:n1-1
            @inbounds Δu[i, j] = u[i+1, j] + u[i-1, j] + u[i, j+1] + u[i, j-1] - 4 * u[i, j]
        end
    end

    # left/right edges
    for i = 2:n1-1
        @inbounds Δu[i, 1] = u[i+1, 1] + u[i-1, 1] + 2 * u[i, 2] - 4 * u[i, 1]
        @inbounds Δu[i, n2] = u[i+1, n2] + u[i-1, n2] + 2 * u[i, n2-1] - 4 * u[i, n2]
    end

    # top/bottom edges
    for j = 2:n2-1
        @inbounds Δu[1, j] = u[1, j+1] + u[1, j-1] + 2 * u[2, j] - 4 * u[1, j]
        @inbounds Δu[n1, j] = u[n1, j+1] + u[n1, j-1] + 2 * u[n1-1, j] - 4 * u[n1, j]
    end

    # corners
    @inbounds Δu[1, 1] = 2 * (u[2, 1] + u[1, 2]) - 4 * u[1, 1]
    @inbounds Δu[n1, 1] = 2 * (u[n1-1, 1] + u[n1, 2]) - 4 * u[n1, 1]
    @inbounds Δu[1, n2] = 2 * (u[2, n2] + u[1, n2-1]) - 4 * u[1, n2]
    @inbounds Δu[n1, n2] = 2 * (u[n1-1, n2] + u[n1, n2-1]) - 4 * u[n1, n2]
end

@inline function rush_larsen(g, α, β, Δt)
    inf = α /+ β)
    τ = 1.0f0 /+ β)
    return clamp(g + (g - inf) * expm1(-Δt / τ), 0.0f0, 1.0f0)
end

function update_M_cpu(g, v, Δt)
    # the condition is needed here to prevent NaN when v == 47.0
    α = isapprox(v, 47.0f0) ? 10.0f0 : -(v + 47.0f0) / (exp(-0.1f0 * (v + 47.0f0)) - 1.0f0)
    β = (40.0f0 * exp(-0.056f0 * (v + 72.0f0)))
    return rush_larsen(g, α, β, Δt)
end

function update_H_cpu(g, v, Δt)
    α = 0.126f0 * exp(-0.25f0 * (v + 77.0f0))
    β = 1.7f0 / (exp(-0.082f0 * (v + 22.5f0)) + 1.0f0)
    return rush_larsen(g, α, β, Δt)
end

function update_J_cpu(g, v, Δt)
    α = (0.55f0 * exp(-0.25f0 * (v + 78.0f0))) / (exp(-0.2f0 * (v + 78.0f0)) + 1.0f0)
    β = 0.3f0 / (exp(-0.1f0 * (v + 32.0f0)) + 1.0f0)
    return rush_larsen(g, α, β, Δt)
end

function update_D_cpu(g, v, Δt)
    α = γ * (0.095f0 * exp(-0.01f0 * (v - 5.0f0))) / (exp(-0.072f0 * (v - 5.0f0)) + 1.0f0)
    β = γ * (0.07f0 * exp(-0.017f0 * (v + 44.0f0))) / (exp(0.05f0 * (v + 44.0f0)) + 1.0f0)
    return rush_larsen(g, α, β, Δt)
end

function update_F_cpu(g, v, Δt)
    α = γ * (0.012f0 * exp(-0.008f0 * (v + 28.0f0))) / (exp(0.15f0 * (v + 28.0f0)) + 1.0f0)
    β = γ * (0.0065f0 * exp(-0.02f0 * (v + 30.0f0))) / (exp(-0.2f0 * (v + 30.0f0)) + 1.0f0)
    return rush_larsen(g, α, β, Δt)
end

function update_XI_cpu(g, v, Δt)
    α = (0.0005f0 * exp(0.083f0 * (v + 50.0f0))) / (exp(0.057f0 * (v + 50.0f0)) + 1.0f0)
    β = (0.0013f0 * exp(-0.06f0 * (v + 20.0f0))) / (exp(-0.04f0 * (v + 20.0f0)) + 1.0f0)
    return rush_larsen(g, α, β, Δt)
end

function update_C_cpu(g, d, f, v, Δt)
    ECa = D_Ca - 82.3f0 - 13.0278f0 * log(g)
    kCa = C_s * g_s * d * f
    iCa = kCa * (v - ECa)
    inf = 1.0f-7 * (0.07f0 - g)
    τ = 1.0f0 / 0.07f0
    return g + (g - inf) * expm1(-Δt / τ)
end

function update_gates_cpu(u, XI, M, H, J, D, F, C, Δt)
    # let Δt = Float32(Δt) # remove the Let 
        n1, n2 = size(u)
        for j = 1:n2
            for i = 1:n1
                v = u[i, j]

                XI[i, j] = update_XI_cpu(XI[i, j], v, Δt)
                M[i, j] = update_M_cpu(M[i, j], v, Δt)
                H[i, j] = update_H_cpu(H[i, j], v, Δt)
                J[i, j] = update_J_cpu(J[i, j], v, Δt)
                D[i, j] = update_D_cpu(D[i, j], v, Δt)
                F[i, j] = update_F_cpu(F[i, j], v, Δt)

                C[i, j] = update_C_cpu(C[i, j], D[i, j], F[i, j], v, Δt)
            end
        end
    # end
end

# iK1 is the inward-rectifying potassium current
function calc_iK1(v)
    ea = exp(0.04f0 * (v + 85.0f0))
    eb = exp(0.08f0 * (v + 53.0f0))
    ec = exp(0.04f0 * (v + 53.0f0))
    ed = exp(-0.04f0 * (v + 23.0f0))
    return 0.35f0 * (
        4.0f0 * (ea - 1.0f0) / (eb + ec) +
        0.2f0 * (isapprox(v, -23.0f0) ? 25.0f0 : (v + 23.0f0) / (1.0f0 - ed))
    )
end

# ix1 is the time-independent background potassium current
function calc_ix1(v, xi)
    ea = exp(0.04f0 * (v + 77.0f0))
    eb = exp(0.04f0 * (v + 35.0f0))
    return xi * 0.8f0 * (ea - 1.0f0) / eb
end

# iNa is the sodium current (similar to the classic Hodgkin-Huxley model)
function calc_iNa(v, m, h, j)
    return C_Na * (g_Na * m^3 * h * j + g_NaC) * (v - ENa)
end

# iCa is the calcium current
function calc_iCa(v, d, f, c)
    ECa = D_Ca - 82.3f0 - 13.0278f0 * log(c)    # ECa is the calcium reversal potential
    return C_s * g_s * d * f * (v - ECa)
end

function update_du_cpu(du, u, XI, M, H, J, D, F, C)
    n1, n2 = size(u)

    for j = 1:n2
        for i = 1:n1
            v = u[i, j]

            # calculating individual currents
            iK1 = calc_iK1(v)
            ix1 = calc_ix1(v, XI[i, j])
            iNa = calc_iNa(v, M[i, j], H[i, j], J[i, j])
            iCa = calc_iCa(v, D[i, j], F[i, j], C[i, j])

            # # total current
            I_sum = iK1 + ix1 + iNa + iCa

            # # the reaction part of the reaction-diffusion equation
            du[i, j] = -I_sum / C_m
        end
    end
end

function (f::BeelerReuterCpu)(du, u, p, t)
    Δt = t - f.t

    if Δt != 0 || t == 0
        update_gates_cpu(
            u,
            eltype(u).(f.XI), # type conversion
            eltype(u).(f.M),
            eltype(u).(f.H),
            eltype(u).(f.J),
            eltype(u).(f.D),
            eltype(u).(f.F),
            eltype(u).(f.C),
            eltype(u)(Δt),
        )
        f.t = t
    end

    laplacian(eltype(u).(f.Δu), u) # type conversion

    # calculate the reaction portion
    update_du_cpu(du, u, f.XI, f.M, f.H, f.J, f.D, f.F, f.C)

    # ...add the diffusion portion
    du .+= f.diff_coef .* f.Δu
    du
end

N = 10;
u0 = fill(v0, (N, N));
u0[5:6,5:6] .= v1;   # a small square in the middle of the domain

using Plots
heatmap(u0)

deriv_cpu = BeelerReuterCpu(u0, 1.0);
prob = ODEProblem(deriv_cpu, u0, (0.0, 5.0))
# Lower tolerances to show the methods converge to the same value
@time sol=solve(prob, TRBDF2(autodiff = true), saveat = 100.0)

The main changes are :

  1. Make the struct BeelerReuterCpu parametric.
  2. In the (f::BeelerReuterCpu)(du, u, p, t) function, when calling update_gates_cpu and laplacian functions, the aruguments are converted to the type of u , like eltype(u).(f.Δu).

I made these changes otherwise I get the Dual Number error

ERROR: LoadError: MethodError: no method matching Float64(::ForwardDiff.Dual{ForwardDiff.Tag{SciMLBase.UJacobianWrapper{ODEFunction{true, BeelerReuterCpu{Float64}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Float64, SciMLBase.NullParameters}, Float64}, Float64, 12})
Closest candidates are:
  (::Type{T})(::Real, ::RoundingMode) where T<:AbstractFloat at rounding.jl:200
  (::Type{T})(::T) where T<:Number at boot.jl:760
  (::Type{T})(::AbstractChar) where T<:Union{AbstractChar, Number} at char.jl:50

My questions are:

  1. why type conversion is needed when calling update_gates_cpu and laplacian but not update_du_cpu inside (f::BeelerReuterCpu)(du, u, p, t)?
  2. I used parametric type in BeelerReuterCpu to make the type consistent between u and the fields in BeelerReuterCpu . I assume that's necessary for the Dual Number to work, but obviously that's not enough. I suppose that's because the struct is not created dynamically? Is there a way to do that.
  3. In the Documentation it is said DiffEqBase.dualcache is needed to avoid the Dual Number error when using cache. The cache example given there was very straightforward. But I wonder what exactly is considered as cache. Does it include any intermediary calculation variables?

I hope there could be an example of complex ODE with automatic differentiation. All the complex ODE examples in the tutorials and benchmarks only accept Float and therefore doesn't work for automatic differentiation.
Thanks!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions