Skip to content

Commit 03fa9f0

Browse files
authored
Merge branch 'main' into sbak/attr_module_pr
2 parents 186d394 + e0fa23e commit 03fa9f0

File tree

14 files changed

+243
-70
lines changed

14 files changed

+243
-70
lines changed

CONTRIBUTING.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,19 @@ git push -u origin <local-branch>:<remote-branch>
8383

8484
4. With CI/CD process in place, the PR will be accepted and the corresponding issue closed only after adequate testing has been completed, manually, by the developer and NVRx engineer reviewing the code.
8585

86+
#### Documentation Building
87+
88+
When contributing documentation changes, ensure the documentation builds correctly. See the [docs CI workflow](https://github.com/NVIDIA/nvidia-resiliency-ext/blob/main/.github/workflows/build_docs.yml) for up-to-date instructions:
89+
90+
```bash
91+
pip install -U sphinx sphinx-rtd-theme sphinxcontrib-napoleon sphinx_copybutton lightning psutil defusedxml
92+
sphinx-build -b html docs/source public/
93+
94+
# alternatively,
95+
cd docs
96+
make html
97+
```
98+
You can then view the locally built documentation under `public` directory or `docs/build/html` (e.g., `open public/index.html`). Ensure that all documentation changes are properly formatted and that the build completes without warnings or errors.
8699

87100
#### Signing Your Work
88101

docs/source/checkpointing/async/usage_guide.rst

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ Usage guide
33
The :py:class:`nvidia_resiliency_ext.checkpointing.async_ckpt.core.AsyncCallsQueue`
44
provides application users with an interface to schedule :py:class:`nvidia_resiliency_ext.checkpointing.async_ckpt.core.AsyncRequest`,
55
which defines checkpoint routine, its args/kwargs and finalization steps when the checkpoint routine is finished.
6-
This class is a singleton, implying each rank will have only one instance of this class.
76
It is recommended to call the `close()` API on the `AsyncCallsQueue` at the end of training to ensure a clean shutdown of the process that manages async checkpointing.
87
We also extend the API of `abort_nvrx_checkpoint()` to abort the async processes and cleanly restart the `AsyncCallsQueue` in case of any restarts of the training processes.
98

docs/source/index.rst

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,12 @@ nvidia-resiliency-ext
1010
Features
1111
--------
1212

13-
* `Hang detection and automatic in-job restarting <https://github.com/NVIDIA/nvidia-resiliency-ext/blob/main/docs/source/fault_tolerance/index.rst>`_
14-
* `In-process restarting <https://github.com/NVIDIA/nvidia-resiliency-ext/blob/main/docs/source/inprocess/index.rst>`_
15-
* `Async checkpointing <https://github.com/NVIDIA/nvidia-resiliency-ext/blob/main/docs/source/checkpointing/async/index.rst>`_
16-
* `Local checkpointing <https://github.com/NVIDIA/nvidia-resiliency-ext/blob/main/docs/source/checkpointing/local/index.rst>`_
17-
* `Straggler (slower ranks) detection <https://github.com/NVIDIA/nvidia-resiliency-ext/blob/main/docs/source/straggler_det/index.rst>`_
18-
* `Shared utilities and distributed logging <https://github.com/NVIDIA/nvidia-resiliency-ext/blob/main/docs/source/shared_utils/index.rst>`_
13+
* `Hang detection and automatic in-job restarting <fault_tolerance/index.html>`_
14+
* `In-process restarting <inprocess/index.html>`_
15+
* `Async checkpointing <checkpointing/async/index.html>`_
16+
* `Local checkpointing <checkpointing/local/index.html>`_
17+
* `Straggler (slower ranks) detection <straggler_det/index.html>`_
18+
* `Shared utilities and distributed logging <shared_utils/index.html>`_
1919

2020
.. toctree::
2121
:maxdepth: 3

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ packaging = "*"
3737
python = ">=3.10"
3838
psutil = ">=6.0.0"
3939
pyyaml = "*"
40-
pynvml = ">=12.0.0"
4140
nvidia-ml-py = ">=12.570.86"
4241
defusedxml = "*"
4342

src/nvidia_resiliency_ext/attribution/trace_analyzer/trace_collector.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@
1313

1414
from nvidia_resiliency_ext.attribution.utils import capture_logs
1515
from nvidia_resiliency_ext.shared_utils.health_check import GPUHealthCheck, NicHealthCheck
16+
from nvidia_resiliency_ext.shared_utils.log_manager import LogConfig
1617

17-
logger = logging.getLogger(__name__)
18+
logger = logging.getLogger(LogConfig.name)
1819

1920

2021
class TraceCollector(ABC):
@@ -65,6 +66,7 @@ def __init__(
6566
self.stack_trace = None
6667
self.dump_fn = torch._C._distributed_c10d._dump_nccl_trace
6768
self.json = json
69+
logger = logging.getLogger(LogConfig.name)
6870
logger.info(f"{self.rank} created TorchFRTraceCollector")
6971

7072
def collect(self):
@@ -112,11 +114,10 @@ def get_health_check_results(local_rank: int):
112114
- Returns the bypassed output strings for GPU and NIC health checks
113115
"""
114116
health_check_results = {}
115-
116-
with capture_logs() as stderr_gpu:
117+
with capture_logs(LogConfig.name) as stderr_gpu:
117118
gpu_health_check = GPUHealthCheck(device_index=local_rank)
118119
gpu_health = gpu_health_check._perform_health_check()
119-
with capture_logs() as stderr_nic:
120+
with capture_logs(LogConfig.name) as stderr_nic:
120121
nic_health_check = NicHealthCheck()
121122
nic_health_check.set_nic_device(local_rank)
122123
nic_health = nic_health_check._perform_health_check()

src/nvidia_resiliency_ext/checkpointing/async_ckpt/core.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
"""
2020

2121
import logging
22+
import weakref
2223
from abc import ABC, abstractmethod
2324
from collections import deque
2425
from queue import Empty
@@ -130,14 +131,18 @@ def execute_finalize_fns(self, validate_matching_call_idx: bool = True) -> int:
130131
return self.call_idx
131132

132133

133-
# Singleton metaclass
134-
class Singleton(type):
135-
_instances = {}
134+
class ObjectTracker(type):
135+
def __init__(cls, name, bases, attrs):
136+
super().__init__(name, bases, attrs)
137+
cls._instances = weakref.WeakSet()
136138

137139
def __call__(cls, *args, **kwargs):
138-
if cls not in cls._instances:
139-
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
140-
return cls._instances[cls]
140+
instance = super().__call__(*args, **kwargs)
141+
cls._instances.add(instance)
142+
return instance
143+
144+
def get_instances(cls):
145+
return list(cls._instances)
141146

142147

143148
class AsyncCaller(ABC):
@@ -558,15 +563,11 @@ class _ActiveAsyncRequest(NamedTuple):
558563
async_request: AsyncRequest
559564

560565

561-
class AsyncCallsQueue(metaclass=Singleton):
566+
class AsyncCallsQueue(metaclass=ObjectTracker):
562567
"""Manages a queue of async calls.
563568
564569
Allows adding a new async call with `schedule_async_request` and finalizing
565570
active calls with `maybe_finalize_async_calls`.
566-
567-
This class is a Singleton implying there will be only one instance of AsyncCallsQueue per rank.
568-
Making this object a singleton avoids mis-use from users where they could potentially spin multiple async CP workers.
569-
Making this object a singleton also enables simplification of process life-cycle management during CP aborts.
570571
"""
571572

572573
def __init__(self, persistent: bool = True):
@@ -667,8 +668,7 @@ def __del__(self):
667668

668669
def abort_nvrx_checkpoint():
669670
"""Abort NVRx Checkpoint Utility. This will close the AsyncCallsQueue that manages async checkpoints"""
670-
# we have a singleton persistent worker in our async calls queue
671671
# close the async calls queue which will ensure a clean restart
672672
# of the CP async process in subsequent async save requests.
673-
async_queue_singleton = AsyncCallsQueue(persistent=True)
674-
async_queue_singleton.close(abort=True)
673+
for async_queue in AsyncCallsQueue.get_instances():
674+
async_queue.close(abort=True)

src/nvidia_resiliency_ext/fault_tolerance/launcher.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,7 @@ def setup_rank_monitors(self, envs: Dict[int, Dict[str, str]]) -> None:
478478
ipc_socket_path=rmon_ipc_socket,
479479
is_restarter_logger=is_restarter_logger,
480480
mp_ctx=fork_mp_ctx,
481+
env=worker_env,
481482
)
482483

483484
def shutdown_rank_monitors(self):

src/nvidia_resiliency_ext/fault_tolerance/rank_monitor_server.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,8 @@ def __init__(
126126
cfg: FaultToleranceConfig,
127127
ipc_socket_path: str,
128128
rank_monitor_ready_event,
129-
logger: RankMonitorLogger,
129+
logger: logging.Logger,
130+
is_restarter_logger: bool,
130131
):
131132
"""
132133
Initializes the RankMonitorServer object.
@@ -151,7 +152,10 @@ def __init__(
151152
self.connection_lock = asyncio.Lock()
152153
self.rank_monitor_ready_event = rank_monitor_ready_event
153154
self.logger = logger
154-
self.state_machine = RankMonitorStateMachine(logger)
155+
self.rmlogger = RankMonitorLogger(
156+
level=logger.level, is_restarter_logger=is_restarter_logger
157+
)
158+
self.state_machine = RankMonitorStateMachine(self.rmlogger)
155159
self._periodic_restart_task = None
156160
self.health_checker = GPUHealthCheck(
157161
interval=self.cfg.node_health_check_interval, on_failure=self._handle_unhealthy_node
@@ -264,7 +268,7 @@ async def _handle_init_msg(self, msg, writer):
264268
# Update NIC health checker on the rank to monitor.
265269
if self.nic_health_checker is not None:
266270
self.nic_health_checker.set_nic_device(local_rank=self.rank_info.local_rank)
267-
self.logger.set_connected_rank(msg.rank_info.global_rank)
271+
self.rmlogger.set_connected_rank(msg.rank_info.global_rank)
268272
await write_obj_to_ipc_stream(OkMsg(cfg=self.cfg), writer)
269273

270274
async def _handle_heartbeat_msg(self, msg, writer):
@@ -318,7 +322,7 @@ def _handle_ipc_connection_lost(self):
318322
f"Section(s) {open_section_names} were still open. you can use`.end_all_sections` to avoid this warning"
319323
)
320324
self.open_sections.clear()
321-
self.logger.set_connected_rank(None)
325+
self.rmlogger.set_connected_rank(None)
322326
if self.connection_lock.locked():
323327
self.connection_lock.release()
324328

@@ -535,18 +539,15 @@ def run(
535539

536540
try:
537541
setup_logger(force_reset=True, node_local_tmp_prefix="rankmonsvr")
538-
rmlogger = RankMonitorLogger(
539-
level=cfg.log_level, is_restarter_logger=is_restarter_logger
540-
)
541-
542542
logger = logging.getLogger(LogConfig.name)
543543

544544
logger.debug(f"Starting RankMonitorServer... PID={os.getpid()}")
545545
inst = RankMonitorServer(
546546
cfg,
547547
ipc_socket_path,
548548
rank_monitor_ready_event,
549-
rmlogger,
549+
logger,
550+
is_restarter_logger,
550551
)
551552
asyncio.run(inst._rank_monitor_loop())
552553
logger.debug("Leaving RankMonitorServer process")

src/nvidia_resiliency_ext/inprocess/rank_assignment.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
import warnings
2727
from typing import Callable, Optional, Union
2828

29+
from nvidia_resiliency_ext.shared_utils.log_manager import LogConfig
30+
2931
from . import exception, utils
3032
from .state import Mode, State
3133
from .store import StoreMixin
@@ -177,7 +179,8 @@ def __call__(self, ctx: RankAssignmentCtx) -> RankAssignmentCtx:
177179
active_rank = None
178180
# Log deactivation if transitioning from ACTIVE to INACTIVE
179181
if state.mode == Mode.ACTIVE:
180-
log = logging.getLogger(__name__)
182+
log = logging.getLogger(LogConfig.name)
183+
181184
log.info(
182185
f"[In-process] Rank deactivated (rank={state.rank}) due to max active world size limit ({active_world_size})"
183186
)
@@ -224,7 +227,7 @@ def __call__(self, ctx: RankAssignmentCtx) -> RankAssignmentCtx:
224227
active_rank = None
225228
# Log deactivation if transitioning from ACTIVE to INACTIVE
226229
if state.mode == Mode.ACTIVE:
227-
log = logging.getLogger(__name__)
230+
log = logging.getLogger(LogConfig.name)
228231
log.info(
229232
f"[In-process] Rank deactivated (rank={state.rank}) due to divisibility requirement (active_world_size={active_world_size}, divisor={divisor})"
230233
)
@@ -349,7 +352,7 @@ def __repr__(self):
349352
return f'{type(self).__name__}({self.name=})'
350353

351354

352-
def bounded_activate(node, counter, path=None):
355+
def bounded_activate(node, counter, path=None, current_state=None):
353356
if path is None:
354357
path = []
355358

@@ -361,17 +364,29 @@ def bounded_activate(node, counter, path=None):
361364
for ascendant in path
362365
)
363366
):
367+
# Log activation if this is the current rank
368+
if current_state and current_state.initial_rank == node.state.initial_rank:
369+
log = logging.getLogger(LogConfig.name)
370+
log.info(
371+
f"[In-process] Rank activated (initial_rank={node.state.initial_rank}, active_rank={counter}) in topology tree"
372+
)
364373
node.activate(counter)
365374
counter += 1
366375
for ascendant in path:
367376
ascendant.active_count += 1
368377
else:
378+
# Log deactivation if this is the current rank
379+
if current_state and current_state.initial_rank == node.state.initial_rank:
380+
log = logging.getLogger(LogConfig.name)
381+
log.info(
382+
f"[In-process] Rank deactivated (initial_rank={node.state.initial_rank}) due to max_ranks constraint in topology layer"
383+
)
369384
node.deactivate()
370385

371386
path.append(node)
372387

373388
for child in node.children.values():
374-
counter = bounded_activate(child, counter, path)
389+
counter = bounded_activate(child, counter, path, current_state)
375390
path.pop()
376391
return counter
377392

@@ -574,7 +589,7 @@ def build_tree(self, state, store):
574589
def replace_with_inactive(self, terminated_active_ranks):
575590
replaced_terminate_active_ranks = set()
576591

577-
log = logging.getLogger(__name__)
592+
log = logging.getLogger(LogConfig.name)
578593

579594
for terminated_active_rank in sorted(terminated_active_ranks):
580595
terminated_active_node = self.rank_map[terminated_active_rank]
@@ -625,7 +640,7 @@ def replace_with_backfill(self, unhandled_terminations):
625640
key=lambda node: node.state.active_rank,
626641
)
627642

628-
log = logging.getLogger(__name__)
643+
log = logging.getLogger(LogConfig.name)
629644
for backfill_node, terminated_node in itertools.zip_longest(
630645
reversed(largest_active_nodes),
631646
terminated_nodes,
@@ -647,7 +662,7 @@ def replace_with_backfill(self, unhandled_terminations):
647662

648663
def shift_ranks(self, replaced_active, unhandled_terminations):
649664
sorted_replaced_active = sorted(replaced_active)
650-
log = logging.getLogger(__name__)
665+
log = logging.getLogger(LogConfig.name)
651666

652667
for n in self.rank_map.values():
653668
n.state.active_world_size -= len(unhandled_terminations)
@@ -672,7 +687,7 @@ def filter_active_world_size(self):
672687
new_active_world_size = self.world_size_filter(active_world_size)
673688
assert new_active_world_size <= active_world_size
674689

675-
log = logging.getLogger(__name__)
690+
log = logging.getLogger(LogConfig.name)
676691
for leaf in self.tree.iter_leaves():
677692
leaf.state.active_world_size = new_active_world_size
678693
if leaf.state.mode == Mode.ACTIVE and leaf.state.active_rank >= new_active_world_size:
@@ -722,7 +737,7 @@ def __call__(self, ctx: RankAssignmentCtx) -> RankAssignmentCtx:
722737
if self.tree is None:
723738
self.build_tree(state, store)
724739

725-
active_world_size = bounded_activate(self.tree, 0)
740+
active_world_size = bounded_activate(self.tree, 0, None, self.current_state)
726741
for node in self.rank_map.values():
727742
node.state.active_world_size = active_world_size
728743

@@ -738,7 +753,7 @@ def __call__(self, ctx: RankAssignmentCtx) -> RankAssignmentCtx:
738753
rank for rank in terminated_ranks if self.rank_map[rank].state.mode == Mode.ACTIVE
739754
)
740755

741-
log = logging.getLogger(__name__)
756+
log = logging.getLogger(LogConfig.name)
742757
for terminated_rank in terminated_ranks:
743758
# If this rank is being terminated, log it
744759
if self.current_state.initial_rank == self.rank_map[terminated_rank].state.initial_rank:
@@ -808,7 +823,7 @@ def __call__(self, ctx: RankAssignmentCtx) -> RankAssignmentCtx:
808823
terminated_ranks = utils.format_rank_set(terminated_ranks)
809824
raise RankDiscarded(f'{rank=} {terminated_ranks=}')
810825
elif rank >= world_size:
811-
log = logging.getLogger(__name__)
826+
log = logging.getLogger(LogConfig.name)
812827
old_rank = rank
813828
rank = ordered_terminated_ranks[rank - world_size]
814829
log.info(
@@ -869,7 +884,7 @@ def __call__(self, ctx: RankAssignmentCtx) -> RankAssignmentCtx:
869884
old_rank = rank
870885
rank = rank - sum(rank > terminated_rank for terminated_rank in terminated_ranks)
871886
if old_rank != rank:
872-
log = logging.getLogger(__name__)
887+
log = logging.getLogger(LogConfig.name)
873888
log.info(f"[In-process] Rank shifted (rank changed from {old_rank} to {rank})")
874889

875890
state = dataclasses.replace(
@@ -982,7 +997,7 @@ def __call__(self, ctx: RankAssignmentCtx) -> RankAssignmentCtx:
982997

983998
group_count = int(store.get(prefixed_key))
984999
if not self.condition(group_count):
985-
log = logging.getLogger(__name__)
1000+
log = logging.getLogger(LogConfig.name)
9861001
log.info(
9871002
f"[In-process] Rank marked for termination (rank={rank}, group_key={key}, group_count={group_count}) due to failed group condition"
9881003
)

src/nvidia_resiliency_ext/inprocess/wrap.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,7 @@ def __init__(
219219
self.finalize = finalize
220220
self.health_check = health_check
221221

222+
setup_logger(node_local_tmp_prefix="wrapper")
222223
# Construct internal restart_health_check by chaining user's health_check with GPU and NVL checks
223224
self._construct_restart_health_check()
224225

0 commit comments

Comments
 (0)