88import logging
99from typing import Type , List , Optional , Callable , Tuple
1010from datetime import timedelta
11+ import threading
1112
1213from torch .futures import Future
1314from torch .distributed import (
2627
2728logger = logging .getLogger (__name__ )
2829
30+ # TODO: use non strings which are cheaper
31+ _QUEUE_CLOSE = "queue_close"
32+ _FUTURE_RESULT = "fut_result"
33+ _FUTURE_EXCEPTION = "fut_exception"
34+
2935
3036def _get (queue : mp .Queue , timeout ) -> object :
3137 v = queue .get (timeout = timeout )
@@ -208,9 +214,17 @@ def getBackendName(self):
208214
209215
210216class BabyWork (Work ):
211- def __init__ (self , tx : mp .Queue , rx : mp .Queue , op_id : int , timeout : float ):
217+ def __init__ (
218+ self ,
219+ pg : "ProcessGroupBaby" ,
220+ tx : mp .Queue ,
221+ rx : mp .Queue ,
222+ op_id : int ,
223+ timeout : float ,
224+ ):
212225 super ().__init__ ()
213226
227+ self ._pg = pg
214228 self ._tx = tx
215229 self ._rx = rx
216230 self ._op_id = op_id
@@ -221,6 +235,9 @@ def wait(self) -> bool:
221235 assert _get (self ._rx , self ._timeout ) == self ._op_id
222236 return True
223237
238+ def get_future (self ) -> Future :
239+ return self ._pg ._get_future (self ._op_id )
240+
224241
225242class BabyWorkNCCL (BabyWork ):
226243 def wait (self ) -> bool :
@@ -255,6 +272,8 @@ def __init__(self, timeout: float = 60.0) -> None:
255272 self ._p = None
256273 self ._tx = None
257274 self ._rx = None
275+ self ._future_queue = None
276+ self ._future_thread = None
258277
259278 self ._timeout = timeout
260279
@@ -264,20 +283,46 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None:
264283
265284 self ._world_size = world_size
266285
286+ if self ._tx is not None :
287+ self ._tx .close ()
288+ if self ._rx is not None :
289+ self ._rx .close ()
290+ if self ._future_queue is not None :
291+ self ._future_queue .put (_QUEUE_CLOSE )
292+ self ._future_queue .close ()
293+
267294 ctx = mp .get_context ("spawn" )
268295 self ._tx = ctx .Queue ()
269296 self ._rx = ctx .Queue ()
270297
298+ # futures need thread to fire callbacks
299+ self ._future_queue = ctx .Queue ()
300+ # this lock needs to be held when manipulating _futures
301+ self ._futures_lock = threading .Lock ()
302+ self ._futures = {}
303+ self ._future_thread = threading .Thread (
304+ target = self ._future_handler ,
305+ args = (self ._future_queue ,),
306+ daemon = True ,
307+ )
308+ self ._future_thread .start ()
309+
271310 self ._p = ctx .Process (
272311 target = self ._worker ,
273- args = (store_addr , rank , world_size , self ._tx , self ._rx ),
312+ args = (store_addr , rank , world_size , self ._tx , self ._rx , self . _future_queue ),
274313 daemon = True ,
275314 )
276315 self ._p .start ()
277316
278317 @classmethod
279318 def _worker (
280- cls , store_addr : str , rank : int , world_size : int , rx : mp .Queue , tx : mp .Queue
319+ cls ,
320+ store_addr : str ,
321+ rank : int ,
322+ world_size : int ,
323+ rx : mp .Queue ,
324+ tx : mp .Queue ,
325+ future_queue : mp .Queue ,
281326 ) -> None :
282327 try :
283328 store = create_store (store_addr )
@@ -291,15 +336,28 @@ def _worker(
291336 op = rx .get ()
292337 cmd = op [0 ]
293338 if cmd == "func" :
294- func , args , kwargs = op [1 :]
295- work [next_op_id ] = getattr (pg , func )(* args , ** kwargs )
339+ func_name , args , kwargs = op [1 :]
340+ fn = getattr (pg , func_name )
341+ work [next_op_id ] = fn (* args , ** kwargs )
296342 tx .put (next_op_id )
297343 next_op_id += 1
298344 elif cmd == "wait" :
299345 op_id = op [1 ]
300346 work [op_id ].wait ()
301347 del work [op_id ]
302348 tx .put (op_id )
349+ elif cmd == "future" :
350+ op_id = op [1 ]
351+
352+ def callback (fut : Future ):
353+ try :
354+ fut .wait ()
355+ future_queue .put ((op_id , _FUTURE_RESULT , None ))
356+ except Exception as e :
357+ future_queue .put ((op_id , _FUTURE_EXCEPTION , e ))
358+
359+ work [op_id ].get_future ().add_done_callback (callback )
360+ tx .put (op_id )
303361 elif cmd == "synchronize" :
304362 # CUDA only, use events instead of waiting on CPU
305363 op_id = op [1 ]
@@ -322,12 +380,41 @@ def _worker(
322380 logger .exception ("worker errored" )
323381 tx .put (e )
324382
383+ def _future_handler (self , future_queue : mp .Queue ) -> None :
384+ try :
385+ while True :
386+ cmd = future_queue .get ()
387+ if cmd == _QUEUE_CLOSE :
388+ break
389+ op_id , mode , data = cmd
390+ with self ._futures_lock :
391+ fut = self ._futures [op_id ]
392+ del self ._futures [op_id ]
393+ if mode == _FUTURE_RESULT :
394+ fut .set_result (data )
395+ elif mode == _FUTURE_EXCEPTION :
396+ fut .set_exception (data )
397+ else :
398+ raise ValueError (f"unknown mode { mode } " )
399+ except Exception as e :
400+ logger .exception (f"got unexpected error in future handler: { e } " )
401+
402+ def _get_future (self , op_id : int ) -> Future :
403+ with self ._futures_lock :
404+ fut = Future ()
405+ self ._futures [op_id ] = fut
406+ self ._tx .put (("future" , op_id ), timeout = self ._timeout )
407+
408+ assert _get (self ._rx , self ._timeout ) == op_id
409+ # TODO: return correct tensor instead of None
410+ return fut
411+
325412 def _run_func (self , func : str , * args : object , ** kwargs : object ) -> Work :
326413 self ._tx .put (("func" , func , args , kwargs ), timeout = self ._timeout )
327414 op_id = _get (self ._rx , self ._timeout )
328415 assert isinstance (op_id , int ), f"invalid return { op_id } "
329416 return self .WORK_CLASS (
330- tx = self ._tx , rx = self ._rx , op_id = op_id , timeout = self ._timeout
417+ pg = self , tx = self ._tx , rx = self ._rx , op_id = op_id , timeout = self ._timeout
331418 )
332419
333420 def allreduce (self , tensors : List [torch .Tensor ], opts : object ) -> Work :
@@ -366,7 +453,7 @@ class ProcessGroupBabyNCCL(ProcessGroupBaby):
366453 tensors may leak in the current PyTorch implementation. TODO fix
367454 """
368455
369- PG_CLASS = BaseProcessGroupGloo
456+ PG_CLASS = BaseProcessGroupNCCL
370457 WORK_CLASS = BabyWorkNCCL
371458
372459 def getBackendName (self ):
0 commit comments