@@ -166,7 +166,7 @@ def __init__(
166166 orig_k_layernorm = submodules .k_layernorm
167167 config .kv_channels = self .head_dim
168168 config .num_query_groups = self .num_key_value_heads
169- if self . is_sliding and config . window_size is None :
169+ if text_config . use_bidirectional_attention == 'vision' :
170170 kwargs ['attn_mask_type' ] = AttnMaskType .arbitrary
171171 if self .is_kv_shared_layer :
172172 submodules .k_layernorm = IdentityOp
@@ -291,10 +291,10 @@ def forward(self, hidden_states: Tensor, attention_mask: Tensor, **kwargs) -> Tu
291291 packed_seq_params = kwargs .get ('packed_seq_params' )
292292 attention_bias = kwargs .get ('attention_bias' )
293293 mixed_qkv , _ = self .linear_qkv (hidden_states )
294- if getattr (self , 'world_size' , None ) is not None and self .config . num_query_groups < self .world_size :
294+ if getattr (self , 'world_size' , None ) is not None and self .num_key_value_heads < self .world_size :
295295 mixed_qkv = all_gather_last_dim_from_tensor_parallel_region (mixed_qkv )
296- idx = get_tensor_model_parallel_rank () // (self .world_size // self .config . num_query_groups )
297- size = mixed_qkv .size ()[- 1 ] // self .config . num_query_groups
296+ idx = get_tensor_model_parallel_rank () // (self .world_size // self .num_key_value_heads )
297+ size = mixed_qkv .size ()[- 1 ] // self .num_key_value_heads
298298 mixed_qkv = mixed_qkv [:, :, idx * size :(idx + 1 ) * size ]
299299
300300 thd_format = packed_seq_params is not None and packed_seq_params .qkv_format == 'thd'
@@ -327,9 +327,9 @@ def forward(self, hidden_states: Tensor, attention_mask: Tensor, **kwargs) -> Tu
327327 value = value .squeeze (1 )
328328 # Query [sq, b, ng, np/ng * hn] -> [sq, b, np, hn]
329329 query = query .reshape (query .size (0 ), query .size (1 ), - 1 , self .hidden_size_per_attention_head )
330- if getattr (self , 'world_size' , None ) is not None and self .config . num_query_groups < self .world_size :
331- idx = get_tensor_model_parallel_rank () % (self .world_size // self .config . num_query_groups )
332- size = query .shape [2 ] // (self .world_size // self .config . num_query_groups )
330+ if getattr (self , 'world_size' , None ) is not None and self .num_key_value_heads < self .world_size :
331+ idx = get_tensor_model_parallel_rank () % (self .world_size // self .num_key_value_heads )
332+ size = query .shape [2 ] // (self .world_size // self .num_key_value_heads )
333333 query = query [:, :, idx * size :(idx + 1 ) * size , :]
334334 query = self .q_layernorm (query )
335335 if isinstance (rotary_pos_emb , torch .Tensor ):
@@ -609,23 +609,24 @@ def forward(self, *args, **kwargs):
609609 extra_block_kwargs ['shared_kv_states' ] = shared_kv_states
610610 kwargs ['extra_block_kwargs' ] = extra_block_kwargs
611611 attention_mask = kwargs .get ('attention_mask' )
612- kwargs [ ' attention_mask' ] = {'sliding_attention' : attention_mask , 'full_attention' : attention_mask }
612+ attention_mask = {'sliding_attention' : attention_mask , 'full_attention' : attention_mask }
613613 if self .text_config .use_bidirectional_attention == 'vision' :
614- kwargs [ 'attention_mask' ][ 'sliding_attention' ] = self ._create_sliding_attention_mask (
615- attention_mask , mm_token_type_ids )
614+ self ._update_attention_mask ( attention_mask , mm_token_type_ids )
615+ kwargs [ 'attention_mask' ] = attention_mask
616616 hidden_states = super ().forward (* args , ** kwargs )
617617 if self .hidden_size_per_layer_input and not self .post_process :
618618 hidden_states = self ._pack_pp_output (hidden_states , per_layer_inputs , shared_kv_states )
619619 return hidden_states
620620
621- def _create_sliding_attention_mask (self , attention_mask , mm_token_type_ids ):
621+ def _update_attention_mask (self , attention_mask , mm_token_type_ids ):
622+ sliding_attention = attention_mask ['sliding_attention' ]
623+ full_attention = attention_mask ['full_attention' ]
624+ # sliding
622625 window_size = self .text_config .sliding_window - 1
623- seq_len = attention_mask .shape [- 1 ]
624-
625- window_mask = torch .ones (seq_len , seq_len , dtype = torch .bool , device = attention_mask .device )
626+ seq_len = sliding_attention .shape [- 1 ]
627+ window_mask = torch .ones (seq_len , seq_len , dtype = torch .bool , device = sliding_attention .device )
626628 window_mask = ~ torch .triu (window_mask , diagonal = - window_size )
627-
628- attention_mask = attention_mask | window_mask
629+ sliding_attention = sliding_attention | window_mask
629630 if mm_token_type_ids is not None :
630631 is_vision = mm_token_type_ids > 0
631632 is_prev_vision = torch .roll (is_vision , shifts = 1 , dims = - 1 )
@@ -635,8 +636,10 @@ def _create_sliding_attention_mask(self, attention_mask, mm_token_type_ids):
635636 q_group = vision_group_ids .unsqueeze (1 ).unsqueeze (- 1 )
636637 k_group = vision_group_ids .unsqueeze (1 ).unsqueeze (- 2 )
637638 same_vision_group = (q_group == k_group ) & (q_group >= 0 ) & (k_group >= 0 )
638- attention_mask = attention_mask & ~ same_vision_group
639- return attention_mask
639+ sliding_attention = sliding_attention & ~ same_vision_group
640+ full_attention = full_attention & ~ same_vision_group
641+ attention_mask ['sliding_attention' ] = sliding_attention
642+ attention_mask ['full_attention' ] = full_attention
640643
641644 def _pack_pp_output (self , hidden_states , per_layer_inputs , shared_kv_states ):
642645 per_layer_inputs = per_layer_inputs .view (* hidden_states .shape [:2 ], - 1 )
@@ -854,3 +857,62 @@ def _replace_router(self, transformer_layer_spec, mlp_key='experts_mlp'):
854857 visual_cls = Gemma4Vit ,
855858 loader = Gemma4Loader ,
856859 ))
860+
861+
862+ class Gemma4UnifiedVit (Gemma4Vit ):
863+ module_mapping = {
864+ 'model.embed_vision' : 'embed_vision' ,
865+ 'model.embed_audio' : 'embed_audio' ,
866+ }
867+ _vision_tower = []
868+
869+ def prepare_model (self , hf_config : PretrainedConfig ):
870+ from transformers .models .gemma4_unified .modeling_gemma4_unified import (Gemma4UnifiedModel ,
871+ Gemma4UnifiedMultimodalEmbedder ,
872+ Gemma4UnifiedVisionEmbedder )
873+ dtype = hf_config .torch_dtype
874+ self .embed_vision = (
875+ Gemma4UnifiedVisionEmbedder (hf_config .vision_config , hf_config .text_config ).to (dtype )
876+ if hf_config .vision_config is not None else None )
877+
878+ self .embed_audio = (
879+ Gemma4UnifiedMultimodalEmbedder (hf_config .audio_config , hf_config .text_config ).to (dtype )
880+ if hf_config .audio_config is not None else None )
881+ self .register_buffer ('embed_scale' , torch .tensor (hf_config .hidden_size ** 0.5 ).to (dtype ), persistent = False )
882+ self .model_cls = Gemma4UnifiedModel
883+
884+
885+ class Gemma4UnifiedBridge (Gemma4Bridge ):
886+
887+ def _convert_hf_state_dict (self , hf_state_dict , to_mcore ):
888+ res = super ()._convert_hf_state_dict (hf_state_dict , to_mcore )
889+ new_state_dict = {}
890+ if to_mcore :
891+ for k , v in res .items ():
892+ if k .startswith ('model.embed_vision.embedding_projection.' ):
893+ new_state_dict ['model.embed_vision.multimodal_embedder.embedding_projection.'
894+ + k [len ('model.embed_vision.embedding_projection.' ):]] = v
895+ elif k .startswith ('model.vision_embedder.' ):
896+ new_state_dict ['model.embed_vision.' + k [len ('model.vision_embedder.' ):]] = v
897+ else :
898+ new_state_dict [k ] = v
899+ else :
900+ for k , v in res .items ():
901+ if k .startswith ('model.embed_vision.multimodal_embedder.' ):
902+ new_state_dict ['model.embed_vision.' + k [len ('model.embed_vision.multimodal_embedder.' ):]] = v
903+ elif k .startswith ('model.embed_vision.' ):
904+ new_state_dict ['model.vision_embedder.' + k [len ('model.embed_vision.' ):]] = v
905+ else :
906+ new_state_dict [k ] = v
907+ res = new_state_dict
908+ return res
909+
910+
911+ register_model (
912+ ModelMeta (
913+ ModelType .gemma4_unified ,
914+ ['gemma4_unified' ],
915+ bridge_cls = Gemma4UnifiedBridge ,
916+ visual_cls = Gemma4UnifiedVit ,
917+ loader = Gemma4Loader ,
918+ ))
0 commit comments