Skip to content

Commit 9b71268

Browse files
committed
feat: Add progress-based early termination for in-job restarts
Implement centralized training progress tracking to automatically terminate jobs when restarts fail to make meaningful progress, preventing wasted compute on jobs stuck in restart loops. Problem: -------- The InJob restart would continue restarting failed jobs up to max_restarts regardless of whether restarts were effective. Jobs that crash repeatedly wast compute time across many restart attempts before exhausting the restart budget. Solution: --------- - Immediate arming: Rank monitors send iteration=1 on init to enable tracking - Periodic update: Rank monitors send iteration updates every 30 seconds. - Early termination after 3 consecutive restarts with <200 iteration progress - Works for both fast-failing (seconds) and slow-failing (minutes/hours) jobs Configuration: -------------- --ft-max-no-progress-restarts=3 # Max restarts without progress (0 = disabled) --ft-min-progress-iterations=200 # Min iterations to show progress --ft-progress-update-interval=30.0 # Update frequency (seconds)
1 parent 9801d27 commit 9b71268

File tree

6 files changed

+514
-67
lines changed

6 files changed

+514
-67
lines changed

src/nvidia_resiliency_ext/fault_tolerance/config.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,15 @@ class FaultToleranceConfig:
8585
skip_section_response: bool = True
8686
use_infra_group_rank: bool = True
8787
numa_bind_strict: bool = False
88+
# Progress tracking configuration (controlled by max_no_progress_restarts)
89+
max_no_progress_restarts: int = 3
90+
min_progress_iterations: int = 200
91+
progress_update_interval: float = 30.0 # Seconds between sending progress updates to launcher
92+
93+
@property
94+
def is_progress_tracking_enabled(self) -> bool:
95+
"""Check if progress tracking is enabled (controlled by max_no_progress_restarts > 0)."""
96+
return self.max_no_progress_restarts > 0
8897

8998
@staticmethod
9099
def from_kwargs(ignore_not_recognized: bool = True, **kwargs) -> 'FaultToleranceConfig':

src/nvidia_resiliency_ext/fault_tolerance/data.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,10 +152,18 @@ def __init__(self, authkey: Optional[bytes] = None):
152152

153153
class InitMsg:
154154
"""
155-
Send (rank -> rank monitor) to initialize new session
155+
Send (rank -> rank monitor) to initialize new session.
156+
157+
Attributes:
158+
rank_info: Information about this rank
159+
iteration: Current training iteration if available from workload framework.
160+
If None, indicates that the workload cannot report iterations,
161+
and progress tracking should remain disabled.
156162
"""
157163

158-
pass
164+
def __init__(self, rank_info=None, iteration: Optional[int] = None):
165+
self.rank_info = rank_info
166+
self.iteration = iteration
159167

160168

161169
class HeartbeatMsg:
@@ -188,10 +196,12 @@ def __init__(
188196
rank: int,
189197
section: str,
190198
action: SectionAction,
199+
iteration: Optional[int] = None,
191200
):
192201
self.rank = rank
193202
self.section = section
194203
self.action = action
204+
self.iteration = iteration
195205

196206

197207
class UpdateConfigMsg:

src/nvidia_resiliency_ext/fault_tolerance/launcher.py

Lines changed: 171 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import socket
2626
import sys
2727
import tempfile
28+
import threading
2829
import time
2930
import uuid
3031
import warnings
@@ -69,11 +70,12 @@
6970
FT_LAUNCHER_IPC_SOCKET_ENV_VAR,
7071
FT_RANK_MONITOR_IPC_SOCKET_ENV_VAR,
7172
)
73+
from nvidia_resiliency_ext.fault_tolerance.progress_tracker import TrainingProgressTracker
7274
from nvidia_resiliency_ext.fault_tolerance.rank_monitor_server import RankMonitorServer
7375
from nvidia_resiliency_ext.fault_tolerance.utils import (
7476
patched_method,
77+
read_obj_from_ipc_stream,
7578
terminate_mp_processes,
76-
write_obj_to_ipc_stream,
7779
)
7880
from nvidia_resiliency_ext.shared_utils.log_manager import LogConfig, setup_logger
7981
from nvidia_resiliency_ext.shared_utils.profiling import ProfilingEvent, record_profiling_event
@@ -192,6 +194,15 @@ def _wrap_entrypoint_with_numactl(
192194
# https://github.com/pytorch/pytorch/blob/release/2.3/torch/distributed/elastic/agent/server/local_elastic_agent.py
193195

194196

197+
@dataclass
198+
class RankMonitorState:
199+
"""State for a single rank monitor process and its IPC connections."""
200+
process: Any # multiprocessing.Process
201+
socket_path: str = ""
202+
listener_thread: Optional[threading.Thread] = None
203+
stop_event: Optional[threading.Event] = None
204+
205+
195206
class LocalElasticAgent(SimpleElasticAgent):
196207
"""An implementation of :py:class:`torchelastic.agent.server.ElasticAgent` that handles host-local workers.
197208
@@ -317,8 +328,15 @@ def __init__(
317328
self._term_timeout = term_timeout
318329
self._workers_stop_timeout = workers_stop_timeout
319330
self._is_store_host = is_store_host
320-
self._local_rank_to_rmon: Dict[int, Any] = dict()
331+
# Rank monitor state (process, IPC connections, listener tasks) per local rank
332+
self._rank_monitors: Dict[int, RankMonitorState] = dict()
321333
self._ft_cfg = fault_tol_cfg
334+
# Centralized progress tracking (always instantiated, active only if configured)
335+
self._progress_tracker = TrainingProgressTracker(
336+
min_progress_iterations=fault_tol_cfg.min_progress_iterations,
337+
max_no_progress_restarts=fault_tol_cfg.max_no_progress_restarts,
338+
)
339+
self._rank_iterations: Dict[int, int] = dict() # Track max iteration per rank
322340
self._children_pgids: Set[int] = set()
323341
self._restart_policy = restart_policy
324342
self._node_id = self._get_fq_hostname()
@@ -367,7 +385,7 @@ def _open_rendezvous_for_restart(self):
367385
self._worker_group.group_rank if self._worker_group else "N/A"
368386
)
369387
except Exception as e:
370-
logger.warning(f"Failed to open rendezvous: {e}")
388+
logger.error(f"Failed to open rendezvous: {e}")
371389
# For legacy rendezvous, no action needed - it uses different mechanism
372390

373391
def _invoke_run(self, role: str = DEFAULT_ROLE) -> RunResult:
@@ -420,7 +438,16 @@ def _invoke_run_with_any_failed_policy(self, role: str = DEFAULT_ROLE) -> RunRes
420438
rank=self._worker_group.group_rank,
421439
)
422440

423-
if self._remaining_restarts > 0:
441+
self._progress_tracker.analyze_previous_cycle()
442+
should_terminate_early = self._progress_tracker.should_terminate_early()
443+
444+
if should_terminate_early:
445+
logger.error(
446+
"[%s] Progress tracker detected no progress across restarts. "
447+
"No more restarts will be attempted.",
448+
role
449+
)
450+
elif self._remaining_restarts > 0:
424451
logger.info(
425452
"[%s] Worker group %s. "
426453
"%s/%s attempts left;"
@@ -434,14 +461,13 @@ def _invoke_run_with_any_failed_policy(self, role: str = DEFAULT_ROLE) -> RunRes
434461
# Open rendezvous before restarting (for barrier-based rendezvous)
435462
self._open_rendezvous_for_restart()
436463
self._restart_workers(self._worker_group)
437-
else:
438-
self._stop_workers(self._worker_group)
439-
self._worker_group.state = WorkerState.FAILED
440-
# to preserve torchrun's behaviour, should not return WorkerState.UNHEALTHY.
441-
# we use WorkerState.UNHEALTHY to denote a worker group that is still
442-
# running but has some failed workers. torchrun does not use WorkerState.UNHEALTHY
443-
run_result = self._monitor_workers(self._worker_group)
444-
return run_result
464+
continue # Continue monitoring after restart
465+
466+
# No more restarts (either exhausted or early termination)
467+
self._stop_workers(self._worker_group)
468+
self._worker_group.state = WorkerState.FAILED
469+
run_result = self._monitor_workers(self._worker_group)
470+
return run_result
445471
elif state == WorkerState.HEALTHY:
446472
# Check for cluster-wide issues: unhealthy nodes or new nodes waiting
447473
unhealthy_count = self._check_cluster_unhealthy_count()
@@ -579,31 +605,132 @@ def get_rank_mon_socket_path(self, local_rank):
579605

580606
def setup_rank_monitors(self, envs: Dict[int, Dict[str, str]]) -> None:
581607
fork_mp_ctx = torch.multiprocessing.get_context("fork")
608+
new_monitors = [] # Track newly started monitors
609+
582610
for worker_env in envs.values():
583611
# Start rank monitors if not already started
584612
# Each rank (re)connects to its rank monitor when it starts
585613
# Monitor of the local rank0 on the store hosting node is the restarter logger
586614
local_rank = int(worker_env['LOCAL_RANK'])
587615
is_restarter_logger = self._is_store_host and local_rank == 0
588616
rmon_ipc_socket = worker_env[FT_RANK_MONITOR_IPC_SOCKET_ENV_VAR]
589-
if local_rank not in self._local_rank_to_rmon:
590-
self._local_rank_to_rmon[local_rank] = RankMonitorServer.run_in_subprocess(
617+
if local_rank not in self._rank_monitors:
618+
rmon_proc = RankMonitorServer.run_in_subprocess(
591619
cfg=self._ft_cfg,
592620
ipc_socket_path=rmon_ipc_socket,
593621
is_restarter_logger=is_restarter_logger,
594622
mp_ctx=fork_mp_ctx,
595623
env=worker_env,
596624
)
625+
self._rank_monitors[local_rank] = RankMonitorState(process=rmon_proc)
626+
new_monitors.append((local_rank, rmon_proc))
627+
628+
# Establish bidirectional IPC connections to new rank monitors
629+
if new_monitors:
630+
async def connect_all():
631+
await asyncio.gather(
632+
*[self._connect_to_rank_monitor(lr, rmon) for lr, rmon in new_monitors]
633+
)
634+
asyncio.run(connect_all())
597635

598636
def shutdown_rank_monitors(self):
599-
for local_rank, rmon_proc in self._local_rank_to_rmon.items():
637+
# Stop listener threads and terminate rank monitor processes
638+
for local_rank, state in self._rank_monitors.items():
639+
# Signal listener thread to stop
640+
if state.stop_event:
641+
state.stop_event.set()
642+
643+
# Wait for listener thread to finish (will close connection gracefully)
644+
if state.listener_thread and state.listener_thread.is_alive():
645+
state.listener_thread.join(timeout=2.0)
646+
647+
# Terminate rank monitor processes
648+
for local_rank, state in self._rank_monitors.items():
600649
with contextlib.suppress(Exception):
601-
rmon_proc.terminate()
650+
state.process.terminate()
602651
with contextlib.suppress(Exception):
603-
rmon_proc.join()
652+
state.process.join()
604653
with contextlib.suppress(Exception):
605654
os.unlink(self.get_rank_mon_socket_path(local_rank))
606655

656+
async def _connect_to_rank_monitor(self, local_rank: int, rmon_proc) -> None:
657+
"""Start listener thread for rank monitor bidirectional IPC.
658+
659+
Note: This is called after rank_monitor_ready_event is set, which guarantees
660+
the socket file already exists. The actual connection is created inside the
661+
background thread's event loop to avoid event loop conflicts.
662+
"""
663+
launcher_to_rmon_socket = f"{tempfile.gettempdir()}/_ft_launcher{rmon_proc.pid}_to_rmon.socket"
664+
665+
state = self._rank_monitors[local_rank]
666+
state.socket_path = launcher_to_rmon_socket
667+
state.stop_event = threading.Event()
668+
669+
# Start listener thread (will create connection in its own event loop)
670+
state.listener_thread = threading.Thread(
671+
target=self._listen_to_rank_monitor_thread,
672+
args=(local_rank, launcher_to_rmon_socket, state.stop_event),
673+
daemon=True,
674+
name=f"RankMonitor-{local_rank}-Listener"
675+
)
676+
state.listener_thread.start()
677+
678+
def _update_progress_iteration(self, local_rank: int, iteration: int):
679+
"""Update iteration for a specific rank and aggregate using MIN strategy."""
680+
# Update this rank's max iteration
681+
self._rank_iterations[local_rank] = max(
682+
self._rank_iterations.get(local_rank, 0), iteration
683+
)
684+
685+
# Use minimum across all ranks (most conservative - slowest rank determines progress)
686+
min_iteration = min(self._rank_iterations.values()) if self._rank_iterations else 0
687+
self._progress_tracker.update_iteration(min_iteration)
688+
689+
def _listen_to_rank_monitor_thread(self, local_rank: int, socket_path: str, stop_event: threading.Event) -> None:
690+
"""Listen for messages from rank monitor in a background thread.
691+
692+
This runs in a separate thread with its own event loop to receive messages
693+
from the rank monitor server. The connection is created in this thread's
694+
event loop to avoid cross-loop conflicts.
695+
"""
696+
# Create a new event loop for this thread
697+
loop = asyncio.new_event_loop()
698+
asyncio.set_event_loop(loop)
699+
700+
async def listen_loop():
701+
try:
702+
# Create connection in THIS thread's event loop
703+
reader, writer = await asyncio.open_unix_connection(socket_path)
704+
705+
try:
706+
while not stop_event.is_set():
707+
# Use wait_for with timeout to allow checking stop_event periodically
708+
try:
709+
msg = await asyncio.wait_for(read_obj_from_ipc_stream(reader), timeout=1.0)
710+
if isinstance(msg, dict) and msg.get("type") == "iteration_update":
711+
# Handle iteration update from rank monitor
712+
iteration = msg["iteration"]
713+
self._update_progress_iteration(local_rank, iteration)
714+
else:
715+
logger.debug(f"Received message from rank monitor {local_rank}: {msg}")
716+
except asyncio.TimeoutError:
717+
# Timeout is expected, just check stop_event and continue
718+
continue
719+
finally:
720+
# Clean up connection
721+
writer.close()
722+
await writer.wait_closed()
723+
except (asyncio.IncompleteReadError, ConnectionResetError, BrokenPipeError, EOFError):
724+
logger.debug(f"Rank monitor {local_rank} connection closed")
725+
except Exception as e:
726+
if not stop_event.is_set():
727+
logger.error(f"Error listening to rank monitor {local_rank}: {e}")
728+
729+
try:
730+
loop.run_until_complete(listen_loop())
731+
finally:
732+
loop.close()
733+
607734
def _setup_local_watchdog(self, envs: Dict[int, Dict[str, str]]) -> None:
608735
enable_watchdog_env_name = TORCHELASTIC_ENABLE_FILE_TIMER
609736
watchdog_enabled = os.getenv(enable_watchdog_env_name)
@@ -682,23 +809,7 @@ def _stop_workers(self, worker_group: WorkerGroup, *args, **kwargs) -> None:
682809

683810
logger.info(f"Stopping workers... Timeout = {self._workers_stop_timeout} sec.")
684811

685-
# Send close message to rank monitors
686-
for local_rank, rmon_proc in self._local_rank_to_rmon.items():
687-
try:
688-
launcher_to_rmon_socket = f"{tempfile.gettempdir()}/_ft_launcher{rmon_proc.pid}_to_rmon.socket"
689-
if os.path.exists(launcher_to_rmon_socket):
690-
async def send_close_msg():
691-
reader, writer = await asyncio.open_unix_connection(launcher_to_rmon_socket)
692-
try:
693-
await write_obj_to_ipc_stream("close_worker_ipc_connection", writer)
694-
finally:
695-
writer.close()
696-
await writer.wait_closed()
697-
698-
asyncio.run(send_close_msg())
699-
except Exception as e:
700-
logger.warning(f"Failed to send close message to rank monitor {local_rank}: {e}")
701-
812+
# Rank monitors will detect worker shutdown when worker processes disconnect
702813
self._shutdown(timeout=self._workers_stop_timeout)
703814

704815
# Record worker termination event after shutdown is complete
@@ -2074,6 +2185,32 @@ def get_args_parser() -> ArgumentParser:
20742185
help="Do not raise an error if there is no Fault Tolerance pkg config provided, just use default settings.",
20752186
)
20762187

2188+
#
2189+
# Progress tracking arguments
2190+
#
2191+
2192+
parser.add_argument(
2193+
"--ft-max-no-progress-restarts",
2194+
"--ft-max_no_progress_restarts",
2195+
type=int,
2196+
default=3,
2197+
dest="ft_max_no_progress_restarts",
2198+
help="Maximum consecutive restarts without progress before early termination. "
2199+
"Progress tracking is enabled when this value > 0. "
2200+
"Set to 0 or -1 to disable progress tracking. "
2201+
"Default: 3 (progress tracking enabled).",
2202+
)
2203+
2204+
parser.add_argument(
2205+
"--ft-min-progress-iterations",
2206+
"--ft-min_progress_iterations",
2207+
type=int,
2208+
default=200,
2209+
dest="ft_min_progress_iterations",
2210+
help="Minimum iterations required to consider a restart as making progress. "
2211+
"Default: 200.",
2212+
)
2213+
20772214
#
20782215
# Positional arguments.
20792216
#

0 commit comments

Comments
 (0)