@@ -23,23 +23,39 @@ class ModelArgs:
23
23
block_size : int = 2048
24
24
vocab_size : int = 32000
25
25
n_layer : int = 32
26
- n_head : int = 32
26
+ # n_head in gpt-fast
27
+ n_heads : int = 32
27
28
dim : int = 4096
28
- intermediate_size : int = None
29
+ # hidden dim is intermediate_size in gpt-fast
30
+ hidden_dim : int = None
29
31
n_local_heads : int = - 1
30
32
head_dim : int = 64
31
33
rope_base : float = 10000
32
34
norm_eps : float = 1e-5
33
-
35
+ multiple_of = 256
36
+ ffn_dim_multiplier = None
37
+
34
38
def __post_init__ (self ):
35
39
if self .n_local_heads == - 1 :
36
- self .n_local_heads = self .n_head
37
- if self .intermediate_size is None :
38
- hidden_dim = 4 * self .dim
39
- n_hidden = int (2 * hidden_dim / 3 )
40
- self .intermediate_size = find_multiple (n_hidden , 256 )
41
- self .head_dim = self .dim // self .n_head
40
+ self .n_local_heads = self .n_heads
41
+ if self .hidden_dim is None :
42
+ # If hidden_dim is not explicitly set in the ModelArgs,
43
+ # then calculate implicitly based on dim and
44
+ # also multiple of `args.multiple_of`
45
+ multiple_of = args .multiple_of
46
+ hidden_dim = 4 * dim
47
+ hidden_dim = int (2 * hidden_dim / 3 )
48
+ if args .ffn_dim_multiplier is not None :
49
+ hidden_dim = int (args .ffn_dim_multiplier * hidden_dim )
50
+ args .hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1 ) // multiple_of )
51
+ self .head_dim = self .dim // self .n_heads
42
52
53
+ @classmethod
54
+ def from_params (cls , params_path : str ):
55
+ with open (params_path , "r" ) as f :
56
+ params = json .loads (f .read ())
57
+ return cls (** params )
58
+
43
59
@classmethod
44
60
def from_name (cls , name : str ):
45
61
print (f"name { name } " )
@@ -69,31 +85,31 @@ def from_name(cls, name: str):
69
85
"CodeLlama-7b-Python-hf" : dict (
70
86
block_size = 16384 , vocab_size = 32000 , n_layer = 32 , dim = 4096 , rope_base = 1000000
71
87
),
72
- "7B" : dict (n_layer = 32 , n_head = 32 , dim = 4096 ),
73
- "13B" : dict (n_layer = 40 , n_head = 40 , dim = 5120 ),
74
- "30B" : dict (n_layer = 60 , n_head = 52 , dim = 6656 ),
88
+ "7B" : dict (n_layer = 32 , n_heads = 32 , dim = 4096 ),
89
+ "13B" : dict (n_layer = 40 , n_heads = 40 , dim = 5120 ),
90
+ "30B" : dict (n_layer = 60 , n_heads = 52 , dim = 6656 ),
75
91
"34B" : dict (
76
92
n_layer = 48 ,
77
- n_head = 64 ,
93
+ n_heads = 64 ,
78
94
dim = 8192 ,
79
95
vocab_size = 32000 ,
80
96
n_local_heads = 8 ,
81
- intermediate_size = 22016 ,
97
+ hidden_dim = 22016 ,
82
98
rope_base = 1000000 ,
83
99
), # CodeLlama-34B-Python-hf
84
100
"70B" : dict (
85
- n_layer = 80 , n_head = 64 , dim = 8192 , n_local_heads = 8 , intermediate_size = 28672
101
+ n_layer = 80 , n_heads = 64 , dim = 8192 , n_local_heads = 8 , hidden_dim = 28672
86
102
),
87
103
"Mistral-7B" : dict (
88
104
n_layer = 32 ,
89
- n_head = 32 ,
105
+ n_heads = 32 ,
90
106
n_local_heads = 8 ,
91
107
dim = 4096 ,
92
- intermediate_size = 14336 ,
108
+ hidden_dim = 14336 ,
93
109
vocab_size = 32000 ,
94
110
),
95
- "stories15M" : dict (n_layer = 6 , n_head = 6 , dim = 288 ),
96
- "stories110M" : dict (n_layer = 12 , n_head = 12 , dim = 768 ),
111
+ "stories15M" : dict (n_layer = 6 , n_heads = 6 , dim = 288 ),
112
+ "stories110M" : dict (n_layer = 12 , n_heads = 12 , dim = 768 ),
97
113
}
98
114
99
115
@@ -140,7 +156,7 @@ def setup_caches(self, max_batch_size, max_seq_length):
140
156
and self .max_batch_size >= max_batch_size
141
157
):
142
158
return
143
- head_dim = self .config .dim // self .config .n_head
159
+ head_dim = self .config .dim // self .config .n_heads
144
160
max_seq_length = find_multiple (max_seq_length , 8 )
145
161
self .max_seq_length = max_seq_length
146
162
self .max_batch_size = max_batch_size
@@ -150,8 +166,8 @@ def setup_caches(self, max_batch_size, max_seq_length):
150
166
)
151
167
152
168
freqs_cis = precompute_freqs_cis (
153
- self .config .block_size ,
154
- self .config .dim // self . config . n_head ,
169
+ self .config .dim // self . config . n_heads ,
170
+ self .config .block_size * 2 ,
155
171
self .config .rope_base ,
156
172
)
157
173
self .register_buffer ("freqs_cis" , freqs_cis , persistent = True )
@@ -182,6 +198,10 @@ def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
182
198
def from_name (cls , name : str ):
183
199
return cls (ModelArgs .from_name (name ))
184
200
201
+ @classmethod
202
+ def from_params (cls , params_path : str ):
203
+ return cls (ModelArgs .from_params (params_path ))
204
+
185
205
186
206
class TransformerBlock (nn .Module ):
187
207
def __init__ (self , config : ModelArgs ) -> None :
@@ -202,19 +222,19 @@ def forward(
202
222
class Attention (nn .Module ):
203
223
def __init__ (self , config : ModelArgs ):
204
224
super ().__init__ ()
205
- assert config .dim % config .n_head == 0
225
+ assert config .dim % config .n_heads == 0
206
226
207
227
# key, query, value projections for all heads, but in a batch
208
- # total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
228
+ # total_head_dim = (config.n_heads + 2 * config.n_local_heads) * config.head_dim
209
229
# self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
210
- self .wq = nn .Linear (config .dim , config .n_head * config .head_dim , bias = False )
230
+ self .wq = nn .Linear (config .dim , config .n_heads * config .head_dim , bias = False )
211
231
self .wk = nn .Linear (config .dim , config .n_local_heads * config .head_dim , bias = False )
212
232
self .wv = nn .Linear (config .dim , config .n_local_heads * config .head_dim , bias = False )
213
233
214
234
self .wo = nn .Linear (config .dim , config .dim , bias = False )
215
235
self .kv_cache = None
216
236
217
- self .n_head = config .n_head
237
+ self .n_heads = config .n_heads
218
238
self .head_dim = config .head_dim
219
239
self .n_local_heads = config .n_local_heads
220
240
self .dim = config .dim
@@ -243,7 +263,7 @@ def forward(
243
263
# kv_size = self.n_local_heads * self.head_dim
244
264
# q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
245
265
246
- q = q .view (bsz , seqlen , self .n_head , self .head_dim )
266
+ q = q .view (bsz , seqlen , self .n_heads , self .head_dim )
247
267
k = k .view (bsz , seqlen , self .n_local_heads , self .head_dim )
248
268
v = v .view (bsz , seqlen , self .n_local_heads , self .head_dim )
249
269
@@ -255,8 +275,8 @@ def forward(
255
275
if self .kv_cache is not None :
256
276
k , v = self .kv_cache .update (input_pos , k , v )
257
277
258
- k = k .repeat_interleave (self .n_head // self .n_local_heads , dim = 1 )
259
- v = v .repeat_interleave (self .n_head // self .n_local_heads , dim = 1 )
278
+ k = k .repeat_interleave (self .n_heads // self .n_local_heads , dim = 1 )
279
+ v = v .repeat_interleave (self .n_heads // self .n_local_heads , dim = 1 )
260
280
y = F .scaled_dot_product_attention (q , k , v , attn_mask = mask , dropout_p = 0.0 )
261
281
262
282
y = y .transpose (1 , 2 ).contiguous ().view (bsz , seqlen , self .dim )
@@ -268,9 +288,9 @@ def forward(
268
288
class FeedForward (nn .Module ):
269
289
def __init__ (self , config : ModelArgs ) -> None :
270
290
super ().__init__ ()
271
- self .w1 = nn .Linear (config .dim , config .intermediate_size , bias = False )
272
- self .w3 = nn .Linear (config .dim , config .intermediate_size , bias = False )
273
- self .w2 = nn .Linear (config .intermediate_size , config .dim , bias = False )
291
+ self .w1 = nn .Linear (config .dim , config .hidden_dim , bias = False )
292
+ self .w2 = nn .Linear (config .hidden_dim , config .dim , bias = False )
293
+ self .w3 = nn .Linear (config .dim , config .hidden_dim , bias = False )
274
294
275
295
def forward (self , x : Tensor ) -> Tensor :
276
296
return self .w2 (F .silu (self .w1 (x )) * self .w3 (x ))
@@ -289,8 +309,8 @@ def forward(self, x: Tensor) -> Tensor:
289
309
output = self ._norm (x .float ()).type_as (x )
290
310
return output * self .weight
291
311
292
-
293
- def precompute_freqs_cis (seq_len : int , n_elem : int , base : int = 10000 ) -> Tensor :
312
+ # transpsoed first two arguments to align with model in ET
313
+ def precompute_freqs_cis (n_elem : int , seq_len : int , base : int = 10000 ) -> Tensor :
294
314
freqs = 1.0 / (
295
315
base ** (torch .arange (0 , n_elem , 2 )[: (n_elem // 2 )].float () / n_elem )
296
316
)
0 commit comments