Skip to content

Commit d4194e6

Browse files
vchuravyranochaclaude
authored
Simple Linesearch (#1)
* simple Backtracking line search after SIAMFANL * Update src/linesearches.jl Co-authored-by: Hendrik Ranocha <ranocha@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Hendrik Ranocha <ranocha@users.noreply.github.com> * cleanup * rename n_res to norm_res * enable access to J and make backtracking line search fully in place * make alpha configurable * export linesearch and make logic cleaner * add tests * cleanup example * bump version * Remove Pluto notebooks from docs Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * include Rosenbrock in docs * ensure neg_res is the same type as res * Revert "Remove Pluto notebooks from docs" This reverts commit 3929f34. --------- Co-authored-by: Hendrik Ranocha <ranocha@users.noreply.github.com> Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 8807ee9 commit d4194e6

8 files changed

Lines changed: 252 additions & 39 deletions

File tree

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Ariadne"
22
uuid = "0be81120-40bf-4f8b-adf0-26103efb66f1"
33
authors = ["Valentin Churavy <v.churavy@gmail.com>"]
4-
version = "0.1.0"
4+
version = "0.1.1"
55

66
[deps]
77
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"

docs/make.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ const EXAMPLES_DIR = joinpath(@__DIR__, "..", "examples")
4040
const OUTPUT_DIR = joinpath(@__DIR__, "src/generated")
4141

4242
examples = [
43+
"Rosenbrock" => "rosenbrock",
4344
"Bratu -- 1D" => "bratu",
4445
"Bratu -- KernelAbstractions" => "bratu_ka",
4546
"Simple" => "simple",

docs/src/index.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,14 @@ newton_krylov!
99
newton_krylov
1010
```
1111

12+
### Line Searches
13+
14+
```@docs
15+
Ariadne.LineSearches.AbstractLineSearch
16+
NoLineSearch
17+
BacktrackingLineSearch
18+
```
19+
1220
### Parameters
1321

1422
```@docs

examples/rosenbrock.jl

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# # Generalized Rosenbrock
2+
3+
# This example is taken from Fig. 1 of:
4+
# > A. Pal et al., "NonlinearSolve.jl: High-performance and robust solvers for systems
5+
# > of nonlinear equations in Julia," arXiv [math.NA], 24-Mar-2024.
6+
# > https://arxiv.org/abs/2403.16341
7+
8+
# ## Packages
9+
10+
using Ariadne
11+
12+
# ## Problem definition
13+
14+
# The generalized Rosenbrock function in $N$ dimensions:
15+
# ```math
16+
# F(x)_1 = 1 - x_1, \quad F(x)_i = 10(x_i - x_{i-1}^2), \quad i = 2, \ldots, N
17+
# ```
18+
19+
function generalized_rosenbrock(x, _)
20+
return vcat(
21+
1 - x[1],
22+
10 .* (x[2:end] .- x[1:(end - 1)] .* x[1:(end - 1)])
23+
)
24+
end
25+
26+
# The standard starting point is $x_1 = -1.2$, $x_i = 1$ for $i \geq 2$.
27+
28+
N = 12
29+
x_start = vcat(-1.2, ones(N - 1))
30+
31+
# ## Without line search
32+
33+
# Solving with GMRES and no line search (`NoLineSearch`).
34+
# The number of iterations required grows quickly with $N$ and the solver
35+
# fails to converge for $N \geq 9$ within the iteration budget.
36+
37+
_, stats = newton_krylov(
38+
generalized_rosenbrock,
39+
copy(x_start);
40+
algo = :gmres,
41+
linesearch! = NoLineSearch(),
42+
max_niter = 100_000
43+
)
44+
stats
45+
46+
# ## With backtracking line search
47+
48+
# Using `BacktrackingLineSearch` stabilizes convergence for larger $N$.
49+
# Pal et al. report that their backtracking implementation does not converge for $N = 10$
50+
# (using `abstol = 1e-8`); with `abstol = 1e-12` our implementation converges for all
51+
# $N \leq 12$.
52+
53+
_, stats = newton_krylov(
54+
generalized_rosenbrock,
55+
copy(x_start);
56+
algo = :gmres,
57+
linesearch! = BacktrackingLineSearch(),
58+
max_niter = 100_000
59+
)
60+
stats

examples/simple.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,23 +24,23 @@ fig, ax = contour(xs, ys, (x, y) -> norm(F([x, y], nothing)); levels)
2424

2525
trace_1 = let x₀ = [2.0, 0.5]
2626
xs = Vector{Tuple{Float64, Float64}}(undef, 0)
27-
hist(x, res, n_res) = (push!(xs, (x[1], x[2])); nothing)
27+
hist(x, res, norm_res) = (push!(xs, (x[1], x[2])); nothing)
2828
x, stats = newton_krylov!(F!, x₀, nothing, callback = hist)
2929
xs
3030
end
3131
lines!(ax, trace_1)
3232

3333
trace_2 = let x₀ = [2.5, 3.0]
3434
xs = Vector{Tuple{Float64, Float64}}(undef, 0)
35-
hist(x, res, n_res) = (push!(xs, (x[1], x[2])); nothing)
35+
hist(x, res, norm_res) = (push!(xs, (x[1], x[2])); nothing)
3636
x, stats = newton_krylov!(F!, x₀, nothing, callback = hist)
3737
xs
3838
end
3939
lines!(ax, trace_2)
4040

4141
trace_3 = let x₀ = [3.0, 4.0]
4242
xs = Vector{Tuple{Float64, Float64}}(undef, 0)
43-
hist(x, res, n_res) = (push!(xs, (x[1], x[2])); nothing)
43+
hist(x, res, norm_res) = (push!(xs, (x[1], x[2])); nothing)
4444
x, stats = newton_krylov!(F!, x₀, nothing, callback = hist, forcing = Ariadne.EisenstatWalker(η_max = 0.68949), verbose = 1)
4545
@show stats.solved
4646
xs

src/Ariadne.jl

Lines changed: 38 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,14 @@ function Base.collect(JOp::Union{Adjoint{<:Any, <:AbstractJacobianOperator}, Tra
227227
return J
228228
end
229229

230+
##
231+
# LineSearches
232+
##
233+
234+
include("linesearches.jl")
235+
import .LineSearches: AbstractLineSearch, NoLineSearch, BacktrackingLineSearch
236+
export NoLineSearch, BacktrackingLineSearch
237+
230238
##
231239
# Newton-Krylov
232240
##
@@ -270,15 +278,15 @@ end
270278
"""
271279
Compute the Eisenstat-Walker forcing term for n > 0
272280
"""
273-
function (F::EisenstatWalker)(η, tol, n_res, n_res_prior)
274-
η_res = F.γ * n_res^2 / n_res_prior^2
281+
function (F::EisenstatWalker)(η, tol, norm_res, norm_res_prior)
282+
η_res = F.γ * norm_res^2 / norm_res_prior^2
275283
# Eq 3.6
276284
if F.γ * η^2 <= 1 // 10
277285
η_safe = min(F.η_max, η_res)
278286
else
279287
η_safe = min(F.η_max, max(η_res, F.γ * η^2))
280288
end
281-
return min(F.η_max, max(η_safe, 1 // 2 * tol / n_res)) # Eq 3.5
289+
return min(F.η_max, max(η_safe, 1 // 2 * tol / norm_res)) # Eq 3.5
282290
end
283291
initial(F::EisenstatWalker) = F.η_max
284292

@@ -331,13 +339,13 @@ end
331339
struct Stats
332340
outer_iterations::Int
333341
inner_iterations::Int
334-
n_res::Float64
342+
norm_res::Float64
335343
end
336-
function update(stats::Stats, inner_iterations, n_res::Float64)
344+
function update(stats::Stats, inner_iterations, norm_res::Float64)
337345
return Stats(
338346
stats.outer_iterations + 1,
339347
stats.inner_iterations + inner_iterations,
340-
n_res
348+
norm_res
341349
)
342350
end
343351

@@ -357,6 +365,7 @@ function newton_krylov!(
357365
tol_abs = 1.0e-12, # Scipy uses 6e-6
358366
max_niter = 50,
359367
forcing::Union{Forcing, Nothing} = EisenstatWalker(),
368+
linesearch!::AbstractLineSearch = NoLineSearch(),
360369
verbose = 0,
361370
algo = :gmres,
362371
M = nothing,
@@ -366,25 +375,25 @@ function newton_krylov!(
366375
)
367376
t₀ = time_ns()
368377
F!(res, u, p) # res = F(u)
369-
n_res = norm(res)
370-
callback(u, res, n_res)
378+
norm_res = norm(res)
379+
callback(u, res, norm_res)
371380

372-
tol = tol_rel * n_res + tol_abs
381+
tol = tol_rel * norm_res + tol_abs
373382

374383
if forcing !== nothing
375384
η = initial(forcing)
376385
end
377386

378-
verbose > 0 && @info "Jacobian-Free Newton-Krylov" algo res₀ = n_res tol tol_rel tol_abs η
387+
verbose > 0 && @info "Jacobian-Free Newton-Krylov" algo res₀ = norm_res tol tol_rel tol_abs η
379388

380389
J = JacobianOperator(F!, res, u, p)
381390

382391
# TODO: Refactor to provide method that re-uses the cache here.
383392
kc = KrylovConstructor(res)
384393
workspace = krylov_workspace(algo, kc)
385394

386-
stats = Stats(0, 0, n_res)
387-
while n_res > tol && stats.outer_iterations <= max_niter
395+
stats = Stats(0, 0, norm_res)
396+
while norm_res > tol && stats.outer_iterations <= max_niter
388397
# Handle kwargs for Preconditioners
389398
kwargs = krylov_kwargs
390399
if N !== nothing
@@ -405,47 +414,41 @@ function newton_krylov!(
405414
kwargs = (; atol = zero(η), rtol = η, kwargs...)
406415
end
407416

408-
# Solve: J d = res = F(u)
409-
# Typically, the Newton method is formulated as J d = -F(u)
410-
# with update u = u + d.
411-
# To simplify the implementation, we solve J d = F(u)
412-
# and update u = u - d instead.
413-
# `res` is modified by J, so we create a copy `res`
414-
# TODO: provide a temporary storage for `res`
415-
krylov_solve!(workspace, J, copy(res); kwargs...)
416-
417-
d = workspace.x # (negative) Newton direction
418-
s = 1 # Scaling of the Newton step TODO: LineSearch
417+
# Solve: J d = -res = -F(u)
418+
# The Newton method is formulated as J d = -F(u)
419+
# `res` is modified by J, so we create a `neg_res` copy here.
420+
# TODO: provide cache for `neg_res` to avoid this allocation.
421+
neg_res = similar(res)
422+
@. neg_res = -res
423+
krylov_solve!(workspace, J, neg_res; kwargs...)
419424

420-
# Update u
421-
u .= muladd.(-s, d, u) # u = u - s * d
425+
d₀ = workspace.x # (negative) Newton direction
422426

423-
# Update residual and norm
424-
n_res_prior = n_res
427+
# Perform line search to find an appropriate step size and update `u` and `res` in-place
428+
norm_res_prior = norm_res
429+
norm_res = linesearch!(J, F!, res, norm_res_prior, u, p, d₀)
425430

426-
F!(res, u, p) # res = F(u)
427-
n_res = norm(res)
428-
callback(u, res, n_res)
431+
callback(u, res, norm_res)
429432

430-
if isinf(n_res) || isnan(n_res)
433+
if isinf(norm_res) || isnan(norm_res)
431434
@error "Inner solver blew up" stats
432435
break
433436
end
434437

435438
if forcing !== nothing
436-
η = forcing(η, tol, n_res, n_res_prior)
439+
η = forcing(η, tol, norm_res, norm_res_prior)
437440
end
438441

439442
# This is almost to be expected for implicit time-stepping
440443
if verbose > 0 && workspace.stats.niter == 0 && forcing !== nothing
441444
@info "Inexact Newton thinks our step is good enough " η stats
442445
end
443446

444-
stats = update(stats, workspace.stats.niter, n_res)
445-
verbose > 0 && @info "Newton" iter = n_res η stats
447+
stats = update(stats, workspace.stats.niter, norm_res)
448+
verbose > 0 && @info "Newton" iter = norm_res η stats
446449
end
447450
t = (time_ns() - t₀) / 1.0e9
448-
return u, (; solved = n_res <= tol, stats, t)
451+
return u, (; solved = norm_res <= tol, stats, t)
449452
end
450453

451454
end # module Ariadne

src/linesearches.jl

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
module LineSearches
2+
3+
using LinearAlgebra
4+
5+
"""
6+
AbstractLineSearch
7+
8+
Line search may update the solution `u` and the residual `res` in-place,
9+
given the function `F!`, parameters `p`, and the Newton direction `d`.
10+
11+
They must call `F!(res, u, p)` to update the residual after updating `u`.
12+
13+
```julia
14+
struct NewLineSearch <: AbstractLineSearch
15+
# parameters for the line search
16+
end
17+
18+
function (ls::NewLineSearch)(J::AbstractJacobianOperator, F!, res, norm_res_prior, u, p, d)
19+
# perform line search to find an appropriate step size
20+
# ...
21+
# update u and res in-place
22+
F!(res, u, p)
23+
return norm(res)
24+
end
25+
```
26+
"""
27+
abstract type AbstractLineSearch end
28+
29+
"""
30+
NoLineSearch()
31+
32+
A line search that does not perform any line search: it simply takes the full Newton step.
33+
"""
34+
struct NoLineSearch <: AbstractLineSearch end
35+
36+
function (::NoLineSearch)(J, F!, res, norm_res_prior, u, p, d)
37+
# No line search: take the full Newton step
38+
u .+= d
39+
F!(res, u, p)
40+
return norm(res)
41+
end
42+
43+
"""
44+
BacktrackingLineSearch(; n_iter_max = 10)
45+
46+
## References
47+
48+
- Kelley, C. T. (2022).
49+
Solving nonlinear equations with iterative methods:
50+
Solvers and examples in Julia.
51+
Society for Industrial and Applied Mathematics.
52+
- <https://github.com/ctkelley/SIAMFANLEquations.jl>
53+
"""
54+
Base.@kwdef struct BacktrackingLineSearch <: AbstractLineSearch
55+
n_iter_max::Int = 10
56+
alpha::Float64 = 1.0e-4
57+
end
58+
59+
function (ls::BacktrackingLineSearch)(J, F!, res, norm_res_prior, u, p, d)
60+
alpha = ls.alpha
61+
lambda = 1.0
62+
63+
@assert ls.n_iter_max > 0 "n_iter_max must be positive and larger than 0"
64+
@assert alpha > 0 "alpha must be positive"
65+
66+
# Take the full Newton step (lambda = 1.0)
67+
u .= muladd.(lambda, d, u) # u = u + lambda * d
68+
F!(res, u, p)
69+
norm_res = norm(res)
70+
71+
for _ in 2:ls.n_iter_max
72+
# Armijo condition
73+
if norm_res <= (1 - alpha * lambda) * norm_res_prior
74+
return norm_res
75+
end
76+
77+
# Halve lambda and retract the excess step incrementally:
78+
# u goes from u + old_lambda*d to u + new_lambda*d,
79+
# so the adjustment is (new_lambda - old_lambda)*d (negative).
80+
new_lambda = lambda * 0.5
81+
s = new_lambda - lambda
82+
u .= muladd.(s, d, u) # u = u + (new_lambda - old_lambda) * d
83+
lambda = new_lambda
84+
F!(res, u, p)
85+
norm_res = norm(res)
86+
end
87+
return norm_res
88+
end
89+
90+
end # module LineSearches

0 commit comments

Comments
 (0)