-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathPreconditioners.jl
48 lines (37 loc) · 1.92 KB
/
Preconditioners.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
module Preconditioners
using SparseArrays, SparseMatricesCSR
using LinearSolve
import LinearSolve: \
using Adapt
using UnPack
import KernelAbstractions: Backend, @kernel, @index, @ndrange, @groupsize, @print, functional,
CPU,synchronize
import SparseArrays: getcolptr,getnzval
import SparseMatricesCSR: getnzval
import LinearAlgebra: Symmetric
## Generic Code #
# CSR and CSC are exact the same in symmetric matrices,so we need to hold symmetry info
# in order to be exploited in cases in which one format has better access pattern than the other.
abstract type AbstractMatrixSymmetry end
struct SymmetricMatrix <: AbstractMatrixSymmetry end
struct NonSymmetricMatrix <: AbstractMatrixSymmetry end
abstract type AbstractMatrixFormat end
struct CSRFormat <: AbstractMatrixFormat end
struct CSCFormat <: AbstractMatrixFormat end
# Why using these traits?
# Since we are targeting multiple backends, but unfortunately, all the sparse matrix CSC/CSR on all
# backends don't share the same supertype (e.g. AbstractSparseMatrixCSC/AbstractSparseMatrixCSR)
# e.g. CUSPARSE.CuSparseDeviceMatrixCSC <:SparseArrays.AbstractSparseMatrixCSC → false
# So we need to define our own traits to identify the format of the sparse matrix
sparsemat_format_type(::SparseMatrixCSC) = CSCFormat
sparsemat_format_type(::SparseMatrixCSR) = CSRFormat
#TODO: remove once https://github.com/JuliaGPU/CUDA.jl/pull/2740 is merged
convert_to_backend(backend::Backend, A::AbstractSparseMatrix) =
adapt(backend, A) # fallback value, specific backends are to be extended in their corresponding extensions.
# Why? because we want to circumvent piracy when extending these functions for device backend (e.g. CuSparseDeviceMatrixCSR)
# TODO: find a more robust solution to dispatch the correct function
colvals(A::SparseMatrixCSR) = SparseMatricesCSR.colvals(A)
getrowptr(A::SparseMatrixCSR) = SparseMatricesCSR.getrowptr(A)
include("l1_gauss_seidel.jl")
export L1GSPrecBuilder
end