Skip to content

Commit c8f5d2a

Browse files
authored
Merge pull request #380 from dchorel/Fix_concatenate_3D_with_4D_images
[WIP] Concatenate 3d with 4d images into a 4d image
2 parents abcf770 + 53dbd94 commit c8f5d2a

File tree

3 files changed

+53
-11
lines changed

3 files changed

+53
-11
lines changed

scilpy/image/operations.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
# -*- coding: utf-8 -*-
22

33
"""
4-
Utility operations provided for scil_image_math.py and scil_connectivity_math.py
4+
Utility operations provided for scil_image_math.py
5+
and scil_connectivity_math.py
56
They basically act as wrappers around numpy to avoid installing MRtrix/FSL
67
to apply simple operations on nibabel images or numpy arrays.
78
"""
@@ -54,7 +55,7 @@ def get_image_ops():
5455
"""Get a dictionary of all functions relating to image operations"""
5556
image_ops = get_array_ops()
5657
image_ops.update(OrderedDict([
57-
('concatenate', concat),
58+
('concatenate', concatenate),
5859
('dilation', dilation),
5960
('erosion', erosion),
6061
('closing', closing),
@@ -83,6 +84,13 @@ def _validate_imgs(*imgs):
8384
raise ValueError('Not all inputs have the same shape!')
8485

8586

87+
def _validate_imgs_concat(*imgs):
88+
"""Make sure that all inputs are images."""
89+
for img in imgs:
90+
if not isinstance(img, nib.Nifti1Image):
91+
raise ValueError('Inputs are not all images')
92+
93+
8694
def _validate_length(input_list, length, at_least=False):
8795
"""Make sure the the input list has the right number of arguments
8896
(length)."""
@@ -499,20 +507,30 @@ def invert(input_list, ref_img):
499507
return output_data
500508

501509

502-
def concat(input_list, ref_img):
510+
def concatenate(input_list, ref_img):
503511
"""
504-
concat: IMGs
505-
Concatenate a list of 3D images into a single 4D image.
512+
concatenate: IMGs
513+
Concatenate a list of 3D and 4D images into a single 4D image.
506514
"""
507-
_validate_imgs(*input_list, ref_img)
508-
if len(input_list[0].header.get_data_shape()) != 3:
509-
raise ValueError('Concatenate require 3D arrays.')
515+
516+
_validate_imgs_concat(*input_list, ref_img)
517+
if len(input_list[0].header.get_data_shape()) > 4:
518+
raise ValueError('Concatenate require 3D or 4D arrays.')
510519

511520
input_data = []
512521
for img in input_list:
522+
513523
data = img.get_fdata(dtype=np.float64)
514-
input_data.append(data)
524+
525+
if len(img.header.get_data_shape()) == 4:
526+
data = np.rollaxis(data, 3)
527+
for i in range(0, len(data)):
528+
input_data.append(data[i])
529+
else:
530+
input_data.append(data)
531+
515532
img.uncache()
533+
516534
return np.rollaxis(np.stack(input_data), axis=0, start=4)
517535

518536

scripts/scil_image_math.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,15 @@ def main():
107107
found_ref = True
108108
break
109109

110+
# If there's a 4D image, replace the previous 3D image with
111+
# this one for reference
112+
for input_arg in args.in_images:
113+
if not is_float(input_arg):
114+
ref_img = nib.load(input_arg)
115+
if len(ref_img.shape) == 4:
116+
mask = np.zeros(ref_img.shape)
117+
break
118+
110119
if not found_ref:
111120
raise ValueError('Requires at least one nifti image.')
112121

@@ -137,7 +146,10 @@ def main():
137146

138147
if isinstance(img, nib.Nifti1Image):
139148
data = img.get_fdata(dtype=np.float64)
140-
mask[data > 0] = 1
149+
if data.ndim == 4:
150+
mask[np.sum(data, axis=3).astype(bool) > 0] = 1
151+
else:
152+
mask[data > 0] = 1
141153
img.uncache()
142154
input_img.append(img)
143155

scripts/tests/test_image_math.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def test_execution_low_mult(script_runner):
4545
assert ret.success
4646

4747

48-
def test_execution_concat(script_runner):
48+
def test_execution_concatenate(script_runner):
4949
os.chdir(os.path.expanduser(tmp_dir.name))
5050
in_img_1 = os.path.join(get_home(), 'atlas', 'ids', '10.nii.gz')
5151
in_img_2 = os.path.join(get_home(), 'atlas', 'ids', '11.nii.gz')
@@ -57,3 +57,15 @@ def test_execution_concat(script_runner):
5757
in_img_1, in_img_2, in_img_3, in_img_4, in_img_5,
5858
in_img_6, 'concat_ids.nii.gz')
5959
assert ret.success
60+
61+
62+
def test_execution_concatenate_4D(script_runner):
63+
os.chdir(os.path.expanduser(tmp_dir.name))
64+
in_img_1 = os.path.join(get_home(), 'atlas', 'ids', '10.nii.gz')
65+
in_img_2 = os.path.join(get_home(), 'atlas', 'ids', '8_10.nii.gz')
66+
in_img_3 = os.path.join(get_home(), 'atlas', 'ids', '12.nii.gz')
67+
in_img_4 = os.path.join(get_home(), 'atlas', 'ids', '8_10.nii.gz')
68+
ret = script_runner.run('scil_image_math.py', 'concatenate',
69+
in_img_1, in_img_2, in_img_3, in_img_4,
70+
'concat_ids_4d.nii.gz')
71+
assert ret.success

0 commit comments

Comments
 (0)