diff --git a/nibabel/nifti1.py b/nibabel/nifti1.py index 548ad34658..b7e81405a5 100644 --- a/nibabel/nifti1.py +++ b/nibabel/nifti1.py @@ -2014,13 +2014,9 @@ def as_reoriented(self, ornt): return img # Also apply the transform to the dim_info fields - new_dim = list(img.header.get_dim_info()) - for idx, value in enumerate(new_dim): - # For each value, leave as None if it was that way, - # otherwise check where we have mapped it to - if value is None: - continue - new_dim[idx] = np.where(ornt[:, 0] == idx)[0] + new_dim = [ + None if orig_dim is None else int(ornt[orig_dim, 0]) + for orig_dim in img.header.get_dim_info()] img.header.set_dim_info(*new_dim) diff --git a/nibabel/tests/test_nifti1.py b/nibabel/tests/test_nifti1.py index 78f876ec7d..8ae9d35116 100644 --- a/nibabel/tests/test_nifti1.py +++ b/nibabel/tests/test_nifti1.py @@ -28,8 +28,10 @@ from nibabel.spatialimages import HeaderDataError from nibabel.tmpdirs import InTemporaryDirectory from ..freesurfer import load as mghload +from ..orientations import aff2axcodes from .test_arraywriters import rt_err_estimate, IUINT_TYPES +from .test_orientations import ALL_ORNTS from .test_helpers import bytesio_filemap, bytesio_round_trip from .nibabel_data import get_nibabel_data, needs_nibabel_data @@ -1403,6 +1405,36 @@ def test_rt_bias(self): bias_thresh = np.max([max_miss / np.sqrt(count), eps]) assert_true(np.abs(bias) < bias_thresh) + def test_reoriented_dim_info(self): + # Check that dim_info is reoriented correctly + arr = np.arange(24).reshape((2, 3, 4)) + # Start as RAS + aff = np.diag([2, 3, 4, 1]) + simg = self.single_class(arr, aff) + for freq, phas, slic in ((0, 1, 2), + (0, 2, 1), + (1, 0, 2), + (2, 0, 1), + (None, None, None), + (0, 2, None), + (0, None, None), + (None, 2, 1), + (None, None, 1), + ): + simg.header.set_dim_info(freq, phas, slic) + fdir = 'RAS'[freq] if freq is not None else None + pdir = 'RAS'[phas] if phas is not None else None + sdir = 'RAS'[slic] if slic is not None else None + for ornt in ALL_ORNTS: + rimg = simg.as_reoriented(np.array(ornt)) + axcode = aff2axcodes(rimg.affine) + dirs = ''.join(axcode).replace('P', 'A').replace('I', 'S').replace('L', 'R') + new_freq, new_phas, new_slic = rimg.header.get_dim_info() + new_fdir = dirs[new_freq] if new_freq is not None else None + new_pdir = dirs[new_phas] if new_phas is not None else None + new_sdir = dirs[new_slic] if new_slic is not None else None + assert_equal((new_fdir, new_pdir, new_sdir), (fdir, pdir, sdir)) + @runif_extra_has('slow') def test_large_nifti1(): diff --git a/nibabel/tests/test_orientations.py b/nibabel/tests/test_orientations.py index 58c5e5f9e2..0605d33f20 100644 --- a/nibabel/tests/test_orientations.py +++ b/nibabel/tests/test_orientations.py @@ -83,6 +83,18 @@ OUT_ORNTS = [np.array(ornt) for ornt in OUT_ORNTS] +_LABELS = ['RL', 'AP', 'SI'] +ALL_AXCODES = [(_LABELS[i0][j0], _LABELS[i1][j1], _LABELS[i2][j2]) + for i0 in range(3) for i1 in range(3) for i2 in range(3) + if i0 != i1 != i2 != i0 + for j0 in range(2) for j1 in range(2) for j2 in range(2)] + +ALL_ORNTS = [[[i0, j0], [i1, j1], [i2, j2]] + for i0 in range(3) for i1 in range(3) for i2 in range(3) + if i0 != i1 != i2 != i0 + for j0 in [1, -1] for j1 in [1, -1] for j2 in [1, -1]] + + def same_transform(taff, ornt, shape): # Applying transformations implied by `ornt` to a made-up array # ``arr`` of shape `shape`, results in ``t_arr``. When the point @@ -125,6 +137,10 @@ def test_apply(): apply_orientation, a, [[0, 1], [np.nan, np.nan], [2, 1]]) + shape = np.array(a.shape) + for ornt in ALL_ORNTS: + t_arr = apply_orientation(a, ornt) + assert_array_equal(a.shape, np.array(t_arr.shape)[np.array(ornt)[:, 0]]) def test_flip_axis(): @@ -282,6 +298,9 @@ def test_ornt2axcodes(): # As do directions not in range assert_raises(ValueError, ornt2axcodes, [[0, 0]]) + for axcodes, ornt in zip(ALL_AXCODES, ALL_ORNTS): + assert_equal(ornt2axcodes(ornt), axcodes) + def test_axcodes2ornt(): # Go from axcodes back to orientations @@ -340,6 +359,9 @@ def test_axcodes2ornt(): assert_raises(ValueError, axcodes2ornt, 'blD', ('SD', 'BF', 'lD')) assert_raises(ValueError, axcodes2ornt, 'blD', ('SD', 'SF', 'lD')) + for axcodes, ornt in zip(ALL_AXCODES, ALL_ORNTS): + assert_array_equal(axcodes2ornt(axcodes), ornt) + def test_aff2axcodes(): assert_equal(aff2axcodes(np.eye(4)), tuple('RAS'))