Skip to content

Commit 2de27e5

Browse files
committed
Improved style transfer demo performance by a small amount.
1 parent 6c3b08a commit 2de27e5

1 file changed

Lines changed: 6 additions & 33 deletions

File tree

bridge/lib/bridge.cpp

Lines changed: 6 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -168,14 +168,6 @@ extern "C" bridge_pt_model_t load_model(const uint8_t* model_path) {
168168
std::cout << "Model loaded successfully!" << std::endl;
169169
std::cout.flush();
170170
return { static_cast<void*>(module) };
171-
172-
// torch::jit::Module tmp = torch::jit::load(path);
173-
// std::cout << "Model loaded successfully!" << std::endl;
174-
// std::cout.flush();
175-
// auto* module = new torch::jit::Module(std::move(tmp));
176-
// std::cout << "Model moved successfully!" << std::endl;
177-
// std::cout.flush();
178-
// return { static_cast<void*>(module) };
179171
} catch (const c10::Error& e) {
180172
std::cerr << "error loading the model\n" << e.msg();
181173
std::cout << "error loading the model\n" << e.msg();
@@ -186,37 +178,18 @@ extern "C" bridge_pt_model_t load_model(const uint8_t* model_path) {
186178
std::cout.flush();
187179

188180
return { nullptr };
189-
190-
191-
192-
// bridge_pt_model_t model_wrapper;
193-
// torch::jit::Module* pt_module = new torch::jit::Module(); // = (torch::jit::Module*) model_wrapper.pt_module;
194-
// try {
195-
// *pt_module = torch::jit::load(mp);
196-
// std::cout << "Model loaded successfully!" << std::endl;
197-
// std::cout.flush();
198-
// model_wrapper.pt_module = pt_module;
199-
// } catch (const c10::Error& e) {
200-
// std::cerr << "error loading the model\n" << e.msg();
201-
// std::cout << "error loading the model\n" << e.msg();
202-
// std::cout.flush();
203-
// std::cerr.flush();
204-
// }
205-
206-
// std::cout << pt_module->dump_to_str(false,false,false) << std::endl;
207-
// std::cout.flush();
208-
209-
// return model_wrapper;
210181
}
211182

212183

213184

214185
bridge_tensor_t model_forward(bridge_pt_model_t model, bridge_tensor_t input, bool is_vgg_based_model) {
215186
auto tn_mps = bridge_to_torch(input,DEVICE,true,DTYPE);
216-
auto tn = tn_mps.permute({2, 0, 1}).unsqueeze(0).contiguous();
187+
tn_mps = tn_mps.permute({2, 0, 1}).contiguous();
188+
tn_mps.unsqueeze_(0);//.contiguous();
189+
// auto tn = tn_mps.permute({2, 0, 1}).unsqueeze(0).contiguous();
217190

218191
std::vector<torch::jit::IValue> ins;
219-
ins.push_back(tn);
192+
ins.push_back(tn_mps);
220193

221194
auto* module = static_cast<torch::jit::Module*>(model.pt_module);
222195
auto o = module->forward(ins).toTensor();
@@ -227,8 +200,8 @@ bridge_tensor_t model_forward(bridge_pt_model_t model, bridge_tensor_t input, bo
227200
tn_out.div_(255.0);
228201
}
229202

230-
auto tn_out_cpu = tn_out.to(torch::kCPU,torch::kFloat32,false,true);
231-
203+
auto tn_out_cpu = tn_out.to(torch::kCPU,torch::kFloat32,false,false);
204+
232205
return torch_to_bridge(tn_out_cpu);
233206

234207
}

0 commit comments

Comments
 (0)