Skip to content

Commit e684393

Browse files
committed
feat: Add progress-based early termination for in-job restarts
Implement centralized training progress tracking to automatically terminate jobs when restarts fail to make meaningful progress, preventing wasted compute on jobs stuck in restart loops. Problem: -------- The InJob restart would continue restarting failed jobs up to max_restarts regardless of whether restarts were effective. Jobs that crash repeatedly wast compute time across many restart attempts before exhausting the restart budget. Solution: --------- - Immediate arming: Rank monitors send iteration=1 on init to enable tracking - Periodic update: Rank monitors send iteration updates every 30 seconds. - Early termination after 3 consecutive restarts with <200 iteration progress - Works for both fast-failing (seconds) and slow-failing (minutes/hours) jobs Configuration: -------------- --ft-max-no-progress-restarts=3 # Max restarts without progress (0 = disabled) --ft-min-progress-iterations=200 # Min iterations to show progress --ft-progress-update-interval=30.0 # Update frequency (seconds)
1 parent 9801d27 commit e684393

File tree

6 files changed

+451
-40
lines changed

6 files changed

+451
-40
lines changed

src/nvidia_resiliency_ext/fault_tolerance/config.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,15 @@ class FaultToleranceConfig:
8585
skip_section_response: bool = True
8686
use_infra_group_rank: bool = True
8787
numa_bind_strict: bool = False
88+
# Progress tracking configuration (controlled by max_no_progress_restarts)
89+
max_no_progress_restarts: int = 3
90+
min_progress_iterations: int = 200
91+
progress_update_interval: float = 30.0 # Seconds between sending progress updates to launcher
92+
93+
@property
94+
def is_progress_tracking_enabled(self) -> bool:
95+
"""Check if progress tracking is enabled (controlled by max_no_progress_restarts > 0)."""
96+
return self.max_no_progress_restarts > 0
8897

8998
@staticmethod
9099
def from_kwargs(ignore_not_recognized: bool = True, **kwargs) -> 'FaultToleranceConfig':

src/nvidia_resiliency_ext/fault_tolerance/data.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,10 +188,12 @@ def __init__(
188188
rank: int,
189189
section: str,
190190
action: SectionAction,
191+
iteration: Optional[int] = None,
191192
):
192193
self.rank = rank
193194
self.section = section
194195
self.action = action
196+
self.iteration = iteration
195197

196198

197199
class UpdateConfigMsg:

src/nvidia_resiliency_ext/fault_tolerance/launcher.py

Lines changed: 171 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,11 @@
6969
FT_LAUNCHER_IPC_SOCKET_ENV_VAR,
7070
FT_RANK_MONITOR_IPC_SOCKET_ENV_VAR,
7171
)
72+
from nvidia_resiliency_ext.fault_tolerance.progress_tracker import TrainingProgressTracker
7273
from nvidia_resiliency_ext.fault_tolerance.rank_monitor_server import RankMonitorServer
7374
from nvidia_resiliency_ext.fault_tolerance.utils import (
7475
patched_method,
76+
read_obj_from_ipc_stream,
7577
terminate_mp_processes,
7678
write_obj_to_ipc_stream,
7779
)
@@ -192,6 +194,15 @@ def _wrap_entrypoint_with_numactl(
192194
# https://github.com/pytorch/pytorch/blob/release/2.3/torch/distributed/elastic/agent/server/local_elastic_agent.py
193195

194196

197+
@dataclass
198+
class RankMonitorState:
199+
"""State for a single rank monitor process and its IPC connections."""
200+
process: Any # multiprocessing.Process
201+
reader: Optional[asyncio.StreamReader] = None
202+
writer: Optional[asyncio.StreamWriter] = None
203+
listener_task: Optional[asyncio.Task] = None
204+
205+
195206
class LocalElasticAgent(SimpleElasticAgent):
196207
"""An implementation of :py:class:`torchelastic.agent.server.ElasticAgent` that handles host-local workers.
197208
@@ -317,8 +328,15 @@ def __init__(
317328
self._term_timeout = term_timeout
318329
self._workers_stop_timeout = workers_stop_timeout
319330
self._is_store_host = is_store_host
320-
self._local_rank_to_rmon: Dict[int, Any] = dict()
331+
# Rank monitor state (process, IPC connections, listener tasks) per local rank
332+
self._rank_monitors: Dict[int, RankMonitorState] = dict()
321333
self._ft_cfg = fault_tol_cfg
334+
# Centralized progress tracking (always instantiated, active only if configured)
335+
self._progress_tracker = TrainingProgressTracker(
336+
min_progress_iterations=fault_tol_cfg.min_progress_iterations,
337+
max_no_progress_restarts=fault_tol_cfg.max_no_progress_restarts,
338+
)
339+
self._rank_iterations: Dict[int, int] = dict() # Track max iteration per rank
322340
self._children_pgids: Set[int] = set()
323341
self._restart_policy = restart_policy
324342
self._node_id = self._get_fq_hostname()
@@ -367,7 +385,7 @@ def _open_rendezvous_for_restart(self):
367385
self._worker_group.group_rank if self._worker_group else "N/A"
368386
)
369387
except Exception as e:
370-
logger.warning(f"Failed to open rendezvous: {e}")
388+
logger.error(f"Failed to open rendezvous: {e}")
371389
# For legacy rendezvous, no action needed - it uses different mechanism
372390

373391
def _invoke_run(self, role: str = DEFAULT_ROLE) -> RunResult:
@@ -420,7 +438,16 @@ def _invoke_run_with_any_failed_policy(self, role: str = DEFAULT_ROLE) -> RunRes
420438
rank=self._worker_group.group_rank,
421439
)
422440

423-
if self._remaining_restarts > 0:
441+
self._progress_tracker.analyze_previous_cycle()
442+
should_terminate_early = self._progress_tracker.should_terminate_early()
443+
444+
if should_terminate_early:
445+
logger.error(
446+
"[%s] Progress tracker detected no progress across restarts. "
447+
"No more restarts will be attempted.",
448+
role
449+
)
450+
elif self._remaining_restarts > 0:
424451
logger.info(
425452
"[%s] Worker group %s. "
426453
"%s/%s attempts left;"
@@ -434,14 +461,13 @@ def _invoke_run_with_any_failed_policy(self, role: str = DEFAULT_ROLE) -> RunRes
434461
# Open rendezvous before restarting (for barrier-based rendezvous)
435462
self._open_rendezvous_for_restart()
436463
self._restart_workers(self._worker_group)
437-
else:
438-
self._stop_workers(self._worker_group)
439-
self._worker_group.state = WorkerState.FAILED
440-
# to preserve torchrun's behaviour, should not return WorkerState.UNHEALTHY.
441-
# we use WorkerState.UNHEALTHY to denote a worker group that is still
442-
# running but has some failed workers. torchrun does not use WorkerState.UNHEALTHY
443-
run_result = self._monitor_workers(self._worker_group)
444-
return run_result
464+
continue # Continue monitoring after restart
465+
466+
# No more restarts (either exhausted or early termination)
467+
self._stop_workers(self._worker_group)
468+
self._worker_group.state = WorkerState.FAILED
469+
run_result = self._monitor_workers(self._worker_group)
470+
return run_result
445471
elif state == WorkerState.HEALTHY:
446472
# Check for cluster-wide issues: unhealthy nodes or new nodes waiting
447473
unhealthy_count = self._check_cluster_unhealthy_count()
@@ -579,31 +605,113 @@ def get_rank_mon_socket_path(self, local_rank):
579605

580606
def setup_rank_monitors(self, envs: Dict[int, Dict[str, str]]) -> None:
581607
fork_mp_ctx = torch.multiprocessing.get_context("fork")
608+
new_monitors = [] # Track newly started monitors
609+
582610
for worker_env in envs.values():
583611
# Start rank monitors if not already started
584612
# Each rank (re)connects to its rank monitor when it starts
585613
# Monitor of the local rank0 on the store hosting node is the restarter logger
586614
local_rank = int(worker_env['LOCAL_RANK'])
587615
is_restarter_logger = self._is_store_host and local_rank == 0
588616
rmon_ipc_socket = worker_env[FT_RANK_MONITOR_IPC_SOCKET_ENV_VAR]
589-
if local_rank not in self._local_rank_to_rmon:
590-
self._local_rank_to_rmon[local_rank] = RankMonitorServer.run_in_subprocess(
617+
if local_rank not in self._rank_monitors:
618+
rmon_proc = RankMonitorServer.run_in_subprocess(
591619
cfg=self._ft_cfg,
592620
ipc_socket_path=rmon_ipc_socket,
593621
is_restarter_logger=is_restarter_logger,
594622
mp_ctx=fork_mp_ctx,
595623
env=worker_env,
596624
)
625+
self._rank_monitors[local_rank] = RankMonitorState(process=rmon_proc)
626+
new_monitors.append((local_rank, rmon_proc))
627+
628+
# Establish bidirectional IPC connections to new rank monitors
629+
if new_monitors:
630+
async def connect_all():
631+
await asyncio.gather(
632+
*[self._connect_to_rank_monitor(lr, rmon) for lr, rmon in new_monitors]
633+
)
634+
asyncio.run(connect_all())
597635

598636
def shutdown_rank_monitors(self):
599-
for local_rank, rmon_proc in self._local_rank_to_rmon.items():
637+
# Stop listener tasks, close connections, and send shutdown messages
638+
for local_rank, state in self._rank_monitors.items():
639+
# Cancel listener task
640+
if state.listener_task and not state.listener_task.done():
641+
state.listener_task.cancel()
642+
643+
# Close connection with shutdown message
644+
if state.writer:
645+
try:
646+
async def send_shutdown():
647+
await write_obj_to_ipc_stream("shutdown", state.writer)
648+
state.writer.close()
649+
await state.writer.wait_closed()
650+
asyncio.run(send_shutdown())
651+
except Exception as e:
652+
logger.debug(f"Error closing rank monitor connection for rank {local_rank}: {e}")
653+
654+
# Terminate rank monitor processes
655+
for local_rank, state in self._rank_monitors.items():
600656
with contextlib.suppress(Exception):
601-
rmon_proc.terminate()
657+
state.process.terminate()
602658
with contextlib.suppress(Exception):
603-
rmon_proc.join()
659+
state.process.join()
604660
with contextlib.suppress(Exception):
605661
os.unlink(self.get_rank_mon_socket_path(local_rank))
606662

663+
async def _connect_to_rank_monitor(self, local_rank: int, rmon_proc) -> None:
664+
"""Establish persistent connection to rank monitor for bidirectional IPC.
665+
666+
Note: This is called after rank_monitor_ready_event is set, which guarantees
667+
the socket file already exists.
668+
"""
669+
launcher_to_rmon_socket = f"{tempfile.gettempdir()}/_ft_launcher{rmon_proc.pid}_to_rmon.socket"
670+
671+
reader, writer = await asyncio.open_unix_connection(launcher_to_rmon_socket)
672+
state = self._rank_monitors[local_rank]
673+
state.reader = reader
674+
state.writer = writer
675+
logger.debug(f"Connected to rank monitor {local_rank} at {launcher_to_rmon_socket}")
676+
677+
# Start listener task for this connection
678+
state.listener_task = asyncio.create_task(self._listen_to_rank_monitor(local_rank, reader))
679+
680+
def _update_progress_iteration(self, local_rank: int, iteration: int):
681+
"""Update iteration for a specific rank and aggregate using MIN strategy."""
682+
# Update this rank's max iteration
683+
self._rank_iterations[local_rank] = max(
684+
self._rank_iterations.get(local_rank, 0), iteration
685+
)
686+
687+
# Use minimum across all ranks (most conservative - slowest rank determines progress)
688+
min_iteration = min(self._rank_iterations.values()) if self._rank_iterations else 0
689+
self._progress_tracker.update_iteration(min_iteration)
690+
691+
logger.debug(
692+
f"Updated iteration for rank {local_rank}={iteration}, "
693+
f"cluster min={min_iteration}, all_ranks={self._rank_iterations}"
694+
)
695+
696+
async def _listen_to_rank_monitor(self, local_rank: int, reader) -> None:
697+
"""Listen for messages from rank monitor."""
698+
try:
699+
while True:
700+
msg = await read_obj_from_ipc_stream(reader)
701+
if isinstance(msg, dict) and msg.get("type") == "iteration_update":
702+
# Handle iteration update from rank monitor
703+
iteration = msg["iteration"]
704+
self._update_progress_iteration(local_rank, iteration)
705+
logger.debug(f"[Rank {local_rank}] Received iteration update: {iteration}")
706+
else:
707+
logger.debug(f"Received message from rank monitor {local_rank}: {msg}")
708+
except (asyncio.IncompleteReadError, ConnectionResetError, BrokenPipeError, EOFError):
709+
logger.debug(f"Rank monitor {local_rank} connection closed")
710+
except asyncio.CancelledError:
711+
logger.debug(f"Listener for rank monitor {local_rank} cancelled")
712+
except Exception as e:
713+
logger.error(f"Error listening to rank monitor {local_rank}: {e}")
714+
607715
def _setup_local_watchdog(self, envs: Dict[int, Dict[str, str]]) -> None:
608716
enable_watchdog_env_name = TORCHELASTIC_ENABLE_FILE_TIMER
609717
watchdog_enabled = os.getenv(enable_watchdog_env_name)
@@ -682,22 +790,27 @@ def _stop_workers(self, worker_group: WorkerGroup, *args, **kwargs) -> None:
682790

683791
logger.info(f"Stopping workers... Timeout = {self._workers_stop_timeout} sec.")
684792

685-
# Send close message to rank monitors
686-
for local_rank, rmon_proc in self._local_rank_to_rmon.items():
687-
try:
688-
launcher_to_rmon_socket = f"{tempfile.gettempdir()}/_ft_launcher{rmon_proc.pid}_to_rmon.socket"
689-
if os.path.exists(launcher_to_rmon_socket):
690-
async def send_close_msg():
691-
reader, writer = await asyncio.open_unix_connection(launcher_to_rmon_socket)
692-
try:
693-
await write_obj_to_ipc_stream("close_worker_ipc_connection", writer)
694-
finally:
695-
writer.close()
696-
await writer.wait_closed()
697-
698-
asyncio.run(send_close_msg())
699-
except Exception as e:
700-
logger.warning(f"Failed to send close message to rank monitor {local_rank}: {e}")
793+
# Send close message to rank monitors through persistent connections
794+
async def send_close_messages():
795+
tasks = []
796+
for local_rank, state in self._rank_monitors.items():
797+
if state.writer:
798+
async def send_msg(writer, local_rank):
799+
await write_obj_to_ipc_stream("close_worker_ipc_connection", writer)
800+
tasks.append(send_msg(state.writer, local_rank))
801+
if tasks:
802+
# return_exceptions=True catches exceptions from send_msg, no need for try-except inside
803+
results = await asyncio.gather(*tasks, return_exceptions=True)
804+
for local_rank, result in zip([lr for lr, s in self._rank_monitors.items() if s.writer], results):
805+
if isinstance(result, Exception):
806+
# Connection errors during shutdown are expected (rank monitor may be dead)
807+
if isinstance(result, (ConnectionError, BrokenPipeError, OSError)):
808+
logger.debug(f"Rank monitor {local_rank} already disconnected: {result}")
809+
else:
810+
logger.warning(f"Unexpected error sending close message to rank monitor {local_rank}: {result}")
811+
812+
if self._rank_monitors:
813+
asyncio.run(send_close_messages())
701814

702815
self._shutdown(timeout=self._workers_stop_timeout)
703816

@@ -2074,6 +2187,32 @@ def get_args_parser() -> ArgumentParser:
20742187
help="Do not raise an error if there is no Fault Tolerance pkg config provided, just use default settings.",
20752188
)
20762189

2190+
#
2191+
# Progress tracking arguments
2192+
#
2193+
2194+
parser.add_argument(
2195+
"--ft-max-no-progress-restarts",
2196+
"--ft-max_no_progress_restarts",
2197+
type=int,
2198+
default=3,
2199+
dest="ft_max_no_progress_restarts",
2200+
help="Maximum consecutive restarts without progress before early termination. "
2201+
"Progress tracking is enabled when this value > 0. "
2202+
"Set to 0 or -1 to disable progress tracking. "
2203+
"Default: 3 (progress tracking enabled).",
2204+
)
2205+
2206+
parser.add_argument(
2207+
"--ft-min-progress-iterations",
2208+
"--ft-min_progress_iterations",
2209+
type=int,
2210+
default=200,
2211+
dest="ft_min_progress_iterations",
2212+
help="Minimum iterations required to consider a restart as making progress. "
2213+
"Default: 200.",
2214+
)
2215+
20772216
#
20782217
# Positional arguments.
20792218
#

0 commit comments

Comments
 (0)