Skip to content

Commit 04e2a79

Browse files
Merge pull request #196 from hexinw-nvidia/tflops
feat: Add infrastructure rank support and optimize section monitoring
2 parents 258e007 + 4a4a173 commit 04e2a79

File tree

11 files changed

+553
-40
lines changed

11 files changed

+553
-40
lines changed

docs/source/fault_tolerance/usage_guide.rst

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,35 @@ The restart behavior depends on the ``--ft-restart-policy`` parameter, which sup
6363
falls below the minimum specified in ``--nnodes``. This allows for some worker failures to be handled
6464
without restarting remaining workers, e.g., with the :doc:`../inprocess/index`.
6565
For details on how ``min-healthy`` policy interacts with :doc:`../inprocess/index` see :doc:`integration/inprocess`.
66+
67+
Rank assignment
68+
^^^^^^^^^^^^^^^
69+
70+
The ``ft_launcher`` assigns ranks to workers during the rendezvous process.
71+
72+
**Infrastructure-based assignment (default):**
73+
74+
By default (``--ft-use-infra-group-rank=True``), rank assignments **always** come from the infrastructure:
75+
76+
* The launcher first checks ``SLURM_PROCID`` (automatically set in SLURM environments)
77+
* If not available, it falls back to ``GROUP_RANK`` (set by ``ft_launcher`` itself)
78+
79+
Infrastructure ranks are used for **every rendezvous**, including after failures/restarts. Previous
80+
rank assignments are ignored. This ensures consistency with the infrastructure's rank assignment,
81+
which is important for static deployments and proper resource allocation.
82+
83+
.. note::
84+
Hot spare/redundancy is **NOT supported** with ``use_infra_group_rank=True`` because dynamic
85+
rendezvous cannot guarantee that lower infrastructure ranks will join as participants first.
86+
87+
**Deterministic assignment (alternative):**
88+
89+
Set ``--ft-use-infra-group-rank=False`` (or ``use_infra_group_rank: false`` in config) to use
90+
deterministic sorted assignment based on node descriptors. In this mode:
91+
92+
* Previous rank assignments are preserved when possible
93+
* New workers fill gaps left by failed workers
94+
* Ranks are reassigned based on sorted node descriptors
6695

6796

6897
Hang detection

examples/fault_tolerance/fault_tol_cfg_heartbeats.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@ fault_tolerance:
22
initial_rank_heartbeat_timeout: null
33
rank_heartbeat_timeout: null
44
log_level: "DEBUG"
5+
# use_infra_group_rank: true # Default: Use infrastructure ranks (SLURM_PROCID/GROUP_RANK) on initial rendezvous

examples/fault_tolerance/fault_tol_cfg_sections.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@ fault_tolerance:
77
checkpoint: 30
88
rank_out_of_section_timeout: 30
99
log_level: "DEBUG"
10+
# use_infra_group_rank: true # Default: Use infrastructure ranks (SLURM_PROCID/GROUP_RANK) on initial rendezvous

src/nvidia_resiliency_ext/fault_tolerance/_ft_rendezvous.py

Lines changed: 82 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,10 @@ class RendezvousSettings:
233233
If set to True (default), nodes from the redundancy list and new arrivals are migrated
234234
to the wait list. If set to False, new arrivals will be moved to the redundancy list
235235
and will wait there until the next rendezvous round.
236+
use_infra_group_rank:
237+
Whether to use infrastructure group rank for rank assignment instead of sorted
238+
participant-based assignment. If True, ranks are read from SLURM_PROCID (in SLURM
239+
environments) or GROUP_RANK (set by launcher) environment variables.
236240
"""
237241

238242
run_id: str
@@ -242,6 +246,7 @@ class RendezvousSettings:
242246
keep_alive_interval: timedelta
243247
keep_alive_max_attempt: int
244248
upscaling_enabled: bool = True
249+
use_infra_group_rank: bool = True
245250

246251

247252
@dataclass(eq=True, order=True, frozen=True)
@@ -789,8 +794,22 @@ def _add_to_participants(self) -> None:
789794
log.debug(f"Node {self._node} was not in the wait list.")
790795

791796
# The ranks of the participants will be set once the rendezvous is
792-
# complete.
793-
state.participants[self._node] = 0
797+
# complete. If use_infra_group_rank is enabled, store the infrastructure
798+
# rank (SLURM_PROCID or GROUP_RANK) here; otherwise, use placeholder -1.
799+
if self._settings.use_infra_group_rank:
800+
# Try SLURM_PROCID first (set by SLURM), then fall back to GROUP_RANK (set by launcher)
801+
infra_rank_str = os.getenv('SLURM_PROCID', os.getenv('GROUP_RANK', '-1'))
802+
infra_rank = int(infra_rank_str)
803+
if infra_rank < 0:
804+
raise ValueError(
805+
"use_infra_group_rank is enabled but neither SLURM_PROCID nor GROUP_RANK "
806+
"environment variable is set. Please set one of these environment variables "
807+
"or disable use_infra_group_rank."
808+
)
809+
state.participants[self._node] = infra_rank
810+
log.debug(f"Node {self._node} stored infrastructure rank {infra_rank} from environment")
811+
else:
812+
state.participants[self._node] = 0
794813

795814
self._keep_alive()
796815

@@ -874,16 +893,61 @@ def _remove_from_redundancy_list(self) -> None:
874893

875894
@staticmethod
876895
def _assign_ranks(
877-
participants: Dict[_NodeDesc, int], prev: Dict[_NodeDesc, int]
896+
participants: Dict[_NodeDesc, int],
897+
prev: Dict[_NodeDesc, int],
898+
use_infra_group_rank: bool = False,
878899
) -> Dict[_NodeDesc, int]:
879-
# Assign ranks. Re-use assigment from the previous round as much as possible
900+
"""
901+
Assign ranks to participants in the rendezvous.
902+
903+
Behavior depends on use_infra_group_rank:
904+
905+
1. If use_infra_group_rank=True:
906+
- ALWAYS use infrastructure ranks directly from SLURM_PROCID or GROUP_RANK
907+
- Previous assignments are ignored
908+
- Validates that all ranks are in range [0, world_size) and unique
909+
- Ensures consistency with infrastructure's rank assignment
910+
- Note: Hot spare/redundancy is NOT supported in this mode as dynamic
911+
rendezvous cannot guarantee lower ranks join as participants first
912+
913+
2. If use_infra_group_rank=False:
914+
- Use deterministic assignment, preserving previous ranks when possible
915+
- Fill gaps left by failed nodes with new participants
916+
917+
Args:
918+
participants: Dict mapping node descriptors to infrastructure ranks
919+
prev: Dict of previous rank assignments (empty on first rendezvous)
920+
use_infra_group_rank: If True, always use infrastructure ranks
921+
922+
Returns:
923+
Dict mapping node descriptors to assigned ranks
924+
"""
925+
# If use_infra_group_rank is enabled, use the infrastructure ranks directly
926+
if use_infra_group_rank:
927+
# Validate that all participants have valid infrastructure ranks
928+
for node, rank in participants.items():
929+
if rank < 0 or rank >= len(participants):
930+
raise ValueError(
931+
f"Invalid infrastructure rank {rank} for node {node}. "
932+
f"Expected rank in range [0, {len(participants)})"
933+
)
934+
# Check for duplicate ranks
935+
ranks_set = set(participants.values())
936+
if len(ranks_set) != len(participants):
937+
raise ValueError(
938+
f"Duplicate infrastructure ranks detected in participants: {participants}"
939+
)
940+
log.debug(f"Using infrastructure ranks directly: {participants}")
941+
return dict(participants)
942+
943+
# Default behavior: Assign ranks. Re-use assignment from the previous round as much as possible
880944
world_size = len(participants)
881945
sorted_keys = sorted(participants.keys())
882946
free_ranks = set(range(world_size))
883947
res = {}
884948
for p in sorted_keys:
885949
prev_rank = prev.get(p, -1)
886-
if prev_rank >= 0 and prev_rank < world_size:
950+
if prev_rank >= 0 and prev_rank < world_size and prev_rank in free_ranks:
887951
# if this node can have the same rank, use it
888952
res[p] = prev_rank
889953
free_ranks.remove(prev_rank)
@@ -920,7 +984,9 @@ def _mark_rendezvous_complete(self) -> None:
920984
state.wait_list.clear()
921985

922986
# Will try to preserve node<->rank mapping
923-
state.participants = self._assign_ranks(state.participants, self._prev_participants)
987+
state.participants = self._assign_ranks(
988+
state.participants, self._prev_participants, self._settings.use_infra_group_rank
989+
)
924990

925991
# Set initial worker states, assume all workers are healthy at the beginning
926992
state.worker_states = {n: WorkerState.HEALTHY for n in state.participants}
@@ -1156,6 +1222,7 @@ def from_backend(
11561222
local_addr: Optional[str] = None,
11571223
timeout: Optional[RendezvousTimeout] = None,
11581224
upscaling_enabled: bool = True,
1225+
use_infra_group_rank: bool = False,
11591226
):
11601227
"""Create a new :py:class:`FtRendezvousHandler`.
11611228
@@ -1176,6 +1243,8 @@ def from_backend(
11761243
The timeout configuration of the rendezvous.
11771244
upscaling_enabled:
11781245
Whether to enable upscaling of a completed rendezvous with redundant or new nodes.
1246+
use_infra_group_rank:
1247+
Whether to use infrastructure group rank for rank assignment.
11791248
"""
11801249
# We associate each handler instance with a unique node descriptor.
11811250
node = cls._node_desc_generator.generate(local_addr)
@@ -1188,6 +1257,7 @@ def from_backend(
11881257
keep_alive_interval=timedelta(seconds=5),
11891258
keep_alive_max_attempt=3,
11901259
upscaling_enabled=upscaling_enabled,
1260+
use_infra_group_rank=use_infra_group_rank,
11911261
)
11921262

11931263
state_holder = _BackendRendezvousStateHolder(backend, settings)
@@ -1657,6 +1727,10 @@ def create_handler(
16571727
| | :py:meth:`RendezvousHandler.shutdown`. Defaults to |
16581728
| | 30 seconds. |
16591729
+-------------------+------------------------------------------------------+
1730+
| use_infra_group_ | Whether to always use infrastructure group rank for |
1731+
| rank | rank assignment. Previous assignments are ignored. |
1732+
| | Hot spare/redundancy NOT supported. Defaults to True.|
1733+
+-------------------+------------------------------------------------------+
16601734
"""
16611735
try:
16621736
timeout = RendezvousTimeout(
@@ -1667,6 +1741,7 @@ def create_handler(
16671741

16681742
# torchrun default behaviour if not specified otherwise
16691743
upscale_completed = params.config.get('upscaling_enabled', True)
1744+
use_infra_group_rank = params.config.get('use_infra_group_rank', True)
16701745

16711746
return FtRendezvousHandler.from_backend(
16721747
params.run_id,
@@ -1677,6 +1752,7 @@ def create_handler(
16771752
params.local_addr,
16781753
timeout,
16791754
upscaling_enabled=upscale_completed,
1755+
use_infra_group_rank=use_infra_group_rank,
16801756
)
16811757
except Exception as e:
16821758
construct_and_record_rdzv_event(

src/nvidia_resiliency_ext/fault_tolerance/config.py

Lines changed: 55 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,23 @@ class FaultToleranceConfig:
4444
* `rank_termination_signal` signal used to terminate the rank when failure is detected.
4545
* `log_level` log level of fault tolerance components
4646
* `rank_section_timeouts` Mapping[str,float|None] timeouts for specific sections in user code.
47+
Only sections listed here will send IPC messages to the monitor server and collect timing data.
48+
Sections not in this mapping will have near-zero overhead (no IPC, no timing collection).
4749
* `rank_out_of_section_timeout` [float|None] the timeout used for implicit/default section,
4850
that spans code not wrapped in any other section.
4951
* `restart_check_interval` - interval between checks if restart is in progress, needed for layered restart protocol
50-
* `enable_nic_monitor` - Enable NIC health monitoring in training.
52+
* `enable_nic_monitor` - Enable NIC health monitoring in training. Default: False.
5153
* `pci_topo_file` - PCI topo file that describes GPU and NIC topology.
5254
* `link_down_path_template` - Template path for NIC link down files. Should contain '{dev_name}'
5355
placeholder which will be replaced with actual NIC device name.
56+
* `skip_section_response` - If True, section and heartbeat messages are sent without waiting
57+
for server response (unidirectional communication). This significantly reduces latency for
58+
high-frequency operations. Server logs errors instead of sending them back.
59+
Default: True (recommended for production). Set to False during development to catch errors immediately.
60+
* `use_infra_group_rank` - If True, always use infrastructure group rank for rank assignment.
61+
Reads from SLURM_PROCID (in SLURM environments) or GROUP_RANK (set by launcher). Previous
62+
rank assignments are ignored to ensure consistency with infrastructure's rank assignment.
63+
Note: Hot spare/redundancy is NOT supported with this setting. Default: True.
5464
5565
If any timeout is None, it has no effect (as if it was +INF).
5666
All timeouts can be deduced and set during runtime.
@@ -66,9 +76,11 @@ class FaultToleranceConfig:
6676
rank_termination_signal: signal.Signals = signal.SIGKILL
6777
log_level: int = logging.INFO
6878
restart_check_interval: float = 60.0
69-
enable_nic_monitor: bool = True
79+
enable_nic_monitor: bool = False
7080
pci_topo_file: Optional[str] = None
7181
link_down_path_template: Optional[str] = None
82+
skip_section_response: bool = True
83+
use_infra_group_rank: bool = True
7284

7385
@staticmethod
7486
def from_kwargs(ignore_not_recognized: bool = True, **kwargs) -> 'FaultToleranceConfig':
@@ -121,11 +133,37 @@ def from_yaml_file(cfg_path: str, ignore_not_recognized: bool = True) -> 'FaultT
121133
else:
122134
raise ValueError(f"'fault_tolerance' section not found in config file {cfg_path}")
123135

136+
@staticmethod
137+
def _parse_timeout_arg(timeout_arg: str) -> Optional[float]:
138+
"""
139+
Parse a timeout CLI argument.
140+
Timeout can be a float or 'None'/'null'/'' to represent None.
141+
142+
Args:
143+
timeout_arg (str): The timeout value as a string
144+
145+
Returns:
146+
Optional[float]: The parsed timeout value or None
147+
"""
148+
timeout_arg = timeout_arg.strip()
149+
if timeout_arg.lower() in ['none', 'null', '']:
150+
return None
151+
else:
152+
return float(timeout_arg)
153+
124154
@staticmethod
125155
def _parse_section_timeouts_arg(section_timeouts_arg: str) -> Mapping[str, Optional[float]]:
126-
# Parse section timeouts CLI argument, expected format is:
127-
# "section1:timeout1,section2:timeout2,..."
128-
# Timeout can be float or 'None'/'null'/'' to represent None.
156+
"""
157+
Parse section timeouts CLI argument.
158+
Expected format: "section1:timeout1,section2:timeout2,..."
159+
Timeout can be a float or 'None'/'null'/'' to represent None.
160+
161+
Args:
162+
section_timeouts_arg (str): The section timeouts string
163+
164+
Returns:
165+
Mapping[str, Optional[float]]: Dictionary mapping section names to timeout values
166+
"""
129167
section_timeouts_arg = section_timeouts_arg.strip()
130168
if not section_timeouts_arg:
131169
return {}
@@ -135,10 +173,7 @@ def _parse_section_timeouts_arg(section_timeouts_arg: str) -> Mapping[str, Optio
135173
section, timeout = st.split(":")
136174
section = section.strip()
137175
timeout = timeout.strip()
138-
if timeout.lower() in ['none', 'null', '']:
139-
res[section] = None
140-
else:
141-
res[section] = float(timeout)
176+
res[section] = FaultToleranceConfig._parse_timeout_arg(timeout)
142177
return res
143178

144179
@staticmethod
@@ -167,12 +202,23 @@ def from_args(args: argparse.Namespace):
167202

168203
# Extract FT args from CLI
169204
cli_ft_args = {}
205+
timeout_fields = [
206+
'initial_rank_heartbeat_timeout',
207+
'rank_heartbeat_timeout',
208+
'rank_out_of_section_timeout',
209+
'workload_check_interval',
210+
'node_health_check_interval',
211+
'safety_factor',
212+
'restart_check_interval',
213+
]
170214
for field in fields(FaultToleranceConfig):
171215
cli_field_name = f"ft_{field.name}"
172216
val = getattr(args, cli_field_name, None)
173217
if val is not None:
174218
if field.name == "rank_section_timeouts" and isinstance(val, str):
175219
val = FaultToleranceConfig._parse_section_timeouts_arg(val)
220+
elif field.name in timeout_fields and isinstance(val, str):
221+
val = FaultToleranceConfig._parse_timeout_arg(val)
176222
cli_ft_args[field.name] = val
177223

178224
# Update config with CLI args

0 commit comments

Comments
 (0)