1616#include < chrono>
1717#include < thread>
1818
19- #include < opencv2/opencv.hpp>
2019
2120
2221#define def_bridge_simple (Name ) \
@@ -66,6 +65,19 @@ torch::Tensor bridge_to_torch(bridge_tensor_t &bt) {
6665 return torch::from_blob (bt.data , shape, torch::kFloat );
6766}
6867
68+ torch::Tensor bridge_to_torch (bridge_tensor_t &bt,torch::Device device, bool copy,torch::ScalarType dtype = torch::kFloat32 ) {
69+ std::vector<int64_t > sizes_vec (bt.sizes , bt.sizes + bt.dim );
70+ auto shape = torch::IntArrayRef (sizes_vec);
71+ auto t = torch::from_blob (bt.data , shape, torch::kFloat );
72+ if (device != torch::kCPU )
73+ copy = true ;
74+ if (copy)
75+ return t.to (device, dtype, /* non_blocking=*/ false , /* copy=*/ true );
76+ else
77+ return t.to (device, dtype, /* non_blocking=*/ false , /* copy=*/ false );
78+
79+ }
80+
6981extern " C" float32_t * unsafe (const float32_t * arr) {
7082 return const_cast <float32_t *>(arr);
7183}
@@ -131,6 +143,102 @@ extern "C" bridge_tensor_t load_run_model(const uint8_t* model_path, bridge_tens
131143 return torch_to_bridge (output);
132144}
133145
146+
147+ extern " C" bridge_pt_model_t load_model (const uint8_t * model_path) {
148+
149+ std::cout << " Begin loading model from path: " << model_path << std::endl;
150+ std::cout.flush ();
151+ std::string path (reinterpret_cast <const char *>(model_path));
152+ std::cout << " Loading model from path: " << path << std::endl;
153+ std::cout.flush ();
154+
155+ try {
156+
157+ auto * module = new torch::jit::Module (torch::jit::load (path));
158+ module ->to (torch::kMPS ,torch::kFloat16 ,false );
159+ module ->eval ();
160+ std::cout << " Model loaded successfully!" << std::endl;
161+ std::cout.flush ();
162+ return { static_cast <void *>(module ) };
163+
164+ // torch::jit::Module tmp = torch::jit::load(path);
165+ // std::cout << "Model loaded successfully!" << std::endl;
166+ // std::cout.flush();
167+ // auto* module = new torch::jit::Module(std::move(tmp));
168+ // std::cout << "Model moved successfully!" << std::endl;
169+ // std::cout.flush();
170+ // return { static_cast<void*>(module) };
171+ } catch (const c10::Error& e) {
172+ std::cerr << " error loading the model\n " << e.msg ();
173+ std::cout << " error loading the model\n " << e.msg ();
174+ std::cout.flush ();
175+ std::cerr.flush ();
176+ }
177+ std::cout << " Model loading failed!" << std::endl;
178+ std::cout.flush ();
179+
180+ return { nullptr };
181+
182+
183+
184+ // bridge_pt_model_t model_wrapper;
185+ // torch::jit::Module* pt_module = new torch::jit::Module(); // = (torch::jit::Module*) model_wrapper.pt_module;
186+ // try {
187+ // *pt_module = torch::jit::load(mp);
188+ // std::cout << "Model loaded successfully!" << std::endl;
189+ // std::cout.flush();
190+ // model_wrapper.pt_module = pt_module;
191+ // } catch (const c10::Error& e) {
192+ // std::cerr << "error loading the model\n" << e.msg();
193+ // std::cout << "error loading the model\n" << e.msg();
194+ // std::cout.flush();
195+ // std::cerr.flush();
196+ // }
197+
198+ // std::cout << pt_module->dump_to_str(false,false,false) << std::endl;
199+ // std::cout.flush();
200+
201+ // return model_wrapper;
202+ }
203+
204+
205+
206+ bridge_tensor_t model_forward (bridge_pt_model_t model, bridge_tensor_t input, bool is_vgg_based_model) {
207+
208+ auto tn_mps = bridge_to_torch (input,torch::kMPS ,true ,torch::kFloat16 );
209+ // auto tn_mps = tn.to(torch::kMPS,false,true);
210+ auto tn = tn_mps.permute ({2 , 0 , 1 }).unsqueeze (0 ).contiguous ();
211+
212+ std::vector<torch::jit::IValue> ins;
213+ ins.push_back (tn);
214+
215+ auto * module = static_cast <torch::jit::Module*>(model.pt_module );
216+ auto o = module ->forward (ins).toTensor ();
217+ auto tn_out = o.squeeze (0 ).contiguous ().permute ({1 , 2 , 0 }).contiguous ();
218+
219+ if (is_vgg_based_model) {
220+ tn_out = tn_out / 255.0 ;
221+ }
222+
223+ auto tn_out_cpu = tn_out.to (torch::kCPU ,torch::kFloat32 ,false ,true );
224+ return torch_to_bridge (tn_out_cpu);
225+
226+ }
227+
228+ extern " C" bridge_tensor_t model_forward (bridge_pt_model_t model, bridge_tensor_t input) {
229+ return model_forward (model, input, false );
230+ }
231+
232+ extern " C" bridge_tensor_t model_forward_style_transfer (bridge_pt_model_t model, bridge_tensor_t input) {
233+ return model_forward (model, input, true );
234+ }
235+
236+
237+ extern " C" void hello_world (void ) {
238+ std::cout << " Hello from C++!" << std::endl;
239+ std::cout.flush ();
240+ }
241+
134242extern " C" bridge_tensor_t increment3 (bridge_tensor_t arr) {
135243 auto t = bridge_to_torch (arr);
136244 // Increment the tensor
@@ -404,37 +512,37 @@ extern "C" void split_loop_filler(int64_t n,int64_t* ret) {
404512
405513
406514
407- cv::VideoCapture open_camera (int cam_index) {
408- cv::VideoCapture cap (cam_index, cv::CAP_AVFOUNDATION);
409- if (!cap.isOpened ()) {
410- std::cerr << " Could not open camera index " << cam_index << std::endl;
411- return cv::VideoCapture ();
412- }
413- cap.set (cv::CAP_PROP_BUFFERSIZE, 1 ); // minimal internal buffering
414- cap.set (cv::CAP_PROP_FPS, 60 ); // request higher FPS if possible
415- return cap;
416- }
515+ // cv::VideoCapture open_camera(int cam_index) {
516+ // cv::VideoCapture cap(cam_index, cv::CAP_AVFOUNDATION);
517+ // if (!cap.isOpened()) {
518+ // std::cerr << "Could not open camera index " << cam_index << std::endl;
519+ // return cv::VideoCapture();
520+ // }
521+ // cap.set(cv::CAP_PROP_BUFFERSIZE, 1); // minimal internal buffering
522+ // cap.set(cv::CAP_PROP_FPS, 60); // request higher FPS if possible
523+ // return cap;
524+ // }
417525
418526
419- extern " C" void show_webcam (void ) {
420- cv::VideoCapture cap;
421- cap = open_camera (0 );
527+ // extern "C" void show_webcam(void) {
528+ // cv::VideoCapture cap;
529+ // cap = open_camera(0);
422530
423- cv::Mat frame_bgr;
531+ // cv::Mat frame_bgr;
424532
425- while (true ) {
426- if (!cap.read (frame_bgr) || frame_bgr.empty ()) {
427- std::cerr << " [WARN] Empty frame, exiting" << std::endl;
428- break ;
429- }
533+ // while (true) {
534+ // if (!cap.read(frame_bgr) || frame_bgr.empty()) {
535+ // std::cerr << "[WARN] Empty frame, exiting" << std::endl;
536+ // break;
537+ // }
430538
431- cv::imshow (" webcam" , frame_bgr);
539+ // cv::imshow("webcam", frame_bgr);
432540
433- if (cv::waitKey (1 ) == 27 ) { // ESC key
434- break ;
435- }
436- }
541+ // if (cv::waitKey(1) == 27) { // ESC key
542+ // break;
543+ // }
544+ // }
437545
438- cap.release ();
439- cv::destroyAllWindows ();
440- }
546+ // cap.release();
547+ // cv::destroyAllWindows();
548+ // }
0 commit comments