3
3
import math
4
4
import os
5
5
from collections import OrderedDict
6
+ import re
6
7
from datetime import datetime
7
8
from pathlib import Path
8
9
from typing import Callable , OrderedDict
@@ -210,12 +211,109 @@ def find_vectors_weights(vectors):
210
211
print (f"Text Encoder weight average magnitude: { avg_mag } " )
211
212
print (f"Text Encoder weight average strength: { avg_str } " )
212
213
214
+ _keys = [
215
+ unet_attn_weight_results .keys (),
216
+ unet_conv_weight_results .keys (),
217
+ text_encoder_weight_results .keys (),
218
+ ]
219
+
213
220
return {
214
221
"unet" : unet_attn_weight_results ,
215
222
"text_encoder" : text_encoder_weight_results ,
216
223
}
217
224
218
225
226
+ def find_vectors_weight_blocks (vectors ):
227
+ weight = ".weight"
228
+
229
+ results = {}
230
+
231
+ print (f"model key count: { len (vectors .keys ())} " )
232
+
233
+ for k in vectors .keys ():
234
+ if k .endswith (".weight" ) is False :
235
+ continue
236
+
237
+ x = find_group (k )
238
+
239
+ if x is not None :
240
+ if x not in results .keys ():
241
+ results [x ] = []
242
+ results [x ].append (torch .flatten (vectors .get_tensor (k )).tolist ())
243
+
244
+ for key in results .keys ():
245
+ sum_mag = 0 # average magnitude
246
+ sum_str = 0 # average strength
247
+ for vectors in results .get (key ):
248
+ sum_mag += get_vector_data_magnitude (vectors )
249
+ sum_str += get_vector_data_strength (vectors )
250
+
251
+ avg_mag = sum_mag / len (results [key ])
252
+ avg_str = sum_str / len (results [key ])
253
+
254
+ print (f"{ key } weight average magnitude: { avg_mag } " )
255
+ print (f"{ key } weight average strength: { avg_str } " )
256
+
257
+ return results
258
+
259
+
260
+ def find_group (key ):
261
+ """
262
+ Find the group we want to put these keys into
263
+ TODO: describe this better
264
+ """
265
+ r = r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_"
266
+
267
+ matches = re .search (r , key )
268
+
269
+ if matches is not None and matches .group (3 ) == "attentions" :
270
+ r2 = r"(transformer_blocks)_(\d+)_(attn\d+)_(to_[^\.]+).(lora_up|lora_down)"
271
+
272
+ matches2 = re .search (r2 , key )
273
+
274
+ if matches2 is not None :
275
+ # print(
276
+ # matches2.group(1),
277
+ # matches2.group(2),
278
+ # matches2.group(3),
279
+ # matches2.group(4),
280
+ # matches2.group(5),
281
+ # )
282
+
283
+ block_name = matches .group (1 ) + "_block" # up|down
284
+ block_id = matches .group (2 ) # \d
285
+ transformer_block = matches2 .group (2 )
286
+ attn = matches .group (3 )
287
+
288
+ attn_to = matches2 .group (4 )
289
+
290
+ return f"unet_{ block_name } _{ block_id } _transformer_block_{ transformer_block } _{ attn } _{ attn_to } "
291
+
292
+ # atten
293
+ # atten2
294
+ # ff
295
+
296
+ else :
297
+ # lora_unet_up_blocks_3_attentions_0_proj_in.lora_down.weight
298
+ # lora_unet_up_blocks_3_attentions_0_proj_in.lora_up.weight
299
+ # lora_unet_up_blocks_3_attentions_0_proj_out.lora_down.weight
300
+ # lora_unet_up_blocks_3_attentions_0_proj_out.lora_up.weight
301
+ # proj
302
+ r3 = r"(proj_(in|out)).(lora_up|lora_down)"
303
+ # key2 = "transformer_blocks_0_attn1_to_k.lora_down.weight"
304
+
305
+ matches3 = re .search (r3 , key )
306
+
307
+ if matches3 is not None :
308
+ block_name = matches .group (1 ) + "_block" # up|down
309
+ block_id = matches .group (2 ) # \d
310
+ in_out = matches3 .group (1 )
311
+
312
+ return f"{ block_name } _{ block_id } _{ in_out } "
313
+
314
+ return None
315
+
316
+
219
317
def get_vector_data_strength (data : dict [int , Tensor ]) -> float :
220
318
value = 0
221
319
for n in data :
@@ -253,21 +351,25 @@ def process_safetensor_file(file, args):
253
351
filename = os .path .basename (file )
254
352
print (file )
255
353
256
- parsed = {}
354
+ meta = {}
257
355
258
356
if metadata is not None :
259
- parsed = parse_metadata (metadata )
357
+ meta = parse_metadata (metadata )
260
358
else :
261
- parsed = {}
359
+ meta = {}
262
360
263
- parsed ["file" ] = file
264
- parsed ["filename" ] = filename
361
+ meta ["file" ] = file
362
+ meta ["filename" ] = filename
265
363
266
364
if args .weights :
267
- find_vectors_weights (f )
365
+ weights = find_vectors_weights (f )
366
+ elif args .verbose_weights :
367
+ weights = find_vectors_weight_blocks (f )
368
+ else :
369
+ weights = None
268
370
269
371
print ("----------------------" )
270
- return parsed
372
+ return ( weights , meta )
271
373
272
374
273
375
def print_list (list ):
@@ -380,7 +482,7 @@ def process(args):
380
482
381
483
382
484
if __name__ == "__main__" :
383
- parser = argparse .ArgumentParser ()
485
+ parser = argparse .ArgumentParser (description = "LoRA Inspector" )
384
486
385
487
parser .add_argument (
386
488
"lora_file_or_dir" , type = str , help = "Directory containing the lora files"
@@ -407,12 +509,19 @@ def process(args):
407
509
help = "Show the most common tags in the training set" ,
408
510
)
409
511
512
+ parser .add_argument (
513
+ "-v" ,
514
+ "--verbose_weights" ,
515
+ action = "store_true" ,
516
+ help = "Experimental. Average magnitude and strength, separated, for all blocks and attention" ,
517
+ )
518
+
410
519
args = parser .parse_args ()
411
- results = process (args )
520
+ ( weights , meta ) = process (args )
412
521
413
522
if args .save_meta :
414
- if type (results ) == list :
415
- for result in results :
523
+ if type (meta ) == list :
524
+ for result in meta :
416
525
# print("result", json.dumps(result, indent=4, sort_keys=True, default=str))
417
526
if "ss_session_id" in result :
418
527
newfile = (
@@ -426,16 +535,11 @@ def process(args):
426
535
save_metadata (newfile , result )
427
536
print (f"Metadata saved to { newfile } .json" )
428
537
else :
429
- if "ss_session_id" in results :
430
- newfile = (
431
- "meta/"
432
- + str (results ["filename" ])
433
- + "-"
434
- + results ["ss_session_id" ]
435
- )
538
+ if "ss_session_id" in meta :
539
+ newfile = "meta/" + str (meta ["filename" ]) + "-" + meta ["ss_session_id" ]
436
540
else :
437
- newfile = "meta/" + str (results ["filename" ])
438
- save_metadata (newfile , results )
541
+ newfile = "meta/" + str (meta ["filename" ])
542
+ save_metadata (newfile , meta )
439
543
print (f"Metadata saved to { newfile } .json" )
440
544
441
545
if args .tags :
@@ -477,3 +581,5 @@ def process(args):
477
581
else :
478
582
print ("No tag frequency found" )
479
583
# print(results)
584
+ # if weights is not None:
585
+ # print(weights.keys())
0 commit comments