-
Notifications
You must be signed in to change notification settings - Fork 315
Open
Labels
bugSomething isn't workingSomething isn't working
Description
⚙️ Your current environment
The output of python collect_env.py
### Environment Information ###
Operating System: `Linux-5.10.0-35-amd64-x86_64-with-glibc2.31`
Python Version: `3.12.11 | packaged by Anaconda, Inc. | (main, Jun 5 2025, 13:09:17) [GCC 11.2.0]`
llm-compressor Version: `0.6.0.1`
compressed-tensors Version: `0.10.2`
transformers Version: `4.52.4`
torch Version: `2.7.0`
CUDA Devices: `['NVIDIA RTX A6000', 'NVIDIA RTX A6000', 'NVIDIA RTX A6000', 'NVIDIA RTX A6000']`
AMD Devices: `None`
🐛 Describe the bug
Using the example script for quantising the (Med)Gemma3 with GPTQModifier fails during saving, yielding following error:
model.save_pretrained(SAVE_DIR, save_compressed=True)
File "/home/jnn/miniconda3/envs/medgem/lib/python3.12/site-packages/llmcompressor/transformers/sparsification/compressed_tensors_utils.py", line 107, in save_pretrained_wrapper
original_save_pretrained.__get__(model, model_class)(
File "/home/jnn/miniconda3/envs/medgem/lib/python3.12/site-packages/transformers/modeling_utils.py", line 3705, in save_pretrained
module = module_map[module_name]
~~~~~~~~~~^^^^^^^^^^^^^
KeyError: 'vision_tower.vision_model.embeddings.patch_embedding.weight'
It might be me, that does something wrong, or there is a discrepancy between Gemma3 and MedGemma Im unaware of, but I am suspicious, when it works on the smaller model.
The quantized models works and outputs the example-description of the COCO kitten-image, as expected.
🛠️ Steps to reproduce
Follows, the exact script I used.
import os
import requests
import torch
from PIL import Image
from transformers import AutoProcessor, Gemma3ForConditionalGeneration
from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import GPTQModifier
from llmcompressor.utils import dispatch_for_generation
# Load model.
# model_id = "google/medgemma-4b-it"
model_id = "google/medgemma-27b-it"
model = Gemma3ForConditionalGeneration.from_pretrained(model_id, torch_dtype="auto")
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
# Oneshot arguments
DATASET_ID = "flickr30k"
DATASET_SPLIT = {"calibration": "test[:512]"}
NUM_CALIBRATION_SAMPLES = 512
MAX_SEQUENCE_LENGTH = 2048
# Define a oneshot data collator for multimodal inputs.
def data_collator(batch):
assert len(batch) == 1
return {key: torch.tensor(value) for key, value in batch[0].items()}
# Recipe
recipe = [
GPTQModifier(
targets="Linear",
scheme="W4A16",
ignore=[
"lm_head",
"re:model\.vision_tower.*",
"re:model\.multi_modal_projector.*",
],
),
]
# Perform oneshot
oneshot(
model=model,
tokenizer=model_id,
dataset=DATASET_ID,
splits=DATASET_SPLIT,
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
trust_remote_code_model=True,
data_collator=data_collator,
)
# Confirm generations of the quantized model look sane.
print("========== SAMPLE GENERATION ==============")
dispatch_for_generation(model)
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "Please describe the animal in this image\n"},
{"type": "image"},
],
},
]
prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
image_url = "http://images.cocodataset.org/train2017/000000231895.jpg"
raw_image = Image.open(requests.get(image_url, stream=True).raw)
# Note: compile is disabled: https://github.com/huggingface/transformers/issues/38333
inputs = processor(images=raw_image, text=prompt, return_tensors="pt").to("cuda")
output = model.generate(**inputs, max_new_tokens=100, disable_compile=True)
print(processor.decode(output[0], skip_special_tokens=True))
print("==========================================")
# Save to disk compressed.
SAVE_DIR = model_id.rstrip("/").split("/")[-1] + "-W4A16-G128"
model.save_pretrained(SAVE_DIR, save_compressed=True)
processor.save_pretrained(SAVE_DIR)Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working