-
Notifications
You must be signed in to change notification settings - Fork 320
Convert AutoAWQ checkpoints to compressed-tensors #2112
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 9 commits
a6da068
e281dfa
23a9d73
2d11f50
a147a51
717d6c8
f997a80
9085cc3
d10e9a1
15c0ba2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,338 @@ | ||
| """ | ||
| Convert AutoAWQ models to llmcompressor-compatible models. | ||
|
|
||
| This module offers the functionality to convert models quantized with AutoAWQ into | ||
| compressed models in llmcompressor's format, which can then be served with vLLM. | ||
| This module can be used as a CLI tool or as a Python API. | ||
|
|
||
| ## CLI Usage | ||
|
|
||
| ```sh | ||
| python -m llmcompressor.modifiers.awq.convert_autoawq \ | ||
| --model-name-or-path /path/to/model \ | ||
| --output-dir /path/to/compressed/model \ | ||
| --quantization-format naive-quantized | ||
| ``` | ||
|
|
||
| For more information, run `python -m llmcompressor.modifiers.awq.convert_autoawq --help` | ||
| or refer to the `ConversionArgs` dataclass below. | ||
|
|
||
| ## Python API Usage | ||
|
|
||
| ```python | ||
| from llmcompressor.modifiers.awq.convert_autoawq import load_and_convert_from_autoawq | ||
|
|
||
| awq_model_path = "/path/to/model" # can also be model_id on huggingface hub | ||
| model = load_and_convert_from_autoawq(awq_model_path) | ||
| model.generate(...) # the converted model is now ready to be used. | ||
| ``` | ||
| """ | ||
|
|
||
| import glob | ||
| import os | ||
| import re | ||
| from dataclasses import dataclass, field | ||
| from pathlib import Path | ||
| from tempfile import TemporaryDirectory | ||
| from typing import Any, Literal, cast | ||
|
|
||
| import torch | ||
| import transformers | ||
| from auto_round.export.export_to_awq.utils import ( | ||
| reverse_awq_order, | ||
| unpack_awq, | ||
| ) | ||
| from compressed_tensors import ModelCompressor | ||
| from compressed_tensors.quantization import ( | ||
| QuantizationArgs, | ||
| QuantizationConfig, | ||
| QuantizationScheme, | ||
| QuantizationStatus, | ||
| QuantizationStrategy, | ||
| QuantizationType, | ||
| ) | ||
| from huggingface_hub import load_state_dict_from_file, snapshot_download | ||
|
|
||
|
|
||
| def is_autoawq_model(model_path: Path, trust_remote_code: bool = False) -> bool: | ||
| config = transformers.AutoConfig.from_pretrained( | ||
| model_path, trust_remote_code=trust_remote_code | ||
| ) | ||
| if not hasattr(config, "quantization_config"): | ||
| return False | ||
|
|
||
| quantization_config = cast(dict[str, Any], config.quantization_config) | ||
| return quantization_config.get("quant_method") == "awq" | ||
|
|
||
|
|
||
| def resolve_model_path(model_name_or_path: str) -> Path: | ||
| if os.path.isdir(model_name_or_path): | ||
| return Path(model_name_or_path) | ||
| else: | ||
| # If the input is a model ID, download the model from the Hugging Face Hub and | ||
| # return the path to the local directory. | ||
| return Path(snapshot_download(model_name_or_path)) | ||
|
|
||
|
|
||
| def load_state_dict_from_model_dir(model_path: Path) -> dict[str, torch.Tensor]: | ||
| weight_files = glob.glob(str(model_path / "*.safetensors")) | ||
| if not weight_files: | ||
| weight_files = glob.glob(str(model_path / "*.bin")) | ||
|
|
||
| state_dict = {} | ||
| for weight_file in weight_files: | ||
| state_dict.update( | ||
| load_state_dict_from_file( | ||
| weight_file, map_location="cpu", weights_only=True | ||
| ) | ||
| ) | ||
| return state_dict | ||
|
|
||
|
|
||
| def dequantize_gemm( | ||
| state_dict: dict[str, torch.Tensor], prefix: str, autoawq_config: dict[str, Any] | ||
| ) -> None: | ||
| num_bits = cast(int, autoawq_config.get("bits")) | ||
| group_size = cast(int, autoawq_config.get("group_size")) | ||
|
|
||
| qweight = state_dict.pop(f"{prefix}.qweight") | ||
| scales = state_dict.pop(f"{prefix}.scales") | ||
| qzeros = state_dict.pop(f"{prefix}.qzeros") | ||
|
|
||
| def dequantize_gemm_original( | ||
| qweight: torch.Tensor, | ||
| qzeros: torch.Tensor, | ||
| scales: torch.Tensor, | ||
| bits: int, | ||
| group_size: int, | ||
| ) -> tuple[torch.Tensor, torch.Tensor]: | ||
| """Modified from auto_round.export.export_to_awq.utils.dequantize_gemm.""" | ||
| # Unpack the qweight and qzeros tensors | ||
| iweight, izeros = unpack_awq(qweight, qzeros, bits) | ||
| # Reverse the order of the iweight and izeros tensors | ||
| iweight, izeros = reverse_awq_order(iweight, izeros, bits) | ||
|
|
||
| # overflow checks | ||
| iweight = torch.bitwise_and(iweight, (2**bits) - 1) | ||
| izeros = torch.bitwise_and(izeros, (2**bits) - 1) | ||
|
|
||
| # fp16 weights | ||
| scales_interleaved = scales.repeat_interleave(group_size, dim=0) | ||
| izeros_interleaved = izeros.repeat_interleave(group_size, dim=0) | ||
| fweight = (iweight - izeros_interleaved) * scales_interleaved | ||
mutichung marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| return fweight, izeros | ||
|
|
||
| weight, zero_point = dequantize_gemm_original( | ||
| qweight, qzeros, scales, num_bits, group_size | ||
| ) | ||
|
|
||
| # AutoAWQ uses [0, 2^bits - 1], e.g., [0, 15], for quantized weights, but | ||
| # compressed-tensors uses [-2^(bits - 1), 2^(bits - 1) - 1], e.g., [-8, 7]. | ||
| # Therefore, we need to shift the zero point by 2^(bits - 1) to match the range | ||
| # of compressed-tensors and to allow correct quant/dequantization. | ||
| shifted_zero_point = zero_point - 2 ** (num_bits - 1) | ||
|
|
||
| state_dict.update( | ||
| { | ||
| f"{prefix}.weight": weight.T, | ||
| f"{prefix}.weight_scale": scales.T, | ||
| f"{prefix}.weight_zero_point": shifted_zero_point.T, | ||
| } | ||
| ) | ||
mutichung marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| def dequantize_autoawq_state_dict( | ||
| state_dict: dict[str, torch.Tensor], autoawq_config: dict[str, Any] | ||
| ) -> dict[str, torch.Tensor]: | ||
| version = cast(str, autoawq_config.get("version")) | ||
|
|
||
| # TODO: maybe add support for other versions? | ||
| match version: | ||
| case "gemm": | ||
| dequantize_fn = dequantize_gemm | ||
| case _: | ||
| raise ValueError(f"Unsupported version: {version}") | ||
|
|
||
| keys = list(state_dict.keys()) | ||
| for key in filter(lambda k: k.endswith("qweight"), keys): | ||
| prefix = key.removesuffix(".qweight") | ||
| dequantize_fn(state_dict, prefix, autoawq_config) | ||
|
|
||
| return state_dict | ||
|
|
||
|
|
||
| def convert_and_save( | ||
| model_name_or_path: str, | ||
| output_dir: str, | ||
| quantization_format: str, | ||
| overwrite: bool = False, | ||
| trust_remote_code: bool = False, | ||
| ) -> None: | ||
| """Convert an AutoAWQ model to a compressed model and save it. | ||
|
|
||
| Steps: | ||
|
|
||
| 1. Load the model weights directly. | ||
| 2. Dequantize the weights accordingly. | ||
| 3. Load the model with the dequantized weights. | ||
| 4. Add the quantization parameters to the model. | ||
| 5. Re-pack the weights using `ModelCompressor` with the correct configuration. | ||
| 6. Save the model to the output directory. | ||
|
|
||
| :param model_name_or_path: Model ID on huggingface hub or path to local model. | ||
| :param output_dir: Path to save the converted model. | ||
| :param quantization_format: Compression format to be saved. | ||
| :param overwrite: Overwrite the existing output directory if it exists. | ||
| :param trust_remote_code: Whether to trust remote code. | ||
| """ | ||
| is_empty_dir = ( | ||
| os.path.isdir(output_dir) and next(os.scandir(output_dir), None) is None | ||
| ) | ||
| if not is_empty_dir and not overwrite: | ||
| raise FileExistsError( | ||
| f"Output directory {output_dir} already exists. Set `overwrite=True` to" | ||
| " overwrite the existing directory." | ||
| ) | ||
|
|
||
| model_path = resolve_model_path(model_name_or_path) | ||
| if not is_autoawq_model(model_path, trust_remote_code): | ||
| raise ValueError("Model is not an AutoAWQ model") | ||
|
|
||
| config = transformers.AutoConfig.from_pretrained( | ||
| model_path, trust_remote_code=trust_remote_code | ||
| ) | ||
| autoawq_config = cast(dict[str, Any], config.quantization_config) | ||
| num_bits = cast(int, autoawq_config.get("bits")) | ||
| is_symmetric = not autoawq_config.get("zero_point") | ||
| group_size = cast(int, autoawq_config.get("group_size")) | ||
|
|
||
| # Convert AutoAWQ's substring-based ignore list to llm-compressor's regex format | ||
| # Usage in AutoAWQ: | ||
| # ```python | ||
| # if any(key in name for key in modules_to_not_convert): ... | ||
| # ``` | ||
| # See https://github.com/casper-hansen/AutoAWQ/blob/88e4c76b20755db275574e6a03c83c84ba3bece5/awq/utils/module.py#L62 | ||
| modules_to_not_convert = autoawq_config.get("modules_to_not_convert", None) | ||
| ignore = [] | ||
| if modules_to_not_convert is not None: | ||
| # Convert each substring pattern to a regex pattern that matches it anywhere | ||
| for module in modules_to_not_convert: | ||
| ignore.append(f"re:.*{re.escape(module)}.*") | ||
|
|
||
| ignore.append("lm_head") # AutoAWQ ignores lm_head by default | ||
|
|
||
| # 1. Load the model weights directly. | ||
| state_dict = load_state_dict_from_model_dir(model_path) | ||
|
|
||
| # 2. Dequantize the weights accordingly. | ||
| state_dict = dequantize_autoawq_state_dict(state_dict, autoawq_config) | ||
|
|
||
| # 3. Load the model with the dequantized weights. | ||
| del config.quantization_config # remove to avoid loading with AutoAWQ. | ||
| with transformers.modeling_utils.no_init_weights(): | ||
| model = transformers.AutoModelForCausalLM.from_config( | ||
| config, torch_dtype=torch.float16, trust_remote_code=trust_remote_code | ||
| ) | ||
|
|
||
| model.load_state_dict(state_dict, strict=False) | ||
|
|
||
| # 4. Add the quantization parameters to the model. | ||
| quantization_scheme = QuantizationScheme( | ||
| targets=["Linear"], | ||
| weights=QuantizationArgs( | ||
| num_bits=num_bits, | ||
| type=QuantizationType.INT, | ||
| symmetric=is_symmetric, | ||
| group_size=group_size, | ||
| strategy=QuantizationStrategy.GROUP, | ||
| ), | ||
| ) | ||
|
|
||
| for key in filter(lambda k: k.endswith("weight_zero_point"), state_dict.keys()): | ||
| module_name = key.removesuffix(".weight_zero_point") | ||
| setattr( | ||
| model.get_submodule(module_name), "quantization_scheme", quantization_scheme | ||
| ) | ||
|
|
||
| quant_config = QuantizationConfig( | ||
| config_groups={"group_0": quantization_scheme}, | ||
| quant_method="compressed-tensors", | ||
| quantization_status=QuantizationStatus.COMPRESSED, | ||
| format=quantization_format, | ||
| ignore=ignore, | ||
| ) | ||
|
|
||
| # 5. Re-pack the weights using `ModelCompressor`. | ||
| compressor = ModelCompressor(quantization_config=quant_config) | ||
| compressed_state_dict = compressor.compress(model, state_dict, show_progress=True) | ||
|
|
||
| # 6. Save the model. | ||
| tokenizer = transformers.AutoTokenizer.from_pretrained( | ||
| model_name_or_path, trust_remote_code=trust_remote_code | ||
| ) | ||
| model.save_pretrained(output_dir, state_dict=compressed_state_dict) | ||
| tokenizer.save_pretrained(output_dir) | ||
| compressor.update_config(output_dir) | ||
|
|
||
|
|
||
| def load_and_convert_from_autoawq( | ||
| model_name_or_path: str, | ||
| quantization_format: str = "naive-quantized", | ||
| trust_remote_code: bool = False, | ||
| ) -> transformers.modeling_utils.PreTrainedModel: | ||
| """ | ||
| Load an AutoAWQ checkpoint and convert it to a compressed model. | ||
|
|
||
| :param model_name_or_path: Model ID on huggingface hub or path to local model. | ||
| :param quantization_format: Compression format to be saved. | ||
| :param trust_remote_code: Whether to trust remote code. | ||
| :return: A compressed model. | ||
| """ | ||
| with TemporaryDirectory() as temp_dir: | ||
| convert_and_save( | ||
| model_name_or_path, | ||
| temp_dir, | ||
| quantization_format, | ||
| trust_remote_code=trust_remote_code, | ||
| ) | ||
| return transformers.AutoModelForCausalLM.from_pretrained( | ||
| temp_dir, torch_dtype=torch.float16, trust_remote_code=trust_remote_code | ||
| ) | ||
|
|
||
|
|
||
| @dataclass | ||
| class ConversionArgs: | ||
| model_name_or_path: str = field( | ||
| metadata={"help": "Model ID on huggingface hub or path to local model."}, | ||
| ) | ||
| output_dir: str = field( | ||
| metadata={"help": "Path to save the converted model."}, | ||
| ) | ||
| quantization_format: Literal["naive-quantized", "packed-quantized"] = field( | ||
| default="naive-quantized", | ||
| metadata={"help": "Compression format to be saved."}, | ||
| ) # TODO: switch default to packed-quantized once supported by llm-compressor. | ||
| overwrite: bool = field( | ||
| default=False, | ||
| metadata={"help": "Overwrite the existing output directory if it exists."}, | ||
| ) | ||
| trust_remote_code: bool = field( | ||
| default=False, | ||
| metadata={"help": "Whether to trust remote code."}, | ||
| ) | ||
|
|
||
|
|
||
| __all__ = ["convert_and_save", "load_and_convert_from_autoawq", "ConversionArgs"] | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| parser = transformers.HfArgumentParser(ConversionArgs) | ||
| args = parser.parse_args_into_dataclasses()[0] | ||
| convert_and_save( | ||
| args.model_name_or_path, | ||
| args.output_dir, | ||
| args.quantization_format, | ||
| args.overwrite, | ||
| args.trust_remote_code, | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,56 @@ | ||
| from tempfile import TemporaryDirectory | ||
|
|
||
| from lm_eval.evaluator import simple_evaluate | ||
|
|
||
| from llmcompressor.modifiers.awq.convert_autoawq import convert_and_save | ||
| from tests.testing_utils import requires_gpu | ||
|
|
||
|
|
||
| def run_lm_eval(model_name_or_path: str): | ||
| results = simple_evaluate( | ||
| model="hf", | ||
| model_args=f"pretrained={model_name_or_path},dtype=float16", | ||
| tasks=["arc_challenge", "arc_easy"], | ||
| num_fewshot=5, | ||
| batch_size=16, | ||
| ) | ||
|
|
||
| return results | ||
|
|
||
|
|
||
| def compare_models(model_name_or_path: str): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. running lm_eval can be expensive, when comparing models we just want to ensure the logits are the same for a given set of input_ids. So one way to make this cheaper would be input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids
orig_logits = orig_model.forward(input_ids=input_ids).logits
new_logits = new_model.forward(input_ids=input_ids).logits
# possible things to compare
print(f"Norm Diff {(orig_logits-new_logits).norm()}")
print(f"Norm MSE {torch.nn.MSELoss()(orig_logits,new_logits).norm()}")
print(f"Norm {orig_logits.norm()}, {new_logits.norm()}")
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Got it. Will work on this and see what I can do. |
||
| autoawq_result = run_lm_eval(model_name_or_path) | ||
| with TemporaryDirectory() as converted_model_dir: | ||
| convert_and_save(model_name_or_path, converted_model_dir, "naive-quantized") | ||
| converted_result = run_lm_eval(converted_model_dir) | ||
|
|
||
| arc_c_autoawq = autoawq_result["results"]["arc_challenge"]["acc_norm,none"] | ||
| arc_c_converted = converted_result["results"]["arc_challenge"]["acc_norm,none"] | ||
| arc_e_autoawq = autoawq_result["results"]["arc_easy"]["acc_norm,none"] | ||
| arc_e_converted = converted_result["results"]["arc_easy"]["acc_norm,none"] | ||
|
|
||
| assert abs(arc_e_autoawq - arc_e_converted) < 1e-2, ( | ||
| f"Arc Easy: autoawq={arc_e_autoawq} != converted={arc_e_converted}." | ||
| ) | ||
| assert abs(arc_c_autoawq - arc_c_converted) < 1e-2, ( | ||
| f"Arc Challenge: autoawq={arc_c_autoawq} != converted={arc_c_converted}." | ||
| ) | ||
|
|
||
|
|
||
| @requires_gpu | ||
| def test_mistral(): | ||
| compare_models( | ||
| "fbaldassarri/mistralai_Mistral-7B-Instruct-v0.3-autoawq-int4-gs128-asym" | ||
| ) | ||
|
|
||
|
|
||
| @requires_gpu | ||
| def test_qwen(): | ||
| compare_models( | ||
| "ruikangliu/DeepSeek-R1-Distill-Qwen-1.5B-quantized.awq-autoawq-w4g128" | ||
| ) | ||
|
|
||
|
|
||
| @requires_gpu | ||
| def test_llama(): | ||
| compare_models("AMead10/Llama-3.2-3B-Instruct-AWQ") | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this may be more suitable for a higher scope, for example placing it into
examples/ortools/, or even directly into compressed-tensors as it does not involve any src code in llmcompressor.This is the first time we've added a feature like this, just posting here to see what the rest of the team thinks, and we can decide after that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agree, this is not a modifier and should probably be in tools or maybe utils?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here's my take from a user experience perspective. None of them are strong opinions btw 😆:
llm-compressorand look for things related to AWQ. Putting the conversion script under the AWQ modifier module improves discoverability.examples/,tools/) + mentioning it in the AWQ modifier's description should also work but adds an extra layer of maintenance effort if things are subjected to changes.compressed-tensorsdoes not have any documentation at the moment. Users would have to be redirected from some warning messages, AWQ modifier documentation, or vLLM documentation to be aware of the existence of this tool.Also, just curious, here's a question regarding the purpose of this tool:
Footnotes
vLLM AutoAWQ page ↩