diff --git a/torchao/quantization/README.md b/torchao/quantization/README.md index 622ec1cbc..3780c1f06 100644 --- a/torchao/quantization/README.md +++ b/torchao/quantization/README.md @@ -164,6 +164,117 @@ model = torch.compile(model, mode='max-autotune') model(input) ``` +## Affine Quantization +Affine quantization refers to the type of quantization that maps from floating point numbers to quantized numbers (typically integer) with an affine transformation, i.e.: `quantized_val = float_val / scale + zero_point` where `scale` and `zero_point` are quantization parameters for some granularity and based on some data. + +### Quantization Primitives +We used to have different quantize and dequantize operators for quantization with different granularities. But in the end these can all be expressed with a `block_size` argument with different settings, so we unified existing quant primitives to `choose_qparams_affine`, `quantize_affine` and `dequantize_affine` that can represent symmetric/asymmetric per tensor/channel/token/channel_group quantization, this can be used to implement the unified quantized tensor subclass. + +### Quantized Tensor Subclass +We also have a unified quantized tensor subclass that implements how to get a quantized tensor from floating point tensor and what does it mean to call linear ops on an instance of the tensor, e.g. `F.linear` and `aten.addmm`, with this we could dispatch to different operators (e.g. `int4mm` op) based on device (cpu, cuda) and quantization settings (`int4`, `int8`) and also packing formats (e.g. format optimized for cpu int4 mm kernel) + +### Quantization Flow +What we need to do afterwards is roughly the following + +``` +from torchao.dtypes.aqt import to_aq +def apply_int8wo_quant(weight): + mapping_type = MappingType.SYMMETRIC + target_dtype = torch.int8 + eps = torch.finfo(torch.float32).eps + zero_point_dtype = torch.int64 + block_size = (1, weight.shape[1]) + return to_aq(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype) + +for n, m in model.named_modules(): + if isinstance(m, torch.nn.Linear): + # optional filtering for module name, shape etc. + m.weight = nn.Parameter(apply_int8wo_quant(m.weight)) + # note: quantization for activation need to be applied after the weight quantization + # quantization activation (needed by dynamic quantization) + # input_quant_func = apply_int8wo_quant # specify how input activation is quantized + # m.weight = nn.Parameter(to_laq(m.weight, input_quant_func)) +``` +The model/tensor subclass should also be compatible with AOTI and torch.export, currently we can support +`torch.export.export` and `torch.aot_compile` with the following workaround: +``` +from torchao.quantization.utils import unwrap_tensor_subclass +m_unwrapped = unwrap_tensor_subclass(m) + + +# export +m = torch.export.export(m_unwrapped, example_inputs).module() + +# aot_compile +torch._export.aot_compile(m_unwrapped, example_inputs) +``` + +But we expect this will be integrated into the export path by default in the future. + + +### Example +Let's use int4 weight only quantization that's targeting tinygemm int4 weight only quantized matmul +as an example: +```python +import torch +from torchao.quantization.quant_primitives import MappingType, ZeroPointDomain +from torchao.dtypes import to_aq +from torch._inductor.runtime.runtime_utils import do_bench_gpu +import copy +from torchao.quantization.quant_api import ( + quantize, + get_apply_int4wo_quant, +) + +class ToyLinearModel(torch.nn.Module): + def __init__(self, m=64, n=32, k=64): + super().__init__() + self.linear1 = torch.nn.Linear(m, n, bias=False) + self.linear2 = torch.nn.Linear(n, k, bias=False) + + def example_inputs(self, batch_size=1, dtype=torch.float32, device="cpu"): + return (torch.randn(batch_size, self.linear1.in_features, dtype=dtype, device=device),) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x + +dtype = torch.bfloat16 +m = ToyLinearModel(1024, 1024, 1024).eval().to(dtype).to("cuda") +m_bf16 = copy.deepcopy(m) +example_inputs = m.example_inputs(dtype=dtype, device="cuda") + +m_bf16 = torch.compile(m_bf16, mode='max-autotune') +# apply int4 weight only quant (compatible with tinygemm int4 weight only quant mm kernel in torchao) +groupsize = 32 +m = quantize(m, get_apply_int4wo_quant(groupsize=groupsize)) + +torch._inductor.config.force_fuse_int_mm_with_mul = True +torch._inductor.config.use_mixed_mm = True + +# temporary workaround for tensor subclass + torch.compile +from torchao.quantization.utils import unwrap_tensor_subclass +m = unwrap_tensor_subclass(m) +# compile the model to improve performance +m = torch.compile(m, mode='max-autotune') + +# benchmark to see the speedup +from torchao.utils import benchmark_model + +num_runs = 100 +torch._dynamo.reset() +bf16_time = benchmark_model(m_bf16, num_runs, example_inputs[0]) +print(f"bf16 mean time: {bf16_time}") +int4_time = benchmark_model(m, num_runs, example_inputs[0]) +print(f"int4 weight only quantized mean time: {int4_time}") +print(f"speedup: {bf16_time / int4_time}") + +# output (1xA100 GPU machine) +bf16 mean time: 71.457685546875 +int4 weight only quantized mean time: 31.4580908203125 +speedup: 2.2715200981216173 +``` ## Notes