Skip to content

Commit 9f392e7

Browse files
committed
Refactor test functions
1 parent af1f749 commit 9f392e7

File tree

4 files changed

+60
-75
lines changed

4 files changed

+60
-75
lines changed

tests/test_sparse_solvers.py

Lines changed: 45 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import scipy.sparse as sp
88
import xobjects as xo
99
from xobjects.test_helpers import fix_random_seed
10-
from xobjects.sparse import _test_helpers as sptest
1110
from xobjects.context import ModuleNotAvailableError
1211
import warnings
1312
import pytest
@@ -36,6 +35,41 @@
3635
but if testing larger systems, could potentially be omitted.
3736
'''
3837

38+
# ---- Helper functions ----
39+
def issymmetric(A, tol=0):
40+
if A.shape[0] != A.shape[1]:
41+
return False
42+
diff = A - A.T
43+
if tol == 0:
44+
return diff.nnz == 0
45+
else:
46+
# tolerance-based check
47+
return abs(diff).max() <= tol
48+
49+
50+
def assert_residual_ok(res_ref, res_solver,
51+
abs_tol=1e-12,
52+
factor=10):
53+
"""
54+
Check that our solver's residual is both:
55+
- absolutely small enough (abs_tol),
56+
- not catastrophically worse than the reference (factor * res_ref).
57+
"""
58+
# sanity: reference solver itself should be good
59+
assert res_ref < abs_tol, f"Reference residual too large: {res_ref}"
60+
61+
# absolute bound
62+
assert res_solver < abs_tol, (
63+
f"Residual {res_solver} exceeds absolute tolerance {abs_tol}"
64+
)
65+
66+
# relative bound vs reference
67+
assert res_solver <= factor * res_ref, (
68+
f"Residual {res_solver} not within factor {factor} of "
69+
f"reference residual {res_ref}"
70+
)
71+
72+
# ---- Tests ----
3973
cpu_tests = [
4074
("scipySLU", xo.ContextCpu()),
4175
("PyKLU", xo.ContextCpu()),
@@ -118,9 +152,9 @@ def make_tridiagonal_system(n, nbatches):
118152
@pytest.mark.parametrize("sparse_system", [random_system, tridiag_system])
119153
def test_vector_solve(test_solver, test_context, sparse_system):
120154
A_sp, b_sp, x_sp, _ = sparse_system
121-
assert not sptest.issymmetric(A_sp)
155+
assert not issymmetric(A_sp)
122156

123-
scipy_residual = sptest.rel_residual(A_sp,x_sp,b_sp)
157+
scipy_residual = xo.sparse.rel_residual(A_sp,x_sp,b_sp)
124158

125159
if "Cpu" in str(test_context):
126160
A = test_context.splike_lib.sparse.csc_matrix(A_sp)
@@ -134,9 +168,9 @@ def test_vector_solve(test_solver, test_context, sparse_system):
134168
)
135169
x = solver.solve(b)
136170

137-
solver_residual = sptest.rel_residual(A,x,b)
138-
sptest.assert_residual_ok(scipy_residual,solver_residual,
139-
abs_tol = ABS_TOL, factor = TOLERANCE_FACTOR)
171+
solver_residual = xo.sparse.rel_residual(A,x,b)
172+
assert_residual_ok(scipy_residual,solver_residual,
173+
abs_tol = ABS_TOL, factor = TOLERANCE_FACTOR)
140174

141175
random_system = make_random_sparse_system(SPARSE_SYSTEM_SIZE, NUM_BATCHES)
142176
tridiag_system = make_tridiagonal_system(SPARSE_SYSTEM_SIZE, NUM_BATCHES)
@@ -145,8 +179,8 @@ def test_vector_solve(test_solver, test_context, sparse_system):
145179
@pytest.mark.parametrize("sparse_system", [random_system, tridiag_system])
146180
def test_batched_solve(test_solver, test_context, sparse_system):
147181
A_sp, b_sp, x_sp, _ = sparse_system
148-
assert not sptest.issymmetric(A_sp)
149-
scipy_residual = sptest.rel_residual(A_sp,x_sp,b_sp)
182+
assert not issymmetric(A_sp)
183+
scipy_residual = xo.sparse.rel_residual(A_sp,x_sp,b_sp)
150184
if "Cpu" in str(test_context):
151185
A = test_context.splike_lib.sparse.csc_matrix(A_sp)
152186
if "Cupy" in str(test_context):
@@ -160,6 +194,6 @@ def test_batched_solve(test_solver, test_context, sparse_system):
160194
)
161195
x = solver.solve(b)
162196

163-
solver_residual = sptest.rel_residual(A,x,b)
164-
sptest.assert_residual_ok(scipy_residual,solver_residual,
165-
abs_tol = ABS_TOL, factor = TOLERANCE_FACTOR)
197+
solver_residual = xo.sparse.rel_residual(A,x,b)
198+
assert_residual_ok(scipy_residual,solver_residual,
199+
abs_tol = ABS_TOL, factor = TOLERANCE_FACTOR)

xobjects/sparse/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
from ._sparse import factorized_sparse_solver
1+
from ._sparse import factorized_sparse_solver, rel_residual
22
from . import solvers
3-
__all__ = ["factorized_sparse_solver","solvers"]
3+
__all__ = ["factorized_sparse_solver","solvers", "rel_residual"]

xobjects/sparse/_sparse.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# ########################################### #
55

66
import scipy.sparse
7+
import numpy.linalg as npl
78
from numpy import ndarray as nparray
89
from typing import Optional, Literal, Union
910
from ..context import XContext
@@ -330,4 +331,15 @@ def factorized_sparse_solver(A: Union[scipy.sparse.csr_matrix,
330331

331332
def dbugprint(verbose: bool, text: str):
332333
if verbose:
333-
_print("[xo.sparse] "+text)
334+
_print("[xo.sparse] "+text)
335+
336+
def rel_residual(A,x,b):
337+
if hasattr(A, "get"):
338+
A = A.get()
339+
if hasattr(x, "get"):
340+
x = x.get()
341+
if hasattr(b, "get"):
342+
b = b.get()
343+
assert scipy.sparse.issparse(A), ("A must be a sparse matrix")
344+
345+
return npl.norm(A@x - b) / (npl.norm(b))

xobjects/sparse/_test_helpers.py

Lines changed: 0 additions & 61 deletions
This file was deleted.

0 commit comments

Comments
 (0)