@@ -55,7 +55,7 @@ class AsyncRequest(NamedTuple):
5555 async_fn : Optional [Callable ]
5656 async_fn_args : Tuple
5757 finalize_fns : List [Callable ]
58- async_fn_kwargs : Dict = {}
58+ async_fn_kwargs : Optional [ Dict ] = None
5959 preload_fn : Callable = None
6060 is_frozen : bool = False
6161 call_idx : int = 0
@@ -86,7 +86,8 @@ def execute_sync(self) -> None:
8686 async_fn_args [1 ] = self .preload_fn ()
8787 # persist the state
8888 if self .async_fn is not None :
89- self .async_fn (* async_fn_args , ** self .async_fn_kwargs )
89+ async_fn_kwargs = dict (self .async_fn_kwargs or {})
90+ self .async_fn (* async_fn_args , ** async_fn_kwargs )
9091 # This utility implements a sync cp save. Hence the barrier.
9192 torch .distributed .barrier ()
9293 # Finalize the CP state
@@ -262,8 +263,9 @@ def schedule_async_call(self, async_req: AsyncRequest) -> None:
262263
263264 ctx = mp .get_context ('fork' )
264265 self .start_time = time ()
266+ async_fn_kwargs = dict (async_req .async_fn_kwargs or {})
265267 self .process = ctx .Process (
266- target = async_req .async_fn , args = async_fn_args , kwargs = async_req . async_fn_kwargs
268+ target = async_req .async_fn , args = async_fn_args , kwargs = async_fn_kwargs
267269 )
268270 self .process .start ()
269271 init_time = time ()
@@ -540,7 +542,8 @@ def async_loop(
540542 async_fn_args [1 ] = item .preload_fn ()
541543 logger .debug (f"{ rank } has completed D2H of { call_idx } " )
542544 preload_q .task_done ()
543- item .async_fn (* async_fn_args , ** item .async_fn_kwargs )
545+ async_fn_kwargs = dict (item .async_fn_kwargs or {})
546+ item .async_fn (* async_fn_args , ** async_fn_kwargs )
544547 logger .debug (f"{ rank } has completed saving { item .call_idx } " )
545548 comp_q .put (item .call_idx )
546549 queue .task_done ()
0 commit comments