From 94c64479f34304194e54a42a086cafff972c303c Mon Sep 17 00:00:00 2001 From: Alexis Montoison Date: Wed, 23 Apr 2025 10:03:30 -0500 Subject: [PATCH] Add an option to choose the Krylov solver --- src/fft_model.jl | 6 ++++-- src/kkt.jl | 16 ++++++++-------- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/src/fft_model.jl b/src/fft_model.jl index e077105..6e5939f 100644 --- a/src/fft_model.jl +++ b/src/fft_model.jl @@ -24,9 +24,10 @@ mutable struct FFTNLPModel{T,VT,FFT,R,C} <: AbstractNLPModel{T,VT} rdft::Bool fft_timer::Ref{Float64} mapping_timer::Ref{Float64} + krylov_solver::Symbol end -function FFTNLPModel{T,VT}(parameters::FFTParameters; rdft::Bool=false) where {T,VT} +function FFTNLPModel{T,VT}(parameters::FFTParameters; krylov_solver::Symbol=:cg, rdft::Bool=false) where {T,VT} DFTdim = parameters.paramf[1] # problem size (1, 2, 3) DFTsize = parameters.paramf[2] # problem dimension N = prod(DFTsize) @@ -82,7 +83,8 @@ function FFTNLPModel{T,VT}(parameters::FFTParameters; rdft::Bool=false) where {T end fft_timer = Ref{Float64}(0.0) mapping_timer = Ref{Float64}(0.0) - return FFTNLPModel(meta, parameters, N, Counters(), op, buffer_real, buffer_complex1, buffer_complex2, rdft, fft_timer, mapping_timer) + return FFTNLPModel(meta, parameters, N, Counters(), op, buffer_real, buffer_complex1, + buffer_complex2, rdft, fft_timer, mapping_timer, krylov_solver) end function NLPModels.cons!(nlp::FFTNLPModel, x::AbstractVector, c::AbstractVector) diff --git a/src/kkt.jl b/src/kkt.jl index fb0c610..3540bca 100644 --- a/src/kkt.jl +++ b/src/kkt.jl @@ -104,7 +104,7 @@ end =# struct FFTKKTSystem{T, VI, VT, MT, LS} <: MadNLP.AbstractReducedKKTSystem{T, VT, MT, MadNLP.ExactHessian{T, VT}} - nlp::FFTNLPModel + nlp::FFTNLPModel{T, VT} # Operators K::MT P::FFTPreconditioner{T, VT} @@ -153,7 +153,7 @@ function MadNLP.create_kkt_system( l_lower = VT(undef, nlb) u_lower = VT(undef, nub) - workspace = Krylov.CgWorkspace(2*nβ, 2*nβ, VT) + workspace = Krylov.krylov_workspace(Val(nlp.krylov_solver), 2*nβ, 2*nβ, VT) z1 = VT(undef, nβ) z2 = VT(undef, 2*nβ) @@ -176,11 +176,11 @@ MadNLP.get_hessian(kkt::FFTKKTSystem) = nothing MadNLP.get_jacobian(kkt::FFTKKTSystem) = nothing # Dirty wrapper to MadNLP's linear solver -MadNLP.is_inertia(::Krylov.CgWorkspace) = true -MadNLP.inertia(::Krylov.CgWorkspace) = (0, 0, 0) -MadNLP.introduce(::Krylov.CgWorkspace) = "CG" -MadNLP.improve!(::Krylov.CgWorkspace) = true -MadNLP.factorize!(::Krylov.CgWorkspace) = nothing +MadNLP.is_inertia(::Krylov.KrylovWorkspace) = true +MadNLP.inertia(::Krylov.KrylovWorkspace) = (0, 0, 0) +MadNLP.introduce(::Krylov.KrylovWorkspace) = "Krylov" +MadNLP.improve!(::Krylov.KrylovWorkspace) = true +MadNLP.factorize!(::Krylov.KrylovWorkspace) = nothing MadNLP.is_inertia_correct(kkt::FFTKKTSystem, p, n, z) = true @@ -355,7 +355,7 @@ function MadNLP.solve!(kkt::FFTKKTSystem, w::MadNLP.AbstractKKTVector) bβ .= w1 .- w3 .+ w4 .- Σ1 .* w5 .+ Σ2 .* w6 bz .= w2 .- w3 .- w4 .- Σ1 .* w5 .- Σ2 .* w6 - # Solve with CG + # Solve with the Krylov solver (CG by default) Krylov.krylov_solve!(kkt.linear_solver, kkt.K, b, M=kkt.P, atol=1e-12, rtol=0.0, verbose=0) x = Krylov.solution(kkt.linear_solver) push!(kkt.krylov_iterations, kkt.linear_solver |> Krylov.iteration_count)