Skip to content

Commit 1b87ef2

Browse files
committed
2 parents 99ba2cb + b2842c9 commit 1b87ef2

File tree

4 files changed

+44
-42
lines changed

4 files changed

+44
-42
lines changed

Diff for: ldm/modules/pruningckptio.py

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from pytorch_lightning.plugins.io.torch_plugin import TorchCheckpointIO
2+
3+
from typing import Any, Callable, Dict, Optional
4+
from pytorch_lightning.utilities.types import _PATH
5+
from ldm.pruner import prune_checkpoint
6+
7+
class PruningCheckpointIO(TorchCheckpointIO):
8+
def save_checkpoint(
9+
self,
10+
checkpoint: Dict[str, Any],
11+
path: _PATH,
12+
storage_options: Optional[Any] = None
13+
) -> None:
14+
pruned_checkpoint = prune_checkpoint(checkpoint)
15+
TorchCheckpointIO.save_checkpoint(self, pruned_checkpoint, path, storage_options)

Diff for: ldm/pruner.py

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
def prune_checkpoint(old_state):
2+
print(f"Prunin' Checkpoint")
3+
pruned_checkpoint = dict()
4+
print(f"Checkpoint Keys: {old_state.keys()}")
5+
for key in old_state.keys():
6+
if key != "optimizer_states":
7+
pruned_checkpoint[key] = old_state[key]
8+
else:
9+
print("Removing optimizer states from checkpoint")
10+
if "global_step" in old_state:
11+
print(f"This is global step {old_state['global_step']}.")
12+
old_state = pruned_checkpoint['state_dict'].copy()
13+
new_state = dict()
14+
for key in old_state:
15+
new_state[key] = old_state[key].half()
16+
pruned_checkpoint['state_dict'] = new_state
17+
return pruned_checkpoint

Diff for: main.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import argparse, os, sys, datetime, glob, importlib, csv
2+
from ldm.modules.pruningckptio import PruningCheckpointIO
23
import numpy as np
34
import time
45
import torch
@@ -785,7 +786,8 @@ def on_train_epoch_start(self, trainer, pl_module):
785786

786787
trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg]
787788
trainer_kwargs["max_steps"] = trainer_opt.max_steps
788-
789+
trainer_kwargs["plugins"] = PruningCheckpointIO()
790+
789791
trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
790792
trainer.logdir = logdir ###
791793

Diff for: prune_ckpt.py

+9-41
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,26 @@
11
import os
2+
from ldm.pruner import prune_checkpoint
23
import torch
34
import argparse
4-
import glob
55

66

77
parser = argparse.ArgumentParser(description='Pruning')
88
parser.add_argument('--ckpt', type=str, default=None, help='path to model ckpt')
99
args = parser.parse_args()
1010
ckpt = args.ckpt
1111

12-
def prune_it(p, keep_only_ema=False):
13-
print(f"prunin' in path: {p}")
14-
size_initial = os.path.getsize(p)
15-
nsd = dict()
16-
sd = torch.load(p, map_location="cpu")
17-
print(sd.keys())
18-
for k in sd.keys():
19-
if k != "optimizer_states":
20-
nsd[k] = sd[k]
21-
else:
22-
print(f"removing optimizer states for path {p}")
23-
if "global_step" in sd:
24-
print(f"This is global step {sd['global_step']}.")
25-
if keep_only_ema:
26-
sd = nsd["state_dict"].copy()
27-
# infer ema keys
28-
ema_keys = {k: "model_ema." + k[6:].replace(".", ".") for k in sd.keys() if k.startswith("model.")}
29-
new_sd = dict()
30-
31-
for k in sd:
32-
if k in ema_keys:
33-
new_sd[k] = sd[ema_keys[k]].half()
34-
elif not k.startswith("model_ema.") or k in ["model_ema.num_updates", "model_ema.decay"]:
35-
new_sd[k] = sd[k].half()
36-
37-
assert len(new_sd) == len(sd) - len(ema_keys)
38-
nsd["state_dict"] = new_sd
39-
else:
40-
sd = nsd['state_dict'].copy()
41-
new_sd = dict()
42-
for k in sd:
43-
new_sd[k] = sd[k].half()
44-
nsd['state_dict'] = new_sd
45-
46-
fn = f"{os.path.splitext(p)[0]}-pruned.ckpt" if not keep_only_ema else f"{os.path.splitext(p)[0]}-ema-pruned.ckpt"
47-
print(f"saving pruned checkpoint at: {fn}")
48-
torch.save(nsd, fn)
12+
def prune_it(checkpoint_path):
13+
print(f"Prunin' checkpoint from path: {checkpoint_path}")
14+
size_initial = os.path.getsize(checkpoint_path)
15+
checkpoint = torch.load(checkpoint_path, map_location="cpu")
16+
pruned = prune_checkpoint(checkpoint)
17+
fn = f"{os.path.splitext(checkpoint_path)[0]}-pruned.ckpt"
18+
print(f"Saving pruned checkpoint at: {fn}")
19+
torch.save(pruned, fn)
4920
newsize = os.path.getsize(fn)
5021
MSG = f"New ckpt size: {newsize*1e-9:.2f} GB. " + \
5122
f"Saved {(size_initial - newsize)*1e-9:.2f} GB by removing optimizer states"
52-
if keep_only_ema:
53-
MSG += " and non-EMA weights"
5423
print(MSG)
5524

56-
5725
if __name__ == "__main__":
5826
prune_it(ckpt)

0 commit comments

Comments
 (0)