Skip to content

Commit a71bcf4

Browse files
committed
limit LoRA targets
1 parent cd2f63f commit a71bcf4

File tree

1 file changed

+103
-68
lines changed

1 file changed

+103
-68
lines changed

vllm/model_executor/models/llama.py

+103-68
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
# See the License for the specific language governing permissions and
2222
# limitations under the License.
2323
"""Inference-only LLaMA model compatible with HuggingFace weights."""
24+
2425
from typing import Any, Dict, Iterable, List, Optional, Tuple
2526

2627
import torch
@@ -29,28 +30,36 @@
2930

3031
from vllm.attention import Attention, AttentionMetadata
3132
from vllm.config import LoRAConfig
32-
from vllm.distributed import (get_tensor_model_parallel_rank,
33-
get_tensor_model_parallel_world_size)
33+
from vllm.distributed import (
34+
get_tensor_model_parallel_rank,
35+
get_tensor_model_parallel_world_size,
36+
)
3437
from vllm.model_executor.layers.activation import SiluAndMul
3538
from vllm.model_executor.layers.layernorm import RMSNorm
36-
from vllm.model_executor.layers.linear import (LinearMethodBase,
37-
MergedColumnParallelLinear,
38-
QKVParallelLinear,
39-
RowParallelLinear)
39+
from vllm.model_executor.layers.linear import (
40+
LinearMethodBase,
41+
MergedColumnParallelLinear,
42+
QKVParallelLinear,
43+
RowParallelLinear,
44+
)
4045
from vllm.model_executor.layers.logits_processor import LogitsProcessor
4146
from vllm.model_executor.layers.rotary_embedding import get_rope
4247
from vllm.model_executor.layers.sampler import Sampler
4348
from vllm.model_executor.layers.vocab_parallel_embedding import (
44-
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
49+
DEFAULT_VOCAB_PADDING_SIZE,
50+
ParallelLMHead,
51+
VocabParallelEmbedding,
52+
)
4553
from vllm.model_executor.model_loader.weight_utils import (
46-
default_weight_loader, kv_cache_scales_loader)
54+
default_weight_loader,
55+
kv_cache_scales_loader,
56+
)
4757
from vllm.model_executor.sampling_metadata import SamplingMetadata
4858
from vllm.sequence import SamplerOutput
4959
from vllm.utils import is_hip
5060

5161

5262
class LlamaMLP(nn.Module):
53-
5463
def __init__(
5564
self,
5665
hidden_size: int,
@@ -60,16 +69,22 @@ def __init__(
6069
) -> None:
6170
super().__init__()
6271
self.gate_up_proj = MergedColumnParallelLinear(
63-
hidden_size, [intermediate_size] * 2,
72+
hidden_size,
73+
[intermediate_size] * 2,
74+
bias=False,
75+
linear_method=linear_method,
76+
)
77+
self.down_proj = RowParallelLinear(
78+
intermediate_size,
79+
hidden_size,
6480
bias=False,
65-
linear_method=linear_method)
66-
self.down_proj = RowParallelLinear(intermediate_size,
67-
hidden_size,
68-
bias=False,
69-
linear_method=linear_method)
81+
linear_method=linear_method,
82+
)
7083
if hidden_act != "silu":
71-
raise ValueError(f"Unsupported activation: {hidden_act}. "
72-
"Only silu is supported for now.")
84+
raise ValueError(
85+
f"Unsupported activation: {hidden_act}. "
86+
"Only silu is supported for now."
87+
)
7388
self.act_fn = SiluAndMul()
7489

7590
def forward(self, x):
@@ -80,7 +95,6 @@ def forward(self, x):
8095

8196

8297
class LlamaAttention(nn.Module):
83-
8498
def __init__(
8599
self,
86100
hidden_size: int,
@@ -147,11 +161,13 @@ def __init__(
147161
base=rope_theta,
148162
rope_scaling=rope_scaling,
149163
)
150-
self.attn = Attention(self.num_heads,
151-
self.head_dim,
152-
self.scaling,
153-
num_kv_heads=self.num_kv_heads,
154-
sliding_window=sliding_window)
164+
self.attn = Attention(
165+
self.num_heads,
166+
self.head_dim,
167+
self.scaling,
168+
num_kv_heads=self.num_kv_heads,
169+
sliding_window=sliding_window,
170+
)
155171

156172
def forward(
157173
self,
@@ -163,14 +179,12 @@ def forward(
163179
qkv, _ = self.qkv_proj(hidden_states)
164180
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
165181
q, k = self.rotary_emb(positions, q, k)
166-
attn_output = self.attn(q, k, v, kv_cache, attn_metadata,
167-
self.kv_scale)
182+
attn_output = self.attn(q, k, v, kv_cache, attn_metadata, self.kv_scale)
168183
output, _ = self.o_proj(attn_output)
169184
return output
170185

171186

172187
class LlamaDecoderLayer(nn.Module):
173-
174188
def __init__(
175189
self,
176190
config: LlamaConfig,
@@ -180,18 +194,21 @@ def __init__(
180194
self.hidden_size = config.hidden_size
181195
rope_theta = getattr(config, "rope_theta", 10000)
182196
rope_scaling = getattr(config, "rope_scaling", None)
183-
max_position_embeddings = getattr(config, "max_position_embeddings",
184-
8192)
197+
max_position_embeddings = getattr(
198+
config, "max_position_embeddings", 8192
199+
)
185200
sliding_window = getattr(config, "sliding_window", None)
186201
# Support abacusai/Smaug-72B-v0.1 with attention_bias
187202
# Support internlm/internlm-7b with bias
188203
attention_bias = getattr(config, "attention_bias", False) or getattr(
189-
config, "bias", False)
204+
config, "bias", False
205+
)
190206
self.self_attn = LlamaAttention(
191207
hidden_size=self.hidden_size,
192208
num_heads=config.num_attention_heads,
193-
num_kv_heads=getattr(config, "num_key_value_heads",
194-
config.num_attention_heads),
209+
num_kv_heads=getattr(
210+
config, "num_key_value_heads", config.num_attention_heads
211+
),
195212
rope_theta=rope_theta,
196213
rope_scaling=rope_scaling,
197214
max_position_embeddings=max_position_embeddings,
@@ -205,10 +222,12 @@ def __init__(
205222
hidden_act=config.hidden_act,
206223
linear_method=linear_method,
207224
)
208-
self.input_layernorm = RMSNorm(config.hidden_size,
209-
eps=config.rms_norm_eps)
210-
self.post_attention_layernorm = RMSNorm(config.hidden_size,
211-
eps=config.rms_norm_eps)
225+
self.input_layernorm = RMSNorm(
226+
config.hidden_size, eps=config.rms_norm_eps
227+
)
228+
self.post_attention_layernorm = RMSNorm(
229+
config.hidden_size, eps=config.rms_norm_eps
230+
)
212231

213232
def forward(
214233
self,
@@ -224,7 +243,8 @@ def forward(
224243
hidden_states = self.input_layernorm(hidden_states)
225244
else:
226245
hidden_states, residual = self.input_layernorm(
227-
hidden_states, residual)
246+
hidden_states, residual
247+
)
228248
hidden_states = self.self_attn(
229249
positions=positions,
230250
hidden_states=hidden_states,
@@ -234,13 +254,13 @@ def forward(
234254

235255
# Fully Connected
236256
hidden_states, residual = self.post_attention_layernorm(
237-
hidden_states, residual)
257+
hidden_states, residual
258+
)
238259
hidden_states = self.mlp(hidden_states)
239260
return hidden_states, residual
240261

241262

242263
class LlamaModel(nn.Module):
243-
244264
def __init__(
245265
self,
246266
config: LlamaConfig,
@@ -250,19 +270,24 @@ def __init__(
250270
super().__init__()
251271
self.config = config
252272
self.padding_idx = config.pad_token_id
253-
lora_vocab = (lora_config.lora_extra_vocab_size *
254-
(lora_config.max_loras or 1)) if lora_config else 0
273+
lora_vocab = (
274+
(lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
275+
if lora_config
276+
else 0
277+
)
255278
self.vocab_size = config.vocab_size + lora_vocab
256279
self.org_vocab_size = config.vocab_size
257280
self.embed_tokens = VocabParallelEmbedding(
258281
self.vocab_size,
259282
config.hidden_size,
260283
org_num_embeddings=config.vocab_size,
261284
)
262-
self.layers = nn.ModuleList([
263-
LlamaDecoderLayer(config, linear_method)
264-
for _ in range(config.num_hidden_layers)
265-
])
285+
self.layers = nn.ModuleList(
286+
[
287+
LlamaDecoderLayer(config, linear_method)
288+
for _ in range(config.num_hidden_layers)
289+
]
290+
)
266291
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
267292

268293
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
@@ -316,11 +341,8 @@ class LlamaForCausalLM(nn.Module):
316341
"embed_tokens",
317342
"lm_head",
318343
]
319-
embedding_modules = {
320-
"embed_tokens": "input_embeddings",
321-
"lm_head": "output_embeddings",
322-
}
323-
embedding_padding_modules = ["lm_head"]
344+
embedding_modules = {}
345+
embedding_padding_modules = []
324346

325347
def __init__(
326348
self,
@@ -342,12 +364,14 @@ def __init__(
342364
padding_size=DEFAULT_VOCAB_PADDING_SIZE
343365
# We need bigger padding if using lora for kernel
344366
# compatibility
345-
if not lora_config else lora_config.lora_vocab_padding_size,
367+
if not lora_config
368+
else lora_config.lora_vocab_padding_size,
346369
)
347370

348371
logit_scale = getattr(config, "logit_scale", 1.0)
349-
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
350-
config.vocab_size, logit_scale)
372+
self.logits_processor = LogitsProcessor(
373+
self.unpadded_vocab_size, config.vocab_size, logit_scale
374+
)
351375
self.sampler = Sampler()
352376

353377
def forward(
@@ -357,14 +381,17 @@ def forward(
357381
kv_caches: List[torch.Tensor],
358382
attn_metadata: AttentionMetadata,
359383
) -> torch.Tensor:
360-
hidden_states = self.model(input_ids, positions, kv_caches,
361-
attn_metadata)
384+
hidden_states = self.model(
385+
input_ids, positions, kv_caches, attn_metadata
386+
)
362387
return hidden_states
363388

364-
def compute_logits(self, hidden_states: torch.Tensor,
365-
sampling_metadata: SamplingMetadata) -> torch.Tensor:
366-
logits = self.logits_processor(self.lm_head.weight, hidden_states,
367-
sampling_metadata)
389+
def compute_logits(
390+
self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata
391+
) -> torch.Tensor:
392+
logits = self.logits_processor(
393+
self.lm_head.weight, hidden_states, sampling_metadata
394+
)
368395
return logits
369396

370397
def sample(
@@ -388,12 +415,14 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
388415
for name, loaded_weight in weights:
389416
if "rotary_emb.inv_freq" in name:
390417
continue
391-
if ("rotary_emb.cos_cached" in name
392-
or "rotary_emb.sin_cached" in name):
418+
if (
419+
"rotary_emb.cos_cached" in name
420+
or "rotary_emb.sin_cached" in name
421+
):
393422
# Models trained using ColossalAI may include these tensors in
394423
# the checkpoint. Skip them.
395424
continue
396-
for (param_name, weight_name, shard_id) in stacked_params_mapping:
425+
for param_name, weight_name, shard_id in stacked_params_mapping:
397426
if weight_name not in name:
398427
continue
399428
name = name.replace(weight_name, param_name)
@@ -409,8 +438,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
409438
if name.endswith(".bias") and name not in params_dict:
410439
continue
411440
param = params_dict[name]
412-
weight_loader = getattr(param, "weight_loader",
413-
default_weight_loader)
441+
weight_loader = getattr(
442+
param, "weight_loader", default_weight_loader
443+
)
414444
weight_loader(param, loaded_weight)
415445

416446
# If this function is called, it should always initialize KV cache scale
@@ -420,9 +450,12 @@ def load_kv_cache_scales(self, quantization_param_path: str) -> None:
420450
tp_size = get_tensor_model_parallel_world_size()
421451
tp_rank = get_tensor_model_parallel_rank()
422452
for layer_idx, scaling_factor in kv_cache_scales_loader(
423-
quantization_param_path, tp_rank, tp_size,
424-
self.config.num_hidden_layers,
425-
self.config.__class__.model_type):
453+
quantization_param_path,
454+
tp_rank,
455+
tp_size,
456+
self.config.num_hidden_layers,
457+
self.config.__class__.model_type,
458+
):
426459
layer_self_attn = self.model.layers[layer_idx].self_attn
427460

428461
if is_hip():
@@ -434,5 +467,7 @@ def load_kv_cache_scales(self, quantization_param_path: str) -> None:
434467
if hasattr(layer_self_attn, "kv_scale"):
435468
layer_self_attn.kv_scale = scaling_factor
436469
else:
437-
raise RuntimeError("Self attention has no KV cache scaling "
438-
"factor attribute!")
470+
raise RuntimeError(
471+
"Self attention has no KV cache scaling "
472+
"factor attribute!"
473+
)

0 commit comments

Comments
 (0)