diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index cdc420045109e..4437ff81a44a2 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -24,6 +24,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed non-zero process exits in `CombinedLoader.reset()` with large tensors and persistent spawned workers by avoiding explicit `_shutdown_workers()` calls and relying on iterator cleanup via `del` [#21708](https://github.com/Lightning-AI/pytorch-lightning/issues/21708) + - Fixed `SIGTERMException` producing a zero exit code instead of 143 (128 + SIGTERM) ([#21623](https://github.com/Lightning-AI/pytorch-lightning/issues/21623)) --- diff --git a/src/lightning/pytorch/utilities/combined_loader.py b/src/lightning/pytorch/utilities/combined_loader.py index 9c89c998aa913..06425047bffba 100644 --- a/src/lightning/pytorch/utilities/combined_loader.py +++ b/src/lightning/pytorch/utilities/combined_loader.py @@ -397,7 +397,7 @@ def _load_state_dicts(self, states: list[dict[str, Any]]) -> None: def _shutdown_workers_and_reset_iterator(dataloader: object) -> None: if hasattr(dataloader, "_iterator"): if isinstance(dataloader._iterator, _MultiProcessingDataLoaderIter): - dataloader._iterator._shutdown_workers() + del dataloader._iterator dataloader._iterator = None diff --git a/tests/tests_pytorch/loops/test_loops.py b/tests/tests_pytorch/loops/test_loops.py index 384ae2b47859b..f3d333a3f412b 100644 --- a/tests/tests_pytorch/loops/test_loops.py +++ b/tests/tests_pytorch/loops/test_loops.py @@ -1111,9 +1111,9 @@ def __init__(self, *args, dataloader, **kwargs): super().__init__(*args, **kwargs) self.dataloader = dataloader - def _shutdown_workers(self): + def __del__(self): self.dataloader.shutdown_workers_epochs.append(trainer.current_epoch) - super()._shutdown_workers() + super().__del__() class TestDataLoader(DataLoader): def __init__(self, *args, **kwargs): @@ -1137,8 +1137,8 @@ def _get_iterator(self): trainer.fit(model, train_dataloader, val_dataloader) if persistent_workers: - # workers get created and persist until the teardown in the final epoch - expected = [trainer.current_epoch, trainer.current_epoch] # once epoch end, once on teardown + # workers persist across epochs and are shut down exactly once via __del__. + expected = [trainer.current_epoch] elif should_fail: expected = [ # <-- iter() on epoch 0, workers get created @@ -1155,8 +1155,7 @@ def _get_iterator(self): assert train_dataloader.shutdown_workers_epochs == expected if persistent_workers: - # workers get created and persist until the teardown in the final epoch - expected = [trainer.current_epoch, trainer.current_epoch] # once epoch end, once on teardown + expected = [trainer.current_epoch] elif should_fail: expected = [ # <-- iter() on sanity check, workers get created diff --git a/tests/tests_pytorch/utilities/test_combined_loader.py b/tests/tests_pytorch/utilities/test_combined_loader.py index da168be1e3e8a..acf890919c022 100644 --- a/tests/tests_pytorch/utilities/test_combined_loader.py +++ b/tests/tests_pytorch/utilities/test_combined_loader.py @@ -657,3 +657,27 @@ def test_load_state_dicts(): cl._load_state_dicts([state1, state2]) stateful1.load_state_dict.assert_called_with(state1) stateful2.load_state_dict.assert_called_with(state2) + + +def test_combined_loader_reset_uses_del_not_shutdown_workers(): + """Test that `combined_loader.reset()` uses `del` to reset the dataloader iterator instead of calling + `_shutdown_workers()` explicitly. + + This is a regression test for https://github.com/Lightning-AI/pytorch-lightning/issues/21703 + + """ + from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter + + dataloader = DataLoader(range(10), num_workers=2, persistent_workers=True, multiprocessing_context="spawn") + combined_loader = CombinedLoader([dataloader]) + + mock_iterator = Mock(spec=_MultiProcessingDataLoaderIter) + mock_iterator._shutdown_workers = Mock() + dataloader._iterator = mock_iterator + + iterator_ref = dataloader._iterator + + combined_loader.reset() + + iterator_ref._shutdown_workers.assert_not_called() + assert dataloader._iterator is None