This repository was archived by the owner on Mar 10, 2026. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 327
Expand file tree
/
Copy pathload.py
More file actions
131 lines (110 loc) · 4.36 KB
/
load.py
File metadata and controls
131 lines (110 loc) · 4.36 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
from tensorflow.keras import layers
def parse_imagenet_example(img_size, crop_to_aspect_ratio):
"""Function to parse a TFRecord example into an image and label"""
resizing = None
if img_size:
resizing = layers.Resizing(
width=img_size[0],
height=img_size[1],
crop_to_aspect_ratio=crop_to_aspect_ratio,
)
def apply(example):
# Read example
image_key = "image/encoded"
label_key = "image/class/label"
keys_to_features = {
image_key: tf.io.FixedLenFeature((), tf.string, ""),
label_key: tf.io.FixedLenFeature([], tf.int64, -1),
}
parsed = tf.io.parse_single_example(example, keys_to_features)
# Decode and resize image
image_bytes = tf.reshape(parsed[image_key], shape=[])
image = tf.io.decode_jpeg(image_bytes, channels=3)
if resizing:
image = resizing(image)
# Decode label
label = (
tf.cast(tf.reshape(parsed[label_key], shape=()), dtype=tf.int32) - 1
)
label = tf.one_hot(label, 1000)
return image, label
return apply
def load(
split,
tfrecord_path,
batch_size=None,
shuffle=True,
shuffle_buffer=None,
reshuffle_each_iteration=False,
img_size=None,
crop_to_aspect_ratio=True,
):
"""Loads the ImageNet dataset from TFRecords
Usage:
```python
dataset, ds_info = keras_cv.datasets.imagenet.load(
split="train", tfrecord_path="gs://my-bucket/imagenet-tfrecords"
)
```
Args:
split: the split to load. Should be one of "train" or "validation."
tfrecord_path: the path to your preprocessed ImageNet TFRecords.
See keras_cv/datasets/imagenet/README.md for preprocessing
instructions.
batch_size: how many instances to include in batches after loading.
Should only be specified if img_size is specified (so that images
can be resized to the same size before batching).
shuffle: whether to shuffle the dataset, defaults to `True`.
shuffle_buffer: the size of the buffer to use in shuffling.
reshuffle_each_iteration: whether to reshuffle the dataset on every
epoch, defaults to `False`.
img_size: the size to resize the images to, when None, this indicates
that images should not be resized. Defaults to `None`.
Returns:
tf.data.Dataset containing ImageNet. Each entry is a dictionary
containing keys {"image": image, "label": label} where images is a
Tensor of shape [H, W, 3] and label is a Tensor of shape [1000].
"""
if batch_size is not None and img_size is None:
raise ValueError(
"Batching can only be performed if images are resized."
)
num_splits = 1024 if split == "train" else 128
filenames = [
f"{tfrecord_path}/{split}-{i:05d}-of-{num_splits:05d}"
for i in range(0, num_splits)
]
dataset = tf.data.TFRecordDataset(
filenames=filenames, num_parallel_reads=tf.data.AUTOTUNE
)
dataset = dataset.map(
parse_imagenet_example(img_size, crop_to_aspect_ratio),
num_parallel_calls=tf.data.AUTOTUNE,
)
if shuffle:
if not batch_size and not shuffle_buffer:
raise ValueError(
"If `shuffle=True`, either a `batch_size` or `shuffle_buffer` "
"must be provided to `keras_cv.datasets.imagenet.load().`"
)
shuffle_buffer = shuffle_buffer or 8 * batch_size
dataset = dataset.shuffle(
shuffle_buffer, reshuffle_each_iteration=reshuffle_each_iteration
)
if batch_size is not None:
dataset = dataset.batch(batch_size)
return dataset.prefetch(tf.data.AUTOTUNE)