Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,32 @@
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
CustomModuleMixin,
)
from invokeai.backend.patches.layers.param_shape_utils import get_param_shape
from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor


class CustomInvokeLinear8bitLt(InvokeLinear8bitLt, CustomModuleMixin):
def _cast_tensor_for_input(self, tensor: torch.Tensor | None, input: torch.Tensor) -> torch.Tensor | None:
tensor = cast_to_device(tensor, input.device)
if (
tensor is not None
and input.is_floating_point()
and tensor.is_floating_point()
and not isinstance(tensor, GGMLTensor)
and tensor.dtype != input.dtype
):
tensor = tensor.to(dtype=input.dtype)
return tensor

def _cast_weight_bias_for_input(self, input: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | None]:
# See the matching method on CustomInvokeLinearNF4 for the rationale. Int8Params doesn't have
# the same packed-shape problem as Params4bit, but we still substitute a meta tensor so that
# patches don't accidentally read the quantized weight values.
weight = torch.empty(get_param_shape(self.weight), device="meta")
bias = self._cast_tensor_for_input(self.bias, input)
return weight, bias

def _autocast_forward_with_patches(self, x: torch.Tensor) -> torch.Tensor:
return autocast_linear_forward_sidecar_patches(self, x, self._patches_and_weights)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,34 @@
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
CustomModuleMixin,
)
from invokeai.backend.patches.layers.param_shape_utils import get_param_shape
from invokeai.backend.quantization.bnb_nf4 import InvokeLinearNF4
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor


class CustomInvokeLinearNF4(InvokeLinearNF4, CustomModuleMixin):
def _cast_tensor_for_input(self, tensor: torch.Tensor | None, input: torch.Tensor) -> torch.Tensor | None:
tensor = cast_to_device(tensor, input.device)
if (
tensor is not None
and input.is_floating_point()
and tensor.is_floating_point()
and not isinstance(tensor, GGMLTensor)
and tensor.dtype != input.dtype
):
tensor = tensor.to(dtype=input.dtype)
return tensor

def _cast_weight_bias_for_input(self, input: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | None]:
# The NF4 weight is a Params4bit whose .shape reports the *packed-byte* layout, not the logical
# (out_features, in_features) shape. We hand patches a meta-device tensor with the correct
# logical shape so that shape-only patches (LoRA, LoHA, MergedLayerPatch over LoRA, ...) work.
# Patches that read the original weight values (e.g. SetParameterLayer, DoRA) are not supported
# on NF4-quantized modules.
weight = torch.empty(get_param_shape(self.weight), device="meta")
bias = self._cast_tensor_for_input(self.bias, input)
return weight, bias

def _autocast_forward_with_patches(self, x: torch.Tensor) -> torch.Tensor:
return autocast_linear_forward_sidecar_patches(self, x, self._patches_and_weights)

Expand Down
Loading