1717This module provides an async utilities which allow to start
1818a checkpoint save process in the background.
1919"""
20-
20+ import gc
2121import logging
2222import weakref
2323from abc import ABC , abstractmethod
@@ -345,7 +345,7 @@ class PersistentAsyncCaller(AsyncCaller):
345345 Starts process asynchronously and allows checking if all processes on all ranks are done.
346346 """
347347
348- def __init__ (self ):
348+ def __init__ (self , is_daemon : bool = False ):
349349 self .process : mp .Process = None
350350 self .start_time : Optional [float ] = None
351351 ctx = mp .get_context ('spawn' )
@@ -359,6 +359,12 @@ def __init__(self):
359359 self .cur_item : int = None
360360 self .cur_idx : int = - 1
361361 self .rank : int = None
362+ # When background_worker_is_daemon flag is True, the async background
363+ # worker is spawned as a daemon making async worker shutdown cleaner.
364+ # The restriction of spawning the async worker as a daemon is that
365+ # the FileWriter performing the FileIO in the background process cannot
366+ # be parallelized with multi-processing.
367+ self .background_worker_is_daemon = is_daemon
362368
363369 def schedule_async_call (self , async_req : AsyncRequest ) -> None :
364370 """Put `AsyncRequest` to the Persistent Async Caller
@@ -385,15 +391,21 @@ def schedule_async_call(self, async_req: AsyncRequest) -> None:
385391 if self .process is None :
386392 ctx = mp .get_context ('spawn' )
387393 logger .info (f"PersistentAsyncCaller: { self .rank } , Starting Async Caller" )
394+ if self .background_worker_is_daemon :
395+ async_loop_target = PersistentAsyncCaller .async_loop_for_daemon_worker
396+ else :
397+ async_loop_target = PersistentAsyncCaller .async_loop
398+
388399 self .process : mp .Process = ctx .Process (
389- target = PersistentAsyncCaller . async_loop ,
400+ target = async_loop_target ,
390401 args = (
391402 self .rank ,
392403 self .queue ,
393404 self .preload_q ,
394405 self .comp_q ,
395406 logger .getEffectiveLevel (),
396407 ),
408+ daemon = self .background_worker_is_daemon ,
397409 )
398410 self .process .start ()
399411 logger .debug (f"PersistentAsyncCaller: { self .rank } , Started Async Caller { self .process } " )
@@ -496,8 +508,7 @@ def _debug_is_async_process_running(self):
496508 return self .process .is_alive ()
497509
498510 @staticmethod
499- @_disable_gc ()
500- def async_loop (
511+ def async_process_target (
501512 rank : int ,
502513 queue : mp .JoinableQueue ,
503514 preload_q : mp .JoinableQueue ,
@@ -547,9 +558,39 @@ def async_loop(
547558 logger .debug (f"{ rank } has completed saving { item .call_idx } " )
548559 comp_q .put (item .call_idx )
549560 queue .task_done ()
550-
561+ del async_fn_args
562+ del item
563+ gc .collect ()
551564 logger .info (f"PersistentAsyncCaller: persistent ckpt worker for { rank } has terminated" )
552565
566+ @staticmethod
567+ @_disable_gc ()
568+ def async_loop (
569+ rank : int ,
570+ queue : mp .JoinableQueue ,
571+ preload_q : mp .JoinableQueue ,
572+ comp_q : mp .Queue ,
573+ log_level : int = logging .INFO ,
574+ ):
575+ """
576+ Main function for the persistent checkpoint worker called by a non daemon async process.
577+ In this loop, child processes may be created (For example: to parallelize File IO)
578+ """
579+ PersistentAsyncCaller .async_process_target (rank , queue , preload_q , comp_q , log_level )
580+
581+ @staticmethod
582+ def async_loop_for_daemon_worker (
583+ rank : int ,
584+ queue : mp .JoinableQueue ,
585+ preload_q : mp .JoinableQueue ,
586+ comp_q : mp .Queue ,
587+ log_level : int = logging .INFO ,
588+ ):
589+ """
590+ Main function for the persistent checkpoint worker called by a daemon async process
591+ """
592+ PersistentAsyncCaller .async_process_target (rank , queue , preload_q , comp_q , log_level )
593+
553594
554595class _ActiveAsyncRequest (NamedTuple ):
555596 """Helper to represent an active async call.
@@ -573,18 +614,19 @@ class AsyncCallsQueue(metaclass=ObjectTracker):
573614 active calls with `maybe_finalize_async_calls`.
574615 """
575616
576- def __init__ (self , persistent : bool = True ):
617+ def __init__ (self , persistent : bool = True , is_daemon : bool = False ):
577618 self .async_calls : deque [_ActiveAsyncRequest ] = deque ([])
578619 self .call_idx : int = - 1
579620 self .persistent : bool = persistent
621+ self .is_daemon : bool = is_daemon
580622 self .persistent_caller : AsyncCaller = None
581623
582624 def _get_async_caller (self ):
583625 if not self .persistent :
584626 logger .warning ("The TemporalAsyncCaller will be deprecated soon. " )
585627 return TemporalAsyncCaller ()
586628 if self .persistent_caller is None :
587- self .persistent_caller = PersistentAsyncCaller ()
629+ self .persistent_caller = PersistentAsyncCaller (is_daemon = self . is_daemon )
588630 return self .persistent_caller
589631
590632 def schedule_async_request (self , async_request : AsyncRequest ) -> int :
0 commit comments