Skip to content

Commit 0ec2fc9

Browse files
RadhaGulhane13Luodiancoderabbitai[bot]
authored
[Feature] Support for Gemma-3 Models (#821)
* Add Gemma3 model * Rebase and refactor Signed-off-by: Radha Gulhane <radha.gulhane@zoom.us> * format refactor * Add evaluation script * Apply suggestion from @coderabbitai[bot] Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> * feat(gemma3): Enhance Gemma3 model with AutoModel support and video handling * fix(gemma3): handle import error for specific model class --------- Signed-off-by: Radha Gulhane <radha.gulhane@zoom.us> Co-authored-by: Li Bo <drluodian@gmail.com> Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
1 parent a18ab60 commit 0ec2fc9

3 files changed

Lines changed: 358 additions & 0 deletions

File tree

examples/models/gemma3.sh

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#!/bin/bash
2+
# Run and exactly reproduce gemma3 results!
3+
# mme as an example
4+
5+
NUM_PROCESSES="${NUM_PROCESSES:-8}"
6+
MAIN_PORT="${MAIN_PORT:-12345}"
7+
MODEL_ID="${MODEL_ID:-google/gemma-3-4b-it}"
8+
TASKS="${TASKS:-mmmu_val,ai2d,mathvista_testmini}"
9+
BATCH_SIZE="${BATCH_SIZE:-1}"
10+
OUTPUT_PATH="${OUTPUT_PATH:-./logs/}"
11+
12+
accelerate launch --num_processes "${NUM_PROCESSES}" --main_process_port "${MAIN_PORT}" -m lmms_eval \
13+
--model gemma3 \
14+
--model_args "pretrained=${MODEL_ID}" \
15+
--tasks "${TASKS}" \
16+
--batch_size "${BATCH_SIZE}" --output_path "${OUTPUT_PATH}"

lmms_eval/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
"fuyu": "Fuyu",
2626
"gemini_api": "GeminiAPI",
2727
"gpt4o_audio": "GPT4OAudio",
28+
"gemma3": "Gemma3",
2829
"gpt4v": "GPT4V",
2930
"idefics2": "Idefics2",
3031
"instructblip": "InstructBLIP",

lmms_eval/models/simple/gemma3.py

Lines changed: 341 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,341 @@
1+
import base64
2+
import os
3+
import warnings
4+
from io import BytesIO
5+
from typing import Dict, List, Optional, Tuple, Union
6+
7+
import torch
8+
from accelerate import Accelerator, DistributedType
9+
from loguru import logger as eval_logger
10+
from PIL import Image
11+
from tqdm import tqdm
12+
from transformers import AutoModelForVision2Seq, AutoProcessor, AutoTokenizer
13+
14+
from lmms_eval import utils
15+
from lmms_eval.api.instance import Instance
16+
from lmms_eval.api.model import lmms
17+
from lmms_eval.api.registry import register_model
18+
19+
warnings.simplefilter("ignore", category=DeprecationWarning)
20+
warnings.filterwarnings("ignore")
21+
22+
# Constants for default pixel values
23+
DEFAULT_MIN_PIXELS = 256 * 28 * 28
24+
DEFAULT_MAX_PIXELS = 1605632
25+
DEFAULT_MAX_FRAMES = 32
26+
27+
28+
@register_model("gemma3")
29+
class Gemma3(lmms):
30+
"""
31+
Gemma3 Model
32+
https://huggingface.co/google/gemma-3-27b-it
33+
"""
34+
35+
def __init__(
36+
self,
37+
pretrained: str = "google/gemma-3-27b-it",
38+
device: Optional[str] = "cuda",
39+
device_map: Optional[str] = "auto",
40+
batch_size: Optional[Union[int, str]] = 1,
41+
trust_remote_code: Optional[bool] = True,
42+
use_cache=True,
43+
attn_implementation: Optional[str] = None,
44+
min_pixels: int = DEFAULT_MIN_PIXELS,
45+
max_pixels: int = DEFAULT_MAX_PIXELS,
46+
max_num_frames: int = DEFAULT_MAX_FRAMES,
47+
interleave_visuals: Optional[bool] = False,
48+
system_prompt: Optional[str] = "You are a helpful assistant.",
49+
reasoning_prompt: Optional[str] = None,
50+
**kwargs,
51+
) -> None:
52+
super().__init__()
53+
# Do not use kwargs for now
54+
assert kwargs == {}, f"Unexpected kwargs: {kwargs}"
55+
56+
accelerator = Accelerator()
57+
if accelerator.num_processes > 1:
58+
self._device = torch.device(f"cuda:{accelerator.local_process_index}")
59+
self.device_map = f"cuda:{accelerator.local_process_index}"
60+
else:
61+
self._device = torch.device(device)
62+
self.device_map = device_map if device_map else device
63+
64+
# Prepare model loading arguments
65+
model_kwargs = {
66+
"torch_dtype": torch.bfloat16,
67+
"device_map": self.device_map,
68+
}
69+
70+
# Add attention implementation if specified
71+
if attn_implementation is not None:
72+
model_kwargs["attn_implementation"] = attn_implementation
73+
74+
# Try to load with AutoModelForVision2Seq which handles various vision-language models
75+
try:
76+
self._model = AutoModelForVision2Seq.from_pretrained(pretrained, **model_kwargs).eval()
77+
except Exception:
78+
# Fallback to a more generic approach if specific model class not found
79+
from transformers import AutoModel
80+
81+
self._model = AutoModel.from_pretrained(pretrained, **model_kwargs).eval()
82+
self._tokenizer = AutoTokenizer.from_pretrained(pretrained, trust_remote_code=trust_remote_code, device_map=self.device_map)
83+
self.processor = AutoProcessor.from_pretrained(pretrained, max_pixels=max_pixels, min_pixels=min_pixels)
84+
85+
self._config = self._model.config
86+
self._max_length = kwargs.get("max_length", 2048)
87+
self._model.tie_weights()
88+
self.batch_size_per_gpu = int(batch_size)
89+
self.use_cache = use_cache
90+
self.system_prompt = system_prompt
91+
self.interleave_visuals = interleave_visuals
92+
93+
self.max_pixels = max_pixels
94+
self.min_pixels = min_pixels
95+
self.max_num_frames = max_num_frames
96+
97+
if reasoning_prompt:
98+
self.reasoning_prompt = reasoning_prompt.replace("\\n", "\n")
99+
else:
100+
self.reasoning_prompt = None
101+
102+
if accelerator.num_processes > 1:
103+
assert accelerator.distributed_type in [
104+
DistributedType.FSDP,
105+
DistributedType.MULTI_GPU,
106+
], "Unsupported distributed type provided. Only DDP and FSDP are supported."
107+
if accelerator.distributed_type == DistributedType.FSDP:
108+
self._model = accelerator.prepare(self.model)
109+
else:
110+
self._model = accelerator.prepare_model(self.model, evaluation_mode=True)
111+
self.accelerator = accelerator
112+
if self.accelerator.is_local_main_process:
113+
eval_logger.info(f"Using {accelerator.num_processes} devices with data parallelism")
114+
self._rank = self.accelerator.local_process_index
115+
self._world_size = self.accelerator.num_processes
116+
else:
117+
self.model.to(self._device)
118+
self._rank = 0
119+
self._world_size = 1
120+
self.model.eval()
121+
122+
@property
123+
def config(self):
124+
# return the associated transformers.AutoConfig for the given pretrained model.
125+
return self._config
126+
127+
@property
128+
def tokenizer(self):
129+
return self._tokenizer
130+
131+
@property
132+
def model(self):
133+
# returns the model, unwrapping it if using Accelerate
134+
if hasattr(self, "accelerator"):
135+
return self.accelerator.unwrap_model(self._model)
136+
else:
137+
return self._model
138+
139+
@property
140+
def eot_token_id(self):
141+
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
142+
# return self.tokenizer.eod_id
143+
return self.tokenizer.eos_token_id
144+
145+
@property
146+
def max_length(self):
147+
return self._max_length
148+
149+
@property
150+
def batch_size(self):
151+
return self.batch_size_per_gpu
152+
153+
@property
154+
def device(self):
155+
return self._device
156+
157+
@property
158+
def rank(self):
159+
return self._rank
160+
161+
@property
162+
def world_size(self):
163+
return self._world_size
164+
165+
def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
166+
raise NotImplementedError("Not implemented for Gemma3.")
167+
168+
def flatten(self, input: List[List]) -> List:
169+
"""Flatten a nested list into a single list.
170+
171+
Args:
172+
input: A nested list structure
173+
174+
Returns:
175+
A flattened single-level list
176+
"""
177+
new_list = []
178+
for i in input:
179+
for j in i:
180+
new_list.append(j)
181+
return new_list
182+
183+
def generate_until(self, requests: List[Instance]) -> List[str]:
184+
"""Generate text completions for given requests.
185+
186+
Args:
187+
requests: List of Instance objects containing generation requests
188+
189+
Returns:
190+
List of generated text responses
191+
"""
192+
res = []
193+
194+
def _collate(x):
195+
# the negative sign on len(toks) sorts descending - this has a few advantages:
196+
# - time estimates will always be over not underestimates, which is more useful for planning
197+
# - to know the size of a batch when going through the list, you know the first one is always the batch
198+
# padded context length. this is useful to simplify the batching logic and more importantly to make
199+
# automatic adaptive batches much much easier to implement
200+
# - any OOMs will happen right away rather than near the end
201+
toks = self.tokenizer.encode(x[0])
202+
return -len(toks), x[0]
203+
204+
pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding")
205+
# we group requests by their generation_kwargs,
206+
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
207+
# in the same batch.
208+
re_ords = utils.Collator([reg.args for reg in requests], _collate, grouping=True)
209+
chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None)
210+
for chunk in chunks:
211+
contexts, all_gen_kwargs, doc_to_visual, doc_id, task, split = zip(*chunk)
212+
task = task[0]
213+
split = split[0]
214+
visual_list = [doc_to_visual[0](self.task_dict[task][split][ids]) for ids in doc_id]
215+
gen_kwargs = all_gen_kwargs[0]
216+
217+
# Set default until or update values from gen_kwargs if present
218+
until = gen_kwargs.get("until", [self.tokenizer.decode(self.eot_token_id)])
219+
220+
if isinstance(until, str):
221+
until = [until]
222+
elif not isinstance(until, list):
223+
raise ValueError(f"Expected `gen_kwargs['until']` to be of type Union[str, list], but got {type(until)}")
224+
225+
# Avoid using '\n\n' as a stopper to prevent truncation, which can lead to incorrect results
226+
until = [item for item in until if item != "\n\n"]
227+
228+
if isinstance(contexts, tuple):
229+
contexts = list(contexts)
230+
231+
for i in range(len(contexts)):
232+
if "<image>" in contexts[i]:
233+
contexts[i] = contexts[i].replace("<image>", "")
234+
235+
batched_messages = []
236+
for i, context in enumerate(contexts):
237+
if "<image>" in context:
238+
context = context.replace("<image>", "")
239+
240+
message = [{"role": "system", "content": [{"type": "text", "text": self.system_prompt}]}]
241+
242+
if self.reasoning_prompt:
243+
context = context.strip() + self.reasoning_prompt
244+
contexts[i] = context
245+
246+
processed_visuals = []
247+
for visual in visual_list[i]:
248+
try:
249+
if isinstance(visual, str) and visual.endswith((".mp4", ".avi", ".mov")): # Video file
250+
if not os.path.exists(visual):
251+
eval_logger.warning(f"Video file not found: {visual}")
252+
continue
253+
processed_visuals.append({"type": "video", "video": visual, "max_pixels": self.max_pixels, "min_pixels": self.min_pixels})
254+
elif isinstance(visual, Image.Image): # Handle both single and multiple images
255+
base64_image = visual.convert("RGB")
256+
buffer = BytesIO()
257+
base64_image.save(buffer, format="JPEG")
258+
base64_bytes = base64.b64encode(buffer.getvalue())
259+
base64_string = base64_bytes.decode("utf-8")
260+
processed_visuals.append({"type": "image", "image": f"data:image/jpeg;base64,{base64_string}", "max_pixels": self.max_pixels, "min_pixels": self.min_pixels})
261+
except Exception as e:
262+
eval_logger.error(f"Failed to process visual: {e}")
263+
continue
264+
265+
message.append(
266+
{
267+
"role": "user",
268+
"content": processed_visuals + [{"type": "text", "text": context}],
269+
}
270+
)
271+
272+
batched_messages.append(message)
273+
274+
inputs = self.processor.apply_chat_template(batched_messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", padding="max_length", pad_to_multiple_of=8, max_length=self.max_length).to(
275+
self.model.device, dtype=torch.bfloat16
276+
)
277+
278+
if self.device_map == "auto":
279+
inputs = inputs.to("cuda")
280+
else:
281+
inputs = inputs.to(self.device)
282+
283+
# Set default generation kwargs
284+
default_gen_kwargs = {
285+
"max_new_tokens": 128,
286+
"temperature": 0.0, # Set to 0 for greedy default
287+
"top_p": None,
288+
"num_beams": 1,
289+
}
290+
# Update with provided kwargs
291+
current_gen_kwargs = {**default_gen_kwargs, **gen_kwargs}
292+
293+
if current_gen_kwargs["temperature"] > 0:
294+
current_gen_kwargs["do_sample"] = True
295+
else:
296+
current_gen_kwargs["do_sample"] = False
297+
current_gen_kwargs["temperature"] = None
298+
current_gen_kwargs["top_p"] = None
299+
300+
cont = self.model.generate(
301+
**inputs,
302+
do_sample=current_gen_kwargs["do_sample"],
303+
temperature=current_gen_kwargs["temperature"],
304+
top_p=current_gen_kwargs["top_p"],
305+
num_beams=current_gen_kwargs["num_beams"],
306+
max_new_tokens=current_gen_kwargs["max_new_tokens"],
307+
use_cache=self.use_cache,
308+
)
309+
310+
generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, cont)]
311+
answers = self.processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)
312+
for i, ans in enumerate(answers):
313+
# print(f"Raw answer {i}: {ans}")
314+
for term in until:
315+
if len(term) > 0:
316+
ans = ans.split(term)[0]
317+
answers[i] = ans
318+
319+
for ans, context in zip(answers, contexts):
320+
res.append(ans)
321+
self.cache_hook.add_partial("generate_until", (context, gen_kwargs), ans)
322+
pbar.update(1)
323+
# reorder this group of results back to original unsorted form
324+
res = re_ords.get_original(res)
325+
326+
pbar.close()
327+
return res
328+
329+
def generate_until_multi_round(self, requests: List[Instance]) -> List[str]:
330+
"""Generate text in a multi-round conversation format.
331+
332+
Args:
333+
requests: List of Instance objects for multi-round generation
334+
335+
Returns:
336+
List of generated responses
337+
338+
Raises:
339+
NotImplementedError: This method is not yet implemented
340+
"""
341+
raise NotImplementedError("TODO: Implement multi-round generation")

0 commit comments

Comments
 (0)