Skip to content

Commit b3f5b93

Browse files
committed
ProcessGroupBaby: support full suite of PG tests
1 parent 4bdb8a7 commit b3f5b93

File tree

2 files changed

+143
-18
lines changed

2 files changed

+143
-18
lines changed

torchft/process_group.py

+121-9
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,20 @@
1919
import logging
2020
import queue
2121
import threading
22+
from dataclasses import dataclass
2223
from datetime import timedelta
23-
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type, Union
24+
from typing import (
25+
TYPE_CHECKING,
26+
Any,
27+
Callable,
28+
Dict,
29+
List,
30+
Optional,
31+
Tuple,
32+
Type,
33+
TypeVar,
34+
Union,
35+
)
2436

2537
import torch
2638
import torch.distributed as dist
@@ -29,7 +41,6 @@
2941
# pyre-fixme[21]: no attribute ProcessGroupNCCL
3042
# pyre-fixme[21]: no attribute ProcessGroupGloo
3143
from torch.distributed import (
32-
BroadcastOptions,
3344
DeviceMesh,
3445
PrefixStore,
3546
ProcessGroup as BaseProcessGroup,
@@ -40,7 +51,14 @@
4051
get_rank,
4152
init_device_mesh,
4253
)
43-
from torch.distributed.distributed_c10d import Work, _world
54+
from torch.distributed.distributed_c10d import (
55+
AllgatherOptions,
56+
AllreduceOptions,
57+
BroadcastOptions,
58+
ReduceOp,
59+
Work,
60+
_world,
61+
)
4462
from torch.futures import Future
4563

4664
if TYPE_CHECKING:
@@ -54,6 +72,9 @@
5472
_FUTURE_EXCEPTION = "fut_exception"
5573

5674

75+
T = TypeVar("T")
76+
77+
5778
def _get(q: mp.Queue, timeout: Union[float, timedelta]) -> object:
5879
"""
5980
Gets an item from a queue with a timeout. If the timeout is exceeded then
@@ -122,15 +143,17 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None:
122143
raise NotImplementedError("not implemented")
123144

124145
# pyre-fixme[14]: inconsistent override
125-
def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
146+
def allreduce(
147+
self, tensors: List[torch.Tensor], opts: Union[AllreduceOptions, ReduceOp]
148+
) -> Work:
126149
raise NotImplementedError("not implemented")
127150

128151
# pyre-fixme[14]: inconsistent override
129152
def allgather(
130153
self,
131154
output_tensors: List[List[torch.Tensor]],
132155
input_tensor: List[torch.Tensor],
133-
opts: object,
156+
opts: AllgatherOptions,
134157
) -> Work:
135158
"""
136159
Gathers tensors from the whole group in a list.
@@ -140,7 +163,9 @@ def allgather(
140163
raise NotImplementedError("not implemented")
141164

142165
# pyre-fixme[14]: inconsistent override
143-
def broadcast(self, tensor_list: List[torch.Tensor], opts: object) -> Work:
166+
def broadcast(
167+
self, tensor_list: List[torch.Tensor], opts: BroadcastOptions
168+
) -> Work:
144169
"""
145170
Broadcasts the tensor to the whole group.
146171
@@ -567,6 +592,9 @@ def wait(self, timeout: Optional[timedelta] = None) -> bool:
567592
def get_future(self) -> Future[object]:
568593
return self._pg._get_future(self._op_id)
569594

595+
def __del__(self) -> None:
596+
self._tx.put(("del", self._op_id), timeout=self._timeout)
597+
570598

571599
class _BabyWorkNCCL(_BabyWork):
572600
def wait(self, timeout: Optional[timedelta] = None) -> bool:
@@ -695,15 +723,18 @@ def _worker(
695723
cmd = op[0]
696724
if cmd == "func":
697725
func_name, args, kwargs = op[1:]
726+
args = _PickleSafeOptions.unsafe_args(args)
698727
fn = getattr(pg, func_name)
699728
work[next_op_id] = fn(*args, **kwargs)
700729
tx.put(next_op_id)
701730
next_op_id += 1
702731
elif cmd == "wait":
703732
op_id: int = op[1]
704733
work[op_id].wait()
705-
del work[op_id]
706734
tx.put(op_id)
735+
elif cmd == "del":
736+
op_id: int = op[1]
737+
del work[op_id]
707738
elif cmd == "future":
708739
op_id: int = op[1]
709740

@@ -775,7 +806,10 @@ def _run_func(self, func: str, *args: object, **kwargs: object) -> Work:
775806
assert rx is not None
776807
assert tx is not None
777808

778-
tx.put(("func", func, args, kwargs), timeout=self._timeout)
809+
tx.put(
810+
("func", func, _PickleSafeOptions.safe_args(args), kwargs),
811+
timeout=self._timeout,
812+
)
779813

780814
op_id = _get(rx, self._timeout)
781815
assert isinstance(op_id, int), f"invalid return {op_id}"
@@ -784,7 +818,11 @@ def _run_func(self, func: str, *args: object, **kwargs: object) -> Work:
784818
pg=self, tx=tx, rx=rx, op_id=op_id, timeout=self._timeout
785819
)
786820

787-
def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
821+
def allreduce(
822+
self,
823+
tensors: List[torch.Tensor],
824+
opts: Union[dist.AllreduceOptions, dist.ReduceOp],
825+
) -> Work:
788826
assert isinstance(tensors, list), "input must be list"
789827

790828
for tensor in tensors:
@@ -793,10 +831,84 @@ def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
793831

794832
return self._run_func("allreduce", tensors, opts)
795833

834+
def allgather(
835+
self,
836+
output_tensors: List[List[torch.Tensor]],
837+
input_tensor: List[torch.Tensor],
838+
opts: AllgatherOptions,
839+
) -> Work:
840+
assert isinstance(output_tensors, list), "input must be list"
841+
assert isinstance(input_tensor, list), "input must be list"
842+
843+
for tensor_list in output_tensors:
844+
for tensor in tensor_list:
845+
if not tensor.is_shared():
846+
tensor.share_memory_()
847+
848+
for tensor in input_tensor:
849+
if not tensor.is_shared():
850+
tensor.share_memory_()
851+
852+
return self._run_func("allgather", output_tensors, input_tensor, opts)
853+
854+
def broadcast(
855+
self,
856+
tensor_list: List[torch.Tensor],
857+
opts: BroadcastOptions,
858+
) -> Work:
859+
assert isinstance(tensor_list, list), "input must be list"
860+
861+
for tensor in tensor_list:
862+
if not tensor.is_shared():
863+
tensor.share_memory_()
864+
865+
return self._run_func("broadcast", tensor_list, opts)
866+
796867
def size(self) -> int:
797868
return self._world_size
798869

799870

871+
@dataclass
872+
class _PickleSafeOptions:
873+
func: Callable[[], object]
874+
fields: Dict[str, object]
875+
876+
@classmethod
877+
def safe_args(cls, args: T) -> T:
878+
if isinstance(args, tuple):
879+
return tuple(cls.safe_args(arg) for arg in args)
880+
elif isinstance(args, list):
881+
return [cls.safe_args(arg) for arg in args]
882+
elif isinstance(args, (AllreduceOptions, AllgatherOptions, BroadcastOptions)):
883+
return cls.from_torch(args)
884+
else:
885+
return args
886+
887+
@classmethod
888+
def unsafe_args(cls, args: T) -> T:
889+
if isinstance(args, tuple):
890+
return tuple(cls.unsafe_args(arg) for arg in args)
891+
elif isinstance(args, list):
892+
return [cls.unsafe_args(arg) for arg in args]
893+
elif isinstance(args, cls):
894+
return args.to_torch()
895+
else:
896+
return args
897+
898+
@classmethod
899+
def from_torch(cls, opts: object) -> "_PickleSafeOptions":
900+
return cls(
901+
func=opts.__class__,
902+
fields={k: getattr(opts, k) for k in dir(opts) if not k.startswith("_")},
903+
)
904+
905+
def to_torch(self) -> object:
906+
opts = self.func()
907+
for k, v in self.fields.items():
908+
setattr(opts, k, v)
909+
return opts
910+
911+
800912
class ProcessGroupBabyGloo(ProcessGroupBaby):
801913
"""
802914
This is a ProcessGroup that runs Gloo in a subprocess.

torchft/process_group_test.py

+22-9
Original file line numberDiff line numberDiff line change
@@ -86,14 +86,15 @@ def check_tensors(arg: Any) -> None: # pyre-ignore[2]
8686
check_tensors(item)
8787

8888
# Test collectives
89-
collectives = {
90-
"allreduce": ([input_tensor], AllreduceOptions()),
91-
"allgather": (output_tensors, [input_tensor], AllgatherOptions()),
92-
"broadcast": (tensor_list, BroadcastOptions()),
93-
"broadcast_one": (input_tensor, 0),
94-
}
89+
collectives = [
90+
("allreduce", ([input_tensor], AllreduceOptions())),
91+
("allreduce", ([input_tensor], ReduceOp.SUM)),
92+
("allgather", (output_tensors, [input_tensor], AllgatherOptions())),
93+
("broadcast", (tensor_list, BroadcastOptions())),
94+
("broadcast_one", (input_tensor, 0)),
95+
]
9596
works: Dict[str, dist._Work] = {}
96-
for coll_str, args in collectives.items():
97+
for coll_str, args in collectives:
9798
coll = getattr(pg, coll_str)
9899
work = coll(*args)
99100
works[coll_str] = work
@@ -246,6 +247,18 @@ def test_reconfigure_baby_process_group(self) -> None:
246247
assert p_2 is not None
247248
self.assertTrue(p_2.is_alive())
248249

250+
def test_baby_gloo_opts(self) -> None:
251+
store = TCPStore(
252+
host_name="localhost", port=0, is_master=True, wait_for_workers=False
253+
)
254+
255+
store_addr = f"localhost:{store.port}/prefix"
256+
257+
a = ProcessGroupBabyGloo(timeout=timedelta(seconds=10))
258+
a.configure(store_addr, 0, 1)
259+
260+
_test_pg(a)
261+
249262
def test_dummy(self) -> None:
250263
pg = ProcessGroupDummy(0, 1)
251264
m = nn.Linear(3, 4)
@@ -367,8 +380,8 @@ def test_managed_process_group(self) -> None:
367380
self.assertIsInstance(list(works.values())[0], _ManagedWork)
368381

369382
self.assertEqual(manager.report_error.call_count, 0)
370-
self.assertEqual(manager.wrap_future.call_count, 1)
371-
self.assertEqual(manager.wait_quorum.call_count, 1)
383+
self.assertEqual(manager.wrap_future.call_count, 2)
384+
self.assertEqual(manager.wait_quorum.call_count, 2)
372385

373386

374387
class DeviceMeshTest(TestCase):

0 commit comments

Comments
 (0)