Skip to content

Commit 0b180cd

Browse files
committed
Improve something.
1 parent 5aba91a commit 0b180cd

31 files changed

Lines changed: 452 additions & 245 deletions

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,4 +43,5 @@ build/
4343
libtorch/
4444
.cache/
4545
build-release/
46-
libtorch_static/
46+
libtorch_static/
47+
examples/vgg/**/*.pt

bridge/include/bridge.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ typedef struct nil_scalar_tensor_t {
3131

3232
float* unsafe(const float* arr);
3333
bridge_tensor_t load_tensor_from_file(const uint8_t* file_path);
34+
bridge_tensor_t load_tensor_dict_from_file(const uint8_t* file_path,const uint8_t* tensor_key);
35+
bridge_tensor_t load_run_model(const uint8_t* model_path, bridge_tensor_t input);
3436

3537

3638
int baz(void);

bridge/lib/bridge.cpp

Lines changed: 50 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#include <bridge.h>
22

33
#include <torch/torch.h>
4+
#include <torch/script.h>
5+
46
// #include <torch/script.h>
57
// #include <Aten/ATen.h>
68
#include <iostream>
@@ -78,26 +80,62 @@ std::vector<char> get_the_bytes(std::string filename) {
7880
extern "C" bridge_tensor_t load_tensor_from_file(const uint8_t* file_path) {
7981
// // Load the tensor from a file
8082
// torch::Tensor tensor;
81-
// // torch::load(tensor,file_path);
83+
// torch::load(tensor,file_path);
8284

8385
// std::cout << "Tensor loaded from file: " << tensor.sizes() << std::endl;
8486

85-
// // Convert the tensor to a bridge_tensor_t
87+
std::string fp(reinterpret_cast<const char*>(file_path));
88+
std::vector<char> f = get_the_bytes(fp);
89+
torch::IValue x = torch::pickle_load(f);
90+
torch::Tensor t = x.toTensor();
91+
return torch_to_bridge(t);
92+
}
93+
94+
extern "C" bridge_tensor_t load_tensor_dict_from_file(const uint8_t* file_path,const uint8_t* tensor_key) {
95+
96+
97+
std::cout << "Loading tensor from file: " << file_path << std::endl;
98+
std::cout << "Tensor key: " << tensor_key << std::endl;
99+
100+
std::cout.flush();
86101

87102
std::string fp(reinterpret_cast<const char*>(file_path));
88-
std::cout << "File path: " << fp << std::endl;
103+
std::string tk(reinterpret_cast<const char*>(tensor_key));
89104

90-
std::vector<char> f = get_the_bytes(fp);
91-
std::cout << "File size: " << f.size() << std::endl;
105+
torch::jit::script::Module container = torch::jit::load(fp);
106+
torch::Tensor tensor = container.attr(tk).toTensor();
92107

93-
torch::IValue x = torch::pickle_load(f);
94-
// std::cout << "IValue loaded from file: " << x << std::endl;
108+
return torch_to_bridge(tensor);
95109

96-
torch::Tensor t = x.toTensor();
97-
std::cout << "Tensor loaded from IValue: " << t.sizes() << std::endl;
98-
std::cout << "Tensor sum: " << t.sum() << std::endl;
110+
}
99111

100-
return torch_to_bridge(t);
112+
extern "C" bridge_tensor_t load_run_model(const uint8_t* model_path, bridge_tensor_t input) {
113+
auto t_input = bridge_to_torch(input);
114+
std::string mp(reinterpret_cast<const char*>(model_path));
115+
116+
std::cout << "Loading model from path: " << mp << std::endl;
117+
std::cout.flush();
118+
119+
120+
torch::jit::Module module;
121+
try
122+
{
123+
// Deserialize the ScriptModule from a file using torch::jit::load().
124+
module = torch::jit::load(mp);
125+
}
126+
catch (const c10::Error& e)
127+
{
128+
std::cerr << "error loading the model\n" << e.msg();
129+
std::system("pause");
130+
}
131+
132+
std::vector<torch::jit::IValue> inputs;
133+
inputs.push_back(t_input);
134+
135+
auto output = module.forward(inputs).toTensor();
136+
137+
std::cout << "Model output: " << output.sizes() << std::endl;
138+
return torch_to_bridge(output);
101139
}
102140

103141

@@ -108,6 +146,7 @@ extern "C" bridge_tensor_t load_tensor_from_file(const uint8_t* file_path) {
108146

109147

110148

149+
111150
extern "C" bridge_tensor_t increment3(bridge_tensor_t arr) {
112151
auto t = bridge_to_torch(arr);
113152
// Increment the tensor

examples/vgg/dump_weights.py

Lines changed: 61 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,71 @@
33
import sys
44
import os
55

6+
import vgg
67

7-
from vgg import vgg16
88

9-
model = vgg16(pretrained=True)
9+
# from vgg import vgg16
1010

11-
model = model.to(torch.float16)
11+
# model = vgg16(pretrained=True)
1212

13-
# Add the scripts directory to the sys.path
14-
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', 'scripts')))
13+
# model = model.to(torch.float16)
1514

16-
import chai
15+
# # Add the scripts directory to the sys.path
16+
# sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', 'scripts')))
1717

18-
os.makedirs('models/vgg16', exist_ok=True)
19-
model.chai_dump('models/vgg16','vgg16', with_json=False, verbose=True)
18+
# import chai
2019

21-
# print(model.state_dict().keys())
22-
# print([(n,w.dtype) for (n,w) in model.state_dict().items()])
20+
# os.makedirs('models/vgg16', exist_ok=True)
21+
# model.chai_dump('models/vgg16','vgg16', with_json=False, verbose=True)
22+
23+
# # print(model.state_dict().keys())
24+
# # print([(n,w.dtype) for (n,w) in model.state_dict().items()])
25+
26+
27+
for vggxx in vgg.model_urls.keys():
28+
print(vggxx)
29+
model = vgg.__dict__[vggxx](pretrained=True)
30+
model = model.to(torch.float32)
31+
model.eval()
32+
33+
# # Add the scripts directory to the sys.path
34+
# sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', 'scripts')))
35+
36+
# import chai
37+
38+
# os.makedirs('models/vgg16', exist_ok=True)
39+
# model.chai_dump('models/vgg16',vggxx, with_json=False, verbose=True)
40+
41+
# print(f'models/traced_{vggxx}.pt')
42+
# faux_input = torch.randn(3,720,1280)
43+
# traced_model = torch.jit.trace(model, faux_input)
44+
# traced_model.save(f'models/traced_{vggxx}.pt')
45+
46+
print(f'models/{vggxx}.pt')
47+
sd = model.state_dict()
48+
torch.save(sd, f'models/{vggxx}.pt')
49+
50+
# print(f'models/traced_{vggxx}.pt')
51+
# faux_input = torch.randn(3,720,1280).to(torch.float16)
52+
# traced_model = torch.jit.trace(model, faux_input)
53+
# torch.save(traced_model,f'models/traced_{vggxx}.pt')
54+
55+
print(f'models/script_{vggxx}.pt')
56+
script_model = torch.jit.script(model)
57+
script_model.save(f'models/script_{vggxx}.pt')
58+
59+
test = torch.rand(1,3,720,1280).to(torch.float32)
60+
# print(model(test).shape)
61+
62+
model.eval()
63+
print(f'models/trace_{vggxx}.pt')
64+
faux_input = torch.rand(1,3,720,1280).to(torch.float32)
65+
print(model(faux_input).shape)
66+
trace_model = torch.jit.trace(model,faux_input)
67+
trace_model.save(f'models/trace_{vggxx}.pt')
68+
# torch.save(trace_model,f'models/trace_{vggxx}.pt')
69+
70+
71+
# print(model.state_dict().keys())
72+
# print([(n,w.dtype) for (n,w) in model.state_dict().items()])
73+

examples/vgg/mktensor.ipynb

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,37 @@
7878
},
7979
{
8080
"cell_type": "code",
81-
"execution_count": null,
81+
"execution_count": 24,
8282
"id": "1a17ec15",
8383
"metadata": {},
8484
"outputs": [],
85-
"source": []
85+
"source": [
86+
"m = {'a': x, 'b': x + 1}\n",
87+
"torch.save(m, 'my_tensor_dict.pt')"
88+
]
89+
},
90+
{
91+
"cell_type": "code",
92+
"execution_count": null,
93+
"id": "841d4dfe",
94+
"metadata": {},
95+
"outputs": [],
96+
"source": [
97+
"class MyModule(torch.nn.Module):\n",
98+
" def __init__(self):\n",
99+
" super(MyModule, self).__init__()\n",
100+
" self.conv1 = torch.nn.Conv2d(in_channels=3,\n",
101+
" out_channels=16,\n",
102+
" kernel_size=3,\n",
103+
" stride=2)\n",
104+
" self.conv2 = torch.nn.Conv2d(in_channels=16,\n",
105+
" out_channels=32,\n",
106+
" kernel_size=3,\n",
107+
" stride=2)\n",
108+
"\n",
109+
" def forward(self, x):\n",
110+
" return self.linear(x)"
111+
]
86112
}
87113
],
88114
"metadata": {

examples/vgg/models/vgg16/classifier.0.bias.meta.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,6 @@
55
],
66
"rank": 1,
77
"size": 4096,
8-
"dtype": "float16",
9-
"element_bits": 16
8+
"dtype": "float32",
9+
"element_bits": 32
1010
}

examples/vgg/models/vgg16/classifier.0.weight.meta.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,6 @@
66
],
77
"rank": 2,
88
"size": 102760448,
9-
"dtype": "float16",
10-
"element_bits": 16
9+
"dtype": "float32",
10+
"element_bits": 32
1111
}

examples/vgg/models/vgg16/classifier.3.bias.meta.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,6 @@
55
],
66
"rank": 1,
77
"size": 4096,
8-
"dtype": "float16",
9-
"element_bits": 16
8+
"dtype": "float32",
9+
"element_bits": 32
1010
}

examples/vgg/models/vgg16/classifier.3.weight.meta.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,6 @@
66
],
77
"rank": 2,
88
"size": 16777216,
9-
"dtype": "float16",
10-
"element_bits": 16
9+
"dtype": "float32",
10+
"element_bits": 32
1111
}

examples/vgg/models/vgg16/classifier.6.bias.meta.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,6 @@
55
],
66
"rank": 1,
77
"size": 1000,
8-
"dtype": "float16",
9-
"element_bits": 16
8+
"dtype": "float32",
9+
"element_bits": 32
1010
}

0 commit comments

Comments
 (0)