11#include < bridge.h>
22
33#include < torch/torch.h>
4+ #include < Aten/ATen.h>
5+
46#include < torch/script.h>
57
68// #include <torch/script.h>
1618#include < chrono>
1719#include < thread>
1820
19- #include < opencv2/opencv.hpp>
2021
2122
2223#define def_bridge_simple (Name ) \
2829
2930
3031
32+ // Globals
33+
34+
35+ torch::Device get_best_device ();
36+ torch::ScalarType get_best_dtype ();
37+
38+ auto best_device = get_best_device();
39+ auto best_dtype = get_best_dtype();
40+
41+ torch::NoGradGuard no_grad;
42+ torch::AutoGradMode enable_grad (false );
43+
44+ bool debug_cpu_only = false ;
45+
46+
47+
48+ torch::Device get_best_device () {
49+ if (debug_cpu_only)
50+ return torch::Device (torch::kCPU );
51+
52+ if (torch::hasMPS ()) {
53+ return torch::Device (torch::kMPS );
54+ } else if (torch::hasCUDA ()) {
55+ return torch::Device (torch::kCUDA );
56+ } else {
57+ return torch::Device (torch::kCPU );
58+ }
59+ }
60+
61+ extern " C" void debug_cpu_only_mode (bool_t mode) {
62+ debug_cpu_only = mode;
63+ if (debug_cpu_only) {
64+ best_device = torch::Device (torch::kCPU );
65+ } else {
66+ best_device = get_best_device ();
67+ }
68+ }
69+
70+ extern " C" bool_t accelerator_available () {
71+ return (best_device == torch::Device (torch::kCUDA ) || best_device == torch::Device (torch::kMPS ));
72+ }
73+
74+ torch::ScalarType get_best_dtype () {
75+ if (torch::hasMPS ()) {
76+ return torch::kFloat16 ;
77+ } else if (torch::hasCUDA ()) {
78+ return torch::kFloat16 ;
79+ } else {
80+ return torch::kFloat32 ;
81+ }
82+ }
83+
3184int bridge_tensor_elements (bridge_tensor_t &bt) {
3285 int size = 1 ;
3386 for (int i = 0 ; i < bt.dim ; ++i) {
@@ -40,14 +93,14 @@ size_t bridge_tensor_size(bridge_tensor_t &bt) {
4093 return sizeof (float32_t ) * bridge_tensor_elements (bt);
4194}
4295
43- void store_tensor (torch ::Tensor &input, float32_t * dest) {
96+ void store_tensor (at ::Tensor &input, float32_t * dest) {
4497 float32_t * data = input.data_ptr <float32_t >();
4598 size_t bytes_size = sizeof (float32_t ) * input.numel ();
4699 // std::memmove(dest,data,bytes_size);
47100 std::memcpy (dest,data,bytes_size);
48101}
49102
50- bridge_tensor_t torch_to_bridge (torch ::Tensor &tensor) {
103+ bridge_tensor_t torch_to_bridge (at ::Tensor &tensor) {
51104 bridge_tensor_t result;
52105 result.created_by_c = true ;
53106 result.dim = tensor.dim ();
@@ -60,12 +113,25 @@ bridge_tensor_t torch_to_bridge(torch::Tensor &tensor) {
60113 return result;
61114}
62115
63- torch ::Tensor bridge_to_torch (bridge_tensor_t &bt) {
116+ at ::Tensor bridge_to_torch (bridge_tensor_t &bt) {
64117 std::vector<int64_t > sizes_vec (bt.sizes , bt.sizes + bt.dim );
65118 auto shape = torch::IntArrayRef (sizes_vec);
66119 return torch::from_blob (bt.data , shape, torch::kFloat );
67120}
68121
122+ at::Tensor bridge_to_torch (bridge_tensor_t &bt,torch::Device device, bool copy,torch::ScalarType dtype = torch::kFloat32 ) {
123+ std::vector<int64_t > sizes_vec (bt.sizes , bt.sizes + bt.dim );
124+ auto shape = torch::IntArrayRef (sizes_vec);
125+ auto t = torch::from_blob (bt.data , shape, torch::kFloat );
126+ if (device != torch::kCPU )
127+ copy = true ;
128+ if (copy)
129+ return t.to (device, dtype, /* non_blocking=*/ false , /* copy=*/ true );
130+ else
131+ return t.to (device, dtype, /* non_blocking=*/ false , /* copy=*/ false );
132+
133+ }
134+
69135extern " C" float32_t * unsafe (const float32_t * arr) {
70136 return const_cast <float32_t *>(arr);
71137}
@@ -131,6 +197,92 @@ extern "C" bridge_tensor_t load_run_model(const uint8_t* model_path, bridge_tens
131197 return torch_to_bridge (output);
132198}
133199
200+
201+
202+
203+ extern " C" bridge_pt_model_t load_model (const uint8_t * model_path) {
204+
205+ std::cout << " Begin loading model from path: " << model_path << std::endl;
206+ std::cout.flush ();
207+ std::string path (reinterpret_cast <const char *>(model_path));
208+ std::cout << " Loading model from path: " << path << std::endl;
209+ std::cout.flush ();
210+
211+ try {
212+ auto * module = new torch::jit::Module (torch::jit::load (path));
213+ module ->to (best_device,best_dtype,false );
214+ module ->eval ();
215+ std::cout << " Model loaded successfully!" << std::endl;
216+ std::cout.flush ();
217+ return { static_cast <void *>(module ) };
218+ } catch (const c10::Error& e) {
219+ std::cerr << " error loading the model\n " << e.msg ();
220+ std::cout << " error loading the model\n " << e.msg ();
221+ std::cout.flush ();
222+ std::cerr.flush ();
223+ }
224+ std::cout << " Model loading failed!" << std::endl;
225+ std::cout.flush ();
226+
227+ return { nullptr };
228+ }
229+
230+
231+
232+ bridge_tensor_t model_forward (bridge_pt_model_t model, bridge_tensor_t input, bool is_vgg_based_model) {
233+ auto tn_mps = bridge_to_torch (input,best_device,true ,best_dtype);
234+ // tn_mps = tn_mps.permute({2, 0, 1}).contiguous();
235+ // tn_mps.unsqueeze_(0);//.contiguous();
236+ auto tn = tn_mps.permute ({2 , 0 , 1 }).unsqueeze (0 ).contiguous ();
237+
238+ std::vector<torch::jit::IValue> ins;
239+ ins.push_back (tn);
240+
241+ auto * module = static_cast <torch::jit::Module*>(model.pt_module );
242+ auto o = module ->forward (ins).toTensor ();
243+ // auto tn_out = o.squeeze(0).permute({1, 2, 0}).contiguous();
244+ auto tn_out = o.squeeze (0 ).contiguous ().permute ({1 , 2 , 0 }).contiguous ();
245+
246+ if (is_vgg_based_model) {
247+ tn_out.div_ (255.0 );
248+ }
249+
250+ auto tn_out_cpu = tn_out.to (torch::kCPU ,torch::kFloat32 ,false ,true );
251+
252+ return torch_to_bridge (tn_out_cpu);
253+
254+ }
255+
256+ extern " C" bridge_tensor_t model_forward (bridge_pt_model_t model, bridge_tensor_t input) {
257+ return model_forward (model, input, false );
258+ }
259+
260+ extern " C" bridge_tensor_t model_forward_style_transfer (bridge_pt_model_t model, bridge_tensor_t input) {
261+ return model_forward (model, input, true );
262+ }
263+
264+ // std::tuple<uint64_t, uint64_t> get_cpu_frame_size(uint64_t width, uint64_t height, float32_t scale_factor) {
265+ // // if (best_device == torch::kMPS || best_device == torch::kCUDA)
266+ // if (accelerator_available())
267+ // return std::make_tuple(width, height);
268+ // uint64_t new_width = static_cast<uint64_t>(width * scale_factor);
269+ // uint64_t new_height = static_cast<uint64_t>(height * scale_factor);
270+ // return std::make_tuple(new_width, new_height);
271+ // }
272+
273+ // extern "C" uint64_t get_cpu_frame_width(uint64_t width,float32_t scale_factor) {
274+ // return std::get<0>(get_cpu_frame_size(width, 0, scale_factor));
275+ // }
276+ // extern "C" uint64_t get_cpu_frame_height(uint64_t height,float32_t scale_factor) {
277+ // return std::get<1>(get_cpu_frame_size(0, height, scale_factor));
278+ // }
279+
280+
281+ extern " C" void hello_world (void ) {
282+ std::cout << " Hello from C++!" << std::endl;
283+ std::cout.flush ();
284+ }
285+
134286extern " C" bridge_tensor_t increment3 (bridge_tensor_t arr) {
135287 auto t = bridge_to_torch (arr);
136288 // Increment the tensor
@@ -404,37 +556,37 @@ extern "C" void split_loop_filler(int64_t n,int64_t* ret) {
404556
405557
406558
407- cv::VideoCapture open_camera (int cam_index) {
408- cv::VideoCapture cap (cam_index, cv::CAP_AVFOUNDATION);
409- if (!cap.isOpened ()) {
410- std::cerr << " Could not open camera index " << cam_index << std::endl;
411- return cv::VideoCapture ();
412- }
413- cap.set (cv::CAP_PROP_BUFFERSIZE, 1 ); // minimal internal buffering
414- cap.set (cv::CAP_PROP_FPS, 60 ); // request higher FPS if possible
415- return cap;
416- }
559+ // cv::VideoCapture open_camera(int cam_index) {
560+ // cv::VideoCapture cap(cam_index, cv::CAP_AVFOUNDATION);
561+ // if (!cap.isOpened()) {
562+ // std::cerr << "Could not open camera index " << cam_index << std::endl;
563+ // return cv::VideoCapture();
564+ // }
565+ // cap.set(cv::CAP_PROP_BUFFERSIZE, 1); // minimal internal buffering
566+ // cap.set(cv::CAP_PROP_FPS, 60); // request higher FPS if possible
567+ // return cap;
568+ // }
417569
418570
419- extern " C" void show_webcam (void ) {
420- cv::VideoCapture cap;
421- cap = open_camera (0 );
571+ // extern "C" void show_webcam(void) {
572+ // cv::VideoCapture cap;
573+ // cap = open_camera(0);
422574
423- cv::Mat frame_bgr;
575+ // cv::Mat frame_bgr;
424576
425- while (true ) {
426- if (!cap.read (frame_bgr) || frame_bgr.empty ()) {
427- std::cerr << " [WARN] Empty frame, exiting" << std::endl;
428- break ;
429- }
577+ // while (true) {
578+ // if (!cap.read(frame_bgr) || frame_bgr.empty()) {
579+ // std::cerr << "[WARN] Empty frame, exiting" << std::endl;
580+ // break;
581+ // }
430582
431- cv::imshow (" webcam" , frame_bgr);
583+ // cv::imshow("webcam", frame_bgr);
432584
433- if (cv::waitKey (1 ) == 27 ) { // ESC key
434- break ;
435- }
436- }
585+ // if (cv::waitKey(1) == 27) { // ESC key
586+ // break;
587+ // }
588+ // }
437589
438- cap.release ();
439- cv::destroyAllWindows ();
440- }
590+ // cap.release();
591+ // cv::destroyAllWindows();
592+ // }
0 commit comments