You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Update ARM CPU experimental kernels from AO to leverage pip install (#1458)
* update experimental kernels in torchchat
* Update docs/quantization.md
Co-authored-by: Jack-Khuu <[email protected]>
* Update torchchat/utils/quantize.py
Co-authored-by: Jack-Khuu <[email protected]>
* Update torchchat/utils/quantize.py
Co-authored-by: Jack-Khuu <[email protected]>
* Fixing import typo in quantize.py
* Bump ET pin to pick up AO changes
* Bump torchao-pin to match ET and torchchat
* Update torchao-pin.txt
* Split up AOTI and ET tests
* Bump ET pin to 2-26-25 with new AO pin
* Undo et pin bump; fails basic install
* update
* up
* up
* up
* up
* up
* up
* up
* up
* up
* up
* up
* up
---------
Co-authored-by: Jack-Khuu <[email protected]>
Copy file name to clipboardexpand all lines: docs/quantization.md
+10-6
Original file line number
Diff line number
Diff line change
@@ -120,13 +120,15 @@ python3 torchchat.py generate llama3 --pte-path llama3.pte --prompt "Hello my n
120
120
121
121
## Experimental TorchAO lowbit kernels
122
122
123
-
WARNING: These kernels only work on devices with ARM CPUs, for example on Mac computers with Apple Silicon.
123
+
If you are on a Mac with Apple Silicon, we have 1-8 quantization available for embedding and linear layers, backed by CPU and MPS kernels.
124
+
125
+
The CPU kernels are installed automatically by the torchchat install script and can be used out of the box. To use the MPS kernels, follow the setup instructions below.
124
126
125
127
### Use
126
128
127
129
#### linear:a8wxdq
128
130
The quantization scheme linear:a8wxdq dynamically quantizes activations to 8 bits, and quantizes the weights in a groupwise manner with a specified bitwidth and groupsize.
129
-
It takes arguments bitwidth (1, 2, 3, 4, 5, 6, 7), groupsize, and has_weight_zeros (true, false).
131
+
It takes arguments bitwidth (1, 2, 3, 4, 5, 6, 7, 8), groupsize (-1 if channelwise desired), and has_weight_zeros (true, false).
130
132
The argument has_weight_zeros indicates whether the weights are quantized with scales only (has_weight_zeros: false) or with both scales and zeros (has_weight_zeros: true).
131
133
Roughly speaking, {bitwidth: 4, groupsize: 32, has_weight_zeros: false} is similar to GGML's Q4_0 quantization scheme.
132
134
@@ -138,7 +140,9 @@ The quantization scheme embedding:wx quantizes embeddings in a groupwise manner
138
140
You should expect high performance on ARM CPU if groupsize is divisible by 32. With other platforms and argument choices, a slow fallback kernel will be used. You will see warnings about this during quantization.
139
141
140
142
### Setup
141
-
To use linear:a8wxdq and embedding:wx, you must set up the torchao experimental kernels. These will only work on devices with ARM CPUs, for example on Mac computers with Apple Silicon.
143
+
If you are using the torchao ops from python (i.e not with a C++ runner), they are available out of the box on a Mac with Apple Silicon, and you can skip these setup steps.
144
+
145
+
If you plan to use the kernels from the AOTI/ExecuTorch C++ runners, follow the setup steps below.
Note: if you want to use the new kernels in the AOTI and C++ runners, you must pass the flag link_torchao_ops when running the scripts the build the runners.
154
+
When building the AOTI and C++ runners, you must pass the flag link_torchao_ops when running the scripts the build the runners.
# Flag for whether the a8wxdq quantizer is available.
@@ -117,7 +129,45 @@ def quantize_model(
117
129
unwrap_tensor_subclass(model)
118
130
continue
119
131
120
-
ifquantizerin ["linear:a8wxdq", "embedding:wx"]:
132
+
ifquantizer=="linear:a8wxdq":
133
+
ifget_precision() !=torch.float32:
134
+
print(f"Quantizer {quantizer} requires float32 inputs, but received {get_precision()}. Changing dtype to float32. Note that after quantization, the weights will be lowbit integers, not float32.")
0 commit comments