Skip to content

Commit 7dad9ad

Browse files
author
xinetzone
committed
创建 MNIST 与 Cifar 的数据解析工具并将其打包
new file: custom/__init__.py new file: custom/cifar.py 解析 cifar 数据 new file: custom/genX.py 打包 MNIST 与 Cifar new file: custom/mnist.py 解析 mnist 数据
1 parent 979b73f commit 7dad9ad

File tree

4 files changed

+236
-0
lines changed

4 files changed

+236
-0
lines changed

custom/__init__.py

Whitespace-only changes.

custom/cifar.py

+92
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
'''
2+
作者:xinetzone
3+
时间:2019/12/3
4+
'''
5+
import tarfile
6+
from pathlib import Path
7+
import pickle
8+
import time
9+
import numpy as np
10+
11+
12+
class Cifar:
13+
def __init__(self, root, namespace):
14+
"""CIFAR image classification dataset from https://www.cs.toronto.edu/~kriz/cifar.html
15+
Each sample is an image (in 3D NDArray) with shape (3, 32, 32).
16+
17+
参数
18+
=========
19+
meta : 保存了类别信息
20+
root : str, 数据根目录
21+
namespace : 'cifar-10' 或 'cifar-100'
22+
"""
23+
#super().__init__(*args, **kwds)
24+
#self.__dict__ = self
25+
self.root = Path(root)
26+
# 解压数据集到 root,并将解析后的数据载入内存
27+
self._load(namespace)
28+
29+
def _extractall(self, namespace):
30+
'''解压 tar 文件并返回路径
31+
32+
参数
33+
========
34+
tar_name:tar 文件名称
35+
'''
36+
tar_name = self.root / f'{namespace}-python.tar.gz'
37+
with tarfile.open(tar_name) as tar:
38+
tar.extractall(self.root) # 解压全部文件
39+
names = tar.getnames() # 获取解压后的文件所在目录
40+
return names
41+
42+
def _decode(self, path):
43+
'''载入二进制流到内存'''
44+
with open(path, 'rb') as fp: # 打开文件
45+
# 载入数据到内存
46+
data = pickle.load(fp, encoding='bytes')
47+
return data
48+
49+
def _load_cifar10(self, names):
50+
'''将解析后的 cifar10 数据载入内存'''
51+
# 获取数据根目录
52+
R = [self.root /
53+
name for name in names if (self.root / name).is_dir()][0]
54+
# 元数据信息
55+
meta = self._decode(list(R.glob('*.meta'))[0])
56+
# 训练集信息
57+
train = [self._decode(path) for path in R.glob('*_batch_*')]
58+
# 测试集信息
59+
test = [self._decode(path) for path in R.glob('*test*')][0]
60+
return meta, train, test
61+
62+
def _load_cifar100(self, names):
63+
'''将解析后的 cifar100 数据载入内存'''
64+
# 获取数据根目录
65+
R = [self.root /
66+
name for name in names if (self.root / name).is_dir()][0]
67+
# 元数据信息
68+
meta = self._decode(list(R.glob('*meta*'))[0])
69+
# 训练集信息
70+
train = [self._decode(path) for path in R.glob('*train*')][0]
71+
# 测试集信息
72+
test = [self._decode(path) for path in R.glob('*test*')][0]
73+
return meta, train, test
74+
75+
def _load(self, namespace):
76+
# 解压数据集到 root,并返回文件列表
77+
names = self._extractall(namespace)
78+
if namespace == 'cifar-10':
79+
self.meta, train, test = self._load_cifar10(names)
80+
self.trainX = np.concatenate(
81+
[x[b'data'] for x in train]).reshape(-1, 3, 32, 32)
82+
self.trainY = np.concatenate([x[b'labels'] for x in train])
83+
self.testX = np.array(test[b'data']).reshape(-1, 3, 32, 32)
84+
self.testY = np.array(test[b'labels'])
85+
elif namespace == 'cifar-100':
86+
self.meta, train, test = self._load_cifar100(names)
87+
self.trainX = np.array(train[b'data']).reshape(-1, 3, 32, 32)
88+
self.testX = np.array(test[b'data']).reshape(-1, 3, 32, 32)
89+
self.train_fine_labels = np.array(train[b'fine_labels'])
90+
self.train_coarse_labels = np.array(train[b'coarse_labels'])
91+
self.test_fine_labels = np.array(test[b'fine_labels'])
92+
self.test_coarse_labels = np.array(test[b'coarse_labels'])

custom/genX.py

+77
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
'''
2+
作者:xinetzone
3+
时间:2019/12/3
4+
'''
5+
import tables as tb
6+
import numpy as np
7+
8+
from custom.cifar import Cifar
9+
from custom.mnist import MNIST
10+
11+
12+
class Bunch(dict):
13+
def __init__(self, root, *args, **kwargs):
14+
"""将数据 MNIST,Fashion MNIST,Cifar 10,Cifar 100 打包
15+
为 HDF5
16+
17+
参数
18+
=========
19+
root : 数据的根目录
20+
"""
21+
super().__init__(*args, **kwargs)
22+
self.__dict__ = self
23+
self.mnist = MNIST(root, 'mnist')
24+
self.fashion_mnist = MNIST(root, 'fashion-mnist')
25+
self.cifar10 = Cifar(root, 'cifar-10')
26+
self.cifar100 = Cifar(root, 'cifar-100')
27+
28+
def _change(self, img):
29+
'''将数据由 (num, channel, h, w) 转换为 (num, h, w, channel)'''
30+
return np.transpose(img, (0, 2, 3, 1))
31+
32+
def toHDF5(self, save_path):
33+
'''将数据打包为 HDF5 格式
34+
35+
参数
36+
===========
37+
save_path:数据保存的路径
38+
'''
39+
filters = tb.Filters(complevel=7, shuffle=False)
40+
with tb.open_file(f'{save_path}/X.h5', 'w', filters=filters, title='Xinet\'s dataset') as h5:
41+
for name in self:
42+
h5.create_group('/', name, title=name)
43+
if name in ['mnist', 'fashion_mnist']:
44+
h5.create_array(
45+
h5.root[name], 'trainX', self[name].train_data)
46+
h5.create_array(
47+
h5.root[name], 'trainY', self[name].train_label)
48+
h5.create_array(
49+
h5.root[name], 'testX', self[name].test_data)
50+
h5.create_array(
51+
h5.root[name], 'testY', self[name].test_label)
52+
elif name == 'cifar10':
53+
h5.create_array(
54+
h5.root[name], 'trainX', self._change(self[name].trainX))
55+
h5.create_array(h5.root[name], 'trainY', self[name].trainY)
56+
h5.create_array(
57+
h5.root[name], 'testX', self._change(self[name].testX))
58+
h5.create_array(h5.root[name], 'testY', self[name].testY)
59+
h5.create_array(h5.root[name], 'label_names', np.array(
60+
self[name].meta[b'label_names']))
61+
elif name == 'cifar100':
62+
h5.create_array(
63+
h5.root[name], 'trainX', self._change(self[name].trainX))
64+
h5.create_array(
65+
h5.root[name], 'testX', self._change(self[name].testX))
66+
h5.create_array(
67+
h5.root[name], 'train_coarse_labels', self[name].train_coarse_labels)
68+
h5.create_array(
69+
h5.root[name], 'test_coarse_labels', self[name].test_coarse_labels)
70+
h5.create_array(
71+
h5.root[name], 'train_fine_labels', self[name].train_fine_labels)
72+
h5.create_array(
73+
h5.root[name], 'test_fine_labels', self[name].test_fine_labels)
74+
h5.create_array(h5.root[name], 'coarse_label_names', np.array(
75+
self[name].meta[b'coarse_label_names']))
76+
h5.create_array(h5.root[name], 'fine_label_names', np.array(
77+
self[name].meta[b'fine_label_names']))

custom/mnist.py

+67
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
'''
2+
作者:xinetzone
3+
时间:2019/12/3
4+
'''
5+
import struct
6+
from pathlib import Path
7+
import numpy as np
8+
import gzip
9+
10+
11+
class MNIST:
12+
def __init__(self, root, namespace):
13+
"""MNIST 与 FASGION-MNIST 数据解码工具
14+
1. (MNIST handwritten digits dataset from http://yann.lecun.com/exdb/mnist) 下载后放置在 `mnist` 目录
15+
2. (A dataset of Zalando's article images consisting of fashion products,
16+
a drop-in replacement of the original MNIST dataset from https://github.com/zalandoresearch/fashion-mnist)
17+
数据下载放置在 `fashion-mnist` 目录
18+
19+
Each sample is an image (in 2D NDArray) with shape (28, 28).
20+
21+
22+
23+
参数
24+
========
25+
root : 数据根目录,如 'E:/Data/Zip/'
26+
namespace : 'mnist' or 'fashion-mnist'
27+
28+
实例属性
29+
========
30+
train_data:训练数据集图片
31+
train_label:训练数据集标签名称
32+
test_data:测试数据集图片
33+
test_label:测试数据集标签名称
34+
"""
35+
root = Path(root) / namespace
36+
self._name2array(root)
37+
38+
def _name2array(self, root):
39+
'''
40+
官方网站是以 `[offset][type][value][description]` 的格式封装数据的,
41+
因而我们使用 `struct.unpack`
42+
'''
43+
_train_data = root / 'train-images-idx3-ubyte.gz' # 训练数据集文件名
44+
_train_label = root / 'train-labels-idx1-ubyte.gz' # 训练数据集的标签文件名
45+
_test_data = root / 't10k-images-idx3-ubyte.gz' # 测试数据集文件名
46+
_test_label = root / 't10k-labels-idx1-ubyte.gz' # 测试数据集的标签文件名
47+
self.train_data = self.get_data(_train_data) # 获得训练数据集图片
48+
self.train_label = self.get_label(_train_label) # 获得训练数据集标签名称
49+
self.test_data = self.get_data(_test_data) # 获得测试数据集图片
50+
self.test_label = self.get_label(_test_label) # 获得测试数据集标签名称
51+
52+
def get_data(self, data):
53+
'''获取图像信息'''
54+
with gzip.open(data, 'rb') as fin:
55+
shape = struct.unpack(">IIII", fin.read(16))[1:]
56+
data = np.frombuffer(fin.read(), dtype=np.uint8)
57+
data = data.reshape(shape)
58+
return data
59+
60+
def get_label(self, label):
61+
'''获取标签信息'''
62+
with gzip.open(label, 'rb') as fin:
63+
struct.unpack(">II", fin.read(8)) # 参考数据集的网站,即 offset=8
64+
# 获得数据集的标签
65+
label = fin.read()
66+
label = np.frombuffer(label, dtype=np.uint8).astype(np.int32)
67+
return label

0 commit comments

Comments
 (0)