From 20d557031994e6c3fef7d0c3146ee80ecd9f8624 Mon Sep 17 00:00:00 2001 From: Mark Worrall Date: Tue, 21 Mar 2023 14:32:35 +0000 Subject: [PATCH 1/9] feat: update submodule --- configs/local_setup.yml | 4 +- megatron/checkpointing.py | 1 + megatron/model/transformer.py | 168 ++++++++++++++++++++++++++- megatron/mpu/__init__.py | 1 + megatron/mpu/layers.py | 65 +++++++++++ megatron/neox_arguments/neox_args.py | 16 +++ megatron/training.py | 14 +++ 7 files changed, 266 insertions(+), 3 deletions(-) diff --git a/configs/local_setup.yml b/configs/local_setup.yml index 99b3bdfd6..cdb96fb4d 100644 --- a/configs/local_setup.yml +++ b/configs/local_setup.yml @@ -26,5 +26,7 @@ "log-dir": "logs", "use_wandb": True, "wandb_host": "https://api.wandb.ai", - "wandb_project": "neox" + "wandb_project": "neox", + "num_gpus": 1, + "ia3_prompt_tuning": True } diff --git a/megatron/checkpointing.py b/megatron/checkpointing.py index 9502f5b32..4af203763 100644 --- a/megatron/checkpointing.py +++ b/megatron/checkpointing.py @@ -241,6 +241,7 @@ def load_checkpoint( load_optimizer_states=load_optim_and_scheduler, load_lr_scheduler_states=load_optim_and_scheduler, tag=tag, + load_module_strict=False ) if checkpoint_name is None: diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index e753a3532..55f046ce8 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -18,6 +18,7 @@ """Transformer.""" import math +import sys import torch import torch.nn.functional as F import torch.nn as nn @@ -93,7 +94,9 @@ def __init__( if self.activation_type == "geglu" else ff_mult * neox_args.hidden_size ) - self.dense_h_to_4h = mpu.ColumnParallelLinear( + mlp_column_parallel_cls = getattr(mpu, neox_args.mlp_column_parallel_cls) + + self.dense_h_to_4h = mlp_column_parallel_cls( neox_args=neox_args, input_size=neox_args.hidden_size, output_size=ff_dim, @@ -590,6 +593,166 @@ def forward(self, hidden_states, attention_mask, layer_past=None): return output, bias +class ParallelSelfAttentionIA3(ParallelSelfAttention): + def __init__( + self, + neox_args, + attention_mask_func, + init_method, + output_layer_init_method, + layer_number, + rpe=None, + rotary=False, + use_cache=False, + parallel_output=False, + ): + super().__init__( + neox_args, + attention_mask_func, + init_method, + output_layer_init_method, + layer_number, + rpe=rpe, + rotary=rotary, + use_cache=use_cache, + parallel_output=parallel_output, + ) + self.l_k = self._create_ia3_parameter(neox_args) + self.l_v = self._create_ia3_parameter(neox_args) + + def _create_ia3_parameter(self, neox_args): + if neox_args.use_cpu_initialization: + param = torch.nn.Parameter( + torch.empty( + self.hidden_size_per_partition, dtype=neox_args.params_dtype + ) + ) + else: + param = torch.nn.Parameter( + torch.empty( + self.hidden_size_per_partition, + device=torch.cuda.current_device(), + dtype=neox_args.params_dtype, + ) + ) + param.model_parallel = True + param.partition_dim = 0 + #param.stride = stride + # Always initialize to ones. + with torch.no_grad(): + torch.nn.init.ones_(param) + return param + + def forward(self, hidden_states, attention_mask, layer_past=None): + + # hidden_states: [sq, b, h] + + # ===================== + # Query, Key, and Value + # ===================== + + # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] + mixed_x_layer, _ = self.query_key_value(hidden_states) + + # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn] + new_tensor_shape = mixed_x_layer.size()[:-1] + ( + self.num_attention_heads_per_partition, + 3 * self.hidden_size_per_attention_head, + ) + mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) + + # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] + (query_layer, key_layer, value_layer) = mpu.split_tensor_along_last_dim( + mixed_x_layer, 3 + ) + # Apply IA3 rescaling to keys & values: + def _apply_ia3_rescaling(layer, scale_vector): + layer_size = layer.shape + layer = layer.reshape(layer_size[0], layer_size[1], -1) + layer *= scale_vector + return layer.reshape(layer_size) + + key_layer = _apply_ia3_rescaling(key_layer, self.l_k) + value_layer = _apply_ia3_rescaling(value_layer, self.l_v) + + if exists(self.rotary_emb): + if exists(self.rotary_ndims): + # partial rotary + query_rot, query_pass = ( + query_layer[..., : self.rotary_ndims], + query_layer[..., self.rotary_ndims :], + ) + key_rot, key_pass = ( + key_layer[..., : self.rotary_ndims], + key_layer[..., self.rotary_ndims :], + ) + else: + # full rotary + query_rot, key_rot = query_layer, key_layer + apply_rotary_fn = ( + apply_rotary_pos_emb_torch if self.bf16 else apply_rotary_pos_emb + ) + + seq_len = key_layer.shape[0] + offset = 0 + if exists(layer_past) and layer_past.numel() > 0: + offset = layer_past[0].shape[0] + seq_len += offset + cos, sin = self.rotary_emb(value_layer, seq_len=seq_len) + query_layer, key_layer = apply_rotary_fn( + query_rot, key_rot, cos, sin, offset=offset + ) + + if exists(self.rotary_ndims): + query_layer = torch.cat((query_layer, query_pass), dim=-1) + key_layer = torch.cat((key_layer, key_pass), dim=-1) + + # ================================== + # Cache key and value for inference + # ================================== + + if exists(layer_past) and layer_past.numel() > 0: + past_key, past_value = layer_past + key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=0) + value_layer = torch.cat( + (past_value.type_as(value_layer), value_layer), dim=0 + ) + + if self.use_cache: + present = torch.stack((key_layer, value_layer)) + + if self.use_flash_attention: + context_layer = self.flash_attention(query_layer, key_layer, value_layer) + elif not self.sparse: + context_layer = self.attention( + query_layer, key_layer, value_layer, layer_past, attention_mask + ) + else: + context_layer = self.sparse_attention( + query_layer, key_layer, value_layer, attention_mask + ) + + # [b, np, sq, hn] --> [sq, b, np, hn] + context_layer = context_layer.permute(2, 0, 1, 3).contiguous() + + # [sq, b, np, hn] --> [sq, b, hp] + new_context_layer_shape = context_layer.size()[:-2] + ( + self.hidden_size_per_partition, + ) + context_layer = context_layer.view(*new_context_layer_shape) + + # ================= + # Output. [sq, b, h] + # ================= + + output, bias = self.dense(context_layer) + + if self.use_cache: + output = [output, present] + + return output, bias + + class ParallelTransformerLayer(nn.Module): """A single transformer layer. @@ -625,9 +788,10 @@ def __init__( if self.gpt_j_residual: self.reduce = mpu.mappings.reduce_from_model_parallel_region + self_attention_cls = getattr(sys.modules[__name__], neox_args.self_attention_cls) # Self attention. - self.attention = ParallelSelfAttention( + self.attention = self_attention_cls( neox_args=neox_args, attention_mask_func=attention_mask_func, init_method=init_method, diff --git a/megatron/mpu/__init__.py b/megatron/mpu/__init__.py index 611d2adbf..60866b8cf 100644 --- a/megatron/mpu/__init__.py +++ b/megatron/mpu/__init__.py @@ -36,6 +36,7 @@ from .initialize import model_parallel_is_initialized from .layers import ColumnParallelLinear +from .layers import ColumnParallelLinearIA3 from .layers import RowParallelLinear from .layers import VocabParallelEmbedding from .layers import ParallelRelativePositionBias diff --git a/megatron/mpu/layers.py b/megatron/mpu/layers.py index 92edbd6eb..bb8e01100 100644 --- a/megatron/mpu/layers.py +++ b/megatron/mpu/layers.py @@ -564,6 +564,71 @@ def forward(self, input_): return output, output_bias +class ColumnParallelLinearIA3(ColumnParallelLinear): + def __init__( + self, + neox_args, + input_size, + output_size, + bias=True, + gather_output=True, + init_method=init.xavier_normal_, + stride=1, + keep_master_weight_for_test=False, + skip_bias_add=False, + mup_rescale_parameters=False, + ): + super().__init__( + neox_args, + input_size, + output_size, + bias=bias, + gather_output=gather_output, + init_method=init_method, + stride=stride, + keep_master_weight_for_test=keep_master_weight_for_test, + skip_bias_add=skip_bias_add, + mup_rescale_parameters=mup_rescale_parameters + ) + if neox_args.use_cpu_initialization: + self.l_ff = Parameter( + torch.empty( + self.output_size_per_partition, dtype=neox_args.params_dtype + ) + ) + else: + self.l_ff = Parameter( + torch.empty( + self.output_size_per_partition, + device=torch.cuda.current_device(), + dtype=neox_args.params_dtype, + ) + ) + self.l_ff.model_parallel = True + self.l_ff.partition_dim = 0 + self.l_ff.stride = stride + # Always initialize l_ff to ones. + with torch.no_grad(): + torch.nn.init.ones_(self.l_ff) + + def forward(self, input_): + if self.use_mup and self.mup_rescale_parameters: + input_ /= self.width_mult() + # Set up backprop all-reduce. + input_parallel = copy_to_model_parallel_region(input_) + # Matrix multiply. + + bias = self.bias if not self.skip_bias_add else None + output_parallel = F.linear(input_parallel, self.weight, bias) + output_parallel *= self.l_ff # apply IA3 rescaling + if self.gather_output: + # All-gather across the partitions. + output = gather_from_model_parallel_region(output_parallel) + else: + output = output_parallel + output_bias = self.bias if self.skip_bias_add else None + return output, output_bias + class RowParallelLinear(torch.nn.Module): """Linear layer with row parallelism. diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index 95e6b6b8e..5aa823de8 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -355,11 +355,27 @@ class NeoXArgsModel(NeoXArgsTemplate): """ output_layer_parallelism: Literal["row", "column"] = "row" + ia3_prompt_tuning: bool = False + """ + Run IA3 prompt tuning based off: + Few-Shot Parameter-Efficient Fine-Tuning is Better and Cheaper than In-Context Learning + https://arxiv.org/pdf/2205.05638.pdf + """ """ Parameter controlling whether the output layer is parallelized over the hidden dim (row) or the vocab dim (column) """ + self_attention_cls: str = "ParallelSelfAttention" + """ + Default class to use for self attention + """ + + mlp_column_parallel_cls: str = "ColumnParallelLinear" + """ + Default class to use for linear column layer parallelism + """ + @dataclass class NeoXArgsOptimizer(NeoXArgsTemplate): diff --git a/megatron/training.py b/megatron/training.py index 6ebbe780d..a2ed28d1a 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -385,6 +385,10 @@ def get_model(neox_args, use_cache=False): # If mup isn't being used anyways, this has no effect. old_use_mup = neox_args.use_mup neox_args.use_mup = False + if neox_args.ia3_prompt_tuning: + neox_args.mlp_column_parallel_cls = "ColumnParallelLinearIA3" + neox_args.self_attention_cls = "ParallelSelfAttentionIA3" + model = GPT2ModelPipe( neox_args=neox_args, num_tokentypes=0, @@ -412,6 +416,16 @@ def get_model(neox_args, use_cache=False): for name, param in model.named_parameters(): if not "soft_embedding" in name: param.requires_grad = False + elif neox_args.ia3_prompt_tuning: + layers_to_train = ["l_ff", "l_k", "l_v"] + for name, param in model.named_parameters(): + if not any([x in name for x in layers_to_train]): + param.requires_grad = False + + trainable_params = sum( + p.numel() for p in model.parameters() if p.requires_grad + ) + print(f"Number of trainable parameters: {trainable_params}") if not neox_args.is_pipe_parallel: # Export PipeParallel model to nn.Sequential model to avoid the overhead of deepspeed's pipe parallel training From 3c403e08e9861355ecda557e20392da50538b507 Mon Sep 17 00:00:00 2001 From: Mark Worrall Date: Tue, 21 Mar 2023 15:14:47 +0000 Subject: [PATCH 2/9] feat: add no weight decay parameters --- megatron/model/utils.py | 4 ++-- megatron/neox_arguments/neox_args.py | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/megatron/model/utils.py b/megatron/model/utils.py index 3e9940c1e..77a6ddb26 100644 --- a/megatron/model/utils.py +++ b/megatron/model/utils.py @@ -48,14 +48,14 @@ def get_params_for_weight_decay_optimization(module, neox_args): [ p for n, p in list(module_._parameters.items()) - if p is not None and n != "bias" + if p is not None and n not in neox_args.no_weight_decay_params ] ) no_weight_decay_params["params"].extend( [ p for n, p in list(module_._parameters.items()) - if p is not None and n == "bias" + if p is not None and n in neox_args.no_weight_decay_params ] ) if neox_args.weight_decay == 0.0: diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index 5aa823de8..a74e7e01d 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -13,7 +13,7 @@ # limitations under the License. import subprocess -from dataclasses import dataclass +from dataclasses import dataclass, field try: from .template import NeoXArgsTemplate @@ -376,6 +376,8 @@ class NeoXArgsModel(NeoXArgsTemplate): Default class to use for linear column layer parallelism """ + no_weight_decay_params: list = field(default_factory=lambda: ["bias", "l_ff", "l_v", "l_k"]) + @dataclass class NeoXArgsOptimizer(NeoXArgsTemplate): From 661c780c2f138124d0017ab110c20b040e4e4a29 Mon Sep 17 00:00:00 2001 From: Mark Worrall Date: Tue, 21 Mar 2023 15:45:29 +0000 Subject: [PATCH 3/9] feat: add flag for checkpointing --- megatron/checkpointing.py | 4 +++- megatron/neox_arguments/neox_args.py | 7 +++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/megatron/checkpointing.py b/megatron/checkpointing.py index 4af203763..d70880096 100644 --- a/megatron/checkpointing.py +++ b/megatron/checkpointing.py @@ -227,6 +227,8 @@ def load_checkpoint( ): """Load a model checkpoint and return the iteration.""" if neox_args.deepspeed: + if neox_args.ia3_prompt_tuning: + neox_args.load_module_strict = False load_optim_and_scheduler = ( not neox_args.no_load_optim ) # TODO: These should be configured by separate args @@ -241,7 +243,7 @@ def load_checkpoint( load_optimizer_states=load_optim_and_scheduler, load_lr_scheduler_states=load_optim_and_scheduler, tag=tag, - load_module_strict=False + load_module_strict=neox_args.load_module_strict ) if checkpoint_name is None: diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index a74e7e01d..ba3f9d1c0 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -377,7 +377,14 @@ class NeoXArgsModel(NeoXArgsTemplate): """ no_weight_decay_params: list = field(default_factory=lambda: ["bias", "l_ff", "l_v", "l_k"]) + """ + Which parameters we won't apply weight decay to + """ + load_module_strict: bool = True + """ + Whether to strictly enforce that the keys in state_dict of module & checkpoint match. + """ @dataclass class NeoXArgsOptimizer(NeoXArgsTemplate): From e825de0f631130e9051a6d12bab054ed6af472d8 Mon Sep 17 00:00:00 2001 From: Mark Worrall Date: Thu, 23 Mar 2023 11:53:19 +0000 Subject: [PATCH 4/9] feat: rename ia3_prompt_tuning -> ia3_tuning --- configs/local_setup.yml | 4 ++-- megatron/checkpointing.py | 2 +- megatron/neox_arguments/neox_args.py | 4 ++-- megatron/training.py | 4 ++-- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/configs/local_setup.yml b/configs/local_setup.yml index cdb96fb4d..4f44f6eac 100644 --- a/configs/local_setup.yml +++ b/configs/local_setup.yml @@ -27,6 +27,6 @@ "use_wandb": True, "wandb_host": "https://api.wandb.ai", "wandb_project": "neox", - "num_gpus": 1, - "ia3_prompt_tuning": True + #"num_gpus": 4, + "ia3_tuning": True } diff --git a/megatron/checkpointing.py b/megatron/checkpointing.py index d70880096..1661c962d 100644 --- a/megatron/checkpointing.py +++ b/megatron/checkpointing.py @@ -227,7 +227,7 @@ def load_checkpoint( ): """Load a model checkpoint and return the iteration.""" if neox_args.deepspeed: - if neox_args.ia3_prompt_tuning: + if neox_args.ia3_tuning: neox_args.load_module_strict = False load_optim_and_scheduler = ( not neox_args.no_load_optim diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index ba3f9d1c0..e0b81ec6e 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -355,9 +355,9 @@ class NeoXArgsModel(NeoXArgsTemplate): """ output_layer_parallelism: Literal["row", "column"] = "row" - ia3_prompt_tuning: bool = False + ia3_tuning: bool = False """ - Run IA3 prompt tuning based off: + Run IA3 tuning based off: Few-Shot Parameter-Efficient Fine-Tuning is Better and Cheaper than In-Context Learning https://arxiv.org/pdf/2205.05638.pdf """ diff --git a/megatron/training.py b/megatron/training.py index a2ed28d1a..dfbeefb1b 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -385,7 +385,7 @@ def get_model(neox_args, use_cache=False): # If mup isn't being used anyways, this has no effect. old_use_mup = neox_args.use_mup neox_args.use_mup = False - if neox_args.ia3_prompt_tuning: + if neox_args.ia3_tuning: neox_args.mlp_column_parallel_cls = "ColumnParallelLinearIA3" neox_args.self_attention_cls = "ParallelSelfAttentionIA3" @@ -416,7 +416,7 @@ def get_model(neox_args, use_cache=False): for name, param in model.named_parameters(): if not "soft_embedding" in name: param.requires_grad = False - elif neox_args.ia3_prompt_tuning: + elif neox_args.ia3_tuning: layers_to_train = ["l_ff", "l_k", "l_v"] for name, param in model.named_parameters(): if not any([x in name for x in layers_to_train]): From 7d4b726ab98905085faacd60eaf7a721eabcd4c7 Mon Sep 17 00:00:00 2001 From: Mark Worrall Date: Thu, 23 Mar 2023 11:59:28 +0000 Subject: [PATCH 5/9] feat: remove stride parameter --- megatron/model/transformer.py | 1 - megatron/mpu/layers.py | 1 - 2 files changed, 2 deletions(-) diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 55f046ce8..542242429 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -637,7 +637,6 @@ def _create_ia3_parameter(self, neox_args): ) param.model_parallel = True param.partition_dim = 0 - #param.stride = stride # Always initialize to ones. with torch.no_grad(): torch.nn.init.ones_(param) diff --git a/megatron/mpu/layers.py b/megatron/mpu/layers.py index bb8e01100..fdd361ec4 100644 --- a/megatron/mpu/layers.py +++ b/megatron/mpu/layers.py @@ -606,7 +606,6 @@ def __init__( ) self.l_ff.model_parallel = True self.l_ff.partition_dim = 0 - self.l_ff.stride = stride # Always initialize l_ff to ones. with torch.no_grad(): torch.nn.init.ones_(self.l_ff) From 65d64ad5598628e1dffb29e0be300aba30b6ddfc Mon Sep 17 00:00:00 2001 From: Mark Worrall Date: Thu, 23 Mar 2023 13:36:52 +0000 Subject: [PATCH 6/9] feat: add comments to ia3 rescaling function --- megatron/model/.transformer.py.swp | Bin 0 -> 53248 bytes megatron/model/transformer.py | 9 ++++++++- 2 files changed, 8 insertions(+), 1 deletion(-) create mode 100644 megatron/model/.transformer.py.swp diff --git a/megatron/model/.transformer.py.swp b/megatron/model/.transformer.py.swp new file mode 100644 index 0000000000000000000000000000000000000000..3fb571738ba94cbfdadbecd48a703fd139cbfc01 GIT binary patch literal 53248 zcmeI53z%e8b?-|>AsWON1Ot)dv|g*ByLx(N03B1B0R|m}8JrnFop2A;Rj0d)sje#O zG1CJkD%VF6qtWDx8lzWIY4_{GZBz#dX>h(#~d^d_-bBQQ=e`~+@sZ%{O zsL9n^=j(r0)j9jI_S$Pd*4k_Dv#|ccZRaJAugsUY9#SgZdclc}&wQeE=1l##(n@Xr zsq57Pc{u)Rc4y~L-!Qvzes%OLJ zi{|q_xzGEX-%azmlzaXF^Lvl^{B-X5k>+>Je4b)5u=yNi?zhb6ALgDv$lSO2{(J8E z5$5?;^O>3g+4O(a+<&b3d|mGOubBHb|66j;A86AxpL6%qUu^p3XQ^tQzclyU+@5?) zI56SBgaZ=}OgJ#%z=Q)64oo;O;lP9g6Anx`@SnngTBB4tj+{R(iY)8@R{#IW0ERp14n z1!lorDEJ=$F9!R-Z-d_i^I#7A1O@yKa0_@X_;Zkg{|eTFqrewX;@=Lg1kVH8z^UNj z;2RhiZU=u4ZUonZKLyVQzXk3GK82CuIbaj`7YrI71}_4qgN0>pay2ZJ=h#R32p{m&;d^X$Af>N zex#26&Kev>q$LTi&T!Bj4$^+}@~pbJm^PX#Go|2h(rmlO;%|}GdQ1Jva^YAHrvfXnwE7fPj8OW4&-D8 z*(M+bUhXtz64MY`wF6mCCt;sBxzs_bXV;Jzm1-kfOk94Nly&QL?bo{9Rx@iwn*58{oW@*n^ zTbk5tzuxI(eahhONRYVclI)mx$c&4aN*E-!UH1)22sEBRfb?yHItQ99s~a>0DDHj}#}Yp`6SvZ&jBa>-b>UgDBlrswMUhQHa%P}WOw z4^<@Ur{C$B0(u@+l8drrNKYyZTkN!2o&B&ZVYnC~;)A|!hP%Cco}p0-AMDbcp$4dn z<%?D6HEU@Dn%dUpq&-|&fG&cocG}ku`kU#HFB@jKJl8fqy7Hr|f3*40wm;;%%Sw3`s?JeHC;2jHiUgS4NzsROwUaDowT20?Og3u7 z)P&F~ika2!&`*eH4VHVI;S%yC&6ZAJvDwOKPlHaczFbMqtfZIHadq}(ycMPp(v9nyqGahK*iX^?Uh*XEb%Psu`y}>bQ3x?RNTUwz9x9h$rkHT_2fgOP za6o$;p;+b{VE}p}2Bvjeb|u?(txL|>vTNI}8OXVy-I9b|yIweusM$_kwe41IzvNk4 z>SemkH`~en9-@ihGU;3#?638*(iG{)wLBPfH_y%O-@m_7>(=n@u5@}!b1lum{@l6S z&fL0v*Vfq$mGvshGdn}Z*#k*uVL)?b*sFKC{8EFh@=T?LjM}1Tv^mWr7cy6t4jcNCvP55JBCVT#q9XWA=<79?mIq0> zKAmh>zhV7Mvi-tsXKmX;{p~z|$IdNMgOw!N!i-D>6vpy3LXbU?>`be0K5Y||GA(EP6pH9Vc;Hg`*(nA!A0Og@FehGz>(kq zK+1oM@L+OHI56SBgaZ=}OgQih!U36b8G{A0=?Yut7~4~s6r@Vy@;2CBw#3Ay-0Cbb zTX46hy}ivOlM7Ii959fIMM2qewarXx|5M3RRE83u&FoO7I4Xif$|ZyW=j@6$Yn8PP z+9o-Z?CkWTDb<#RfsFoaL9=5&HHtA0#=du5=xtfT5 zwuzy}6e*^e(g?q0wknqIVqWVgIrV5vDn(XCJ*h0Yi~uMrujpT8SsKvP%j4;pWZiadpWnnZ zZfs>3Di08^o;CL|ZN(bWUcwq+EK0&zS!{@{=aT45G8lGS-kOHdF~cArQ~M@qV5?*b zuWfr$WRxs)j7C`~!+t1KVLTC<4Ww5NjQyxk-F_AL!t{%c%aQ$8XXlf| zj1?@CNSc-mYVn+5e*g!Tvli-JqV1^IN4jE96}?SNX4VWOh7+0R>*`4MFEgQUbfoT` zX+!HTwL~%deAD)1)}OURRn_U++Q=DcHCtSLY9T4vyZWb=0XSwEVR>weCUq-PgsqlCuZr3 zU98;0IPJA6J+PykD~l~r^>6SN<9x*{jW*6EmIX||x(||-IKP}`H2Wo%oD&0}T zFKXnemBwiJo#Ur(uC&!Iy`GF-svnM*N%AR5`CVn_9~M=+?gxC2mYB=8KYK?$G2|{( zRi|fa%veaZX)z2b-|2dq4b}pyp3!C$Q2PICSgZadHUQE8x0*T6SJ3fq2RDN2!0&+$ z*aqgnuY;eU>x&KGFTtBY8=MNn4sZncEPDUj!5@PvcpNwu90|UJ-hVUr1F#(&5AF-@ zMBo1yxEfpmo(U@8k>Gf6D|-H?!CS#wz>C3N@HB7^cp&%$I{(MP%fU0iW59RN@$Ud{ z0&fI8@Mv%ZxEsCy2jFJ#-@$9aGe81P0$+eP!k?SKG$?~3!0oce5B>=329E?sg0HZ~ ze=X<$S@VAym<8XxAM}IgfXhG^JO!)=-(!tm*8N`(_JGHNZ?V?@d2lx z+jrHVSTRQEHp`E0P_n^CrHGXkEY`%CN%m&uOT9yfz}U;W>Ea@509E}_vZ=_Yj;9uC zbWHVCx&~R$vX=10wr-g=rcGG}8*vygv@*jf@2j2>VRFX_``-45}FRt{LYu$`5QFzZmA9z}IN&a_*b)c+Oo=%>WtoO4|$<5`o zK1$S*SXGQw3HjWkkfX(mm>4Z5IM0BE#re~;4?t1iDE%nFqAc*owL%$8eTH&Cu;NON z(w4X^1`Wl+sa7(gr}0!>!pb#TUb!pcr?|AIQmG(+fqW=gPBvd$+^yvz%6g#}EfuvO zq7!!cqGo1ATq&kBV9VjsIuiBJsMXO_?dY##DW1|xnAm;`s>93q2#Xktq!Q#R4n-Ta zQ(OrYG*<*A)gsocZA=+tSny$2LH z<#x{7U}vT}9~QK8!IIf01-BGw7t2%x8kmi&#JS5PR=U_^+m?+oqHAXR+EZ#vQFW-c z1uJDaF5Sm`yweLS)K`q~Yo)o$;^(UEk2g9dmIDb6I{ukZAf~BN(TMvjzzkBY0j9N9 zhxKOEv4QxZ`3`6*DEbVZPf{@#C$q3*WAuc(o)Hmwl3Pe_WnVMfUp~2=?YDU=zWW0D zZ&FmrvatI7EJ>F8t|$rCd|vv^39r}u20oxeCaCFq|$d&){` zHG`%#heA0roEyWVnIt7kMYXO#X)YT$MGi zGtC531!oblnHhr!T|#Z8WN+F}yC_`|r?U#Slzv)PAw4@Ks9uLu>TJ7{F7;}Sa?tHc z2$;F z&Lmsu4i5OvMhH=)!UtDr=L_xP8R^3v&b(zjlMX`inoKNNp4rczsyqzo&aJW{-b8o@ zDL+UhW6Cc{k05IFsP{6AF{zB^idl*`QkcNOG8{ToIT_n+50e%q;0W*xGZFV5c0`|q zut$gt?9tWS?^rfe>BSkL>b}JR^t`F6KXQ+d`Cixf8rh1`DuVVBi-T+BT92WT*J;Vd zmCw*oHG1I3A2H4-AICU-$5DUVn((~is$>!LGCAI8f7cnM|I3>HJ)&Qt|2K{P|0VSM z_kmXfS^wVx9s}+~um1;dJ@{SlEKmkVfzPAY%i8~^!JEN#;7`CAU<=p;zJrc`6L=BW z3C;qifghpUe+0Y-yc@g$>;aDlr-ReLvEXQMJv#n1;A-&4;Q8P(unUM!fY<=;2fm5k zFZ%&r44w{Vz}@KkH-T$_`2KGLj{z@3&%Y83KouMh{uv$rdhl|v44w=g48Dbq-vAea z3&0%s3Htqy!9C!g!0W;5z#^!EQ^6*%4m<+f2Yem9{~y5h;FaJ7paU)hrvtGI919)+ zzKT8ItzZZeAocYlV-qk{cK9D6B66wb=voDh(tR(f=QfC*UHW;1ZUJT{N;D^zxc!JR zU`)tdqIS07U0+G6pbf;)lo}jniGjV2Sr|7OaoJy3Y;|gb`O~}>pbELzMR^_{OCGQD z6OEBZrMsVlw<^mu$4A@T>;u!+DCbMxi0NX~{KMZV?qd$#zZCW@1w8dRxw`YWgk$=b z8Xa^`mXiS39}lRGSx18!$g; zfv8Z?l$3o;jdWiJ*TsU}OhT$&l+DGnRZVK6^t$HZGl~O^hz^w_y711<{_~CS>V>hj zCYwKntZ0h>YZM|-G(}mI{m&E~pnK)BzjT&b4{1?1ly7-Ko364eNPbpUYuH_S zsjQLHZ~AJr*_-;Vs;;V}sNkP#P(Fe*A^mHO@n(znq0sFqIak&j{SFuD`eb395y2iB zZ3ZT`T*!8Wpq+~W&i50FHW}CAr9IPZ<4sVu7{)wcMB45*69HAJNw9)v^bM-N(EBDg zsg0A%xNxA0u4Ij-C}Ea+zDp2`kTNpbanaJ?gV;NTDPdj`x?IX78(Kb#V;D01767G; z99|Drn6h>~_*s%NE&l0>sR6aQ9h|=XNRihF%x+I*m{LI$(}I1HC5bE-4b@DB)UTr2 z%at>(ce7u1ydloi_on6Up&7QqsYfgnR+?;E%>DLlP)ZU{4dQ02OX~aSsKW_*hd4z& z-iD+7XsRZqLmiZZAzW{*WqKNC>Ux->A$zp>9eljTeaq~eRGU7r&n??MAm9gcOguimd)ps~GGPV-p5!t97;{X2$G{@H%{okJ7_fB;E=YySK z8hj7^{=HxUh`#>~bonoW*Mm#JCE#>$8aNJo5uN`1;H6*>*bI&Z4+0Ma|Arp_7VvDa z9hAZ2z^{Ohpu2wLaxsbCX$ z82Ac0{k7oLU@tfp90`7aPX8J3PvE8CTp;`VZ$Phq6}T9j0e%DAf!_W}@BwfGco8@U zJPee;Pteuh3Z4s|1>T3Qel=JEbKp0?iQs|Y@8G%c{m;kf@mdn%|E49eHxN}-dw|Hd zSn?%`5<0!+qpV1FhZY`~Xr&{YvwWW}k+m7LBwxcT7E1es4T1e-Sr2Ka!*&;Kg30$~ zd-t)ahP}T9)+>ouv@g)5V-_FG#;=$%n`tpA@^szVLXS5fh-asMAKIbfj~O#W;fRmC~3DSoT98@i!`R*W}w z*ffjzdhnD+1@K;1v1Nf6Ph7@YCj0SgEp?Amn@(0P;NrgA>sArxK_$Y7yG=Mpkc2Ds=GA3Y5uQd)UxLO1(ymv2j9DRp=;5z5WdZrxq7tsV6q{leZQ_6n6daai#p6VMuqN}%8!LTDGn_r5+k#WBAarQBbnW@r^_o8s8(&o&^j zq3h)Uy=qbM0;f`Tq8b#-+$&iI#7uS6igRh+l}PR3X*E=1bMbYhj|PtYj!dB%&3zE< z+tj$sL*56p7tj>b5yUGPa>fv$N|2Aq0I5Cg;z*u9L4=SmzC`SlGQ2WL6!ddGj5$9_ zY6Wc`v!Uw+)X^@==p%b~JY7my0w)|*jBvU^w&735Dak_xc3GrG2g6c<`nGWcF~tnx zV%(d4V6DVG7srND>;K22y1vKg|24CUd>cCdD?tiQ0C%JFe^c~+Abb7K2C~=xQ|S6H z1up?tfE}O$O5j%X{OiFLKx_aN@E9O_{vQHvM%S0U{x1P3khA{QgX6%D(eu9u-Vfxg zzh{FUkahku!3p5===!&U17H>W7B~TX8(m*)0Dlc$4r;*e`+qFoe~ixmHSk8T2W$n8 z1jm2}gFDduuLIYDy2i^bM;5P7=;JIKucoY!3z+LG6w}Q8TH-nb|u?svLh;2a5 z{1dytTfho+Qa)(R%baAh=bhlYh}ve z2eT=1$X1dXrfH7!Xb<~rQ(>;Hk4cjQuiAR5qqfS~(=CGiG_+LOt2e`$qcvnIRUeQR zD!e-Sjsb$HiWhLOr$SG%f zuU1J|EEh-iStU}RET+zc32utT#-Cq?Yq9LnDyg{UpGzRsnOVAmxzz|kzLsoH<3=u2 zFl%L4I~(?0FyoRU>|s))j#o3}h15w9BRhNzU56{m-*r9v)|{8Oy%mucZEWePJ*qZ? zG5)1m#2xja8><#(@l-JK8-yC$GV*OhviD6&e%VWHTj01hTc@T;#FIQ7ZeVL<~vP2xkrSZEX0`pDG#}TS)t?qKI zTAt@TpgG+m2DAU1>*@l>msZzT*8AC~q~L{JElbkwvNDGT%}-bSZ`#uaghiUxf=rRS zc)+j~E8LzbN$Gj4nsNk@Rbp{u*^(g6^AgEk5vx3j(QR6;_Bma^mbKBJEdWhv?G)yN00Da_ zqF0=xDaNCmOY@L=u^{xgmSd8q-T0Pz=xGNtDVeFwgp3M?*l3qdxZNSY6+#g#-WRTs zm;JmWt}(G~9{X{w0d>`xaJe93Ee^`N19_6^@EUGl zQ^^kHH(BwVmh_8f-Xw=3%31m-VQuX(bDFdp*N=+to=H4+(zmPJG(R7bjK%LTsd0jw zoM6bT+#K^E24ZE1v>JOO>$0hy3ALYn>*$Fm~KYA(Kv$8-qhwV@e;Kn0NjBc;6p(60z4mF z4xRvHPryCs|L+4gfCJz(Fay4e{{J>`HFyKK43xp+!2Q72(EmRPJ^|!Bz&`_W7Qi#X zW^e*{C=g$PPlNY>Yd{C=0H=W&@MCNO{{#FXxDs3dP6Cet_W>pF7VH8q1hQA)f#9>) z0bUC(1($#g;7)7+w}F>|E5TlH4)_sv067!jyWq{>I`9JUe9!=AfyaO;Aa(cE@wNgh zyP-xj+1Qus0fGs}_;Hhl-d3u-9HgPY&4@ZIdxJ`r&EW}3{!HWCc?eoTjuw$d*Pe7B z)sOS_XQM@1DqR5>7<~UFJy2!PiLz2tZh4o%?EHEzS5Afr?e4EO*_w$`SqTt7-f5dKiwG&%0CCK8OvLG=iIwek>haJutLj}yP^x-2T%I5;g z!d?ck%f0N_5M~8lD#@BGo~#%zQ!_7A+9|>uiFF7?&Cm8ie?cR*k+S25&FHq`Sg3Nl z;;S0*RjsGbA_|sqq7ei$>X17=npAf%Bj?bLSYh!Bj5${WZOue5H#YdWbe?jZ@Tlq{ z3R(fB!7Wt*LR(B}y*BJ~w9cqQqn!-sNrklacx=v1(5$%zW-{qTg;Y!YBbEe~6Vra2 zP<@zk<_O28CLl z$GAjwnREz=@q@BD$zQE>;hj7$y-ZQCz?kC8iWKS$(L=E5ltW>KMLgC?m4raK^iS@G zvK2^5Kua8kLt{!R1^&SSmoZ%q+J1S@A+$UM6iZXoPv(3p$+Q z4;AkF&y)t~F4pEnP=Qbp{2EV(69}<-uwdcpOBp@*$<3(K6VX<5n7b~4xK=$24$aJg zCEcZjct*sl6n;*25XnnzgtI<|hdUM(j6E`hvy!ESvLrdk72TzJqTh#Z?adE6C6#zV z`o{K&-P6yD)B-|eM9E8ruf-{*6EZmFjDFEECO))Kv1SYMW1H$G_lPtH%{WIR{K?%J z8R(jx277?)1&}lU z<($7)fU|+<_a_5+-`_jIH9*e#lXw2fd;a9Czazkh(B)qRUJAB>rvlOMj{*N1J^p*( zCU7IT2K+Ag9Z&<8fD~*34+ReaqVK;A$k~6g_y1483xMeU7XWz&;G@7%;QyfCe;Irf z+z37ht_D|t%fU|YbZ|0|GXib_H-i_0bAY@b@ayRNUjy=fz{^1bP6GD_pF{6|H@FhW z`GAiCccJrt6}%r@1-f7nWZ)F=MfCi?1@^swHNI~Fau$Hp-QDB$e`C=Z`3M_GU}hcf z<>d^BV=bP%q>tu(zOf9;tHT1>=BJ;_sL3zOV7y<(Et@LCovf^phJ1LdR`hPd(Sv+> zXN$a-NtCKO->~`778;>iwQVtZ%K5uz&v8d#imRo1wH1#1%4&#~*B|z}!yZT4p+qVd z1k&0kGjWWd&2~_dbH$BM5sG0`c_u^~-(87}Plhjb+-+I&r`?pLc(T+Tq?gKCraG}R z?#8t1XyFMZZjo0T_uS810tY+Cq^hm%+C&@2HKUXc5X#HA2XXW;M({ljHuPkk{Xnrfm)t zN>heuM_6Svb}R|XNc;$!UxXDk{bN@ghN%=gPRS z_6Z}JdggetL)_JD!ez=;P4?I@=E!(&%d#!p)pF3dZ6rw;0aA@uEkKqdg(;A`f@z^6 z$gN2_({zbV#Nc7vNYjP-JbU#%yW8WwGZ<4GkORx?b`#k#*WK6On zM5qbA9J#n*+>wh;b9WvLxYF27*g{f7{YVDZKCO+@*dB8B>5?QNSL|PaiWnWRIHL7P zl$bSgCs1z2+0yfd-;%_$SH3KY-6KZkk_|q(PH`_u(tVtq&wTjgxe?t@Fgcl>EjV_Q zo2K9UnF{rV1@gi)HxrBnIDS4Z=jHqb)0|xhFC%OnZ)oIpm%A zBwZg;e#ylFl+3D?zFLuufN8BCp2*AEOrXp1OYD+j)_HvX*L3HLk zb6LA^WEFl_*~Zi>!U@A(I#&cz_i|FQnZp(1AnG5keQvjlTF`lo_&WxH&s1{>h*ATE z;!N!47=#^3XXHlQ<4b8AOJX*<+nQ)_qs72p!E6PyKP&Ho#J^o$b3a|pU1JUj8Kxh9t z_yG6=&;TcaAEBRr3cMfu4R}51f^$FuO5mgD>aPNqf@gq>z&aql{$B%andA~+iS5MBLla2t3dSOm`i8^BTEo792S z!?i`aysqi^zl`-&bZ&KySB00Yur4Xh-Sv(}v&>n;Y$6iVOdCq}p8a&Qk$9$_TgDC? zko8M|3lF!ATe;+*hL~~M|0ww-OO<>ERIBgdgokHwFXa}gtH~-WO!Box;zc8l{Y? z;(VV=!e(Ie#~oQ#8d^}oDimIGiM;+&<6`|v!F?7g4 zQwkFwrP|*KCC%N<%iAhE5m6%mFIsEQkII^@UI`dW;S@1P1vMnkhltXTgLM1Zu+b3< zn4kZQ&e~iHcl4a1t5@40%KCZC?4UW7gF;FfUF#YxuuvDLW9hw@jmt5AZ1*J@s75Be zMXq`I4rELN^8rR%ujAU!pg{%s6!v7R8)B}pvV?uLgLYA45*dq*6gUnocjUU4JRbE| z(8(FT`YOnA*FP`WT59zR6mTO3s zkhUuORmb+U>cQ9%eKi)%(auO}#$~cfW+hu!_EUPsF}X)!?cm`Jn!Xd1wAg#c*~HCE zUc{H75^@YZ2gizWjty{^D?efuw#BnqEdSW|Nl-#c2)$?^w$ms}*MSyv=)8nYjgY+h zTP5S)Q{NmMP@06CB_?k+V-WXB>!f;1%yf*|eCjPjbi`*M*?L{*G3SJY}N& zD$6d%M96Z-_N?=a*$wN@s3hur<1BHsJMCExi*5Ggt?pQ%Mx3q5qmS>m(6TaYH)S`A ztZ_6~vZPE|>byF>A!HceYjyw_vvq9|H*EFJxkf|vQ*&wvG-j>-blyo^ExN3*SksqC zl5TAuCoGwFAW8v!+7w6K5O=sa-3h7%mw3`KGB%D5OCkXFV#iLwI2ntNtywQZWTF>v zapnk2V!58l%4#Zx-|$(m|2&#S1wWu5{N+z(3C?`+0ZLLF<%kizACo5J+O=9siMkT$ z+Y9_J7X2ERq^e#Rl`lZbo-n=b!1}5ArajuM{WD2WE(sI&<0L4q>>L&5_70X*(BPk~ z6`SWsfRM*RqDfJOj{@i3f@Kqq;_LMaUFX8u(OQ&_C}g1On0%#B{la4DY@}dnT~gZ- za3}57nmu{X2Y+gpL_PB(dCafFULVvW?&y05s^cDL|9@2I(oIJHUkdbpdH3Ioz;+<- z`?~|({`26q;5A?w{2I6$z5Y({9qKN0*Jy8idU=fG#dXTa;gOTcqL3+w=!!SO)O z1-O&?xgAJ7-6(Zsuf$w#P2iVm}w*9O+}{2*M?Y3cuks}YlBm#i186U1v3j(v92 zmOgK??83^-WZ7k8sU^G-_xRHMTzh4zrWfR7fh$~&lMuPJ65Ve7|DXqmlL;P-S$6S> zD063Y+KD;{GF3+{hI34bGOicGyF<4c^~Jty7ytjsYkcv@ka`?UUgK+Dvs(&=DFQN>-wiO3Skq!Bt>CxDd>M@1p1bHMko5F4ziw705pS zZ-S2lc^5z%oCqEajs{;x&;NVyM?lU1tb%`smYcx~!7A7fguau& zV8Ve32PPc&xpF|%LwjL$J5)ObQ4S+?tcyr?TFshRH4f9s+18YNYSRy@S?t@*As8vg zaB$z5z|5H`;VBoAz-X=L_Y>%yL@B-=@(PDkS+h|_5j7Lwn6@o)Lghj1%pO`*REnc8wy`oARNhYSD! literal 0 HcmV?d00001 diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 542242429..ccafea613 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -664,8 +664,15 @@ def forward(self, hidden_states, attention_mask, layer_past=None): (query_layer, key_layer, value_layer) = mpu.split_tensor_along_last_dim( mixed_x_layer, 3 ) - # Apply IA3 rescaling to keys & values: + def _apply_ia3_rescaling(layer, scale_vector): + """Apply IA3 rescaling: + + Reshapes: [sq, b, np, hn] -> [sq, b, np * hn] to perform + rescaling and then back to [sq, b, np, hn]. + + Note: np * hn == h/p == self.hidden_size_per_partition + """ layer_size = layer.shape layer = layer.reshape(layer_size[0], layer_size[1], -1) layer *= scale_vector From 221b6dd30e1668aef04fab8287b6195a2ac6abab Mon Sep 17 00:00:00 2001 From: Mark Worrall Date: Thu, 23 Mar 2023 14:14:25 +0000 Subject: [PATCH 7/9] feat: move IA3 scaling after MLP non-linearity --- megatron/model/transformer.py | 116 ++++++++++++++++++++------- megatron/mpu/__init__.py | 1 - megatron/mpu/layers.py | 64 --------------- megatron/neox_arguments/neox_args.py | 4 +- megatron/training.py | 2 +- 5 files changed, 90 insertions(+), 97 deletions(-) diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index ccafea613..1e8531693 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -89,22 +89,21 @@ def __init__( # auto scale so geglu has equal parameters ff_mult = 4 * 2 / 3 if self.activation_type == "geglu" else 4 - ff_dim = ( + self.ff_dim = ( int(ff_mult * neox_args.hidden_size) * 2 if self.activation_type == "geglu" else ff_mult * neox_args.hidden_size ) - mlp_column_parallel_cls = getattr(mpu, neox_args.mlp_column_parallel_cls) - self.dense_h_to_4h = mlp_column_parallel_cls( + self.dense_h_to_4h = mpu.ColumnParallelLinear( neox_args=neox_args, input_size=neox_args.hidden_size, - output_size=ff_dim, + output_size=self.ff_dim, gather_output=False, init_method=init_method, skip_bias_add=True, ) - ff_dim_in = ff_dim // 2 if self.activation_type == "geglu" else ff_dim + ff_dim_in = self.ff_dim // 2 if self.activation_type == "geglu" else self.ff_dim # Project back to h. self.dense_4h_to_h = mpu.RowParallelLinear( neox_args=neox_args, @@ -137,6 +136,56 @@ def forward(self, hidden_states): return output, output_bias +class ParallelMLPIA3(ParallelMLP): + """MLP. + + MLP will take the input with h hidden state, project it to 4*h + hidden dimension, perform nonlinear transformation, and project the + state back into h hidden dimension. At the end, dropout is also + applied. + + Applies IA3 rescaling of each column after non-linearity: + https://arxiv.org/pdf/2205.05638.pdf + """ + + def __init__( + self, neox_args, init_method, output_layer_init_method, parallel_output=False + ): + super().__init__( + neox_args, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + parallel_output=parallel_output + ) + + world_size = mpu.get_model_parallel_world_size() + self.hidden_size_per_partition = mpu.divide(self.ff_dim, world_size) # 4hp + self.l_ff = create_ia3_parameter(self.hidden_size_per_partition, neox_args) + + def forward(self, hidden_states): + + # [s, b, 4hp] + intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states) + + if ( + self.activation_type == "gelu" and self.bias_gelu_fusion + ) or self.activation_type == "geglu": + intermediate_parallel = self.activation_func( + intermediate_parallel, bias_parallel + ) + else: + intermediate_parallel = self.activation_func( + intermediate_parallel + bias_parallel + ) + + # Apply IA3 rescaling: + intermediate_parallel *= self.l_ff + + # [s, b, h] + output, output_bias = self.dense_4h_to_h(intermediate_parallel) + return output, output_bias + + class ParallelLinear(nn.Module): """ A Parallel Linear Layer transforming the transformer outputs from hidden_size -> vocab_size @@ -594,6 +643,9 @@ def forward(self, hidden_states, attention_mask, layer_past=None): class ParallelSelfAttentionIA3(ParallelSelfAttention): + """Applies IA3 rescaling to key and query vectors per: + https://arxiv.org/pdf/2205.05638.pdf + """ def __init__( self, neox_args, @@ -617,30 +669,9 @@ def __init__( use_cache=use_cache, parallel_output=parallel_output, ) - self.l_k = self._create_ia3_parameter(neox_args) - self.l_v = self._create_ia3_parameter(neox_args) + self.l_k = create_ia3_parameter(self.hidden_size_per_partition, neox_args) + self.l_v = create_ia3_parameter(self.hidden_size_per_partition, neox_args) - def _create_ia3_parameter(self, neox_args): - if neox_args.use_cpu_initialization: - param = torch.nn.Parameter( - torch.empty( - self.hidden_size_per_partition, dtype=neox_args.params_dtype - ) - ) - else: - param = torch.nn.Parameter( - torch.empty( - self.hidden_size_per_partition, - device=torch.cuda.current_device(), - dtype=neox_args.params_dtype, - ) - ) - param.model_parallel = True - param.partition_dim = 0 - # Always initialize to ones. - with torch.no_grad(): - torch.nn.init.ones_(param) - return param def forward(self, hidden_states, attention_mask, layer_past=None): @@ -815,7 +846,8 @@ def __init__( self.post_attention_layernorm = norm(neox_args.hidden_size, eps=eps) # MLP - self.mlp = ParallelMLP( + parallel_mlp_cls = getattr(sys.modules[__name__], neox_args.parallel_mlp_cls) + self.mlp = parallel_mlp_cls( neox_args=neox_args, init_method=init_method, output_layer_init_method=output_layer_init_method, @@ -974,3 +1006,29 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, bias=Non return logits_parallel return mpu.gather_from_model_parallel_region(logits_parallel) + + +def create_ia3_parameter(param_size, neox_args): + """Create a parameter vector for use in IA3 scaling, per: + https://arxiv.org/pdf/2205.05638.pdf + """ + if neox_args.use_cpu_initialization: + param = torch.nn.Parameter( + torch.empty( + param_size, dtype=neox_args.params_dtype + ) + ) + else: + param = torch.nn.Parameter( + torch.empty( + param_size, + device=torch.cuda.current_device(), + dtype=neox_args.params_dtype, + ) + ) + param.model_parallel = True + param.partition_dim = 0 + # Always initialize to ones. + with torch.no_grad(): + torch.nn.init.ones_(param) + return param diff --git a/megatron/mpu/__init__.py b/megatron/mpu/__init__.py index 60866b8cf..611d2adbf 100644 --- a/megatron/mpu/__init__.py +++ b/megatron/mpu/__init__.py @@ -36,7 +36,6 @@ from .initialize import model_parallel_is_initialized from .layers import ColumnParallelLinear -from .layers import ColumnParallelLinearIA3 from .layers import RowParallelLinear from .layers import VocabParallelEmbedding from .layers import ParallelRelativePositionBias diff --git a/megatron/mpu/layers.py b/megatron/mpu/layers.py index fdd361ec4..92edbd6eb 100644 --- a/megatron/mpu/layers.py +++ b/megatron/mpu/layers.py @@ -564,70 +564,6 @@ def forward(self, input_): return output, output_bias -class ColumnParallelLinearIA3(ColumnParallelLinear): - def __init__( - self, - neox_args, - input_size, - output_size, - bias=True, - gather_output=True, - init_method=init.xavier_normal_, - stride=1, - keep_master_weight_for_test=False, - skip_bias_add=False, - mup_rescale_parameters=False, - ): - super().__init__( - neox_args, - input_size, - output_size, - bias=bias, - gather_output=gather_output, - init_method=init_method, - stride=stride, - keep_master_weight_for_test=keep_master_weight_for_test, - skip_bias_add=skip_bias_add, - mup_rescale_parameters=mup_rescale_parameters - ) - if neox_args.use_cpu_initialization: - self.l_ff = Parameter( - torch.empty( - self.output_size_per_partition, dtype=neox_args.params_dtype - ) - ) - else: - self.l_ff = Parameter( - torch.empty( - self.output_size_per_partition, - device=torch.cuda.current_device(), - dtype=neox_args.params_dtype, - ) - ) - self.l_ff.model_parallel = True - self.l_ff.partition_dim = 0 - # Always initialize l_ff to ones. - with torch.no_grad(): - torch.nn.init.ones_(self.l_ff) - - def forward(self, input_): - if self.use_mup and self.mup_rescale_parameters: - input_ /= self.width_mult() - # Set up backprop all-reduce. - input_parallel = copy_to_model_parallel_region(input_) - # Matrix multiply. - - bias = self.bias if not self.skip_bias_add else None - output_parallel = F.linear(input_parallel, self.weight, bias) - output_parallel *= self.l_ff # apply IA3 rescaling - if self.gather_output: - # All-gather across the partitions. - output = gather_from_model_parallel_region(output_parallel) - else: - output = output_parallel - output_bias = self.bias if self.skip_bias_add else None - return output, output_bias - class RowParallelLinear(torch.nn.Module): """Linear layer with row parallelism. diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index e0b81ec6e..65d5c3788 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -371,9 +371,9 @@ class NeoXArgsModel(NeoXArgsTemplate): Default class to use for self attention """ - mlp_column_parallel_cls: str = "ColumnParallelLinear" + parallel_mlp_cls: str = "ParallelMLP" """ - Default class to use for linear column layer parallelism + Default class to use for linear MLP parallelism """ no_weight_decay_params: list = field(default_factory=lambda: ["bias", "l_ff", "l_v", "l_k"]) diff --git a/megatron/training.py b/megatron/training.py index dfbeefb1b..8d9b32203 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -386,7 +386,7 @@ def get_model(neox_args, use_cache=False): old_use_mup = neox_args.use_mup neox_args.use_mup = False if neox_args.ia3_tuning: - neox_args.mlp_column_parallel_cls = "ColumnParallelLinearIA3" + neox_args.parallel_mlp_cls = "ParallelMLPIA3" neox_args.self_attention_cls = "ParallelSelfAttentionIA3" model = GPT2ModelPipe( From 467231bd708a07930f0f44e870e97a786d3096e1 Mon Sep 17 00:00:00 2001 From: Mark Worrall Date: Thu, 23 Mar 2023 14:44:40 +0000 Subject: [PATCH 8/9] feat: update defaults --- configs/local_setup.yml | 3 +-- megatron/training.py | 1 + 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/configs/local_setup.yml b/configs/local_setup.yml index 4f44f6eac..b17d5f024 100644 --- a/configs/local_setup.yml +++ b/configs/local_setup.yml @@ -27,6 +27,5 @@ "use_wandb": True, "wandb_host": "https://api.wandb.ai", "wandb_project": "neox", - #"num_gpus": 4, - "ia3_tuning": True + "ia3_tuning": False } diff --git a/megatron/training.py b/megatron/training.py index 8d9b32203..26494520c 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -425,6 +425,7 @@ def get_model(neox_args, use_cache=False): trainable_params = sum( p.numel() for p in model.parameters() if p.requires_grad ) + # Note: only for current partition print(f"Number of trainable parameters: {trainable_params}") if not neox_args.is_pipe_parallel: From ceeda69aaf05d1db8dfa3d12a62b0b93e1149182 Mon Sep 17 00:00:00 2001 From: Mark Worrall Date: Thu, 23 Mar 2023 14:49:08 +0000 Subject: [PATCH 9/9] feat: update comment --- megatron/training.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/megatron/training.py b/megatron/training.py index 26494520c..5cf852608 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -425,8 +425,7 @@ def get_model(neox_args, use_cache=False): trainable_params = sum( p.numel() for p in model.parameters() if p.requires_grad ) - # Note: only for current partition - print(f"Number of trainable parameters: {trainable_params}") + print(f"Number of trainable parameters (current partition): {trainable_params}") if not neox_args.is_pipe_parallel: # Export PipeParallel model to nn.Sequential model to avoid the overhead of deepspeed's pipe parallel training