21
21
# See the License for the specific language governing permissions and
22
22
# limitations under the License.
23
23
"""Inference-only LLaMA model compatible with HuggingFace weights."""
24
+
24
25
from typing import Any , Dict , Iterable , List , Optional , Tuple
25
26
26
27
import torch
29
30
30
31
from vllm .attention import Attention , AttentionMetadata
31
32
from vllm .config import LoRAConfig
32
- from vllm .distributed import (get_tensor_model_parallel_rank ,
33
- get_tensor_model_parallel_world_size )
33
+ from vllm .distributed import (
34
+ get_tensor_model_parallel_rank ,
35
+ get_tensor_model_parallel_world_size ,
36
+ )
34
37
from vllm .model_executor .layers .activation import SiluAndMul
35
38
from vllm .model_executor .layers .layernorm import RMSNorm
36
- from vllm .model_executor .layers .linear import (LinearMethodBase ,
37
- MergedColumnParallelLinear ,
38
- QKVParallelLinear ,
39
- RowParallelLinear )
39
+ from vllm .model_executor .layers .linear import (
40
+ LinearMethodBase ,
41
+ MergedColumnParallelLinear ,
42
+ QKVParallelLinear ,
43
+ RowParallelLinear ,
44
+ )
40
45
from vllm .model_executor .layers .logits_processor import LogitsProcessor
41
46
from vllm .model_executor .layers .rotary_embedding import get_rope
42
47
from vllm .model_executor .layers .sampler import Sampler
43
48
from vllm .model_executor .layers .vocab_parallel_embedding import (
44
- DEFAULT_VOCAB_PADDING_SIZE , ParallelLMHead , VocabParallelEmbedding )
49
+ DEFAULT_VOCAB_PADDING_SIZE ,
50
+ ParallelLMHead ,
51
+ VocabParallelEmbedding ,
52
+ )
45
53
from vllm .model_executor .model_loader .weight_utils import (
46
- default_weight_loader , kv_cache_scales_loader )
54
+ default_weight_loader ,
55
+ kv_cache_scales_loader ,
56
+ )
47
57
from vllm .model_executor .sampling_metadata import SamplingMetadata
48
58
from vllm .sequence import SamplerOutput
49
59
from vllm .utils import is_hip
50
60
51
61
52
62
class LlamaMLP (nn .Module ):
53
-
54
63
def __init__ (
55
64
self ,
56
65
hidden_size : int ,
@@ -60,16 +69,22 @@ def __init__(
60
69
) -> None :
61
70
super ().__init__ ()
62
71
self .gate_up_proj = MergedColumnParallelLinear (
63
- hidden_size , [intermediate_size ] * 2 ,
72
+ hidden_size ,
73
+ [intermediate_size ] * 2 ,
74
+ bias = False ,
75
+ linear_method = linear_method ,
76
+ )
77
+ self .down_proj = RowParallelLinear (
78
+ intermediate_size ,
79
+ hidden_size ,
64
80
bias = False ,
65
- linear_method = linear_method )
66
- self .down_proj = RowParallelLinear (intermediate_size ,
67
- hidden_size ,
68
- bias = False ,
69
- linear_method = linear_method )
81
+ linear_method = linear_method ,
82
+ )
70
83
if hidden_act != "silu" :
71
- raise ValueError (f"Unsupported activation: { hidden_act } . "
72
- "Only silu is supported for now." )
84
+ raise ValueError (
85
+ f"Unsupported activation: { hidden_act } . "
86
+ "Only silu is supported for now."
87
+ )
73
88
self .act_fn = SiluAndMul ()
74
89
75
90
def forward (self , x ):
@@ -80,7 +95,6 @@ def forward(self, x):
80
95
81
96
82
97
class LlamaAttention (nn .Module ):
83
-
84
98
def __init__ (
85
99
self ,
86
100
hidden_size : int ,
@@ -147,11 +161,13 @@ def __init__(
147
161
base = rope_theta ,
148
162
rope_scaling = rope_scaling ,
149
163
)
150
- self .attn = Attention (self .num_heads ,
151
- self .head_dim ,
152
- self .scaling ,
153
- num_kv_heads = self .num_kv_heads ,
154
- sliding_window = sliding_window )
164
+ self .attn = Attention (
165
+ self .num_heads ,
166
+ self .head_dim ,
167
+ self .scaling ,
168
+ num_kv_heads = self .num_kv_heads ,
169
+ sliding_window = sliding_window ,
170
+ )
155
171
156
172
def forward (
157
173
self ,
@@ -163,14 +179,12 @@ def forward(
163
179
qkv , _ = self .qkv_proj (hidden_states )
164
180
q , k , v = qkv .split ([self .q_size , self .kv_size , self .kv_size ], dim = - 1 )
165
181
q , k = self .rotary_emb (positions , q , k )
166
- attn_output = self .attn (q , k , v , kv_cache , attn_metadata ,
167
- self .kv_scale )
182
+ attn_output = self .attn (q , k , v , kv_cache , attn_metadata , self .kv_scale )
168
183
output , _ = self .o_proj (attn_output )
169
184
return output
170
185
171
186
172
187
class LlamaDecoderLayer (nn .Module ):
173
-
174
188
def __init__ (
175
189
self ,
176
190
config : LlamaConfig ,
@@ -180,18 +194,21 @@ def __init__(
180
194
self .hidden_size = config .hidden_size
181
195
rope_theta = getattr (config , "rope_theta" , 10000 )
182
196
rope_scaling = getattr (config , "rope_scaling" , None )
183
- max_position_embeddings = getattr (config , "max_position_embeddings" ,
184
- 8192 )
197
+ max_position_embeddings = getattr (
198
+ config , "max_position_embeddings" , 8192
199
+ )
185
200
sliding_window = getattr (config , "sliding_window" , None )
186
201
# Support abacusai/Smaug-72B-v0.1 with attention_bias
187
202
# Support internlm/internlm-7b with bias
188
203
attention_bias = getattr (config , "attention_bias" , False ) or getattr (
189
- config , "bias" , False )
204
+ config , "bias" , False
205
+ )
190
206
self .self_attn = LlamaAttention (
191
207
hidden_size = self .hidden_size ,
192
208
num_heads = config .num_attention_heads ,
193
- num_kv_heads = getattr (config , "num_key_value_heads" ,
194
- config .num_attention_heads ),
209
+ num_kv_heads = getattr (
210
+ config , "num_key_value_heads" , config .num_attention_heads
211
+ ),
195
212
rope_theta = rope_theta ,
196
213
rope_scaling = rope_scaling ,
197
214
max_position_embeddings = max_position_embeddings ,
@@ -205,10 +222,12 @@ def __init__(
205
222
hidden_act = config .hidden_act ,
206
223
linear_method = linear_method ,
207
224
)
208
- self .input_layernorm = RMSNorm (config .hidden_size ,
209
- eps = config .rms_norm_eps )
210
- self .post_attention_layernorm = RMSNorm (config .hidden_size ,
211
- eps = config .rms_norm_eps )
225
+ self .input_layernorm = RMSNorm (
226
+ config .hidden_size , eps = config .rms_norm_eps
227
+ )
228
+ self .post_attention_layernorm = RMSNorm (
229
+ config .hidden_size , eps = config .rms_norm_eps
230
+ )
212
231
213
232
def forward (
214
233
self ,
@@ -224,7 +243,8 @@ def forward(
224
243
hidden_states = self .input_layernorm (hidden_states )
225
244
else :
226
245
hidden_states , residual = self .input_layernorm (
227
- hidden_states , residual )
246
+ hidden_states , residual
247
+ )
228
248
hidden_states = self .self_attn (
229
249
positions = positions ,
230
250
hidden_states = hidden_states ,
@@ -234,13 +254,13 @@ def forward(
234
254
235
255
# Fully Connected
236
256
hidden_states , residual = self .post_attention_layernorm (
237
- hidden_states , residual )
257
+ hidden_states , residual
258
+ )
238
259
hidden_states = self .mlp (hidden_states )
239
260
return hidden_states , residual
240
261
241
262
242
263
class LlamaModel (nn .Module ):
243
-
244
264
def __init__ (
245
265
self ,
246
266
config : LlamaConfig ,
@@ -250,19 +270,24 @@ def __init__(
250
270
super ().__init__ ()
251
271
self .config = config
252
272
self .padding_idx = config .pad_token_id
253
- lora_vocab = (lora_config .lora_extra_vocab_size *
254
- (lora_config .max_loras or 1 )) if lora_config else 0
273
+ lora_vocab = (
274
+ (lora_config .lora_extra_vocab_size * (lora_config .max_loras or 1 ))
275
+ if lora_config
276
+ else 0
277
+ )
255
278
self .vocab_size = config .vocab_size + lora_vocab
256
279
self .org_vocab_size = config .vocab_size
257
280
self .embed_tokens = VocabParallelEmbedding (
258
281
self .vocab_size ,
259
282
config .hidden_size ,
260
283
org_num_embeddings = config .vocab_size ,
261
284
)
262
- self .layers = nn .ModuleList ([
263
- LlamaDecoderLayer (config , linear_method )
264
- for _ in range (config .num_hidden_layers )
265
- ])
285
+ self .layers = nn .ModuleList (
286
+ [
287
+ LlamaDecoderLayer (config , linear_method )
288
+ for _ in range (config .num_hidden_layers )
289
+ ]
290
+ )
266
291
self .norm = RMSNorm (config .hidden_size , eps = config .rms_norm_eps )
267
292
268
293
def get_input_embeddings (self , input_ids : torch .Tensor ) -> torch .Tensor :
@@ -316,11 +341,8 @@ class LlamaForCausalLM(nn.Module):
316
341
"embed_tokens" ,
317
342
"lm_head" ,
318
343
]
319
- embedding_modules = {
320
- "embed_tokens" : "input_embeddings" ,
321
- "lm_head" : "output_embeddings" ,
322
- }
323
- embedding_padding_modules = ["lm_head" ]
344
+ embedding_modules = {}
345
+ embedding_padding_modules = []
324
346
325
347
def __init__ (
326
348
self ,
@@ -342,12 +364,14 @@ def __init__(
342
364
padding_size = DEFAULT_VOCAB_PADDING_SIZE
343
365
# We need bigger padding if using lora for kernel
344
366
# compatibility
345
- if not lora_config else lora_config .lora_vocab_padding_size ,
367
+ if not lora_config
368
+ else lora_config .lora_vocab_padding_size ,
346
369
)
347
370
348
371
logit_scale = getattr (config , "logit_scale" , 1.0 )
349
- self .logits_processor = LogitsProcessor (self .unpadded_vocab_size ,
350
- config .vocab_size , logit_scale )
372
+ self .logits_processor = LogitsProcessor (
373
+ self .unpadded_vocab_size , config .vocab_size , logit_scale
374
+ )
351
375
self .sampler = Sampler ()
352
376
353
377
def forward (
@@ -357,14 +381,17 @@ def forward(
357
381
kv_caches : List [torch .Tensor ],
358
382
attn_metadata : AttentionMetadata ,
359
383
) -> torch .Tensor :
360
- hidden_states = self .model (input_ids , positions , kv_caches ,
361
- attn_metadata )
384
+ hidden_states = self .model (
385
+ input_ids , positions , kv_caches , attn_metadata
386
+ )
362
387
return hidden_states
363
388
364
- def compute_logits (self , hidden_states : torch .Tensor ,
365
- sampling_metadata : SamplingMetadata ) -> torch .Tensor :
366
- logits = self .logits_processor (self .lm_head .weight , hidden_states ,
367
- sampling_metadata )
389
+ def compute_logits (
390
+ self , hidden_states : torch .Tensor , sampling_metadata : SamplingMetadata
391
+ ) -> torch .Tensor :
392
+ logits = self .logits_processor (
393
+ self .lm_head .weight , hidden_states , sampling_metadata
394
+ )
368
395
return logits
369
396
370
397
def sample (
@@ -388,12 +415,14 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
388
415
for name , loaded_weight in weights :
389
416
if "rotary_emb.inv_freq" in name :
390
417
continue
391
- if ("rotary_emb.cos_cached" in name
392
- or "rotary_emb.sin_cached" in name ):
418
+ if (
419
+ "rotary_emb.cos_cached" in name
420
+ or "rotary_emb.sin_cached" in name
421
+ ):
393
422
# Models trained using ColossalAI may include these tensors in
394
423
# the checkpoint. Skip them.
395
424
continue
396
- for ( param_name , weight_name , shard_id ) in stacked_params_mapping :
425
+ for param_name , weight_name , shard_id in stacked_params_mapping :
397
426
if weight_name not in name :
398
427
continue
399
428
name = name .replace (weight_name , param_name )
@@ -409,8 +438,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
409
438
if name .endswith (".bias" ) and name not in params_dict :
410
439
continue
411
440
param = params_dict [name ]
412
- weight_loader = getattr (param , "weight_loader" ,
413
- default_weight_loader )
441
+ weight_loader = getattr (
442
+ param , "weight_loader" , default_weight_loader
443
+ )
414
444
weight_loader (param , loaded_weight )
415
445
416
446
# If this function is called, it should always initialize KV cache scale
@@ -420,9 +450,12 @@ def load_kv_cache_scales(self, quantization_param_path: str) -> None:
420
450
tp_size = get_tensor_model_parallel_world_size ()
421
451
tp_rank = get_tensor_model_parallel_rank ()
422
452
for layer_idx , scaling_factor in kv_cache_scales_loader (
423
- quantization_param_path , tp_rank , tp_size ,
424
- self .config .num_hidden_layers ,
425
- self .config .__class__ .model_type ):
453
+ quantization_param_path ,
454
+ tp_rank ,
455
+ tp_size ,
456
+ self .config .num_hidden_layers ,
457
+ self .config .__class__ .model_type ,
458
+ ):
426
459
layer_self_attn = self .model .layers [layer_idx ].self_attn
427
460
428
461
if is_hip ():
@@ -434,5 +467,7 @@ def load_kv_cache_scales(self, quantization_param_path: str) -> None:
434
467
if hasattr (layer_self_attn , "kv_scale" ):
435
468
layer_self_attn .kv_scale = scaling_factor
436
469
else :
437
- raise RuntimeError ("Self attention has no KV cache scaling "
438
- "factor attribute!" )
470
+ raise RuntimeError (
471
+ "Self attention has no KV cache scaling "
472
+ "factor attribute!"
473
+ )
0 commit comments