Skip to content

Commit 1fb9697

Browse files
committed
[Examples] Add Gemma 4 E4B NVFP4A16 quantization example
Add NVFP4A16 weight-only quantization example for google/gemma-4-E4B-it. Includes a Dockerfile since Gemma 4 requires transformers from git main which is newer than the version currently pinned by llmcompressor. The ignore list skips vision_tower, audio_tower, embed_vision, and embed_audio modules which are specific to Gemma 4's multimodal architecture. Uses AutoModelForImageTextToText and AutoProcessor as required by the Gemma 4 model class. Tested end-to-end: quantization, sample generation, and model saving all complete successfully. Signed-off-by: Ziming <frankziming26@outlook.com>
1 parent 026c917 commit 1fb9697

File tree

2 files changed

+69
-0
lines changed

2 files changed

+69
-0
lines changed
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
FROM nvcr.io/nvidia/pytorch:25.04-py3
2+
3+
WORKDIR /workspace
4+
5+
# Install llmcompressor and upgrade transformers for Gemma 4 support.
6+
# Gemma 4 (model_type: gemma4) requires transformers from git main which is newer
7+
# than the version currently pinned by llmcompressor.
8+
#
9+
# Step 1: Install llmcompressor (keeps NVIDIA constraint file so torch/cuda stay).
10+
# Step 2: Force-upgrade transformers, huggingface_hub, regex (bypass constraints).
11+
RUN pip install --no-deps git+https://github.com/vllm-project/llm-compressor.git \
12+
"compressed-tensors>=0.14.1a2" loguru "datasets>=4.0.0" accelerate \
13+
"auto-round>=0.10.2" nvidia-ml-py && \
14+
PIP_CONSTRAINT="" pip install multiprocess dill xxhash fsspec && \
15+
PIP_CONSTRAINT="" pip install --force-reinstall \
16+
git+https://github.com/huggingface/transformers.git \
17+
"huggingface_hub>=1.5.0" \
18+
"regex>=2025.10.22" \
19+
tokenizers safetensors && \
20+
pip install "numpy<2"
21+
22+
COPY gemma4_example.py .
23+
24+
ENTRYPOINT ["python", "gemma4_example.py"]
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
from compressed_tensors.offload import dispatch_model
2+
from transformers import AutoModelForImageTextToText, AutoProcessor
3+
4+
from llmcompressor import oneshot
5+
from llmcompressor.modifiers.quantization import QuantizationModifier
6+
7+
# Load model.
8+
MODEL_ID = "google/gemma-4-E4B-it"
9+
model = AutoModelForImageTextToText.from_pretrained(MODEL_ID, dtype="auto")
10+
processor = AutoProcessor.from_pretrained(MODEL_ID)
11+
12+
# Configure the quantization algorithm and scheme.
13+
# In this case, we:
14+
# * quantize the weights to fp4 with per group 16 via ptq
15+
# * skip the vision encoder, audio encoder, embedding projections, and lm_head
16+
recipe = QuantizationModifier(
17+
targets="Linear",
18+
scheme="NVFP4A16",
19+
ignore=[
20+
"lm_head",
21+
"re:.*vision_tower.*",
22+
"re:.*audio_tower.*",
23+
"re:.*embed_vision.*",
24+
"re:.*embed_audio.*",
25+
],
26+
)
27+
28+
# Apply quantization.
29+
oneshot(model=model, recipe=recipe)
30+
31+
print("\n\n========== SAMPLE GENERATION ==============")
32+
dispatch_model(model)
33+
messages = [
34+
{"role": "user", "content": "Hello my name is"},
35+
]
36+
text = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
37+
inputs = processor(text=text, return_tensors="pt").to(model.device)
38+
output = model.generate(**inputs, max_new_tokens=100)
39+
print(processor.decode(output[0], skip_special_tokens=True))
40+
print("==========================================\n\n")
41+
42+
# Save to disk in compressed-tensors format.
43+
SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-NVFP4A16"
44+
model.save_pretrained(SAVE_DIR, save_compressed=True)
45+
processor.save_pretrained(SAVE_DIR)

0 commit comments

Comments
 (0)