Skip to content

Commit c26242d

Browse files
Merge pull request NVIDIA#204 from herman-ai/herman/nvbugs_5596118
Add explicit async_fn_kwargs.
2 parents 3010430 + 076f2d0 commit c26242d

File tree

4 files changed

+8
-4
lines changed

4 files changed

+8
-4
lines changed

src/nvidia_resiliency_ext/checkpointing/async_ckpt/torch_ckpt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def async_save(self, state_dict, *args, **kwargs):
4949
preloaded_sd = preload_tensors(state_dict)
5050
torch.cuda.synchronize()
5151
async_request = AsyncRequest(
52-
TorchAsyncCheckpoint.async_fn, (preloaded_sd, *args), [], kwargs
52+
TorchAsyncCheckpoint.async_fn, (preloaded_sd, *args), [], kwargs or {}
5353
)
5454
self._async_calls_queue.schedule_async_request(async_request)
5555

src/nvidia_resiliency_ext/checkpointing/local/ckpt_managers/base_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ def finalize_fn():
307307
# we must wait for D2H to complete before returning control to the training
308308
with debug_time("ckpt_D2H_synchronize", logger):
309309
torch.cuda.synchronize()
310-
return AsyncRequest(save_fn, save_args, [finalize_fn])
310+
return AsyncRequest(save_fn, save_args, [finalize_fn], async_fn_kwargs={})
311311

312312
assert not is_async
313313
save_fn(*save_args)

tests/checkpointing/unit/test_async_writer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,9 @@ def finalize_fn():
6565
"""Finalizes async checkpointing and synchronizes processes."""
6666
save_state_dict_async_finalize(*save_state_dict_ret)
6767

68-
return AsyncRequest(save_fn, save_args, [finalize_fn], preload_fn=preload_fn)
68+
return AsyncRequest(
69+
save_fn, save_args, [finalize_fn], preload_fn=preload_fn, async_fn_kwargs={}
70+
)
6971

7072
def async_save_checkpoint(
7173
self,

tests/checkpointing/unit/test_async_writer_msc.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,9 @@ def finalize_fn():
4646
save_state_dict_async_finalize(*save_state_dict_ret)
4747
torch.distributed.barrier()
4848

49-
return AsyncRequest(save_fn, save_args, [finalize_fn], preload_fn=preload_fn)
49+
return AsyncRequest(
50+
save_fn, save_args, [finalize_fn], preload_fn=preload_fn, async_fn_kwargs={}
51+
)
5052

5153
def async_save_checkpoint(
5254
self, checkpoint_dir, state_dict, planner, async_queue, thread_count=1

0 commit comments

Comments
 (0)