Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
OrdinaryDiffEqCore = "bbf590c4-e513-4bbe-9b18-05decba2e5d8"
OrdinaryDiffEqLowOrderRK = "1344f307-1e59-4825-a18e-ace9aa3fa4c6"
OrdinaryDiffEqTsit5 = "b1df2697-797e-41e3-8120-5422d3b24e4a"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Expand Down
1 change: 1 addition & 0 deletions src/OrdinaryDiffEqOperatorSplitting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ include("solver.jl")
include("utils.jl")

export GenericSplitFunction, OperatorSplittingProblem, LieTrotterGodunov
export OperatorSplittingMinimalSolution

include("precompilation.jl")

Expand Down
104 changes: 98 additions & 6 deletions src/integrator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,47 @@ function subreinit!(
return DiffEqBase.reinit!(subintegrator, u0[solution_indices]; kwargs...)
end

# subreinit! for enhanced cache (subintegrator)
function subreinit!(
f,
u0,
solution_indices,
cache::AbstractOperatorSplittingCache;
t0 = nothing,
dt = nothing,
kwargs...
)
# Update cache state
@views u_sub = u0[solution_indices]
cache.u .= u_sub
cache.uprev .= u_sub

if t0 !== nothing
cache.t = t0
cache.tprev = t0
end

if dt !== nothing
cache.dt = dt
cache.dtcache = dt
end

cache.iter = 0
cache.sol.retcode = ReturnCode.Default

# Recursively reinit subintegrators
if cache isa LieTrotterGodunovCache && cache.subintegrator_tree isa Tuple
# Use manual iteration instead of @unroll for this nested call
i = 1
for subintegrator in cache.subintegrator_tree
subreinit!(get_operator(f, i), u0, cache.solution_index_tree[i], subintegrator; t0, dt, kwargs...)
i += 1
end
end

return nothing
end

@unroll function subreinit!(
f,
u0,
Expand Down Expand Up @@ -496,6 +537,22 @@ function check_error_subintegrators(integrator, subintegrator::DEIntegrator)
return SciMLBase.check_error(subintegrator)
end

# check_error for enhanced cache (subintegrator)
function check_error_subintegrators(integrator, cache::AbstractOperatorSplittingCache)
# Check the cache's own retcode
if !SciMLBase.successful_retcode(cache.sol.retcode) &&
cache.sol.retcode != ReturnCode.Default
return cache.sol.retcode
end

# Recursively check subintegrators if cache has them
if cache isa LieTrotterGodunovCache
return check_error_subintegrators(integrator, cache.subintegrator_tree)
end

return cache.sol.retcode
end

function DiffEqBase.step!(integrator::OperatorSplittingIntegrator, dt, stop_at_tdt = false)
return @timeit_debug "step!" begin
# OridinaryDiffEq lets dt be negative if tdir is -1, but that's inconsistent
Expand Down Expand Up @@ -682,6 +739,18 @@ function synchronize_subintegrator!(
end
end

# Synchronize enhanced cache (subintegrator)
function synchronize_subintegrator!(
cache::AbstractOperatorSplittingCache, integrator::OperatorSplittingIntegrator
)
(; t, dt) = integrator
@assert cache.t == t
# Update dt if needed
cache.dtcache = dt
cache.dt = dt
return nothing
end

function advance_solution_to!(
integrator::OperatorSplittingIntegrator,
cache::AbstractOperatorSplittingCache, tnext::Number
Expand Down Expand Up @@ -759,16 +828,39 @@ function build_subintegrator_tree_with_cache(

subintegrator_tree = first.(subintegrator_tree_with_caches)
inner_caches = last.(subintegrator_tree_with_caches)

# Build the trees for this level
solution_index_tree = ntuple(
i -> build_solution_index_tree_recursion(get_operator(f, i), f.solution_indices[i]),
length(f.functions)
)
synchronizer_tree = ntuple(
i -> build_synchronizer_tree_recursion(get_operator(f, i), f.synchronizers[i]),
length(f.functions)
)

# TODO fix mixed device type problems we have to be smarter
uprev = @view uprevouter[solution_indices]
u = @view uouter[solution_indices]
return subintegrator_tree,
init_cache(
f, alg;
uprev = uprev, u = u,
inner_caches = inner_caches
)

# Create the enhanced cache with all fields
cache = init_cache(
f, alg;
uprev = uprev, u = u,
inner_caches = inner_caches,
subintegrator_tree = subintegrator_tree,
solution_index_tree = solution_index_tree,
synchronizer_tree = synchronizer_tree,
t = t0,
dt = dt,
controller = controller
)

# Return (subintegrator, cache) tuple where both are the same enhanced cache.
# The first element is used as the subintegrator in the tree (the cache acts as a subintegrator),
# and the second element is the cache that will be stored in init_cache's inner_caches.
# This maintains the expected return signature while allowing the cache to function as a subintegrator.
return cache, cache
end

function build_subintegrator_tree_with_cache(
Expand Down
94 changes: 85 additions & 9 deletions src/solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,74 @@ struct LieTrotterGodunov{AlgTupleType} <: AbstractOperatorSplittingAlgorithm
# transfer_algs::TransferTupleType # Tuple of transfer algorithms from the master solution into the individual ones
end

struct LieTrotterGodunovCache{uType, uprevType, iiType} <: AbstractOperatorSplittingCache
"""
OperatorSplittingMinimalSolution

Minimal solution struct for subintegrators that carries just a retcode field for failure communication.
"""
mutable struct OperatorSplittingMinimalSolution{R}
retcode::R
end

OperatorSplittingMinimalSolution() = OperatorSplittingMinimalSolution(ReturnCode.Default)

"""
LieTrotterGodunovCache

Enhanced cache for Lie-Trotter-Godunov splitting that can act as a subintegrator.
Contains fields for adaptive time stepping and nested problem handling.
"""
mutable struct LieTrotterGodunovCache{uType, uprevType, tType, dtType, solType, controllerType, EEstType, iiType, subintTreeType, solidxTreeType, syncTreeType, statsType} <: AbstractOperatorSplittingCache
# Solution state
u::uType
uprev::uprevType

# Time stepping state
t::tType
tprev::tType
dt::dtType
dtcache::dtType

# Minimal solution for retcode communication
sol::solType

# Adaptive stepping fields
controller::controllerType
EEst::EEstType
iter::Int
stats::statsType

# Inner caches and subintegrator trees
inner_caches::iiType
subintegrator_tree::subintTreeType
solution_index_tree::solidxTreeType
synchronizer_tree::syncTreeType
end

function init_cache(
f::GenericSplitFunction, alg::LieTrotterGodunov;
uprev::AbstractArray, u::AbstractVector,
inner_caches,
subintegrator_tree = inner_caches,
solution_index_tree = ntuple(i -> nothing, length(inner_caches)),
synchronizer_tree = ntuple(i -> NoExternalSynchronization(), length(inner_caches)),
t = zero(eltype(u)),
dt = zero(eltype(u)),
controller = nothing,
alias_uprev = true,
alias_u = false
)
_uprev = alias_uprev ? uprev : RecursiveArrayTools.recursivecopy(uprev)
_u = alias_u ? u : RecursiveArrayTools.recursivecopy(u)
return LieTrotterGodunovCache(_u, _uprev, inner_caches)
tType = typeof(t)
sol = OperatorSplittingMinimalSolution()
return LieTrotterGodunovCache(
_u, _uprev,
t, copy(t), dt, copy(dt),
sol,
controller, zero(eltype(u)), 0, IntegratorStats(),
inner_caches, subintegrator_tree, solution_index_tree, synchronizer_tree
)
end

@inline @unroll function advance_solution_to!(
Expand All @@ -33,24 +85,48 @@ end
synchronizers::Tuple, cache::LieTrotterGodunovCache, tnext
)
# We assume that the integrators are already synced
(; inner_caches) = cache
(; inner_caches, subintegrator_tree, solution_index_tree, synchronizer_tree) = cache

# Update cache's own time state
cache.tprev = cache.t
cache.t = tnext

# Reset sol.retcode to default before attempting the step
cache.sol.retcode = ReturnCode.Default

# For each inner operator
i = 0
@unroll for subinteg in subintegrators
i += 1
synchronizer = synchronizers[i]
idxs = solution_indices[i]
cache = inner_caches[i]
inner_cache = inner_caches[i]

@timeit_debug "sync ->" forward_sync_subintegrator!(outer_integrator, subinteg, idxs, synchronizer)
@timeit_debug "time solve" advance_solution_to!(
outer_integrator, subinteg, idxs, synchronizer, cache, tnext
outer_integrator, subinteg, idxs, synchronizer, inner_cache, tnext
)
if !(subinteg isa Tuple) &&
subinteg.sol.retcode ∉
(ReturnCode.Default, ReturnCode.Success)
return

# Check return code and propagate failure
# Check for enhanced cache first, then DEIntegrator
if subinteg isa AbstractOperatorSplittingCache
# For enhanced cache acting as subintegrator
if subinteg.sol.retcode ∉ (ReturnCode.Default, ReturnCode.Success)
cache.sol.retcode = subinteg.sol.retcode
return
end
elseif !(subinteg isa Tuple)
# For single DEIntegrator
if subinteg.sol.retcode ∉ (ReturnCode.Default, ReturnCode.Success)
cache.sol.retcode = subinteg.sol.retcode
return
end
end

backward_sync_subintegrator!(outer_integrator, subinteg, idxs, synchronizer)
end

# If we got here, mark success
cache.sol.retcode = ReturnCode.Success
cache.iter += 1
end
53 changes: 53 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,15 @@ function forward_sync_subintegrator!(
return forward_sync_external!(outer_integrator, inner_integrator, sync)
end

# Forward sync for enhanced cache (subintegrator)
function forward_sync_subintegrator!(
outer_integrator::OperatorSplittingIntegrator,
cache::AbstractOperatorSplittingCache, solution_indices, sync
)
forward_sync_internal!(outer_integrator, cache, solution_indices)
return forward_sync_external!(outer_integrator, cache, sync)
end

"""
backward_sync_subintegrator!(outer_integrator::OperatorSplittingIntegrator, inner_integrator::DEIntegrator, solution_indices, sync)

Expand All @@ -79,6 +88,15 @@ function backward_sync_subintegrator!(
return backward_sync_external!(outer_integrator, inner_integrator, sync)
end

# Backward sync for enhanced cache (subintegrator)
function backward_sync_subintegrator!(
outer_integrator::OperatorSplittingIntegrator,
cache::AbstractOperatorSplittingCache, solution_indices, sync
)
backward_sync_internal!(outer_integrator, cache, solution_indices)
return backward_sync_external!(outer_integrator, cache, sync)
end

# This is a bit tricky, because per default the operator splitting integrators share their solution vector. However, there is also the case
# when part of the problem is on a different device (thing e.g. about operator A being on CPU and B being on GPU).
# This case should be handled with special synchronizers.
Expand All @@ -95,6 +113,29 @@ function backward_sync_internal!(
return nothing
end

# Forward sync for enhanced cache (subintegrator)
function forward_sync_internal!(
outer_integrator::OperatorSplittingIntegrator,
cache::AbstractOperatorSplittingCache, solution_indices
)
@views uouter = outer_integrator.u[solution_indices]
sync_vectors!(cache.uprev, uouter)
sync_vectors!(cache.u, uouter)
# Update time state
cache.t = outer_integrator.t
cache.dt = outer_integrator.dt
return nothing
end

# Backward sync for enhanced cache (subintegrator)
function backward_sync_internal!(
outer_integrator::OperatorSplittingIntegrator,
cache::AbstractOperatorSplittingCache, solution_indices
)
@views uouter = outer_integrator.u[solution_indices]
return sync_vectors!(uouter, cache.u)
end

function forward_sync_internal!(
outer_integrator::OperatorSplittingIntegrator,
inner_integrator::DEIntegrator, solution_indices
Expand Down Expand Up @@ -125,6 +166,12 @@ function forward_sync_external!(
)
return nothing
end
function forward_sync_external!(
outer_integrator::OperatorSplittingIntegrator,
cache::AbstractOperatorSplittingCache, sync::NoExternalSynchronization
)
return nothing
end
function forward_sync_external!(
outer_integrator::OperatorSplittingIntegrator,
inner_integrator::DEIntegrator, sync
Expand All @@ -144,6 +191,12 @@ function backward_sync_external!(
)
return nothing
end
function backward_sync_external!(
outer_integrator::OperatorSplittingIntegrator,
cache::AbstractOperatorSplittingCache, sync::NoExternalSynchronization
)
return nothing
end
function backward_sync_external!(
outer_integrator::OperatorSplittingIntegrator,
inner_integrator::DEIntegrator, sync
Expand Down
Loading
Loading