-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathdata_loader.py
89 lines (82 loc) · 2.89 KB
/
data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image
import importlib
import torch
from misc.utils import horovod
hvd = horovod()
hvd.init()
# ==================================================================#
# == LOADER ==#
# ==================================================================#
def get_loader(mode_data,
image_size,
batch_size,
dataset='BP4D',
mode='train',
shuffling=False,
num_workers=0,
HOROVOD=False,
**kwargs):
mean = (0.5, 0.5, 0.5)
std = (0.5, 0.5, 0.5)
transform = []
if mode_data == 'faces' or mode != 'train':
transform += [
transforms.Resize((image_size, image_size),
interpolation=Image.ANTIALIAS)
]
elif dataset == 'RafD' or dataset == 'EmotionNet':
window = int(image_size / 10)
transform += [
transforms.Resize((image_size + window, image_size + window),
interpolation=Image.ANTIALIAS)
]
transform += [
transforms.RandomResizedCrop(
image_size, scale=(0.7, 1.0), ratio=(0.8, 1.2))
]
else:
window = int(image_size / 10)
transform += [
transforms.Resize((image_size + window, image_size + window),
interpolation=Image.ANTIALIAS)
]
transform += [
transforms.RandomResizedCrop(
image_size, scale=(0.7, 1.0), ratio=(0.8, 1.2))
]
if dataset != 'RafD' and mode == 'train':
transform += [transforms.RandomHorizontalFlip()]
transform += [transforms.ToTensor(), transforms.Normalize(mean, std)]
transform = transforms.Compose(transform)
dataset_module = getattr(
importlib.import_module('datasets.{}'.format(dataset)), dataset)
dataset = dataset_module(
image_size,
mode_data,
transform,
mode,
shuffling=shuffling or mode == 'train',
verbose=mode == 'train' and hvd.rank() == 0,
**kwargs)
if hvd.size() == 1:
data_loader = DataLoader(
dataset=dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers)
elif hvd.size() != 1:
if mode == 'train':
sampler = torch.utils.data.distributed.DistributedSampler(
dataset, num_replicas=hvd.size(), rank=hvd.rank())
else:
sampler = torch.utils.data.distributed.DistributedSampler(
dataset, num_replicas=1, rank=0)
data_loader = DataLoader(
dataset=dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
sampler=sampler)
return data_loader