11#include < bridge.h>
22
33#include < torch/torch.h>
4+ #include < Aten/ATen.h>
5+
46#include < torch/script.h>
57
68// #include <torch/script.h>
1618#include < chrono>
1719#include < thread>
1820
19- #include < opencv2/opencv.hpp>
2021
2122
2223namespace tnf = torch::nn::functional;
@@ -29,6 +30,57 @@ namespace tnf = torch::nn::functional;
2930 return torch_to_bridge (t_output); \
3031 }
3132
33+ // Globals
34+
35+
36+ torch::Device get_best_device ();
37+ torch::ScalarType get_best_dtype ();
38+
39+ auto best_device = get_best_device();
40+ auto best_dtype = get_best_dtype();
41+
42+ torch::NoGradGuard no_grad;
43+ torch::AutoGradMode enable_grad (false );
44+
45+ bool debug_cpu_only = false ;
46+
47+
48+
49+ torch::Device get_best_device () {
50+ if (debug_cpu_only)
51+ return torch::Device (torch::kCPU );
52+
53+ if (torch::hasMPS ()) {
54+ return torch::Device (torch::kMPS );
55+ } else if (torch::hasCUDA ()) {
56+ return torch::Device (torch::kCUDA );
57+ } else {
58+ return torch::Device (torch::kCPU );
59+ }
60+ }
61+
62+ extern " C" void debug_cpu_only_mode (bool_t mode) {
63+ debug_cpu_only = mode;
64+ if (debug_cpu_only) {
65+ best_device = torch::Device (torch::kCPU );
66+ } else {
67+ best_device = get_best_device ();
68+ }
69+ }
70+
71+ extern " C" bool_t accelerator_available () {
72+ return (best_device == torch::Device (torch::kCUDA ) || best_device == torch::Device (torch::kMPS ));
73+ }
74+
75+ torch::ScalarType get_best_dtype () {
76+ if (torch::hasMPS ()) {
77+ return torch::kFloat16 ;
78+ } else if (torch::hasCUDA ()) {
79+ return torch::kFloat16 ;
80+ } else {
81+ return torch::kFloat32 ;
82+ }
83+ }
3284
3385int bridge_tensor_elements (bridge_tensor_t &bt) {
3486 int size = 1 ;
@@ -42,14 +94,14 @@ size_t bridge_tensor_size(bridge_tensor_t &bt) {
4294 return sizeof (float32_t ) * bridge_tensor_elements (bt);
4395}
4496
45- void store_tensor (torch ::Tensor &input, float32_t * dest) {
97+ void store_tensor (at ::Tensor &input, float32_t * dest) {
4698 float32_t * data = input.data_ptr <float32_t >();
4799 size_t bytes_size = sizeof (float32_t ) * input.numel ();
48100 // std::memmove(dest,data,bytes_size);
49101 std::memcpy (dest,data,bytes_size);
50102}
51103
52- bridge_tensor_t torch_to_bridge (torch ::Tensor &tensor) {
104+ bridge_tensor_t torch_to_bridge (at ::Tensor &tensor) {
53105 bridge_tensor_t result;
54106 result.created_by_c = true ;
55107 result.dim = tensor.dim ();
@@ -62,12 +114,25 @@ bridge_tensor_t torch_to_bridge(torch::Tensor &tensor) {
62114 return result;
63115}
64116
65- torch ::Tensor bridge_to_torch (bridge_tensor_t &bt) {
117+ at ::Tensor bridge_to_torch (bridge_tensor_t &bt) {
66118 std::vector<int64_t > sizes_vec (bt.sizes , bt.sizes + bt.dim );
67119 auto shape = torch::IntArrayRef (sizes_vec);
68120 return torch::from_blob (bt.data , shape, torch::kFloat );
69121}
70122
123+ at::Tensor bridge_to_torch (bridge_tensor_t &bt,torch::Device device, bool copy,torch::ScalarType dtype = torch::kFloat32 ) {
124+ std::vector<int64_t > sizes_vec (bt.sizes , bt.sizes + bt.dim );
125+ auto shape = torch::IntArrayRef (sizes_vec);
126+ auto t = torch::from_blob (bt.data , shape, torch::kFloat );
127+ if (device != torch::kCPU )
128+ copy = true ;
129+ if (copy)
130+ return t.to (device, dtype, /* non_blocking=*/ false , /* copy=*/ true );
131+ else
132+ return t.to (device, dtype, /* non_blocking=*/ false , /* copy=*/ false );
133+
134+ }
135+
71136extern " C" float32_t * unsafe (const float32_t * arr) {
72137 return const_cast <float32_t *>(arr);
73138}
@@ -133,6 +198,92 @@ extern "C" bridge_tensor_t load_run_model(const uint8_t* model_path, bridge_tens
133198 return torch_to_bridge (output);
134199}
135200
201+
202+
203+
204+ extern " C" bridge_pt_model_t load_model (const uint8_t * model_path) {
205+
206+ std::cout << " Begin loading model from path: " << model_path << std::endl;
207+ std::cout.flush ();
208+ std::string path (reinterpret_cast <const char *>(model_path));
209+ std::cout << " Loading model from path: " << path << std::endl;
210+ std::cout.flush ();
211+
212+ try {
213+ auto * module = new torch::jit::Module (torch::jit::load (path));
214+ module ->to (best_device,best_dtype,false );
215+ module ->eval ();
216+ std::cout << " Model loaded successfully!" << std::endl;
217+ std::cout.flush ();
218+ return { static_cast <void *>(module ) };
219+ } catch (const c10::Error& e) {
220+ std::cerr << " error loading the model\n " << e.msg ();
221+ std::cout << " error loading the model\n " << e.msg ();
222+ std::cout.flush ();
223+ std::cerr.flush ();
224+ }
225+ std::cout << " Model loading failed!" << std::endl;
226+ std::cout.flush ();
227+
228+ return { nullptr };
229+ }
230+
231+
232+
233+ bridge_tensor_t model_forward (bridge_pt_model_t model, bridge_tensor_t input, bool is_vgg_based_model) {
234+ auto tn_mps = bridge_to_torch (input,best_device,true ,best_dtype);
235+ // tn_mps = tn_mps.permute({2, 0, 1}).contiguous();
236+ // tn_mps.unsqueeze_(0);//.contiguous();
237+ auto tn = tn_mps.permute ({2 , 0 , 1 }).unsqueeze (0 ).contiguous ();
238+
239+ std::vector<torch::jit::IValue> ins;
240+ ins.push_back (tn);
241+
242+ auto * module = static_cast <torch::jit::Module*>(model.pt_module );
243+ auto o = module ->forward (ins).toTensor ();
244+ // auto tn_out = o.squeeze(0).permute({1, 2, 0}).contiguous();
245+ auto tn_out = o.squeeze (0 ).contiguous ().permute ({1 , 2 , 0 }).contiguous ();
246+
247+ if (is_vgg_based_model) {
248+ tn_out.div_ (255.0 );
249+ }
250+
251+ auto tn_out_cpu = tn_out.to (torch::kCPU ,torch::kFloat32 ,false ,true );
252+
253+ return torch_to_bridge (tn_out_cpu);
254+
255+ }
256+
257+ extern " C" bridge_tensor_t model_forward (bridge_pt_model_t model, bridge_tensor_t input) {
258+ return model_forward (model, input, false );
259+ }
260+
261+ extern " C" bridge_tensor_t model_forward_style_transfer (bridge_pt_model_t model, bridge_tensor_t input) {
262+ return model_forward (model, input, true );
263+ }
264+
265+ // std::tuple<uint64_t, uint64_t> get_cpu_frame_size(uint64_t width, uint64_t height, float32_t scale_factor) {
266+ // // if (best_device == torch::kMPS || best_device == torch::kCUDA)
267+ // if (accelerator_available())
268+ // return std::make_tuple(width, height);
269+ // uint64_t new_width = static_cast<uint64_t>(width * scale_factor);
270+ // uint64_t new_height = static_cast<uint64_t>(height * scale_factor);
271+ // return std::make_tuple(new_width, new_height);
272+ // }
273+
274+ // extern "C" uint64_t get_cpu_frame_width(uint64_t width,float32_t scale_factor) {
275+ // return std::get<0>(get_cpu_frame_size(width, 0, scale_factor));
276+ // }
277+ // extern "C" uint64_t get_cpu_frame_height(uint64_t height,float32_t scale_factor) {
278+ // return std::get<1>(get_cpu_frame_size(0, height, scale_factor));
279+ // }
280+
281+
282+ extern " C" void hello_world (void ) {
283+ std::cout << " Hello from C++!" << std::endl;
284+ std::cout.flush ();
285+ }
286+
136287extern " C" bridge_tensor_t increment3 (bridge_tensor_t arr) {
137288 auto t = bridge_to_torch (arr);
138289 // Increment the tensor
@@ -406,40 +557,36 @@ extern "C" void split_loop_filler(int64_t n,int64_t* ret) {
406557
407558
408559
409- cv::VideoCapture open_camera (int cam_index) {
410- cv::VideoCapture cap (cam_index, cv::CAP_AVFOUNDATION);
411- if (!cap.isOpened ()) {
412- std::cerr << " Could not open camera index " << cam_index << std::endl;
413- return cv::VideoCapture ();
414- }
415- cap.set (cv::CAP_PROP_BUFFERSIZE, 1 ); // minimal internal buffering
416- cap.set (cv::CAP_PROP_FPS, 60 ); // request higher FPS if possible
417- return cap;
418- }
419-
560+ // cv::VideoCapture open_camera(int cam_index) {
561+ // cv::VideoCapture cap(cam_index, cv::CAP_AVFOUNDATION);
562+ // if (!cap.isOpened()) {
563+ // std::cerr << "Could not open camera index " << cam_index << std::endl;
564+ // return cv::VideoCapture();
565+ // }
566+ // cap.set(cv::CAP_PROP_BUFFERSIZE, 1); // minimal internal buffering
567+ // cap.set(cv::CAP_PROP_FPS, 60); // request higher FPS if possible
568+ // return cap;
569+ // }
420570
421- extern " C" void show_webcam (void ) {
422- cv::VideoCapture cap;
423- cap = open_camera (0 );
424571
425- cv::Mat frame_bgr;
572+ // extern "C" void show_webcam(void) {
573+ // cv::VideoCapture cap;
574+ // cap = open_camera(0);
426575
427- while (true ) {
428- if (!cap.read (frame_bgr) || frame_bgr.empty ()) {
429- std::cerr << " [WARN] Empty frame, exiting" << std::endl;
430- break ;
431- }
576+ // cv::Mat frame_bgr;
432577
433- cv::imshow (" webcam" , frame_bgr);
578+ // while (true) {
579+ // if (!cap.read(frame_bgr) || frame_bgr.empty()) {
580+ // std::cerr << "[WARN] Empty frame, exiting" << std::endl;
581+ // break;
582+ // }
434583
435- if (cv::waitKey (1 ) == 27 ) { // ESC key
436- break ;
437- }
438- }
584+ // cv::imshow("webcam", frame_bgr);
439585
440- cap.release ();
441- cv::destroyAllWindows ();
442- }
586+ // if (cv::waitKey(1) == 27) { // ESC key
587+ // break;
588+ // }
589+ // }
443590
444591
445592// Simple activation function defs
@@ -687,4 +834,4 @@ extern "C" bridge_tensor_t dropout3d(
687834 .training (training));
688835
689836 return torch_to_bridge (t_output);
690- }
837+ }
0 commit comments