Skip to content

Commit ec43633

Browse files
committed
ProcessGroupBaby: support full suite of PG tests
1 parent 4bdb8a7 commit ec43633

File tree

2 files changed

+169
-28
lines changed

2 files changed

+169
-28
lines changed

torchft/process_group.py

+131-9
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,21 @@
1919
import logging
2020
import queue
2121
import threading
22+
from dataclasses import dataclass
2223
from datetime import timedelta
23-
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type, Union
24+
from typing import (
25+
TYPE_CHECKING,
26+
Any,
27+
Callable,
28+
Dict,
29+
List,
30+
Optional,
31+
Tuple,
32+
Type,
33+
TypeVar,
34+
Union,
35+
cast,
36+
)
2437

2538
import torch
2639
import torch.distributed as dist
@@ -29,7 +42,6 @@
2942
# pyre-fixme[21]: no attribute ProcessGroupNCCL
3043
# pyre-fixme[21]: no attribute ProcessGroupGloo
3144
from torch.distributed import (
32-
BroadcastOptions,
3345
DeviceMesh,
3446
PrefixStore,
3547
ProcessGroup as BaseProcessGroup,
@@ -40,7 +52,14 @@
4052
get_rank,
4153
init_device_mesh,
4254
)
43-
from torch.distributed.distributed_c10d import Work, _world
55+
from torch.distributed.distributed_c10d import (
56+
AllgatherOptions,
57+
AllreduceOptions,
58+
BroadcastOptions,
59+
ReduceOp,
60+
Work,
61+
_world,
62+
)
4463
from torch.futures import Future
4564

4665
if TYPE_CHECKING:
@@ -54,6 +73,9 @@
5473
_FUTURE_EXCEPTION = "fut_exception"
5574

5675

76+
T = TypeVar("T")
77+
78+
5779
def _get(q: mp.Queue, timeout: Union[float, timedelta]) -> object:
5880
"""
5981
Gets an item from a queue with a timeout. If the timeout is exceeded then
@@ -122,15 +144,17 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None:
122144
raise NotImplementedError("not implemented")
123145

124146
# pyre-fixme[14]: inconsistent override
125-
def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
147+
def allreduce(
148+
self, tensors: List[torch.Tensor], opts: Union[AllreduceOptions, ReduceOp]
149+
) -> Work:
126150
raise NotImplementedError("not implemented")
127151

128152
# pyre-fixme[14]: inconsistent override
129153
def allgather(
130154
self,
131155
output_tensors: List[List[torch.Tensor]],
132156
input_tensor: List[torch.Tensor],
133-
opts: object,
157+
opts: AllgatherOptions,
134158
) -> Work:
135159
"""
136160
Gathers tensors from the whole group in a list.
@@ -140,7 +164,9 @@ def allgather(
140164
raise NotImplementedError("not implemented")
141165

142166
# pyre-fixme[14]: inconsistent override
143-
def broadcast(self, tensor_list: List[torch.Tensor], opts: object) -> Work:
167+
def broadcast(
168+
self, tensor_list: List[torch.Tensor], opts: BroadcastOptions
169+
) -> Work:
144170
"""
145171
Broadcasts the tensor to the whole group.
146172
@@ -567,6 +593,9 @@ def wait(self, timeout: Optional[timedelta] = None) -> bool:
567593
def get_future(self) -> Future[object]:
568594
return self._pg._get_future(self._op_id)
569595

596+
def __del__(self) -> None:
597+
self._tx.put(("del", self._op_id), timeout=self._timeout)
598+
570599

571600
class _BabyWorkNCCL(_BabyWork):
572601
def wait(self, timeout: Optional[timedelta] = None) -> bool:
@@ -695,15 +724,18 @@ def _worker(
695724
cmd = op[0]
696725
if cmd == "func":
697726
func_name, args, kwargs = op[1:]
727+
args = _PickleSafeOptions.unsafe_args(args)
698728
fn = getattr(pg, func_name)
699729
work[next_op_id] = fn(*args, **kwargs)
700730
tx.put(next_op_id)
701731
next_op_id += 1
702732
elif cmd == "wait":
703733
op_id: int = op[1]
704734
work[op_id].wait()
705-
del work[op_id]
706735
tx.put(op_id)
736+
elif cmd == "del":
737+
op_id: int = op[1]
738+
del work[op_id]
707739
elif cmd == "future":
708740
op_id: int = op[1]
709741

@@ -731,6 +763,8 @@ def callback(fut: Future[object]) -> None:
731763

732764
del work[op_id]
733765
tx.put((op_id, event))
766+
elif cmd == "num_active_work":
767+
tx.put(len(work))
734768
else:
735769
raise ValueError(f"unknown cmd: {cmd}")
736770

@@ -775,7 +809,10 @@ def _run_func(self, func: str, *args: object, **kwargs: object) -> Work:
775809
assert rx is not None
776810
assert tx is not None
777811

778-
tx.put(("func", func, args, kwargs), timeout=self._timeout)
812+
tx.put(
813+
("func", func, _PickleSafeOptions.safe_args(args), kwargs),
814+
timeout=self._timeout,
815+
)
779816

780817
op_id = _get(rx, self._timeout)
781818
assert isinstance(op_id, int), f"invalid return {op_id}"
@@ -784,7 +821,11 @@ def _run_func(self, func: str, *args: object, **kwargs: object) -> Work:
784821
pg=self, tx=tx, rx=rx, op_id=op_id, timeout=self._timeout
785822
)
786823

787-
def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
824+
def allreduce(
825+
self,
826+
tensors: List[torch.Tensor],
827+
opts: Union[dist.AllreduceOptions, dist.ReduceOp],
828+
) -> Work:
788829
assert isinstance(tensors, list), "input must be list"
789830

790831
for tensor in tensors:
@@ -793,9 +834,90 @@ def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
793834

794835
return self._run_func("allreduce", tensors, opts)
795836

837+
def allgather(
838+
self,
839+
output_tensors: List[List[torch.Tensor]],
840+
input_tensor: List[torch.Tensor],
841+
opts: AllgatherOptions,
842+
) -> Work:
843+
assert isinstance(output_tensors, list), "input must be list"
844+
assert isinstance(input_tensor, list), "input must be list"
845+
846+
for tensor_list in output_tensors:
847+
for tensor in tensor_list:
848+
if not tensor.is_shared():
849+
tensor.share_memory_()
850+
851+
for tensor in input_tensor:
852+
if not tensor.is_shared():
853+
tensor.share_memory_()
854+
855+
return self._run_func("allgather", output_tensors, input_tensor, opts)
856+
857+
def broadcast(
858+
self,
859+
tensor_list: List[torch.Tensor],
860+
opts: BroadcastOptions,
861+
) -> Work:
862+
assert isinstance(tensor_list, list), "input must be list"
863+
864+
for tensor in tensor_list:
865+
if not tensor.is_shared():
866+
tensor.share_memory_()
867+
868+
return self._run_func("broadcast", tensor_list, opts)
869+
796870
def size(self) -> int:
797871
return self._world_size
798872

873+
def num_active_work(self) -> int:
874+
assert self._tx is not None
875+
self._tx.put(("num_active_work",), timeout=self._timeout)
876+
877+
assert self._rx is not None
878+
return cast(int, _get(self._rx, self._timeout))
879+
880+
881+
@dataclass
882+
class _PickleSafeOptions:
883+
func: Callable[[], object]
884+
fields: Dict[str, object]
885+
886+
@classmethod
887+
def safe_args(cls, args: T) -> T:
888+
if isinstance(args, tuple):
889+
return tuple(cls.safe_args(arg) for arg in args)
890+
elif isinstance(args, list):
891+
return [cls.safe_args(arg) for arg in args]
892+
elif isinstance(args, (AllreduceOptions, AllgatherOptions, BroadcastOptions)):
893+
return cls.from_torch(args)
894+
else:
895+
return args
896+
897+
@classmethod
898+
def unsafe_args(cls, args: T) -> T:
899+
if isinstance(args, tuple):
900+
return tuple(cls.unsafe_args(arg) for arg in args)
901+
elif isinstance(args, list):
902+
return [cls.unsafe_args(arg) for arg in args]
903+
elif isinstance(args, cls):
904+
return args.to_torch()
905+
else:
906+
return args
907+
908+
@classmethod
909+
def from_torch(cls, opts: object) -> "_PickleSafeOptions":
910+
return cls(
911+
func=opts.__class__,
912+
fields={k: getattr(opts, k) for k in dir(opts) if not k.startswith("_")},
913+
)
914+
915+
def to_torch(self) -> object:
916+
opts = self.func()
917+
for k, v in self.fields.items():
918+
setattr(opts, k, v)
919+
return opts
920+
799921

800922
class ProcessGroupBabyGloo(ProcessGroupBaby):
801923
"""

torchft/process_group_test.py

+38-19
Original file line numberDiff line numberDiff line change
@@ -4,37 +4,43 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import gc
78
import multiprocessing
89
import os
910
import unittest
1011
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
1112
from datetime import timedelta
12-
from typing import Any, Dict, Tuple, cast
13-
from unittest import TestCase, skipUnless
13+
from typing import Any, cast, Dict, Tuple
14+
from unittest import skipUnless, TestCase
1415
from unittest.mock import Mock
1516

1617
import torch
1718
import torch.distributed as dist
1819
from torch import nn
1920
from torch._C._distributed_c10d import (
21+
_resolve_process_group,
2022
AllgatherOptions,
2123
AllreduceOptions,
2224
BroadcastOptions,
2325
ReduceOp,
24-
_resolve_process_group,
2526
)
2627
from torch.distributed import (
28+
_functional_collectives,
29+
get_world_size,
2730
ReduceOp,
2831
TCPStore,
2932
Work,
30-
_functional_collectives,
31-
get_world_size,
3233
)
3334
from torch.distributed.device_mesh import init_device_mesh
3435

3536
from torchft.manager import Manager
3637
from torchft.process_group import (
38+
_DummyWork,
39+
_ErrorSwallowingWork,
40+
_ManagedWork,
3741
ErrorSwallowingProcessGroupWrapper,
42+
extend_device_mesh,
43+
ft_init_device_mesh,
3844
ManagedProcessGroup,
3945
ProcessGroup,
4046
ProcessGroupBabyGloo,
@@ -43,11 +49,6 @@
4349
ProcessGroupGloo,
4450
ProcessGroupNCCL,
4551
ProcessGroupWrapper,
46-
_DummyWork,
47-
_ErrorSwallowingWork,
48-
_ManagedWork,
49-
extend_device_mesh,
50-
ft_init_device_mesh,
5152
)
5253

5354

@@ -86,14 +87,15 @@ def check_tensors(arg: Any) -> None: # pyre-ignore[2]
8687
check_tensors(item)
8788

8889
# Test collectives
89-
collectives = {
90-
"allreduce": ([input_tensor], AllreduceOptions()),
91-
"allgather": (output_tensors, [input_tensor], AllgatherOptions()),
92-
"broadcast": (tensor_list, BroadcastOptions()),
93-
"broadcast_one": (input_tensor, 0),
94-
}
90+
collectives = [
91+
("allreduce", ([input_tensor], AllreduceOptions())),
92+
("allreduce", ([input_tensor], ReduceOp.SUM)),
93+
("allgather", (output_tensors, [input_tensor], AllgatherOptions())),
94+
("broadcast", (tensor_list, BroadcastOptions())),
95+
("broadcast_one", (input_tensor, 0)),
96+
]
9597
works: Dict[str, dist._Work] = {}
96-
for coll_str, args in collectives.items():
98+
for coll_str, args in collectives:
9799
coll = getattr(pg, coll_str)
98100
work = coll(*args)
99101
works[coll_str] = work
@@ -246,6 +248,23 @@ def test_reconfigure_baby_process_group(self) -> None:
246248
assert p_2 is not None
247249
self.assertTrue(p_2.is_alive())
248250

251+
def test_baby_gloo_apis(self) -> None:
252+
store = TCPStore(
253+
host_name="localhost", port=0, is_master=True, wait_for_workers=False
254+
)
255+
256+
store_addr = f"localhost:{store.port}/prefix"
257+
258+
a = ProcessGroupBabyGloo(timeout=timedelta(seconds=10))
259+
a.configure(store_addr, 0, 1)
260+
261+
_test_pg(a)
262+
263+
# force collection to ensure no BabyWork objects remain
264+
gc.collect()
265+
266+
self.assertEqual(a.num_active_work(), 0)
267+
249268
def test_dummy(self) -> None:
250269
pg = ProcessGroupDummy(0, 1)
251270
m = nn.Linear(3, 4)
@@ -367,8 +386,8 @@ def test_managed_process_group(self) -> None:
367386
self.assertIsInstance(list(works.values())[0], _ManagedWork)
368387

369388
self.assertEqual(manager.report_error.call_count, 0)
370-
self.assertEqual(manager.wrap_future.call_count, 1)
371-
self.assertEqual(manager.wait_quorum.call_count, 1)
389+
self.assertEqual(manager.wrap_future.call_count, 2)
390+
self.assertEqual(manager.wait_quorum.call_count, 2)
372391

373392

374393
class DeviceMeshTest(TestCase):

0 commit comments

Comments
 (0)