diff --git a/README.md b/README.md index 0da273f91..64b05b9b5 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ [![](https://dcbadge.vercel.app/api/server/gpumode?style=flat)](https://discord.gg/gpumode) -[Introduction](#introduction) | [Inference](#inference) | [Training](#training) | [Composability](#composability) | [Custom Kernels](#custom-kernels) | [Alpha Features](#alpha-features) | [Installation](#installation) | [Integrations](#integrations) | [Videos](#videos) | [License](#license) | [Citation](#citation) +[Introduction](#introduction) | [Inference](#inference) | [Training](#training) | [Installation](#installation) |[Composability](#composability) | [Custom Kernels](#custom-kernels) | [Prototype Features](#prototype-features) | [Integrations](#integrations) | [Videos](#videos) | [License](#license) | [Citation](#citation) ## Introduction @@ -17,21 +17,29 @@ torchao just works with `torch.compile()` and `FSDP2` over most PyTorch models o ## Inference +Our optimizations deliver significant speedups and memory savings: + +- **INT4 Weight-Only Quantization**: 2x higher throughput (201 vs 107 tokens/sec) with 65% less memory (4.9GB vs 13.9GB) on LLaMA-2-7B +- **Float8 Dynamic Quantization**: Demonstrates 53.88% speedup on Flux.1-Dev* and 27.33% speedup on CogVideoX-5b on H100 GPU while preserving image quality +- **INT4 + 2:4 Sparsity**: 2.4x throughput increase (226 vs 95 tokens/sec) with 80% memory reduction (5.3GB vs 16.4GB peak memory) on LLaMA-3-8B + +For detailed benchmarks across models and techniques, see our [quantization documentation](torchao/quantization/README.md). + ### Post Training Quantization -Quantizing and Sparsifying your models is a 1 liner that should work on any model with an `nn.Linear` including your favorite HuggingFace model. You can find a more comprehensive usage instructions [here](torchao/quantization/), sparsity [here](/torchao/_models/sam/README.md) and a HuggingFace inference example [here](scripts/hf_eval.py) +Quantizing and Sparsifying your models is a 1 liner that should work on any model with an `nn.Linear` including your favorite HuggingFace model. You can find a more comprehensive usage instructions [here](torchao/quantization/README.md), sparsity [here](torchao/sparsity/README.md) and a HuggingFace inference example [here](scripts/hf_eval.py) For inference, we have the option of 1. Quantize only the weights: works best for memory bound models 2. Quantize the weights and activations: works best for compute bound models -2. Quantize the activations and weights and sparsify the weight +3. Quantize the activations and weights and sparsify the weight ```python -from torchao.quantization.quant_api import ( +from torchao.quantization import ( quantize_, int8_dynamic_activation_int8_weight, + float8_dynamic_activation_float8_weight, int4_weight_only, - int8_weight_only ) quantize_(m, int4_weight_only()) ``` @@ -52,13 +60,14 @@ We also provide a developer facing API so you can implement your own quantizatio We've added kv cache quantization and other features in order to enable long context length (and necessarily memory efficient) inference. -In practice these features alongside int4 weight only quantization allow us to **reduce peak memory by ~55%**, meaning we can Llama3.1-8B inference with a **130k context length with only 18.9 GB of peak memory.** More details can be found [here](torchao/_models/llama/README.md) +In practice these features alongside int4 weight only quantization allow us to **reduce peak memory by ~55%**, meaning we can Llama3.1-8B inference with a **130k context length with only 18.9 GB of peak memory.** More details can be found [here](torchao/_models/llama/README.md#kv-cache-quantization---memory-efficient-inference) + ## Training ### Quantization Aware Training -Post-training quantization can result in a fast and compact model, but may also lead to accuracy degradation. We recommend exploring Quantization Aware Training (QAT) to overcome this limitation. In collaboration with Torchtune, we've developed a QAT recipe that demonstrates significant accuracy improvements over traditional PTQ, recovering **96% of the accuracy degradation on hellaswag and 68% of the perplexity degradation on wikitext** for Llama3 compared to post-training quantization (PTQ). And we've provided a full recipe [here](https://pytorch.org/blog/quantization-aware-training/). For more details, please see the [QAT README](./torchao/quantization/qat/README.md). +Post-training quantization can result in a fast and compact model, but may also lead to accuracy degradation. We recommend exploring Quantization Aware Training (QAT) to overcome this limitation. In collaboration with [Torchtune](https://github.com/pytorch/torchtune/blob/main/recipes/quantization.md#quantization-aware-training-qat), we've developed a QAT recipe that demonstrates significant accuracy improvements over traditional PTQ, recovering **96% of the accuracy degradation on hellaswag and 68% of the perplexity degradation on wikitext** for Llama3 compared to post-training quantization (PTQ). And we've provided a full recipe [here](https://pytorch.org/blog/quantization-aware-training/) ```python from torchao.quantization import ( @@ -107,6 +116,8 @@ We've added support for semi-structured 2:4 sparsity with **6% end-to-end speedu The code change is a 1 liner with the full example available [here](torchao/sparsity/training/) ```python +from torchao.sparsity.training import SemiSparseLinear, swap_linear_with_semi_sparse_linear + swap_linear_with_semi_sparse_linear(model, {"seq.0": SemiSparseLinear}) ``` @@ -128,68 +139,65 @@ optim = CPUOffloadOptimizer(model.parameters(), torch.optim.AdamW, fused=True) optim.load_state_dict(ckpt["optim"]) ``` -## Composability +## Installation -1. `torch.compile`: A key design principle for us is composability as in any new dtype or layout we provide needs to work with our compiler. It shouldn't matter if the kernels are written in pure PyTorch, CUDA, C++, or Triton - things should just work! So we write the dtype, layout, or bit packing logic in pure PyTorch and code-generate efficient kernels. -3. [FSDP2](https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md): Historically most quantization has been done for inference, there is now a thriving area of research combining distributed algorithms and quantization. +`torchao` makes liberal use of several new features in Pytorch, it's recommended to use it with the current nightly or latest stable version of PyTorch, see [getting started](https://pytorch.org/get-started/locally/) for more details. -The best example we have combining the composability of lower bit dtype with compile and fsdp is [NF4](torchao/dtypes/nf4tensor.py) which we used to implement the [QLoRA](https://www.youtube.com/watch?v=UvRl4ansfCg) algorithm. So if you're doing research at the intersection of this area we'd love to hear from you. +Install the stable release (recommended): +```bash +pip install torchao +``` -## Custom Kernels +Other options: +```bash +# Nightly build +pip install --pre torchao --index-url https://download.pytorch.org/whl/nightly/cu124 -We've added support for authoring and releasing [custom ops](./torchao/csrc/) that do not graph break with `torch.compile()` so if you love writing kernels but hate packaging them so they work all operating systems and cuda versions, we'd love to accept contributions for your custom ops. We have a few examples you can follow +# Different CUDA versions +pip install torchao --index-url https://download.pytorch.org/whl/cu118 # CUDA 11.8 +pip install torchao --index-url https://download.pytorch.org/whl/cpu # CPU only -1. [fp6](torchao/dtypes/floatx) for 2x faster inference over fp16 with an easy to use API `quantize_(model, fpx_weight_only(3, 2))` -2. [2:4 Sparse Marlin GEMM](https://github.com/pytorch/ao/pull/733) 2x speedups for FP16xINT4 kernels even at batch sizes up to 256 -3. [int4 tinygemm unpacker](https://github.com/pytorch/ao/pull/415) which makes it easier to switch quantized backends for inference +``` -If you believe there's other CUDA kernels we should be taking a closer look at please leave a comment on [this issue](https://github.com/pytorch/ao/issues/697) +### Development Install +``` +USE_CPP=0 python setup.py develop # Skip C++/CUDA extensions +``` +## Composability +`torch.compile`: A key design principle for us is composability - any custom dtype or memory layout should work with our compiler. We enable kernel implementations in PyTorch, CUDA, C++, or Triton. This allows researchers and engineers to start with high-level dtype and layout logic in pure PyTorch, then progressively optimize performance by implementing lower-level kernels as needed, while maintaining compatibility with the compile infrastructure. -## Alpha features +[FSDP2](https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md): Historically most quantization has been done for inference, there is now a thriving area of research combining distributed algorithms and quantization. -Things we're excited about but need more time to cook in the oven +The best example we have combining the composability of lower bit dtype with compile and fsdp is [NF4](torchao/dtypes/nf4tensor.py) which we used to implement the [QLoRA](https://www.youtube.com/watch?v=UvRl4ansfCg) algorithm. So if you're doing research at the intersection of this area we'd love to hear from you. -1. [MX](torchao/prototype/mx_formats) training and inference support with tensors using the [OCP MX spec](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) data types, which can be described as groupwise scaled float8/float6/float4/int8, with the scales being constrained to powers of two. This work is prototype as the hardware support is not available yet. -2. [Int8 Quantized Training](https://github.com/pytorch/ao/tree/main/torchao/prototype/quantized_training): We're trying out full int8 training. This is easy to use with `quantize_(model, int8_weight_only_quantized_training())`. This work is prototype as the memory benchmarks are not compelling yet. -3. [IntX](https://github.com/pytorch/ao/tree/main/torchao/dtypes/uintx): We've managed to support all the ints by doing some clever bitpacking in pure PyTorch and then compiling it. This work is prototype as unfortunately without some more investment in either the compiler or low-bit kernels, int4 is more compelling than any smaller dtype -4. [Bitnet](https://github.com/pytorch/ao/blob/main/torchao/prototype/dtypes/bitnet.py): Mostly this is very cool to people on the team. This is prototype because how useful these kernels are is highly dependent on better hardware and kernel support. +Our framework makes it straightforward to add tensor parallel support to your custom quantized tensor subclass. Check out our [tensor parallel tutorial](tutorials/developer_api_guide/tensor_parallel.py) to see how a quantized tensor subclass can be extended to support column and row-wise tensor sharding while maintaining compatibility with `torch.compile`. -## Installation +## Custom Kernels -`torchao` makes liberal use of several new features in Pytorch, it's recommended to use it with the current nightly or latest stable version of PyTorch. +We've added support for authoring and releasing [custom ops](./torchao/csrc/) that do not graph break with `torch.compile()`. We have a few examples you can follow -Stable release from Pypi which will default to CUDA 12.1 +1. [fp6](torchao/dtypes/floatx/README.md) for 2x faster inference over fp16 with an easy to use API `quantize_(model, fpx_weight_only(3, 2))` +2. [2:4 Sparse Marlin GEMM](https://github.com/pytorch/ao/pull/733) 2x speedups for FP16xINT4 kernels even at batch sizes up to 256 +3. [int4 tinygemm unpacker](https://github.com/pytorch/ao/pull/415) which makes it easier to switch quantized backends for inference -```Shell -pip install torchao -``` +If you believe there's other CUDA kernels we should be taking a closer look at please leave a comment on [this issue](https://github.com/pytorch/ao/issues/697) or feel free to contribute directly to the repo. -Stable Release from the PyTorch index -```Shell -pip install torchao --extra-index-url https://download.pytorch.org/whl/cu121 # full options are cpu/cu118/cu121/cu124 -``` -Nightly Release -```Shell -pip install --pre torchao --index-url https://download.pytorch.org/whl/nightly/cu121 # full options are cpu/cu118/cu121/cu124 -``` +## Prototype Features -For *most* developers you probably want to skip building custom C++/CUDA extensions for faster iteration +Check out our [prototype directory](torchao/prototype/README.md) where we experiment with cutting-edge model optimization techniques for both training and inference. If you're interested in contributing experimental research or want to explore feel free to open an issue or PR. -```Shell -USE_CPP=0 pip install -e . -``` ## OSS Integrations We're also fortunate to be integrated into some of the leading open-source libraries including 1. Hugging Face transformers with a [builtin inference backend](https://huggingface.co/docs/transformers/main/quantization/torchao) and [low bit optimizers](https://github.com/huggingface/transformers/pull/31865) -2. Hugging Face diffusers best practices with torch.compile and torchao in a standalone repo [diffusers-torchao](https://github.com/sayakpaul/diffusers-torchao) +2. Hugging Face diffusers best practices with torch.compile and torchao in a standalone repo [diffusers-torchao](https://github.com/huggingface/diffusers/blob/main/docs/source/en/quantization/torchao.md) 3. Mobius HQQ backend leveraged our int4 kernels to get [195 tok/s on a 4090](https://github.com/mobiusml/hqq#faster-inference) -4. [TorchTune](https://github.com/pytorch/torchtune) for our QLoRA and QAT recipes +4. [TorchTune](https://pytorch.org/torchtune/main/tutorials/qlora_finetune.html?highlight=qlora) for our QLoRA and QAT recipes 5. [torchchat](https://github.com/pytorch/torchchat) for post training quantization -6. SGLang for LLM serving: [usage](https://github.com/sgl-project/sglang/blob/4f2ee48ed1c66ee0e189daa4120581de324ee814/docs/backend/backend.md?plain=1#L83) and the major [PR](https://github.com/sgl-project/sglang/pull/1341). +6. SGLang for LLM serving: [usage](https://docs.sglang.ai/backend/server_arguments.html#server-arguments) and the major [PR](https://github.com/sgl-project/sglang/pull/1341). ## Videos * [Keynote talk at GPU MODE IRL](https://youtu.be/FH5wiwOyPX4?si=VZK22hHz25GRzBG1&t=1009)