Skip to content

Latest commit

 

History

History
340 lines (234 loc) · 19.8 KB

README_ZH.md

File metadata and controls

340 lines (234 loc) · 19.8 KB
**Nunchaku** 是一款专为4-bit神经网络优化的高性能推理引擎,基于我们的论文 [SVDQuant](http://arxiv.org/abs/2411.05007) 提出。底层量化库请参考 [DeepCompressor](https://github.com/mit-han-lab/deepcompressor)。

欢迎加入我们的用户群:SlackDiscord微信,与社区交流!更多详情请见此处。如有任何问题、建议或贡献意向,欢迎随时联系!

最新动态

更多动态

项目概览

teaser SVDQuant 是一种支持4-bit权重和激活的后训练量化技术,能有效保持视觉质量。在12B FLUX.1-dev模型上,相比BF16模型实现了3.6倍内存压缩。通过消除CPU offloading,在16GB笔记本RTX 4090上比16位模型快8.7倍,比NF4 W4A16基线快3倍。在PixArt-∑模型上,其视觉质量显著优于其他W4A4甚至W4A8方案。"E2E"表示包含文本编码器和VAE解码器的端到端延迟。

SVDQuant: 通过低秩分量吸收异常值实现4-bit扩散模型量化
Muyang Li*, Yujun Lin*, Zhekai Zhang*, Tianle Cai, Xiuyu Li, Junxian Guo, Enze Xie, Chenlin Meng, Jun-Yan Zhu, Song Han
麻省理工学院、英伟达、卡内基梅隆大学、普林斯顿大学、加州大学伯克利分校、上海交通大学、pika实验室

方法原理

量化方法 -- SVDQuant

intuitionSVDQuant三阶段示意图。阶段1:原始激活 $\boldsymbol{X}$ 和权重 $\boldsymbol{W}$ 均含异常值,4-bit量化困难。阶段2:将激活异常值迁移至权重,得到更易量化的激活 $\hat{\boldsymbol{X}}$ 和更难量化的权重 $\hat{\boldsymbol{W}}$ 。阶段3:通过SVD将 $\hat{\boldsymbol{W}}$ 分解为低秩分量 $\boldsymbol{L}_1\boldsymbol{L}_2$ 和残差 $\hat{\boldsymbol{W}}-\boldsymbol{L}_1\boldsymbol{L}_2$ ,低秩分支以16位精度运行缓解量化难度。

Nunchaku引擎设计

engine (a) 原始低秩分支(秩32)因额外读写16位数据引入57%的延迟。Nunchaku通过核融合优化。(b) 将下投影与量化、上投影与4-bit计算分别融合,减少数据搬运。

性能表现

efficiencySVDQuant 将12B FLUX.1模型的体积压缩了3.6倍,同时将原始16位模型的显存占用减少了3.5倍。借助Nunchaku,我们的INT4模型在桌面和笔记本的NVIDIA RTX 4090 GPU上比NF4 W4A16基线快了3.0倍。值得一提的是,在笔记本4090上,通过消除CPU offloading,总体加速达到了10.1倍。我们的NVFP4模型在RTX 5090 GPU上也比BF16和NF4快了3.1倍。

安装指南

我们提供了在 Windows 上安装和使用 Nunchaku 的教学视频,支持英文中文两个版本。同时,你也可以参考对应的图文教程 docs/setup_windows.md。如果在安装过程中遇到问题,建议优先查阅这些资源。

Wheel包安装

前置条件

确保已安装 PyTorch>=2.5。例如:

pip install torch==2.6 torchvision==0.21 torchaudio==2.6

安装nunchaku

Hugging FaceModelScopeGitHub release选择对应Python和PyTorch版本的wheel。例如Python 3.11和PyTorch 2.6:

pip install https://huggingface.co/mit-han-lab/nunchaku/resolve/main/nunchaku-0.2.0+torch2.6-cp311-cp311-linux_x86_64.whl
ComfyUI用户

若使用ComfyUI便携包,请确保将nunchaku安装到ComfyUI自带的Python环境。查看ComfyUI日志获取Python路径:

** Python executable: G:\ComfyuI\python\python.exe

使用该Python安装wheel:

"G:\ComfyUI\python\python.exe" -m pip install <your-wheel-file>.whl

示例:为Python 3.11和PyTorch 2.6安装:

"G:\ComfyUI\python\python.exe" -m pip install https://github.com/mit-han-lab/nunchaku/releases/download/v0.2.0/nunchaku-0.2.0+torch2.6-cp311-cp311-linux_x86_64.whl
Blackwell显卡用户(50系列)

若使用Blackwell显卡(如50系列),请安装PyTorch 2.7及以上版本,并使用FP4模型

源码编译

注意

  • Linux需CUDA≥12.2,Windows需CUDA≥12.6。Blackwell显卡需CUDA≥12.8。
  • Windows用户请参考此问题升级MSVC编译器。
  • 支持SM_75(Turing:RTX 2080)、SM_86(Ampere:RTX 3090)、SM_89(Ada:RTX 4090)、SM_80(A100)架构显卡,详见此问题
  1. 安装依赖:

    conda create -n nunchaku python=3.11
    conda activate nunchaku
    pip install torch torchvision torchaudio
    pip install ninja wheel diffusers transformers accelerate sentencepiece protobuf huggingface_hub
    
    # Gradio演示依赖
    pip install peft opencv-python gradio spaces GPUtil  

    Blackwell用户需安装PyTorch nightly(CUDA 12.8):

    pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128
  2. 编译安装: 确保gcc/g++≥11。Linux用户可通过Conda安装:

    conda install -c conda-forge gxx=11 gcc=11

    Windows用户请安装最新Visual Studio

    编译命令:

    git clone https://github.com/mit-han-lab/nunchaku.git
    cd nunchaku
    git submodule init
    git submodule update
    python setup.py develop

    打包wheel:

    NUNCHAKU_INSTALL_MODE=ALL NUNCHAKU_BUILD_WHEELS=1 python -m build --wheel --no-isolation

    设置NUNCHAKU_INSTALL_MODE=ALL确保wheel支持所有显卡架构。

使用示例

示例中,我们提供了运行4-bitFLUX.1SANA模型的极简脚本,API与diffusers兼容。例如FLUX.1-dev脚本:

import torch
from diffusers import FluxPipeline

from nunchaku import NunchakuFluxTransformer2dModel
from nunchaku.utils import get_precision

precision = get_precision()  # 自动检测GPU支持的精度(int4或fp4)
transformer = NunchakuFluxTransformer2dModel.from_pretrained(f"mit-han-lab/svdq-{precision}-flux.1-dev")
pipeline = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")
image = pipeline("举着'Hello World'标牌的猫咪", num_inference_steps=50, guidance_scale=3.5).images[0]
image.save(f"flux.1-dev-{precision}.png")

注意:**Turing显卡用户(如20系列)**需设置torch_dtype=torch.float16并使用nunchaku-fp16注意力模块,完整示例见examples/flux.1-dev-turing.py

FP16 Attention

除FlashAttention-2外,Nunchaku提供定制FP16 Attention实现,在30/40/50系显卡上提速1.2倍且无损精度。启用方式:

transformer.set_attention_impl("nunchaku-fp16")

完整示例见examples/flux.1-dev-fp16attn.py

First-Block Cache

Nunchaku支持First-Block Cache加速长步去噪。启用方式:

apply_cache_on_pipe(pipeline, residual_diff_threshold=0.12)

residual_diff_threshold越大速度越快但可能影响质量,推荐值0.12,50步推理提速2倍,30步提速1.4倍。完整示例见examples/flux.1-dev-cache.py

CPU offloading

最小化显存占用至4 GiB,设置offload=True并启用CPU offloading:

pipeline.enable_sequential_cpu_offload()

完整示例见examples/flux.1-dev-offload.py

自定义LoRA

lora

SVDQuant 可以无缝集成现有的 LoRA,而无需重新量化。你可以简单地通过以下方式使用你的 LoRA:

transformer.update_lora_params(path_to_your_lora)
transformer.set_lora_strength(lora_strength)

path_to_your_lora 也可以是一个远程的 HuggingFace 路径。在 examples/flux.1-dev-lora.py 中,我们提供了一个运行 Ghibsky LoRA 的最小示例脚本,结合了 SVDQuant 的 4-bit FLUX.1-dev:

import torch
from diffusers import FluxPipeline

from nunchaku import NunchakuFluxTransformer2dModel
from nunchaku.utils import get_precision

precision = get_precision()  # 自动检测你的精度是 'int4' 还是 'fp4',取决于你的 GPU
transformer = NunchakuFluxTransformer2dModel.from_pretrained(f"mit-han-lab/svdq-{precision}-flux.1-dev")
pipeline = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")

### LoRA 相关代码 ###
transformer.update_lora_params(
    "aleksa-codes/flux-ghibsky-illustration/lora.safetensors"
)  # 你的 LoRA safetensors 路径,也可以是远程 HuggingFace 路径
transformer.set_lora_strength(1)  # 在这里设置你的 LoRA 强度
### LoRA 相关代码结束 ###

image = pipeline(
    "GHIBSKY 风格,被雪覆盖的舒适山间小屋,烟囱里冒出袅袅炊烟,窗户透出温暖诱人的灯光",  # noqa: E501
    num_inference_steps=25,
    guidance_scale=3.5,
).images[0]
image.save(f"flux.1-dev-ghibsky-{precision}.png")

如果需要组合多个 LoRA,可以使用 nunchaku.lora.flux.compose.compose_lora 来实现组合。用法如下:

composed_lora = compose_lora(
    [
        ("PATH_OR_STATE_DICT_OF_LORA1", lora_strength1),
        ("PATH_OR_STATE_DICT_OF_LORA2", lora_strength2),
        # 根据需要添加更多 LoRA
    ]
)  # 在使用组合 LoRA 时在此处设置每个 LoRA 的强度
transformer.update_lora_params(composed_lora)

你可以为列表中的每个 LoRA 指定单独的强度。完整的示例请参考 examples/flux.1-dev-multiple-lora.py

对于 ComfyUI 用户,你可以直接使用我们的 LoRA 加载器。转换后的 LoRA 已被弃用,请参考 mit-han-lab/ComfyUI-nunchaku 获取更多详细信息。

ControlNets

Nunchaku 支持 FLUX.1-toolsFLUX.1-dev-ControlNet-Union-Pro 模型。示例脚本可以在 examples 目录中找到。

control

ComfyUI

请参考 mit-han-lab/ComfyUI-nunchaku 获取在 ComfyUI 中的使用方法。

使用演示

自定义模型量化

请参考 mit-han-lab/deepcompressor。更简单的流程即将推出。

基准测试

请参考 app/flux/t2i/README.md 获取重现我们论文质量结果和对 FLUX.1 模型进行推理延迟基准测试的说明。

路线图

请查看 此处 获取四月的路线图。

引用

如果你觉得 nunchaku 对你的研究有用或相关,请引用我们的论文:

@inproceedings{
  li2024svdquant,
  title={SVDQuant: Absorbing Outliers by Low-Rank Components for 4-Bit Diffusion Models},
  author={Li*, Muyang and Lin*, Yujun and Zhang*, Zhekai and Cai, Tianle and Li, Xiuyu and Guo, Junxian and Xie, Enze and Meng, Chenlin and Zhu, Jun-Yan and Han, Song},
  booktitle={The Thirteenth International Conference on Learning Representations},
  year={2025}
}

相关项目

联系我们

对于有兴趣采用 SVDQuant 或 Nunchaku 的企业,包括技术咨询、赞助机会或合作意向,请联系 [email protected]

致谢

感谢 MIT-IBM Watson AI Lab、MIT 和Amazon Science Hub、MIT AI Hardware Program、National Science Foundation、Packard Foundation、Dell、LG、Hyundai和Samsung对本研究的支持。感谢 NVIDIA 捐赠 DGX 服务器。

我们使用 img2img-turbo 训练草图生成图像的 LoRA。我们的文生图和图像生成用户界面基于 playground-v.25img2img-turbo 构建。我们的安全检查器来自 hart

Nunchaku 还受到许多开源库的启发,包括(但不限于)TensorRT-LLMvLLMQServeAWQFlashAttention-2Atom