2525import socket
2626import sys
2727import tempfile
28+ import threading
2829import time
2930import uuid
3031import warnings
6970 FT_LAUNCHER_IPC_SOCKET_ENV_VAR ,
7071 FT_RANK_MONITOR_IPC_SOCKET_ENV_VAR ,
7172)
73+ from nvidia_resiliency_ext .fault_tolerance .progress_tracker import TrainingProgressTracker
7274from nvidia_resiliency_ext .fault_tolerance .rank_monitor_server import RankMonitorServer
7375from nvidia_resiliency_ext .fault_tolerance .utils import (
7476 patched_method ,
77+ read_obj_from_ipc_stream ,
7578 terminate_mp_processes ,
76- write_obj_to_ipc_stream ,
7779)
7880from nvidia_resiliency_ext .shared_utils .log_manager import LogConfig , setup_logger
7981from nvidia_resiliency_ext .shared_utils .profiling import ProfilingEvent , record_profiling_event
@@ -192,6 +194,15 @@ def _wrap_entrypoint_with_numactl(
192194# https://github.com/pytorch/pytorch/blob/release/2.3/torch/distributed/elastic/agent/server/local_elastic_agent.py
193195
194196
197+ @dataclass
198+ class RankMonitorState :
199+ """State for a single rank monitor process and its IPC connections."""
200+ process : Any # multiprocessing.Process
201+ socket_path : str = ""
202+ listener_thread : Optional [threading .Thread ] = None
203+ stop_event : Optional [threading .Event ] = None
204+
205+
195206class LocalElasticAgent (SimpleElasticAgent ):
196207 """An implementation of :py:class:`torchelastic.agent.server.ElasticAgent` that handles host-local workers.
197208
@@ -317,8 +328,15 @@ def __init__(
317328 self ._term_timeout = term_timeout
318329 self ._workers_stop_timeout = workers_stop_timeout
319330 self ._is_store_host = is_store_host
320- self ._local_rank_to_rmon : Dict [int , Any ] = dict ()
331+ # Rank monitor state (process, IPC connections, listener tasks) per local rank
332+ self ._rank_monitors : Dict [int , RankMonitorState ] = dict ()
321333 self ._ft_cfg = fault_tol_cfg
334+ # Centralized progress tracking (always instantiated, active only if configured)
335+ self ._progress_tracker = TrainingProgressTracker (
336+ min_progress_iterations = fault_tol_cfg .min_progress_iterations ,
337+ max_no_progress_restarts = fault_tol_cfg .max_no_progress_restarts ,
338+ )
339+ self ._rank_iterations : Dict [int , int ] = dict () # Track max iteration per rank
322340 self ._children_pgids : Set [int ] = set ()
323341 self ._restart_policy = restart_policy
324342 self ._node_id = self ._get_fq_hostname ()
@@ -367,7 +385,7 @@ def _open_rendezvous_for_restart(self):
367385 self ._worker_group .group_rank if self ._worker_group else "N/A"
368386 )
369387 except Exception as e :
370- logger .warning (f"Failed to open rendezvous: { e } " )
388+ logger .error (f"Failed to open rendezvous: { e } " )
371389 # For legacy rendezvous, no action needed - it uses different mechanism
372390
373391 def _invoke_run (self , role : str = DEFAULT_ROLE ) -> RunResult :
@@ -420,7 +438,16 @@ def _invoke_run_with_any_failed_policy(self, role: str = DEFAULT_ROLE) -> RunRes
420438 rank = self ._worker_group .group_rank ,
421439 )
422440
423- if self ._remaining_restarts > 0 :
441+ self ._progress_tracker .analyze_previous_cycle ()
442+ should_terminate_early = self ._progress_tracker .should_terminate_early ()
443+
444+ if should_terminate_early :
445+ logger .error (
446+ "[%s] Progress tracker detected no progress across restarts. "
447+ "No more restarts will be attempted." ,
448+ role
449+ )
450+ elif self ._remaining_restarts > 0 :
424451 logger .info (
425452 "[%s] Worker group %s. "
426453 "%s/%s attempts left;"
@@ -434,14 +461,13 @@ def _invoke_run_with_any_failed_policy(self, role: str = DEFAULT_ROLE) -> RunRes
434461 # Open rendezvous before restarting (for barrier-based rendezvous)
435462 self ._open_rendezvous_for_restart ()
436463 self ._restart_workers (self ._worker_group )
437- else :
438- self ._stop_workers (self ._worker_group )
439- self ._worker_group .state = WorkerState .FAILED
440- # to preserve torchrun's behaviour, should not return WorkerState.UNHEALTHY.
441- # we use WorkerState.UNHEALTHY to denote a worker group that is still
442- # running but has some failed workers. torchrun does not use WorkerState.UNHEALTHY
443- run_result = self ._monitor_workers (self ._worker_group )
444- return run_result
464+ continue # Continue monitoring after restart
465+
466+ # No more restarts (either exhausted or early termination)
467+ self ._stop_workers (self ._worker_group )
468+ self ._worker_group .state = WorkerState .FAILED
469+ run_result = self ._monitor_workers (self ._worker_group )
470+ return run_result
445471 elif state == WorkerState .HEALTHY :
446472 # Check for cluster-wide issues: unhealthy nodes or new nodes waiting
447473 unhealthy_count = self ._check_cluster_unhealthy_count ()
@@ -579,31 +605,132 @@ def get_rank_mon_socket_path(self, local_rank):
579605
580606 def setup_rank_monitors (self , envs : Dict [int , Dict [str , str ]]) -> None :
581607 fork_mp_ctx = torch .multiprocessing .get_context ("fork" )
608+ new_monitors = [] # Track newly started monitors
609+
582610 for worker_env in envs .values ():
583611 # Start rank monitors if not already started
584612 # Each rank (re)connects to its rank monitor when it starts
585613 # Monitor of the local rank0 on the store hosting node is the restarter logger
586614 local_rank = int (worker_env ['LOCAL_RANK' ])
587615 is_restarter_logger = self ._is_store_host and local_rank == 0
588616 rmon_ipc_socket = worker_env [FT_RANK_MONITOR_IPC_SOCKET_ENV_VAR ]
589- if local_rank not in self ._local_rank_to_rmon :
590- self . _local_rank_to_rmon [ local_rank ] = RankMonitorServer .run_in_subprocess (
617+ if local_rank not in self ._rank_monitors :
618+ rmon_proc = RankMonitorServer .run_in_subprocess (
591619 cfg = self ._ft_cfg ,
592620 ipc_socket_path = rmon_ipc_socket ,
593621 is_restarter_logger = is_restarter_logger ,
594622 mp_ctx = fork_mp_ctx ,
595623 env = worker_env ,
596624 )
625+ self ._rank_monitors [local_rank ] = RankMonitorState (process = rmon_proc )
626+ new_monitors .append ((local_rank , rmon_proc ))
627+
628+ # Establish bidirectional IPC connections to new rank monitors
629+ if new_monitors :
630+ async def connect_all ():
631+ await asyncio .gather (
632+ * [self ._connect_to_rank_monitor (lr , rmon ) for lr , rmon in new_monitors ]
633+ )
634+ asyncio .run (connect_all ())
597635
598636 def shutdown_rank_monitors (self ):
599- for local_rank , rmon_proc in self ._local_rank_to_rmon .items ():
637+ # Stop listener threads and terminate rank monitor processes
638+ for local_rank , state in self ._rank_monitors .items ():
639+ # Signal listener thread to stop
640+ if state .stop_event :
641+ state .stop_event .set ()
642+
643+ # Wait for listener thread to finish (will close connection gracefully)
644+ if state .listener_thread and state .listener_thread .is_alive ():
645+ state .listener_thread .join (timeout = 2.0 )
646+
647+ # Terminate rank monitor processes
648+ for local_rank , state in self ._rank_monitors .items ():
600649 with contextlib .suppress (Exception ):
601- rmon_proc .terminate ()
650+ state . process .terminate ()
602651 with contextlib .suppress (Exception ):
603- rmon_proc .join ()
652+ state . process .join ()
604653 with contextlib .suppress (Exception ):
605654 os .unlink (self .get_rank_mon_socket_path (local_rank ))
606655
656+ async def _connect_to_rank_monitor (self , local_rank : int , rmon_proc ) -> None :
657+ """Start listener thread for rank monitor bidirectional IPC.
658+
659+ Note: This is called after rank_monitor_ready_event is set, which guarantees
660+ the socket file already exists. The actual connection is created inside the
661+ background thread's event loop to avoid event loop conflicts.
662+ """
663+ launcher_to_rmon_socket = f"{ tempfile .gettempdir ()} /_ft_launcher{ rmon_proc .pid } _to_rmon.socket"
664+
665+ state = self ._rank_monitors [local_rank ]
666+ state .socket_path = launcher_to_rmon_socket
667+ state .stop_event = threading .Event ()
668+
669+ # Start listener thread (will create connection in its own event loop)
670+ state .listener_thread = threading .Thread (
671+ target = self ._listen_to_rank_monitor_thread ,
672+ args = (local_rank , launcher_to_rmon_socket , state .stop_event ),
673+ daemon = True ,
674+ name = f"RankMonitor-{ local_rank } -Listener"
675+ )
676+ state .listener_thread .start ()
677+
678+ def _update_progress_iteration (self , local_rank : int , iteration : int ):
679+ """Update iteration for a specific rank and aggregate using MIN strategy."""
680+ # Update this rank's max iteration
681+ self ._rank_iterations [local_rank ] = max (
682+ self ._rank_iterations .get (local_rank , 0 ), iteration
683+ )
684+
685+ # Use minimum across all ranks (most conservative - slowest rank determines progress)
686+ min_iteration = min (self ._rank_iterations .values ()) if self ._rank_iterations else 0
687+ self ._progress_tracker .update_iteration (min_iteration )
688+
689+ def _listen_to_rank_monitor_thread (self , local_rank : int , socket_path : str , stop_event : threading .Event ) -> None :
690+ """Listen for messages from rank monitor in a background thread.
691+
692+ This runs in a separate thread with its own event loop to receive messages
693+ from the rank monitor server. The connection is created in this thread's
694+ event loop to avoid cross-loop conflicts.
695+ """
696+ # Create a new event loop for this thread
697+ loop = asyncio .new_event_loop ()
698+ asyncio .set_event_loop (loop )
699+
700+ async def listen_loop ():
701+ try :
702+ # Create connection in THIS thread's event loop
703+ reader , writer = await asyncio .open_unix_connection (socket_path )
704+
705+ try :
706+ while not stop_event .is_set ():
707+ # Use wait_for with timeout to allow checking stop_event periodically
708+ try :
709+ msg = await asyncio .wait_for (read_obj_from_ipc_stream (reader ), timeout = 1.0 )
710+ if isinstance (msg , dict ) and msg .get ("type" ) == "iteration_update" :
711+ # Handle iteration update from rank monitor
712+ iteration = msg ["iteration" ]
713+ self ._update_progress_iteration (local_rank , iteration )
714+ else :
715+ logger .debug (f"Received message from rank monitor { local_rank } : { msg } " )
716+ except asyncio .TimeoutError :
717+ # Timeout is expected, just check stop_event and continue
718+ continue
719+ finally :
720+ # Clean up connection
721+ writer .close ()
722+ await writer .wait_closed ()
723+ except (asyncio .IncompleteReadError , ConnectionResetError , BrokenPipeError , EOFError ):
724+ logger .debug (f"Rank monitor { local_rank } connection closed" )
725+ except Exception as e :
726+ if not stop_event .is_set ():
727+ logger .error (f"Error listening to rank monitor { local_rank } : { e } " )
728+
729+ try :
730+ loop .run_until_complete (listen_loop ())
731+ finally :
732+ loop .close ()
733+
607734 def _setup_local_watchdog (self , envs : Dict [int , Dict [str , str ]]) -> None :
608735 enable_watchdog_env_name = TORCHELASTIC_ENABLE_FILE_TIMER
609736 watchdog_enabled = os .getenv (enable_watchdog_env_name )
@@ -682,23 +809,7 @@ def _stop_workers(self, worker_group: WorkerGroup, *args, **kwargs) -> None:
682809
683810 logger .info (f"Stopping workers... Timeout = { self ._workers_stop_timeout } sec." )
684811
685- # Send close message to rank monitors
686- for local_rank , rmon_proc in self ._local_rank_to_rmon .items ():
687- try :
688- launcher_to_rmon_socket = f"{ tempfile .gettempdir ()} /_ft_launcher{ rmon_proc .pid } _to_rmon.socket"
689- if os .path .exists (launcher_to_rmon_socket ):
690- async def send_close_msg ():
691- reader , writer = await asyncio .open_unix_connection (launcher_to_rmon_socket )
692- try :
693- await write_obj_to_ipc_stream ("close_worker_ipc_connection" , writer )
694- finally :
695- writer .close ()
696- await writer .wait_closed ()
697-
698- asyncio .run (send_close_msg ())
699- except Exception as e :
700- logger .warning (f"Failed to send close message to rank monitor { local_rank } : { e } " )
701-
812+ # Rank monitors will detect worker shutdown when worker processes disconnect
702813 self ._shutdown (timeout = self ._workers_stop_timeout )
703814
704815 # Record worker termination event after shutdown is complete
@@ -2074,6 +2185,32 @@ def get_args_parser() -> ArgumentParser:
20742185 help = "Do not raise an error if there is no Fault Tolerance pkg config provided, just use default settings." ,
20752186 )
20762187
2188+ #
2189+ # Progress tracking arguments
2190+ #
2191+
2192+ parser .add_argument (
2193+ "--ft-max-no-progress-restarts" ,
2194+ "--ft-max_no_progress_restarts" ,
2195+ type = int ,
2196+ default = 3 ,
2197+ dest = "ft_max_no_progress_restarts" ,
2198+ help = "Maximum consecutive restarts without progress before early termination. "
2199+ "Progress tracking is enabled when this value > 0. "
2200+ "Set to 0 or -1 to disable progress tracking. "
2201+ "Default: 3 (progress tracking enabled)." ,
2202+ )
2203+
2204+ parser .add_argument (
2205+ "--ft-min-progress-iterations" ,
2206+ "--ft-min_progress_iterations" ,
2207+ type = int ,
2208+ default = 200 ,
2209+ dest = "ft_min_progress_iterations" ,
2210+ help = "Minimum iterations required to consider a restart as making progress. "
2211+ "Default: 200." ,
2212+ )
2213+
20772214 #
20782215 # Positional arguments.
20792216 #
0 commit comments