Skip to content
This repository was archived by the owner on May 15, 2025. It is now read-only.

Commit 89e640c

Browse files
final fixes
1 parent 5ef1936 commit 89e640c

File tree

4 files changed

+53
-2
lines changed

4 files changed

+53
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ SciMLBase = "2.37.0"
6969
Setfield = "1.1.1"
7070
StaticArrays = "1.9"
7171
StaticArraysCore = "1.4.2"
72-
TaylorDiff = "0.2.5, 0.3"
72+
TaylorDiff = "0.2.5"
7373
Test = "1.10"
7474
Tracker = "0.2.33"
7575
Zygote = "0.6.69"

src/termination_conditions.jl

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ end
293293

294294
# This dispatch is needed based on how Terminating Callback works!
295295
# This intentially drops the `abstol` and `reltol` arguments
296-
function (cache::NonlinearTerminationModeCache)(integrator::DiffEqBase.AbstractODEIntegrator,
296+
function (cache::NonlinearTerminationModeCache)(integrator::SciMLBase.AbstractODEIntegrator,
297297
abstol::Number, reltol::Number, min_t)
298298
retval = cache(cache.mode, get_du(integrator), integrator.u, integrator.uprev)
299299
(min_t === nothing || integrator.t min_t) && return retval
@@ -517,3 +517,52 @@ end
517517

518518
NONLINEARSOLVE_DEFAULT_NORM(u) = norm(u)
519519
NONLINEARSOLVE_DEFAULT_NORM(f::F, u) where {F} = norm(f.(u))
520+
521+
@inline __fast_scalar_indexing(args...) = all(ArrayInterface.fast_scalar_indexing, args)
522+
523+
@inline __maximum_abs(op::F, x, y) where {F} = __maximum(abs op, x, y)
524+
## Nonallocating version of maximum(op.(x, y))
525+
@inline function __maximum(op::F, x, y) where {F}
526+
if __fast_scalar_indexing(x, y)
527+
return maximum(@closure((xᵢyᵢ)->begin
528+
xᵢ, yᵢ = xᵢyᵢ
529+
return op(xᵢ, yᵢ)
530+
end), zip(x, y))
531+
else
532+
return mapreduce(@closure((xᵢ, yᵢ)->op(xᵢ, yᵢ)), max, x, y)
533+
end
534+
end
535+
536+
@inline function __norm_op(::typeof(Base.Fix2(norm, 2)), op::F, x, y) where {F}
537+
if __fast_scalar_indexing(x, y)
538+
return sqrt(sum(@closure((xᵢyᵢ)->begin
539+
xᵢ, yᵢ = xᵢyᵢ
540+
return op(xᵢ, yᵢ)^2
541+
end), zip(x, y)))
542+
else
543+
return sqrt(mapreduce(@closure((xᵢ, yᵢ)->(op(xᵢ, yᵢ)^2)), +, x, y))
544+
end
545+
end
546+
547+
@inline __norm_op(norm::N, op::F, x, y) where {N, F} = norm(op.(x, y))
548+
549+
function __nonlinearsolve_is_approx(x::Number, y::Number; atol = false,
550+
rtol = atol > 0 ? false : sqrt(eps(promote_type(typeof(x), typeof(y)))))
551+
return isapprox(x, y; atol, rtol)
552+
end
553+
function __nonlinearsolve_is_approx(x, y; atol = false,
554+
rtol = atol > 0 ? false : sqrt(eps(promote_type(eltype(x), eltype(y)))))
555+
length(x) != length(y) && return false
556+
d = __maximum_abs(-, x, y)
557+
return d max(atol, rtol * max(maximum(abs, x), maximum(abs, y)))
558+
end
559+
560+
@inline function __add_and_norm(::Nothing, x, y)
561+
Base.depwarn("Not specifying the internal norm of termination conditions has been \
562+
deprecated. Using inf-norm currently.",
563+
:__add_and_norm)
564+
return __maximum_abs(+, x, y)
565+
end
566+
@inline __add_and_norm(::typeof(Base.Fix1(maximum, abs)), x, y) = __maximum_abs(+, x, y)
567+
@inline __add_and_norm(::typeof(Base.Fix2(norm, Inf)), x, y) = __maximum_abs(+, x, y)
568+
@inline __add_and_norm(f::F, x, y) where {F} = __norm_op(f, +, x, y)

test/core/23_test_problems_tests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
@testsetup module RobustnessTesting
22
using LinearAlgebra, NonlinearProblemLibrary, DiffEqBase, Test
3+
using SimpleNonlinearSolve
34

45
problems = NonlinearProblemLibrary.problems
56
dicts = NonlinearProblemLibrary.dicts

test/core/rootfind_tests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ using Reexport
33
@reexport using AllocCheck, StaticArrays, Random, LinearAlgebra, ForwardDiff, DiffEqBase,
44
TaylorDiff
55
import PolyesterForwardDiff
6+
using SimpleNonlinearSolve
67

78
quadratic_f(u, p) = u .* u .- p
89
quadratic_f!(du, u, p) = (du .= u .* u .- p)

0 commit comments

Comments
 (0)