Skip to content

Commit a074e1f

Browse files
deploy changes
1 parent e4ce4f3 commit a074e1f

15 files changed

Lines changed: 629 additions & 32 deletions

File tree

asparagus/functional/reverse_preprocessing.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,15 @@ def reverse_preprocessing(array, image_properties):
88
pad_bbox = image_properties["pad_box"]
99
crop_bbox = image_properties["crop_box"]
1010

11-
shape = array.shape[2:]
12-
if len(shape) == 2:
11+
ndim = len(array.shape[2:])
12+
if ndim == 2:
1313
mode = "bilinear"
14-
elif len(shape) == 3:
14+
elif ndim == 3:
1515
mode = "trilinear"
1616

1717
if len(pad_bbox) > 0:
1818
array = unpad_array(array, pad_bbox)
19-
verify_shapes_are_equal(reference_shape=shape, target_shape=image_properties["shape_before_pad"])
19+
verify_shapes_are_equal(reference_shape=array.shape[2:], target_shape=image_properties["shape_before_pad"])
2020

2121
array = F.interpolate(array, size=image_properties["size_before_resample"], mode=mode)
2222

asparagus/modules/lightning_modules/segmentation_module.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,6 @@
2323
volume_similarity,
2424
)
2525
from gardening_tools.functional.paths.write import save_json
26-
from gardening_tools.functional.transforms.cropping_and_padding import (
27-
fit_patch_size_to_image_size,
28-
)
2926
from gardening_tools.modules.losses.deep_supervision import DeepSupervisionLoss
3027
from gardening_tools.modules.losses.DiceCE import DiceCE
3128
from gardening_tools.modules.metrics import GeneralizedDiceScore
@@ -51,7 +48,6 @@ def __init__(
5148
val_transforms: Optional[transforms.Compose] = None,
5249
optimizer: str = "SGD",
5350
inference_patch_size: list = [],
54-
inference_mode: str = "3D",
5551
test_output_path: str = None,
5652
log_image_every_n_epochs: int = 50,
5753
weight_decay: float = 3e-5,
@@ -78,7 +74,6 @@ def __init__(
7874
load_decoder=load_decoder,
7975
repeat_stem_weights=repeat_stem_weights,
8076
)
81-
self.inference_mode = inference_mode
8277
self.inference_patch_size = inference_patch_size
8378
self.test_output_path = test_output_path
8479
self.num_classes = model.num_classes
@@ -161,7 +156,6 @@ def training_step(self, batch, batch_idx):
161156

162157
def validation_step(self, batch, batch_idx):
163158
x, y = batch["image"], batch["label"]
164-
165159
pred = self.model(x)
166160
loss = self.val_loss(pred, y)
167161
self.log(
@@ -221,7 +215,7 @@ def test_step(self, batch, batch_idx):
221215

222216
logits = self.model.sliding_window_predict(
223217
data=x,
224-
patch_size=fit_patch_size_to_image_size(self.inference_patch_size, list(x.shape[2:])),
218+
patch_size=self.inference_patch_size,
225219
overlap=0.5,
226220
)
227221

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
11
from asparagus.modules.transforms.clamp import Torch_ClampTarget as Torch_ClampTarget
2+
from asparagus.modules.transforms.crop import Torch_Crop as Torch_Crop
3+
from asparagus.modules.transforms.pad import Torch_Pad as Torch_Pad
Lines changed: 275 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,275 @@
1+
import numpy as np
2+
import torch
3+
from gardening_tools.modules.transforms.BaseTransform import BaseTransform
4+
5+
6+
def select_foreground_voxel_to_include(foreground_locations):
7+
if isinstance(foreground_locations, list):
8+
locidx = np.random.choice(len(foreground_locations))
9+
location = foreground_locations[locidx]
10+
elif isinstance(foreground_locations, dict):
11+
selected_class = np.random.choice(list(foreground_locations.keys()))
12+
locidx = np.random.choice(len(foreground_locations[selected_class]))
13+
location = foreground_locations[selected_class][locidx]
14+
return location
15+
16+
17+
def torch_crop(
18+
image: torch.tensor,
19+
patch_size,
20+
input_dims: torch.tensor,
21+
target_image_shape: list | tuple,
22+
target_label_shape: list | tuple,
23+
p_oversample_foreground=0.0,
24+
foreground_locations=None,
25+
label: torch.tensor = None,
26+
):
27+
if foreground_locations is None:
28+
foreground_locations = []
29+
30+
if len(patch_size) == 3:
31+
image, label = torch_crop_3D_case_from_3D(
32+
image=image,
33+
foreground_locations=foreground_locations,
34+
label=label,
35+
patch_size=patch_size,
36+
p_oversample_foreground=p_oversample_foreground,
37+
target_image_shape=target_image_shape,
38+
target_label_shape=target_label_shape,
39+
)
40+
elif len(patch_size) == 2 and input_dims == 3:
41+
image, label = torch_crop_2D_case_from_3D(
42+
image=image,
43+
foreground_locations=foreground_locations,
44+
label=label,
45+
patch_size=patch_size,
46+
p_oversample_foreground=p_oversample_foreground,
47+
target_image_shape=target_image_shape,
48+
target_label_shape=target_label_shape,
49+
)
50+
elif len(patch_size) == 2 and input_dims == 2:
51+
image, label = torch_crop_2D_case_from_2D(
52+
image=image,
53+
foreground_locations=foreground_locations,
54+
label=label,
55+
patch_size=patch_size,
56+
p_oversample_foreground=p_oversample_foreground,
57+
target_image_shape=target_image_shape,
58+
target_label_shape=target_label_shape,
59+
)
60+
61+
return image, label
62+
63+
64+
def torch_crop_3D_case_from_3D(
65+
image,
66+
foreground_locations,
67+
label,
68+
patch_size,
69+
p_oversample_foreground,
70+
target_image_shape,
71+
target_label_shape,
72+
):
73+
image_out = torch.zeros(target_image_shape, device=image.device)
74+
label_out = torch.zeros(target_label_shape, device=image.device)
75+
76+
crop_start_idx = []
77+
if len(foreground_locations) == 0 or np.random.uniform() >= p_oversample_foreground:
78+
for d in range(3):
79+
if image.shape[d + 1] < patch_size[d]:
80+
crop_start_idx += [0]
81+
else:
82+
crop_start_idx += [np.random.randint(image.shape[d + 1] - patch_size[d] + 1)]
83+
else:
84+
location = select_foreground_voxel_to_include(foreground_locations)
85+
for d in range(3):
86+
if image.shape[d + 1] < patch_size[d]:
87+
crop_start_idx += [0]
88+
else:
89+
crop_start_idx += [
90+
np.random.randint(
91+
max(0, location[d] - patch_size[d]),
92+
min(location[d], image.shape[d + 1] - patch_size[d]) + 1,
93+
)
94+
]
95+
96+
image_out = image[
97+
:,
98+
crop_start_idx[0] : crop_start_idx[0] + patch_size[0],
99+
crop_start_idx[1] : crop_start_idx[1] + patch_size[1],
100+
crop_start_idx[2] : crop_start_idx[2] + patch_size[2],
101+
]
102+
if label is None:
103+
return image_out, None
104+
label_out = label[
105+
:,
106+
crop_start_idx[0] : crop_start_idx[0] + patch_size[0],
107+
crop_start_idx[1] : crop_start_idx[1] + patch_size[1],
108+
crop_start_idx[2] : crop_start_idx[2] + patch_size[2],
109+
]
110+
return image_out, label_out
111+
112+
113+
def torch_crop_2D_case_from_3D(
114+
image,
115+
foreground_locations,
116+
label,
117+
patch_size,
118+
p_oversample_foreground,
119+
target_image_shape,
120+
target_label_shape,
121+
):
122+
image_out = torch.zeros(target_image_shape, device=image.device)
123+
label_out = torch.zeros(target_label_shape, device=image.device)
124+
125+
crop_start_idx = []
126+
if len(foreground_locations) == 0 or np.random.uniform() >= p_oversample_foreground:
127+
x_idx = np.random.randint(image.shape[1])
128+
for d in range(2):
129+
if image.shape[d + 2] < patch_size[d]:
130+
crop_start_idx += [0]
131+
else:
132+
crop_start_idx += [np.random.randint(image.shape[d + 2] - patch_size[d] + 1)]
133+
else:
134+
location = select_foreground_voxel_to_include(foreground_locations)
135+
x_idx = location[0]
136+
for d in range(2):
137+
if image.shape[d + 2] < patch_size[d]:
138+
crop_start_idx += [0]
139+
else:
140+
crop_start_idx += [
141+
np.random.randint(
142+
max(0, location[d + 1] - patch_size[d]),
143+
min(location[d + 1], image.shape[d + 2] - patch_size[d]) + 1,
144+
)
145+
]
146+
147+
image_out[:, :, :] = image[
148+
:,
149+
x_idx,
150+
crop_start_idx[0] : crop_start_idx[0] + patch_size[0],
151+
crop_start_idx[1] : crop_start_idx[1] + patch_size[1],
152+
]
153+
154+
if label is None:
155+
return image_out, None
156+
157+
label_out[:, :, :] = label[
158+
:,
159+
x_idx,
160+
crop_start_idx[0] : crop_start_idx[0] + patch_size[0],
161+
crop_start_idx[1] : crop_start_idx[1] + patch_size[1],
162+
]
163+
164+
return image_out, label_out
165+
166+
167+
def torch_crop_2D_case_from_2D(
168+
image,
169+
foreground_locations,
170+
label,
171+
patch_size,
172+
p_oversample_foreground,
173+
target_image_shape,
174+
target_label_shape,
175+
):
176+
image_out = torch.zeros(target_image_shape, device=image.device)
177+
label_out = torch.zeros(target_label_shape, device=image.device)
178+
179+
crop_start_idx = []
180+
if len(foreground_locations) == 0 or np.random.uniform() >= p_oversample_foreground:
181+
for d in range(2):
182+
if image.shape[d + 1] < patch_size[d]:
183+
crop_start_idx += [0]
184+
else:
185+
crop_start_idx += [np.random.randint(image.shape[d + 1] - patch_size[d] + 1)]
186+
else:
187+
location = select_foreground_voxel_to_include(foreground_locations)
188+
for d in range(2):
189+
if image.shape[d + 1] < patch_size[d]:
190+
crop_start_idx += [0]
191+
else:
192+
crop_start_idx += [
193+
np.random.randint(
194+
max(0, location[d] - patch_size[d]),
195+
min(location[d], image.shape[d + 1] - patch_size[d]) + 1,
196+
)
197+
]
198+
199+
image_out[:, :, :] = image[
200+
:,
201+
crop_start_idx[0] : crop_start_idx[0] + patch_size[0],
202+
crop_start_idx[1] : crop_start_idx[1] + patch_size[1],
203+
]
204+
205+
if label is None:
206+
return image_out, None
207+
208+
label_out[:, :, :] = label[
209+
:,
210+
crop_start_idx[0] : crop_start_idx[0] + patch_size[0],
211+
crop_start_idx[1] : crop_start_idx[1] + patch_size[1],
212+
]
213+
214+
return image_out, label_out
215+
216+
217+
class Torch_Crop(BaseTransform):
218+
def __init__(
219+
self,
220+
data_key: str = "image",
221+
label_key: str = "label",
222+
patch_size: tuple | list = None,
223+
p_oversample_foreground: float = 0.0,
224+
):
225+
self.data_key = data_key
226+
self.label_key = label_key
227+
self.patch_size = patch_size
228+
self.p_oversample_foreground = p_oversample_foreground
229+
230+
@staticmethod
231+
def get_params(data, target_shape):
232+
input_shape = data.shape
233+
target_image_shape = (input_shape[0], *target_shape)
234+
target_label_shape = (1, *target_shape)
235+
return input_shape, target_image_shape, target_label_shape
236+
237+
def __crop__(
238+
self,
239+
data_dict,
240+
foreground_locations,
241+
input_shape,
242+
p_oversample_foreground,
243+
target_image_shape,
244+
target_label_shape,
245+
):
246+
image = data_dict[self.data_key]
247+
label = data_dict.get(self.label_key)
248+
image, label = torch_crop(
249+
image=image,
250+
patch_size=self.patch_size,
251+
input_dims=len(input_shape[1:]),
252+
target_image_shape=target_image_shape,
253+
target_label_shape=target_label_shape,
254+
p_oversample_foreground=p_oversample_foreground,
255+
foreground_locations=foreground_locations,
256+
label=label,
257+
)
258+
data_dict[self.data_key] = image
259+
if label is not None:
260+
data_dict[self.label_key] = label
261+
return data_dict
262+
263+
def __call__(self, data_dict: dict) -> dict:
264+
input_shape, target_image_shape, target_label_shape = self.get_params(
265+
data=data_dict[self.data_key],
266+
target_shape=self.patch_size,
267+
)
268+
return self.__crop__(
269+
data_dict=data_dict,
270+
foreground_locations=data_dict.get("foreground_locations"),
271+
input_shape=input_shape,
272+
p_oversample_foreground=self.p_oversample_foreground,
273+
target_image_shape=target_image_shape,
274+
target_label_shape=target_label_shape,
275+
)

0 commit comments

Comments
 (0)