Skip to content

Commit 248d3c3

Browse files
committed
Add optional conditioning support to inferers and corresponding unit tests
1 parent 0f5da11 commit 248d3c3

File tree

5 files changed

+625
-12
lines changed

5 files changed

+625
-12
lines changed

Diff for: monai/inferers/inferer.py

+37-9
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,9 @@ def __call__(
324324
kwargs: optional keyword args to be passed to ``network``.
325325
326326
"""
327+
# check if there is a conditioning signal
328+
condition = kwargs.pop("condition", None)
329+
327330
patches_locations: Iterable[tuple[torch.Tensor, Sequence[int]]] | MetaTensor
328331
if self.splitter is None:
329332
# handle situations where the splitter is not provided
@@ -350,14 +353,28 @@ def __call__(
350353

351354
ratios: list[float] = []
352355
mergers: list[Merger] = []
353-
for patches, locations, batch_size in self._batch_sampler(patches_locations):
354-
# run inference
355-
outputs = self._run_inference(network, patches, *args, **kwargs)
356-
# initialize the mergers
357-
if not mergers:
358-
mergers, ratios = self._initialize_mergers(inputs, outputs, patches, batch_size)
359-
# aggregate outputs
360-
self._aggregate(outputs, locations, batch_size, mergers, ratios)
356+
if condition is not None:
357+
for (patches, locations, batch_size), (condition_patches, _, _) in zip(
358+
self._batch_sampler(patches_locations), self._batch_sampler(condition_locations)
359+
):
360+
# add patched condition to kwargs
361+
kwargs["condition"] = condition_patches
362+
# run inference
363+
outputs = self._run_inference(network, patches, *args, **kwargs)
364+
# initialize the mergers
365+
if not mergers:
366+
mergers, ratios = self._initialize_mergers(inputs, outputs, patches, batch_size)
367+
# aggregate outputs
368+
self._aggregate(outputs, locations, batch_size, mergers, ratios)
369+
else:
370+
for patches, locations, batch_size in self._batch_sampler(patches_locations):
371+
# run inference
372+
outputs = self._run_inference(network, patches, *args, **kwargs)
373+
# initialize the mergers
374+
if not mergers:
375+
mergers, ratios = self._initialize_mergers(inputs, outputs, patches, batch_size)
376+
# aggregate outputs
377+
self._aggregate(outputs, locations, batch_size, mergers, ratios)
361378

362379
# finalize the mergers and get the results
363380
merged_outputs = [merger.finalize() for merger in mergers]
@@ -742,7 +759,18 @@ def __call__(
742759
f"Currently, only 2D `roi_size` ({self.orig_roi_size}) with 3D `inputs` tensor (shape={inputs.shape}) is supported."
743760
)
744761

745-
return super().__call__(inputs=inputs, network=lambda x: self.network_wrapper(network, x, *args, **kwargs))
762+
# check if there is a conditioning signal
763+
condition = kwargs.get("condition", None)
764+
if condition is not None:
765+
return super().__call__(
766+
inputs=inputs,
767+
network=lambda x, *args, **kwargs: self.network_wrapper(network, x, *args, **kwargs),
768+
condition=condition,
769+
)
770+
else:
771+
return super().__call__(
772+
inputs=inputs, network=lambda x, *args, **kwargs: self.network_wrapper(network, x, *args, **kwargs)
773+
)
746774

747775
def network_wrapper(
748776
self,

Diff for: monai/inferers/utils.py

+13-3
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,8 @@ def sliding_window_inference(
153153
device = device or inputs.device
154154
sw_device = sw_device or inputs.device
155155

156+
condition = kwargs.pop("condition", None)
157+
156158
temp_meta = None
157159
if isinstance(inputs, MetaTensor):
158160
temp_meta = MetaTensor([]).copy_meta_from(inputs, copy_attr=False)
@@ -168,6 +170,8 @@ def sliding_window_inference(
168170
pad_size.extend([half, diff - half])
169171
if any(pad_size):
170172
inputs = F.pad(inputs, pad=pad_size, mode=look_up_option(padding_mode, PytorchPadMode), value=cval)
173+
if condition is not None:
174+
condition = F.pad(condition, pad=pad_size, mode=look_up_option(padding_mode, PytorchPadMode), value=cval)
171175

172176
# Store all slices
173177
scan_interval = _get_scan_interval(image_size, roi_size, num_spatial_dims, overlap)
@@ -220,13 +224,19 @@ def sliding_window_inference(
220224
]
221225
if sw_batch_size > 1:
222226
win_data = torch.cat([inputs[win_slice] for win_slice in unravel_slice]).to(sw_device)
227+
if condition is not None:
228+
win_condition = torch.cat([condition[win_slice] for win_slice in unravel_slice]).to(sw_device)
229+
kwargs["condition"] = win_condition
223230
else:
224231
win_data = inputs[unravel_slice[0]].to(sw_device)
232+
if condition is not None:
233+
win_condition = condition[unravel_slice[0]].to(sw_device)
234+
kwargs["condition"] = win_condition
235+
225236
if with_coord:
226-
seg_prob_out = predictor(win_data, unravel_slice, *args, **kwargs) # batched patch
237+
seg_prob_out = predictor(win_data, unravel_slice, *args, **kwargs)
227238
else:
228-
seg_prob_out = predictor(win_data, *args, **kwargs) # batched patch
229-
239+
seg_prob_out = predictor(win_data, *args, **kwargs)
230240
# convert seg_prob_out to tuple seg_tuple, this does not allocate new memory.
231241
dict_keys, seg_tuple = _flatten_struct(seg_prob_out)
232242
if process_fn:

Diff for: tests/inferers/test_patch_inferer.py

+228
Original file line numberDiff line numberDiff line change
@@ -305,5 +305,233 @@ def test_patch_inferer_errors(self, inputs, arguments, expected_error):
305305
inferer(inputs=inputs, network=lambda x: x)
306306

307307

308+
309+
# ----------------------------------------------------------------------------
310+
# Error test cases with conditionign
311+
# ----------------------------------------------------------------------------
312+
313+
# no-overlapping 2x2 patches
314+
TEST_CASE_0_TENSOR_c = [
315+
TENSOR_4x4,
316+
dict(splitter=SlidingWindowSplitter(patch_size=(2, 2)), merger_cls=AvgMerger),
317+
lambda x, condition: x + condition,
318+
TENSOR_4x4 * 2,
319+
]
320+
321+
# no-overlapping 2x2 patches using all default parameters (except for splitter)
322+
TEST_CASE_1_TENSOR_c = [
323+
TENSOR_4x4,
324+
dict(splitter=SlidingWindowSplitter(patch_size=(2, 2))),
325+
lambda x, condition: x + condition,
326+
TENSOR_4x4 * 2,
327+
]
328+
329+
# divisible batch_size
330+
TEST_CASE_2_TENSOR_c = [
331+
TENSOR_4x4,
332+
dict(splitter=SlidingWindowSplitter(patch_size=(2, 2)), merger_cls=AvgMerger, batch_size=2),
333+
lambda x, condition: x + condition,
334+
TENSOR_4x4 * 2,
335+
]
336+
337+
# non-divisible batch_size
338+
TEST_CASE_3_TENSOR_c = [
339+
TENSOR_4x4,
340+
dict(splitter=SlidingWindowSplitter(patch_size=(2, 2)), merger_cls=AvgMerger, batch_size=3),
341+
lambda x, condition: x + condition,
342+
TENSOR_4x4 * 2,
343+
]
344+
345+
# patches that are already split (Splitter should be None)
346+
TEST_CASE_4_SPLIT_LIST_c = [
347+
[
348+
(TENSOR_4x4[..., :2, :2], (0, 0)),
349+
(TENSOR_4x4[..., :2, 2:], (0, 2)),
350+
(TENSOR_4x4[..., 2:, :2], (2, 0)),
351+
(TENSOR_4x4[..., 2:, 2:], (2, 2)),
352+
],
353+
dict(splitter=None, merger_cls=AvgMerger, merged_shape=(2, 3, 4, 4)),
354+
lambda x, condition: x + condition,
355+
TENSOR_4x4 * 2,
356+
]
357+
358+
# using all default parameters (patches are already split)
359+
TEST_CASE_5_SPLIT_LIST_c = [
360+
[
361+
(TENSOR_4x4[..., :2, :2], (0, 0)),
362+
(TENSOR_4x4[..., :2, 2:], (0, 2)),
363+
(TENSOR_4x4[..., 2:, :2], (2, 0)),
364+
(TENSOR_4x4[..., 2:, 2:], (2, 2)),
365+
],
366+
dict(merger_cls=AvgMerger, merged_shape=(2, 3, 4, 4)),
367+
lambda x, condition: x + condition,
368+
TENSOR_4x4 * 2,
369+
]
370+
371+
# output smaller than input patches
372+
TEST_CASE_6_SMALLER_c = [
373+
TENSOR_4x4,
374+
dict(splitter=SlidingWindowSplitter(patch_size=(2, 2)), merger_cls=AvgMerger),
375+
lambda x, condition: torch.mean(x, dim=(-1, -2), keepdim=True) + torch.mean(condition, dim=(-1, -2), keepdim=True),
376+
TENSOR_2x2 * 2,
377+
]
378+
379+
# preprocess patches
380+
TEST_CASE_7_PREPROCESS_c = [
381+
TENSOR_4x4,
382+
dict(
383+
splitter=SlidingWindowSplitter(patch_size=(2, 2)),
384+
merger_cls=AvgMerger,
385+
preprocessing=lambda x: 2 * x,
386+
postprocessing=None,
387+
),
388+
lambda x, condition: x + condition,
389+
2 * TENSOR_4x4 + TENSOR_4x4,
390+
]
391+
392+
# preprocess patches
393+
TEST_CASE_8_POSTPROCESS_c = [
394+
TENSOR_4x4,
395+
dict(
396+
splitter=SlidingWindowSplitter(patch_size=(2, 2)),
397+
merger_cls=AvgMerger,
398+
preprocessing=None,
399+
postprocessing=lambda x: 4 * x,
400+
),
401+
lambda x, condition: x + condition,
402+
4 * TENSOR_4x4 * 2,
403+
]
404+
405+
# str merger as the class name
406+
TEST_CASE_9_STR_MERGER_c = [
407+
TENSOR_4x4,
408+
dict(splitter=SlidingWindowSplitter(patch_size=(2, 2)), merger_cls="AvgMerger"),
409+
lambda x, condition: x + condition,
410+
TENSOR_4x4 * 2,
411+
]
412+
413+
# str merger as dotted patch
414+
TEST_CASE_10_STR_MERGER_c = [
415+
TENSOR_4x4,
416+
dict(splitter=SlidingWindowSplitter(patch_size=(2, 2)), merger_cls="monai.inferers.merger.AvgMerger"),
417+
lambda x, condition: x + condition,
418+
TENSOR_4x4 * 2,
419+
]
420+
421+
# non-divisible patch_size leading to larger image (without matching spatial shape)
422+
TEST_CASE_11_PADDING_c = [
423+
TENSOR_4x4,
424+
dict(
425+
splitter=SlidingWindowSplitter(patch_size=(2, 3), pad_mode="constant", pad_value=0.0),
426+
merger_cls=AvgMerger,
427+
match_spatial_shape=False,
428+
),
429+
lambda x, condition: x + condition,
430+
pad(TENSOR_4x4, (0, 2), value=0.0) * 2,
431+
]
432+
433+
# non-divisible patch_size with matching spatial shapes
434+
TEST_CASE_12_MATCHING_c = [
435+
TENSOR_4x4,
436+
dict(splitter=SlidingWindowSplitter(patch_size=(2, 3), pad_mode=None), merger_cls=AvgMerger),
437+
lambda x, condition: x + condition,
438+
pad(TENSOR_4x4[..., :3], (0, 1), value=float("nan")) * 2,
439+
]
440+
441+
# non-divisible patch_size with matching spatial shapes
442+
TEST_CASE_13_PADDING_MATCHING_c = [
443+
TENSOR_4x4,
444+
dict(splitter=SlidingWindowSplitter(patch_size=(2, 3)), merger_cls=AvgMerger),
445+
lambda x, condition: x + condition,
446+
TENSOR_4x4 * 2,
447+
]
448+
449+
# multi-threading
450+
TEST_CASE_14_MULTITHREAD_BUFFER_c = [
451+
TENSOR_4x4,
452+
dict(splitter=SlidingWindowSplitter(patch_size=(2, 2)), merger_cls=AvgMerger, buffer_size=2),
453+
lambda x, condition: x + condition,
454+
TENSOR_4x4 * 2,
455+
]
456+
457+
# multi-threading with batch
458+
TEST_CASE_15_MULTITHREADD_BUFFER_c = [
459+
TENSOR_4x4,
460+
dict(splitter=SlidingWindowSplitter(patch_size=(2, 2)), merger_cls=AvgMerger, buffer_size=4, batch_size=4),
461+
lambda x, condition: x + condition,
462+
TENSOR_4x4 * 2,
463+
]
464+
465+
# list of tensor output
466+
TEST_CASE_0_LIST_TENSOR_c = [
467+
TENSOR_4x4,
468+
dict(splitter=SlidingWindowSplitter(patch_size=(2, 2)), merger_cls=AvgMerger),
469+
lambda x, condition: (x + condition, x + condition),
470+
(TENSOR_4x4 * 2, TENSOR_4x4 * 2),
471+
]
472+
473+
# list of tensor output
474+
TEST_CASE_0_DICT_c = [
475+
TENSOR_4x4,
476+
dict(splitter=SlidingWindowSplitter(patch_size=(2, 2)), merger_cls=AvgMerger),
477+
lambda x, condition: {"model_output": x + condition},
478+
{"model_output": TENSOR_4x4 * 2},
479+
]
480+
481+
482+
483+
class PatchInfererTests_cond(unittest.TestCase):
484+
@parameterized.expand(
485+
[
486+
TEST_CASE_0_TENSOR_c,
487+
TEST_CASE_1_TENSOR_c,
488+
TEST_CASE_2_TENSOR_c,
489+
TEST_CASE_3_TENSOR_c,
490+
TEST_CASE_4_SPLIT_LIST_c,
491+
TEST_CASE_5_SPLIT_LIST_c,
492+
TEST_CASE_6_SMALLER_c,
493+
TEST_CASE_7_PREPROCESS_c,
494+
TEST_CASE_8_POSTPROCESS_c,
495+
TEST_CASE_9_STR_MERGER_c,
496+
TEST_CASE_10_STR_MERGER_c,
497+
TEST_CASE_11_PADDING_c,
498+
TEST_CASE_12_MATCHING_c,
499+
TEST_CASE_13_PADDING_MATCHING_c,
500+
TEST_CASE_14_MULTITHREAD_BUFFER_c,
501+
TEST_CASE_15_MULTITHREADD_BUFFER_c,
502+
]
503+
)
504+
def test_patch_inferer_tensor(self, inputs, arguments, network, expected):
505+
if isinstance(inputs, list): # case 4 and 5
506+
condition = [(x[0].clone(), x[1]) for x in inputs]
507+
else:
508+
condition = inputs.clone()
509+
inferer = PatchInferer(**arguments)
510+
output = inferer(inputs=inputs, network=network, condition=condition)
511+
assert_allclose(output, expected)
512+
513+
@parameterized.expand([TEST_CASE_0_LIST_TENSOR_c])
514+
def test_patch_inferer_list_tensor(self, inputs, arguments, network, expected):
515+
if isinstance(inputs, list): # case 4 and 5
516+
condition = [(x[0].clone(), x[1]) for x in inputs]
517+
else:
518+
condition = inputs.clone()
519+
inferer = PatchInferer(**arguments)
520+
output = inferer(inputs=inputs, network=network, condition=condition)
521+
for out, exp in zip(output, expected):
522+
assert_allclose(out, exp)
523+
524+
@parameterized.expand([TEST_CASE_0_DICT_c])
525+
def test_patch_inferer_dict(self, inputs, arguments, network, expected):
526+
if isinstance(inputs, list): # case 4 and 5
527+
condition = [(x[0].clone(), x[1]) for x in inputs]
528+
else:
529+
condition = inputs.clone()
530+
inferer = PatchInferer(**arguments)
531+
output = inferer(inputs=inputs, network=network, condition=condition)
532+
for k in expected:
533+
assert_allclose(output[k], expected[k])
534+
535+
308536
if __name__ == "__main__":
309537
unittest.main()

Diff for: tests/inferers/test_slice_inferer.py

+33
Original file line numberDiff line numberDiff line change
@@ -53,5 +53,38 @@ def test_shape(self, spatial_dim):
5353
result = inferer(input_volume, model)
5454

5555

56+
class TestSliceInferer_cond(unittest.TestCase):
57+
58+
@parameterized.expand(TEST_CASES)
59+
def test_shape(self, spatial_dim):
60+
spatial_dim = int(spatial_dim)
61+
62+
model = UNet(
63+
spatial_dims=2, in_channels=1, out_channels=1, channels=(4, 8, 16), strides=(2, 2), num_res_units=2
64+
)
65+
66+
# overwrite the forward method to test the inferer with a model that takes a condition
67+
model.forward = lambda x, condition: x + condition if condition is not None else x
68+
69+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
70+
model.to(device)
71+
model.eval()
72+
73+
# Initialize a dummy 3D tensor volume with shape (N,C,D,H,W)
74+
input_volume = torch.ones(1, 1, 64, 256, 256, device=device)
75+
condition_volume = torch.ones(1, 1, 64, 256, 256, device=device)
76+
# Remove spatial dim to slide across from the roi_size
77+
roi_size = list(input_volume.shape[2:])
78+
roi_size.pop(spatial_dim)
79+
80+
# Initialize and run inferer
81+
inferer = SliceInferer(roi_size=roi_size, spatial_dim=spatial_dim, sw_batch_size=1, cval=-1)
82+
result = inferer(input_volume, model, condition=condition_volume)
83+
84+
self.assertTupleEqual(result.shape, input_volume.shape)
85+
self.assertEqual(result.sum(), (input_volume + condition_volume).sum())
86+
# test that the inferer can be run multiple times
87+
result = inferer(input_volume, model, condition=condition_volume)
88+
5689
if __name__ == "__main__":
5790
unittest.main()

0 commit comments

Comments
 (0)