Skip to content

Commit 6c740d3

Browse files
committed
configurable number of thread for dcp
1 parent aa44bde commit 6c740d3

File tree

1 file changed

+5
-0
lines changed
  • torch_xla/experimental/distributed_checkpoint

1 file changed

+5
-0
lines changed

torch_xla/experimental/distributed_checkpoint/manager.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ def __init__(self,
107107
save_interval: int,
108108
max_to_keep: Optional[int] = 0,
109109
max_pending_async: Optional[int] = 1,
110+
num_of_threads: Optional[int] = 1,
110111
process_group: dist.ProcessGroup = None,
111112
chkpt_on_preemption: bool = True):
112113
"""
@@ -127,6 +128,8 @@ def __init__(self,
127128
slow down the active checkpoint.
128129
Default: 1, which only allows a single async checkpoint to be
129130
pending at a time.
131+
number_of_threads: Number of concurrent threads for writing checkpoint to
132+
file system.
130133
process_group: The process group to use when coordinating the checkpoint.
131134
Default: None, in which case a subgroup of the default process
132135
group will be created.
@@ -142,6 +145,7 @@ def __init__(self,
142145
self.base_path = os.path.join(path, '') # Ensure the base path ends in '/'
143146
self.save_interval = save_interval
144147
self.max_to_keep = max_to_keep
148+
self.num_of_threads = num_of_threads
145149
self.chkpt_on_preemption = chkpt_on_preemption
146150

147151
# Create a new group if none is provided
@@ -226,6 +230,7 @@ def _save(self, step, state_dict):
226230
state_dict=state_dict,
227231
storage_writer=FsspecWriter(
228232
path,
233+
thread_count=self.num_of_threads,
229234
per_thread_copy_ahead=0,
230235
),
231236
planner=xc.SPMDSavePlanner(),

0 commit comments

Comments
 (0)