From b0dc6525d18355230213c2493c530b3e08a3b26c Mon Sep 17 00:00:00 2001 From: termi-official Date: Fri, 20 Feb 2026 14:35:41 +0100 Subject: [PATCH 01/17] Initial try --- src/integrator.jl | 456 +++++++++++++++++++++++++++------ src/solver.jl | 112 +++++++- src/utils.jl | 134 ++++++---- test/operator_splitting_api.jl | 181 ++++++------- 4 files changed, 663 insertions(+), 220 deletions(-) diff --git a/src/integrator.jl b/src/integrator.jl index 45b2df0..af196b7 100644 --- a/src/integrator.jl +++ b/src/integrator.jl @@ -15,10 +15,95 @@ Base.@kwdef mutable struct IntegratorOptions{tType, fType, F3} isoutofdomain::F3 = DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN end +# --------------------------------------------------------------------------- +# SplitSubIntegratorStatus +# --------------------------------------------------------------------------- +""" + SplitSubIntegratorStatus + +Minimal error-communication object carried by a [`SplitSubIntegrator`](@ref). +It contains only the `retcode` so that failure can be propagated up the +operator-splitting tree without carrying an actual solution vector. +""" +mutable struct SplitSubIntegratorStatus + retcode::ReturnCode.T +end + +SplitSubIntegratorStatus() = SplitSubIntegratorStatus(ReturnCode.Default) + +# --------------------------------------------------------------------------- +# SplitSubIntegrator +# --------------------------------------------------------------------------- +""" + SplitSubIntegrator + +An intermediate node in the operator-splitting subintegrator tree. It is +self-contained: it knows its own solution indices, its child synchronizers, +and the child solution-index tree. It does **not** carry an `f` field +(operator information lives in the cache / algorithm). + +Fields +------ +- `alg` — the `AbstractOperatorSplittingAlgorithm` at this level +- `u` — view into the *master* solution vector for this sub-problem +- `uprev` — copy of `u` at the start of a step (for rollback) +- `u_master` — reference to the full master solution vector of the + outermost `OperatorSplittingIntegrator` (needed during sync) +- `t`, `dt`, `dtcache` — time tracking +- `iter` — step counter +- `EEst` — error estimate (`NaN` for non-adaptive, `1.0` default for adaptive) +- `controller` — step-size controller (or `nothing` for non-adaptive) +- `force_stepfail` — flag set when a step must be re-tried +- `last_step_failed` — flag set after a failed step to detect double-failure +- `status` — [`SplitSubIntegratorStatus`](@ref) for retcode communication +- `cache` — `AbstractOperatorSplittingCache` for the algorithm at this level +- `subintegrator_tree` — tuple of child integrators (`SplitSubIntegrator` or `DEIntegrator`) +- `solution_indices` — global indices (into master `u`) owned by this sub-integrator +- `solution_index_tree`— per-child global solution indices +- `synchronizer_tree` — per-child synchronizer objects +""" +mutable struct SplitSubIntegrator{ + algType, + uType, + tType, + EEstType, + controllerType, + cacheType, + subintTreeType, + solidxType, + solidxTreeType, + syncTreeType, + } + alg::algType + u::uType # view into master u for this sub-problem + uprev::uType # local copy for rollback (same element type, plain Array) + u_master::uType # reference to the outermost master u + t::tType + dt::tType + dtcache::tType + iter::Int + EEst::EEstType + controller::controllerType + force_stepfail::Bool + last_step_failed::Bool + status::SplitSubIntegratorStatus + cache::cacheType + subintegrator_tree::subintTreeType # Tuple + solution_indices::solidxType + solution_index_tree::solidxTreeType # Tuple + synchronizer_tree::syncTreeType # Tuple +end + +# Convenience predicate +@inline SciMLBase.isadaptive(sub::SplitSubIntegrator) = isadaptive(sub.alg) + +# --------------------------------------------------------------------------- +# OperatorSplittingIntegrator +# --------------------------------------------------------------------------- """ OperatorSplittingIntegrator <: AbstractODEIntegrator -A variant of [`ODEIntegrator`](https://github.com/SciML/OrdinaryDiffEq.jl/blob/6ec5a55bda26efae596bf99bea1a1d729636f412/src/integrators/type.jl#L77-L123) to perform opeartor splitting. +A variant of [`ODEIntegrator`](https://github.com/SciML/OrdinaryDiffEq.jl/blob/6ec5a55bda26efae596bf99bea1a1d729636f412/src/integrators/type.jl#L77-L123) to perform operator splitting. Derived from https://github.com/CliMA/ClimaTimeSteppers.jl/blob/ef3023747606d2750e674d321413f80638136632/src/integrators.jl. """ @@ -35,26 +120,24 @@ mutable struct OperatorSplittingIntegrator{ cacheType, solType, subintTreeType, - solidxTreeType, - syncTreeType, controllerType, optionsType, } <: SciMLBase.AbstractODEIntegrator{algType, true, uType, tType} const f::fType const alg::algType - u::uType # Master Solution - uprev::uType # Master Solution - tmp::uType # Interpolation buffer + u::uType # Master Solution + uprev::uType # Master Solution + tmp::uType # Interpolation buffer p::pType - t::tType # Current time + t::tType # Current time tprev::tType - dt::tType # This is the time step length which which we use during time marching - dtcache::tType # This is the proposed time step length + dt::tType # Time step length used during time marching + dtcache::tType # Proposed time step length const dtchangeable::Bool # Indicator whether dtcache can be changed tstops::heapType - _tstops::tstopsType # argument to __init used as default argument to reinit! + _tstops::tstopsType # argument to __init used as default argument to reinit! saveat::heapType - _saveat::saveatType # argument to __init used as default argument to reinit! + _saveat::saveatType # argument to __init used as default argument to reinit! callback::callbackType advance_to_tstop::Bool # TODO group these into some internal flag struct @@ -65,9 +148,11 @@ mutable struct OperatorSplittingIntegrator{ # DiffEqBase.initialize! and DiffEqBase.finalize! cache::cacheType sol::solType + # NOTE: solution_index_tree and synchronizer_tree have been moved into + # the SplitSubIntegrator nodes. The flat subintegrator_tree here is a + # Tuple of SplitSubIntegrator (or DEIntegrator for the degenerate + # single-level case). subintegrator_tree::subintTreeType - solution_index_tree::solidxTreeType - synchronizer_tree::syncTreeType iter::Int controller::controllerType opts::optionsType @@ -75,6 +160,9 @@ mutable struct OperatorSplittingIntegrator{ tdir::tType end +# --------------------------------------------------------------------------- +# __init +# --------------------------------------------------------------------------- # called by DiffEqBase.init and DiffEqBase.solve function SciMLBase.__init( prob::OperatorSplittingProblem, @@ -131,10 +219,13 @@ function SciMLBase.__init( callback = DiffEqBase.CallbackSet(callback) - subintegrator_tree, - cache = build_subintegrator_tree_with_cache( + # Build the subintegrator tree. Each SplitSubIntegrator is now + # self-contained: it holds its own solution_indices, solution_index_tree, + # and synchronizer_tree. + subintegrator_tree, cache = build_subintegrator_tree_with_cache( prob, alg, uprev, u, + u, # u_master == u at the outermost level 1:length(u), t0, dt, tf, tstops, saveat, d_discontinuities, callback, @@ -166,15 +257,13 @@ function SciMLBase.__init( cache, sol, subintegrator_tree, - build_solution_index_tree(prob.f), - build_synchronizer_tree(prob.f), 0, controller, IntegratorOptions(; verbose, adaptive), IntegratorStats(), tType(tstops_internal.ordering isa DataStructures.FasterForward ? 1 : -1) ) - DiffEqBase.initialize!(callback, u0, t0, integrator) # Do I need this? + DiffEqBase.initialize!(callback, u0, t0, integrator) return integrator end @@ -219,7 +308,6 @@ function DiffEqBase.reinit!( return subreinit!( integrator.f, u0, - 1:length(u0), integrator.subintegrator_tree; t0, tf, dt, erase_sol, @@ -230,10 +318,10 @@ function DiffEqBase.reinit!( ) end +# subreinit! for a leaf DEIntegrator function subreinit!( f, u0, - solution_indices, subintegrator::DEIntegrator; dt, kwargs... @@ -242,23 +330,97 @@ function subreinit!( if dt !== nothing subintegrator.dt = dt end - return DiffEqBase.reinit!(subintegrator, u0[solution_indices]; kwargs...) + # solution_indices are carried by the parent SplitSubIntegrator + error("subreinit! called directly on a DEIntegrator — should be reached only via SplitSubIntegrator") +end + +# subreinit! for an intermediate SplitSubIntegrator +function subreinit!( + f, + u0, + sub::SplitSubIntegrator; + t0, + tf, + dt, + kwargs... + ) + idxs = sub.solution_indices + sub.u .= @view u0[idxs] + sub.uprev .= @view u0[idxs] + sub.t = t0 + if dt !== nothing + sub.dt = dt + sub.dtcache = dt + end + sub.iter = 0 + sub.force_stepfail = false + sub.last_step_failed = false + sub.status = SplitSubIntegratorStatus(ReturnCode.Default) + # Reset EEst to the appropriate default + if isadaptive(sub) + sub.EEst = one(sub.EEst) + else + sub.EEst = sub.EEst # keep NaN sentinel + end + return subreinit_children!(f, u0, sub; t0, tf, dt, kwargs...) +end + +@unroll function subreinit_children!( + f, + u0, + sub::SplitSubIntegrator; + kwargs... + ) + i = 1 + @unroll for child in sub.subintegrator_tree + _subreinit_child!(get_operator(f, i), u0, child, sub.solution_index_tree[i]; kwargs...) + i += 1 + end +end + +# Dispatch for leaf DEIntegrator children +function _subreinit_child!( + f_child, + u0, + child::DEIntegrator, + child_solution_indices; + dt, + kwargs... + ) + if dt !== nothing + child.dt = dt + end + return DiffEqBase.reinit!(child, @view(u0[child_solution_indices]); kwargs...) +end + +# Dispatch for nested SplitSubIntegrator children +function _subreinit_child!( + f_child, + u0, + child::SplitSubIntegrator, + _child_solution_indices; # ignored — child carries its own + kwargs... + ) + return subreinit!(f_child, u0, child; kwargs...) end +# Top-level subreinit! over a tuple of subintegrators (called from reinit!) @unroll function subreinit!( f, u0, - solution_indices, subintegrators::Tuple; kwargs... ) i = 1 - @unroll for subintegrator in subintegrators - subreinit!(get_operator(f, i), u0, f.solution_indices[i], subintegrator; kwargs...) + @unroll for sub in subintegrators + subreinit!(get_operator(f, i), u0, sub; kwargs...) i += 1 end end +# --------------------------------------------------------------------------- +# handle_tstop! +# --------------------------------------------------------------------------- function OrdinaryDiffEqCore.handle_tstop!(integrator::OperatorSplittingIntegrator) if SciMLBase.has_tstop(integrator) tdir_t = tdir(integrator) * integrator.t @@ -291,7 +453,9 @@ notify_integrator_hit_tstop!(integrator::OperatorSplittingIntegrator) = nothing is_first_iteration(integrator::OperatorSplittingIntegrator) = integrator.iter == 0 increment_iteration(integrator::OperatorSplittingIntegrator) = integrator.iter += 1 -# Controller interface +# --------------------------------------------------------------------------- +# Controller interface — outermost integrator +# --------------------------------------------------------------------------- function reject_step!(integrator::OperatorSplittingIntegrator) OrdinaryDiffEqCore.increment_reject!(integrator.stats) return reject_step!(integrator, integrator.cache, integrator.controller) @@ -334,6 +498,41 @@ function update_uprev!(integrator::OperatorSplittingIntegrator) return nothing end +# --------------------------------------------------------------------------- +# Controller interface — SplitSubIntegrator +# --------------------------------------------------------------------------- +function reject_step!(sub::SplitSubIntegrator) + sub.u .= sub.uprev + # Propagate rollback to all leaf DEIntegrators within this subtree so + # their state is consistent before the next attempt. + _rollback_subintegrator_tree!(sub.subintegrator_tree, sub.u_master) +end + +function _rollback_subintegrator_tree!(subintegrators::Tuple, u_master) + @unroll for child in subintegrators + _rollback_child!(child, u_master) + end +end + +function _rollback_child!(child::SplitSubIntegrator, u_master) + child.u .= child.uprev + _rollback_subintegrator_tree!(child.subintegrator_tree, u_master) +end + +function _rollback_child!(child::DEIntegrator, u_master) + # The leaf integrator's uprev already holds the correct state because + # forward_sync_internal! copies u_master into it before each sub-step. + # Nothing to do here beyond letting the view aliasing keep things consistent. + return nothing +end + +function accept_step!(sub::SplitSubIntegrator) + RecursiveArrayTools.recursivecopy!(sub.uprev, sub.u) +end + +# --------------------------------------------------------------------------- +# step_header! / step_footer! — outermost integrator +# --------------------------------------------------------------------------- function step_header!(integrator::OperatorSplittingIntegrator) # Accept or reject the step if !is_first_iteration(integrator) @@ -377,8 +576,7 @@ function step_footer!(integrator::OperatorSplittingIntegrator) if should_accept_step(integrator) integrator.last_step_failed = false integrator.tprev = integrator.t - integrator.t = ttmp #OrdinaryDiffEqCore.fixed_t_for_floatingpoint_error!(integrator, ttmp) - # OrdinaryDiffEqCore.handle_callbacks!(integrator) + integrator.t = ttmp step_accept_controller!(integrator) # Noop for non-adaptive algorithms elseif integrator.force_stepfail if isadaptive(integrator) @@ -392,11 +590,12 @@ function step_footer!(integrator::OperatorSplittingIntegrator) integrator.last_step_failed = true end - # integration_monitor_step(integrator) - return nothing end +# --------------------------------------------------------------------------- +# __solve / solve! / step! +# --------------------------------------------------------------------------- # called by DiffEqBase.solve function SciMLBase.__solve( prob::OperatorSplittingProblem, @@ -472,7 +671,7 @@ function SciMLBase.check_error(integrator::OperatorSplittingIntegrator) verbose = true # integrator.opts.verbose - if DiffEqBase.NAN_CHECK(integrator.dtcache) || DiffEqBase.NAN_CHECK(integrator.dt) # replace with https://github.com/SciML/OrdinaryDiffEq.jl/blob/373a8eec8024ef1acc6c5f0c87f479aa0cf128c3/lib/OrdinaryDiffEqCore/src/iterator_interface.jl#L5-L6 after moving to sciml integrators + if DiffEqBase.NAN_CHECK(integrator.dtcache) || DiffEqBase.NAN_CHECK(integrator.dt) if verbose @warn("NaN dt detected. Likely a NaN value in the state, parameters, or derivative value caused this outcome.") end @@ -482,9 +681,10 @@ function SciMLBase.check_error(integrator::OperatorSplittingIntegrator) return check_error_subintegrators(integrator, integrator.subintegrator_tree) end +# Recurse over a tuple of children function check_error_subintegrators(integrator, subintegrator_tree::Tuple) - for subintegrator in subintegrator_tree - retcode = check_error_subintegrators(integrator, subintegrator) + for sub in subintegrator_tree + retcode = check_error_subintegrators(integrator, sub) if !SciMLBase.successful_retcode(retcode) && retcode != ReturnCode.Default return retcode end @@ -492,13 +692,23 @@ function check_error_subintegrators(integrator, subintegrator_tree::Tuple) return integrator.sol.retcode end -function check_error_subintegrators(integrator, subintegrator::DEIntegrator) - return SciMLBase.check_error(subintegrator) +# Leaf: read retcode from the DEIntegrator's solution +function check_error_subintegrators(integrator, sub::DEIntegrator) + return SciMLBase.check_error(sub) +end + +# Intermediate node: read retcode from the SplitSubIntegrator status object +function check_error_subintegrators(integrator, sub::SplitSubIntegrator) + rc = sub.status.retcode + if !SciMLBase.successful_retcode(rc) && rc != ReturnCode.Default + return rc + end + # Also recurse into children + return check_error_subintegrators(integrator, sub.subintegrator_tree) end function DiffEqBase.step!(integrator::OperatorSplittingIntegrator, dt, stop_at_tdt = false) return @timeit_debug "step!" begin - # OridinaryDiffEq lets dt be negative if tdir is -1, but that's inconsistent dt <= zero(dt) && error("dt must be positive") stop_at_tdt && !integrator.dtchangeable && error("Cannot stop at t + dt if dtchangeable is false") @@ -525,14 +735,10 @@ end # TimeChoiceIterator API @inline function DiffEqBase.get_tmp_cache(integrator::OperatorSplittingIntegrator) - # DiffEqBase.get_tmp_cache(integrator, integrator.alg, integrator.cache) return (integrator.tmp,) end -# @inline function DiffEqBase.get_tmp_cache(integrator::OperatorSplittingIntegrator, ::AbstractOperatorSplittingAlgorithm, cache) -# return (cache.tmp,) -# end + # Interpolation -# TODO via https://github.com/SciML/SciMLBase.jl/blob/master/src/interpolation.jl function linear_interpolation!(y, t, y1, y2, t1, t2) return y .= y1 + (t - t1) * (y2 - y1) / (t2 - t1) end @@ -542,6 +748,9 @@ function (integrator::OperatorSplittingIntegrator)(tmp, t) ) end +# --------------------------------------------------------------------------- +# Stepsize controller hooks — outermost integrator +# --------------------------------------------------------------------------- """ stepsize_controller!(::OperatorSplittingIntegrator) @@ -575,21 +784,24 @@ Updates `dtcache` of the integrator if the step is rejected and the the operator return step_reject_controller!(integrator, algorithm, nothing) end -# helper functions for dealing with time-reversed integrators in the same way -# that OrdinaryDiffEq.jl does +# --------------------------------------------------------------------------- +# Time helpers +# --------------------------------------------------------------------------- tdir(integrator) = integrator.tstops.ordering isa DataStructures.FasterForward ? 1 : -1 is_past_t(integrator, t) = tdir(integrator) * (t - integrator.t) ≤ zero(integrator.t) function reached_tstop(integrator, tstop, stop_at_tstop = integrator.dtchangeable) if stop_at_tstop integrator.t > tstop && error("Integrator missed stop at $tstop (current time=$(integrator.t)). Aborting.") - return integrator.t == tstop # Check for exact hit - else #!stop_at_tstop + return integrator.t == tstop + else return is_past_t(integrator, tstop) end end -# Dunno stuff +# --------------------------------------------------------------------------- +# SciMLBase integrator interface +# --------------------------------------------------------------------------- function SciMLBase.done(integrator::OperatorSplittingIntegrator) if !( integrator.sol.retcode in ( @@ -620,6 +832,16 @@ function advance_solution_to!(integrator::OperatorSplittingIntegrator, tnext) return advance_solution_to!(integrator, integrator.cache, tnext) end +function advance_solution_to!( + outer_integrator::OperatorSplittingIntegrator, + sub::SplitSubIntegrator, sync, cache, tend + ) + # Advance a SplitSubIntegrator node using its own advance_solution_to! dispatch + dt = tend - sub.t + sub.dt = dt + return advance_solution_to!(outer_integrator, sub, tend) +end + function advance_solution_to!( outer_integrator::OperatorSplittingIntegrator, integrator::DEIntegrator, solution_indices, sync, cache, tend @@ -628,7 +850,9 @@ function advance_solution_to!( return SciMLBase.step!(integrator, dt, true) end -# ----------------------------------- SciMLBase.jl Integrator Interface ------------------------------------ +# --------------------------------------------------------------------------- +# SciMLBase.jl integrator interface +# --------------------------------------------------------------------------- SciMLBase.has_stats(::OperatorSplittingIntegrator) = true SciMLBase.has_tstop(integrator::OperatorSplittingIntegrator) = !isempty(integrator.tstops) @@ -637,7 +861,6 @@ SciMLBase.pop_tstop!(integrator::OperatorSplittingIntegrator) = pop!(integrator. DiffEqBase.get_dt(integrator::OperatorSplittingIntegrator) = integrator.dt function set_dt!(integrator::OperatorSplittingIntegrator, dt) - # TODO: figure out interface for recomputing other objects (linear operators, etc) dt <= zero(dt) && error("dt must be positive") return integrator.dt = dt end @@ -656,10 +879,11 @@ function DiffEqBase.add_saveat!(integrator::OperatorSplittingIntegrator, t) return push!(integrator.saveat, t) end -# not sure what this should do? -# defined as default initialize: https://github.com/SciML/DiffEqBase.jl/blob/master/src/callbacks.jl#L3 DiffEqBase.u_modified!(i::OperatorSplittingIntegrator, bool) = nothing +# --------------------------------------------------------------------------- +# Synchronization +# --------------------------------------------------------------------------- function synchronize_subintegrator_tree!(integrator::OperatorSplittingIntegrator) return synchronize_subintegrator!(integrator.subintegrator_tree, integrator) end @@ -667,41 +891,87 @@ end @unroll function synchronize_subintegrator!( subintegrator_tree::Tuple, integrator::OperatorSplittingIntegrator ) - @unroll for subintegrator in subintegrator_tree - synchronize_subintegrator!(subintegrator, integrator) + @unroll for sub in subintegrator_tree + synchronize_subintegrator!(sub, integrator) end end +# Sync a SplitSubIntegrator node: update its t/dt then recurse into children function synchronize_subintegrator!( - subintegrator::DEIntegrator, integrator::OperatorSplittingIntegrator + sub::SplitSubIntegrator, integrator::OperatorSplittingIntegrator ) (; t, dt) = integrator - @assert subintegrator.t == t - return if !isadaptive(subintegrator) - SciMLBase.set_proposed_dt!(subintegrator, dt) + @assert sub.t == t "SplitSubIntegrator time $(sub.t) out of sync with outer integrator time $t" + if !isadaptive(sub) + sub.dt = dt + sub.dtcache = dt + end + # Recurse: sync children against the *sub-integrator* (not outer) time + @unroll for child in sub.subintegrator_tree + synchronize_subintegrator_child!(child, sub) end end +function synchronize_subintegrator_child!( + child::DEIntegrator, parent::SplitSubIntegrator + ) + @assert child.t == parent.t "Child integrator time $(child.t) out of sync with parent time $(parent.t)" + if !isadaptive(child) + SciMLBase.set_proposed_dt!(child, parent.dt) + end +end + +function synchronize_subintegrator_child!( + child::SplitSubIntegrator, parent::SplitSubIntegrator + ) + @assert child.t == parent.t "Nested SplitSubIntegrator time $(child.t) out of sync with parent time $(parent.t)" + if !isadaptive(child) + child.dt = parent.dt + child.dtcache = parent.dt + end +end + +# --------------------------------------------------------------------------- +# advance_solution_to! for AbstractOperatorSplittingCache +# (dispatches into the algorithm-specific method in solver.jl) +# --------------------------------------------------------------------------- function advance_solution_to!( integrator::OperatorSplittingIntegrator, cache::AbstractOperatorSplittingCache, tnext::Number ) return advance_solution_to!( - integrator, integrator.subintegrator_tree, integrator.solution_index_tree, - integrator.synchronizer_tree, cache, tnext + integrator, integrator.subintegrator_tree, cache, tnext ) end -# Dispatch for tree node construction +# advance_solution_to! for a SplitSubIntegrator node +# (the algorithm-specific method in solver.jl calls this signature) +function advance_solution_to!( + outer_integrator::OperatorSplittingIntegrator, + sub::SplitSubIntegrator, tend + ) + return advance_solution_to!( + outer_integrator, sub, sub.subintegrator_tree, + sub.solution_index_tree, sub.synchronizer_tree, + sub.cache, tend + ) +end + +# --------------------------------------------------------------------------- +# Tree construction +# --------------------------------------------------------------------------- +# Top-level dispatch: builds a Tuple of SplitSubIntegrators function build_subintegrator_tree_with_cache( prob::OperatorSplittingProblem, alg::AbstractOperatorSplittingAlgorithm, uprevouter::AbstractVector, uouter::AbstractVector, + u_master::AbstractVector, solution_indices, t0, dt, tf, tstops, saveat, d_discontinuities, callback, adaptive, verbose ) (; f, p) = prob + subintegrator_tree_with_caches = ntuple( i -> build_subintegrator_tree_with_cache( prob, @@ -709,6 +979,7 @@ function build_subintegrator_tree_with_cache( get_operator(f, i), p[i], uprevouter, uouter, + u_master, f.solution_indices[i], t0, dt, tf, tstops, saveat, d_discontinuities, callback, @@ -722,7 +993,6 @@ function build_subintegrator_tree_with_cache( ) caches = ntuple(i -> subintegrator_tree_with_caches[i][2], length(f.functions)) - # TODO fix mixed device type problems we have to be smarter return subintegrator_tree, init_cache( f, alg; @@ -731,10 +1001,13 @@ function build_subintegrator_tree_with_cache( ) end +# Intermediate node: inner algorithm is also an AbstractOperatorSplittingAlgorithm +# wrapping a GenericSplitFunction → produce a SplitSubIntegrator function build_subintegrator_tree_with_cache( prob::OperatorSplittingProblem, alg::AbstractOperatorSplittingAlgorithm, f::GenericSplitFunction, p::Tuple, uprevouter::AbstractVector, uouter::AbstractVector, + u_master::AbstractVector, solution_indices, t0, dt, tf, tstops, saveat, d_discontinuities, callback, @@ -742,13 +1015,17 @@ function build_subintegrator_tree_with_cache( save_end = false, controller = nothing ) - subintegrator_tree_with_caches = ntuple( + tType = typeof(dt) + + # Build children recursively + child_results = ntuple( i -> build_subintegrator_tree_with_cache( prob, alg.inner_algs[i], get_operator(f, i), p[i], uprevouter, uouter, + u_master, f.solution_indices[i], t0, dt, tf, tstops, saveat, d_discontinuities, callback, @@ -757,25 +1034,57 @@ function build_subintegrator_tree_with_cache( length(f.functions) ) - subintegrator_tree = first.(subintegrator_tree_with_caches) - inner_caches = last.(subintegrator_tree_with_caches) + child_subintegrators = ntuple(i -> child_results[i][1], length(f.functions)) + child_caches = ntuple(i -> child_results[i][2], length(f.functions)) - # TODO fix mixed device type problems we have to be smarter - uprev = @view uprevouter[solution_indices] - u = @view uouter[solution_indices] - return subintegrator_tree, - init_cache( - f, alg; - uprev = uprev, u = u, - inner_caches = inner_caches - ) + # Build per-child solution_index_tree and synchronizer_tree + child_solution_indices = ntuple(i -> f.solution_indices[i], length(f.functions)) + child_synchronizers = ntuple(i -> f.synchronizers[i], length(f.functions)) + + # Cache for *this* level + u_sub = @view uouter[solution_indices] + uprev_sub = @view uprevouter[solution_indices] + level_cache = init_cache( + f, alg; + uprev = uprev_sub, u = u_sub, + inner_caches = child_caches + ) + + # EEst default + EEst_val = isadaptive(alg) ? one(tType) : tType(NaN) + + sub = SplitSubIntegrator( + alg, + u_sub, + RecursiveArrayTools.recursivecopy(u_sub), # uprev: local copy for rollback + u_master, + t0, + dt, + dt, # dtcache + 0, # iter + EEst_val, + controller, + false, # force_stepfail + false, # last_step_failed + SplitSubIntegratorStatus(), + level_cache, + child_subintegrators, + solution_indices, + child_solution_indices, + child_synchronizers + ) + + return sub, level_cache end +# Leaf node: inner algorithm is a plain SciMLBase.AbstractODEAlgorithm +# → produce an ODEIntegrator (existing behaviour) function build_subintegrator_tree_with_cache( prob::OperatorSplittingProblem, alg::SciMLBase.AbstractODEAlgorithm, f::F, p::P, uprevouter::S, uouter::S, + u_master::S, solution_indices, t0::T, dt::T, tf::T, tstops, saveat, d_discontinuities, callback, @@ -786,10 +1095,6 @@ function build_subintegrator_tree_with_cache( uprev = @view uprevouter[solution_indices] u = @view uouter[solution_indices] - # When working with MTK, we want to pass f as a system down here. - # In that case ODEProblem constructs the correct parameter struct. - # If the system does not have parameters in first place, then - # The NullParameters object will be constructed automatically. prob2 = if p isa NullParameters SciMLBase.ODEProblem(f, u, (t0, min(t0 + dt, tf))) else @@ -812,6 +1117,7 @@ function build_subintegrator_tree_with_cache( return integrator, integrator.cache end +# forward/backward sync no-ops for tuple nodes (handled inside SplitSubIntegrator) function forward_sync_subintegrator!( outer_integrator::OperatorSplittingIntegrator, subintegrator_tree::Tuple, solution_indices::Tuple, synchronizers::Tuple diff --git a/src/solver.jl b/src/solver.jl index b2b0070..2d7cfcd 100644 --- a/src/solver.jl +++ b/src/solver.jl @@ -5,8 +5,7 @@ A first order sequential operator splitting algorithm attributed to [Lie:1880:tti,Tro:1959:psg,God:1959:dmn](@cite). """ struct LieTrotterGodunov{AlgTupleType} <: AbstractOperatorSplittingAlgorithm - inner_algs::AlgTupleType # Tuple of timesteppers for inner problems - # transfer_algs::TransferTupleType # Tuple of transfer algorithms from the master solution into the individual ones + inner_algs::AlgTupleType end struct LieTrotterGodunovCache{uType, uprevType, iiType} <: AbstractOperatorSplittingCache @@ -27,30 +26,117 @@ function init_cache( return LieTrotterGodunovCache(_u, _uprev, inner_caches) end +# --------------------------------------------------------------------------- +# advance_solution_to! for the outermost integrator with a Tuple of children +# This is the top-level dispatch when the outer integrator's cache is a +# LieTrotterGodunovCache and children are SplitSubIntegrators or DEIntegrators. +# --------------------------------------------------------------------------- @inline @unroll function advance_solution_to!( outer_integrator::OperatorSplittingIntegrator, - subintegrators::Tuple, solution_indices::Tuple, - synchronizers::Tuple, cache::LieTrotterGodunovCache, tnext + subintegrators::Tuple, cache::LieTrotterGodunovCache, tnext + ) + (; inner_caches) = cache + i = 0 + @unroll for subinteg in subintegrators + i += 1 + inner_cache = inner_caches[i] + _advance_child!(outer_integrator, subinteg, inner_cache, tnext) + # Check for failure after each child + if _child_failed(outer_integrator, subinteg) + outer_integrator.force_stepfail = true + return + end + end +end + +# --------------------------------------------------------------------------- +# advance_solution_to! for a SplitSubIntegrator node with LieTrotterGodunov +# This is the recursive dispatch when a SplitSubIntegrator's own cache is LTG. +# --------------------------------------------------------------------------- +@inline @unroll function advance_solution_to!( + outer_integrator::OperatorSplittingIntegrator, + sub::SplitSubIntegrator, + subintegrators::Tuple, + solution_indices::Tuple, + synchronizers::Tuple, + cache::LieTrotterGodunovCache, + tnext ) - # We assume that the integrators are already synced (; inner_caches) = cache - # For each inner operator i = 0 @unroll for subinteg in subintegrators i += 1 synchronizer = synchronizers[i] idxs = solution_indices[i] - cache = inner_caches[i] + inner_cache = inner_caches[i] - @timeit_debug "sync ->" forward_sync_subintegrator!(outer_integrator, subinteg, idxs, synchronizer) - @timeit_debug "time solve" advance_solution_to!( - outer_integrator, subinteg, idxs, synchronizer, cache, tnext + @timeit_debug "sync ->" forward_sync_subintegrator!( + outer_integrator, subinteg, idxs, synchronizer ) - if !(subinteg isa Tuple) && - subinteg.sol.retcode ∉ - (ReturnCode.Default, ReturnCode.Success) + @timeit_debug "time solve" _advance_child!( + outer_integrator, subinteg, inner_cache, tnext + ) + if _child_failed(outer_integrator, subinteg) + sub.status = SplitSubIntegratorStatus(ReturnCode.Failure) return end backward_sync_subintegrator!(outer_integrator, subinteg, idxs, synchronizer) end + # All children succeeded: mark this node as successful + sub.status = SplitSubIntegratorStatus(ReturnCode.Success) + # Accept the sub-step: copy u into uprev for potential future rollback + accept_step!(sub) + sub.t = tnext + sub.iter += 1 +end + +# --------------------------------------------------------------------------- +# _advance_child!: dispatch on child type +# --------------------------------------------------------------------------- + +# Child is a SplitSubIntegrator: call its own advance_solution_to! +function _advance_child!( + outer_integrator::OperatorSplittingIntegrator, + child::SplitSubIntegrator, + _inner_cache, # ignored — child uses its own cache + tnext + ) + # Forward sync from the outer master u into this child's u + forward_sync_subintegrator!( + outer_integrator, child, child.solution_indices, NoExternalSynchronization() + ) + advance_solution_to!(outer_integrator, child, tnext) + backward_sync_subintegrator!( + outer_integrator, child, child.solution_indices, NoExternalSynchronization() + ) +end + +# Child is a leaf DEIntegrator +function _advance_child!( + outer_integrator::OperatorSplittingIntegrator, + child::DEIntegrator, + _inner_cache, + tnext + ) + dt = tnext - child.t + SciMLBase.step!(child, dt, true) + # If the leaf adaptive integrator failed unrecoverably, error immediately + if !SciMLBase.successful_retcode(child.sol.retcode) && + child.sol.retcode != ReturnCode.Default + if isadaptive(child) + error("Adaptive inner integrator failed unrecoverably with retcode $(child.sol.retcode). Aborting.") + end + # non-adaptive failure: signal to parent + end +end + +# --------------------------------------------------------------------------- +# _child_failed: check whether a child reported a failure +# --------------------------------------------------------------------------- +function _child_failed(outer_integrator, child::DEIntegrator) + return child.sol.retcode ∉ (ReturnCode.Default, ReturnCode.Success) +end + +function _child_failed(outer_integrator, child::SplitSubIntegrator) + return child.status.retcode ∉ (ReturnCode.Default, ReturnCode.Success) end diff --git a/src/utils.jl b/src/utils.jl index dcc9ee0..b1c83cd 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -14,8 +14,6 @@ function tstops_and_saveat_heaps(t0, tf, tstops, saveat) saveat = tf > t0 ? saveat : -saveat saveat = [t0:saveat:tf..., tf] else - # We do not need to filter saveat like tstops because the saving - # callback will ignore any times that are not between t0 and tf. saveat = collect(saveat) end saveat = DataStructures.BinaryHeap{FT, ordering}(saveat) @@ -26,8 +24,8 @@ end """ need_sync(a, b) -This function determines whether it is necessary to synchronize two objects with any solution information. -A possible reason when no synchronization is necessary might be that the vectors alias each other in memory. +Determines whether it is necessary to synchronize two objects with any +solution information. """ need_sync @@ -39,7 +37,7 @@ need_sync(a::SubArray, b::SubArray) = a.parent !== b.parent """ sync_vectors!(a, b) -Copies the information in object b into object a, if synchronization is necessary. +Copies the information in object `b` into object `a`, if synchronization is necessary. """ function sync_vectors!(a, b) return if need_sync(a, b) && a !== b @@ -47,13 +45,14 @@ function sync_vectors!(a, b) end end +# --------------------------------------------------------------------------- +# forward_sync_subintegrator! +# --------------------------------------------------------------------------- """ - forward_sync_subintegrator!(outer_integrator::OperatorSplittingIntegrator, inner_integrator::DEIntegrator, solution_indices, sync) + forward_sync_subintegrator!(outer_integrator, inner, solution_indices, sync) -This function is responsible of copying the solution and parameters of the outer integrator and the synchronized subintegrators with the information given into the inner integrator. -If the inner integrator is synchronized with other inner integrators using `sync`, the function `forward_sync_external!` shall be dispatched for `sync`. -The `sync` object is passed from the outside and is the main entry point to dispatch custom types on for parameter synchronization. -The `solution_indices` are global indices in the outer integrators solution vectors. +Copy state from the outer integrator into the inner integrator before a +sub-step, and apply any external parameter synchronisation via `sync`. """ function forward_sync_subintegrator!( outer_integrator::OperatorSplittingIntegrator, @@ -64,12 +63,31 @@ function forward_sync_subintegrator!( end """ - backward_sync_subintegrator!(outer_integrator::OperatorSplittingIntegrator, inner_integrator::DEIntegrator, solution_indices, sync) + forward_sync_subintegrator! for SplitSubIntegrator -This function is responsible of copying the solution of the inner integrator back into outer integrator and the synchronized subintegrators. -If the inner integrator is synchronized with other inner integrators using `sync`, the function `backward_sync_external!` shall be dispatched for `sync`. -The `sync` object is passed from the outside and is the main entry point to dispatch custom types on for parameter synchronization. -The `solution_indices` are global indices in the outer integrators solution vectors. +When the inner node is a `SplitSubIntegrator` we only need to copy the master +solution vector slice into its `u` (the `SplitSubIntegrator.u` is already a +view, but on a different device or after a rollback it may need refreshing). +""" +function forward_sync_subintegrator!( + outer_integrator::OperatorSplittingIntegrator, + sub::SplitSubIntegrator, solution_indices, sync + ) + # Sync the view: master → sub.u (noop if they already alias) + @views uouter = outer_integrator.u[solution_indices] + sync_vectors!(sub.u, uouter) + sync_vectors!(sub.uprev, uouter) + return forward_sync_external!(outer_integrator, sub, sync) +end + +# --------------------------------------------------------------------------- +# backward_sync_subintegrator! +# --------------------------------------------------------------------------- +""" + backward_sync_subintegrator!(outer_integrator, inner, solution_indices, sync) + +Copy state from the inner integrator back into the outer integrator after a +sub-step, and apply any external parameter synchronisation via `sync`. """ function backward_sync_subintegrator!( outer_integrator::OperatorSplittingIntegrator, @@ -79,9 +97,18 @@ function backward_sync_subintegrator!( return backward_sync_external!(outer_integrator, inner_integrator, sync) end -# This is a bit tricky, because per default the operator splitting integrators share their solution vector. However, there is also the case -# when part of the problem is on a different device (thing e.g. about operator A being on CPU and B being on GPU). -# This case should be handled with special synchronizers. +function backward_sync_subintegrator!( + outer_integrator::OperatorSplittingIntegrator, + sub::SplitSubIntegrator, solution_indices, sync + ) + @views uouter = outer_integrator.u[solution_indices] + sync_vectors!(uouter, sub.u) + return backward_sync_external!(outer_integrator, sub, sync) +end + +# --------------------------------------------------------------------------- +# forward_sync_internal! / backward_sync_internal! +# --------------------------------------------------------------------------- function forward_sync_internal!( outer_integrator::OperatorSplittingIntegrator, inner_integrator::OperatorSplittingIntegrator, solution_indices @@ -112,7 +139,9 @@ function backward_sync_internal!( return sync_vectors!(uouter, inner_integrator.u) end -# This is a noop, because operator splitting integrators do not have parameters for now +# --------------------------------------------------------------------------- +# forward_sync_external! / backward_sync_external! +# --------------------------------------------------------------------------- function forward_sync_external!( outer_integrator::OperatorSplittingIntegrator, inner_integrator::OperatorSplittingIntegrator, sync::NoExternalSynchronization @@ -125,12 +154,26 @@ function forward_sync_external!( ) return nothing end +# SplitSubIntegrator has no parameters for now → no-op +function forward_sync_external!( + outer_integrator::OperatorSplittingIntegrator, + sub::SplitSubIntegrator, sync::NoExternalSynchronization + ) + return nothing +end function forward_sync_external!( outer_integrator::OperatorSplittingIntegrator, inner_integrator::DEIntegrator, sync ) return synchronize_solution_with_parameters!(outer_integrator, inner_integrator.p, sync) end +function forward_sync_external!( + outer_integrator::OperatorSplittingIntegrator, + sub::SplitSubIntegrator, sync + ) + # SplitSubIntegrator does not carry p for now; dispatch on sync type if needed + return nothing +end function backward_sync_external!( outer_integrator::OperatorSplittingIntegrator, @@ -144,51 +187,50 @@ function backward_sync_external!( ) return nothing end +function backward_sync_external!( + outer_integrator::OperatorSplittingIntegrator, + sub::SplitSubIntegrator, sync::NoExternalSynchronization + ) + return nothing +end function backward_sync_external!( outer_integrator::OperatorSplittingIntegrator, inner_integrator::DEIntegrator, sync ) return synchronize_solution_with_parameters!(outer_integrator, inner_integrator.p, sync) end +function backward_sync_external!( + outer_integrator::OperatorSplittingIntegrator, + sub::SplitSubIntegrator, sync + ) + return nothing +end -function synchronize_solution_with_parameters!(outer_integrator::OperatorSplittingIntegrator, p, sync) +function synchronize_solution_with_parameters!( + outer_integrator::OperatorSplittingIntegrator, p, sync + ) @warn "Outer synchronizer not dispatched for parameter type $(typeof(p)) with synchronizer type $(typeof(sync))." maxlog = 1 return nothing end -# If we encounter NullParameters, then we have the convention for the standard sync map that no external solution is necessary. function synchronize_solution_with_parameters!( outer_integrator::OperatorSplittingIntegrator, p::NullParameters, sync ) return nothing end -# TODO this should go into a custom tree data structure instead of into a tuple-tree +# --------------------------------------------------------------------------- +# NOTE: build_solution_index_tree and build_synchronizer_tree are NO LONGER +# needed as standalone functions — the information is now embedded directly +# into each SplitSubIntegrator during build_subintegrator_tree_with_cache. +# They are kept here (no-ops returning nothing) only so that any external +# code that might call them does not hard-error. +# --------------------------------------------------------------------------- function build_solution_index_tree(f::GenericSplitFunction) - return ntuple( - i -> build_solution_index_tree_recursion(f.functions[i], f.solution_indices[i]), - length(f.functions) - ) -end - -function build_solution_index_tree_recursion(f::GenericSplitFunction, solution_indices) - return ntuple( - i -> build_solution_index_tree_recursion(f.functions[i], f.solution_indices[i]), - length(f.functions) - ) -end - -function build_solution_index_tree_recursion(f, solution_indices) - return solution_indices + # Deprecated: solution index trees now live inside SplitSubIntegrator. + return nothing end function build_synchronizer_tree(f::GenericSplitFunction) - return ntuple(i -> build_synchronizer_tree_recursion(f.functions[i], f.synchronizers[i]), length(f.functions)) -end - -function build_synchronizer_tree_recursion(f::GenericSplitFunction, synchronizers) - return ntuple(i -> build_synchronizer_tree_recursion(f.functions[i], f.synchronizers[i]), length(f.functions)) -end - -function build_synchronizer_tree_recursion(f, synchronizer) - return synchronizer + # Deprecated: synchronizer trees now live inside SplitSubIntegrator. + return nothing end diff --git a/test/operator_splitting_api.jl b/test/operator_splitting_api.jl index 51537c5..57a514c 100644 --- a/test/operator_splitting_api.jl +++ b/test/operator_splitting_api.jl @@ -8,7 +8,9 @@ using OrdinaryDiffEqLowOrderRK using OrdinaryDiffEqTsit5 using ModelingToolkit -# Reference +# --------------------------------------------------------------------------- +# Reference problem +# --------------------------------------------------------------------------- tspan = (0.0, 100.0) u0 = [ 0.7611944793397108 @@ -33,26 +35,22 @@ end trueu = exp((tspan[2] - tspan[1]) * (trueA + trueB)) * u0 # Setup individual functions -# Diagonal components function ode1(du, u, p, t) return @. du = -0.1u end f1 = ODEFunction(ode1) -# Off-diagonal components function ode2(du, u, p, t) du[1] = -0.01u[2] return du[2] = -0.01u[1] end f2 = ODEFunction(ode2) -# Now some recursive splitting function ode3(du, u, p, t) du[1] = -0.005u[2] return du[2] = -0.005u[1] end f3 = ODEFunction(ode3) -# The time stepper carries the individual solver information. @independent_variables time Dt = Differential(time) @@ -69,99 +67,122 @@ end @named testmodel2 = TestModelODE2() testsys2 = mtkcompile(testmodel2; sort_eqs = false) -# Test whether adaptive code path works in principle +# --------------------------------------------------------------------------- +# FakeAdaptiveAlgorithm — tests adaptive code path +# +# With the new interface FakeAdaptiveAlgorithm no longer needs to override +# build_subintegrator_tree_with_cache. It just wraps the standard cache in +# its own FakeAdaptiveAlgorithmCache. +# --------------------------------------------------------------------------- struct FakeAdaptiveAlgorithm{T} <: OS.AbstractOperatorSplittingAlgorithm alg::T + inner_algs::T # delegate inner_algs to the wrapped algorithm end +FakeAdaptiveAlgorithm(alg::T) where {T} = FakeAdaptiveAlgorithm{T}(alg, alg.inner_algs) + struct FakeAdaptiveAlgorithmCache{T} <: OS.AbstractOperatorSplittingCache cache::T end + @inline DiffEqBase.isadaptive(::FakeAdaptiveAlgorithm) = true -@inline function OS.stepsize_controller!(integrator::OS.OperatorSplittingIntegrator, alg::FakeAdaptiveAlgorithm) +@inline function OS.stepsize_controller!( + integrator::OS.OperatorSplittingIntegrator, alg::FakeAdaptiveAlgorithm + ) return nothing end -@inline function OS.step_accept_controller!(integrator::OS.OperatorSplittingIntegrator, alg::FakeAdaptiveAlgorithm, q) +@inline function OS.step_accept_controller!( + integrator::OS.OperatorSplittingIntegrator, alg::FakeAdaptiveAlgorithm, q + ) integrator.dt = integrator.dtcache return nothing end -@inline function OS.step_reject_controller!(integrator::OS.OperatorSplittingIntegrator, alg::FakeAdaptiveAlgorithm, q) +@inline function OS.step_reject_controller!( + integrator::OS.OperatorSplittingIntegrator, alg::FakeAdaptiveAlgorithm, q + ) error("The tests should never run into this scenario!") - return nothing # Do nothing + return nothing end -function OS.build_subintegrator_tree_with_cache( - prob::OS.OperatorSplittingProblem, alg::FakeAdaptiveAlgorithm, - uprevouter::AbstractVector, uouter::AbstractVector, - solution_indices, - t0, dt, tf, - tstops, saveat, d_discontinuities, callback, - adaptive, verbose, - ) - subintegrators, inner_cache = OS.build_subintegrator_tree_with_cache( - prob, alg.alg, uprevouter, uouter, solution_indices, - t0, dt, tf, - tstops, saveat, d_discontinuities, callback, - adaptive, verbose, - ) - return subintegrators, FakeAdaptiveAlgorithmCache( - inner_cache, - ) +# Override init_cache to wrap the inner cache in FakeAdaptiveAlgorithmCache +function OS.init_cache( + f::GenericSplitFunction, alg::FakeAdaptiveAlgorithm; + kwargs... + ) + inner_cache = OS.init_cache(f, alg.alg; kwargs...) + return FakeAdaptiveAlgorithmCache(inner_cache) end -function OS.build_subintegrator_tree_with_cache( - prob::OS.OperatorSplittingProblem, alg::FakeAdaptiveAlgorithm, - f::GenericSplitFunction, p::Tuple, - uprevouter::AbstractVector, uouter::AbstractVector, - solution_indices, - t0, dt, tf, - tstops, saveat, d_discontinuities, callback, - adaptive, verbose, - save_end = false, - controller = nothing + +@inline DiffEqBase.get_tmp_cache( + integrator::OS.OperatorSplittingIntegrator, + alg::OS.AbstractOperatorSplittingAlgorithm, + cache::FakeAdaptiveAlgorithmCache + ) = DiffEqBase.get_tmp_cache(integrator, alg, cache.cache) + +@inline function OS.advance_solution_to!( + outer_integrator::OS.OperatorSplittingIntegrator, + subintegrators::Tuple, + cache::FakeAdaptiveAlgorithmCache, + tnext ) - subintegrators, inner_cache = OS.build_subintegrator_tree_with_cache( - prob, alg.alg, f, p, uprevouter, uouter, solution_indices, - t0, dt, tf, - tstops, saveat, d_discontinuities, callback, - adaptive, verbose, + return OS.advance_solution_to!( + outer_integrator, subintegrators, cache.cache, tnext ) +end - return subintegrators, FakeAdaptiveAlgorithmCache( - inner_cache, - ) +# For SplitSubIntegrator nodes whose cache was wrapped by FakeAdaptiveAlgorithmCache +@inline function OS.advance_solution_to!( + outer_integrator::OS.OperatorSplittingIntegrator, + sub::OS.SplitSubIntegrator, + subintegrators::Tuple, + solution_indices::Tuple, + synchronizers::Tuple, + cache::FakeAdaptiveAlgorithmCache, + tnext + ) + return OS.advance_solution_to!( + outer_integrator, sub, subintegrators, solution_indices, + synchronizers, cache.cache, tnext + ) end + FakeAdaptiveLTG(inner) = FakeAdaptiveAlgorithm(LieTrotterGodunov(inner)) -@inline DiffEqBase.get_tmp_cache(integrator::OS.OperatorSplittingIntegrator, alg::OS.AbstractOperatorSplittingAlgorithm, cache::FakeAdaptiveAlgorithmCache) = DiffEqBase.get_tmp_cache(integrator, alg, cache.cache) -@inline function OS.advance_solution_to!(outer_integrator::OS.OperatorSplittingIntegrator, subintegrators::Tuple, solution_indices::Tuple, synchronizers::Tuple, cache::FakeAdaptiveAlgorithmCache, tnext) - return OS.advance_solution_to!(outer_integrator, subintegrators, solution_indices, synchronizers, cache.cache, tnext) +# --------------------------------------------------------------------------- +# Helper: given the outer integrator, walk into the first SplitSubIntegrator +# and find the first leaf DEIntegrator. +# --------------------------------------------------------------------------- +function first_leaf_integrator(integrator) + node = integrator.subintegrator_tree[1] + while node isa OS.SplitSubIntegrator + node = node.subintegrator_tree[1] + end + return node end +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- @testset "reinit and convergence" begin dt = 0.01π - # Here we describe index sets f1dofs and f2dofs that map the - # local indices in f1 and f2 into the global problem. Just put - # ode_true and ode1/ode2 side by side to see how they connect. f1dofs = [1, 2, 3] f2dofs = [1, 3] fsplit1a = GenericSplitFunction((f1, f2), (f1dofs, f2dofs)) fsplit1b = GenericSplitFunction((f1, testsys2), (f1dofs, f2dofs)) - # Now the usual setup just with our new problem type. prob1a = OperatorSplittingProblem(fsplit1a, u0, tspan) prob1b = OperatorSplittingProblem(fsplit1b, u0, tspan) - # Note that we define the dof indices w.r.t the parent function. - # Hence the indices for `fsplit2_inner` are. - f1dofs = [1, 2, 3] - f2dofs = [1, 3] f3dofs = [1, 3] fsplit2_inner = GenericSplitFunction((f3, f3), (f3dofs, f3dofs)) fsplit2_outer = GenericSplitFunction((f1, fsplit2_inner), (f1dofs, f2dofs)) prob2 = OperatorSplittingProblem(fsplit2_outer, u0, tspan) + + nsteps = ceil(Int, (tspan[2] - tspan[1]) / dt) + for TimeStepperType in (LieTrotterGodunov, FakeAdaptiveLTG) @testset "Solver type $TimeStepperType | $tstepper" for (prob, tstepper) in ( (prob1a, TimeStepperType((Euler(), Euler()))), @@ -179,20 +200,24 @@ end (prob2, TimeStepperType((Tsit5(), TimeStepperType((Euler(), Tsit5()))))), (prob2, TimeStepperType((Tsit5(), TimeStepperType((Tsit5(), Tsit5()))))), ) - # The remaining code works as usual. integrator = DiffEqBase.init( prob, tstepper, dt = dt, verbose = true, alias_u0 = false, adaptive = false ) @test integrator.sol.retcode == DiffEqBase.ReturnCode.Default + DiffEqBase.solve!(integrator) @test integrator.sol.retcode == DiffEqBase.ReturnCode.Success ufinal = copy(integrator.u) @test isapprox(ufinal, trueu, atol = 1.0e-6) @test integrator.t ≈ tspan[2] - @test integrator.subintegrator_tree[1].t ≈ tspan[2] @test integrator.dtcache ≈ dt - @test integrator.iter == ceil(Int, (tspan[2] - tspan[1]) / dt) - @test integrator.subintegrator_tree[1].iter == ceil(Int, (tspan[2] - tspan[1]) / dt) + @test integrator.iter == nsteps + + # SplitSubIntegrators now carry t and iter at each level + sub1 = integrator.subintegrator_tree[1] + @test sub1 isa OS.SplitSubIntegrator + @test sub1.t ≈ tspan[2] + @test sub1.iter == nsteps DiffEqBase.reinit!(integrator; dt = dt) @test integrator.sol.retcode == DiffEqBase.ReturnCode.Default @@ -200,10 +225,8 @@ end end @test isapprox(ufinal, integrator.u, atol = 1.0e-12) @test integrator.t ≈ tspan[2] - @test integrator.subintegrator_tree[1].t ≈ tspan[2] @test integrator.dtcache ≈ dt - @test integrator.iter == ceil(Int, (tspan[2] - tspan[1]) / dt) - @test integrator.subintegrator_tree[1].iter == ceil(Int, (tspan[2] - tspan[1]) / dt) + @test integrator.iter == nsteps DiffEqBase.reinit!(integrator; dt = dt) @test integrator.sol.retcode == DiffEqBase.ReturnCode.Default @@ -211,20 +234,16 @@ end end @test isapprox(ufinal, integrator.u, atol = 1.0e-12) @test integrator.t ≈ tspan[2] - @test integrator.subintegrator_tree[1].t ≈ tspan[2] @test integrator.dtcache ≈ dt - @test integrator.iter == ceil(Int, (tspan[2] - tspan[1]) / dt) - @test integrator.subintegrator_tree[1].iter == ceil(Int, (tspan[2] - tspan[1]) / dt) + @test integrator.iter == nsteps DiffEqBase.reinit!(integrator; dt = dt) @test integrator.sol.retcode == DiffEqBase.ReturnCode.Default DiffEqBase.solve!(integrator) @test integrator.sol.retcode == DiffEqBase.ReturnCode.Success @test integrator.t ≈ tspan[2] - @test integrator.subintegrator_tree[1].t ≈ tspan[2] @test integrator.dtcache ≈ dt - @test integrator.iter == ceil(Int, (tspan[2] - tspan[1]) / dt) - @test integrator.subintegrator_tree[1].iter == ceil(Int, (tspan[2] - tspan[1]) / dt) + @test integrator.iter == nsteps end end @@ -233,7 +252,6 @@ end (prob1a, TimeStepperType((Tsit5(), Tsit5()))), (prob2, TimeStepperType((Tsit5(), TimeStepperType((Tsit5(), Tsit5()))))), ) - # The remaining code works as usual. integrator = DiffEqBase.init( prob, tstepper, dt = dt, verbose = true, alias_u0 = false, adaptive = true ) @@ -243,10 +261,8 @@ end ufinal = copy(integrator.u) @test isapprox(ufinal, trueu, atol = 1.0e-6) @test integrator.t ≈ tspan[2] - @test integrator.subintegrator_tree[1].t ≈ tspan[2] @test integrator.dtcache ≈ dt - @test integrator.iter == ceil(Int, (tspan[2] - tspan[1]) / dt) - @test integrator.subintegrator_tree[1].iter == ceil(Int, (tspan[2] - tspan[1]) / dt) + @test integrator.iter == nsteps DiffEqBase.reinit!(integrator; dt = dt) @test integrator.sol.retcode == DiffEqBase.ReturnCode.Default @@ -254,10 +270,8 @@ end end @test isapprox(ufinal, integrator.u, atol = 1.0e-12) @test integrator.t ≈ tspan[2] - @test integrator.subintegrator_tree[1].t ≈ tspan[2] @test integrator.dtcache ≈ dt - @test integrator.iter == ceil(Int, (tspan[2] - tspan[1]) / dt) - @test integrator.subintegrator_tree[1].iter == ceil(Int, (tspan[2] - tspan[1]) / dt) + @test integrator.iter == nsteps DiffEqBase.reinit!(integrator; dt = dt) @test integrator.sol.retcode == DiffEqBase.ReturnCode.Default @@ -265,24 +279,20 @@ end end @test isapprox(ufinal, integrator.u, atol = 1.0e-12) @test integrator.t ≈ tspan[2] - @test integrator.subintegrator_tree[1].t ≈ tspan[2] @test integrator.dtcache ≈ dt - @test integrator.iter == ceil(Int, (tspan[2] - tspan[1]) / dt) - @test integrator.subintegrator_tree[1].iter == ceil(Int, (tspan[2] - tspan[1]) / dt) + @test integrator.iter == nsteps DiffEqBase.reinit!(integrator; dt = dt) @test integrator.sol.retcode == DiffEqBase.ReturnCode.Default DiffEqBase.solve!(integrator) @test integrator.sol.retcode == DiffEqBase.ReturnCode.Success @test integrator.t ≈ tspan[2] - @test integrator.subintegrator_tree[1].t ≈ tspan[2] @test integrator.dtcache ≈ dt - @test integrator.iter == ceil(Int, (tspan[2] - tspan[1]) / dt) - @test integrator.subintegrator_tree[1].iter == ceil(Int, (tspan[2] - tspan[1]) / dt) + @test integrator.iter == nsteps end end - @testset "Instbility detectioon" begin + @testset "Instability detection" begin dt = 0.01π function ode_NaN(du, u, p, t) @@ -291,11 +301,10 @@ end end f1dofs = [1, 2, 3] - f2dofs = [1, 3] + f3dofs = [1, 3] f_NaN = ODEFunction(ode_NaN) - f_NaN_dofs = f3dofs - fsplit_NaN = GenericSplitFunction((f1, f_NaN), (f1dofs, f_NaN_dofs)) + fsplit_NaN = GenericSplitFunction((f1, f_NaN), (f1dofs, f3dofs)) prob_NaN = OperatorSplittingProblem(fsplit_NaN, u0, tspan) for TimeStepperType in (LieTrotterGodunov,) From a74998e4f9f372b720898171e4435309dff01e6c Mon Sep 17 00:00:00 2001 From: termi-official Date: Fri, 20 Feb 2026 20:01:05 +0100 Subject: [PATCH 02/17] Try to recover tests --- Project.toml | 6 +- src/integrator.jl | 852 ++++++++++++++------------------- src/solver.jl | 204 ++++---- src/utils.jl | 278 +++++------ test/operator_splitting_api.jl | 1 - 5 files changed, 605 insertions(+), 736 deletions(-) diff --git a/Project.toml b/Project.toml index 46754c0..74eb749 100644 --- a/Project.toml +++ b/Project.toml @@ -20,13 +20,14 @@ CommonSolve = "0.2.4" DataStructures = "0.18.22, 0.19" DiffEqBase = "6.165.1" ExplicitImports = "1" -ModelingToolkit = "10" +ModelingToolkit = "10, 11" OrdinaryDiffEqCore = "1.19.0, 2, 3.1" OrdinaryDiffEqLowOrderRK = "1.7" OrdinaryDiffEqTsit5 = "1.1.0" PrecompileTools = "1.0" RecursiveArrayTools = "3.39.0" SafeTestsets = "0.1.0" +SciCompDSL = "1" SciMLBase = "2.77.0" TimerOutputs = "0.5.28" Unrolled = "0.1.5" @@ -37,7 +38,8 @@ ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78" OrdinaryDiffEqTsit5 = "b1df2697-797e-41e3-8120-5422d3b24e4a" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" +SciCompDSL = "91a8cdf1-4ca6-467b-a780-87fda3fff15e" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["ExplicitImports", "ModelingToolkit", "OrdinaryDiffEqTsit5", "SafeTestsets", "Test"] +test = ["ExplicitImports", "ModelingToolkit", "OrdinaryDiffEqTsit5", "SafeTestsets", "SciCompDSL", "Test"] diff --git a/src/integrator.jl b/src/integrator.jl index af196b7..4b66e55 100644 --- a/src/integrator.jl +++ b/src/integrator.jl @@ -22,8 +22,8 @@ end SplitSubIntegratorStatus Minimal error-communication object carried by a [`SplitSubIntegrator`](@ref). -It contains only the `retcode` so that failure can be propagated up the -operator-splitting tree without carrying an actual solution vector. +Holds only `retcode` so that failure can be propagated up the operator-splitting +tree without carrying an actual solution vector. """ mutable struct SplitSubIntegratorStatus retcode::ReturnCode.T @@ -37,30 +37,34 @@ SplitSubIntegratorStatus() = SplitSubIntegratorStatus(ReturnCode.Default) """ SplitSubIntegrator -An intermediate node in the operator-splitting subintegrator tree. It is -self-contained: it knows its own solution indices, its child synchronizers, -and the child solution-index tree. It does **not** carry an `f` field -(operator information lives in the cache / algorithm). - -Fields ------- -- `alg` — the `AbstractOperatorSplittingAlgorithm` at this level -- `u` — view into the *master* solution vector for this sub-problem -- `uprev` — copy of `u` at the start of a step (for rollback) -- `u_master` — reference to the full master solution vector of the - outermost `OperatorSplittingIntegrator` (needed during sync) -- `t`, `dt`, `dtcache` — time tracking -- `iter` — step counter -- `EEst` — error estimate (`NaN` for non-adaptive, `1.0` default for adaptive) -- `controller` — step-size controller (or `nothing` for non-adaptive) -- `force_stepfail` — flag set when a step must be re-tried -- `last_step_failed` — flag set after a failed step to detect double-failure -- `status` — [`SplitSubIntegratorStatus`](@ref) for retcode communication -- `cache` — `AbstractOperatorSplittingCache` for the algorithm at this level -- `subintegrator_tree` — tuple of child integrators (`SplitSubIntegrator` or `DEIntegrator`) -- `solution_indices` — global indices (into master `u`) owned by this sub-integrator -- `solution_index_tree`— per-child global solution indices -- `synchronizer_tree` — per-child synchronizer objects +An intermediate node in the operator-splitting subintegrator tree. + +Each `SplitSubIntegrator` is self-contained: it knows its own solution indices, +its children's synchronizers, solution indices, and sub-integrators. It does +**not** carry an `f` field (operator information lives in the cache/algorithm). + +## Fields +- `alg` — `AbstractOperatorSplittingAlgorithm` at this level +- `u` — local solution buffer for this sub-problem (may be a + view *or* an independent array, e.g. for GPU sub-problems) +- `uprev` — copy of `u` at the start of a step (for rollback) +- `u_master` — reference to the full master solution vector of the + outermost `OperatorSplittingIntegrator` (needed for sync) +- `t`, `dt`, `dtcache` — time tracking +- `iter` — step counter at this level +- `EEst` — error estimate (`NaN` for non-adaptive, `1.0` default + for adaptive) +- `controller` — step-size controller (or `nothing` for non-adaptive) +- `force_stepfail` — flag: current step must be retried +- `last_step_failed` — flag: previous step failed (double-failure detection) +- `status` — [`SplitSubIntegratorStatus`](@ref) for retcode communication +- `cache` — `AbstractOperatorSplittingCache` for the algorithm at + this level +- `child_subintegrators` — tuple of direct children (`SplitSubIntegrator` or + `DEIntegrator`) +- `solution_indices` — global indices (into master `u`) **owned by this node** +- `child_solution_indices` — tuple of per-child global solution indices +- `child_synchronizers` — tuple of per-child synchronizer objects """ mutable struct SplitSubIntegrator{ algType, @@ -69,15 +73,15 @@ mutable struct SplitSubIntegrator{ EEstType, controllerType, cacheType, - subintTreeType, + childSubintType, solidxType, - solidxTreeType, - syncTreeType, + childSolidxType, + childSyncType, } alg::algType - u::uType # view into master u for this sub-problem - uprev::uType # local copy for rollback (same element type, plain Array) - u_master::uType # reference to the outermost master u + u::uType # local solution buffer + uprev::uType # local rollback buffer + u_master::uType # reference to outermost master u t::tType dt::tType dtcache::tType @@ -88,24 +92,40 @@ mutable struct SplitSubIntegrator{ last_step_failed::Bool status::SplitSubIntegratorStatus cache::cacheType - subintegrator_tree::subintTreeType # Tuple + child_subintegrators::childSubintType # Tuple solution_indices::solidxType - solution_index_tree::solidxTreeType # Tuple - synchronizer_tree::syncTreeType # Tuple + child_solution_indices::childSolidxType # Tuple + child_synchronizers::childSyncType # Tuple end -# Convenience predicate +# --- SplitSubIntegrator interface --- + @inline SciMLBase.isadaptive(sub::SplitSubIntegrator) = isadaptive(sub.alg) +# proposed-dt interface (mirrors ODEIntegrator) +function SciMLBase.set_proposed_dt!(sub::SplitSubIntegrator, dt) + if sub.dtcache != dt # only touch if actually changing + sub.dtcache = dt + if !isadaptive(sub) + sub.dt = dt + end + end + return nothing +end + # --------------------------------------------------------------------------- # OperatorSplittingIntegrator # --------------------------------------------------------------------------- """ OperatorSplittingIntegrator <: AbstractODEIntegrator -A variant of [`ODEIntegrator`](https://github.com/SciML/OrdinaryDiffEq.jl/blob/6ec5a55bda26efae596bf99bea1a1d729636f412/src/integrators/type.jl#L77-L123) to perform operator splitting. +A variant of [`ODEIntegrator`](https://github.com/SciML/OrdinaryDiffEq.jl/blob/6ec5a55bda26efae596bf99bea1a1d729636f412/src/integrators/type.jl#L77-L123) +to perform operator splitting. Derived from https://github.com/CliMA/ClimaTimeSteppers.jl/blob/ef3023747606d2750e674d321413f80638136632/src/integrators.jl. + +Note: `solution_index_tree` and `synchronizer_tree` have been removed; this +information now lives inside each [`SplitSubIntegrator`](@ref) child node. """ mutable struct OperatorSplittingIntegrator{ fType, @@ -120,39 +140,38 @@ mutable struct OperatorSplittingIntegrator{ cacheType, solType, subintTreeType, + childSolidxType, + childSyncType, controllerType, optionsType, } <: SciMLBase.AbstractODEIntegrator{algType, true, uType, tType} const f::fType const alg::algType - u::uType # Master Solution - uprev::uType # Master Solution - tmp::uType # Interpolation buffer + u::uType # Master solution + uprev::uType # Master solution previous step + tmp::uType # Interpolation buffer p::pType - t::tType # Current time + t::tType # Current time tprev::tType - dt::tType # Time step length used during time marching - dtcache::tType # Proposed time step length - const dtchangeable::Bool # Indicator whether dtcache can be changed + dt::tType # Time step length used during time marching + dtcache::tType # Proposed time step length + const dtchangeable::Bool tstops::heapType - _tstops::tstopsType # argument to __init used as default argument to reinit! + _tstops::tstopsType saveat::heapType - _saveat::saveatType # argument to __init used as default argument to reinit! + _saveat::saveatType callback::callbackType advance_to_tstop::Bool - # TODO group these into some internal flag struct last_step_failed::Bool force_stepfail::Bool isout::Bool u_modified::Bool - # DiffEqBase.initialize! and DiffEqBase.finalize! cache::cacheType sol::solType - # NOTE: solution_index_tree and synchronizer_tree have been moved into - # the SplitSubIntegrator nodes. The flat subintegrator_tree here is a - # Tuple of SplitSubIntegrator (or DEIntegrator for the degenerate - # single-level case). - subintegrator_tree::subintTreeType + # Tuple of SplitSubIntegrator nodes (one per top-level operator). + child_subintegrators::subintTreeType + child_solution_indices::childSolidxType # Tuple + child_synchronizers::childSyncType # Tuple iter::Int controller::controllerType opts::optionsType @@ -160,10 +179,15 @@ mutable struct OperatorSplittingIntegrator{ tdir::tType end +# Convenience: the old field name `subintegrator_tree` was used in tests and +# docs; alias it so external code still compiles during the transition. +# (Remove in a future breaking release.) +@inline Base.getproperty(i::OperatorSplittingIntegrator, s::Symbol) = + s === :subintegrator_tree ? getfield(i, :child_subintegrators) : getfield(i, s) + # --------------------------------------------------------------------------- # __init # --------------------------------------------------------------------------- -# called by DiffEqBase.init and DiffEqBase.solve function SciMLBase.__init( prob::OperatorSplittingProblem, alg::AbstractOperatorSplittingAlgorithm, @@ -189,10 +213,10 @@ function SciMLBase.__init( dt = tf > t0 ? dt : -dt tType = typeof(dt) - # Warn if the algorithm is non-adaptive but the user tries to make it adaptive. - (!isadaptive(alg) && adaptive && verbose) && warn("The algorithm $alg is not adaptive.") + (!isadaptive(alg) && adaptive && verbose) && + @warn("The algorithm $alg is not adaptive.") - dtchangeable = true # isadaptive(alg) + dtchangeable = true if tstops isa AbstractArray || tstops isa Tuple || tstops isa Number _tstops = nothing @@ -201,62 +225,52 @@ function SciMLBase.__init( tstops = () end - # Setup tstop logic tstops_internal = OrdinaryDiffEqCore.initialize_tstops( tType, tstops, d_discontinuities, prob.tspan ) - saveat_internal = OrdinaryDiffEqCore.initialize_saveat(tType, saveat, prob.tspan) + saveat_internal = OrdinaryDiffEqCore.initialize_saveat(tType, saveat, prob.tspan) d_discontinuities_internal = OrdinaryDiffEqCore.initialize_d_discontinuities( tType, d_discontinuities, prob.tspan ) - u = setup_u(prob, alg, alias_u0) + u = setup_u(prob, alg, alias_u0) uprev = setup_u(prob, alg, false) - tmp = setup_u(prob, alg, false) + tmp = setup_u(prob, alg, false) uType = typeof(u) - sol = SciMLBase.build_solution(prob, alg, tType[], uType[]) - + sol = SciMLBase.build_solution(prob, alg, tType[], uType[]) callback = DiffEqBase.CallbackSet(callback) - # Build the subintegrator tree. Each SplitSubIntegrator is now - # self-contained: it holds its own solution_indices, solution_index_tree, - # and synchronizer_tree. - subintegrator_tree, cache = build_subintegrator_tree_with_cache( + child_subintegrators, cache = build_subintegrators( prob, alg, uprev, u, - u, # u_master == u at the outermost level + u, # u_master == u at the outermost level 1:length(u), t0, dt, tf, tstops, saveat, d_discontinuities, callback, adaptive, verbose ) + child_solution_indices = ntuple(i -> prob.f.solution_indices[i], length(prob.f.functions)) + child_synchronizers = ntuple(i -> prob.f.synchronizers[i], length(prob.f.functions)) + integrator = OperatorSplittingIntegrator( prob.f, alg, - u, - uprev, - tmp, + u, uprev, tmp, p, - t0, - copy(t0), - dt, - dtcache, + t0, copy(dt), + dt, dtcache, dtchangeable, - tstops_internal, - tstops, - saveat_internal, - saveat, + tstops_internal, tstops, + saveat_internal, saveat, callback, advance_to_tstop, - false, - false, - false, - false, - cache, - sol, - subintegrator_tree, + false, false, false, false, + cache, sol, + child_subintegrators, + child_solution_indices, + child_synchronizers, 0, controller, IntegratorOptions(; verbose, adaptive), @@ -267,7 +281,11 @@ function SciMLBase.__init( return integrator end +# --------------------------------------------------------------------------- +# reinit! +# --------------------------------------------------------------------------- SciMLBase.has_reinit(integrator::OperatorSplittingIntegrator) = true + function DiffEqBase.reinit!( integrator::OperatorSplittingIntegrator, u0 = integrator.sol.prob.u0; @@ -280,14 +298,15 @@ function DiffEqBase.reinit!( reinit_callbacks = true, reinit_retcode = true ) - integrator.u .= u0 + integrator.u .= u0 integrator.uprev .= u0 - integrator.t = t0 + integrator.t = t0 integrator.tprev = t0 if dt !== nothing integrator.dt = dt end - integrator.tstops, integrator.saveat = tstops_and_saveat_heaps(t0, tf, tstops, saveat) + integrator.tstops, integrator.saveat = + tstops_and_saveat_heaps(t0, tf, tstops, saveat) integrator.iter = 0 if erase_sol resize!(integrator.sol.t, 0) @@ -295,7 +314,7 @@ function DiffEqBase.reinit!( end if reinit_callbacks DiffEqBase.initialize!(integrator.callback, u0, t0, integrator) - else # always reinit the saving callback so that t0 can be saved if needed + else saving_callback = integrator.callback.discrete_callbacks[end] DiffEqBase.initialize!(saving_callback, u0, t0, integrator) end @@ -305,117 +324,86 @@ function DiffEqBase.reinit!( ) end - return subreinit!( + _subreinit_tuple!( integrator.f, u0, - integrator.subintegrator_tree; + integrator.child_subintegrators; t0, tf, dt, - erase_sol, - tstops, - saveat, - reinit_callbacks, - reinit_retcode + erase_sol, tstops, saveat, + reinit_callbacks, reinit_retcode ) + return nothing end -# subreinit! for a leaf DEIntegrator -function subreinit!( - f, - u0, - subintegrator::DEIntegrator; - dt, - kwargs... - ) - # dt is not reset as expected in reinit! - if dt !== nothing - subintegrator.dt = dt - end - # solution_indices are carried by the parent SplitSubIntegrator - error("subreinit! called directly on a DEIntegrator — should be reached only via SplitSubIntegrator") -end - -# subreinit! for an intermediate SplitSubIntegrator -function subreinit!( - f, - u0, - sub::SplitSubIntegrator; - t0, - tf, - dt, - kwargs... - ) - idxs = sub.solution_indices - sub.u .= @view u0[idxs] - sub.uprev .= @view u0[idxs] - sub.t = t0 - if dt !== nothing - sub.dt = dt - sub.dtcache = dt - end - sub.iter = 0 - sub.force_stepfail = false - sub.last_step_failed = false - sub.status = SplitSubIntegratorStatus(ReturnCode.Default) - # Reset EEst to the appropriate default - if isadaptive(sub) - sub.EEst = one(sub.EEst) - else - sub.EEst = sub.EEst # keep NaN sentinel - end - return subreinit_children!(f, u0, sub; t0, tf, dt, kwargs...) -end +# --- subreinit! helpers --- -@unroll function subreinit_children!( +# Iterate over a tuple of children (outermost call from reinit!) +@unroll function _subreinit_tuple!( f, u0, - sub::SplitSubIntegrator; + children::Tuple; kwargs... ) i = 1 - @unroll for child in sub.subintegrator_tree - _subreinit_child!(get_operator(f, i), u0, child, sub.solution_index_tree[i]; kwargs...) + @unroll for child in children + _subreinit_child!(get_operator(f, i), u0, child; kwargs...) i += 1 end end -# Dispatch for leaf DEIntegrator children +# Reinitialise a leaf DEIntegrator child function _subreinit_child!( f_child, u0, - child::DEIntegrator, - child_solution_indices; + child::DEIntegrator; dt, kwargs... ) - if dt !== nothing - child.dt = dt + if dt !== nothing && child.dtchangeable + SciMLBase.set_proposed_dt!(child, dt) end - return DiffEqBase.reinit!(child, @view(u0[child_solution_indices]); kwargs...) + # solution_indices live on the parent SplitSubIntegrator (or on the outer + # integrator for top-level children) — they were baked into child at init. + # reinit! on an ODEIntegrator resets u from its prob.u0; we need to pass + # the correct slice here. The parent calls us with the correct f_child + # but not the indices — those are embedded in child.sol.prob.u0 already + # because we constructed child with a view/copy of the right slice. + return DiffEqBase.reinit!(child; kwargs...) end -# Dispatch for nested SplitSubIntegrator children +# Reinitialise an intermediate SplitSubIntegrator child function _subreinit_child!( f_child, u0, - child::SplitSubIntegrator, - _child_solution_indices; # ignored — child carries its own + sub::SplitSubIntegrator; + t0, + tf, + dt, kwargs... ) - return subreinit!(f_child, u0, child; kwargs...) -end - -# Top-level subreinit! over a tuple of subintegrators (called from reinit!) -@unroll function subreinit!( - f, + idxs = sub.solution_indices + sub.u .= @view u0[idxs] + sub.uprev .= @view u0[idxs] + sub.t = t0 + if dt !== nothing + SciMLBase.set_proposed_dt!(sub, dt) + end + sub.iter = 0 + sub.force_stepfail = false + sub.last_step_failed = false + sub.status = SplitSubIntegratorStatus(ReturnCode.Default) + # Reset EEst to its appropriate default + if isadaptive(sub) + sub.EEst = one(sub.EEst) + end + # Recurse into this node's children + _subreinit_tuple!( + f_child, u0, - subintegrators::Tuple; - kwargs... + sub.child_subintegrators; + t0, tf, dt, kwargs... ) - i = 1 - @unroll for sub in subintegrators - subreinit!(get_operator(f, i), u0, sub; kwargs...) - i += 1 - end + return nothing end # --------------------------------------------------------------------------- @@ -423,11 +411,11 @@ end # --------------------------------------------------------------------------- function OrdinaryDiffEqCore.handle_tstop!(integrator::OperatorSplittingIntegrator) if SciMLBase.has_tstop(integrator) - tdir_t = tdir(integrator) * integrator.t + tdir_t = tdir(integrator) * integrator.t tdir_tstop = SciMLBase.first_tstop(integrator) if tdir_t == tdir_tstop - while tdir_t == tdir_tstop #remove all redundant copies - res = SciMLBase.pop_tstop!(integrator) + while tdir_t == tdir_tstop + SciMLBase.pop_tstop!(integrator) SciMLBase.has_tstop(integrator) ? (tdir_tstop = SciMLBase.first_tstop(integrator)) : break end @@ -436,8 +424,8 @@ function OrdinaryDiffEqCore.handle_tstop!(integrator::OperatorSplittingIntegrato if !integrator.dtchangeable SciMLBase.change_t_via_interpolation!( integrator, - tdir(integrator) * - SciMLBase.pop_tstop!(integrator), Val{true} + tdir(integrator) * SciMLBase.pop_tstop!(integrator), + Val{true} ) notify_integrator_hit_tstop!(integrator) else @@ -450,36 +438,36 @@ end notify_integrator_hit_tstop!(integrator::OperatorSplittingIntegrator) = nothing -is_first_iteration(integrator::OperatorSplittingIntegrator) = integrator.iter == 0 +is_first_iteration(integrator::OperatorSplittingIntegrator) = integrator.iter == 0 increment_iteration(integrator::OperatorSplittingIntegrator) = integrator.iter += 1 # --------------------------------------------------------------------------- -# Controller interface — outermost integrator +# Step accept/reject — outermost integrator # --------------------------------------------------------------------------- function reject_step!(integrator::OperatorSplittingIntegrator) OrdinaryDiffEqCore.increment_reject!(integrator.stats) return reject_step!(integrator, integrator.cache, integrator.controller) end function reject_step!(integrator::OperatorSplittingIntegrator, cache, controller) - return integrator.u .= integrator.uprev - # TODO what do we need to do with the subintegrators? + integrator.u .= integrator.uprev + # TODO: roll back sub-integrators + return nothing end function reject_step!(integrator::OperatorSplittingIntegrator, cache, ::Nothing) - return if length(integrator.uprev) == 0 + if length(integrator.uprev) == 0 error("Cannot roll back integrator. Aborting time integration step at $(integrator.t).") end + return nothing end -# Solution looping interface function should_accept_step(integrator::OperatorSplittingIntegrator) - if integrator.force_stepfail || integrator.isout - return false - end + integrator.force_stepfail || integrator.isout && return false return should_accept_step(integrator, integrator.cache, integrator.controller) end function should_accept_step(integrator::OperatorSplittingIntegrator, cache, ::Nothing) return !(integrator.force_stepfail) end + function accept_step!(integrator::OperatorSplittingIntegrator) OrdinaryDiffEqCore.increment_accept!(integrator.stats) return accept_step!(integrator, integrator.cache, integrator.controller) @@ -488,66 +476,60 @@ function accept_step!(integrator::OperatorSplittingIntegrator, cache, controller return store_previous_info!(integrator) end function store_previous_info!(integrator::OperatorSplittingIntegrator) - return if length(integrator.uprev) > 0 # Integrator can rollback + if length(integrator.uprev) > 0 update_uprev!(integrator) end + return nothing end - function update_uprev!(integrator::OperatorSplittingIntegrator) RecursiveArrayTools.recursivecopy!(integrator.uprev, integrator.u) return nothing end -# --------------------------------------------------------------------------- -# Controller interface — SplitSubIntegrator -# --------------------------------------------------------------------------- +# Step accept/reject — SplitSubIntegrator +function accept_step!(sub::SplitSubIntegrator) + RecursiveArrayTools.recursivecopy!(sub.uprev, sub.u) + return nothing +end function reject_step!(sub::SplitSubIntegrator) sub.u .= sub.uprev - # Propagate rollback to all leaf DEIntegrators within this subtree so - # their state is consistent before the next attempt. - _rollback_subintegrator_tree!(sub.subintegrator_tree, sub.u_master) + _rollback_children!(sub.child_subintegrators, sub.u_master) + return nothing end -function _rollback_subintegrator_tree!(subintegrators::Tuple, u_master) - @unroll for child in subintegrators +# Roll back each child's local buffer to match master u. +# For DEIntegrators the leaf will be re-synced via forward_sync before the +# next attempt, so there is nothing to do here. +@unroll function _rollback_children!(children::Tuple, u_master) + @unroll for child in children _rollback_child!(child, u_master) end end - function _rollback_child!(child::SplitSubIntegrator, u_master) - child.u .= child.uprev - _rollback_subintegrator_tree!(child.subintegrator_tree, u_master) + child.u .= @view u_master[child.solution_indices] + RecursiveArrayTools.recursivecopy!(child.uprev, child.u) + _rollback_children!(child.child_subintegrators, u_master) + return nothing end - function _rollback_child!(child::DEIntegrator, u_master) - # The leaf integrator's uprev already holds the correct state because - # forward_sync_internal! copies u_master into it before each sub-step. - # Nothing to do here beyond letting the view aliasing keep things consistent. + # forward_sync before the next sub-step will restore this correctly. return nothing end -function accept_step!(sub::SplitSubIntegrator) - RecursiveArrayTools.recursivecopy!(sub.uprev, sub.u) -end - # --------------------------------------------------------------------------- # step_header! / step_footer! — outermost integrator # --------------------------------------------------------------------------- function step_header!(integrator::OperatorSplittingIntegrator) - # Accept or reject the step if !is_first_iteration(integrator) if should_accept_step(integrator) accept_step!(integrator) - else # Step should be rejected and hence repeated + else reject_step!(integrator) end - elseif integrator.u_modified # && integrator.iter == 0 + elseif integrator.u_modified update_uprev!(integrator) end - - # Before stepping we might need to adjust the dt increment_iteration(integrator) - # OrdinaryDiffEqCore.choose_algorithm!(integrator, integrator.cache) OrdinaryDiffEqCore.fix_dt_at_bounds!(integrator) OrdinaryDiffEqCore.modify_dt_for_tstops!(integrator) return integrator.force_stepfail = false @@ -557,46 +539,43 @@ function footer_reset_flags!(integrator) return integrator.u_modified = false end function setup_validity_flags!(integrator, t_next) - return integrator.isout = false #integrator.opts.isoutofdomain(integrator.u, integrator.p, t_next) + return integrator.isout = false end function fix_solution_buffer_sizes!(integrator, sol) resize!(integrator.sol.t, integrator.saveiter) resize!(integrator.sol.u, integrator.saveiter) - return if !(integrator.sol isa SciMLBase.DAESolution) + if !(integrator.sol isa SciMLBase.DAESolution) resize!(integrator.sol.k, integrator.saveiter_dense) end + return nothing end function step_footer!(integrator::OperatorSplittingIntegrator) ttmp = integrator.t + tdir(integrator) * integrator.dt - footer_reset_flags!(integrator) setup_validity_flags!(integrator, ttmp) - if should_accept_step(integrator) integrator.last_step_failed = false integrator.tprev = integrator.t - integrator.t = ttmp - step_accept_controller!(integrator) # Noop for non-adaptive algorithms + integrator.t = ttmp + step_accept_controller!(integrator) elseif integrator.force_stepfail if isadaptive(integrator) step_reject_controller!(integrator) OrdinaryDiffEqCore.post_newton_controller!(integrator, integrator.alg) - elseif integrator.dtchangeable # Non-adaptive but can change dt + elseif integrator.dtchangeable integrator.dt /= integrator.opts.failfactor elseif integrator.last_step_failed return end integrator.last_step_failed = true end - return nothing end # --------------------------------------------------------------------------- # __solve / solve! / step! # --------------------------------------------------------------------------- -# called by DiffEqBase.solve function SciMLBase.__solve( prob::OperatorSplittingProblem, alg::AbstractOperatorSplittingAlgorithm, args...; kwargs... @@ -605,26 +584,21 @@ function SciMLBase.__solve( return DiffEqBase.solve!(integrator) end -# either called directly (after init), or by DiffEqBase.solve (via __solve) function DiffEqBase.solve!(integrator::OperatorSplittingIntegrator) while !isempty(integrator.tstops) while tdir(integrator) * integrator.t < SciMLBase.first_tstop(integrator) step_header!(integrator) @timeit_debug "check_error" SciMLBase.check_error!(integrator) ∉ ( ReturnCode.Success, ReturnCode.Default, - ) && return + ) && return integrator.sol __step!(integrator) step_footer!(integrator) - if !SciMLBase.has_tstop(integrator) - break - end + SciMLBase.has_tstop(integrator) || break end OrdinaryDiffEqCore.handle_tstop!(integrator) end SciMLBase.postamble!(integrator) - if integrator.sol.retcode != ReturnCode.Default - return integrator.sol - end + integrator.sol.retcode != ReturnCode.Default && return integrator.sol return integrator.sol = SciMLBase.solution_new_retcode( integrator.sol, ReturnCode.Success ) @@ -632,7 +606,7 @@ end function DiffEqBase.step!(integrator::OperatorSplittingIntegrator) @timeit_debug "step!" if integrator.advance_to_tstop - tstop = first_tstop(integrator) + tstop = SciMLBase.first_tstop(integrator) while !reached_tstop(integrator, tstop) step_header!(integrator) @timeit_debug "check_error" SciMLBase.check_error!(integrator) ∉ ( @@ -640,9 +614,7 @@ function DiffEqBase.step!(integrator::OperatorSplittingIntegrator) ) && return __step!(integrator) step_footer!(integrator) - if !SciMLBase.has_tstop(integrator) - break - end + SciMLBase.has_tstop(integrator) || break end else step_header!(integrator) @@ -660,55 +632,12 @@ function DiffEqBase.step!(integrator::OperatorSplittingIntegrator) step_footer!(integrator) end end - return OrdinaryDiffEqCore.handle_tstop!(integrator) -end - -function SciMLBase.check_error(integrator::OperatorSplittingIntegrator) - if !SciMLBase.successful_retcode(integrator.sol) && - integrator.sol.retcode != ReturnCode.Default - return integrator.sol.retcode - end - - verbose = true # integrator.opts.verbose - - if DiffEqBase.NAN_CHECK(integrator.dtcache) || DiffEqBase.NAN_CHECK(integrator.dt) - if verbose - @warn("NaN dt detected. Likely a NaN value in the state, parameters, or derivative value caused this outcome.") - end - return ReturnCode.DtNaN - end - - return check_error_subintegrators(integrator, integrator.subintegrator_tree) -end - -# Recurse over a tuple of children -function check_error_subintegrators(integrator, subintegrator_tree::Tuple) - for sub in subintegrator_tree - retcode = check_error_subintegrators(integrator, sub) - if !SciMLBase.successful_retcode(retcode) && retcode != ReturnCode.Default - return retcode - end - end - return integrator.sol.retcode -end - -# Leaf: read retcode from the DEIntegrator's solution -function check_error_subintegrators(integrator, sub::DEIntegrator) - return SciMLBase.check_error(sub) -end - -# Intermediate node: read retcode from the SplitSubIntegrator status object -function check_error_subintegrators(integrator, sub::SplitSubIntegrator) - rc = sub.status.retcode - if !SciMLBase.successful_retcode(rc) && rc != ReturnCode.Default - return rc - end - # Also recurse into children - return check_error_subintegrators(integrator, sub.subintegrator_tree) + OrdinaryDiffEqCore.handle_tstop!(integrator) + return end function DiffEqBase.step!(integrator::OperatorSplittingIntegrator, dt, stop_at_tdt = false) - return @timeit_debug "step!" begin + @timeit_debug "step!" begin dt <= zero(dt) && error("dt must be positive") stop_at_tdt && !integrator.dtchangeable && error("Cannot stop at t + dt if dtchangeable is false") @@ -723,22 +652,49 @@ function DiffEqBase.step!(integrator::OperatorSplittingIntegrator, dt, stop_at_t step_footer!(integrator) end end + return nothing end -function setup_u(prob::OperatorSplittingProblem, solver, alias_u0) - if alias_u0 - return prob.u0 - else - return RecursiveArrayTools.recursivecopy(prob.u0) +# --------------------------------------------------------------------------- +# check_error +# --------------------------------------------------------------------------- +function SciMLBase.check_error(integrator::OperatorSplittingIntegrator) + if !SciMLBase.successful_retcode(integrator.sol) && + integrator.sol.retcode != ReturnCode.Default + return integrator.sol.retcode + end + if DiffEqBase.NAN_CHECK(integrator.dtcache) || DiffEqBase.NAN_CHECK(integrator.dt) + integrator.opts.verbose && + @warn("NaN dt detected. Likely a NaN value in the state, parameters, or derivative value caused this outcome.") + return ReturnCode.DtNaN end + return _check_error_children(integrator.sol.retcode, integrator.child_subintegrators) +end + +@unroll function _check_error_children(current_retcode, children::Tuple) + @unroll for child in children + rc = _child_retcode(child) + if !SciMLBase.successful_retcode(rc) && rc != ReturnCode.Default + return rc + end + end + return current_retcode +end + +_child_retcode(child::DEIntegrator) = SciMLBase.check_error(child) +_child_retcode(child::SplitSubIntegrator) = child.status.retcode + +# --------------------------------------------------------------------------- +# Internal step +# --------------------------------------------------------------------------- +function setup_u(prob::OperatorSplittingProblem, solver, alias_u0) + alias_u0 ? prob.u0 : RecursiveArrayTools.recursivecopy(prob.u0) end -# TimeChoiceIterator API @inline function DiffEqBase.get_tmp_cache(integrator::OperatorSplittingIntegrator) return (integrator.tmp,) end -# Interpolation function linear_interpolation!(y, t, y1, y2, t1, t2) return y .= y1 + (t - t1) * (y2 - y1) / (t2 - t1) end @@ -748,47 +704,25 @@ function (integrator::OperatorSplittingIntegrator)(tmp, t) ) end -# --------------------------------------------------------------------------- # Stepsize controller hooks — outermost integrator -# --------------------------------------------------------------------------- -""" - stepsize_controller!(::OperatorSplittingIntegrator) - -Updates the controller using the current state of the integrator if the operator splitting algorithm is adaptive. -""" @inline function stepsize_controller!(integrator::OperatorSplittingIntegrator) - algorithm = integrator.alg - isadaptive(algorithm) || return nothing - return stepsize_controller!(integrator, algorithm) + isadaptive(integrator.alg) || return nothing + return stepsize_controller!(integrator, integrator.alg) end - -""" - step_accept_controller!(::OperatorSplittingIntegrator) - -Updates `dtcache` of the integrator if the step is accepted and the operator splitting algorithm is adaptive. -""" @inline function step_accept_controller!(integrator::OperatorSplittingIntegrator) - algorithm = integrator.alg - isadaptive(algorithm) || return nothing - return step_accept_controller!(integrator, algorithm, nothing) + isadaptive(integrator.alg) || return nothing + return step_accept_controller!(integrator, integrator.alg, nothing) end - -""" - step_reject_controller!(::OperatorSplittingIntegrator) - -Updates `dtcache` of the integrator if the step is rejected and the the operator splitting algorithm is adaptive. -""" @inline function step_reject_controller!(integrator::OperatorSplittingIntegrator) - algorithm = integrator.alg - isadaptive(algorithm) || return nothing - return step_reject_controller!(integrator, algorithm, nothing) + isadaptive(integrator.alg) || return nothing + return step_reject_controller!(integrator, integrator.alg, nothing) end -# --------------------------------------------------------------------------- # Time helpers -# --------------------------------------------------------------------------- -tdir(integrator) = integrator.tstops.ordering isa DataStructures.FasterForward ? 1 : -1 -is_past_t(integrator, t) = tdir(integrator) * (t - integrator.t) ≤ zero(integrator.t) +tdir(integrator) = + integrator.tstops.ordering isa DataStructures.FasterForward ? 1 : -1 +is_past_t(integrator, t) = + tdir(integrator) * (t - integrator.t) ≤ zero(integrator.t) function reached_tstop(integrator, tstop, stop_at_tstop = integrator.dtchangeable) if stop_at_tstop integrator.t > tstop && @@ -799,17 +733,10 @@ function reached_tstop(integrator, tstop, stop_at_tstop = integrator.dtchangeabl end end -# --------------------------------------------------------------------------- # SciMLBase integrator interface -# --------------------------------------------------------------------------- function SciMLBase.done(integrator::OperatorSplittingIntegrator) - if !( - integrator.sol.retcode in ( - ReturnCode.Default, ReturnCode.Success, - ) - ) - return true - elseif isempty(integrator.tstops) + integrator.sol.retcode ∉ (ReturnCode.Default, ReturnCode.Success) && return true + if isempty(integrator.tstops) SciMLBase.postamble!(integrator) return true end @@ -820,150 +747,70 @@ function SciMLBase.postamble!(integrator::OperatorSplittingIntegrator) return DiffEqBase.finalize!(integrator.callback, integrator.u, integrator.t, integrator) end -function __step!(integrator) +function __step!(integrator::OperatorSplittingIntegrator) tnext = integrator.t + integrator.dt - synchronize_subintegrator_tree!(integrator) + _sync_children!(integrator) advance_solution_to!(integrator, tnext) - return stepsize_controller!(integrator) -end - -# solvers need to define this interface -function advance_solution_to!(integrator::OperatorSplittingIntegrator, tnext) - return advance_solution_to!(integrator, integrator.cache, tnext) -end - -function advance_solution_to!( - outer_integrator::OperatorSplittingIntegrator, - sub::SplitSubIntegrator, sync, cache, tend - ) - # Advance a SplitSubIntegrator node using its own advance_solution_to! dispatch - dt = tend - sub.t - sub.dt = dt - return advance_solution_to!(outer_integrator, sub, tend) -end - -function advance_solution_to!( - outer_integrator::OperatorSplittingIntegrator, - integrator::DEIntegrator, solution_indices, sync, cache, tend - ) - dt = tend - integrator.t - return SciMLBase.step!(integrator, dt, true) -end - -# --------------------------------------------------------------------------- -# SciMLBase.jl integrator interface -# --------------------------------------------------------------------------- -SciMLBase.has_stats(::OperatorSplittingIntegrator) = true - -SciMLBase.has_tstop(integrator::OperatorSplittingIntegrator) = !isempty(integrator.tstops) -SciMLBase.first_tstop(integrator::OperatorSplittingIntegrator) = first(integrator.tstops) -SciMLBase.pop_tstop!(integrator::OperatorSplittingIntegrator) = pop!(integrator.tstops) - -DiffEqBase.get_dt(integrator::OperatorSplittingIntegrator) = integrator.dt -function set_dt!(integrator::OperatorSplittingIntegrator, dt) - dt <= zero(dt) && error("dt must be positive") - return integrator.dt = dt -end - -function DiffEqBase.add_tstop!(integrator::OperatorSplittingIntegrator, t) - is_past_t(integrator, t) && - error("Cannot add a tstop at $t because that is behind the current \ - integrator time $(integrator.t)") - return push!(integrator.tstops, t) -end - -function DiffEqBase.add_saveat!(integrator::OperatorSplittingIntegrator, t) - is_past_t(integrator, t) && - error("Cannot add a saveat point at $t because that is behind the \ - current integrator time $(integrator.t)") - return push!(integrator.saveat, t) + stepsize_controller!(integrator) + return nothing end -DiffEqBase.u_modified!(i::OperatorSplittingIntegrator, bool) = nothing - -# --------------------------------------------------------------------------- -# Synchronization -# --------------------------------------------------------------------------- -function synchronize_subintegrator_tree!(integrator::OperatorSplittingIntegrator) - return synchronize_subintegrator!(integrator.subintegrator_tree, integrator) +# Sync all direct children of the outermost integrator +function _sync_children!(integrator::OperatorSplittingIntegrator) + _sync_children_tuple!(integrator.child_subintegrators, integrator) end -@unroll function synchronize_subintegrator!( - subintegrator_tree::Tuple, integrator::OperatorSplittingIntegrator +@unroll function _sync_children_tuple!( + children::Tuple, + parent::OperatorSplittingIntegrator ) - @unroll for sub in subintegrator_tree - synchronize_subintegrator!(sub, integrator) + @unroll for child in children + _sync_child_to_parent!(child, parent) end end -# Sync a SplitSubIntegrator node: update its t/dt then recurse into children -function synchronize_subintegrator!( - sub::SplitSubIntegrator, integrator::OperatorSplittingIntegrator - ) - (; t, dt) = integrator - @assert sub.t == t "SplitSubIntegrator time $(sub.t) out of sync with outer integrator time $t" - if !isadaptive(sub) - sub.dt = dt - sub.dtcache = dt - end - # Recurse: sync children against the *sub-integrator* (not outer) time - @unroll for child in sub.subintegrator_tree - synchronize_subintegrator_child!(child, sub) +function _sync_child_to_parent!(child::DEIntegrator, parent::OperatorSplittingIntegrator) + @assert child.t == parent.t "($(child.t) != $(parent.t))" + if !isadaptive(child) && child.dtchangeable + SciMLBase.set_proposed_dt!(child, parent.dt) end end -function synchronize_subintegrator_child!( - child::DEIntegrator, parent::SplitSubIntegrator +function _sync_child_to_parent!( + child::SplitSubIntegrator, parent::OperatorSplittingIntegrator ) - @assert child.t == parent.t "Child integrator time $(child.t) out of sync with parent time $(parent.t)" + @assert child.t == parent.t "($(child.t) != $(parent.t))" if !isadaptive(child) SciMLBase.set_proposed_dt!(child, parent.dt) end end -function synchronize_subintegrator_child!( - child::SplitSubIntegrator, parent::SplitSubIntegrator - ) - @assert child.t == parent.t "Nested SplitSubIntegrator time $(child.t) out of sync with parent time $(parent.t)" - if !isadaptive(child) - child.dt = parent.dt - child.dtcache = parent.dt - end +# Entry point: dispatch to the algorithm's advance_solution_to! +function advance_solution_to!(integrator::OperatorSplittingIntegrator, tnext) + return advance_solution_to!(integrator, integrator.cache, tnext) end -# --------------------------------------------------------------------------- -# advance_solution_to! for AbstractOperatorSplittingCache -# (dispatches into the algorithm-specific method in solver.jl) -# --------------------------------------------------------------------------- +# Algorithm-level dispatch (implemented in solver.jl per algorithm) function advance_solution_to!( integrator::OperatorSplittingIntegrator, cache::AbstractOperatorSplittingCache, tnext::Number ) return advance_solution_to!( - integrator, integrator.subintegrator_tree, cache, tnext - ) -end - -# advance_solution_to! for a SplitSubIntegrator node -# (the algorithm-specific method in solver.jl calls this signature) -function advance_solution_to!( - outer_integrator::OperatorSplittingIntegrator, - sub::SplitSubIntegrator, tend - ) - return advance_solution_to!( - outer_integrator, sub, sub.subintegrator_tree, - sub.solution_index_tree, sub.synchronizer_tree, - sub.cache, tend + integrator, integrator.child_subintegrators, cache, tnext ) end # --------------------------------------------------------------------------- # Tree construction # --------------------------------------------------------------------------- -# Top-level dispatch: builds a Tuple of SplitSubIntegrators -function build_subintegrator_tree_with_cache( - prob::OperatorSplittingProblem, alg::AbstractOperatorSplittingAlgorithm, - uprevouter::AbstractVector, uouter::AbstractVector, + +# Top-level builder: called from __init with the full problem. +# Returns (child_subintegrators::Tuple, cache::AbstractOperatorSplittingCache) +function build_subintegrators( + prob::OperatorSplittingProblem, + alg::AbstractOperatorSplittingAlgorithm, + uprevouter::AbstractVector, + uouter::AbstractVector, u_master::AbstractVector, solution_indices, t0, dt, tf, @@ -972,14 +819,13 @@ function build_subintegrator_tree_with_cache( ) (; f, p) = prob - subintegrator_tree_with_caches = ntuple( - i -> build_subintegrator_tree_with_cache( + results = ntuple( + i -> _build_child( prob, alg.inner_algs[i], get_operator(f, i), p[i], - uprevouter, uouter, - u_master, + uprevouter, uouter, u_master, f.solution_indices[i], t0, dt, tf, tstops, saveat, d_discontinuities, callback, @@ -988,25 +834,27 @@ function build_subintegrator_tree_with_cache( length(f.functions) ) - subintegrator_tree = ntuple( - i -> subintegrator_tree_with_caches[i][1], length(f.functions) + child_subintegrators = ntuple(i -> results[i][1], length(f.functions)) + child_caches = ntuple(i -> results[i][2], length(f.functions)) + + cache = init_cache( + f, alg; + uprev = uprevouter, u = uouter, alias_u = true, + inner_caches = child_caches ) - caches = ntuple(i -> subintegrator_tree_with_caches[i][2], length(f.functions)) - return subintegrator_tree, - init_cache( - f, alg; - uprev = uprevouter, u = uouter, alias_u = true, - inner_caches = caches - ) + return child_subintegrators, cache end -# Intermediate node: inner algorithm is also an AbstractOperatorSplittingAlgorithm -# wrapping a GenericSplitFunction → produce a SplitSubIntegrator -function build_subintegrator_tree_with_cache( - prob::OperatorSplittingProblem, alg::AbstractOperatorSplittingAlgorithm, - f::GenericSplitFunction, p::Tuple, - uprevouter::AbstractVector, uouter::AbstractVector, +# Intermediate node: inner alg is an AbstractOperatorSplittingAlgorithm and +# f is a GenericSplitFunction → produce a SplitSubIntegrator +function _build_child( + prob::OperatorSplittingProblem, + alg::AbstractOperatorSplittingAlgorithm, + f::GenericSplitFunction, + p::Tuple, + uprevouter::AbstractVector, + uouter::AbstractVector, u_master::AbstractVector, solution_indices, t0, dt, tf, @@ -1017,15 +865,14 @@ function build_subintegrator_tree_with_cache( ) tType = typeof(dt) - # Build children recursively - child_results = ntuple( - i -> build_subintegrator_tree_with_cache( + # Recurse: build each grandchild + grandchild_results = ntuple( + i -> _build_child( prob, alg.inner_algs[i], get_operator(f, i), p[i], - uprevouter, uouter, - u_master, + uprevouter, uouter, u_master, f.solution_indices[i], t0, dt, tf, tstops, saveat, d_discontinuities, callback, @@ -1034,38 +881,34 @@ function build_subintegrator_tree_with_cache( length(f.functions) ) - child_subintegrators = ntuple(i -> child_results[i][1], length(f.functions)) - child_caches = ntuple(i -> child_results[i][2], length(f.functions)) - - # Build per-child solution_index_tree and synchronizer_tree - child_solution_indices = ntuple(i -> f.solution_indices[i], length(f.functions)) - child_synchronizers = ntuple(i -> f.synchronizers[i], length(f.functions)) + child_subintegrators = ntuple(i -> grandchild_results[i][1], length(f.functions)) + child_caches = ntuple(i -> grandchild_results[i][2], length(f.functions)) + child_solution_indices = ntuple(i -> f.solution_indices[i], length(f.functions)) + child_synchronizers = ntuple(i -> f.synchronizers[i], length(f.functions)) - # Cache for *this* level u_sub = @view uouter[solution_indices] uprev_sub = @view uprevouter[solution_indices] + level_cache = init_cache( f, alg; uprev = uprev_sub, u = u_sub, inner_caches = child_caches ) - # EEst default EEst_val = isadaptive(alg) ? one(tType) : tType(NaN) sub = SplitSubIntegrator( alg, - u_sub, - RecursiveArrayTools.recursivecopy(u_sub), # uprev: local copy for rollback + # u and uprev: independent copies so that rollback works even when + # u_sub is a view into a device-local buffer. + RecursiveArrayTools.recursivecopy(Array(u_sub)), + RecursiveArrayTools.recursivecopy(Array(u_sub)), u_master, - t0, - dt, - dt, # dtcache - 0, # iter + t0, dt, dt, # t, dt, dtcache + 0, # iter EEst_val, controller, - false, # force_stepfail - false, # last_step_failed + false, false, # force_stepfail, last_step_failed SplitSubIntegratorStatus(), level_cache, child_subintegrators, @@ -1077,9 +920,9 @@ function build_subintegrator_tree_with_cache( return sub, level_cache end -# Leaf node: inner algorithm is a plain SciMLBase.AbstractODEAlgorithm +# Leaf node: inner alg is a plain SciMLBase.AbstractODEAlgorithm # → produce an ODEIntegrator (existing behaviour) -function build_subintegrator_tree_with_cache( +function _build_child( prob::OperatorSplittingProblem, alg::SciMLBase.AbstractODEAlgorithm, f::F, p::P, @@ -1092,9 +935,7 @@ function build_subintegrator_tree_with_cache( save_end = false, controller = nothing ) where {S, T, P, F} - uprev = @view uprevouter[solution_indices] - u = @view uouter[solution_indices] - + u = @view uouter[solution_indices] prob2 = if p isa NullParameters SciMLBase.ODEProblem(f, u, (t0, min(t0 + dt, tf))) else @@ -1102,31 +943,44 @@ function build_subintegrator_tree_with_cache( end integrator = SciMLBase.__init( - prob2, - alg; + prob2, alg; dt, saveat = (), d_discontinuities, save_everystep = false, advance_to_tstop = false, - adaptive, - controller, - verbose + adaptive, controller, verbose ) - return integrator, integrator.cache end -# forward/backward sync no-ops for tuple nodes (handled inside SplitSubIntegrator) -function forward_sync_subintegrator!( - outer_integrator::OperatorSplittingIntegrator, subintegrator_tree::Tuple, - solution_indices::Tuple, synchronizers::Tuple - ) - return nothing +# --------------------------------------------------------------------------- +# SciMLBase API +# --------------------------------------------------------------------------- +SciMLBase.has_stats(::OperatorSplittingIntegrator) = true + +SciMLBase.has_tstop(i::OperatorSplittingIntegrator) = !isempty(i.tstops) +SciMLBase.first_tstop(i::OperatorSplittingIntegrator) = first(i.tstops) +SciMLBase.pop_tstop!(i::OperatorSplittingIntegrator) = pop!(i.tstops) + +DiffEqBase.get_dt(i::OperatorSplittingIntegrator) = i.dt +function set_dt!(i::OperatorSplittingIntegrator, dt) + dt <= zero(dt) && error("dt must be positive") + return i.dt = dt end -function backward_sync_subintegrator!( - outer_integrator::OperatorSplittingIntegrator, - subintegrator_tree::Tuple, solution_indices::Tuple, synchronizer::Tuple - ) - return nothing + +function DiffEqBase.add_tstop!(i::OperatorSplittingIntegrator, t) + is_past_t(i, t) && + error("Cannot add a tstop at $t because that is behind the current \ + integrator time $(i.t)") + return push!(i.tstops, t) end + +function DiffEqBase.add_saveat!(i::OperatorSplittingIntegrator, t) + is_past_t(i, t) && + error("Cannot add a saveat point at $t because that is behind the \ + current integrator time $(i.t)") + return push!(i.saveat, t) +end + +DiffEqBase.u_modified!(i::OperatorSplittingIntegrator, bool) = nothing diff --git a/src/solver.jl b/src/solver.jl index 2d7cfcd..8b219a4 100644 --- a/src/solver.jl +++ b/src/solver.jl @@ -1,8 +1,11 @@ -# Lie-Trotter-Godunov Splitting Implementation +# --------------------------------------------------------------------------- +# Lie-Trotter-Godunov operator splitting +# --------------------------------------------------------------------------- """ LieTrotterGodunov <: AbstractOperatorSplittingAlgorithm -A first order sequential operator splitting algorithm attributed to [Lie:1880:tti,Tro:1959:psg,God:1959:dmn](@cite). +First-order sequential operator splitting algorithm attributed to +[Lie:1880:tti,Tro:1959:psg,God:1959:dmn](@cite). """ struct LieTrotterGodunov{AlgTupleType} <: AbstractOperatorSplittingAlgorithm inner_algs::AlgTupleType @@ -22,121 +25,150 @@ function init_cache( alias_u = false ) _uprev = alias_uprev ? uprev : RecursiveArrayTools.recursivecopy(uprev) - _u = alias_u ? u : RecursiveArrayTools.recursivecopy(u) + _u = alias_u ? u : RecursiveArrayTools.recursivecopy(u) return LieTrotterGodunovCache(_u, _uprev, inner_caches) end # --------------------------------------------------------------------------- -# advance_solution_to! for the outermost integrator with a Tuple of children -# This is the top-level dispatch when the outer integrator's cache is a -# LieTrotterGodunovCache and children are SplitSubIntegrators or DEIntegrators. +# advance_solution_to! for a SplitSubIntegrator node +# +# The SplitSubIntegrator is now the *parent* for its own children. +# It carries child_solution_indices and child_synchronizers directly. +# +# Entry point called from integrator.jl for a SplitSubIntegrator node # --------------------------------------------------------------------------- -@inline @unroll function advance_solution_to!( - outer_integrator::OperatorSplittingIntegrator, - subintegrators::Tuple, cache::LieTrotterGodunovCache, tnext +function advance_solution_to!( + outer::OperatorSplittingIntegrator, + children::Tuple, + cache::AbstractOperatorSplittingCache, + tnext ) - (; inner_caches) = cache - i = 0 - @unroll for subinteg in subintegrators - i += 1 - inner_cache = inner_caches[i] - _advance_child!(outer_integrator, subinteg, inner_cache, tnext) - # Check for failure after each child - if _child_failed(outer_integrator, subinteg) - outer_integrator.force_stepfail = true - return - end + _perform_step!(outer, children, cache, tnext) + + if outer.force_stepfail + outer.sol = SciMLBase.solution_new_retcode( + outer.sol, + ReturnCode.Failure + ) + return end + + # All children succeeded: advance this node's time and counter + # outer.sol = SciMLBase.solution_new_retcode( + # outer.sol, + # ReturnCode.Success + # ) + return end -# --------------------------------------------------------------------------- -# advance_solution_to! for a SplitSubIntegrator node with LieTrotterGodunov -# This is the recursive dispatch when a SplitSubIntegrator's own cache is LTG. -# --------------------------------------------------------------------------- -@inline @unroll function advance_solution_to!( - outer_integrator::OperatorSplittingIntegrator, - sub::SplitSubIntegrator, - subintegrators::Tuple, - solution_indices::Tuple, - synchronizers::Tuple, - cache::LieTrotterGodunovCache, - tnext - ) - (; inner_caches) = cache +function advance_solution_to!( + outer::SplitSubIntegrator, + children::Tuple, + cache::AbstractOperatorSplittingCache, + tnext +) + _perform_step!(outer, children, cache, tnext) + + if outer.force_stepfail + outer.status = SplitSubIntegratorStatus(ReturnCode.Failure) + return + end + + # All children succeeded: advance this node's time and counter + outer.status = SplitSubIntegratorStatus(ReturnCode.Success) + + return +end + +@unroll function _perform_step!( + outer, + children::Tuple, + cache::LieTrotterGodunovCache, + tnext +) i = 0 - @unroll for subinteg in subintegrators + @unroll for child in children i += 1 - synchronizer = synchronizers[i] - idxs = solution_indices[i] - inner_cache = inner_caches[i] + idxs = outer.child_solution_indices[i] + sync = outer.child_synchronizers[i] - @timeit_debug "sync ->" forward_sync_subintegrator!( - outer_integrator, subinteg, idxs, synchronizer - ) - @timeit_debug "time solve" _advance_child!( - outer_integrator, subinteg, inner_cache, tnext - ) - if _child_failed(outer_integrator, subinteg) - sub.status = SplitSubIntegratorStatus(ReturnCode.Failure) + @timeit_debug "sync ->" forward_sync_subintegrator!(outer, child, idxs, sync) + @timeit_debug "time solve" _do_step!(outer, child, tnext) + if _child_failed(child) + outer.force_stepfail = true return end - backward_sync_subintegrator!(outer_integrator, subinteg, idxs, synchronizer) + backward_sync_subintegrator!(outer, child, idxs, sync) end - # All children succeeded: mark this node as successful - sub.status = SplitSubIntegratorStatus(ReturnCode.Success) - # Accept the sub-step: copy u into uprev for potential future rollback - accept_step!(sub) - sub.t = tnext - sub.iter += 1 end # --------------------------------------------------------------------------- -# _advance_child!: dispatch on child type +# _do_step!: pure integration, no sync. +# The caller (advance_children_*) owns forward/backward sync around this. # --------------------------------------------------------------------------- -# Child is a SplitSubIntegrator: call its own advance_solution_to! -function _advance_child!( - outer_integrator::OperatorSplittingIntegrator, - child::SplitSubIntegrator, - _inner_cache, # ignored — child uses its own cache - tnext - ) - # Forward sync from the outer master u into this child's u - forward_sync_subintegrator!( - outer_integrator, child, child.solution_indices, NoExternalSynchronization() - ) - advance_solution_to!(outer_integrator, child, tnext) - backward_sync_subintegrator!( - outer_integrator, child, child.solution_indices, NoExternalSynchronization() - ) -end - -# Child is a leaf DEIntegrator -function _advance_child!( - outer_integrator::OperatorSplittingIntegrator, +# Leaf: DEIntegrator +function _do_step!( + outer::OperatorSplittingIntegrator, child::DEIntegrator, - _inner_cache, tnext ) dt = tnext - child.t SciMLBase.step!(child, dt, true) - # If the leaf adaptive integrator failed unrecoverably, error immediately + + # Unrecoverable failure: error immediately regardless of adaptive/non-adaptive if !SciMLBase.successful_retcode(child.sol.retcode) && child.sol.retcode != ReturnCode.Default - if isadaptive(child) - error("Adaptive inner integrator failed unrecoverably with retcode $(child.sol.retcode). Aborting.") - end - # non-adaptive failure: signal to parent + error("Inner integrator failed unrecoverably with retcode \ + $(child.sol.retcode) at t=$(child.t). Aborting.") + end + return nothing +end + +# Intermediate: SplitSubIntegrator — recurse +function _do_step!( + outer::OperatorSplittingIntegrator, + sub::SplitSubIntegrator, + tnext + ) + # Sync sub's children among themselves before recursing. + # (The parent already synced sub.u from master u via forward_sync before + # calling _do_step!; here we propagate sub.dt down to sub's own children.) + _sync_sub_children!(sub) + advance_solution_to!(outer, sub, tnext) + return nothing +end + +# Propagate time-step information from sub to its own children +function _sync_sub_children!(sub::SplitSubIntegrator) + _sync_sub_children_tuple!(sub.child_subintegrators, sub) +end + +@unroll function _sync_sub_children_tuple!(children::Tuple, parent::SplitSubIntegrator) + @unroll for child in children + _sync_child_to_sub_parent!(child, parent) + end +end + +function _sync_child_to_sub_parent!(child::DEIntegrator, parent::SplitSubIntegrator) + @assert child.t == parent.t "($(child.t) != $(parent.t))" + if !isadaptive(child) && child.dtchangeable + SciMLBase.set_proposed_dt!(child, parent.dt) + end +end + +function _sync_child_to_sub_parent!(child::SplitSubIntegrator, parent::SplitSubIntegrator) + @assert child.t == parent.t "($(child.t) != $(parent.t))" + if !isadaptive(child) + SciMLBase.set_proposed_dt!(child, parent.dt) end end # --------------------------------------------------------------------------- # _child_failed: check whether a child reported a failure # --------------------------------------------------------------------------- -function _child_failed(outer_integrator, child::DEIntegrator) - return child.sol.retcode ∉ (ReturnCode.Default, ReturnCode.Success) -end +_child_failed(child::DEIntegrator) = + child.sol.retcode ∉ (ReturnCode.Default, ReturnCode.Success) -function _child_failed(outer_integrator, child::SplitSubIntegrator) - return child.status.retcode ∉ (ReturnCode.Default, ReturnCode.Success) -end +_child_failed(child::SplitSubIntegrator) = + child.status.retcode ∉ (ReturnCode.Default, ReturnCode.Success) diff --git a/src/utils.jl b/src/utils.jl index b1c83cd..d87f000 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -3,7 +3,6 @@ function tstops_and_saveat_heaps(t0, tf, tstops, saveat) FT = typeof(tf) ordering = tf > t0 ? DataStructures.FasterForward : DataStructures.FasterReverse - # ensure that tstops includes tf and only has values ahead of t0 tstops = [filter(t -> t0 < t < tf || tf < t < t0, tstops)..., tf] tstops = DataStructures.BinaryHeap{FT, ordering}(tstops) @@ -25,212 +24,195 @@ end need_sync(a, b) Determines whether it is necessary to synchronize two objects with any -solution information. +solution information. A possible reason when no synchronization is necessary +might be that the vectors alias each other in memory. """ need_sync need_sync(a::AbstractVector, b::AbstractVector) = true -need_sync(a::SubArray, b::AbstractVector) = a.parent !== b -need_sync(a::AbstractVector, b::SubArray) = a !== b.parent -need_sync(a::SubArray, b::SubArray) = a.parent !== b.parent +need_sync(a::SubArray, b::AbstractVector) = a.parent !== b +need_sync(a::AbstractVector, b::SubArray) = a !== b.parent +need_sync(a::SubArray, b::SubArray) = a.parent !== b.parent """ sync_vectors!(a, b) -Copies the information in object `b` into object `a`, if synchronization is necessary. +Copies the information in `b` into `a` if synchronization is necessary. """ function sync_vectors!(a, b) - return if need_sync(a, b) && a !== b + if need_sync(a, b) && a !== b a .= b end + return nothing end # --------------------------------------------------------------------------- # forward_sync_subintegrator! +# +# The *parent* (OperatorSplittingIntegrator or SplitSubIntegrator) calls this +# before each child's step. It copies the relevant slice of the master +# solution into the child and applies any external parameter synchronisation. # --------------------------------------------------------------------------- -""" - forward_sync_subintegrator!(outer_integrator, inner, solution_indices, sync) -Copy state from the outer integrator into the inner integrator before a -sub-step, and apply any external parameter synchronisation via `sync`. -""" +# Parent = outermost OperatorSplittingIntegrator, child = DEIntegrator function forward_sync_subintegrator!( - outer_integrator::OperatorSplittingIntegrator, - inner_integrator::DEIntegrator, solution_indices, sync + parent::OperatorSplittingIntegrator, + child::DEIntegrator, + solution_indices, + sync ) - forward_sync_internal!(outer_integrator, inner_integrator, solution_indices) - return forward_sync_external!(outer_integrator, inner_integrator, sync) + _forward_sync_internal_leaf!(parent.u, child, solution_indices) + return forward_sync_external!(parent, child, sync) end -""" - forward_sync_subintegrator! for SplitSubIntegrator +# Parent = outermost OperatorSplittingIntegrator, child = SplitSubIntegrator +function forward_sync_subintegrator!( + parent::OperatorSplittingIntegrator, + child::SplitSubIntegrator, + solution_indices, + sync + ) + @views uparent = parent.u[solution_indices] + sync_vectors!(child.u, uparent) + sync_vectors!(child.uprev, uparent) + forward_sync_external!(parent, child, sync) + return nothing +end -When the inner node is a `SplitSubIntegrator` we only need to copy the master -solution vector slice into its `u` (the `SplitSubIntegrator.u` is already a -view, but on a different device or after a rollback it may need refreshing). -""" +# Parent = SplitSubIntegrator, child = DEIntegrator function forward_sync_subintegrator!( - outer_integrator::OperatorSplittingIntegrator, - sub::SplitSubIntegrator, solution_indices, sync - ) - # Sync the view: master → sub.u (noop if they already alias) - @views uouter = outer_integrator.u[solution_indices] - sync_vectors!(sub.u, uouter) - sync_vectors!(sub.uprev, uouter) - return forward_sync_external!(outer_integrator, sub, sync) + parent::SplitSubIntegrator, + child::DEIntegrator, + solution_indices, + sync + ) + # parent.u is this level's buffer; solution_indices are relative to + # the master u. We read from the master via u_master. + _forward_sync_internal_leaf!(parent.u_master, child, solution_indices) + forward_sync_external!(parent, child, sync) + return nothing +end + +# Parent = SplitSubIntegrator, child = SplitSubIntegrator +function forward_sync_subintegrator!( + parent::SplitSubIntegrator, + child::SplitSubIntegrator, + solution_indices, + sync + ) + @views umaster = parent.u_master[solution_indices] + sync_vectors!(child.u, umaster) + sync_vectors!(child.uprev, umaster) + forward_sync_external!(parent, child, sync) + return nothing +end + +# Shared internal helper: copy master u slice → leaf DEIntegrator u/uprev +function _forward_sync_internal_leaf!(u_source, child::DEIntegrator, solution_indices) + @views usrc = u_source[solution_indices] + sync_vectors!(child.uprev, usrc) + sync_vectors!(child.u, usrc) + SciMLBase.u_modified!(child, true) + return nothing end # --------------------------------------------------------------------------- # backward_sync_subintegrator! +# +# The *parent* calls this after each child's step to copy the child's updated +# state back into the master solution vector. # --------------------------------------------------------------------------- -""" - backward_sync_subintegrator!(outer_integrator, inner, solution_indices, sync) -Copy state from the inner integrator back into the outer integrator after a -sub-step, and apply any external parameter synchronisation via `sync`. -""" +# Parent = outermost OperatorSplittingIntegrator, child = DEIntegrator function backward_sync_subintegrator!( - outer_integrator::OperatorSplittingIntegrator, - inner_integrator::DEIntegrator, solution_indices, sync - ) - backward_sync_internal!(outer_integrator, inner_integrator, solution_indices) - return backward_sync_external!(outer_integrator, inner_integrator, sync) + parent::OperatorSplittingIntegrator, + child::DEIntegrator, + solution_indices, + sync + ) + @views uparent = parent.u[solution_indices] + sync_vectors!(uparent, child.u) + backward_sync_external!(parent, child, sync) + return nothing end +# Parent = outermost OperatorSplittingIntegrator, child = SplitSubIntegrator function backward_sync_subintegrator!( - outer_integrator::OperatorSplittingIntegrator, - sub::SplitSubIntegrator, solution_indices, sync + parent::OperatorSplittingIntegrator, + child::SplitSubIntegrator, + solution_indices, + sync ) - @views uouter = outer_integrator.u[solution_indices] - sync_vectors!(uouter, sub.u) - return backward_sync_external!(outer_integrator, sub, sync) + @views uparent = parent.u[solution_indices] + sync_vectors!(uparent, child.u) + return backward_sync_external!(parent, child, sync) end -# --------------------------------------------------------------------------- -# forward_sync_internal! / backward_sync_internal! -# --------------------------------------------------------------------------- -function forward_sync_internal!( - outer_integrator::OperatorSplittingIntegrator, - inner_integrator::OperatorSplittingIntegrator, solution_indices - ) - return nothing -end -function backward_sync_internal!( - outer_integrator::OperatorSplittingIntegrator, - inner_integrator::OperatorSplittingIntegrator, solution_indices +# Parent = SplitSubIntegrator, child = DEIntegrator +function backward_sync_subintegrator!( + parent::SplitSubIntegrator, + child::DEIntegrator, + solution_indices, + sync ) - return nothing + @views umaster = parent.u_master[solution_indices] + sync_vectors!(umaster, child.u) + # Also keep parent.u consistent + @views ulocal = parent.u[solution_indices .- first(parent.solution_indices) .+ 1] + sync_vectors!(ulocal, child.u) + return backward_sync_external!(parent, child, sync) end -function forward_sync_internal!( - outer_integrator::OperatorSplittingIntegrator, - inner_integrator::DEIntegrator, solution_indices - ) - @views uouter = outer_integrator.u[solution_indices] - sync_vectors!(inner_integrator.uprev, uouter) - sync_vectors!(inner_integrator.u, uouter) - return SciMLBase.u_modified!(inner_integrator, true) -end -function backward_sync_internal!( - outer_integrator::OperatorSplittingIntegrator, - inner_integrator::DEIntegrator, solution_indices +# Parent = SplitSubIntegrator, child = SplitSubIntegrator +function backward_sync_subintegrator!( + parent::SplitSubIntegrator, + child::SplitSubIntegrator, + solution_indices, + sync ) - @views uouter = outer_integrator.u[solution_indices] - return sync_vectors!(uouter, inner_integrator.u) + @views umaster = parent.u_master[solution_indices] + sync_vectors!(umaster, child.u) + @views ulocal = parent.u[solution_indices .- first(parent.solution_indices) .+ 1] + sync_vectors!(ulocal, child.u) + return backward_sync_external!(parent, child, sync) end # --------------------------------------------------------------------------- # forward_sync_external! / backward_sync_external! +# These handle parameter synchronisation via the `sync` object. # --------------------------------------------------------------------------- -function forward_sync_external!( - outer_integrator::OperatorSplittingIntegrator, - inner_integrator::OperatorSplittingIntegrator, sync::NoExternalSynchronization - ) - return nothing -end -function forward_sync_external!( - outer_integrator::OperatorSplittingIntegrator, - inner_integrator::DEIntegrator, sync::NoExternalSynchronization - ) - return nothing -end -# SplitSubIntegrator has no parameters for now → no-op -function forward_sync_external!( - outer_integrator::OperatorSplittingIntegrator, - sub::SplitSubIntegrator, sync::NoExternalSynchronization - ) - return nothing -end -function forward_sync_external!( - outer_integrator::OperatorSplittingIntegrator, - inner_integrator::DEIntegrator, sync - ) - return synchronize_solution_with_parameters!(outer_integrator, inner_integrator.p, sync) -end -function forward_sync_external!( - outer_integrator::OperatorSplittingIntegrator, - sub::SplitSubIntegrator, sync - ) - # SplitSubIntegrator does not carry p for now; dispatch on sync type if needed - return nothing -end -function backward_sync_external!( - outer_integrator::OperatorSplittingIntegrator, - inner_integrator::OperatorSplittingIntegrator, sync::NoExternalSynchronization - ) - return nothing -end -function backward_sync_external!( - outer_integrator::OperatorSplittingIntegrator, - inner_integrator::DEIntegrator, sync::NoExternalSynchronization - ) - return nothing -end -function backward_sync_external!( - outer_integrator::OperatorSplittingIntegrator, - sub::SplitSubIntegrator, sync::NoExternalSynchronization - ) - return nothing -end -function backward_sync_external!( - outer_integrator::OperatorSplittingIntegrator, - inner_integrator::DEIntegrator, sync +# NoExternalSynchronization: no-op for all parent/child combinations +forward_sync_external!(parent::DEIntegrator, child::DEIntegrator, ::NoExternalSynchronization) = nothing +backward_sync_external!(parent::DEIntegrator, child::DEIntegrator, ::NoExternalSynchronization) = nothing +forward_sync_external!(parent::OperatorSplittingIntegrator, child::DEIntegrator, ::NoExternalSynchronization) = nothing +backward_sync_external!(parent::OperatorSplittingIntegrator, child::DEIntegrator, ::NoExternalSynchronization) = nothing + +# OperatorSplittingIntegrator parent with DEIntegrator child: parameter sync +function forward_sync_external!( + parent::OperatorSplittingIntegrator, + child::DEIntegrator, + sync ) - return synchronize_solution_with_parameters!(outer_integrator, inner_integrator.p, sync) + return synchronize_solution_with_parameters!(parent, child.p, sync) end function backward_sync_external!( - outer_integrator::OperatorSplittingIntegrator, - sub::SplitSubIntegrator, sync + parent::OperatorSplittingIntegrator, + child::DEIntegrator, + sync ) - return nothing + return synchronize_solution_with_parameters!(parent, child.p, sync) end + function synchronize_solution_with_parameters!( - outer_integrator::OperatorSplittingIntegrator, p, sync + parent::OperatorSplittingIntegrator, p, sync ) @warn "Outer synchronizer not dispatched for parameter type $(typeof(p)) with synchronizer type $(typeof(sync))." maxlog = 1 return nothing end function synchronize_solution_with_parameters!( - outer_integrator::OperatorSplittingIntegrator, p::NullParameters, sync + parent::OperatorSplittingIntegrator, ::NullParameters, sync ) return nothing end - -# --------------------------------------------------------------------------- -# NOTE: build_solution_index_tree and build_synchronizer_tree are NO LONGER -# needed as standalone functions — the information is now embedded directly -# into each SplitSubIntegrator during build_subintegrator_tree_with_cache. -# They are kept here (no-ops returning nothing) only so that any external -# code that might call them does not hard-error. -# --------------------------------------------------------------------------- -function build_solution_index_tree(f::GenericSplitFunction) - # Deprecated: solution index trees now live inside SplitSubIntegrator. - return nothing -end - -function build_synchronizer_tree(f::GenericSplitFunction) - # Deprecated: synchronizer trees now live inside SplitSubIntegrator. - return nothing -end diff --git a/test/operator_splitting_api.jl b/test/operator_splitting_api.jl index 57a514c..f0125db 100644 --- a/test/operator_splitting_api.jl +++ b/test/operator_splitting_api.jl @@ -215,7 +215,6 @@ end # SplitSubIntegrators now carry t and iter at each level sub1 = integrator.subintegrator_tree[1] - @test sub1 isa OS.SplitSubIntegrator @test sub1.t ≈ tspan[2] @test sub1.iter == nsteps From 6bdb2c786b777177dfe5b156dd8344d1f1e2a50e Mon Sep 17 00:00:00 2001 From: termi-official Date: Fri, 20 Feb 2026 23:30:35 +0100 Subject: [PATCH 03/17] Cut down some redundancies and fix tstop handling. --- src/OrdinaryDiffEqOperatorSplitting.jl | 4 +- src/function.jl | 15 +- src/integrator.jl | 321 ++++++++++++++++++------- src/precompilation.jl | 23 +- src/problem.jl | 2 +- src/solver.jl | 126 +--------- src/utils.jl | 51 +++- test/operator_splitting_api.jl | 32 +-- 8 files changed, 324 insertions(+), 250 deletions(-) diff --git a/src/OrdinaryDiffEqOperatorSplitting.jl b/src/OrdinaryDiffEqOperatorSplitting.jl index 7647074..a386d75 100644 --- a/src/OrdinaryDiffEqOperatorSplitting.jl +++ b/src/OrdinaryDiffEqOperatorSplitting.jl @@ -11,13 +11,15 @@ import SciMLBase: DEIntegrator, NullParameters, isadaptive import RecursiveArrayTools -import OrdinaryDiffEqCore +import OrdinaryDiffEqCore: OrdinaryDiffEqCore, isdtchangeable, + stepsize_controller!, step_accept_controller!, step_reject_controller! abstract type AbstractOperatorSplitFunction <: SciMLBase.AbstractODEFunction{true} end abstract type AbstractOperatorSplittingAlgorithm end abstract type AbstractOperatorSplittingCache end @inline SciMLBase.isadaptive(::AbstractOperatorSplittingAlgorithm) = false +@inline isdtchangeable(alg::AbstractOperatorSplittingAlgorithm) = all(isdtchangeable.(alg.inner_algs)) include("function.jl") include("problem.jl") diff --git a/src/function.jl b/src/function.jl index 31e18ac..6128478 100644 --- a/src/function.jl +++ b/src/function.jl @@ -13,11 +13,24 @@ struct GenericSplitFunction{fSetType <: Tuple, idxSetType <: Tuple, sSetType <: # Operators to update the ode function parameters. synchronizers::sSetType function GenericSplitFunction(fs::Tuple, drs::Tuple, syncers::Tuple) - @assert length(fs) == length(drs) == length(syncers) + @assert length(fs) == length(drs) == length(syncers) "Number of input tuples does not match." + gsf_recursive_function_type_safety_check.(fs) return new{typeof(fs), typeof(drs), typeof(syncers)}(fs, drs, syncers) end end +function gsf_recursive_function_type_safety_check(f::GenericSplitFunction) + gsf_recursive_function_type_safety_check.(f.functions) +end + +function gsf_recursive_function_type_safety_check(dunno) + error("Failed to construct GenericSplitFunction. One of the inner functions is of type $(typeof(dunno)) which is not a subtype of SciMLBase.AbstractDiffEqFunction.") +end + +function gsf_recursive_function_type_safety_check(::SciMLBase.AbstractDiffEqFunction) + # OK +end + num_operators(f::GenericSplitFunction) = length(f.functions) """ diff --git a/src/integrator.jl b/src/integrator.jl index 4b66e55..1996483 100644 --- a/src/integrator.jl +++ b/src/integrator.jl @@ -35,7 +35,7 @@ SplitSubIntegratorStatus() = SplitSubIntegratorStatus(ReturnCode.Default) # SplitSubIntegrator # --------------------------------------------------------------------------- """ - SplitSubIntegrator + SplitSubIntegrator <: AbstractODEIntegrator An intermediate node in the operator-splitting subintegrator tree. @@ -77,29 +77,38 @@ mutable struct SplitSubIntegrator{ solidxType, childSolidxType, childSyncType, - } + optionsType + } <: SciMLBase.AbstractODEIntegrator{algType, true, uType, tType} alg::algType u::uType # local solution buffer uprev::uType # local rollback buffer u_master::uType # reference to outermost master u t::tType + tprev::tType dt::tType dtcache::tType + const dtchangeable::Bool iter::Int EEst::EEstType controller::controllerType force_stepfail::Bool last_step_failed::Bool + u_modified::Bool # TODO we can probably remove this status::SplitSubIntegratorStatus + stats::IntegratorStats cache::cacheType child_subintegrators::childSubintType # Tuple solution_indices::solidxType child_solution_indices::childSolidxType # Tuple child_synchronizers::childSyncType # Tuple + opts::optionsType end # --- SplitSubIntegrator interface --- +tdir(integrator::SplitSubIntegrator) = sign(integrator.dt) + +@inline SciMLBase.has_tstop(::SplitSubIntegrator) = false @inline SciMLBase.isadaptive(sub::SplitSubIntegrator) = isadaptive(sub.alg) # proposed-dt interface (mirrors ODEIntegrator) @@ -121,11 +130,6 @@ end A variant of [`ODEIntegrator`](https://github.com/SciML/OrdinaryDiffEq.jl/blob/6ec5a55bda26efae596bf99bea1a1d729636f412/src/integrators/type.jl#L77-L123) to perform operator splitting. - -Derived from https://github.com/CliMA/ClimaTimeSteppers.jl/blob/ef3023747606d2750e674d321413f80638136632/src/integrators.jl. - -Note: `solution_index_tree` and `synchronizer_tree` have been removed; this -information now lives inside each [`SplitSubIntegrator`](@ref) child node. """ mutable struct OperatorSplittingIntegrator{ fType, @@ -179,11 +183,7 @@ mutable struct OperatorSplittingIntegrator{ tdir::tType end -# Convenience: the old field name `subintegrator_tree` was used in tests and -# docs; alias it so external code still compiles during the transition. -# (Remove in a future breaking release.) -@inline Base.getproperty(i::OperatorSplittingIntegrator, s::Symbol) = - s === :subintegrator_tree ? getfield(i, :child_subintegrators) : getfield(i, s) +const AnySplitIntegrator = Union{SplitSubIntegrator, OperatorSplittingIntegrator} # --------------------------------------------------------------------------- # __init @@ -216,7 +216,7 @@ function SciMLBase.__init( (!isadaptive(alg) && adaptive && verbose) && @warn("The algorithm $alg is not adaptive.") - dtchangeable = true + dtchangeable = isdtchangeable(alg) if tstops isa AbstractArray || tstops isa Tuple || tstops isa Number _tstops = nothing @@ -436,24 +436,24 @@ function OrdinaryDiffEqCore.handle_tstop!(integrator::OperatorSplittingIntegrato return nothing end -notify_integrator_hit_tstop!(integrator::OperatorSplittingIntegrator) = nothing +notify_integrator_hit_tstop!(integrator::AnySplitIntegrator) = nothing -is_first_iteration(integrator::OperatorSplittingIntegrator) = integrator.iter == 0 -increment_iteration(integrator::OperatorSplittingIntegrator) = integrator.iter += 1 +is_first_iteration(integrator::AnySplitIntegrator) = integrator.iter == 0 +increment_iteration(integrator::AnySplitIntegrator) = integrator.iter += 1 # --------------------------------------------------------------------------- -# Step accept/reject — outermost integrator +# Step accept/reject # --------------------------------------------------------------------------- -function reject_step!(integrator::OperatorSplittingIntegrator) +function reject_step!(integrator::AnySplitIntegrator) OrdinaryDiffEqCore.increment_reject!(integrator.stats) return reject_step!(integrator, integrator.cache, integrator.controller) end -function reject_step!(integrator::OperatorSplittingIntegrator, cache, controller) +function reject_step!(integrator::AnySplitIntegrator, cache, controller) integrator.u .= integrator.uprev # TODO: roll back sub-integrators return nothing end -function reject_step!(integrator::OperatorSplittingIntegrator, cache, ::Nothing) +function reject_step!(integrator::AnySplitIntegrator, cache, ::Nothing) if length(integrator.uprev) == 0 error("Cannot roll back integrator. Aborting time integration step at $(integrator.t).") end @@ -464,62 +464,55 @@ function should_accept_step(integrator::OperatorSplittingIntegrator) integrator.force_stepfail || integrator.isout && return false return should_accept_step(integrator, integrator.cache, integrator.controller) end -function should_accept_step(integrator::OperatorSplittingIntegrator, cache, ::Nothing) +function should_accept_step(integrator::SplitSubIntegrator) + integrator.force_stepfail && return false + return should_accept_step(integrator, integrator.cache, integrator.controller) +end +function should_accept_step(integrator::AnySplitIntegrator, cache, ::Nothing) return !(integrator.force_stepfail) end -function accept_step!(integrator::OperatorSplittingIntegrator) +function accept_step!(integrator::AnySplitIntegrator) OrdinaryDiffEqCore.increment_accept!(integrator.stats) return accept_step!(integrator, integrator.cache, integrator.controller) end -function accept_step!(integrator::OperatorSplittingIntegrator, cache, controller) +function accept_step!(integrator::AnySplitIntegrator, cache, controller) return store_previous_info!(integrator) end -function store_previous_info!(integrator::OperatorSplittingIntegrator) +function store_previous_info!(integrator::AnySplitIntegrator) if length(integrator.uprev) > 0 update_uprev!(integrator) end return nothing end -function update_uprev!(integrator::OperatorSplittingIntegrator) +function update_uprev!(integrator::AnySplitIntegrator) RecursiveArrayTools.recursivecopy!(integrator.uprev, integrator.u) return nothing end -# Step accept/reject — SplitSubIntegrator -function accept_step!(sub::SplitSubIntegrator) - RecursiveArrayTools.recursivecopy!(sub.uprev, sub.u) - return nothing -end -function reject_step!(sub::SplitSubIntegrator) - sub.u .= sub.uprev - _rollback_children!(sub.child_subintegrators, sub.u_master) - return nothing -end - # Roll back each child's local buffer to match master u. # For DEIntegrators the leaf will be re-synced via forward_sync before the # next attempt, so there is nothing to do here. -@unroll function _rollback_children!(children::Tuple, u_master) +@unroll function rollback_children!(children::Tuple, u_master) @unroll for child in children - _rollback_child!(child, u_master) + rollback_child!(child, u_master) end end -function _rollback_child!(child::SplitSubIntegrator, u_master) +function rollback_child!(child::SplitSubIntegrator, u_master) child.u .= @view u_master[child.solution_indices] RecursiveArrayTools.recursivecopy!(child.uprev, child.u) _rollback_children!(child.child_subintegrators, u_master) return nothing end -function _rollback_child!(child::DEIntegrator, u_master) +function rollback_child!(child::DEIntegrator, u_master) # forward_sync before the next sub-step will restore this correctly. return nothing end # --------------------------------------------------------------------------- -# step_header! / step_footer! — outermost integrator +# step_header! / step_footer! # --------------------------------------------------------------------------- -function step_header!(integrator::OperatorSplittingIntegrator) +function step_header!(integrator::AnySplitIntegrator) if !is_first_iteration(integrator) if should_accept_step(integrator) accept_step!(integrator) @@ -532,15 +525,20 @@ function step_header!(integrator::OperatorSplittingIntegrator) increment_iteration(integrator) OrdinaryDiffEqCore.fix_dt_at_bounds!(integrator) OrdinaryDiffEqCore.modify_dt_for_tstops!(integrator) - return integrator.force_stepfail = false + integrator.force_stepfail = false + return nothing end function footer_reset_flags!(integrator) - return integrator.u_modified = false + integrator.u_modified = false + return end +footer_reset_flags!(::SplitSubIntegrator) = nothing function setup_validity_flags!(integrator, t_next) - return integrator.isout = false + integrator.isout = false + return end +setup_validity_flags!(::SplitSubIntegrator, _) = nothing function fix_solution_buffer_sizes!(integrator, sol) resize!(integrator.sol.t, integrator.saveiter) resize!(integrator.sol.u, integrator.saveiter) @@ -550,14 +548,49 @@ function fix_solution_buffer_sizes!(integrator, sol) return nothing end -function step_footer!(integrator::OperatorSplittingIntegrator) +function OrdinaryDiffEqCore.fixed_t_for_floatingpoint_error!(integrator::OperatorSplittingIntegrator, ttmp) + return if DiffEqBase.has_tstop(integrator) + tstop = integrator.tdir * DiffEqBase.first_tstop(integrator) + if abs(ttmp - tstop) < + 100eps(float(max(integrator.t, tstop) / oneunit(integrator.t))) * + oneunit(integrator.t) + # We have to update the floating point errors of the subintegrator nodes, because + # they do not have the tstop logic. + try_snap_children_to_tstop!.(integrator.child_subintegrators, tstop) + tstop + else + ttmp + end + else + ttmp + end +end +function try_snap_children_to_tstop!(integrator::SplitSubIntegrator, tstop) + if abs(tstop - integrator.t) < + 100eps(float(max(integrator.t, tstop) / oneunit(integrator.t))) * oneunit(integrator.t) + integrator.t = tstop + else + @warn "Failed to snap timestep for integrator $(integrator.t) with parent integrator hitting the tstop $(tstop)." + end + try_snap_children_to_tstop!.(integrator.child_subintegrators, tstop) +end +function try_snap_children_to_tstop!(integrator::DEIntegrator, tstop) + if abs(tstop - integrator.t) < + 100eps(float(max(integrator.t, tstop) / oneunit(integrator.t))) * oneunit(integrator.t) + integrator.t = tstop + else + @warn "Failed to snap timestep for integrator $(integrator.t) with parent integrator hitting the tstop $(tstop)." + end +end + +function step_footer!(integrator::AnySplitIntegrator) ttmp = integrator.t + tdir(integrator) * integrator.dt footer_reset_flags!(integrator) setup_validity_flags!(integrator, ttmp) if should_accept_step(integrator) integrator.last_step_failed = false integrator.tprev = integrator.t - integrator.t = ttmp + integrator.t = OrdinaryDiffEqCore.fixed_t_for_floatingpoint_error!(integrator, ttmp) step_accept_controller!(integrator) elseif integrator.force_stepfail if isadaptive(integrator) @@ -570,6 +603,7 @@ function step_footer!(integrator::OperatorSplittingIntegrator) end integrator.last_step_failed = true end + validate_time_point(integrator) return nothing end @@ -652,6 +686,25 @@ function DiffEqBase.step!(integrator::OperatorSplittingIntegrator, dt, stop_at_t step_footer!(integrator) end end + OrdinaryDiffEqCore.handle_tstop!(integrator) + return nothing +end + +function DiffEqBase.step!(integrator::SplitSubIntegrator, dt, stop_at_tdt = false) + @timeit_debug "step!" begin + dt <= zero(dt) && error("dt must be positive") + stop_at_tdt && !integrator.dtchangeable && + error("Cannot stop at t + dt if dtchangeable is false") + tnext = integrator.t + tdir(integrator) * dt + while !reached_tstop(integrator, tnext, stop_at_tdt) + step_header!(integrator) + @timeit_debug "check_error" SciMLBase.check_error!(integrator) ∉ ( + ReturnCode.Success, ReturnCode.Default, + ) && return + __step!(integrator) + step_footer!(integrator) + end + end return nothing end @@ -671,6 +724,25 @@ function SciMLBase.check_error(integrator::OperatorSplittingIntegrator) return _check_error_children(integrator.sol.retcode, integrator.child_subintegrators) end +function SciMLBase.check_error(integrator::SplitSubIntegrator) + if !SciMLBase.successful_retcode(integrator.status.retcode) && + integrator.status.retcode != ReturnCode.Default + return integrator.status.retcode + end + if DiffEqBase.NAN_CHECK(integrator.dtcache) || DiffEqBase.NAN_CHECK(integrator.dt) + integrator.opts.verbose && + @warn("NaN dt detected. Likely a NaN value in the state, parameters, or derivative value caused this outcome.") + return ReturnCode.DtNaN + end + return _check_error_children(integrator.status.retcode, integrator.child_subintegrators) +end + +function SciMLBase.check_error!(integrator::SplitSubIntegrator) + code = SciMLBase.check_error(integrator) + integrator.status.retcode = code + return code +end + @unroll function _check_error_children(current_retcode, children::Tuple) @unroll for child in children rc = _child_retcode(child) @@ -704,16 +776,18 @@ function (integrator::OperatorSplittingIntegrator)(tmp, t) ) end -# Stepsize controller hooks — outermost integrator -@inline function stepsize_controller!(integrator::OperatorSplittingIntegrator) +# Stepsize controller hooks +@inline function stepsize_controller!(integrator::AnySplitIntegrator) isadaptive(integrator.alg) || return nothing - return stepsize_controller!(integrator, integrator.alg) + stepsize_controller!(integrator, integrator.alg) + return nothing end -@inline function step_accept_controller!(integrator::OperatorSplittingIntegrator) +@inline function step_accept_controller!(integrator::AnySplitIntegrator) isadaptive(integrator.alg) || return nothing - return step_accept_controller!(integrator, integrator.alg, nothing) + step_accept_controller!(integrator, integrator.alg, nothing) + return nothing end -@inline function step_reject_controller!(integrator::OperatorSplittingIntegrator) +@inline function step_reject_controller!(integrator::AnySplitIntegrator) isadaptive(integrator.alg) || return nothing return step_reject_controller!(integrator, integrator.alg, nothing) end @@ -727,7 +801,7 @@ function reached_tstop(integrator, tstop, stop_at_tstop = integrator.dtchangeabl if stop_at_tstop integrator.t > tstop && error("Integrator missed stop at $tstop (current time=$(integrator.t)). Aborting.") - return integrator.t == tstop + return integrator.t ≈ tstop else return is_past_t(integrator, tstop) end @@ -747,59 +821,106 @@ function SciMLBase.postamble!(integrator::OperatorSplittingIntegrator) return DiffEqBase.finalize!(integrator.callback, integrator.u, integrator.t, integrator) end -function __step!(integrator::OperatorSplittingIntegrator) - tnext = integrator.t + integrator.dt - _sync_children!(integrator) - advance_solution_to!(integrator, tnext) +function __step!(integrator::AnySplitIntegrator) + advance_solution_by!(integrator, integrator.dt) stepsize_controller!(integrator) return nothing end -# Sync all direct children of the outermost integrator -function _sync_children!(integrator::OperatorSplittingIntegrator) - _sync_children_tuple!(integrator.child_subintegrators, integrator) +# Entry point: dispatch to the algorithm's advance_solution_by! +function advance_solution_by!(integrator::AnySplitIntegrator, dt) + return advance_solution_by!(integrator, integrator.cache, dt) end -@unroll function _sync_children_tuple!( +# Algorithm-level dispatch (implemented in solver.jl per algorithm) +function advance_solution_by!( + integrator::AnySplitIntegrator, + cache::AbstractOperatorSplittingCache, dt + ) + return advance_solution_by!( + integrator, integrator.child_subintegrators, cache, dt + ) +end + +# --------------------------------------------------------------------------- +# advance_solution_by! for a SplitSubIntegrator node +# +# The SplitSubIntegrator is now the *parent* for its own children. +# It carries child_solution_indices and child_synchronizers directly. +# +# Entry point called from integrator.jl for a SplitSubIntegrator node +# --------------------------------------------------------------------------- +function advance_solution_by!( + outer::OperatorSplittingIntegrator, children::Tuple, - parent::OperatorSplittingIntegrator + cache::AbstractOperatorSplittingCache, + dt ) - @unroll for child in children - _sync_child_to_parent!(child, parent) + _perform_step!(outer, children, cache, dt) + + if outer.force_stepfail && all(isadaptive.(children)) + # We do not know recover at this point, as an decrease in the solve + # interval is unlikely to help here. + outer.sol = SciMLBase.solution_new_retcode( + outer.sol, + ReturnCode.Failure + ) + return end + + return end -function _sync_child_to_parent!(child::DEIntegrator, parent::OperatorSplittingIntegrator) - @assert child.t == parent.t "($(child.t) != $(parent.t))" - if !isadaptive(child) && child.dtchangeable - SciMLBase.set_proposed_dt!(child, parent.dt) +function advance_solution_by!( + outer::SplitSubIntegrator, + children::Tuple, + cache::AbstractOperatorSplittingCache, + dt +) + _perform_step!(outer, children, cache, dt) + + if outer.force_stepfail + outer.status = SplitSubIntegratorStatus(ReturnCode.Failure) + return end + + # All children succeeded: advance this node's time and counter + outer.status = SplitSubIntegratorStatus(ReturnCode.Success) + + return end -function _sync_child_to_parent!( - child::SplitSubIntegrator, parent::OperatorSplittingIntegrator +# Recursion dispatch +function advance_solution_by!( + outer::AnySplitIntegrator, + sub::SplitSubIntegrator, + dt ) - @assert child.t == parent.t "($(child.t) != $(parent.t))" - if !isadaptive(child) - SciMLBase.set_proposed_dt!(child, parent.dt) + SciMLBase.step!(sub, dt, true) + + # Unrecoverable failure: error immediately regardless of adaptive/non-adaptive + if !SciMLBase.successful_retcode(sub.status.retcode) && + sub.status.retcode != ReturnCode.Default + error("Inner integrator failed unrecoverably with retcode \ + $(sub.status.retcode) at t=$(child.t). Aborting.") end + return nothing end -# Entry point: dispatch to the algorithm's advance_solution_to! -function advance_solution_to!(integrator::OperatorSplittingIntegrator, tnext) - return advance_solution_to!(integrator, integrator.cache, tnext) -end +# Leaf disptach +function advance_solution_by!(outer::AnySplitIntegrator, child::DEIntegrator, dt) + SciMLBase.step!(child, dt, true) -# Algorithm-level dispatch (implemented in solver.jl per algorithm) -function advance_solution_to!( - integrator::OperatorSplittingIntegrator, - cache::AbstractOperatorSplittingCache, tnext::Number - ) - return advance_solution_to!( - integrator, integrator.child_subintegrators, cache, tnext - ) + # Unrecoverable failure: error immediately regardless of adaptive/non-adaptive + if !SciMLBase.successful_retcode(child.sol.retcode) && + child.sol.retcode != ReturnCode.Default + error("Inner integrator failed unrecoverably with retcode \ + $(child.sol.retcode) at t=$(child.t). Aborting.") + end + return nothing end + # --------------------------------------------------------------------------- # Tree construction # --------------------------------------------------------------------------- @@ -904,17 +1025,20 @@ function _build_child( RecursiveArrayTools.recursivecopy(Array(u_sub)), RecursiveArrayTools.recursivecopy(Array(u_sub)), u_master, - t0, dt, dt, # t, dt, dtcache + t0, t0, dt, dt, # t, tprev, dt, dtcache + isdtchangeable(alg), 0, # iter EEst_val, controller, - false, false, # force_stepfail, last_step_failed + false, false, false, # force_stepfail, last_step_failed, u_modified SplitSubIntegratorStatus(), + IntegratorStats(), level_cache, child_subintegrators, solution_indices, child_solution_indices, - child_synchronizers + child_synchronizers, + IntegratorOptions(; verbose, adaptive), ) return sub, level_cache @@ -937,14 +1061,15 @@ function _build_child( ) where {S, T, P, F} u = @view uouter[solution_indices] prob2 = if p isa NullParameters - SciMLBase.ODEProblem(f, u, (t0, min(t0 + dt, tf))) + SciMLBase.ODEProblem(f, u, (t0, tf)) else - SciMLBase.ODEProblem(f, u, (t0, min(t0 + dt, tf)), p) + SciMLBase.ODEProblem(f, u, (t0, tf), p) end integrator = SciMLBase.__init( prob2, alg; dt, + tstops, saveat = (), d_discontinuities, save_everystep = false, @@ -973,6 +1098,20 @@ function DiffEqBase.add_tstop!(i::OperatorSplittingIntegrator, t) is_past_t(i, t) && error("Cannot add a tstop at $t because that is behind the current \ integrator time $(i.t)") + DiffEqBase.add_tstop!.(i.child_subintegrators, t) + push!(i.tstops, t) + return nothing +end + +function DiffEqBase.add_tstop!(i::SplitSubIntegrator, t) + DiffEqBase.add_tstop!.(i.child_subintegrators, t) +end + +function _add_tstop!(i::OperatorSplittingIntegrator, t) + is_past_t(i, t) && + error("Cannot add a tstop at $t because that is behind the current \ + integrator time $(i.t)") + _add_tstop!(i.child_subintegrators, t) return push!(i.tstops, t) end diff --git a/src/precompilation.jl b/src/precompilation.jl index 3514554..11dc2e8 100644 --- a/src/precompilation.jl +++ b/src/precompilation.jl @@ -1,13 +1,21 @@ using PrecompileTools: @compile_workload using OrdinaryDiffEqLowOrderRK: Euler +import DiffEqBase: DiffEqBase, step!, solve! function _precompile_ode1(du, u, p, t) - return @. du = -0.1u + @. du = -0.1u + return end function _precompile_ode2(du, u, p, t) du[1] = -0.01u[2] - return du[2] = -0.01u[1] + du[2] = -0.01u[1] + return +end + +function _precompile_ode3(du, u, p, t) + du[1] = -0.01u[2] + du[2] = -0.01u[1] end @compile_workload begin @@ -17,16 +25,19 @@ end f1 = DiffEqBase.ODEFunction(_precompile_ode1) f2 = DiffEqBase.ODEFunction(_precompile_ode2) + f3 = DiffEqBase.ODEFunction(_precompile_ode3) f1dofs = [1, 2, 3] f2dofs = [1, 3] - fsplit = GenericSplitFunction((f1, f2), (f1dofs, f2dofs)) + f3dofs = [2, 3] + fsplitinner = GenericSplitFunction((f2, f3), (f2dofs, f3dofs)) + fsplit = GenericSplitFunction((f1, fsplitinner), (f1dofs, [1,2,3])) prob = OperatorSplittingProblem(fsplit, u0, tspan) - tstepper = LieTrotterGodunov((Euler(), Euler())) + tstepper = LieTrotterGodunov((Euler(), LieTrotterGodunov((Euler(), Euler())))) # Precompile init and a few steps integrator = DiffEqBase.init(prob, tstepper, dt = 0.01, verbose = false) - DiffEqBase.step!(integrator) - DiffEqBase.solve!(integrator) + # step!(integrator) + # solve!(integrator) end diff --git a/src/problem.jl b/src/problem.jl index 59e5b1d..d3d8539 100644 --- a/src/problem.jl +++ b/src/problem.jl @@ -35,6 +35,6 @@ recursive_null_parameters(f::AbstractOperatorSplitFunction) = @error "Not implem function recursive_null_parameters(f::GenericSplitFunction) return ntuple(i -> recursive_null_parameters(get_operator(f, i)), length(f.functions)) end -function recursive_null_parameters(f) # Wildcard for leafs +function recursive_null_parameters(f::SciMLBase.AbstractDiffEqFunction) # Wildcard for leafs return NullParameters() end diff --git a/src/solver.jl b/src/solver.jl index 8b219a4..0a7b969 100644 --- a/src/solver.jl +++ b/src/solver.jl @@ -29,62 +29,11 @@ function init_cache( return LieTrotterGodunovCache(_u, _uprev, inner_caches) end -# --------------------------------------------------------------------------- -# advance_solution_to! for a SplitSubIntegrator node -# -# The SplitSubIntegrator is now the *parent* for its own children. -# It carries child_solution_indices and child_synchronizers directly. -# -# Entry point called from integrator.jl for a SplitSubIntegrator node -# --------------------------------------------------------------------------- -function advance_solution_to!( - outer::OperatorSplittingIntegrator, - children::Tuple, - cache::AbstractOperatorSplittingCache, - tnext - ) - _perform_step!(outer, children, cache, tnext) - - if outer.force_stepfail - outer.sol = SciMLBase.solution_new_retcode( - outer.sol, - ReturnCode.Failure - ) - return - end - - # All children succeeded: advance this node's time and counter - # outer.sol = SciMLBase.solution_new_retcode( - # outer.sol, - # ReturnCode.Success - # ) - return -end - -function advance_solution_to!( - outer::SplitSubIntegrator, - children::Tuple, - cache::AbstractOperatorSplittingCache, - tnext -) - _perform_step!(outer, children, cache, tnext) - - if outer.force_stepfail - outer.status = SplitSubIntegratorStatus(ReturnCode.Failure) - return - end - - # All children succeeded: advance this node's time and counter - outer.status = SplitSubIntegratorStatus(ReturnCode.Success) - - return -end - @unroll function _perform_step!( outer, children::Tuple, cache::LieTrotterGodunovCache, - tnext + dt ) i = 0 @unroll for child in children @@ -93,7 +42,7 @@ end sync = outer.child_synchronizers[i] @timeit_debug "sync ->" forward_sync_subintegrator!(outer, child, idxs, sync) - @timeit_debug "time solve" _do_step!(outer, child, tnext) + @timeit_debug "time solve" advance_solution_by!(outer, child, dt) if _child_failed(child) outer.force_stepfail = true return @@ -101,74 +50,3 @@ end backward_sync_subintegrator!(outer, child, idxs, sync) end end - -# --------------------------------------------------------------------------- -# _do_step!: pure integration, no sync. -# The caller (advance_children_*) owns forward/backward sync around this. -# --------------------------------------------------------------------------- - -# Leaf: DEIntegrator -function _do_step!( - outer::OperatorSplittingIntegrator, - child::DEIntegrator, - tnext - ) - dt = tnext - child.t - SciMLBase.step!(child, dt, true) - - # Unrecoverable failure: error immediately regardless of adaptive/non-adaptive - if !SciMLBase.successful_retcode(child.sol.retcode) && - child.sol.retcode != ReturnCode.Default - error("Inner integrator failed unrecoverably with retcode \ - $(child.sol.retcode) at t=$(child.t). Aborting.") - end - return nothing -end - -# Intermediate: SplitSubIntegrator — recurse -function _do_step!( - outer::OperatorSplittingIntegrator, - sub::SplitSubIntegrator, - tnext - ) - # Sync sub's children among themselves before recursing. - # (The parent already synced sub.u from master u via forward_sync before - # calling _do_step!; here we propagate sub.dt down to sub's own children.) - _sync_sub_children!(sub) - advance_solution_to!(outer, sub, tnext) - return nothing -end - -# Propagate time-step information from sub to its own children -function _sync_sub_children!(sub::SplitSubIntegrator) - _sync_sub_children_tuple!(sub.child_subintegrators, sub) -end - -@unroll function _sync_sub_children_tuple!(children::Tuple, parent::SplitSubIntegrator) - @unroll for child in children - _sync_child_to_sub_parent!(child, parent) - end -end - -function _sync_child_to_sub_parent!(child::DEIntegrator, parent::SplitSubIntegrator) - @assert child.t == parent.t "($(child.t) != $(parent.t))" - if !isadaptive(child) && child.dtchangeable - SciMLBase.set_proposed_dt!(child, parent.dt) - end -end - -function _sync_child_to_sub_parent!(child::SplitSubIntegrator, parent::SplitSubIntegrator) - @assert child.t == parent.t "($(child.t) != $(parent.t))" - if !isadaptive(child) - SciMLBase.set_proposed_dt!(child, parent.dt) - end -end - -# --------------------------------------------------------------------------- -# _child_failed: check whether a child reported a failure -# --------------------------------------------------------------------------- -_child_failed(child::DEIntegrator) = - child.sol.retcode ∉ (ReturnCode.Default, ReturnCode.Success) - -_child_failed(child::SplitSubIntegrator) = - child.status.retcode ∉ (ReturnCode.Default, ReturnCode.Success) diff --git a/src/utils.jl b/src/utils.jl index d87f000..43263b6 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -61,8 +61,9 @@ function forward_sync_subintegrator!( solution_indices, sync ) - _forward_sync_internal_leaf!(parent.u, child, solution_indices) - return forward_sync_external!(parent, child, sync) + forward_sync_internal!(parent.u, child, solution_indices) + forward_sync_external!(parent, child, sync) + return nothing end # Parent = outermost OperatorSplittingIntegrator, child = SplitSubIntegrator @@ -88,7 +89,7 @@ function forward_sync_subintegrator!( ) # parent.u is this level's buffer; solution_indices are relative to # the master u. We read from the master via u_master. - _forward_sync_internal_leaf!(parent.u_master, child, solution_indices) + forward_sync_internal!(parent.u_master, child, solution_indices) forward_sync_external!(parent, child, sync) return nothing end @@ -108,7 +109,7 @@ function forward_sync_subintegrator!( end # Shared internal helper: copy master u slice → leaf DEIntegrator u/uprev -function _forward_sync_internal_leaf!(u_source, child::DEIntegrator, solution_indices) +function forward_sync_internal!(u_source, child::DEIntegrator, solution_indices) @views usrc = u_source[solution_indices] sync_vectors!(child.uprev, usrc) sync_vectors!(child.u, usrc) @@ -216,3 +217,45 @@ function synchronize_solution_with_parameters!( ) return nothing end + +# Time stuff +function OrdinaryDiffEqCore.fix_dt_at_bounds!(integrator::AnySplitIntegrator) + if tdir(integrator) > 0 + integrator.dt = min(integrator.opts.dtmax, integrator.dt) + else + integrator.dt = max(integrator.opts.dtmax, integrator.dt) + end + dtmin = OrdinaryDiffEqCore.timedepentdtmin(integrator) + if tdir(integrator) > 0 + integrator.dt = max(integrator.dt, dtmin) + else + integrator.dt = min(integrator.dt, dtmin) + end + return nothing +end + +# Check time-step information consistency +validate_time_point(integrator::AnySplitIntegrator) = validate_time_point(integrator, integrator.child_subintegrators) +function validate_time_point(parent, child::SplitSubIntegrator) + @assert parent.t == child.t "(parent.t=$(parent.t) != child.t=$(child.t))" + validate_time_point(child, child.child_subintegrators) +end + +@unroll function validate_time_point(parent, children::Tuple) + @unroll for child in children + validate_time_point(parent, child) + end +end + +function validate_time_point(parent, child::DEIntegrator) + @assert child.t == parent.t "(parent.t=$(parent.t) != child.t=$(child.t))" +end + +# --------------------------------------------------------------------------- +# _child_failed: check whether a child reported a failure +# --------------------------------------------------------------------------- +_child_failed(child::DEIntegrator) = + child.sol.retcode ∉ (ReturnCode.Default, ReturnCode.Success) + +_child_failed(child::SplitSubIntegrator) = + child.status.retcode ∉ (ReturnCode.Default, ReturnCode.Success) diff --git a/test/operator_splitting_api.jl b/test/operator_splitting_api.jl index f0125db..641629b 100644 --- a/test/operator_splitting_api.jl +++ b/test/operator_splitting_api.jl @@ -6,7 +6,7 @@ import SciMLBase: ReturnCode import DiffEqBase: DiffEqBase, ODEFunction, ODEProblem using OrdinaryDiffEqLowOrderRK using OrdinaryDiffEqTsit5 -using ModelingToolkit +using ModelingToolkit, SciCompDSL # --------------------------------------------------------------------------- # Reference problem @@ -120,47 +120,35 @@ end cache::FakeAdaptiveAlgorithmCache ) = DiffEqBase.get_tmp_cache(integrator, alg, cache.cache) -@inline function OS.advance_solution_to!( +@inline function OS.advance_solution_by!( outer_integrator::OS.OperatorSplittingIntegrator, subintegrators::Tuple, cache::FakeAdaptiveAlgorithmCache, - tnext + dt ) - return OS.advance_solution_to!( - outer_integrator, subintegrators, cache.cache, tnext + return OS.advance_solution_by!( + outer_integrator, subintegrators, cache.cache, dt ) end # For SplitSubIntegrator nodes whose cache was wrapped by FakeAdaptiveAlgorithmCache -@inline function OS.advance_solution_to!( +@inline function OS.advance_solution_by!( outer_integrator::OS.OperatorSplittingIntegrator, sub::OS.SplitSubIntegrator, subintegrators::Tuple, solution_indices::Tuple, synchronizers::Tuple, cache::FakeAdaptiveAlgorithmCache, - tnext + dt ) - return OS.advance_solution_to!( + return OS.advance_solution_by!( outer_integrator, sub, subintegrators, solution_indices, - synchronizers, cache.cache, tnext + synchronizers, cache.cache, dt ) end FakeAdaptiveLTG(inner) = FakeAdaptiveAlgorithm(LieTrotterGodunov(inner)) -# --------------------------------------------------------------------------- -# Helper: given the outer integrator, walk into the first SplitSubIntegrator -# and find the first leaf DEIntegrator. -# --------------------------------------------------------------------------- -function first_leaf_integrator(integrator) - node = integrator.subintegrator_tree[1] - while node isa OS.SplitSubIntegrator - node = node.subintegrator_tree[1] - end - return node -end - # --------------------------------------------------------------------------- # Tests # --------------------------------------------------------------------------- @@ -214,7 +202,7 @@ end @test integrator.iter == nsteps # SplitSubIntegrators now carry t and iter at each level - sub1 = integrator.subintegrator_tree[1] + sub1 = integrator.child_subintegrators[1] @test sub1.t ≈ tspan[2] @test sub1.iter == nsteps From a47201d0cb4bc78ff0e57f0e6a22a262d470c82d Mon Sep 17 00:00:00 2001 From: termi-official Date: Fri, 20 Feb 2026 23:57:09 +0100 Subject: [PATCH 04/17] Relax error --- src/function.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/function.jl b/src/function.jl index 6128478..0760de0 100644 --- a/src/function.jl +++ b/src/function.jl @@ -24,7 +24,7 @@ function gsf_recursive_function_type_safety_check(f::GenericSplitFunction) end function gsf_recursive_function_type_safety_check(dunno) - error("Failed to construct GenericSplitFunction. One of the inner functions is of type $(typeof(dunno)) which is not a subtype of SciMLBase.AbstractDiffEqFunction.") + @warn "One of the inner functions in GenericSplitFunction is of type $(typeof(dunno)) which is not a subtype of SciMLBase.AbstractDiffEqFunction." end function gsf_recursive_function_type_safety_check(::SciMLBase.AbstractDiffEqFunction) From c92e8cc29d9960c1dfe05a53ac33d0979ea4a25d Mon Sep 17 00:00:00 2001 From: termi-official Date: Sat, 21 Feb 2026 00:02:32 +0100 Subject: [PATCH 05/17] Fix parameter init for wildcards --- src/problem.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/problem.jl b/src/problem.jl index d3d8539..35365db 100644 --- a/src/problem.jl +++ b/src/problem.jl @@ -35,6 +35,9 @@ recursive_null_parameters(f::AbstractOperatorSplitFunction) = @error "Not implem function recursive_null_parameters(f::GenericSplitFunction) return ntuple(i -> recursive_null_parameters(get_operator(f, i)), length(f.functions)) end -function recursive_null_parameters(f::SciMLBase.AbstractDiffEqFunction) # Wildcard for leafs +function recursive_null_parameters(f::SciMLBase.AbstractDiffEqFunction) + return NullParameters() +end +function recursive_null_parameters(f) return NullParameters() end From 0cb14ff7e549d4f9743abebe10dba7960c3edfde Mon Sep 17 00:00:00 2001 From: termi-official Date: Sat, 21 Feb 2026 00:08:28 +0100 Subject: [PATCH 06/17] Pretty printing for LTG --- src/solver.jl | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/solver.jl b/src/solver.jl index 0a7b969..2814d49 100644 --- a/src/solver.jl +++ b/src/solver.jl @@ -11,6 +11,16 @@ struct LieTrotterGodunov{AlgTupleType} <: AbstractOperatorSplittingAlgorithm inner_algs::AlgTupleType end +function Base.show(io::IO, alg::LieTrotterGodunov) + print(io, "LTG (") + for inner_alg in alg.inner_algs[1:end-1] + Base.show(io, inner_alg) + print(io, " -> ") + end + length(alg.inner_algs) > 0 && Base.show(io, alg.inner_algs[end]) + print(io, ")") +end + struct LieTrotterGodunovCache{uType, uprevType, iiType} <: AbstractOperatorSplittingCache u::uType uprev::uprevType From 31766135ac89ea3ef0d26f8d442b40fc4883a766 Mon Sep 17 00:00:00 2001 From: termi-official Date: Sat, 21 Feb 2026 00:15:20 +0100 Subject: [PATCH 07/17] Fake faulty sync --- src/utils.jl | 5 ----- test/operator_splitting_api.jl | 6 +++--- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 43263b6..f1ecff3 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -158,9 +158,6 @@ function backward_sync_subintegrator!( ) @views umaster = parent.u_master[solution_indices] sync_vectors!(umaster, child.u) - # Also keep parent.u consistent - @views ulocal = parent.u[solution_indices .- first(parent.solution_indices) .+ 1] - sync_vectors!(ulocal, child.u) return backward_sync_external!(parent, child, sync) end @@ -173,8 +170,6 @@ function backward_sync_subintegrator!( ) @views umaster = parent.u_master[solution_indices] sync_vectors!(umaster, child.u) - @views ulocal = parent.u[solution_indices .- first(parent.solution_indices) .+ 1] - sync_vectors!(ulocal, child.u) return backward_sync_external!(parent, child, sync) end diff --git a/test/operator_splitting_api.jl b/test/operator_splitting_api.jl index 641629b..945392d 100644 --- a/test/operator_splitting_api.jl +++ b/test/operator_splitting_api.jl @@ -74,11 +74,11 @@ testsys2 = mtkcompile(testmodel2; sort_eqs = false) # build_subintegrator_tree_with_cache. It just wraps the standard cache in # its own FakeAdaptiveAlgorithmCache. # --------------------------------------------------------------------------- -struct FakeAdaptiveAlgorithm{T} <: OS.AbstractOperatorSplittingAlgorithm +struct FakeAdaptiveAlgorithm{T,T2} <: OS.AbstractOperatorSplittingAlgorithm alg::T - inner_algs::T # delegate inner_algs to the wrapped algorithm + inner_algs::T2 # delegate inner_algs to the wrapped algorithm end -FakeAdaptiveAlgorithm(alg::T) where {T} = FakeAdaptiveAlgorithm{T}(alg, alg.inner_algs) +FakeAdaptiveAlgorithm(alg) = FakeAdaptiveAlgorithm(alg, alg.inner_algs) struct FakeAdaptiveAlgorithmCache{T} <: OS.AbstractOperatorSplittingCache cache::T From 9481821ab83220c7093410a9f4279a6ca21ae26e Mon Sep 17 00:00:00 2001 From: termi-official Date: Sat, 21 Feb 2026 00:15:31 +0100 Subject: [PATCH 08/17] Roll back to MTKv10 --- Project.toml | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index 74eb749..46754c0 100644 --- a/Project.toml +++ b/Project.toml @@ -20,14 +20,13 @@ CommonSolve = "0.2.4" DataStructures = "0.18.22, 0.19" DiffEqBase = "6.165.1" ExplicitImports = "1" -ModelingToolkit = "10, 11" +ModelingToolkit = "10" OrdinaryDiffEqCore = "1.19.0, 2, 3.1" OrdinaryDiffEqLowOrderRK = "1.7" OrdinaryDiffEqTsit5 = "1.1.0" PrecompileTools = "1.0" RecursiveArrayTools = "3.39.0" SafeTestsets = "0.1.0" -SciCompDSL = "1" SciMLBase = "2.77.0" TimerOutputs = "0.5.28" Unrolled = "0.1.5" @@ -38,8 +37,7 @@ ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78" OrdinaryDiffEqTsit5 = "b1df2697-797e-41e3-8120-5422d3b24e4a" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" -SciCompDSL = "91a8cdf1-4ca6-467b-a780-87fda3fff15e" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["ExplicitImports", "ModelingToolkit", "OrdinaryDiffEqTsit5", "SafeTestsets", "SciCompDSL", "Test"] +test = ["ExplicitImports", "ModelingToolkit", "OrdinaryDiffEqTsit5", "SafeTestsets", "Test"] From f1d0bb02394690196bb030d00b0ef5c944336a0b Mon Sep 17 00:00:00 2001 From: termi-official Date: Sat, 21 Feb 2026 00:15:47 +0100 Subject: [PATCH 09/17] Recover precompilation --- src/precompilation.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/precompilation.jl b/src/precompilation.jl index 11dc2e8..ad92e82 100644 --- a/src/precompilation.jl +++ b/src/precompilation.jl @@ -38,6 +38,6 @@ end # Precompile init and a few steps integrator = DiffEqBase.init(prob, tstepper, dt = 0.01, verbose = false) - # step!(integrator) - # solve!(integrator) + step!(integrator) + solve!(integrator) end From b30f7458c84b77fac88b8fa70231fda7c83bb7a9 Mon Sep 17 00:00:00 2001 From: termi-official Date: Sat, 21 Feb 2026 00:16:01 +0100 Subject: [PATCH 10/17] Remove MTKv10 from tests --- test/operator_splitting_api.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/operator_splitting_api.jl b/test/operator_splitting_api.jl index 945392d..f3b1ce6 100644 --- a/test/operator_splitting_api.jl +++ b/test/operator_splitting_api.jl @@ -6,7 +6,7 @@ import SciMLBase: ReturnCode import DiffEqBase: DiffEqBase, ODEFunction, ODEProblem using OrdinaryDiffEqLowOrderRK using OrdinaryDiffEqTsit5 -using ModelingToolkit, SciCompDSL +using ModelingToolkit # --------------------------------------------------------------------------- # Reference problem From 6e1b38d6ae01209c7dc228ec1a7b4bd443ae2a10 Mon Sep 17 00:00:00 2001 From: termi-official Date: Sat, 21 Feb 2026 00:44:45 +0100 Subject: [PATCH 11/17] Okay I give up. Just take my tstops at the subintegration level. --- src/integrator.jl | 81 +++++++++++++++++------------------------------ 1 file changed, 29 insertions(+), 52 deletions(-) diff --git a/src/integrator.jl b/src/integrator.jl index 1996483..07fa870 100644 --- a/src/integrator.jl +++ b/src/integrator.jl @@ -51,6 +51,7 @@ its children's synchronizers, solution indices, and sub-integrators. It does - `u_master` — reference to the full master solution vector of the outermost `OperatorSplittingIntegrator` (needed for sync) - `t`, `dt`, `dtcache` — time tracking + `dtchangeable`, `stops` - `iter` — step counter at this level - `EEst` — error estimate (`NaN` for non-adaptive, `1.0` default for adaptive) @@ -70,6 +71,7 @@ mutable struct SplitSubIntegrator{ algType, uType, tType, + tstopsType, EEstType, controllerType, cacheType, @@ -88,6 +90,7 @@ mutable struct SplitSubIntegrator{ dt::tType dtcache::tType const dtchangeable::Bool + tstops::tstopsType iter::Int EEst::EEstType controller::controllerType @@ -102,13 +105,14 @@ mutable struct SplitSubIntegrator{ child_solution_indices::childSolidxType # Tuple child_synchronizers::childSyncType # Tuple opts::optionsType + tdir::tType end # --- SplitSubIntegrator interface --- tdir(integrator::SplitSubIntegrator) = sign(integrator.dt) -@inline SciMLBase.has_tstop(::SplitSubIntegrator) = false +@inline SciMLBase.has_tstop(::SplitSubIntegrator) = true @inline SciMLBase.isadaptive(sub::SplitSubIntegrator) = isadaptive(sub.alg) # proposed-dt interface (mirrors ODEIntegrator) @@ -409,7 +413,7 @@ end # --------------------------------------------------------------------------- # handle_tstop! # --------------------------------------------------------------------------- -function OrdinaryDiffEqCore.handle_tstop!(integrator::OperatorSplittingIntegrator) +function OrdinaryDiffEqCore.handle_tstop!(integrator::AnySplitIntegrator) if SciMLBase.has_tstop(integrator) tdir_t = tdir(integrator) * integrator.t tdir_tstop = SciMLBase.first_tstop(integrator) @@ -548,14 +552,12 @@ function fix_solution_buffer_sizes!(integrator, sol) return nothing end -function OrdinaryDiffEqCore.fixed_t_for_floatingpoint_error!(integrator::OperatorSplittingIntegrator, ttmp) +function OrdinaryDiffEqCore.fixed_t_for_floatingpoint_error!(integrator::AnySplitIntegrator, ttmp) return if DiffEqBase.has_tstop(integrator) tstop = integrator.tdir * DiffEqBase.first_tstop(integrator) if abs(ttmp - tstop) < 100eps(float(max(integrator.t, tstop) / oneunit(integrator.t))) * oneunit(integrator.t) - # We have to update the floating point errors of the subintegrator nodes, because - # they do not have the tstop logic. try_snap_children_to_tstop!.(integrator.child_subintegrators, tstop) tstop else @@ -638,7 +640,7 @@ function DiffEqBase.solve!(integrator::OperatorSplittingIntegrator) ) end -function DiffEqBase.step!(integrator::OperatorSplittingIntegrator) +function DiffEqBase.step!(integrator::AnySplitIntegrator) @timeit_debug "step!" if integrator.advance_to_tstop tstop = SciMLBase.first_tstop(integrator) while !reached_tstop(integrator, tstop) @@ -670,7 +672,7 @@ function DiffEqBase.step!(integrator::OperatorSplittingIntegrator) return end -function DiffEqBase.step!(integrator::OperatorSplittingIntegrator, dt, stop_at_tdt = false) +function DiffEqBase.step!(integrator::AnySplitIntegrator, dt, stop_at_tdt = false) @timeit_debug "step!" begin dt <= zero(dt) && error("dt must be positive") stop_at_tdt && !integrator.dtchangeable && @@ -690,24 +692,6 @@ function DiffEqBase.step!(integrator::OperatorSplittingIntegrator, dt, stop_at_t return nothing end -function DiffEqBase.step!(integrator::SplitSubIntegrator, dt, stop_at_tdt = false) - @timeit_debug "step!" begin - dt <= zero(dt) && error("dt must be positive") - stop_at_tdt && !integrator.dtchangeable && - error("Cannot stop at t + dt if dtchangeable is false") - tnext = integrator.t + tdir(integrator) * dt - while !reached_tstop(integrator, tnext, stop_at_tdt) - step_header!(integrator) - @timeit_debug "check_error" SciMLBase.check_error!(integrator) ∉ ( - ReturnCode.Success, ReturnCode.Default, - ) && return - __step!(integrator) - step_footer!(integrator) - end - end - return nothing -end - # --------------------------------------------------------------------------- # check_error # --------------------------------------------------------------------------- @@ -1007,8 +991,12 @@ function _build_child( child_solution_indices = ntuple(i -> f.solution_indices[i], length(f.functions)) child_synchronizers = ntuple(i -> f.synchronizers[i], length(f.functions)) - u_sub = @view uouter[solution_indices] - uprev_sub = @view uprevouter[solution_indices] + u_sub = RecursiveArrayTools.recursivecopy(uouter[solution_indices]) + uprev_sub = RecursiveArrayTools.recursivecopy(uprevouter[solution_indices]) + + tstops_internal = OrdinaryDiffEqCore.initialize_tstops( + tType, tstops, d_discontinuities, prob.tspan + ) level_cache = init_cache( f, alg; @@ -1020,13 +1008,12 @@ function _build_child( sub = SplitSubIntegrator( alg, - # u and uprev: independent copies so that rollback works even when - # u_sub is a view into a device-local buffer. - RecursiveArrayTools.recursivecopy(Array(u_sub)), - RecursiveArrayTools.recursivecopy(Array(u_sub)), + u_sub, + uprev_sub, u_master, t0, t0, dt, dt, # t, tprev, dt, dtcache isdtchangeable(alg), + tstops_internal, 0, # iter EEst_val, controller, @@ -1039,6 +1026,7 @@ function _build_child( child_solution_indices, child_synchronizers, IntegratorOptions(; verbose, adaptive), + one(tType), ) return sub, level_cache @@ -1082,19 +1070,19 @@ end # --------------------------------------------------------------------------- # SciMLBase API # --------------------------------------------------------------------------- -SciMLBase.has_stats(::OperatorSplittingIntegrator) = true +SciMLBase.has_stats(::AnySplitIntegrator) = true -SciMLBase.has_tstop(i::OperatorSplittingIntegrator) = !isempty(i.tstops) -SciMLBase.first_tstop(i::OperatorSplittingIntegrator) = first(i.tstops) -SciMLBase.pop_tstop!(i::OperatorSplittingIntegrator) = pop!(i.tstops) +SciMLBase.has_tstop(i::AnySplitIntegrator) = !isempty(i.tstops) +SciMLBase.first_tstop(i::AnySplitIntegrator) = first(i.tstops) +SciMLBase.pop_tstop!(i::AnySplitIntegrator) = pop!(i.tstops) -DiffEqBase.get_dt(i::OperatorSplittingIntegrator) = i.dt -function set_dt!(i::OperatorSplittingIntegrator, dt) +DiffEqBase.get_dt(i::AnySplitIntegrator) = i.dt +function set_dt!(i::AnySplitIntegrator, dt) dt <= zero(dt) && error("dt must be positive") return i.dt = dt end -function DiffEqBase.add_tstop!(i::OperatorSplittingIntegrator, t) +function DiffEqBase.add_tstop!(i::AnySplitIntegrator, t) is_past_t(i, t) && error("Cannot add a tstop at $t because that is behind the current \ integrator time $(i.t)") @@ -1103,23 +1091,12 @@ function DiffEqBase.add_tstop!(i::OperatorSplittingIntegrator, t) return nothing end -function DiffEqBase.add_tstop!(i::SplitSubIntegrator, t) - DiffEqBase.add_tstop!.(i.child_subintegrators, t) -end - -function _add_tstop!(i::OperatorSplittingIntegrator, t) - is_past_t(i, t) && - error("Cannot add a tstop at $t because that is behind the current \ - integrator time $(i.t)") - _add_tstop!(i.child_subintegrators, t) - return push!(i.tstops, t) -end - function DiffEqBase.add_saveat!(i::OperatorSplittingIntegrator, t) is_past_t(i, t) && error("Cannot add a saveat point at $t because that is behind the \ current integrator time $(i.t)") - return push!(i.saveat, t) + push!(i.saveat, t) + return nothing end -DiffEqBase.u_modified!(i::OperatorSplittingIntegrator, bool) = nothing +DiffEqBase.u_modified!(i::AnySplitIntegrator, bool) = nothing From 0c16b840f663d133257f214ac966b76e9515121d Mon Sep 17 00:00:00 2001 From: termi-official Date: Sat, 21 Feb 2026 01:13:32 +0100 Subject: [PATCH 12/17] Fix faulty tstops and adaptivity handling --- src/integrator.jl | 6 ++---- test/operator_splitting_api.jl | 31 +++++++++++-------------------- 2 files changed, 13 insertions(+), 24 deletions(-) diff --git a/src/integrator.jl b/src/integrator.jl index 07fa870..96bd842 100644 --- a/src/integrator.jl +++ b/src/integrator.jl @@ -112,9 +112,6 @@ end tdir(integrator::SplitSubIntegrator) = sign(integrator.dt) -@inline SciMLBase.has_tstop(::SplitSubIntegrator) = true -@inline SciMLBase.isadaptive(sub::SplitSubIntegrator) = isadaptive(sub.alg) - # proposed-dt interface (mirrors ODEIntegrator) function SciMLBase.set_proposed_dt!(sub::SplitSubIntegrator, dt) if sub.dtcache != dt # only touch if actually changing @@ -1099,4 +1096,5 @@ function DiffEqBase.add_saveat!(i::OperatorSplittingIntegrator, t) return nothing end -DiffEqBase.u_modified!(i::AnySplitIntegrator, bool) = nothing +DiffEqBase.u_modified!(i::OperatorSplittingIntegrator, bool) = i.u_modified = bool +DiffEqBase.u_modified!(i::SplitSubIntegrator, bool) = i.u_modified = bool diff --git a/test/operator_splitting_api.jl b/test/operator_splitting_api.jl index f3b1ce6..9872895 100644 --- a/test/operator_splitting_api.jl +++ b/test/operator_splitting_api.jl @@ -74,7 +74,7 @@ testsys2 = mtkcompile(testmodel2; sort_eqs = false) # build_subintegrator_tree_with_cache. It just wraps the standard cache in # its own FakeAdaptiveAlgorithmCache. # --------------------------------------------------------------------------- -struct FakeAdaptiveAlgorithm{T,T2} <: OS.AbstractOperatorSplittingAlgorithm +struct FakeAdaptiveAlgorithm{T, T2} <: OS.AbstractOperatorSplittingAlgorithm alg::T inner_algs::T2 # delegate inner_algs to the wrapped algorithm end @@ -120,34 +120,25 @@ end cache::FakeAdaptiveAlgorithmCache ) = DiffEqBase.get_tmp_cache(integrator, alg, cache.cache) -@inline function OS.advance_solution_by!( - outer_integrator::OS.OperatorSplittingIntegrator, +@inline function OS._perform_step!( + outer_integrator, subintegrators::Tuple, cache::FakeAdaptiveAlgorithmCache, dt ) - return OS.advance_solution_by!( + return OS._perform_step!( outer_integrator, subintegrators, cache.cache, dt ) end -# For SplitSubIntegrator nodes whose cache was wrapped by FakeAdaptiveAlgorithmCache -@inline function OS.advance_solution_by!( - outer_integrator::OS.OperatorSplittingIntegrator, - sub::OS.SplitSubIntegrator, - subintegrators::Tuple, - solution_indices::Tuple, - synchronizers::Tuple, - cache::FakeAdaptiveAlgorithmCache, - dt - ) - return OS.advance_solution_by!( - outer_integrator, sub, subintegrators, solution_indices, - synchronizers, cache.cache, dt - ) +FakeAdaptiveLTG(inner) = FakeAdaptiveAlgorithm(LieTrotterGodunov(inner)) + +function Base.show(io::IO, alg::FakeAdaptiveAlgorithm) + print(io, "FAKE (") + Base.show(io, alg.alg) + print(io, ")") end -FakeAdaptiveLTG(inner) = FakeAdaptiveAlgorithm(LieTrotterGodunov(inner)) # --------------------------------------------------------------------------- # Tests @@ -172,7 +163,7 @@ FakeAdaptiveLTG(inner) = FakeAdaptiveAlgorithm(LieTrotterGodunov(inner)) nsteps = ceil(Int, (tspan[2] - tspan[1]) / dt) for TimeStepperType in (LieTrotterGodunov, FakeAdaptiveLTG) - @testset "Solver type $TimeStepperType | $tstepper" for (prob, tstepper) in ( + @testset "$tstepper" for (prob, tstepper) in ( (prob1a, TimeStepperType((Euler(), Euler()))), (prob1a, TimeStepperType((Tsit5(), Euler()))), (prob1a, TimeStepperType((Euler(), Tsit5()))), From a6d8bb4f3491835a95b43863e4462f53a892f8f9 Mon Sep 17 00:00:00 2001 From: termi-official Date: Sat, 21 Feb 2026 01:36:00 +0100 Subject: [PATCH 13/17] Fix controller crash as we do not have controllers now --- src/integrator.jl | 26 ++++++++++++++++++++++---- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/src/integrator.jl b/src/integrator.jl index 96bd842..87ec4a2 100644 --- a/src/integrator.jl +++ b/src/integrator.jl @@ -202,6 +202,7 @@ function SciMLBase.__init( advance_to_tstop = false, adaptive = isadaptive(alg), controller = nothing, + # controller = OrdinaryDiffEqCore.PIController(0.14, 0.08), alias_u0 = false, verbose = true, kwargs... @@ -383,7 +384,7 @@ function _subreinit_child!( kwargs... ) idxs = sub.solution_indices - sub.u .= @view u0[idxs] + sub.u .= @view u0[idxs] sub.uprev .= @view u0[idxs] sub.t = t0 if dt !== nothing @@ -763,15 +764,32 @@ end stepsize_controller!(integrator, integrator.alg) return nothing end +@inline function stepsize_controller!(integrator::AnySplitIntegrator, alg::AbstractOperatorSplittingAlgorithm) + isadaptive(alg) || return nothing + #stepsize_controller!(integrator, integrator.controller) + return nothing +end @inline function step_accept_controller!(integrator::AnySplitIntegrator) isadaptive(integrator.alg) || return nothing - step_accept_controller!(integrator, integrator.alg, nothing) + step_accept_controller!(integrator, integrator.alg) + return nothing +end +@inline function step_accept_controller!(integrator::AnySplitIntegrator, alg::AbstractOperatorSplittingAlgorithm) + isadaptive(alg) || return nothing + #step_accept_controller!(integrator, integrator.controller) return nothing end @inline function step_reject_controller!(integrator::AnySplitIntegrator) isadaptive(integrator.alg) || return nothing - return step_reject_controller!(integrator, integrator.alg, nothing) + step_reject_controller!(integrator, integrator.alg) + return nothing end +@inline function step_reject_controller!(integrator::AnySplitIntegrator, alg::AbstractOperatorSplittingAlgorithm) + isadaptive(integrator.alg) || return nothing + # step_reject_controller!(integrator, integrator.controller) + return nothing +end + # Time helpers tdir(integrator) = @@ -1044,7 +1062,7 @@ function _build_child( save_end = false, controller = nothing ) where {S, T, P, F} - u = @view uouter[solution_indices] + u = uouter[solution_indices] prob2 = if p isa NullParameters SciMLBase.ODEProblem(f, u, (t0, tf)) else From cd05df083241bbec6cee0369986d8f1e4d0b0304 Mon Sep 17 00:00:00 2001 From: termi-official Date: Sat, 21 Feb 2026 01:40:06 +0100 Subject: [PATCH 14/17] Hook into rollback --- src/integrator.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/integrator.jl b/src/integrator.jl index 87ec4a2..a375ae3 100644 --- a/src/integrator.jl +++ b/src/integrator.jl @@ -452,7 +452,7 @@ function reject_step!(integrator::AnySplitIntegrator) end function reject_step!(integrator::AnySplitIntegrator, cache, controller) integrator.u .= integrator.uprev - # TODO: roll back sub-integrators + rollback_children!(integrator) return nothing end function reject_step!(integrator::AnySplitIntegrator, cache, ::Nothing) @@ -495,6 +495,7 @@ end # Roll back each child's local buffer to match master u. # For DEIntegrators the leaf will be re-synced via forward_sync before the # next attempt, so there is nothing to do here. +rollback_children!(integrator::OperatorSplittingIntegrator) = rollback_children!(integrator.child_subintegrators, integrator.u) @unroll function rollback_children!(children::Tuple, u_master) @unroll for child in children rollback_child!(child, u_master) From 6433795889e9682faf7a33c74f988e2be462badc Mon Sep 17 00:00:00 2001 From: termi-official Date: Sat, 21 Feb 2026 03:49:15 +0100 Subject: [PATCH 15/17] Fix another fault in sync and reinit --- src/integrator.jl | 51 +++++++------------ src/solver.jl | 26 +++++----- src/utils.jl | 90 ++-------------------------------- test/operator_splitting_api.jl | 4 +- 4 files changed, 38 insertions(+), 133 deletions(-) diff --git a/src/integrator.jl b/src/integrator.jl index a375ae3..6c360f6 100644 --- a/src/integrator.jl +++ b/src/integrator.jl @@ -63,7 +63,7 @@ its children's synchronizers, solution indices, and sub-integrators. It does this level - `child_subintegrators` — tuple of direct children (`SplitSubIntegrator` or `DEIntegrator`) -- `solution_indices` — global indices (into master `u`) **owned by this node** +- `solution_indices` — global indices (into parent `u`) **owned by this node** - `child_solution_indices` — tuple of per-child global solution indices - `child_synchronizers` — tuple of per-child synchronizer objects """ @@ -243,7 +243,7 @@ function SciMLBase.__init( sol = SciMLBase.build_solution(prob, alg, tType[], uType[]) callback = DiffEqBase.CallbackSet(callback) - child_subintegrators, cache = build_subintegrators( + child_subintegrators = build_subintegrators( prob, alg, uprev, u, u, # u_master == u at the outermost level @@ -253,6 +253,11 @@ function SciMLBase.__init( adaptive, verbose ) + cache = init_cache( + prob.f, alg; + uprev = uprev, u = u, + ) + child_solution_indices = ntuple(i -> prob.f.solution_indices[i], length(prob.f.functions)) child_synchronizers = ntuple(i -> prob.f.synchronizers[i], length(prob.f.functions)) @@ -363,13 +368,9 @@ function _subreinit_child!( ) if dt !== nothing && child.dtchangeable SciMLBase.set_proposed_dt!(child, dt) + # Reinit does not touch this, so we reset it manually. + set_dt!(child, dt) end - # solution_indices live on the parent SplitSubIntegrator (or on the outer - # integrator for top-level children) — they were baked into child at init. - # reinit! on an ODEIntegrator resets u from its prob.u0; we need to pass - # the correct slice here. The parent calls us with the correct f_child - # but not the indices — those are embedded in child.sol.prob.u0 already - # because we constructed child with a view/copy of the right slice. return DiffEqBase.reinit!(child; kwargs...) end @@ -383,12 +384,10 @@ function _subreinit_child!( dt, kwargs... ) - idxs = sub.solution_indices - sub.u .= @view u0[idxs] - sub.uprev .= @view u0[idxs] - sub.t = t0 + sub.t = t0 if dt !== nothing SciMLBase.set_proposed_dt!(sub, dt) + set_dt!(sub, dt) end sub.iter = 0 sub.force_stepfail = false @@ -823,7 +822,7 @@ end function __step!(integrator::AnySplitIntegrator) advance_solution_by!(integrator, integrator.dt) - stepsize_controller!(integrator) + stepsize_controller!(integrator) # FIXME this should go into the footer return nothing end @@ -940,7 +939,7 @@ function build_subintegrators( ) (; f, p) = prob - results = ntuple( + child_subintegrators = ntuple( i -> _build_child( prob, alg.inner_algs[i], @@ -955,16 +954,7 @@ function build_subintegrators( length(f.functions) ) - child_subintegrators = ntuple(i -> results[i][1], length(f.functions)) - child_caches = ntuple(i -> results[i][2], length(f.functions)) - - cache = init_cache( - f, alg; - uprev = uprevouter, u = uouter, alias_u = true, - inner_caches = child_caches - ) - - return child_subintegrators, cache + return child_subintegrators end # Intermediate node: inner alg is an AbstractOperatorSplittingAlgorithm and @@ -986,8 +976,8 @@ function _build_child( ) tType = typeof(dt) - # Recurse: build each grandchild - grandchild_results = ntuple( + # Recurse: build each consecutive child + child_subintegrators = ntuple( i -> _build_child( prob, alg.inner_algs[i], @@ -1002,8 +992,6 @@ function _build_child( length(f.functions) ) - child_subintegrators = ntuple(i -> grandchild_results[i][1], length(f.functions)) - child_caches = ntuple(i -> grandchild_results[i][2], length(f.functions)) child_solution_indices = ntuple(i -> f.solution_indices[i], length(f.functions)) child_synchronizers = ntuple(i -> f.synchronizers[i], length(f.functions)) @@ -1017,7 +1005,6 @@ function _build_child( level_cache = init_cache( f, alg; uprev = uprev_sub, u = u_sub, - inner_caches = child_caches ) EEst_val = isadaptive(alg) ? one(tType) : tType(NaN) @@ -1045,7 +1032,7 @@ function _build_child( one(tType), ) - return sub, level_cache + return sub end # Leaf node: inner alg is a plain SciMLBase.AbstractODEAlgorithm @@ -1080,7 +1067,7 @@ function _build_child( advance_to_tstop = false, adaptive, controller, verbose ) - return integrator, integrator.cache + return integrator end # --------------------------------------------------------------------------- @@ -1093,7 +1080,7 @@ SciMLBase.first_tstop(i::AnySplitIntegrator) = first(i.tstops) SciMLBase.pop_tstop!(i::AnySplitIntegrator) = pop!(i.tstops) DiffEqBase.get_dt(i::AnySplitIntegrator) = i.dt -function set_dt!(i::AnySplitIntegrator, dt) +function set_dt!(i::DiffEqBase.DEIntegrator, dt) dt <= zero(dt) && error("dt must be positive") return i.dt = dt end diff --git a/src/solver.jl b/src/solver.jl index 2814d49..abf0c9e 100644 --- a/src/solver.jl +++ b/src/solver.jl @@ -21,26 +21,20 @@ function Base.show(io::IO, alg::LieTrotterGodunov) print(io, ")") end -struct LieTrotterGodunovCache{uType, uprevType, iiType} <: AbstractOperatorSplittingCache +struct LieTrotterGodunovCache{uType, uprevType} <: AbstractOperatorSplittingCache u::uType uprev::uprevType - inner_caches::iiType end function init_cache( f::GenericSplitFunction, alg::LieTrotterGodunov; uprev::AbstractArray, u::AbstractVector, - inner_caches, - alias_uprev = true, - alias_u = false ) - _uprev = alias_uprev ? uprev : RecursiveArrayTools.recursivecopy(uprev) - _u = alias_u ? u : RecursiveArrayTools.recursivecopy(u) - return LieTrotterGodunovCache(_u, _uprev, inner_caches) + return LieTrotterGodunovCache(u, uprev) end @unroll function _perform_step!( - outer, + parent, children::Tuple, cache::LieTrotterGodunovCache, dt @@ -48,15 +42,17 @@ end i = 0 @unroll for child in children i += 1 - idxs = outer.child_solution_indices[i] - sync = outer.child_synchronizers[i] - @timeit_debug "sync ->" forward_sync_subintegrator!(outer, child, idxs, sync) - @timeit_debug "time solve" advance_solution_by!(outer, child, dt) + idxs = parent.child_solution_indices[i] + sync = parent.child_synchronizers[i] + + @timeit_debug "sync ->" forward_sync_subintegrator!(parent, child, idxs, sync) + @timeit_debug "time solve" advance_solution_by!(parent, child, dt) if _child_failed(child) - outer.force_stepfail = true + parent.force_stepfail = true return end - backward_sync_subintegrator!(outer, child, idxs, sync) + + backward_sync_subintegrator!(parent, child, idxs, sync) end end diff --git a/src/utils.jl b/src/utils.jl index f1ecff3..a442a50 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -54,9 +54,8 @@ end # solution into the child and applies any external parameter synchronisation. # --------------------------------------------------------------------------- -# Parent = outermost OperatorSplittingIntegrator, child = DEIntegrator function forward_sync_subintegrator!( - parent::OperatorSplittingIntegrator, + parent::AnySplitIntegrator, child::DEIntegrator, solution_indices, sync @@ -66,53 +65,11 @@ function forward_sync_subintegrator!( return nothing end -# Parent = outermost OperatorSplittingIntegrator, child = SplitSubIntegrator -function forward_sync_subintegrator!( - parent::OperatorSplittingIntegrator, - child::SplitSubIntegrator, - solution_indices, - sync - ) - @views uparent = parent.u[solution_indices] - sync_vectors!(child.u, uparent) - sync_vectors!(child.uprev, uparent) - forward_sync_external!(parent, child, sync) - return nothing -end - -# Parent = SplitSubIntegrator, child = DEIntegrator -function forward_sync_subintegrator!( - parent::SplitSubIntegrator, - child::DEIntegrator, - solution_indices, - sync - ) - # parent.u is this level's buffer; solution_indices are relative to - # the master u. We read from the master via u_master. - forward_sync_internal!(parent.u_master, child, solution_indices) - forward_sync_external!(parent, child, sync) - return nothing -end - -# Parent = SplitSubIntegrator, child = SplitSubIntegrator -function forward_sync_subintegrator!( - parent::SplitSubIntegrator, - child::SplitSubIntegrator, - solution_indices, - sync - ) - @views umaster = parent.u_master[solution_indices] - sync_vectors!(child.u, umaster) - sync_vectors!(child.uprev, umaster) - forward_sync_external!(parent, child, sync) - return nothing -end - # Shared internal helper: copy master u slice → leaf DEIntegrator u/uprev function forward_sync_internal!(u_source, child::DEIntegrator, solution_indices) @views usrc = u_source[solution_indices] - sync_vectors!(child.uprev, usrc) sync_vectors!(child.u, usrc) + sync_vectors!(child.uprev, child.u) SciMLBase.u_modified!(child, true) return nothing end @@ -124,55 +81,18 @@ end # state back into the master solution vector. # --------------------------------------------------------------------------- -# Parent = outermost OperatorSplittingIntegrator, child = DEIntegrator function backward_sync_subintegrator!( - parent::OperatorSplittingIntegrator, + parent::AnySplitIntegrator, child::DEIntegrator, solution_indices, sync ) - @views uparent = parent.u[solution_indices] - sync_vectors!(uparent, child.u) + @views udst = parent.u[solution_indices] + sync_vectors!(udst, child.u) backward_sync_external!(parent, child, sync) return nothing end -# Parent = outermost OperatorSplittingIntegrator, child = SplitSubIntegrator -function backward_sync_subintegrator!( - parent::OperatorSplittingIntegrator, - child::SplitSubIntegrator, - solution_indices, - sync - ) - @views uparent = parent.u[solution_indices] - sync_vectors!(uparent, child.u) - return backward_sync_external!(parent, child, sync) -end - -# Parent = SplitSubIntegrator, child = DEIntegrator -function backward_sync_subintegrator!( - parent::SplitSubIntegrator, - child::DEIntegrator, - solution_indices, - sync - ) - @views umaster = parent.u_master[solution_indices] - sync_vectors!(umaster, child.u) - return backward_sync_external!(parent, child, sync) -end - -# Parent = SplitSubIntegrator, child = SplitSubIntegrator -function backward_sync_subintegrator!( - parent::SplitSubIntegrator, - child::SplitSubIntegrator, - solution_indices, - sync - ) - @views umaster = parent.u_master[solution_indices] - sync_vectors!(umaster, child.u) - return backward_sync_external!(parent, child, sync) -end - # --------------------------------------------------------------------------- # forward_sync_external! / backward_sync_external! # These handle parameter synchronisation via the `sync` object. diff --git a/test/operator_splitting_api.jl b/test/operator_splitting_api.jl index 9872895..e5c16bc 100644 --- a/test/operator_splitting_api.jl +++ b/test/operator_splitting_api.jl @@ -154,7 +154,7 @@ end prob1a = OperatorSplittingProblem(fsplit1a, u0, tspan) prob1b = OperatorSplittingProblem(fsplit1b, u0, tspan) - f3dofs = [1, 3] + f3dofs = [1, 2] fsplit2_inner = GenericSplitFunction((f3, f3), (f3dofs, f3dofs)) fsplit2_outer = GenericSplitFunction((f1, fsplit2_inner), (f1dofs, f2dofs)) @@ -243,6 +243,8 @@ end @test integrator.iter == nsteps DiffEqBase.reinit!(integrator; dt = dt) + @test integrator.dt == dt + @test integrator.dt == integrator.dtcache @test integrator.sol.retcode == DiffEqBase.ReturnCode.Default for (u, t) in DiffEqBase.TimeChoiceIterator(integrator, tspan[1]:5.0:tspan[2]) end From 65a2c6f9aa82a5b63c2fe3e6d15812c38658e755 Mon Sep 17 00:00:00 2001 From: termi-official Date: Sat, 21 Feb 2026 04:05:46 +0100 Subject: [PATCH 16/17] Revive some comments. --- src/integrator.jl | 20 ++++++-------------- src/solver.jl | 2 +- src/utils.jl | 33 ++++++++++++++++++++------------- test/operator_splitting_api.jl | 29 +++++++++++++++++++---------- 4 files changed, 46 insertions(+), 38 deletions(-) diff --git a/src/integrator.jl b/src/integrator.jl index 6c360f6..2d206fa 100644 --- a/src/integrator.jl +++ b/src/integrator.jl @@ -15,9 +15,7 @@ Base.@kwdef mutable struct IntegratorOptions{tType, fType, F3} isoutofdomain::F3 = DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN end -# --------------------------------------------------------------------------- -# SplitSubIntegratorStatus -# --------------------------------------------------------------------------- + """ SplitSubIntegratorStatus @@ -31,9 +29,7 @@ end SplitSubIntegratorStatus() = SplitSubIntegratorStatus(ReturnCode.Default) -# --------------------------------------------------------------------------- -# SplitSubIntegrator -# --------------------------------------------------------------------------- + """ SplitSubIntegrator <: AbstractODEIntegrator @@ -123,9 +119,7 @@ function SciMLBase.set_proposed_dt!(sub::SplitSubIntegrator, dt) return nothing end -# --------------------------------------------------------------------------- -# OperatorSplittingIntegrator -# --------------------------------------------------------------------------- + """ OperatorSplittingIntegrator <: AbstractODEIntegrator @@ -439,8 +433,6 @@ end notify_integrator_hit_tstop!(integrator::AnySplitIntegrator) = nothing -is_first_iteration(integrator::AnySplitIntegrator) = integrator.iter == 0 -increment_iteration(integrator::AnySplitIntegrator) = integrator.iter += 1 # --------------------------------------------------------------------------- # Step accept/reject @@ -531,6 +523,9 @@ function step_header!(integrator::AnySplitIntegrator) return nothing end +is_first_iteration(integrator::AnySplitIntegrator) = integrator.iter == 0 +increment_iteration(integrator::AnySplitIntegrator) = integrator.iter += 1 + function footer_reset_flags!(integrator) integrator.u_modified = false return @@ -738,9 +733,6 @@ end _child_retcode(child::DEIntegrator) = SciMLBase.check_error(child) _child_retcode(child::SplitSubIntegrator) = child.status.retcode -# --------------------------------------------------------------------------- -# Internal step -# --------------------------------------------------------------------------- function setup_u(prob::OperatorSplittingProblem, solver, alias_u0) alias_u0 ? prob.u0 : RecursiveArrayTools.recursivecopy(prob.u0) end diff --git a/src/solver.jl b/src/solver.jl index abf0c9e..052ce09 100644 --- a/src/solver.jl +++ b/src/solver.jl @@ -8,7 +8,7 @@ First-order sequential operator splitting algorithm attributed to [Lie:1880:tti,Tro:1959:psg,God:1959:dmn](@cite). """ struct LieTrotterGodunov{AlgTupleType} <: AbstractOperatorSplittingAlgorithm - inner_algs::AlgTupleType + inner_algs::AlgTupleType # Tuple of timesteppers for inner problems end function Base.show(io::IO, alg::LieTrotterGodunov) diff --git a/src/utils.jl b/src/utils.jl index a442a50..6a7769f 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -3,6 +3,7 @@ function tstops_and_saveat_heaps(t0, tf, tstops, saveat) FT = typeof(tf) ordering = tf > t0 ? DataStructures.FasterForward : DataStructures.FasterReverse + # ensure that tstops includes tf and only has values ahead of t0 tstops = [filter(t -> t0 < t < tf || tf < t < t0, tstops)..., tf] tstops = DataStructures.BinaryHeap{FT, ordering}(tstops) @@ -13,6 +14,8 @@ function tstops_and_saveat_heaps(t0, tf, tstops, saveat) saveat = tf > t0 ? saveat : -saveat saveat = [t0:saveat:tf..., tf] else + # We do not need to filter saveat like tstops because the saving + # callback will ignore any times that are not between t0 and tf. saveat = collect(saveat) end saveat = DataStructures.BinaryHeap{FT, ordering}(saveat) @@ -46,13 +49,14 @@ function sync_vectors!(a, b) return nothing end -# --------------------------------------------------------------------------- -# forward_sync_subintegrator! -# -# The *parent* (OperatorSplittingIntegrator or SplitSubIntegrator) calls this -# before each child's step. It copies the relevant slice of the master -# solution into the child and applies any external parameter synchronisation. -# --------------------------------------------------------------------------- +""" + forward_sync_subintegrator!(parent_integrator::OperatorSplittingIntegrator, inner_integrator::DEIntegrator, solution_indices, sync) + +This function is responsible of copying the solution and parameters of the parent integrator and the synchronized subintegrators with the information given into the inner integrator. +If the inner integrator is synchronized with other inner integrators using `sync`, the function `forward_sync_external!` shall be dispatched for `sync`. +The `sync` object is passed from the outside and is the main entry point to dispatch custom types on for parameter synchronization. +The `solution_indices` are indices into the parent integrators solution vectors. +""" function forward_sync_subintegrator!( parent::AnySplitIntegrator, @@ -74,12 +78,15 @@ function forward_sync_internal!(u_source, child::DEIntegrator, solution_indices) return nothing end -# --------------------------------------------------------------------------- -# backward_sync_subintegrator! -# -# The *parent* calls this after each child's step to copy the child's updated -# state back into the master solution vector. -# --------------------------------------------------------------------------- + +""" + backward_sync_subintegrator!(parent_integrator::OperatorSplittingIntegrator, inner_integrator::DEIntegrator, solution_indices, sync) + +This function is responsible of copying the solution of the inner integrator back into parent integrator and the synchronized subintegrators. +If the inner integrator is synchronized with other inner integrators using `sync`, the function `backward_sync_external!` shall be dispatched for `sync`. +The `sync` object is passed from the outside and is the main entry point to dispatch custom types on for parameter synchronization. +The `solution_indices` are indices in the parent integrators solution vectors. +""" function backward_sync_subintegrator!( parent::AnySplitIntegrator, diff --git a/test/operator_splitting_api.jl b/test/operator_splitting_api.jl index e5c16bc..f4d19e8 100644 --- a/test/operator_splitting_api.jl +++ b/test/operator_splitting_api.jl @@ -67,13 +67,7 @@ end @named testmodel2 = TestModelODE2() testsys2 = mtkcompile(testmodel2; sort_eqs = false) -# --------------------------------------------------------------------------- -# FakeAdaptiveAlgorithm — tests adaptive code path -# -# With the new interface FakeAdaptiveAlgorithm no longer needs to override -# build_subintegrator_tree_with_cache. It just wraps the standard cache in -# its own FakeAdaptiveAlgorithmCache. -# --------------------------------------------------------------------------- +# Test whether adaptive code path works in principle struct FakeAdaptiveAlgorithm{T, T2} <: OS.AbstractOperatorSplittingAlgorithm alg::T inner_algs::T2 # delegate inner_algs to the wrapped algorithm @@ -146,6 +140,9 @@ end @testset "reinit and convergence" begin dt = 0.01π + # Here we describe index sets f1dofs and f2dofs that map the + # local indices in f1 and f2 into the global problem. Just put + # ode_true and ode1/ode2 side by side to see how they connect. f1dofs = [1, 2, 3] f2dofs = [1, 3] fsplit1a = GenericSplitFunction((f1, f2), (f1dofs, f2dofs)) @@ -153,7 +150,9 @@ end prob1a = OperatorSplittingProblem(fsplit1a, u0, tspan) prob1b = OperatorSplittingProblem(fsplit1b, u0, tspan) - + + # Note that we define the dof indices w.r.t the parent function. + # Hence the indices for `fsplit2_inner` are. f3dofs = [1, 2] fsplit2_inner = GenericSplitFunction((f3, f3), (f3dofs, f3dofs)) fsplit2_outer = GenericSplitFunction((f1, fsplit2_inner), (f1dofs, f2dofs)) @@ -184,6 +183,9 @@ end ) @test integrator.sol.retcode == DiffEqBase.ReturnCode.Default + sub1 = integrator.child_subintegrators[1] + sub2 = integrator.child_subintegrators[2] + DiffEqBase.solve!(integrator) @test integrator.sol.retcode == DiffEqBase.ReturnCode.Success ufinal = copy(integrator.u) @@ -192,11 +194,12 @@ end @test integrator.dtcache ≈ dt @test integrator.iter == nsteps - # SplitSubIntegrators now carry t and iter at each level - sub1 = integrator.child_subintegrators[1] @test sub1.t ≈ tspan[2] @test sub1.iter == nsteps + @test sub2.t ≈ tspan[2] + @test sub2.iter == nsteps + DiffEqBase.reinit!(integrator; dt = dt) @test integrator.sol.retcode == DiffEqBase.ReturnCode.Default for (u, t) in DiffEqBase.TimeChoiceIterator(integrator, tspan[1]:5.0:tspan[2]) @@ -222,6 +225,12 @@ end @test integrator.t ≈ tspan[2] @test integrator.dtcache ≈ dt @test integrator.iter == nsteps + + @test sub1.t ≈ tspan[2] + @test sub1.iter == nsteps + + @test sub2.t ≈ tspan[2] + @test sub2.iter == nsteps end end From 2dda569a475529ca93c2430fef990eca8d932022 Mon Sep 17 00:00:00 2001 From: Kyle Beggs Date: Mon, 23 Feb 2026 14:57:12 -0500 Subject: [PATCH 17/17] format --- src/function.jl | 4 +-- src/integrator.jl | 66 +++++++++++++++++----------------- src/precompilation.jl | 4 +-- src/solver.jl | 14 ++++---- src/utils.jl | 16 ++++----- test/operator_splitting_api.jl | 12 +++---- 6 files changed, 58 insertions(+), 58 deletions(-) diff --git a/src/function.jl b/src/function.jl index 0760de0..b06b4b5 100644 --- a/src/function.jl +++ b/src/function.jl @@ -20,11 +20,11 @@ struct GenericSplitFunction{fSetType <: Tuple, idxSetType <: Tuple, sSetType <: end function gsf_recursive_function_type_safety_check(f::GenericSplitFunction) - gsf_recursive_function_type_safety_check.(f.functions) + return gsf_recursive_function_type_safety_check.(f.functions) end function gsf_recursive_function_type_safety_check(dunno) - @warn "One of the inner functions in GenericSplitFunction is of type $(typeof(dunno)) which is not a subtype of SciMLBase.AbstractDiffEqFunction." + return @warn "One of the inner functions in GenericSplitFunction is of type $(typeof(dunno)) which is not a subtype of SciMLBase.AbstractDiffEqFunction." end function gsf_recursive_function_type_safety_check(::SciMLBase.AbstractDiffEqFunction) diff --git a/src/integrator.jl b/src/integrator.jl index 2d206fa..ffd6c10 100644 --- a/src/integrator.jl +++ b/src/integrator.jl @@ -75,7 +75,7 @@ mutable struct SplitSubIntegrator{ solidxType, childSolidxType, childSyncType, - optionsType + optionsType, } <: SciMLBase.AbstractODEIntegrator{algType, true, uType, tType} alg::algType u::uType # local solution buffer @@ -224,17 +224,17 @@ function SciMLBase.__init( tstops_internal = OrdinaryDiffEqCore.initialize_tstops( tType, tstops, d_discontinuities, prob.tspan ) - saveat_internal = OrdinaryDiffEqCore.initialize_saveat(tType, saveat, prob.tspan) + saveat_internal = OrdinaryDiffEqCore.initialize_saveat(tType, saveat, prob.tspan) d_discontinuities_internal = OrdinaryDiffEqCore.initialize_d_discontinuities( tType, d_discontinuities, prob.tspan ) - u = setup_u(prob, alg, alias_u0) + u = setup_u(prob, alg, alias_u0) uprev = setup_u(prob, alg, false) - tmp = setup_u(prob, alg, false) + tmp = setup_u(prob, alg, false) uType = typeof(u) - sol = SciMLBase.build_solution(prob, alg, tType[], uType[]) + sol = SciMLBase.build_solution(prob, alg, tType[], uType[]) callback = DiffEqBase.CallbackSet(callback) child_subintegrators = build_subintegrators( @@ -252,8 +252,8 @@ function SciMLBase.__init( uprev = uprev, u = u, ) - child_solution_indices = ntuple(i -> prob.f.solution_indices[i], length(prob.f.functions)) - child_synchronizers = ntuple(i -> prob.f.synchronizers[i], length(prob.f.functions)) + child_solution_indices = ntuple(i -> prob.f.solution_indices[i], length(prob.f.functions)) + child_synchronizers = ntuple(i -> prob.f.synchronizers[i], length(prob.f.functions)) integrator = OperatorSplittingIntegrator( prob.f, @@ -264,7 +264,7 @@ function SciMLBase.__init( dt, dtcache, dtchangeable, tstops_internal, tstops, - saveat_internal, saveat, + saveat_internal, saveat, callback, advance_to_tstop, false, false, false, false, @@ -299,9 +299,9 @@ function DiffEqBase.reinit!( reinit_callbacks = true, reinit_retcode = true ) - integrator.u .= u0 + integrator.u .= u0 integrator.uprev .= u0 - integrator.t = t0 + integrator.t = t0 integrator.tprev = t0 if dt !== nothing integrator.dt = dt @@ -383,10 +383,10 @@ function _subreinit_child!( SciMLBase.set_proposed_dt!(sub, dt) set_dt!(sub, dt) end - sub.iter = 0 - sub.force_stepfail = false + sub.iter = 0 + sub.force_stepfail = false sub.last_step_failed = false - sub.status = SplitSubIntegratorStatus(ReturnCode.Default) + sub.status = SplitSubIntegratorStatus(ReturnCode.Default) # Reset EEst to its appropriate default if isadaptive(sub) sub.EEst = one(sub.EEst) @@ -406,7 +406,7 @@ end # --------------------------------------------------------------------------- function OrdinaryDiffEqCore.handle_tstop!(integrator::AnySplitIntegrator) if SciMLBase.has_tstop(integrator) - tdir_t = tdir(integrator) * integrator.t + tdir_t = tdir(integrator) * integrator.t tdir_tstop = SciMLBase.first_tstop(integrator) if tdir_t == tdir_tstop while tdir_t == tdir_tstop @@ -523,7 +523,7 @@ function step_header!(integrator::AnySplitIntegrator) return nothing end -is_first_iteration(integrator::AnySplitIntegrator) = integrator.iter == 0 +is_first_iteration(integrator::AnySplitIntegrator) = integrator.iter == 0 increment_iteration(integrator::AnySplitIntegrator) = integrator.iter += 1 function footer_reset_flags!(integrator) @@ -562,16 +562,16 @@ function OrdinaryDiffEqCore.fixed_t_for_floatingpoint_error!(integrator::AnySpli end function try_snap_children_to_tstop!(integrator::SplitSubIntegrator, tstop) if abs(tstop - integrator.t) < - 100eps(float(max(integrator.t, tstop) / oneunit(integrator.t))) * oneunit(integrator.t) + 100eps(float(max(integrator.t, tstop) / oneunit(integrator.t))) * oneunit(integrator.t) integrator.t = tstop else @warn "Failed to snap timestep for integrator $(integrator.t) with parent integrator hitting the tstop $(tstop)." end - try_snap_children_to_tstop!.(integrator.child_subintegrators, tstop) + return try_snap_children_to_tstop!.(integrator.child_subintegrators, tstop) end function try_snap_children_to_tstop!(integrator::DEIntegrator, tstop) - if abs(tstop - integrator.t) < - 100eps(float(max(integrator.t, tstop) / oneunit(integrator.t))) * oneunit(integrator.t) + return if abs(tstop - integrator.t) < + 100eps(float(max(integrator.t, tstop) / oneunit(integrator.t))) * oneunit(integrator.t) integrator.t = tstop else @warn "Failed to snap timestep for integrator $(integrator.t) with parent integrator hitting the tstop $(tstop)." @@ -585,7 +585,7 @@ function step_footer!(integrator::AnySplitIntegrator) if should_accept_step(integrator) integrator.last_step_failed = false integrator.tprev = integrator.t - integrator.t = OrdinaryDiffEqCore.fixed_t_for_floatingpoint_error!(integrator, ttmp) + integrator.t = OrdinaryDiffEqCore.fixed_t_for_floatingpoint_error!(integrator, ttmp) step_accept_controller!(integrator) elseif integrator.force_stepfail if isadaptive(integrator) @@ -730,11 +730,11 @@ end return current_retcode end -_child_retcode(child::DEIntegrator) = SciMLBase.check_error(child) +_child_retcode(child::DEIntegrator) = SciMLBase.check_error(child) _child_retcode(child::SplitSubIntegrator) = child.status.retcode function setup_u(prob::OperatorSplittingProblem, solver, alias_u0) - alias_u0 ? prob.u0 : RecursiveArrayTools.recursivecopy(prob.u0) + return alias_u0 ? prob.u0 : RecursiveArrayTools.recursivecopy(prob.u0) end @inline function DiffEqBase.get_tmp_cache(integrator::OperatorSplittingIntegrator) @@ -863,11 +863,11 @@ function advance_solution_by!( end function advance_solution_by!( - outer::SplitSubIntegrator, - children::Tuple, - cache::AbstractOperatorSplittingCache, - dt -) + outer::SplitSubIntegrator, + children::Tuple, + cache::AbstractOperatorSplittingCache, + dt + ) _perform_step!(outer, children, cache, dt) if outer.force_stepfail @@ -984,10 +984,10 @@ function _build_child( length(f.functions) ) - child_solution_indices = ntuple(i -> f.solution_indices[i], length(f.functions)) - child_synchronizers = ntuple(i -> f.synchronizers[i], length(f.functions)) + child_solution_indices = ntuple(i -> f.solution_indices[i], length(f.functions)) + child_synchronizers = ntuple(i -> f.synchronizers[i], length(f.functions)) - u_sub = RecursiveArrayTools.recursivecopy(uouter[solution_indices]) + u_sub = RecursiveArrayTools.recursivecopy(uouter[solution_indices]) uprev_sub = RecursiveArrayTools.recursivecopy(uprevouter[solution_indices]) tstops_internal = OrdinaryDiffEqCore.initialize_tstops( @@ -1067,9 +1067,9 @@ end # --------------------------------------------------------------------------- SciMLBase.has_stats(::AnySplitIntegrator) = true -SciMLBase.has_tstop(i::AnySplitIntegrator) = !isempty(i.tstops) -SciMLBase.first_tstop(i::AnySplitIntegrator) = first(i.tstops) -SciMLBase.pop_tstop!(i::AnySplitIntegrator) = pop!(i.tstops) +SciMLBase.has_tstop(i::AnySplitIntegrator) = !isempty(i.tstops) +SciMLBase.first_tstop(i::AnySplitIntegrator) = first(i.tstops) +SciMLBase.pop_tstop!(i::AnySplitIntegrator) = pop!(i.tstops) DiffEqBase.get_dt(i::AnySplitIntegrator) = i.dt function set_dt!(i::DiffEqBase.DEIntegrator, dt) diff --git a/src/precompilation.jl b/src/precompilation.jl index ad92e82..f638e74 100644 --- a/src/precompilation.jl +++ b/src/precompilation.jl @@ -15,7 +15,7 @@ end function _precompile_ode3(du, u, p, t) du[1] = -0.01u[2] - du[2] = -0.01u[1] + return du[2] = -0.01u[1] end @compile_workload begin @@ -31,7 +31,7 @@ end f2dofs = [1, 3] f3dofs = [2, 3] fsplitinner = GenericSplitFunction((f2, f3), (f2dofs, f3dofs)) - fsplit = GenericSplitFunction((f1, fsplitinner), (f1dofs, [1,2,3])) + fsplit = GenericSplitFunction((f1, fsplitinner), (f1dofs, [1, 2, 3])) prob = OperatorSplittingProblem(fsplit, u0, tspan) tstepper = LieTrotterGodunov((Euler(), LieTrotterGodunov((Euler(), Euler())))) diff --git a/src/solver.jl b/src/solver.jl index 052ce09..3163474 100644 --- a/src/solver.jl +++ b/src/solver.jl @@ -13,12 +13,12 @@ end function Base.show(io::IO, alg::LieTrotterGodunov) print(io, "LTG (") - for inner_alg in alg.inner_algs[1:end-1] + for inner_alg in alg.inner_algs[1:(end - 1)] Base.show(io, inner_alg) print(io, " -> ") end length(alg.inner_algs) > 0 && Base.show(io, alg.inner_algs[end]) - print(io, ")") + return print(io, ")") end struct LieTrotterGodunovCache{uType, uprevType} <: AbstractOperatorSplittingCache @@ -34,11 +34,11 @@ function init_cache( end @unroll function _perform_step!( - parent, - children::Tuple, - cache::LieTrotterGodunovCache, - dt -) + parent, + children::Tuple, + cache::LieTrotterGodunovCache, + dt + ) i = 0 @unroll for child in children i += 1 diff --git a/src/utils.jl b/src/utils.jl index 6a7769f..8dfdd69 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -33,9 +33,9 @@ might be that the vectors alias each other in memory. need_sync need_sync(a::AbstractVector, b::AbstractVector) = true -need_sync(a::SubArray, b::AbstractVector) = a.parent !== b -need_sync(a::AbstractVector, b::SubArray) = a !== b.parent -need_sync(a::SubArray, b::SubArray) = a.parent !== b.parent +need_sync(a::SubArray, b::AbstractVector) = a.parent !== b +need_sync(a::AbstractVector, b::SubArray) = a !== b.parent +need_sync(a::SubArray, b::SubArray) = a.parent !== b.parent """ sync_vectors!(a, b) @@ -72,7 +72,7 @@ end # Shared internal helper: copy master u slice → leaf DEIntegrator u/uprev function forward_sync_internal!(u_source, child::DEIntegrator, solution_indices) @views usrc = u_source[solution_indices] - sync_vectors!(child.u, usrc) + sync_vectors!(child.u, usrc) sync_vectors!(child.uprev, child.u) SciMLBase.u_modified!(child, true) return nothing @@ -106,9 +106,9 @@ end # --------------------------------------------------------------------------- # NoExternalSynchronization: no-op for all parent/child combinations -forward_sync_external!(parent::DEIntegrator, child::DEIntegrator, ::NoExternalSynchronization) = nothing +forward_sync_external!(parent::DEIntegrator, child::DEIntegrator, ::NoExternalSynchronization) = nothing backward_sync_external!(parent::DEIntegrator, child::DEIntegrator, ::NoExternalSynchronization) = nothing -forward_sync_external!(parent::OperatorSplittingIntegrator, child::DEIntegrator, ::NoExternalSynchronization) = nothing +forward_sync_external!(parent::OperatorSplittingIntegrator, child::DEIntegrator, ::NoExternalSynchronization) = nothing backward_sync_external!(parent::OperatorSplittingIntegrator, child::DEIntegrator, ::NoExternalSynchronization) = nothing # OperatorSplittingIntegrator parent with DEIntegrator child: parameter sync @@ -160,7 +160,7 @@ end validate_time_point(integrator::AnySplitIntegrator) = validate_time_point(integrator, integrator.child_subintegrators) function validate_time_point(parent, child::SplitSubIntegrator) @assert parent.t == child.t "(parent.t=$(parent.t) != child.t=$(child.t))" - validate_time_point(child, child.child_subintegrators) + return validate_time_point(child, child.child_subintegrators) end @unroll function validate_time_point(parent, children::Tuple) @@ -170,7 +170,7 @@ end end function validate_time_point(parent, child::DEIntegrator) - @assert child.t == parent.t "(parent.t=$(parent.t) != child.t=$(child.t))" + return @assert child.t == parent.t "(parent.t=$(parent.t) != child.t=$(child.t))" end # --------------------------------------------------------------------------- diff --git a/test/operator_splitting_api.jl b/test/operator_splitting_api.jl index f4d19e8..2f11a30 100644 --- a/test/operator_splitting_api.jl +++ b/test/operator_splitting_api.jl @@ -109,10 +109,10 @@ function OS.init_cache( end @inline DiffEqBase.get_tmp_cache( - integrator::OS.OperatorSplittingIntegrator, - alg::OS.AbstractOperatorSplittingAlgorithm, - cache::FakeAdaptiveAlgorithmCache - ) = DiffEqBase.get_tmp_cache(integrator, alg, cache.cache) + integrator::OS.OperatorSplittingIntegrator, + alg::OS.AbstractOperatorSplittingAlgorithm, + cache::FakeAdaptiveAlgorithmCache +) = DiffEqBase.get_tmp_cache(integrator, alg, cache.cache) @inline function OS._perform_step!( outer_integrator, @@ -130,7 +130,7 @@ FakeAdaptiveLTG(inner) = FakeAdaptiveAlgorithm(LieTrotterGodunov(inner)) function Base.show(io::IO, alg::FakeAdaptiveAlgorithm) print(io, "FAKE (") Base.show(io, alg.alg) - print(io, ")") + return print(io, ")") end @@ -150,7 +150,7 @@ end prob1a = OperatorSplittingProblem(fsplit1a, u0, tspan) prob1b = OperatorSplittingProblem(fsplit1b, u0, tspan) - + # Note that we define the dof indices w.r.t the parent function. # Hence the indices for `fsplit2_inner` are. f3dofs = [1, 2]