Skip to content

Commit de6e4ce

Browse files
committed
BF: array images return array if OK float type
The `get_fdata` method should return the contained array if the array is the correct (matching) floating point type. This was not happening because of I used the `astype` method, which, by default, does a copy. Fix and test.
1 parent 32c5f2e commit de6e4ce

File tree

2 files changed

+15
-3
lines changed

2 files changed

+15
-3
lines changed

nibabel/dataobj_images.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ def get_fdata(self, caching='fill', dtype=np.float64):
344344
if self._fdata_cache is not None:
345345
if self._fdata_cache.dtype.type == dtype.type:
346346
return self._fdata_cache
347-
data = np.asanyarray(self._dataobj).astype(dtype)
347+
data = np.asanyarray(self._dataobj).astype(dtype, copy=False)
348348
if caching == 'fill':
349349
self._fdata_cache = data
350350
return data

nibabel/tests/test_image_api.py

+14-2
Original file line numberDiff line numberDiff line change
@@ -342,8 +342,8 @@ def _check_array_caching(self, imaker, meth_name, caching):
342342
# Returned data same object as underlying dataobj if using
343343
# old ``get_data`` method, or using newer ``get_fdata``
344344
# method, where original array was float64.
345-
dataobj_is_data = (img.dataobj.dtype == np.float64
346-
or method == img.get_data)
345+
arr_dtype = img.dataobj.dtype
346+
dataobj_is_data = arr_dtype == np.float64 or method == img.get_data
347347
# Set something to the output array.
348348
data[:] = 42
349349
get_result_changed = np.all(get_data_func() == 42)
@@ -367,6 +367,16 @@ def _check_array_caching(self, imaker, meth_name, caching):
367367
# cache state.
368368
img.uncache()
369369
assert_true(img.in_memory)
370+
if meth_name != 'get_fdata':
371+
return
372+
# Return original array from get_fdata only if the input array is the
373+
# requested dtype.
374+
float_types = np.sctypes['float']
375+
if arr_dtype not in float_types:
376+
return
377+
for float_type in float_types:
378+
data = get_data_func(dtype=float_type)
379+
assert_equal(data is img.dataobj, arr_dtype == float_type)
370380

371381
def validate_data_deprecated(self, imaker, params):
372382
# Check _data property still exists, but raises warning
@@ -542,6 +552,8 @@ class TestAnalyzeAPI(ImageHeaderAPI):
542552
has_scaling = False
543553
can_save = True
544554
standard_extension = '.img'
555+
# Supported dtypes for storing to disk
556+
storable_dtypes = (np.uint8, np.int16, np.int32, np.float32, np.float64)
545557

546558

547559
class TestSpatialImageAPI(TestAnalyzeAPI):

0 commit comments

Comments
 (0)