Skip to content

Commit ef9de0c

Browse files
committed
Add scale factor and performance optimizations to style transfer demo.
1 parent 2de27e5 commit ef9de0c

8 files changed

Lines changed: 369 additions & 24 deletions

File tree

bridge/include/bridge.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,14 @@ bridge_tensor_t load_run_model(const uint8_t* model_path, bridge_tensor_t input)
5151

5252
bridge_pt_model_t load_model(const uint8_t* model_path);
5353

54-
bridge_tensor_t model_forward(bridge_pt_model_t model, bridge_tensor_t input);
55-
54+
bool_t accelerator_available(void);
5655

56+
bridge_tensor_t model_forward(bridge_pt_model_t model, bridge_tensor_t input);
5757
bridge_tensor_t model_forward_style_transfer(bridge_pt_model_t model, bridge_tensor_t input);
5858

59+
uint64_t get_cpu_frame_width(uint64_t width, float32_t scale_factor);
60+
uint64_t get_cpu_frame_height(uint64_t height, float32_t scale_factor);
61+
5962
bridge_tensor_t resize(bridge_tensor_t input,int height,int width);
6063
bridge_tensor_t imagenet_normalize(bridge_tensor_t input);
6164

bridge/lib/bridge.cpp

Lines changed: 64 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,48 @@
2929

3030

3131

32+
// Globals
33+
34+
35+
torch::Device get_best_device();
36+
torch::ScalarType get_best_dtype();
37+
38+
auto best_device = get_best_device();
39+
auto best_dtype = get_best_dtype();
40+
3241
torch::NoGradGuard no_grad;
3342
torch::AutoGradMode enable_grad(false);
3443

44+
45+
46+
47+
48+
49+
torch::Device get_best_device() {
50+
if (torch::hasMPS()) {
51+
return torch::Device(torch::kMPS);
52+
} else if (torch::hasCUDA()) {
53+
return torch::Device(torch::kCUDA);
54+
} else {
55+
return torch::Device(torch::kCPU);
56+
}
57+
}
58+
59+
extern "C" bool_t accelerator_available() {
60+
return false;
61+
// return torch::hasMPS() || torch::hasCUDA();
62+
}
63+
64+
torch::ScalarType get_best_dtype() {
65+
if (torch::hasMPS()) {
66+
return torch::kFloat16;
67+
} else if (torch::hasCUDA()) {
68+
return torch::kFloat16;
69+
} else {
70+
return torch::kFloat32;
71+
}
72+
}
73+
3574
int bridge_tensor_elements(bridge_tensor_t &bt) {
3675
int size = 1;
3776
for (int i = 0; i < bt.dim; ++i) {
@@ -149,8 +188,6 @@ extern "C" bridge_tensor_t load_run_model(const uint8_t* model_path, bridge_tens
149188
}
150189

151190

152-
#define DEVICE torch::kMPS
153-
#define DTYPE torch::kFloat16
154191

155192

156193
extern "C" bridge_pt_model_t load_model(const uint8_t* model_path) {
@@ -163,7 +200,7 @@ extern "C" bridge_pt_model_t load_model(const uint8_t* model_path) {
163200

164201
try {
165202
auto* module = new torch::jit::Module(torch::jit::load(path));
166-
module->to(DEVICE,DTYPE,false);
203+
module->to(best_device,best_dtype,false);
167204
module->eval();
168205
std::cout << "Model loaded successfully!" << std::endl;
169206
std::cout.flush();
@@ -183,24 +220,24 @@ extern "C" bridge_pt_model_t load_model(const uint8_t* model_path) {
183220

184221

185222
bridge_tensor_t model_forward(bridge_pt_model_t model, bridge_tensor_t input, bool is_vgg_based_model) {
186-
auto tn_mps = bridge_to_torch(input,DEVICE,true,DTYPE);
187-
tn_mps = tn_mps.permute({2, 0, 1}).contiguous();
188-
tn_mps.unsqueeze_(0);//.contiguous();
189-
// auto tn = tn_mps.permute({2, 0, 1}).unsqueeze(0).contiguous();
223+
auto tn_mps = bridge_to_torch(input,best_device,true,best_dtype);
224+
// tn_mps = tn_mps.permute({2, 0, 1}).contiguous();
225+
// tn_mps.unsqueeze_(0);//.contiguous();
226+
auto tn = tn_mps.permute({2, 0, 1}).unsqueeze(0).contiguous();
190227

191228
std::vector<torch::jit::IValue> ins;
192-
ins.push_back(tn_mps);
229+
ins.push_back(tn);
193230

194231
auto* module = static_cast<torch::jit::Module*>(model.pt_module);
195232
auto o = module->forward(ins).toTensor();
196-
auto tn_out = o.squeeze(0).permute({1, 2, 0}).contiguous();
197-
// auto tn_out = o.squeeze(0).contiguous().permute({1, 2, 0}).contiguous();
233+
// auto tn_out = o.squeeze(0).permute({1, 2, 0}).contiguous();
234+
auto tn_out = o.squeeze(0).contiguous().permute({1, 2, 0}).contiguous();
198235

199236
if (is_vgg_based_model) {
200237
tn_out.div_(255.0);
201238
}
202239

203-
auto tn_out_cpu = tn_out.to(torch::kCPU,torch::kFloat32,false,false);
240+
auto tn_out_cpu = tn_out.to(torch::kCPU,torch::kFloat32,false,true);
204241

205242
return torch_to_bridge(tn_out_cpu);
206243

@@ -214,6 +251,22 @@ extern "C" bridge_tensor_t model_forward_style_transfer(bridge_pt_model_t model,
214251
return model_forward(model, input, true);
215252
}
216253

254+
std::tuple<uint64_t, uint64_t> get_cpu_frame_size(uint64_t width, uint64_t height, float32_t scale_factor) {
255+
// if (best_device == torch::kMPS || best_device == torch::kCUDA)
256+
if (accelerator_available())
257+
return std::make_tuple(width, height);
258+
uint64_t new_width = static_cast<uint64_t>(width * scale_factor);
259+
uint64_t new_height = static_cast<uint64_t>(height * scale_factor);
260+
return std::make_tuple(new_width, new_height);
261+
}
262+
263+
extern "C" uint64_t get_cpu_frame_width(uint64_t width,float32_t scale_factor) {
264+
return std::get<0>(get_cpu_frame_size(width, 0, scale_factor));
265+
}
266+
extern "C" uint64_t get_cpu_frame_height(uint64_t height,float32_t scale_factor) {
267+
return std::get<1>(get_cpu_frame_size(0, height, scale_factor));
268+
}
269+
217270

218271
extern "C" void hello_world(void) {
219272
std::cout << "Hello from C++!" << std::endl;

demos/video/chapel-webcam/lib/smol.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ void chpl__init_ndarrayRandom(int64_t _ln,
3131
int32_t _fn);
3232
void chpl__init_smol(int64_t _ln,
3333
int32_t _fn);
34+
chpl_bool acceleratorAvailable(void);
35+
int64_t getCPUFrameWidth(int64_t width);
36+
int64_t getCPUFrameHeight(int64_t height);
3437
int64_t square(int64_t x);
3538
void printArray(chpl_external_array * a);
3639
void globalLoadModel(void);

demos/video/chapel-webcam/main.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,14 +113,34 @@ int mirror() {
113113
const std::string windowName = "Webcam Feed";
114114
cv::namedWindow(windowName, cv::WINDOW_AUTOSIZE);
115115

116+
cv::Size frame_size;
117+
cv::Size new_frame_size;
118+
116119
while (true) {
117120
// Capture a new frame from webcam
118121
cap >> frame;
119122
if (frame.empty()) {
120123
std::cerr << "Error: Empty frame captured.\n";
121124
break;
122125
}
126+
frame_size = frame.size();
127+
if (!acceleratorAvailable()) {
128+
const auto width = getCPUFrameWidth(frame_size.width);
129+
const auto height = getCPUFrameHeight(frame_size.height);
130+
new_frame_size = cv::Size(width, height);
131+
} else {
132+
new_frame_size = frame_size;
133+
}
134+
135+
cv::resize(frame, frame, new_frame_size);
136+
137+
std::cout << "Frame size: " << frame.size() << std::endl;
138+
std::cout << "New frame size: " << new_frame_size << std::endl;
139+
123140
cv::Mat next_frame = new_frame(frame);
141+
142+
cv::resize(next_frame, next_frame, frame_size);
143+
124144
// Display the captured frame
125145
cv::imshow(windowName, next_frame);
126146

demos/video/chapel-webcam/model2.ipynb

Lines changed: 149 additions & 10 deletions
Large diffs are not rendered by default.

demos/video/chapel-webcam/smol.chpl

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,19 @@ use Tensor;
22
use Layer;
33
import Utilities as util;
44

5+
config const cpuScaleFactor: real(32) = 0.2;
6+
7+
writeln("CPU Scale Factor: ", cpuScaleFactor);
8+
9+
export proc acceleratorAvailable(): bool do
10+
return Bridge.acceleratorAvailable();
11+
12+
export proc getCPUFrameWidth(width: int): int do
13+
return Bridge.getCPUFrameWidth(width,cpuScaleFactor : real(32));
14+
15+
export proc getCPUFrameHeight(height: int): int do
16+
return Bridge.getCPUFrameHeight(height,cpuScaleFactor : real(32));
17+
518

619
export proc square(x: int): int {
720
writeln(x, " * ", x, " = ", x * x);
@@ -35,7 +48,9 @@ const startTime = getTime();
3548

3649
// ../style-transfer/models/exports/mps/nature_oil_painting_ep4_bt4_sw1e10_cw_1e5_float32.pt
3750
// ../style-transfer/models/exports/mps/udnie_float32.pt
38-
config const modelPath: string = "../style-transfer/models/exports/mps/starry_ep3_bt4_sw1e11_cw_1e5_float32.pt";
51+
// ../style-transfer/models/exports/mps/starry_ep3_bt4_sw1e11_cw_1e5_float32.pt // This is the one
52+
// ../style-transfer/models/exports/cpu/mosaic_float16.pt
53+
config const modelPath: string = "../style-transfer/models/exports/cpu/mosaic_float16.pt";
3954
var model : Bridge.bridge_pt_model_t;
4055

4156
var modelLayer : shared TorchModule(real(32))?;
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
import torch
2+
3+
4+
class TransformerNet(torch.nn.Module):
5+
def __init__(self):
6+
super(TransformerNet, self).__init__()
7+
# Initial convolution layers
8+
self.conv1 = ConvLayer(3, 32, kernel_size=9, stride=1)
9+
self.in1 = torch.nn.InstanceNorm2d(32, affine=True)
10+
self.conv2 = ConvLayer(32, 64, kernel_size=3, stride=2)
11+
self.in2 = torch.nn.InstanceNorm2d(64, affine=True)
12+
self.conv3 = ConvLayer(64, 128, kernel_size=3, stride=2)
13+
self.in3 = torch.nn.InstanceNorm2d(128, affine=True)
14+
# Residual layers
15+
self.res1 = ResidualBlock(128)
16+
self.res2 = ResidualBlock(128)
17+
self.res3 = ResidualBlock(128)
18+
self.res4 = ResidualBlock(128)
19+
self.res5 = ResidualBlock(128)
20+
# Upsampling Layers
21+
self.deconv1 = UpsampleConvLayer(128, 64, kernel_size=3, stride=1, upsample=2)
22+
self.in4 = torch.nn.InstanceNorm2d(64, affine=True)
23+
self.deconv2 = UpsampleConvLayer(64, 32, kernel_size=3, stride=1, upsample=2)
24+
self.in5 = torch.nn.InstanceNorm2d(32, affine=True)
25+
self.deconv3 = ConvLayer(32, 3, kernel_size=9, stride=1)
26+
# Non-linearities
27+
self.relu = torch.nn.ReLU()
28+
29+
def forward(self, X):
30+
y = self.relu(self.in1(self.conv1(X)))
31+
y = self.relu(self.in2(self.conv2(y)))
32+
y = self.relu(self.in3(self.conv3(y)))
33+
y = self.res1(y)
34+
y = self.res2(y)
35+
y = self.res3(y)
36+
y = self.res4(y)
37+
y = self.res5(y)
38+
y = self.relu(self.in4(self.deconv1(y)))
39+
y = self.relu(self.in5(self.deconv2(y)))
40+
y = self.deconv3(y)
41+
return y
42+
43+
44+
class ConvLayer(torch.nn.Module):
45+
def __init__(self, in_channels, out_channels, kernel_size, stride):
46+
super(ConvLayer, self).__init__()
47+
reflection_padding = kernel_size // 2
48+
self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
49+
self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride)
50+
51+
def forward(self, x):
52+
out = self.reflection_pad(x)
53+
out = self.conv2d(out)
54+
return out
55+
56+
57+
class ResidualBlock(torch.nn.Module):
58+
"""ResidualBlock
59+
introduced in: https://arxiv.org/abs/1512.03385
60+
recommended architecture: http://torch.ch/blog/2016/02/04/resnets.html
61+
"""
62+
63+
def __init__(self, channels):
64+
super(ResidualBlock, self).__init__()
65+
self.conv1 = ConvLayer(channels, channels, kernel_size=3, stride=1)
66+
self.in1 = torch.nn.InstanceNorm2d(channels, affine=True)
67+
self.conv2 = ConvLayer(channels, channels, kernel_size=3, stride=1)
68+
self.in2 = torch.nn.InstanceNorm2d(channels, affine=True)
69+
self.relu = torch.nn.ReLU()
70+
71+
def forward(self, x):
72+
residual = x
73+
out = self.relu(self.in1(self.conv1(x)))
74+
out = self.in2(self.conv2(out))
75+
out = out + residual
76+
return out
77+
78+
79+
class UpsampleConvLayer(torch.nn.Module):
80+
"""UpsampleConvLayer
81+
Upsamples the input and then does a convolution. This method gives better results
82+
compared to ConvTranspose2d.
83+
ref: http://distill.pub/2016/deconv-checkerboard/
84+
"""
85+
86+
def __init__(self, in_channels, out_channels, kernel_size, stride, upsample):
87+
super(UpsampleConvLayer, self).__init__()
88+
# self.upsample = upsample
89+
self.upsample = torch.nn.Upsample(scale_factor=2, mode='nearest')
90+
reflection_padding = kernel_size // 2
91+
self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
92+
self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride)
93+
94+
def forward(self, x):
95+
x_in = x
96+
# print('upsample', self.upsample)
97+
# x_in = torch.nn.functional.interpolate(x_in, mode='nearest', scale_factor=self.upsample)
98+
# if self.upsample:
99+
# x_in = torch.nn.functional.interpolate(x_in, mode='nearest', scale_factor=self.upsample)
100+
out = self.upsample(x_in)
101+
out = self.reflection_pad(out)
102+
# out = self.reflection_pad(out.to(torch.float32)).to(x.dtype)
103+
out = self.conv2d(out)
104+
return out

lib/Bridge.chpl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,14 @@ module Bridge {
7171
in model: bridge_pt_model_t,
7272
in input: bridge_tensor_t): bridge_tensor_t;
7373

74+
extern "accelerator_available"
75+
proc acceleratorAvailable(): bool;
76+
77+
extern "get_cpu_frame_width"
78+
proc getCPUFrameWidth(width: int(64), scale_factor: real(32)): int(64);
79+
extern "get_cpu_frame_height"
80+
proc getCPUFrameHeight(height: int(64), scale_factor: real(32)): int(64);
81+
7482

7583
extern proc convolve2d(
7684
in input: bridge_tensor_t,

0 commit comments

Comments
 (0)