Skip to content

Commit dca8c51

Browse files
Handle BigFloat error messages and loading in sparse defaults
1 parent be89d6f commit dca8c51

File tree

6 files changed

+53
-13
lines changed

6 files changed

+53
-13
lines changed

.github/workflows/Tests.yml

+1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ jobs:
2929
- "pre"
3030
group:
3131
- "Core"
32+
- "DefaultsLoading"
3233
- "LinearSolveHYPRE"
3334
- "LinearSolvePardiso"
3435
- "LinearSolveBandedMatrices"

ext/LinearSolveSparseArraysExt.jl

+8-7
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ module LinearSolveSparseArraysExt
33
using LinearSolve, LinearAlgebra
44
using SparseArrays
55
using SparseArrays: AbstractSparseMatrixCSC, nonzeros, rowvals, getcolptr
6+
using LinearSolve: BLASELTYPES
67

78
# Can't `using KLU` because cannot have a dependency in there without
89
# requiring the user does `using KLU`
@@ -39,7 +40,7 @@ function LinearSolve.handle_sparsematrixcsc_lu(A::AbstractSparseMatrixCSC)
3940
end
4041

4142
function LinearSolve.defaultalg(
42-
A::Symmetric{<:Number, <:SparseMatrixCSC}, b, ::OperatorAssumptions{Bool})
43+
A::Symmetric{<:BLASELTYPES, <:SparseMatrixCSC}, b, ::OperatorAssumptions{Bool})
4344
LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.CHOLMODFactorization)
4445
end
4546

@@ -78,7 +79,7 @@ function LinearSolve.init_cacheval(
7879
end
7980

8081
function LinearSolve.init_cacheval(
81-
alg::UMFPACKFactorization, A::AbstractSparseArray, b, u, Pl, Pr,
82+
alg::UMFPACKFactorization, A::AbstractSparseArray{Float64}, b, u, Pl, Pr,
8283
maxiters::Int, abstol,
8384
reltol,
8485
verbose::Bool, assumptions::OperatorAssumptions)
@@ -136,7 +137,7 @@ function LinearSolve.init_cacheval(
136137
end
137138

138139
function LinearSolve.init_cacheval(
139-
alg::KLUFactorization, A::AbstractSparseArray, b, u, Pl, Pr,
140+
alg::KLUFactorization, A::AbstractSparseArray{Float64}, b, u, Pl, Pr,
140141
maxiters::Int, abstol,
141142
reltol,
142143
verbose::Bool, assumptions::OperatorAssumptions)
@@ -186,15 +187,15 @@ function LinearSolve.init_cacheval(alg::CHOLMODFactorization,
186187
Pl, Pr,
187188
maxiters::Int, abstol, reltol,
188189
verbose::Bool, assumptions::OperatorAssumptions) where {T <:
189-
Union{Float32, Float64}}
190+
BLASELTYPES}
190191
PREALLOCATED_CHOLMOD
191192
end
192193

193194
function LinearSolve.init_cacheval(alg::NormalCholeskyFactorization,
194-
A::Union{AbstractSparseArray, LinearSolve.GPUArraysCore.AnyGPUArray,
195-
Symmetric{<:Number, <:AbstractSparseArray}}, b, u, Pl, Pr,
195+
A::Union{AbstractSparseArray{T}, LinearSolve.GPUArraysCore.AnyGPUArray,
196+
Symmetric{T, <:AbstractSparseArray{T}}}, b, u, Pl, Pr,
196197
maxiters::Int, abstol, reltol, verbose::Bool,
197-
assumptions::OperatorAssumptions)
198+
assumptions::OperatorAssumptions) where {T <: BLASELTYPES}
198199
LinearSolve.ArrayInterface.cholesky_instance(convert(AbstractMatrix, A))
199200
end
200201

src/default.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ function algchoice_to_alg(alg::Symbol)
229229
elseif alg === :DirectLdiv!
230230
DirectLdiv!()
231231
elseif alg === :SparspakFactorization
232-
SparspakFactorization()
232+
SparspakFactorization(throwerror = false)
233233
elseif alg === :KLUFactorization
234234
KLUFactorization()
235235
elseif alg === :UMFPACKFactorization

src/factorization.jl

+5-5
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ function init_cacheval(alg::CholeskyFactorization, A::GPUArraysCore.AnyGPUArray,
319319
cholesky(A; check = false)
320320
end
321321

322-
function init_cacheval(alg::CholeskyFactorization, A, b, u, Pl, Pr,
322+
function init_cacheval(alg::CholeskyFactorization, A::AbstractArray{<:BLASELTYPES}, b, u, Pl, Pr,
323323
maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
324324
ArrayInterface.cholesky_instance(convert(AbstractMatrix, A), alg.pivot)
325325
end
@@ -333,7 +333,7 @@ function init_cacheval(alg::CholeskyFactorization, A::Matrix{Float64}, b, u, Pl,
333333
end
334334

335335
function init_cacheval(alg::CholeskyFactorization,
336-
A::Union{Diagonal, AbstractSciMLOperator}, b, u, Pl, Pr,
336+
A::Union{Diagonal, AbstractSciMLOperator, AbstractArray}, b, u, Pl, Pr,
337337
maxiters::Int, abstol, reltol, verbose::Bool,
338338
assumptions::OperatorAssumptions)
339339
nothing
@@ -1046,10 +1046,10 @@ This e.g. allows for Automatic Differentiation (AD) of a sparse-matrix solve.
10461046
"""
10471047
struct SparspakFactorization <: AbstractSparseFactorization
10481048
reuse_symbolic::Bool
1049-
1050-
function SparspakFactorization(;reuse_symbolic = true)
1049+
1050+
function SparspakFactorization(;reuse_symbolic = true, throwerror = true)
10511051
ext = Base.get_extension(@__MODULE__, :LinearSolveSparspakExt)
1052-
if ext === nothing
1052+
if throwerror && ext === nothing
10531053
error("SparspakFactorization requires that Sparspak is loaded, i.e. `using Sparspak`")
10541054
else
10551055
new(reuse_symbolic)

test/defaults_loading.jl

+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
using SparseArrays
2+
using LinearSolve
3+
using Test
4+
5+
n = 10
6+
dx = 1/n
7+
dx2 = dx^-2
8+
vals = Vector{BigFloat}(undef, 0)
9+
cols = Vector{Int}(undef, 0)
10+
rows = Vector{Int}(undef, 0)
11+
for i in 1:n
12+
if i != 1
13+
push!(vals, dx2)
14+
push!(cols, i-1)
15+
push!(rows, i)
16+
end
17+
push!(vals, -2dx2)
18+
push!(cols, i)
19+
push!(rows, i)
20+
if i != n
21+
push!(vals, dx2)
22+
push!(cols, i+1)
23+
push!(rows, i)
24+
end
25+
end
26+
mat = sparse(rows, cols, vals, n, n)
27+
rhs = big.(zeros(n))
28+
rhs[begin] = rhs[end] = -2
29+
prob = LinearProblem(mat, rhs)
30+
@test_throws ["SparspakFactorization required", "using Sparspak"] sol = solve(prob).u
31+
32+
using Sparspak
33+
sol = solve(prob).u
34+
@test sol isa Vector{BigFloat}

test/runtests.jl

+4
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ if GROUP == "All" || GROUP == "Enzyme"
2525
@time @safetestset "Enzyme Derivative Rules" include("enzyme.jl")
2626
end
2727

28+
if GROUP == "All" || GROUP == "DefaultsLoading"
29+
@time @safetestset "Enzyme Derivative Rules" include("defaults_loading.jl")
30+
end
31+
2832
if GROUP == "LinearSolveCUDA"
2933
Pkg.activate("gpu")
3034
Pkg.develop(PackageSpec(path = dirname(@__DIR__)))

0 commit comments

Comments
 (0)