11#include < bridge.h>
22
33#include < torch/torch.h>
4+ #include < Aten/ATen.h>
5+
46#include < torch/script.h>
57
68// #include <torch/script.h>
2729
2830
2931
32+ // Globals
33+
34+
35+ torch::Device get_best_device ();
36+ torch::ScalarType get_best_dtype ();
37+
38+ auto best_device = get_best_device();
39+ auto best_dtype = get_best_dtype();
40+
41+ torch::NoGradGuard no_grad;
42+ torch::AutoGradMode enable_grad (false );
43+
44+ bool debug_cpu_only = false ;
45+
46+
47+
48+ torch::Device get_best_device () {
49+ if (debug_cpu_only)
50+ return torch::Device (torch::kCPU );
51+
52+ if (torch::hasMPS ()) {
53+ return torch::Device (torch::kMPS );
54+ } else if (torch::hasCUDA ()) {
55+ return torch::Device (torch::kCUDA );
56+ } else {
57+ return torch::Device (torch::kCPU );
58+ }
59+ }
60+
61+ extern " C" void debug_cpu_only_mode (bool_t mode) {
62+ debug_cpu_only = mode;
63+ if (debug_cpu_only) {
64+ best_device = torch::Device (torch::kCPU );
65+ } else {
66+ best_device = get_best_device ();
67+ }
68+ }
69+
70+ extern " C" bool_t accelerator_available () {
71+ return (best_device == torch::Device (torch::kCUDA ) || best_device == torch::Device (torch::kMPS ));
72+ }
73+
74+ torch::ScalarType get_best_dtype () {
75+ if (torch::hasMPS ()) {
76+ return torch::kFloat16 ;
77+ } else if (torch::hasCUDA ()) {
78+ return torch::kFloat16 ;
79+ } else {
80+ return torch::kFloat32 ;
81+ }
82+ }
83+
3084int bridge_tensor_elements (bridge_tensor_t &bt) {
3185 int size = 1 ;
3286 for (int i = 0 ; i < bt.dim ; ++i) {
@@ -39,14 +93,14 @@ size_t bridge_tensor_size(bridge_tensor_t &bt) {
3993 return sizeof (float32_t ) * bridge_tensor_elements (bt);
4094}
4195
42- void store_tensor (torch ::Tensor &input, float32_t * dest) {
96+ void store_tensor (at ::Tensor &input, float32_t * dest) {
4397 float32_t * data = input.data_ptr <float32_t >();
4498 size_t bytes_size = sizeof (float32_t ) * input.numel ();
4599 // std::memmove(dest,data,bytes_size);
46100 std::memcpy (dest,data,bytes_size);
47101}
48102
49- bridge_tensor_t torch_to_bridge (torch ::Tensor &tensor) {
103+ bridge_tensor_t torch_to_bridge (at ::Tensor &tensor) {
50104 bridge_tensor_t result;
51105 result.created_by_c = true ;
52106 result.dim = tensor.dim ();
@@ -59,13 +113,13 @@ bridge_tensor_t torch_to_bridge(torch::Tensor &tensor) {
59113 return result;
60114}
61115
62- torch ::Tensor bridge_to_torch (bridge_tensor_t &bt) {
116+ at ::Tensor bridge_to_torch (bridge_tensor_t &bt) {
63117 std::vector<int64_t > sizes_vec (bt.sizes , bt.sizes + bt.dim );
64118 auto shape = torch::IntArrayRef (sizes_vec);
65119 return torch::from_blob (bt.data , shape, torch::kFloat );
66120}
67121
68- torch ::Tensor bridge_to_torch (bridge_tensor_t &bt,torch::Device device, bool copy,torch::ScalarType dtype = torch::kFloat32 ) {
122+ at ::Tensor bridge_to_torch (bridge_tensor_t &bt,torch::Device device, bool copy,torch::ScalarType dtype = torch::kFloat32 ) {
69123 std::vector<int64_t > sizes_vec (bt.sizes , bt.sizes + bt.dim );
70124 auto shape = torch::IntArrayRef (sizes_vec);
71125 auto t = torch::from_blob (bt.data , shape, torch::kFloat );
@@ -144,6 +198,8 @@ extern "C" bridge_tensor_t load_run_model(const uint8_t* model_path, bridge_tens
144198}
145199
146200
201+
202+
147203extern " C" bridge_pt_model_t load_model (const uint8_t * model_path) {
148204
149205 std::cout << " Begin loading model from path: " << model_path << std::endl;
@@ -153,21 +209,12 @@ extern "C" bridge_pt_model_t load_model(const uint8_t* model_path) {
153209 std::cout.flush ();
154210
155211 try {
156-
157212 auto * module = new torch::jit::Module (torch::jit::load (path));
158- module ->to (torch:: kMPS ,torch:: kFloat16 ,false );
213+ module ->to (best_device,best_dtype ,false );
159214 module ->eval ();
160215 std::cout << " Model loaded successfully!" << std::endl;
161216 std::cout.flush ();
162217 return { static_cast <void *>(module ) };
163-
164- // torch::jit::Module tmp = torch::jit::load(path);
165- // std::cout << "Model loaded successfully!" << std::endl;
166- // std::cout.flush();
167- // auto* module = new torch::jit::Module(std::move(tmp));
168- // std::cout << "Model moved successfully!" << std::endl;
169- // std::cout.flush();
170- // return { static_cast<void*>(module) };
171218 } catch (const c10::Error& e) {
172219 std::cerr << " error loading the model\n " << e.msg ();
173220 std::cout << " error loading the model\n " << e.msg ();
@@ -178,49 +225,30 @@ extern "C" bridge_pt_model_t load_model(const uint8_t* model_path) {
178225 std::cout.flush ();
179226
180227 return { nullptr };
181-
182-
183-
184- // bridge_pt_model_t model_wrapper;
185- // torch::jit::Module* pt_module = new torch::jit::Module(); // = (torch::jit::Module*) model_wrapper.pt_module;
186- // try {
187- // *pt_module = torch::jit::load(mp);
188- // std::cout << "Model loaded successfully!" << std::endl;
189- // std::cout.flush();
190- // model_wrapper.pt_module = pt_module;
191- // } catch (const c10::Error& e) {
192- // std::cerr << "error loading the model\n" << e.msg();
193- // std::cout << "error loading the model\n" << e.msg();
194- // std::cout.flush();
195- // std::cerr.flush();
196- // }
197-
198- // std::cout << pt_module->dump_to_str(false,false,false) << std::endl;
199- // std::cout.flush();
200-
201- // return model_wrapper;
202228}
203229
204230
205231
206232bridge_tensor_t model_forward (bridge_pt_model_t model, bridge_tensor_t input, bool is_vgg_based_model) {
207-
208- auto tn_mps = bridge_to_torch (input,torch:: kMPS , true ,torch:: kFloat16 );
209- // auto tn_mps = tn.to(torch::kMPS,false,true );
233+ auto tn_mps = bridge_to_torch (input,best_device, true ,best_dtype);
234+ // tn_mps = tn_mps.permute({2, 0, 1}).contiguous( );
235+ // tn_mps.unsqueeze_(0);//.contiguous( );
210236 auto tn = tn_mps.permute ({2 , 0 , 1 }).unsqueeze (0 ).contiguous ();
211237
212238 std::vector<torch::jit::IValue> ins;
213239 ins.push_back (tn);
214240
215241 auto * module = static_cast <torch::jit::Module*>(model.pt_module );
216242 auto o = module ->forward (ins).toTensor ();
243+ // auto tn_out = o.squeeze(0).permute({1, 2, 0}).contiguous();
217244 auto tn_out = o.squeeze (0 ).contiguous ().permute ({1 , 2 , 0 }).contiguous ();
218245
219246 if (is_vgg_based_model) {
220- tn_out = tn_out / 255.0 ;
247+ tn_out. div_ ( 255.0 ) ;
221248 }
222249
223250 auto tn_out_cpu = tn_out.to (torch::kCPU ,torch::kFloat32 ,false ,true );
251+
224252 return torch_to_bridge (tn_out_cpu);
225253
226254}
@@ -233,6 +261,22 @@ extern "C" bridge_tensor_t model_forward_style_transfer(bridge_pt_model_t model,
233261 return model_forward (model, input, true );
234262}
235263
264+ // std::tuple<uint64_t, uint64_t> get_cpu_frame_size(uint64_t width, uint64_t height, float32_t scale_factor) {
265+ // // if (best_device == torch::kMPS || best_device == torch::kCUDA)
266+ // if (accelerator_available())
267+ // return std::make_tuple(width, height);
268+ // uint64_t new_width = static_cast<uint64_t>(width * scale_factor);
269+ // uint64_t new_height = static_cast<uint64_t>(height * scale_factor);
270+ // return std::make_tuple(new_width, new_height);
271+ // }
272+
273+ // extern "C" uint64_t get_cpu_frame_width(uint64_t width,float32_t scale_factor) {
274+ // return std::get<0>(get_cpu_frame_size(width, 0, scale_factor));
275+ // }
276+ // extern "C" uint64_t get_cpu_frame_height(uint64_t height,float32_t scale_factor) {
277+ // return std::get<1>(get_cpu_frame_size(0, height, scale_factor));
278+ // }
279+
236280
237281extern " C" void hello_world (void ) {
238282 std::cout << " Hello from C++!" << std::endl;
0 commit comments