-
Notifications
You must be signed in to change notification settings - Fork 35
Expand file tree
/
Copy pathdataloader.py
More file actions
96 lines (73 loc) · 3.46 KB
/
dataloader.py
File metadata and controls
96 lines (73 loc) · 3.46 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import cv2
import torch
import numpy as np
from torch.utils import data
import random
# from config import config
from utils.transforms import generate_random_crop_pos, random_crop_pad_to_shape, normalize
def random_mirror(rgb, gt, modal_x):
if random.random() >= 0.5:
rgb = cv2.flip(rgb, 1)
gt = cv2.flip(gt, 1)
modal_x = cv2.flip(modal_x, 1)
return rgb, gt, modal_x
def random_scale(rgb, gt, modal_x, scales):
scale = random.choice(scales)
sh = int(rgb.shape[0] * scale)
sw = int(rgb.shape[1] * scale)
rgb = cv2.resize(rgb, (sw, sh), interpolation=cv2.INTER_LINEAR)
gt = cv2.resize(gt, (sw, sh), interpolation=cv2.INTER_NEAREST)
modal_x = cv2.resize(modal_x, (sw, sh), interpolation=cv2.INTER_LINEAR)
return rgb, gt, modal_x, scale
class TrainPre(object):
def __init__(self, config):
self.config = config
self.norm_mean = config.norm_mean
self.norm_std = config.norm_std
def __call__(self, rgb, gt, modal_x):
rgb, gt, modal_x = random_mirror(rgb, gt, modal_x)
if self.config.train_scale_array is not None:
rgb, gt, modal_x, scale = random_scale(rgb, gt, modal_x, self.config.train_scale_array)
rgb = normalize(rgb, self.norm_mean, self.norm_std)
modal_x = normalize(modal_x, self.norm_mean, self.norm_std)
crop_size = (self.config.image_height, self.config.image_width)
crop_pos = generate_random_crop_pos(rgb.shape[:2], crop_size)
p_rgb, _ = random_crop_pad_to_shape(rgb, crop_pos, crop_size, 0)
p_gt, _ = random_crop_pad_to_shape(gt, crop_pos, crop_size, 255)
p_modal_x, _ = random_crop_pad_to_shape(modal_x, crop_pos, crop_size, 0)
p_rgb = p_rgb.transpose(2, 0, 1)
p_modal_x = p_modal_x.transpose(2, 0, 1)
return p_rgb, p_gt, p_modal_x
class ValPre(object):
def __call__(self, rgb, gt, modal_x):
return rgb, gt, modal_x
def get_train_loader(engine, dataset, config=None):
data_setting = {'rgb_root': config.rgb_root_folder,
'rgb_format': config.rgb_format,
'gt_root': config.gt_root_folder,
'gt_format': config.gt_format,
'transform_gt': config.gt_transform,
'x_root':config.x_root_folder,
'x_format': config.x_format,
'x_single_channel': config.x_is_single_channel,
'class_names': config.class_names,
'train_source': config.train_source,
'eval_source': config.eval_source,
'class_names': config.class_names}
train_preprocess = TrainPre(config)
train_dataset = dataset(data_setting, "train", train_preprocess, config.batch_size * config.niters_per_epoch)
train_sampler = None
is_shuffle = True
batch_size = config.batch_size
if engine.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
batch_size = config.batch_size // engine.world_size
is_shuffle = False
train_loader = data.DataLoader(train_dataset,
batch_size=batch_size,
num_workers=config.num_workers,
drop_last=True,
shuffle=is_shuffle,
pin_memory=True,
sampler=train_sampler)
return train_loader, train_sampler