-
Notifications
You must be signed in to change notification settings - Fork 244
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Replace WeightOnlyInt8Linear with TorchAO int8_weight_only quantization #1328
Changes from 2 commits
92e0a9d
1a42fb6
93d9876
d1d6aa1
a286e58
a655d58
8811c7e
483928b
add35e8
008fea0
d2e4995
bc2c2d0
4eb7fbb
d62680c
c0630a6
fe76c85
8478e5d
5e18de7
8475c79
731936d
554cf86
dadaade
c7bb8b9
b0abf27
b870f7e
4e621ce
6e40ec0
d979da1
6d6f2b9
46b784e
2c03a2a
e1fefc0
bc0c1dc
dfbd865
19ecd95
1315275
1d7e71f
36d0712
5bc5552
902542d
6de1a01
ff2d53c
5e16167
582e558
7dad56f
155bd4b
a325191
86efcd3
a1ba6a1
490ad39
b95074b
3f0fec3
c121ed2
e60680b
1ba40d7
06e78ce
6bfc5c8
d625f72
e5543e2
2d96e48
defc225
2227014
bc0f93a
cd10377
ef58fce
601f2d1
083960b
a942c16
f514b35
c536da4
8662471
a64b9e3
84d2232
1c2f5aa
59e168e
7b3a5fd
4e2c384
8bae547
eba2b07
2f34fee
ad7f85a
31ecb18
5f9b347
8b1af3f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,7 +26,7 @@ | |
|
||
# from functools import reduce | ||
# from math import gcd | ||
from typing import Dict, Optional, Callable, Any, List | ||
from typing import Any, Callable, Dict, List, Optional | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
@@ -37,6 +37,7 @@ | |
from torchao.quantization.quant_api import ( | ||
int4_weight_only, | ||
Int4WeightOnlyQuantizer, | ||
int8_weight_only, | ||
Int8DynActInt4WeightQuantizer, | ||
quantize_, | ||
) | ||
|
@@ -45,8 +46,8 @@ | |
find_multiple, | ||
get_device_str, | ||
get_precision, | ||
set_precision, | ||
name_to_dtype, | ||
set_precision, | ||
state_dict_device, | ||
use_et_backend, | ||
) | ||
|
@@ -60,28 +61,36 @@ | |
|
||
import inspect | ||
|
||
|
||
def get_named_parameters(func: Callable) -> List[str]: | ||
# Get the signature of the function | ||
signature = inspect.signature(func) | ||
|
||
# Extract the parameters from the signature | ||
parameters = signature.parameters | ||
|
||
# Filter and return named parameters | ||
named_params = [ | ||
name for name, param in parameters.items() | ||
if param.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY) | ||
name | ||
for name, param in parameters.items() | ||
if param.kind | ||
in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY) | ||
] | ||
return named_params | ||
|
||
def validate_args(named_params: List[str], q_kwargs: Dict[str, Any], quantizer: Optional[str] = None) -> Dict[str, Any]: | ||
|
||
def validate_args( | ||
named_params: List[str], q_kwargs: Dict[str, Any], quantizer: Optional[str] = None | ||
) -> Dict[str, Any]: | ||
for key in q_kwargs.keys(): | ||
if key not in named_params: | ||
print(f"Specification for quantizer {quantizer} has extraneous key {key}. Ignoring.") | ||
print( | ||
f"Specification for quantizer {quantizer} has extraneous key {key}. Ignoring." | ||
) | ||
del q_kwargs[key] | ||
return q_kwargs | ||
|
||
|
||
######################################################################### | ||
### torchchat quantization API ### | ||
|
||
|
@@ -110,21 +119,30 @@ def quantize_model( | |
if quantizer not in quantizer_class_dict: | ||
raise RuntimeError(f"unknown quantizer {quantizer} specified") | ||
else: | ||
ao_quant = True | ||
# Use tensor subclass API for int4 weight only. | ||
if device == "cuda" and quantizer == "linear:int4": | ||
quantize_(model, int4_weight_only(q_kwargs["groupsize"])) | ||
elif quantizer == "linear:int8": | ||
print("quantizer is linear int8") | ||
quantize_(model, int8_weight_only()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not integrate it into a QuantHandler class dispatched thru the handler dict at a single call site rather than build a chain of if statements? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi @mikekgfb, we will refactor this part in the future after all quant APIs are moved to torchao I think There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. torchAO already has a class-based API that is used for other quantizers? Why do these differently, and then later refactor them? Or why not do them all a consistent way now, and if you refactor later, do that? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah, quantizer API is deprecated in favor of |
||
else: | ||
ao_quant = False | ||
if ao_quant: | ||
if not support_tensor_subclass: | ||
unwrap_tensor_subclass(model) | ||
continue | ||
|
||
if quantizer in ["linear:a8wxdq", "embedding:wx"]: | ||
# These quantizers require float32 input weights. Note that after quantization, | ||
# the weights will no longer be float32, but lowbit integers | ||
if get_precision() != torch.float32: | ||
print(f"Quantizer {quantizer} requires float32 inputs, but received {get_precision()}. Changing dtype to float32. Note that after quantization, the weights will be lowbit integers, not float32.") | ||
print( | ||
f"Quantizer {quantizer} requires float32 inputs, but received {get_precision()}. Changing dtype to float32. Note that after quantization, the weights will be lowbit integers, not float32." | ||
) | ||
set_precision(torch.float32) | ||
# We set global precision from quantize options if it is specified at cli.py:485 | ||
|
||
# We set global precision from quantize options if it is specified at cli.py:485 | ||
# so the precision returned by get_precision() is always the authoritative precision/dtype in torchchat | ||
precision = get_precision() | ||
|
||
|
@@ -141,14 +159,19 @@ def quantize_model( | |
model = quant_handler.quantize(model) | ||
|
||
|
||
|
||
######################################################################### | ||
### QuantHandler API definition ### | ||
### (unify with torchao in future) ### | ||
|
||
|
||
class QuantHandler: | ||
def __init__(self, model: Optional[nn.Module] = None, device="cpu", precision=None, tokenizer=None): | ||
def __init__( | ||
self, | ||
model: Optional[nn.Module] = None, | ||
device="cpu", | ||
precision=None, | ||
tokenizer=None, | ||
): | ||
self.model_ = model | ||
self.device = device | ||
self.tokenizer = tokenizer | ||
|
@@ -176,7 +199,15 @@ def quantize(self, model: nn.Module) -> nn.Module: | |
|
||
|
||
class PrecisionHandler(QuantHandler): | ||
def __init__(self, model: Optional[nn.Module]=None, device="cpu", precision=None, tokenizer=None, *, dtype): | ||
def __init__( | ||
self, | ||
model: Optional[nn.Module] = None, | ||
device="cpu", | ||
precision=None, | ||
tokenizer=None, | ||
*, | ||
dtype, | ||
): | ||
self.model_ = model | ||
self.device = device | ||
self.tokenizer = tokenizer | ||
|
@@ -205,7 +236,15 @@ def quantized_model(self) -> nn.Module: | |
|
||
|
||
class ExecutorHandler(QuantHandler): | ||
def __init__(self, model: Optional[nn.Module]=None, device="cpu", precision=None, tokenizer=None, *, accelerator): | ||
def __init__( | ||
self, | ||
model: Optional[nn.Module] = None, | ||
device="cpu", | ||
precision=None, | ||
tokenizer=None, | ||
*, | ||
accelerator, | ||
): | ||
self.model_ = model | ||
|
||
if isinstance(accelerator, str): | ||
|
@@ -529,147 +568,6 @@ def linear_int8_et(input, weight, scales): | |
) | ||
|
||
|
||
class WeightOnlyInt8Linear(nn.Module): | ||
__constants__ = ["in_features", "out_features"] | ||
in_features: int | ||
out_features: int | ||
weight: torch.Tensor | ||
scales: torch.Tensor | ||
|
||
def __init__( | ||
self, | ||
in_features, | ||
out_features, | ||
bias=None, | ||
device=None, | ||
dtype=None, | ||
*, | ||
weight: Optional[torch.Tensor] = None, | ||
scales: Optional[torch.Tensor] = None, | ||
groupsize: Optional[int] = None, | ||
): | ||
super().__init__() | ||
if dtype is None: | ||
dtype = torch.get_default_dtype() | ||
|
||
if device is None: | ||
device = "cpu" | ||
|
||
assert not bias, "Bias is not supported by LinearInt8" | ||
self.in_features = in_features | ||
self.out_features = out_features | ||
|
||
assert (weight is None) == bool( | ||
scales is None | ||
), "must specify both weights and scales, or neither" | ||
if weight is None: | ||
weight = torch.empty( | ||
(out_features, in_features), | ||
dtype=torch.int8, | ||
device=device, | ||
) | ||
if groupsize is None or (groupsize == 0): | ||
scales = torch.empty(out_features, dtype=dtype, device=device) | ||
else: | ||
n_groups = (in_features + groupsize - 1) // groupsize | ||
scales = torch.empty(out_features, n_groups, dtype=dtype, device=device) | ||
|
||
self.register_buffer("weight", weight.to(device)) | ||
self.register_buffer("scales", scales.to(device)) | ||
|
||
if use_et_backend(): | ||
self.forward = self.et_forward | ||
else: | ||
self.forward = self.aoti_forward | ||
|
||
def aoti_forward(self, input: torch.Tensor) -> torch.Tensor: | ||
return linear_int8_aoti(input, self.weight, self.scales) | ||
|
||
def et_forward(self, input: torch.Tensor) -> torch.Tensor: | ||
return linear_int8_et(input, self.weight, self.scales) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Int 8 seems like it special cased for ET, reminder to check that as well |
||
|
||
|
||
class WeightOnlyInt8QuantHandler(QuantHandler): | ||
def __init__( | ||
self, | ||
model: Optional[nn.Module] = None, | ||
device = None, | ||
precision=None, | ||
tokenizer=None, | ||
*, | ||
node_type: str = "*", | ||
bitwidth: Optional[int] = None, | ||
groupsize: Optional[int] = None, | ||
): | ||
self.model_ = model | ||
self.device = device | ||
self.groupsize = groupsize | ||
self.node_type = node_type | ||
if bitwidth is None: | ||
self.bitwidth = 8 | ||
else: | ||
self.bitwidth = bitwidth | ||
|
||
@torch.no_grad() | ||
def quantize(self, module): | ||
# cur_state_dict = state_dict_device(self.model_.state_dict()) | ||
# dict_device = "cpu" # self.device | ||
|
||
if self.bitwidth == 4: | ||
range_min = -8 | ||
range_max = 7 | ||
elif self.bitwidth == 8: | ||
range_min = -128 | ||
range_max = 127 | ||
else: | ||
raise ValueError(f"Unsupported bitwidth {self.bitwidth}") | ||
|
||
for name, child in module.named_children(): | ||
# print(f"name: {name}") | ||
if isinstance(child, nn.Linear): | ||
if ( | ||
(self.node_type == "*") | ||
or (self.node_type == "output" and name == "output") | ||
or (self.node_type == "!output" and name != "output") | ||
): | ||
# print(f"{name, child}") | ||
input_weight = child.weight.float() | ||
# print(f"{name, child}") | ||
# print(f"in_features: {child.in_features}") | ||
# print(f"out_features: {child.out_features}") | ||
|
||
# print(f"expanded weight shape {input_weight.shape}") | ||
weight, scales, _ = dynamically_quantize_per_channel( | ||
input_weight, | ||
range_min, | ||
range_max, | ||
torch.int8, | ||
self.groupsize, | ||
scales_dtype=child.weight.dtype, | ||
) | ||
|
||
setattr( | ||
module, | ||
name, | ||
WeightOnlyInt8Linear( | ||
in_features=child.in_features, | ||
out_features=child.out_features, | ||
device=self.device, | ||
# update variables from quantization | ||
weight=weight, | ||
scales=scales, | ||
groupsize=self.groupsize, | ||
), | ||
) | ||
else: | ||
self.quantize(child) | ||
|
||
return module | ||
|
||
def quantized_model(self) -> nn.Module: | ||
return self.quantize(self.model_) | ||
|
||
|
||
######################################################################### | ||
##### embedding table quantization ###### | ||
### (unify with torchao in future) ### | ||
|
@@ -886,10 +784,10 @@ def quantized_model(self) -> nn.Module: | |
# class references | ||
quantizer_class_dict = { | ||
"embedding": EmbeddingOnlyQuantHandler, | ||
"linear:int8": WeightOnlyInt8QuantHandler, | ||
"precision": PrecisionHandler, | ||
"executor": ExecutorHandler, | ||
"linear:int4": Int4WeightOnlyQuantizer, | ||
"linear:int8": int8_weight_only, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we can probably use None for now, and remove this later There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We check for int8_weight_only and finished check before it looks at the table I think @vmpuri can you check? |
||
"linear:a8w4dq": Int8DynActInt4WeightQuantizer, | ||
} | ||
|
||
|
@@ -932,11 +830,16 @@ def quantized_model(self) -> nn.Module: | |
print("Slow fallback kernels will be used.") | ||
|
||
except Exception as e: | ||
|
||
class ErrorHandler(QuantHandler): | ||
def __init__(self, model: Optional[nn.Module]=None, device="cpu", precision=None): | ||
def __init__( | ||
self, model: Optional[nn.Module] = None, device="cpu", precision=None | ||
): | ||
global torchao_experimental_load_error | ||
raise Exception(f"Note: Failed to load torchao experimental quantizer with error: {torchao_experimental_load_error}") | ||
|
||
raise Exception( | ||
f"Note: Failed to load torchao experimental quantizer with error: {torchao_experimental_load_error}" | ||
) | ||
|
||
torchao_experimental_load_error = e | ||
quantizer_class_dict["linear:a8wxdq"] = ErrorHandler | ||
quantizer_class_dict["embedding:wx"] = ErrorHandler |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.