Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion models/zimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,21 @@
from nunchaku.utils import pad_tensor


def add_comfy_cast_weights_attr(svdq_linear: SVDQW4A4Linear, comfy_linear: nn.Linear):
"""
Add dummy `comfy_cast_weights` and `weight`fields to a SVDQW4A4Linear module.

Make it compatible with offloading mechanism in ModelPatcher class in a lower vram condition.

Note
----
See method `comfy.model_patcher.ModelPatcher#_load_list` and method `comfy.model_patcher.ModelPatcher#load`
"""
if hasattr(comfy_linear, "comfy_cast_weights"):
svdq_linear.comfy_cast_weights = comfy_linear.comfy_cast_weights
svdq_linear.weight = None


def fuse_to_svdquant_linear(comfy_linear1: nn.Linear, comfy_linear2: nn.Linear, **kwargs) -> SVDQW4A4Linear:
"""
Fuse two linear modules into one SVDQW4A4Linear.
Expand All @@ -45,14 +60,16 @@ def fuse_to_svdquant_linear(comfy_linear1: nn.Linear, comfy_linear2: nn.Linear,
assert comfy_linear1.in_features == comfy_linear2.in_features
assert comfy_linear1.bias is None and comfy_linear2.bias is None
torch_dtype = kwargs.pop("torch_dtype", comfy_linear1.weight.dtype)
return SVDQW4A4Linear(
svdq_linear = SVDQW4A4Linear(
comfy_linear1.in_features,
comfy_linear1.out_features + comfy_linear2.out_features,
bias=False,
torch_dtype=torch_dtype,
device=comfy_linear1.weight.device,
**kwargs,
)
add_comfy_cast_weights_attr(svdq_linear, comfy_linear1)
return svdq_linear


def fused_qkv_norm_rotary(
Expand Down Expand Up @@ -114,7 +131,9 @@ def __init__(self, orig_attn: JointAttention, **kwargs):
self.head_dim = orig_attn.head_dim

self.qkv = SVDQW4A4Linear.from_linear(orig_attn.qkv, **kwargs)
add_comfy_cast_weights_attr(self.qkv, orig_attn.qkv)
self.out = SVDQW4A4Linear.from_linear(orig_attn.out, **kwargs)
add_comfy_cast_weights_attr(self.out, orig_attn.out)

self.q_norm = orig_attn.q_norm
self.k_norm = orig_attn.k_norm
Expand Down Expand Up @@ -180,6 +199,7 @@ def __init__(self, orig_ff: FeedForward, **kwargs):
super().__init__()
self.w13 = fuse_to_svdquant_linear(orig_ff.w1, orig_ff.w3, **kwargs)
self.w2 = SVDQW4A4Linear.from_linear(orig_ff.w2, **kwargs)
add_comfy_cast_weights_attr(self.w2, orig_ff.w2)

def _forward_silu_gating(self, x1, x3):
return clamp_fp16(F.silu(x1) * x3)
Expand Down