Skip to content

Commit 3ac3862

Browse files
authored
Merge pull request #169 from NVIDIA/abasant/fix_5498779
Allow multiple AsyncCallsQueue in a process to enable different behav…
2 parents afc4267 + 8ef20c3 commit 3ac3862

File tree

3 files changed

+104
-16
lines changed

3 files changed

+104
-16
lines changed

docs/source/checkpointing/async/usage_guide.rst

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ Usage guide
33
The :py:class:`nvidia_resiliency_ext.checkpointing.async_ckpt.core.AsyncCallsQueue`
44
provides application users with an interface to schedule :py:class:`nvidia_resiliency_ext.checkpointing.async_ckpt.core.AsyncRequest`,
55
which defines checkpoint routine, its args/kwargs and finalization steps when the checkpoint routine is finished.
6-
This class is a singleton, implying each rank will have only one instance of this class.
76
It is recommended to call the `close()` API on the `AsyncCallsQueue` at the end of training to ensure a clean shutdown of the process that manages async checkpointing.
87
We also extend the API of `abort_nvrx_checkpoint()` to abort the async processes and cleanly restart the `AsyncCallsQueue` in case of any restarts of the training processes.
98

src/nvidia_resiliency_ext/checkpointing/async_ckpt/core.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
"""
2020

2121
import logging
22+
import weakref
2223
from abc import ABC, abstractmethod
2324
from collections import deque
2425
from queue import Empty
@@ -130,14 +131,18 @@ def execute_finalize_fns(self, validate_matching_call_idx: bool = True) -> int:
130131
return self.call_idx
131132

132133

133-
# Singleton metaclass
134-
class Singleton(type):
135-
_instances = {}
134+
class ObjectTracker(type):
135+
def __init__(cls, name, bases, attrs):
136+
super().__init__(name, bases, attrs)
137+
cls._instances = weakref.WeakSet()
136138

137139
def __call__(cls, *args, **kwargs):
138-
if cls not in cls._instances:
139-
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
140-
return cls._instances[cls]
140+
instance = super().__call__(*args, **kwargs)
141+
cls._instances.add(instance)
142+
return instance
143+
144+
def get_instances(cls):
145+
return list(cls._instances)
141146

142147

143148
class AsyncCaller(ABC):
@@ -558,15 +563,11 @@ class _ActiveAsyncRequest(NamedTuple):
558563
async_request: AsyncRequest
559564

560565

561-
class AsyncCallsQueue(metaclass=Singleton):
566+
class AsyncCallsQueue(metaclass=ObjectTracker):
562567
"""Manages a queue of async calls.
563568
564569
Allows adding a new async call with `schedule_async_request` and finalizing
565570
active calls with `maybe_finalize_async_calls`.
566-
567-
This class is a Singleton implying there will be only one instance of AsyncCallsQueue per rank.
568-
Making this object a singleton avoids mis-use from users where they could potentially spin multiple async CP workers.
569-
Making this object a singleton also enables simplification of process life-cycle management during CP aborts.
570571
"""
571572

572573
def __init__(self, persistent: bool = True):
@@ -667,8 +668,7 @@ def __del__(self):
667668

668669
def abort_nvrx_checkpoint():
669670
"""Abort NVRx Checkpoint Utility. This will close the AsyncCallsQueue that manages async checkpoints"""
670-
# we have a singleton persistent worker in our async calls queue
671671
# close the async calls queue which will ensure a clean restart
672672
# of the CP async process in subsequent async save requests.
673-
async_queue_singleton = AsyncCallsQueue(persistent=True)
674-
async_queue_singleton.close(abort=True)
673+
for async_queue in AsyncCallsQueue.get_instances():
674+
async_queue.close(abort=True)

tests/checkpointing/unit/test_async_writer.py

Lines changed: 90 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,17 @@
3131
)
3232
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
3333

34-
from nvidia_resiliency_ext.checkpointing.async_ckpt.core import AsyncCallsQueue, AsyncRequest
34+
from nvidia_resiliency_ext.checkpointing.async_ckpt.core import (
35+
AsyncCallsQueue,
36+
AsyncRequest,
37+
abort_nvrx_checkpoint,
38+
)
3539
from nvidia_resiliency_ext.checkpointing.async_ckpt.filesystem_async import FileSystemWriterAsync
3640
from nvidia_resiliency_ext.checkpointing.async_ckpt.state_dict_saver import (
3741
save_state_dict_async_finalize,
3842
save_state_dict_async_plan,
3943
)
44+
from nvidia_resiliency_ext.checkpointing.async_ckpt.torch_ckpt import TorchAsyncCheckpoint
4045
from nvidia_resiliency_ext.checkpointing.utils import diff
4146
from tests.checkpointing.unit import TempNamedDir
4247
from tests.checkpointing.unit.test_utilities import Model, Utils
@@ -92,6 +97,10 @@ def sync_save_checkpoint(self, checkpoint_dir, state_dict, planner):
9297
planner=planner,
9398
)
9499

100+
def async_save_checkpoint_on_rank0(self, checkpoint_dir, state_dict, torch_ckpt_impl):
101+
if torch.distributed.get_rank() == 0:
102+
torch_ckpt_impl.async_save(state_dict, checkpoint_dir / 'test')
103+
95104
def load_checkpoint(self, checkpoint_dir, state_dict):
96105
"""Loads a checkpoint into the given state_dict."""
97106
load(
@@ -219,3 +228,83 @@ def test_cached_metadata(self, tmp_path_dist_ckpt, async_queue):
219228
), f'{field.name} is different in metadata from non-cached, cached metadata impls'
220229
ckpt_dir.cleanup()
221230
async_queue.close()
231+
232+
def test_async_cp_with_multiple_queue_and_abort(self, tmp_path_dist_ckpt):
233+
"""
234+
Verifies that async checkpointing backend can be used with multiple async queues.
235+
For example, user may want to save 2 checkpoints i.e. one sharded state and one only on rank-0.
236+
Verify the abort CP functionality and the ability to resume after an abort operation
237+
"""
238+
Utils.initialize_distributed()
239+
model = FSDP(Model((1024, 1024), 8))
240+
async_queue_dist = AsyncCallsQueue()
241+
ckpt_impl = TorchAsyncCheckpoint(persistent_queue=True)
242+
with (
243+
TempNamedDir(
244+
tmp_path_dist_ckpt / 'async_checkpoint_dist', sync=True
245+
) as async_ckpt_dir_dist,
246+
TempNamedDir(
247+
tmp_path_dist_ckpt / 'async_checkpoint_no_dist', sync=True
248+
) as async_ckpt_dir_no_dist,
249+
):
250+
state_dict = model.state_dict()
251+
planner = DefaultSavePlanner()
252+
253+
# Perform async saves for both dist CP and non-dict CP use cases.
254+
self.async_save_checkpoint(async_ckpt_dir_dist, state_dict, planner, async_queue_dist)
255+
self.async_save_checkpoint_on_rank0(async_ckpt_dir_no_dist, state_dict, ckpt_impl)
256+
async_queue_dist.maybe_finalize_async_calls(blocking=True, no_dist=False)
257+
ckpt_impl.finalize_async_save(blocking=True, no_dist=True)
258+
259+
# Abort the CP workers to mock the action of inprocess restarts
260+
abort_nvrx_checkpoint()
261+
262+
# validate state of the Async CP workers after abort operation
263+
async_calls_queue_no_dist = ckpt_impl._get_async_calls_queue()
264+
assert (
265+
async_calls_queue_no_dist is not None
266+
), "We expect a valid state of AsyncCallsQueue"
267+
async_process_no_dist = async_calls_queue_no_dist._get_async_caller()
268+
if async_process_no_dist is not None:
269+
assert (
270+
async_process_no_dist._debug_is_async_process_running() is False
271+
), "After abort async process must stop"
272+
273+
async_process_dist = async_queue_dist._get_async_caller()
274+
if async_process_dist is not None:
275+
assert (
276+
async_process_dist._debug_is_async_process_running() is False
277+
), "After abort async process must stop"
278+
279+
# Perform async saves for both dist CP and non-dist CP use cases.
280+
# Validate that operations seamlessly resume after an abort operation
281+
self.async_save_checkpoint(async_ckpt_dir_dist, state_dict, planner, async_queue_dist)
282+
self.async_save_checkpoint_on_rank0(async_ckpt_dir_no_dist, state_dict, ckpt_impl)
283+
async_queue_dist.maybe_finalize_async_calls(blocking=True, no_dist=False)
284+
ckpt_impl.finalize_async_save(blocking=True, no_dist=True)
285+
286+
# validate state of the Async CP workers after resume operation
287+
async_calls_queue_no_dist = ckpt_impl._get_async_calls_queue()
288+
assert (
289+
async_calls_queue_no_dist is not None
290+
), "We expect a valid state of AsyncCallsQueue object in TorchAsyncCheckpoint after a CP event"
291+
async_process_no_dist = async_calls_queue_no_dist._get_async_caller()
292+
# for the non_dist CP use case, only rank-0 is expected to trigger an async process
293+
if torch.distributed.get_rank() == 0:
294+
assert (
295+
async_process_no_dist is not None
296+
), "We expect a valid state of AsyncCaller after a CP event"
297+
assert (
298+
async_process_no_dist._debug_is_async_process_running() is True
299+
), "After resume, we expect async process to be running on rank 0 for non dist async save"
300+
301+
async_process_dist = async_queue_dist._get_async_caller()
302+
assert (
303+
async_process_dist is not None
304+
), "We expect a valid state of AsyncCaller after a CP event"
305+
assert (
306+
async_process_dist._debug_is_async_process_running() is True
307+
), "After resume, we expect async process to be running on all ranks for dist async save"
308+
309+
async_queue_dist.close()
310+
ckpt_impl.close()

0 commit comments

Comments
 (0)