Skip to content

Commit c3a35db

Browse files
remo-helpdesilinguist
authored andcommitted
Add a test for the new names method.
1 parent 1bf86ec commit c3a35db

File tree

1 file changed

+36
-0
lines changed

1 file changed

+36
-0
lines changed

tests/test_output.py

+36
Original file line numberDiff line numberDiff line change
@@ -1178,3 +1178,39 @@ def test_save_models_to_current_directory():
11781178
eq_(learner1.model_type, learner2.model_type)
11791179
eq_(learner1.model_params, learner2.model_params)
11801180
eq_(learner1.model_kwargs, learner2.model_kwargs)
1181+
1182+
1183+
def check_get_feature_names_out(selection):
1184+
# We want to make sure that this gives us the list of feature names AFTER
1185+
# the feature selector has done its work.
1186+
1187+
# we create some learners to test on
1188+
classifier_classes = ['LogisticRegression', 'DummyClassifier', 'LinearSVC']
1189+
classifiers = [Learner(i) for i in classifier_classes]
1190+
regressor_classes = ['LinearRegression', 'SVR', 'AdaBoostRegressor']
1191+
regressors = [Learner(i) for i in regressor_classes]
1192+
1193+
# we create some minimal training data for our learners
1194+
train_fs, _, _ = make_regression_data(num_examples=20, num_features=6)
1195+
1196+
# If the selection param is `True`, set a couple of feature to 0-values only,
1197+
# so the feature selector will remove them when training
1198+
if selection:
1199+
train_fs.features[:, [0, 3]] = 0
1200+
feature_names = ['f2', 'f3', 'f5', 'f6']
1201+
else:
1202+
feature_names = ['f1', 'f2', 'f3', 'f4', 'f5', 'f6']
1203+
1204+
# test whether we get the expected feature names from the
1205+
# `get_feature_names_out()` method
1206+
for classifier in classifiers:
1207+
classifier.train(train_fs, grid_search=False)
1208+
assert list(classifier.get_feature_names_out()) == feature_names
1209+
for regressor in regressors:
1210+
regressor.train(train_fs, grid_search=False)
1211+
assert list(regressor.get_feature_names_out()) == feature_names
1212+
1213+
1214+
def test_get_feature_names_out():
1215+
yield check_get_feature_names_out, True
1216+
yield check_get_feature_names_out, False

0 commit comments

Comments
 (0)