Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace WeightOnlyInt8Linear with TorchAO int8_weight_only quantization #1328

Closed
wants to merge 84 commits into from
Closed
Changes from 2 commits
Commits
Show all changes
84 commits
Select commit Hold shift + click to select a range
92e0a9d
Replace WeightOnlyInt8Linear with TorchAO int8_weight_only quantization
Oct 24, 2024
1a42fb6
Merge branch 'main' into torchao_int8_weight_only
vmpuri Nov 12, 2024
93d9876
fix: enforce python version install requirements (#1345)
leseb Nov 12, 2024
d1d6aa1
Remove last references to use_distributed argument (#1353)
mreso Nov 13, 2024
a286e58
Add cstdint to tokenizer (missing include) (#1339)
byjlw Nov 13, 2024
a655d58
Setup a SIGINT handler to gracefully exit the program once the user p…
leseb Nov 13, 2024
8811c7e
Update cli.py to make --device/--dtype pre-empt quantize dict-specifi…
mikekgfb Nov 13, 2024
483928b
Update Caching logic to only trigger on the first inference sample (#…
Jack-Khuu Nov 13, 2024
add35e8
Minor typo + Update install_requirements.sh to support python 3.10 >=…
Jack-Khuu Nov 13, 2024
008fea0
fix: Remove dup gguf dependency (#1371)
leseb Nov 14, 2024
d2e4995
Bug Fix: Check for explicit cli device (fast) (#1374)
Jack-Khuu Nov 14, 2024
bc2c2d0
fix: do not print perf stat when NaN (#1375)
leseb Nov 15, 2024
4eb7fbb
fix: Fail gracefully when "model" arg is missing when downloading (#1…
leseb Nov 16, 2024
d62680c
Ignore tokens per sec from jit_compile iteration (#1378)
yanbing-j Nov 19, 2024
c0630a6
Download fix (#1366)
gabe-l-hart Nov 19, 2024
fe76c85
Update builder.py (#1387)
mikekgfb Nov 19, 2024
8478e5d
Add multimodal to possible tests (#1382)
mikekgfb Nov 19, 2024
5e18de7
Fix typo in RuntimeException in builder.py (#1386)
mikekgfb Nov 20, 2024
8475c79
Bug fix: Enable fast to override quantize json (#1377)
Jack-Khuu Nov 20, 2024
731936d
Changing the referenced AAR so that it uses the AAR from the docs (#1…
infil00p Nov 23, 2024
554cf86
Typo fixes in native-execution.md (#1394)
mikekgfb Nov 26, 2024
dadaade
Improvements for readability in ADVANCED-USERS.md (#1393)
mikekgfb Nov 26, 2024
c7bb8b9
Update multimodal.md to exercise server as part of test (#1391)
mikekgfb Nov 26, 2024
b0abf27
Update quantization.md link to quantize.py (#1392)
Jack-Khuu Dec 3, 2024
b870f7e
Bump torch pin to 20241010 (#1400)
larryliu0820 Dec 6, 2024
4e621ce
Use pytorch-labs/tokenizers and remove tokenizer/ (#1401)
larryliu0820 Dec 7, 2024
6e40ec0
Update PT Pin to 1013 (#1407)
Jack-Khuu Dec 9, 2024
d979da1
Update docs for max-autotune usage (#1405)
yanbing-j Dec 9, 2024
6d6f2b9
Update run-docs to include `run-docs native` (#1403)
mikekgfb Dec 9, 2024
46b784e
Update README.md to run and query server during test (#1384)
mikekgfb Dec 9, 2024
2c03a2a
Update run-docs to enable `run-docs evaluation` (#1383)
mikekgfb Dec 9, 2024
e1fefc0
Revert "Use pytorch-labs/tokenizers and remove tokenizer/ (#1401)" (#…
Jack-Khuu Dec 10, 2024
bc0c1dc
Update README.md (whitespace) (#1412)
mikekgfb Dec 10, 2024
dfbd865
Update evaluation.md to include AOTI (#1411)
mikekgfb Dec 10, 2024
19ecd95
Update ADVANCED-USERS.md (#1396)
mikekgfb Dec 11, 2024
1315275
Bump PT pin to 20241028 (#1419)
Jack-Khuu Dec 12, 2024
1d7e71f
Avoid curl fails due to server startup time in CI(#1418)
mikekgfb Dec 12, 2024
36d0712
Add torchao mps ops (#1415)
manuelcandales Dec 13, 2024
5bc5552
Multi Pin Bumps across PT/AO/tune/ET: pt dev20241213 (#1367)
Jack-Khuu Dec 14, 2024
902542d
Update int4pack related in torchchat gguf (#1404)
yanbing-j Dec 17, 2024
6de1a01
update torchao pin: optimized shaders (#1428)
manuelcandales Dec 18, 2024
ff2d53c
Update install_requirements.sh to tune + pt/pt dev20241218 (#1426)
Jack-Khuu Dec 19, 2024
5e16167
Add Granite code support (#1336)
gabe-l-hart Dec 19, 2024
582e558
Fix 3.2 11B inference, by updating padded_collate_tiled_images_and_ma…
Jack-Khuu Dec 19, 2024
7dad56f
Integrate distributed inference with chat/server (#1381)
mreso Dec 19, 2024
155bd4b
Granite 3.0 / 3.1 dense support (#1432)
gabe-l-hart Dec 20, 2024
a325191
Fix typo in quantize.py (#1434)
mikekgfb Dec 23, 2024
86efcd3
Update sh -> bash in quantization.md (#1437)
mikekgfb Dec 23, 2024
a1ba6a1
Output explicit selection of /bin/bash as interpreter for test script…
mikekgfb Dec 23, 2024
490ad39
Fix how stream flag is read from request (#1441)
mreso Dec 25, 2024
b95074b
[retry] Use pytorch-labs/tokenizers and remove tokenizer/ (#1401) (#1…
larryliu0820 Jan 3, 2025
3f0fec3
Update README.md to include granite (#1445)
mikekgfb Jan 5, 2025
c121ed2
Create local-model.md (#1448)
mikekgfb Jan 6, 2025
e60680b
Update evaluation.md (#1442)
mikekgfb Jan 6, 2025
1ba40d7
Create distributed.md (#1438)
mikekgfb Jan 6, 2025
06e78ce
[aoti] Remove need for -l in cmake (#1159)
angelayi Jan 15, 2025
6bfc5c8
Bumping ET Pin to Jan16 2025 (#1459)
Jack-Khuu Jan 17, 2025
d625f72
Fix typo in quantize.py (#1461)
mikekgfb Jan 17, 2025
e5543e2
Update run-readme-pr-mps.yml for typo (#1460)
mikekgfb Jan 17, 2025
2d96e48
Add Intel XPU device support to generate and serve (#1361)
jenniew Jan 18, 2025
defc225
Create run-readme-pr-linuxaarch64 (#1350)
mikekgfb Jan 21, 2025
2227014
Bump test-readme-mps-macos timeout (#1451)
mikekgfb Jan 21, 2025
bc0f93a
Update torch/tune/vision pins to 1/19/25 (#1467)
Jack-Khuu Jan 22, 2025
cd10377
Add warning in PTEModel when not defined (#1468)
Jack-Khuu Jan 22, 2025
ef58fce
Add attention_backend as a configurable option (#1456)
yanbing-j Jan 22, 2025
601f2d1
Update import of sdpa_with_kv_cache to custom_ops (#1470)
Jack-Khuu Jan 22, 2025
083960b
Typo: Fix generate signature type hint for attention_backend (#1471)
Jack-Khuu Jan 22, 2025
a942c16
chat: Change role to user for user prompts (#1447)
vladoovtcharov Jan 22, 2025
f514b35
Update run-readme-pr-linuxaarch64.yml to use correct runner (#1469)
Jack-Khuu Jan 23, 2025
c536da4
Increment start_pos by encoded size in generate (#1462)
nlpfollower Jan 23, 2025
8662471
Explicitly turning off pybindings for ExecuTorch unless requested (#1…
Jack-Khuu Jan 24, 2025
a64b9e3
Replace RMSNorm by nn.RMSNorm (#1464)
manuelcandales Jan 24, 2025
84d2232
Update aoti calls to utilize new export and packaging APIs (#1455)
angelayi Jan 24, 2025
1c2f5aa
Update numpy requirements to no longer upper bound on 2.0 (#1479)
Jack-Khuu Jan 24, 2025
59e168e
Add evaluation, multimodal, native tests to run-readme-pr-macos.yml (…
mikekgfb Jan 24, 2025
7b3a5fd
Add evaluation, multimodal, native tests to run-readme-pr-mps.yml (#1…
mikekgfb Jan 24, 2025
4e2c384
Force run-readme-pr-macos.yml to use CPU instead of incorrectly loadi…
mikekgfb Jan 24, 2025
8bae547
Add distributed tests to run-readme-pr.yml (#1466)
mikekgfb Jan 27, 2025
eba2b07
Update run-docs to avoid code duplication (#1439)
mikekgfb Jan 30, 2025
2f34fee
Add `export --output-snapshot-path snap.tc`, and `--snapshot-path sna…
mikekgfb Jan 31, 2025
ad7f85a
Update check_gibberish to check for aspell availability(#1487)
mikekgfb Jan 31, 2025
31ecb18
Add DeepSeek R1 Distill 8B (#1488)
Jack-Khuu Feb 3, 2025
5f9b347
Replace WeightOnlyInt8Linear with TorchAO int8_weight_only quantization
Oct 24, 2024
8b1af3f
Fallback to original quantization if float16
Feb 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
229 changes: 66 additions & 163 deletions torchchat/utils/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

# from functools import reduce
# from math import gcd
from typing import Dict, Optional, Callable, Any, List
from typing import Any, Callable, Dict, List, Optional

import torch
import torch.nn as nn
Expand All @@ -37,6 +37,7 @@
from torchao.quantization.quant_api import (
int4_weight_only,
Int4WeightOnlyQuantizer,
int8_weight_only,
Int8DynActInt4WeightQuantizer,
quantize_,
)
Expand All @@ -45,8 +46,8 @@
find_multiple,
get_device_str,
get_precision,
set_precision,
name_to_dtype,
set_precision,
state_dict_device,
use_et_backend,
)
Expand All @@ -60,28 +61,36 @@

import inspect


def get_named_parameters(func: Callable) -> List[str]:
# Get the signature of the function
signature = inspect.signature(func)

# Extract the parameters from the signature
parameters = signature.parameters

# Filter and return named parameters
named_params = [
name for name, param in parameters.items()
if param.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY)
name
for name, param in parameters.items()
if param.kind
in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY)
]
return named_params

def validate_args(named_params: List[str], q_kwargs: Dict[str, Any], quantizer: Optional[str] = None) -> Dict[str, Any]:

def validate_args(
named_params: List[str], q_kwargs: Dict[str, Any], quantizer: Optional[str] = None
) -> Dict[str, Any]:
for key in q_kwargs.keys():
if key not in named_params:
print(f"Specification for quantizer {quantizer} has extraneous key {key}. Ignoring.")
print(
f"Specification for quantizer {quantizer} has extraneous key {key}. Ignoring."
)
del q_kwargs[key]
return q_kwargs


#########################################################################
### torchchat quantization API ###

Expand Down Expand Up @@ -110,21 +119,30 @@ def quantize_model(
if quantizer not in quantizer_class_dict:
raise RuntimeError(f"unknown quantizer {quantizer} specified")
else:
ao_quant = True
# Use tensor subclass API for int4 weight only.
if device == "cuda" and quantizer == "linear:int4":
quantize_(model, int4_weight_only(q_kwargs["groupsize"]))
elif quantizer == "linear:int8":
print("quantizer is linear int8")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
print("quantizer is linear int8")

quantize_(model, int8_weight_only())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not integrate it into a QuantHandler class dispatched thru the handler dict at a single call site rather than build a chain of if statements?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @mikekgfb, we will refactor this part in the future after all quant APIs are moved to torchao I think

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torchAO already has a class-based API that is used for other quantizers? Why do these differently, and then later refactor them? Or why not do them all a consistent way now, and if you refactor later, do that?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, quantizer API is deprecated in favor of quantize_, that's why we are gradually refactoring the quantizer APIs to use quantize_, the reason we do it one by one is because there might be missing support/alignment on numerics etc. that we need to do during the migration

else:
ao_quant = False
if ao_quant:
if not support_tensor_subclass:
unwrap_tensor_subclass(model)
continue

if quantizer in ["linear:a8wxdq", "embedding:wx"]:
# These quantizers require float32 input weights. Note that after quantization,
# the weights will no longer be float32, but lowbit integers
if get_precision() != torch.float32:
print(f"Quantizer {quantizer} requires float32 inputs, but received {get_precision()}. Changing dtype to float32. Note that after quantization, the weights will be lowbit integers, not float32.")
print(
f"Quantizer {quantizer} requires float32 inputs, but received {get_precision()}. Changing dtype to float32. Note that after quantization, the weights will be lowbit integers, not float32."
)
set_precision(torch.float32)
# We set global precision from quantize options if it is specified at cli.py:485

# We set global precision from quantize options if it is specified at cli.py:485
# so the precision returned by get_precision() is always the authoritative precision/dtype in torchchat
precision = get_precision()

Expand All @@ -141,14 +159,19 @@ def quantize_model(
model = quant_handler.quantize(model)



#########################################################################
### QuantHandler API definition ###
### (unify with torchao in future) ###


class QuantHandler:
def __init__(self, model: Optional[nn.Module] = None, device="cpu", precision=None, tokenizer=None):
def __init__(
self,
model: Optional[nn.Module] = None,
device="cpu",
precision=None,
tokenizer=None,
):
self.model_ = model
self.device = device
self.tokenizer = tokenizer
Expand Down Expand Up @@ -176,7 +199,15 @@ def quantize(self, model: nn.Module) -> nn.Module:


class PrecisionHandler(QuantHandler):
def __init__(self, model: Optional[nn.Module]=None, device="cpu", precision=None, tokenizer=None, *, dtype):
def __init__(
self,
model: Optional[nn.Module] = None,
device="cpu",
precision=None,
tokenizer=None,
*,
dtype,
):
self.model_ = model
self.device = device
self.tokenizer = tokenizer
Expand Down Expand Up @@ -205,7 +236,15 @@ def quantized_model(self) -> nn.Module:


class ExecutorHandler(QuantHandler):
def __init__(self, model: Optional[nn.Module]=None, device="cpu", precision=None, tokenizer=None, *, accelerator):
def __init__(
self,
model: Optional[nn.Module] = None,
device="cpu",
precision=None,
tokenizer=None,
*,
accelerator,
):
self.model_ = model

if isinstance(accelerator, str):
Expand Down Expand Up @@ -529,147 +568,6 @@ def linear_int8_et(input, weight, scales):
)


class WeightOnlyInt8Linear(nn.Module):
__constants__ = ["in_features", "out_features"]
in_features: int
out_features: int
weight: torch.Tensor
scales: torch.Tensor

def __init__(
self,
in_features,
out_features,
bias=None,
device=None,
dtype=None,
*,
weight: Optional[torch.Tensor] = None,
scales: Optional[torch.Tensor] = None,
groupsize: Optional[int] = None,
):
super().__init__()
if dtype is None:
dtype = torch.get_default_dtype()

if device is None:
device = "cpu"

assert not bias, "Bias is not supported by LinearInt8"
self.in_features = in_features
self.out_features = out_features

assert (weight is None) == bool(
scales is None
), "must specify both weights and scales, or neither"
if weight is None:
weight = torch.empty(
(out_features, in_features),
dtype=torch.int8,
device=device,
)
if groupsize is None or (groupsize == 0):
scales = torch.empty(out_features, dtype=dtype, device=device)
else:
n_groups = (in_features + groupsize - 1) // groupsize
scales = torch.empty(out_features, n_groups, dtype=dtype, device=device)

self.register_buffer("weight", weight.to(device))
self.register_buffer("scales", scales.to(device))

if use_et_backend():
self.forward = self.et_forward
else:
self.forward = self.aoti_forward

def aoti_forward(self, input: torch.Tensor) -> torch.Tensor:
return linear_int8_aoti(input, self.weight, self.scales)

def et_forward(self, input: torch.Tensor) -> torch.Tensor:
return linear_int8_et(input, self.weight, self.scales)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Int 8 seems like it special cased for ET, reminder to check that as well



class WeightOnlyInt8QuantHandler(QuantHandler):
def __init__(
self,
model: Optional[nn.Module] = None,
device = None,
precision=None,
tokenizer=None,
*,
node_type: str = "*",
bitwidth: Optional[int] = None,
groupsize: Optional[int] = None,
):
self.model_ = model
self.device = device
self.groupsize = groupsize
self.node_type = node_type
if bitwidth is None:
self.bitwidth = 8
else:
self.bitwidth = bitwidth

@torch.no_grad()
def quantize(self, module):
# cur_state_dict = state_dict_device(self.model_.state_dict())
# dict_device = "cpu" # self.device

if self.bitwidth == 4:
range_min = -8
range_max = 7
elif self.bitwidth == 8:
range_min = -128
range_max = 127
else:
raise ValueError(f"Unsupported bitwidth {self.bitwidth}")

for name, child in module.named_children():
# print(f"name: {name}")
if isinstance(child, nn.Linear):
if (
(self.node_type == "*")
or (self.node_type == "output" and name == "output")
or (self.node_type == "!output" and name != "output")
):
# print(f"{name, child}")
input_weight = child.weight.float()
# print(f"{name, child}")
# print(f"in_features: {child.in_features}")
# print(f"out_features: {child.out_features}")

# print(f"expanded weight shape {input_weight.shape}")
weight, scales, _ = dynamically_quantize_per_channel(
input_weight,
range_min,
range_max,
torch.int8,
self.groupsize,
scales_dtype=child.weight.dtype,
)

setattr(
module,
name,
WeightOnlyInt8Linear(
in_features=child.in_features,
out_features=child.out_features,
device=self.device,
# update variables from quantization
weight=weight,
scales=scales,
groupsize=self.groupsize,
),
)
else:
self.quantize(child)

return module

def quantized_model(self) -> nn.Module:
return self.quantize(self.model_)


#########################################################################
##### embedding table quantization ######
### (unify with torchao in future) ###
Expand Down Expand Up @@ -886,10 +784,10 @@ def quantized_model(self) -> nn.Module:
# class references
quantizer_class_dict = {
"embedding": EmbeddingOnlyQuantHandler,
"linear:int8": WeightOnlyInt8QuantHandler,
"precision": PrecisionHandler,
"executor": ExecutorHandler,
"linear:int4": Int4WeightOnlyQuantizer,
"linear:int8": int8_weight_only,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can probably use None for now, and remove this later

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We check for int8_weight_only and finished check before it looks at the table I think

@vmpuri can you check?

"linear:a8w4dq": Int8DynActInt4WeightQuantizer,
}

Expand Down Expand Up @@ -932,11 +830,16 @@ def quantized_model(self) -> nn.Module:
print("Slow fallback kernels will be used.")

except Exception as e:

class ErrorHandler(QuantHandler):
def __init__(self, model: Optional[nn.Module]=None, device="cpu", precision=None):
def __init__(
self, model: Optional[nn.Module] = None, device="cpu", precision=None
):
global torchao_experimental_load_error
raise Exception(f"Note: Failed to load torchao experimental quantizer with error: {torchao_experimental_load_error}")

raise Exception(
f"Note: Failed to load torchao experimental quantizer with error: {torchao_experimental_load_error}"
)

torchao_experimental_load_error = e
quantizer_class_dict["linear:a8wxdq"] = ErrorHandler
quantizer_class_dict["embedding:wx"] = ErrorHandler
Loading