Skip to content

Commit fa60713

Browse files
committed
Loading module files in style tranfer works
:
1 parent af2b197 commit fa60713

32 files changed

Lines changed: 895 additions & 8 deletions

demos/video/CMakeLists.txt

Lines changed: 53 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,22 +9,19 @@ find_library(METAL Metal REQUIRED)
99
find_library(FOUNDATION Foundation REQUIRED)
1010

1111

12+
1213
add_executable(VidStreamer
13-
${CMAKE_CURRENT_SOURCE_DIR}/webcam_infer.cpp
14-
${CMAKE_CURRENT_SOURCE_DIR}/cvtool.hpp
15-
${CMAKE_CURRENT_SOURCE_DIR}/imageops.hpp
14+
${CMAKE_CURRENT_SOURCE_DIR}/webcam-capture/webcam_infer.cpp
1615
)
1716

1817
target_include_directories(VidStreamer
1918
PRIVATE
19+
${CMAKE_CURRENT_SOURCE_DIR}/include
2020
${LIBTORCH_DIR}/include
2121
${LIBTORCH_DIR}/include/torch/csrc/api/include
2222
)
2323

24-
target_link_directories(VidStreamer
25-
PRIVATE
26-
${LIBTORCH_DIR}/lib
27-
)
24+
target_link_directories(VidStreamer PRIVATE ${LIBTORCH_DIR}/lib)
2825

2926
target_link_libraries(VidStreamer
3027
PRIVATE
@@ -43,10 +40,58 @@ set_target_properties(VidStreamer PROPERTIES
4340
RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}
4441
)
4542

46-
4743
if (CMAKE_CXX_COMPILER_ID MATCHES "Clang")
4844
target_compile_options(VidStreamer PRIVATE -Ofast -flto -ffast-math)
4945
target_link_options(VidStreamer PRIVATE -flto)
5046
endif()
5147

5248

49+
50+
51+
52+
53+
54+
add_executable(StyleTransfer
55+
${CMAKE_CURRENT_SOURCE_DIR}/style-transfer/style_transfer.cpp
56+
)
57+
58+
target_include_directories(StyleTransfer
59+
PRIVATE
60+
${CMAKE_CURRENT_SOURCE_DIR}/include
61+
${LIBTORCH_DIR}/include
62+
${LIBTORCH_DIR}/include/torch/csrc/api/include
63+
)
64+
65+
target_link_directories(StyleTransfer PRIVATE ${LIBTORCH_DIR}/lib)
66+
67+
target_link_libraries(StyleTransfer
68+
PRIVATE
69+
-ltorch
70+
-ltorch_cpu
71+
-lc10
72+
-ltorch_global_deps
73+
${OpenCV_LIBS}
74+
# ${TORCH_LIBRARIES}
75+
${ACCELERATE}
76+
${METAL}
77+
${FOUNDATION}
78+
)
79+
80+
set_target_properties(StyleTransfer PROPERTIES
81+
RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}
82+
)
83+
84+
if (CMAKE_CXX_COMPILER_ID MATCHES "Clang")
85+
target_compile_options(StyleTransfer PRIVATE -Ofast -flto -ffast-math)
86+
target_link_options(StyleTransfer PRIVATE -flto)
87+
endif()
88+
89+
90+
add_custom_command(
91+
TARGET StyleTransfer
92+
POST_BUILD
93+
COMMAND ${CMAKE_COMMAND} -E copy_directory
94+
"${CMAKE_CURRENT_SOURCE_DIR}/style-transfer/models"
95+
"$<TARGET_FILE_DIR:StyleTransfer>/style-transfer/models"
96+
COMMENT "NOT! Copying ${PROJECT_ROOT_DIR}/examples/vgg/images to $<TARGET_FILE_DIR:vgg>/images"
97+
)
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 1,
6+
"id": "6e4d2e04",
7+
"metadata": {},
8+
"outputs": [],
9+
"source": [
10+
"import torch"
11+
]
12+
},
13+
{
14+
"cell_type": "code",
15+
"execution_count": 11,
16+
"id": "ec74c8a7",
17+
"metadata": {},
18+
"outputs": [],
19+
"source": [
20+
"class MyModule(torch.nn.Module):\n",
21+
" def __init__(self, N, M):\n",
22+
" super(MyModule, self).__init__()\n",
23+
" self.linear = torch.nn.Linear(N, M)\n",
24+
"\n",
25+
" def forward(self, input):\n",
26+
" return self.linear(input)\n"
27+
]
28+
},
29+
{
30+
"cell_type": "code",
31+
"execution_count": null,
32+
"id": "180e54ac",
33+
"metadata": {},
34+
"outputs": [],
35+
"source": [
36+
"my_module = MyModule(10,20)\n",
37+
"# sm = torch.jit.script(my_module)\n",
38+
"sm = torch.jit.script(my_module)\n",
39+
"sm.save(\"models/my_module.pt\")"
40+
]
41+
},
42+
{
43+
"cell_type": "code",
44+
"execution_count": null,
45+
"id": "d5e377e0",
46+
"metadata": {},
47+
"outputs": [],
48+
"source": []
49+
},
50+
{
51+
"cell_type": "code",
52+
"execution_count": null,
53+
"id": "89e90304",
54+
"metadata": {},
55+
"outputs": [
56+
{
57+
"ename": "RuntimeError",
58+
"evalue": "Parent directory models does not exist.",
59+
"output_type": "error",
60+
"traceback": [
61+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
62+
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
63+
"Cell \u001b[0;32mIn[4], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43msm\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msave\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mmodels/my_module.pt\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n",
64+
"File \u001b[0;32m~/.venv/lib/python3.12/site-packages/torch/jit/_script.py:754\u001b[0m, in \u001b[0;36mRecursiveScriptModule.save\u001b[0;34m(self, f, **kwargs)\u001b[0m\n\u001b[1;32m 745\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21msave\u001b[39m(\u001b[38;5;28mself\u001b[39m, f, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 746\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124mr\u001b[39m\u001b[38;5;124;03m\"\"\"Save with a file-like object.\u001b[39;00m\n\u001b[1;32m 747\u001b[0m \n\u001b[1;32m 748\u001b[0m \u001b[38;5;124;03m save(f, _extra_files={})\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 752\u001b[0m \u001b[38;5;124;03m DO NOT confuse these two functions when it comes to the 'f' parameter functionality.\u001b[39;00m\n\u001b[1;32m 753\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 754\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_c\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msave\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mstr\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mf\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
65+
"\u001b[0;31mRuntimeError\u001b[0m: Parent directory models does not exist."
66+
]
67+
}
68+
],
69+
"source": [
70+
"# sm.save(\"models/my_module.pt\")"
71+
]
72+
},
73+
{
74+
"cell_type": "code",
75+
"execution_count": null,
76+
"id": "d85b6e83",
77+
"metadata": {},
78+
"outputs": [],
79+
"source": [
80+
"m = torch.jit.load(\"models/my_module.pt\")"
81+
]
82+
},
83+
{
84+
"cell_type": "code",
85+
"execution_count": 17,
86+
"id": "7d6255fd",
87+
"metadata": {},
88+
"outputs": [
89+
{
90+
"data": {
91+
"text/plain": [
92+
"RecursiveScriptModule(\n",
93+
" original_name=MyModule\n",
94+
" (linear): RecursiveScriptModule(original_name=Linear)\n",
95+
")"
96+
]
97+
},
98+
"execution_count": 17,
99+
"metadata": {},
100+
"output_type": "execute_result"
101+
}
102+
],
103+
"source": [
104+
"m"
105+
]
106+
},
107+
{
108+
"cell_type": "code",
109+
"execution_count": 20,
110+
"id": "0d8ff397",
111+
"metadata": {},
112+
"outputs": [],
113+
"source": [
114+
"x = torch.randn(10)"
115+
]
116+
},
117+
{
118+
"cell_type": "code",
119+
"execution_count": 21,
120+
"id": "ffe62563",
121+
"metadata": {},
122+
"outputs": [
123+
{
124+
"data": {
125+
"text/plain": [
126+
"tensor([-0.5386, 0.6545, 0.4650, -0.3320, 0.2735, 0.2796, -0.4549, 0.2646,\n",
127+
" -0.9322, -0.3031, -0.3441, -0.3761, 0.6457, 0.6456, -0.2478, -0.2270,\n",
128+
" 0.8485, 0.9710, -0.0596, 0.6110], grad_fn=<ViewBackward0>)"
129+
]
130+
},
131+
"execution_count": 21,
132+
"metadata": {},
133+
"output_type": "execute_result"
134+
}
135+
],
136+
"source": [
137+
"m(x)"
138+
]
139+
},
140+
{
141+
"cell_type": "code",
142+
"execution_count": null,
143+
"id": "51739d61",
144+
"metadata": {},
145+
"outputs": [],
146+
"source": []
147+
}
148+
],
149+
"metadata": {
150+
"kernelspec": {
151+
"display_name": ".venv",
152+
"language": "python",
153+
"name": "python3"
154+
},
155+
"language_info": {
156+
"codemirror_mode": {
157+
"name": "ipython",
158+
"version": 3
159+
},
160+
"file_extension": ".py",
161+
"mimetype": "text/x-python",
162+
"name": "python",
163+
"nbconvert_exporter": "python",
164+
"pygments_lexer": "ipython3",
165+
"version": "3.12.9"
166+
}
167+
},
168+
"nbformat": 4,
169+
"nbformat_minor": 5
170+
}
4.43 KB
Binary file not shown.
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
#include <torch/torch.h>
2+
#include <torch/script.h>
3+
#include <iostream>
4+
#include <fstream>
5+
#include <string>
6+
#include <vector>
7+
8+
torch::jit::Module load_model(const std::string& model_path) {
9+
std::cout << "Loading model from path: " << model_path << std::endl;
10+
torch::jit::Module module;
11+
try {
12+
// Deserialize the ScriptModule from a file using torch::jit::load().
13+
module = torch::jit::load(model_path);
14+
} catch (const c10::Error& e) {
15+
std::cerr << "error loading the model\n" << e.msg();
16+
}
17+
std::cout << "Model loaded successfully." << std::endl;
18+
return module;
19+
20+
}
21+
22+
torch::Tensor run_model(torch::jit::Module& module, const torch::Tensor& input) {
23+
std::vector<torch::jit::IValue> inputs;
24+
inputs.push_back(input);
25+
26+
std::cout << "Input tensor: " << input.sizes() << std::endl;
27+
auto output = module.forward(inputs).toTensor();
28+
std::cout << "Model output: " << output.sizes() << std::endl;
29+
return output;
30+
}
31+
32+
33+
int main() {
34+
// Load the model
35+
std::string model_path = "style-transfer/models/my_module.pt";
36+
torch::jit::Module module = load_model(model_path);
37+
38+
// Create a random input tensor
39+
torch::Tensor input = torch::randn({10});
40+
41+
// Run the model
42+
torch::Tensor output = run_model(module, input);
43+
44+
// Print the output tensor
45+
std::cout << "Output tensor: " << output.sizes() << std::endl;
46+
47+
return 0;
48+
}
File renamed without changes.

examples/.gitignore

Whitespace-only changes.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
saved_models/*
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# fast-neural-style :city_sunrise: :rocket:
2+
3+
This repository contains a pytorch implementation of an algorithm for artistic style transfer. The algorithm can be used to mix the content of an image with the style of another image. For example, here is a photograph of a door arch rendered in the style of a stained glass painting.
4+
5+
The model uses the method described in [Perceptual Losses for Real-Time Style Transfer and Super-Resolution](https://arxiv.org/abs/1603.08155) along with [Instance Normalization](https://arxiv.org/pdf/1607.08022.pdf). The saved-models for examples shown in the README can be downloaded from [here](https://www.dropbox.com/s/lrvwfehqdcxoza8/saved_models.zip?dl=0).
6+
7+
<p align="center">
8+
<img src="images/style-images/mosaic.jpg" height="200px">
9+
<img src="images/content-images/amber.jpg" height="200px">
10+
<img src="images/output-images/amber-mosaic.jpg" height="440px">
11+
</p>
12+
13+
## Requirements
14+
15+
The program is written in Python, and uses [pytorch](http://pytorch.org/), [scipy](https://www.scipy.org). A GPU is not necessary, but can provide a significant speed up especially for training a new model. Regular sized images can be styled on a laptop or desktop using saved models.
16+
17+
## Usage
18+
19+
Stylize image
20+
21+
```
22+
python neural_style/neural_style.py eval --content-image </path/to/content/image> --model </path/to/saved/model> --output-image </path/to/output/image> --accel
23+
```
24+
25+
- `--content-image`: path to content image you want to stylize.
26+
- `--model`: saved model to be used for stylizing the image (eg: `mosaic.pth`)
27+
- `--output-image`: path for saving the output image.
28+
- `--content-scale`: factor for scaling down the content image if memory is an issue (eg: value of 2 will halve the height and width of content-image)
29+
- `--accel`: use accelerator
30+
31+
Train model
32+
33+
```bash
34+
python neural_style/neural_style.py train --dataset </path/to/train-dataset> --style-image </path/to/style/image> --save-model-dir </path/to/save-model/folder> --epochs 2 --accel
35+
```
36+
37+
There are several command line arguments, the important ones are listed below
38+
39+
- `--dataset`: path to training dataset, the path should point to a folder containing another folder with all the training images. I used COCO 2014 Training images dataset [80K/13GB] [(download)](https://cocodataset.org/#download).
40+
- `--style-image`: path to style-image.
41+
- `--save-model-dir`: path to folder where trained model will be saved.
42+
- `--accel`: use accelerator.
43+
44+
If `--accel` argument is given, pytorch will search for available hardware acceleration device and attempt to use it. This example is known to work on CUDA, MPS and XPU devices.
45+
46+
Refer to `neural_style/neural_style.py` for other command line arguments. For training new models you might have to tune the values of `--content-weight` and `--style-weight`. The mosaic style model shown above was trained with `--content-weight 1e5` and `--style-weight 1e10`. The remaining 3 models were also trained with similar order of weight parameters with slight variation in the `--style-weight` (`5e10` or `1e11`).
47+
48+
## Models
49+
50+
Models for the examples shown below can be downloaded from [here](https://www.dropbox.com/s/lrvwfehqdcxoza8/saved_models.zip?dl=0) or by running the script `download_saved_models.py`.
51+
52+
<div align='center'>
53+
<img src='images/content-images/amber.jpg' height="174px">
54+
</div>
55+
56+
<div align='center'>
57+
<img src='images/style-images/mosaic.jpg' height="174px">
58+
<img src='images/output-images/amber-mosaic.jpg' height="174px">
59+
<img src='images/output-images/amber-candy.jpg' height="174px">
60+
<img src='images/style-images/candy.jpg' height="174px">
61+
<br>
62+
<img src='images/style-images/rain-princess-cropped.jpg' height="174px">
63+
<img src='images/output-images/amber-rain-princess.jpg' height="174px">
64+
<img src='images/output-images/amber-udnie.jpg' height="174px">
65+
<img src='images/style-images/udnie.jpg' height="174px">
66+
</div>

0 commit comments

Comments
 (0)