Skip to content

Commit 4af18f5

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 29eccf6 commit 4af18f5

1 file changed

Lines changed: 11 additions & 9 deletions

File tree

tests/tests_pytorch/loops/test_training_epoch_loop.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -266,8 +266,9 @@ def test_broadcast_sigterm_interval(n_steps):
266266
def test_broadcast_sigterm_forced_at_epoch_boundary():
267267
"""Test that a SIGTERM broadcast is forced at epoch end even if the interval hasn't been reached.
268268
269-
This prevents hanging ranks when broadcast_sigterm_every_n_steps > 1 and SIGTERM
270-
arrives between broadcasts near the end of an epoch.
269+
This prevents hanging ranks when broadcast_sigterm_every_n_steps > 1 and SIGTERM arrives between broadcasts near the
270+
end of an epoch.
271+
271272
"""
272273
trainer = Trainer(broadcast_sigterm_every_n_steps=100)
273274
epoch_loop = trainer.fit_loop.epoch_loop
@@ -278,13 +279,14 @@ def test_broadcast_sigterm_forced_at_epoch_boundary():
278279
mock_fetcher = Mock()
279280
mock_fetcher.done = True # epoch is ending
280281

281-
with patch.object(epoch_loop, "_broadcast_sigterm_tensor") as mock_broadcast, patch.object(
282-
epoch_loop, "_should_check_val_fx", return_value=False
283-
), patch.object(epoch_loop, "_should_accumulate", return_value=False), patch.object(
284-
epoch_loop, "_save_loggers_on_train_batch_end"
285-
), patch(
286-
"torch.distributed.is_available", return_value=True
287-
), patch("torch.distributed.is_initialized", return_value=True):
282+
with (
283+
patch.object(epoch_loop, "_broadcast_sigterm_tensor") as mock_broadcast,
284+
patch.object(epoch_loop, "_should_check_val_fx", return_value=False),
285+
patch.object(epoch_loop, "_should_accumulate", return_value=False),
286+
patch.object(epoch_loop, "_save_loggers_on_train_batch_end"),
287+
patch("torch.distributed.is_available", return_value=True),
288+
patch("torch.distributed.is_initialized", return_value=True),
289+
):
288290
trainer._accelerator_connector._devices_flag = 2
289291
epoch_loop.on_advance_end(mock_fetcher)
290292

0 commit comments

Comments
 (0)