Skip to content

Commit bac8431

Browse files
committed
Improve test coverage for broadcast_sigterm_every_n_steps
- Fix epoch boundary test: mock world_size property instead of _devices_flag, which didn't affect trainer.world_size - Rewrite interval test to call real advance() with mocked distributed instead of reimplementing the logic in the test - Add ddp_spawn integration test exercising real NCCL broadcasts on 2 GPUs with non-aligned step count to trigger epoch-end flush
1 parent e2d7496 commit bac8431

1 file changed

Lines changed: 48 additions & 9 deletions

File tree

tests/tests_pytorch/loops/test_training_epoch_loop.py

Lines changed: 48 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
2323
from lightning.pytorch.demos.boring_classes import BoringModel
2424
from lightning.pytorch.trainer.trainer import Trainer
25+
from tests_pytorch.helpers.runif import RunIf
2526

2627

2728
def test_no_val_on_train_epoch_loop_restart(tmp_path):
@@ -246,20 +247,29 @@ def test_broadcast_sigterm_every_n_steps_default():
246247

247248
@pytest.mark.parametrize("n_steps", [1, 5, 10])
248249
def test_broadcast_sigterm_interval(n_steps):
249-
"""Test that _broadcast_sigterm_tensor is called at the correct interval."""
250+
"""Test that _broadcast_sigterm_tensor is called at the correct interval during advance()."""
250251
trainer = Trainer(broadcast_sigterm_every_n_steps=n_steps)
251252
epoch_loop = trainer.fit_loop.epoch_loop
252253

253254
total_steps = 20
254-
broadcast_call_count = 0
255255

256-
for _ in range(total_steps):
257-
epoch_loop._sigterm_broadcast_step += 1
258-
if epoch_loop._sigterm_broadcast_step >= trainer.broadcast_sigterm_every_n_steps:
259-
epoch_loop._sigterm_broadcast_step = 0
260-
broadcast_call_count += 1
261-
262-
assert broadcast_call_count == total_steps // n_steps
256+
with (
257+
patch.object(epoch_loop, "_broadcast_sigterm_tensor") as mock_broadcast,
258+
patch("torch.distributed.is_available", return_value=True),
259+
patch("torch.distributed.is_initialized", return_value=True),
260+
patch.object(type(trainer), "world_size", new_callable=lambda: property(lambda self: 2)),
261+
):
262+
for _ in range(total_steps):
263+
# Raise StopIteration to exit advance() right after the broadcast check,
264+
# before it tries to fetch a batch and run training.
265+
mock_fetcher = Mock()
266+
mock_fetcher.__next__ = Mock(side_effect=StopIteration)
267+
try:
268+
epoch_loop.advance(mock_fetcher)
269+
except (StopIteration, TypeError, AttributeError):
270+
pass
271+
272+
assert mock_broadcast.call_count == total_steps // n_steps
263273
assert epoch_loop._sigterm_broadcast_step == total_steps % n_steps
264274

265275

@@ -292,3 +302,32 @@ def test_broadcast_sigterm_forced_at_epoch_boundary():
292302

293303
mock_broadcast.assert_called_once()
294304
assert epoch_loop._sigterm_broadcast_step == 0
305+
306+
307+
@RunIf(min_cuda_gpus=2)
308+
def test_broadcast_sigterm_interval_ddp(tmp_path):
309+
"""Test that broadcast_sigterm_every_n_steps controls broadcast frequency in real DDP training.
310+
311+
Uses ddp_spawn to exercise real torch.distributed broadcast paths (lines 300-304, 408-410).
312+
After training, _sigterm_broadcast_step should be 0 because the epoch-end forced broadcast resets it.
313+
"""
314+
n_steps = 5
315+
limit_train_batches = 7 # 7 % 5 = 2 remaining steps, triggers epoch-end forced broadcast
316+
317+
model = BoringModel()
318+
trainer = Trainer(
319+
default_root_dir=tmp_path,
320+
max_epochs=1,
321+
limit_train_batches=limit_train_batches,
322+
accelerator="gpu",
323+
devices=2,
324+
strategy="ddp_spawn",
325+
broadcast_sigterm_every_n_steps=n_steps,
326+
enable_progress_bar=False,
327+
enable_model_summary=False,
328+
enable_checkpointing=False,
329+
logger=False,
330+
)
331+
# Training should complete without hanging — the epoch-end forced broadcast
332+
# ensures all ranks stay in sync even when limit_train_batches is not a multiple of n_steps.
333+
trainer.fit(model)

0 commit comments

Comments
 (0)