@@ -61,6 +61,31 @@ def dummy_init_pg() -> None:
61
61
)
62
62
63
63
64
+ def _should_run_collective (collective_str : str , backend_str : str , device : str ) -> bool :
65
+ """Verify if the collective is supported by the backend and device.
66
+
67
+ See https://pytorch.org/docs/stable/distributed.html#backends for the
68
+ supported collectives / backends / devices matrix.
69
+
70
+ """
71
+ if "nccl" in backend_str .lower ():
72
+ # all collectives are supported for NCCL/CUDA but none on CPU.
73
+ return device == "cuda"
74
+ elif "gloo" in backend_str .lower ():
75
+ if device == "cuda" :
76
+ # GLOO/GPU only supports broadcast and all_reduce.
77
+ if collective_str in ["broadcast" , "all_reduce" ]:
78
+ return True
79
+ return False
80
+ else : # cpu
81
+ if collective_str in ["reduce_scatter" , "all_to_all" ]:
82
+ return False
83
+ return True
84
+ else :
85
+ # Non defined backends (e.g. ErrorSwallowing) should continue to work.
86
+ return True
87
+
88
+
64
89
def _test_pg (
65
90
pg : ProcessGroup ,
66
91
example_tensor : torch .Tensor = torch .randn ((2 , 3 ), dtype = torch .float32 ),
@@ -95,10 +120,25 @@ def check_tensors(arg: Any) -> None: # pyre-ignore[2]
95
120
("allgather" , (output_tensors , [input_tensor ], AllgatherOptions ())),
96
121
("broadcast" , (tensor_list , BroadcastOptions ())),
97
122
("broadcast_one" , (input_tensor , 0 )),
98
- ("reduce_scatter" , (output_tensors , [input_tensor ], ReduceScatterOptions ())),
123
+ (
124
+ "reduce_scatter" ,
125
+ (output_tensors [0 ], [[input_tensor ]], ReduceScatterOptions ()),
126
+ ),
99
127
]
100
128
works : Dict [str , dist ._Work ] = {}
129
+
130
+ try :
131
+ backend_str = pg .getBackendName ()
132
+ device = example_tensor .device
133
+ if type (device ) is torch .device :
134
+ device = device .type
135
+ except NotImplementedError as e :
136
+ backend_str = ""
137
+ device = ""
138
+
101
139
for coll_str , args in collectives :
140
+ if not _should_run_collective (coll_str , backend_str = backend_str , device = device ):
141
+ continue
102
142
coll = getattr (pg , coll_str )
103
143
work = coll (* args )
104
144
works [coll_str ] = work
0 commit comments