19
19
import logging
20
20
import queue
21
21
import threading
22
+ from collections import defaultdict
23
+ from contextlib import contextmanager , nullcontext
22
24
from dataclasses import dataclass
23
25
from datetime import timedelta
24
26
from typing import (
25
27
TYPE_CHECKING ,
26
28
Any ,
27
29
Callable ,
28
30
Dict ,
31
+ Generator ,
29
32
List ,
30
33
Optional ,
31
34
Tuple ,
@@ -586,29 +589,59 @@ def __init__(
586
589
self ._timeout = timeout
587
590
588
591
def wait (self , timeout : Optional [timedelta ] = None ) -> bool :
592
+ self ._pg ._assert_alive ()
593
+
589
594
self ._tx .put (("wait" , self ._op_id ), timeout = self ._timeout )
590
- assert _get (self ._rx , self ._timeout ) == self ._op_id
595
+ op_id , event = cast (
596
+ Tuple [int , Optional [torch .cuda .Event ]],
597
+ _get (self ._rx , timeout or self ._timeout ),
598
+ )
599
+ assert op_id == self ._op_id
600
+ if event is not None :
601
+ event .wait ()
591
602
return True
592
603
604
+ def synchronize (self ) -> None :
605
+ # TODO: No one seems to use this and NCCL wait already only waits the
606
+ # stream and is non-blocking on the CPU side so no real need for a
607
+ # separate call.
608
+ raise NotImplementedError ("not implemented" )
609
+
593
610
def get_future (self ) -> Future [object ]:
594
611
return self ._pg ._get_future (self ._op_id )
595
612
596
613
def __del__ (self ) -> None :
597
614
self ._tx .put (("del" , self ._op_id ), timeout = self ._timeout )
598
615
599
616
600
- class _BabyWorkNCCL (_BabyWork ):
601
- def wait (self , timeout : Optional [timedelta ] = None ) -> bool :
602
- self ._tx .put (("synchronize" , self ._op_id ), timeout = self ._timeout )
603
- # pyre-fixme[23]: unable to unpack into 2 values
604
- op_id , event = _get (self ._rx , self ._timeout )
605
- assert op_id == self ._op_id
606
- assert isinstance (event , torch .cuda .Event )
617
+ def _is_any_cuda (obj : object ) -> bool :
618
+ """
619
+ Returns true if any of the tensors in the object are CUDA tensors.
607
620
608
- # Wait on Event makes the stream wait but not the CPU thread.
609
- event .wait ()
621
+ Supports lists, tuples, dicts, and tensors.
622
+ """
623
+ if isinstance (obj , torch .Tensor ):
624
+ return obj .is_cuda
625
+ elif isinstance (obj , (list , tuple )):
626
+ return any (_is_any_cuda (o ) for o in obj )
627
+ elif isinstance (obj , dict ):
628
+ return any (_is_any_cuda (o ) for o in obj .values ())
629
+ else :
630
+ return False
610
631
611
- return True
632
+
633
+ @dataclass
634
+ class _OpMetadata :
635
+ work : Work
636
+ stream : Optional [torch .cuda .Stream ]
637
+
638
+ @contextmanager
639
+ def set_stream (self ) -> Generator [None , None , None ]:
640
+ if self .stream is not None :
641
+ with torch .cuda .stream (self .stream ):
642
+ yield
643
+ else :
644
+ yield
612
645
613
646
614
647
class ProcessGroupBaby (ProcessGroup ):
@@ -617,11 +650,8 @@ class ProcessGroupBaby(ProcessGroup):
617
650
subprocess. Since it's running in a subprocess all tensors need to be in
618
651
shared memory or will be moved to shared memory. CUDA tensors are implicitly
619
652
share able and don't need any changes.
620
-
621
653
"""
622
654
623
- WORK_CLASS : Type [_BabyWork ] = _BabyWork
624
-
625
655
def __init__ (self , timeout : Union [float , timedelta ] = 60.0 ) -> None :
626
656
super ().__init__ (0 , 1 )
627
657
@@ -640,6 +670,10 @@ def __init__(self, timeout: Union[float, timedelta] = 60.0) -> None:
640
670
641
671
self ._timeout : float = timeout
642
672
673
+ self ._cuda_device_id : Optional [int ] = (
674
+ torch .cuda .current_device () if torch .cuda .is_available () else None
675
+ )
676
+
643
677
def configure (self , store_addr : str , rank : int , world_size : int ) -> None :
644
678
if self ._p is not None :
645
679
self ._p .kill ()
@@ -679,7 +713,15 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None:
679
713
680
714
self ._p = ctx .Process (
681
715
target = self ._worker ,
682
- args = (store_addr , rank , world_size , self ._tx , self ._rx , self ._future_queue ),
716
+ args = (
717
+ store_addr ,
718
+ rank ,
719
+ world_size ,
720
+ self ._tx ,
721
+ self ._rx ,
722
+ self ._future_queue ,
723
+ self ._cuda_device_id ,
724
+ ),
683
725
daemon = True ,
684
726
)
685
727
self ._p .start ()
@@ -704,8 +746,12 @@ def _worker(
704
746
rx : mp .Queue ,
705
747
tx : mp .Queue ,
706
748
future_queue : mp .Queue ,
749
+ cuda_device_id : Optional [int ],
707
750
) -> None :
708
751
try :
752
+ if cuda_device_id is not None :
753
+ torch .cuda .set_device (cuda_device_id )
754
+
709
755
store = create_store_client (store_addr )
710
756
711
757
try :
@@ -716,23 +762,62 @@ def _worker(
716
762
return
717
763
tx .put (None )
718
764
719
- work = {}
765
+ streams = defaultdict (lambda : torch .cuda .Stream ())
766
+ work : Dict [int , _OpMetadata ] = {}
720
767
next_op_id : int = 0
721
768
722
769
while True :
723
770
op = rx .get ()
724
771
cmd = op [0 ]
725
772
if cmd == "func" :
726
- func_name , args , kwargs = op [1 :]
727
- args = _PickleSafeOptions .unsafe_args (args )
728
- fn = getattr (pg , func_name )
729
- work [next_op_id ] = fn (* args , ** kwargs )
773
+ func_name , args , kwargs , stream_id , event = op [1 :]
774
+
775
+ # To avoid potential deadlocks we need to preserve the
776
+ # stream/synchronization behavior of the parent process.
777
+ # We allocate one Stream per stream_id to make sure that we
778
+ # don't accidentally introduce cross stream synchronization
779
+ # points.
780
+ stream = streams [stream_id ] if stream_id is not None else None
781
+ with (
782
+ torch .cuda .stream (stream )
783
+ if stream is not None
784
+ else nullcontext ()
785
+ ):
786
+
787
+ # Make the stream wait on the cuda event to make sure we
788
+ # don't start the operation until the tensor is ready.
789
+ if event is not None :
790
+ event .wait ()
791
+
792
+ args = _PickleSafeOptions .unsafe_args (args )
793
+ fn = getattr (pg , func_name )
794
+ work [next_op_id ] = _OpMetadata (
795
+ work = fn (* args , ** kwargs ),
796
+ stream = stream ,
797
+ )
730
798
tx .put (next_op_id )
731
799
next_op_id += 1
732
800
elif cmd == "wait" :
733
801
op_id : int = op [1 ]
734
- work [op_id ].wait ()
735
- tx .put (op_id )
802
+
803
+ metadata = work [op_id ]
804
+
805
+ with metadata .set_stream ():
806
+ # With WorkNCCL this makes the stream wait not the CPU when
807
+ # no timeout is passed.
808
+ metadata .work .wait ()
809
+
810
+ # Register event on the stream that we can pass to the main
811
+ # process.
812
+ event = (
813
+ torch .cuda .current_stream ().record_event (
814
+ torch .cuda .Event (interprocess = True )
815
+ )
816
+ if metadata .stream is not None
817
+ else None
818
+ )
819
+
820
+ tx .put ((op_id , event ))
736
821
elif cmd == "del" :
737
822
op_id : int = op [1 ]
738
823
del work [op_id ]
@@ -746,25 +831,12 @@ def callback(fut: Future[object]) -> None:
746
831
except Exception as e :
747
832
future_queue .put ((op_id , _FUTURE_EXCEPTION , e ))
748
833
749
- work [op_id ].get_future ().add_done_callback (callback )
834
+ work [op_id ].work . get_future ().add_done_callback (callback )
750
835
tx .put (op_id )
751
- elif cmd == "synchronize" :
752
- # CUDA only, use events instead of waiting on CPU
753
- op_id = op [1 ]
754
-
755
- # With WorkNCCL this makes the stream wait not the CPU when
756
- # no timeout is passed.
757
- work [op_id ].wait ()
758
-
759
- # Register event on the stream that we can pass to the main
760
- # process.
761
- event = torch .cuda .Event (interprocess = True )
762
- event .record ()
763
-
764
- del work [op_id ]
765
- tx .put ((op_id , event ))
766
836
elif cmd == "num_active_work" :
767
837
tx .put (len (work ))
838
+ elif cmd == "cuda_device_id" :
839
+ tx .put (torch .cuda .current_device ())
768
840
else :
769
841
raise ValueError (f"unknown cmd: { cmd } " )
770
842
@@ -792,6 +864,8 @@ def _future_handler(self, future_queue: mp.Queue) -> None:
792
864
logger .exception (f"got unexpected error in future handler: { e } " )
793
865
794
866
def _get_future (self , op_id : int ) -> Future [object ]:
867
+ self ._assert_alive ()
868
+
795
869
with self ._futures_lock :
796
870
fut = Future () # pyre-fixme[29]: is not a function
797
871
self ._futures [op_id ] = fut
@@ -804,22 +878,50 @@ def _get_future(self, op_id: int) -> Future[object]:
804
878
return fut
805
879
806
880
def _run_func (self , func : str , * args : object , ** kwargs : object ) -> Work :
881
+ self ._assert_alive ()
882
+
807
883
rx = self ._rx
808
884
tx = self ._tx
809
885
assert rx is not None
810
886
assert tx is not None
811
887
888
+ is_cuda = _is_any_cuda (args )
889
+
890
+ stream_id = torch .cuda .current_stream ().stream_id if is_cuda else None
891
+ event = (
892
+ torch .cuda .current_stream ().record_event (
893
+ torch .cuda .Event (interprocess = True )
894
+ )
895
+ if is_cuda
896
+ else None
897
+ )
898
+
812
899
tx .put (
813
- ("func" , func , _PickleSafeOptions .safe_args (args ), kwargs ),
900
+ (
901
+ "func" ,
902
+ func ,
903
+ _PickleSafeOptions .safe_args (args ),
904
+ kwargs ,
905
+ stream_id ,
906
+ event ,
907
+ ),
814
908
timeout = self ._timeout ,
815
909
)
816
910
817
911
op_id = _get (rx , self ._timeout )
818
912
assert isinstance (op_id , int ), f"invalid return { op_id } "
819
913
820
- return self .WORK_CLASS (
821
- pg = self , tx = tx , rx = rx , op_id = op_id , timeout = self ._timeout
822
- )
914
+ return _BabyWork (pg = self , tx = tx , rx = rx , op_id = op_id , timeout = self ._timeout )
915
+
916
+ def _assert_alive (self ) -> None :
917
+ """
918
+ Assert that the process group is alive. This is used to ensure that
919
+ operations are not performed on a dead process group and any errors are surfaced.
920
+ """
921
+ p = self ._p
922
+ assert p is not None
923
+ if not p .is_alive ():
924
+ raise RuntimeError (f"child process { p .pid = } is dead { p .exitcode = } " )
823
925
824
926
def allreduce (
825
927
self ,
@@ -877,6 +979,13 @@ def num_active_work(self) -> int:
877
979
assert self ._rx is not None
878
980
return cast (int , _get (self ._rx , self ._timeout ))
879
981
982
+ def cuda_device_id (self ) -> int :
983
+ assert self ._tx is not None
984
+ self ._tx .put (("cuda_device_id" ,), timeout = self ._timeout )
985
+
986
+ assert self ._rx is not None
987
+ return cast (int , _get (self ._rx , self ._timeout ))
988
+
880
989
881
990
@dataclass
882
991
class _PickleSafeOptions :
@@ -950,9 +1059,10 @@ class ProcessGroupBabyNCCL(ProcessGroupBaby):
950
1059
951
1060
WARNING: If the child process is killed while an operation is running, CUDA
952
1061
tensors may leak in the current PyTorch implementation. TODO fix
953
- """
954
1062
955
- WORK_CLASS = _BabyWorkNCCL
1063
+ If CUDA tensors are being used on a non-default device you must call
1064
+ ``torch.cuda.set_device()`` prior to instantiating this ProcessGroup.
1065
+ """
956
1066
957
1067
@classmethod
958
1068
def _create_pg (cls , store : Store , rank : int , world_size : int ) -> BaseProcessGroup :
0 commit comments