@@ -132,34 +132,60 @@ extern "C" bridge_tensor_t load_run_model(const uint8_t* model_path, bridge_tens
132132
133133
134134extern " C" bridge_pt_model_t load_model (const uint8_t * model_path) {
135+
135136 std::cout << " Begin loading model from path: " << model_path << std::endl;
136137 std::cout.flush ();
137- std::string mp (reinterpret_cast <const char *>(model_path));
138- std::cout << " Loading model from path: " << mp << std::endl;
138+ std::string path (reinterpret_cast <const char *>(model_path));
139+ std::cout << " Loading model from path: " << path << std::endl;
139140 std::cout.flush ();
140141
141- bridge_pt_model_t model_wrapper;
142- torch::jit::Module* pt_module = new torch::jit::Module (); // = (torch::jit::Module*) model_wrapper.pt_module;
143142 try {
144- *pt_module = torch::jit::load (mp );
143+ torch::jit::Module tmp = torch::jit::load (path );
145144 std::cout << " Model loaded successfully!" << std::endl;
146145 std::cout.flush ();
147- model_wrapper.pt_module = pt_module;
146+ auto * module = new torch::jit::Module (std::move (tmp));
147+ std::cout << " Model moved successfully!" << std::endl;
148+ std::cout.flush ();
149+ return { static_cast <void *>(module ) };
148150 } catch (const c10::Error& e) {
149151 std::cerr << " error loading the model\n " << e.msg ();
150152 std::cout << " error loading the model\n " << e.msg ();
151153 std::cout.flush ();
152154 std::cerr.flush ();
153155 }
154-
155- std::cout << pt_module->dump_to_str (false ,false ,false ) << std::endl;
156+ std::cout << " Model loading failed!" << std::endl;
156157 std::cout.flush ();
157158
158- return model_wrapper;
159+ return { nullptr };
160+
161+
162+
163+ // bridge_pt_model_t model_wrapper;
164+ // torch::jit::Module* pt_module = new torch::jit::Module(); // = (torch::jit::Module*) model_wrapper.pt_module;
165+ // try {
166+ // *pt_module = torch::jit::load(mp);
167+ // std::cout << "Model loaded successfully!" << std::endl;
168+ // std::cout.flush();
169+ // model_wrapper.pt_module = pt_module;
170+ // } catch (const c10::Error& e) {
171+ // std::cerr << "error loading the model\n" << e.msg();
172+ // std::cout << "error loading the model\n" << e.msg();
173+ // std::cout.flush();
174+ // std::cerr.flush();
175+ // }
176+
177+ // std::cout << pt_module->dump_to_str(false,false,false) << std::endl;
178+ // std::cout.flush();
179+
180+ // return model_wrapper;
159181}
160182
161183extern " C" bridge_tensor_t model_forward (bridge_pt_model_t model, bridge_tensor_t input) {
162184 auto t_input = bridge_to_torch (input);
185+
186+ std::cout << " Input tensor: " << t_input.sizes () << std::endl;
187+ std::cout.flush ();
188+
163189 std::vector<torch::jit::IValue> inputs;
164190 inputs.push_back (t_input);
165191 // torch::jit::Module* pt_module = (torch::jit::Module*) model.pt_module;
@@ -168,6 +194,42 @@ extern "C" bridge_tensor_t model_forward(bridge_pt_model_t model, bridge_tensor_
168194 return torch_to_bridge (output);
169195}
170196
197+ extern " C" bridge_tensor_t model_forward_style_transfer (bridge_pt_model_t model, bridge_tensor_t input) {
198+ auto input_tensor = bridge_to_torch (input);
199+ auto t_input = input_tensor.clone ();
200+
201+ std::cout << " Input tensor: " << t_input.sizes () << std::endl;
202+ std::cout.flush ();
203+
204+ t_input = t_input.permute ({2 , 0 , 1 }).unsqueeze (0 );
205+
206+ std::cout << " Input tensor reshaped: " << t_input.sizes () << std::endl;
207+ std::cout.flush ();
208+
209+ std::vector<torch::jit::IValue> inputs;
210+ inputs.push_back (t_input);
211+
212+ std::cout << " Constructed inputs: " << inputs.size () << std::endl;
213+ std::cout.flush ();
214+
215+ // torch::jit::script::Module & pt_module = model.pt_module;
216+
217+ auto * pt_module = static_cast <torch::jit::Module*>(model.pt_module );
218+
219+ // torch::jit::script::Module* pt_module = (torch::jit::script::Module*)model.pt_module;
220+ // std::cout << pt_module->dump_to_str(false,false,false) << std::endl;
221+ // // std::cout.flush();
222+ auto output = pt_module->forward (inputs).toTensor ();
223+ std::cout << " Output tensor: " << output.sizes () << std::endl;
224+ std::cout.flush ();
225+ // output = output.squeeze(0).permute({1, 2, 0}).clamp(0, 1).mul(255.0);
226+
227+ // std::cout << "Processed utput tensor: " << output.sizes() << std::endl;
228+ // std::cout.flush();
229+
230+ return torch_to_bridge (input_tensor);
231+ }
232+
171233
172234extern " C" void hello_world (void ) {
173235 std::cout << " Hello from C++!" << std::endl;
0 commit comments