1515
1616import SimpleITK as sitk
1717import numpy as np
18+ from numpy .testing import assert_allclose
1819from dataclasses import dataclass
1920from scipy import ndimage
2021
21- from typing import List , Tuple , Callable , Optional , Union , Any , Iterable , cast
22+ from typing import List , Callable , Optional , Union , Any , Iterable
2223try :
2324 import numpy .typing as npt
2425except ImportError : # pragma: no cover
@@ -32,21 +33,15 @@ class PreprocessingSettings():
3233 - matrix_size: number of voxels output volume (z, y, x)
3334 - spacing: output voxel spacing in mm (z, y, x)
3435 - physical_size: size in mm/voxel of the target volume (z, y, x)
35- - align_physical_space: whether to align sequences to eachother, based on metadata
36- - crop_to_first_physical_centre: whether to crop to physical centre of first sequence,
37- or to the new centre after aligning sequences
3836 - align_segmentation: whether to align the scans using the centroid of the provided segmentation
3937 """
40- matrix_size : Iterable [int ] = ( 20 , 160 , 160 )
38+ matrix_size : Optional [ Iterable [int ]] = None
4139 spacing : Optional [Iterable [float ]] = None
4240 physical_size : Optional [Iterable [float ]] = None
43- align_physical_space : bool = False
44- crop_to_first_physical_centre : bool = False
4541 align_segmentation : Optional [sitk .Image ] = None
4642
4743 def __post_init__ (self ):
48- if self .physical_size is None :
49- assert self .spacing , "Need either physical_size or spacing"
44+ if self .physical_size is None and self .spacing is not None and self .matrix_size is not None :
5045 # calculate physical size
5146 self .physical_size = [
5247 voxel_spacing * num_voxels
@@ -56,8 +51,7 @@ def __post_init__(self):
5651 )
5752 ]
5853
59- if self .spacing is None :
60- assert self .physical_size , "Need either physical_size or spacing"
54+ if self .spacing is None and self .physical_size is not None and self .matrix_size is not None :
6155 # calculate spacing
6256 self .spacing = [
6357 size / num_voxels
@@ -67,13 +61,8 @@ def __post_init__(self):
6761 )
6862 ]
6963
70- @property
71- def _spacing (self ) -> Iterable [float ]:
72- return cast (Iterable [float ], self .spacing )
73-
74- @property
75- def _physical_size (self ) -> Iterable [float ]:
76- return cast (Iterable [float ], self .physical_size )
64+ if self .align_segmentation is not None :
65+ raise NotImplementedError ("Alignment of scans based on segmentation is not implemented yet." )
7766
7867
7968def resample_img (
@@ -170,83 +159,6 @@ def crop_or_pad(
170159 return np .pad (image [tuple (slicer )], padding )
171160
172161
173- def get_overlap_start_indices (img_main : sitk .Image , img_secondary : sitk .Image ):
174- # convert start index from main image to secondary image
175- point_secondary = img_secondary .TransformIndexToPhysicalPoint ((0 , 0 , 0 ))
176- index_main = img_main .TransformPhysicalPointToContinuousIndex (point_secondary )
177-
178- # clip index
179- index_main = np .clip (index_main , a_min = 0 , a_max = None )
180-
181- # convert main index back to secondary image
182- point_main = img_main .TransformContinuousIndexToPhysicalPoint (index_main )
183- index_secondary = img_secondary .TransformPhysicalPointToContinuousIndex (point_main )
184-
185- # round secondary index up (round to 5 decimals for e.g. 18.999999999999996)
186- index_secondary = np .ceil (np .round (index_secondary , decimals = 5 ))
187-
188- # convert secondary index once again to main image
189- point_secondary = img_secondary .TransformContinuousIndexToPhysicalPoint (index_secondary )
190- index_main = img_main .TransformPhysicalPointToIndex (point_secondary )
191-
192- # convert and return result
193- return np .array (index_secondary ).astype (int ), np .array (index_main ).astype (int )
194-
195-
196- def get_overlap_end_indices (img_main : sitk .Image , img_secondary : sitk .Image ):
197- # convert end index from secondary image to primary image
198- point_secondary = img_secondary .TransformIndexToPhysicalPoint (img_secondary .GetSize ())
199- index_main = img_main .TransformPhysicalPointToContinuousIndex (point_secondary )
200-
201- # clip index
202- index_main = [min (sz , i ) for (i , sz ) in zip (index_main , img_main .GetSize ())]
203-
204- # convert primary index back to secondary image
205- point_main = img_main .TransformContinuousIndexToPhysicalPoint (index_main )
206- index_secondary = img_secondary .TransformPhysicalPointToContinuousIndex (point_main )
207-
208- # round secondary index down (round to 5 decimals for e.g. 18.999999999999996)
209- index_secondary = np .floor (np .round (index_secondary , decimals = 5 ))
210-
211- # convert secondary index once again to primary image
212- point_secondary = img_secondary .TransformContinuousIndexToPhysicalPoint (index_secondary )
213- index_main = img_main .TransformPhysicalPointToIndex (point_secondary )
214-
215- # convert and return result
216- return np .array (index_secondary ).astype (int ), np .array (index_main ).astype (int )
217-
218-
219- def crop_to_common_physical_space (
220- img_main : sitk .Image ,
221- img_sec : sitk .Image
222- ) -> Tuple [sitk .Image , sitk .Image ]:
223- """
224- Crop SimpleITK images to the largest shared physical volume
225- """
226- # determine crop indices
227- idx_start_sec , idx_start_main = get_overlap_start_indices (img_main , img_sec )
228- idx_end_sec , idx_end_main = get_overlap_end_indices (img_main , img_sec )
229-
230- # check extracted indices
231- assert ((idx_end_sec - idx_start_sec ) > np .array (img_sec .GetSize ()) / 2 ).all (), \
232- "Found unrealistically little overlap when aligning scans, aborting."
233- assert ((idx_end_main - idx_start_main ) > np .array (img_main .GetSize ()) / 2 ).all (), \
234- "Found unrealistically little overlap when aligning scans, aborting."
235-
236- # apply crop
237- slices = [slice (idx_start , idx_end ) for (idx_start , idx_end ) in zip (idx_start_main , idx_end_main )]
238- img_main = img_main [slices ]
239-
240- slices = [slice (idx_start , idx_end ) for (idx_start , idx_end ) in zip (idx_start_sec , idx_end_sec )]
241- img_sec = img_sec [slices ]
242-
243- return img_main , img_sec
244-
245-
246- def get_physical_centre (image : sitk .Image ):
247- return image .TransformContinuousIndexToPhysicalPoint (np .array (image .GetSize ()) / 2.0 )
248-
249-
250162@dataclass
251163class Sample :
252164 scans : List [sitk .Image ]
@@ -260,53 +172,42 @@ class Sample:
260172 num_gt_lesions : Optional [int ] = None
261173
262174 def __post_init__ (self ):
263- # determine main centre
264- self .main_centre = get_physical_centre (self .scans [0 ])
265-
266175 if self .lbl is not None :
267176 # keep track of connected components
268177 lbl = sitk .GetArrayFromImage (self .lbl )
269178 _ , num_gt_lesions = ndimage .label (lbl , structure = np .ones ((3 , 3 , 3 )))
270179 self .num_gt_lesions = num_gt_lesions
271180
272- def crop_to_common_physical_space (self ):
273- """
274- Align physical centre of the first scan (e.g., T2W) with subsequent scans (e.g., ADC, high b-value)
275- """
276- main_centre = get_physical_centre (self .scans [0 ])
277-
278- should_align_scans = False
279- for scan in self .scans [1 :]:
280- secondary_centre = get_physical_centre (scan )
281-
282- # calculate distance from center of first scan (e.g., T2W) to center of secondary scan (e.g., ADC, high b-value)
283- distance = np .sqrt (np .sum ((np .array (main_centre ) - np .array (secondary_centre ))** 2 ))
284-
285- # if difference in center coordinates is more than 2mm, align the scans
286- if distance > 2 :
287- print (f"Aligning scans with distance of { distance :.1f} mm between centers for { self .name } ." )
288- should_align_scans = True
289-
290- if should_align_scans :
291- for i , main_scan in enumerate (self .scans ):
292- for j , secondary_scan in enumerate (self .scans ):
293- if i == j :
294- continue
295-
296- # align scans
297- img_main , img_sec = crop_to_common_physical_space (main_scan , secondary_scan )
298- self .scans [i ] = img_main
299- self .scans [j ] = img_sec
300-
301- def resample (self ):
302- """Resample scans and label"""
181+ def resample_to_first_scan (self ):
182+ """Resample scans and label to the first scan"""
183+ # set up resampler to resolution, field of view, etc. of first scan
184+ resampler = sitk .ResampleImageFilter () # default linear
185+ resampler .SetReferenceImage (self .scans [0 ])
186+ resampler .SetInterpolator (sitk .sitkBSpline )
187+
188+ # resample other images
189+ self .scans [1 :] = [resampler .Execute (scan ) for scan in self .scans [1 :]]
190+
191+ # resample annotation
192+ resampler .SetInterpolator (sitk .sitkNearestNeighbor )
193+ if self .lbl is not None :
194+ self .lbl = resampler .Execute (self .lbl )
195+
196+ def resample_spacing (self , spacing : Optional [Iterable [float ]] = None ):
197+ """Resample scans and label to the target spacing"""
198+ if spacing is None :
199+ assert self .settings .spacing is not None
200+ spacing = self .settings .spacing
201+
202+ # resample scans to target resolution
303203 self .scans = [
304- resample_img (scan , out_spacing = self . settings . _spacing , is_label = False )
204+ resample_img (scan , out_spacing = spacing , is_label = False )
305205 for scan in self .scans
306206 ]
307207
208+ # resample annotation to target resolution
308209 if self .lbl is not None :
309- self .lbl = resample_img (self .lbl , out_spacing = self . settings . _spacing , is_label = True )
210+ self .lbl = resample_img (self .lbl , out_spacing = spacing , is_label = True )
310211
311212 def centre_crop (self ):
312213 """Centre crop scans and label"""
@@ -318,7 +219,7 @@ def centre_crop(self):
318219 if self .lbl is not None :
319220 self .lbl = crop_or_pad (self .lbl , size = self .settings .matrix_size )
320221
321- def copy_physical_metadata (self ):
222+ def align_physical_metadata (self , check_almost_equal = True ):
322223 """Align the origin and direction of each scan, and label"""
323224 case_origin , case_direction , case_spacing = None , None , None
324225 for img in self .scans :
@@ -328,6 +229,13 @@ def copy_physical_metadata(self):
328229 case_direction = img .GetDirection ()
329230 case_spacing = img .GetSpacing ()
330231 else :
232+ if check_almost_equal :
233+ # check if current scan's metadata is almost equal to the first scan
234+ assert_allclose (img .GetOrigin (), case_origin )
235+ assert_allclose (img .GetDirection (), case_direction )
236+ assert_allclose (img .GetSpacing (), case_spacing )
237+
238+ # copy over first scan's metadata to current scan
331239 img .SetOrigin (case_origin )
332240 img .SetDirection (case_direction )
333241 img .SetSpacing (case_spacing )
@@ -348,18 +256,19 @@ def preprocess(self):
348256 # apply scan transformation
349257 self .scans = [self .scan_preprocess_func (scan ) for scan in self .scans ]
350258
351- if self .settings .align_physical_space :
352- # align sequences based on metadata
353- self .crop_to_common_physical_space ()
259+ if self .settings .spacing is not None :
260+ # resample scans and label to specified spacing
261+ self .resample_spacing ()
354262
355- # resample scans and label
356- self .resample ()
263+ if self .settings .matrix_size is not None :
264+ # perform centre crop
265+ self .centre_crop ()
357266
358- # perform centre crop
359- self .centre_crop ()
267+ # resample scans and label to first scan's spacing, field-of-view, etc.
268+ self .resample_to_first_scan ()
360269
361270 # copy physical metadata to align subvoxel differences between sequences
362- self .copy_physical_metadata ()
271+ self .align_physical_metadata ()
363272
364273 if self .lbl is not None :
365274 # check connected components of annotation
0 commit comments