Skip to content

Commit a722107

Browse files
committed
Croppad
1 parent 3de4c00 commit a722107

4 files changed

Lines changed: 433 additions & 1 deletion

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "yucca"
3-
version = "2.2.7"
3+
version = "2.2.8"
44
authors = [
55
{ name="Sebastian Llambias", email="llambias@live.com" },
66
{ name="Asbjørn Munk", email="9844416+asbjrnmunk@users.noreply.github.com" },
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .croppad import torch_croppad
Lines changed: 345 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,345 @@
1+
import torch
2+
import numpy as np
3+
import torch.nn.functional as F
4+
5+
6+
def torch_croppad(
7+
image: torch.tensor,
8+
patch_size,
9+
input_dims: torch.tensor,
10+
target_image_shape: list | tuple,
11+
target_label_shape: list | tuple,
12+
p_oversample_foreground=0.0,
13+
foreground_locations=None,
14+
label: torch.tensor = None,
15+
**pad_kwargs,
16+
):
17+
"""
18+
Crops and pads the input image and label to the specified patch size.
19+
Input includes channel/modality dimension so 3D is actually 4D and 2D is 3D.
20+
"""
21+
22+
if len(patch_size) == 3:
23+
image, label = croppad_3D_case_from_3D(
24+
image=image,
25+
foreground_locations=foreground_locations,
26+
label=label,
27+
patch_size=patch_size,
28+
p_oversample_foreground=p_oversample_foreground,
29+
target_image_shape=target_image_shape,
30+
target_label_shape=target_label_shape,
31+
**pad_kwargs,
32+
)
33+
elif len(patch_size) == 2 and input_dims == 3:
34+
image, label = croppad_2D_case_from_3D(
35+
image=image,
36+
foreground_locations=foreground_locations,
37+
label=label,
38+
patch_size=patch_size,
39+
p_oversample_foreground=p_oversample_foreground,
40+
target_image_shape=target_image_shape,
41+
target_label_shape=target_label_shape,
42+
**pad_kwargs,
43+
)
44+
elif len(patch_size) == 2 and input_dims == 2:
45+
image, label = croppad_2D_case_from_2D(
46+
image=image,
47+
foreground_locations=foreground_locations,
48+
label=label,
49+
patch_size=patch_size,
50+
p_oversample_foreground=p_oversample_foreground,
51+
target_image_shape=target_image_shape,
52+
target_label_shape=target_label_shape,
53+
**pad_kwargs,
54+
)
55+
56+
return image, label
57+
58+
59+
def croppad_3D_case_from_3D(
60+
image,
61+
foreground_locations,
62+
label,
63+
patch_size,
64+
p_oversample_foreground,
65+
target_image_shape,
66+
target_label_shape,
67+
**pad_kwargs,
68+
):
69+
image_out = torch.zeros(target_image_shape)
70+
label_out = torch.zeros(target_label_shape)
71+
72+
# First we pad to ensure min size is met
73+
to_pad = []
74+
for d in range(3):
75+
if image.shape[d + 1] < patch_size[d]:
76+
to_pad += [patch_size[d] - image.shape[d + 1]]
77+
else:
78+
to_pad += [0]
79+
80+
pad_lb_x = to_pad[0] // 2
81+
pad_ub_x = to_pad[0] // 2 + to_pad[0] % 2
82+
pad_lb_y = to_pad[1] // 2
83+
pad_ub_y = to_pad[1] // 2 + to_pad[1] % 2
84+
pad_lb_z = to_pad[2] // 2
85+
pad_ub_z = to_pad[2] // 2 + to_pad[2] % 2
86+
87+
# This is where we should implement any patch selection biases.
88+
# The final patch excted after augmentation will always be the center of this patch
89+
# to avoid interpolation artifacts near the borders
90+
crop_start_idx = []
91+
if len(foreground_locations) == 0 or np.random.uniform() >= p_oversample_foreground:
92+
for d in range(3):
93+
if image.shape[d + 1] < patch_size[d]:
94+
crop_start_idx += [0]
95+
else:
96+
crop_start_idx += [np.random.randint(image.shape[d + 1] - patch_size[d] + 1)]
97+
else:
98+
location = select_foreground_voxel_to_include(foreground_locations)
99+
for d in range(3):
100+
if image.shape[d + 1] < patch_size[d]:
101+
crop_start_idx += [0]
102+
else:
103+
crop_start_idx += [
104+
np.random.randint(
105+
max(0, location[d] - patch_size[d]),
106+
min(location[d], image.shape[d + 1] - patch_size[d]) + 1,
107+
)
108+
]
109+
110+
image_out[
111+
:,
112+
:,
113+
:,
114+
:,
115+
] = F.pad(
116+
image[
117+
:,
118+
crop_start_idx[0] : crop_start_idx[0] + patch_size[0],
119+
crop_start_idx[1] : crop_start_idx[1] + patch_size[1],
120+
crop_start_idx[2] : crop_start_idx[2] + patch_size[2],
121+
],
122+
(
123+
pad_lb_z,
124+
pad_ub_z,
125+
pad_lb_y,
126+
pad_ub_y,
127+
pad_lb_x,
128+
pad_ub_x,
129+
0,
130+
0,
131+
),
132+
**pad_kwargs,
133+
)
134+
if label is None:
135+
return image_out, None
136+
label_out[
137+
:,
138+
:,
139+
:,
140+
:,
141+
] = F.pad(
142+
label[
143+
:,
144+
crop_start_idx[0] : crop_start_idx[0] + patch_size[0],
145+
crop_start_idx[1] : crop_start_idx[1] + patch_size[1],
146+
crop_start_idx[2] : crop_start_idx[2] + patch_size[2],
147+
],
148+
pad_lb_z,
149+
pad_ub_z,
150+
pad_lb_y,
151+
pad_ub_y,
152+
pad_lb_x,
153+
pad_ub_x,
154+
0,
155+
0,
156+
)
157+
return image_out, label_out
158+
159+
160+
def croppad_2D_case_from_3D(
161+
image,
162+
foreground_locations,
163+
label,
164+
patch_size,
165+
p_oversample_foreground,
166+
target_image_shape,
167+
target_label_shape,
168+
**pad_kwargs,
169+
):
170+
"""
171+
The possible input for this can be 2D or 3D data.
172+
For 2D we want to pad or crop as necessary.
173+
For 3D we want to first select a slice from the first dimension, i.e. volume[idx, :, :],
174+
then pad or crop as necessary.
175+
"""
176+
image_out = F.pad(target_image_shape)
177+
label_out = F.pad(target_label_shape)
178+
179+
# First we pad to ensure min size is met
180+
to_pad = []
181+
for d in range(2):
182+
if image.shape[d + 2] < patch_size[d]:
183+
to_pad += [patch_size[d] - image.shape[d + 2]]
184+
else:
185+
to_pad += [0]
186+
187+
pad_lb_y = to_pad[0] // 2
188+
pad_ub_y = to_pad[0] // 2 + to_pad[0] % 2
189+
pad_lb_z = to_pad[1] // 2
190+
pad_ub_z = to_pad[1] // 2 + to_pad[1] % 2
191+
192+
# This is where we should implement any patch selection biases.
193+
# The final patch extracted after augmentation will always be the center of this patch
194+
# as this is where augmentation-induced interpolation artefacts are least likely
195+
crop_start_idx = []
196+
if len(foreground_locations) == 0 or np.random.uniform() >= p_oversample_foreground:
197+
x_idx = np.random.randint(image.shape[1])
198+
for d in range(2):
199+
if image.shape[d + 2] < patch_size[d]:
200+
crop_start_idx += [0]
201+
else:
202+
crop_start_idx += [np.random.randint(image.shape[d + 2] - patch_size[d] + 1)]
203+
else:
204+
location = select_foreground_voxel_to_include(foreground_locations)
205+
x_idx = location[0]
206+
for d in range(2):
207+
if image.shape[d + 2] < patch_size[d]:
208+
crop_start_idx += [0]
209+
else:
210+
crop_start_idx += [
211+
np.random.randint(
212+
max(0, location[d + 1] - patch_size[d]),
213+
min(location[d + 1], image.shape[d + 2] - patch_size[d]) + 1,
214+
)
215+
]
216+
217+
image_out[:, :, :] = F.pad(
218+
image[
219+
:,
220+
x_idx,
221+
crop_start_idx[0] : crop_start_idx[0] + patch_size[0],
222+
crop_start_idx[1] : crop_start_idx[1] + patch_size[1],
223+
],
224+
pad_lb_z,
225+
pad_ub_z,
226+
pad_lb_y,
227+
pad_ub_y,
228+
0,
229+
0,
230+
**pad_kwargs,
231+
)
232+
233+
if label is None:
234+
return image_out, None
235+
236+
label_out[:, :, :] = F.pad(
237+
label[
238+
:,
239+
x_idx,
240+
crop_start_idx[0] : crop_start_idx[0] + patch_size[0],
241+
crop_start_idx[1] : crop_start_idx[1] + patch_size[1],
242+
],
243+
((0, 0), (pad_lb_y, pad_ub_y), (pad_lb_z, pad_ub_z)),
244+
)
245+
246+
return image_out, label_out
247+
248+
249+
def croppad_2D_case_from_2D(
250+
image,
251+
foreground_locations,
252+
label,
253+
patch_size,
254+
p_oversample_foreground,
255+
target_image_shape,
256+
target_label_shape,
257+
**pad_kwargs,
258+
):
259+
"""
260+
The possible input for this can be 2D or 3D data.
261+
For 2D we want to pad or crop as necessary.
262+
For 3D we want to first select a slice from the first dimension, i.e. volume[idx, :, :],
263+
then pad or crop as necessary.
264+
"""
265+
image_out = F.pad(target_image_shape)
266+
label_out = F.pad(target_label_shape)
267+
268+
# First we pad to ensure min size is met
269+
to_pad = []
270+
for d in range(2):
271+
if image.shape[d + 1] < patch_size[d]:
272+
to_pad += [patch_size[d] - image.shape[d + 1]]
273+
else:
274+
to_pad += [0]
275+
276+
pad_lb_x = to_pad[0] // 2
277+
pad_ub_x = to_pad[0] // 2 + to_pad[0] % 2
278+
pad_lb_y = to_pad[1] // 2
279+
pad_ub_y = to_pad[1] // 2 + to_pad[1] % 2
280+
281+
# This is where we should implement any patch selection biases.
282+
# The final patch extracted after augmentation will always be the center of this patch
283+
# as this is where artefacts are least present
284+
crop_start_idx = []
285+
if len(foreground_locations) == 0 or np.random.uniform() >= p_oversample_foreground:
286+
for d in range(2):
287+
if image.shape[d + 1] < patch_size[d]:
288+
crop_start_idx += [0]
289+
else:
290+
crop_start_idx += [np.random.randint(image.shape[d + 1] - patch_size[d] + 1)]
291+
else:
292+
location = select_foreground_voxel_to_include(foreground_locations)
293+
for d in range(2):
294+
if image.shape[d + 1] < patch_size[d]:
295+
crop_start_idx += [0]
296+
else:
297+
crop_start_idx += [
298+
np.random.randint(
299+
max(0, location[d] - patch_size[d]),
300+
min(location[d], image.shape[d + 1] - patch_size[d]) + 1,
301+
)
302+
]
303+
304+
image_out[:, :, :] = F.pad(
305+
image[
306+
:,
307+
crop_start_idx[0] : crop_start_idx[0] + patch_size[0],
308+
crop_start_idx[1] : crop_start_idx[1] + patch_size[1],
309+
],
310+
((0, 0), (pad_lb_x, pad_ub_x), (pad_lb_y, pad_ub_y)),
311+
**pad_kwargs,
312+
)
313+
314+
if label is None: # Reconstruction/inpainting
315+
return image_out, None
316+
317+
if len(label.shape) == 1: # Classification
318+
return image_out, label
319+
320+
label_out[:, :, :] = F.pad(
321+
label[
322+
:,
323+
crop_start_idx[0] : crop_start_idx[0] + patch_size[0],
324+
crop_start_idx[1] : crop_start_idx[1] + patch_size[1],
325+
],
326+
pad_lb_y,
327+
pad_ub_y,
328+
pad_lb_x,
329+
pad_ub_x,
330+
0,
331+
0,
332+
)
333+
334+
return image_out, label_out
335+
336+
337+
def select_foreground_voxel_to_include(foreground_locations):
338+
if isinstance(foreground_locations, list):
339+
locidx = np.random.choice(len(foreground_locations))
340+
location = foreground_locations[locidx]
341+
elif isinstance(foreground_locations, dict):
342+
selected_class = np.random.choice(list(foreground_locations.keys()))
343+
locidx = np.random.choice(len(foreground_locations[selected_class]))
344+
location = foreground_locations[selected_class][locidx]
345+
return location

0 commit comments

Comments
 (0)