-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathreparameterization.py
More file actions
109 lines (103 loc) · 5.5 KB
/
Copy pathreparameterization.py
File metadata and controls
109 lines (103 loc) · 5.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import os
# current_working_directory = os.getcwd()
# # print output to the console
# print(current_working_directory)
# os.chdir("/Users/tuannguyenxuan/Documents/Ventura/GT/fisheye8k/yolov9/")
# print(os.getcwd())
import torch
from models.yolo import Model
import argparse
def main(args):
device = torch.device(args.device)
model = Model(args.cfg, ch=3, nc=args.classes_num, anchors=3)
#model = model.half()
model = model.to(device)
_ = model.eval()
ckpt = torch.load(args.weights, map_location='cpu')
model.names = ckpt['model'].names
model.nc = ckpt['model'].nc
idx = 0
for k, v in model.state_dict().items():
if "model.{}.".format(idx) in k:
if (args.model == "c" and idx < 22) or (args.model == "e" and idx < 29):
kr = k.replace("model.{}.".format(idx), "model.{}.".format(idx + 1 if args.model == 'c' else idx))
model.state_dict()[k] -= model.state_dict()[k]
model.state_dict()[k] += ckpt['model'].state_dict()[kr]
print(k, "perfectly matched!!")
elif args.model == 'e' and idx < 42:
kr = k.replace("model.{}.".format(idx), "model.{}.".format(idx + 7))
model.state_dict()[k] -= model.state_dict()[k]
model.state_dict()[k] += ckpt['model'].state_dict()[kr]
print(k, "perfectly matched!!")
elif "model.{}.cv2.".format(idx) in k:
kr = k.replace("model.{}.cv2.".format(idx),
"model.{}.cv4.".format(idx + 16 if args.model == 'c' else idx + 7))
model.state_dict()[k] -= model.state_dict()[k]
model.state_dict()[k] += ckpt['model'].state_dict()[kr]
print(k, "perfectly matched!!")
elif "model.{}.cv3.".format(idx) in k:
kr = k.replace("model.{}.cv3.".format(idx),
"model.{}.cv5.".format(idx + 16 if args.model == 'c' else idx + 7))
model.state_dict()[k] -= model.state_dict()[k]
model.state_dict()[k] += ckpt['model'].state_dict()[kr]
print(k, "perfectly matched!!")
elif "model.{}.dfl.".format(idx) in k:
kr = k.replace("model.{}.dfl.".format(idx),
"model.{}.dfl2.".format(idx + 16 if args.model == 'c' else idx + 7))
model.state_dict()[k] -= model.state_dict()[k]
model.state_dict()[k] += ckpt['model'].state_dict()[kr]
print(k, "perfectly matched!!")
else:
while True:
idx += 1
if "model.{}.".format(idx) in k:
break
if (args.model == "c" and idx < 22) or (args.model == "e" and idx < 29):
kr = k.replace("model.{}.".format(idx), "model.{}.".format(idx + 1 if args.model == 'c' else idx))
model.state_dict()[k] -= model.state_dict()[k]
model.state_dict()[k] += ckpt['model'].state_dict()[kr]
print(k, "perfectly matched!!")
elif args.model == 'e' and idx < 42:
kr = k.replace("model.{}.".format(idx), "model.{}.".format(idx + 7))
model.state_dict()[k] -= model.state_dict()[k]
model.state_dict()[k] += ckpt['model'].state_dict()[kr]
print(k, "perfectly matched!!")
elif "model.{}.cv2.".format(idx) in k:
kr = k.replace("model.{}.cv2.".format(idx),
"model.{}.cv4.".format(idx + 16 if args.model == 'c' else idx + 7))
model.state_dict()[k] -= model.state_dict()[k]
model.state_dict()[k] += ckpt['model'].state_dict()[kr]
print(k, "perfectly matched!!")
elif "model.{}.cv3.".format(idx) in k:
kr = k.replace("model.{}.cv3.".format(idx),
"model.{}.cv5.".format(idx + 16 if args.model == 'c' else idx + 7))
model.state_dict()[k] -= model.state_dict()[k]
model.state_dict()[k] += ckpt['model'].state_dict()[kr]
print(k, "perfectly matched!!")
elif "model.{}.dfl.".format(idx) in k:
kr = k.replace("model.{}.dfl.".format(idx),
"model.{}.dfl2.".format(idx + 16 if args.model == 'c' else idx + 7))
model.state_dict()[k] -= model.state_dict()[k]
model.state_dict()[k] += ckpt['model'].state_dict()[kr]
print(k, "perfectly matched!!")
_ = model.eval()
m_ckpt = {'model': model.half(),
'optimizer': None,
'best_fitness': None,
'ema': None,
'updates': None,
'opt': None,
'git': None,
'date': None,
'epoch': -1}
torch.save(m_ckpt, args.save)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--cfg', type=str, default='../models/detect/gelan-e-our.yaml', help='model.yaml path')
parser.add_argument('--model', type=str, default='e', help='convert model type (c or e)')
parser.add_argument('--weights', type=str, default='./yolov9-e-modify-trained.pt', help='weights path')
parser.add_argument('--device', default='cpu', help='device id (i.e. 0 or 0,1) or cpu')
parser.add_argument('--classes_num', default=80, type=int, help='number of classes')
parser.add_argument('--save', default='./yolov9-e-modify-converted.pt', type=str, help='save path')
args = parser.parse_args()
main(args)