-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpatch_comfyui_nunchaku_lora.py
More file actions
131 lines (103 loc) Β· 4.53 KB
/
patch_comfyui_nunchaku_lora.py
File metadata and controls
131 lines (103 loc) Β· 4.53 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import safetensors.torch
from safetensors import safe_open
import torch
import os
import glob
def patch_final_layer_adaLN(state_dict, prefix="lora_unet_final_layer", verbose=True):
"""
Add dummy adaLN weights if missing, using final_layer_linear shapes as reference.
Args:
state_dict (dict): keys -> tensors
prefix (str): base name for final_layer keys
verbose (bool): print debug info
Returns:
dict: patched state_dict
"""
final_layer_linear_down = None
final_layer_linear_up = None
adaLN_down_key = f"{prefix}_adaLN_modulation_1.lora_down.weight"
adaLN_up_key = f"{prefix}_adaLN_modulation_1.lora_up.weight"
linear_down_key = f"{prefix}_linear.lora_down.weight"
linear_up_key = f"{prefix}_linear.lora_up.weight"
if verbose:
print(f"\nπ Checking for final_layer keys with prefix: '{prefix}'")
print(f" Linear down: {linear_down_key}")
print(f" Linear up: {linear_up_key}")
if linear_down_key in state_dict:
final_layer_linear_down = state_dict[linear_down_key]
if linear_up_key in state_dict:
final_layer_linear_up = state_dict[linear_up_key]
has_adaLN = adaLN_down_key in state_dict and adaLN_up_key in state_dict
has_linear = final_layer_linear_down is not None and final_layer_linear_up is not None
if verbose:
print(f" β
Has final_layer.linear: {has_linear}")
print(f" β
Has final_layer.adaLN_modulation_1: {has_adaLN}")
if has_linear and not has_adaLN:
dummy_down = torch.zeros_like(final_layer_linear_down)
dummy_up = torch.zeros_like(final_layer_linear_up)
state_dict[adaLN_down_key] = dummy_down
state_dict[adaLN_up_key] = dummy_up
if verbose:
print(f"β
Added dummy adaLN weights:")
print(f" {adaLN_down_key} (shape: {dummy_down.shape})")
print(f" {adaLN_up_key} (shape: {dummy_up.shape})")
else:
if verbose:
print("β
No patch needed β adaLN weights already present or no final_layer.linear found.")
return state_dict
def main():
print("π Universal final_layer.adaLN LoRA patcher (.safetensors)")
print("Looking for .safetensors files in the 'lora' directory...")
lora_dir = "lora"
lora_files = glob.glob(os.path.join(lora_dir, "*.safetensors"))
if not lora_files:
print(f"\nβ No `.safetensors` files found in the '{lora_dir}' directory.")
return
print(f"\nFound {len(lora_files)} file(s) to process.")
patched_count = 0
for input_path in lora_files:
if "_patched" in os.path.basename(input_path):
print(f"\nβοΈ Skipping already patched file: {input_path}")
continue
print(f"\n-----------------------------------------------------")
print(f"Processing: {input_path}")
base, ext = os.path.splitext(input_path)
output_path = f"{base}_patched{ext}"
# Load
state_dict = {}
with safe_open(input_path, framework="pt", device="cpu") as f:
for k in f.keys():
state_dict[k] = f.get_tensor(k)
print(f"β
Loaded {len(state_dict)} tensors.")
# Try common prefixes in order
prefixes = [
"lora_unet_final_layer",
"final_layer",
"base_model.model.final_layer"
]
patched = False
before_count = len(state_dict)
for prefix in prefixes:
state_dict = patch_final_layer_adaLN(state_dict, prefix=prefix, verbose=False)
if len(state_dict) > before_count:
print(f"β
Patch applied using prefix '{prefix}'.")
patched = True
break # Stop after the first successful patch
if not patched:
print("β
No patch needed for this file.")
continue
# Save
safetensors.torch.save_file(state_dict, output_path)
print(f"β
Patched file saved to: {output_path}")
patched_count += 1
# Verify
with safe_open(output_path, framework="pt", device="cpu") as f:
has_adaLN_after = any("adaLN_modulation_1" in k for k in f.keys())
if has_adaLN_after:
print("β
Verification successful: `adaLN` keys are present.")
else:
print("β Verification failed: `adaLN` keys are missing in the output file.")
print(f"\n-----------------------------------------------------")
print(f"π Done. Patched {patched_count} file(s).")
if __name__ == "__main__":
main()