Skip to content

Commit 5dd95fc

Browse files
sampathwebtensorflower-gardener
authored andcommitted
Allow passing a custom cache_dir to tf.keras.datasets.load_data. This is helpful when the default location ~/.keras in home directory has limited disk space.
PiperOrigin-RevId: 713015638
1 parent 6f44991 commit 5dd95fc

21 files changed

+85
-28
lines changed

tf_keras/api/golden/v1/tensorflow.keras.datasets.boston_housing.pbtxt

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,6 @@ path: "tensorflow.keras.datasets.boston_housing"
22
tf_module {
33
member_method {
44
name: "load_data"
5-
argspec: "args=[\'path\', \'test_split\', \'seed\'], varargs=None, keywords=None, defaults=[\'boston_housing.npz\', \'0.2\', \'113\'], "
5+
argspec: "args=[\'path\', \'test_split\', \'seed\', \'cache_dir\'], varargs=None, keywords=None, defaults=[\'boston_housing.npz\', \'0.2\', \'113\', \'None\'], "
66
}
77
}

tf_keras/api/golden/v1/tensorflow.keras.datasets.cifar10.pbtxt

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,6 @@ path: "tensorflow.keras.datasets.cifar10"
22
tf_module {
33
member_method {
44
name: "load_data"
5-
argspec: "args=[], varargs=None, keywords=None, defaults=None"
5+
argspec: "args=[\'cache_dir\'], varargs=None, keywords=None, defaults=[\'None\'], "
66
}
77
}

tf_keras/api/golden/v1/tensorflow.keras.datasets.cifar100.pbtxt

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,6 @@ path: "tensorflow.keras.datasets.cifar100"
22
tf_module {
33
member_method {
44
name: "load_data"
5-
argspec: "args=[\'label_mode\'], varargs=None, keywords=None, defaults=[\'fine\'], "
5+
argspec: "args=[\'label_mode\', \'cache_dir\'], varargs=None, keywords=None, defaults=[\'fine\', \'None\'], "
66
}
77
}

tf_keras/api/golden/v1/tensorflow.keras.datasets.fashion_mnist.pbtxt

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,6 @@ path: "tensorflow.keras.datasets.fashion_mnist"
22
tf_module {
33
member_method {
44
name: "load_data"
5-
argspec: "args=[], varargs=None, keywords=None, defaults=None"
5+
argspec: "args=[\'cache_dir\'], varargs=None, keywords=None, defaults=[\'None\'], "
66
}
77
}

tf_keras/api/golden/v1/tensorflow.keras.datasets.imdb.pbtxt

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,6 @@ tf_module {
66
}
77
member_method {
88
name: "load_data"
9-
argspec: "args=[\'path\', \'num_words\', \'skip_top\', \'maxlen\', \'seed\', \'start_char\', \'oov_char\', \'index_from\'], varargs=None, keywords=kwargs, defaults=[\'imdb.npz\', \'None\', \'0\', \'None\', \'113\', \'1\', \'2\', \'3\'], "
9+
argspec: "args=[\'path\', \'num_words\', \'skip_top\', \'maxlen\', \'seed\', \'start_char\', \'oov_char\', \'index_from\', \'cache_dir\'], varargs=None, keywords=kwargs, defaults=[\'imdb.npz\', \'None\', \'0\', \'None\', \'113\', \'1\', \'2\', \'3\', \'None\'], "
1010
}
1111
}

tf_keras/api/golden/v1/tensorflow.keras.datasets.mnist.pbtxt

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,6 @@ path: "tensorflow.keras.datasets.mnist"
22
tf_module {
33
member_method {
44
name: "load_data"
5-
argspec: "args=[\'path\'], varargs=None, keywords=None, defaults=[\'mnist.npz\'], "
5+
argspec: "args=[\'path\', \'cache_dir\'], varargs=None, keywords=None, defaults=[\'mnist.npz\', \'None\'], "
66
}
77
}

tf_keras/api/golden/v1/tensorflow.keras.datasets.reuters.pbtxt

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,6 @@ tf_module {
1010
}
1111
member_method {
1212
name: "load_data"
13-
argspec: "args=[\'path\', \'num_words\', \'skip_top\', \'maxlen\', \'test_split\', \'seed\', \'start_char\', \'oov_char\', \'index_from\'], varargs=None, keywords=kwargs, defaults=[\'reuters.npz\', \'None\', \'0\', \'None\', \'0.2\', \'113\', \'1\', \'2\', \'3\'], "
13+
argspec: "args=[\'path\', \'num_words\', \'skip_top\', \'maxlen\', \'test_split\', \'seed\', \'start_char\', \'oov_char\', \'index_from\', \'cache_dir\'], varargs=None, keywords=kwargs, defaults=[\'reuters.npz\', \'None\', \'0\', \'None\', \'0.2\', \'113\', \'1\', \'2\', \'3\', \'None\'], "
1414
}
1515
}

tf_keras/api/golden/v2/tensorflow.keras.datasets.boston_housing.pbtxt

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,6 @@ path: "tensorflow.keras.datasets.boston_housing"
22
tf_module {
33
member_method {
44
name: "load_data"
5-
argspec: "args=[\'path\', \'test_split\', \'seed\'], varargs=None, keywords=None, defaults=[\'boston_housing.npz\', \'0.2\', \'113\'], "
5+
argspec: "args=[\'path\', \'test_split\', \'seed\', \'cache_dir\'], varargs=None, keywords=None, defaults=[\'boston_housing.npz\', \'0.2\', \'113\', \'None\'], "
66
}
77
}

tf_keras/api/golden/v2/tensorflow.keras.datasets.cifar10.pbtxt

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,6 @@ path: "tensorflow.keras.datasets.cifar10"
22
tf_module {
33
member_method {
44
name: "load_data"
5-
argspec: "args=[], varargs=None, keywords=None, defaults=None"
5+
argspec: "args=[\'cache_dir\'], varargs=None, keywords=None, defaults=[\'None\'], "
66
}
77
}

tf_keras/api/golden/v2/tensorflow.keras.datasets.cifar100.pbtxt

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,6 @@ path: "tensorflow.keras.datasets.cifar100"
22
tf_module {
33
member_method {
44
name: "load_data"
5-
argspec: "args=[\'label_mode\'], varargs=None, keywords=None, defaults=[\'fine\'], "
5+
argspec: "args=[\'label_mode\', \'cache_dir\'], varargs=None, keywords=None, defaults=[\'fine\', \'None\'], "
66
}
77
}

tf_keras/api/golden/v2/tensorflow.keras.datasets.fashion_mnist.pbtxt

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,6 @@ path: "tensorflow.keras.datasets.fashion_mnist"
22
tf_module {
33
member_method {
44
name: "load_data"
5-
argspec: "args=[], varargs=None, keywords=None, defaults=None"
5+
argspec: "args=[\'cache_dir\'], varargs=None, keywords=None, defaults=[\'None\'], "
66
}
77
}

tf_keras/api/golden/v2/tensorflow.keras.datasets.imdb.pbtxt

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,6 @@ tf_module {
66
}
77
member_method {
88
name: "load_data"
9-
argspec: "args=[\'path\', \'num_words\', \'skip_top\', \'maxlen\', \'seed\', \'start_char\', \'oov_char\', \'index_from\'], varargs=None, keywords=kwargs, defaults=[\'imdb.npz\', \'None\', \'0\', \'None\', \'113\', \'1\', \'2\', \'3\'], "
9+
argspec: "args=[\'path\', \'num_words\', \'skip_top\', \'maxlen\', \'seed\', \'start_char\', \'oov_char\', \'index_from\', \'cache_dir\'], varargs=None, keywords=kwargs, defaults=[\'imdb.npz\', \'None\', \'0\', \'None\', \'113\', \'1\', \'2\', \'3\', \'None\'], "
1010
}
1111
}

tf_keras/api/golden/v2/tensorflow.keras.datasets.mnist.pbtxt

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,6 @@ path: "tensorflow.keras.datasets.mnist"
22
tf_module {
33
member_method {
44
name: "load_data"
5-
argspec: "args=[\'path\'], varargs=None, keywords=None, defaults=[\'mnist.npz\'], "
5+
argspec: "args=[\'path\', \'cache_dir\'], varargs=None, keywords=None, defaults=[\'mnist.npz\', \'None\'], "
66
}
77
}

tf_keras/api/golden/v2/tensorflow.keras.datasets.reuters.pbtxt

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,6 @@ tf_module {
1010
}
1111
member_method {
1212
name: "load_data"
13-
argspec: "args=[\'path\', \'num_words\', \'skip_top\', \'maxlen\', \'test_split\', \'seed\', \'start_char\', \'oov_char\', \'index_from\'], varargs=None, keywords=kwargs, defaults=[\'reuters.npz\', \'None\', \'0\', \'None\', \'0.2\', \'113\', \'1\', \'2\', \'3\'], "
13+
argspec: "args=[\'path\', \'num_words\', \'skip_top\', \'maxlen\', \'test_split\', \'seed\', \'start_char\', \'oov_char\', \'index_from\', \'cache_dir\'], varargs=None, keywords=kwargs, defaults=[\'reuters.npz\', \'None\', \'0\', \'None\', \'0.2\', \'113\', \'1\', \'2\', \'3\', \'None\'], "
1414
}
1515
}

tf_keras/datasets/boston_housing.py

+14-5
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
# ==============================================================================
1515
"""Boston housing price regression dataset."""
1616

17+
import os
18+
1719
import numpy as np
1820

1921
from tf_keras.utils.data_utils import get_file
@@ -23,7 +25,9 @@
2325

2426

2527
@keras_export("keras.datasets.boston_housing.load_data")
26-
def load_data(path="boston_housing.npz", test_split=0.2, seed=113):
28+
def load_data(
29+
path="boston_housing.npz", test_split=0.2, seed=113, cache_dir=None
30+
):
2731
"""Loads the Boston Housing dataset.
2832
2933
This is a dataset taken from the StatLib library which is maintained at
@@ -43,11 +47,12 @@ def load_data(path="boston_housing.npz", test_split=0.2, seed=113):
4347
[StatLib website](http://lib.stat.cmu.edu/datasets/boston).
4448
4549
Args:
46-
path: path where to cache the dataset locally
47-
(relative to `~/.keras/datasets`).
50+
path: path where to cache the dataset locally (relative to
51+
`~/.keras/datasets`).
4852
test_split: fraction of the data to reserve as test set.
49-
seed: Random seed for shuffling the data
50-
before computing the test split.
53+
seed: Random seed for shuffling the data before computing the test split.
54+
cache_dir: directory where to cache the dataset locally. When None,
55+
defaults to `~/.keras/datasets`.
5156
5257
Returns:
5358
Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
@@ -64,12 +69,16 @@ def load_data(path="boston_housing.npz", test_split=0.2, seed=113):
6469
origin_folder = (
6570
"https://storage.googleapis.com/tensorflow/tf-keras-datasets/"
6671
)
72+
if cache_dir:
73+
cache_dir = os.path.expanduser(cache_dir)
74+
os.makedirs(cache_dir, exist_ok=True)
6775
path = get_file(
6876
path,
6977
origin=origin_folder + "boston_housing.npz",
7078
file_hash=( # noqa: E501
7179
"f553886a1f8d56431e820c5b82552d9d95cfcb96d1e678153f8839538947dff5"
7280
),
81+
cache_dir=cache_dir,
7382
)
7483
with np.load(path, allow_pickle=True) as f:
7584
x = f["x"]

tf_keras/datasets/cifar10.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828

2929
@keras_export("keras.datasets.cifar10.load_data")
30-
def load_data():
30+
def load_data(cache_dir=None):
3131
"""Loads the CIFAR10 dataset.
3232
3333
This is a dataset of 50,000 32x32 color training images and 10,000 test
@@ -49,6 +49,10 @@ def load_data():
4949
| 8 | ship |
5050
| 9 | truck |
5151
52+
Args:
53+
cache_dir: directory where to cache the dataset locally. When None,
54+
defaults to `~/.keras/datasets`.
55+
5256
Returns:
5357
Tuple of NumPy arrays: `(x_train, y_train), (x_test, y_test)`.
5458
@@ -78,13 +82,17 @@ def load_data():
7882
"""
7983
dirname = "cifar-10-batches-py"
8084
origin = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
85+
if cache_dir:
86+
cache_dir = os.path.expanduser(cache_dir)
87+
os.makedirs(cache_dir, exist_ok=True)
8188
path = get_file(
8289
dirname,
8390
origin=origin,
8491
untar=True,
8592
file_hash=( # noqa: E501
8693
"6d958be074577803d12ecdefd02955f39262c83c16fe9348329d7fe0b5c001ce"
8794
),
95+
cache_dir=cache_dir,
8896
)
8997

9098
num_train_samples = 50000

tf_keras/datasets/cifar100.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828

2929
@keras_export("keras.datasets.cifar100.load_data")
30-
def load_data(label_mode="fine"):
30+
def load_data(label_mode="fine", cache_dir=None):
3131
"""Loads the CIFAR100 dataset.
3232
3333
This is a dataset of 50,000 32x32 color training images and
@@ -39,6 +39,8 @@ def load_data(label_mode="fine"):
3939
label_mode: one of "fine", "coarse". If it is "fine" the category labels
4040
are the fine-grained labels, if it is "coarse" the output labels are the
4141
coarse-grained superclasses.
42+
cache_dir: directory where to cache the dataset locally. When None,
43+
defaults to `~/.keras/datasets`.
4244
4345
Returns:
4446
Tuple of NumPy arrays: `(x_train, y_train), (x_test, y_test)`.
@@ -75,13 +77,17 @@ def load_data(label_mode="fine"):
7577

7678
dirname = "cifar-100-python"
7779
origin = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
80+
if cache_dir:
81+
cache_dir = os.path.expanduser(cache_dir)
82+
os.makedirs(cache_dir, exist_ok=True)
7883
path = get_file(
7984
dirname,
8085
origin=origin,
8186
untar=True,
8287
file_hash=( # noqa: E501
8388
"85cd44d02ba6437773c5bbd22e183051d648de2e7d6b014e1ef29b855ba677a7"
8489
),
90+
cache_dir=cache_dir,
8591
)
8692

8793
fpath = os.path.join(path, "train")

tf_keras/datasets/fashion_mnist.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727

2828
@keras_export("keras.datasets.fashion_mnist.load_data")
29-
def load_data():
29+
def load_data(cache_dir=None):
3030
"""Loads the Fashion-MNIST dataset.
3131
3232
This is a dataset of 60,000 28x28 grayscale images of 10 fashion categories,
@@ -48,6 +48,10 @@ def load_data():
4848
| 8 | Bag |
4949
| 9 | Ankle boot |
5050
51+
Args:
52+
cache_dir: directory where to cache the dataset locally. When None,
53+
defaults to `~/.keras/datasets`.
54+
5155
Returns:
5256
Tuple of NumPy arrays: `(x_train, y_train), (x_test, y_test)`.
5357
@@ -77,7 +81,6 @@ def load_data():
7781
The copyright for Fashion-MNIST is held by Zalando SE.
7882
Fashion-MNIST is licensed under the [MIT license](
7983
https://github.com/zalandoresearch/fashion-mnist/blob/master/LICENSE).
80-
8184
"""
8285
dirname = os.path.join("datasets", "fashion-mnist")
8386
base = "https://storage.googleapis.com/tensorflow/tf-keras-datasets/"
@@ -87,10 +90,19 @@ def load_data():
8790
"t10k-labels-idx1-ubyte.gz",
8891
"t10k-images-idx3-ubyte.gz",
8992
]
90-
93+
if cache_dir:
94+
cache_dir = os.path.expanduser(cache_dir)
95+
os.makedirs(cache_dir, exist_ok=True)
9196
paths = []
9297
for fname in files:
93-
paths.append(get_file(fname, origin=base + fname, cache_subdir=dirname))
98+
paths.append(
99+
get_file(
100+
fname,
101+
origin=base + fname,
102+
cache_dir=cache_dir,
103+
cache_subdir=dirname,
104+
)
105+
)
94106

95107
with gzip.open(paths[0], "rb") as lbpath:
96108
y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8)

tf_keras/datasets/imdb.py

+8
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""IMDB sentiment classification dataset."""
1616

1717
import json
18+
import os
1819

1920
import numpy as np
2021

@@ -36,6 +37,7 @@ def load_data(
3637
start_char=1,
3738
oov_char=2,
3839
index_from=3,
40+
cache_dir=None,
3941
**kwargs,
4042
):
4143
"""Loads the [IMDB dataset](https://ai.stanford.edu/~amaas/data/sentiment/).
@@ -73,6 +75,8 @@ def load_data(
7375
Words that were cut out because of the `num_words` or
7476
`skip_top` limits will be replaced with this character.
7577
index_from: int. Index actual words with this index and higher.
78+
cache_dir: directory where to cache the dataset locally. When None,
79+
defaults to `~/.keras/datasets`.
7680
**kwargs: Used for backwards compatibility.
7781
7882
Returns:
@@ -108,12 +112,16 @@ def load_data(
108112
origin_folder = (
109113
"https://storage.googleapis.com/tensorflow/tf-keras-datasets/"
110114
)
115+
if cache_dir:
116+
cache_dir = os.path.expanduser(cache_dir)
117+
os.makedirs(cache_dir, exist_ok=True)
111118
path = get_file(
112119
path,
113120
origin=origin_folder + "imdb.npz",
114121
file_hash=( # noqa: E501
115122
"69664113be75683a8fe16e3ed0ab59fda8886cb3cd7ada244f7d9544e4676b9f"
116123
),
124+
cache_dir=cache_dir,
117125
)
118126
with np.load(path, allow_pickle=True) as f:
119127
x_train, labels_train = f["x_train"], f["y_train"]

tf_keras/datasets/mnist.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
# ==============================================================================
1515
"""MNIST handwritten digits dataset."""
16+
import os
1617

1718
import numpy as np
1819

@@ -23,7 +24,7 @@
2324

2425

2526
@keras_export("keras.datasets.mnist.load_data")
26-
def load_data(path="mnist.npz"):
27+
def load_data(path="mnist.npz", cache_dir=None):
2728
"""Loads the MNIST dataset.
2829
2930
This is a dataset of 60,000 28x28 grayscale images of the 10 digits,
@@ -32,8 +33,9 @@ def load_data(path="mnist.npz"):
3233
[MNIST homepage](http://yann.lecun.com/exdb/mnist/).
3334
3435
Args:
35-
path: path where to cache the dataset locally
36-
(relative to `~/.keras/datasets`).
36+
path: path where to cache the dataset locally relative to cache_dir.
37+
cache_dir: directory where to cache the dataset locally. When None,
38+
defaults to `~/.keras/datasets`.
3739
3840
Returns:
3941
Tuple of NumPy arrays: `(x_train, y_train), (x_test, y_test)`.
@@ -72,12 +74,16 @@ def load_data(path="mnist.npz"):
7274
origin_folder = (
7375
"https://storage.googleapis.com/tensorflow/tf-keras-datasets/"
7476
)
77+
if cache_dir:
78+
cache_dir = os.path.expanduser(cache_dir)
79+
os.makedirs(cache_dir, exist_ok=True)
7580
path = get_file(
7681
path,
7782
origin=origin_folder + "mnist.npz",
7883
file_hash=( # noqa: E501
7984
"731c5ac602752760c8e48fbffcf8c3b850d9dc2a2aedcf2cc48468fc17b673d1"
8085
),
86+
cache_dir=cache_dir,
8187
)
8288
with np.load(path, allow_pickle=True) as f:
8389
x_train, y_train = f["x_train"], f["y_train"]

0 commit comments

Comments
 (0)