@@ -168,14 +168,6 @@ extern "C" bridge_pt_model_t load_model(const uint8_t* model_path) {
168168 std::cout << " Model loaded successfully!" << std::endl;
169169 std::cout.flush ();
170170 return { static_cast <void *>(module ) };
171-
172- // torch::jit::Module tmp = torch::jit::load(path);
173- // std::cout << "Model loaded successfully!" << std::endl;
174- // std::cout.flush();
175- // auto* module = new torch::jit::Module(std::move(tmp));
176- // std::cout << "Model moved successfully!" << std::endl;
177- // std::cout.flush();
178- // return { static_cast<void*>(module) };
179171 } catch (const c10::Error& e) {
180172 std::cerr << " error loading the model\n " << e.msg ();
181173 std::cout << " error loading the model\n " << e.msg ();
@@ -186,37 +178,18 @@ extern "C" bridge_pt_model_t load_model(const uint8_t* model_path) {
186178 std::cout.flush ();
187179
188180 return { nullptr };
189-
190-
191-
192- // bridge_pt_model_t model_wrapper;
193- // torch::jit::Module* pt_module = new torch::jit::Module(); // = (torch::jit::Module*) model_wrapper.pt_module;
194- // try {
195- // *pt_module = torch::jit::load(mp);
196- // std::cout << "Model loaded successfully!" << std::endl;
197- // std::cout.flush();
198- // model_wrapper.pt_module = pt_module;
199- // } catch (const c10::Error& e) {
200- // std::cerr << "error loading the model\n" << e.msg();
201- // std::cout << "error loading the model\n" << e.msg();
202- // std::cout.flush();
203- // std::cerr.flush();
204- // }
205-
206- // std::cout << pt_module->dump_to_str(false,false,false) << std::endl;
207- // std::cout.flush();
208-
209- // return model_wrapper;
210181}
211182
212183
213184
214185bridge_tensor_t model_forward (bridge_pt_model_t model, bridge_tensor_t input, bool is_vgg_based_model) {
215186 auto tn_mps = bridge_to_torch (input,DEVICE,true ,DTYPE);
216- auto tn = tn_mps.permute ({2 , 0 , 1 }).unsqueeze (0 ).contiguous ();
187+ tn_mps = tn_mps.permute ({2 , 0 , 1 }).contiguous ();
188+ tn_mps.unsqueeze_ (0 );// .contiguous();
189+ // auto tn = tn_mps.permute({2, 0, 1}).unsqueeze(0).contiguous();
217190
218191 std::vector<torch::jit::IValue> ins;
219- ins.push_back (tn );
192+ ins.push_back (tn_mps );
220193
221194 auto * module = static_cast <torch::jit::Module*>(model.pt_module );
222195 auto o = module ->forward (ins).toTensor ();
@@ -227,8 +200,8 @@ bridge_tensor_t model_forward(bridge_pt_model_t model, bridge_tensor_t input, bo
227200 tn_out.div_ (255.0 );
228201 }
229202
230- auto tn_out_cpu = tn_out.to (torch::kCPU ,torch::kFloat32 ,false ,true );
231-
203+ auto tn_out_cpu = tn_out.to (torch::kCPU ,torch::kFloat32 ,false ,false );
204+
232205 return torch_to_bridge (tn_out_cpu);
233206
234207}
0 commit comments