Skip to content

Commit 29eccf6

Browse files
committed
Force SIGTERM broadcast at epoch/validation boundaries
When broadcast_sigterm_every_n_steps > 1, SIGTERM could arrive between broadcasts near the end of an epoch. Without a forced check, rank 0 would exit while other ranks hang waiting at the next collective (e.g. validation barrier). This adds a forced broadcast whenever the epoch ends or validation is about to start.
1 parent b2ea3e2 commit 29eccf6

2 files changed

Lines changed: 41 additions & 0 deletions

File tree

src/lightning/pytorch/loops/training_epoch_loop.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,18 @@ def on_advance_end(self, data_fetcher: _DataFetcher) -> None:
397397
should_check_val = False
398398
self._skip_next_val = False
399399

400+
# Force a SIGTERM broadcast at major boundaries (validation, epoch end)
401+
# to prevent hanging ranks when broadcast_sigterm_every_n_steps > 1.
402+
if (
403+
torch.distributed.is_available()
404+
and torch.distributed.is_initialized()
405+
and self.trainer.world_size > 1
406+
and self._sigterm_broadcast_step > 0
407+
and (should_check_val or data_fetcher.done)
408+
):
409+
self._sigterm_broadcast_step = 0
410+
self._broadcast_sigterm_tensor()
411+
400412
if should_check_val:
401413
# this needs to be set so the correct `trainer._active_loop` is picked
402414
self.trainer.validating = True

tests/tests_pytorch/loops/test_training_epoch_loop.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,3 +261,32 @@ def test_broadcast_sigterm_interval(n_steps):
261261

262262
assert broadcast_call_count == total_steps // n_steps
263263
assert epoch_loop._sigterm_broadcast_step == total_steps % n_steps
264+
265+
266+
def test_broadcast_sigterm_forced_at_epoch_boundary():
267+
"""Test that a SIGTERM broadcast is forced at epoch end even if the interval hasn't been reached.
268+
269+
This prevents hanging ranks when broadcast_sigterm_every_n_steps > 1 and SIGTERM
270+
arrives between broadcasts near the end of an epoch.
271+
"""
272+
trainer = Trainer(broadcast_sigterm_every_n_steps=100)
273+
epoch_loop = trainer.fit_loop.epoch_loop
274+
275+
# Simulate 5 steps taken (well below interval of 100)
276+
epoch_loop._sigterm_broadcast_step = 5
277+
278+
mock_fetcher = Mock()
279+
mock_fetcher.done = True # epoch is ending
280+
281+
with patch.object(epoch_loop, "_broadcast_sigterm_tensor") as mock_broadcast, patch.object(
282+
epoch_loop, "_should_check_val_fx", return_value=False
283+
), patch.object(epoch_loop, "_should_accumulate", return_value=False), patch.object(
284+
epoch_loop, "_save_loggers_on_train_batch_end"
285+
), patch(
286+
"torch.distributed.is_available", return_value=True
287+
), patch("torch.distributed.is_initialized", return_value=True):
288+
trainer._accelerator_connector._devices_flag = 2
289+
epoch_loop.on_advance_end(mock_fetcher)
290+
291+
mock_broadcast.assert_called_once()
292+
assert epoch_loop._sigterm_broadcast_step == 0

0 commit comments

Comments
 (0)