Skip to content

Commit 42133f8

Browse files
Add facility to xo.assert_allclose with outliers
1 parent 5fe063e commit 42133f8

File tree

1 file changed

+21
-2
lines changed

1 file changed

+21
-2
lines changed

xobjects/general.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def __call__(self, *args, **kwargs):
1717
_print = Print()
1818

1919

20-
def assert_allclose(a, b, rtol=1e-7, atol=1e-7):
20+
def assert_allclose(a, b, rtol=0, atol=0, max_outliers=0):
2121
if hasattr(a, "get"):
2222
a = a.get()
2323
if hasattr(b, "get"):
@@ -30,4 +30,23 @@ def assert_allclose(a, b, rtol=1e-7, atol=1e-7):
3030
b = np.squeeze(b)
3131
except:
3232
pass
33-
np_assert_allclose(a, b, rtol=rtol, atol=atol)
33+
34+
try:
35+
np_assert_allclose(a, b, rtol=rtol, atol=atol)
36+
except AssertionError as e:
37+
if max_outliers == 0:
38+
raise e
39+
if not allclose_with_outliers(a, b, rtol, atol, max_outliers):
40+
raise AssertionError(
41+
"Arrays are not close enough, even with outliers allowed."
42+
) from e
43+
44+
45+
def allclose_with_outliers(a, b, rtol=1e-7, atol=0, max_outliers=0):
46+
a = np.asanyarray(a)
47+
b = np.asanyarray(b)
48+
diff = np.abs(a - b)
49+
allowed = atol + rtol * np.abs(b)
50+
mask = diff > allowed
51+
num_outliers = np.count_nonzero(mask)
52+
return num_outliers <= max_outliers

0 commit comments

Comments
 (0)