[quantization] Add Gemma4VisionPooler PTQ wrapper with export support#787
Conversation
302b968 to
58198ea
Compare
| return wrapped.as_export_module( | ||
| "prefill", | ||
| output_length=self.output_length, | ||
| pixel_position_ids=pixel_pos_ids, | ||
| ).eval() |
There was a problem hiding this comment.
The "prefill" string is passed as a positional arg, this causes the test to fail:
python -m tico.quantization.examples.inspect --config tico/quantization/examples/configs/wrapper_smoke.yaml --mode wrapper-smoke --case gemma4_vision_pooler --export circle --output-dir ./out/wrapper_smoke
Output:
- Circle export failed: QuantGemma4VisionPooler.as_export_module() takes 1 positional argument but 2 positional arguments (and 2 keyword-only arguments) were given
There was a problem hiding this comment.
👍 Fixed, thank you for catching that!
| Input contract: | ||
| ``hidden_states`` has shape ``(1, S, D)`` where ``S`` is the fixed | ||
| vision encoder sequence length. | ||
| ``pixel_position_ids`` has shape ``(1, S, 2)`` — pre-computed on CPU. | ||
| ``padding_positions`` has shape ``(1, S)`` — pre-computed on CPU. |
There was a problem hiding this comment.
If I understand correctly this adapter is for "prefill" mode, could you consider to rename it to "Gemma4VisionPoolerPrefillExportAdapter"
| *, | ||
| output_length: int, | ||
| pixel_position_ids: torch.Tensor, | ||
| ) -> nn.Module: |
There was a problem hiding this comment.
Also, as far as I know QuantGemma4VisionPooler won't work in decode phase.
Maybe we can consider such changes:
def as_export_module(self,
*,
,mode: ExportMode = "prefill",
output_length: int,
pixel_position_ids: torch.Tensor,
) -> nn.Module:
"""Return a static export adapter for the requested execution mode."""
if mode != "prefill":
raise ValueError(
f"Unsupported Gemma4 VisionPooler export mode: {mode!r}"
)
...
return Gemma4VisionPoolerPrefillExportAdapter(self) <- renamed function
a3ef615 to
2e7ec45
Compare
|
Sorry for late review. I was in training session. I'll review this PR soon. |
| for obs in self._all_observers(): | ||
| assert obs.has_qparams |
There was a problem hiding this comment.
For the mx types, we can add an if condition here since MXObserver doesn't have has_qparams.
| for obs in self._all_observers(): | |
| assert obs.has_qparams | |
| for obs in self._all_observers(): | |
| if isinstance(obs, AffineObserverBase): | |
| assert obs.has_qparams |
|
|
||
| return pooled, updated_padding | ||
|
|
||
| def as_export_module( |
There was a problem hiding this comment.
(Optional) I think the static pooling profile should be owned by the export adapter rather than mutating the original wrapper. The current implementation does this in as_export_module().
This makes as_export_module() different from most other wrappers. For most export adapters, we only adapt the input/output contract and keep using the existing module state. In this pooler, however, as_export_module() also specializes the wrapper for a particular static profile by materializing pool_weights and pool_mask from pixel_position_ids / output_length.
There are two issues with that:
- The static profile is stored on the original
QuantGemma4VisionPooler, not on the returned adapter. Ifas_export_module()is called again with another profile, the later call overwritesself.pool_weights/self.pool_mask, so an already-created adapter can silently start using the new profile. collect(self.pool_weights)afterfreeze_qparams()is likely not doing what it looks like.freeze_qparams()disables observers, andObserverBase.collect()returns immediately whenenabled == False. So the newly-created buffer is not actually collected into the observer stats beforecompute_qparams().
I think a cleaner design is to make the export adapter own the profile-specific tensors and a profile-local observer for pool_weights.
Something like:
@staticmethod
def _make_profile_observer(
observer: ObserverBase,
value: torch.Tensor,
) -> ObserverBase:
"""
Return an observer calibrated only for a static export-profile tensor.
"""
profile_observer = copy.deepcopy(observer)
profile_observer.enabled = True
profile_observer.reset()
profile_observer.collect(value.detach())
profile_observer.compute_qparams()
profile_observer.enabled = False
return profile_observer
def as_export_module(
self,
mode: ExportMode = "prefill",
*,
output_length: int,
pixel_position_ids: torch.Tensor,
) -> nn.Module:
if mode != "prefill":
raise ValueError(f"Unsupported Gemma4 VisionPooler export mode: {mode!r}")
if self._mode is not Mode.QUANT:
raise RuntimeError(
"Gemma4VisionPooler can be exported only after freeze_qparams()."
)
# Affine observers expose ``has_qparams``; non-affine observers such as
# MX may not. Do not assume the attribute exists for every observer.
for obs in self._all_observers():
if hasattr(obs, "has_qparams") and not obs.has_qparams:
raise RuntimeError(f"Observer {obs.name!r} does not have qparams.")
from tico.quantization.wrapq.wrappers.gemma4.export_adapters import (
Gemma4VisionPoolerPrefillExportAdapter,
)
pool_weights, pool_mask = self._build_pool_weights(
seq_len=pixel_position_ids.shape[1],
output_length=output_length,
pixel_position_ids=pixel_position_ids,
)
pool_weight_observer = self._make_profile_observer(
self.obs_pool_weight,
pool_weights,
)
return Gemma4VisionPoolerPrefillExportAdapter(
self,
pool_weights=pool_weights,
pool_mask=pool_mask,
pool_weight_observer=pool_weight_observer,
)And the adapter can own the static buffers:
class Gemma4VisionPoolerPrefillExportAdapter(nn.Module):
def __init__(
self,
wrapped_pooler: nn.Module,
*,
pool_weights: torch.Tensor,
pool_mask: torch.Tensor,
pool_weight_observer: ObserverBase,
):
super().__init__()
self.wrapped_pooler = wrapped_pooler
self.pool_weight_observer = pool_weight_observer
self.register_buffer("pool_weights", pool_weights.detach().clone())
self.register_buffer("pool_mask", pool_mask.detach().clone())
def forward(
self,
hidden_states: torch.Tensor,
padding_positions: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
return self.wrapped_pooler.forward_export(
hidden_states,
padding_positions,
pool_weights=self.pool_weights,
pool_mask=self.pool_mask,
pool_weight_observer=self.pool_weight_observer,
)Then forward_export() can use the adapter-provided profile state:
def forward_export(
self,
hidden_states: torch.Tensor,
padding_positions: torch.Tensor,
*,
pool_weights: torch.Tensor,
pool_mask: torch.Tensor,
pool_weight_observer: ObserverBase,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Run the decomposed static pooler forward.
The export adapter owns ``pool_weights``, ``pool_mask``, and the
profile-local observer for ``pool_weights``. This keeps each static
export profile independent from other adapters created from the same
quantized wrapper.
"""
input_dtype = hidden_states.dtype
hidden_states = self._fq(hidden_states, self.obs_act_in)
# Step 1: Zero out padding positions.
hidden_states = hidden_states.masked_fill(padding_positions.unsqueeze(-1), 0)
# Step 2: Fake-quantize input.
hidden_states = self._fq(hidden_states, self.obs_pool_in)
# Step 3: Fake-quantize static pool weights with the profile-local observer.
if self._mode is Mode.QUANT:
pool_weights_q = pool_weight_observer.fake_quant(pool_weights)
# Step 4: Spatial pooling.
pooled = pool_weights_q @ hidden_states.float()
pooled = pooled.to(input_dtype)
# Step 5: Fake-quantize matmul output.
pooled = self._fq(pooled, self.obs_pool_matmul_out)
# Step 6: Scale by sqrt(hidden_size).
if self._mode is Mode.QUANT:
root_hidden_size = self.obs_root_hidden_size.fake_quant(self.root_hidden_size)
pooled = pooled * root_hidden_size
# Step 7: Fake-quantize final output.
pooled = self._fq(pooled, self.obs_pool_out)
updated_padding = pool_mask.expand(pooled.shape[0], -1)
return pooled, updated_paddingThere was a problem hiding this comment.
The static profile is stored on the original QuantGemma4VisionPooler, not on the returned adapter. If as_export_module() is called again with another profile, the later call overwrites self.pool_weights / self.pool_mask, so an already-created adapter can silently start using the new profile.
Yes, that's true. And yet, I think, there's something to discuss.
Let me try to imagine a situation in which we call as_export_module() multiple times. In such situation we'll need to:
- Calibrate the model once.
- Export the model (call
as_export_moduleand store the returned adapter). - Calibrate the model for the 2nd time.
- Export the model again (call
as_export_moduleagain and store the adapter). - Convert the 1st export adapter to Circle.
- Convert the 2nd adapter to Circle.
Doesn't this scenario look a little exotic?
Let's take a look at the export adapters defined in tico/quantization/wrapq/wrappers/gemma4/export_adapters.py:
Gemma4TokenEmbeddingExportAdapterGemma4VisionPrefillExportAdapterGemma4VisionEncoderLayerPrefillExportAdapterGemma4TextDecoderLayerPrefillExportAdapterGemma4TextDecoderLayerDecodeExportAdapter
In their forward methods they call their wrapped modules self.wrapped(...). So, I guess, the above scenario is equally applicable to them. If you calibrate the model, then call as_export_module and store the returned adapter, then calibrate the model again on different data and call as_export_module for the 2nd time, won't this affect the 1st adapter? The 1st and the 2nd adapters use the same wrapped module don't they?
Here's why I prefer to keep all observers and static tensors in QuantGemma4VisionPooler:
- "Traditionally" observers belong to classes that inherit from
QuantModuleBase. Export adapter doesn't inherit fromQuantModuleBase, so I don't want to let the adapter own observers. - Most export adapters in
export_adapters.pyare very simple and concise as they delegate most of the work to the original wrapped modules. In my mind even the term "adapter" suggests that it must be simple. So, I don't want to complicate the adapter moving observers and tensors there. I feel better when the complexity is encapsulated in the original wrapper class.
There was a problem hiding this comment.
collect(self.pool_weights) after freeze_qparams() is likely not doing what it looks like. freeze_qparams() disables observers, and ObserverBase.collect() returns immediately when enabled == False. So the newly-created buffer is not actually collected into the observer stats before compute_qparams().
👍 Good point, thank you. I've fixed this right in the as_export_module:
class QuantGemma4VisionPooler(QuantModuleBase):
...
def as_export_module(
...
# Collect statistics about pool_weights and compute qparams
obs_pool_weight_enabled: bool = self.obs_pool_weight.enabled
self.obs_pool_weight.enabled = True
self.obs_pool_weight.reset()
self.obs_pool_weight.collect(self.pool_weights)
self.obs_pool_weight.compute_qparams()
self.obs_pool_weight.enabled = obs_pool_weight_enabled
...
There was a problem hiding this comment.
The 1st and the 2nd adapters use the same wrapped module don't they?
Thanks, that is a fair point. I agree that the practical risk is limited. Actually, that is why I tagged the comment as an optional.
I just wanted to make it a bit more robust for future changes since there's many changes recently.
It'd better document or enforce the single-use/profile contract may be sufficient in another PR.
| output_length: Number of soft tokens to produce. When ``None`` | ||
| the pooler returns the same sequence length. | ||
| """ | ||
| assert output_length is not None |
There was a problem hiding this comment.
I guess original module checks if hidden_states.shape[1] == output_length. If you intentionally ignored that logic, it would be better to add an assertion for this.
There was a problem hiding this comment.
👍 Fixed, thank you for catching that.
5c54e97 to
8be9bda
Compare
Implements quantization wrapper for Gemma4 vision pooler. TICO-DCO-1.0-Signed-off-by: d.savchenkov <d.savchenkov@partner.samsung.com>
What
This PR implements a complete Post-Training Quantization (PTQ) wrapper for the Gemma4 vision pooler module. The implementation includes a dual-mode wrapper with both flexible and exportable forward methods, an export adapter for static-shape torch.export, comprehensive unit tests, smoke tests, and an example script demonstrating the full PTQ workflow.
Why
The Gemma4 vision pooler performs spatial pooling of vision patch tokens and scales the result by
sqrt(hidden_size)in float32. The original implementation uses dynamic operations (F.one_hot,torch.div) that are nottorch.export-friendly. This PR provides:Key Design Decisions
Dual Forward Methods: The wrapper implements two forward methods:
forward(): Flexible, supports dynamic shapes, used during calibrationforward_export(): Static, exportable, uses precomputed buffers for Circle conversionPrecomputed Weight Matrix: The pooling operation is decomposed into a precomputed weight matrix (
pool_weights) and output mask (pool_mask), replacing dynamicF.one_hot+torch.divwith a staticmatmul.Six Observers: The wrapper observes:
obs_act_in: Input activation (after initial fake-quant)obs_pool_in: Pooling input (after masked_fill)obs_pool_weight: Precomputed weight matrixobs_root_hidden_size: Scaling factorobs_pool_matmul_out: Matmul outputobs_pool_out: Final pooled outputExport Adapter Pattern:
Gemma4VisionPoolerExportAdapteris a thin wrapper that redirects toforward_export(), keeping the export graph clean and separating quantization logic from export contracts.Changes
tico/quantization/wrapq/wrappers/gemma4/quant_vision_pooler.pyforward(),forward_export(),as_export_module(), and_build_pool_weights()static methodtico/quantization/wrapq/wrappers/gemma4/export_adapters.pyGemma4VisionPoolerExportAdapterclass that redirects toforward_export()tico/quantization/wrapq/wrappers/registry.pyquant_vision_poolerin_CORE_MODULES(uncommented)tico/quantization/wrapq/utils/version.py4.59.0)test/quantization/wrapq/wrappers/gemma4/test_quant_vision_pooler.pyforward_export(), andas_export_module()test/quantization/wrapq/wrappers/gemma4/test_quantize_vision_pooler.pytico/quantization/wrapq/examples/gemma4/quantize_vision_pooler.pytico/quantization/wrapq/examples/gemma4/__init__.pytico/quantization/recipes/debug/wrapper_smoke/cases/gemma4.pyTests
Unit Tests (
test_quant_vision_pooler.py):test_no_quant_forward_matches_fp: Verifies wrapper matches FP module in NO_QUANT modetest_no_quant_output_shape: Validates output shapetest_mode_transitions: Tests NO_QUANT → CALIB → QUANT lifecycletest_observers_are_collected: Verifies all 6 observers are registeredtest_input_is_observed_in_calib_mode: Confirms input observation during calibrationtest_output_is_fake_quantized_in_quant_mode: Confirms fake quantization in QUANT modetest_quant_mode_output_is_finite: Validates finite output in QUANT modetest_dtype_override: Tests PTQConfig dtype overridestest_forward_export_requires_precomputed_buffers: Testsforward_export()prerequisitestest_forward_export_matches_forward: Comparesforward_export()withforward()test_as_export_module_requires_quant_mode: Validates QUANT mode requirementtest_as_export_module_returns_adapter: Verifies adapter typetest_as_export_module_precomputes_buffers_on_wrapper: Checks buffer registrationtest_export_adapter_decomposed_forward_matches_fp: Validates adapter outputtest_build_pool_weights_matches_original: Compares decomposed pooling with originalInternal Tests (
test_quantize_vision_pooler.py):test_no_quant_vision_pooler_matches_reference: Wrapper parity testtest_prepare_convert_vision_pooler_flow: Full PTQ flow testtest_as_export_module_flow: Export module flow testSmoke Test
Example Scripts
quantize_vision_pooler.py:as_export_module()gemma4_vision_pooler.q.circleThe example script produces quantization error statistics and a visualization of FP vs. quantized output distributions.