diff --git a/src/nvidia_resiliency_ext/checkpointing/async_ckpt/torch_ckpt.py b/src/nvidia_resiliency_ext/checkpointing/async_ckpt/torch_ckpt.py index e06ea618..620f7114 100644 --- a/src/nvidia_resiliency_ext/checkpointing/async_ckpt/torch_ckpt.py +++ b/src/nvidia_resiliency_ext/checkpointing/async_ckpt/torch_ckpt.py @@ -49,7 +49,7 @@ def async_save(self, state_dict, *args, **kwargs): preloaded_sd = preload_tensors(state_dict) torch.cuda.synchronize() async_request = AsyncRequest( - TorchAsyncCheckpoint.async_fn, (preloaded_sd, *args), [], kwargs + TorchAsyncCheckpoint.async_fn, (preloaded_sd, *args), [], kwargs or {} ) self._async_calls_queue.schedule_async_request(async_request) diff --git a/src/nvidia_resiliency_ext/checkpointing/local/ckpt_managers/base_manager.py b/src/nvidia_resiliency_ext/checkpointing/local/ckpt_managers/base_manager.py index a203704a..1694c488 100644 --- a/src/nvidia_resiliency_ext/checkpointing/local/ckpt_managers/base_manager.py +++ b/src/nvidia_resiliency_ext/checkpointing/local/ckpt_managers/base_manager.py @@ -307,7 +307,7 @@ def finalize_fn(): # we must wait for D2H to complete before returning control to the training with debug_time("ckpt_D2H_synchronize", logger): torch.cuda.synchronize() - return AsyncRequest(save_fn, save_args, [finalize_fn]) + return AsyncRequest(save_fn, save_args, [finalize_fn], async_fn_kwargs={}) assert not is_async save_fn(*save_args) diff --git a/tests/checkpointing/unit/test_async_writer.py b/tests/checkpointing/unit/test_async_writer.py index 28fb924e..53853f79 100644 --- a/tests/checkpointing/unit/test_async_writer.py +++ b/tests/checkpointing/unit/test_async_writer.py @@ -65,7 +65,9 @@ def finalize_fn(): """Finalizes async checkpointing and synchronizes processes.""" save_state_dict_async_finalize(*save_state_dict_ret) - return AsyncRequest(save_fn, save_args, [finalize_fn], preload_fn=preload_fn) + return AsyncRequest( + save_fn, save_args, [finalize_fn], preload_fn=preload_fn, async_fn_kwargs={} + ) def async_save_checkpoint( self, diff --git a/tests/checkpointing/unit/test_async_writer_msc.py b/tests/checkpointing/unit/test_async_writer_msc.py index 54ae3fe8..0143d63f 100644 --- a/tests/checkpointing/unit/test_async_writer_msc.py +++ b/tests/checkpointing/unit/test_async_writer_msc.py @@ -46,7 +46,9 @@ def finalize_fn(): save_state_dict_async_finalize(*save_state_dict_ret) torch.distributed.barrier() - return AsyncRequest(save_fn, save_args, [finalize_fn], preload_fn=preload_fn) + return AsyncRequest( + save_fn, save_args, [finalize_fn], preload_fn=preload_fn, async_fn_kwargs={} + ) def async_save_checkpoint( self, checkpoint_dir, state_dict, planner, async_queue, thread_count=1