@@ -266,8 +266,9 @@ def test_broadcast_sigterm_interval(n_steps):
266266def test_broadcast_sigterm_forced_at_epoch_boundary ():
267267 """Test that a SIGTERM broadcast is forced at epoch end even if the interval hasn't been reached.
268268
269- This prevents hanging ranks when broadcast_sigterm_every_n_steps > 1 and SIGTERM
270- arrives between broadcasts near the end of an epoch.
269+ This prevents hanging ranks when broadcast_sigterm_every_n_steps > 1 and SIGTERM arrives between broadcasts near the
270+ end of an epoch.
271+
271272 """
272273 trainer = Trainer (broadcast_sigterm_every_n_steps = 100 )
273274 epoch_loop = trainer .fit_loop .epoch_loop
@@ -278,13 +279,14 @@ def test_broadcast_sigterm_forced_at_epoch_boundary():
278279 mock_fetcher = Mock ()
279280 mock_fetcher .done = True # epoch is ending
280281
281- with patch .object (epoch_loop , "_broadcast_sigterm_tensor" ) as mock_broadcast , patch .object (
282- epoch_loop , "_should_check_val_fx" , return_value = False
283- ), patch .object (epoch_loop , "_should_accumulate" , return_value = False ), patch .object (
284- epoch_loop , "_save_loggers_on_train_batch_end"
285- ), patch (
286- "torch.distributed.is_available" , return_value = True
287- ), patch ("torch.distributed.is_initialized" , return_value = True ):
282+ with (
283+ patch .object (epoch_loop , "_broadcast_sigterm_tensor" ) as mock_broadcast ,
284+ patch .object (epoch_loop , "_should_check_val_fx" , return_value = False ),
285+ patch .object (epoch_loop , "_should_accumulate" , return_value = False ),
286+ patch .object (epoch_loop , "_save_loggers_on_train_batch_end" ),
287+ patch ("torch.distributed.is_available" , return_value = True ),
288+ patch ("torch.distributed.is_initialized" , return_value = True ),
289+ ):
288290 trainer ._accelerator_connector ._devices_flag = 2
289291 epoch_loop .on_advance_end (mock_fetcher )
290292
0 commit comments