|
7 | 7 | # mypy: ignore-errors
|
8 | 8 | import copy
|
9 | 9 | import unittest
|
| 10 | +import itertools |
10 | 11 |
|
11 | 12 | import torch
|
12 | 13 | import torch.nn as nn
|
13 | 14 | from torch._inductor.utils import run_and_get_code
|
14 | 15 | from torch._dynamo import config
|
| 16 | +import torchao |
15 | 17 | from torch.ao.quantization import MinMaxObserver, QConfigMapping
|
16 | 18 |
|
17 | 19 | from torchao.quantization.dynamic_quant import (
|
|
54 | 56 | _fqn_to_op_to_shape_to_count,
|
55 | 57 | LoggingTensorMode,
|
56 | 58 | )
|
| 59 | +from torchao.quantization.autoquant import ( |
| 60 | + AQInt8DynamicallyQuantizedLinearWeight, |
| 61 | + AQWeightOnlyQuantizedLinearWeight, |
| 62 | + AQWeightOnlyQuantizedLinearWeight2, |
| 63 | + AQWeightOnlyQuantizedLinearWeight3 |
| 64 | + |
| 65 | +) |
57 | 66 | from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx
|
58 | 67 | import os
|
59 | 68 | from parameterized import parameterized
|
|
71 | 80 | ("cuda", torch.bfloat16),
|
72 | 81 | ]
|
73 | 82 |
|
| 83 | +def combine_parameters(a, b): |
| 84 | + new_tuples = [] |
| 85 | + for (tuple1, tuple2) in itertools.product(a, b): |
| 86 | + new_tuples.append(tuple1 + tuple2) |
| 87 | + return new_tuples |
| 88 | + |
74 | 89 | def run_supported_device_dtype(test_method):
|
75 | 90 | def wrapper(*args, **kwargs):
|
76 | 91 | if args[2] == "cuda" and not torch.cuda.is_available():
|
@@ -907,6 +922,36 @@ def test_int8_weight_only_quant_subclass(self, device, dtype):
|
907 | 922 | Int8WeightOnlyQuantizedLinearWeight.from_float, device, 40, test_dtype=dtype
|
908 | 923 | )
|
909 | 924 |
|
| 925 | + @parameterized.expand(COMMON_DEVICE_DTYPE) |
| 926 | + def test_aq_int8_dynamic_quant_subclass(self, device, dtype): |
| 927 | + self._test_lin_weight_subclass_impl( |
| 928 | + AQInt8DynamicallyQuantizedLinearWeight.from_float, device, 35, test_dtype=dtype |
| 929 | + ) |
| 930 | + |
| 931 | + @parameterized.expand(COMMON_DEVICE_DTYPE) |
| 932 | + def test_aq_int8_weight_only_quant_subclass(self, device, dtype): |
| 933 | + self._test_lin_weight_subclass_impl( |
| 934 | + AQInt8DynamicallyQuantizedLinearWeight.from_float, device, 35, test_dtype=dtype |
| 935 | + ) |
| 936 | + |
| 937 | + @parameterized.expand(COMMON_DEVICE_DTYPE) |
| 938 | + def test_aq_int8_weight_only_quant_subclass(self, device, dtype): |
| 939 | + self._test_lin_weight_subclass_impl( |
| 940 | + AQWeightOnlyQuantizedLinearWeight.from_float, device, 35, test_dtype=dtype |
| 941 | + ) |
| 942 | + |
| 943 | + @parameterized.expand(COMMON_DEVICE_DTYPE) |
| 944 | + def test_aq_int8_weight_only_quant_2_subclass(self, device, dtype): |
| 945 | + self._test_lin_weight_subclass_impl( |
| 946 | + AQWeightOnlyQuantizedLinearWeight2.from_float, device, 35, test_dtype=dtype |
| 947 | + ) |
| 948 | + |
| 949 | + @parameterized.expand(COMMON_DEVICE_DTYPE) |
| 950 | + def test_aq_int8_weight_only_quant_3_subclass(self, device, dtype): |
| 951 | + self._test_lin_weight_subclass_impl( |
| 952 | + AQWeightOnlyQuantizedLinearWeight3.from_float, device, 35, test_dtype=dtype |
| 953 | + ) |
| 954 | + |
910 | 955 | @parameterized.expand(COMMON_DEVICE_DTYPE)
|
911 | 956 | @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch nightly.")
|
912 | 957 | def test_int4_weight_only_quant_subclass(self, device, dtype):
|
@@ -1290,6 +1335,74 @@ def test_on_dummy_distilbert(self):
|
1290 | 1335 | print("sqnr_pt_quant", sqnr_pt_quant)
|
1291 | 1336 | self.assertTrue(sqnr_sq >= 8.0)
|
1292 | 1337 |
|
| 1338 | +class TestAutoQuant(unittest.TestCase): |
| 1339 | + @parameterized.expand(combine_parameters(COMMON_DEVICE_DTYPE, |
| 1340 | + [ |
| 1341 | + (16, 128, 128), |
| 1342 | + (64, 128, 128), |
| 1343 | + # (2**15, 128, 128), TODO: Runs out of shared memory on T4 |
| 1344 | + (16, 128, 256), |
| 1345 | + # (64, 128, 256), # TODO: Runs out of shared memory on T4 |
| 1346 | + (16, 256, 128), |
| 1347 | + (64, 256, 128), |
| 1348 | + # (256, 256, 128), TODO: Runs out of shared memory on T4 |
| 1349 | + ])) |
| 1350 | + @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "autoquant requires 2.3+.") |
| 1351 | + def test_autoquant_one_input(self, device, dtype, m, k, n): |
| 1352 | + print("(m, k, n): ", (m, k, n)) |
| 1353 | + if device != "cuda" or not torch.cuda.is_available(): |
| 1354 | + self.skipTest(f"autoquant currently does not support {device}") |
| 1355 | + if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0): |
| 1356 | + if dtype == torch.bfloat16: |
| 1357 | + self.skipTest(f"bfloat16 requires sm80+") |
| 1358 | + if m == 1: |
| 1359 | + self.skipTest(f"Shape {(m, k, n)} requires sm80+") |
| 1360 | + torch._inductor.config.epilogue_fusion = False |
| 1361 | + torch._inductor.config.use_mixed_mm = True |
| 1362 | + torch._inductor.config.force_fuse_int_mm_with_mul = True |
| 1363 | + torch._dynamo.config.automatic_dynamic_shapes = False |
| 1364 | + |
| 1365 | + example_input = torch.randn(m, k, device=device, dtype=dtype) |
| 1366 | + model = torch.nn.Sequential( |
| 1367 | + torch.nn.ReLU(), |
| 1368 | + torch.nn.Linear(k,n), |
| 1369 | + torch.nn.ReLU(), |
| 1370 | + ).to(device).to(dtype) |
| 1371 | + out = model(example_input) |
| 1372 | + torchao.autoquant(model, example_input) |
| 1373 | + out2 = model(example_input) |
| 1374 | + sqnr = SQNR(out, out2) |
| 1375 | + self.assertTrue(sqnr >= 30) |
| 1376 | + |
| 1377 | + @parameterized.expand(combine_parameters(COMMON_DEVICE_DTYPE, |
| 1378 | + [ |
| 1379 | + (1, 1, 128, 128), |
| 1380 | + (1, 32, 128, 128), |
| 1381 | + (32, 32, 128, 128), |
| 1382 | + ])) |
| 1383 | + @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "autoquant requires 2.3+.") |
| 1384 | + def test_autoquant_multi_input(self, device, dtype, m1, m2, k, n): |
| 1385 | + if device != "cuda" or not torch.cuda.is_available(): |
| 1386 | + self.skipTest(f"autoquant currently does not support {device}") |
| 1387 | + if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0): |
| 1388 | + if dtype == torch.bfloat16: |
| 1389 | + self.skipTest(f"bfloat16 requires sm80+") |
| 1390 | + if m1 == 1 or m2 == 1: |
| 1391 | + self.skipTest(f"Shape {(m1, m2, k, n)} requires sm80+") |
| 1392 | + model = torch.nn.Sequential( |
| 1393 | + torch.nn.ReLU(), |
| 1394 | + torch.nn.Linear(k,n), |
| 1395 | + torch.nn.ReLU(), |
| 1396 | + ).to(device).to(dtype) |
| 1397 | + example_input = torch.randn(m1, k, device=device, dtype=dtype) |
| 1398 | + example_input2 = torch.randn(m2, k, device=device, dtype=dtype) |
| 1399 | + torchao.quantization.change_linears_to_autoquantizable(model) |
| 1400 | + out=model(example_input) |
| 1401 | + model(example_input2) |
| 1402 | + torchao.quantization.change_autoquantizable_to_quantized(model) |
| 1403 | + out2 = model(example_input) |
| 1404 | + sqnr = SQNR(out, out2) |
| 1405 | + self.assertTrue(sqnr >= 30) |
1293 | 1406 |
|
1294 | 1407 | if __name__ == "__main__":
|
1295 | 1408 | unittest.main()
|
0 commit comments