diff --git a/.ci/cspell_dict.txt b/.ci/cspell_dict.txt
index fda0c856c1e..edbcb99ef05 100644
--- a/.ci/cspell_dict.txt
+++ b/.ci/cspell_dict.txt
@@ -126,6 +126,7 @@ elts
eltwise
eltwises
embeddingbag
+emnlp
eprint
eprinttype
errstate
@@ -172,6 +173,7 @@ hardswish
hardtanh
hawq
headport
+hellaswag
hiddens
hparam
hparams
@@ -199,6 +201,7 @@ inputless
interp
ioff
iterval
+jinjie
junitxml
keepdim
keepdims
@@ -307,6 +310,7 @@ onnxp
onnxqdq
onnxroi
onnxx
+openbookqa
openvino
openvinotoolkit
oplambda
@@ -428,6 +432,7 @@ sparsifiers
sparsifies
sparsify
sparsifying
+sqft
squeezenet
stabilityai_stablelm
stdv
@@ -481,6 +486,7 @@ vtype
weakrefs
weightable
whowhatbench
+winogrande
xlabel
xnli
xticklabels
diff --git a/examples/llm_compression/torch/qat_with_nls_downstream/README.md b/examples/llm_compression/torch/qat_with_nls_downstream/README.md
new file mode 100644
index 00000000000..10a72d15b7b
--- /dev/null
+++ b/examples/llm_compression/torch/qat_with_nls_downstream/README.md
@@ -0,0 +1,97 @@
+# Quantization-aware NLS Tuning for improving accuracy on downstream Tasks
+
+This example demonstrates how to improve accuracy of Large Language Models (LLMs) with 4bit weights by
+quantization-aware-training with **Neural Low-Rank Adapter Search (NLS)** on downstream tasks.
+
+
+
+
+
+[main.py](main.py) supports fine-tuning and evaluating a language model with quantization-aware training and **Neural Low-Rank Adapter Search (NLS)** proposed by [Shears](https://arxiv.org/abs/2404.10934) and [SQFT](https://arxiv.org/abs/2410.03750) on various downstream tasks. For example, to run the script for the task [openbookqa](https://huggingface.co/datasets/allenai/openbookqa), you can use the following command:
+
+```bash
+python main.py --pretrained Qwen/Qwen2.5-3B-Instruct --output_dir output --do_train --task openbookqa --lr 1e-4 --epochs 3 --batch_size 16 --eval_batch_size 64 --lora_rank_space 32 24 16
+```
+
+- `--pretrained`: The model ID or path of a pretrained Hugging Face model configuration.
+- `--output_dir`: Path to the directory for storing logs, tuning checkpoints, compressed models, and evaluation results.
+- `--do_train`: Whether to perform training. If not specified, the script will only evaluate the compressed model.
+- `--task`: The evaluation task to be performed. Choices: ["gsm8k", "hellaswag", "openbookqa", "winogrande", "arc_challenge", "arc_easy"].
+- `--lr`: Learning rate for fine-tuning.
+- `--epochs`: Number of epochs for training.
+- `--batch_size`: Size of the training batch.
+- `--eval_batch_size`: Size of the batch for evaluation.
+- `--lora_rank_space`: Specifies the search space for LoRA adapter ranks. For example, [32, 24, 16] indicates the ranks to be considered during NLS training and searching.
+- `--resume`: Whether to resume training from a checkpoint. If specified, the script will load the trained checkpoint and continue training or evaluation.
+- `--custom_rank_config`: Specifies the LoRA rank of adapters per layer.
+
+Regarding evaluation, the script will automatically use a heuristic to obtain a good configuration for evaluation. This default strategy takes advantage of some information from the training phase and requires the evaluation of only 7 suggested configurations. This is automatically done in the example script, and only the best configuration from these candidates is returned to the user. More powerful elastic LoRA NLS configurations can be optionally obtained through more advanced search algorithms. We also support testing a custom configuration for evaluation after training. The following command will load the trained checkpoint and test the specified LoRA rank configuration:
+
+```bash
+python main.py --pretrained Qwen/Qwen2.5-3B-Instruct --output_dir output --resume --task openbookqa --lora_rank_space 32 24 16 --custom_rank_config 32 24 16 24 24 32 24 32 32 16 24 16 24 32 24 16 24 24 32 32 24 32 32 16 32 32 24 32
+```
+
+This script also supports running the vanilla LoRA method. We only need to pass a single number for `--lora_rank_space`, such as `--lora_rank_space 32`. In addition, the training time of LoRA and NLS is very similar, and there is almost no overhead in activating different sub-adapters during training. For instance, fine-tuning the compressed Llama-3.2-3B-Instruct model for 3 epochs on [arc-challenge](https://huggingface.co/datasets/allenai/ai2_arc) takes 161.83 seconds with LoRA and 164.89 seconds with NLS.
+
+## Results
+
+The table illustrates that Quantization-Aware Training integrated with absorbable QAT + LoRA / QAT + NLS substantially improves the performance of compressed models on downstream tasks, and QAT + NLS performs better than QAT + LoRA overall.
+
+The average score in the table represent the average accuracy of the four downstream tasks, [openbookqa](https://huggingface.co/datasets/allenai/openbookqa), [winogrande](https://huggingface.co/datasets/allenai/winogrande), [arc-challenge](https://huggingface.co/datasets/allenai/ai2_arc) and [arc-easy](https://huggingface.co/datasets/allenai/ai2_arc) (all are "acc_norm" except winogrande which is "acc" of [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness)). For QAT + LoRA and QAT + NLS, we conducted experiments with epochs set to 3, 4, and 5, LoRA rank set to 16 and 32, the corresponding LoRA rank space of NLS set to `[16,12,8]` and `[32,24,16]`. We present the best results for each method. All quantization methods compressed the models to `INT4_ASYM` precision with a group size of `64`.
+
+| Model | Precision | Average score |
+|--------------------------------------|--------------------|---------------|
+| google/gemma-2-2b-it | INT4 (QAT + LoRA) | 0.6801 |
+| | INT4 (QAT + NLS) | **0.6843** |
+| Qwen/Qwen2.5-3B-Instruct | INT4 (QAT + LoRA) | 0.6916 |
+| | INT4 (QAT + NLS) | **0.6966** |
+| mistralai/Mistral-7B-v0.3 | INT4 (QAT + LoRA) | 0.7164 |
+| | INT4 (QAT + NLS) | **0.7291** |
+| meta-llama/Llama-3.2-3B-Instruct | INT4 (QAT + LoRA) | 0.6510 |
+| | INT4 (QAT + NLS) | **0.6570** |
+| HuggingFaceTB/SmolLM-1.7B-Instruct | INT4 (QAT + LoRA) | **0.5765** |
+| | INT4 (QAT + NLS) | 0.5733 |
+| meta-llama/Meta-Llama-3-8B | INT4 (QAT + LoRA) | 0.7236 |
+| | INT4 (QAT + NLS) | **0.7350** |
+| meta-llama/Meta-Llama-3-8B-Instruct | INT4 (QAT + LoRA) | 0.7076 |
+| | INT4 (QAT + NLS) | **0.7128** |
+| meta-llama/Llama-3.1-8B | INT4 (QAT + LoRA) | 0.7243 |
+| | INT4 (QAT + NLS) | **0.7297** |
+| meta-llama/Llama-3.1-8B-Instruct | INT4 (QAT + LoRA) | 0.7140 |
+| | INT4 (QAT + NLS) | **0.7166** |
+| Qwen/Qwen2.5-7B | INT4 (QAT + LoRA) | 0.7366 |
+| | INT4 (QAT + NLS) | **0.7408** |
+| Qwen/Qwen2.5-7B-Instruct | INT4 (QAT + LoRA) | 0.7356 |
+| | INT4 (QAT + NLS) | **0.7382** |
+
+## Citation
+
+If you find this code and the NLS technique helpful, please kindly cite:
+
+```bibtex
+@inproceedings{munoz2025low,
+ title=Low-Rank Adapters Meet Neural Architecture Search for LLM Compression,
+ author="Munoz, J. Pablo and
+ Yuan, Jinjie and
+ Jain, Nilesh",,
+ booktitle={AAAI'25 workshop on CoLoRAI - Connecting Low-Rank Representations in AI},
+ year={2025},
+ url={https://arxiv.org/abs/2501.16372}
+}
+```
+
+```bibtex
+@inproceedings{munoz-2024-sqft,
+ title = "{SQFT}: Low-cost Model Adaptation in Low-precision Sparse Foundation Models",
+ author = "Munoz, Juan Pablo and
+ Yuan, Jinjie and
+ Jain, Nilesh",
+ booktitle = "Findings of the Association for Computational Linguistics: EMNLP 2024",
+ month = nov,
+ year = "2024",
+ address = "Miami, Florida, USA",
+ publisher = "Association for Computational Linguistics",
+ url = "https://aclanthology.org/2024.findings-emnlp.749",
+ pages = "12817--12832",
+}
+```
diff --git a/examples/llm_compression/torch/qat_with_nls_downstream/main.py b/examples/llm_compression/torch/qat_with_nls_downstream/main.py
new file mode 100644
index 00000000000..76b9d1e5445
--- /dev/null
+++ b/examples/llm_compression/torch/qat_with_nls_downstream/main.py
@@ -0,0 +1,709 @@
+# Copyright (c) 2025 Intel Corporation
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import argparse
+import json
+import random
+import re
+import shutil
+import sys
+import warnings
+from collections import defaultdict
+from datetime import datetime
+from pathlib import Path
+from typing import Union
+
+import datasets
+import numpy as np
+import torch
+import transformers
+from lm_eval import evaluator
+from lm_eval.models.huggingface import HFLM
+from torch import Tensor
+from torch import nn
+from torch.jit import TracerWarning
+from torch.utils.tensorboard import SummaryWriter
+from transformers import AutoModelForCausalLM
+from transformers import AutoTokenizer
+from transformers import get_cosine_schedule_with_warmup
+
+from examples.llm_compression.torch.qat_with_lora.main import load_checkpoint
+from examples.llm_compression.torch.qat_with_lora.main import save_checkpoint
+from examples.llm_compression.torch.qat_with_lora.main import set_trainable
+from nncf.common.logging.track_progress import track
+from nncf.data.dataset import Dataset
+from nncf.parameters import CompressionFormat
+from nncf.parameters import CompressWeightsMode
+from nncf.quantization.advanced_parameters import AdvancedCompressionParameters
+from nncf.quantization.quantize_model import compress_weights
+from nncf.torch.function_hook.wrapper import get_hook_storage
+from nncf.torch.quantization.layers import AsymmetricLoraNLSQuantizer
+from nncf.torch.quantization.layers import SymmetricLoraNLSQuantizer
+
+warnings.filterwarnings("ignore", category=TracerWarning)
+
+
+def get_gsm8k() -> list[str]:
+ """
+ Loads and processes the GSM8K dataset.
+
+ This function loads the GSM8K dataset, processes each sample to extract relevant fields,
+ and formats the data into prompts suitable for training.
+
+ :return: A list of processed prompts from the GSM8K dataset.
+ """
+ train_dataset = datasets.load_dataset("gsm8k", "main", split="train")
+ processed_train_dataset = []
+ for sample in train_dataset:
+ prompt = f"Question: {sample['question']}\nAnswer: {sample['answer']}"
+ processed_train_dataset.append(prompt)
+
+ return processed_train_dataset
+
+
+def get_hellaswag() -> list[str]:
+ """
+ Loads and processes the HellaSwag dataset.
+
+ :return: A list of processed prompts from the HellaSwag dataset.
+ """
+
+ def preprocess(text):
+ """Preprocess the text by removing unwanted characters and formatting."""
+ text = text.strip()
+ text = text.replace(" [title]", ". ")
+ text = re.sub("\\[.*?\\]", "", text)
+ text = text.replace(" ", " ")
+ return text
+
+ train_dataset = datasets.load_dataset("hellaswag", split="train")
+ processed_train_dataset = []
+ for sample in train_dataset:
+ context = sample["ctx_a"] + " " + sample["ctx_b"].capitalize()
+ document = {
+ "query": preprocess(sample["activity_label"] + ": " + context),
+ "choices": [preprocess(ending) for ending in sample["endings"]],
+ "gold": int(sample["label"]),
+ }
+ query = document["query"]
+ answer = document["choices"][document["gold"]]
+ prompt = query + " " + answer
+ processed_train_dataset.append(prompt)
+
+ return processed_train_dataset
+
+
+def get_openbookqa() -> list[str]:
+ """
+ Loads and processes the OpenBookQA dataset.
+
+ :return: A list of processed prompts from the OpenBookQA dataset.
+ """
+ train_dataset = datasets.load_dataset("openbookqa", split="train")
+ processed_train_dataset = []
+ for sample in train_dataset:
+ document = {
+ "id": sample["id"],
+ "query": sample["question_stem"],
+ "choices": sample["choices"]["text"],
+ "gold": ["A", "B", "C", "D"].index(sample["answerKey"].strip()),
+ }
+ prompt = document["query"]
+ answer = document["choices"][document["gold"]]
+ prompt = prompt + " " + answer
+ processed_train_dataset.append(prompt)
+
+ return processed_train_dataset
+
+
+def get_winogrande() -> list[str]:
+ """
+ Loads and processes the Winogrande dataset.
+
+ :return: A list of processed prompts from the Winogrande dataset.
+ """
+ train_dataset = datasets.load_dataset("winogrande", "winogrande_debiased", split="train")
+ processed_train_dataset = []
+ for sample in train_dataset:
+ pronoun_location = sample["sentence"].index("_")
+ answer = sample["option" + sample["answer"]]
+ prompt = sample["sentence"][:pronoun_location] + answer + sample["sentence"][pronoun_location + 1 :]
+ processed_train_dataset.append(prompt)
+
+ return processed_train_dataset
+
+
+def get_arc(name: str = "ARC-Easy") -> list[str]:
+ """
+ Loads and processes the ARC (ARC-Easy or ARC-Challenge) dataset.
+
+ :return: A list of processed prompts from the ARC dataset.
+ """
+ train_dataset = datasets.load_dataset("ai2_arc", name, split="train")
+ processed_train_dataset = []
+ for sample in train_dataset:
+ # Map numeric answer keys to letter representations.
+ num_to_letter = {"1": "A", "2": "B", "3": "C", "4": "D", "5": "E"}
+ sample["answerKey"] = num_to_letter.get(sample["answerKey"], sample["answerKey"])
+
+ # Process the ARC document to extract relevant fields.
+ processed_document = {
+ "id": sample["id"],
+ "query": "Question: " + sample["question"] + "\nAnswer:",
+ "choices": sample["choices"]["text"],
+ "gold": ["A", "B", "C", "D", "E"].index(sample["answerKey"]),
+ }
+
+ # Construct the prompt with the correct answer.
+ answer = processed_document["choices"][processed_document["gold"]]
+ prompt = processed_document["query"] + " " + answer
+ processed_train_dataset.append(prompt)
+
+ return processed_train_dataset
+
+
+def lm_eval(model: nn.Module, tokenizer: AutoTokenizer, task: str, batch_size: int = 1) -> dict[str, any]:
+ """
+ Evaluates a language model on a specified task using the lm-eval library.
+
+ This function initializes a HFLM (from lm-eval) with the provided model and tokenizer,
+ and then evaluates it on the specified task.
+
+ :param model: The language model to be evaluated.
+ :param tokenizer: The tokenizer corresponding to the language model.
+ :param task: The evaluation tasks or task configs.
+ :param batch_size: The batch size to be used during evaluation.
+ :return: A dictionary containing the evaluation results.
+ """
+ lm = HFLM(pretrained=model, tokenizer=tokenizer, batch_size=batch_size)
+ results = evaluator.simple_evaluate(lm, tasks=task, log_samples=False)["results"]
+ return results[task]
+
+
+def tokenize(
+ tokenizer: AutoTokenizer,
+ prompt: str,
+ add_eos_token: bool = True,
+ max_length: int = 256,
+) -> dict[str, list[int]]:
+ """
+ Tokenize the given prompt.
+
+ :param tokenizer: The tokenizer to use.
+ :param prompt: The prompt to tokenize.
+ :param add_eos_token: Whether to add an eos token.
+ :param max_length: The maximum length of the tokenized input.
+ :return: A dictionary containing tokenized input ids, attention mask, and labels.
+ """
+ result = tokenizer(
+ prompt,
+ truncation=True,
+ max_length=max_length,
+ padding=True,
+ return_tensors=None,
+ )
+ if result["input_ids"][-1] != tokenizer.eos_token_id and len(result["input_ids"]) < max_length and add_eos_token:
+ result["input_ids"].append(tokenizer.eos_token_id)
+ result["attention_mask"].append(1)
+
+ result["labels"] = result["input_ids"].copy()
+ return result
+
+
+def get_layer_id_vs_lora_quantizers_map(
+ model: nn.Module,
+) -> dict[int, list[Union["AsymmetricLoraNLSQuantizer", "SymmetricLoraNLSQuantizer"]]]:
+ """
+ Maps layer IDs to their corresponding LoRA quantizers.
+
+ :param model: The model containing LoRA quantizers.
+ :return: A dictionary mapping layer IDs to lists of LoRA quantizers.
+ """
+ hook_storage = get_hook_storage(model)
+ layer_id_vs_lora_quantizers_map = defaultdict(list)
+
+ for name, module in hook_storage.named_hooks():
+ if isinstance(module, (AsymmetricLoraNLSQuantizer, SymmetricLoraNLSQuantizer)) and (module.num_bits == 4):
+ match = re.search(r"layers:(\d+):", name)
+ if match is None:
+ msg = (
+ "Model is supposed to have a specific structure with Transformer blocks "
+ "stored as follows: self.layers = nn.ModuleList(...)"
+ )
+ raise ValueError(msg)
+ layer_id = int(match.group(1))
+ layer_id_vs_lora_quantizers_map[layer_id].append(module)
+
+ return layer_id_vs_lora_quantizers_map
+
+
+@torch.no_grad()
+def configure_lora_adapters(
+ layer_id_vs_lora_quantizers_map: dict[int, list[Union["AsymmetricLoraNLSQuantizer", "SymmetricLoraNLSQuantizer"]]],
+ lora_rank_space: list[int] = None,
+ adapter_strategy: str = None,
+ specific_rank_config: list[int] = None,
+) -> list[int]:
+ """
+ Configures sub-adapters with specified ranks (or adapter strategy) for each layer in the model.
+
+ :param layer_id_vs_lora_quantizers_map: A dictionary mapping layer IDs to lists of LoRA quantizers.
+ :param lora_rank_space: A list of possible ranks for the LoRA adapters.
+ :param adapter_strategy: Strategy to select the rank from the `lora_rank_space`.
+ Options are 'maximal', 'median', 'minimal', 'random'.
+ :param specific_rank_config: A specific configuration of ranks for each layer.
+ :return: A list of activated ranks for each layer.
+ """
+ # Ensure that either [`lora_rank_space` and `adapter_strategy`] or [`specific_rank_config`] is provided
+ if specific_rank_config is None:
+ assert lora_rank_space and adapter_strategy, (
+ "`specific_rank_config` is not provided, both `lora_rank_space` and `adapter_strategy` must be specified."
+ )
+ else:
+ assert len(specific_rank_config) == len(layer_id_vs_lora_quantizers_map), (
+ "Length of specific_rank_config must match the number of layers."
+ )
+
+ activated_rank_config = []
+ for layer, lora_quantizers in layer_id_vs_lora_quantizers_map.items():
+ if specific_rank_config is not None:
+ selected_rank = specific_rank_config[layer]
+ else:
+ if adapter_strategy == "maximal":
+ selected_rank = lora_rank_space[0]
+ elif adapter_strategy == "median":
+ selected_rank = lora_rank_space[(len(lora_rank_space) - 1) // 2]
+ elif adapter_strategy == "minimal":
+ selected_rank = lora_rank_space[-1]
+ elif adapter_strategy == "random":
+ selected_rank = int(np.random.choice(lora_rank_space))
+ else:
+ error_message = "Invalid adapter strategy"
+ raise ValueError(error_message)
+
+ # Activate the sub-adapter with the selected rank
+ for lora_quantizer in lora_quantizers:
+ lora_quantizer.set_active_rank(selected_rank)
+ activated_rank_config.append(selected_rank)
+
+ return activated_rank_config
+
+
+def get_argument_parser() -> argparse.ArgumentParser:
+ parser = argparse.ArgumentParser(add_help=True)
+
+ # Model params
+ parser.add_argument(
+ "--pretrained",
+ type=str,
+ default="Qwen/Qwen2.5-3B-Instruct",
+ help="The model id or path of a pretrained HF model configuration.",
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=Path,
+ default="output",
+ help="Path to the directory for storing logs, tuning checkpoint, compressed model, validation references.",
+ )
+ parser.add_argument(
+ "--resume",
+ action="store_true",
+ help="Whether to start from previously saved checkpoint. If not specified or checkpoint does not exist, "
+ "start from scratch by post-training weight compression initialization.",
+ )
+ parser.add_argument(
+ "--do_train",
+ action="store_true",
+ )
+
+ # Downstream task
+ parser.add_argument(
+ "--task",
+ type=str,
+ choices=[
+ "openbookqa",
+ "winogrande",
+ "arc_challenge",
+ "arc_easy",
+ "gsm8k",
+ "hellaswag",
+ ],
+ default="openbookqa",
+ help="Evaluation task",
+ )
+ parser.add_argument(
+ "--lm_eval_metric",
+ type=str,
+ default="acc_norm,none",
+ help="The metrics of the lm-eval task. Different tasks have different metrics.",
+ )
+
+ # Training params
+ parser.add_argument(
+ "--lr",
+ type=float,
+ default=1e-4,
+ help="Learning rate for fine-tuning. "
+ "For larger models (over 2 billion parameters), a learning rate of 5e-4 is recommended.",
+ )
+ parser.add_argument("--epochs", type=int, default=3, help="Number of epochs.")
+ parser.add_argument("--batch_size", type=int, default=16, help="Size of training batch.")
+ parser.add_argument(
+ "--microbatch_size",
+ type=int,
+ default=16,
+ help="Size of each training microbatch. Gradients will be accumulated until the batch size is reached.",
+ )
+ parser.add_argument("--eval_batch_size", type=int, default=64, help="Size of batch for evaluation.")
+
+ # Neural Low-rank Adapter Search (NLS) params
+ parser.add_argument(
+ "--lora_rank_space",
+ type=int,
+ nargs="+",
+ default=None,
+ help="Search space for LoRA adapter ranks. For example, if the (maximum) rank is 32, "
+ "this can be [32, 24, 16] to specify the ranks to be used during NLS.",
+ )
+ parser.add_argument(
+ "--custom_rank_config",
+ type=int,
+ nargs="+",
+ default=None,
+ help="Custom LoRA rank configuration (NLS) for evaluation.",
+ )
+ return parser
+
+
+def main(argv) -> float:
+ """
+ Fine-tuning and evaluating a language model with quantization-aware training and LoRA adapters,
+ including optional Neural Low-rank Adapter Search (NLS).
+ """
+ parser = get_argument_parser()
+ args = parser.parse_args(argv)
+ assert torch.cuda.is_available()
+ transformers.set_seed(42)
+ device = "cuda"
+ torch_dtype = torch.bfloat16
+ lora_rank = max(args.lora_rank_space)
+ disable_nls = len(args.lora_rank_space) == 1
+ compression_format = CompressionFormat.FQ_LORA if disable_nls else CompressionFormat.FQ_LORA_NLS
+ compression_config = dict(
+ mode=CompressWeightsMode.INT4_ASYM,
+ group_size=64,
+ compression_format=compression_format,
+ advanced_parameters=AdvancedCompressionParameters(lora_adapter_rank=lora_rank),
+ )
+
+ # Configure output and log files.
+ output_dir = Path(args.output_dir)
+ tensorboard_dir = output_dir / "tb" / datetime.now().strftime("%Y-%m-%d__%H-%M-%S")
+ last_dir = output_dir / "last"
+ result_file = output_dir / "result.json"
+
+ if not args.resume:
+ shutil.rmtree(output_dir, ignore_errors=True)
+ for path in [output_dir, tensorboard_dir, last_dir]:
+ path.mkdir(exist_ok=True, parents=True)
+ ckpt_file = last_dir / "nncf_checkpoint.pth"
+ print(f"To visualize the loss, open Tensorboard using the logs from: {tensorboard_dir}")
+ tb = SummaryWriter(tensorboard_dir, "QAT with absorbable LoRA")
+ overall_result = {}
+
+ # Load original model and tokenizer.
+ model = AutoModelForCausalLM.from_pretrained(args.pretrained, torch_dtype=torch_dtype, device_map=device)
+ tokenizer = AutoTokenizer.from_pretrained(args.pretrained)
+ if tokenizer.pad_token is None:
+ tokenizer.pad_token = tokenizer.eos_token
+
+ results_before_compression = lm_eval(model, tokenizer, task=args.task, batch_size=args.eval_batch_size)
+ print(f"Results before compression={json.dumps(results_before_compression, indent=4)}")
+ overall_result["results_before_compression"] = results_before_compression
+
+ # Dataset preparation
+ train_dataset = None
+ if args.task == "gsm8k":
+ train_dataset = get_gsm8k()
+ elif args.task == "hellaswag":
+ train_dataset = get_hellaswag()
+ elif args.task == "openbookqa":
+ train_dataset = get_openbookqa()
+ elif args.task == "winogrande":
+ train_dataset = get_winogrande()
+ elif args.task == "arc_challenge":
+ train_dataset = get_arc(name="ARC-Challenge")
+ elif args.task == "arc_easy":
+ train_dataset = get_arc(name="ARC-Easy")
+ else:
+ error_message = f"Unsupported task: {args.task}."
+ raise ValueError(error_message)
+ model_input = model.dummy_inputs
+ train_dataset = [tokenize(tokenizer, sample) for sample in train_dataset]
+ random.shuffle(train_dataset)
+
+ model = compress_weights(
+ model,
+ dataset=Dataset([{k: v.to(device) for k, v in model_input.items()}]),
+ **compression_config,
+ )
+ results_of_compressed_model = lm_eval(model, tokenizer, task=args.task, batch_size=args.eval_batch_size)
+ print(f"Results of NNCF compressed model={json.dumps(results_of_compressed_model, indent=4)}")
+ overall_result["results_of_compressed_model"] = results_of_compressed_model
+ initial_result = results_of_compressed_model[args.lm_eval_metric]
+
+ # Create or load model to tune with Fake Quantizers and absorbable LoRA adapters.
+ if args.resume and ckpt_file.exists():
+ model = AutoModelForCausalLM.from_pretrained(args.pretrained, torch_dtype=torch_dtype, device_map=device)
+ model = load_checkpoint(model, model_input, ckpt_file)
+ else:
+ save_checkpoint(model, ckpt_file)
+
+ layer_id_vs_lora_quantizers_map = None
+ if not disable_nls:
+ layer_id_vs_lora_quantizers_map = get_layer_id_vs_lora_quantizers_map(model)
+
+ if args.do_train:
+ fq_lr = args.lr / 10
+ weight_decay = args.lr
+ param_to_train = set_trainable(model, lora_lr=args.lr, fq_lr=fq_lr)
+ opt = torch.optim.AdamW(param_to_train, weight_decay=weight_decay)
+
+ grad_accumulation_steps = args.batch_size // args.microbatch_size
+ num_samples = len(train_dataset)
+ epoch_samples = num_samples - num_samples % args.microbatch_size
+ microbatches_per_epoch = epoch_samples // args.microbatch_size
+ aggregated_loss = float("nan")
+ loss_numerator = grad_steps = total_microbatches = 0
+ data_collator = transformers.DataCollatorForSeq2Seq(
+ tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
+ )
+ total_steps = (microbatches_per_epoch * args.epochs) // grad_accumulation_steps
+ scheduler = get_cosine_schedule_with_warmup(
+ opt, num_warmup_steps=int(0.1 * total_steps), num_training_steps=total_steps
+ )
+
+ if disable_nls:
+ activation_counter = None
+ loss_recorder = None
+ else:
+ # Initialize the counter for tracking activation counts during training
+ maximal_lora_rank_config = configure_lora_adapters(
+ layer_id_vs_lora_quantizers_map,
+ lora_rank_space=args.lora_rank_space,
+ adapter_strategy="maximal",
+ )
+ activation_counter = [
+ {rank: 0 for rank in args.lora_rank_space} for _ in range(len(maximal_lora_rank_config))
+ ]
+
+ # Initialize the loss recorder for tracking losses during training (for each sub-adapter)
+ loss_recorder = defaultdict(list)
+
+ for epoch in range(args.epochs):
+ batch_indices_epoch = torch.randperm(num_samples)[:epoch_samples].chunk(microbatches_per_epoch)
+ for indices in track(batch_indices_epoch, description=f"Train epoch {epoch}"):
+ # If Neural Low-rank Adapter Search (NLS) is enabled,
+ # configure the LoRA adapters with a random rank configuration from the specified rank space.
+ if not disable_nls and grad_steps == 0:
+ current_config = configure_lora_adapters(
+ layer_id_vs_lora_quantizers_map,
+ lora_rank_space=args.lora_rank_space,
+ adapter_strategy="random",
+ )
+ # Update the activation counter
+ for idx, rank in enumerate(current_config):
+ activation_counter[idx][rank] += 1
+ current_config_tuple = tuple(current_config)
+
+ indices = indices.tolist()
+ total_microbatches += 1
+
+ def form_batch(inputs: list[Tensor]):
+ batch = [inputs[i] for i in indices]
+ batch = data_collator(batch)
+ batch = {k: v.to(device) for k, v in batch.items()}
+ return batch
+
+ inputs = form_batch(train_dataset)
+ outputs = model(**inputs)
+ loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
+
+ # Record the loss for the current configuration
+ if not disable_nls:
+ loss_recorder[current_config_tuple].append(loss.item())
+
+ # Perform an optimization step after accumulating gradients over multiple minibatches.
+ loss_numerator += loss.item()
+ grad_steps += 1
+ if not torch.isfinite(loss).item():
+ err = f"Fine-tuning loss is {loss}"
+ raise ValueError(err)
+ (loss / grad_accumulation_steps).backward()
+ if grad_steps == grad_accumulation_steps:
+ opt.step()
+ scheduler.step()
+ opt.zero_grad()
+ aggregated_loss = loss_numerator / grad_steps
+ loss_numerator = grad_steps = 0
+
+ current_lr = scheduler.get_last_lr()[0]
+ if total_microbatches % 10 == 0:
+ print(
+ f"Epoch: {epoch + 1}, "
+ f"Step: {total_microbatches}, "
+ f"Loss: {aggregated_loss:.4f}, "
+ f"Learning Rate: {current_lr:.6f}"
+ )
+ tb.add_scalar("learning_rate", current_lr, total_microbatches)
+ tb.add_scalar("loss", aggregated_loss, total_microbatches)
+
+ save_checkpoint(model, ckpt_file)
+
+ # Start evaluation
+ if disable_nls:
+ results_of_lora_finetuned_compressed_model = lm_eval(
+ model, tokenizer, task=args.task, batch_size=args.eval_batch_size
+ )
+ print(
+ f"Results of quantization-aware-finetuned (LoRA) NNCF compressed model="
+ f"{json.dumps(results_of_lora_finetuned_compressed_model, indent=4)}"
+ )
+ overall_result["lora_results"] = results_of_lora_finetuned_compressed_model
+ best_result = results_of_lora_finetuned_compressed_model[args.lm_eval_metric]
+ else:
+ overall_result["nls_results"] = []
+ # Use some of the signals from training to find some heuristic configurations for evaluation.
+ if args.do_train:
+ # Extract the most frequently activated configuration
+ def get_most_frequent_config(activation_counter):
+ most_frequent_config = []
+ for layer_counter in activation_counter:
+ most_frequent_rank = max(layer_counter, key=layer_counter.get)
+ most_frequent_config.append(most_frequent_rank)
+ return most_frequent_config
+
+ # Calculate the average loss for each configuration and select the top k with the minimum loss
+ def get_top_k_min_loss_configs(loss_recorder, k=5):
+ avg_loss_configs = [(config, sum(losses) / len(losses)) for config, losses in loss_recorder.items()]
+ avg_loss_configs.sort(key=lambda x: x[1])
+ top_k_configs = [list(config) for config, _ in avg_loss_configs[:k]]
+ return top_k_configs
+
+ best_result = initial_result
+ # Test the median configuration
+ median_lora_rank_config = configure_lora_adapters(
+ layer_id_vs_lora_quantizers_map,
+ lora_rank_space=args.lora_rank_space,
+ adapter_strategy="median",
+ )
+ results_of_nls_finetuned_compressed_model_median = lm_eval(
+ model, tokenizer, task=args.task, batch_size=args.eval_batch_size
+ )
+ print(
+ f"Results of quantization-aware-finetuned (NLS-Median) NNCF compressed model="
+ f"{json.dumps(results_of_nls_finetuned_compressed_model_median, indent=4)}"
+ )
+ overall_result["nls_results"].append(
+ {
+ "type": "median",
+ "config": median_lora_rank_config,
+ "results": results_of_nls_finetuned_compressed_model_median,
+ }
+ )
+ best_result = max(
+ best_result,
+ results_of_nls_finetuned_compressed_model_median[args.lm_eval_metric],
+ )
+
+ # Test the most frequent configuration
+ most_frequent_lora_rank_config = get_most_frequent_config(activation_counter)
+ configure_lora_adapters(
+ layer_id_vs_lora_quantizers_map,
+ specific_rank_config=most_frequent_lora_rank_config,
+ )
+ results_of_nls_finetuned_compressed_model_most_frequent = lm_eval(
+ model, tokenizer, task=args.task, batch_size=args.eval_batch_size
+ )
+ print(
+ f"Results of quantization-aware-finetuned (NLS-Most-Frequent) NNCF compressed model="
+ f"{json.dumps(results_of_nls_finetuned_compressed_model_most_frequent, indent=4)}"
+ )
+ overall_result["nls_results"].append(
+ {
+ "type": "most-frequent",
+ "config": most_frequent_lora_rank_config,
+ "results": results_of_nls_finetuned_compressed_model_most_frequent,
+ }
+ )
+ best_result = max(
+ best_result,
+ results_of_nls_finetuned_compressed_model_most_frequent[args.lm_eval_metric],
+ )
+
+ # Test the top 5 min loss configurations
+ top_5_min_loss_configs = get_top_k_min_loss_configs(loss_recorder, k=5)
+ for i, min_loss_config in enumerate(top_5_min_loss_configs):
+ configure_lora_adapters(
+ layer_id_vs_lora_quantizers_map,
+ specific_rank_config=min_loss_config,
+ )
+ results_of_nls_finetuned_compressed_model_min_loss = lm_eval(
+ model, tokenizer, task=args.task, batch_size=args.eval_batch_size
+ )
+ print(
+ f"Results of quantization-aware-finetuned (NLS-Min-Loss-{i + 1}) NNCF compressed model="
+ f"{json.dumps(results_of_nls_finetuned_compressed_model_min_loss, indent=4)}"
+ )
+ overall_result["nls_results"].append(
+ {
+ "type": f"min-loss-{i + 1}",
+ "config": min_loss_config,
+ "results": results_of_nls_finetuned_compressed_model_min_loss,
+ }
+ )
+ best_result = max(
+ best_result,
+ results_of_nls_finetuned_compressed_model_min_loss[args.lm_eval_metric],
+ )
+ else:
+ assert args.custom_rank_config is not None, "Please provide `custom_rank_config` for evaluation."
+ configure_lora_adapters(
+ layer_id_vs_lora_quantizers_map,
+ specific_rank_config=args.custom_rank_config,
+ )
+ results_of_nls_finetuned_compressed_model_custom = lm_eval(
+ model, tokenizer, task=args.task, batch_size=args.eval_batch_size
+ )
+ print(
+ f"Results of quantization-aware-finetuned (NLS with custom config) NNCF compressed model="
+ f"{json.dumps(results_of_nls_finetuned_compressed_model_custom, indent=4)}"
+ )
+ overall_result["nls_results"].append(
+ {
+ "type": "custom",
+ "config": args.custom_rank_config,
+ "results": results_of_nls_finetuned_compressed_model_custom,
+ }
+ )
+ best_result = results_of_nls_finetuned_compressed_model_custom[args.lm_eval_metric]
+
+ print(f"Overall result: {json.dumps(overall_result, indent=4)}")
+ # Save results
+ with open(result_file, "w") as f:
+ json.dump(overall_result, f, indent=4)
+
+ result_diff = best_result - initial_result
+ result_diff = round(result_diff, 2)
+ return result_diff
+
+
+if __name__ == "__main__":
+ main(sys.argv[1:])
diff --git a/examples/llm_compression/torch/qat_with_nls_downstream/pics/lora_vs_nls.png b/examples/llm_compression/torch/qat_with_nls_downstream/pics/lora_vs_nls.png
new file mode 100644
index 00000000000..2d504dea9b6
--- /dev/null
+++ b/examples/llm_compression/torch/qat_with_nls_downstream/pics/lora_vs_nls.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a34b754bc6beb859bc04fd265c1bbeb98649818c0a9d50207298b308e73b6e16
+size 36016
diff --git a/examples/llm_compression/torch/qat_with_nls_downstream/requirements.txt b/examples/llm_compression/torch/qat_with_nls_downstream/requirements.txt
new file mode 100644
index 00000000000..cb6d5ccf15d
--- /dev/null
+++ b/examples/llm_compression/torch/qat_with_nls_downstream/requirements.txt
@@ -0,0 +1,8 @@
+tensorboard==2.13.0
+torch==2.7.0
+whowhatbench @ git+https://github.com/openvinotoolkit/openvino.genai@2025.1.0.0#subdirectory=tools/who_what_benchmark
+numpy>=1.23.5,<2
+openvino==2025.1
+optimum-intel>=1.22.0
+transformers>=4.48.0
+lm_eval==0.4.8
diff --git a/nncf/common/quantization/structs.py b/nncf/common/quantization/structs.py
index de4a61fef56..a9dae4792a4 100644
--- a/nncf/common/quantization/structs.py
+++ b/nncf/common/quantization/structs.py
@@ -35,14 +35,20 @@ class QuantizationScheme(StrEnum):
representing the lower and upper boundaries of the range, respectively.
:param SYMMETRIC_LORA: Symmetric quantization with Low-Rank Adapters (LoRA), involving the sum of weights and
the multiplication of low-rank adapters.
+ :param SYMMETRIC_LORA_NLS: Symmetric quantization with Low-Rank Adapters (LoRA) and Neural Low-Rank Adapter Search
+ (NLS), involving the sum of weights and the multiplication of low-rank adapters.
:param ASYMMETRIC_LORA: Asymmetric quantization with Low-Rank Adapters (LoRA), involving the sum of weights and
the multiplication of low-rank adapters.
+ :param ASYMMETRIC_LORA_NLS: Asymmetric quantization with Low-Rank Adapters (LoRA) and Neural Low-Rank Adapter Search
+ (NLS), involving the sum of weights and the multiplication of low-rank adapters.
"""
SYMMETRIC = "symmetric"
ASYMMETRIC = "asymmetric"
SYMMETRIC_LORA = "symmetric_lora"
+ SYMMETRIC_LORA_NLS = "symmetric_lora_nls"
ASYMMETRIC_LORA = "asymmetric_lora"
+ ASYMMETRIC_LORA_NLS = "asymmetric_lora_nls"
class QuantizerConfig:
diff --git a/nncf/parameters.py b/nncf/parameters.py
index 92b158fa9a6..be323837df7 100644
--- a/nncf/parameters.py
+++ b/nncf/parameters.py
@@ -112,11 +112,15 @@ class CompressionFormat(StrEnum):
the multiplication of adapters. This makes quantization-aware training (QAT) more efficient in terms of
accuracy, as adapters can also be tuned and remain computationally affordable during training due to their
small dimensions.
+ :param FQ_LORA_NLS: Represents the 'fake_quantize_with_lora_nls' format, which extends FQ_LORA with elastic
+ absorbable low-rank adapters (LoRA). Quantization is applied similarly to FQ_LORA, and utilizing NLS often
+ results in better performance for downstream task fine-tuning.
"""
DQ = "dequantize"
FQ = "fake_quantize"
FQ_LORA = "fake_quantize_with_lora"
+ FQ_LORA_NLS = "fake_quantize_with_lora_nls"
@api(canonical_alias="nncf.StripFormat")
diff --git a/nncf/quantization/algorithms/weight_compression/algorithm.py b/nncf/quantization/algorithms/weight_compression/algorithm.py
index 79363722934..9d67dba4ec1 100644
--- a/nncf/quantization/algorithms/weight_compression/algorithm.py
+++ b/nncf/quantization/algorithms/weight_compression/algorithm.py
@@ -199,8 +199,12 @@ def check_user_compression_configuration(
requires a dataset, but it's not provided."
raise nncf.ValidationError(msg)
- if lora_correction and compression_format in [CompressionFormat.FQ, CompressionFormat.FQ_LORA]:
- msg = "LoRA Correction algorithm is not compatible with FQ and FQ_LORA compression formats."
+ if lora_correction and compression_format in [
+ CompressionFormat.FQ,
+ CompressionFormat.FQ_LORA,
+ CompressionFormat.FQ_LORA_NLS,
+ ]:
+ msg = "LoRA Correction algorithm is not compatible with FQ, FQ_LORA and FQ_LORA_NLS compression formats."
raise nncf.ValidationError(msg)
diff --git a/nncf/quantization/algorithms/weight_compression/torch_backend.py b/nncf/quantization/algorithms/weight_compression/torch_backend.py
index 50f765c35c3..13e6abc751a 100644
--- a/nncf/quantization/algorithms/weight_compression/torch_backend.py
+++ b/nncf/quantization/algorithms/weight_compression/torch_backend.py
@@ -72,6 +72,7 @@
from nncf.torch.quantization.layers import INT4SymmetricWeightsDecompressor
from nncf.torch.quantization.layers import INT8AsymmetricWeightsDecompressor
from nncf.torch.quantization.layers import INT8SymmetricWeightsDecompressor
+from nncf.torch.quantization.layers import PTLoraNLSSpec
from nncf.torch.quantization.layers import PTLoraSpec
from nncf.torch.quantization.layers import PTQuantizerSpec
@@ -303,6 +304,12 @@ def get_fq_insertion_command(
if is_all_8bit and compression_format == CompressionFormat.FQ_LORA:
mode_vs_schema_map[CompressWeightsMode.INT8_ASYM] = QuantizationScheme.ASYMMETRIC_LORA
mode_vs_schema_map[CompressWeightsMode.INT8_SYM] = QuantizationScheme.SYMMETRIC_LORA
+ if compression_format == CompressionFormat.FQ_LORA_NLS:
+ mode_vs_schema_map[CompressWeightsMode.INT4_ASYM] = QuantizationScheme.ASYMMETRIC_LORA_NLS
+ mode_vs_schema_map[CompressWeightsMode.INT4_SYM] = QuantizationScheme.SYMMETRIC_LORA_NLS
+ if is_all_8bit:
+ mode_vs_schema_map[CompressWeightsMode.INT8_ASYM] = QuantizationScheme.ASYMMETRIC_LORA_NLS
+ mode_vs_schema_map[CompressWeightsMode.INT8_SYM] = QuantizationScheme.SYMMETRIC_LORA_NLS
schema = mode_vs_schema_map[compression_config.mode]
@@ -322,10 +329,23 @@ def get_fq_insertion_command(
)
quantizer_cls = QUANTIZATION_MODULES.get(schema)
- if schema in [QuantizationScheme.ASYMMETRIC_LORA, QuantizationScheme.SYMMETRIC_LORA]:
- lora_spec = PTLoraSpec(
- lora_rank=lora_adapter_rank, orig_weight_shape=orig_weight_shape, weight_shape=weight_shape
- )
+ if schema in [
+ QuantizationScheme.ASYMMETRIC_LORA,
+ QuantizationScheme.ASYMMETRIC_LORA_NLS,
+ QuantizationScheme.SYMMETRIC_LORA,
+ QuantizationScheme.SYMMETRIC_LORA_NLS,
+ ]:
+ if schema in [QuantizationScheme.ASYMMETRIC_LORA, QuantizationScheme.SYMMETRIC_LORA]:
+ lora_spec = PTLoraSpec(
+ lora_rank=lora_adapter_rank, orig_weight_shape=orig_weight_shape, weight_shape=weight_shape
+ )
+ else:
+ lora_spec = PTLoraNLSSpec(
+ lora_rank=lora_adapter_rank,
+ active_lora_rank=lora_adapter_rank,
+ orig_weight_shape=orig_weight_shape,
+ weight_shape=weight_shape,
+ )
quantizer = quantizer_cls(quantizer_spec, lora_spec)
lora_dtype = quantizer.lora_A.dtype
svd_residual = torch.rand(weight_shape).to(device) * scale / 100 # value on [0,1] * (1/100 of quant size)
@@ -337,7 +357,11 @@ def get_fq_insertion_command(
quantizer = quantizer_cls(quantizer_spec)
levels = quantizer.levels
- if schema in [QuantizationScheme.ASYMMETRIC_LORA, QuantizationScheme.ASYMMETRIC]:
+ if schema in [
+ QuantizationScheme.ASYMMETRIC_LORA,
+ QuantizationScheme.ASYMMETRIC_LORA_NLS,
+ QuantizationScheme.ASYMMETRIC,
+ ]:
zero_point = compressed_weight.zero_point.data
dtype = quantizer.input_low.dtype
# NOTE: Lose some accuracy, because of inversion of round
diff --git a/nncf/quantization/quantize_model.py b/nncf/quantization/quantize_model.py
index 41a90b212cd..6b69b654657 100644
--- a/nncf/quantization/quantize_model.py
+++ b/nncf/quantization/quantize_model.py
@@ -577,8 +577,8 @@ def compress_weights(
msg = "TorchFX does not supports statistics caching."
raise nncf.ParameterNotSupportedError(msg)
- if compression_format in [CompressionFormat.FQ, CompressionFormat.FQ_LORA]:
- msg = "Torch FX backend does not support FQ and FQ_LORA compression formats."
+ if compression_format in [CompressionFormat.FQ, CompressionFormat.FQ_LORA, CompressionFormat.FQ_LORA_NLS]:
+ msg = "Torch FX backend does not support FQ, FQ_LORA and FQ_LORA_NLS compression formats."
raise nncf.ParameterNotSupportedError(msg)
if (
@@ -606,8 +606,8 @@ def compress_weights(
msg = "Simultaneous use of Lora correction and GPTQ algorithms is not supported. Select one of them."
raise nncf.ParameterNotSupportedError(msg)
- if compression_format in [CompressionFormat.FQ, CompressionFormat.FQ_LORA]:
- msg = "OpenVINO backend does not support FQ and FQ_LORA compression formats."
+ if compression_format in [CompressionFormat.FQ, CompressionFormat.FQ_LORA, CompressionFormat.FQ_LORA_NLS]:
+ msg = "OpenVINO backend does not support FQ, FQ_LORA and FQ_LORA_NLS compression formats."
raise nncf.ParameterNotSupportedError(msg)
compression_weights_impl = ov_compress_weights_impl
diff --git a/nncf/torch/quantization/layers.py b/nncf/torch/quantization/layers.py
index 2eab2c555ce..75f32d2e909 100644
--- a/nncf/torch/quantization/layers.py
+++ b/nncf/torch/quantization/layers.py
@@ -197,6 +197,27 @@ def get_state(self):
return {arg: getattr(self, arg) for arg in self._arg_names}
+class PTLoraNLSSpec(PTLoraSpec):
+ _arg_names = PTLoraSpec._arg_names + ["active_lora_rank"]
+
+ def __init__(
+ self,
+ lora_rank: int,
+ active_lora_rank: int,
+ orig_weight_shape: list[int],
+ weight_shape: list[int],
+ ):
+ """
+ :param lora_rank: The rank of the adapters.
+ :param active_lora_rank: The active rank of the adapters.
+ :param orig_weight_shape: The shape of the original weight tensor.
+ :param weight_shape: The shape of the weight tensor before applying quantization. In case of group-wise
+ quantization, weights are reshaped from [Cout, Cin] to [Cout, Cin // group_size, group_size].
+ """
+ super().__init__(lora_rank, orig_weight_shape, weight_shape)
+ self.active_lora_rank = active_lora_rank
+
+
class PTQPointStateNames:
QSPEC = "qspec"
TARGET_POINT = "target_point"
@@ -1103,6 +1124,39 @@ def get_adapters(self) -> dict[str, torch.Tensor]:
}
+class LoraNLSMixin(LoraMixin):
+ """
+ Represents learnable LoRA (Low-Rank Adaptation) adapters for quantization modules,
+ and uses Neural Low-Rank Adapter Search (NLS) algorithm to make the adapter elastic.
+ """
+
+ def init_lora(self, lspec: PTLoraNLSSpec):
+ super().init_lora(lspec)
+ self.max_lora_rank = lspec.lora_rank
+ self.active_lora_rank = lspec.active_lora_rank
+
+ def set_active_rank(self, rank: int):
+ """
+ Set the active rank for the LoRA adapters.
+
+ :param rank: The rank to be set as active.
+ """
+ if rank > self.max_lora_rank:
+ msg = f"Activated rank {rank} cannot exceed the maximum LoRA rank {self.max_lora_rank}"
+ raise ValueError(msg)
+ self.active_lora_rank = rank
+
+ def get_active_adapters(self) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Get the currently active LoRA adapters.
+
+ :return: A dictionary containing the active LoRA adapters.
+ """
+ lora_A = self.lora_A[: self.active_lora_rank, :]
+ lora_B = self.lora_B[:, : self.active_lora_rank]
+ return lora_A, lora_B
+
+
@COMPRESSION_MODULES.register()
@QUANTIZATION_MODULES.register(QuantizationMode.ASYMMETRIC_LORA)
class AsymmetricLoraQuantizer(AsymmetricQuantizer, LoraMixin):
@@ -1154,6 +1208,36 @@ def from_config(cls, state) -> "AsymmetricLoraQuantizer":
return cls(qspec, lspec)
+@COMPRESSION_MODULES.register()
+@QUANTIZATION_MODULES.register(QuantizationMode.ASYMMETRIC_LORA_NLS)
+class AsymmetricLoraNLSQuantizer(AsymmetricLoraQuantizer, LoraNLSMixin):
+ def quantize(self, x: torch.Tensor, execute_traced_op_as_identity: bool = False):
+ # TODO: (dokuchaev) remove within new tracing (ticket-163869)
+ with DisableTorchFunction():
+ # in multi-device case after loading nncf checkpoint, quantizers have a different device.
+ self.to(x.device)
+ lora_A, lora_B = self.get_active_adapters()
+ return asymmetric_quantize_lora(
+ x,
+ self._lspec.weight_shape,
+ lora_A,
+ lora_B,
+ self.input_low,
+ self.input_range,
+ self.level_low,
+ self.level_high,
+ self.levels,
+ self.eps,
+ skip=execute_traced_op_as_identity,
+ )
+
+ @classmethod
+ def from_config(cls, state) -> "AsymmetricLoraNLSQuantizer":
+ qspec = PTQuantizerSpec.from_state(state["qspec"])
+ lspec = PTLoraNLSSpec.from_state(state["lspec"])
+ return cls(qspec, lspec)
+
+
@COMPRESSION_MODULES.register()
@QUANTIZATION_MODULES.register(QuantizationMode.SYMMETRIC_LORA)
class SymmetricLoraQuantizer(SymmetricQuantizer, LoraMixin):
@@ -1202,6 +1286,35 @@ def from_config(cls, state) -> "SymmetricLoraQuantizer":
return cls(qspec, lspec)
+@COMPRESSION_MODULES.register()
+@QUANTIZATION_MODULES.register(QuantizationMode.SYMMETRIC_LORA_NLS)
+class SymmetricLoraNLSQuantizer(SymmetricQuantizer, LoraNLSMixin):
+ def quantize(self, x, execute_traced_op_as_identity: bool = False):
+ # TODO: (dokuchaev) remove within new tracing (ticket-163869)
+ with DisableTorchFunction():
+ # in multi-device case after loading nncf checkpoint, quantizers have a different device.
+ self.to(x.device)
+ lora_A, lora_B = self.get_active_adapters()
+ return symmetric_quantize_lora(
+ x,
+ self._lspec.weight_shape,
+ lora_A,
+ lora_B,
+ self.scale,
+ self.level_low,
+ self.level_high,
+ self.levels,
+ self.eps,
+ skip=execute_traced_op_as_identity,
+ )
+
+ @classmethod
+ def from_config(cls, state) -> "SymmetricLoraNLSQuantizer":
+ qspec = PTQuantizerSpec.from_state(state["qspec"])
+ lspec = PTLoraNLSSpec.from_state(state["lspec"])
+ return cls(qspec, lspec)
+
+
def get_per_channel_scale_shape(input_shape, is_weights, channel_idx: Optional[int] = None) -> list[int]:
scale_shape = [1 for _ in input_shape]
if channel_idx is None:
diff --git a/tests/cross_fw/examples/.test_durations b/tests/cross_fw/examples/.test_durations
index e88198a394f..6e20fe44187 100644
--- a/tests/cross_fw/examples/.test_durations
+++ b/tests/cross_fw/examples/.test_durations
@@ -16,5 +16,6 @@
"tests/cross_fw/examples/test_examples.py::test_examples[post_training_quantization_torch_fx_resnet18]": 412.243,
"tests/cross_fw/examples/test_examples.py::test_examples[fp8_llm_quantization]": 229.69,
"tests/cross_fw/examples/test_examples.py::test_examples[quantization_aware_training_tensorflow_mobilenet_v2]": 1500.00,
- "tests/cross_fw/examples/test_examples.py::test_examples[llm_compression_qat_with_lora]": 665
+ "tests/cross_fw/examples/test_examples.py::test_examples[llm_compression_qat_with_lora]": 665,
+ "tests/cross_fw/examples/test_examples.py::test_examples[llm_compression_qat_with_nls]": 1030
}
diff --git a/tests/cross_fw/examples/example_scope.json b/tests/cross_fw/examples/example_scope.json
index 93c19d583f7..2d4997addf0 100644
--- a/tests/cross_fw/examples/example_scope.json
+++ b/tests/cross_fw/examples/example_scope.json
@@ -285,6 +285,16 @@
"similarity_diff": 0.034
}
},
+ "llm_compression_qat_with_nls": {
+ "backend": "torch",
+ "device": "cuda",
+ "requirements": "examples/llm_compression/torch/qat_with_nls_downstream/requirements.txt",
+ "cpu": "Intel(R) Core(TM) i9-10980XE CPU @ 3.00GHz",
+ "accuracy_tolerance": 0.02,
+ "accuracy_metrics": {
+ "accuracy_diff": 0.03
+ }
+ },
"quantization_aware_training_tensorflow_mobilenet_v2": {
"backend": "tf",
"requirements": "examples/quantization_aware_training/tensorflow/mobilenet_v2/requirements.txt",
diff --git a/tests/cross_fw/examples/run_example.py b/tests/cross_fw/examples/run_example.py
index e9c2b4004c1..d31be84bb5b 100644
--- a/tests/cross_fw/examples/run_example.py
+++ b/tests/cross_fw/examples/run_example.py
@@ -213,6 +213,29 @@ def llm_compression_qat_with_lora() -> float:
return {"similarity_diff": similarity_diff}
+def llm_compression_qat_with_nls() -> float:
+ from examples.llm_compression.torch.qat_with_nls_downstream.main import main as qat_with_nls_main
+
+ set_torch_cuda_seed()
+
+ args = [
+ "--pretrained=HuggingFaceTB/SmolLM2-135M-Instruct",
+ "--do_train",
+ "--task=arc_challenge",
+ "--epochs=2",
+ "--batch_size=16",
+ "--lr=5e-4",
+ "--lora_rank_space",
+ "16",
+ "12",
+ "8",
+ ]
+
+ accuracy_diff = qat_with_nls_main(args)
+
+ return {"accuracy_diff": accuracy_diff}
+
+
def post_training_quantization_torch_fx_resnet18():
from examples.post_training_quantization.torch_fx.resnet18.main import main as resnet18_main