Skip to content

Commit 169b530

Browse files
authored
[Bugfix] Clean up some cruft in mamba.py (vllm-project#9343)
1 parent f0fe4fe commit 169b530

File tree

2 files changed

+11
-104
lines changed

2 files changed

+11
-104
lines changed

docs/source/models/supported_models.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ Text Generation
155155
* - :code:`MambaForCausalLM`
156156
- Mamba
157157
- :code:`state-spaces/mamba-130m-hf`, :code:`state-spaces/mamba-790m-hf`, :code:`state-spaces/mamba-2.8b-hf`, etc.
158-
- ✅︎
158+
-
159159
-
160160
* - :code:`MiniCPMForCausalLM`
161161
- MiniCPM

vllm/model_executor/models/mamba.py

+10-103
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
# coding=utf-8
22
"""PyTorch MAMBA model."""
3-
from dataclasses import dataclass
43
from typing import Iterable, List, Optional, Tuple
54

65
import torch
@@ -10,7 +9,6 @@
109
from vllm.attention.backends.abstract import AttentionMetadata
1110
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
1211
from vllm.distributed import get_tensor_model_parallel_world_size
13-
from vllm.model_executor.layers.activation import SiluAndMul
1412
from vllm.model_executor.layers.layernorm import RMSNorm
1513
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
1614
MergedColumnParallelLinear,
@@ -39,13 +37,6 @@
3937
KVCache = Tuple[torch.Tensor, torch.Tensor]
4038

4139

42-
@dataclass
43-
class MambaCacheParams:
44-
is_prompt: bool = False
45-
conv_state: torch.Tensor = torch.Tensor()
46-
ssm_state: torch.Tensor = torch.Tensor()
47-
48-
4940
# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
5041
class MambaMixer(nn.Module):
5142
"""
@@ -209,37 +200,6 @@ def forward(self, hidden_states: torch.Tensor,
209200
return contextualized_states
210201

211202

212-
class MambaMLP(nn.Module):
213-
214-
def __init__(
215-
self,
216-
config: MambaConfig,
217-
quant_config: Optional[QuantizationConfig] = None,
218-
) -> None:
219-
super().__init__()
220-
hidden_size = config.hidden_size
221-
intermediate_size = config.intermediate_size
222-
hidden_act = config.hidden_act
223-
self.gate_up_proj = MergedColumnParallelLinear(
224-
hidden_size, [intermediate_size] * 2,
225-
bias=False,
226-
quant_config=quant_config)
227-
self.down_proj = RowParallelLinear(intermediate_size,
228-
hidden_size,
229-
bias=False,
230-
quant_config=quant_config)
231-
if hidden_act != "silu":
232-
raise ValueError(f"Unsupported activation: {hidden_act}. "
233-
"Only silu is supported for now.")
234-
self.act_fn = SiluAndMul()
235-
236-
def forward(self, x):
237-
gate_up, _ = self.gate_up_proj(x)
238-
x = self.act_fn(gate_up)
239-
x, _ = self.down_proj(x)
240-
return x
241-
242-
243203
class MambaDecoderLayer(nn.Module):
244204

245205
def __init__(self,
@@ -252,7 +212,6 @@ def __init__(self,
252212
self.config = config
253213
self.mixer = MambaMixer(config, layer_idx)
254214

255-
self.feed_forward = MambaMLP(config, quant_config=quant_config)
256215
self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
257216
self.pre_ff_layernorm = RMSNorm(config.hidden_size,
258217
eps=config.layer_norm_epsilon)
@@ -274,10 +233,6 @@ def forward(
274233

275234
hidden_states = self.mixer(hidden_states, attn_metadata, conv_state,
276235
ssm_state)
277-
# Fully Connected
278-
hidden_states, residual = self.pre_ff_layernorm(
279-
hidden_states, residual)
280-
hidden_states = self.feed_forward(hidden_states)
281236
return hidden_states, residual
282237

283238

@@ -319,7 +274,6 @@ def forward(
319274
self,
320275
input_ids: torch.Tensor,
321276
positions: torch.Tensor,
322-
kv_caches: List[torch.Tensor],
323277
attn_metadata: AttentionMetadata,
324278
conv_state: torch.Tensor,
325279
ssm_state: torch.Tensor,
@@ -346,26 +300,6 @@ def forward(
346300

347301

348302
class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
349-
packed_modules_mapping = {
350-
"qkv_proj": [
351-
"q_proj",
352-
"k_proj",
353-
"v_proj",
354-
],
355-
}
356-
357-
# LoRA specific attributes
358-
supported_lora_modules = [
359-
"qkv_proj",
360-
"o_proj",
361-
"embed_tokens",
362-
"lm_head",
363-
]
364-
embedding_modules = {
365-
"embeddings": "input_embeddings",
366-
"lm_head": "output_embeddings",
367-
}
368-
embedding_padding_modules = ["lm_head"]
369303

370304
def __init__(
371305
self,
@@ -416,8 +350,8 @@ def forward(self,
416350
mamba_cache_tensors = self.mamba_cache.current_run_tensors(
417351
input_ids, attn_metadata, **kwargs)
418352

419-
hidden_states = self.backbone(input_ids, positions, kv_caches,
420-
attn_metadata, mamba_cache_tensors[0],
353+
hidden_states = self.backbone(input_ids, positions, attn_metadata,
354+
mamba_cache_tensors[0],
421355
mamba_cache_tensors[1])
422356

423357
return hidden_states
@@ -457,43 +391,16 @@ def sample(
457391
return next_tokens
458392

459393
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
460-
stacked_params_mapping = [
461-
# (param_name, shard_name, shard_id)
462-
("qkv_proj", "q_proj", "q"),
463-
("qkv_proj", "k_proj", "k"),
464-
("qkv_proj", "v_proj", "v"),
465-
("gate_up_proj", "gate_proj", 0),
466-
("gate_up_proj", "up_proj", 1),
467-
]
468-
469394
params_dict = dict(self.named_parameters())
470395
for name, loaded_weight in weights:
471-
if "rotary_emb.inv_freq" in name:
472-
continue
473-
474396
if "A_log" in name:
475397
name = name.replace("A_log", "A")
476398

477-
if ".self_attn." in name:
478-
name = name.replace(".self_attn", "")
479-
480-
for param_name, weight_name, shard_id in stacked_params_mapping:
481-
if weight_name not in name:
482-
continue
483-
name = name.replace(weight_name, param_name)
484-
# Skip loading extra bias for GPTQ models.
485-
if name.endswith(".bias") and name not in params_dict:
486-
continue
487-
param = params_dict[name]
488-
weight_loader = param.weight_loader
489-
weight_loader(param, loaded_weight, shard_id)
490-
break
491-
else:
492-
# Skip loading extra bias for GPTQ models.
493-
if name.endswith(".bias") and name not in params_dict:
494-
continue
495-
496-
param = params_dict[name]
497-
weight_loader = getattr(param, "weight_loader",
498-
default_weight_loader)
499-
weight_loader(param, loaded_weight)
399+
# Skip loading extra bias for GPTQ models.
400+
if name.endswith(".bias") and name not in params_dict:
401+
continue
402+
403+
param = params_dict[name]
404+
weight_loader = getattr(param, "weight_loader",
405+
default_weight_loader)
406+
weight_loader(param, loaded_weight)

0 commit comments

Comments
 (0)