19
19
import logging
20
20
import queue
21
21
import threading
22
+ from collections import defaultdict
23
+ from contextlib import contextmanager , nullcontext
22
24
from dataclasses import dataclass
23
25
from datetime import timedelta
24
26
from typing import (
25
27
TYPE_CHECKING ,
26
28
Any ,
27
29
Callable ,
28
30
Dict ,
31
+ Generator ,
29
32
List ,
30
33
Optional ,
31
34
Tuple ,
@@ -586,29 +589,59 @@ def __init__(
586
589
self ._timeout = timeout
587
590
588
591
def wait (self , timeout : Optional [timedelta ] = None ) -> bool :
592
+ self ._pg ._assert_alive ()
593
+
589
594
self ._tx .put (("wait" , self ._op_id ), timeout = self ._timeout )
590
- assert _get (self ._rx , self ._timeout ) == self ._op_id
595
+ op_id , event = cast (
596
+ Tuple [int , Optional [torch .cuda .Event ]],
597
+ _get (self ._rx , timeout or self ._timeout ),
598
+ )
599
+ assert op_id == self ._op_id
600
+ if event is not None :
601
+ event .wait ()
591
602
return True
592
603
604
+ def synchronize (self ) -> None :
605
+ # TODO: No one seems to use this and NCCL wait already only waits the
606
+ # stream and is non-blocking on the CPU side so no real need for a
607
+ # separate call.
608
+ raise NotImplementedError ("not implemented" )
609
+
593
610
def get_future (self ) -> Future [object ]:
594
611
return self ._pg ._get_future (self ._op_id )
595
612
596
613
def __del__ (self ) -> None :
597
614
self ._tx .put (("del" , self ._op_id ), timeout = self ._timeout )
598
615
599
616
600
- class _BabyWorkNCCL (_BabyWork ):
601
- def wait (self , timeout : Optional [timedelta ] = None ) -> bool :
602
- self ._tx .put (("synchronize" , self ._op_id ), timeout = self ._timeout )
603
- # pyre-fixme[23]: unable to unpack into 2 values
604
- op_id , event = _get (self ._rx , self ._timeout )
605
- assert op_id == self ._op_id
606
- assert isinstance (event , torch .cuda .Event )
617
+ def _is_any_cuda (obj : object ) -> bool :
618
+ """
619
+ Returns true if any of the tensors in the object are CUDA tensors.
607
620
608
- # Wait on Event makes the stream wait but not the CPU thread.
609
- event .wait ()
621
+ Supports lists, tuples, dicts, and tensors.
622
+ """
623
+ if isinstance (obj , torch .Tensor ):
624
+ return obj .is_cuda
625
+ elif isinstance (obj , (list , tuple )):
626
+ return any (_is_any_cuda (o ) for o in obj )
627
+ elif isinstance (obj , dict ):
628
+ return any (_is_any_cuda (o ) for o in obj .values ())
629
+ else :
630
+ return False
610
631
611
- return True
632
+
633
+ @dataclass
634
+ class _OpMetadata :
635
+ work : Work
636
+ stream : Optional [torch .cuda .Stream ]
637
+
638
+ @contextmanager
639
+ def set_stream (self ) -> Generator [None , None , None ]:
640
+ if self .stream is not None :
641
+ with torch .cuda .stream (self .stream ):
642
+ yield
643
+ else :
644
+ yield
612
645
613
646
614
647
class ProcessGroupBaby (ProcessGroup ):
@@ -617,11 +650,8 @@ class ProcessGroupBaby(ProcessGroup):
617
650
subprocess. Since it's running in a subprocess all tensors need to be in
618
651
shared memory or will be moved to shared memory. CUDA tensors are implicitly
619
652
share able and don't need any changes.
620
-
621
653
"""
622
654
623
- WORK_CLASS : Type [_BabyWork ] = _BabyWork
624
-
625
655
def __init__ (self , timeout : Union [float , timedelta ] = 60.0 ) -> None :
626
656
super ().__init__ (0 , 1 )
627
657
@@ -679,7 +709,14 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None:
679
709
680
710
self ._p = ctx .Process (
681
711
target = self ._worker ,
682
- args = (store_addr , rank , world_size , self ._tx , self ._rx , self ._future_queue ),
712
+ args = (
713
+ store_addr ,
714
+ rank ,
715
+ world_size ,
716
+ self ._tx ,
717
+ self ._rx ,
718
+ self ._future_queue ,
719
+ ),
683
720
daemon = True ,
684
721
)
685
722
self ._p .start ()
@@ -716,23 +753,76 @@ def _worker(
716
753
return
717
754
tx .put (None )
718
755
719
- work = {}
756
+ streams : Dict [str , torch .cuda .Stream ] = {}
757
+ work : Dict [int , _OpMetadata ] = {}
720
758
next_op_id : int = 0
721
759
722
760
while True :
723
761
op = rx .get ()
724
762
cmd = op [0 ]
725
763
if cmd == "func" :
726
- func_name , args , kwargs = op [1 :]
727
- args = _PickleSafeOptions .unsafe_args (args )
728
- fn = getattr (pg , func_name )
729
- work [next_op_id ] = fn (* args , ** kwargs )
764
+ func_name , args , kwargs , stream_device , stream_id , event = op [1 :]
765
+
766
+ print (f"func { func_name = } " )
767
+
768
+ # To avoid potential deadlocks we need to preserve the
769
+ # stream/synchronization behavior of the parent process.
770
+ # We allocate one Stream per stream_id to make sure that we
771
+ # don't accidentally introduce cross stream synchronization
772
+ # points.
773
+ if stream_id is not None :
774
+ stream_key = f"{ stream_device } /{ stream_id } "
775
+ if stream_key not in streams :
776
+ streams [stream_key ] = torch .cuda .Stream (
777
+ device = stream_device
778
+ )
779
+ stream = streams [stream_key ]
780
+ else :
781
+ stream = None
782
+
783
+ with (
784
+ torch .cuda .stream (stream )
785
+ if stream is not None
786
+ else nullcontext ()
787
+ ):
788
+ print ("stream created" )
789
+
790
+ # Make the stream wait on the cuda event to make sure we
791
+ # don't start the operation until the tensor is ready.
792
+ if event is not None :
793
+ event .wait ()
794
+
795
+ print ("waited" )
796
+
797
+ args = _PickleSafeOptions .unsafe_args (args )
798
+ fn = getattr (pg , func_name )
799
+ work [next_op_id ] = _OpMetadata (
800
+ work = fn (* args , ** kwargs ),
801
+ stream = stream ,
802
+ )
730
803
tx .put (next_op_id )
731
804
next_op_id += 1
732
805
elif cmd == "wait" :
733
806
op_id : int = op [1 ]
734
- work [op_id ].wait ()
735
- tx .put (op_id )
807
+
808
+ metadata = work [op_id ]
809
+
810
+ with metadata .set_stream ():
811
+ # With WorkNCCL this makes the stream wait not the CPU when
812
+ # no timeout is passed.
813
+ metadata .work .wait ()
814
+
815
+ # Register event on the stream that we can pass to the main
816
+ # process.
817
+ event = (
818
+ torch .cuda .current_stream ().record_event (
819
+ torch .cuda .Event (interprocess = True )
820
+ )
821
+ if metadata .stream is not None
822
+ else None
823
+ )
824
+
825
+ tx .put ((op_id , event ))
736
826
elif cmd == "del" :
737
827
op_id : int = op [1 ]
738
828
del work [op_id ]
@@ -746,23 +836,8 @@ def callback(fut: Future[object]) -> None:
746
836
except Exception as e :
747
837
future_queue .put ((op_id , _FUTURE_EXCEPTION , e ))
748
838
749
- work [op_id ].get_future ().add_done_callback (callback )
839
+ work [op_id ].work . get_future ().add_done_callback (callback )
750
840
tx .put (op_id )
751
- elif cmd == "synchronize" :
752
- # CUDA only, use events instead of waiting on CPU
753
- op_id = op [1 ]
754
-
755
- # With WorkNCCL this makes the stream wait not the CPU when
756
- # no timeout is passed.
757
- work [op_id ].wait ()
758
-
759
- # Register event on the stream that we can pass to the main
760
- # process.
761
- event = torch .cuda .Event (interprocess = True )
762
- event .record ()
763
-
764
- del work [op_id ]
765
- tx .put ((op_id , event ))
766
841
elif cmd == "num_active_work" :
767
842
tx .put (len (work ))
768
843
else :
@@ -792,6 +867,8 @@ def _future_handler(self, future_queue: mp.Queue) -> None:
792
867
logger .exception (f"got unexpected error in future handler: { e } " )
793
868
794
869
def _get_future (self , op_id : int ) -> Future [object ]:
870
+ self ._assert_alive ()
871
+
795
872
with self ._futures_lock :
796
873
fut = Future () # pyre-fixme[29]: is not a function
797
874
self ._futures [op_id ] = fut
@@ -804,22 +881,52 @@ def _get_future(self, op_id: int) -> Future[object]:
804
881
return fut
805
882
806
883
def _run_func (self , func : str , * args : object , ** kwargs : object ) -> Work :
884
+ self ._assert_alive ()
885
+
807
886
rx = self ._rx
808
887
tx = self ._tx
809
888
assert rx is not None
810
889
assert tx is not None
811
890
891
+ is_cuda = _is_any_cuda (args )
892
+
893
+ stream_device = torch .cuda .current_stream ().device if is_cuda else None
894
+ stream_id = torch .cuda .current_stream ().stream_id if is_cuda else None
895
+ event = (
896
+ torch .cuda .current_stream ().record_event (
897
+ torch .cuda .Event (interprocess = True )
898
+ )
899
+ if is_cuda
900
+ else None
901
+ )
902
+
812
903
tx .put (
813
- ("func" , func , _PickleSafeOptions .safe_args (args ), kwargs ),
904
+ (
905
+ "func" ,
906
+ func ,
907
+ _PickleSafeOptions .safe_args (args ),
908
+ kwargs ,
909
+ stream_device ,
910
+ stream_id ,
911
+ event ,
912
+ ),
814
913
timeout = self ._timeout ,
815
914
)
816
915
817
916
op_id = _get (rx , self ._timeout )
818
917
assert isinstance (op_id , int ), f"invalid return { op_id } "
819
918
820
- return self .WORK_CLASS (
821
- pg = self , tx = tx , rx = rx , op_id = op_id , timeout = self ._timeout
822
- )
919
+ return _BabyWork (pg = self , tx = tx , rx = rx , op_id = op_id , timeout = self ._timeout )
920
+
921
+ def _assert_alive (self ) -> None :
922
+ """
923
+ Assert that the process group is alive. This is used to ensure that
924
+ operations are not performed on a dead process group and any errors are surfaced.
925
+ """
926
+ p = self ._p
927
+ assert p is not None
928
+ if not p .is_alive ():
929
+ raise RuntimeError (f"child process { p .pid = } is dead { p .exitcode = } " )
823
930
824
931
def allreduce (
825
932
self ,
@@ -952,8 +1059,6 @@ class ProcessGroupBabyNCCL(ProcessGroupBaby):
952
1059
tensors may leak in the current PyTorch implementation. TODO fix
953
1060
"""
954
1061
955
- WORK_CLASS = _BabyWorkNCCL
956
-
957
1062
@classmethod
958
1063
def _create_pg (cls , store : Store , rank : int , world_size : int ) -> BaseProcessGroup :
959
1064
# pyre-fixme[16]: no attribute ProcessGroupNCCL
0 commit comments