@@ -85,21 +85,30 @@ def getBackendName(self) -> str:
8585 raise NotImplementedError ("not implemented" )
8686
8787
88- class ProcessGroupGloo (ProcessGroup ):
88+ class ProcessGroupWrapper (ProcessGroup ):
89+ PG_CLASS : Type [BaseProcessGroup ]
8990 """
90- This is a wrapper around ProcessGroupGloo with a reconfiguration argument .
91+ This is a wrapper around any ProcessGroup with a reconfiguration method .
9192 """
9293
93- def __init__ (self ) -> None :
94+ def __init__ (self , timeout : float = 60.0 ) -> None :
95+ """
96+ Args:
97+ timeout: the timeout to use for the process group
98+ """
9499 super ().__init__ (0 , 1 )
95100 self ._pg = None
96101
97102 def configure (self , store_addr : str , rank : int , world_size : int ) -> None :
103+ if self ._pg is not None :
104+ if hasattr (self ._pg , "abort" ):
105+ self ._pg .abort ()
106+ self ._pg = None
107+
98108 store = create_store (store_addr )
99109
100- # TODO: set lower timeout
101- # pyre-fixme[16]: no attribute ProcessGroupGloo
102- self ._pg = BaseProcessGroupGloo (store , rank , world_size )
110+ # TODO: set global timeout
111+ self ._pg = self .PG_CLASS (store , rank , world_size )
103112
104113 def allreduce (self , tensors : List [torch .Tensor ], opts : object ) -> Work :
105114 return self ._pg .allreduce (tensors , opts )
@@ -118,10 +127,35 @@ def broadcast(self, tensor_list: List[torch.Tensor], opts: object) -> Work:
118127 def size (self ) -> int :
119128 return self ._pg .size ()
120129
130+
131+ class ProcessGroupGloo (ProcessGroupWrapper ):
132+ """
133+ This is a reconfigurable version of ProcessGroupGloo.
134+ """
135+
136+ PG_CLASS = BaseProcessGroupGloo
137+
121138 def getBackendName (self ) -> str :
122139 return "torchft-gloo"
123140
124141
142+ class ProcessGroupNCCL (ProcessGroupWrapper ):
143+ """
144+ This is a reconfigurable version of ProcessGroupNCCL.
145+
146+ WARNING: this may result in deadlocks due to NCCL error handling. This is
147+ provided for completeness but your mileage may vary.
148+
149+ TODO: verify shutdown correctness with latest NCCL. This currently will call
150+ abort when reconfiguring, we need to ensure this is safe.
151+ """
152+
153+ PG_CLASS = BaseProcessGroupNCCL
154+
155+ def getBackendName (self ) -> str :
156+ return "torchft-nccl"
157+
158+
125159class DummyWork (dist ._Work ):
126160 def __init__ (self , result ):
127161 super ().__init__ ()
0 commit comments