Skip to content

Commit 1e41afc

Browse files
Michiel CottaarMichiel Cottaar
Michiel Cottaar
authored and
Michiel Cottaar
committed
ENH: allow numpy integer scalars to index fileslice
This is done by allow 0-dimensional numpy integers to pass. Tests are added to show identical behaviour as normal integers for canonical_slicers and fileslice (the two places where `is_fancy` is used).
1 parent e2f50b4 commit 1e41afc

File tree

2 files changed

+23
-7
lines changed

2 files changed

+23
-7
lines changed

nibabel/fileslice.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def is_fancy(sliceobj):
4848
if not isinstance(sliceobj, tuple):
4949
sliceobj = (sliceobj,)
5050
for slicer in sliceobj:
51-
if hasattr(slicer, 'dtype'): # ndarray always fancy
51+
if hasattr(slicer, 'dtype') and slicer.ndim > 0: # ndarray always fancy
5252
return True
5353
# slice or Ellipsis or None OK for basic
5454
if isinstance(slicer, slice) or slicer in (None, Ellipsis):

nibabel/tests/test_fileslice.py

+22-6
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def _check_slice(sliceobj):
3333

3434

3535
def test_is_fancy():
36-
slices = (2, [2], [2, 3], Ellipsis, np.array(2), np.array((2, 3)))
36+
slices = (2, [2], [2, 3], Ellipsis, np.array((2, 3)))
3737
for slice0 in slices:
3838
_check_slice(slice0)
3939
_check_slice((slice0,)) # tuple is same
@@ -46,7 +46,7 @@ def test_is_fancy():
4646
assert not is_fancy((None,))
4747
assert not is_fancy((None, 1))
4848
assert not is_fancy((1, None))
49-
# Chack that actual False returned (rather than falsey)
49+
# Check that actual False returned (rather than falsey)
5050
assert is_fancy(1) is False
5151

5252

@@ -57,7 +57,9 @@ def test_canonical_slicers():
5757
slice(0, 9),
5858
slice(1, 10),
5959
slice(1, 10, 2),
60-
2)
60+
2,
61+
np.array(2))
62+
6163
shape = (10, 10)
6264
for slice0 in slicers:
6365
assert canonical_slicers((slice0,), shape) == (slice0, slice(None))
@@ -93,9 +95,9 @@ def test_canonical_slicers():
9395
assert canonical_slicers(slice(None), shape) == (slice(None), slice(None))
9496
# Check fancy indexing raises error
9597
with pytest.raises(ValueError):
96-
canonical_slicers((np.array(1), 1), shape)
98+
canonical_slicers((np.array([1]), 1), shape)
9799
with pytest.raises(ValueError):
98-
canonical_slicers((1, np.array(1)), shape)
100+
canonical_slicers((1, np.array([1])), shape)
99101
# Check out of range integer raises error
100102
with pytest.raises(ValueError):
101103
canonical_slicers((10,), shape)
@@ -111,6 +113,11 @@ def test_canonical_slicers():
111113
# Check negative -> positive
112114
assert canonical_slicers(-1, shape) == (9, slice(None))
113115
assert canonical_slicers((slice(None), -1), shape) == (slice(None), 9)
116+
# check numpy integer scalars behave the same as numpy integers
117+
assert canonical_slicers(np.array(2), shape) == canonical_slicers(2, shape)
118+
assert canonical_slicers((np.array(2), np.array(1)), shape) == canonical_slicers((2, 1), shape)
119+
assert canonical_slicers((2, np.array(1)), shape) == canonical_slicers((2, 1), shape)
120+
assert canonical_slicers((np.array(2), 1), shape) == canonical_slicers((2, 1), shape)
114121

115122

116123
def test_slice2outax():
@@ -664,20 +671,29 @@ def slicer_samples(shape):
664671
if ndim == 0:
665672
return
666673
yield (None, 0)
674+
yield (None, np.array(0))
667675
yield (0, None)
676+
yield (np.array(0), None)
668677
yield (Ellipsis, -1)
678+
yield (Ellipsis, np.array(-1))
669679
yield (-1, Ellipsis)
680+
yield (np.array(-1), Ellipsis)
670681
yield (None, Ellipsis)
671682
yield (Ellipsis, None)
672683
yield (Ellipsis, None, None)
673684
if ndim == 1:
674685
return
675686
yield (0, None, slice(None))
687+
yield (np.array(0), None, slice(None))
676688
yield (Ellipsis, -1, None)
689+
yield (Ellipsis, np.array(-1), None)
677690
yield (0, Ellipsis, None)
691+
yield (np.array(0), Ellipsis, None)
678692
if ndim == 2:
679693
return
680694
yield (slice(None), 0, -1, None)
695+
yield (slice(None), np.array(0), np.array(-1), None)
696+
yield (np.array(0), slice(None), np.array(-1), None)
681697

682698

683699
def test_fileslice():
@@ -711,7 +727,7 @@ def test_fileslice_errors():
711727
_check_slicer((1,), arr, fobj, 0, 'C')
712728
# Fancy indexing raises error
713729
with pytest.raises(ValueError):
714-
fileslice(fobj, (np.array(1),), (2, 3, 4), arr.dtype)
730+
fileslice(fobj, (np.array([1]),), (2, 3, 4), arr.dtype)
715731

716732

717733
def test_fileslice_heuristic():

0 commit comments

Comments
 (0)