diff --git a/examples/llm_compression/openvino/smollm2_360m_adaptive_codebook/README.md b/examples/llm_compression/openvino/smollm2_360m_adaptive_codebook/README.md new file mode 100644 index 00000000000..644b7461e19 --- /dev/null +++ b/examples/llm_compression/openvino/smollm2_360m_adaptive_codebook/README.md @@ -0,0 +1,34 @@ +# Large Language Models FP8 Compression Example + +This example demonstrates how to apply codebook compression to [HuggingFaceTB/SmolLM2-360M-Instruct](https://huggingface.co/HuggingFaceTB/SmolLM2-360M-Instruct) model. It can be useful for evaluation and early HW enablement purposes. + +## Prerequisites + +Before running this example, ensure you have Python 3.10+ installed and set up your environment: + +### 1. Create and activate a virtual environment + +```bash +python3 -m venv nncf_env +source nncf_env/bin/activate # On Windows: nncf_env\Scripts\activate.bat +``` + +### 2. Install NNCF and other dependencies + +```bash +python3 -m pip install ../../../../ -r requirements.txt +``` + +## Run Example + +To run example: + +```bash +python main.py +``` + +This will automatically: + +- Download the SmolLM2 model and dataset +- Apply weight compression using NNCF +- Save the optimized model diff --git a/examples/llm_compression/openvino/smollm2_360m_adaptive_codebook/main.py b/examples/llm_compression/openvino/smollm2_360m_adaptive_codebook/main.py new file mode 100644 index 00000000000..0b28559fdc5 --- /dev/null +++ b/examples/llm_compression/openvino/smollm2_360m_adaptive_codebook/main.py @@ -0,0 +1,221 @@ +# Copyright (c) 2026 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from functools import partial + +import datasets +import numpy as np +from optimum.intel.openvino import OVModelForCausalLM +from scipy.stats import norm +from torch.jit import TracerWarning +from transformers import AutoTokenizer +from transformers import logging + +import nncf +from nncf.quantization.advanced_parameters import AdvancedAdaptiveCodebookParameters + +logging.set_verbosity_error() +warnings.filterwarnings("ignore", category=TracerWarning) + + +MODEL_ID = "HuggingFaceTB/SmolLM2-360M-Instruct" +COMPRESSED_MODEL_ID = "smollm2_360m_compressed_codebook" + + +def get_dataset(model, tokenizer): + def transform_func(item, tokenizer, input_shapes, max_tokens=128): + text = item["text"] + tokens = tokenizer(text) + + res = { + "input_ids": np.expand_dims(np.array(tokens["input_ids"][:max_tokens]), 0), + "attention_mask": np.expand_dims(np.array(tokens["attention_mask"][:max_tokens]), 0), + } + + if "position_ids" in input_shapes: + position_ids = np.cumsum(res["attention_mask"], axis=1) - 1 + position_ids[res["attention_mask"] == 0] = 1 + res["position_ids"] = position_ids + batch_size = res["input_ids"].shape[0] + + if "beam_idx" in input_shapes: + res["beam_idx"] = np.arange(batch_size, dtype=int) + + return res + + def get_input_shapes(model, batch_size=1): + inputs = {} + + for val in model.model.inputs: + name = val.any_name + shape = list(val.partial_shape.get_min_shape()) + shape[0] = batch_size + inputs[name] = shape + + return inputs + + input_shapes = get_input_shapes(model, batch_size=1) + + dataset = datasets.load_dataset("wikitext", "wikitext-2-raw-v1", split="test") + dataset = dataset.filter(lambda example: len(example["text"]) > 128) + + def preprocess_fn(example): + return {"text": tokenizer.apply_chat_template(example["messages"], add_generation_prompt=False, tokenize=False)} + + num_samples = 2048 + ds = datasets.load_dataset("neuralmagic/LLM_compression_calibration", split="train") + ds = ds.shuffle(seed=42).select(range(num_samples)) + ds = ds.map(preprocess_fn) + dataset = ds + + quantization_dataset = nncf.Dataset( + dataset, partial(transform_func, tokenizer=tokenizer, input_shapes=input_shapes) + ) + return quantization_dataset + + +def create_normal_distributed_values(n_levels=8) -> np.ndarray: + probs = (np.arange(n_levels) + 0.5) / n_levels + + # Inverse CDF (quantiles) of standard normal distribution + values = norm.ppf(probs) + + # Normalize to [-1, 1] + values = values / np.max(np.abs(values)) + + return values.astype(np.float32) + + +def generate_answers( + questions: list[str], model: OVModelForCausalLM, tokenizer: AutoTokenizer, max_new_tokens: int = 10 +) -> dict[str, str]: + """ + Generate answers for a list of questions using the provided model and tokenizer. + + :param questions: List of questions to be answered. + :param model: The model to use for generating answers. + :param tokenizer: The tokenizer to use for processing the input and output. + :param max_new_tokens: Maximum number of new tokens to generate for each answer. Defaults to 50. + :return: A dictionary mapping each question to its corresponding answer. + """ + messages = [ + {"role": "system", "content": "You are a chatbot who always responds as short as possible."}, + {"role": "user", "content": "What is the capital of Spain?"}, + {"role": "assistant", "content": "Madrid."}, + ] + answers_by_questions = {} + + for question in questions: + messages.append({"role": "user", "content": question}) + input_ids = tokenizer.apply_chat_template( + messages, tokenize=True, add_generation_prompt=True, return_tensors="pt" + ).to(device=model.device) + input_len = len(input_ids[0]) + + output = model.generate(input_ids, max_new_tokens=max_new_tokens, do_sample=False)[0] + answer = tokenizer.decode(output[input_len:], skip_special_tokens=True) + answers_by_questions[question] = answer + messages.append({"role": "assistant", "content": answer}) + + return answers_by_questions + + +def print_answers(header: str, answers_by_questions: list[str]) -> None: + """ + Print the answers to the console. + + :param header: Header to print before the answers. + :param answers_by_questions: Dictionary mapping questions to their answers. + """ + print(header) + for question, answer in answers_by_questions.items(): + print(f"Q: {question}\nA: {answer}\n") + + +QUESTIONS = [ + "What is the capital of France?", + "What is the highest peak in the Alps?", + "What is the largest city in Canada?", + "What is the most visited city in Japan?", +] + + +def load_model_and_tokenizer(model_id: str, export=True) -> tuple[OVModelForCausalLM, AutoTokenizer]: + """ + Load the model and tokenizer from the specified model ID. + + :param model_id: The identifier of the model to load. + :param export: Whether to export the model for OpenVINO. Defaults to True. + :return: A tuple containing the loaded model and tokenizer. + """ + tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False) + model = OVModelForCausalLM.from_pretrained( + model_id, + export=export, + load_in_8bit=False, + ) + return model, tokenizer + + +def codebook_example( + model_id: str, compressed_model_id: str, adaptive_codebook: bool = False, num_elements: int = 10 +) -> list[str]: + """ + Example of using the custom codebook compression. + + :param model_id: The identifier of the model to load. + :param compressed_model_id: The identifier for the compressed model to save. + :param adaptive_codebook: Whether to use adaptive codebook compression. Defaults to False. + :param num_parameters: Number of parameters in the codebook. Defaults to 8. + :return: A list of answers generated by the model after compression. + """ + model, tokenizer = load_model_and_tokenizer(model_id) + + answers_by_questions = generate_answers(QUESTIONS, model, tokenizer) + print_answers("Non-optimized model outputs:\n", answers_by_questions) + + codebook = create_normal_distributed_values(num_elements) + + adaptive_codebook_params = AdvancedAdaptiveCodebookParameters( + num_elements=num_elements, value_type=nncf.tensor.TensorDataType.float16, per_block=False + ) + quantization_dataset = get_dataset(model, tokenizer) + + model.model = nncf.compress_weights( + model.model, + mode=nncf.CompressWeightsMode.ADAPTIVE_CODEBOOK if adaptive_codebook else nncf.CompressWeightsMode.CODEBOOK, + ratio=1.0, + group_size=-1, + scale_estimation=True, + dataset=quantization_dataset, + advanced_parameters=nncf.AdvancedCompressionParameters( + codebook=codebook, adaptive_codebook_params=adaptive_codebook_params if adaptive_codebook else None + ), + ) + model.save_pretrained(compressed_model_id) + tokenizer.save_pretrained(compressed_model_id) + + model, tokenizer = load_model_and_tokenizer(compressed_model_id, False) + answers_by_questions = generate_answers(QUESTIONS, model, tokenizer) + print_answers("Optimized model outputs:\n", answers_by_questions) + + return list(answers_by_questions.values()) + + +def main(): + res = codebook_example(MODEL_ID, COMPRESSED_MODEL_ID) + res += codebook_example(MODEL_ID, COMPRESSED_MODEL_ID + "_adaptive", adaptive_codebook=True) + return res + + +if __name__ == "__main__": + main() diff --git a/examples/llm_compression/openvino/smollm2_360m_adaptive_codebook/requirements.txt b/examples/llm_compression/openvino/smollm2_360m_adaptive_codebook/requirements.txt new file mode 100644 index 00000000000..2b42ca1ffb4 --- /dev/null +++ b/examples/llm_compression/openvino/smollm2_360m_adaptive_codebook/requirements.txt @@ -0,0 +1,10 @@ +datasets==4.4.1 +openvino==2025.4.1 +optimum-intel[openvino]==1.26.0 +optimum-onnx==0.0.3 +optimum==2.0.0 +transformers==4.53.0 +onnx==1.17.0 +torch==2.9.0 +torchvision==0.24.0 +pillow==12.0.0 diff --git a/src/nncf/parameters.py b/src/nncf/parameters.py index 21c20743490..49c23176607 100644 --- a/src/nncf/parameters.py +++ b/src/nncf/parameters.py @@ -95,6 +95,7 @@ class CompressWeightsMode(StrEnum): :param FP8_E4M3: A FP8 format with E4M3 values sharing group-level fp16 scale. :param FP4: A FP4 format with E2M1 values sharing group-level fp16 scale. :param CODEBOOK: Codebook (LUT) quantization format. + :param ADAPTIVE_CODEBOOK: Adaptive codebook (LUT) quantization format. :param CB4_F8E4M3: Codebook (LUT) format with 16 fixed fp8 values in E4M3 format. """ @@ -110,6 +111,7 @@ class CompressWeightsMode(StrEnum): FP8_E4M3 = "fp8_e4m3" FP4 = "fp4" CODEBOOK = "codebook" + ADAPTIVE_CODEBOOK = "adaptive_codebook" @api(canonical_alias="nncf.CompressionFormat") diff --git a/src/nncf/quantization/advanced_parameters.py b/src/nncf/quantization/advanced_parameters.py index 35c45ad994d..ffe8359c72c 100644 --- a/src/nncf/quantization/advanced_parameters.py +++ b/src/nncf/quantization/advanced_parameters.py @@ -28,6 +28,7 @@ from nncf.quantization.range_estimator import AggregatorType from nncf.quantization.range_estimator import RangeEstimatorParameters from nncf.quantization.range_estimator import StatisticsType +from nncf.tensor import TensorDataType TTensor = Any @@ -384,6 +385,25 @@ class AdvancedLoraCorrectionParameters: use_int8_adapters: bool = True +@api() +@dataclass +class AdvancedAdaptiveCodebookParameters: + """ + Contains advanced parameters for adaptive codebook estimation. + + :param value_type: The target tensor data type for the codebook. + :type value_type: TensorDataType + :param per_block: Whether to use per-block codebooks (e.g., all down_proj has the same codeboook). + :type per_block: bool + :param num_elements: The number of elements in each codebook entry. + :type num_elements: int + """ + + value_type: TensorDataType = TensorDataType.f8e4m3 + per_block: bool = False + num_elements: int = 16 + + @api() @dataclass class AdvancedCompressionParameters: @@ -426,6 +446,9 @@ class AdvancedCompressionParameters: lora_correction_params: AdvancedLoraCorrectionParameters = field(default_factory=AdvancedLoraCorrectionParameters) backend_params: dict[str, Any] = field(default_factory=dict) codebook: Optional[TTensor] = None + adaptive_codebook_params: AdvancedAdaptiveCodebookParameters = field( + default_factory=AdvancedAdaptiveCodebookParameters + ) @api() diff --git a/src/nncf/quantization/algorithms/weight_compression/algorithm.py b/src/nncf/quantization/algorithms/weight_compression/algorithm.py index 90b1f585a7f..0b091af4654 100644 --- a/src/nncf/quantization/algorithms/weight_compression/algorithm.py +++ b/src/nncf/quantization/algorithms/weight_compression/algorithm.py @@ -41,6 +41,7 @@ from nncf.quantization.advanced_parameters import convert_to_dict_recursively from nncf.quantization.algorithms.algorithm import Algorithm from nncf.quantization.algorithms.weight_compression.awq import AWQ +from nncf.quantization.algorithms.weight_compression.codebook_estimation import CodebookEstimation from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters from nncf.quantization.algorithms.weight_compression.constants import CB4_QUANTILES from nncf.quantization.algorithms.weight_compression.gptq import GPTQ @@ -93,7 +94,11 @@ def get_weight_compression_configuration( elif group_size is None and mode in NON_INT8_MODES: if mode in [CompressWeightsMode.MXFP4, CompressWeightsMode.MXFP8_E4M3]: group_size = 32 - elif mode in [CompressWeightsMode.CODEBOOK, CompressWeightsMode.CB4_F8E4M3]: + elif mode in [ + CompressWeightsMode.CODEBOOK, + CompressWeightsMode.CB4_F8E4M3, + CompressWeightsMode.ADAPTIVE_CODEBOOK, + ]: group_size = -1 else: group_size = 128 @@ -203,6 +208,27 @@ def check_user_compression_configuration( if msg: raise nncf.ValidationError(msg) + if ( + advanced_parameters.adaptive_codebook_params is not None + and codebook is not None + and mode == CompressWeightsMode.ADAPTIVE_CODEBOOK + ): + cb_params = advanced_parameters.adaptive_codebook_params + if cb_params.num_elements is not None and cb_params.num_elements != np_codebook.size: + msg = ( + "The 'num_elements' parameter in Adaptive Codebook parameters " + "must match the size of the provided codebook. " + f"Expected {np_codebook.size}, but got {cb_params.num_elements}." + ) + raise nncf.ValidationError(msg) + + if cb_params.per_block and (group_size and group_size != -1): + msg = ( + "When 'per_block' is set to True in Adaptive Codebook parameters, " + "the 'group_size' must be -1 (no grouping) or None." + ) + raise nncf.ValidationError(msg) + for size in values_to_check: if size <= 0: msg = f"The subset_size value should be positive, but subset_size={size} is given." @@ -231,7 +257,7 @@ def check_user_compression_configuration( msg = "LoRA Correction algorithm is not compatible with FQ, FQ_LORA and FQ_LORA_NLS compression formats." raise nncf.ValidationError(msg) - if mode == CompressWeightsMode.CODEBOOK and (advanced_parameters is None or advanced_parameters.codebook is None): + if mode in [CompressWeightsMode.CODEBOOK] and (advanced_parameters is None or advanced_parameters.codebook is None): msg = "Codebook compression mode requires codebook parameters to be specified in advanced_parameters." raise nncf.ValidationError(msg) @@ -337,6 +363,7 @@ def __init__( self._scale_estimation = scale_estimation self._gptq = gptq self._lora_correction = lora_correction + self._codebook_estimation = mode == CompressWeightsMode.ADAPTIVE_CODEBOOK self._backup_mode = backup_mode self._compression_format = compression_format self._advanced_parameters = ( @@ -377,6 +404,14 @@ def __init__( scale_estimation_params.weight_penalty, ) + if self._codebook_estimation: + codebook_estimation_params = self._advanced_parameters.adaptive_codebook_params + self._codebook_estimation_algo = CodebookEstimation( + codebook_estimation_params.value_type, + codebook_estimation_params.per_block, + codebook_estimation_params.num_elements, + ) + self._data_aware_mixed_precision = ( self._sensitivity_metric != SensitivityMetric.WEIGHT_QUANTIZATION_ERROR and self._ratio != 1.0 ) @@ -385,6 +420,7 @@ def __init__( or self._scale_estimation or self._lora_correction or self._gptq + or self._codebook_estimation ) @property @@ -543,6 +579,10 @@ def _get_primary_config(self, group_size: int) -> WeightCompressionConfig: codebook_values = Tensor(CB4_QUANTILES) elif self._mode == CompressWeightsMode.CODEBOOK: codebook_values = Tensor(self._advanced_parameters.codebook) + elif self._mode == CompressWeightsMode.ADAPTIVE_CODEBOOK: + codebook_values = Tensor( + self._advanced_parameters.codebook if self._advanced_parameters.codebook is not None else CB4_QUANTILES + ) return WeightCompressionConfig( mode=self._mode, @@ -1083,6 +1123,15 @@ def apply_with_parameters( lora_correction_algo = None description = "Applying Weight Compression" + if self._codebook_estimation: + precomputed_compressed_weights = self._codebook_estimation_algo.apply( + model=model, + graph=graph, + all_weight_params=all_weight_params, + statistics=statistics, + backend_entity=self._backend_entity, + ) + if self._gptq: del statistics model, precomputed_compressed_weights = self._gptq_algo.apply( diff --git a/src/nncf/quantization/algorithms/weight_compression/codebook_estimation.py b/src/nncf/quantization/algorithms/weight_compression/codebook_estimation.py new file mode 100644 index 00000000000..3532303bfd6 --- /dev/null +++ b/src/nncf/quantization/algorithms/weight_compression/codebook_estimation.py @@ -0,0 +1,577 @@ +# Copyright (c) 2026 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from copy import deepcopy +from dataclasses import dataclass +from typing import Optional, TypeVar + +import nncf +from nncf.common.graph.graph import NNCFGraph +from nncf.common.logging.track_progress import track +from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer +from nncf.common.tensor_statistics.statistics import WCTensorStatistic +from nncf.common.utils.backend import BackendType +from nncf.common.utils.backend import get_backend +from nncf.quantization.algorithms.algorithm import Algorithm +from nncf.quantization.algorithms.weight_compression.activation_stats import process_stats +from nncf.quantization.algorithms.weight_compression.backend import WeightCompressionAlgoBackend +from nncf.quantization.algorithms.weight_compression.config import WeightCompressionConfig +from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters +from nncf.quantization.algorithms.weight_compression.parameters import CompressedWeight +from nncf.quantization.algorithms.weight_compression.weight_lowering import _calculate_normalized_weight +from nncf.quantization.algorithms.weight_compression.weight_lowering import calculate_float_quantization_params +from nncf.quantization.algorithms.weight_compression.weight_lowering import float_quantize_dequantize_weight +from nncf.quantization.algorithms.weight_compression.weight_lowering import reshape_weight_for_grouped_quantization +from nncf.tensor import Tensor +from nncf.tensor import TensorDataType +from nncf.tensor import functions as fns + +TModel = TypeVar("TModel") + + +class CodebookEstimation(Algorithm): + """ + Codebook estimation algorithm implementation. + """ + + def __init__( + self, value_type: TensorDataType = TensorDataType.f8e4m3, per_block: bool = True, num_elements: int = 16 + ): + """ + Initializes the CodebookEstimation algorithm. + """ + super().__init__() + + self._value_type = value_type + self._per_block = per_block + self._num_elements = num_elements + + @property + def available_backends(self) -> list[BackendType]: + return [BackendType.OPENVINO] + + def _set_backend_entity(self, model: TModel) -> None: + """ + Creates a helper class with a backed-specific logic of the algorithm. + + :param model: Backend-specific input model. + """ + model_backend = get_backend(model) + if model_backend == BackendType.OPENVINO: + from nncf.quantization.algorithms.weight_compression.openvino_backend import OVWeightCompressionAlgoBackend + + self._backend_entity = OVWeightCompressionAlgoBackend(model) + else: + msg = ( + "Cannot return backend-specific Codebook Estimation entity because" + f" {model_backend.value} is not supported!" + ) + raise nncf.UnsupportedBackendError(msg) + + def apply( + self, + model: TModel, + graph: NNCFGraph, + all_weight_params: list[WeightCompressionParameters], + statistics: dict[str, WCTensorStatistic], + backend_entity: Optional[WeightCompressionAlgoBackend] = None, + ) -> dict[str, CompressedWeight]: + """ + Estimates better codebook. + Minimizes difference between floating point MatMul and + MatMul with compressed weights. + The algorithm computes codebook and indexes for MatMul compression. + + :param model: Model for applying algorithm. + :param graph: Model graph. + :param all_weight_params: List of all weight parameters. + :param statistics: Input activation statistics for each node. + :param statistic_points: Statistic points with collected statistics values. + :param dataset: A representative dataset for the calibration process. + :param backend_entity: Weight compression algorithm backend. + :return: A dictionary that maps weight names to CompressedWeight with codebook, codebook indexes and scale. + """ + self._backend_entity = backend_entity + if self._backend_entity is None: + self._set_backend_entity(model) + + if self._per_block: + return self.apply_per_group(model, graph, all_weight_params, statistics, backend_entity) + + res = dict() + + for wp in track(all_weight_params, description="Applying Codebook Estimation"): + weight_name = wp.weight_name + node_name = wp.node_with_weight.node_name + config = wp.compression_config + + if config.num_bits != 4: # or node_name not in statistics: + res[weight_name] = CompressedWeight() + continue + + stats = statistics[node_name] + + weight_data = self._backend_entity.get_weight_names_and_port_ids(wp.node_with_weight, graph) + if len(weight_data) != 1: # not supported by the algorithm + continue + _, weight_port_id = weight_data[0] + + weight = self._backend_entity.get_weight(wp.node_with_weight, weight_port_id, model, graph) + + codebook = self.calculate_codebook(stats, weight, wp.reduction_axes, config, wp) + res[weight_name] = CompressedWeight(None, None, None, codebook) + + return res + + def apply_per_group( + self, + model: TModel, + graph: NNCFGraph, + all_weight_params: list[WeightCompressionParameters], + statistics: dict[str, WCTensorStatistic], + backend_entity: Optional[WeightCompressionAlgoBackend] = None, + ) -> dict[str, CompressedWeight]: + """ + Estimates better codebook for group of weights grouped by name: down_proj, up_proj, etc. + Minimizes difference between floating point MatMul and + MatMul with compressed weights. + The algorithm computes codebook and indexes for MatMul compression. + + :param model: Model for applying algorithm. + :param graph: Model graph. + :param all_weight_params: List of all weight parameters. + :param statistics: Input activation statistics for each node. + :param statistic_points: Statistic points with collected statistics values. + :param dataset: A representative dataset for the calibration process. + :param backend_entity: Weight compression algorithm backend. + :return: A dictionary that maps weight names to CompressedWeight with codebook, codebook indexes and scale. + """ + self._backend_entity = backend_entity + if self._backend_entity is None: + self._set_backend_entity(model) + res = dict() + + for wp in track(all_weight_params, description="Applying Codebook Estimation per group"): + weight_name = wp.weight_name + node_name = wp.node_with_weight.node_name + config = wp.compression_config + + if weight_name in res: + continue + + if config.num_bits != 4: # or node_name not in statistics: + res[weight_name] = CompressedWeight() + continue + + weight = self.get_weight(model, graph, wp) + if weight is None: + continue + + weights = [weight] + stats = [statistics[node_name]] + group_weights_params = [wp] + + clear_weight_name = "".join(filter(lambda x: x.isalpha(), weight_name)) + + for other_wp in all_weight_params: + if other_wp.weight_name == weight_name: + continue + other_weight_name = "".join(filter(lambda x: x.isalpha(), other_wp.weight_name)) + other_node_name = other_wp.node_with_weight.node_name + + if clear_weight_name == other_weight_name: + other_weight = self.get_weight(model, graph, other_wp) + if other_weight is not None: + weights.append(other_weight) + stats.append(statistics[other_node_name]) + group_weights_params.append(other_wp) + + codebook = self.calculate_codebook_for_group(stats, weights, wp.reduction_axes, config, wp) + + for gwp in group_weights_params: + res[gwp.weight_name] = CompressedWeight(None, None, None, codebook) + gwp.compression_config.codebook_values = codebook + + return res + + def get_weight(self, model: TModel, graph: NNCFGraph, wp: WeightCompressionParameters) -> Tensor: + weight_data = self._backend_entity.get_weight_names_and_port_ids(wp.node_with_weight, graph) + if len(weight_data) != 1: # not supported by the algorithm + return None + _, weight_port_id = weight_data[0] + + return self._backend_entity.get_weight(wp.node_with_weight, weight_port_id, model, graph) + + def calculate_codebook( + self, + statistics: WCTensorStatistic, + weight: Tensor, + reduction_axes: tuple[int, ...], + config: WeightCompressionConfig, + wp: WeightCompressionParameters, + ) -> Tensor: + reduction_axis = reduction_axes[0] + weight = deepcopy(weight.astype(TensorDataType.float32)) + + s, X = process_stats(statistics, -1) + + if reduction_axis == 0: + weight = fns.transpose(weight) + reduction_axis = 1 + + orig_shape = weight.shape + + if config.group_size != -1: + weight, reduction_axes = reshape_weight_for_grouped_quantization(weight, reduction_axes, config.group_size) + s = fns.unsqueeze(s, -2) + s, _ = reshape_weight_for_grouped_quantization(s, reduction_axis, config.group_size) + + importance = fns.ones_like(weight) + importance = importance * s + + scale = calculate_float_quantization_params(weight, reduction_axes, config, signed=True) + norm_weight = _calculate_normalized_weight(weight, scale) + + codebook, indexes, variants = weights_clusterization_k_means( + norm_weight, importance, n_centroids=self._num_elements + ) + + indexes = indexes.reshape(weight.shape) + + best_codebook = codebook.as_openvino_tensor().astype(self._value_type) + + diff = float("inf") + + if self._num_elements == config.get_numpy_codebook().size: + variants[0] = fns.tensor( + config.get_numpy_codebook().data, backend=weight.backend, dtype=TensorDataType.float16 + ) + variants[1] = fns.tensor( + list(range(-self._num_elements // 2, self._num_elements - self._num_elements // 2)), + backend=weight.backend, + dtype=TensorDataType.float16, + ) + + weight = fns.reshape(weight, orig_shape) + + fp_outs = fns.matmul(weight, X) + for var in variants: + var = var.as_openvino_tensor().astype(self._value_type) + config.codebook_values = Tensor(var) + qw = float_quantize_dequantize_weight(weight, config, wp.reduction_axes) + q_outs = fns.matmul(fns.reshape(qw, orig_shape), X) + + cur_diff = fns.mean(fns.abs(fp_outs - q_outs)).item() + if cur_diff < diff: + diff = cur_diff + best_codebook = var + + return Tensor(best_codebook) + + def calculate_codebook_for_group( + self, + statistics: list[WCTensorStatistic], + weights: list[Tensor], + reduction_axes: tuple[int, ...], + config: WeightCompressionConfig, + wp: WeightCompressionParameters, + ) -> Tensor: + reduction_axis = reduction_axes[0] + + norm_weight = [] + importances = [] + Xs = [] + fp_outs = [] + + for stat, weight in zip(statistics, weights): + weight = deepcopy(weight.astype(TensorDataType.float32)) + s, X = process_stats(stat, -1) + Xs.append(X) + + if reduction_axis == 0: + weight = fns.transpose(weight) + reduction_axis = 1 + + if config.group_size != -1: + weight, reduction_axes = reshape_weight_for_grouped_quantization( + weight, reduction_axes, config.group_size + ) + + fp_outs.append(fns.matmul(weight, X)) + + importance = fns.ones_like(weight) + importance = importance * s + importances.append(importance) + + scale = calculate_float_quantization_params(weight, reduction_axes, config, signed=False) + norm_weight.append(_calculate_normalized_weight(weight, scale)) + + norm_weight = fns.concatenate(norm_weight, axis=0) + importance = fns.concatenate(importances, axis=0) + + codebook, _, variants = weights_clusterization_k_means(norm_weight, importance, n_centroids=self._num_elements) + + best_codebook = codebook.as_openvino_tensor().astype(self._value_type) + + diff = float("inf") + + if self._num_elements == config.get_numpy_codebook().size: + variants[0] = fns.tensor( + config.get_numpy_codebook().data, backend=weight.backend, dtype=TensorDataType.float16 + ) + variants[1] = fns.tensor( + list(range(-self._num_elements // 2, self._num_elements - self._num_elements // 2)), + backend=weight.backend, + dtype=TensorDataType.float16, + ) + + coeffs = [fns.mean(fns.abs(X)).item() for X in Xs] + + for var in variants: + var = var.as_openvino_tensor().astype(self._value_type) + config.codebook_values = Tensor(var) + + cur_diff = 0.0 + for weight, X, fp_out, c in zip(weights, Xs, fp_outs, coeffs): + qw = float_quantize_dequantize_weight(weight, config, wp.reduction_axes) + q_out = fns.matmul(qw, X) + cur_diff += c * fns.mean(fns.abs(fp_out - q_out)).item() + if cur_diff < diff: + diff = cur_diff + best_codebook = var + + return Tensor(best_codebook) + + def get_statistic_points(self, model: TModel, graph: NNCFGraph) -> StatisticPointsContainer: + """ + Returns statistic points, for which StatisticsCollector should collect statistics. + + :param model: Model for statistics collection. + :param graph: Model graph. + :return: Statistic points, for which StatisticsCollector should collect statistics. + """ + return StatisticPointsContainer() + + +def round_to_left(quantiles, values): + center_of_quantiles = 0.5 * (quantiles[1:] + quantiles[:-1]) + return fns.searchsorted(center_of_quantiles, values, side="left", sorter=None) + + +@dataclass +class KMeansAlgoData: + centroids: Tensor + weighted_centroids: Tensor + weighted_importance: Tensor | None = None + + +class KMeansWeighted: + def __init__(self, n_clusters=8, max_iter=300): + self.n_clusters = n_clusters + self.max_iter = max_iter + self.variants = [] + self.centroids = None + + @staticmethod + def get_init(values, frequencies, n_clusters): + step = 1.0 / (n_clusters - 1) + denum = fns.sum(frequencies) + quants = [i * step for i in range(n_clusters)] + n_frequencies = frequencies / denum + n_frequencies = fns.cumsum(n_frequencies, axis=0) + + res = fns.zeros((n_clusters,), backend=values.backend, dtype=values.dtype) + for i in range(n_clusters): + if i == 0: + res[i] = values[0] + elif i == n_clusters - 1: + res[i] = values[-1] + else: + prev_val = values[fns.nonzero(n_frequencies <= quants[i])[0][-1].item()].item() + next_val = values[fns.nonzero(n_frequencies <= quants[i + 1])[0][-1].item()].item() + res[i] = (prev_val + next_val) / 2 + + # avoid close centroids + th = 0.05 + for i in range(1, n_clusters - 1): + if (res[i] - res[i + 1]).abs() / max(res[i].abs(), res[i + 1].abs()) < th: + res[i] = (res[i - 1] + res[i + 1]) / 2 + + return res + + @staticmethod + def create_histogramm(data, granularity=0.01): + centers = [] + step = granularity + + data_range = (data.min().item(), data.max().item()) + prev = data_range[0] + + while prev < data_range[1]: + centers.append(prev + step / 2) + prev += step + + centers = fns.tensor(centers, backend=data.backend) + centroid_idxs = round_to_left(centers, data) + + res = [[], [], []] + for i in range(centers.size): + idxs = fns.nonzero(centroid_idxs == i) + if len(idxs[0]) == 0: + continue + res[0].append(centers[i]) + res[1].append(fns.sum(data[idxs])) + res[2].append(len(idxs[0])) + + res[0] = fns.tensor(res[0], backend=data.backend) # centers of histogram bins + res[1] = fns.tensor(res[1], backend=data.backend) # sum of values in each bin + res[2] = fns.tensor(res[2], backend=data.backend) # count of values in each bin + + return res + + @staticmethod + def create_histogramm_sorted(data_, importance, intervals=100000): + centers = [] + ranges = [] + + intervals = max(min(intervals, int(0.005 * data_.size)), 100) + + step = data_.max().item() - data_.min().item() + step /= intervals + + sorted_idx = fns.argsort(data_) + data = data_[sorted_idx] + importance = importance[sorted_idx] + + data_range = (data.min().item(), data.max().item()) + prev = data_range[0] + + while prev < data_range[1]: + centers.append(prev + step / 2) + prev += step + + if len(centers) > 1: + ranges.append(0.5 * (centers[-2] + centers[-1])) + ranges.append(centers[-1]) + + centers = fns.tensor(centers, backend=data_.backend, dtype=data_.dtype) + ranges = fns.tensor(ranges, backend=data_.backend, dtype=data_.dtype) + + ranges_idxs = round_to_left(data, ranges) + + res_centers = [] + weighted_data = [] + weighted_importance = [] + + for i in range(centers.size): + if i == 0: + data_range, importance_range = data[: ranges_idxs[1].item()], importance[: ranges_idxs[1].item()] + elif i == centers.size - 1: + data_range, importance_range = data[ranges_idxs[-2].item() :], importance[ranges_idxs[-2].item() :] + else: + idx = 2 * i + data_range, importance_range = ( + data[ranges_idxs[idx - 1].item() : ranges_idxs[idx + 1].item()], + importance[ranges_idxs[idx - 1].item() : ranges_idxs[idx + 1].item()], + ) + + if data_range.size == 0: + continue + res_centers.append(centers[i].item()) + weighted_data.append(fns.sum(fns.multiply(data_range, importance_range)).item()) + weighted_importance.append(fns.sum(importance_range).item()) + + res = KMeansAlgoData( + fns.tensor(res_centers, backend=data_.backend, dtype=data_.dtype), + fns.tensor(weighted_data, backend=data_.backend, dtype=data_.dtype), + fns.tensor(weighted_importance, backend=data_.backend, dtype=data_.dtype), + ) + return res + + def fit(self, X_train, importance, init, fixed=None): + if self.max_iter == 1: + self.centroids = deepcopy(init) + return + if fixed is None: + fixed = [0, len(init) // 2, len(init) - 1] + + self.hist = KMeansWeighted.create_histogramm_sorted(X_train, importance) + + init_by_hist = self.get_init(self.hist.centroids, self.hist.weighted_importance, self.n_clusters) + init_by_hist[0] = init[0] + init_by_hist[-1] = init[-1] + zero_idx = fns.argmin(fns.abs(init_by_hist[:]), axis=0).item() + + if init[0] <= 0.0: + init_by_hist[zero_idx] = 0.0 # to have zero in codebook + fixed[1] = zero_idx + init = init_by_hist + + self.centroids = deepcopy(init) + + # not only last variant is stored, + # but also intermediate ones for choosing codebook which gives minimum diff in MatMul + saving_intervals = 5 + iteration = 0 + prev_centroids = self.centroids + while iteration < self.max_iter: + prev_centroids = deepcopy(self.centroids) + + if iteration % saving_intervals == 0: + self.variants.append(deepcopy(self.centroids)) + + centroid_idxs = round_to_left(self.centroids, self.hist.centroids) + for i in range(self.n_clusters): + idxs = fns.nonzero(centroid_idxs == i) + if len(idxs[0]) == 0: + continue + self.centroids[i] = ( + fns.sum(self.hist.weighted_centroids[idxs]).item() + / fns.sum(self.hist.weighted_importance[idxs]).item() + ) + + for idx in fixed: + self.centroids[idx] = init[idx] + iteration += 1 + if fns.any(fns.all(fns.abs(self.centroids - prev_centroids) < 0.00001)): + break + + if (iteration - 1) % saving_intervals != 0: + self.variants.append(deepcopy(self.centroids)) + + def evaluate(self, X): + centroid_idxs = round_to_left(self.centroids, X) + return deepcopy(self.centroids).flatten(), centroid_idxs + + +def weights_clusterization_k_means(weight, importance, n_centroids=2**4): + orig_shape = weight.shape + weight = weight.flatten() + importance = importance.flatten() + + n_init = [0, 0] + n_init[0] = weight.min() + n_init[-1] = weight.max() + + kmeans = KMeansWeighted(n_centroids, max_iter=70) + + # fixed centroids: min, zero, max + kmeans.fit( + weight, + importance, + n_init, + fixed=[0, n_centroids // 2 - 1, n_centroids - 1] if n_init[0] < 0.0 else [0, n_centroids - 1], + ) + codebook, indexes = kmeans.evaluate(weight) + + indexes = fns.reshape(indexes, orig_shape) + + return codebook, indexes, kmeans.variants diff --git a/src/nncf/quantization/algorithms/weight_compression/config.py b/src/nncf/quantization/algorithms/weight_compression/config.py index 25d475212ef..761c269e3da 100644 --- a/src/nncf/quantization/algorithms/weight_compression/config.py +++ b/src/nncf/quantization/algorithms/weight_compression/config.py @@ -71,6 +71,7 @@ def is_integer(self): CompressWeightsMode.FP8_E4M3, CompressWeightsMode.FP4, CompressWeightsMode.CODEBOOK, + CompressWeightsMode.ADAPTIVE_CODEBOOK, CompressWeightsMode.CB4_F8E4M3, ] @@ -79,7 +80,11 @@ def is_codebook(self): """ :return: True if compression type is codebook, else False. """ - return self.mode in [CompressWeightsMode.CODEBOOK, CompressWeightsMode.CB4_F8E4M3] + return self.mode in [ + CompressWeightsMode.CODEBOOK, + CompressWeightsMode.CB4_F8E4M3, + CompressWeightsMode.ADAPTIVE_CODEBOOK, + ] @property def compression_dtype(self) -> TensorDataType: diff --git a/src/nncf/quantization/algorithms/weight_compression/openvino_backend.py b/src/nncf/quantization/algorithms/weight_compression/openvino_backend.py index 56c5282d2e0..7ff93fdc337 100644 --- a/src/nncf/quantization/algorithms/weight_compression/openvino_backend.py +++ b/src/nncf/quantization/algorithms/weight_compression/openvino_backend.py @@ -246,7 +246,7 @@ def _create_compression_subgraph( if compression_config.is_codebook: converted_const = create_ov_codebook_subgraph( codebook=compressed_weight.codebook - if compression_config.mode == CompressWeightsMode.CODEBOOK + if compression_config.mode in [CompressWeightsMode.CODEBOOK, CompressWeightsMode.ADAPTIVE_CODEBOOK] else compressed_weight.codebook.as_openvino_tensor().astype(TensorDataType.f8e4m3), indexes=compressed_weight.tensor, dtype=compression_dtype, diff --git a/src/nncf/quantization/algorithms/weight_compression/weight_lowering.py b/src/nncf/quantization/algorithms/weight_compression/weight_lowering.py index 1a7e47f9fe6..93f3d572f97 100644 --- a/src/nncf/quantization/algorithms/weight_compression/weight_lowering.py +++ b/src/nncf/quantization/algorithms/weight_compression/weight_lowering.py @@ -82,7 +82,7 @@ def reshape_weight_for_grouped_quantization( def calculate_float_quantization_params( - weight: Tensor, reduction_axes: ReductionAxes, config: WeightCompressionConfig + weight: Tensor, reduction_axes: ReductionAxes, config: WeightCompressionConfig, signed: bool = False ) -> Tensor: """ Calculates the scale for nf4 or mxfp8_e4m3/mxfp4/fp8_e4m3/fp4 quantization. @@ -90,6 +90,7 @@ def calculate_float_quantization_params( :param weight: Weight array to compress. :param reduction_axes: Axes along which to reduce (collect) different statistics (e.g., min, max). :param config: Weight compression configuration. + :param signed: Whether to use signed scale for quantization. :return: Scale tensor of float32 type for float quantization. """ assert not config.is_integer @@ -97,7 +98,12 @@ def calculate_float_quantization_params( if weight.dtype != TensorDataType.float32: weight = weight.astype(TensorDataType.float32) - scale = fns.max(fns.abs(weight), axis=reduction_axes, keepdims=True) + if signed: + scale_neg = fns.min(weight, axis=reduction_axes, keepdims=True) + scale_pos = fns.max(weight, axis=reduction_axes, keepdims=True) + scale = fns.where(fns.abs(scale_neg) >= fns.abs(scale_pos), scale_neg, scale_pos) + else: + scale = fns.max(fns.abs(weight), axis=reduction_axes, keepdims=True) if config.mode != CompressWeightsMode.NF4: if config.compression_dtype in FP_MAX_VALUES: max_val = FP_MAX_VALUES[config.compression_dtype] @@ -340,6 +346,13 @@ def compress_weight( ) if not config.is_integer: + if ( + precomputed_compressed_weight is not None + and precomputed_compressed_weight.tensor is not None + and precomputed_compressed_weight.codebook is not None + ): + return precomputed_compressed_weight + compressed_weight, scale, indexes = do_float_quantization(weight, config, reduction_axes, precomputed_scale) if indexes is not None: return CompressedWeight( diff --git a/src/nncf/quantization/quantize_model.py b/src/nncf/quantization/quantize_model.py index 2986ab91e4d..edcb6d9befe 100644 --- a/src/nncf/quantization/quantize_model.py +++ b/src/nncf/quantization/quantize_model.py @@ -511,6 +511,7 @@ def compress_weights( CompressWeightsMode.FP8_E4M3, CompressWeightsMode.FP4, CompressWeightsMode.CODEBOOK, + CompressWeightsMode.ADAPTIVE_CODEBOOK, CompressWeightsMode.CB4_F8E4M3, ] if mode in not_supported_modes: @@ -559,6 +560,7 @@ def compress_weights( CompressWeightsMode.FP8_E4M3, CompressWeightsMode.FP4, CompressWeightsMode.CODEBOOK, + CompressWeightsMode.ADAPTIVE_CODEBOOK, CompressWeightsMode.CB4_F8E4M3, ] if mode in not_supported_modes: @@ -567,10 +569,7 @@ def compress_weights( ) raise nncf.ParameterNotSupportedError(msg) - options = { - "gptq": gptq, - "lora_correction": lora_correction, - } + options = {"gptq": gptq, "lora_correction": lora_correction} unsupported_options = [name for name, value in options.items() if value is not None] if unsupported_options: msg = f"TorchFX backend does not support {', '.join(unsupported_options)} option(s). Set them to None." @@ -634,6 +633,7 @@ def compress_weights( CompressWeightsMode.FP8_E4M3, CompressWeightsMode.FP4, CompressWeightsMode.CODEBOOK, + CompressWeightsMode.ADAPTIVE_CODEBOOK, CompressWeightsMode.CB4_F8E4M3, ] if mode in not_supported_modes: @@ -642,10 +642,7 @@ def compress_weights( ) raise nncf.ParameterNotSupportedError(msg) - options = { - "gptq": gptq, - "lora_correction": lora_correction, - } + options = {"gptq": gptq, "lora_correction": lora_correction} unsupported_options = [name for name, value in options.items() if value is not None] if unsupported_options: msg = f"ONNX backend does not support {', '.join(unsupported_options)} option(s). Set them to None." diff --git a/src/nncf/tensor/functions/__init__.py b/src/nncf/tensor/functions/__init__.py index d624a02befb..11f55c3d3a8 100644 --- a/src/nncf/tensor/functions/__init__.py +++ b/src/nncf/tensor/functions/__init__.py @@ -16,6 +16,7 @@ from nncf.tensor.functions.numeric import allclose as allclose from nncf.tensor.functions.numeric import any as any from nncf.tensor.functions.numeric import arange as arange +from nncf.tensor.functions.numeric import argmin as argmin from nncf.tensor.functions.numeric import argsort as argsort from nncf.tensor.functions.numeric import as_tensor_like as as_tensor_like from nncf.tensor.functions.numeric import astype as astype @@ -53,6 +54,7 @@ from nncf.tensor.functions.numeric import minimum as minimum from nncf.tensor.functions.numeric import moveaxis as moveaxis from nncf.tensor.functions.numeric import multiply as multiply +from nncf.tensor.functions.numeric import nonzero as nonzero from nncf.tensor.functions.numeric import ones_like as ones_like from nncf.tensor.functions.numeric import percentile as percentile from nncf.tensor.functions.numeric import power as power diff --git a/src/nncf/tensor/functions/numeric.py b/src/nncf/tensor/functions/numeric.py index 8418fe89a39..ac9886000ba 100644 --- a/src/nncf/tensor/functions/numeric.py +++ b/src/nncf/tensor/functions/numeric.py @@ -335,6 +335,16 @@ def where(condition: Tensor, x: Union[Tensor, float], y: Union[Tensor, float]) - """ +@tensor_dispatcher +def nonzero(condition: Tensor) -> tuple[Tensor, ...]: + """ + Return the indices of the elements that are non-zero. + + :param condition: The input tensor. + :return: A tensor containing the indices of the non-zero elements. + """ + + @tensor_dispatcher def sign(a: Tensor) -> Tensor: """ @@ -683,6 +693,17 @@ def argsort(a: Tensor, axis: int = -1, descending: bool = False, stable: bool = """ +@tensor_dispatcher +def argmin(a: Tensor, axis: None) -> Tensor: + """ + Returns the indices of the minimum values along an axis. + + :param a: The tensor for which to find the minimum values. + :param axis: Axis or tuple of axes along which to find the minimum values. + :return: Indices of the minimum values along an axis. + """ + + @tensor_dispatcher def diag(a: Tensor, k: int = 0) -> Tensor: """ diff --git a/src/nncf/tensor/functions/numpy_numeric.py b/src/nncf/tensor/functions/numpy_numeric.py index 0513825bae8..8958c23fc62 100644 --- a/src/nncf/tensor/functions/numpy_numeric.py +++ b/src/nncf/tensor/functions/numpy_numeric.py @@ -200,6 +200,13 @@ def _( return np.where(condition, x, y) +@numeric.nonzero.register +def _( + condition: T_NUMPY, +) -> tuple[T_NUMPY_ARRAY, ...]: + return np.nonzero(condition) + + @numeric.sign.register def _(a: T_NUMPY) -> T_NUMPY: return np.sign(a) @@ -326,16 +333,16 @@ def _(a: T_NUMPY) -> T_NUMBER: return a.item() -@numeric.cumsum.register -def _(a: T_NUMPY, axis: int) -> T_NUMPY: - return np.cumsum(a, axis=axis) - - @numeric.sum.register def _(a: T_NUMPY, axis: T_AXIS = None, keepdims: bool = False) -> T_NUMPY_ARRAY: return np.array(np.sum(a, axis=axis, keepdims=keepdims)) +@numeric.cumsum.register +def _(a: T_NUMPY, axis: int) -> T_NUMPY: + return np.cumsum(a, axis=axis) + + @numeric.multiply.register def _(x1: T_NUMPY, x2: Union[T_NUMPY, float]) -> T_NUMPY_ARRAY: return np.multiply(x1, x2) @@ -380,6 +387,11 @@ def _(a: T_NUMPY, axis: int = -1, descending: bool = False, stable: bool = False return np.argsort(a, axis=axis, kind="stable" if stable else None) +@numeric.argmin.register +def _(a: T_NUMPY, axis: None) -> T_NUMPY: + return np.argmin(a, axis=axis) + + @numeric.diag.register def _(a: T_NUMPY, k: int = 0) -> T_NUMPY_ARRAY: return np.diag(a, k=k) diff --git a/src/nncf/tensor/functions/torch_numeric.py b/src/nncf/tensor/functions/torch_numeric.py index a9f80e7a0d4..4169afe24df 100644 --- a/src/nncf/tensor/functions/torch_numeric.py +++ b/src/nncf/tensor/functions/torch_numeric.py @@ -348,16 +348,16 @@ def _(a: torch.Tensor) -> T_NUMBER: return a.item() -@numeric.cumsum.register -def _(a: torch.Tensor, axis: int) -> torch.Tensor: - return torch.cumsum(a, dim=axis) - - @numeric.sum.register def _(a: torch.Tensor, axis: T_AXIS = None, keepdims: bool = False) -> torch.Tensor: return torch.sum(a, dim=axis, keepdim=keepdims) +@numeric.cumsum.register +def _(a: torch.Tensor, axis: int) -> torch.Tensor: + return torch.cumsum(a, dim=axis) + + @numeric.multiply.register def _(x1: torch.Tensor, x2: Union[torch.Tensor, float]) -> torch.Tensor: return torch.multiply(x1, x2) diff --git a/tests/cross_fw/examples/example_scope.json b/tests/cross_fw/examples/example_scope.json index ad2f5f56f8c..ced11baa8ce 100644 --- a/tests/cross_fw/examples/example_scope.json +++ b/tests/cross_fw/examples/example_scope.json @@ -304,6 +304,23 @@ ] } }, + "adaptive_codebook_llm_compression": { + "backend": "openvino", + "requirements": "examples/llm_compression/openvino/smollm2_360m_adaptive_codebook/requirements.txt", + "cpu": "Intel(R) Core(TM) i9-10980XE CPU @ 3.00GHz", + "accuracy_metrics": { + "answers": [ + "The capital of France is Paris.", + "The highest peak in the Alps is Montaldes", + "Torridia.", + "To the most visited city in Japan, Tokyo is", + " Paris.", + " The highest peak in the Alps is Mont Blanc.", + " Ottawa.", + " Tokyo." + ] + } + }, "llm_compression_distillation_qat_with_lora": { "backend": "torch", "device": "cuda", diff --git a/tests/openvino/native/quantization/test_weights_compression.py b/tests/openvino/native/quantization/test_weights_compression.py index 9d2ae47c892..24fb3824aa3 100644 --- a/tests/openvino/native/quantization/test_weights_compression.py +++ b/tests/openvino/native/quantization/test_weights_compression.py @@ -37,6 +37,7 @@ from nncf.parameters import BackupMode from nncf.parameters import CompressionFormat from nncf.quantization import compress_weights +from nncf.quantization.advanced_parameters import AdvancedAdaptiveCodebookParameters as CodebookParams from nncf.quantization.advanced_parameters import AdvancedCompressionParameters from nncf.quantization.advanced_parameters import AdvancedCompressionParameters as CompressionParams from nncf.quantization.advanced_parameters import AdvancedGPTQParameters as GPTQParams @@ -2147,6 +2148,39 @@ def test_codebook_is_correct_array(codebook): ) +@pytest.mark.parametrize("value_type", [None, TensorDataType.float16, TensorDataType.f8e4m3, TensorDataType.int8]) +@pytest.mark.parametrize("group_size", [-1, 4]) +def test_adaptive_codebooks(value_type, group_size): + model = AWQMatmulModel().ov_model + dataset = Dataset([np.ones([1, 8, 8])]) + advanced_parameters = ( + CompressionParams() + if value_type is None + else CompressionParams(adaptive_codebook_params=CodebookParams(value_type=value_type)) + ) + + n_matmuls = 0 + for op in model.get_ordered_ops(): + if op.get_type_name() == "MatMul": + n_matmuls += 1 + + compressed_model = compress_weights( + model, + mode=CompressWeightsMode.ADAPTIVE_CODEBOOK, + group_size=group_size, + dataset=dataset, + advanced_parameters=advanced_parameters, + ) + + n_gathers = 0 + for op in compressed_model.get_ordered_ops(): + if op.get_type_name() == "Gather": + n_gathers += 1 + + # For each MatMul except lm_head, there should be one Gather operation to fetch from the codebook + assert n_gathers == n_matmuls - 1 + + class TestOVTemplateWeightCompression(TemplateWeightCompression): @staticmethod def get_matmul_model() -> ov.Model: