Skip to content

Commit 491e4b6

Browse files
committed
Add broadcast_sigterm_every_n_steps Trainer parameter
Allow users to control how often SIGTERM status is broadcast across ranks, reducing CPU-GPU sync overhead for fast training loops while preserving the default every-step behavior. Closes #21487
1 parent bb7820f commit 491e4b6

3 files changed

Lines changed: 51 additions & 1 deletion

File tree

src/lightning/pytorch/loops/training_epoch_loop.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ def __init__(self, trainer: "pl.Trainer", min_steps: Optional[int] = None, max_s
9494
self._batches_that_stepped: int = 0
9595
self._restart_stage = RestartStage.NONE
9696
self._skip_next_val = False
97+
self._sigterm_broadcast_step: int = 0
9798

9899
@property
99100
def total_batch_idx(self) -> int:
@@ -297,7 +298,10 @@ def advance(self, data_fetcher: _DataFetcher) -> None:
297298
# =====================================================================
298299

299300
if torch.distributed.is_available() and torch.distributed.is_initialized() and self.trainer.world_size > 1:
300-
self._broadcast_sigterm_tensor()
301+
self._sigterm_broadcast_step += 1
302+
if self._sigterm_broadcast_step >= self.trainer.broadcast_sigterm_every_n_steps:
303+
self._sigterm_broadcast_step = 0
304+
self._broadcast_sigterm_tensor()
301305

302306
# =====================================================================
303307

src/lightning/pytorch/trainer/trainer.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ def __init__(
129129
plugins: Optional[Union[_PLUGIN_INPUT, list[_PLUGIN_INPUT]]] = None,
130130
sync_batchnorm: bool = False,
131131
reload_dataloaders_every_n_epochs: int = 0,
132+
broadcast_sigterm_every_n_steps: int = 1,
132133
default_root_dir: Optional[_PATH] = None,
133134
enable_autolog_hparams: bool = True,
134135
model_registry: Optional[str] = None,
@@ -300,6 +301,13 @@ def __init__(
300301
reload_dataloaders_every_n_epochs: Set to a positive integer to reload dataloaders every n epochs.
301302
Default: ``0``.
302303
304+
broadcast_sigterm_every_n_steps: How often (in training steps) to broadcast SIGTERM status across
305+
ranks in distributed training. The default ``1`` broadcasts every step. Higher values reduce
306+
the overhead of the NCCL broadcast at the cost of increased SIGTERM detection latency
307+
(worst case: ``(N-1) * step_time``). This is useful for fast training loops where the
308+
per-step broadcast cost is significant relative to the step time.
309+
Default: ``1``.
310+
303311
default_root_dir: Default path for logs and weights when no logger/ckpt_callback passed.
304312
Default: ``os.getcwd()``.
305313
Can be remote file paths such as `s3://mybucket/path` or 'hdfs://path/'
@@ -447,6 +455,11 @@ def __init__(
447455
self.predict_loop = _PredictionLoop(self, inference_mode=inference_mode)
448456

449457
self.accumulate_grad_batches = accumulate_grad_batches
458+
if broadcast_sigterm_every_n_steps < 1:
459+
raise ValueError(
460+
f"`broadcast_sigterm_every_n_steps` must be >= 1, got {broadcast_sigterm_every_n_steps}."
461+
)
462+
self.broadcast_sigterm_every_n_steps = broadcast_sigterm_every_n_steps
450463

451464
# init callbacks
452465
# Declare attributes to be set in _callback_connector on_trainer_init

tests/tests_pytorch/loops/test_training_epoch_loop.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,3 +228,36 @@ def train_and_resume(dataloader, resume_step, expected_warning):
228228

229229
# Resume mid-epoch, stateful dataloader -> no warning
230230
train_and_resume(dataloader=StatefulIterable(), resume_step=1, expected_warning=False)
231+
232+
233+
def test_broadcast_sigterm_every_n_steps_validation():
234+
"""Test that invalid values for broadcast_sigterm_every_n_steps are rejected."""
235+
with pytest.raises(ValueError, match="broadcast_sigterm_every_n_steps` must be >= 1"):
236+
Trainer(broadcast_sigterm_every_n_steps=0)
237+
with pytest.raises(ValueError, match="broadcast_sigterm_every_n_steps` must be >= 1"):
238+
Trainer(broadcast_sigterm_every_n_steps=-1)
239+
240+
241+
def test_broadcast_sigterm_every_n_steps_default():
242+
"""Test that the default value broadcasts every step."""
243+
trainer = Trainer()
244+
assert trainer.broadcast_sigterm_every_n_steps == 1
245+
246+
247+
@pytest.mark.parametrize("n_steps", [1, 5, 10])
248+
def test_broadcast_sigterm_interval(n_steps):
249+
"""Test that _broadcast_sigterm_tensor is called at the correct interval."""
250+
trainer = Trainer(broadcast_sigterm_every_n_steps=n_steps)
251+
epoch_loop = trainer.fit_loop.epoch_loop
252+
253+
total_steps = 20
254+
broadcast_call_count = 0
255+
256+
for _ in range(total_steps):
257+
epoch_loop._sigterm_broadcast_step += 1
258+
if epoch_loop._sigterm_broadcast_step >= trainer.broadcast_sigterm_every_n_steps:
259+
epoch_loop._sigterm_broadcast_step = 0
260+
broadcast_call_count += 1
261+
262+
assert broadcast_call_count == total_steps // n_steps
263+
assert epoch_loop._sigterm_broadcast_step == total_steps % n_steps

0 commit comments

Comments
 (0)