简体   繁体   中英

scikit feature importance selection experiences

Scikit-learn has a mechanism to rank features (classification) using extreme randomized trees.

forest = ExtraTreesClassifier(n_estimators=250,
                          compute_importances=True,
                          random_state=0)

I have a question if this method is doing a "univariate" or "multivariate" feature ranking. Univariate case is where individual features are compared to each other. I would appreciate some clarifications here. Any other parameters that I should try to fiddle? Any experiences and pitfalls with this ranking methhod are also appreciated. THe output of this ranking identify feature numbers(5,20,7. I would like to check if the feature number really corresponds to the row in the feature matrix. THat is, the feature number 5 corresponds to the sixth row in the feature matrix (starts with 0).

I'm not an expert but this is not univariate. In fact the total feature importance is computed from the feature importance of each tree (taking the mean value i think).

For each tree, the importances are computed from the impurity of the split .

I used this method and it seems to give good results, better from my point of view than the univariate method. But I don't know any technique to test the results except the knowledge of the dataset.

To order, the feature correctly you should follow this example and modify it a bit like so to use pandas.DataFrame and their proper column names:

import numpy as np

from sklearn.ensemble import ExtraTreesClassifier

X = pandas.DataFrame(...)
Y = pandas.Series(...)

# Build a forest and compute the feature importances
forest = ExtraTreesClassifier(n_estimators=250,
                              random_state=0)

forest.fit(X, y)

feature_importance = forest.feature_importances_
feature_importance = 100.0 * (feature_importance / feature_importance.max())
sorted_idx = np.argsort(feature_importance)[::-1]
print "Feature importance:"
i=1
for f,w in zip(X.columns[sorted_idx], feature_importance[sorted_idx]):
    print "%d) %s : %d" % (i, f, w)
    i+=1
pos = np.arange(sorted_idx.shape[0]) + .5
plt.subplot(1, 2, 2)
nb_to_display = 30
plt.barh(pos[:nb_to_display], feature_importance[sorted_idx][:nb_to_display], align='center')
plt.yticks(pos[:nb_to_display], X.columns[sorted_idx][:nb_to_display])
plt.xlabel('Relative Importance')
plt.title('Variable Importance')
plt.show()

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM