Skip to content

Commit 973e7af

Browse files
committed
cpp model construction demo works with cam.cpp.
1 parent 31fce77 commit 973e7af

2 files changed

Lines changed: 158 additions & 1 deletion

File tree

demos/video/cpp-model-construction/CMakeLists.txt

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ include(CMakePrintHelpers)
44
# set(CMAKE_CXX_STANDARD 17)
55
# list(APPEND CMAKE_PREFIX_PATH "${CMAKE_CURRENT_SOURCE_DIR}/libtorch/share/cmake")
66
find_package(Torch REQUIRED)
7-
# find_package(OpenCV REQUIRED)
7+
find_package(OpenCV REQUIRED)
88
SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++23 -lm -ldl")
99

1010
cmake_print_variables(TORCH_LIBRARIES)
@@ -20,4 +20,15 @@ set_property(TARGET CPPModelConstruction PROPERTY CXX_STANDARD 23)
2020

2121
set_property(TARGET CPPModelConstruction PROPERTY
2222
RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}
23+
)
24+
25+
add_executable(CPPModelCam ${CMAKE_CURRENT_SOURCE_DIR}/cam.cpp ${CMAKE_CURRENT_SOURCE_DIR}/transformer_net.hpp)
26+
27+
include_directories(${OpenCV_INCLUDE_DIRS})
28+
29+
target_link_libraries(CPPModelCam ${TORCH_LIBRARIES} ${OpenCV_LIBS})
30+
set_property(TARGET CPPModelCam PROPERTY CXX_STANDARD 23)
31+
32+
set_property(TARGET CPPModelCam PROPERTY
33+
RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}
2334
)
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
#include "transformer_net.hpp"
2+
3+
4+
#include <opencv2/opencv.hpp>
5+
#include <iostream>
6+
7+
8+
TransformerNet load_net() {
9+
torch::manual_seed(0);
10+
std::cout << "set seed\n";
11+
TransformerNet model;
12+
std::cout << "TransformerNet model created.\n";
13+
14+
model->load_parameters("/Users/iainmoncrief/Documents/Github/ChAI/demos/video/cpp-model-construction/state_dict_raw.pt");
15+
// torch::serialize::InputArchive archive;
16+
// archive.load_from("/Users/iainmoncrief/Documents/Github/ChAI/demos/video/cpp-model-construction/incomplete_sunday_afternoon.model"); // load the raw weights :contentReference[oaicite:2]{index=2}
17+
// std::cout << "Loading model from archive...\n";
18+
// model->load(archive);
19+
// std::cout << "Model loaded successfully!\n";
20+
// model->eval();
21+
std::cout << "Model is in evaluation mode.\n";
22+
23+
// dummy input
24+
auto input = torch::randn({1,3,256,256});
25+
std::cout << "Input shape: " << input.sizes() << "\n";
26+
auto output = model->forward(input);
27+
28+
std::cout << "Output shape: " << output.sizes() << "\n";
29+
30+
return model;
31+
}
32+
33+
34+
cv::Mat new_frame(cv::Mat &frame,TransformerNet &model) {
35+
36+
cv::Mat rgb_float_frame;
37+
cv::cvtColor(frame, rgb_float_frame, cv::COLOR_BGR2RGB);
38+
rgb_float_frame.convertTo(rgb_float_frame, CV_32FC3, 1.0f/255.0f);
39+
40+
// cv::MatSize size = rgb_frame.size;
41+
// std::cout << "x " << size[0] << " y " << size[1] << " channels " << rgb_frame.dims << std::endl;
42+
int64_t height = rgb_float_frame.rows;
43+
int64_t width = rgb_float_frame.cols;
44+
int64_t channels = rgb_float_frame.channels();
45+
int64_t pixels = rgb_float_frame.total();
46+
int64_t size = pixels * channels;
47+
48+
// std::cout << "Width: " << width << ", Height: " << height << ", Channels: " << channels << ", Size: " << size << std::endl;
49+
50+
torch::Tensor tensor = torch::from_blob(rgb_float_frame.data,
51+
{height,width, channels},
52+
torch::kFloat32).to(torch::kCPU, torch::kFloat32, /*non_blocking=*/false, /*copy=*/true);
53+
54+
tensor = tensor.permute({2, 0, 1}).unsqueeze(0).contiguous();
55+
56+
torch::Tensor output_tensor = model->forward(tensor);
57+
output_tensor = output_tensor.squeeze(0).permute({1, 2, 0}).contiguous();
58+
output_tensor = output_tensor.to(torch::kCPU, torch::kFloat32, /*non_blocking=*/false, /*copy=*/true);
59+
output_tensor.div_(255.0);
60+
61+
// chpl_external_array
62+
// rgb_float_frame_data_ptr = chpl_make_external_array_ptr(rgb_float_frame.data,size);
63+
64+
// chpl_external_array
65+
// rgb_float_output_frame_array = getNewFrame(&rgb_float_frame_data_ptr, height, width, channels);
66+
67+
68+
// cv::Mat new_rgb_frame(height, width, CV_8UC3,new_frame_array.elts);
69+
// cv::cvtColor(new_rgb_frame, new_rgb_frame, cv::COLOR_RGB2BGR);
70+
71+
cv::Mat output_frame(height,width,CV_32FC3,output_tensor.data_ptr<float>()); // frame to write to
72+
output_frame.convertTo(output_frame, CV_8UC3, 255.0f);
73+
cv::cvtColor(output_frame, output_frame, cv::COLOR_RGB2BGR);
74+
75+
return output_frame;
76+
77+
78+
}
79+
80+
81+
int mirror() {
82+
cv::VideoCapture cap(0);
83+
if (!cap.isOpened()) {
84+
std::cerr << "Error: Cannot open the webcam.\n";
85+
return -1;
86+
}
87+
88+
TransformerNet model = load_net();
89+
90+
cv::Mat frame;
91+
const std::string windowName = "Webcam Feed";
92+
cv::namedWindow(windowName, cv::WINDOW_AUTOSIZE);
93+
94+
cv::Size original_frame_size;
95+
cv::Size processed_frame_size;
96+
97+
while (true) {
98+
99+
uint64_t start = cv::getTickCount();
100+
101+
// Capture a new frame from webcam
102+
cap >> frame;
103+
if (frame.empty()) {
104+
std::cerr << "Error: Empty frame captured.\n";
105+
break;
106+
}
107+
108+
original_frame_size = frame.size();
109+
110+
const int width = (original_frame_size.width * 0.2);
111+
const int height = (original_frame_size.height * 0.2);
112+
processed_frame_size = cv::Size(width, height);
113+
cv::resize(frame, frame, processed_frame_size);
114+
115+
// std::cout << "Frame size: " << frame.size() << std::endl;
116+
// std::cout << "New frame size: " << processed_frame_size << std::endl;
117+
118+
cv::Mat next_frame = new_frame(frame,model);
119+
120+
cv::resize(next_frame, next_frame, original_frame_size);
121+
122+
// Display the captured frame
123+
cv::imshow(windowName, next_frame);
124+
125+
// Wait for 30ms or until 'q' key is pressed
126+
char key = static_cast<char>(cv::waitKey(1));
127+
if (key == 'q' || key == 27) { // 'q' or ESC to quit
128+
break;
129+
}
130+
131+
double fps = cv::getTickFrequency() / (cv::getTickCount() - start);
132+
std::cout << "\rcv::FPS : " << fps << "\t\r" << std::flush;
133+
}
134+
135+
// Release the camera and destroy all windows
136+
cap.release();
137+
cv::destroyAllWindows();
138+
return 0;
139+
}
140+
141+
142+
int main(int argc, char* argv[]) {
143+
return mirror();
144+
}
145+
146+

0 commit comments

Comments
 (0)