Skip to content

[quantization] Add a gemma wrapper for TextModel#791

Open
mhs4670go wants to merge 3 commits into
Samsung:mainfrom
mhs4670go:tex
Open

[quantization] Add a gemma wrapper for TextModel#791
mhs4670go wants to merge 3 commits into
Samsung:mainfrom
mhs4670go:tex

Conversation

@mhs4670go

@mhs4670go mhs4670go commented Jun 22, 2026

Copy link
Copy Markdown
Contributor

This commit adds a wrapper for gemma text model.

python -m tico.quantization.examples.inspect \
  --config tico/quantization/examples/configs/wrapper_smoke.yaml \
  --mode wrapper-smoke \
  --case gemma4_text_model \
  --strict
┌───────────── Wrapper Smoke Summary ─────────────
│ Case             : gemma4_text_model
│ Status           : PASS
│ Mean |diff|      : 0.079053
│ Max |diff|       : 0.483883
│ PEIR             : 0.083751
│ Shape match      : True
│ Quant finite     : True
└─────────────────────────────────────────────────
    ┌────────────────────────────────────────────┐
 3.7┤                                            │
    │                                            │
    │                                     •   •  │
 2.6┤                                            │
    │                                 •          │
    │                              •  ••         │
    │                             ••••           │
 1.6┤                            •••             │
    │                        •••••               │
    │                      •••••                 │
 0.5┤                    •••••                   │
    │                   •••••                    │
    │                ••••••                      │
    │              •••••                         │
-0.5┤            ••••••                          │
    │          •••••                             │
    │         ••••                               │
-1.6┤      ••••                                  │
    │    ••••••                                  │
    │   •••                                      │
    │  •••                                       │
-2.7┤                                            │
    └┬──────────┬──────────┬─────────┬──────────┬┘
   -2.7       -1.1        0.5       2.1       3.7 

TICO-DCO-1.0-Signed-off-by: seongwoo mhs4670go@naver.com

@mhs4670go mhs4670go force-pushed the tex branch 2 times, most recently from 84fe0d5 to f177e87 Compare June 22, 2026 14:00
@mhs4670go mhs4670go requested a review from dvsav June 22, 2026 14:06
@dvsav

dvsav commented Jun 22, 2026

Copy link
Copy Markdown
Contributor

Note for myself

What

This commit adds a full PTQ wrapper (QuantGemma4TextModel) for the dense Gemma4 E2B text model, enabling quantization of the complete Gemma4TextModel architecture. The wrapper replaces a previous stub/skeleton implementation with a production-ready version that constructs bounded static attention masks and RoPE templates internally (without calling HuggingFace mask factories), supports Gemma4-specific Per-Layer Embeddings (PLE), shared KV states, and is compatible with the torch.export-based conversion pipeline. The wrapper is also activated in the module registry.

Why

The Gemma4 model family requires a text-model-level wrapper to enable end-to-end PTQ quantization and conversion to Circle format. Previously, only the sub-component wrappers (attention, MLP, decoder layer) were registered, meaning the full text model could not be prepared as a single quantizable unit. This commit fills that gap by providing the top-level QuantGemma4TextModel wrapper that orchestrates the sub-layer wrappers and handles Gemma4-specific concerns:

  • Static mask construction: NPU export requires precomputed bounded masks rather than dynamic HuggingFace mask factory calls. The wrapper builds full causal and sliding-window causal mask templates as registered buffers, sliced at runtime.
  • Finite mask fill values: Using PTQConfig.attention_mask_fill_value (a finite negative number) instead of -inf keeps affine observer ranges usable during calibration.
  • Static RoPE templates: Precomputed cos/sin tables per layer type avoid dynamic RoPE computation during torch.export tracing.
  • PLE support: Gemma4's Per-Layer Embeddings mechanism (token-identity + context-aware projection) needs special handling in the quantized path, including reverse input-id lookup from embeddings for multimodal use cases.

Key Design Decisions

  1. Bounded static mask templates over HF mask factories: The wrapper pre-builds full causal and sliding-window causal masks as buffers during __init__, then slices them at runtime. This eliminates the dependency on HuggingFace's create_causal_mask / create_sliding_window_causal_mask during the forward pass, which is essential for torch.export compatibility and NPU deployment.

  2. Finite fill value instead of -inf: Masked positions use PTQConfig.attention_mask_fill_value (a large finite negative number) rather than the minimum finite value of the activation dtype. This prevents affine observers from seeing extreme outlier values that would distort quantization parameter computation.

  3. Static RoPE templates with dynamic fallback: Position embeddings are precomputed and registered as buffers for the export/static path, but the wrapper falls back to calling self.rotary_emb() dynamically when not in export context (and position_ids were explicitly provided). This balances export compatibility with calibration flexibility.

  4. No as_export_module on TextModel: The complete QuantGemma4TextModel is intentionally not the static NPU export unit. Export happens at the decoder-layer granularity. The wrapper explicitly raises NotImplementedError if static inputs are required without a pre-provided mask mapping.

  5. MoE rejection: The E2B scope explicitly rejects Mixture-of-Experts configs via assert_gemma4_e2b_no_moe, keeping the wrapper focused on dense decoder layers only.

  6. Reverse input-id lookup for PLE: When only inputs_embeds is provided (typical in multimodal VLM pipelines), the wrapper can reverse-lookup input_ids from the embedding matrix to compute per-layer token-identity embeddings. This avoids requiring callers to provide both.

  7. force_export flag + torch.compiler.is_compiling() detection: The wrapper detects whether it's running inside a torch.export context and automatically switches to static template paths. The force_export flag allows explicit override.

Changes

  • tico/quantization/wrapq/wrappers/gemma4/quant_text_model.py - Major expansion from a stub to a full 848-line implementation of QuantGemma4TextModel. Key additions include:

    • Bounded static mask template construction (_build_full_attention_mask_template, _build_sliding_attention_mask_template) and registration as buffers
    • Static RoPE cos/sin template precomputation and registration
    • Position ID creation/validation with static template support (_make_position_ids)
    • Attention mask normalization and construction from static templates (_normalize_attention_mask_for_layer, _create_attention_mask_mapping, _observe_attention_mask_mapping)
    • Per-Layer Embeddings (PLE) support: get_per_layer_inputs, project_per_layer_inputs, _reverse_input_ids_from_embeddings
    • Full forward method with DynamicCache support, shared KV states, return_dict/output_hidden_states/output_attentions/return_shared_kv_states output options
    • Comprehensive observer collection (_all_observers) including attention masks, position embeddings, and PLE observers
    • Input validation for invalid argument combinations
    • _requires_static_inputs() detection for torch.export context
    • _unwrap_layer_output helper for decoder layer output extraction
    • _output_cls() returning Gemma4TextModelOutputWithPast for HF-compatible output
  • test/quantization/wrapq/wrappers/gemma4/test_quant_text_model.py - New 372-line test file with 9 unit tests covering:

    • Registry-based prepare wrapping
    • No-quant forward parity with HuggingFace (input_ids path)
    • Static CPU-provided mask mapping path
    • PLE path with input_ids
    • PLE path with inputs_embeds and explicit per_layer_inputs (multimodal-style)
    • Shared KV states return
    • Input validation error messages
    • Absence of as_export_module on TextModel
    • Mode transitions and observer collection lifecycle
    • MoE config rejection
  • tico/quantization/recipes/debug/wrapper_smoke/cases/gemma4.py - Added Gemma4TextModelCase smoke case class and included it in the GEMMA4_CASES tuple. The case builds a tiny Gemma4TextModel with sliding+full attention, provides calibration/eval inputs with input_ids and attention_mask, and uses return_dict=True.

  • tico/quantization/wrapq/wrappers/registry.py - Uncommented the quant_text_model module line in _CORE_MODULES, activating the QuantGemma4TextModel wrapper in the registry so it is discovered and used by prepare().

Tests

The changes are tested through two complementary mechanisms:

  1. Unit tests (test_quant_text_model.py): 9 test cases covering the core functionality:

    • Registry integration: verifies prepare() produces a PTQWrapper containing QuantGemma4TextModel
    • Numerical parity: no-quant wrapper output matches HuggingFace reference for both input_ids and inputs_embeds paths (using torch.allclose with atol=1e-5, rtol=1e-5)
    • Static mask mapping: verifies the CPU-provided mask dict path matches HF output
    • PLE paths: both token-identity (input_ids) and multimodal-style (inputs_embeds + explicit per_layer_inputs)
    • Shared KV states: verifies output contains shared KV states when requested
    • Validation: checks error messages for invalid input combinations (both input_ids and inputs_embeds, per_layer_inputs with input_ids, static export without mask dict)
    • Export contract: confirms TextModel does not own as_export_module
    • Lifecycle: calibration → freeze → quantization mode transitions with observer verification
    • Scope: MoE configs are explicitly rejected

    All tests are gated by @unittest.skipUnless(_has_gemma4(), ...) to skip gracefully when transformers doesn't provide Gemma4.

$ python -m pytest test/quantization/wrapq/wrappers/gemma4/test_quant_text_model.py -v
============================================================================== test session starts ===============================================================================
platform linux -- Python 3.10.12, pytest-9.0.3, pluggy-1.6.0 -- /home/d.savchenkov/myenv/bin/python
cachedir: .pytest_cache
rootdir: /home/d.savchenkov/TICO
configfile: pyproject.toml
plugins: anyio-4.13.0
collected 10 items                                                                                                                                                               

test/quantization/wrapq/wrappers/gemma4/test_quant_text_model.py::TestQuantGemma4TextModel::test_00_prepare_wraps_text_model_when_registered PASSED                        [ 10%]
test/quantization/wrapq/wrappers/gemma4/test_quant_text_model.py::TestQuantGemma4TextModel::test_mode_transitions_and_observer_collection PASSED                           [ 20%]
test/quantization/wrapq/wrappers/gemma4/test_quant_text_model.py::TestQuantGemma4TextModel::test_moe_text_model_is_rejected_for_e2b_scope PASSED                           [ 30%]
test/quantization/wrapq/wrappers/gemma4/test_quant_text_model.py::TestQuantGemma4TextModel::test_no_quant_forward_matches_hf_text_model_with_input_ids PASSED              [ 40%]
test/quantization/wrapq/wrappers/gemma4/test_quant_text_model.py::TestQuantGemma4TextModel::test_ple_path_matches_hf_with_input_ids PASSED                                 [ 50%]
test/quantization/wrapq/wrappers/gemma4/test_quant_text_model.py::TestQuantGemma4TextModel::test_ple_path_matches_hf_with_inputs_embeds_and_explicit_per_layer_inputs PASSED [ 60%]
test/quantization/wrapq/wrappers/gemma4/test_quant_text_model.py::TestQuantGemma4TextModel::test_shared_kv_text_model_returns_shared_state_when_requested PASSED           [ 70%]
test/quantization/wrapq/wrappers/gemma4/test_quant_text_model.py::TestQuantGemma4TextModel::test_static_attention_mask_mapping_matches_hf_text_model PASSED                [ 80%]
test/quantization/wrapq/wrappers/gemma4/test_quant_text_model.py::TestQuantGemma4TextModel::test_text_model_wrapper_does_not_own_export_adapter_hook PASSED                [ 90%]
test/quantization/wrapq/wrappers/gemma4/test_quant_text_model.py::TestQuantGemma4TextModel::test_validation_errors_match_expected_contract PASSED                          [100%]

============================================================================== 10 passed in 11.32s ===============================================================================

  1. Smoke test case (gemma4.py): Gemma4TextDecoderLayerSlidingPrefillCase added to the GEMMA4_CASES tuple for integration-level wrapper smoke testing.
$ python -m tico.quantization.examples.inspect \
    --config tico/quantization/examples/configs/wrapper_smoke.yaml \
    --mode wrapper-smoke \
    --case gemma4_text_decoder_layer_sliding_prefill \
    --export circle \
    --output-dir ./out/wrapper_smoke
┌───────────── Wrapper Smoke Summary ─────────────
│ Case             : gemma4_text_decoder_layer_sliding_prefill
│ Status           : PASS
│ Mean |diff|      : 0.116630
│ Max |diff|       : 1.095015
│ PEIR             : 0.118720
│ Shape match      : True
│ Quant finite     : True
└─────────────────────────────────────────────────
Artifacts:
  - circle: out/wrapper_smoke/gemma4_text_decoder_layer_sliding_prefill.q.circle
    ┌────────────────────────────────────────────┐
 4.9┤                                            │
    │                                       ••   │
    │                                      •• •  │
 3.2┤                                    •••     │
    │                                  ••••      │
    │                               ••••         │
    │                             ••••           │
 1.5┤                           ••••             │
    │                        •••••••             │
    │                      •••••                 │
-0.2┤                     ••••                   │
    │                 •••••••                    │
    │                 ••••                       │
    │             ••••••                         │
-1.9┤            •••••                           │
    │        •  ••••                             │
    │         •••                                │
-3.6┤       •••                                  │
    │     •••                                    │
    │     ••                                     │
    │  ••                                        │
-5.3┤                                            │
    └┬──────────┬──────────┬─────────┬──────────┬┘
   -5.3       -2.7       -0.2       2.4       4.9 

Torrero
Torrero previously approved these changes Jun 23, 2026

@Torrero Torrero left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Comment on lines +613 to +615
matches = (inputs_embeds[:, :, None, :] == weight[None, None, :, :]).all(
dim=-1
)

@dvsav dvsav Jun 23, 2026

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect that the exact equality comparison inputs_embeds[...] == weight[...] may fail:

  1. In QUANT mode, inputs_embeds is fake-quantized before the reverse lookup is called, so it no longer exactly matches any row in the raw weight table.
  2. With dtype casting, floating-point non-associativity breaks the comparison: (weight * scale).to(fp16) != weight.to(fp16) * scale.to(fp16), because the multiplication happens in different precision contexts on each side.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. The reverse lookup is inherently a floating-point exact-match operation against the raw embedding table as you said, so it does not fit well with the quantization path.

I changed the QUANT path so that, when PLE is enabled and inputs_embeds is provided, callers must also provide explicit per_layer_inputs. This prevents the reverse lookup from running on fake-quantized embeddings.

@dvsav

dvsav commented Jun 23, 2026

Copy link
Copy Markdown
Contributor

🤔

Gemma4TextModelCase

$ python -m tico.quantization.examples.inspect \
    --config tico/quantization/examples/configs/wrapper_smoke.yaml \
    --mode wrapper-smoke \
    --case gemma4_text_model \
    --export circle \
    --output-dir ./out/wrapper_smoke
┌───────────── Wrapper Smoke Summary ─────────────
│ Case             : gemma4_text_model
│ Status           : FAIL
│ Mean |diff|      : 0.079053
│ Max |diff|       : 0.483881
│ PEIR             : 0.083751
│ Shape match      : True
│ Quant finite     : True
└─────────────────────────────────────────────────
Messages:
  - Circle export failed: Observed exception
  Explanation: Dynamo found no exception handler at the top-level compiled function when encountering an exception. Exception will propagate outside the compiled region.
  Hint: Your code may result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled. You can do this by removing the `torch.compile` call, or by using `torch.compiler.set_stance("force_eager")`. 
  Hint: It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues.

  Developer debug context: raised exception NotImplementedError([ConstantVariable(str: 'QuantGemma4TextModel static/export mode requires static masks as a dict keyed by layer type.')])

 For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0088.html

from user code:
   File "myenv/lib/python3.10/site-packages/torch/_dynamo/functional_export.py", line 221, in forward
    res = self._export_root(*args, **kwargs)
  File "TICO/tico/quantization/wrapq/wrappers/ptq_wrapper.py", line 46, in forward
    return self.wrapped(*args, **kwargs)
  File "TICO/tico/quantization/wrapq/wrappers/gemma4/quant_text_model.py", line 761, in forward
    mask_mapping = self._create_attention_mask_mapping(
  File "TICO/tico/quantization/wrapq/wrappers/gemma4/quant_text_model.py", line 532, in _create_attention_mask_mapping
    raise NotImplementedError(

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

    ┌────────────────────────────────────────────┐
 3.7┤                                            │
    │                                         •  │
 2.6┤                                     •      │
    │                                 ••         │
    │                             •••••          │
 1.6┤                            •••             │
    │                        •••••               │
 0.5┤                    ••••••                  │
    │                  ••••••                    │
    │              •••••• •                      │
-0.5┤            ••••••                          │
    │          •••••                             │
-1.6┤      ••••••                                │
    │    ••••••                                  │
    │  •••                                       │
-2.7┤                                            │
    └┬──────────┬──────────┬─────────┬──────────┬┘
   -2.7       -1.1        0.5       2.1       3.7

This commit adds a wrapper for gemma text model.

TICO-DCO-1.0-Signed-off-by: seongwoo <mhs4670go@naver.com>
@mhs4670go

mhs4670go commented Jun 24, 2026

Copy link
Copy Markdown
Contributor Author

@dvsav

The export failure is expected because Gemma4TextModelCase currently validates only PTQ preparation, calibration, conversion, and numerical parity. Exporting the full text model is not considered here. It can be re-consdiered later we design a runtime.

To make this limitation clear, I explicitly marked Circle export as unsupported for this case so that an export request fails early with a descriptive message, rather than surfacing as an unexpected conversion error.

python -m tico.quantization.examples.inspect \
    --config tico/quantization/examples/configs/wrapper_smoke.yaml \
    --mode wrapper-smoke \
    --case gemma4_text_model \
    --export circle \
    --output-dir ./out/wrapper_smoke
┌───────────── Wrapper Smoke Summary ─────────────
│ Case             : gemma4_text_model
│ Status           : FAIL
│ Mean |diff|      : 0.079053
│ Max |diff|       : 0.483883
│ PEIR             : 0.083751
│ Shape match      : True
│ Quant finite     : True
└─────────────────────────────────────────────────
Messages:
  - This case validates PTQ numerical parity only. Full Gemma4TextModel Circle export requires a dedicated static adapter.
    ┌────────────────────────────────────────────┐
 3.7┤                                            │
    │                                            │
    │                                     •   •  │
 2.6┤                                            │
    │                                 •          │
    │                              •  ••         │
    │                             ••••           │
 1.6┤                            •••             │
    │                        •••••               │
    │                      •••••                 │
 0.5┤                    •••••                   │
    │                   •••••                    │
    │                ••••••                      │
    │              •••••                         │
-0.5┤            ••••••                          │
    │          •••••                             │
    │         ••••                               │
-1.6┤      ••••                                  │
    │    ••••••                                  │
    │   •••                                      │
    │  •••                                       │
-2.7┤                                            │
    └┬──────────┬──────────┬─────────┬──────────┬┘
   -2.7       -1.1        0.5       2.1       3.7

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants