Skip to content

Commit 87339f6

Browse files
Merge branch 'Iainmon:main' into main
2 parents 9f12057 + 92b241e commit 87339f6

38 files changed

Lines changed: 2431 additions & 64 deletions

CMakeLists.txt

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,22 @@ target_include_directories(
131131
)
132132

133133

134+
add_library(bridge_objs STATIC $<TARGET_OBJECTS:bridge>)
135+
set_target_properties(bridge_objs
136+
PROPERTIES ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}"
137+
)
138+
139+
134140
set(BRIDGE_OBJECT_FILES $<TARGET_OBJECTS:bridge>)
135141

142+
# add_custom_command(
143+
# TARGET bridge
144+
# POST_BUILD
145+
# COMMAND ${CMAKE_COMMAND} -E copy_directory
146+
# "${CMAKE_CURRENT_SOURCE_DIR}/style-transfer/models"
147+
# "$<TARGET_FILE_DIR:StyleTransfer>/style-transfer/models"
148+
# COMMENT "NOT! Copying ${PROJECT_ROOT_DIR}/examples/vgg/images to $<TARGET_FILE_DIR:vgg>/images"
149+
# )
136150

137151

138152

@@ -214,12 +228,8 @@ add_dependencies(TinyLayerTest ChAI)
214228
target_link_options(TinyLayerTest
215229
PRIVATE
216230
--main-module layer_test.chpl
217-
-M ${PROJECT_ROOT_DIR}/lib
218-
${BRIDGE_DIR}/include/bridge.h
219-
${BRIDGE_OBJECT_FILES}
220-
-L ${LIBTORCH_DIR}/lib
221-
${LIBTORCH_LIBS_LINKER_ARGS}
222-
--ldflags "-Wl,-rpath,${LIBTORCH_DIR}/lib"
231+
# -M ${PROJECT_ROOT_DIR}/lib
232+
${CHAI_LINKER_ARGS}
223233
)
224234
# chpl test/tiny/layer_test.chpl -M lib bridge/include/bridge.h build/CMakeFiles/bridge.dir/bridge/lib/bridge.cpp.o -L libtorch/lib -ltorch -ltorch_cpu -lc10 -ltorch_global_deps --ldflags "-Wl,-rpath,libtorch/lib"
225235

@@ -237,7 +247,8 @@ set(CHAI_LINKER_ARGS
237247
${BRIDGE_OBJECT_FILES}
238248
-L ${LIBTORCH_DIR}/lib
239249
${LIBTORCH_LIBS_LINKER_ARGS}
240-
--ldflags "-Wl,-rpath,${LIBTORCH_DIR}/lib"
250+
--ccflags "-I${BRIDGE_DIR}/include -L${PROJECT_ROOT_DIR}/build"
251+
--ldflags "-L${PROJECT_ROOT_DIR}/build -Wl,-rpath,${LIBTORCH_DIR}/lib"
241252
)
242253

243254

bridge/.DS_Store

0 Bytes
Binary file not shown.

bridge/include/bridge.h

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ typedef double float64_t;
1515
typedef char bool_t;
1616
typedef unsigned char uint8_t;
1717
typedef unsigned int uint32_t;
18+
typedef unsigned long long uint64_t;
19+
20+
void debug_cpu_only_mode(bool_t mode);
1821

1922
typedef struct bridge_tensor_t {
2023
float* data;
@@ -24,6 +27,17 @@ typedef struct bridge_tensor_t {
2427
} bridge_tensor_t;
2528

2629

30+
typedef struct bridge_pt_model_t {
31+
void* pt_module;
32+
} bridge_pt_model_t;
33+
34+
typedef struct test_struct_t {
35+
int* field;
36+
} test_struct_t;
37+
38+
39+
void hello_world(void);
40+
2741
typedef struct nil_scalar_tensor_t {
2842
float scalar;
2943
bridge_tensor_t tensor;
@@ -36,6 +50,14 @@ float* unsafe(const float* arr);
3650
bridge_tensor_t load_tensor_from_file(const uint8_t* file_path);
3751
bridge_tensor_t load_tensor_dict_from_file(const uint8_t* file_path,const uint8_t* tensor_key);
3852
bridge_tensor_t load_run_model(const uint8_t* model_path, bridge_tensor_t input);
53+
54+
bridge_pt_model_t load_model(const uint8_t* model_path);
55+
56+
bool_t accelerator_available(void);
57+
58+
bridge_tensor_t model_forward(bridge_pt_model_t model, bridge_tensor_t input);
59+
bridge_tensor_t model_forward_style_transfer(bridge_pt_model_t model, bridge_tensor_t input);
60+
3961
bridge_tensor_t resize(bridge_tensor_t input,int height,int width);
4062
bridge_tensor_t imagenet_normalize(bridge_tensor_t input);
4163

@@ -107,7 +129,6 @@ proto_bridge_simple(tanhshrink);
107129
void split_loop(int64_t idx, int64_t n);
108130
void split_loop_filler(int64_t n,int64_t* ret);
109131

110-
void show_webcam(void);
111132

112133

113134
// bridge_tensor_t conv2d(

bridge/lib/bridge.cpp

Lines changed: 183 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#include <bridge.h>
22

33
#include <torch/torch.h>
4+
#include <Aten/ATen.h>
5+
46
#include <torch/script.h>
57

68
// #include <torch/script.h>
@@ -16,7 +18,6 @@
1618
#include <chrono>
1719
#include <thread>
1820

19-
#include <opencv2/opencv.hpp>
2021

2122

2223
#define def_bridge_simple(Name) \
@@ -28,6 +29,58 @@
2829

2930

3031

32+
// Globals
33+
34+
35+
torch::Device get_best_device();
36+
torch::ScalarType get_best_dtype();
37+
38+
auto best_device = get_best_device();
39+
auto best_dtype = get_best_dtype();
40+
41+
torch::NoGradGuard no_grad;
42+
torch::AutoGradMode enable_grad(false);
43+
44+
bool debug_cpu_only = false;
45+
46+
47+
48+
torch::Device get_best_device() {
49+
if (debug_cpu_only)
50+
return torch::Device(torch::kCPU);
51+
52+
if (torch::hasMPS()) {
53+
return torch::Device(torch::kMPS);
54+
} else if (torch::hasCUDA()) {
55+
return torch::Device(torch::kCUDA);
56+
} else {
57+
return torch::Device(torch::kCPU);
58+
}
59+
}
60+
61+
extern "C" void debug_cpu_only_mode(bool_t mode) {
62+
debug_cpu_only = mode;
63+
if (debug_cpu_only) {
64+
best_device = torch::Device(torch::kCPU);
65+
} else {
66+
best_device = get_best_device();
67+
}
68+
}
69+
70+
extern "C" bool_t accelerator_available() {
71+
return (best_device == torch::Device(torch::kCUDA) || best_device == torch::Device(torch::kMPS));
72+
}
73+
74+
torch::ScalarType get_best_dtype() {
75+
if (torch::hasMPS()) {
76+
return torch::kFloat16;
77+
} else if (torch::hasCUDA()) {
78+
return torch::kFloat16;
79+
} else {
80+
return torch::kFloat32;
81+
}
82+
}
83+
3184
int bridge_tensor_elements(bridge_tensor_t &bt) {
3285
int size = 1;
3386
for (int i = 0; i < bt.dim; ++i) {
@@ -40,14 +93,14 @@ size_t bridge_tensor_size(bridge_tensor_t &bt) {
4093
return sizeof(float32_t) * bridge_tensor_elements(bt);
4194
}
4295

43-
void store_tensor(torch::Tensor &input, float32_t* dest) {
96+
void store_tensor(at::Tensor &input, float32_t* dest) {
4497
float32_t * data = input.data_ptr<float32_t>();
4598
size_t bytes_size = sizeof(float32_t) * input.numel();
4699
// std::memmove(dest,data,bytes_size);
47100
std::memcpy(dest,data,bytes_size);
48101
}
49102

50-
bridge_tensor_t torch_to_bridge(torch::Tensor &tensor) {
103+
bridge_tensor_t torch_to_bridge(at::Tensor &tensor) {
51104
bridge_tensor_t result;
52105
result.created_by_c = true;
53106
result.dim = tensor.dim();
@@ -60,12 +113,25 @@ bridge_tensor_t torch_to_bridge(torch::Tensor &tensor) {
60113
return result;
61114
}
62115

63-
torch::Tensor bridge_to_torch(bridge_tensor_t &bt) {
116+
at::Tensor bridge_to_torch(bridge_tensor_t &bt) {
64117
std::vector<int64_t> sizes_vec(bt.sizes, bt.sizes + bt.dim);
65118
auto shape = torch::IntArrayRef(sizes_vec);
66119
return torch::from_blob(bt.data, shape, torch::kFloat);
67120
}
68121

122+
at::Tensor bridge_to_torch(bridge_tensor_t &bt,torch::Device device, bool copy,torch::ScalarType dtype = torch::kFloat32) {
123+
std::vector<int64_t> sizes_vec(bt.sizes, bt.sizes + bt.dim);
124+
auto shape = torch::IntArrayRef(sizes_vec);
125+
auto t = torch::from_blob(bt.data, shape, torch::kFloat);
126+
if (device != torch::kCPU)
127+
copy = true;
128+
if (copy)
129+
return t.to(device, dtype, /*non_blocking=*/false, /*copy=*/true);
130+
else
131+
return t.to(device, dtype, /*non_blocking=*/false, /*copy=*/false);
132+
133+
}
134+
69135
extern "C" float32_t* unsafe(const float32_t* arr) {
70136
return const_cast<float32_t*>(arr);
71137
}
@@ -131,6 +197,92 @@ extern "C" bridge_tensor_t load_run_model(const uint8_t* model_path, bridge_tens
131197
return torch_to_bridge(output);
132198
}
133199

200+
201+
202+
203+
extern "C" bridge_pt_model_t load_model(const uint8_t* model_path) {
204+
205+
std::cout << "Begin loading model from path: " << model_path << std::endl;
206+
std::cout.flush();
207+
std::string path(reinterpret_cast<const char*>(model_path));
208+
std::cout << "Loading model from path: " << path << std::endl;
209+
std::cout.flush();
210+
211+
try {
212+
auto* module = new torch::jit::Module(torch::jit::load(path));
213+
module->to(best_device,best_dtype,false);
214+
module->eval();
215+
std::cout << "Model loaded successfully!" << std::endl;
216+
std::cout.flush();
217+
return { static_cast<void*>(module) };
218+
} catch (const c10::Error& e) {
219+
std::cerr << "error loading the model\n" << e.msg();
220+
std::cout << "error loading the model\n" << e.msg();
221+
std::cout.flush();
222+
std::cerr.flush();
223+
}
224+
std::cout << "Model loading failed!" << std::endl;
225+
std::cout.flush();
226+
227+
return { nullptr };
228+
}
229+
230+
231+
232+
bridge_tensor_t model_forward(bridge_pt_model_t model, bridge_tensor_t input, bool is_vgg_based_model) {
233+
auto tn_mps = bridge_to_torch(input,best_device,true,best_dtype);
234+
// tn_mps = tn_mps.permute({2, 0, 1}).contiguous();
235+
// tn_mps.unsqueeze_(0);//.contiguous();
236+
auto tn = tn_mps.permute({2, 0, 1}).unsqueeze(0).contiguous();
237+
238+
std::vector<torch::jit::IValue> ins;
239+
ins.push_back(tn);
240+
241+
auto* module = static_cast<torch::jit::Module*>(model.pt_module);
242+
auto o = module->forward(ins).toTensor();
243+
// auto tn_out = o.squeeze(0).permute({1, 2, 0}).contiguous();
244+
auto tn_out = o.squeeze(0).contiguous().permute({1, 2, 0}).contiguous();
245+
246+
if (is_vgg_based_model) {
247+
tn_out.div_(255.0);
248+
}
249+
250+
auto tn_out_cpu = tn_out.to(torch::kCPU,torch::kFloat32,false,true);
251+
252+
return torch_to_bridge(tn_out_cpu);
253+
254+
}
255+
256+
extern "C" bridge_tensor_t model_forward(bridge_pt_model_t model, bridge_tensor_t input) {
257+
return model_forward(model, input, false);
258+
}
259+
260+
extern "C" bridge_tensor_t model_forward_style_transfer(bridge_pt_model_t model, bridge_tensor_t input) {
261+
return model_forward(model, input, true);
262+
}
263+
264+
// std::tuple<uint64_t, uint64_t> get_cpu_frame_size(uint64_t width, uint64_t height, float32_t scale_factor) {
265+
// // if (best_device == torch::kMPS || best_device == torch::kCUDA)
266+
// if (accelerator_available())
267+
// return std::make_tuple(width, height);
268+
// uint64_t new_width = static_cast<uint64_t>(width * scale_factor);
269+
// uint64_t new_height = static_cast<uint64_t>(height * scale_factor);
270+
// return std::make_tuple(new_width, new_height);
271+
// }
272+
273+
// extern "C" uint64_t get_cpu_frame_width(uint64_t width,float32_t scale_factor) {
274+
// return std::get<0>(get_cpu_frame_size(width, 0, scale_factor));
275+
// }
276+
// extern "C" uint64_t get_cpu_frame_height(uint64_t height,float32_t scale_factor) {
277+
// return std::get<1>(get_cpu_frame_size(0, height, scale_factor));
278+
// }
279+
280+
281+
extern "C" void hello_world(void) {
282+
std::cout << "Hello from C++!" << std::endl;
283+
std::cout.flush();
284+
}
285+
134286
extern "C" bridge_tensor_t increment3(bridge_tensor_t arr) {
135287
auto t = bridge_to_torch(arr);
136288
// Increment the tensor
@@ -404,37 +556,37 @@ extern "C" void split_loop_filler(int64_t n,int64_t* ret) {
404556

405557

406558

407-
cv::VideoCapture open_camera(int cam_index) {
408-
cv::VideoCapture cap(cam_index, cv::CAP_AVFOUNDATION);
409-
if (!cap.isOpened()) {
410-
std::cerr << "Could not open camera index " << cam_index << std::endl;
411-
return cv::VideoCapture();
412-
}
413-
cap.set(cv::CAP_PROP_BUFFERSIZE, 1); // minimal internal buffering
414-
cap.set(cv::CAP_PROP_FPS, 60); // request higher FPS if possible
415-
return cap;
416-
}
559+
// cv::VideoCapture open_camera(int cam_index) {
560+
// cv::VideoCapture cap(cam_index, cv::CAP_AVFOUNDATION);
561+
// if (!cap.isOpened()) {
562+
// std::cerr << "Could not open camera index " << cam_index << std::endl;
563+
// return cv::VideoCapture();
564+
// }
565+
// cap.set(cv::CAP_PROP_BUFFERSIZE, 1); // minimal internal buffering
566+
// cap.set(cv::CAP_PROP_FPS, 60); // request higher FPS if possible
567+
// return cap;
568+
// }
417569

418570

419-
extern "C" void show_webcam(void) {
420-
cv::VideoCapture cap;
421-
cap = open_camera(0);
571+
// extern "C" void show_webcam(void) {
572+
// cv::VideoCapture cap;
573+
// cap = open_camera(0);
422574

423-
cv::Mat frame_bgr;
575+
// cv::Mat frame_bgr;
424576

425-
while (true) {
426-
if (!cap.read(frame_bgr) || frame_bgr.empty()) {
427-
std::cerr << "[WARN] Empty frame, exiting" << std::endl;
428-
break;
429-
}
577+
// while (true) {
578+
// if (!cap.read(frame_bgr) || frame_bgr.empty()) {
579+
// std::cerr << "[WARN] Empty frame, exiting" << std::endl;
580+
// break;
581+
// }
430582

431-
cv::imshow("webcam", frame_bgr);
583+
// cv::imshow("webcam", frame_bgr);
432584

433-
if (cv::waitKey(1) == 27) { // ESC key
434-
break;
435-
}
436-
}
585+
// if (cv::waitKey(1) == 27) { // ESC key
586+
// break;
587+
// }
588+
// }
437589

438-
cap.release();
439-
cv::destroyAllWindows();
440-
}
590+
// cap.release();
591+
// cv::destroyAllWindows();
592+
// }

demos/models/readme.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
This folder contains the model architectures used in the demos.

0 commit comments

Comments
 (0)