Skip to content

Commit 4a02378

Browse files
Merge branch 'main' into socket_mismatch
2 parents 37aec27 + 4dce778 commit 4a02378

File tree

5 files changed

+315
-3
lines changed

5 files changed

+315
-3
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ requires = ["poetry-core>=1.0.0", "pybind11", "setuptools", "wheel"]
3232
build-backend = "poetry.core.masonry.api"
3333

3434
[tool.poetry.dependencies]
35+
nv-one-logger-core = ">=2.1.0"
3536
torch = ">=2.3.0"
3637
packaging = "*"
3738
python = ">=3.10"

src/nvidia_resiliency_ext/fault_tolerance/_ft_rendezvous.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
from nvidia_resiliency_ext.shared_utils.log_manager import LogConfig
6161

6262
from ..shared_utils.health_check import GPUHealthCheck
63+
from ..shared_utils.profiling import ProfilingEvent, record_profiling_event
6364
from .data import WorkloadAction
6465
from .ipc_connector import IpcConnector
6566
from .launcher import FT_LAUNCHER_IPC_SOCKET, UnhealthyNodeException
@@ -1322,6 +1323,12 @@ def next_rendezvous(self) -> Union[RendezvousInfo, Tuple[Store, int, int]]:
13221323
self._record(message=msg)
13231324
log.info(msg)
13241325

1326+
# Record rendezvous start event
1327+
rendezvous_start_event_id = record_profiling_event(
1328+
ProfilingEvent.RENDEZVOUS_STARTED,
1329+
node_id=self._this_node,
1330+
)
1331+
13251332
try:
13261333
self._stop_heartbeats()
13271334

@@ -1362,6 +1369,12 @@ def next_rendezvous(self) -> Union[RendezvousInfo, Tuple[Store, int, int]]:
13621369
self._record(message=msg, rank=rank)
13631370
log.info(msg)
13641371

1372+
# Record rendezvous completion event
1373+
rendezvous_completion_event_id = record_profiling_event(
1374+
ProfilingEvent.RENDEZVOUS_COMPLETED,
1375+
node_id=self._this_node,
1376+
)
1377+
13651378
# Use RendezvousInfo if available (newer PyTorch versions >= 2.4.0)
13661379
# Fall back to tuple format if RendezvousInfo is not supported
13671380
if _RENDEZVOUS_INFO_AVAILABLE:
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""
17+
Monkey patch for PyTorch's c10d_rendezvous_backend to add use_libuv support.
18+
19+
This patch modifies the _create_tcp_store function to accept and use the use_libuv
20+
parameter from RendezvousParameters, allowing users to control whether to use
21+
the libuv backend or the traditional socket backend for TCPStore.
22+
23+
Usage:
24+
from nvidia_resiliency_ext.fault_tolerance.c10d_monkey_patch import apply_c10d_patch
25+
apply_c10d_patch()
26+
"""
27+
28+
import logging
29+
30+
from nvidia_resiliency_ext.shared_utils.log_manager import LogConfig
31+
32+
logger = logging.getLogger(LogConfig.name)
33+
34+
35+
def _patched_create_tcp_store(params: "RendezvousParameters") -> "TCPStore": # noqa: F821
36+
"""
37+
Patched version of _create_tcp_store that supports use_libuv parameter.
38+
39+
This function is identical to the original _create_tcp_store except it
40+
extracts and uses the use_libuv parameter from RendezvousParameters.
41+
"""
42+
import os
43+
from datetime import timedelta
44+
from typing import cast
45+
46+
from torch.distributed import TCPStore
47+
from torch.distributed.elastic.events import NodeState, construct_and_record_rdzv_event
48+
from torch.distributed.elastic.rendezvous.api import RendezvousConnectionError
49+
from torch.distributed.elastic.rendezvous.c10d_rendezvous_backend import (
50+
_matches_machine_hostname,
51+
parse_rendezvous_endpoint,
52+
)
53+
54+
# Default port for TCP store (29400) - defined locally for PyTorch 2.3.1 compatibility
55+
DEFAULT_PORT = 29400
56+
host, port = parse_rendezvous_endpoint(params.endpoint, default_port=DEFAULT_PORT)
57+
58+
cfg_is_host = params.get_as_bool("is_host")
59+
# If the user has explicitly specified whether our process should host the
60+
# the store, respect it.
61+
if cfg_is_host is not None:
62+
is_host = cfg_is_host
63+
# Otherwise try to determine whether we are the host based on our hostname
64+
# and IP address.
65+
else:
66+
is_host = _matches_machine_hostname(host)
67+
68+
# The timeout
69+
read_timeout = cast(int, params.get_as_int("read_timeout", 60))
70+
if read_timeout <= 0:
71+
raise ValueError("The read timeout must be a positive integer.")
72+
73+
# The use_libuv parameter - NEW FUNCTIONALITY
74+
use_libuv = params.get_as_bool("use_libuv", True)
75+
76+
# In specific cases we attempt to instantiate the store twice. For details
77+
# see the explanation in the except clause below.
78+
for is_server in [is_host, False]:
79+
try:
80+
store = TCPStore(
81+
host,
82+
port,
83+
is_master=is_server,
84+
multi_tenant=True,
85+
timeout=timedelta(seconds=read_timeout),
86+
use_libuv=use_libuv, # NEW PARAMETER
87+
)
88+
89+
if is_server:
90+
msg = f"Process {os.getpid()} hosts the TCP store for the C10d rendezvous backend."
91+
construct_and_record_rdzv_event(
92+
run_id=params.run_id, message=msg, node_state=NodeState.INIT
93+
)
94+
logger.info(msg)
95+
96+
break
97+
except (ValueError, RuntimeError, TimeoutError) as exc:
98+
# If we heuristically inferred the value of is_host as True and our
99+
# first attempt to instantiate the TCP store has failed, try it one
100+
# more time with is_host set to False. As an edge case there can be
101+
# more than one process that is part of the same rendezvous on this
102+
# machine and only one of them will eventually host the store.
103+
104+
if not is_server or cfg_is_host is not None:
105+
raise RendezvousConnectionError(
106+
"The connection to the C10d store has failed. See inner exception for details."
107+
) from exc
108+
109+
return store # type: ignore[possibly-undefined]
110+
111+
112+
def apply_c10d_patch():
113+
"""
114+
Apply the monkey patch to add use_libuv support to c10d_rendezvous_backend.
115+
116+
This function patches the _create_tcp_store function in the c10d_rendezvous_backend
117+
module to support the use_libuv parameter.
118+
"""
119+
try:
120+
from torch.distributed.elastic.rendezvous import c10d_rendezvous_backend
121+
122+
# Apply the patch
123+
c10d_rendezvous_backend._create_tcp_store = _patched_create_tcp_store
124+
125+
logger.info(
126+
"Successfully applied c10d_rendezvous_backend monkey patch for use_libuv support"
127+
)
128+
129+
except ImportError as e:
130+
logger.error(f"Failed to import c10d_rendezvous_backend: {e}")
131+
raise
132+
except Exception as e:
133+
logger.error(f"Failed to apply c10d monkey patch: {e}")
134+
raise

src/nvidia_resiliency_ext/fault_tolerance/launcher.py

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
write_obj_to_ipc_stream,
7777
)
7878
from nvidia_resiliency_ext.shared_utils.log_manager import LogConfig, setup_logger
79+
from nvidia_resiliency_ext.shared_utils.profiling import ProfilingEvent, record_profiling_event
7980

8081
# Deprecation warning for FT_LAUNCHER_LOGLEVEL
8182
if os.getenv('FT_LAUNCHER_LOGLEVEL') is not None:
@@ -101,6 +102,10 @@ def _register_ft_rdzv_handler():
101102
from torch.distributed.elastic.rendezvous.c10d_rendezvous_backend import create_backend
102103

103104
from ._ft_rendezvous import FtRendezvousHandler, create_handler
105+
from .c10d_monkey_patch import apply_c10d_patch
106+
107+
# Apply monkey patch to add use_libuv support to c10d backend
108+
apply_c10d_patch()
104109

105110
def _create_ft_rdzv_handler(params: RendezvousParameters) -> FtRendezvousHandler:
106111
backend, store = create_backend(params)
@@ -138,7 +143,7 @@ class LocalElasticAgent(SimpleElasticAgent):
138143
python multiprocessing compatible. To pass multiprocessing data structures
139144
to the workers you may create the data structure in the same multiprocessing
140145
context as the specified ``start_method`` and pass it as a function argument.
141-
146+
142147
Note: If your training script uses the nvrx logger, make sure to call
143148
``setup_logger()`` at the beginning of your training function to ensure
144149
the logger is properly set up in each subprocess.
@@ -179,12 +184,12 @@ def trainer(args) -> str:
179184
# Ensure nvrx logger is set up in this subprocess
180185
from nvidia_resiliency_ext.shared_utils.log_manager import setup_logger
181186
setup_logger()
182-
187+
183188
# Use the nvrx logger
184189
import logging
185190
logger = logging.getLogger(LogConfig.name)
186191
logger.info("Training started")
187-
192+
188193
return "do train"
189194
190195
def main():
@@ -251,6 +256,7 @@ def __init__(
251256
self._ft_cfg = fault_tol_cfg
252257
self._children_pgids: Set[int] = set()
253258
self._restart_policy = restart_policy
259+
self._node_id = self._get_fq_hostname()
254260

255261
DEFAULT_ROLE = "default" # FIXME
256262

@@ -322,6 +328,13 @@ def _invoke_run_with_any_failed_policy(self, role: str = DEFAULT_ROLE) -> RunRes
322328
self._exit_barrier()
323329
return run_result
324330
elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED}:
331+
# Record failure detection event
332+
record_profiling_event(
333+
ProfilingEvent.FAILURE_DETECTED,
334+
node_id=self._rdzv_handler._this_node,
335+
rank=self._worker_group.group_rank,
336+
)
337+
325338
if self._remaining_restarts > 0:
326339
logger.info(
327340
"[%s] Worker group %s. "
@@ -347,6 +360,13 @@ def _invoke_run_with_any_failed_policy(self, role: str = DEFAULT_ROLE) -> RunRes
347360
num_nodes_waiting = rdzv_handler.num_nodes_waiting()
348361
group_rank = self._worker_group.group_rank
349362
if num_nodes_waiting > 0:
363+
# Record failure detection event
364+
record_profiling_event(
365+
ProfilingEvent.FAILURE_DETECTED,
366+
node_id=self._rdzv_handler._this_node,
367+
rank=self._worker_group.group_rank,
368+
)
369+
350370
logger.info(
351371
"[%s] Detected %s "
352372
"new nodes from group_rank=%s; "
@@ -587,6 +607,13 @@ async def send_close_msg():
587607

588608
self._shutdown(timeout=self._workers_stop_timeout)
589609

610+
# Record worker termination event after shutdown is complete
611+
record_profiling_event(
612+
ProfilingEvent.WORKER_TERMINATED,
613+
node_id=self._rdzv_handler._this_node,
614+
rank=worker_group.group_rank,
615+
)
616+
590617
# pyre-fixme[56]: Pyre was not able to infer the type of the decorator
591618
# `torch.distributed.elastic.metrics.prof`.
592619
@prof
@@ -596,6 +623,13 @@ def _start_workers(self, worker_group: WorkerGroup) -> Dict[int, Any]:
596623
assert store is not None
597624
restart_count = spec.max_restarts - self._remaining_restarts
598625

626+
# Record worker start start event
627+
record_profiling_event(
628+
ProfilingEvent.WORKER_START_STARTED,
629+
node_id=self._rdzv_handler._this_node,
630+
rank=worker_group.group_rank,
631+
)
632+
599633
use_agent_store = spec.rdzv_handler.use_agent_store
600634

601635
args: Dict[int, Tuple] = {}
@@ -667,8 +701,16 @@ def _start_workers(self, worker_group: WorkerGroup) -> Dict[int, Any]:
667701

668702
self._children_pgids = {os.getpgid(p) for p in self._pcontext.pids().values()}
669703

704+
# Record worker start completion event
705+
record_profiling_event(
706+
ProfilingEvent.WORKER_START_COMPLETED,
707+
node_id=self._rdzv_handler._this_node,
708+
rank=worker_group.group_rank,
709+
)
710+
670711
return self._pcontext.pids()
671712

713+
672714
def _shutdown(self, death_sig: signal.Signals = signal.SIGTERM, timeout: int = 30) -> None:
673715
if self._worker_watchdog is not None:
674716
self._worker_watchdog.stop()
@@ -1054,6 +1096,7 @@ def launch_agent(
10541096
)
10551097

10561098
logger.info(f"Agent .run() is OK. No failures in the result. {result=}")
1099+
10571100
return result.return_values
10581101
except UnhealthyNodeException as e:
10591102
# do not shutdown rendezvous when an unhealthy node is leaving
@@ -1987,6 +2030,10 @@ def config_from_args(args) -> Tuple[LaunchConfig, Union[Callable, str], List[str
19872030

19882031
rdzv_configs = _parse_rendezvous_config(args.rdzv_conf)
19892032

2033+
# Add use_libuv=False for c10d backend
2034+
if args.rdzv_backend == 'c10d':
2035+
rdzv_configs['use_libuv'] = False
2036+
19902037
if args.rdzv_backend == "static":
19912038
rdzv_configs["rank"] = args.node_rank
19922039

0 commit comments

Comments
 (0)