Skip to content

Commit 146f9e4

Browse files
committed
fix(gemma3): load Gemma3ForConditionalGeneration to restore .generate() and remove AutoModel fallback
1 parent 52fc5fe commit 146f9e4

1 file changed

Lines changed: 3 additions & 9 deletions

File tree

lmms_eval/models/simple/gemma3.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from loguru import logger as eval_logger
1010
from PIL import Image
1111
from tqdm import tqdm
12-
from transformers import AutoModelForVision2Seq, AutoProcessor, AutoTokenizer
12+
from transformers import AutoProcessor, AutoTokenizer, Gemma3ForConditionalGeneration
1313

1414
from lmms_eval import utils
1515
from lmms_eval.api.instance import Instance
@@ -71,14 +71,8 @@ def __init__(
7171
if attn_implementation is not None:
7272
model_kwargs["attn_implementation"] = attn_implementation
7373

74-
# Try to load with AutoModelForVision2Seq which handles various vision-language models
75-
try:
76-
self._model = AutoModelForVision2Seq.from_pretrained(pretrained, **model_kwargs).eval()
77-
except Exception:
78-
# Fallback to a more generic approach if specific model class not found
79-
from transformers import AutoModel
80-
81-
self._model = AutoModel.from_pretrained(pretrained, **model_kwargs).eval()
74+
# Minimal, generation-capable loader: use the dedicated Gemma3 class
75+
self._model = Gemma3ForConditionalGeneration.from_pretrained(pretrained, **model_kwargs).eval()
8276
self._tokenizer = AutoTokenizer.from_pretrained(pretrained, trust_remote_code=trust_remote_code, device_map=self.device_map)
8377
self.processor = AutoProcessor.from_pretrained(pretrained, max_pixels=max_pixels, min_pixels=min_pixels)
8478

0 commit comments

Comments
 (0)