2222# See the License for the specific language governing permissions and
2323# limitations under the License.
2424"""Inference-only SwissAI model compatible with HuggingFace weights."""
25- from typing import Any , Dict , Iterable , Optional , Set , Tuple , Type , Union
25+ from collections .abc import Iterable
26+ from typing import Any , Optional , Union
2627
2728import torch
2829from torch import nn
2930from transformers import SwissAIConfig
3031
31- from vllm .attention import Attention
32+ from vllm .attention import Attention , AttentionType
3233from vllm .compilation .decorators import support_torch_compile
3334from vllm .config import CacheConfig , VllmConfig
3435from vllm .distributed import get_pp_group , get_tensor_model_parallel_world_size
4041from vllm .model_executor .layers .logits_processor import LogitsProcessor
4142from vllm .model_executor .layers .quantization import QuantizationConfig
4243from vllm .model_executor .layers .rotary_embedding import get_rope
43- from vllm .model_executor .layers .sampler import SamplerOutput , get_sampler
4444from vllm .model_executor .layers .vocab_parallel_embedding import (
4545 DEFAULT_VOCAB_PADDING_SIZE , ParallelLMHead , VocabParallelEmbedding )
4646from vllm .model_executor .model_loader .weight_utils import (
@@ -95,19 +95,22 @@ def forward(self, x):
9595
9696class SwissAIAttention (nn .Module ):
9797
98- def __init__ (self ,
99- config : SwissAIConfig ,
100- hidden_size : int ,
101- num_heads : int ,
102- num_kv_heads : int ,
103- rope_theta : float = 10000 ,
104- rope_scaling : Optional [Dict [str , Any ]] = None ,
105- max_position_embeddings : int = 8192 ,
106- quant_config : Optional [QuantizationConfig ] = None ,
107- bias : bool = False ,
108- bias_o_proj : bool = False ,
109- cache_config : Optional [CacheConfig ] = None ,
110- prefix : str = "" ) -> None :
98+ def __init__ (
99+ self ,
100+ config : SwissAIConfig ,
101+ hidden_size : int ,
102+ num_heads : int ,
103+ num_kv_heads : int ,
104+ rope_theta : float = 10000 ,
105+ rope_scaling : Optional [dict [str , Any ]] = None ,
106+ max_position_embeddings : int = 8192 ,
107+ quant_config : Optional [QuantizationConfig ] = None ,
108+ bias : bool = False ,
109+ bias_o_proj : bool = False ,
110+ cache_config : Optional [CacheConfig ] = None ,
111+ prefix : str = "" ,
112+ attn_type : str = AttentionType .DECODER ,
113+ ) -> None :
111114 super ().__init__ ()
112115 layer_idx = extract_layer_index (prefix )
113116 self .hidden_size = hidden_size
@@ -129,8 +132,8 @@ def __init__(self,
129132 self .head_dim = getattr (config , "head_dim" ,
130133 self .hidden_size // self .total_num_heads )
131134 # Phi models introduced a partial_rotary_factor parameter in the config
132- partial_rotary_factor = getattr (config , "partial_rotary_factor" , 1 )
133- self . rotary_dim = int ( partial_rotary_factor * self . head_dim )
135+ self . partial_rotary_factor = getattr (config , "partial_rotary_factor" ,
136+ 1 )
134137 self .q_size = self .num_heads * self .head_dim
135138 self .kv_size = self .num_kv_heads * self .head_dim
136139 self .scaling = self .head_dim ** - 0.5
@@ -155,19 +158,9 @@ def __init__(self,
155158 prefix = f"{ prefix } .o_proj" ,
156159 )
157160
158- is_neox_style = True
159- is_gguf = quant_config and quant_config .get_name () == "gguf"
160- if is_gguf and config .model_type == "swissai" :
161- is_neox_style = False
162-
163- self .rotary_emb = get_rope (
164- self .head_dim ,
165- rotary_dim = self .rotary_dim ,
166- max_position = max_position_embeddings ,
167- base = rope_theta ,
168- rope_scaling = rope_scaling ,
169- is_neox_style = is_neox_style ,
170- )
161+ self ._init_rotary_emb (config ,
162+ rope_scaling = rope_scaling ,
163+ quant_config = quant_config )
171164
172165 if hasattr (config , "interleaved_sliding_window" ):
173166 interleaved_sliding_window = config .interleaved_sliding_window
@@ -190,6 +183,7 @@ def __init__(self,
190183 cache_config = cache_config ,
191184 quant_config = quant_config ,
192185 per_layer_sliding_window = sliding_window ,
186+ attn_type = attn_type ,
193187 prefix = f"{ prefix } .attn" ,
194188 )
195189
@@ -212,6 +206,24 @@ def forward(
212206 output , _ = self .o_proj (attn_output )
213207 return output
214208
209+ def _init_rotary_emb (self , config : SwissAIConfig ,
210+ rope_scaling : Optional [dict [str , Any ]],
211+ quant_config : Optional [QuantizationConfig ]) -> None :
212+ is_neox_style = True
213+ is_gguf = quant_config and quant_config .get_name () == "gguf"
214+ if is_gguf and config .model_type == "swissai" :
215+ is_neox_style = False
216+
217+ self .rotary_emb = get_rope (
218+ self .head_dim ,
219+ rotary_dim = int (self .partial_rotary_factor * self .head_dim ),
220+ max_position = self .max_position_embeddings ,
221+ base = self .rope_theta ,
222+ rope_scaling = rope_scaling ,
223+ is_neox_style = is_neox_style ,
224+ partial_rotary_factor = self .partial_rotary_factor ,
225+ )
226+
215227
216228class SwissAIDecoderLayer (nn .Module ):
217229
@@ -241,6 +253,15 @@ def __init__(
241253 if hasattr (config , 'qkv_bias' ):
242254 attention_bias = config .qkv_bias
243255
256+ # By default, SwissAI uses causal attention as it is a decoder-only model.
257+ # You can override the HF config with `is_causal=False` to enable
258+ # bidirectional attention, which is used in some embedding models
259+ # (e.g. parasail-ai/GritLM-7B-vllm)
260+ if getattr (config , "is_causal" , True ):
261+ attn_type = AttentionType .DECODER
262+ else :
263+ attn_type = AttentionType .ENCODER_ONLY
264+
244265 self .self_attn = SwissAIAttention (
245266 config = config ,
246267 hidden_size = self .hidden_size ,
@@ -255,6 +276,7 @@ def __init__(
255276 bias_o_proj = bias_o_proj ,
256277 cache_config = cache_config ,
257278 prefix = f"{ prefix } .self_attn" ,
279+ attn_type = attn_type ,
258280 )
259281 self .mlp = SwissAIMLP (
260282 hidden_size = self .hidden_size ,
@@ -274,7 +296,7 @@ def forward(
274296 positions : torch .Tensor ,
275297 hidden_states : torch .Tensor ,
276298 residual : Optional [torch .Tensor ],
277- ) -> Tuple [torch .Tensor , torch .Tensor ]:
299+ ) -> tuple [torch .Tensor , torch .Tensor ]:
278300 # Self Attention
279301 if residual is None :
280302 residual = hidden_states
@@ -299,7 +321,7 @@ def __init__(self,
299321 * ,
300322 vllm_config : VllmConfig ,
301323 prefix : str = "" ,
302- layer_type : Type [ SwissAIDecoderLayer ] = SwissAIDecoderLayer ):
324+ layer_type : type [ nn . Module ] = SwissAIDecoderLayer ):
303325 super ().__init__ ()
304326
305327 config = vllm_config .model_config .hf_config
@@ -313,7 +335,7 @@ def __init__(self,
313335 (lora_config .max_loras or 1 )) if lora_config else 0
314336 self .vocab_size = config .vocab_size + lora_vocab
315337 self .org_vocab_size = config .vocab_size
316- if not torch . cuda . is_available () or get_pp_group ().is_first_rank or (config .tie_word_embeddings
338+ if get_pp_group ().is_first_rank or (config .tie_word_embeddings
317339 and get_pp_group ().is_last_rank ):
318340 self .embed_tokens = VocabParallelEmbedding (
319341 self .vocab_size ,
@@ -336,6 +358,8 @@ def __init__(self,
336358 else :
337359 self .norm = PPMissingLayer ()
338360
361+ self .aux_hidden_state_layers : tuple [int ] = tuple ()
362+
339363 self .make_empty_intermediate_tensors = (
340364 make_empty_intermediate_tensors_factory (
341365 ["hidden_states" , "residual" ], config .hidden_size ))
@@ -349,7 +373,8 @@ def forward(
349373 positions : torch .Tensor ,
350374 intermediate_tensors : Optional [IntermediateTensors ],
351375 inputs_embeds : Optional [torch .Tensor ] = None ,
352- ) -> Union [torch .Tensor , IntermediateTensors ]:
376+ ) -> Union [torch .Tensor , IntermediateTensors , tuple [torch .Tensor ,
377+ list [torch .Tensor ]]]:
353378 if get_pp_group ().is_first_rank :
354379 if inputs_embeds is not None :
355380 hidden_states = inputs_embeds
@@ -361,7 +386,11 @@ def forward(
361386 hidden_states = intermediate_tensors ["hidden_states" ]
362387 residual = intermediate_tensors ["residual" ]
363388
364- for layer in self .layers [self .start_layer :self .end_layer ]:
389+ aux_hidden_states = []
390+ for idx , layer in enumerate (
391+ self .layers [self .start_layer :self .end_layer ]):
392+ if idx in self .aux_hidden_state_layers :
393+ aux_hidden_states .append (hidden_states + residual )
365394 hidden_states , residual = layer (positions , hidden_states , residual )
366395
367396 if not get_pp_group ().is_last_rank :
@@ -371,18 +400,21 @@ def forward(
371400 })
372401
373402 hidden_states , _ = self .norm (hidden_states , residual )
403+
404+ if len (aux_hidden_states ) > 0 :
405+ return hidden_states , aux_hidden_states
374406 return hidden_states
375407
376- def load_weights (self , weights : Iterable [Tuple [str ,
377- torch .Tensor ]]) -> Set [str ]:
408+ def load_weights (self , weights : Iterable [tuple [str ,
409+ torch .Tensor ]]) -> set [str ]:
378410 stacked_params_mapping = [
379411 # (param_name, shard_name, shard_id)
380412 (".qkv_proj" , ".q_proj" , "q" ),
381413 (".qkv_proj" , ".k_proj" , "k" ),
382414 (".qkv_proj" , ".v_proj" , "v" ),
383415 ]
384416 params_dict = dict (self .named_parameters ())
385- loaded_params : Set [str ] = set ()
417+ loaded_params : set [str ] = set ()
386418 for name , loaded_weight in weights :
387419 if "rotary_emb.inv_freq" in name :
388420 continue
@@ -450,7 +482,11 @@ class SwissAIForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
450482 }
451483 embedding_padding_modules = ["lm_head" ]
452484
453- def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
485+ def __init__ (self ,
486+ * ,
487+ vllm_config : VllmConfig ,
488+ prefix : str = "" ,
489+ layer_type : type [nn .Module ] = SwissAIDecoderLayer ):
454490 super ().__init__ ()
455491 config = vllm_config .model_config .hf_config
456492 quant_config = vllm_config .quant_config
@@ -459,7 +495,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
459495 self .lora_config = lora_config
460496
461497 self .model = self ._init_model (vllm_config = vllm_config ,
462- prefix = maybe_prefix (prefix , "model" ))
498+ prefix = maybe_prefix (prefix , "model" ),
499+ layer_type = layer_type )
463500
464501 if get_pp_group ().is_last_rank :
465502 self .unpadded_vocab_size = config .vocab_size
@@ -489,13 +526,19 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
489526 else :
490527 self .lm_head = PPMissingLayer ()
491528
492- self .sampler = get_sampler ()
493-
494529 self .make_empty_intermediate_tensors = (
495530 self .model .make_empty_intermediate_tensors )
496531
497- def _init_model (self , vllm_config : VllmConfig , prefix : str = "" ):
498- return SwissAIModel (vllm_config = vllm_config , prefix = prefix )
532+ def set_aux_hidden_state_layers (self , layers : tuple [int ]) -> None :
533+ self .model .aux_hidden_state_layers = layers
534+
535+ def _init_model (self ,
536+ vllm_config : VllmConfig ,
537+ prefix : str = "" ,
538+ layer_type : type [nn .Module ] = SwissAIDecoderLayer ):
539+ return SwissAIModel (vllm_config = vllm_config ,
540+ prefix = prefix ,
541+ layer_type = layer_type )
499542
500543 def get_input_embeddings (self , input_ids : torch .Tensor ) -> torch .Tensor :
501544 return self .model .get_input_embeddings (input_ids )
@@ -520,13 +563,8 @@ def compute_logits(
520563 sampling_metadata )
521564 return logits
522565
523- def sample (self , logits : torch .Tensor ,
524- sampling_metadata : SamplingMetadata ) -> Optional [SamplerOutput ]:
525- next_tokens = self .sampler (logits , sampling_metadata )
526- return next_tokens
527-
528- def load_weights (self , weights : Iterable [Tuple [str ,
529- torch .Tensor ]]) -> Set [str ]:
566+ def load_weights (self , weights : Iterable [tuple [str ,
567+ torch .Tensor ]]) -> set [str ]:
530568 loader = AutoWeightsLoader (
531569 self ,
532570 skip_prefixes = (["lm_head." ]
0 commit comments