diff --git a/Project.toml b/Project.toml index 8c98c01e6..5ebc4912c 100644 --- a/Project.toml +++ b/Project.toml @@ -81,6 +81,7 @@ LinearSolvePETScExt = ["PETSc", "SparseArrays"] LinearSolveParUExt = ["ParU_jll", "SparseArrays"] LinearSolvePardisoExt = ["Pardiso", "SparseArrays"] LinearSolveRecursiveFactorizationExt = "RecursiveFactorization" +LinearSolveSTRUMPACKExt = "SparseArrays" LinearSolveSparseArraysExt = "SparseArrays" LinearSolveSparspakExt = ["SparseArrays", "Sparspak"] diff --git a/ext/LinearSolveSTRUMPACKExt.jl b/ext/LinearSolveSTRUMPACKExt.jl new file mode 100644 index 000000000..26f23cc8e --- /dev/null +++ b/ext/LinearSolveSTRUMPACKExt.jl @@ -0,0 +1,239 @@ +module LinearSolveSTRUMPACKExt + +using LinearSolve: LinearSolve, LinearVerbosity, OperatorAssumptions +using SparseArrays: SparseArrays, AbstractSparseMatrixCSC, getcolptr, rowvals, nonzeros +using SciMLBase: SciMLBase, ReturnCode +using SciMLLogging: @SciMLMessage +using Libdl: Libdl + +const STRUMPACK_SUCCESS = Cint(0) +const STRUMPACK_MATRIX_NOT_SET = Cint(1) +const STRUMPACK_REORDERING_ERROR = Cint(2) +const STRUMPACK_ZERO_PIVOT = Cint(3) +const STRUMPACK_NO_CONVERGENCE = Cint(4) +const STRUMPACK_INACCURATE_INERTIA = Cint(5) + +const STRUMPACK_DOUBLE = Cint(1) +const STRUMPACK_MT = Cint(0) + +const _libstrumpack = Ref{Ptr{Cvoid}}(C_NULL) + +function _load_libstrumpack() + for name in ( + "libstrumpack.so", + "libstrumpack.so.8", + "libstrumpack.so.7", + "libstrumpack.dylib", + "strumpack", + ) + handle = Libdl.dlopen_e(name) + handle != C_NULL && return handle + end + return C_NULL +end + +function __init__() + return _libstrumpack[] = _load_libstrumpack() +end + +strumpack_isavailable() = _libstrumpack[] != C_NULL + +mutable struct STRUMPACKCache + solver::Ref{Ptr{Cvoid}} + rowptr::Vector{Int32} + colind::Vector{Int32} + nzval::Vector{Float64} + + function STRUMPACKCache() + cache = new(Ref{Ptr{Cvoid}}(C_NULL), Int32[], Int32[], Float64[]) + finalizer(_strumpack_destroy!, cache) + return cache + end +end + +function _strumpack_destroy!(cache::STRUMPACKCache) + _libstrumpack[] == C_NULL && return + cache.solver[] == C_NULL && return + ccall((:STRUMPACK_destroy, _libstrumpack[]), Cvoid, (Ref{Ptr{Cvoid}},), cache.solver) + cache.solver[] = C_NULL + return +end + +function _ensure_initialized!(cache::STRUMPACKCache) + cache.solver[] != C_NULL && return + ccall( + (:STRUMPACK_init_mt, _libstrumpack[]), + Cvoid, + (Ref{Ptr{Cvoid}}, Cint, Cint, Cint, Ptr{Ptr{UInt8}}, Cint), + cache.solver, + STRUMPACK_DOUBLE, + STRUMPACK_MT, + Cint(0), + Ptr{Ptr{UInt8}}(C_NULL), + Cint(0) + ) + return +end + +function _csc_to_csr_0based(A::AbstractSparseMatrixCSC) + n = size(A, 1) + colptr = getcolptr(A) + rowval = rowvals(A) + vals = nonzeros(A) + + nnz = length(vals) + rowptr = zeros(Int32, n + 1) + + @inbounds for idx in eachindex(rowval) + rowptr[Int(rowval[idx]) + 1] += 1 + end + + @inbounds for i in 1:n + rowptr[i + 1] += rowptr[i] + end + + nextidx = copy(rowptr) + colind = Vector{Int32}(undef, nnz) + outvals = Vector{Float64}(undef, nnz) + + @inbounds for j in 1:size(A, 2) + for p in colptr[j]:(colptr[j + 1] - 1) + row = Int(rowval[p]) + pos = Int(nextidx[row] + 1) + nextidx[row] += 1 + colind[pos] = Int32(j - 1) + outvals[pos] = Float64(vals[p]) + end + end + + return rowptr, colind, outvals +end + +function _retcode_from_strumpack(info::Cint) + return if info == STRUMPACK_SUCCESS + ReturnCode.Success + elseif info == STRUMPACK_ZERO_PIVOT + ReturnCode.Infeasible + elseif info == STRUMPACK_NO_CONVERGENCE + ReturnCode.ConvergenceFailure + elseif info == STRUMPACK_INACCURATE_INERTIA + ReturnCode.Unstable + elseif info == STRUMPACK_MATRIX_NOT_SET || info == STRUMPACK_REORDERING_ERROR + ReturnCode.Failure + else + ReturnCode.Failure + end +end + +function LinearSolve.init_cacheval( + ::LinearSolve.STRUMPACKFactorization, + A::AbstractSparseMatrixCSC{<:AbstractFloat}, b, u, Pl, Pr, + maxiters::Int, abstol, reltol, + verbose::Union{LinearVerbosity, Bool}, assumptions::OperatorAssumptions + ) + return STRUMPACKCache() +end + +function LinearSolve.init_cacheval( + ::LinearSolve.STRUMPACKFactorization, + A, b, u, Pl, Pr, + maxiters::Int, abstol, reltol, + verbose::Union{LinearVerbosity, Bool}, assumptions::OperatorAssumptions + ) + return nothing +end + +function SciMLBase.solve!( + cache::LinearSolve.LinearCache, + alg::LinearSolve.STRUMPACKFactorization; + kwargs... + ) + if _libstrumpack[] == C_NULL + error("STRUMPACKFactorization requires a discoverable STRUMPACK shared library (`libstrumpack`)") + end + + A = convert(AbstractMatrix, cache.A) + if !(A isa AbstractSparseMatrixCSC) + error("STRUMPACKFactorization currently supports only sparse CSC matrices") + end + size(A, 1) == size(A, 2) || error("STRUMPACKFactorization requires a square matrix") + + scache = LinearSolve.@get_cacheval(cache, :STRUMPACKFactorization) + if scache === nothing + error("STRUMPACKFactorization currently supports `AbstractSparseMatrixCSC{<:AbstractFloat}`") + end + + _ensure_initialized!(scache) + + if cache.isfresh + scache.rowptr, scache.colind, scache.nzval = _csc_to_csr_0based(A) + ccall( + (:STRUMPACK_set_csr_matrix, _libstrumpack[]), + Cvoid, + (Ptr{Cvoid}, Cint, Ref{Cint}, Ref{Cint}, Ref{Cdouble}, Cint), + scache.solver[], + Cint(size(A, 1)), + scache.rowptr, + scache.colind, + scache.nzval, + Cint(0) + ) + + info = ccall((:STRUMPACK_factor, _libstrumpack[]), Cint, (Ptr{Cvoid},), scache.solver[]) + if info != STRUMPACK_SUCCESS + @SciMLMessage( + "STRUMPACK factorization failed (code $(Int(info)))", + cache.verbose, + :solver_failure + ) + cache.isfresh = false + return SciMLBase.build_linear_solution( + alg, + cache.u, + nothing, + cache; + retcode = _retcode_from_strumpack(info) + ) + end + cache.isfresh = false + end + + bvec = Float64.(cache.b) + xvec = Float64.(cache.u) + + info = ccall( + (:STRUMPACK_solve, _libstrumpack[]), + Cint, + (Ptr{Cvoid}, Ref{Cdouble}, Ref{Cdouble}, Cint), + scache.solver[], + bvec, + xvec, + Cint(alg.use_initial_guess) + ) + + if info != STRUMPACK_SUCCESS + @SciMLMessage( + "STRUMPACK solve failed (code $(Int(info)))", + cache.verbose, + :solver_failure + ) + return SciMLBase.build_linear_solution( + alg, + cache.u, + nothing, + cache; + retcode = _retcode_from_strumpack(info) + ) + end + + copyto!(cache.u, xvec) + return SciMLBase.build_linear_solution( + alg, + cache.u, + nothing, + cache; + retcode = ReturnCode.Success + ) +end + +end diff --git a/src/LinearSolve.jl b/src/LinearSolve.jl index 1307ac648..6b3a62a47 100644 --- a/src/LinearSolve.jl +++ b/src/LinearSolve.jl @@ -475,6 +475,7 @@ for alg in ( :DiagonalFactorization, :CholeskyFactorization, :BunchKaufmanFactorization, :CHOLMODFactorization, :LDLtFactorization, :AppleAccelerateLUFactorization, :MKLLUFactorization, :MetalLUFactorization, :CUSOLVERRFFactorization, :ParUFactorization, + :STRUMPACKFactorization, ) @eval needs_square_A(::$(alg)) = true end @@ -513,7 +514,8 @@ export LUFactorization, SVDFactorization, QRFactorization, GenericFactorization, UMFPACKFactorization, KLUFactorization, FastLUFactorization, FastQRFactorization, SparspakFactorization, DiagonalFactorization, CholeskyFactorization, BunchKaufmanFactorization, CHOLMODFactorization, LDLtFactorization, - CUSOLVERRFFactorization, CliqueTreesFactorization, ParUFactorization + CUSOLVERRFFactorization, CliqueTreesFactorization, ParUFactorization, + STRUMPACKFactorization export LinearSolveFunction, DirectLdiv!, show_algorithm_choices diff --git a/src/factorization.jl b/src/factorization.jl index 33cc0114e..71e9b3173 100644 --- a/src/factorization.jl +++ b/src/factorization.jl @@ -1482,6 +1482,52 @@ struct SparspakFactorization <: AbstractSparseFactorization end end +""" +`STRUMPACKFactorization(; use_initial_guess = false)` + +A sparse direct solver based on +[STRUMPACK](https://github.com/pghysels/STRUMPACK) via the +`LinearSolveSTRUMPACKExt` extension. + +This wrapper targets the single-node (`MT`) sparse interface and currently supports +real sparse matrices (`AbstractSparseMatrixCSC{<:AbstractFloat}`), solving in +`Float64` precision. + +!!! note + + Using this solver requires: + 1. `using SparseArrays` (to enable sparse matrix support), and + 2. a system installation of `libstrumpack` discoverable by the dynamic loader. +""" +struct STRUMPACKFactorization <: AbstractSparseFactorization + use_initial_guess::Bool + + function STRUMPACKFactorization(; use_initial_guess = false, throwerror = true) + ext = Base.get_extension(@__MODULE__, :LinearSolveSTRUMPACKExt) + return if throwerror && (ext === nothing || !ext.strumpack_isavailable()) + error("STRUMPACKFactorization requires a discoverable STRUMPACK shared library (`libstrumpack`) and `using SparseArrays`") + else + new(use_initial_guess) + end + end +end + +function init_cacheval( + ::STRUMPACKFactorization, + ::Union{AbstractMatrix, Nothing, AbstractSciMLOperator}, b, u, Pl, Pr, + maxiters::Int, abstol, reltol, + verbose::Union{LinearVerbosity, Bool}, assumptions::OperatorAssumptions + ) + return nothing +end + +function init_cacheval( + ::STRUMPACKFactorization, ::StaticArray, b, u, Pl, Pr, + maxiters::Int, abstol, reltol, verbose::Union{LinearVerbosity, Bool}, assumptions::OperatorAssumptions + ) + return nothing +end + function init_cacheval( alg::SparspakFactorization, A::Union{AbstractMatrix, Nothing, AbstractSciMLOperator}, b, u, Pl, Pr, diff --git a/test/defaults_loading.jl b/test/defaults_loading.jl index 05369acf1..817828d54 100644 --- a/test/defaults_loading.jl +++ b/test/defaults_loading.jl @@ -29,6 +29,13 @@ rhs[begin] = rhs[end] = -2 prob = LinearProblem(mat, rhs) @test_throws ["SparspakFactorization required", "using Sparspak"] sol = solve(prob).u +STRUMPACKExt = Base.get_extension(LinearSolve, :LinearSolveSTRUMPACKExt) +if STRUMPACKExt === nothing || !STRUMPACKExt.strumpack_isavailable() + @test_throws ["STRUMPACKFactorization", "libstrumpack"] STRUMPACKFactorization() +else + @test STRUMPACKFactorization() isa STRUMPACKFactorization +end + using Sparspak sol = solve(prob).u @test sol isa Vector{BigFloat} diff --git a/test/resolve.jl b/test/resolve.jl index 3782175f5..42987e72d 100644 --- a/test/resolve.jl +++ b/test/resolve.jl @@ -4,6 +4,9 @@ using LinearSolve: AbstractDenseFactorization, AbstractSparseFactorization, AMDGPUOffloadLUFactorization, AMDGPUOffloadQRFactorization, SparspakFactorization +const STRUMPACKExt = Base.get_extension(LinearSolve, :LinearSolveSTRUMPACKExt) +const HAS_STRUMPACK = STRUMPACKExt !== nothing && STRUMPACKExt.strumpack_isavailable() + # Function to check if an algorithm is mixed precision function is_mixed_precision_alg(alg) alg_name = string(alg) @@ -48,6 +51,7 @@ for alg in vcat( (!(alg == AppleAccelerate32MixedLUFactorization) || Sys.isapple()) && (!(alg == OpenBLAS32MixedLUFactorization) || LinearSolve.useopenblas) && (!(alg == SparspakFactorization) || false) && + (!(alg == STRUMPACKFactorization) || HAS_STRUMPACK) && ( !(alg == ParUFactorization) || Base.get_extension(LinearSolve, :LinearSolveParUExt) !== nothing @@ -55,7 +59,7 @@ for alg in vcat( A = [1.0 2.0; 3.0 4.0] alg in [ KLUFactorization, UMFPACKFactorization, SparspakFactorization, - ParUFactorization, + ParUFactorization, STRUMPACKFactorization, ] && (A = sparse(A)) A = A' * A @@ -84,7 +88,7 @@ for alg in vcat( A = [1.0 2.0; 3.0 4.0] alg in [ KLUFactorization, UMFPACKFactorization, SparspakFactorization, - ParUFactorization, + ParUFactorization, STRUMPACKFactorization, ] && (A = sparse(A)) A = A' * A diff --git a/test/runtests.jl b/test/runtests.jl index df24888c0..1820e03b8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -53,6 +53,10 @@ if GROUP == "DefaultsLoading" @time @safetestset "Defaults Loading Tests" include("defaults_loading.jl") end +if GROUP == "All" || GROUP == "LinearSolveSTRUMPACK" + @time @safetestset "LinearSolveSTRUMPACK" include("strumpack/strumpack.jl") +end + if GROUP == "LinearSolveAutotune" Pkg.activate(joinpath(dirname(@__DIR__), "lib", GROUP)) Pkg.test( diff --git a/test/strumpack/strumpack.jl b/test/strumpack/strumpack.jl new file mode 100644 index 000000000..be203b242 --- /dev/null +++ b/test/strumpack/strumpack.jl @@ -0,0 +1,56 @@ +using LinearSolve, LinearAlgebra, SparseArrays, SciMLBase +using Test + +@testset "STRUMPACK Factorization" begin + ext = Base.get_extension(LinearSolve, :LinearSolveSTRUMPACKExt) + @test ext !== nothing + + if ext === nothing || !ext.strumpack_isavailable() + @test_throws ["STRUMPACKFactorization", "libstrumpack"] STRUMPACKFactorization() + @test STRUMPACKFactorization(throwerror = false) isa STRUMPACKFactorization + + A = sparse([4.0 1.0; 2.0 3.0]) + b = [1.0, -1.0] + prob = LinearProblem(A, b) + @test_throws ["STRUMPACKFactorization", "libstrumpack"] solve( + prob, + STRUMPACKFactorization(throwerror = false) + ) + else + A = sparse( + [ + 7.0 1.0 0.0 + 2.0 8.0 1.0 + 0.0 3.0 9.0 + ] + ) + b = [1.0, -2.0, 3.0] + + prob = LinearProblem(A, b) + sol = solve(prob, STRUMPACKFactorization()) + @test sol.retcode == ReturnCode.Success + @test A * sol.u ≈ b atol = 1.0e-10 rtol = 1.0e-10 + + cache = init(prob, STRUMPACKFactorization()) + sol1 = solve!(cache) + @test sol1.retcode == ReturnCode.Success + @test A * sol1.u ≈ b atol = 1.0e-10 rtol = 1.0e-10 + + A2 = sparse( + [ + 8.0 1.0 0.0 + 2.0 9.0 1.0 + 0.0 3.0 10.0 + ] + ) + cache.A = A2 + sol2 = solve!(cache) + @test sol2.retcode == ReturnCode.Success + @test A2 * sol2.u ≈ b atol = 1.0e-10 rtol = 1.0e-10 + + prob_guess = LinearProblem(A, b; u0 = fill(1.0, length(b))) + sol_guess = solve(prob_guess, STRUMPACKFactorization(use_initial_guess = true)) + @test sol_guess.retcode == ReturnCode.Success + @test A * sol_guess.u ≈ b atol = 1.0e-10 rtol = 1.0e-10 + end +end