Skip to content

Commit 6668e67

Browse files
committed
Add sobel image example. Going to try use CLion.
1 parent 347944a commit 6668e67

4 files changed

Lines changed: 110 additions & 45 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,4 @@ libtorch/
4545
build-release/
4646
libtorch_static/
4747
examples/vgg/**/*.pt
48+
segmentation_models.pytorch/

demos/video/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ find_library(FOUNDATION Foundation REQUIRED)
1212
add_executable(VidStreamer
1313
${CMAKE_CURRENT_SOURCE_DIR}/webcam_infer.cpp
1414
${CMAKE_CURRENT_SOURCE_DIR}/cvutil.hpp
15+
${CMAKE_CURRENT_SOURCE_DIR}/imageops.hpp
1516
)
1617

1718
target_include_directories(VidStreamer

demos/video/imageops.hpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
#pragma once
2+
3+
#include <torch/torch.h>
4+
5+
6+
namespace imageops {
7+
8+
/// \brief Apply Sobel edge detection to each RGB channel separately.
9+
/// \param img Input tensor [1, 3, H, W] (float32/float64, CPU or CUDA)
10+
/// \return Tensor [1, 3, H, W] with gradient magnitude per channel.
11+
torch::Tensor sobel_rgb(const torch::Tensor& img)
12+
{
13+
TORCH_CHECK(img.dim() == 4 &&
14+
img.size(0) == 1 &&
15+
img.size(1) == 3,
16+
"Expected input of shape [1, 3, H, W]");
17+
18+
// Preserve dtype & device
19+
const auto opts = img.options();
20+
const auto dev = img.device();
21+
22+
/* ----- build 3-channel Sobel kernels -------------------------------- */
23+
// (out_channels, in_channels / groups, kH, kW) with groups = 3
24+
torch::Tensor kx = torch::tensor({{ -1, 0, 1},
25+
{ -2, 0, 2},
26+
{ -1, 0, 1}}, opts);
27+
torch::Tensor ky = torch::tensor({{ -1, -2, -1},
28+
{ 0, 0, 0},
29+
{ 1, 2, 1}}, opts);
30+
31+
// Replicate each kernel for the three groups (RGB)
32+
torch::Tensor weight_x = kx.expand({3, 1, 3, 3}).clone();
33+
torch::Tensor weight_y = ky.expand({3, 1, 3, 3}).clone();
34+
35+
/* ----- convolutions -------------------------------------------------- */
36+
const int64_t groups = 3;
37+
const int64_t padding = 1; // keep spatial size
38+
39+
torch::Tensor gx = torch::conv2d(img, weight_x, /*bias=*/{}, /*stride=*/1,
40+
padding, /*dilation=*/1, groups);
41+
torch::Tensor gy = torch::conv2d(img, weight_y, /*bias=*/{}, /*stride=*/1,
42+
padding, /*dilation=*/1, groups);
43+
44+
/* ----- gradient magnitude ------------------------------------------- */
45+
torch::Tensor magnitude = torch::sqrt(gx.pow(2) + gy.pow(2) + 1e-12);
46+
47+
return magnitude;
48+
}
49+
50+
}

demos/video/webcam_infer.cpp

Lines changed: 58 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,43 @@
66
#include <utility>
77

88
#include "cvutil.hpp"
9+
#include "imageops.hpp"
10+
11+
12+
torch::Tensor sobel_dx = torch::tensor({{-1, 0, 1},
13+
{-2, 0, 2},
14+
{-1, 0, 1}}).to(torch::kFloat32);
15+
torch::Tensor sobel_dy = torch::tensor({{-1, -2, -1},
16+
{0, 0, 0},
17+
{1, 2, 1}}).to(torch::kFloat32);
18+
19+
torch::Tensor sobel_kernel = torch::cat({sobel_dx, sobel_dy}, 0).unsqueeze(0).unsqueeze(0);
20+
21+
torch::Tensor sobel_conv(torch::Tensor& input) {
22+
// Convert input to float32
23+
// input = input.to(torch::kFloat32);
24+
25+
// // Apply Sobel filter
26+
// auto sobel_x = at::conv2d(input, sobel_kernel, /*bias=*/{}, /*stride=*/1, /*padding=*/1);
27+
// auto sobel_y = at::conv2d(input, sobel_kernel.transpose(0, 1), /*bias=*/{}, /*stride=*/1, /*padding=*/1);
28+
29+
// // Compute magnitude
30+
// auto sobel_magnitude = torch::sqrt(sobel_x.pow(2) + sobel_y.pow(2));
31+
32+
33+
auto output = torch::conv2d(input, sobel_kernel, {},1,1);
34+
35+
return output;
36+
}
937

1038
struct Model : torch::nn::Module {
1139
Model() {
1240
fc1 = register_module("fc1", torch::nn::Linear(784, 64));
1341
fc2 = register_module("fc2", torch::nn::Linear(64, 10));
1442
r = register_parameter("r", torch::rand({1, 3, 224, 224}));
1543
uninitialized = true;
44+
45+
std::cout << "Sobel kernel: " << sobel_kernel.sizes() << std::endl;
1646
}
1747

1848
torch::Tensor forward(torch::Tensor x) {
@@ -23,7 +53,9 @@ struct Model : torch::nn::Module {
2353
std::cout << "Input sizes: " << x.sizes() << std::endl;
2454
}
2555
// auto output = x + r;
26-
auto output = imagenet_normalize_tensor(x);
56+
auto input = x;
57+
auto output = imageops::sobel_rgb(input);
58+
// auto output = imagenet_normalize_tensor(input);
2759
return output;
2860
}
2961

@@ -33,6 +65,8 @@ struct Model : torch::nn::Module {
3365
};
3466

3567

68+
int max_fps = 60;
69+
3670
int main(int argc, char** argv) {
3771

3872
if (argc < 2) {
@@ -141,8 +175,6 @@ int main(int argc, char** argv) {
141175

142176
std::chrono::time_point<std::chrono::system_clock> start_total = std::chrono::system_clock::now();
143177
std::chrono::time_point<std::chrono::system_clock> last_update = std::chrono::system_clock::now();
144-
const double max_fps = 30.0;
145-
const double max_frame_delay = 1000.0 / max_fps;
146178

147179
size_t frame_count = 0;
148180
size_t last_frame_count = 0;
@@ -158,27 +190,34 @@ int main(int argc, char** argv) {
158190
start_total = std::chrono::system_clock::now();
159191
last_update = std::chrono::system_clock::now(); // ??? not sure
160192
cap.set(cv::CAP_PROP_POS_FRAMES, 0);
161-
std::cout << "\r[INFO] Replaying video..." << std::flush;
193+
std::cout << "[INFO] Replaying video..." << std::endl;
162194
continue;
163195
}
164196
std::cerr << "[WARN] Empty frame, exiting" << std::endl;
165197
break;
166198
}
167199

168-
// Convert BGR -> RGB + float32 scaled 0‑1 directly into pre‑allocated tensor
169-
// cv::cvtColor(frame_bgr, frame_rgb, cv::COLOR_BGR2RGB);
170-
// frame_rgb.convertTo(frame_rgb, CV_32F, 1.f/255.f);
171200

172-
// Rearrange NHWC -> NCHW (in‑place view, no copy)
173-
// torch::Tensor input_tensor = frame_tensor_cpu.permute({0,3,1,2});
201+
++frame_count;
202+
const std::chrono::time_point<std::chrono::system_clock> now = std::chrono::system_clock::now();
203+
auto delta = now - last_update;
204+
// std::chrono::milliseconds delta_millis = std::chrono::duration_cast<std::chrono::microseconds>(delta);
205+
double delta_time = std::chrono::duration_cast<std::chrono::duration<double>>(delta).count();
206+
auto fps = 1.0 / delta_time;
207+
std::cout << "\r[INFO] FPS: " << fps << " fps" << std::flush;
174208

209+
// Display (optional)
175210

211+
double sleep_time = (1.0 / ((double)max_fps)) - delta_time;
212+
213+
std::this_thread::sleep_for(std::chrono::duration<double>(sleep_time));
176214

177-
auto input_tensor = to_tensor(frame_bgr);
178-
// auto input_tensor = frame_tensor_device->permute({0, 3, 1, 2});
179215

180216

181217

218+
auto input_tensor = to_tensor(frame_bgr);
219+
// auto input_tensor = frame_tensor_device->permute({0, 3, 1, 2});
220+
182221
// std::cout << "input_tensor: " << input_tensor.sizes() << std::endl;
183222
// std::cout << "input_tensor max: " << torch::max(input_tensor) << std::endl;
184223
// input_tensor = to_mps(input_tensor);
@@ -189,41 +228,26 @@ int main(int argc, char** argv) {
189228
// auto output_bgr = to_mat(output);
190229
output_bgr = to_mat(output);
191230

231+
cv::imshow("webcam", output_bgr);
232+
192233
// Display FPS
193-
++frame_count;
194-
const std::chrono::time_point<std::chrono::system_clock> now = std::chrono::system_clock::now();
195-
auto delta = now - last_update;
196-
// std::chrono::milliseconds delta_millis = std::chrono::duration_cast<std::chrono::microseconds>(delta);
197-
double delta_time = std::chrono::duration_cast<std::chrono::duration<double,std::milli>>(delta).count();
198-
std::cout << "\r[INFO] Frame time: " << delta_time * 1000.0 << " ms" << std::flush;
199234

235+
// std::cout << "[INFO] dt: " << delta_time << std::endl;
236+
// std::cout << "[INFO] FPS: " << fps << std::endl;
200237

201-
// auto now = std::chrono::steady_clock::now();
202-
// auto delta = std::chrono::duration_cast<std::chrono::duration<double>>(now - start_total);
203-
// double seconds = std::chrono::duration_cast<std::chrono::duration<double>>(delta).count();
204-
// double max_frame_count = max_fps / seconds;
205-
// double fps = frame_count / seconds;
206238

207239

208240

209-
// std::this_thread::sleep_for(delta);
241+
// std::thread::sleep_for(std::chrono::milliseconds(700));
210242

211-
// Sleep just to avoid too high FPS
243+
// std::thread::sleep_for()
212244

213-
// if (fps > max_fps) {
214-
// double missed_frames = fps - max_fps;
215-
// std::this_thread::sleep_for(std::chrono::milliseconds(missed_frames / max_fps));
216-
// }
217-
218245

246+
// std::thread::sleep_for(std::chrono::milliseconds(expected_time_index - (delta_time + last_time_index)));
219247

220-
double fps = (frame_count - last_frame_count) / delta_time;
221-
std::cout << "\r[INFO] FPS: " << fps << std::flush;
222-
last_update = now;
223248

224-
// Display (optional)
225-
cv::imshow("webcam", output_bgr);
226249
last_frame_count = frame_count;
250+
last_update = now; // std::chrono::system_clock::now();
227251
if (cv::waitKey(1) == 27) { // ESC key
228252
break;
229253
}
@@ -233,14 +257,3 @@ int main(int argc, char** argv) {
233257
cv::destroyAllWindows();
234258
return 0;
235259
}
236-
237-
238-
239-
240-
241-
// if (seconds >= 1.0) {
242-
// double fps = frame_count / seconds;
243-
// std::cout << "\r[INFO] FPS: " << fps << std::flush;
244-
// frame_count = 0;
245-
// start_total = now;
246-
// }

0 commit comments

Comments
 (0)