2929
3030
3131
32+ // Globals
33+
34+
35+ torch::Device get_best_device ();
36+ torch::ScalarType get_best_dtype ();
37+
38+ auto best_device = get_best_device();
39+ auto best_dtype = get_best_dtype();
40+
3241torch::NoGradGuard no_grad;
3342torch::AutoGradMode enable_grad (false );
3443
44+
45+
46+
47+
48+
49+ torch::Device get_best_device () {
50+ if (torch::hasMPS ()) {
51+ return torch::Device (torch::kMPS );
52+ } else if (torch::hasCUDA ()) {
53+ return torch::Device (torch::kCUDA );
54+ } else {
55+ return torch::Device (torch::kCPU );
56+ }
57+ }
58+
59+ extern " C" bool_t accelerator_available () {
60+ return false ;
61+ // return torch::hasMPS() || torch::hasCUDA();
62+ }
63+
64+ torch::ScalarType get_best_dtype () {
65+ if (torch::hasMPS ()) {
66+ return torch::kFloat16 ;
67+ } else if (torch::hasCUDA ()) {
68+ return torch::kFloat16 ;
69+ } else {
70+ return torch::kFloat32 ;
71+ }
72+ }
73+
3574int bridge_tensor_elements (bridge_tensor_t &bt) {
3675 int size = 1 ;
3776 for (int i = 0 ; i < bt.dim ; ++i) {
@@ -149,8 +188,6 @@ extern "C" bridge_tensor_t load_run_model(const uint8_t* model_path, bridge_tens
149188}
150189
151190
152- #define DEVICE torch::kMPS
153- #define DTYPE torch::kFloat16
154191
155192
156193extern " C" bridge_pt_model_t load_model (const uint8_t * model_path) {
@@ -163,7 +200,7 @@ extern "C" bridge_pt_model_t load_model(const uint8_t* model_path) {
163200
164201 try {
165202 auto * module = new torch::jit::Module (torch::jit::load (path));
166- module ->to (DEVICE,DTYPE ,false );
203+ module ->to (best_device,best_dtype ,false );
167204 module ->eval ();
168205 std::cout << " Model loaded successfully!" << std::endl;
169206 std::cout.flush ();
@@ -183,24 +220,24 @@ extern "C" bridge_pt_model_t load_model(const uint8_t* model_path) {
183220
184221
185222bridge_tensor_t model_forward (bridge_pt_model_t model, bridge_tensor_t input, bool is_vgg_based_model) {
186- auto tn_mps = bridge_to_torch (input,DEVICE ,true ,DTYPE );
187- tn_mps = tn_mps.permute ({2 , 0 , 1 }).contiguous ();
188- tn_mps.unsqueeze_ (0 );// .contiguous();
189- // auto tn = tn_mps.permute({2, 0, 1}).unsqueeze(0).contiguous();
223+ auto tn_mps = bridge_to_torch (input,best_device ,true ,best_dtype );
224+ // tn_mps = tn_mps.permute({2, 0, 1}).contiguous();
225+ // tn_mps.unsqueeze_(0);//.contiguous();
226+ auto tn = tn_mps.permute ({2 , 0 , 1 }).unsqueeze (0 ).contiguous ();
190227
191228 std::vector<torch::jit::IValue> ins;
192- ins.push_back (tn_mps );
229+ ins.push_back (tn );
193230
194231 auto * module = static_cast <torch::jit::Module*>(model.pt_module );
195232 auto o = module ->forward (ins).toTensor ();
196- auto tn_out = o.squeeze (0 ).permute ({1 , 2 , 0 }).contiguous ();
197- // auto tn_out = o.squeeze(0).contiguous().permute({1, 2, 0}).contiguous();
233+ // auto tn_out = o.squeeze(0).permute({1, 2, 0}).contiguous();
234+ auto tn_out = o.squeeze (0 ).contiguous ().permute ({1 , 2 , 0 }).contiguous ();
198235
199236 if (is_vgg_based_model) {
200237 tn_out.div_ (255.0 );
201238 }
202239
203- auto tn_out_cpu = tn_out.to (torch::kCPU ,torch::kFloat32 ,false ,false );
240+ auto tn_out_cpu = tn_out.to (torch::kCPU ,torch::kFloat32 ,false ,true );
204241
205242 return torch_to_bridge (tn_out_cpu);
206243
@@ -214,6 +251,22 @@ extern "C" bridge_tensor_t model_forward_style_transfer(bridge_pt_model_t model,
214251 return model_forward (model, input, true );
215252}
216253
254+ std::tuple<uint64_t , uint64_t > get_cpu_frame_size (uint64_t width, uint64_t height, float32_t scale_factor) {
255+ // if (best_device == torch::kMPS || best_device == torch::kCUDA)
256+ if (accelerator_available ())
257+ return std::make_tuple (width, height);
258+ uint64_t new_width = static_cast <uint64_t >(width * scale_factor);
259+ uint64_t new_height = static_cast <uint64_t >(height * scale_factor);
260+ return std::make_tuple (new_width, new_height);
261+ }
262+
263+ extern " C" uint64_t get_cpu_frame_width (uint64_t width,float32_t scale_factor) {
264+ return std::get<0 >(get_cpu_frame_size (width, 0 , scale_factor));
265+ }
266+ extern " C" uint64_t get_cpu_frame_height (uint64_t height,float32_t scale_factor) {
267+ return std::get<1 >(get_cpu_frame_size (0 , height, scale_factor));
268+ }
269+
217270
218271extern " C" void hello_world (void ) {
219272 std::cout << " Hello from C++!" << std::endl;
0 commit comments