Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 103 additions & 0 deletions src/python/py/models/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ def __init__(
"algo_config": int4_algo_config,
},
"use_qdq": extra_options.get("use_qdq", False),
"use_tmac": extra_options.get("use_tmac", False)
}
if self.quant_type is not None:
# Create quantized attributes from quantization config
Expand Down Expand Up @@ -358,6 +359,7 @@ def make_attention_init(self):
and not self.matmul_attrs["use_lora"]
and not self.attention_attrs["q_norm"]
and not self.attention_attrs["k_norm"]
and not extra_options.get("use_tmac", False)
)

# Some EPs don't support fusing rotary embeddings inside GQA yet
Expand Down Expand Up @@ -799,6 +801,8 @@ def make_matmul(self, matmul, basename, root_input, **kwargs):
def make_matmul_op(self, matmul, basename, root_input, **kwargs):
if self.onnx_dtype in {ir.DataType.FLOAT16, ir.DataType.BFLOAT16, ir.DataType.FLOAT}:
return self.make_matmul_float(matmul, basename, root_input, **kwargs)
elif self.quant_attrs["use_tmac"]:
return self.make_matmul_tmac(matmul, basename, root_input, **kwargs)
elif self.onnx_dtype in {ir.DataType.INT4, ir.DataType.UINT4}:
if self.quant_attrs["use_qdq"]:
return self.make_matmul_int4_qdq(matmul, basename, root_input, **kwargs)
Expand All @@ -818,6 +822,87 @@ def make_matmul_float(self, matmul, name, root_input, **kwargs):

return name

def make_matmul_tmac(self, matmul, basename, root_input, **kwargs):
if not hasattr(matmul, "qweight"):
# TODO: quantize weights, then save new MatMul weights for onnx model
# print(f"Quantizing to {self.onnx_dtype} on-the-fly is not currently supported.")
# print(f"Saving as {self.io_dtype} on-the-fly and quantizing to {self.onnx_dtype} at the end.")
return self.make_matmul_float(matmul, basename, root_input, **kwargs)

if "q_proj" in basename or "k_proj" in basename:
n_heads = self.num_attn_heads if "q_proj" in basename else self.num_kv_heads
head_dim = matmul.out_features // n_heads

# Permutation indices
perm = torch.arange(0, matmul.out_features, dtype=torch.long, device=matmul.qweight.device)
perm = perm.view(n_heads, 2, head_dim // 2).transpose(1, 2).reshape(-1)

# Apply to qweight
if matmul.qweight.shape[0] == matmul.out_features:
matmul.qweight = matmul.qweight[perm, :]
elif matmul.qweight.shape[1] == matmul.out_features:
matmul.qweight = matmul.qweight[:, perm]

# Apply to scales
if matmul.scales.shape[0] == matmul.out_features:
if matmul.scales.dim() == 1:
matmul.scales = matmul.scales[perm]
else:
matmul.scales = matmul.scales[perm, :]
elif matmul.scales.dim() > 1 and matmul.scales.shape[1] == matmul.out_features:
matmul.scales = matmul.scales[:, perm]

# Apply to qzeros
if hasattr(matmul, "qzeros") and matmul.qzeros is not None:
if hasattr(self, "quant_model"):
# Unpack qzeros, permute, and repack
self.quant_model.unpack_qzeros(matmul)

if matmul.qzeros.shape[0] == matmul.out_features:
matmul.qzeros = matmul.qzeros[perm, :]
elif matmul.qzeros.shape[1] == matmul.out_features:
matmul.qzeros = matmul.qzeros[:, perm]

self.quant_model.pack_qzeros(matmul)
else:
# Fallback if quant_model is not available (e.g. manual quantization)
# Assuming qzeros is not packed along out_features or we can't handle it
if matmul.qzeros.shape[0] == matmul.out_features:
matmul.qzeros = matmul.qzeros[perm, :]
elif matmul.qzeros.shape[1] == matmul.out_features:
matmul.qzeros = matmul.qzeros[:, perm]

name = f"{basename}NBits"

# Input weights are quantized, save quantized MatMul weights for onnx model
weight = name[1:].replace("/", ".") + ".qweight"
self.make_initializer(matmul.qweight, weight)
scales = name[1:].replace("/", ".") + ".scales"
self.make_initializer(matmul.scales, scales, to=self.io_dtype)

inputs = [root_input, weight, scales]

if hasattr(matmul, "qzeros") and matmul.qzeros is not None:
zeros = name[1:].replace("/", ".") + ".qzeros"
self.make_initializer(matmul.qzeros, zeros)
inputs.append(zeros)

if hasattr(matmul, "g_idx") and matmul.g_idx is not None:
g_idx = name[1:].replace("/", ".") + ".g_idx"
self.make_initializer(matmul.g_idx, g_idx, to=ir.DataType.INT32)
inputs.append(g_idx)

output = "logits" if kwargs.get("logits", False) else f"{name}/output_0"
self.make_node(
"MatMulNBits", inputs=inputs, outputs=[output], name=name, domain="com.microsoft",
accuracy_level=5,
bits=matmul.bits, block_size=matmul.group_size, K=matmul.in_features, N=matmul.out_features,
)
self.make_value(output, self.io_dtype, shape=['batch_size', 'sequence_length', matmul.out_features])

return name


def make_matmul_int4(self, matmul, basename, root_input, **kwargs):
if not hasattr(matmul, "qweight"):
# TODO: quantize weights, then save new MatMul weights for onnx model
Expand Down Expand Up @@ -2544,6 +2629,7 @@ def make_model(self, input_path):
q_size = self.num_attn_heads * self.head_size
kv_size = self.num_kv_heads * self.head_size
model = QuantModel.from_pretrained(self.quant_type, input_path=input_path, quant_attrs=self.quant_attrs, q_size=q_size, kv_size=kv_size, intermediate_size=self.intermediate_size, num_layers=self.num_layers)
self.quant_model = model

else:
# Load PyTorch model
Expand Down Expand Up @@ -3640,6 +3726,20 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):
self.rotemb_attrs["rescale_factors"] = 1.0 / config.compression_ratio


class BitnetModel(MistralModel):
def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):
super().__init__(config, io_dtype, onnx_dtype, ep, cache_dir, extra_options)

# def make_mlp_proj(self, layer_id, mlp, root_input):
# BitNEtMLP: self.down_proj(self.ffn_layernorm(self.act_fn(self.gate_proj(x) * self.up_proj(x))))

# def make_attention(self, layer_id, attention, root_input, **kwargs):
# BitLinear layer





def check_extra_options(kv_pairs):
"""
Check key-value pairs and set values correctly
Expand Down Expand Up @@ -3828,6 +3928,8 @@ def create_model(model_name, input_path, output_dir, precision, execution_provid
onnx_model = QwenModel(config, io_dtype, onnx_dtype, execution_provider, cache_dir, extra_options)
elif config.architectures[0] == "Qwen3ForCausalLM":
onnx_model = Qwen3Model(config, io_dtype, onnx_dtype, execution_provider, cache_dir, extra_options)
elif config.architectures[0] == "BitnetForCausalLM":
onnx_model = BitnetModel(config, io_dtype, onnx_dtype, execution_provider, cache_dir, extra_options)
else:
raise NotImplementedError(f"The {hf_name} model is not currently supported.")

Expand Down Expand Up @@ -3960,6 +4062,7 @@ def get_args():
Use this option to create quantized ONNX models that use BF16 precision.
adapter_path = Path to folder on disk containing the adapter files (adapter_config.json and adapter model weights).
Use this option for LoRA models.
use_tmac = Use T-MAC for quantization. Default is false. Supports TMAC_*, Q4_0, TQ types and GPTQ, GPTQv2, BitNet and BitDistiller models.
"""),
)

Expand Down
66 changes: 64 additions & 2 deletions src/python/py/models/quantized_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import os
import re
import numpy as np


class QuantizedTensorModule:
Expand Down Expand Up @@ -103,7 +104,7 @@ def __init__(self, quant_type, input_path, quant_attrs, q_size, kv_size, interme
self.num_layers = num_layers
self._quant_attrs = quant_attrs
self._load_quant_config(quant_attrs) # codeql[py/init-calls-subclass]

print(f"Loading quantized model from {input_path} with quantization type {quant_type}")
for weight_file in os.listdir(input_path):
if weight_file.endswith(".safetensors"):
weights = load_file(os.path.join(input_path, weight_file))
Expand All @@ -114,6 +115,7 @@ def __init__(self, quant_type, input_path, quant_attrs, q_size, kv_size, interme
# Per-layer quantization support
local_bits = self.get_layer_bits(name) # codeql[py/init-calls-subclass]
local_group_size = self.get_layer_group_size(name) # codeql[py/init-calls-subclass]
local_sym = self.get_layer_sym(name) # codeql[py/init-calls-subclass]

if name == "model.embed_tokens.weight" or name == "transformer.embedding.word_embeddings.weight":
self.embedding.weight = tensor
Expand All @@ -127,7 +129,8 @@ def __init__(self, quant_type, input_path, quant_attrs, q_size, kv_size, interme
self.lm_head.bias = tensor
elif name == "transformer.rotary_pos_emb.inv_freq":
# transformer.rotary_pos_emb.inv_freq in ChatGLM3.
# Skip rotary embedding weights since they can be re-calculated when looping through the model
# Skip rotary embedding weights since they can be re-calculated wh
# en looping through the model
continue
elif name == "lm_head.qweight" or name == "transformer.output_layer.qweight":
self._initialize_quantized_lm_head(local_bits, local_group_size)
Expand Down Expand Up @@ -437,6 +440,7 @@ def __init__(self, quant_type, input_path, quant_attrs, q_size, kv_size, interme
def _load_quant_config(self, quant_attrs):
self.global_group_size = quant_attrs["config"]["group_size"]
self.global_bits = quant_attrs["config"]["bits"]
self.global_sym = quant_attrs["config"].get("sym", False)

def get_layer_bits(self, layer_name):
# 'bits' is globally defined for all layers
Expand All @@ -445,6 +449,10 @@ def get_layer_bits(self, layer_name):
def get_layer_group_size(self, layer_name):
# 'group_size' is globally defined for all layers
return self.global_group_size

def get_layer_sym(self, layer_name):
# 'sym' is globally defined for all layers
return self.global_sym

def _initialize_quantized_lm_head(self, bits, group_size):
"""
Expand Down Expand Up @@ -839,6 +847,7 @@ def __init__(self, quant_type, input_path, quant_attrs, q_size, kv_size, interme
# Set `g_idx` to None since it's not used in `MatMulNBits`
self.lm_head.g_idx = None


def handle_qzeros(self, module):
"""
Re-pack `qzeros` to handle extra `-1`s
Expand All @@ -863,6 +872,58 @@ def __init__(self, module):
self.pack_qzeros(temp_module)
module.qzeros = temp_module.qzeros


def unpack_pack_tmac(self, module):
"""
Return T_-MAC biased unit8 weight [0, 2 ** bits), fp16 scales and biased fp16 zeros
"""

# assert that qweight and qzeros are of type torch.int32 type
if module.qweight.dtype != torch.int32 or module.qzeros.dtype != torch.int32:
raise ValueError("T-MAC unpacking requires qweight and qzeros to be of type torch.int32.")

bits = 32 // (module.scales.shape[1] // module.qzeros.shape[1])
K = module.qweight.shape[0] * (32 // bits)
M = module.qweight.shape[1]
group_size =K // module.scales.shape[0]

# Currently only support models that all weights are coreresponding to qunatization config
if bits != module.bits or group_size != module.group_size:
raise ValueError(f"Error in T-MAC unpacking: bits {bits} and group_size {group_size} do not match module's bits {module.bits} and group_size {module.group_size}.")

qweight = module.qweight.numpy()
qzeros = module.qzeros.numpy()
scales = module.scales.numpy()

# TODO: use unpack_on_row with transpose = true
qweights = [(qweight >> bit_offset) & ((1 << bits) - 1) for bit_offset in range(0, 32, bits)]
weight = np.stack(qweights, axis=1).reshape(K, M).T.astype("uint8")

scales = scales.T

# Unpack qzeros
zeros = [(qzeros >> bit_offset) & ((1 << bits) - 1) for bit_offset in range(0, 32, bits)]
zeros = np.stack(zeros, axis=-1).reshape(K // group_size, M).T.astype(scales.dtype)
if not self.gptq_v2:
# `zeros = zeros - 1` in AutoGPTQ, Not in GPTQModel
zeros += 1
zeros = (zeros - (2 ** (bits - 1))) * scales

# get packed weight
# TODO: use torch operations for packing, can't use pack_on_row
mask = (1 << bits) - 1
flattened_bits = np.unpackbits((weight & mask).astype(np.uint8)).reshape(-1, 8)[:, -bits:]
weight_packed = np.packbits(flattened_bits)

if not self.is_symmetric and zeros is not None:
module.qzeros = torch.from_numpy(zeros.astype(np.float16).copy().view(np.uint8).flatten())
else :
module.qzeros = None

module.scales = torch.from_numpy(scales.astype(np.float16))
module.qweight = torch.from_numpy(weight_packed.reshape(M, K * bits // 8))


class QuarkModel(QuantizedModel):
def __init__(self, quant_type, input_path, quant_attrs, q_size, kv_size, intermediate_size, num_layers):
super().__init__(quant_type, input_path, quant_attrs, q_size, kv_size, intermediate_size, num_layers)
Expand Down Expand Up @@ -989,6 +1050,7 @@ def get_layer_group_size(self, layer_name):
name = ".".join(layer_name.split(".")[:-1])
return self.overrides.get(name, {}).get("group_size", self.global_group_size)


class QuantModel:
@staticmethod
def from_pretrained(quant_type, **kwargs):
Expand Down