Skip to content
8 changes: 5 additions & 3 deletions lib/OrdinaryDiffEqCore/src/alg_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -309,17 +309,19 @@ function DiffEqBase.prepare_alg(alg::CompositeAlgorithm, u0, p, prob)
if cf isa AutoSwitch
nonstiffalg = _prepare_autoswitch_alg(cf.nonstiffalg, u0, p, prob)
stiffalg = _prepare_autoswitch_alg(cf.stiffalg, u0, p, prob)
cf = AutoSwitch(nonstiffalg, stiffalg,
cf = AutoSwitch(
nonstiffalg, stiffalg,
cf.maxstiffstep, cf.maxnonstiffstep,
cf.nonstifftol, cf.stifftol,
cf.dtfac, cf.stiffalgfirst, cf.switch_max)
cf.dtfac, cf.stiffalgfirst, cf.switch_max
)
end
return CompositeAlgorithm(algs, cf)
end

_prepare_autoswitch_alg(alg, u0, p, prob) = DiffEqBase.prepare_alg(alg, u0, p, prob)
function _prepare_autoswitch_alg(algs::Tuple, u0, p, prob)
map(a -> DiffEqBase.prepare_alg(a, u0, p, prob), algs)
return map(a -> DiffEqBase.prepare_alg(a, u0, p, prob), algs)
end

has_autodiff(alg::OrdinaryDiffEqAlgorithm) = false
Expand Down
2 changes: 1 addition & 1 deletion lib/OrdinaryDiffEqDifferentiation/src/alg_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ function prepare_ADType(autodiff_alg::AutoSparse, prob, u0, p, standardtag)
end

function prepare_ADType(autodiff_alg::AutoForwardDiff, prob, u0, p, standardtag::Bool)
prepare_ADType(autodiff_alg, prob, u0, p, Val(standardtag))
return prepare_ADType(autodiff_alg, prob, u0, p, Val(standardtag))
end

function _prepare_ADType_fwd(autodiff_alg::AutoForwardDiff, prob, u0, tag)
Expand Down
50 changes: 50 additions & 0 deletions lib/OrdinaryDiffEqMultirate/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
name = "OrdinaryDiffEqMultirate"
uuid = "d4b830b4-ac80-426b-8507-16693d424963"
authors = ["singhharsh1708 <hs1663531@gmail.com>"]
version = "0.1.0"

[deps]
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221"
OrdinaryDiffEqCore = "bbf590c4-e513-4bbe-9b18-05decba2e5d8"
OrdinaryDiffEqLowOrderRK = "1344f307-1e59-4825-a18e-ace9aa3fa4c6"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"

[compat]
Aqua = "0.8.11"
DiffEqBase = "6.194"
DiffEqDevTools = "2.44.4"
FastBroadcast = "0.3"
JET = "0.9, 0.11"
MuladdMacro = "0.2"
OrdinaryDiffEqCore = "3"
OrdinaryDiffEqLowOrderRK = "1"
Pkg = "1"
RecursiveArrayTools = "3.36"
Reexport = "1.2"
SafeTestsets = "0.1.0"
SciMLBase = "2.116"
Test = "<0.0.1, 1"
julia = "1.10"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
DiffEqDevTools = "f3b72e0c-5b89-59e1-b016-84e28bfd966d"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
OrdinaryDiffEqLowOrderRK = "1344f307-1e59-4825-a18e-ace9aa3fa4c6"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[sources.OrdinaryDiffEqCore]
path = "../OrdinaryDiffEqCore"

[sources.OrdinaryDiffEqLowOrderRK]
path = "../OrdinaryDiffEqLowOrderRK"

[targets]
test = ["DiffEqDevTools", "LinearAlgebra", "OrdinaryDiffEqLowOrderRK", "SafeTestsets", "Test", "Pkg"]
26 changes: 26 additions & 0 deletions lib/OrdinaryDiffEqMultirate/src/OrdinaryDiffEqMultirate.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
module OrdinaryDiffEqMultirate

import OrdinaryDiffEqCore: alg_order, isfsal,
OrdinaryDiffEqAdaptiveAlgorithm,
generic_solver_docstring,
unwrap_alg, initialize!, perform_step!,
calculate_residuals, calculate_residuals!,
OrdinaryDiffEqMutableCache, OrdinaryDiffEqConstantCache,
@cache, alg_cache, full_cache, get_fsalfirstlast
import OrdinaryDiffEqCore
import FastBroadcast: @..
import MuladdMacro: @muladd
import RecursiveArrayTools: recursivefill!
import DiffEqBase: prepare_alg

using Reexport
@reexport using SciMLBase

include("algorithms.jl")
include("alg_utils.jl")
include("mreef_caches.jl")
include("mreef_perform_step.jl")

export MREEF

end
6 changes: 6 additions & 0 deletions lib/OrdinaryDiffEqMultirate/src/alg_utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
alg_order(alg::MREEF) = alg.order
isfsal(::MREEF) = false

function prepare_alg(alg::MREEF, u0::AbstractArray, p, prob)
return alg
end
30 changes: 30 additions & 0 deletions lib/OrdinaryDiffEqMultirate/src/algorithms.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
@doc generic_solver_docstring(
"Multirate Richardson Extrapolation with Euler as the base method (MREEF).

Solves a split ODE of the form `du/dt = f1(u,t) + f2(u,t)` where `f1` is the
fast component and `f2` is the slow component (SciML convention). The slow rate
`f2` is frozen over each macro interval and the fast rate `f1` is integrated
with `m` explicit Euler substeps. Aitken–Neville Richardson extrapolation over
`order` base solutions is then applied to boost accuracy.",
"MREEF",
"Multirate explicit method.",
"""@article{engstrom2009multirate,
title={Multirate explicit Adams methods for time integration of conservation laws},
author={Engstr{\\\"o}m, C and Ferm, L and L{\\\"o}tstedt, P and Sj{\\\"o}green, B},
year={2009}}""",
"""
- `m`: number of fast substeps per macro interval. Default is `4`.
- `order`: extrapolation order (number of base solutions). Default is `4`.
- `seq`: subdivision sequence, `:harmonic` (default) or `:romberg`.
""",
"""
m::Int = 4,
order::Int = 4,
seq::Symbol = :harmonic,
"""
)
Base.@kwdef struct MREEF <: OrdinaryDiffEqAdaptiveAlgorithm
m::Int = 4
order::Int = 4
seq::Symbol = :harmonic
end
44 changes: 44 additions & 0 deletions lib/OrdinaryDiffEqMultirate/src/mreef_caches.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
struct MREEFConstantCache{T} <: OrdinaryDiffEqConstantCache
T::T # pre-allocated extrapolation table: Vector of length `order`
end

@cache mutable struct MREEFCache{uType, rateType, uNoUnitsType} <: OrdinaryDiffEqMutableCache
u::uType
uprev::uType
tmp::uType
atmp::uNoUnitsType
k_slow::rateType
k_fast::rateType
T::Array{uType, 1}
fsalfirst::rateType
k::rateType
end

get_fsalfirstlast(cache::MREEFCache, u) = (cache.fsalfirst, cache.k)

function alg_cache(
alg::MREEF, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
dt, reltol, p, calck,
::Val{true}, verbose
) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
tmp = zero(u)
atmp = similar(u, uEltypeNoUnits)
recursivefill!(atmp, false)
k_slow = zero(rate_prototype)
k_fast = zero(rate_prototype)
T = [zero(u) for _ in 1:(alg.order)]
fsalfirst = zero(rate_prototype)
k = zero(rate_prototype)
return MREEFCache(u, uprev, tmp, atmp, k_slow, k_fast, T, fsalfirst, k)
end

function alg_cache(
alg::MREEF, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
dt, reltol, p, calck,
::Val{false}, verbose
) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
T = Vector{typeof(u)}(undef, alg.order)
return MREEFConstantCache(T)
end
163 changes: 163 additions & 0 deletions lib/OrdinaryDiffEqMultirate/src/mreef_perform_step.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
# ── MREEF: step-count sequence ────────────────────────────────────────────────

@inline function _mreef_sequence(seq::Symbol, order::Int)
if seq === :harmonic
return ntuple(j -> j, order)
elseif seq === :romberg
return ntuple(j -> 1 << (j - 1), order)
else
throw(ArgumentError("MREEF: unknown sequence `$seq`, choose :harmonic or :romberg"))
end
end

# ── MREEF initialize! ─────────────────────────────────────────────────────────

function initialize!(integrator, cache::MREEFCache)
integrator.kshortsize = 2
(; fsalfirst, k) = cache
integrator.fsalfirst = fsalfirst
integrator.fsallast = k
resize!(integrator.k, integrator.kshortsize)
integrator.k[1] = integrator.fsalfirst
integrator.k[2] = integrator.fsallast
integrator.f.f1(integrator.fsalfirst, integrator.uprev, integrator.p, integrator.t)
integrator.f.f2(cache.tmp, integrator.uprev, integrator.p, integrator.t)
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) # f1
integrator.stats.nf2 += 1 # f2
return integrator.fsalfirst .+= cache.tmp
end

function initialize!(integrator, cache::MREEFConstantCache)
integrator.kshortsize = 2
integrator.k = typeof(integrator.k)(undef, integrator.kshortsize)
integrator.fsalfirst = integrator.f.f1(integrator.uprev, integrator.p, integrator.t) +
integrator.f.f2(integrator.uprev, integrator.p, integrator.t)
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)
integrator.stats.nf2 += 1
integrator.fsallast = zero(integrator.fsalfirst)
integrator.k[1] = integrator.fsalfirst
return integrator.k[2] = integrator.fsallast
end

# ── MREEF perform_step! (in-place, MutableCache) ──────────────────────────────
#
# Base multirate Euler with nj macro intervals, m fast substeps each:
# 1. k_slow = f.f2(u, p, t_mac) — frozen slow rate for the macro interval
# 2. m fast substeps: u += h_fast*(k_slow + f.f1(u, p, t_fast))
# f1 = fast/stiff (large eigenvalues), f2 = slow/non-stiff (SciML convention).
# Then apply Aitken–Neville Richardson extrapolation over T[1..order].

function perform_step!(integrator, cache::MREEFCache, repeat_step = false)
(; t, dt, uprev, u, f, p) = integrator
(; tmp, atmp, k_slow, k_fast, T) = cache
alg = unwrap_alg(integrator, false)
m = alg.m
order = alg.order
ns = _mreef_sequence(alg.seq, order)

# Fill first tableau column: T[j] = base method with ns[j] macro intervals
for j in 1:order
nj = ns[j]
h_mac = dt / nj
h_fast = h_mac / m

@.. broadcast = false T[j] = uprev

for i_mac in 1:nj
t_mac = t + (i_mac - 1) * h_mac

# Slow evaluation (f2): frozen for all m fast substeps
f.f2(k_slow, T[j], p, t_mac)
integrator.stats.nf2 += 1

for i_fast in 1:m
t_fast = t_mac + (i_fast - 1) * h_fast
f.f1(k_fast, T[j], p, t_fast)
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)
@.. broadcast = false T[j] = T[j] + h_fast * k_slow + h_fast * k_fast
end
end
end

# Aitken–Neville Richardson extrapolation (in-place, reverse-row order)
# Formula: T[j] <- T[j] + (T[j] - T[j-1]) / (ns[j]/ns[j-k] - 1)
for k in 1:(order - 1)
for j in order:-1:(k + 1)
ratio = ns[j] / ns[j - k]
@.. broadcast = false tmp = (T[j] - T[j - 1]) / (ratio - 1)
@.. broadcast = false T[j] = T[j] + tmp
end
end

@.. broadcast = false u = T[order]

return if integrator.opts.adaptive
@.. broadcast = false tmp = T[order] - T[order - 1]
calculate_residuals!(
atmp,
tmp,
uprev,
u,
integrator.opts.abstol,
integrator.opts.reltol,
integrator.opts.internalnorm,
t,
)
integrator.EEst = integrator.opts.internalnorm(atmp, t)
end
end

# ── MREEF perform_step! (out-of-place, ConstantCache) ─────────────────────────

@muladd function perform_step!(integrator, cache::MREEFConstantCache, repeat_step = false)
(; t, dt, uprev, f, p) = integrator
alg = unwrap_alg(integrator, false)
m = alg.m
order = alg.order
ns = _mreef_sequence(alg.seq, order)
T = cache.T

for j in 1:order
nj = ns[j]
h_mac = dt / nj
h_fast = h_mac / m

u_cur = uprev
for i_mac in 1:nj
t_mac = t + (i_mac - 1) * h_mac
k_slow = f.f2(u_cur, p, t_mac)
integrator.stats.nf2 += 1
for i_fast in 1:m
t_fast = t_mac + (i_fast - 1) * h_fast
k_fast = f.f1(u_cur, p, t_fast)
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)
u_cur = @.. broadcast = false u_cur + h_fast * k_slow + h_fast * k_fast
end
end
T[j] = u_cur
end

# Aitken–Neville Richardson extrapolation
for k in 1:(order - 1)
for j in order:-1:(k + 1)
ratio = ns[j] / ns[j - k]
T[j] = @.. broadcast = false T[j] + (T[j] - T[j - 1]) / (ratio - 1)
end
end

integrator.u = T[order]

if integrator.opts.adaptive
utilde = @.. broadcast = false T[order] - T[order - 1]
atmp = calculate_residuals(
utilde,
uprev,
integrator.u,
integrator.opts.abstol,
integrator.opts.reltol,
integrator.opts.internalnorm,
t,
)
integrator.EEst = integrator.opts.internalnorm(atmp, t)
end
end
Loading
Loading