|
| 1 | +# Copyright (c) MONAI Consortium |
| 2 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 3 | +# you may not use this file except in compliance with the License. |
| 4 | +# You may obtain a copy of the License at |
| 5 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 6 | +# Unless required by applicable law or agreed to in writing, software |
| 7 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 8 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 9 | +# See the License for the specific language governing permissions and |
| 10 | +# limitations under the License. |
| 11 | + |
| 12 | +from __future__ import annotations |
| 13 | + |
| 14 | +import unittest |
| 15 | +from itertools import product |
| 16 | + |
| 17 | +import torch |
| 18 | +from parameterized import parameterized |
| 19 | + |
| 20 | +from monai.data import DataLoader, Dataset, MetaTensor, ThreadDataLoader, create_test_image_2d |
| 21 | +from monai.engines.evaluator import SupervisedEvaluator |
| 22 | +from monai.transforms import Compose, EnsureChannelFirstd, Invertd, Spacingd |
| 23 | +from monai.utils.enums import CommonKeys |
| 24 | +from tests.test_utils import TEST_DEVICES, SkipIfNoModule |
| 25 | + |
| 26 | + |
| 27 | +class TestInvertDict(unittest.TestCase): |
| 28 | + |
| 29 | + def setUp(self): |
| 30 | + self.orig_size = (60, 60) |
| 31 | + img, _ = create_test_image_2d(*self.orig_size, 2, 10, num_seg_classes=2) |
| 32 | + self.img = MetaTensor(img, meta={"original_channel_dim": float("nan"), "pixdim": [1.0, 1.0]}) |
| 33 | + self.key = CommonKeys.IMAGE |
| 34 | + self.pred = CommonKeys.PRED |
| 35 | + self.new_pixdim = 2.0 |
| 36 | + |
| 37 | + self.preprocessing = Compose([EnsureChannelFirstd(self.key), Spacingd(self.key, pixdim=[self.new_pixdim] * 2)]) |
| 38 | + |
| 39 | + self.postprocessing = Compose([Invertd(self.pred, transform=self.preprocessing, orig_keys=self.key)]) |
| 40 | + |
| 41 | + @parameterized.expand(TEST_DEVICES) |
| 42 | + def test_simple_processing(self, device): |
| 43 | + """ |
| 44 | + Tests postprocessing operations perform correctly, in particular that `Invertd` does inversion correctly. |
| 45 | +
|
| 46 | + This will apply the preprocessing sequence which resizes the result, then the postprocess sequence which |
| 47 | + returns it to the original shape using Invertd. This tests that the shape of the output is the same as the |
| 48 | + original image. This will also test that Invertd doesn't get confused if transforms in the postprocessing |
| 49 | + sequence are tracing and so adding information to `applied_operations`, this is what `Lambdad` is doing in |
| 50 | + `self.postprocessing`. |
| 51 | + """ |
| 52 | + |
| 53 | + item = {self.key: self.img.to(device)} |
| 54 | + pre = self.preprocessing(item) |
| 55 | + |
| 56 | + nw = int(self.orig_size[0] / self.new_pixdim) |
| 57 | + nh = int(self.orig_size[1] / self.new_pixdim) |
| 58 | + |
| 59 | + self.assertTupleEqual(pre[self.key].shape, (1, nh, nw), "Pre-processing did not reshape input correctly") |
| 60 | + self.assertTrue(len(pre[self.key].applied_operations) > 0, "Pre-processing transforms did not trace correctly") |
| 61 | + |
| 62 | + pre[self.pred] = pre[self.key] # the inputs are the prediction for this test |
| 63 | + |
| 64 | + post = self.postprocessing(pre) |
| 65 | + |
| 66 | + self.assertTupleEqual( |
| 67 | + post[self.pred].shape, (1, *self.orig_size), "Result does not have same shape as original input" |
| 68 | + ) |
| 69 | + |
| 70 | + @parameterized.expand(product(sum(TEST_DEVICES, []), [True, False])) |
| 71 | + @SkipIfNoModule("ignite") |
| 72 | + def test_workflow(self, device, use_threads): |
| 73 | + """ |
| 74 | + This tests the interaction between pre and postprocesing transform sequences being executed in parallel. |
| 75 | +
|
| 76 | + When the `ThreadDataLoader` is used to load batches, this is done in parallel at times with the execution of |
| 77 | + the post-process transform sequence. Previously this encountered a race condition at times because the |
| 78 | + `TraceableTransform.tracing` variables of transforms was being toggled in different threads, so at times a |
| 79 | + pre-process transform wouldn't trace correctly and so confuse `Invertd`. Using a `SupervisedEvaluator` is |
| 80 | + the best way to induce this race condition, other methods didn't get the timing right.. |
| 81 | + """ |
| 82 | + batch_size = 2 |
| 83 | + ds_size = 4 |
| 84 | + test_data = [{self.key: self.img.clone().to(device)} for _ in range(ds_size)] |
| 85 | + ds = Dataset(test_data, transform=self.preprocessing) |
| 86 | + dl_type = ThreadDataLoader if use_threads else DataLoader |
| 87 | + dl = dl_type(ds, num_workers=0, batch_size=batch_size) |
| 88 | + |
| 89 | + class AssertAppliedOps(torch.nn.Module): |
| 90 | + def forward(self, x): |
| 91 | + assert len(x.applied_operations) == x.shape[0] |
| 92 | + assert all(len(a) > 0 for a in x.applied_operations) |
| 93 | + return x |
| 94 | + |
| 95 | + evaluator = SupervisedEvaluator( |
| 96 | + device=device, network=AssertAppliedOps(), postprocessing=self.postprocessing, val_data_loader=dl |
| 97 | + ) |
| 98 | + |
| 99 | + evaluator.run() |
| 100 | + |
| 101 | + self.assertTupleEqual(evaluator.state.output[0][self.pred].shape, (1, *self.orig_size)) |
| 102 | + |
| 103 | + |
| 104 | +if __name__ == "__main__": |
| 105 | + unittest.main() |
0 commit comments