@@ -305,5 +305,233 @@ def test_patch_inferer_errors(self, inputs, arguments, expected_error):
305
305
inferer (inputs = inputs , network = lambda x : x )
306
306
307
307
308
+
309
+ # ----------------------------------------------------------------------------
310
+ # Error test cases with conditionign
311
+ # ----------------------------------------------------------------------------
312
+
313
+ # no-overlapping 2x2 patches
314
+ TEST_CASE_0_TENSOR_c = [
315
+ TENSOR_4x4 ,
316
+ dict (splitter = SlidingWindowSplitter (patch_size = (2 , 2 )), merger_cls = AvgMerger ),
317
+ lambda x , condition : x + condition ,
318
+ TENSOR_4x4 * 2 ,
319
+ ]
320
+
321
+ # no-overlapping 2x2 patches using all default parameters (except for splitter)
322
+ TEST_CASE_1_TENSOR_c = [
323
+ TENSOR_4x4 ,
324
+ dict (splitter = SlidingWindowSplitter (patch_size = (2 , 2 ))),
325
+ lambda x , condition : x + condition ,
326
+ TENSOR_4x4 * 2 ,
327
+ ]
328
+
329
+ # divisible batch_size
330
+ TEST_CASE_2_TENSOR_c = [
331
+ TENSOR_4x4 ,
332
+ dict (splitter = SlidingWindowSplitter (patch_size = (2 , 2 )), merger_cls = AvgMerger , batch_size = 2 ),
333
+ lambda x , condition : x + condition ,
334
+ TENSOR_4x4 * 2 ,
335
+ ]
336
+
337
+ # non-divisible batch_size
338
+ TEST_CASE_3_TENSOR_c = [
339
+ TENSOR_4x4 ,
340
+ dict (splitter = SlidingWindowSplitter (patch_size = (2 , 2 )), merger_cls = AvgMerger , batch_size = 3 ),
341
+ lambda x , condition : x + condition ,
342
+ TENSOR_4x4 * 2 ,
343
+ ]
344
+
345
+ # patches that are already split (Splitter should be None)
346
+ TEST_CASE_4_SPLIT_LIST_c = [
347
+ [
348
+ (TENSOR_4x4 [..., :2 , :2 ], (0 , 0 )),
349
+ (TENSOR_4x4 [..., :2 , 2 :], (0 , 2 )),
350
+ (TENSOR_4x4 [..., 2 :, :2 ], (2 , 0 )),
351
+ (TENSOR_4x4 [..., 2 :, 2 :], (2 , 2 )),
352
+ ],
353
+ dict (splitter = None , merger_cls = AvgMerger , merged_shape = (2 , 3 , 4 , 4 )),
354
+ lambda x , condition : x + condition ,
355
+ TENSOR_4x4 * 2 ,
356
+ ]
357
+
358
+ # using all default parameters (patches are already split)
359
+ TEST_CASE_5_SPLIT_LIST_c = [
360
+ [
361
+ (TENSOR_4x4 [..., :2 , :2 ], (0 , 0 )),
362
+ (TENSOR_4x4 [..., :2 , 2 :], (0 , 2 )),
363
+ (TENSOR_4x4 [..., 2 :, :2 ], (2 , 0 )),
364
+ (TENSOR_4x4 [..., 2 :, 2 :], (2 , 2 )),
365
+ ],
366
+ dict (merger_cls = AvgMerger , merged_shape = (2 , 3 , 4 , 4 )),
367
+ lambda x , condition : x + condition ,
368
+ TENSOR_4x4 * 2 ,
369
+ ]
370
+
371
+ # output smaller than input patches
372
+ TEST_CASE_6_SMALLER_c = [
373
+ TENSOR_4x4 ,
374
+ dict (splitter = SlidingWindowSplitter (patch_size = (2 , 2 )), merger_cls = AvgMerger ),
375
+ lambda x , condition : torch .mean (x , dim = (- 1 , - 2 ), keepdim = True ) + torch .mean (condition , dim = (- 1 , - 2 ), keepdim = True ),
376
+ TENSOR_2x2 * 2 ,
377
+ ]
378
+
379
+ # preprocess patches
380
+ TEST_CASE_7_PREPROCESS_c = [
381
+ TENSOR_4x4 ,
382
+ dict (
383
+ splitter = SlidingWindowSplitter (patch_size = (2 , 2 )),
384
+ merger_cls = AvgMerger ,
385
+ preprocessing = lambda x : 2 * x ,
386
+ postprocessing = None ,
387
+ ),
388
+ lambda x , condition : x + condition ,
389
+ 2 * TENSOR_4x4 + TENSOR_4x4 ,
390
+ ]
391
+
392
+ # preprocess patches
393
+ TEST_CASE_8_POSTPROCESS_c = [
394
+ TENSOR_4x4 ,
395
+ dict (
396
+ splitter = SlidingWindowSplitter (patch_size = (2 , 2 )),
397
+ merger_cls = AvgMerger ,
398
+ preprocessing = None ,
399
+ postprocessing = lambda x : 4 * x ,
400
+ ),
401
+ lambda x , condition : x + condition ,
402
+ 4 * TENSOR_4x4 * 2 ,
403
+ ]
404
+
405
+ # str merger as the class name
406
+ TEST_CASE_9_STR_MERGER_c = [
407
+ TENSOR_4x4 ,
408
+ dict (splitter = SlidingWindowSplitter (patch_size = (2 , 2 )), merger_cls = "AvgMerger" ),
409
+ lambda x , condition : x + condition ,
410
+ TENSOR_4x4 * 2 ,
411
+ ]
412
+
413
+ # str merger as dotted patch
414
+ TEST_CASE_10_STR_MERGER_c = [
415
+ TENSOR_4x4 ,
416
+ dict (splitter = SlidingWindowSplitter (patch_size = (2 , 2 )), merger_cls = "monai.inferers.merger.AvgMerger" ),
417
+ lambda x , condition : x + condition ,
418
+ TENSOR_4x4 * 2 ,
419
+ ]
420
+
421
+ # non-divisible patch_size leading to larger image (without matching spatial shape)
422
+ TEST_CASE_11_PADDING_c = [
423
+ TENSOR_4x4 ,
424
+ dict (
425
+ splitter = SlidingWindowSplitter (patch_size = (2 , 3 ), pad_mode = "constant" , pad_value = 0.0 ),
426
+ merger_cls = AvgMerger ,
427
+ match_spatial_shape = False ,
428
+ ),
429
+ lambda x , condition : x + condition ,
430
+ pad (TENSOR_4x4 , (0 , 2 ), value = 0.0 ) * 2 ,
431
+ ]
432
+
433
+ # non-divisible patch_size with matching spatial shapes
434
+ TEST_CASE_12_MATCHING_c = [
435
+ TENSOR_4x4 ,
436
+ dict (splitter = SlidingWindowSplitter (patch_size = (2 , 3 ), pad_mode = None ), merger_cls = AvgMerger ),
437
+ lambda x , condition : x + condition ,
438
+ pad (TENSOR_4x4 [..., :3 ], (0 , 1 ), value = float ("nan" )) * 2 ,
439
+ ]
440
+
441
+ # non-divisible patch_size with matching spatial shapes
442
+ TEST_CASE_13_PADDING_MATCHING_c = [
443
+ TENSOR_4x4 ,
444
+ dict (splitter = SlidingWindowSplitter (patch_size = (2 , 3 )), merger_cls = AvgMerger ),
445
+ lambda x , condition : x + condition ,
446
+ TENSOR_4x4 * 2 ,
447
+ ]
448
+
449
+ # multi-threading
450
+ TEST_CASE_14_MULTITHREAD_BUFFER_c = [
451
+ TENSOR_4x4 ,
452
+ dict (splitter = SlidingWindowSplitter (patch_size = (2 , 2 )), merger_cls = AvgMerger , buffer_size = 2 ),
453
+ lambda x , condition : x + condition ,
454
+ TENSOR_4x4 * 2 ,
455
+ ]
456
+
457
+ # multi-threading with batch
458
+ TEST_CASE_15_MULTITHREADD_BUFFER_c = [
459
+ TENSOR_4x4 ,
460
+ dict (splitter = SlidingWindowSplitter (patch_size = (2 , 2 )), merger_cls = AvgMerger , buffer_size = 4 , batch_size = 4 ),
461
+ lambda x , condition : x + condition ,
462
+ TENSOR_4x4 * 2 ,
463
+ ]
464
+
465
+ # list of tensor output
466
+ TEST_CASE_0_LIST_TENSOR_c = [
467
+ TENSOR_4x4 ,
468
+ dict (splitter = SlidingWindowSplitter (patch_size = (2 , 2 )), merger_cls = AvgMerger ),
469
+ lambda x , condition : (x + condition , x + condition ),
470
+ (TENSOR_4x4 * 2 , TENSOR_4x4 * 2 ),
471
+ ]
472
+
473
+ # list of tensor output
474
+ TEST_CASE_0_DICT_c = [
475
+ TENSOR_4x4 ,
476
+ dict (splitter = SlidingWindowSplitter (patch_size = (2 , 2 )), merger_cls = AvgMerger ),
477
+ lambda x , condition : {"model_output" : x + condition },
478
+ {"model_output" : TENSOR_4x4 * 2 },
479
+ ]
480
+
481
+
482
+
483
+ class PatchInfererTests_cond (unittest .TestCase ):
484
+ @parameterized .expand (
485
+ [
486
+ TEST_CASE_0_TENSOR_c ,
487
+ TEST_CASE_1_TENSOR_c ,
488
+ TEST_CASE_2_TENSOR_c ,
489
+ TEST_CASE_3_TENSOR_c ,
490
+ TEST_CASE_4_SPLIT_LIST_c ,
491
+ TEST_CASE_5_SPLIT_LIST_c ,
492
+ TEST_CASE_6_SMALLER_c ,
493
+ TEST_CASE_7_PREPROCESS_c ,
494
+ TEST_CASE_8_POSTPROCESS_c ,
495
+ TEST_CASE_9_STR_MERGER_c ,
496
+ TEST_CASE_10_STR_MERGER_c ,
497
+ TEST_CASE_11_PADDING_c ,
498
+ TEST_CASE_12_MATCHING_c ,
499
+ TEST_CASE_13_PADDING_MATCHING_c ,
500
+ TEST_CASE_14_MULTITHREAD_BUFFER_c ,
501
+ TEST_CASE_15_MULTITHREADD_BUFFER_c ,
502
+ ]
503
+ )
504
+ def test_patch_inferer_tensor (self , inputs , arguments , network , expected ):
505
+ if isinstance (inputs , list ): # case 4 and 5
506
+ condition = [(x [0 ].clone (), x [1 ]) for x in inputs ]
507
+ else :
508
+ condition = inputs .clone ()
509
+ inferer = PatchInferer (** arguments )
510
+ output = inferer (inputs = inputs , network = network , condition = condition )
511
+ assert_allclose (output , expected )
512
+
513
+ @parameterized .expand ([TEST_CASE_0_LIST_TENSOR_c ])
514
+ def test_patch_inferer_list_tensor (self , inputs , arguments , network , expected ):
515
+ if isinstance (inputs , list ): # case 4 and 5
516
+ condition = [(x [0 ].clone (), x [1 ]) for x in inputs ]
517
+ else :
518
+ condition = inputs .clone ()
519
+ inferer = PatchInferer (** arguments )
520
+ output = inferer (inputs = inputs , network = network , condition = condition )
521
+ for out , exp in zip (output , expected ):
522
+ assert_allclose (out , exp )
523
+
524
+ @parameterized .expand ([TEST_CASE_0_DICT_c ])
525
+ def test_patch_inferer_dict (self , inputs , arguments , network , expected ):
526
+ if isinstance (inputs , list ): # case 4 and 5
527
+ condition = [(x [0 ].clone (), x [1 ]) for x in inputs ]
528
+ else :
529
+ condition = inputs .clone ()
530
+ inferer = PatchInferer (** arguments )
531
+ output = inferer (inputs = inputs , network = network , condition = condition )
532
+ for k in expected :
533
+ assert_allclose (output [k ], expected [k ])
534
+
535
+
308
536
if __name__ == "__main__" :
309
537
unittest .main ()
0 commit comments