66from ..context_cupy import ContextCupy
77from ..context_pyopencl import ContextPyopencl
88from .solvers ._abstract_solver import SuperLUlikeSolver
9+ from ..general import _print
10+
911try :
1012 from cupy import ndarray as cparray
1113 import cupyx .scipy .sparse
@@ -31,11 +33,184 @@ def factorized_sparse_solver(A: Union[scipy.sparse.csr_matrix,
3133 "cupy" ,
3234 "pyopencl" ],
3335 XContext
34- ] = None
36+ ] = None ,
37+ verbose : bool = False
3538 ) -> SuperLUlikeSolver :
39+ """
40+ Build and return a factorized sparse linear solver on CPU or GPU.
41+
42+ This function inspects the provided sparse matrix and execution context
43+ and returns a *factorized* solver object (SuperLU-like). The solver can
44+ then be reused to efficiently solve multiple linear systems with the same
45+ matrix `A`.
46+
47+ The actual backend is chosen automatically based on:
48+ * The requested/derived `context` ("cpu", "cupy", "pyopencl"),
49+ * The availability of optional libraries (PyKLU, cuDSS, CuPy/CUSPARSE),
50+ * And the optional `force_solver` argument.
51+
52+ On CPU:
53+ * Default: try PyKLU, fall back to `scipy.sparse.linalg.splu`.
54+ * You may force `"scipySLU"` or `"PyKLU"` explicitly.
55+
56+ On CUDA/CuPy:
57+ * Default: try cuDSS (`DirectSolverSuperLU`), then fall back to:
58+ - `cupyx.scipy.sparse.linalg.splu` if CUSPARSE `csrsm2` is available, or
59+ - a cached SuperLU-based solver (`luLU`) and finally `cupyx.splu`.
60+ * You may force `"cuDSS"`, `"CachedSLU"`, or `"cupySLU"` explicitly.
61+
62+ PyOpenCL:
63+ * Currently not supported and will raise `NotImplementedError`.
64+
65+ Parameters
66+ ----------
67+ A : scipy.sparse.spmatrix or cupyx.scipy.sparse.spmatrix
68+ Sparse system matrix to factorize.
69+
70+ The matrix is internally converted to the format expected by the
71+ chosen backend:
72+
73+ * CPU context: converted to CSC (`scipy.sparse.csc_matrix`).
74+ * CuPy/GPU context: converted to CSR (`cupyx.scipy.sparse.csr_matrix`).
75+
76+ For **best performance**, you should pass `A` already in the
77+ preferred format to avoid extra conversions:
78+
79+ * If the context is (or will usually be) **CPU**, provide `A`
80+ as a CSC matrix (`A.tocsc()`).
81+ * If the context is **GPU/CuPy**, provide `A` as a CSR matrix
82+ (`A.tocsr()` or `cupyx.scipy.sparse.csr_matrix`).
83+
84+ If `context` is `None`, it is still inferred from the type of `A`
85+ and the availability of CuPy, e.g.:
86+
87+ * SciPy sparse or NumPy array → `"cpu"`,
88+ * CuPy sparse or CuPy array → `"cupy"`,
89+ * otherwise a `TypeError` is raised.
90+
91+ n_batches : int, optional
92+ Controls the expected shape of the right-hand side (RHS) for GPU
93+ solvers and hence whether solves are treated as single or batched:
94+
95+ * If ``n_batches == 0`` (default), the solver is configured for
96+ single-RHS solves and expects a vector RHS of shape ``(n,)``.
97+ * If ``n_batches > 0``, the solver is configured for batched solves
98+ and expects a 2D RHS array of shape ``(n, n_batches)`` (i.e.
99+ ``nrhs = n_batches``).
100+
101+ This argument is primarily used by CUDA-based solvers (e.g. cuDSS and
102+ cached SuperLU) to preconfigure internal data structures for batched
103+ solves. It has no effect for CPU-based solvers.
104+
105+ force_solver : {"scipySLU", "PyKLU", "cuDSS", "CachedSLU", "cupySLU"}, optional
106+ If provided, forces the use of a specific backend instead of the
107+ automatic selection:
108+
109+ * `"scipySLU"` : Use `scipy.sparse.linalg.splu` (CPU).
110+ * `"PyKLU"` : Use the `PyKLU.Klu` solver (CPU).
111+ * `"cuDSS"` : Use CUDA/cuDSS-based `DirectSolverSuperLU` (GPU).
112+ * `"CachedSLU"`: Use CUDA cached SuperLU (`luLU`) (GPU).
113+ * `"cupySLU"` : Use `cupyx.scipy.sparse.linalg.splu` (GPU).
114+
115+ Using a solver that does not match the current `context` will result
116+ in a `ValueError`.
117+
118+ solverKwargs : dict, optional
119+ Extra keyword arguments forwarded to the underlying solver constructor.
120+ If `None`, an empty dict is used.
121+
122+ Some backends make use of `permc_spec` (matrix permutation strategy).
123+ When not explicitly provided and appropriate, this function sets
124+ `permc_spec="MMD_AT_PLUS_A"` as a sensible default for the matrices that
125+ will typically be encountered in an xobjects workflow.
126+
127+ context : {"cpu", "cupy", "pyopencl"} or XContext, optional
128+ Execution context. Can be either:
129+
130+ * A string:
131+ - `"cpu"`: Use CPU-based solvers (SciPy / PyKLU).
132+ - `"cupy"`: Use CuPy/CUDA-based solvers.
133+ - `"pyopencl"`: PyOpenCL context (currently unsupported).
134+ * A context object instance:
135+ - `ContextCpu`
136+ - `ContextCupy`
137+ - `ContextPyopencl`
138+
139+ If `None`, the context is inferred from `A` as described above.
140+
141+ verbose : bool, optional
142+ If `True`, prints debug messages describing the solver-selection
143+ process, fallbacks, and the final solver that is returned.
144+
145+ Returns
146+ -------
147+ SuperLUlikeSolver
148+ A factorized solver object compatible with SciPy’s `splu`-like
149+ interface (i.e. typically exposing a `solve` method and related
150+ accessors). The exact concrete type depends on the backend:
151+ * CPU:
152+ - `scipy.sparse.linalg.SuperLU` (for `"scipySLU"`),
153+ - `PyKLU.Klu` (for `"PyKLU"`).
154+ * CUDA/CuPy:
155+ - `DirectSolverSuperLU` (cuDSS),
156+ - `luLU` (cached SuperLU),
157+ - `cupyx.scipy.sparse.linalg.SuperLU` (for `"cupySLU"`).
158+
159+ Raises
160+ ------
161+ TypeError
162+ If the type of `A` is unsupported when inferring the context.
163+
164+ AssertionError
165+ If `A` does not match the required type for the chosen context
166+
167+ ModuleNotFoundError
168+ If a requested solver backend depends on a module that is not
169+ installed (e.g. CuPy, PyKLU, cuDSS), and no fallback is available.
170+
171+ RuntimeError
172+ If a requested GPU solver fails during initialization.
173+
174+ NotImplementedError
175+ If `context` is `"pyopencl"` or `ContextPyopencl`, since no sparse
176+ solver is currently implemented for that backend.
177+
178+ ValueError
179+ If an invalid `context` string is provided, or if `force_solver`
180+ does not match any known solver for the active context.
181+
182+ Notes
183+ -----
184+ - For best performance on CPU, PyKLU is preferred when available.
185+ - For CuPy/CUDA, cuDSS is preferred when available, and this function
186+ will automatically fall back to other solvers if cuDSS is not present
187+ or fails at runtime.
188+ - The returned solver is *factorized* and should be reused to solve
189+ multiple right-hand sides efficiently.
190+
191+ Examples
192+ --------
193+ Factorize a SciPy sparse matrix on CPU with automatic solver selection:
194+
195+ >>> A = scipy.sparse.random(1000, 1000, density=0.01, format="csr")
196+ >>> solver = factorized_sparse_solver(A)
197+ >>> x = solver.solve(b)
198+
199+ Explicitly request the SciPy SuperLU solver on CPU:
200+
201+ >>> solver = factorized_sparse_solver(A, force_solver="scipySLU")
202+
203+ Using CuPy and cuDSS (requires CuPy and cuDSS bindings):
204+
205+ >>> A_gpu = cupyx.scipy.sparse.csr_matrix(A)
206+ >>> solver = factorized_sparse_solver(A_gpu, context="cupy")
207+ >>> x_gpu = solver.solve(b_gpu)
208+ """
36209 if solverKwargs is None :
37210 solverKwargs = {}
38211 if context is None :
212+ dbugprint (verbose , "No context provided. " \
213+ "Context will be inferred from matrix" )
39214 if isinstance (A , scipy .sparse .spmatrix ) or isinstance (A , nparray ):
40215 context = 'cpu'
41216 elif (_cupy_available and
@@ -48,17 +223,28 @@ def factorized_sparse_solver(A: Union[scipy.sparse.csr_matrix,
48223 assert isinstance (A , scipy .sparse .spmatrix ), (
49224 "When using CPU context A must be a scipy.sparse matrix"
50225 )
226+ A = A .tocsc () # CPU Solvers require csc format
51227 if 'permc_spec' not in solverKwargs :
52228 solverKwargs = solverKwargs | {"permc_spec" :"MMD_AT_PLUS_A" }
53- if force_solver is None or force_solver == "scipySLU" :
54- if A .shape [0 ]* n_batches < 10 ** 5 and force_solver is None :
55- import warnings
56- warnings .warn ("For small matrices, using PyKLU "
57- "can provide improved performance" )
58- solver = scipy .sparse .linalg .splu (A .tocsc (),** solverKwargs )
229+ if force_solver is None :
230+ dbugprint (verbose , "No solver requested. " \
231+ "Picking best solver for CPU Context" )
232+ try :
233+ dbugprint (verbose , "Attempting to use PyKLU" )
234+ import PyKLU
235+ solver = PyKLU .Klu (A )
236+ dbugprint (verbose , "PyKLU succeeded" )
237+ except (ModuleNotFoundError , RuntimeError ) as e :
238+ dbugprint (verbose , "PyKLU failed. " \
239+ "Falling back to scipy.splu \n "
240+ f"Encountered error: { e } " )
241+
242+ solver = scipy .sparse .linalg .splu (A ,** solverKwargs )
243+ elif force_solver == "scipySLU" :
244+ solver = scipy .sparse .linalg .splu (A ,** solverKwargs )
59245 elif force_solver == "PyKLU" :
60246 import PyKLU
61- solver = PyKLU .Klu (A . tocsc () )
247+ solver = PyKLU .Klu (A )
62248 else :
63249 raise ValueError ("Unrecognized CPU Sparse solver. Available options: "
64250 "scipySLU, PyKLU" )
@@ -67,36 +253,44 @@ def factorized_sparse_solver(A: Union[scipy.sparse.csr_matrix,
67253 if not _cupy_available :
68254 raise ModuleNotFoundError ("No cupy module found. " \
69255 "ContextCupy unavailable" )
70- assert isinstance (A ,cupyx .scipy .sparse .csr_matrix ), (
256+ assert isinstance (A ,cupyx .scipy .sparse .spmatrix ), (
71257 "When using ContextCupy, input must be "
72- "cupyx.scipy.sparse.csr_matrix" )
73-
258+ "cupyx.scipy.sparse matrix" )
259+
260+ A = A .tocsr () # GPU solvers require csr format
74261 if force_solver is not None and force_solver != "cuDSS" :
75262 if 'permc_spec' not in solverKwargs :
76263 solverKwargs = solverKwargs | {"permc_spec" :"MMD_AT_PLUS_A" }
77264 if force_solver is None :
265+ dbugprint (verbose , "No solver requested. " \
266+ "Picking best solver for Cupy Context" )
78267 import warnings
79268 try :
269+ dbugprint (verbose , "Attempting to use cuDSS Solver" )
80270 from .solvers .CUDA ._cuDSSLU import DirectSolverSuperLU
81271 solver = DirectSolverSuperLU (A , n_batches = n_batches , ** solverKwargs )
272+ dbugprint (verbose , "cuDSS succeeded" )
82273 except (ModuleNotFoundError , RuntimeError ) as e :
83- warnings . warn ( "cuDSS not available. "
84- "Falling back to Cached-SuperLU (spsm) "
85- f"Encountered Error: { e } " )
274+ dbugprint ( verbose , "cuDSS failed. \n "
275+ f"Encountered Error: { e } " )
276+ warnings . warn ( "cuDSS not available. Performance will be degraded" )
86277 if 'permc_spec' not in solverKwargs :
87278 solverKwargs = solverKwargs | {"permc_spec" :"MMD_AT_PLUS_A" }
88- try :
89- if cusparse .check_availability ('csrsm2' ):
90- raise RuntimeError ("csrsm2 is avaiable. "
91- "cupy SuperLU performs better "
92- "than Cached-SuperLU (spsm)" )
93- from .solvers .CUDA ._luLU import luLU
94- solver = luLU (A , n_batches = n_batches , ** solverKwargs )
95- except RuntimeError as e :
96- warnings .warn ("Cached-SuperLU (spsm) solver failed. "
97- "Falling back to cupy SuperLU. "
98- f"Error encountered: { e } " )
279+ if cusparse .check_availability ('csrsm2' ):
280+ dbugprint (verbose , "csrsm2 available. Using cupyx.splu solver" )
99281 solver = cupyx .scipy .sparse .linalg .splu (A , ** solverKwargs )
282+ else :
283+ try :
284+ dbugprint (verbose , "csrms2 unavailable. " \
285+ "Attempting to use CachedSuperLU (spsm)" )
286+ from .solvers .CUDA ._luLU import luLU
287+ solver = luLU (A , n_batches = n_batches , ** solverKwargs )
288+ dbugprint (verbose , "CachedSuperLU succeeded" )
289+ except RuntimeError as e :
290+ dbugprint (verbose , "CachedSuperLU failed. \n "
291+ f"Encountered error: { e } \n "
292+ "Falling back to cupyx.splu with spsm" )
293+ solver = cupyx .scipy .sparse .linalg .splu (A , ** solverKwargs )
100294 elif force_solver == "cuDSS" :
101295 from .solvers .CUDA ._cuDSSLU import DirectSolverSuperLU
102296 solver = DirectSolverSuperLU (A , n_batches = n_batches , ** solverKwargs )
@@ -114,4 +308,9 @@ def factorized_sparse_solver(A: Union[scipy.sparse.csr_matrix,
114308 else :
115309 raise ValueError ("Invalid context. Available contexts are: " \
116310 "cpu, cupy, pyopencl" )
117- return solver
311+ dbugprint (verbose , "Returning solver: " + str (solver ))
312+ return solver
313+
314+ def dbugprint (verbose : bool , text : str ):
315+ if verbose :
316+ _print ("[xo.sparse] " + text )
0 commit comments