33import sys
44import os
55
6+ import vgg
67
7- from vgg import vgg16
88
9- model = vgg16 ( pretrained = True )
9+ # from vgg import vgg16
1010
11- model = model . to ( torch . float16 )
11+ # model = vgg16(pretrained=True )
1212
13- # Add the scripts directory to the sys.path
14- sys .path .append (os .path .abspath (os .path .join (os .path .dirname (__file__ ), '..' , '..' , 'scripts' )))
13+ # model = model.to(torch.float16)
1514
16- import chai
15+ # # Add the scripts directory to the sys.path
16+ # sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', 'scripts')))
1717
18- os .makedirs ('models/vgg16' , exist_ok = True )
19- model .chai_dump ('models/vgg16' ,'vgg16' , with_json = False , verbose = True )
18+ # import chai
2019
21- # print(model.state_dict().keys())
22- # print([(n,w.dtype) for (n,w) in model.state_dict().items()])
20+ # os.makedirs('models/vgg16', exist_ok=True)
21+ # model.chai_dump('models/vgg16','vgg16', with_json=False, verbose=True)
22+
23+ # # print(model.state_dict().keys())
24+ # # print([(n,w.dtype) for (n,w) in model.state_dict().items()])
25+
26+
27+ for vggxx in vgg .model_urls .keys ():
28+ print (vggxx )
29+ model = vgg .__dict__ [vggxx ](pretrained = True )
30+ model = model .to (torch .float32 )
31+ model .eval ()
32+
33+ # # Add the scripts directory to the sys.path
34+ # sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', 'scripts')))
35+
36+ # import chai
37+
38+ # os.makedirs('models/vgg16', exist_ok=True)
39+ # model.chai_dump('models/vgg16',vggxx, with_json=False, verbose=True)
40+
41+ # print(f'models/traced_{vggxx}.pt')
42+ # faux_input = torch.randn(3,720,1280)
43+ # traced_model = torch.jit.trace(model, faux_input)
44+ # traced_model.save(f'models/traced_{vggxx}.pt')
45+
46+ print (f'models/{ vggxx } .pt' )
47+ sd = model .state_dict ()
48+ torch .save (sd , f'models/{ vggxx } .pt' )
49+
50+ # print(f'models/traced_{vggxx}.pt')
51+ # faux_input = torch.randn(3,720,1280).to(torch.float16)
52+ # traced_model = torch.jit.trace(model, faux_input)
53+ # torch.save(traced_model,f'models/traced_{vggxx}.pt')
54+
55+ print (f'models/script_{ vggxx } .pt' )
56+ script_model = torch .jit .script (model )
57+ script_model .save (f'models/script_{ vggxx } .pt' )
58+
59+ test = torch .rand (1 ,3 ,720 ,1280 ).to (torch .float32 )
60+ # print(model(test).shape)
61+
62+ model .eval ()
63+ print (f'models/trace_{ vggxx } .pt' )
64+ faux_input = torch .rand (1 ,3 ,720 ,1280 ).to (torch .float32 )
65+ print (model (faux_input ).shape )
66+ trace_model = torch .jit .trace (model ,faux_input )
67+ trace_model .save (f'models/trace_{ vggxx } .pt' )
68+ # torch.save(trace_model,f'models/trace_{vggxx}.pt')
69+
70+
71+ # print(model.state_dict().keys())
72+ # print([(n,w.dtype) for (n,w) in model.state_dict().items()])
73+
0 commit comments