Skip to content

Commit a7909b1

Browse files
authored
Merge branch 'main' into sbak/fr_attr_pr_squashed
2 parents 6fd0e8e + 3850483 commit a7909b1

File tree

2 files changed

+10
-5
lines changed

2 files changed

+10
-5
lines changed

examples/checkpointing/async_writer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,9 @@ def finalize_fn():
117117
save_state_dict_async_finalize(*save_state_dict_ret)
118118
dist.barrier()
119119

120-
return AsyncRequest(save_fn, save_args, [finalize_fn], preload_fn=preload_fn)
120+
return AsyncRequest(
121+
save_fn, save_args, [finalize_fn], async_fn_kwargs={}, preload_fn=preload_fn
122+
)
121123

122124

123125
def save_checkpoint(checkpoint_dir, async_queue, model, thread_count, enable_msc):

src/nvidia_resiliency_ext/checkpointing/async_ckpt/core.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ class AsyncRequest(NamedTuple):
5555
async_fn: Optional[Callable]
5656
async_fn_args: Tuple
5757
finalize_fns: List[Callable]
58-
async_fn_kwargs: Dict = {}
58+
async_fn_kwargs: Optional[Dict] = None
5959
preload_fn: Callable = None
6060
is_frozen: bool = False
6161
call_idx: int = 0
@@ -86,7 +86,8 @@ def execute_sync(self) -> None:
8686
async_fn_args[1] = self.preload_fn()
8787
# persist the state
8888
if self.async_fn is not None:
89-
self.async_fn(*async_fn_args, **self.async_fn_kwargs)
89+
async_fn_kwargs = dict(self.async_fn_kwargs or {})
90+
self.async_fn(*async_fn_args, **async_fn_kwargs)
9091
# This utility implements a sync cp save. Hence the barrier.
9192
torch.distributed.barrier()
9293
# Finalize the CP state
@@ -262,8 +263,9 @@ def schedule_async_call(self, async_req: AsyncRequest) -> None:
262263

263264
ctx = mp.get_context('fork')
264265
self.start_time = time()
266+
async_fn_kwargs = dict(async_req.async_fn_kwargs or {})
265267
self.process = ctx.Process(
266-
target=async_req.async_fn, args=async_fn_args, kwargs=async_req.async_fn_kwargs
268+
target=async_req.async_fn, args=async_fn_args, kwargs=async_fn_kwargs
267269
)
268270
self.process.start()
269271
init_time = time()
@@ -540,7 +542,8 @@ def async_loop(
540542
async_fn_args[1] = item.preload_fn()
541543
logger.debug(f"{rank} has completed D2H of {call_idx}")
542544
preload_q.task_done()
543-
item.async_fn(*async_fn_args, **item.async_fn_kwargs)
545+
async_fn_kwargs = dict(item.async_fn_kwargs or {})
546+
item.async_fn(*async_fn_args, **async_fn_kwargs)
544547
logger.debug(f"{rank} has completed saving {item.call_idx}")
545548
comp_q.put(item.call_idx)
546549
queue.task_done()

0 commit comments

Comments
 (0)