Skip to content

Commit 1bd737a

Browse files
fix: use del instead of explicit shutdown in CombinedLoader.reset() (#21708)
* fix: use del instead of explicit shutdown in CombinedLoader.reset() * Update src/lightning/pytorch/CHANGELOG.md --------- Co-authored-by: Deependu <deependujha21@gmail.com>
1 parent 932b7e3 commit 1bd737a

4 files changed

Lines changed: 32 additions & 7 deletions

File tree

src/lightning/pytorch/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2424

2525
### Fixed
2626

27+
- Fixed non-zero process exits in `CombinedLoader.reset()` with large tensors and persistent spawned workers by avoiding explicit `_shutdown_workers()` calls and relying on iterator cleanup via `del` [#21708](https://github.com/Lightning-AI/pytorch-lightning/issues/21708)
28+
2729
- Fixed `SIGTERMException` producing a zero exit code instead of 143 (128 + SIGTERM) ([#21623](https://github.com/Lightning-AI/pytorch-lightning/issues/21623))
2830

2931
---

src/lightning/pytorch/utilities/combined_loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,7 @@ def _load_state_dicts(self, states: list[dict[str, Any]]) -> None:
397397
def _shutdown_workers_and_reset_iterator(dataloader: object) -> None:
398398
if hasattr(dataloader, "_iterator"):
399399
if isinstance(dataloader._iterator, _MultiProcessingDataLoaderIter):
400-
dataloader._iterator._shutdown_workers()
400+
del dataloader._iterator
401401
dataloader._iterator = None
402402

403403

tests/tests_pytorch/loops/test_loops.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1111,9 +1111,9 @@ def __init__(self, *args, dataloader, **kwargs):
11111111
super().__init__(*args, **kwargs)
11121112
self.dataloader = dataloader
11131113

1114-
def _shutdown_workers(self):
1114+
def __del__(self):
11151115
self.dataloader.shutdown_workers_epochs.append(trainer.current_epoch)
1116-
super()._shutdown_workers()
1116+
super().__del__()
11171117

11181118
class TestDataLoader(DataLoader):
11191119
def __init__(self, *args, **kwargs):
@@ -1137,8 +1137,8 @@ def _get_iterator(self):
11371137
trainer.fit(model, train_dataloader, val_dataloader)
11381138

11391139
if persistent_workers:
1140-
# workers get created and persist until the teardown in the final epoch
1141-
expected = [trainer.current_epoch, trainer.current_epoch] # once epoch end, once on teardown
1140+
# workers persist across epochs and are shut down exactly once via __del__.
1141+
expected = [trainer.current_epoch]
11421142
elif should_fail:
11431143
expected = [
11441144
# <-- iter() on epoch 0, workers get created
@@ -1155,8 +1155,7 @@ def _get_iterator(self):
11551155
assert train_dataloader.shutdown_workers_epochs == expected
11561156

11571157
if persistent_workers:
1158-
# workers get created and persist until the teardown in the final epoch
1159-
expected = [trainer.current_epoch, trainer.current_epoch] # once epoch end, once on teardown
1158+
expected = [trainer.current_epoch]
11601159
elif should_fail:
11611160
expected = [
11621161
# <-- iter() on sanity check, workers get created

tests/tests_pytorch/utilities/test_combined_loader.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -657,3 +657,27 @@ def test_load_state_dicts():
657657
cl._load_state_dicts([state1, state2])
658658
stateful1.load_state_dict.assert_called_with(state1)
659659
stateful2.load_state_dict.assert_called_with(state2)
660+
661+
662+
def test_combined_loader_reset_uses_del_not_shutdown_workers():
663+
"""Test that `combined_loader.reset()` uses `del` to reset the dataloader iterator instead of calling
664+
`_shutdown_workers()` explicitly.
665+
666+
This is a regression test for https://github.com/Lightning-AI/pytorch-lightning/issues/21703
667+
668+
"""
669+
from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter
670+
671+
dataloader = DataLoader(range(10), num_workers=2, persistent_workers=True, multiprocessing_context="spawn")
672+
combined_loader = CombinedLoader([dataloader])
673+
674+
mock_iterator = Mock(spec=_MultiProcessingDataLoaderIter)
675+
mock_iterator._shutdown_workers = Mock()
676+
dataloader._iterator = mock_iterator
677+
678+
iterator_ref = dataloader._iterator
679+
680+
combined_loader.reset()
681+
682+
iterator_ref._shutdown_workers.assert_not_called()
683+
assert dataloader._iterator is None

0 commit comments

Comments
 (0)