|
6 | 6 | import vgg |
7 | 7 |
|
8 | 8 |
|
9 | | -# from vgg import vgg16 |
| 9 | +from vgg import vgg16 |
10 | 10 |
|
11 | | -# model = vgg16(pretrained=True) |
| 11 | +model = vgg16(pretrained=True) |
12 | 12 |
|
13 | | -# model = model.to(torch.float16) |
| 13 | +model = model.to(torch.float16) |
14 | 14 |
|
15 | | -# # Add the scripts directory to the sys.path |
16 | | -# sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', 'scripts'))) |
| 15 | +# Add the scripts directory to the sys.path |
| 16 | +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', 'scripts'))) |
17 | 17 |
|
18 | | -# import chai |
| 18 | +import chai |
19 | 19 |
|
20 | | -# os.makedirs('models/vgg16', exist_ok=True) |
21 | | -# model.chai_dump('models/vgg16','vgg16', with_json=False, verbose=True) |
| 20 | +os.makedirs('models/vgg16', exist_ok=True) |
| 21 | +model.chai_dump('models/vgg16','vgg16', with_json=False, verbose=True) |
22 | 22 |
|
23 | 23 | # # print(model.state_dict().keys()) |
24 | 24 | # # print([(n,w.dtype) for (n,w) in model.state_dict().items()]) |
|
27 | 27 | for vggxx in vgg.model_urls.keys(): |
28 | 28 | print(vggxx) |
29 | 29 | model = vgg.__dict__[vggxx](pretrained=True) |
30 | | - model = model.to(torch.float32) |
| 30 | + model = model.to(torch.float16) |
31 | 31 | model.eval() |
32 | 32 |
|
33 | 33 | # # Add the scripts directory to the sys.path |
| 34 | + |
34 | 35 | # sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', 'scripts'))) |
35 | 36 |
|
36 | 37 | # import chai |
|
43 | 44 | # traced_model = torch.jit.trace(model, faux_input) |
44 | 45 | # traced_model.save(f'models/traced_{vggxx}.pt') |
45 | 46 |
|
46 | | - print(f'models/{vggxx}.pt') |
47 | | - sd = model.state_dict() |
48 | | - torch.save(sd, f'models/{vggxx}.pt') |
| 47 | + # print(f'models/{vggxx}.pt') |
| 48 | + # sd = model.state_dict() |
| 49 | + # torch.save(sd, f'models/{vggxx}.pt') |
49 | 50 |
|
50 | | - # print(f'models/traced_{vggxx}.pt') |
51 | | - # faux_input = torch.randn(3,720,1280).to(torch.float16) |
52 | | - # traced_model = torch.jit.trace(model, faux_input) |
53 | | - # torch.save(traced_model,f'models/traced_{vggxx}.pt') |
| 51 | + # # print(f'models/traced_{vggxx}.pt') |
| 52 | + # # faux_input = torch.randn(3,720,1280).to(torch.float16) |
| 53 | + # # traced_model = torch.jit.trace(model, faux_input) |
| 54 | + # # torch.save(traced_model,f'models/traced_{vggxx}.pt') |
54 | 55 |
|
55 | | - print(f'models/script_{vggxx}.pt') |
56 | | - script_model = torch.jit.script(model) |
57 | | - script_model.save(f'models/script_{vggxx}.pt') |
| 56 | + # print(f'models/script_{vggxx}.pt') |
| 57 | + # script_model = torch.jit.script(model) |
| 58 | + # script_model.save(f'models/script_{vggxx}.pt') |
58 | 59 |
|
59 | | - test = torch.rand(1,3,720,1280).to(torch.float32) |
| 60 | + # test = torch.rand(1,3,720,1280).to(torch.float32) |
60 | 61 | # print(model(test).shape) |
61 | 62 |
|
62 | 63 | model.eval() |
63 | 64 | print(f'models/trace_{vggxx}.pt') |
64 | | - faux_input = torch.rand(1,3,720,1280).to(torch.float32) |
| 65 | + faux_input = torch.rand(1,3,720,1280).to(torch.float16) |
65 | 66 | print(model(faux_input).shape) |
66 | 67 | trace_model = torch.jit.trace(model,faux_input) |
67 | 68 | trace_model.save(f'models/trace_{vggxx}.pt') |
|
0 commit comments