diff --git a/requirements.txt b/requirements.txt index d947546..a664f53 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,4 +3,4 @@ modelscope peft>=0.11,<0.20 safetensors tqdm -transformers>=4.33,<5.10.0 +transformers>=4.33,<5.11.0 diff --git a/src/mcore_bridge/bridge/gpt_bridge.py b/src/mcore_bridge/bridge/gpt_bridge.py index f3111ff..b77ee4c 100644 --- a/src/mcore_bridge/bridge/gpt_bridge.py +++ b/src/mcore_bridge/bridge/gpt_bridge.py @@ -1677,7 +1677,7 @@ def _convert_pre_process(self, mg_model, hf_state_dict, hf_prefix: str, to_mcore else: hf_state_dict = {} self._set_word_embeddings(mg_model, hf_state_dict, to_mcore) - if self.is_multimodal: + if self.is_multimodal and not self.config.language_model_only: for prefix, mg_prefix in self.module_mapping.items(): mg_module = deep_getattr(mg_model, f'visual.{mg_prefix}') hf_state_dict.update(self._set_module(mg_module, hf_state_dict, f'{hf_prefix}{prefix}.', to_mcore)) diff --git a/src/mcore_bridge/config/model_config.py b/src/mcore_bridge/config/model_config.py index 79f0aee..257f16b 100644 --- a/src/mcore_bridge/config/model_config.py +++ b/src/mcore_bridge/config/model_config.py @@ -222,6 +222,7 @@ class ModelConfig(TransformerConfig): mtp_shared_weights: bool = False # visual + language_model_only: bool = False hf_config: Optional[PretrainedConfig] = None vit_attn_impl: Optional[str] = None # e.g. 'flash_attention_2' @@ -343,7 +344,7 @@ def __post_init__(self): self.mcore_model_type = get_mcore_model_type(self.hf_model_type) self.model_meta = get_model_meta(self.mcore_model_type) self.is_multimodal = self.model_meta.visual_cls is not None - if self.is_multimodal: + if self.is_multimodal and not self.language_model_only: self.test_mm_type = getattr(self.model_meta.visual_cls, 'test_mm_type', 'image') else: self.test_mm_type = 'text' diff --git a/src/mcore_bridge/model/mm_gpt_model.py b/src/mcore_bridge/model/mm_gpt_model.py index 2a80c04..b535bfb 100644 --- a/src/mcore_bridge/model/mm_gpt_model.py +++ b/src/mcore_bridge/model/mm_gpt_model.py @@ -47,7 +47,10 @@ def forward(_self, input_): _self.reduce_scatter_embeddings = reduce_scatter_embeddings packed_seq_params = kwargs.get('packed_seq_params') if self.visual is not None: - res = self.visual.get_inputs_embeds(res, **kwargs) + if self.config.language_model_only: + res = self.visual.get_inputs_embeds_language_model(res, **kwargs) + else: + res = self.visual.get_inputs_embeds(res, **kwargs) kwargs.clear() if isinstance(res, dict): # compat dict diff --git a/src/mcore_bridge/model/mm_gpts/gemma4.py b/src/mcore_bridge/model/mm_gpts/gemma4.py index 5c6bd91..94b6e12 100644 --- a/src/mcore_bridge/model/mm_gpts/gemma4.py +++ b/src/mcore_bridge/model/mm_gpts/gemma4.py @@ -74,14 +74,21 @@ def prepare_model(self, hf_config: PretrainedConfig): self.embed_audio = ( Gemma4MultimodalEmbedder(hf_config.audio_config, hf_config.text_config).to(dtype) if hf_config.audio_config is not None else None) - self.register_buffer('embed_scale', torch.tensor(hf_config.hidden_size**0.5).to(dtype), persistent=False) + self.prepare_language_model(hf_config) self.model_cls = Gemma4Model - def get_inputs_embeds(self, inputs_embeds, **kwargs): - input_ids = kwargs.get('input_ids') + def prepare_language_model(self, hf_config: PretrainedConfig): + self.register_buffer( + 'embed_scale', torch.tensor(hf_config.hidden_size**0.5).to(hf_config.torch_dtype), persistent=False) + + def get_inputs_embeds_language_model(self, inputs_embeds, **kwargs): inputs_embeds = inputs_embeds * self.embed_scale.to(inputs_embeds.dtype) + return {'inputs_embeds': inputs_embeds, 'llm_input_ids': kwargs.get('input_ids')} + def get_inputs_embeds(self, inputs_embeds, **kwargs): hf_config = self.hf_config + res = self.get_inputs_embeds_language_model(inputs_embeds, **kwargs) + pixel_values = kwargs.get('pixel_values') pixel_values_videos = kwargs.get('pixel_values_videos') input_features = kwargs.get('input_features') @@ -89,6 +96,7 @@ def get_inputs_embeds(self, inputs_embeds, **kwargs): image_position_ids = kwargs.get('image_position_ids') video_position_ids = kwargs.get('video_position_ids') + input_ids = kwargs.get('input_ids') image_mask = input_ids == hf_config.image_token_id video_mask = input_ids == hf_config.video_token_id audio_mask = input_ids == hf_config.audio_token_id @@ -96,6 +104,8 @@ def get_inputs_embeds(self, inputs_embeds, **kwargs): llm_input_ids = input_ids.clone() llm_input_ids[multimodal_mask] = hf_config.text_config.pad_token_id + inputs_embeds = res['inputs_embeds'] + if pixel_values is not None: with self.patch_hf_config(): image_features = self.model_cls.get_image_features( @@ -121,7 +131,9 @@ def get_inputs_embeds(self, inputs_embeds, **kwargs): audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype) audio_mask_e = audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) inputs_embeds = inputs_embeds.masked_scatter(audio_mask_e, audio_features) - return {'inputs_embeds': inputs_embeds, 'llm_input_ids': llm_input_ids} + res['inputs_embeds'] = inputs_embeds + res['llm_input_ids'] = llm_input_ids + return res class Gemma4SelfAttention(SelfAttention): @@ -878,7 +890,7 @@ def prepare_model(self, hf_config: PretrainedConfig): self.embed_audio = ( Gemma4UnifiedMultimodalEmbedder(hf_config.audio_config, hf_config.text_config).to(dtype) if hf_config.audio_config is not None else None) - self.register_buffer('embed_scale', torch.tensor(hf_config.hidden_size**0.5).to(dtype), persistent=False) + self.prepare_language_model(hf_config) self.model_cls = Gemma4UnifiedModel diff --git a/src/mcore_bridge/model/mm_gpts/utils.py b/src/mcore_bridge/model/mm_gpts/utils.py index f333ced..b610500 100644 --- a/src/mcore_bridge/model/mm_gpts/utils.py +++ b/src/mcore_bridge/model/mm_gpts/utils.py @@ -52,13 +52,20 @@ def __init__(self, config: ModelConfig): self.hf_config = hf_config self.prepare_attn_impl() with patch_get_dynamic_module(): - self.prepare_model(hf_config) + if config.language_model_only: + self.prepare_language_model(hf_config) + else: + self.prepare_model(hf_config) + self.to(device='cuda') @abstractmethod def prepare_model(self, hf_config: PretrainedConfig): pass + def prepare_language_model(self, hf_config: PretrainedConfig): + pass + def prepare_attn_impl(self): vit_attn_impl = self.config.vit_attn_impl or 'flash_attention_2' if self.config.attention_backend.name == 'flash': @@ -68,6 +75,9 @@ def prepare_attn_impl(self): def get_inputs_embeds(self, inputs_embeds, **kwargs): pass + def get_inputs_embeds_language_model(self, inputs_embeds, **kwargs): + return inputs_embeds + @staticmethod def _get_vision_config(hf_config): for k in ['vision_config', 'vit_config']: