Skip to content

Commit f27517b

Browse files
Inverse Threading Fix (#8418)
Fixes #8056. ### Description This fixes a race condition where the `tracing` member may be toggled in multiple threads. This is turned into a thread-local value that removes this issue. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Eric Kerfoot <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 8f3d8e8 commit f27517b

File tree

6 files changed

+146
-9
lines changed

6 files changed

+146
-9
lines changed

monai/transforms/inverse.py

+39-7
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from __future__ import annotations
1313

14+
import threading
1415
import warnings
1516
from collections.abc import Hashable, Mapping
1617
from contextlib import contextmanager
@@ -66,15 +67,41 @@ class TraceableTransform(Transform):
6667
The information in the stack of applied transforms must be compatible with the
6768
default collate, by only storing strings, numbers and arrays.
6869
69-
`tracing` could be enabled by `self.set_tracing` or setting
70+
`tracing` could be enabled by assigning to `self.tracing` or setting
7071
`MONAI_TRACE_TRANSFORM` when initializing the class.
7172
"""
7273

73-
tracing = MONAIEnvVars.trace_transform() != "0"
74+
def _init_trace_threadlocal(self):
75+
"""Create a `_tracing` instance member to store the thread-local tracing state value."""
76+
# needed since this class is meant to be a trait with no constructor
77+
if not hasattr(self, "_tracing"):
78+
self._tracing = threading.local()
79+
80+
# This is True while the above initialising _tracing is False when this is
81+
# called from a different thread than the one initialising _tracing.
82+
if not hasattr(self._tracing, "value"):
83+
self._tracing.value = MONAIEnvVars.trace_transform() != "0"
84+
85+
def __getstate__(self):
86+
"""When pickling, remove the `_tracing` member from the output, if present, since it's not picklable."""
87+
_dict = dict(getattr(self, "__dict__", {})) # this makes __dict__ always present in the unpickled object
88+
_slots = {k: getattr(self, k) for k in getattr(self, "__slots__", [])}
89+
_dict.pop("_tracing", None) # remove tracing
90+
return _dict if len(_slots) == 0 else (_dict, _slots)
91+
92+
@property
93+
def tracing(self) -> bool:
94+
"""
95+
Returns the tracing state, which is thread-local and initialised to `MONAIEnvVars.trace_transform() != "0"`.
96+
"""
97+
self._init_trace_threadlocal()
98+
return bool(self._tracing.value)
7499

75-
def set_tracing(self, tracing: bool) -> None:
76-
"""Set whether to trace transforms."""
77-
self.tracing = tracing
100+
@tracing.setter
101+
def tracing(self, val: bool):
102+
"""Sets the thread-local tracing state to `val`."""
103+
self._init_trace_threadlocal()
104+
self._tracing.value = val
78105

79106
@staticmethod
80107
def trace_key(key: Hashable = None):
@@ -291,7 +318,7 @@ def check_transforms_match(self, transform: Mapping) -> None:
291318

292319
def get_most_recent_transform(self, data, key: Hashable = None, check: bool = True, pop: bool = False):
293320
"""
294-
Get most recent transform for the stack.
321+
Get most recent matching transform for the current class from the sequence of applied operations.
295322
296323
Args:
297324
data: dictionary of data or `MetaTensor`.
@@ -316,9 +343,14 @@ def get_most_recent_transform(self, data, key: Hashable = None, check: bool = Tr
316343
all_transforms = data.get(self.trace_key(key), MetaTensor.get_default_applied_operations())
317344
else:
318345
raise ValueError(f"`data` should be either `MetaTensor` or dictionary, got {type(data)}.")
346+
347+
if not all_transforms:
348+
raise ValueError(f"Item of type {type(data)} (key: {key}, pop: {pop}) has empty 'applied_operations'")
349+
319350
if check:
320351
self.check_transforms_match(all_transforms[-1])
321-
return all_transforms.pop() if pop else all_transforms[-1]
352+
353+
return all_transforms.pop(-1) if pop else all_transforms[-1]
322354

323355
def pop_transform(self, data, key: Hashable = None, check: bool = True):
324356
"""
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
import unittest
15+
from itertools import product
16+
17+
import torch
18+
from parameterized import parameterized
19+
20+
from monai.data import DataLoader, Dataset, MetaTensor, ThreadDataLoader, create_test_image_2d
21+
from monai.engines.evaluator import SupervisedEvaluator
22+
from monai.transforms import Compose, EnsureChannelFirstd, Invertd, Spacingd
23+
from monai.utils.enums import CommonKeys
24+
from tests.test_utils import TEST_DEVICES, SkipIfNoModule
25+
26+
27+
class TestInvertDict(unittest.TestCase):
28+
29+
def setUp(self):
30+
self.orig_size = (60, 60)
31+
img, _ = create_test_image_2d(*self.orig_size, 2, 10, num_seg_classes=2)
32+
self.img = MetaTensor(img, meta={"original_channel_dim": float("nan"), "pixdim": [1.0, 1.0]})
33+
self.key = CommonKeys.IMAGE
34+
self.pred = CommonKeys.PRED
35+
self.new_pixdim = 2.0
36+
37+
self.preprocessing = Compose([EnsureChannelFirstd(self.key), Spacingd(self.key, pixdim=[self.new_pixdim] * 2)])
38+
39+
self.postprocessing = Compose([Invertd(self.pred, transform=self.preprocessing, orig_keys=self.key)])
40+
41+
@parameterized.expand(TEST_DEVICES)
42+
def test_simple_processing(self, device):
43+
"""
44+
Tests postprocessing operations perform correctly, in particular that `Invertd` does inversion correctly.
45+
46+
This will apply the preprocessing sequence which resizes the result, then the postprocess sequence which
47+
returns it to the original shape using Invertd. This tests that the shape of the output is the same as the
48+
original image. This will also test that Invertd doesn't get confused if transforms in the postprocessing
49+
sequence are tracing and so adding information to `applied_operations`, this is what `Lambdad` is doing in
50+
`self.postprocessing`.
51+
"""
52+
53+
item = {self.key: self.img.to(device)}
54+
pre = self.preprocessing(item)
55+
56+
nw = int(self.orig_size[0] / self.new_pixdim)
57+
nh = int(self.orig_size[1] / self.new_pixdim)
58+
59+
self.assertTupleEqual(pre[self.key].shape, (1, nh, nw), "Pre-processing did not reshape input correctly")
60+
self.assertTrue(len(pre[self.key].applied_operations) > 0, "Pre-processing transforms did not trace correctly")
61+
62+
pre[self.pred] = pre[self.key] # the inputs are the prediction for this test
63+
64+
post = self.postprocessing(pre)
65+
66+
self.assertTupleEqual(
67+
post[self.pred].shape, (1, *self.orig_size), "Result does not have same shape as original input"
68+
)
69+
70+
@parameterized.expand(product(sum(TEST_DEVICES, []), [True, False]))
71+
@SkipIfNoModule("ignite")
72+
def test_workflow(self, device, use_threads):
73+
"""
74+
This tests the interaction between pre and postprocesing transform sequences being executed in parallel.
75+
76+
When the `ThreadDataLoader` is used to load batches, this is done in parallel at times with the execution of
77+
the post-process transform sequence. Previously this encountered a race condition at times because the
78+
`TraceableTransform.tracing` variables of transforms was being toggled in different threads, so at times a
79+
pre-process transform wouldn't trace correctly and so confuse `Invertd`. Using a `SupervisedEvaluator` is
80+
the best way to induce this race condition, other methods didn't get the timing right..
81+
"""
82+
batch_size = 2
83+
ds_size = 4
84+
test_data = [{self.key: self.img.clone().to(device)} for _ in range(ds_size)]
85+
ds = Dataset(test_data, transform=self.preprocessing)
86+
dl_type = ThreadDataLoader if use_threads else DataLoader
87+
dl = dl_type(ds, num_workers=0, batch_size=batch_size)
88+
89+
class AssertAppliedOps(torch.nn.Module):
90+
def forward(self, x):
91+
assert len(x.applied_operations) == x.shape[0]
92+
assert all(len(a) > 0 for a in x.applied_operations)
93+
return x
94+
95+
evaluator = SupervisedEvaluator(
96+
device=device, network=AssertAppliedOps(), postprocessing=self.postprocessing, val_data_loader=dl
97+
)
98+
99+
evaluator.run()
100+
101+
self.assertTupleEqual(evaluator.state.output[0][self.pred].shape, (1, *self.orig_size))
102+
103+
104+
if __name__ == "__main__":
105+
unittest.main()

tests/transforms/inverse/test_traceable_transform.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,13 @@ def test_default(self):
4545
self.assertEqual(len(data[expected_key]), 2)
4646
self.assertEqual(data[expected_key][-1]["class"], "_TraceTest")
4747

48-
with self.assertRaises(IndexError):
48+
with self.assertRaises(ValueError):
4949
a.pop({"test": "test"}) # no stack in the data
5050
data = a.pop(data)
5151
data = a.pop(data)
5252
self.assertEqual(data[expected_key], [])
5353

54-
with self.assertRaises(IndexError): # no more items
54+
with self.assertRaises(ValueError): # no more items
5555
a.pop(data)
5656

5757

0 commit comments

Comments
 (0)