Skip to content
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
338 changes: 338 additions & 0 deletions src/llmcompressor/modifiers/awq/convert_autoawq.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this may be more suitable for a higher scope, for example placing it into examples/ or tools/, or even directly into compressed-tensors as it does not involve any src code in llmcompressor.

This is the first time we've added a feature like this, just posting here to see what the rest of the team thinks, and we can decide after that.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree, this is not a modifier and should probably be in tools or maybe utils?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's my take from a user experience perspective. None of them are strong opinions btw 😆:

  • Users with their AutoAWQ checkpoint would naturally visit the document site of llm-compressor and look for things related to AWQ. Putting the conversion script under the AWQ modifier module improves discoverability.
  • Putting it into some other places (examples/, tools/) + mentioning it in the AWQ modifier's description should also work but adds an extra layer of maintenance effort if things are subjected to changes.
  • compressed-tensors does not have any documentation at the moment. Users would have to be redirected from some warning messages, AWQ modifier documentation, or vLLM documentation to be aware of the existence of this tool.

Also, just curious, here's a question regarding the purpose of this tool:

  • vLLM seems to support serving AutoAWQ checkpoints 1. Why do we need to convert the format? Is vLLM planning on dropping the support and removing the AutoAWQ kernels?
  • If so, then mentioning this conversion tool in the vLLM doc and the warning/error message makes the most sense.
  • Otherwise, what are reasons and the potential entry points do you expect users to have the need to convert AutoAWQ checkpoints?

Footnotes

  1. vLLM AutoAWQ page

Original file line number Diff line number Diff line change
@@ -0,0 +1,338 @@
"""
Convert AutoAWQ models to llmcompressor-compatible models.

This module offers the functionality to convert models quantized with AutoAWQ into
compressed models in llmcompressor's format, which can then be served with vLLM.
This module can be used as a CLI tool or as a Python API.

## CLI Usage

```sh
python -m llmcompressor.modifiers.awq.convert_autoawq \
--model-name-or-path /path/to/model \
--output-dir /path/to/compressed/model \
--quantization-format naive-quantized
```

For more information, run `python -m llmcompressor.modifiers.awq.convert_autoawq --help`
or refer to the `ConversionArgs` dataclass below.

## Python API Usage

```python
from llmcompressor.modifiers.awq.convert_autoawq import load_and_convert_from_autoawq

awq_model_path = "/path/to/model" # can also be model_id on huggingface hub
model = load_and_convert_from_autoawq(awq_model_path)
model.generate(...) # the converted model is now ready to be used.
```
"""

import glob
import os
import re
from dataclasses import dataclass, field
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any, Literal, cast

import torch
import transformers
from auto_round.export.export_to_awq.utils import (
reverse_awq_order,
unpack_awq,
)
from compressed_tensors import ModelCompressor
from compressed_tensors.quantization import (
QuantizationArgs,
QuantizationConfig,
QuantizationScheme,
QuantizationStatus,
QuantizationStrategy,
QuantizationType,
)
from huggingface_hub import load_state_dict_from_file, snapshot_download


def is_autoawq_model(model_path: Path, trust_remote_code: bool = False) -> bool:
config = transformers.AutoConfig.from_pretrained(
model_path, trust_remote_code=trust_remote_code
)
if not hasattr(config, "quantization_config"):
return False

quantization_config = cast(dict[str, Any], config.quantization_config)
return quantization_config.get("quant_method") == "awq"


def resolve_model_path(model_name_or_path: str) -> Path:
if os.path.isdir(model_name_or_path):
return Path(model_name_or_path)
else:
# If the input is a model ID, download the model from the Hugging Face Hub and
# return the path to the local directory.
return Path(snapshot_download(model_name_or_path))


def load_state_dict_from_model_dir(model_path: Path) -> dict[str, torch.Tensor]:
weight_files = glob.glob(str(model_path / "*.safetensors"))
if not weight_files:
weight_files = glob.glob(str(model_path / "*.bin"))

state_dict = {}
for weight_file in weight_files:
state_dict.update(
load_state_dict_from_file(
weight_file, map_location="cpu", weights_only=True
)
)
return state_dict


def dequantize_gemm(
state_dict: dict[str, torch.Tensor], prefix: str, autoawq_config: dict[str, Any]
) -> None:
num_bits = cast(int, autoawq_config.get("bits"))
group_size = cast(int, autoawq_config.get("group_size"))

qweight = state_dict.pop(f"{prefix}.qweight")
scales = state_dict.pop(f"{prefix}.scales")
qzeros = state_dict.pop(f"{prefix}.qzeros")

def dequantize_gemm_original(
qweight: torch.Tensor,
qzeros: torch.Tensor,
scales: torch.Tensor,
bits: int,
group_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Modified from auto_round.export.export_to_awq.utils.dequantize_gemm."""
# Unpack the qweight and qzeros tensors
iweight, izeros = unpack_awq(qweight, qzeros, bits)
# Reverse the order of the iweight and izeros tensors
iweight, izeros = reverse_awq_order(iweight, izeros, bits)

# overflow checks
iweight = torch.bitwise_and(iweight, (2**bits) - 1)
izeros = torch.bitwise_and(izeros, (2**bits) - 1)

# fp16 weights
scales_interleaved = scales.repeat_interleave(group_size, dim=0)
izeros_interleaved = izeros.repeat_interleave(group_size, dim=0)
fweight = (iweight - izeros_interleaved) * scales_interleaved

return fweight, izeros

weight, zero_point = dequantize_gemm_original(
qweight, qzeros, scales, num_bits, group_size
)

# AutoAWQ uses [0, 2^bits - 1], e.g., [0, 15], for quantized weights, but
# compressed-tensors uses [-2^(bits - 1), 2^(bits - 1) - 1], e.g., [-8, 7].
# Therefore, we need to shift the zero point by 2^(bits - 1) to match the range
# of compressed-tensors and to allow correct quant/dequantization.
shifted_zero_point = zero_point - 2 ** (num_bits - 1)

state_dict.update(
{
f"{prefix}.weight": weight.T,
f"{prefix}.weight_scale": scales.T,
f"{prefix}.weight_zero_point": shifted_zero_point.T,
}
)


def dequantize_autoawq_state_dict(
state_dict: dict[str, torch.Tensor], autoawq_config: dict[str, Any]
) -> dict[str, torch.Tensor]:
version = cast(str, autoawq_config.get("version"))

# TODO: maybe add support for other versions?
match version:
case "gemm":
dequantize_fn = dequantize_gemm
case _:
raise ValueError(f"Unsupported version: {version}")

keys = list(state_dict.keys())
for key in filter(lambda k: k.endswith("qweight"), keys):
prefix = key.removesuffix(".qweight")
dequantize_fn(state_dict, prefix, autoawq_config)

return state_dict


def convert_and_save(
model_name_or_path: str,
output_dir: str,
quantization_format: str,
overwrite: bool = False,
trust_remote_code: bool = False,
) -> None:
"""Convert an AutoAWQ model to a compressed model and save it.

Steps:

1. Load the model weights directly.
2. Dequantize the weights accordingly.
3. Load the model with the dequantized weights.
4. Add the quantization parameters to the model.
5. Re-pack the weights using `ModelCompressor` with the correct configuration.
6. Save the model to the output directory.

:param model_name_or_path: Model ID on huggingface hub or path to local model.
:param output_dir: Path to save the converted model.
:param quantization_format: Compression format to be saved.
:param overwrite: Overwrite the existing output directory if it exists.
:param trust_remote_code: Whether to trust remote code.
"""
is_empty_dir = (
os.path.isdir(output_dir) and next(os.scandir(output_dir), None) is None
)
if not is_empty_dir and not overwrite:
raise FileExistsError(
f"Output directory {output_dir} already exists. Set `overwrite=True` to"
" overwrite the existing directory."
)

model_path = resolve_model_path(model_name_or_path)
if not is_autoawq_model(model_path, trust_remote_code):
raise ValueError("Model is not an AutoAWQ model")

config = transformers.AutoConfig.from_pretrained(
model_path, trust_remote_code=trust_remote_code
)
autoawq_config = cast(dict[str, Any], config.quantization_config)
num_bits = cast(int, autoawq_config.get("bits"))
is_symmetric = not autoawq_config.get("zero_point")
group_size = cast(int, autoawq_config.get("group_size"))

# Convert AutoAWQ's substring-based ignore list to llm-compressor's regex format
# Usage in AutoAWQ:
# ```python
# if any(key in name for key in modules_to_not_convert): ...
# ```
# See https://github.com/casper-hansen/AutoAWQ/blob/88e4c76b20755db275574e6a03c83c84ba3bece5/awq/utils/module.py#L62
modules_to_not_convert = autoawq_config.get("modules_to_not_convert", None)
ignore = []
if modules_to_not_convert is not None:
# Convert each substring pattern to a regex pattern that matches it anywhere
for module in modules_to_not_convert:
ignore.append(f"re:.*{re.escape(module)}.*")

ignore.append("lm_head") # AutoAWQ ignores lm_head by default

# 1. Load the model weights directly.
state_dict = load_state_dict_from_model_dir(model_path)

# 2. Dequantize the weights accordingly.
state_dict = dequantize_autoawq_state_dict(state_dict, autoawq_config)

# 3. Load the model with the dequantized weights.
del config.quantization_config # remove to avoid loading with AutoAWQ.
with transformers.modeling_utils.no_init_weights():
model = transformers.AutoModelForCausalLM.from_config(
config, torch_dtype=torch.float16, trust_remote_code=trust_remote_code
)

model.load_state_dict(state_dict, strict=False)

# 4. Add the quantization parameters to the model.
quantization_scheme = QuantizationScheme(
targets=["Linear"],
weights=QuantizationArgs(
num_bits=num_bits,
type=QuantizationType.INT,
symmetric=is_symmetric,
group_size=group_size,
strategy=QuantizationStrategy.GROUP,
),
)

for key in filter(lambda k: k.endswith("weight_zero_point"), state_dict.keys()):
module_name = key.removesuffix(".weight_zero_point")
setattr(
model.get_submodule(module_name), "quantization_scheme", quantization_scheme
)

quant_config = QuantizationConfig(
config_groups={"group_0": quantization_scheme},
quant_method="compressed-tensors",
quantization_status=QuantizationStatus.COMPRESSED,
format=quantization_format,
ignore=ignore,
)

# 5. Re-pack the weights using `ModelCompressor`.
compressor = ModelCompressor(quantization_config=quant_config)
compressed_state_dict = compressor.compress(model, state_dict, show_progress=True)

# 6. Save the model.
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_name_or_path, trust_remote_code=trust_remote_code
)
model.save_pretrained(output_dir, state_dict=compressed_state_dict)
tokenizer.save_pretrained(output_dir)
compressor.update_config(output_dir)


def load_and_convert_from_autoawq(
model_name_or_path: str,
quantization_format: str = "naive-quantized",
trust_remote_code: bool = False,
) -> transformers.modeling_utils.PreTrainedModel:
"""
Load an AutoAWQ checkpoint and convert it to a compressed model.

:param model_name_or_path: Model ID on huggingface hub or path to local model.
:param quantization_format: Compression format to be saved.
:param trust_remote_code: Whether to trust remote code.
:return: A compressed model.
"""
with TemporaryDirectory() as temp_dir:
convert_and_save(
model_name_or_path,
temp_dir,
quantization_format,
trust_remote_code=trust_remote_code,
)
return transformers.AutoModelForCausalLM.from_pretrained(
temp_dir, torch_dtype=torch.float16, trust_remote_code=trust_remote_code
)


@dataclass
class ConversionArgs:
model_name_or_path: str = field(
metadata={"help": "Model ID on huggingface hub or path to local model."},
)
output_dir: str = field(
metadata={"help": "Path to save the converted model."},
)
quantization_format: Literal["naive-quantized", "packed-quantized"] = field(
default="naive-quantized",
metadata={"help": "Compression format to be saved."},
) # TODO: switch default to packed-quantized once supported by llm-compressor.
overwrite: bool = field(
default=False,
metadata={"help": "Overwrite the existing output directory if it exists."},
)
trust_remote_code: bool = field(
default=False,
metadata={"help": "Whether to trust remote code."},
)


__all__ = ["convert_and_save", "load_and_convert_from_autoawq", "ConversionArgs"]


if __name__ == "__main__":
parser = transformers.HfArgumentParser(ConversionArgs)
args = parser.parse_args_into_dataclasses()[0]
convert_and_save(
args.model_name_or_path,
args.output_dir,
args.quantization_format,
args.overwrite,
args.trust_remote_code,
)
56 changes: 56 additions & 0 deletions tests/llmcompressor/modifiers/awq/test_convert_autoawq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from tempfile import TemporaryDirectory

from lm_eval.evaluator import simple_evaluate

from llmcompressor.modifiers.awq.convert_autoawq import convert_and_save
from tests.testing_utils import requires_gpu


def run_lm_eval(model_name_or_path: str):
results = simple_evaluate(
model="hf",
model_args=f"pretrained={model_name_or_path},dtype=float16",
tasks=["arc_challenge", "arc_easy"],
num_fewshot=5,
batch_size=16,
)

return results


def compare_models(model_name_or_path: str):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

running lm_eval can be expensive, when comparing models we just want to ensure the logits are the same for a given set of input_ids. So one way to make this cheaper would be

input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids
orig_logits = orig_model.forward(input_ids=input_ids).logits
new_logits = new_model.forward(input_ids=input_ids).logits

# possible things to compare
print(f"Norm Diff {(orig_logits-new_logits).norm()}")
print(f"Norm MSE {torch.nn.MSELoss()(orig_logits,new_logits).norm()}")
print(f"Norm {orig_logits.norm()}, {new_logits.norm()}")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. Will work on this and see what I can do.

autoawq_result = run_lm_eval(model_name_or_path)
with TemporaryDirectory() as converted_model_dir:
convert_and_save(model_name_or_path, converted_model_dir, "naive-quantized")
converted_result = run_lm_eval(converted_model_dir)

arc_c_autoawq = autoawq_result["results"]["arc_challenge"]["acc_norm,none"]
arc_c_converted = converted_result["results"]["arc_challenge"]["acc_norm,none"]
arc_e_autoawq = autoawq_result["results"]["arc_easy"]["acc_norm,none"]
arc_e_converted = converted_result["results"]["arc_easy"]["acc_norm,none"]

assert abs(arc_e_autoawq - arc_e_converted) < 1e-2, (
f"Arc Easy: autoawq={arc_e_autoawq} != converted={arc_e_converted}."
)
assert abs(arc_c_autoawq - arc_c_converted) < 1e-2, (
f"Arc Challenge: autoawq={arc_c_autoawq} != converted={arc_c_converted}."
)


@requires_gpu
def test_mistral():
compare_models(
"fbaldassarri/mistralai_Mistral-7B-Instruct-v0.3-autoawq-int4-gs128-asym"
)


@requires_gpu
def test_qwen():
compare_models(
"ruikangliu/DeepSeek-R1-Distill-Qwen-1.5B-quantized.awq-autoawq-w4g128"
)


@requires_gpu
def test_llama():
compare_models("AMead10/Llama-3.2-3B-Instruct-AWQ")