Skip to content

Commit b322889

Browse files
authored
Work on style transfer demos (#71)
- Add example calling OpenCV loop from Chapel (chapel-lang/chapel#27242) `demos/video/style-transfer` - Create demo in python for segmented style transfer for each face `demos/video/style-transfer/face-recognition` - Create demo in Chapel for style transfer `demos/video/chapel-webcam`
2 parents 0a2984e + 2b9368f commit b322889

30 files changed

Lines changed: 1608 additions & 51 deletions

CMakeLists.txt

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,22 @@ target_include_directories(
131131
)
132132

133133

134+
add_library(bridge_objs STATIC $<TARGET_OBJECTS:bridge>)
135+
set_target_properties(bridge_objs
136+
PROPERTIES ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}"
137+
)
138+
139+
134140
set(BRIDGE_OBJECT_FILES $<TARGET_OBJECTS:bridge>)
135141

142+
# add_custom_command(
143+
# TARGET bridge
144+
# POST_BUILD
145+
# COMMAND ${CMAKE_COMMAND} -E copy_directory
146+
# "${CMAKE_CURRENT_SOURCE_DIR}/style-transfer/models"
147+
# "$<TARGET_FILE_DIR:StyleTransfer>/style-transfer/models"
148+
# COMMENT "NOT! Copying ${PROJECT_ROOT_DIR}/examples/vgg/images to $<TARGET_FILE_DIR:vgg>/images"
149+
# )
136150

137151

138152

@@ -237,7 +251,8 @@ set(CHAI_LINKER_ARGS
237251
${BRIDGE_OBJECT_FILES}
238252
-L ${LIBTORCH_DIR}/lib
239253
${LIBTORCH_LIBS_LINKER_ARGS}
240-
--ldflags "-Wl,-rpath,${LIBTORCH_DIR}/lib"
254+
--ccflags "-I${BRIDGE_DIR}/include -L${PROJECT_ROOT_DIR}/build"
255+
--ldflags "-L${PROJECT_ROOT_DIR}/build -Wl,-rpath,${LIBTORCH_DIR}/lib"
241256
)
242257

243258

bridge/include/bridge.h

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ typedef double float64_t;
1515
typedef char bool_t;
1616
typedef unsigned char uint8_t;
1717
typedef unsigned int uint32_t;
18+
typedef unsigned long long uint64_t;
1819

1920
typedef struct bridge_tensor_t {
2021
float* data;
@@ -24,6 +25,17 @@ typedef struct bridge_tensor_t {
2425
} bridge_tensor_t;
2526

2627

28+
typedef struct bridge_pt_model_t {
29+
void* pt_module;
30+
} bridge_pt_model_t;
31+
32+
typedef struct test_struct_t {
33+
int* field;
34+
} test_struct_t;
35+
36+
37+
void hello_world(void);
38+
2739
typedef struct nil_scalar_tensor_t {
2840
float scalar;
2941
bridge_tensor_t tensor;
@@ -36,6 +48,14 @@ float* unsafe(const float* arr);
3648
bridge_tensor_t load_tensor_from_file(const uint8_t* file_path);
3749
bridge_tensor_t load_tensor_dict_from_file(const uint8_t* file_path,const uint8_t* tensor_key);
3850
bridge_tensor_t load_run_model(const uint8_t* model_path, bridge_tensor_t input);
51+
52+
bridge_pt_model_t load_model(const uint8_t* model_path);
53+
54+
bridge_tensor_t model_forward(bridge_pt_model_t model, bridge_tensor_t input);
55+
56+
57+
bridge_tensor_t model_forward_style_transfer(bridge_pt_model_t model, bridge_tensor_t input);
58+
3959
bridge_tensor_t resize(bridge_tensor_t input,int height,int width);
4060
bridge_tensor_t imagenet_normalize(bridge_tensor_t input);
4161

@@ -107,7 +127,6 @@ proto_bridge_simple(tanhshrink);
107127
void split_loop(int64_t idx, int64_t n);
108128
void split_loop_filler(int64_t n,int64_t* ret);
109129

110-
void show_webcam(void);
111130

112131

113132
// bridge_tensor_t conv2d(

bridge/lib/bridge.cpp

Lines changed: 136 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
#include <chrono>
1717
#include <thread>
1818

19-
#include <opencv2/opencv.hpp>
2019

2120

2221
#define def_bridge_simple(Name) \
@@ -66,6 +65,19 @@ torch::Tensor bridge_to_torch(bridge_tensor_t &bt) {
6665
return torch::from_blob(bt.data, shape, torch::kFloat);
6766
}
6867

68+
torch::Tensor bridge_to_torch(bridge_tensor_t &bt,torch::Device device, bool copy,torch::ScalarType dtype = torch::kFloat32) {
69+
std::vector<int64_t> sizes_vec(bt.sizes, bt.sizes + bt.dim);
70+
auto shape = torch::IntArrayRef(sizes_vec);
71+
auto t = torch::from_blob(bt.data, shape, torch::kFloat);
72+
if (device != torch::kCPU)
73+
copy = true;
74+
if (copy)
75+
return t.to(device, dtype, /*non_blocking=*/false, /*copy=*/true);
76+
else
77+
return t.to(device, dtype, /*non_blocking=*/false, /*copy=*/false);
78+
79+
}
80+
6981
extern "C" float32_t* unsafe(const float32_t* arr) {
7082
return const_cast<float32_t*>(arr);
7183
}
@@ -131,6 +143,102 @@ extern "C" bridge_tensor_t load_run_model(const uint8_t* model_path, bridge_tens
131143
return torch_to_bridge(output);
132144
}
133145

146+
147+
extern "C" bridge_pt_model_t load_model(const uint8_t* model_path) {
148+
149+
std::cout << "Begin loading model from path: " << model_path << std::endl;
150+
std::cout.flush();
151+
std::string path(reinterpret_cast<const char*>(model_path));
152+
std::cout << "Loading model from path: " << path << std::endl;
153+
std::cout.flush();
154+
155+
try {
156+
157+
auto* module = new torch::jit::Module(torch::jit::load(path));
158+
module->to(torch::kMPS,torch::kFloat16,false);
159+
module->eval();
160+
std::cout << "Model loaded successfully!" << std::endl;
161+
std::cout.flush();
162+
return { static_cast<void*>(module) };
163+
164+
// torch::jit::Module tmp = torch::jit::load(path);
165+
// std::cout << "Model loaded successfully!" << std::endl;
166+
// std::cout.flush();
167+
// auto* module = new torch::jit::Module(std::move(tmp));
168+
// std::cout << "Model moved successfully!" << std::endl;
169+
// std::cout.flush();
170+
// return { static_cast<void*>(module) };
171+
} catch (const c10::Error& e) {
172+
std::cerr << "error loading the model\n" << e.msg();
173+
std::cout << "error loading the model\n" << e.msg();
174+
std::cout.flush();
175+
std::cerr.flush();
176+
}
177+
std::cout << "Model loading failed!" << std::endl;
178+
std::cout.flush();
179+
180+
return { nullptr };
181+
182+
183+
184+
// bridge_pt_model_t model_wrapper;
185+
// torch::jit::Module* pt_module = new torch::jit::Module(); // = (torch::jit::Module*) model_wrapper.pt_module;
186+
// try {
187+
// *pt_module = torch::jit::load(mp);
188+
// std::cout << "Model loaded successfully!" << std::endl;
189+
// std::cout.flush();
190+
// model_wrapper.pt_module = pt_module;
191+
// } catch (const c10::Error& e) {
192+
// std::cerr << "error loading the model\n" << e.msg();
193+
// std::cout << "error loading the model\n" << e.msg();
194+
// std::cout.flush();
195+
// std::cerr.flush();
196+
// }
197+
198+
// std::cout << pt_module->dump_to_str(false,false,false) << std::endl;
199+
// std::cout.flush();
200+
201+
// return model_wrapper;
202+
}
203+
204+
205+
206+
bridge_tensor_t model_forward(bridge_pt_model_t model, bridge_tensor_t input, bool is_vgg_based_model) {
207+
208+
auto tn_mps = bridge_to_torch(input,torch::kMPS,true,torch::kFloat16);
209+
// auto tn_mps = tn.to(torch::kMPS,false,true);
210+
auto tn = tn_mps.permute({2, 0, 1}).unsqueeze(0).contiguous();
211+
212+
std::vector<torch::jit::IValue> ins;
213+
ins.push_back(tn);
214+
215+
auto* module = static_cast<torch::jit::Module*>(model.pt_module);
216+
auto o = module->forward(ins).toTensor();
217+
auto tn_out = o.squeeze(0).contiguous().permute({1, 2, 0}).contiguous();
218+
219+
if (is_vgg_based_model) {
220+
tn_out = tn_out / 255.0;
221+
}
222+
223+
auto tn_out_cpu = tn_out.to(torch::kCPU,torch::kFloat32,false,true);
224+
return torch_to_bridge(tn_out_cpu);
225+
226+
}
227+
228+
extern "C" bridge_tensor_t model_forward(bridge_pt_model_t model, bridge_tensor_t input) {
229+
return model_forward(model, input, false);
230+
}
231+
232+
extern "C" bridge_tensor_t model_forward_style_transfer(bridge_pt_model_t model, bridge_tensor_t input) {
233+
return model_forward(model, input, true);
234+
}
235+
236+
237+
extern "C" void hello_world(void) {
238+
std::cout << "Hello from C++!" << std::endl;
239+
std::cout.flush();
240+
}
241+
134242
extern "C" bridge_tensor_t increment3(bridge_tensor_t arr) {
135243
auto t = bridge_to_torch(arr);
136244
// Increment the tensor
@@ -404,37 +512,37 @@ extern "C" void split_loop_filler(int64_t n,int64_t* ret) {
404512

405513

406514

407-
cv::VideoCapture open_camera(int cam_index) {
408-
cv::VideoCapture cap(cam_index, cv::CAP_AVFOUNDATION);
409-
if (!cap.isOpened()) {
410-
std::cerr << "Could not open camera index " << cam_index << std::endl;
411-
return cv::VideoCapture();
412-
}
413-
cap.set(cv::CAP_PROP_BUFFERSIZE, 1); // minimal internal buffering
414-
cap.set(cv::CAP_PROP_FPS, 60); // request higher FPS if possible
415-
return cap;
416-
}
515+
// cv::VideoCapture open_camera(int cam_index) {
516+
// cv::VideoCapture cap(cam_index, cv::CAP_AVFOUNDATION);
517+
// if (!cap.isOpened()) {
518+
// std::cerr << "Could not open camera index " << cam_index << std::endl;
519+
// return cv::VideoCapture();
520+
// }
521+
// cap.set(cv::CAP_PROP_BUFFERSIZE, 1); // minimal internal buffering
522+
// cap.set(cv::CAP_PROP_FPS, 60); // request higher FPS if possible
523+
// return cap;
524+
// }
417525

418526

419-
extern "C" void show_webcam(void) {
420-
cv::VideoCapture cap;
421-
cap = open_camera(0);
527+
// extern "C" void show_webcam(void) {
528+
// cv::VideoCapture cap;
529+
// cap = open_camera(0);
422530

423-
cv::Mat frame_bgr;
531+
// cv::Mat frame_bgr;
424532

425-
while (true) {
426-
if (!cap.read(frame_bgr) || frame_bgr.empty()) {
427-
std::cerr << "[WARN] Empty frame, exiting" << std::endl;
428-
break;
429-
}
533+
// while (true) {
534+
// if (!cap.read(frame_bgr) || frame_bgr.empty()) {
535+
// std::cerr << "[WARN] Empty frame, exiting" << std::endl;
536+
// break;
537+
// }
430538

431-
cv::imshow("webcam", frame_bgr);
539+
// cv::imshow("webcam", frame_bgr);
432540

433-
if (cv::waitKey(1) == 27) { // ESC key
434-
break;
435-
}
436-
}
541+
// if (cv::waitKey(1) == 27) { // ESC key
542+
// break;
543+
// }
544+
// }
437545

438-
cap.release();
439-
cv::destroyAllWindows();
440-
}
546+
// cap.release();
547+
// cv::destroyAllWindows();
548+
// }

demos/video/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@
22

33
# export OpenCV_DIR="$(brew --prefix opencv)/share/opencv4"
44
# /opt/homebrew/opt/opencv/share/opencv4
5+
6+
7+
8+
9+
# This is messy
510
find_package(OpenCV 4 REQUIRED)
611

712
find_library(ACCELERATE Accelerate REQUIRED)
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
*.a
2+
*.so
3+
*.o
4+
*.dylib
5+
lib/savec

0 commit comments

Comments
 (0)