Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add checkpoint #945

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
Open
22 changes: 11 additions & 11 deletions training/DeepSpeed-Domino/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ pip install -r requirements.txt
```

## Prepare the Dataset
Follow the instructions from [Megatron-DeepSpeed](https://github.com/deepspeedai/Megatron-DeepSpeed/tree/main/examples_deepspeed/universal_checkpointing#download-and-pre-process-training-dataset) to prepare the training dataset.
Follow the instructions from [Megatron-DeepSpeed](https://github.com/microsoft/Megatron-DeepSpeed/tree/main/examples_deepspeed/universal_checkpointing#download-and-pre-process-training-dataset) to prepare the training dataset.

## Execute Domino Training

Expand Down Expand Up @@ -38,16 +38,16 @@ The output should look like this:

```
training ...
iteration: 1 | loss: 11.318 | iteration time (ms): 2174.0469932556152
iteration: 2 | loss: 11.307 | iteration time (ms): 1414.4024848937988
iteration: 3 | loss: 11.323 | iteration time (ms): 1385.9455585479736
iteration: 4 | loss: 11.310 | iteration time (ms): 1475.5175113677979
iteration: 5 | loss: 11.306 | iteration time (ms): 1395.7207202911377
iteration: 6 | loss: 11.315 | iteration time (ms): 1392.2104835510254
iteration: 7 | loss: 11.314 | iteration time (ms): 1402.6703834533691
iteration: 8 | loss: 11.309 | iteration time (ms): 1450.613260269165
iteration: 9 | loss: 11.305 | iteration time (ms): 1473.1688499450684
iteration: 10 | loss: 11.320 | iteration time (ms): 1398.4534740447998
iteration: 1 | loss: 11.318 | iteration time (ms): 2174.0469932556152
iteration: 2 | loss: 11.307 | iteration time (ms): 1414.4024848937988
iteration: 3 | loss: 11.323 | iteration time (ms): 1385.9455585479736
iteration: 4 | loss: 11.310 | iteration time (ms): 1475.5175113677979
iteration: 5 | loss: 11.306 | iteration time (ms): 1395.7207202911377
iteration: 6 | loss: 11.315 | iteration time (ms): 1392.2104835510254
iteration: 7 | loss: 11.314 | iteration time (ms): 1402.6703834533691
iteration: 8 | loss: 11.309 | iteration time (ms): 1450.613260269165
iteration: 9 | loss: 11.305 | iteration time (ms): 1473.1688499450684
iteration: 10 | loss: 11.320 | iteration time (ms): 1398.4534740447998
[2024-11-04 15:32:30,918] [INFO] [launch.py:351:main] Process 73015 exits successfully.
[2024-11-04 15:32:30,918] [INFO] [launch.py:351:main] Process 73017 exits successfully.
[2024-11-04 15:32:30,919] [INFO] [launch.py:351:main] Process 73014 exits successfully.
Expand Down
19 changes: 18 additions & 1 deletion training/DeepSpeed-Domino/domino/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,9 +204,26 @@ def parse_args():
'validation set.')
parser.add_argument('--log-interval', type=int, default=100,
help='Report loss and timing interval.')
parser.add_argument('--save', type=str, default=None,
help='Output directory to save checkpoints to.')
parser.add_argument('--no-save-optim', action='store_true', default=None,
help='Do not save current optimizer.')
parser.add_argument('--no-save-rng', action='store_true', default=None,
help='Do not save current rng state.')
parser.add_argument('--save-interval', type=int, default=None,
help='Number of iterations between checkpoint saves.')
parser.add_argument('--load', type=str, default=None,
help='Directory containing a model checkpoint.')
parser.add_argument('--no-load-optim', action='store_true', default=None,
help='Do not load optimizer when loading checkpoint.')
parser.add_argument('--no-load-rng', action='store_true', default=None,
help='Do not load rng state when loading checkpoint.')
parser.add_argument('--exit-on-missing-checkpoint', action='store_true',
help="If '--load' is set, but checkpoint is not found "
"(e.g., path typo), then exit instead of random "
"initialization.")
parser.add_argument('--save-interval', type=int, default=None,
help='Number of iterations between checkpoint saves.')

args = parser.parse_args()

args.rank = int(os.getenv('RANK', '0'))
Expand Down
4 changes: 4 additions & 0 deletions training/DeepSpeed-Domino/domino/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from domino.modules.fused_bias_gelu import bias_gelu

from megatron import fused_kernels
import deepspeed


def initialize_domino():
Expand All @@ -37,6 +38,9 @@ def initialize_domino():
world_size=args.world_size,
rank=args.rank
)

deepspeed.init_distributed()

mpu.initialize_model_parallel(args.tensor_model_parallel_size)
seed_ = args.seed
data_parallel_random_init = False
Expand Down
154 changes: 153 additions & 1 deletion training/DeepSpeed-Domino/domino/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,71 @@ def forward(self, input_ids, position_ids):
return combined_embeds


def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
"""For easy load."""

state_dict_ = {}
state_dict_[self._word_embeddings_key] \
= self.word_embeddings.state_dict(prefix=prefix,
keep_vars=keep_vars)
if self.add_position_embedding:
state_dict_[self._position_embeddings_key] \
= self.position_embeddings.state_dict(prefix=prefix,
keep_vars=keep_vars)
if self.num_tokentypes > 0:
state_dict_[self._tokentype_embeddings_key] \
= self.tokentype_embeddings.state_dict(prefix=prefix,
keep_vars=keep_vars)

return state_dict_

def load_state_dict(self, state_dict, strict=True):
"""Customized load."""

# Word embedding.
if self._word_embeddings_key in state_dict:
state_dict_ = state_dict[self._word_embeddings_key]
else:
# for backward compatibility.
state_dict_ = {}
for key in state_dict.keys():
if 'word_embeddings' in key:
state_dict_[key.split('word_embeddings.')[1]] \
= state_dict[key]
self.word_embeddings.load_state_dict(state_dict_, strict=strict)

# Position embedding.
if self.add_position_embedding:
if self._position_embeddings_key in state_dict:
state_dict_ = state_dict[self._position_embeddings_key]
else:
# for backward compatibility.
state_dict_ = {}
for key in state_dict.keys():
if 'position_embeddings' in key:
state_dict_[key.split('position_embeddings.')[1]] \
= state_dict[key]
self.position_embeddings.load_state_dict(state_dict_, strict=strict)

# Tokentype embedding.
if self.num_tokentypes > 0:
state_dict_ = {}
if self._tokentype_embeddings_key in state_dict:
state_dict_ = state_dict[self._tokentype_embeddings_key]
else:
# for backward compatibility.
for key in state_dict.keys():
if 'tokentype_embeddings' in key:
state_dict_[key.split('tokentype_embeddings.')[1]] \
= state_dict[key]
if len(state_dict_.keys()) > 0:
self.tokentype_embeddings.load_state_dict(state_dict_,
strict=strict)
else:
print('***WARNING*** expected tokentype embeddings in the '
'checkpoint but could not find it', flush=True)


class RotaryEmbedding(nn.Module):
def __init__(self, dim, seq_len_interpolation_factor=None):
super().__init__()
Expand Down Expand Up @@ -190,4 +255,91 @@ def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask,
encoder_output = encoder_output_t

return encoder_output


def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
"""For easy load."""

state_dict_ = {}
if self.pre_process:
state_dict_[self._embedding_key] \
= self.embedding.state_dict_for_save_checkpoint(prefix=prefix,
keep_vars=keep_vars)
if self.add_encoder:
state_dict_[self._encoder_key] \
= self.encoder.state_dict_for_save_checkpoint(prefix=prefix,
keep_vars=keep_vars)
if self.post_process:
if self.add_pooler:
state_dict_[self._pooler_key] \
= self.pooler.state_dict_for_save_checkpoint(prefix=prefix,
keep_vars=keep_vars)
if self.untie_embeddings_and_output_weights:
state_dict_[self._output_layer_key] \
= self.output_layer.state_dict(prefix=prefix, keep_vars=keep_vars)

if self.add_decoder:
state_dict_[self._decoder_key] \
= self.decoder.state_dict_for_save_checkpoint(prefix=prefix,
keep_vars=keep_vars)

return state_dict_

def load_state_dict(self, state_dict, strict=True):
"""Customized load."""

# Embedding.
if self.pre_process:
if self._embedding_key in state_dict:
state_dict_ = state_dict[self._embedding_key]
else:
# for backward compatibility.
state_dict_ = {}
for key in state_dict.keys():
if '_embeddings' in key:
state_dict_[key] = state_dict[key]
self.embedding.load_state_dict(state_dict_, strict=strict)

# Encoder.
if self.add_encoder:
if self._encoder_key in state_dict:
state_dict_ = state_dict[self._encoder_key]
# For backward compatibility.
elif 'transformer' in state_dict:
state_dict_ = state_dict['transformer']
else:
# For backward compatibility.
state_dict_ = {}
for key in state_dict.keys():
if 'transformer.' in key:
state_dict_[key.split('transformer.')[1]] = state_dict[key]

# For backward compatibility.
state_dict_self_attention = {}
for key in state_dict_.keys():
if '.attention.' in key:
state_dict_self_attention[key.replace(".attention.",
".self_attention.")] = state_dict_[key]
else:
state_dict_self_attention[key] = state_dict_[key]
state_dict_ = state_dict_self_attention

self.encoder.load_state_dict(state_dict_, strict=strict)

# Pooler.
if self.post_process:
if self.add_pooler:
assert 'pooler' in state_dict, \
'could not find data for pooler in the checkpoint'
self.pooler.load_state_dict(state_dict[self._pooler_key],
strict=strict)
if self.untie_embeddings_and_output_weights:
assert 'output_layer' in state_dict, \
'could not find data for output_layer in the checkpoint'
self.output_layer.load_state_dict(state_dict[self._output_layer_key],
strict=strict)
# Decoder.
if self.add_decoder:
assert 'decoder' in state_dict, \
'could not find data for pooler in the checkpoint'
self.decoder.load_state_dict(state_dict[self._decoder_key],
strict=strict)
18 changes: 17 additions & 1 deletion training/DeepSpeed-Domino/domino/modules/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,14 @@ def __init__(self, config=None, share_embeddings_and_output_weights=True):
self.config = config
self.share_embeddings_and_output_weights = share_embeddings_and_output_weights


def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
"""Use this function to override the state dict for
saving checkpoints.
"""

return self.state_dict(prefix=prefix, keep_vars=keep_vars)

def initialize_word_embeddings(self):
self.share_embeddings_and_output_weights = True
return
Expand Down Expand Up @@ -74,7 +82,8 @@ def float_conversion(val):
return conversion_helper(val, float_conversion)


class Float16Module(torch.nn.Module):
# class Float16Module(torch.nn.Module):
class Float16Module(DominoModule):

def __init__(self, module, args):
super(Float16Module, self).__init__()
Expand All @@ -91,3 +100,10 @@ def forward(self, *inputs, **kwargs):
outputs = float16_to_fp32(outputs)
return outputs


def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
""" Retrieve state_dict from the module being wrapped."""
return self.module.state_dict_for_save_checkpoint(prefix=prefix, keep_vars=keep_vars)

def load_state_dict(self, state_dict, strict=True):
self.module.load_state_dict(state_dict, strict=strict)
13 changes: 12 additions & 1 deletion training/DeepSpeed-Domino/domino/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
from domino.tensor_parallel.data import broadcast_data


from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint

def is_rank_0():
# if torch.cuda.current_device() == 0:
if torch.distributed.get_rank() == 0:
Expand Down Expand Up @@ -109,7 +112,10 @@ def setup_model_and_optimizer(base_model,
optimizer = get_megatron_optimizer(models, no_wd_decay_cond, scale_lr_cond)
opt_param_scheduler = get_optimizer_param_scheduler(optimizer)

args.iteration = 0
if args.load is not None:
args.iteration = load_checkpoint(model, optimizer, opt_param_scheduler)
else:
args.iteration = 0

return model, optimizer, opt_param_scheduler

Expand Down Expand Up @@ -297,6 +303,11 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
config)

iteration += 1

if args.save and args.save_interval and \
iteration % args.save_interval == 0:
save_checkpoint(iteration, model, optimizer, opt_param_scheduler)

args.consumed_train_samples += mpu.get_data_parallel_world_size() * \
args.micro_batch_size * get_num_microbatches()

Expand Down
Loading
Loading