Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modular backend - LoRA/LyCORIS #6667

Merged
merged 14 commits into from
Jul 31, 2024
11 changes: 11 additions & 0 deletions invokeai/app/invocations/denoise_latents.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
from invokeai.backend.stable_diffusion.extensions.controlnet import ControlNetExt
from invokeai.backend.stable_diffusion.extensions.freeu import FreeUExt
from invokeai.backend.stable_diffusion.extensions.lora_patcher import LoRAPatcherExt
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
from invokeai.backend.stable_diffusion.extensions.rescale_cfg import RescaleCFGExt
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
Expand Down Expand Up @@ -833,6 +834,16 @@ def step_callback(state: PipelineIntermediateState) -> None:
if self.unet.freeu_config:
ext_manager.add_extension(FreeUExt(self.unet.freeu_config))

### lora
if self.unet.loras:
ext_manager.add_extension(
LoRAPatcherExt(
node_context=context,
loras=self.unet.loras,
prefix="lora_unet_",
)
)

# context for loading additional models
with ExitStack() as exit_stack:
# later should be smth like:
Expand Down
131 changes: 81 additions & 50 deletions invokeai/backend/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import bisect
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List, Optional, Set, Tuple, Union

import torch
from safetensors.torch import load_file
Expand Down Expand Up @@ -46,9 +46,19 @@ def __init__(
self.rank = None # set in layer implementation
self.layer_key = layer_key

def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
raise NotImplementedError()

def get_bias(self, orig_bias: torch.Tensor) -> Optional[torch.Tensor]:
return self.bias

def get_parameters(self, orig_module: torch.nn.Module) -> Dict[str, torch.Tensor]:
params = {"weight": self.get_weight(orig_module.weight)}
bias = self.get_bias(orig_module.bias)
if bias is not None:
params["bias"] = bias
return params

def calc_size(self) -> int:
model_size = 0
for val in [self.bias]:
Expand All @@ -60,6 +70,15 @@ def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype]
if self.bias is not None:
self.bias = self.bias.to(device=device, dtype=dtype)

def check_keys(self, values: Dict[str, torch.Tensor], known_keys: Set[str]):
all_known_keys = known_keys | {"alpha", "bias_indices", "bias_values", "bias_size"}
unknown_keys = set(values.keys()) - all_known_keys
if unknown_keys:
# TODO: how to warn log?
print(
f"[WARN] Unexpected keys found in LoRA/LyCORIS layer, model might work incorrectly! Keys: {unknown_keys}"
)


# TODO: find and debug lora/locon with bias
class LoRALayer(LoRALayerBase):
Expand All @@ -76,14 +95,19 @@ def __init__(

self.up = values["lora_up.weight"]
self.down = values["lora_down.weight"]
if "lora_mid.weight" in values:
self.mid: Optional[torch.Tensor] = values["lora_mid.weight"]
else:
self.mid = None
self.mid = values.get("lora_mid.weight", None)

self.rank = self.down.shape[0]
self.check_keys(
values,
{
"lora_up.weight",
"lora_down.weight",
"lora_mid.weight",
},
)

def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
if self.mid is not None:
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
Expand Down Expand Up @@ -125,20 +149,23 @@ def __init__(self, layer_key: str, values: Dict[str, torch.Tensor]):
self.w1_b = values["hada_w1_b"]
self.w2_a = values["hada_w2_a"]
self.w2_b = values["hada_w2_b"]

if "hada_t1" in values:
self.t1: Optional[torch.Tensor] = values["hada_t1"]
else:
self.t1 = None

if "hada_t2" in values:
self.t2: Optional[torch.Tensor] = values["hada_t2"]
else:
self.t2 = None
self.t1 = values.get("hada_t1", None)
self.t2 = values.get("hada_t2", None)

self.rank = self.w1_b.shape[0]
self.check_keys(
values,
{
"hada_w1_a",
"hada_w1_b",
"hada_w2_a",
"hada_w2_b",
"hada_t1",
"hada_t2",
},
)

def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
if self.t1 is None:
weight: torch.Tensor = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)

Expand Down Expand Up @@ -186,37 +213,41 @@ def __init__(
):
super().__init__(layer_key, values)

if "lokr_w1" in values:
self.w1: Optional[torch.Tensor] = values["lokr_w1"]
self.w1_a = None
self.w1_b = None
else:
self.w1 = None
self.w1 = values.get("lokr_w1", None)
if self.w1 is None:
self.w1_a = values["lokr_w1_a"]
self.w1_b = values["lokr_w1_b"]

if "lokr_w2" in values:
self.w2: Optional[torch.Tensor] = values["lokr_w2"]
self.w2_a = None
self.w2_b = None
else:
self.w2 = None
self.w2 = values.get("lokr_w2", None)
if self.w2 is None:
self.w2_a = values["lokr_w2_a"]
self.w2_b = values["lokr_w2_b"]

if "lokr_t2" in values:
self.t2: Optional[torch.Tensor] = values["lokr_t2"]
else:
self.t2 = None
self.t2 = values.get("lokr_t2", None)

if "lokr_w1_b" in values:
self.rank = values["lokr_w1_b"].shape[0]
elif "lokr_w2_b" in values:
self.rank = values["lokr_w2_b"].shape[0]
if self.w1_b is not None:
self.rank = self.w1_b.shape[0]
elif self.w2_b is not None:
self.rank = self.w2_b.shape[0]
else:
self.rank = None # unscaled

def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
# Although lokr_t1 not used in algo, it still defined in LoKR weights
self.check_keys(
values,
{
"lokr_w1",
"lokr_w1_a",
"lokr_w1_b",
"lokr_w2",
"lokr_w2_a",
"lokr_w2_b",
"lokr_t1",
"lokr_t2",
},
)

def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
w1: Optional[torch.Tensor] = self.w1
if w1 is None:
assert self.w1_a is not None
Expand Down Expand Up @@ -272,7 +303,9 @@ def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype]


class FullLayer(LoRALayerBase):
# bias handled in LoRALayerBase(calc_size, to)
# weight: torch.Tensor
# bias: Optional[torch.Tensor]

def __init__(
self,
Expand All @@ -282,15 +315,12 @@ def __init__(
super().__init__(layer_key, values)

self.weight = values["diff"]

if len(values.keys()) > 1:
_keys = list(values.keys())
_keys.remove("diff")
raise NotImplementedError(f"Unexpected keys in lora diff layer: {_keys}")
self.bias = values.get("diff_b", None)

self.rank = None # unscaled
self.check_keys(values, {"diff", "diff_b"})

def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
return self.weight

def calc_size(self) -> int:
Expand Down Expand Up @@ -319,8 +349,9 @@ def __init__(
self.on_input = values["on_input"]

self.rank = None # unscaled
self.check_keys(values, {"weight", "on_input"})

def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
weight = self.weight
if not self.on_input:
weight = weight.reshape(-1, 1)
Expand Down Expand Up @@ -459,23 +490,23 @@ def from_checkpoint(

for layer_key, values in state_dict.items():
# lora and locon
if "lora_down.weight" in values:
if "lora_up.weight" in values:
layer: AnyLoRALayer = LoRALayer(layer_key, values)

# loha
elif "hada_w1_b" in values:
elif "hada_w1_a" in values:
layer = LoHALayer(layer_key, values)

# lokr
elif "lokr_w1_b" in values or "lokr_w1" in values:
elif "lokr_w1" in values or "lokr_w1_a" in values:
layer = LoKRLayer(layer_key, values)

# diff
elif "diff" in values:
layer = FullLayer(layer_key, values)

# ia3
elif "weight" in values and "on_input" in values:
elif "on_input" in values:
layer = IA3Layer(layer_key, values)

else:
Expand Down
Loading