-
Notifications
You must be signed in to change notification settings - Fork 1.3k
/
Copy pathtest_paired_image_dataset.py
70 lines (55 loc) · 2.18 KB
/
test_paired_image_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import math
import os
import torchvision.utils
from basicsr.data import build_dataloader, build_dataset
def main(mode='folder'):
"""Test paired image dataset.
Args:
mode: There are three modes: 'lmdb', 'folder', 'meta_info_file'.
"""
opt = {}
opt['dist'] = False
opt['phase'] = 'train'
opt['name'] = 'DIV2K'
opt['type'] = 'PairedImageDataset'
if mode == 'folder':
opt['dataroot_gt'] = 'datasets/DIV2K/DIV2K_train_HR_sub'
opt['dataroot_lq'] = 'datasets/DIV2K/DIV2K_train_LR_bicubic/X4_sub'
opt['filename_tmpl'] = '{}'
opt['io_backend'] = dict(type='disk')
elif mode == 'meta_info_file':
opt['dataroot_gt'] = 'datasets/DIV2K/DIV2K_train_HR_sub'
opt['dataroot_lq'] = 'datasets/DIV2K/DIV2K_train_LR_bicubic/X4_sub'
opt['meta_info_file'] = 'basicsr/data/meta_info/meta_info_DIV2K800sub_GT.txt' # noqa:E501
opt['filename_tmpl'] = '{}'
opt['io_backend'] = dict(type='disk')
elif mode == 'lmdb':
opt['dataroot_gt'] = 'datasets/DIV2K/DIV2K_train_HR_sub.lmdb'
opt['dataroot_lq'] = 'datasets/DIV2K/DIV2K_train_LR_bicubic_X4_sub.lmdb' # noqa:E501
opt['io_backend'] = dict(type='lmdb')
opt['gt_size'] = 128
opt['use_hflip'] = True
opt['use_rot'] = True
opt['num_worker_per_gpu'] = 2
opt['batch_size_per_gpu'] = 16
opt['scale'] = 4
opt['dataset_enlarge_ratio'] = 1
os.makedirs('tmp', exist_ok=True)
dataset = build_dataset(opt)
data_loader = build_dataloader(dataset, opt, num_gpu=0, dist=opt['dist'], sampler=None)
nrow = int(math.sqrt(opt['batch_size_per_gpu']))
padding = 2 if opt['phase'] == 'train' else 0
print('start...')
for i, data in enumerate(data_loader):
if i > 5:
break
print(i)
lq = data['lq']
gt = data['gt']
lq_path = data['lq_path']
gt_path = data['gt_path']
print(lq_path, gt_path)
torchvision.utils.save_image(lq, f'tmp/lq_{i:03d}.png', nrow=nrow, padding=padding, normalize=False)
torchvision.utils.save_image(gt, f'tmp/gt_{i:03d}.png', nrow=nrow, padding=padding, normalize=False)
if __name__ == '__main__':
main()