Skip to content

Commit 79bc349

Browse files
committed
Update solver interface so that it is consistent accross platforms
1 parent 8c12104 commit 79bc349

File tree

2 files changed

+19
-15
lines changed

2 files changed

+19
-15
lines changed
Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,24 @@
1-
__all__ = []
1+
from ....context import ModuleNotAvailableError
22
try:
33
from ._cuDSSLU import DirectSolverSuperLU as cuDSSSuperLU
4-
__all__.append("cuDSSSuperLU")
5-
except (ModuleNotFoundError,ImportError):
6-
pass
4+
except (ModuleNotFoundError,ImportError) as e:
5+
def cuDSSSuperLU(*args, **kwargs):
6+
raise ModuleNotAvailableError(
7+
"cuDSSSuperLU is not available. Could not import required backend."
8+
) from e
79
try:
810
from ._luLU import luLU as CachedSuperLU
9-
__all__.append("CachedSuperLU")
11+
except (ModuleNotFoundError,ImportError):
12+
def CachedSuperLU(*args, **kwargs):
13+
raise ModuleNotAvailableError(
14+
"CachedSuperLU is not available. Could not import required backend."
15+
) from e
16+
try:
1017
from cupyx.scipy.sparse.linalg import splu as CupySuperLU
11-
__all__.append("CupySuperLU")
1218
except (ModuleNotFoundError,ImportError):
13-
pass
19+
def CupySuperLU(*args, **kwargs):
20+
raise ModuleNotAvailableError(
21+
"CupySuperLU is not available. Could not import required backend."
22+
) from e
1423

15-
if not __all__:
16-
raise ImportError("Failed to import any CUDA-based sparse solver")
24+
__all__ = ["cuDSSSuperLU", "CachedSuperLU", "CupySuperLU"]
Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,3 @@
11
from . import CPU
2-
__all__ = ["CPU"]
3-
try:
4-
from . import CUDA
5-
__all__.append("CUDA")
6-
except ImportError:
7-
pass
2+
from . import CUDA
3+
__all__ = ["CPU", "CUDA"]

0 commit comments

Comments
 (0)