diff --git a/src/OrdinaryDiffEqOperatorSplitting.jl b/src/OrdinaryDiffEqOperatorSplitting.jl index 7647074..a386d75 100644 --- a/src/OrdinaryDiffEqOperatorSplitting.jl +++ b/src/OrdinaryDiffEqOperatorSplitting.jl @@ -11,13 +11,15 @@ import SciMLBase: DEIntegrator, NullParameters, isadaptive import RecursiveArrayTools -import OrdinaryDiffEqCore +import OrdinaryDiffEqCore: OrdinaryDiffEqCore, isdtchangeable, + stepsize_controller!, step_accept_controller!, step_reject_controller! abstract type AbstractOperatorSplitFunction <: SciMLBase.AbstractODEFunction{true} end abstract type AbstractOperatorSplittingAlgorithm end abstract type AbstractOperatorSplittingCache end @inline SciMLBase.isadaptive(::AbstractOperatorSplittingAlgorithm) = false +@inline isdtchangeable(alg::AbstractOperatorSplittingAlgorithm) = all(isdtchangeable.(alg.inner_algs)) include("function.jl") include("problem.jl") diff --git a/src/function.jl b/src/function.jl index 31e18ac..b06b4b5 100644 --- a/src/function.jl +++ b/src/function.jl @@ -13,11 +13,24 @@ struct GenericSplitFunction{fSetType <: Tuple, idxSetType <: Tuple, sSetType <: # Operators to update the ode function parameters. synchronizers::sSetType function GenericSplitFunction(fs::Tuple, drs::Tuple, syncers::Tuple) - @assert length(fs) == length(drs) == length(syncers) + @assert length(fs) == length(drs) == length(syncers) "Number of input tuples does not match." + gsf_recursive_function_type_safety_check.(fs) return new{typeof(fs), typeof(drs), typeof(syncers)}(fs, drs, syncers) end end +function gsf_recursive_function_type_safety_check(f::GenericSplitFunction) + return gsf_recursive_function_type_safety_check.(f.functions) +end + +function gsf_recursive_function_type_safety_check(dunno) + return @warn "One of the inner functions in GenericSplitFunction is of type $(typeof(dunno)) which is not a subtype of SciMLBase.AbstractDiffEqFunction." +end + +function gsf_recursive_function_type_safety_check(::SciMLBase.AbstractDiffEqFunction) + # OK +end + num_operators(f::GenericSplitFunction) = length(f.functions) """ diff --git a/src/integrator.jl b/src/integrator.jl index 45b2df0..ffd6c10 100644 --- a/src/integrator.jl +++ b/src/integrator.jl @@ -15,12 +15,116 @@ Base.@kwdef mutable struct IntegratorOptions{tType, fType, F3} isoutofdomain::F3 = DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN end + """ - OperatorSplittingIntegrator <: AbstractODEIntegrator + SplitSubIntegratorStatus + +Minimal error-communication object carried by a [`SplitSubIntegrator`](@ref). +Holds only `retcode` so that failure can be propagated up the operator-splitting +tree without carrying an actual solution vector. +""" +mutable struct SplitSubIntegratorStatus + retcode::ReturnCode.T +end + +SplitSubIntegratorStatus() = SplitSubIntegratorStatus(ReturnCode.Default) + + +""" + SplitSubIntegrator <: AbstractODEIntegrator + +An intermediate node in the operator-splitting subintegrator tree. + +Each `SplitSubIntegrator` is self-contained: it knows its own solution indices, +its children's synchronizers, solution indices, and sub-integrators. It does +**not** carry an `f` field (operator information lives in the cache/algorithm). + +## Fields +- `alg` — `AbstractOperatorSplittingAlgorithm` at this level +- `u` — local solution buffer for this sub-problem (may be a + view *or* an independent array, e.g. for GPU sub-problems) +- `uprev` — copy of `u` at the start of a step (for rollback) +- `u_master` — reference to the full master solution vector of the + outermost `OperatorSplittingIntegrator` (needed for sync) +- `t`, `dt`, `dtcache` — time tracking + `dtchangeable`, `stops` +- `iter` — step counter at this level +- `EEst` — error estimate (`NaN` for non-adaptive, `1.0` default + for adaptive) +- `controller` — step-size controller (or `nothing` for non-adaptive) +- `force_stepfail` — flag: current step must be retried +- `last_step_failed` — flag: previous step failed (double-failure detection) +- `status` — [`SplitSubIntegratorStatus`](@ref) for retcode communication +- `cache` — `AbstractOperatorSplittingCache` for the algorithm at + this level +- `child_subintegrators` — tuple of direct children (`SplitSubIntegrator` or + `DEIntegrator`) +- `solution_indices` — global indices (into parent `u`) **owned by this node** +- `child_solution_indices` — tuple of per-child global solution indices +- `child_synchronizers` — tuple of per-child synchronizer objects +""" +mutable struct SplitSubIntegrator{ + algType, + uType, + tType, + tstopsType, + EEstType, + controllerType, + cacheType, + childSubintType, + solidxType, + childSolidxType, + childSyncType, + optionsType, + } <: SciMLBase.AbstractODEIntegrator{algType, true, uType, tType} + alg::algType + u::uType # local solution buffer + uprev::uType # local rollback buffer + u_master::uType # reference to outermost master u + t::tType + tprev::tType + dt::tType + dtcache::tType + const dtchangeable::Bool + tstops::tstopsType + iter::Int + EEst::EEstType + controller::controllerType + force_stepfail::Bool + last_step_failed::Bool + u_modified::Bool # TODO we can probably remove this + status::SplitSubIntegratorStatus + stats::IntegratorStats + cache::cacheType + child_subintegrators::childSubintType # Tuple + solution_indices::solidxType + child_solution_indices::childSolidxType # Tuple + child_synchronizers::childSyncType # Tuple + opts::optionsType + tdir::tType +end + +# --- SplitSubIntegrator interface --- + +tdir(integrator::SplitSubIntegrator) = sign(integrator.dt) -A variant of [`ODEIntegrator`](https://github.com/SciML/OrdinaryDiffEq.jl/blob/6ec5a55bda26efae596bf99bea1a1d729636f412/src/integrators/type.jl#L77-L123) to perform opeartor splitting. +# proposed-dt interface (mirrors ODEIntegrator) +function SciMLBase.set_proposed_dt!(sub::SplitSubIntegrator, dt) + if sub.dtcache != dt # only touch if actually changing + sub.dtcache = dt + if !isadaptive(sub) + sub.dt = dt + end + end + return nothing +end + + +""" + OperatorSplittingIntegrator <: AbstractODEIntegrator -Derived from https://github.com/CliMA/ClimaTimeSteppers.jl/blob/ef3023747606d2750e674d321413f80638136632/src/integrators.jl. +A variant of [`ODEIntegrator`](https://github.com/SciML/OrdinaryDiffEq.jl/blob/6ec5a55bda26efae596bf99bea1a1d729636f412/src/integrators/type.jl#L77-L123) +to perform operator splitting. """ mutable struct OperatorSplittingIntegrator{ fType, @@ -35,39 +139,38 @@ mutable struct OperatorSplittingIntegrator{ cacheType, solType, subintTreeType, - solidxTreeType, - syncTreeType, + childSolidxType, + childSyncType, controllerType, optionsType, } <: SciMLBase.AbstractODEIntegrator{algType, true, uType, tType} const f::fType const alg::algType - u::uType # Master Solution - uprev::uType # Master Solution - tmp::uType # Interpolation buffer + u::uType # Master solution + uprev::uType # Master solution previous step + tmp::uType # Interpolation buffer p::pType - t::tType # Current time + t::tType # Current time tprev::tType - dt::tType # This is the time step length which which we use during time marching - dtcache::tType # This is the proposed time step length - const dtchangeable::Bool # Indicator whether dtcache can be changed + dt::tType # Time step length used during time marching + dtcache::tType # Proposed time step length + const dtchangeable::Bool tstops::heapType - _tstops::tstopsType # argument to __init used as default argument to reinit! + _tstops::tstopsType saveat::heapType - _saveat::saveatType # argument to __init used as default argument to reinit! + _saveat::saveatType callback::callbackType advance_to_tstop::Bool - # TODO group these into some internal flag struct last_step_failed::Bool force_stepfail::Bool isout::Bool u_modified::Bool - # DiffEqBase.initialize! and DiffEqBase.finalize! cache::cacheType sol::solType - subintegrator_tree::subintTreeType - solution_index_tree::solidxTreeType - synchronizer_tree::syncTreeType + # Tuple of SplitSubIntegrator nodes (one per top-level operator). + child_subintegrators::subintTreeType + child_solution_indices::childSolidxType # Tuple + child_synchronizers::childSyncType # Tuple iter::Int controller::controllerType opts::optionsType @@ -75,7 +178,11 @@ mutable struct OperatorSplittingIntegrator{ tdir::tType end -# called by DiffEqBase.init and DiffEqBase.solve +const AnySplitIntegrator = Union{SplitSubIntegrator, OperatorSplittingIntegrator} + +# --------------------------------------------------------------------------- +# __init +# --------------------------------------------------------------------------- function SciMLBase.__init( prob::OperatorSplittingProblem, alg::AbstractOperatorSplittingAlgorithm, @@ -89,6 +196,7 @@ function SciMLBase.__init( advance_to_tstop = false, adaptive = isadaptive(alg), controller = nothing, + # controller = OrdinaryDiffEqCore.PIController(0.14, 0.08), alias_u0 = false, verbose = true, kwargs... @@ -101,10 +209,10 @@ function SciMLBase.__init( dt = tf > t0 ? dt : -dt tType = typeof(dt) - # Warn if the algorithm is non-adaptive but the user tries to make it adaptive. - (!isadaptive(alg) && adaptive && verbose) && warn("The algorithm $alg is not adaptive.") + (!isadaptive(alg) && adaptive && verbose) && + @warn("The algorithm $alg is not adaptive.") - dtchangeable = true # isadaptive(alg) + dtchangeable = isdtchangeable(alg) if tstops isa AbstractArray || tstops isa Tuple || tstops isa Number _tstops = nothing @@ -113,7 +221,6 @@ function SciMLBase.__init( tstops = () end - # Setup tstop logic tstops_internal = OrdinaryDiffEqCore.initialize_tstops( tType, tstops, d_discontinuities, prob.tspan ) @@ -128,57 +235,58 @@ function SciMLBase.__init( uType = typeof(u) sol = SciMLBase.build_solution(prob, alg, tType[], uType[]) - callback = DiffEqBase.CallbackSet(callback) - subintegrator_tree, - cache = build_subintegrator_tree_with_cache( + child_subintegrators = build_subintegrators( prob, alg, uprev, u, + u, # u_master == u at the outermost level 1:length(u), t0, dt, tf, tstops, saveat, d_discontinuities, callback, adaptive, verbose ) + cache = init_cache( + prob.f, alg; + uprev = uprev, u = u, + ) + + child_solution_indices = ntuple(i -> prob.f.solution_indices[i], length(prob.f.functions)) + child_synchronizers = ntuple(i -> prob.f.synchronizers[i], length(prob.f.functions)) + integrator = OperatorSplittingIntegrator( prob.f, alg, - u, - uprev, - tmp, + u, uprev, tmp, p, - t0, - copy(t0), - dt, - dtcache, + t0, copy(dt), + dt, dtcache, dtchangeable, - tstops_internal, - tstops, - saveat_internal, - saveat, + tstops_internal, tstops, + saveat_internal, saveat, callback, advance_to_tstop, - false, - false, - false, - false, - cache, - sol, - subintegrator_tree, - build_solution_index_tree(prob.f), - build_synchronizer_tree(prob.f), + false, false, false, false, + cache, sol, + child_subintegrators, + child_solution_indices, + child_synchronizers, 0, controller, IntegratorOptions(; verbose, adaptive), IntegratorStats(), tType(tstops_internal.ordering isa DataStructures.FasterForward ? 1 : -1) ) - DiffEqBase.initialize!(callback, u0, t0, integrator) # Do I need this? + DiffEqBase.initialize!(callback, u0, t0, integrator) return integrator end +# --------------------------------------------------------------------------- +# reinit! +# --------------------------------------------------------------------------- SciMLBase.has_reinit(integrator::OperatorSplittingIntegrator) = true + function DiffEqBase.reinit!( integrator::OperatorSplittingIntegrator, u0 = integrator.sol.prob.u0; @@ -198,7 +306,8 @@ function DiffEqBase.reinit!( if dt !== nothing integrator.dt = dt end - integrator.tstops, integrator.saveat = tstops_and_saveat_heaps(t0, tf, tstops, saveat) + integrator.tstops, integrator.saveat = + tstops_and_saveat_heaps(t0, tf, tstops, saveat) integrator.iter = 0 if erase_sol resize!(integrator.sol.t, 0) @@ -206,7 +315,7 @@ function DiffEqBase.reinit!( end if reinit_callbacks DiffEqBase.initialize!(integrator.callback, u0, t0, integrator) - else # always reinit the saving callback so that t0 can be saved if needed + else saving_callback = integrator.callback.discrete_callbacks[end] DiffEqBase.initialize!(saving_callback, u0, t0, integrator) end @@ -216,56 +325,92 @@ function DiffEqBase.reinit!( ) end - return subreinit!( + _subreinit_tuple!( integrator.f, u0, - 1:length(u0), - integrator.subintegrator_tree; + integrator.child_subintegrators; t0, tf, dt, - erase_sol, - tstops, - saveat, - reinit_callbacks, - reinit_retcode + erase_sol, tstops, saveat, + reinit_callbacks, reinit_retcode ) + return nothing end -function subreinit!( +# --- subreinit! helpers --- + +# Iterate over a tuple of children (outermost call from reinit!) +@unroll function _subreinit_tuple!( f, u0, - solution_indices, - subintegrator::DEIntegrator; + children::Tuple; + kwargs... + ) + i = 1 + @unroll for child in children + _subreinit_child!(get_operator(f, i), u0, child; kwargs...) + i += 1 + end +end + +# Reinitialise a leaf DEIntegrator child +function _subreinit_child!( + f_child, + u0, + child::DEIntegrator; dt, kwargs... ) - # dt is not reset as expected in reinit! - if dt !== nothing - subintegrator.dt = dt + if dt !== nothing && child.dtchangeable + SciMLBase.set_proposed_dt!(child, dt) + # Reinit does not touch this, so we reset it manually. + set_dt!(child, dt) end - return DiffEqBase.reinit!(subintegrator, u0[solution_indices]; kwargs...) + return DiffEqBase.reinit!(child; kwargs...) end -@unroll function subreinit!( - f, +# Reinitialise an intermediate SplitSubIntegrator child +function _subreinit_child!( + f_child, u0, - solution_indices, - subintegrators::Tuple; + sub::SplitSubIntegrator; + t0, + tf, + dt, kwargs... ) - i = 1 - @unroll for subintegrator in subintegrators - subreinit!(get_operator(f, i), u0, f.solution_indices[i], subintegrator; kwargs...) - i += 1 + sub.t = t0 + if dt !== nothing + SciMLBase.set_proposed_dt!(sub, dt) + set_dt!(sub, dt) + end + sub.iter = 0 + sub.force_stepfail = false + sub.last_step_failed = false + sub.status = SplitSubIntegratorStatus(ReturnCode.Default) + # Reset EEst to its appropriate default + if isadaptive(sub) + sub.EEst = one(sub.EEst) end + # Recurse into this node's children + _subreinit_tuple!( + f_child, + u0, + sub.child_subintegrators; + t0, tf, dt, kwargs... + ) + return nothing end -function OrdinaryDiffEqCore.handle_tstop!(integrator::OperatorSplittingIntegrator) +# --------------------------------------------------------------------------- +# handle_tstop! +# --------------------------------------------------------------------------- +function OrdinaryDiffEqCore.handle_tstop!(integrator::AnySplitIntegrator) if SciMLBase.has_tstop(integrator) tdir_t = tdir(integrator) * integrator.t tdir_tstop = SciMLBase.first_tstop(integrator) if tdir_t == tdir_tstop - while tdir_t == tdir_tstop #remove all redundant copies - res = SciMLBase.pop_tstop!(integrator) + while tdir_t == tdir_tstop + SciMLBase.pop_tstop!(integrator) SciMLBase.has_tstop(integrator) ? (tdir_tstop = SciMLBase.first_tstop(integrator)) : break end @@ -274,8 +419,8 @@ function OrdinaryDiffEqCore.handle_tstop!(integrator::OperatorSplittingIntegrato if !integrator.dtchangeable SciMLBase.change_t_via_interpolation!( integrator, - tdir(integrator) * - SciMLBase.pop_tstop!(integrator), Val{true} + tdir(integrator) * SciMLBase.pop_tstop!(integrator), + Val{true} ) notify_integrator_hit_tstop!(integrator) else @@ -286,118 +431,180 @@ function OrdinaryDiffEqCore.handle_tstop!(integrator::OperatorSplittingIntegrato return nothing end -notify_integrator_hit_tstop!(integrator::OperatorSplittingIntegrator) = nothing +notify_integrator_hit_tstop!(integrator::AnySplitIntegrator) = nothing -is_first_iteration(integrator::OperatorSplittingIntegrator) = integrator.iter == 0 -increment_iteration(integrator::OperatorSplittingIntegrator) = integrator.iter += 1 -# Controller interface -function reject_step!(integrator::OperatorSplittingIntegrator) +# --------------------------------------------------------------------------- +# Step accept/reject +# --------------------------------------------------------------------------- +function reject_step!(integrator::AnySplitIntegrator) OrdinaryDiffEqCore.increment_reject!(integrator.stats) return reject_step!(integrator, integrator.cache, integrator.controller) end -function reject_step!(integrator::OperatorSplittingIntegrator, cache, controller) - return integrator.u .= integrator.uprev - # TODO what do we need to do with the subintegrators? +function reject_step!(integrator::AnySplitIntegrator, cache, controller) + integrator.u .= integrator.uprev + rollback_children!(integrator) + return nothing end -function reject_step!(integrator::OperatorSplittingIntegrator, cache, ::Nothing) - return if length(integrator.uprev) == 0 +function reject_step!(integrator::AnySplitIntegrator, cache, ::Nothing) + if length(integrator.uprev) == 0 error("Cannot roll back integrator. Aborting time integration step at $(integrator.t).") end + return nothing end -# Solution looping interface function should_accept_step(integrator::OperatorSplittingIntegrator) - if integrator.force_stepfail || integrator.isout - return false - end + integrator.force_stepfail || integrator.isout && return false + return should_accept_step(integrator, integrator.cache, integrator.controller) +end +function should_accept_step(integrator::SplitSubIntegrator) + integrator.force_stepfail && return false return should_accept_step(integrator, integrator.cache, integrator.controller) end -function should_accept_step(integrator::OperatorSplittingIntegrator, cache, ::Nothing) +function should_accept_step(integrator::AnySplitIntegrator, cache, ::Nothing) return !(integrator.force_stepfail) end -function accept_step!(integrator::OperatorSplittingIntegrator) + +function accept_step!(integrator::AnySplitIntegrator) OrdinaryDiffEqCore.increment_accept!(integrator.stats) return accept_step!(integrator, integrator.cache, integrator.controller) end -function accept_step!(integrator::OperatorSplittingIntegrator, cache, controller) +function accept_step!(integrator::AnySplitIntegrator, cache, controller) return store_previous_info!(integrator) end -function store_previous_info!(integrator::OperatorSplittingIntegrator) - return if length(integrator.uprev) > 0 # Integrator can rollback +function store_previous_info!(integrator::AnySplitIntegrator) + if length(integrator.uprev) > 0 update_uprev!(integrator) end + return nothing end - -function update_uprev!(integrator::OperatorSplittingIntegrator) +function update_uprev!(integrator::AnySplitIntegrator) RecursiveArrayTools.recursivecopy!(integrator.uprev, integrator.u) return nothing end -function step_header!(integrator::OperatorSplittingIntegrator) - # Accept or reject the step +# Roll back each child's local buffer to match master u. +# For DEIntegrators the leaf will be re-synced via forward_sync before the +# next attempt, so there is nothing to do here. +rollback_children!(integrator::OperatorSplittingIntegrator) = rollback_children!(integrator.child_subintegrators, integrator.u) +@unroll function rollback_children!(children::Tuple, u_master) + @unroll for child in children + rollback_child!(child, u_master) + end +end +function rollback_child!(child::SplitSubIntegrator, u_master) + child.u .= @view u_master[child.solution_indices] + RecursiveArrayTools.recursivecopy!(child.uprev, child.u) + _rollback_children!(child.child_subintegrators, u_master) + return nothing +end +function rollback_child!(child::DEIntegrator, u_master) + # forward_sync before the next sub-step will restore this correctly. + return nothing +end + +# --------------------------------------------------------------------------- +# step_header! / step_footer! +# --------------------------------------------------------------------------- +function step_header!(integrator::AnySplitIntegrator) if !is_first_iteration(integrator) if should_accept_step(integrator) accept_step!(integrator) - else # Step should be rejected and hence repeated + else reject_step!(integrator) end - elseif integrator.u_modified # && integrator.iter == 0 + elseif integrator.u_modified update_uprev!(integrator) end - - # Before stepping we might need to adjust the dt increment_iteration(integrator) - # OrdinaryDiffEqCore.choose_algorithm!(integrator, integrator.cache) OrdinaryDiffEqCore.fix_dt_at_bounds!(integrator) OrdinaryDiffEqCore.modify_dt_for_tstops!(integrator) - return integrator.force_stepfail = false + integrator.force_stepfail = false + return nothing end +is_first_iteration(integrator::AnySplitIntegrator) = integrator.iter == 0 +increment_iteration(integrator::AnySplitIntegrator) = integrator.iter += 1 + function footer_reset_flags!(integrator) - return integrator.u_modified = false + integrator.u_modified = false + return end +footer_reset_flags!(::SplitSubIntegrator) = nothing function setup_validity_flags!(integrator, t_next) - return integrator.isout = false #integrator.opts.isoutofdomain(integrator.u, integrator.p, t_next) + integrator.isout = false + return end +setup_validity_flags!(::SplitSubIntegrator, _) = nothing function fix_solution_buffer_sizes!(integrator, sol) resize!(integrator.sol.t, integrator.saveiter) resize!(integrator.sol.u, integrator.saveiter) - return if !(integrator.sol isa SciMLBase.DAESolution) + if !(integrator.sol isa SciMLBase.DAESolution) resize!(integrator.sol.k, integrator.saveiter_dense) end + return nothing end -function step_footer!(integrator::OperatorSplittingIntegrator) - ttmp = integrator.t + tdir(integrator) * integrator.dt +function OrdinaryDiffEqCore.fixed_t_for_floatingpoint_error!(integrator::AnySplitIntegrator, ttmp) + return if DiffEqBase.has_tstop(integrator) + tstop = integrator.tdir * DiffEqBase.first_tstop(integrator) + if abs(ttmp - tstop) < + 100eps(float(max(integrator.t, tstop) / oneunit(integrator.t))) * + oneunit(integrator.t) + try_snap_children_to_tstop!.(integrator.child_subintegrators, tstop) + tstop + else + ttmp + end + else + ttmp + end +end +function try_snap_children_to_tstop!(integrator::SplitSubIntegrator, tstop) + if abs(tstop - integrator.t) < + 100eps(float(max(integrator.t, tstop) / oneunit(integrator.t))) * oneunit(integrator.t) + integrator.t = tstop + else + @warn "Failed to snap timestep for integrator $(integrator.t) with parent integrator hitting the tstop $(tstop)." + end + return try_snap_children_to_tstop!.(integrator.child_subintegrators, tstop) +end +function try_snap_children_to_tstop!(integrator::DEIntegrator, tstop) + return if abs(tstop - integrator.t) < + 100eps(float(max(integrator.t, tstop) / oneunit(integrator.t))) * oneunit(integrator.t) + integrator.t = tstop + else + @warn "Failed to snap timestep for integrator $(integrator.t) with parent integrator hitting the tstop $(tstop)." + end +end +function step_footer!(integrator::AnySplitIntegrator) + ttmp = integrator.t + tdir(integrator) * integrator.dt footer_reset_flags!(integrator) setup_validity_flags!(integrator, ttmp) - if should_accept_step(integrator) integrator.last_step_failed = false integrator.tprev = integrator.t - integrator.t = ttmp #OrdinaryDiffEqCore.fixed_t_for_floatingpoint_error!(integrator, ttmp) - # OrdinaryDiffEqCore.handle_callbacks!(integrator) - step_accept_controller!(integrator) # Noop for non-adaptive algorithms + integrator.t = OrdinaryDiffEqCore.fixed_t_for_floatingpoint_error!(integrator, ttmp) + step_accept_controller!(integrator) elseif integrator.force_stepfail if isadaptive(integrator) step_reject_controller!(integrator) OrdinaryDiffEqCore.post_newton_controller!(integrator, integrator.alg) - elseif integrator.dtchangeable # Non-adaptive but can change dt + elseif integrator.dtchangeable integrator.dt /= integrator.opts.failfactor elseif integrator.last_step_failed return end integrator.last_step_failed = true end - - # integration_monitor_step(integrator) - + validate_time_point(integrator) return nothing end -# called by DiffEqBase.solve +# --------------------------------------------------------------------------- +# __solve / solve! / step! +# --------------------------------------------------------------------------- function SciMLBase.__solve( prob::OperatorSplittingProblem, alg::AbstractOperatorSplittingAlgorithm, args...; kwargs... @@ -406,34 +613,29 @@ function SciMLBase.__solve( return DiffEqBase.solve!(integrator) end -# either called directly (after init), or by DiffEqBase.solve (via __solve) function DiffEqBase.solve!(integrator::OperatorSplittingIntegrator) while !isempty(integrator.tstops) while tdir(integrator) * integrator.t < SciMLBase.first_tstop(integrator) step_header!(integrator) @timeit_debug "check_error" SciMLBase.check_error!(integrator) ∉ ( ReturnCode.Success, ReturnCode.Default, - ) && return + ) && return integrator.sol __step!(integrator) step_footer!(integrator) - if !SciMLBase.has_tstop(integrator) - break - end + SciMLBase.has_tstop(integrator) || break end OrdinaryDiffEqCore.handle_tstop!(integrator) end SciMLBase.postamble!(integrator) - if integrator.sol.retcode != ReturnCode.Default - return integrator.sol - end + integrator.sol.retcode != ReturnCode.Default && return integrator.sol return integrator.sol = SciMLBase.solution_new_retcode( integrator.sol, ReturnCode.Success ) end -function DiffEqBase.step!(integrator::OperatorSplittingIntegrator) +function DiffEqBase.step!(integrator::AnySplitIntegrator) @timeit_debug "step!" if integrator.advance_to_tstop - tstop = first_tstop(integrator) + tstop = SciMLBase.first_tstop(integrator) while !reached_tstop(integrator, tstop) step_header!(integrator) @timeit_debug "check_error" SciMLBase.check_error!(integrator) ∉ ( @@ -441,9 +643,7 @@ function DiffEqBase.step!(integrator::OperatorSplittingIntegrator) ) && return __step!(integrator) step_footer!(integrator) - if !SciMLBase.has_tstop(integrator) - break - end + SciMLBase.has_tstop(integrator) || break end else step_header!(integrator) @@ -461,78 +661,86 @@ function DiffEqBase.step!(integrator::OperatorSplittingIntegrator) step_footer!(integrator) end end - return OrdinaryDiffEqCore.handle_tstop!(integrator) + OrdinaryDiffEqCore.handle_tstop!(integrator) + return +end + +function DiffEqBase.step!(integrator::AnySplitIntegrator, dt, stop_at_tdt = false) + @timeit_debug "step!" begin + dt <= zero(dt) && error("dt must be positive") + stop_at_tdt && !integrator.dtchangeable && + error("Cannot stop at t + dt if dtchangeable is false") + tnext = integrator.t + tdir(integrator) * dt + stop_at_tdt && DiffEqBase.add_tstop!(integrator, tnext) + while !reached_tstop(integrator, tnext, stop_at_tdt) + step_header!(integrator) + @timeit_debug "check_error" SciMLBase.check_error!(integrator) ∉ ( + ReturnCode.Success, ReturnCode.Default, + ) && return + __step!(integrator) + step_footer!(integrator) + end + end + OrdinaryDiffEqCore.handle_tstop!(integrator) + return nothing end +# --------------------------------------------------------------------------- +# check_error +# --------------------------------------------------------------------------- function SciMLBase.check_error(integrator::OperatorSplittingIntegrator) if !SciMLBase.successful_retcode(integrator.sol) && integrator.sol.retcode != ReturnCode.Default return integrator.sol.retcode end - - verbose = true # integrator.opts.verbose - - if DiffEqBase.NAN_CHECK(integrator.dtcache) || DiffEqBase.NAN_CHECK(integrator.dt) # replace with https://github.com/SciML/OrdinaryDiffEq.jl/blob/373a8eec8024ef1acc6c5f0c87f479aa0cf128c3/lib/OrdinaryDiffEqCore/src/iterator_interface.jl#L5-L6 after moving to sciml integrators - if verbose + if DiffEqBase.NAN_CHECK(integrator.dtcache) || DiffEqBase.NAN_CHECK(integrator.dt) + integrator.opts.verbose && @warn("NaN dt detected. Likely a NaN value in the state, parameters, or derivative value caused this outcome.") - end return ReturnCode.DtNaN end - - return check_error_subintegrators(integrator, integrator.subintegrator_tree) + return _check_error_children(integrator.sol.retcode, integrator.child_subintegrators) end -function check_error_subintegrators(integrator, subintegrator_tree::Tuple) - for subintegrator in subintegrator_tree - retcode = check_error_subintegrators(integrator, subintegrator) - if !SciMLBase.successful_retcode(retcode) && retcode != ReturnCode.Default - return retcode - end +function SciMLBase.check_error(integrator::SplitSubIntegrator) + if !SciMLBase.successful_retcode(integrator.status.retcode) && + integrator.status.retcode != ReturnCode.Default + return integrator.status.retcode + end + if DiffEqBase.NAN_CHECK(integrator.dtcache) || DiffEqBase.NAN_CHECK(integrator.dt) + integrator.opts.verbose && + @warn("NaN dt detected. Likely a NaN value in the state, parameters, or derivative value caused this outcome.") + return ReturnCode.DtNaN end - return integrator.sol.retcode + return _check_error_children(integrator.status.retcode, integrator.child_subintegrators) end -function check_error_subintegrators(integrator, subintegrator::DEIntegrator) - return SciMLBase.check_error(subintegrator) +function SciMLBase.check_error!(integrator::SplitSubIntegrator) + code = SciMLBase.check_error(integrator) + integrator.status.retcode = code + return code end -function DiffEqBase.step!(integrator::OperatorSplittingIntegrator, dt, stop_at_tdt = false) - return @timeit_debug "step!" begin - # OridinaryDiffEq lets dt be negative if tdir is -1, but that's inconsistent - dt <= zero(dt) && error("dt must be positive") - stop_at_tdt && !integrator.dtchangeable && - error("Cannot stop at t + dt if dtchangeable is false") - tnext = integrator.t + tdir(integrator) * dt - stop_at_tdt && DiffEqBase.add_tstop!(integrator, tnext) - while !reached_tstop(integrator, tnext, stop_at_tdt) - step_header!(integrator) - @timeit_debug "check_error" SciMLBase.check_error!(integrator) ∉ ( - ReturnCode.Success, ReturnCode.Default, - ) && return - __step!(integrator) - step_footer!(integrator) +@unroll function _check_error_children(current_retcode, children::Tuple) + @unroll for child in children + rc = _child_retcode(child) + if !SciMLBase.successful_retcode(rc) && rc != ReturnCode.Default + return rc end end + return current_retcode end +_child_retcode(child::DEIntegrator) = SciMLBase.check_error(child) +_child_retcode(child::SplitSubIntegrator) = child.status.retcode + function setup_u(prob::OperatorSplittingProblem, solver, alias_u0) - if alias_u0 - return prob.u0 - else - return RecursiveArrayTools.recursivecopy(prob.u0) - end + return alias_u0 ? prob.u0 : RecursiveArrayTools.recursivecopy(prob.u0) end -# TimeChoiceIterator API @inline function DiffEqBase.get_tmp_cache(integrator::OperatorSplittingIntegrator) - # DiffEqBase.get_tmp_cache(integrator, integrator.alg, integrator.cache) return (integrator.tmp,) end -# @inline function DiffEqBase.get_tmp_cache(integrator::OperatorSplittingIntegrator, ::AbstractOperatorSplittingAlgorithm, cache) -# return (cache.tmp,) -# end -# Interpolation -# TODO via https://github.com/SciML/SciMLBase.jl/blob/master/src/interpolation.jl + function linear_interpolation!(y, t, y1, y2, t1, t2) return y .= y1 + (t - t1) * (y2 - y1) / (t2 - t1) end @@ -542,62 +750,58 @@ function (integrator::OperatorSplittingIntegrator)(tmp, t) ) end -""" - stepsize_controller!(::OperatorSplittingIntegrator) - -Updates the controller using the current state of the integrator if the operator splitting algorithm is adaptive. -""" -@inline function stepsize_controller!(integrator::OperatorSplittingIntegrator) - algorithm = integrator.alg - isadaptive(algorithm) || return nothing - return stepsize_controller!(integrator, algorithm) +# Stepsize controller hooks +@inline function stepsize_controller!(integrator::AnySplitIntegrator) + isadaptive(integrator.alg) || return nothing + stepsize_controller!(integrator, integrator.alg) + return nothing end - -""" - step_accept_controller!(::OperatorSplittingIntegrator) - -Updates `dtcache` of the integrator if the step is accepted and the operator splitting algorithm is adaptive. -""" -@inline function step_accept_controller!(integrator::OperatorSplittingIntegrator) - algorithm = integrator.alg - isadaptive(algorithm) || return nothing - return step_accept_controller!(integrator, algorithm, nothing) +@inline function stepsize_controller!(integrator::AnySplitIntegrator, alg::AbstractOperatorSplittingAlgorithm) + isadaptive(alg) || return nothing + #stepsize_controller!(integrator, integrator.controller) + return nothing end - -""" - step_reject_controller!(::OperatorSplittingIntegrator) - -Updates `dtcache` of the integrator if the step is rejected and the the operator splitting algorithm is adaptive. -""" -@inline function step_reject_controller!(integrator::OperatorSplittingIntegrator) - algorithm = integrator.alg - isadaptive(algorithm) || return nothing - return step_reject_controller!(integrator, algorithm, nothing) +@inline function step_accept_controller!(integrator::AnySplitIntegrator) + isadaptive(integrator.alg) || return nothing + step_accept_controller!(integrator, integrator.alg) + return nothing +end +@inline function step_accept_controller!(integrator::AnySplitIntegrator, alg::AbstractOperatorSplittingAlgorithm) + isadaptive(alg) || return nothing + #step_accept_controller!(integrator, integrator.controller) + return nothing +end +@inline function step_reject_controller!(integrator::AnySplitIntegrator) + isadaptive(integrator.alg) || return nothing + step_reject_controller!(integrator, integrator.alg) + return nothing +end +@inline function step_reject_controller!(integrator::AnySplitIntegrator, alg::AbstractOperatorSplittingAlgorithm) + isadaptive(integrator.alg) || return nothing + # step_reject_controller!(integrator, integrator.controller) + return nothing end -# helper functions for dealing with time-reversed integrators in the same way -# that OrdinaryDiffEq.jl does -tdir(integrator) = integrator.tstops.ordering isa DataStructures.FasterForward ? 1 : -1 -is_past_t(integrator, t) = tdir(integrator) * (t - integrator.t) ≤ zero(integrator.t) + +# Time helpers +tdir(integrator) = + integrator.tstops.ordering isa DataStructures.FasterForward ? 1 : -1 +is_past_t(integrator, t) = + tdir(integrator) * (t - integrator.t) ≤ zero(integrator.t) function reached_tstop(integrator, tstop, stop_at_tstop = integrator.dtchangeable) if stop_at_tstop integrator.t > tstop && error("Integrator missed stop at $tstop (current time=$(integrator.t)). Aborting.") - return integrator.t == tstop # Check for exact hit - else #!stop_at_tstop + return integrator.t ≈ tstop + else return is_past_t(integrator, tstop) end end -# Dunno stuff +# SciMLBase integrator interface function SciMLBase.done(integrator::OperatorSplittingIntegrator) - if !( - integrator.sol.retcode in ( - ReturnCode.Default, ReturnCode.Success, - ) - ) - return true - elseif isempty(integrator.tstops) + integrator.sol.retcode ∉ (ReturnCode.Default, ReturnCode.Success) && return true + if isempty(integrator.tstops) SciMLBase.postamble!(integrator) return true end @@ -608,107 +812,132 @@ function SciMLBase.postamble!(integrator::OperatorSplittingIntegrator) return DiffEqBase.finalize!(integrator.callback, integrator.u, integrator.t, integrator) end -function __step!(integrator) - tnext = integrator.t + integrator.dt - synchronize_subintegrator_tree!(integrator) - advance_solution_to!(integrator, tnext) - return stepsize_controller!(integrator) +function __step!(integrator::AnySplitIntegrator) + advance_solution_by!(integrator, integrator.dt) + stepsize_controller!(integrator) # FIXME this should go into the footer + return nothing end -# solvers need to define this interface -function advance_solution_to!(integrator::OperatorSplittingIntegrator, tnext) - return advance_solution_to!(integrator, integrator.cache, tnext) +# Entry point: dispatch to the algorithm's advance_solution_by! +function advance_solution_by!(integrator::AnySplitIntegrator, dt) + return advance_solution_by!(integrator, integrator.cache, dt) end -function advance_solution_to!( - outer_integrator::OperatorSplittingIntegrator, - integrator::DEIntegrator, solution_indices, sync, cache, tend +# Algorithm-level dispatch (implemented in solver.jl per algorithm) +function advance_solution_by!( + integrator::AnySplitIntegrator, + cache::AbstractOperatorSplittingCache, dt + ) + return advance_solution_by!( + integrator, integrator.child_subintegrators, cache, dt ) - dt = tend - integrator.t - return SciMLBase.step!(integrator, dt, true) end -# ----------------------------------- SciMLBase.jl Integrator Interface ------------------------------------ -SciMLBase.has_stats(::OperatorSplittingIntegrator) = true - -SciMLBase.has_tstop(integrator::OperatorSplittingIntegrator) = !isempty(integrator.tstops) -SciMLBase.first_tstop(integrator::OperatorSplittingIntegrator) = first(integrator.tstops) -SciMLBase.pop_tstop!(integrator::OperatorSplittingIntegrator) = pop!(integrator.tstops) +# --------------------------------------------------------------------------- +# advance_solution_by! for a SplitSubIntegrator node +# +# The SplitSubIntegrator is now the *parent* for its own children. +# It carries child_solution_indices and child_synchronizers directly. +# +# Entry point called from integrator.jl for a SplitSubIntegrator node +# --------------------------------------------------------------------------- +function advance_solution_by!( + outer::OperatorSplittingIntegrator, + children::Tuple, + cache::AbstractOperatorSplittingCache, + dt + ) + _perform_step!(outer, children, cache, dt) + + if outer.force_stepfail && all(isadaptive.(children)) + # We do not know recover at this point, as an decrease in the solve + # interval is unlikely to help here. + outer.sol = SciMLBase.solution_new_retcode( + outer.sol, + ReturnCode.Failure + ) + return + end -DiffEqBase.get_dt(integrator::OperatorSplittingIntegrator) = integrator.dt -function set_dt!(integrator::OperatorSplittingIntegrator, dt) - # TODO: figure out interface for recomputing other objects (linear operators, etc) - dt <= zero(dt) && error("dt must be positive") - return integrator.dt = dt + return end -function DiffEqBase.add_tstop!(integrator::OperatorSplittingIntegrator, t) - is_past_t(integrator, t) && - error("Cannot add a tstop at $t because that is behind the current \ - integrator time $(integrator.t)") - return push!(integrator.tstops, t) -end +function advance_solution_by!( + outer::SplitSubIntegrator, + children::Tuple, + cache::AbstractOperatorSplittingCache, + dt + ) + _perform_step!(outer, children, cache, dt) -function DiffEqBase.add_saveat!(integrator::OperatorSplittingIntegrator, t) - is_past_t(integrator, t) && - error("Cannot add a saveat point at $t because that is behind the \ - current integrator time $(integrator.t)") - return push!(integrator.saveat, t) -end + if outer.force_stepfail + outer.status = SplitSubIntegratorStatus(ReturnCode.Failure) + return + end -# not sure what this should do? -# defined as default initialize: https://github.com/SciML/DiffEqBase.jl/blob/master/src/callbacks.jl#L3 -DiffEqBase.u_modified!(i::OperatorSplittingIntegrator, bool) = nothing + # All children succeeded: advance this node's time and counter + outer.status = SplitSubIntegratorStatus(ReturnCode.Success) -function synchronize_subintegrator_tree!(integrator::OperatorSplittingIntegrator) - return synchronize_subintegrator!(integrator.subintegrator_tree, integrator) + return end -@unroll function synchronize_subintegrator!( - subintegrator_tree::Tuple, integrator::OperatorSplittingIntegrator +# Recursion dispatch +function advance_solution_by!( + outer::AnySplitIntegrator, + sub::SplitSubIntegrator, + dt ) - @unroll for subintegrator in subintegrator_tree - synchronize_subintegrator!(subintegrator, integrator) + SciMLBase.step!(sub, dt, true) + + # Unrecoverable failure: error immediately regardless of adaptive/non-adaptive + if !SciMLBase.successful_retcode(sub.status.retcode) && + sub.status.retcode != ReturnCode.Default + error("Inner integrator failed unrecoverably with retcode \ + $(sub.status.retcode) at t=$(child.t). Aborting.") end + return nothing end -function synchronize_subintegrator!( - subintegrator::DEIntegrator, integrator::OperatorSplittingIntegrator - ) - (; t, dt) = integrator - @assert subintegrator.t == t - return if !isadaptive(subintegrator) - SciMLBase.set_proposed_dt!(subintegrator, dt) +# Leaf disptach +function advance_solution_by!(outer::AnySplitIntegrator, child::DEIntegrator, dt) + SciMLBase.step!(child, dt, true) + + # Unrecoverable failure: error immediately regardless of adaptive/non-adaptive + if !SciMLBase.successful_retcode(child.sol.retcode) && + child.sol.retcode != ReturnCode.Default + error("Inner integrator failed unrecoverably with retcode \ + $(child.sol.retcode) at t=$(child.t). Aborting.") end + return nothing end -function advance_solution_to!( - integrator::OperatorSplittingIntegrator, - cache::AbstractOperatorSplittingCache, tnext::Number - ) - return advance_solution_to!( - integrator, integrator.subintegrator_tree, integrator.solution_index_tree, - integrator.synchronizer_tree, cache, tnext - ) -end -# Dispatch for tree node construction -function build_subintegrator_tree_with_cache( - prob::OperatorSplittingProblem, alg::AbstractOperatorSplittingAlgorithm, - uprevouter::AbstractVector, uouter::AbstractVector, +# --------------------------------------------------------------------------- +# Tree construction +# --------------------------------------------------------------------------- + +# Top-level builder: called from __init with the full problem. +# Returns (child_subintegrators::Tuple, cache::AbstractOperatorSplittingCache) +function build_subintegrators( + prob::OperatorSplittingProblem, + alg::AbstractOperatorSplittingAlgorithm, + uprevouter::AbstractVector, + uouter::AbstractVector, + u_master::AbstractVector, solution_indices, t0, dt, tf, tstops, saveat, d_discontinuities, callback, adaptive, verbose ) (; f, p) = prob - subintegrator_tree_with_caches = ntuple( - i -> build_subintegrator_tree_with_cache( + + child_subintegrators = ntuple( + i -> _build_child( prob, alg.inner_algs[i], get_operator(f, i), p[i], - uprevouter, uouter, + uprevouter, uouter, u_master, f.solution_indices[i], t0, dt, tf, tstops, saveat, d_discontinuities, callback, @@ -717,24 +946,19 @@ function build_subintegrator_tree_with_cache( length(f.functions) ) - subintegrator_tree = ntuple( - i -> subintegrator_tree_with_caches[i][1], length(f.functions) - ) - caches = ntuple(i -> subintegrator_tree_with_caches[i][2], length(f.functions)) - - # TODO fix mixed device type problems we have to be smarter - return subintegrator_tree, - init_cache( - f, alg; - uprev = uprevouter, u = uouter, alias_u = true, - inner_caches = caches - ) + return child_subintegrators end -function build_subintegrator_tree_with_cache( - prob::OperatorSplittingProblem, alg::AbstractOperatorSplittingAlgorithm, - f::GenericSplitFunction, p::Tuple, - uprevouter::AbstractVector, uouter::AbstractVector, +# Intermediate node: inner alg is an AbstractOperatorSplittingAlgorithm and +# f is a GenericSplitFunction → produce a SplitSubIntegrator +function _build_child( + prob::OperatorSplittingProblem, + alg::AbstractOperatorSplittingAlgorithm, + f::GenericSplitFunction, + p::Tuple, + uprevouter::AbstractVector, + uouter::AbstractVector, + u_master::AbstractVector, solution_indices, t0, dt, tf, tstops, saveat, d_discontinuities, callback, @@ -742,13 +966,16 @@ function build_subintegrator_tree_with_cache( save_end = false, controller = nothing ) - subintegrator_tree_with_caches = ntuple( - i -> build_subintegrator_tree_with_cache( + tType = typeof(dt) + + # Recurse: build each consecutive child + child_subintegrators = ntuple( + i -> _build_child( prob, alg.inner_algs[i], get_operator(f, i), p[i], - uprevouter, uouter, + uprevouter, uouter, u_master, f.solution_indices[i], t0, dt, tf, tstops, saveat, d_discontinuities, callback, @@ -757,25 +984,57 @@ function build_subintegrator_tree_with_cache( length(f.functions) ) - subintegrator_tree = first.(subintegrator_tree_with_caches) - inner_caches = last.(subintegrator_tree_with_caches) - - # TODO fix mixed device type problems we have to be smarter - uprev = @view uprevouter[solution_indices] - u = @view uouter[solution_indices] - return subintegrator_tree, - init_cache( - f, alg; - uprev = uprev, u = u, - inner_caches = inner_caches - ) + child_solution_indices = ntuple(i -> f.solution_indices[i], length(f.functions)) + child_synchronizers = ntuple(i -> f.synchronizers[i], length(f.functions)) + + u_sub = RecursiveArrayTools.recursivecopy(uouter[solution_indices]) + uprev_sub = RecursiveArrayTools.recursivecopy(uprevouter[solution_indices]) + + tstops_internal = OrdinaryDiffEqCore.initialize_tstops( + tType, tstops, d_discontinuities, prob.tspan + ) + + level_cache = init_cache( + f, alg; + uprev = uprev_sub, u = u_sub, + ) + + EEst_val = isadaptive(alg) ? one(tType) : tType(NaN) + + sub = SplitSubIntegrator( + alg, + u_sub, + uprev_sub, + u_master, + t0, t0, dt, dt, # t, tprev, dt, dtcache + isdtchangeable(alg), + tstops_internal, + 0, # iter + EEst_val, + controller, + false, false, false, # force_stepfail, last_step_failed, u_modified + SplitSubIntegratorStatus(), + IntegratorStats(), + level_cache, + child_subintegrators, + solution_indices, + child_solution_indices, + child_synchronizers, + IntegratorOptions(; verbose, adaptive), + one(tType), + ) + + return sub end -function build_subintegrator_tree_with_cache( +# Leaf node: inner alg is a plain SciMLBase.AbstractODEAlgorithm +# → produce an ODEIntegrator (existing behaviour) +function _build_child( prob::OperatorSplittingProblem, alg::SciMLBase.AbstractODEAlgorithm, f::F, p::P, uprevouter::S, uouter::S, + u_master::S, solution_indices, t0::T, dt::T, tf::T, tstops, saveat, d_discontinuities, callback, @@ -783,44 +1042,57 @@ function build_subintegrator_tree_with_cache( save_end = false, controller = nothing ) where {S, T, P, F} - uprev = @view uprevouter[solution_indices] - u = @view uouter[solution_indices] - - # When working with MTK, we want to pass f as a system down here. - # In that case ODEProblem constructs the correct parameter struct. - # If the system does not have parameters in first place, then - # The NullParameters object will be constructed automatically. + u = uouter[solution_indices] prob2 = if p isa NullParameters - SciMLBase.ODEProblem(f, u, (t0, min(t0 + dt, tf))) + SciMLBase.ODEProblem(f, u, (t0, tf)) else - SciMLBase.ODEProblem(f, u, (t0, min(t0 + dt, tf)), p) + SciMLBase.ODEProblem(f, u, (t0, tf), p) end integrator = SciMLBase.__init( - prob2, - alg; + prob2, alg; dt, + tstops, saveat = (), d_discontinuities, save_everystep = false, advance_to_tstop = false, - adaptive, - controller, - verbose + adaptive, controller, verbose ) + return integrator +end + +# --------------------------------------------------------------------------- +# SciMLBase API +# --------------------------------------------------------------------------- +SciMLBase.has_stats(::AnySplitIntegrator) = true + +SciMLBase.has_tstop(i::AnySplitIntegrator) = !isempty(i.tstops) +SciMLBase.first_tstop(i::AnySplitIntegrator) = first(i.tstops) +SciMLBase.pop_tstop!(i::AnySplitIntegrator) = pop!(i.tstops) - return integrator, integrator.cache +DiffEqBase.get_dt(i::AnySplitIntegrator) = i.dt +function set_dt!(i::DiffEqBase.DEIntegrator, dt) + dt <= zero(dt) && error("dt must be positive") + return i.dt = dt end -function forward_sync_subintegrator!( - outer_integrator::OperatorSplittingIntegrator, subintegrator_tree::Tuple, - solution_indices::Tuple, synchronizers::Tuple - ) +function DiffEqBase.add_tstop!(i::AnySplitIntegrator, t) + is_past_t(i, t) && + error("Cannot add a tstop at $t because that is behind the current \ + integrator time $(i.t)") + DiffEqBase.add_tstop!.(i.child_subintegrators, t) + push!(i.tstops, t) return nothing end -function backward_sync_subintegrator!( - outer_integrator::OperatorSplittingIntegrator, - subintegrator_tree::Tuple, solution_indices::Tuple, synchronizer::Tuple - ) + +function DiffEqBase.add_saveat!(i::OperatorSplittingIntegrator, t) + is_past_t(i, t) && + error("Cannot add a saveat point at $t because that is behind the \ + current integrator time $(i.t)") + push!(i.saveat, t) return nothing end + +DiffEqBase.u_modified!(i::OperatorSplittingIntegrator, bool) = i.u_modified = bool +DiffEqBase.u_modified!(i::SplitSubIntegrator, bool) = i.u_modified = bool diff --git a/src/precompilation.jl b/src/precompilation.jl index 3514554..f638e74 100644 --- a/src/precompilation.jl +++ b/src/precompilation.jl @@ -1,11 +1,19 @@ using PrecompileTools: @compile_workload using OrdinaryDiffEqLowOrderRK: Euler +import DiffEqBase: DiffEqBase, step!, solve! function _precompile_ode1(du, u, p, t) - return @. du = -0.1u + @. du = -0.1u + return end function _precompile_ode2(du, u, p, t) + du[1] = -0.01u[2] + du[2] = -0.01u[1] + return +end + +function _precompile_ode3(du, u, p, t) du[1] = -0.01u[2] return du[2] = -0.01u[1] end @@ -17,16 +25,19 @@ end f1 = DiffEqBase.ODEFunction(_precompile_ode1) f2 = DiffEqBase.ODEFunction(_precompile_ode2) + f3 = DiffEqBase.ODEFunction(_precompile_ode3) f1dofs = [1, 2, 3] f2dofs = [1, 3] - fsplit = GenericSplitFunction((f1, f2), (f1dofs, f2dofs)) + f3dofs = [2, 3] + fsplitinner = GenericSplitFunction((f2, f3), (f2dofs, f3dofs)) + fsplit = GenericSplitFunction((f1, fsplitinner), (f1dofs, [1, 2, 3])) prob = OperatorSplittingProblem(fsplit, u0, tspan) - tstepper = LieTrotterGodunov((Euler(), Euler())) + tstepper = LieTrotterGodunov((Euler(), LieTrotterGodunov((Euler(), Euler())))) # Precompile init and a few steps integrator = DiffEqBase.init(prob, tstepper, dt = 0.01, verbose = false) - DiffEqBase.step!(integrator) - DiffEqBase.solve!(integrator) + step!(integrator) + solve!(integrator) end diff --git a/src/problem.jl b/src/problem.jl index 59e5b1d..35365db 100644 --- a/src/problem.jl +++ b/src/problem.jl @@ -35,6 +35,9 @@ recursive_null_parameters(f::AbstractOperatorSplitFunction) = @error "Not implem function recursive_null_parameters(f::GenericSplitFunction) return ntuple(i -> recursive_null_parameters(get_operator(f, i)), length(f.functions)) end -function recursive_null_parameters(f) # Wildcard for leafs +function recursive_null_parameters(f::SciMLBase.AbstractDiffEqFunction) + return NullParameters() +end +function recursive_null_parameters(f) return NullParameters() end diff --git a/src/solver.jl b/src/solver.jl index b2b0070..3163474 100644 --- a/src/solver.jl +++ b/src/solver.jl @@ -1,56 +1,58 @@ -# Lie-Trotter-Godunov Splitting Implementation +# --------------------------------------------------------------------------- +# Lie-Trotter-Godunov operator splitting +# --------------------------------------------------------------------------- """ LieTrotterGodunov <: AbstractOperatorSplittingAlgorithm -A first order sequential operator splitting algorithm attributed to [Lie:1880:tti,Tro:1959:psg,God:1959:dmn](@cite). +First-order sequential operator splitting algorithm attributed to +[Lie:1880:tti,Tro:1959:psg,God:1959:dmn](@cite). """ struct LieTrotterGodunov{AlgTupleType} <: AbstractOperatorSplittingAlgorithm inner_algs::AlgTupleType # Tuple of timesteppers for inner problems - # transfer_algs::TransferTupleType # Tuple of transfer algorithms from the master solution into the individual ones end -struct LieTrotterGodunovCache{uType, uprevType, iiType} <: AbstractOperatorSplittingCache +function Base.show(io::IO, alg::LieTrotterGodunov) + print(io, "LTG (") + for inner_alg in alg.inner_algs[1:(end - 1)] + Base.show(io, inner_alg) + print(io, " -> ") + end + length(alg.inner_algs) > 0 && Base.show(io, alg.inner_algs[end]) + return print(io, ")") +end + +struct LieTrotterGodunovCache{uType, uprevType} <: AbstractOperatorSplittingCache u::uType uprev::uprevType - inner_caches::iiType end function init_cache( f::GenericSplitFunction, alg::LieTrotterGodunov; uprev::AbstractArray, u::AbstractVector, - inner_caches, - alias_uprev = true, - alias_u = false ) - _uprev = alias_uprev ? uprev : RecursiveArrayTools.recursivecopy(uprev) - _u = alias_u ? u : RecursiveArrayTools.recursivecopy(u) - return LieTrotterGodunovCache(_u, _uprev, inner_caches) + return LieTrotterGodunovCache(u, uprev) end -@inline @unroll function advance_solution_to!( - outer_integrator::OperatorSplittingIntegrator, - subintegrators::Tuple, solution_indices::Tuple, - synchronizers::Tuple, cache::LieTrotterGodunovCache, tnext +@unroll function _perform_step!( + parent, + children::Tuple, + cache::LieTrotterGodunovCache, + dt ) - # We assume that the integrators are already synced - (; inner_caches) = cache - # For each inner operator i = 0 - @unroll for subinteg in subintegrators + @unroll for child in children i += 1 - synchronizer = synchronizers[i] - idxs = solution_indices[i] - cache = inner_caches[i] - - @timeit_debug "sync ->" forward_sync_subintegrator!(outer_integrator, subinteg, idxs, synchronizer) - @timeit_debug "time solve" advance_solution_to!( - outer_integrator, subinteg, idxs, synchronizer, cache, tnext - ) - if !(subinteg isa Tuple) && - subinteg.sol.retcode ∉ - (ReturnCode.Default, ReturnCode.Success) + + idxs = parent.child_solution_indices[i] + sync = parent.child_synchronizers[i] + + @timeit_debug "sync ->" forward_sync_subintegrator!(parent, child, idxs, sync) + @timeit_debug "time solve" advance_solution_by!(parent, child, dt) + if _child_failed(child) + parent.force_stepfail = true return end - backward_sync_subintegrator!(outer_integrator, subinteg, idxs, synchronizer) + + backward_sync_subintegrator!(parent, child, idxs, sync) end end diff --git a/src/utils.jl b/src/utils.jl index dcc9ee0..8dfdd69 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -26,8 +26,9 @@ end """ need_sync(a, b) -This function determines whether it is necessary to synchronize two objects with any solution information. -A possible reason when no synchronization is necessary might be that the vectors alias each other in memory. +Determines whether it is necessary to synchronize two objects with any +solution information. A possible reason when no synchronization is necessary +might be that the vectors alias each other in memory. """ need_sync @@ -39,156 +40,144 @@ need_sync(a::SubArray, b::SubArray) = a.parent !== b.parent """ sync_vectors!(a, b) -Copies the information in object b into object a, if synchronization is necessary. +Copies the information in `b` into `a` if synchronization is necessary. """ function sync_vectors!(a, b) - return if need_sync(a, b) && a !== b + if need_sync(a, b) && a !== b a .= b end + return nothing end """ - forward_sync_subintegrator!(outer_integrator::OperatorSplittingIntegrator, inner_integrator::DEIntegrator, solution_indices, sync) + forward_sync_subintegrator!(parent_integrator::OperatorSplittingIntegrator, inner_integrator::DEIntegrator, solution_indices, sync) -This function is responsible of copying the solution and parameters of the outer integrator and the synchronized subintegrators with the information given into the inner integrator. +This function is responsible of copying the solution and parameters of the parent integrator and the synchronized subintegrators with the information given into the inner integrator. If the inner integrator is synchronized with other inner integrators using `sync`, the function `forward_sync_external!` shall be dispatched for `sync`. The `sync` object is passed from the outside and is the main entry point to dispatch custom types on for parameter synchronization. -The `solution_indices` are global indices in the outer integrators solution vectors. +The `solution_indices` are indices into the parent integrators solution vectors. """ + function forward_sync_subintegrator!( - outer_integrator::OperatorSplittingIntegrator, - inner_integrator::DEIntegrator, solution_indices, sync + parent::AnySplitIntegrator, + child::DEIntegrator, + solution_indices, + sync ) - forward_sync_internal!(outer_integrator, inner_integrator, solution_indices) - return forward_sync_external!(outer_integrator, inner_integrator, sync) + forward_sync_internal!(parent.u, child, solution_indices) + forward_sync_external!(parent, child, sync) + return nothing end +# Shared internal helper: copy master u slice → leaf DEIntegrator u/uprev +function forward_sync_internal!(u_source, child::DEIntegrator, solution_indices) + @views usrc = u_source[solution_indices] + sync_vectors!(child.u, usrc) + sync_vectors!(child.uprev, child.u) + SciMLBase.u_modified!(child, true) + return nothing +end + + """ - backward_sync_subintegrator!(outer_integrator::OperatorSplittingIntegrator, inner_integrator::DEIntegrator, solution_indices, sync) + backward_sync_subintegrator!(parent_integrator::OperatorSplittingIntegrator, inner_integrator::DEIntegrator, solution_indices, sync) -This function is responsible of copying the solution of the inner integrator back into outer integrator and the synchronized subintegrators. +This function is responsible of copying the solution of the inner integrator back into parent integrator and the synchronized subintegrators. If the inner integrator is synchronized with other inner integrators using `sync`, the function `backward_sync_external!` shall be dispatched for `sync`. The `sync` object is passed from the outside and is the main entry point to dispatch custom types on for parameter synchronization. -The `solution_indices` are global indices in the outer integrators solution vectors. +The `solution_indices` are indices in the parent integrators solution vectors. """ -function backward_sync_subintegrator!( - outer_integrator::OperatorSplittingIntegrator, - inner_integrator::DEIntegrator, solution_indices, sync - ) - backward_sync_internal!(outer_integrator, inner_integrator, solution_indices) - return backward_sync_external!(outer_integrator, inner_integrator, sync) -end -# This is a bit tricky, because per default the operator splitting integrators share their solution vector. However, there is also the case -# when part of the problem is on a different device (thing e.g. about operator A being on CPU and B being on GPU). -# This case should be handled with special synchronizers. -function forward_sync_internal!( - outer_integrator::OperatorSplittingIntegrator, - inner_integrator::OperatorSplittingIntegrator, solution_indices - ) - return nothing -end -function backward_sync_internal!( - outer_integrator::OperatorSplittingIntegrator, - inner_integrator::OperatorSplittingIntegrator, solution_indices +function backward_sync_subintegrator!( + parent::AnySplitIntegrator, + child::DEIntegrator, + solution_indices, + sync ) + @views udst = parent.u[solution_indices] + sync_vectors!(udst, child.u) + backward_sync_external!(parent, child, sync) return nothing end -function forward_sync_internal!( - outer_integrator::OperatorSplittingIntegrator, - inner_integrator::DEIntegrator, solution_indices - ) - @views uouter = outer_integrator.u[solution_indices] - sync_vectors!(inner_integrator.uprev, uouter) - sync_vectors!(inner_integrator.u, uouter) - return SciMLBase.u_modified!(inner_integrator, true) -end -function backward_sync_internal!( - outer_integrator::OperatorSplittingIntegrator, - inner_integrator::DEIntegrator, solution_indices - ) - @views uouter = outer_integrator.u[solution_indices] - return sync_vectors!(uouter, inner_integrator.u) -end +# --------------------------------------------------------------------------- +# forward_sync_external! / backward_sync_external! +# These handle parameter synchronisation via the `sync` object. +# --------------------------------------------------------------------------- -# This is a noop, because operator splitting integrators do not have parameters for now -function forward_sync_external!( - outer_integrator::OperatorSplittingIntegrator, - inner_integrator::OperatorSplittingIntegrator, sync::NoExternalSynchronization - ) - return nothing -end -function forward_sync_external!( - outer_integrator::OperatorSplittingIntegrator, - inner_integrator::DEIntegrator, sync::NoExternalSynchronization - ) - return nothing -end -function forward_sync_external!( - outer_integrator::OperatorSplittingIntegrator, - inner_integrator::DEIntegrator, sync - ) - return synchronize_solution_with_parameters!(outer_integrator, inner_integrator.p, sync) -end +# NoExternalSynchronization: no-op for all parent/child combinations +forward_sync_external!(parent::DEIntegrator, child::DEIntegrator, ::NoExternalSynchronization) = nothing +backward_sync_external!(parent::DEIntegrator, child::DEIntegrator, ::NoExternalSynchronization) = nothing +forward_sync_external!(parent::OperatorSplittingIntegrator, child::DEIntegrator, ::NoExternalSynchronization) = nothing +backward_sync_external!(parent::OperatorSplittingIntegrator, child::DEIntegrator, ::NoExternalSynchronization) = nothing -function backward_sync_external!( - outer_integrator::OperatorSplittingIntegrator, - inner_integrator::OperatorSplittingIntegrator, sync::NoExternalSynchronization - ) - return nothing -end -function backward_sync_external!( - outer_integrator::OperatorSplittingIntegrator, - inner_integrator::DEIntegrator, sync::NoExternalSynchronization +# OperatorSplittingIntegrator parent with DEIntegrator child: parameter sync +function forward_sync_external!( + parent::OperatorSplittingIntegrator, + child::DEIntegrator, + sync ) - return nothing + return synchronize_solution_with_parameters!(parent, child.p, sync) end function backward_sync_external!( - outer_integrator::OperatorSplittingIntegrator, - inner_integrator::DEIntegrator, sync + parent::OperatorSplittingIntegrator, + child::DEIntegrator, + sync ) - return synchronize_solution_with_parameters!(outer_integrator, inner_integrator.p, sync) + return synchronize_solution_with_parameters!(parent, child.p, sync) end -function synchronize_solution_with_parameters!(outer_integrator::OperatorSplittingIntegrator, p, sync) + +function synchronize_solution_with_parameters!( + parent::OperatorSplittingIntegrator, p, sync + ) @warn "Outer synchronizer not dispatched for parameter type $(typeof(p)) with synchronizer type $(typeof(sync))." maxlog = 1 return nothing end -# If we encounter NullParameters, then we have the convention for the standard sync map that no external solution is necessary. function synchronize_solution_with_parameters!( - outer_integrator::OperatorSplittingIntegrator, p::NullParameters, sync + parent::OperatorSplittingIntegrator, ::NullParameters, sync ) return nothing end -# TODO this should go into a custom tree data structure instead of into a tuple-tree -function build_solution_index_tree(f::GenericSplitFunction) - return ntuple( - i -> build_solution_index_tree_recursion(f.functions[i], f.solution_indices[i]), - length(f.functions) - ) +# Time stuff +function OrdinaryDiffEqCore.fix_dt_at_bounds!(integrator::AnySplitIntegrator) + if tdir(integrator) > 0 + integrator.dt = min(integrator.opts.dtmax, integrator.dt) + else + integrator.dt = max(integrator.opts.dtmax, integrator.dt) + end + dtmin = OrdinaryDiffEqCore.timedepentdtmin(integrator) + if tdir(integrator) > 0 + integrator.dt = max(integrator.dt, dtmin) + else + integrator.dt = min(integrator.dt, dtmin) + end + return nothing end -function build_solution_index_tree_recursion(f::GenericSplitFunction, solution_indices) - return ntuple( - i -> build_solution_index_tree_recursion(f.functions[i], f.solution_indices[i]), - length(f.functions) - ) +# Check time-step information consistency +validate_time_point(integrator::AnySplitIntegrator) = validate_time_point(integrator, integrator.child_subintegrators) +function validate_time_point(parent, child::SplitSubIntegrator) + @assert parent.t == child.t "(parent.t=$(parent.t) != child.t=$(child.t))" + return validate_time_point(child, child.child_subintegrators) end -function build_solution_index_tree_recursion(f, solution_indices) - return solution_indices +@unroll function validate_time_point(parent, children::Tuple) + @unroll for child in children + validate_time_point(parent, child) + end end -function build_synchronizer_tree(f::GenericSplitFunction) - return ntuple(i -> build_synchronizer_tree_recursion(f.functions[i], f.synchronizers[i]), length(f.functions)) +function validate_time_point(parent, child::DEIntegrator) + return @assert child.t == parent.t "(parent.t=$(parent.t) != child.t=$(child.t))" end -function build_synchronizer_tree_recursion(f::GenericSplitFunction, synchronizers) - return ntuple(i -> build_synchronizer_tree_recursion(f.functions[i], f.synchronizers[i]), length(f.functions)) -end +# --------------------------------------------------------------------------- +# _child_failed: check whether a child reported a failure +# --------------------------------------------------------------------------- +_child_failed(child::DEIntegrator) = + child.sol.retcode ∉ (ReturnCode.Default, ReturnCode.Success) -function build_synchronizer_tree_recursion(f, synchronizer) - return synchronizer -end +_child_failed(child::SplitSubIntegrator) = + child.status.retcode ∉ (ReturnCode.Default, ReturnCode.Success) diff --git a/test/operator_splitting_api.jl b/test/operator_splitting_api.jl index 51537c5..2f11a30 100644 --- a/test/operator_splitting_api.jl +++ b/test/operator_splitting_api.jl @@ -8,7 +8,9 @@ using OrdinaryDiffEqLowOrderRK using OrdinaryDiffEqTsit5 using ModelingToolkit -# Reference +# --------------------------------------------------------------------------- +# Reference problem +# --------------------------------------------------------------------------- tspan = (0.0, 100.0) u0 = [ 0.7611944793397108 @@ -33,26 +35,22 @@ end trueu = exp((tspan[2] - tspan[1]) * (trueA + trueB)) * u0 # Setup individual functions -# Diagonal components function ode1(du, u, p, t) return @. du = -0.1u end f1 = ODEFunction(ode1) -# Off-diagonal components function ode2(du, u, p, t) du[1] = -0.01u[2] return du[2] = -0.01u[1] end f2 = ODEFunction(ode2) -# Now some recursive splitting function ode3(du, u, p, t) du[1] = -0.005u[2] return du[2] = -0.005u[1] end f3 = ODEFunction(ode3) -# The time stepper carries the individual solver information. @independent_variables time Dt = Differential(time) @@ -70,74 +68,75 @@ end testsys2 = mtkcompile(testmodel2; sort_eqs = false) # Test whether adaptive code path works in principle -struct FakeAdaptiveAlgorithm{T} <: OS.AbstractOperatorSplittingAlgorithm +struct FakeAdaptiveAlgorithm{T, T2} <: OS.AbstractOperatorSplittingAlgorithm alg::T + inner_algs::T2 # delegate inner_algs to the wrapped algorithm end +FakeAdaptiveAlgorithm(alg) = FakeAdaptiveAlgorithm(alg, alg.inner_algs) + struct FakeAdaptiveAlgorithmCache{T} <: OS.AbstractOperatorSplittingCache cache::T end + @inline DiffEqBase.isadaptive(::FakeAdaptiveAlgorithm) = true -@inline function OS.stepsize_controller!(integrator::OS.OperatorSplittingIntegrator, alg::FakeAdaptiveAlgorithm) +@inline function OS.stepsize_controller!( + integrator::OS.OperatorSplittingIntegrator, alg::FakeAdaptiveAlgorithm + ) return nothing end -@inline function OS.step_accept_controller!(integrator::OS.OperatorSplittingIntegrator, alg::FakeAdaptiveAlgorithm, q) +@inline function OS.step_accept_controller!( + integrator::OS.OperatorSplittingIntegrator, alg::FakeAdaptiveAlgorithm, q + ) integrator.dt = integrator.dtcache return nothing end -@inline function OS.step_reject_controller!(integrator::OS.OperatorSplittingIntegrator, alg::FakeAdaptiveAlgorithm, q) +@inline function OS.step_reject_controller!( + integrator::OS.OperatorSplittingIntegrator, alg::FakeAdaptiveAlgorithm, q + ) error("The tests should never run into this scenario!") - return nothing # Do nothing + return nothing end -function OS.build_subintegrator_tree_with_cache( - prob::OS.OperatorSplittingProblem, alg::FakeAdaptiveAlgorithm, - uprevouter::AbstractVector, uouter::AbstractVector, - solution_indices, - t0, dt, tf, - tstops, saveat, d_discontinuities, callback, - adaptive, verbose, - ) - subintegrators, inner_cache = OS.build_subintegrator_tree_with_cache( - prob, alg.alg, uprevouter, uouter, solution_indices, - t0, dt, tf, - tstops, saveat, d_discontinuities, callback, - adaptive, verbose, - ) - return subintegrators, FakeAdaptiveAlgorithmCache( - inner_cache, - ) +# Override init_cache to wrap the inner cache in FakeAdaptiveAlgorithmCache +function OS.init_cache( + f::GenericSplitFunction, alg::FakeAdaptiveAlgorithm; + kwargs... + ) + inner_cache = OS.init_cache(f, alg.alg; kwargs...) + return FakeAdaptiveAlgorithmCache(inner_cache) end -function OS.build_subintegrator_tree_with_cache( - prob::OS.OperatorSplittingProblem, alg::FakeAdaptiveAlgorithm, - f::GenericSplitFunction, p::Tuple, - uprevouter::AbstractVector, uouter::AbstractVector, - solution_indices, - t0, dt, tf, - tstops, saveat, d_discontinuities, callback, - adaptive, verbose, - save_end = false, - controller = nothing + +@inline DiffEqBase.get_tmp_cache( + integrator::OS.OperatorSplittingIntegrator, + alg::OS.AbstractOperatorSplittingAlgorithm, + cache::FakeAdaptiveAlgorithmCache +) = DiffEqBase.get_tmp_cache(integrator, alg, cache.cache) + +@inline function OS._perform_step!( + outer_integrator, + subintegrators::Tuple, + cache::FakeAdaptiveAlgorithmCache, + dt ) - subintegrators, inner_cache = OS.build_subintegrator_tree_with_cache( - prob, alg.alg, f, p, uprevouter, uouter, solution_indices, - t0, dt, tf, - tstops, saveat, d_discontinuities, callback, - adaptive, verbose, + return OS._perform_step!( + outer_integrator, subintegrators, cache.cache, dt ) - - return subintegrators, FakeAdaptiveAlgorithmCache( - inner_cache, - ) end + FakeAdaptiveLTG(inner) = FakeAdaptiveAlgorithm(LieTrotterGodunov(inner)) -@inline DiffEqBase.get_tmp_cache(integrator::OS.OperatorSplittingIntegrator, alg::OS.AbstractOperatorSplittingAlgorithm, cache::FakeAdaptiveAlgorithmCache) = DiffEqBase.get_tmp_cache(integrator, alg, cache.cache) -@inline function OS.advance_solution_to!(outer_integrator::OS.OperatorSplittingIntegrator, subintegrators::Tuple, solution_indices::Tuple, synchronizers::Tuple, cache::FakeAdaptiveAlgorithmCache, tnext) - return OS.advance_solution_to!(outer_integrator, subintegrators, solution_indices, synchronizers, cache.cache, tnext) +function Base.show(io::IO, alg::FakeAdaptiveAlgorithm) + print(io, "FAKE (") + Base.show(io, alg.alg) + return print(io, ")") end + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- @testset "reinit and convergence" begin dt = 0.01π @@ -149,21 +148,21 @@ end fsplit1a = GenericSplitFunction((f1, f2), (f1dofs, f2dofs)) fsplit1b = GenericSplitFunction((f1, testsys2), (f1dofs, f2dofs)) - # Now the usual setup just with our new problem type. prob1a = OperatorSplittingProblem(fsplit1a, u0, tspan) prob1b = OperatorSplittingProblem(fsplit1b, u0, tspan) # Note that we define the dof indices w.r.t the parent function. # Hence the indices for `fsplit2_inner` are. - f1dofs = [1, 2, 3] - f2dofs = [1, 3] - f3dofs = [1, 3] + f3dofs = [1, 2] fsplit2_inner = GenericSplitFunction((f3, f3), (f3dofs, f3dofs)) fsplit2_outer = GenericSplitFunction((f1, fsplit2_inner), (f1dofs, f2dofs)) prob2 = OperatorSplittingProblem(fsplit2_outer, u0, tspan) + + nsteps = ceil(Int, (tspan[2] - tspan[1]) / dt) + for TimeStepperType in (LieTrotterGodunov, FakeAdaptiveLTG) - @testset "Solver type $TimeStepperType | $tstepper" for (prob, tstepper) in ( + @testset "$tstepper" for (prob, tstepper) in ( (prob1a, TimeStepperType((Euler(), Euler()))), (prob1a, TimeStepperType((Tsit5(), Euler()))), (prob1a, TimeStepperType((Euler(), Tsit5()))), @@ -179,20 +178,27 @@ end (prob2, TimeStepperType((Tsit5(), TimeStepperType((Euler(), Tsit5()))))), (prob2, TimeStepperType((Tsit5(), TimeStepperType((Tsit5(), Tsit5()))))), ) - # The remaining code works as usual. integrator = DiffEqBase.init( prob, tstepper, dt = dt, verbose = true, alias_u0 = false, adaptive = false ) @test integrator.sol.retcode == DiffEqBase.ReturnCode.Default + + sub1 = integrator.child_subintegrators[1] + sub2 = integrator.child_subintegrators[2] + DiffEqBase.solve!(integrator) @test integrator.sol.retcode == DiffEqBase.ReturnCode.Success ufinal = copy(integrator.u) @test isapprox(ufinal, trueu, atol = 1.0e-6) @test integrator.t ≈ tspan[2] - @test integrator.subintegrator_tree[1].t ≈ tspan[2] @test integrator.dtcache ≈ dt - @test integrator.iter == ceil(Int, (tspan[2] - tspan[1]) / dt) - @test integrator.subintegrator_tree[1].iter == ceil(Int, (tspan[2] - tspan[1]) / dt) + @test integrator.iter == nsteps + + @test sub1.t ≈ tspan[2] + @test sub1.iter == nsteps + + @test sub2.t ≈ tspan[2] + @test sub2.iter == nsteps DiffEqBase.reinit!(integrator; dt = dt) @test integrator.sol.retcode == DiffEqBase.ReturnCode.Default @@ -200,10 +206,8 @@ end end @test isapprox(ufinal, integrator.u, atol = 1.0e-12) @test integrator.t ≈ tspan[2] - @test integrator.subintegrator_tree[1].t ≈ tspan[2] @test integrator.dtcache ≈ dt - @test integrator.iter == ceil(Int, (tspan[2] - tspan[1]) / dt) - @test integrator.subintegrator_tree[1].iter == ceil(Int, (tspan[2] - tspan[1]) / dt) + @test integrator.iter == nsteps DiffEqBase.reinit!(integrator; dt = dt) @test integrator.sol.retcode == DiffEqBase.ReturnCode.Default @@ -211,20 +215,22 @@ end end @test isapprox(ufinal, integrator.u, atol = 1.0e-12) @test integrator.t ≈ tspan[2] - @test integrator.subintegrator_tree[1].t ≈ tspan[2] @test integrator.dtcache ≈ dt - @test integrator.iter == ceil(Int, (tspan[2] - tspan[1]) / dt) - @test integrator.subintegrator_tree[1].iter == ceil(Int, (tspan[2] - tspan[1]) / dt) + @test integrator.iter == nsteps DiffEqBase.reinit!(integrator; dt = dt) @test integrator.sol.retcode == DiffEqBase.ReturnCode.Default DiffEqBase.solve!(integrator) @test integrator.sol.retcode == DiffEqBase.ReturnCode.Success @test integrator.t ≈ tspan[2] - @test integrator.subintegrator_tree[1].t ≈ tspan[2] @test integrator.dtcache ≈ dt - @test integrator.iter == ceil(Int, (tspan[2] - tspan[1]) / dt) - @test integrator.subintegrator_tree[1].iter == ceil(Int, (tspan[2] - tspan[1]) / dt) + @test integrator.iter == nsteps + + @test sub1.t ≈ tspan[2] + @test sub1.iter == nsteps + + @test sub2.t ≈ tspan[2] + @test sub2.iter == nsteps end end @@ -233,7 +239,6 @@ end (prob1a, TimeStepperType((Tsit5(), Tsit5()))), (prob2, TimeStepperType((Tsit5(), TimeStepperType((Tsit5(), Tsit5()))))), ) - # The remaining code works as usual. integrator = DiffEqBase.init( prob, tstepper, dt = dt, verbose = true, alias_u0 = false, adaptive = true ) @@ -243,21 +248,19 @@ end ufinal = copy(integrator.u) @test isapprox(ufinal, trueu, atol = 1.0e-6) @test integrator.t ≈ tspan[2] - @test integrator.subintegrator_tree[1].t ≈ tspan[2] @test integrator.dtcache ≈ dt - @test integrator.iter == ceil(Int, (tspan[2] - tspan[1]) / dt) - @test integrator.subintegrator_tree[1].iter == ceil(Int, (tspan[2] - tspan[1]) / dt) + @test integrator.iter == nsteps DiffEqBase.reinit!(integrator; dt = dt) + @test integrator.dt == dt + @test integrator.dt == integrator.dtcache @test integrator.sol.retcode == DiffEqBase.ReturnCode.Default for (u, t) in DiffEqBase.TimeChoiceIterator(integrator, tspan[1]:5.0:tspan[2]) end @test isapprox(ufinal, integrator.u, atol = 1.0e-12) @test integrator.t ≈ tspan[2] - @test integrator.subintegrator_tree[1].t ≈ tspan[2] @test integrator.dtcache ≈ dt - @test integrator.iter == ceil(Int, (tspan[2] - tspan[1]) / dt) - @test integrator.subintegrator_tree[1].iter == ceil(Int, (tspan[2] - tspan[1]) / dt) + @test integrator.iter == nsteps DiffEqBase.reinit!(integrator; dt = dt) @test integrator.sol.retcode == DiffEqBase.ReturnCode.Default @@ -265,24 +268,20 @@ end end @test isapprox(ufinal, integrator.u, atol = 1.0e-12) @test integrator.t ≈ tspan[2] - @test integrator.subintegrator_tree[1].t ≈ tspan[2] @test integrator.dtcache ≈ dt - @test integrator.iter == ceil(Int, (tspan[2] - tspan[1]) / dt) - @test integrator.subintegrator_tree[1].iter == ceil(Int, (tspan[2] - tspan[1]) / dt) + @test integrator.iter == nsteps DiffEqBase.reinit!(integrator; dt = dt) @test integrator.sol.retcode == DiffEqBase.ReturnCode.Default DiffEqBase.solve!(integrator) @test integrator.sol.retcode == DiffEqBase.ReturnCode.Success @test integrator.t ≈ tspan[2] - @test integrator.subintegrator_tree[1].t ≈ tspan[2] @test integrator.dtcache ≈ dt - @test integrator.iter == ceil(Int, (tspan[2] - tspan[1]) / dt) - @test integrator.subintegrator_tree[1].iter == ceil(Int, (tspan[2] - tspan[1]) / dt) + @test integrator.iter == nsteps end end - @testset "Instbility detectioon" begin + @testset "Instability detection" begin dt = 0.01π function ode_NaN(du, u, p, t) @@ -291,11 +290,10 @@ end end f1dofs = [1, 2, 3] - f2dofs = [1, 3] + f3dofs = [1, 3] f_NaN = ODEFunction(ode_NaN) - f_NaN_dofs = f3dofs - fsplit_NaN = GenericSplitFunction((f1, f_NaN), (f1dofs, f_NaN_dofs)) + fsplit_NaN = GenericSplitFunction((f1, f_NaN), (f1dofs, f3dofs)) prob_NaN = OperatorSplittingProblem(fsplit_NaN, u0, tspan) for TimeStepperType in (LieTrotterGodunov,)