Skip to content

Commit d3f27aa

Browse files
committed
添加对指定 class_names 的样本的选择
1 parent eba2277 commit d3f27aa

File tree

1 file changed

+29
-1
lines changed

1 file changed

+29
-1
lines changed

casia/feature.py

+29-1
Original file line numberDiff line numberDiff line change
@@ -154,4 +154,32 @@ def train_iter(self):
154154
def test_iter(self):
155155
'''返回 测试集的 (features, labels)'''
156156
for mpf in self.h5.iter_nodes('/test'):
157-
yield self._features(mpf), self._labels(mpf)
157+
yield self._features(mpf), self._labels(mpf)
158+
159+
160+
class SubCASIA(CASIAFeature):
161+
def __init__(self, class_names, hdf_path):
162+
'''casia 子集数据 MPF 特征处理工具
163+
通过给定的 class_names 获取 CASIA 的子集数据
164+
'''
165+
self.h5 = tb.open_file(hdf_path)
166+
self.class_names = class_names
167+
168+
def get_iter(self, super_iter):
169+
'''从 super_iter 获取包含 self.class_names 的迭代器'''
170+
for features, labels in super_iter:
171+
# 选择指定 class_names 的样本
172+
frame = DataFrame(features, labels)
173+
# 选择 frame.index 与 self.class_names 的交集
174+
frame = frame.loc[frame.index & self.class_names]
175+
features = frame.values
176+
labels = frame.index.values
177+
yield features, labels
178+
179+
def sub_train_iter(self):
180+
'''从 self.train_iter() 获取包含 self.class_names 的迭代器'''
181+
return self.get_iter(self.train_iter())
182+
183+
def sub_test_iter(self):
184+
'''从 self.test_iter() 获取包含 self.class_names 的迭代器'''
185+
return self.get_iter(self.test_iter())

0 commit comments

Comments
 (0)