2525import socket
2626import sys
2727import tempfile
28+ import threading
2829import time
2930import uuid
3031import warnings
6970 FT_LAUNCHER_IPC_SOCKET_ENV_VAR ,
7071 FT_RANK_MONITOR_IPC_SOCKET_ENV_VAR ,
7172)
73+ from nvidia_resiliency_ext .fault_tolerance .progress_tracker import TrainingProgressTracker
7274from nvidia_resiliency_ext .fault_tolerance .rank_monitor_server import RankMonitorServer
7375from nvidia_resiliency_ext .fault_tolerance .utils import (
7476 patched_method ,
77+ read_obj_from_ipc_stream ,
7578 terminate_mp_processes ,
7679 write_obj_to_ipc_stream ,
7780)
@@ -192,6 +195,15 @@ def _wrap_entrypoint_with_numactl(
192195# https://github.com/pytorch/pytorch/blob/release/2.3/torch/distributed/elastic/agent/server/local_elastic_agent.py
193196
194197
198+ @dataclass
199+ class RankMonitorState :
200+ """State for a single rank monitor process and its IPC connections."""
201+ process : Any # multiprocessing.Process
202+ socket_path : str = ""
203+ listener_thread : Optional [threading .Thread ] = None
204+ stop_event : Optional [threading .Event ] = None
205+
206+
195207class LocalElasticAgent (SimpleElasticAgent ):
196208 """An implementation of :py:class:`torchelastic.agent.server.ElasticAgent` that handles host-local workers.
197209
@@ -317,8 +329,15 @@ def __init__(
317329 self ._term_timeout = term_timeout
318330 self ._workers_stop_timeout = workers_stop_timeout
319331 self ._is_store_host = is_store_host
320- self ._local_rank_to_rmon : Dict [int , Any ] = dict ()
332+ # Rank monitor state (process, IPC connections, listener tasks) per local rank
333+ self ._rank_monitors : Dict [int , RankMonitorState ] = dict ()
321334 self ._ft_cfg = fault_tol_cfg
335+ # Centralized progress tracking (always instantiated, active only if configured)
336+ self ._progress_tracker = TrainingProgressTracker (
337+ min_progress_iterations = fault_tol_cfg .min_progress_iterations ,
338+ max_no_progress_restarts = fault_tol_cfg .max_no_progress_restarts ,
339+ )
340+ self ._rank_iterations : Dict [int , int ] = dict () # Track max iteration per rank
322341 self ._children_pgids : Set [int ] = set ()
323342 self ._restart_policy = restart_policy
324343 self ._node_id = self ._get_fq_hostname ()
@@ -367,7 +386,7 @@ def _open_rendezvous_for_restart(self):
367386 self ._worker_group .group_rank if self ._worker_group else "N/A"
368387 )
369388 except Exception as e :
370- logger .warning (f"Failed to open rendezvous: { e } " )
389+ logger .error (f"Failed to open rendezvous: { e } " )
371390 # For legacy rendezvous, no action needed - it uses different mechanism
372391
373392 def _invoke_run (self , role : str = DEFAULT_ROLE ) -> RunResult :
@@ -420,7 +439,16 @@ def _invoke_run_with_any_failed_policy(self, role: str = DEFAULT_ROLE) -> RunRes
420439 rank = self ._worker_group .group_rank ,
421440 )
422441
423- if self ._remaining_restarts > 0 :
442+ self ._progress_tracker .analyze_previous_cycle ()
443+ should_terminate_early = self ._progress_tracker .should_terminate_early ()
444+
445+ if should_terminate_early :
446+ logger .error (
447+ "[%s] Progress tracker detected no progress across restarts. "
448+ "No more restarts will be attempted." ,
449+ role
450+ )
451+ elif self ._remaining_restarts > 0 :
424452 logger .info (
425453 "[%s] Worker group %s. "
426454 "%s/%s attempts left;"
@@ -434,14 +462,13 @@ def _invoke_run_with_any_failed_policy(self, role: str = DEFAULT_ROLE) -> RunRes
434462 # Open rendezvous before restarting (for barrier-based rendezvous)
435463 self ._open_rendezvous_for_restart ()
436464 self ._restart_workers (self ._worker_group )
437- else :
438- self ._stop_workers (self ._worker_group )
439- self ._worker_group .state = WorkerState .FAILED
440- # to preserve torchrun's behaviour, should not return WorkerState.UNHEALTHY.
441- # we use WorkerState.UNHEALTHY to denote a worker group that is still
442- # running but has some failed workers. torchrun does not use WorkerState.UNHEALTHY
443- run_result = self ._monitor_workers (self ._worker_group )
444- return run_result
465+ continue # Continue monitoring after restart
466+
467+ # No more restarts (either exhausted or early termination)
468+ self ._stop_workers (self ._worker_group )
469+ self ._worker_group .state = WorkerState .FAILED
470+ run_result = self ._monitor_workers (self ._worker_group )
471+ return run_result
445472 elif state == WorkerState .HEALTHY :
446473 # Check for cluster-wide issues: unhealthy nodes or new nodes waiting
447474 unhealthy_count = self ._check_cluster_unhealthy_count ()
@@ -579,31 +606,138 @@ def get_rank_mon_socket_path(self, local_rank):
579606
580607 def setup_rank_monitors (self , envs : Dict [int , Dict [str , str ]]) -> None :
581608 fork_mp_ctx = torch .multiprocessing .get_context ("fork" )
609+ new_monitors = [] # Track newly started monitors
610+
582611 for worker_env in envs .values ():
583612 # Start rank monitors if not already started
584613 # Each rank (re)connects to its rank monitor when it starts
585614 # Monitor of the local rank0 on the store hosting node is the restarter logger
586615 local_rank = int (worker_env ['LOCAL_RANK' ])
587616 is_restarter_logger = self ._is_store_host and local_rank == 0
588617 rmon_ipc_socket = worker_env [FT_RANK_MONITOR_IPC_SOCKET_ENV_VAR ]
589- if local_rank not in self ._local_rank_to_rmon :
590- self . _local_rank_to_rmon [ local_rank ] = RankMonitorServer .run_in_subprocess (
618+ if local_rank not in self ._rank_monitors :
619+ rmon_proc = RankMonitorServer .run_in_subprocess (
591620 cfg = self ._ft_cfg ,
592621 ipc_socket_path = rmon_ipc_socket ,
593622 is_restarter_logger = is_restarter_logger ,
594623 mp_ctx = fork_mp_ctx ,
595624 env = worker_env ,
596625 )
626+ self ._rank_monitors [local_rank ] = RankMonitorState (process = rmon_proc )
627+ new_monitors .append ((local_rank , rmon_proc ))
628+
629+ # Establish bidirectional IPC connections to new rank monitors
630+ if new_monitors :
631+ async def connect_all ():
632+ await asyncio .gather (
633+ * [self ._connect_to_rank_monitor (lr , rmon ) for lr , rmon in new_monitors ]
634+ )
635+ asyncio .run (connect_all ())
597636
598637 def shutdown_rank_monitors (self ):
599- for local_rank , rmon_proc in self ._local_rank_to_rmon .items ():
638+ # Stop listener threads and terminate rank monitor processes
639+ for local_rank , state in self ._rank_monitors .items ():
640+ # Signal listener thread to stop
641+ if state .stop_event :
642+ state .stop_event .set ()
643+
644+ # Wait for listener thread to finish (will close connection gracefully)
645+ if state .listener_thread and state .listener_thread .is_alive ():
646+ state .listener_thread .join (timeout = 2.0 )
647+
648+ # Terminate rank monitor processes
649+ for local_rank , state in self ._rank_monitors .items ():
600650 with contextlib .suppress (Exception ):
601- rmon_proc .terminate ()
651+ state . process .terminate ()
602652 with contextlib .suppress (Exception ):
603- rmon_proc .join ()
653+ state . process .join ()
604654 with contextlib .suppress (Exception ):
605655 os .unlink (self .get_rank_mon_socket_path (local_rank ))
606656
657+ async def _connect_to_rank_monitor (self , local_rank : int , rmon_proc ) -> None :
658+ """Start listener thread for rank monitor bidirectional IPC.
659+
660+ Note: This is called after rank_monitor_ready_event is set, which guarantees
661+ the socket file already exists. The actual connection is created inside the
662+ background thread's event loop to avoid event loop conflicts.
663+ """
664+ launcher_to_rmon_socket = f"{ tempfile .gettempdir ()} /_ft_launcher{ rmon_proc .pid } _to_rmon.socket"
665+
666+ state = self ._rank_monitors [local_rank ]
667+ state .socket_path = launcher_to_rmon_socket
668+ state .stop_event = threading .Event ()
669+
670+ # Start listener thread (will create connection in its own event loop)
671+ state .listener_thread = threading .Thread (
672+ target = self ._listen_to_rank_monitor_thread ,
673+ args = (local_rank , launcher_to_rmon_socket , state .stop_event ),
674+ daemon = True ,
675+ name = f"RankMonitor-{ local_rank } -Listener"
676+ )
677+ state .listener_thread .start ()
678+
679+ def _update_progress_iteration (self , local_rank : int , iteration : int ):
680+ """Update iteration for a specific rank and aggregate using MIN strategy."""
681+ # Update this rank's max iteration
682+ self ._rank_iterations [local_rank ] = max (
683+ self ._rank_iterations .get (local_rank , 0 ), iteration
684+ )
685+
686+ # Use minimum across all ranks (most conservative - slowest rank determines progress)
687+ min_iteration = min (self ._rank_iterations .values ()) if self ._rank_iterations else 0
688+ self ._progress_tracker .update_iteration (min_iteration )
689+
690+ logger .debug (
691+ f"Updated iteration for rank { local_rank } ={ iteration } , "
692+ f"cluster min={ min_iteration } , all_ranks={ self ._rank_iterations } "
693+ )
694+
695+ def _listen_to_rank_monitor_thread (self , local_rank : int , socket_path : str , stop_event : threading .Event ) -> None :
696+ """Listen for messages from rank monitor in a background thread.
697+
698+ This runs in a separate thread with its own event loop to receive messages
699+ from the rank monitor server. The connection is created in this thread's
700+ event loop to avoid cross-loop conflicts.
701+ """
702+ # Create a new event loop for this thread
703+ loop = asyncio .new_event_loop ()
704+ asyncio .set_event_loop (loop )
705+
706+ async def listen_loop ():
707+ try :
708+ # Create connection in THIS thread's event loop
709+ reader , writer = await asyncio .open_unix_connection (socket_path )
710+
711+ try :
712+ while not stop_event .is_set ():
713+ # Use wait_for with timeout to allow checking stop_event periodically
714+ try :
715+ msg = await asyncio .wait_for (read_obj_from_ipc_stream (reader ), timeout = 1.0 )
716+ if isinstance (msg , dict ) and msg .get ("type" ) == "iteration_update" :
717+ # Handle iteration update from rank monitor
718+ iteration = msg ["iteration" ]
719+ self ._update_progress_iteration (local_rank , iteration )
720+ # Note: Don't log every iteration update - too chatty during training
721+ else :
722+ logger .debug (f"Received message from rank monitor { local_rank } : { msg } " )
723+ except asyncio .TimeoutError :
724+ # Timeout is expected, just check stop_event and continue
725+ continue
726+ finally :
727+ # Clean up connection
728+ writer .close ()
729+ await writer .wait_closed ()
730+ except (asyncio .IncompleteReadError , ConnectionResetError , BrokenPipeError , EOFError ):
731+ logger .debug (f"Rank monitor { local_rank } connection closed" )
732+ except Exception as e :
733+ if not stop_event .is_set ():
734+ logger .error (f"Error listening to rank monitor { local_rank } : { e } " )
735+
736+ try :
737+ loop .run_until_complete (listen_loop ())
738+ finally :
739+ loop .close ()
740+
607741 def _setup_local_watchdog (self , envs : Dict [int , Dict [str , str ]]) -> None :
608742 enable_watchdog_env_name = TORCHELASTIC_ENABLE_FILE_TIMER
609743 watchdog_enabled = os .getenv (enable_watchdog_env_name )
@@ -682,23 +816,7 @@ def _stop_workers(self, worker_group: WorkerGroup, *args, **kwargs) -> None:
682816
683817 logger .info (f"Stopping workers... Timeout = { self ._workers_stop_timeout } sec." )
684818
685- # Send close message to rank monitors
686- for local_rank , rmon_proc in self ._local_rank_to_rmon .items ():
687- try :
688- launcher_to_rmon_socket = f"{ tempfile .gettempdir ()} /_ft_launcher{ rmon_proc .pid } _to_rmon.socket"
689- if os .path .exists (launcher_to_rmon_socket ):
690- async def send_close_msg ():
691- reader , writer = await asyncio .open_unix_connection (launcher_to_rmon_socket )
692- try :
693- await write_obj_to_ipc_stream ("close_worker_ipc_connection" , writer )
694- finally :
695- writer .close ()
696- await writer .wait_closed ()
697-
698- asyncio .run (send_close_msg ())
699- except Exception as e :
700- logger .warning (f"Failed to send close message to rank monitor { local_rank } : { e } " )
701-
819+ # Rank monitors will detect worker shutdown when worker processes disconnect
702820 self ._shutdown (timeout = self ._workers_stop_timeout )
703821
704822 # Record worker termination event after shutdown is complete
@@ -2074,6 +2192,32 @@ def get_args_parser() -> ArgumentParser:
20742192 help = "Do not raise an error if there is no Fault Tolerance pkg config provided, just use default settings." ,
20752193 )
20762194
2195+ #
2196+ # Progress tracking arguments
2197+ #
2198+
2199+ parser .add_argument (
2200+ "--ft-max-no-progress-restarts" ,
2201+ "--ft-max_no_progress_restarts" ,
2202+ type = int ,
2203+ default = 3 ,
2204+ dest = "ft_max_no_progress_restarts" ,
2205+ help = "Maximum consecutive restarts without progress before early termination. "
2206+ "Progress tracking is enabled when this value > 0. "
2207+ "Set to 0 or -1 to disable progress tracking. "
2208+ "Default: 3 (progress tracking enabled)." ,
2209+ )
2210+
2211+ parser .add_argument (
2212+ "--ft-min-progress-iterations" ,
2213+ "--ft-min_progress_iterations" ,
2214+ type = int ,
2215+ default = 200 ,
2216+ dest = "ft_min_progress_iterations" ,
2217+ help = "Minimum iterations required to consider a restart as making progress. "
2218+ "Default: 200." ,
2219+ )
2220+
20772221 #
20782222 # Positional arguments.
20792223 #
0 commit comments