Skip to content

Commit ad00d6a

Browse files
committed
model conversion to onnx format
1 parent be92d81 commit ad00d6a

File tree

4 files changed

+64
-0
lines changed

4 files changed

+64
-0
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22
/.idea
33
*.pth
44
*.onnx
5+
saved_*.png

README.md

+19
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
Differences between original repository and fork:
44

55
* Compatibility with PyTorch >=2.0. (🔥)
6+
* Original pretrained models and converted ONNX models from GitHub [releases page](https://github.com/clibdev/colorization/releases). (🔥)
7+
* Model conversion to ONNX format using the [export.py](export.py) file. (🔥)
68
* Installation with updated [requirements.txt](requirements.txt) file.
79
* Additional command line options for specifying model weights in the [demo_release.py](demo_release.py) file.
810

@@ -12,8 +14,25 @@ Differences between original repository and fork:
1214
pip install -r requirements.txt
1315
```
1416

17+
# Pretrained models
18+
19+
| Name | Link |
20+
|-------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
21+
| ECCV 16 | [PyTorch](https://github.com/clibdev/colorization/releases/latest/download/colorization_release_v2-9b330a0b.pth), [ONNX](https://github.com/clibdev/colorization/releases/latest/download/colorization_release_v2-9b330a0b.onnx) |
22+
| SIGGRAPH 17 | [PyTorch](https://github.com/clibdev/colorization/releases/latest/download/siggraph17-df00044c.pth), [ONNX](https://github.com/clibdev/colorization/releases/latest/download/siggraph17-df00044c.onnx) |
23+
1524
# Inference
1625

1726
```shell
1827
python demo_release.py --eccv16_weights colorization_release_v2-9b330a0b.pth --siggraph17_weights siggraph17-df00044c.pth -i imgs/ansel_adams3.jpg
1928
```
29+
30+
# Export to ONNX format
31+
32+
```shell
33+
pip install onnx
34+
```
35+
```shell
36+
python export.py --weights colorization_release_v2-9b330a0b.pth --net_type eccv16
37+
python export.py --weights siggraph17-df00044c.pth --net_type siggraph17
38+
```

export.py

+41
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import argparse
2+
import os
3+
from colorizers import *
4+
5+
6+
if __name__ == '__main__':
7+
parser = argparse.ArgumentParser()
8+
9+
parser.add_argument('--weights', type=str, default='./colorization_release_v2-9b330a0b.pth', help='Weights path')
10+
parser.add_argument('--net_type', default='eccv16', type=str, help='The network architecture: eccv16 or siggraph17')
11+
parser.add_argument('--device', default='cpu', type=str, help='cuda:0 or cpu')
12+
args = parser.parse_args()
13+
14+
if not os.path.exists(args.weights):
15+
print('Cannot find weights: {0}'.format(args.weights))
16+
exit()
17+
18+
if args.net_type == 'eccv16':
19+
colorizer = eccv16(pretrained=True, weights=args.weights)
20+
elif args.net_type == 'siggraph17':
21+
colorizer = siggraph17(pretrained=True, weights=args.weights)
22+
else:
23+
print('Unsupported network type: {0}'.format(args.net_type))
24+
exit()
25+
26+
colorizer.eval()
27+
colorizer.to(args.device)
28+
29+
model_path = os.path.splitext(args.weights)[0] + '.onnx'
30+
print(model_path)
31+
32+
dummy_input = torch.randn(1, 1, 256, 256).to(args.device)
33+
torch.onnx.export(
34+
colorizer,
35+
dummy_input,
36+
model_path,
37+
verbose=False,
38+
input_names=['input'],
39+
output_names=['output'],
40+
opset_version=18
41+
)

requirements.txt

+3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
11
torch>=2.0.0
22
scikit-image>=0.21.0
33
matplotlib>=3.7.0
4+
5+
### Optional
6+
# onnx>=1.14.0

0 commit comments

Comments
 (0)