1
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
2
# This file is adapted from language_model.py in Megatron-LM
3
3
4
+ from typing import Literal , Optional
5
+
4
6
import torch
5
7
from torch import einsum , nn
6
8
from domino .arguments import get_args
14
16
from domino .tensor_parallel .partition import _initialize_affine_weight_gpu , set_tensor_model_parallel_attributes
15
17
from domino .tensor_parallel .partition import ColumnParallelLinear , RowParallelLinearNoComm
16
18
19
+ from megatron .core .models .common .embeddings .rotary_pos_embedding import RotaryEmbedding
20
+ from megatron .model .utils import get_norm
21
+
17
22
from deepspeed .runtime .domino .transformer import DominoTransformer
18
23
19
24
def parallel_lm_logits (input_ , word_embeddings_weight , parallel_output ,
@@ -45,12 +50,18 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
45
50
def get_language_model (config , num_tokentypes ,
46
51
encoder_attn_mask_type ,
47
52
pre_process = True , post_process = True ):
53
+ args = get_args ()
48
54
language_model = TransformerLanguageModel (
49
55
config ,
50
56
encoder_attn_mask_type ,
51
57
num_tokentypes = num_tokentypes ,
52
58
pre_process = pre_process ,
53
- post_process = post_process
59
+ post_process = post_process ,
60
+ position_embedding_type = args .position_embedding_type ,
61
+ rotary_percent = args .rotary_percent ,
62
+ rotary_base = args .rotary_base ,
63
+ rope_scaling = args .use_rope_scaling ,
64
+ seq_len_interpolation_factor = args .rotary_seq_len_interpolation_factor ,
54
65
)
55
66
56
67
return language_model
@@ -85,37 +96,18 @@ def forward(self, input_ids, position_ids):
85
96
return combined_embeds
86
97
87
98
88
- class RotaryEmbedding (nn .Module ):
89
- def __init__ (self , dim , seq_len_interpolation_factor = None ):
90
- super ().__init__ ()
91
- self .seq_len_interpolation_factor = seq_len_interpolation_factor
92
- inv_freq = 1.0 / (10000 ** (torch .arange (0 , dim , 2 ).float () / dim ))
93
- self .register_buffer ('inv_freq' , inv_freq , persistent = False )
94
-
95
- def forward (self , max_seq_len , offset = 0 ):
96
- seq = torch .arange (max_seq_len , device = self .inv_freq .device ) + offset
97
- if self .seq_len_interpolation_factor is not None :
98
- seq = seq .type_as (self .inv_freq )
99
- seq *= 1 / self .seq_len_interpolation_factor
100
- freqs = einsum ('i , j -> i j' , seq .type_as (self .inv_freq ), self .inv_freq )
101
- # first part even vector components, second part odd vector components,
102
- # 2 * dim in dimension size
103
- emb = torch .cat ((freqs , freqs ), dim = - 1 )
104
- # emb [seq_length, .., dim]
105
- return emb [:, None , None , :]
106
-
107
- # def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
108
- # state_dict.pop(f'{prefix}inv_freq', None)
109
- # return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
110
-
111
-
112
99
class TransformerLanguageModel (DominoModule ):
113
100
def __init__ (self ,
114
101
config ,
115
102
encoder_attn_mask_type ,
116
103
num_tokentypes = 0 ,
117
104
pre_process = True ,
118
- post_process = True ):
105
+ post_process = True ,
106
+ position_embedding_type : Literal ['learned_absolute' , 'rope' , 'none' ] = 'learned_absolute' ,
107
+ rotary_percent : float = 1.0 ,
108
+ rotary_base : int = 10000 ,
109
+ rope_scaling : bool = False ,
110
+ seq_len_interpolation_factor : Optional [float ] = None ,):
119
111
120
112
args = get_args ()
121
113
super (TransformerLanguageModel , self ).__init__ (share_embeddings_and_output_weights = True )
@@ -127,6 +119,11 @@ def __init__(self,
127
119
self .init_method = config .init_method
128
120
self .encoder_attn_mask_type = encoder_attn_mask_type
129
121
self .encoder_hidden_state = None
122
+ self .position_embedding_type = position_embedding_type
123
+ self .rotary_percent = rotary_percent
124
+ self .rotary_base = rotary_base
125
+ self .rotary_scaling = rope_scaling
126
+ self .seq_length = config .seq_length
130
127
131
128
if self .pre_process :
132
129
self .embedding = Embedding (self .hidden_size ,
@@ -138,19 +135,18 @@ def __init__(self,
138
135
self .use_rotary_position_embeddings = \
139
136
args .position_embedding_type == 'rope'
140
137
if self .use_rotary_position_embeddings :
141
- self .seq_length = args .seq_length
142
- rotary_dim = args .hidden_size // args .num_attention_heads \
143
- if args .kv_channels is None else args .kv_channels
144
- if args .rotary_percent < 1.0 :
145
- rotary_dim = int (rotary_dim * args .rotary_percent )
146
138
self .rotary_pos_emb = RotaryEmbedding (
147
- rotary_dim ,
148
- seq_len_interpolation_factor = args .rotary_seq_len_interpolation_factor
139
+ kv_channels = config .kv_channels ,
140
+ rotary_percent = rotary_percent ,
141
+ rotary_interleaved = config .rotary_interleaved ,
142
+ seq_len_interpolation_factor = seq_len_interpolation_factor ,
143
+ rotary_base = rotary_base ,
144
+ rope_scaling = rope_scaling ,
149
145
)
150
146
151
147
self .encoder = DominoTransformer (
152
148
config , ModelType .encoder_or_decoder , mpu ,
153
- fused_layer_norm , _initialize_affine_weight_gpu ,
149
+ get_norm , _initialize_affine_weight_gpu ,
154
150
ColumnParallelLinear , RowParallelLinearNoComm , apply_rotary_pos_emb ,
155
151
bias_dropout_add_fused_train , bias_dropout_add_fused_inference ,
156
152
self_attn_mask_type = self .encoder_attn_mask_type ,
0 commit comments