Skip to content

Commit 87290f5

Browse files
authored
manager: expose participating_rank (#94)
1 parent 6e5dcbd commit 87290f5

File tree

2 files changed

+28
-1
lines changed

2 files changed

+28
-1
lines changed

torchft/manager.py

+20
Original file line numberDiff line numberDiff line change
@@ -652,15 +652,35 @@ def batches_committed(self) -> int:
652652
"""
653653
return self._batches_committed
654654

655+
def participating_rank(self) -> Optional[int]:
656+
"""
657+
Get the replica group rank of the current quorum. This will be the same on all
658+
ranks within the replica group.
659+
660+
If this replica group is not participating in the current quorum, this will be None.
661+
662+
This will block on the async quorum if it is not yet ready.
663+
664+
Returns:
665+
the rank of the current quorum
666+
"""
667+
self.wait_quorum()
668+
669+
return self._participating_rank
670+
655671
def num_participants(self) -> int:
656672
"""
657673
Get the number of participants in the current quorum.
658674
659675
This is the number of replicas participating in the current step.
660676
677+
This will block on the async quorum if it is not yet ready.
678+
661679
Returns:
662680
the number of participants in the current quorum
663681
"""
682+
self.wait_quorum()
683+
664684
assert self._participating_world_size >= 0, "internal error"
665685
return self._participating_world_size
666686

torchft/manager_test.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import concurrent
78
from datetime import timedelta
89
from typing import Optional
910
from unittest import TestCase
@@ -521,10 +522,16 @@ def test_manager_wrap_future_timeout(self, client_mock: MagicMock) -> None:
521522
def test_manager_numerics(self, client_mock: MagicMock) -> None:
522523
manager = self._create_manager()
523524

524-
manager._quorum_future = MagicMock()
525+
manager._quorum_future = quorum_future = MagicMock(
526+
spec=concurrent.futures.Future
527+
)
525528
manager._participating_rank = 1
526529
manager._participating_world_size = 5
527530
self.assertEqual(manager.num_participants(), 5)
531+
self.assertEqual(quorum_future.result.call_count, 1)
532+
self.assertEqual(manager.participating_rank(), 1)
533+
self.assertEqual(quorum_future.result.call_count, 2)
534+
528535
# pyre-ignore[16]: _pg is mocked
529536
manager._pg.allreduce.return_value = _DummyWork(None)
530537

0 commit comments

Comments
 (0)