@@ -140,13 +140,19 @@ extern "C" bridge_pt_model_t load_model(const uint8_t* model_path) {
140140 std::cout.flush ();
141141
142142 try {
143- torch::jit::Module tmp = torch::jit::load (path);
143+
144+ auto * module = new torch::jit::Module (torch::jit::load (path));
144145 std::cout << " Model loaded successfully!" << std::endl;
145146 std::cout.flush ();
146- auto * module = new torch::jit::Module (std::move (tmp));
147- std::cout << " Model moved successfully!" << std::endl;
148- std::cout.flush ();
149147 return { static_cast <void *>(module ) };
148+
149+ // torch::jit::Module tmp = torch::jit::load(path);
150+ // std::cout << "Model loaded successfully!" << std::endl;
151+ // std::cout.flush();
152+ // auto* module = new torch::jit::Module(std::move(tmp));
153+ // std::cout << "Model moved successfully!" << std::endl;
154+ // std::cout.flush();
155+ // return { static_cast<void*>(module) };
150156 } catch (const c10::Error& e) {
151157 std::cerr << " error loading the model\n " << e.msg ();
152158 std::cout << " error loading the model\n " << e.msg ();
@@ -190,44 +196,68 @@ extern "C" bridge_tensor_t model_forward(bridge_pt_model_t model, bridge_tensor_
190196 inputs.push_back (t_input);
191197 // torch::jit::Module* pt_module = (torch::jit::Module*) model.pt_module;
192198 // auto output = pt_module->forward(inputs).toTensor();
193- auto output = t_input;
199+ auto * module = static_cast <torch::jit::Module*>(model.pt_module );
200+ auto output = module ->forward (inputs).toTensor ();
201+ std::cout << " Output tensor: " << output.sizes () << std::endl;
202+ std::cout.flush ();
203+ // auto output = t_input;
194204 return torch_to_bridge (output);
195205}
196206
197207extern " C" bridge_tensor_t model_forward_style_transfer (bridge_pt_model_t model, bridge_tensor_t input) {
198208 auto input_tensor = bridge_to_torch (input);
199- auto t_input = input_tensor.clone ();
209+ auto input_tensor_copy = input_tensor.clone ().contiguous ();
210+ auto t_input = input_tensor_copy;
211+ auto * module = static_cast <torch::jit::Module*>(model.pt_module );
212+
213+ std::cout << " Model: " << module ->dump_to_str (false , false , false ) << std::endl;
214+ std::cout.flush ();
200215
201216 std::cout << " Input tensor: " << t_input.sizes () << std::endl;
202217 std::cout.flush ();
203218
204- t_input = t_input .permute ({2 , 0 , 1 }).unsqueeze (0 );
219+ auto model_input = input_tensor_copy .permute ({2 , 0 , 1 }).unsqueeze (0 );
205220
206- std::cout << " Input tensor reshaped: " << t_input.sizes () << std::endl;
207- std::cout.flush ();
221+ // std::cout << "Input tensor reshaped: " << model_input.sizes() << std::endl;
222+ // std::cout.flush();
223+
224+ // std::vector<torch::jit::IValue> inputs;
225+ // inputs.push_back(model_input);
226+
227+ // std::cout << "Constructed inputs: " << inputs.size() << std::endl;
228+ // std::cout.flush();
229+
230+ // return torch_to_bridge(input_tensor_copy);
208231
209232 std::vector<torch::jit::IValue> inputs;
210- inputs.push_back (t_input );
233+ inputs.push_back (model_input );
211234
212- std::cout << " Constructed inputs : " << inputs. size () << std::endl;
235+ std::cout << " Model input : " << model_input. sizes () << std::endl;
213236 std::cout.flush ();
214237
238+ auto model_output = module ->forward (inputs).toTensor ();
239+ std::cout << " Output tensor: " << model_output.sizes () << std::endl;
240+ std::cout.flush ();
241+
242+ auto output = model_output.div (255.0 ).squeeze (0 ).permute ({1 , 2 , 0 }).clamp (0 , 1 );
243+ return torch_to_bridge (output);
244+
215245 // torch::jit::script::Module & pt_module = model.pt_module;
216246
217- auto * pt_module = static_cast <torch::jit::Module*>(model.pt_module );
247+ // auto* pt_module = static_cast<torch::jit::Module*>(model.pt_module);
218248
219- // torch::jit::script::Module* pt_module = (torch::jit::script::Module*)model.pt_module;
220- // std::cout << pt_module->dump_to_str(false,false,false) << std::endl;
221- // // std::cout.flush();
222- auto output = pt_module->forward (inputs).toTensor ();
249+ // // torch::jit::script::Module* pt_module = (torch::jit::script::Module*)model.pt_module;
250+ // // std::cout << pt_module->dump_to_str(false,false,false) << std::endl;
251+ // // // std::cout.flush();
252+ // auto output = pt_module->forward(inputs).toTensor();
223253 std::cout << " Output tensor: " << output.sizes () << std::endl;
224254 std::cout.flush ();
225255 // output = output.squeeze(0).permute({1, 2, 0}).clamp(0, 1).mul(255.0);
226256
227257 // std::cout << "Processed utput tensor: " << output.sizes() << std::endl;
228258 // std::cout.flush();
229259
230- return torch_to_bridge (input_tensor );
260+ return torch_to_bridge (input_tensor_copy );
231261}
232262
233263
0 commit comments