Skip to content

Commit 866873a

Browse files
authored
process_group/ManagedProcessGroup: ensure quorum and PG is configured before operations (#83)
1 parent ccf74d4 commit 866873a

File tree

2 files changed

+5
-1
lines changed

2 files changed

+5
-1
lines changed

torchft/process_group.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import logging
2020
import queue
2121
import threading
22-
from abc import ABC
2322
from datetime import timedelta
2423
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type, Union
2524

@@ -507,6 +506,10 @@ def __init__(self, manager: "Manager") -> None:
507506
self._manager = manager
508507

509508
def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
509+
# Ensure we have a valid quorum and are configured before trying to do
510+
# any work.
511+
self._manager.wait_quorum()
512+
510513
if self._manager.errored() is not None:
511514
return _DummyWork(tensors)
512515

torchft/process_group_test.py

+1
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,7 @@ def test_managed_process_group(self) -> None:
368368

369369
self.assertEqual(manager.report_error.call_count, 0)
370370
self.assertEqual(manager.wrap_future.call_count, 1)
371+
self.assertEqual(manager.wait_quorum.call_count, 1)
371372

372373

373374
class DeviceMeshTest(TestCase):

0 commit comments

Comments
 (0)