Skip to content

Commit dc328be

Browse files
author
d.savchenkov
committed
[quantization] Add Gemma4VisionPooler PTQ wrapper with export support
Implements quantization wrapper for Gemma4 vision pooler. TICO-DCO-1.0-Signed-off-by: d.savchenkov <d.savchenkov@partner.samsung.com>
1 parent 5555056 commit dc328be

9 files changed

Lines changed: 1343 additions & 14 deletions

File tree

test/quantization/wrapq/wrappers/gemma4/test_quant_vision_pooler.py

Lines changed: 497 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Smoke tests for Gemma4 vision pooler prepare-calibrate-convert flow."""
16+
17+
import copy
18+
import os
19+
import unittest
20+
21+
import torch
22+
23+
from tico.quantization import convert, prepare
24+
from tico.quantization.config.ptq import PTQConfig
25+
from tico.quantization.wrapq.mode import Mode
26+
from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper
27+
28+
29+
IS_INTERNAL_TEST = os.environ.get("RUN_INTERNAL_TESTS", "0") == "1"
30+
_SKIP_MSG = "required transformers Gemma4 modules are not installed"
31+
32+
33+
def _has_gemma4() -> bool:
34+
"""Return whether the installed transformers package provides Gemma4 vision."""
35+
try:
36+
from transformers.models.gemma4.configuration_gemma4 import ( # noqa: F401
37+
Gemma4VisionConfig,
38+
)
39+
from transformers.models.gemma4.modeling_gemma4 import ( # noqa: F401
40+
Gemma4VisionPooler,
41+
)
42+
except Exception:
43+
return False
44+
return True
45+
46+
47+
def _make_vision_config():
48+
"""Create a tiny Gemma4 vision config for synthetic smoke tests."""
49+
from transformers.models.gemma4.configuration_gemma4 import Gemma4VisionConfig
50+
51+
cfg = Gemma4VisionConfig(
52+
hidden_size=32,
53+
intermediate_size=64,
54+
num_hidden_layers=1,
55+
num_attention_heads=4,
56+
num_key_value_heads=2,
57+
head_dim=8,
58+
attention_dropout=0.0,
59+
max_position_embeddings=128,
60+
rms_norm_eps=1e-6,
61+
use_clipped_linears=False,
62+
rope_parameters={"rope_type": "default", "rope_theta": 100.0},
63+
)
64+
if not hasattr(cfg, "_attn_implementation"):
65+
setattr(cfg, "_attn_implementation", "eager")
66+
else:
67+
cfg._attn_implementation = "eager"
68+
return cfg
69+
70+
71+
def _pixel_position_ids(batch_size: int, seq_len: int) -> torch.Tensor:
72+
"""Create deterministic 2-D pixel position ids for a tiny patch sequence."""
73+
side = int(seq_len**0.5)
74+
coords = torch.arange(seq_len)
75+
xy = torch.stack((coords % side, coords // side), dim=-1)
76+
return xy.unsqueeze(0).expand(batch_size, -1, -1).long()
77+
78+
79+
def _padding_positions(batch_size: int, seq_len: int) -> torch.Tensor:
80+
"""Create an all-False padding mask (no padding)."""
81+
return torch.zeros(batch_size, seq_len, dtype=torch.bool)
82+
83+
84+
@unittest.skipIf(
85+
not IS_INTERNAL_TEST,
86+
"Internal smoke test — set RUN_INTERNAL_TESTS=1 to enable it.",
87+
)
88+
@unittest.skipUnless(_has_gemma4(), _SKIP_MSG)
89+
class TestGemma4VisionPoolerSmoke(unittest.TestCase):
90+
"""Exercise Gemma4 vision pooler wrapper parity and PTQ flow."""
91+
92+
def setUp(self):
93+
"""Create deterministic tiny Gemma4 vision pooler modules."""
94+
torch.manual_seed(2026)
95+
from transformers.models.gemma4.modeling_gemma4 import Gemma4VisionPooler
96+
97+
self.cfg = _make_vision_config()
98+
self.fp_pooler = Gemma4VisionPooler(self.cfg).eval()
99+
self.fp_ref = copy.deepcopy(self.fp_pooler).eval()
100+
# seq_len=16, output_length=4 so that k=2 (16 / 4 = 4, sqrt(4) = 2)
101+
self.seq_len = 16
102+
self.output_length = 4
103+
104+
def _sample(self):
105+
"""Create one synthetic Gemma4 vision pooler sample."""
106+
batch_size = 1
107+
return {
108+
"hidden_states": torch.randn(
109+
batch_size, self.seq_len, self.cfg.hidden_size
110+
),
111+
"pixel_position_ids": _pixel_position_ids(batch_size, self.seq_len),
112+
"padding_positions": _padding_positions(batch_size, self.seq_len),
113+
"output_length": self.output_length,
114+
}
115+
116+
def test_no_quant_vision_pooler_matches_reference(self):
117+
"""The wrapper should match the floating-point module before quantization."""
118+
from tico.quantization.wrapq.wrappers.gemma4.quant_vision_pooler import (
119+
QuantGemma4VisionPooler,
120+
)
121+
122+
wrapped = QuantGemma4VisionPooler(self.fp_pooler, qcfg=PTQConfig()).eval()
123+
sample = self._sample()
124+
125+
with torch.no_grad():
126+
quant_out = wrapped(**sample)
127+
fp_out = self.fp_ref(**sample)
128+
129+
# Both return (pooled_features, updated_padding)
130+
self.assertIsInstance(quant_out, tuple)
131+
self.assertIsInstance(fp_out, tuple)
132+
self.assertEqual(quant_out[0].shape, fp_out[0].shape)
133+
self.assertTrue(torch.allclose(quant_out[0], fp_out[0], atol=1e-5, rtol=1e-5))
134+
self.assertTrue(torch.equal(quant_out[1], fp_out[1]))
135+
136+
def test_prepare_convert_vision_pooler_flow(self):
137+
"""Quantize Gemma4 vision pooler and validate a synthetic output."""
138+
from tico.quantization.wrapq.wrappers.gemma4.quant_vision_pooler import (
139+
QuantGemma4VisionPooler,
140+
)
141+
142+
prepared = prepare(self.fp_pooler, PTQConfig())
143+
self.assertIsInstance(prepared, PTQWrapper)
144+
self.assertIsInstance(prepared.wrapped, QuantGemma4VisionPooler)
145+
146+
with torch.no_grad():
147+
for _ in range(3):
148+
prepared(**self._sample())
149+
150+
quantized = convert(prepared)
151+
self.assertIs(quantized._mode, Mode.QUANT)
152+
153+
sample = self._sample()
154+
with torch.no_grad():
155+
quant_out = quantized(**sample)
156+
fp_out = self.fp_ref(**sample)
157+
158+
self.assertIsInstance(quant_out, tuple)
159+
self.assertEqual(quant_out[0].shape, fp_out[0].shape)
160+
self.assertTrue(torch.isfinite(quant_out[0]).all())
161+
162+
def test_as_export_module_flow(self):
163+
"""Test the as_export_module flow for Circle export."""
164+
from tico.quantization.wrapq.wrappers.gemma4.export_adapters import (
165+
Gemma4VisionPoolerPrefillExportAdapter,
166+
)
167+
168+
prepared = prepare(self.fp_pooler, PTQConfig())
169+
170+
with torch.no_grad():
171+
for _ in range(3):
172+
prepared(**self._sample())
173+
174+
quantized = convert(prepared)
175+
176+
pixel_pos_ids = _pixel_position_ids(1, self.seq_len)
177+
adapter = quantized.wrapped.as_export_module(
178+
output_length=self.output_length,
179+
pixel_position_ids=pixel_pos_ids,
180+
)
181+
182+
self.assertIsInstance(adapter, Gemma4VisionPoolerPrefillExportAdapter)
183+
184+
# Verify adapter forward works
185+
sample = self._sample()
186+
adapter_kwargs = {
187+
"hidden_states": sample["hidden_states"],
188+
"pixel_position_ids": sample["pixel_position_ids"],
189+
"padding_positions": sample["padding_positions"],
190+
}
191+
with torch.no_grad():
192+
out = adapter(**adapter_kwargs)
193+
194+
self.assertIsInstance(out, tuple)
195+
self.assertEqual(len(out), 2)
196+
197+
198+
if __name__ == "__main__":
199+
unittest.main()

tico/quantization/recipes/debug/wrapper_smoke/cases/gemma4.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -583,6 +583,26 @@ def _vision_position_ids(batch_size: int, seq_len: int) -> torch.Tensor:
583583
return xy.unsqueeze(0).expand(batch_size, -1, -1).long()
584584

585585

586+
def _pixel_position_ids(batch_size: int, seq_len: int) -> torch.Tensor:
587+
"""Create deterministic 2-D pixel position ids for a tiny patch sequence.
588+
589+
The pooler requires ``pixel_position_ids`` with shape ``(B, S, 2)`` where
590+
the last dimension encodes ``(x, y)`` patch coordinates. We build a
591+
simple square grid layout that is compatible with the ``output_length``
592+
used in pooler tests: ``seq_len = output_length * k^2`` where ``k`` is
593+
the pooling factor.
594+
"""
595+
side = int(seq_len**0.5)
596+
coords = torch.arange(seq_len)
597+
xy = torch.stack((coords % side, coords // side), dim=-1)
598+
return xy.unsqueeze(0).expand(batch_size, -1, -1).long()
599+
600+
601+
def _padding_positions(batch_size: int, seq_len: int) -> torch.Tensor:
602+
"""Create an all-False padding mask (no padding)."""
603+
return torch.zeros(batch_size, seq_len, dtype=torch.bool)
604+
605+
586606
class Gemma4VisionAttentionCase(Gemma4BaseCase):
587607
"""Smoke case for one tiny Gemma4 vision attention module."""
588608

@@ -693,6 +713,108 @@ def eval_input(
693713
return self._sample()
694714

695715

716+
class Gemma4VisionPoolerCase(Gemma4BaseCase):
717+
"""Smoke case for one tiny Gemma4 vision pooler module."""
718+
719+
name = "gemma4_vision_pooler"
720+
description = "Quantize one tiny Gemma4 vision pooler module."
721+
tags = ("gemma4", "e2b", "vision", "pooler")
722+
max_mean_abs_diff = 2.0
723+
# seq_len=16 and output_length=4 so that k=2 (16 / 4 = 4, sqrt(4) = 2).
724+
seq_len = 16
725+
output_length = 4
726+
727+
def build(self, cfg: Mapping[str, Any]) -> tuple[torch.nn.Module, torch.nn.Module]:
728+
"""Build a tiny Gemma4 vision pooler module and reference copy."""
729+
from transformers.models.gemma4.modeling_gemma4 import Gemma4VisionPooler
730+
731+
torch.manual_seed(123)
732+
self.vision_cfg = _make_vision_config()
733+
module = Gemma4VisionPooler(self.vision_cfg).eval()
734+
return module, clone_module(module)
735+
736+
def _sample(self) -> ForwardInput:
737+
"""Create one synthetic Gemma4 vision pooler input."""
738+
batch_size = 1
739+
return ForwardInput(
740+
(),
741+
{
742+
"hidden_states": torch.randn(
743+
batch_size, self.seq_len, self.vision_cfg.hidden_size
744+
),
745+
"pixel_position_ids": _pixel_position_ids(batch_size, self.seq_len),
746+
"padding_positions": _padding_positions(batch_size, self.seq_len),
747+
"output_length": self.output_length,
748+
},
749+
)
750+
751+
def forward(self, module: torch.nn.Module, sample: ForwardInput) -> Any:
752+
"""Run a Gemma4 vision pooler without sharing mutable sample state."""
753+
cloned = _clone_forward_input(sample)
754+
output = module(*cloned.args, **dict(cloned.kwargs))
755+
# Return only the pooled features for comparison.
756+
return output[0] if isinstance(output, tuple) else output
757+
758+
def reference_forward(
759+
self, reference: torch.nn.Module, sample: ForwardInput
760+
) -> Any:
761+
"""Run the original Gemma4 vision pooler without sharing mutable sample state."""
762+
cloned = _clone_forward_input(sample)
763+
output = reference(*cloned.args, **dict(cloned.kwargs))
764+
return output[0] if isinstance(output, tuple) else output
765+
766+
def calibration_inputs(
767+
self,
768+
prepared: torch.nn.Module,
769+
cfg: Mapping[str, Any],
770+
) -> list[ForwardInput]:
771+
"""Create Gemma4 vision pooler calibration samples."""
772+
return [self._sample() for _ in range(3)]
773+
774+
def eval_input(
775+
self,
776+
prepared: torch.nn.Module,
777+
cfg: Mapping[str, Any],
778+
) -> ForwardInput:
779+
"""Create the Gemma4 vision pooler evaluation sample."""
780+
return self._sample()
781+
782+
def export_module(
783+
self, quantized: torch.nn.Module, cfg: Mapping[str, Any]
784+
) -> torch.nn.Module:
785+
"""Export the wrapped pooler in prefill mode with fixed output_length.
786+
787+
Passes ``pixel_position_ids`` so the export adapter precomputes the
788+
pooling weight matrix and output mask at construction time, replacing
789+
the dynamic ``F.one_hot`` and ``torch.div`` operations with a static
790+
``matmul``.
791+
"""
792+
wrapped = getattr(quantized, "wrapped", quantized)
793+
if hasattr(wrapped, "as_export_module"):
794+
pixel_pos_ids = _pixel_position_ids(1, self.seq_len)
795+
return wrapped.as_export_module(
796+
mode="prefill",
797+
output_length=self.output_length,
798+
pixel_position_ids=pixel_pos_ids,
799+
).eval()
800+
return quantized
801+
802+
def export_input(
803+
self, eval_sample: ForwardInput, cfg: Mapping[str, Any]
804+
) -> ForwardInput:
805+
"""Create static export inputs expected by the pooler adapter.
806+
807+
The export adapter bakes ``output_length`` as a construction-time
808+
constant, so it is not included in the forward signature.
809+
"""
810+
cloned = _clone_forward_input(eval_sample)
811+
kwargs = dict(cloned.kwargs)
812+
hidden = kwargs["hidden_states"]
813+
pixel_position_ids = kwargs["pixel_position_ids"]
814+
padding_positions = kwargs["padding_positions"]
815+
return ForwardInput((hidden, pixel_position_ids, padding_positions), {})
816+
817+
696818
GEMMA4_CASES = (
697819
Gemma4TextMLPCase(),
698820
Gemma4TextAttentionCase(),
@@ -704,4 +826,5 @@ def eval_input(
704826
Gemma4TextDecoderLayerSharedKVCase(),
705827
Gemma4VisionAttentionCase(),
706828
Gemma4VisionEncoderLayerCase(),
829+
Gemma4VisionPoolerCase(),
707830
)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# DO NOT REMOVE THIS FILE

0 commit comments

Comments
 (0)