Skip to content

mit-han-lab/nunchaku

Repository files navigation

Nunchaku is a high-performance inference engine optimized for 4-bit neural networks, as introduced in our paper SVDQuant. For the underlying quantization library, check out DeepCompressor.

Join our user groups on Slack, Discord and WeChat to engage in discussions with the community! More details can be found here. If you have any questions, run into issues, or are interested in contributing, don’t hesitate to reach out!

News

  • [2025-04-16] 🎥 Released tutorial videos in both English and Chinese to assist installation and usage.
  • [2025-04-09] 📢 Published the April roadmap and an FAQ to help the community get started and stay up to date with Nunchaku’s development.
  • [2025-04-05] 🚀 Nunchaku v0.2.0 released! This release brings multi-LoRA and ControlNet support with even faster performance powered by FP16 attention and First-Block Cache. We've also added compatibility for 20-series GPUs — Nunchaku is now more accessible than ever!
  • [2025-03-17] 🚀 Released NVFP4 4-bit Shuttle-Jaguar and FLUX.1-tools and also upgraded the INT4 FLUX.1-tool models. Download and update your models from our HuggingFace or ModelScope collections!
  • [2025-03-13] 📦 Separate the ComfyUI node into a standalone repository for easier installation and release node v0.1.6! Plus, 4-bit Shuttle-Jaguar is now fully supported!
  • [2025-03-07] 🚀 Nunchaku v0.1.4 Released! We've supported 4-bit text encoder and per-layer CPU offloading, reducing FLUX's minimum memory requirement to just 4 GiB while maintaining a 2–3× speedup. This update also fixes various issues related to resolution, LoRA, pin memory, and runtime stability. Check out the release notes for full details!
  • [2025-02-20] 🚀 We release the pre-built wheels to simplify installation! Check here for the guidance!
  • [2025-02-20] 🚀 Support NVFP4 precision on NVIDIA RTX 5090! NVFP4 delivers superior image quality compared to INT4, offering ~3× speedup on the RTX 5090 over BF16. Learn more in our blog, checkout examples for usage and try our demo online!
  • [2025-02-18] 🔥 Customized LoRA conversion and model quantization instructions are now available! ComfyUI workflows now support customized LoRA, along with FLUX.1-Tools!
  • [2025-02-11] 🎉 SVDQuant has been selected as a ICLR 2025 Spotlight! FLUX.1-tools Gradio demos are now available! Check here for the usage details! Our new depth-to-image demo is also online—try it out!
More
  • [2025-02-04] 🚀 4-bit FLUX.1-tools is here! Enjoy a 2-3× speedup over the original models. Check out the examples for usage. ComfyUI integration is coming soon!
  • [2025-01-23] 🚀 4-bit SANA support is here! Experience a 2-3× speedup compared to the 16-bit model. Check out the usage example and the deployment guide for more details. Explore our live demo at svdquant.mit.edu!
  • [2025-01-22] 🎉 SVDQuant has been accepted to ICLR 2025!
  • [2024-12-08] Support ComfyUI. Please check mit-han-lab/ComfyUI-nunchaku for the usage.
  • [2024-11-07] 🔥 Our latest W4A4 Diffusion model quantization work SVDQuant is publicly released! Check DeepCompressor for the quantization library.

Overview

teaser SVDQuant is a post-training quantization technique for 4-bit weights and activations that well maintains visual fidelity. On 12B FLUX.1-dev, it achieves 3.6× memory reduction compared to the BF16 model. By eliminating CPU offloading, it offers 8.7× speedup over the 16-bit model when on a 16GB laptop 4090 GPU, 3× faster than the NF4 W4A16 baseline. On PixArt-∑, it demonstrates significantly superior visual quality over other W4A4 or even W4A8 baselines. "E2E" means the end-to-end latency including the text encoder and VAE decoder.

SVDQuant: Absorbing Outliers by Low-Rank Components for 4-Bit Diffusion Models
Muyang Li*, Yujun Lin*, Zhekai Zhang*, Tianle Cai, Xiuyu Li, Junxian Guo, Enze Xie, Chenlin Meng, Jun-Yan Zhu, and Song Han
MIT, NVIDIA, CMU, Princeton, UC Berkeley, SJTU, and Pika Labs

Method

Quantization Method -- SVDQuant

intuitionOverview of SVDQuant. Stage1: Originally, both the activation $\boldsymbol{X}$ and weights $\boldsymbol{W}$ contain outliers, making 4-bit quantization challenging. Stage 2: We migrate the outliers from activations to weights, resulting in the updated activation $\hat{\boldsymbol{X}}$ and weights $\hat{\boldsymbol{W}}$. While $\hat{\boldsymbol{X}}$ becomes easier to quantize, $\hat{\boldsymbol{W}}$ now becomes more difficult. Stage 3: SVDQuant further decomposes $\hat{\boldsymbol{W}}$ into a low-rank component $\boldsymbol{L}_1\boldsymbol{L}_2$ and a residual $\hat{\boldsymbol{W}}-\boldsymbol{L}_1\boldsymbol{L}_2$ with SVD. Thus, the quantization difficulty is alleviated by the low-rank branch, which runs at 16-bit precision.

Nunchaku Engine Design

engine (a) Naïvely running low-rank branch with rank 32 will introduce 57% latency overhead due to extra read of 16-bit inputs in Down Projection and extra write of 16-bit outputs in Up Projection. Nunchaku optimizes this overhead with kernel fusion. (b) Down Projection and Quantize kernels use the same input, while Up Projection and 4-Bit Compute kernels share the same output. To reduce data movement overhead, we fuse the first two and the latter two kernels together.

Performance

efficiencySVDQuant reduces the 12B FLUX.1 model size by 3.6× and cuts the 16-bit model's memory usage by 3.5×. With Nunchaku, our INT4 model runs 3.0× faster than the NF4 W4A16 baseline on both desktop and laptop NVIDIA RTX 4090 GPUs. Notably, on the laptop 4090, it achieves a total 10.1× speedup by eliminating CPU offloading. Our NVFP4 model is also 3.1× faster than both BF16 and NF4 on the RTX 5090 GPU.

Installation

We provide tutorial videos to help you install and use Nunchaku on Windows, available in both English and Chinese. You can also follow the corresponding step-by-step text guide at docs/setup_windows.md. If you run into issues, these resources are a good place to start.

Wheels

Prerequisites

Before installation, ensure you have PyTorch>=2.5 installed. For example, you can use the following command to install PyTorch 2.6:

pip install torch==2.6 torchvision==0.21 torchaudio==2.6

Install nunchaku

Once PyTorch is installed, you can directly install nunchaku from Hugging Face, ModelScope or GitHub release. Be sure to select the appropriate wheel for your Python and PyTorch version. For example, for Python 3.11 and PyTorch 2.6:

pip install https://huggingface.co/mit-han-lab/nunchaku/resolve/main/nunchaku-0.2.0+torch2.6-cp311-cp311-linux_x86_64.whl
For ComfyUI Users

If you're using the ComfyUI portable package, make sure to install nunchaku into the correct Python environment bundled with ComfyUI. To find the right Python path, launch ComfyUI and check the log output. You'll see something like this in the first several lines:

** Python executable: G:\ComfyuI\python\python.exe

Use that Python executable to install nunchaku:

"G:\ComfyUI\python\python.exe" -m pip install <your-wheel-file>.whl

Example: Installing for Python 3.11 and PyTorch 2.6:

"G:\ComfyUI\python\python.exe" -m pip install https://github.com/mit-han-lab/nunchaku/releases/download/v0.2.0/nunchaku-0.2.0+torch2.6-cp311-cp311-linux_x86_64.whl
For Blackwell GPUs (50-series)

If you're using a Blackwell GPU (e.g., 50-series GPUs), install a wheel with PyTorch 2.7 and higher. Additionally, use FP4 models instead of INT4 models."

Build from Source

Note:

  • Make sure your CUDA version is at least 12.2 on Linux and at least 12.6 on Windows. If you're using a Blackwell GPU (e.g., 50-series GPUs), CUDA 12.8 or higher is required.

  • For Windows users, please refer to this issue for the instruction. Please upgrade your MSVC compiler to the latest version.

  • We currently support only NVIDIA GPUs with architectures sm_75 (Turing: RTX 2080), sm_86 (Ampere: RTX 3090, A6000), sm_89 (Ada: RTX 4090), and sm_80 (A100). See this issue for more details.

  1. Install dependencies:

    conda create -n nunchaku python=3.11
    conda activate nunchaku
    pip install torch torchvision torchaudio
    pip install ninja wheel diffusers transformers accelerate sentencepiece protobuf huggingface_hub
    
    # For gradio demos
    pip install peft opencv-python gradio spaces GPUtil  

    To enable NVFP4 on Blackwell GPUs (e.g., 50-series GPUs), please install nightly PyTorch with CUDA 12.8. The installation command can be:

    pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128
  2. Install nunchaku package: Make sure you have gcc/g++>=11. If you don't, you can install it via Conda on Linux:

    conda install -c conda-forge gxx=11 gcc=11

    For Windows users, you can download and install the lastest Visual Studio.

    Then build the package from source with

    git clone https://github.com/mit-han-lab/nunchaku.git
    cd nunchaku
    git submodule init
    git submodule update
    python setup.py develop

    If you are building wheels for distribution, use:

    NUNCHAKU_INSTALL_MODE=ALL NUNCHAKU_BUILD_WHEELS=1 python -m build --wheel --no-isolation

    Make sure to set the environment variable NUNCHAKU_INSTALL_MODE to ALL. Otherwise, the generated wheels will only work on GPUs with the same architecture as the build machine.

Usage Example

In examples, we provide minimal scripts for running INT4 FLUX.1 and SANA models with Nunchaku. It shares the same APIs as diffusers and can be used in a similar way. For example, the script for FLUX.1-dev is as follows:

import torch
from diffusers import FluxPipeline

from nunchaku import NunchakuFluxTransformer2dModel
from nunchaku.utils import get_precision

precision = get_precision()  # auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer = NunchakuFluxTransformer2dModel.from_pretrained(f"mit-han-lab/svdq-{precision}-flux.1-dev")
pipeline = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")
image = pipeline("A cat holding a sign that says hello world", num_inference_steps=50, guidance_scale=3.5).images[0]
image.save(f"flux.1-dev-{precision}.png")

Note: If you're using a Turing GPU (e.g., NVIDIA 20-series), make sure to set torch_dtype=torch.float16 and use our nunchaku-fp16 attention module as below. A complete example is available in examples/flux.1-dev-turing.py.

FP16 Attention

In addition to FlashAttention-2, Nunchaku introduces a custom FP16 attention implementation that achieves up to 1.2× faster performance on NVIDIA 30-, 40-, and even 50-series GPUs—without loss in precision. To enable it, simply use:

transformer.set_attention_impl("nunchaku-fp16")

See examples/flux.1-dev-fp16attn.py for a complete example.

First-Block Cache

Nunchaku supports First-Block Cache to accelerate long-step denoising. Enable it easily with:

apply_cache_on_pipe(pipeline, residual_diff_threshold=0.12)

You can tune the residual_diff_threshold to balance speed and quality: larger values yield faster inference at the cost of some quality. A recommended value is 0.12, which provides up to 2× speedup for 50-step denoising and 1.4× speedup for 30-step denoising. See the full example in examples/flux.1-dev-cache.py.

CPU Offloading

To minimize GPU memory usage, Nunchaku supports CPU offloading—requiring as little as 4 GiB of GPU memory. You can enable it by setting offload=True when initializing NunchakuFluxTransformer2dModel, and then calling:

pipeline.enable_sequential_cpu_offload()

For a complete example, refer to examples/flux.1-dev-offload.py.

Customized LoRA

lora

SVDQuant seamlessly integrates with off-the-shelf LoRAs without requiring requantization. You can simply use your LoRA with:

transformer.update_lora_params(path_to_your_lora)
transformer.set_lora_strength(lora_strength)

path_to_your_lora can also be a remote HuggingFace path. In examples/flux.1-dev-lora.py, we provide a minimal example script for running Ghibsky LoRA with SVDQuant's 4-bit FLUX.1-dev:

import torch
from diffusers import FluxPipeline

from nunchaku import NunchakuFluxTransformer2dModel
from nunchaku.utils import get_precision

precision = get_precision()  # auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer = NunchakuFluxTransformer2dModel.from_pretrained(f"mit-han-lab/svdq-{precision}-flux.1-dev")
pipeline = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")

### LoRA Related Code ###
transformer.update_lora_params(
    "aleksa-codes/flux-ghibsky-illustration/lora.safetensors"
)  # Path to your LoRA safetensors, can also be a remote HuggingFace path
transformer.set_lora_strength(1)  # Your LoRA strength here
### End of LoRA Related Code ###

image = pipeline(
    "GHIBSKY style, cozy mountain cabin covered in snow, with smoke curling from the chimney and a warm, inviting light spilling through the windows",  # noqa: E501
    num_inference_steps=25,
    guidance_scale=3.5,
).images[0]
image.save(f"flux.1-dev-ghibsky-{precision}.png")

To compose multiple LoRAs, you can use nunchaku.lora.flux.compose.compose_lora to compose them. The usage is

composed_lora = compose_lora(
    [
        ("PATH_OR_STATE_DICT_OF_LORA1", lora_strength1),
        ("PATH_OR_STATE_DICT_OF_LORA2", lora_strength2),
        # Add more LoRAs as needed
    ]
)  # set your lora strengths here when using composed lora
transformer.update_lora_params(composed_lora)

You can specify individual strengths for each LoRA in the list. For a complete example, refer to examples/flux.1-dev-multiple-lora.py.

For ComfyUI users, you can directly use our LoRA loader. The converted LoRA is deprecated. Please refer to mit-han-lab/ComfyUI-nunchaku for more details.

ControlNets

Nunchaku supports both the FLUX.1-tools and the FLUX.1-dev-ControlNet-Union-Pro models. Example scripts can be found in the examples directory.

control

ComfyUI

Please refer to mit-han-lab/ComfyUI-nunchaku for the usage in ComfyUI.

Gradio Demos

Customized Model Quantization

Please refer to mit-han-lab/deepcompressor. A simpler workflow is coming soon.

Benchmark

Please refer to app/flux/t2i/README.md for instructions on reproducing our paper's quality results and benchmarking inference latency on FLUX.1 models.

Roadmap

Please check here for the roadmap for April.

Citation

If you find nunchaku useful or relevant to your research, please cite our paper:

@inproceedings{
  li2024svdquant,
  title={SVDQuant: Absorbing Outliers by Low-Rank Components for 4-Bit Diffusion Models},
  author={Li*, Muyang and Lin*, Yujun and Zhang*, Zhekai and Cai, Tianle and Li, Xiuyu and Guo, Junxian and Xie, Enze and Meng, Chenlin and Zhu, Jun-Yan and Han, Song},
  booktitle={The Thirteenth International Conference on Learning Representations},
  year={2025}
}

Related Projects

Contact Us

For enterprises interested in adopting SVDQuant or Nunchaku, including technical consulting, sponsorship opportunities, or partnership inquiries, please contact us at [email protected].

Acknowledgments

We thank MIT-IBM Watson AI Lab, MIT and Amazon Science Hub, MIT AI Hardware Program, National Science Foundation, Packard Foundation, Dell, LG, Hyundai, and Samsung for supporting this research. We thank NVIDIA for donating the DGX server.

We use img2img-turbo to train the sketch-to-image LoRA. Our text-to-image and image-to-image UI is built upon playground-v.25 and img2img-turbo, respectively. Our safety checker is borrowed from hart.

Nunchaku is also inspired by many open-source libraries, including (but not limited to) TensorRT-LLM, vLLM, QServe, AWQ, FlashAttention-2, and Atom.