Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion src/lightning/pytorch/loops/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def __init__(self, trainer: "pl.Trainer", min_steps: Optional[int] = None, max_s
self._batches_that_stepped: int = 0
self._restart_stage = RestartStage.NONE
self._skip_next_val = False
self._sigterm_broadcast_step: int = 0

@property
def total_batch_idx(self) -> int:
Expand Down Expand Up @@ -297,7 +298,10 @@ def advance(self, data_fetcher: _DataFetcher) -> None:
# =====================================================================

if torch.distributed.is_available() and torch.distributed.is_initialized() and self.trainer.world_size > 1:
self._broadcast_sigterm_tensor()
self._sigterm_broadcast_step += 1
if self._sigterm_broadcast_step >= self.trainer.broadcast_sigterm_every_n_steps:
self._sigterm_broadcast_step = 0
self._broadcast_sigterm_tensor()

# =====================================================================

Expand Down Expand Up @@ -393,6 +397,18 @@ def on_advance_end(self, data_fetcher: _DataFetcher) -> None:
should_check_val = False
self._skip_next_val = False

# Force a SIGTERM broadcast at major boundaries (validation, epoch end)
# to prevent hanging ranks when broadcast_sigterm_every_n_steps > 1.
if (
torch.distributed.is_available()
and torch.distributed.is_initialized()
and self.trainer.world_size > 1
and self._sigterm_broadcast_step > 0
and (should_check_val or data_fetcher.done)
):
self._sigterm_broadcast_step = 0
self._broadcast_sigterm_tensor()

if should_check_val:
# this needs to be set so the correct `trainer._active_loop` is picked
self.trainer.validating = True
Expand Down
11 changes: 11 additions & 0 deletions src/lightning/pytorch/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def __init__(
plugins: Optional[Union[_PLUGIN_INPUT, list[_PLUGIN_INPUT]]] = None,
sync_batchnorm: bool = False,
reload_dataloaders_every_n_epochs: int = 0,
broadcast_sigterm_every_n_steps: int = 1,
default_root_dir: Optional[_PATH] = None,
enable_autolog_hparams: bool = True,
model_registry: Optional[str] = None,
Expand Down Expand Up @@ -300,6 +301,13 @@ def __init__(
reload_dataloaders_every_n_epochs: Set to a positive integer to reload dataloaders every n epochs.
Default: ``0``.

broadcast_sigterm_every_n_steps: How often (in training steps) to broadcast SIGTERM status across
ranks in distributed training. The default ``1`` broadcasts every step. Higher values reduce
the overhead of the NCCL broadcast at the cost of increased SIGTERM detection latency
(worst case: ``(N-1) * step_time``). This is useful for fast training loops where the
per-step broadcast cost is significant relative to the step time.
Default: ``1``.

default_root_dir: Default path for logs and weights when no logger/ckpt_callback passed.
Default: ``os.getcwd()``.
Can be remote file paths such as `s3://mybucket/path` or 'hdfs://path/'
Expand Down Expand Up @@ -447,6 +455,9 @@ def __init__(
self.predict_loop = _PredictionLoop(self, inference_mode=inference_mode)

self.accumulate_grad_batches = accumulate_grad_batches
if broadcast_sigterm_every_n_steps < 1:
raise ValueError(f"`broadcast_sigterm_every_n_steps` must be >= 1, got {broadcast_sigterm_every_n_steps}.")
self.broadcast_sigterm_every_n_steps = broadcast_sigterm_every_n_steps

# init callbacks
# Declare attributes to be set in _callback_connector on_trainer_init
Expand Down
103 changes: 103 additions & 0 deletions tests/tests_pytorch/loops/test_training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import logging
from unittest.mock import Mock, patch

Expand All @@ -22,6 +23,7 @@
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.trainer.trainer import Trainer
from tests_pytorch.helpers.runif import RunIf


def test_no_val_on_train_epoch_loop_restart(tmp_path):
Expand Down Expand Up @@ -228,3 +230,104 @@ def train_and_resume(dataloader, resume_step, expected_warning):

# Resume mid-epoch, stateful dataloader -> no warning
train_and_resume(dataloader=StatefulIterable(), resume_step=1, expected_warning=False)


def test_broadcast_sigterm_every_n_steps_validation():
"""Test that invalid values for broadcast_sigterm_every_n_steps are rejected."""
with pytest.raises(ValueError, match="broadcast_sigterm_every_n_steps` must be >= 1"):
Trainer(broadcast_sigterm_every_n_steps=0)
with pytest.raises(ValueError, match="broadcast_sigterm_every_n_steps` must be >= 1"):
Trainer(broadcast_sigterm_every_n_steps=-1)


def test_broadcast_sigterm_every_n_steps_default():
"""Test that the default value broadcasts every step."""
trainer = Trainer()
assert trainer.broadcast_sigterm_every_n_steps == 1


@pytest.mark.parametrize("n_steps", [1, 5, 10])
def test_broadcast_sigterm_interval(n_steps):
"""Test that _broadcast_sigterm_tensor is called at the correct interval during advance()."""
trainer = Trainer(broadcast_sigterm_every_n_steps=n_steps)
epoch_loop = trainer.fit_loop.epoch_loop

total_steps = 20

with (
patch.object(epoch_loop, "_broadcast_sigterm_tensor") as mock_broadcast,
patch("torch.distributed.is_available", return_value=True),
patch("torch.distributed.is_initialized", return_value=True),
patch.object(type(trainer), "world_size", new_callable=lambda: property(lambda self: 2)),
):
for _ in range(total_steps):
# Raise StopIteration to exit advance() right after the broadcast check,
# before it tries to fetch a batch and run training.
mock_fetcher = Mock()
mock_fetcher.__next__ = Mock(side_effect=StopIteration)
with contextlib.suppress(StopIteration, TypeError, AttributeError):
epoch_loop.advance(mock_fetcher)

assert mock_broadcast.call_count == total_steps // n_steps
assert epoch_loop._sigterm_broadcast_step == total_steps % n_steps


def test_broadcast_sigterm_forced_at_epoch_boundary():
"""Test that a SIGTERM broadcast is forced at epoch end even if the interval hasn't been reached.

This prevents hanging ranks when broadcast_sigterm_every_n_steps > 1 and SIGTERM arrives between broadcasts near the
end of an epoch.

"""
trainer = Trainer(broadcast_sigterm_every_n_steps=100)
epoch_loop = trainer.fit_loop.epoch_loop

# Simulate 5 steps taken (well below interval of 100)
epoch_loop._sigterm_broadcast_step = 5

mock_fetcher = Mock()
mock_fetcher.done = True # epoch is ending

with (
patch.object(epoch_loop, "_broadcast_sigterm_tensor") as mock_broadcast,
patch.object(epoch_loop, "_should_check_val_fx", return_value=False),
patch.object(epoch_loop, "_should_accumulate", return_value=False),
patch.object(epoch_loop, "_save_loggers_on_train_batch_end"),
patch("torch.distributed.is_available", return_value=True),
patch("torch.distributed.is_initialized", return_value=True),
patch.object(type(trainer), "world_size", new_callable=lambda: property(lambda self: 2)),
):
epoch_loop.on_advance_end(mock_fetcher)

mock_broadcast.assert_called_once()
assert epoch_loop._sigterm_broadcast_step == 0


@RunIf(min_cuda_gpus=2)
def test_broadcast_sigterm_interval_ddp(tmp_path):
"""Test that broadcast_sigterm_every_n_steps controls broadcast frequency in real DDP training.

Uses ddp_spawn to exercise real torch.distributed broadcast paths (lines 300-304, 408-410). After training,
_sigterm_broadcast_step should be 0 because the epoch-end forced broadcast resets it.

"""
n_steps = 5
limit_train_batches = 7 # 7 % 5 = 2 remaining steps, triggers epoch-end forced broadcast

model = BoringModel()
trainer = Trainer(
default_root_dir=tmp_path,
max_epochs=1,
limit_train_batches=limit_train_batches,
accelerator="gpu",
devices=2,
strategy="ddp_spawn",
broadcast_sigterm_every_n_steps=n_steps,
enable_progress_bar=False,
enable_model_summary=False,
enable_checkpointing=False,
logger=False,
)
# Training should complete without hanging — the epoch-end forced broadcast
# ensures all ranks stay in sync even when limit_train_batches is not a multiple of n_steps.
trainer.fit(model)
Loading