Skip to content

Commit 72352fe

Browse files
committed
fix
1 parent 0834563 commit 72352fe

2 files changed

Lines changed: 6 additions & 5 deletions

File tree

src/mcore_bridge/model/mm_gpt_model.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,11 @@ def forward(_self, input_):
4646
res = origin_forward(_self, input_)
4747
_self.reduce_scatter_embeddings = reduce_scatter_embeddings
4848
packed_seq_params = kwargs.get('packed_seq_params')
49-
if self.config.language_model_only:
50-
res = self.visual.get_inputs_embeds_language_model(res, **kwargs)
51-
else:
52-
res = self.visual.get_inputs_embeds(res, **kwargs)
49+
if self.visual is not None:
50+
if self.config.language_model_only:
51+
res = self.visual.get_inputs_embeds_language_model(res, **kwargs)
52+
else:
53+
res = self.visual.get_inputs_embeds(res, **kwargs)
5354
kwargs.clear()
5455
if isinstance(res, dict):
5556
# compat dict

src/mcore_bridge/model/mm_gpts/gemma4.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def get_inputs_embeds(self, inputs_embeds, **kwargs):
132132
audio_mask_e = audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
133133
inputs_embeds = inputs_embeds.masked_scatter(audio_mask_e, audio_features)
134134
res['inputs_embeds'] = inputs_embeds
135-
res['llm_input_ids'] = input_ids
135+
res['llm_input_ids'] = llm_input_ids
136136
return res
137137

138138

0 commit comments

Comments
 (0)