Skip to content

Commit e26efe3

Browse files
authored
Merge pull request #199 from NVIDIA/abasant/persistent_worker_del_after_abort
Training exit after Inprocess abort should ensure clean shutdown of PersistentAsync worker
2 parents f614f56 + 43c620f commit e26efe3

File tree

2 files changed

+34
-1
lines changed

2 files changed

+34
-1
lines changed

src/nvidia_resiliency_ext/checkpointing/async_ckpt/core.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -655,7 +655,9 @@ def close(self, abort=False):
655655
abort (bool, optional): Default to False. Needs to be manually set to true when
656656
the checkpoint async process needs to be aborted.
657657
"""
658-
if not abort:
658+
# For a clean shut down scenario with valid async processes running,
659+
# finalize all pending async calls
660+
if not abort and (self.persistent is False or self.persistent_caller is not None):
659661
self.maybe_finalize_async_calls(blocking=True)
660662
if self.persistent and self.persistent_caller:
661663
self.persistent_caller.close(abort=abort)

tests/checkpointing/unit/test_async_writer.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,3 +308,34 @@ def test_async_cp_with_multiple_queue_and_abort(self, tmp_path_dist_ckpt):
308308

309309
async_queue_dist.close()
310310
ckpt_impl.close()
311+
312+
def test_async_cp_with_multiple_queue_and_abort_followed_by_delete(self, tmp_path_dist_ckpt):
313+
"""
314+
Test that persistent async CP worker shuts down cleanly after an abort operation.
315+
This test mocks the behavior of training exiting after an abort triggered by an inprocess restart.
316+
"""
317+
Utils.initialize_distributed()
318+
model = FSDP(Model((1024, 1024), 8))
319+
async_queue_dist = AsyncCallsQueue(persistent=True)
320+
with (
321+
TempNamedDir(
322+
tmp_path_dist_ckpt / 'async_checkpoint_dist', sync=True
323+
) as async_ckpt_dir_dist,
324+
):
325+
state_dict = model.state_dict()
326+
planner = DefaultSavePlanner()
327+
328+
try:
329+
# Raise an exception in training process right after async CP request is submitted
330+
with pytest.raises(RuntimeError) as exc_info:
331+
self.async_save_checkpoint(
332+
async_ckpt_dir_dist, state_dict, planner, async_queue_dist
333+
)
334+
raise RuntimeError("Fake exception to mock training process exception")
335+
async_queue_dist.maybe_finalize_async_calls(blocking=True, no_dist=False)
336+
finally:
337+
# Mock behavior of an abort operation triggered by inprocess restart when exception occurs.
338+
# Abort the CP workers to mock the action of inprocess restarts
339+
abort_nvrx_checkpoint()
340+
# Mock training loop exit which would invoke a __del__ on async queue object
341+
async_queue_dist.__del__()

0 commit comments

Comments
 (0)