Description
Dear all
thank you for the scikit-plot package!
I am trying to plot feature importances from a random forest model which is used inside a pipeline with preprocessing steps.
I am able to extract the classifier from the pipeline, which is needed for the plot. but I am struggling with getting the feature names out of the classifier/pipe object. in the sample you provide the feature_names are listed manually which is easy for the iris dataset. but how to get them automatically from the classifier/pipe object?
there seems to be a method .get_feature_names() which is though not working for the pipe or classifier object.
would it be possible to use these feature_names as a default as the plot seems not helpful without attribute names.
thanks a lot for any hints! please find my code below.
Define preprocessing pipeline
categorical_features = [
'ped_alter',
'ped_sprache'
]
categorical_transformer = Pipeline(steps=
[("imputer", SimpleImputer(strategy='most_frequent')),
("onehot", OneHotEncoder(categories="auto", handle_unknown='ignore'))])
numeric_features = [
'n_y_ltm_tage',
'akt_a_total',
'akt_a_y1',
'akt_a_y2',
]
numeric_transformer = Pipeline(steps=[
('imputer', SimpleImputer(strategy='mean')),
('transformer', PowerTransformer())])
preprocessor = ColumnTransformer(
transformers=[
('num', numeric_transformer, numeric_features),
('cat', categorical_transformer, categorical_features)])
Define classifier & modeling pipeline
clf = RandomForestClassifier(n_estimators=100, max_depth=2, n_jobs=4)
pipe = Pipeline(steps=[('preprocessor', preprocessor),
('classifier', clf)])
Split DataFrame
y = df['target']
x = df.drop(['target'], axis=1)
Perform train/test split
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2)
pipe.fit(x_train, y_train)
#--- save pipe as a pickle
#--- load pipe as a pickle
clf = pipe.steps[1][1]
attr_list = clf.????????????
attr_list = pipe.??????????????
import scikitplot as skplt
skplt.estimators.plot_feature_importances(clf,
feature_names = attr_list,
max_num_features=30,
x_tick_rotation=45 )