@@ -187,7 +187,6 @@ def __repr__(self) -> str:
187
187
188
188
189
189
class ProcessGroupWrapper (ProcessGroup ):
190
- PG_CLASS : Type [BaseProcessGroup ] # pyre-fixme[13]: never initialized
191
190
"""
192
191
This is a wrapper around any ProcessGroup with a reconfiguration method.
193
192
"""
@@ -209,9 +208,10 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None:
209
208
210
209
store = create_store_client (store_addr )
211
210
212
- # TODO: set global timeout
213
- # pyre-fixme[20]: expects argument options
214
- self ._pg = self .PG_CLASS (store , rank , world_size )
211
+ self ._pg = self ._create_pg (store , rank , world_size )
212
+
213
+ def _create_pg (self , store : Store , rank : int , world_size : int ) -> BaseProcessGroup :
214
+ raise NotImplementedError ("not implemented" )
215
215
216
216
def allreduce (self , tensors : List [torch .Tensor ], opts : object ) -> Work :
217
217
return self .parent .allreduce (tensors , opts )
@@ -244,9 +244,12 @@ class ProcessGroupGloo(ProcessGroupWrapper):
244
244
This is a reconfigurable version of ProcessGroupGloo.
245
245
"""
246
246
247
- PG_CLASS : Type [BaseProcessGroup ] = (
248
- BaseProcessGroupGloo # pyre-fixme[16]: no attribute ProcessGroupGloo
249
- )
247
+ def __init__ (self , timeout : timedelta = timedelta (seconds = 60.0 )) -> None :
248
+ super ().__init__ ()
249
+ self ._timeout = timeout
250
+
251
+ def _create_pg (self , store : Store , rank : int , world_size : int ) -> BaseProcessGroup :
252
+ return BaseProcessGroupGloo (store , rank , world_size , self ._timeout )
250
253
251
254
def getBackendName (self ) -> str :
252
255
return "torchft-gloo"
@@ -263,9 +266,12 @@ class ProcessGroupNCCL(ProcessGroupWrapper):
263
266
abort when reconfiguring, we need to ensure this is safe.
264
267
"""
265
268
266
- PG_CLASS : Type [BaseProcessGroup ] = (
267
- BaseProcessGroupNCCL # pyre-fixme[16]: no attribute ProcessGroupNCCL
268
- )
269
+ def __init__ (self , timeout : timedelta = timedelta (seconds = 60.0 )) -> None :
270
+ super ().__init__ ()
271
+ self ._timeout = timeout
272
+
273
+ def _create_pg (self , store : Store , rank : int , world_size : int ) -> BaseProcessGroup :
274
+ return BaseProcessGroupNCCL (store , rank , world_size , self ._timeout )
269
275
270
276
def getBackendName (self ) -> str :
271
277
return "torchft-nccl"
0 commit comments