Skip to content
This repository was archived by the owner on May 15, 2025. It is now read-only.

Commit a24bff1

Browse files
Merge pull request #163 from tansongchen/main
Add Householder's method
2 parents 64aff41 + d382ad4 commit a24bff1

File tree

5 files changed

+113
-2
lines changed

5 files changed

+113
-2
lines changed

Project.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,14 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
2626
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
2727
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
2828
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
29+
TaylorDiff = "b36ab563-344f-407b-a36a-4f200bebf99c"
2930

3031
[extensions]
3132
SimpleNonlinearSolveChainRulesCoreExt = "ChainRulesCore"
3233
SimpleNonlinearSolveReverseDiffExt = "ReverseDiff"
3334
SimpleNonlinearSolveTrackerExt = "Tracker"
3435
SimpleNonlinearSolveZygoteExt = "Zygote"
36+
SimpleNonlinearSolveTaylorDiffExt = "TaylorDiff"
3537

3638
[compat]
3739
ADTypes = "1.9"
@@ -63,6 +65,7 @@ SciMLBase = "2.37.0"
6365
Setfield = "1.1.1"
6466
StaticArrays = "1.9"
6567
StaticArraysCore = "1.4.2"
68+
TaylorDiff = "0.2.5"
6669
Test = "1.10"
6770
Tracker = "0.2.33"
6871
Zygote = "0.6.69"
@@ -86,9 +89,10 @@ ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
8689
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
8790
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
8891
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
92+
TaylorDiff = "b36ab563-344f-407b-a36a-4f200bebf99c"
8993
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
9094
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
9195
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
9296

9397
[targets]
94-
test = ["AllocCheck", "Aqua", "DiffEqBase", "ExplicitImports", "FiniteDiff", "ForwardDiff", "Hwloc", "InteractiveUtils", "LinearAlgebra", "NonlinearProblemLibrary", "Pkg", "PolyesterForwardDiff", "Random", "ReTestItems", "Reexport", "ReverseDiff", "StaticArrays", "Test", "Tracker", "Zygote"]
98+
test = ["AllocCheck", "Aqua", "DiffEqBase", "ExplicitImports", "FiniteDiff", "ForwardDiff", "Hwloc", "InteractiveUtils", "LinearAlgebra", "NonlinearProblemLibrary", "Pkg", "PolyesterForwardDiff", "Random", "ReTestItems", "Reexport", "ReverseDiff", "StaticArrays", "TaylorDiff", "Test", "Tracker", "Zygote"]
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
module SimpleNonlinearSolveTaylorDiffExt
2+
using SimpleNonlinearSolve
3+
using SimpleNonlinearSolve: ImmutableNonlinearProblem, ReturnCode, build_solution,
4+
check_termination, init_termination_cache
5+
using SimpleNonlinearSolve: __maybe_unaliased, _get_fx, __fixed_parameter_function
6+
using MaybeInplace: @bb
7+
using SciMLBase: isinplace
8+
9+
import TaylorDiff
10+
11+
@inline function __get_higher_order_derivatives(
12+
::SimpleHouseholder{N}, prob, f, x, fx) where {N}
13+
vN = Val(N)
14+
l = map(one, x)
15+
t = TaylorDiff.make_seed(x, l, vN)
16+
17+
if isinplace(prob)
18+
bundle = similar(fx, TaylorDiff.TaylorScalar{eltype(fx), N})
19+
f(bundle, t)
20+
map!(TaylorDiff.primal, fx, bundle)
21+
else
22+
bundle = f(t)
23+
fx = map(TaylorDiff.primal, bundle)
24+
end
25+
bundle = inv.(bundle)
26+
num = TaylorDiff.extract_derivative(bundle, N - 1)
27+
den = TaylorDiff.extract_derivative(bundle, N)
28+
return num, den, fx
29+
end
30+
31+
function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleHouseholder{N},
32+
args...; abstol = nothing, reltol = nothing, maxiters = 1000,
33+
termination_condition = nothing, alias_u0 = false, kwargs...) where {N}
34+
x = __maybe_unaliased(prob.u0, alias_u0)
35+
length(x) == 1 ||
36+
throw(ArgumentError("SimpleHouseholder only supports scalar problems"))
37+
fx = _get_fx(prob, x)
38+
@bb xo = copy(x)
39+
f = __fixed_parameter_function(prob)
40+
41+
abstol, reltol, tc_cache = init_termination_cache(
42+
prob, abstol, reltol, fx, x, termination_condition)
43+
44+
for i in 1:maxiters
45+
num, den, fx = __get_higher_order_derivatives(alg, prob, f, x, fx)
46+
47+
if i == 1
48+
iszero(fx) && build_solution(prob, alg, x, fx; retcode = ReturnCode.Success)
49+
else
50+
# Termination Checks
51+
tc_sol = check_termination(tc_cache, fx, x, xo, prob, alg)
52+
tc_sol !== nothing && return tc_sol
53+
end
54+
55+
@bb copyto!(xo, x)
56+
@bb x .+= (N - 1) .* num ./ den
57+
end
58+
59+
return build_solution(prob, alg, x, fx; retcode = ReturnCode.MaxIters)
60+
end
61+
62+
end

src/SimpleNonlinearSolve.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ include("nlsolve/lbroyden.jl")
4747
include("nlsolve/klement.jl")
4848
include("nlsolve/trustRegion.jl")
4949
include("nlsolve/halley.jl")
50+
include("nlsolve/householder.jl")
5051
include("nlsolve/dfsane.jl")
5152

5253
## Interval Nonlinear Solvers
@@ -139,6 +140,7 @@ end
139140
export AutoFiniteDiff, AutoForwardDiff, AutoPolyesterForwardDiff
140141
export SimpleBroyden, SimpleDFSane, SimpleGaussNewton, SimpleHalley, SimpleKlement,
141142
SimpleLimitedMemoryBroyden, SimpleNewtonRaphson, SimpleTrustRegion
143+
export SimpleHouseholder
142144
export Alefeld, Bisection, Brent, Falsi, ITP, Ridder
143145

144146
end # module

src/nlsolve/householder.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
"""
2+
SimpleHouseholder{order}()
3+
4+
A low-overhead implementation of Householder's method to arbitrary order.
5+
This method is non-allocating on scalar and static array problems.
6+
7+
!!! warning
8+
9+
Needs `TaylorDiff.jl` to be explicitly loaded before using this functionality.
10+
Internally, this uses TaylorDiff.jl for automatic differentiation.
11+
12+
### Type Parameters
13+
14+
- `order`: the convergence order of the Householder method. `order = 2` is the same as Newton's method, `order = 3` is the same as Halley's method, etc.
15+
"""
16+
struct SimpleHouseholder{order} <: AbstractNewtonAlgorithm end

test/core/rootfind_tests.jl

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
@testsetup module RootfindingTesting
22
using Reexport
3-
@reexport using AllocCheck, StaticArrays, Random, LinearAlgebra, ForwardDiff, DiffEqBase
3+
@reexport using AllocCheck, StaticArrays, Random, LinearAlgebra, ForwardDiff, DiffEqBase,
4+
TaylorDiff
45
import PolyesterForwardDiff
56

67
quadratic_f(u, p) = u .* u .- p
@@ -91,6 +92,32 @@ end
9192
end
9293
end
9394

95+
@testitem "SimpleHouseholder" setup=[RootfindingTesting] tags=[:core] begin
96+
@testset "AutoDiff: TaylorDiff.jl" for order in (2, 3, 4)
97+
@testset "[OOP] u0: $(nameof(typeof(u0)))" for u0 in ([1.0], @SVector[1.0], 1.0)
98+
sol = benchmark_nlsolve_oop(
99+
quadratic_f, u0; solver = SimpleHouseholder{order}())
100+
@test SciMLBase.successful_retcode(sol)
101+
@test all(abs.(sol.u .* sol.u .- 2) .< 1e-9)
102+
end
103+
104+
@testset "[IIP] u0: $(nameof(typeof(u0)))" for u0 in ([1.0],)
105+
sol = benchmark_nlsolve_iip(
106+
quadratic_f!, u0; solver = SimpleHouseholder{order}())
107+
@test SciMLBase.successful_retcode(sol)
108+
@test all(abs.(sol.u .* sol.u .- 2) .< 1e-9)
109+
end
110+
end
111+
112+
@testset "Termination condition: $(nameof(typeof(termination_condition))) u0: $(nameof(typeof(u0)))" for termination_condition in TERMINATION_CONDITIONS,
113+
u0 in (1.0, [1.0], @SVector[1.0])
114+
115+
probN = NonlinearProblem(quadratic_f, u0, 2.0)
116+
@test all(solve(probN, SimpleHouseholder{2}(); termination_condition).u .≈
117+
sqrt(2.0))
118+
end
119+
end
120+
94121
@testitem "Derivative Free Metods" setup=[RootfindingTesting] tags=[:core] begin
95122
@testset "$(nameof(typeof(alg)))" for alg in [
96123
SimpleBroyden(), SimpleKlement(), SimpleDFSane(),

0 commit comments

Comments
 (0)