Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@ modelscope
peft>=0.11,<0.20
safetensors
tqdm
transformers>=4.33,<5.10.0
transformers>=4.33,<5.11.0
2 changes: 1 addition & 1 deletion src/mcore_bridge/bridge/gpt_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -1677,7 +1677,7 @@ def _convert_pre_process(self, mg_model, hf_state_dict, hf_prefix: str, to_mcore
else:
hf_state_dict = {}
self._set_word_embeddings(mg_model, hf_state_dict, to_mcore)
if self.is_multimodal:
if self.is_multimodal and not self.config.language_model_only:
for prefix, mg_prefix in self.module_mapping.items():
mg_module = deep_getattr(mg_model, f'visual.{mg_prefix}')
hf_state_dict.update(self._set_module(mg_module, hf_state_dict, f'{hf_prefix}{prefix}.', to_mcore))
Expand Down
3 changes: 2 additions & 1 deletion src/mcore_bridge/config/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ class ModelConfig(TransformerConfig):
mtp_shared_weights: bool = False

# visual
language_model_only: bool = False
hf_config: Optional[PretrainedConfig] = None
vit_attn_impl: Optional[str] = None # e.g. 'flash_attention_2'

Expand Down Expand Up @@ -343,7 +344,7 @@ def __post_init__(self):
self.mcore_model_type = get_mcore_model_type(self.hf_model_type)
self.model_meta = get_model_meta(self.mcore_model_type)
self.is_multimodal = self.model_meta.visual_cls is not None
if self.is_multimodal:
if self.is_multimodal and not self.language_model_only:
self.test_mm_type = getattr(self.model_meta.visual_cls, 'test_mm_type', 'image')
else:
self.test_mm_type = 'text'
Expand Down
15 changes: 9 additions & 6 deletions src/mcore_bridge/model/mm_gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,15 @@ def forward(_self, input_):
packed_seq_params = kwargs.get('packed_seq_params')
if self.visual is not None:
res = self.visual.get_inputs_embeds(res, **kwargs)
kwargs.clear()
if isinstance(res, dict):
# compat dict
inputs_embeds = res.pop('inputs_embeds')
kwargs.update(res)
res = inputs_embeds
else:
assert self.config.language_model_only
res = self.visual.get_inputs_embeds_language_model(res, **kwargs)
Comment thread
Jintao-Huang marked this conversation as resolved.
Outdated
kwargs.clear()
if isinstance(res, dict):
# compat dict
inputs_embeds = res.pop('inputs_embeds')
kwargs.update(res)
res = inputs_embeds
if self.config.context_parallel_size > 1:
res = split_cp_inputs(res, getattr(packed_seq_params, 'cu_seqlens_q', None), 1)
if reduce_scatter_embeddings:
Expand Down
36 changes: 26 additions & 10 deletions src/mcore_bridge/model/mm_gpts/gemma4.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,27 +74,42 @@ def prepare_model(self, hf_config: PretrainedConfig):
self.embed_audio = (
Gemma4MultimodalEmbedder(hf_config.audio_config, hf_config.text_config).to(dtype)
if hf_config.audio_config is not None else None)
self.register_buffer('embed_scale', torch.tensor(hf_config.hidden_size**0.5).to(dtype), persistent=False)
self.prepare_language_model(hf_config)
self.model_cls = Gemma4Model

def get_inputs_embeds(self, inputs_embeds, **kwargs):
def prepare_language_model(self, hf_config: PretrainedConfig):
self.register_buffer(
'embed_scale', torch.tensor(hf_config.hidden_size**0.5).to(hf_config.torch_dtype), persistent=False)

def get_inputs_embeds_language_model(self, inputs_embeds, **kwargs):
input_ids = kwargs.get('input_ids')
hf_config = self.hf_config
inputs_embeds = inputs_embeds * self.embed_scale.to(inputs_embeds.dtype)
Comment thread
Jintao-Huang marked this conversation as resolved.

image_mask = input_ids == hf_config.image_token_id
video_mask = input_ids == hf_config.video_token_id
audio_mask = input_ids == hf_config.audio_token_id
multimodal_mask = image_mask | video_mask | audio_mask
llm_input_ids = input_ids.clone()
llm_input_ids[multimodal_mask] = hf_config.text_config.pad_token_id
return {'inputs_embeds': inputs_embeds, 'llm_input_ids': llm_input_ids}

def get_inputs_embeds(self, inputs_embeds, **kwargs):
hf_config = self.hf_config
res = self.get_inputs_embeds_language_model(inputs_embeds, **kwargs)
input_ids = kwargs.get('input_ids')
image_mask = input_ids == hf_config.image_token_id
video_mask = input_ids == hf_config.video_token_id
audio_mask = input_ids == hf_config.audio_token_id

pixel_values = kwargs.get('pixel_values')
pixel_values_videos = kwargs.get('pixel_values_videos')
input_features = kwargs.get('input_features')
input_features_mask = kwargs.get('input_features_mask')
image_position_ids = kwargs.get('image_position_ids')
video_position_ids = kwargs.get('video_position_ids')

image_mask = input_ids == hf_config.image_token_id
video_mask = input_ids == hf_config.video_token_id
audio_mask = input_ids == hf_config.audio_token_id
multimodal_mask = image_mask | video_mask | audio_mask
llm_input_ids = input_ids.clone()
llm_input_ids[multimodal_mask] = hf_config.text_config.pad_token_id
inputs_embeds = res['inputs_embeds']

if pixel_values is not None:
with self.patch_hf_config():
Expand All @@ -121,7 +136,8 @@ def get_inputs_embeds(self, inputs_embeds, **kwargs):
audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
audio_mask_e = audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
inputs_embeds = inputs_embeds.masked_scatter(audio_mask_e, audio_features)
return {'inputs_embeds': inputs_embeds, 'llm_input_ids': llm_input_ids}
res['inputs_embeds'] = inputs_embeds
return res


class Gemma4SelfAttention(SelfAttention):
Expand Down Expand Up @@ -878,7 +894,7 @@ def prepare_model(self, hf_config: PretrainedConfig):
self.embed_audio = (
Gemma4UnifiedMultimodalEmbedder(hf_config.audio_config, hf_config.text_config).to(dtype)
if hf_config.audio_config is not None else None)
self.register_buffer('embed_scale', torch.tensor(hf_config.hidden_size**0.5).to(dtype), persistent=False)
self.prepare_language_model(hf_config)
self.model_cls = Gemma4UnifiedModel


Expand Down
12 changes: 11 additions & 1 deletion src/mcore_bridge/model/mm_gpts/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,20 @@ def __init__(self, config: ModelConfig):
self.hf_config = hf_config
self.prepare_attn_impl()
with patch_get_dynamic_module():
self.prepare_model(hf_config)
if config.language_model_only:
self.prepare_language_model(hf_config)
else:
self.prepare_model(hf_config)

self.to(device='cuda')

@abstractmethod
def prepare_model(self, hf_config: PretrainedConfig):
pass

def prepare_language_model(self, hf_config: PretrainedConfig):
pass

def prepare_attn_impl(self):
vit_attn_impl = self.config.vit_attn_impl or 'flash_attention_2'
if self.config.attention_backend.name == 'flash':
Expand All @@ -68,6 +75,9 @@ def prepare_attn_impl(self):
def get_inputs_embeds(self, inputs_embeds, **kwargs):
pass

def get_inputs_embeds_language_model(self, inputs_embeds, **kwargs):
return inputs_embeds

@staticmethod
def _get_vision_config(hf_config):
for k in ['vision_config', 'vit_config']:
Expand Down
Loading