Skip to content

[quantization] Add QuantGemma4TextScaledWordEmbedding PTQ wrapper#788

Merged
dvsav merged 1 commit into
Samsung:mainfrom
dvsav:word_emb
Jun 22, 2026
Merged

[quantization] Add QuantGemma4TextScaledWordEmbedding PTQ wrapper#788
dvsav merged 1 commit into
Samsung:mainfrom
dvsav:word_emb

Conversation

@dvsav

@dvsav dvsav commented Jun 18, 2026

Copy link
Copy Markdown
Contributor

What

This PR adds complete Post-Training Quantization (PTQ) support for the Gemma4TextScaledWordEmbedding module, a key component of the Gemma4 multimodal model family. The implementation includes a comprehensive PTQ wrapper with per-channel weight quantization, embed_scale fake quantization, full test coverage (14 unit tests + 3 smoke tests), and an example script demonstrating the complete quantization flow with Circle format export.


Why

The Gemma4TextScaledWordEmbedding module extends standard embedding layers by multiplying token embeddings with a scalar scale factor (embed_scale). This scaling operation is critical for Gemma4's numerical stability and must be properly quantized to maintain accuracy in the static-shape NPU inference flow.

Prior to this change, the wrapper existed as a skeleton with only basic weight observation. This PR completes the implementation by:

  • Adding proper fake quantization for both the embedding output and the scale factor
  • Enabling per-channel asymmetric weight quantization (matching the pattern used in QuantEmbedding)
  • Making the wrapper exportable for Circle format conversion
  • Providing comprehensive test coverage to validate correctness

This change is part of the broader Gemma4 E2B static PTQ skeleton effort, moving individual wrappers from skeleton status to fully functional quantization modules.


Key Design Decisions

1. Per-Channel Asymmetric Weight Quantization

Decision: Use QScheme.PER_CHANNEL_ASYMM with channel_axis=0 for the weight observer.

Rationale: This matches the pattern established in QuantEmbedding (tico/quantization/wrapq/wrappers/nn/quant_embedding.py). Per-channel quantization provides better accuracy for embedding tables because each embedding vector can have its own scale/zero-point, accommodating varying value ranges across the vocabulary.

2. Four-Observer Architecture

Decision: Add 4 observers: obs_weight, obs_embedding, obs_embed_scale, obs_act_out.

Rationale:

  • obs_weight: Quantizes the embedding weight matrix (per-channel)
  • obs_embedding: Quantizes the raw embedding output before scaling
  • obs_embed_scale: Quantizes the scalar scale factor itself
  • obs_act_out: Quantizes the final scaled output

This granular observation ensures all intermediate tensors are properly quantized, maintaining numerical consistency through the embedding → scale → output chain.

3. Fake Quantization on embed_scale

Decision: Apply fake quantization to the scale factor in QUANT mode.

Rationale: While embed_scale is a scalar, quantizing it ensures the multiplication hidden_states * scale operates on quantized values, maintaining consistency with the fake quantization paradigm. The quantization error on a scalar is negligible but including it ensures the graph accurately represents the quantized inference behavior.

4. as_export_module Returns Self

Decision: The as_export_module() method returns self after asserting QUANT mode.

Rationale: The wrapper is already exportable—its forward method uses only torch.export-compatible operations (embedding lookup, multiplication, fake_quant). No additional adaptation is needed, following the pattern used in other simple wrappers like QuantGemma4VisionPooler.


Changes

File Changes
tico/quantization/wrapq/wrappers/gemma4/quant_text_scaled_word_embedding.py Added QScheme import; changed weight observer to per-channel asymmetric (channel_axis=0); added obs_embedding and obs_embed_scale observers; added embed_scale collection in enable_calibration(); added fake quantization for embedding output and scale in forward(); added as_export_module() method; updated _all_observers() to return all 4 observers
test/quantization/wrapq/wrappers/gemma4/test_quant_text_scaled_word_embedding.py NEW - 14 unit tests covering: NO_QUANT mode forward match, output shape, mode transitions, observer collection, weight/scale observation in CALIB mode, fake quantization in QUANT mode, finite output, dtype overrides, qscheme defaults, as_export_module requirements
test/quantization/wrapq/wrappers/gemma4/test_quantize_text_scaled_word_embedding.py NEW - 3 smoke tests for prepare-calibrate-convert flow: NO_QUANT match with FP reference, full PTQ flow validation, as_export_module export flow (skipped by default, requires RUN_INTERNAL_TESTS=1)
tico/quantization/wrapq/examples/gemma4/quantize_text_scaled_word_embedding.py NEW - Complete example script demonstrating: tiny model creation, calibration data generation, PTQ preparation/calibration/conversion, PEIR error computation, visualization, Circle format export
tico/quantization/wrapq/wrappers/registry.py Uncommented quant_text_scaled_word_embedding module registration, enabling automatic wrapper discovery
tico/quantization/recipes/debug/wrapper_smoke/cases/gemma4.py Added smoke test case for Gemma4 scaled word embedding module

Tests

Unit Tests (14 tests)

File: test/quantization/wrapq/wrappers/gemma4/test_quant_text_scaled_word_embedding.py

Test Purpose
test_no_quant_forward_matches_fp Verifies NO_QUANT mode delegates to original module
test_no_quant_output_shape Validates output shape (batch, seq, dim)
test_mode_transitions Tests NO_QUANT → CALIB → QUANT lifecycle
test_observers_are_collected Confirms all 4 observers present
test_weight_is_observed_in_calib_mode Verifies weight collection during calibration
test_embed_scale_is_observed_in_calib_mode Verifies scale collection during calibration
test_output_is_fake_quantized_in_quant_mode Confirms fake_quant applied in QUANT mode
test_quant_mode_output_is_finite Validates finite output with correct shape
test_dtype_override Tests PTQConfig dtype override propagation
test_weight_uses_per_channel_asymm_by_default Confirms default qscheme
test_as_export_module_requires_quant_mode Verifies QUANT mode assertion
test_as_export_module_returns_self Confirms self-return for export
$ python -m pytest test/quantization/wrapq/wrappers/gemma4/test_quant_text_scaled_word_embedding.py -v
==================================================================== test session starts ====================================================================
platform linux -- Python 3.10.12, pytest-8.4.0, pluggy-1.6.0 -- /home/d.savchenkov/myenv/bin/python
cachedir: .pytest_cache
rootdir: /home/d.savchenkov/TICO
configfile: pyproject.toml
plugins: anyio-4.12.0, mock-3.15.1, xdist-3.7.0, cov-6.2.1
collected 12 items                                                                                                                                          

test/quantization/wrapq/wrappers/gemma4/test_quant_text_scaled_word_embedding.py::TestQuantGemma4TextScaledWordEmbedding::test_as_export_module_requires_quant_mode PASSED [  8%]
test/quantization/wrapq/wrappers/gemma4/test_quant_text_scaled_word_embedding.py::TestQuantGemma4TextScaledWordEmbedding::test_as_export_module_returns_self PASSED [ 16%]
test/quantization/wrapq/wrappers/gemma4/test_quant_text_scaled_word_embedding.py::TestQuantGemma4TextScaledWordEmbedding::test_dtype_override PASSED  [ 25%]
test/quantization/wrapq/wrappers/gemma4/test_quant_text_scaled_word_embedding.py::TestQuantGemma4TextScaledWordEmbedding::test_embed_scale_is_observed_in_calib_mode PASSED [ 33%]
test/quantization/wrapq/wrappers/gemma4/test_quant_text_scaled_word_embedding.py::TestQuantGemma4TextScaledWordEmbedding::test_mode_transitions PASSED [ 41%]
test/quantization/wrapq/wrappers/gemma4/test_quant_text_scaled_word_embedding.py::TestQuantGemma4TextScaledWordEmbedding::test_no_quant_forward_matches_fp PASSED [ 50%]
test/quantization/wrapq/wrappers/gemma4/test_quant_text_scaled_word_embedding.py::TestQuantGemma4TextScaledWordEmbedding::test_no_quant_output_shape PASSED [ 58%]
test/quantization/wrapq/wrappers/gemma4/test_quant_text_scaled_word_embedding.py::TestQuantGemma4TextScaledWordEmbedding::test_observers_are_collected PASSED [ 66%]
test/quantization/wrapq/wrappers/gemma4/test_quant_text_scaled_word_embedding.py::TestQuantGemma4TextScaledWordEmbedding::test_output_is_fake_quantized_in_quant_mode PASSED [ 75%]
test/quantization/wrapq/wrappers/gemma4/test_quant_text_scaled_word_embedding.py::TestQuantGemma4TextScaledWordEmbedding::test_quant_mode_output_is_finite PASSED [ 83%]
test/quantization/wrapq/wrappers/gemma4/test_quant_text_scaled_word_embedding.py::TestQuantGemma4TextScaledWordEmbedding::test_weight_is_observed_in_calib_mode PASSED [ 91%]
test/quantization/wrapq/wrappers/gemma4/test_quant_text_scaled_word_embedding.py::TestQuantGemma4TextScaledWordEmbedding::test_weight_uses_per_channel_asymm_by_default PASSED [100%]

============================================================== 12 passed, 2 warnings in 4.56s ===============================================================

Internal Tests (3 tests)

File: test/quantization/wrapq/wrappers/gemma4/test_quantize_text_scaled_word_embedding.py

Test Purpose
test_no_quant_embedding_matches_reference Wrapper matches FP before quantization
test_prepare_convert_embedding_flow Full PTQ prepare-calibrate-convert flow
test_as_export_module_flow Export module validation for Circle
$ RUN_INTERNAL_TESTS=1 python -m pytest test/quantization/wrapq/wrappers/gemma4/test_quantize_text_scaled_word_embedding.py -v
==================================================================== test session starts ====================================================================
platform linux -- Python 3.10.12, pytest-8.4.0, pluggy-1.6.0 -- /home/d.savchenkov/myenv/bin/python
cachedir: .pytest_cache
rootdir: /home/d.savchenkov/TICO
configfile: pyproject.toml
plugins: anyio-4.12.0, mock-3.15.1, xdist-3.7.0, cov-6.2.1
collected 3 items                                                                                                                                           

test/quantization/wrapq/wrappers/gemma4/test_quantize_text_scaled_word_embedding.py::TestGemma4TextScaledWordEmbeddingSmoke::test_as_export_module_flow PASSED [ 33%]
test/quantization/wrapq/wrappers/gemma4/test_quantize_text_scaled_word_embedding.py::TestGemma4TextScaledWordEmbeddingSmoke::test_no_quant_embedding_matches_reference PASSED [ 66%]
test/quantization/wrapq/wrappers/gemma4/test_quantize_text_scaled_word_embedding.py::TestGemma4TextScaledWordEmbeddingSmoke::test_prepare_convert_embedding_flow PASSED [100%]

=============================================================== 3 passed, 2 warnings in 4.72s ===============================================================

Smoke Test

$ python -m tico.quantization.examples.inspect \
    --config tico/quantization/examples/configs/wrapper_smoke.yaml \
    --mode wrapper-smoke \
    --case gemma4_text_scaled_word_embedding \
    --export circle \
    --output-dir ./out/wrapper_smoke
[QuantCheck] WARNING: 1 nodes without qparam detected (see logs).
┌───────────── Wrapper Smoke Summary ─────────────
│ Case             : gemma4_text_scaled_word_embedding
│ Status           : PASS
│ Mean |diff|      : 0.000965
│ Max |diff|       : 0.002763
│ PEIR             : 0.003481
│ Shape match      : True
│ Quant finite     : True
└─────────────────────────────────────────────────
Artifacts:
  - circle: out/wrapper_smoke/gemma4_text_scaled_word_embedding.q.circle
     ┌───────────────────────────────────────────┐
 0.40┤                                           │
     │                                      •••  │
 0.25┤                                  •• •     │
     │                               ••••        │
 0.11┤                           •••••           │
     │                        ••••               │
-0.04┤                     ••••                  │
     │                  ••••                     │
     │              •••••                        │
-0.18┤           •••••                           │
     │         •••                               │
-0.33┤      •••                                  │
     │  •                                        │
-0.47┤                                           │
     └┬──────────┬─────────┬──────────┬─────────┬┘
    -0.47      -0.26     -0.04      0.18     0.40

Example Script

File: tico/quantization/wrapq/examples/gemma4/quantize_text_scaled_word_embedding.py

The example script demonstrates the complete PTQ workflow:

  1. Model Creation: Creates a tiny Gemma4TextScaledWordEmbedding (vocab=1000, dim=64) without downloading pretrained weights
  2. Calibration Data: Generates 20 synthetic input ID sequences
  3. PTQ Flow: Prepare → Calibrate → Convert
  4. Error Analysis: Computes PEIR (Peak Error-to-Interval Ratio) and displays visualization
  5. Circle Export: Exports to Circle format via tico.convert()
$ python tico/quantization/wrapq/examples/gemma4/quantize_text_scaled_word_embedding.py
┌───────────── Quantization Error Summary ─────────────
│ FP output shape    : (1, 16, 64)
│ Quant output shape : (1, 16, 64)
│ Mean |diff|        : 0.001094
│ PEIR               : 0.396075 %
└──────────────────────────────────────────────────────
     ┌───────────────────────────────────────────┐
 0.44┤                                           │
     │                                        •  │
     │                                    •••    │
 0.30┤                                 •••       │
     │                               •••         │
     │                            ••••           │
 0.15┤                          •••              │
     │                       ••••                │
 0.00┤                     ••••                  │
     │                   •••                     │
     │                ••••                       │
-0.14┤              •••                          │
     │           ••••                            │
     │         •••                               │
-0.29┤      •••                                  │
     │    •••                                    │
     │  ••                                       │
-0.43┤                                           │
     └┬──────────┬─────────┬──────────┬─────────┬┘
    -0.43      -0.21     0.00       0.22     0.44 


Converting to Circle format...
[QuantCheck] WARNING: 1 nodes without qparam detected (see logs).
Circle model saved as 'gemma4_text_scaled_word_embedding.q.circle'

@dvsav dvsav requested a review from Torrero June 18, 2026 08:41
@dvsav dvsav force-pushed the word_emb branch 2 times, most recently from 1ed9d8d to 3ef7fa3 Compare June 22, 2026 08:52
@dvsav dvsav marked this pull request as ready for review June 22, 2026 09:02
@dvsav dvsav requested a review from mhs4670go June 22, 2026 09:02
Torrero
Torrero previously approved these changes Jun 22, 2026

@Torrero Torrero left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@mhs4670go

Copy link
Copy Markdown
Contributor

The test command is test_quantize_text_scaled_word_embedding.py but seems that the result is about test_quantize_vision_pooler.py. Maybe something wrong?

$ RUN_INTERNAL_TESTS=1 python -m pytest test/quantization/wrapq/wrappers/gemma4/test_quantize_text_scaled_word_embedding.py -v
================================================================================ test session starts =================================================================================
platform linux -- Python 3.10.12, pytest-8.4.0, pluggy-1.6.0 -- /home/d.savchenkov/myenv/bin/python3
cachedir: .pytest_cache
rootdir: /home/d.savchenkov/TICO
configfile: pyproject.toml
plugins: anyio-4.12.0, mock-3.15.1, xdist-3.7.0, cov-6.2.1
collected 3 items                                                                                                                                                                    

test/quantization/wrapq/wrappers/gemma4/test_quantize_vision_pooler.py::TestGemma4VisionPoolerSmoke::test_as_export_module_flow PASSED                                         [ 33%]
test/quantization/wrapq/wrappers/gemma4/test_quantize_vision_pooler.py::TestGemma4VisionPoolerSmoke::test_no_quant_vision_pooler_matches_reference PASSED                      [ 66%]
test/quantization/wrapq/wrappers/gemma4/test_quantize_vision_pooler.py::TestGemma4VisionPoolerSmoke::test_prepare_convert_vision_pooler_flow PASSED                            [100%]

=========================================================================== 3 passed, 2 warnings in 4.54s ============================================================================

Comment on lines +55 to +57
self.register_buffer(
"embed_scale", torch.tensor(fp.embed_scale), persistent=False
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about just using original scale instead of copying as a separate buffer?

def enable_calibration(self) -> None:
    super().enable_calibration()
    self.obs_weight.collect(self.module.weight)
    self.obs_embed_scale.collect(self.module.embed_scale)

# forward
scale = self.module.embed_scale
if self._mode is Mode.QUANT:
    scale = self.obs_embed_scale.fake_quant(scale)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 done

Add complete PTQ quantization support for Gemma4TextScaledWordEmbedding.

TICO-DCO-1.0-Signed-off-by: d.savchenkov <d.savchenkov@partner.samsung.com>

@mhs4670go mhs4670go left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@dvsav dvsav merged commit bb32f01 into Samsung:main Jun 22, 2026
7 checks passed
@dvsav dvsav deleted the word_emb branch June 22, 2026 13:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants