Skip to content

Commit b6b92bd

Browse files
committed
feat: Add progress-based early termination for in-job restarts
Implement centralized training progress tracking to automatically terminate jobs when restarts fail to make meaningful progress, preventing wasted compute on jobs stuck in restart loops. Problem: -------- The InJob restart would continue restarting failed jobs up to max_restarts regardless of whether restarts were effective. Jobs that crash repeatedly wast compute time across many restart attempts before exhausting the restart budget. Solution: --------- - Immediate arming: Rank monitors send iteration=1 on init to enable tracking - Periodic update: Rank monitors send iteration updates every 30 seconds. - Early termination after 3 consecutive restarts with <200 iteration progress - Works for both fast-failing (seconds) and slow-failing (minutes/hours) jobs Configuration: -------------- --ft-max-no-progress-restarts=3 # Max restarts without progress (0 = disabled) --ft-min-progress-iterations=200 # Min iterations to show progress --ft-progress-update-interval=30.0 # Update frequency (seconds)
1 parent 9801d27 commit b6b92bd

File tree

6 files changed

+518
-56
lines changed

6 files changed

+518
-56
lines changed

src/nvidia_resiliency_ext/fault_tolerance/config.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,15 @@ class FaultToleranceConfig:
8585
skip_section_response: bool = True
8686
use_infra_group_rank: bool = True
8787
numa_bind_strict: bool = False
88+
# Progress tracking configuration (controlled by max_no_progress_restarts)
89+
max_no_progress_restarts: int = 3
90+
min_progress_iterations: int = 200
91+
progress_update_interval: float = 30.0 # Seconds between sending progress updates to launcher
92+
93+
@property
94+
def is_progress_tracking_enabled(self) -> bool:
95+
"""Check if progress tracking is enabled (controlled by max_no_progress_restarts > 0)."""
96+
return self.max_no_progress_restarts > 0
8897

8998
@staticmethod
9099
def from_kwargs(ignore_not_recognized: bool = True, **kwargs) -> 'FaultToleranceConfig':

src/nvidia_resiliency_ext/fault_tolerance/data.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,10 +152,18 @@ def __init__(self, authkey: Optional[bytes] = None):
152152

153153
class InitMsg:
154154
"""
155-
Send (rank -> rank monitor) to initialize new session
155+
Send (rank -> rank monitor) to initialize new session.
156+
157+
Attributes:
158+
rank_info: Information about this rank
159+
iteration: Current training iteration if available from workload framework.
160+
If None, indicates that the workload cannot report iterations,
161+
and progress tracking should remain disabled.
156162
"""
157163

158-
pass
164+
def __init__(self, rank_info=None, iteration: Optional[int] = None):
165+
self.rank_info = rank_info
166+
self.iteration = iteration
159167

160168

161169
class HeartbeatMsg:
@@ -188,10 +196,12 @@ def __init__(
188196
rank: int,
189197
section: str,
190198
action: SectionAction,
199+
iteration: Optional[int] = None,
191200
):
192201
self.rank = rank
193202
self.section = section
194203
self.action = action
204+
self.iteration = iteration
195205

196206

197207
class UpdateConfigMsg:

src/nvidia_resiliency_ext/fault_tolerance/launcher.py

Lines changed: 177 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import socket
2626
import sys
2727
import tempfile
28+
import threading
2829
import time
2930
import uuid
3031
import warnings
@@ -69,9 +70,11 @@
6970
FT_LAUNCHER_IPC_SOCKET_ENV_VAR,
7071
FT_RANK_MONITOR_IPC_SOCKET_ENV_VAR,
7172
)
73+
from nvidia_resiliency_ext.fault_tolerance.progress_tracker import TrainingProgressTracker
7274
from nvidia_resiliency_ext.fault_tolerance.rank_monitor_server import RankMonitorServer
7375
from nvidia_resiliency_ext.fault_tolerance.utils import (
7476
patched_method,
77+
read_obj_from_ipc_stream,
7578
terminate_mp_processes,
7679
write_obj_to_ipc_stream,
7780
)
@@ -192,6 +195,15 @@ def _wrap_entrypoint_with_numactl(
192195
# https://github.com/pytorch/pytorch/blob/release/2.3/torch/distributed/elastic/agent/server/local_elastic_agent.py
193196

194197

198+
@dataclass
199+
class RankMonitorState:
200+
"""State for a single rank monitor process and its IPC connections."""
201+
process: Any # multiprocessing.Process
202+
socket_path: str = ""
203+
listener_thread: Optional[threading.Thread] = None
204+
stop_event: Optional[threading.Event] = None
205+
206+
195207
class LocalElasticAgent(SimpleElasticAgent):
196208
"""An implementation of :py:class:`torchelastic.agent.server.ElasticAgent` that handles host-local workers.
197209
@@ -317,8 +329,15 @@ def __init__(
317329
self._term_timeout = term_timeout
318330
self._workers_stop_timeout = workers_stop_timeout
319331
self._is_store_host = is_store_host
320-
self._local_rank_to_rmon: Dict[int, Any] = dict()
332+
# Rank monitor state (process, IPC connections, listener tasks) per local rank
333+
self._rank_monitors: Dict[int, RankMonitorState] = dict()
321334
self._ft_cfg = fault_tol_cfg
335+
# Centralized progress tracking (always instantiated, active only if configured)
336+
self._progress_tracker = TrainingProgressTracker(
337+
min_progress_iterations=fault_tol_cfg.min_progress_iterations,
338+
max_no_progress_restarts=fault_tol_cfg.max_no_progress_restarts,
339+
)
340+
self._rank_iterations: Dict[int, int] = dict() # Track max iteration per rank
322341
self._children_pgids: Set[int] = set()
323342
self._restart_policy = restart_policy
324343
self._node_id = self._get_fq_hostname()
@@ -367,7 +386,7 @@ def _open_rendezvous_for_restart(self):
367386
self._worker_group.group_rank if self._worker_group else "N/A"
368387
)
369388
except Exception as e:
370-
logger.warning(f"Failed to open rendezvous: {e}")
389+
logger.error(f"Failed to open rendezvous: {e}")
371390
# For legacy rendezvous, no action needed - it uses different mechanism
372391

373392
def _invoke_run(self, role: str = DEFAULT_ROLE) -> RunResult:
@@ -420,7 +439,16 @@ def _invoke_run_with_any_failed_policy(self, role: str = DEFAULT_ROLE) -> RunRes
420439
rank=self._worker_group.group_rank,
421440
)
422441

423-
if self._remaining_restarts > 0:
442+
self._progress_tracker.analyze_previous_cycle()
443+
should_terminate_early = self._progress_tracker.should_terminate_early()
444+
445+
if should_terminate_early:
446+
logger.error(
447+
"[%s] Progress tracker detected no progress across restarts. "
448+
"No more restarts will be attempted.",
449+
role
450+
)
451+
elif self._remaining_restarts > 0:
424452
logger.info(
425453
"[%s] Worker group %s. "
426454
"%s/%s attempts left;"
@@ -434,14 +462,13 @@ def _invoke_run_with_any_failed_policy(self, role: str = DEFAULT_ROLE) -> RunRes
434462
# Open rendezvous before restarting (for barrier-based rendezvous)
435463
self._open_rendezvous_for_restart()
436464
self._restart_workers(self._worker_group)
437-
else:
438-
self._stop_workers(self._worker_group)
439-
self._worker_group.state = WorkerState.FAILED
440-
# to preserve torchrun's behaviour, should not return WorkerState.UNHEALTHY.
441-
# we use WorkerState.UNHEALTHY to denote a worker group that is still
442-
# running but has some failed workers. torchrun does not use WorkerState.UNHEALTHY
443-
run_result = self._monitor_workers(self._worker_group)
444-
return run_result
465+
continue # Continue monitoring after restart
466+
467+
# No more restarts (either exhausted or early termination)
468+
self._stop_workers(self._worker_group)
469+
self._worker_group.state = WorkerState.FAILED
470+
run_result = self._monitor_workers(self._worker_group)
471+
return run_result
445472
elif state == WorkerState.HEALTHY:
446473
# Check for cluster-wide issues: unhealthy nodes or new nodes waiting
447474
unhealthy_count = self._check_cluster_unhealthy_count()
@@ -579,31 +606,138 @@ def get_rank_mon_socket_path(self, local_rank):
579606

580607
def setup_rank_monitors(self, envs: Dict[int, Dict[str, str]]) -> None:
581608
fork_mp_ctx = torch.multiprocessing.get_context("fork")
609+
new_monitors = [] # Track newly started monitors
610+
582611
for worker_env in envs.values():
583612
# Start rank monitors if not already started
584613
# Each rank (re)connects to its rank monitor when it starts
585614
# Monitor of the local rank0 on the store hosting node is the restarter logger
586615
local_rank = int(worker_env['LOCAL_RANK'])
587616
is_restarter_logger = self._is_store_host and local_rank == 0
588617
rmon_ipc_socket = worker_env[FT_RANK_MONITOR_IPC_SOCKET_ENV_VAR]
589-
if local_rank not in self._local_rank_to_rmon:
590-
self._local_rank_to_rmon[local_rank] = RankMonitorServer.run_in_subprocess(
618+
if local_rank not in self._rank_monitors:
619+
rmon_proc = RankMonitorServer.run_in_subprocess(
591620
cfg=self._ft_cfg,
592621
ipc_socket_path=rmon_ipc_socket,
593622
is_restarter_logger=is_restarter_logger,
594623
mp_ctx=fork_mp_ctx,
595624
env=worker_env,
596625
)
626+
self._rank_monitors[local_rank] = RankMonitorState(process=rmon_proc)
627+
new_monitors.append((local_rank, rmon_proc))
628+
629+
# Establish bidirectional IPC connections to new rank monitors
630+
if new_monitors:
631+
async def connect_all():
632+
await asyncio.gather(
633+
*[self._connect_to_rank_monitor(lr, rmon) for lr, rmon in new_monitors]
634+
)
635+
asyncio.run(connect_all())
597636

598637
def shutdown_rank_monitors(self):
599-
for local_rank, rmon_proc in self._local_rank_to_rmon.items():
638+
# Stop listener threads and terminate rank monitor processes
639+
for local_rank, state in self._rank_monitors.items():
640+
# Signal listener thread to stop
641+
if state.stop_event:
642+
state.stop_event.set()
643+
644+
# Wait for listener thread to finish (will close connection gracefully)
645+
if state.listener_thread and state.listener_thread.is_alive():
646+
state.listener_thread.join(timeout=2.0)
647+
648+
# Terminate rank monitor processes
649+
for local_rank, state in self._rank_monitors.items():
600650
with contextlib.suppress(Exception):
601-
rmon_proc.terminate()
651+
state.process.terminate()
602652
with contextlib.suppress(Exception):
603-
rmon_proc.join()
653+
state.process.join()
604654
with contextlib.suppress(Exception):
605655
os.unlink(self.get_rank_mon_socket_path(local_rank))
606656

657+
async def _connect_to_rank_monitor(self, local_rank: int, rmon_proc) -> None:
658+
"""Start listener thread for rank monitor bidirectional IPC.
659+
660+
Note: This is called after rank_monitor_ready_event is set, which guarantees
661+
the socket file already exists. The actual connection is created inside the
662+
background thread's event loop to avoid event loop conflicts.
663+
"""
664+
launcher_to_rmon_socket = f"{tempfile.gettempdir()}/_ft_launcher{rmon_proc.pid}_to_rmon.socket"
665+
666+
state = self._rank_monitors[local_rank]
667+
state.socket_path = launcher_to_rmon_socket
668+
state.stop_event = threading.Event()
669+
670+
# Start listener thread (will create connection in its own event loop)
671+
state.listener_thread = threading.Thread(
672+
target=self._listen_to_rank_monitor_thread,
673+
args=(local_rank, launcher_to_rmon_socket, state.stop_event),
674+
daemon=True,
675+
name=f"RankMonitor-{local_rank}-Listener"
676+
)
677+
state.listener_thread.start()
678+
679+
def _update_progress_iteration(self, local_rank: int, iteration: int):
680+
"""Update iteration for a specific rank and aggregate using MIN strategy."""
681+
# Update this rank's max iteration
682+
self._rank_iterations[local_rank] = max(
683+
self._rank_iterations.get(local_rank, 0), iteration
684+
)
685+
686+
# Use minimum across all ranks (most conservative - slowest rank determines progress)
687+
min_iteration = min(self._rank_iterations.values()) if self._rank_iterations else 0
688+
self._progress_tracker.update_iteration(min_iteration)
689+
690+
logger.debug(
691+
f"Updated iteration for rank {local_rank}={iteration}, "
692+
f"cluster min={min_iteration}, all_ranks={self._rank_iterations}"
693+
)
694+
695+
def _listen_to_rank_monitor_thread(self, local_rank: int, socket_path: str, stop_event: threading.Event) -> None:
696+
"""Listen for messages from rank monitor in a background thread.
697+
698+
This runs in a separate thread with its own event loop to receive messages
699+
from the rank monitor server. The connection is created in this thread's
700+
event loop to avoid cross-loop conflicts.
701+
"""
702+
# Create a new event loop for this thread
703+
loop = asyncio.new_event_loop()
704+
asyncio.set_event_loop(loop)
705+
706+
async def listen_loop():
707+
try:
708+
# Create connection in THIS thread's event loop
709+
reader, writer = await asyncio.open_unix_connection(socket_path)
710+
711+
try:
712+
while not stop_event.is_set():
713+
# Use wait_for with timeout to allow checking stop_event periodically
714+
try:
715+
msg = await asyncio.wait_for(read_obj_from_ipc_stream(reader), timeout=1.0)
716+
if isinstance(msg, dict) and msg.get("type") == "iteration_update":
717+
# Handle iteration update from rank monitor
718+
iteration = msg["iteration"]
719+
self._update_progress_iteration(local_rank, iteration)
720+
# Note: Don't log every iteration update - too chatty during training
721+
else:
722+
logger.debug(f"Received message from rank monitor {local_rank}: {msg}")
723+
except asyncio.TimeoutError:
724+
# Timeout is expected, just check stop_event and continue
725+
continue
726+
finally:
727+
# Clean up connection
728+
writer.close()
729+
await writer.wait_closed()
730+
except (asyncio.IncompleteReadError, ConnectionResetError, BrokenPipeError, EOFError):
731+
logger.debug(f"Rank monitor {local_rank} connection closed")
732+
except Exception as e:
733+
if not stop_event.is_set():
734+
logger.error(f"Error listening to rank monitor {local_rank}: {e}")
735+
736+
try:
737+
loop.run_until_complete(listen_loop())
738+
finally:
739+
loop.close()
740+
607741
def _setup_local_watchdog(self, envs: Dict[int, Dict[str, str]]) -> None:
608742
enable_watchdog_env_name = TORCHELASTIC_ENABLE_FILE_TIMER
609743
watchdog_enabled = os.getenv(enable_watchdog_env_name)
@@ -682,23 +816,7 @@ def _stop_workers(self, worker_group: WorkerGroup, *args, **kwargs) -> None:
682816

683817
logger.info(f"Stopping workers... Timeout = {self._workers_stop_timeout} sec.")
684818

685-
# Send close message to rank monitors
686-
for local_rank, rmon_proc in self._local_rank_to_rmon.items():
687-
try:
688-
launcher_to_rmon_socket = f"{tempfile.gettempdir()}/_ft_launcher{rmon_proc.pid}_to_rmon.socket"
689-
if os.path.exists(launcher_to_rmon_socket):
690-
async def send_close_msg():
691-
reader, writer = await asyncio.open_unix_connection(launcher_to_rmon_socket)
692-
try:
693-
await write_obj_to_ipc_stream("close_worker_ipc_connection", writer)
694-
finally:
695-
writer.close()
696-
await writer.wait_closed()
697-
698-
asyncio.run(send_close_msg())
699-
except Exception as e:
700-
logger.warning(f"Failed to send close message to rank monitor {local_rank}: {e}")
701-
819+
# Rank monitors will detect worker shutdown when worker processes disconnect
702820
self._shutdown(timeout=self._workers_stop_timeout)
703821

704822
# Record worker termination event after shutdown is complete
@@ -2074,6 +2192,32 @@ def get_args_parser() -> ArgumentParser:
20742192
help="Do not raise an error if there is no Fault Tolerance pkg config provided, just use default settings.",
20752193
)
20762194

2195+
#
2196+
# Progress tracking arguments
2197+
#
2198+
2199+
parser.add_argument(
2200+
"--ft-max-no-progress-restarts",
2201+
"--ft-max_no_progress_restarts",
2202+
type=int,
2203+
default=3,
2204+
dest="ft_max_no_progress_restarts",
2205+
help="Maximum consecutive restarts without progress before early termination. "
2206+
"Progress tracking is enabled when this value > 0. "
2207+
"Set to 0 or -1 to disable progress tracking. "
2208+
"Default: 3 (progress tracking enabled).",
2209+
)
2210+
2211+
parser.add_argument(
2212+
"--ft-min-progress-iterations",
2213+
"--ft-min_progress_iterations",
2214+
type=int,
2215+
default=200,
2216+
dest="ft_min_progress_iterations",
2217+
help="Minimum iterations required to consider a restart as making progress. "
2218+
"Default: 200.",
2219+
)
2220+
20772221
#
20782222
# Positional arguments.
20792223
#

0 commit comments

Comments
 (0)