1
1
import os
2
+ from ldm .pruner import prune_checkpoint
2
3
import torch
3
4
import argparse
4
- import glob
5
5
6
6
7
7
parser = argparse .ArgumentParser (description = 'Pruning' )
8
8
parser .add_argument ('--ckpt' , type = str , default = None , help = 'path to model ckpt' )
9
9
args = parser .parse_args ()
10
10
ckpt = args .ckpt
11
11
12
- def prune_it (p , keep_only_ema = False ):
13
- print (f"prunin' in path: { p } " )
14
- size_initial = os .path .getsize (p )
15
- nsd = dict ()
16
- sd = torch .load (p , map_location = "cpu" )
17
- print (sd .keys ())
18
- for k in sd .keys ():
19
- if k != "optimizer_states" :
20
- nsd [k ] = sd [k ]
21
- else :
22
- print (f"removing optimizer states for path { p } " )
23
- if "global_step" in sd :
24
- print (f"This is global step { sd ['global_step' ]} ." )
25
- if keep_only_ema :
26
- sd = nsd ["state_dict" ].copy ()
27
- # infer ema keys
28
- ema_keys = {k : "model_ema." + k [6 :].replace ("." , "." ) for k in sd .keys () if k .startswith ("model." )}
29
- new_sd = dict ()
30
-
31
- for k in sd :
32
- if k in ema_keys :
33
- new_sd [k ] = sd [ema_keys [k ]].half ()
34
- elif not k .startswith ("model_ema." ) or k in ["model_ema.num_updates" , "model_ema.decay" ]:
35
- new_sd [k ] = sd [k ].half ()
36
-
37
- assert len (new_sd ) == len (sd ) - len (ema_keys )
38
- nsd ["state_dict" ] = new_sd
39
- else :
40
- sd = nsd ['state_dict' ].copy ()
41
- new_sd = dict ()
42
- for k in sd :
43
- new_sd [k ] = sd [k ].half ()
44
- nsd ['state_dict' ] = new_sd
45
-
46
- fn = f"{ os .path .splitext (p )[0 ]} -pruned.ckpt" if not keep_only_ema else f"{ os .path .splitext (p )[0 ]} -ema-pruned.ckpt"
47
- print (f"saving pruned checkpoint at: { fn } " )
48
- torch .save (nsd , fn )
12
+ def prune_it (checkpoint_path ):
13
+ print (f"Prunin' checkpoint from path: { checkpoint_path } " )
14
+ size_initial = os .path .getsize (checkpoint_path )
15
+ checkpoint = torch .load (checkpoint_path , map_location = "cpu" )
16
+ pruned = prune_checkpoint (checkpoint )
17
+ fn = f"{ os .path .splitext (checkpoint_path )[0 ]} -pruned.ckpt"
18
+ print (f"Saving pruned checkpoint at: { fn } " )
19
+ torch .save (pruned , fn )
49
20
newsize = os .path .getsize (fn )
50
21
MSG = f"New ckpt size: { newsize * 1e-9 :.2f} GB. " + \
51
22
f"Saved { (size_initial - newsize )* 1e-9 :.2f} GB by removing optimizer states"
52
- if keep_only_ema :
53
- MSG += " and non-EMA weights"
54
23
print (MSG )
55
24
56
-
57
25
if __name__ == "__main__" :
58
26
prune_it (ckpt )
0 commit comments