Skip to content

Commit 8fd0cdb

Browse files
committed
fix gemma3
1 parent 5e5bdfc commit 8fd0cdb

File tree

5 files changed

+31
-26
lines changed

5 files changed

+31
-26
lines changed

optimum/exporters/openvino/model_configs.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4136,8 +4136,6 @@ def __init__(
41364136
@register_in_tasks_manager("gemma3", *["image-text-to-text"], library_name="transformers")
41374137
class Gemma3OpenVINOConfig(BaseVLMOpenVINOConfig):
41384138
MIN_TRANSFORMERS_VERSION = "4.50.0"
4139-
# TODO (@echarlaix): add v5 support
4140-
MAX_TRANSFORMERS_VERSION = "4.57.6"
41414139

41424140
def __init__(
41434141
self,

optimum/exporters/openvino/model_patcher.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4657,22 +4657,29 @@ def __init__(
46574657
model: "PreTrainedModel",
46584658
model_kwargs: Dict[str, Any],
46594659
):
4660-
model.__orig_forward = model.forward
4661-
# Adopted from https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/got_ocr2/modeling_got_ocr2.py#L835
4662-
# Adopted from https://github.com/huggingface/transformers/blob/v4.49.0-Gemma-3/src/transformers/models/gemma3/modeling_gemma3.py#L1321
4663-
if (
4664-
hasattr(model, "model")
4665-
and hasattr(model.model, "get_image_features")
4666-
and is_transformers_version("<", "5")
4667-
):
4668-
model.forward = model.model.get_image_features
4669-
else:
4670-
model.forward = model.get_image_features
46714660
super().__init__(config, model, model_kwargs)
46724661

4673-
def __exit__(self, exc_type, exc_value, traceback):
4674-
super().__exit__(exc_type, exc_value, traceback)
4675-
self._model.forward = self._model.__orig_forward
4662+
@functools.wraps(self.orig_forward)
4663+
def patched_forward(*args, **kwargs):
4664+
# Adapted from https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/got_ocr2/modeling_got_ocr2.py#L835
4665+
# Adapted from https://github.com/huggingface/transformers/blob/v4.49.0-Gemma-3/src/transformers/models/gemma3/modeling_gemma3.py#L1321
4666+
if (
4667+
hasattr(self._model, "model")
4668+
and hasattr(self._model.model, "get_image_features")
4669+
and is_transformers_version("<", "5")
4670+
):
4671+
get_image_features = self._model.model.get_image_features
4672+
else:
4673+
get_image_features = self._model.get_image_features
4674+
4675+
outputs = get_image_features(*args, **kwargs)
4676+
4677+
if is_transformers_version(">=", "5"):
4678+
outputs = BaseModelOutputWithPooling(pooler_output=outputs.pooler_output)
4679+
4680+
return outputs
4681+
4682+
self.patched_forward = patched_forward
46764683

46774684

46784685
# Adopted from https://github.com/huggingface/transformers/blob/v4.49.0-Gemma-3/src/transformers/models/gemma3/modeling_gemma3.py#L1147

tests/openvino/test_decoder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase):
115115
if is_transformers_version(">", "4.47"):
116116
SUPPORTED_ARCHITECTURES += ("olmo2",)
117117

118-
if is_transformers_version(">", "4.49"):
118+
if is_transformers_version(">=", "4.50"):
119119
SUPPORTED_ARCHITECTURES += ("gemma3_text",)
120120

121121
if is_transformers_version(">=", "4.51.0"):

tests/openvino/test_genai.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ class LLMPipelineTestCase(unittest.TestCase):
7676
SUPPORTED_ARCHITECTURES += ("qwen",)
7777
if is_transformers_version("<", "5"):
7878
SUPPORTED_ARCHITECTURES += ("phimoe",)
79-
if is_transformers_version(">=", "4.49") and is_transformers_version("<", "5"):
79+
if is_transformers_version(">=", "4.50"):
8080
SUPPORTED_ARCHITECTURES += ("gemma3_text",)
8181
if is_transformers_version(">=", "4.51.0"):
8282
SUPPORTED_ARCHITECTURES += ("qwen3", "qwen3_moe")
@@ -224,8 +224,7 @@ class VLMPipelineTestCase(unittest.TestCase):
224224
SUPPORTED_ARCHITECTURES += ("qwen2_5_vl",)
225225
if is_transformers_version("<", "4.54.0"):
226226
SUPPORTED_ARCHITECTURES += ("phi4mm",)
227-
# TODO: add fix for v5 and update MAX_TRANSFORMERS_VERSION accordingly
228-
if is_transformers_version(">=", "4.49") and is_transformers_version("<", "5"):
227+
if is_transformers_version(">=", "4.50"):
229228
SUPPORTED_ARCHITECTURES += ("gemma3",)
230229
if is_transformers_version("<", "5"):
231230
SUPPORTED_ARCHITECTURES += ("llava", "llava_next_video")

tests/openvino/test_seq2seq.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -581,9 +581,11 @@ class OVModelForVisualCausalLMIntegrationTest(OVSeq2SeqTestMixin):
581581
SUPPORTED_ARCHITECTURES += ["phi4mm"]
582582
SUPPORT_AUDIO.append("phi4mm")
583583

584-
# TODO: add fix for v5 and update MAX_TRANSFORMERS_VERSION accordingly
585-
if is_transformers_version(">", "4.49") and is_transformers_version("<", "5"):
586-
SUPPORTED_ARCHITECTURES += ["gemma3", "smolvlm"]
584+
if is_transformers_version(">=", "4.50"):
585+
SUPPORTED_ARCHITECTURES += ["gemma3"]
586+
# TODO: add fix for v5 and update MAX_TRANSFORMERS_VERSION accordingly
587+
if is_transformers_version("<", "5"):
588+
SUPPORTED_ARCHITECTURES += ["smolvlm"]
587589

588590
# TODO: add fix for v5 and update MAX_TRANSFORMERS_VERSION accordingly
589591
if is_transformers_version(">=", "4.51") and is_transformers_version("<", "5"):
@@ -614,7 +616,6 @@ class OVModelForVisualCausalLMIntegrationTest(OVSeq2SeqTestMixin):
614616
"llama4",
615617
"llava_next_video",
616618
"phi4_multimodal",
617-
"gemma3",
618619
"smolvlm",
619620
}
620621
REMOTE_CODE_MODELS = ["internvl_chat", "minicpmv", "minicpmo", "llava-qwen2", "phi3_v", "maira2", "phi4mm"]
@@ -783,9 +784,9 @@ def test_compare_to_transformers(self, model_arch):
783784
set_seed(SEED)
784785

785786
additional_inputs = {}
786-
# gemma3 does not support dynamic cache, it is unfair to compare dynamic cache result vs hybrid cache,
787+
# gemma3 does not support dynamic cache until v4.53, we cannot compare dynamic cache result vs hybrid cache,
787788
# align cache representation in torch model
788-
if model_arch == "gemma3":
789+
if model_arch == "gemma3" and is_transformers_version("<", "4.53.0"):
789790
patch_update_causal_mask(
790791
transformers_model if is_transformers_version("<", "4.52.0") else transformers_model.language_model,
791792
"4.43.0",

0 commit comments

Comments
 (0)