Skip to content

Commit 3c7e839

Browse files
metascroyJack-Khuu
andauthored
Update ARM CPU experimental kernels from AO to leverage pip install (#1458)
* update experimental kernels in torchchat * Update docs/quantization.md Co-authored-by: Jack-Khuu <[email protected]> * Update torchchat/utils/quantize.py Co-authored-by: Jack-Khuu <[email protected]> * Update torchchat/utils/quantize.py Co-authored-by: Jack-Khuu <[email protected]> * Fixing import typo in quantize.py * Bump ET pin to pick up AO changes * Bump torchao-pin to match ET and torchchat * Update torchao-pin.txt * Split up AOTI and ET tests * Bump ET pin to 2-26-25 with new AO pin * Undo et pin bump; fails basic install * update * up * up * up * up * up * up * up * up * up * up * up * up --------- Co-authored-by: Jack-Khuu <[email protected]>
1 parent cbbbf50 commit 3c7e839

12 files changed

+194
-60
lines changed

.github/workflows/pull.yml

+61-20
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ jobs:
292292
echo "::endgroup::"
293293
294294
echo "::group::Run inference with quantize file"
295-
for DEVICE in cpu; do # cuda
295+
for DEVICE in cpu; do # cuda
296296
# cuda - fails because `AttributeError: 'Linear' object has no attribute '_linear_extra_repr'`
297297
# follow up with torchao as a separate PR
298298
echo "saving snapshot for device ${DEVICE} and dtype bfloat16, and reloading as snapshot"
@@ -349,7 +349,7 @@ jobs:
349349
# python3 torchchat.py export --output-snap model.tc --dtype float32 --quantize torchchat/quant_config/cuda-32.json --checkpoint "./checkpoints/${REPO_NAME}/model.pth"
350350
# python3 torchchat.py generate --snap model.tc --dtype float32 --checkpoint "./checkpoints/${REPO_NAME}/model.pth"
351351
# echo "::endgroup::"
352-
352+
353353
test-gpu-aoti-float16:
354354
permissions:
355355
id-token: write
@@ -1075,7 +1075,7 @@ jobs:
10751075
./runner/build_android.sh
10761076
echo "Tests complete."
10771077
1078-
test-torchao-aoti-experimental:
1078+
test-torchao-experimental-python:
10791079
strategy:
10801080
matrix:
10811081
runner: [macos-14-xlarge]
@@ -1107,13 +1107,60 @@ jobs:
11071107
./install/install_requirements.sh
11081108
pip3 list
11091109
python3 -c 'import torch;print(f"torch: {torch.__version__, torch.version.git_version}")'
1110-
- name: Install torchao-ops
1111-
id: install-torchao-ops
1110+
- name: Run inference
11121111
run: |
1113-
bash torchchat/utils/scripts/build_torchao_ops.sh
1114-
- name: Install runner AOTI
1115-
id: install-runner-aoti
1112+
python torchchat.py download stories110M
1113+
wget -O ./tokenizer.model https://github.com/karpathy/llama2.c/raw/master/tokenizer.model
1114+
export PRMT="Once upon a time in a land far away"
1115+
echo "Generate eager"
1116+
python torchchat.py generate stories110M --temperature 0 --prompt "${PRMT}" --device cpu --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}'
1117+
echo "Generate compile"
1118+
python torchchat.py generate stories110M --temperature 0 --prompt "${PRMT}" --device cpu --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}' --compile
1119+
echo "Export AOTI"
1120+
python torchchat.py export stories110M --output-aoti-package-path ./model.pt2 --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}'
1121+
echo "Generate AOTI"
1122+
python torchchat.py generate stories110M --aoti-package-path ./model.pt2 --prompt "${PRMT}"
1123+
echo "Tests complete."
1124+
1125+
test-torchao-experimental-cpp:
1126+
strategy:
1127+
matrix:
1128+
runner: [macos-14-xlarge]
1129+
runs-on: ${{matrix.runner}}
1130+
steps:
1131+
- name: Checkout repo
1132+
uses: actions/checkout@v3
1133+
with:
1134+
submodules: true
1135+
- name: Setup Python
1136+
uses: actions/setup-python@v2
1137+
with:
1138+
python-version: 3.10.11
1139+
- name: Setup Xcode
1140+
if: runner.os == 'macOS'
1141+
uses: maxim-lobanov/setup-xcode@v1
1142+
with:
1143+
xcode-version: '15.3'
1144+
- name: Print machine info
1145+
run: |
1146+
uname -a
1147+
if [ $(uname -s) == Darwin ]; then
1148+
sysctl machdep.cpu.brand_string
1149+
sysctl machdep.cpu.core_count
1150+
fi
1151+
- name: Install torchchat
1152+
run: |
1153+
echo "Intalling pip3 packages"
1154+
./install/install_requirements.sh
1155+
pip3 list
1156+
python3 -c 'import torch;print(f"torch: {torch.__version__, torch.version.git_version}")'
1157+
- name: Clone torchao
1158+
id: clone-torchao
1159+
run: |
1160+
bash torchchat/utils/scripts/clone_torchao.sh
1161+
- name: Install runner
11161162
run: |
1163+
echo "Installing runner"
11171164
bash torchchat/utils/scripts/build_native.sh aoti link_torchao_ops
11181165
- name: Run inference
11191166
run: |
@@ -1123,11 +1170,9 @@ jobs:
11231170
echo "Export and run AOTI (C++ runner)"
11241171
python torchchat.py export stories110M --output-aoti-package-path ./model.pt2 --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}'
11251172
./cmake-out/aoti_run ./model.pt2 -z ./tokenizer.model -t 0 -i "${PRMT}"
1126-
echo "Generate AOTI"
1127-
python torchchat.py generate stories110M --aoti-package-path ./model.pt2 --prompt "${PRMT}"
11281173
echo "Tests complete."
11291174
1130-
test-torchao-et-experimental:
1175+
test-torchao-experimental-et:
11311176
strategy:
11321177
matrix:
11331178
runner: [macos-14-xlarge]
@@ -1159,15 +1204,15 @@ jobs:
11591204
./install/install_requirements.sh
11601205
pip3 list
11611206
python3 -c 'import torch;print(f"torch: {torch.__version__, torch.version.git_version}")'
1162-
- name: Install torchao-ops
1163-
id: install-torchao-ops
1164-
run: |
1165-
bash torchchat/utils/scripts/build_torchao_ops.sh
11661207
- name: Install ET
11671208
run: |
11681209
echo "Installing ExecuTorch"
11691210
export TORCHCHAT_ROOT=${PWD}
11701211
bash torchchat/utils/scripts/install_et.sh
1212+
- name: Clone torchao
1213+
id: clone-torchao
1214+
run: |
1215+
bash torchchat/utils/scripts/clone_torchao.sh
11711216
- name: Install runner
11721217
run: |
11731218
echo "Installing runner"
@@ -1177,14 +1222,9 @@ jobs:
11771222
python torchchat.py download stories110M
11781223
wget -O ./tokenizer.model https://github.com/karpathy/llama2.c/raw/master/tokenizer.model
11791224
export PRMT="Once upon a time in a land far away"
1180-
echo "Generate eager"
1181-
python torchchat.py generate stories110M --temperature 0 --prompt "${PRMT}" --device cpu --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}'
1182-
echo "Generate compile"
1183-
python torchchat.py generate stories110M --temperature 0 --prompt "${PRMT}" --device cpu --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}' --compile
11841225
echo "Export and run ET (C++ runner)"
11851226
python torchchat.py export stories110M --output-pte-path ./model.pte --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}'
11861227
./cmake-out/et_run ./model.pte -z ./tokenizer.model -t 0 -i "${PRMT}"
1187-
echo "Tests complete."
11881228
11891229
test-torchao-experimental-mps:
11901230
strategy:
@@ -1216,6 +1256,7 @@ jobs:
12161256
- name: Install torchao-ops-mps
12171257
id: install-torchao-ops-mps
12181258
run: |
1259+
bash torchchat/utils/scripts/clone_torchao.sh
12191260
bash torchchat/utils/scripts/build_torchao_ops.sh mps
12201261
- name: Run inference
12211262
run: |

docs/quantization.md

+10-6
Original file line numberDiff line numberDiff line change
@@ -120,13 +120,15 @@ python3 torchchat.py generate llama3 --pte-path llama3.pte --prompt "Hello my n
120120

121121
## Experimental TorchAO lowbit kernels
122122

123-
WARNING: These kernels only work on devices with ARM CPUs, for example on Mac computers with Apple Silicon.
123+
If you are on a Mac with Apple Silicon, we have 1-8 quantization available for embedding and linear layers, backed by CPU and MPS kernels.
124+
125+
The CPU kernels are installed automatically by the torchchat install script and can be used out of the box. To use the MPS kernels, follow the setup instructions below.
124126

125127
### Use
126128

127129
#### linear:a8wxdq
128130
The quantization scheme linear:a8wxdq dynamically quantizes activations to 8 bits, and quantizes the weights in a groupwise manner with a specified bitwidth and groupsize.
129-
It takes arguments bitwidth (1, 2, 3, 4, 5, 6, 7), groupsize, and has_weight_zeros (true, false).
131+
It takes arguments bitwidth (1, 2, 3, 4, 5, 6, 7, 8), groupsize (-1 if channelwise desired), and has_weight_zeros (true, false).
130132
The argument has_weight_zeros indicates whether the weights are quantized with scales only (has_weight_zeros: false) or with both scales and zeros (has_weight_zeros: true).
131133
Roughly speaking, {bitwidth: 4, groupsize: 32, has_weight_zeros: false} is similar to GGML's Q4_0 quantization scheme.
132134

@@ -138,7 +140,9 @@ The quantization scheme embedding:wx quantizes embeddings in a groupwise manner
138140
You should expect high performance on ARM CPU if groupsize is divisible by 32. With other platforms and argument choices, a slow fallback kernel will be used. You will see warnings about this during quantization.
139141

140142
### Setup
141-
To use linear:a8wxdq and embedding:wx, you must set up the torchao experimental kernels. These will only work on devices with ARM CPUs, for example on Mac computers with Apple Silicon.
143+
If you are using the torchao ops from python (i.e not with a C++ runner), they are available out of the box on a Mac with Apple Silicon, and you can skip these setup steps.
144+
145+
If you plan to use the kernels from the AOTI/ExecuTorch C++ runners, follow the setup steps below.
142146

143147
From the torchchat root directory, run
144148
```
@@ -147,7 +151,7 @@ bash torchchat/utils/scripts/build_torchao_ops.sh
147151

148152
This should take about 10 seconds to complete.
149153

150-
Note: if you want to use the new kernels in the AOTI and C++ runners, you must pass the flag link_torchao_ops when running the scripts the build the runners.
154+
When building the AOTI and C++ runners, you must pass the flag link_torchao_ops when running the scripts the build the runners.
151155

152156
```
153157
bash torchchat/utils/scripts/build_native.sh aoti link_torchao_ops
@@ -175,8 +179,8 @@ OMP_NUM_THREADS=6 python3 torchchat.py generate llama3.1 --device cpu --dtype fl
175179

176180
#### AOTI
177181
```
178-
OMP_NUM_THREADS=6 python torchchat.py export llama3.1 --device cpu --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}' --output-dso llama3_1.so
179-
OMP_NUM_THREADS=6 python3 torchchat.py generate llama3.1 --dso-path llama3_1.so --prompt "Once upon a time," --num-samples 5
182+
OMP_NUM_THREADS=6 python torchchat.py export llama3.1 --device cpu --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}' --output-aoti-package-path llama3_1.pt2
183+
OMP_NUM_THREADS=6 python3 torchchat.py generate llama3.1 --aoti-package-path llama3_1.pt2 --prompt "Once upon a time," --num-samples 5
180184
```
181185

182186
If you built the AOTI runner with link_torchao_ops as discussed in the setup section, you can also use the C++ runner:

install/.pins/torchao-pin.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2e032c6b0de960dee554dcb08126ace718b14c6d
1+
711fa0809f06fc97febd0c3fe72563c3fe227e51

install/install_requirements.sh

+1-6
Original file line numberDiff line numberDiff line change
@@ -126,12 +126,7 @@ then
126126
)
127127
fi
128128

129-
# For torchao need to install from github since nightly build doesn't have macos build.
130-
# TODO: Remove this and install nightly build, once it supports macos
131-
(
132-
set -x
133-
$PIP_EXECUTABLE install git+https://github.com/pytorch/ao.git@7d8794622f3ac7ffa98761314019a20fba06edef
134-
)
129+
bash install/install_torchao.sh
135130

136131
if [[ -x "$(command -v nvidia-smi)" ]]; then
137132
(

install/install_torchao.sh

+39
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
#!/bin/bash
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
9+
# USE_CPP=1 indicates that the torchao experimental aten kernels will be built and loaded
10+
# if on Mac with Apple Silicon
11+
12+
if [ -z "${PYTHON_EXECUTABLE:-}" ];
13+
then
14+
if [[ -z ${CONDA_DEFAULT_ENV:-} ]] || [[ ${CONDA_DEFAULT_ENV:-} == "base" ]] || [[ ! -x "$(command -v python)" ]];
15+
then
16+
PYTHON_EXECUTABLE=python3
17+
else
18+
PYTHON_EXECUTABLE=python
19+
fi
20+
fi
21+
echo "Using python executable: $PYTHON_EXECUTABLE"
22+
23+
if [[ "$PYTHON_EXECUTABLE" == "python" ]];
24+
then
25+
PIP_EXECUTABLE=pip
26+
elif [[ "$PYTHON_EXECUTABLE" == "python3" ]];
27+
then
28+
PIP_EXECUTABLE=pip3
29+
else
30+
PIP_EXECUTABLE=pip${PYTHON_SYS_VERSION}
31+
fi
32+
echo "Using pip executable: $PIP_EXECUTABLE"
33+
34+
35+
export TORCHAO_PIN=$(cat install/.pins/torchao-pin.txt)
36+
(
37+
set -x
38+
USE_CPP=1 $PIP_EXECUTABLE install git+https://github.com/pytorch/ao.git@${TORCHAO_PIN}
39+
)

torchchat/export.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -439,7 +439,8 @@ def main(args):
439439
tokenizer,
440440
max_seq_length=builder_args.max_seq_length,
441441
support_tensor_subclass=output_dso_path is None
442-
and output_aoti_package_path is None,
442+
and output_aoti_package_path is None
443+
and output_pte_path is None,
443444
)
444445
model_to_pte = model
445446
model_to_dso = model

torchchat/utils/quantize.py

+54-19
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,18 @@
5050
state_dict_device,
5151
use_et_backend,
5252
)
53+
from torchao.experimental.packed_linear_int8_dynamic_activation_intx_weight_layout import (
54+
PackedLinearInt8DynamicActivationIntxWeightLayout,
55+
)
56+
from torchao.experimental.quant_api import (
57+
int8_dynamic_activation_intx_weight,
58+
IntxWeightEmbeddingQuantizer,
59+
)
60+
from torchao.quantization.granularity import (
61+
PerGroup,
62+
PerRow,
63+
)
64+
from torchao.dtypes import PlainLayout
5365

5466

5567
# Flag for whether the a8wxdq quantizer is available.
@@ -117,7 +129,45 @@ def quantize_model(
117129
unwrap_tensor_subclass(model)
118130
continue
119131

120-
if quantizer in ["linear:a8wxdq", "embedding:wx"]:
132+
if quantizer == "linear:a8wxdq":
133+
if get_precision() != torch.float32:
134+
print(f"Quantizer {quantizer} requires float32 inputs, but received {get_precision()}. Changing dtype to float32. Note that after quantization, the weights will be lowbit integers, not float32.")
135+
set_precision(torch.float32)
136+
137+
group_size = q_kwargs["groupsize"]
138+
bit_width = q_kwargs["bitwidth"]
139+
has_weight_zeros = q_kwargs["has_weight_zeros"]
140+
granularity = PerRow() if group_size == -1 else PerGroup(group_size)
141+
weight_dtype = getattr(torch, f"int{bit_width}")
142+
143+
try:
144+
quantize_(
145+
model,
146+
int8_dynamic_activation_intx_weight(
147+
weight_dtype=weight_dtype,
148+
granularity=granularity,
149+
has_weight_zeros=has_weight_zeros,
150+
layout=PackedLinearInt8DynamicActivationIntxWeightLayout(),
151+
),
152+
)
153+
except Exception as e:
154+
print("Encountered error during quantization: {e}")
155+
print("Trying with PlainLayout")
156+
quantize_(
157+
model,
158+
int8_dynamic_activation_intx_weight(
159+
weight_dtype=weight_dtype,
160+
granularity=granularity,
161+
has_weight_zeros=has_weight_zeros,
162+
layout=PlainLayout(),
163+
),
164+
)
165+
166+
if not support_tensor_subclass:
167+
unwrap_tensor_subclass(model)
168+
continue
169+
170+
if quantizer == "embedding:wx":
121171
# These quantizers require float32 input weights. Note that after quantization,
122172
# the weights will no longer be float32, but lowbit integers
123173
if get_precision() != torch.float32:
@@ -889,10 +939,12 @@ def quantized_model(self) -> nn.Module:
889939
# class references
890940
quantizer_class_dict = {
891941
"embedding": EmbeddingOnlyQuantHandler,
942+
"embedding:wx": IntxWeightEmbeddingQuantizer,
892943
"linear:int8": WeightOnlyInt8QuantHandler,
893944
"precision": PrecisionHandler,
894945
"executor": ExecutorHandler,
895946
"linear:int4": Int4WeightOnlyQuantizer,
947+
"linear:a8wxdq": None, # uses quantize_ API
896948
"linear:a8w4dq": Int8DynActInt4WeightQuantizer,
897949
}
898950

@@ -915,27 +967,10 @@ def quantized_model(self) -> nn.Module:
915967
torchao_experimental_quant_api_spec.loader.exec_module(
916968
torchao_experimental_quant_api
917969
)
918-
from torchao_experimental_quant_api import (
919-
Int8DynActIntxWeightLinearQuantizer,
920-
IntxWeightEmbeddingQuantizer,
921-
UIntxWeightOnlyLinearQuantizer,
922-
)
923-
924-
quantizer_class_dict["linear:a8wxdq"] = Int8DynActIntxWeightLinearQuantizer
925-
quantizer_class_dict["embedding:wx"] = IntxWeightEmbeddingQuantizer
970+
from torchao_experimental_quant_api import UIntxWeightOnlyLinearQuantizer
926971
quantizer_class_dict["linear:afpwx"] = UIntxWeightOnlyLinearQuantizer
927972

928973
# Try loading custom op
929-
try:
930-
import glob
931-
932-
libs = glob.glob(f"{torchao_build_path}/cmake-out/lib/libtorchao_ops_aten.*")
933-
libs = list(filter(lambda l: (l.endswith("so") or l.endswith("dylib")), libs))
934-
torch.ops.load_library(libs[0])
935-
print("Loaded torchao cpu ops.")
936-
except Exception as e:
937-
print("Unable to load torchao cpu ops library. Slow fallback kernels will be used.")
938-
939974
try:
940975
libname = "libtorchao_ops_mps_aten.dylib"
941976
libpath = f"{torchao_build_path}/cmake-out/lib/{libname}"

0 commit comments

Comments
 (0)