Skip to content

Commit 664ee7d

Browse files
committed
making process on style transfer demo c++.
1 parent e24a7a2 commit 664ee7d

6 files changed

Lines changed: 41 additions & 34 deletions

File tree

demos/video/include/cvtool.hpp

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,28 @@
88
#include <utility>
99

1010

11+
namespace cvtool {
12+
static torch::Device default_device(torch::kCPU);
13+
static bool default_device_set = false;
14+
static torch::Device set_default_device(torch::Device device) {
15+
default_device = device;
16+
default_device_set = true;
17+
return default_device;
18+
}
19+
torch::Device get_default_device() {
20+
if (!default_device_set) {
21+
if (torch::mps::is_available()) {
22+
std::cout << "[INFO] Running on MPS" << std::endl;
23+
default_device = torch::Device(torch::kMPS);
24+
} else {
25+
std::cout << "[INFO] MPS not available, falling back to CPU" << std::endl;
26+
default_device = torch::Device(torch::kCPU);
27+
}
28+
}
29+
return default_device;
30+
}
31+
}
32+
1133
static torch::Device default_device(torch::kCPU);
1234
torch::Device get_default_device();
1335

@@ -114,8 +136,8 @@ cv::Mat to_mat(at::Tensor &tensor) {
114136

115137
torch::Device get_default_device() {
116138
if (torch::mps::is_available()) {
117-
// default_device = torch::Device(torch::kMPS);
118139
std::cout << "[INFO] Running on MPS" << std::endl;
140+
default_device = torch::Device(torch::kMPS);
119141
} else {
120142
std::cout << "[INFO] MPS not available, falling back to CPU" << std::endl;
121143
}
0 Bytes
Binary file not shown.
3.27 MB
Binary file not shown.

demos/video/style-transfer/style_transfer.cpp

Lines changed: 15 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ torch::jit::Module load_model(const std::string& model_path) {
2626
std::cout << "Model loaded successfully." << std::endl;
2727

2828
std::cout << "Moving model to device..." << std::endl;
29-
module.to(default_device_st);
29+
auto device = cvtool::get_default_device();
30+
module.to(device);
3031
std::cout << "Model moved to device." << std::endl;
3132

3233
} catch (const c10::Error& e) {
@@ -50,9 +51,9 @@ torch::Tensor run_model(torch::jit::Module& module, const torch::Tensor& input)
5051
std::vector<torch::jit::IValue> inputs;
5152
inputs.push_back(input);
5253

53-
std::cout << "Input tensor: " << input.sizes() << std::endl;
54+
// std::cout << "Input tensor: " << input.sizes() << std::endl;
5455
auto output = module.forward(inputs).toTensor();
55-
std::cout << "Model output: " << output.sizes() << std::endl;
56+
// std::cout << "Model output: " << output.sizes() << std::endl;
5657
return output;
5758
}
5859

@@ -71,14 +72,17 @@ int main() {
7172
std::cout << "MPS is available and set as the default device." << std::endl;
7273
} else {
7374
default_device_st = torch::Device(torch::kCPU);
74-
std::cout << "MPS is not available. Using CPU instead." << std::endl;
75+
std::cout << "MPS is not available. Using CPU instead. " << std::endl;
7576
}
77+
cvtool::set_default_device(default_device_st);
78+
79+
auto device = cvtool::get_default_device();
7680

7781
// default_device = default_device_st;
7882

7983
std::string model_path = "style-transfer/models/mosaic.pt";
8084
torch::jit::Module module = load_model(model_path);
81-
torch::Tensor input = torch::randn({1, 3, 1428, 1904}, default_device_st);
85+
torch::Tensor input = torch::randn({1, 3, 1428, 1904}, device);
8286
torch::Tensor output = run_model(module, input);
8387

8488
// Print the output tensor
@@ -90,7 +94,7 @@ int main() {
9094

9195
int run_webcam_model(torch::jit::Module& module, int cam_index, int max_fps, bool is_video_loop, std::string vid_path = "") {
9296

93-
torch::Device device = default_device_st;
97+
torch::Device device = cvtool::get_default_device();
9498

9599
module.eval();
96100
module.to(device);
@@ -104,31 +108,13 @@ int run_webcam_model(torch::jit::Module& module, int cam_index, int max_fps, boo
104108
cap = open_camera(cam_index);
105109
}
106110

107-
108-
// 4. Pre‑allocate tensor to avoid dynamic allocations
109-
// int width = static_cast<int>(cap.get(cv::CAP_PROP_FRAME_WIDTH));
110-
// int height = static_cast<int>(cap.get(cv::CAP_PROP_FRAME_HEIGHT));
111111
auto camera_resolution = get_camera_resolution(cap);
112112
int height = std::get<0>(camera_resolution);
113113
int width = std::get<1>(camera_resolution);
114114

115115

116-
117-
// // NHWC float32 frame buffer (1, H, W, 3)
118-
// auto options_cpu = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCPU);
119-
// torch::Tensor frame_tensor_cpu = torch::empty({1, height, width, 3}, options_cpu);
120-
121-
// // MPS device tensor gets created lazily (to avoid copies when MPS unavailable)
122-
// torch::Tensor frame_tensor_device;
123-
// if (device.is_mps()) {
124-
// frame_tensor_device = frame_tensor_cpu.to(device, /*non_blocking=*/true);
125-
// }
126-
127-
// auto frame_tensor_device = create_frame_buffer_tensor(height, width, device);
128-
129116
cv::Mat frame_bgr;
130117
cv::Mat output_bgr;
131-
// cv::Mat frame_rgb(height, width, CV_32FC3, frame_tensor_device->data_ptr());
132118

133119
const auto to_mps = [&](torch::Tensor& t){ return device.is_mps() ? t.to(device, /*non_blocking=*/true) : t; };
134120

@@ -162,26 +148,22 @@ int run_webcam_model(torch::jit::Module& module, int cam_index, int max_fps, boo
162148
++frame_count;
163149
const std::chrono::time_point<std::chrono::system_clock> now = std::chrono::system_clock::now();
164150
auto delta = now - last_update;
165-
// std::chrono::milliseconds delta_millis = std::chrono::duration_cast<std::chrono::microseconds>(delta);
166151
double delta_time = std::chrono::duration_cast<std::chrono::duration<double>>(delta).count();
167152
auto fps = 1.0 / delta_time;
168153
std::cout << "\r[INFO] FPS: " << fps << " fps" << std::flush;
169-
170-
// Display (optional)
171-
172154
double sleep_time = (1.0 / ((double)max_fps)) - delta_time;
173-
174155
std::this_thread::sleep_for(std::chrono::duration<double>(sleep_time));
175156

176157

177-
auto input_tensor = to_tensor(frame_bgr).clone();
178-
auto mps_tensor = input_tensor.to(torch::kMPS,true).clone();
158+
159+
auto input_tensor = to_tensor(frame_bgr);
160+
auto mps_tensor = input_tensor.to(device,true);
179161

180162
auto prepped_input = preprocess_input(mps_tensor);
181163

182164
// Forward pass
183-
auto output = run_model(module, prepped_input).clone();
184-
auto processed_output = output.to(torch::kCPU,true).clone();
165+
auto output = run_model(module, prepped_input);
166+
auto processed_output = output.to(torch::kCPU,true);
185167

186168
output_bgr = to_mat(processed_output);
187169

-6.47 MB
Binary file not shown.

examples/pytorch-examples/fast_neural_style/neural_style/neural_style.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,9 @@ def stylize(args):
188188
sm = torch.jit.script(style_model)
189189
sm.save(f"models/{model_name}.pt")
190190

191+
sm = torch.jit.script(style_model.to(torch.float16))
192+
sm.save(f"models/{model_name}_float16.pt")
193+
191194
utils.save_image(args.output_image, output[0])
192195

193196

0 commit comments

Comments
 (0)