7676 write_obj_to_ipc_stream ,
7777)
7878from nvidia_resiliency_ext .shared_utils .log_manager import LogConfig , setup_logger
79+ from nvidia_resiliency_ext .shared_utils .profiling import ProfilingEvent , record_profiling_event
7980
8081# Deprecation warning for FT_LAUNCHER_LOGLEVEL
8182if os .getenv ('FT_LAUNCHER_LOGLEVEL' ) is not None :
@@ -101,6 +102,10 @@ def _register_ft_rdzv_handler():
101102 from torch .distributed .elastic .rendezvous .c10d_rendezvous_backend import create_backend
102103
103104 from ._ft_rendezvous import FtRendezvousHandler , create_handler
105+ from .c10d_monkey_patch import apply_c10d_patch
106+
107+ # Apply monkey patch to add use_libuv support to c10d backend
108+ apply_c10d_patch ()
104109
105110 def _create_ft_rdzv_handler (params : RendezvousParameters ) -> FtRendezvousHandler :
106111 backend , store = create_backend (params )
@@ -138,7 +143,7 @@ class LocalElasticAgent(SimpleElasticAgent):
138143 python multiprocessing compatible. To pass multiprocessing data structures
139144 to the workers you may create the data structure in the same multiprocessing
140145 context as the specified ``start_method`` and pass it as a function argument.
141-
146+
142147 Note: If your training script uses the nvrx logger, make sure to call
143148 ``setup_logger()`` at the beginning of your training function to ensure
144149 the logger is properly set up in each subprocess.
@@ -179,12 +184,12 @@ def trainer(args) -> str:
179184 # Ensure nvrx logger is set up in this subprocess
180185 from nvidia_resiliency_ext.shared_utils.log_manager import setup_logger
181186 setup_logger()
182-
187+
183188 # Use the nvrx logger
184189 import logging
185190 logger = logging.getLogger(LogConfig.name)
186191 logger.info("Training started")
187-
192+
188193 return "do train"
189194
190195 def main():
@@ -251,6 +256,7 @@ def __init__(
251256 self ._ft_cfg = fault_tol_cfg
252257 self ._children_pgids : Set [int ] = set ()
253258 self ._restart_policy = restart_policy
259+ self ._node_id = self ._get_fq_hostname ()
254260
255261 DEFAULT_ROLE = "default" # FIXME
256262
@@ -322,6 +328,13 @@ def _invoke_run_with_any_failed_policy(self, role: str = DEFAULT_ROLE) -> RunRes
322328 self ._exit_barrier ()
323329 return run_result
324330 elif state in {WorkerState .UNHEALTHY , WorkerState .FAILED }:
331+ # Record failure detection event
332+ record_profiling_event (
333+ ProfilingEvent .FAILURE_DETECTED ,
334+ node_id = self ._rdzv_handler ._this_node ,
335+ rank = self ._worker_group .group_rank ,
336+ )
337+
325338 if self ._remaining_restarts > 0 :
326339 logger .info (
327340 "[%s] Worker group %s. "
@@ -347,6 +360,13 @@ def _invoke_run_with_any_failed_policy(self, role: str = DEFAULT_ROLE) -> RunRes
347360 num_nodes_waiting = rdzv_handler .num_nodes_waiting ()
348361 group_rank = self ._worker_group .group_rank
349362 if num_nodes_waiting > 0 :
363+ # Record failure detection event
364+ record_profiling_event (
365+ ProfilingEvent .FAILURE_DETECTED ,
366+ node_id = self ._rdzv_handler ._this_node ,
367+ rank = self ._worker_group .group_rank ,
368+ )
369+
350370 logger .info (
351371 "[%s] Detected %s "
352372 "new nodes from group_rank=%s; "
@@ -587,6 +607,13 @@ async def send_close_msg():
587607
588608 self ._shutdown (timeout = self ._workers_stop_timeout )
589609
610+ # Record worker termination event after shutdown is complete
611+ record_profiling_event (
612+ ProfilingEvent .WORKER_TERMINATED ,
613+ node_id = self ._rdzv_handler ._this_node ,
614+ rank = worker_group .group_rank ,
615+ )
616+
590617 # pyre-fixme[56]: Pyre was not able to infer the type of the decorator
591618 # `torch.distributed.elastic.metrics.prof`.
592619 @prof
@@ -596,6 +623,13 @@ def _start_workers(self, worker_group: WorkerGroup) -> Dict[int, Any]:
596623 assert store is not None
597624 restart_count = spec .max_restarts - self ._remaining_restarts
598625
626+ # Record worker start start event
627+ record_profiling_event (
628+ ProfilingEvent .WORKER_START_STARTED ,
629+ node_id = self ._rdzv_handler ._this_node ,
630+ rank = worker_group .group_rank ,
631+ )
632+
599633 use_agent_store = spec .rdzv_handler .use_agent_store
600634
601635 args : Dict [int , Tuple ] = {}
@@ -667,8 +701,16 @@ def _start_workers(self, worker_group: WorkerGroup) -> Dict[int, Any]:
667701
668702 self ._children_pgids = {os .getpgid (p ) for p in self ._pcontext .pids ().values ()}
669703
704+ # Record worker start completion event
705+ record_profiling_event (
706+ ProfilingEvent .WORKER_START_COMPLETED ,
707+ node_id = self ._rdzv_handler ._this_node ,
708+ rank = worker_group .group_rank ,
709+ )
710+
670711 return self ._pcontext .pids ()
671712
713+
672714 def _shutdown (self , death_sig : signal .Signals = signal .SIGTERM , timeout : int = 30 ) -> None :
673715 if self ._worker_watchdog is not None :
674716 self ._worker_watchdog .stop ()
@@ -1054,6 +1096,7 @@ def launch_agent(
10541096 )
10551097
10561098 logger .info (f"Agent .run() is OK. No failures in the result. { result = } " )
1099+
10571100 return result .return_values
10581101 except UnhealthyNodeException as e :
10591102 # do not shutdown rendezvous when an unhealthy node is leaving
@@ -1987,6 +2030,10 @@ def config_from_args(args) -> Tuple[LaunchConfig, Union[Callable, str], List[str
19872030
19882031 rdzv_configs = _parse_rendezvous_config (args .rdzv_conf )
19892032
2033+ # Add use_libuv=False for c10d backend
2034+ if args .rdzv_backend == 'c10d' :
2035+ rdzv_configs ['use_libuv' ] = False
2036+
19902037 if args .rdzv_backend == "static" :
19912038 rdzv_configs ["rank" ] = args .node_rank
19922039
0 commit comments