Skip to content

Commit 0b46a5d

Browse files
authored
Merge pull request #1 from DIAGNijmegen/nnunet-preprocessing-strategy
New preprocessing strategy
2 parents 28c9bb4 + 9cf6a47 commit 0b46a5d

File tree

85 files changed

+442
-305
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

85 files changed

+442
-305
lines changed

.gitattributes

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
*.mha filter=lfs diff=lfs merge=lfs -text
2+
*.nii.gz filter=lfs diff=lfs merge=lfs -text

.github/workflows/tests.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ jobs:
1414

1515
steps:
1616
- uses: actions/checkout@v2
17+
with:
18+
lfs: 'true'
1719
- name: Set up Python ${{ matrix.python-version }}
1820
uses: actions/setup-python@v2
1921
with:

setup.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
long_description = fh.read()
66

77
setuptools.setup(
8-
version='1.1.6',
8+
version='1.2',
99
author_email='[email protected]',
1010
long_description=long_description,
1111
long_description_content_type="text/markdown",
@@ -14,5 +14,7 @@
1414
"Bug Tracker": "https://github.com/DIAGNijmegen/picai_prep/issues"
1515
},
1616
license='Apache License, Version 2.0',
17-
packages=['picai_prep', 'picai_prep.resources', 'picai_prep.examples.dcm2mha', 'picai_prep.examples.mha2nnunet'],
17+
package_dir={"": "src"}, # our packages live under src, but src is not a package itself
18+
packages=setuptools.find_packages('src', exclude=['tests']),
19+
exclude_package_data={'': ['tests']},
1820
)

src/picai_prep/examples/mha2nnunet/picai_archive.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -109,16 +109,17 @@ def generate_mha2nnunet_settings(
109109
}
110110
},
111111
"preprocessing": {
112-
"matrix_size": [
113-
20,
114-
320,
115-
320
116-
],
117-
"spacing": [
118-
3.0,
119-
0.5,
120-
0.5
121-
]
112+
# optionally, resample and perform centre crop:
113+
# "matrix_size": [
114+
# 20,
115+
# 320,
116+
# 320
117+
# ],
118+
# "spacing": [
119+
# 3.0,
120+
# 0.5,
121+
# 0.5
122+
# ],
122123
},
123124
"archive": archive_list
124125
}

src/picai_prep/examples/mha2nnunet/picai_archive_inference.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -97,16 +97,17 @@ def generate_mha2nnunet_settings(
9797
}
9898
},
9999
"preprocessing": {
100-
"matrix_size": [
101-
20,
102-
160,
103-
160
104-
],
105-
"spacing": [
106-
3.6,
107-
0.5,
108-
0.5
109-
]
100+
# optionally, resample and perform centre crop:
101+
# "matrix_size": [
102+
# 20,
103+
# 160,
104+
# 160
105+
# ],
106+
# "spacing": [
107+
# 3.6,
108+
# 0.5,
109+
# 0.5
110+
# ]
110111
},
111112
"archive": archive_list
112113
}

src/picai_prep/preprocessing.py

Lines changed: 49 additions & 140 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,11 @@
1515

1616
import SimpleITK as sitk
1717
import numpy as np
18+
from numpy.testing import assert_allclose
1819
from dataclasses import dataclass
1920
from scipy import ndimage
2021

21-
from typing import List, Tuple, Callable, Optional, Union, Any, Iterable, cast
22+
from typing import List, Callable, Optional, Union, Any, Iterable
2223
try:
2324
import numpy.typing as npt
2425
except ImportError: # pragma: no cover
@@ -32,21 +33,15 @@ class PreprocessingSettings():
3233
- matrix_size: number of voxels output volume (z, y, x)
3334
- spacing: output voxel spacing in mm (z, y, x)
3435
- physical_size: size in mm/voxel of the target volume (z, y, x)
35-
- align_physical_space: whether to align sequences to eachother, based on metadata
36-
- crop_to_first_physical_centre: whether to crop to physical centre of first sequence,
37-
or to the new centre after aligning sequences
3836
- align_segmentation: whether to align the scans using the centroid of the provided segmentation
3937
"""
40-
matrix_size: Iterable[int] = (20, 160, 160)
38+
matrix_size: Optional[Iterable[int]] = None
4139
spacing: Optional[Iterable[float]] = None
4240
physical_size: Optional[Iterable[float]] = None
43-
align_physical_space: bool = False
44-
crop_to_first_physical_centre: bool = False
4541
align_segmentation: Optional[sitk.Image] = None
4642

4743
def __post_init__(self):
48-
if self.physical_size is None:
49-
assert self.spacing, "Need either physical_size or spacing"
44+
if self.physical_size is None and self.spacing is not None and self.matrix_size is not None:
5045
# calculate physical size
5146
self.physical_size = [
5247
voxel_spacing * num_voxels
@@ -56,8 +51,7 @@ def __post_init__(self):
5651
)
5752
]
5853

59-
if self.spacing is None:
60-
assert self.physical_size, "Need either physical_size or spacing"
54+
if self.spacing is None and self.physical_size is not None and self.matrix_size is not None:
6155
# calculate spacing
6256
self.spacing = [
6357
size / num_voxels
@@ -67,13 +61,8 @@ def __post_init__(self):
6761
)
6862
]
6963

70-
@property
71-
def _spacing(self) -> Iterable[float]:
72-
return cast(Iterable[float], self.spacing)
73-
74-
@property
75-
def _physical_size(self) -> Iterable[float]:
76-
return cast(Iterable[float], self.physical_size)
64+
if self.align_segmentation is not None:
65+
raise NotImplementedError("Alignment of scans based on segmentation is not implemented yet.")
7766

7867

7968
def resample_img(
@@ -170,83 +159,6 @@ def crop_or_pad(
170159
return np.pad(image[tuple(slicer)], padding)
171160

172161

173-
def get_overlap_start_indices(img_main: sitk.Image, img_secondary: sitk.Image):
174-
# convert start index from main image to secondary image
175-
point_secondary = img_secondary.TransformIndexToPhysicalPoint((0, 0, 0))
176-
index_main = img_main.TransformPhysicalPointToContinuousIndex(point_secondary)
177-
178-
# clip index
179-
index_main = np.clip(index_main, a_min=0, a_max=None)
180-
181-
# convert main index back to secondary image
182-
point_main = img_main.TransformContinuousIndexToPhysicalPoint(index_main)
183-
index_secondary = img_secondary.TransformPhysicalPointToContinuousIndex(point_main)
184-
185-
# round secondary index up (round to 5 decimals for e.g. 18.999999999999996)
186-
index_secondary = np.ceil(np.round(index_secondary, decimals=5))
187-
188-
# convert secondary index once again to main image
189-
point_secondary = img_secondary.TransformContinuousIndexToPhysicalPoint(index_secondary)
190-
index_main = img_main.TransformPhysicalPointToIndex(point_secondary)
191-
192-
# convert and return result
193-
return np.array(index_secondary).astype(int), np.array(index_main).astype(int)
194-
195-
196-
def get_overlap_end_indices(img_main: sitk.Image, img_secondary: sitk.Image):
197-
# convert end index from secondary image to primary image
198-
point_secondary = img_secondary.TransformIndexToPhysicalPoint(img_secondary.GetSize())
199-
index_main = img_main.TransformPhysicalPointToContinuousIndex(point_secondary)
200-
201-
# clip index
202-
index_main = [min(sz, i) for (i, sz) in zip(index_main, img_main.GetSize())]
203-
204-
# convert primary index back to secondary image
205-
point_main = img_main.TransformContinuousIndexToPhysicalPoint(index_main)
206-
index_secondary = img_secondary.TransformPhysicalPointToContinuousIndex(point_main)
207-
208-
# round secondary index down (round to 5 decimals for e.g. 18.999999999999996)
209-
index_secondary = np.floor(np.round(index_secondary, decimals=5))
210-
211-
# convert secondary index once again to primary image
212-
point_secondary = img_secondary.TransformContinuousIndexToPhysicalPoint(index_secondary)
213-
index_main = img_main.TransformPhysicalPointToIndex(point_secondary)
214-
215-
# convert and return result
216-
return np.array(index_secondary).astype(int), np.array(index_main).astype(int)
217-
218-
219-
def crop_to_common_physical_space(
220-
img_main: sitk.Image,
221-
img_sec: sitk.Image
222-
) -> Tuple[sitk.Image, sitk.Image]:
223-
"""
224-
Crop SimpleITK images to the largest shared physical volume
225-
"""
226-
# determine crop indices
227-
idx_start_sec, idx_start_main = get_overlap_start_indices(img_main, img_sec)
228-
idx_end_sec, idx_end_main = get_overlap_end_indices(img_main, img_sec)
229-
230-
# check extracted indices
231-
assert ((idx_end_sec - idx_start_sec) > np.array(img_sec.GetSize()) / 2).all(), \
232-
"Found unrealistically little overlap when aligning scans, aborting."
233-
assert ((idx_end_main - idx_start_main) > np.array(img_main.GetSize()) / 2).all(), \
234-
"Found unrealistically little overlap when aligning scans, aborting."
235-
236-
# apply crop
237-
slices = [slice(idx_start, idx_end) for (idx_start, idx_end) in zip(idx_start_main, idx_end_main)]
238-
img_main = img_main[slices]
239-
240-
slices = [slice(idx_start, idx_end) for (idx_start, idx_end) in zip(idx_start_sec, idx_end_sec)]
241-
img_sec = img_sec[slices]
242-
243-
return img_main, img_sec
244-
245-
246-
def get_physical_centre(image: sitk.Image):
247-
return image.TransformContinuousIndexToPhysicalPoint(np.array(image.GetSize()) / 2.0)
248-
249-
250162
@dataclass
251163
class Sample:
252164
scans: List[sitk.Image]
@@ -260,53 +172,42 @@ class Sample:
260172
num_gt_lesions: Optional[int] = None
261173

262174
def __post_init__(self):
263-
# determine main centre
264-
self.main_centre = get_physical_centre(self.scans[0])
265-
266175
if self.lbl is not None:
267176
# keep track of connected components
268177
lbl = sitk.GetArrayFromImage(self.lbl)
269178
_, num_gt_lesions = ndimage.label(lbl, structure=np.ones((3, 3, 3)))
270179
self.num_gt_lesions = num_gt_lesions
271180

272-
def crop_to_common_physical_space(self):
273-
"""
274-
Align physical centre of the first scan (e.g., T2W) with subsequent scans (e.g., ADC, high b-value)
275-
"""
276-
main_centre = get_physical_centre(self.scans[0])
277-
278-
should_align_scans = False
279-
for scan in self.scans[1:]:
280-
secondary_centre = get_physical_centre(scan)
281-
282-
# calculate distance from center of first scan (e.g., T2W) to center of secondary scan (e.g., ADC, high b-value)
283-
distance = np.sqrt(np.sum((np.array(main_centre) - np.array(secondary_centre))**2))
284-
285-
# if difference in center coordinates is more than 2mm, align the scans
286-
if distance > 2:
287-
print(f"Aligning scans with distance of {distance:.1f} mm between centers for {self.name}.")
288-
should_align_scans = True
289-
290-
if should_align_scans:
291-
for i, main_scan in enumerate(self.scans):
292-
for j, secondary_scan in enumerate(self.scans):
293-
if i == j:
294-
continue
295-
296-
# align scans
297-
img_main, img_sec = crop_to_common_physical_space(main_scan, secondary_scan)
298-
self.scans[i] = img_main
299-
self.scans[j] = img_sec
300-
301-
def resample(self):
302-
"""Resample scans and label"""
181+
def resample_to_first_scan(self):
182+
"""Resample scans and label to the first scan"""
183+
# set up resampler to resolution, field of view, etc. of first scan
184+
resampler = sitk.ResampleImageFilter() # default linear
185+
resampler.SetReferenceImage(self.scans[0])
186+
resampler.SetInterpolator(sitk.sitkBSpline)
187+
188+
# resample other images
189+
self.scans[1:] = [resampler.Execute(scan) for scan in self.scans[1:]]
190+
191+
# resample annotation
192+
resampler.SetInterpolator(sitk.sitkNearestNeighbor)
193+
if self.lbl is not None:
194+
self.lbl = resampler.Execute(self.lbl)
195+
196+
def resample_spacing(self, spacing: Optional[Iterable[float]] = None):
197+
"""Resample scans and label to the target spacing"""
198+
if spacing is None:
199+
assert self.settings.spacing is not None
200+
spacing = self.settings.spacing
201+
202+
# resample scans to target resolution
303203
self.scans = [
304-
resample_img(scan, out_spacing=self.settings._spacing, is_label=False)
204+
resample_img(scan, out_spacing=spacing, is_label=False)
305205
for scan in self.scans
306206
]
307207

208+
# resample annotation to target resolution
308209
if self.lbl is not None:
309-
self.lbl = resample_img(self.lbl, out_spacing=self.settings._spacing, is_label=True)
210+
self.lbl = resample_img(self.lbl, out_spacing=spacing, is_label=True)
310211

311212
def centre_crop(self):
312213
"""Centre crop scans and label"""
@@ -318,7 +219,7 @@ def centre_crop(self):
318219
if self.lbl is not None:
319220
self.lbl = crop_or_pad(self.lbl, size=self.settings.matrix_size)
320221

321-
def copy_physical_metadata(self):
222+
def align_physical_metadata(self, check_almost_equal=True):
322223
"""Align the origin and direction of each scan, and label"""
323224
case_origin, case_direction, case_spacing = None, None, None
324225
for img in self.scans:
@@ -328,6 +229,13 @@ def copy_physical_metadata(self):
328229
case_direction = img.GetDirection()
329230
case_spacing = img.GetSpacing()
330231
else:
232+
if check_almost_equal:
233+
# check if current scan's metadata is almost equal to the first scan
234+
assert_allclose(img.GetOrigin(), case_origin)
235+
assert_allclose(img.GetDirection(), case_direction)
236+
assert_allclose(img.GetSpacing(), case_spacing)
237+
238+
# copy over first scan's metadata to current scan
331239
img.SetOrigin(case_origin)
332240
img.SetDirection(case_direction)
333241
img.SetSpacing(case_spacing)
@@ -348,18 +256,19 @@ def preprocess(self):
348256
# apply scan transformation
349257
self.scans = [self.scan_preprocess_func(scan) for scan in self.scans]
350258

351-
if self.settings.align_physical_space:
352-
# align sequences based on metadata
353-
self.crop_to_common_physical_space()
259+
if self.settings.spacing is not None:
260+
# resample scans and label to specified spacing
261+
self.resample_spacing()
354262

355-
# resample scans and label
356-
self.resample()
263+
if self.settings.matrix_size is not None:
264+
# perform centre crop
265+
self.centre_crop()
357266

358-
# perform centre crop
359-
self.centre_crop()
267+
# resample scans and label to first scan's spacing, field-of-view, etc.
268+
self.resample_to_first_scan()
360269

361270
# copy physical metadata to align subvoxel differences between sequences
362-
self.copy_physical_metadata()
271+
self.align_physical_metadata()
363272

364273
if self.lbl is not None:
365274
# check connected components of annotation

src/picai_prep/resources/mha2nnunet_schema.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -38,24 +38,16 @@
3838
"type": "object",
3939
"description": "Preprocessing parameters",
4040
"properties": {
41-
"align_physical_space": {
42-
"description": "...",
43-
"type": "boolean"
44-
},
45-
"crop_to_first_physical_centre": {
46-
"description": "...",
47-
"type": "boolean"
48-
},
4941
"physical_size": {
50-
"description": "...",
42+
"description": "Target field-of-view in mm (z, y, x). Automatically calculated if `matrix_size` and `spacing` are set.",
5143
"$ref": "#/$defs/3d"
5244
},
5345
"matrix_size": {
54-
"description": "Defaults to [20, 160, 160] if neither this or 'physical_size' is set.",
46+
"description": "Target matrix size. Automatically calculated if `physical_size` and `spacing` are set.",
5547
"$ref": "#/$defs/3d"
5648
},
5749
"spacing": {
58-
"description": "...",
50+
"description": "Target resolution in mm/voxel (z, y, x). Automatically calculated if `physical_size` and `matrix_size` are set.",
5951
"$ref": "#/$defs/3d"
6052
}
6153
},
-5.5 KB
Binary file not shown.
-212 Bytes
Binary file not shown.
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
version https://git-lfs.github.com/spec/v1
2+
oid sha256:f948ae1b3b346bc6889b21fca54970c2c3fd89498cd16e74f8208db7c692469d
3+
size 6364

0 commit comments

Comments
 (0)