Skip to content

Commit 51b79ff

Browse files
committed
CPP style transfer demo working with artifacts.
1 parent a0d42d5 commit 51b79ff

2 files changed

Lines changed: 19 additions & 7 deletions

File tree

demos/video/include/cvtool.hpp

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -179,16 +179,28 @@ cv::Mat to_mat(at::Tensor &tensor) {
179179
int height = tensor.size(2);
180180
int width = tensor.size(3);
181181
auto t = tensor
182-
.mul(255)
183-
.squeeze()
184182
.detach()
183+
.squeeze()
184+
.mul(255.0)
185+
.clamp(0, 255)
185186
.permute({1, 2, 0})
186187
.contiguous()
187188
.to(torch::kUInt8)
188-
// .clamp(0, 255)
189189
.clone()
190-
// .to(cvtool::get_default_device(), /*non_blocking=*/true, /*copy=*/true)
191190
.to(torch::kCPU);
191+
192+
193+
// auto t = tensor
194+
// .mul(255)
195+
// .squeeze()
196+
// .detach()
197+
// .permute({1, 2, 0})
198+
// .contiguous()
199+
// .to(torch::kUInt8)
200+
// // .clamp(0, 255)
201+
// .clone()
202+
// // .to(cvtool::get_default_device(), /*non_blocking=*/true, /*copy=*/true)
203+
// .to(torch::kCPU);
192204
cv::Mat mat = cv::Mat(height, width, CV_8UC3, t.data_ptr());
193205
return mat.clone();
194206

demos/video/style-transfer/style_transfer.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ int main() {
123123
// default_device = default_device_st;
124124

125125
// std::string model_path = "style-transfer/models/mosaic_float32.pt";
126-
std::string model_path = "style-transfer/models/sobel_edge_float32.pt" ;
126+
std::string model_path = "style-transfer/models/mosaic_float16.pt" ;
127127
torch::jit::Module module = load_model(model_path);
128128
/*
129129
// module.to(torch::kFloat16);
@@ -225,8 +225,8 @@ int run_webcam_model(torch::jit::Module& module, int cam_index, int max_fps, boo
225225

226226

227227
// // // works
228-
auto processed_input = run_model(module,prepped_input);
229-
auto out_processed_input = processed_input.to(torch::kCPU,true);
228+
auto processed_input = run_model(module,prepped_input.to(torch::kFloat16)) / 255.0;
229+
auto out_processed_input = processed_input; // processed_input.to(torch::kCPU,true);
230230
output_bgr = to_mat(out_processed_input, cv::COLOR_RGB2BGR);
231231

232232
// // works

0 commit comments

Comments
 (0)