@@ -107,6 +107,7 @@ def __init__(self,
107
107
save_interval : int ,
108
108
max_to_keep : Optional [int ] = 0 ,
109
109
max_pending_async : Optional [int ] = 1 ,
110
+ num_of_threads : Optional [int ] = 1 ,
110
111
process_group : dist .ProcessGroup = None ,
111
112
chkpt_on_preemption : bool = True ):
112
113
"""
@@ -127,6 +128,8 @@ def __init__(self,
127
128
slow down the active checkpoint.
128
129
Default: 1, which only allows a single async checkpoint to be
129
130
pending at a time.
131
+ number_of_threads: Number of concurrent threads for writing checkpoint to
132
+ file system.
130
133
process_group: The process group to use when coordinating the checkpoint.
131
134
Default: None, in which case a subgroup of the default process
132
135
group will be created.
@@ -142,6 +145,7 @@ def __init__(self,
142
145
self .base_path = os .path .join (path , '' ) # Ensure the base path ends in '/'
143
146
self .save_interval = save_interval
144
147
self .max_to_keep = max_to_keep
148
+ self .num_of_threads = num_of_threads
145
149
self .chkpt_on_preemption = chkpt_on_preemption
146
150
147
151
# Create a new group if none is provided
@@ -226,6 +230,7 @@ def _save(self, step, state_dict):
226
230
state_dict = state_dict ,
227
231
storage_writer = FsspecWriter (
228
232
path ,
233
+ thread_count = self .num_of_threads ,
229
234
per_thread_copy_ahead = 0 ,
230
235
),
231
236
planner = xc .SPMDSavePlanner (),
0 commit comments