@@ -41,12 +41,14 @@ auto best_dtype = get_best_dtype();
4141torch::NoGradGuard no_grad;
4242torch::AutoGradMode enable_grad (false );
4343
44-
45-
44+ bool debug_cpu_only = false ;
4645
4746
4847
4948torch::Device get_best_device () {
49+ if (debug_cpu_only)
50+ return torch::Device (torch::kCPU );
51+
5052 if (torch::hasMPS ()) {
5153 return torch::Device (torch::kMPS );
5254 } else if (torch::hasCUDA ()) {
@@ -56,9 +58,17 @@ torch::Device get_best_device() {
5658 }
5759}
5860
61+ extern " C" void debug_cpu_only_mode (bool_t mode) {
62+ debug_cpu_only = mode;
63+ if (debug_cpu_only) {
64+ best_device = torch::Device (torch::kCPU );
65+ } else {
66+ best_device = get_best_device ();
67+ }
68+ }
69+
5970extern " C" bool_t accelerator_available () {
60- return false ;
61- // return torch::hasMPS() || torch::hasCUDA();
71+ return (best_device == torch::Device (torch::kCUDA ) || best_device == torch::Device (torch::kMPS ));
6272}
6373
6474torch::ScalarType get_best_dtype () {
@@ -251,21 +261,21 @@ extern "C" bridge_tensor_t model_forward_style_transfer(bridge_pt_model_t model,
251261 return model_forward (model, input, true );
252262}
253263
254- std::tuple<uint64_t , uint64_t > get_cpu_frame_size (uint64_t width, uint64_t height, float32_t scale_factor) {
255- // if (best_device == torch::kMPS || best_device == torch::kCUDA)
256- if (accelerator_available ())
257- return std::make_tuple (width, height);
258- uint64_t new_width = static_cast <uint64_t >(width * scale_factor);
259- uint64_t new_height = static_cast <uint64_t >(height * scale_factor);
260- return std::make_tuple (new_width, new_height);
261- }
264+ // std::tuple<uint64_t, uint64_t> get_cpu_frame_size(uint64_t width, uint64_t height, float32_t scale_factor) {
265+ // // if (best_device == torch::kMPS || best_device == torch::kCUDA)
266+ // if (accelerator_available())
267+ // return std::make_tuple(width, height);
268+ // uint64_t new_width = static_cast<uint64_t>(width * scale_factor);
269+ // uint64_t new_height = static_cast<uint64_t>(height * scale_factor);
270+ // return std::make_tuple(new_width, new_height);
271+ // }
262272
263- extern " C" uint64_t get_cpu_frame_width (uint64_t width,float32_t scale_factor) {
264- return std::get<0 >(get_cpu_frame_size (width, 0 , scale_factor));
265- }
266- extern " C" uint64_t get_cpu_frame_height (uint64_t height,float32_t scale_factor) {
267- return std::get<1 >(get_cpu_frame_size (0 , height, scale_factor));
268- }
273+ // extern "C" uint64_t get_cpu_frame_width(uint64_t width,float32_t scale_factor) {
274+ // return std::get<0>(get_cpu_frame_size(width, 0, scale_factor));
275+ // }
276+ // extern "C" uint64_t get_cpu_frame_height(uint64_t height,float32_t scale_factor) {
277+ // return std::get<1>(get_cpu_frame_size(0, height, scale_factor));
278+ // }
269279
270280
271281extern " C" void hello_world (void ) {
0 commit comments