Skip to content
Open
Show file tree
Hide file tree
Changes from 71 commits
Commits
Show all changes
77 commits
Select commit Hold shift + click to select a range
488cacc
Support scale estimation inside GPTQ
alexsu52 Jun 10, 2024
ee64877
fix for INT4_ASYM
alexsu52 Sep 4, 2024
f22e411
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr Sep 23, 2024
51b4d7b
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr Sep 26, 2024
f66cd1e
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr Sep 30, 2024
7ce5a53
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr Oct 2, 2024
f74d156
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr Nov 11, 2024
5288c79
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr Nov 11, 2024
1becf15
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr Nov 14, 2024
047d7d9
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr Dec 10, 2024
c0c7e57
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr Dec 16, 2024
b74dea1
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr Dec 27, 2024
26a9a77
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr Jan 7, 2025
25fcc2c
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr Feb 25, 2025
26d4887
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr Mar 12, 2025
7748233
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr Apr 1, 2025
df251b3
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr Apr 8, 2025
4c134c4
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr Apr 9, 2025
6147097
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr Apr 14, 2025
2b94d28
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr May 7, 2025
5e312a5
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr May 9, 2025
2c5e983
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr May 12, 2025
1d8db1e
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr May 23, 2025
7244f18
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr May 28, 2025
443048c
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr Jun 2, 2025
80d2d8a
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr Jun 11, 2025
06bb19b
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr Jun 26, 2025
5d97d87
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr Jul 2, 2025
ae7cece
Merge remote-tracking branch 'upstream/develop' into develop
andreyanufr Jul 10, 2025
3bcd47b
Initial codebook estimation algorithm.
andreyanufr Jul 11, 2025
eb93fdb
First working example for layer wise codebook.
andreyanufr Jul 14, 2025
5bfccee
Experiment.
andreyanufr Jul 16, 2025
509b6ef
Experiment with accuracy improvement.
andreyanufr Jul 17, 2025
872b025
Fix in histogram computation.
andreyanufr Jul 18, 2025
9a8d08b
Experimrnt.
andreyanufr Jul 28, 2025
ad518c2
Search best codebook by minimizing MatMul diff.
andreyanufr Jul 30, 2025
2f8ec00
Merge remote-tracking branch 'upstream/develop' into aanuf/LUT_per_la…
andreyanufr Sep 10, 2025
b5c4c4a
Merge remote-tracking branch 'upstream/develop' into aanuf/LUT_per_la…
andreyanufr Sep 25, 2025
2fb21b2
Removed unused code.
andreyanufr Sep 25, 2025
812cbed
Remove unused code.
andreyanufr Sep 25, 2025
8c896c8
Replace np by fns.
andreyanufr Sep 26, 2025
a792c0b
Replace np by fns.
andreyanufr Sep 26, 2025
d3c2ab8
Replace np by fns.
andreyanufr Sep 26, 2025
9eec3e3
Replace np by fns.
andreyanufr Sep 29, 2025
5a66fda
Fixed problems with fp64 data types.
andreyanufr Sep 30, 2025
ec432bd
Resolved merge conflict with signed scale.
andreyanufr Oct 8, 2025
735c809
Fixed.
andreyanufr Oct 8, 2025
8ea3946
Removed unused code.
andreyanufr Oct 9, 2025
037a255
Fixed bug with close centroids.
andreyanufr Oct 10, 2025
817a790
Fixed error with argmin/cumsum args.
andreyanufr Oct 13, 2025
be6029a
Removed unused fuction.
andreyanufr Oct 13, 2025
58a64d8
Fix.
andreyanufr Oct 13, 2025
ffe0cf4
Fix.
andreyanufr Oct 13, 2025
72af4fd
Fix.
andreyanufr Oct 13, 2025
c6f72ee
Fixed bug with codebook type..
andreyanufr Oct 13, 2025
8497f4e
Fixed bug with cb4 codebook conversion to fp8.
andreyanufr Oct 14, 2025
535d2da
Disabled codebook estimation for onnx and torch.
andreyanufr Oct 14, 2025
6b2e7f7
Temporal fix for empty cluster.
andreyanufr Nov 5, 2025
4e54047
Per MatMul type codebook.
andreyanufr Dec 1, 2025
998f996
Changed interval step to number of intervals.
andreyanufr Dec 2, 2025
5708ceb
Weighted codebook selection.
andreyanufr Dec 8, 2025
aded7f2
Changed codebook data type.
andreyanufr Dec 10, 2025
2aac843
Fixed bug.
andreyanufr Dec 23, 2025
fb5b7d8
Codebook datatype for experiments.
andreyanufr Dec 24, 2025
708595a
Fixed merge conflicts.
andreyanufr Jan 5, 2026
37765b3
Removed codebook_estimation paramater and replaced it with ADAPTIVE_C…
andreyanufr Jan 7, 2026
97ad50d
Added adaptive codebook parameters.
andreyanufr Jan 7, 2026
8064c3e
1) Added example with adaptive codebook.
andreyanufr Jan 9, 2026
0be0f7d
Added example to test.
andreyanufr Jan 9, 2026
1b25dc1
Added test for adaptiva codebook.
andreyanufr Jan 12, 2026
836d7f8
Added codebook parameters check.
andreyanufr Jan 12, 2026
880d5fb
Applied comments.
andreyanufr Jan 13, 2026
3d00b66
Added support of group_size for per-tensor codebook.
andreyanufr Jan 13, 2026
3eee52d
Fixed merge conflict.
andreyanufr Jan 13, 2026
c048a6a
Check advanced codebook paramaters only in case of right mode.
andreyanufr Jan 14, 2026
4e14f96
Fixed bug in merging.
andreyanufr Jan 14, 2026
ce11a60
Fixed bug with empty histogram bins.
andreyanufr Jan 14, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Large Language Models FP8 Compression Example

This example demonstrates how to apply codebook compression to [HuggingFaceTB/SmolLM2-360M-Instruct](https://huggingface.co/HuggingFaceTB/SmolLM2-360M-Instruct) model. It can be useful for evaluation and early HW enablement purposes.

## Prerequisites

Before running this example, ensure you have Python 3.10+ installed and set up your environment:

### 1. Create and activate a virtual environment

```bash
python3 -m venv nncf_env
source nncf_env/bin/activate # On Windows: nncf_env\Scripts\activate.bat
```

### 2. Install NNCF and other dependencies

```bash
python3 -m pip install ../../../../ -r requirements.txt
```

## Run Example

To run example:

```bash
python main.py
```

This will automatically:

- Download the SmolLM2 model and dataset
- Apply weight compression using NNCF
- Save the optimized model
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
# Copyright (c) 2026 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from functools import partial

import datasets
import numpy as np
from optimum.intel.openvino import OVModelForCausalLM
from scipy.stats import norm
from torch.jit import TracerWarning
from transformers import AutoTokenizer
from transformers import logging

import nncf
from nncf.quantization.advanced_parameters import AdvancedAdaptiveCodebookParameters

logging.set_verbosity_error()
warnings.filterwarnings("ignore", category=TracerWarning)


MODEL_ID = "HuggingFaceTB/SmolLM2-360M-Instruct"
COMPRESSED_MODEL_ID = "smollm2_360m_compressed_codebook"


def get_dataset(model, tokenizer):
def transform_func(item, tokenizer, input_shapes, max_tokens=128):
text = item["text"]
tokens = tokenizer(text)

res = {
"input_ids": np.expand_dims(np.array(tokens["input_ids"][:max_tokens]), 0),
"attention_mask": np.expand_dims(np.array(tokens["attention_mask"][:max_tokens]), 0),
}

if "position_ids" in input_shapes:
position_ids = np.cumsum(res["attention_mask"], axis=1) - 1
position_ids[res["attention_mask"] == 0] = 1
res["position_ids"] = position_ids
batch_size = res["input_ids"].shape[0]

if "beam_idx" in input_shapes:
res["beam_idx"] = np.arange(batch_size, dtype=int)

return res

def get_input_shapes(model, batch_size=1):
inputs = {}

for val in model.model.inputs:
name = val.any_name
shape = list(val.partial_shape.get_min_shape())
shape[0] = batch_size
inputs[name] = shape

return inputs

input_shapes = get_input_shapes(model, batch_size=1)

dataset = datasets.load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
dataset = dataset.filter(lambda example: len(example["text"]) > 128)

def preprocess_fn(example):
return {"text": tokenizer.apply_chat_template(example["messages"], add_generation_prompt=False, tokenize=False)}

num_samples = 2048
ds = datasets.load_dataset("neuralmagic/LLM_compression_calibration", split="train")
ds = ds.shuffle(seed=42).select(range(num_samples))
ds = ds.map(preprocess_fn)
dataset = ds

quantization_dataset = nncf.Dataset(
dataset, partial(transform_func, tokenizer=tokenizer, input_shapes=input_shapes)
)
return quantization_dataset


def create_normal_distributed_values(n_levels=8) -> np.ndarray:
probs = (np.arange(n_levels) + 0.5) / n_levels

# Inverse CDF (quantiles) of standard normal distribution
values = norm.ppf(probs)

# Normalize to [-1, 1]
values = values / np.max(np.abs(values))

return values.astype(np.float32)


def generate_answers(
questions: list[str], model: OVModelForCausalLM, tokenizer: AutoTokenizer, max_new_tokens: int = 10
) -> dict[str, str]:
"""
Generate answers for a list of questions using the provided model and tokenizer.

:param questions: List of questions to be answered.
:param model: The model to use for generating answers.
:param tokenizer: The tokenizer to use for processing the input and output.
:param max_new_tokens: Maximum number of new tokens to generate for each answer. Defaults to 50.
:return: A dictionary mapping each question to its corresponding answer.
"""
messages = [
{"role": "system", "content": "You are a chatbot who always responds as short as possible."},
{"role": "user", "content": "What is the capital of Spain?"},
{"role": "assistant", "content": "Madrid."},
]
answers_by_questions = {}

for question in questions:
messages.append({"role": "user", "content": question})
input_ids = tokenizer.apply_chat_template(
messages, tokenize=True, add_generation_prompt=True, return_tensors="pt"
).to(device=model.device)
input_len = len(input_ids[0])

output = model.generate(input_ids, max_new_tokens=max_new_tokens, do_sample=False)[0]
answer = tokenizer.decode(output[input_len:], skip_special_tokens=True)
answers_by_questions[question] = answer
messages.append({"role": "assistant", "content": answer})

return answers_by_questions


def print_answers(header: str, answers_by_questions: list[str]) -> None:
"""
Print the answers to the console.

:param header: Header to print before the answers.
:param answers_by_questions: Dictionary mapping questions to their answers.
"""
print(header)
for question, answer in answers_by_questions.items():
print(f"Q: {question}\nA: {answer}\n")


QUESTIONS = [
"What is the capital of France?",
"What is the highest peak in the Alps?",
"What is the largest city in Canada?",
"What is the most visited city in Japan?",
]


def load_model_and_tokenizer(model_id: str, export=True) -> tuple[OVModelForCausalLM, AutoTokenizer]:
"""
Load the model and tokenizer from the specified model ID.

:param model_id: The identifier of the model to load.
:param export: Whether to export the model for OpenVINO. Defaults to True.
:return: A tuple containing the loaded model and tokenizer.
"""
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
model = OVModelForCausalLM.from_pretrained(
model_id,
export=export,
load_in_8bit=False,
)
return model, tokenizer


def codebook_example(
model_id: str, compressed_model_id: str, adaptive_codebook: bool = False, num_elements: int = 10
) -> list[str]:
"""
Example of using the custom codebook compression.

:param model_id: The identifier of the model to load.
:param compressed_model_id: The identifier for the compressed model to save.
:param adaptive_codebook: Whether to use adaptive codebook compression. Defaults to False.
:param num_parameters: Number of parameters in the codebook. Defaults to 8.
:return: A list of answers generated by the model after compression.
"""
model, tokenizer = load_model_and_tokenizer(model_id)

answers_by_questions = generate_answers(QUESTIONS, model, tokenizer)
print_answers("Non-optimized model outputs:\n", answers_by_questions)

codebook = create_normal_distributed_values(num_elements)

adaptive_codebook_params = AdvancedAdaptiveCodebookParameters(
num_elements=num_elements, value_type=nncf.tensor.TensorDataType.float16, per_block=False
)
quantization_dataset = get_dataset(model, tokenizer)

model.model = nncf.compress_weights(
model.model,
mode=nncf.CompressWeightsMode.ADAPTIVE_CODEBOOK if adaptive_codebook else nncf.CompressWeightsMode.CODEBOOK,
ratio=1.0,
group_size=-1,
scale_estimation=True,
dataset=quantization_dataset,
advanced_parameters=nncf.AdvancedCompressionParameters(
codebook=codebook, adaptive_codebook_params=adaptive_codebook_params if adaptive_codebook else None
),
)
model.save_pretrained(compressed_model_id)
tokenizer.save_pretrained(compressed_model_id)

model, tokenizer = load_model_and_tokenizer(compressed_model_id, False)
answers_by_questions = generate_answers(QUESTIONS, model, tokenizer)
print_answers("Optimized model outputs:\n", answers_by_questions)

return list(answers_by_questions.values())


def main():
res = codebook_example(MODEL_ID, COMPRESSED_MODEL_ID)
res += codebook_example(MODEL_ID, COMPRESSED_MODEL_ID + "_adaptive", adaptive_codebook=True)
return res


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
datasets==4.4.1
openvino==2025.4.1
optimum-intel[openvino]==1.26.0
optimum-onnx==0.0.3
optimum==2.0.0
transformers==4.53.0
onnx==1.17.0
torch==2.9.0
torchvision==0.24.0
pillow==12.0.0
2 changes: 2 additions & 0 deletions src/nncf/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ class CompressWeightsMode(StrEnum):
:param FP8_E4M3: A FP8 format with E4M3 values sharing group-level fp16 scale.
:param FP4: A FP4 format with E2M1 values sharing group-level fp16 scale.
:param CODEBOOK: Codebook (LUT) quantization format.
:param ADAPTIVE_CODEBOOK: Adaptive codebook (LUT) quantization format.
:param CB4_F8E4M3: Codebook (LUT) format with 16 fixed fp8 values in E4M3 format.
"""

Expand All @@ -110,6 +111,7 @@ class CompressWeightsMode(StrEnum):
FP8_E4M3 = "fp8_e4m3"
FP4 = "fp4"
CODEBOOK = "codebook"
ADAPTIVE_CODEBOOK = "adaptive_codebook"


@api(canonical_alias="nncf.CompressionFormat")
Expand Down
23 changes: 23 additions & 0 deletions src/nncf/quantization/advanced_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from nncf.quantization.range_estimator import AggregatorType
from nncf.quantization.range_estimator import RangeEstimatorParameters
from nncf.quantization.range_estimator import StatisticsType
from nncf.tensor import TensorDataType

TTensor = Any

Expand Down Expand Up @@ -384,6 +385,25 @@ class AdvancedLoraCorrectionParameters:
use_int8_adapters: bool = True


@api()
@dataclass
class AdvancedAdaptiveCodebookParameters:
"""
Contains advanced parameters for adaptive codebook estimation.

:param value_type: The target tensor data type for the codebook.
:type value_type: TensorDataType
:param per_block: Whether to use per-block codebooks (e.g., all down_proj has the same codeboook).
:type per_block: bool
:param num_elements: The number of elements in each codebook entry.
:type num_elements: int
"""

value_type: TensorDataType = TensorDataType.f8e4m3
per_block: bool = True
num_elements: int = 16


@api()
@dataclass
class AdvancedCompressionParameters:
Expand Down Expand Up @@ -426,6 +446,9 @@ class AdvancedCompressionParameters:
lora_correction_params: AdvancedLoraCorrectionParameters = field(default_factory=AdvancedLoraCorrectionParameters)
backend_params: dict[str, Any] = field(default_factory=dict)
codebook: Optional[TTensor] = None
adaptive_codebook_params: AdvancedAdaptiveCodebookParameters = field(
default_factory=AdvancedAdaptiveCodebookParameters
)


@api()
Expand Down
Loading