From ba1a8e1f0dd4832695b843a8047e060e6da41cac Mon Sep 17 00:00:00 2001 From: Alexander Eichhorn Date: Mon, 4 May 2026 00:24:38 +0200 Subject: [PATCH] fix(mm): support diffusers FLUX LoRAs on NF4/8-bit quantized base models CustomInvokeLinearNF4 and CustomInvokeLinear8bitLt were missing the _cast_weight_bias_for_input / _cast_tensor_for_input methods that the sidecar-patches branch in autocast_linear_forward_sidecar_patches calls. This caused an AttributeError whenever a non-LoRALayer/FluxControlLoRALayer patch (e.g. MergedLayerPatch produced by the diffusers FLUX LoRA converter for fused Q/K/V/mlp into linear1) was applied to a quantized FLUX module. The weight is exposed as a meta-device tensor with the correct logical shape (read from quant_state for Params4bit, since .shape reports the packed-byte layout). Shape-only patches (LoRA, LoHA, MergedLayerPatch) work; SetParameterLayer / DoRA on quantized modules remain unsupported. --- .../custom_invoke_linear_8_bit_lt.py | 22 +++++++++++++++++ .../custom_invoke_linear_nf4.py | 24 +++++++++++++++++++ 2 files changed, 46 insertions(+) diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_invoke_linear_8_bit_lt.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_invoke_linear_8_bit_lt.py index 2b9d8e9e98e..0f538caa5a4 100644 --- a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_invoke_linear_8_bit_lt.py +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_invoke_linear_8_bit_lt.py @@ -8,10 +8,32 @@ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import ( CustomModuleMixin, ) +from invokeai.backend.patches.layers.param_shape_utils import get_param_shape from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt +from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor class CustomInvokeLinear8bitLt(InvokeLinear8bitLt, CustomModuleMixin): + def _cast_tensor_for_input(self, tensor: torch.Tensor | None, input: torch.Tensor) -> torch.Tensor | None: + tensor = cast_to_device(tensor, input.device) + if ( + tensor is not None + and input.is_floating_point() + and tensor.is_floating_point() + and not isinstance(tensor, GGMLTensor) + and tensor.dtype != input.dtype + ): + tensor = tensor.to(dtype=input.dtype) + return tensor + + def _cast_weight_bias_for_input(self, input: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | None]: + # See the matching method on CustomInvokeLinearNF4 for the rationale. Int8Params doesn't have + # the same packed-shape problem as Params4bit, but we still substitute a meta tensor so that + # patches don't accidentally read the quantized weight values. + weight = torch.empty(get_param_shape(self.weight), device="meta") + bias = self._cast_tensor_for_input(self.bias, input) + return weight, bias + def _autocast_forward_with_patches(self, x: torch.Tensor) -> torch.Tensor: return autocast_linear_forward_sidecar_patches(self, x, self._patches_and_weights) diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_invoke_linear_nf4.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_invoke_linear_nf4.py index 89284d5509a..82596901704 100644 --- a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_invoke_linear_nf4.py +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_invoke_linear_nf4.py @@ -10,10 +10,34 @@ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import ( CustomModuleMixin, ) +from invokeai.backend.patches.layers.param_shape_utils import get_param_shape from invokeai.backend.quantization.bnb_nf4 import InvokeLinearNF4 +from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor class CustomInvokeLinearNF4(InvokeLinearNF4, CustomModuleMixin): + def _cast_tensor_for_input(self, tensor: torch.Tensor | None, input: torch.Tensor) -> torch.Tensor | None: + tensor = cast_to_device(tensor, input.device) + if ( + tensor is not None + and input.is_floating_point() + and tensor.is_floating_point() + and not isinstance(tensor, GGMLTensor) + and tensor.dtype != input.dtype + ): + tensor = tensor.to(dtype=input.dtype) + return tensor + + def _cast_weight_bias_for_input(self, input: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | None]: + # The NF4 weight is a Params4bit whose .shape reports the *packed-byte* layout, not the logical + # (out_features, in_features) shape. We hand patches a meta-device tensor with the correct + # logical shape so that shape-only patches (LoRA, LoHA, MergedLayerPatch over LoRA, ...) work. + # Patches that read the original weight values (e.g. SetParameterLayer, DoRA) are not supported + # on NF4-quantized modules. + weight = torch.empty(get_param_shape(self.weight), device="meta") + bias = self._cast_tensor_for_input(self.bias, input) + return weight, bias + def _autocast_forward_with_patches(self, x: torch.Tensor) -> torch.Tensor: return autocast_linear_forward_sidecar_patches(self, x, self._patches_and_weights)