Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@ modelscope
peft>=0.11,<0.20
safetensors
tqdm
transformers>=4.33,<5.10.0
transformers>=4.33,<5.11.0
2 changes: 1 addition & 1 deletion src/mcore_bridge/bridge/gpt_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -1677,7 +1677,7 @@ def _convert_pre_process(self, mg_model, hf_state_dict, hf_prefix: str, to_mcore
else:
hf_state_dict = {}
self._set_word_embeddings(mg_model, hf_state_dict, to_mcore)
if self.is_multimodal:
if self.is_multimodal and not self.config.language_model_only:
for prefix, mg_prefix in self.module_mapping.items():
mg_module = deep_getattr(mg_model, f'visual.{mg_prefix}')
hf_state_dict.update(self._set_module(mg_module, hf_state_dict, f'{hf_prefix}{prefix}.', to_mcore))
Expand Down
3 changes: 2 additions & 1 deletion src/mcore_bridge/config/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ class ModelConfig(TransformerConfig):
mtp_shared_weights: bool = False

# visual
language_model_only: bool = False
hf_config: Optional[PretrainedConfig] = None
vit_attn_impl: Optional[str] = None # e.g. 'flash_attention_2'

Expand Down Expand Up @@ -343,7 +344,7 @@ def __post_init__(self):
self.mcore_model_type = get_mcore_model_type(self.hf_model_type)
self.model_meta = get_model_meta(self.mcore_model_type)
self.is_multimodal = self.model_meta.visual_cls is not None
if self.is_multimodal:
if self.is_multimodal and not self.language_model_only:
self.test_mm_type = getattr(self.model_meta.visual_cls, 'test_mm_type', 'image')
else:
self.test_mm_type = 'text'
Expand Down
5 changes: 4 additions & 1 deletion src/mcore_bridge/model/mm_gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ def forward(_self, input_):
_self.reduce_scatter_embeddings = reduce_scatter_embeddings
packed_seq_params = kwargs.get('packed_seq_params')
if self.visual is not None:
res = self.visual.get_inputs_embeds(res, **kwargs)
if self.config.language_model_only:
res = self.visual.get_inputs_embeds_language_model(res, **kwargs)
else:
res = self.visual.get_inputs_embeds(res, **kwargs)
kwargs.clear()
if isinstance(res, dict):
# compat dict
Expand Down
22 changes: 17 additions & 5 deletions src/mcore_bridge/model/mm_gpts/gemma4.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,28 +74,38 @@ def prepare_model(self, hf_config: PretrainedConfig):
self.embed_audio = (
Gemma4MultimodalEmbedder(hf_config.audio_config, hf_config.text_config).to(dtype)
if hf_config.audio_config is not None else None)
self.register_buffer('embed_scale', torch.tensor(hf_config.hidden_size**0.5).to(dtype), persistent=False)
self.prepare_language_model(hf_config)
self.model_cls = Gemma4Model

def get_inputs_embeds(self, inputs_embeds, **kwargs):
input_ids = kwargs.get('input_ids')
def prepare_language_model(self, hf_config: PretrainedConfig):
self.register_buffer(
'embed_scale', torch.tensor(hf_config.hidden_size**0.5).to(hf_config.torch_dtype), persistent=False)

def get_inputs_embeds_language_model(self, inputs_embeds, **kwargs):
inputs_embeds = inputs_embeds * self.embed_scale.to(inputs_embeds.dtype)
Comment thread
Jintao-Huang marked this conversation as resolved.
return {'inputs_embeds': inputs_embeds, 'llm_input_ids': kwargs.get('input_ids')}

def get_inputs_embeds(self, inputs_embeds, **kwargs):
hf_config = self.hf_config
res = self.get_inputs_embeds_language_model(inputs_embeds, **kwargs)

pixel_values = kwargs.get('pixel_values')
pixel_values_videos = kwargs.get('pixel_values_videos')
input_features = kwargs.get('input_features')
input_features_mask = kwargs.get('input_features_mask')
image_position_ids = kwargs.get('image_position_ids')
video_position_ids = kwargs.get('video_position_ids')

input_ids = kwargs.get('input_ids')
image_mask = input_ids == hf_config.image_token_id
video_mask = input_ids == hf_config.video_token_id
audio_mask = input_ids == hf_config.audio_token_id
multimodal_mask = image_mask | video_mask | audio_mask
llm_input_ids = input_ids.clone()
llm_input_ids[multimodal_mask] = hf_config.text_config.pad_token_id

inputs_embeds = res['inputs_embeds']

if pixel_values is not None:
with self.patch_hf_config():
image_features = self.model_cls.get_image_features(
Expand All @@ -121,7 +131,9 @@ def get_inputs_embeds(self, inputs_embeds, **kwargs):
audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
audio_mask_e = audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
inputs_embeds = inputs_embeds.masked_scatter(audio_mask_e, audio_features)
return {'inputs_embeds': inputs_embeds, 'llm_input_ids': llm_input_ids}
res['inputs_embeds'] = inputs_embeds
res['llm_input_ids'] = llm_input_ids
return res


class Gemma4SelfAttention(SelfAttention):
Expand Down Expand Up @@ -878,7 +890,7 @@ def prepare_model(self, hf_config: PretrainedConfig):
self.embed_audio = (
Gemma4UnifiedMultimodalEmbedder(hf_config.audio_config, hf_config.text_config).to(dtype)
if hf_config.audio_config is not None else None)
self.register_buffer('embed_scale', torch.tensor(hf_config.hidden_size**0.5).to(dtype), persistent=False)
self.prepare_language_model(hf_config)
self.model_cls = Gemma4UnifiedModel


Expand Down
12 changes: 11 additions & 1 deletion src/mcore_bridge/model/mm_gpts/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,20 @@ def __init__(self, config: ModelConfig):
self.hf_config = hf_config
self.prepare_attn_impl()
with patch_get_dynamic_module():
self.prepare_model(hf_config)
if config.language_model_only:
self.prepare_language_model(hf_config)
else:
self.prepare_model(hf_config)

self.to(device='cuda')

@abstractmethod
def prepare_model(self, hf_config: PretrainedConfig):
pass

def prepare_language_model(self, hf_config: PretrainedConfig):
pass

def prepare_attn_impl(self):
vit_attn_impl = self.config.vit_attn_impl or 'flash_attention_2'
if self.config.attention_backend.name == 'flash':
Expand All @@ -68,6 +75,9 @@ def prepare_attn_impl(self):
def get_inputs_embeds(self, inputs_embeds, **kwargs):
pass

def get_inputs_embeds_language_model(self, inputs_embeds, **kwargs):
return inputs_embeds

@staticmethod
def _get_vision_config(hf_config):
for k in ['vision_config', 'vit_config']:
Expand Down
Loading