Skip to content

Commit 47e1630

Browse files
authored
support language_model_only (#112)
1 parent e167b3a commit 47e1630

6 files changed

Lines changed: 36 additions & 10 deletions

File tree

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@ modelscope
33
peft>=0.11,<0.20
44
safetensors
55
tqdm
6-
transformers>=4.33,<5.10.0
6+
transformers>=4.33,<5.11.0

src/mcore_bridge/bridge/gpt_bridge.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1677,7 +1677,7 @@ def _convert_pre_process(self, mg_model, hf_state_dict, hf_prefix: str, to_mcore
16771677
else:
16781678
hf_state_dict = {}
16791679
self._set_word_embeddings(mg_model, hf_state_dict, to_mcore)
1680-
if self.is_multimodal:
1680+
if self.is_multimodal and not self.config.language_model_only:
16811681
for prefix, mg_prefix in self.module_mapping.items():
16821682
mg_module = deep_getattr(mg_model, f'visual.{mg_prefix}')
16831683
hf_state_dict.update(self._set_module(mg_module, hf_state_dict, f'{hf_prefix}{prefix}.', to_mcore))

src/mcore_bridge/config/model_config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,7 @@ class ModelConfig(TransformerConfig):
222222
mtp_shared_weights: bool = False
223223

224224
# visual
225+
language_model_only: bool = False
225226
hf_config: Optional[PretrainedConfig] = None
226227
vit_attn_impl: Optional[str] = None # e.g. 'flash_attention_2'
227228

@@ -343,7 +344,7 @@ def __post_init__(self):
343344
self.mcore_model_type = get_mcore_model_type(self.hf_model_type)
344345
self.model_meta = get_model_meta(self.mcore_model_type)
345346
self.is_multimodal = self.model_meta.visual_cls is not None
346-
if self.is_multimodal:
347+
if self.is_multimodal and not self.language_model_only:
347348
self.test_mm_type = getattr(self.model_meta.visual_cls, 'test_mm_type', 'image')
348349
else:
349350
self.test_mm_type = 'text'

src/mcore_bridge/model/mm_gpt_model.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,10 @@ def forward(_self, input_):
4747
_self.reduce_scatter_embeddings = reduce_scatter_embeddings
4848
packed_seq_params = kwargs.get('packed_seq_params')
4949
if self.visual is not None:
50-
res = self.visual.get_inputs_embeds(res, **kwargs)
50+
if self.config.language_model_only:
51+
res = self.visual.get_inputs_embeds_language_model(res, **kwargs)
52+
else:
53+
res = self.visual.get_inputs_embeds(res, **kwargs)
5154
kwargs.clear()
5255
if isinstance(res, dict):
5356
# compat dict

src/mcore_bridge/model/mm_gpts/gemma4.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,28 +74,38 @@ def prepare_model(self, hf_config: PretrainedConfig):
7474
self.embed_audio = (
7575
Gemma4MultimodalEmbedder(hf_config.audio_config, hf_config.text_config).to(dtype)
7676
if hf_config.audio_config is not None else None)
77-
self.register_buffer('embed_scale', torch.tensor(hf_config.hidden_size**0.5).to(dtype), persistent=False)
77+
self.prepare_language_model(hf_config)
7878
self.model_cls = Gemma4Model
7979

80-
def get_inputs_embeds(self, inputs_embeds, **kwargs):
81-
input_ids = kwargs.get('input_ids')
80+
def prepare_language_model(self, hf_config: PretrainedConfig):
81+
self.register_buffer(
82+
'embed_scale', torch.tensor(hf_config.hidden_size**0.5).to(hf_config.torch_dtype), persistent=False)
83+
84+
def get_inputs_embeds_language_model(self, inputs_embeds, **kwargs):
8285
inputs_embeds = inputs_embeds * self.embed_scale.to(inputs_embeds.dtype)
86+
return {'inputs_embeds': inputs_embeds, 'llm_input_ids': kwargs.get('input_ids')}
8387

88+
def get_inputs_embeds(self, inputs_embeds, **kwargs):
8489
hf_config = self.hf_config
90+
res = self.get_inputs_embeds_language_model(inputs_embeds, **kwargs)
91+
8592
pixel_values = kwargs.get('pixel_values')
8693
pixel_values_videos = kwargs.get('pixel_values_videos')
8794
input_features = kwargs.get('input_features')
8895
input_features_mask = kwargs.get('input_features_mask')
8996
image_position_ids = kwargs.get('image_position_ids')
9097
video_position_ids = kwargs.get('video_position_ids')
9198

99+
input_ids = kwargs.get('input_ids')
92100
image_mask = input_ids == hf_config.image_token_id
93101
video_mask = input_ids == hf_config.video_token_id
94102
audio_mask = input_ids == hf_config.audio_token_id
95103
multimodal_mask = image_mask | video_mask | audio_mask
96104
llm_input_ids = input_ids.clone()
97105
llm_input_ids[multimodal_mask] = hf_config.text_config.pad_token_id
98106

107+
inputs_embeds = res['inputs_embeds']
108+
99109
if pixel_values is not None:
100110
with self.patch_hf_config():
101111
image_features = self.model_cls.get_image_features(
@@ -121,7 +131,9 @@ def get_inputs_embeds(self, inputs_embeds, **kwargs):
121131
audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
122132
audio_mask_e = audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
123133
inputs_embeds = inputs_embeds.masked_scatter(audio_mask_e, audio_features)
124-
return {'inputs_embeds': inputs_embeds, 'llm_input_ids': llm_input_ids}
134+
res['inputs_embeds'] = inputs_embeds
135+
res['llm_input_ids'] = llm_input_ids
136+
return res
125137

126138

127139
class Gemma4SelfAttention(SelfAttention):
@@ -878,7 +890,7 @@ def prepare_model(self, hf_config: PretrainedConfig):
878890
self.embed_audio = (
879891
Gemma4UnifiedMultimodalEmbedder(hf_config.audio_config, hf_config.text_config).to(dtype)
880892
if hf_config.audio_config is not None else None)
881-
self.register_buffer('embed_scale', torch.tensor(hf_config.hidden_size**0.5).to(dtype), persistent=False)
893+
self.prepare_language_model(hf_config)
882894
self.model_cls = Gemma4UnifiedModel
883895

884896

src/mcore_bridge/model/mm_gpts/utils.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,13 +52,20 @@ def __init__(self, config: ModelConfig):
5252
self.hf_config = hf_config
5353
self.prepare_attn_impl()
5454
with patch_get_dynamic_module():
55-
self.prepare_model(hf_config)
55+
if config.language_model_only:
56+
self.prepare_language_model(hf_config)
57+
else:
58+
self.prepare_model(hf_config)
59+
5660
self.to(device='cuda')
5761

5862
@abstractmethod
5963
def prepare_model(self, hf_config: PretrainedConfig):
6064
pass
6165

66+
def prepare_language_model(self, hf_config: PretrainedConfig):
67+
pass
68+
6269
def prepare_attn_impl(self):
6370
vit_attn_impl = self.config.vit_attn_impl or 'flash_attention_2'
6471
if self.config.attention_backend.name == 'flash':
@@ -68,6 +75,9 @@ def prepare_attn_impl(self):
6875
def get_inputs_embeds(self, inputs_embeds, **kwargs):
6976
pass
7077

78+
def get_inputs_embeds_language_model(self, inputs_embeds, **kwargs):
79+
return inputs_embeds
80+
7181
@staticmethod
7282
def _get_vision_config(hf_config):
7383
for k in ['vision_config', 'vit_config']:

0 commit comments

Comments
 (0)