Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions src/OrdinaryDiffEqOperatorSplitting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import Unrolled: @unroll

import SciMLBase, DiffEqBase, DataStructures

import OrdinaryDiffEqCore
import OrdinaryDiffEqCore: OrdinaryDiffEqCore, isadaptive, alg_order

import UnPack: @unpack
import DiffEqBase: init, TimeChoiceIterator
Expand All @@ -16,14 +16,13 @@ abstract type AbstractOperatorSplitFunction <: DiffEqBase.AbstractODEFunction{tr
abstract type AbstractOperatorSplittingAlgorithm end
abstract type AbstractOperatorSplittingCache end

@inline DiffEqBase.isadaptive(::AbstractOperatorSplittingAlgorithm) = false

include("function.jl")
include("problem.jl")
include("integrator.jl")
include("solver.jl")
include("utils.jl")
include("controller.jl")

export GenericSplitFunction, OperatorSplittingProblem, LieTrotterGodunov
export GenericSplitFunction, OperatorSplittingProblem, LieTrotterGodunov, PalindromicPairLieTrotterGodunov

end
98 changes: 98 additions & 0 deletions src/controller.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
@inline OrdinaryDiffEqCore.ispredictive(::AbstractOperatorSplittingAlgorithm) = false
@inline OrdinaryDiffEqCore.isstandard(::AbstractOperatorSplittingAlgorithm) = false
function OrdinaryDiffEqCore.beta2_default(alg::AbstractOperatorSplittingAlgorithm)
isadaptive(alg) ? 2 // (5alg_order(alg)) : 0 // 1
end
function OrdinaryDiffEqCore.beta1_default(alg::AbstractOperatorSplittingAlgorithm, beta2)
isadaptive(alg) ? 7 // (10alg_order(alg)) : 0 // 1
end

function OrdinaryDiffEqCore.qmin_default(alg::AbstractOperatorSplittingAlgorithm)
isadaptive(alg) ? 1 // 5 : 0 // 1
end
OrdinaryDiffEqCore.qmax_default(alg::AbstractOperatorSplittingAlgorithm) = 10 // 1
function OrdinaryDiffEqCore.gamma_default(alg::AbstractOperatorSplittingAlgorithm)
isadaptive(alg) ? 9 // 10 : 0 // 1
end
OrdinaryDiffEqCore.qsteady_min_default(alg::AbstractOperatorSplittingAlgorithm) = 1 // 1
OrdinaryDiffEqCore.qsteady_max_default(alg::AbstractOperatorSplittingAlgorithm) = 1 // 1

mutable struct PIController{T} <: OrdinaryDiffEqCore.AbstractController
qmin::T
qmax::T
qsteady_min::T
qsteady_max::T
qoldinit::T
beta1::T
beta2::T
gamma::T
# Internal
q11::T
qold::T
q::T
end
PIController(; qmin, qmax, qsteady_min, qsteady_max, qoldinit, beta1, beta2, gamma, q11) = PIController(qmin, qmax, qsteady_min, qsteady_max, qoldinit, beta1, beta2, gamma, q11, qoldinit, qoldinit)

function default_controller(alg, cache)
if !isadaptive(alg)
@warn "Trying to construct a controller for $alg, but the algorithm is not adaptive."
return nothing
end

beta2 = OrdinaryDiffEqCore.beta2_default(alg)
beta1 = OrdinaryDiffEqCore.beta1_default(alg, beta2)
qmin = OrdinaryDiffEqCore.qmin_default(alg)
qmax = OrdinaryDiffEqCore.qmax_default(alg)
gamma = OrdinaryDiffEqCore.gamma_default(alg)
qsteady_min = OrdinaryDiffEqCore.qsteady_min_default(alg)
qsteady_max = OrdinaryDiffEqCore.qsteady_max_default(alg)
qoldinit = 1 // 10^4
q11 = 1 // 1
PIController(;
beta1, beta2,
qmin, qmax,
gamma,
qsteady_min, qsteady_max,
qoldinit, q11
)
end

@inline DiffEqBase.isadaptive(::AbstractOperatorSplittingAlgorithm) = false

@inline function stepsize_controller!(integrator::OperatorSplittingIntegrator, controller::PIController, alg)
(; qold, qmin, qmax, gamma) = controller
(; beta1, beta2) = controller
EEst = DiffEqBase.value(integrator.EEst)

if iszero(EEst)
q = inv(qmax)
else
q11 = OrdinaryDiffEqCore.fastpower(EEst, convert(typeof(EEst), beta1))
q = q11 / OrdinaryDiffEqCore.fastpower(qold, convert(typeof(EEst), beta2))
controller.q11 = q11
@fastmath q = max(inv(qmax), min(inv(qmin), q / gamma))
end
controller.q = q # Return Q for temporary compat with OrdinaryDiffEqCore
end

function step_accept_controller!(integrator::OperatorSplittingIntegrator, controller::PIController, alg)
(; q, qsteady_min, qsteady_max, qoldinit) = controller
EEst = DiffEqBase.value(integrator.EEst)

if qsteady_min <= q <= qsteady_max
q = one(q)
end
controller.qold = max(EEst, qoldinit)
integrator.dt /= q
return nothing
end

function step_reject_controller!(integrator::OperatorSplittingIntegrator, controller::PIController, alg)
(; q11, qmin, gamma) = controller
integrator.dt /= min(inv(qmin), q11 / gamma)
return nothing
end

@inline function should_accept_step(integrator, controller::OrdinaryDiffEqCore.AbstractController)
return integrator.EEst <= 1
end
139 changes: 56 additions & 83 deletions src/integrator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@ end

IntegratorStats() = IntegratorStats(0, 0)

Base.@kwdef mutable struct IntegratorOptions{tType, fType, F3}
Base.@kwdef mutable struct IntegratorOptions{tType, fType, F2, F3}
adaptive::Bool
dtmin::tType = eps(Float64)
dtmax::tType = Inf
failfactor::fType = 4.0
verbose::Bool = false
internalnorm::F2 = DiffEqBase.ODE_DEFAULT_NORM
isoutofdomain::F3 = DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN
end

Expand Down Expand Up @@ -70,6 +71,7 @@ mutable struct OperatorSplittingIntegrator{
synchronizer_tree::syncTreeType
iter::Int
controller::controllerType
EEst::Float64 # TODO integrate with controller cache
opts::optionsType
stats::IntegratorStats
tdir::tType
Expand Down Expand Up @@ -129,16 +131,22 @@ function DiffEqBase.__init(

callback = DiffEqBase.CallbackSet(callback)

opts = IntegratorOptions(; verbose, adaptive, kwargs...)

subintegrator_tree,
cache = build_subintegrator_tree_with_cache(
prob, alg,
uprev, u,
1:length(u),
t0, dt, tf,
tstops, saveat, d_discontinuities, callback,
adaptive, verbose
opts,
)

if controller === nothing && adaptive
controller = default_controller(alg, cache)
end

integrator = OperatorSplittingIntegrator(
prob.f,
alg,
Expand Down Expand Up @@ -168,7 +176,8 @@ function DiffEqBase.__init(
build_synchronizer_tree(prob.f),
0,
controller,
IntegratorOptions(; verbose, adaptive),
NaN,
opts,
IntegratorStats(),
tType(tstops_internal.ordering isa DataStructures.FasterForward ? 1 : -1)
)
Expand Down Expand Up @@ -280,13 +289,17 @@ increment_iteration(integrator::OperatorSplittingIntegrator) = integrator.iter +
# Controller interface
function reject_step!(integrator::OperatorSplittingIntegrator)
OrdinaryDiffEqCore.increment_reject!(integrator.stats)
reject_step!(integrator, integrator.cache, integrator.controller)
reject_step!(integrator, integrator.controller)
end
function reject_step!(integrator::OperatorSplittingIntegrator, cache, controller)
function reject_step!(integrator::OperatorSplittingIntegrator, controller)
integrator.u .= integrator.uprev
# TODO what do we need to do with the subintegrators?
if !integrator.force_stepfail
step_reject_controller!(integrator, controller, integrator.alg)
end
# We need to roll-back the sub-integrators
prepare_subintegrators_to_redo_step!(integrator)
end
function reject_step!(integrator::OperatorSplittingIntegrator, cache, ::Nothing)
function reject_step!(integrator::OperatorSplittingIntegrator, ::Nothing)
if length(integrator.uprev) == 0
error("Cannot roll back integrator. Aborting time integration step at $(integrator.t).")
end
Expand All @@ -297,9 +310,9 @@ function should_accept_step(integrator::OperatorSplittingIntegrator)
if integrator.force_stepfail || integrator.isout
return false
end
return should_accept_step(integrator, integrator.cache, integrator.controller)
return should_accept_step(integrator, integrator.controller)
end
function should_accept_step(integrator::OperatorSplittingIntegrator, cache, ::Nothing)
function should_accept_step(integrator::OperatorSplittingIntegrator, ::Nothing)
return !(integrator.force_stepfail)
end
function accept_step!(integrator::OperatorSplittingIntegrator)
Expand Down Expand Up @@ -366,7 +379,7 @@ function step_footer!(integrator::OperatorSplittingIntegrator)
integrator.t = ttmp#OrdinaryDiffEqCore.fixed_t_for_floatingpoint_error!(integrator, ttmp)
# OrdinaryDiffEqCore.handle_callbacks!(integrator)
step_accept_controller!(integrator) # Noop for non-adaptive algorithms
elseif integrator.force_stepfail
elseif integrator.force_stepfail # Rejected by solver
if SciMLBase.isadaptive(integrator)
step_reject_controller!(integrator)
OrdinaryDiffEqCore.post_newton_controller!(integrator, integrator.alg)
Expand Down Expand Up @@ -525,9 +538,8 @@ end
Updates the controller using the current state of the integrator if the operator splitting algorithm is adaptive.
"""
@inline function stepsize_controller!(integrator::OperatorSplittingIntegrator)
algorithm = integrator.alg
DiffEqBase.isadaptive(algorithm) || return nothing
stepsize_controller!(integrator, algorithm)
DiffEqBase.isadaptive(integrator) || return nothing
stepsize_controller!(integrator, integrator.controller, integrator.alg)
end

"""
Expand All @@ -536,9 +548,8 @@ end
Updates `dtcache` of the integrator if the step is accepted and the operator splitting algorithm is adaptive.
"""
@inline function step_accept_controller!(integrator::OperatorSplittingIntegrator)
algorithm = integrator.alg
DiffEqBase.isadaptive(algorithm) || return nothing
step_accept_controller!(integrator, algorithm, nothing)
DiffEqBase.isadaptive(integrator) || return nothing
step_accept_controller!(integrator, integrator.controller, integrator.alg)
end

"""
Expand All @@ -547,9 +558,8 @@ end
Updates `dtcache` of the integrator if the step is rejected and the the operator splitting algorithm is adaptive.
"""
@inline function step_reject_controller!(integrator::OperatorSplittingIntegrator)
algorithm = integrator.alg
DiffEqBase.isadaptive(algorithm) || return nothing
step_reject_controller!(integrator, algorithm, nothing)
DiffEqBase.isadaptive(integrator) || return nothing
step_reject_controller!(integrator, integrator.controller, integrator.alg)
end

# helper functions for dealing with time-reversed integrators in the same way
Expand Down Expand Up @@ -646,7 +656,7 @@ end
function synchronize_subintegrator!(
subintegrator::SciMLBase.DEIntegrator, integrator::OperatorSplittingIntegrator)
@unpack t, dt = integrator
@assert subintegrator.t == t
@assert subintegrator.t == t "Integrators out of sync. The outer integrator is at $t, but inner integrator is at $(subintegrator.t)"
if !DiffEqBase.isadaptive(subintegrator)
SciMLBase.set_proposed_dt!(subintegrator, dt)
end
Expand All @@ -662,72 +672,37 @@ end
# Dispatch for tree node construction
function build_subintegrator_tree_with_cache(
prob::OperatorSplittingProblem, alg::AbstractOperatorSplittingAlgorithm,
f::GenericSplitFunction, p::Tuple,
uprevouter::AbstractVector, uouter::AbstractVector,
solution_indices,
t0, dt, tf,
tstops, saveat, d_discontinuities, callback,
adaptive, verbose
args...,
)
(; f, p) = prob
subintegrator_tree_with_caches = ntuple(
i -> build_subintegrator_tree_with_cache(
prob,
alg.inner_algs[i],
get_operator(f, i),
p[i],
uprevouter, uouter,
f.solution_indices[i],
t0, dt, tf,
tstops, saveat, d_discontinuities, callback,
adaptive, verbose
),
length(f.functions)
)
# subintegrator_tree_with_caches = ntuple(
# i -> build_subintegrator_tree_with_cache(
# OperatorSplittingProblem(),
# alg.inner_algs[i],
# get_operator(f, i),
# p[i],
# uprevouter, uouter,
# f.solution_indices[i],
# t0, dt, tf,
# args...,
# ),
# length(f.functions)
# )

# subintegrator_tree_leafs = first.(subintegrator_tree_with_caches)
# inner_caches = last.(subintegrator_tree_with_caches)

subintegrator_tree = ntuple(
i -> subintegrator_tree_with_caches[i][1], length(f.functions))
caches = ntuple(i -> subintegrator_tree_with_caches[i][2], length(f.functions))

# TODO fix mixed device type problems we have to be smarter
return subintegrator_tree,
init_cache(f, alg;
uprev = uprevouter, u = uouter, alias_u = true,
inner_caches = caches
)
end

function build_subintegrator_tree_with_cache(
prob::OperatorSplittingProblem, alg::AbstractOperatorSplittingAlgorithm,
f::GenericSplitFunction, p::Tuple,
uprevouter::AbstractVector, uouter::AbstractVector,
solution_indices,
t0, dt, tf,
tstops, saveat, d_discontinuities, callback,
adaptive, verbose,
save_end = false,
controller = nothing
)
subintegrator_tree_with_caches = ntuple(
i -> build_subintegrator_tree_with_cache(
prob,
alg.inner_algs[i],
get_operator(f, i),
p[i],
uprevouter, uouter,
f.solution_indices[i],
t0, dt, tf,
tstops, saveat, d_discontinuities, callback,
adaptive, verbose
),
i-> DiffEqBase.__init()
length(f.functions)
)

subintegrator_tree = first.(subintegrator_tree_with_caches)
inner_caches = last.(subintegrator_tree_with_caches)

# TODO fix mixed device type problems we have to be smarter
uprev = @view uprevouter[solution_indices]
u = @view uouter[solution_indices]
u = @view uouter[solution_indices]
return subintegrator_tree,
init_cache(f, alg;
uprev = uprev, u = u,
Expand All @@ -741,13 +716,11 @@ function build_subintegrator_tree_with_cache(
uprevouter::S, uouter::S,
solution_indices,
t0::T, dt::T, tf::T,
tstops, saveat, d_discontinuities, callback,
adaptive, verbose,
save_end = false,
controller = nothing
opts,
args...,
) where {S, T, P, F}
uprev = @view uprevouter[solution_indices]
u = @view uouter[solution_indices]
u = @view uouter[solution_indices]

integrator = DiffEqBase.__init(
SciMLBase.ODEProblem(f, u, (t0, min(t0 + dt, tf)), p),
Expand All @@ -757,9 +730,9 @@ function build_subintegrator_tree_with_cache(
d_discontinuities,
save_everystep = false,
advance_to_tstop = false,
adaptive,
controller,
verbose
opts.adaptive,
opts.verbose,
args...,
)

return integrator, integrator.cache
Expand Down
Loading