Skip to content

Commit fa1630d

Browse files
authored
ProcessGroupBaby: support full suite of PG tests (#89)
1 parent 2a67d66 commit fa1630d

File tree

2 files changed

+206
-18
lines changed

2 files changed

+206
-18
lines changed

torchft/process_group.py

+131-9
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,21 @@
1919
import logging
2020
import queue
2121
import threading
22+
from dataclasses import dataclass
2223
from datetime import timedelta
23-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
24+
from typing import (
25+
TYPE_CHECKING,
26+
Any,
27+
Callable,
28+
Dict,
29+
List,
30+
Optional,
31+
Tuple,
32+
Type,
33+
TypeVar,
34+
Union,
35+
cast,
36+
)
2437

2538
import torch
2639
import torch.distributed as dist
@@ -29,7 +42,6 @@
2942
# pyre-fixme[21]: no attribute ProcessGroupNCCL
3043
# pyre-fixme[21]: no attribute ProcessGroupGloo
3144
from torch.distributed import (
32-
BroadcastOptions,
3345
DeviceMesh,
3446
PrefixStore,
3547
ProcessGroup as BaseProcessGroup,
@@ -40,7 +52,14 @@
4052
get_rank,
4153
init_device_mesh,
4254
)
43-
from torch.distributed.distributed_c10d import Work, _world
55+
from torch.distributed.distributed_c10d import (
56+
AllgatherOptions,
57+
AllreduceOptions,
58+
BroadcastOptions,
59+
ReduceOp,
60+
Work,
61+
_world,
62+
)
4463
from torch.futures import Future
4564

4665
if TYPE_CHECKING:
@@ -54,6 +73,9 @@
5473
_FUTURE_EXCEPTION = "fut_exception"
5574

5675

76+
T = TypeVar("T")
77+
78+
5779
def _get(q: mp.Queue, timeout: Union[float, timedelta]) -> object:
5880
"""
5981
Gets an item from a queue with a timeout. If the timeout is exceeded then
@@ -122,15 +144,17 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None:
122144
raise NotImplementedError("not implemented")
123145

124146
# pyre-fixme[14]: inconsistent override
125-
def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
147+
def allreduce(
148+
self, tensors: List[torch.Tensor], opts: Union[AllreduceOptions, ReduceOp]
149+
) -> Work:
126150
raise NotImplementedError("not implemented")
127151

128152
# pyre-fixme[14]: inconsistent override
129153
def allgather(
130154
self,
131155
output_tensors: List[List[torch.Tensor]],
132156
input_tensor: List[torch.Tensor],
133-
opts: object,
157+
opts: AllgatherOptions,
134158
) -> Work:
135159
"""
136160
Gathers tensors from the whole group in a list.
@@ -140,7 +164,9 @@ def allgather(
140164
raise NotImplementedError("not implemented")
141165

142166
# pyre-fixme[14]: inconsistent override
143-
def broadcast(self, tensor_list: List[torch.Tensor], opts: object) -> Work:
167+
def broadcast(
168+
self, tensor_list: List[torch.Tensor], opts: BroadcastOptions
169+
) -> Work:
144170
"""
145171
Broadcasts the tensor to the whole group.
146172
@@ -567,6 +593,9 @@ def wait(self, timeout: Optional[timedelta] = None) -> bool:
567593
def get_future(self) -> Future[object]:
568594
return self._pg._get_future(self._op_id)
569595

596+
def __del__(self) -> None:
597+
self._tx.put(("del", self._op_id), timeout=self._timeout)
598+
570599

571600
class _BabyWorkNCCL(_BabyWork):
572601
def wait(self, timeout: Optional[timedelta] = None) -> bool:
@@ -695,15 +724,18 @@ def _worker(
695724
cmd = op[0]
696725
if cmd == "func":
697726
func_name, args, kwargs = op[1:]
727+
args = _PickleSafeOptions.unsafe_args(args)
698728
fn = getattr(pg, func_name)
699729
work[next_op_id] = fn(*args, **kwargs)
700730
tx.put(next_op_id)
701731
next_op_id += 1
702732
elif cmd == "wait":
703733
op_id: int = op[1]
704734
work[op_id].wait()
705-
del work[op_id]
706735
tx.put(op_id)
736+
elif cmd == "del":
737+
op_id: int = op[1]
738+
del work[op_id]
707739
elif cmd == "future":
708740
op_id: int = op[1]
709741

@@ -731,6 +763,8 @@ def callback(fut: Future[object]) -> None:
731763

732764
del work[op_id]
733765
tx.put((op_id, event))
766+
elif cmd == "num_active_work":
767+
tx.put(len(work))
734768
else:
735769
raise ValueError(f"unknown cmd: {cmd}")
736770

@@ -775,7 +809,10 @@ def _run_func(self, func: str, *args: object, **kwargs: object) -> Work:
775809
assert rx is not None
776810
assert tx is not None
777811

778-
tx.put(("func", func, args, kwargs), timeout=self._timeout)
812+
tx.put(
813+
("func", func, _PickleSafeOptions.safe_args(args), kwargs),
814+
timeout=self._timeout,
815+
)
779816

780817
op_id = _get(rx, self._timeout)
781818
assert isinstance(op_id, int), f"invalid return {op_id}"
@@ -784,7 +821,11 @@ def _run_func(self, func: str, *args: object, **kwargs: object) -> Work:
784821
pg=self, tx=tx, rx=rx, op_id=op_id, timeout=self._timeout
785822
)
786823

787-
def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
824+
def allreduce(
825+
self,
826+
tensors: List[torch.Tensor],
827+
opts: Union[dist.AllreduceOptions, dist.ReduceOp],
828+
) -> Work:
788829
assert isinstance(tensors, list), "input must be list"
789830

790831
for tensor in tensors:
@@ -793,9 +834,90 @@ def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
793834

794835
return self._run_func("allreduce", tensors, opts)
795836

837+
def allgather(
838+
self,
839+
output_tensors: List[List[torch.Tensor]],
840+
input_tensor: List[torch.Tensor],
841+
opts: AllgatherOptions,
842+
) -> Work:
843+
assert isinstance(output_tensors, list), "input must be list"
844+
assert isinstance(input_tensor, list), "input must be list"
845+
846+
for tensor_list in output_tensors:
847+
for tensor in tensor_list:
848+
if not tensor.is_shared():
849+
tensor.share_memory_()
850+
851+
for tensor in input_tensor:
852+
if not tensor.is_shared():
853+
tensor.share_memory_()
854+
855+
return self._run_func("allgather", output_tensors, input_tensor, opts)
856+
857+
def broadcast(
858+
self,
859+
tensor_list: List[torch.Tensor],
860+
opts: BroadcastOptions,
861+
) -> Work:
862+
assert isinstance(tensor_list, list), "input must be list"
863+
864+
for tensor in tensor_list:
865+
if not tensor.is_shared():
866+
tensor.share_memory_()
867+
868+
return self._run_func("broadcast", tensor_list, opts)
869+
796870
def size(self) -> int:
797871
return self._world_size
798872

873+
def num_active_work(self) -> int:
874+
assert self._tx is not None
875+
self._tx.put(("num_active_work",), timeout=self._timeout)
876+
877+
assert self._rx is not None
878+
return cast(int, _get(self._rx, self._timeout))
879+
880+
881+
@dataclass
882+
class _PickleSafeOptions:
883+
func: Callable[[], object]
884+
fields: Dict[str, object]
885+
886+
@classmethod
887+
def safe_args(cls, args: T) -> T:
888+
if isinstance(args, tuple):
889+
return tuple(cls.safe_args(arg) for arg in args)
890+
elif isinstance(args, list):
891+
return [cls.safe_args(arg) for arg in args]
892+
elif isinstance(args, (AllreduceOptions, AllgatherOptions, BroadcastOptions)):
893+
return cls.from_torch(args)
894+
else:
895+
return args
896+
897+
@classmethod
898+
def unsafe_args(cls, args: T) -> T:
899+
if isinstance(args, tuple):
900+
return tuple(cls.unsafe_args(arg) for arg in args)
901+
elif isinstance(args, list):
902+
return [cls.unsafe_args(arg) for arg in args]
903+
elif isinstance(args, cls):
904+
return args.to_torch()
905+
else:
906+
return args
907+
908+
@classmethod
909+
def from_torch(cls, opts: object) -> "_PickleSafeOptions":
910+
return cls(
911+
func=opts.__class__,
912+
fields={k: getattr(opts, k) for k in dir(opts) if not k.startswith("_")},
913+
)
914+
915+
def to_torch(self) -> object:
916+
opts = self.func()
917+
for k, v in self.fields.items():
918+
setattr(opts, k, v)
919+
return opts
920+
799921

800922
class ProcessGroupBabyGloo(ProcessGroupBaby):
801923
"""

torchft/process_group_test.py

+75-9
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import gc
78
import io
89
import multiprocessing
910
import os
@@ -87,14 +88,15 @@ def check_tensors(arg: Any) -> None: # pyre-ignore[2]
8788
check_tensors(item)
8889

8990
# Test collectives
90-
collectives = {
91-
"allreduce": ([input_tensor], AllreduceOptions()),
92-
"allgather": (output_tensors, [input_tensor], AllgatherOptions()),
93-
"broadcast": (tensor_list, BroadcastOptions()),
94-
"broadcast_one": (input_tensor, 0),
95-
}
91+
collectives = [
92+
("allreduce", ([input_tensor], AllreduceOptions())),
93+
("allreduce", ([input_tensor], ReduceOp.SUM)),
94+
("allgather", (output_tensors, [input_tensor], AllgatherOptions())),
95+
("broadcast", (tensor_list, BroadcastOptions())),
96+
("broadcast_one", (input_tensor, 0)),
97+
]
9698
works: Dict[str, dist._Work] = {}
97-
for coll_str, args in collectives.items():
99+
for coll_str, args in collectives:
98100
coll = getattr(pg, coll_str)
99101
work = coll(*args)
100102
works[coll_str] = work
@@ -247,6 +249,23 @@ def test_reconfigure_baby_process_group(self) -> None:
247249
assert p_2 is not None
248250
self.assertTrue(p_2.is_alive())
249251

252+
def test_baby_gloo_apis(self) -> None:
253+
store = TCPStore(
254+
host_name="localhost", port=0, is_master=True, wait_for_workers=False
255+
)
256+
257+
store_addr = f"localhost:{store.port}/prefix"
258+
259+
a = ProcessGroupBabyGloo(timeout=timedelta(seconds=10))
260+
a.configure(store_addr, 0, 1)
261+
262+
_test_pg(a)
263+
264+
# force collection to ensure no BabyWork objects remain
265+
gc.collect()
266+
267+
self.assertEqual(a.num_active_work(), 0)
268+
250269
def test_dummy(self) -> None:
251270
pg = ProcessGroupDummy(0, 1)
252271
m = nn.Linear(3, 4)
@@ -368,5 +387,52 @@ def test_managed_process_group(self) -> None:
368387
self.assertIsInstance(list(works.values())[0], _ManagedWork)
369388

370389
self.assertEqual(manager.report_error.call_count, 0)
371-
self.assertEqual(manager.wrap_future.call_count, 1)
372-
self.assertEqual(manager.wait_quorum.call_count, 1)
390+
self.assertEqual(manager.wrap_future.call_count, 2)
391+
self.assertEqual(manager.wait_quorum.call_count, 2)
392+
393+
394+
class DeviceMeshTest(TestCase):
395+
@staticmethod
396+
def _test_init_device_mesh(world_size: int, rank: int) -> None:
397+
os.environ["MASTER_ADDR"] = "127.0.0.1"
398+
os.environ["MASTER_PORT"] = str(12346)
399+
os.environ["RANK"] = str(rank)
400+
os.environ["WORLD_SIZE"] = str(4)
401+
402+
testcase = TestCase()
403+
404+
manager = Mock(spec=Manager)
405+
# Even though we only have 4 workers, we can still initialize (2, 4) mesh.
406+
# That's because the replicate group is NOT phystically created in the
407+
# real mesh but is virtually added to the mesh via ManagedDeviceMesh.
408+
device_mesh = ft_init_device_mesh(
409+
device_type="cpu",
410+
mesh_shape=(2, world_size),
411+
mesh_dim_names=("dp_replicate", "dp_shard"),
412+
replicate_dim=0,
413+
manager=manager,
414+
)
415+
416+
testcase.assertTrue(
417+
isinstance(device_mesh.get_group("dp_replicate"), ManagedProcessGroup)
418+
)
419+
testcase.assertTrue(
420+
not isinstance(device_mesh.get_group("dp_shard"), ManagedProcessGroup)
421+
)
422+
replicate_group = device_mesh.get_group("dp_replicate")
423+
testcase.assertEqual(
424+
cast(ManagedProcessGroup, replicate_group)._manager, manager
425+
)
426+
replicate_mesh = device_mesh["dp_replicate"]
427+
testcase.assertEqual(replicate_mesh.get_group(), replicate_group)
428+
flatten_mesh = device_mesh._flatten("dp")
429+
manager.num_participants.return_value = 1
430+
testcase.assertEqual(flatten_mesh.size(), world_size)
431+
testcase.assertEqual(flatten_mesh.get_local_rank(), dist.get_rank())
432+
433+
def test_init_device_mesh(self) -> None:
434+
with ProcessPoolExecutor(max_workers=4) as executor:
435+
futures = []
436+
for i in range(4):
437+
future = executor.submit(self._test_init_device_mesh, 4, i)
438+
futures.append(future)

0 commit comments

Comments
 (0)