7676 write_obj_to_ipc_stream ,
7777)
7878from nvidia_resiliency_ext .shared_utils .log_manager import LogConfig , setup_logger
79+ from nvidia_resiliency_ext .shared_utils .profiling import ProfilingEvent , record_profiling_event
7980
8081# Deprecation warning for FT_LAUNCHER_LOGLEVEL
8182if os .getenv ('FT_LAUNCHER_LOGLEVEL' ) is not None :
@@ -142,7 +143,7 @@ class LocalElasticAgent(SimpleElasticAgent):
142143 python multiprocessing compatible. To pass multiprocessing data structures
143144 to the workers you may create the data structure in the same multiprocessing
144145 context as the specified ``start_method`` and pass it as a function argument.
145-
146+
146147 Note: If your training script uses the nvrx logger, make sure to call
147148 ``setup_logger()`` at the beginning of your training function to ensure
148149 the logger is properly set up in each subprocess.
@@ -183,12 +184,12 @@ def trainer(args) -> str:
183184 # Ensure nvrx logger is set up in this subprocess
184185 from nvidia_resiliency_ext.shared_utils.log_manager import setup_logger
185186 setup_logger()
186-
187+
187188 # Use the nvrx logger
188189 import logging
189190 logger = logging.getLogger(LogConfig.name)
190191 logger.info("Training started")
191-
192+
192193 return "do train"
193194
194195 def main():
@@ -255,6 +256,7 @@ def __init__(
255256 self ._ft_cfg = fault_tol_cfg
256257 self ._children_pgids : Set [int ] = set ()
257258 self ._restart_policy = restart_policy
259+ self ._node_id = self ._get_fq_hostname ()
258260
259261 DEFAULT_ROLE = "default" # FIXME
260262
@@ -326,6 +328,13 @@ def _invoke_run_with_any_failed_policy(self, role: str = DEFAULT_ROLE) -> RunRes
326328 self ._exit_barrier ()
327329 return run_result
328330 elif state in {WorkerState .UNHEALTHY , WorkerState .FAILED }:
331+ # Record failure detection event
332+ record_profiling_event (
333+ ProfilingEvent .FAILURE_DETECTED ,
334+ node_id = self ._rdzv_handler ._this_node ,
335+ rank = self ._worker_group .group_rank ,
336+ )
337+
329338 if self ._remaining_restarts > 0 :
330339 logger .info (
331340 "[%s] Worker group %s. "
@@ -351,6 +360,13 @@ def _invoke_run_with_any_failed_policy(self, role: str = DEFAULT_ROLE) -> RunRes
351360 num_nodes_waiting = rdzv_handler .num_nodes_waiting ()
352361 group_rank = self ._worker_group .group_rank
353362 if num_nodes_waiting > 0 :
363+ # Record failure detection event
364+ record_profiling_event (
365+ ProfilingEvent .FAILURE_DETECTED ,
366+ node_id = self ._rdzv_handler ._this_node ,
367+ rank = self ._worker_group .group_rank ,
368+ )
369+
354370 logger .info (
355371 "[%s] Detected %s "
356372 "new nodes from group_rank=%s; "
@@ -591,6 +607,13 @@ async def send_close_msg():
591607
592608 self ._shutdown (timeout = self ._workers_stop_timeout )
593609
610+ # Record worker termination event after shutdown is complete
611+ record_profiling_event (
612+ ProfilingEvent .WORKER_TERMINATED ,
613+ node_id = self ._rdzv_handler ._this_node ,
614+ rank = worker_group .group_rank ,
615+ )
616+
594617 # pyre-fixme[56]: Pyre was not able to infer the type of the decorator
595618 # `torch.distributed.elastic.metrics.prof`.
596619 @prof
@@ -600,6 +623,13 @@ def _start_workers(self, worker_group: WorkerGroup) -> Dict[int, Any]:
600623 assert store is not None
601624 restart_count = spec .max_restarts - self ._remaining_restarts
602625
626+ # Record worker start start event
627+ record_profiling_event (
628+ ProfilingEvent .WORKER_START_STARTED ,
629+ node_id = self ._rdzv_handler ._this_node ,
630+ rank = worker_group .group_rank ,
631+ )
632+
603633 use_agent_store = spec .rdzv_handler .use_agent_store
604634
605635 args : Dict [int , Tuple ] = {}
@@ -671,8 +701,16 @@ def _start_workers(self, worker_group: WorkerGroup) -> Dict[int, Any]:
671701
672702 self ._children_pgids = {os .getpgid (p ) for p in self ._pcontext .pids ().values ()}
673703
704+ # Record worker start completion event
705+ record_profiling_event (
706+ ProfilingEvent .WORKER_START_COMPLETED ,
707+ node_id = self ._rdzv_handler ._this_node ,
708+ rank = worker_group .group_rank ,
709+ )
710+
674711 return self ._pcontext .pids ()
675712
713+
676714 def _shutdown (self , death_sig : signal .Signals = signal .SIGTERM , timeout : int = 30 ) -> None :
677715 if self ._worker_watchdog is not None :
678716 self ._worker_watchdog .stop ()
@@ -1058,6 +1096,7 @@ def launch_agent(
10581096 )
10591097
10601098 logger .info (f"Agent .run() is OK. No failures in the result. { result = } " )
1099+
10611100 return result .return_values
10621101 except UnhealthyNodeException as e :
10631102 # do not shutdown rendezvous when an unhealthy node is leaving
0 commit comments