@@ -26,7 +26,8 @@ torch::jit::Module load_model(const std::string& model_path) {
2626 std::cout << " Model loaded successfully." << std::endl;
2727
2828 std::cout << " Moving model to device..." << std::endl;
29- module .to (default_device_st);
29+ auto device = cvtool::get_default_device ();
30+ module .to (device);
3031 std::cout << " Model moved to device." << std::endl;
3132
3233 } catch (const c10::Error& e) {
@@ -50,9 +51,9 @@ torch::Tensor run_model(torch::jit::Module& module, const torch::Tensor& input)
5051 std::vector<torch::jit::IValue> inputs;
5152 inputs.push_back (input);
5253
53- std::cout << " Input tensor: " << input.sizes () << std::endl;
54+ // std::cout << "Input tensor: " << input.sizes() << std::endl;
5455 auto output = module .forward (inputs).toTensor ();
55- std::cout << " Model output: " << output.sizes () << std::endl;
56+ // std::cout << "Model output: " << output.sizes() << std::endl;
5657 return output;
5758}
5859
@@ -71,14 +72,17 @@ int main() {
7172 std::cout << " MPS is available and set as the default device." << std::endl;
7273 } else {
7374 default_device_st = torch::Device (torch::kCPU );
74- std::cout << " MPS is not available. Using CPU instead." << std::endl;
75+ std::cout << " MPS is not available. Using CPU instead. " << std::endl;
7576 }
77+ cvtool::set_default_device (default_device_st);
78+
79+ auto device = cvtool::get_default_device ();
7680
7781 // default_device = default_device_st;
7882
7983 std::string model_path = " style-transfer/models/mosaic.pt" ;
8084 torch::jit::Module module = load_model (model_path);
81- torch::Tensor input = torch::randn ({1 , 3 , 1428 , 1904 }, default_device_st );
85+ torch::Tensor input = torch::randn ({1 , 3 , 1428 , 1904 }, device );
8286 torch::Tensor output = run_model (module , input);
8387
8488 // Print the output tensor
@@ -90,7 +94,7 @@ int main() {
9094
9195int run_webcam_model (torch::jit::Module& module , int cam_index, int max_fps, bool is_video_loop, std::string vid_path = " " ) {
9296
93- torch::Device device = default_device_st ;
97+ torch::Device device = cvtool::get_default_device () ;
9498
9599 module .eval ();
96100 module .to (device);
@@ -104,31 +108,13 @@ int run_webcam_model(torch::jit::Module& module, int cam_index, int max_fps, boo
104108 cap = open_camera (cam_index);
105109 }
106110
107-
108- // 4. Pre‑allocate tensor to avoid dynamic allocations
109- // int width = static_cast<int>(cap.get(cv::CAP_PROP_FRAME_WIDTH));
110- // int height = static_cast<int>(cap.get(cv::CAP_PROP_FRAME_HEIGHT));
111111 auto camera_resolution = get_camera_resolution (cap);
112112 int height = std::get<0 >(camera_resolution);
113113 int width = std::get<1 >(camera_resolution);
114114
115115
116-
117- // // NHWC float32 frame buffer (1, H, W, 3)
118- // auto options_cpu = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCPU);
119- // torch::Tensor frame_tensor_cpu = torch::empty({1, height, width, 3}, options_cpu);
120-
121- // // MPS device tensor gets created lazily (to avoid copies when MPS unavailable)
122- // torch::Tensor frame_tensor_device;
123- // if (device.is_mps()) {
124- // frame_tensor_device = frame_tensor_cpu.to(device, /*non_blocking=*/true);
125- // }
126-
127- // auto frame_tensor_device = create_frame_buffer_tensor(height, width, device);
128-
129116 cv::Mat frame_bgr;
130117 cv::Mat output_bgr;
131- // cv::Mat frame_rgb(height, width, CV_32FC3, frame_tensor_device->data_ptr());
132118
133119 const auto to_mps = [&](torch::Tensor& t){ return device.is_mps () ? t.to (device, /* non_blocking=*/ true ) : t; };
134120
@@ -162,26 +148,22 @@ int run_webcam_model(torch::jit::Module& module, int cam_index, int max_fps, boo
162148 ++frame_count;
163149 const std::chrono::time_point<std::chrono::system_clock> now = std::chrono::system_clock::now ();
164150 auto delta = now - last_update;
165- // std::chrono::milliseconds delta_millis = std::chrono::duration_cast<std::chrono::microseconds>(delta);
166151 double delta_time = std::chrono::duration_cast<std::chrono::duration<double >>(delta).count ();
167152 auto fps = 1.0 / delta_time;
168153 std::cout << " \r [INFO] FPS: " << fps << " fps" << std::flush;
169-
170- // Display (optional)
171-
172154 double sleep_time = (1.0 / ((double )max_fps)) - delta_time;
173-
174155 std::this_thread::sleep_for (std::chrono::duration<double >(sleep_time));
175156
176157
177- auto input_tensor = to_tensor (frame_bgr).clone ();
178- auto mps_tensor = input_tensor.to (torch::kMPS ,true ).clone ();
158+
159+ auto input_tensor = to_tensor (frame_bgr);
160+ auto mps_tensor = input_tensor.to (device,true );
179161
180162 auto prepped_input = preprocess_input (mps_tensor);
181163
182164 // Forward pass
183- auto output = run_model (module , prepped_input). clone () ;
184- auto processed_output = output.to (torch::kCPU ,true ). clone () ;
165+ auto output = run_model (module , prepped_input);
166+ auto processed_output = output.to (torch::kCPU ,true );
185167
186168 output_bgr = to_mat (processed_output);
187169
0 commit comments