17
17
"""
18
18
19
19
import logging
20
- import queue
20
+ import sys
21
21
import threading
22
22
from contextlib import contextmanager , nullcontext
23
23
from dataclasses import dataclass
63
63
from torch .futures import Future
64
64
from torch .utils ._pytree import tree_any
65
65
66
+ from torchft .multiprocessing import _MonitoredQueue
67
+
66
68
if TYPE_CHECKING :
67
69
from torchft .manager import Manager
68
70
77
79
T = TypeVar ("T" )
78
80
79
81
80
- def _get (q : mp .Queue , timeout : Union [float , timedelta ]) -> object :
81
- """
82
- Gets an item from a queue with a timeout. If the timeout is exceeded then
83
- a TimeoutError is raised.
84
-
85
- If an exception is returned from the queue then it is raised.
86
-
87
- Args:
88
- q: queue to get from
89
- timeout: timeout in seconds
90
- """
91
- if isinstance (timeout , timedelta ):
92
- timeout = timeout .total_seconds ()
93
- try :
94
- v = q .get (timeout = timeout )
95
- except queue .Empty as e :
96
- raise TimeoutError (f"queue.get() timed out after { timeout } seconds" ) from e
97
- if isinstance (v , Exception ):
98
- raise v
99
- return v
100
-
101
-
102
82
def create_store_client (store_addr : str ) -> Store :
103
83
"""
104
84
Creates a PrefixStore(TCPStore(...)) client from an address in the format:
@@ -573,8 +553,8 @@ class _BabyWork(Work):
573
553
def __init__ (
574
554
self ,
575
555
pg : "ProcessGroupBaby" ,
576
- tx : mp . Queue ,
577
- rx : mp . Queue ,
556
+ tx : _MonitoredQueue ,
557
+ rx : _MonitoredQueue ,
578
558
op_id : int ,
579
559
timeout : float ,
580
560
) -> None :
@@ -592,7 +572,7 @@ def wait(self, timeout: Optional[timedelta] = None) -> bool:
592
572
self ._tx .put (("wait" , self ._op_id ), timeout = self ._timeout )
593
573
op_id , event = cast (
594
574
Tuple [int , Optional [torch .cuda .Event ]],
595
- _get ( self ._rx , timeout or self ._timeout ),
575
+ self ._rx . get ( timeout or self ._timeout ),
596
576
)
597
577
assert op_id == self ._op_id
598
578
if event is not None :
@@ -649,9 +629,9 @@ def __init__(self, timeout: Union[float, timedelta] = 60.0) -> None:
649
629
self ._world_size = - 1
650
630
651
631
self ._p : Optional [mp .Process ] = None
652
- self ._tx : Optional [mp . Queue ] = None
653
- self ._rx : Optional [mp . Queue ] = None
654
- self ._future_queue : Optional [mp . Queue ] = None
632
+ self ._tx : Optional [_MonitoredQueue ] = None
633
+ self ._rx : Optional [_MonitoredQueue ] = None
634
+ self ._future_queue : Optional [_MonitoredQueue ] = None
655
635
self ._future_thread : Optional [threading .Thread ] = None
656
636
self ._futures : Dict [int , Future [object ]] = {}
657
637
self ._futures_lock = threading .Lock ()
@@ -661,60 +641,80 @@ def __init__(self, timeout: Union[float, timedelta] = 60.0) -> None:
661
641
662
642
self ._timeout : float = timeout
663
643
664
- def configure (self , store_addr : str , rank : int , world_size : int ) -> None :
665
- if self ._p is not None :
666
- self ._p .kill ()
644
+ def shutdown (self ) -> None :
645
+ """
646
+ Shutdown the process group. This will kill the underlying process and
647
+ close all queues.
667
648
668
- self ._world_size = world_size
649
+ This is a no-op if the process group is already shutdown.
650
+
651
+ ProcessGroup can be reconfigured after shutdown.
652
+ """
669
653
670
654
if self ._tx is not None :
671
655
self ._tx .close ()
672
656
if self ._rx is not None :
673
657
self ._rx .close ()
674
- if self ._future_queue is not None :
658
+
659
+ future_queue = self ._future_queue
660
+ if future_queue is not None :
675
661
# wait for the future thread to exit and then close the queue
676
- self ._future_queue .put (_QUEUE_CLOSE )
677
- assert self ._future_thread is not None
678
- self ._future_thread .join (timeout = 10.0 )
679
- # pyre-ignore[16]: optional value is checked above
680
- if self ._future_thread .is_alive ():
662
+ future_queue .put (_QUEUE_CLOSE , timeout = timedelta (seconds = 10.0 ))
663
+
664
+ future_thread = self ._future_thread
665
+ assert future_thread is not None
666
+ future_thread .join (timeout = 10.0 )
667
+ if future_thread .is_alive ():
681
668
raise RuntimeError ("future thread did not exit" )
682
- # pyre-ignore[16]: optional value is checked above
683
- self ._future_queue .close ()
669
+
670
+ future_queue .close ()
671
+
672
+ # Kill after closing queues to avoid log spam.
673
+ if self ._p is not None :
674
+ self ._p .kill ()
675
+
676
+ def configure (self , store_addr : str , rank : int , world_size : int ) -> None :
677
+ self ._world_size = world_size
678
+
679
+ self .shutdown ()
684
680
685
681
ctx = mp .get_context ("spawn" )
686
- self ._tx = ctx .Queue ()
687
- self ._rx = rx = ctx .Queue ()
682
+ tx = ctx .Queue ()
683
+ rx = ctx .Queue ()
684
+ future_queue = ctx .Queue ()
685
+
686
+ self ._p = p = ctx .Process (
687
+ target = self ._worker ,
688
+ args = (
689
+ store_addr ,
690
+ rank ,
691
+ world_size ,
692
+ tx ,
693
+ rx ,
694
+ future_queue ,
695
+ ),
696
+ daemon = True ,
697
+ )
698
+ p .start ()
699
+
700
+ self ._tx = tx = _MonitoredQueue (p , tx )
701
+ self ._rx = rx = _MonitoredQueue (p , rx )
702
+ self ._future_queue = future_queue = _MonitoredQueue (p , future_queue )
688
703
689
704
# futures need thread to fire callbacks
690
- self ._future_queue = ctx .Queue ()
691
705
# this lock needs to be held when manipulating _futures
692
706
self ._futures_lock = threading .Lock ()
693
707
self ._futures = {}
694
708
self ._future_thread = threading .Thread (
695
709
target = self ._future_handler ,
696
- args = (self . _future_queue ,),
710
+ args = (future_queue ,),
697
711
daemon = True ,
698
712
)
699
713
self ._future_thread .start ()
700
714
701
- self ._p = ctx .Process (
702
- target = self ._worker ,
703
- args = (
704
- store_addr ,
705
- rank ,
706
- world_size ,
707
- self ._tx ,
708
- self ._rx ,
709
- self ._future_queue ,
710
- ),
711
- daemon = True ,
712
- )
713
- self ._p .start ()
714
-
715
715
# fetch the status of the PG init
716
- # if an exception was returned _get will throw
717
- assert _get ( rx , self ._timeout ) is None
716
+ # if an exception was returned get will throw
717
+ assert rx . get ( self ._timeout ) is None
718
718
719
719
@classmethod
720
720
def _create_pg (cls , store : Store , rank : int , world_size : int ) -> BaseProcessGroup :
@@ -739,7 +739,7 @@ def _worker(
739
739
try :
740
740
pg = cls ._create_pg (store , rank , world_size )
741
741
except Exception as e :
742
- logger . exception (f"got exception in worker: { e } " )
742
+ print (f"got exception in worker: { e } " , file = sys . stderr )
743
743
tx .put (e )
744
744
return
745
745
tx .put (None )
@@ -829,17 +829,21 @@ def callback(fut: Future[object]) -> None:
829
829
raise ValueError (f"unknown cmd: { cmd } " )
830
830
831
831
except Exception as e :
832
- logger . exception ( "worker errored" )
832
+ print ( f "worker errored: { e } " , file = sys . stderr )
833
833
tx .put (e )
834
834
raise
835
835
836
- def _future_handler (self , future_queue : mp . Queue ) -> None :
836
+ def _future_handler (self , future_queue : _MonitoredQueue ) -> None :
837
837
try :
838
838
while True :
839
- cmd = future_queue .get ()
839
+ try :
840
+ # timeout doesn't really matter here
841
+ cmd = future_queue .get (timeout = timedelta (seconds = 10.0 ))
842
+ except TimeoutError :
843
+ continue
840
844
if cmd == _QUEUE_CLOSE :
841
845
break
842
- op_id , mode , data = cmd
846
+ op_id , mode , data = cast ( Tuple [ int , str , object ], cmd )
843
847
with self ._futures_lock :
844
848
fut = self ._futures [op_id ]
845
849
del self ._futures [op_id ]
@@ -862,7 +866,7 @@ def _get_future(self, op_id: int) -> Future[object]:
862
866
self ._tx .put (("future" , op_id ), timeout = self ._timeout )
863
867
864
868
assert self ._rx is not None
865
- assert _get ( self ._rx , self ._timeout ) == op_id
869
+ assert self ._rx . get ( self ._timeout ) == op_id
866
870
# TODO: return correct tensor instead of None
867
871
return fut
868
872
@@ -899,7 +903,7 @@ def _run_func(self, func: str, *args: object, **kwargs: object) -> Work:
899
903
timeout = self ._timeout ,
900
904
)
901
905
902
- op_id = _get ( rx , self ._timeout )
906
+ op_id = rx . get ( self ._timeout )
903
907
assert isinstance (op_id , int ), f"invalid return { op_id } "
904
908
905
909
return _BabyWork (pg = self , tx = tx , rx = rx , op_id = op_id , timeout = self ._timeout )
@@ -968,7 +972,7 @@ def num_active_work(self) -> int:
968
972
self ._tx .put (("num_active_work" ,), timeout = self ._timeout )
969
973
970
974
assert self ._rx is not None
971
- return cast (int , _get ( self ._rx , self ._timeout ))
975
+ return cast (int , self ._rx . get ( self ._timeout ))
972
976
973
977
974
978
@dataclass
0 commit comments