diff --git a/Project.toml b/Project.toml index 5d9d184..9428bf3 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SimpleNonlinearSolve" uuid = "727e6d20-b764-4bd8-a329-72de5adea6c7" authors = ["SciML"] -version = "1.12.3" +version = "1.12.4" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -10,10 +10,12 @@ ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" +EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56" FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" MaybeInplace = "bb5d69b7-63fc-4a16-80bd-7e42200c7bdb" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" @@ -37,7 +39,7 @@ SimpleNonlinearSolveZygoteExt = "Zygote" [compat] ADTypes = "1.9" -AllocCheck = "0.1.1" +AllocCheck = "0.1.1, 0.2" Aqua = "0.8" ArrayInterface = "7.9" ChainRulesCore = "1.23" @@ -45,6 +47,7 @@ ConcreteStructs = "0.2.3" DiffEqBase = "6.149" DiffResults = "1.1" DifferentiationInterface = "0.6.1" +EnumX = "1.0.4" ExplicitImports = "1.5.0" FastClosures = "0.3.2" FiniteDiff = "2.23.1" @@ -52,6 +55,7 @@ ForwardDiff = "0.10.36" Hwloc = "3" InteractiveUtils = "<0.0.1, 1" LinearAlgebra = "1.10" +Markdown = "1.11.0" MaybeInplace = "0.1.3" NonlinearProblemLibrary = "0.1.2" Pkg = "1.10" @@ -65,7 +69,7 @@ SciMLBase = "2.37.0" Setfield = "1.1.1" StaticArrays = "1.9" StaticArraysCore = "1.4.2" -TaylorDiff = "0.2.5, 0.3" +TaylorDiff = "0.2.5" Test = "1.10" Tracker = "0.2.33" Zygote = "0.6.69" diff --git a/src/SimpleNonlinearSolve.jl b/src/SimpleNonlinearSolve.jl index e6c939b..fa69cf5 100644 --- a/src/SimpleNonlinearSolve.jl +++ b/src/SimpleNonlinearSolve.jl @@ -6,17 +6,16 @@ using ADTypes: ADTypes, AbstractADType, AutoFiniteDiff, AutoForwardDiff, AutoPolyesterForwardDiff using ArrayInterface: ArrayInterface using ConcreteStructs: @concrete -using DiffEqBase: DiffEqBase, AbstractNonlinearTerminationMode, - AbstractSafeNonlinearTerminationMode, - AbstractSafeBestNonlinearTerminationMode, AbsNormTerminationMode, - NONLINEARSOLVE_DEFAULT_NORM +using DiffEqBase: DiffEqBase using DifferentiationInterface: DifferentiationInterface using DiffResults: DiffResults using FastClosures: @closure using FiniteDiff: FiniteDiff using ForwardDiff: ForwardDiff, Dual +using EnumX using LinearAlgebra: LinearAlgebra, I, convert, copyto!, diagind, dot, issuccess, lu, mul!, norm, transpose +using Markdown using MaybeInplace: @bb, setindex_trait, CanSetindex, CannotSetindex using Reexport: @reexport using SciMLBase: @add_kwonly, SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem, @@ -36,6 +35,8 @@ abstract type AbstractBracketingAlgorithm <: AbstractSimpleNonlinearSolveAlgorit abstract type AbstractNewtonAlgorithm <: AbstractSimpleNonlinearSolveAlgorithm end @inline __is_extension_loaded(::Val) = false +include("termination_conditions_deprecated.jl") +include("termination_conditions.jl") include("immutable_nonlinear_problem.jl") include("utils.jl") include("linesearch.jl") @@ -143,4 +144,13 @@ export SimpleBroyden, SimpleDFSane, SimpleGaussNewton, SimpleHalley, SimpleKleme export SimpleHouseholder export Alefeld, Bisection, Brent, Falsi, ITP, Ridder +export SteadyStateDiffEqTerminationMode, SimpleNonlinearSolveTerminationMode, + NormTerminationMode, RelTerminationMode, RelNormTerminationMode, AbsTerminationMode, + AbsNormTerminationMode, RelSafeTerminationMode, AbsSafeTerminationMode, + RelSafeBestTerminationMode, AbsSafeBestTerminationMode +# Deprecated API +export NLSolveTerminationMode, + NLSolveSafeTerminationOptions, NLSolveTerminationCondition, + NLSolveSafeTerminationResult + end # module diff --git a/src/termination_conditions.jl b/src/termination_conditions.jl new file mode 100644 index 0000000..f933158 --- /dev/null +++ b/src/termination_conditions.jl @@ -0,0 +1,568 @@ +abstract type AbstractNonlinearTerminationMode end +abstract type AbstractSafeNonlinearTerminationMode <: AbstractNonlinearTerminationMode end +abstract type AbstractSafeBestNonlinearTerminationMode <: + AbstractSafeNonlinearTerminationMode end + +""" + SteadyStateDiffEqTerminationMode <: AbstractNonlinearTerminationMode + +Check if all values of the derivative is close to zero wrt both relative and absolute +tolerance. + +!!! danger + + This has been deprecated. +""" +struct SteadyStateDiffEqTerminationMode <: AbstractNonlinearTerminationMode + function SteadyStateDiffEqTerminationMode() + Base.depwarn("`SteadyStateDiffEqTerminationMode` is deprecated and isn't used \ + in any upstream library. Remove uses of this.", + :SteadyStateDiffEqTerminationMode) + return new() + end +end + +""" + SimpleNonlinearSolveTerminationMode <: AbstractNonlinearTerminationMode + +Check if all values of the derivative is close to zero wrt both relative and absolute +tolerance. Or check that the value of the current and previous state is within the specified +tolerances. + +!!! danger + + This has been deprecated. +""" +struct SimpleNonlinearSolveTerminationMode <: AbstractNonlinearTerminationMode + function SimpleNonlinearSolveTerminationMode() + Base.depwarn("`SimpleNonlinearSolveTerminationMode` is deprecated and isn't used \ + in any upstream library. Remove uses of this.", + :SimpleNonlinearSolveTerminationMode) + return new() + end +end + +@inline set_termination_mode_internalnorm(mode, ::F) where {F} = mode + +@inline __norm_type(::typeof(Base.Fix2(norm, Inf))) = :Inf +@inline __norm_type(::typeof(Base.Fix1(maximum, abs))) = :Inf +@inline __norm_type(::typeof(Base.Fix2(norm, 2))) = :L2 +@inline __norm_type(::F) where {F} = F + +const TERM_DOCS = Dict( + :Norm => doc"``\| \Delta u \| \leq reltol \times \| \Delta u + u \|`` or ``\| \Delta u \| \leq abstol``.", + :Rel => doc"``all \left(| \Delta u | \leq reltol \times | u | \right)``.", + :RelNorm => doc"``\| \Delta u \| \leq reltol \times \| \Delta u + u \|``.", + :Abs => doc"``all \left( | \Delta u | \leq abstol \right)``.", + :AbsNorm => doc"``\| \Delta u \| \leq abstol``." +) + +const __TERM_INTERNALNORM_DOCS = """ +where `internalnorm` is the norm to use for the termination condition. Special handling is +done for `norm(_, 2)`, `norm(_, Inf)`, and `maximum(abs, _)`. + +Default is left as `nothing`, which allows upstream frameworks to choose the correct norm +based on the problem type. If directly using the `init` API, a proper norm must be +provided""" + +for name in (:Rel, :Abs) + struct_name = Symbol(name, :TerminationMode) + doctring = TERM_DOCS[name] + + @eval begin + """ + $($struct_name) <: AbstractNonlinearTerminationMode + + Terminates if $($doctring). + + ``\\Delta u`` denotes the increment computed by the nonlinear solver and ``u`` denotes the solution. + """ + struct $(struct_name) <: AbstractNonlinearTerminationMode end + end +end + +for name in (:Norm, :RelNorm, :AbsNorm) + struct_name = Symbol(name, :TerminationMode) + doctring = TERM_DOCS[name] + + @eval begin + """ + $($struct_name) <: AbstractNonlinearTerminationMode + + Terminates if $($doctring). + + ``\\Delta u`` denotes the increment computed by the inner nonlinear solver. + + ## Constructor + + $($struct_name)(internalnorm = nothing) + + $($__TERM_INTERNALNORM_DOCS). + """ + @concrete struct $(struct_name){F} <: AbstractNonlinearTerminationMode + internalnorm + + $(struct_name)(f::F = nothing) where {F} = new{__norm_type(f), F}(f) + end + + @inline function set_termination_mode_internalnorm( + ::$(struct_name), internalnorm::F) where {F} + return $(struct_name)(internalnorm) + end + end +end + +for norm_type in (:Rel, :Abs), safety in (:Safe, :SafeBest) + struct_name = Symbol(norm_type, safety, :TerminationMode) + supertype_name = Symbol(:Abstract, safety, :NonlinearTerminationMode) + + doctring = safety == :Safe ? + "Essentially [`$(norm_type)NormTerminationMode`](@ref) + terminate if there \ + has been no improvement for the last `patience_steps` + terminate if the \ + solution blows up (diverges)." : + "Essentially [`$(norm_type)SafeTerminationMode`](@ref), but caches the best\ + solution found so far." + + @eval begin + """ + $($struct_name) <: $($supertype_name) + + $($doctring) + + ## Constructor + + $($struct_name)(internalnorm = nothing; protective_threshold = nothing, + patience_steps = 100, patience_objective_multiplier = 3, + min_max_factor = 1.3, max_stalled_steps = nothing) + + $($__TERM_INTERNALNORM_DOCS). + """ + @concrete struct $(struct_name){F, T <: Union{Nothing, Int}} <: $(supertype_name) + internalnorm + protective_threshold + patience_steps::Int + patience_objective_multiplier + min_max_factor + max_stalled_steps::T + + function $(struct_name)(f::F = nothing; protective_threshold = nothing, + patience_steps = 100, patience_objective_multiplier = 3, + min_max_factor = 1.3, max_stalled_steps = nothing) where {F} + return new{__norm_type(f), typeof(max_stalled_steps), F, + typeof(protective_threshold), typeof(patience_objective_multiplier), + typeof(min_max_factor)}(f, protective_threshold, patience_steps, + patience_objective_multiplier, min_max_factor, max_stalled_steps) + end + end + + @inline function set_termination_mode_internalnorm( + mode::$(struct_name), internalnorm::F) where {F} + return $(struct_name)(internalnorm; mode.protective_threshold, + mode.patience_steps, mode.patience_objective_multiplier, + mode.min_max_factor, mode.max_stalled_steps) + end + end +end + +@concrete mutable struct NonlinearTerminationModeCache{dep_retcode, + M <: AbstractNonlinearTerminationMode, + R <: Union{NonlinearSafeTerminationReturnCode.T, ReturnCode.T}} + u + retcode::R + abstol + reltol + best_objective_value + mode::M + initial_objective + objectives_trace + nsteps::Int + saved_values + u0_norm + step_norm_trace + max_stalled_steps + u_diff_cache +end + +@inline get_termination_mode(cache::NonlinearTerminationModeCache) = cache.mode +@inline get_abstol(cache::NonlinearTerminationModeCache) = cache.abstol +@inline get_reltol(cache::NonlinearTerminationModeCache) = cache.reltol +@inline get_saved_values(cache::NonlinearTerminationModeCache) = cache.saved_values + +function __update_u!!(cache::NonlinearTerminationModeCache, u) + cache.u === nothing && return + if cache.u isa AbstractArray && ArrayInterface.can_setindex(cache.u) + copyto!(cache.u, u) + else + cache.u = u + end +end + +@inline __cvt_real(::Type{T}, ::Nothing) where {T} = nothing +@inline __cvt_real(::Type{T}, x) where {T} = real(T(x)) + +@inline _get_tolerance(η, ::Type{T}) where {T} = __cvt_real(T, η) +@inline function _get_tolerance(::Nothing, ::Type{T}) where {T} + η = real(oneunit(T)) * (eps(real(one(T))))^(4 // 5) + return _get_tolerance(η, T) +end + +function SciMLBase.init(du::Union{AbstractArray{T}, T}, u::Union{AbstractArray{T}, T}, + mode::AbstractNonlinearTerminationMode, saved_value_prototype...; + use_deprecated_retcodes::Val{D} = Val(true), # Remove in v8, warn in v7 + abstol = nothing, reltol = nothing, kwargs...) where {D, T <: Number} + abstol = _get_tolerance(abstol, T) + reltol = _get_tolerance(reltol, T) + TT = typeof(abstol) + u_ = mode isa AbstractSafeBestNonlinearTerminationMode ? + (ArrayInterface.can_setindex(u) ? copy(u) : u) : nothing + if mode isa AbstractSafeNonlinearTerminationMode + if mode isa AbsSafeTerminationMode || mode isa AbsSafeBestTerminationMode + initial_objective = __apply_termination_internalnorm(mode.internalnorm, du) + u0_norm = nothing + else + initial_objective = __apply_termination_internalnorm(mode.internalnorm, du) / + (__add_and_norm(mode.internalnorm, du, u) + eps(TT)) + u0_norm = mode.max_stalled_steps === nothing ? nothing : norm(u, 2) + end + objectives_trace = Vector{TT}(undef, mode.patience_steps) + step_norm_trace = mode.max_stalled_steps === nothing ? nothing : + Vector{TT}(undef, mode.max_stalled_steps) + best_value = initial_objective + max_stalled_steps = mode.max_stalled_steps + if ArrayInterface.can_setindex(u_) && !(u_ isa Number) && + step_norm_trace !== nothing + u_diff_cache = similar(u_) + else + u_diff_cache = u_ + end + else + initial_objective = nothing + objectives_trace = nothing + u0_norm = nothing + step_norm_trace = nothing + best_value = __cvt_real(T, Inf) + max_stalled_steps = nothing + u_diff_cache = u_ + end + + length(saved_value_prototype) == 0 && (saved_value_prototype = nothing) + + retcode = ifelse(D, NonlinearSafeTerminationReturnCode.Default, ReturnCode.Default) + + return NonlinearTerminationModeCache{D}(u_, retcode, abstol, reltol, best_value, mode, + initial_objective, objectives_trace, 0, saved_value_prototype, u0_norm, + step_norm_trace, max_stalled_steps, u_diff_cache) +end + +function SciMLBase.reinit!(cache::NonlinearTerminationModeCache{dep_retcode}, du, + u, saved_value_prototype...; abstol = nothing, reltol = nothing, + kwargs...) where {dep_retcode} + T = eltype(cache.abstol) + length(saved_value_prototype) != 0 && (cache.saved_values = saved_value_prototype) + + u_ = cache.mode isa AbstractSafeBestNonlinearTerminationMode ? + (ArrayInterface.can_setindex(u) ? copy(u) : u) : nothing + cache.u = u_ + cache.retcode = ifelse(dep_retcode, NonlinearSafeTerminationReturnCode.Default, + ReturnCode.Default) + + cache.abstol = _get_tolerance(abstol, T) + cache.reltol = _get_tolerance(reltol, T) + cache.nsteps = 0 + + mode = get_termination_mode(cache) + if mode isa AbstractSafeNonlinearTerminationMode + if mode isa AbsSafeTerminationMode || mode isa AbsSafeBestTerminationMode + initial_objective = __apply_termination_internalnorm( + cache.mode.internalnorm, du) + else + initial_objective = __apply_termination_internalnorm( + cache.mode.internalnorm, du) / + (__add_and_norm(cache.mode.internalnorm, du, u) + eps(TT)) + cache.max_stalled_steps !== nothing && (cache.u0_norm = norm(u_, 2)) + end + best_value = initial_objective + else + initial_objective = nothing + best_value = __cvt_real(T, Inf) + end + cache.best_objective_value = best_value + cache.initial_objective = initial_objective + return cache +end + +# This dispatch is needed based on how Terminating Callback works! +# This intentially drops the `abstol` and `reltol` arguments +function (cache::NonlinearTerminationModeCache)(integrator::SciMLBase.AbstractODEIntegrator, + abstol::Number, reltol::Number, min_t) + retval = cache(cache.mode, get_du(integrator), integrator.u, integrator.uprev) + (min_t === nothing || integrator.t ≥ min_t) && return retval + return false +end +function (cache::NonlinearTerminationModeCache)(du, u, uprev, args...) + return cache(cache.mode, du, u, uprev, args...) +end + +function (cache::NonlinearTerminationModeCache)(mode::AbstractNonlinearTerminationMode, du, + u, uprev, args...) + return check_convergence(mode, du, u, uprev, cache.abstol, cache.reltol) +end + +function (cache::NonlinearTerminationModeCache{dep_retcode})( + mode::AbstractSafeNonlinearTerminationMode, + du, u, uprev, args...) where {dep_retcode} + if mode isa AbsSafeTerminationMode || mode isa AbsSafeBestTerminationMode + objective = __apply_termination_internalnorm(mode.internalnorm, du) + criteria = cache.abstol + else + objective = __apply_termination_internalnorm(mode.internalnorm, du) / + (__add_and_norm(mode.internalnorm, du, u) + eps(cache.abstol)) + criteria = cache.reltol + end + + # Protective Break + if isinf(objective) || isnan(objective) + cache.retcode = ifelse(dep_retcode, + NonlinearSafeTerminationReturnCode.ProtectiveTermination, ReturnCode.Unstable) + return true + end + ## By default we turn this off since it has the potential for false positives + if cache.mode.protective_threshold !== nothing && + (objective > cache.initial_objective * cache.mode.protective_threshold * length(du)) + cache.retcode = ifelse(dep_retcode, + NonlinearSafeTerminationReturnCode.ProtectiveTermination, ReturnCode.Unstable) + return true + end + + # Check if best solution + if mode isa AbstractSafeBestNonlinearTerminationMode && + objective < cache.best_objective_value + cache.best_objective_value = objective + __update_u!!(cache, u) + if cache.saved_values !== nothing && length(args) ≥ 1 + cache.saved_values = args + end + end + + # Main Termination Condition + if objective ≤ criteria + cache.retcode = ifelse(dep_retcode, + NonlinearSafeTerminationReturnCode.Success, ReturnCode.Success) + return true + end + + # Terminate if there has been no improvement for the last `patience_steps` + cache.nsteps += 1 + cache.nsteps == 1 && (cache.initial_objective = objective) + cache.objectives_trace[mod1(cache.nsteps, length(cache.objectives_trace))] = objective + + if objective ≤ cache.mode.patience_objective_multiplier * criteria + if cache.nsteps ≥ cache.mode.patience_steps + if cache.nsteps < length(cache.objectives_trace) + min_obj, max_obj = extrema(@view(cache.objectives_trace[1:(cache.nsteps)])) + else + min_obj, max_obj = extrema(cache.objectives_trace) + end + if min_obj < cache.mode.min_max_factor * max_obj + cache.retcode = ifelse(dep_retcode, + NonlinearSafeTerminationReturnCode.PatienceTermination, + ReturnCode.Stalled) + return true + end + end + end + + # Test for stalling if that is not disabled + if cache.step_norm_trace !== nothing + if ArrayInterface.can_setindex(cache.u_diff_cache) && !(u isa Number) + @. cache.u_diff_cache = u - uprev + else + cache.u_diff_cache = u .- uprev + end + du_norm = norm(cache.u_diff_cache, 2) + cache.step_norm_trace[mod1(cache.nsteps, length(cache.step_norm_trace))] = du_norm + if cache.nsteps ≥ cache.mode.max_stalled_steps + max_step_norm = maximum(cache.step_norm_trace) + if cache.mode isa AbsSafeTerminationMode || + cache.mode isa AbsSafeBestTerminationMode + stalled_step = max_step_norm ≤ cache.abstol + else + stalled_step = max_step_norm ≤ + cache.reltol * (max_step_norm + cache.u0_norm) + end + if stalled_step + cache.retcode = ifelse(dep_retcode, + NonlinearSafeTerminationReturnCode.PatienceTermination, + ReturnCode.Stalled) + return true + end + end + end + + cache.retcode = ifelse(dep_retcode, + NonlinearSafeTerminationReturnCode.Failure, ReturnCode.Failure) + return false +end + +# Check Convergence +function check_convergence(::SteadyStateDiffEqTerminationMode, duₙ, uₙ, uₙ₋₁, abstol, + reltol) + if __fast_scalar_indexing(duₙ, uₙ) + return all(@closure(xy->begin + x, y = xy + return (abs(x) ≤ abstol) | (abs(x) ≤ reltol * abs(y)) + end), + zip(duₙ, uₙ)) + else + return all(@. (abs(duₙ) ≤ abstol) | (abs(duₙ) ≤ reltol * abs(uₙ))) + end +end + +function check_convergence( + ::SimpleNonlinearSolveTerminationMode, duₙ, uₙ, uₙ₋₁, abstol, reltol) + if __fast_scalar_indexing(duₙ, uₙ) + return all(@closure(xy->begin + x, y = xy + return (abs(x) ≤ abstol) | (abs(x) ≤ reltol * abs(y)) + end), + zip(duₙ, uₙ)) || + __nonlinearsolve_is_approx(uₙ, uₙ₋₁; atol = abstol, rtol = reltol) + else + return all(@. (abs(duₙ) ≤ abstol) | (abs(duₙ) ≤ reltol * abs(uₙ))) || + __nonlinearsolve_is_approx(uₙ, uₙ₋₁; atol = abstol, rtol = reltol) + end +end +function check_convergence(::RelTerminationMode, duₙ, uₙ, uₙ₋₁, abstol, reltol) + if __fast_scalar_indexing(duₙ, uₙ) + return all(@closure(xy->begin + x, y = xy + return abs(x) ≤ reltol * abs(y) + end), zip(duₙ, uₙ)) + else + return all(@. abs(duₙ) ≤ reltol * abs(uₙ + duₙ)) + end +end +function check_convergence(::AbsTerminationMode, duₙ, uₙ, uₙ₋₁, abstol, reltol) + return all(@closure(x->abs(x) ≤ abstol), duₙ) +end + +function check_convergence(mode::NormTerminationMode, duₙ, uₙ, uₙ₋₁, abstol, reltol) + du_norm = __apply_termination_internalnorm(mode.internalnorm, duₙ) + return (du_norm ≤ abstol) || + (du_norm ≤ reltol * __add_and_norm(mode.internalnorm, duₙ, uₙ)) +end +function check_convergence( + mode::Union{ + RelNormTerminationMode, RelSafeTerminationMode, RelSafeBestTerminationMode}, + duₙ, uₙ, uₙ₋₁, abstol, reltol) + return __apply_termination_internalnorm(mode.internalnorm, duₙ) ≤ + reltol * __add_and_norm(mode.internalnorm, duₙ, uₙ) +end +function check_convergence( + mode::Union{AbsNormTerminationMode, AbsSafeTerminationMode, + AbsSafeBestTerminationMode}, + duₙ, uₙ, uₙ₋₁, abstol, reltol) + return __apply_termination_internalnorm(mode.internalnorm, duₙ) ≤ abstol +end + +@inline function __apply_termination_internalnorm(::Nothing, u) + return __apply_termination_internalnorm(Base.Fix1(maximum, abs), u) +end +@inline __apply_termination_internalnorm(f::F, u) where {F} = f(u) + +# Nonlinear Solve Norm (norm(_, 2)) +NONLINEARSOLVE_DEFAULT_NORM(u::Union{AbstractFloat, Complex}) = @fastmath abs(u) +function NONLINEARSOLVE_DEFAULT_NORM(f::F, + u::Union{AbstractFloat, Complex}) where {F} + return @fastmath abs(f(u)) +end + +function NONLINEARSOLVE_DEFAULT_NORM(u::Array{ + T}) where {T <: Union{AbstractFloat, Complex}} + x = zero(T) + @inbounds @fastmath for ui in u + x += abs2(ui) + end + return Base.FastMath.sqrt_fast(real(x)) +end + +function NONLINEARSOLVE_DEFAULT_NORM(f::F, + u::Union{Array{T}, Iterators.Zip{<:Tuple{Vararg{Array{T}}}}}) where { + F, T <: Union{AbstractFloat, Complex}} + x = zero(T) + @inbounds @fastmath for ui in u + x += abs2(f(ui)) + end + return Base.FastMath.sqrt_fast(real(x)) +end + +function NONLINEARSOLVE_DEFAULT_NORM(u::StaticArray{ + <:Tuple, T}) where {T <: Union{AbstractFloat, Complex}} + return Base.FastMath.sqrt_fast(real(sum(abs2, u))) +end + +function NONLINEARSOLVE_DEFAULT_NORM(f::F, + u::StaticArray{<:Tuple, T}) where { + F, T <: Union{AbstractFloat, Complex}} + return Base.FastMath.sqrt_fast(real(sum(abs2 ∘ f, u))) +end + +function NONLINEARSOLVE_DEFAULT_NORM(u::AbstractArray) + return Base.FastMath.sqrt_fast(UNITLESS_ABS2(u)) +end + +function NONLINEARSOLVE_DEFAULT_NORM(f::F, u::AbstractArray) where {F} + return Base.FastMath.sqrt_fast(UNITLESS_ABS2(f, u)) +end + +NONLINEARSOLVE_DEFAULT_NORM(u) = norm(u) +NONLINEARSOLVE_DEFAULT_NORM(f::F, u) where {F} = norm(f.(u)) + +@inline __fast_scalar_indexing(args...) = all(ArrayInterface.fast_scalar_indexing, args) + +@inline __maximum_abs(op::F, x, y) where {F} = __maximum(abs ∘ op, x, y) +## Nonallocating version of maximum(op.(x, y)) +@inline function __maximum(op::F, x, y) where {F} + if __fast_scalar_indexing(x, y) + return maximum(@closure((xᵢyᵢ)->begin + xᵢ, yᵢ = xᵢyᵢ + return op(xᵢ, yᵢ) + end), zip(x, y)) + else + return mapreduce(@closure((xᵢ, yᵢ)->op(xᵢ, yᵢ)), max, x, y) + end +end + +@inline function __norm_op(::typeof(Base.Fix2(norm, 2)), op::F, x, y) where {F} + if __fast_scalar_indexing(x, y) + return sqrt(sum(@closure((xᵢyᵢ)->begin + xᵢ, yᵢ = xᵢyᵢ + return op(xᵢ, yᵢ)^2 + end), zip(x, y))) + else + return sqrt(mapreduce(@closure((xᵢ, yᵢ)->(op(xᵢ, yᵢ)^2)), +, x, y)) + end +end + +@inline __norm_op(norm::N, op::F, x, y) where {N, F} = norm(op.(x, y)) + +function __nonlinearsolve_is_approx(x::Number, y::Number; atol = false, + rtol = atol > 0 ? false : sqrt(eps(promote_type(typeof(x), typeof(y))))) + return isapprox(x, y; atol, rtol) +end +function __nonlinearsolve_is_approx(x, y; atol = false, + rtol = atol > 0 ? false : sqrt(eps(promote_type(eltype(x), eltype(y))))) + length(x) != length(y) && return false + d = __maximum_abs(-, x, y) + return d ≤ max(atol, rtol * max(maximum(abs, x), maximum(abs, y))) +end + +@inline function __add_and_norm(::Nothing, x, y) + Base.depwarn("Not specifying the internal norm of termination conditions has been \ + deprecated. Using inf-norm currently.", + :__add_and_norm) + return __maximum_abs(+, x, y) +end +@inline __add_and_norm(::typeof(Base.Fix1(maximum, abs)), x, y) = __maximum_abs(+, x, y) +@inline __add_and_norm(::typeof(Base.Fix2(norm, Inf)), x, y) = __maximum_abs(+, x, y) +@inline __add_and_norm(f::F, x, y) where {F} = __norm_op(f, +, x, y) diff --git a/src/termination_conditions_deprecated.jl b/src/termination_conditions_deprecated.jl new file mode 100644 index 0000000..451b84c --- /dev/null +++ b/src/termination_conditions_deprecated.jl @@ -0,0 +1,307 @@ +""" + NonlinearSafeTerminationReturnCode + +Return Codes for the safe nonlinear termination conditions. + +These return codes have been deprecated. Termination Conditions will return +`SciMLBase.Retcode.T` starting from v7. +""" +@enumx NonlinearSafeTerminationReturnCode begin + """ + NonlinearSafeTerminationReturnCode.Success + + Termination Condition was satisfied! + """ + Success + """ + NonlinearSafeTerminationReturnCode.Default + + Default Return Code. Used for type stability and conveys no additional information! + """ + Default + """ + NonlinearSafeTerminationReturnCode.PatienceTermination + + Terminate if there has been no improvement for the last `patience_steps`. + """ + PatienceTermination + """ + NonlinearSafeTerminationReturnCode.ProtectiveTermination + + Terminate if the objective value increased by this factor wrt initial objective or the + value diverged. + """ + ProtectiveTermination + """ + NonlinearSafeTerminationReturnCode.Failure + + Termination Condition was not satisfied! + """ + Failure +end + +# NOTE: Deprecate the following API eventually. This API leads to quite a bit of type +# instability +@enumx NLSolveSafeTerminationReturnCode begin + Success + PatienceTermination + ProtectiveTermination + Failure +end + +# SteadyStateDefault and NLSolveDefault are needed to be compatible with the existing +# termination conditions in NonlinearSolve and SteadyStateDiffEq +@enumx NLSolveTerminationMode begin + SteadyStateDefault + NLSolveDefault + Norm + Rel + RelNorm + Abs + AbsNorm + RelSafe + RelSafeBest + AbsSafe + AbsSafeBest +end + +struct NLSolveSafeTerminationOptions{T1, T2, T3} + protective_threshold::T1 + patience_steps::Int + patience_objective_multiplier::T2 + min_max_factor::T3 +end + +mutable struct NLSolveSafeTerminationResult{T, uType} + u::uType + best_objective_value::T + best_objective_value_iteration::Int + return_code::NLSolveSafeTerminationReturnCode.T +end + +function NLSolveSafeTerminationResult(u = nothing; best_objective_value = Inf64, + best_objective_value_iteration = 0, + return_code = NLSolveSafeTerminationReturnCode.Failure) + u = u !== nothing ? copy(u) : u + Base.depwarn( + "NLSolveSafeTerminationResult has been deprecated in favor of the new dispatch based termination conditions. Please use the new termination conditions API!", + :NLSolveSafeTerminationResult) + return NLSolveSafeTerminationResult{typeof(best_objective_value), typeof(u)}(u, + best_objective_value, best_objective_value_iteration, return_code) +end + +const BASIC_TERMINATION_MODES = (NLSolveTerminationMode.SteadyStateDefault, + NLSolveTerminationMode.NLSolveDefault, + NLSolveTerminationMode.Norm, NLSolveTerminationMode.Rel, + NLSolveTerminationMode.RelNorm, + NLSolveTerminationMode.Abs, NLSolveTerminationMode.AbsNorm) + +const SAFE_TERMINATION_MODES = (NLSolveTerminationMode.RelSafe, + NLSolveTerminationMode.RelSafeBest, + NLSolveTerminationMode.AbsSafe, + NLSolveTerminationMode.AbsSafeBest) + +const SAFE_BEST_TERMINATION_MODES = (NLSolveTerminationMode.RelSafeBest, + NLSolveTerminationMode.AbsSafeBest) + +@doc doc""" + NLSolveTerminationCondition(mode; abstol::T = 1e-8, reltol::T = 1e-6, + protective_threshold = 1e3, patience_steps::Int = 30, + patience_objective_multiplier = 3, min_max_factor = 1.3) + +Define the termination criteria for the NonlinearProblem or SteadyStateProblem. + +## Termination Conditions + +#### Termination on Absolute Tolerance + + * `NLSolveTerminationMode.Abs`: Terminates if ``all \left( | \frac{\partial u}{\partial t} | \leq abstol \right)`` + * `NLSolveTerminationMode.AbsNorm`: Terminates if ``\| \frac{\partial u}{\partial t} \| \leq abstol`` + * `NLSolveTerminationMode.AbsSafe`: Essentially `abs_norm` + terminate if there has been no improvement for the last 30 steps + terminate if the solution blows up (diverges) + * `NLSolveTerminationMode.AbsSafeBest`: Same as `NLSolveTerminationMode.AbsSafe` but uses the best solution found so far, i.e. deviates only if the solution has not converged + +#### Termination on Relative Tolerance + + * `NLSolveTerminationMode.Rel`: Terminates if ``all \left(| \frac{\partial u}{\partial t} | \leq reltol \times | u | \right)`` + * `NLSolveTerminationMode.RelNorm`: Terminates if ``\| \frac{\partial u}{\partial t} \| \leq reltol \times \| \frac{\partial u}{\partial t} + u \|`` + * `NLSolveTerminationMode.RelSafe`: Essentially `rel_norm` + terminate if there has been no improvement for the last 30 steps + terminate if the solution blows up (diverges) + * `NLSolveTerminationMode.RelSafeBest`: Same as `NLSolveTerminationMode.RelSafe` but uses the best solution found so far, i.e. deviates only if the solution has not converged + +#### Termination using both Absolute and Relative Tolerances + + * `NLSolveTerminationMode.Norm`: Terminates if ``\| \frac{\partial u}{\partial t} \| \leq reltol \times \| \frac{\partial u}{\partial t} + u \|`` or ``\| \frac{\partial u}{\partial t} \| \leq abstol`` + * `NLSolveTerminationMode.SteadyStateDefault`: Check if all values of the derivative is close to zero wrt both relative and absolute tolerance. This is usable for small problems but doesn't scale well for neural networks. + * `NLSolveTerminationMode.NLSolveDefault`: Check if all values of the derivative is close to zero wrt both relative and absolute tolerance. Or check that the value of the current and previous state is within the specified tolerances. This is usable for small problems but doesn't scale well for neural networks. + +## General Arguments + + * `abstol`: Absolute Tolerance + * `reltol`: Relative Tolerance + +## Arguments specific to `*Safe*` modes + + * `protective_threshold`: If the objective value increased by this factor wrt initial objective terminate immediately. + * `patience_steps`: If objective is within `patience_objective_multiplier` factor of the criteria and no improvement within `min_max_factor` has happened then terminate. + +!!! warning + This has been deprecated and will be removed in the next major release. Please use the new dispatch based termination conditions API. +""" +struct NLSolveTerminationCondition{mode, T, + S <: Union{<:NLSolveSafeTerminationOptions, Nothing}} + abstol::T + reltol::T + safe_termination_options::S +end + +function Base.show(io::IO, s::NLSolveTerminationCondition{mode}) where {mode} + print(io, + "NLSolveTerminationCondition(mode = $(mode), abstol = $(s.abstol), reltol = $(s.reltol)") + if mode ∈ SAFE_TERMINATION_MODES + print(io, ", safe_termination_options = ", s.safe_termination_options, ")") + else + print(io, ")") + end +end + +get_termination_mode(::NLSolveTerminationCondition{mode}) where {mode} = mode + +# Don't specify `mode` since the defaults would depend on the package +function NLSolveTerminationCondition(mode; abstol::T = 1e-8, reltol::T = 1e-6, + protective_threshold = 1e3, patience_steps::Int = 30, + patience_objective_multiplier = 3, + min_max_factor = 1.3) where {T} + Base.depwarn( + "NLSolveTerminationCondition has been deprecated in favor of the new dispatch based termination conditions. Please use the new termination conditions API!", + :NLSolveTerminationCondition) + @assert mode ∈ instances(NLSolveTerminationMode.T) + options = if mode ∈ SAFE_TERMINATION_MODES + NLSolveSafeTerminationOptions(protective_threshold, patience_steps, + patience_objective_multiplier, min_max_factor) + else + nothing + end + return NLSolveTerminationCondition{mode, T, typeof(options)}(abstol, reltol, options) +end + +function (cond::NLSolveTerminationCondition)(storage::Union{ + NLSolveSafeTerminationResult, + Nothing +}) + mode = get_termination_mode(cond) + # We need both the dispatches to support solvers that don't use the integrator + # interface like SimpleNonlinearSolve + if mode in BASIC_TERMINATION_MODES + function _termination_condition_closure_basic(integrator, abstol, reltol, min_t) + return _termination_condition_closure_basic(get_du(integrator), integrator.u, + integrator.uprev, abstol, reltol) + end + function _termination_condition_closure_basic(du, u, uprev, abstol, reltol) + return _has_converged(du, u, uprev, cond, abstol, reltol) + end + return _termination_condition_closure_basic + else + mode ∈ SAFE_BEST_TERMINATION_MODES && @assert storage !== nothing + nstep::Int = 0 + + function _termination_condition_closure_safe(integrator, abstol, reltol, min_t) + return _termination_condition_closure_safe(get_du(integrator), integrator.u, + integrator.uprev, abstol, reltol) + end + @inbounds function _termination_condition_closure_safe(du, u, uprev, abstol, reltol) + aType = typeof(abstol) + protective_threshold = aType(cond.safe_termination_options.protective_threshold) + objective_values = aType[] + patience_objective_multiplier = cond.safe_termination_options.patience_objective_multiplier + + if mode ∈ SAFE_BEST_TERMINATION_MODES + storage.best_objective_value = aType(Inf) + storage.best_objective_value_iteration = 0 + end + + if mode ∈ SAFE_BEST_TERMINATION_MODES + objective = NONLINEARSOLVE_DEFAULT_NORM(du) + criteria = abstol + else + objective = NONLINEARSOLVE_DEFAULT_NORM(du) / + (NONLINEARSOLVE_DEFAULT_NORM(du .+ u) + eps(aType)) + criteria = reltol + end + + if mode ∈ SAFE_BEST_TERMINATION_MODES + if objective < storage.best_objective_value + storage.best_objective_value = objective + storage.best_objective_value_iteration = nstep + 1 + if storage.u !== nothing + storage.u .= u + end + end + end + + # Main Termination Criteria + if objective ≤ criteria + storage.return_code = NLSolveSafeTerminationReturnCode.Success + return true + end + + # Terminate if there has been no improvement for the last `patience_steps` + nstep += 1 + push!(objective_values, objective) + + if objective ≤ typeof(criteria)(patience_objective_multiplier) * criteria + if nstep ≥ cond.safe_termination_options.patience_steps + last_k_values = objective_values[max(1, + length(objective_values) - + cond.safe_termination_options.patience_steps):end] + if maximum(last_k_values) < + typeof(criteria)(cond.safe_termination_options.min_max_factor) * + minimum(last_k_values) + storage.return_code = NLSolveSafeTerminationReturnCode.PatienceTermination + return true + end + end + end + + # Protective break + if objective ≥ objective_values[1] * protective_threshold * length(du) + storage.return_code = NLSolveSafeTerminationReturnCode.ProtectiveTermination + return true + end + + storage.return_code = NLSolveSafeTerminationReturnCode.Failure + return false + end + return _termination_condition_closure_safe + end +end + +# Convergence Criteria +@inline function _has_converged(du, u, uprev, cond::NLSolveTerminationCondition{mode}, + abstol = cond.abstol, reltol = cond.reltol) where {mode} + return _has_converged(du, u, uprev, mode, abstol, reltol) +end + +@inline @inbounds function _has_converged(du, u, uprev, mode, abstol, reltol) + if mode == NLSolveTerminationMode.Norm + du_norm = NONLINEARSOLVE_DEFAULT_NORM(du) + return du_norm ≤ abstol || du_norm ≤ reltol * NONLINEARSOLVE_DEFAULT_NORM(du + u) + elseif mode == NLSolveTerminationMode.Rel + return all(abs.(du) .≤ reltol .* abs.(u)) + elseif mode ∈ (NLSolveTerminationMode.RelNorm, NLSolveTerminationMode.RelSafe, + NLSolveTerminationMode.RelSafeBest) + return NONLINEARSOLVE_DEFAULT_NORM(du) ≤ + reltol * NONLINEARSOLVE_DEFAULT_NORM(du .+ u) + elseif mode == NLSolveTerminationMode.Abs + return all(abs.(du) .≤ abstol) + elseif mode ∈ (NLSolveTerminationMode.AbsNorm, NLSolveTerminationMode.AbsSafe, + NLSolveTerminationMode.AbsSafeBest) + return NONLINEARSOLVE_DEFAULT_NORM(du) ≤ abstol + elseif mode == NLSolveTerminationMode.SteadyStateDefault + return all((abs.(du) .≤ abstol) .| (abs.(du) .≤ reltol .* abs.(u))) + elseif mode == NLSolveTerminationMode.NLSolveDefault + atol, rtol = abstol, reltol + return all((abs.(du) .≤ abstol) .| (abs.(du) .≤ reltol .* abs.(u))) || + isapprox(u, uprev; atol, rtol) + else + throw(ArgumentError("Unknown termination mode: $mode")) + end +end diff --git a/src/utils.jl b/src/utils.jl index 5e967fa..8a2e2e4 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -165,17 +165,17 @@ function init_termination_cache( tc_ = if hasfield(typeof(tc), :internalnorm) && tc.internalnorm === nothing internalnorm = ifelse( prob isa ImmutableNonlinearProblem, Base.Fix1(maximum, abs), Base.Fix2(norm, 2)) - DiffEqBase.set_termination_mode_internalnorm(tc, internalnorm) + set_termination_mode_internalnorm(tc, internalnorm) else tc end tc_cache = init(du, u, tc_; abstol, reltol, use_deprecated_retcodes = Val(false)) - return DiffEqBase.get_abstol(tc_cache), DiffEqBase.get_reltol(tc_cache), tc_cache + return get_abstol(tc_cache), get_reltol(tc_cache), tc_cache end function check_termination(tc_cache, fx, x, xo, prob, alg) return check_termination( - tc_cache, fx, x, xo, prob, alg, DiffEqBase.get_termination_mode(tc_cache)) + tc_cache, fx, x, xo, prob, alg, get_termination_mode(tc_cache)) end function check_termination( tc_cache, fx, x, xo, prob, alg, ::AbstractNonlinearTerminationMode) @@ -237,7 +237,7 @@ end @inline __reshape(x::AbstractArray, args...) = reshape(x, args...) # Override cases which might be used in a kernel launch -__get_tolerance(x, η, ::Type{T}) where {T} = DiffEqBase._get_tolerance(η, T) +__get_tolerance(x, η, ::Type{T}) where {T} = _get_tolerance(η, T) function __get_tolerance(x::Union{SArray, Number}, ::Nothing, ::Type{T}) where {T} η = real(oneunit(T)) * (eps(real(one(T))))^(real(T)(0.8)) return T(η) diff --git a/test/core/23_test_problems_tests.jl b/test/core/23_test_problems_tests.jl index 6d19808..6fb3ea5 100644 --- a/test/core/23_test_problems_tests.jl +++ b/test/core/23_test_problems_tests.jl @@ -1,5 +1,6 @@ @testsetup module RobustnessTesting using LinearAlgebra, NonlinearProblemLibrary, DiffEqBase, Test +using SimpleNonlinearSolve problems = NonlinearProblemLibrary.problems dicts = NonlinearProblemLibrary.dicts diff --git a/test/core/rootfind_tests.jl b/test/core/rootfind_tests.jl index 721d7fa..b78a7fa 100644 --- a/test/core/rootfind_tests.jl +++ b/test/core/rootfind_tests.jl @@ -3,6 +3,7 @@ using Reexport @reexport using AllocCheck, StaticArrays, Random, LinearAlgebra, ForwardDiff, DiffEqBase, TaylorDiff import PolyesterForwardDiff +using SimpleNonlinearSolve quadratic_f(u, p) = u .* u .- p quadratic_f!(du, u, p) = (du .= u .* u .- p)