57
57
AllreduceOptions ,
58
58
BroadcastOptions ,
59
59
ReduceOp ,
60
+ ReduceScatterOptions ,
60
61
Work ,
61
62
)
62
63
from torch .futures import Future
@@ -159,6 +160,20 @@ def broadcast_one(self, tensor: torch.Tensor, root: int) -> Work:
159
160
opts .rootRank = root
160
161
return self .broadcast ([tensor ], opts )
161
162
163
+ # pyre-fixme[14]: inconsistent override
164
+ def reduce_scatter (
165
+ self ,
166
+ output_tensors : List [torch .Tensor ],
167
+ input_tensors : List [List [torch .Tensor ]],
168
+ opts : ReduceScatterOptions ,
169
+ ) -> Work :
170
+ """
171
+ Reduces, then scatters a list of tensors to all processes in a group.
172
+
173
+ See torch.distributed.reduce_scatter for more details.
174
+ """
175
+ raise NotImplementedError ("not implemented" )
176
+
162
177
def size (self ) -> int :
163
178
raise NotImplementedError ("not implemented" )
164
179
@@ -267,6 +282,14 @@ def allgather(
267
282
def broadcast (self , tensor_list : List [torch .Tensor ], opts : object ) -> Work :
268
283
return self .parent .broadcast (tensor_list , opts )
269
284
285
+ def reduce_scatter (
286
+ self ,
287
+ output_tensors : List [torch .Tensor ],
288
+ input_tensors : List [List [torch .Tensor ]],
289
+ opts : object ,
290
+ ) -> Work :
291
+ return self .parent .reduce_scatter (output_tensors , input_tensors , opts )
292
+
270
293
def size (self ) -> int :
271
294
return self .parent .size ()
272
295
@@ -295,6 +318,25 @@ def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGro
295
318
def getBackendName (self ) -> str :
296
319
return "torchft-gloo"
297
320
321
+ # pyre-fixme[14,15]: inconsistent override
322
+ def reduce_scatter (
323
+ self ,
324
+ output_tensors : List [torch .Tensor ],
325
+ input_tensors : List [List [torch .Tensor ]],
326
+ opts : ReduceScatterOptions ,
327
+ ) -> None :
328
+ """
329
+ This function is a placeholder for the reduce_scatter operation in the
330
+ ProcessGroupGloo class. However, this operation is not supported by the
331
+ Gloo backend, and thus, calling this function will raise a
332
+ RuntimeError.
333
+
334
+ Raises:
335
+ RuntimeError: Always raised since reduce_scatter is not
336
+ supported by ProcessGroupGloo.
337
+ """
338
+ raise RuntimeError ("ProcessGroupGloo does not support reduce_scatter." )
339
+
298
340
299
341
class ProcessGroupNCCL (ProcessGroupWrapper ):
300
342
"""
@@ -354,11 +396,6 @@ def __init__(self, rank: int, world: int) -> None:
354
396
def configure (self , store_addr : str , rank : int , world_size : int ) -> None :
355
397
self .configure_count += 1
356
398
357
- def broadcast (self , tensor_list : List [torch .Tensor ], opts : object ) -> Work :
358
- res = _DummyWork (tensor_list )
359
- self ._work .append (res )
360
- return res
361
-
362
399
def allgather (
363
400
self ,
364
401
output_tensors : List [List [torch .Tensor ]],
@@ -377,6 +414,24 @@ def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
377
414
self ._work .append (res )
378
415
return res
379
416
417
+ def broadcast (self , tensor_list : List [torch .Tensor ], opts : object ) -> Work :
418
+ res = _DummyWork (tensor_list )
419
+ self ._work .append (res )
420
+ return res
421
+
422
+ def reduce_scatter (
423
+ self ,
424
+ output_tensors : List [torch .Tensor ],
425
+ input_tensors : List [List [torch .Tensor ]],
426
+ opts : object ,
427
+ ) -> Work :
428
+ for o , i in zip (output_tensors , input_tensors [0 ]):
429
+ o .copy_ (i )
430
+
431
+ res = _DummyWork (output_tensors )
432
+ self ._work .append (res )
433
+ return res
434
+
380
435
def size (self ) -> int :
381
436
return self ._world
382
437
@@ -970,6 +1025,25 @@ def broadcast(
970
1025
971
1026
return self ._run_func ("broadcast" , tensor_list , opts )
972
1027
1028
+ def reduce_scatter (
1029
+ self ,
1030
+ output_tensors : List [torch .Tensor ],
1031
+ input_tensors : List [List [torch .Tensor ]],
1032
+ opts : ReduceScatterOptions ,
1033
+ ) -> Work :
1034
+ assert isinstance (output_tensors , list ), "input must be list"
1035
+ assert isinstance (input_tensors , list ), "input must be list"
1036
+
1037
+ for tensor in output_tensors :
1038
+ if not tensor .is_shared ():
1039
+ tensor .share_memory_ ()
1040
+
1041
+ for tensor_list in input_tensors :
1042
+ for tensor in tensor_list :
1043
+ if not tensor .is_shared ():
1044
+ tensor .share_memory_ ()
1045
+ return self ._run_func ("reduce_scatter" , output_tensors , input_tensors , opts )
1046
+
973
1047
def size (self ) -> int :
974
1048
return self ._world_size
975
1049
@@ -992,7 +1066,15 @@ def safe_args(cls, args: T) -> T:
992
1066
return tuple (cls .safe_args (arg ) for arg in args )
993
1067
elif isinstance (args , list ):
994
1068
return [cls .safe_args (arg ) for arg in args ]
995
- elif isinstance (args , (AllreduceOptions , AllgatherOptions , BroadcastOptions )):
1069
+ elif isinstance (
1070
+ args ,
1071
+ (
1072
+ AllreduceOptions ,
1073
+ AllgatherOptions ,
1074
+ BroadcastOptions ,
1075
+ ReduceScatterOptions ,
1076
+ ),
1077
+ ):
996
1078
return cls .from_torch (args )
997
1079
else :
998
1080
return args
@@ -1038,6 +1120,25 @@ def _create_pg(cls, store: Store, rank: int, world_size: int) -> BaseProcessGrou
1038
1120
def getBackendName (self ) -> str :
1039
1121
return "torchft-baby-gloo"
1040
1122
1123
+ # pyre-fixme[15]: inconsistent override
1124
+ def reduce_scatter (
1125
+ self ,
1126
+ output_tensors : List [torch .Tensor ],
1127
+ input_tensors : List [List [torch .Tensor ]],
1128
+ opts : ReduceScatterOptions ,
1129
+ ) -> None :
1130
+ """
1131
+ This function is a placeholder for the reduce_scatter operation in the
1132
+ ProcessGroupGloo class. However, this operation is not supported by the
1133
+ Gloo backend, and thus, calling this function will raise a
1134
+ RuntimeError.
1135
+
1136
+ Raises:
1137
+ RuntimeError: Always raised since reduce_scatter is not
1138
+ supported by ProcessGroupGloo.
1139
+ """
1140
+ raise RuntimeError ("ProcessGroupBabyGloo does not support reduce_scatter." )
1141
+
1041
1142
1042
1143
class ProcessGroupBabyNCCL (ProcessGroupBaby ):
1043
1144
"""
0 commit comments