Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion models/zimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,21 @@
from nunchaku.utils import pad_tensor


def add_comfy_cast_weights_attr(svdq_linear: SVDQW4A4Linear, comfy_linear: nn.Linear):
"""
Add dummy `comfy_cast_weights` and `weight`fields to a SVDQW4A4Linear module.

Make it compatible with offloading mechanism in ModelPatcher class in a lower vram condition.

Note
----
See method `comfy.model_patcher.ModelPatcher#_load_list` and method `comfy.model_patcher.ModelPatcher#load`
"""
if hasattr(comfy_linear, "comfy_cast_weights"):
svdq_linear.comfy_cast_weights = comfy_linear.comfy_cast_weights
svdq_linear.weight = None


def fuse_to_svdquant_linear(comfy_linear1: nn.Linear, comfy_linear2: nn.Linear, **kwargs) -> SVDQW4A4Linear:
"""
Fuse two linear modules into one SVDQW4A4Linear.
Expand All @@ -45,14 +60,16 @@ def fuse_to_svdquant_linear(comfy_linear1: nn.Linear, comfy_linear2: nn.Linear,
assert comfy_linear1.in_features == comfy_linear2.in_features
assert comfy_linear1.bias is None and comfy_linear2.bias is None
torch_dtype = kwargs.pop("torch_dtype", comfy_linear1.weight.dtype)
return SVDQW4A4Linear(
svdq_linear = SVDQW4A4Linear(
comfy_linear1.in_features,
comfy_linear1.out_features + comfy_linear2.out_features,
bias=False,
torch_dtype=torch_dtype,
device=comfy_linear1.weight.device,
**kwargs,
)
add_comfy_cast_weights_attr(svdq_linear, comfy_linear1)
return svdq_linear


def fused_qkv_norm_rotary(
Expand Down Expand Up @@ -114,7 +131,9 @@ def __init__(self, orig_attn: JointAttention, **kwargs):
self.head_dim = orig_attn.head_dim

self.qkv = SVDQW4A4Linear.from_linear(orig_attn.qkv, **kwargs)
add_comfy_cast_weights_attr(self.qkv, orig_attn.qkv)
self.out = SVDQW4A4Linear.from_linear(orig_attn.out, **kwargs)
add_comfy_cast_weights_attr(self.out, orig_attn.out)

self.q_norm = orig_attn.q_norm
self.k_norm = orig_attn.k_norm
Expand Down Expand Up @@ -180,6 +199,7 @@ def __init__(self, orig_ff: FeedForward, **kwargs):
super().__init__()
self.w13 = fuse_to_svdquant_linear(orig_ff.w1, orig_ff.w3, **kwargs)
self.w2 = SVDQW4A4Linear.from_linear(orig_ff.w2, **kwargs)
add_comfy_cast_weights_attr(self.w2, orig_ff.w2)

def _forward_silu_gating(self, x1, x3):
return clamp_fp16(F.silu(x1) * x3)
Expand Down
12 changes: 2 additions & 10 deletions nodes/models/zimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,6 @@ def _load(sd: dict[str, torch.Tensor], metadata: dict[str, str] = {}):
if len(temp_sd) > 0:
sd = temp_sd

parameters = comfy.utils.calculate_parameters(sd)
weight_dtype = comfy.utils.weight_dtype(sd)

load_device = model_management.get_torch_device()
offload_device = model_management.unet_offload_device()
check_hardware_compatibility(quantization_config, load_device)
Expand All @@ -137,13 +134,8 @@ def _load(sd: dict[str, torch.Tensor], metadata: dict[str, str] = {}):
model_config = NunchakuZImage(rank=rank, precision=precision, skip_refiners=skip_refiners)

if not is_turing():
unet_weight_dtype = list(model_config.supported_inference_dtypes)
unet_dtype = model_management.unet_dtype(
model_params=parameters, supported_dtypes=unet_weight_dtype, weight_dtype=weight_dtype
)
manual_cast_dtype = model_management.unet_manual_cast(
unet_dtype, load_device, model_config.supported_inference_dtypes
)
unet_dtype = torch.bfloat16
manual_cast_dtype = None
torch_dtype = torch.bfloat16
else:
unet_dtype = torch.bfloat16
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ requires = [ "setuptools>=61", "wheel" ]

[project]
name = "comfyui-nunchaku"
version = "1.2.0"
version = "1.2.1"
description = "Nunchaku ComfyUI Node. Nunchaku is a high-performance inference engine optimized for low-bit neural networks. See more details: https://github.com/nunchaku-tech/nunchaku"
license = { file = "LICENCE.txt" }
requires-python = ">=3.10,<3.14"
Expand Down