Skip to content

Commit

Permalink
Modular backend - LoRA/LyCORIS (#6667)
Browse files Browse the repository at this point in the history
## Summary

Code for lora patching from #6577.
Additionally made it the way, that lora can patch not only `weight`, but
also `bias`, because saw some loras which doing it.

## Related Issues / Discussions

#6606 

https://invokeai.notion.site/Modular-Stable-Diffusion-Backend-Design-Document-e8952daab5d5472faecdc4a72d377b0d

## QA Instructions

Run with and without set `USE_MODULAR_DENOISE` environment.

## Merge Plan

Replace old lora patcher with new after review done.
If you think that there should be some kind of tests - feel free to add.

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
  • Loading branch information
RyanJDick authored Jul 31, 2024
2 parents 94d64b8 + 5a9173f commit 4ce64b6
Show file tree
Hide file tree
Showing 9 changed files with 331 additions and 129 deletions.
8 changes: 4 additions & 4 deletions invokeai/app/invocations/compel.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,12 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:

with (
# apply all patches while the model is on the target device
text_encoder_info.model_on_device() as (model_state_dict, text_encoder),
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
tokenizer_info as tokenizer,
ModelPatcher.apply_lora_text_encoder(
text_encoder,
loras=_lora_loader(),
model_state_dict=model_state_dict,
cached_weights=cached_weights,
),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
ModelPatcher.apply_clip_skip(text_encoder, self.clip.skipped_layers),
Expand Down Expand Up @@ -175,13 +175,13 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:

with (
# apply all patches while the model is on the target device
text_encoder_info.model_on_device() as (state_dict, text_encoder),
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
tokenizer_info as tokenizer,
ModelPatcher.apply_lora(
text_encoder,
loras=_lora_loader(),
prefix=lora_prefix,
model_state_dict=state_dict,
cached_weights=cached_weights,
),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
ModelPatcher.apply_clip_skip(text_encoder, clip_field.skipped_layers),
Expand Down
15 changes: 13 additions & 2 deletions invokeai/app/invocations/denoise_latents.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
from invokeai.backend.stable_diffusion.extensions.freeu import FreeUExt
from invokeai.backend.stable_diffusion.extensions.inpaint import InpaintExt
from invokeai.backend.stable_diffusion.extensions.inpaint_model import InpaintModelExt
from invokeai.backend.stable_diffusion.extensions.lora import LoRAExt
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
from invokeai.backend.stable_diffusion.extensions.rescale_cfg import RescaleCFGExt
from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
Expand Down Expand Up @@ -845,6 +846,16 @@ def step_callback(state: PipelineIntermediateState) -> None:
if self.unet.freeu_config:
ext_manager.add_extension(FreeUExt(self.unet.freeu_config))

### lora
if self.unet.loras:
for lora_field in self.unet.loras:
ext_manager.add_extension(
LoRAExt(
node_context=context,
model_id=lora_field.lora,
weight=lora_field.weight,
)
)
### seamless
if self.unet.seamless_axes:
ext_manager.add_extension(SeamlessExt(self.unet.seamless_axes))
Expand Down Expand Up @@ -964,14 +975,14 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
assert isinstance(unet_info.model, UNet2DConditionModel)
with (
ExitStack() as exit_stack,
unet_info.model_on_device() as (model_state_dict, unet),
unet_info.model_on_device() as (cached_weights, unet),
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
SeamlessExt.static_patch_model(unet, self.unet.seamless_axes), # FIXME
# Apply the LoRA after unet has been moved to its target device for faster patching.
ModelPatcher.apply_lora_unet(
unet,
loras=_lora_loader(),
model_state_dict=model_state_dict,
cached_weights=cached_weights,
),
):
assert isinstance(unet, UNet2DConditionModel)
Expand Down
135 changes: 85 additions & 50 deletions invokeai/backend/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@

import bisect
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List, Optional, Set, Tuple, Union

import torch
from safetensors.torch import load_file
from typing_extensions import Self

import invokeai.backend.util.logging as logger
from invokeai.backend.model_manager import BaseModelType
from invokeai.backend.raw_model import RawModel

Expand Down Expand Up @@ -46,9 +47,19 @@ def __init__(
self.rank = None # set in layer implementation
self.layer_key = layer_key

def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
raise NotImplementedError()

def get_bias(self, orig_bias: torch.Tensor) -> Optional[torch.Tensor]:
return self.bias

def get_parameters(self, orig_module: torch.nn.Module) -> Dict[str, torch.Tensor]:
params = {"weight": self.get_weight(orig_module.weight)}
bias = self.get_bias(orig_module.bias)
if bias is not None:
params["bias"] = bias
return params

def calc_size(self) -> int:
model_size = 0
for val in [self.bias]:
Expand All @@ -60,6 +71,17 @@ def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype]
if self.bias is not None:
self.bias = self.bias.to(device=device, dtype=dtype)

def check_keys(self, values: Dict[str, torch.Tensor], known_keys: Set[str]):
"""Log a warning if values contains unhandled keys."""
# {"alpha", "bias_indices", "bias_values", "bias_size"} are hard-coded, because they are handled by
# `LoRALayerBase`. Sub-classes should provide the known_keys that they handled.
all_known_keys = known_keys | {"alpha", "bias_indices", "bias_values", "bias_size"}
unknown_keys = set(values.keys()) - all_known_keys
if unknown_keys:
logger.warning(
f"Unexpected keys found in LoRA/LyCORIS layer, model might work incorrectly! Keys: {unknown_keys}"
)


# TODO: find and debug lora/locon with bias
class LoRALayer(LoRALayerBase):
Expand All @@ -76,14 +98,19 @@ def __init__(

self.up = values["lora_up.weight"]
self.down = values["lora_down.weight"]
if "lora_mid.weight" in values:
self.mid: Optional[torch.Tensor] = values["lora_mid.weight"]
else:
self.mid = None
self.mid = values.get("lora_mid.weight", None)

self.rank = self.down.shape[0]
self.check_keys(
values,
{
"lora_up.weight",
"lora_down.weight",
"lora_mid.weight",
},
)

def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
if self.mid is not None:
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
Expand Down Expand Up @@ -125,20 +152,23 @@ def __init__(self, layer_key: str, values: Dict[str, torch.Tensor]):
self.w1_b = values["hada_w1_b"]
self.w2_a = values["hada_w2_a"]
self.w2_b = values["hada_w2_b"]

if "hada_t1" in values:
self.t1: Optional[torch.Tensor] = values["hada_t1"]
else:
self.t1 = None

if "hada_t2" in values:
self.t2: Optional[torch.Tensor] = values["hada_t2"]
else:
self.t2 = None
self.t1 = values.get("hada_t1", None)
self.t2 = values.get("hada_t2", None)

self.rank = self.w1_b.shape[0]
self.check_keys(
values,
{
"hada_w1_a",
"hada_w1_b",
"hada_w2_a",
"hada_w2_b",
"hada_t1",
"hada_t2",
},
)

def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
if self.t1 is None:
weight: torch.Tensor = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)

Expand Down Expand Up @@ -186,37 +216,39 @@ def __init__(
):
super().__init__(layer_key, values)

if "lokr_w1" in values:
self.w1: Optional[torch.Tensor] = values["lokr_w1"]
self.w1_a = None
self.w1_b = None
else:
self.w1 = None
self.w1 = values.get("lokr_w1", None)
if self.w1 is None:
self.w1_a = values["lokr_w1_a"]
self.w1_b = values["lokr_w1_b"]

if "lokr_w2" in values:
self.w2: Optional[torch.Tensor] = values["lokr_w2"]
self.w2_a = None
self.w2_b = None
else:
self.w2 = None
self.w2 = values.get("lokr_w2", None)
if self.w2 is None:
self.w2_a = values["lokr_w2_a"]
self.w2_b = values["lokr_w2_b"]

if "lokr_t2" in values:
self.t2: Optional[torch.Tensor] = values["lokr_t2"]
else:
self.t2 = None
self.t2 = values.get("lokr_t2", None)

if "lokr_w1_b" in values:
self.rank = values["lokr_w1_b"].shape[0]
elif "lokr_w2_b" in values:
self.rank = values["lokr_w2_b"].shape[0]
if self.w1_b is not None:
self.rank = self.w1_b.shape[0]
elif self.w2_b is not None:
self.rank = self.w2_b.shape[0]
else:
self.rank = None # unscaled

def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
self.check_keys(
values,
{
"lokr_w1",
"lokr_w1_a",
"lokr_w1_b",
"lokr_w2",
"lokr_w2_a",
"lokr_w2_b",
"lokr_t2",
},
)

def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
w1: Optional[torch.Tensor] = self.w1
if w1 is None:
assert self.w1_a is not None
Expand Down Expand Up @@ -272,7 +304,9 @@ def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype]


class FullLayer(LoRALayerBase):
# bias handled in LoRALayerBase(calc_size, to)
# weight: torch.Tensor
# bias: Optional[torch.Tensor]

def __init__(
self,
Expand All @@ -282,15 +316,12 @@ def __init__(
super().__init__(layer_key, values)

self.weight = values["diff"]

if len(values.keys()) > 1:
_keys = list(values.keys())
_keys.remove("diff")
raise NotImplementedError(f"Unexpected keys in lora diff layer: {_keys}")
self.bias = values.get("diff_b", None)

self.rank = None # unscaled
self.check_keys(values, {"diff", "diff_b"})

def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
return self.weight

def calc_size(self) -> int:
Expand Down Expand Up @@ -319,8 +350,9 @@ def __init__(
self.on_input = values["on_input"]

self.rank = None # unscaled
self.check_keys(values, {"weight", "on_input"})

def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
weight = self.weight
if not self.on_input:
weight = weight.reshape(-1, 1)
Expand Down Expand Up @@ -458,24 +490,27 @@ def from_checkpoint(
state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict)

for layer_key, values in state_dict.items():
# Detect layers according to LyCORIS detection logic(`weight_list_det`)
# https://github.com/KohakuBlueleaf/LyCORIS/tree/8ad8000efb79e2b879054da8c9356e6143591bad/lycoris/modules

# lora and locon
if "lora_down.weight" in values:
if "lora_up.weight" in values:
layer: AnyLoRALayer = LoRALayer(layer_key, values)

# loha
elif "hada_w1_b" in values:
elif "hada_w1_a" in values:
layer = LoHALayer(layer_key, values)

# lokr
elif "lokr_w1_b" in values or "lokr_w1" in values:
elif "lokr_w1" in values or "lokr_w1_a" in values:
layer = LoKRLayer(layer_key, values)

# diff
elif "diff" in values:
layer = FullLayer(layer_key, values)

# ia3
elif "weight" in values and "on_input" in values:
elif "on_input" in values:
layer = IA3Layer(layer_key, values)

else:
Expand Down
Loading

0 comments on commit 4ce64b6

Please sign in to comment.