Skip to content

Commit ace82be

Browse files
Fix resize, algorithm-switch, and OOP W-caching for Jacobian reuse
Three robustness fixes for Rosenbrock-W Jacobian reuse: 1. Resize detection: Track last_u_length in JacReuseState. After a callback resize!, u_modified is cleared before perform_step! runs, so we check length(u) directly to detect dimension changes. 2. CompositeAlgorithm switch detection: Track last_step_iter to detect gaps in integrator.iter, which indicate another algorithm (e.g., Vern8) ran between Rosenbrock steps. Forces J recomputation on switch-back to avoid using a stale Jacobian. 3. OOP W caching: Honor the new_W=false flag by caching and reusing the factorized W (LU), matching IIP behavior. Previously the OOP path always rebuilt W from stale J, which masked inaccuracy and prevented the self-correcting rejection feedback loop. This fixes AutoVern8(Rosenbrock23()) OOP hitting MaxIters on stiff problems. Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com> Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent ef15360 commit ace82be

File tree

3 files changed

+50
-9
lines changed

3 files changed

+50
-9
lines changed

lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,26 @@ function _rosenbrock_jac_reuse_decision(integrator, cache, dtgamma)
6363
return (true, true)
6464
end
6565

66+
# Detect algorithm switch in CompositeAlgorithm: if integrator.iter jumped
67+
# by more than 1 since our last Rosenbrock step, another algorithm ran in
68+
# between and the cached Jacobian is evaluated at a stale u.
69+
if jac_reuse.last_step_iter != 0 && integrator.iter > jac_reuse.last_step_iter + 1
70+
return (true, true)
71+
end
72+
6673
# Callback modification: recompute
6774
if integrator.u_modified
6875
return (true, true)
6976
end
7077

78+
# Resize detection: if u changed length since last J computation,
79+
# the cached LU factorization has wrong dimensions.
80+
# (u_modified is already cleared by reeval_internals_due_to_modification!
81+
# before perform_step! runs, so we need this explicit check.)
82+
if length(integrator.u) != jac_reuse.last_u_length && jac_reuse.last_u_length != 0
83+
return (true, true)
84+
end
85+
7186
# Previous step was rejected (EEst > 1): the old W wasn't good enough.
7287
# Recompute everything since we're retrying with a different dt anyway.
7388
if integrator.EEst > 1
@@ -792,8 +807,12 @@ function calc_rosenbrock_differentiation!(integrator, cache, dtd1, dtgamma, repe
792807
# _rosenbrock_jac_reuse_decision). This tracks the dtgamma at the
793808
# last J computation for the gamma ratio heuristic.
794809
jac_reuse = get_jac_reuse(cache)
795-
if jac_reuse !== nothing && new_jac
796-
jac_reuse.pending_dtgamma = dtgamma
810+
if jac_reuse !== nothing
811+
jac_reuse.last_step_iter = integrator.iter
812+
if new_jac
813+
jac_reuse.pending_dtgamma = dtgamma
814+
jac_reuse.last_u_length = length(integrator.u)
815+
end
797816
end
798817
end
799818
# If the Jacobian is not updated, we won't have to update ∂/∂t either.
@@ -829,6 +848,9 @@ function calc_rosenbrock_differentiation(integrator, cache, dtgamma, repeat_step
829848

830849
new_jac, new_W = newJW
831850

851+
# Track iteration for algorithm-switch detection in CompositeAlgorithm
852+
jac_reuse.last_step_iter = integrator.iter
853+
832854
# For complex W types (operators), delegate to standard calc_W
833855
if cache.W isa StaticWOperator || cache.W isa WOperator ||
834856
cache.W isa AbstractSciMLOperator
@@ -841,10 +863,14 @@ function calc_rosenbrock_differentiation(integrator, cache, dtgamma, repeat_step
841863
mass_matrix = integrator.f.mass_matrix
842864
update_coefficients!(mass_matrix, integrator.uprev, integrator.p, integrator.t)
843865

844-
# Safety: if cached_J is nothing (e.g. first use after algorithm switch),
866+
# Safety: if cached_J or cached_W is nothing (e.g. first use after algorithm switch),
845867
# force a fresh computation regardless of the decision.
846868
if !new_jac && jac_reuse.cached_J === nothing
847869
new_jac = true
870+
new_W = true
871+
end
872+
if !new_W && jac_reuse.cached_W === nothing
873+
new_W = true
848874
end
849875

850876
if new_jac
@@ -864,14 +890,23 @@ function calc_rosenbrock_differentiation(integrator, cache, dtgamma, repeat_step
864890
# committed as last_dtgamma on the next accepted step.
865891
if new_jac
866892
jac_reuse.pending_dtgamma = dtgamma
893+
jac_reuse.last_u_length = length(integrator.u)
867894
end
868895

869-
# Build W from (possibly cached) J
870-
W = J - mass_matrix * inv(dtgamma)
871-
if !isa(W, Number)
872-
W = DiffEqBase.default_factorize(W)
896+
# Build and cache W, or reuse cached W (including LU factorization).
897+
# Reusing old W (with old dtgamma) mirrors IIP behavior: if the old W
898+
# is too inaccurate, the step will be rejected and EEst > 1 triggers
899+
# a fresh J+W computation on the retry.
900+
if new_W
901+
W = J - mass_matrix * inv(dtgamma)
902+
if !isa(W, Number)
903+
W = DiffEqBase.default_factorize(W)
904+
end
905+
integrator.stats.nw += 1
906+
jac_reuse.cached_W = W
907+
else
908+
W = jac_reuse.cached_W
873909
end
874-
integrator.stats.nw += 1
875910

876911
return dT, W
877912
end

lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ Fields:
1616
- `cached_J`: Cached Jacobian for OOP reuse (type-erased for flexibility)
1717
- `cached_dT`: Cached time derivative for OOP reuse
1818
- `cached_W`: Cached factorized W for OOP reuse (LU factorization is expensive)
19+
- `last_u_length`: Length of u when J/W were last computed (detects resize!)
20+
- `last_step_iter`: `integrator.iter` at last Rosenbrock step (detects algorithm switches)
1921
"""
2022
mutable struct JacReuseState{T}
2123
last_dtgamma::T
@@ -25,10 +27,12 @@ mutable struct JacReuseState{T}
2527
cached_J::Any
2628
cached_dT::Any
2729
cached_W::Any
30+
last_u_length::Int
31+
last_step_iter::Int
2832
end
2933

3034
function JacReuseState(dtgamma::T) where {T}
31-
return JacReuseState{T}(dtgamma, dtgamma, 0, 50, nothing, nothing, nothing)
35+
return JacReuseState{T}(dtgamma, dtgamma, 0, 50, nothing, nothing, nothing, 0, 0)
3236
end
3337

3438
# Fake values since non-FSAL

lib/OrdinaryDiffEqRosenbrock/test/jacobian_reuse_test.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,8 @@ strict_rosenbrock = [
124124
@test jr.cached_J === nothing
125125
@test jr.cached_dT === nothing
126126
@test jr.cached_W === nothing
127+
@test jr.last_u_length == 0
128+
@test jr.last_step_iter == 0
127129
end
128130

129131
# ========================================================================

0 commit comments

Comments
 (0)