Skip to content

Commit 683d261

Browse files
author
xinetzone
committed
创建 X 数据的解析工具
new file: utils/__init__.py new file: utils/dataset.py 通用数据工具 new file: utils/xdecode.py 解析 X.h5 数据
1 parent 7dad9ad commit 683d261

File tree

3 files changed

+131
-0
lines changed

3 files changed

+131
-0
lines changed

utils/__init__.py

Whitespace-only changes.

utils/dataset.py

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import pickle
2+
3+
4+
class Bunch(dict):
5+
def __init__(self, *args, **kwargs):
6+
super().__init__(*args, **kwargs)
7+
self.__dict__ = self
8+
9+
10+
def bunch2json(bunch, path):
11+
# bunch 序列化为 JSON
12+
with open(path, 'wb') as fp:
13+
pickle.dump(bunch, fp)
14+
15+
16+
def json2bunch(path):
17+
# JSON 反序列化为 bunch
18+
with open(path, 'rb') as fp:
19+
X = pickle.load(fp)
20+
return X

utils/xdecode.py

+111
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
import numpy as np
2+
import tables as tb
3+
4+
5+
class XDecode:
6+
def __init__(self, root='../datasome/X.h5'):
7+
'''解析数据 X
8+
使用完该实例记得关闭 self.h5.close()
9+
10+
实例属性
11+
========
12+
members 即 ['mnist', 'fashion_mnist', 'cifar100', 'cifar10']
13+
'''
14+
self.h5 = tb.open_file(root)
15+
self.members = self.h5.root.__members__
16+
17+
def summary(self):
18+
'''打印 X 数据的摘要信息'''
19+
print(self.h5)
20+
21+
def get_members(self, name):
22+
'''获得给定 name 的成员名称列表
23+
24+
参数
25+
=======
26+
name in ['mnist', 'fashion_mnist', 'cifar100', 'cifar10']
27+
'''
28+
return self.h5.root[name].__members__
29+
30+
def get_mnist_members(self):
31+
return self.h5.root.mnist.__members__
32+
33+
def get_trainX(self, name):
34+
'''获得给定 name 的 trainX
35+
36+
参数
37+
=======
38+
name in ['mnist', 'fashion_mnist', 'cifar100', 'cifar10']
39+
'''
40+
return self.h5.get_node(f'/{name}', 'trainX')
41+
42+
def get_testX(self, name):
43+
'''获得给定 name 的 testX
44+
45+
参数
46+
=======
47+
name in ['mnist', 'fashion_mnist', 'cifar100', 'cifar10']
48+
'''
49+
return self.h5.get_node(f'/{name}', 'testX')
50+
51+
def get_trainY(self, name):
52+
'''获得给定 name 的 trainY
53+
54+
参数
55+
=======
56+
name in ['mnist', 'fashion_mnist', 'cifar10']
57+
'''
58+
return self.h5.get_node(f'/{name}', 'trainY')
59+
60+
def get_testY(self, name):
61+
'''获得给定 name 的 testY
62+
63+
参数
64+
=======
65+
name in ['mnist', 'fashion_mnist', 'cifar10']
66+
'''
67+
return self.h5.get_node(f'/{name}', 'testY')
68+
69+
def get_train_coarse_labels(self):
70+
'''获得 Cifar100 训练集的粗标签'''
71+
return self.h5.get_node('/cifar100', 'train_coarse_labels')
72+
73+
def get_train_fine_labels(self):
74+
'''获得 Cifar100 训练集的细标签'''
75+
return self.h5.get_node('/cifar100', 'train_fine_labels')
76+
77+
def get_test_coarse_labels(self):
78+
'''获得 Cifar100 测试集的粗标签'''
79+
return self.h5.get_node('/cifar100', 'test_coarse_labels')
80+
81+
def get_test_fine_labels(self):
82+
'''获得 Cifar100 的测试集细标签'''
83+
return self.h5.get_node('/cifar100', 'test_fine_labels')
84+
85+
def get_coarse_label_names(self):
86+
'''获得 Cifar100 测试集的粗标签的名称'''
87+
label_names = self.h5.get_node('/cifar100', 'coarse_label_names')
88+
return np.asanyarray(label_names, "U")
89+
90+
def get_fine_label_names(self):
91+
'''获得 Cifar100 的测试集细标签的名称'''
92+
label_names = self.h5.get_node('/cifar100', 'fine_label_names')
93+
return np.asanyarray(label_names, "U")
94+
95+
def get_label_names(self, name):
96+
'''获得给定 name 的标签名称
97+
98+
参数
99+
======
100+
name in ['mnist', 'fashion_mnist', 'cifar10']
101+
'''
102+
if name == 'cifar10':
103+
label_names = self.h5.get_node('/cifar10', 'label_names')
104+
return np.asanyarray(label_names, "U")
105+
elif name == 'fashion_mnist':
106+
return [
107+
'T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
108+
'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'
109+
]
110+
elif name == 'mnist':
111+
return np.arange(10)

0 commit comments

Comments
 (0)