Skip to content

Commit a425493

Browse files
committed
fixes reduce_scatter function signature, refactors test and adds reduce_scatter test
1 parent d076a54 commit a425493

File tree

2 files changed

+45
-6
lines changed

2 files changed

+45
-6
lines changed

torchft/process_group.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def broadcast_one(self, tensor: torch.Tensor, root: int) -> Work:
185185
def reduce_scatter(
186186
self,
187187
output_tensors: List[torch.Tensor],
188-
input_tensors: List[torch.Tensor],
188+
input_tensors: List[List[torch.Tensor]],
189189
opts: ReduceScatterOptions,
190190
) -> Work:
191191
"""
@@ -306,7 +306,7 @@ def broadcast(self, tensor_list: List[torch.Tensor], opts: object) -> Work:
306306
def reduce_scatter(
307307
self,
308308
output_tensors: List[torch.Tensor],
309-
input_tensors: List[torch.Tensor],
309+
input_tensors: List[List[torch.Tensor]],
310310
opts: object,
311311
) -> Work:
312312
return self.parent.reduce_scatter(output_tensors, input_tensors, opts)
@@ -424,10 +424,10 @@ def broadcast(self, tensor_list: List[torch.Tensor], opts: object) -> Work:
424424
def reduce_scatter(
425425
self,
426426
output_tensors: List[torch.Tensor],
427-
input_tensors: List[torch.Tensor],
427+
input_tensors: List[List[torch.Tensor]],
428428
opts: object,
429429
) -> Work:
430-
for o, i in zip(output_tensors, input_tensors):
430+
for o, i in zip(output_tensors, input_tensors[0]):
431431
o.copy_(i)
432432

433433
res = _DummyWork(output_tensors)
@@ -1013,7 +1013,6 @@ def reduce_scatter(
10131013
for tensor in tensor_list:
10141014
if not tensor.is_shared():
10151015
tensor.share_memory_()
1016-
10171016
return self._run_func("reduce_scatter", output_tensors, input_tensors, opts)
10181017

10191018
def size(self) -> int:

torchft/process_group_test.py

+41-1
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,31 @@ def dummy_init_pg() -> None:
6161
)
6262

6363

64+
def _should_run_collective(collective_str: str, backend_str: str, device: str) -> bool:
65+
"""Verify if the collective is supported by the backend and device.
66+
67+
See https://pytorch.org/docs/stable/distributed.html#backends for the
68+
supported collectives / backends / devices matrix.
69+
70+
"""
71+
if "nccl" in backend_str.lower():
72+
# all collectives are supported for NCCL/CUDA but none on CPU.
73+
return device == "cuda"
74+
elif "gloo" in backend_str.lower():
75+
if device == "cuda":
76+
# GLOO/GPU only supports broadcast and all_reduce.
77+
if collective_str in ["broadcast", "all_reduce"]:
78+
return True
79+
return False
80+
else: # cpu
81+
if collective_str in ["reduce_scatter", "all_to_all"]:
82+
return False
83+
return True
84+
else:
85+
# Non defined backends (e.g. ErrorSwallowing) should continue to work.
86+
return True
87+
88+
6489
def _test_pg(
6590
pg: ProcessGroup,
6691
example_tensor: torch.Tensor = torch.randn((2, 3), dtype=torch.float32),
@@ -95,10 +120,25 @@ def check_tensors(arg: Any) -> None: # pyre-ignore[2]
95120
("allgather", (output_tensors, [input_tensor], AllgatherOptions())),
96121
("broadcast", (tensor_list, BroadcastOptions())),
97122
("broadcast_one", (input_tensor, 0)),
98-
("reduce_scatter", (output_tensors, [input_tensor], ReduceScatterOptions())),
123+
(
124+
"reduce_scatter",
125+
(output_tensors[0], [[input_tensor]], ReduceScatterOptions()),
126+
),
99127
]
100128
works: Dict[str, dist._Work] = {}
129+
130+
try:
131+
backend_str = pg.getBackendName()
132+
device = example_tensor.device
133+
if type(device) is torch.device:
134+
device = device.type
135+
except NotImplementedError as e:
136+
backend_str = ""
137+
device = ""
138+
101139
for coll_str, args in collectives:
140+
if not _should_run_collective(coll_str, backend_str=backend_str, device=device):
141+
continue
102142
coll = getattr(pg, coll_str)
103143
work = coll(*args)
104144
works[coll_str] = work

0 commit comments

Comments
 (0)