Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update ARM CPU experimental kernels from AO to leverage pip install #1458

Merged
merged 29 commits into from
Mar 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
bdac616
update experimental kernels in torchchat
metascroy Jan 15, 2025
74363e4
Update docs/quantization.md
metascroy Jan 16, 2025
48f568d
Update torchchat/utils/quantize.py
metascroy Jan 16, 2025
525701d
Update torchchat/utils/quantize.py
metascroy Jan 16, 2025
f9a7bb9
Fixing import typo in quantize.py
Jack-Khuu Jan 17, 2025
0abe175
Bump ET pin to pick up AO changes
Jack-Khuu Jan 18, 2025
95304b8
Merge branch 'main' into new-intx-quantizer
Jack-Khuu Feb 11, 2025
76e8ec5
Bump torchao-pin to match ET and torchchat
Jack-Khuu Feb 11, 2025
c2108d6
Merge branch 'main' into new-intx-quantizer
Jack-Khuu Feb 20, 2025
4ee1b96
Merge branch 'main' into new-intx-quantizer
Jack-Khuu Feb 24, 2025
61a1c62
Merge branch 'main' into new-intx-quantizer
Jack-Khuu Feb 26, 2025
3e04645
Update torchao-pin.txt
Jack-Khuu Feb 27, 2025
94fcd9a
Split up AOTI and ET tests
Jack-Khuu Feb 27, 2025
7e56c55
Bump ET pin to 2-26-25 with new AO pin
Jack-Khuu Feb 27, 2025
77e8a62
Undo et pin bump; fails basic install
Jack-Khuu Feb 27, 2025
67dd729
Merge branch 'main' into new-intx-quantizer
Jack-Khuu Mar 5, 2025
94ad51a
update
metascroy Mar 11, 2025
34cb931
up
metascroy Mar 11, 2025
b564fc1
up
metascroy Mar 11, 2025
9eed5d1
up
metascroy Mar 11, 2025
14365c4
up
metascroy Mar 11, 2025
66d90e1
up
metascroy Mar 11, 2025
12cbd13
up
metascroy Mar 11, 2025
28d1a99
up
metascroy Mar 11, 2025
d2cc25a
up
metascroy Mar 11, 2025
d79f870
up
metascroy Mar 11, 2025
a8106fd
up
metascroy Mar 11, 2025
aa6fb70
up
metascroy Mar 11, 2025
8a9a644
up
metascroy Mar 11, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 61 additions & 20 deletions .github/workflows/pull.yml
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ jobs:
echo "::endgroup::"

echo "::group::Run inference with quantize file"
for DEVICE in cpu; do # cuda
for DEVICE in cpu; do # cuda
# cuda - fails because `AttributeError: 'Linear' object has no attribute '_linear_extra_repr'`
# follow up with torchao as a separate PR
echo "saving snapshot for device ${DEVICE} and dtype bfloat16, and reloading as snapshot"
Expand Down Expand Up @@ -349,7 +349,7 @@ jobs:
# python3 torchchat.py export --output-snap model.tc --dtype float32 --quantize torchchat/quant_config/cuda-32.json --checkpoint "./checkpoints/${REPO_NAME}/model.pth"
# python3 torchchat.py generate --snap model.tc --dtype float32 --checkpoint "./checkpoints/${REPO_NAME}/model.pth"
# echo "::endgroup::"

test-gpu-aoti-float16:
permissions:
id-token: write
Expand Down Expand Up @@ -1075,7 +1075,7 @@ jobs:
./runner/build_android.sh
echo "Tests complete."

test-torchao-aoti-experimental:
test-torchao-experimental-python:
strategy:
matrix:
runner: [macos-14-xlarge]
Expand Down Expand Up @@ -1107,13 +1107,60 @@ jobs:
./install/install_requirements.sh
pip3 list
python3 -c 'import torch;print(f"torch: {torch.__version__, torch.version.git_version}")'
- name: Install torchao-ops
id: install-torchao-ops
- name: Run inference
run: |
bash torchchat/utils/scripts/build_torchao_ops.sh
- name: Install runner AOTI
id: install-runner-aoti
python torchchat.py download stories110M
wget -O ./tokenizer.model https://github.com/karpathy/llama2.c/raw/master/tokenizer.model
export PRMT="Once upon a time in a land far away"
echo "Generate eager"
python torchchat.py generate stories110M --temperature 0 --prompt "${PRMT}" --device cpu --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}'
echo "Generate compile"
python torchchat.py generate stories110M --temperature 0 --prompt "${PRMT}" --device cpu --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}' --compile
echo "Export AOTI"
python torchchat.py export stories110M --output-aoti-package-path ./model.pt2 --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}'
echo "Generate AOTI"
python torchchat.py generate stories110M --aoti-package-path ./model.pt2 --prompt "${PRMT}"
echo "Tests complete."

test-torchao-experimental-cpp:
strategy:
matrix:
runner: [macos-14-xlarge]
runs-on: ${{matrix.runner}}
steps:
- name: Checkout repo
uses: actions/checkout@v3
with:
submodules: true
- name: Setup Python
uses: actions/setup-python@v2
with:
python-version: 3.10.11
- name: Setup Xcode
if: runner.os == 'macOS'
uses: maxim-lobanov/setup-xcode@v1
with:
xcode-version: '15.3'
- name: Print machine info
run: |
uname -a
if [ $(uname -s) == Darwin ]; then
sysctl machdep.cpu.brand_string
sysctl machdep.cpu.core_count
fi
- name: Install torchchat
run: |
echo "Intalling pip3 packages"
./install/install_requirements.sh
pip3 list
python3 -c 'import torch;print(f"torch: {torch.__version__, torch.version.git_version}")'
- name: Clone torchao
id: clone-torchao
run: |
bash torchchat/utils/scripts/clone_torchao.sh
- name: Install runner
run: |
echo "Installing runner"
bash torchchat/utils/scripts/build_native.sh aoti link_torchao_ops
- name: Run inference
run: |
Expand All @@ -1123,11 +1170,9 @@ jobs:
echo "Export and run AOTI (C++ runner)"
python torchchat.py export stories110M --output-aoti-package-path ./model.pt2 --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}'
./cmake-out/aoti_run ./model.pt2 -z ./tokenizer.model -t 0 -i "${PRMT}"
echo "Generate AOTI"
python torchchat.py generate stories110M --aoti-package-path ./model.pt2 --prompt "${PRMT}"
echo "Tests complete."

test-torchao-et-experimental:
test-torchao-experimental-et:
strategy:
matrix:
runner: [macos-14-xlarge]
Expand Down Expand Up @@ -1159,15 +1204,15 @@ jobs:
./install/install_requirements.sh
pip3 list
python3 -c 'import torch;print(f"torch: {torch.__version__, torch.version.git_version}")'
- name: Install torchao-ops
id: install-torchao-ops
run: |
bash torchchat/utils/scripts/build_torchao_ops.sh
- name: Install ET
run: |
echo "Installing ExecuTorch"
export TORCHCHAT_ROOT=${PWD}
bash torchchat/utils/scripts/install_et.sh
- name: Clone torchao
id: clone-torchao
run: |
bash torchchat/utils/scripts/clone_torchao.sh
- name: Install runner
run: |
echo "Installing runner"
Expand All @@ -1177,14 +1222,9 @@ jobs:
python torchchat.py download stories110M
wget -O ./tokenizer.model https://github.com/karpathy/llama2.c/raw/master/tokenizer.model
export PRMT="Once upon a time in a land far away"
echo "Generate eager"
python torchchat.py generate stories110M --temperature 0 --prompt "${PRMT}" --device cpu --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}'
echo "Generate compile"
python torchchat.py generate stories110M --temperature 0 --prompt "${PRMT}" --device cpu --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}' --compile
echo "Export and run ET (C++ runner)"
python torchchat.py export stories110M --output-pte-path ./model.pte --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}'
./cmake-out/et_run ./model.pte -z ./tokenizer.model -t 0 -i "${PRMT}"
echo "Tests complete."

test-torchao-experimental-mps:
strategy:
Expand Down Expand Up @@ -1216,6 +1256,7 @@ jobs:
- name: Install torchao-ops-mps
id: install-torchao-ops-mps
run: |
bash torchchat/utils/scripts/clone_torchao.sh
bash torchchat/utils/scripts/build_torchao_ops.sh mps
- name: Run inference
run: |
Expand Down
16 changes: 10 additions & 6 deletions docs/quantization.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,13 +120,15 @@ python3 torchchat.py generate llama3 --pte-path llama3.pte --prompt "Hello my n

## Experimental TorchAO lowbit kernels

WARNING: These kernels only work on devices with ARM CPUs, for example on Mac computers with Apple Silicon.
If you are on a Mac with Apple Silicon, we have 1-8 quantization available for embedding and linear layers, backed by CPU and MPS kernels.

The CPU kernels are installed automatically by the torchchat install script and can be used out of the box. To use the MPS kernels, follow the setup instructions below.

### Use

#### linear:a8wxdq
The quantization scheme linear:a8wxdq dynamically quantizes activations to 8 bits, and quantizes the weights in a groupwise manner with a specified bitwidth and groupsize.
It takes arguments bitwidth (1, 2, 3, 4, 5, 6, 7), groupsize, and has_weight_zeros (true, false).
It takes arguments bitwidth (1, 2, 3, 4, 5, 6, 7, 8), groupsize (-1 if channelwise desired), and has_weight_zeros (true, false).
The argument has_weight_zeros indicates whether the weights are quantized with scales only (has_weight_zeros: false) or with both scales and zeros (has_weight_zeros: true).
Roughly speaking, {bitwidth: 4, groupsize: 32, has_weight_zeros: false} is similar to GGML's Q4_0 quantization scheme.

Expand All @@ -138,7 +140,9 @@ The quantization scheme embedding:wx quantizes embeddings in a groupwise manner
You should expect high performance on ARM CPU if groupsize is divisible by 32. With other platforms and argument choices, a slow fallback kernel will be used. You will see warnings about this during quantization.

### Setup
To use linear:a8wxdq and embedding:wx, you must set up the torchao experimental kernels. These will only work on devices with ARM CPUs, for example on Mac computers with Apple Silicon.
If you are using the torchao ops from python (i.e not with a C++ runner), they are available out of the box on a Mac with Apple Silicon, and you can skip these setup steps.

If you plan to use the kernels from the AOTI/ExecuTorch C++ runners, follow the setup steps below.

From the torchchat root directory, run
```
Expand All @@ -147,7 +151,7 @@ bash torchchat/utils/scripts/build_torchao_ops.sh

This should take about 10 seconds to complete.

Note: if you want to use the new kernels in the AOTI and C++ runners, you must pass the flag link_torchao_ops when running the scripts the build the runners.
When building the AOTI and C++ runners, you must pass the flag link_torchao_ops when running the scripts the build the runners.

```
bash torchchat/utils/scripts/build_native.sh aoti link_torchao_ops
Expand Down Expand Up @@ -175,8 +179,8 @@ OMP_NUM_THREADS=6 python3 torchchat.py generate llama3.1 --device cpu --dtype fl

#### AOTI
```
OMP_NUM_THREADS=6 python torchchat.py export llama3.1 --device cpu --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}' --output-dso llama3_1.so
OMP_NUM_THREADS=6 python3 torchchat.py generate llama3.1 --dso-path llama3_1.so --prompt "Once upon a time," --num-samples 5
OMP_NUM_THREADS=6 python torchchat.py export llama3.1 --device cpu --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}' --output-aoti-package-path llama3_1.pt2
OMP_NUM_THREADS=6 python3 torchchat.py generate llama3.1 --aoti-package-path llama3_1.pt2 --prompt "Once upon a time," --num-samples 5
```

If you built the AOTI runner with link_torchao_ops as discussed in the setup section, you can also use the C++ runner:
Expand Down
2 changes: 1 addition & 1 deletion install/.pins/torchao-pin.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2e032c6b0de960dee554dcb08126ace718b14c6d
711fa0809f06fc97febd0c3fe72563c3fe227e51
7 changes: 1 addition & 6 deletions install/install_requirements.sh
Original file line number Diff line number Diff line change
Expand Up @@ -126,12 +126,7 @@ then
)
fi

# For torchao need to install from github since nightly build doesn't have macos build.
# TODO: Remove this and install nightly build, once it supports macos
(
set -x
$PIP_EXECUTABLE install git+https://github.com/pytorch/ao.git@7d8794622f3ac7ffa98761314019a20fba06edef
)
bash install/install_torchao.sh

if [[ -x "$(command -v nvidia-smi)" ]]; then
(
Expand Down
39 changes: 39 additions & 0 deletions install/install_torchao.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#!/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


# USE_CPP=1 indicates that the torchao experimental aten kernels will be built and loaded
# if on Mac with Apple Silicon

if [ -z "${PYTHON_EXECUTABLE:-}" ];
then
if [[ -z ${CONDA_DEFAULT_ENV:-} ]] || [[ ${CONDA_DEFAULT_ENV:-} == "base" ]] || [[ ! -x "$(command -v python)" ]];
then
PYTHON_EXECUTABLE=python3
else
PYTHON_EXECUTABLE=python
fi
fi
echo "Using python executable: $PYTHON_EXECUTABLE"

if [[ "$PYTHON_EXECUTABLE" == "python" ]];
then
PIP_EXECUTABLE=pip
elif [[ "$PYTHON_EXECUTABLE" == "python3" ]];
then
PIP_EXECUTABLE=pip3
else
PIP_EXECUTABLE=pip${PYTHON_SYS_VERSION}
fi
echo "Using pip executable: $PIP_EXECUTABLE"


export TORCHAO_PIN=$(cat install/.pins/torchao-pin.txt)
(
set -x
USE_CPP=1 $PIP_EXECUTABLE install git+https://github.com/pytorch/ao.git@${TORCHAO_PIN}
)
3 changes: 2 additions & 1 deletion torchchat/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,8 @@ def main(args):
tokenizer,
max_seq_length=builder_args.max_seq_length,
support_tensor_subclass=output_dso_path is None
and output_aoti_package_path is None,
and output_aoti_package_path is None
and output_pte_path is None,
)
model_to_pte = model
model_to_dso = model
Expand Down
73 changes: 54 additions & 19 deletions torchchat/utils/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,18 @@
state_dict_device,
use_et_backend,
)
from torchao.experimental.packed_linear_int8_dynamic_activation_intx_weight_layout import (
PackedLinearInt8DynamicActivationIntxWeightLayout,
)
from torchao.experimental.quant_api import (
int8_dynamic_activation_intx_weight,
IntxWeightEmbeddingQuantizer,
)
from torchao.quantization.granularity import (
PerGroup,
PerRow,
)
from torchao.dtypes import PlainLayout


# Flag for whether the a8wxdq quantizer is available.
Expand Down Expand Up @@ -117,7 +129,45 @@ def quantize_model(
unwrap_tensor_subclass(model)
continue

if quantizer in ["linear:a8wxdq", "embedding:wx"]:
if quantizer == "linear:a8wxdq":
if get_precision() != torch.float32:
print(f"Quantizer {quantizer} requires float32 inputs, but received {get_precision()}. Changing dtype to float32. Note that after quantization, the weights will be lowbit integers, not float32.")
set_precision(torch.float32)

group_size = q_kwargs["groupsize"]
bit_width = q_kwargs["bitwidth"]
has_weight_zeros = q_kwargs["has_weight_zeros"]
granularity = PerRow() if group_size == -1 else PerGroup(group_size)
weight_dtype = getattr(torch, f"int{bit_width}")

try:
quantize_(
model,
int8_dynamic_activation_intx_weight(
weight_dtype=weight_dtype,
granularity=granularity,
has_weight_zeros=has_weight_zeros,
layout=PackedLinearInt8DynamicActivationIntxWeightLayout(),
),
)
except Exception as e:
print("Encountered error during quantization: {e}")
print("Trying with PlainLayout")
quantize_(
model,
int8_dynamic_activation_intx_weight(
weight_dtype=weight_dtype,
granularity=granularity,
has_weight_zeros=has_weight_zeros,
layout=PlainLayout(),
),
)

if not support_tensor_subclass:
unwrap_tensor_subclass(model)
continue

if quantizer == "embedding:wx":
# These quantizers require float32 input weights. Note that after quantization,
# the weights will no longer be float32, but lowbit integers
if get_precision() != torch.float32:
Expand Down Expand Up @@ -889,10 +939,12 @@ def quantized_model(self) -> nn.Module:
# class references
quantizer_class_dict = {
"embedding": EmbeddingOnlyQuantHandler,
"embedding:wx": IntxWeightEmbeddingQuantizer,
"linear:int8": WeightOnlyInt8QuantHandler,
"precision": PrecisionHandler,
"executor": ExecutorHandler,
"linear:int4": Int4WeightOnlyQuantizer,
"linear:a8wxdq": None, # uses quantize_ API
"linear:a8w4dq": Int8DynActInt4WeightQuantizer,
}

Expand All @@ -915,27 +967,10 @@ def quantized_model(self) -> nn.Module:
torchao_experimental_quant_api_spec.loader.exec_module(
torchao_experimental_quant_api
)
from torchao_experimental_quant_api import (
Int8DynActIntxWeightLinearQuantizer,
IntxWeightEmbeddingQuantizer,
UIntxWeightOnlyLinearQuantizer,
)

quantizer_class_dict["linear:a8wxdq"] = Int8DynActIntxWeightLinearQuantizer
quantizer_class_dict["embedding:wx"] = IntxWeightEmbeddingQuantizer
from torchao_experimental_quant_api import UIntxWeightOnlyLinearQuantizer
quantizer_class_dict["linear:afpwx"] = UIntxWeightOnlyLinearQuantizer

# Try loading custom op
try:
import glob

libs = glob.glob(f"{torchao_build_path}/cmake-out/lib/libtorchao_ops_aten.*")
libs = list(filter(lambda l: (l.endswith("so") or l.endswith("dylib")), libs))
torch.ops.load_library(libs[0])
print("Loaded torchao cpu ops.")
except Exception as e:
print("Unable to load torchao cpu ops library. Slow fallback kernels will be used.")

try:
libname = "libtorchao_ops_mps_aten.dylib"
libpath = f"{torchao_build_path}/cmake-out/lib/{libname}"
Expand Down
Loading