Skip to content

Commit 8c12104

Browse files
committed
Make sparse solver choice verbose and document
1 parent 2bd4a14 commit 8c12104

File tree

1 file changed

+225
-26
lines changed

1 file changed

+225
-26
lines changed

xobjects/sparse/_sparse.py

Lines changed: 225 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from ..context_cupy import ContextCupy
77
from ..context_pyopencl import ContextPyopencl
88
from .solvers._abstract_solver import SuperLUlikeSolver
9+
from ..general import _print
10+
911
try:
1012
from cupy import ndarray as cparray
1113
import cupyx.scipy.sparse
@@ -31,11 +33,184 @@ def factorized_sparse_solver(A: Union[scipy.sparse.csr_matrix,
3133
"cupy",
3234
"pyopencl"],
3335
XContext
34-
] = None
36+
] = None,
37+
verbose: bool = False
3538
) -> SuperLUlikeSolver:
39+
"""
40+
Build and return a factorized sparse linear solver on CPU or GPU.
41+
42+
This function inspects the provided sparse matrix and execution context
43+
and returns a *factorized* solver object (SuperLU-like). The solver can
44+
then be reused to efficiently solve multiple linear systems with the same
45+
matrix `A`.
46+
47+
The actual backend is chosen automatically based on:
48+
* The requested/derived `context` ("cpu", "cupy", "pyopencl"),
49+
* The availability of optional libraries (PyKLU, cuDSS, CuPy/CUSPARSE),
50+
* And the optional `force_solver` argument.
51+
52+
On CPU:
53+
* Default: try PyKLU, fall back to `scipy.sparse.linalg.splu`.
54+
* You may force `"scipySLU"` or `"PyKLU"` explicitly.
55+
56+
On CUDA/CuPy:
57+
* Default: try cuDSS (`DirectSolverSuperLU`), then fall back to:
58+
- `cupyx.scipy.sparse.linalg.splu` if CUSPARSE `csrsm2` is available, or
59+
- a cached SuperLU-based solver (`luLU`) and finally `cupyx.splu`.
60+
* You may force `"cuDSS"`, `"CachedSLU"`, or `"cupySLU"` explicitly.
61+
62+
PyOpenCL:
63+
* Currently not supported and will raise `NotImplementedError`.
64+
65+
Parameters
66+
----------
67+
A : scipy.sparse.spmatrix or cupyx.scipy.sparse.spmatrix
68+
Sparse system matrix to factorize.
69+
70+
The matrix is internally converted to the format expected by the
71+
chosen backend:
72+
73+
* CPU context: converted to CSC (`scipy.sparse.csc_matrix`).
74+
* CuPy/GPU context: converted to CSR (`cupyx.scipy.sparse.csr_matrix`).
75+
76+
For **best performance**, you should pass `A` already in the
77+
preferred format to avoid extra conversions:
78+
79+
* If the context is (or will usually be) **CPU**, provide `A`
80+
as a CSC matrix (`A.tocsc()`).
81+
* If the context is **GPU/CuPy**, provide `A` as a CSR matrix
82+
(`A.tocsr()` or `cupyx.scipy.sparse.csr_matrix`).
83+
84+
If `context` is `None`, it is still inferred from the type of `A`
85+
and the availability of CuPy, e.g.:
86+
87+
* SciPy sparse or NumPy array → `"cpu"`,
88+
* CuPy sparse or CuPy array → `"cupy"`,
89+
* otherwise a `TypeError` is raised.
90+
91+
n_batches : int, optional
92+
Controls the expected shape of the right-hand side (RHS) for GPU
93+
solvers and hence whether solves are treated as single or batched:
94+
95+
* If ``n_batches == 0`` (default), the solver is configured for
96+
single-RHS solves and expects a vector RHS of shape ``(n,)``.
97+
* If ``n_batches > 0``, the solver is configured for batched solves
98+
and expects a 2D RHS array of shape ``(n, n_batches)`` (i.e.
99+
``nrhs = n_batches``).
100+
101+
This argument is primarily used by CUDA-based solvers (e.g. cuDSS and
102+
cached SuperLU) to preconfigure internal data structures for batched
103+
solves. It has no effect for CPU-based solvers.
104+
105+
force_solver : {"scipySLU", "PyKLU", "cuDSS", "CachedSLU", "cupySLU"}, optional
106+
If provided, forces the use of a specific backend instead of the
107+
automatic selection:
108+
109+
* `"scipySLU"` : Use `scipy.sparse.linalg.splu` (CPU).
110+
* `"PyKLU"` : Use the `PyKLU.Klu` solver (CPU).
111+
* `"cuDSS"` : Use CUDA/cuDSS-based `DirectSolverSuperLU` (GPU).
112+
* `"CachedSLU"`: Use CUDA cached SuperLU (`luLU`) (GPU).
113+
* `"cupySLU"` : Use `cupyx.scipy.sparse.linalg.splu` (GPU).
114+
115+
Using a solver that does not match the current `context` will result
116+
in a `ValueError`.
117+
118+
solverKwargs : dict, optional
119+
Extra keyword arguments forwarded to the underlying solver constructor.
120+
If `None`, an empty dict is used.
121+
122+
Some backends make use of `permc_spec` (matrix permutation strategy).
123+
When not explicitly provided and appropriate, this function sets
124+
`permc_spec="MMD_AT_PLUS_A"` as a sensible default for the matrices that
125+
will typically be encountered in an xobjects workflow.
126+
127+
context : {"cpu", "cupy", "pyopencl"} or XContext, optional
128+
Execution context. Can be either:
129+
130+
* A string:
131+
- `"cpu"`: Use CPU-based solvers (SciPy / PyKLU).
132+
- `"cupy"`: Use CuPy/CUDA-based solvers.
133+
- `"pyopencl"`: PyOpenCL context (currently unsupported).
134+
* A context object instance:
135+
- `ContextCpu`
136+
- `ContextCupy`
137+
- `ContextPyopencl`
138+
139+
If `None`, the context is inferred from `A` as described above.
140+
141+
verbose : bool, optional
142+
If `True`, prints debug messages describing the solver-selection
143+
process, fallbacks, and the final solver that is returned.
144+
145+
Returns
146+
-------
147+
SuperLUlikeSolver
148+
A factorized solver object compatible with SciPy’s `splu`-like
149+
interface (i.e. typically exposing a `solve` method and related
150+
accessors). The exact concrete type depends on the backend:
151+
* CPU:
152+
- `scipy.sparse.linalg.SuperLU` (for `"scipySLU"`),
153+
- `PyKLU.Klu` (for `"PyKLU"`).
154+
* CUDA/CuPy:
155+
- `DirectSolverSuperLU` (cuDSS),
156+
- `luLU` (cached SuperLU),
157+
- `cupyx.scipy.sparse.linalg.SuperLU` (for `"cupySLU"`).
158+
159+
Raises
160+
------
161+
TypeError
162+
If the type of `A` is unsupported when inferring the context.
163+
164+
AssertionError
165+
If `A` does not match the required type for the chosen context
166+
167+
ModuleNotFoundError
168+
If a requested solver backend depends on a module that is not
169+
installed (e.g. CuPy, PyKLU, cuDSS), and no fallback is available.
170+
171+
RuntimeError
172+
If a requested GPU solver fails during initialization.
173+
174+
NotImplementedError
175+
If `context` is `"pyopencl"` or `ContextPyopencl`, since no sparse
176+
solver is currently implemented for that backend.
177+
178+
ValueError
179+
If an invalid `context` string is provided, or if `force_solver`
180+
does not match any known solver for the active context.
181+
182+
Notes
183+
-----
184+
- For best performance on CPU, PyKLU is preferred when available.
185+
- For CuPy/CUDA, cuDSS is preferred when available, and this function
186+
will automatically fall back to other solvers if cuDSS is not present
187+
or fails at runtime.
188+
- The returned solver is *factorized* and should be reused to solve
189+
multiple right-hand sides efficiently.
190+
191+
Examples
192+
--------
193+
Factorize a SciPy sparse matrix on CPU with automatic solver selection:
194+
195+
>>> A = scipy.sparse.random(1000, 1000, density=0.01, format="csr")
196+
>>> solver = factorized_sparse_solver(A)
197+
>>> x = solver.solve(b)
198+
199+
Explicitly request the SciPy SuperLU solver on CPU:
200+
201+
>>> solver = factorized_sparse_solver(A, force_solver="scipySLU")
202+
203+
Using CuPy and cuDSS (requires CuPy and cuDSS bindings):
204+
205+
>>> A_gpu = cupyx.scipy.sparse.csr_matrix(A)
206+
>>> solver = factorized_sparse_solver(A_gpu, context="cupy")
207+
>>> x_gpu = solver.solve(b_gpu)
208+
"""
36209
if solverKwargs is None:
37210
solverKwargs = {}
38211
if context is None:
212+
dbugprint(verbose, "No context provided. " \
213+
"Context will be inferred from matrix")
39214
if isinstance(A, scipy.sparse.spmatrix) or isinstance(A, nparray):
40215
context = 'cpu'
41216
elif (_cupy_available and
@@ -48,17 +223,28 @@ def factorized_sparse_solver(A: Union[scipy.sparse.csr_matrix,
48223
assert isinstance(A, scipy.sparse.spmatrix), (
49224
"When using CPU context A must be a scipy.sparse matrix"
50225
)
226+
A = A.tocsc() # CPU Solvers require csc format
51227
if 'permc_spec' not in solverKwargs:
52228
solverKwargs = solverKwargs | {"permc_spec":"MMD_AT_PLUS_A"}
53-
if force_solver is None or force_solver == "scipySLU":
54-
if A.shape[0]*n_batches < 10**5 and force_solver is None:
55-
import warnings
56-
warnings.warn("For small matrices, using PyKLU "
57-
"can provide improved performance")
58-
solver = scipy.sparse.linalg.splu(A.tocsc(),**solverKwargs)
229+
if force_solver is None:
230+
dbugprint(verbose, "No solver requested. " \
231+
"Picking best solver for CPU Context")
232+
try:
233+
dbugprint(verbose, "Attempting to use PyKLU")
234+
import PyKLU
235+
solver = PyKLU.Klu(A)
236+
dbugprint(verbose, "PyKLU succeeded")
237+
except (ModuleNotFoundError, RuntimeError) as e:
238+
dbugprint(verbose, "PyKLU failed. " \
239+
"Falling back to scipy.splu \n"
240+
f"Encountered error: {e}")
241+
242+
solver = scipy.sparse.linalg.splu(A,**solverKwargs)
243+
elif force_solver == "scipySLU":
244+
solver = scipy.sparse.linalg.splu(A,**solverKwargs)
59245
elif force_solver == "PyKLU":
60246
import PyKLU
61-
solver = PyKLU.Klu(A.tocsc())
247+
solver = PyKLU.Klu(A)
62248
else:
63249
raise ValueError("Unrecognized CPU Sparse solver. Available options: "
64250
"scipySLU, PyKLU")
@@ -67,36 +253,44 @@ def factorized_sparse_solver(A: Union[scipy.sparse.csr_matrix,
67253
if not _cupy_available:
68254
raise ModuleNotFoundError("No cupy module found. " \
69255
"ContextCupy unavailable")
70-
assert isinstance(A ,cupyx.scipy.sparse.csr_matrix), (
256+
assert isinstance(A ,cupyx.scipy.sparse.spmatrix), (
71257
"When using ContextCupy, input must be "
72-
"cupyx.scipy.sparse.csr_matrix")
73-
258+
"cupyx.scipy.sparse matrix")
259+
260+
A = A.tocsr() # GPU solvers require csr format
74261
if force_solver is not None and force_solver != "cuDSS":
75262
if 'permc_spec' not in solverKwargs:
76263
solverKwargs = solverKwargs | {"permc_spec":"MMD_AT_PLUS_A"}
77264
if force_solver is None:
265+
dbugprint(verbose, "No solver requested. " \
266+
"Picking best solver for Cupy Context")
78267
import warnings
79268
try:
269+
dbugprint(verbose, "Attempting to use cuDSS Solver")
80270
from .solvers.CUDA._cuDSSLU import DirectSolverSuperLU
81271
solver = DirectSolverSuperLU(A, n_batches = n_batches, **solverKwargs)
272+
dbugprint(verbose, "cuDSS succeeded")
82273
except (ModuleNotFoundError, RuntimeError) as e:
83-
warnings.warn("cuDSS not available. "
84-
"Falling back to Cached-SuperLU (spsm) "
85-
f"Encountered Error: {e}")
274+
dbugprint(verbose, "cuDSS failed. \n"
275+
f"Encountered Error: {e}")
276+
warnings.warn("cuDSS not available. Performance will be degraded")
86277
if 'permc_spec' not in solverKwargs:
87278
solverKwargs = solverKwargs | {"permc_spec":"MMD_AT_PLUS_A"}
88-
try:
89-
if cusparse.check_availability('csrsm2'):
90-
raise RuntimeError("csrsm2 is avaiable. "
91-
"cupy SuperLU performs better "
92-
"than Cached-SuperLU (spsm)")
93-
from .solvers.CUDA._luLU import luLU
94-
solver = luLU(A, n_batches = n_batches, **solverKwargs)
95-
except RuntimeError as e:
96-
warnings.warn("Cached-SuperLU (spsm) solver failed. "
97-
"Falling back to cupy SuperLU. "
98-
f"Error encountered: {e}")
279+
if cusparse.check_availability('csrsm2'):
280+
dbugprint(verbose, "csrsm2 available. Using cupyx.splu solver")
99281
solver = cupyx.scipy.sparse.linalg.splu(A, **solverKwargs)
282+
else:
283+
try:
284+
dbugprint(verbose, "csrms2 unavailable. " \
285+
"Attempting to use CachedSuperLU (spsm)")
286+
from .solvers.CUDA._luLU import luLU
287+
solver = luLU(A, n_batches = n_batches, **solverKwargs)
288+
dbugprint(verbose, "CachedSuperLU succeeded")
289+
except RuntimeError as e:
290+
dbugprint(verbose, "CachedSuperLU failed. \n"
291+
f"Encountered error: {e} \n"
292+
"Falling back to cupyx.splu with spsm")
293+
solver = cupyx.scipy.sparse.linalg.splu(A, **solverKwargs)
100294
elif force_solver == "cuDSS":
101295
from .solvers.CUDA._cuDSSLU import DirectSolverSuperLU
102296
solver = DirectSolverSuperLU(A, n_batches = n_batches, **solverKwargs)
@@ -114,4 +308,9 @@ def factorized_sparse_solver(A: Union[scipy.sparse.csr_matrix,
114308
else:
115309
raise ValueError("Invalid context. Available contexts are: " \
116310
"cpu, cupy, pyopencl")
117-
return solver
311+
dbugprint(verbose, "Returning solver: " + str(solver))
312+
return solver
313+
314+
def dbugprint(verbose: bool, text: str):
315+
if verbose:
316+
_print("[xo.sparse] "+text)

0 commit comments

Comments
 (0)