Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/llmcompressor/conversion/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Conversion utilities for importing third-party quantized checkpoints."""

from llmcompressor.conversion.autoawq_to_ct import convert_autoawq_to_ct

__all__ = ["convert_autoawq_to_ct"]
5 changes: 5 additions & 0 deletions src/llmcompressor/conversion/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Allow ``python -m llmcompressor.conversion.autoawq_to_ct``."""
Copy link

Copilot AI Mar 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The module docstring says python -m llmcompressor.conversion.autoawq_to_ct, but src/llmcompressor/conversion/__main__.py is only used by python -m llmcompressor.conversion. Either update the docstring to reflect the actual invocation, or consider moving this entrypoint to autoawq_to_ct/__main__.py (or rely solely on if __name__ == '__main__' already present in autoawq_to_ct.py).

Suggested change
"""Allow ``python -m llmcompressor.conversion.autoawq_to_ct``."""
"""Allow ``python -m llmcompressor.conversion``."""

Copilot uses AI. Check for mistakes.

from llmcompressor.conversion.autoawq_to_ct import main

main()
372 changes: 372 additions & 0 deletions src/llmcompressor/conversion/autoawq_to_ct.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,372 @@
"""
Conversion tool for converting AutoAWQ quantized checkpoints to the
compressed-tensors format (``pack_quantized`` compressor).

AutoAWQ stores int4 weights in int32 tensors with an interleaved packing
order ``[0, 2, 4, 6, 1, 3, 5, 7]``, while compressed-tensors uses the
sequential order ``[0, 1, 2, 3, 4, 5, 6, 7]``. This module handles the
re-packing and metadata generation so the output model can be loaded
directly by vLLM.

Usage (CLI)::

python -m llmcompressor.conversion.autoawq_to_ct \\
--model-path /path/to/autoawq-model \\
--output-path /path/to/output \\
--num-bits 4 --group-size 128

Usage (Python API)::

from llmcompressor.conversion.autoawq_to_ct import convert_autoawq_to_ct

convert_autoawq_to_ct(
model_path="/path/to/autoawq-model",
output_path="/path/to/output",
)
"""

from __future__ import annotations

import argparse
import json
import logging
import shutil
from pathlib import Path

import torch
from safetensors import safe_open
from safetensors.torch import save_file
from tqdm import tqdm
from transformers import AutoConfig, AutoTokenizer

logger = logging.getLogger(__name__)

__all__ = ["convert_autoawq_to_ct"]

# AutoAWQ packs 8 int4 values into int32 using the interleaved order
# ``[0, 2, 4, 6, 1, 3, 5, 7]``. The *inverse* permutation
# ``[0, 4, 1, 5, 2, 6, 3, 7]`` maps bit-positions back to the
# sequential column indices expected by compressed-tensors.
_AWQ_ORDER = [0, 2, 4, 6, 1, 3, 5, 7]
_AWQ_REVERSE_ORDER = [0, 4, 1, 5, 2, 6, 3, 7] # inverse of _AWQ_ORDER

# AutoAWQ tensor suffix → compressed-tensors tensor suffix
_KEY_MAP = {
".qweight": ".weight_packed",
".scales": ".weight_scale",
".qzeros": ".weight_zero_point",
}


# ---------------------------------------------------------------------------
# Weight conversion helpers
# ---------------------------------------------------------------------------


def _unpack_awq_int4(packed: torch.Tensor) -> torch.Tensor:
"""Unpack int4 values from AutoAWQ's **interleaved** int32 packing.

AutoAWQ's ``gemm_pack`` packs 8 int4 values per int32 using::

order_map = [0, 2, 4, 6, 1, 3, 5, 7]
for i in range(8):
int_weight[:, col] |= weight[:, col*8 + order_map[i]] << (i * 4)

This function reverses that process and returns values in the natural
(sequential) column order.

:param packed: ``(rows, cols // 8)`` int32 tensor.
:return: ``(rows, cols)`` int32 tensor with values in ``[0, 15]``.
"""
rows, packed_cols = packed.shape
cols = packed_cols * 8

# Step 1: extract the 8 nibbles stored at each bit-position.
raw = torch.zeros((rows, cols), dtype=torch.int32, device=packed.device)
for bit_pos in range(8):
raw[:, bit_pos::8] = (packed >> (bit_pos * 4)) & 0xF

# Step 2: undo the interleaving.
# Bit-position ``bit_pos`` holds the original column ``_AWQ_ORDER[bit_pos]``
# within each group of 8. We scatter back to sequential order.
result = torch.zeros_like(raw)
for seq_idx, bit_pos in enumerate(_AWQ_REVERSE_ORDER):
result[:, seq_idx::8] = raw[:, bit_pos::8]

return result


def _pack_ct_int4(values: torch.Tensor) -> torch.Tensor:
"""Pack int4 values into compressed-tensors' **sequential** int32 format.

compressed-tensors stores 8 int4 values per int32 in natural order:
``value[i]`` occupies bits ``4*i … 4*i+3``.

:param values: ``(rows, cols)`` int32 tensor (each element in ``[0, 15]``).
:return: ``(rows, cols // 8)`` int32 tensor.
"""
rows, cols = values.shape
if cols % 8 != 0:
raise ValueError(f"columns must be divisible by 8, got {cols}")

packed = torch.zeros(
(rows, cols // 8), dtype=torch.int32, device=values.device
)
for i in range(8):
packed |= (values[:, i::8] & 0xF).to(torch.int32) << (i * 4)
return packed


def _repack_awq_to_ct(packed_awq: torch.Tensor) -> torch.Tensor:
"""One-shot conversion: AWQ-packed int32 → CT-packed int32."""
return _pack_ct_int4(_unpack_awq_int4(packed_awq))


Comment on lines +120 to +124
Copy link

Copilot AI Mar 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_repack_awq_to_ct expands the packed int32 weights into a full int32 matrix in _unpack_awq_int4 (8× more elements) and then repacks. For large models this intermediate can be tens of GB and may OOM CPU RAM. Consider repacking by extracting/reordering nibbles within the packed int32 tensor (staying in the packed shape) or using a chunked/streaming approach to cap peak memory.

Suggested change
def _repack_awq_to_ct(packed_awq: torch.Tensor) -> torch.Tensor:
"""One-shot conversion: AWQ-packed int32 → CT-packed int32."""
return _pack_ct_int4(_unpack_awq_int4(packed_awq))
def _repack_awq_to_ct(
packed_awq: torch.Tensor,
max_chunk_bytes: int = 256 * 1024 * 1024,
) -> torch.Tensor:
"""Convert AWQ-packed int32CT-packed int32 with bounded peak memory.
The naive implementation would unpack the entire tensor to an 8× larger
int32 matrix and then repack it. For large models this can require tens
of GB of RAM. To avoid that, we process the tensor in row-wise chunks:
each chunk is unpacked and repacked independently, and the intermediate
is immediately discarded.
:param packed_awq: AWQ-packed int32 tensor of shape ``(rows, cols_packed)``.
:param max_chunk_bytes: Approximate upper bound on the size of the
unpacked intermediate per chunk, in bytes.
:return: CT-packed int32 tensor with the same shape as ``packed_awq``.
"""
if packed_awq.dim() != 2:
# Keep behavior simple and explicit: this helper is for 2D weight
# matrices. If other shapes are needed, they should be reshaped by
# the caller.
raise ValueError(
f"_repack_awq_to_ct expects a 2D tensor, got shape {tuple(packed_awq.shape)}"
)
rows, cols_packed = packed_awq.shape
if rows == 0 or cols_packed == 0:
return packed_awq.clone()
# Each packed column expands to 8 int32 values in the unpacked matrix.
cols_unpacked = cols_packed * 8
bytes_per_row_unpacked = cols_unpacked * 4 # int32 = 4 bytes
# Compute a chunk size (number of rows) that keeps the unpacked
# intermediate for a chunk under `max_chunk_bytes`. Always process at
# least one row.
max_rows_per_chunk = max(1, max_chunk_bytes // max(bytes_per_row_unpacked, 1))
# Preallocate output tensor in the packed CT layout.
packed_ct = torch.empty_like(packed_awq)
for start in range(0, rows, max_rows_per_chunk):
end = min(rows, start + max_rows_per_chunk)
# Slice the current chunk of rows, convert layout, and write back.
chunk_packed_awq = packed_awq[start:end]
chunk_unpacked = _unpack_awq_int4(chunk_packed_awq)
chunk_packed_ct = _pack_ct_int4(chunk_unpacked)
packed_ct[start:end] = chunk_packed_ct
return packed_ct

Copilot uses AI. Check for mistakes.
# ---------------------------------------------------------------------------
# Key renaming helpers
# ---------------------------------------------------------------------------


def _rename_key(key: str, awq_prefixes: set[str]) -> str:
"""Return the compressed-tensors key name for *key*, or *key* unchanged."""
for prefix in awq_prefixes:
if not key.startswith(prefix):
continue
suffix = key[len(prefix):]
if suffix in _KEY_MAP:
return prefix + _KEY_MAP[suffix]
return key


# ---------------------------------------------------------------------------
# Main conversion
# ---------------------------------------------------------------------------


def convert_autoawq_to_ct(
model_path: str | Path,
output_path: str | Path,
num_bits: int = 4,
group_size: int = 128,
symmetric: bool = False,
) -> None:
"""Convert an AutoAWQ checkpoint to the compressed-tensors ``pack_quantized``
format so that the resulting model can be loaded directly in vLLM.

:param model_path: directory containing the AutoAWQ model.
:param output_path: destination directory for the converted model.
:param num_bits: quantization bit-width (default 4).
:param group_size: quantization group size (default 128).
:param symmetric: ``True`` for symmetric quantisation (AutoAWQ default
is *asymmetric*, i.e. ``False``).
"""
model_path = Path(model_path)
output_path = Path(output_path)
output_path.mkdir(parents=True, exist_ok=True)

logger.info("Converting AutoAWQ model: %s → %s", model_path, output_path)

# ----- Load model config -----
config = AutoConfig.from_pretrained(model_path)
awq_config = getattr(config, "quantization_config", None)
if awq_config and isinstance(awq_config, dict):
num_bits = awq_config.get("bits", num_bits)
group_size = awq_config.get("group_size", group_size)
# AutoAWQ uses ``zero_point: True`` to indicate *asymmetric* quant.
symmetric = not awq_config.get("zero_point", True)
Comment on lines +173 to +176
Copy link

Copilot AI Mar 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CLI-provided num_bits / group_size / symmetric are always overwritten when config.quantization_config is present, so users cannot override AutoAWQ metadata even if they pass explicit flags. If override is intended, consider using None defaults (and argparse defaults of None) so you can distinguish “not provided” from “provided”, or add an explicit --no-autodetect / --prefer-cli switch.

Suggested change
num_bits = awq_config.get("bits", num_bits)
group_size = awq_config.get("group_size", group_size)
# AutoAWQ uses ``zero_point: True`` to indicate *asymmetric* quant.
symmetric = not awq_config.get("zero_point", True)
# Only apply AutoAWQ metadata when the corresponding value is still at
# its default, so that explicit CLI arguments can override it.
if num_bits == 4:
num_bits = awq_config.get("bits", num_bits)
if group_size == 128:
group_size = awq_config.get("group_size", group_size)
# AutoAWQ uses ``zero_point: True`` to indicate *asymmetric* quant.
if symmetric is False:
symmetric = not awq_config.get("zero_point", True)

Copilot uses AI. Check for mistakes.
logger.info(
"Quantisation params: bits=%d group_size=%d symmetric=%s",
num_bits, group_size, symmetric,
)

# ----- Discover safetensors shards -----
st_files = sorted(model_path.glob("*.safetensors"))
if not st_files:
raise FileNotFoundError(
f"No .safetensors files in {model_path}. "
"Make sure the model was saved in safetensors format."
)
logger.info("Found %d safetensors shard(s)", len(st_files))

# Collect *all* AWQ quantised layer prefixes across shards so that the
# index-file rewriting can reference them.
all_awq_prefixes: set[str] = set()

# ----- Convert each shard -----
for st_file in tqdm(st_files, desc="Converting shards"):
converted: dict[str, torch.Tensor] = {}

with safe_open(str(st_file), framework="pt", device="cpu") as f:
keys = list(f.keys())

# AWQ prefixes in *this* shard
shard_prefixes: set[str] = set()
for key in keys:
if key.endswith(".qweight"):
shard_prefixes.add(key.removesuffix(".qweight"))
all_awq_prefixes |= shard_prefixes

for key in tqdm(keys, desc=f" {st_file.name}", leave=False):
tensor = f.get_tensor(key)

# Try to match to an AWQ quantised layer
matched_prefix = None
for prefix in shard_prefixes:
if key.startswith(prefix):
matched_prefix = prefix
break

Comment on lines +212 to +218
Copy link

Copilot AI Mar 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The per-tensor conversion does an O(#keys × #quant_prefixes) scan (for prefix in shard_prefixes: if key.startswith(prefix)) for every key. On large sharded checkpoints this can be a noticeable CPU cost. Consider determining the prefix via known suffixes (e.g., if key.endswith('.qweight'): prefix=removesuffix(...)) or precomputing a lookup so each key is classified in O(1).

Copilot uses AI. Check for mistakes.
if matched_prefix is None:
# Non-quantised parameter – pass through unchanged.
converted[key] = tensor
continue

Comment on lines +202 to +223
Copy link

Copilot AI Mar 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shard conversion currently identifies quantized layer prefixes only from keys ending in .qweight within the same shard. If a shard contains .scales/.qzeros for a layer whose .qweight lives in a different shard, those tensors will be passed through unconverted, while the index rewrite later will still rename them based on all_awq_prefixes, producing a broken checkpoint (index points to renamed keys that don't exist). Consider detecting AWQ tensors directly by suffix (.qweight, .scales, .qzeros) and deriving the prefix from the key itself, or do a pre-pass over all shards to collect prefixes and use the global set when converting each shard.

Copilot uses AI. Check for mistakes.
suffix = key[len(matched_prefix):]

if suffix == ".qweight":
converted[f"{matched_prefix}.weight_packed"] = (
_repack_awq_to_ct(tensor)
)

elif suffix == ".scales":
converted[f"{matched_prefix}.weight_scale"] = tensor

elif suffix == ".qzeros":
# Zero-points are also packed with the AWQ interleave.
zp = _unpack_awq_int4(tensor)
converted[f"{matched_prefix}.weight_zero_point"] = zp

elif suffix == ".bias":
converted[key] = tensor

else:
converted[key] = tensor
Comment on lines +224 to +243
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This section for handling different tensor suffixes can be simplified. The elif suffix == ".bias": block is redundant because its logic is identical to the else: block that follows. Combining them will make the code more concise and easier to read.

                suffix = key[len(matched_prefix):]

                if suffix == ".qweight":
                    converted[f"{matched_prefix}.weight_packed"] = (
                        _repack_awq_to_ct(tensor)
                    )
                elif suffix == ".scales":
                    converted[f"{matched_prefix}.weight_scale"] = tensor
                elif suffix == ".qzeros":
                    # Zero-points are also packed with the AWQ interleave.
                    zp = _unpack_awq_int4(tensor)
                    converted[f"{matched_prefix}.weight_zero_point"] = zp
                else:
                    # Pass through other parameters like bias.
                    converted[key] = tensor


save_file(converted, str(output_path / st_file.name))

# ----- Build compressed-tensors quantization_config -----
strategy = "group" if group_size > 0 else "channel"
quant_config = {
"quant_method": "compressed-tensors",
"format": "pack_quantized",
"global_compression_ratio": None,
"config_groups": {
"group_0": {
"targets": ["Linear"],
"weights": {
"num_bits": num_bits,
"type": "int",
"symmetric": symmetric,
"strategy": strategy,
"group_size": group_size if group_size > 0 else None,
},
"input_activations": None,
"output_activations": None,
}
},
"ignore": ["lm_head"],
}
Comment on lines +247 to +268
Copy link

Copilot AI Mar 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The generated quantization_config schema differs from the one produced elsewhere in this repo (e.g., entrypoints/model_free/save_utils.update_config), which includes fields like compression_version and quantization_status and constructs the config via compressed_tensors.quantization.QuantizationConfig. To avoid incompatibilities with downstream loaders expecting the standard compressed-tensors config shape, consider building this dict using QuantizationConfig/QuantizationScheme and dumping it similarly to update_config (including format, ignore, and quantization_status).

Copilot uses AI. Check for mistakes.

# ----- Write config.json -----
config_dict = config.to_dict()
config_dict["quantization_config"] = quant_config
with open(output_path / "config.json", "w") as fp:
json.dump(config_dict, fp, indent=2)
logger.info("Wrote config.json with compressed-tensors quantization_config")

# ----- Tokenizer -----
try:
tok = AutoTokenizer.from_pretrained(model_path)
tok.save_pretrained(output_path)
logger.info("Saved tokenizer")
except Exception as exc:
logger.warning("Could not copy tokenizer: %s", exc)

# ----- Rewrite safetensors index (multi-shard models) -----
for idx_file in model_path.glob("*.safetensors.index.json"):
with open(idx_file) as fp:
index = json.load(fp)

new_map: dict[str, str] = {}
for old_key, shard_name in index.get("weight_map", {}).items():
new_map[_rename_key(old_key, all_awq_prefixes)] = shard_name
index["weight_map"] = new_map

with open(output_path / idx_file.name, "w") as fp:
json.dump(index, fp, indent=2)
logger.info("Rewrote %s", idx_file.name)

# ----- Copy any remaining auxiliary files -----
_auxiliary_globs = [
"generation_config.json",
"special_tokens_map.json",
"merges.txt",
]
for pattern in _auxiliary_globs:
for src in model_path.glob(pattern):
dst = output_path / src.name
if not dst.exists():
shutil.copy2(src, dst)
Comment on lines +300 to +309
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current method of copying auxiliary files using a hardcoded list of globs is brittle. It may miss important files required for the model to load correctly, such as tokenizer.json or other tokenizer-related files not covered by save_pretrained. A more robust approach is to iterate through all files in the source directory and copy any that are not explicitly generated or modified by this script. This ensures a more complete and reliable model conversion.

    # ----- Copy any remaining auxiliary files -----
    for src in model_path.glob("*"):
        if src.is_dir() or src.suffix == ".safetensors":
            continue
        dst = output_path / src.name
        if not dst.exists():
            shutil.copy2(src, dst)


logger.info("Conversion complete.")


# ---------------------------------------------------------------------------
# CLI entry-point
# ---------------------------------------------------------------------------


def _build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description=(
"Convert an AutoAWQ quantized model checkpoint to the "
"compressed-tensors pack_quantized format."
),
)
parser.add_argument(
"--model-path",
type=str,
required=True,
help="Path to the AutoAWQ model directory.",
)
parser.add_argument(
"--output-path",
type=str,
required=True,
help="Destination directory for the converted model.",
)
parser.add_argument(
"--num-bits",
type=int,
default=4,
help="Quantization bit-width (default: 4).",
)
parser.add_argument(
"--group-size",
type=int,
default=128,
help="Quantization group size (default: 128).",
)
parser.add_argument(
"--symmetric",
action="store_true",
default=False,
help="Treat quantisation as symmetric (default: asymmetric).",
)
return parser


def main(argv: list[str] | None = None) -> None:
args = _build_parser().parse_args(argv)
logging.basicConfig(level=logging.INFO, format="%(levelname)s | %(message)s")
convert_autoawq_to_ct(
model_path=args.model_path,
output_path=args.output_path,
num_bits=args.num_bits,
group_size=args.group_size,
symmetric=args.symmetric,
)


if __name__ == "__main__":
main()
Empty file.
Loading
Loading