Skip to content

[quantization] Add Gemma4VisionPooler PTQ wrapper with export support#787

Merged
mhs4670go merged 1 commit into
Samsung:mainfrom
dvsav:vision_pooler
Jun 22, 2026
Merged

[quantization] Add Gemma4VisionPooler PTQ wrapper with export support#787
mhs4670go merged 1 commit into
Samsung:mainfrom
dvsav:vision_pooler

Conversation

@dvsav

@dvsav dvsav commented Jun 16, 2026

Copy link
Copy Markdown
Contributor

What

This PR implements a complete Post-Training Quantization (PTQ) wrapper for the Gemma4 vision pooler module. The implementation includes a dual-mode wrapper with both flexible and exportable forward methods, an export adapter for static-shape torch.export, comprehensive unit tests, smoke tests, and an example script demonstrating the full PTQ workflow.

Why

The Gemma4 vision pooler performs spatial pooling of vision patch tokens and scales the result by sqrt(hidden_size) in float32. The original implementation uses dynamic operations (F.one_hot, torch.div) that are not torch.export-friendly. This PR provides:

  • A PTQ wrapper that observes input and output activations for calibration
  • A decomposed static implementation suitable for NPU export
  • Complete test coverage to ensure correctness

Key Design Decisions

  1. Dual Forward Methods: The wrapper implements two forward methods:

    • forward(): Flexible, supports dynamic shapes, used during calibration
    • forward_export(): Static, exportable, uses precomputed buffers for Circle conversion
  2. Precomputed Weight Matrix: The pooling operation is decomposed into a precomputed weight matrix (pool_weights) and output mask (pool_mask), replacing dynamic F.one_hot + torch.div with a static matmul.

  3. Six Observers: The wrapper observes:

    • obs_act_in: Input activation (after initial fake-quant)
    • obs_pool_in: Pooling input (after masked_fill)
    • obs_pool_weight: Precomputed weight matrix
    • obs_root_hidden_size: Scaling factor
    • obs_pool_matmul_out: Matmul output
    • obs_pool_out: Final pooled output
  4. Export Adapter Pattern: Gemma4VisionPoolerExportAdapter is a thin wrapper that redirects to forward_export(), keeping the export graph clean and separating quantization logic from export contracts.

Changes

File Changes
tico/quantization/wrapq/wrappers/gemma4/quant_vision_pooler.py Implemented full PTQ wrapper with 6 observers, forward(), forward_export(), as_export_module(), and _build_pool_weights() static method
tico/quantization/wrapq/wrappers/gemma4/export_adapters.py Added Gemma4VisionPoolerExportAdapter class that redirects to forward_export()
tico/quantization/wrapq/wrappers/registry.py Enabled quant_vision_pooler in _CORE_MODULES (uncommented)
tico/quantization/wrapq/utils/version.py Added Gemma4 minimum version requirement (4.59.0)
test/quantization/wrapq/wrappers/gemma4/test_quant_vision_pooler.py Added 15 unit tests covering NO_QUANT mode, mode transitions, calibration, fake quantization, dtype overrides, forward_export(), and as_export_module()
test/quantization/wrapq/wrappers/gemma4/test_quantize_vision_pooler.py Added 3 smoke tests for prepare-calibrate-convert flow and export module
tico/quantization/wrapq/examples/gemma4/quantize_vision_pooler.py Added example script demonstrating full PTQ workflow with Circle export
tico/quantization/wrapq/examples/gemma4/__init__.py Added package init file
tico/quantization/recipes/debug/wrapper_smoke/cases/gemma4.py Added smoke test case for Gemma4 vision pooler

Tests

Unit Tests (test_quant_vision_pooler.py):

  • test_no_quant_forward_matches_fp: Verifies wrapper matches FP module in NO_QUANT mode
  • test_no_quant_output_shape: Validates output shape
  • test_mode_transitions: Tests NO_QUANT → CALIB → QUANT lifecycle
  • test_observers_are_collected: Verifies all 6 observers are registered
  • test_input_is_observed_in_calib_mode: Confirms input observation during calibration
  • test_output_is_fake_quantized_in_quant_mode: Confirms fake quantization in QUANT mode
  • test_quant_mode_output_is_finite: Validates finite output in QUANT mode
  • test_dtype_override: Tests PTQConfig dtype overrides
  • test_forward_export_requires_precomputed_buffers: Tests forward_export() prerequisites
  • test_forward_export_matches_forward: Compares forward_export() with forward()
  • test_as_export_module_requires_quant_mode: Validates QUANT mode requirement
  • test_as_export_module_returns_adapter: Verifies adapter type
  • test_as_export_module_precomputes_buffers_on_wrapper: Checks buffer registration
  • test_export_adapter_decomposed_forward_matches_fp: Validates adapter output
  • test_build_pool_weights_matches_original: Compares decomposed pooling with original
$ python3 -m pytest test/quantization/wrapq/wrappers/gemma4/test_quant_vision_pooler.py -v
==================================================================== test session starts =====================================================================
platform linux -- Python 3.10.12, pytest-8.4.0, pluggy-1.6.0 -- /home/d.savchenkov/myenv/bin/python3
cachedir: .pytest_cache
rootdir: /home/d.savchenkov/TICO
configfile: pyproject.toml
plugins: anyio-4.12.0, mock-3.15.1, xdist-3.7.0, cov-6.2.1
collected 15 items                                                                                                                                           

test/quantization/wrapq/wrappers/gemma4/test_quant_vision_pooler.py::TestQuantGemma4VisionPooler::test_as_export_module_precomputes_buffers_on_wrapper PASSED [  6%]
test/quantization/wrapq/wrappers/gemma4/test_quant_vision_pooler.py::TestQuantGemma4VisionPooler::test_as_export_module_requires_quant_mode PASSED     [ 13%]
test/quantization/wrapq/wrappers/gemma4/test_quant_vision_pooler.py::TestQuantGemma4VisionPooler::test_as_export_module_returns_adapter PASSED         [ 20%]
test/quantization/wrapq/wrappers/gemma4/test_quant_vision_pooler.py::TestQuantGemma4VisionPooler::test_build_pool_weights_matches_original PASSED      [ 26%]
test/quantization/wrapq/wrappers/gemma4/test_quant_vision_pooler.py::TestQuantGemma4VisionPooler::test_dtype_override PASSED                           [ 33%]
test/quantization/wrapq/wrappers/gemma4/test_quant_vision_pooler.py::TestQuantGemma4VisionPooler::test_export_adapter_decomposed_forward_matches_fp PASSED [ 40%]
test/quantization/wrapq/wrappers/gemma4/test_quant_vision_pooler.py::TestQuantGemma4VisionPooler::test_forward_export_matches_forward PASSED           [ 46%]
test/quantization/wrapq/wrappers/gemma4/test_quant_vision_pooler.py::TestQuantGemma4VisionPooler::test_forward_export_requires_precomputed_buffers PASSED [ 53%]
test/quantization/wrapq/wrappers/gemma4/test_quant_vision_pooler.py::TestQuantGemma4VisionPooler::test_input_is_observed_in_calib_mode PASSED          [ 60%]
test/quantization/wrapq/wrappers/gemma4/test_quant_vision_pooler.py::TestQuantGemma4VisionPooler::test_mode_transitions PASSED                         [ 66%]
test/quantization/wrapq/wrappers/gemma4/test_quant_vision_pooler.py::TestQuantGemma4VisionPooler::test_no_quant_forward_matches_fp PASSED              [ 73%]
test/quantization/wrapq/wrappers/gemma4/test_quant_vision_pooler.py::TestQuantGemma4VisionPooler::test_no_quant_output_shape PASSED                    [ 80%]
test/quantization/wrapq/wrappers/gemma4/test_quant_vision_pooler.py::TestQuantGemma4VisionPooler::test_observers_are_collected PASSED                  [ 86%]
test/quantization/wrapq/wrappers/gemma4/test_quant_vision_pooler.py::TestQuantGemma4VisionPooler::test_output_is_fake_quantized_in_quant_mode PASSED   [ 93%]
test/quantization/wrapq/wrappers/gemma4/test_quant_vision_pooler.py::TestQuantGemma4VisionPooler::test_quant_mode_output_is_finite PASSED              [100%]

=============================================================== 15 passed, 2 warnings in 5.30s ===============================================================

Internal Tests (test_quantize_vision_pooler.py):

  • test_no_quant_vision_pooler_matches_reference: Wrapper parity test
  • test_prepare_convert_vision_pooler_flow: Full PTQ flow test
  • test_as_export_module_flow: Export module flow test
$ RUN_INTERNAL_TESTS=1 python3 -m pytest test/quantization/wrapq/wrappers/gemma4/test_quantize_vision_pooler.py -v
==================================================================== test session starts =====================================================================
platform linux -- Python 3.10.12, pytest-8.4.0, pluggy-1.6.0 -- /home/d.savchenkov/myenv/bin/python3
cachedir: .pytest_cache
rootdir: /home/d.savchenkov/TICO
configfile: pyproject.toml
plugins: anyio-4.12.0, mock-3.15.1, xdist-3.7.0, cov-6.2.1
collected 3 items                                                                                                                                            

test/quantization/wrapq/wrappers/gemma4/test_quantize_vision_pooler.py::TestGemma4VisionPoolerSmoke::test_as_export_module_flow PASSED                 [ 33%]
test/quantization/wrapq/wrappers/gemma4/test_quantize_vision_pooler.py::TestGemma4VisionPoolerSmoke::test_no_quant_vision_pooler_matches_reference PASSED [ 66%]
test/quantization/wrapq/wrappers/gemma4/test_quantize_vision_pooler.py::TestGemma4VisionPoolerSmoke::test_prepare_convert_vision_pooler_flow PASSED    [100%]

=============================================================== 3 passed, 2 warnings in 5.17s ================================================================

Smoke Test

$ python -m tico.quantization.examples.inspect \
    --config tico/quantization/examples/configs/wrapper_smoke.yaml \
    --mode wrapper-smoke \
    --case gemma4_vision_pooler \
    --export circle \
    --output-dir ./out/wrapper_smoke
[QuantCheck] WARNING: 6 nodes without qparam detected (see logs).
┌───────────── Wrapper Smoke Summary ─────────────
│ Case             : gemma4_vision_pooler
│ Status           : PASS
│ Mean |diff|      : 0.029907
│ Max |diff|       : 0.554016
│ PEIR             : 0.038011
│ Shape match      : True
│ Quant finite     : True
└─────────────────────────────────────────────────
Artifacts:
  - circle: out/wrapper_smoke/gemma4_vision_pooler.q.circle
    ┌────────────────────────────────────────────┐
 9.4┤                                            │
    │                                        ••  │
    │                                            │
 6.8┤                                    •       │
    │                                •••         │
    │                             •••            │
 4.1┤                           •••              │
    │                        •••                 │
 1.4┤                      ••                    │
    │                   •••                      │
    │                 ••                         │
-1.3┤              •••                           │
    │           •••                              │
    │         •••                                │
-3.9┤      •••                                   │
    │    • •                                     │
    │  •••                                       │
-6.6┤                                            │
    └┬──────────┬──────────┬─────────┬──────────┬┘
   -6.6       -2.6        1.4       5.4       9.4

Example Scripts

quantize_vision_pooler.py:

  • Creates a tiny Gemma4VisionPooler with synthetic config
  • Generates 20 calibration samples
  • Runs prepare → calibrate → convert flow
  • Computes PEIR (Peak Error-to-Interval Ratio) between FP and quantized outputs
  • Exports to Circle format via as_export_module()
  • Saves output as gemma4_vision_pooler.q.circle

The example script produces quantization error statistics and a visualization of FP vs. quantized output distributions.

$ python tico/quantization/wrapq/examples/gemma4/quantize_vision_pooler.py
┌───────────── Quantization Error Summary ─────────────
│ FP pooled shape    : (1, 4, 32)
│ Quant pooled shape : (1, 4, 32)
│ Mean |diff|        : 0.026594
│ PEIR               : 0.494674 %
└──────────────────────────────────────────────────────
    ┌────────────────────────────────────────────┐
 7.8┤                                            │
    │                                       • •  │
 5.0┤                                    •••     │
    │                                ••••        │
    │                              •••           │
 2.3┤                          ••••              │
    │                       ••••                 │
-0.5┤                    ••••                    │
    │                 ••••                       │
-3.2┤              ••••                          │
    │            ••                              │
    │        •••                                 │
-6.0┤                                            │
    │  •                                         │
-8.7┤                                            │
    └┬──────────┬──────────┬─────────┬──────────┬┘
   -8.7       -4.6       -0.5       3.7       7.8 

[QuantCheck] WARNING: 6 nodes without qparam detected (see logs).
Circle model saved as 'gemma4_vision_pooler.q.circle'

@dvsav dvsav marked this pull request as ready for review June 17, 2026 07:16
@dvsav dvsav requested review from Torrero and mhs4670go June 17, 2026 07:16
@dvsav dvsav force-pushed the vision_pooler branch 2 times, most recently from 302b968 to 58198ea Compare June 17, 2026 14:20
Comment on lines +795 to +799
return wrapped.as_export_module(
"prefill",
output_length=self.output_length,
pixel_position_ids=pixel_pos_ids,
).eval()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The "prefill" string is passed as a positional arg, this causes the test to fail:

python -m tico.quantization.examples.inspect   --config tico/quantization/examples/configs/wrapper_smoke.yaml   --mode wrapper-smoke   --case gemma4_vision_pooler   --export circle   --output-dir ./out/wrapper_smoke

Output:

 - Circle export failed: QuantGemma4VisionPooler.as_export_module() takes 1 positional argument but 2 positional arguments (and 2 keyword-only arguments) were given

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 Fixed, thank you for catching that!

Comment on lines +281 to +285
Input contract:
``hidden_states`` has shape ``(1, S, D)`` where ``S`` is the fixed
vision encoder sequence length.
``pixel_position_ids`` has shape ``(1, S, 2)`` — pre-computed on CPU.
``padding_positions`` has shape ``(1, S)`` — pre-computed on CPU.

@Torrero Torrero Jun 17, 2026

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly this adapter is for "prefill" mode, could you consider to rename it to "Gemma4VisionPoolerPrefillExportAdapter"

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 done

*,
output_length: int,
pixel_position_ids: torch.Tensor,
) -> nn.Module:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, as far as I know QuantGemma4VisionPooler won't work in decode phase.
Maybe we can consider such changes:

   def as_export_module(self, 
        *,
        ,mode: ExportMode = "prefill", 
        output_length: int,
        pixel_position_ids: torch.Tensor,
   ) -> nn.Module:
        """Return a static export adapter for the requested execution mode."""
        if mode != "prefill":
          raise ValueError(
              f"Unsupported Gemma4 VisionPooler export mode: {mode!r}"
          )
        ...
        return Gemma4VisionPoolerPrefillExportAdapter(self)  <- renamed function 

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 done

@dvsav dvsav force-pushed the vision_pooler branch 3 times, most recently from a3ef615 to 2e7ec45 Compare June 18, 2026 07:57
Torrero
Torrero previously approved these changes Jun 18, 2026

@Torrero Torrero left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you!

@mhs4670go

Copy link
Copy Markdown
Contributor

Sorry for late review. I was in training session. I'll review this PR soon.

Comment on lines +266 to +267
for obs in self._all_observers():
assert obs.has_qparams

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the mx types, we can add an if condition here since MXObserver doesn't have has_qparams.

Suggested change
for obs in self._all_observers():
assert obs.has_qparams
for obs in self._all_observers():
if isinstance(obs, AffineObserverBase):
assert obs.has_qparams

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 done


return pooled, updated_padding

def as_export_module(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Optional) I think the static pooling profile should be owned by the export adapter rather than mutating the original wrapper. The current implementation does this in as_export_module().

This makes as_export_module() different from most other wrappers. For most export adapters, we only adapt the input/output contract and keep using the existing module state. In this pooler, however, as_export_module() also specializes the wrapper for a particular static profile by materializing pool_weights and pool_mask from pixel_position_ids / output_length.

There are two issues with that:

  1. The static profile is stored on the original QuantGemma4VisionPooler, not on the returned adapter. If as_export_module() is called again with another profile, the later call overwrites self.pool_weights / self.pool_mask, so an already-created adapter can silently start using the new profile.
  2. collect(self.pool_weights) after freeze_qparams() is likely not doing what it looks like. freeze_qparams() disables observers, and ObserverBase.collect() returns immediately when enabled == False. So the newly-created buffer is not actually collected into the observer stats before compute_qparams().

I think a cleaner design is to make the export adapter own the profile-specific tensors and a profile-local observer for pool_weights.

Something like:

@staticmethod
    def _make_profile_observer(
        observer: ObserverBase,
        value: torch.Tensor,
    ) -> ObserverBase:
        """
        Return an observer calibrated only for a static export-profile tensor.
        """
        profile_observer = copy.deepcopy(observer)

        profile_observer.enabled = True
        profile_observer.reset()
        profile_observer.collect(value.detach())
        profile_observer.compute_qparams()
        profile_observer.enabled = False

        return profile_observer


def as_export_module(
        self,
        mode: ExportMode = "prefill",
        *,
        output_length: int,
        pixel_position_ids: torch.Tensor,
    ) -> nn.Module:
        if mode != "prefill":
            raise ValueError(f"Unsupported Gemma4 VisionPooler export mode: {mode!r}")

        if self._mode is not Mode.QUANT:
            raise RuntimeError(
                "Gemma4VisionPooler can be exported only after freeze_qparams()."
            )

        # Affine observers expose ``has_qparams``; non-affine observers such as
        # MX may not.  Do not assume the attribute exists for every observer.
        for obs in self._all_observers():
            if hasattr(obs, "has_qparams") and not obs.has_qparams:
                raise RuntimeError(f"Observer {obs.name!r} does not have qparams.")

        from tico.quantization.wrapq.wrappers.gemma4.export_adapters import (
            Gemma4VisionPoolerPrefillExportAdapter,
        )

        pool_weights, pool_mask = self._build_pool_weights(
            seq_len=pixel_position_ids.shape[1],
            output_length=output_length,
            pixel_position_ids=pixel_position_ids,
        )

        pool_weight_observer = self._make_profile_observer(
            self.obs_pool_weight,
            pool_weights,
        )

        return Gemma4VisionPoolerPrefillExportAdapter(
            self,
            pool_weights=pool_weights,
            pool_mask=pool_mask,
            pool_weight_observer=pool_weight_observer,
        )

And the adapter can own the static buffers:

class Gemma4VisionPoolerPrefillExportAdapter(nn.Module):
    def __init__(
        self,
        wrapped_pooler: nn.Module,
        *,
        pool_weights: torch.Tensor,
        pool_mask: torch.Tensor,
        pool_weight_observer: ObserverBase,
    ):
        super().__init__()

        self.wrapped_pooler = wrapped_pooler
        self.pool_weight_observer = pool_weight_observer

        self.register_buffer("pool_weights", pool_weights.detach().clone())
        self.register_buffer("pool_mask", pool_mask.detach().clone())

    def forward(
        self,
        hidden_states: torch.Tensor,
        padding_positions: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        return self.wrapped_pooler.forward_export(
            hidden_states,
            padding_positions,
            pool_weights=self.pool_weights,
            pool_mask=self.pool_mask,
            pool_weight_observer=self.pool_weight_observer,
        )

Then forward_export() can use the adapter-provided profile state:

def forward_export(
        self,
        hidden_states: torch.Tensor,
        padding_positions: torch.Tensor,
        *,
        pool_weights: torch.Tensor,
        pool_mask: torch.Tensor,
        pool_weight_observer: ObserverBase,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Run the decomposed static pooler forward.

        The export adapter owns ``pool_weights``, ``pool_mask``, and the
        profile-local observer for ``pool_weights``.  This keeps each static
        export profile independent from other adapters created from the same
        quantized wrapper.
        """
        input_dtype = hidden_states.dtype

        hidden_states = self._fq(hidden_states, self.obs_act_in)

        # Step 1: Zero out padding positions.
        hidden_states = hidden_states.masked_fill(padding_positions.unsqueeze(-1), 0)

        # Step 2: Fake-quantize input.
        hidden_states = self._fq(hidden_states, self.obs_pool_in)

        # Step 3: Fake-quantize static pool weights with the profile-local observer.
        if self._mode is Mode.QUANT:
            pool_weights_q = pool_weight_observer.fake_quant(pool_weights)

        # Step 4: Spatial pooling.
        pooled = pool_weights_q @ hidden_states.float()
        pooled = pooled.to(input_dtype)

        # Step 5: Fake-quantize matmul output.
        pooled = self._fq(pooled, self.obs_pool_matmul_out)

        # Step 6: Scale by sqrt(hidden_size).
        if self._mode is Mode.QUANT:
            root_hidden_size = self.obs_root_hidden_size.fake_quant(self.root_hidden_size)
        pooled = pooled * root_hidden_size

        # Step 7: Fake-quantize final output.
        pooled = self._fq(pooled, self.obs_pool_out)

        updated_padding = pool_mask.expand(pooled.shape[0], -1)
        return pooled, updated_padding

@dvsav dvsav Jun 22, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The static profile is stored on the original QuantGemma4VisionPooler, not on the returned adapter. If as_export_module() is called again with another profile, the later call overwrites self.pool_weights / self.pool_mask, so an already-created adapter can silently start using the new profile.

Yes, that's true. And yet, I think, there's something to discuss.

Let me try to imagine a situation in which we call as_export_module() multiple times. In such situation we'll need to:

  1. Calibrate the model once.
  2. Export the model (call as_export_module and store the returned adapter).
  3. Calibrate the model for the 2nd time.
  4. Export the model again (call as_export_module again and store the adapter).
  5. Convert the 1st export adapter to Circle.
  6. Convert the 2nd adapter to Circle.

Doesn't this scenario look a little exotic?


Let's take a look at the export adapters defined in tico/quantization/wrapq/wrappers/gemma4/export_adapters.py:

  • Gemma4TokenEmbeddingExportAdapter
  • Gemma4VisionPrefillExportAdapter
  • Gemma4VisionEncoderLayerPrefillExportAdapter
  • Gemma4TextDecoderLayerPrefillExportAdapter
  • Gemma4TextDecoderLayerDecodeExportAdapter

In their forward methods they call their wrapped modules self.wrapped(...). So, I guess, the above scenario is equally applicable to them. If you calibrate the model, then call as_export_module and store the returned adapter, then calibrate the model again on different data and call as_export_module for the 2nd time, won't this affect the 1st adapter? The 1st and the 2nd adapters use the same wrapped module don't they?


Here's why I prefer to keep all observers and static tensors in QuantGemma4VisionPooler:

  • "Traditionally" observers belong to classes that inherit from QuantModuleBase. Export adapter doesn't inherit from QuantModuleBase, so I don't want to let the adapter own observers.
  • Most export adapters in export_adapters.py are very simple and concise as they delegate most of the work to the original wrapped modules. In my mind even the term "adapter" suggests that it must be simple. So, I don't want to complicate the adapter moving observers and tensors there. I feel better when the complexity is encapsulated in the original wrapper class.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

collect(self.pool_weights) after freeze_qparams() is likely not doing what it looks like. freeze_qparams() disables observers, and ObserverBase.collect() returns immediately when enabled == False. So the newly-created buffer is not actually collected into the observer stats before compute_qparams().

👍 Good point, thank you. I've fixed this right in the as_export_module:

class QuantGemma4VisionPooler(QuantModuleBase):
    ...
    def as_export_module(
        ...
        # Collect statistics about pool_weights and compute qparams
        obs_pool_weight_enabled: bool = self.obs_pool_weight.enabled
        self.obs_pool_weight.enabled = True
        self.obs_pool_weight.reset()
        self.obs_pool_weight.collect(self.pool_weights)
        self.obs_pool_weight.compute_qparams()
        self.obs_pool_weight.enabled = obs_pool_weight_enabled
        ...

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dvsav

The 1st and the 2nd adapters use the same wrapped module don't they?

Thanks, that is a fair point. I agree that the practical risk is limited. Actually, that is why I tagged the comment as an optional.

I just wanted to make it a bit more robust for future changes since there's many changes recently.

It'd better document or enforce the single-use/profile contract may be sufficient in another PR.

output_length: Number of soft tokens to produce. When ``None``
the pooler returns the same sequence length.
"""
assert output_length is not None

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess original module checks if hidden_states.shape[1] == output_length. If you intentionally ignored that logic, it would be better to add an assertion for this.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 Fixed, thank you for catching that.

@dvsav dvsav requested a review from mhs4670go June 22, 2026 07:02
@dvsav dvsav force-pushed the vision_pooler branch 4 times, most recently from 5c54e97 to 8be9bda Compare June 22, 2026 08:18
Implements quantization wrapper for Gemma4 vision pooler.

TICO-DCO-1.0-Signed-off-by: d.savchenkov <d.savchenkov@partner.samsung.com>

@mhs4670go mhs4670go left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks.

@mhs4670go mhs4670go merged commit efe50a7 into Samsung:main Jun 22, 2026
7 checks passed
@dvsav dvsav deleted the vision_pooler branch June 22, 2026 08:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants