@@ -201,66 +201,36 @@ extern "C" bridge_pt_model_t load_model(const uint8_t* model_path) {
201201 // return model_wrapper;
202202}
203203
204- extern " C" bridge_tensor_t model_forward (bridge_pt_model_t model, bridge_tensor_t input) {
204+
205+
206+ bridge_tensor_t model_forward (bridge_pt_model_t model, bridge_tensor_t input, bool is_vgg_based_model) {
205207
206208 auto tn_mps = bridge_to_torch (input,torch::kMPS ,true ,torch::kFloat16 );
207209 // auto tn_mps = tn.to(torch::kMPS,false,true);
208- auto tn_ = tn_mps.permute ({2 , 0 , 1 }).unsqueeze (0 ).contiguous ();
210+ auto tn = tn_mps.permute ({2 , 0 , 1 }).unsqueeze (0 ).contiguous ();
209211
210212 std::vector<torch::jit::IValue> ins;
211- ins.push_back (tn_ );
213+ ins.push_back (tn );
212214
213215 auto * module = static_cast <torch::jit::Module*>(model.pt_module );
214216 auto o = module ->forward (ins).toTensor ();
215217 auto tn_out = o.squeeze (0 ).contiguous ().permute ({1 , 2 , 0 }).contiguous ();
216- auto tn_out_cpu = tn_out.to (torch::kCPU ,torch::kFloat32 ,false ,true );
217- return torch_to_bridge (tn_out_cpu);
218-
219- /*
220218
221- auto t = bridge_to_torch(input).clone();
222- auto t_input = t.permute({2, 0, 1}).unsqueeze(0); // Add batch dimension
219+ if (is_vgg_based_model) {
220+ tn_out = tn_out / 255.0 ;
221+ }
223222
224- std::cout << "Input tensor: " << t_input.sizes() << std::endl ;
225- std::cout.flush( );
223+ auto tn_out_cpu = tn_out. to (torch:: kCPU ,torch:: kFloat32 , false , true ) ;
224+ return torch_to_bridge (tn_out_cpu );
226225
227- std::vector<torch::jit::IValue> inputs;
228- inputs.push_back(t_input);
229- // torch::jit::Module* pt_module = (torch::jit::Module*) model.pt_module;
230- // auto output = pt_module->forward(inputs).toTensor();
231- auto* module = static_cast<torch::jit::Module*>(model.pt_module);
232- auto output = module->forward(inputs).toTensor();
233- std::cout << "Output tensor: " << output.sizes() << std::endl;
234- std::cout.flush();
226+ }
235227
236- auto output_reshaped = output.squeeze(0).permute({1, 2, 0}); // Remove batch dimension and permute back to HWC
237- std::cout << "Output reshaped tensor: " << output_reshaped.sizes() << std::endl;
238- std::cout.flush();
239- // auto output = t_input;
240- return torch_to_bridge(output_reshaped);
241- */
228+ extern " C" bridge_tensor_t model_forward (bridge_pt_model_t model, bridge_tensor_t input) {
229+ return model_forward (model, input, false );
242230}
243231
244232extern " C" bridge_tensor_t model_forward_style_transfer (bridge_pt_model_t model, bridge_tensor_t input) {
245- auto bt = bridge_to_torch (input).clone ();
246- auto t_input = bt.permute ({2 , 0 , 1 }).unsqueeze (0 ); // Convert from CHW to HWC
247-
248- std::cout << " Input tensor: " << t_input.sizes () << std::endl;
249- std::cout.flush ();
250-
251- std::vector<torch::jit::IValue> inputs;
252- inputs.push_back (t_input);
253- // torch::jit::Module* pt_module = (torch::jit::Module*) model.pt_module;
254- // auto output = pt_module->forward(inputs).toTensor();
255- // auto* module = static_cast<torch::jit::Module*>(model.pt_module);
256- auto module = *static_cast <torch::jit::Module*>(model.pt_module );
257- std::cout << " Module: " << module .dump_to_str (false , false , false ) << std::endl;
258- std::cout.flush ();
259- auto output = module .forward (inputs).toTensor ();
260- std::cout << " Output tensor: " << output.sizes () << std::endl;
261- std::cout.flush ();
262- // auto output = t_input;
263- return torch_to_bridge (output);
233+ return model_forward (model, input, true );
264234}
265235
266236
0 commit comments