This repository was archived by the owner on Sep 18, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 436
Expand file tree
/
Copy pathdirectory_iterator.py
More file actions
173 lines (159 loc) · 7.54 KB
/
directory_iterator.py
File metadata and controls
173 lines (159 loc) · 7.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
"""Utilities for real-time data augmentation on image data.
"""
import multiprocessing.pool
import os
import numpy as np
from .iterator import BatchFromFilesMixin, Iterator
from .utils import _list_valid_filenames_in_directory
class DirectoryIterator(BatchFromFilesMixin, Iterator):
"""Iterator capable of reading images from a directory on disk.
# Arguments
directory: string, path to the directory to read images from.
Each subdirectory in this directory will be
considered to contain images from one class,
or alternatively you could specify class subdirectories
via the `classes` argument.
image_data_generator: Instance of `ImageDataGenerator`
to use for random transformations and normalization.
target_size: tuple of integers, dimensions to resize input images to.
color_mode: One of `"rgb"`, `"rgba"`, `"grayscale"`.
Color mode to read images.
classes: Optional list of strings, names of subdirectories
containing images from each class (e.g. `["dogs", "cats"]`).
It will be computed automatically if not set.
class_mode: Mode for yielding the targets:
`"binary"`: binary targets (if there are only two classes),
`"categorical"`: categorical targets,
`"sparse"`: integer targets,
`"input"`: targets are images identical to input images (mainly
used to work with autoencoders),
`None`: no targets get yielded (only input images are yielded).
batch_size: Integer, size of a batch.
shuffle: Boolean, whether to shuffle the data between epochs.
If set to False, sorts the data in alphanumeric order.
seed: Random seed for data shuffling.
data_format: String, one of `channels_first`, `channels_last`.
save_to_dir: Optional directory where to save the pictures
being yielded, in a viewable format. This is useful
for visualizing the random transformations being
applied, for debugging purposes.
save_prefix: String prefix to use for saving sample
images (if `save_to_dir` is set).
save_format: Format to use for saving sample images
(if `save_to_dir` is set).
follow_links: Boolean, follow symbolic links to subdirectories
subset: Subset of data (`"training"` or `"validation"`) if
validation_split is set in ImageDataGenerator.
interpolation: Interpolation method used to resample the image if the
target size is different from that of the loaded image.
Supported methods are "nearest", "bilinear", and "bicubic".
If PIL version 1.1.3 or newer is installed, "lanczos" is also
supported. If PIL version 3.4.0 or newer is installed, "box" and
"hamming" are also supported. By default, "nearest" is used.
keep_aspect_ratio: Boolean, whether to resize images to a target size
without aspect ratio distortion. The image is cropped in the center
with target aspect ratio before resizing.
dtype: Dtype to use for generated arrays.
verbose: 1 or 0. Defaults to 1. 0=silent mode,
1=prints total number of samples and classes.
"""
allowed_class_modes = {'categorical', 'binary', 'sparse', 'input', None}
def __new__(cls, *args, **kwargs):
try:
from tensorflow.keras.utils import Sequence as TFSequence
if TFSequence not in cls.__bases__:
cls.__bases__ = cls.__bases__ + (TFSequence,)
except ImportError:
pass
return super(DirectoryIterator, cls).__new__(cls)
def __init__(self,
directory,
image_data_generator,
target_size=(256, 256),
color_mode='rgb',
classes=None,
class_mode='categorical',
batch_size=32,
shuffle=True,
seed=None,
data_format='channels_last',
save_to_dir=None,
save_prefix='',
save_format='png',
follow_links=False,
subset=None,
interpolation='nearest',
keep_aspect_ratio=False,
dtype='float32',
verbose=1):
super(DirectoryIterator, self).set_processing_attrs(image_data_generator,
target_size,
color_mode,
data_format,
save_to_dir,
save_prefix,
save_format,
subset,
interpolation,
keep_aspect_ratio)
self.directory = directory
self.classes = classes
if class_mode not in self.allowed_class_modes:
raise ValueError('Invalid class_mode: {}; expected one of: {}'
.format(class_mode, self.allowed_class_modes))
self.class_mode = class_mode
self.dtype = dtype
# First, count the number of samples and classes.
self.samples = 0
if not classes:
classes = []
for subdir in sorted(os.listdir(directory)):
if os.path.isdir(os.path.join(directory, subdir)):
classes.append(subdir)
self.num_classes = len(classes)
self.class_indices = dict(zip(classes, range(len(classes))))
pool = multiprocessing.pool.ThreadPool()
# Second, build an index of the images
# in the different class subfolders.
results = []
self.filenames = []
i = 0
for dirpath in (os.path.join(directory, subdir) for subdir in classes):
results.append(
pool.apply_async(_list_valid_filenames_in_directory,
(dirpath, self.white_list_formats, self.split,
self.class_indices, follow_links)))
classes_list = []
for res in results:
classes, filenames = res.get()
classes_list.append(classes)
self.filenames += filenames
self.samples = len(self.filenames)
self.classes = np.zeros((self.samples,), dtype='int32')
for classes in classes_list:
self.classes[i:i + len(classes)] = classes
i += len(classes)
if verbose == 0:
pass
else :
print('Found %d images belonging to %d classes.' %
(self.samples, self.num_classes))
pool.close()
pool.join()
self._filepaths = [
os.path.join(self.directory, fname) for fname in self.filenames
]
super(DirectoryIterator, self).__init__(self.samples,
batch_size,
shuffle,
seed)
@property
def filepaths(self):
return self._filepaths
@property
def labels(self):
return self.classes
@property # mixin needs this property to work
def sample_weight(self):
# no sample weights will be returned
return None