6969 FT_LAUNCHER_IPC_SOCKET_ENV_VAR ,
7070 FT_RANK_MONITOR_IPC_SOCKET_ENV_VAR ,
7171)
72+ from nvidia_resiliency_ext .fault_tolerance .progress_tracker import TrainingProgressTracker
7273from nvidia_resiliency_ext .fault_tolerance .rank_monitor_server import RankMonitorServer
7374from nvidia_resiliency_ext .fault_tolerance .utils import (
7475 patched_method ,
76+ read_obj_from_ipc_stream ,
7577 terminate_mp_processes ,
7678 write_obj_to_ipc_stream ,
7779)
@@ -192,6 +194,15 @@ def _wrap_entrypoint_with_numactl(
192194# https://github.com/pytorch/pytorch/blob/release/2.3/torch/distributed/elastic/agent/server/local_elastic_agent.py
193195
194196
197+ @dataclass
198+ class RankMonitorState :
199+ """State for a single rank monitor process and its IPC connections."""
200+ process : Any # multiprocessing.Process
201+ reader : Optional [asyncio .StreamReader ] = None
202+ writer : Optional [asyncio .StreamWriter ] = None
203+ listener_task : Optional [asyncio .Task ] = None
204+
205+
195206class LocalElasticAgent (SimpleElasticAgent ):
196207 """An implementation of :py:class:`torchelastic.agent.server.ElasticAgent` that handles host-local workers.
197208
@@ -317,8 +328,15 @@ def __init__(
317328 self ._term_timeout = term_timeout
318329 self ._workers_stop_timeout = workers_stop_timeout
319330 self ._is_store_host = is_store_host
320- self ._local_rank_to_rmon : Dict [int , Any ] = dict ()
331+ # Rank monitor state (process, IPC connections, listener tasks) per local rank
332+ self ._rank_monitors : Dict [int , RankMonitorState ] = dict ()
321333 self ._ft_cfg = fault_tol_cfg
334+ # Centralized progress tracking (always instantiated, active only if configured)
335+ self ._progress_tracker = TrainingProgressTracker (
336+ min_progress_iterations = fault_tol_cfg .min_progress_iterations ,
337+ max_no_progress_restarts = fault_tol_cfg .max_no_progress_restarts ,
338+ )
339+ self ._rank_iterations : Dict [int , int ] = dict () # Track max iteration per rank
322340 self ._children_pgids : Set [int ] = set ()
323341 self ._restart_policy = restart_policy
324342 self ._node_id = self ._get_fq_hostname ()
@@ -367,7 +385,7 @@ def _open_rendezvous_for_restart(self):
367385 self ._worker_group .group_rank if self ._worker_group else "N/A"
368386 )
369387 except Exception as e :
370- logger .warning (f"Failed to open rendezvous: { e } " )
388+ logger .error (f"Failed to open rendezvous: { e } " )
371389 # For legacy rendezvous, no action needed - it uses different mechanism
372390
373391 def _invoke_run (self , role : str = DEFAULT_ROLE ) -> RunResult :
@@ -420,7 +438,16 @@ def _invoke_run_with_any_failed_policy(self, role: str = DEFAULT_ROLE) -> RunRes
420438 rank = self ._worker_group .group_rank ,
421439 )
422440
423- if self ._remaining_restarts > 0 :
441+ self ._progress_tracker .analyze_previous_cycle ()
442+ should_terminate_early = self ._progress_tracker .should_terminate_early ()
443+
444+ if should_terminate_early :
445+ logger .error (
446+ "[%s] Progress tracker detected no progress across restarts. "
447+ "No more restarts will be attempted." ,
448+ role
449+ )
450+ elif self ._remaining_restarts > 0 :
424451 logger .info (
425452 "[%s] Worker group %s. "
426453 "%s/%s attempts left;"
@@ -434,14 +461,13 @@ def _invoke_run_with_any_failed_policy(self, role: str = DEFAULT_ROLE) -> RunRes
434461 # Open rendezvous before restarting (for barrier-based rendezvous)
435462 self ._open_rendezvous_for_restart ()
436463 self ._restart_workers (self ._worker_group )
437- else :
438- self ._stop_workers (self ._worker_group )
439- self ._worker_group .state = WorkerState .FAILED
440- # to preserve torchrun's behaviour, should not return WorkerState.UNHEALTHY.
441- # we use WorkerState.UNHEALTHY to denote a worker group that is still
442- # running but has some failed workers. torchrun does not use WorkerState.UNHEALTHY
443- run_result = self ._monitor_workers (self ._worker_group )
444- return run_result
464+ continue # Continue monitoring after restart
465+
466+ # No more restarts (either exhausted or early termination)
467+ self ._stop_workers (self ._worker_group )
468+ self ._worker_group .state = WorkerState .FAILED
469+ run_result = self ._monitor_workers (self ._worker_group )
470+ return run_result
445471 elif state == WorkerState .HEALTHY :
446472 # Check for cluster-wide issues: unhealthy nodes or new nodes waiting
447473 unhealthy_count = self ._check_cluster_unhealthy_count ()
@@ -579,31 +605,113 @@ def get_rank_mon_socket_path(self, local_rank):
579605
580606 def setup_rank_monitors (self , envs : Dict [int , Dict [str , str ]]) -> None :
581607 fork_mp_ctx = torch .multiprocessing .get_context ("fork" )
608+ new_monitors = [] # Track newly started monitors
609+
582610 for worker_env in envs .values ():
583611 # Start rank monitors if not already started
584612 # Each rank (re)connects to its rank monitor when it starts
585613 # Monitor of the local rank0 on the store hosting node is the restarter logger
586614 local_rank = int (worker_env ['LOCAL_RANK' ])
587615 is_restarter_logger = self ._is_store_host and local_rank == 0
588616 rmon_ipc_socket = worker_env [FT_RANK_MONITOR_IPC_SOCKET_ENV_VAR ]
589- if local_rank not in self ._local_rank_to_rmon :
590- self . _local_rank_to_rmon [ local_rank ] = RankMonitorServer .run_in_subprocess (
617+ if local_rank not in self ._rank_monitors :
618+ rmon_proc = RankMonitorServer .run_in_subprocess (
591619 cfg = self ._ft_cfg ,
592620 ipc_socket_path = rmon_ipc_socket ,
593621 is_restarter_logger = is_restarter_logger ,
594622 mp_ctx = fork_mp_ctx ,
595623 env = worker_env ,
596624 )
625+ self ._rank_monitors [local_rank ] = RankMonitorState (process = rmon_proc )
626+ new_monitors .append ((local_rank , rmon_proc ))
627+
628+ # Establish bidirectional IPC connections to new rank monitors
629+ if new_monitors :
630+ async def connect_all ():
631+ await asyncio .gather (
632+ * [self ._connect_to_rank_monitor (lr , rmon ) for lr , rmon in new_monitors ]
633+ )
634+ asyncio .run (connect_all ())
597635
598636 def shutdown_rank_monitors (self ):
599- for local_rank , rmon_proc in self ._local_rank_to_rmon .items ():
637+ # Stop listener tasks, close connections, and send shutdown messages
638+ for local_rank , state in self ._rank_monitors .items ():
639+ # Cancel listener task
640+ if state .listener_task and not state .listener_task .done ():
641+ state .listener_task .cancel ()
642+
643+ # Close connection with shutdown message
644+ if state .writer :
645+ try :
646+ async def send_shutdown ():
647+ await write_obj_to_ipc_stream ("shutdown" , state .writer )
648+ state .writer .close ()
649+ await state .writer .wait_closed ()
650+ asyncio .run (send_shutdown ())
651+ except Exception as e :
652+ logger .debug (f"Error closing rank monitor connection for rank { local_rank } : { e } " )
653+
654+ # Terminate rank monitor processes
655+ for local_rank , state in self ._rank_monitors .items ():
600656 with contextlib .suppress (Exception ):
601- rmon_proc .terminate ()
657+ state . process .terminate ()
602658 with contextlib .suppress (Exception ):
603- rmon_proc .join ()
659+ state . process .join ()
604660 with contextlib .suppress (Exception ):
605661 os .unlink (self .get_rank_mon_socket_path (local_rank ))
606662
663+ async def _connect_to_rank_monitor (self , local_rank : int , rmon_proc ) -> None :
664+ """Establish persistent connection to rank monitor for bidirectional IPC.
665+
666+ Note: This is called after rank_monitor_ready_event is set, which guarantees
667+ the socket file already exists.
668+ """
669+ launcher_to_rmon_socket = f"{ tempfile .gettempdir ()} /_ft_launcher{ rmon_proc .pid } _to_rmon.socket"
670+
671+ reader , writer = await asyncio .open_unix_connection (launcher_to_rmon_socket )
672+ state = self ._rank_monitors [local_rank ]
673+ state .reader = reader
674+ state .writer = writer
675+ logger .debug (f"Connected to rank monitor { local_rank } at { launcher_to_rmon_socket } " )
676+
677+ # Start listener task for this connection
678+ state .listener_task = asyncio .create_task (self ._listen_to_rank_monitor (local_rank , reader ))
679+
680+ def _update_progress_iteration (self , local_rank : int , iteration : int ):
681+ """Update iteration for a specific rank and aggregate using MIN strategy."""
682+ # Update this rank's max iteration
683+ self ._rank_iterations [local_rank ] = max (
684+ self ._rank_iterations .get (local_rank , 0 ), iteration
685+ )
686+
687+ # Use minimum across all ranks (most conservative - slowest rank determines progress)
688+ min_iteration = min (self ._rank_iterations .values ()) if self ._rank_iterations else 0
689+ self ._progress_tracker .update_iteration (min_iteration )
690+
691+ logger .debug (
692+ f"Updated iteration for rank { local_rank } ={ iteration } , "
693+ f"cluster min={ min_iteration } , all_ranks={ self ._rank_iterations } "
694+ )
695+
696+ async def _listen_to_rank_monitor (self , local_rank : int , reader ) -> None :
697+ """Listen for messages from rank monitor."""
698+ try :
699+ while True :
700+ msg = await read_obj_from_ipc_stream (reader )
701+ if isinstance (msg , dict ) and msg .get ("type" ) == "iteration_update" :
702+ # Handle iteration update from rank monitor
703+ iteration = msg ["iteration" ]
704+ self ._update_progress_iteration (local_rank , iteration )
705+ logger .debug (f"[Rank { local_rank } ] Received iteration update: { iteration } " )
706+ else :
707+ logger .debug (f"Received message from rank monitor { local_rank } : { msg } " )
708+ except (asyncio .IncompleteReadError , ConnectionResetError , BrokenPipeError , EOFError ):
709+ logger .debug (f"Rank monitor { local_rank } connection closed" )
710+ except asyncio .CancelledError :
711+ logger .debug (f"Listener for rank monitor { local_rank } cancelled" )
712+ except Exception as e :
713+ logger .error (f"Error listening to rank monitor { local_rank } : { e } " )
714+
607715 def _setup_local_watchdog (self , envs : Dict [int , Dict [str , str ]]) -> None :
608716 enable_watchdog_env_name = TORCHELASTIC_ENABLE_FILE_TIMER
609717 watchdog_enabled = os .getenv (enable_watchdog_env_name )
@@ -682,22 +790,27 @@ def _stop_workers(self, worker_group: WorkerGroup, *args, **kwargs) -> None:
682790
683791 logger .info (f"Stopping workers... Timeout = { self ._workers_stop_timeout } sec." )
684792
685- # Send close message to rank monitors
686- for local_rank , rmon_proc in self ._local_rank_to_rmon .items ():
687- try :
688- launcher_to_rmon_socket = f"{ tempfile .gettempdir ()} /_ft_launcher{ rmon_proc .pid } _to_rmon.socket"
689- if os .path .exists (launcher_to_rmon_socket ):
690- async def send_close_msg ():
691- reader , writer = await asyncio .open_unix_connection (launcher_to_rmon_socket )
692- try :
693- await write_obj_to_ipc_stream ("close_worker_ipc_connection" , writer )
694- finally :
695- writer .close ()
696- await writer .wait_closed ()
697-
698- asyncio .run (send_close_msg ())
699- except Exception as e :
700- logger .warning (f"Failed to send close message to rank monitor { local_rank } : { e } " )
793+ # Send close message to rank monitors through persistent connections
794+ async def send_close_messages ():
795+ tasks = []
796+ for local_rank , state in self ._rank_monitors .items ():
797+ if state .writer :
798+ async def send_msg (writer , local_rank ):
799+ await write_obj_to_ipc_stream ("close_worker_ipc_connection" , writer )
800+ tasks .append (send_msg (state .writer , local_rank ))
801+ if tasks :
802+ # return_exceptions=True catches exceptions from send_msg, no need for try-except inside
803+ results = await asyncio .gather (* tasks , return_exceptions = True )
804+ for local_rank , result in zip ([lr for lr , s in self ._rank_monitors .items () if s .writer ], results ):
805+ if isinstance (result , Exception ):
806+ # Connection errors during shutdown are expected (rank monitor may be dead)
807+ if isinstance (result , (ConnectionError , BrokenPipeError , OSError )):
808+ logger .debug (f"Rank monitor { local_rank } already disconnected: { result } " )
809+ else :
810+ logger .warning (f"Unexpected error sending close message to rank monitor { local_rank } : { result } " )
811+
812+ if self ._rank_monitors :
813+ asyncio .run (send_close_messages ())
701814
702815 self ._shutdown (timeout = self ._workers_stop_timeout )
703816
@@ -2074,6 +2187,32 @@ def get_args_parser() -> ArgumentParser:
20742187 help = "Do not raise an error if there is no Fault Tolerance pkg config provided, just use default settings." ,
20752188 )
20762189
2190+ #
2191+ # Progress tracking arguments
2192+ #
2193+
2194+ parser .add_argument (
2195+ "--ft-max-no-progress-restarts" ,
2196+ "--ft-max_no_progress_restarts" ,
2197+ type = int ,
2198+ default = 3 ,
2199+ dest = "ft_max_no_progress_restarts" ,
2200+ help = "Maximum consecutive restarts without progress before early termination. "
2201+ "Progress tracking is enabled when this value > 0. "
2202+ "Set to 0 or -1 to disable progress tracking. "
2203+ "Default: 3 (progress tracking enabled)." ,
2204+ )
2205+
2206+ parser .add_argument (
2207+ "--ft-min-progress-iterations" ,
2208+ "--ft-min_progress_iterations" ,
2209+ type = int ,
2210+ default = 200 ,
2211+ dest = "ft_min_progress_iterations" ,
2212+ help = "Minimum iterations required to consider a restart as making progress. "
2213+ "Default: 200." ,
2214+ )
2215+
20772216 #
20782217 # Positional arguments.
20792218 #
0 commit comments