Skip to content

Commit c403580

Browse files
authored
Reapply Autoquant (#82) (#109)
1 parent fc5d2c8 commit c403580

File tree

8 files changed

+546
-18
lines changed

8 files changed

+546
-18
lines changed

.github/workflows/regression_test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,4 +53,4 @@ jobs:
5353
5454
- name: Run tests
5555
run: |
56-
pytest test --verbose -s -x
56+
pytest test --verbose -s

README.md

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
# torchao: PyTorch Architecture Optimization
1+
# torchao: PyTorch Architecture Optimization
22

33
**Note: This repository is currently under heavy development - if you have suggestions on the API or use-cases you'd like to be covered, please open an github issue**
44

5-
The `torchao` package allows you to quantize and prune your models using native PyTorch.
5+
The `torchao` package allows you to quantize and prune your models using native PyTorch.
66

77
The repo hosts both
88
1. lower precision [dtypes](./torchao/dtypes) such as nf4, uint4
@@ -38,30 +38,46 @@ pip install -e .
3838

3939
Typically quantization algorithms will have different schemes for how the activation and weights are quantized so A16W8 for instance means the activations are quantized to 16 bits wheras the weights are quantized to 8 bits. Trying out different quantization schemes in `torchao` is generally a 1 line change.
4040

41-
### A8W8 Dynamic Quantization
41+
### Autoquantization
4242

43-
```Python
43+
The `autoquant` api can be used to quickly and accurately quantize your model. When used as in the example below, the api first identifies the shapes
44+
of the activations that the different linear layers see, it then benchmarks these shapes across different types of quantized and non-quantized layers in order to pick the fastest one, attempting to take into account fusions where possible. Finally once the best class is found for each layer, it swaps the linear. Currently this api chooses between no quantization, int8 dynamic quantization and int8 weight only quantization for each layer.
45+
46+
```python
4447
import torch
45-
from torchao.quantization import quant_api
48+
import torchao
4649

47-
# Fuse the int8*int8 -> int32 matmul and subsequent mul op avoiding materialization of the int32 intermediary tensor
50+
# inductor settings which improve torch.compile performance for quantized modules
4851
torch._inductor.config.force_fuse_int_mm_with_mul = True
52+
torch._inductor.config.use_mixed_mm = True
4953

5054
# Plug in your model and example input
5155
model = torch.nn.Sequential(torch.nn.Linear(32, 64)).cuda().to(torch.bfloat16)
5256
input = torch.randn(32,32, dtype=torch.bfloat16, device='cuda')
5357

54-
# convert linear modules to quantized linear modules
55-
quant_api.change_linear_weights_to_int8_dqtensors(model)
58+
# perform autoquantization
59+
torchao.autoquant(model, (input))
5660

5761
# compile the model to improve performance
5862
model = torch.compile(model, mode='max-autotune')
5963
model(input)
6064
```
6165

66+
67+
### A8W8 Dynamic Quantization
68+
69+
```python
70+
# Fuse the int8*int8 -> int32 matmul and subsequent mul op avoiding materialization of the int32 intermediary tensor
71+
torch._inductor.config.force_fuse_int_mm_with_mul = True
72+
from torchao.quantization import quant_api
73+
# convert linear modules to quantized tensor subclasses
74+
quant_api.change_linear_weights_to_int8_dqtensors(model)
75+
```
76+
6277
### A16W8 WeightOnly Quantization
6378

6479
```python
80+
from torchao.quantization import quant_api
6581
quant_api.change_linear_weights_to_int8_woqtensors(model)
6682
```
6783

@@ -71,6 +87,7 @@ This technique works best when the torch._inductor.config.use_mixed_mm option is
7187
### A16W4 WeightOnly Quantization
7288

7389
```python
90+
from torchao.quantization import quant_api
7491
quant_api.change_linear_weights_to_int4_woqtensors(model)
7592
```
7693

test/integration/test_integration.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,13 @@
77
# mypy: ignore-errors
88
import copy
99
import unittest
10+
import itertools
1011

1112
import torch
1213
import torch.nn as nn
1314
from torch._inductor.utils import run_and_get_code
1415
from torch._dynamo import config
16+
import torchao
1517
from torch.ao.quantization import MinMaxObserver, QConfigMapping
1618

1719
from torchao.quantization.dynamic_quant import (
@@ -54,6 +56,13 @@
5456
_fqn_to_op_to_shape_to_count,
5557
LoggingTensorMode,
5658
)
59+
from torchao.quantization.autoquant import (
60+
AQInt8DynamicallyQuantizedLinearWeight,
61+
AQWeightOnlyQuantizedLinearWeight,
62+
AQWeightOnlyQuantizedLinearWeight2,
63+
AQWeightOnlyQuantizedLinearWeight3
64+
65+
)
5766
from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx
5867
import os
5968
from parameterized import parameterized
@@ -71,6 +80,12 @@
7180
("cuda", torch.bfloat16),
7281
]
7382

83+
def combine_parameters(a, b):
84+
new_tuples = []
85+
for (tuple1, tuple2) in itertools.product(a, b):
86+
new_tuples.append(tuple1 + tuple2)
87+
return new_tuples
88+
7489
def run_supported_device_dtype(test_method):
7590
def wrapper(*args, **kwargs):
7691
if args[2] == "cuda" and not torch.cuda.is_available():
@@ -907,6 +922,36 @@ def test_int8_weight_only_quant_subclass(self, device, dtype):
907922
Int8WeightOnlyQuantizedLinearWeight.from_float, device, 40, test_dtype=dtype
908923
)
909924

925+
@parameterized.expand(COMMON_DEVICE_DTYPE)
926+
def test_aq_int8_dynamic_quant_subclass(self, device, dtype):
927+
self._test_lin_weight_subclass_impl(
928+
AQInt8DynamicallyQuantizedLinearWeight.from_float, device, 35, test_dtype=dtype
929+
)
930+
931+
@parameterized.expand(COMMON_DEVICE_DTYPE)
932+
def test_aq_int8_weight_only_quant_subclass(self, device, dtype):
933+
self._test_lin_weight_subclass_impl(
934+
AQInt8DynamicallyQuantizedLinearWeight.from_float, device, 35, test_dtype=dtype
935+
)
936+
937+
@parameterized.expand(COMMON_DEVICE_DTYPE)
938+
def test_aq_int8_weight_only_quant_subclass(self, device, dtype):
939+
self._test_lin_weight_subclass_impl(
940+
AQWeightOnlyQuantizedLinearWeight.from_float, device, 35, test_dtype=dtype
941+
)
942+
943+
@parameterized.expand(COMMON_DEVICE_DTYPE)
944+
def test_aq_int8_weight_only_quant_2_subclass(self, device, dtype):
945+
self._test_lin_weight_subclass_impl(
946+
AQWeightOnlyQuantizedLinearWeight2.from_float, device, 35, test_dtype=dtype
947+
)
948+
949+
@parameterized.expand(COMMON_DEVICE_DTYPE)
950+
def test_aq_int8_weight_only_quant_3_subclass(self, device, dtype):
951+
self._test_lin_weight_subclass_impl(
952+
AQWeightOnlyQuantizedLinearWeight3.from_float, device, 35, test_dtype=dtype
953+
)
954+
910955
@parameterized.expand(COMMON_DEVICE_DTYPE)
911956
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch nightly.")
912957
def test_int4_weight_only_quant_subclass(self, device, dtype):
@@ -1290,6 +1335,74 @@ def test_on_dummy_distilbert(self):
12901335
print("sqnr_pt_quant", sqnr_pt_quant)
12911336
self.assertTrue(sqnr_sq >= 8.0)
12921337

1338+
class TestAutoQuant(unittest.TestCase):
1339+
@parameterized.expand(combine_parameters(COMMON_DEVICE_DTYPE,
1340+
[
1341+
(16, 128, 128),
1342+
(64, 128, 128),
1343+
# (2**15, 128, 128), TODO: Runs out of shared memory on T4
1344+
(16, 128, 256),
1345+
# (64, 128, 256), # TODO: Runs out of shared memory on T4
1346+
(16, 256, 128),
1347+
(64, 256, 128),
1348+
# (256, 256, 128), TODO: Runs out of shared memory on T4
1349+
]))
1350+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "autoquant requires 2.3+.")
1351+
def test_autoquant_one_input(self, device, dtype, m, k, n):
1352+
print("(m, k, n): ", (m, k, n))
1353+
if device != "cuda" or not torch.cuda.is_available():
1354+
self.skipTest(f"autoquant currently does not support {device}")
1355+
if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0):
1356+
if dtype == torch.bfloat16:
1357+
self.skipTest(f"bfloat16 requires sm80+")
1358+
if m == 1:
1359+
self.skipTest(f"Shape {(m, k, n)} requires sm80+")
1360+
torch._inductor.config.epilogue_fusion = False
1361+
torch._inductor.config.use_mixed_mm = True
1362+
torch._inductor.config.force_fuse_int_mm_with_mul = True
1363+
torch._dynamo.config.automatic_dynamic_shapes = False
1364+
1365+
example_input = torch.randn(m, k, device=device, dtype=dtype)
1366+
model = torch.nn.Sequential(
1367+
torch.nn.ReLU(),
1368+
torch.nn.Linear(k,n),
1369+
torch.nn.ReLU(),
1370+
).to(device).to(dtype)
1371+
out = model(example_input)
1372+
torchao.autoquant(model, example_input)
1373+
out2 = model(example_input)
1374+
sqnr = SQNR(out, out2)
1375+
self.assertTrue(sqnr >= 30)
1376+
1377+
@parameterized.expand(combine_parameters(COMMON_DEVICE_DTYPE,
1378+
[
1379+
(1, 1, 128, 128),
1380+
(1, 32, 128, 128),
1381+
(32, 32, 128, 128),
1382+
]))
1383+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "autoquant requires 2.3+.")
1384+
def test_autoquant_multi_input(self, device, dtype, m1, m2, k, n):
1385+
if device != "cuda" or not torch.cuda.is_available():
1386+
self.skipTest(f"autoquant currently does not support {device}")
1387+
if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0):
1388+
if dtype == torch.bfloat16:
1389+
self.skipTest(f"bfloat16 requires sm80+")
1390+
if m1 == 1 or m2 == 1:
1391+
self.skipTest(f"Shape {(m1, m2, k, n)} requires sm80+")
1392+
model = torch.nn.Sequential(
1393+
torch.nn.ReLU(),
1394+
torch.nn.Linear(k,n),
1395+
torch.nn.ReLU(),
1396+
).to(device).to(dtype)
1397+
example_input = torch.randn(m1, k, device=device, dtype=dtype)
1398+
example_input2 = torch.randn(m2, k, device=device, dtype=dtype)
1399+
torchao.quantization.change_linears_to_autoquantizable(model)
1400+
out=model(example_input)
1401+
model(example_input2)
1402+
torchao.quantization.change_autoquantizable_to_quantized(model)
1403+
out2 = model(example_input)
1404+
sqnr = SQNR(out, out2)
1405+
self.assertTrue(sqnr >= 30)
12931406

12941407
if __name__ == "__main__":
12951408
unittest.main()

torchao/__init__.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
1+
from torchao.quantization import (
2+
apply_weight_only_int8_quant,
3+
apply_dynamic_quant,
4+
autoquant,
5+
)
16
from . import dtypes
2-
from .quantization.quant_api import apply_dynamic_quant
3-
from .quantization.quant_api import apply_weight_only_int8_quant
47

58
__all__ = [
6-
"dtypes",
7-
"apply_dynamic_quant",
9+
"dtypes",
10+
"apply_dynamic_quant",
11+
"apply_weight_only_int8_quant",
12+
"autoquant",
813
]

torchao/quantization/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from .utils import * # noqa: F403
1212
from .weight_only import * # noqa: F403
1313
from .unified import *
14+
from .autoquant import *
1415

1516
__all__ = [
1617
"DynamicallyPerAxisQuantizedLinear",
@@ -26,6 +27,9 @@
2627
"dynamically_quantize_per_channel",
2728
"dequantize_per_tensor",
2829
"dequantize_per_channel",
30+
"autoquant",
31+
"change_linears_to_autoquantizable",
32+
"change_autoquantizable_to_quantized",
2933
"quant_int8_dynamic_linear",
3034
"quant_int8_matmul",
3135
"quant_int8_dynamic_per_token_linear",

0 commit comments

Comments
 (0)