diff --git a/training/DeepSpeed-Domino/README.md b/training/DeepSpeed-Domino/README.md index 92f6d1ecc..31792e01d 100644 --- a/training/DeepSpeed-Domino/README.md +++ b/training/DeepSpeed-Domino/README.md @@ -5,7 +5,6 @@ pip install -r requirements.txt ``` -## Prepare the Dataset Follow the instructions from [Megatron-DeepSpeed](https://github.com/deepspeedai/Megatron-DeepSpeed/tree/main/examples_deepspeed/universal_checkpointing#download-and-pre-process-training-dataset) to prepare the training dataset. ## Execute Domino Training @@ -38,16 +37,16 @@ The output should look like this: ``` training ... -iteration: 1 | loss: 11.318 | iteration time (ms): 2174.0469932556152 -iteration: 2 | loss: 11.307 | iteration time (ms): 1414.4024848937988 -iteration: 3 | loss: 11.323 | iteration time (ms): 1385.9455585479736 -iteration: 4 | loss: 11.310 | iteration time (ms): 1475.5175113677979 -iteration: 5 | loss: 11.306 | iteration time (ms): 1395.7207202911377 -iteration: 6 | loss: 11.315 | iteration time (ms): 1392.2104835510254 -iteration: 7 | loss: 11.314 | iteration time (ms): 1402.6703834533691 -iteration: 8 | loss: 11.309 | iteration time (ms): 1450.613260269165 -iteration: 9 | loss: 11.305 | iteration time (ms): 1473.1688499450684 -iteration: 10 | loss: 11.320 | iteration time (ms): 1398.4534740447998 +iteration: 1 | loss: 11.318 | iteration time (ms): 2174.0469932556152 +iteration: 2 | loss: 11.307 | iteration time (ms): 1414.4024848937988 +iteration: 3 | loss: 11.323 | iteration time (ms): 1385.9455585479736 +iteration: 4 | loss: 11.310 | iteration time (ms): 1475.5175113677979 +iteration: 5 | loss: 11.306 | iteration time (ms): 1395.7207202911377 +iteration: 6 | loss: 11.315 | iteration time (ms): 1392.2104835510254 +iteration: 7 | loss: 11.314 | iteration time (ms): 1402.6703834533691 +iteration: 8 | loss: 11.309 | iteration time (ms): 1450.613260269165 +iteration: 9 | loss: 11.305 | iteration time (ms): 1473.1688499450684 +iteration: 10 | loss: 11.320 | iteration time (ms): 1398.4534740447998 [2024-11-04 15:32:30,918] [INFO] [launch.py:351:main] Process 73015 exits successfully. [2024-11-04 15:32:30,918] [INFO] [launch.py:351:main] Process 73017 exits successfully. [2024-11-04 15:32:30,919] [INFO] [launch.py:351:main] Process 73014 exits successfully. diff --git a/training/DeepSpeed-Domino/domino/arguments.py b/training/DeepSpeed-Domino/domino/arguments.py index 8bc59223a..79ee62f41 100644 --- a/training/DeepSpeed-Domino/domino/arguments.py +++ b/training/DeepSpeed-Domino/domino/arguments.py @@ -92,10 +92,25 @@ def parse_args(): parser.add_argument('--position-embedding-type', type=str, default='learned_absolute', choices=['learned_absolute', 'rope'], help='Position embedding type.') + parser.add_argument('--use-rotary-position-embeddings', action='store_true', + help='Use rotary positional embeddings or not. ' + 'Deprecated: use --position-embedding-type') + parser.add_argument('--rotary-base', type=int, default=10000, + help='Base to use for rotary positional embeddings, default 10000') parser.add_argument('--rotary-percent', type=float, default=1.0, help='Percent of rotary dimension to use, default 100%') + parser.add_argument('--rotary-interleaved', action='store_true', + help='Use interleaved rotary embedding.') parser.add_argument('--rotary-seq-len-interpolation-factor', type=int, default=None, help='Sequence length interpolation factor for rotary embeddings.') + parser.add_argument('--use-rope-scaling', action='store_true', + help='Apply rope scaling as used in llama3.1') + parser.add_argument('--disable-bias-linear', action='store_false', + help='Disable bias in the linear layers', + dest='add_bias_linear') + parser.add_argument('--group-query-attention', action='store_true', + help='Use group-query attention.') + parser.add_argument('--num-query-groups', type=int, default=1) parser.add_argument('--hidden-dropout', type=float, default=0.1, help='Dropout probability for hidden state transformer.') parser.add_argument('--attention-dropout', type=float, default=0.1, @@ -180,8 +195,11 @@ def parse_args(): 'GPT2BPETokenizer', 'SentencePieceTokenizer', 'GPTSentencePieceTokenizer', + 'HuggingFaceTokenizer', 'NullTokenizer'], help='What type of tokenizer to use.') + parser.add_argument('--tokenizer-model', type=str, default=None, + help='Sentencepiece tokenizer model.') parser.add_argument('--make-vocab-size-divisible-by', type=int, default=128, help='Pad the vocab size to be divisible by this value.' 'This is added for computational efficieny reasons.') @@ -343,6 +361,12 @@ class TransformerConfig(): gated_linear_unit: bool = False activation_func: Callable = F.gelu bias_gelu_fusion = False + kv_channels: int = None + rotary_interleaved: bool = False + normalization: str = 'LayerNorm' + group_query_attention: bool = False + num_query_groups: int = 1 + seq_length: int = 2048 # initialization init_method: Callable = None diff --git a/training/DeepSpeed-Domino/domino/language_model.py b/training/DeepSpeed-Domino/domino/language_model.py index 2cfb2f9fd..80a78a077 100644 --- a/training/DeepSpeed-Domino/domino/language_model.py +++ b/training/DeepSpeed-Domino/domino/language_model.py @@ -1,6 +1,8 @@ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # This file is adapted from language_model.py in Megatron-LM +from typing import Literal, Optional + import torch from torch import einsum, nn from domino.arguments import get_args @@ -14,6 +16,9 @@ from domino.tensor_parallel.partition import _initialize_affine_weight_gpu, set_tensor_model_parallel_attributes from domino.tensor_parallel.partition import ColumnParallelLinear, RowParallelLinearNoComm +from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding +from megatron.model.utils import get_norm + from deepspeed.runtime.domino.transformer import DominoTransformer def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, @@ -45,12 +50,18 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, def get_language_model(config, num_tokentypes, encoder_attn_mask_type, pre_process=True, post_process=True): + args = get_args() language_model = TransformerLanguageModel( config, encoder_attn_mask_type, num_tokentypes=num_tokentypes, pre_process=pre_process, - post_process=post_process + post_process=post_process, + position_embedding_type=args.position_embedding_type, + rotary_percent=args.rotary_percent, + rotary_base=args.rotary_base, + rope_scaling=args.use_rope_scaling, + seq_len_interpolation_factor = args.rotary_seq_len_interpolation_factor, ) return language_model @@ -85,38 +96,18 @@ def forward(self, input_ids, position_ids): return combined_embeds -class RotaryEmbedding(nn.Module): - def __init__(self, dim, seq_len_interpolation_factor=None): - super().__init__() - self.seq_len_interpolation_factor = seq_len_interpolation_factor - inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) - self.register_buffer('inv_freq', inv_freq, persistent=False) - - def forward(self, max_seq_len, offset=0): - seq = torch.arange(max_seq_len, device=self.inv_freq.device) + offset - if self.seq_len_interpolation_factor is not None: - seq = seq.type_as(self.inv_freq) - seq *= 1 / self.seq_len_interpolation_factor - freqs = einsum('i , j -> i j', seq.type_as(self.inv_freq), self.inv_freq) - # first part even vector components, second part odd vector components, - # 2 * dim in dimension size - emb = torch.cat((freqs, freqs), dim=-1) - # emb [seq_length, .., dim] - return emb[:, None, None, :] - - # def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): - # state_dict.pop(f'{prefix}inv_freq', None) - # return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) - - class TransformerLanguageModel(DominoModule): def __init__(self, config, encoder_attn_mask_type, num_tokentypes=0, pre_process=True, - post_process=True): - + post_process=True, + position_embedding_type: Literal['learned_absolute', 'rope', 'none'] = 'learned_absolute', + rotary_percent: float = 1.0, + rotary_base: int = 10000, + rope_scaling: bool = False, + seq_len_interpolation_factor: Optional[float] = None,): args = get_args() super(TransformerLanguageModel, self).__init__(share_embeddings_and_output_weights=True) @@ -127,6 +118,11 @@ def __init__(self, self.init_method = config.init_method self.encoder_attn_mask_type = encoder_attn_mask_type self.encoder_hidden_state = None + self.position_embedding_type = position_embedding_type + self.rotary_percent = rotary_percent + self.rotary_base = rotary_base + self.rotary_scaling = rope_scaling + self.seq_length = config.seq_length if self.pre_process: self.embedding = Embedding(self.hidden_size, @@ -138,19 +134,18 @@ def __init__(self, self.use_rotary_position_embeddings = \ args.position_embedding_type == 'rope' if self.use_rotary_position_embeddings: - self.seq_length = args.seq_length - rotary_dim = args.hidden_size // args.num_attention_heads \ - if args.kv_channels is None else args.kv_channels - if args.rotary_percent < 1.0: - rotary_dim = int(rotary_dim * args.rotary_percent) self.rotary_pos_emb = RotaryEmbedding( - rotary_dim, - seq_len_interpolation_factor=args.rotary_seq_len_interpolation_factor + kv_channels=config.kv_channels, + rotary_percent=rotary_percent, + rotary_interleaved=config.rotary_interleaved, + seq_len_interpolation_factor=seq_len_interpolation_factor, + rotary_base=rotary_base, + rope_scaling=rope_scaling, ) self.encoder = DominoTransformer( config, ModelType.encoder_or_decoder, mpu, - fused_layer_norm, _initialize_affine_weight_gpu, + get_norm, _initialize_affine_weight_gpu, ColumnParallelLinear, RowParallelLinearNoComm, apply_rotary_pos_emb, bias_dropout_add_fused_train, bias_dropout_add_fused_inference, self_attn_mask_type=self.encoder_attn_mask_type, diff --git a/training/DeepSpeed-Domino/megatron/core/datasets/megatron_tokenizer.py b/training/DeepSpeed-Domino/megatron/core/datasets/megatron_tokenizer.py new file mode 100644 index 000000000..efc96992b --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/core/datasets/megatron_tokenizer.py @@ -0,0 +1,116 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +import json +from abc import ABC, abstractmethod +from collections import OrderedDict +from typing import Any +import numpy +class MegatronTokenizer(ABC): + """Abstract class for tokenizer + Absent a config or class-specific tracking of which objects are uniquely identifying, we must + include all key word arguments as unique identifiers + Args: + tokenizer_paths (Tuple[str]): All tokenizer source paths or prefixes + tokenizer_options (Dict[str, Any]): All tokenizer options + """ + def __init__(self, *tokenizer_paths: str, **tokenizer_options: Any): + self.unique_identifiers = OrderedDict() + self.unique_identifiers["class"] = type(self).__name__ + self.unique_identifiers["tokenizer_path"] = list(tokenizer_paths) + for option in tokenizer_options: + self.unique_identifiers[option] = str(tokenizer_options[option]) + self.unique_description = json.dumps(self.unique_identifiers, indent=4) + super().__init__() + @abstractmethod + def tokenize(self, text: str) -> numpy.ndarray: + """Convert text to embedding ids + Args: + text (str): The text to convert + Returns: + numpy.ndarray: The converted embedding ids + """ + pass + def detokenize(self, ids: numpy.ndarray) -> str: + """Convert embedding ids to text + Args: + ids (numpy.ndarray): The ids to convert + Returns: + str: The converted text + Raises: + NotImplementedError: Non-abstract, optional method + """ + raise NotImplementedError("{} has no method 'detokenize'".format(type(self).__name__)) + def offsets(self, ids: list[int], text: str) -> list[int]: + """Convert embedding ids to text offsets + Args: + ids (list[int]): The ids to convert + text (str): The text to convert + Returns: + list[int]: The converted offsets + Raises: + NotImplementedError: Non-abstract, optional method + """ + raise NotImplementedError("{} has no method 'offsets'".format(type(self).__name__)) + @property + @abstractmethod + def vocab(self): + """Dictionary from vocab text token to id token""" + pass + @property + @abstractmethod + def inv_vocab(self): + """Dictionary from vocab id token to text token""" + pass + @property + @abstractmethod + def vocab_size(self): + """The vocabulary size""" + pass + @property + def cls(self): + """The CLS token id + Raises: + NotImplementedError: Non-abstract, optional attribute + """ + raise NotImplementedError("{} has no attribute 'cls'".format(type(self).__name__)) + @property + def sep(self): + """The SEP token id + Raises: + NotImplementedError: Non-abstract, optional attribute + """ + raise NotImplementedError("{} has no attribute 'sep'".format(type(self).__name__)) + @property + def pad(self): + """The PAD token id + Raises: + NotImplementedError: Non-abstract, optional attribute + """ + raise NotImplementedError("{} has no attribute 'pad'".format(type(self).__name__)) + @property + def eod(self): + """The EOD token id + Raises: + NotImplementedError: Non-abstract, optional attribute + """ + raise NotImplementedError("{} has no attribute 'eod'".format(type(self).__name__)) + @property + def bos(self): + """The BOS token id + Raises: + NotImplementedError: Non-abstract, optional attribute + """ + raise NotImplementedError("{} has no attribute 'bos'".format(type(self).__name__)) + @property + def eos(self): + """The EOS token id + Raises: + NotImplementedError: Non-abstract, optional attribute + """ + raise NotImplementedError("{} has no attribute 'eos'".format(type(self).__name__)) + @property + def mask(self): + """The MASK token id + Raises: + NotImplementedError: Non-abstract, optional attribute + """ + raise NotImplementedError("{} has no attribute 'mask'".format(type(self).__name__)) \ No newline at end of file diff --git a/training/DeepSpeed-Domino/megatron/core/models/common/embeddings/rotary_pos_embedding.py b/training/DeepSpeed-Domino/megatron/core/models/common/embeddings/rotary_pos_embedding.py new file mode 100644 index 000000000..bab603e84 --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/core/models/common/embeddings/rotary_pos_embedding.py @@ -0,0 +1,195 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from megatron.core.transformer.transformer_config import TransformerConfig + from megatron.core.transformer.transformer_block import TransformerBlock + from megatron.core.inference_params import InferenceParams + from megatron.core.packed_seq_params import PackedSeqParams +import logging +import math +from functools import lru_cache +import torch +from torch import Tensor, nn +from megatron.core import parallel_state +from megatron.core.models.common.rope_utils import ( # for backward compatibility; pylint: disable=unused-import + _apply_rotary_pos_emb_bshd, + _apply_rotary_pos_emb_thd, + _rotate_half, + apply_rotary_pos_emb, + get_pos_emb_on_this_cp_rank, +) +logger = logging.getLogger(__name__) +__all__ = ['RotaryEmbedding'] +class RotaryEmbedding(nn.Module): + """Rotary Embedding for language model. + Args: + kv_channels (int): Projection weights dimension in multi-head attention. Obtained + from transformer config + rotary_percent (float): Percent of rotary dimension to use for rotary position + embeddings. + rotary_interleaved (bool, optional): If True, interleaved rotary position embeddings. + Defaults to False. + seq_len_interpolation_factor (float, optional): scale of linearly interpolating RoPE + for longer sequences. The value must be a float larger than 1.0. Defaults to None + rotary_base (int, optional): Base period for rotary position embeddings. Defaults to + 10000. + rope_scaling (bool, optional): Apply rope scaling as used in llama 3.1 + use_cpu_initialization (bool, optional): If False, initialize the inv_freq directly + on the GPU. Defaults to False + """ + def __init__( + self, + kv_channels: int, + rotary_percent: float, + rotary_interleaved: bool = False, + seq_len_interpolation_factor: float = None, + rotary_base: int = 10000, + rope_scaling: bool = False, + use_cpu_initialization: bool = False, + ) -> None: + super().__init__() + dim = kv_channels + if rotary_percent < 1.0: + dim = int(dim * rotary_percent) + self.rotary_interleaved = rotary_interleaved + self.seq_len_interpolation_factor = seq_len_interpolation_factor + device = 'cpu' if use_cpu_initialization else torch.cuda.current_device() + self.inv_freq = 1.0 / ( + rotary_base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) + ) + if rope_scaling: + self.inv_freq = self._apply_scaling(self.inv_freq) + def _apply_scaling( + self, + freqs, + factor=8, + low_freq_factor=1, + high_freq_factor=4, + original_max_position_embeddings=8192, + ): + # This implementation is adapted from: + # https://github.com/huggingface/transformers/blob/2a5a6ad18aa22e98429bb5ecb880660328030ea0/src/transformers/modeling_rope_utils.py#L303-L343 + factor = factor # `8` in the original implementation + low_freq_factor = low_freq_factor # `1` in the original implementation + high_freq_factor = high_freq_factor # `4` in the original implementation + old_context_len = original_max_position_embeddings # `8192` in the original implementation + low_freq_wavelen = old_context_len / low_freq_factor + high_freq_wavelen = old_context_len / high_freq_factor + wavelen = 2 * math.pi / freqs + # wavelen < high_freq_wavelen: do nothing + # wavelen > low_freq_wavelen: divide by factor + inv_freq_llama = torch.where(wavelen > low_freq_wavelen, freqs / factor, freqs) + # otherwise: interpolate between the two, using a smooth factor + smooth_factor = (old_context_len / wavelen - low_freq_factor) / ( + high_freq_factor - low_freq_factor + ) + smoothed_inv_freq = ( + 1 - smooth_factor + ) * inv_freq_llama / factor + smooth_factor * inv_freq_llama + is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen) + inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama) + return inv_freq_llama + def get_freqs_non_repeated(self, max_seq_len: int, offset: int = 0) -> Tensor: + """Generates matrix of frequencies based on positions in the sequence, + used to create positional encodings""" + seq = ( + torch.arange(max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + + offset + ) + if self.seq_len_interpolation_factor is not None: + seq *= 1 / self.seq_len_interpolation_factor + freqs = torch.outer(seq, self.inv_freq) # [seq len, dim] + return freqs + def get_cos_sin(self, max_seq_len: int, offset: int = 0) -> (Tensor, Tensor): + """Cosine and sine values for RoPE are precomputed for all positions up to the maximum + sequence length""" + freqs = self.get_freqs_non_repeated(max_seq_len, offset) + cos = torch.cos(freqs) + sin = torch.sin(freqs) + return cos, sin + @lru_cache(maxsize=32) + def forward(self, max_seq_len: int, offset: int = 0, packed_seq: bool = False) -> Tensor: + """Forward pass of RoPE embedding. + Args: + max_seq_len (int): Maximum size of sequence + offset (int, optional): RoPE offset. Defaults to 0. + packed_seq (bool, optional): Whether to use packed sequence. Defaults to False. + Returns: + Tensor: Embeddings after applying RoPE. + """ + if self.inv_freq.device.type == 'cpu': + # move `inv_freq` to GPU once at the first micro-batch forward pass + self.inv_freq = self.inv_freq.to(device=torch.cuda.current_device()) + freqs = self.get_freqs_non_repeated(max_seq_len, offset) + # first part even vector components, second part odd vector components, + # 2 * dim in dimension size + if not self.rotary_interleaved: + emb = torch.cat((freqs, freqs), dim=-1) + else: + emb = torch.stack((freqs.view(-1, 1), freqs.view(-1, 1)), dim=-1).view( + freqs.shape[0], -1 + ) + # emb [seq_length, .., dim] + emb = emb[:, None, None, :] + # if parallel_state.get_context_parallel_world_size() > 1 and not packed_seq: + # # slice rotary_pos_emb along sequence dimension and select the parition of the current + # # CP rank + # emb = get_pos_emb_on_this_cp_rank(emb, 0) + return emb + def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): + state_dict.pop(f'{prefix}inv_freq', None) + return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) + def get_rotary_seq_len( + self, + inference_params: InferenceParams, + transformer: TransformerBlock, + transformer_input: Tensor, + transformer_config: TransformerConfig, + packed_seq_params: PackedSeqParams, + ) -> float: + """Function to get the rotary sequence length. + Args: + inference_params : Used during Inference time + transformer (TransformerBlock): The transformer block (decoder/encoder) used + by the model + transformer_input (Tensor): Input tensor to the transformer + transformer_config (TransformerConfig): Transformer config used by the model + packed_seq_params (PackedSeqParams): Packed sequence params + Returns: + float: The rotary sequence length + """ + if packed_seq_params is not None: + # max_seqlen are the max sequence length in the packed sequence before being divived + # by the tp and cp size. + return max(packed_seq_params.max_seqlen_q, packed_seq_params.max_seqlen_kv) + elif inference_params is not None: + rotary_seq_len = inference_params.max_sequence_length + else: + if transformer.input_tensor is not None: + rotary_seq_len = transformer.input_tensor.size(0) + else: + rotary_seq_len = transformer_input.size(0) + if transformer_config.sequence_parallel: + rotary_seq_len *= transformer_config.tensor_model_parallel_size + rotary_seq_len *= transformer_config.context_parallel_size + return rotary_seq_len +def _rotate_half(x): + """ + change sign so the last dimension becomes [-odd, +even] + """ + x1, x2 = torch.chunk(x, 2, dim=-1) + return torch.cat((-x2, x1), dim=-1) +def apply_rotary_pos_emb(t, freqs): + """ + input tensor t is of shape [seq_length, ..., dim] + rotary positional embeding tensor freqs is of shape [seq_length, ..., dim] + check https://kexue.fm/archives/8265 for detailed formulas + """ + rot_dim = freqs.shape[-1] + # ideally t_pass is empty so rotary pos embedding is applied to all tensor t + t, t_pass = t[..., :rot_dim], t[..., rot_dim:] + # first part is cosine component + # second part is sine component, need to change signs with _rotate_half method + t = (t * freqs.cos()) + (_rotate_half(t) * freqs.sin()) + return torch.cat((t, t_pass), dim=-1) \ No newline at end of file diff --git a/training/DeepSpeed-Domino/megatron/core/models/common/rope_utils.py b/training/DeepSpeed-Domino/megatron/core/models/common/rope_utils.py new file mode 100644 index 000000000..5ca4085dc --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/core/models/common/rope_utils.py @@ -0,0 +1,195 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +from __future__ import annotations +from typing import TYPE_CHECKING, Optional +if TYPE_CHECKING: + from megatron.core.transformer.transformer_config import TransformerConfig +import logging +import torch +from torch import Tensor +from megatron.core import parallel_state +logger = logging.getLogger(__name__) +try: + from megatron.core.extensions.transformer_engine import ( + fused_apply_rotary_pos_emb, + fused_apply_rotary_pos_emb_thd, + ) + HAVE_APPLY_ROPE_FUSION = True +except ImportError: + try: + from apex.transformer.functional import ( + fused_apply_rotary_pos_emb, + fused_apply_rotary_pos_emb_thd, + ) + HAVE_APPLY_ROPE_FUSION = True + except ImportError: + HAVE_APPLY_ROPE_FUSION = False +try: + from flash_attn.layers.rotary import apply_rotary_emb as apply_rotary_emb_flash +except ImportError: + apply_rotary_emb_flash = None +__all__ = ['apply_rotary_emb_flash'] +def get_pos_emb_on_this_cp_rank(pos_emb: Tensor, seq_dim: int) -> Tensor: + """Get the position embedding on the current context parallel rank. + Args: + pos_emb (Tensor): Positional embedding tensor + seq_dim (int): Sequence dimension + """ + cp_size = parallel_state.get_context_parallel_world_size() + cp_rank = parallel_state.get_context_parallel_rank() + cp_idx = torch.tensor( + [cp_rank, (2 * cp_size - cp_rank - 1)], device="cpu", pin_memory=True + ).cuda(non_blocking=True) + pos_emb = pos_emb.view( + *pos_emb.shape[:seq_dim], 2 * cp_size, -1, *pos_emb.shape[(seq_dim + 1) :] + ) + pos_emb = pos_emb.index_select(seq_dim, cp_idx) + pos_emb = pos_emb.view(*pos_emb.shape[:seq_dim], -1, *pos_emb.shape[(seq_dim + 2) :]) + return pos_emb +def _rotate_half(x: Tensor, rotary_interleaved: bool) -> Tensor: + """Change sign so the last dimension becomes [-odd, +even] + Args: + x (Tensor): Input tensor + Returns: + Tensor: Tensor rotated half + """ + if not rotary_interleaved: + x1, x2 = torch.chunk(x, 2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + else: + x1 = x[:, :, :, ::2] + x2 = x[:, :, :, 1::2] + x_new = torch.stack((-x2, x1), dim=-1) + return x_new.view(x_new.shape[0], x_new.shape[1], x_new.shape[2], -1) +def _apply_rotary_pos_emb_bshd( + t: Tensor, + freqs: Tensor, + rotary_interleaved: bool = False, + multi_latent_attention: bool = False, + mscale: float = 1.0, +) -> Tensor: + """Apply rotary positional embedding to input tensor T. + check https://kexue.fm/archives/8265 for detailed formulas + Args: + t (Tensor): Input tensor T is of shape [seq_length, ... , dim] + freqs (Tensor): Rotary Positional embedding tensor freq is of shape [seq_length, ..., dim] + Returns: + Tensor: The input tensor after applying RoPE + """ + rot_dim = freqs.shape[-1] + # ideally t_pass is empty so rotary pos embedding is applied to all tensor t + t, t_pass = t[..., :rot_dim], t[..., rot_dim:] + if multi_latent_attention: + x1 = t[..., 0::2] + x2 = t[..., 1::2] + t = torch.cat((x1, x2), dim=-1) + # first part is cosine component + # second part is sine component, need to change signs with _rotate_half method + cos_ = (torch.cos(freqs) * mscale).to(t.dtype) + sin_ = (torch.sin(freqs) * mscale).to(t.dtype) + t = (t * cos_) + (_rotate_half(t, rotary_interleaved) * sin_) + return torch.cat((t, t_pass), dim=-1) +def _get_thd_freqs_on_this_cp_rank(cp_rank: int, cp_size: int, x: Tensor, freqs: Tensor) -> Tensor: + if cp_size > 1: + cp_seg = x.size(0) // 2 + full_seqlen = cp_size * x.size(0) + return torch.cat( + [ + freqs[cp_rank * cp_seg : (cp_rank + 1) * cp_seg], + freqs[full_seqlen - (cp_rank + 1) * cp_seg : full_seqlen - cp_rank * cp_seg], + ] + ) + else: + return freqs[: x.size(0)] +def _apply_rotary_pos_emb_thd( + t: Tensor, + cu_seqlens: Tensor, + freqs: Tensor, + rotary_interleaved: bool = False, + multi_latent_attention: bool = False, + mscale: float = 1.0, +) -> Tensor: + """A baseline implementation of applying RoPE for `thd` format. + Args: + t (Tensor): Input tensor T is of shape [t, h, d] + cu_seqlens(Tensor): Cumulative sum of sequence lengths in a batch for `t`, + with shape [b + 1] and dtype torch.int32. + freqs (Tensor): Rotary Positional embedding tensor freq is of shape [max_s, 1, 1, d] + Returns: + Tensor: Shape [t, h, d]. The input tensor after applying RoPE. + """ + cp_size = parallel_state.get_context_parallel_world_size() + cp_rank = parallel_state.get_context_parallel_rank() + cu_seqlens = cu_seqlens // cp_size + seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + return torch.cat( + [ + _apply_rotary_pos_emb_bshd( + x.unsqueeze(1), + _get_thd_freqs_on_this_cp_rank(cp_rank, cp_size, x, freqs), + rotary_interleaved=rotary_interleaved, + multi_latent_attention=multi_latent_attention, + mscale=mscale, + ) + for x in torch.split(t, seqlens) + ] + ).squeeze(1) +def apply_rotary_pos_emb( + t: Tensor, + freqs: Tensor, + config: TransformerConfig, + cu_seqlens: Optional[Tensor] = None, + mscale: float = 1.0, +): + """ + Reroute to the appropriate apply_rotary_pos_emb function depending on + fused/unfused kernels, or bshd (conventional) / thd (packed seq) format + """ + if config.apply_rope_fusion: + return fused_apply_rotary_pos_emb(t, freqs) + else: + if cu_seqlens is None: + return _apply_rotary_pos_emb_bshd( + t, + freqs, + rotary_interleaved=config.rotary_interleaved, + multi_latent_attention=config.multi_latent_attention, + mscale=mscale, + ) + else: + return _apply_rotary_pos_emb_thd( + t, + cu_seqlens, + freqs, + rotary_interleaved=config.rotary_interleaved, + multi_latent_attention=config.multi_latent_attention, + mscale=mscale, + ) +def apply_rotary_pos_emb_with_cos_sin( + t: Tensor, cos: Tensor, sin: Tensor, rotary_interleaved: bool = False +) -> Tensor: + """ + This function applies rotary positional embedding to the target tensor t + using precomputed cos and sin of size (seq_len, d_rot / 2) + """ + cos = cos.to(t.dtype) + sin = sin.to(t.dtype) + if apply_rotary_emb_flash is None: + # Combine cos and sin into freqs + freqs = torch.stack([cos, sin], dim=-1).flatten(start_dim=-2) + # Expand freqs to match t's shape + while freqs.dim() < t.dim(): + freqs = freqs.unsqueeze(1) + freqs = freqs.expand(t.shape[:-1] + (-1,)) + y = _apply_rotary_pos_emb_bshd( + t, + freqs, + rotary_interleaved=rotary_interleaved, + multi_latent_attention=False, + mscale=1.0, + ) + else: + # Use Flash Attention's optimized kernel for rotary embedding + t = t.permute(1, 0, 2, 3) + y = apply_rotary_emb_flash(t, cos, sin, rotary_interleaved) + y = y.permute(1, 0, 2, 3) + return y \ No newline at end of file diff --git a/training/DeepSpeed-Domino/megatron/model/__init__.py b/training/DeepSpeed-Domino/megatron/model/__init__.py index f5025bf25..be9f3514d 100644 --- a/training/DeepSpeed-Domino/megatron/model/__init__.py +++ b/training/DeepSpeed-Domino/megatron/model/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. from .fused_layer_norm import MixedFusedLayerNorm as LayerNorm +from .rms_norm import RMSNorm from .distributed import DistributedDataParallel from .bert_model import BertModel diff --git a/training/DeepSpeed-Domino/megatron/model/language_model.py b/training/DeepSpeed-Domino/megatron/model/language_model.py index 85b5dc5cb..731b4d012 100644 --- a/training/DeepSpeed-Domino/megatron/model/language_model.py +++ b/training/DeepSpeed-Domino/megatron/model/language_model.py @@ -8,7 +8,7 @@ from megatron import get_args from megatron.core import mpu, tensor_parallel from megatron.core.enums import ModelType -from megatron.core.models.common.rotary_pos_embedding import RotaryEmbedding +from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding from .enums import AttnMaskType, LayerType from .module import MegatronModule diff --git a/training/DeepSpeed-Domino/megatron/model/rms_norm.py b/training/DeepSpeed-Domino/megatron/model/rms_norm.py new file mode 100644 index 000000000..62252c6bb --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/model/rms_norm.py @@ -0,0 +1,25 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +import torch +from torch import nn +class RMSNorm(torch.nn.Module): + def __init__(self, + dim: int, + eps: float = 1e-6, + sequence_parallel: bool = False, + config: dict = None): + """RMS Normaliation module + Args: + dim (int): The width of input, i.e. hidden size + eps (float): epsilon to use for the norm, default to 1e-6 + sequence_parallel (bool): Set to true if sequence parallelism is being used, + this marks the weights as needing to be allreduced. + """ + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + setattr(self.weight, 'sequence_parallel', sequence_parallel) + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + def forward(self, x): + output = self._norm(x.float()).type_as(x) + return output * self.weight \ No newline at end of file diff --git a/training/DeepSpeed-Domino/megatron/model/transformer.py b/training/DeepSpeed-Domino/megatron/model/transformer.py index a7898156f..9573e7190 100644 --- a/training/DeepSpeed-Domino/megatron/model/transformer.py +++ b/training/DeepSpeed-Domino/megatron/model/transformer.py @@ -16,7 +16,7 @@ from megatron.model.enums import AttnMaskType, LayerType, AttnType from megatron.model.fused_softmax import FusedScaleMaskSoftmax from megatron.model.fused_bias_gelu import bias_gelu_impl -from megatron.core.models.common.rotary_pos_embedding import apply_rotary_pos_emb +from megatron.core.models.common.embeddings.rotary_pos_embedding import apply_rotary_pos_emb from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu try: diff --git a/training/DeepSpeed-Domino/megatron/model/utils.py b/training/DeepSpeed-Domino/megatron/model/utils.py index cf3727c02..f0080db6d 100644 --- a/training/DeepSpeed-Domino/megatron/model/utils.py +++ b/training/DeepSpeed-Domino/megatron/model/utils.py @@ -7,6 +7,7 @@ import torch from megatron import get_args +from megatron.model import LayerNorm, RMSNorm def init_method_normal(sigma): """Init method based on N(0, sigma).""" @@ -52,3 +53,15 @@ def openai_gelu(x): @torch.jit.script def erf_gelu(x): return x * 0.5 * (torch.erf(x / 1.41421).to(dtype=x.dtype)+torch.ones_like(x).to(dtype=x.dtype)) + +def get_norm(config): + if config.normalization == "LayerNorm": + return LayerNorm( + config.hidden_size, + eps=config.layernorm_epsilon, + no_persist_layer_norm=not config.persist_layer_norm) + elif config.normalization == "RMSNorm": + return RMSNorm(dim=config.hidden_size, + eps=config.layernorm_epsilon) + else: + raise Exception(f"unsupported norm type '{config.normalization}'.") diff --git a/training/DeepSpeed-Domino/megatron/tokenizer/tokenizer.py b/training/DeepSpeed-Domino/megatron/tokenizer/tokenizer.py index 79dab75a0..5f813e615 100644 --- a/training/DeepSpeed-Domino/megatron/tokenizer/tokenizer.py +++ b/training/DeepSpeed-Domino/megatron/tokenizer/tokenizer.py @@ -5,11 +5,13 @@ from abc import ABC from abc import abstractmethod +from megatron.core.datasets.megatron_tokenizer import MegatronTokenizer + from .bert_tokenization import FullTokenizer as FullBertTokenizer from .gpt2_tokenization import GPT2Tokenizer -def build_tokenizer(args): +def build_tokenizer(args, **kwargs): """Initialize tokenizer.""" if args.rank == 0: print('> building {} tokenizer ...'.format(args.tokenizer_type), @@ -36,6 +38,8 @@ def build_tokenizer(args): elif args.tokenizer_type == 'GPTSentencePieceTokenizer': assert args.tokenizer_model is not None tokenizer = _GPTSentencePieceTokenizer(args.tokenizer_model) + elif args.tokenizer_type == 'HuggingFaceTokenizer': + tokenizer = _HuggingFaceTokenizer(args.tokenizer_model, **kwargs) elif args.tokenizer_type == 'NullTokenizer': assert args.vocab_size is not None tokenizer = _NullTokenizer(args.vocab_size) @@ -66,6 +70,54 @@ def _vocab_size_with_padding(orig_vocab_size, args): return after +class _HuggingFaceTokenizer(MegatronTokenizer): + def __init__(self, pretrained_model_name_or_path, **kwargs): + super().__init__(pretrained_model_name_or_path, **kwargs) + try: + import transformers + except ImportError: + raise EnvironmentError( + f"The transformers library must be installed to use huggingface_tokenizer_provider" + ) + # TODO(bnorick): download tokenizer once to lustre and use force offline to make sure all tasks read it from there + self._tokenizer = transformers.AutoTokenizer.from_pretrained( + pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs + ) + self._vocab = self._tokenizer.get_vocab() + self._inv_vocab = {token_id: token for token, token_id in self._vocab.items()} + @property + def vocab_size(self): + return len(self._tokenizer) + @property + def vocab(self): + """Dictionary from vocab text token to id token.""" + return self._vocab + @property + def inv_vocab(self): + """Dictionary from vocab id token to text token.""" + return self._inv_vocab + @property + def decoder(self): + return self._inv_vocab + def tokenize(self, text, **kwargs): + return self._tokenizer(text, **kwargs).input_ids + def detokenize(self, token_ids, **kwargs): + return self._tokenizer.decode(token_ids, **kwargs) + def offsets(self, ids: list[int], text: str) -> list[int]: + retok_ids: "transformers.BatchEncoding" = self._tokenizer(text) + offsets, next_start_idx = [], 0 + for i in range(len(ids)): + span = retok_ids.token_to_chars(i) + if span is not None: + offsets.append(span.start) + next_start_idx = span.end + else: + offsets.append(next_start_idx) + return offsets + @property + def eod(self): + return self._tokenizer.eos_token_id + class AbstractTokenizer(ABC): """Abstract class for tokenizer.""" diff --git a/training/DeepSpeed-Domino/pretrain_llama3_8b.sh b/training/DeepSpeed-Domino/pretrain_llama3_8b.sh new file mode 100644 index 000000000..113fb1a2b --- /dev/null +++ b/training/DeepSpeed-Domino/pretrain_llama3_8b.sh @@ -0,0 +1,139 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# This file is adapted from pretrain_llama.sh in Megatron-LM + +#!/bin/bash + +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export CUDA_VISIBLE_DEVICES=0,2 +GPUS_PER_NODE=2 +# Change for multinode config +MASTER_ADDR=localhost +MASTER_PORT=6000 +NNODES=1 +NODE_RANK=0 +WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) + +CHECKPOINT_PATH=/workspace/dataset/checkpoint +TOKENIZER_PATH=/workspace/model/Llama-3.1-8B +rm -rf $CHECKPOINT_PATH/* +rm -rf ./wandb/* +VOCAB_FILE="/workspace/dataset/gpt2-vocab.json" +MERGE_FILE="/workspace/dataset/gpt2-merges.txt" +DATA_PATH="/workspace/dataset/BookCorpusDataset_text_document" + +SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" +export PYTHONPATH=$SCRIPT_DIR:$PYTHONPATH + +DISTRIBUTED_ARGS=" + --nproc_per_node $GPUS_PER_NODE \ + --nnodes $NNODES \ + --node_rank $NODE_RANK \ + --master_addr $MASTER_ADDR \ + --master_port $MASTER_PORT +" + +LLAMA_ARGS=" + --disable-bias-linear \ + --tokenizer-type HuggingFaceTokenizer \ + --tokenizer-model ${TOKENIZER_PATH} \ + --normalization RMSNorm \ + --group-query-attention \ + --num-query-groups 8 \ + --no-masked-softmax-fusion \ + --attention-dropout 0.0 \ + --hidden-dropout 0.0 \ + --position-embedding-type rope \ + --rotary-percent 1.0 \ + --rotary-base 500000 \ + --use-rope-scaling \ + --use-rotary-position-embeddings \ + --swiglu \ + --num-layers 32 \ + --hidden-size 4096 \ + --ffn-hidden-size 14336 \ + --num-attention-heads 32 \ + --max-position-embeddings 131072 \ + --seq-length 2048 \ + + --micro-batch-size 4 \ + --global-batch-size 8 \ + --lr 0.00015 \ + --train-iters 80 \ + --lr-decay-iters 320000 \ + --lr-decay-style cosine \ + --min-lr 1.0e-5 \ + --weight-decay 1e-2 \ + --lr-warmup-fraction .01 \ + --clip-grad 1.0 \ + --fp16 \ + --tensor-model-parallel-size $WORLD_SIZE \ + --seed 3407 \ +" +# llama3.1 70B +# LLAMA_ARGS=" +# --disable-bias-linear \ +# --tokenizer-type HuggingFaceTokenizer \ +# --tokenizer-model ${TOKENIZER_PATH} \ +# --transformer-impl local \ +# --normalization RMSNorm \ +# --group-query-attention \ +# --num-query-groups 8 \ +# --no-masked-softmax-fusion \ +# --attention-softmax-in-fp32 \ +# --attention-dropout 0.0 \ +# --hidden-dropout 0.0 \ +# --untie-embeddings-and-output-weights \ +# --position-embedding-type rope \ +# --rotary-percent 1.0 \ +# --rotary-base 500000 \ +# --use-rope-scaling \ +# --use-rotary-position-embeddings \ +# --swiglu \ +# --tensor-model-parallel-size 1 \ +# --pipeline-model-parallel-size 1 \ +# --num-layers 80 \ +# --hidden-size 8192 \ +# --ffn-hidden-size 28672 \ +# --num-attention-heads 64 \ +# --max-position-embeddings 131072 \ +# --seq-length 8192 \ + +# --micro-batch-size 2 \ +# --global-batch-size 8 \ +# --lr 0.00015 \ +# --train-iters 80 \ +# --lr-decay-iters 320000 \ +# --lr-decay-style cosine \ +# --min-lr 1.0e-5 \ +# --weight-decay 1e-2 \ +# --lr-warmup-fraction .01 \ +# --clip-grad 1.0 \ +# --no-gradient-accumulation-fusion \ +# --fp16 \ +# --tensor-model-parallel-size $WORLD_SIZE \ +# --seed 3407 \ +# --causal-lm +# " + +DATA_ARGS=" + --data-path $DATA_PATH \ + --vocab-file $VOCAB_FILE \ + --merge-file $MERGE_FILE \ + --split 949,50,1 +" + +OUTPUT_ARGS=" + --log-interval 1 \ + --save-interval 10000 \ + --eval-interval 1000 \ + --eval-iters 1 +" + +cmd="deepspeed --num_gpus $WORLD_SIZE \ + pretrain_llama.py \ + $LLAMA_ARGS \ + $DATA_ARGS \ + $OUTPUT_ARGS + " +echo $cmd +eval $cmd \ No newline at end of file