Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 29 additions & 17 deletions scipy_doctest/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,27 +249,39 @@ def try_convert_namedtuple(got):


def try_convert_printed_array(got):
"""Printed arrays: reinsert commas.
"""Printed arrays (no commas): reinsert commas.

Handles arbitrary N-dimensional arrays (including 3D and higher), where
numpy separates sub-arrays with blank lines. We never parse numeric
values -- we only look at square brackets and whitespace. Walking the
string, each run of whitespace is replaced with either ``", "`` (when it
sits between two values or sub-arrays, e.g. ``number number`` or ``] [``)
or nothing (when it abuts an opening/closing bracket, e.g. ``[ 0`` or
``0 ]``).
"""
# a minimal version is `s_got = ", ".join(got[1:-1].split())`
# but it fails if there's a space after the opening bracket: "[ 0 1 2 ]"
# For 2D arrays, split into rows, drop spurious entries, then reassemble.
if not got.startswith('['):
return got

g1 = got[1:-1] # strip outer "[...]"-s
rows = [x for x in g1.split("[") if x]
rows2 = [", ".join(row.split()) for row in rows]

if got.startswith("[["):
# was a 2D array, restore the opening brackets in rows; XXX clean up
rows3 = ["[" + row for row in rows2]
else:
rows3 = rows2

# add back the outer brackets
s_got = "[" + ", ".join(rows3) + "]"
return s_got
out = []
prev = '' # last emitted non-whitespace character
i, n = 0, len(got)
while i < n:
ch = got[i]
if ch.isspace():
j = i
while j < n and got[j].isspace():
j += 1
nxt = got[j] if j < n else ''
prev_closes = prev.isalnum() or prev in '.]'
next_opens = nxt.isalnum() or nxt in '.+-['
if prev_closes and next_opens:
out.append(', ')
i = j
else:
out.append(ch)
prev = ch
i += 1
return ''.join(out)


def has_masked(got):
Expand Down
37 changes: 37 additions & 0 deletions scipy_doctest/tests/module_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,43 @@ def rank_3_array_repr():
"""


def rank_3_printed_array():
"""Check recovery of a printed (no commas) rank-3 array.

See https://github.com/scipy/scipy_doctest/issues/21

>>> import numpy as np
>>> print(np.arange(24).reshape(2, 3, 4))
[[[ 0 1 2 3]
[ 4 5 6 7]
[ 8 9 10 11]]
<BLANKLINE>
[[12 13 14 15]
[16 17 18 19]
[20 21 22 23]]]
"""


def rank_4_printed_array():
"""Check recovery of a printed (no commas) rank-4 array.

>>> import numpy as np
>>> print(np.arange(16).reshape(2, 2, 2, 2))
[[[[ 0 1]
[ 2 3]]
<BLANKLINE>
[[ 4 5]
[ 6 7]]]
<BLANKLINE>
<BLANKLINE>
[[[ 8 9]
[10 11]]
<BLANKLINE>
[[12 13]
[14 15]]]]
"""


# This is used by test_testmod.py::test_public_object_discovery
# While in test we only need __all__ to be not empty, let's make it correct, too.
__all__ = [x for x in vars().keys() if not x.startswith("_")]
88 changes: 88 additions & 0 deletions scipy_doctest/tests/test_ndim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
"""Tests for N-dimensional array handling in DTChecker.

Regression tests for https://github.com/scipy/scipy_doctest/issues/21
"""

import doctest

import numpy as np

from ..impl import DTChecker, try_convert_printed_array


def test_try_convert_printed_array_1d():
s = "[0 1 2]"
out = try_convert_printed_array(s)
assert eval(out) == [0, 1, 2]


def test_try_convert_printed_array_2d():
s = "[[0 1 2]\n [3 4 5]]"
out = try_convert_printed_array(s)
assert eval(out) == [[0, 1, 2], [3, 4, 5]]


def test_try_convert_printed_array_3d():
a = np.arange(24).reshape(2, 3, 4)
out = try_convert_printed_array(str(a))
assert np.array_equal(np.array(eval(out)), a)


def test_try_convert_printed_array_4d():
a = np.arange(16).reshape(2, 2, 2, 2)
out = try_convert_printed_array(str(a))
assert np.array_equal(np.array(eval(out)), a)


def test_try_convert_printed_array_3d_floats():
a = np.linspace(0.0, 1.0, 12).reshape(2, 2, 3)
out = try_convert_printed_array(str(a))
assert np.allclose(np.array(eval(out)), a)


def test_try_convert_printed_array_negatives():
a = np.array([[[-1, 2], [3, -4]], [[5, -6], [-7, 8]]])
out = try_convert_printed_array(str(a))
assert np.array_equal(np.array(eval(out)), a)


def test_check_output_3d_printed_array():
"""A printed 3D array repr (no commas, blank line between slabs)
should round-trip through the checker."""
checker = DTChecker()
a = np.arange(24).reshape(2, 3, 4)
s = str(a)
assert checker.check_output(s, s, doctest.ELLIPSIS)


def test_check_output_3d_array_repr_blankline():
"""The numpy `array(...)` repr with <BLANKLINE> directives passes."""
checker = DTChecker()
want = (
"array([[[ 0, 1, 2, 3],\n"
" [ 4, 5, 6, 7],\n"
" [ 8, 9, 10, 11]],\n"
"<BLANKLINE>\n"
" [[12, 13, 14, 15],\n"
" [16, 17, 18, 19],\n"
" [20, 21, 22, 23]]])"
)
got = repr(np.arange(24).reshape(2, 3, 4))
assert checker.check_output(want, got, doctest.ELLIPSIS)


def test_check_output_3d_printed_with_tolerance():
"""Printed 3D arrays with floats slightly off should pass via np.allclose."""
checker = DTChecker()
a = np.linspace(0.0, 1.0, 12).reshape(2, 2, 3)
want = str(a)
b = a + 1e-9
got = str(b)
assert checker.check_output(want, got, doctest.ELLIPSIS)


def test_check_output_4d_printed_array():
checker = DTChecker()
a = np.arange(16).reshape(2, 2, 2, 2)
s = str(a)
assert checker.check_output(s, s, doctest.ELLIPSIS)