Skip to content

Commit 372ea1e

Browse files
committed
Merge branch 'main' into pardiso-vendor
2 parents 835f224 + 1f8eb0d commit 372ea1e

File tree

3 files changed

+83
-34
lines changed

3 files changed

+83
-34
lines changed

ext/LinearSolvePardisoExt.jl

+39-30
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ function LinearSolve.init_cacheval(alg::PardisoJL,
2222
reltol,
2323
verbose::Bool,
2424
assumptions::LinearSolve.OperatorAssumptions)
25-
@unpack nprocs, solver_type, matrix_type, iparm, dparm, vendor = alg
25+
@unpack nprocs, solver_type, matrix_type, cache_analysis, iparm, dparm, vendor = alg
2626
A = convert(AbstractMatrix, A)
2727

2828
if isnothing(vendor)
@@ -37,8 +37,12 @@ function LinearSolve.init_cacheval(alg::PardisoJL,
3737
solver = if vendor == :MKL
3838
solver = if Pardiso.mkl_is_available()
3939
solver = Pardiso.MKLPardisoSolver()
40+
Pardiso.pardisoinit(solver)
4041
nprocs !== nothing && Pardiso.set_nprocs!(solver, nprocs)
41-
transposed_iparm = 2 # see https://www.intel.com/content/www/us/en/docs/onemkl/developer-reference-c/2024-0/pardiso-iparm-parameter.html#IPARM37
42+
43+
# for mkl 1 means conjugated and 2 means transposed.
44+
# https://www.intel.com/content/www/us/en/docs/onemkl/developer-reference-c/2024-0/pardiso-iparm-parameter.html#IPARM37
45+
transposed_iparm = 2
4246

4347
solver
4448
else
@@ -47,6 +51,7 @@ function LinearSolve.init_cacheval(alg::PardisoJL,
4751
elseif vendor == :Panua
4852
solver = if Pardiso.panua_is_available()
4953
solver = Pardiso.PardisoSolver()
54+
Pardiso.pardisoinit(solver)
5055
solver_type !== nothing && Pardiso.set_solver!(solver, solver_type)
5156

5257
solver
@@ -57,8 +62,6 @@ function LinearSolve.init_cacheval(alg::PardisoJL,
5762
error("Pardiso vendor must be either `:MKL` or `:Panua`")
5863
end
5964

60-
Pardiso.pardisoinit(solver) # default initialization
61-
6265
if matrix_type !== nothing
6366
Pardiso.set_matrixtype!(solver, matrix_type)
6467
else
@@ -72,22 +75,6 @@ function LinearSolve.init_cacheval(alg::PardisoJL,
7275
end
7376
verbose && Pardiso.set_msglvl!(solver, Pardiso.MESSAGE_LEVEL_ON)
7477

75-
# pass in vector of tuples like [(iparm::Int, key::Int) ...]
76-
if iparm !== nothing
77-
for i in iparm
78-
Pardiso.set_iparm!(solver, i...)
79-
end
80-
end
81-
82-
if dparm !== nothing
83-
for d in dparm
84-
Pardiso.set_dparm!(solver, d...)
85-
end
86-
end
87-
88-
# Make sure to say it's transposed because its CSC not CSR
89-
Pardiso.set_iparm!(solver, 12, transposed_iparm)
90-
9178
#=
9279
Note: It is recommended to use IPARM(11)=1 (scaling) and IPARM(13)=1 (matchings) for
9380
highly indefinite symmetric matrices e.g. from interior point optimizations or saddle point problems.
@@ -99,36 +86,58 @@ function LinearSolve.init_cacheval(alg::PardisoJL,
9986
be changed to Pardiso.ANALYSIS_NUM_FACT in the solver loop otherwise instabilities
10087
occur in the example https://github.com/SciML/OrdinaryDiffEq.jl/issues/1569
10188
=#
102-
Pardiso.set_iparm!(solver, 11, 0)
103-
Pardiso.set_iparm!(solver, 13, 0)
104-
105-
Pardiso.set_phase!(solver, Pardiso.ANALYSIS)
89+
if cache_analysis
90+
Pardiso.set_iparm!(solver, 11, 0)
91+
Pardiso.set_iparm!(solver, 13, 0)
92+
end
10693

10794
if alg.solver_type == 1
10895
# PARDISO uses a numerical factorization A = LU for the first system and
10996
# applies these exact factors L and U for the next steps in a
11097
# preconditioned Krylov-Subspace iteration. If the iteration does not
11198
# converge, the solver will automatically switch back to the numerical factorization.
99+
# Be aware that in the intel docs, iparm indexes are one lower.
112100
Pardiso.set_iparm!(solver, 4, round(Int, abs(log10(reltol)), RoundDown) * 10 + 1)
113101
end
114102

115-
Pardiso.pardiso(solver,
116-
u,
117-
SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A), nonzeros(A)),
118-
b)
103+
# pass in vector of tuples like [(iparm::Int, key::Int) ...]
104+
if iparm !== nothing
105+
for i in iparm
106+
Pardiso.set_iparm!(solver, i...)
107+
end
108+
end
119109

110+
if dparm !== nothing
111+
for d in dparm
112+
Pardiso.set_dparm!(solver, d...)
113+
end
114+
end
115+
116+
# Make sure to say it's transposed because its CSC not CSR
117+
# This is also the only value which should not be overwritten by users
118+
Pardiso.set_iparm!(solver, 12, transposed_iparm)
119+
120+
if cache_analysis
121+
Pardiso.set_phase!(solver, Pardiso.ANALYSIS)
122+
Pardiso.pardiso(solver,
123+
u,
124+
SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A), nonzeros(A)),
125+
b)
126+
end
127+
128+
>>>>>>> main
120129
return solver
121130
end
122131

123132
function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::PardisoJL; kwargs...)
124133
@unpack A, b, u = cache
125134
A = convert(AbstractMatrix, A)
126135
if cache.isfresh
127-
Pardiso.set_phase!(cache.cacheval, Pardiso.NUM_FACT)
136+
phase = alg.cache_analysis ? Pardiso.NUM_FACT : Pardiso.ANALYSIS_NUM_FACT
137+
Pardiso.set_phase!(cache.cacheval, phase)
128138
Pardiso.pardiso(cache.cacheval, A, eltype(A)[])
129139
cache.isfresh = false
130140
end
131-
132141
Pardiso.set_phase!(cache.cacheval, Pardiso.SOLVE_ITERATIVE_REFINE)
133142
Pardiso.pardiso(cache.cacheval, u, A, b)
134143

src/extension_algs.jl

+21-3
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ end
8686
```julia
8787
MKLPardisoFactorize(; nprocs::Union{Int, Nothing} = nothing,
8888
matrix_type = nothing,
89+
cache_analysis = false,
8990
iparm::Union{Vector{Tuple{Int, Int}}, Nothing} = nothing,
9091
dparm::Union{Vector{Tuple{Int, Int}}, Nothing} = nothing)
9192
```
@@ -98,7 +99,11 @@ A sparse factorization method using MKL Pardiso.
9899
99100
## Keyword Arguments
100101
101-
For the definition of the keyword arguments, see the Pardiso.jl documentation.
102+
Setting `cache_analysis = true` disables Pardiso's scaling and matching defaults
103+
and caches the result of the initial analysis phase for all further computations
104+
with this solver.
105+
106+
For the definition of the other keyword arguments, see the Pardiso.jl documentation.
102107
All values default to `nothing` and the solver internally determines the values
103108
given the input types, and these keyword arguments are only for overriding the
104109
default handling process. This should not be required by most users.
@@ -109,6 +114,7 @@ MKLPardisoFactorize(; kwargs...) = PardisoJL(; vendor=:MKL, solver_type = 0, kwa
109114
```julia
110115
MKLPardisoIterate(; nprocs::Union{Int, Nothing} = nothing,
111116
matrix_type = nothing,
117+
cache_analysis = false,
112118
iparm::Union{Vector{Tuple{Int, Int}}, Nothing} = nothing,
113119
dparm::Union{Vector{Tuple{Int, Int}}, Nothing} = nothing)
114120
```
@@ -121,7 +127,11 @@ A mixed factorization+iterative method using MKL Pardiso.
121127
122128
## Keyword Arguments
123129
124-
For the definition of the keyword arguments, see the Pardiso.jl documentation.
130+
Setting `cache_analysis = true` disables Pardiso's scaling and matching defaults
131+
and caches the result of the initial analysis phase for all further computations
132+
with this solver.
133+
134+
For the definition of the other keyword arguments, see the Pardiso.jl documentation.
125135
All values default to `nothing` and the solver internally determines the values
126136
given the input types, and these keyword arguments are only for overriding the
127137
default handling process. This should not be required by most users.
@@ -133,6 +143,7 @@ MKLPardisoIterate(; kwargs...) = PardisoJL(; vendor=:MKL, solver_type = 1, kwarg
133143
```julia
134144
PanuaPardisoFactorize(; nprocs::Union{Int, Nothing} = nothing,
135145
matrix_type = nothing,
146+
cache_analysis = false,
136147
iparm::Union{Vector{Tuple{Int, Int}}, Nothing} = nothing,
137148
dparm::Union{Vector{Tuple{Int, Int}}, Nothing} = nothing)
138149
```
@@ -145,6 +156,10 @@ A sparse factorization method using Panua Pardiso.
145156
146157
## Keyword Arguments
147158
159+
Setting `cache_analysis = true` disables Pardiso's scaling and matching defaults
160+
and caches the result of the initial analysis phase for all further computations
161+
with this solver.
162+
148163
For the definition of the keyword arguments, see the Pardiso.jl documentation.
149164
All values default to `nothing` and the solver internally determines the values
150165
given the input types, and these keyword arguments are only for overriding the
@@ -207,13 +222,15 @@ struct PardisoJL{T1, T2} <: LinearSolve.SciMLLinearSolveAlgorithm
207222
nprocs::Union{Int, Nothing}
208223
solver_type::T1
209224
matrix_type::T2
225+
cache_analysis::Bool
210226
iparm::Union{Vector{Tuple{Int, Int}}, Nothing}
211227
dparm::Union{Vector{Tuple{Int, Int}}, Nothing}
212228
vendor::Union{Symbol,Nothing}
213229

214230
function PardisoJL(; nprocs::Union{Int, Nothing} = nothing,
215231
solver_type = nothing,
216232
matrix_type = nothing,
233+
cache_analysis = false,
217234
iparm::Union{Vector{Tuple{Int, Int}}, Nothing} = nothing,
218235
dparm::Union{Vector{Tuple{Int, Int}}, Nothing} = nothing,
219236
vendor::Union{Symbol,Nothing}=nothing )
@@ -225,7 +242,8 @@ struct PardisoJL{T1, T2} <: LinearSolve.SciMLLinearSolveAlgorithm
225242
T2 = typeof(matrix_type)
226243
@assert T1 <: Union{Int, Nothing, ext.Pardiso.Solver}
227244
@assert T2 <: Union{Int, Nothing, ext.Pardiso.MatrixType}
228-
return new{T1, T2}(nprocs, solver_type, matrix_type, iparm, dparm, vendor)
245+
return new{T1, T2}(
246+
nprocs, solver_type, matrix_type, cache_analysis, iparm, dparm, vendor)
229247
end
230248
end
231249
end

test/pardiso/pardiso.jl

+23-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using LinearSolve, SparseArrays, Random
1+
using LinearSolve, SparseArrays, Random, LinearAlgebra
22
import Pardiso
33

44
A1 = sparse([1.0 0 -2 3
@@ -38,6 +38,9 @@ for alg in algs
3838
@test A2 * u b2
3939
end
4040

41+
return
42+
43+
4144
Random.seed!(10)
4245
A = sprand(n, n, 0.8);
4346
A2 = 2.0 .* A;
@@ -64,6 +67,25 @@ sol33 = solve(linsolve)
6467
@test sol12.u sol32.u
6568
@test sol13.u sol33.u
6669

70+
71+
# Test for problem from #497
72+
function makeA()
73+
n = 60
74+
colptr = [1, 4, 7, 11, 15, 17, 22, 26, 30, 34, 38, 40, 46, 50, 54, 58, 62, 64, 70, 74, 78, 82, 86, 88, 94, 98, 102, 106, 110, 112, 118, 122, 126, 130, 134, 136, 142, 146, 150, 154, 158, 160, 166, 170, 174, 178, 182, 184, 190, 194, 198, 202, 206, 208, 214, 218, 222, 224, 226, 228, 232]
75+
rowval = [1, 3, 4, 1, 2, 4, 2, 4, 9, 10, 3, 5, 11, 12, 1, 3, 2, 4, 6, 11, 12, 2, 7, 9, 10, 2, 7, 8, 10, 8, 10, 15, 16, 9, 11, 17, 18, 7, 9, 2, 8, 10, 12, 17, 18, 8, 13, 15, 16, 8, 13, 14, 16, 14, 16, 21, 22, 15, 17, 23, 24, 13, 15, 8, 14, 16, 18, 23, 24, 14, 19, 21, 22, 14, 19, 20, 22, 20, 22, 27, 28, 21, 23, 29, 30, 19, 21, 14, 20, 22, 24, 29, 30, 20, 25, 27, 28, 20, 25, 26, 28, 26, 28, 33, 34, 27, 29, 35, 36, 25, 27, 20, 26, 28, 30, 35, 36, 26, 31, 33, 34, 26, 31, 32, 34, 32, 34, 39, 40, 33, 35, 41, 42, 31, 33, 26, 32, 34, 36, 41, 42, 32, 37, 39, 40, 32, 37, 38, 40, 38, 40, 45, 46, 39, 41, 47, 48, 37, 39, 32, 38, 40, 42, 47, 48, 38, 43, 45, 46, 38, 43, 44, 46, 44, 46, 51, 52, 45, 47, 53, 54, 43, 45, 38, 44, 46, 48, 53, 54, 44, 49, 51, 52, 44, 49, 50, 52, 50, 52, 57, 58, 51, 53, 59, 60, 49, 51, 44, 50, 52, 54, 59, 60, 50, 55, 57, 58, 50, 55, 56, 58, 56, 58, 57, 59, 55, 57, 50, 56, 58, 60]
76+
nzval = [-0.64, 1.0, -1.0, 0.8606811145510832, -13.792569659442691, 1.0, 0.03475000000000006, 1.0, -0.03510101010101016, -0.975, -1.0806825309567203, 1.0, -0.95, -0.025, 2.370597639417811, -2.3705976394178108, -11.083604432603583, -0.2770901108150896, 1.0, -0.025, -0.95, -0.3564, -0.64, 1.0, -1.0, 13.792569659442691, 0.8606811145510832, -13.792569659442691, 1.0, 0.03475000000000006, 1.0, -0.03510101010101016, -0.975, -1.0806825309567203, 1.0, -0.95, -0.025, 2.370597639417811, -2.3705976394178108, 10.698449178570607, -11.083604432603583, -0.2770901108150896, 1.0, -0.025, -0.95, -0.3564, -0.64, 1.0, -1.0, 13.792569659442691, 0.8606811145510832, -13.792569659442691, 1.0, 0.03475000000000006, 1.0, -0.03510101010101016, -0.975, -1.0806825309567203, 1.0, -0.95, -0.025, 2.370597639417811, -2.3705976394178108, 10.698449178570607, -11.083604432603583, -0.2770901108150896, 1.0, -0.025, -0.95, -0.3564, -0.64, 1.0, -1.0, 13.792569659442691, 0.8606811145510832, -13.792569659442691, 1.0, 0.03475000000000006, 1.0, -0.03510101010101016, -0.975, -1.0806825309567203, 1.0, -0.95, -0.025, 2.370597639417811, -2.3705976394178108, 10.698449178570607, -11.083604432603583, -0.2770901108150896, 1.0, -0.025, -0.95, -0.3564, -0.64, 1.0, -1.0, 13.792569659442691, 0.8606811145510832, -13.792569659442691, 1.0, 0.03475000000000006, 1.0, -0.03510101010101016, -0.975, -1.0806825309567203, 1.0, -0.95, -0.025, 2.370597639417811, -2.3705976394178108, 10.698449178570607, -11.083604432603583, -0.2770901108150896, 1.0, -0.025, -0.95, -0.3564, -0.64, 1.0, -1.0, 13.792569659442691, 0.8606811145510832, -13.792569659442691, 1.0, 0.03475000000000006, 1.0, -0.03510101010101016, -0.975, -1.0806825309567203, 1.0, -0.95, -0.025, 2.370597639417811, -2.3705976394178108, 10.698449178570607, -11.083604432603583, -0.2770901108150896, 1.0, -0.025, -0.95, -0.3564, -0.64, 1.0, -1.0, 13.792569659442691, 0.8606811145510832, -13.792569659442691, 1.0, 0.03475000000000006, 1.0, -0.03510101010101016, -0.975, -1.0806825309567203, 1.0, -0.95, -0.025, 2.370597639417811, -2.3705976394178108, 10.698449178570607, -11.083604432603583, -0.2770901108150896, 1.0, -0.025, -0.95, -0.3564, -0.64, 1.0, -1.0, 13.792569659442691, 0.8606811145510832, -13.792569659442691, 1.0, 0.03475000000000006, 1.0, -0.03510101010101016, -0.975, -1.0806825309567203, 1.0, -0.95, -0.025, 2.370597639417811, -2.3705976394178108, 10.698449178570607, -11.083604432603583, -0.2770901108150896, 1.0, -0.025, -0.95, -0.3564, -0.64, 1.0, -1.0, 13.792569659442691, 0.8606811145510832, -13.792569659442691, 1.0, 0.03475000000000006, 1.0, -0.03510101010101016, -0.975, -1.0806825309567203, 1.0, -0.95, -0.025, 2.370597639417811, -2.3705976394178108, 10.698449178570607, -11.083604432603583, -0.2770901108150896, 1.0, -0.025, -0.95, -0.3564, -0.64, 1.0, -1.0, 13.792569659442691, 0.8606811145510832, -13.792569659442691, 1.0, 0.03475000000000006, 1.0, -1.0806825309567203, 1.0, 2.370597639417811, -2.3705976394178108, 10.698449178570607, -11.083604432603583, -0.2770901108150896, 1.0]
77+
A = SparseMatrixCSC(n, n, colptr, rowval, nzval)
78+
return(A)
79+
end
80+
81+
A=makeA()
82+
u0=fill(0.1,size(A,2))
83+
linprob = LinearProblem(A, A*u0)
84+
u = LinearSolve.solve(linprob, PardisoJL())
85+
@test norm(u-u0) < 1.0e-14
86+
87+
88+
6789
# Testing and demonstrating Pardiso.set_iparm! for MKLPardisoSolver
6890
solver = Pardiso.MKLPardisoSolver()
6991
iparm = [

0 commit comments

Comments
 (0)