Skip to content

Commit 50eb37a

Browse files
committed
fix
1 parent 775697a commit 50eb37a

1 file changed

Lines changed: 1 addition & 10 deletions

File tree

src/mcore_bridge/model/mm_gpts/gemma4.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -82,17 +82,8 @@ def prepare_language_model(self, hf_config: PretrainedConfig):
8282
'embed_scale', torch.tensor(hf_config.hidden_size**0.5).to(hf_config.torch_dtype), persistent=False)
8383

8484
def get_inputs_embeds_language_model(self, inputs_embeds, **kwargs):
85-
input_ids = kwargs.get('input_ids')
86-
hf_config = self.hf_config
8785
inputs_embeds = inputs_embeds * self.embed_scale.to(inputs_embeds.dtype)
88-
89-
image_mask = input_ids == hf_config.image_token_id
90-
video_mask = input_ids == hf_config.video_token_id
91-
audio_mask = input_ids == hf_config.audio_token_id
92-
multimodal_mask = image_mask | video_mask | audio_mask
93-
llm_input_ids = input_ids.clone()
94-
llm_input_ids[multimodal_mask] = hf_config.text_config.pad_token_id
95-
return {'inputs_embeds': inputs_embeds, 'llm_input_ids': llm_input_ids}
86+
return {'inputs_embeds': inputs_embeds, 'llm_input_ids': kwargs.get('input_ids')}
9687

9788
def get_inputs_embeds(self, inputs_embeds, **kwargs):
9889
hf_config = self.hf_config

0 commit comments

Comments
 (0)