Skip to content

Commit 3c51f8a

Browse files
committed
Add verbose weights
```bash python lora-inspector.py -v /mnt/900/lora/boo_v2.safetensors ``` ``` down_block_0_proj_in weight average magnitude: 1.4924503799105646 down_block_0_proj_in weight average strength: 0.017584813472947592 down_block_0_proj_out weight average magnitude: 1.5157774537695803 down_block_0_proj_out weight average strength: 0.017854483676551985 unet_down_block_0_transformer_block_0_attentions_to_k weight average magnitude: 1.462021831711949 unet_down_block_0_transformer_block_0_attentions_to_k weight average strength: 0.014734332086778964 unet_down_block_0_transformer_block_0_attentions_to_out_0 weight average magnitude: 1.5096795979790278 unet_down_block_0_transformer_block_0_attentions_to_out_0 weight average strength: 0.017774835308342455 unet_down_block_0_transformer_block_0_attentions_to_q weight average magnitude: 1.445551043693953 unet_down_block_0_transformer_block_0_attentions_to_q weight average strength: 0.017098382122298614 unet_down_block_0_transformer_block_0_attentions_to_v weight average magnitude: 1.4463436289083411 unet_down_block_0_transformer_block_0_attentions_to_v weight average strength: 0.01456201074246142 ```
1 parent 9414f0e commit 3c51f8a

File tree

1 file changed

+126
-20
lines changed

1 file changed

+126
-20
lines changed

lora-inspector.py

Lines changed: 126 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import math
44
import os
55
from collections import OrderedDict
6+
import re
67
from datetime import datetime
78
from pathlib import Path
89
from typing import Callable, OrderedDict
@@ -210,12 +211,109 @@ def find_vectors_weights(vectors):
210211
print(f"Text Encoder weight average magnitude: {avg_mag}")
211212
print(f"Text Encoder weight average strength: {avg_str}")
212213

214+
_keys = [
215+
unet_attn_weight_results.keys(),
216+
unet_conv_weight_results.keys(),
217+
text_encoder_weight_results.keys(),
218+
]
219+
213220
return {
214221
"unet": unet_attn_weight_results,
215222
"text_encoder": text_encoder_weight_results,
216223
}
217224

218225

226+
def find_vectors_weight_blocks(vectors):
227+
weight = ".weight"
228+
229+
results = {}
230+
231+
print(f"model key count: {len(vectors.keys())}")
232+
233+
for k in vectors.keys():
234+
if k.endswith(".weight") is False:
235+
continue
236+
237+
x = find_group(k)
238+
239+
if x is not None:
240+
if x not in results.keys():
241+
results[x] = []
242+
results[x].append(torch.flatten(vectors.get_tensor(k)).tolist())
243+
244+
for key in results.keys():
245+
sum_mag = 0 # average magnitude
246+
sum_str = 0 # average strength
247+
for vectors in results.get(key):
248+
sum_mag += get_vector_data_magnitude(vectors)
249+
sum_str += get_vector_data_strength(vectors)
250+
251+
avg_mag = sum_mag / len(results[key])
252+
avg_str = sum_str / len(results[key])
253+
254+
print(f"{key} weight average magnitude: {avg_mag}")
255+
print(f"{key} weight average strength: {avg_str}")
256+
257+
return results
258+
259+
260+
def find_group(key):
261+
"""
262+
Find the group we want to put these keys into
263+
TODO: describe this better
264+
"""
265+
r = r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_"
266+
267+
matches = re.search(r, key)
268+
269+
if matches is not None and matches.group(3) == "attentions":
270+
r2 = r"(transformer_blocks)_(\d+)_(attn\d+)_(to_[^\.]+).(lora_up|lora_down)"
271+
272+
matches2 = re.search(r2, key)
273+
274+
if matches2 is not None:
275+
# print(
276+
# matches2.group(1),
277+
# matches2.group(2),
278+
# matches2.group(3),
279+
# matches2.group(4),
280+
# matches2.group(5),
281+
# )
282+
283+
block_name = matches.group(1) + "_block" # up|down
284+
block_id = matches.group(2) # \d
285+
transformer_block = matches2.group(2)
286+
attn = matches.group(3)
287+
288+
attn_to = matches2.group(4)
289+
290+
return f"unet_{block_name}_{block_id}_transformer_block_{transformer_block}_{attn}_{attn_to}"
291+
292+
# atten
293+
# atten2
294+
# ff
295+
296+
else:
297+
# lora_unet_up_blocks_3_attentions_0_proj_in.lora_down.weight
298+
# lora_unet_up_blocks_3_attentions_0_proj_in.lora_up.weight
299+
# lora_unet_up_blocks_3_attentions_0_proj_out.lora_down.weight
300+
# lora_unet_up_blocks_3_attentions_0_proj_out.lora_up.weight
301+
# proj
302+
r3 = r"(proj_(in|out)).(lora_up|lora_down)"
303+
# key2 = "transformer_blocks_0_attn1_to_k.lora_down.weight"
304+
305+
matches3 = re.search(r3, key)
306+
307+
if matches3 is not None:
308+
block_name = matches.group(1) + "_block" # up|down
309+
block_id = matches.group(2) # \d
310+
in_out = matches3.group(1)
311+
312+
return f"{block_name}_{block_id}_{in_out}"
313+
314+
return None
315+
316+
219317
def get_vector_data_strength(data: dict[int, Tensor]) -> float:
220318
value = 0
221319
for n in data:
@@ -253,21 +351,25 @@ def process_safetensor_file(file, args):
253351
filename = os.path.basename(file)
254352
print(file)
255353

256-
parsed = {}
354+
meta = {}
257355

258356
if metadata is not None:
259-
parsed = parse_metadata(metadata)
357+
meta = parse_metadata(metadata)
260358
else:
261-
parsed = {}
359+
meta = {}
262360

263-
parsed["file"] = file
264-
parsed["filename"] = filename
361+
meta["file"] = file
362+
meta["filename"] = filename
265363

266364
if args.weights:
267-
find_vectors_weights(f)
365+
weights = find_vectors_weights(f)
366+
elif args.verbose_weights:
367+
weights = find_vectors_weight_blocks(f)
368+
else:
369+
weights = None
268370

269371
print("----------------------")
270-
return parsed
372+
return (weights, meta)
271373

272374

273375
def print_list(list):
@@ -380,7 +482,7 @@ def process(args):
380482

381483

382484
if __name__ == "__main__":
383-
parser = argparse.ArgumentParser()
485+
parser = argparse.ArgumentParser(description="LoRA Inspector")
384486

385487
parser.add_argument(
386488
"lora_file_or_dir", type=str, help="Directory containing the lora files"
@@ -407,12 +509,19 @@ def process(args):
407509
help="Show the most common tags in the training set",
408510
)
409511

512+
parser.add_argument(
513+
"-v",
514+
"--verbose_weights",
515+
action="store_true",
516+
help="Experimental. Average magnitude and strength, separated, for all blocks and attention",
517+
)
518+
410519
args = parser.parse_args()
411-
results = process(args)
520+
(weights, meta) = process(args)
412521

413522
if args.save_meta:
414-
if type(results) == list:
415-
for result in results:
523+
if type(meta) == list:
524+
for result in meta:
416525
# print("result", json.dumps(result, indent=4, sort_keys=True, default=str))
417526
if "ss_session_id" in result:
418527
newfile = (
@@ -426,16 +535,11 @@ def process(args):
426535
save_metadata(newfile, result)
427536
print(f"Metadata saved to {newfile}.json")
428537
else:
429-
if "ss_session_id" in results:
430-
newfile = (
431-
"meta/"
432-
+ str(results["filename"])
433-
+ "-"
434-
+ results["ss_session_id"]
435-
)
538+
if "ss_session_id" in meta:
539+
newfile = "meta/" + str(meta["filename"]) + "-" + meta["ss_session_id"]
436540
else:
437-
newfile = "meta/" + str(results["filename"])
438-
save_metadata(newfile, results)
541+
newfile = "meta/" + str(meta["filename"])
542+
save_metadata(newfile, meta)
439543
print(f"Metadata saved to {newfile}.json")
440544

441545
if args.tags:
@@ -477,3 +581,5 @@ def process(args):
477581
else:
478582
print("No tag frequency found")
479583
# print(results)
584+
# if weights is not None:
585+
# print(weights.keys())

0 commit comments

Comments
 (0)