Skip to content

Commit e55542a

Browse files
authored
Adds reduce_scatter into torchft (#102)
* initial commit for reduce_scatter * fixes reduce_scatter function signature, refactors test and adds reduce_scatter test * fixes test * adds explicit NotImplementedError to reduce_scatter in gloo, simplify the test suite * fix tests after merge * add explicit error for ProcessGroupGloo * notimplementederror->runtimeerror
1 parent 8f0d125 commit e55542a

File tree

2 files changed

+145
-23
lines changed

2 files changed

+145
-23
lines changed

torchft/process_group.py

+107-6
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
AllreduceOptions,
5858
BroadcastOptions,
5959
ReduceOp,
60+
ReduceScatterOptions,
6061
Work,
6162
)
6263
from torch.futures import Future
@@ -159,6 +160,20 @@ def broadcast_one(self, tensor: torch.Tensor, root: int) -> Work:
159160
opts.rootRank = root
160161
return self.broadcast([tensor], opts)
161162

163+
# pyre-fixme[14]: inconsistent override
164+
def reduce_scatter(
165+
self,
166+
output_tensors: List[torch.Tensor],
167+
input_tensors: List[List[torch.Tensor]],
168+
opts: ReduceScatterOptions,
169+
) -> Work:
170+
"""
171+
Reduces, then scatters a list of tensors to all processes in a group.
172+
173+
See torch.distributed.reduce_scatter for more details.
174+
"""
175+
raise NotImplementedError("not implemented")
176+
162177
def size(self) -> int:
163178
raise NotImplementedError("not implemented")
164179

@@ -267,6 +282,14 @@ def allgather(
267282
def broadcast(self, tensor_list: List[torch.Tensor], opts: object) -> Work:
268283
return self.parent.broadcast(tensor_list, opts)
269284

285+
def reduce_scatter(
286+
self,
287+
output_tensors: List[torch.Tensor],
288+
input_tensors: List[List[torch.Tensor]],
289+
opts: object,
290+
) -> Work:
291+
return self.parent.reduce_scatter(output_tensors, input_tensors, opts)
292+
270293
def size(self) -> int:
271294
return self.parent.size()
272295

@@ -295,6 +318,25 @@ def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGro
295318
def getBackendName(self) -> str:
296319
return "torchft-gloo"
297320

321+
# pyre-fixme[14,15]: inconsistent override
322+
def reduce_scatter(
323+
self,
324+
output_tensors: List[torch.Tensor],
325+
input_tensors: List[List[torch.Tensor]],
326+
opts: ReduceScatterOptions,
327+
) -> None:
328+
"""
329+
This function is a placeholder for the reduce_scatter operation in the
330+
ProcessGroupGloo class. However, this operation is not supported by the
331+
Gloo backend, and thus, calling this function will raise a
332+
RuntimeError.
333+
334+
Raises:
335+
RuntimeError: Always raised since reduce_scatter is not
336+
supported by ProcessGroupGloo.
337+
"""
338+
raise RuntimeError("ProcessGroupGloo does not support reduce_scatter.")
339+
298340

299341
class ProcessGroupNCCL(ProcessGroupWrapper):
300342
"""
@@ -354,11 +396,6 @@ def __init__(self, rank: int, world: int) -> None:
354396
def configure(self, store_addr: str, rank: int, world_size: int) -> None:
355397
self.configure_count += 1
356398

357-
def broadcast(self, tensor_list: List[torch.Tensor], opts: object) -> Work:
358-
res = _DummyWork(tensor_list)
359-
self._work.append(res)
360-
return res
361-
362399
def allgather(
363400
self,
364401
output_tensors: List[List[torch.Tensor]],
@@ -377,6 +414,24 @@ def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
377414
self._work.append(res)
378415
return res
379416

417+
def broadcast(self, tensor_list: List[torch.Tensor], opts: object) -> Work:
418+
res = _DummyWork(tensor_list)
419+
self._work.append(res)
420+
return res
421+
422+
def reduce_scatter(
423+
self,
424+
output_tensors: List[torch.Tensor],
425+
input_tensors: List[List[torch.Tensor]],
426+
opts: object,
427+
) -> Work:
428+
for o, i in zip(output_tensors, input_tensors[0]):
429+
o.copy_(i)
430+
431+
res = _DummyWork(output_tensors)
432+
self._work.append(res)
433+
return res
434+
380435
def size(self) -> int:
381436
return self._world
382437

@@ -970,6 +1025,25 @@ def broadcast(
9701025

9711026
return self._run_func("broadcast", tensor_list, opts)
9721027

1028+
def reduce_scatter(
1029+
self,
1030+
output_tensors: List[torch.Tensor],
1031+
input_tensors: List[List[torch.Tensor]],
1032+
opts: ReduceScatterOptions,
1033+
) -> Work:
1034+
assert isinstance(output_tensors, list), "input must be list"
1035+
assert isinstance(input_tensors, list), "input must be list"
1036+
1037+
for tensor in output_tensors:
1038+
if not tensor.is_shared():
1039+
tensor.share_memory_()
1040+
1041+
for tensor_list in input_tensors:
1042+
for tensor in tensor_list:
1043+
if not tensor.is_shared():
1044+
tensor.share_memory_()
1045+
return self._run_func("reduce_scatter", output_tensors, input_tensors, opts)
1046+
9731047
def size(self) -> int:
9741048
return self._world_size
9751049

@@ -992,7 +1066,15 @@ def safe_args(cls, args: T) -> T:
9921066
return tuple(cls.safe_args(arg) for arg in args)
9931067
elif isinstance(args, list):
9941068
return [cls.safe_args(arg) for arg in args]
995-
elif isinstance(args, (AllreduceOptions, AllgatherOptions, BroadcastOptions)):
1069+
elif isinstance(
1070+
args,
1071+
(
1072+
AllreduceOptions,
1073+
AllgatherOptions,
1074+
BroadcastOptions,
1075+
ReduceScatterOptions,
1076+
),
1077+
):
9961078
return cls.from_torch(args)
9971079
else:
9981080
return args
@@ -1038,6 +1120,25 @@ def _create_pg(cls, store: Store, rank: int, world_size: int) -> BaseProcessGrou
10381120
def getBackendName(self) -> str:
10391121
return "torchft-baby-gloo"
10401122

1123+
# pyre-fixme[15]: inconsistent override
1124+
def reduce_scatter(
1125+
self,
1126+
output_tensors: List[torch.Tensor],
1127+
input_tensors: List[List[torch.Tensor]],
1128+
opts: ReduceScatterOptions,
1129+
) -> None:
1130+
"""
1131+
This function is a placeholder for the reduce_scatter operation in the
1132+
ProcessGroupGloo class. However, this operation is not supported by the
1133+
Gloo backend, and thus, calling this function will raise a
1134+
RuntimeError.
1135+
1136+
Raises:
1137+
RuntimeError: Always raised since reduce_scatter is not
1138+
supported by ProcessGroupGloo.
1139+
"""
1140+
raise RuntimeError("ProcessGroupBabyGloo does not support reduce_scatter.")
1141+
10411142

10421143
class ProcessGroupBabyNCCL(ProcessGroupBaby):
10431144
"""

torchft/process_group_test.py

+38-17
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
AllreduceOptions,
2424
BroadcastOptions,
2525
ReduceOp,
26+
ReduceScatterOptions,
2627
_resolve_process_group,
2728
)
2829
from torch.distributed import (
@@ -94,18 +95,28 @@ def check_tensors(arg: Any) -> None: # pyre-ignore[2]
9495
("allgather", (output_tensors, [input_tensor], AllgatherOptions())),
9596
("broadcast", (tensor_list, BroadcastOptions())),
9697
("broadcast_one", (input_tensor, 0)),
98+
(
99+
"reduce_scatter",
100+
(output_tensors[0], [[input_tensor]], ReduceScatterOptions()),
101+
),
97102
]
98103
works: Dict[str, dist._Work] = {}
99-
for coll_str, args in collectives:
100-
coll = getattr(pg, coll_str)
101-
work = coll(*args)
102-
works[coll_str] = work
103-
work.wait()
104-
fut = work.get_future()
105-
fut.wait()
106104

107-
# Check that all tensor arguments have the expected shapes and dtypes
108-
check_tensors(args)
105+
for coll_str, args in collectives:
106+
try:
107+
coll = getattr(pg, coll_str)
108+
work = coll(*args)
109+
works[coll_str] = work
110+
work.wait()
111+
fut = work.get_future()
112+
fut.wait()
113+
# Check that all tensor arguments have the expected shapes and dtypes
114+
check_tensors(args)
115+
except RuntimeError as e:
116+
if f"does not support {coll_str}" in str(e):
117+
# Skip collectives that are not supported by the backend.
118+
continue
119+
raise e
109120

110121
print(works)
111122
return works
@@ -306,7 +317,7 @@ def test_baby_nccl_2gpu(self) -> None:
306317

307318
store_addr: str = f"localhost:{store.port}/prefix"
308319

309-
def run(rank: int) -> Tuple[torch.Tensor, Work]:
320+
def run(rank: int) -> Tuple[ProcessGroupBabyNCCL, torch.Tensor, Work]:
310321
a = ProcessGroupBabyNCCL(
311322
timeout=timedelta(seconds=10.0),
312323
)
@@ -318,19 +329,29 @@ def run(rank: int) -> Tuple[torch.Tensor, Work]:
318329
at = torch.tensor([rank + 1], device="cuda")
319330

320331
a_work = a.allreduce([at], ReduceOp.SUM)
321-
return at, a_work
332+
return a, at, a_work
322333

323334
with ThreadPoolExecutor(max_workers=2) as executor:
324335
a_fut = executor.submit(run, 0)
325336
b_fut = executor.submit(run, 1)
326337

327-
at, a_work = a_fut.result()
328-
bt, b_work = b_fut.result()
329-
330-
a_work.wait()
331-
b_work.get_future().wait()
338+
a, at, a_work = a_fut.result()
339+
b, bt, b_work = b_fut.result()
332340

333-
torch.testing.assert_close(at.cpu(), bt.cpu())
341+
try:
342+
a_work.wait()
343+
b_work.get_future().wait()
344+
torch.testing.assert_close(at.cpu(), bt.cpu())
345+
finally:
346+
# cleanup - first ensure that babywork is deleted before shutting down PGs
347+
# note futures must be deleted as they hold references to babywork
348+
del a_fut
349+
del b_fut
350+
del a_work
351+
del b_work
352+
gc.collect()
353+
b.shutdown()
354+
a.shutdown()
334355

335356
def test_device_mesh(self) -> None:
336357
os.environ["MASTER_ADDR"] = "localhost"

0 commit comments

Comments
 (0)