19
19
import logging
20
20
import queue
21
21
import threading
22
+ from collections import defaultdict
23
+ from contextlib import contextmanager , nullcontext
22
24
from dataclasses import dataclass
23
25
from datetime import timedelta
24
26
from typing import (
25
27
TYPE_CHECKING ,
26
28
Any ,
27
29
Callable ,
28
30
Dict ,
31
+ Generator ,
29
32
List ,
30
33
Optional ,
31
34
Tuple ,
@@ -587,28 +590,56 @@ def __init__(
587
590
588
591
def wait (self , timeout : Optional [timedelta ] = None ) -> bool :
589
592
self ._tx .put (("wait" , self ._op_id ), timeout = self ._timeout )
590
- assert _get (self ._rx , self ._timeout ) == self ._op_id
593
+ op_id , event = cast (
594
+ Tuple [int , Optional [torch .cuda .Event ]],
595
+ _get (self ._rx , timeout or self ._timeout ),
596
+ )
597
+ assert op_id == self ._op_id
598
+ if event is not None :
599
+ event .wait ()
591
600
return True
592
601
602
+ def synchronize (self ) -> None :
603
+ # TODO: No one seems to use this and NCCL wait already only waits the
604
+ # stream and is non-blocking on the CPU side so no real need for a
605
+ # separate call.
606
+ raise NotImplementedError ("not implemented" )
607
+
593
608
def get_future (self ) -> Future [object ]:
594
609
return self ._pg ._get_future (self ._op_id )
595
610
596
611
def __del__ (self ) -> None :
597
612
self ._tx .put (("del" , self ._op_id ), timeout = self ._timeout )
598
613
599
614
600
- class _BabyWorkNCCL (_BabyWork ):
601
- def wait (self , timeout : Optional [timedelta ] = None ) -> bool :
602
- self ._tx .put (("synchronize" , self ._op_id ), timeout = self ._timeout )
603
- # pyre-fixme[23]: unable to unpack into 2 values
604
- op_id , event = _get (self ._rx , self ._timeout )
605
- assert op_id == self ._op_id
606
- assert isinstance (event , torch .cuda .Event )
615
+ def _is_any_cuda (obj : object ) -> bool :
616
+ """
617
+ Returns true if any of the tensors in the object are CUDA tensors.
607
618
608
- # Wait on Event makes the stream wait but not the CPU thread.
609
- event .wait ()
619
+ Supports lists, tuples, dicts, and tensors.
620
+ """
621
+ if isinstance (obj , torch .Tensor ):
622
+ return obj .is_cuda
623
+ elif isinstance (obj , (list , tuple )):
624
+ return any (_is_any_cuda (o ) for o in obj )
625
+ elif isinstance (obj , dict ):
626
+ return any (_is_any_cuda (o ) for o in obj .values ())
627
+ else :
628
+ return False
610
629
611
- return True
630
+
631
+ @dataclass
632
+ class _OpMetadata :
633
+ work : Work
634
+ stream : Optional [torch .cuda .Stream ]
635
+
636
+ @contextmanager
637
+ def set_stream (self ) -> Generator [None , None , None ]:
638
+ if self .stream is not None :
639
+ with torch .cuda .stream (self .stream ):
640
+ yield
641
+ else :
642
+ yield
612
643
613
644
614
645
class ProcessGroupBaby (ProcessGroup ):
@@ -620,8 +651,6 @@ class ProcessGroupBaby(ProcessGroup):
620
651
621
652
"""
622
653
623
- WORK_CLASS : Type [_BabyWork ] = _BabyWork
624
-
625
654
def __init__ (self , timeout : Union [float , timedelta ] = 60.0 ) -> None :
626
655
super ().__init__ (0 , 1 )
627
656
@@ -716,23 +745,62 @@ def _worker(
716
745
return
717
746
tx .put (None )
718
747
719
- work = {}
748
+ streams = defaultdict (lambda : torch .cuda .Stream ())
749
+ work : Dict [int , _OpMetadata ] = {}
720
750
next_op_id : int = 0
721
751
722
752
while True :
723
753
op = rx .get ()
724
754
cmd = op [0 ]
725
755
if cmd == "func" :
726
- func_name , args , kwargs = op [1 :]
727
- args = _PickleSafeOptions .unsafe_args (args )
728
- fn = getattr (pg , func_name )
729
- work [next_op_id ] = fn (* args , ** kwargs )
756
+ func_name , args , kwargs , stream_id , event = op [1 :]
757
+
758
+ # To avoid potential deadlocks we need to preserve the
759
+ # stream/synchronization behavior of the parent process.
760
+ # We allocate one Stream per stream_id to make sure that we
761
+ # don't accidentally introduce cross stream synchronization
762
+ # points.
763
+ stream = streams [stream_id ] if stream_id is not None else None
764
+ with (
765
+ torch .cuda .stream (stream )
766
+ if stream is not None
767
+ else nullcontext ()
768
+ ):
769
+
770
+ # Make the stream wait on the cuda event to make sure we
771
+ # don't start the operation until the tensor is ready.
772
+ if event is not None :
773
+ event .wait ()
774
+
775
+ args = _PickleSafeOptions .unsafe_args (args )
776
+ fn = getattr (pg , func_name )
777
+ work [next_op_id ] = _OpMetadata (
778
+ work = fn (* args , ** kwargs ),
779
+ stream = stream ,
780
+ )
730
781
tx .put (next_op_id )
731
782
next_op_id += 1
732
783
elif cmd == "wait" :
733
784
op_id : int = op [1 ]
734
- work [op_id ].wait ()
735
- tx .put (op_id )
785
+
786
+ metadata = work [op_id ]
787
+
788
+ with metadata .set_stream ():
789
+ # With WorkNCCL this makes the stream wait not the CPU when
790
+ # no timeout is passed.
791
+ metadata .work .wait ()
792
+
793
+ # Register event on the stream that we can pass to the main
794
+ # process.
795
+ event = (
796
+ torch .cuda .current_stream ().record_event (
797
+ torch .cuda .Event (interprocess = True )
798
+ )
799
+ if metadata .stream is not None
800
+ else None
801
+ )
802
+
803
+ tx .put ((op_id , event ))
736
804
elif cmd == "del" :
737
805
op_id : int = op [1 ]
738
806
del work [op_id ]
@@ -746,23 +814,8 @@ def callback(fut: Future[object]) -> None:
746
814
except Exception as e :
747
815
future_queue .put ((op_id , _FUTURE_EXCEPTION , e ))
748
816
749
- work [op_id ].get_future ().add_done_callback (callback )
817
+ work [op_id ].work . get_future ().add_done_callback (callback )
750
818
tx .put (op_id )
751
- elif cmd == "synchronize" :
752
- # CUDA only, use events instead of waiting on CPU
753
- op_id = op [1 ]
754
-
755
- # With WorkNCCL this makes the stream wait not the CPU when
756
- # no timeout is passed.
757
- work [op_id ].wait ()
758
-
759
- # Register event on the stream that we can pass to the main
760
- # process.
761
- event = torch .cuda .Event (interprocess = True )
762
- event .record ()
763
-
764
- del work [op_id ]
765
- tx .put ((op_id , event ))
766
819
elif cmd == "num_active_work" :
767
820
tx .put (len (work ))
768
821
else :
@@ -809,17 +862,33 @@ def _run_func(self, func: str, *args: object, **kwargs: object) -> Work:
809
862
assert rx is not None
810
863
assert tx is not None
811
864
865
+ is_cuda = _is_any_cuda (args )
866
+
867
+ stream_id = torch .cuda .current_stream ().stream_id if is_cuda else None
868
+ event = (
869
+ torch .cuda .current_stream ().record_event (
870
+ torch .cuda .Event (interprocess = True )
871
+ )
872
+ if is_cuda
873
+ else None
874
+ )
875
+
812
876
tx .put (
813
- ("func" , func , _PickleSafeOptions .safe_args (args ), kwargs ),
877
+ (
878
+ "func" ,
879
+ func ,
880
+ _PickleSafeOptions .safe_args (args ),
881
+ kwargs ,
882
+ stream_id ,
883
+ event ,
884
+ ),
814
885
timeout = self ._timeout ,
815
886
)
816
887
817
888
op_id = _get (rx , self ._timeout )
818
889
assert isinstance (op_id , int ), f"invalid return { op_id } "
819
890
820
- return self .WORK_CLASS (
821
- pg = self , tx = tx , rx = rx , op_id = op_id , timeout = self ._timeout
822
- )
891
+ return _BabyWork (pg = self , tx = tx , rx = rx , op_id = op_id , timeout = self ._timeout )
823
892
824
893
def allreduce (
825
894
self ,
@@ -952,8 +1021,6 @@ class ProcessGroupBabyNCCL(ProcessGroupBaby):
952
1021
tensors may leak in the current PyTorch implementation. TODO fix
953
1022
"""
954
1023
955
- WORK_CLASS = _BabyWorkNCCL
956
-
957
1024
@classmethod
958
1025
def _create_pg (cls , store : Store , rank : int , world_size : int ) -> BaseProcessGroup :
959
1026
# pyre-fixme[16]: no attribute ProcessGroupNCCL
0 commit comments