Skip to content

Commit 1ec47da

Browse files
committed
Align SwissAI to v0.9.0.1
1 parent 7ee27da commit 1ec47da

1 file changed

Lines changed: 90 additions & 52 deletions

File tree

vllm/model_executor/models/swissai.py

Lines changed: 90 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,14 @@
2222
# See the License for the specific language governing permissions and
2323
# limitations under the License.
2424
"""Inference-only SwissAI model compatible with HuggingFace weights."""
25-
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Type, Union
25+
from collections.abc import Iterable
26+
from typing import Any, Optional, Union
2627

2728
import torch
2829
from torch import nn
2930
from transformers import SwissAIConfig
3031

31-
from vllm.attention import Attention
32+
from vllm.attention import Attention, AttentionType
3233
from vllm.compilation.decorators import support_torch_compile
3334
from vllm.config import CacheConfig, VllmConfig
3435
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
@@ -40,7 +41,6 @@
4041
from vllm.model_executor.layers.logits_processor import LogitsProcessor
4142
from vllm.model_executor.layers.quantization import QuantizationConfig
4243
from vllm.model_executor.layers.rotary_embedding import get_rope
43-
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
4444
from vllm.model_executor.layers.vocab_parallel_embedding import (
4545
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
4646
from vllm.model_executor.model_loader.weight_utils import (
@@ -95,19 +95,22 @@ def forward(self, x):
9595

9696
class SwissAIAttention(nn.Module):
9797

98-
def __init__(self,
99-
config: SwissAIConfig,
100-
hidden_size: int,
101-
num_heads: int,
102-
num_kv_heads: int,
103-
rope_theta: float = 10000,
104-
rope_scaling: Optional[Dict[str, Any]] = None,
105-
max_position_embeddings: int = 8192,
106-
quant_config: Optional[QuantizationConfig] = None,
107-
bias: bool = False,
108-
bias_o_proj: bool = False,
109-
cache_config: Optional[CacheConfig] = None,
110-
prefix: str = "") -> None:
98+
def __init__(
99+
self,
100+
config: SwissAIConfig,
101+
hidden_size: int,
102+
num_heads: int,
103+
num_kv_heads: int,
104+
rope_theta: float = 10000,
105+
rope_scaling: Optional[dict[str, Any]] = None,
106+
max_position_embeddings: int = 8192,
107+
quant_config: Optional[QuantizationConfig] = None,
108+
bias: bool = False,
109+
bias_o_proj: bool = False,
110+
cache_config: Optional[CacheConfig] = None,
111+
prefix: str = "",
112+
attn_type: str = AttentionType.DECODER,
113+
) -> None:
111114
super().__init__()
112115
layer_idx = extract_layer_index(prefix)
113116
self.hidden_size = hidden_size
@@ -129,8 +132,8 @@ def __init__(self,
129132
self.head_dim = getattr(config, "head_dim",
130133
self.hidden_size // self.total_num_heads)
131134
# Phi models introduced a partial_rotary_factor parameter in the config
132-
partial_rotary_factor = getattr(config, "partial_rotary_factor", 1)
133-
self.rotary_dim = int(partial_rotary_factor * self.head_dim)
135+
self.partial_rotary_factor = getattr(config, "partial_rotary_factor",
136+
1)
134137
self.q_size = self.num_heads * self.head_dim
135138
self.kv_size = self.num_kv_heads * self.head_dim
136139
self.scaling = self.head_dim**-0.5
@@ -155,19 +158,9 @@ def __init__(self,
155158
prefix=f"{prefix}.o_proj",
156159
)
157160

158-
is_neox_style = True
159-
is_gguf = quant_config and quant_config.get_name() == "gguf"
160-
if is_gguf and config.model_type == "swissai":
161-
is_neox_style = False
162-
163-
self.rotary_emb = get_rope(
164-
self.head_dim,
165-
rotary_dim=self.rotary_dim,
166-
max_position=max_position_embeddings,
167-
base=rope_theta,
168-
rope_scaling=rope_scaling,
169-
is_neox_style=is_neox_style,
170-
)
161+
self._init_rotary_emb(config,
162+
rope_scaling=rope_scaling,
163+
quant_config=quant_config)
171164

172165
if hasattr(config, "interleaved_sliding_window"):
173166
interleaved_sliding_window = config.interleaved_sliding_window
@@ -190,6 +183,7 @@ def __init__(self,
190183
cache_config=cache_config,
191184
quant_config=quant_config,
192185
per_layer_sliding_window=sliding_window,
186+
attn_type=attn_type,
193187
prefix=f"{prefix}.attn",
194188
)
195189

@@ -212,6 +206,24 @@ def forward(
212206
output, _ = self.o_proj(attn_output)
213207
return output
214208

209+
def _init_rotary_emb(self, config: SwissAIConfig,
210+
rope_scaling: Optional[dict[str, Any]],
211+
quant_config: Optional[QuantizationConfig]) -> None:
212+
is_neox_style = True
213+
is_gguf = quant_config and quant_config.get_name() == "gguf"
214+
if is_gguf and config.model_type == "swissai":
215+
is_neox_style = False
216+
217+
self.rotary_emb = get_rope(
218+
self.head_dim,
219+
rotary_dim=int(self.partial_rotary_factor * self.head_dim),
220+
max_position=self.max_position_embeddings,
221+
base=self.rope_theta,
222+
rope_scaling=rope_scaling,
223+
is_neox_style=is_neox_style,
224+
partial_rotary_factor=self.partial_rotary_factor,
225+
)
226+
215227

216228
class SwissAIDecoderLayer(nn.Module):
217229

@@ -241,6 +253,15 @@ def __init__(
241253
if hasattr(config, 'qkv_bias'):
242254
attention_bias = config.qkv_bias
243255

256+
# By default, SwissAI uses causal attention as it is a decoder-only model.
257+
# You can override the HF config with `is_causal=False` to enable
258+
# bidirectional attention, which is used in some embedding models
259+
# (e.g. parasail-ai/GritLM-7B-vllm)
260+
if getattr(config, "is_causal", True):
261+
attn_type = AttentionType.DECODER
262+
else:
263+
attn_type = AttentionType.ENCODER_ONLY
264+
244265
self.self_attn = SwissAIAttention(
245266
config=config,
246267
hidden_size=self.hidden_size,
@@ -255,6 +276,7 @@ def __init__(
255276
bias_o_proj=bias_o_proj,
256277
cache_config=cache_config,
257278
prefix=f"{prefix}.self_attn",
279+
attn_type=attn_type,
258280
)
259281
self.mlp = SwissAIMLP(
260282
hidden_size=self.hidden_size,
@@ -274,7 +296,7 @@ def forward(
274296
positions: torch.Tensor,
275297
hidden_states: torch.Tensor,
276298
residual: Optional[torch.Tensor],
277-
) -> Tuple[torch.Tensor, torch.Tensor]:
299+
) -> tuple[torch.Tensor, torch.Tensor]:
278300
# Self Attention
279301
if residual is None:
280302
residual = hidden_states
@@ -299,7 +321,7 @@ def __init__(self,
299321
*,
300322
vllm_config: VllmConfig,
301323
prefix: str = "",
302-
layer_type: Type[SwissAIDecoderLayer] = SwissAIDecoderLayer):
324+
layer_type: type[nn.Module] = SwissAIDecoderLayer):
303325
super().__init__()
304326

305327
config = vllm_config.model_config.hf_config
@@ -313,7 +335,7 @@ def __init__(self,
313335
(lora_config.max_loras or 1)) if lora_config else 0
314336
self.vocab_size = config.vocab_size + lora_vocab
315337
self.org_vocab_size = config.vocab_size
316-
if not torch.cuda.is_available() or get_pp_group().is_first_rank or (config.tie_word_embeddings
338+
if get_pp_group().is_first_rank or (config.tie_word_embeddings
317339
and get_pp_group().is_last_rank):
318340
self.embed_tokens = VocabParallelEmbedding(
319341
self.vocab_size,
@@ -336,6 +358,8 @@ def __init__(self,
336358
else:
337359
self.norm = PPMissingLayer()
338360

361+
self.aux_hidden_state_layers: tuple[int] = tuple()
362+
339363
self.make_empty_intermediate_tensors = (
340364
make_empty_intermediate_tensors_factory(
341365
["hidden_states", "residual"], config.hidden_size))
@@ -349,7 +373,8 @@ def forward(
349373
positions: torch.Tensor,
350374
intermediate_tensors: Optional[IntermediateTensors],
351375
inputs_embeds: Optional[torch.Tensor] = None,
352-
) -> Union[torch.Tensor, IntermediateTensors]:
376+
) -> Union[torch.Tensor, IntermediateTensors, tuple[torch.Tensor,
377+
list[torch.Tensor]]]:
353378
if get_pp_group().is_first_rank:
354379
if inputs_embeds is not None:
355380
hidden_states = inputs_embeds
@@ -361,7 +386,11 @@ def forward(
361386
hidden_states = intermediate_tensors["hidden_states"]
362387
residual = intermediate_tensors["residual"]
363388

364-
for layer in self.layers[self.start_layer:self.end_layer]:
389+
aux_hidden_states = []
390+
for idx, layer in enumerate(
391+
self.layers[self.start_layer:self.end_layer]):
392+
if idx in self.aux_hidden_state_layers:
393+
aux_hidden_states.append(hidden_states + residual)
365394
hidden_states, residual = layer(positions, hidden_states, residual)
366395

367396
if not get_pp_group().is_last_rank:
@@ -371,18 +400,21 @@ def forward(
371400
})
372401

373402
hidden_states, _ = self.norm(hidden_states, residual)
403+
404+
if len(aux_hidden_states) > 0:
405+
return hidden_states, aux_hidden_states
374406
return hidden_states
375407

376-
def load_weights(self, weights: Iterable[Tuple[str,
377-
torch.Tensor]]) -> Set[str]:
408+
def load_weights(self, weights: Iterable[tuple[str,
409+
torch.Tensor]]) -> set[str]:
378410
stacked_params_mapping = [
379411
# (param_name, shard_name, shard_id)
380412
(".qkv_proj", ".q_proj", "q"),
381413
(".qkv_proj", ".k_proj", "k"),
382414
(".qkv_proj", ".v_proj", "v"),
383415
]
384416
params_dict = dict(self.named_parameters())
385-
loaded_params: Set[str] = set()
417+
loaded_params: set[str] = set()
386418
for name, loaded_weight in weights:
387419
if "rotary_emb.inv_freq" in name:
388420
continue
@@ -450,7 +482,11 @@ class SwissAIForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
450482
}
451483
embedding_padding_modules = ["lm_head"]
452484

453-
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
485+
def __init__(self,
486+
*,
487+
vllm_config: VllmConfig,
488+
prefix: str = "",
489+
layer_type: type[nn.Module] = SwissAIDecoderLayer):
454490
super().__init__()
455491
config = vllm_config.model_config.hf_config
456492
quant_config = vllm_config.quant_config
@@ -459,7 +495,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
459495
self.lora_config = lora_config
460496

461497
self.model = self._init_model(vllm_config=vllm_config,
462-
prefix=maybe_prefix(prefix, "model"))
498+
prefix=maybe_prefix(prefix, "model"),
499+
layer_type=layer_type)
463500

464501
if get_pp_group().is_last_rank:
465502
self.unpadded_vocab_size = config.vocab_size
@@ -489,13 +526,19 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
489526
else:
490527
self.lm_head = PPMissingLayer()
491528

492-
self.sampler = get_sampler()
493-
494529
self.make_empty_intermediate_tensors = (
495530
self.model.make_empty_intermediate_tensors)
496531

497-
def _init_model(self, vllm_config: VllmConfig, prefix: str = ""):
498-
return SwissAIModel(vllm_config=vllm_config, prefix=prefix)
532+
def set_aux_hidden_state_layers(self, layers: tuple[int]) -> None:
533+
self.model.aux_hidden_state_layers = layers
534+
535+
def _init_model(self,
536+
vllm_config: VllmConfig,
537+
prefix: str = "",
538+
layer_type: type[nn.Module] = SwissAIDecoderLayer):
539+
return SwissAIModel(vllm_config=vllm_config,
540+
prefix=prefix,
541+
layer_type=layer_type)
499542

500543
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
501544
return self.model.get_input_embeddings(input_ids)
@@ -520,13 +563,8 @@ def compute_logits(
520563
sampling_metadata)
521564
return logits
522565

523-
def sample(self, logits: torch.Tensor,
524-
sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]:
525-
next_tokens = self.sampler(logits, sampling_metadata)
526-
return next_tokens
527-
528-
def load_weights(self, weights: Iterable[Tuple[str,
529-
torch.Tensor]]) -> Set[str]:
566+
def load_weights(self, weights: Iterable[tuple[str,
567+
torch.Tensor]]) -> set[str]:
530568
loader = AutoWeightsLoader(
531569
self,
532570
skip_prefixes=(["lm_head."]

0 commit comments

Comments
 (0)