|
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
| 14 | +import contextlib |
14 | 15 | import logging |
15 | 16 | from unittest.mock import Mock, patch |
16 | 17 |
|
@@ -264,10 +265,8 @@ def test_broadcast_sigterm_interval(n_steps): |
264 | 265 | # before it tries to fetch a batch and run training. |
265 | 266 | mock_fetcher = Mock() |
266 | 267 | mock_fetcher.__next__ = Mock(side_effect=StopIteration) |
267 | | - try: |
| 268 | + with contextlib.suppress(StopIteration, TypeError, AttributeError): |
268 | 269 | epoch_loop.advance(mock_fetcher) |
269 | | - except (StopIteration, TypeError, AttributeError): |
270 | | - pass |
271 | 270 |
|
272 | 271 | assert mock_broadcast.call_count == total_steps // n_steps |
273 | 272 | assert epoch_loop._sigterm_broadcast_step == total_steps % n_steps |
@@ -308,8 +307,9 @@ def test_broadcast_sigterm_forced_at_epoch_boundary(): |
308 | 307 | def test_broadcast_sigterm_interval_ddp(tmp_path): |
309 | 308 | """Test that broadcast_sigterm_every_n_steps controls broadcast frequency in real DDP training. |
310 | 309 |
|
311 | | - Uses ddp_spawn to exercise real torch.distributed broadcast paths (lines 300-304, 408-410). |
312 | | - After training, _sigterm_broadcast_step should be 0 because the epoch-end forced broadcast resets it. |
| 310 | + Uses ddp_spawn to exercise real torch.distributed broadcast paths (lines 300-304, 408-410). After training, |
| 311 | + _sigterm_broadcast_step should be 0 because the epoch-end forced broadcast resets it. |
| 312 | +
|
313 | 313 | """ |
314 | 314 | n_steps = 5 |
315 | 315 | limit_train_batches = 7 # 7 % 5 = 2 remaining steps, triggers epoch-end forced broadcast |
|
0 commit comments