Skip to content

Commit 166ff7c

Browse files
Flag to disable tracking for training task (#1104)
* Flag to disable tracking for training task Signed-off-by: Sachidanand Alle <[email protected]> * Flag to disable tracking for training task Signed-off-by: Sachidanand Alle <[email protected]> Signed-off-by: Sachidanand Alle <[email protected]>
1 parent a1a2b84 commit 166ff7c

File tree

1 file changed

+13
-1
lines changed

1 file changed

+13
-1
lines changed

monailabel/tasks/train/basic_train.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@
3232
PersistentDataset,
3333
SmartCacheDataset,
3434
ThreadDataLoader,
35+
get_track_meta,
3536
partition_dataset,
37+
set_track_meta,
3638
)
3739
from monai.engines import SupervisedEvaluator, SupervisedTrainer
3840
from monai.handlers import (
@@ -167,6 +169,7 @@ def __init__(
167169
self._find_unused_parameters = find_unused_parameters
168170
self._load_strict = load_strict
169171
self._labels = [] if labels is None else [labels] if isinstance(labels, str) else labels
172+
self._disable_tracking = kwargs.get("disable_tracking", True)
170173

171174
@abstractmethod
172175
def network(self, context: Context):
@@ -455,7 +458,16 @@ def train(self, rank, world_size, request, datalist):
455458

456459
# Finalize and Run Training
457460
self.finalize(context)
458-
context.trainer.run()
461+
462+
# Disable Tracking
463+
meta_tracking = get_track_meta()
464+
if self._disable_tracking:
465+
set_track_meta(False)
466+
467+
try:
468+
context.trainer.run()
469+
finally:
470+
set_track_meta(meta_tracking) # In case of same process (restore)
459471

460472
if context.multi_gpu:
461473
torch.distributed.destroy_process_group()

0 commit comments

Comments
 (0)