diff --git a/training/DeepSpeed-Domino/README.md b/training/DeepSpeed-Domino/README.md index 92f6d1ecc..3c1f4040b 100644 --- a/training/DeepSpeed-Domino/README.md +++ b/training/DeepSpeed-Domino/README.md @@ -6,7 +6,7 @@ pip install -r requirements.txt ``` ## Prepare the Dataset -Follow the instructions from [Megatron-DeepSpeed](https://github.com/deepspeedai/Megatron-DeepSpeed/tree/main/examples_deepspeed/universal_checkpointing#download-and-pre-process-training-dataset) to prepare the training dataset. +Follow the instructions from [Megatron-DeepSpeed](https://github.com/microsoft/Megatron-DeepSpeed/tree/main/examples_deepspeed/universal_checkpointing#download-and-pre-process-training-dataset) to prepare the training dataset. ## Execute Domino Training @@ -38,16 +38,16 @@ The output should look like this: ``` training ... -iteration: 1 | loss: 11.318 | iteration time (ms): 2174.0469932556152 -iteration: 2 | loss: 11.307 | iteration time (ms): 1414.4024848937988 -iteration: 3 | loss: 11.323 | iteration time (ms): 1385.9455585479736 -iteration: 4 | loss: 11.310 | iteration time (ms): 1475.5175113677979 -iteration: 5 | loss: 11.306 | iteration time (ms): 1395.7207202911377 -iteration: 6 | loss: 11.315 | iteration time (ms): 1392.2104835510254 -iteration: 7 | loss: 11.314 | iteration time (ms): 1402.6703834533691 -iteration: 8 | loss: 11.309 | iteration time (ms): 1450.613260269165 -iteration: 9 | loss: 11.305 | iteration time (ms): 1473.1688499450684 -iteration: 10 | loss: 11.320 | iteration time (ms): 1398.4534740447998 +iteration: 1 | loss: 11.318 | iteration time (ms): 2174.0469932556152 +iteration: 2 | loss: 11.307 | iteration time (ms): 1414.4024848937988 +iteration: 3 | loss: 11.323 | iteration time (ms): 1385.9455585479736 +iteration: 4 | loss: 11.310 | iteration time (ms): 1475.5175113677979 +iteration: 5 | loss: 11.306 | iteration time (ms): 1395.7207202911377 +iteration: 6 | loss: 11.315 | iteration time (ms): 1392.2104835510254 +iteration: 7 | loss: 11.314 | iteration time (ms): 1402.6703834533691 +iteration: 8 | loss: 11.309 | iteration time (ms): 1450.613260269165 +iteration: 9 | loss: 11.305 | iteration time (ms): 1473.1688499450684 +iteration: 10 | loss: 11.320 | iteration time (ms): 1398.4534740447998 [2024-11-04 15:32:30,918] [INFO] [launch.py:351:main] Process 73015 exits successfully. [2024-11-04 15:32:30,918] [INFO] [launch.py:351:main] Process 73017 exits successfully. [2024-11-04 15:32:30,919] [INFO] [launch.py:351:main] Process 73014 exits successfully. diff --git a/training/DeepSpeed-Domino/domino/arguments.py b/training/DeepSpeed-Domino/domino/arguments.py index 8bc59223a..7c1938e0b 100644 --- a/training/DeepSpeed-Domino/domino/arguments.py +++ b/training/DeepSpeed-Domino/domino/arguments.py @@ -204,9 +204,26 @@ def parse_args(): 'validation set.') parser.add_argument('--log-interval', type=int, default=100, help='Report loss and timing interval.') + parser.add_argument('--save', type=str, default=None, + help='Output directory to save checkpoints to.') + parser.add_argument('--no-save-optim', action='store_true', default=None, + help='Do not save current optimizer.') + parser.add_argument('--no-save-rng', action='store_true', default=None, + help='Do not save current rng state.') + parser.add_argument('--save-interval', type=int, default=None, + help='Number of iterations between checkpoint saves.') + parser.add_argument('--load', type=str, default=None, + help='Directory containing a model checkpoint.') + parser.add_argument('--no-load-optim', action='store_true', default=None, + help='Do not load optimizer when loading checkpoint.') + parser.add_argument('--no-load-rng', action='store_true', default=None, + help='Do not load rng state when loading checkpoint.') + parser.add_argument('--exit-on-missing-checkpoint', action='store_true', + help="If '--load' is set, but checkpoint is not found " + "(e.g., path typo), then exit instead of random " + "initialization.") parser.add_argument('--save-interval', type=int, default=None, help='Number of iterations between checkpoint saves.') - args = parser.parse_args() args.rank = int(os.getenv('RANK', '0')) diff --git a/training/DeepSpeed-Domino/domino/initialize.py b/training/DeepSpeed-Domino/domino/initialize.py index 36e0fa1bc..51f2213d2 100644 --- a/training/DeepSpeed-Domino/domino/initialize.py +++ b/training/DeepSpeed-Domino/domino/initialize.py @@ -14,6 +14,7 @@ from domino.modules.fused_bias_gelu import bias_gelu from megatron import fused_kernels +import deepspeed def initialize_domino(): @@ -37,6 +38,9 @@ def initialize_domino(): world_size=args.world_size, rank=args.rank ) + + deepspeed.init_distributed() + mpu.initialize_model_parallel(args.tensor_model_parallel_size) seed_ = args.seed data_parallel_random_init = False diff --git a/training/DeepSpeed-Domino/domino/language_model.py b/training/DeepSpeed-Domino/domino/language_model.py index 2cfb2f9fd..5cbee692f 100644 --- a/training/DeepSpeed-Domino/domino/language_model.py +++ b/training/DeepSpeed-Domino/domino/language_model.py @@ -85,6 +85,71 @@ def forward(self, input_ids, position_ids): return combined_embeds + def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False): + """For easy load.""" + + state_dict_ = {} + state_dict_[self._word_embeddings_key] \ + = self.word_embeddings.state_dict(prefix=prefix, + keep_vars=keep_vars) + if self.add_position_embedding: + state_dict_[self._position_embeddings_key] \ + = self.position_embeddings.state_dict(prefix=prefix, + keep_vars=keep_vars) + if self.num_tokentypes > 0: + state_dict_[self._tokentype_embeddings_key] \ + = self.tokentype_embeddings.state_dict(prefix=prefix, + keep_vars=keep_vars) + + return state_dict_ + + def load_state_dict(self, state_dict, strict=True): + """Customized load.""" + + # Word embedding. + if self._word_embeddings_key in state_dict: + state_dict_ = state_dict[self._word_embeddings_key] + else: + # for backward compatibility. + state_dict_ = {} + for key in state_dict.keys(): + if 'word_embeddings' in key: + state_dict_[key.split('word_embeddings.')[1]] \ + = state_dict[key] + self.word_embeddings.load_state_dict(state_dict_, strict=strict) + + # Position embedding. + if self.add_position_embedding: + if self._position_embeddings_key in state_dict: + state_dict_ = state_dict[self._position_embeddings_key] + else: + # for backward compatibility. + state_dict_ = {} + for key in state_dict.keys(): + if 'position_embeddings' in key: + state_dict_[key.split('position_embeddings.')[1]] \ + = state_dict[key] + self.position_embeddings.load_state_dict(state_dict_, strict=strict) + + # Tokentype embedding. + if self.num_tokentypes > 0: + state_dict_ = {} + if self._tokentype_embeddings_key in state_dict: + state_dict_ = state_dict[self._tokentype_embeddings_key] + else: + # for backward compatibility. + for key in state_dict.keys(): + if 'tokentype_embeddings' in key: + state_dict_[key.split('tokentype_embeddings.')[1]] \ + = state_dict[key] + if len(state_dict_.keys()) > 0: + self.tokentype_embeddings.load_state_dict(state_dict_, + strict=strict) + else: + print('***WARNING*** expected tokentype embeddings in the ' + 'checkpoint but could not find it', flush=True) + + class RotaryEmbedding(nn.Module): def __init__(self, dim, seq_len_interpolation_factor=None): super().__init__() @@ -190,4 +255,91 @@ def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask, encoder_output = encoder_output_t return encoder_output - \ No newline at end of file + + def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False): + """For easy load.""" + + state_dict_ = {} + if self.pre_process: + state_dict_[self._embedding_key] \ + = self.embedding.state_dict_for_save_checkpoint(prefix=prefix, + keep_vars=keep_vars) + if self.add_encoder: + state_dict_[self._encoder_key] \ + = self.encoder.state_dict_for_save_checkpoint(prefix=prefix, + keep_vars=keep_vars) + if self.post_process: + if self.add_pooler: + state_dict_[self._pooler_key] \ + = self.pooler.state_dict_for_save_checkpoint(prefix=prefix, + keep_vars=keep_vars) + if self.untie_embeddings_and_output_weights: + state_dict_[self._output_layer_key] \ + = self.output_layer.state_dict(prefix=prefix, keep_vars=keep_vars) + + if self.add_decoder: + state_dict_[self._decoder_key] \ + = self.decoder.state_dict_for_save_checkpoint(prefix=prefix, + keep_vars=keep_vars) + + return state_dict_ + + def load_state_dict(self, state_dict, strict=True): + """Customized load.""" + + # Embedding. + if self.pre_process: + if self._embedding_key in state_dict: + state_dict_ = state_dict[self._embedding_key] + else: + # for backward compatibility. + state_dict_ = {} + for key in state_dict.keys(): + if '_embeddings' in key: + state_dict_[key] = state_dict[key] + self.embedding.load_state_dict(state_dict_, strict=strict) + + # Encoder. + if self.add_encoder: + if self._encoder_key in state_dict: + state_dict_ = state_dict[self._encoder_key] + # For backward compatibility. + elif 'transformer' in state_dict: + state_dict_ = state_dict['transformer'] + else: + # For backward compatibility. + state_dict_ = {} + for key in state_dict.keys(): + if 'transformer.' in key: + state_dict_[key.split('transformer.')[1]] = state_dict[key] + + # For backward compatibility. + state_dict_self_attention = {} + for key in state_dict_.keys(): + if '.attention.' in key: + state_dict_self_attention[key.replace(".attention.", + ".self_attention.")] = state_dict_[key] + else: + state_dict_self_attention[key] = state_dict_[key] + state_dict_ = state_dict_self_attention + + self.encoder.load_state_dict(state_dict_, strict=strict) + + # Pooler. + if self.post_process: + if self.add_pooler: + assert 'pooler' in state_dict, \ + 'could not find data for pooler in the checkpoint' + self.pooler.load_state_dict(state_dict[self._pooler_key], + strict=strict) + if self.untie_embeddings_and_output_weights: + assert 'output_layer' in state_dict, \ + 'could not find data for output_layer in the checkpoint' + self.output_layer.load_state_dict(state_dict[self._output_layer_key], + strict=strict) + # Decoder. + if self.add_decoder: + assert 'decoder' in state_dict, \ + 'could not find data for pooler in the checkpoint' + self.decoder.load_state_dict(state_dict[self._decoder_key], + strict=strict) diff --git a/training/DeepSpeed-Domino/domino/modules/module.py b/training/DeepSpeed-Domino/domino/modules/module.py index b89bbc21f..0f42ca764 100644 --- a/training/DeepSpeed-Domino/domino/modules/module.py +++ b/training/DeepSpeed-Domino/domino/modules/module.py @@ -25,6 +25,14 @@ def __init__(self, config=None, share_embeddings_and_output_weights=True): self.config = config self.share_embeddings_and_output_weights = share_embeddings_and_output_weights + + def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False): + """Use this function to override the state dict for + saving checkpoints. + """ + + return self.state_dict(prefix=prefix, keep_vars=keep_vars) + def initialize_word_embeddings(self): self.share_embeddings_and_output_weights = True return @@ -74,7 +82,8 @@ def float_conversion(val): return conversion_helper(val, float_conversion) -class Float16Module(torch.nn.Module): +# class Float16Module(torch.nn.Module): +class Float16Module(DominoModule): def __init__(self, module, args): super(Float16Module, self).__init__() @@ -91,3 +100,10 @@ def forward(self, *inputs, **kwargs): outputs = float16_to_fp32(outputs) return outputs + + def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False): + """ Retrieve state_dict from the module being wrapped.""" + return self.module.state_dict_for_save_checkpoint(prefix=prefix, keep_vars=keep_vars) + + def load_state_dict(self, state_dict, strict=True): + self.module.load_state_dict(state_dict, strict=strict) diff --git a/training/DeepSpeed-Domino/domino/training.py b/training/DeepSpeed-Domino/domino/training.py index 59e253fcf..142a30cf7 100644 --- a/training/DeepSpeed-Domino/domino/training.py +++ b/training/DeepSpeed-Domino/domino/training.py @@ -17,6 +17,9 @@ from domino.tensor_parallel.data import broadcast_data +from megatron.checkpointing import load_checkpoint +from megatron.checkpointing import save_checkpoint + def is_rank_0(): # if torch.cuda.current_device() == 0: if torch.distributed.get_rank() == 0: @@ -109,7 +112,10 @@ def setup_model_and_optimizer(base_model, optimizer = get_megatron_optimizer(models, no_wd_decay_cond, scale_lr_cond) opt_param_scheduler = get_optimizer_param_scheduler(optimizer) - args.iteration = 0 + if args.load is not None: + args.iteration = load_checkpoint(model, optimizer, opt_param_scheduler) + else: + args.iteration = 0 return model, optimizer, opt_param_scheduler @@ -297,6 +303,11 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler, config) iteration += 1 + + if args.save and args.save_interval and \ + iteration % args.save_interval == 0: + save_checkpoint(iteration, model, optimizer, opt_param_scheduler) + args.consumed_train_samples += mpu.get_data_parallel_world_size() * \ args.micro_batch_size * get_num_microbatches() diff --git a/training/DeepSpeed-Domino/megatron/checkpointing.py b/training/DeepSpeed-Domino/megatron/checkpointing.py index e88b58513..5c174624f 100644 --- a/training/DeepSpeed-Domino/megatron/checkpointing.py +++ b/training/DeepSpeed-Domino/megatron/checkpointing.py @@ -9,12 +9,11 @@ import torch -from megatron import update_num_microbatches -from megatron.core import mpu, tensor_parallel -from .global_vars import get_args -from .utils import (unwrap_model, - print_rank_0) - +# from megatron import update_num_microbatches +import domino.parallel_state as mpu +from domino.tensor_parallel.random import get_cuda_rng_tracker +from domino.arguments import get_args +from domino.utils import unwrap_model, print_rank_0 _CHECKPOINT_VERSION = None @@ -194,7 +193,7 @@ def get_rng_state(): 'np_rng_state': np.random.get_state(), 'torch_rng_state': torch.get_rng_state(), 'cuda_rng_state': torch.cuda.get_rng_state(), - 'rng_tracker_states': tensor_parallel.get_cuda_rng_tracker().get_states()} + 'rng_tracker_states': get_cuda_rng_tracker().get_states()} rng_state_list = None if torch.distributed.is_initialized() and \ @@ -218,6 +217,8 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler): # Only rank zero of the data parallel writes to the disk. model = unwrap_model(model) + model_module = model.module + model = [model_module] print_rank_0('saving checkpoint at iteration {:7d} to {}'.format( iteration, args.save)) @@ -241,7 +242,10 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler): # Arguments, iteration, and model. state_dict = {} - state_dict['args'] = args + t_args = args + t_args.init_method = None + t_args.output_layer_init_method = None + state_dict['args'] = t_args state_dict['checkpoint_version'] = 3.0 state_dict['iteration'] = iteration if len(model) == 1: @@ -503,6 +507,8 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri load_dir = getattr(args, load_arg) model = unwrap_model(model) + model_module = model.module + model = [model_module] state_dict, checkpoint_name, release = _load_base_checkpoint(load_dir, rank0=False) @@ -522,6 +528,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri set_checkpoint_version(state_dict.get('checkpoint_version', 0)) # Set iteration. + args.finetune = False if args.finetune or release: iteration = 0 else: @@ -544,7 +551,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri check_checkpoint_args(checkpoint_args) args.consumed_train_samples = getattr(checkpoint_args, 'consumed_train_samples', 0) - update_num_microbatches(consumed_samples=args.consumed_train_samples) + # update_num_microbatches(consumed_samples=args.consumed_train_samples) args.consumed_valid_samples = getattr(checkpoint_args, 'consumed_valid_samples', 0) else: @@ -614,7 +621,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri # Check for empty states array if not rng_state['rng_tracker_states']: raise KeyError - tensor_parallel.get_cuda_rng_tracker().set_states( + get_cuda_rng_tracker().set_states( rng_state['rng_tracker_states']) else: # backward compatability random.setstate(state_dict['random_rng_state']) @@ -624,7 +631,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri # Check for empty states array if not state_dict['rng_tracker_states']: raise KeyError - tensor_parallel.get_cuda_rng_tracker().set_states( + get_cuda_rng_tracker().set_states( state_dict['rng_tracker_states']) except KeyError: print_rank_0('Unable to load rng state from checkpoint {}. ' diff --git a/training/DeepSpeed-Domino/megatron/initialize.py b/training/DeepSpeed-Domino/megatron/initialize.py index 367ba85cb..b8aa6d3e8 100644 --- a/training/DeepSpeed-Domino/megatron/initialize.py +++ b/training/DeepSpeed-Domino/megatron/initialize.py @@ -16,7 +16,7 @@ from megatron import get_tensorboard_writer from megatron.core import mpu, tensor_parallel from megatron.arguments import parse_args, validate_args -from megatron.checkpointing import load_args_from_checkpoint +# from megatron.checkpointing import load_args_from_checkpoint from megatron.global_vars import set_global_variables from megatron.model.transformer import bias_dropout_add_fused_train from megatron.model.fused_bias_gelu import bias_gelu