Skip to content

Commit 7a8bee8

Browse files
committed
Update dockerfile, process and requirements to work with newer nnunetv2 version
1 parent eba4686 commit 7a8bee8

3 files changed

Lines changed: 38 additions & 67 deletions

File tree

Dockerfile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ RUN python3.9 -m pip install --no-cache-dir -r /tmp/requirements.txt -f https://
2828
RUN git config --global advice.detachedHead false && \
2929
git clone --no-checkout https://github.com/MIC-DKFZ/nnUNet.git /opt/algorithm/nnunet/ && \
3030
cd /opt/algorithm/nnunet/ && \
31-
git checkout 947eafbb9adb5eb06b9171330b4688e006e6f301
31+
git checkout v2.5.1
3232

3333
# Install a few dependencies that are not automatically installed
3434
RUN pip3 install \
@@ -60,7 +60,7 @@ COPY --chown=user:user export2onnx.py /opt/app/
6060
COPY --chown=user:user ./architecture/extensions/nnunetv2/ /opt/algorithm/nnunet/nnunetv2/
6161

6262
# Copy model checkpoint to docker (uncomment if you put the model weights directly in this repo)
63-
#COPY --chown=user:user ./architecture/nnUNet_results/ /opt/ml/model/
63+
#COPY --chown=user:user ./architecture/nnUNet_results/ /opt/algorithm/nnunet/nnUNet_results/
6464

6565
# Copy container testing data to docker (uncomment if you want to see if the model works and put a test image and spacing in this repo)
6666
#COPY --chown=user:user /architecture/input/ /input/

process.py

Lines changed: 34 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def start_pipeline(self):
4242

4343
def load_model(self):
4444
start_model_load_time = time.time()
45-
45+
4646
# Set up the nnUNetPredictor
4747
self.predictor = nnUNetPredictor(
4848
tile_step_size=0.5,
@@ -55,13 +55,13 @@ def load_model(self):
5555
)
5656
# Initialize the network architecture, loads the checkpoint
5757
self.predictor.initialize_from_trained_model_folder(
58-
"/opt/ml/model/Dataset601_Full_128_64/nnUNetTrainer_ULS_500_QuarterLR__nnUNetPlans_shallow__3d_fullres_resenc",
58+
"/opt/ml/model/Dataset090_ULS23_Combined/nnUNetTrainer__nnUNetResEncUNetLPlans__3d_fullres",
5959
use_folds=("all"),
6060
checkpoint_name="checkpoint_best.pth",
6161
)
6262
end_model_load_time = time.time()
6363
print(f"Model loading runtime: {end_model_load_time - start_model_load_time}s")
64-
64+
6565
def load_data(self):
6666
"""
6767
1) Loads the .mha files containing the VOI stacks in the input directory
@@ -75,7 +75,7 @@ def load_data(self):
7575
input_dir = Path("/input/images/stacked-3d-ct-lesion-volumes/")
7676

7777
# Load the spacings per VOI
78-
with open(Path("/input/stacked-3d-volumetric-spacings.json"), 'r') as json_file:
78+
with open(Path("/input/stacked_spacing_sample.json"), 'r') as json_file:
7979
spacings = json.load(json_file)
8080

8181
for input_file in input_dir.glob("*.mha"):
@@ -88,48 +88,17 @@ def load_data(self):
8888
image_data = sitk.GetArrayFromImage(self.image_metadata)
8989
for i in range(int(image_data.shape[0] / self.z_size)):
9090
voi = image_data[self.z_size * i:self.z_size * (i + 1), :, :]
91+
# Note: spacings[i] contains the scan spacing for this VOI
92+
93+
# Unstack the VOI's, perform optional preprocessing and save
94+
# them to individual binary files for memory-efficient access
95+
print(voi.shape)
96+
voi = voi[32:96, 64:192, 64:192]
97+
np.save(f"/tmp/voi_{i}.npy", np.array([voi])) # Add dummy batch dimension for nnUnet
9198

92-
# Convert the VOI back to a SimpleITK image
93-
voi_image = sitk.GetImageFromArray(voi)
94-
95-
# Calculate and set the metadata for the unstacked VOI
96-
original_origin = self.image_metadata.GetOrigin()
97-
original_spacing = self.image_metadata.GetSpacing()
98-
new_origin = [
99-
original_origin[0], # x-origin remains the same
100-
original_origin[1], # y-origin remains the same
101-
original_origin[2] + i * self.z_size * original_spacing[2], # Adjust z-origin for each VOI
102-
]
103-
voi_image.SetOrigin(new_origin)
104-
voi_image.SetSpacing(original_spacing)
105-
voi_image.SetDirection(self.image_metadata.GetDirection())
106-
107-
# Define the cropping region in physical space
108-
voi_shape = voi_image.GetSize()
109-
start_index = [64, 64, 32] # Start indices for cropping
110-
crop_size = [128, 128, 64] # Size of the cropped region
111-
112-
# Perform cropping using SimpleITK
113-
voi_cropped = sitk.RegionOfInterest(voi_image, size=crop_size, index=start_index)
114-
115-
# Update the origin of the cropped VOI
116-
cropped_origin = [
117-
voi_image.GetOrigin()[0] + start_index[0] * voi_image.GetSpacing()[0],
118-
voi_image.GetOrigin()[1] + start_index[1] * voi_image.GetSpacing()[1],
119-
voi_image.GetOrigin()[2] + start_index[2] * voi_image.GetSpacing()[2],
120-
]
121-
voi_cropped.SetOrigin(cropped_origin)
122-
voi_cropped.SetSpacing(voi_image.GetSpacing())
123-
voi_cropped.SetDirection(voi_image.GetDirection())
124-
125-
# Save the cropped VOI to a binary file
126-
voi_cropped_array = sitk.GetArrayFromImage(voi_cropped)
127-
np.save(f"/tmp/voi_{i}.npy", np.array([voi_cropped_array])) # Add dummy batch dimension for nnUnet
128-
12999
end_load_time = time.time()
130100
print(f"Data pre-processing runtime: {end_load_time - start_load_time}s")
131101

132-
133102
return spacings
134103

135104
def predict(self, spacings):
@@ -143,8 +112,8 @@ def predict(self, spacings):
143112

144113
for i, voi_spacing in enumerate(spacings):
145114
# Load the 3D array from the binary file
146-
voi = torch.from_numpy(np.load(f"/tmp/voi_{i}.npy"))
147-
voi = voi.to(dtype=torch.float32)
115+
voi = np.load(f"/tmp/voi_{i}.npy")
116+
#voi = voi.to(dtype=np.float32)
148117

149118
print(f'\nPredicting image of shape: {voi.shape}, spacing: {voi_spacing}')
150119
predictions.append(self.predictor.predict_single_npy_array(voi, {'spacing': voi_spacing}, None, None, False))
@@ -155,11 +124,13 @@ def predict(self, spacings):
155124

156125
def postprocess(self, predictions):
157126
"""
158-
Runs post-processing and saves predictions for each VOI.
127+
Runs post-processing and saves stacked predictions.
159128
:param predictions: list of numpy arrays containing the predicted lesion masks per VOI
160129
"""
161130
start_postprocessing_time = time.time()
162-
131+
132+
# Run postprocessing code here, for the baseline we only remove any
133+
# segmentation outputs not connected to the center lesion prediction
163134
for i, segmentation in enumerate(predictions):
164135
print(f"Post-processing prediction {i}")
165136
instance_mask, num_features = ndimage.label(segmentation)
@@ -171,28 +142,27 @@ def postprocess(self, predictions):
171142

172143
# Pad segmentations to fit with original image size
173144
segmentation_pad = np.pad(segmentation,
174-
((32, 32),
175-
(64, 64),
176-
(64, 64)),
177-
mode='constant', constant_values=0)
145+
((32, 32),
146+
(64, 64),
147+
(64, 64)),
148+
mode='constant', constant_values=0)
178149

179-
# Convert padded segmentation and original segmentation back to a SimpleITK image
150+
# Convert padded segmentation back to a SimpleITK image
180151
segmentation_image = sitk.GetImageFromArray(segmentation_pad)
181-
segmentation_original = sitk.GetImageFromArray(segmentation)
182-
152+
183153
# Update the origin to account for the padding
184-
voi_origin = segmentation_original.GetOrigin()
185-
voi_spacing = segmentation_original.GetSpacing()
186-
voi_direction = segmentation_original.GetDirection()
187-
154+
original_origin = self.image_metadata.GetOrigin()
155+
original_spacing = self.image_metadata.GetSpacing()
188156
new_origin = [
189-
voi_origin[0] - 32 * voi_spacing[0], # Adjust for z padding
190-
voi_origin[1] - 64 * voi_spacing[1], # Adjust for x padding
191-
voi_origin[2] - 64 * voi_spacing[2], # Adjust for y padding
157+
original_origin[0] - 64 * original_spacing[0], # Adjust for x padding
158+
original_origin[1] - 64 * original_spacing[1], # Adjust for y padding
159+
original_origin[2] - 32 * original_spacing[2], # Adjust for z padding
192160
]
193161
segmentation_image.SetOrigin(new_origin)
194-
segmentation_image.SetDirection(voi_direction)
195-
segmentation_image.SetSpacing(voi_spacing)
162+
163+
# Copy the direction and spacing from the original metadata
164+
segmentation_image.SetDirection(self.image_metadata.GetDirection())
165+
segmentation_image.SetSpacing(self.image_metadata.GetSpacing())
196166

197167
# Save the updated segmentation image
198168
predictions[i] = sitk.GetArrayFromImage(segmentation_image)
@@ -201,8 +171,8 @@ def postprocess(self, predictions):
201171

202172
# Create mask image and copy over metadata
203173
mask = sitk.GetImageFromArray(predictions)
204-
mask.CopyInformation(self.image_metadata)
205-
174+
mask.CopyInformation(self.image_metadata)
175+
206176
sitk.WriteImage(mask, f"/output/images/ct-binary-uls/{self.id.name}")
207177
print("Output dir contents:", os.listdir("/output/images/ct-binary-uls/"))
208178
print("Output batched image shape:", predictions.shape)

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@ numpy==1.24.4
55
scikit-image==0.19.3
66
scipy==1.10.1
77
click==8.1.5
8-
batchgenerators==0.23
8+
batchgenerators>=0.25
99
blosc2==2.5.1
1010
acvl_utils==0.2
1111
torch==2.5.1
12+
nnunetv2==2.5.1

0 commit comments

Comments
 (0)