Skip to content

Commit b8fb9c3

Browse files
adammoodytjruwase
andauthored
parallelize writing of layer checkpoint files across data parallel instances (#1419)
* parallelize layer checkpoints across data parallel groups * use partition_uniform to determine start/end index values * formatting fix * config: add option for parallel write of layer checkpoints in pipeline stage * yapf fixes * enable parallel layer write according to config param * avoid extraneous makedir when rank 0 writes all layers Co-authored-by: Olatunji Ruwase <[email protected]>
1 parent 99fde3b commit b8fb9c3

File tree

5 files changed

+61
-12
lines changed

5 files changed

+61
-12
lines changed

deepspeed/runtime/config.py

+15
Original file line numberDiff line numberDiff line change
@@ -669,6 +669,19 @@ def get_checkpoint_tag_validation_mode(checkpoint_params):
669669
)
670670

671671

672+
def get_checkpoint_parallel_write_pipeline(checkpoint_params):
673+
par_write_params = checkpoint_params.get(CHECKPOINT_PARALLEL_WRITE, {})
674+
par_write_pipeline = par_write_params.get(
675+
CHECKPOINT_PARALLEL_WRITE_PIPELINE_STAGE,
676+
CHECKPOINT_PARALLEL_WRITE_PIPELINE_STAGE_DEFAULT)
677+
if par_write_pipeline in [True, False]:
678+
return par_write_pipeline
679+
else:
680+
raise DeepSpeedConfigError(
681+
"checkpoint::parallel_write::pipeline_stage "
682+
f"value of '{par_write_pipeline}' is invalid, expecting: true or false")
683+
684+
672685
def get_dataloader_drop_last(param_dict):
673686
return get_scalar_param(param_dict,
674687
DATALOADER_DROP_LAST,
@@ -887,6 +900,8 @@ def _initialize_params(self, param_dict):
887900
self.load_universal_checkpoint = checkpoint_params.get(
888901
LOAD_UNIVERSAL_CHECKPOINT,
889902
LOAD_UNIVERSAL_CHECKPOINT_DEFAULT)
903+
par_write_pipe = get_checkpoint_parallel_write_pipeline(checkpoint_params)
904+
self.checkpoint_parallel_write_pipeline = par_write_pipe
890905

891906
self.aio_config = get_aio_config(param_dict)
892907

deepspeed/runtime/constants.py

+7
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,9 @@ class ValidationMode:
367367
# "checkpoint": {
368368
# tag_validation=["Ignore"|"Warn"|"Fail"]
369369
# load_universal=false
370+
# parallel_write: {
371+
# pipeline_stage: [True|False]
372+
# }
370373
# }
371374
CHECKPOINT = "checkpoint"
372375
CHECKPOINT_TAG_VALIDATION = "tag_validation"
@@ -380,6 +383,10 @@ class ValidationMode:
380383
LOAD_UNIVERSAL_CHECKPOINT = "load_universal"
381384
LOAD_UNIVERSAL_CHECKPOINT_DEFAULT = False
382385

386+
CHECKPOINT_PARALLEL_WRITE = "parallel_write"
387+
CHECKPOINT_PARALLEL_WRITE_PIPELINE_STAGE = "pipeline_stage"
388+
CHECKPOINT_PARALLEL_WRITE_PIPELINE_STAGE_DEFAULT = False
389+
383390
#########################################
384391
# Drop the last incomplete Batch
385392
# #########################################

deepspeed/runtime/engine.py

+17-7
Original file line numberDiff line numberDiff line change
@@ -2924,7 +2924,11 @@ def save_checkpoint(self, save_dir, tag=None, client_state={}, save_latest=True)
29242924
self._create_checkpoint_file(save_dir, tag, False)
29252925
self._save_moe_checkpoint(save_dir, tag, client_state=client_state)
29262926

2927-
if self.save_non_zero_checkpoint:
2927+
# We distribute the task of saving layer checkpoint files among
2928+
# data parallel instances, so all procs should call _save_checkpoint.
2929+
# All procs then call module_state_dict(), but only procs of data
2930+
# parallel rank 0 save the general model params.
2931+
if not self.has_moe_layers:
29282932
self._create_checkpoint_file(save_dir, tag, False)
29292933
self._save_checkpoint(save_dir, tag, client_state=client_state)
29302934

@@ -3091,12 +3095,18 @@ def _create_zero_checkpoint_files(self, save_dir, tag):
30913095
def _save_checkpoint(self, save_dir, tag, client_state={}):
30923096

30933097
save_path = self._get_ckpt_name(save_dir, tag)
3098+
3099+
zero_optimizer_state = self.zero_optimization() or self.bfloat16_enabled()
3100+
30943101
# A hack to save the checkpointing directory. Pipeline parallelism overrides
30953102
# module_state_dict() and uses this path to save the model. module_state_dict()
3096-
# then instead just returns None.
3103+
# then instead just returns None. The module_state_dict() implementation in
3104+
# PipelineEngine expects the save path to be set in self._curr_ckpt_path.
30973105
self._curr_ckpt_path = os.path.join(save_dir, tag)
3098-
zero_optimizer_state = self.zero_optimization() or self.bfloat16_enabled()
3099-
state = dict(module=self.module_state_dict(),
3106+
module = self.module_state_dict()
3107+
self._curr_ckpt_path = None
3108+
3109+
state = dict(module=module,
31003110
buffer_names=self._get_buffer_names(),
31013111
optimizer=self.optimizer.state_dict()
31023112
if self.optimizer and not zero_optimizer_state else None,
@@ -3114,9 +3124,9 @@ def _save_checkpoint(self, save_dir, tag, client_state={}):
31143124
ds_version=version)
31153125
state.update(client_state)
31163126

3117-
log_dist(message=f'Saving model checkpoint: {save_path}', ranks=[0, 1])
3118-
self.checkpoint_engine.save(state, save_path)
3119-
self._curr_save_path = None
3127+
if self.save_non_zero_checkpoint:
3128+
log_dist(message=f'Saving model checkpoint: {save_path}', ranks=[0, 1])
3129+
self.checkpoint_engine.save(state, save_path)
31203130

31213131
def _get_buffer_names(self):
31223132
buffer_names = []

deepspeed/runtime/pipe/engine.py

+2
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,8 @@ def __init__(self, has_bool_tensors=False, *super_args, **super_kwargs):
182182
self.module.activation_checkpoint_interval = self._config.pipeline[
183183
'activation_checkpoint_interval']
184184

185+
self.module.checkpoint_parallel_write_pipeline = self._config.checkpoint_parallel_write_pipeline
186+
185187
if self.is_last_stage():
186188
self.loss_model = self.module.loss_fn
187189

deepspeed/runtime/pipe/module.py

+20-5
Original file line numberDiff line numberDiff line change
@@ -562,13 +562,28 @@ def ckpt_layer_path_list(self, ckpt_dir, local_layer_idx):
562562
return ckpt_files
563563

564564
def save_state_dict(self, save_dir, checkpoint_engine):
565-
if self._grid.data_parallel_id != 0:
566-
return
565+
# Processes having the same model parallel rank on different data parallel instances
566+
# have identical layer weights. We can distribute the task of saving the layer weights
567+
# among the data parallel ranks. For example, if a pipeline stage has 9 layers and
568+
# if there are 2 data parallel instances, rank 0 will save the first 5 layers and
569+
# rank 1 will save the last 4.
570+
dp_rank = self._grid.data_parallel_id
571+
dp_size = self._grid.data_parallel_size
572+
num_layers = len(self.forward_funcs)
573+
if self.checkpoint_parallel_write_pipeline:
574+
# spread layers evenly across data parallel ranks
575+
offsets = ds_utils.partition_uniform(num_layers, dp_size)
576+
start, end = offsets[dp_rank], offsets[dp_rank + 1]
577+
else:
578+
# data parallel rank 0 writes all layers
579+
if dp_rank != 0:
580+
return
581+
start, end = 0, num_layers
582+
layer_list = self.forward_funcs[start:end]
567583

568584
os.makedirs(save_dir, exist_ok=True)
569-
layer_offset = self._local_start
570-
for idx, layer in enumerate(self.forward_funcs):
571-
model_ckpt_path = self.ckpt_layer_path(save_dir, idx)
585+
for idx, layer in enumerate(layer_list):
586+
model_ckpt_path = self.ckpt_layer_path(save_dir, start + idx)
572587
if not hasattr(layer, 'state_dict'):
573588
continue
574589
# We pass cloned tensors to torch.save() to avoid checkpoint bloat which occurs because torch.save()

0 commit comments

Comments
 (0)