@@ -42,7 +42,7 @@ def start_pipeline(self):
4242
4343 def load_model (self ):
4444 start_model_load_time = time .time ()
45-
45+
4646 # Set up the nnUNetPredictor
4747 self .predictor = nnUNetPredictor (
4848 tile_step_size = 0.5 ,
@@ -55,13 +55,13 @@ def load_model(self):
5555 )
5656 # Initialize the network architecture, loads the checkpoint
5757 self .predictor .initialize_from_trained_model_folder (
58- "/opt/ml/model/Dataset601_Full_128_64/nnUNetTrainer_ULS_500_QuarterLR__nnUNetPlans_shallow__3d_fullres_resenc " ,
58+ "/opt/ml/model/Dataset090_ULS23_Combined/nnUNetTrainer__nnUNetResEncUNetLPlans__3d_fullres " ,
5959 use_folds = ("all" ),
6060 checkpoint_name = "checkpoint_best.pth" ,
6161 )
6262 end_model_load_time = time .time ()
6363 print (f"Model loading runtime: { end_model_load_time - start_model_load_time } s" )
64-
64+
6565 def load_data (self ):
6666 """
6767 1) Loads the .mha files containing the VOI stacks in the input directory
@@ -75,7 +75,7 @@ def load_data(self):
7575 input_dir = Path ("/input/images/stacked-3d-ct-lesion-volumes/" )
7676
7777 # Load the spacings per VOI
78- with open (Path ("/input/stacked-3d-volumetric-spacings .json" ), 'r' ) as json_file :
78+ with open (Path ("/input/stacked_spacing_sample .json" ), 'r' ) as json_file :
7979 spacings = json .load (json_file )
8080
8181 for input_file in input_dir .glob ("*.mha" ):
@@ -88,48 +88,17 @@ def load_data(self):
8888 image_data = sitk .GetArrayFromImage (self .image_metadata )
8989 for i in range (int (image_data .shape [0 ] / self .z_size )):
9090 voi = image_data [self .z_size * i :self .z_size * (i + 1 ), :, :]
91+ # Note: spacings[i] contains the scan spacing for this VOI
92+
93+ # Unstack the VOI's, perform optional preprocessing and save
94+ # them to individual binary files for memory-efficient access
95+ print (voi .shape )
96+ voi = voi [32 :96 , 64 :192 , 64 :192 ]
97+ np .save (f"/tmp/voi_{ i } .npy" , np .array ([voi ])) # Add dummy batch dimension for nnUnet
9198
92- # Convert the VOI back to a SimpleITK image
93- voi_image = sitk .GetImageFromArray (voi )
94-
95- # Calculate and set the metadata for the unstacked VOI
96- original_origin = self .image_metadata .GetOrigin ()
97- original_spacing = self .image_metadata .GetSpacing ()
98- new_origin = [
99- original_origin [0 ], # x-origin remains the same
100- original_origin [1 ], # y-origin remains the same
101- original_origin [2 ] + i * self .z_size * original_spacing [2 ], # Adjust z-origin for each VOI
102- ]
103- voi_image .SetOrigin (new_origin )
104- voi_image .SetSpacing (original_spacing )
105- voi_image .SetDirection (self .image_metadata .GetDirection ())
106-
107- # Define the cropping region in physical space
108- voi_shape = voi_image .GetSize ()
109- start_index = [64 , 64 , 32 ] # Start indices for cropping
110- crop_size = [128 , 128 , 64 ] # Size of the cropped region
111-
112- # Perform cropping using SimpleITK
113- voi_cropped = sitk .RegionOfInterest (voi_image , size = crop_size , index = start_index )
114-
115- # Update the origin of the cropped VOI
116- cropped_origin = [
117- voi_image .GetOrigin ()[0 ] + start_index [0 ] * voi_image .GetSpacing ()[0 ],
118- voi_image .GetOrigin ()[1 ] + start_index [1 ] * voi_image .GetSpacing ()[1 ],
119- voi_image .GetOrigin ()[2 ] + start_index [2 ] * voi_image .GetSpacing ()[2 ],
120- ]
121- voi_cropped .SetOrigin (cropped_origin )
122- voi_cropped .SetSpacing (voi_image .GetSpacing ())
123- voi_cropped .SetDirection (voi_image .GetDirection ())
124-
125- # Save the cropped VOI to a binary file
126- voi_cropped_array = sitk .GetArrayFromImage (voi_cropped )
127- np .save (f"/tmp/voi_{ i } .npy" , np .array ([voi_cropped_array ])) # Add dummy batch dimension for nnUnet
128-
12999 end_load_time = time .time ()
130100 print (f"Data pre-processing runtime: { end_load_time - start_load_time } s" )
131101
132-
133102 return spacings
134103
135104 def predict (self , spacings ):
@@ -143,8 +112,8 @@ def predict(self, spacings):
143112
144113 for i , voi_spacing in enumerate (spacings ):
145114 # Load the 3D array from the binary file
146- voi = torch . from_numpy ( np .load (f"/tmp/voi_{ i } .npy" ) )
147- voi = voi .to (dtype = torch .float32 )
115+ voi = np .load (f"/tmp/voi_{ i } .npy" )
116+ # voi = voi.to(dtype=np .float32)
148117
149118 print (f'\n Predicting image of shape: { voi .shape } , spacing: { voi_spacing } ' )
150119 predictions .append (self .predictor .predict_single_npy_array (voi , {'spacing' : voi_spacing }, None , None , False ))
@@ -155,11 +124,13 @@ def predict(self, spacings):
155124
156125 def postprocess (self , predictions ):
157126 """
158- Runs post-processing and saves predictions for each VOI .
127+ Runs post-processing and saves stacked predictions .
159128 :param predictions: list of numpy arrays containing the predicted lesion masks per VOI
160129 """
161130 start_postprocessing_time = time .time ()
162-
131+
132+ # Run postprocessing code here, for the baseline we only remove any
133+ # segmentation outputs not connected to the center lesion prediction
163134 for i , segmentation in enumerate (predictions ):
164135 print (f"Post-processing prediction { i } " )
165136 instance_mask , num_features = ndimage .label (segmentation )
@@ -171,28 +142,27 @@ def postprocess(self, predictions):
171142
172143 # Pad segmentations to fit with original image size
173144 segmentation_pad = np .pad (segmentation ,
174- ((32 , 32 ),
175- (64 , 64 ),
176- (64 , 64 )),
177- mode = 'constant' , constant_values = 0 )
145+ ((32 , 32 ),
146+ (64 , 64 ),
147+ (64 , 64 )),
148+ mode = 'constant' , constant_values = 0 )
178149
179- # Convert padded segmentation and original segmentation back to a SimpleITK image
150+ # Convert padded segmentation back to a SimpleITK image
180151 segmentation_image = sitk .GetImageFromArray (segmentation_pad )
181- segmentation_original = sitk .GetImageFromArray (segmentation )
182-
152+
183153 # Update the origin to account for the padding
184- voi_origin = segmentation_original .GetOrigin ()
185- voi_spacing = segmentation_original .GetSpacing ()
186- voi_direction = segmentation_original .GetDirection ()
187-
154+ original_origin = self .image_metadata .GetOrigin ()
155+ original_spacing = self .image_metadata .GetSpacing ()
188156 new_origin = [
189- voi_origin [0 ] - 32 * voi_spacing [0 ], # Adjust for z padding
190- voi_origin [1 ] - 64 * voi_spacing [1 ], # Adjust for x padding
191- voi_origin [2 ] - 64 * voi_spacing [2 ], # Adjust for y padding
157+ original_origin [0 ] - 64 * original_spacing [0 ], # Adjust for x padding
158+ original_origin [1 ] - 64 * original_spacing [1 ], # Adjust for y padding
159+ original_origin [2 ] - 32 * original_spacing [2 ], # Adjust for z padding
192160 ]
193161 segmentation_image .SetOrigin (new_origin )
194- segmentation_image .SetDirection (voi_direction )
195- segmentation_image .SetSpacing (voi_spacing )
162+
163+ # Copy the direction and spacing from the original metadata
164+ segmentation_image .SetDirection (self .image_metadata .GetDirection ())
165+ segmentation_image .SetSpacing (self .image_metadata .GetSpacing ())
196166
197167 # Save the updated segmentation image
198168 predictions [i ] = sitk .GetArrayFromImage (segmentation_image )
@@ -201,8 +171,8 @@ def postprocess(self, predictions):
201171
202172 # Create mask image and copy over metadata
203173 mask = sitk .GetImageFromArray (predictions )
204- mask .CopyInformation (self .image_metadata )
205-
174+ mask .CopyInformation (self .image_metadata )
175+
206176 sitk .WriteImage (mask , f"/output/images/ct-binary-uls/{ self .id .name } " )
207177 print ("Output dir contents:" , os .listdir ("/output/images/ct-binary-uls/" ))
208178 print ("Output batched image shape:" , predictions .shape )
0 commit comments