1
1
# coding=utf-8
2
2
"""PyTorch MAMBA model."""
3
- from dataclasses import dataclass
4
3
from typing import Iterable , List , Optional , Tuple
5
4
6
5
import torch
10
9
from vllm .attention .backends .abstract import AttentionMetadata
11
10
from vllm .config import CacheConfig , LoRAConfig , SchedulerConfig
12
11
from vllm .distributed import get_tensor_model_parallel_world_size
13
- from vllm .model_executor .layers .activation import SiluAndMul
14
12
from vllm .model_executor .layers .layernorm import RMSNorm
15
13
from vllm .model_executor .layers .linear import (ColumnParallelLinear ,
16
14
MergedColumnParallelLinear ,
39
37
KVCache = Tuple [torch .Tensor , torch .Tensor ]
40
38
41
39
42
- @dataclass
43
- class MambaCacheParams :
44
- is_prompt : bool = False
45
- conv_state : torch .Tensor = torch .Tensor ()
46
- ssm_state : torch .Tensor = torch .Tensor ()
47
-
48
-
49
40
# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
50
41
class MambaMixer (nn .Module ):
51
42
"""
@@ -209,37 +200,6 @@ def forward(self, hidden_states: torch.Tensor,
209
200
return contextualized_states
210
201
211
202
212
- class MambaMLP (nn .Module ):
213
-
214
- def __init__ (
215
- self ,
216
- config : MambaConfig ,
217
- quant_config : Optional [QuantizationConfig ] = None ,
218
- ) -> None :
219
- super ().__init__ ()
220
- hidden_size = config .hidden_size
221
- intermediate_size = config .intermediate_size
222
- hidden_act = config .hidden_act
223
- self .gate_up_proj = MergedColumnParallelLinear (
224
- hidden_size , [intermediate_size ] * 2 ,
225
- bias = False ,
226
- quant_config = quant_config )
227
- self .down_proj = RowParallelLinear (intermediate_size ,
228
- hidden_size ,
229
- bias = False ,
230
- quant_config = quant_config )
231
- if hidden_act != "silu" :
232
- raise ValueError (f"Unsupported activation: { hidden_act } . "
233
- "Only silu is supported for now." )
234
- self .act_fn = SiluAndMul ()
235
-
236
- def forward (self , x ):
237
- gate_up , _ = self .gate_up_proj (x )
238
- x = self .act_fn (gate_up )
239
- x , _ = self .down_proj (x )
240
- return x
241
-
242
-
243
203
class MambaDecoderLayer (nn .Module ):
244
204
245
205
def __init__ (self ,
@@ -252,7 +212,6 @@ def __init__(self,
252
212
self .config = config
253
213
self .mixer = MambaMixer (config , layer_idx )
254
214
255
- self .feed_forward = MambaMLP (config , quant_config = quant_config )
256
215
self .norm = RMSNorm (config .hidden_size , eps = config .layer_norm_epsilon )
257
216
self .pre_ff_layernorm = RMSNorm (config .hidden_size ,
258
217
eps = config .layer_norm_epsilon )
@@ -274,10 +233,6 @@ def forward(
274
233
275
234
hidden_states = self .mixer (hidden_states , attn_metadata , conv_state ,
276
235
ssm_state )
277
- # Fully Connected
278
- hidden_states , residual = self .pre_ff_layernorm (
279
- hidden_states , residual )
280
- hidden_states = self .feed_forward (hidden_states )
281
236
return hidden_states , residual
282
237
283
238
@@ -319,7 +274,6 @@ def forward(
319
274
self ,
320
275
input_ids : torch .Tensor ,
321
276
positions : torch .Tensor ,
322
- kv_caches : List [torch .Tensor ],
323
277
attn_metadata : AttentionMetadata ,
324
278
conv_state : torch .Tensor ,
325
279
ssm_state : torch .Tensor ,
@@ -346,26 +300,6 @@ def forward(
346
300
347
301
348
302
class MambaForCausalLM (nn .Module , HasInnerState , IsAttentionFree ):
349
- packed_modules_mapping = {
350
- "qkv_proj" : [
351
- "q_proj" ,
352
- "k_proj" ,
353
- "v_proj" ,
354
- ],
355
- }
356
-
357
- # LoRA specific attributes
358
- supported_lora_modules = [
359
- "qkv_proj" ,
360
- "o_proj" ,
361
- "embed_tokens" ,
362
- "lm_head" ,
363
- ]
364
- embedding_modules = {
365
- "embeddings" : "input_embeddings" ,
366
- "lm_head" : "output_embeddings" ,
367
- }
368
- embedding_padding_modules = ["lm_head" ]
369
303
370
304
def __init__ (
371
305
self ,
@@ -416,8 +350,8 @@ def forward(self,
416
350
mamba_cache_tensors = self .mamba_cache .current_run_tensors (
417
351
input_ids , attn_metadata , ** kwargs )
418
352
419
- hidden_states = self .backbone (input_ids , positions , kv_caches ,
420
- attn_metadata , mamba_cache_tensors [0 ],
353
+ hidden_states = self .backbone (input_ids , positions , attn_metadata ,
354
+ mamba_cache_tensors [0 ],
421
355
mamba_cache_tensors [1 ])
422
356
423
357
return hidden_states
@@ -457,43 +391,16 @@ def sample(
457
391
return next_tokens
458
392
459
393
def load_weights (self , weights : Iterable [Tuple [str , torch .Tensor ]]):
460
- stacked_params_mapping = [
461
- # (param_name, shard_name, shard_id)
462
- ("qkv_proj" , "q_proj" , "q" ),
463
- ("qkv_proj" , "k_proj" , "k" ),
464
- ("qkv_proj" , "v_proj" , "v" ),
465
- ("gate_up_proj" , "gate_proj" , 0 ),
466
- ("gate_up_proj" , "up_proj" , 1 ),
467
- ]
468
-
469
394
params_dict = dict (self .named_parameters ())
470
395
for name , loaded_weight in weights :
471
- if "rotary_emb.inv_freq" in name :
472
- continue
473
-
474
396
if "A_log" in name :
475
397
name = name .replace ("A_log" , "A" )
476
398
477
- if ".self_attn." in name :
478
- name = name .replace (".self_attn" , "" )
479
-
480
- for param_name , weight_name , shard_id in stacked_params_mapping :
481
- if weight_name not in name :
482
- continue
483
- name = name .replace (weight_name , param_name )
484
- # Skip loading extra bias for GPTQ models.
485
- if name .endswith (".bias" ) and name not in params_dict :
486
- continue
487
- param = params_dict [name ]
488
- weight_loader = param .weight_loader
489
- weight_loader (param , loaded_weight , shard_id )
490
- break
491
- else :
492
- # Skip loading extra bias for GPTQ models.
493
- if name .endswith (".bias" ) and name not in params_dict :
494
- continue
495
-
496
- param = params_dict [name ]
497
- weight_loader = getattr (param , "weight_loader" ,
498
- default_weight_loader )
499
- weight_loader (param , loaded_weight )
399
+ # Skip loading extra bias for GPTQ models.
400
+ if name .endswith (".bias" ) and name not in params_dict :
401
+ continue
402
+
403
+ param = params_dict [name ]
404
+ weight_loader = getattr (param , "weight_loader" ,
405
+ default_weight_loader )
406
+ weight_loader (param , loaded_weight )
0 commit comments