Skip to content

Commit 57347f7

Browse files
committed
Add CPU and accelerator options to style transfer demo.
1 parent ef9de0c commit 57347f7

6 files changed

Lines changed: 71 additions & 48 deletions

File tree

bridge/include/bridge.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ typedef unsigned char uint8_t;
1717
typedef unsigned int uint32_t;
1818
typedef unsigned long long uint64_t;
1919

20+
void debug_cpu_only_mode(bool_t mode);
21+
2022
typedef struct bridge_tensor_t {
2123
float* data;
2224
int* sizes;
@@ -56,9 +58,6 @@ bool_t accelerator_available(void);
5658
bridge_tensor_t model_forward(bridge_pt_model_t model, bridge_tensor_t input);
5759
bridge_tensor_t model_forward_style_transfer(bridge_pt_model_t model, bridge_tensor_t input);
5860

59-
uint64_t get_cpu_frame_width(uint64_t width, float32_t scale_factor);
60-
uint64_t get_cpu_frame_height(uint64_t height, float32_t scale_factor);
61-
6261
bridge_tensor_t resize(bridge_tensor_t input,int height,int width);
6362
bridge_tensor_t imagenet_normalize(bridge_tensor_t input);
6463

bridge/lib/bridge.cpp

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,14 @@ auto best_dtype = get_best_dtype();
4141
torch::NoGradGuard no_grad;
4242
torch::AutoGradMode enable_grad(false);
4343

44-
45-
44+
bool debug_cpu_only = false;
4645

4746

4847

4948
torch::Device get_best_device() {
49+
if (debug_cpu_only)
50+
return torch::Device(torch::kCPU);
51+
5052
if (torch::hasMPS()) {
5153
return torch::Device(torch::kMPS);
5254
} else if (torch::hasCUDA()) {
@@ -56,9 +58,17 @@ torch::Device get_best_device() {
5658
}
5759
}
5860

61+
extern "C" void debug_cpu_only_mode(bool_t mode) {
62+
debug_cpu_only = mode;
63+
if (debug_cpu_only) {
64+
best_device = torch::Device(torch::kCPU);
65+
} else {
66+
best_device = get_best_device();
67+
}
68+
}
69+
5970
extern "C" bool_t accelerator_available() {
60-
return false;
61-
// return torch::hasMPS() || torch::hasCUDA();
71+
return (best_device == torch::Device(torch::kCUDA) || best_device == torch::Device(torch::kMPS));
6272
}
6373

6474
torch::ScalarType get_best_dtype() {
@@ -251,21 +261,21 @@ extern "C" bridge_tensor_t model_forward_style_transfer(bridge_pt_model_t model,
251261
return model_forward(model, input, true);
252262
}
253263

254-
std::tuple<uint64_t, uint64_t> get_cpu_frame_size(uint64_t width, uint64_t height, float32_t scale_factor) {
255-
// if (best_device == torch::kMPS || best_device == torch::kCUDA)
256-
if (accelerator_available())
257-
return std::make_tuple(width, height);
258-
uint64_t new_width = static_cast<uint64_t>(width * scale_factor);
259-
uint64_t new_height = static_cast<uint64_t>(height * scale_factor);
260-
return std::make_tuple(new_width, new_height);
261-
}
264+
// std::tuple<uint64_t, uint64_t> get_cpu_frame_size(uint64_t width, uint64_t height, float32_t scale_factor) {
265+
// // if (best_device == torch::kMPS || best_device == torch::kCUDA)
266+
// if (accelerator_available())
267+
// return std::make_tuple(width, height);
268+
// uint64_t new_width = static_cast<uint64_t>(width * scale_factor);
269+
// uint64_t new_height = static_cast<uint64_t>(height * scale_factor);
270+
// return std::make_tuple(new_width, new_height);
271+
// }
262272

263-
extern "C" uint64_t get_cpu_frame_width(uint64_t width,float32_t scale_factor) {
264-
return std::get<0>(get_cpu_frame_size(width, 0, scale_factor));
265-
}
266-
extern "C" uint64_t get_cpu_frame_height(uint64_t height,float32_t scale_factor) {
267-
return std::get<1>(get_cpu_frame_size(0, height, scale_factor));
268-
}
273+
// extern "C" uint64_t get_cpu_frame_width(uint64_t width,float32_t scale_factor) {
274+
// return std::get<0>(get_cpu_frame_size(width, 0, scale_factor));
275+
// }
276+
// extern "C" uint64_t get_cpu_frame_height(uint64_t height,float32_t scale_factor) {
277+
// return std::get<1>(get_cpu_frame_size(0, height, scale_factor));
278+
// }
269279

270280

271281
extern "C" void hello_world(void) {

demos/video/chapel-webcam/lib/smol.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ void chpl__init_ndarrayRandom(int64_t _ln,
3232
void chpl__init_smol(int64_t _ln,
3333
int32_t _fn);
3434
chpl_bool acceleratorAvailable(void);
35-
int64_t getCPUFrameWidth(int64_t width);
36-
int64_t getCPUFrameHeight(int64_t height);
35+
int64_t getScaledFrameWidth(int64_t width);
36+
int64_t getScaledFrameHeight(int64_t height);
3737
int64_t square(int64_t x);
3838
void printArray(chpl_external_array * a);
3939
void globalLoadModel(void);

demos/video/chapel-webcam/main.cpp

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,8 @@ int mirror() {
113113
const std::string windowName = "Webcam Feed";
114114
cv::namedWindow(windowName, cv::WINDOW_AUTOSIZE);
115115

116-
cv::Size frame_size;
117-
cv::Size new_frame_size;
116+
cv::Size original_frame_size;
117+
cv::Size processed_frame_size;
118118

119119
while (true) {
120120
// Capture a new frame from webcam
@@ -123,23 +123,20 @@ int mirror() {
123123
std::cerr << "Error: Empty frame captured.\n";
124124
break;
125125
}
126-
frame_size = frame.size();
127-
if (!acceleratorAvailable()) {
128-
const auto width = getCPUFrameWidth(frame_size.width);
129-
const auto height = getCPUFrameHeight(frame_size.height);
130-
new_frame_size = cv::Size(width, height);
131-
} else {
132-
new_frame_size = frame_size;
133-
}
134126

135-
cv::resize(frame, frame, new_frame_size);
127+
original_frame_size = frame.size();
128+
129+
const auto width = getScaledFrameWidth(original_frame_size.width);
130+
const auto height = getScaledFrameHeight(original_frame_size.height);
131+
processed_frame_size = cv::Size(width, height);
132+
cv::resize(frame, frame, processed_frame_size);
136133

137-
std::cout << "Frame size: " << frame.size() << std::endl;
138-
std::cout << "New frame size: " << new_frame_size << std::endl;
134+
// std::cout << "Frame size: " << frame.size() << std::endl;
135+
// std::cout << "New frame size: " << processed_frame_size << std::endl;
139136

140137
cv::Mat next_frame = new_frame(frame);
141138

142-
cv::resize(next_frame, next_frame, frame_size);
139+
cv::resize(next_frame, next_frame, original_frame_size);
143140

144141
// Display the captured frame
145142
cv::imshow(windowName, next_frame);

demos/video/chapel-webcam/smol.chpl

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,39 @@ use Tensor;
22
use Layer;
33
import Utilities as util;
44

5-
config const cpuScaleFactor: real(32) = 0.2;
5+
config const cpuScale: real = 0.2;
6+
config const accelScale: real = 0.45;
7+
config const debugCPUOnly: bool = false;
8+
9+
10+
const cpuScaleFactor = cpuScale;
11+
const acceleratorScaleFactor = accelScale;
612

7-
writeln("CPU Scale Factor: ", cpuScaleFactor);
813

914
export proc acceleratorAvailable(): bool do
1015
return Bridge.acceleratorAvailable();
1116

12-
export proc getCPUFrameWidth(width: int): int do
13-
return Bridge.getCPUFrameWidth(width,cpuScaleFactor : real(32));
1417

15-
export proc getCPUFrameHeight(height: int): int do
16-
return Bridge.getCPUFrameHeight(height,cpuScaleFactor : real(32));
18+
export proc getScaledFrameWidth(width: int): int do
19+
if acceleratorAvailable() then
20+
return (width:real * acceleratorScaleFactor):int;
21+
else
22+
return (width:real * cpuScaleFactor):int;
23+
24+
export proc getScaledFrameHeight(height: int): int do
25+
if acceleratorAvailable() then
26+
return (height:real * acceleratorScaleFactor):int;
27+
else
28+
return (height:real * cpuScaleFactor):int;
29+
30+
31+
if debugCPUOnly then
32+
writeln("Debugging CPU only!");
33+
Bridge.debugCpuOnlyMode(debugCPUOnly);
34+
35+
writeln("CPU Scale Factor: ", cpuScaleFactor);
36+
writeln("Accelerator Scale Factor: ", acceleratorScaleFactor);
37+
writeln("Accelerator Available: ", acceleratorAvailable());
1738

1839

1940
export proc square(x: int): int {

lib/Bridge.chpl

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,11 +74,7 @@ module Bridge {
7474
extern "accelerator_available"
7575
proc acceleratorAvailable(): bool;
7676

77-
extern "get_cpu_frame_width"
78-
proc getCPUFrameWidth(width: int(64), scale_factor: real(32)): int(64);
79-
extern "get_cpu_frame_height"
80-
proc getCPUFrameHeight(height: int(64), scale_factor: real(32)): int(64);
81-
77+
extern "debug_cpu_only_mode" proc debugCpuOnlyMode(mode: bool): void;
8278

8379
extern proc convolve2d(
8480
in input: bridge_tensor_t,

0 commit comments

Comments
 (0)