19
19
import logging
20
20
import queue
21
21
import threading
22
+ from contextlib import contextmanager , nullcontext
22
23
from dataclasses import dataclass
23
24
from datetime import timedelta
24
25
from typing import (
25
26
TYPE_CHECKING ,
26
27
Any ,
27
28
Callable ,
28
29
Dict ,
30
+ Generator ,
29
31
List ,
30
32
Optional ,
31
33
Tuple ,
32
- Type ,
33
34
TypeVar ,
34
35
Union ,
35
36
cast ,
58
59
BroadcastOptions ,
59
60
ReduceOp ,
60
61
Work ,
61
- _world ,
62
62
)
63
63
from torch .futures import Future
64
64
@@ -586,29 +586,59 @@ def __init__(
586
586
self ._timeout = timeout
587
587
588
588
def wait (self , timeout : Optional [timedelta ] = None ) -> bool :
589
+ self ._pg ._assert_alive ()
590
+
589
591
self ._tx .put (("wait" , self ._op_id ), timeout = self ._timeout )
590
- assert _get (self ._rx , self ._timeout ) == self ._op_id
592
+ op_id , event = cast (
593
+ Tuple [int , Optional [torch .cuda .Event ]],
594
+ _get (self ._rx , timeout or self ._timeout ),
595
+ )
596
+ assert op_id == self ._op_id
597
+ if event is not None :
598
+ event .wait ()
591
599
return True
592
600
601
+ def synchronize (self ) -> None :
602
+ # TODO: No one seems to use this and NCCL wait already only waits the
603
+ # stream and is non-blocking on the CPU side so no real need for a
604
+ # separate call.
605
+ raise NotImplementedError ("not implemented" )
606
+
593
607
def get_future (self ) -> Future [object ]:
594
608
return self ._pg ._get_future (self ._op_id )
595
609
596
610
def __del__ (self ) -> None :
597
611
self ._tx .put (("del" , self ._op_id ), timeout = self ._timeout )
598
612
599
613
600
- class _BabyWorkNCCL (_BabyWork ):
601
- def wait (self , timeout : Optional [timedelta ] = None ) -> bool :
602
- self ._tx .put (("synchronize" , self ._op_id ), timeout = self ._timeout )
603
- # pyre-fixme[23]: unable to unpack into 2 values
604
- op_id , event = _get (self ._rx , self ._timeout )
605
- assert op_id == self ._op_id
606
- assert isinstance (event , torch .cuda .Event )
614
+ def _is_any_cuda (obj : object ) -> bool :
615
+ """
616
+ Returns true if any of the tensors in the object are CUDA tensors.
607
617
608
- # Wait on Event makes the stream wait but not the CPU thread.
609
- event .wait ()
618
+ Supports lists, tuples, dicts, and tensors.
619
+ """
620
+ if isinstance (obj , torch .Tensor ):
621
+ return obj .is_cuda
622
+ elif isinstance (obj , (list , tuple )):
623
+ return any (_is_any_cuda (o ) for o in obj )
624
+ elif isinstance (obj , dict ):
625
+ return any (_is_any_cuda (o ) for o in obj .values ())
626
+ else :
627
+ return False
610
628
611
- return True
629
+
630
+ @dataclass
631
+ class _OpMetadata :
632
+ work : Work
633
+ stream : Optional [torch .cuda .Stream ]
634
+
635
+ @contextmanager
636
+ def set_stream (self ) -> Generator [None , None , None ]:
637
+ if self .stream is not None :
638
+ with torch .cuda .stream (self .stream ):
639
+ yield
640
+ else :
641
+ yield
612
642
613
643
614
644
class ProcessGroupBaby (ProcessGroup ):
@@ -617,11 +647,8 @@ class ProcessGroupBaby(ProcessGroup):
617
647
subprocess. Since it's running in a subprocess all tensors need to be in
618
648
shared memory or will be moved to shared memory. CUDA tensors are implicitly
619
649
share able and don't need any changes.
620
-
621
650
"""
622
651
623
- WORK_CLASS : Type [_BabyWork ] = _BabyWork
624
-
625
652
def __init__ (self , timeout : Union [float , timedelta ] = 60.0 ) -> None :
626
653
super ().__init__ (0 , 1 )
627
654
@@ -679,7 +706,14 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None:
679
706
680
707
self ._p = ctx .Process (
681
708
target = self ._worker ,
682
- args = (store_addr , rank , world_size , self ._tx , self ._rx , self ._future_queue ),
709
+ args = (
710
+ store_addr ,
711
+ rank ,
712
+ world_size ,
713
+ self ._tx ,
714
+ self ._rx ,
715
+ self ._future_queue ,
716
+ ),
683
717
daemon = True ,
684
718
)
685
719
self ._p .start ()
@@ -716,23 +750,76 @@ def _worker(
716
750
return
717
751
tx .put (None )
718
752
719
- work = {}
753
+ streams : Dict [str , torch .cuda .Stream ] = {}
754
+ work : Dict [int , _OpMetadata ] = {}
720
755
next_op_id : int = 0
721
756
722
757
while True :
723
758
op = rx .get ()
724
759
cmd = op [0 ]
725
760
if cmd == "func" :
726
- func_name , args , kwargs = op [1 :]
727
- args = _PickleSafeOptions .unsafe_args (args )
728
- fn = getattr (pg , func_name )
729
- work [next_op_id ] = fn (* args , ** kwargs )
761
+ func_name , args , kwargs , stream_device , stream_id , event = op [1 :]
762
+
763
+ print (f"func { func_name = } " )
764
+
765
+ # To avoid potential deadlocks we need to preserve the
766
+ # stream/synchronization behavior of the parent process.
767
+ # We allocate one Stream per stream_id to make sure that we
768
+ # don't accidentally introduce cross stream synchronization
769
+ # points.
770
+ if stream_id is not None :
771
+ stream_key = f"{ stream_device } /{ stream_id } "
772
+ if stream_key not in streams :
773
+ streams [stream_key ] = torch .cuda .Stream (
774
+ device = stream_device
775
+ )
776
+ stream = streams [stream_key ]
777
+ else :
778
+ stream = None
779
+
780
+ with (
781
+ torch .cuda .stream (stream )
782
+ if stream is not None
783
+ else nullcontext ()
784
+ ):
785
+ print ("stream created" )
786
+
787
+ # Make the stream wait on the cuda event to make sure we
788
+ # don't start the operation until the tensor is ready.
789
+ if event is not None :
790
+ event .wait ()
791
+
792
+ print ("waited" )
793
+
794
+ args = _PickleSafeOptions .unsafe_args (args )
795
+ fn = getattr (pg , func_name )
796
+ work [next_op_id ] = _OpMetadata (
797
+ work = fn (* args , ** kwargs ),
798
+ stream = stream ,
799
+ )
730
800
tx .put (next_op_id )
731
801
next_op_id += 1
732
802
elif cmd == "wait" :
733
803
op_id : int = op [1 ]
734
- work [op_id ].wait ()
735
- tx .put (op_id )
804
+
805
+ metadata = work [op_id ]
806
+
807
+ with metadata .set_stream ():
808
+ # With WorkNCCL this makes the stream wait not the CPU when
809
+ # no timeout is passed.
810
+ metadata .work .wait ()
811
+
812
+ # Register event on the stream that we can pass to the main
813
+ # process.
814
+ event = (
815
+ torch .cuda .current_stream ().record_event (
816
+ torch .cuda .Event (interprocess = True )
817
+ )
818
+ if metadata .stream is not None
819
+ else None
820
+ )
821
+
822
+ tx .put ((op_id , event ))
736
823
elif cmd == "del" :
737
824
op_id : int = op [1 ]
738
825
del work [op_id ]
@@ -746,23 +833,8 @@ def callback(fut: Future[object]) -> None:
746
833
except Exception as e :
747
834
future_queue .put ((op_id , _FUTURE_EXCEPTION , e ))
748
835
749
- work [op_id ].get_future ().add_done_callback (callback )
836
+ work [op_id ].work . get_future ().add_done_callback (callback )
750
837
tx .put (op_id )
751
- elif cmd == "synchronize" :
752
- # CUDA only, use events instead of waiting on CPU
753
- op_id = op [1 ]
754
-
755
- # With WorkNCCL this makes the stream wait not the CPU when
756
- # no timeout is passed.
757
- work [op_id ].wait ()
758
-
759
- # Register event on the stream that we can pass to the main
760
- # process.
761
- event = torch .cuda .Event (interprocess = True )
762
- event .record ()
763
-
764
- del work [op_id ]
765
- tx .put ((op_id , event ))
766
838
elif cmd == "num_active_work" :
767
839
tx .put (len (work ))
768
840
else :
@@ -792,6 +864,8 @@ def _future_handler(self, future_queue: mp.Queue) -> None:
792
864
logger .exception (f"got unexpected error in future handler: { e } " )
793
865
794
866
def _get_future (self , op_id : int ) -> Future [object ]:
867
+ self ._assert_alive ()
868
+
795
869
with self ._futures_lock :
796
870
fut = Future () # pyre-fixme[29]: is not a function
797
871
self ._futures [op_id ] = fut
@@ -804,22 +878,52 @@ def _get_future(self, op_id: int) -> Future[object]:
804
878
return fut
805
879
806
880
def _run_func (self , func : str , * args : object , ** kwargs : object ) -> Work :
881
+ self ._assert_alive ()
882
+
807
883
rx = self ._rx
808
884
tx = self ._tx
809
885
assert rx is not None
810
886
assert tx is not None
811
887
888
+ is_cuda = _is_any_cuda (args )
889
+
890
+ stream_device = torch .cuda .current_stream ().device if is_cuda else None
891
+ stream_id = torch .cuda .current_stream ().stream_id if is_cuda else None
892
+ event = (
893
+ torch .cuda .current_stream ().record_event (
894
+ torch .cuda .Event (interprocess = True )
895
+ )
896
+ if is_cuda
897
+ else None
898
+ )
899
+
812
900
tx .put (
813
- ("func" , func , _PickleSafeOptions .safe_args (args ), kwargs ),
901
+ (
902
+ "func" ,
903
+ func ,
904
+ _PickleSafeOptions .safe_args (args ),
905
+ kwargs ,
906
+ stream_device ,
907
+ stream_id ,
908
+ event ,
909
+ ),
814
910
timeout = self ._timeout ,
815
911
)
816
912
817
913
op_id = _get (rx , self ._timeout )
818
914
assert isinstance (op_id , int ), f"invalid return { op_id } "
819
915
820
- return self .WORK_CLASS (
821
- pg = self , tx = tx , rx = rx , op_id = op_id , timeout = self ._timeout
822
- )
916
+ return _BabyWork (pg = self , tx = tx , rx = rx , op_id = op_id , timeout = self ._timeout )
917
+
918
+ def _assert_alive (self ) -> None :
919
+ """
920
+ Assert that the process group is alive. This is used to ensure that
921
+ operations are not performed on a dead process group and any errors are surfaced.
922
+ """
923
+ p = self ._p
924
+ assert p is not None
925
+ if not p .is_alive ():
926
+ raise RuntimeError (f"child process { p .pid = } is dead { p .exitcode = } " )
823
927
824
928
def allreduce (
825
929
self ,
@@ -952,8 +1056,6 @@ class ProcessGroupBabyNCCL(ProcessGroupBaby):
952
1056
tensors may leak in the current PyTorch implementation. TODO fix
953
1057
"""
954
1058
955
- WORK_CLASS = _BabyWorkNCCL
956
-
957
1059
@classmethod
958
1060
def _create_pg (cls , store : Store , rank : int , world_size : int ) -> BaseProcessGroup :
959
1061
# pyre-fixme[16]: no attribute ProcessGroupNCCL
0 commit comments