Skip to content

Commit 31fce77

Browse files
committed
cpp-model-construction working for some reason.
1 parent 310fe1a commit 31fce77

25 files changed

Lines changed: 4375 additions & 3 deletions

demos/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,5 @@ add_subdirectory(video)
33
# add_subdirectory(webcam_filter)
44

55
# add_subdirectory(torchtest)
6-
add_subdirectory(torchtest_bridge)
6+
add_subdirectory(torchtest_bridge)
7+

demos/video/CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,4 +99,6 @@ add_custom_command(
9999
"${CMAKE_CURRENT_SOURCE_DIR}/style-transfer/models"
100100
"$<TARGET_FILE_DIR:StyleTransfer>/style-transfer/models"
101101
COMMENT "NOT! Copying ${PROJECT_ROOT_DIR}/examples/vgg/images to $<TARGET_FILE_DIR:vgg>/images"
102-
)
102+
)
103+
104+
add_subdirectory(cpp-model-construction)
6.43 MB
Binary file not shown.

demos/video/chapel-webcam/model2.ipynb

Lines changed: 64 additions & 0 deletions
Large diffs are not rendered by default.

demos/video/chapel-webcam/model3.ipynb

Lines changed: 153 additions & 0 deletions
Large diffs are not rendered by default.

demos/video/chapel-webcam/net.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
#include "transformer_net.hpp"
2+
3+
int main() {
4+
torch::manual_seed(0);
5+
TransformerNet model;
6+
model->eval();
7+
8+
// dummy input
9+
auto input = torch::randn({1,3,256,256});
10+
auto output = model->forward(input);
11+
12+
std::cout << "Output shape: " << output.sizes() << "\n";
13+
}
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
// transformer_net.h
2+
#pragma once
3+
#include <torch/torch.h>
4+
5+
//
6+
// --- ConvLayer -------------------------------------------------------------
7+
//
8+
struct ConvLayerImpl : torch::nn::Module {
9+
torch::nn::ReflectionPad2d reflection_pad{nullptr};
10+
torch::nn::Conv2d conv2d{nullptr};
11+
12+
ConvLayerImpl(int64_t in_channels,
13+
int64_t out_channels,
14+
int64_t kernel_size,
15+
int64_t stride)
16+
: reflection_pad(torch::nn::ReflectionPad2dOptions(kernel_size / 2)),
17+
conv2d(torch::nn::Conv2dOptions(in_channels, out_channels, kernel_size)
18+
.stride(stride))
19+
{
20+
register_module("reflection_pad", reflection_pad);
21+
register_module("conv2d", conv2d);
22+
}
23+
24+
torch::Tensor forward(torch::Tensor x) {
25+
x = reflection_pad->forward(x);
26+
x = conv2d->forward(x);
27+
return x;
28+
}
29+
};
30+
TORCH_MODULE(ConvLayer);
31+
32+
//
33+
// --- ResidualBlock ---------------------------------------------------------
34+
//
35+
struct ResidualBlockImpl : torch::nn::Module {
36+
ConvLayer conv1{nullptr};
37+
torch::nn::InstanceNorm2d in1{nullptr};
38+
ConvLayer conv2{nullptr};
39+
torch::nn::InstanceNorm2d in2{nullptr};
40+
torch::nn::ReLU relu{nullptr};
41+
42+
ResidualBlockImpl(int64_t channels)
43+
: conv1(ConvLayer(channels, channels, 3, 1)),
44+
in1(torch::nn::InstanceNorm2dOptions(channels).affine(true)),
45+
conv2(ConvLayer(channels, channels, 3, 1)),
46+
in2(torch::nn::InstanceNorm2dOptions(channels).affine(true)),
47+
relu(torch::nn::ReLUOptions(true))
48+
{
49+
register_module("conv1", conv1);
50+
register_module("in1", in1);
51+
register_module("conv2", conv2);
52+
register_module("in2", in2);
53+
register_module("relu", relu);
54+
}
55+
56+
torch::Tensor forward(torch::Tensor x) {
57+
auto residual = x;
58+
auto out = relu->forward(in1->forward(conv1->forward(x)));
59+
out = in2->forward(conv2->forward(out));
60+
return out + residual;
61+
}
62+
};
63+
TORCH_MODULE(ResidualBlock);
64+
65+
//
66+
// --- UpsampleConvLayer -----------------------------------------------------
67+
//
68+
struct UpsampleConvLayerImpl : torch::nn::Module {
69+
torch::nn::Upsample upsample{nullptr};
70+
torch::nn::ReflectionPad2d reflection_pad{nullptr};
71+
torch::nn::Conv2d conv2d{nullptr};
72+
73+
UpsampleConvLayerImpl(int64_t in_channels,
74+
int64_t out_channels,
75+
int64_t kernel_size,
76+
int64_t stride,
77+
int64_t upsample_scale)
78+
: upsample(torch::nn::UpsampleOptions()
79+
.scale_factor(std::vector<double>{(double)upsample_scale, (double)upsample_scale})
80+
.mode(torch::kNearest)),
81+
reflection_pad(torch::nn::ReflectionPad2dOptions(kernel_size / 2)),
82+
conv2d(torch::nn::Conv2dOptions(in_channels, out_channels, kernel_size)
83+
.stride(stride))
84+
{
85+
register_module("upsample", upsample);
86+
register_module("reflection_pad", reflection_pad);
87+
register_module("conv2d", conv2d);
88+
}
89+
90+
torch::Tensor forward(torch::Tensor x) {
91+
x = upsample->forward(x);
92+
x = reflection_pad->forward(x);
93+
x = conv2d->forward(x);
94+
return x;
95+
}
96+
};
97+
TORCH_MODULE(UpsampleConvLayer);
98+
99+
//
100+
// --- TransformerNet --------------------------------------------------------
101+
//
102+
struct TransformerNetImpl : torch::nn::Module {
103+
ConvLayer conv1{nullptr};
104+
torch::nn::InstanceNorm2d in1{nullptr};
105+
ConvLayer conv2{nullptr};
106+
torch::nn::InstanceNorm2d in2{nullptr};
107+
ConvLayer conv3{nullptr};
108+
torch::nn::InstanceNorm2d in3{nullptr};
109+
ResidualBlock res1{nullptr}, res2{nullptr}, res3{nullptr},
110+
res4{nullptr}, res5{nullptr};
111+
UpsampleConvLayer deconv1{nullptr}, deconv2{nullptr};
112+
torch::nn::InstanceNorm2d in4{nullptr}, in5{nullptr};
113+
ConvLayer deconv3{nullptr};
114+
torch::nn::ReLU relu{nullptr};
115+
116+
TransformerNetImpl()
117+
: conv1(ConvLayer( 3, 32, 9, 1)),
118+
in1 (torch::nn::InstanceNorm2dOptions(32).affine(true)),
119+
conv2(ConvLayer( 32, 64, 3, 2)),
120+
in2 (torch::nn::InstanceNorm2dOptions(64).affine(true)),
121+
conv3(ConvLayer( 64, 128, 3, 2)),
122+
in3 (torch::nn::InstanceNorm2dOptions(128).affine(true)),
123+
res1 (ResidualBlock(128)),
124+
res2 (ResidualBlock(128)),
125+
res3 (ResidualBlock(128)),
126+
res4 (ResidualBlock(128)),
127+
res5 (ResidualBlock(128)),
128+
deconv1(UpsampleConvLayer(128, 64, 3, 1, 2)),
129+
in4 (torch::nn::InstanceNorm2dOptions(64).affine(true)),
130+
deconv2(UpsampleConvLayer( 64, 32, 3, 1, 2)),
131+
in5 (torch::nn::InstanceNorm2dOptions(32).affine(true)),
132+
deconv3(ConvLayer( 32, 3, 9, 1)),
133+
relu (torch::nn::ReLUOptions(true))
134+
{
135+
register_module("conv1", conv1);
136+
register_module("in1", in1);
137+
register_module("conv2", conv2);
138+
register_module("in2", in2);
139+
register_module("conv3", conv3);
140+
register_module("in3", in3);
141+
register_module("res1", res1);
142+
register_module("res2", res2);
143+
register_module("res3", res3);
144+
register_module("res4", res4);
145+
register_module("res5", res5);
146+
register_module("deconv1", deconv1);
147+
register_module("in4", in4);
148+
register_module("deconv2", deconv2);
149+
register_module("in5", in5);
150+
register_module("deconv3", deconv3);
151+
register_module("relu", relu);
152+
}
153+
154+
torch::Tensor forward(torch::Tensor x) {
155+
x = relu->forward(in1->forward(conv1->forward(x)));
156+
x = relu->forward(in2->forward(conv2->forward(x)));
157+
x = relu->forward(in3->forward(conv3->forward(x)));
158+
x = res1->forward(x);
159+
x = res2->forward(x);
160+
x = res3->forward(x);
161+
x = res4->forward(x);
162+
x = res5->forward(x);
163+
x = relu->forward(in4->forward(deconv1->forward(x)));
164+
x = relu->forward(in5->forward(deconv2->forward(x)));
165+
x = deconv3->forward(x);
166+
return x;
167+
}
168+
};
169+
TORCH_MODULE(TransformerNet);
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
include(CMakePrintHelpers)
2+
3+
# project(MyProject)
4+
# set(CMAKE_CXX_STANDARD 17)
5+
# list(APPEND CMAKE_PREFIX_PATH "${CMAKE_CURRENT_SOURCE_DIR}/libtorch/share/cmake")
6+
find_package(Torch REQUIRED)
7+
# find_package(OpenCV REQUIRED)
8+
SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++23 -lm -ldl")
9+
10+
cmake_print_variables(TORCH_LIBRARIES)
11+
cmake_print_variables(TORCH_INCLUDE_DIRS)
12+
cmake_print_variables(TORCH_INSTALL_PREFIX)
13+
cmake_print_variables(TORCH_CXX_FLAGS)
14+
cmake_print_variables(TORCH_LIBRARY)
15+
16+
add_executable(CPPModelConstruction ${CMAKE_CURRENT_SOURCE_DIR}/net.cpp ${CMAKE_CURRENT_SOURCE_DIR}/transformer_net.hpp)
17+
18+
target_link_libraries(CPPModelConstruction ${TORCH_LIBRARIES})
19+
set_property(TARGET CPPModelConstruction PROPERTY CXX_STANDARD 23)
20+
21+
set_property(TARGET CPPModelConstruction PROPERTY
22+
RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}
23+
)
Binary file not shown.

0 commit comments

Comments
 (0)