Skip to content

Conversation

@phatwila
Copy link

@phatwila phatwila commented Dec 26, 2025

Enable LoRA support for Z-Image (NextDiT/Lumina2) models, including multi-LoRA fusion, state reset, and test workflows.

Summary of Changes

  • SVDQuant Linear Wrapper
    Implemented _LoRALinear in wrappers/zimage.py to correctly apply LoRA weights to W4A4 SVDQuant linear modules.

  • Fusion Order Parity
    Added _fuse_qkv_lora and _fuse_w13_lora logic to preserve parity with Z-Image fused projection layouts, ensuring LoRA composition matches the model’s native weight fusion order.

  • Robust Weight Management
    Implemented reset_lora to safely restore original unpatched modules between workflow executions, preventing LoRA state leakage across runs.

  • Node Integration
    Fully integrated ComfyZImageWrapper into NunchakuZImageDiTLoader, enabling automatic LoRA patching through standard ComfyUI workflows.


Motivation

Z-Image turbo models (NextDiT/Lumina2) are commonly deployed in Nunchaku-quantized form for performance reasons, but this previously prevented the use of LoRAs due to incompatible weight formats and fused projection layers.

This PR enables full LoRA support for these models within ComfyUI, allowing users to apply stylistic and domain LoRAs (for example anime, illustration, realism or vintage cartoon styles) while retaining the performance and memory advantages of W4A4 quantization.

The goal is functional parity with standard ComfyUI LoRA workflows, without requiring users to load or maintain full-precision model variants.


Modifications

wrappers/zimage.py

  • Added _LoRALinear, a wrapper for applying LoRA deltas to SVDQuant linear layers.
  • Implemented _fuse_qkv_lora and _fuse_w13_lora to correctly merge LoRA weights into Z-Image fused projection matrices.
  • Added reset_lora to restore original module state between executions.
  • Added compose_loras to manage multi-LoRA composition when standard patching is not applicable.
  • Introduced ComfyZImageWrapper, which manages LoRA state and applies composition during forward passes when required.

nodes/lora/zimage.py

  • Added NunchakuZImageLoraLoader, a dedicated LoRA loader node for Z-Image models.
  • Uses copy_with_ctx to clone model wrappers with attached LoRA state, preserving standard ComfyUI model cloning semantics.

nodes/models/zimage.py

  • Updated NunchakuZImageDiTLoader to wrap the diffusion model with ComfyZImageWrapper, enabling transparent LoRA support at load time.

Tests

  • Added a new test workflow under tests/workflows/nunchaku-z-image-turbo-lora/ demonstrating a vintage cartoon style LoRA.
  • Test parameters: seed 88888, steps 9, shift 7, Euler sampler.
  • Reference images are uploaded and registered in test_cases.json.

Compatibility Fix

  • Fixed model_patcher.load() to return (self,), ensuring correct state tracking by ComfyUI’s execution engine.

Checklist

  • Code is formatted using Pre-Commit hooks (pre-commit run --all-files)
  • Relevant test workflows are added under tests/workflows
  • Reference images are uploaded and registered in test_cases.json
  • Additional models registered in scripts/download_models.py and test_data/models.yaml
z-image_00214_

- Add NunchakuZImageLoraLoader node for applying LoRAs to Z-Image models

- Add ComfyZImageWrapper with _LoRALinear for SVDQuant linear layers

- Add _fuse_qkv_lora and _fuse_w13_lora for fusion order parity

- Add reset_lora for robust weight management between runs

- Add test workflow for vintage cartoon style LoRA

- Register vintage cartoon LoRA in models.yaml
@Ph0rk0z
Copy link

Ph0rk0z commented Dec 27, 2025

Tested and it works.

@flybirdxx
Copy link

image image I tested this PR, but I got UNET misses and noisy images.

@phatwila
Copy link
Author

@flybirdxx
I'm not able to replicate this issue, but looking at the unet_missing warnings it include all layers (layers.0 through layers.29), so quantization weights are missing. You might be using the wrong model... I see your using the svdq-fp4_r128 model which is the FP4 variant for RTX 50-series GPUs. If you're running on an older GPU (pre-50-series), try using the INT4 svdq-int4_r128 model instead and see if that helps.

@phatwila phatwila closed this Dec 29, 2025
@phatwila phatwila reopened this Dec 29, 2025
@ssitu
Copy link

ssitu commented Dec 29, 2025

LoRAs seem to stay on the GPU when the --lowvram argument is enabled, which offloads everything to the CPU after use. I've attempted a quick fix by setting the device_id to 'cpu' in the ctx_for_copy variable when CPU offloading is detected. This reduces the leftover VRAM in use by half, but there's still the other half being reserved, the source of which is not obvious to me.

- Store adaLN LoRA A/B as buffers on the underlying nn.Linear so ModelPatcher leaf-module .to() moves them in lowvram mode
- Keep Z-Image wrapper .to()/to_safely returning self
- Ensure ModelPatcher is created with offload_device when copying Z-Image models
@phatwila
Copy link
Author

@ssitu
Nice catch! Give the latest commit a try and let me know if it fixes the VRAM retention issue.

@ssitu
Copy link

ssitu commented Dec 30, 2025

The last commit gives the same savings as my quick hack, but there is still VRAM being retained. I tried some things out, and the diff below fixes the issue for me. I'd guess the buffers are not registered correctly, but I'm not sure. Any ideas? I am on torch 2.9.0 if that has any impact.

+++ b/wrappers/zimage.py
@@ -29,6 +29,7 @@ from nunchaku.lora.flux.nunchaku_converter import (
     unpack_lowrank_weight,
 )
 from nunchaku.models.linear import SVDQW4A4Linear
+from ..model_patcher import NunchakuModelPatcher
 from nunchaku.utils import load_state_dict_in_safetensors

 logger = logging.getLogger(__name__)
@@ -77,6 +78,7 @@ class _LoRALinear(nn.Module):
     def __init__(self, base: nn.Linear):
         super().__init__()
         self.base = base
+        self.loras: List[Tuple[torch.Tensor, torch.Tensor]] = []  # (A, B) where delta = (x @ A.T) @ B.T

     @property
     def in_features(self) -> int:
@@ -93,6 +95,13 @@ class _LoRALinear(nn.Module):
     @property
     def bias(self) -> Optional[torch.Tensor]:
         return self.base.bias
+    
+    def _apply(self, fn):
+        """Override _apply to also apply to LoRA weights when model is moved."""
+        super()._apply(fn)
+        if self.loras:
+            self.loras = [(fn(A), fn(B)) for A, B in self.loras]
+        return self

     @staticmethod
     def _register_or_set_buffer(module: nn.Module, name: str, tensor: torch.Tensor) -> None:
@@ -671,5 +680,5 @@ def copy_with_ctx(model_wrapper: ComfyZImageWrapper) -> Tuple[ComfyZImageWrapper
     device_id = ctx_for_copy.get("device_id", 0)
     offload_device = ctx_for_copy.get("offload_device", torch.device("cpu"))

-    ret_model = ModelPatcher(model_base, load_device=device, offload_device=offload_device)
+    ret_model = NunchakuModelPatcher(model_base, load_device=device, offload_device=offload_device)
     return ret_model_wrapper, ret_model

@judian17
Copy link

judian17 commented Dec 30, 2025

@flybirdxx I'm not able to replicate this issue, but looking at the unet_missing warnings it include all layers (layers.0 through layers.29), so quantization weights are missing. You might be using the wrong model... I see your using the svdq-fp4_r128 model which is the FP4 variant for RTX 50-series GPUs. If you're running on an older GPU (pre-50-series), try using the INT4 svdq-int4_r128 model instead and see if that helps.

I’m having the same issue as @flybirdxx. Could there be differences between FP4 and INT4 models? Maybe those who got it working aren’t using RTX 50-series GPUs? I run ComfyUI with 5070ti.

Supplement: I can successfully run the FP4 zimage model to generate images using the official ComfyUI-Nunchaku, but when using this PR even without LORA, it shows the above error messages and produces noisy output instead.

model_type FLOW
unet missing: ['noise_refiner.0.attention.out.wcscales', 'noise_refiner.0.feed_forward.w13.wcscales', 'noise_refiner.0.feed_forward.w2.wcscales', 'noise_refiner.1.attention.out.wcscales', 'noise_refiner.1.feed_forward.w13.wcscales', 'noise_refiner.1.feed_forward.w2.wcscales', 'context_refiner.0.attention.out.wcscales', 'context_refiner.0.feed_forward.w13.wcscales', 'context_refiner.0.feed_forward.w2.wcscales', 'context_refiner.1.attention.out.wcscales', 'context_refiner.1.feed_forward.w13.wcscales', 'context_refiner.1.feed_forward.w2.wcscales', 'layers.0.attention.out.wcscales', 'layers.0.feed_forward.w13.wcscales', 'layers.0.feed_forward.w2.wcscales', 'layers.1.attention.out.wcscales', 'layers.1.feed_forward.w13.wcscales', 'layers.1.feed_forward.w2.wcscales', 'layers.2.attention.out.wcscales', 'layers.2.feed_forward.w13.wcscales', 'layers.2.feed_forward.w2.wcscales', 'layers.3.attention.out.wcscales', 'layers.3.feed_forward.w13.wcscales', 'layers.3.feed_forward.w2.wcscales', 'layers.4.attention.out.wcscales', 'layers.4.feed_forward.w13.wcscales', 'layers.4.feed_forward.w2.wcscales', 'layers.5.attention.out.wcscales', 'layers.5.feed_forward.w13.wcscales', 'layers.5.feed_forward.w2.wcscales', 'layers.6.attention.out.wcscales', 'layers.6.feed_forward.w13.wcscales', 'layers.6.feed_forward.w2.wcscales', 'layers.7.attention.out.wcscales', 'layers.7.feed_forward.w13.wcscales', 'layers.7.feed_forward.w2.wcscales', 'layers.8.attention.out.wcscales', 'layers.8.feed_forward.w13.wcscales', 'layers.8.feed_forward.w2.wcscales', 'layers.9.attention.out.wcscales', 'layers.9.feed_forward.w13.wcscales', 'layers.9.feed_forward.w2.wcscales', 'layers.10.attention.out.wcscales', 'layers.10.feed_forward.w13.wcscales', 'layers.10.feed_forward.w2.wcscales', 'layers.11.attention.out.wcscales', 'layers.11.feed_forward.w13.wcscales', 'layers.11.feed_forward.w2.wcscales', 'layers.12.attention.out.wcscales', 'layers.12.feed_forward.w13.wcscales', 'layers.12.feed_forward.w2.wcscales', 'layers.13.attention.out.wcscales', 'layers.13.feed_forward.w13.wcscales', 'layers.13.feed_forward.w2.wcscales', 'layers.14.attention.out.wcscales', 'layers.14.feed_forward.w13.wcscales', 'layers.14.feed_forward.w2.wcscales', 'layers.15.attention.out.wcscales', 'layers.15.feed_forward.w13.wcscales', 'layers.15.feed_forward.w2.wcscales', 'layers.16.attention.out.wcscales', 'layers.16.feed_forward.w13.wcscales', 'layers.16.feed_forward.w2.wcscales', 'layers.17.attention.out.wcscales', 'layers.17.feed_forward.w13.wcscales', 'layers.17.feed_forward.w2.wcscales', 'layers.18.attention.out.wcscales', 'layers.18.feed_forward.w13.wcscales', 'layers.18.feed_forward.w2.wcscales', 'layers.19.attention.out.wcscales', 'layers.19.feed_forward.w13.wcscales', 'layers.19.feed_forward.w2.wcscales', 'layers.20.attention.out.wcscales', 'layers.20.feed_forward.w13.wcscales', 'layers.20.feed_forward.w2.wcscales', 'layers.21.attention.out.wcscales', 'layers.21.feed_forward.w13.wcscales', 'layers.21.feed_forward.w2.wcscales', 'layers.22.attention.out.wcscales', 'layers.22.feed_forward.w13.wcscales', 'layers.22.feed_forward.w2.wcscales', 'layers.23.attention.out.wcscales', 'layers.23.feed_forward.w13.wcscales', 'layers.23.feed_forward.w2.wcscales', 'layers.24.attention.out.wcscales', 'layers.24.feed_forward.w13.wcscales', 'layers.24.feed_forward.w2.wcscales', 'layers.25.attention.out.wcscales', 'layers.25.feed_forward.w13.wcscales', 'layers.25.feed_forward.w2.wcscales', 'layers.26.attention.out.wcscales', 'layers.26.feed_forward.w13.wcscales', 'layers.26.feed_forward.w2.wcscales', 'layers.27.attention.out.wcscales', 'layers.27.feed_forward.w13.wcscales', 'layers.27.feed_forward.w2.wcscales', 'layers.28.attention.out.wcscales', 'layers.28.feed_forward.w13.wcscales', 'layers.28.feed_forward.w2.wcscales', 'layers.29.attention.out.wcscales', 'layers.29.feed_forward.w13.wcscales', 'layers.29.feed_forward.w2.wcscales']
unet unexpected: ['noise_refiner.0.attention.out.wtscale', 'noise_refiner.0.feed_forward.w13.wtscale', 'noise_refiner.0.feed_forward.w2.wtscale', 'noise_refiner.1.attention.out.wtscale', 'noise_refiner.1.feed_forward.w13.wtscale', 'noise_refiner.1.feed_forward.w2.wtscale', 'context_refiner.0.attention.out.wtscale', 'context_refiner.0.feed_forward.w13.wtscale', 'context_refiner.0.feed_forward.w2.wtscale', 'context_refiner.1.attention.out.wtscale', 'context_refiner.1.feed_forward.w13.wtscale', 'context_refiner.1.feed_forward.w2.wtscale', 'layers.0.attention.out.wtscale', 'layers.0.feed_forward.w13.wtscale', 'layers.0.feed_forward.w2.wtscale', 'layers.1.attention.out.wtscale', 'layers.1.feed_forward.w13.wtscale', 'layers.1.feed_forward.w2.wtscale', 'layers.2.attention.out.wtscale', 'layers.2.feed_forward.w13.wtscale', 'layers.2.feed_forward.w2.wtscale', 'layers.3.attention.out.wtscale', 'layers.3.feed_forward.w13.wtscale', 'layers.3.feed_forward.w2.wtscale', 'layers.4.attention.out.wtscale', 'layers.4.feed_forward.w13.wtscale', 'layers.4.feed_forward.w2.wtscale', 'layers.5.attention.out.wtscale', 'layers.5.feed_forward.w13.wtscale', 'layers.5.feed_forward.w2.wtscale', 'layers.6.attention.out.wtscale', 'layers.6.feed_forward.w13.wtscale', 'layers.6.feed_forward.w2.wtscale', 'layers.7.attention.out.wtscale', 'layers.7.feed_forward.w13.wtscale', 'layers.7.feed_forward.w2.wtscale', 'layers.8.attention.out.wtscale', 'layers.8.feed_forward.w13.wtscale', 'layers.8.feed_forward.w2.wtscale', 'layers.9.attention.out.wtscale', 'layers.9.feed_forward.w13.wtscale', 'layers.9.feed_forward.w2.wtscale', 'layers.10.attention.out.wtscale', 'layers.10.feed_forward.w13.wtscale', 'layers.10.feed_forward.w2.wtscale', 'layers.11.attention.out.wtscale', 'layers.11.feed_forward.w13.wtscale', 'layers.11.feed_forward.w2.wtscale', 'layers.12.attention.out.wtscale', 'layers.12.feed_forward.w13.wtscale', 'layers.12.feed_forward.w2.wtscale', 'layers.13.attention.out.wtscale', 'layers.13.feed_forward.w13.wtscale', 'layers.13.feed_forward.w2.wtscale', 'layers.14.attention.out.wtscale', 'layers.14.feed_forward.w13.wtscale', 'layers.14.feed_forward.w2.wtscale', 'layers.15.attention.out.wtscale', 'layers.15.feed_forward.w13.wtscale', 'layers.15.feed_forward.w2.wtscale', 'layers.16.attention.out.wtscale', 'layers.16.feed_forward.w13.wtscale', 'layers.16.feed_forward.w2.wtscale', 'layers.17.attention.out.wtscale', 'layers.17.feed_forward.w13.wtscale', 'layers.17.feed_forward.w2.wtscale', 'layers.18.attention.out.wtscale', 'layers.18.feed_forward.w13.wtscale', 'layers.18.feed_forward.w2.wtscale', 'layers.19.attention.out.wtscale', 'layers.19.feed_forward.w13.wtscale', 'layers.19.feed_forward.w2.wtscale', 'layers.20.attention.out.wtscale', 'layers.20.feed_forward.w13.wtscale', 'layers.20.feed_forward.w2.wtscale', 'layers.21.attention.out.wtscale', 'layers.21.feed_forward.w13.wtscale', 'layers.21.feed_forward.w2.wtscale', 'layers.22.attention.out.wtscale', 'layers.22.feed_forward.w13.wtscale', 'layers.22.feed_forward.w2.wtscale', 'layers.23.attention.out.wtscale', 'layers.23.feed_forward.w13.wtscale', 'layers.23.feed_forward.w2.wtscale', 'layers.24.attention.out.wtscale', 'layers.24.feed_forward.w13.wtscale', 'layers.24.feed_forward.w2.wtscale', 'layers.25.attention.out.wtscale', 'layers.25.feed_forward.w13.wtscale', 'layers.25.feed_forward.w2.wtscale', 'layers.26.attention.out.wtscale', 'layers.26.feed_forward.w13.wtscale', 'layers.26.feed_forward.w2.wtscale', 'layers.27.attention.out.wtscale', 'layers.27.feed_forward.w13.wtscale', 'layers.27.feed_forward.w2.wtscale', 'layers.28.attention.out.wtscale', 'layers.28.feed_forward.w13.wtscale', 'layers.28.feed_forward.w2.wtscale', 'layers.29.attention.out.wtscale', 'layers.29.feed_forward.w13.wtscale', 'layers.29.feed_forward.w2.wtscale']
Requested to load Lumina2
loaded completely; 9476.67 MB usable, 4007.28 MB loaded, full load: True
Requested to load AutoencodingEngine

@phatwila
Copy link
Author

@ssitu Thanks for digging into this. I think the main difference with the NunchakuModelPatcher approach is that it bypasses ComfyUI’s standard ModelPatcher.load() bookkeeping (e.g. model.model_loaded_weight_memory). That can change ComfyUI’s smart-memory decisions and makes the logs/VRAM behavior harder to interpret. Patching NunchakuModelPatcher to fix that is outside the scope of this PR I think so I opted to keep the standard ModelPatcher for now.

Also, nvidia-smi reports allocator “reserved” VRAM; a real leak would show up as torch.cuda.memory_allocated() staying high after offload/unload.

I tested a small CUDA memory report to compare allocated vs reserved. On my side:

  • --lowvram: ComfyUI keeps the model allocated after the run (smart caching), but after a forced unload_all_models it drops to ~8 MB allocated / ~32 MB reserved.
  • --lowvram --disable-smart-memory: allocated drops to ~168 MB after VAE decode, and to ~8 MB / ~32 MB after unload_all_models.

So I’m not seeing persistent allocations consistent with LoRA weights stuck on the GPU any more on my end and any remaining “VRAM used” in nvidia-smi is allocator caching (reserved), not live tensors (allocated) from my understanding.

@judian17
Copy link

judian17 commented Dec 30, 2025

image image I tested this PR, but I got UNET misses and noisy images.

@flybirdxx @phatwila I consulted an AI assistant regarding this issue, and it suggested modifying custom_nodes\ComfyUI-nunchaku\nodes\models\zimage.py.

After applying the changes below, I successfully ran generation using LoRA on an RTX 5070 Ti. I am sharing this for the author's reference, though I am not entirely sure if this modification is sufficient for all use cases.

Modifications in nodes/models/zimage.py

1. Add the import to the top of the file:

from nunchaku.models.transformers.utils import patch_scale_key

2. Update the model loading section:

model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
    model = model_config.get_model(patched_sd, "")

    # --- Start of modification ---
    # Apply the patch_scale_key before loading weights
    patch_scale_key(model.diffusion_model, patched_sd) 
    # --- End of modification ---

    model.load_model_weights(patched_sd, "")

    # Preserve the actual CUDA index when running multi-GPU.
    # ... (subsequent code) ...
    ```

Introduces a new function '_apply_nvfp4_scale_keys()' to handle Nunchaku SVDQW4A4Linear scale parameters correctly:
- Fill missing 'wcscales' with ones (stable default behavior)
- Remap per-channel scales stored under 'wtscale' to 'wcscales' if shape matches
- Pop float 'wtscale' from state dict and assign to module attribute

Also: robust 'attention.to_out' key replacement for broader checkpoint formats.

Fixes noisy output / "unexpected key" errors on RTX 50-series GPUs.
@phatwila
Copy link
Author

@flybirdxx @judian17
Thanks! Let me know if the latest commit fixes the issue for you. I don't have a 50-series card so hopefully I nailed the fix first try.

@judian17
Copy link

@flybirdxx @judian17 Thanks! Let me know if the latest commit fixes the issue for you. I don't have a 50-series card so hopefully I nailed the fix first try.

Yes, the latest version works very well,Thank you!

@ssitu
Copy link

ssitu commented Dec 30, 2025

  • --lowvram: ComfyUI keeps the model allocated after the run (smart caching), but after a forced unload_all_models it drops to ~8 MB allocated / ~32 MB reserved.
  • --lowvram --disable-smart-memory: allocated drops to ~168 MB after VAE decode, and to ~8 MB / ~32 MB after unload_all_models.

Ah gotcha, I didn't have the smart memory turned off so it wasn't clear to me if it was a leak or not. Turning it off does make it clear that it is fixed, thanks!
This all stemmed from doing a comparison with Flux, and I would get crashes running both SVDQuants of Flux and ZiT. Looks like the Nunchaku Flux implementation might have some leaking as well for lowvram users, but that's an issue for another time

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants