Skip to content

Commit

Permalink
process_group/ManagedProcessGroup: ensure quorum and PG is configured…
Browse files Browse the repository at this point in the history
… before operations (#83)
  • Loading branch information
d4l3k authored Jan 28, 2025
1 parent ccf74d4 commit 866873a
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 1 deletion.
5 changes: 4 additions & 1 deletion torchft/process_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import logging
import queue
import threading
from abc import ABC
from datetime import timedelta
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type, Union

Expand Down Expand Up @@ -507,6 +506,10 @@ def __init__(self, manager: "Manager") -> None:
self._manager = manager

def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
# Ensure we have a valid quorum and are configured before trying to do
# any work.
self._manager.wait_quorum()

if self._manager.errored() is not None:
return _DummyWork(tensors)

Expand Down
1 change: 1 addition & 0 deletions torchft/process_group_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,7 @@ def test_managed_process_group(self) -> None:

self.assertEqual(manager.report_error.call_count, 0)
self.assertEqual(manager.wrap_future.call_count, 1)
self.assertEqual(manager.wait_quorum.call_count, 1)


class DeviceMeshTest(TestCase):
Expand Down

0 comments on commit 866873a

Please sign in to comment.