Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 146 additions & 4 deletions networks/svd_merge_lora.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import argparse
import itertools
import json
import os
import re
import time
import torch
from safetensors.torch import load_file, save_file
Expand All @@ -14,6 +17,106 @@

CLAMP_QUANTILE = 0.99

ACCEPTABLE = [12, 17, 20, 26]
SDXL_LAYER_NUM = [12, 20]

LAYER12 = {
"BASE": True,
"IN00": False, "IN01": False, "IN02": False, "IN03": False, "IN04": True, "IN05": True,
"IN06": False, "IN07": True, "IN08": True, "IN09": False, "IN10": False, "IN11": False,
"MID": True,
"OUT00": True, "OUT01": True, "OUT02": True, "OUT03": True, "OUT04": True, "OUT05": True,
"OUT06": False, "OUT07": False, "OUT08": False, "OUT09": False, "OUT10": False, "OUT11": False
}

LAYER17 = {
"BASE": True,
"IN00": False, "IN01": True, "IN02": True, "IN03": False, "IN04": True, "IN05": True,
"IN06": False, "IN07": True, "IN08": True, "IN09": False, "IN10": False, "IN11": False,
"MID": True,
"OUT00": False, "OUT01": False, "OUT02": False, "OUT03": True, "OUT04": True, "OUT05": True,
"OUT06": True, "OUT07": True, "OUT08": True, "OUT09": True, "OUT10": True, "OUT11": True,
}

LAYER20 = {
"BASE": True,
"IN00": True, "IN01": True, "IN02": True, "IN03": True, "IN04": True, "IN05": True,
"IN06": True, "IN07": True, "IN08": True, "IN09": False, "IN10": False, "IN11": False,
"MID": True,
"OUT00": True, "OUT01": True, "OUT02": True, "OUT03": True, "OUT04": True, "OUT05": True,
"OUT06": True, "OUT07": True, "OUT08": True, "OUT09": False, "OUT10": False, "OUT11": False,
}

LAYER26 = {
"BASE": True,
"IN00": True, "IN01": True, "IN02": True, "IN03": True, "IN04": True, "IN05": True,
"IN06": True, "IN07": True, "IN08": True, "IN09": True, "IN10": True, "IN11": True,
"MID": True,
"OUT00": True, "OUT01": True, "OUT02": True, "OUT03": True, "OUT04": True, "OUT05": True,
"OUT06": True, "OUT07": True, "OUT08": True, "OUT09": True, "OUT10": True, "OUT11": True,
}

assert len([v for v in LAYER12.values() if v]) == 12
assert len([v for v in LAYER17.values() if v]) == 17
assert len([v for v in LAYER20.values() if v]) == 20
assert len([v for v in LAYER26.values() if v]) == 26

RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")


def get_lbw_block_index(lora_name: str, is_sdxl: bool = False) -> int:
# lbw block index is 0-based, but 0 for text encoder, so we return 0 for text encoder
if "text_model_encoder_" in lora_name: # LoRA for text encoder
return 0

# lbw block index is 1-based for U-Net, and no "input_blocks.0" in CompVis SD, so "input_blocks.1" have index 2
block_idx = -1 # invalid lora name
if not is_sdxl:
NUM_OF_BLOCKS = 12 # up/down blocks
m = RE_UPDOWN.search(lora_name)
if m:
g = m.groups()
up_down = g[0]
i = int(g[1])
j = int(g[3])
if up_down == "down":
if g[2] == "resnets" or g[2] == "attentions":
idx = 3 * i + j + 1
elif g[2] == "downsamplers":
idx = 3 * (i + 1)
else:
return block_idx # invalid lora name
elif up_down == "up":
if g[2] == "resnets" or g[2] == "attentions":
idx = 3 * i + j
elif g[2] == "upsamplers":
idx = 3 * i + 2
else:
return block_idx # invalid lora name

if g[0] == "down":
block_idx = 1 + idx # 1-based index, down block index
elif g[0] == "up":
block_idx = 1 + NUM_OF_BLOCKS + 1 + idx # 1-based index, num blocks, mid block, up block index

elif "mid_block_" in lora_name:
block_idx = 1 + NUM_OF_BLOCKS # 1-based index, num blocks, mid block
else:
if lora_name.startswith("lora_unet_"):
name = lora_name[len("lora_unet_") :]
if name.startswith("time_embed_") or name.startswith("label_emb_"): # 1, No LoRA in sd-scripts
block_idx = 1
elif name.startswith("input_blocks_"): # 1-8 to 2-9
block_idx = 1 + int(name.split("_")[2])
elif name.startswith("middle_block_"): # 10
block_idx = 10
elif name.startswith("output_blocks_"): # 0-8 to 11-19
block_idx = 11 + int(name.split("_")[2])
elif name.startswith("out_"): # 20, No LoRA in sd-scripts
block_idx = 20

return block_idx


def load_state_dict(file_name, dtype):
if os.path.splitext(file_name)[1] == ".safetensors":
Expand Down Expand Up @@ -42,12 +145,34 @@ def save_to_file(file_name, state_dict, dtype, metadata):
torch.save(state_dict, file_name)


def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dtype):
def merge_lora_models(models, ratios, lbws, new_rank, new_conv_rank, device, merge_dtype):
logger.info(f"new rank: {new_rank}, new conv rank: {new_conv_rank}")
merged_sd = {}
v2 = None
v2 = None # This is meaning LoRA Metadata v2, Not meaning SD2
base_model = None
for model, ratio in zip(models, ratios):

if lbws:
try:
# lbwは"[1,1,1,1,1,1,1,1,1,1,1,1]"のような文字列で与えられることを期待している
lbws = [json.loads(lbw) for lbw in lbws]
except Exception:
raise ValueError(f"format of lbws are must be json / 層別適用率はJSON形式で書いてください")
assert all(isinstance(lbw, list) for lbw in lbws), f"lbws are must be list / 層別適用率はリストにしてください"
assert len(set(len(lbw) for lbw in lbws)) == 1, "all lbws should have the same length / 層別適用率は同じ長さにしてください"
assert all(len(lbw) in ACCEPTABLE for lbw in lbws), f"length of lbw are must be in {ACCEPTABLE} / 層別適用率の長さは{ACCEPTABLE}のいずれかにしてください"
assert all(all(isinstance(weight, (int, float)) for weight in lbw) for lbw in lbws), f"values of lbs are must be numbers / 層別適用率の値はすべて数値にしてください"

layer_num = len(lbws[0])
is_sdxl = True if layer_num in SDXL_LAYER_NUM else False
FLAGS = {
"12": LAYER12.values(),
"17": LAYER17.values(),
"20": LAYER20.values(),
"26": LAYER26.values(),
}[str(layer_num)]
LBW_TARGET_IDX = [i for i, flag in enumerate(FLAGS) if flag]

for model, ratio, lbw in itertools.zip_longest(models, ratios, lbws):
logger.info(f"loading: {model}")
lora_sd, lora_metadata = load_state_dict(model, merge_dtype)

Expand All @@ -57,6 +182,12 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
if base_model is None:
base_model = lora_metadata.get(train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None)

if lbw:
lbw_weights = [1] * 26
for index, value in zip(LBW_TARGET_IDX, lbw):
lbw_weights[index] = value
print(dict(zip(LAYER26.keys(), lbw_weights)))

# merge
logger.info(f"merging...")
for key in tqdm(list(lora_sd.keys())):
Expand Down Expand Up @@ -93,6 +224,12 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
# W <- W + U * D
scale = alpha / network_dim

if lbw:
index = get_lbw_block_index(key, is_sdxl)
is_lbw_target = index in LBW_TARGET_IDX
if is_lbw_target:
scale *= lbw_weights[index] # keyがlbwの対象であれば、lbwの重みを掛ける

if device: # and isinstance(scale, torch.Tensor):
scale = scale.to(device)

Expand Down Expand Up @@ -170,6 +307,10 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty

def merge(args):
assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
if args.lbws:
assert len(args.models) == len(args.lbws), f"number of models must be equal to number of ratios / モデルの数と層別適用率の数は合わせてください"
else:
args.lbws = [] # zip_longestで扱えるようにlbws未使用時には空のリストにしておく

def str_to_dtype(p):
if p == "float":
Expand All @@ -187,7 +328,7 @@ def str_to_dtype(p):

new_conv_rank = args.new_conv_rank if args.new_conv_rank is not None else args.new_rank
state_dict, metadata, v2, base_model = merge_lora_models(
args.models, args.ratios, args.new_rank, new_conv_rank, args.device, merge_dtype
args.models, args.ratios, args.lbws, args.new_rank, new_conv_rank, args.device, merge_dtype
)

logger.info(f"calculating hashes and creating metadata...")
Expand Down Expand Up @@ -237,6 +378,7 @@ def setup_parser() -> argparse.ArgumentParser:
"--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors"
)
parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率")
parser.add_argument("--lbws", type=str, nargs="*", help="lbw for each model / それぞれのLoRAモデルの層別適用率")
parser.add_argument("--new_rank", type=int, default=4, help="Specify rank of output LoRA / 出力するLoRAのrank (dim)")
parser.add_argument(
"--new_conv_rank",
Expand Down
Loading