Skip to content

Commit 4db69c1

Browse files
authored
Merge pull request #1759 from NVIDIA-NeMo/ko3n1g/chore/test-ft
2 parents 1f83587 + 799e839 commit 4db69c1

File tree

3 files changed

+104
-4
lines changed

3 files changed

+104
-4
lines changed
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import logging
2+
from dataclasses import dataclass
3+
from typing import Callable, List, Optional, Union
4+
5+
import nemo_run as run
6+
from nemo_run import Plugin
7+
8+
9+
logging.basicConfig(level=logging.DEBUG)
10+
logger = logging.getLogger(__name__)
11+
12+
13+
@dataclass
14+
class FaultTolerancePluginScriptArgs:
15+
"""Arguments for FaultTolerancePlugin to pass to run.Script."""
16+
17+
enable_ft_package: bool
18+
calc_ft_timeouts: bool
19+
20+
21+
def _default_fault_tolerance_converter(args: FaultTolerancePluginScriptArgs) -> List[str]:
22+
"""Default converter for FaultTolerancePlugin that generates hydra-style overrides."""
23+
return [
24+
f"ft.enable_ft_package={str(args.enable_ft_package).lower()}",
25+
f"ft.calc_ft_timeouts={str(args.calc_ft_timeouts).lower()}",
26+
]
27+
28+
29+
@dataclass(kw_only=True)
30+
class FaultTolerancePlugin(Plugin):
31+
"""
32+
A plugin for setting up fault tolerance configuration.
33+
This plugin enables workload hang detection, automatic calculation of timeouts used for hang detection,
34+
detection of rank(s) terminated due to an error and workload respawning in case of a failure.
35+
36+
37+
Args:
38+
enable_ft_package (bool): Enable the fault tolerance package. Default is True.
39+
calc_ft_timeouts (bool): Automatically compute timeouts. Default is True.
40+
num_in_job_restarts (int): Max number of restarts on failure, within the same job. Default is 3.
41+
num_job_retries_on_failure (int): Max number of new job restarts on failure. Default is 2.
42+
initial_rank_heartbeat_timeout (int): Timeouts are time intervals used by a rank monitor to detect
43+
that a rank is not alive. This is the max timeout for the initial heartbeat. Default is 1800.
44+
rank_heartbeat_timeout (int): This is the timeout for subsequent hearbeats after the initial heartbeat.
45+
Default is 300.
46+
script_args_converter_fn (Optional[Callable]): A function that takes FaultTolerancePluginScriptArgs
47+
and returns a list of CLI arguments. If not provided,
48+
uses the default hydra-style converter.
49+
50+
Note:
51+
This plugin is incompatible with NsysPlugin. Nsys profiling cannot be used when fault tolerance
52+
is enabled.
53+
"""
54+
55+
enable_ft_package: bool = True
56+
calc_ft_timeouts: bool = True
57+
num_in_job_restarts: int = 3
58+
num_job_retries_on_failure: int = 2
59+
initial_rank_heartbeat_timeout: int = 1800
60+
rank_heartbeat_timeout: int = 300
61+
script_args_converter_fn: Optional[Callable[[FaultTolerancePluginScriptArgs], List[str]]] = None
62+
63+
def setup(self, task: Union["run.Partial", "run.Script"], executor: "run.Executor"):
64+
"""Set up the fault tolerance plugin."""
65+
# Set up fault tolerance launcher for both task types
66+
executor.launcher = run.FaultTolerance(
67+
max_restarts=self.num_in_job_restarts,
68+
initial_rank_heartbeat_timeout=self.initial_rank_heartbeat_timeout,
69+
rank_heartbeat_timeout=self.rank_heartbeat_timeout,
70+
)
71+
executor.retries = self.num_job_retries_on_failure
72+
73+
if isinstance(task, run.Script):
74+
# For run.Script, append CLI overrides to the script arguments
75+
# Create args dataclass
76+
script_args = FaultTolerancePluginScriptArgs(
77+
enable_ft_package=self.enable_ft_package,
78+
calc_ft_timeouts=self.calc_ft_timeouts,
79+
)
80+
81+
# Use custom converter or default
82+
converter = self.script_args_converter_fn or _default_fault_tolerance_converter
83+
cli_overrides = converter(script_args)
84+
85+
task.args.extend(cli_overrides)
86+
logger.info(f"{self.__class__.__name__} added CLI overrides: {', '.join(cli_overrides)}")
87+
else:
88+
raise NotImplementedError("FaultTolerancePlugin is only supported for run.Script tasks")

scripts/performance/setup_experiment.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from pathlib import Path
2323
from typing import Any, Dict, List, Optional
2424

25+
import nemo_run as run
2526
from nemo_run.config import get_nemorun_home
2627

2728

@@ -34,9 +35,6 @@
3435
from .utils.evaluate import calc_convergence_and_performance
3536
from .utils.executors import dgxc_executor, slurm_executor
3637

37-
import nemo_run as run
38-
39-
4038
try:
4139
import wandb
4240

@@ -46,8 +44,10 @@
4644

4745
try:
4846
from perf_plugins import NsysPlugin, PerfEnvPlugin
47+
from resiliency_plugins import FaultTolerancePlugin
4948
except (ImportError, ModuleNotFoundError):
5049
from .perf_plugins import NsysPlugin, PerfEnvPlugin
50+
from .resiliency_plugins import FaultTolerancePlugin
5151

5252
import logging
5353

@@ -327,6 +327,18 @@ def main(
327327
)
328328
)
329329

330+
if use_recipes:
331+
plugins.append(
332+
FaultTolerancePlugin(
333+
enable_ft_package=True,
334+
calc_ft_timeouts=True,
335+
num_in_job_restarts=10,
336+
num_job_retries_on_failure=10,
337+
initial_rank_heartbeat_timeout=1800,
338+
rank_heartbeat_timeout=300,
339+
)
340+
)
341+
330342
nemorun_script = run.Script(
331343
path=str(run_script_path),
332344
entrypoint="python",

scripts/performance/utils/executors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,6 @@ def dgxc_executor(
223223
else {}
224224
),
225225
env_vars=env_vars,
226-
launcher="ft",
226+
launcher="torchrun",
227227
)
228228
return executor

0 commit comments

Comments
 (0)