@@ -74,28 +74,38 @@ def prepare_model(self, hf_config: PretrainedConfig):
7474 self .embed_audio = (
7575 Gemma4MultimodalEmbedder (hf_config .audio_config , hf_config .text_config ).to (dtype )
7676 if hf_config .audio_config is not None else None )
77- self .register_buffer ( 'embed_scale' , torch . tensor ( hf_config . hidden_size ** 0.5 ). to ( dtype ), persistent = False )
77+ self .prepare_language_model ( hf_config )
7878 self .model_cls = Gemma4Model
7979
80- def get_inputs_embeds (self , inputs_embeds , ** kwargs ):
81- input_ids = kwargs .get ('input_ids' )
80+ def prepare_language_model (self , hf_config : PretrainedConfig ):
81+ self .register_buffer (
82+ 'embed_scale' , torch .tensor (hf_config .hidden_size ** 0.5 ).to (hf_config .torch_dtype ), persistent = False )
83+
84+ def get_inputs_embeds_language_model (self , inputs_embeds , ** kwargs ):
8285 inputs_embeds = inputs_embeds * self .embed_scale .to (inputs_embeds .dtype )
86+ return {'inputs_embeds' : inputs_embeds , 'llm_input_ids' : kwargs .get ('input_ids' )}
8387
88+ def get_inputs_embeds (self , inputs_embeds , ** kwargs ):
8489 hf_config = self .hf_config
90+ res = self .get_inputs_embeds_language_model (inputs_embeds , ** kwargs )
91+
8592 pixel_values = kwargs .get ('pixel_values' )
8693 pixel_values_videos = kwargs .get ('pixel_values_videos' )
8794 input_features = kwargs .get ('input_features' )
8895 input_features_mask = kwargs .get ('input_features_mask' )
8996 image_position_ids = kwargs .get ('image_position_ids' )
9097 video_position_ids = kwargs .get ('video_position_ids' )
9198
99+ input_ids = kwargs .get ('input_ids' )
92100 image_mask = input_ids == hf_config .image_token_id
93101 video_mask = input_ids == hf_config .video_token_id
94102 audio_mask = input_ids == hf_config .audio_token_id
95103 multimodal_mask = image_mask | video_mask | audio_mask
96104 llm_input_ids = input_ids .clone ()
97105 llm_input_ids [multimodal_mask ] = hf_config .text_config .pad_token_id
98106
107+ inputs_embeds = res ['inputs_embeds' ]
108+
99109 if pixel_values is not None :
100110 with self .patch_hf_config ():
101111 image_features = self .model_cls .get_image_features (
@@ -121,7 +131,9 @@ def get_inputs_embeds(self, inputs_embeds, **kwargs):
121131 audio_features = audio_features .to (inputs_embeds .device , inputs_embeds .dtype )
122132 audio_mask_e = audio_mask .unsqueeze (- 1 ).expand_as (inputs_embeds ).to (inputs_embeds .device )
123133 inputs_embeds = inputs_embeds .masked_scatter (audio_mask_e , audio_features )
124- return {'inputs_embeds' : inputs_embeds , 'llm_input_ids' : llm_input_ids }
134+ res ['inputs_embeds' ] = inputs_embeds
135+ res ['llm_input_ids' ] = llm_input_ids
136+ return res
125137
126138
127139class Gemma4SelfAttention (SelfAttention ):
@@ -878,7 +890,7 @@ def prepare_model(self, hf_config: PretrainedConfig):
878890 self .embed_audio = (
879891 Gemma4UnifiedMultimodalEmbedder (hf_config .audio_config , hf_config .text_config ).to (dtype )
880892 if hf_config .audio_config is not None else None )
881- self .register_buffer ( 'embed_scale' , torch . tensor ( hf_config . hidden_size ** 0.5 ). to ( dtype ), persistent = False )
893+ self .prepare_language_model ( hf_config )
882894 self .model_cls = Gemma4UnifiedModel
883895
884896
0 commit comments