forked from DIAGNijmegen/oncology-ULS-fast-for-challenge
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathprocess.py
More file actions
183 lines (151 loc) · 7.79 KB
/
process.py
File metadata and controls
183 lines (151 loc) · 7.79 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import os
import time
import json
import torch
from scipy import ndimage
import SimpleITK as sitk
import numpy as np
from pathlib import Path
from evalutils import SegmentationAlgorithm
from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor
from nnunetv2.inference.export_prediction import convert_predicted_logits_to_segmentation_with_correct_shape
from nnunetv2.utilities.helpers import empty_cache
class Uls23(SegmentationAlgorithm):
def __init__(self):
self.image_metadata = None # Keep track of the metadata of the input volume
self.id = None # Keep track of batched volume file name for export
self.z_size = 128 # Number of voxels in the z-dimension for each input VOI
self.xy_size = 256 # Number of voxels in the xy-dimensions for each input VOI
self.z_size_model = 64 # Number of voxels in the z-dimension that the model takes
self.xy_size_model = 128 # Number of voxels in the xy-dimensions that the model takes
self.device = torch.device("cuda")
self.predictor = None # nnUnet predictor
def start_pipeline(self):
"""
Starts inference algorithm
"""
start_time = time.time()
# We need to create the correct output folder, determined by the interface, ourselves
os.makedirs("/output/images/ct-binary-uls/", exist_ok=True)
self.load_model()
spacings = self.load_data()
predictions = self.predict(spacings)
self.postprocess(predictions)
end_time = time.time()
print(f"Total job runtime: {end_time - start_time}s")
def load_model(self):
start_model_load_time = time.time()
# Set up the nnUNetPredictor
self.predictor = nnUNetPredictor(
tile_step_size=0.5,
use_gaussian=True,
use_mirroring=False, # False is faster but less accurate
device=self.device,
verbose=False,
verbose_preprocessing=False,
allow_tqdm=False
)
# Initialize the network architecture, loads the checkpoint
self.predictor.initialize_from_trained_model_folder(
"/opt/ml/model/nnUNet_results/Dataset090_ULS23_Combined/nnUNetTrainer__nnUNetResEncUNetLPlans__3d_fullres",
use_folds=("all"),
checkpoint_name="checkpoint_best.pth",
)
end_model_load_time = time.time()
print(f"Model loading runtime: {end_model_load_time - start_model_load_time}s")
def load_data(self):
"""
1) Loads the .mha files containing the VOI stacks in the input directory
2) Unstacks them into individual lesion VOI's
3) Optional: preprocess volumes
4) Predict per VOI
"""
start_load_time = time.time()
# Input directory is determined by the algorithm interface on GC
input_dir = Path("/input/images/stacked-3d-ct-lesion-volumes/")
# Load the spacings per VOI
with open(Path("/input/stacked-3d-volumetric-spacings.json"), 'r') as json_file:
spacings = json.load(json_file)
for input_file in input_dir.glob("*.mha"):
self.id = input_file
# Load and keep track of the image metadata
self.image_metadata = sitk.ReadImage(input_dir / input_file)
# Now get the image data
image_data = sitk.GetArrayFromImage(self.image_metadata)
for i in range(int(image_data.shape[0] / self.z_size)):
voi = image_data[self.z_size * i:self.z_size * (i + 1), :, :]
# Note: spacings[i] contains the scan spacing for this VOI
# Unstack the VOI's, perform optional preprocessing and save
# them to individual binary files for memory-efficient access
print(voi.shape)
voi = voi[32:96, 64:192, 64:192]
np.save(f"/tmp/voi_{i}.npy", np.array([voi])) # Add dummy batch dimension for nnUnet
end_load_time = time.time()
print(f"Data pre-processing runtime: {end_load_time - start_load_time}s")
return spacings
def predict(self, spacings):
"""
Runs nnUnet inference on the images, then moves to post-processing
:param spacings: list containing the spacing per VOI
:return: list of numpy arrays containing the predicted lesion masks per VOI
"""
start_inference_time = time.time()
predictions = []
for i, voi_spacing in enumerate(spacings):
# Load the 3D array from the binary file
voi = np.load(f"/tmp/voi_{i}.npy")
#voi = voi.to(dtype=np.float32)
print(f'\nPredicting image of shape: {voi.shape}, spacing: {voi_spacing}')
predictions.append(self.predictor.predict_single_npy_array(voi, {'spacing': voi_spacing}, None, None, False))
end_inference_time = time.time()
print(f"Total inference runtime: {end_inference_time - start_inference_time}s")
return predictions
def postprocess(self, predictions):
"""
Runs post-processing and saves stacked predictions.
:param predictions: list of numpy arrays containing the predicted lesion masks per VOI
"""
start_postprocessing_time = time.time()
# Run postprocessing code here, for the baseline we only remove any
# segmentation outputs not connected to the center lesion prediction
for i, segmentation in enumerate(predictions):
print(f"Post-processing prediction {i}")
instance_mask, num_features = ndimage.label(segmentation)
if num_features > 1:
print("Found multiple lesion predictions")
segmentation[instance_mask != instance_mask[
int(self.z_size_model / 2), int(self.xy_size_model / 2), int(self.xy_size_model / 2)]] = 0
segmentation[segmentation != 0] = 1
# Pad segmentations to fit with original image size
segmentation_pad = np.pad(segmentation,
((32, 32),
(64, 64),
(64, 64)),
mode='constant', constant_values=0)
# Convert padded segmentation back to a SimpleITK image
segmentation_image = sitk.GetImageFromArray(segmentation_pad)
# Update the origin to account for the padding
original_origin = self.image_metadata.GetOrigin()
original_spacing = self.image_metadata.GetSpacing()
new_origin = [
original_origin[0] - 64 * original_spacing[0], # Adjust for x padding
original_origin[1] - 64 * original_spacing[1], # Adjust for y padding
original_origin[2] - 32 * original_spacing[2], # Adjust for z padding
]
segmentation_image.SetOrigin(new_origin)
# Copy the direction and spacing from the original metadata
segmentation_image.SetDirection(self.image_metadata.GetDirection())
segmentation_image.SetSpacing(self.image_metadata.GetSpacing())
# Save the updated segmentation image
predictions[i] = sitk.GetArrayFromImage(segmentation_image)
predictions = np.concatenate(predictions, axis=0) # Stack predictions
# Create mask image and copy over metadata
mask = sitk.GetImageFromArray(predictions)
mask.CopyInformation(self.image_metadata)
sitk.WriteImage(mask, f"/output/images/ct-binary-uls/{self.id.name}")
print("Output dir contents:", os.listdir("/output/images/ct-binary-uls/"))
print("Output batched image shape:", predictions.shape)
end_postprocessing_time = time.time()
print(f"Postprocessing & saving runtime: {end_postprocessing_time - start_postprocessing_time}s")
if __name__ == "__main__":
Uls23().start_pipeline()