Skip to content

Commit 384af42

Browse files
authored
Merge pull request NVIDIA#197 from NVIDIA/abasant/nvrx_153
NVRx Stability - Multithread file IO to simplify error prop and shutdown cleanup logic
2 parents e26efe3 + 49fdbb3 commit 384af42

File tree

3 files changed

+425
-66
lines changed

3 files changed

+425
-66
lines changed

src/nvidia_resiliency_ext/checkpointing/async_ckpt/core.py

Lines changed: 50 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
This module provides an async utilities which allow to start
1818
a checkpoint save process in the background.
1919
"""
20-
20+
import gc
2121
import logging
2222
import weakref
2323
from abc import ABC, abstractmethod
@@ -345,7 +345,7 @@ class PersistentAsyncCaller(AsyncCaller):
345345
Starts process asynchronously and allows checking if all processes on all ranks are done.
346346
"""
347347

348-
def __init__(self):
348+
def __init__(self, is_daemon: bool = False):
349349
self.process: mp.Process = None
350350
self.start_time: Optional[float] = None
351351
ctx = mp.get_context('spawn')
@@ -359,6 +359,12 @@ def __init__(self):
359359
self.cur_item: int = None
360360
self.cur_idx: int = -1
361361
self.rank: int = None
362+
# When background_worker_is_daemon flag is True, the async background
363+
# worker is spawned as a daemon making async worker shutdown cleaner.
364+
# The restriction of spawning the async worker as a daemon is that
365+
# the FileWriter performing the FileIO in the background process cannot
366+
# be parallelized with multi-processing.
367+
self.background_worker_is_daemon = is_daemon
362368

363369
def schedule_async_call(self, async_req: AsyncRequest) -> None:
364370
"""Put `AsyncRequest` to the Persistent Async Caller
@@ -385,15 +391,21 @@ def schedule_async_call(self, async_req: AsyncRequest) -> None:
385391
if self.process is None:
386392
ctx = mp.get_context('spawn')
387393
logger.info(f"PersistentAsyncCaller: {self.rank}, Starting Async Caller")
394+
if self.background_worker_is_daemon:
395+
async_loop_target = PersistentAsyncCaller.async_loop_for_daemon_worker
396+
else:
397+
async_loop_target = PersistentAsyncCaller.async_loop
398+
388399
self.process: mp.Process = ctx.Process(
389-
target=PersistentAsyncCaller.async_loop,
400+
target=async_loop_target,
390401
args=(
391402
self.rank,
392403
self.queue,
393404
self.preload_q,
394405
self.comp_q,
395406
logger.getEffectiveLevel(),
396407
),
408+
daemon=self.background_worker_is_daemon,
397409
)
398410
self.process.start()
399411
logger.debug(f"PersistentAsyncCaller: {self.rank}, Started Async Caller {self.process}")
@@ -496,8 +508,7 @@ def _debug_is_async_process_running(self):
496508
return self.process.is_alive()
497509

498510
@staticmethod
499-
@_disable_gc()
500-
def async_loop(
511+
def async_process_target(
501512
rank: int,
502513
queue: mp.JoinableQueue,
503514
preload_q: mp.JoinableQueue,
@@ -547,9 +558,39 @@ def async_loop(
547558
logger.debug(f"{rank} has completed saving {item.call_idx}")
548559
comp_q.put(item.call_idx)
549560
queue.task_done()
550-
561+
del async_fn_args
562+
del item
563+
gc.collect()
551564
logger.info(f"PersistentAsyncCaller: persistent ckpt worker for {rank} has terminated")
552565

566+
@staticmethod
567+
@_disable_gc()
568+
def async_loop(
569+
rank: int,
570+
queue: mp.JoinableQueue,
571+
preload_q: mp.JoinableQueue,
572+
comp_q: mp.Queue,
573+
log_level: int = logging.INFO,
574+
):
575+
"""
576+
Main function for the persistent checkpoint worker called by a non daemon async process.
577+
In this loop, child processes may be created (For example: to parallelize File IO)
578+
"""
579+
PersistentAsyncCaller.async_process_target(rank, queue, preload_q, comp_q, log_level)
580+
581+
@staticmethod
582+
def async_loop_for_daemon_worker(
583+
rank: int,
584+
queue: mp.JoinableQueue,
585+
preload_q: mp.JoinableQueue,
586+
comp_q: mp.Queue,
587+
log_level: int = logging.INFO,
588+
):
589+
"""
590+
Main function for the persistent checkpoint worker called by a daemon async process
591+
"""
592+
PersistentAsyncCaller.async_process_target(rank, queue, preload_q, comp_q, log_level)
593+
553594

554595
class _ActiveAsyncRequest(NamedTuple):
555596
"""Helper to represent an active async call.
@@ -573,18 +614,19 @@ class AsyncCallsQueue(metaclass=ObjectTracker):
573614
active calls with `maybe_finalize_async_calls`.
574615
"""
575616

576-
def __init__(self, persistent: bool = True):
617+
def __init__(self, persistent: bool = True, is_daemon: bool = False):
577618
self.async_calls: deque[_ActiveAsyncRequest] = deque([])
578619
self.call_idx: int = -1
579620
self.persistent: bool = persistent
621+
self.is_daemon: bool = is_daemon
580622
self.persistent_caller: AsyncCaller = None
581623

582624
def _get_async_caller(self):
583625
if not self.persistent:
584626
logger.warning("The TemporalAsyncCaller will be deprecated soon. ")
585627
return TemporalAsyncCaller()
586628
if self.persistent_caller is None:
587-
self.persistent_caller = PersistentAsyncCaller()
629+
self.persistent_caller = PersistentAsyncCaller(is_daemon=self.is_daemon)
588630
return self.persistent_caller
589631

590632
def schedule_async_request(self, async_request: AsyncRequest) -> int:

0 commit comments

Comments
 (0)