Skip to content

Commit 6c3b08a

Browse files
committed
Improving performance of style transfer demo.
1 parent a7329ae commit 6c3b08a

7 files changed

Lines changed: 324 additions & 15 deletions

File tree

bridge/.DS_Store

0 Bytes
Binary file not shown.

bridge/lib/bridge.cpp

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#include <bridge.h>
22

33
#include <torch/torch.h>
4+
#include <Aten/ATen.h>
5+
46
#include <torch/script.h>
57

68
// #include <torch/script.h>
@@ -27,6 +29,9 @@
2729

2830

2931

32+
torch::NoGradGuard no_grad;
33+
torch::AutoGradMode enable_grad(false);
34+
3035
int bridge_tensor_elements(bridge_tensor_t &bt) {
3136
int size = 1;
3237
for (int i = 0; i < bt.dim; ++i) {
@@ -39,14 +44,14 @@ size_t bridge_tensor_size(bridge_tensor_t &bt) {
3944
return sizeof(float32_t) * bridge_tensor_elements(bt);
4045
}
4146

42-
void store_tensor(torch::Tensor &input, float32_t* dest) {
47+
void store_tensor(at::Tensor &input, float32_t* dest) {
4348
float32_t * data = input.data_ptr<float32_t>();
4449
size_t bytes_size = sizeof(float32_t) * input.numel();
4550
// std::memmove(dest,data,bytes_size);
4651
std::memcpy(dest,data,bytes_size);
4752
}
4853

49-
bridge_tensor_t torch_to_bridge(torch::Tensor &tensor) {
54+
bridge_tensor_t torch_to_bridge(at::Tensor &tensor) {
5055
bridge_tensor_t result;
5156
result.created_by_c = true;
5257
result.dim = tensor.dim();
@@ -59,13 +64,13 @@ bridge_tensor_t torch_to_bridge(torch::Tensor &tensor) {
5964
return result;
6065
}
6166

62-
torch::Tensor bridge_to_torch(bridge_tensor_t &bt) {
67+
at::Tensor bridge_to_torch(bridge_tensor_t &bt) {
6368
std::vector<int64_t> sizes_vec(bt.sizes, bt.sizes + bt.dim);
6469
auto shape = torch::IntArrayRef(sizes_vec);
6570
return torch::from_blob(bt.data, shape, torch::kFloat);
6671
}
6772

68-
torch::Tensor bridge_to_torch(bridge_tensor_t &bt,torch::Device device, bool copy,torch::ScalarType dtype = torch::kFloat32) {
73+
at::Tensor bridge_to_torch(bridge_tensor_t &bt,torch::Device device, bool copy,torch::ScalarType dtype = torch::kFloat32) {
6974
std::vector<int64_t> sizes_vec(bt.sizes, bt.sizes + bt.dim);
7075
auto shape = torch::IntArrayRef(sizes_vec);
7176
auto t = torch::from_blob(bt.data, shape, torch::kFloat);
@@ -144,6 +149,10 @@ extern "C" bridge_tensor_t load_run_model(const uint8_t* model_path, bridge_tens
144149
}
145150

146151

152+
#define DEVICE torch::kMPS
153+
#define DTYPE torch::kFloat16
154+
155+
147156
extern "C" bridge_pt_model_t load_model(const uint8_t* model_path) {
148157

149158
std::cout << "Begin loading model from path: " << model_path << std::endl;
@@ -153,9 +162,8 @@ extern "C" bridge_pt_model_t load_model(const uint8_t* model_path) {
153162
std::cout.flush();
154163

155164
try {
156-
157165
auto* module = new torch::jit::Module(torch::jit::load(path));
158-
module->to(torch::kMPS,torch::kFloat16,false);
166+
module->to(DEVICE,DTYPE,false);
159167
module->eval();
160168
std::cout << "Model loaded successfully!" << std::endl;
161169
std::cout.flush();
@@ -204,23 +212,23 @@ extern "C" bridge_pt_model_t load_model(const uint8_t* model_path) {
204212

205213

206214
bridge_tensor_t model_forward(bridge_pt_model_t model, bridge_tensor_t input, bool is_vgg_based_model) {
207-
208-
auto tn_mps = bridge_to_torch(input,torch::kMPS,true,torch::kFloat16);
209-
// auto tn_mps = tn.to(torch::kMPS,false,true);
215+
auto tn_mps = bridge_to_torch(input,DEVICE,true,DTYPE);
210216
auto tn = tn_mps.permute({2, 0, 1}).unsqueeze(0).contiguous();
211217

212218
std::vector<torch::jit::IValue> ins;
213219
ins.push_back(tn);
214220

215221
auto* module = static_cast<torch::jit::Module*>(model.pt_module);
216222
auto o = module->forward(ins).toTensor();
217-
auto tn_out = o.squeeze(0).contiguous().permute({1, 2, 0}).contiguous();
223+
auto tn_out = o.squeeze(0).permute({1, 2, 0}).contiguous();
224+
// auto tn_out = o.squeeze(0).contiguous().permute({1, 2, 0}).contiguous();
218225

219226
if (is_vgg_based_model) {
220-
tn_out = tn_out / 255.0;
227+
tn_out.div_(255.0);
221228
}
222229

223230
auto tn_out_cpu = tn_out.to(torch::kCPU,torch::kFloat32,false,true);
231+
224232
return torch_to_bridge(tn_out_cpu);
225233

226234
}

demos/models/readme.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
This folder contains the model architectures used in the demos.

demos/models/transformer_net.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
import torch
2+
3+
4+
class TransformerNet(torch.nn.Module):
5+
def __init__(self):
6+
super(TransformerNet, self).__init__()
7+
# Initial convolution layers
8+
self.conv1 = ConvLayer(3, 32, kernel_size=9, stride=1)
9+
self.in1 = torch.nn.InstanceNorm2d(32, affine=True)
10+
self.conv2 = ConvLayer(32, 64, kernel_size=3, stride=2)
11+
self.in2 = torch.nn.InstanceNorm2d(64, affine=True)
12+
self.conv3 = ConvLayer(64, 128, kernel_size=3, stride=2)
13+
self.in3 = torch.nn.InstanceNorm2d(128, affine=True)
14+
# Residual layers
15+
self.res1 = ResidualBlock(128)
16+
self.res2 = ResidualBlock(128)
17+
self.res3 = ResidualBlock(128)
18+
self.res4 = ResidualBlock(128)
19+
self.res5 = ResidualBlock(128)
20+
# Upsampling Layers
21+
self.deconv1 = UpsampleConvLayer(128, 64, kernel_size=3, stride=1, upsample=2)
22+
self.in4 = torch.nn.InstanceNorm2d(64, affine=True)
23+
self.deconv2 = UpsampleConvLayer(64, 32, kernel_size=3, stride=1, upsample=2)
24+
self.in5 = torch.nn.InstanceNorm2d(32, affine=True)
25+
self.deconv3 = ConvLayer(32, 3, kernel_size=9, stride=1)
26+
# Non-linearities
27+
self.relu = torch.nn.ReLU()
28+
29+
def forward(self, X):
30+
y = self.relu(self.in1(self.conv1(X)))
31+
y = self.relu(self.in2(self.conv2(y)))
32+
y = self.relu(self.in3(self.conv3(y)))
33+
y = self.res1(y)
34+
y = self.res2(y)
35+
y = self.res3(y)
36+
y = self.res4(y)
37+
y = self.res5(y)
38+
y = self.relu(self.in4(self.deconv1(y)))
39+
y = self.relu(self.in5(self.deconv2(y)))
40+
y = self.deconv3(y)
41+
return y
42+
43+
44+
class ConvLayer(torch.nn.Module):
45+
def __init__(self, in_channels, out_channels, kernel_size, stride):
46+
super(ConvLayer, self).__init__()
47+
reflection_padding = kernel_size // 2
48+
self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
49+
self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride)
50+
51+
def forward(self, x):
52+
out = self.reflection_pad(x)
53+
out = self.conv2d(out)
54+
return out
55+
56+
57+
class ResidualBlock(torch.nn.Module):
58+
"""ResidualBlock
59+
introduced in: https://arxiv.org/abs/1512.03385
60+
recommended architecture: http://torch.ch/blog/2016/02/04/resnets.html
61+
"""
62+
63+
def __init__(self, channels):
64+
super(ResidualBlock, self).__init__()
65+
self.conv1 = ConvLayer(channels, channels, kernel_size=3, stride=1)
66+
self.in1 = torch.nn.InstanceNorm2d(channels, affine=True)
67+
self.conv2 = ConvLayer(channels, channels, kernel_size=3, stride=1)
68+
self.in2 = torch.nn.InstanceNorm2d(channels, affine=True)
69+
self.relu = torch.nn.ReLU()
70+
71+
def forward(self, x):
72+
residual = x
73+
out = self.relu(self.in1(self.conv1(x)))
74+
out = self.in2(self.conv2(out))
75+
out = out + residual
76+
return out
77+
78+
79+
class UpsampleConvLayer(torch.nn.Module):
80+
"""UpsampleConvLayer
81+
Upsamples the input and then does a convolution. This method gives better results
82+
compared to ConvTranspose2d.
83+
ref: http://distill.pub/2016/deconv-checkerboard/
84+
"""
85+
86+
def __init__(self, in_channels, out_channels, kernel_size, stride, upsample):
87+
super(UpsampleConvLayer, self).__init__()
88+
# self.upsample = upsample
89+
self.upsample = torch.nn.Upsample(scale_factor=2, mode='nearest')
90+
reflection_padding = kernel_size // 2
91+
self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
92+
self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride)
93+
94+
def forward(self, x):
95+
x_in = x
96+
# print('upsample', self.upsample)
97+
# x_in = torch.nn.functional.interpolate(x_in, mode='nearest', scale_factor=self.upsample)
98+
# if self.upsample:
99+
# x_in = torch.nn.functional.interpolate(x_in, mode='nearest', scale_factor=self.upsample)
100+
out = self.upsample(x_in)
101+
out = self.reflection_pad(out)
102+
out = self.conv2d(out)
103+
return out

demos/video/chapel-webcam/model.ipynb

Lines changed: 84 additions & 4 deletions
Large diffs are not rendered by default.
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"id": "22e96cc8",
7+
"metadata": {},
8+
"outputs": [],
9+
"source": [
10+
"import torch\n",
11+
"from torchvision.io import decode_image,read_image\n",
12+
"from torchvision import transforms\n",
13+
"from torchvision.transforms import functional as F\n",
14+
"from PIL import Image\n",
15+
"import matplotlib.pyplot as plt\n",
16+
"def show_image(image):\n",
17+
" # plt.imshow(transforms.ToPILImage()(image), interpolation=\"bicubic\")\n",
18+
" # # pil_image = transforms.ToPILImage()(image)\n",
19+
" # # pil_image.show()\n",
20+
" plt.imshow(image.detach().permute(1, 2, 0).cpu())\n",
21+
" plt.axis('off')\n",
22+
" plt.show()\n",
23+
"\n",
24+
"# img = decode_image('coast.jpeg', mode='RGB')\n",
25+
"# img = img.to(torch.float32) / 255.0\n",
26+
"\n",
27+
"pil_img = Image.open('coast.jpeg')\n",
28+
"img = F.to_tensor(pil_img)\n",
29+
"print(img.shape)\n",
30+
"show_image(img)"
31+
]
32+
},
33+
{
34+
"cell_type": "code",
35+
"execution_count": null,
36+
"id": "335855a9",
37+
"metadata": {},
38+
"outputs": [],
39+
"source": [
40+
"device = torch.device('mps')\n",
41+
"stm = torch.jit.load('../style-transfer/models/exports/mps/mosaic_float16.pt', map_location=device)"
42+
]
43+
},
44+
{
45+
"cell_type": "code",
46+
"execution_count": null,
47+
"id": "9d32c34b",
48+
"metadata": {},
49+
"outputs": [],
50+
"source": [
51+
"stm.to(device)\n",
52+
"stm.eval()\n",
53+
"print(\"Model loaded.\")\n",
54+
"# help(stm)"
55+
]
56+
},
57+
{
58+
"cell_type": "code",
59+
"execution_count": null,
60+
"id": "20e724b3",
61+
"metadata": {},
62+
"outputs": [],
63+
"source": [
64+
"in_img = img.to(device).to(torch.float16).unsqueeze(0).contiguous()\n",
65+
"in_img.shape\n",
66+
"x = in_img.squeeze(0).to(torch.float32)\n",
67+
"print(x.shape)\n",
68+
"show_image(x)\n"
69+
]
70+
},
71+
{
72+
"cell_type": "code",
73+
"execution_count": null,
74+
"id": "a6de87a4",
75+
"metadata": {},
76+
"outputs": [],
77+
"source": [
78+
"output = stm(in_img).detach()\n",
79+
"print(output.shape,output.dtype,output.device)\n",
80+
"torch.mps.empty_cache()"
81+
]
82+
},
83+
{
84+
"cell_type": "code",
85+
"execution_count": null,
86+
"id": "473c1852",
87+
"metadata": {},
88+
"outputs": [],
89+
"source": [
90+
"out_img = (output.squeeze(0) / 255.0).to(torch.float32)\n",
91+
"print(out_img.shape)\n",
92+
"show_image(out_img)"
93+
]
94+
}
95+
],
96+
"metadata": {
97+
"kernelspec": {
98+
"display_name": ".venv",
99+
"language": "python",
100+
"name": "python3"
101+
},
102+
"language_info": {
103+
"codemirror_mode": {
104+
"name": "ipython",
105+
"version": 3
106+
},
107+
"file_extension": ".py",
108+
"mimetype": "text/x-python",
109+
"name": "python",
110+
"nbconvert_exporter": "python",
111+
"pygments_lexer": "ipython3",
112+
"version": "3.12.9"
113+
}
114+
},
115+
"nbformat": 4,
116+
"nbformat_minor": 5
117+
}

demos/video/chapel-webcam/sobel.pt

0 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)