Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions test/quantization/recipes/test_llama_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from unittest.mock import patch

import tico.quantization.recipes.adapters.llama as llama_mod
import torch
from tico.quantization.recipes.adapters.llama import LlamaAdapter
from tico.quantization.recipes.context import RecipeContext

Expand All @@ -40,6 +41,26 @@ def _fake_llama_context(cfg):
return RecipeContext(cfg=cfg, adapter=LlamaAdapter(), model=model)


class TiedEmbeddingLlamaModel(torch.nn.Module):
"""Small LLaMA-like model with tied input and output embedding weights."""

def __init__(self):
super().__init__()
self.model = SimpleNamespace(layers=[object(), object()])
self.config = SimpleNamespace(max_position_embeddings=16)
self.embed_tokens = torch.nn.Embedding(8, 4)
self.lm_head = torch.nn.Linear(4, 8, bias=False)
self.lm_head.weight = self.embed_tokens.weight

def get_input_embeddings(self):
"""Return the input embedding module."""
return self.embed_tokens

def get_output_embeddings(self):
"""Return the output embedding module."""
return self.lm_head


class TestLlamaAdapter(unittest.TestCase):
def test_build_ptq_config_forwards_profile_and_weight_options(self):
"""LlamaAdapter should forward recipe PTQ options to the config builder."""
Expand Down Expand Up @@ -69,9 +90,60 @@ def fake_build_llm_ptq_config(**kwargs):
self.assertEqual(captured["model_type"], "llama")
self.assertEqual(captured["num_hidden_layers"], 2)
self.assertEqual(captured["linear_weight_bits"], 4)
self.assertEqual(captured["embedding_weight_bits"], 8)
self.assertEqual(captured["lm_head_weight_bits"], 8)
self.assertEqual(captured["profile"], "reference_eval")
self.assertFalse(captured["strict_wrap"])

def test_build_ptq_config_defaults_lm_head_bits_to_embedding_bits(self):
"""LlamaAdapter should default LM head bit-width to embedding bit-width."""
captured = {}

def fake_build_llm_ptq_config(**kwargs):
captured.update(kwargs)
return {"ptq": "config"}

ctx = _fake_llama_context({"model_args": {"profile": "reference_eval"}})

with patch.object(llama_mod, "build_llm_ptq_config", fake_build_llm_ptq_config):
config = LlamaAdapter().build_ptq_config(
ctx,
{
"activation_dtype": "int16",
"default_qscheme": "per_tensor_symm",
"linear_weight_bits": 4,
"embedding_weight_bits": 8,
"norm_weight_dtype": "int16",
"strict_wrap": False,
},
)

self.assertEqual(config, {"ptq": "config"})
self.assertEqual(captured["embedding_weight_bits"], 8)
self.assertEqual(captured["lm_head_weight_bits"], 8)

def test_build_ptq_config_rejects_mismatched_tied_embedding_bits(self):
"""LlamaAdapter should reject mismatched bits for tied embeddings."""
ctx = RecipeContext(
cfg={"model_args": {"profile": "reference_eval"}},
adapter=LlamaAdapter(),
model=TiedEmbeddingLlamaModel(),
)

with self.assertRaisesRegex(ValueError, "tied input embedding and lm_head"):
LlamaAdapter().build_ptq_config(
ctx,
{
"activation_dtype": "int16",
"default_qscheme": "per_tensor_symm",
"linear_weight_bits": 4,
"embedding_weight_bits": 8,
"lm_head_weight_bits": 4,
"norm_weight_dtype": "int16",
"strict_wrap": False,
},
)

def test_build_calibration_inputs_rejects_non_positive_effective_sequence_length(
self,
):
Expand Down
65 changes: 65 additions & 0 deletions test/quantization/recipes/test_qwen_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from unittest.mock import patch

import tico.quantization.recipes.adapters.qwen3_vl as qwen_mod
import tico.quantization.recipes.data.vlm as vlm_data

import torch
from tico.quantization.recipes.adapters.qwen3_vl import Qwen3VLAdapter
Expand Down Expand Up @@ -107,6 +108,70 @@ def fake_build_qwen3_vl_ptq_config(**kwargs):
self.assertEqual(captured["model_args"]["vision"]["grid_thw"], (1, 8, 8))
self.assertFalse(captured["strict_wrap"])

def test_build_calibration_inputs_routes_mixed_dataset_config(self):
"""Qwen3VLAdapter should route mixed calibration datasets to the data helper."""
captured = {}
datasets = [
{"dataset": "vqav2", "split": "testdev", "n_samples": 3},
{"dataset": "wikitext2", "split": "train", "n_samples": 5},
]
ctx = RecipeContext(
cfg={
"runtime": {"seed": 7},
"calibration": {"datasets": datasets, "seq_len": 128},
},
adapter=Qwen3VLAdapter(),
model=_fake_qwen_model(),
)
ctx.processor = object()

def fake_build_vlm_calibration_inputs(**kwargs):
captured.update(kwargs)
return [{"input_ids": torch.ones(1, 2)}]

with patch.object(
qwen_mod,
"build_vlm_calibration_inputs",
fake_build_vlm_calibration_inputs,
):
result = Qwen3VLAdapter().build_calibration_inputs(ctx)

self.assertEqual(len(result), 1)
self.assertTrue(torch.equal(result[0]["input_ids"], torch.ones(1, 2)))
self.assertEqual(captured["datasets"], datasets)
self.assertEqual(captured["max_seq_len"], 128)
self.assertEqual(captured["seed"], 7)

def test_vlm_data_helper_parses_mixed_dataset_string(self):
"""The VLM data helper should parse old CLI-style mixed dataset specs."""
captured = {}

def fake_get_mixed_calib_inputs(**kwargs):
captured.update(kwargs)
return ["mixed"]

with patch.object(
vlm_data, "get_mixed_calib_inputs", fake_get_mixed_calib_inputs
):
result = vlm_data.build_vlm_calibration_inputs(
processor=object(),
dataset="vqav2:testdev:3,wikitext2:train:5",
n_samples=1,
max_seq_len=128,
seed=11,
)

self.assertEqual(result, ["mixed"])
self.assertEqual(
captured["dataset_config"],
{
"vqav2": {"split": "testdev", "n_samples": 3},
"wikitext2": {"split": "train", "n_samples": 5},
},
)
self.assertEqual(captured["max_seq_len"], 128)
self.assertEqual(captured["seed"], 11)

def test_apply_smoothquant_maps_component_selection_to_excluded_appliers(self):
"""SmoothQuant component selection should translate to excluded applier names."""
adapter = Qwen3VLAdapter()
Expand Down
2 changes: 1 addition & 1 deletion tico/quantization/examples/configs/llama_gptq_ptq.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ pipeline:
default_qscheme: per_tensor_symm
linear_weight_bits: 4
embedding_weight_bits: 8
lm_head_weight_bits: 4
lm_head_weight_bits: 8
spin_rotation_weight_bits: 16
norm_weight_dtype: int16
strict_wrap: true
Expand Down
9 changes: 7 additions & 2 deletions tico/quantization/examples/configs/qwen3_vl_gptq_ptq.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,13 @@ runtime:
show_progress: true

calibration:
dataset: vqav2
n_samples: 128
datasets:
- dataset: vqav2
split: testdev
n_samples: 128
- dataset: wikitext2
split: train
n_samples: 128
seq_len: 2048

model_args:
Expand Down
116 changes: 114 additions & 2 deletions tico/quantization/recipes/adapters/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,111 @@
)


def _weights_share_storage(left: torch.Tensor, right: torch.Tensor) -> bool:
"""Return True if two weight tensors share the exact same storage slice."""
if left is right:
return True

if not isinstance(left, torch.Tensor) or not isinstance(right, torch.Tensor):
return False

if left.device != right.device:
return False

if left.device.type == "meta" or right.device.type == "meta":
return False

if left.numel() == 0 or right.numel() == 0:
return False

return (
left.untyped_storage().data_ptr() == right.untyped_storage().data_ptr()
and left.storage_offset() == right.storage_offset()
and tuple(left.shape) == tuple(right.shape)
and tuple(left.stride()) == tuple(right.stride())
)


def has_tied_input_output_embeddings(model: torch.nn.Module) -> bool:
"""Return True if the input embedding and LM head weights are tied."""
get_input_embeddings = getattr(model, "get_input_embeddings", None)
get_output_embeddings = getattr(model, "get_output_embeddings", None)

if not callable(get_input_embeddings) or not callable(get_output_embeddings):
return False

input_embeddings = get_input_embeddings()
output_embeddings = get_output_embeddings()

if input_embeddings is None or output_embeddings is None:
return False

input_weight = getattr(input_embeddings, "weight", None)
output_weight = getattr(output_embeddings, "weight", None)

if input_weight is None or output_weight is None:
return False

return _weights_share_storage(input_weight, output_weight)


def _resolve_lm_head_weight_bits(stage_cfg: Mapping[str, Any]) -> int | None:
"""Resolve LM head bit-width using embedding bit-width as its default."""
lm_head_weight_bits = stage_cfg.get("lm_head_weight_bits")
if lm_head_weight_bits is not None:
return int(lm_head_weight_bits)

embedding_weight_bits = stage_cfg.get("embedding_weight_bits")
if embedding_weight_bits is None:
return None

return int(embedding_weight_bits)


def _resolve_embedding_weight_bits(stage_cfg: Mapping[str, Any]) -> int | None:
"""Resolve the optional input embedding bit-width from stage configuration."""
embedding_weight_bits = stage_cfg.get("embedding_weight_bits")
if embedding_weight_bits is None:
return None
return int(embedding_weight_bits)


def validate_tied_embedding_weight_bits(
model: torch.nn.Module,
embedding_weight_bits: int | None,
lm_head_weight_bits: int | None,
) -> None:
"""
Reject different embedding and LM head bit-widths for tied weights.

Args:
model: Model whose input embedding and output projection are inspected.
embedding_weight_bits: Bit-width requested for input embedding weights.
lm_head_weight_bits: Bit-width requested for LM head weights.

Raises:
ValueError: If the model ties input embedding and LM head weights while
their requested bit-widths differ.
"""
if embedding_weight_bits is None or lm_head_weight_bits is None:
return

if embedding_weight_bits == lm_head_weight_bits:
return

if not has_tied_input_output_embeddings(model):
return

raise ValueError(
"Cannot use different bit-widths for tied input embedding and lm_head "
"weights: "
f"embedding_weight_bits={embedding_weight_bits}, "
f"lm_head_weight_bits={lm_head_weight_bits}. "
"Set both options to the same value or use a model with untied "
"input/output embeddings."
)


class LlamaAdapter(ModelAdapter):
family = "llama"

Expand Down Expand Up @@ -186,15 +291,22 @@ def build_ptq_config(self, ctx: RecipeContext, stage_cfg: Mapping[str, Any]):
if _is_stage_enabled(ctx.cfg, "spinquant")
else None
)
embedding_weight_bits = _resolve_embedding_weight_bits(stage_cfg)
lm_head_weight_bits = _resolve_lm_head_weight_bits(stage_cfg)
validate_tied_embedding_weight_bits(
ctx.model,
embedding_weight_bits,
lm_head_weight_bits,
)

return build_llm_ptq_config(
model_type="llama",
num_hidden_layers=num_hidden_layers,
activation_dtype=activation_dtype,
default_qscheme=default_qscheme,
linear_weight_bits=stage_cfg.get("linear_weight_bits"),
embedding_weight_bits=stage_cfg.get("embedding_weight_bits"),
lm_head_weight_bits=stage_cfg.get("lm_head_weight_bits"),
embedding_weight_bits=embedding_weight_bits,
lm_head_weight_bits=lm_head_weight_bits,
spin_rotation_weight_bits=spin_rotation_weight_bits,
norm_dtype=wrapq_dtype_from_name(stage_cfg["norm_dtype"])
if stage_cfg.get("norm_dtype")
Expand Down
4 changes: 4 additions & 0 deletions tico/quantization/recipes/adapters/qwen3_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,15 @@ def _disable_cache(model: Any) -> None:

def build_calibration_inputs(self, ctx: RecipeContext) -> list[dict]:
calib = ctx.cfg.get("calibration", {})
runtime = ctx.cfg.get("runtime", {})
return build_vlm_calibration_inputs(
processor=ctx.processor,
dataset=calib.get("dataset", "vqav2"),
datasets=calib.get("datasets"),
n_samples=int(calib.get("n_samples", 128)),
split=calib.get("split", "testdev"),
max_seq_len=calib.get("seq_len"),
seed=int(runtime.get("seed", 42)),
)

def forward_calibration(
Expand Down
Loading
Loading