@@ -154,4 +154,32 @@ def train_iter(self):
154
154
def test_iter (self ):
155
155
'''返回 测试集的 (features, labels)'''
156
156
for mpf in self .h5 .iter_nodes ('/test' ):
157
- yield self ._features (mpf ), self ._labels (mpf )
157
+ yield self ._features (mpf ), self ._labels (mpf )
158
+
159
+
160
+ class SubCASIA (CASIAFeature ):
161
+ def __init__ (self , class_names , hdf_path ):
162
+ '''casia 子集数据 MPF 特征处理工具
163
+ 通过给定的 class_names 获取 CASIA 的子集数据
164
+ '''
165
+ self .h5 = tb .open_file (hdf_path )
166
+ self .class_names = class_names
167
+
168
+ def get_iter (self , super_iter ):
169
+ '''从 super_iter 获取包含 self.class_names 的迭代器'''
170
+ for features , labels in super_iter :
171
+ # 选择指定 class_names 的样本
172
+ frame = DataFrame (features , labels )
173
+ # 选择 frame.index 与 self.class_names 的交集
174
+ frame = frame .loc [frame .index & self .class_names ]
175
+ features = frame .values
176
+ labels = frame .index .values
177
+ yield features , labels
178
+
179
+ def sub_train_iter (self ):
180
+ '''从 self.train_iter() 获取包含 self.class_names 的迭代器'''
181
+ return self .get_iter (self .train_iter ())
182
+
183
+ def sub_test_iter (self ):
184
+ '''从 self.test_iter() 获取包含 self.class_names 的迭代器'''
185
+ return self .get_iter (self .test_iter ())
0 commit comments