Skip to content

Commit 92e0a9d

Browse files
vmpurivmpuri
vmpuri
authored and
vmpuri
committed
Replace WeightOnlyInt8Linear with TorchAO int8_weight_only quantization
1 parent 7fe2c86 commit 92e0a9d

File tree

1 file changed

+66
-163
lines changed

1 file changed

+66
-163
lines changed

torchchat/utils/quantize.py

+66-163
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
# from functools import reduce
2828
# from math import gcd
29-
from typing import Dict, Optional, Callable, Any, List
29+
from typing import Any, Callable, Dict, List, Optional
3030

3131
import torch
3232
import torch.nn as nn
@@ -37,6 +37,7 @@
3737
from torchao.quantization.quant_api import (
3838
int4_weight_only,
3939
Int4WeightOnlyQuantizer,
40+
int8_weight_only,
4041
Int8DynActInt4WeightQuantizer,
4142
quantize_,
4243
)
@@ -45,8 +46,8 @@
4546
find_multiple,
4647
get_device_str,
4748
get_precision,
48-
set_precision,
4949
name_to_dtype,
50+
set_precision,
5051
state_dict_device,
5152
use_et_backend,
5253
)
@@ -60,28 +61,36 @@
6061

6162
import inspect
6263

64+
6365
def get_named_parameters(func: Callable) -> List[str]:
6466
# Get the signature of the function
6567
signature = inspect.signature(func)
66-
68+
6769
# Extract the parameters from the signature
6870
parameters = signature.parameters
69-
71+
7072
# Filter and return named parameters
7173
named_params = [
72-
name for name, param in parameters.items()
73-
if param.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY)
74+
name
75+
for name, param in parameters.items()
76+
if param.kind
77+
in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY)
7478
]
7579
return named_params
7680

77-
def validate_args(named_params: List[str], q_kwargs: Dict[str, Any], quantizer: Optional[str] = None) -> Dict[str, Any]:
81+
82+
def validate_args(
83+
named_params: List[str], q_kwargs: Dict[str, Any], quantizer: Optional[str] = None
84+
) -> Dict[str, Any]:
7885
for key in q_kwargs.keys():
7986
if key not in named_params:
80-
print(f"Specification for quantizer {quantizer} has extraneous key {key}. Ignoring.")
87+
print(
88+
f"Specification for quantizer {quantizer} has extraneous key {key}. Ignoring."
89+
)
8190
del q_kwargs[key]
8291
return q_kwargs
83-
84-
92+
93+
8594
#########################################################################
8695
### torchchat quantization API ###
8796

@@ -110,21 +119,30 @@ def quantize_model(
110119
if quantizer not in quantizer_class_dict:
111120
raise RuntimeError(f"unknown quantizer {quantizer} specified")
112121
else:
122+
ao_quant = True
113123
# Use tensor subclass API for int4 weight only.
114124
if device == "cuda" and quantizer == "linear:int4":
115125
quantize_(model, int4_weight_only(q_kwargs["groupsize"]))
126+
elif quantizer == "linear:int8":
127+
print("quantizer is linear int8")
128+
quantize_(model, int8_weight_only())
129+
else:
130+
ao_quant = False
131+
if ao_quant:
116132
if not support_tensor_subclass:
117133
unwrap_tensor_subclass(model)
118134
continue
119-
135+
120136
if quantizer in ["linear:a8wxdq", "embedding:wx"]:
121137
# These quantizers require float32 input weights. Note that after quantization,
122138
# the weights will no longer be float32, but lowbit integers
123139
if get_precision() != torch.float32:
124-
print(f"Quantizer {quantizer} requires float32 inputs, but received {get_precision()}. Changing dtype to float32. Note that after quantization, the weights will be lowbit integers, not float32.")
140+
print(
141+
f"Quantizer {quantizer} requires float32 inputs, but received {get_precision()}. Changing dtype to float32. Note that after quantization, the weights will be lowbit integers, not float32."
142+
)
125143
set_precision(torch.float32)
126-
127-
# We set global precision from quantize options if it is specified at cli.py:485
144+
145+
# We set global precision from quantize options if it is specified at cli.py:485
128146
# so the precision returned by get_precision() is always the authoritative precision/dtype in torchchat
129147
precision = get_precision()
130148

@@ -141,14 +159,19 @@ def quantize_model(
141159
model = quant_handler.quantize(model)
142160

143161

144-
145162
#########################################################################
146163
### QuantHandler API definition ###
147164
### (unify with torchao in future) ###
148165

149166

150167
class QuantHandler:
151-
def __init__(self, model: Optional[nn.Module] = None, device="cpu", precision=None, tokenizer=None):
168+
def __init__(
169+
self,
170+
model: Optional[nn.Module] = None,
171+
device="cpu",
172+
precision=None,
173+
tokenizer=None,
174+
):
152175
self.model_ = model
153176
self.device = device
154177
self.tokenizer = tokenizer
@@ -176,7 +199,15 @@ def quantize(self, model: nn.Module) -> nn.Module:
176199

177200

178201
class PrecisionHandler(QuantHandler):
179-
def __init__(self, model: Optional[nn.Module]=None, device="cpu", precision=None, tokenizer=None, *, dtype):
202+
def __init__(
203+
self,
204+
model: Optional[nn.Module] = None,
205+
device="cpu",
206+
precision=None,
207+
tokenizer=None,
208+
*,
209+
dtype,
210+
):
180211
self.model_ = model
181212
self.device = device
182213
self.tokenizer = tokenizer
@@ -205,7 +236,15 @@ def quantized_model(self) -> nn.Module:
205236

206237

207238
class ExecutorHandler(QuantHandler):
208-
def __init__(self, model: Optional[nn.Module]=None, device="cpu", precision=None, tokenizer=None, *, accelerator):
239+
def __init__(
240+
self,
241+
model: Optional[nn.Module] = None,
242+
device="cpu",
243+
precision=None,
244+
tokenizer=None,
245+
*,
246+
accelerator,
247+
):
209248
self.model_ = model
210249

211250
if isinstance(accelerator, str):
@@ -529,147 +568,6 @@ def linear_int8_et(input, weight, scales):
529568
)
530569

531570

532-
class WeightOnlyInt8Linear(nn.Module):
533-
__constants__ = ["in_features", "out_features"]
534-
in_features: int
535-
out_features: int
536-
weight: torch.Tensor
537-
scales: torch.Tensor
538-
539-
def __init__(
540-
self,
541-
in_features,
542-
out_features,
543-
bias=None,
544-
device=None,
545-
dtype=None,
546-
*,
547-
weight: Optional[torch.Tensor] = None,
548-
scales: Optional[torch.Tensor] = None,
549-
groupsize: Optional[int] = None,
550-
):
551-
super().__init__()
552-
if dtype is None:
553-
dtype = torch.get_default_dtype()
554-
555-
if device is None:
556-
device = "cpu"
557-
558-
assert not bias, "Bias is not supported by LinearInt8"
559-
self.in_features = in_features
560-
self.out_features = out_features
561-
562-
assert (weight is None) == bool(
563-
scales is None
564-
), "must specify both weights and scales, or neither"
565-
if weight is None:
566-
weight = torch.empty(
567-
(out_features, in_features),
568-
dtype=torch.int8,
569-
device=device,
570-
)
571-
if groupsize is None or (groupsize == 0):
572-
scales = torch.empty(out_features, dtype=dtype, device=device)
573-
else:
574-
n_groups = (in_features + groupsize - 1) // groupsize
575-
scales = torch.empty(out_features, n_groups, dtype=dtype, device=device)
576-
577-
self.register_buffer("weight", weight.to(device))
578-
self.register_buffer("scales", scales.to(device))
579-
580-
if use_et_backend():
581-
self.forward = self.et_forward
582-
else:
583-
self.forward = self.aoti_forward
584-
585-
def aoti_forward(self, input: torch.Tensor) -> torch.Tensor:
586-
return linear_int8_aoti(input, self.weight, self.scales)
587-
588-
def et_forward(self, input: torch.Tensor) -> torch.Tensor:
589-
return linear_int8_et(input, self.weight, self.scales)
590-
591-
592-
class WeightOnlyInt8QuantHandler(QuantHandler):
593-
def __init__(
594-
self,
595-
model: Optional[nn.Module] = None,
596-
device = None,
597-
precision=None,
598-
tokenizer=None,
599-
*,
600-
node_type: str = "*",
601-
bitwidth: Optional[int] = None,
602-
groupsize: Optional[int] = None,
603-
):
604-
self.model_ = model
605-
self.device = device
606-
self.groupsize = groupsize
607-
self.node_type = node_type
608-
if bitwidth is None:
609-
self.bitwidth = 8
610-
else:
611-
self.bitwidth = bitwidth
612-
613-
@torch.no_grad()
614-
def quantize(self, module):
615-
# cur_state_dict = state_dict_device(self.model_.state_dict())
616-
# dict_device = "cpu" # self.device
617-
618-
if self.bitwidth == 4:
619-
range_min = -8
620-
range_max = 7
621-
elif self.bitwidth == 8:
622-
range_min = -128
623-
range_max = 127
624-
else:
625-
raise ValueError(f"Unsupported bitwidth {self.bitwidth}")
626-
627-
for name, child in module.named_children():
628-
# print(f"name: {name}")
629-
if isinstance(child, nn.Linear):
630-
if (
631-
(self.node_type == "*")
632-
or (self.node_type == "output" and name == "output")
633-
or (self.node_type == "!output" and name != "output")
634-
):
635-
# print(f"{name, child}")
636-
input_weight = child.weight.float()
637-
# print(f"{name, child}")
638-
# print(f"in_features: {child.in_features}")
639-
# print(f"out_features: {child.out_features}")
640-
641-
# print(f"expanded weight shape {input_weight.shape}")
642-
weight, scales, _ = dynamically_quantize_per_channel(
643-
input_weight,
644-
range_min,
645-
range_max,
646-
torch.int8,
647-
self.groupsize,
648-
scales_dtype=child.weight.dtype,
649-
)
650-
651-
setattr(
652-
module,
653-
name,
654-
WeightOnlyInt8Linear(
655-
in_features=child.in_features,
656-
out_features=child.out_features,
657-
device=self.device,
658-
# update variables from quantization
659-
weight=weight,
660-
scales=scales,
661-
groupsize=self.groupsize,
662-
),
663-
)
664-
else:
665-
self.quantize(child)
666-
667-
return module
668-
669-
def quantized_model(self) -> nn.Module:
670-
return self.quantize(self.model_)
671-
672-
673571
#########################################################################
674572
##### embedding table quantization ######
675573
### (unify with torchao in future) ###
@@ -886,10 +784,10 @@ def quantized_model(self) -> nn.Module:
886784
# class references
887785
quantizer_class_dict = {
888786
"embedding": EmbeddingOnlyQuantHandler,
889-
"linear:int8": WeightOnlyInt8QuantHandler,
890787
"precision": PrecisionHandler,
891788
"executor": ExecutorHandler,
892789
"linear:int4": Int4WeightOnlyQuantizer,
790+
"linear:int8": int8_weight_only,
893791
"linear:a8w4dq": Int8DynActInt4WeightQuantizer,
894792
}
895793

@@ -932,11 +830,16 @@ def quantized_model(self) -> nn.Module:
932830
print("Slow fallback kernels will be used.")
933831

934832
except Exception as e:
833+
935834
class ErrorHandler(QuantHandler):
936-
def __init__(self, model: Optional[nn.Module]=None, device="cpu", precision=None):
835+
def __init__(
836+
self, model: Optional[nn.Module] = None, device="cpu", precision=None
837+
):
937838
global torchao_experimental_load_error
938-
raise Exception(f"Note: Failed to load torchao experimental quantizer with error: {torchao_experimental_load_error}")
939-
839+
raise Exception(
840+
f"Note: Failed to load torchao experimental quantizer with error: {torchao_experimental_load_error}"
841+
)
842+
940843
torchao_experimental_load_error = e
941844
quantizer_class_dict["linear:a8wxdq"] = ErrorHandler
942845
quantizer_class_dict["embedding:wx"] = ErrorHandler

0 commit comments

Comments
 (0)