-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdata_utils.py
More file actions
122 lines (99 loc) · 3.54 KB
/
data_utils.py
File metadata and controls
122 lines (99 loc) · 3.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
def process_image(
train=False,
thumbnail_size=256,
image_size=224,
mean=None,
std=None,
rotation_angle=30
):
"""
Return the function to process a dataset of images for training or
inference.
Args:
train (Bool): whether the processing will be applied to a
training dataset.
thumbnail_size (Int) - default 256: the size of the images shortest
side before center-cropping to image_size.
image_size (Int) - default 224: the size of image expected by the
model that will perform the inference. The default is 224 because most
of `torch.models` modules expect this size.
mean (List) - default [0.485, 0.456, 0.406]: value of the mean of the
dataset. The default is the mean of ImageNet.
std (List) - default [0.229, 0.224, 0.225]: value of the standard
deviation of the dataset. The default is the standard deviation of
ImageNet.
rotation_angle (Int) - default 30: maximum range of degrees to randomly
rotate the training dataset to. It is used to perform the training
data augmentation.
Returns:
transform (torchvision.transforms): a function containing the
transformations to perform on the images.
"""
if mean is None:
mean = (0.485, 0.456, 0.406)
if std is None:
std = (0.229, 0.224, 0.225)
if train:
# Preprocess and augment data
return transforms.Compose(
[
transforms.RandomRotation(rotation_angle),
transforms.RandomResizedCrop(image_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean, std),
]
)
else:
return transforms.Compose(
[
transforms.Resize(thumbnail_size),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize(mean, std),
]
)
def get_datasets(data_dir_dict):
"""
Fetch the images from the training and the cross validation folder
with their correct labels and preprocess them.
Args:
data_dir_dict (Dict): a dictionary containing the path to the training
and the validation images folder.
Returns:
datasets (Dict): a dictionary containing the training and the
validation datasets preprocessed along with their correct labels.
"""
dataset = {
"train": datasets.ImageFolder(
data_dir_dict["train"],
transform=process_image(train=True),
),
"valid": datasets.ImageFolder(
data_dir_dict["valid"],
transform=process_image(train=False),
),
}
return dataset
def get_dataloaders(datasets_dict, batch_size):
"""
Load data from the dataset for training or inference.
Args:
datasets_dict (Dict): a dictionary containing the training and the
validation datasets.
batch_size (Int): the quantity of data to load at each
iteration.
Returns:
dataloaders (Dict): a dictionary containing the dataloaders.
"""
dataloaders = {
"train": DataLoader(
datasets_dict["train"], batch_size=batch_size, shuffle=True
),
"valid": DataLoader(
datasets_dict["valid"], batch_size=batch_size, shuffle=True
),
}
return dataloaders