forked from tritant/ComfyUI_Flux_Block_Lora_Merger
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathflux_block_lora_merger.py
150 lines (121 loc) · 5.45 KB
/
flux_block_lora_merger.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import os
import gc
import torch
from aiohttp import web
from safetensors.torch import load_file
from folder_paths import get_filename_list, get_full_path
from comfy.sd import load_lora_for_models
from comfy_extras.nodes_model_merging import save_checkpoint
# Setup API route for block listing
def setup():
from server import PromptServer
@PromptServer.instance.routes.get("/custom/flux_block_lora_merger/list_blocks")
async def list_blocks(request):
file = request.rel_url.query.get("file")
if not file:
return web.json_response({"error": "Missing file parameter"}, status=400)
try:
print(f"[DEBUG] Requested LoRA file: '{file}'")
full_path = get_full_path("loras", file)
print(f"[DEBUG] Full path resolved: {full_path}")
if not os.path.exists(full_path):
return web.json_response({"error": f"File not found at path: {full_path}"}, status=404)
lora_sd = load_file(full_path)
print("[DEBUG] Keys in LoRA:")
for k in lora_sd.keys():
print(" -", k)
blocks = set()
for k in lora_sd:
if not k.startswith("lora_unet_"):
continue
try:
parts = k.split(".")[0].split("_")
if "blocks" in parts:
idx = parts.index("blocks")
block_id = parts[idx + 1]
group_type = parts[idx - 1]
block = f"{group_type}_blocks_{block_id}"
blocks.add(block)
except Exception as e:
print(f"[WARN] Could not extract block from key '{k}': {e}")
print("[DEBUG] Block groups:", sorted(blocks))
return web.json_response({"blocks": sorted(blocks)})
except Exception as e:
print(f"[ERROR] Failed to list blocks: {e}")
return web.json_response({"error": f"Failed to load blocks: {str(e)}"}, status=500)
# Main merge node class
class FluxBlockLoraMerger:
@classmethod
def INPUT_TYPES(cls):
lora_list = get_filename_list("loras")
return {
"required": {
"unet_model": ("MODEL",),
"lora_path": (lora_list,),
"weight": ("FLOAT", {"default": 1.0}),
"save_model": ("BOOLEAN", {"default": False}),
"save_filename": ("STRING", {"default": "flux_block_merged.safetensors"}),
"block_prefixes": ("STRING", {"multiline": True, "default": ""})
}
}
RETURN_TYPES = ("MODEL", "STRING",)
RETURN_NAMES = ("model", "merge_report",)
FUNCTION = "merge_selected_blocks"
CATEGORY = "flux/dev"
def merge_selected_blocks(self, unet_model, lora_path, weight, save_model, save_filename, block_prefixes):
model = unet_model.clone()
lora_path_full = get_full_path("loras", lora_path)
lora_sd_full = load_file(lora_path_full)
excluded_blocks = [p.strip().replace("block:", "") for p in block_prefixes.splitlines() if p.strip() and p.startswith("block:")]
merged_keys = {}
excluded_keys = {}
for k, v in lora_sd_full.items():
if not k.startswith("lora_unet_"):
continue
try:
parts = k.split(".")[0].split("_")
block_group = None
if "blocks" in parts:
idx = parts.index("blocks")
block_id = parts[idx + 1]
group_type = parts[idx - 1]
block_group = f"{group_type}_blocks_{block_id}"
except Exception as e:
print(f"[WARN] Failed to parse block group for key: {k} → {e}")
block_group = None
if block_group and block_group in excluded_blocks:
excluded_keys[k] = v
continue
merged_keys[k] = v
ignored_keys = [k for k in lora_sd_full if k not in merged_keys and k not in excluded_keys]
print("[BLOCK SCAN] Available blocks in LoRA:")
for k in sorted(set(k.split(".")[0] for k in merged_keys)):
print(f" - {k}")
print(f"[BLOCK MERGE] Excluded blocks: {excluded_blocks}")
print(f" → Loaded {len(merged_keys)} keys from allowed blocks")
print(f" → Skipped {len(excluded_keys)} keys from excluded blocks")
print(f" → Ignored {len(ignored_keys)} keys (non-UNet or text_encoder)")
model, _ = load_lora_for_models(model, None, merged_keys, weight, 0.0)
if save_model:
print(f"[SAVE] Saving Model {save_filename} Cleaning Vram Before")
torch.cuda.empty_cache()
gc.collect()
output_path = os.path.join(os.getcwd(), "output")
os.makedirs(output_path, exist_ok=True)
save_checkpoint(
model=model,
filename_prefix=os.path.splitext(save_filename)[0],
output_dir=output_path,
prompt=None,
extra_pnginfo=None
)
print(f"[SAVE] Model saved to output/{save_filename}")
report = f"✔️ Merged {len(merged_keys)} keys (excluded: {len(excluded_keys)}), ignored: {len(ignored_keys)}"
return (model, report)
NODE_CLASS_MAPPINGS = {
"FluxBlockLoraMerger": FluxBlockLoraMerger
}
NODE_DISPLAY_NAME_MAPPINGS = {
"FluxBlockLoraMerger": "Flux Block LoRA Merger 🧩"
}
setup()