Skip to content

Commit 5b27c58

Browse files
authored
[model] support gemma4_unified (#108)
1 parent d62ca16 commit 5b27c58

7 files changed

Lines changed: 86 additions & 23 deletions

File tree

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ The following is the list of models supported by MCore-Bridge:
145145
| Series | model_type |
146146
| -------- | ------------------------------------------------------------ |
147147
| Qwen | qwen2_vl, qwen2_5_vl, qwen2_5_omni<br />qwen3_vl, qwen3_vl_moe, qwen3_omni_moe, qwen3_asr<br />qwen3_5, qwen3_5_moe |
148-
| Gemma | gemma4 |
148+
| Gemma | gemma4, gemma4_unified |
149149
| GLM | glm4v, glm4v_moe |
150150
| Kimi | kimi_vl |
151151
| InternVL | internvl_chat, internvl |

README_zh.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ uv pip install -e . --torch-backend=auto
142142
| 系列 | model_type |
143143
| -------- | ------------------------------------------------------------ |
144144
| Qwen | qwen2_vl, qwen2_5_vl, qwen2_5_omni<br />qwen3_vl, qwen3_vl_moe, qwen3_omni_moe, qwen3_asr<br />qwen3_5, qwen3_5_moe |
145-
| Gemma | gemma4 |
145+
| Gemma | gemma4, gemma4_unified |
146146
| GLM | glm4v, glm4v_moe |
147147
| Kimi | kimi_vl |
148148
| InternVL | internvl_chat, internvl |

src/mcore_bridge/config/parser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def hf_to_mcore_config(hf_config: PretrainedConfig) -> Dict[str, Any]:
190190
n_shared_experts = res.pop('n_shared_experts')
191191
elif llm_model_type in {'ernie4_5', 'ernie4_5_moe', 'glm4'}:
192192
res['rotary_interleaved'] = True
193-
elif hf_model_type in {'gemma4'}:
193+
elif hf_model_type in {'gemma4', 'gemma4_unified'}:
194194
res['qk_layernorm'] = True
195195
res['window_size'] = f'{window_size - 1},0'
196196
window_attn_skip_freq = ','.join(['1' if lt == 'sliding_attention' else '0' for lt in layer_types])

src/mcore_bridge/model/constant.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class MLLMModelType:
3232
kimi_vl = 'kimi_vl'
3333
llama4 = 'llama4'
3434
gemma4 = 'gemma4'
35+
gemma4_unified = 'gemma4_unified'
3536

3637
kimi_k25 = 'kimi_k25'
3738

src/mcore_bridge/model/gpts/deepseek_v4.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,7 @@ def _convert_hf_state_dict(self, hf_state_dict, to_mcore):
391391
k = k[:-len('.scale')] + '.weight_scale_inv'
392392
new_res[k] = v
393393
res = new_res
394-
elif not to_mcore:
394+
else:
395395
res = self._remove_prefix(res, 'model.')
396396
new_res = {}
397397
for k, v in res.items():

src/mcore_bridge/model/gpts/qwen3_emb.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def _convert_hf_state_dict(self, hf_state_dict, to_mcore):
1111
res = super()._convert_hf_state_dict(hf_state_dict, to_mcore)
1212
if to_mcore:
1313
res = self._add_prefix(res, 'model.')
14-
elif not to_mcore:
14+
else:
1515
res = self._remove_prefix(res, 'model.')
1616
return res
1717

src/mcore_bridge/model/mm_gpts/gemma4.py

Lines changed: 80 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def __init__(
166166
orig_k_layernorm = submodules.k_layernorm
167167
config.kv_channels = self.head_dim
168168
config.num_query_groups = self.num_key_value_heads
169-
if self.is_sliding and config.window_size is None:
169+
if text_config.use_bidirectional_attention == 'vision':
170170
kwargs['attn_mask_type'] = AttnMaskType.arbitrary
171171
if self.is_kv_shared_layer:
172172
submodules.k_layernorm = IdentityOp
@@ -291,10 +291,10 @@ def forward(self, hidden_states: Tensor, attention_mask: Tensor, **kwargs) -> Tu
291291
packed_seq_params = kwargs.get('packed_seq_params')
292292
attention_bias = kwargs.get('attention_bias')
293293
mixed_qkv, _ = self.linear_qkv(hidden_states)
294-
if getattr(self, 'world_size', None) is not None and self.config.num_query_groups < self.world_size:
294+
if getattr(self, 'world_size', None) is not None and self.num_key_value_heads < self.world_size:
295295
mixed_qkv = all_gather_last_dim_from_tensor_parallel_region(mixed_qkv)
296-
idx = get_tensor_model_parallel_rank() // (self.world_size // self.config.num_query_groups)
297-
size = mixed_qkv.size()[-1] // self.config.num_query_groups
296+
idx = get_tensor_model_parallel_rank() // (self.world_size // self.num_key_value_heads)
297+
size = mixed_qkv.size()[-1] // self.num_key_value_heads
298298
mixed_qkv = mixed_qkv[:, :, idx * size:(idx + 1) * size]
299299

300300
thd_format = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd'
@@ -327,9 +327,9 @@ def forward(self, hidden_states: Tensor, attention_mask: Tensor, **kwargs) -> Tu
327327
value = value.squeeze(1)
328328
# Query [sq, b, ng, np/ng * hn] -> [sq, b, np, hn]
329329
query = query.reshape(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head)
330-
if getattr(self, 'world_size', None) is not None and self.config.num_query_groups < self.world_size:
331-
idx = get_tensor_model_parallel_rank() % (self.world_size // self.config.num_query_groups)
332-
size = query.shape[2] // (self.world_size // self.config.num_query_groups)
330+
if getattr(self, 'world_size', None) is not None and self.num_key_value_heads < self.world_size:
331+
idx = get_tensor_model_parallel_rank() % (self.world_size // self.num_key_value_heads)
332+
size = query.shape[2] // (self.world_size // self.num_key_value_heads)
333333
query = query[:, :, idx * size:(idx + 1) * size, :]
334334
query = self.q_layernorm(query)
335335
if isinstance(rotary_pos_emb, torch.Tensor):
@@ -609,23 +609,24 @@ def forward(self, *args, **kwargs):
609609
extra_block_kwargs['shared_kv_states'] = shared_kv_states
610610
kwargs['extra_block_kwargs'] = extra_block_kwargs
611611
attention_mask = kwargs.get('attention_mask')
612-
kwargs['attention_mask'] = {'sliding_attention': attention_mask, 'full_attention': attention_mask}
612+
attention_mask = {'sliding_attention': attention_mask, 'full_attention': attention_mask}
613613
if self.text_config.use_bidirectional_attention == 'vision':
614-
kwargs['attention_mask']['sliding_attention'] = self._create_sliding_attention_mask(
615-
attention_mask, mm_token_type_ids)
614+
self._update_attention_mask(attention_mask, mm_token_type_ids)
615+
kwargs['attention_mask'] = attention_mask
616616
hidden_states = super().forward(*args, **kwargs)
617617
if self.hidden_size_per_layer_input and not self.post_process:
618618
hidden_states = self._pack_pp_output(hidden_states, per_layer_inputs, shared_kv_states)
619619
return hidden_states
620620

621-
def _create_sliding_attention_mask(self, attention_mask, mm_token_type_ids):
621+
def _update_attention_mask(self, attention_mask, mm_token_type_ids):
622+
sliding_attention = attention_mask['sliding_attention']
623+
full_attention = attention_mask['full_attention']
624+
# sliding
622625
window_size = self.text_config.sliding_window - 1
623-
seq_len = attention_mask.shape[-1]
624-
625-
window_mask = torch.ones(seq_len, seq_len, dtype=torch.bool, device=attention_mask.device)
626+
seq_len = sliding_attention.shape[-1]
627+
window_mask = torch.ones(seq_len, seq_len, dtype=torch.bool, device=sliding_attention.device)
626628
window_mask = ~torch.triu(window_mask, diagonal=-window_size)
627-
628-
attention_mask = attention_mask | window_mask
629+
sliding_attention = sliding_attention | window_mask
629630
if mm_token_type_ids is not None:
630631
is_vision = mm_token_type_ids > 0
631632
is_prev_vision = torch.roll(is_vision, shifts=1, dims=-1)
@@ -635,8 +636,10 @@ def _create_sliding_attention_mask(self, attention_mask, mm_token_type_ids):
635636
q_group = vision_group_ids.unsqueeze(1).unsqueeze(-1)
636637
k_group = vision_group_ids.unsqueeze(1).unsqueeze(-2)
637638
same_vision_group = (q_group == k_group) & (q_group >= 0) & (k_group >= 0)
638-
attention_mask = attention_mask & ~same_vision_group
639-
return attention_mask
639+
sliding_attention = sliding_attention & ~same_vision_group
640+
full_attention = full_attention & ~same_vision_group
641+
attention_mask['sliding_attention'] = sliding_attention
642+
attention_mask['full_attention'] = full_attention
640643

641644
def _pack_pp_output(self, hidden_states, per_layer_inputs, shared_kv_states):
642645
per_layer_inputs = per_layer_inputs.view(*hidden_states.shape[:2], -1)
@@ -854,3 +857,62 @@ def _replace_router(self, transformer_layer_spec, mlp_key='experts_mlp'):
854857
visual_cls=Gemma4Vit,
855858
loader=Gemma4Loader,
856859
))
860+
861+
862+
class Gemma4UnifiedVit(Gemma4Vit):
863+
module_mapping = {
864+
'model.embed_vision': 'embed_vision',
865+
'model.embed_audio': 'embed_audio',
866+
}
867+
_vision_tower = []
868+
869+
def prepare_model(self, hf_config: PretrainedConfig):
870+
from transformers.models.gemma4_unified.modeling_gemma4_unified import (Gemma4UnifiedModel,
871+
Gemma4UnifiedMultimodalEmbedder,
872+
Gemma4UnifiedVisionEmbedder)
873+
dtype = hf_config.torch_dtype
874+
self.embed_vision = (
875+
Gemma4UnifiedVisionEmbedder(hf_config.vision_config, hf_config.text_config).to(dtype)
876+
if hf_config.vision_config is not None else None)
877+
878+
self.embed_audio = (
879+
Gemma4UnifiedMultimodalEmbedder(hf_config.audio_config, hf_config.text_config).to(dtype)
880+
if hf_config.audio_config is not None else None)
881+
self.register_buffer('embed_scale', torch.tensor(hf_config.hidden_size**0.5).to(dtype), persistent=False)
882+
self.model_cls = Gemma4UnifiedModel
883+
884+
885+
class Gemma4UnifiedBridge(Gemma4Bridge):
886+
887+
def _convert_hf_state_dict(self, hf_state_dict, to_mcore):
888+
res = super()._convert_hf_state_dict(hf_state_dict, to_mcore)
889+
new_state_dict = {}
890+
if to_mcore:
891+
for k, v in res.items():
892+
if k.startswith('model.embed_vision.embedding_projection.'):
893+
new_state_dict['model.embed_vision.multimodal_embedder.embedding_projection.'
894+
+ k[len('model.embed_vision.embedding_projection.'):]] = v
895+
elif k.startswith('model.vision_embedder.'):
896+
new_state_dict['model.embed_vision.' + k[len('model.vision_embedder.'):]] = v
897+
else:
898+
new_state_dict[k] = v
899+
else:
900+
for k, v in res.items():
901+
if k.startswith('model.embed_vision.multimodal_embedder.'):
902+
new_state_dict['model.embed_vision.' + k[len('model.embed_vision.multimodal_embedder.'):]] = v
903+
elif k.startswith('model.embed_vision.'):
904+
new_state_dict['model.vision_embedder.' + k[len('model.embed_vision.'):]] = v
905+
else:
906+
new_state_dict[k] = v
907+
res = new_state_dict
908+
return res
909+
910+
911+
register_model(
912+
ModelMeta(
913+
ModelType.gemma4_unified,
914+
['gemma4_unified'],
915+
bridge_cls=Gemma4UnifiedBridge,
916+
visual_cls=Gemma4UnifiedVit,
917+
loader=Gemma4Loader,
918+
))

0 commit comments

Comments
 (0)