Skip to content

Commit d43d52e

Browse files
vmpurivmpuri
vmpuri
authored and
vmpuri
committed
Replace WeightOnlyInt8Linear with TorchAO int8_weight_only quantization
1 parent 7fe2c86 commit d43d52e

File tree

1 file changed

+11
-142
lines changed

1 file changed

+11
-142
lines changed

torchchat/utils/quantize.py

+11-142
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa
3737
from torchao.quantization.quant_api import (
3838
int4_weight_only,
39+
int8_weight_only,
3940
Int4WeightOnlyQuantizer,
4041
Int8DynActInt4WeightQuantizer,
4142
quantize_,
@@ -110,12 +111,20 @@ def quantize_model(
110111
if quantizer not in quantizer_class_dict:
111112
raise RuntimeError(f"unknown quantizer {quantizer} specified")
112113
else:
114+
ao_quant = True
113115
# Use tensor subclass API for int4 weight only.
114116
if device == "cuda" and quantizer == "linear:int4":
115117
quantize_(model, int4_weight_only(q_kwargs["groupsize"]))
118+
elif quantizer == "linear:int8":
119+
print("quantizer is linear int8")
120+
quantize_(model, int8_weight_only())
121+
else:
122+
ao_quant = False
123+
if ao_quant:
116124
if not support_tensor_subclass:
117125
unwrap_tensor_subclass(model)
118126
continue
127+
119128

120129
if quantizer in ["linear:a8wxdq", "embedding:wx"]:
121130
# These quantizers require float32 input weights. Note that after quantization,
@@ -529,147 +538,6 @@ def linear_int8_et(input, weight, scales):
529538
)
530539

531540

532-
class WeightOnlyInt8Linear(nn.Module):
533-
__constants__ = ["in_features", "out_features"]
534-
in_features: int
535-
out_features: int
536-
weight: torch.Tensor
537-
scales: torch.Tensor
538-
539-
def __init__(
540-
self,
541-
in_features,
542-
out_features,
543-
bias=None,
544-
device=None,
545-
dtype=None,
546-
*,
547-
weight: Optional[torch.Tensor] = None,
548-
scales: Optional[torch.Tensor] = None,
549-
groupsize: Optional[int] = None,
550-
):
551-
super().__init__()
552-
if dtype is None:
553-
dtype = torch.get_default_dtype()
554-
555-
if device is None:
556-
device = "cpu"
557-
558-
assert not bias, "Bias is not supported by LinearInt8"
559-
self.in_features = in_features
560-
self.out_features = out_features
561-
562-
assert (weight is None) == bool(
563-
scales is None
564-
), "must specify both weights and scales, or neither"
565-
if weight is None:
566-
weight = torch.empty(
567-
(out_features, in_features),
568-
dtype=torch.int8,
569-
device=device,
570-
)
571-
if groupsize is None or (groupsize == 0):
572-
scales = torch.empty(out_features, dtype=dtype, device=device)
573-
else:
574-
n_groups = (in_features + groupsize - 1) // groupsize
575-
scales = torch.empty(out_features, n_groups, dtype=dtype, device=device)
576-
577-
self.register_buffer("weight", weight.to(device))
578-
self.register_buffer("scales", scales.to(device))
579-
580-
if use_et_backend():
581-
self.forward = self.et_forward
582-
else:
583-
self.forward = self.aoti_forward
584-
585-
def aoti_forward(self, input: torch.Tensor) -> torch.Tensor:
586-
return linear_int8_aoti(input, self.weight, self.scales)
587-
588-
def et_forward(self, input: torch.Tensor) -> torch.Tensor:
589-
return linear_int8_et(input, self.weight, self.scales)
590-
591-
592-
class WeightOnlyInt8QuantHandler(QuantHandler):
593-
def __init__(
594-
self,
595-
model: Optional[nn.Module] = None,
596-
device = None,
597-
precision=None,
598-
tokenizer=None,
599-
*,
600-
node_type: str = "*",
601-
bitwidth: Optional[int] = None,
602-
groupsize: Optional[int] = None,
603-
):
604-
self.model_ = model
605-
self.device = device
606-
self.groupsize = groupsize
607-
self.node_type = node_type
608-
if bitwidth is None:
609-
self.bitwidth = 8
610-
else:
611-
self.bitwidth = bitwidth
612-
613-
@torch.no_grad()
614-
def quantize(self, module):
615-
# cur_state_dict = state_dict_device(self.model_.state_dict())
616-
# dict_device = "cpu" # self.device
617-
618-
if self.bitwidth == 4:
619-
range_min = -8
620-
range_max = 7
621-
elif self.bitwidth == 8:
622-
range_min = -128
623-
range_max = 127
624-
else:
625-
raise ValueError(f"Unsupported bitwidth {self.bitwidth}")
626-
627-
for name, child in module.named_children():
628-
# print(f"name: {name}")
629-
if isinstance(child, nn.Linear):
630-
if (
631-
(self.node_type == "*")
632-
or (self.node_type == "output" and name == "output")
633-
or (self.node_type == "!output" and name != "output")
634-
):
635-
# print(f"{name, child}")
636-
input_weight = child.weight.float()
637-
# print(f"{name, child}")
638-
# print(f"in_features: {child.in_features}")
639-
# print(f"out_features: {child.out_features}")
640-
641-
# print(f"expanded weight shape {input_weight.shape}")
642-
weight, scales, _ = dynamically_quantize_per_channel(
643-
input_weight,
644-
range_min,
645-
range_max,
646-
torch.int8,
647-
self.groupsize,
648-
scales_dtype=child.weight.dtype,
649-
)
650-
651-
setattr(
652-
module,
653-
name,
654-
WeightOnlyInt8Linear(
655-
in_features=child.in_features,
656-
out_features=child.out_features,
657-
device=self.device,
658-
# update variables from quantization
659-
weight=weight,
660-
scales=scales,
661-
groupsize=self.groupsize,
662-
),
663-
)
664-
else:
665-
self.quantize(child)
666-
667-
return module
668-
669-
def quantized_model(self) -> nn.Module:
670-
return self.quantize(self.model_)
671-
672-
673541
#########################################################################
674542
##### embedding table quantization ######
675543
### (unify with torchao in future) ###
@@ -886,10 +754,10 @@ def quantized_model(self) -> nn.Module:
886754
# class references
887755
quantizer_class_dict = {
888756
"embedding": EmbeddingOnlyQuantHandler,
889-
"linear:int8": WeightOnlyInt8QuantHandler,
890757
"precision": PrecisionHandler,
891758
"executor": ExecutorHandler,
892759
"linear:int4": Int4WeightOnlyQuantizer,
760+
"linear:int8": int8_weight_only,
893761
"linear:a8w4dq": Int8DynActInt4WeightQuantizer,
894762
}
895763

@@ -917,6 +785,7 @@ def quantized_model(self) -> nn.Module:
917785
IntxWeightEmbeddingQuantizer,
918786
)
919787

788+
920789
quantizer_class_dict["linear:a8wxdq"] = Int8DynActIntxWeightLinearQuantizer
921790
quantizer_class_dict["embedding:wx"] = IntxWeightEmbeddingQuantizer
922791

0 commit comments

Comments
 (0)