diff --git a/src/nvidia_resiliency_ext/fault_tolerance/launcher.py b/src/nvidia_resiliency_ext/fault_tolerance/launcher.py index db31fe22..08b8962a 100644 --- a/src/nvidia_resiliency_ext/fault_tolerance/launcher.py +++ b/src/nvidia_resiliency_ext/fault_tolerance/launcher.py @@ -408,18 +408,22 @@ def _invoke_run_with_any_failed_policy(self, role: str = DEFAULT_ROLE) -> RunRes rank=self._worker_group.group_rank, ) - logger.info( - "[%s] Detected cluster changes from group_rank=%s " - "(unhealthy_nodes=%s, nodes_waiting=%s); will restart worker group", - role, - group_rank, - unhealthy_count, - num_nodes_waiting, - ) - - # Note: The node that triggered the change (unhealthy or new) already opened - # the rendezvous, so we don't need to open it again here. - self._restart_workers(self._worker_group) + if self._remaining_restarts > 0: + logger.info( + "[%s] Detected cluster changes from group_rank=%s " + "(unhealthy_nodes=%s, nodes_waiting=%s); will restart worker group", + role, + group_rank, + unhealthy_count, + num_nodes_waiting, + ) + self._remaining_restarts -= 1 + # Note: The node that triggered the change (unhealthy or new) already opened + # the rendezvous, so we don't need to open it again here. + self._restart_workers(self._worker_group) + else: + self._stop_workers(self._worker_group) + return RunResult(state=WorkerState.FAILED) else: raise Exception(f"[{role}] Worker group in {state.name} state")