Skip to content

Commit 9a33df9

Browse files
committed
Add an option to choose the Krylov solver
1 parent 6f74156 commit 9a33df9

File tree

2 files changed

+12
-10
lines changed

2 files changed

+12
-10
lines changed

src/fft_model.jl

+4-2
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,10 @@ mutable struct FFTNLPModel{T,VT,FFT,R,C} <: AbstractNLPModel{T,VT}
2424
rdft::Bool
2525
fft_timer::Ref{Float64}
2626
mapping_timer::Ref{Float64}
27+
krylov_solver::Symbol
2728
end
2829

29-
function FFTNLPModel{T,VT}(parameters::FFTParameters; rdft::Bool=false) where {T,VT}
30+
function FFTNLPModel{T,VT}(parameters::FFTParameters; krylov_solver::Symbol=:cg, rdft::Bool=false) where {T,VT}
3031
DFTdim = parameters.paramf[1] # problem size (1, 2, 3)
3132
DFTsize = parameters.paramf[2] # problem dimension
3233
N = prod(DFTsize)
@@ -82,7 +83,8 @@ function FFTNLPModel{T,VT}(parameters::FFTParameters; rdft::Bool=false) where {T
8283
end
8384
fft_timer = Ref{Float64}(0.0)
8485
mapping_timer = Ref{Float64}(0.0)
85-
return FFTNLPModel(meta, parameters, N, Counters(), op, buffer_real, buffer_complex1, buffer_complex2, rdft, fft_timer, mapping_timer)
86+
return FFTNLPModel(meta, parameters, N, Counters(), op, buffer_real, buffer_complex1,
87+
buffer_complex2, rdft, fft_timer, mapping_timer, krylov_solver)
8688
end
8789

8890
function NLPModels.cons!(nlp::FFTNLPModel, x::AbstractVector, c::AbstractVector)

src/kkt.jl

+8-8
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ end
104104
=#
105105

106106
struct FFTKKTSystem{T, VI, VT, MT, LS} <: MadNLP.AbstractReducedKKTSystem{T, VT, MT, MadNLP.ExactHessian{T, VT}}
107-
nlp::FFTNLPModel
107+
nlp::FFTNLPModel{T, VT}
108108
# Operators
109109
K::MT
110110
P::FFTPreconditioner{T, VT}
@@ -153,7 +153,7 @@ function MadNLP.create_kkt_system(
153153
l_lower = VT(undef, nlb)
154154
u_lower = VT(undef, nub)
155155

156-
workspace = Krylov.CgWorkspace(2*nβ, 2*nβ, VT)
156+
workspace = Krylov.krylov_workspace(Val(nlp.krylov_solver), 2*nβ, 2*nβ, VT)
157157

158158
z1 = VT(undef, nβ)
159159
z2 = VT(undef, 2*nβ)
@@ -176,11 +176,11 @@ MadNLP.get_hessian(kkt::FFTKKTSystem) = nothing
176176
MadNLP.get_jacobian(kkt::FFTKKTSystem) = nothing
177177

178178
# Dirty wrapper to MadNLP's linear solver
179-
MadNLP.is_inertia(::Krylov.CgWorkspace) = true
180-
MadNLP.inertia(::Krylov.CgWorkspace) = (0, 0, 0)
181-
MadNLP.introduce(::Krylov.CgWorkspace) = "CG"
182-
MadNLP.improve!(::Krylov.CgWorkspace) = true
183-
MadNLP.factorize!(::Krylov.CgWorkspace) = nothing
179+
MadNLP.is_inertia(::Krylov.KrylovWorkspace) = true
180+
MadNLP.inertia(::Krylov.KrylovWorkspace) = (0, 0, 0)
181+
MadNLP.introduce(::Krylov.KrylovWorkspace) = "Krylov"
182+
MadNLP.improve!(::Krylov.KrylovWorkspace) = true
183+
MadNLP.factorize!(::Krylov.KrylovWorkspace) = nothing
184184

185185
MadNLP.is_inertia_correct(kkt::FFTKKTSystem, p, n, z) = true
186186

@@ -355,7 +355,7 @@ function MadNLP.solve!(kkt::FFTKKTSystem, w::MadNLP.AbstractKKTVector)
355355
bβ .= w1 .- w3 .+ w4 .- Σ1 .* w5 .+ Σ2 .* w6
356356
bz .= w2 .- w3 .- w4 .- Σ1 .* w5 .- Σ2 .* w6
357357

358-
# Solve with CG
358+
# Solve with the Krylov solver (CG by default)
359359
Krylov.krylov_solve!(kkt.linear_solver, kkt.K, b, M=kkt.P, atol=1e-12, rtol=0.0, verbose=0)
360360
x = Krylov.solution(kkt.linear_solver)
361361
push!(kkt.krylov_iterations, kkt.linear_solver |> Krylov.iteration_count)

0 commit comments

Comments
 (0)