|
22 | 22 | from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint |
23 | 23 | from lightning.pytorch.demos.boring_classes import BoringModel |
24 | 24 | from lightning.pytorch.trainer.trainer import Trainer |
| 25 | +from tests_pytorch.helpers.runif import RunIf |
25 | 26 |
|
26 | 27 |
|
27 | 28 | def test_no_val_on_train_epoch_loop_restart(tmp_path): |
@@ -246,20 +247,29 @@ def test_broadcast_sigterm_every_n_steps_default(): |
246 | 247 |
|
247 | 248 | @pytest.mark.parametrize("n_steps", [1, 5, 10]) |
248 | 249 | def test_broadcast_sigterm_interval(n_steps): |
249 | | - """Test that _broadcast_sigterm_tensor is called at the correct interval.""" |
| 250 | + """Test that _broadcast_sigterm_tensor is called at the correct interval during advance().""" |
250 | 251 | trainer = Trainer(broadcast_sigterm_every_n_steps=n_steps) |
251 | 252 | epoch_loop = trainer.fit_loop.epoch_loop |
252 | 253 |
|
253 | 254 | total_steps = 20 |
254 | | - broadcast_call_count = 0 |
255 | 255 |
|
256 | | - for _ in range(total_steps): |
257 | | - epoch_loop._sigterm_broadcast_step += 1 |
258 | | - if epoch_loop._sigterm_broadcast_step >= trainer.broadcast_sigterm_every_n_steps: |
259 | | - epoch_loop._sigterm_broadcast_step = 0 |
260 | | - broadcast_call_count += 1 |
261 | | - |
262 | | - assert broadcast_call_count == total_steps // n_steps |
| 256 | + with ( |
| 257 | + patch.object(epoch_loop, "_broadcast_sigterm_tensor") as mock_broadcast, |
| 258 | + patch("torch.distributed.is_available", return_value=True), |
| 259 | + patch("torch.distributed.is_initialized", return_value=True), |
| 260 | + patch.object(type(trainer), "world_size", new_callable=lambda: property(lambda self: 2)), |
| 261 | + ): |
| 262 | + for _ in range(total_steps): |
| 263 | + # Raise StopIteration to exit advance() right after the broadcast check, |
| 264 | + # before it tries to fetch a batch and run training. |
| 265 | + mock_fetcher = Mock() |
| 266 | + mock_fetcher.__next__ = Mock(side_effect=StopIteration) |
| 267 | + try: |
| 268 | + epoch_loop.advance(mock_fetcher) |
| 269 | + except (StopIteration, TypeError, AttributeError): |
| 270 | + pass |
| 271 | + |
| 272 | + assert mock_broadcast.call_count == total_steps // n_steps |
263 | 273 | assert epoch_loop._sigterm_broadcast_step == total_steps % n_steps |
264 | 274 |
|
265 | 275 |
|
@@ -292,3 +302,32 @@ def test_broadcast_sigterm_forced_at_epoch_boundary(): |
292 | 302 |
|
293 | 303 | mock_broadcast.assert_called_once() |
294 | 304 | assert epoch_loop._sigterm_broadcast_step == 0 |
| 305 | + |
| 306 | + |
| 307 | +@RunIf(min_cuda_gpus=2) |
| 308 | +def test_broadcast_sigterm_interval_ddp(tmp_path): |
| 309 | + """Test that broadcast_sigterm_every_n_steps controls broadcast frequency in real DDP training. |
| 310 | +
|
| 311 | + Uses ddp_spawn to exercise real torch.distributed broadcast paths (lines 300-304, 408-410). |
| 312 | + After training, _sigterm_broadcast_step should be 0 because the epoch-end forced broadcast resets it. |
| 313 | + """ |
| 314 | + n_steps = 5 |
| 315 | + limit_train_batches = 7 # 7 % 5 = 2 remaining steps, triggers epoch-end forced broadcast |
| 316 | + |
| 317 | + model = BoringModel() |
| 318 | + trainer = Trainer( |
| 319 | + default_root_dir=tmp_path, |
| 320 | + max_epochs=1, |
| 321 | + limit_train_batches=limit_train_batches, |
| 322 | + accelerator="gpu", |
| 323 | + devices=2, |
| 324 | + strategy="ddp_spawn", |
| 325 | + broadcast_sigterm_every_n_steps=n_steps, |
| 326 | + enable_progress_bar=False, |
| 327 | + enable_model_summary=False, |
| 328 | + enable_checkpointing=False, |
| 329 | + logger=False, |
| 330 | + ) |
| 331 | + # Training should complete without hanging — the epoch-end forced broadcast |
| 332 | + # ensures all ranks stay in sync even when limit_train_batches is not a multiple of n_steps. |
| 333 | + trainer.fit(model) |
0 commit comments